forked from pepisg/simple_segmentation_toolkit
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinfer.py
More file actions
executable file
·126 lines (99 loc) · 4.43 KB
/
infer.py
File metadata and controls
executable file
·126 lines (99 loc) · 4.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python3
"""Simple inference script for semantic segmentation."""
import argparse
import cv2
import random
import yaml
from pathlib import Path
import torch
from super_gradients.training import models
from training.visualizer import PredictionVisualizer
from training.dataset import SegmentationDataset
def find_latest_checkpoint(checkpoint_dir: Path) -> tuple:
"""
Find the most recent model checkpoint.
Returns:
(checkpoint_path, model_name) tuple
"""
checkpoints = list(checkpoint_dir.glob("**/*.pth"))
if not checkpoints:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
latest = max(checkpoints, key=lambda p: p.stat().st_mtime)
# Infer model name from directory structure (e.g., checkpoints/pp_lite_b_seg_custom/best.pth)
model_dir = latest.parent.name
if "_custom" in model_dir:
model_name = model_dir.replace("_custom", "")
elif "_segmentation" in model_dir:
model_name = model_dir.replace("_segmentation", "")
else:
model_name = model_dir
return latest, model_name
def find_random_image(data_dir: Path = Path("data")) -> Path:
"""Find a random image from raw_images or accepted images."""
# Try raw_images first, then accepted
search_dirs = [data_dir / "raw_images", data_dir / "labeling" / "accepted" / "images"]
for search_dir in search_dirs:
if search_dir.exists():
images = list(search_dir.glob("*.png")) + list(search_dir.glob("*.jpg"))
if images:
return random.choice(images)
raise FileNotFoundError("No images found in data/raw_images or data/labeling/accepted/images")
def main():
parser = argparse.ArgumentParser(description="Run inference on an image")
parser.add_argument("image", type=str, nargs='?', default=None, help="Path to input image (random if not specified)")
parser.add_argument("--config", type=str, default="configs/ontology.yaml", help="Config file")
parser.add_argument("--checkpoint", type=str, default=None, help="Model checkpoint (finds latest if not specified)")
parser.add_argument("--model", type=str, default="ddrnet_39", help="Model architecture")
args = parser.parse_args()
print("\n" + "="*60)
print("INFERENCE SCRIPT")
print("="*60)
# Get image path
image_path = Path(args.image) if args.image else find_random_image()
print(f"✓ Image: {image_path}")
# Load config
with open(args.config, 'r') as f:
config = yaml.safe_load(f)
class_names = [cls['name'] for cls in config['ontology']['classes']]
class_colors = [cls['color'] for cls in config['ontology']['classes']]
num_classes = len(class_names) + 1
# Find checkpoint and infer model if needed
if args.checkpoint:
checkpoint_path = Path(args.checkpoint)
model_name = args.model
else:
checkpoint_path, inferred_model = find_latest_checkpoint(Path("models/checkpoints"))
model_name = inferred_model
print(f"✓ Model: {model_name} (auto-detected)")
print(f"✓ Checkpoint: {checkpoint_path}")
# Load model
model = models.get(model_name, num_classes=num_classes)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Load and preprocess image
image = cv2.imread(str(image_path))
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
input_tensor = torch.from_numpy(rgb_image).permute(2, 0, 1).float() / 255.0
# Apply ImageNet normalization
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
input_tensor = (input_tensor - mean) / std
# Run inference
with torch.no_grad():
output = model(input_tensor.unsqueeze(0))
if isinstance(output, (list, tuple)):
output = output[0]
prediction = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
# Visualize
visualizer = PredictionVisualizer(class_names, class_colors)
pred_colored = visualizer.create_colored_mask(prediction)
overlay = cv2.addWeighted(image, 0.7, pred_colored, 0.3, 0)
# Show result
cv2.imshow(f"Segmentation Result - {image_path.name}", overlay)
print("\n✓ Showing result... Press any key to close")
print("="*60)
cv2.waitKey(0)
cv2.destroyAllWindows()
if __name__ == "__main__":
main()