MOA 12.03
Real Time Analytics for Data Streams
LimAttHoeffdingTree.java
Go to the documentation of this file.
00001 /*
00002  *    LimAttHoeffdingTree.java
00003  *    Copyright (C) 2010 University of Waikato, Hamilton, New Zealand
00004  *    @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
00005  *
00006  *    This program is free software; you can redistribute it and/or modify
00007  *    it under the terms of the GNU General Public License as published by
00008  *    the Free Software Foundation; either version 3 of the License, or
00009  *    (at your option) any later version.
00010  *
00011  *    This program is distributed in the hope that it will be useful,
00012  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00013  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00014  *    GNU General Public License for more details.
00015  *
00016  *    You should have received a copy of the GNU General Public License
00017  *    along with this program. If not, see <http://www.gnu.org/licenses/>.
00018  *    
00019  */
00020 package moa.classifiers.trees;
00021 
00022 import moa.classifiers.bayes.NaiveBayes;
00023 import moa.classifiers.core.attributeclassobservers.AttributeClassObserver;
00024 import weka.core.Instance;
00025 import weka.core.Utils;
00026 
00055 public class LimAttHoeffdingTree extends HoeffdingTree {
00056 
00057     private static final long serialVersionUID = 1L;
00058 
00059     @Override
00060     public String getPurposeString() {
00061         return "Hoeffding decision trees with a restricted number of attributes for data streams.";
00062     }
00063 
00064     protected int[] listAttributes;
00065 
00066     public void setlistAttributes(int[] list) {
00067         this.listAttributes = list;
00068     }
00069 
00070     public static class LimAttLearningNode extends ActiveLearningNode {
00071 
00072         private static final long serialVersionUID = 1L;
00073 
00074         protected double weightSeenAtLastSplitEvaluation;
00075 
00076         protected int[] listAttributes;
00077 
00078         protected int numAttributes;
00079 
00080         public LimAttLearningNode(double[] initialClassObservations) {
00081             super(initialClassObservations);
00082         }
00083 
00084         public void setlistAttributes(int[] list) {
00085             this.listAttributes = list;
00086             this.numAttributes = list.length;
00087         }
00088 
00089         @Override
00090         public void learnFromInstance(Instance inst, HoeffdingTree ht) {
00091             this.observedClassDistribution.addToValue((int) inst.classValue(),
00092                     inst.weight());
00093             if (this.listAttributes == null) {
00094                 setlistAttributes(((LimAttHoeffdingTree) ht).listAttributes);
00095             }
00096 
00097             for (int j = 0; j < this.numAttributes; j++) {
00098                 int i = this.listAttributes[j];
00099                 int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst);
00100                 AttributeClassObserver obs = this.attributeObservers.get(i);
00101                 if (obs == null) {
00102                     obs = inst.attribute(instAttIndex).isNominal() ? ht.newNominalClassObserver() : ht.newNumericClassObserver();
00103                     this.attributeObservers.set(i, obs);
00104                 }
00105                 obs.observeAttributeClass(inst.value(instAttIndex), (int) inst.classValue(), inst.weight());
00106             }
00107         }
00108     }
00109 
00110     public LimAttHoeffdingTree() {
00111         this.removePoorAttsOption = null;
00112     }
00113 
00114     public static class LearningNodeNB extends LimAttLearningNode {
00115 
00116         private static final long serialVersionUID = 1L;
00117 
00118         public LearningNodeNB(double[] initialClassObservations) {
00119             super(initialClassObservations);
00120         }
00121 
00122         @Override
00123         public double[] getClassVotes(Instance inst, HoeffdingTree ht) {
00124             if (getWeightSeen() >= ht.nbThresholdOption.getValue()) {
00125                 return NaiveBayes.doNaiveBayesPrediction(inst,
00126                         this.observedClassDistribution,
00127                         this.attributeObservers);
00128             }
00129             return super.getClassVotes(inst, ht);
00130         }
00131 
00132         @Override
00133         public void disableAttribute(int attIndex) {
00134             // should not disable poor atts - they are used in NB calc
00135         }
00136     }
00137 
00138     public static class LearningNodeNBAdaptive extends LearningNodeNB {
00139 
00140         private static final long serialVersionUID = 1L;
00141 
00142         protected double mcCorrectWeight = 0.0;
00143 
00144         protected double nbCorrectWeight = 0.0;
00145 
00146         public LearningNodeNBAdaptive(double[] initialClassObservations) {
00147             super(initialClassObservations);
00148         }
00149 
00150         @Override
00151         public void learnFromInstance(Instance inst, HoeffdingTree ht) {
00152             int trueClass = (int) inst.classValue();
00153             if (this.observedClassDistribution.maxIndex() == trueClass) {
00154                 this.mcCorrectWeight += inst.weight();
00155             }
00156             if (Utils.maxIndex(NaiveBayes.doNaiveBayesPrediction(inst,
00157                     this.observedClassDistribution, this.attributeObservers)) == trueClass) {
00158                 this.nbCorrectWeight += inst.weight();
00159             }
00160             super.learnFromInstance(inst, ht);
00161         }
00162 
00163         @Override
00164         public double[] getClassVotes(Instance inst, HoeffdingTree ht) {
00165             if (this.mcCorrectWeight > this.nbCorrectWeight) {
00166                 return this.observedClassDistribution.getArrayCopy();
00167             }
00168             double ret[] = NaiveBayes.doNaiveBayesPrediction(inst,
00169                     this.observedClassDistribution, this.attributeObservers);
00170             for (int i = 0; i < ret.length; i++) {
00171                 ret[i] *= this.observedClassDistribution.sumOfValues();
00172             }
00173             return ret;
00174         }
00175     }
00176 
00177     @Override
00178     protected LearningNode newLearningNode(double[] initialClassObservations) {
00179         LearningNode ret;
00180         int predictionOption = this.leafpredictionOption.getChosenIndex();
00181         if (predictionOption == 0) { //MC
00182             ret = new LimAttLearningNode(initialClassObservations);
00183         } else if (predictionOption == 1) { //NB
00184             ret = new LearningNodeNB(initialClassObservations);
00185         } else { //NBAdaptive
00186             ret = new LearningNodeNBAdaptive(initialClassObservations);
00187         }
00188         return ret;
00189     }
00190 
00191     @Override
00192     public boolean isRandomizable() {
00193         return true;
00194     }
00195 }
 All Classes Namespaces Files Functions Variables Enumerations