Skip to content

Commit a995884

Browse files
committed
debug added
1 parent e32fa7f commit a995884

File tree

1 file changed

+2
-14
lines changed

1 file changed

+2
-14
lines changed

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,24 +1359,12 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13591359
return LTX2PipelineOutput(frames=video, audio=audio)
13601360

13611361

1362-
@partial(
1363-
jax.jit,
1364-
static_argnames=(
1365-
"do_classifier_free_guidance",
1366-
"guidance_scale",
1367-
"latent_num_frames",
1368-
"latent_height",
1369-
"latent_width",
1370-
"audio_num_frames",
1371-
"fps",
1372-
),
1373-
)
13741362

13751363
def print_shardings(pytree, prefix=""):
13761364
flat_tree, _ = jax.tree_util.tree_flatten(pytree)
13771365
for i, leaf in enumerate(flat_tree):
13781366
if hasattr(leaf, 'sharding'):
1379-
print(f"{prefix}leaf_{i} sharding: {leaf.sharding}")
1367+
print(f"{prefix}leaf_{i} sharding: {{leaf.sharding}}")
13801368
else:
13811369
print(f"{prefix}leaf_{i} has no sharding attribute")
13821370

@@ -1390,7 +1378,7 @@ def print_shardings(pytree, prefix=""):
13901378
"latent_width",
13911379
"audio_num_frames",
13921380
"fps",
1393-
),
1381+
),
13941382
)
13951383
def transformer_forward_pass(
13961384
graphdef,

0 commit comments

Comments
 (0)