forked from weaviate/i2v-pytorch-models
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathimage2vec_vit.py
More file actions
73 lines (58 loc) · 2.53 KB
/
image2vec_vit.py
File metadata and controls
73 lines (58 loc) · 2.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import logging
import threading
import torch
from PIL import Image
from transformers import ViTImageProcessor, ViTModel
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
# https://huggingface.co/google/vit-base-patch16-224
MODEL_NAME = "google/vit-base-patch16-224"
class Img2VecViT:
def __init__(self, cuda_support, cuda_core):
self.device = torch.device(cuda_core if cuda_support else "cpu")
self.model = ViTModel.from_pretrained(MODEL_NAME)
self.layer_output_size = self.model.config.hidden_size
if self.layer_output_size != 768:
raise ValueError(
"Only ViT models with hidden size of 768 are supported at the moment"
)
self.model = self.model.to(self.device)
self.model.eval()
self.processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
self.lock = threading.Lock()
def get_vec(self, image_path):
img = Image.open(image_path).convert("RGB")
"""
If one of the image dimensions is 1 or 3 it can confuse the `infer_channel_dimension_format` function
in the `Transformers` library, so we set the input_data_format to 'channels_last' in that case.
"""
input_data_format = (
"channels_last" if img.width in (1, 3) or img.height in (1, 3) else None
)
try:
inputs = self.processor(
images=img, return_tensors="pt", input_data_format=input_data_format
)
except ValueError:
"""
The conversion of a PIL.Image.Image to a numpy array should yield the following shape:
(width, height, channels) so long as we are calling .convert("RGB") on the image beforehand.
Therefore, this except block should never trigger but I am leaving it here as a precaution for
if we ever do change the image conversion method.
"""
logger.error(
"Unable to infer color channel format, defaulting to 'channels_first'"
)
inputs = self.processor(
images=img, return_tensors="pt", input_data_format="channels_first"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with self.lock:
with torch.no_grad():
outputs = self.model(**inputs)
features = outputs.last_hidden_state.mean(dim=1)
return features.cpu().numpy()[0]