-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpngToCutomModel.py
More file actions
118 lines (89 loc) · 3.84 KB
/
pngToCutomModel.py
File metadata and controls
118 lines (89 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing import image
from sklearn.model_selection import train_test_split
from tensorflow.keras.applications.mobilenet import preprocess_input
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense, Conv2D, MaxPooling2D
import seaborn as sns
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import BatchNormalization
from sklearn.metrics import ConfusionMatrixDisplay as cmd
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
from datetime import datetime
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.callbacks import ReduceLROnPlateau
def load_images_from_path(path):
images, labels = [], []
for file in os.listdir(path):
if file.lower().endswith('.png'):
images.append(image.img_to_array(image.load_img(os.path.join(path, file),
target_size=(224, 224, 3))))
labels.append((int(file[0])-1))
return images, labels
def show_images(images):
fig, axes = plt.subplots(1, 8, figsize=(20, 20),
subplot_kw={'xticks': [], 'yticks': []})
for i, ax in enumerate(axes.flat):
ax.imshow(images[i] / 255)
x, y = [], []
images, labels = load_images_from_path('Spectrograms')
x += images
y += labels
x_train, x_test, y_train, y_test = train_test_split(x, y, stratify=y, test_size=0.1,
random_state=0)
y_train = np.array(y_train)
y_test = np.array(y_test)
timestamp = datetime.now().strftime("%m-%d-%y_%H:%M:%S")
checkpoint = ModelCheckpoint(
filepath=f'Models/Sonai_{timestamp}.h5',
monitor='val_accuracy', # or 'val_loss'
save_best_only=True,
mode='max', # use 'min' for val_loss
verbose=1
)
x_train_norm = preprocess_input(np.array(x_train))
x_test_norm = preprocess_input(np.array(x_test))
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)))
model.add(MaxPooling2D(2, 2))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(2, 2))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(2, 2))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(2, 2))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(2, 2))
model.add(Flatten())
model.add(Dense(512, activation='relu', kernel_regularizer=l2(0.001)))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(128, activation='relu', kernel_regularizer=l2(0.001)))
model.add(BatchNormalization())
model.add(Dense(5, activation='softmax'))
model.compile(optimizer=Adam(learning_rate=0.0005),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Callbacks
early_stop = EarlyStopping(monitor='val_accuracy', patience=15, restore_best_weights=True)
lr_reducer = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, verbose=1)
hist = model.fit(x_train_norm, y_train, validation_data=(x_test_norm, y_test),
batch_size=22, epochs=40, callbacks=[checkpoint, early_stop, lr_reducer])
sns.set()
acc = hist.history['accuracy']
val_acc = hist.history['val_accuracy']
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, '-', label='Training Accuracy')
plt.plot(epochs, val_acc, ':', label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
# Save the plot with a timestamp
plot_filename = f"Models/Plot_{timestamp}.png"
plt.savefig(plot_filename)