-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathimage_processor.py
More file actions
54 lines (42 loc) · 1.71 KB
/
image_processor.py
File metadata and controls
54 lines (42 loc) · 1.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
import torch
class ImageProcessor:
def __init__(self, img_path: str, device: str = "cuda") -> None:
"""
Initialize the ImageProcessor object.
:param img_path: Path to the image to be processed.
:param device: The device to process the image on ("cpu" or "cuda").
"""
self.img_path = img_path
self.device = device
def process_image(self) -> torch.Tensor:
"""
Process the image with the specified transformations: Resize, CenterCrop, ToTensor, and Normalize.
:return: A batch of the transformed image tensor on the specified device.
"""
# Open the image file
img = Image.open(self.img_path)
# Define the transformation pipeline
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)
# Apply transformations and prepare a batch
img_transformed = transform(img)
img_batch = torch.unsqueeze(img_transformed, 0).to(self.device)
return img_batch
def process_image_official(self) -> torch.Tensor:
img = Image.open(self.img_path)
# Initialize the Weight Transforms
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
# Apply it to the input image
img_transformed = preprocess(img)
img_batch = torch.unsqueeze(img_transformed, 0).to(self.device)
return img_batch