Skip to content

Commit 09f17ac

Browse files
committed
Merge remote-tracking branch 'origin/main' into prisha/ltx2_transformer
2 parents d954532 + cddbf6a commit 09f17ac

18 files changed

+952
-43
lines changed

src/maxdiffusion/generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from absl import app
2727
from maxdiffusion import (pyconfig, FlaxDDIMScheduler, max_utils)
2828

29+
from maxdiffusion.train_utils import transformer_engine_context
2930
from maxdiffusion.maxdiffusion_utils import rescale_noise_cfg
3031
from flax.linen import partitioning as nn_partitioning
3132
from maxdiffusion.image_processor import VaeImageProcessor
@@ -261,4 +262,5 @@ def main(argv: Sequence[str]) -> None:
261262

262263

263264
if __name__ == "__main__":
264-
app.run(main)
265+
with transformer_engine_context():
266+
app.run(main)

src/maxdiffusion/generate_flux.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging
3535
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
36+
from maxdiffusion.train_utils import transformer_engine_context
3637
from maxdiffusion.max_utils import (
3738
device_put_replicated,
3839
get_memory_allocations,
@@ -492,4 +493,5 @@ def main(argv: Sequence[str]) -> None:
492493

493494

494495
if __name__ == "__main__":
495-
app.run(main)
496+
with transformer_engine_context():
497+
app.run(main)

src/maxdiffusion/generate_flux_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from maxdiffusion import pyconfig, max_logging, max_utils
2727

2828
from maxdiffusion.checkpointing.checkpointing_utils import load_params_from_path
29+
from maxdiffusion.train_utils import transformer_engine_context
2930
from maxdiffusion.max_utils import setup_initial_state
3031

3132

@@ -123,4 +124,5 @@ def main(argv: Sequence[str]) -> None:
123124

124125

125126
if __name__ == "__main__":
126-
app.run(main)
127+
with transformer_engine_context():
128+
app.run(main)

src/maxdiffusion/generate_ltx_video.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline, ConditioningItem
2222
import maxdiffusion.pipelines.ltx_video.crf_compressor as crf_compressor
2323
from maxdiffusion import pyconfig, max_logging
24+
from maxdiffusion.train_utils import transformer_engine_context
2425
import torchvision.transforms.functional as TVF
2526
import imageio
2627
from datetime import datetime
@@ -267,4 +268,5 @@ def main(argv: Sequence[str]) -> None:
267268

268269

269270
if __name__ == "__main__":
270-
app.run(main)
271+
with transformer_engine_context():
272+
app.run(main)

src/maxdiffusion/generate_sdxl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from maxdiffusion import pyconfig, max_utils
3131
from maxdiffusion.image_processor import VaeImageProcessor
32+
from maxdiffusion.train_utils import transformer_engine_context
3233
from maxdiffusion.maxdiffusion_utils import (
3334
get_add_time_ids,
3435
rescale_noise_cfg,
@@ -322,4 +323,5 @@ def main(argv: Sequence[str]) -> None:
322323

323324

324325
if __name__ == "__main__":
325-
app.run(main)
326+
with transformer_engine_context():
327+
app.run(main)

src/maxdiffusion/generate_wan.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2 import WanCheckpointerI2V_2_2
2424
from maxdiffusion import pyconfig, max_logging, max_utils
2525
from absl import app
26+
from maxdiffusion.train_utils import transformer_engine_context
2627
from maxdiffusion.utils import export_to_video
2728
from maxdiffusion.utils.loading_utils import load_image
2829
from google.cloud import storage
@@ -296,4 +297,5 @@ def main(argv: Sequence[str]) -> None:
296297

297298

298299
if __name__ == "__main__":
299-
app.run(main)
300+
with transformer_engine_context():
301+
app.run(main)

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ def __init__(
10111011
),
10121012
)
10131013

1014-
self.drop_out = nnx.Dropout(dropout)
1014+
self.drop_out = nnx.Dropout(dropout, deterministic=False)
10151015

10161016
self.norm_q = nnx.data(None)
10171017
self.norm_k = nnx.data(None)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""

0 commit comments

Comments
 (0)