MOA 12.03
Real Time Analytics for Data Streams
OzaBoostAdwin.java
Go to the documentation of this file.
00001 /*
00002  *    OzaBoostAdwin.java
00003  *    Copyright (C) 2010 University of Waikato, Hamilton, New Zealand
00004  *    @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
00005  *
00006  *    This program is free software; you can redistribute it and/or modify
00007  *    it under the terms of the GNU General Public License as published by
00008  *    the Free Software Foundation; either version 3 of the License, or
00009  *    (at your option) any later version.
00010  *
00011  *    This program is distributed in the hope that it will be useful,
00012  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00013  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00014  *    GNU General Public License for more details.
00015  *
00016  *    You should have received a copy of the GNU General Public License
00017  *    along with this program. If not, see <http://www.gnu.org/licenses/>.
00018  *    
00019  */
00020 package moa.classifiers.meta;
00021 
00022 import moa.classifiers.core.driftdetection.ADWIN;
00023 import moa.classifiers.AbstractClassifier;
00024 import moa.classifiers.Classifier;
00025 import moa.core.DoubleVector;
00026 import moa.core.Measurement;
00027 import moa.core.MiscUtils;
00028 import moa.options.ClassOption;
00029 import moa.options.FlagOption;
00030 import moa.options.FloatOption;
00031 import moa.options.IntOption;
00032 import weka.core.Instance;
00033 
00040 public class OzaBoostAdwin extends AbstractClassifier {
00041 
00042     private static final long serialVersionUID = 1L;
00043 
00044     @Override
00045     public String getPurposeString() {
00046         return "Boosting for evolving data streams using ADWIN.";
00047     }
00048         
00049     public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
00050             "Classifier to train.", Classifier.class, "trees.HoeffdingTree");
00051 
00052     public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
00053             "The number of models to boost.", 10, 1, Integer.MAX_VALUE);
00054 
00055     public FlagOption pureBoostOption = new FlagOption("pureBoost", 'p',
00056             "Boost with weights only; no poisson.");
00057 
00058     public FloatOption deltaAdwinOption = new FloatOption("deltaAdwin", 'a',
00059             "Delta of Adwin change detection", 0.002, 0.0, 1.0);
00060 
00061     public FlagOption outputCodesOption = new FlagOption("outputCodes", 'o',
00062             "Use Output Codes to use binary classifiers.");
00063 
00064     public FlagOption sammeOption = new FlagOption("same", 'e',
00065             "Use Samme Algorithm.");
00066 
00067     protected Classifier[] ensemble;
00068 
00069     protected double[] scms;
00070 
00071     protected double[] swms;
00072 
00073     protected ADWIN[] ADError;
00074 
00075     protected int numberOfChangesDetected;
00076 
00077     protected int[][] matrixCodes;
00078 
00079     protected boolean initMatrixCodes = false;
00080 
00081     protected double logKm1 = 0.0;
00082 
00083     protected int Km1 = 1;
00084 
00085     protected boolean initKm1 = false;
00086 
00087     @Override
00088     public void resetLearningImpl() {
00089         this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
00090         Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
00091         baseLearner.resetLearning();
00092         for (int i = 0; i < this.ensemble.length; i++) {
00093             this.ensemble[i] = baseLearner.copy();
00094         }
00095         this.scms = new double[this.ensemble.length];
00096         this.swms = new double[this.ensemble.length];
00097         this.ADError = new ADWIN[this.ensemble.length];
00098         for (int i = 0; i < this.ensemble.length; i++) {
00099             this.ADError[i] = new ADWIN((double) this.deltaAdwinOption.getValue());
00100         }
00101         this.numberOfChangesDetected = 0;
00102         if (this.outputCodesOption.isSet()) {
00103             this.initMatrixCodes = true;
00104         }
00105         if (this.sammeOption.isSet()) {
00106             this.initKm1 = true;
00107         }
00108 
00109     }
00110 
00111     @Override
00112     public void trainOnInstanceImpl(Instance inst) {
00113         int numClasses = inst.numClasses();
00114         // Set log (k-1) and (k-1) for SAMME Method
00115         if (this.sammeOption.isSet()) {
00116             this.Km1 = numClasses - 1;
00117             this.logKm1 = Math.log(this.Km1);
00118             this.initKm1 = false;
00119         }
00120         //Output Codes
00121         if (this.initMatrixCodes == true) {
00122 
00123             this.matrixCodes = new int[this.ensemble.length][inst.numClasses()];
00124             for (int i = 0; i < this.ensemble.length; i++) {
00125                 int numberOnes;
00126                 int numberZeros;
00127 
00128                 do { // until we have the same number of zeros and ones
00129                     numberOnes = 0;
00130                     numberZeros = 0;
00131                     for (int j = 0; j < numClasses; j++) {
00132                         int result = 0;
00133                         if (j == 1 && numClasses == 2) {
00134                             result = 1 - this.matrixCodes[i][0];
00135                         } else {
00136                             result = (this.classifierRandom.nextBoolean() ? 1 : 0);
00137                         }
00138                         this.matrixCodes[i][j] = result;
00139                         if (result == 1) {
00140                             numberOnes++;
00141                         } else {
00142                             numberZeros++;
00143                         }
00144                     }
00145                 } while ((numberOnes - numberZeros) * (numberOnes - numberZeros) > (this.ensemble.length % 2));
00146 
00147             }
00148             this.initMatrixCodes = false;
00149         }
00150 
00151 
00152         boolean Change = false;
00153         double lambda_d = 1.0;
00154         Instance weightedInst = (Instance) inst.copy();
00155         for (int i = 0; i < this.ensemble.length; i++) {
00156             double k = this.pureBoostOption.isSet() ? lambda_d : MiscUtils.poisson(lambda_d * this.Km1, this.classifierRandom);
00157             if (k > 0.0) {
00158                 if (this.outputCodesOption.isSet()) {
00159                     weightedInst.setClassValue((double) this.matrixCodes[i][(int) inst.classValue()]);
00160                 }
00161                 weightedInst.setWeight(inst.weight() * k);
00162                 this.ensemble[i].trainOnInstance(weightedInst);
00163             }
00164             boolean correctlyClassifies = this.ensemble[i].correctlyClassifies(weightedInst);
00165             if (correctlyClassifies) {
00166                 this.scms[i] += lambda_d;
00167                 lambda_d *= this.trainingWeightSeenByModel / (2 * this.scms[i]);
00168             } else {
00169                 this.swms[i] += lambda_d;
00170                 lambda_d *= this.trainingWeightSeenByModel / (2 * this.swms[i]);
00171             }
00172 
00173             double ErrEstim = this.ADError[i].getEstimation();
00174             if (this.ADError[i].setInput(correctlyClassifies ? 0 : 1)) {
00175                 if (this.ADError[i].getEstimation() > ErrEstim) {
00176                     Change = true;
00177                 }
00178             }
00179         }
00180         if (Change) {
00181             numberOfChangesDetected++;
00182             double max = 0.0;
00183             int imax = -1;
00184             for (int i = 0; i < this.ensemble.length; i++) {
00185                 if (max < this.ADError[i].getEstimation()) {
00186                     max = this.ADError[i].getEstimation();
00187                     imax = i;
00188                 }
00189             }
00190             if (imax != -1) {
00191                 this.ensemble[imax].resetLearning();
00192                 //this.ensemble[imax].trainOnInstance(inst);
00193                 this.ADError[imax] = new ADWIN((double) this.deltaAdwinOption.getValue());
00194                 this.scms[imax] = 0;
00195                 this.swms[imax] = 0;
00196             }
00197         }
00198     }
00199 
00200     protected double getEnsembleMemberWeight(int i) {
00201         double em = this.swms[i] / (this.scms[i] + this.swms[i]);
00202         if ((em == 0.0) || (em > 0.5)) {
00203             return this.logKm1;
00204         }
00205         return Math.log((1.0 - em) / em) + this.logKm1;
00206     }
00207 
00208     @Override
00209     public double[] getVotesForInstance(Instance inst) {
00210         if (this.outputCodesOption.isSet()) {
00211             return getVotesForInstanceBinary(inst);
00212         }
00213         DoubleVector combinedVote = new DoubleVector();
00214         for (int i = 0; i < this.ensemble.length; i++) {
00215             double memberWeight = getEnsembleMemberWeight(i);
00216             if (memberWeight > 0.0) {
00217                 DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst));
00218                 if (vote.sumOfValues() > 0.0) {
00219                     vote.normalize();
00220                     vote.scaleValues(memberWeight);
00221                     combinedVote.addValues(vote);
00222                 }
00223             } else {
00224                 break;
00225             }
00226         }
00227         return combinedVote.getArrayRef();
00228     }
00229 
00230     public double[] getVotesForInstanceBinary(Instance inst) {
00231         double combinedVote[] = new double[(int) inst.numClasses()];
00232         Instance weightedInst = (Instance) inst.copy();
00233         if (this.initMatrixCodes == false) {
00234             for (int i = 0; i < this.ensemble.length; i++) {
00235                 //Replace class by OC
00236                 weightedInst.setClassValue((double) this.matrixCodes[i][(int) inst.classValue()]);
00237 
00238                 double vote[];
00239                 vote = this.ensemble[i].getVotesForInstance(weightedInst);
00240                 //Binary Case
00241                 int voteClass = 0;
00242                 if (vote.length == 2) {
00243                     voteClass = (vote[1] > vote[0] ? 1 : 0);
00244                 }
00245                 //Update votes
00246                 for (int j = 0; j < inst.numClasses(); j++) {
00247                     if (this.matrixCodes[i][j] == voteClass) {
00248                         combinedVote[j] += getEnsembleMemberWeight(i);
00249                     }
00250                 }
00251             }
00252         }
00253         return combinedVote;
00254     }
00255 
00256     @Override
00257     public boolean isRandomizable() {
00258         return true;
00259     }
00260 
00261     @Override
00262     public void getModelDescription(StringBuilder out, int indent) {
00263         // TODO Auto-generated method stub
00264     }
00265 
00266     @Override
00267     protected Measurement[] getModelMeasurementsImpl() {
00268         return new Measurement[]{new Measurement("ensemble size",
00269                     this.ensemble != null ? this.ensemble.length : 0),
00270                     new Measurement("change detections", this.numberOfChangesDetected)
00271                 };
00272     }
00273 
00274     @Override
00275     public Classifier[] getSubClassifiers() {
00276         return this.ensemble.clone();
00277     }
00278 }
 All Classes Namespaces Files Functions Variables Enumerations