33from comfy .model_management import get_torch_device
44from .vfi_utilities import preprocess_frames , postprocess_frames , generate_frames_rife , logger
55from .trt_utilities import Engine
6+ from .utilities import download_file , ColoredLogger
67import folder_paths
78import time
89from polygraphy import cuda
10+ import comfy .model_management as mm
11+ import tensorrt
12+ import json
913
1014ENGINE_DIR = os .path .join (folder_paths .models_dir , "tensorrt" , "rife" )
1115
16+ # Image dimensions for TensorRT engine building
17+ IMAGE_DIM_MIN = 256
18+ IMAGE_DIM_OPT = 512
19+ IMAGE_DIM_MAX = 3840
20+
21+ # Logger for this module
22+ rife_logger = ColoredLogger ("ComfyUI-Rife-Tensorrt" )
23+
24+ # Function to load configuration
25+ def load_node_config (config_filename = "load_rife_config.json" ):
26+ """Loads node configuration from a JSON file."""
27+ current_dir = os .path .dirname (__file__ )
28+ config_path = os .path .join (current_dir , config_filename )
29+
30+ default_config = {
31+ "model" : {
32+ "options" : ["rife49_ensemble_True_scale_1_sim" ],
33+ "default" : "rife49_ensemble_True_scale_1_sim" ,
34+ "tooltip" : "Default model (fallback from code)"
35+ },
36+ "precision" : {
37+ "options" : ["fp16" , "fp32" ],
38+ "default" : "fp16" ,
39+ "tooltip" : "Default precision (fallback from code)"
40+ }
41+ }
42+
43+ try :
44+ with open (config_path , 'r' ) as f :
45+ config = json .load (f )
46+ rife_logger .info (f"Successfully loaded configuration from { config_filename } " )
47+ return config
48+ except FileNotFoundError :
49+ rife_logger .warning (f"Configuration file '{ config_path } ' not found. Using default fallback configuration." )
50+ return default_config
51+ except json .JSONDecodeError :
52+ rife_logger .error (f"Error decoding JSON from '{ config_path } '. Using default fallback configuration." )
53+ return default_config
54+ except Exception as e :
55+ rife_logger .error (f"An unexpected error occurred while loading '{ config_path } ': { e } . Using default fallback." )
56+ return default_config
57+
58+ # Load the configuration once when the module is imported
59+ LOAD_RIFE_NODE_CONFIG = load_node_config ()
60+
61+ class LoadRifeTensorrtModel :
62+ @classmethod
63+ def INPUT_TYPES (cls ):
64+ # Use the pre-loaded configuration
65+ model_config = LOAD_RIFE_NODE_CONFIG .get ("model" , {})
66+ precision_config = LOAD_RIFE_NODE_CONFIG .get ("precision" , {})
67+
68+ # Provide sensible defaults if keys are missing in the config
69+ model_options = model_config .get ("options" , ["rife49_ensemble_True_scale_1_sim" ])
70+ model_default = model_config .get ("default" , "rife49_ensemble_True_scale_1_sim" )
71+ model_tooltip = model_config .get ("tooltip" , "Select a RIFE model." )
72+
73+ precision_options = precision_config .get ("options" , ["fp16" , "fp32" ])
74+ precision_default = precision_config .get ("default" , "fp16" )
75+ precision_tooltip = precision_config .get ("tooltip" , "Select precision." )
76+
77+ return {
78+ "required" : {
79+ "model" : (model_options , {"default" : model_default , "tooltip" : model_tooltip }),
80+ "precision" : (precision_options , {"default" : precision_default , "tooltip" : precision_tooltip }),
81+ }
82+ }
83+
84+ RETURN_NAMES = ("rife_trt_model" ,)
85+ RETURN_TYPES = ("RIFE_TRT_MODEL" ,)
86+ CATEGORY = "tensorrt"
87+ DESCRIPTION = "Load RIFE tensorrt models, they will be built automatically if not found."
88+ FUNCTION = "load_rife_tensorrt_model"
89+
90+ def load_rife_tensorrt_model (self , model , precision ):
91+ tensorrt_models_dir = os .path .join (folder_paths .models_dir , "tensorrt" , "rife" )
92+ onnx_models_dir = os .path .join (folder_paths .models_dir , "onnx" )
93+
94+ os .makedirs (tensorrt_models_dir , exist_ok = True )
95+ os .makedirs (onnx_models_dir , exist_ok = True )
96+
97+ onnx_model_path = os .path .join (onnx_models_dir , f"{ model } .onnx" )
98+
99+ # Build tensorrt model path with detailed naming
100+ engine_channel = 3
101+ engine_min_batch , engine_opt_batch , engine_max_batch = 1 , 1 , 1
102+ engine_min_h , engine_opt_h , engine_max_h = IMAGE_DIM_MIN , IMAGE_DIM_OPT , IMAGE_DIM_MAX
103+ engine_min_w , engine_opt_w , engine_max_w = IMAGE_DIM_MIN , IMAGE_DIM_OPT , IMAGE_DIM_MAX
104+ tensorrt_model_path = os .path .join (tensorrt_models_dir , f"{ model } _{ precision } _{ engine_min_batch } x{ engine_channel } x{ engine_min_h } x{ engine_min_w } _{ engine_opt_batch } x{ engine_channel } x{ engine_opt_h } x{ engine_opt_w } _{ engine_max_batch } x{ engine_channel } x{ engine_max_h } x{ engine_max_w } _{ tensorrt .__version__ } .trt" )
105+
106+ if not os .path .exists (tensorrt_model_path ):
107+ if not os .path .exists (onnx_model_path ):
108+ onnx_model_download_url = f"https://huggingface.co/yuvraj108c/rife-onnx/resolve/main/{ model } .onnx"
109+ rife_logger .info (f"Downloading { onnx_model_download_url } " )
110+ download_file (url = onnx_model_download_url , save_path = onnx_model_path )
111+ else :
112+ rife_logger .info (f"ONNX model found at: { onnx_model_path } " )
113+
114+ rife_logger .info (f"Building TensorRT engine for { onnx_model_path } : { tensorrt_model_path } " )
115+ mm .soft_empty_cache ()
116+ s = time .time ()
117+ engine = Engine (tensorrt_model_path )
118+ ret = engine .build (
119+ onnx_path = onnx_model_path ,
120+ fp16 = True if precision == "fp16" else False ,
121+ input_profile = [
122+ {
123+ "img0" : [(engine_min_batch , engine_channel , engine_min_h , engine_min_w ), (engine_opt_batch , engine_channel , engine_opt_h , engine_opt_w ), (engine_max_batch , engine_channel , engine_max_h , engine_max_w )],
124+ "img1" : [(engine_min_batch , engine_channel , engine_min_h , engine_min_w ), (engine_opt_batch , engine_channel , engine_opt_h , engine_opt_w ), (engine_max_batch , engine_channel , engine_max_h , engine_max_w )],
125+ }
126+ ],
127+ )
128+ if ret != 0 :
129+ if os .path .exists (tensorrt_model_path ):
130+ os .remove (tensorrt_model_path )
131+ raise RuntimeError (f"Failed to build TensorRT engine from { onnx_model_path } " )
132+ e = time .time ()
133+ rife_logger .info (f"Time taken to build: { (e - s )} seconds" )
134+
135+ rife_logger .info (f"Loading TensorRT engine: { tensorrt_model_path } " )
136+ mm .soft_empty_cache ()
137+ engine = Engine (tensorrt_model_path )
138+ engine .load ()
139+
140+ return (engine ,)
141+
12142class RifeTensorrt :
13143 @classmethod
14144 def INPUT_TYPES (s ):
15145 return {
16146 "required" : {
17- "frames" : ("IMAGE" , ),
18- "engine " : (os . listdir ( ENGINE_DIR ), ),
19- "clear_cache_after_n_frames" : ("INT" , {"default" : 100 , "min" : 1 , "max" : 1000 }),
20- "multiplier" : ("INT" , {"default" : 2 , "min" : 1 }),
21- "use_cuda_graph" : ("BOOLEAN" , {"default" : True }),
22- "keep_model_loaded" : ("BOOLEAN" , {"default" : False }),
147+ "frames" : ("IMAGE" , { "tooltip" : "Input frames for video frame interpolation" } ),
148+ "rife_trt_model " : ("RIFE_TRT_MODEL" , { "tooltip" : "Tensorrt model built and loaded" } ),
149+ "clear_cache_after_n_frames" : ("INT" , {"default" : 100 , "min" : 1 , "max" : 1000 , "tooltip" : "Clear CUDA cache after processing this many frames" }),
150+ "multiplier" : ("INT" , {"default" : 2 , "min" : 1 , "tooltip" : "Frame interpolation multiplier" }),
151+ "use_cuda_graph" : ("BOOLEAN" , {"default" : True , "tooltip" : "Use CUDA graph for better performance" }),
152+ "keep_model_loaded" : ("BOOLEAN" , {"default" : False , "tooltip" : "Keep model loaded in memory after processing" }),
23153 },
24154 }
25155
@@ -31,7 +161,7 @@ def INPUT_TYPES(s):
31161 def vfi (
32162 self ,
33163 frames ,
34- engine ,
164+ rife_trt_model ,
35165 clear_cache_after_n_frames = 100 ,
36166 multiplier = 2 ,
37167 use_cuda_graph = True ,
@@ -45,24 +175,21 @@ def vfi(
45175 }
46176
47177 cudaStream = cuda .Stream ()
48- engine_path = os .path .join (ENGINE_DIR , engine )
49- if (not hasattr (self , 'engine' ) or self .engine_label != engine ):
50- self .engine = Engine (engine_path )
51- logger (f"Loading TensorRT engine: { engine_path } " )
52- self .engine .load ()
53- self .engine .activate ()
54- self .engine_label = engine
55- else :
56- logger (f"Using cached TensorRT engine: { engine_path } " )
57-
58- self .engine .allocate_buffers (shape_dict = shape_dict )
178+
179+ # Use the provided model directly
180+ engine = rife_trt_model
181+ logger (f"Using loaded TensorRT engine" )
182+
183+ # Activate and allocate buffers for the engine
184+ engine .activate ()
185+ engine .allocate_buffers (shape_dict = shape_dict )
59186
60187 frames = preprocess_frames (frames )
61188
62189 def return_middle_frame (frame_0 , frame_1 , timestep ):
63190 timestep_t = torch .tensor ([timestep ], dtype = torch .float32 ).to (get_torch_device ())
64191 # s = time.time()
65- output = self . engine .infer ({"img0" : frame_0 , "img1" : frame_1 , "timestep" : timestep_t }, cudaStream , use_cuda_graph )
192+ output = engine .infer ({"img0" : frame_0 , "img1" : frame_1 , "timestep" : timestep_t }, cudaStream , use_cuda_graph )
66193 # e = time.time()
67194 # print(f"Time taken to infer: {(e-s)*1000} ms")
68195
@@ -71,19 +198,21 @@ def return_middle_frame(frame_0, frame_1, timestep):
71198
72199 result = generate_frames_rife (frames , clear_cache_after_n_frames , multiplier , return_middle_frame )
73200 out = postprocess_frames (result )
74-
201+
75202 if not keep_model_loaded :
76- del self . engine , self . engine_label
203+ engine . reset ()
77204
78205 return (out ,)
79206
80207
81208NODE_CLASS_MAPPINGS = {
82209 "RifeTensorrt" : RifeTensorrt ,
210+ "LoadRifeTensorrtModel" : LoadRifeTensorrtModel ,
83211}
84212
85213NODE_DISPLAY_NAME_MAPPINGS = {
86214 "RifeTensorrt" : "⚡ Rife Tensorrt" ,
215+ "LoadRifeTensorrtModel" : "Load Rife Tensorrt Model" ,
87216}
88217
89218__all__ = ['NODE_CLASS_MAPPINGS' , 'NODE_DISPLAY_NAME_MAPPINGS' ]
0 commit comments