New scripts to cluster based on C->S/S->C direction filter; improving the plot's...
[pingpong.git] / python_ml / plotting-dbscan-metric.py
1 from sklearn.cluster import DBSCAN
2 from sklearn import metrics
3 import sys
4 import math
5 import matplotlib.cm as cm
6 import numpy as np
7 import matplotlib.pyplot as plt
8
9 # metric function for clustering
10 def metric(x, y):
11         # Compare 2 datapoints in array element 2 and 3 that contains C or S
12         if x[2] != y[2] or x[3] != y[3]:
13                 # We are not going to cluster these together since they have different directions
14                 return sys.maxsize;
15         else:
16                 # Compute Euclidian distance here
17                 return math.sqrt((x[0] - y[0])**2 + (x[1] - y[1])**2)
18
19 # Create a subplot with 1 row and 2 columns
20 fig, (ax2) = plt.subplots(1, 1)
21 fig.set_size_inches(20, 20)
22
23
24 # Read from file
25 # TODO: Just change the following path and filename 
26 #       when needed to read from a different file
27 path = "/scratch/July-2018/Pairs3/"
28 device = "kwikset-off-phone-side"
29 filename = device + ".txt"
30 plt.ylim(0, 2000)
31 plt.xlim(0, 2000)
32
33 # Number of triggers
34 trig = 50
35
36 # Read and create an array of pairs
37 with open(path + filename, "r") as pairs:
38         pairsArr = []
39         pairsSrcLabels = []
40         for line in pairs:
41                 # We will see a pair and we need to split it into xpoint and ypoint
42                 xpoint, ypoint, srcHost1, srcHost2, src1, src2 = line.split(", ")
43                 # Assign 1000 for client and 0 for server to create distance
44                 src1Val = 1000 if src1 == 'C' else 0
45                 src2Val = 1000 if src2 == 'C' else 0
46                 pair = [int(xpoint), int(ypoint), int(src1Val), int(src2Val)]
47                 pairSrc = [int(xpoint), int(ypoint), srcHost1, srcHost2, src1, src2]
48                 # Array of actual points
49                 pairsArr.append(pair)
50                 # Array of source labels
51                 pairsSrcLabels.append(pairSrc)
52
53 # Formed array of pairs         
54 #print(pairsArr)
55 X = np.array(pairsArr);
56
57 # Compute DBSCAN
58 # eps = distances
59 # min_samples = minimum number of members of a cluster
60 #db = DBSCAN(eps=20, min_samples=trig - 5).fit(X)
61 # TODO: This is just for seeing more clusters
62 db = DBSCAN(eps=20, min_samples=trig - 45, metric=metric).fit(X)
63 core_samples_mask = np.zeros_like(db.labels_, dtype=bool)
64 core_samples_mask[db.core_sample_indices_] = True
65 labels = db.labels_
66
67 # Number of clusters in labels, ignoring noise if present.
68 n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
69
70 #print('Estimated number of clusters: %d' % n_clusters_)
71
72 import matplotlib.pyplot as plt
73
74 # Black removed and is used for noise instead.
75 unique_labels = set(labels)
76 #print("Labels: " + str(labels))
77
78 colors = [plt.cm.Spectral(each)
79           for each in np.linspace(0, 1, len(unique_labels))]
80 for k, col in zip(unique_labels, colors):
81     cluster_col = [1, 0, 0, 1]
82     if k == -1:
83         # Black used for noise.
84         col = [0, 0, 0, 1]
85
86     class_member_mask = (labels == k)
87
88     xy = X[class_member_mask & core_samples_mask]
89     plt.plot(xy[:, 0], xy[:, 1], 'o', markerfacecolor=tuple(cluster_col),
90              markeredgecolor='k', markersize=10)
91
92     xy = X[class_member_mask & ~core_samples_mask]
93     plt.plot(xy[:, 0], xy[:, 1], 'o', markerfacecolor=tuple(col),
94              markeredgecolor='k', markersize=6)
95
96 # Print lengths
97 count = 0
98 for pair in pairsArr:
99         if labels[count] == -1:
100                 plt.text(pair[0], pair[1], str(pair[0]) + ", " + str(pair[1]), fontsize=10)
101         else:
102         # Only print the frequency when this is a real cluster
103                 plt.text(pair[0], pair[1], str(pair[0]) + ", " + str(pair[1]) + 
104                         " f: " + str(labels.tolist().count(labels[count])), fontsize=10)
105         count = count + 1
106
107 # Print source-destination labels
108 count = 0
109 for pair in pairsSrcLabels:
110         # Only print the frequency when this is a real cluster
111         plt.text(pair[0], pair[1], str(pair[4]) + "->" + str(pair[5]))
112         count = count + 1
113         
114 plt.title(device + ' - Clusters: %d' % n_clusters_)
115 plt.show()
116