-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot.py
More file actions
114 lines (102 loc) · 4.1 KB
/
plot.py
File metadata and controls
114 lines (102 loc) · 4.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import matplotlib.pyplot as plt
import numpy as np
def plot(centroids, X_train, X_test=None, y_train=None, y_test=None, projection=False):
"""
Plots the data points and centroids of the clusters for the kmeans library
:param centroids: The centroids obtained after kmeans training
:type centroids: np.ndarray
:param X_train: The training set, a numpy array of shape (N, D) containing N examples with D dimensions
:type X_train: list or np.ndarray
:param X_test: (Optional) The testing set, a numpy array of shape (N, D) containing N examples with D dimensions
:type X_test: list or np.ndarray
:param y_train: (Optional) Label of cluster for each training sample
:type y_train: list or np.ndarray
:param y_test: (Optional) Label of cluster for each test sample
:type y_test: list or np.ndarray
:param projection: (Optional) Plot in either 2D or 3D space (True = Plot in 3D)
:type projection: bool
:return: Matplotlib figure and axis object
"""
assert X_train is not None
assert isinstance(X_train, (np.ndarray, list))
assert isinstance(projection, bool)
if isinstance(X_train, list):
X_train = np.array(X_train)
if X_test is not None:
assert isinstance(X_test, (np.ndarray, list))
if isinstance(X_test, list):
X_test = np.array(X_test)
if y_train is not None:
assert isinstance(y_train, (np.ndarray, list))
if isinstance(y_train, list):
y_train = np.array(y_train)
if y_test is not None:
assert isinstance(y_test, (np.ndarray, list))
if isinstance(y_test, list):
y_test = np.array(y_test)
dimension = X_train.shape[1]
if X_test is not None:
assert dimension == X_test.shape[1]
assert 1 <= dimension <= 3
fig = plt.figure()
if dimension == 1:
ax = fig.add_subplot(projection="3d" if projection else "rectilinear")
for index, centroid in enumerate(centroids):
cluster = X_train[y_train == index]
train = ax.scatter(cluster[:, 0], np.zeros(cluster.shape[0]))
ax.scatter(
centroid[0],
np.zeros(centroid.shape[0]),
s=200,
c=train.get_edgecolor(),
)
if X_test is not None:
cluster = X_test[y_test == index]
ax.scatter(
cluster[:, 0],
np.zeros(cluster.shape[0]),
marker="x",
c=train.get_edgecolor(),
)
elif dimension == 2:
ax = fig.add_subplot(projection="3d" if projection else "rectilinear")
for index, centroid in enumerate(centroids):
cluster = X_train[y_train == index]
train = ax.scatter(cluster[:, 0], cluster[:, 1])
ax.scatter(centroid[0], centroid[1], s=200, c=train.get_edgecolor())
if X_test is not None:
cluster = X_test[y_test == index]
ax.scatter(
cluster[:, 0],
cluster[:, 1],
marker="x",
c=train.get_edgecolor(),
)
else:
ax = fig.add_subplot(projection="3d")
for index, centroid in enumerate(centroids):
cluster = X_train[y_train == index]
train = ax.scatter(cluster[:, 0], cluster[:, 1], cluster[:, 2])
ax.scatter(
centroid[0],
centroid[1],
centroid[2],
s=200,
c=train.get_edgecolor(),
)
if X_test is not None:
cluster = X_test[y_test == index]
ax.scatter(
cluster[:, 0],
cluster[:, 1],
cluster[:, 2],
marker="x",
c=train.get_edgecolor(),
)
ax.legend(
["Train Samples", "Centroids", "Test Samples"]
) if X_test is not None else ax.legend([" Samples", "Centroids"])
leg = ax.get_legend()
for handle in leg.legendHandles:
handle.set_color("brown")
return fig, ax