MOA 12.03
Real Time Analytics for Data Streams
OzaBagAdwin.java
Go to the documentation of this file.
00001 /*
00002  *    OzaBagAdwin.java
00003  *    Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
00004  *    @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
00005  *
00006  *    This program is free software; you can redistribute it and/or modify
00007  *    it under the terms of the GNU General Public License as published by
00008  *    the Free Software Foundation; either version 3 of the License, or
00009  *    (at your option) any later version.
00010  *
00011  *    This program is distributed in the hope that it will be useful,
00012  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00013  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00014  *    GNU General Public License for more details.
00015  *
00016  *    You should have received a copy of the GNU General Public License
00017  *    along with this program. If not, see <http://www.gnu.org/licenses/>.
00018  *    
00019  */
00020 package moa.classifiers.meta;
00021 
00022 import moa.classifiers.core.driftdetection.ADWIN;
00023 import moa.classifiers.AbstractClassifier;
00024 import moa.classifiers.Classifier;
00025 import weka.core.Instance;
00026 
00027 import moa.core.DoubleVector;
00028 import moa.core.Measurement;
00029 import moa.core.MiscUtils;
00030 import moa.options.ClassOption;
00031 import moa.options.IntOption;
00032 
00082 public class OzaBagAdwin extends AbstractClassifier {
00083 
00084     private static final long serialVersionUID = 1L;
00085 
00086     @Override
00087     public String getPurposeString() {
00088         return "Bagging for evolving data streams using ADWIN.";
00089     }    
00090     
00091     public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
00092             "Classifier to train.", Classifier.class, "trees.HoeffdingTree");
00093 
00094     public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
00095             "The number of models in the bag.", 10, 1, Integer.MAX_VALUE);
00096 
00097     protected Classifier[] ensemble;
00098 
00099     protected ADWIN[] ADError;
00100 
00101     @Override
00102     public void resetLearningImpl() {
00103         this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
00104         Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
00105         baseLearner.resetLearning();
00106         for (int i = 0; i < this.ensemble.length; i++) {
00107             this.ensemble[i] = baseLearner.copy();
00108         }
00109         this.ADError = new ADWIN[this.ensemble.length];
00110         for (int i = 0; i < this.ensemble.length; i++) {
00111             this.ADError[i] = new ADWIN();
00112         }
00113     }
00114 
00115     @Override
00116     public void trainOnInstanceImpl(Instance inst) {
00117         boolean Change = false;
00118         for (int i = 0; i < this.ensemble.length; i++) {
00119             int k = MiscUtils.poisson(1.0, this.classifierRandom);
00120             if (k > 0) {
00121                 Instance weightedInst = (Instance) inst.copy();
00122                 weightedInst.setWeight(inst.weight() * k);
00123                 this.ensemble[i].trainOnInstance(weightedInst);
00124             }
00125             boolean correctlyClassifies = this.ensemble[i].correctlyClassifies(inst);
00126             double ErrEstim = this.ADError[i].getEstimation();
00127             if (this.ADError[i].setInput(correctlyClassifies ? 0 : 1)) {
00128                 if (this.ADError[i].getEstimation() > ErrEstim) {
00129                     Change = true;
00130                 }
00131             }
00132         }
00133         if (Change) {
00134             double max = 0.0;
00135             int imax = -1;
00136             for (int i = 0; i < this.ensemble.length; i++) {
00137                 if (max < this.ADError[i].getEstimation()) {
00138                     max = this.ADError[i].getEstimation();
00139                     imax = i;
00140                 }
00141             }
00142             if (imax != -1) {
00143                 this.ensemble[imax].resetLearning();
00144                 //this.ensemble[imax].trainOnInstance(inst);
00145                 this.ADError[imax] = new ADWIN();
00146             }
00147         }
00148     }
00149 
00150     @Override
00151     public double[] getVotesForInstance(Instance inst) {
00152         DoubleVector combinedVote = new DoubleVector();
00153         for (int i = 0; i < this.ensemble.length; i++) {
00154             DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst));
00155             if (vote.sumOfValues() > 0.0) {
00156                 vote.normalize();
00157                 combinedVote.addValues(vote);
00158             }
00159         }
00160         return combinedVote.getArrayRef();
00161     }
00162 
00163     @Override
00164     public boolean isRandomizable() {
00165         return true;
00166     }
00167 
00168     @Override
00169     public void getModelDescription(StringBuilder out, int indent) {
00170         // TODO Auto-generated method stub
00171     }
00172 
00173     @Override
00174     protected Measurement[] getModelMeasurementsImpl() {
00175         return new Measurement[]{new Measurement("ensemble size",
00176                     this.ensemble != null ? this.ensemble.length : 0)};
00177     }
00178 
00179     @Override
00180     public Classifier[] getSubClassifiers() {
00181         return this.ensemble.clone();
00182     }
00183 }
 All Classes Namespaces Files Functions Variables Enumerations