Skip to content

Commit c7f0f30

Browse files
committed
Added audio_format param to make methods in export_utils generic
1 parent 269b107 commit c7f0f30

3 files changed

Lines changed: 29 additions & 7 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,4 @@ qwix_module_path: ".*"
9191
jit_initializers: True
9292
enable_single_replica_ckpt_restoring: False
9393
seed: 0
94+
audio_format: "s16"

src/maxdiffusion/generate_ltx2.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,15 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
164164
video_path = f"{filename_prefix}ltx2_output_{getattr(config, 'seed', 0)}_{i}.mp4"
165165
audio_i = audios[i] if audios is not None else None
166166

167+
audio_format = getattr(config, "audio_format", "s16")
168+
167169
export_to_video_with_audio(
168-
video=videos[i], fps=fps, audio=audio_i, audio_sample_rate=audio_sample_rate, output_path=video_path
170+
video=videos[i],
171+
fps=fps,
172+
audio=audio_i,
173+
audio_sample_rate=audio_sample_rate,
174+
output_path=video_path,
175+
audio_format=audio_format,
169176
)
170177

171178
saved_video_path.append(video_path)

src/maxdiffusion/utils/export_utils.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def _write_audio(
271271
audio_stream,
272272
samples: Any,
273273
audio_sample_rate: int,
274+
target_format: str = "s16",
274275
) -> None:
275276
import numpy as np
276277

@@ -286,14 +287,27 @@ def _write_audio(
286287
if samples.shape[1] != 2:
287288
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
288289

289-
if samples.dtype != np.int16:
290-
samples = np.clip(samples, -1.0, 1.0)
291-
samples = (samples * 32767.0).astype(np.int16)
290+
if target_format == "s16":
291+
if samples.dtype != np.int16:
292+
samples = np.clip(samples, -1.0, 1.0)
293+
samples = (samples * 32767.0).astype(np.int16)
294+
elif target_format == "s32":
295+
if samples.dtype != np.int32:
296+
samples = np.clip(samples, -1.0, 1.0)
297+
samples = (samples * 2147483647.0).astype(np.int32)
298+
elif target_format in ["flt", "dbl", "fltp", "dblp"]:
299+
target_dtype = np.float32 if "flt" in target_format else np.float64
300+
if samples.dtype != target_dtype:
301+
samples = samples.astype(target_dtype)
302+
else:
303+
# Fallback to clip and scaling for other int formats if they were added, but raise for now
304+
raise ValueError(f"Unsupported target_format for converting numpy array: {target_format}")
305+
292306
samples_np = np.ascontiguousarray(samples).reshape(1, -1)
293307

294308
frame_in = av.AudioFrame.from_ndarray(
295309
samples_np,
296-
format="s16",
310+
format=target_format,
297311
layout="stereo",
298312
)
299313
frame_in.sample_rate = audio_sample_rate
@@ -302,7 +316,7 @@ def _write_audio(
302316

303317

304318
def export_to_video_with_audio(
305-
video: Any, fps: int, audio: Optional[Any], audio_sample_rate: Optional[int], output_path: str
319+
video: Any, fps: int, audio: Optional[Any], audio_sample_rate: Optional[int], output_path: str, audio_format: str = "s16"
306320
) -> None:
307321
"""
308322
Encodes video (and optionally audio) to a file using PyAV.
@@ -351,6 +365,6 @@ def export_to_video_with_audio(
351365
container.mux(packet)
352366

353367
if audio is not None:
354-
_write_audio(container, audio_stream, audio, audio_sample_rate)
368+
_write_audio(container, audio_stream, audio, audio_sample_rate, target_format=audio_format)
355369

356370
container.close()

0 commit comments

Comments
 (0)