Skip to content

Commit 68e0696

Browse files
Merge pull request #344 from cpersson-amd:main
PiperOrigin-RevId: 878133314
2 parents 20d650a + e15f3ce commit 68e0696

12 files changed

Lines changed: 32 additions & 39 deletions

src/maxdiffusion/generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from absl import app
2727
from maxdiffusion import (pyconfig, FlaxDDIMScheduler, max_utils)
2828

29+
from maxdiffusion.train_utils import transformer_engine_context
2930
from maxdiffusion.maxdiffusion_utils import rescale_noise_cfg
3031
from flax.linen import partitioning as nn_partitioning
3132
from maxdiffusion.image_processor import VaeImageProcessor
@@ -261,4 +262,5 @@ def main(argv: Sequence[str]) -> None:
261262

262263

263264
if __name__ == "__main__":
264-
app.run(main)
265+
with transformer_engine_context():
266+
app.run(main)

src/maxdiffusion/generate_flux.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging
3535
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
36+
from maxdiffusion.train_utils import transformer_engine_context
3637
from maxdiffusion.max_utils import (
3738
device_put_replicated,
3839
get_memory_allocations,
@@ -492,4 +493,5 @@ def main(argv: Sequence[str]) -> None:
492493

493494

494495
if __name__ == "__main__":
495-
app.run(main)
496+
with transformer_engine_context():
497+
app.run(main)

src/maxdiffusion/generate_flux_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from maxdiffusion import pyconfig, max_logging, max_utils
2727

2828
from maxdiffusion.checkpointing.checkpointing_utils import load_params_from_path
29+
from maxdiffusion.train_utils import transformer_engine_context
2930
from maxdiffusion.max_utils import setup_initial_state
3031

3132

@@ -123,4 +124,5 @@ def main(argv: Sequence[str]) -> None:
123124

124125

125126
if __name__ == "__main__":
126-
app.run(main)
127+
with transformer_engine_context():
128+
app.run(main)

src/maxdiffusion/generate_ltx_video.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline, ConditioningItem
2222
import maxdiffusion.pipelines.ltx_video.crf_compressor as crf_compressor
2323
from maxdiffusion import pyconfig, max_logging
24+
from maxdiffusion.train_utils import transformer_engine_context
2425
import torchvision.transforms.functional as TVF
2526
import imageio
2627
from datetime import datetime
@@ -267,4 +268,5 @@ def main(argv: Sequence[str]) -> None:
267268

268269

269270
if __name__ == "__main__":
270-
app.run(main)
271+
with transformer_engine_context():
272+
app.run(main)

src/maxdiffusion/generate_sdxl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from maxdiffusion import pyconfig, max_utils
3131
from maxdiffusion.image_processor import VaeImageProcessor
32+
from maxdiffusion.train_utils import transformer_engine_context
3233
from maxdiffusion.maxdiffusion_utils import (
3334
get_add_time_ids,
3435
rescale_noise_cfg,
@@ -322,4 +323,5 @@ def main(argv: Sequence[str]) -> None:
322323

323324

324325
if __name__ == "__main__":
325-
app.run(main)
326+
with transformer_engine_context():
327+
app.run(main)

src/maxdiffusion/generate_wan.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2 import WanCheckpointerI2V_2_2
2424
from maxdiffusion import pyconfig, max_logging, max_utils
2525
from absl import app
26+
from maxdiffusion.train_utils import transformer_engine_context
2627
from maxdiffusion.utils import export_to_video
2728
from maxdiffusion.utils.loading_utils import load_image
2829
from google.cloud import storage
@@ -296,4 +297,5 @@ def main(argv: Sequence[str]) -> None:
296297

297298

298299
if __name__ == "__main__":
299-
app.run(main)
300+
with transformer_engine_context():
301+
app.run(main)

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import List, Union, Optional
1818
from ...pyconfig import HyperParameters
1919
from functools import partial
20-
from contextlib import nullcontext
2120
from flax import nnx
2221
from flax.linen import partitioning as nn_partitioning
2322
import jax
@@ -116,15 +115,8 @@ def __call__(
116115
scheduler=self.scheduler,
117116
scheduler_state=scheduler_state,
118117
)
119-
# Set the TE shard_guard context_manager if using TE cudnn_flash attention
120-
if self.config.attention == "cudnn_flash_te":
121-
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
122118

123-
shard_guard = global_shard_guard(MeshResource(cp_resource="context"))
124-
else:
125-
shard_guard = nullcontext()
126-
127-
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules), shard_guard:
119+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
128120
latents = p_run_inference(
129121
graphdef=graphdef,
130122
sharded_state=state,

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import List, Union, Optional
1818
from ...pyconfig import HyperParameters
1919
from functools import partial
20-
from contextlib import nullcontext
2120
from flax import nnx
2221
from flax.linen import partitioning as nn_partitioning
2322
import jax
@@ -140,15 +139,8 @@ def __call__(
140139
scheduler=self.scheduler,
141140
scheduler_state=scheduler_state,
142141
)
143-
# Set the TE shard_guard context_manager if using TE cudnn_flash attention
144-
if self.config.attention == "cudnn_flash_te":
145-
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
146142

147-
shard_guard = global_shard_guard(MeshResource(cp_resource="context"))
148-
else:
149-
shard_guard = nullcontext()
150-
151-
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules), shard_guard:
143+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
152144
latents = p_run_inference(
153145
low_noise_graphdef=low_noise_graphdef,
154146
low_noise_state=low_noise_state,

src/maxdiffusion/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
max_logging,
2323
pyconfig,
2424
)
25+
from maxdiffusion.train_utils import transformer_engine_context
2526

2627
from maxdiffusion.train_utils import (
2728
validate_train_config,
@@ -43,4 +44,5 @@ def main(argv: Sequence[str]) -> None:
4344

4445

4546
if __name__ == "__main__":
46-
app.run(main)
47+
with transformer_engine_context():
48+
app.run(main)

src/maxdiffusion/train_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,11 @@ def transformer_engine_context():
206206
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
207207
# Inform TransformerEngine of MaxDiffusion's physical mesh resources.
208208
mesh_resource = MeshResource(
209-
dp_resource="data",
209+
dp_resource=None,
210210
tp_resource="tensor",
211211
fsdp_resource="fsdp",
212212
pp_resource=None,
213-
cp_resource=None,
213+
cp_resource="context",
214214
)
215215
with global_shard_guard(mesh_resource):
216216
yield

0 commit comments

Comments
 (0)