MOA 12.03
Real Time Analytics for Data Streams
SGD.java
Go to the documentation of this file.
00001 /*
00002  *    SGD.java
00003  *    Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
00004  *    @author Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz)
00005  *    @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
00006  *
00007  *    This program is free software; you can redistribute it and/or modify
00008  *    it under the terms of the GNU General Public License as published by
00009  *    the Free Software Foundation; either version 3 of the License, or
00010  *    (at your option) any later version.
00011  *
00012  *    This program is distributed in the hope that it will be useful,
00013  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00014  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00015  *    GNU General Public License for more details.
00016  *
00017  *    You should have received a copy of the GNU General Public License
00018  *    along with this program. If not, see <http://www.gnu.org/licenses/>.
00019  *    
00020  */
00021 
00022 /*
00023  *    SGD.java
00024  *    Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
00025  *
00026  */
00027 package moa.classifiers.functions;
00028 
00029 import moa.classifiers.AbstractClassifier;
00030 import moa.core.Measurement;
00031 import moa.core.StringUtils;
00032 import moa.options.FloatOption;
00033 import moa.options.MultiChoiceOption;
00034 import weka.core.Instance;
00035 import weka.core.Utils;
00036 
00044 public class SGD extends AbstractClassifier {
00045 
00047     private static final long serialVersionUID = -3732968666673530290L;
00048 
00049       @Override
00050     public String getPurposeString() {
00051         return "Stochastic gradient descent for learning various linear models (binary class SVM, binary class logistic regression and linear regression).";
00052     }
00053 
00055     protected double m_lambda = 0.0001;
00056 
00057     public FloatOption lambdaRegularizationOption = new FloatOption("lambdaRegularization",
00058             'l', "Lambda regularization parameter .",
00059             0.0001, 0.00, Integer.MAX_VALUE);
00060 
00062     protected double m_learningRate = 0.01;
00063 
00064     public FloatOption learningRateOption = new FloatOption("learningRate",
00065             'r', "Learning rate parameter.",
00066             0.0001, 0.00, Integer.MAX_VALUE);
00067 
00069     protected double[] m_weights;
00070 
00072     protected double m_t;
00073 
00075     protected double m_numInstances;
00076 
00077     protected static final int HINGE = 0;
00078 
00079     protected static final int LOGLOSS = 1;
00080 
00081     protected static final int SQUAREDLOSS = 2;
00082 
00084     protected int m_loss = HINGE;
00085 
00086     public MultiChoiceOption lossFunctionOption = new MultiChoiceOption(
00087             "lossFunction", 'o', "The loss function to use.", new String[]{
00088                 "HINGE", "LOGLOSS", "SQUAREDLOSS"}, new String[]{
00089                 "Hinge loss (SVM)",
00090                 "Log loss (logistic regression)",
00091                 "Squared loss (regression)"}, 0);
00092 
00098     public void setLambda(double lambda) {
00099         m_lambda = lambda;
00100     }
00101 
00107     public double getLambda() {
00108         return m_lambda;
00109     }
00110 
00116     public void setLossFunction(int function) {
00117         m_loss = function;
00118     }
00119 
00125     public int getLossFunction() {
00126         return m_loss;
00127     }
00128 
00134     public void setLearningRate(double lr) {
00135         m_learningRate = lr;
00136     }
00137 
00143     public double getLearningRate() {
00144         return m_learningRate;
00145     }
00146 
00150     public void reset() {
00151         m_t = 1;
00152         m_weights = null;
00153     }
00154 
00155     protected double dloss(double z) {
00156         if (m_loss == HINGE) {
00157             return (z < 1) ? 1 : 0;
00158         }
00159 
00160         if (m_loss == LOGLOSS) {
00161             // log loss
00162             if (z < 0) {
00163                 return 1.0 / (Math.exp(z) + 1.0);
00164             } else {
00165                 double t = Math.exp(-z);
00166                 return t / (t + 1);
00167             }
00168         }
00169 
00170         // squared loss
00171         return z;
00172     }
00173 
00174     protected static double dotProd(Instance inst1, double[] weights, int classIndex) {
00175         double result = 0;
00176 
00177         int n1 = inst1.numValues();
00178         int n2 = weights.length - 1;
00179 
00180         for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) {
00181             int ind1 = inst1.index(p1);
00182             int ind2 = p2;
00183             if (ind1 == ind2) {
00184                 if (ind1 != classIndex && !inst1.isMissingSparse(p1)) {
00185                     result += inst1.valueSparse(p1) * weights[p2];
00186                 }
00187                 p1++;
00188                 p2++;
00189             } else if (ind1 > ind2) {
00190                 p2++;
00191             } else {
00192                 p1++;
00193             }
00194         }
00195         return (result);
00196     }
00197 
00198     @Override
00199     public void resetLearningImpl() {
00200         reset();
00201         setLambda(this.lambdaRegularizationOption.getValue());
00202         setLearningRate(this.learningRateOption.getValue());
00203         setLossFunction(this.lossFunctionOption.getChosenIndex());
00204     }
00205 
00211     @Override
00212     public void trainOnInstanceImpl(Instance instance) {
00213 
00214         if (m_weights == null) {
00215             m_weights = new double[instance.numAttributes() + 1];
00216         }
00217 
00218         if (!instance.classIsMissing()) {
00219 
00220             double wx = dotProd(instance, m_weights, instance.classIndex());
00221 
00222             double y;
00223             double z;
00224             if (instance.classAttribute().isNominal()) {
00225                 y = (instance.classValue() == 0) ? -1 : 1;
00226                 z = y * (wx + m_weights[m_weights.length - 1]);
00227             } else {
00228                 y = instance.classValue();
00229                 z = y - (wx + m_weights[m_weights.length - 1]);
00230                 y = 1;
00231             }
00232 
00233             // Compute multiplier for weight decay
00234             double multiplier = 1.0;
00235             if (m_numInstances == 0) {
00236                 multiplier = 1.0 - (m_learningRate * m_lambda) / m_t;
00237             } else {
00238                 multiplier = 1.0 - (m_learningRate * m_lambda) / m_numInstances;
00239             }
00240             for (int i = 0; i < m_weights.length - 1; i++) {
00241                 m_weights[i] *= multiplier;
00242             }
00243 
00244             // Only need to do the following if the loss is non-zero
00245             if (m_loss != HINGE || (z < 1)) {
00246 
00247                 // Compute Factor for updates
00248                 double factor = m_learningRate * y * dloss(z);
00249 
00250                 // Update coefficients for attributes
00251                 int n1 = instance.numValues();
00252                 for (int p1 = 0; p1 < n1; p1++) {
00253                     int indS = instance.index(p1);
00254                     if (indS != instance.classIndex() && !instance.isMissingSparse(p1)) {
00255                         m_weights[indS] += factor * instance.valueSparse(p1);
00256                     }
00257                 }
00258 
00259                 // update the bias
00260                 m_weights[m_weights.length - 1] += factor;
00261             }
00262             m_t++;
00263         }
00264     }
00265 
00273     @Override
00274     public double[] getVotesForInstance(Instance inst) {
00275 
00276         if (m_weights == null) {
00277             return new double[inst.numAttributes() + 1];
00278         }
00279         double[] result = (inst.classAttribute().isNominal())
00280                 ? new double[2]
00281                 : new double[1];
00282 
00283 
00284         double wx = dotProd(inst, m_weights, inst.classIndex());// * m_wScale;
00285         double z = (wx + m_weights[m_weights.length - 1]);
00286 
00287         if (inst.classAttribute().isNumeric()) {
00288             result[0] = z;
00289             return result;
00290         }
00291 
00292         if (z <= 0) {
00293             //  z = 0;
00294             if (m_loss == LOGLOSS) {
00295                 result[0] = 1.0 / (1.0 + Math.exp(z));
00296                 result[1] = 1.0 - result[0];
00297             } else {
00298                 result[0] = 1;
00299             }
00300         } else {
00301             if (m_loss == LOGLOSS) {
00302                 result[1] = 1.0 / (1.0 + Math.exp(-z));
00303                 result[0] = 1.0 - result[1];
00304             } else {
00305                 result[1] = 1;
00306             }
00307         }
00308         return result;
00309     }
00310 
00311     @Override
00312     public void getModelDescription(StringBuilder result, int indent) {
00313         StringUtils.appendIndented(result, indent, toString());
00314         StringUtils.appendNewline(result);
00315     }
00316 
00322     public String toString() {
00323         if (m_weights == null) {
00324             return "SGD: No model built yet.\n";
00325         }
00326         StringBuffer buff = new StringBuffer();
00327         buff.append("Loss function: ");
00328         if (m_loss == HINGE) {
00329             buff.append("Hinge loss (SVM)\n\n");
00330         } else if (m_loss == LOGLOSS) {
00331             buff.append("Log loss (logistic regression)\n\n");
00332         } else {
00333             buff.append("Squared loss (linear regression)\n\n");
00334         }
00335 
00336         // buff.append(m_data.classAttribute().name() + " = \n\n");
00337         int printed = 0;
00338 
00339         for (int i = 0; i < m_weights.length - 1; i++) {
00340             // if (i != m_data.classIndex()) {
00341             if (printed > 0) {
00342                 buff.append(" + ");
00343             } else {
00344                 buff.append("   ");
00345             }
00346 
00347             buff.append(Utils.doubleToString(m_weights[i], 12, 4) + " "
00348                     // + m_data.attribute(i).name()
00349                     + "\n");
00350 
00351             printed++;
00352             //}
00353         }
00354 
00355         if (m_weights[m_weights.length - 1] > 0) {
00356             buff.append(" + " + Utils.doubleToString(m_weights[m_weights.length - 1], 12, 4));
00357         } else {
00358             buff.append(" - " + Utils.doubleToString(-m_weights[m_weights.length - 1], 12, 4));
00359         }
00360 
00361         return buff.toString();
00362     }
00363 
00364     @Override
00365     protected Measurement[] getModelMeasurementsImpl() {
00366         return null;
00367     }
00368 
00369     @Override
00370     public boolean isRandomizable() {
00371         return false;
00372     }
00373 }
 All Classes Namespaces Files Functions Variables Enumerations