MOA 12.03
Real Time Analytics for Data Streams
|
00001 /* 00002 * DecisionStump.java 00003 * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand 00004 * @author Richard Kirkby ([email protected]) 00005 * 00006 * This program is free software; you can redistribute it and/or modify 00007 * it under the terms of the GNU General Public License as published by 00008 * the Free Software Foundation; either version 3 of the License, or 00009 * (at your option) any later version. 00010 * 00011 * This program is distributed in the hope that it will be useful, 00012 * but WITHOUT ANY WARRANTY; without even the implied warranty of 00013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00014 * GNU General Public License for more details. 00015 * 00016 * You should have received a copy of the GNU General Public License 00017 * along with this program. If not, see <http://www.gnu.org/licenses/>. 00018 * 00019 */ 00020 package moa.classifiers.trees; 00021 00022 import moa.classifiers.AbstractClassifier; 00023 import moa.classifiers.core.attributeclassobservers.AttributeClassObserver; 00024 import moa.classifiers.core.AttributeSplitSuggestion; 00025 import moa.classifiers.core.attributeclassobservers.GaussianNumericAttributeClassObserver; 00026 import moa.classifiers.core.attributeclassobservers.NominalAttributeClassObserver; 00027 import moa.classifiers.core.splitcriteria.SplitCriterion; 00028 import moa.core.AutoExpandVector; 00029 import moa.core.DoubleVector; 00030 import moa.core.Measurement; 00031 import moa.options.ClassOption; 00032 import moa.options.FlagOption; 00033 import moa.options.IntOption; 00034 import weka.core.Instance; 00035 00050 public class DecisionStump extends AbstractClassifier { 00051 00052 private static final long serialVersionUID = 1L; 00053 00054 @Override 00055 public String getPurposeString() { 00056 return "Decision trees of one level."; 00057 } 00058 00059 public IntOption gracePeriodOption = new IntOption("gracePeriod", 'g', 00060 "The number of instances to observe between model changes.", 1000, 00061 0, Integer.MAX_VALUE); 00062 00063 public FlagOption binarySplitsOption = new FlagOption("binarySplits", 'b', 00064 "Only allow binary splits."); 00065 00066 public ClassOption splitCriterionOption = new ClassOption("splitCriterion", 00067 'c', "Split criterion to use.", SplitCriterion.class, 00068 "InfoGainSplitCriterion"); 00069 00070 protected AttributeSplitSuggestion bestSplit; 00071 00072 protected DoubleVector observedClassDistribution; 00073 00074 protected AutoExpandVector<AttributeClassObserver> attributeObservers; 00075 00076 protected double weightSeenAtLastSplit; 00077 00078 @Override 00079 public void resetLearningImpl() { 00080 this.bestSplit = null; 00081 this.observedClassDistribution = new DoubleVector(); 00082 this.attributeObservers = new AutoExpandVector<AttributeClassObserver>(); 00083 this.weightSeenAtLastSplit = 0.0; 00084 } 00085 00086 @Override 00087 protected Measurement[] getModelMeasurementsImpl() { 00088 return null; 00089 } 00090 00091 @Override 00092 public void getModelDescription(StringBuilder out, int indent) { 00093 // TODO Auto-generated method stub 00094 } 00095 00096 @Override 00097 public void trainOnInstanceImpl(Instance inst) { 00098 this.observedClassDistribution.addToValue((int) inst.classValue(), inst.weight()); 00099 for (int i = 0; i < inst.numAttributes() - 1; i++) { 00100 int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst); 00101 AttributeClassObserver obs = this.attributeObservers.get(i); 00102 if (obs == null) { 00103 obs = inst.attribute(instAttIndex).isNominal() ? newNominalClassObserver() 00104 : newNumericClassObserver(); 00105 this.attributeObservers.set(i, obs); 00106 } 00107 obs.observeAttributeClass(inst.value(instAttIndex), (int) inst.classValue(), inst.weight()); 00108 } 00109 if (this.trainingWeightSeenByModel - this.weightSeenAtLastSplit >= this.gracePeriodOption.getValue()) { 00110 this.bestSplit = findBestSplit((SplitCriterion) getPreparedClassOption(this.splitCriterionOption)); 00111 this.weightSeenAtLastSplit = this.trainingWeightSeenByModel; 00112 } 00113 } 00114 00115 @Override 00116 public double[] getVotesForInstance(Instance inst) { 00117 if (this.bestSplit != null) { 00118 int branch = this.bestSplit.splitTest.branchForInstance(inst); 00119 if (branch >= 0) { 00120 return this.bestSplit.resultingClassDistributionFromSplit(branch); 00121 } 00122 } 00123 return this.observedClassDistribution.getArrayCopy(); 00124 } 00125 00126 @Override 00127 public boolean isRandomizable() { 00128 return false; 00129 } 00130 00131 protected AttributeClassObserver newNominalClassObserver() { 00132 return new NominalAttributeClassObserver(); 00133 } 00134 00135 protected AttributeClassObserver newNumericClassObserver() { 00136 return new GaussianNumericAttributeClassObserver(); 00137 } 00138 00139 protected AttributeSplitSuggestion findBestSplit(SplitCriterion criterion) { 00140 AttributeSplitSuggestion bestFound = null; 00141 double bestMerit = Double.NEGATIVE_INFINITY; 00142 double[] preSplitDist = this.observedClassDistribution.getArrayCopy(); 00143 for (int i = 0; i < this.attributeObservers.size(); i++) { 00144 AttributeClassObserver obs = this.attributeObservers.get(i); 00145 if (obs != null) { 00146 AttributeSplitSuggestion suggestion = obs.getBestEvaluatedSplitSuggestion(criterion, 00147 preSplitDist, i, this.binarySplitsOption.isSet()); 00148 if (suggestion.merit > bestMerit) { 00149 bestMerit = suggestion.merit; 00150 bestFound = suggestion; 00151 } 00152 } 00153 } 00154 return bestFound; 00155 } 00156 }