-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwrapper_inference.py
More file actions
148 lines (120 loc) · 5.04 KB
/
wrapper_inference.py
File metadata and controls
148 lines (120 loc) · 5.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
================================================================================
Author: Andrea Iommi
Code Ownership:
- All Python source code in this file is written solely by the author.
Documentation Notice:
- All docstrings and inline documentation are written by ChatGPT,
but thoroughly checked and approved by the author for accuracy.
================================================================================
"""
from pathlib import Path
import torch
from PIL.Image import Image
from pandas import DataFrame
from torchvision.transforms import transforms as T
from model.resnet_model import CNNClassifier
from utils import load_json
class ModelInference:
"""
A utility class for running inference on images using a trained CNN model.
This class handles:
- Loading the model and its configuration.
- Preparing and transforming input images.
- Running predictions and returning results as a pandas DataFrame.
Notes
-----
- Code authored entirely by the project author.
- Documentation generated by ChatGPT and subsequently reviewed by the author.
"""
def __init__(self):
"""
Initializes the WebInference class by setting placeholders for model
attributes and defining preprocessing transformations.
Attributes
----------
model_dir : str or None
Path to the directory of the currently loaded model. Used to avoid reloading unnecessarily.
model : torch.nn.Module or None
The loaded CNN model instance, moved to the desired device.
params : dict or None
Model parameters loaded from `params.json`, including number of classes and pretrained flag.
idx2class : dict or None
Mapping of class indices to class labels, loaded from `idx2class.json`.
tfms : torchvision.transforms.Compose
A composition of image preprocessing transformations:
- Resize to (400, 400)
- Convert to tensor
- Normalize pixel values
"""
self.model_dir = None
self.model = None
self.params = None
self.idx2class = None
self.tfms = T.Compose([
T.Resize((400, 400)), # Resize images to 400x400 pixels
T.ToTensor(), # Convert PIL Image to PyTorch Tensor
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize image pixel values
])
def prepare_img(self, img: Image) -> torch.Tensor:
"""
Applies preprocessing transformations to an input image and prepares it for inference.
Parameters
----------
img : PIL.Image.Image
The input image to preprocess.
Returns
-------
torch.Tensor
A tensor of shape (1, C, H, W), ready for model inference.
"""
return self.tfms(img).unsqueeze(0)
def inference(self, img: Image, device: str, model_name: str) -> DataFrame:
"""
Performs inference on a single image using a specified model.
Parameters
----------
img : PIL.Image.Image
The input image to classify.
device : str
Device to run inference on, e.g. "cpu" or "cuda".
model_name : str
Path to the model directory. Must contain:
- `params.json` : configuration file (e.g., number of classes, pretrained flag).
- `idx2class.json` : mapping from class indices to labels.
- `model.pth` : trained model weights.
Returns
-------
pandas.DataFrame
A DataFrame with two columns:
- "index": the class label
- "prob": the predicted probability for that class
Sorted in ascending order by probability.
"""
# Load model only if it's not already cached or if model_name changes
if not self.model or self.model_dir != model_name:
self.model_dir = model_name
model_name = Path(model_name)
# Load model configuration and label mapping
self.params = load_json(model_name / "params.json")
self.idx2class = load_json(model_name / "idx2class.json")
# Initialize and load model weights
cls = CNNClassifier(
num_classes=self.params["num_classes"],
pretrained=self.params["pretrained"],
model_name="resnet18"
)
cls.load_state_dict(torch.load(model_name / "model.pth", weights_only=True))
cls.eval()
self.model = cls.to(device)
# Preprocess and move image to device
img = self.prepare_img(img).to(device)
# Run inference without tracking gradients
with torch.no_grad():
logits = self.model(img)
# Convert logits to probabilities and return sorted DataFrame
return DataFrame(
data=torch.softmax(logits, dim=1).cpu().numpy(),
columns=list(self.idx2class.values()),
index=["prob"]
).T.reset_index().sort_values(by="prob", ascending=True)