Skip to content

Commit 092f5cc

Browse files
authored
Merge branch 'master' into copilot/migrate-to-python-312
2 parents 9131a98 + 4e341d4 commit 092f5cc

1 file changed

Lines changed: 28 additions & 0 deletions

File tree

src/onnx_cuda_inference.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from src.onnx_inference import ONNXInference
2+
from src.onnx_exporter import ONNXExporter
3+
import onnxruntime as ort
4+
import os
5+
6+
class ONNXInferenceCUDA(ONNXInference):
7+
def __init__(self, model_loader, model_path, debug_mode=False):
8+
"""
9+
Initialize the ONNXInference object.
10+
11+
:param model_loader: Object responsible for loading the model and categories.
12+
:param model_path: Path to the ONNX model.
13+
:param debug_mode: If True, print additional debug information.
14+
"""
15+
super().__init__(model_loader, model_path, debug_mode=debug_mode)
16+
17+
def load_model(self):
18+
"""
19+
Load the ONNX model. If the model does not exist, export it.
20+
21+
:return: Loaded ONNX model.
22+
"""
23+
if not os.path.exists(self.onnx_path):
24+
onnx_exporter = ONNXExporter(
25+
self.model_loader.model, self.model_loader.device, self.onnx_path
26+
)
27+
onnx_exporter.export_model()
28+
return ort.InferenceSession(self.onnx_path, providers=["CUDAExecutionProvider"])

0 commit comments

Comments
 (0)