Skip to content

Commit e7c216d

Browse files
authored
Merge pull request #1 from ComfyAssets/fix/pr14-review-fixes
Fix error handling and CUDA cleanup from PR yuvraj108c#14
2 parents 41686c5 + 5d54ca6 commit e7c216d

File tree

6 files changed

+374
-46
lines changed

6 files changed

+374
-46
lines changed

__init__.py

Lines changed: 150 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,153 @@
33
from comfy.model_management import get_torch_device
44
from .vfi_utilities import preprocess_frames, postprocess_frames, generate_frames_rife, logger
55
from .trt_utilities import Engine
6+
from .utilities import download_file, ColoredLogger
67
import folder_paths
78
import time
89
from polygraphy import cuda
10+
import comfy.model_management as mm
11+
import tensorrt
12+
import json
913

1014
ENGINE_DIR = os.path.join(folder_paths.models_dir, "tensorrt", "rife")
1115

16+
# Image dimensions for TensorRT engine building
17+
IMAGE_DIM_MIN = 256
18+
IMAGE_DIM_OPT = 512
19+
IMAGE_DIM_MAX = 3840
20+
21+
# Logger for this module
22+
rife_logger = ColoredLogger("ComfyUI-Rife-Tensorrt")
23+
24+
# Function to load configuration
25+
def load_node_config(config_filename="load_rife_config.json"):
26+
"""Loads node configuration from a JSON file."""
27+
current_dir = os.path.dirname(__file__)
28+
config_path = os.path.join(current_dir, config_filename)
29+
30+
default_config = {
31+
"model": {
32+
"options": ["rife49_ensemble_True_scale_1_sim"],
33+
"default": "rife49_ensemble_True_scale_1_sim",
34+
"tooltip": "Default model (fallback from code)"
35+
},
36+
"precision": {
37+
"options": ["fp16", "fp32"],
38+
"default": "fp16",
39+
"tooltip": "Default precision (fallback from code)"
40+
}
41+
}
42+
43+
try:
44+
with open(config_path, 'r') as f:
45+
config = json.load(f)
46+
rife_logger.info(f"Successfully loaded configuration from {config_filename}")
47+
return config
48+
except FileNotFoundError:
49+
rife_logger.warning(f"Configuration file '{config_path}' not found. Using default fallback configuration.")
50+
return default_config
51+
except json.JSONDecodeError:
52+
rife_logger.error(f"Error decoding JSON from '{config_path}'. Using default fallback configuration.")
53+
return default_config
54+
except Exception as e:
55+
rife_logger.error(f"An unexpected error occurred while loading '{config_path}': {e}. Using default fallback.")
56+
return default_config
57+
58+
# Load the configuration once when the module is imported
59+
LOAD_RIFE_NODE_CONFIG = load_node_config()
60+
61+
class LoadRifeTensorrtModel:
62+
@classmethod
63+
def INPUT_TYPES(cls):
64+
# Use the pre-loaded configuration
65+
model_config = LOAD_RIFE_NODE_CONFIG.get("model", {})
66+
precision_config = LOAD_RIFE_NODE_CONFIG.get("precision", {})
67+
68+
# Provide sensible defaults if keys are missing in the config
69+
model_options = model_config.get("options", ["rife49_ensemble_True_scale_1_sim"])
70+
model_default = model_config.get("default", "rife49_ensemble_True_scale_1_sim")
71+
model_tooltip = model_config.get("tooltip", "Select a RIFE model.")
72+
73+
precision_options = precision_config.get("options", ["fp16", "fp32"])
74+
precision_default = precision_config.get("default", "fp16")
75+
precision_tooltip = precision_config.get("tooltip", "Select precision.")
76+
77+
return {
78+
"required": {
79+
"model": (model_options, {"default": model_default, "tooltip": model_tooltip}),
80+
"precision": (precision_options, {"default": precision_default, "tooltip": precision_tooltip}),
81+
}
82+
}
83+
84+
RETURN_NAMES = ("rife_trt_model",)
85+
RETURN_TYPES = ("RIFE_TRT_MODEL",)
86+
CATEGORY = "tensorrt"
87+
DESCRIPTION = "Load RIFE tensorrt models, they will be built automatically if not found."
88+
FUNCTION = "load_rife_tensorrt_model"
89+
90+
def load_rife_tensorrt_model(self, model, precision):
91+
tensorrt_models_dir = os.path.join(folder_paths.models_dir, "tensorrt", "rife")
92+
onnx_models_dir = os.path.join(folder_paths.models_dir, "onnx")
93+
94+
os.makedirs(tensorrt_models_dir, exist_ok=True)
95+
os.makedirs(onnx_models_dir, exist_ok=True)
96+
97+
onnx_model_path = os.path.join(onnx_models_dir, f"{model}.onnx")
98+
99+
# Build tensorrt model path with detailed naming
100+
engine_channel = 3
101+
engine_min_batch, engine_opt_batch, engine_max_batch = 1, 1, 1
102+
engine_min_h, engine_opt_h, engine_max_h = IMAGE_DIM_MIN, IMAGE_DIM_OPT, IMAGE_DIM_MAX
103+
engine_min_w, engine_opt_w, engine_max_w = IMAGE_DIM_MIN, IMAGE_DIM_OPT, IMAGE_DIM_MAX
104+
tensorrt_model_path = os.path.join(tensorrt_models_dir, f"{model}_{precision}_{engine_min_batch}x{engine_channel}x{engine_min_h}x{engine_min_w}_{engine_opt_batch}x{engine_channel}x{engine_opt_h}x{engine_opt_w}_{engine_max_batch}x{engine_channel}x{engine_max_h}x{engine_max_w}_{tensorrt.__version__}.trt")
105+
106+
if not os.path.exists(tensorrt_model_path):
107+
if not os.path.exists(onnx_model_path):
108+
onnx_model_download_url = f"https://huggingface.co/yuvraj108c/rife-onnx/resolve/main/{model}.onnx"
109+
rife_logger.info(f"Downloading {onnx_model_download_url}")
110+
download_file(url=onnx_model_download_url, save_path=onnx_model_path)
111+
else:
112+
rife_logger.info(f"ONNX model found at: {onnx_model_path}")
113+
114+
rife_logger.info(f"Building TensorRT engine for {onnx_model_path}: {tensorrt_model_path}")
115+
mm.soft_empty_cache()
116+
s = time.time()
117+
engine = Engine(tensorrt_model_path)
118+
ret = engine.build(
119+
onnx_path=onnx_model_path,
120+
fp16=True if precision == "fp16" else False,
121+
input_profile=[
122+
{
123+
"img0": [(engine_min_batch, engine_channel, engine_min_h, engine_min_w), (engine_opt_batch, engine_channel, engine_opt_h, engine_opt_w), (engine_max_batch, engine_channel, engine_max_h, engine_max_w)],
124+
"img1": [(engine_min_batch, engine_channel, engine_min_h, engine_min_w), (engine_opt_batch, engine_channel, engine_opt_h, engine_opt_w), (engine_max_batch, engine_channel, engine_max_h, engine_max_w)],
125+
}
126+
],
127+
)
128+
if ret != 0:
129+
if os.path.exists(tensorrt_model_path):
130+
os.remove(tensorrt_model_path)
131+
raise RuntimeError(f"Failed to build TensorRT engine from {onnx_model_path}")
132+
e = time.time()
133+
rife_logger.info(f"Time taken to build: {(e-s)} seconds")
134+
135+
rife_logger.info(f"Loading TensorRT engine: {tensorrt_model_path}")
136+
mm.soft_empty_cache()
137+
engine = Engine(tensorrt_model_path)
138+
engine.load()
139+
140+
return (engine,)
141+
12142
class RifeTensorrt:
13143
@classmethod
14144
def INPUT_TYPES(s):
15145
return {
16146
"required": {
17-
"frames": ("IMAGE", ),
18-
"engine": (os.listdir(ENGINE_DIR),),
19-
"clear_cache_after_n_frames": ("INT", {"default": 100, "min": 1, "max": 1000}),
20-
"multiplier": ("INT", {"default": 2, "min": 1}),
21-
"use_cuda_graph": ("BOOLEAN", {"default": True}),
22-
"keep_model_loaded": ("BOOLEAN", {"default": False}),
147+
"frames": ("IMAGE", {"tooltip": "Input frames for video frame interpolation"}),
148+
"rife_trt_model": ("RIFE_TRT_MODEL", {"tooltip": "Tensorrt model built and loaded"}),
149+
"clear_cache_after_n_frames": ("INT", {"default": 100, "min": 1, "max": 1000, "tooltip": "Clear CUDA cache after processing this many frames"}),
150+
"multiplier": ("INT", {"default": 2, "min": 1, "tooltip": "Frame interpolation multiplier"}),
151+
"use_cuda_graph": ("BOOLEAN", {"default": True, "tooltip": "Use CUDA graph for better performance"}),
152+
"keep_model_loaded": ("BOOLEAN", {"default": False, "tooltip": "Keep model loaded in memory after processing"}),
23153
},
24154
}
25155

