Skip to content

Commit cd7412c

Browse files
authored
feat: Update torchvision model loading (#427)
1 parent 395a780 commit cd7412c

1 file changed

Lines changed: 3 additions & 17 deletions

File tree

examples/instance_kind/model.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright 2023-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Redistribution and use in source and binary forms, with or without
44
# modification, are permitted provided that the following conditions
@@ -24,8 +24,8 @@
2424
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27-
import numpy as np
2827
import torch
28+
from torchvision import models
2929
import triton_python_backend_utils as pb_utils
3030
from torch.utils.dlpack import to_dlpack
3131

@@ -49,21 +49,7 @@ def initialize(self, args):
4949
device = "cuda" if args["model_instance_kind"] == "GPU" else "cpu"
5050
device_id = args["model_instance_device_id"]
5151
self.device = f"{device}:{device_id}"
52-
# This example is configured to work with torch=1.13
53-
# and torchvision=0.14. Thus, we need to provide a proper tag `0.14.1`
54-
# to make sure loaded Resnet50 is compatible with
55-
# installed `torchvision`.
56-
# Refer to README for installation instructions.
57-
self.model = (
58-
torch.hub.load(
59-
"pytorch/vision:v0.14.1",
60-
"resnet50",
61-
weights="IMAGENET1K_V2",
62-
skip_validation=True,
63-
)
64-
.to(self.device)
65-
.eval()
66-
)
52+
self.model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2).to(self.device).eval()
6753

6854
def execute(self, requests):
6955
"""

0 commit comments

Comments
 (0)