MOA 12.03
Real Time Analytics for Data Streams
SPegasos.java
Go to the documentation of this file.
00001 /*
00002  *    SPegasos.java
00003  *    Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
00004  *    @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
00005  *
00006  *    This program is free software; you can redistribute it and/or modify
00007  *    it under the terms of the GNU General Public License as published by
00008  *    the Free Software Foundation; either version 3 of the License, or
00009  *    (at your option) any later version.
00010  *
00011  *    This program is distributed in the hope that it will be useful,
00012  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00013  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00014  *    GNU General Public License for more details.
00015  *
00016  *    You should have received a copy of the GNU General Public License
00017  *    along with this program. If not, see <http://www.gnu.org/licenses/>.
00018  *    
00019  */
00020 
00021 /*
00022  *    SPegasos.java
00023  *    Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
00024  *
00025  */
00026 package moa.classifiers.functions;
00027 
00028 import moa.classifiers.AbstractClassifier;
00029 import moa.core.Measurement;
00030 import moa.core.StringUtils;
00031 import moa.options.FloatOption;
00032 import moa.options.MultiChoiceOption;
00033 import weka.core.Instance;
00034 import weka.core.Utils;
00035 
00059 public class SPegasos extends AbstractClassifier {
00060 
00064     private static final long serialVersionUID = -3732968666673530290L;
00065 
00066     @Override
00067     public String getPurposeString() {
00068         return "Stochastic variant of the Pegasos (Primal Estimated sub-GrAdient SOlver for SVM) method of Shalev-Shwartz et al. (2007).";
00069     }
00070 
00074     protected double m_lambda = 0.0001;
00075 
00076     public FloatOption lambdaRegularizationOption = new FloatOption("lambdaRegularization",
00077             'l', "Lambda regularization parameter .",
00078             0.0001, 0.00, Integer.MAX_VALUE);
00079 
00080     protected static final int HINGE = 0;
00081 
00082     protected static final int LOGLOSS = 1;
00083 
00087     protected int m_loss = HINGE;
00088 
00089     public MultiChoiceOption lossFunctionOption = new MultiChoiceOption(
00090             "lossFunction", 'o', "The loss function to use.", new String[]{
00091                 "HINGE", "LOGLOSS"}, new String[]{
00092                 "Hinge loss (SVM)",
00093                 "Log loss (logistic regression)"}, 0);
00094 
00098     protected double[] m_weights;
00099 
00103     protected double m_t;
00104 
00110     public void setLambda(double lambda) {
00111         m_lambda = lambda;
00112     }
00113 
00119     public double getLambda() {
00120         return m_lambda;
00121     }
00122 
00128     public void setLossFunction(int function) {
00129         m_loss = function;
00130     }
00131 
00137     public int getLossFunction() {
00138         return m_loss;
00139     }
00140 
00144     public void reset() {
00145         m_t = 2;
00146         m_weights = null;
00147     }
00148 
00149     protected static double dotProd(Instance inst1, double[] weights, int classIndex) {
00150         double result = 0;
00151 
00152         int n1 = inst1.numValues();
00153         int n2 = weights.length - 1;
00154 
00155         for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) {
00156             int ind1 = inst1.index(p1);
00157             int ind2 = p2;
00158             if (ind1 == ind2) {
00159                 if (ind1 != classIndex && !inst1.isMissingSparse(p1)) {
00160                     result += inst1.valueSparse(p1) * weights[p2];
00161                 }
00162                 p1++;
00163                 p2++;
00164             } else if (ind1 > ind2) {
00165                 p2++;
00166             } else {
00167                 p1++;
00168             }
00169         }
00170         return (result);
00171     }
00172 
00173     protected double dloss(double z) {
00174         if (m_loss == HINGE) {
00175             return (z < 1) ? 1 : 0;
00176         }
00177 
00178         // log loss
00179         if (z < 0) {
00180             return 1.0 / (Math.exp(z) + 1.0);
00181         } else {
00182             double t = Math.exp(-z);
00183             return t / (t + 1);
00184         }
00185     }
00186 
00187     @Override
00188     public void resetLearningImpl() {
00189         reset();
00190         setLambda(this.lambdaRegularizationOption.getValue());
00191         setLossFunction(this.lossFunctionOption.getChosenIndex());
00192     }
00193 
00199     @Override
00200     public void trainOnInstanceImpl(Instance instance) {
00201 
00202         if (m_weights == null) {
00203             m_weights = new double[instance.numAttributes() + 1];
00204         }
00205         if (!instance.classIsMissing()) {
00206 
00207             double learningRate = 1.0 / (m_lambda * m_t);
00208             //double scale = 1.0 - learningRate * m_lambda;
00209             double scale = 1.0 - 1.0 / m_t;
00210             double y = (instance.classValue() == 0) ? -1 : 1;
00211             double wx = dotProd(instance, m_weights, instance.classIndex());
00212             double z = y * (wx + m_weights[m_weights.length - 1]);
00213 
00214             for (int j = 0; j < m_weights.length - 1; j++) {
00215                 if (j != instance.classIndex()) {
00216                     m_weights[j] *= scale;
00217                 }
00218             }
00219 
00220             if (m_loss == LOGLOSS || (z < 1)) {
00221                 double loss = dloss(z);
00222                 int n1 = instance.numValues();
00223                 for (int p1 = 0; p1 < n1; p1++) {
00224                     int indS = instance.index(p1);
00225                     if (indS != instance.classIndex() && !instance.isMissingSparse(p1)) {
00226                         double m = learningRate * loss * (instance.valueSparse(p1) * y);
00227                         m_weights[indS] += m;
00228                     }
00229                 }
00230 
00231                 // update the bias
00232                 m_weights[m_weights.length - 1] += learningRate * loss * y;
00233             }
00234 
00235             double norm = 0;
00236             for (int k = 0; k < m_weights.length - 1; k++) {
00237                 if (k != instance.classIndex()) {
00238                     norm += (m_weights[k] * m_weights[k]);
00239                 }
00240             }
00241 
00242             double scale2 = Math.min(1.0, (1.0 / (m_lambda * norm)));
00243             if (scale2 < 1.0) {
00244                 scale2 = Math.sqrt(scale2);
00245                 for (int j = 0; j < m_weights.length - 1; j++) {
00246                     if (j != instance.classIndex()) {
00247                         m_weights[j] *= scale2;
00248                     }
00249                 }
00250             }
00251             m_t++;
00252         }
00253     }
00254 
00262     @Override
00263     public double[] getVotesForInstance(Instance inst) {
00264 
00265         if (m_weights == null) {
00266             return new double[inst.numAttributes() + 1];
00267         }
00268 
00269         double[] result = new double[2];
00270 
00271         double wx = dotProd(inst, m_weights, inst.classIndex());// * m_wScale;
00272         double z = (wx + m_weights[m_weights.length - 1]);
00273         //System.out.print("" + z + ": ");
00274         // System.out.println(1.0 / (1.0 + Math.exp(-z)));
00275         if (z <= 0) {
00276             //  z = 0;
00277             if (m_loss == LOGLOSS) {
00278                 result[0] = 1.0 / (1.0 + Math.exp(z));
00279                 result[1] = 1.0 - result[0];
00280             } else {
00281                 result[0] = 1;
00282             }
00283         } else {
00284             if (m_loss == LOGLOSS) {
00285                 result[1] = 1.0 / (1.0 + Math.exp(-z));
00286                 result[0] = 1.0 - result[1];
00287             } else {
00288                 result[1] = 1;
00289             }
00290         }
00291         return result;
00292     }
00293 
00294     @Override
00295     public void getModelDescription(StringBuilder result, int indent) {
00296         StringUtils.appendIndented(result, indent, toString());
00297         StringUtils.appendNewline(result);
00298     }
00299 
00305     @Override
00306     public String toString() {
00307         if (m_weights == null) {
00308             return "SPegasos: No model built yet.\n";
00309         }
00310         StringBuffer buff = new StringBuffer();
00311         buff.append("Loss function: ");
00312         if (m_loss == HINGE) {
00313             buff.append("Hinge loss (SVM)\n\n");
00314         } else {
00315             buff.append("Log loss (logistic regression)\n\n");
00316         }
00317         int printed = 0;
00318 
00319         for (int i = 0; i < m_weights.length - 1; i++) {
00320             //   if (i != m_data.classIndex()) {
00321             if (printed > 0) {
00322                 buff.append(" + ");
00323             } else {
00324                 buff.append("   ");
00325             }
00326 
00327             buff.append(Utils.doubleToString(m_weights[i], 12, 4) + " "
00328                     //+ m_data.attribute(i).name()
00329                     + "\n");
00330 
00331             printed++;
00332         }
00333         //}
00334 
00335         if (m_weights[m_weights.length - 1] > 0) {
00336             buff.append(" + " + Utils.doubleToString(m_weights[m_weights.length - 1], 12, 4));
00337         } else {
00338             buff.append(" - " + Utils.doubleToString(-m_weights[m_weights.length - 1], 12, 4));
00339         }
00340 
00341         return buff.toString();
00342     }
00343 
00344     @Override
00345     protected Measurement[] getModelMeasurementsImpl() {
00346         return null;
00347     }
00348 
00349     @Override
00350     public boolean isRandomizable() {
00351         return false;
00352     }
00353 }
 All Classes Namespaces Files Functions Variables Enumerations