Skip to content

Commit 3c9e649

Browse files
committed
Merge branch 'main' into ltx2-attention
2 parents 14e8490 + 095502a commit 3c9e649

14 files changed

Lines changed: 34 additions & 41 deletions

src/maxdiffusion/generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from absl import app
2727
from maxdiffusion import (pyconfig, FlaxDDIMScheduler, max_utils)
2828

29+
from maxdiffusion.train_utils import transformer_engine_context
2930
from maxdiffusion.maxdiffusion_utils import rescale_noise_cfg
3031
from flax.linen import partitioning as nn_partitioning
3132
from maxdiffusion.image_processor import VaeImageProcessor
@@ -261,4 +262,5 @@ def main(argv: Sequence[str]) -> None:
261262

262263

263264
if __name__ == "__main__":
264-
app.run(main)
265+
with transformer_engine_context():
266+
app.run(main)

src/maxdiffusion/generate_flux.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging
3535
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
36+
from maxdiffusion.train_utils import transformer_engine_context
3637
from maxdiffusion.max_utils import (
3738
device_put_replicated,
3839
get_memory_allocations,
@@ -492,4 +493,5 @@ def main(argv: Sequence[str]) -> None:
492493

493494

494495
if __name__ == "__main__":
495-
app.run(main)
496+
with transformer_engine_context():
497+
app.run(main)

src/maxdiffusion/generate_flux_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from maxdiffusion import pyconfig, max_logging, max_utils
2727

2828
from maxdiffusion.checkpointing.checkpointing_utils import load_params_from_path
29+
from maxdiffusion.train_utils import transformer_engine_context
2930
from maxdiffusion.max_utils import setup_initial_state
3031

3132

@@ -123,4 +124,5 @@ def main(argv: Sequence[str]) -> None:
123124

124125

125126
if __name__ == "__main__":
126-
app.run(main)
127+
with transformer_engine_context():
128+
app.run(main)

src/maxdiffusion/generate_ltx_video.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline, ConditioningItem
2222
import maxdiffusion.pipelines.ltx_video.crf_compressor as crf_compressor
2323
from maxdiffusion import pyconfig, max_logging
24+
from maxdiffusion.train_utils import transformer_engine_context
2425
import torchvision.transforms.functional as TVF
2526
import imageio
2627
from datetime import datetime
@@ -267,4 +268,5 @@ def main(argv: Sequence[str]) -> None:
267268

268269

269270
if __name__ == "__main__":
270-
app.run(main)
271+
with transformer_engine_context():
272+
app.run(main)

src/maxdiffusion/generate_sdxl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from maxdiffusion import pyconfig, max_utils
3131
from maxdiffusion.image_processor import VaeImageProcessor
32+
from maxdiffusion.train_utils import transformer_engine_context
3233
from maxdiffusion.maxdiffusion_utils import (
3334
get_add_time_ids,
3435
rescale_noise_cfg,
@@ -322,4 +323,5 @@ def main(argv: Sequence[str]) -> None:
322323

323324

324325
if __name__ == "__main__":
325-
app.run(main)
326+
with transformer_engine_context():
327+
app.run(main)

src/maxdiffusion/generate_wan.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2 import WanCheckpointerI2V_2_2
2424
from maxdiffusion import pyconfig, max_logging, max_utils
2525
from absl import app
26+
from maxdiffusion.train_utils import transformer_engine_context
2627
from maxdiffusion.utils import export_to_video
2728
from maxdiffusion.utils.loading_utils import load_image
2829
from google.cloud import storage
@@ -296,4 +297,5 @@ def main(argv: Sequence[str]) -> None:
296297

297298

298299
if __name__ == "__main__":
299-
app.run(main)
300+
with transformer_engine_context():
301+
app.run(main)

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ def __init__(
10111011
),
10121012
)
10131013

1014-
self.drop_out = nnx.Dropout(dropout)
1014+
self.drop_out = nnx.Dropout(dropout, deterministic=False)
10151015

10161016
self.norm_q = nnx.data(None)
10171017
self.norm_k = nnx.data(None)

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def __init__(
237237
else:
238238
raise NotImplementedError(f"{activation_fn} is not implemented.")
239239

240-
self.drop_out = nnx.Dropout(dropout)
240+
self.drop_out = nnx.Dropout(dropout, deterministic=False)
241241
self.proj_out = nnx.Linear(
242242
rngs=rngs,
243243
in_features=inner_dim,

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import List, Union, Optional
1818
from ...pyconfig import HyperParameters
1919
from functools import partial
20-
from contextlib import nullcontext
2120
from flax import nnx
2221
from flax.linen import partitioning as nn_partitioning
2322
import jax
@@ -116,15 +115,8 @@ def __call__(
116115
scheduler=self.scheduler,
117116
scheduler_state=scheduler_state,
118117
)
119-
# Set the TE shard_guard context_manager if using TE cudnn_flash attention
120-
if self.config.attention == "cudnn_flash_te":
121-
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
122118

123-
shard_guard = global_shard_guard(MeshResource(cp_resource="context"))
124-
else:
125-
shard_guard = nullcontext()
126-
127-
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules), shard_guard:
119+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
128120
latents = p_run_inference(
129121
graphdef=graphdef,
130122
sharded_state=state,

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import List, Union, Optional
1818
from ...pyconfig import HyperParameters
1919
from functools import partial
20-
from contextlib import nullcontext
2120
from flax import nnx
2221
from flax.linen import partitioning as nn_partitioning
2322
import jax
@@ -140,15 +139,8 @@ def __call__(
140139
scheduler=self.scheduler,
141140
scheduler_state=scheduler_state,
142141
)
143-
# Set the TE shard_guard context_manager if using TE cudnn_flash attention
144-
if self.config.attention == "cudnn_flash_te":
145-
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
146142

147-
shard_guard = global_shard_guard(MeshResource(cp_resource="context"))
148-
else:
149-
shard_guard = nullcontext()
150-
151-
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules), shard_guard:
143+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
152144
latents = p_run_inference(
153145
low_noise_graphdef=low_noise_graphdef,
154146
low_noise_state=low_noise_state,

0 commit comments

Comments
 (0)