|
1 | 1 | """ |
2 | | -Vision Transformer (ViT) |
3 | | -======================== |
| 2 | +Vision Transformer (ViT) Module |
| 3 | +================================ |
4 | 4 |
|
5 | | -This module demonstrates how to use a pretrained Vision Transformer (ViT) |
6 | | -for image classification using Hugging Face's Transformers library. |
| 5 | +Classify images using a pretrained Vision Transformer (ViT) |
| 6 | +from Hugging Face Transformers. |
| 7 | +
|
| 8 | +Can be used as a demo or imported in other scripts. |
7 | 9 |
|
8 | 10 | Source: |
9 | 11 | https://huggingface.co/docs/transformers/model_doc/vit |
10 | 12 | """ |
11 | 13 |
|
| 14 | +try: |
| 15 | + import requests |
| 16 | + import torch |
| 17 | + from io import BytesIO |
| 18 | + from PIL import Image |
| 19 | + from transformers import ViTForImageClassification, ViTImageProcessor |
| 20 | +except ImportError as e: |
| 21 | + raise ImportError( |
| 22 | + "This module requires 'torch', 'transformers', 'PIL', and 'requests'. " |
| 23 | + "Install them with: pip install torch transformers pillow requests" |
| 24 | + ) from e |
12 | 25 |
|
13 | | -def vision_transformer_demo() -> None: |
14 | | - """ |
15 | | - Demonstrates Vision Transformer (ViT) on a sample image. |
16 | 26 |
|
17 | | - Example: |
18 | | - >>> vision_transformer_demo() # doctest: +SKIP |
19 | | - Predicted label: tabby, tabby cat |
20 | | - """ |
21 | | - try: |
22 | | - import requests |
23 | | - import torch |
24 | | - from PIL import Image |
25 | | - from transformers import ViTForImageClassification, ViTImageProcessor |
26 | | - except ImportError as e: |
27 | | - raise ImportError( |
28 | | - "This demo requires 'torch', 'transformers', 'PIL', and 'requests' packages. " |
29 | | - "Install them with: pip install torch transformers pillow requests" |
30 | | - ) from e |
31 | | - |
32 | | - # Load a sample image |
33 | | - url = ( |
34 | | - "https://huggingface.co/datasets/huggingface/documentation-images/" |
35 | | - "resolve/main/cat_sample.jpeg" |
36 | | - ) |
37 | | - image = Image.open(requests.get(url, stream=True, timeout=10).raw) |
38 | | - |
39 | | - # Load pretrained model and processor |
| 27 | +def classify_image(image: Image.Image) -> str: |
| 28 | + """Classify a PIL image using pretrained ViT.""" |
40 | 29 | processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") |
41 | 30 | model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") |
42 | 31 |
|
43 | | - # Preprocess the image |
44 | 32 | inputs = processor(images=image, return_tensors="pt") |
45 | 33 |
|
46 | | - # Run inference |
47 | 34 | with torch.no_grad(): |
48 | 35 | outputs = model(**inputs) |
49 | 36 | logits = outputs.logits |
50 | 37 |
|
51 | 38 | predicted_class_idx = logits.argmax(-1).item() |
52 | | - predicted_label = model.config.id2label[predicted_class_idx] |
| 39 | + return model.config.id2label[predicted_class_idx] |
| 40 | + |
| 41 | + |
| 42 | +def demo(url: str = None) -> None: |
| 43 | + """ |
| 44 | + Run a demo using a sample image or provided URL. |
| 45 | + |
| 46 | + Args: |
| 47 | + url (str): URL of the image. If None, uses a default cat image. |
| 48 | + """ |
| 49 | + if url is None: |
| 50 | + url = "https://images.unsplash.com/photo-1592194996308-7b43878e84a6" # default example image |
| 51 | + |
| 52 | + try: |
| 53 | + response = requests.get(url, timeout=10) |
| 54 | + response.raise_for_status() |
| 55 | + image = Image.open(BytesIO(response.content)) |
| 56 | + except Exception as e: |
| 57 | + print(f"Failed to load image from {url}. Error: {e}") |
| 58 | + return |
53 | 59 |
|
54 | | - print(f"Predicted label: {predicted_label}") |
| 60 | + label = classify_image(image) |
| 61 | + print(f"Predicted label: {label}") |
55 | 62 |
|
56 | 63 |
|
57 | 64 | if __name__ == "__main__": |
58 | | - vision_transformer_demo() |
| 65 | + demo() |
0 commit comments