MOA 12.03
Real Time Analytics for Data Streams
OzaBagASHT.java
Go to the documentation of this file.
00001 /*
00002  *    OzaBagASHT.java
00003  *    Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
00004  *    @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
00005  *
00006  *    This program is free software; you can redistribute it and/or modify
00007  *    it under the terms of the GNU General Public License as published by
00008  *    the Free Software Foundation; either version 3 of the License, or
00009  *    (at your option) any later version.
00010  *
00011  *    This program is distributed in the hope that it will be useful,
00012  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00013  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00014  *    GNU General Public License for more details.
00015  *
00016  *    You should have received a copy of the GNU General Public License
00017  *    along with this program. If not, see <http://www.gnu.org/licenses/>.
00018  *    
00019  */
00020 package moa.classifiers.meta;
00021 
00022 import moa.classifiers.Classifier;
00023 import moa.classifiers.trees.ASHoeffdingTree;
00024 import moa.core.DoubleVector;
00025 import moa.core.MiscUtils;
00026 import moa.options.IntOption;
00027 import moa.options.FlagOption;
00028 import weka.core.Instance;
00029 import weka.core.Utils;
00030 
00085 public class OzaBagASHT extends OzaBag {
00086 
00087     private static final long serialVersionUID = 1L;
00088 
00089     @Override
00090     public String getPurposeString() {
00091         return "Bagging using trees of different size.";
00092     }
00093     
00094     public IntOption firstClassifierSizeOption = new IntOption("firstClassifierSize", 'f',
00095             "The size of first classifier in the bag.", 1, 1, Integer.MAX_VALUE);
00096 
00097     public FlagOption useWeightOption = new FlagOption("useWeight",
00098             'u', "Enable weight classifiers.");
00099 
00100     public FlagOption resetTreesOption = new FlagOption("resetTrees",
00101             'r', "Reset trees when size is higher than the max.");
00102 
00103     protected double[] error;
00104 
00105     protected double alpha = 0.01;
00106 
00107     @Override
00108     public void resetLearningImpl() {
00109         this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
00110         this.error = new double[this.ensembleSizeOption.getValue()];
00111         Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
00112         baseLearner.resetLearning();
00113         int pow = this.firstClassifierSizeOption.getValue(); //EXTENSION TO ASHT
00114         for (int i = 0; i < this.ensemble.length; i++) {
00115             this.ensemble[i] = baseLearner.copy();
00116             this.error[i] = 0.0;
00117             ((ASHoeffdingTree) this.ensemble[i]).setMaxSize(pow); //EXTENSION TO ASHT
00118             if ((this.resetTreesOption != null)
00119                     && this.resetTreesOption.isSet()) {
00120                 ((ASHoeffdingTree) this.ensemble[i]).setResetTree();
00121             }
00122             pow *= 2; //EXTENSION TO ASHT
00123         }
00124     }
00125 
00126     @Override
00127     public void trainOnInstanceImpl(Instance inst) {
00128         int trueClass = (int) inst.classValue();
00129         for (int i = 0; i < this.ensemble.length; i++) {
00130             int k = MiscUtils.poisson(1.0, this.classifierRandom);
00131             if (k > 0) {
00132                 Instance weightedInst = (Instance) inst.copy();
00133                 weightedInst.setWeight(inst.weight() * k);
00134                 if (Utils.maxIndex(this.ensemble[i].getVotesForInstance(inst)) == trueClass) {
00135                     this.error[i] += alpha * (0.0 - this.error[i]); //EWMA
00136                 } else {
00137                     this.error[i] += alpha * (1.0 - this.error[i]); //EWMA
00138                 }
00139                 this.ensemble[i].trainOnInstance(weightedInst);
00140             }
00141         }
00142     }
00143 
00144     public double[] getVotesForInstance(Instance inst) {
00145         DoubleVector combinedVote = new DoubleVector();
00146         for (int i = 0; i < this.ensemble.length; i++) {
00147             DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst));
00148             if (vote.sumOfValues() > 0.0) {
00149                 vote.normalize();
00150                 if ((this.useWeightOption != null)
00151                         && this.useWeightOption.isSet()) {
00152                     vote.scaleValues(1.0 / (this.error[i] * this.error[i]));
00153                 }
00154                 combinedVote.addValues(vote);
00155             }
00156         }
00157         return combinedVote.getArrayRef();
00158     }
00159 
00160     @Override
00161     public void getModelDescription(StringBuilder out, int indent) {
00162         // TODO Auto-generated method stub
00163     }
00164 }
 All Classes Namespaces Files Functions Variables Enumerations