Improving colors and looks of graph plots.
[pingpong.git] / python_ml / silhouette.py
1 from __future__ import print_function
2
3 from sklearn.datasets import make_blobs
4 from sklearn.cluster import KMeans
5 from sklearn.metrics import silhouette_samples, silhouette_score
6
7 import matplotlib.pyplot as plt
8 import matplotlib.cm as cm
9 import numpy as np
10
11 print(__doc__)
12
13 # Generating the sample data from make_blobs
14 # This particular setting has one distinct cluster and 3 clusters placed close
15 # together.
16 '''X, y = make_blobs(n_samples=500,
17                   n_features=2,
18                   centers=4,
19                   cluster_std=1,
20                   center_box=(-10.0, 10.0),
21                   shuffle=True,
22                   random_state=1)  # For reproducibility'''
23
24 X = np.array([[132, 192], [117, 960], [117, 962], [1343, 0], [117, 1116], [117, 1117], [117, 1118], [117, 1119], [1015, 0], [117, 966]])
25
26 range_n_clusters = [2, 3, 4, 5, 6]
27
28 for n_clusters in range_n_clusters:
29     # Create a subplot with 1 row and 2 columns
30     fig, (ax1, ax2) = plt.subplots(1, 2)
31     fig.set_size_inches(18, 7)
32
33     # The 1st subplot is the silhouette plot
34     # The silhouette coefficient can range from -1, 1 but in this example all
35     # lie within [-0.1, 1]
36     ax1.set_xlim([-0.1, 1])
37     # The (n_clusters+1)*10 is for inserting blank space between silhouette
38     # plots of individual clusters, to demarcate them clearly.
39     ax1.set_ylim([0, len(X) + (n_clusters + 1) * 10])
40
41     # Initialize the clusterer with n_clusters value and a random generator
42     # seed of 10 for reproducibility.
43     clusterer = KMeans(n_clusters=n_clusters, random_state=10)
44     cluster_labels = clusterer.fit_predict(X)
45
46     # The silhouette_score gives the average value for all the samples.
47     # This gives a perspective into the density and separation of the formed
48     # clusters
49     silhouette_avg = silhouette_score(X, cluster_labels)
50     print("For n_clusters =", n_clusters,
51           "The average silhouette_score is :", silhouette_avg)
52
53     # Compute the silhouette scores for each sample
54     sample_silhouette_values = silhouette_samples(X, cluster_labels)
55
56     y_lower = 10
57     for i in range(n_clusters):
58         # Aggregate the silhouette scores for samples belonging to
59         # cluster i, and sort them
60         ith_cluster_silhouette_values = \
61             sample_silhouette_values[cluster_labels == i]
62
63         ith_cluster_silhouette_values.sort()
64
65         size_cluster_i = ith_cluster_silhouette_values.shape[0]
66         y_upper = y_lower + size_cluster_i
67
68         color = cm.nipy_spectral(float(i) / n_clusters)
69         ax1.fill_betweenx(np.arange(y_lower, y_upper),
70                           0, ith_cluster_silhouette_values,
71                           facecolor=color, edgecolor=color, alpha=0.7)
72
73         # Label the silhouette plots with their cluster numbers at the middle
74         ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
75
76         # Compute the new y_lower for next plot
77         y_lower = y_upper + 10  # 10 for the 0 samples
78
79     ax1.set_title("The silhouette plot for the various clusters.")
80     ax1.set_xlabel("The silhouette coefficient values")
81     ax1.set_ylabel("Cluster label")
82
83     # The vertical line for average silhouette score of all the values
84     ax1.axvline(x=silhouette_avg, color="red", linestyle="--")
85
86     ax1.set_yticks([])  # Clear the yaxis labels / ticks
87     ax1.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])
88
89     # 2nd Plot showing the actual clusters formed
90     colors = cm.nipy_spectral(cluster_labels.astype(float) / n_clusters)
91     ax2.scatter(X[:, 0], X[:, 1], marker='.', s=30, lw=0, alpha=0.7,
92                 c=colors, edgecolor='k')
93
94     # Labeling the clusters
95     centers = clusterer.cluster_centers_
96     # Draw white circles at cluster centers
97     ax2.scatter(centers[:, 0], centers[:, 1], marker='o',
98                 c="white", alpha=1, s=200, edgecolor='k')
99
100     for i, c in enumerate(centers):
101         ax2.scatter(c[0], c[1], marker='$%d$' % i, alpha=1,
102                     s=50, edgecolor='k')
103
104     ax2.set_title("The visualization of the clustered data.")
105     ax2.set_xlabel("Feature space for the 1st feature")
106     ax2.set_ylabel("Feature space for the 2nd feature")
107
108     plt.suptitle(("Silhouette analysis for KMeans clustering on sample data "
109                   "with n_clusters = %d" % n_clusters),
110                  fontsize=14, fontweight='bold')
111
112     plt.show()
113