-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path53_inference_classifier.py
More file actions
381 lines (320 loc) · 16.4 KB
/
53_inference_classifier.py
File metadata and controls
381 lines (320 loc) · 16.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
import pandas as pd
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
import logging
import random
import argparse
import json
import sys
# --- Default Configuration ---
DEFAULT_MASTER_CSV_PATH = "inference_out/numeric/results.csv"
DEFAULT_CUTOUTS_DIR = "inference_out/cutouts_masked"
DEFAULT_GLOBAL_MODEL_DIR = "runs/classifier/all_species"
DEFAULT_PER_SPECIES_MODEL_DIR = "runs/classifier/species"
DEFAULT_LEGACY_MODEL_PATH = "runs/classifier/best_multihead_classifier.pt"
DEFAULT_OUTPUT_CSV_PATH = "inference_out/numeric/results.csv"
DEFAULT_VISUALIZE_DIR = "inference_out/visualized_classifier_inference"
DEFAULT_LOG_FILE = "master_log.txt"
ANNOTATION_CSV_PATH = "inference_out/cutouts_masked/annotations.csv"
# --- Sample Data Configuration ---
SAMPLE_CUTOUTS_DIR = "inference_out/cutouts_masked" # Note: If you ran the pipline before, the example cutouts will not align with the results-file. Use `inference_out/cutouts_masked` instead.
SAMPLE_ANNOTATION_CSV_PATH ="models/pretrained_model_data/classification/training_input_data/annotations.csv"
# Placeholder – will be built dynamically
LABEL_GROUPS = {} # mapping: group -> {'0': unknown?, '1': label, ...}
def setup_logging(log_file):
"""Configures logging to output to both console and file."""
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if logger.hasHandlers():
logger.handlers.clear()
# File handler for master_log.txt
file_handler = logging.FileHandler(log_file, mode='a', encoding="utf-8")
file_handler.setLevel(logging.INFO)
file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
# Console handler for INFO and above
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
console_formatter = logging.Formatter('%(levelname)s: %(message)s')
console_handler.setFormatter(console_formatter)
logger.addHandler(console_handler)
def save_groups_json(groups, path):
try:
with open(path, 'w', encoding='utf-8') as f:
json.dump(groups, f, indent=2)
except Exception as e:
logging.warning(f"Could not save label_groups.json: {e}")
def load_groups_json(path):
if os.path.exists(path):
try:
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
except:
pass
return None
def build_label_groups(annotations_df, target_columns=None):
"""
Dynamically build LABEL_GROUPS from the annotation DataFrame.
Rules:
- Exclude rows with 'skipped' in that column when collecting classes.
- If a label starting with 'unknown_' exists OR label == 'not_visible', pick one as ignore label (key '0').
- Remaining labels get keys '1','2',... (stable alphabetical order).
- If no ignore label existes, only numbered keys ('1','2',...) are created and no '0' key.
"""
if target_columns is None:
# All columns except filename (and any added columns like species)
target_columns = [c for c in annotations_df.columns if c not in ('filename', 'species')]
groups = {}
for col in target_columns:
if col == 'filename':
continue
valid_labels = annotations_df[annotations_df[col] != 'skipped'][col].dropna().unique()
ignore_labels = [l for l in valid_labels if l.startswith('unknown_') or l == 'not_visible']
# The actual classes for training are everything that's not an ignore label
class_labels = sorted([l for l in valid_labels if l not in ignore_labels])
group_map = {}
if ignore_labels:
# Pick the first one as the representative 'unknown' class for this group
group_map['0'] = ignore_labels[0]
# Assign numeric keys to the sorted class labels
for i, label in enumerate(class_labels):
group_map[str(i + 1)] = label
groups[col] = group_map
return groups
# --- Model Definition ---
class MultiHeadResNet(nn.Module):
def __init__(self, base_model, label_groups):
super().__init__()
self.label_groups = label_groups
self.base = nn.Sequential(*list(base_model.children())[:-1])
in_features = base_model.fc.in_features
self.heads = nn.ModuleDict({
group: nn.Linear(in_features, len(classes))
for group, classes in label_groups.items()
})
def forward(self, x):
x = self.base(x)
x = torch.flatten(x, 1)
return {group: head(x) for group, head in self.heads.items()}
def load_label_groups(args):
"""
Builds label groups directly from the annotation file.
"""
logging.info(f"Building label groups from annotation file: {args.annotation_file}")
if not os.path.exists(args.annotation_file):
logging.error(f"Annotation file not found: {args.annotation_file}. Cannot build label groups.")
return None
try:
df = pd.read_csv(args.annotation_file)
groups = build_label_groups(df)
logging.info(f"Successfully built label groups from {args.annotation_file}")
# Pretty print the groups for verification
pretty_groups = json.dumps(groups, indent=2)
logging.info(f"Built Label Groups:\n{pretty_groups}")
return groups
except Exception as e:
logging.error(f"Failed to read or process annotation file {args.annotation_file}: {e}")
return None
def load_models(device, args, global_groups):
"""
Load global + species models using dynamic LABEL_GROUPS.
Expects each model directory to optionally contain label_groups.json; if absent uses global groups.
"""
models_cache = {}
if not global_groups:
logging.error("No label groups available; aborting model loading.")
return {}, None
def build_model(groups):
return MultiHeadResNet(models.resnet18(weights=None), groups)
# Global
global_path = os.path.join(args.global_model_dir, "best_multihead_classifier.pt")
fallback_used = False
model_groups_used = global_groups
if os.path.exists(global_path):
m = build_model(global_groups)
state = torch.load(global_path, map_location=device)
m.load_state_dict(state)
m.to(device).eval()
models_cache['__global__'] = m
logging.info(f"Loaded global model: {global_path}")
elif os.path.exists(args.legacy_model_path):
# Legacy single file (assume same groups)
m = build_model(global_groups)
state = torch.load(args.legacy_model_path, map_location=device)
m.load_state_dict(state)
m.to(device).eval()
models_cache['__global__'] = m
fallback_used = True
logging.info(f"Loaded legacy global model: {args.legacy_model_path}")
else:
logging.error("No global model weights found.")
return {}, None
# Species-specific
if os.path.isdir(args.per_species_model_dir):
for sp in os.listdir(args.per_species_model_dir):
sp_dir = os.path.join(args.per_species_model_dir, sp)
weight_path = os.path.join(sp_dir, "best_multihead_classifier.pt")
if not os.path.isfile(weight_path):
continue
# Try species-specific groups JSON
sp_groups = load_groups_json(os.path.join(sp_dir, "label_groups.json"
)) or global_groups
try:
sp_model = build_model(sp_groups)
state = torch.load(weight_path, map_location=device)
sp_model.load_state_dict(state)
sp_model.to(device).eval()
models_cache[sp] = sp_model
logging.info(f"Loaded species model '{sp}' with groups: {list(sp_groups.keys())}")
except Exception as e:
logging.warning(f"Failed loading species model {weight_path}: {e}")
if fallback_used:
logging.warning("Using legacy path (no all_species directory).")
return models_cache, model_groups_used
def run_inference(args):
global LABEL_GROUPS
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load groups from the required JSON file
LABEL_GROUPS = load_label_groups(args)
if not LABEL_GROUPS:
logging.error("Could not load label groups. Exiting.")
return
if args.show_groups:
print("=== LABEL_GROUPS from JSON ===")
print(json.dumps(LABEL_GROUPS, indent=2))
return
logging.info("--- Starting Inference ---")
if not os.path.exists(args.master_csv_path):
logging.error(f"Master CSV not found: {args.master_csv_path}")
return
df = pd.read_csv(args.master_csv_path)
# Determine which column to use for model selection ('class' or 'species')
model_key_col = None
if 'class' in df.columns:
model_key_col = 'class'
logging.info("Using 'class' column for model selection.")
# Print the unique classes
print("Unique classes in 'class' column:", sorted(df['class'].unique()))
elif 'species' in df.columns:
model_key_col = 'species'
logging.info("Using 'species' column for model selection.")
else:
logging.warning("Neither 'class' nor 'species' column found in master CSV. Only the global model will be used.")
models_cache, groups_used = load_models(device, args, LABEL_GROUPS)
if not models_cache:
return
# Build idx->label maps from groups_used
idx_to_label = {
group: {i: name for i, name in enumerate(classes.values())}
for group, classes in groups_used.items()
}
def select_model(row):
key = row.get(model_key_col, '') if model_key_col else ''
return models_cache.get(key, models_cache.get('__global__'))
# Visualization sampling
indices_to_visualize = set()
if args.visualize and len(df) > 0:
indices_to_visualize = set(random.sample(df.index.tolist(), min(args.visualize_count, len(df))))
transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
results = {g: [] for g in groups_used.keys()}
used_model_col = []
with torch.no_grad():
for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Running Inference"):
base = os.path.splitext(row['image_name'])[0]
inst = row['instance_number']
cutout_filename = f"{base}_{inst-1}.tiff"
img_path = os.path.join(args.cutouts_dir, cutout_filename)
model = select_model(row)
model_key = row.get(model_key_col, '') if model_key_col else ''
model_tag = model_key if model_key in models_cache else '__global__'
used_model_col.append(model_tag)
if not os.path.exists(img_path):
logging.warning(f"Missing cutout: {img_path}")
for g in groups_used.keys():
results[g].append('not_found')
continue
try:
img = Image.open(img_path).convert('RGB')
tensor = transform(img).unsqueeze(0).to(device)
outputs = model(tensor)
predicted_labels = {}
for g, logits in outputs.items():
_, pred_idx = torch.max(logits, 1)
pred_name = idx_to_label[g][pred_idx.item()]
results[g].append(pred_name)
predicted_labels[g] = pred_name
if index in indices_to_visualize and args.visualize:
os.makedirs(args.visualize_dir, exist_ok=True)
original = Image.open(img_path).convert('RGB')
try:
font = ImageFont.truetype("arial.ttf", size=15)
except:
font = ImageFont.load_default()
lines = [f"Model: {model_tag}"] + [f"{k}:{v}" for k,v in predicted_labels.items()]
text = "\n".join(lines)
pad = 10
temp_draw = ImageDraw.Draw(Image.new('RGB',(1,1)))
bbox = temp_draw.textbbox((0,0), text, font=font)
t_w = bbox[2]-bbox[0]; t_h = bbox[3]-bbox[1]
new_w = max(original.width, t_w + 2*pad)
new_h = original.height + t_h + 2*pad
canvas = Image.new("RGB",(new_w,new_h),"white")
canvas.paste(original,(0,0))
draw = ImageDraw.Draw(canvas)
draw.text((pad, original.height + pad), text, fill="black", font=font)
canvas.save(os.path.join(args.visualize_dir, os.path.basename(img_path)))
except Exception as e:
logging.error(f"Error processing {img_path}: {e}")
for g in groups_used.keys():
results[g].append('error')
for g, preds in results.items():
df[g] = preds
df['classifier_model_used'] = used_model_col
os.makedirs(os.path.dirname(args.output_csv_path), exist_ok=True)
df.to_csv(args.output_csv_path, index=False)
logging.info(f"Inference saved to {args.output_csv_path}")
def main():
parser = argparse.ArgumentParser(description="Run inference with a multi-head classifier on leaf cutouts.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# --- Path Arguments ---
parser.add_argument("--master-csv-path", type=str, default=DEFAULT_MASTER_CSV_PATH, help="Path to the master results CSV file from the segmenter.")
parser.add_argument("--cutouts-dir", type=str, default=DEFAULT_CUTOUTS_DIR, help="Directory containing the masked leaf cutouts.")
parser.add_argument("--global-model-dir", type=str, default=DEFAULT_GLOBAL_MODEL_DIR, help="Directory for the global model, must contain label_groups.json.")
parser.add_argument("--per-species-model-dir", type=str, default=DEFAULT_PER_SPECIES_MODEL_DIR, help="Base directory for per-species models.")
parser.add_argument("--legacy-model-path", type=str, default=DEFAULT_LEGACY_MODEL_PATH, help="Path to a single legacy model file (fallback).")
parser.add_argument("--output-csv-path", type=str, default=DEFAULT_OUTPUT_CSV_PATH, help="Path to save the updated CSV with classification results.")
parser.add_argument("--visualize-dir", type=str, default=DEFAULT_VISUALIZE_DIR, help="Directory to save visualized inference images.")
parser.add_argument("--log-file", type=str, default=DEFAULT_LOG_FILE, help="Path to the master log file.")
parser.add_argument("--annotation-file", type=str, default=ANNOTATION_CSV_PATH, help="Path to the annotation CSV file, used as a fallback for building label groups.")
# --- Control Arguments ---
parser.add_argument("--visualize-count", type=int, default=300, help="Number of random samples to visualize.")
parser.add_argument("--no-visualize", action="store_false", dest="visualize", help="Disable saving of visualized inference images.")
parser.add_argument("--show-groups", action="store_true", help="Display dynamic label groups and exit without running inference.")
parser.add_argument("--use-sample-data", action="store_true", help="Use pretrained models and sample data for validation.")
parser.add_argument("--use-sample-model", action="store_true", help="Use the pretrained sample model for inference.") # <-- Added
args = parser.parse_args()
# --- Override paths if using sample data ---
if args.use_sample_data:
args.master_csv_path = DEFAULT_MASTER_CSV_PATH
args.cutouts_dir = DEFAULT_CUTOUTS_DIR
args.annotation_file = SAMPLE_ANNOTATION_CSV_PATH
if args.use_sample_model:
args.global_model_dir = "models/pretrained_model_data/classification/pretrained_models/all_species"
args.per_species_model_dir = "models/pretrained_model_data/classification/pretrained_models/species"
args.legacy_model_path = "runs/classifier/best_multihead_classifier.pt"
args.annotation_file = SAMPLE_ANNOTATION_CSV_PATH
setup_logging(args.log_file)
logging.info("--- Starting script: 53_inference_classifier.py ---")
run_inference(args)
logging.info("--- Finished script: 53_inference_classifier.py ---")
if __name__ == "__main__":
main()