Skip to content

Commit 782d353

Browse files
committed
Add SekoTalkInputs
1 parent 709d6a0 commit 782d353

1 file changed

Lines changed: 110 additions & 54 deletions

File tree

lightx2v/utils/input_info.py

Lines changed: 110 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -86,58 +86,61 @@ class VaceInputInfo:
8686

8787
@dataclass
8888
class S2VInputInfo:
89-
infer_steps: int | Any = UNSET
90-
seed: int | Any = UNSET
91-
prompt: str | Any = UNSET
92-
prompt_enhanced: str | Any = UNSET
93-
negative_prompt: str | Any = UNSET
94-
image_path: str | Any = UNSET
95-
audio_path: str | Any = UNSET
96-
audio_num: int | Any = UNSET
97-
video_duration: float | Any = UNSET
98-
with_mask: bool | Any = UNSET
99-
return_result_tensor: bool | Any = UNSET
100-
save_result_path: str | Any = UNSET
101-
return_result_tensor: bool | Any = UNSET
102-
stream_config: dict | Any = UNSET
103-
resize_mode: str | Any = UNSET
104-
target_shape: list | Any = UNSET
89+
seed: int = field(default_factory=int)
90+
prompt: str = field(default_factory=str)
91+
prompt_enhanced: str = field(default_factory=str)
92+
negative_prompt: str = field(default_factory=str)
93+
image_path: str = field(default_factory=str)
94+
audio_path: str = field(default_factory=str)
95+
audio_num: int = field(default_factory=int)
96+
with_mask: bool = field(default_factory=lambda: False)
97+
save_result_path: str = field(default_factory=str)
98+
return_result_tensor: bool = field(default_factory=lambda: False)
99+
stream_config: dict = field(default_factory=dict)
100+
# shape related
101+
resize_mode: str = field(default_factory=str)
102+
original_shape: list = field(default_factory=list)
103+
resized_shape: list = field(default_factory=list)
104+
latent_shape: list = field(default_factory=list)
105+
target_shape: list = field(default_factory=list)
106+
105107
# prev info
106-
overlap_frame: torch.Tensor | Any = UNSET
107-
overlap_latent: torch.Tensor | Any = UNSET
108+
overlap_frame: torch.Tensor = field(default_factory=lambda: None)
109+
overlap_latent: torch.Tensor = field(default_factory=lambda: None)
108110
# input preprocess audio
109-
audio_clip: torch.Tensor | Any = UNSET
110-
111-
@classmethod
112-
def from_args(cls, args, **overrides):
113-
"""
114-
Build InputInfo from argparse.Namespace (or any object with __dict__)
115-
Priority:
116-
args < overrides
117-
"""
118-
field_names = {f.name for f in fields(cls)}
119-
data = {k: v for k, v in vars(args).items() if k in field_names}
120-
data.update(overrides)
121-
return cls(**data)
122-
123-
def normalize_unset_to_none(self):
124-
"""
125-
Replace all UNSET fields with None.
126-
Call this right before running / inference.
127-
"""
128-
for f in fields(self):
129-
if getattr(self, f.name) is UNSET:
130-
setattr(self, f.name, None)
131-
return self
111+
audio_clip: torch.Tensor = field(default_factory=lambda: None)
132112

133113

134114
@dataclass
135-
class RS2VInputInfo(S2VInputInfo):
115+
class RS2VInputInfo:
116+
seed: int = field(default_factory=int)
117+
prompt: str = field(default_factory=str)
118+
prompt_enhanced: str = field(default_factory=str)
119+
negative_prompt: str = field(default_factory=str)
120+
image_path: str = field(default_factory=str)
121+
audio_path: str = field(default_factory=str)
122+
audio_num: int = field(default_factory=int)
123+
with_mask: bool = field(default_factory=lambda: False)
124+
save_result_path: str = field(default_factory=str)
125+
return_result_tensor: bool = field(default_factory=lambda: False)
126+
stream_config: dict = field(default_factory=dict)
127+
# shape related
128+
resize_mode: str = field(default_factory=str)
129+
original_shape: list = field(default_factory=list)
130+
resized_shape: list = field(default_factory=list)
131+
latent_shape: list = field(default_factory=list)
132+
target_shape: list = field(default_factory=list)
133+
134+
# prev info
135+
overlap_frame: torch.Tensor = field(default_factory=lambda: None)
136+
overlap_latent: torch.Tensor = field(default_factory=lambda: None)
137+
# input preprocess audio
138+
audio_clip: torch.Tensor = field(default_factory=lambda: None)
136139
# input reference state
137-
ref_state: int | Any = UNSET
140+
ref_state: int = field(default_factory=int)
138141
# flags for first and last clip
139-
is_first: bool | Any = UNSET
140-
is_last: bool | Any = UNSET
142+
is_first: bool = field(default_factory=lambda: False)
143+
is_last: bool = field(default_factory=lambda: False)
141144

