MOA 12.03
Real Time Analytics for Data Streams
CMM_GTAnalysis.java
Go to the documentation of this file.
00001 /*
00002  *    CMM_GTAnalysis.java
00003  *    Copyright (C) 2010 RWTH Aachen University, Germany
00004  *    @author Jansen ([email protected])
00005  *
00006  *    This program is free software; you can redistribute it and/or modify
00007  *    it under the terms of the GNU General Public License as published by
00008  *    the Free Software Foundation; either version 3 of the License, or
00009  *    (at your option) any later version.
00010  *
00011  *    This program is distributed in the hope that it will be useful,
00012  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00013  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00014  *    GNU General Public License for more details.
00015  *
00016  *    You should have received a copy of the GNU General Public License
00017  *    along with this program. If not, see <http://www.gnu.org/licenses/>.
00018  *    
00019  */
00020 
00021 
00038 /*
00039  * TODO:
00040  * - try to avoid calcualting the radius multiple times
00041  * - avoid the full distance map?
00042  * - knn functionality in clusters
00043  * - noise error
00044  */
00045 
00046 package moa.evaluation;
00047 
00048 import java.util.ArrayList;
00049 import java.util.HashMap;
00050 import java.util.Iterator;
00051 import moa.cluster.Clustering;
00052 import moa.core.AutoExpandVector;
00053 import moa.gui.visualization.DataPoint;
00054 import weka.core.Instance;
00055 
00056 public class CMM_GTAnalysis{
00057         
00061     private Clustering gtClustering;
00062     
00066     private ArrayList<CMMPoint> cmmpoints;
00067     
00071     private ArrayList<GTCluster> gt0Clusters;
00072 
00076     private ArrayList<Integer> noise;
00077     
00081     private int numPoints;
00082 
00086     private int numGTClusters;
00087 
00092     private int numGTClasses;
00093 
00097     private int numGT0Classes;
00098 
00102     private int numDims;
00103 
00110     private HashMap<Integer, Integer> mapTrueLabelToWorkLabel;
00111 
00115     private int[] mergeMap;
00116 
00121     private int noiseErrorByModel;
00122 
00127     private int pointErrorByModel;    
00128     
00132     private boolean debug = false;
00133 
00134     
00135     /******* CMM parameter ***********/
00136 
00140     private int knnNeighbourhood = 2;
00141 
00146     private double tauConnection = 0.5;
00147     
00152     private double clusterConnectionMaxPoints = knnNeighbourhood;
00153     
00161     private boolean useExpConnectivity = false;
00162     private double lambdaConnRefXValue = 0.01;
00163     private double lambdaConnX = 4;
00164     private double lamdaConn;
00165     
00166     
00167     /******************************************/
00168     
00169     
00174     protected class CMMPoint extends DataPoint{
00178         protected DataPoint p = null;
00179         
00183         protected int pID = 0;
00184         
00185         
00189         protected int trueClass = -1;
00190 
00191         
00195         protected double connectivity = 1.0;
00196         
00197         
00201         protected double knnInCluster = 0.0; 
00202         
00203         
00207         protected ArrayList<Integer> knnIndices;
00208 
00209         public CMMPoint(DataPoint point, int id) {
00210             //make a copy, but keep reference
00211             super(point,point.getTimestamp());
00212             p = point;
00213             pID = id;
00214             trueClass = (int)point.classValue();
00215         }
00216 
00217         
00224         protected int workclass(){
00225             if(trueClass == -1 )
00226                 return -1;
00227             else
00228                 return mapTrueLabelToWorkLabel.get(trueClass);
00229         }
00230 
00231         protected boolean isNoise(){ return trueClass==-1;}
00232     }
00233 
00234     
00235     
00240     protected class GTCluster{
00242         private ArrayList<Integer> points = new ArrayList<Integer>();
00243         
00250         private ArrayList<Integer> clusterRepresentations = new ArrayList<Integer>();
00251      
00253         private int workclass;
00254         
00256         private final int orgWorkClass;
00257         
00259         private final int label;
00260         
00262         private ArrayList<Integer> mergedWorkLabels = null;
00263         
00265         private double knnMeanAvg = 0;
00266         
00268         private double knnDevAvg = 0;
00269         
00271         private ArrayList<Double> connections = new ArrayList<Double>();
00272         
00273 
00274         private GTCluster(int workclass, int label, int gtClusteringID) {
00275            this.orgWorkClass = workclass;
00276            this.workclass = workclass;
00277            this.label = label;
00278            this.clusterRepresentations.add(gtClusteringID);
00279         }
00280 
00281         
00286         protected int getLabel(){
00287             return label;
00288         }
00289 
00295         protected double getInclusionProbability(CMMPoint point){
00296             double prob = Double.MIN_VALUE;
00297             //check all cluster representatives for coverage
00298             for (int c = 0; c < clusterRepresentations.size(); c++) {
00299                double tmp_prob = gtClustering.get(clusterRepresentations.get(c)).getInclusionProbability(point);
00300                if(tmp_prob > prob) prob = tmp_prob;
00301             }
00302             return prob;
00303         }
00304 
00305         
00310         private void calculateKnn(){
00311             for (int p0 : points) {
00312                 CMMPoint cmdp = cmmpoints.get(p0);
00313                 if(!cmdp.isNoise()){
00314                     AutoExpandVector<Double> knnDist = new AutoExpandVector<Double>();
00315                     AutoExpandVector<Integer> knnPointIndex = new AutoExpandVector<Integer>();
00316                     
00317                     //calculate nearest neighbours 
00318                     getKnnInCluster(cmdp, knnNeighbourhood, points, knnDist,knnPointIndex);
00319 
00320                     //TODO: What to do if we have less then k neighbours?
00321                     double avgKnn = 0;
00322                     for (int i = 0; i < knnDist.size(); i++) {
00323                         avgKnn+= knnDist.get(i);
00324                     }
00325                     if(knnDist.size()!=0)
00326                         avgKnn/=knnDist.size();
00327                     cmdp.knnInCluster = avgKnn;
00328                     cmdp.knnIndices = knnPointIndex;
00329                     cmdp.p.setMeasureValue("knnAvg", cmdp.knnInCluster);
00330 
00331                     knnMeanAvg+=avgKnn;
00332                     knnDevAvg+=Math.pow(avgKnn,2);
00333                 }
00334             }
00335             knnMeanAvg=knnMeanAvg/(double)points.size();
00336             knnDevAvg=knnDevAvg/(double)points.size();
00337 
00338             double variance = knnDevAvg-Math.pow(knnMeanAvg,2.0);
00339             // Due to numerical errors, small negative values can occur.
00340             if (variance <= 0.0) variance = 1e-50;
00341             knnDevAvg = Math.sqrt(variance);
00342 
00343         }
00344 
00345         
00351         private void calculateClusterConnection(int otherCid, boolean initial){
00352             double avgConnection = 0;
00353             if(workclass==otherCid){
00354                 avgConnection = 1;
00355             }
00356             else{
00357                 AutoExpandVector<Double> kmax = new AutoExpandVector<Double>();
00358                 AutoExpandVector<Integer> kmaxIndexes = new AutoExpandVector<Integer>();
00359 
00360                 for(int p : points){
00361                     CMMPoint cmdp = cmmpoints.get(p);
00362                     double con_p_Cj = getConnectionValue(cmmpoints.get(p), otherCid);
00363                     double connection = cmdp.connectivity * con_p_Cj;
00364                     if(initial){
00365                         cmdp.p.setMeasureValue("Connection to C"+otherCid, con_p_Cj);
00366                     }
00367 
00368                     //connection
00369                     if(kmax.size() < clusterConnectionMaxPoints || connection > kmax.get(kmax.size()-1)){
00370                         int index = 0;
00371                         while(index < kmax.size() && connection < kmax.get(index)) {
00372                             index++;
00373                         }
00374                         kmax.add(index, connection);
00375                         kmaxIndexes.add(index, p);
00376                         if(kmax.size() > clusterConnectionMaxPoints){
00377                             kmax.remove(kmax.size()-1);
00378                             kmaxIndexes.add(kmaxIndexes.size()-1);
00379                         }
00380                     }
00381                 }
00382                 //connection
00383                 for (int k = 0; k < kmax.size(); k++) {
00384                     avgConnection+= kmax.get(k);
00385                 }
00386                 avgConnection/=kmax.size();
00387             }
00388 
00389             if(otherCid<connections.size()){
00390                 connections.set(otherCid, avgConnection);
00391             }
00392             else
00393                 if(connections.size() == otherCid){
00394                     connections.add(avgConnection);
00395                 }
00396                 else
00397                     System.out.println("Something is going really wrong with the connection listing!"+knnNeighbourhood+" "+tauConnection);
00398         }
00399 
00400         
00405         private void mergeCluster(int mergeID){
00406             if(mergeID < gt0Clusters.size()){
00407                 //track merging (debugging)
00408                 for (int i = 0; i < numGTClasses; i++) {
00409                     if(mergeMap[i]==mergeID)
00410                         mergeMap[i]=workclass;
00411                     if(mergeMap[i]>mergeID)
00412                         mergeMap[i]--;
00413                 }
00414                 GTCluster gtcMerge  = gt0Clusters.get(mergeID);
00415                 if(debug)
00416                     System.out.println("Merging C"+gtcMerge.workclass+" into C"+workclass+
00417                             " with Con "+connections.get(mergeID)+" / "+gtcMerge.connections.get(workclass));
00418 
00419 
00420                 //update mapTrueLabelToWorkLabel
00421                 mapTrueLabelToWorkLabel.put(gtcMerge.label, workclass);
00422                 Iterator iterator = mapTrueLabelToWorkLabel.keySet().iterator();
00423                 while (iterator.hasNext()) {
00424                     Integer key = (Integer)iterator.next();
00425                     //update pointer of already merged cluster
00426                     int value = mapTrueLabelToWorkLabel.get(key);
00427                     if(value == mergeID)
00428                         mapTrueLabelToWorkLabel.put(key, workclass);
00429                     if(value > mergeID)
00430                         mapTrueLabelToWorkLabel.put(key, value-1);
00431                 }
00432 
00433                 //merge points from B into A
00434                 points.addAll(gtcMerge.points);
00435                 clusterRepresentations.addAll(gtcMerge.clusterRepresentations);
00436                 if(mergedWorkLabels==null){
00437                     mergedWorkLabels = new ArrayList<Integer>();
00438                 }
00439                 mergedWorkLabels.add(gtcMerge.orgWorkClass);
00440                 if(gtcMerge.mergedWorkLabels!=null)
00441                     mergedWorkLabels.addAll(gtcMerge.mergedWorkLabels);
00442 
00443                 gt0Clusters.remove(mergeID);
00444 
00445                 //update workclass labels
00446                 for(int c=mergeID; c < gt0Clusters.size(); c++){
00447                     gt0Clusters.get(c).workclass = c;
00448                 }
00449 
00450                 //update knn distances
00451                 calculateKnn();
00452                 for(int c=0; c < gt0Clusters.size(); c++){
00453                     gt0Clusters.get(c).connections.remove(mergeID);
00454                     
00455                     //recalculate connection from other clusters to the new merged one
00456                     gt0Clusters.get(c).calculateClusterConnection(workclass,false);
00457                     //and from new merged one to other clusters
00458                     gt0Clusters.get(workclass).calculateClusterConnection(c,false);
00459                 }
00460             }
00461             else{
00462                 System.out.println("Merge indices are not valid");
00463             }
00464         }
00465     }
00466 
00467     
00473     public CMM_GTAnalysis(Clustering trueClustering, ArrayList<DataPoint> points, boolean enableClassMerge){
00474         if(debug)
00475             System.out.println("GT Analysis Debug Output");
00476 
00477         noiseErrorByModel = 0;
00478         pointErrorByModel = 0;
00479         if(!enableClassMerge){
00480                 tauConnection = 1.0;
00481         }
00482 
00483         lamdaConn = -Math.log(lambdaConnRefXValue)/Math.log(2)/lambdaConnX;
00484         
00485         this.gtClustering = trueClustering;
00486 
00487         numPoints = points.size();
00488         numDims = points.get(0).numAttributes()-1;
00489         numGTClusters = gtClustering.size();
00490 
00491         //init mappings between work and true labels
00492         mapTrueLabelToWorkLabel = new HashMap<Integer, Integer>();
00493         
00494         //set up base of new clustering
00495         gt0Clusters = new ArrayList<GTCluster>();
00496         int numWorkClasses = 0;
00497         //create label to worklabel mapping as real labels can be just a set of unordered integers
00498         for (int i = 0; i < numGTClusters; i++) {
00499             int label = (int)gtClustering.get(i).getGroundTruth();
00500             if(!mapTrueLabelToWorkLabel.containsKey(label)){
00501                 gt0Clusters.add(new GTCluster(numWorkClasses,label,i));
00502                 mapTrueLabelToWorkLabel.put(label,numWorkClasses);
00503                 numWorkClasses++;
00504             }
00505             else{
00506                 gt0Clusters.get(mapTrueLabelToWorkLabel.get(label)).clusterRepresentations.add(i);
00507             }
00508         }
00509         numGTClasses = numWorkClasses;
00510 
00511         mergeMap = new int[numGTClasses];
00512         for (int i = 0; i < numGTClasses; i++) {
00513             mergeMap[i]=i;
00514         }
00515 
00516         //create cmd point wrapper instances
00517         cmmpoints = new ArrayList<CMMPoint>();
00518         for (int p = 0; p < points.size(); p++) {
00519             CMMPoint cmdp = new CMMPoint(points.get(p), p);
00520             cmmpoints.add(cmdp);
00521         }
00522 
00523 
00524         //split points up into their GTClusters and Noise (according to class labels)
00525         noise = new ArrayList<Integer>();
00526         for (int p = 0; p < numPoints; p++) {
00527             if(cmmpoints.get(p).isNoise()){
00528                 noise.add(p);
00529             }
00530             else{
00531                 gt0Clusters.get(cmmpoints.get(p).workclass()).points.add(p);
00532             }
00533         }
00534 
00535         //calculate initial knnMean and knnDev
00536         for (GTCluster gtc : gt0Clusters) {
00537             gtc.calculateKnn();
00538         }
00539 
00540         //calculate cluster connections
00541         calculateGTClusterConnections();
00542 
00543         //calculate point connections with own clusters
00544         calculateGTPointQualities();
00545 
00546         if(debug)
00547             System.out.println("GT Analysis Debug End");
00548 
00549    }
00550 
00558     //TODO: Cache the connection value for a point to the different clusters???
00559     protected double getConnectionValue(CMMPoint cmmp, int clusterID){
00560         AutoExpandVector<Double> knnDist = new AutoExpandVector<Double>();
00561         AutoExpandVector<Integer> knnPointIndex = new AutoExpandVector<Integer>();
00562         
00563         //calculate the knn distance of the point to the cluster
00564         getKnnInCluster(cmmp, knnNeighbourhood, gt0Clusters.get(clusterID).points, knnDist, knnPointIndex);
00565 
00566         //TODO: What to do if we have less then k neighbors?
00567         double avgDist = 0;
00568         for (int i = 0; i < knnDist.size(); i++) {
00569             avgDist+= knnDist.get(i);
00570         }
00571         //what to do if we only have a single point???
00572         if(knnDist.size()!=0)
00573             avgDist/=knnDist.size();
00574         else
00575             return 0;
00576 
00577         //get the upper knn distance of the cluster
00578         double upperKnn = gt0Clusters.get(clusterID).knnMeanAvg + gt0Clusters.get(clusterID).knnDevAvg;
00579         
00580         /* calculate the connectivity based on knn distance of the point within the cluster
00581            and the upper knn distance of the cluster*/ 
00582         if(avgDist < upperKnn){
00583             return 1;
00584         }
00585         else{
00586             //value that should be reached at upperKnn distance
00587             //Choose connection formula
00588             double conn;
00589             if(useExpConnectivity)
00590                 conn = Math.pow(2,-lamdaConn*(avgDist-upperKnn)/upperKnn);
00591             else
00592                 conn = upperKnn/avgDist;
00593 
00594             if(Double.isNaN(conn))
00595                 System.out.println("Connectivity NaN at "+cmmp.p.getTimestamp());
00596 
00597             return conn;
00598         }
00599     }
00600 
00601     
00609     private void getKnnInCluster(CMMPoint cmmp, int k,
00610                                  ArrayList<Integer> pointIDs,
00611                                  AutoExpandVector<Double> knnDist,
00612                                  AutoExpandVector<Integer> knnPointIndex) {
00613 
00614         //iterate over every point in the choosen cluster, cal distance and insert into list
00615         for (int p1 = 0; p1 < pointIDs.size(); p1++) {
00616             int pid = pointIDs.get(p1);
00617             if(cmmp.pID == pid) continue;
00618             double dist = distance(cmmp,cmmpoints.get(pid));
00619             if(knnDist.size() < k || dist < knnDist.get(knnDist.size()-1)){
00620                 int index = 0;
00621                 while(index < knnDist.size() && dist > knnDist.get(index)) {
00622                     index++;
00623                 }
00624                 knnDist.add(index, dist);
00625                 knnPointIndex.add(index,pid);
00626                 if(knnDist.size() > k){
00627                     knnDist.remove(knnDist.size()-1);
00628                     knnPointIndex.remove(knnPointIndex.size()-1);
00629                 }
00630             }
00631         }
00632     }
00633 
00634 
00635     
00639     private void calculateGTPointQualities(){
00640         for (int p = 0; p < numPoints; p++) {
00641             CMMPoint cmdp = cmmpoints.get(p);
00642             if(!cmdp.isNoise()){
00643                 cmdp.connectivity = getConnectionValue(cmdp, cmdp.workclass());
00644                 cmdp.p.setMeasureValue("Connectivity", cmdp.connectivity);
00645             }
00646         }
00647     }
00648 
00649     
00650     
00655     private void calculateGTClusterConnections(){
00656         for (int c0 = 0; c0 < gt0Clusters.size(); c0++) {
00657             for (int c1 = 0; c1 < gt0Clusters.size(); c1++) {
00658                     gt0Clusters.get(c0).calculateClusterConnection(c1, true);
00659             }
00660         }
00661 
00662         boolean changedConnection = true;
00663         while(changedConnection){
00664             if(debug){
00665                 System.out.println("Cluster Connection");
00666                 for (int c = 0; c < gt0Clusters.size(); c++) {
00667                     System.out.print("C"+gt0Clusters.get(c).label+" --> ");
00668                     for (int c1 = 0; c1 < gt0Clusters.get(c).connections.size(); c1++) {
00669                         System.out.print(" C"+gt0Clusters.get(c1).label+": "+gt0Clusters.get(c).connections.get(c1));
00670                     }
00671                     System.out.println("");
00672                 }
00673                 System.out.println("");
00674             }
00675 
00676             double max = 0;
00677             int maxIndexI = -1;
00678             int maxIndexJ = -1;
00679 
00680             changedConnection = false;
00681             for (int c0 = 0; c0 < gt0Clusters.size(); c0++) {
00682                 for (int c1 = c0+1; c1 < gt0Clusters.size(); c1++) {
00683                     if(c0==c1) continue;
00684                         double min =Math.min(gt0Clusters.get(c0).connections.get(c1), gt0Clusters.get(c1).connections.get(c0));
00685                         if(min > max){
00686                             max = min;
00687                             maxIndexI = c0;
00688                             maxIndexJ = c1;
00689                         }
00690                 }
00691             }
00692             if(maxIndexI!=-1 && max > tauConnection){
00693                 gt0Clusters.get(maxIndexI).mergeCluster(maxIndexJ);
00694                 if(debug)
00695                     System.out.println("Merging "+maxIndexI+" and "+maxIndexJ+" because of connection "+max);
00696 
00697                 changedConnection = true;
00698             }
00699         }
00700         numGT0Classes = gt0Clusters.size();
00701     }
00702 
00703     
00709     public double getClassSeparability(){
00710 //        int totalConn = numGTClasses*(numGTClasses-1)/2;
00711 //        int mergedConn = 0;
00712 //        for(GTCluster gt : gt0Clusters){
00713 //            int merged = gt.clusterRepresentations.size();
00714 //            if(merged > 1)
00715 //                mergedConn+=merged * (merged-1)/2;
00716 //        }
00717 //        if(totalConn == 0)
00718 //            return 0;
00719 //        else
00720 //            return 1-mergedConn/(double)totalConn;
00721         return numGT0Classes/(double)numGTClasses;
00722 
00723     }
00724 
00725     
00731     public double getNoiseSeparability(){
00732         if(noise.isEmpty()) 
00733             return 1;
00734 
00735         double connectivity = 0;
00736         for(int p : noise){
00737             CMMPoint npoint = cmmpoints.get(p);
00738             double maxConnection = 0;
00739 
00740             //TODO: some kind of pruning possible. what about weighting?
00741             for (int c = 0; c < gt0Clusters.size(); c++) {
00742                 double connection = getConnectionValue(npoint, c);
00743                 if(connection > maxConnection)
00744                     maxConnection = connection;
00745             }
00746             connectivity+=maxConnection;
00747             npoint.p.setMeasureValue("MaxConnection", maxConnection);
00748         }
00749 
00750         return 1-(connectivity / noise.size());
00751     }
00752 
00753     
00758     public double getModelQuality(){
00759         for(int p = 0; p < numPoints; p++){
00760             CMMPoint cmdp = cmmpoints.get(p);
00761             for(int hc = 0; hc < numGTClusters;hc++){
00762                 if(gtClustering.get(hc).getGroundTruth() != cmdp.trueClass){
00763                     if(gtClustering.get(hc).getInclusionProbability(cmdp) >= 1){
00764                         if(!cmdp.isNoise())
00765                             pointErrorByModel++;
00766                         else
00767                             noiseErrorByModel++;
00768                         break;
00769                     }
00770                 }
00771             }
00772         }
00773         if(debug)
00774             System.out.println("Error by model: noise "+noiseErrorByModel+" point "+pointErrorByModel);
00775 
00776         return 1-((pointErrorByModel + noiseErrorByModel)/(double) numPoints);
00777     }
00778 
00779     
00785     protected CMMPoint getPoint(int index){
00786         return cmmpoints.get(index);
00787     }
00788 
00789     
00795     protected GTCluster getGT0Cluster(int index){
00796         return gt0Clusters.get(index);
00797     }
00798 
00803     protected int getNumberOfGT0Classes() {
00804         return numGT0Classes;
00805     }
00806     
00813     private double distance(Instance inst1, Instance inst2){
00814           return distance(inst1, inst2.toDoubleArray());
00815 
00816     }
00817     
00824     private double distance(Instance inst1, double[] inst2){
00825         double distance = 0.0;
00826         for (int i = 0; i < numDims; i++) {
00827             double d = inst1.value(i) - inst2[i];
00828             distance += d * d;
00829         }
00830         return Math.sqrt(distance);
00831     }
00832     
00837     public String getParameterString(){
00838         String para = "";
00839         para+="k="+knnNeighbourhood+";";
00840         if(useExpConnectivity){
00841                 para+="lambdaConnX="+lambdaConnX+";";
00842                 para+="lambdaConn="+lamdaConn+";";
00843                 para+="lambdaConnRef="+lambdaConnRefXValue+";";
00844         }
00845         para+="m="+clusterConnectionMaxPoints+";";
00846         para+="tauConn="+tauConnection+";";
00847 
00848         return para;
00849     }    
00850 }
00851 
00852 
 All Classes Namespaces Files Functions Variables Enumerations