-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathalg_test.py
More file actions
84 lines (64 loc) · 2.02 KB
/
alg_test.py
File metadata and controls
84 lines (64 loc) · 2.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from pig.scripts.PartitionAlg import PartitionAlg, DataPoint, sum_squares
from random import random
import time
def print_clusters(clusters):
for cluster in clusters:
print("Cluster:")
for point in cluster.points:
print("{}, ".format(point.id), end='')
print()
def get_stats(clusters):
min = float("inf")
max = 0
for cluster in clusters:
if len(cluster.points) < min:
min = len(cluster.points)
if len(cluster.points) > max:
max = len(cluster.points)
return min, max
n = 10
n_points = 100
cluster_size = 6
def test_performance():
global n
global n_points
global cluster_size
for i in range(100):
points = []
for id in range(n_points):
point = [random() for i in range(n)]
points.append(DataPoint(id, point))
t = time.time()
clusters = PartitionAlg.k_means(points, cluster_size)
PartitionAlg.normalize(clusters, cluster_size)
t = time.time() - t
stats = get_stats(clusters)
print("{} students: {} group size min, max: ({}, {})".format(n_points, t, stats[0], stats[1]))
n_points += 100
def test_accuracy():
global n
global n_points
global cluster_size
print("\nPartitioning\n")
# Create data
points = []
for id in range(n_points):
point = [random() for i in range(n)]
points.append(DataPoint(id, point))
# Partition
clusters = PartitionAlg.k_means(points, cluster_size)
print_clusters(clusters)
print("\nNormalizing\n")
# Normalize
PartitionAlg.normalize(clusters, cluster_size)
print_clusters(clusters)
# Test normalizing
for cluster in clusters[:-1]:
assert len(cluster.points) == cluster_size
# Test accuracy
for i, cluster in enumerate(clusters):
variation = 0
for point in cluster.points:
variation += sum_squares(cluster.mean, point.point)
print(f"Cluster {i} variation: {variation}")
test_accuracy()