File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed
Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change 2525
2626
2727# load pretrained model
28+ device = torch .device ('mps' if torch .backends .mps .is_available () else 'cpu' )
2829model = torchvision .models .resnet18 (weights = 'DEFAULT' )
2930model .fc = torch .nn .Identity ()
31+ model .to (device )
3032model .eval ()
3133
3234
3335def generate_embeddings (inputs ):
34- return model (inputs ) .detach ().numpy ()
36+ return model (inputs . to ( device )) .detach (). cpu ().numpy ()
3537
3638
3739# generate, save, and index embeddings
@@ -53,7 +55,8 @@ def show_images(dataset_images):
5355 grid = torchvision .utils .make_grid (dataset_images )
5456 img = (grid / 2 + 0.5 ).permute (1 , 2 , 0 ).numpy ()
5557 plt .imshow (img )
56- plt .waitforbuttonpress ()
58+ plt .draw ()
59+ plt .waitforbuttonpress (timeout = 3 )
5760
5861
5962# load 5 random unseen images
You can’t perform that action at this time.
0 commit comments