@@ -86,58 +86,61 @@ class VaceInputInfo:
8686
8787@dataclass
8888class S2VInputInfo :
89- infer_steps : int | Any = UNSET
90- seed : int | Any = UNSET
91- prompt : str | Any = UNSET
92- prompt_enhanced : str | Any = UNSET
93- negative_prompt : str | Any = UNSET
94- image_path : str | Any = UNSET
95- audio_path : str | Any = UNSET
96- audio_num : int | Any = UNSET
97- video_duration : float | Any = UNSET
98- with_mask : bool | Any = UNSET
99- return_result_tensor : bool | Any = UNSET
100- save_result_path : str | Any = UNSET
101- return_result_tensor : bool | Any = UNSET
102- stream_config : dict | Any = UNSET
103- resize_mode : str | Any = UNSET
104- target_shape : list | Any = UNSET
89+ seed : int = field (default_factory = int )
90+ prompt : str = field (default_factory = str )
91+ prompt_enhanced : str = field (default_factory = str )
92+ negative_prompt : str = field (default_factory = str )
93+ image_path : str = field (default_factory = str )
94+ audio_path : str = field (default_factory = str )
95+ audio_num : int = field (default_factory = int )
96+ with_mask : bool = field (default_factory = lambda : False )
97+ save_result_path : str = field (default_factory = str )
98+ return_result_tensor : bool = field (default_factory = lambda : False )
99+ stream_config : dict = field (default_factory = dict )
100+ # shape related
101+ resize_mode : str = field (default_factory = str )
102+ original_shape : list = field (default_factory = list )
103+ resized_shape : list = field (default_factory = list )
104+ latent_shape : list = field (default_factory = list )
105+ target_shape : list = field (default_factory = list )
106+
105107 # prev info
106- overlap_frame : torch .Tensor | Any = UNSET
107- overlap_latent : torch .Tensor | Any = UNSET
108+ overlap_frame : torch .Tensor = field ( default_factory = lambda : None )
109+ overlap_latent : torch .Tensor = field ( default_factory = lambda : None )
108110 # input preprocess audio
109- audio_clip : torch .Tensor | Any = UNSET
110-
111- @classmethod
112- def from_args (cls , args , ** overrides ):
113- """
114- Build InputInfo from argparse.Namespace (or any object with __dict__)
115- Priority:
116- args < overrides
117- """
118- field_names = {f .name for f in fields (cls )}
119- data = {k : v for k , v in vars (args ).items () if k in field_names }
120- data .update (overrides )
121- return cls (** data )
122-
123- def normalize_unset_to_none (self ):
124- """
125- Replace all UNSET fields with None.
126- Call this right before running / inference.
127- """
128- for f in fields (self ):
129- if getattr (self , f .name ) is UNSET :
130- setattr (self , f .name , None )
131- return self
111+ audio_clip : torch .Tensor = field (default_factory = lambda : None )
132112
133113
134114@dataclass
135- class RS2VInputInfo (S2VInputInfo ):
115+ class RS2VInputInfo :
116+ seed : int = field (default_factory = int )
117+ prompt : str = field (default_factory = str )
118+ prompt_enhanced : str = field (default_factory = str )
119+ negative_prompt : str = field (default_factory = str )
120+ image_path : str = field (default_factory = str )
121+ audio_path : str = field (default_factory = str )
122+ audio_num : int = field (default_factory = int )
123+ with_mask : bool = field (default_factory = lambda : False )
124+ save_result_path : str = field (default_factory = str )
125+ return_result_tensor : bool = field (default_factory = lambda : False )
126+ stream_config : dict = field (default_factory = dict )
127+ # shape related
128+ resize_mode : str = field (default_factory = str )
129+ original_shape : list = field (default_factory = list )
130+ resized_shape : list = field (default_factory = list )
131+ latent_shape : list = field (default_factory = list )
132+ target_shape : list = field (default_factory = list )
133+
134+ # prev info
135+ overlap_frame : torch .Tensor = field (default_factory = lambda : None )
136+ overlap_latent : torch .Tensor = field (default_factory = lambda : None )
137+ # input preprocess audio
138+ audio_clip : torch .Tensor = field (default_factory = lambda : None )
136139 # input reference state
137- ref_state : int | Any = UNSET
140+ ref_state : int = field ( default_factory = int )
138141 # flags for first and last clip
139- is_first : bool | Any = UNSET
140- is_last : bool | Any = UNSET
142+ is_first : bool = field ( default_factory = lambda : False )
143+ is_last : bool = field ( default_factory = lambda : False )
141144
142145
143146# Need Check
@@ -280,15 +283,6 @@ class WorldPlayT2VInputInfo:
280283 action : torch .Tensor = field (default_factory = lambda : None )
281284
282285
283- def init_input_info_from_args (task , args , ** overrides ):
284- if task == "s2v" :
285- return S2VInputInfo .from_args (args , ** overrides )
286- elif task == "rs2v" :
287- return RS2VInputInfo .from_args (args , ** overrides )
288- else :
289- raise ValueError (f"Unsupported task: { task } " )
290-
291-
292286def init_empty_input_info (task ):
293287 if task == "t2v" :
294288 return T2VInputInfo ()
@@ -320,6 +314,68 @@ def init_empty_input_info(task):
320314 raise ValueError (f"Unsupported task: { task } " )
321315
322316
317+ @dataclass
318+ class SekoTalkInputs :
319+ infer_steps : int | Any = UNSET
320+ seed : int | Any = UNSET
321+ prompt : str | Any = UNSET
322+ prompt_enhanced : str | Any = UNSET
323+ negative_prompt : str | Any = UNSET
324+ image_path : str | Any = UNSET
325+ audio_path : str | Any = UNSET
326+ audio_num : int | Any = UNSET
327+ video_duration : float | Any = UNSET
328+ with_mask : bool | Any = UNSET
329+ return_result_tensor : bool | Any = UNSET
330+ save_result_path : str | Any = UNSET
331+ return_result_tensor : bool | Any = UNSET
332+ stream_config : dict | Any = UNSET
333+
334+ resize_mode : str | Any = UNSET
335+ target_shape : list | Any = UNSET
336+
337+ # prev info
338+ overlap_frame : torch .Tensor | Any = UNSET
339+ overlap_latent : torch .Tensor | Any = UNSET
340+ # input preprocess audio
341+ audio_clip : torch .Tensor | Any = UNSET
342+
343+ # input reference state
344+ ref_state : int | Any = UNSET
345+ # flags for first and last clip
346+ is_first : bool | Any = UNSET
347+ is_last : bool | Any = UNSET
348+
349+ @classmethod
350+ def from_args (cls , args , ** overrides ):
351+ """
352+ Build InputInfo from argparse.Namespace (or any object with __dict__)
353+ Priority:
354+ args < overrides
355+ """
356+ field_names = {f .name for f in fields (cls )}
357+ data = {k : v for k , v in vars (args ).items () if k in field_names }
358+ data .update (overrides )
359+ return cls (** data )
360+
361+ def normalize_unset_to_none (self ):
362+ """
363+ Replace all UNSET fields with None.
364+ Call this right before running / inference.
365+ """
366+ for f in fields (self ):
367+ if getattr (self , f .name ) is UNSET :
368+ setattr (self , f .name , None )
369+ return self
370+
371+
372+ def init_input_info_from_args (task , args , ** overrides ):
373+ if task in ["s2v" , "rs2v" ]:
374+ return SekoTalkInputs .from_args (args , ** overrides )
375+ else :
376+ raise ValueError (f"Unsupported task: { task } " )
377+
378+
323379def fill_input_info_from_defaults (input_info , defaults ):
324380 for key in input_info .__dataclass_fields__ :
325381 if key in defaults and getattr (input_info , key ) is UNSET :
0 commit comments