-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_with_data_augmentation.py
More file actions
109 lines (93 loc) · 3.12 KB
/
model_with_data_augmentation.py
File metadata and controls
109 lines (93 loc) · 3.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from tensorflow.keras.preprocessing import image_dataset_from_directory
import matplotlib.pyplot as plt
# === CONFIG ===
DATASET_DIR = r"C:\\Users\\hgmar\\Downloads\\archive\\BreaKHis_v1\\BreaKHis_v1\\histology_slides\\breast"
BATCH_SIZE = 32
IMG_SIZE = (224, 224)
EPOCHS = 10
# === LOAD DATA WITH AUGMENTATION ===
print("Loading dataset with augmentation...")
# Data augmentation pipeline
data_augmentation = tf.keras.Sequential([
layers.RandomFlip('horizontal'),
layers.RandomRotation(0.1),
layers.RandomZoom(0.1),
])
train_dataset = image_dataset_from_directory(
DATASET_DIR,
labels='inferred',
label_mode='binary',
batch_size=BATCH_SIZE,
image_size=IMG_SIZE,
shuffle=True,
validation_split=0.2,
subset='training',
seed=123
)
val_dataset = image_dataset_from_directory(
DATASET_DIR,
labels='inferred',
label_mode='binary',
batch_size=BATCH_SIZE,
image_size=IMG_SIZE,
shuffle=True,
validation_split=0.2,
subset='validation',
seed=123
)
# Apply data augmentation to the training set
train_dataset = train_dataset.map(lambda x, y: (data_augmentation(x, training=True), y))
# Prefetch for speed
AUTOTUNE = tf.data.AUTOTUNE
train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
val_dataset = val_dataset.prefetch(buffer_size=AUTOTUNE)
# === BUILD MODEL ===
print("Building MobileNetV2 model...")
base_model = MobileNetV2(input_shape=IMG_SIZE + (3,),
include_top=False,
weights='imagenet')
base_model.trainable = False
# Use Functional API
inputs = tf.keras.Input(shape=IMG_SIZE + (3,))
x = preprocess_input(inputs)
x = base_model(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
outputs = layers.Dense(1, activation='sigmoid')(x)
model = models.Model(inputs, outputs)
# === COMPILE MODEL ===
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
# === TRAIN MODEL ===
print("Training model...")
history = model.fit(
train_dataset,
validation_data=val_dataset,
epochs=EPOCHS
)
# === SAVE MODEL ===
model.save("breakhis_mobilenet_model.keras")
print("Model saved as breakhis_mobilenet_model.keras")
# === PLOT RESULTS ===
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(EPOCHS)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()