-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathKMeansnet.py
More file actions
36 lines (29 loc) · 1.02 KB
/
KMeansnet.py
File metadata and controls
36 lines (29 loc) · 1.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
class Kmeansnet(object):
def __init__(self, data, clusters, eta):
self.data = data
self.n_dim = data.shape[1]
self.num_clusters = clusters
self.weights = np.random.rand(self.num_clusters, self.n_dim)
self.eta = eta
def calc_dist(self, inp, weights):
return np.sum((weights * inp), axis=1)
def normalise_data(self, data):
normalisers = np.sqrt(np.sum(data ** 2, axis=1)).reshape(self.data.shape[0], 1)
return data / normalisers
def train(self, epochs):
self.data = self.normalise_data(self.data)
for i in range(epochs):
for d in range(self.data.shape[0]):
dist = self.calc_dist(self.data[d, :], self.weights)
cluster = np.argmax(dist)
self.weights[cluster, :] += self.eta * self.data[d, :] - self.weights[cluster, :]
def predict(self, inp):
dist = self.calc_dist(inp, self.weights)
best = np.argmax(dist)
return best
def predict_all(self, data):
best = np.zeros((data.shape[0], 1))
for i in range(data.shape[0]):
best[i] = self.predict(data[i, :])
return best