Skip to content

Commit 6f1a088

Browse files
committed
debug added
1 parent 9f8f5be commit 6f1a088

2 files changed

Lines changed: 7 additions & 5 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ logical_axis_rules: [
4747
['activation_length', 'context'],
4848
['activation_heads', 'tensor'],
4949
['mlp','tensor'],
50-
['embed', None],
50+
['embed', ['context', 'fsdp']],
5151
['heads', 'tensor'],
5252
['norm', 'tensor'],
5353
['conv_batch', ['data', 'context', 'fsdp']],

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,12 +1361,14 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13611361

13621362

13631363
def print_shardings(pytree, prefix=""):
1364-
flat_tree, _ = jax.tree_util.tree_flatten(pytree)
1365-
for i, leaf in enumerate(flat_tree):
1364+
flat_tree, treedef = jax.tree_util.tree_flatten_with_path(pytree)
1365+
for i, (path, leaf) in enumerate(flat_tree):
1366+
path_str = jax.tree_util.keystr(path)
1367+
shape_str = f"shape: {leaf.shape}" if hasattr(leaf, 'shape') else "shape: N/A"
13661368
if hasattr(leaf, 'sharding'):
1367-
print(f"{prefix}leaf_{i} sharding: {leaf.sharding}")
1369+
print(f"{prefix}leaf_{i} {path_str} {shape_str} sharding: {leaf.sharding}")
13681370
else:
1369-
print(f"{prefix}leaf_{i} has no sharding attribute")
1371+
print(f"{prefix}leaf_{i} {path_str} {shape_str} has no sharding attribute")
13701372

13711373
@partial(
13721374
jax.jit,

0 commit comments

Comments
 (0)