Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion deploy/k8s/flux-server/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ spec:
value: "8081"
- name: HF_HOME
value: "/models/flux"
- name: FLUX_QUANTIZE_FP8
value: "true"
# Optional: HuggingFace token for gated models
# - name: HF_TOKEN
# valueFrom:
Expand All @@ -42,7 +44,7 @@ spec:
resources:
requests:
cpu: "2"
memory: "16Gi"
memory: "12Gi"
nvidia.com/gpu: "1"
limits:
nvidia.com/gpu: "1"
Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile.flux-server
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
FROM nvidia/cuda:12.4.1-runtime-ubuntu22.04
FROM nvidia/cuda:12.8.0-runtime-ubuntu22.04

ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
ENV HF_HOME=/models/flux

RUN apt-get update && apt-get install -y --no-install-recommends \
python3 python3-pip python3-venv \
python3 python3-pip python3-venv git \
&& rm -rf /var/lib/apt/lists/*

WORKDIR /app
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
[project.optional-dependencies]
flux = [
"torch>=2.4.0",
"diffusers>=0.33.0",
"diffusers @ git+https://github.com/huggingface/diffusers.git",
"transformers>=4.44.0",
"accelerate>=0.33.0",
"sentencepiece>=0.2.0",
Expand Down
2 changes: 1 addition & 1 deletion requirements-flux.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torch>=2.4.0
diffusers>=0.33.0
diffusers @ git+https://github.com/huggingface/diffusers.git
transformers>=4.44.0
accelerate>=0.33.0
sentencepiece>=0.2.0
Expand Down
42 changes: 25 additions & 17 deletions src/flux_server/model_loader.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
"""Flux model loader with memory optimization for 16GB VRAM GPUs."""
"""Flux model loader with FP8 quantization for 16GB VRAM GPUs."""

import logging
import os
from pathlib import Path

import torch
from diffusers import FluxPipeline
from diffusers import Flux2KleinPipeline

logger = logging.getLogger(__name__)

# Default model for FLUX.1-schnell (Apache 2.0)
DEFAULT_MODEL_ID = "black-forest-labs/FLUX.1-schnell"
# Default model: FLUX.2-klein-4B (Apache 2.0, 4B params)
DEFAULT_MODEL_ID = "black-forest-labs/FLUX.2-klein-4B"


def load_flux_pipeline(
model_id: str = DEFAULT_MODEL_ID,
cache_dir: str | None = None,
device: str = "cuda",
) -> FluxPipeline:
"""Load Flux pipeline with memory optimization for 16GB VRAM.
quantize_fp8: bool = True,
) -> Flux2KleinPipeline:
"""Load FLUX.2-klein pipeline with optional FP8 quantization.

Uses BF16 precision and CPU offloading to fit within a single
16GB VRAM GPU.
At BF16 the model uses ~13GB VRAM. With FP8 quantization on
the transformer, VRAM drops to ~8GB — leaving headroom on
16GB GPUs for larger resolutions or batch work.
"""
if cache_dir is None:
cache_dir = os.environ.get("HF_HOME", "/models/flux")
Expand All @@ -31,31 +33,37 @@ def load_flux_pipeline(

logger.info("Loading Flux pipeline: %s (cache: %s)", model_id, cache_dir)

pipe = FluxPipeline.from_pretrained(
pipe = Flux2KleinPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
cache_dir=cache_dir,
)

# Sequential CPU offloading moves one layer at a time to GPU,
# keeping peak VRAM usage well within 16GB
pipe.enable_sequential_cpu_offload(device=device)
if quantize_fp8:
logger.info("Enabling FP8 layerwise casting on transformer...")
pipe.transformer.enable_layerwise_casting(
storage_dtype=torch.float8_e4m3fn,
compute_dtype=torch.bfloat16,
)
logger.info("FP8 layerwise casting enabled")

logger.info("Flux pipeline loaded successfully with CPU offloading")
pipe.enable_model_cpu_offload(device=device)

logger.info("Flux pipeline loaded successfully (fp8=%s)", quantize_fp8)
return pipe


def generate_image(
pipe: FluxPipeline,
pipe: Flux2KleinPipeline,
prompt: str,
width: int = 1024,
height: int = 1024,
num_inference_steps: int = 4,
seed: int | None = None,
):
"""Generate an image using Flux.1-schnell.
"""Generate an image using FLUX.2-klein-4B.

FLUX.1-schnell uses 4 inference steps and no CFG (guidance_scale=0.0).
Uses 4 inference steps with guidance_scale=1.0.
Returns a PIL Image.
"""
generator = None
Expand All @@ -67,7 +75,7 @@ def generate_image(
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=0.0,
guidance_scale=1.0,
generator=generator,
)

Expand Down
10 changes: 5 additions & 5 deletions src/flux_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import base64
import io
import logging
import os
import time
from contextlib import asynccontextmanager

Expand Down Expand Up @@ -39,8 +40,9 @@ class GenerateResponse(BaseModel):
@asynccontextmanager
async def lifespan(app: FastAPI):
global _pipe
logger.info("Loading Flux model...")
_pipe = load_flux_pipeline()
quantize_fp8 = os.environ.get("FLUX_QUANTIZE_FP8", "true").lower() in ("true", "1", "yes")
logger.info("Loading Flux model (fp8=%s)...", quantize_fp8)
_pipe = load_flux_pipeline(quantize_fp8=quantize_fp8)
logger.info("Flux model ready")
yield
_pipe = None
Expand All @@ -54,7 +56,7 @@ async def lifespan(app: FastAPI):
async def health():
if _pipe is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return {"status": "healthy", "model": "FLUX.1-schnell"}
return {"status": "healthy", "model": "FLUX.2-klein-4B"}


@app.post("/generate", response_model=GenerateResponse)
Expand Down Expand Up @@ -105,8 +107,6 @@ async def generate(req: GenerateRequest):


def main():
import os

import uvicorn

host = os.environ.get("FLUX_HOST", "0.0.0.0")
Expand Down