diff --git a/deploy/k8s/flux-server/deployment.yaml b/deploy/k8s/flux-server/deployment.yaml index f23a903..43d3241 100644 --- a/deploy/k8s/flux-server/deployment.yaml +++ b/deploy/k8s/flux-server/deployment.yaml @@ -33,6 +33,8 @@ spec: value: "8081" - name: HF_HOME value: "/models/flux" + - name: FLUX_QUANTIZE_FP8 + value: "true" # Optional: HuggingFace token for gated models # - name: HF_TOKEN # valueFrom: @@ -42,7 +44,7 @@ spec: resources: requests: cpu: "2" - memory: "16Gi" + memory: "12Gi" nvidia.com/gpu: "1" limits: nvidia.com/gpu: "1" diff --git a/docker/Dockerfile.flux-server b/docker/Dockerfile.flux-server index a2a25a7..841cad9 100644 --- a/docker/Dockerfile.flux-server +++ b/docker/Dockerfile.flux-server @@ -1,11 +1,11 @@ -FROM nvidia/cuda:12.4.1-runtime-ubuntu22.04 +FROM nvidia/cuda:12.8.0-runtime-ubuntu22.04 ENV DEBIAN_FRONTEND=noninteractive ENV PYTHONUNBUFFERED=1 ENV HF_HOME=/models/flux RUN apt-get update && apt-get install -y --no-install-recommends \ - python3 python3-pip python3-venv \ + python3 python3-pip python3-venv git \ && rm -rf /var/lib/apt/lists/* WORKDIR /app diff --git a/pyproject.toml b/pyproject.toml index 402c2eb..284c5be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ [project.optional-dependencies] flux = [ "torch>=2.4.0", - "diffusers>=0.33.0", + "diffusers @ git+https://github.com/huggingface/diffusers.git", "transformers>=4.44.0", "accelerate>=0.33.0", "sentencepiece>=0.2.0", diff --git a/requirements-flux.txt b/requirements-flux.txt index a2865c1..481b573 100644 --- a/requirements-flux.txt +++ b/requirements-flux.txt @@ -1,5 +1,5 @@ torch>=2.4.0 -diffusers>=0.33.0 +diffusers @ git+https://github.com/huggingface/diffusers.git transformers>=4.44.0 accelerate>=0.33.0 sentencepiece>=0.2.0 diff --git a/src/flux_server/model_loader.py b/src/flux_server/model_loader.py index 302e3e3..9eea5ef 100644 --- a/src/flux_server/model_loader.py +++ b/src/flux_server/model_loader.py @@ -1,27 +1,29 @@ -"""Flux model loader with memory optimization for 16GB VRAM GPUs.""" +"""Flux model loader with FP8 quantization for 16GB VRAM GPUs.""" import logging import os from pathlib import Path import torch -from diffusers import FluxPipeline +from diffusers import Flux2KleinPipeline logger = logging.getLogger(__name__) -# Default model for FLUX.1-schnell (Apache 2.0) -DEFAULT_MODEL_ID = "black-forest-labs/FLUX.1-schnell" +# Default model: FLUX.2-klein-4B (Apache 2.0, 4B params) +DEFAULT_MODEL_ID = "black-forest-labs/FLUX.2-klein-4B" def load_flux_pipeline( model_id: str = DEFAULT_MODEL_ID, cache_dir: str | None = None, device: str = "cuda", -) -> FluxPipeline: - """Load Flux pipeline with memory optimization for 16GB VRAM. + quantize_fp8: bool = True, +) -> Flux2KleinPipeline: + """Load FLUX.2-klein pipeline with optional FP8 quantization. - Uses BF16 precision and CPU offloading to fit within a single - 16GB VRAM GPU. + At BF16 the model uses ~13GB VRAM. With FP8 quantization on + the transformer, VRAM drops to ~8GB — leaving headroom on + 16GB GPUs for larger resolutions or batch work. """ if cache_dir is None: cache_dir = os.environ.get("HF_HOME", "/models/flux") @@ -31,31 +33,37 @@ def load_flux_pipeline( logger.info("Loading Flux pipeline: %s (cache: %s)", model_id, cache_dir) - pipe = FluxPipeline.from_pretrained( + pipe = Flux2KleinPipeline.from_pretrained( model_id, torch_dtype=torch.bfloat16, cache_dir=cache_dir, ) - # Sequential CPU offloading moves one layer at a time to GPU, - # keeping peak VRAM usage well within 16GB - pipe.enable_sequential_cpu_offload(device=device) + if quantize_fp8: + logger.info("Enabling FP8 layerwise casting on transformer...") + pipe.transformer.enable_layerwise_casting( + storage_dtype=torch.float8_e4m3fn, + compute_dtype=torch.bfloat16, + ) + logger.info("FP8 layerwise casting enabled") - logger.info("Flux pipeline loaded successfully with CPU offloading") + pipe.enable_model_cpu_offload(device=device) + + logger.info("Flux pipeline loaded successfully (fp8=%s)", quantize_fp8) return pipe def generate_image( - pipe: FluxPipeline, + pipe: Flux2KleinPipeline, prompt: str, width: int = 1024, height: int = 1024, num_inference_steps: int = 4, seed: int | None = None, ): - """Generate an image using Flux.1-schnell. + """Generate an image using FLUX.2-klein-4B. - FLUX.1-schnell uses 4 inference steps and no CFG (guidance_scale=0.0). + Uses 4 inference steps with guidance_scale=1.0. Returns a PIL Image. """ generator = None @@ -67,7 +75,7 @@ def generate_image( width=width, height=height, num_inference_steps=num_inference_steps, - guidance_scale=0.0, + guidance_scale=1.0, generator=generator, ) diff --git a/src/flux_server/server.py b/src/flux_server/server.py index 890407c..c79e084 100644 --- a/src/flux_server/server.py +++ b/src/flux_server/server.py @@ -4,6 +4,7 @@ import base64 import io import logging +import os import time from contextlib import asynccontextmanager @@ -39,8 +40,9 @@ class GenerateResponse(BaseModel): @asynccontextmanager async def lifespan(app: FastAPI): global _pipe - logger.info("Loading Flux model...") - _pipe = load_flux_pipeline() + quantize_fp8 = os.environ.get("FLUX_QUANTIZE_FP8", "true").lower() in ("true", "1", "yes") + logger.info("Loading Flux model (fp8=%s)...", quantize_fp8) + _pipe = load_flux_pipeline(quantize_fp8=quantize_fp8) logger.info("Flux model ready") yield _pipe = None @@ -54,7 +56,7 @@ async def lifespan(app: FastAPI): async def health(): if _pipe is None: raise HTTPException(status_code=503, detail="Model not loaded") - return {"status": "healthy", "model": "FLUX.1-schnell"} + return {"status": "healthy", "model": "FLUX.2-klein-4B"} @app.post("/generate", response_model=GenerateResponse) @@ -105,8 +107,6 @@ async def generate(req: GenerateRequest): def main(): - import os - import uvicorn host = os.environ.get("FLUX_HOST", "0.0.0.0")