forked from yuvraj108c/ComfyUI-Rife-Tensorrt
-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy path__init__.py
More file actions
324 lines (272 loc) · 14 KB
/
__init__.py
File metadata and controls
324 lines (272 loc) · 14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import torch
import os
import sys
from pathlib import Path
from comfy.model_management import get_torch_device
from .vfi_utilities import preprocess_frames, postprocess_frames, generate_frames_rife, logger
from .trt_utilities import Engine
from .utilities import download_file, ColoredLogger
import folder_paths
import time
# Auto-detect CUDA toolkit and add DLL path before importing polygraphy
def _setup_cuda_dll_path():
"""Auto-detect CUDA toolkit and add cudart64 DLL path on Windows."""
if not sys.platform.startswith("win"):
return
cuda_root = None
# Check for CUDA_PATH or CUDA_HOME environment variables
cuda_root = os.environ.get("CUDA_PATH") or os.environ.get("CUDA_HOME")
if not cuda_root:
# Try default Windows install location
program_files = os.environ.get("PROGRAMFILES")
if program_files:
cuda_base = Path(program_files) / "NVIDIA GPU Computing Toolkit" / "CUDA"
if cuda_base.exists():
# Find highest version directory
versions = sorted([d for d in cuda_base.iterdir() if d.is_dir()], reverse=True)
if versions:
cuda_root = str(versions[0])
if cuda_root:
cuda_path = Path(cuda_root)
# CUDA 13.0+ puts cudart64 in bin/x64 subdirectory
cuda_bin_x64 = cuda_path / "bin" / "x64"
if cuda_bin_x64.exists() and any(cuda_bin_x64.glob("cudart64*.dll")):
os.add_dll_directory(str(cuda_bin_x64))
return
# Fallback to regular bin directory for older CUDA versions
cuda_bin = cuda_path / "bin"
if cuda_bin.exists() and any(cuda_bin.glob("cudart64*.dll")):
os.add_dll_directory(str(cuda_bin))
return
# CUDA toolkit not found - print warning with download link
print("[ComfyUI-Rife-TensorRT] WARNING: CUDA toolkit not found.")
print(" Set CUDA_PATH environment variable or install CUDA toolkit.")
print(" Download: https://developer.nvidia.com/cuda-13-0-2-download-archive")
_setup_cuda_dll_path()
from polygraphy import cuda
import comfy.model_management as mm
import tensorrt
import json
ENGINE_DIR = os.path.join(folder_paths.models_dir, "tensorrt", "rife")
# Default resolution profiles (fallback if config is missing)
DEFAULT_RESOLUTION_PROFILES = {
"small": {"min": 384, "opt": 720, "max": 1080},
"medium": {"min": 672, "opt": 1080, "max": 1312}
}
# Logger for this module
rife_logger = ColoredLogger("ComfyUI-Rife-Tensorrt")
# Function to load configuration
def load_node_config(config_filename="load_rife_config.json"):
"""Loads node configuration from a JSON file."""
current_dir = os.path.dirname(__file__)
config_path = os.path.join(current_dir, config_filename)
default_config = {
"model": {
"options": ["rife49_ensemble_True_scale_1_sim"],
"default": "rife49_ensemble_True_scale_1_sim",
"tooltip": "Default model (fallback from code)"
},
"precision": {
"options": ["fp16", "fp32"],
"default": "fp16",
"tooltip": "Default precision (fallback from code)"
}
}
try:
with open(config_path, 'r') as f:
config = json.load(f)
rife_logger.info(f"Successfully loaded configuration from {config_filename}")
return config
except FileNotFoundError:
rife_logger.warning(f"Configuration file '{config_path}' not found. Using default fallback configuration.")
return default_config
except json.JSONDecodeError:
rife_logger.error(f"Error decoding JSON from '{config_path}'. Using default fallback configuration.")
return default_config
except Exception as e:
rife_logger.error(f"An unexpected error occurred while loading '{config_path}': {e}. Using default fallback.")
return default_config
# Load the configuration once when the module is imported
LOAD_RIFE_NODE_CONFIG = load_node_config()
class CustomResolutionConfig:
"""Node to configure custom resolution dimensions for TensorRT engine building."""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"min_dim": ("INT", {"default": 384, "min": 64, "max": 4096, "step": 8, "tooltip": "Minimum resolution dimension"}),
"opt_dim": ("INT", {"default": 720, "min": 64, "max": 4096, "step": 8, "tooltip": "Optimal resolution dimension (most common)"}),
"max_dim": ("INT", {"default": 1312, "min": 64, "max": 4096, "step": 8, "tooltip": "Maximum resolution dimension"}),
}
}
RETURN_TYPES = ("RIFE_RESOLUTION_CONFIG",)
RETURN_NAMES = ("resolution_config",)
FUNCTION = "configure"
CATEGORY = "tensorrt"
DESCRIPTION = "Configure custom resolution dimensions for RIFE TensorRT engine."
def configure(self, min_dim, opt_dim, max_dim):
config = {
"min": min_dim,
"opt": opt_dim,
"max": max_dim,
}
return (config,)
class AutoLoadRifeTensorrtModel:
@classmethod
def INPUT_TYPES(cls):
# Use the pre-loaded configuration
model_config = LOAD_RIFE_NODE_CONFIG.get("model", {})
precision_config = LOAD_RIFE_NODE_CONFIG.get("precision", {})
# Provide sensible defaults if keys are missing in the config
model_options = model_config.get("options", ["rife49_ensemble_True_scale_1_sim"])
model_default = model_config.get("default", "rife49_ensemble_True_scale_1_sim")
model_tooltip = model_config.get("tooltip", "Select a RIFE model.")
precision_options = precision_config.get("options", ["fp16", "fp32"])
precision_default = precision_config.get("default", "fp16")
precision_tooltip = precision_config.get("tooltip", "Select precision.")
# Resolution profile configuration
profile_config = LOAD_RIFE_NODE_CONFIG.get("resolution_profile", {})
profile_options = profile_config.get("options", ["small", "medium"])
# Ensure 'custom' is always available
if "custom" not in profile_options:
profile_options = profile_options + ["custom"]
profile_default = profile_config.get("default", "small")
profile_tooltip = profile_config.get("tooltip", "Resolution range for TensorRT engine. Use 'custom' with the INT inputs below.")
return {
"required": {
"model": (model_options, {"default": model_default, "tooltip": model_tooltip}),
"precision": (precision_options, {"default": precision_default, "tooltip": precision_tooltip}),
"resolution_profile": (profile_options, {"default": profile_default, "tooltip": profile_tooltip}),
},
"optional": {
"custom_config": ("RIFE_RESOLUTION_CONFIG", {"tooltip": "Custom resolution config (used when profile='custom')"}),
}
}
RETURN_NAMES = ("rife_trt_model",)
RETURN_TYPES = ("RIFE_TRT_MODEL",)
CATEGORY = "tensorrt"
DESCRIPTION = "Load RIFE tensorrt models, they will be built automatically if not found."
FUNCTION = "load_rife_tensorrt_model"
def load_rife_tensorrt_model(self, model, precision, resolution_profile, custom_config=None):
tensorrt_models_dir = os.path.join(folder_paths.models_dir, "tensorrt", "rife")
onnx_models_dir = os.path.join(folder_paths.models_dir, "onnx")
os.makedirs(tensorrt_models_dir, exist_ok=True)
os.makedirs(onnx_models_dir, exist_ok=True)
onnx_model_path = os.path.join(onnx_models_dir, f"{model}.onnx")
# Get resolution dimensions based on profile
if resolution_profile == "custom":
if custom_config is None:
rife_logger.warning("Custom profile selected but no custom_config provided. Using defaults that cover both small and medium ranges.")
dim_min, dim_opt, dim_max = 384, 720, 1312
else:
dim_min = custom_config.get("min", 384)
dim_opt = custom_config.get("opt", 720)
dim_max = custom_config.get("max", 1312)
# Use dimensions in profile name for custom engines
profile_name = f"custom_{dim_min}_{dim_opt}_{dim_max}"
else:
profiles = LOAD_RIFE_NODE_CONFIG.get("resolution_profiles", DEFAULT_RESOLUTION_PROFILES)
profile = profiles.get(resolution_profile, DEFAULT_RESOLUTION_PROFILES["small"])
dim_min = profile.get("min", 384)
dim_opt = profile.get("opt", 720)
dim_max = profile.get("max", 1080)
profile_name = resolution_profile
rife_logger.info(f"Using resolution profile '{profile_name}': min={dim_min}, opt={dim_opt}, max={dim_max}")
# Build tensorrt model path with detailed naming (includes profile)
engine_channel = 3
engine_min_batch, engine_opt_batch, engine_max_batch = 1, 1, 1
engine_min_h, engine_opt_h, engine_max_h = dim_min, dim_opt, dim_max
engine_min_w, engine_opt_w, engine_max_w = dim_min, dim_opt, dim_max
tensorrt_model_path = os.path.join(tensorrt_models_dir, f"{model}_{precision}_{profile_name}_{engine_min_batch}x{engine_channel}x{engine_min_h}x{engine_min_w}_{engine_opt_batch}x{engine_channel}x{engine_opt_h}x{engine_opt_w}_{engine_max_batch}x{engine_channel}x{engine_max_h}x{engine_max_w}_{tensorrt.__version__}.trt")
if not os.path.exists(tensorrt_model_path):
if not os.path.exists(onnx_model_path):
onnx_model_download_url = f"https://huggingface.co/yuvraj108c/rife-onnx/resolve/main/{model}.onnx"
rife_logger.info(f"Downloading {onnx_model_download_url}")
download_file(url=onnx_model_download_url, save_path=onnx_model_path)
else:
rife_logger.info(f"ONNX model found at: {onnx_model_path}")
rife_logger.info(f"Building TensorRT engine for {onnx_model_path}: {tensorrt_model_path}")
mm.soft_empty_cache()
s = time.time()
engine = Engine(tensorrt_model_path)
engine.build(
onnx_path=onnx_model_path,
fp16=True if precision == "fp16" else False,
input_profile=[
{
"img0": [(engine_min_batch, engine_channel, engine_min_h, engine_min_w), (engine_opt_batch, engine_channel, engine_opt_h, engine_opt_w), (engine_max_batch, engine_channel, engine_max_h, engine_max_w)],
"img1": [(engine_min_batch, engine_channel, engine_min_h, engine_min_w), (engine_opt_batch, engine_channel, engine_opt_h, engine_opt_w), (engine_max_batch, engine_channel, engine_max_h, engine_max_w)],
}
],
)
e = time.time()
rife_logger.info(f"Time taken to build: {(e-s)} seconds")
rife_logger.info(f"Loading TensorRT engine: {tensorrt_model_path}")
mm.soft_empty_cache()
engine = Engine(tensorrt_model_path)
engine.load()
return (engine,)
class AutoRifeTensorrt:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"frames": ("IMAGE", {"tooltip": "Input frames for video frame interpolation"}),
"rife_trt_model": ("RIFE_TRT_MODEL", {"tooltip": "Tensorrt model built and loaded"}),
"clear_cache_after_n_frames": ("INT", {"default": 100, "min": 1, "max": 1000, "tooltip": "Clear CUDA cache after processing this many frames"}),
"multiplier": ("INT", {"default": 2, "min": 1, "tooltip": "Frame interpolation multiplier"}),
"use_cuda_graph": ("BOOLEAN", {"default": True, "tooltip": "Use CUDA graph for better performance"}),
"keep_model_loaded": ("BOOLEAN", {"default": False, "tooltip": "Keep model loaded in memory after processing"}),
},
}
RETURN_TYPES = ("IMAGE", )
FUNCTION = "vfi"
CATEGORY = "tensorrt"
OUTPUT_NODE=True
def vfi(
self,
frames,
rife_trt_model,
clear_cache_after_n_frames=100,
multiplier=2,
use_cuda_graph=True,
keep_model_loaded=False,
):
B, H, W, C = frames.shape
shape_dict = {
"img0": {"shape": (1, 3, H, W)},
"img1": {"shape": (1, 3, H, W)},
"output": {"shape": (1, 3, H, W)},
}
cudaStream = cuda.Stream()
# Use the provided model directly
engine = rife_trt_model
logger(f"Using loaded TensorRT engine")
# Activate and allocate buffers for the engine
engine.activate()
engine.allocate_buffers(shape_dict=shape_dict)
frames = preprocess_frames(frames)
def return_middle_frame(frame_0, frame_1, timestep):
timestep_t = torch.tensor([timestep], dtype=torch.float32).to(get_torch_device())
# s = time.time()
output = engine.infer({"img0": frame_0, "img1": frame_1, "timestep": timestep_t}, cudaStream, use_cuda_graph)
# e = time.time()
# print(f"Time taken to infer: {(e-s)*1000} ms")
result = output['output']
return result
result = generate_frames_rife(frames, clear_cache_after_n_frames, multiplier, return_middle_frame)
out = postprocess_frames(result)
if not keep_model_loaded:
engine.reset()
return (out,)
NODE_CLASS_MAPPINGS = {
"AutoRifeTensorrt": AutoRifeTensorrt,
"AutoLoadRifeTensorrtModel": AutoLoadRifeTensorrtModel,
"CustomResolutionConfig": CustomResolutionConfig,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"AutoRifeTensorrt": "Auto RIFE TensorRT",
"AutoLoadRifeTensorrtModel": "(Down)load RIFE TensorRT Model",
"CustomResolutionConfig": "RIFE Custom Resolution Config",
}
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']