MOA 12.03
Real Time Analytics for Data Streams
AbstractClassifier.java
Go to the documentation of this file.
00001 /*
00002  *    AbstractClassifier.java
00003  *    Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
00004  *    @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
00005  *
00006  *    This program is free software; you can redistribute it and/or modify
00007  *    it under the terms of the GNU General Public License as published by
00008  *    the Free Software Foundation; either version 3 of the License, or
00009  *    (at your option) any later version.
00010  *
00011  *    This program is distributed in the hope that it will be useful,
00012  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00013  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00014  *    GNU General Public License for more details.
00015  *
00016  *    You should have received a copy of the GNU General Public License
00017  *    along with this program. If not, see <http://www.gnu.org/licenses/>.
00018  *    
00019  */
00020 package moa.classifiers;
00021 
00022 import java.util.Arrays;
00023 import java.util.LinkedList;
00024 import java.util.List;
00025 import java.util.Random;
00026 import moa.core.InstancesHeader;
00027 import moa.core.Measurement;
00028 import moa.core.ObjectRepository;
00029 import moa.core.StringUtils;
00030 import moa.gui.AWTRenderer;
00031 import moa.options.AbstractOptionHandler;
00032 import moa.options.IntOption;
00033 import moa.tasks.TaskMonitor;
00034 import weka.core.Instance;
00035 import weka.core.Instances;
00036 import weka.core.Utils;
00037 
00045 public abstract class AbstractClassifier extends AbstractOptionHandler
00046         implements Classifier {
00047 
00048     @Override
00049     public String getPurposeString() {
00050         return "MOA Classifier: " + getClass().getCanonicalName();
00051     }
00052 
00054     protected InstancesHeader modelContext;
00055 
00057     protected double trainingWeightSeenByModel = 0.0;
00058 
00060     protected int randomSeed = 1;
00061 
00063     public IntOption randomSeedOption;
00064 
00066     public Random classifierRandom;
00067 
00072     public AbstractClassifier() {
00073         if (isRandomizable()) {
00074             this.randomSeedOption = new IntOption("randomSeed", 'r',
00075                     "Seed for random behaviour of the classifier.", 1);
00076         }
00077     }
00078 
00079     @Override
00080     public void prepareForUseImpl(TaskMonitor monitor,
00081             ObjectRepository repository) {
00082         if (this.randomSeedOption != null) {
00083             this.randomSeed = this.randomSeedOption.getValue();
00084         }
00085         if (!trainingHasStarted()) {
00086             resetLearning();
00087         }
00088     }
00089 
00090     @Override
00091     public void setModelContext(InstancesHeader ih) {
00092         if ((ih != null) && (ih.classIndex() < 0)) {
00093             throw new IllegalArgumentException(
00094                     "Context for a classifier must include a class to learn");
00095         }
00096         if (trainingHasStarted()
00097                 && (this.modelContext != null)
00098                 && ((ih == null) || !contextIsCompatible(this.modelContext, ih))) {
00099             throw new IllegalArgumentException(
00100                     "New context is not compatible with existing model");
00101         }
00102         this.modelContext = ih;
00103     }
00104 
00105     @Override
00106     public InstancesHeader getModelContext() {
00107         return this.modelContext;
00108     }
00109 
00110     @Override
00111     public void setRandomSeed(int s) {
00112         this.randomSeed = s;
00113         if (this.randomSeedOption != null) {
00114             // keep option consistent
00115             this.randomSeedOption.setValue(s);
00116         }
00117     }
00118 
00119     @Override
00120     public boolean trainingHasStarted() {
00121         return this.trainingWeightSeenByModel > 0.0;
00122     }
00123 
00124     @Override
00125     public double trainingWeightSeenByModel() {
00126         return this.trainingWeightSeenByModel;
00127     }
00128 
00129     @Override
00130     public void resetLearning() {
00131         this.trainingWeightSeenByModel = 0.0;
00132         if (isRandomizable()) {
00133             this.classifierRandom = new Random(this.randomSeed);
00134         }
00135         resetLearningImpl();
00136     }
00137 
00138     @Override
00139     public void trainOnInstance(Instance inst) {
00140         if (inst.weight() > 0.0) {
00141             this.trainingWeightSeenByModel += inst.weight();
00142             trainOnInstanceImpl(inst);
00143         }
00144     }
00145 
00146     @Override
00147     public Measurement[] getModelMeasurements() {
00148         List<Measurement> measurementList = new LinkedList<Measurement>();
00149         measurementList.add(new Measurement("model training instances",
00150                 trainingWeightSeenByModel()));
00151         measurementList.add(new Measurement("model serialized size (bytes)",
00152                 measureByteSize()));
00153         Measurement[] modelMeasurements = getModelMeasurementsImpl();
00154         if (modelMeasurements != null) {
00155             measurementList.addAll(Arrays.asList(modelMeasurements));
00156         }
00157         // add average of sub-model measurements
00158         Classifier[] subModels = getSubClassifiers();
00159         if ((subModels != null) && (subModels.length > 0)) {
00160             List<Measurement[]> subMeasurements = new LinkedList<Measurement[]>();
00161             for (Classifier subModel : subModels) {
00162                 if (subModel != null) {
00163                     subMeasurements.add(subModel.getModelMeasurements());
00164                 }
00165             }
00166             Measurement[] avgMeasurements = Measurement.averageMeasurements(subMeasurements.toArray(new Measurement[subMeasurements.size()][]));
00167             measurementList.addAll(Arrays.asList(avgMeasurements));
00168         }
00169         return measurementList.toArray(new Measurement[measurementList.size()]);
00170     }
00171 
00172     @Override
00173     public void getDescription(StringBuilder out, int indent) {
00174         StringUtils.appendIndented(out, indent, "Model type: ");
00175         out.append(this.getClass().getName());
00176         StringUtils.appendNewline(out);
00177         Measurement.getMeasurementsDescription(getModelMeasurements(), out,
00178                 indent);
00179         StringUtils.appendNewlineIndented(out, indent, "Model description:");
00180         StringUtils.appendNewline(out);
00181         if (trainingHasStarted()) {
00182             getModelDescription(out, indent);
00183         } else {
00184             StringUtils.appendIndented(out, indent,
00185                     "Model has not been trained.");
00186         }
00187     }
00188 
00189     @Override
00190     public Classifier[] getSubClassifiers() {
00191         return null;
00192     }
00193 
00194     @Override
00195     public Classifier copy() {
00196         return (Classifier) super.copy();
00197     }
00198 
00199     @Override
00200     public boolean correctlyClassifies(Instance inst) {
00201         return Utils.maxIndex(getVotesForInstance(inst)) == (int) inst.classValue();
00202     }
00203 
00209     public String getClassNameString() {
00210         return InstancesHeader.getClassNameString(this.modelContext);
00211     }
00212 
00219     public String getClassLabelString(int classLabelIndex) {
00220         return InstancesHeader.getClassLabelString(this.modelContext,
00221                 classLabelIndex);
00222     }
00223 
00230     public String getAttributeNameString(int attIndex) {
00231         return InstancesHeader.getAttributeNameString(this.modelContext,
00232                 attIndex);
00233     }
00234 
00242     public String getNominalValueString(int attIndex, int valIndex) {
00243         return InstancesHeader.getNominalValueString(this.modelContext,
00244                 attIndex, valIndex);
00245     }
00246 
00247 
00265     public static boolean contextIsCompatible(InstancesHeader originalContext,
00266             InstancesHeader newContext) {
00267 
00268         if (newContext.numClasses() < originalContext.numClasses()) {
00269             return false; // rule 1
00270         }
00271         if (newContext.numAttributes() < originalContext.numAttributes()) {
00272             return false; // rule 2
00273         }
00274         int oPos = 0;
00275         int nPos = 0;
00276         while (oPos < originalContext.numAttributes()) {
00277             if (oPos == originalContext.classIndex()) {
00278                 oPos++;
00279                 if (!(oPos < originalContext.numAttributes())) {
00280                     break;
00281                 }
00282             }
00283             if (nPos == newContext.classIndex()) {
00284                 nPos++;
00285             }
00286             if (originalContext.attribute(oPos).isNominal()) {
00287                 if (!newContext.attribute(nPos).isNominal()) {
00288                     return false; // rule 4
00289                 }
00290                 if (newContext.attribute(nPos).numValues() < originalContext.attribute(oPos).numValues()) {
00291                     return false; // rule 3
00292                 }
00293             } else {
00294                 assert (originalContext.attribute(oPos).isNumeric());
00295                 if (!newContext.attribute(nPos).isNumeric()) {
00296                     return false; // rule 4
00297                 }
00298             }
00299             oPos++;
00300             nPos++;
00301         }
00302         return true; // all checks clear
00303     }
00304 
00310     @Override
00311     public AWTRenderer getAWTRenderer() {
00312         // TODO should return a default renderer here
00313         // - or should null be interpreted as the default?
00314         return null;
00315     }
00316 
00317 
00326     public abstract void resetLearningImpl();
00327 
00337     public abstract void trainOnInstanceImpl(Instance inst);
00338 
00348     protected abstract Measurement[] getModelMeasurementsImpl();
00349 
00356     public abstract void getModelDescription(StringBuilder out, int indent);
00357 
00366     protected static int modelAttIndexToInstanceAttIndex(int index,
00367             Instance inst) {
00368         return inst.classIndex() > index ? index : index + 1;
00369     }
00370 
00379     protected static int modelAttIndexToInstanceAttIndex(int index,
00380             Instances insts) {
00381         return insts.classIndex() > index ? index : index + 1;
00382     }
00383 }
 All Classes Namespaces Files Functions Variables Enumerations