MOA 12.03
Real Time Analytics for Data Streams
|
00001 /* 00002 * OzaBoostAdwin.java 00003 * Copyright (C) 2010 University of Waikato, Hamilton, New Zealand 00004 * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) 00005 * 00006 * This program is free software; you can redistribute it and/or modify 00007 * it under the terms of the GNU General Public License as published by 00008 * the Free Software Foundation; either version 3 of the License, or 00009 * (at your option) any later version. 00010 * 00011 * This program is distributed in the hope that it will be useful, 00012 * but WITHOUT ANY WARRANTY; without even the implied warranty of 00013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00014 * GNU General Public License for more details. 00015 * 00016 * You should have received a copy of the GNU General Public License 00017 * along with this program. If not, see <http://www.gnu.org/licenses/>. 00018 * 00019 */ 00020 package moa.classifiers.meta; 00021 00022 import moa.classifiers.core.driftdetection.ADWIN; 00023 import moa.classifiers.AbstractClassifier; 00024 import moa.classifiers.Classifier; 00025 import moa.core.DoubleVector; 00026 import moa.core.Measurement; 00027 import moa.core.MiscUtils; 00028 import moa.options.ClassOption; 00029 import moa.options.FlagOption; 00030 import moa.options.FloatOption; 00031 import moa.options.IntOption; 00032 import weka.core.Instance; 00033 00040 public class OzaBoostAdwin extends AbstractClassifier { 00041 00042 private static final long serialVersionUID = 1L; 00043 00044 @Override 00045 public String getPurposeString() { 00046 return "Boosting for evolving data streams using ADWIN."; 00047 } 00048 00049 public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', 00050 "Classifier to train.", Classifier.class, "trees.HoeffdingTree"); 00051 00052 public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', 00053 "The number of models to boost.", 10, 1, Integer.MAX_VALUE); 00054 00055 public FlagOption pureBoostOption = new FlagOption("pureBoost", 'p', 00056 "Boost with weights only; no poisson."); 00057 00058 public FloatOption deltaAdwinOption = new FloatOption("deltaAdwin", 'a', 00059 "Delta of Adwin change detection", 0.002, 0.0, 1.0); 00060 00061 public FlagOption outputCodesOption = new FlagOption("outputCodes", 'o', 00062 "Use Output Codes to use binary classifiers."); 00063 00064 public FlagOption sammeOption = new FlagOption("same", 'e', 00065 "Use Samme Algorithm."); 00066 00067 protected Classifier[] ensemble; 00068 00069 protected double[] scms; 00070 00071 protected double[] swms; 00072 00073 protected ADWIN[] ADError; 00074 00075 protected int numberOfChangesDetected; 00076 00077 protected int[][] matrixCodes; 00078 00079 protected boolean initMatrixCodes = false; 00080 00081 protected double logKm1 = 0.0; 00082 00083 protected int Km1 = 1; 00084 00085 protected boolean initKm1 = false; 00086 00087 @Override 00088 public void resetLearningImpl() { 00089 this.ensemble = new Classifier[this.ensembleSizeOption.getValue()]; 00090 Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption); 00091 baseLearner.resetLearning(); 00092 for (int i = 0; i < this.ensemble.length; i++) { 00093 this.ensemble[i] = baseLearner.copy(); 00094 } 00095 this.scms = new double[this.ensemble.length]; 00096 this.swms = new double[this.ensemble.length]; 00097 this.ADError = new ADWIN[this.ensemble.length]; 00098 for (int i = 0; i < this.ensemble.length; i++) { 00099 this.ADError[i] = new ADWIN((double) this.deltaAdwinOption.getValue()); 00100 } 00101 this.numberOfChangesDetected = 0; 00102 if (this.outputCodesOption.isSet()) { 00103 this.initMatrixCodes = true; 00104 } 00105 if (this.sammeOption.isSet()) { 00106 this.initKm1 = true; 00107 } 00108 00109 } 00110 00111 @Override 00112 public void trainOnInstanceImpl(Instance inst) { 00113 int numClasses = inst.numClasses(); 00114 // Set log (k-1) and (k-1) for SAMME Method 00115 if (this.sammeOption.isSet()) { 00116 this.Km1 = numClasses - 1; 00117 this.logKm1 = Math.log(this.Km1); 00118 this.initKm1 = false; 00119 } 00120 //Output Codes 00121 if (this.initMatrixCodes == true) { 00122 00123 this.matrixCodes = new int[this.ensemble.length][inst.numClasses()]; 00124 for (int i = 0; i < this.ensemble.length; i++) { 00125 int numberOnes; 00126 int numberZeros; 00127 00128 do { // until we have the same number of zeros and ones 00129 numberOnes = 0; 00130 numberZeros = 0; 00131 for (int j = 0; j < numClasses; j++) { 00132 int result = 0; 00133 if (j == 1 && numClasses == 2) { 00134 result = 1 - this.matrixCodes[i][0]; 00135 } else { 00136 result = (this.classifierRandom.nextBoolean() ? 1 : 0); 00137 } 00138 this.matrixCodes[i][j] = result; 00139 if (result == 1) { 00140 numberOnes++; 00141 } else { 00142 numberZeros++; 00143 } 00144 } 00145 } while ((numberOnes - numberZeros) * (numberOnes - numberZeros) > (this.ensemble.length % 2)); 00146 00147 } 00148 this.initMatrixCodes = false; 00149 } 00150 00151 00152 boolean Change = false; 00153 double lambda_d = 1.0; 00154 Instance weightedInst = (Instance) inst.copy(); 00155 for (int i = 0; i < this.ensemble.length; i++) { 00156 double k = this.pureBoostOption.isSet() ? lambda_d : MiscUtils.poisson(lambda_d * this.Km1, this.classifierRandom); 00157 if (k > 0.0) { 00158 if (this.outputCodesOption.isSet()) { 00159 weightedInst.setClassValue((double) this.matrixCodes[i][(int) inst.classValue()]); 00160 } 00161 weightedInst.setWeight(inst.weight() * k); 00162 this.ensemble[i].trainOnInstance(weightedInst); 00163 } 00164 boolean correctlyClassifies = this.ensemble[i].correctlyClassifies(weightedInst); 00165 if (correctlyClassifies) { 00166 this.scms[i] += lambda_d; 00167 lambda_d *= this.trainingWeightSeenByModel / (2 * this.scms[i]); 00168 } else { 00169 this.swms[i] += lambda_d; 00170 lambda_d *= this.trainingWeightSeenByModel / (2 * this.swms[i]); 00171 } 00172 00173 double ErrEstim = this.ADError[i].getEstimation(); 00174 if (this.ADError[i].setInput(correctlyClassifies ? 0 : 1)) { 00175 if (this.ADError[i].getEstimation() > ErrEstim) { 00176 Change = true; 00177 } 00178 } 00179 } 00180 if (Change) { 00181 numberOfChangesDetected++; 00182 double max = 0.0; 00183 int imax = -1; 00184 for (int i = 0; i < this.ensemble.length; i++) { 00185 if (max < this.ADError[i].getEstimation()) { 00186 max = this.ADError[i].getEstimation(); 00187 imax = i; 00188 } 00189 } 00190 if (imax != -1) { 00191 this.ensemble[imax].resetLearning(); 00192 //this.ensemble[imax].trainOnInstance(inst); 00193 this.ADError[imax] = new ADWIN((double) this.deltaAdwinOption.getValue()); 00194 this.scms[imax] = 0; 00195 this.swms[imax] = 0; 00196 } 00197 } 00198 } 00199 00200 protected double getEnsembleMemberWeight(int i) { 00201 double em = this.swms[i] / (this.scms[i] + this.swms[i]); 00202 if ((em == 0.0) || (em > 0.5)) { 00203 return this.logKm1; 00204 } 00205 return Math.log((1.0 - em) / em) + this.logKm1; 00206 } 00207 00208 @Override 00209 public double[] getVotesForInstance(Instance inst) { 00210 if (this.outputCodesOption.isSet()) { 00211 return getVotesForInstanceBinary(inst); 00212 } 00213 DoubleVector combinedVote = new DoubleVector(); 00214 for (int i = 0; i < this.ensemble.length; i++) { 00215 double memberWeight = getEnsembleMemberWeight(i); 00216 if (memberWeight > 0.0) { 00217 DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst)); 00218 if (vote.sumOfValues() > 0.0) { 00219 vote.normalize(); 00220 vote.scaleValues(memberWeight); 00221 combinedVote.addValues(vote); 00222 } 00223 } else { 00224 break; 00225 } 00226 } 00227 return combinedVote.getArrayRef(); 00228 } 00229 00230 public double[] getVotesForInstanceBinary(Instance inst) { 00231 double combinedVote[] = new double[(int) inst.numClasses()]; 00232 Instance weightedInst = (Instance) inst.copy(); 00233 if (this.initMatrixCodes == false) { 00234 for (int i = 0; i < this.ensemble.length; i++) { 00235 //Replace class by OC 00236 weightedInst.setClassValue((double) this.matrixCodes[i][(int) inst.classValue()]); 00237 00238 double vote[]; 00239 vote = this.ensemble[i].getVotesForInstance(weightedInst); 00240 //Binary Case 00241 int voteClass = 0; 00242 if (vote.length == 2) { 00243 voteClass = (vote[1] > vote[0] ? 1 : 0); 00244 } 00245 //Update votes 00246 for (int j = 0; j < inst.numClasses(); j++) { 00247 if (this.matrixCodes[i][j] == voteClass) { 00248 combinedVote[j] += getEnsembleMemberWeight(i); 00249 } 00250 } 00251 } 00252 } 00253 return combinedVote; 00254 } 00255 00256 @Override 00257 public boolean isRandomizable() { 00258 return true; 00259 } 00260 00261 @Override 00262 public void getModelDescription(StringBuilder out, int indent) { 00263 // TODO Auto-generated method stub 00264 } 00265 00266 @Override 00267 protected Measurement[] getModelMeasurementsImpl() { 00268 return new Measurement[]{new Measurement("ensemble size", 00269 this.ensemble != null ? this.ensemble.length : 0), 00270 new Measurement("change detections", this.numberOfChangesDetected) 00271 }; 00272 } 00273 00274 @Override 00275 public Classifier[] getSubClassifiers() { 00276 return this.ensemble.clone(); 00277 } 00278 }