Skip to content

Commit 3cc3847

Browse files
committed
Fixed data loaders, model has started training but is overfitting
1 parent f0f81f7 commit 3cc3847

4 files changed

Lines changed: 209 additions & 182 deletions

File tree

.DS_Store

0 Bytes
Binary file not shown.

dataset.py

Lines changed: 96 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,123 @@
11
import os
2-
import torch
32
from torch.utils.data import Dataset
3+
import torchvision.transforms as transforms
44
from PIL import Image
5+
import torch
56
import numpy as np
67
import cv2
7-
import torchvision.transforms as transforms
88

99
class ISICDataset(Dataset):
10-
11-
def __init__(self, img_dir='ISIC2018/ISIC2018_Task1-2_Training_Input_x2',
12-
annot_dir='ISIC2018/ISIC2018_Task1_Training_GroundTruth_x2',
13-
mode='train', transform=None, img_size=640, grid_size=80):
10+
def __init__(self, img_dir, annot_dir, mode='train', transform=None, img_size=640, model_output_grid_size=80):
1411
"""
1512
Initializes the ISICDataset.
16-
17-
Parameters:
18-
img_dir (str): Path to the directory containing images.
19-
annot_dir (str): Path to the directory containing annotation files (optional for test mode).
20-
mode (str): Mode of the dataset, either 'train' (with annotations) or 'test' (without annotations).
21-
transform (callable, optional): Optional transformations to apply to the images.
2213
"""
23-
14+
print("Initializing ISICDataset...")
15+
2416
self.img_dir = img_dir
25-
self.img_size = img_size
2617
self.annot_dir = annot_dir if mode == 'train' else None
2718
self.mode = mode
28-
self.transform = transform if transform else self.default_transforms()
29-
self.img_files = sorted([f for f in os.listdir(img_dir) if f.endswith('.jpg')])
3019
self.img_size = img_size
31-
self.grid_size = grid_size
20+
self.transform = transform if transform else self.default_transforms()
3221
self.num_anchors = 3
33-
22+
self.grid_size = model_output_grid_size # This should be the same as the model's output grid size (80 in your case)
3423

35-
def default_transforms(self):
36-
"""
37-
Defines default transformations for images if none are provided.
24+
# Get list of image files
25+
print("Loading image files...")
26+
self.img_files = sorted([f for f in os.listdir(img_dir) if f.endswith('.jpg')])
3827

39-
Returns:
40-
transform (callable): Transformation pipeline with resizing, normalization, and conversion to tensor.
41-
"""
28+
if self.mode == 'train':
29+
# Filtering only those images that have corresponding annotation files
30+
print("Filtering images with corresponding annotations...")
31+
annot_files = set([f.replace('_segmentation.png', '') for f in os.listdir(annot_dir) if f.endswith('.png')])
32+
self.img_files = [f for f in self.img_files if f.replace('.jpg', '') in annot_files]
33+
34+
# Safeguard against empty dataset
35+
if not self.img_files:
36+
raise ValueError(f"No valid images found in {img_dir} with corresponding annotations in {annot_dir}")
37+
38+
print(f"Dataset initialized with {len(self.img_files)} images.")
39+
40+
def default_transforms(self):
41+
print("Setting default image transformations...")
4242
return transforms.Compose([
43-
transforms.Resize((self.img_size, self.img_size)), # Resizing for consistent dimensions
43+
transforms.Resize((self.img_size, self.img_size)),
4444
transforms.RandomHorizontalFlip(),
4545
transforms.RandomVerticalFlip(),
46-
transforms.ToTensor(), # Converting image to PyTorch tensor
46+
transforms.ToTensor(),
4747
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
48-
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Normalizing with ImageNet standard
48+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
4949
])
5050

