diff --git a/flux_dfloat11.py b/flux_dfloat11.py new file mode 100644 index 0000000..e608e29 --- /dev/null +++ b/flux_dfloat11.py @@ -0,0 +1,87 @@ +import os +import torch +from diffusers import FluxPipeline, FluxControlNetPipeline, ControlNetModel +from dfloat11 import DFloat11Model +from argparse import ArgumentParser + + +def is_dfloat11_available() -> bool: + try: + import dfloat11 # noqa: F401 + return True + except Exception: + return False + +parser = ArgumentParser( + description="Generate an image using FLUX with DFloat11 weights" +) +parser.add_argument( + "--prompt", + type=str, + default="A futuristic cityscape at sunset, with flying cars, neon lights, and reflective water canals", +) +parser.add_argument("--save_path", type=str, default="image.png") +parser.add_argument( + "--controlnet", action="store_true", help="Enable line controlnet LoRA" +) +parser.add_argument("--seed", type=int, default=0, help="Random seed") +parser.add_argument( + "--steps", + type=int, + default=50, + help="Number of inference steps", +) + + +def main() -> None: + args = parser.parse_args() + + pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 + ) + pipe.enable_model_cpu_offload() + + # Load DFloat11 weights for the text transformer + model_path = os.getenv("DF11_MODEL_PATH", "DFloat11/FLUX.1-dev-DF11") + device = os.getenv("DF11_DEVICE") or ("cuda" if torch.cuda.is_available() else "cpu") + DFloat11Model.from_pretrained( + model_path, + device=device, + bfloat16_model=pipe.transformer, + ) + + if args.controlnet: + try: + controlnet = ControlNetModel.from_pretrained( + "black-forest-labs/flux-controlnet-canny", + torch_dtype=torch.bfloat16, + ) + cpipe = FluxControlNetPipeline(controlnet=controlnet, **pipe.components) + cpipe.enable_model_cpu_offload() + try: + lora_path = os.getenv( + "CONTROLNET_LORA", "black-forest-labs/flux-controlnet-line-lora" + ) + cpipe.load_lora_weights(lora_path, adapter_name="line") + cpipe.set_adapters(["line"], adapter_weights=[1.0]) + except Exception: + pass + pipe = cpipe + except Exception as e: + print(f"Failed to load ControlNet: {e}") + + image = pipe( + args.prompt, + width=1920, + height=1440, + guidance_scale=3.5, + num_inference_steps=args.steps, + max_sequence_length=512, + generator=torch.Generator(device=device).manual_seed(args.seed), + ).images[0] + + image.save(args.save_path) + + +if __name__ == "__main__": + main() diff --git a/main.py b/main.py index f072743..07e5a32 100644 --- a/main.py +++ b/main.py @@ -91,17 +91,30 @@ old_scheduler = pipe.scheduler pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) -if os.path.exists("models/lcm-lora-sdxl"): - pipe.load_lora_weights("models/lcm-lora-sdxl", adapter_name="lcm") -else: - pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl", adapter_name="lcm") -pipe.set_adapters(["lcm"], adapter_weights=[1.0]) +if os.getenv("LOAD_LCM_LORA", "0") == "1": + if os.path.exists("models/lcm-lora-sdxl"): + pipe.load_lora_weights("models/lcm-lora-sdxl", adapter_name="lcm") + else: + pipe.load_lora_weights( + "latent-consistency/lcm-lora-sdxl", adapter_name="lcm" + ) + pipe.set_adapters(["lcm"], adapter_weights=[1.0]) # Load Flux Schnell pipeline for efficient text-to-image flux_pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 ) flux_pipe.enable_model_cpu_offload() +try: + from dfloat11 import DFloat11Model + dfloat_path = os.getenv("DF11_MODEL_PATH", "DFloat11/FLUX.1-schnell-DF11") + DFloat11Model.from_pretrained( + dfloat_path, + device="cpu", + bfloat16_model=flux_pipe.transformer, + ) +except Exception as e: + logger.error(f"Failed to load DFloat11 weights: {e}") try: flux_controlnet = ControlNetModel.from_pretrained( @@ -111,6 +124,14 @@ controlnet=flux_controlnet, **flux_pipe.components ) flux_controlnetpipe.enable_model_cpu_offload() + try: + lora_path = os.getenv( + "CONTROLNET_LORA", "black-forest-labs/flux-controlnet-line-lora" + ) + flux_controlnetpipe.load_lora_weights(lora_path, adapter_name="line") + flux_controlnetpipe.set_adapters(["line"], adapter_weights=[1.0]) + except Exception as e: + logger.error(f"Failed to load ControlNet LoRA: {e}") except Exception as e: logger.error(f"Failed to load Flux ControlNet: {e}") flux_controlnetpipe = None diff --git a/requirements.in b/requirements.in index 6437b66..747f2b9 100644 --- a/requirements.in +++ b/requirements.in @@ -75,4 +75,5 @@ peft gradio pillow-avif-plugin optimum-quanto -xformers \ No newline at end of file +xformers +dfloat11[cuda12] diff --git a/requirements.txt b/requirements.txt index 1ff99b8..bd0fc81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -588,6 +588,7 @@ websockets==12.0 # gradio-client xformers==0.0.28.post1 # via -r requirements.in +dfloat11[cuda12]==0.2.0 yarl==1.18.0 # via aiohttp zipp==3.21.0 diff --git a/tests/unit/test_env_vars.py b/tests/unit/test_env_vars.py new file mode 100644 index 0000000..a77c061 --- /dev/null +++ b/tests/unit/test_env_vars.py @@ -0,0 +1,19 @@ +import os +import sys +import pytest +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) + +try: + from flux_dfloat11 import parser +except Exception: + pytest.skip("flux_dfloat11 dependencies missing", allow_module_level=True) + + +def test_env_override(monkeypatch): + monkeypatch.setenv("DF11_MODEL_PATH", "test-model") + monkeypatch.setenv("DF11_DEVICE", "cpu") + args = parser.parse_args([]) + assert args.save_path == "image.png" + # ensure env var accessible + assert os.getenv("DF11_MODEL_PATH") == "test-model" + assert os.getenv("DF11_DEVICE") == "cpu" diff --git a/tests/unit/test_flux_dfloat11.py b/tests/unit/test_flux_dfloat11.py new file mode 100644 index 0000000..034cc8c --- /dev/null +++ b/tests/unit/test_flux_dfloat11.py @@ -0,0 +1,16 @@ +import os +import sys +import pytest +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) +try: + from flux_dfloat11 import parser +except Exception: # pragma: no cover - optional dependency + pytest.skip("flux_dfloat11 dependencies missing", allow_module_level=True) + + +def test_default_args(): + args = parser.parse_args([]) + assert args.prompt + assert args.save_path == "image.png" + assert args.seed == 0 + assert args.steps == 50