Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 234 additions & 0 deletions backend/PtToONNX.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"source": [
"!pip install ultralytics\n",
"!pip install onnxruntime"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Sam1qCq56Ic8",
"outputId": "4e4a7801-ce5b-41c9-e8b5-f5e47ae91dc3"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting ultralytics\n",
" Downloading ultralytics-8.3.212-py3-none-any.whl.metadata (37 kB)\n",
"Requirement already satisfied: numpy>=1.23.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (2.0.2)\n",
"Requirement already satisfied: matplotlib>=3.3.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (3.10.0)\n",
"Requirement already satisfied: opencv-python>=4.6.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (4.12.0.88)\n",
"Requirement already satisfied: pillow>=7.1.2 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (11.3.0)\n",
"Requirement already satisfied: pyyaml>=5.3.1 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (6.0.3)\n",
"Requirement already satisfied: requests>=2.23.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (2.32.4)\n",
"Requirement already satisfied: scipy>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (1.16.2)\n",
"Requirement already satisfied: torch>=1.8.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (2.8.0+cu126)\n",
"Requirement already satisfied: torchvision>=0.9.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (0.23.0+cu126)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from ultralytics) (5.9.5)\n",
"Requirement already satisfied: polars in /usr/local/lib/python3.12/dist-packages (from ultralytics) (1.25.2)\n",
"Collecting ultralytics-thop>=2.0.0 (from ultralytics)\n",
" Downloading ultralytics_thop-2.0.17-py3-none-any.whl.metadata (14 kB)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (1.3.3)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (0.12.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (4.60.1)\n",
"Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (1.4.9)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (25.0)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (3.2.5)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (2.9.0.post0)\n",
"Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.23.0->ultralytics) (3.4.3)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.23.0->ultralytics) (3.10)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.23.0->ultralytics) (2.5.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.23.0->ultralytics) (2025.10.5)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (3.20.0)\n",
"Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (4.15.0)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (75.2.0)\n",
"Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (1.13.3)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (3.5)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (3.1.6)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (2025.3.0)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.77)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.77)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.80)\n",
"Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (9.10.2.21)\n",
"Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.4.1)\n",
"Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (11.3.0.4)\n",
"Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (10.3.7.77)\n",
"Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (11.7.1.2)\n",
"Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.5.4.2)\n",
"Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (0.7.1)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (2.27.3)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.77)\n",
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.85)\n",
"Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (1.11.1.6)\n",
"Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (3.4.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.7->matplotlib>=3.3.0->ultralytics) (1.17.0)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=1.8.0->ultralytics) (1.3.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=1.8.0->ultralytics) (3.0.3)\n",
"Downloading ultralytics-8.3.212-py3-none-any.whl (1.1 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m22.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading ultralytics_thop-2.0.17-py3-none-any.whl (28 kB)\n",
"Installing collected packages: ultralytics-thop, ultralytics\n",
"Successfully installed ultralytics-8.3.212 ultralytics-thop-2.0.17\n",
"Collecting onnxruntime\n",
" Downloading onnxruntime-1.23.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.0 kB)\n",
"Collecting coloredlogs (from onnxruntime)\n",
" Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)\n",
"Requirement already satisfied: flatbuffers in /usr/local/lib/python3.12/dist-packages (from onnxruntime) (25.9.23)\n",
"Requirement already satisfied: numpy>=1.21.6 in /usr/local/lib/python3.12/dist-packages (from onnxruntime) (2.0.2)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from onnxruntime) (25.0)\n",
"Requirement already satisfied: protobuf in /usr/local/lib/python3.12/dist-packages (from onnxruntime) (5.29.5)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.12/dist-packages (from onnxruntime) (1.13.3)\n",
"Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)\n",
" Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy->onnxruntime) (1.3.0)\n",
"Downloading onnxruntime-1.23.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (17.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m17.4/17.4 MB\u001b[0m \u001b[31m105.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m3.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hInstalling collected packages: humanfriendly, coloredlogs, onnxruntime\n",
"Successfully installed coloredlogs-15.0.1 humanfriendly-10.0 onnxruntime-1.23.1\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pFYmWFk76B1r",
"outputId": "12528276-42e1-4f18-c4d8-6300a9884200"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Creating new Ultralytics Settings v0.0.6 file ✅ \n",
"View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'\n",
"Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.\n",
"Ultralytics 8.3.212 🚀 Python-3.12.11 torch-2.8.0+cu126 CPU (AMD EPYC 7B12)\n",
"Model summary (fused): 112 layers, 68,124,531 parameters, 0 gradients, 257.4 GFLOPs\n",
"\n",
"\u001b[34m\u001b[1mPyTorch:\u001b[0m starting from 'best.pt' with input shape (1, 3, 640, 640) BCHW and output shape(s) (1, 5, 8400) (130.4 MB)\n",
"\u001b[31m\u001b[1mrequirements:\u001b[0m Ultralytics requirements ['onnx>=1.12.0', 'onnxslim>=0.1.71'] not found, attempting AutoUpdate...\n",
"\n",
"\u001b[31m\u001b[1mrequirements:\u001b[0m AutoUpdate success ✅ 2.4s\n",
"WARNING ⚠️ \u001b[31m\u001b[1mrequirements:\u001b[0m \u001b[1mRestart runtime or rerun command for updates to take effect\u001b[0m\n",
"\n",
"\n",
"\u001b[34m\u001b[1mONNX:\u001b[0m starting export with onnx 1.19.1 opset 22...\n",
"\u001b[34m\u001b[1mONNX:\u001b[0m slimming with onnxslim 0.1.71...\n",
"\u001b[34m\u001b[1mONNX:\u001b[0m export success ✅ 18.8s, saved as 'best.onnx' (260.1 MB)\n",
"\n",
"Export complete (26.4s)\n",
"Results saved to \u001b[1m/content\u001b[0m\n",
"Predict: yolo predict task=detect model=best.onnx imgsz=640 \n",
"Validate: yolo val task=detect model=best.onnx imgsz=640 data=/home/hice1/wchia7/scratch/Anole_classifier/Dataset/yolo_training/lizard_10000_v2/data.yaml \n",
"Visualize: https://netron.app\n",
"Model exported to ONNX using ultralytics export function.\n"
]
}
],
"source": [
"from ultralytics import YOLO\n",
"\n",
"# Load your trained YOLOv8 model\n",
"# Replace 'best.pt' with the path to your model file if it's different\n",
"model = YOLO('best.pt')\n",
"\n",
"# Export the model to ONNX format\n",
"# The exported ONNX model will be saved in the same directory as your model file\n",
"model.export(format='onnx')\n",
"\n",
"print(\"Model exported to ONNX using ultralytics export function.\")"
]
},
{
"cell_type": "code",
"source": [
"import onnxruntime\n",
"import numpy as np\n",
"from PIL import Image\n",
"import torchvision.transforms as transforms\n",
"\n",
"# Load the ONNX model\n",
"onnx_model_path = 'best.onnx'\n",
"ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=[\"CPUExecutionProvider\"])\n",
"\n",
"# Load and preprocess the image\n",
"image_path = '/content/36455_4751560.jpg'\n",
"img = Image.open(image_path).convert('RGB') # YOLOv8 expects 3 channels\n",
"\n",
"# The model was trained with an input size of 640x640, so we resize the image\n",
"transform = transforms.Compose([\n",
" transforms.Resize((640, 640)),\n",
" transforms.ToTensor(),\n",
"])\n",
"\n",
"input_tensor = transform(img)\n",
"\n",
"# ONNX Runtime expects numpy arrays, so convert the input tensor\n",
"input_numpy = input_tensor.unsqueeze(0).numpy() # Add batch dimension and convert to numpy\n",
"\n",
"# Get input name from ONNX model\n",
"input_name = ort_session.get_inputs()[0].name\n",
"\n",
"# Run inference with ONNX Runtime\n",
"onnxruntime_input = {input_name: input_numpy}\n",
"results_onnx = ort_session.run(None, onnxruntime_input)\n",
"\n",
"print(\"Inference with ONNX Runtime completed.\")\n",
"# The output format of ONNX Runtime might be different from PyTorch's,\n",
"# you'll likely need to post-process results_onnx to interpret detections.\n",
"print(results_onnx)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "njSoSjHw6Els",
"outputId": "4cd2149d-15dd-425b-b368-2c4df92e0045"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Inference with ONNX Runtime completed.\n",
"[array([[[ 15.158, 20.97, 26.594, ..., 520.43, 559.43, 626.25],\n",
" [ 23.222, 19.643, 13.583, ..., 582.18, 580.32, 606.13],\n",
" [ 30.388, 40.937, 52.213, ..., 237.06, 159.98, 188.47],\n",
" [ 46.662, 39.313, 27.193, ..., 117.63, 144.51, 158.6],\n",
" [ 7.4506e-07, 2.0862e-07, 2.3842e-07, ..., 8.9407e-08, 0, 0]]], dtype=float32)]\n"
]
}
]
}
]
}