MOA 12.03
Real Time Analytics for Data Streams
HoeffdingAdaptiveTree.java
Go to the documentation of this file.
00001 /*
00002  *    HoeffdingAdaptiveTree.java
00003  *    Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
00004  *    @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
00005  *
00006  *    This program is free software; you can redistribute it and/or modify
00007  *    it under the terms of the GNU General Public License as published by
00008  *    the Free Software Foundation; either version 3 of the License, or
00009  *    (at your option) any later version.
00010  *
00011  *    This program is distributed in the hope that it will be useful,
00012  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00013  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00014  *    GNU General Public License for more details.
00015  *
00016  *    You should have received a copy of the GNU General Public License
00017  *    along with this program. If not, see <http://www.gnu.org/licenses/>.
00018  *    
00019  */
00020 package moa.classifiers.trees;
00021 
00022 import java.util.LinkedList;
00023 import java.util.List;
00024 import java.util.Random;
00025 import moa.classifiers.bayes.NaiveBayes;
00026 import moa.classifiers.core.conditionaltests.InstanceConditionalTest;
00027 import moa.classifiers.core.driftdetection.ADWIN;
00028 import moa.core.DoubleVector;
00029 import moa.core.MiscUtils;
00030 import moa.options.MultiChoiceOption;
00031 import weka.core.Instance;
00032 import weka.core.Utils;
00033 
00053 public class HoeffdingAdaptiveTree extends HoeffdingTree {
00054 
00055     private static final long serialVersionUID = 1L;
00056 
00057     @Override
00058     public String getPurposeString() {
00059         return "Hoeffding Adaptive Tree for evolving data streams that uses ADWIN to replace branches for new ones.";
00060     }
00061     
00062  /*   public MultiChoiceOption leafpredictionOption = new MultiChoiceOption(
00063             "leafprediction", 'l', "Leaf prediction to use.", new String[]{
00064                 "MC", "NB", "NBAdaptive"}, new String[]{
00065                 "Majority class",
00066                 "Naive Bayes",
00067                 "Naive Bayes Adaptive"}, 2);*/
00068 
00069     public interface NewNode {
00070 
00071         // Change for adwin
00072         //public boolean getErrorChange();
00073         public int numberLeaves();
00074 
00075         public double getErrorEstimation();
00076 
00077         public double getErrorWidth();
00078 
00079         public boolean isNullError();
00080 
00081         public void killTreeChilds(HoeffdingAdaptiveTree ht);
00082 
00083         public void learnFromInstance(Instance inst, HoeffdingAdaptiveTree ht, SplitNode parent, int parentBranch);
00084 
00085         public void filterInstanceToLeaves(Instance inst, SplitNode myparent, int parentBranch, List<FoundNode> foundNodes,
00086                 boolean updateSplitterCounts);
00087     }
00088 
00089     public static class AdaSplitNode extends SplitNode implements NewNode {
00090 
00091         private static final long serialVersionUID = 1L;
00092 
00093         protected Node alternateTree;
00094 
00095         protected ADWIN estimationErrorWeight;
00096         //public boolean isAlternateTree = false;
00097 
00098         public boolean ErrorChange = false;
00099 
00100         protected int randomSeed = 1;
00101 
00102         protected Random classifierRandom;
00103 
00104         //public boolean getErrorChange() {
00105         //              return ErrorChange;
00106         //}
00107         @Override
00108         public int calcByteSizeIncludingSubtree() {
00109             int byteSize = calcByteSize();
00110             if (alternateTree != null) {
00111                 byteSize += alternateTree.calcByteSizeIncludingSubtree();
00112             }
00113             if (estimationErrorWeight != null) {
00114                 byteSize += estimationErrorWeight.measureByteSize();
00115             }
00116             for (Node child : this.children) {
00117                 if (child != null) {
00118                     byteSize += child.calcByteSizeIncludingSubtree();
00119                 }
00120             }
00121             return byteSize;
00122         }
00123 
00124         public AdaSplitNode(InstanceConditionalTest splitTest,
00125                 double[] classObservations) {
00126             super(splitTest, classObservations);
00127             this.classifierRandom = new Random(this.randomSeed);
00128         }
00129 
00130         @Override
00131         public int numberLeaves() {
00132             int numLeaves = 0;
00133             for (Node child : this.children) {
00134                 if (child != null) {
00135                     numLeaves += ((NewNode) child).numberLeaves();
00136                 }
00137             }
00138             return numLeaves + 1;
00139         }
00140 
00141         @Override
00142         public double getErrorEstimation() {
00143             return this.estimationErrorWeight.getEstimation();
00144         }
00145 
00146         @Override
00147         public double getErrorWidth() {
00148             double w = 0.0;
00149             if (isNullError() == false) {
00150                 w = this.estimationErrorWeight.getWidth();
00151             }
00152             return w;
00153         }
00154 
00155         @Override
00156         public boolean isNullError() {
00157             return (this.estimationErrorWeight == null);
00158         }
00159 
00160         // SplitNodes can have alternative trees, but LearningNodes can't
00161         // LearningNodes can split, but SplitNodes can't
00162         // Parent nodes are allways SplitNodes
00163         @Override
00164         public void learnFromInstance(Instance inst, HoeffdingAdaptiveTree ht, SplitNode parent, int parentBranch) {
00165             int trueClass = (int) inst.classValue();
00166             //New option vore
00167             int k = MiscUtils.poisson(1.0, this.classifierRandom);
00168             Instance weightedInst = (Instance) inst.copy();
00169             if (k > 0) {
00170                 //weightedInst.setWeight(inst.weight() * k);
00171             }
00172             //Compute ClassPrediction using filterInstanceToLeaf
00173             //int ClassPrediction = Utils.maxIndex(filterInstanceToLeaf(inst, null, -1).node.getClassVotes(inst, ht));
00174             int ClassPrediction = 0;
00175             if (filterInstanceToLeaf(inst, parent, parentBranch).node != null) {
00176                 ClassPrediction = Utils.maxIndex(filterInstanceToLeaf(inst, parent, parentBranch).node.getClassVotes(inst, ht));
00177             }
00178 
00179             boolean blCorrect = (trueClass == ClassPrediction);
00180 
00181             if (this.estimationErrorWeight == null) {
00182                 this.estimationErrorWeight = new ADWIN();
00183             }
00184             double oldError = this.getErrorEstimation();
00185             this.ErrorChange = this.estimationErrorWeight.setInput(blCorrect == true ? 0.0 : 1.0);
00186             if (this.ErrorChange == true && oldError > this.getErrorEstimation()) {
00187                 //if error is decreasing, don't do anything
00188                 this.ErrorChange = false;
00189             }
00190 
00191             // Check condition to build a new alternate tree
00192             //if (this.isAlternateTree == false) {
00193             if (this.ErrorChange == true) {//&& this.alternateTree == null) {
00194                 //Start a new alternative tree : learning node
00195                 this.alternateTree = ht.newLearningNode();
00196                 //this.alternateTree.isAlternateTree = true;
00197                 ht.alternateTrees++;
00198             } // Check condition to replace tree
00199             else if (this.alternateTree != null && ((NewNode) this.alternateTree).isNullError() == false) {
00200                 if (this.getErrorWidth() > 300 && ((NewNode) this.alternateTree).getErrorWidth() > 300) {
00201                     double oldErrorRate = this.getErrorEstimation();
00202                     double altErrorRate = ((NewNode) this.alternateTree).getErrorEstimation();
00203                     double fDelta = .05;
00204                     //if (gNumAlts>0) fDelta=fDelta/gNumAlts;
00205                     double fN = 1.0 / ((double) ((NewNode) this.alternateTree).getErrorWidth()) + 1.0 / ((double) this.getErrorWidth());
00206                     double Bound = (double) Math.sqrt((double) 2.0 * oldErrorRate * (1.0 - oldErrorRate) * Math.log(2.0 / fDelta) * fN);
00207                     if (Bound < oldErrorRate - altErrorRate) {
00208                         // Switch alternate tree
00209                         ht.activeLeafNodeCount -= this.numberLeaves();
00210                         ht.activeLeafNodeCount += ((NewNode) this.alternateTree).numberLeaves();
00211                         killTreeChilds(ht);
00212                         if (parent != null) {
00213                             parent.setChild(parentBranch, this.alternateTree);
00214                             //((AdaSplitNode) parent.getChild(parentBranch)).alternateTree = null;
00215                         } else {
00216                             // Switch root tree
00217                             ht.treeRoot = ((AdaSplitNode) ht.treeRoot).alternateTree;
00218                         }
00219                         ht.switchedAlternateTrees++;
00220                     } else if (Bound < altErrorRate - oldErrorRate) {
00221                         // Erase alternate tree
00222                         if (this.alternateTree instanceof ActiveLearningNode) {
00223                             this.alternateTree = null;
00224                             ht.activeLeafNodeCount--;
00225                         } else if (this.alternateTree instanceof InactiveLearningNode) {
00226                             this.alternateTree = null;
00227                             ht.inactiveLeafNodeCount--;
00228                         } else {
00229                             ((AdaSplitNode) this.alternateTree).killTreeChilds(ht);
00230                         }
00231                         ht.prunedAlternateTrees++;
00232                     }
00233                 }
00234             }
00235             //}
00236             //learnFromInstance alternate Tree and Child nodes
00237             if (this.alternateTree != null) {
00238                 ((NewNode) this.alternateTree).learnFromInstance(weightedInst, ht, parent, parentBranch);
00239             }
00240             int childBranch = this.instanceChildIndex(inst);
00241             Node child = this.getChild(childBranch);
00242             if (child != null) {
00243                 ((NewNode) child).learnFromInstance(weightedInst, ht, this, childBranch);
00244             }
00245         }
00246 
00247         @Override
00248         public void killTreeChilds(HoeffdingAdaptiveTree ht) {
00249             for (Node child : this.children) {
00250                 if (child != null) {
00251                     //Delete alternate tree if it exists
00252                     if (child instanceof AdaSplitNode && ((AdaSplitNode) child).alternateTree != null) {
00253                         ((NewNode) ((AdaSplitNode) child).alternateTree).killTreeChilds(ht);
00254                         ht.prunedAlternateTrees++;
00255                     }
00256                     //Recursive delete of SplitNodes
00257                     if (child instanceof AdaSplitNode) {
00258                         ((NewNode) child).killTreeChilds(ht);
00259                     }
00260                     if (child instanceof ActiveLearningNode) {
00261                         child = null;
00262                         ht.activeLeafNodeCount--;
00263                     } else if (child instanceof InactiveLearningNode) {
00264                         child = null;
00265                         ht.inactiveLeafNodeCount--;
00266                     }
00267                 }
00268             }
00269         }
00270 
00271         //New for option votes
00272         //@Override
00273         public void filterInstanceToLeaves(Instance inst, SplitNode myparent,
00274                 int parentBranch, List<FoundNode> foundNodes,
00275                 boolean updateSplitterCounts) {
00276             if (updateSplitterCounts) {
00277                 this.observedClassDistribution.addToValue((int) inst.classValue(), inst.weight());
00278             }
00279             int childIndex = instanceChildIndex(inst);
00280             if (childIndex >= 0) {
00281                 Node child = getChild(childIndex);
00282                 if (child != null) {
00283                     ((NewNode) child).filterInstanceToLeaves(inst, this, childIndex,
00284                             foundNodes, updateSplitterCounts);
00285                 } else {
00286                     foundNodes.add(new FoundNode(null, this, childIndex));
00287                 }
00288             }
00289             if (this.alternateTree != null) {
00290                 ((NewNode) this.alternateTree).filterInstanceToLeaves(inst, this, -999,
00291                         foundNodes, updateSplitterCounts);
00292             }
00293         }
00294     }
00295 
00296     public static class AdaLearningNode extends LearningNodeNBAdaptive implements NewNode {
00297 
00298         private static final long serialVersionUID = 1L;
00299 
00300         protected ADWIN estimationErrorWeight;
00301 
00302         public boolean ErrorChange = false;
00303 
00304         protected int randomSeed = 1;
00305 
00306         protected Random classifierRandom;
00307 
00308         @Override
00309         public int calcByteSize() {
00310             int byteSize = super.calcByteSize();
00311             if (estimationErrorWeight != null) {
00312                 byteSize += estimationErrorWeight.measureByteSize();
00313             }
00314             return byteSize;
00315         }
00316 
00317         public AdaLearningNode(double[] initialClassObservations) {
00318             super(initialClassObservations);
00319             this.classifierRandom = new Random(this.randomSeed);
00320         }
00321 
00322         @Override
00323         public int numberLeaves() {
00324             return 1;
00325         }
00326 
00327         @Override
00328         public double getErrorEstimation() {
00329             if (this.estimationErrorWeight != null) {
00330                 return this.estimationErrorWeight.getEstimation();
00331             } else {
00332                 return 0;
00333             }
00334         }
00335 
00336         @Override
00337         public double getErrorWidth() {
00338             return this.estimationErrorWeight.getWidth();
00339         }
00340 
00341         @Override
00342         public boolean isNullError() {
00343             return (this.estimationErrorWeight == null);
00344         }
00345 
00346         @Override
00347         public void killTreeChilds(HoeffdingAdaptiveTree ht) {
00348         }
00349 
00350         @Override
00351         public void learnFromInstance(Instance inst, HoeffdingAdaptiveTree ht, SplitNode parent, int parentBranch) {
00352             int trueClass = (int) inst.classValue();
00353             //New option vore
00354             int k = MiscUtils.poisson(1.0, this.classifierRandom);
00355             Instance weightedInst = (Instance) inst.copy();
00356             if (k > 0) {
00357                 weightedInst.setWeight(inst.weight() * k);
00358             }
00359             //Compute ClassPrediction using filterInstanceToLeaf
00360             int ClassPrediction = Utils.maxIndex(this.getClassVotes(inst, ht));
00361 
00362             boolean blCorrect = (trueClass == ClassPrediction);
00363 
00364             if (this.estimationErrorWeight == null) {
00365                 this.estimationErrorWeight = new ADWIN();
00366             }
00367             double oldError = this.getErrorEstimation();
00368             this.ErrorChange = this.estimationErrorWeight.setInput(blCorrect == true ? 0.0 : 1.0);
00369             if (this.ErrorChange == true && oldError > this.getErrorEstimation()) {
00370                 this.ErrorChange = false;
00371             }
00372 
00373             //Update statistics
00374             learnFromInstance(weightedInst, ht);        //inst
00375 
00376             //Check for Split condition
00377             double weightSeen = this.getWeightSeen();
00378             if (weightSeen
00379                     - this.getWeightSeenAtLastSplitEvaluation() >= ht.gracePeriodOption.getValue()) {
00380                 ht.attemptToSplit(this, parent,
00381                         parentBranch);
00382                 this.setWeightSeenAtLastSplitEvaluation(weightSeen);
00383             }
00384 
00385 
00386             //learnFromInstance alternate Tree and Child nodes
00387                         /*if (this.alternateTree != null)  {
00388             this.alternateTree.learnFromInstance(inst,ht);
00389             }
00390             for (Node child : this.children) {
00391             if (child != null) {
00392             child.learnFromInstance(inst,ht);
00393             }
00394             }*/
00395         }
00396 
00397         @Override
00398         public double[] getClassVotes(Instance inst, HoeffdingTree ht) {
00399             double[] dist;
00400             int predictionOption = ((HoeffdingAdaptiveTree) ht).leafpredictionOption.getChosenIndex();
00401             if (predictionOption == 0) { //MC
00402                 dist = this.observedClassDistribution.getArrayCopy();
00403             } else if (predictionOption == 1) { //NB
00404                 dist = NaiveBayes.doNaiveBayesPrediction(inst,
00405                         this.observedClassDistribution, this.attributeObservers);
00406             } else { //NBAdaptive
00407                 if (this.mcCorrectWeight > this.nbCorrectWeight) {
00408                     dist = this.observedClassDistribution.getArrayCopy();
00409                 } else {
00410                     dist = NaiveBayes.doNaiveBayesPrediction(inst,
00411                             this.observedClassDistribution, this.attributeObservers);
00412                 }
00413             }
00414             //New for option votes
00415             double distSum = Utils.sum(dist);
00416             if (distSum * this.getErrorEstimation() * this.getErrorEstimation() > 0.0) {
00417                 Utils.normalize(dist, distSum * this.getErrorEstimation() * this.getErrorEstimation()); //Adding weight
00418             }
00419             return dist;
00420         }
00421 
00422         //New for option votes
00423         @Override
00424         public void filterInstanceToLeaves(Instance inst,
00425                 SplitNode splitparent, int parentBranch,
00426                 List<FoundNode> foundNodes, boolean updateSplitterCounts) {
00427             foundNodes.add(new FoundNode(this, splitparent, parentBranch));
00428         }
00429     }
00430 
00431     protected int activeLeafNodeCount;
00432 
00433     protected int inactiveLeafNodeCount;
00434 
00435     protected int alternateTrees;
00436 
00437     protected int prunedAlternateTrees;
00438 
00439     protected int switchedAlternateTrees;
00440 
00441     @Override
00442     protected LearningNode newLearningNode(double[] initialClassObservations) {
00443         // IDEA: to choose different learning nodes depending on predictionOption
00444         return new AdaLearningNode(initialClassObservations);
00445     }
00446 
00447     //@Override
00448     @Override
00449     protected SplitNode newSplitNode(InstanceConditionalTest splitTest,
00450             double[] classObservations) {
00451         return new AdaSplitNode(splitTest, classObservations);
00452     }
00453 
00454     @Override
00455     public void trainOnInstanceImpl(Instance inst) {
00456         if (this.treeRoot == null) {
00457             this.treeRoot = newLearningNode();
00458             this.activeLeafNodeCount = 1;
00459         }
00460         ((NewNode) this.treeRoot).learnFromInstance(inst, this, null, -1);
00461     }
00462 
00463     //New for options vote
00464     public FoundNode[] filterInstanceToLeaves(Instance inst,
00465             SplitNode parent, int parentBranch, boolean updateSplitterCounts) {
00466         List<FoundNode> nodes = new LinkedList<FoundNode>();
00467         ((NewNode) this.treeRoot).filterInstanceToLeaves(inst, parent, parentBranch, nodes,
00468                 updateSplitterCounts);
00469         return nodes.toArray(new FoundNode[nodes.size()]);
00470     }
00471 
00472     @Override
00473     public double[] getVotesForInstance(Instance inst) {
00474         if (this.treeRoot != null) {
00475             FoundNode[] foundNodes = filterInstanceToLeaves(inst,
00476                     null, -1, false);
00477             DoubleVector result = new DoubleVector();
00478             int predictionPaths = 0;
00479             for (FoundNode foundNode : foundNodes) {
00480                 if (foundNode.parentBranch != -999) {
00481                     Node leafNode = foundNode.node;
00482                     if (leafNode == null) {
00483                         leafNode = foundNode.parent;
00484                     }
00485                     double[] dist = leafNode.getClassVotes(inst, this);
00486                     //Albert: changed for weights
00487                     //double distSum = Utils.sum(dist);
00488                     //if (distSum > 0.0) {
00489                     //  Utils.normalize(dist, distSum);
00490                     //}
00491                     result.addValues(dist);
00492                     //predictionPaths++;
00493                 }
00494             }
00495             //if (predictionPaths > this.maxPredictionPaths) {
00496             //  this.maxPredictionPaths++;
00497             //}
00498             return result.getArrayRef();
00499         }
00500         return new double[0];
00501     }
00502 }
 All Classes Namespaces Files Functions Variables Enumerations