@@ -31,7 +161,7 @@ def INPUT_TYPES(s):
31161
def vfi(
32162
self,
33163
frames,
34-
engine,
164+
rife_trt_model,
35165
clear_cache_after_n_frames=100,
36166
multiplier=2,
37167
use_cuda_graph=True,
@@ -45,24 +175,21 @@ def vfi(
45175
}
46176

47177
cudaStream = cuda.Stream()
48-
engine_path = os.path.join(ENGINE_DIR, engine)
49-
if (not hasattr(self, 'engine') or self.engine_label != engine):
50-
self.engine = Engine(engine_path)
51-
logger(f"Loading TensorRT engine: {engine_path}")
52-
self.engine.load()
53-
self.engine.activate()
54-
self.engine_label = engine
55-
else:
56-
logger(f"Using cached TensorRT engine: {engine_path}")
57-
58-
self.engine.allocate_buffers(shape_dict=shape_dict)
178+
179+
# Use the provided model directly
180+
engine = rife_trt_model
181+
logger(f"Using loaded TensorRT engine")
182+
183+
# Activate and allocate buffers for the engine
184+
engine.activate()
185+
engine.allocate_buffers(shape_dict=shape_dict)
59186

60187
frames = preprocess_frames(frames)
61188

62189
def return_middle_frame(frame_0, frame_1, timestep):
63190
timestep_t = torch.tensor([timestep], dtype=torch.float32).to(get_torch_device())
64191
# s = time.time()
65-
output = self.engine.infer({"img0": frame_0, "img1": frame_1, "timestep": timestep_t}, cudaStream, use_cuda_graph)
192+
output = engine.infer({"img0": frame_0, "img1": frame_1, "timestep": timestep_t}, cudaStream, use_cuda_graph)
66193
# e = time.time()
67194
# print(f"Time taken to infer: {(e-s)*1000} ms")
68195

@@ -71,19 +198,21 @@ def return_middle_frame(frame_0, frame_1, timestep):
71198

72199
result = generate_frames_rife(frames, clear_cache_after_n_frames, multiplier, return_middle_frame)
73200
out = postprocess_frames(result)
74-
201+
75202
if not keep_model_loaded:
76-
del self.engine, self.engine_label
203+
engine.reset()
77204

78205
return (out,)
79206

80207

81208
NODE_CLASS_MAPPINGS = {
82209
"RifeTensorrt": RifeTensorrt,
210+
"LoadRifeTensorrtModel": LoadRifeTensorrtModel,
83211
}
84212

85213
NODE_DISPLAY_NAME_MAPPINGS = {
86214
"RifeTensorrt": "⚡ Rife Tensorrt",
215+
"LoadRifeTensorrtModel": "Load Rife Tensorrt Model",
87216
}
88217

89218
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']

load_rife_config.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"model": {
3+
"options": [
4+
"rife47_ensemble_True_scale_1_sim",
5+
"rife48_ensemble_True_scale_1_sim",
6+
"rife49_ensemble_True_scale_1_sim"
7+
],
8+
"default": "rife49_ensemble_True_scale_1_sim",
9+
"tooltip": "RIFE models for video frame interpolation. These models have been tested with tensorrt. Loaded from config."
10+
},
11+
"precision": {
12+
"options": ["fp16", "fp32"],
13+
"default": "fp16",
14+
"tooltip": "Precision to build the tensorrt engines. Loaded from config."
15+
}
16+
}

