MOA 12.03
Real Time Analytics for Data Streams
HoeffdingTree.java
Go to the documentation of this file.
00001 /*
00002  *    HoeffdingTree.java
00003  *    Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
00004  *    @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
00005  *
00006  *    This program is free software; you can redistribute it and/or modify
00007  *    it under the terms of the GNU General Public License as published by
00008  *    the Free Software Foundation; either version 3 of the License, or
00009  *    (at your option) any later version.
00010  *
00011  *    This program is distributed in the hope that it will be useful,
00012  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00013  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00014  *    GNU General Public License for more details.
00015  *
00016  *    You should have received a copy of the GNU General Public License
00017  *    along with this program. If not, see <http://www.gnu.org/licenses/>.
00018  *    
00019  */
00020 package moa.classifiers.trees;
00021 
00022 import java.util.Arrays;
00023 import java.util.Comparator;
00024 import java.util.HashSet;
00025 import java.util.LinkedList;
00026 import java.util.List;
00027 import java.util.Set;
00028 
00029 import moa.AbstractMOAObject;
00030 import moa.classifiers.AbstractClassifier;
00031 import moa.classifiers.bayes.NaiveBayes;
00032 import moa.classifiers.core.attributeclassobservers.AttributeClassObserver;
00033 import moa.classifiers.core.AttributeSplitSuggestion;
00034 import moa.classifiers.core.attributeclassobservers.DiscreteAttributeClassObserver;
00035 import moa.classifiers.core.conditionaltests.InstanceConditionalTest;
00036 import moa.classifiers.core.attributeclassobservers.NullAttributeClassObserver;
00037 import moa.classifiers.core.attributeclassobservers.NumericAttributeClassObserver;
00038 import moa.classifiers.core.splitcriteria.SplitCriterion;
00039 import moa.core.AutoExpandVector;
00040 import moa.core.DoubleVector;
00041 import moa.core.Measurement;
00042 import moa.core.StringUtils;
00043 import moa.options.ClassOption;
00044 import moa.options.FlagOption;
00045 import moa.options.FloatOption;
00046 import moa.options.IntOption;
00047 import moa.core.SizeOf;
00048 import moa.options.*;
00049 import weka.core.Instance;
00050 import weka.core.Utils;
00051 
00096 public class HoeffdingTree extends AbstractClassifier {
00097 
00098     private static final long serialVersionUID = 1L;
00099 
00100     @Override
00101     public String getPurposeString() {
00102         return "Hoeffding Tree or VFDT.";
00103     }
00104 
00105     public IntOption maxByteSizeOption = new IntOption("maxByteSize", 'm',
00106             "Maximum memory consumed by the tree.", 33554432, 0,
00107             Integer.MAX_VALUE);
00108 
00109     /*
00110      * public MultiChoiceOption numericEstimatorOption = new MultiChoiceOption(
00111      * "numericEstimator", 'n', "Numeric estimator to use.", new String[]{
00112      * "GAUSS10", "GAUSS100", "GK10", "GK100", "GK1000", "VFML10", "VFML100",
00113      * "VFML1000", "BINTREE"}, new String[]{ "Gaussian approximation evaluating
00114      * 10 splitpoints", "Gaussian approximation evaluating 100 splitpoints",
00115      * "Greenwald-Khanna quantile summary with 10 tuples", "Greenwald-Khanna
00116      * quantile summary with 100 tuples", "Greenwald-Khanna quantile summary
00117      * with 1000 tuples", "VFML method with 10 bins", "VFML method with 100
00118      * bins", "VFML method with 1000 bins", "Exhaustive binary tree"}, 0);
00119      */
00120     public ClassOption numericEstimatorOption = new ClassOption("numericEstimator",
00121             'n', "Numeric estimator to use.", NumericAttributeClassObserver.class,
00122             "GaussianNumericAttributeClassObserver");
00123 
00124     public ClassOption nominalEstimatorOption = new ClassOption("nominalEstimator",
00125             'd', "Nominal estimator to use.", DiscreteAttributeClassObserver.class,
00126             "NominalAttributeClassObserver");
00127 
00128     public IntOption memoryEstimatePeriodOption = new IntOption(
00129             "memoryEstimatePeriod", 'e',
00130             "How many instances between memory consumption checks.", 1000000,
00131             0, Integer.MAX_VALUE);
00132 
00133     public IntOption gracePeriodOption = new IntOption(
00134             "gracePeriod",
00135             'g',
00136             "The number of instances a leaf should observe between split attempts.",
00137             200, 0, Integer.MAX_VALUE);
00138 
00139     public ClassOption splitCriterionOption = new ClassOption("splitCriterion",
00140             's', "Split criterion to use.", SplitCriterion.class,
00141             "InfoGainSplitCriterion");
00142 
00143     public FloatOption splitConfidenceOption = new FloatOption(
00144             "splitConfidence",
00145             'c',
00146             "The allowable error in split decision, values closer to 0 will take longer to decide.",
00147             0.0000001, 0.0, 1.0);
00148 
00149     public FloatOption tieThresholdOption = new FloatOption("tieThreshold",
00150             't', "Threshold below which a split will be forced to break ties.",
00151             0.05, 0.0, 1.0);
00152 
00153     public FlagOption binarySplitsOption = new FlagOption("binarySplits", 'b',
00154             "Only allow binary splits.");
00155 
00156     public FlagOption stopMemManagementOption = new FlagOption(
00157             "stopMemManagement", 'z',
00158             "Stop growing as soon as memory limit is hit.");
00159 
00160     public FlagOption removePoorAttsOption = new FlagOption("removePoorAtts",
00161             'r', "Disable poor attributes.");
00162 
00163     public FlagOption noPrePruneOption = new FlagOption("noPrePrune", 'p',
00164             "Disable pre-pruning.");
00165 
00166     public static class FoundNode {
00167 
00168         public Node node;
00169 
00170         public SplitNode parent;
00171 
00172         public int parentBranch;
00173 
00174         public FoundNode(Node node, SplitNode parent, int parentBranch) {
00175             this.node = node;
00176             this.parent = parent;
00177             this.parentBranch = parentBranch;
00178         }
00179     }
00180 
00181     public static class Node extends AbstractMOAObject {
00182 
00183         private static final long serialVersionUID = 1L;
00184 
00185         protected DoubleVector observedClassDistribution;
00186 
00187         public Node(double[] classObservations) {
00188             this.observedClassDistribution = new DoubleVector(classObservations);
00189         }
00190 
00191         public int calcByteSize() {
00192             return (int) (SizeOf.sizeOf(this) + SizeOf.fullSizeOf(this.observedClassDistribution));
00193         }
00194 
00195         public int calcByteSizeIncludingSubtree() {
00196             return calcByteSize();
00197         }
00198 
00199         public boolean isLeaf() {
00200             return true;
00201         }
00202 
00203         public FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent,
00204                 int parentBranch) {
00205             return new FoundNode(this, parent, parentBranch);
00206         }
00207 
00208         public double[] getObservedClassDistribution() {
00209             return this.observedClassDistribution.getArrayCopy();
00210         }
00211 
00212         public double[] getClassVotes(Instance inst, HoeffdingTree ht) {
00213             return this.observedClassDistribution.getArrayCopy();
00214         }
00215 
00216         public boolean observedClassDistributionIsPure() {
00217             return this.observedClassDistribution.numNonZeroEntries() < 2;
00218         }
00219 
00220         public void describeSubtree(HoeffdingTree ht, StringBuilder out,
00221                 int indent) {
00222             StringUtils.appendIndented(out, indent, "Leaf ");
00223             out.append(ht.getClassNameString());
00224             out.append(" = ");
00225             out.append(ht.getClassLabelString(this.observedClassDistribution.maxIndex()));
00226             out.append(" weights: ");
00227             this.observedClassDistribution.getSingleLineDescription(out,
00228                     ht.treeRoot.observedClassDistribution.numValues());
00229             StringUtils.appendNewline(out);
00230         }
00231 
00232         public int subtreeDepth() {
00233             return 0;
00234         }
00235 
00236         public double calculatePromise() {
00237             double totalSeen = this.observedClassDistribution.sumOfValues();
00238             return totalSeen > 0.0 ? (totalSeen - this.observedClassDistribution.getValue(this.observedClassDistribution.maxIndex()))
00239                     : 0.0;
00240         }
00241 
00242         @Override
00243         public void getDescription(StringBuilder sb, int indent) {
00244             describeSubtree(null, sb, indent);
00245         }
00246     }
00247 
00248     public static class SplitNode extends Node {
00249 
00250         private static final long serialVersionUID = 1L;
00251 
00252         protected InstanceConditionalTest splitTest;
00253 
00254         protected AutoExpandVector<Node> children = new AutoExpandVector<Node>();
00255 
00256         @Override
00257         public int calcByteSize() {
00258             return super.calcByteSize()
00259                     + (int) (SizeOf.sizeOf(this.children) + SizeOf.fullSizeOf(this.splitTest));
00260         }
00261 
00262         @Override
00263         public int calcByteSizeIncludingSubtree() {
00264             int byteSize = calcByteSize();
00265             for (Node child : this.children) {
00266                 if (child != null) {
00267                     byteSize += child.calcByteSizeIncludingSubtree();
00268                 }
00269             }
00270             return byteSize;
00271         }
00272 
00273         public SplitNode(InstanceConditionalTest splitTest,
00274                 double[] classObservations) {
00275             super(classObservations);
00276             this.splitTest = splitTest;
00277         }
00278 
00279         public int numChildren() {
00280             return this.children.size();
00281         }
00282 
00283         public void setChild(int index, Node child) {
00284             if ((this.splitTest.maxBranches() >= 0)
00285                     && (index >= this.splitTest.maxBranches())) {
00286                 throw new IndexOutOfBoundsException();
00287             }
00288             this.children.set(index, child);
00289         }
00290 
00291         public Node getChild(int index) {
00292             return this.children.get(index);
00293         }
00294 
00295         public int instanceChildIndex(Instance inst) {
00296             return this.splitTest.branchForInstance(inst);
00297         }
00298 
00299         @Override
00300         public boolean isLeaf() {
00301             return false;
00302         }
00303 
00304         @Override
00305         public FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent,
00306                 int parentBranch) {
00307             int childIndex = instanceChildIndex(inst);
00308             if (childIndex >= 0) {
00309                 Node child = getChild(childIndex);
00310                 if (child != null) {
00311                     return child.filterInstanceToLeaf(inst, this, childIndex);
00312                 }
00313                 return new FoundNode(null, this, childIndex);
00314             }
00315             return new FoundNode(this, parent, parentBranch);
00316         }
00317 
00318         @Override
00319         public void describeSubtree(HoeffdingTree ht, StringBuilder out,
00320                 int indent) {
00321             for (int branch = 0; branch < numChildren(); branch++) {
00322                 Node child = getChild(branch);
00323                 if (child != null) {
00324                     StringUtils.appendIndented(out, indent, "if ");
00325                     out.append(this.splitTest.describeConditionForBranch(branch,
00326                             ht.getModelContext()));
00327                     out.append(": ");
00328                     StringUtils.appendNewline(out);
00329                     child.describeSubtree(ht, out, indent + 2);
00330                 }
00331             }
00332         }
00333 
00334         @Override
00335         public int subtreeDepth() {
00336             int maxChildDepth = 0;
00337             for (Node child : this.children) {
00338                 if (child != null) {
00339                     int depth = child.subtreeDepth();
00340                     if (depth > maxChildDepth) {
00341                         maxChildDepth = depth;
00342                     }
00343                 }
00344             }
00345             return maxChildDepth + 1;
00346         }
00347     }
00348 
00349     public static abstract class LearningNode extends Node {
00350 
00351         private static final long serialVersionUID = 1L;
00352 
00353         public LearningNode(double[] initialClassObservations) {
00354             super(initialClassObservations);
00355         }
00356 
00357         public abstract void learnFromInstance(Instance inst, HoeffdingTree ht);
00358     }
00359 
00360     public static class InactiveLearningNode extends LearningNode {
00361 
00362         private static final long serialVersionUID = 1L;
00363 
00364         public InactiveLearningNode(double[] initialClassObservations) {
00365             super(initialClassObservations);
00366         }
00367 
00368         @Override
00369         public void learnFromInstance(Instance inst, HoeffdingTree ht) {
00370             this.observedClassDistribution.addToValue((int) inst.classValue(),
00371                     inst.weight());
00372         }
00373     }
00374 
00375     public static class ActiveLearningNode extends LearningNode {
00376 
00377         private static final long serialVersionUID = 1L;
00378 
00379         protected double weightSeenAtLastSplitEvaluation;
00380 
00381         protected AutoExpandVector<AttributeClassObserver> attributeObservers = new AutoExpandVector<AttributeClassObserver>();
00382 
00383         public ActiveLearningNode(double[] initialClassObservations) {
00384             super(initialClassObservations);
00385             this.weightSeenAtLastSplitEvaluation = getWeightSeen();
00386         }
00387 
00388         @Override
00389         public int calcByteSize() {
00390             return super.calcByteSize()
00391                     + (int) (SizeOf.fullSizeOf(this.attributeObservers));
00392         }
00393 
00394         @Override
00395         public void learnFromInstance(Instance inst, HoeffdingTree ht) {
00396             this.observedClassDistribution.addToValue((int) inst.classValue(),
00397                     inst.weight());
00398             for (int i = 0; i < inst.numAttributes() - 1; i++) {
00399                 int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst);
00400                 AttributeClassObserver obs = this.attributeObservers.get(i);
00401                 if (obs == null) {
00402                     obs = inst.attribute(instAttIndex).isNominal() ? ht.newNominalClassObserver() : ht.newNumericClassObserver();
00403                     this.attributeObservers.set(i, obs);
00404                 }
00405                 obs.observeAttributeClass(inst.value(instAttIndex), (int) inst.classValue(), inst.weight());
00406             }
00407         }
00408 
00409         public double getWeightSeen() {
00410             return this.observedClassDistribution.sumOfValues();
00411         }
00412 
00413         public double getWeightSeenAtLastSplitEvaluation() {
00414             return this.weightSeenAtLastSplitEvaluation;
00415         }
00416 
00417         public void setWeightSeenAtLastSplitEvaluation(double weight) {
00418             this.weightSeenAtLastSplitEvaluation = weight;
00419         }
00420 
00421         public AttributeSplitSuggestion[] getBestSplitSuggestions(
00422                 SplitCriterion criterion, HoeffdingTree ht) {
00423             List<AttributeSplitSuggestion> bestSuggestions = new LinkedList<AttributeSplitSuggestion>();
00424             double[] preSplitDist = this.observedClassDistribution.getArrayCopy();
00425             if (!ht.noPrePruneOption.isSet()) {
00426                 // add null split as an option
00427                 bestSuggestions.add(new AttributeSplitSuggestion(null,
00428                         new double[0][], criterion.getMeritOfSplit(
00429                         preSplitDist,
00430                         new double[][]{preSplitDist})));
00431             }
00432             for (int i = 0; i < this.attributeObservers.size(); i++) {
00433                 AttributeClassObserver obs = this.attributeObservers.get(i);
00434                 if (obs != null) {
00435                     AttributeSplitSuggestion bestSuggestion = obs.getBestEvaluatedSplitSuggestion(criterion,
00436                             preSplitDist, i, ht.binarySplitsOption.isSet());
00437                     if (bestSuggestion != null) {
00438                         bestSuggestions.add(bestSuggestion);
00439                     }
00440                 }
00441             }
00442             return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]);
00443         }
00444 
00445         public void disableAttribute(int attIndex) {
00446             this.attributeObservers.set(attIndex,
00447                     new NullAttributeClassObserver());
00448         }
00449     }
00450 
00451     protected Node treeRoot;
00452 
00453     protected int decisionNodeCount;
00454 
00455     protected int activeLeafNodeCount;
00456 
00457     protected int inactiveLeafNodeCount;
00458 
00459     protected double inactiveLeafByteSizeEstimate;
00460 
00461     protected double activeLeafByteSizeEstimate;
00462 
00463     protected double byteSizeEstimateOverheadFraction;
00464 
00465     protected boolean growthAllowed;
00466 
00467     public int calcByteSize() {
00468         int size = (int) SizeOf.sizeOf(this);
00469         if (this.treeRoot != null) {
00470             size += this.treeRoot.calcByteSizeIncludingSubtree();
00471         }
00472         return size;
00473     }
00474 
00475     @Override
00476     public int measureByteSize() {
00477         return calcByteSize();
00478     }
00479 
00480     @Override
00481     public void resetLearningImpl() {
00482         this.treeRoot = null;
00483         this.decisionNodeCount = 0;
00484         this.activeLeafNodeCount = 0;
00485         this.inactiveLeafNodeCount = 0;
00486         this.inactiveLeafByteSizeEstimate = 0.0;
00487         this.activeLeafByteSizeEstimate = 0.0;
00488         this.byteSizeEstimateOverheadFraction = 1.0;
00489         this.growthAllowed = true;
00490         if (this.leafpredictionOption.getChosenIndex()>0) { 
00491             this.removePoorAttsOption = null;
00492         }
00493     }
00494 
00495     @Override
00496     public void trainOnInstanceImpl(Instance inst) {
00497         if (this.treeRoot == null) {
00498             this.treeRoot = newLearningNode();
00499             this.activeLeafNodeCount = 1;
00500         }
00501         FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1);
00502         Node leafNode = foundNode.node;
00503         if (leafNode == null) {
00504             leafNode = newLearningNode();
00505             foundNode.parent.setChild(foundNode.parentBranch, leafNode);
00506             this.activeLeafNodeCount++;
00507         }
00508         if (leafNode instanceof LearningNode) {
00509             LearningNode learningNode = (LearningNode) leafNode;
00510             learningNode.learnFromInstance(inst, this);
00511             if (this.growthAllowed
00512                     && (learningNode instanceof ActiveLearningNode)) {
00513                 ActiveLearningNode activeLearningNode = (ActiveLearningNode) learningNode;
00514                 double weightSeen = activeLearningNode.getWeightSeen();
00515                 if (weightSeen
00516                         - activeLearningNode.getWeightSeenAtLastSplitEvaluation() >= this.gracePeriodOption.getValue()) {
00517                     attemptToSplit(activeLearningNode, foundNode.parent,
00518                             foundNode.parentBranch);
00519                     activeLearningNode.setWeightSeenAtLastSplitEvaluation(weightSeen);
00520                 }
00521             }
00522         }
00523         if (this.trainingWeightSeenByModel
00524                 % this.memoryEstimatePeriodOption.getValue() == 0) {
00525             estimateModelByteSizes();
00526         }
00527     }
00528 
00529     @Override
00530     public double[] getVotesForInstance(Instance inst) {
00531         if (this.treeRoot != null) {
00532             FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst,
00533                     null, -1);
00534             Node leafNode = foundNode.node;
00535             if (leafNode == null) {
00536                 leafNode = foundNode.parent;
00537             }
00538             return leafNode.getClassVotes(inst, this);
00539         }
00540         return new double[0];
00541     }
00542 
00543     @Override
00544     protected Measurement[] getModelMeasurementsImpl() {
00545         return new Measurement[]{
00546                     new Measurement("tree size (nodes)", this.decisionNodeCount
00547                     + this.activeLeafNodeCount + this.inactiveLeafNodeCount),
00548                     new Measurement("tree size (leaves)", this.activeLeafNodeCount
00549                     + this.inactiveLeafNodeCount),
00550                     new Measurement("active learning leaves",
00551                     this.activeLeafNodeCount),
00552                     new Measurement("tree depth", measureTreeDepth()),
00553                     new Measurement("active leaf byte size estimate",
00554                     this.activeLeafByteSizeEstimate),
00555                     new Measurement("inactive leaf byte size estimate",
00556                     this.inactiveLeafByteSizeEstimate),
00557                     new Measurement("byte size estimate overhead",
00558                     this.byteSizeEstimateOverheadFraction)};
00559     }
00560 
00561     public int measureTreeDepth() {
00562         if (this.treeRoot != null) {
00563             return this.treeRoot.subtreeDepth();
00564         }
00565         return 0;
00566     }
00567 
00568     @Override
00569     public void getModelDescription(StringBuilder out, int indent) {
00570         this.treeRoot.describeSubtree(this, out, indent);
00571     }
00572 
00573     @Override
00574     public boolean isRandomizable() {
00575         return false;
00576     }
00577 
00578     public static double computeHoeffdingBound(double range, double confidence,
00579             double n) {
00580         return Math.sqrt(((range * range) * Math.log(1.0 / confidence))
00581                 / (2.0 * n));
00582     }
00583 
00584     //Procedure added for Hoeffding Adaptive Trees (ADWIN)
00585     protected SplitNode newSplitNode(InstanceConditionalTest splitTest,
00586             double[] classObservations) {
00587         return new SplitNode(splitTest, classObservations);
00588     }
00589 
00590     protected AttributeClassObserver newNominalClassObserver() {
00591         AttributeClassObserver nominalClassObserver = (AttributeClassObserver) getPreparedClassOption(this.nominalEstimatorOption);
00592         return (AttributeClassObserver) nominalClassObserver.copy();
00593     }
00594 
00595     protected AttributeClassObserver newNumericClassObserver() {
00596         AttributeClassObserver numericClassObserver = (AttributeClassObserver) getPreparedClassOption(this.numericEstimatorOption);
00597         return (AttributeClassObserver) numericClassObserver.copy();
00598     }
00599 
00600     protected void attemptToSplit(ActiveLearningNode node, SplitNode parent,
00601             int parentIndex) {
00602         if (!node.observedClassDistributionIsPure()) {
00603             SplitCriterion splitCriterion = (SplitCriterion) getPreparedClassOption(this.splitCriterionOption);
00604             AttributeSplitSuggestion[] bestSplitSuggestions = node.getBestSplitSuggestions(splitCriterion, this);
00605             Arrays.sort(bestSplitSuggestions);
00606             boolean shouldSplit = false;
00607             if (bestSplitSuggestions.length < 2) {
00608                 shouldSplit = bestSplitSuggestions.length > 0;
00609             } else {
00610                 double hoeffdingBound = computeHoeffdingBound(splitCriterion.getRangeOfMerit(node.getObservedClassDistribution()),
00611                         this.splitConfidenceOption.getValue(), node.getWeightSeen());
00612                 AttributeSplitSuggestion bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
00613                 AttributeSplitSuggestion secondBestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 2];
00614                 if ((bestSuggestion.merit - secondBestSuggestion.merit > hoeffdingBound)
00615                         || (hoeffdingBound < this.tieThresholdOption.getValue())) {
00616                     shouldSplit = true;
00617                 }
00618                 // }
00619                 if ((this.removePoorAttsOption != null)
00620                         && this.removePoorAttsOption.isSet()) {
00621                     Set<Integer> poorAtts = new HashSet<Integer>();
00622                     // scan 1 - add any poor to set
00623                     for (int i = 0; i < bestSplitSuggestions.length; i++) {
00624                         if (bestSplitSuggestions[i].splitTest != null) {
00625                             int[] splitAtts = bestSplitSuggestions[i].splitTest.getAttsTestDependsOn();
00626                             if (splitAtts.length == 1) {
00627                                 if (bestSuggestion.merit
00628                                         - bestSplitSuggestions[i].merit > hoeffdingBound) {
00629                                     poorAtts.add(new Integer(splitAtts[0]));
00630                                 }
00631                             }
00632                         }
00633                     }
00634                     // scan 2 - remove good ones from set
00635                     for (int i = 0; i < bestSplitSuggestions.length; i++) {
00636                         if (bestSplitSuggestions[i].splitTest != null) {
00637                             int[] splitAtts = bestSplitSuggestions[i].splitTest.getAttsTestDependsOn();
00638                             if (splitAtts.length == 1) {
00639                                 if (bestSuggestion.merit
00640                                         - bestSplitSuggestions[i].merit < hoeffdingBound) {
00641                                     poorAtts.remove(new Integer(splitAtts[0]));
00642                                 }
00643                             }
00644                         }
00645                     }
00646                     for (int poorAtt : poorAtts) {
00647                         node.disableAttribute(poorAtt);
00648                     }
00649                 }
00650             }
00651             if (shouldSplit) {
00652                 AttributeSplitSuggestion splitDecision = bestSplitSuggestions[bestSplitSuggestions.length - 1];
00653                 if (splitDecision.splitTest == null) {
00654                     // preprune - null wins
00655                     deactivateLearningNode(node, parent, parentIndex);
00656                 } else {
00657                     SplitNode newSplit = newSplitNode(splitDecision.splitTest,
00658                             node.getObservedClassDistribution());
00659                     for (int i = 0; i < splitDecision.numSplits(); i++) {
00660                         Node newChild = newLearningNode(splitDecision.resultingClassDistributionFromSplit(i));
00661                         newSplit.setChild(i, newChild);
00662                     }
00663                     this.activeLeafNodeCount--;
00664                     this.decisionNodeCount++;
00665                     this.activeLeafNodeCount += splitDecision.numSplits();
00666                     if (parent == null) {
00667                         this.treeRoot = newSplit;
00668                     } else {
00669                         parent.setChild(parentIndex, newSplit);
00670                     }
00671                 }
00672                 // manage memory
00673                 enforceTrackerLimit();
00674             }
00675         }
00676     }
00677 
00678     public void enforceTrackerLimit() {
00679         if ((this.inactiveLeafNodeCount > 0)
00680                 || ((this.activeLeafNodeCount * this.activeLeafByteSizeEstimate + this.inactiveLeafNodeCount
00681                 * this.inactiveLeafByteSizeEstimate)
00682                 * this.byteSizeEstimateOverheadFraction > this.maxByteSizeOption.getValue())) {
00683             if (this.stopMemManagementOption.isSet()) {
00684                 this.growthAllowed = false;
00685                 return;
00686             }
00687             FoundNode[] learningNodes = findLearningNodes();
00688             Arrays.sort(learningNodes, new Comparator<FoundNode>() {
00689 
00690                 @Override
00691                 public int compare(FoundNode fn1, FoundNode fn2) {
00692                     return Double.compare(fn1.node.calculatePromise(), fn2.node.calculatePromise());
00693                 }
00694             });
00695             int maxActive = 0;
00696             while (maxActive < learningNodes.length) {
00697                 maxActive++;
00698                 if ((maxActive * this.activeLeafByteSizeEstimate + (learningNodes.length - maxActive)
00699                         * this.inactiveLeafByteSizeEstimate)
00700                         * this.byteSizeEstimateOverheadFraction > this.maxByteSizeOption.getValue()) {
00701                     maxActive--;
00702                     break;
00703                 }
00704             }
00705             int cutoff = learningNodes.length - maxActive;
00706             for (int i = 0; i < cutoff; i++) {
00707                 if (learningNodes[i].node instanceof ActiveLearningNode) {
00708                     deactivateLearningNode(
00709                             (ActiveLearningNode) learningNodes[i].node,
00710                             learningNodes[i].parent,
00711                             learningNodes[i].parentBranch);
00712                 }
00713             }
00714             for (int i = cutoff; i < learningNodes.length; i++) {
00715                 if (learningNodes[i].node instanceof InactiveLearningNode) {
00716                     activateLearningNode(
00717                             (InactiveLearningNode) learningNodes[i].node,
00718                             learningNodes[i].parent,
00719                             learningNodes[i].parentBranch);
00720                 }
00721             }
00722         }
00723     }
00724 
00725     public void estimateModelByteSizes() {
00726         FoundNode[] learningNodes = findLearningNodes();
00727         long totalActiveSize = 0;
00728         long totalInactiveSize = 0;
00729         for (FoundNode foundNode : learningNodes) {
00730             if (foundNode.node instanceof ActiveLearningNode) {
00731                 totalActiveSize += SizeOf.fullSizeOf(foundNode.node);
00732             } else {
00733                 totalInactiveSize += SizeOf.fullSizeOf(foundNode.node);
00734             }
00735         }
00736         if (totalActiveSize > 0) {
00737             this.activeLeafByteSizeEstimate = (double) totalActiveSize
00738                     / this.activeLeafNodeCount;
00739         }
00740         if (totalInactiveSize > 0) {
00741             this.inactiveLeafByteSizeEstimate = (double) totalInactiveSize
00742                     / this.inactiveLeafNodeCount;
00743         }
00744         int actualModelSize = this.measureByteSize();
00745         double estimatedModelSize = (this.activeLeafNodeCount
00746                 * this.activeLeafByteSizeEstimate + this.inactiveLeafNodeCount
00747                 * this.inactiveLeafByteSizeEstimate);
00748         this.byteSizeEstimateOverheadFraction = actualModelSize
00749                 / estimatedModelSize;
00750         if (actualModelSize > this.maxByteSizeOption.getValue()) {
00751             enforceTrackerLimit();
00752         }
00753     }
00754 
00755     public void deactivateAllLeaves() {
00756         FoundNode[] learningNodes = findLearningNodes();
00757         for (int i = 0; i < learningNodes.length; i++) {
00758             if (learningNodes[i].node instanceof ActiveLearningNode) {
00759                 deactivateLearningNode(
00760                         (ActiveLearningNode) learningNodes[i].node,
00761                         learningNodes[i].parent, learningNodes[i].parentBranch);
00762             }
00763         }
00764     }
00765 
00766     protected void deactivateLearningNode(ActiveLearningNode toDeactivate,
00767             SplitNode parent, int parentBranch) {
00768         Node newLeaf = new InactiveLearningNode(toDeactivate.getObservedClassDistribution());
00769         if (parent == null) {
00770             this.treeRoot = newLeaf;
00771         } else {
00772             parent.setChild(parentBranch, newLeaf);
00773         }
00774         this.activeLeafNodeCount--;
00775         this.inactiveLeafNodeCount++;
00776     }
00777 
00778     protected void activateLearningNode(InactiveLearningNode toActivate,
00779             SplitNode parent, int parentBranch) {
00780         Node newLeaf = newLearningNode(toActivate.getObservedClassDistribution());
00781         if (parent == null) {
00782             this.treeRoot = newLeaf;
00783         } else {
00784             parent.setChild(parentBranch, newLeaf);
00785         }
00786         this.activeLeafNodeCount++;
00787         this.inactiveLeafNodeCount--;
00788     }
00789 
00790     protected FoundNode[] findLearningNodes() {
00791         List<FoundNode> foundList = new LinkedList<FoundNode>();
00792         findLearningNodes(this.treeRoot, null, -1, foundList);
00793         return foundList.toArray(new FoundNode[foundList.size()]);
00794     }
00795 
00796     protected void findLearningNodes(Node node, SplitNode parent,
00797             int parentBranch, List<FoundNode> found) {
00798         if (node != null) {
00799             if (node instanceof LearningNode) {
00800                 found.add(new FoundNode(node, parent, parentBranch));
00801             }
00802             if (node instanceof SplitNode) {
00803                 SplitNode splitNode = (SplitNode) node;
00804                 for (int i = 0; i < splitNode.numChildren(); i++) {
00805                     findLearningNodes(splitNode.getChild(i), splitNode, i,
00806                             found);
00807                 }
00808             }
00809         }
00810     }
00811 
00812     public MultiChoiceOption leafpredictionOption = new MultiChoiceOption(
00813             "leafprediction", 'l', "Leaf prediction to use.", new String[]{
00814                 "MC", "NB", "NBAdaptive"}, new String[]{
00815                 "Majority class",
00816                 "Naive Bayes",
00817                 "Naive Bayes Adaptive"}, 2);
00818 
00819     public IntOption nbThresholdOption = new IntOption(
00820             "nbThreshold",
00821             'q',
00822             "The number of instances a leaf should observe before permitting Naive Bayes.",
00823             0, 0, Integer.MAX_VALUE);
00824 
00825     public static class LearningNodeNB extends ActiveLearningNode {
00826 
00827         private static final long serialVersionUID = 1L;
00828 
00829         public LearningNodeNB(double[] initialClassObservations) {
00830             super(initialClassObservations);
00831         }
00832 
00833         @Override
00834         public double[] getClassVotes(Instance inst, HoeffdingTree ht) {
00835             if (getWeightSeen() >= ht.nbThresholdOption.getValue()) {
00836                 return NaiveBayes.doNaiveBayesPrediction(inst,
00837                         this.observedClassDistribution,
00838                         this.attributeObservers);
00839             }
00840             return super.getClassVotes(inst, ht);
00841         }
00842 
00843         @Override
00844         public void disableAttribute(int attIndex) {
00845             // should not disable poor atts - they are used in NB calc
00846         }
00847     }
00848 
00849     public static class LearningNodeNBAdaptive extends LearningNodeNB {
00850 
00851         private static final long serialVersionUID = 1L;
00852 
00853         protected double mcCorrectWeight = 0.0;
00854 
00855         protected double nbCorrectWeight = 0.0;
00856 
00857         public LearningNodeNBAdaptive(double[] initialClassObservations) {
00858             super(initialClassObservations);
00859         }
00860 
00861         @Override
00862         public void learnFromInstance(Instance inst, HoeffdingTree ht) {
00863             int trueClass = (int) inst.classValue();
00864             if (this.observedClassDistribution.maxIndex() == trueClass) {
00865                 this.mcCorrectWeight += inst.weight();
00866             }
00867             if (Utils.maxIndex(NaiveBayes.doNaiveBayesPrediction(inst,
00868                     this.observedClassDistribution, this.attributeObservers)) == trueClass) {
00869                 this.nbCorrectWeight += inst.weight();
00870             }
00871             super.learnFromInstance(inst, ht);
00872         }
00873 
00874         @Override
00875         public double[] getClassVotes(Instance inst, HoeffdingTree ht) {
00876             if (this.mcCorrectWeight > this.nbCorrectWeight) {
00877                 return this.observedClassDistribution.getArrayCopy();
00878             }
00879             return NaiveBayes.doNaiveBayesPrediction(inst,
00880                     this.observedClassDistribution, this.attributeObservers);
00881         }
00882     }
00883 
00884     protected LearningNode newLearningNode() {
00885         return newLearningNode(new double[0]);
00886     }
00887 
00888     protected LearningNode newLearningNode(double[] initialClassObservations) {
00889         LearningNode ret;
00890         int predictionOption = this.leafpredictionOption.getChosenIndex();
00891         if (predictionOption == 0) { //MC
00892             ret = new ActiveLearningNode(initialClassObservations);
00893         } else if (predictionOption == 1) { //NB
00894             ret = new LearningNodeNB(initialClassObservations);
00895         } else { //NBAdaptive
00896             ret = new LearningNodeNBAdaptive(initialClassObservations);
00897         }
00898         return ret;
00899     }
00900 }
 All Classes Namespaces Files Functions Variables Enumerations