5151
def __len__(self):
52-
"""
53-
Returns the total number of images in the dataset.
54-
55-
Returns:
56-
int: Number of images.
57-
"""
5852
return len(self.img_files)
59-
60-
def convert_to_yolo_format(self, bbox, img_width, img_height):
61-
"""
62-
Converts bounding box coordinates from (x_min, y_min, x_max, y_max) to YOLO format.
6353

64-
Parameters:
65-
bbox (list): Bounding box in [x_min, y_min, x_max, y_max] format.
66-
img_width (int): Width of the image.
67-
img_height (int): Height of the image.
54+
def __getitem__(self, idx):
55+
print(f"Getting item {idx}...")
56+
img_path = os.path.join(self.img_dir, self.img_files[idx])
57+
print(f"Loading image from: {img_path}")
6858

69-
Returns:
70-
tuple: Bounding box in YOLO format (x_center, y_center, width, height), normalized to the image size.
71-
"""
59+
try:
60+
image = Image.open(img_path).convert("RGB")
61+
except Exception as e:
62+
print(f"Error loading image {img_path}: {e}")
63+
# In case of error, create a dummy image to keep format consistent
64+
image = torch.zeros((3, self.img_size, self.img_size))
65+
66+
if self.transform:
67+
image = self.transform(image)
68+
69+
if self.mode == 'train':
70+
annot_filename = self.img_files[idx].replace('.jpg', '_segmentation.png')
71+
annot_path = os.path.join(self.annot_dir, annot_filename)
72+
print(f"Loading annotation from: {annot_path}")
73+
74+
if not os.path.exists(annot_path):
75+
print("Annotation file not found, creating dummy target.")
76+
return image, torch.zeros((self.num_anchors, self.grid_size, self.grid_size, 85))
77+
78+
try:
79+
mask = Image.open(annot_path).convert("L")
80+
except Exception as e:
81+
print(f"Error loading annotation {annot_path}: {e}")
82+
return image, torch.zeros((self.num_anchors, self.grid_size, self.grid_size, 85))
83+
84+
mask = mask.resize((self.img_size, self.img_size))
85+
print("Annotation loaded and resized.")
86+
87+
# Convert mask to numpy array and extract bounding boxes
88+
boxes = self.mask_to_bounding_boxes(mask)
89+
90+
# Create a target tensor of size (num_anchors, grid_size, grid_size, 85) and populate it
91+
target_tensor = torch.zeros((self.num_anchors, self.grid_size, self.grid_size, 85))
92+
93+
# Iterate over the bounding boxes and assign them to the appropriate grid cells and anchors
94+
img_width, img_height = mask.size
95+
for box in boxes:
96+
x_min, y_min, x_max, y_max = box
97+
# Calculate grid cell positions
98+
grid_x = int((x_min + x_max) / 2 / img_width * self.grid_size)
99+
grid_y = int((y_min + y_max) / 2 / img_height * self.grid_size)
100+
101+
# Ensure the grid coordinates are within bounds
102+
grid_x = min(max(grid_x, 0), self.grid_size - 1)
103+
grid_y = min(max(grid_y, 0), self.grid_size - 1)
104+
105+
# Convert box to YOLO format
106+
x_center, y_center, width, height = self.convert_to_yolo_format(box, img_width, img_height)
107+
108+
# Assign to target tensor - in this case, using the first anchor (anchor 0)
109+
target_tensor[0, grid_y, grid_x, 0:4] = torch.tensor([x_center, y_center, width, height])
110+
target_tensor[0, grid_y, grid_x, 4] = 1.0 # Objectness score
111+
# Set the class label - assuming one class for skin lesions
112+
target_tensor[0, grid_y, grid_x, 5:] = torch.zeros(80)
113+
114+
return image, target_tensor
115+
else:
116+
# Return a dummy target for validation/test to ensure consistent return format
117+
dummy_target = torch.zeros((self.num_anchors, self.grid_size, self.grid_size, 85))
118+
return image, dummy_target
119+
120+
def convert_to_yolo_format(self, bbox, img_width, img_height):
72121
x_min, y_min, x_max, y_max = bbox
73122
x_center = (x_min + x_max) / 2.0 / img_width
74123
y_center = (y_min + y_max) / 2.0 / img_height
@@ -77,71 +126,13 @@ def convert_to_yolo_format(self, bbox, img_width, img_height):
77126
return x_center, y_center, width, height
78127

79128
def mask_to_bounding_boxes(self, mask):
80-
"""
81-
Extracts bounding boxes from a binary mask image.
82-
83-
Parameters:
84-
mask (PIL.Image): Grayscale mask image.
85-
86-
Returns:
87-
list: List of bounding boxes in [x_min, y_min, x_max, y_max] format.
88-
"""
89129
mask_np = np.array(mask)
90130
boxes = []
91131

92132
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
93133
for contour in contours:
94134
x, y, w, h = cv2.boundingRect(contour)
95-
boxes.append([x, y, x + w, y + h])
96-
97-
return boxes
135+
if w > 0 and h > 0: # Ensure valid bounding box
136+
boxes.append([x, y, x + w, y + h])
98137

