MOA 12.03
Real Time Analytics for Data Streams
LimAttClassifier.java
Go to the documentation of this file.
00001 /*
00002  *    LimAttClassifier.java
00003  *    Copyright (C) 2010 University of Waikato, Hamilton, New Zealand
00004  *    @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
00005  *    @author Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz)
00006  *
00007  *    This program is free software; you can redistribute it and/or modify
00008  *    it under the terms of the GNU General Public License as published by
00009  *    the Free Software Foundation; either version 3 of the License, or
00010  *    (at your option) any later version.
00011  *
00012  *    This program is distributed in the hope that it will be useful,
00013  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00014  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00015  *    GNU General Public License for more details.
00016  *
00017  *    You should have received a copy of the GNU General Public License
00018  *    along with this program. If not, see <http://www.gnu.org/licenses/>.
00019  *    
00020  */
00021 package moa.classifiers.meta;
00022 
00023 import moa.classifiers.trees.LimAttHoeffdingTree;
00024 import weka.core.Instance;
00025 import weka.core.Utils;
00026 
00027 import java.math.BigInteger;
00028 import java.util.Arrays;
00029 import moa.classifiers.core.driftdetection.ADWIN;
00030 import moa.classifiers.AbstractClassifier;
00031 import moa.classifiers.Classifier;
00032 
00033 import moa.core.Measurement;
00034 import moa.options.ClassOption;
00035 import moa.options.FlagOption;
00036 import moa.options.FloatOption;
00037 import moa.options.IntOption;
00038 
00081 public class LimAttClassifier extends AbstractClassifier {
00082 
00083     @Override
00084     public String getPurposeString() {
00085         return "Ensemble Combining Restricted Hoeffding Trees using Stacking";
00086     }    
00087     
00088     /*
00089      * Class that generates all combinations of n elements, taken
00090      * r at a time. The algorithm is described by
00091      *
00092      * Kenneth H. Rosen, Discrete Mathematics and Its Applications,
00093      * 2nd edition (NY: McGraw-Hill, 1991), pp. 284-286.
00094      *
00095      *  @author Michael Gilleland (megilleland@yahoo.com)
00096      */
00097     public class CombinationGenerator {
00098 
00099         private int[] a;
00100 
00101         private int n;
00102 
00103         private int r;
00104 
00105         private BigInteger numLeft;
00106 
00107         private BigInteger total;
00108         //------------
00109         // Constructor
00110         //------------
00111 
00112         public CombinationGenerator(int n, int r) {
00113             if (r > n) {
00114                 throw new IllegalArgumentException();
00115             }
00116             if (n < 1) {
00117                 throw new IllegalArgumentException();
00118             }
00119             this.n = n;
00120             this.r = r;
00121             a = new int[r];
00122             BigInteger nFact = getFactorial(n);
00123             BigInteger rFact = getFactorial(r);
00124             BigInteger nminusrFact = getFactorial(n - r);
00125             total = nFact.divide(rFact.multiply(nminusrFact));
00126             reset();
00127         }
00128         //------
00129         // Reset
00130         //------
00131 
00132         public void reset() {
00133             for (int i = 0; i < a.length; i++) {
00134                 a[i] = i;
00135             }
00136             numLeft = new BigInteger(total.toString());
00137         }
00138         //------------------------------------------------
00139         // Return number of combinations not yet generated
00140         //------------------------------------------------
00141 
00142         public BigInteger getNumLeft() {
00143             return numLeft;
00144         }
00145         //-----------------------------
00146         // Are there more combinations?
00147         //-----------------------------
00148 
00149         public boolean hasMore() {
00150             return numLeft.compareTo(BigInteger.ZERO) == 1;
00151         }
00152         //------------------------------------
00153         // Return total number of combinations
00154         //------------------------------------
00155 
00156         public BigInteger getTotal() {
00157             return total;
00158         }
00159         //------------------
00160         // Compute factorial
00161         //------------------
00162 
00163         private BigInteger getFactorial(int n) {
00164             BigInteger fact = BigInteger.ONE;
00165             for (int i = n; i > 1; i--) {
00166                 fact = fact.multiply(new BigInteger(Integer.toString(i)));
00167             }
00168             return fact;
00169         }
00170         //--------------------------------------------------------
00171         // Generate next combination (algorithm from Rosen p. 286)
00172         //--------------------------------------------------------
00173 
00174         public int[] getNext() {
00175             if (numLeft.equals(total)) {
00176                 numLeft = numLeft.subtract(BigInteger.ONE);
00177                 int[] b = new int[a.length];
00178                 for (int k = 0; k < a.length; k++) {
00179                     b[k] = a[k];
00180                 }
00181                 return b;
00182             }
00183             int i = r - 1;
00184             while (a[i] == n - r + i) {
00185                 i--;
00186             }
00187             a[i] = a[i] + 1;
00188             for (int j = i + 1; j < r; j++) {
00189                 a[j] = a[i] + j - i;
00190             }
00191             numLeft = numLeft.subtract(BigInteger.ONE);
00192             int[] b = new int[a.length];
00193             for (int k = 0; k < a.length; k++) {
00194                 b[k] = a[k];
00195             }
00196             return b;
00197         }
00198     }
00199 
00200     private static final long serialVersionUID = 1L;
00201 
00202     public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
00203             "Classifier to train.", Classifier.class, "trees.LimAttHoeffdingTree");
00204 
00205     public IntOption numAttributesOption = new IntOption("numAttributes", 'n',
00206             "The number of attributes to use per model.", 1, 1, Integer.MAX_VALUE);
00207 
00208     public FloatOption weightShrinkOption = new FloatOption("weightShrink", 'w',
00209             "The number to multiply the weight misclassified counts.", 0.5, 0.0, Float.MAX_VALUE);
00210 
00211     public FloatOption deltaAdwinOption = new FloatOption("deltaAdwin", 'a',
00212             "Delta of Adwin change detection", 0.002, 0.0, 1.0);
00213 
00214     public FloatOption oddsOffsetOption = new FloatOption("oddsOffset", 'o',
00215             "Offset for odds to avoid probabilities that are zero.", 0.001, 0.0, Float.MAX_VALUE);
00216 
00217     public FlagOption pruneOption = new FlagOption("prune", 'x',
00218             "Enable pruning.");
00219 
00220     public FlagOption bigTreesOption = new FlagOption("bigTrees", 'b',
00221             "Use m-n attributes on the trees.");
00222 
00223     public IntOption numEnsemblePruningOption = new IntOption("numEnsemblePruning", 'm',
00224             "The pruned number of classifiers to use to predict.", 10, 1, Integer.MAX_VALUE);
00225 
00226     public FlagOption adwinReplaceWorstClassifierOption = new FlagOption("adwinReplaceWorstClassifier", 'z',
00227             "When one Adwin detects change, replace worst classifier.");
00228 
00229     protected Classifier[] ensemble;
00230 
00231     protected ADWIN[] ADError;
00232 
00233     protected int numberOfChangesDetected;
00234 
00235     protected int[][] matrixCodes;
00236 
00237     protected boolean initMatrixCodes = false;
00238 
00239     protected boolean initClassifiers = false;
00240 
00241     protected int numberAttributes = 1;
00242 
00243     protected int numInstances = 0;
00244 
00245     @Override
00246     public void resetLearningImpl() {
00247         this.initClassifiers = true;
00248         this.reset = true;
00249     }
00250 
00251     @Override
00252     public void trainOnInstanceImpl(Instance inst) {
00253         int numClasses = inst.numClasses();
00254         //Init Ensemble
00255         if (this.initClassifiers == true) {
00256             numberAttributes = numAttributesOption.getValue();
00257             if (bigTreesOption.isSet()) {
00258                 numberAttributes = inst.numAttributes() - 1 - numAttributesOption.getValue();
00259             }
00260             CombinationGenerator x = new CombinationGenerator(inst.numAttributes() - 1, this.numberAttributes);
00261             int numberClassifiers = x.getTotal().intValue();
00262             this.ensemble = new Classifier[numberClassifiers];
00263             Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
00264             baseLearner.resetLearning();
00265             for (int i = 0; i < this.ensemble.length; i++) {
00266                 this.ensemble[i] = baseLearner.copy();
00267             }
00268             this.ADError = new ADWIN[this.ensemble.length];
00269             for (int i = 0; i < this.ensemble.length; i++) {
00270                 this.ADError[i] = new ADWIN((double) this.deltaAdwinOption.getValue());
00271             }
00272             this.numberOfChangesDetected = 0;
00273             //Prepare combinations
00274             int i = 0;
00275             if (baseLearner instanceof LimAttHoeffdingTree) {
00276                 while (x.hasMore()) {
00277                     ((LimAttHoeffdingTree) this.ensemble[i]).setlistAttributes(x.getNext());
00278                     i++;
00279                 }
00280             }
00281 
00282             this.initClassifiers = false;
00283         }
00284 
00285         boolean Change = false;
00286         Instance weightedInst = (Instance) inst.copy();
00287 
00288         //Train Perceptron
00289         double[][] votes = new double[this.ensemble.length + 1][numClasses];
00290         for (int i = 0; i < this.ensemble.length; i++) {
00291             double[] v = new double[numClasses];
00292             for (int j = 0; j < v.length; j++) {
00293                 v[j] = (double) this.oddsOffsetOption.getValue();
00294             }
00295             double[] vt = this.ensemble[i].getVotesForInstance(inst);
00296             double sum = Utils.sum(vt);
00297             if (!Double.isNaN(sum) && (sum > 0)) {
00298                 for (int j = 0; j < vt.length; j++) {
00299                     vt[j] /= sum;
00300                 }
00301             } else {
00302                 // Just in case the base learner returns NaN
00303                 for (int k = 0; k < vt.length; k++) {
00304                     vt[k] = 0.0;
00305                 }
00306             }
00307             sum = numClasses * (double) this.oddsOffsetOption.getValue();
00308             for (int j = 0; j < vt.length; j++) {
00309                 v[j] += vt[j];
00310                 sum += vt[j];
00311             }
00312             for (int j = 0; j < vt.length; j++) {
00313                 votes[i][j] = Math.log(v[j] / (sum - v[j]));
00314             }
00315         }
00316 
00317         if (adwinReplaceWorstClassifierOption.isSet() == false) {
00318             //Train ensemble of classifiers
00319             for (int i = 0; i < this.ensemble.length; i++) {
00320                 boolean correctlyClassifies = this.ensemble[i].correctlyClassifies(weightedInst);
00321                 double ErrEstim = this.ADError[i].getEstimation();
00322                 if (this.ADError[i].setInput(correctlyClassifies ? 0 : 1)) {
00323                     numInstances = initialNumInstancesOption.getValue();
00324                     if (this.ADError[i].getEstimation() > ErrEstim) {
00325                         Change = true;
00326                         //Replace classifier if ADWIN has detected change
00327                         numberOfChangesDetected++;
00328                         this.ensemble[i].resetLearning();
00329                         this.ADError[i] = new ADWIN((double) this.deltaAdwinOption.getValue());
00330                         for (int ii = 0; ii < inst.numClasses(); ii++) {
00331                             weightAttribute[ii][i] = 0.0;// 0.2 * Math.random() - 0.1;
00332                         }
00333                     }
00334                 }
00335             }
00336         } else {
00337             //Train ensemble of classifiers
00338             for (int i = 0; i < this.ensemble.length; i++) {
00339                 boolean correctlyClassifies = this.ensemble[i].correctlyClassifies(weightedInst);
00340                 double ErrEstim = this.ADError[i].getEstimation();
00341                 if (this.ADError[i].setInput(correctlyClassifies ? 0 : 1)) {
00342                     if (this.ADError[i].getEstimation() > ErrEstim) {
00343                         Change = true;
00344                     }
00345                 }
00346             }
00347             //Replace classifier with higher error if ADWIN has detected change
00348             if (Change) {
00349                 numberOfChangesDetected++;
00350                 double max = 0.0;
00351                 int imax = -1;
00352                 for (int i = 0; i < this.ensemble.length; i++) {
00353                     if (max < this.ADError[i].getEstimation()) {
00354                         max = this.ADError[i].getEstimation();
00355                         imax = i;
00356                     }
00357                 }
00358                 if (imax != -1) {
00359                     this.ensemble[imax].resetLearning();
00360                     this.ADError[imax] = new ADWIN((double) this.deltaAdwinOption.getValue());
00361                     for (int ii = 0; ii < inst.numClasses(); ii++) {
00362                         weightAttribute[ii][imax] = 0.0;
00363                     }
00364                 }
00365             }
00366         }
00367 
00368         trainOnInstanceImplPerceptron(inst.numClasses(), (int) inst.classValue(), votes);
00369 
00370         for (int i = 0; i < this.ensemble.length; i++) {
00371             this.ensemble[i].trainOnInstance(inst);
00372         }
00373     }
00374 
00375     @Override
00376     public double[] getVotesForInstance(Instance inst) {
00377         if (this.initClassifiers == true) {
00378             return new double[0];
00379         }
00380         int numClasses = inst.numClasses();
00381 
00382         int sizeEnsemble = this.ensemble.length;
00383         if (pruneOption.isSet()) {
00384             sizeEnsemble = this.numEnsemblePruningOption.getValue();
00385         }
00386 
00387         double[][] votes = new double[sizeEnsemble + 1][numClasses];
00388         int[] bestClassifiers = new int[sizeEnsemble];
00389         if (pruneOption.isSet()) {
00390             //Check for the best classifiers
00391             double[] weight = new double[this.ensemble.length];
00392             for (int i = 0; i < numClasses; i++) {
00393                 for (int j = 0; j < this.ensemble.length; j++) {
00394                     weight[j] += weightAttribute[i][j];
00395                 }
00396             }
00397             Arrays.sort(weight);
00398             double cutValue = weight[this.ensemble.length - sizeEnsemble]; //reverse order
00399             int ii = 0;
00400             for (int j = 0; j < this.ensemble.length; j++) {
00401                 if (weight[j] >= cutValue && ii < sizeEnsemble) {
00402                     bestClassifiers[ii] = j;
00403                     ii++;
00404                 }
00405             }
00406         } else { //Not pruning: all classifiers
00407             for (int ii = 0; ii < sizeEnsemble; ii++) {
00408                 bestClassifiers[ii] = ii;
00409             }
00410         }
00411         for (int ii = 0; ii < sizeEnsemble; ii++) {
00412             int i = bestClassifiers[ii];
00413             double[] v = new double[numClasses];
00414             for (int j = 0; j < v.length; j++) {
00415                 v[j] = (double) this.oddsOffsetOption.getValue();
00416             }
00417             double[] vt = this.ensemble[i].getVotesForInstance(inst);
00418             double sum = Utils.sum(vt);
00419             if (!Double.isNaN(sum) && (sum > 0)) {
00420                 for (int j = 0; j < vt.length; j++) {
00421                     vt[j] /= sum;
00422                 }
00423             } else {
00424                 // Just in case the base learner returns NaN
00425                 for (int k = 0; k < vt.length; k++) {
00426                     vt[k] = 0.0;
00427                 }
00428             }
00429             sum = numClasses * (double) this.oddsOffsetOption.getValue();
00430             for (int j = 0; j < vt.length; j++) {
00431                 v[j] += vt[j];
00432                 sum += vt[j];
00433             }
00434             for (int j = 0; j < vt.length; j++) {
00435                 votes[ii][j] = Math.log(v[j] / (sum - v[j]));
00436                 //                    votes[i][j] = vt[j];
00437             }
00438         }
00439         return getVotesForInstancePerceptron(votes, bestClassifiers, inst.numClasses());
00440     }
00441 
00442     @Override
00443     public boolean isRandomizable() {
00444         return true;
00445     }
00446 
00447     @Override
00448     public void getModelDescription(StringBuilder out, int indent) {
00449         // TODO Auto-generated method stub
00450     }
00451 
00452     @Override
00453     protected Measurement[] getModelMeasurementsImpl() {
00454         return new Measurement[]{new Measurement("ensemble size",
00455                     this.ensemble != null ? this.ensemble.length : 0),
00456                     new Measurement("change detections", this.numberOfChangesDetected)
00457                 };
00458     }
00459 
00460     @Override
00461     public Classifier[] getSubClassifiers() {
00462         return this.ensemble.clone();
00463     }
00464 
00465     //Perceptron
00466     public FloatOption learningRatioOption = new FloatOption("learningRatio", 'r', "Learning ratio", 1);
00467 
00468     public FloatOption penaltyFactorOption = new FloatOption("lambda", 'p', "Lambda", 0.0);
00469 
00470     public IntOption initialNumInstancesOption = new IntOption("initialNumInstances", 'i', "initialNumInstances", 10);
00471 
00472     protected double[][] weightAttribute;
00473 
00474     protected boolean reset;
00475 
00476     public void trainOnInstanceImplPerceptron(int numClasses, int actualClass, double[][] votes) {
00477 
00478         //Init Perceptron
00479         if (this.reset == true) {
00480             this.reset = false;
00481             this.weightAttribute = new double[numClasses][votes.length];
00482             for (int i = 0; i < numClasses; i++) {
00483                 for (int j = 0; j < votes.length - 1; j++) {
00484                     weightAttribute[i][j] = 1.0 / (votes.length - 1.0);
00485                 }
00486             }
00487             numInstances = initialNumInstancesOption.getValue();
00488         }
00489 
00490         // Weight decay
00491         double learningRatio = learningRatioOption.getValue() * 2.0 / (numInstances + (votes.length - 1) + 2.0);
00492         double lambda = penaltyFactorOption.getValue();
00493         numInstances++;
00494 
00495         double[] preds = new double[numClasses];
00496 
00497         for (int i = 0; i < numClasses; i++) {
00498             preds[i] = prediction(votes, i);
00499         }
00500         for (int i = 0; i < numClasses; i++) {
00501             double actual = (i == actualClass) ? 1.0 : 0.0;
00502             double delta = (actual - preds[i]) * preds[i] * (1 - preds[i]);
00503             for (int j = 0; j < this.ensemble.length; j++) {
00504                 this.weightAttribute[i][j] += learningRatio * (delta * votes[j][i] - lambda * this.weightAttribute[i][j]);
00505             }
00506             this.weightAttribute[i][this.ensemble.length] += learningRatio * delta;
00507         }
00508     }
00509 
00510     public double predictionPruning(double[][] votes, int[] bestClassifiers, int classVal) {
00511         double sum = 0.0;
00512         for (int i = 0; i < votes.length - 1; i++) {
00513             sum += (double) weightAttribute[classVal][bestClassifiers[i]] * votes[i][classVal];
00514         }
00515         sum += weightAttribute[classVal][votes.length - 1];
00516         return 1.0 / (1.0 + Math.exp(-sum));
00517     }
00518 
00519     public double prediction(double[][] votes, int classVal) {
00520         double sum = 0.0;
00521         for (int i = 0; i < votes.length - 1; i++) {
00522             sum += (double) weightAttribute[classVal][i] * votes[i][classVal];
00523         }
00524         sum += weightAttribute[classVal][votes.length - 1];
00525         return 1.0 / (1.0 + Math.exp(-sum));
00526     }
00527 
00528     public double[] getVotesForInstancePerceptron(double[][] votesEnsemble, int[] bestClassifiers, int numClasses) {
00529         double[] votes = new double[numClasses];
00530         if (this.reset == false) {
00531             for (int i = 0; i < votes.length; i++) {
00532                 votes[i] = predictionPruning(votesEnsemble, bestClassifiers, i);
00533             }
00534         }
00535         return votes;
00536 
00537     }
00538 }
 All Classes Namespaces Files Functions Variables Enumerations