MOA 12.03
Real Time Analytics for Data Streams
|
00001 /* 00002 * HoeffdingAdaptiveTree.java 00003 * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand 00004 * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) 00005 * 00006 * This program is free software; you can redistribute it and/or modify 00007 * it under the terms of the GNU General Public License as published by 00008 * the Free Software Foundation; either version 3 of the License, or 00009 * (at your option) any later version. 00010 * 00011 * This program is distributed in the hope that it will be useful, 00012 * but WITHOUT ANY WARRANTY; without even the implied warranty of 00013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00014 * GNU General Public License for more details. 00015 * 00016 * You should have received a copy of the GNU General Public License 00017 * along with this program. If not, see <http://www.gnu.org/licenses/>. 00018 * 00019 */ 00020 package moa.classifiers.trees; 00021 00022 import java.util.LinkedList; 00023 import java.util.List; 00024 import java.util.Random; 00025 import moa.classifiers.bayes.NaiveBayes; 00026 import moa.classifiers.core.conditionaltests.InstanceConditionalTest; 00027 import moa.classifiers.core.driftdetection.ADWIN; 00028 import moa.core.DoubleVector; 00029 import moa.core.MiscUtils; 00030 import moa.options.MultiChoiceOption; 00031 import weka.core.Instance; 00032 import weka.core.Utils; 00033 00053 public class HoeffdingAdaptiveTree extends HoeffdingTree { 00054 00055 private static final long serialVersionUID = 1L; 00056 00057 @Override 00058 public String getPurposeString() { 00059 return "Hoeffding Adaptive Tree for evolving data streams that uses ADWIN to replace branches for new ones."; 00060 } 00061 00062 /* public MultiChoiceOption leafpredictionOption = new MultiChoiceOption( 00063 "leafprediction", 'l', "Leaf prediction to use.", new String[]{ 00064 "MC", "NB", "NBAdaptive"}, new String[]{ 00065 "Majority class", 00066 "Naive Bayes", 00067 "Naive Bayes Adaptive"}, 2);*/ 00068 00069 public interface NewNode { 00070 00071 // Change for adwin 00072 //public boolean getErrorChange(); 00073 public int numberLeaves(); 00074 00075 public double getErrorEstimation(); 00076 00077 public double getErrorWidth(); 00078 00079 public boolean isNullError(); 00080 00081 public void killTreeChilds(HoeffdingAdaptiveTree ht); 00082 00083 public void learnFromInstance(Instance inst, HoeffdingAdaptiveTree ht, SplitNode parent, int parentBranch); 00084 00085 public void filterInstanceToLeaves(Instance inst, SplitNode myparent, int parentBranch, List<FoundNode> foundNodes, 00086 boolean updateSplitterCounts); 00087 } 00088 00089 public static class AdaSplitNode extends SplitNode implements NewNode { 00090 00091 private static final long serialVersionUID = 1L; 00092 00093 protected Node alternateTree; 00094 00095 protected ADWIN estimationErrorWeight; 00096 //public boolean isAlternateTree = false; 00097 00098 public boolean ErrorChange = false; 00099 00100 protected int randomSeed = 1; 00101 00102 protected Random classifierRandom; 00103 00104 //public boolean getErrorChange() { 00105 // return ErrorChange; 00106 //} 00107 @Override 00108 public int calcByteSizeIncludingSubtree() { 00109 int byteSize = calcByteSize(); 00110 if (alternateTree != null) { 00111 byteSize += alternateTree.calcByteSizeIncludingSubtree(); 00112 } 00113 if (estimationErrorWeight != null) { 00114 byteSize += estimationErrorWeight.measureByteSize(); 00115 } 00116 for (Node child : this.children) { 00117 if (child != null) { 00118 byteSize += child.calcByteSizeIncludingSubtree(); 00119 } 00120 } 00121 return byteSize; 00122 } 00123 00124 public AdaSplitNode(InstanceConditionalTest splitTest, 00125 double[] classObservations) { 00126 super(splitTest, classObservations); 00127 this.classifierRandom = new Random(this.randomSeed); 00128 } 00129 00130 @Override 00131 public int numberLeaves() { 00132 int numLeaves = 0; 00133 for (Node child : this.children) { 00134 if (child != null) { 00135 numLeaves += ((NewNode) child).numberLeaves(); 00136 } 00137 } 00138 return numLeaves + 1; 00139 } 00140 00141 @Override 00142 public double getErrorEstimation() { 00143 return this.estimationErrorWeight.getEstimation(); 00144 } 00145 00146 @Override 00147 public double getErrorWidth() { 00148 double w = 0.0; 00149 if (isNullError() == false) { 00150 w = this.estimationErrorWeight.getWidth(); 00151 } 00152 return w; 00153 } 00154 00155 @Override 00156 public boolean isNullError() { 00157 return (this.estimationErrorWeight == null); 00158 } 00159 00160 // SplitNodes can have alternative trees, but LearningNodes can't 00161 // LearningNodes can split, but SplitNodes can't 00162 // Parent nodes are allways SplitNodes 00163 @Override 00164 public void learnFromInstance(Instance inst, HoeffdingAdaptiveTree ht, SplitNode parent, int parentBranch) { 00165 int trueClass = (int) inst.classValue(); 00166 //New option vore 00167 int k = MiscUtils.poisson(1.0, this.classifierRandom); 00168 Instance weightedInst = (Instance) inst.copy(); 00169 if (k > 0) { 00170 //weightedInst.setWeight(inst.weight() * k); 00171 } 00172 //Compute ClassPrediction using filterInstanceToLeaf 00173 //int ClassPrediction = Utils.maxIndex(filterInstanceToLeaf(inst, null, -1).node.getClassVotes(inst, ht)); 00174 int ClassPrediction = 0; 00175 if (filterInstanceToLeaf(inst, parent, parentBranch).node != null) { 00176 ClassPrediction = Utils.maxIndex(filterInstanceToLeaf(inst, parent, parentBranch).node.getClassVotes(inst, ht)); 00177 } 00178 00179 boolean blCorrect = (trueClass == ClassPrediction); 00180 00181 if (this.estimationErrorWeight == null) { 00182 this.estimationErrorWeight = new ADWIN(); 00183 } 00184 double oldError = this.getErrorEstimation(); 00185 this.ErrorChange = this.estimationErrorWeight.setInput(blCorrect == true ? 0.0 : 1.0); 00186 if (this.ErrorChange == true && oldError > this.getErrorEstimation()) { 00187 //if error is decreasing, don't do anything 00188 this.ErrorChange = false; 00189 } 00190 00191 // Check condition to build a new alternate tree 00192 //if (this.isAlternateTree == false) { 00193 if (this.ErrorChange == true) {//&& this.alternateTree == null) { 00194 //Start a new alternative tree : learning node 00195 this.alternateTree = ht.newLearningNode(); 00196 //this.alternateTree.isAlternateTree = true; 00197 ht.alternateTrees++; 00198 } // Check condition to replace tree 00199 else if (this.alternateTree != null && ((NewNode) this.alternateTree).isNullError() == false) { 00200 if (this.getErrorWidth() > 300 && ((NewNode) this.alternateTree).getErrorWidth() > 300) { 00201 double oldErrorRate = this.getErrorEstimation(); 00202 double altErrorRate = ((NewNode) this.alternateTree).getErrorEstimation(); 00203 double fDelta = .05; 00204 //if (gNumAlts>0) fDelta=fDelta/gNumAlts; 00205 double fN = 1.0 / ((double) ((NewNode) this.alternateTree).getErrorWidth()) + 1.0 / ((double) this.getErrorWidth()); 00206 double Bound = (double) Math.sqrt((double) 2.0 * oldErrorRate * (1.0 - oldErrorRate) * Math.log(2.0 / fDelta) * fN); 00207 if (Bound < oldErrorRate - altErrorRate) { 00208 // Switch alternate tree 00209 ht.activeLeafNodeCount -= this.numberLeaves(); 00210 ht.activeLeafNodeCount += ((NewNode) this.alternateTree).numberLeaves(); 00211 killTreeChilds(ht); 00212 if (parent != null) { 00213 parent.setChild(parentBranch, this.alternateTree); 00214 //((AdaSplitNode) parent.getChild(parentBranch)).alternateTree = null; 00215 } else { 00216 // Switch root tree 00217 ht.treeRoot = ((AdaSplitNode) ht.treeRoot).alternateTree; 00218 } 00219 ht.switchedAlternateTrees++; 00220 } else if (Bound < altErrorRate - oldErrorRate) { 00221 // Erase alternate tree 00222 if (this.alternateTree instanceof ActiveLearningNode) { 00223 this.alternateTree = null; 00224 ht.activeLeafNodeCount--; 00225 } else if (this.alternateTree instanceof InactiveLearningNode) { 00226 this.alternateTree = null; 00227 ht.inactiveLeafNodeCount--; 00228 } else { 00229 ((AdaSplitNode) this.alternateTree).killTreeChilds(ht); 00230 } 00231 ht.prunedAlternateTrees++; 00232 } 00233 } 00234 } 00235 //} 00236 //learnFromInstance alternate Tree and Child nodes 00237 if (this.alternateTree != null) { 00238 ((NewNode) this.alternateTree).learnFromInstance(weightedInst, ht, parent, parentBranch); 00239 } 00240 int childBranch = this.instanceChildIndex(inst); 00241 Node child = this.getChild(childBranch); 00242 if (child != null) { 00243 ((NewNode) child).learnFromInstance(weightedInst, ht, this, childBranch); 00244 } 00245 } 00246 00247 @Override 00248 public void killTreeChilds(HoeffdingAdaptiveTree ht) { 00249 for (Node child : this.children) { 00250 if (child != null) { 00251 //Delete alternate tree if it exists 00252 if (child instanceof AdaSplitNode && ((AdaSplitNode) child).alternateTree != null) { 00253 ((NewNode) ((AdaSplitNode) child).alternateTree).killTreeChilds(ht); 00254 ht.prunedAlternateTrees++; 00255 } 00256 //Recursive delete of SplitNodes 00257 if (child instanceof AdaSplitNode) { 00258 ((NewNode) child).killTreeChilds(ht); 00259 } 00260 if (child instanceof ActiveLearningNode) { 00261 child = null; 00262 ht.activeLeafNodeCount--; 00263 } else if (child instanceof InactiveLearningNode) { 00264 child = null; 00265 ht.inactiveLeafNodeCount--; 00266 } 00267 } 00268 } 00269 } 00270 00271 //New for option votes 00272 //@Override 00273 public void filterInstanceToLeaves(Instance inst, SplitNode myparent, 00274 int parentBranch, List<FoundNode> foundNodes, 00275 boolean updateSplitterCounts) { 00276 if (updateSplitterCounts) { 00277 this.observedClassDistribution.addToValue((int) inst.classValue(), inst.weight()); 00278 } 00279 int childIndex = instanceChildIndex(inst); 00280 if (childIndex >= 0) { 00281 Node child = getChild(childIndex); 00282 if (child != null) { 00283 ((NewNode) child).filterInstanceToLeaves(inst, this, childIndex, 00284 foundNodes, updateSplitterCounts); 00285 } else { 00286 foundNodes.add(new FoundNode(null, this, childIndex)); 00287 } 00288 } 00289 if (this.alternateTree != null) { 00290 ((NewNode) this.alternateTree).filterInstanceToLeaves(inst, this, -999, 00291 foundNodes, updateSplitterCounts); 00292 } 00293 } 00294 } 00295 00296 public static class AdaLearningNode extends LearningNodeNBAdaptive implements NewNode { 00297 00298 private static final long serialVersionUID = 1L; 00299 00300 protected ADWIN estimationErrorWeight; 00301 00302 public boolean ErrorChange = false; 00303 00304 protected int randomSeed = 1; 00305 00306 protected Random classifierRandom; 00307 00308 @Override 00309 public int calcByteSize() { 00310 int byteSize = super.calcByteSize(); 00311 if (estimationErrorWeight != null) { 00312 byteSize += estimationErrorWeight.measureByteSize(); 00313 } 00314 return byteSize; 00315 } 00316 00317 public AdaLearningNode(double[] initialClassObservations) { 00318 super(initialClassObservations); 00319 this.classifierRandom = new Random(this.randomSeed); 00320 } 00321 00322 @Override 00323 public int numberLeaves() { 00324 return 1; 00325 } 00326 00327 @Override 00328 public double getErrorEstimation() { 00329 if (this.estimationErrorWeight != null) { 00330 return this.estimationErrorWeight.getEstimation(); 00331 } else { 00332 return 0; 00333 } 00334 } 00335 00336 @Override 00337 public double getErrorWidth() { 00338 return this.estimationErrorWeight.getWidth(); 00339 } 00340 00341 @Override 00342 public boolean isNullError() { 00343 return (this.estimationErrorWeight == null); 00344 } 00345 00346 @Override 00347 public void killTreeChilds(HoeffdingAdaptiveTree ht) { 00348 } 00349 00350 @Override 00351 public void learnFromInstance(Instance inst, HoeffdingAdaptiveTree ht, SplitNode parent, int parentBranch) { 00352 int trueClass = (int) inst.classValue(); 00353 //New option vore 00354 int k = MiscUtils.poisson(1.0, this.classifierRandom); 00355 Instance weightedInst = (Instance) inst.copy(); 00356 if (k > 0) { 00357 weightedInst.setWeight(inst.weight() * k); 00358 } 00359 //Compute ClassPrediction using filterInstanceToLeaf 00360 int ClassPrediction = Utils.maxIndex(this.getClassVotes(inst, ht)); 00361 00362 boolean blCorrect = (trueClass == ClassPrediction); 00363 00364 if (this.estimationErrorWeight == null) { 00365 this.estimationErrorWeight = new ADWIN(); 00366 } 00367 double oldError = this.getErrorEstimation(); 00368 this.ErrorChange = this.estimationErrorWeight.setInput(blCorrect == true ? 0.0 : 1.0); 00369 if (this.ErrorChange == true && oldError > this.getErrorEstimation()) { 00370 this.ErrorChange = false; 00371 } 00372 00373 //Update statistics 00374 learnFromInstance(weightedInst, ht); //inst 00375 00376 //Check for Split condition 00377 double weightSeen = this.getWeightSeen(); 00378 if (weightSeen 00379 - this.getWeightSeenAtLastSplitEvaluation() >= ht.gracePeriodOption.getValue()) { 00380 ht.attemptToSplit(this, parent, 00381 parentBranch); 00382 this.setWeightSeenAtLastSplitEvaluation(weightSeen); 00383 } 00384 00385 00386 //learnFromInstance alternate Tree and Child nodes 00387 /*if (this.alternateTree != null) { 00388 this.alternateTree.learnFromInstance(inst,ht); 00389 } 00390 for (Node child : this.children) { 00391 if (child != null) { 00392 child.learnFromInstance(inst,ht); 00393 } 00394 }*/ 00395 } 00396 00397 @Override 00398 public double[] getClassVotes(Instance inst, HoeffdingTree ht) { 00399 double[] dist; 00400 int predictionOption = ((HoeffdingAdaptiveTree) ht).leafpredictionOption.getChosenIndex(); 00401 if (predictionOption == 0) { //MC 00402 dist = this.observedClassDistribution.getArrayCopy(); 00403 } else if (predictionOption == 1) { //NB 00404 dist = NaiveBayes.doNaiveBayesPrediction(inst, 00405 this.observedClassDistribution, this.attributeObservers); 00406 } else { //NBAdaptive 00407 if (this.mcCorrectWeight > this.nbCorrectWeight) { 00408 dist = this.observedClassDistribution.getArrayCopy(); 00409 } else { 00410 dist = NaiveBayes.doNaiveBayesPrediction(inst, 00411 this.observedClassDistribution, this.attributeObservers); 00412 } 00413 } 00414 //New for option votes 00415 double distSum = Utils.sum(dist); 00416 if (distSum * this.getErrorEstimation() * this.getErrorEstimation() > 0.0) { 00417 Utils.normalize(dist, distSum * this.getErrorEstimation() * this.getErrorEstimation()); //Adding weight 00418 } 00419 return dist; 00420 } 00421 00422 //New for option votes 00423 @Override 00424 public void filterInstanceToLeaves(Instance inst, 00425 SplitNode splitparent, int parentBranch, 00426 List<FoundNode> foundNodes, boolean updateSplitterCounts) { 00427 foundNodes.add(new FoundNode(this, splitparent, parentBranch)); 00428 } 00429 } 00430 00431 protected int activeLeafNodeCount; 00432 00433 protected int inactiveLeafNodeCount; 00434 00435 protected int alternateTrees; 00436 00437 protected int prunedAlternateTrees; 00438 00439 protected int switchedAlternateTrees; 00440 00441 @Override 00442 protected LearningNode newLearningNode(double[] initialClassObservations) { 00443 // IDEA: to choose different learning nodes depending on predictionOption 00444 return new AdaLearningNode(initialClassObservations); 00445 } 00446 00447 //@Override 00448 @Override 00449 protected SplitNode newSplitNode(InstanceConditionalTest splitTest, 00450 double[] classObservations) { 00451 return new AdaSplitNode(splitTest, classObservations); 00452 } 00453 00454 @Override 00455 public void trainOnInstanceImpl(Instance inst) { 00456 if (this.treeRoot == null) { 00457 this.treeRoot = newLearningNode(); 00458 this.activeLeafNodeCount = 1; 00459 } 00460 ((NewNode) this.treeRoot).learnFromInstance(inst, this, null, -1); 00461 } 00462 00463 //New for options vote 00464 public FoundNode[] filterInstanceToLeaves(Instance inst, 00465 SplitNode parent, int parentBranch, boolean updateSplitterCounts) { 00466 List<FoundNode> nodes = new LinkedList<FoundNode>(); 00467 ((NewNode) this.treeRoot).filterInstanceToLeaves(inst, parent, parentBranch, nodes, 00468 updateSplitterCounts); 00469 return nodes.toArray(new FoundNode[nodes.size()]); 00470 } 00471 00472 @Override 00473 public double[] getVotesForInstance(Instance inst) { 00474 if (this.treeRoot != null) { 00475 FoundNode[] foundNodes = filterInstanceToLeaves(inst, 00476 null, -1, false); 00477 DoubleVector result = new DoubleVector(); 00478 int predictionPaths = 0; 00479 for (FoundNode foundNode : foundNodes) { 00480 if (foundNode.parentBranch != -999) { 00481 Node leafNode = foundNode.node; 00482 if (leafNode == null) { 00483 leafNode = foundNode.parent; 00484 } 00485 double[] dist = leafNode.getClassVotes(inst, this); 00486 //Albert: changed for weights 00487 //double distSum = Utils.sum(dist); 00488 //if (distSum > 0.0) { 00489 // Utils.normalize(dist, distSum); 00490 //} 00491 result.addValues(dist); 00492 //predictionPaths++; 00493 } 00494 } 00495 //if (predictionPaths > this.maxPredictionPaths) { 00496 // this.maxPredictionPaths++; 00497 //} 00498 return result.getArrayRef(); 00499 } 00500 return new double[0]; 00501 } 00502 }