99-
def __getitem__(self, idx):
100-
"""
101-
Retrieves the image (and annotation if in train mode) for a given index.
102-
103-
Parameters:
104-
idx (int): Index of the image to retrieve.
105-
106-
Returns:
107-
tuple: If in train mode, returns (image, yolo_boxes), else returns image.
108-
"""
109-
# Load image
110-
img_path = os.path.join(self.img_dir, self.img_files[idx])
111-
image = Image.open(img_path).convert("RGB")
112-
113-
if self.transform:
114-
image = self.transform(image) # Resize and normalize the image
115-
116-
# Load and process annotations if in training mode
117-
if self.mode == 'train':
118-
# Load the segmentation mask as an image
119-
annot_path = os.path.join(self.annot_dir, self.img_files[idx].replace('.jpg', '_segmentation.png'))
120-
121-
if not os.path.exists(annot_path):
122-
raise FileNotFoundError(f"Annotation file not found: {annot_path}")
123-
124-
# Open the mask file as a grayscale image
125-
mask = Image.open(annot_path).convert("L") # Convert to grayscale
126-
mask = mask.resize((self.img_size, self.img_size)) # Resize to match image dimensions
127-
boxes = self.mask_to_bounding_boxes(mask)
128-
img_width, img_height = mask.size
129-
130-
# Convert bounding boxes to YOLO format
131-
yolo_boxes = [self.convert_to_yolo_format(box, img_width, img_height) for box in boxes]
132-
133-
# Prepare YOLO-style target tensor with 85 channels (for single class dataset)
134-
target = torch.zeros((self.num_anchors, self.grid_size, self.grid_size, 85)) # 85 for class probabilities
135-
136-
for x_center, y_center, width, height in yolo_boxes:
137-
grid_x = int(x_center * self.grid_size)
138-
grid_y = int(y_center * self.grid_size)
139-
target[:, grid_y, grid_x, 0] = x_center
140-
target[:, grid_y, grid_x, 1] = y_center
141-
target[:, grid_y, grid_x, 2] = width
142-
target[:, grid_y, grid_x, 3] = height
143-
target[:, grid_y, grid_x, 4] = 1.0 # Object confidence score
144-
145-
return image, target # Returning the image and YOLO-style target as a tuple
146-
else:
147-
return image
138+
return boxes

modules.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import torch
2+
import torch.nn as nn
23

3-
class LesionDetectionModel:
4+
class LesionDetectionModel(nn.Module):
45
def __init__(self, model_weights='yolov7.pt', device='cpu'):
56
"""
6-
Initializes the YOLOv7 model for lesion detection using PyTorch Hub.
7+
Initializes the YOLOv7 model for lesion detection using PyTorch Hub with additional dropout layers.
78
89
Parameters:
910
model_weights (str): Path to the pre-trained YOLOv7 weights.
1011
device (str): Device to load the model on ('cuda' or 'cpu').
1112
"""
13+
super(LesionDetectionModel, self).__init__()
14+
1215
self.device = torch.device('cuda' if device == 'cuda' and torch.cuda.is_available() else 'cpu')
13-
16+
1417
# Load the YOLO model without the autoShape wrapper to get direct access to its layers
1518
self.model = torch.hub.load('WongKinYiu/yolov7', 'custom', model_weights, source='github', autoshape=False)
1619
self.model.to(self.device)
@@ -20,6 +23,9 @@ def __init__(self, model_weights='yolov7.pt', device='cpu'):
2023
for param in self.model.backbone.parameters():
2124
param.requires_grad = False
2225

26+
# Add dropout after certain layers
27+
self.dropout = nn.Dropout(p=0.2) # Example of a dropout layer with 50% probability
28+
2329
def forward(self, images):
2430
"""
2531
Performs a forward pass through the model.
@@ -31,9 +37,14 @@ def forward(self, images):
3137
torch.Tensor: Model output with predictions for each bounding box.
3238
"""
3339
images = images.to(self.device)
34-
with torch.no_grad():
35-
pred = self.model(images)[0] # Get predictions
36-
return pred
40+
41+
# Perform a forward pass through the original model
42+
x = self.model(images)[0]
43+
44+
# Apply dropout before returning output
45+
x = self.dropout(x)
46+
47+
return x
3748

3849
def detect(self, images, conf_thres=0.25, iou_thres=0.8):
3950
"""

0 commit comments

Comments
 (0)