-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtrain.py
More file actions
46 lines (37 loc) · 1.45 KB
/
train.py
File metadata and controls
46 lines (37 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from __future__ import absolute_import, division, print_function
import tensorflow as tf
from config import EPOCHS, BATCH_SIZE, model_dir
from prepare_data import get_datasets
from models.alexnet import AlexNet
from models.vgg16 import VGG16
from models.vgg19 import VGG19
def get_model():
# model = AlexNet()
model = VGG16()
# model = VGG19()
model.compile(loss=tf.keras.losses.categorical_crossentropy,
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
metrics=['accuracy'])
return model
if __name__ == '__main__':
# GPU settings
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
train_generator, valid_generator, test_generator, \
train_num, valid_num, test_num = get_datasets()
# Use command tensorboard --logdir "log" to start tensorboard
tensorboard = tf.keras.callbacks.TensorBoard(log_dir='log')
callback_list = [tensorboard]
model = get_model()
model.summary()
# start training
model.fit_generator(train_generator,
epochs=EPOCHS,
steps_per_epoch=train_num // BATCH_SIZE,
validation_data=valid_generator,
validation_steps=valid_num // BATCH_SIZE,
callbacks=callback_list)
# save the whole model
model.save(model_dir)