-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_loader.py
More file actions
47 lines (40 loc) · 1.45 KB
/
data_loader.py
File metadata and controls
47 lines (40 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
class DataLoader:
def __init__(self, train_dir, val_dir, img_size, batch_size):
self.train_dir = train_dir
self.val_dir = val_dir
self.img_size = img_size
self.batch_size = batch_size
def load_data(self):
"""
Configures and loads the training and validation datasets using ImageDataGenerator.
"""
# ImageDataGenerator for data augmentation
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
# Data generator for validation (no augmentation)
val_datagen = ImageDataGenerator(rescale=1./255)
# Load the datasets
train_gen = train_datagen.flow_from_directory(
self.train_dir,
target_size=self.img_size,
batch_size=self.batch_size,
class_mode='binary'
)
val_gen = val_datagen.flow_from_directory(
self.val_dir,
target_size=self.img_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=False # Recommended for validation data to maintain order
)
return train_gen, val_gen