-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
63 lines (50 loc) · 1.86 KB
/
train.py
File metadata and controls
63 lines (50 loc) · 1.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from loader import Loader
from config import *
from model import Network
from outer import Outer
import tensorflow as tf
import numpy as np
import time
def train(loader, config):
for train_round in range(config.train_episodes):
input_upside_buffer, input_downside_buffer, output_upside_buffer, output_downside_buffer = loader.sample(
config.train_buffer_size)
network_upside.train(input_upside_buffer, output_upside_buffer)
network_downside.train(input_downside_buffer, output_downside_buffer)
print("train done.")
def test(loader, config):
data = {}
data["image_input"] = np.reshape(
loader.test_set["input_upside_buffer"], (-1, 128, 128, 128, 1))
data["standard_mat"] = np.reshape(
loader.test_set["output_upside_buffer"], (-1, 16 * 7))
print("The loss of upside:")
network_upside.test(data)
data["image_input"] = np.reshape(
loader.test_set["input_downside_buffer"], (-1, 128, 128, 128, 1))
data["standard_mat"] = np.reshape(
loader.test_set["output_downside_buffer"], (-1, 16 * 7))
print("The loss of downside:")
network_downside.test(data)
if __name__ == "__main__":
# np config
np.set_printoptions(threshold=np.nan)
# initialize the setting and model
config = Config()
loader = Loader(128, 1000, config.training_set_percent)
loader.read_data_file()
loader.initialize_output()
print(loader.output)
loader.sets_apart()
# about the model
network_upside = Network(config, "_upside")
network_downside = Network(config, "_downside")
# train & test
if config.train:
train(loader, config)
if config.test:
test(loader, config)
# call the outer to expory a excel file
output = Outer(config, loader, network_upside, network_downside)
if config.output:
output.out("tooth_result_clin3.xlsx")