MOA 12.03
Real Time Analytics for Data Streams
|
00001 /* 00002 * LeveragingBag.java 00003 * Copyright (C) 2010 University of Waikato, Hamilton, New Zealand 00004 * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) 00005 * 00006 * This program is free software; you can redistribute it and/or modify 00007 * it under the terms of the GNU General Public License as published by 00008 * the Free Software Foundation; either version 3 of the License, or 00009 * (at your option) any later version. 00010 * 00011 * This program is distributed in the hope that it will be useful, 00012 * but WITHOUT ANY WARRANTY; without even the implied warranty of 00013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00014 * GNU General Public License for more details. 00015 * 00016 * You should have received a copy of the GNU General Public License 00017 * along with this program. If not, see <http://www.gnu.org/licenses/>. 00018 * 00019 */ 00020 package moa.classifiers.meta; 00021 00022 import moa.classifiers.core.driftdetection.ADWIN; 00023 import moa.classifiers.AbstractClassifier; 00024 import moa.classifiers.Classifier; 00025 import weka.core.Instance; 00026 00027 import moa.core.DoubleVector; 00028 import moa.core.Measurement; 00029 import moa.core.MiscUtils; 00030 import moa.options.*; 00031 00043 public class LeveragingBag extends AbstractClassifier { 00044 00045 private static final long serialVersionUID = 1L; 00046 00047 @Override 00048 public String getPurposeString() { 00049 return "Leveraging Bagging for evolving data streams using ADWIN."; 00050 } 00051 00052 public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', 00053 "Classifier to train.", Classifier.class, "trees.HoeffdingTree"); 00054 00055 public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', 00056 "The number of models in the bag.", 10, 1, Integer.MAX_VALUE); 00057 00058 public FloatOption weightShrinkOption = new FloatOption("weightShrink", 'w', 00059 "The number to use to compute the weight of new instances.", 6, 0.0, Float.MAX_VALUE); 00060 00061 public FloatOption deltaAdwinOption = new FloatOption("deltaAdwin", 'a', 00062 "Delta of Adwin change detection", 0.002, 0.0, 1.0); 00063 00064 // Leveraging Bagging MC: uses this option to use Output Codes 00065 public FlagOption outputCodesOption = new FlagOption("outputCodes", 'o', 00066 "Use Output Codes to use binary classifiers."); 00067 00068 public MultiChoiceOption leveraginBagAlgorithmOption = new MultiChoiceOption( 00069 "leveraginBagAlgorithm", 'm', "Leveraging Bagging to use.", new String[]{ 00070 "LeveragingBag", "LeveragingBagME", "LeveragingBagHalf", "LeveragingBagWT", "LeveragingSubag"}, 00071 new String[]{"Leveraging Bagging for evolving data streams using ADWIN", 00072 "Leveraging Bagging ME using weight 1 if misclassified, otherwise error/(1-error)", 00073 "Leveraging Bagging Half using resampling without replacement half of the instances", 00074 "Leveraging Bagging WT without taking out all instances.", 00075 "Leveraging Subagging using resampling without replacement." 00076 }, 0); 00077 00078 protected Classifier[] ensemble; 00079 00080 protected ADWIN[] ADError; 00081 00082 protected int numberOfChangesDetected; 00083 00084 protected int[][] matrixCodes; 00085 00086 protected boolean initMatrixCodes = false; 00087 00088 @Override 00089 public void resetLearningImpl() { 00090 this.ensemble = new Classifier[this.ensembleSizeOption.getValue()]; 00091 Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption); 00092 baseLearner.resetLearning(); 00093 for (int i = 0; i < this.ensemble.length; i++) { 00094 this.ensemble[i] = baseLearner.copy(); 00095 } 00096 this.ADError = new ADWIN[this.ensemble.length]; 00097 for (int i = 0; i < this.ensemble.length; i++) { 00098 this.ADError[i] = new ADWIN((double) this.deltaAdwinOption.getValue()); 00099 } 00100 this.numberOfChangesDetected = 0; 00101 if (this.outputCodesOption.isSet()) { 00102 this.initMatrixCodes = true; 00103 } 00104 } 00105 00106 @Override 00107 public void trainOnInstanceImpl(Instance inst) { 00108 int numClasses = inst.numClasses(); 00109 //Output Codes 00110 if (this.initMatrixCodes == true) { 00111 this.matrixCodes = new int[this.ensemble.length][inst.numClasses()]; 00112 for (int i = 0; i < this.ensemble.length; i++) { 00113 int numberOnes; 00114 int numberZeros; 00115 00116 do { // until we have the same number of zeros and ones 00117 numberOnes = 0; 00118 numberZeros = 0; 00119 for (int j = 0; j < numClasses; j++) { 00120 int result = 0; 00121 if (j == 1 && numClasses == 2) { 00122 result = 1 - this.matrixCodes[i][0]; 00123 } else { 00124 result = (this.classifierRandom.nextBoolean() ? 1 : 0); 00125 } 00126 this.matrixCodes[i][j] = result; 00127 if (result == 1) { 00128 numberOnes++; 00129 } else { 00130 numberZeros++; 00131 } 00132 } 00133 } while ((numberOnes - numberZeros) * (numberOnes - numberZeros) > (this.ensemble.length % 2)); 00134 00135 } 00136 this.initMatrixCodes = false; 00137 } 00138 00139 00140 boolean Change = false; 00141 Instance weightedInst = (Instance) inst.copy(); 00142 double w = this.weightShrinkOption.getValue(); 00143 00144 //Train ensemble of classifiers 00145 for (int i = 0; i < this.ensemble.length; i++) { 00146 double k = 0.0; 00147 switch (this.leveraginBagAlgorithmOption.getChosenIndex()) { 00148 case 0: //LeveragingBag 00149 k = MiscUtils.poisson(w, this.classifierRandom); 00150 break; 00151 case 1: //LeveragingBagME 00152 double error = this.ADError[i].getEstimation(); 00153 k = !this.ensemble[i].correctlyClassifies(weightedInst) ? 1.0 : (this.classifierRandom.nextDouble() < (error / (1.0 - error)) ? 1.0 : 0.0); 00154 break; 00155 case 2: //LeveragingBagHalf 00156 w = 1.0; 00157 k = this.classifierRandom.nextBoolean() ? 0.0 : w; 00158 break; 00159 case 3: //LeveragingBagWT 00160 w = 1.0; 00161 k = 1.0 + MiscUtils.poisson(w, this.classifierRandom); 00162 break; 00163 case 4: //LeveragingSubag 00164 w = 1.0; 00165 k = MiscUtils.poisson(1, this.classifierRandom); 00166 k = (k > 0) ? w : 0; 00167 break; 00168 } 00169 if (k > 0) { 00170 if (this.outputCodesOption.isSet()) { 00171 weightedInst.setClassValue((double) this.matrixCodes[i][(int) inst.classValue()]); 00172 } 00173 weightedInst.setWeight(inst.weight() * k); 00174 this.ensemble[i].trainOnInstance(weightedInst); 00175 } 00176 boolean correctlyClassifies = this.ensemble[i].correctlyClassifies(weightedInst); 00177 double ErrEstim = this.ADError[i].getEstimation(); 00178 if (this.ADError[i].setInput(correctlyClassifies ? 0 : 1)) { 00179 if (this.ADError[i].getEstimation() > ErrEstim) { 00180 Change = true; 00181 } 00182 } 00183 } 00184 if (Change) { 00185 numberOfChangesDetected++; 00186 double max = 0.0; 00187 int imax = -1; 00188 for (int i = 0; i < this.ensemble.length; i++) { 00189 if (max < this.ADError[i].getEstimation()) { 00190 max = this.ADError[i].getEstimation(); 00191 imax = i; 00192 } 00193 } 00194 if (imax != -1) { 00195 this.ensemble[imax].resetLearning(); 00196 //this.ensemble[imax].trainOnInstance(inst); 00197 this.ADError[imax] = new ADWIN((double) this.deltaAdwinOption.getValue()); 00198 } 00199 } 00200 } 00201 00202 @Override 00203 public double[] getVotesForInstance(Instance inst) { 00204 if (this.outputCodesOption.isSet()) { 00205 return getVotesForInstanceBinary(inst); 00206 } 00207 DoubleVector combinedVote = new DoubleVector(); 00208 for (int i = 0; i < this.ensemble.length; i++) { 00209 DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst)); 00210 if (vote.sumOfValues() > 0.0) { 00211 vote.normalize(); 00212 combinedVote.addValues(vote); 00213 } 00214 } 00215 return combinedVote.getArrayRef(); 00216 } 00217 00218 public double[] getVotesForInstanceBinary(Instance inst) { 00219 double combinedVote[] = new double[(int) inst.numClasses()]; 00220 Instance weightedInst = (Instance) inst.copy(); 00221 if (this.initMatrixCodes == false) { 00222 for (int i = 0; i < this.ensemble.length; i++) { 00223 //Replace class by OC 00224 weightedInst.setClassValue((double) this.matrixCodes[i][(int) inst.classValue()]); 00225 00226 double vote[]; 00227 vote = this.ensemble[i].getVotesForInstance(weightedInst); 00228 //Binary Case 00229 int voteClass = 0; 00230 if (vote.length == 2) { 00231 voteClass = (vote[1] > vote[0] ? 1 : 0); 00232 } 00233 //Update votes 00234 for (int j = 0; j < inst.numClasses(); j++) { 00235 if (this.matrixCodes[i][j] == voteClass) { 00236 combinedVote[j] += 1; 00237 } 00238 } 00239 } 00240 } 00241 return combinedVote; 00242 } 00243 00244 @Override 00245 public boolean isRandomizable() { 00246 return true; 00247 } 00248 00249 @Override 00250 public void getModelDescription(StringBuilder out, int indent) { 00251 // TODO Auto-generated method stub 00252 } 00253 00254 @Override 00255 protected Measurement[] getModelMeasurementsImpl() { 00256 return new Measurement[]{new Measurement("ensemble size", 00257 this.ensemble != null ? this.ensemble.length : 0), 00258 new Measurement("change detections", this.numberOfChangesDetected) 00259 }; 00260 } 00261 00262 @Override 00263 public Classifier[] getSubClassifiers() { 00264 return this.ensemble.clone(); 00265 } 00266 } 00267