142145

143146
# Need Check
@@ -280,15 +283,6 @@ class WorldPlayT2VInputInfo:
280283
action: torch.Tensor = field(default_factory=lambda: None)
281284

282285

283-
def init_input_info_from_args(task, args, **overrides):
284-
if task == "s2v":
285-
return S2VInputInfo.from_args(args, **overrides)
286-
elif task == "rs2v":
287-
return RS2VInputInfo.from_args(args, **overrides)
288-
else:
289-
raise ValueError(f"Unsupported task: {task}")
290-
291-
292286
def init_empty_input_info(task):
293287
if task == "t2v":
294288
return T2VInputInfo()
@@ -320,6 +314,68 @@ def init_empty_input_info(task):
320314
raise ValueError(f"Unsupported task: {task}")
321315

322316

317+
@dataclass
318+
class SekoTalkInputs:
319+
infer_steps: int | Any = UNSET
320+
seed: int | Any = UNSET
321+
prompt: str | Any = UNSET
322+
prompt_enhanced: str | Any = UNSET
323+
negative_prompt: str | Any = UNSET
324+
image_path: str | Any = UNSET
325+
audio_path: str | Any = UNSET
326+
audio_num: int | Any = UNSET
327+
video_duration: float | Any = UNSET
328+
with_mask: bool | Any = UNSET
329+
return_result_tensor: bool | Any = UNSET
330+
save_result_path: str | Any = UNSET
331+
return_result_tensor: bool | Any = UNSET
332+
stream_config: dict | Any = UNSET
333+
334+
resize_mode: str | Any = UNSET
335+
target_shape: list | Any = UNSET
336+
337+
# prev info
338+
overlap_frame: torch.Tensor | Any = UNSET
339+
overlap_latent: torch.Tensor | Any = UNSET
340+
# input preprocess audio
341+
audio_clip: torch.Tensor | Any = UNSET
342+
343+
# input reference state
344+
ref_state: int | Any = UNSET
345+
# flags for first and last clip
346+
is_first: bool | Any = UNSET
347+
is_last: bool | Any = UNSET
348+
349+
@classmethod
350+
def from_args(cls, args, **overrides):
351+
"""
352+
Build InputInfo from argparse.Namespace (or any object with __dict__)
353+
Priority:
354+
args < overrides
355+
"""
356+
field_names = {f.name for f in fields(cls)}
357+
data = {k: v for k, v in vars(args).items() if k in field_names}
358+
data.update(overrides)
359+
return cls(**data)
360+
361+
def normalize_unset_to_none(self):
362+
"""
363+
Replace all UNSET fields with None.
364+
Call this right before running / inference.
365+
"""
366+
for f in fields(self):
367+
if getattr(self, f.name) is UNSET:
368+
setattr(self, f.name, None)
369+
return self
370+
371+
372+
def init_input_info_from_args(task, args, **overrides):
373+
if task in ["s2v", "rs2v"]:
374+
return SekoTalkInputs.from_args(args, **overrides)
375+
else:
376+
raise ValueError(f"Unsupported task: {task}")
377+
378+
323379
def fill_input_info_from_defaults(input_info, defaults):
324380
for key in input_info.__dataclass_fields__:
325381
if key in defaults and getattr(input_info, key) is UNSET:

0 commit comments

Comments
 (0)