From 6b6b05df72a2aa78929eafe0f5bb8435e750f590 Mon Sep 17 00:00:00 2001 From: Christopher Maher Date: Mon, 23 Mar 2026 11:43:29 -0700 Subject: [PATCH 1/4] feat: upgrade to FLUX.2-klein-4B with FP8 quantization Replace FLUX.1-schnell (12B) with FLUX.2-klein-4B (4B params, Apache 2.0) for significantly improved image quality at lower VRAM usage. Key changes: - Switch from FluxPipeline to Flux2KleinPipeline (diffusers main) - Add FP8 quantization via optimum-quanto (~8GB VRAM vs ~13GB BF16) - Update guidance_scale from 0.0 to 1.0 (Klein uses light guidance) - Use enable_model_cpu_offload instead of sequential offloading - Add FLUX_QUANTIZE_FP8 env var (default true, set false to disable) - Reduce K8s memory request from 16Gi to 12Gi Signed-off-by: Christopher Maher --- deploy/k8s/flux-server/deployment.yaml | 4 ++- pyproject.toml | 3 +- requirements-flux.txt | 3 +- src/flux_server/model_loader.py | 42 +++++++++++++++----------- src/flux_server/server.py | 10 +++--- 5 files changed, 37 insertions(+), 25 deletions(-) diff --git a/deploy/k8s/flux-server/deployment.yaml b/deploy/k8s/flux-server/deployment.yaml index f23a903..43d3241 100644 --- a/deploy/k8s/flux-server/deployment.yaml +++ b/deploy/k8s/flux-server/deployment.yaml @@ -33,6 +33,8 @@ spec: value: "8081" - name: HF_HOME value: "/models/flux" + - name: FLUX_QUANTIZE_FP8 + value: "true" # Optional: HuggingFace token for gated models # - name: HF_TOKEN # valueFrom: @@ -42,7 +44,7 @@ spec: resources: requests: cpu: "2" - memory: "16Gi" + memory: "12Gi" nvidia.com/gpu: "1" limits: nvidia.com/gpu: "1" diff --git a/pyproject.toml b/pyproject.toml index 402c2eb..1a1f8d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,9 +18,10 @@ dependencies = [ [project.optional-dependencies] flux = [ "torch>=2.4.0", - "diffusers>=0.33.0", + "diffusers @ git+https://github.com/huggingface/diffusers.git", "transformers>=4.44.0", "accelerate>=0.33.0", + "optimum-quanto>=0.2.0", "sentencepiece>=0.2.0", "protobuf>=5.28.0", ] diff --git a/requirements-flux.txt b/requirements-flux.txt index a2865c1..92ec20d 100644 --- a/requirements-flux.txt +++ b/requirements-flux.txt @@ -1,7 +1,8 @@ torch>=2.4.0 -diffusers>=0.33.0 +diffusers @ git+https://github.com/huggingface/diffusers.git transformers>=4.44.0 accelerate>=0.33.0 +optimum-quanto>=0.2.0 sentencepiece>=0.2.0 protobuf>=5.28.0 fastapi>=0.115.0 diff --git a/src/flux_server/model_loader.py b/src/flux_server/model_loader.py index 302e3e3..9c31e9c 100644 --- a/src/flux_server/model_loader.py +++ b/src/flux_server/model_loader.py @@ -1,27 +1,29 @@ -"""Flux model loader with memory optimization for 16GB VRAM GPUs.""" +"""Flux model loader with FP8 quantization for 16GB VRAM GPUs.""" import logging import os from pathlib import Path import torch -from diffusers import FluxPipeline +from diffusers import Flux2KleinPipeline logger = logging.getLogger(__name__) -# Default model for FLUX.1-schnell (Apache 2.0) -DEFAULT_MODEL_ID = "black-forest-labs/FLUX.1-schnell" +# Default model: FLUX.2-klein-4B (Apache 2.0, 4B params) +DEFAULT_MODEL_ID = "black-forest-labs/FLUX.2-klein-4B" def load_flux_pipeline( model_id: str = DEFAULT_MODEL_ID, cache_dir: str | None = None, device: str = "cuda", -) -> FluxPipeline: - """Load Flux pipeline with memory optimization for 16GB VRAM. + quantize_fp8: bool = True, +) -> Flux2KleinPipeline: + """Load FLUX.2-klein pipeline with optional FP8 quantization. - Uses BF16 precision and CPU offloading to fit within a single - 16GB VRAM GPU. + At BF16 the model uses ~13GB VRAM. With FP8 quantization on + the transformer, VRAM drops to ~8GB — leaving headroom on + 16GB GPUs for larger resolutions or batch work. """ if cache_dir is None: cache_dir = os.environ.get("HF_HOME", "/models/flux") @@ -31,31 +33,37 @@ def load_flux_pipeline( logger.info("Loading Flux pipeline: %s (cache: %s)", model_id, cache_dir) - pipe = FluxPipeline.from_pretrained( + pipe = Flux2KleinPipeline.from_pretrained( model_id, torch_dtype=torch.bfloat16, cache_dir=cache_dir, ) - # Sequential CPU offloading moves one layer at a time to GPU, - # keeping peak VRAM usage well within 16GB - pipe.enable_sequential_cpu_offload(device=device) + if quantize_fp8: + from optimum.quanto import freeze, qfloat8, quantize - logger.info("Flux pipeline loaded successfully with CPU offloading") + logger.info("Quantizing transformer to FP8...") + quantize(pipe.transformer, weights=qfloat8) + freeze(pipe.transformer) + logger.info("FP8 quantization complete") + + pipe.enable_model_cpu_offload(device=device) + + logger.info("Flux pipeline loaded successfully (fp8=%s)", quantize_fp8) return pipe def generate_image( - pipe: FluxPipeline, + pipe: Flux2KleinPipeline, prompt: str, width: int = 1024, height: int = 1024, num_inference_steps: int = 4, seed: int | None = None, ): - """Generate an image using Flux.1-schnell. + """Generate an image using FLUX.2-klein-4B. - FLUX.1-schnell uses 4 inference steps and no CFG (guidance_scale=0.0). + Uses 4 inference steps with guidance_scale=1.0. Returns a PIL Image. """ generator = None @@ -67,7 +75,7 @@ def generate_image( width=width, height=height, num_inference_steps=num_inference_steps, - guidance_scale=0.0, + guidance_scale=1.0, generator=generator, ) diff --git a/src/flux_server/server.py b/src/flux_server/server.py index 890407c..c79e084 100644 --- a/src/flux_server/server.py +++ b/src/flux_server/server.py @@ -4,6 +4,7 @@ import base64 import io import logging +import os import time from contextlib import asynccontextmanager @@ -39,8 +40,9 @@ class GenerateResponse(BaseModel): @asynccontextmanager async def lifespan(app: FastAPI): global _pipe - logger.info("Loading Flux model...") - _pipe = load_flux_pipeline() + quantize_fp8 = os.environ.get("FLUX_QUANTIZE_FP8", "true").lower() in ("true", "1", "yes") + logger.info("Loading Flux model (fp8=%s)...", quantize_fp8) + _pipe = load_flux_pipeline(quantize_fp8=quantize_fp8) logger.info("Flux model ready") yield _pipe = None @@ -54,7 +56,7 @@ async def lifespan(app: FastAPI): async def health(): if _pipe is None: raise HTTPException(status_code=503, detail="Model not loaded") - return {"status": "healthy", "model": "FLUX.1-schnell"} + return {"status": "healthy", "model": "FLUX.2-klein-4B"} @app.post("/generate", response_model=GenerateResponse) @@ -105,8 +107,6 @@ async def generate(req: GenerateRequest): def main(): - import os - import uvicorn host = os.environ.get("FLUX_HOST", "0.0.0.0") From 5f41068cec577dff55cc52caa4e28818996f9cd8 Mon Sep 17 00:00:00 2001 From: Christopher Maher Date: Mon, 23 Mar 2026 18:28:16 -0700 Subject: [PATCH 2/4] fix: use CUDA devel image for FP8 quantization support The runtime image lacks nvcc/build tools needed by optimum-quanto to JIT-compile Marlin FP8 CUDA kernels. Also adds git for pip install of diffusers from source. Signed-off-by: Christopher Maher --- docker/Dockerfile.flux-server | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile.flux-server b/docker/Dockerfile.flux-server index a2a25a7..1dcd411 100644 --- a/docker/Dockerfile.flux-server +++ b/docker/Dockerfile.flux-server @@ -1,11 +1,11 @@ -FROM nvidia/cuda:12.4.1-runtime-ubuntu22.04 +FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 ENV DEBIAN_FRONTEND=noninteractive ENV PYTHONUNBUFFERED=1 ENV HF_HOME=/models/flux RUN apt-get update && apt-get install -y --no-install-recommends \ - python3 python3-pip python3-venv \ + python3 python3-pip python3-venv git \ && rm -rf /var/lib/apt/lists/* WORKDIR /app From 19ce918d79a00771b039e7dfd161fbcd6a5b18d6 Mon Sep 17 00:00:00 2001 From: Christopher Maher Date: Mon, 23 Mar 2026 19:40:16 -0700 Subject: [PATCH 3/4] fix: bump CUDA to 12.8 for Blackwell GPU support, add python3-dev CUDA 12.4 nvcc doesn't support compute_120 (Blackwell/RTX 50 series). CUDA 12.8 adds Blackwell support. Also adds python3-dev for the Python.h headers needed by optimum-quanto's JIT kernel compilation. Signed-off-by: Christopher Maher --- docker/Dockerfile.flux-server | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile.flux-server b/docker/Dockerfile.flux-server index 1dcd411..137a151 100644 --- a/docker/Dockerfile.flux-server +++ b/docker/Dockerfile.flux-server @@ -1,11 +1,11 @@ -FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 +FROM nvidia/cuda:12.8.0-devel-ubuntu22.04 ENV DEBIAN_FRONTEND=noninteractive ENV PYTHONUNBUFFERED=1 ENV HF_HOME=/models/flux RUN apt-get update && apt-get install -y --no-install-recommends \ - python3 python3-pip python3-venv git \ + python3 python3-pip python3-venv python3-dev git \ && rm -rf /var/lib/apt/lists/* WORKDIR /app From 0fad868bd55666f5dfde501c730be7bdf5d8241a Mon Sep 17 00:00:00 2001 From: Christopher Maher Date: Mon, 23 Mar 2026 22:26:03 -0700 Subject: [PATCH 4/4] fix: replace optimum-quanto with diffusers layerwise casting for FP8 optimum-quanto's Marlin FP8 kernels require JIT compilation and hit runtime contiguity bugs. Switch to diffusers' built-in enable_layerwise_casting (stores weights in FP8, computes in BF16) which needs no external dependencies or CUDA compilation. This also allows switching back to the smaller runtime base image. Signed-off-by: Christopher Maher --- docker/Dockerfile.flux-server | 4 ++-- pyproject.toml | 1 - requirements-flux.txt | 1 - src/flux_server/model_loader.py | 12 ++++++------ 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/docker/Dockerfile.flux-server b/docker/Dockerfile.flux-server index 137a151..841cad9 100644 --- a/docker/Dockerfile.flux-server +++ b/docker/Dockerfile.flux-server @@ -1,11 +1,11 @@ -FROM nvidia/cuda:12.8.0-devel-ubuntu22.04 +FROM nvidia/cuda:12.8.0-runtime-ubuntu22.04 ENV DEBIAN_FRONTEND=noninteractive ENV PYTHONUNBUFFERED=1 ENV HF_HOME=/models/flux RUN apt-get update && apt-get install -y --no-install-recommends \ - python3 python3-pip python3-venv python3-dev git \ + python3 python3-pip python3-venv git \ && rm -rf /var/lib/apt/lists/* WORKDIR /app diff --git a/pyproject.toml b/pyproject.toml index 1a1f8d1..284c5be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,6 @@ flux = [ "diffusers @ git+https://github.com/huggingface/diffusers.git", "transformers>=4.44.0", "accelerate>=0.33.0", - "optimum-quanto>=0.2.0", "sentencepiece>=0.2.0", "protobuf>=5.28.0", ] diff --git a/requirements-flux.txt b/requirements-flux.txt index 92ec20d..481b573 100644 --- a/requirements-flux.txt +++ b/requirements-flux.txt @@ -2,7 +2,6 @@ torch>=2.4.0 diffusers @ git+https://github.com/huggingface/diffusers.git transformers>=4.44.0 accelerate>=0.33.0 -optimum-quanto>=0.2.0 sentencepiece>=0.2.0 protobuf>=5.28.0 fastapi>=0.115.0 diff --git a/src/flux_server/model_loader.py b/src/flux_server/model_loader.py index 9c31e9c..9eea5ef 100644 --- a/src/flux_server/model_loader.py +++ b/src/flux_server/model_loader.py @@ -40,12 +40,12 @@ def load_flux_pipeline( ) if quantize_fp8: - from optimum.quanto import freeze, qfloat8, quantize - - logger.info("Quantizing transformer to FP8...") - quantize(pipe.transformer, weights=qfloat8) - freeze(pipe.transformer) - logger.info("FP8 quantization complete") + logger.info("Enabling FP8 layerwise casting on transformer...") + pipe.transformer.enable_layerwise_casting( + storage_dtype=torch.float8_e4m3fn, + compute_dtype=torch.bfloat16, + ) + logger.info("FP8 layerwise casting enabled") pipe.enable_model_cpu_offload(device=device)