From 8476f2225154db87807ff6953909825e23ca2f68 Mon Sep 17 00:00:00 2001 From: Lee Penkman Date: Tue, 8 Jul 2025 13:41:08 +1200 Subject: [PATCH 1/2] Add DFloat11 integration and example --- flux_dfloat11.py | 79 ++++++++++++++++++++++++++++++++ main.py | 18 ++++++++ readme.md | 9 ++++ requirements.in | 3 +- requirements.txt | 1 + tests/unit/test_env_vars.py | 17 +++++++ tests/unit/test_flux_dfloat11.py | 14 ++++++ 7 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 flux_dfloat11.py create mode 100644 tests/unit/test_env_vars.py create mode 100644 tests/unit/test_flux_dfloat11.py diff --git a/flux_dfloat11.py b/flux_dfloat11.py new file mode 100644 index 0000000..3712051 --- /dev/null +++ b/flux_dfloat11.py @@ -0,0 +1,79 @@ +import os +import torch +from diffusers import FluxPipeline, FluxControlNetPipeline, ControlNetModel +from dfloat11 import DFloat11Model +from argparse import ArgumentParser + + +def is_dfloat11_available() -> bool: + try: + import dfloat11 # noqa: F401 + return True + except Exception: + return False + +parser = ArgumentParser( + description="Generate an image using FLUX with DFloat11 weights" +) +parser.add_argument( + "--prompt", + type=str, + default="A futuristic cityscape at sunset, with flying cars, neon lights, and reflective water canals", +) +parser.add_argument("--save_path", type=str, default="image.png") +parser.add_argument( + "--controlnet", action="store_true", help="Enable line controlnet LoRA" +) + + +def main() -> None: + args = parser.parse_args() + + pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 + ) + pipe.enable_model_cpu_offload() + + # Load DFloat11 weights for the text transformer + model_path = os.getenv("DF11_MODEL_PATH", "DFloat11/FLUX.1-dev-DF11") + DFloat11Model.from_pretrained( + model_path, + device="cpu", + bfloat16_model=pipe.transformer, + ) + + if args.controlnet: + try: + controlnet = ControlNetModel.from_pretrained( + "black-forest-labs/flux-controlnet-canny", + torch_dtype=torch.bfloat16, + ) + cpipe = FluxControlNetPipeline(controlnet=controlnet, **pipe.components) + cpipe.enable_model_cpu_offload() + try: + lora_path = os.getenv( + "CONTROLNET_LORA", "black-forest-labs/flux-controlnet-line-lora" + ) + cpipe.load_lora_weights(lora_path, adapter_name="line") + cpipe.set_adapters(["line"], adapter_weights=[1.0]) + except Exception: + pass + pipe = cpipe + except Exception as e: + print(f"Failed to load ControlNet: {e}") + + image = pipe( + args.prompt, + width=1920, + height=1440, + guidance_scale=3.5, + num_inference_steps=50, + max_sequence_length=512, + generator=torch.Generator(device="cpu").manual_seed(0), + ).images[0] + + image.save(args.save_path) + + +if __name__ == "__main__": + main() diff --git a/main.py b/main.py index f072743..d7f48d0 100644 --- a/main.py +++ b/main.py @@ -102,6 +102,16 @@ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 ) flux_pipe.enable_model_cpu_offload() +try: + from dfloat11 import DFloat11Model + dfloat_path = os.getenv("DF11_MODEL_PATH", "DFloat11/FLUX.1-schnell-DF11") + DFloat11Model.from_pretrained( + dfloat_path, + device="cpu", + bfloat16_model=flux_pipe.transformer, + ) +except Exception as e: + logger.error(f"Failed to load DFloat11 weights: {e}") try: flux_controlnet = ControlNetModel.from_pretrained( @@ -111,6 +121,14 @@ controlnet=flux_controlnet, **flux_pipe.components ) flux_controlnetpipe.enable_model_cpu_offload() + try: + lora_path = os.getenv( + "CONTROLNET_LORA", "black-forest-labs/flux-controlnet-line-lora" + ) + flux_controlnetpipe.load_lora_weights(lora_path, adapter_name="line") + flux_controlnetpipe.set_adapters(["line"], adapter_weights=[1.0]) + except Exception as e: + logger.error(f"Failed to load ControlNet LoRA: {e}") except Exception as e: logger.error(f"Failed to load Flux ControlNet: {e}") flux_controlnetpipe = None diff --git a/readme.md b/readme.md index 150c61c..9c97995 100644 --- a/readme.md +++ b/readme.md @@ -64,6 +64,15 @@ python flux_schnell.py ``` This will generate `flux-schnell.png` using bf16 precision. +### DFloat11 Example +The repository includes an optional script for running the FLUX model using the +experimental DFloat11 weights which drastically reduce memory usage. After +installing the `dfloat11` package, generate an image with: +```bash +python flux_dfloat11.py --prompt "A futuristic city" --save_path myimage.png +``` +Use the `--controlnet` flag to enable the line ControlNet LoRA. + ![gradio demo](gradioimg.png) diff --git a/requirements.in b/requirements.in index 6437b66..747f2b9 100644 --- a/requirements.in +++ b/requirements.in @@ -75,4 +75,5 @@ peft gradio pillow-avif-plugin optimum-quanto -xformers \ No newline at end of file +xformers +dfloat11[cuda12] diff --git a/requirements.txt b/requirements.txt index 1ff99b8..bd0fc81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -588,6 +588,7 @@ websockets==12.0 # gradio-client xformers==0.0.28.post1 # via -r requirements.in +dfloat11[cuda12]==0.2.0 yarl==1.18.0 # via aiohttp zipp==3.21.0 diff --git a/tests/unit/test_env_vars.py b/tests/unit/test_env_vars.py new file mode 100644 index 0000000..5a807b9 --- /dev/null +++ b/tests/unit/test_env_vars.py @@ -0,0 +1,17 @@ +import os +import sys +import pytest +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) + +try: + from flux_dfloat11 import parser +except Exception: + pytest.skip("flux_dfloat11 dependencies missing", allow_module_level=True) + + +def test_env_override(monkeypatch): + monkeypatch.setenv("DF11_MODEL_PATH", "test-model") + args = parser.parse_args([]) + assert args.save_path == "image.png" + # ensure env var accessible + assert os.getenv("DF11_MODEL_PATH") == "test-model" diff --git a/tests/unit/test_flux_dfloat11.py b/tests/unit/test_flux_dfloat11.py new file mode 100644 index 0000000..9b599f0 --- /dev/null +++ b/tests/unit/test_flux_dfloat11.py @@ -0,0 +1,14 @@ +import os +import sys +import pytest +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) +try: + from flux_dfloat11 import parser +except Exception: # pragma: no cover - optional dependency + pytest.skip("flux_dfloat11 dependencies missing", allow_module_level=True) + + +def test_default_args(): + args = parser.parse_args([]) + assert args.prompt + assert args.save_path == "image.png" From 95bba474d9e6ec66e596c63a299ebd9ffff003d3 Mon Sep 17 00:00:00 2001 From: Lee Penkman Date: Tue, 8 Jul 2025 14:43:43 +1200 Subject: [PATCH 2/2] Add options to df11 script --- flux_dfloat11.py | 14 +++++++++++--- main.py | 13 ++++++++----- readme.md | 9 --------- tests/unit/test_env_vars.py | 2 ++ tests/unit/test_flux_dfloat11.py | 2 ++ 5 files changed, 23 insertions(+), 17 deletions(-) diff --git a/flux_dfloat11.py b/flux_dfloat11.py index 3712051..e608e29 100644 --- a/flux_dfloat11.py +++ b/flux_dfloat11.py @@ -24,6 +24,13 @@ def is_dfloat11_available() -> bool: parser.add_argument( "--controlnet", action="store_true", help="Enable line controlnet LoRA" ) +parser.add_argument("--seed", type=int, default=0, help="Random seed") +parser.add_argument( + "--steps", + type=int, + default=50, + help="Number of inference steps", +) def main() -> None: @@ -36,9 +43,10 @@ def main() -> None: # Load DFloat11 weights for the text transformer model_path = os.getenv("DF11_MODEL_PATH", "DFloat11/FLUX.1-dev-DF11") + device = os.getenv("DF11_DEVICE") or ("cuda" if torch.cuda.is_available() else "cpu") DFloat11Model.from_pretrained( model_path, - device="cpu", + device=device, bfloat16_model=pipe.transformer, ) @@ -67,9 +75,9 @@ def main() -> None: width=1920, height=1440, guidance_scale=3.5, - num_inference_steps=50, + num_inference_steps=args.steps, max_sequence_length=512, - generator=torch.Generator(device="cpu").manual_seed(0), + generator=torch.Generator(device=device).manual_seed(args.seed), ).images[0] image.save(args.save_path) diff --git a/main.py b/main.py index d7f48d0..07e5a32 100644 --- a/main.py +++ b/main.py @@ -91,11 +91,14 @@ old_scheduler = pipe.scheduler pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) -if os.path.exists("models/lcm-lora-sdxl"): - pipe.load_lora_weights("models/lcm-lora-sdxl", adapter_name="lcm") -else: - pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl", adapter_name="lcm") -pipe.set_adapters(["lcm"], adapter_weights=[1.0]) +if os.getenv("LOAD_LCM_LORA", "0") == "1": + if os.path.exists("models/lcm-lora-sdxl"): + pipe.load_lora_weights("models/lcm-lora-sdxl", adapter_name="lcm") + else: + pipe.load_lora_weights( + "latent-consistency/lcm-lora-sdxl", adapter_name="lcm" + ) + pipe.set_adapters(["lcm"], adapter_weights=[1.0]) # Load Flux Schnell pipeline for efficient text-to-image flux_pipe = FluxPipeline.from_pretrained( diff --git a/readme.md b/readme.md index 9c97995..150c61c 100644 --- a/readme.md +++ b/readme.md @@ -64,15 +64,6 @@ python flux_schnell.py ``` This will generate `flux-schnell.png` using bf16 precision. -### DFloat11 Example -The repository includes an optional script for running the FLUX model using the -experimental DFloat11 weights which drastically reduce memory usage. After -installing the `dfloat11` package, generate an image with: -```bash -python flux_dfloat11.py --prompt "A futuristic city" --save_path myimage.png -``` -Use the `--controlnet` flag to enable the line ControlNet LoRA. - ![gradio demo](gradioimg.png) diff --git a/tests/unit/test_env_vars.py b/tests/unit/test_env_vars.py index 5a807b9..a77c061 100644 --- a/tests/unit/test_env_vars.py +++ b/tests/unit/test_env_vars.py @@ -11,7 +11,9 @@ def test_env_override(monkeypatch): monkeypatch.setenv("DF11_MODEL_PATH", "test-model") + monkeypatch.setenv("DF11_DEVICE", "cpu") args = parser.parse_args([]) assert args.save_path == "image.png" # ensure env var accessible assert os.getenv("DF11_MODEL_PATH") == "test-model" + assert os.getenv("DF11_DEVICE") == "cpu" diff --git a/tests/unit/test_flux_dfloat11.py b/tests/unit/test_flux_dfloat11.py index 9b599f0..034cc8c 100644 --- a/tests/unit/test_flux_dfloat11.py +++ b/tests/unit/test_flux_dfloat11.py @@ -12,3 +12,5 @@ def test_default_args(): args = parser.parse_args([]) assert args.prompt assert args.save_path == "image.png" + assert args.seed == 0 + assert args.steps == 50