MOA 12.03
Real Time Analytics for Data Streams
NaiveBayesMultinomial.java
Go to the documentation of this file.
00001 /*
00002  *    NaiveBayesMultinomial.java
00003  *    Copyright (C) 2011 University of Waikato, Hamilton, New Zealand
00004  *    @author Bernhard Pfahringer (bernhard@cs.waikato.ac.nz)
00005  *
00006  *    This program is free software; you can redistribute it and/or modify
00007  *    it under the terms of the GNU General Public License as published by
00008  *    the Free Software Foundation; either version 3 of the License, or
00009  *    (at your option) any later version.
00010  *
00011  *    This program is distributed in the hope that it will be useful,
00012  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00013  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00014  *    GNU General Public License for more details.
00015  *
00016  *    You should have received a copy of the GNU General Public License
00017  *    along with this program. If not, see <http://www.gnu.org/licenses/>.
00018  *    
00019  */
00020 package moa.classifiers.bayes;
00021 
00022 import moa.core.Measurement;
00023 import moa.core.StringUtils;
00024 import moa.options.FloatOption;
00025 
00026 import weka.core.*;
00027 
00028 import java.util.*;
00029 import moa.classifiers.AbstractClassifier;
00030 
00054 public class NaiveBayesMultinomial extends AbstractClassifier {
00055 
00056     public FloatOption laplaceCorrectionOption = new FloatOption("laplaceCorrection",
00057             'l', "Laplace correction factor.",
00058             1.0, 0.00, Integer.MAX_VALUE);
00059 
00063     private static final long serialVersionUID = -7204398796974263187L;
00064 
00065     @Override
00066     public String getPurposeString() {
00067         return "Multinomial Naive Bayes classifier: performs classic bayesian prediction while making naive assumption that all inputs are independent.";
00068     }
00069 
00073     protected double[] m_classTotals;
00074 
00078     protected Instances m_headerInfo;
00079 
00083     protected int m_numClasses;
00084 
00088     protected double[] m_probOfClass;
00089 
00094     protected double[][] m_wordTotalForClass;
00095 
00096     protected boolean reset = false;
00097 
00098     @Override
00099     public void resetLearningImpl() {
00100         this.reset = true;
00101     }
00102 
00108     @Override
00109     public void trainOnInstanceImpl(Instance inst) {
00110         if (this.reset == true) {
00111             this.m_numClasses = inst.numClasses();
00112             double laplace = this.laplaceCorrectionOption.getValue();
00113             int numAttributes = inst.numAttributes();
00114 
00115             m_probOfClass = new double[m_numClasses];
00116             Arrays.fill(m_probOfClass, laplace);
00117 
00118             m_classTotals = new double[m_numClasses];
00119             Arrays.fill(m_classTotals, laplace * numAttributes);
00120 
00121             m_wordTotalForClass = new double[numAttributes][m_numClasses];
00122             for (double[] wordTotal : m_wordTotalForClass) {
00123                 Arrays.fill(wordTotal, laplace);
00124             }
00125             this.reset = false;
00126         }
00127         // Update classifier
00128         int classIndex = inst.classIndex();
00129         int classValue = (int) inst.value(classIndex);
00130 
00131         double w = inst.weight();
00132         m_probOfClass[classValue] += w;
00133 
00134         m_classTotals[classValue] += w * totalSize(inst);
00135         double total = m_classTotals[classValue];
00136 
00137         for (int i = 0; i < inst.numValues(); i++) {
00138             int index = inst.index(i);
00139             if (index != classIndex && !inst.isMissing(i)) {
00140                 m_wordTotalForClass[index][classValue] += w * inst.valueSparse(i);
00141             }
00142         }
00143     }
00144 
00152     @Override
00153     public double[] getVotesForInstance(Instance instance) {
00154         if (this.reset == true) {
00155             return new double[2];
00156         }
00157         double[] probOfClassGivenDoc = new double[m_numClasses];
00158         double totalSize = totalSize(instance);
00159 
00160         for (int i = 0; i < m_numClasses; i++) {
00161             probOfClassGivenDoc[i] = Math.log(m_probOfClass[i]) - totalSize * Math.log(m_classTotals[i]);
00162         }
00163 
00164         for (int i = 0; i < instance.numValues(); i++) {
00165 
00166             int index = instance.index(i);
00167             if (index == instance.classIndex() || instance.isMissing(i)) {
00168                 continue;
00169             }
00170 
00171             double wordCount = instance.valueSparse(i);
00172             for (int c = 0; c < m_numClasses; c++) {
00173                 probOfClassGivenDoc[c] += wordCount * Math.log(m_wordTotalForClass[index][c]);
00174             }
00175         }
00176 
00177         return Utils.logs2probs(probOfClassGivenDoc);
00178     }
00179 
00180     public double totalSize(Instance instance) {
00181         int classIndex = instance.classIndex();
00182         double total = 0.0;
00183         for (int i = 0; i < instance.numValues(); i++) {
00184             int index = instance.index(i);
00185             if (index == classIndex || instance.isMissing(i)) {
00186                 continue;
00187             }
00188             double count = instance.valueSparse(i);
00189             if (count >= 0) {
00190                 total += count;
00191             } else {
00192                 //throw new Exception("Numeric attribute value is not >= 0. " + i + " " + index + " " +
00193                 //                  instance.valueSparse(i) + " " + " " + instance);
00194             }
00195         }
00196         return total;
00197     }
00198 
00199     @Override
00200     protected Measurement[] getModelMeasurementsImpl() {
00201         return null;
00202     }
00203 
00204     @Override
00205     public void getModelDescription(StringBuilder result, int indent) {
00206         StringUtils.appendIndented(result, indent, "xxx MNB1 xxx\n\n");
00207 
00208         result.append("The independent probability of a class\n");
00209         result.append("--------------------------------------\n");
00210 
00211         for (int c = 0; c < m_numClasses; c++) {
00212             result.append(m_headerInfo.classAttribute().value(c)).append("\t").
00213                     append(Double.toString(m_probOfClass[c])).append("\n");
00214         }
00215 
00216         result.append("\nThe probability of a word given the class\n");
00217         result.append("-----------------------------------------\n\t");
00218 
00219         for (int c = 0; c < m_numClasses; c++) {
00220             result.append(m_headerInfo.classAttribute().value(c)).append("\t");
00221         }
00222 
00223         result.append("\n");
00224 
00225         for (int w = 0; w < m_headerInfo.numAttributes(); w++) {
00226             if (w == m_headerInfo.classIndex()) {
00227                 continue;
00228             }
00229             result.append(m_headerInfo.attribute(w).name()).append("\t");
00230             for (int c = 0; c < m_numClasses; c++) {
00231                 result.append(m_wordTotalForClass[w][c] / m_classTotals[c]).append("\t");
00232             }
00233             result.append("\n");
00234         }
00235         StringUtils.appendNewline(result);
00236     }
00237 
00238     @Override
00239     public boolean isRandomizable() {
00240         return false;
00241     }
00242 }
 All Classes Namespaces Files Functions Variables Enumerations