This repository was archived by the owner on Apr 14, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathfit.py
More file actions
87 lines (71 loc) · 2.42 KB
/
fit.py
File metadata and controls
87 lines (71 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.callbacks import ModelCheckpoint
# Set training params:
model_name = 'AlcoNet.h5' # path to not trained model
epochs = 100
batch_size = 64 # orig paper trained all networks with batch_size = 128
steps_per_epoch = 34 # typically should be equal to ceil(number_of_samples / batch_size)
img_size = (224,224) # size of input images (H,W) | ResNet standart size is 224x224
# Paths to dirs with train and validation datasets:
train_path = './data/train/'
validation_path = './data/validation/'
# Saving path for checkpoints (after each epoch):
save_path = './AlcoNet_trained{epoch:03d}.h5'
# Saving path for final trained model (after all epochs):
final_save_path = 'AlcoNet_trained{}_final.h5'.format(epochs)
# Generator for train data -->
train_datagen = image.ImageDataGenerator(
# Rescales and preprocess the images
# Transforms the images to increase dataset
# Use keras documentation to tune
rescale = 1./255,
rotation_range = 25,
zoom_range = 0.1,
shear_range = 0.2,
fill_mode = 'constant',
cval = 255,
preprocessing_function = preprocess_input )
train_generator = train_datagen.flow_from_directory(
# Generates train data from selected dir (train_path) for each epoch
# Detects classes by number of folders in selected dir
train_path,
target_size = img_size,
batch_size = batch_size,
class_mode = 'categorical',
shuffle = True )
# Generator for validation data -->
test_datagen = image.ImageDataGenerator(
rescale = 1./255,
preprocessing_function = preprocess_input )
test_generator = test_datagen.flow_from_directory(
validation_path,
target_size = img_size,
batch_size = batch_size,
class_mode = 'categorical',
shuffle = True )
# Checkpoints callback -->
checkpoint = ModelCheckpoint(
filepath = save_path,
monitor = 'val_acc',
verbose = 1,
save_best_only = True, # if True saving only when val_acc increase
mode = 'max' )
callbacks = [checkpoint]
# Model training -->
model = load_model(model_name)
model.fit_generator(
train_generator,
steps_per_epoch = steps_per_epoch,
epochs = epochs,
verbose = 1,
callbacks = callbacks,
validation_data = test_generator )
# Evaluate trained model -->
"""
test = model.evaluate_generator(test_generator, verbose=1)
print('Test loss:', test[0])
print('Test accuracy:', test[1])
"""
model.save(final_save_path)