Skip to content

Commit 20a78b3

Browse files
committed
Improved example [skip ci]
1 parent dd42b18 commit 20a78b3

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

examples/pytorch_image_search.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525

2626

2727
# load pretrained model
28+
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
2829
model = torchvision.models.resnet18(weights='DEFAULT')
2930
model.fc = torch.nn.Identity()
31+
model.to(device)
3032
model.eval()
3133

3234

3335
def generate_embeddings(inputs):
34-
return model(inputs).detach().numpy()
36+
return model(inputs.to(device)).detach().cpu().numpy()
3537

3638

3739
# generate, save, and index embeddings
@@ -53,7 +55,8 @@ def show_images(dataset_images):
5355
grid = torchvision.utils.make_grid(dataset_images)
5456
img = (grid / 2 + 0.5).permute(1, 2, 0).numpy()
5557
plt.imshow(img)
56-
plt.waitforbuttonpress()
58+
plt.draw()
59+
plt.waitforbuttonpress(timeout=3)
5760

5861

5962
# load 5 random unseen images

0 commit comments

Comments
 (0)