MOA 12.03
Real Time Analytics for Data Streams
LeveragingBag.java
Go to the documentation of this file.
00001 /*
00002  *    LeveragingBag.java
00003  *    Copyright (C) 2010 University of Waikato, Hamilton, New Zealand
00004  *    @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
00005  *
00006  *    This program is free software; you can redistribute it and/or modify
00007  *    it under the terms of the GNU General Public License as published by
00008  *    the Free Software Foundation; either version 3 of the License, or
00009  *    (at your option) any later version.
00010  *
00011  *    This program is distributed in the hope that it will be useful,
00012  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00013  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00014  *    GNU General Public License for more details.
00015  *
00016  *    You should have received a copy of the GNU General Public License
00017  *    along with this program. If not, see <http://www.gnu.org/licenses/>.
00018  *    
00019  */
00020 package moa.classifiers.meta;
00021 
00022 import moa.classifiers.core.driftdetection.ADWIN;
00023 import moa.classifiers.AbstractClassifier;
00024 import moa.classifiers.Classifier;
00025 import weka.core.Instance;
00026 
00027 import moa.core.DoubleVector;
00028 import moa.core.Measurement;
00029 import moa.core.MiscUtils;
00030 import moa.options.*;
00031 
00043 public class LeveragingBag extends AbstractClassifier {
00044 
00045     private static final long serialVersionUID = 1L;
00046 
00047     @Override
00048     public String getPurposeString() {
00049         return "Leveraging Bagging for evolving data streams using ADWIN.";
00050     }
00051 
00052     public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
00053             "Classifier to train.", Classifier.class, "trees.HoeffdingTree");
00054 
00055     public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
00056             "The number of models in the bag.", 10, 1, Integer.MAX_VALUE);
00057 
00058     public FloatOption weightShrinkOption = new FloatOption("weightShrink", 'w',
00059             "The number to use to compute the weight of new instances.", 6, 0.0, Float.MAX_VALUE);
00060 
00061     public FloatOption deltaAdwinOption = new FloatOption("deltaAdwin", 'a',
00062             "Delta of Adwin change detection", 0.002, 0.0, 1.0);
00063 
00064     // Leveraging Bagging MC: uses this option to use Output Codes
00065     public FlagOption outputCodesOption = new FlagOption("outputCodes", 'o',
00066             "Use Output Codes to use binary classifiers.");
00067 
00068     public MultiChoiceOption leveraginBagAlgorithmOption = new MultiChoiceOption(
00069             "leveraginBagAlgorithm", 'm', "Leveraging Bagging to use.", new String[]{
00070                 "LeveragingBag", "LeveragingBagME", "LeveragingBagHalf", "LeveragingBagWT", "LeveragingSubag"},
00071             new String[]{"Leveraging Bagging for evolving data streams using ADWIN",
00072                 "Leveraging Bagging ME using weight 1 if misclassified, otherwise error/(1-error)",
00073                 "Leveraging Bagging Half using resampling without replacement half of the instances",
00074                 "Leveraging Bagging WT without taking out all instances.",
00075                 "Leveraging Subagging using resampling without replacement."
00076             }, 0);
00077 
00078     protected Classifier[] ensemble;
00079 
00080     protected ADWIN[] ADError;
00081 
00082     protected int numberOfChangesDetected;
00083 
00084     protected int[][] matrixCodes;
00085 
00086     protected boolean initMatrixCodes = false;
00087 
00088     @Override
00089     public void resetLearningImpl() {
00090         this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
00091         Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
00092         baseLearner.resetLearning();
00093         for (int i = 0; i < this.ensemble.length; i++) {
00094             this.ensemble[i] = baseLearner.copy();
00095         }
00096         this.ADError = new ADWIN[this.ensemble.length];
00097         for (int i = 0; i < this.ensemble.length; i++) {
00098             this.ADError[i] = new ADWIN((double) this.deltaAdwinOption.getValue());
00099         }
00100         this.numberOfChangesDetected = 0;
00101         if (this.outputCodesOption.isSet()) {
00102             this.initMatrixCodes = true;
00103         }
00104     }
00105 
00106     @Override
00107     public void trainOnInstanceImpl(Instance inst) {
00108         int numClasses = inst.numClasses();
00109         //Output Codes
00110         if (this.initMatrixCodes == true) {
00111             this.matrixCodes = new int[this.ensemble.length][inst.numClasses()];
00112             for (int i = 0; i < this.ensemble.length; i++) {
00113                 int numberOnes;
00114                 int numberZeros;
00115 
00116                 do { // until we have the same number of zeros and ones
00117                     numberOnes = 0;
00118                     numberZeros = 0;
00119                     for (int j = 0; j < numClasses; j++) {
00120                         int result = 0;
00121                         if (j == 1 && numClasses == 2) {
00122                             result = 1 - this.matrixCodes[i][0];
00123                         } else {
00124                             result = (this.classifierRandom.nextBoolean() ? 1 : 0);
00125                         }
00126                         this.matrixCodes[i][j] = result;
00127                         if (result == 1) {
00128                             numberOnes++;
00129                         } else {
00130                             numberZeros++;
00131                         }
00132                     }
00133                 } while ((numberOnes - numberZeros) * (numberOnes - numberZeros) > (this.ensemble.length % 2));
00134 
00135             }
00136             this.initMatrixCodes = false;
00137         }
00138 
00139 
00140         boolean Change = false;
00141         Instance weightedInst = (Instance) inst.copy();
00142         double w = this.weightShrinkOption.getValue();
00143 
00144         //Train ensemble of classifiers
00145         for (int i = 0; i < this.ensemble.length; i++) {
00146             double k = 0.0;
00147             switch (this.leveraginBagAlgorithmOption.getChosenIndex()) {
00148                 case 0: //LeveragingBag
00149                     k = MiscUtils.poisson(w, this.classifierRandom);
00150                     break;
00151                 case 1: //LeveragingBagME
00152                     double error = this.ADError[i].getEstimation();
00153                     k = !this.ensemble[i].correctlyClassifies(weightedInst) ? 1.0 : (this.classifierRandom.nextDouble() < (error / (1.0 - error)) ? 1.0 : 0.0);
00154                     break;
00155                 case 2: //LeveragingBagHalf
00156                     w = 1.0;
00157                     k = this.classifierRandom.nextBoolean() ? 0.0 : w;
00158                     break;
00159                 case 3: //LeveragingBagWT
00160                     w = 1.0;
00161                     k = 1.0 + MiscUtils.poisson(w, this.classifierRandom);
00162                     break;
00163                 case 4: //LeveragingSubag
00164                     w = 1.0;
00165                     k = MiscUtils.poisson(1, this.classifierRandom);
00166                     k = (k > 0) ? w : 0;
00167                     break;
00168             }
00169             if (k > 0) {
00170                 if (this.outputCodesOption.isSet()) {
00171                     weightedInst.setClassValue((double) this.matrixCodes[i][(int) inst.classValue()]);
00172                 }
00173                 weightedInst.setWeight(inst.weight() * k);
00174                 this.ensemble[i].trainOnInstance(weightedInst);
00175             }
00176             boolean correctlyClassifies = this.ensemble[i].correctlyClassifies(weightedInst);
00177             double ErrEstim = this.ADError[i].getEstimation();
00178             if (this.ADError[i].setInput(correctlyClassifies ? 0 : 1)) {
00179                 if (this.ADError[i].getEstimation() > ErrEstim) {
00180                     Change = true;
00181                 }
00182             }
00183         }
00184         if (Change) {
00185             numberOfChangesDetected++;
00186             double max = 0.0;
00187             int imax = -1;
00188             for (int i = 0; i < this.ensemble.length; i++) {
00189                 if (max < this.ADError[i].getEstimation()) {
00190                     max = this.ADError[i].getEstimation();
00191                     imax = i;
00192                 }
00193             }
00194             if (imax != -1) {
00195                 this.ensemble[imax].resetLearning();
00196                 //this.ensemble[imax].trainOnInstance(inst);
00197                 this.ADError[imax] = new ADWIN((double) this.deltaAdwinOption.getValue());
00198             }
00199         }
00200     }
00201 
00202     @Override
00203     public double[] getVotesForInstance(Instance inst) {
00204         if (this.outputCodesOption.isSet()) {
00205             return getVotesForInstanceBinary(inst);
00206         }
00207         DoubleVector combinedVote = new DoubleVector();
00208         for (int i = 0; i < this.ensemble.length; i++) {
00209             DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst));
00210             if (vote.sumOfValues() > 0.0) {
00211                 vote.normalize();
00212                 combinedVote.addValues(vote);
00213             }
00214         }
00215         return combinedVote.getArrayRef();
00216     }
00217 
00218     public double[] getVotesForInstanceBinary(Instance inst) {
00219         double combinedVote[] = new double[(int) inst.numClasses()];
00220         Instance weightedInst = (Instance) inst.copy();
00221         if (this.initMatrixCodes == false) {
00222             for (int i = 0; i < this.ensemble.length; i++) {
00223                 //Replace class by OC
00224                 weightedInst.setClassValue((double) this.matrixCodes[i][(int) inst.classValue()]);
00225 
00226                 double vote[];
00227                 vote = this.ensemble[i].getVotesForInstance(weightedInst);
00228                 //Binary Case
00229                 int voteClass = 0;
00230                 if (vote.length == 2) {
00231                     voteClass = (vote[1] > vote[0] ? 1 : 0);
00232                 }
00233                 //Update votes
00234                 for (int j = 0; j < inst.numClasses(); j++) {
00235                     if (this.matrixCodes[i][j] == voteClass) {
00236                         combinedVote[j] += 1;
00237                     }
00238                 }
00239             }
00240         }
00241         return combinedVote;
00242     }
00243 
00244     @Override
00245     public boolean isRandomizable() {
00246         return true;
00247     }
00248 
00249     @Override
00250     public void getModelDescription(StringBuilder out, int indent) {
00251         // TODO Auto-generated method stub
00252     }
00253 
00254     @Override
00255     protected Measurement[] getModelMeasurementsImpl() {
00256         return new Measurement[]{new Measurement("ensemble size",
00257                     this.ensemble != null ? this.ensemble.length : 0),
00258                     new Measurement("change detections", this.numberOfChangesDetected)
00259                 };
00260     }
00261 
00262     @Override
00263     public Classifier[] getSubClassifiers() {
00264         return this.ensemble.clone();
00265     }
00266 }
00267 
 All Classes Namespaces Files Functions Variables Enumerations