-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtoyexample.py
More file actions
29 lines (24 loc) · 1.05 KB
/
toyexample.py
File metadata and controls
29 lines (24 loc) · 1.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
import matplotlib.pyplot as plt
import cv2
import kmeans as km
X = np.array([[2,8],[9,7],[1,9],[4,1],[0,8],[6,2],[5,0],[5,1],[10,10],[8,10],[0,7],[1,6],[9,9],[8,8],[5,2]])
k = 3
max_iter = 5
centroids = km.initialize_centroids(X, k)
for i in range(max_iter):
closest = km.find_closest_centroids(X, centroids)
colors = ['red','green','blue']
for i in range(k):
Xtemp = np.transpose(np.multiply(np.transpose(X),np.equal(closest,i)))
Xtemp = Xtemp[~np.all(Xtemp == 0, axis=1)]
plt.scatter(Xtemp[:,0],Xtemp[:,1],color=colors[i],marker='o')
plt.scatter(centroids[:,0],centroids[:,1],color='black',marker='+',s=100)
plt.show()
centroids = km.update_centroids(X, k, closest)
################################################################################
#ground truth clusters [0 1 0 2 0 2 2 2 1 1 0 0 1 1 2]
#ground truth centroids [[5. 1.2]
# [8.8 8.8]
# [0.8 7.6]]
################################################################################