Skip to content

Commit e32fa7f

Browse files
committed
debug added
1 parent a0f171d commit e32fa7f

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from functools import partial
1919

2020
import numpy as np
21+
import functools
2122
import torch
2223
import jax
2324
import jax.numpy as jnp
@@ -1231,6 +1232,9 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12311232
)
12321233

12331234
for i, t in enumerate(timesteps):
1235+
print(f"\n--- DEBUG SHARDING Step {i} Time {t} ---")
1236+
print_shardings(state, prefix="transformer_state.")
1237+
12341238
noise_pred, noise_pred_audio = transformer_forward_pass(
12351239
graphdef,
12361240
state,
@@ -1367,6 +1371,27 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13671371
"fps",
13681372
),
13691373
)
1374+
1375+
def print_shardings(pytree, prefix=""):
1376+
flat_tree, _ = jax.tree_util.tree_flatten(pytree)
1377+
for i, leaf in enumerate(flat_tree):
1378+
if hasattr(leaf, 'sharding'):
1379+
print(f"{prefix}leaf_{i} sharding: {leaf.sharding}")
1380+
else:
1381+
print(f"{prefix}leaf_{i} has no sharding attribute")
1382+
1383+
@partial(
1384+
jax.jit,
1385+
static_argnames=(
1386+
"do_classifier_free_guidance",
1387+
"guidance_scale",
1388+
"latent_num_frames",
1389+
"latent_height",
1390+
"latent_width",
1391+
"audio_num_frames",
1392+
"fps",
1393+
),
1394+
)
13701395
def transformer_forward_pass(
13711396
graphdef,
13721397
state,

0 commit comments

Comments
 (0)