readme.md

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
# ComfyUI Rife TensorRT ⚡
44

5-
[![python](https://img.shields.io/badge/python-3.10.12-green)](https://www.python.org/downloads/release/python-31012/)
6-
[![cuda](https://img.shields.io/badge/cuda-12.4-green)](https://developer.nvidia.com/cuda-downloads)
7-
[![trt](https://img.shields.io/badge/TRT-10.4.0-green)](https://developer.nvidia.com/tensorrt)
5+
[![python](https://img.shields.io/badge/python-3.12.11-green)](https://www.python.org/downloads/release/python-31211/)
6+
[![cuda](https://img.shields.io/badge/cuda-12.9-green)](https://developer.nvidia.com/cuda-downloads)
7+
[![trt](https://img.shields.io/badge/TRT-10.13.3.9-green)](https://developer.nvidia.com/tensorrt)
88
[![by-nc-sa/4.0](https://img.shields.io/badge/license-CC--BY--NC--SA--4.0-lightgrey)](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en)
99

1010
![node](https://github.com/user-attachments/assets/5fd6d529-300c-42a5-b9cf-46e031f0bcb5)
@@ -40,27 +40,40 @@ cd ./ComfyUI-Rife-Tensorrt
4040
pip install -r requirements.txt
4141
```
4242

43-
## 🛠️ Building Tensorrt Engine
43+
## 🛠️ Supported Models
4444

45-
1. Download one of the following onnx models:
46-
- [rife49_ensemble_True_scale_1_sim.onnx](https://huggingface.co/yuvraj108c/rife-onnx/resolve/main/rife49_ensemble_True_scale_1_sim.onnx)
47-
- [rife48_ensemble_True_scale_1_sim.onnx](https://huggingface.co/yuvraj108c/rife-onnx/resolve/main/rife48_ensemble_True_scale_1_sim.onnx)
48-
- [rife47_ensemble_True_scale_1_sim.onnx](https://huggingface.co/yuvraj108c/rife-onnx/resolve/main/rife47_ensemble_True_scale_1_sim.onnx)
49-
2. Edit onnx/trt paths inside [export_trt.py](./export_trt.py) and build tensorrt engine by running:
50-
- `python export_trt.py`
45+
The following RIFE models are supported and will be automatically downloaded and built:
46+
- **rife49_ensemble_True_scale_1_sim** (default) - Latest and most accurate
47+
- **rife48_ensemble_True_scale_1_sim** - Good balance of speed and quality
48+
- **rife47_ensemble_True_scale_1_sim** - Fastest option
5149

52-
3. Place the exported engine inside ComfyUI `/models/tensorrt/rife` directory
50+
Models are automatically downloaded from [HuggingFace](https://huggingface.co/yuvraj108c/rife-onnx) and TensorRT engines are built on first use.
5351

5452
## ☀️ Usage
5553

56-
- Insert node by `Right Click -> tensorrt -> Rife Tensorrt`
57-
- Image resolutions between `256x256` and `3840x3840` will work with the tensorrt engines
54+
1. **Load Model**: Insert `Right Click -> Add Node -> tensorrt -> Load Rife Tensorrt Model`
55+
- Choose your preferred RIFE model (rife47, rife48, or rife49)
56+
- Select precision (fp16 recommended for speed, fp32 for maximum accuracy)
57+
- The model will be automatically downloaded and TensorRT engine built on first use
58+
59+
2. **Process Frames**: Insert `Right Click -> Add Node -> tensorrt -> Rife Tensorrt`
60+
- Connect the loaded model from step 1
61+
- Input your video frames
62+
- Configure interpolation settings (multiplier, CUDA graph, etc.)
63+
- Image resolutions between `256x256` and `3840x3840` are supported
5864

5965
## 🤖 Environment tested
6066

61-
- Ubuntu 22.04 LTS, Cuda 12.4, Tensorrt 10.4.0, Python 3.10, RTX 3070 GPU
67+
- WSL Ubuntu 24.04.03 LTS, Cuda 12.9, Tensorrt 10.13.3.9, Python 3.12.11, RTX 5080 GPU
6268
- Windows (Not tested, but should work)
6369

70+
## 🚨 Updates
71+
72+
### December 2025
73+
- **Automatic Model Management**: No more manual downloads! Models are automatically downloaded from HuggingFace and TensorRT engines are built on demand
74+
- **Improved Workflow**: New two-node system with `Load Rife Tensorrt Model` + `Rife Tensorrt` for better organization
75+
- **Updated Dependencies**: TensorRT updated to 10.13.3.9 for better performance and compatibility
76+
6477
## 👏 Credits
6578

6679
- https://github.com/styler00dollar/VSGAN-tensorrt-docker

requirements.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
einops
22
colored
33
polygraphy
4-
tensorrt==10.4.0
5-
cuda-python
4+
#tensorrt==10.13.3.9
5+
tensorrt
6+
cuda-python
7+
requests
8+
9+
tqdm

0 commit comments

Comments
 (0)