MOA 12.03
Real Time Analytics for Data Streams
WeightedMajorityAlgorithm.java
Go to the documentation of this file.
00001 /*
00002  *    WeightedMajorityAlgorithm.java
00003  *    Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
00004  *    @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
00005  *
00006  *    This program is free software; you can redistribute it and/or modify
00007  *    it under the terms of the GNU General Public License as published by
00008  *    the Free Software Foundation; either version 3 of the License, or
00009  *    (at your option) any later version.
00010  *
00011  *    This program is distributed in the hope that it will be useful,
00012  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00013  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00014  *    GNU General Public License for more details.
00015  *
00016  *    You should have received a copy of the GNU General Public License
00017  *    along with this program. If not, see <http://www.gnu.org/licenses/>.
00018  *    
00019  */
00020 package moa.classifiers.meta;
00021 
00022 import moa.classifiers.AbstractClassifier;
00023 import moa.classifiers.Classifier;
00024 import moa.core.DoubleVector;
00025 import moa.core.Measurement;
00026 import moa.core.ObjectRepository;
00027 import moa.options.ClassOption;
00028 import moa.options.FlagOption;
00029 import moa.options.FloatOption;
00030 import moa.options.ListOption;
00031 import moa.options.Option;
00032 import moa.tasks.TaskMonitor;
00033 import weka.core.Instance;
00034 import weka.core.Utils;
00035 
00042 public class WeightedMajorityAlgorithm extends AbstractClassifier {
00043 
00044     private static final long serialVersionUID = 1L;
00045     
00046     @Override
00047     public String getPurposeString() {
00048         return "Weighted majority algorithm for data streams.";
00049     }
00050         
00051     public ListOption learnerListOption = new ListOption(
00052             "learners",
00053             'l',
00054             "The learners to combine.",
00055             new ClassOption("learner", ' ', "", Classifier.class,
00056             "trees.HoeffdingTree"),
00057             new Option[]{
00058                 new ClassOption("", ' ', "", Classifier.class,
00059                 "trees.HoeffdingTree -l MC"),
00060                 new ClassOption("", ' ', "", Classifier.class,
00061                 "trees.HoeffdingTree -l NB"),
00062                 new ClassOption("", ' ', "", Classifier.class,
00063                 "trees.HoeffdingTree -l NBAdaptive"),
00064                 new ClassOption("", ' ', "", Classifier.class, "bayes.NaiveBayes")},
00065             ',');
00066 
00067     public FloatOption betaOption = new FloatOption("beta", 'b',
00068             "Factor to punish mistakes by.", 0.9, 0.0, 1.0);
00069 
00070     public FloatOption gammaOption = new FloatOption("gamma", 'g',
00071             "Minimum fraction of weight per model.", 0.01, 0.0, 0.5);
00072 
00073     public FlagOption pruneOption = new FlagOption("prune", 'p',
00074             "Prune poorly performing models from ensemble.");
00075 
00076     protected Classifier[] ensemble;
00077 
00078     protected double[] ensembleWeights;
00079 
00080     @Override
00081     public void prepareForUseImpl(TaskMonitor monitor,
00082             ObjectRepository repository) {
00083         Option[] learnerOptions = this.learnerListOption.getList();
00084         this.ensemble = new Classifier[learnerOptions.length];
00085         for (int i = 0; i < learnerOptions.length; i++) {
00086             monitor.setCurrentActivity("Materializing learner " + (i + 1)
00087                     + "...", -1.0);
00088             this.ensemble[i] = (Classifier) ((ClassOption) learnerOptions[i]).materializeObject(monitor, repository);
00089             if (monitor.taskShouldAbort()) {
00090                 return;
00091             }
00092             monitor.setCurrentActivity("Preparing learner " + (i + 1) + "...",
00093                     -1.0);
00094             this.ensemble[i].prepareForUse(monitor, repository);
00095             if (monitor.taskShouldAbort()) {
00096                 return;
00097             }
00098         }
00099         super.prepareForUseImpl(monitor, repository);
00100     }
00101 
00102     @Override
00103     public void resetLearningImpl() {
00104         this.ensembleWeights = new double[this.ensemble.length];
00105         for (int i = 0; i < this.ensemble.length; i++) {
00106             this.ensemble[i].resetLearning();
00107             this.ensembleWeights[i] = 1.0;
00108         }
00109     }
00110 
00111     @Override
00112     public void trainOnInstanceImpl(Instance inst) {
00113         double totalWeight = 0.0;
00114         for (int i = 0; i < this.ensemble.length; i++) {
00115             boolean prune = false;
00116             if (!this.ensemble[i].correctlyClassifies(inst)) {
00117                 if (this.ensembleWeights[i] > this.gammaOption.getValue()
00118                         / this.ensembleWeights.length) {
00119                     this.ensembleWeights[i] *= this.betaOption.getValue()
00120                             * inst.weight();
00121                 } else if (this.pruneOption.isSet()) {
00122                     prune = true;
00123                     discardModel(i);
00124                     i--;
00125                 }
00126             }
00127             if (!prune) {
00128                 totalWeight += this.ensembleWeights[i];
00129                 this.ensemble[i].trainOnInstance(inst);
00130             }
00131         }
00132         // normalize weights
00133         for (int i = 0; i < this.ensembleWeights.length; i++) {
00134             this.ensembleWeights[i] /= totalWeight;
00135         }
00136     }
00137 
00138     public double[] getVotesForInstance(Instance inst) {
00139         DoubleVector combinedVote = new DoubleVector();
00140         if (this.trainingWeightSeenByModel > 0.0) {
00141             for (int i = 0; i < this.ensemble.length; i++) {
00142                 if (this.ensembleWeights[i] > 0.0) {
00143                     DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst));
00144                     if (vote.sumOfValues() > 0.0) {
00145                         vote.normalize();
00146                         vote.scaleValues(this.ensembleWeights[i]);
00147                         combinedVote.addValues(vote);
00148                     }
00149                 }
00150             }
00151         }
00152         return combinedVote.getArrayRef();
00153     }
00154 
00155     @Override
00156     public void getModelDescription(StringBuilder out, int indent) {
00157         // TODO Auto-generated method stub
00158     }
00159 
00160     @Override
00161     protected Measurement[] getModelMeasurementsImpl() {
00162         Measurement[] measurements = null;
00163         if (this.ensembleWeights != null) {
00164             measurements = new Measurement[this.ensembleWeights.length];
00165             for (int i = 0; i < this.ensembleWeights.length; i++) {
00166                 measurements[i] = new Measurement("member weight " + (i + 1),
00167                         this.ensembleWeights[i]);
00168             }
00169         }
00170         return measurements;
00171     }
00172 
00173     @Override
00174     public boolean isRandomizable() {
00175         return false;
00176     }
00177 
00178     @Override
00179     public Classifier[] getSubClassifiers() {
00180         return this.ensemble.clone();
00181     }
00182 
00183     public void discardModel(int index) {
00184         Classifier[] newEnsemble = new Classifier[this.ensemble.length - 1];
00185         double[] newEnsembleWeights = new double[newEnsemble.length];
00186         int oldPos = 0;
00187         for (int i = 0; i < newEnsemble.length; i++) {
00188             if (oldPos == index) {
00189                 oldPos++;
00190             }
00191             newEnsemble[i] = this.ensemble[oldPos];
00192             newEnsembleWeights[i] = this.ensembleWeights[oldPos];
00193             oldPos++;
00194         }
00195         this.ensemble = newEnsemble;
00196         this.ensembleWeights = newEnsembleWeights;
00197     }
00198 
00199     protected int removePoorestModelBytes() {
00200         int poorestIndex = Utils.minIndex(this.ensembleWeights);
00201         int byteSize = this.ensemble[poorestIndex].measureByteSize();
00202         discardModel(poorestIndex);
00203         return byteSize;
00204     }
00205 }
 All Classes Namespaces Files Functions Variables Enumerations