Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions flux_dfloat11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
import torch
from diffusers import FluxPipeline, FluxControlNetPipeline, ControlNetModel
from dfloat11 import DFloat11Model
from argparse import ArgumentParser


def is_dfloat11_available() -> bool:
try:
import dfloat11 # noqa: F401
return True
except Exception:
return False

parser = ArgumentParser(
description="Generate an image using FLUX with DFloat11 weights"
)
parser.add_argument(
"--prompt",
type=str,
default="A futuristic cityscape at sunset, with flying cars, neon lights, and reflective water canals",
)
parser.add_argument("--save_path", type=str, default="image.png")
parser.add_argument(
"--controlnet", action="store_true", help="Enable line controlnet LoRA"
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--steps",
type=int,
default=50,
help="Number of inference steps",
)


def main() -> None:
args = parser.parse_args()

pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()

# Load DFloat11 weights for the text transformer
model_path = os.getenv("DF11_MODEL_PATH", "DFloat11/FLUX.1-dev-DF11")
device = os.getenv("DF11_DEVICE") or ("cuda" if torch.cuda.is_available() else "cpu")
DFloat11Model.from_pretrained(
model_path,
device=device,
bfloat16_model=pipe.transformer,
)

if args.controlnet:
try:
controlnet = ControlNetModel.from_pretrained(
"black-forest-labs/flux-controlnet-canny",
torch_dtype=torch.bfloat16,
)
cpipe = FluxControlNetPipeline(controlnet=controlnet, **pipe.components)
cpipe.enable_model_cpu_offload()
try:
lora_path = os.getenv(
"CONTROLNET_LORA", "black-forest-labs/flux-controlnet-line-lora"
)
cpipe.load_lora_weights(lora_path, adapter_name="line")
cpipe.set_adapters(["line"], adapter_weights=[1.0])
except Exception:
pass
pipe = cpipe
except Exception as e:
print(f"Failed to load ControlNet: {e}")

image = pipe(
args.prompt,
width=1920,
height=1440,
guidance_scale=3.5,
num_inference_steps=args.steps,
max_sequence_length=512,
generator=torch.Generator(device=device).manual_seed(args.seed),
).images[0]

image.save(args.save_path)


if __name__ == "__main__":
main()
31 changes: 26 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,30 @@
old_scheduler = pipe.scheduler
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)

if os.path.exists("models/lcm-lora-sdxl"):
pipe.load_lora_weights("models/lcm-lora-sdxl", adapter_name="lcm")
else:
pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl", adapter_name="lcm")
pipe.set_adapters(["lcm"], adapter_weights=[1.0])
if os.getenv("LOAD_LCM_LORA", "0") == "1":
if os.path.exists("models/lcm-lora-sdxl"):
pipe.load_lora_weights("models/lcm-lora-sdxl", adapter_name="lcm")
else:
pipe.load_lora_weights(
"latent-consistency/lcm-lora-sdxl", adapter_name="lcm"
)
pipe.set_adapters(["lcm"], adapter_weights=[1.0])

# Load Flux Schnell pipeline for efficient text-to-image
flux_pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
)
flux_pipe.enable_model_cpu_offload()
try:
from dfloat11 import DFloat11Model
dfloat_path = os.getenv("DF11_MODEL_PATH", "DFloat11/FLUX.1-schnell-DF11")
DFloat11Model.from_pretrained(
dfloat_path,
device="cpu",
bfloat16_model=flux_pipe.transformer,
)
except Exception as e:
logger.error(f"Failed to load DFloat11 weights: {e}")

try:
flux_controlnet = ControlNetModel.from_pretrained(
Expand All @@ -111,6 +124,14 @@
controlnet=flux_controlnet, **flux_pipe.components
)
flux_controlnetpipe.enable_model_cpu_offload()
try:
lora_path = os.getenv(
"CONTROLNET_LORA", "black-forest-labs/flux-controlnet-line-lora"
)
flux_controlnetpipe.load_lora_weights(lora_path, adapter_name="line")
flux_controlnetpipe.set_adapters(["line"], adapter_weights=[1.0])
except Exception as e:
logger.error(f"Failed to load ControlNet LoRA: {e}")
except Exception as e:
logger.error(f"Failed to load Flux ControlNet: {e}")
flux_controlnetpipe = None
Expand Down
3 changes: 2 additions & 1 deletion requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,5 @@ peft
gradio
pillow-avif-plugin
optimum-quanto
xformers
xformers
dfloat11[cuda12]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ websockets==12.0
# gradio-client
xformers==0.0.28.post1
# via -r requirements.in
dfloat11[cuda12]==0.2.0
yarl==1.18.0
# via aiohttp
zipp==3.21.0
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/test_env_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
import sys
import pytest
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

try:
from flux_dfloat11 import parser
except Exception:
pytest.skip("flux_dfloat11 dependencies missing", allow_module_level=True)


def test_env_override(monkeypatch):
monkeypatch.setenv("DF11_MODEL_PATH", "test-model")
monkeypatch.setenv("DF11_DEVICE", "cpu")
args = parser.parse_args([])
assert args.save_path == "image.png"
# ensure env var accessible
assert os.getenv("DF11_MODEL_PATH") == "test-model"
assert os.getenv("DF11_DEVICE") == "cpu"
16 changes: 16 additions & 0 deletions tests/unit/test_flux_dfloat11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
import sys
import pytest
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
try:
from flux_dfloat11 import parser
except Exception: # pragma: no cover - optional dependency
pytest.skip("flux_dfloat11 dependencies missing", allow_module_level=True)


def test_default_args():
args = parser.parse_args([])
assert args.prompt
assert args.save_path == "image.png"
assert args.seed == 0
assert args.steps == 50
Loading