MOA 12.03
Real Time Analytics for Data Streams
EvaluatePeriodicHeldOutTest.java
Go to the documentation of this file.
00001 /*
00002  *    EvaluatePeriodicHeldOutTest.java
00003  *    Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
00004  *    @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
00005  *    @author Ammar Shaker (shaker@mathematik.uni-marburg.de)
00006  *
00007  *    This program is free software; you can redistribute it and/or modify
00008  *    it under the terms of the GNU General Public License as published by
00009  *    the Free Software Foundation; either version 3 of the License, or
00010  *    (at your option) any later version.
00011  *
00012  *    This program is distributed in the hope that it will be useful,
00013  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00014  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00015  *    GNU General Public License for more details.
00016  *
00017  *    You should have received a copy of the GNU General Public License
00018  *    along with this program. If not, see <http://www.gnu.org/licenses/>.
00019  *    
00020  */
00021 package moa.tasks;
00022 
00023 import java.io.File;
00024 import java.io.FileOutputStream;
00025 import java.io.PrintStream;
00026 import java.util.ArrayList;
00027 import java.util.List;
00028 
00029 import moa.classifiers.Classifier;
00030 import moa.core.Measurement;
00031 import moa.core.ObjectRepository;
00032 import moa.core.StringUtils;
00033 import moa.core.TimingUtils;
00034 import moa.evaluation.ClassificationPerformanceEvaluator;
00035 import moa.evaluation.LearningCurve;
00036 import moa.evaluation.LearningEvaluation;
00037 import moa.options.ClassOption;
00038 import moa.options.FileOption;
00039 import moa.options.FlagOption;
00040 import moa.options.IntOption;
00041 import moa.streams.CachedInstancesStream;
00042 import moa.streams.InstanceStream;
00043 import weka.core.Instance;
00044 import weka.core.Instances;
00045 
00052 public class EvaluatePeriodicHeldOutTest extends MainTask {
00053 
00054     @Override
00055     public String getPurposeString() {
00056         return "Evaluates a classifier on a stream by periodically testing on a heldout set.";
00057     }
00058 
00059     private static final long serialVersionUID = 1L;
00060 
00061     public ClassOption learnerOption = new ClassOption("learner", 'l',
00062             "Classifier to train.", Classifier.class, "trees.HoeffdingTree");
00063 
00064     public ClassOption streamOption = new ClassOption("stream", 's',
00065             "Stream to learn from.", InstanceStream.class,
00066             "generators.RandomTreeGenerator");
00067 
00068     public ClassOption evaluatorOption = new ClassOption("evaluator", 'e',
00069             "Classification performance evaluation method.",
00070             ClassificationPerformanceEvaluator.class,
00071             "BasicClassificationPerformanceEvaluator");
00072 
00073     public IntOption testSizeOption = new IntOption("testSize", 'n',
00074             "Number of testing examples.", 1000000, 0, Integer.MAX_VALUE);
00075 
00076     public IntOption trainSizeOption = new IntOption("trainSize", 'i',
00077             "Number of training examples, <1 = unlimited.", 0, 0,
00078             Integer.MAX_VALUE);
00079 
00080     public IntOption trainTimeOption = new IntOption("trainTime", 't',
00081             "Number of training seconds.", 10 * 60 * 60, 0, Integer.MAX_VALUE);
00082 
00083     public IntOption sampleFrequencyOption = new IntOption(
00084             "sampleFrequency",
00085             'f',
00086             "Number of training examples between samples of learning performance.",
00087             100000, 0, Integer.MAX_VALUE);
00088 
00089     public FileOption dumpFileOption = new FileOption("dumpFile", 'd',
00090             "File to append intermediate csv results to.", null, "csv", true);
00091 
00092     public FlagOption cacheTestOption = new FlagOption("cacheTest", 'c',
00093             "Cache test instances in memory.");
00094 
00095     @Override
00096     protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
00097         Classifier learner = (Classifier) getPreparedClassOption(this.learnerOption);
00098         InstanceStream stream = (InstanceStream) getPreparedClassOption(this.streamOption);
00099         ClassificationPerformanceEvaluator evaluator = (ClassificationPerformanceEvaluator) getPreparedClassOption(this.evaluatorOption);
00100         learner.setModelContext(stream.getHeader());
00101         long instancesProcessed = 0;
00102         LearningCurve learningCurve = new LearningCurve("evaluation instances");
00103         File dumpFile = this.dumpFileOption.getFile();
00104         PrintStream immediateResultStream = null;
00105         if (dumpFile != null) {
00106             try {
00107                 if (dumpFile.exists()) {
00108                     immediateResultStream = new PrintStream(
00109                             new FileOutputStream(dumpFile, true), true);
00110                 } else {
00111                     immediateResultStream = new PrintStream(
00112                             new FileOutputStream(dumpFile), true);
00113                 }
00114             } catch (Exception ex) {
00115                 throw new RuntimeException(
00116                         "Unable to open immediate result file: " + dumpFile, ex);
00117             }
00118         }
00119         boolean firstDump = true;
00120         InstanceStream testStream = null;
00121         int testSize = this.testSizeOption.getValue();
00122         if (this.cacheTestOption.isSet()) {
00123             monitor.setCurrentActivity("Caching test examples...", -1.0);
00124             Instances testInstances = new Instances(stream.getHeader(),
00125                     this.testSizeOption.getValue());
00126             while (testInstances.numInstances() < testSize) {
00127                 testInstances.add(stream.nextInstance());
00128                 if (testInstances.numInstances()
00129                         % INSTANCES_BETWEEN_MONITOR_UPDATES == 0) {
00130                     if (monitor.taskShouldAbort()) {
00131                         return null;
00132                     }
00133                     monitor.setCurrentActivityFractionComplete((double) testInstances.numInstances()
00134                             / (double) (this.testSizeOption.getValue()));
00135                 }
00136             }
00137             testStream = new CachedInstancesStream(testInstances);
00138         } else {
00139             //testStream = (InstanceStream) stream.copy();
00140             testStream = stream;
00141             /*monitor.setCurrentActivity("Skipping test examples...", -1.0);
00142             for (int i = 0; i < testSize; i++) {
00143             stream.nextInstance();
00144             }*/
00145         }
00146         instancesProcessed = 0;
00147         TimingUtils.enablePreciseTiming();
00148         double totalTrainTime = 0.0;
00149         while ((this.trainSizeOption.getValue() < 1
00150                 || instancesProcessed < this.trainSizeOption.getValue())
00151                 && stream.hasMoreInstances() == true) {
00152             monitor.setCurrentActivityDescription("Training...");
00153             long instancesTarget = instancesProcessed
00154                     + this.sampleFrequencyOption.getValue();
00155             long trainStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread();
00156             while (instancesProcessed < instancesTarget && stream.hasMoreInstances() == true) {
00157                 learner.trainOnInstance(stream.nextInstance());
00158                 instancesProcessed++;
00159                 if (instancesProcessed % INSTANCES_BETWEEN_MONITOR_UPDATES == 0) {
00160                     if (monitor.taskShouldAbort()) {
00161                         return null;
00162                     }
00163                     monitor.setCurrentActivityFractionComplete((double) (instancesProcessed)
00164                             / (double) (this.trainSizeOption.getValue()));
00165                 }
00166             }
00167             double lastTrainTime = TimingUtils.nanoTimeToSeconds(TimingUtils.getNanoCPUTimeOfCurrentThread()
00168                     - trainStartTime);
00169             totalTrainTime += lastTrainTime;
00170             if (totalTrainTime > this.trainTimeOption.getValue()) {
00171                 break;
00172             }
00173             if (this.cacheTestOption.isSet()) {
00174                 testStream.restart();
00175             } 
00176             evaluator.reset();
00177             long testInstancesProcessed = 0;
00178             monitor.setCurrentActivityDescription("Testing (after "
00179                     + StringUtils.doubleToString(
00180                     ((double) (instancesProcessed)
00181                     / (double) (this.trainSizeOption.getValue()) * 100.0), 2)
00182                     + "% training)...");
00183             long testStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread();
00184             int instCount = 0 ;
00185             for (instCount = 0; instCount < testSize; instCount++) {
00186                                 if (stream.hasMoreInstances() == false) {
00187                                         break;
00188                                 }
00189                 Instance testInst = (Instance) testStream.nextInstance().copy();
00190                 double trueClass = testInst.classValue();
00191                 testInst.setClassMissing();
00192                 double[] prediction = learner.getVotesForInstance(testInst);
00193                 testInst.setClassValue(trueClass);
00194                 evaluator.addResult(testInst, prediction);
00195                 testInstancesProcessed++;
00196                 if (testInstancesProcessed % INSTANCES_BETWEEN_MONITOR_UPDATES == 0) {
00197                     if (monitor.taskShouldAbort()) {
00198                         return null;
00199                     }
00200                     monitor.setCurrentActivityFractionComplete((double) testInstancesProcessed
00201                             / (double) (testSize));
00202                 }
00203             }
00204                 if ( instCount != testSize) {
00205                                 break;
00206                         }
00207             double testTime = TimingUtils.nanoTimeToSeconds(TimingUtils.getNanoCPUTimeOfCurrentThread()
00208                     - testStartTime);
00209             List<Measurement> measurements = new ArrayList<Measurement>();
00210             measurements.add(new Measurement("evaluation instances",                            
00211                     instancesProcessed));
00212             measurements.add(new Measurement("total train time", totalTrainTime));
00213             measurements.add(new Measurement("total train speed",
00214                     instancesProcessed / totalTrainTime));
00215             measurements.add(new Measurement("last train time", lastTrainTime));
00216             measurements.add(new Measurement("last train speed",
00217                     this.sampleFrequencyOption.getValue() / lastTrainTime));
00218             measurements.add(new Measurement("test time", testTime));
00219             measurements.add(new Measurement("test speed", this.testSizeOption.getValue()
00220                     / testTime));
00221             Measurement[] performanceMeasurements = evaluator.getPerformanceMeasurements();
00222             for (Measurement measurement : performanceMeasurements) {
00223                 measurements.add(measurement);
00224             }
00225             Measurement[] modelMeasurements = learner.getModelMeasurements();
00226             for (Measurement measurement : modelMeasurements) {
00227                 measurements.add(measurement);
00228             }
00229             learningCurve.insertEntry(new LearningEvaluation(measurements.toArray(new Measurement[measurements.size()])));
00230             if (immediateResultStream != null) {
00231                 if (firstDump) {
00232                     immediateResultStream.println(learningCurve.headerToString());
00233                     firstDump = false;
00234                 }
00235                 immediateResultStream.println(learningCurve.entryToString(learningCurve.numEntries() - 1));
00236                 immediateResultStream.flush();
00237             }
00238             if (monitor.resultPreviewRequested()) {
00239                 monitor.setLatestResultPreview(learningCurve.copy());
00240             }
00241             // if (learner instanceof HoeffdingTree
00242             // || learner instanceof HoeffdingOptionTree) {
00243             // int numActiveNodes = (int) Measurement.getMeasurementNamed(
00244             // "active learning leaves",
00245             // modelMeasurements).getValue();
00246             // // exit if tree frozen
00247             // if (numActiveNodes < 1) {
00248             // break;
00249             // }
00250             // int numNodes = (int) Measurement.getMeasurementNamed(
00251             // "tree size (nodes)", modelMeasurements)
00252             // .getValue();
00253             // if (numNodes == lastNumNodes) {
00254             // noGrowthCount++;
00255             // } else {
00256             // noGrowthCount = 0;
00257             // }
00258             // lastNumNodes = numNodes;
00259             // } else if (learner instanceof OzaBoost || learner instanceof
00260             // OzaBag) {
00261             // double numActiveNodes = Measurement.getMeasurementNamed(
00262             // "[avg] active learning leaves",
00263             // modelMeasurements).getValue();
00264             // // exit if all trees frozen
00265             // if (numActiveNodes == 0.0) {
00266             // break;
00267             // }
00268             // int numNodes = (int) (Measurement.getMeasurementNamed(
00269             // "[avg] tree size (nodes)",
00270             // learner.getModelMeasurements()).getValue() * Measurement
00271             // .getMeasurementNamed("ensemble size",
00272             // modelMeasurements).getValue());
00273             // if (numNodes == lastNumNodes) {
00274             // noGrowthCount++;
00275             // } else {
00276             // noGrowthCount = 0;
00277             // }
00278             // lastNumNodes = numNodes;
00279             // }
00280         }
00281         if (immediateResultStream != null) {
00282             immediateResultStream.close();
00283         }
00284         return learningCurve;
00285     }
00286 
00287     @Override
00288     public Class<?> getTaskResultType() {
00289         return LearningCurve.class;
00290     }
00291 }
 All Classes Namespaces Files Functions Variables Enumerations