Skip to content

Commit 1ba4fe1

Browse files
committed
Add Vision Transformer demo for image classification
1 parent 451b7ec commit 1ba4fe1

File tree

1 file changed

+42
-35
lines changed

1 file changed

+42
-35
lines changed
Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,65 @@
11
"""
2-
Vision Transformer (ViT)
3-
========================
2+
Vision Transformer (ViT) Module
3+
================================
44
5-
This module demonstrates how to use a pretrained Vision Transformer (ViT)
6-
for image classification using Hugging Face's Transformers library.
5+
Classify images using a pretrained Vision Transformer (ViT)
6+
from Hugging Face Transformers.
7+
8+
Can be used as a demo or imported in other scripts.
79
810
Source:
911
https://huggingface.co/docs/transformers/model_doc/vit
1012
"""
1113

14+
try:
15+
import requests
16+
import torch
17+
from io import BytesIO
18+
from PIL import Image
19+
from transformers import ViTForImageClassification, ViTImageProcessor
20+
except ImportError as e:
21+
raise ImportError(
22+
"This module requires 'torch', 'transformers', 'PIL', and 'requests'. "
23+
"Install them with: pip install torch transformers pillow requests"
24+
) from e
1225

13-
def vision_transformer_demo() -> None:
14-
"""
15-
Demonstrates Vision Transformer (ViT) on a sample image.
1626

17-
Example:
18-
>>> vision_transformer_demo() # doctest: +SKIP
19-
Predicted label: tabby, tabby cat
20-
"""
21-
try:
22-
import requests
23-
import torch
24-
from PIL import Image
25-
from transformers import ViTForImageClassification, ViTImageProcessor
26-
except ImportError as e:
27-
raise ImportError(
28-
"This demo requires 'torch', 'transformers', 'PIL', and 'requests' packages. "
29-
"Install them with: pip install torch transformers pillow requests"
30-
) from e
31-
32-
# Load a sample image
33-
url = (
34-
"https://huggingface.co/datasets/huggingface/documentation-images/"
35-
"resolve/main/cat_sample.jpeg"
36-
)
37-
image = Image.open(requests.get(url, stream=True, timeout=10).raw)
38-
39-
# Load pretrained model and processor
27+
def classify_image(image: Image.Image) -> str:
28+
"""Classify a PIL image using pretrained ViT."""
4029
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
4130
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
4231

43-
# Preprocess the image
4432
inputs = processor(images=image, return_tensors="pt")
4533

46-
# Run inference
4734
with torch.no_grad():
4835
outputs = model(**inputs)
4936
logits = outputs.logits
5037

5138
predicted_class_idx = logits.argmax(-1).item()
52-
predicted_label = model.config.id2label[predicted_class_idx]
39+
return model.config.id2label[predicted_class_idx]
40+
41+
42+
def demo(url: str = None) -> None:
43+
"""
44+
Run a demo using a sample image or provided URL.
45+
46+
Args:
47+
url (str): URL of the image. If None, uses a default cat image.
48+
"""
49+
if url is None:
50+
url = "https://images.unsplash.com/photo-1592194996308-7b43878e84a6" # default example image
51+
52+
try:
53+
response = requests.get(url, timeout=10)
54+
response.raise_for_status()
55+
image = Image.open(BytesIO(response.content))
56+
except Exception as e:
57+
print(f"Failed to load image from {url}. Error: {e}")
58+
return
5359

54-
print(f"Predicted label: {predicted_label}")
60+
label = classify_image(image)
61+
print(f"Predicted label: {label}")
5562

5663

5764
if __name__ == "__main__":
58-
vision_transformer_demo()
65+
demo()

0 commit comments

Comments
 (0)