-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
70 lines (47 loc) · 1.18 KB
/
train.py
File metadata and controls
70 lines (47 loc) · 1.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from cost_funcs import CrossEntropyCost
from network import Network
from matplotlib import gridspec, pyplot
import pickle
epochs = 60
mini_batch_size = 10
learning_rate = 0.1
lmbda = 5.0
cost = CrossEntropyCost()
net = Network([784, 100, 10], cost)
ea, ec, ta, tc = net.SGD(
"mnist/data/training.gz",
epochs,
mini_batch_size,
learning_rate,
lmbda,
"mnist/data/testing.gz",
True,
True,
True,
True,
)
with open("progress_data.pkl", "wb") as f:
pickle.dump((ea, ec, ta, tc), f)
x = list(range(epochs))
gs = gridspec.GridSpec(3, 2)
pyplot.figure()
# ----------------------------
acc_plot = pyplot.subplot(gs[0, :]).set_xticklabels([])
pyplot.plot(x, ea, label="evaluation accuracy")
pyplot.plot(x, ta, label="training accuracy")
pyplot.ylabel("accuracy")
pyplot.legend()
# ----------------------------
pyplot.subplot(gs[1, :]).set_xticklabels([])
pyplot.plot(x, ec, label="evaluation cost")
pyplot.ylabel("cost")
pyplot.legend()
# ----------------------------
pyplot.subplot(gs[2, :])
pyplot.plot(x, tc, label="training cost")
pyplot.xlabel("epochs")
pyplot.ylabel("cost")
pyplot.legend()
# ----------------------------
pyplot.show()
pyplot.close()