1+ from enum import Enum
2+ import pickle
3+ from typing import Dict , Optional , Tuple
4+ import torch
5+ from torch .utils .data import DataLoader
6+ import numpy as np
7+
8+ from models .resnet_50_base import load_pretrained_model , create_feature_extractor , MODEL_NAMES
9+ from pybbbc import BBBC021 , constants
10+
11+ import os
12+
13+
14+ def extract_moa_features (model_name : MODEL_NAMES , device , batch_size = 16 , data_root : str = "/scratch/cv-course2025/group8" , compounds : list [str ] = None ) -> None :
15+ """
16+ Extract features for the BBBC021 dataset using a pretrained ResNet50 model.
17+
18+ Args:
19+ model_name: Name of the model to use. Is of type MODEL_NAMES.
20+ device: Device to run the model on
21+ batch_size: Batch size for data loading
22+ data_root: Root directory where the BBBC021 dataset is stored.
23+ compounds: List of compounds to process. If None, all compounds will be processed.
24+ """
25+
26+ # Load pretrained ResNet50 model
27+ pretrained_model = load_pretrained_model (model_name )
28+ # Create feature extractor
29+ feature_extractor = create_feature_extractor (pretrained_model )
30+
31+ if not compounds :
32+ compounds = constants .COMPOUNDS
33+ else :
34+ for compound in compounds :
35+ if compound not in constants .COMPOUNDS :
36+ raise ValueError (f"Compound '{ compound } ' is not a valid compound. "
37+ f"Valid compounds are: { constants .COMPOUNDS } " )
38+
39+ # Create output directory with model name
40+ output_dir = os .path .join (data_root , "bbbc021_features" , model_name .value )
41+ os .makedirs (output_dir , exist_ok = True )
42+
43+ # Set device
44+ feature_extractor = feature_extractor .to (device )
45+ feature_extractor .eval ()
46+
47+ # Process each compound dynamically
48+ for compound in compounds :
49+ data = BBBC021 (root_path = data_root , compound = compound ) # Fixed: use single compound
50+ print (f"Processing Compound: { compound } with { len (data .images )} images" )
51+
52+ # Dictionary to store images grouped by (compound, concentration, moa)
53+ image_groups : Dict [Tuple [str , float , str ], list [torch .Tensor ]] = {}
54+
55+ # Collect images for this compound
56+ for image , metadata in data :
57+ if metadata .compound .moa == 'null' :
58+ print (f"Skipping image with null MOA for compound { compound } ." )
59+ continue
60+
61+ key = (metadata .compound .compound ,
62+ metadata .compound .concentration ,
63+ metadata .compound .moa )
64+
65+ if key not in image_groups :
66+ image_groups [key ] = []
67+ # Convert numpy array to tensor if needed
68+ if isinstance (image , np .ndarray ):
69+ image = torch .from_numpy (image ).float ()
70+ image_groups [key ].append (image )
71+
72+ # Process each group for this compound immediately
73+ for key , images in image_groups .items ():
74+ compound_name , concentration , moa = key
75+
76+ if len (images ) == 0 :
77+ print (f"Warning: No images for group { compound_name } _{ concentration } . Skipping..." )
78+ continue
79+
80+ print (f"Extracting features for { compound_name } @{ concentration } ({ moa } ) - { len (images )} images" )
81+
82+ try :
83+ # Create DataLoader for this group
84+ dataset = torch .utils .data .TensorDataset (torch .stack (images ))
85+ dataloader = DataLoader (dataset , batch_size = batch_size , shuffle = False )
86+
87+ # Extract features
88+ all_features = []
89+ with torch .no_grad ():
90+ for batch in dataloader :
91+ batch_images = batch [0 ].to (device )
92+ features = feature_extractor (batch_images )
93+ features = features .squeeze () # Remove spatial dimensions
94+ all_features .append (features .cpu ())
95+
96+ # Compute average features
97+ all_features = torch .cat (all_features , dim = 0 )
98+ avg_features = torch .mean (all_features , dim = 0 )
99+
100+ # Create result as tuple (key, feature)
101+ result = (key , avg_features )
102+
103+ # Create filename: compound_concentration
104+ filename = f"{ compound_name } _{ concentration } .pkl" .replace (' ' , '_' ).replace ('/' , '_' )
105+ filepath = os .path .join (output_dir , filename )
106+
107+ # Save to file
108+ with open (filepath , 'wb' ) as f :
109+ pickle .dump (result , f )
110+
111+ print (f"Saved features to { filepath } " )
112+
113+ except Exception as e :
114+ print (f"Error processing group { compound_name } _{ concentration } : { e } . Skipping..." )
115+ continue
116+
117+
118+ # main function to run the feature extraction
119+ if __name__ == "__main__" :
120+ os .environ ['CUDA_VISIBLE_DEVICES' ] = '1'
121+
122+ extract_moa_features (
123+ model_name = MODEL_NAMES .BASE_RESNET ,
124+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" ),
125+ batch_size = 16 ,
126+ data_root = "/scratch/cv-course2025/group8" ,
127+ compounds = constants .COMPOUNDS )
0 commit comments