MOA 12.03
Real Time Analytics for Data Streams
OzaBoost.java
Go to the documentation of this file.
00001 /*
00002  *    OzaBoost.java
00003  *    Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
00004  *    @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
00005  *
00006  *    This program is free software; you can redistribute it and/or modify
00007  *    it under the terms of the GNU General Public License as published by
00008  *    the Free Software Foundation; either version 3 of the License, or
00009  *    (at your option) any later version.
00010  *
00011  *    This program is distributed in the hope that it will be useful,
00012  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00013  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00014  *    GNU General Public License for more details.
00015  *
00016  *    You should have received a copy of the GNU General Public License
00017  *    along with this program. If not, see <http://www.gnu.org/licenses/>.
00018  *    
00019  */
00020 package moa.classifiers.meta;
00021 
00022 import moa.classifiers.AbstractClassifier;
00023 import moa.classifiers.Classifier;
00024 import weka.core.Instance;
00025 
00026 import moa.core.DoubleVector;
00027 import moa.core.Measurement;
00028 import moa.core.MiscUtils;
00029 import moa.options.ClassOption;
00030 import moa.options.FlagOption;
00031 import moa.options.IntOption;
00032 
00054 public class OzaBoost extends AbstractClassifier {
00055 
00056     private static final long serialVersionUID = 1L;
00057 
00058     @Override
00059     public String getPurposeString() {
00060         return "Incremental on-line boosting of Oza and Russell.";
00061     }
00062 
00063     public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
00064             "Classifier to train.", Classifier.class, "trees.HoeffdingTree");
00065 
00066     public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
00067             "The number of models to boost.", 10, 1, Integer.MAX_VALUE);
00068 
00069     public FlagOption pureBoostOption = new FlagOption("pureBoost", 'p',
00070             "Boost with weights only; no poisson.");
00071 
00072     protected Classifier[] ensemble;
00073 
00074     protected double[] scms;
00075 
00076     protected double[] swms;
00077 
00078     @Override
00079     public void resetLearningImpl() {
00080         this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
00081         Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
00082         baseLearner.resetLearning();
00083         for (int i = 0; i < this.ensemble.length; i++) {
00084             this.ensemble[i] = baseLearner.copy();
00085         }
00086         this.scms = new double[this.ensemble.length];
00087         this.swms = new double[this.ensemble.length];
00088     }
00089 
00090     @Override
00091     public void trainOnInstanceImpl(Instance inst) {
00092         double lambda_d = 1.0;
00093         for (int i = 0; i < this.ensemble.length; i++) {
00094             double k = this.pureBoostOption.isSet() ? lambda_d : MiscUtils.poisson(lambda_d, this.classifierRandom);
00095             if (k > 0.0) {
00096                 Instance weightedInst = (Instance) inst.copy();
00097                 weightedInst.setWeight(inst.weight() * k);
00098                 this.ensemble[i].trainOnInstance(weightedInst);
00099             }
00100             if (this.ensemble[i].correctlyClassifies(inst)) {
00101                 this.scms[i] += lambda_d;
00102                 lambda_d *= this.trainingWeightSeenByModel / (2 * this.scms[i]);
00103             } else {
00104                 this.swms[i] += lambda_d;
00105                 lambda_d *= this.trainingWeightSeenByModel / (2 * this.swms[i]);
00106             }
00107         }
00108     }
00109 
00110     protected double getEnsembleMemberWeight(int i) {
00111         double em = this.swms[i] / (this.scms[i] + this.swms[i]);
00112         if ((em == 0.0) || (em > 0.5)) {
00113             return 0.0;
00114         }
00115         double Bm = em / (1.0 - em);
00116         return Math.log(1.0 / Bm);
00117     }
00118 
00119     public double[] getVotesForInstance(Instance inst) {
00120         DoubleVector combinedVote = new DoubleVector();
00121         for (int i = 0; i < this.ensemble.length; i++) {
00122             double memberWeight = getEnsembleMemberWeight(i);
00123             if (memberWeight > 0.0) {
00124                 DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst));
00125                 if (vote.sumOfValues() > 0.0) {
00126                     vote.normalize();
00127                     vote.scaleValues(memberWeight);
00128                     combinedVote.addValues(vote);
00129                 }
00130             } else {
00131                 break;
00132             }
00133         }
00134         return combinedVote.getArrayRef();
00135     }
00136 
00137     public boolean isRandomizable() {
00138         return true;
00139     }
00140 
00141     @Override
00142     public void getModelDescription(StringBuilder out, int indent) {
00143         // TODO Auto-generated method stub
00144     }
00145 
00146     @Override
00147     protected Measurement[] getModelMeasurementsImpl() {
00148         return new Measurement[]{new Measurement("ensemble size",
00149                     this.ensemble != null ? this.ensemble.length : 0)};
00150     }
00151 
00152     @Override
00153     public Classifier[] getSubClassifiers() {
00154         return this.ensemble.clone();
00155     }
00156 }
 All Classes Namespaces Files Functions Variables Enumerations