Skip to content

Commit 8a83da1

Browse files
committed
add feature extraction
1 parent ffb8ed8 commit 8a83da1

2 files changed

Lines changed: 167 additions & 1 deletion

File tree

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from enum import Enum
2+
import pickle
3+
from typing import Dict, Optional, Tuple
4+
import torch
5+
from torch.utils.data import DataLoader
6+
import numpy as np
7+
8+
from models.resnet_50_base import load_pretrained_model, create_feature_extractor, MODEL_NAMES
9+
from pybbbc import BBBC021, constants
10+
11+
import os
12+
13+
14+
def extract_moa_features(model_name: MODEL_NAMES, device, batch_size = 16, data_root: str = "/scratch/cv-course2025/group8", compounds: list[str] = None) -> None:
15+
"""
16+
Extract features for the BBBC021 dataset using a pretrained ResNet50 model.
17+
18+
Args:
19+
model_name: Name of the model to use. Is of type MODEL_NAMES.
20+
device: Device to run the model on
21+
batch_size: Batch size for data loading
22+
data_root: Root directory where the BBBC021 dataset is stored.
23+
compounds: List of compounds to process. If None, all compounds will be processed.
24+
"""
25+
26+
# Load pretrained ResNet50 model
27+
pretrained_model = load_pretrained_model(model_name)
28+
# Create feature extractor
29+
feature_extractor = create_feature_extractor(pretrained_model)
30+
31+
if not compounds:
32+
compounds = constants.COMPOUNDS
33+
else:
34+
for compound in compounds:
35+
if compound not in constants.COMPOUNDS:
36+
raise ValueError(f"Compound '{compound}' is not a valid compound. "
37+
f"Valid compounds are: {constants.COMPOUNDS}")
38+
39+
# Create output directory with model name
40+
output_dir = os.path.join(data_root, "bbbc021_features", model_name.value)
41+
os.makedirs(output_dir, exist_ok=True)
42+
43+
# Set device
44+
feature_extractor = feature_extractor.to(device)
45+
feature_extractor.eval()
46+
47+
# Process each compound dynamically
48+
for compound in compounds:
49+
data = BBBC021(root_path=data_root, compound=compound) # Fixed: use single compound
50+
print(f"Processing Compound: {compound} with {len(data.images)} images")
51+
52+
# Dictionary to store images grouped by (compound, concentration, moa)
53+
image_groups: Dict[Tuple[str, float, str], list[torch.Tensor]] = {}
54+
55+
# Collect images for this compound
56+
for image, metadata in data:
57+
if metadata.compound.moa == 'null':
58+
print(f"Skipping image with null MOA for compound {compound}.")
59+
continue
60+
61+
key = (metadata.compound.compound,
62+
metadata.compound.concentration,
63+
metadata.compound.moa)
64+
65+
if key not in image_groups:
66+
image_groups[key] = []
67+
# Convert numpy array to tensor if needed
68+
if isinstance(image, np.ndarray):
69+
image = torch.from_numpy(image).float()
70+
image_groups[key].append(image)
71+
72+
# Process each group for this compound immediately
73+
for key, images in image_groups.items():
74+
compound_name, concentration, moa = key
75+
76+
if len(images) == 0:
77+
print(f"Warning: No images for group {compound_name}_{concentration}. Skipping...")
78+
continue
79+
80+
print(f"Extracting features for {compound_name}@{concentration}({moa}) - {len(images)} images")
81+
82+
try:
83+
# Create DataLoader for this group
84+
dataset = torch.utils.data.TensorDataset(torch.stack(images))
85+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
86+
87+
# Extract features
88+
all_features = []
89+
with torch.no_grad():
90+
for batch in dataloader:
91+
batch_images = batch[0].to(device)
92+
features = feature_extractor(batch_images)
93+
features = features.squeeze() # Remove spatial dimensions
94+
all_features.append(features.cpu())
95+
96+
# Compute average features
97+
all_features = torch.cat(all_features, dim=0)
98+
avg_features = torch.mean(all_features, dim=0)
99+
100+
# Create result as tuple (key, feature)
101+
result = (key, avg_features)
102+
103+
# Create filename: compound_concentration
104+
filename = f"{compound_name}_{concentration}.pkl".replace(' ', '_').replace('/', '_')
105+
filepath = os.path.join(output_dir, filename)
106+
107+
# Save to file
108+
with open(filepath, 'wb') as f:
109+
pickle.dump(result, f)
110+
111+
print(f"Saved features to {filepath}")
112+
113+
except Exception as e:
114+
print(f"Error processing group {compound_name}_{concentration}: {e}. Skipping...")
115+
continue
116+
117+
118+
# main function to run the feature extraction
119+
if __name__ == "__main__":
120+
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
121+
122+
extract_moa_features(
123+
model_name=MODEL_NAMES.BASE_RESNET,
124+
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
125+
batch_size=16,
126+
data_root="/scratch/cv-course2025/group8",
127+
compounds=constants.COMPOUNDS)

models/resnet_50_base.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,23 @@
1+
from enum import Enum
12
from torchvision import models
23
import torch.nn as nn
34
import torch
45

6+
class MODEL_NAMES(Enum):
7+
BASE_RESNET = "base_resnet"
8+
SIMCLR = "resnet_simclr"
9+
DINO = "resnet_wsdino"
10+
11+
def load_pretrained_model(model_name: MODEL_NAMES, weight_path='/scratch/cv-course2025/group8/model_weights'):
12+
"""Load pretrained ResNet50 model."""
13+
14+
# Load full model
15+
if model_name == MODEL_NAMES.BASE_RESNET:
16+
return load_pretrained_resnet50(weights="IMAGENET1K_V2")
17+
18+
elif model_name == MODEL_NAMES.SIMCLR:
19+
return load_pretrained_model_from_weights("resnet50_simclr", weight_path)
20+
521
def load_pretrained_resnet50(weights: str = "IMAGENET1K_V2") -> object:
622
"""Load pretrained ResNet50 model.
723
@@ -26,7 +42,30 @@ def load_pretrained_resnet50(weights: str = "IMAGENET1K_V2") -> object:
2642
pretrained_model.eval()
2743
return pretrained_model
2844

29-
45+
def load_pretrained_model_from_weights(model_name: str, weight_path: str) -> nn.Module:
46+
# TODO: Test this after we trained models
47+
"""Load pretrained ResNet50 model from custom weights.
48+
49+
Args:
50+
model_name: Name of the model to load
51+
weight_path: Path to the weights file
52+
53+
Returns:
54+
nn.Module: Pretrained ResNet50 model
55+
"""
56+
print(f"Loading pretrained ResNet50 from {weight_path}...")
57+
58+
# Load the model architecture
59+
pretrained_model = models.resnet50(weights=None)
60+
61+
# Load the weights
62+
try:
63+
pretrained_model.load_state_dict(torch.load(f"{weight_path}/{model_name.value}.pth"))
64+
except FileNotFoundError:
65+
raise ValueError(f"Weight file '{model_name}.pth' not found in '{weight_path}'")
66+
67+
pretrained_model.eval()
68+
return pretrained_model
3069

3170
def create_feature_extractor(pretrained_model: nn.Module) -> nn.Module:
3271
"""Create a feature extractor from a pretrained ResNet50 model.

0 commit comments

Comments
 (0)