MOA 12.03
Real Time Analytics for Data Streams
|
00001 /* 00002 * WeightedMajorityAlgorithm.java 00003 * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand 00004 * @author Richard Kirkby ([email protected]) 00005 * 00006 * This program is free software; you can redistribute it and/or modify 00007 * it under the terms of the GNU General Public License as published by 00008 * the Free Software Foundation; either version 3 of the License, or 00009 * (at your option) any later version. 00010 * 00011 * This program is distributed in the hope that it will be useful, 00012 * but WITHOUT ANY WARRANTY; without even the implied warranty of 00013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00014 * GNU General Public License for more details. 00015 * 00016 * You should have received a copy of the GNU General Public License 00017 * along with this program. If not, see <http://www.gnu.org/licenses/>. 00018 * 00019 */ 00020 package moa.classifiers.meta; 00021 00022 import moa.classifiers.AbstractClassifier; 00023 import moa.classifiers.Classifier; 00024 import moa.core.DoubleVector; 00025 import moa.core.Measurement; 00026 import moa.core.ObjectRepository; 00027 import moa.options.ClassOption; 00028 import moa.options.FlagOption; 00029 import moa.options.FloatOption; 00030 import moa.options.ListOption; 00031 import moa.options.Option; 00032 import moa.tasks.TaskMonitor; 00033 import weka.core.Instance; 00034 import weka.core.Utils; 00035 00042 public class WeightedMajorityAlgorithm extends AbstractClassifier { 00043 00044 private static final long serialVersionUID = 1L; 00045 00046 @Override 00047 public String getPurposeString() { 00048 return "Weighted majority algorithm for data streams."; 00049 } 00050 00051 public ListOption learnerListOption = new ListOption( 00052 "learners", 00053 'l', 00054 "The learners to combine.", 00055 new ClassOption("learner", ' ', "", Classifier.class, 00056 "trees.HoeffdingTree"), 00057 new Option[]{ 00058 new ClassOption("", ' ', "", Classifier.class, 00059 "trees.HoeffdingTree -l MC"), 00060 new ClassOption("", ' ', "", Classifier.class, 00061 "trees.HoeffdingTree -l NB"), 00062 new ClassOption("", ' ', "", Classifier.class, 00063 "trees.HoeffdingTree -l NBAdaptive"), 00064 new ClassOption("", ' ', "", Classifier.class, "bayes.NaiveBayes")}, 00065 ','); 00066 00067 public FloatOption betaOption = new FloatOption("beta", 'b', 00068 "Factor to punish mistakes by.", 0.9, 0.0, 1.0); 00069 00070 public FloatOption gammaOption = new FloatOption("gamma", 'g', 00071 "Minimum fraction of weight per model.", 0.01, 0.0, 0.5); 00072 00073 public FlagOption pruneOption = new FlagOption("prune", 'p', 00074 "Prune poorly performing models from ensemble."); 00075 00076 protected Classifier[] ensemble; 00077 00078 protected double[] ensembleWeights; 00079 00080 @Override 00081 public void prepareForUseImpl(TaskMonitor monitor, 00082 ObjectRepository repository) { 00083 Option[] learnerOptions = this.learnerListOption.getList(); 00084 this.ensemble = new Classifier[learnerOptions.length]; 00085 for (int i = 0; i < learnerOptions.length; i++) { 00086 monitor.setCurrentActivity("Materializing learner " + (i + 1) 00087 + "...", -1.0); 00088 this.ensemble[i] = (Classifier) ((ClassOption) learnerOptions[i]).materializeObject(monitor, repository); 00089 if (monitor.taskShouldAbort()) { 00090 return; 00091 } 00092 monitor.setCurrentActivity("Preparing learner " + (i + 1) + "...", 00093 -1.0); 00094 this.ensemble[i].prepareForUse(monitor, repository); 00095 if (monitor.taskShouldAbort()) { 00096 return; 00097 } 00098 } 00099 super.prepareForUseImpl(monitor, repository); 00100 } 00101 00102 @Override 00103 public void resetLearningImpl() { 00104 this.ensembleWeights = new double[this.ensemble.length]; 00105 for (int i = 0; i < this.ensemble.length; i++) { 00106 this.ensemble[i].resetLearning(); 00107 this.ensembleWeights[i] = 1.0; 00108 } 00109 } 00110 00111 @Override 00112 public void trainOnInstanceImpl(Instance inst) { 00113 double totalWeight = 0.0; 00114 for (int i = 0; i < this.ensemble.length; i++) { 00115 boolean prune = false; 00116 if (!this.ensemble[i].correctlyClassifies(inst)) { 00117 if (this.ensembleWeights[i] > this.gammaOption.getValue() 00118 / this.ensembleWeights.length) { 00119 this.ensembleWeights[i] *= this.betaOption.getValue() 00120 * inst.weight(); 00121 } else if (this.pruneOption.isSet()) { 00122 prune = true; 00123 discardModel(i); 00124 i--; 00125 } 00126 } 00127 if (!prune) { 00128 totalWeight += this.ensembleWeights[i]; 00129 this.ensemble[i].trainOnInstance(inst); 00130 } 00131 } 00132 // normalize weights 00133 for (int i = 0; i < this.ensembleWeights.length; i++) { 00134 this.ensembleWeights[i] /= totalWeight; 00135 } 00136 } 00137 00138 public double[] getVotesForInstance(Instance inst) { 00139 DoubleVector combinedVote = new DoubleVector(); 00140 if (this.trainingWeightSeenByModel > 0.0) { 00141 for (int i = 0; i < this.ensemble.length; i++) { 00142 if (this.ensembleWeights[i] > 0.0) { 00143 DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst)); 00144 if (vote.sumOfValues() > 0.0) { 00145 vote.normalize(); 00146 vote.scaleValues(this.ensembleWeights[i]); 00147 combinedVote.addValues(vote); 00148 } 00149 } 00150 } 00151 } 00152 return combinedVote.getArrayRef(); 00153 } 00154 00155 @Override 00156 public void getModelDescription(StringBuilder out, int indent) { 00157 // TODO Auto-generated method stub 00158 } 00159 00160 @Override 00161 protected Measurement[] getModelMeasurementsImpl() { 00162 Measurement[] measurements = null; 00163 if (this.ensembleWeights != null) { 00164 measurements = new Measurement[this.ensembleWeights.length]; 00165 for (int i = 0; i < this.ensembleWeights.length; i++) { 00166 measurements[i] = new Measurement("member weight " + (i + 1), 00167 this.ensembleWeights[i]); 00168 } 00169 } 00170 return measurements; 00171 } 00172 00173 @Override 00174 public boolean isRandomizable() { 00175 return false; 00176 } 00177 00178 @Override 00179 public Classifier[] getSubClassifiers() { 00180 return this.ensemble.clone(); 00181 } 00182 00183 public void discardModel(int index) { 00184 Classifier[] newEnsemble = new Classifier[this.ensemble.length - 1]; 00185 double[] newEnsembleWeights = new double[newEnsemble.length]; 00186 int oldPos = 0; 00187 for (int i = 0; i < newEnsemble.length; i++) { 00188 if (oldPos == index) { 00189 oldPos++; 00190 } 00191 newEnsemble[i] = this.ensemble[oldPos]; 00192 newEnsembleWeights[i] = this.ensembleWeights[oldPos]; 00193 oldPos++; 00194 } 00195 this.ensemble = newEnsemble; 00196 this.ensembleWeights = newEnsembleWeights; 00197 } 00198 00199 protected int removePoorestModelBytes() { 00200 int poorestIndex = Utils.minIndex(this.ensembleWeights); 00201 int byteSize = this.ensemble[poorestIndex].measureByteSize(); 00202 discardModel(poorestIndex); 00203 return byteSize; 00204 } 00205 }