MOA 12.03
Real Time Analytics for Data Streams
|
00001 /* 00002 * CMM_GTAnalysis.java 00003 * Copyright (C) 2010 RWTH Aachen University, Germany 00004 * @author Jansen ([email protected]) 00005 * 00006 * This program is free software; you can redistribute it and/or modify 00007 * it under the terms of the GNU General Public License as published by 00008 * the Free Software Foundation; either version 3 of the License, or 00009 * (at your option) any later version. 00010 * 00011 * This program is distributed in the hope that it will be useful, 00012 * but WITHOUT ANY WARRANTY; without even the implied warranty of 00013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00014 * GNU General Public License for more details. 00015 * 00016 * You should have received a copy of the GNU General Public License 00017 * along with this program. If not, see <http://www.gnu.org/licenses/>. 00018 * 00019 */ 00020 00021 00038 /* 00039 * TODO: 00040 * - try to avoid calcualting the radius multiple times 00041 * - avoid the full distance map? 00042 * - knn functionality in clusters 00043 * - noise error 00044 */ 00045 00046 package moa.evaluation; 00047 00048 import java.util.ArrayList; 00049 import java.util.HashMap; 00050 import java.util.Iterator; 00051 import moa.cluster.Clustering; 00052 import moa.core.AutoExpandVector; 00053 import moa.gui.visualization.DataPoint; 00054 import weka.core.Instance; 00055 00056 public class CMM_GTAnalysis{ 00057 00061 private Clustering gtClustering; 00062 00066 private ArrayList<CMMPoint> cmmpoints; 00067 00071 private ArrayList<GTCluster> gt0Clusters; 00072 00076 private ArrayList<Integer> noise; 00077 00081 private int numPoints; 00082 00086 private int numGTClusters; 00087 00092 private int numGTClasses; 00093 00097 private int numGT0Classes; 00098 00102 private int numDims; 00103 00110 private HashMap<Integer, Integer> mapTrueLabelToWorkLabel; 00111 00115 private int[] mergeMap; 00116 00121 private int noiseErrorByModel; 00122 00127 private int pointErrorByModel; 00128 00132 private boolean debug = false; 00133 00134 00135 /******* CMM parameter ***********/ 00136 00140 private int knnNeighbourhood = 2; 00141 00146 private double tauConnection = 0.5; 00147 00152 private double clusterConnectionMaxPoints = knnNeighbourhood; 00153 00161 private boolean useExpConnectivity = false; 00162 private double lambdaConnRefXValue = 0.01; 00163 private double lambdaConnX = 4; 00164 private double lamdaConn; 00165 00166 00167 /******************************************/ 00168 00169 00174 protected class CMMPoint extends DataPoint{ 00178 protected DataPoint p = null; 00179 00183 protected int pID = 0; 00184 00185 00189 protected int trueClass = -1; 00190 00191 00195 protected double connectivity = 1.0; 00196 00197 00201 protected double knnInCluster = 0.0; 00202 00203 00207 protected ArrayList<Integer> knnIndices; 00208 00209 public CMMPoint(DataPoint point, int id) { 00210 //make a copy, but keep reference 00211 super(point,point.getTimestamp()); 00212 p = point; 00213 pID = id; 00214 trueClass = (int)point.classValue(); 00215 } 00216 00217 00224 protected int workclass(){ 00225 if(trueClass == -1 ) 00226 return -1; 00227 else 00228 return mapTrueLabelToWorkLabel.get(trueClass); 00229 } 00230 00231 protected boolean isNoise(){ return trueClass==-1;} 00232 } 00233 00234 00235 00240 protected class GTCluster{ 00242 private ArrayList<Integer> points = new ArrayList<Integer>(); 00243 00250 private ArrayList<Integer> clusterRepresentations = new ArrayList<Integer>(); 00251 00253 private int workclass; 00254 00256 private final int orgWorkClass; 00257 00259 private final int label; 00260 00262 private ArrayList<Integer> mergedWorkLabels = null; 00263 00265 private double knnMeanAvg = 0; 00266 00268 private double knnDevAvg = 0; 00269 00271 private ArrayList<Double> connections = new ArrayList<Double>(); 00272 00273 00274 private GTCluster(int workclass, int label, int gtClusteringID) { 00275 this.orgWorkClass = workclass; 00276 this.workclass = workclass; 00277 this.label = label; 00278 this.clusterRepresentations.add(gtClusteringID); 00279 } 00280 00281 00286 protected int getLabel(){ 00287 return label; 00288 } 00289 00295 protected double getInclusionProbability(CMMPoint point){ 00296 double prob = Double.MIN_VALUE; 00297 //check all cluster representatives for coverage 00298 for (int c = 0; c < clusterRepresentations.size(); c++) { 00299 double tmp_prob = gtClustering.get(clusterRepresentations.get(c)).getInclusionProbability(point); 00300 if(tmp_prob > prob) prob = tmp_prob; 00301 } 00302 return prob; 00303 } 00304 00305 00310 private void calculateKnn(){ 00311 for (int p0 : points) { 00312 CMMPoint cmdp = cmmpoints.get(p0); 00313 if(!cmdp.isNoise()){ 00314 AutoExpandVector<Double> knnDist = new AutoExpandVector<Double>(); 00315 AutoExpandVector<Integer> knnPointIndex = new AutoExpandVector<Integer>(); 00316 00317 //calculate nearest neighbours 00318 getKnnInCluster(cmdp, knnNeighbourhood, points, knnDist,knnPointIndex); 00319 00320 //TODO: What to do if we have less then k neighbours? 00321 double avgKnn = 0; 00322 for (int i = 0; i < knnDist.size(); i++) { 00323 avgKnn+= knnDist.get(i); 00324 } 00325 if(knnDist.size()!=0) 00326 avgKnn/=knnDist.size(); 00327 cmdp.knnInCluster = avgKnn; 00328 cmdp.knnIndices = knnPointIndex; 00329 cmdp.p.setMeasureValue("knnAvg", cmdp.knnInCluster); 00330 00331 knnMeanAvg+=avgKnn; 00332 knnDevAvg+=Math.pow(avgKnn,2); 00333 } 00334 } 00335 knnMeanAvg=knnMeanAvg/(double)points.size(); 00336 knnDevAvg=knnDevAvg/(double)points.size(); 00337 00338 double variance = knnDevAvg-Math.pow(knnMeanAvg,2.0); 00339 // Due to numerical errors, small negative values can occur. 00340 if (variance <= 0.0) variance = 1e-50; 00341 knnDevAvg = Math.sqrt(variance); 00342 00343 } 00344 00345 00351 private void calculateClusterConnection(int otherCid, boolean initial){ 00352 double avgConnection = 0; 00353 if(workclass==otherCid){ 00354 avgConnection = 1; 00355 } 00356 else{ 00357 AutoExpandVector<Double> kmax = new AutoExpandVector<Double>(); 00358 AutoExpandVector<Integer> kmaxIndexes = new AutoExpandVector<Integer>(); 00359 00360 for(int p : points){ 00361 CMMPoint cmdp = cmmpoints.get(p); 00362 double con_p_Cj = getConnectionValue(cmmpoints.get(p), otherCid); 00363 double connection = cmdp.connectivity * con_p_Cj; 00364 if(initial){ 00365 cmdp.p.setMeasureValue("Connection to C"+otherCid, con_p_Cj); 00366 } 00367 00368 //connection 00369 if(kmax.size() < clusterConnectionMaxPoints || connection > kmax.get(kmax.size()-1)){ 00370 int index = 0; 00371 while(index < kmax.size() && connection < kmax.get(index)) { 00372 index++; 00373 } 00374 kmax.add(index, connection); 00375 kmaxIndexes.add(index, p); 00376 if(kmax.size() > clusterConnectionMaxPoints){ 00377 kmax.remove(kmax.size()-1); 00378 kmaxIndexes.add(kmaxIndexes.size()-1); 00379 } 00380 } 00381 } 00382 //connection 00383 for (int k = 0; k < kmax.size(); k++) { 00384 avgConnection+= kmax.get(k); 00385 } 00386 avgConnection/=kmax.size(); 00387 } 00388 00389 if(otherCid<connections.size()){ 00390 connections.set(otherCid, avgConnection); 00391 } 00392 else 00393 if(connections.size() == otherCid){ 00394 connections.add(avgConnection); 00395 } 00396 else 00397 System.out.println("Something is going really wrong with the connection listing!"+knnNeighbourhood+" "+tauConnection); 00398 } 00399 00400 00405 private void mergeCluster(int mergeID){ 00406 if(mergeID < gt0Clusters.size()){ 00407 //track merging (debugging) 00408 for (int i = 0; i < numGTClasses; i++) { 00409 if(mergeMap[i]==mergeID) 00410 mergeMap[i]=workclass; 00411 if(mergeMap[i]>mergeID) 00412 mergeMap[i]--; 00413 } 00414 GTCluster gtcMerge = gt0Clusters.get(mergeID); 00415 if(debug) 00416 System.out.println("Merging C"+gtcMerge.workclass+" into C"+workclass+ 00417 " with Con "+connections.get(mergeID)+" / "+gtcMerge.connections.get(workclass)); 00418 00419 00420 //update mapTrueLabelToWorkLabel 00421 mapTrueLabelToWorkLabel.put(gtcMerge.label, workclass); 00422 Iterator iterator = mapTrueLabelToWorkLabel.keySet().iterator(); 00423 while (iterator.hasNext()) { 00424 Integer key = (Integer)iterator.next(); 00425 //update pointer of already merged cluster 00426 int value = mapTrueLabelToWorkLabel.get(key); 00427 if(value == mergeID) 00428 mapTrueLabelToWorkLabel.put(key, workclass); 00429 if(value > mergeID) 00430 mapTrueLabelToWorkLabel.put(key, value-1); 00431 } 00432 00433 //merge points from B into A 00434 points.addAll(gtcMerge.points); 00435 clusterRepresentations.addAll(gtcMerge.clusterRepresentations); 00436 if(mergedWorkLabels==null){ 00437 mergedWorkLabels = new ArrayList<Integer>(); 00438 } 00439 mergedWorkLabels.add(gtcMerge.orgWorkClass); 00440 if(gtcMerge.mergedWorkLabels!=null) 00441 mergedWorkLabels.addAll(gtcMerge.mergedWorkLabels); 00442 00443 gt0Clusters.remove(mergeID); 00444 00445 //update workclass labels 00446 for(int c=mergeID; c < gt0Clusters.size(); c++){ 00447 gt0Clusters.get(c).workclass = c; 00448 } 00449 00450 //update knn distances 00451 calculateKnn(); 00452 for(int c=0; c < gt0Clusters.size(); c++){ 00453 gt0Clusters.get(c).connections.remove(mergeID); 00454 00455 //recalculate connection from other clusters to the new merged one 00456 gt0Clusters.get(c).calculateClusterConnection(workclass,false); 00457 //and from new merged one to other clusters 00458 gt0Clusters.get(workclass).calculateClusterConnection(c,false); 00459 } 00460 } 00461 else{ 00462 System.out.println("Merge indices are not valid"); 00463 } 00464 } 00465 } 00466 00467 00473 public CMM_GTAnalysis(Clustering trueClustering, ArrayList<DataPoint> points, boolean enableClassMerge){ 00474 if(debug) 00475 System.out.println("GT Analysis Debug Output"); 00476 00477 noiseErrorByModel = 0; 00478 pointErrorByModel = 0; 00479 if(!enableClassMerge){ 00480 tauConnection = 1.0; 00481 } 00482 00483 lamdaConn = -Math.log(lambdaConnRefXValue)/Math.log(2)/lambdaConnX; 00484 00485 this.gtClustering = trueClustering; 00486 00487 numPoints = points.size(); 00488 numDims = points.get(0).numAttributes()-1; 00489 numGTClusters = gtClustering.size(); 00490 00491 //init mappings between work and true labels 00492 mapTrueLabelToWorkLabel = new HashMap<Integer, Integer>(); 00493 00494 //set up base of new clustering 00495 gt0Clusters = new ArrayList<GTCluster>(); 00496 int numWorkClasses = 0; 00497 //create label to worklabel mapping as real labels can be just a set of unordered integers 00498 for (int i = 0; i < numGTClusters; i++) { 00499 int label = (int)gtClustering.get(i).getGroundTruth(); 00500 if(!mapTrueLabelToWorkLabel.containsKey(label)){ 00501 gt0Clusters.add(new GTCluster(numWorkClasses,label,i)); 00502 mapTrueLabelToWorkLabel.put(label,numWorkClasses); 00503 numWorkClasses++; 00504 } 00505 else{ 00506 gt0Clusters.get(mapTrueLabelToWorkLabel.get(label)).clusterRepresentations.add(i); 00507 } 00508 } 00509 numGTClasses = numWorkClasses; 00510 00511 mergeMap = new int[numGTClasses]; 00512 for (int i = 0; i < numGTClasses; i++) { 00513 mergeMap[i]=i; 00514 } 00515 00516 //create cmd point wrapper instances 00517 cmmpoints = new ArrayList<CMMPoint>(); 00518 for (int p = 0; p < points.size(); p++) { 00519 CMMPoint cmdp = new CMMPoint(points.get(p), p); 00520 cmmpoints.add(cmdp); 00521 } 00522 00523 00524 //split points up into their GTClusters and Noise (according to class labels) 00525 noise = new ArrayList<Integer>(); 00526 for (int p = 0; p < numPoints; p++) { 00527 if(cmmpoints.get(p).isNoise()){ 00528 noise.add(p); 00529 } 00530 else{ 00531 gt0Clusters.get(cmmpoints.get(p).workclass()).points.add(p); 00532 } 00533 } 00534 00535 //calculate initial knnMean and knnDev 00536 for (GTCluster gtc : gt0Clusters) { 00537 gtc.calculateKnn(); 00538 } 00539 00540 //calculate cluster connections 00541 calculateGTClusterConnections(); 00542 00543 //calculate point connections with own clusters 00544 calculateGTPointQualities(); 00545 00546 if(debug) 00547 System.out.println("GT Analysis Debug End"); 00548 00549 } 00550 00558 //TODO: Cache the connection value for a point to the different clusters??? 00559 protected double getConnectionValue(CMMPoint cmmp, int clusterID){ 00560 AutoExpandVector<Double> knnDist = new AutoExpandVector<Double>(); 00561 AutoExpandVector<Integer> knnPointIndex = new AutoExpandVector<Integer>(); 00562 00563 //calculate the knn distance of the point to the cluster 00564 getKnnInCluster(cmmp, knnNeighbourhood, gt0Clusters.get(clusterID).points, knnDist, knnPointIndex); 00565 00566 //TODO: What to do if we have less then k neighbors? 00567 double avgDist = 0; 00568 for (int i = 0; i < knnDist.size(); i++) { 00569 avgDist+= knnDist.get(i); 00570 } 00571 //what to do if we only have a single point??? 00572 if(knnDist.size()!=0) 00573 avgDist/=knnDist.size(); 00574 else 00575 return 0; 00576 00577 //get the upper knn distance of the cluster 00578 double upperKnn = gt0Clusters.get(clusterID).knnMeanAvg + gt0Clusters.get(clusterID).knnDevAvg; 00579 00580 /* calculate the connectivity based on knn distance of the point within the cluster 00581 and the upper knn distance of the cluster*/ 00582 if(avgDist < upperKnn){ 00583 return 1; 00584 } 00585 else{ 00586 //value that should be reached at upperKnn distance 00587 //Choose connection formula 00588 double conn; 00589 if(useExpConnectivity) 00590 conn = Math.pow(2,-lamdaConn*(avgDist-upperKnn)/upperKnn); 00591 else 00592 conn = upperKnn/avgDist; 00593 00594 if(Double.isNaN(conn)) 00595 System.out.println("Connectivity NaN at "+cmmp.p.getTimestamp()); 00596 00597 return conn; 00598 } 00599 } 00600 00601 00609 private void getKnnInCluster(CMMPoint cmmp, int k, 00610 ArrayList<Integer> pointIDs, 00611 AutoExpandVector<Double> knnDist, 00612 AutoExpandVector<Integer> knnPointIndex) { 00613 00614 //iterate over every point in the choosen cluster, cal distance and insert into list 00615 for (int p1 = 0; p1 < pointIDs.size(); p1++) { 00616 int pid = pointIDs.get(p1); 00617 if(cmmp.pID == pid) continue; 00618 double dist = distance(cmmp,cmmpoints.get(pid)); 00619 if(knnDist.size() < k || dist < knnDist.get(knnDist.size()-1)){ 00620 int index = 0; 00621 while(index < knnDist.size() && dist > knnDist.get(index)) { 00622 index++; 00623 } 00624 knnDist.add(index, dist); 00625 knnPointIndex.add(index,pid); 00626 if(knnDist.size() > k){ 00627 knnDist.remove(knnDist.size()-1); 00628 knnPointIndex.remove(knnPointIndex.size()-1); 00629 } 00630 } 00631 } 00632 } 00633 00634 00635 00639 private void calculateGTPointQualities(){ 00640 for (int p = 0; p < numPoints; p++) { 00641 CMMPoint cmdp = cmmpoints.get(p); 00642 if(!cmdp.isNoise()){ 00643 cmdp.connectivity = getConnectionValue(cmdp, cmdp.workclass()); 00644 cmdp.p.setMeasureValue("Connectivity", cmdp.connectivity); 00645 } 00646 } 00647 } 00648 00649 00650 00655 private void calculateGTClusterConnections(){ 00656 for (int c0 = 0; c0 < gt0Clusters.size(); c0++) { 00657 for (int c1 = 0; c1 < gt0Clusters.size(); c1++) { 00658 gt0Clusters.get(c0).calculateClusterConnection(c1, true); 00659 } 00660 } 00661 00662 boolean changedConnection = true; 00663 while(changedConnection){ 00664 if(debug){ 00665 System.out.println("Cluster Connection"); 00666 for (int c = 0; c < gt0Clusters.size(); c++) { 00667 System.out.print("C"+gt0Clusters.get(c).label+" --> "); 00668 for (int c1 = 0; c1 < gt0Clusters.get(c).connections.size(); c1++) { 00669 System.out.print(" C"+gt0Clusters.get(c1).label+": "+gt0Clusters.get(c).connections.get(c1)); 00670 } 00671 System.out.println(""); 00672 } 00673 System.out.println(""); 00674 } 00675 00676 double max = 0; 00677 int maxIndexI = -1; 00678 int maxIndexJ = -1; 00679 00680 changedConnection = false; 00681 for (int c0 = 0; c0 < gt0Clusters.size(); c0++) { 00682 for (int c1 = c0+1; c1 < gt0Clusters.size(); c1++) { 00683 if(c0==c1) continue; 00684 double min =Math.min(gt0Clusters.get(c0).connections.get(c1), gt0Clusters.get(c1).connections.get(c0)); 00685 if(min > max){ 00686 max = min; 00687 maxIndexI = c0; 00688 maxIndexJ = c1; 00689 } 00690 } 00691 } 00692 if(maxIndexI!=-1 && max > tauConnection){ 00693 gt0Clusters.get(maxIndexI).mergeCluster(maxIndexJ); 00694 if(debug) 00695 System.out.println("Merging "+maxIndexI+" and "+maxIndexJ+" because of connection "+max); 00696 00697 changedConnection = true; 00698 } 00699 } 00700 numGT0Classes = gt0Clusters.size(); 00701 } 00702 00703 00709 public double getClassSeparability(){ 00710 // int totalConn = numGTClasses*(numGTClasses-1)/2; 00711 // int mergedConn = 0; 00712 // for(GTCluster gt : gt0Clusters){ 00713 // int merged = gt.clusterRepresentations.size(); 00714 // if(merged > 1) 00715 // mergedConn+=merged * (merged-1)/2; 00716 // } 00717 // if(totalConn == 0) 00718 // return 0; 00719 // else 00720 // return 1-mergedConn/(double)totalConn; 00721 return numGT0Classes/(double)numGTClasses; 00722 00723 } 00724 00725 00731 public double getNoiseSeparability(){ 00732 if(noise.isEmpty()) 00733 return 1; 00734 00735 double connectivity = 0; 00736 for(int p : noise){ 00737 CMMPoint npoint = cmmpoints.get(p); 00738 double maxConnection = 0; 00739 00740 //TODO: some kind of pruning possible. what about weighting? 00741 for (int c = 0; c < gt0Clusters.size(); c++) { 00742 double connection = getConnectionValue(npoint, c); 00743 if(connection > maxConnection) 00744 maxConnection = connection; 00745 } 00746 connectivity+=maxConnection; 00747 npoint.p.setMeasureValue("MaxConnection", maxConnection); 00748 } 00749 00750 return 1-(connectivity / noise.size()); 00751 } 00752 00753 00758 public double getModelQuality(){ 00759 for(int p = 0; p < numPoints; p++){ 00760 CMMPoint cmdp = cmmpoints.get(p); 00761 for(int hc = 0; hc < numGTClusters;hc++){ 00762 if(gtClustering.get(hc).getGroundTruth() != cmdp.trueClass){ 00763 if(gtClustering.get(hc).getInclusionProbability(cmdp) >= 1){ 00764 if(!cmdp.isNoise()) 00765 pointErrorByModel++; 00766 else 00767 noiseErrorByModel++; 00768 break; 00769 } 00770 } 00771 } 00772 } 00773 if(debug) 00774 System.out.println("Error by model: noise "+noiseErrorByModel+" point "+pointErrorByModel); 00775 00776 return 1-((pointErrorByModel + noiseErrorByModel)/(double) numPoints); 00777 } 00778 00779 00785 protected CMMPoint getPoint(int index){ 00786 return cmmpoints.get(index); 00787 } 00788 00789 00795 protected GTCluster getGT0Cluster(int index){ 00796 return gt0Clusters.get(index); 00797 } 00798 00803 protected int getNumberOfGT0Classes() { 00804 return numGT0Classes; 00805 } 00806 00813 private double distance(Instance inst1, Instance inst2){ 00814 return distance(inst1, inst2.toDoubleArray()); 00815 00816 } 00817 00824 private double distance(Instance inst1, double[] inst2){ 00825 double distance = 0.0; 00826 for (int i = 0; i < numDims; i++) { 00827 double d = inst1.value(i) - inst2[i]; 00828 distance += d * d; 00829 } 00830 return Math.sqrt(distance); 00831 } 00832 00837 public String getParameterString(){ 00838 String para = ""; 00839 para+="k="+knnNeighbourhood+";"; 00840 if(useExpConnectivity){ 00841 para+="lambdaConnX="+lambdaConnX+";"; 00842 para+="lambdaConn="+lamdaConn+";"; 00843 para+="lambdaConnRef="+lambdaConnRefXValue+";"; 00844 } 00845 para+="m="+clusterConnectionMaxPoints+";"; 00846 para+="tauConn="+tauConnection+";"; 00847 00848 return para; 00849 } 00850 } 00851 00852