diff --git a/.claude/scheduled_tasks.lock b/.claude/scheduled_tasks.lock new file mode 100644 index 0000000..d72653b --- /dev/null +++ b/.claude/scheduled_tasks.lock @@ -0,0 +1 @@ +{"sessionId":"d8c05a9f-ecd4-4918-ba19-035cdd531a1a","pid":805816,"acquiredAt":1778002935354} \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index fde9b1b..326a4ec 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,182 +1,301 @@ -# CLAUDE.md +# scope-streamdiffusion Plugin — Claude Code Guide -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. +Real-time Stable Diffusion pipeline for Daydream Scope using StreamDiffusion. Supports SD 1.5, SDXL, Turbo models with LCM scheduling, ControlNet, TensorRT acceleration, and multi-model orchestration. -## Scope Plugin Expertise +## Design -You are an expert at building Daydream Scope plugins/nodes/pipelines. Reference documentation: -- **https://docs.daydream.live/scope/tutorials/build-video-effects-plugin** -- **https://docs.daydream.live/scope/tutorials/vibe-code-a-scope-plugin** +This is a Scope `Pipeline` subclass that wraps diffusion inference. The plugin is **entry-point discovered** via `pyproject.toml` and loads into Scope's pipeline selector automatically. -## Project Overview +### Core Principles -This is a **Daydream Scope plugin** that integrates StreamDiffusion for real-time Stable Diffusion video generation. It's not a standalone application—it's designed to be installed and discovered by the Daydream Scope framework via Python entry points. +- **Init/Runtime separation.** `__init__()` loads models once; `__call__(**kwargs)` handles per-frame params (prompt, seed, guidance scale, strength). Parameters can change frame-to-frame without reloading. +- **Tensor format aware.** Scope uses `(T, H, W, C)` in [0, 1]; diffusion expects `(B, C, H, W)`. Conversions happen in `__call__()`. +- **Schema-driven config.** All parameters defined in `schema.py` using Pydantic. UI fields auto-generated via `ui_field_config(order=N, label="...")`. +- **Lazy model loading.** Models load on first init; subsequent calls reuse weights. Model changes trigger full reinitialization. -## Installation & Development +## Project Structure -### Install Plugin (Development Mode) -```bash -pip install -e . -``` -Development mode allows code changes to take effect immediately without reinstalling. - -### Verify Plugin Registration -```bash -python -c "import scope_streamdiffusion; print('Plugin loaded successfully')" -``` - -### Testing in Scope -The plugin is automatically discovered by Scope once installed. Start Scope and look for "StreamDiffusion" in the pipeline selector. - -## Architecture - -### Plugin Structure ``` -src/scope_streamdiffusion/ -├── __init__.py # Plugin registration via @hookimpl -├── schema.py # Configuration schema (UI fields + validation) -└── pipeline.py # Pipeline implementation (model + inference) +. +├── CLAUDE.md # This file +├── README.md # User-facing features and usage +├── ADAPTATION_NOTES.md # How StreamDiffusion was adapted to Scope +├── INSTALL.md # Quick install guide +├── pyproject.toml # Package config, entry point, deps +│ +├── src/scope_streamdiffusion/ +│ ├── __init__.py # Plugin registration via hookimpl +│ ├── schema.py # StreamDiffusionConfig (Pydantic + UI) +│ ├── pipeline.py # StreamDiffusionPipeline (main logic) +│ ├── controlnet.py # ControlNet handler for multi-ControlNet support +│ ├── trt_engines.py # TensorRT engine discovery/caching +│ ├── _trt_cache.py # TensorRT compile cache management +│ │ +│ └── _trt/ # TensorRT backend utilities +│ ├── __init__.py +│ ├── models.py # TRT model configs +│ ├── engine.py # Engine compilation and inference +│ ├── builder.py # ONNX → TRT conversion +│ └── utilities.py # Device/precision helpers +│ +└── (no tests directory yet) ``` -### Entry Point System -The plugin is discovered via `pyproject.toml` entry point: -```toml -[project.entry-points."scope"] -scope_streamdiffusion = "scope_streamdiffusion" -``` -Scope automatically loads all registered plugins at startup. +## Key Files -## Critical Architectural Patterns +**Schema & Configuration:** +- `schema.py` — `StreamDiffusionConfig`: Pydantic model with 50+ fields defining model, scheduler, sampler, guidance, seed, ControlNet setup, TensorRT flags. Fields use `ui_field_config()` for Scope UI auto-generation. -### 1. Initialization vs Runtime Separation +**Pipeline Implementation:** +- `pipeline.py` — `StreamDiffusionPipeline`: implements `Pipeline` interface. Methods: + - `get_config_class()`: returns `StreamDiffusionConfig` + - `prepare(**kwargs) → Requirements`: returns resource hints + - `__call__(**kwargs) → dict`: main inference loop; returns `{"video": tensor}` -**This is the most important pattern in the codebase.** +**ControlNet:** +- `controlnet.py` — `ControlNetHandler`: manages multi-ControlNet attachment, caching, and inference integration. Supports Canny, pose, depth, etc. -- **`__init__()`**: One-time model loading, GPU setup, component initialization - - Loads diffusion model from HuggingFace/local path - - Sets up VAE, UNet, text encoder, scheduler - - Initializes Compel for prompt weighting - - NO runtime parameters here +**TensorRT:** +- `trt_engines.py` — discovers cached engines, auto-selects by device/precision +- `_trt/engine.py` — compiles ONNX models to TensorRT `.engine` format with dynamic shapes +- `_trt_cache.py` — caches compiled engines locally for rapid reuse -- **`__call__(**kwargs)`**: Per-frame processing with runtime parameters - - Receives all generation params (prompt, seed, strength, etc.) from kwargs - - Calls `_prepare_runtime_state()` to set up state from kwargs - - Processes frame and returns `{"video": tensor}` - - Parameters can change between frames without reloading model +**Entry Point:** +- `__init__.py` — `@hookimpl` function that Scope's plugin loader calls at discovery -**Why:** Enables efficient real-time streaming where the model stays loaded but parameters can change dynamically. +## Architecture -### 2. Configuration Schema Pattern +### Inference Flow -All pipeline parameters are defined in `schema.py` using: -```python -class StreamDiffusionConfig(BasePipelineConfig): - param_name: type = Field( - default=value, - description="...", - json_schema_extra=ui_field_config(order=N, label="Display Name") - ) +``` +Input frame (from Scope) + ↓ +[Tensor format conversion] (T,H,W,C) → (B,C,H,W) + ↓ +[Load/prepare model] (on first call; cached after) + ↓ +[Encode prompt] (Compel for weighting; cache embeddings) + ↓ +[VAE encode] frame → latent + ↓ +[ControlNet encode] (if enabled; pre-compute for all conditions) + ↓ +[Denoising loop] + for each step in scheduler: + - Add noise if img2img + - Denoise with UNet + - Apply ControlNet + - Apply guidance + ↓ +[VAE decode] latent → image + ↓ +[Tensor format conversion] (B,C,H,W) → (T,H,W,C) + ↓ +Output (PIL.Image or tensor dict) ``` -- Inherits from `BasePipelineConfig` (provided by Scope) -- Each field gets `ui_field_config()` for UI generation -- `order` determines UI layout order -- Validation happens automatically via Pydantic +### Model Loading Lifecycle -### 3. Tensor Format Conversions +1. **First `__init__`:** Load diffusion model from HuggingFace or local path. Setup VAE, UNet, text encoder, scheduler. Warm up GPU. +2. **Subsequent `__init__` calls:** Reuse loaded weights (unless model_id changed). +3. **Model change:** Trigger full reload (detected via signature comparison). +4. **ControlNet attach:** Load and fuse ControlNet weights; cache encoders. -**Scope's tensor format:** `(T, H, W, C)` normalized to [0, 1] -**Diffusion model format:** `(B, C, H, W)` for processing +### Parameter Handling -Conversions happen in `__call__()`: -```python -# Input: Scope format → Model format -frame = video[0] # (H, W, C) -input_tensor = frame.permute(2, 0, 1).unsqueeze(0) # (1, C, H, W) +**Initialization-time (requires model reload):** +- `model_id`: changes which model to load +- `torch_dtype`: precision (float16 vs float32) +- `acceleration`: xformers vs none -# Output: Model format → Scope format -output = result.permute(0, 2, 3, 1).clamp(0, 1) # (T, H, W, C) -``` +**Runtime (can change per-frame):** +- `prompt`: text input (re-encoded each frame or cached if unchanged) +- `seed`: random seed +- `guidance_scale`: classifier-free guidance strength +- `strength`: how much to denoise (img2img) +- `num_inference_steps`: denoising steps +- `scheduler`: LCM, DDPM, etc. (some require reinit) -### 4. Pipeline Interface Methods +### TensorRT Compilation -Must implement: -- `get_config_class()`: Returns config schema class -- `prepare()`: Returns `Requirements` (e.g., input size) -- `__call__(**kwargs)`: Main processing method +- Disabled by default (requires additional setup). +- When enabled: UNet compiled to device-specific `.engine` file. +- Compilation happens on first inference (slow; ~1-5 min depending on model). +- Cached engines reused on subsequent runs (instant load). +- Cache dir: `~/.cache/scope-streamdiffusion/trt/` -## Key Implementation Details +### ControlNet Support -### StreamDiffusion Specifics -- Uses LCM (Latent Consistency Models) scheduler for fast inference -- Supports batch denoising for better performance -- Single-step denoising with `t_index_list = [0]` for Turbo models -- Delta parameter controls temporal consistency in streams +- Multi-ControlNet: attach multiple conditions (e.g., Canny + pose). +- Conditions pre-computed once per prompt. +- Inference: scales applied per denoising step. +- Encoder caching to avoid re-encoding images. -### Model Support -- SD 1.5, SDXL, SD Turbo, SDXL Turbo -- Auto-detects SDXL vs SD 1.5 for proper prompt encoding -- Supports LoRA loading via `load_lora()` and `fuse_lora()` +## Development Workflow -### Prompt Encoding -- Uses Compel library for advanced prompt weighting -- SDXL requires pooled embeddings + add_time_ids -- SD 1.5 uses standard CLIP embeddings +### Before Starting -### ControlNet -- Basic support implemented but not exposed in UI -- Set via `self.controlnet` and `self.controlnet_pipeline` -- To expose: add fields to `StreamDiffusionConfig` +1. **Read `ADAPTATION_NOTES.md`** — explains how the original StreamDiffusion code was adapted to Scope's architecture. +2. **Understand init/runtime separation** — this is the foundation of how parameters flow. +3. **Check `schema.py`** for existing fields — don't add duplicate params. -## Adding New Parameters +### Adding a New Parameter 1. **Add to schema** (`schema.py`): ```python new_param: float = Field( default=1.0, ge=0.0, - le=2.0, - description="Parameter description", - json_schema_extra=ui_field_config(order=99, label="New Parameter"), + le=10.0, + description="What this does", + json_schema_extra=ui_field_config(order=50, label="New Param"), ) ``` 2. **Use in pipeline** (`pipeline.py`): + - If runtime-safe (doesn't require model reload): read from `kwargs` in `__call__()` + - If initialization-time (e.g., model architecture): pass to `__init__()` and track via signature + +3. **Test in Scope:** + - Run Scope (`SCOPE_REPL` or `scope serve`) + - Select StreamDiffusion pipeline + - Parameter should appear in UI with label and order + +### Adding ControlNet Support + +1. **Update schema** to expose ControlNet config: ```python - def __call__(self, **kwargs) -> dict: - new_param = kwargs.get("new_param", 1.0) - # Use new_param in processing... + controlnet_id: Optional[str] = Field( + default=None, + description="ControlNet model ID", + json_schema_extra=ui_field_config(order=40, label="ControlNet"), + ) + controlnet_conditioning: Optional[str] = Field( + default=None, + description="Encoded ControlNet condition", + ) ``` -3. **(Optional) Add to `_prepare_runtime_state()`** if it affects state initialization +2. **Update pipeline** to attach ControlNet: + ```python + if kwargs.get("controlnet_id"): + controlnet = ControlNetHandler(kwargs["controlnet_id"], device=self.device) + # Attach to diffusion pipeline + ``` -## Important Files Referenced +3. **Reference `controlnet.py`** for handler patterns. -- `ADAPTATION_NOTES.md`: Detailed explanation of how original StreamDiffusion code was adapted to Scope's architecture -- `README.md`: User-facing documentation with features and usage -- `INSTALL.md`: Quick installation guide +### TensorRT Integration -## Dependencies +Only attempt if you have CUDA 12+ and understand TensorRT compilation: -Core dependencies defined in `pyproject.toml`: -- `torch`: PyTorch (requires CUDA for GPU) -- `diffusers`: Stable Diffusion models and pipelines -- `compel`: Advanced prompt weighting -- `logfire`: Logging (Scope requirement) -- `numpy`, `pillow`: Image processing +1. Update `_trt/builder.py` if you need new precision/shape configs. +2. Call `_trt_cache.get_engine()` to auto-compile and cache. +3. Swap UNet for TRT engine in inference loop. +4. See `trt_engines.py` for caching logic. -## Debugging +## Testing -Common issues: -- **Plugin not appearing in Scope**: Check entry point registration -- **Model loading fails**: Verify model path and GPU availability -- **Import errors**: Ensure Scope framework is installed -- **Performance issues**: Enable xformers acceleration, reduce inference steps, or use Turbo models +No test suite exists yet. Manual testing approach: -## Development Workflow +```bash +# 1. Install in dev mode +pip install -e . + +# 2. Start Scope +SCOPE_REPL # or: scope serve + +# 3. In Scope, select StreamDiffusion pipeline + +# 4. Set parameters and verify: +# - Prompt changes take effect immediately +# - Model changes trigger reload (check logs) +# - Output looks reasonable +# - Performance is acceptable +``` + +### Debugging + +```bash +# Check plugin is discovered: +python -c "import scope_streamdiffusion; print('OK')" + +# Check config loads: +python -c "from scope_streamdiffusion import StreamDiffusionConfig; print(StreamDiffusionConfig.__fields__.keys())" + +# Test pipeline init: +from scope_streamdiffusion.pipeline import StreamDiffusionPipeline +p = StreamDiffusionPipeline() +print("Pipeline initialized") +``` + +## Important Constraints + +- **Model reloads are expensive.** Changing `model_id`, `torch_dtype`, or `acceleration` causes full reload (10-30s). +- **VRAM is limited.** Default to float16 and xformers acceleration. SDXL needs 8GB+ VRAM. +- **Scheduler matters.** LCM is fast (1-4 steps); DDPM is slow but more flexible (20-50 steps). Some model + scheduler combos don't work well. +- **TensorRT engines are device-specific.** Moving to a different GPU requires recompilation. +- **Prompt encoding is cached.** If prompt doesn't change, embeddings are reused (fast). If it does, encoding happens every frame (slower). + +## Dependencies & xformers + +**Core deps** (in `pyproject.toml`): +- `torch` — deep learning framework +- `diffusers` — HuggingFace diffusion models +- `logfire` — Scope logging integration +- `numpy`, `pillow` — image processing + +**Optional (xformers acceleration):** +xformers is NOT in dependencies because it ships with strict (often wrong) torch pins that break Scope's GPU stack. + +Install manually after setup, choosing the version for your torch: +```bash +torch 2.9.x → uv pip install --no-deps xformers==0.0.33.post2 +torch 2.10.x → uv pip install --no-deps xformers==0.0.34 +``` + +Use `--no-deps` to skip xformers' bogus torch pin. + +## Scope Integration Points + +**Entry point discovery:** +```toml +[project.entry-points."scope"] +scope_streamdiffusion = "scope_streamdiffusion" +``` +Scope calls `hookimpl()` function in `__init__.py` to register the pipeline. + +**Config schema:** +Schema fields with `ui_field_config()` are discovered by Scope and rendered in the pipeline UI. Changes to schema are reflected on next Scope restart. + +**Requirements:** +`prepare()` returns `Requirements` (e.g., minimum VRAM, input resolution). Scope uses this for validation. + +**Tensor I/O:** +`__call__()` receives `video` tensor from Scope in `(T, H, W, C)` format; must return same format. + +## Common Issues + +| Problem | Cause | Solution | +|---------|-------|----------| +| Plugin doesn't appear in Scope | Entry point not registered | Run `pip install -e .` again; restart Scope | +| Model loading fails | HF auth needed or model not found | Check HF cache; verify internet; login to HuggingFace if needed | +| OOM errors | Model too big for GPU | Use SDXL Turbo instead of base SDXL; reduce batch size; enable xformers | +| Slow inference | No GPU acceleration | Install xformers; check `torch.cuda.is_available()` returns True | +| ControlNet not working | Handler not attached properly | Review `controlnet.py` logic; check config passes condition tensor | +| TensorRT compile fails | CUDA version mismatch | Ensure CUDA 12+; check triton compatibility | + +## Code Style & Conventions + +- Type hints required on all public functions. +- Docstrings on classes and complex methods. +- Config validation via Pydantic (no manual validation). +- Logging via `logfire` (not print). +- No magic constants — all tunable params go in schema. + +## References -1. Make code changes in `src/scope_streamdiffusion/` -2. Changes are immediately available (development mode) -3. Restart Scope to reload plugin -4. Test in Scope UI with various parameters -5. Check Scope logs for errors/warnings +- **Scope plugin tutorials:** https://docs.daydream.live/scope/tutorials/build-video-effects-plugin +- **Diffusers docs:** https://huggingface.co/docs/diffusers +- **StreamDiffusion:** https://github.com/cumulo-autumn/StreamDiffusion +- **Scope Pydantic patterns:** Check other Scope pipelines in `daydreamlive-scope` repo diff --git a/docs/plans/LORA_PLAN.md b/docs/plans/LORA_PLAN.md new file mode 100644 index 0000000..34d9e34 --- /dev/null +++ b/docs/plans/LORA_PLAN.md @@ -0,0 +1,155 @@ +# Plan: LoRA Support + +## Context +`schema.py` has `supports_lora = True` already. `pipeline.py` has stub `load_lora` and `fuse_lora` methods that aren't called. Scope has a `download_lora` endpoint already — verify in the parent repo `daydreamlive-scope`. + +## Schema Changes (`src/scope_streamdiffusion/schema.py`) + +Add a `LoraSpec` model and a `loras` list field on `StreamDiffusionConfig`: +```python +class LoraSpec(BaseModel): + repo_id: str # HF repo or local path + weight_name: Optional[str] = None # for repos with multiple files + adapter_name: str # diffusers adapter name; required for stack/swap + scale: float = 1.0 # 0..2 typical + +class StreamDiffusionConfig(BaseModel): + ... + loras: list[LoraSpec] = Field( + default_factory=list, + json_schema_extra=ui_field_config(order=..., label="LoRAs"), + ) +``` + +Order field: place after model selection but before ControlNet config. Reuse Scope's existing LoRA picker UI if one exists in the parent repo's other pipelines. + +## Loader Wiring (`ModelLoader` post-refactor, or `pipeline.py` if pre-refactor) + +LoRAs attach via `pipe.load_lora_weights(repo_id, weight_name=..., adapter_name=...)`. After loading all requested adapters, call `pipe.set_adapters([names...], adapter_weights=[scales...])`. + +**Lifecycle order:** +1. `ModelLoader._load_model` loads the diffusers pipe. +2. SDXL fp16 VAE swap. +3. **LoRA attach.** Iterate `config.loras`, call `pipe.load_lora_weights` per spec. +4. `pipe.set_adapters(...)` with names + scales. +5. **Do NOT call `fuse_lora`.** Keep adapters live so scales/swaps work without reload. Only fuse before TRT compilation (next step). +6. PromptEncoder.attach, ControlNetHandler.attach. +7. TRTLifecycle.attach. **If TRT is enabled, fuse_lora here** before compiling — TRT bakes weights at compile time, so fused-then-compiled is the only correct path (unless using the refit path — see TRT Refit below). + +## Change Detection +Track a "LoRA signature" (sorted tuple of `(repo_id, weight_name, adapter_name, scale)`) on the model loader. On `_swap_model` / `_ensure_pipe_loaded`: +- Same model + same LoRA signature → no-op. +- Same model + different LoRA signature, **eager mode** → call `pipe.unload_lora_weights()`, then re-attach. Cheap, no reload needed. +- Same model + different LoRA signature, **TRT mode without refit** → full reload required. Treat this as a model swap. Surface the cost in the UI — recompiling SDXL UNet is 10+ minutes. +- Same model + different LoRA signature, **TRT mode with refit-capable engine** → refit (see below). 1–10s instead of 10+ min. +- Scale-only change with same adapters loaded, **eager mode** → `pipe.set_adapters(...)` with new weights. No reload. +- Scale-only change, **TRT non-refit** → full reload. **TRT refit** → refit. + +## Cache Coordination with TRT +The TRT cache key (in `_trt_cache.py` / `trt_engines.py`) must include the LoRA signature. Otherwise two different LoRA stacks will collide on the same cache slot and you'll silently load the wrong engine. Hash the sorted signature into the engine filename. + +When using refit, the cache key for the *engine* uses only the base model + refit-capable flag (LoRA signature does NOT affect the engine identity). The fused weights are applied at refit time. The LoRA signature is tracked separately as the "currently refit-applied state" and used only for change detection. + +## Scope Integration +The user mentioned Scope has a `download_lora` endpoint already. Find it in the parent repo (`daydreamlive-scope`) and confirm: +- Whether it returns a local path or a repo_id. +- Whether the UI already has a LoRA picker in other pipelines that we can match. +- Whether LoRA management is per-pipeline or global. + +Match the existing pattern. Don't invent a new one. + +## Testing +1. Eager SD-Turbo + a single style LoRA from CivitAI (download via Scope, attach via config). +2. Live scale change 0.0 → 1.0 → 1.5. Should update without reload. +3. Live LoRA swap (different adapter). Should be fast (unload + load), no model reload. +4. Toggle TRT on with LoRAs attached. Confirm fuse-then-compile path runs and engine is cached with LoRA-aware key. +5. Live LoRA change with TRT on (non-refit) — confirm full reload + recompile triggers and completes. +6. Stack 2 LoRAs simultaneously. Verify `set_adapters` with multiple names works and scales are independent. +7. SDXL + LoRA (eager and TRT). + +## Out of Scope (defer) +- Multi-LoRA blending UI beyond stack-with-scales. +- LoRA training or merging. + +--- + +# Addendum: TRT Refit Path for LoRAs + +The base plan above says "LoRA change with TRT → full reload" — correct but expensive (10+ min for SDXL). TensorRT's **refit** feature lets you update weights in a built engine without rebuilding it. This is the right answer for live LoRA swaps on TRT. + +## What Refit Buys You +- Engine structure (layers, shapes, fusions) stays compiled. +- Only the weight tensors get re-uploaded. +- Typical refit time: **1–10 seconds** for SDXL UNet vs. 10+ minutes for full rebuild. +- Works for scale changes AND adapter swaps, as long as the LoRA targets the same layers. + +## Build-Time Requirements +The engine must be compiled with refit enabled. Two flags in the TRT builder: +- `BuilderFlag.REFIT` — required. +- `BuilderFlag.STRIP_PLAN` (TRT 10+) — optional but recommended; strips weights from the engine file so you ship a smaller cache and refit at load. Trade-off: load is no longer instant — must refit before first inference. + +**Decision:** use `REFIT` only (not `STRIP_PLAN`). Cached engines stay self-sufficient; refit only runs when LoRAs change. The size penalty for `REFIT`-only is small (~5%) and inference perf is unchanged. + +## Implementation Sketch + +### Builder changes (`src/scope_streamdiffusion/_trt/builder.py` or wherever the network config lives) +Add `network_flags` / `builder_config.flags |= 1 << int(trt.BuilderFlag.REFIT)` to all UNet builders (`build_unet_engine`, `build_unet_sdxl_engine`, `build_unet_with_control_engine`, and the new SDXL+control variant). VAE/TAESD/ControlNet engines don't need it — LoRAs target UNet only (cross-attention layers). + +### Refit at runtime (new method on `TRTLifecycle`) +```python +def refit_lora(self, lora_signature): + # 1. Load base UNet weights into a temporary diffusers UNet (CPU OK). + # 2. Apply LoRA stack to that UNet (load_lora_weights + set_adapters + fuse_lora). + # 3. Use trt.Refitter to push the fused weights into the live engine. + # 4. Discard the temp UNet. +``` + +The refitter API: +```python +refitter = trt.Refitter(self._trt_unet_engine.engine, TRT_LOGGER) +for name in refitter.get_all_weights(): # or get_missing() + weights = fused_unet_state_dict[map_trt_name_to_torch(name)] + refitter.set_named_weights(name, weights) +assert refitter.refit_cuda_engine() +``` + +### Name mapping (the hard part) +TRT weight names come from the ONNX export and don't match diffusers' `state_dict` keys 1:1. You need a map. Two approaches: +1. **Build the map at compile time.** During ONNX export, record the `(torch_param_name → onnx_initializer_name)` mapping and persist it next to the engine in the cache. At refit time, load the map and translate. +2. **Reconstruct the map at refit time** by re-running ONNX export on a dummy UNet with the same architecture and reading the resulting initializer names. Slower but simpler. + +Recommend approach 1. Save the map as `.refit_map.json` alongside the engine file. The TRT cache key already covers architecture variants, so the map is valid for the engine. + +### Cache key change +Refit-capable engines and refit-incapable engines are different artifacts. Add `refit=True` to the cache key path component so old (non-refit) cached engines aren't reused. Old engines stay valid for non-LoRA streams; new ones get used when LoRAs are configured. + +## Updated LoRA Lifecycle (replaces "full reload" branch in the base plan) + +| Change | Eager | TRT (refit-capable engine) | TRT (legacy non-refit engine) | +|---|---|---|---| +| Scale only | `set_adapters` | refit | rebuild | +| Adapter swap, same layers | unload + load + `set_adapters` | refit | rebuild | +| Adapter swap, different layers | same | refit (zero out unused) | rebuild | +| Add ControlNet, etc. | rebuild pipeline state | rebuild engine | rebuild | + +"Different layers" case: if a new LoRA targets layers the previous one didn't, those original-weight slots need to be restored to the base model's weights during refit. The fused-state-dict approach handles this naturally since the temp UNet is built from base weights + new LoRA stack. + +## When to Skip Refit +- First time TRT is enabled with LoRAs configured → fuse first, then build (current plan). Refit only helps on subsequent changes. +- Engine compiled before this feature lands → fall back to rebuild. Detect via the cache-key version bump. +- Refitter reports missing weights → log and rebuild. Don't run a partially-refit engine. + +## Testing (Refit-specific) +1. Cold start with one LoRA + TRT. Confirm engine builds with `REFIT` flag (check `engine.refittable`). +2. Live scale change 0.0 → 1.5. Should complete in <10s, no recompile log. +3. Live adapter swap (different LoRA, same target layers). Same speed. +4. Live adapter swap to a LoRA that targets *additional* layers. Confirm refit covers all weights and output is correct. +5. Stress test: 20 rapid scale/adapter changes. Memory should stay stable (the temp UNet must actually free). +6. SDXL refit specifically — name-map size is larger; verify no missing weights. + +## Risk +- TRT refit name mapping is fiddly. Budget time for debugging the ONNX-name ↔ torch-name mapping. +- Some TRT optimizations bake constants. If a LoRA's effective rank changes the optimal kernel choice, refit produces correct but suboptimal output. Acceptable trade-off. +- `STRIP_PLAN` is tempting but adds first-inference latency. Skip it. + +This makes live LoRA swaps on TRT actually viable instead of "technically supported but never used." diff --git a/docs/plans/README.md b/docs/plans/README.md new file mode 100644 index 0000000..c718541 --- /dev/null +++ b/docs/plans/README.md @@ -0,0 +1,19 @@ +# Plans + +Hand-off plans for the next round of work on `sd-multi-model`. Each plan is self-contained and intended to be executed by another agent without needing the originating conversation. + +- [REFACTOR_PLAN.md](REFACTOR_PLAN.md) — decompose `pipeline.py` into helper classes (`TRTLifecycle`, `ModelLoader`, `InferenceCore`) following the `PromptEncoder` / `ControlNetHandler` pattern. +- [SDXL_CONTROLNET_PLAN.md](SDXL_CONTROLNET_PLAN.md) — wire SDXL ControlNet through the eager and TRT paths (currently raises `NotImplementedError` on TRT for SDXL). +- [LORA_PLAN.md](LORA_PLAN.md) — schema, loader wiring, change detection, and the TRT refit path for live LoRA swaps. + +## Recommended order +1. Refactor (lands first — the LoRA plan assumes the `ModelLoader` and `TRTLifecycle` helpers exist). +2. SDXL ControlNet (independent of LoRA). +3. LoRA (depends on refactor; benefits from but does not require ControlNet work). + +## Architectural pattern (read first) +All three plans assume the helper-class composition pattern. The canonical examples in the repo: +- `src/scope_streamdiffusion/prompt_encoder.py` +- `src/scope_streamdiffusion/controlnet.py` + +Helpers take `(device, dtype)` at construction, gain a pipe back-reference via `attach(pipe, sdxl)`, and expose runtime state as instance attributes. diff --git a/docs/plans/REFACTOR_PLAN.md b/docs/plans/REFACTOR_PLAN.md new file mode 100644 index 0000000..34b3382 --- /dev/null +++ b/docs/plans/REFACTOR_PLAN.md @@ -0,0 +1,116 @@ +# Refactor Plan: pipeline.py Decomposition + +## Goal +Reduce `pipeline.py` from ~1900 lines to a thin orchestrator (~400 lines) by extracting cohesive responsibilities into helper classes. Follow the pattern established by `PromptEncoder` (commit `b1b5478`) and the existing `ControlNetHandler`. + +## Architectural Pattern (non-negotiable — already established) +- Helper class lives in its own module under `src/scope_streamdiffusion/`. +- Constructor takes `(device, dtype)` and any static config. +- `attach(pipe, sdxl: bool)` lifecycle method called from `_ensure_pipe_loaded` and `_swap_model` after the diffusers pipeline is loaded. Helpers re-bind to the new pipe here. +- Helper owns its caches and exposes runtime state as instance attributes the pipeline reads through (e.g., `self.prompts.prompt_embeds`). +- Helper has explicit `reset_caches()` / `release()` methods called on model swap or teardown. +- **No mixins.** Composition only. The user explicitly rejected mixins. + +## Reference Files +- `src/scope_streamdiffusion/prompt_encoder.py` — the template. Read this first. +- `src/scope_streamdiffusion/controlnet.py` — second example of the pattern. +- `src/scope_streamdiffusion/pipeline.py` — the source to extract from. + +## Extraction Order (do them in this order, commit between each) + +### Extraction 1: `TRTLifecycle` → `src/scope_streamdiffusion/trt_lifecycle.py` +**Methods to move:** +- `_ensure_trt_taesd` +- `_ensure_trt_controlnet` +- `_ensure_trt_unet` +- `_setup_trt` +- `_reset_trt_state` +- `_set_acceleration_mode` +- `_deactivate_trt` +- `_trt_setup_args_from_config` + +**Compromise to accept:** these methods currently mutate `self.unet`, `self.controlnet`, `self.vae`, `self._taesd_vae` directly. Don't fight it — give the helper a back-reference to the pipeline (`self.pipe = pipe` set in `attach()`) and have it write through. The win is moving 500 lines of TRT-specific lifecycle code out of the orchestrator, not pretending TRT doesn't touch pipeline state. + +**Caches the helper owns:** `_trt_taesd_paths`, `_trt_controlnet_paths`, `_trt_unet_paths`, `_trt_unet_engine`, `_trt_controlnet_engine`, the `acceleration_mode` last-applied value, and the `_trt_cache` adapter handles. The module-scope `_trt_cache._CACHE` stays where it is — it must survive plugin reinit. + +**Pipeline-side after extraction:** +```python +self.trt = TRTLifecycle(device=self.device, dtype=self.dtype) +# in _ensure_pipe_loaded / _swap_model: +self.trt.attach(self, self.sdxl) +# in __call__'s pre-inference setup: +self.trt.ensure_engines(config, want_control=...) +``` + +**Testing checkpoint after this extraction:** +1. Cold-load each model with `acceleration_mode="trt"`: SD-Turbo, SDXL-Turbo, DMD2. +2. Live-swap from SD-Turbo → SDXL-Turbo → DMD2 → SD-Turbo. Confirm no `context=None` crashes (the band-aid `_ensure_activated` in `_trt/engine.py` should still cover this; if it triggers, that's a regression in the swap teardown path). +3. Toggle ControlNet on SD1.5 + SD-Turbo while running. +4. Switch `acceleration_mode` between `none` / `xformers` / `trt` mid-stream. + +--- + +### Extraction 2: `ModelLoader` → `src/scope_streamdiffusion/model_loader.py` +**Methods to move:** +- `_load_model` +- `_load_preset` +- `_release_pipe_state` +- `_swap_model` +- `_install_sdxl_fp16_vae` +- `_set_taesd` +- `load_lora` (currently a stub — leave as-is, the LoRA plan wires it up) +- `fuse_lora` (stub — same) + +**State the helper owns:** the `MODEL_PRESETS` dict (move it to this module), last-loaded `model_id`, last-loaded preset signature, the SDXL fp16 VAE replacement state, TAESD-installed flag. + +**Compromise:** like TRT, this writes through to `self.pipe`, `self.unet`, `self.vae`, `self.text_encoder`, `self.text_encoder_2`, `self.tokenizer`, `self.tokenizer_2`, `self.scheduler`, `self.sdxl`. Use the back-reference; the goal is consolidation, not purity. + +**Order matters in `attach`/swap flow:** ModelLoader runs first, then PromptEncoder.attach, then ControlNetHandler.attach, then TRTLifecycle.attach. Document this in a comment at the top of `pipeline._ensure_pipe_loaded`. + +**Testing checkpoint:** +1. Cold load each preset. +2. Swap each direction. Verify no double-loaded models in VRAM (`nvidia-smi` while swapping). +3. Verify SDXL fp16 VAE replacement still happens on SDXL-Turbo and DMD2. +4. Verify TAESD eager and TRT both still work. + +--- + +### Extraction 3: `InferenceCore` → `src/scope_streamdiffusion/inference_core.py` +**Methods to move:** +- `_set_timesteps` +- `_initialize_noise` +- `_setup_seed_transition` +- `_slerp_noise` +- `_advance_seed_transition` +- `_cancel_seed_transition` +- `_encode_image` +- `_decode_image` +- `_add_noise` +- `_scheduler_step_batch` +- `_unet_step` +- `_predict_x0_batch` + +**State the helper owns:** `alpha_prod_t_sqrt`, `beta_prod_t_sqrt`, `c_skip`, `c_out`, `sub_timesteps_tensor`, `init_noise`, `x_t_latent_buffer`, the seed-transition fields (`_pending_seed`, `_transition_remaining`, etc.). + +**Reads (not writes) from pipeline:** `self.pipe.prompts.prompt_embeds`, `self.pipe.unet`, `self.pipe.controlnet`, `self.pipe.controlnet_input`, `self.pipe.vae`, `self.pipe.scheduler`. Pass these through the back-reference. + +**`__call__` after this extraction shrinks to roughly:** +```python +def __call__(self, **kwargs): + config = self._validate_config(kwargs) + self._prepare_runtime_state(config) + self.prompts.encode_for_frame(...) + if self.controlnet_handler: + self.controlnet_handler.update(...) + latent = self.inference.run_step(video, config) + return {"video": self.inference.to_scope_format(latent)} +``` + +**Testing checkpoint:** full smoke test — every model × (txt2img / img2img / loopback) × (eager / xformers / TRT) × (with/without negative prompt) × seed transitions. + +## Cross-cutting Rules +- **Don't change behavior.** This is a pure move. If you find a bug, note it in a comment — fix it in a separate commit after the refactor lands. +- **Commit per extraction.** Three commits. Each must pass the testing checkpoint before moving to the next. +- **Don't extract `__init__`, `prepare`, `_prepare_runtime_state`, `__call__`, `get_config_class`, or the schema-driven setters.** These are the orchestrator's job. +- **Don't add abstract base classes or interfaces** for the helpers. Three concrete classes is fine. +- **Don't introduce a `BaseHelper` parent class.** They share a pattern, not behavior. diff --git a/docs/plans/SDXL_CONTROLNET_PLAN.md b/docs/plans/SDXL_CONTROLNET_PLAN.md new file mode 100644 index 0000000..7ace9ba --- /dev/null +++ b/docs/plans/SDXL_CONTROLNET_PLAN.md @@ -0,0 +1,63 @@ +# Plan: SDXL ControlNet Support + +## Context +Current state: `_ensure_trt_unet` has `if want_control: if self.sdxl: raise NotImplementedError(...)`. SD1.5 ControlNet (eager + TRT) works. SDXL ControlNet works in eager mode through diffusers but the TRT path is unimplemented. + +Test target model: `diffusers/controlnet-canny-sdxl-1.0` (paired with `stabilityai/stable-diffusion-xl-base-1.0` or SDXL-Turbo). The DMD2 1-step UNet is a swap — SDXL ControlNet against DMD2 is a stretch goal; verify with the base SDXL UNet first. + +## Eager Path (verify first, may already work) +1. Confirm `ControlNetHandler.update()` correctly produces residuals for SDXL-shape inputs. SDXL UNet expects `added_cond_kwargs={"text_embeds": ..., "time_ids": ...}` — the ControlNet model also needs these. Read `diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl` for the canonical wiring. +2. In `_unet_step` (post-extraction: `InferenceCore.unet_step`), when `self.sdxl and self.controlnet`, call ControlNet with the SDXL aug-conditioning, then pass the residuals into `self.unet` along with `added_cond_kwargs`. +3. If this produces correct output, eager SDXL ControlNet is done. Move to TRT. + +## TRT Path + +### Step 1: New ONNX export wrapper +File: `src/scope_streamdiffusion/_trt/models.py`. Add `UNetSDXLWithControlInputs` modeled on the existing `UNetWithControlInputs` (SD1.5) and `UNetSDXL` (SDXL no-control). + +Inputs (in order — must match adapter feed order): +- `sample` (B, 4, H/8, W/8) +- `timestep` (scalar or (B,)) +- `encoder_hidden_states` (B, 77, 2048) +- `text_embeds` (B, 1280) ← SDXL aug +- `time_ids` (B, 6) ← SDXL aug +- `input_control_00` … `input_control_{N-1}` (down residuals) +- `input_control_middle` (mid residual) + +Output: `latent` (same shape as `sample`). + +Forward should call `self.unet(sample, timestep, encoder_hidden_states, added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids}, down_block_additional_residuals=[...], mid_block_additional_residual=...)`. + +### Step 2: New builder +File: `src/scope_streamdiffusion/trt_engines.py`. Add `build_unet_sdxl_with_control_engine(...)` modeled on `build_unet_with_control_engine` + `build_unet_sdxl_engine`. Use the same dynamic-shape ranges as the SDXL UNet build (512–1024). + +**Known constraints:** +- ONNX export of SDXL UNet runs ~5 GB. Use `external_data` format. The existing `build_unet_sdxl_engine` already does this — copy its handling. +- ControlNet residuals are full-resolution feature maps; this multiplies the export size by ~10–20%. Expect 6 GB ONNX. +- Static shape recommended for first cut. Generalize to dynamic only after a static build runs. +- Compile time: 5–15 minutes on a 4090. Cache aggressively. + +### Step 3: Standalone SDXL ControlNet engine +The SD1.5 path uses a separate `ControlNetEngine` (`src/scope_streamdiffusion/_trt/engine.py`) that produces residuals consumed by `UNet2DConditionModelWithControlEngine`. Mirror this for SDXL: +- Add ONNX wrapper for SDXL ControlNet to `_trt/models.py` (it has the same SDXL aug-conditioning inputs as the UNet wrapper). +- Add builder `build_controlnet_sdxl_engine` in `trt_engines.py`. +- The existing `ControlNetEngine` class in `_trt/engine.py` has hard-coded `block_out_channels=(320, 640, 1280, 1280)` for SD1.5. SDXL ControlNet uses `(320, 640, 1280)` (one fewer block) and produces 9 down residuals + 1 mid (versus 12+1 for SD1.5). Add `ControlNetSDXLEngine` or parameterize `ControlNetEngine` by passing `chans` and `spec` at construction. + +### Step 4: New runtime adapter +File: `src/scope_streamdiffusion/trt_engines.py`. Add `TRTUNetSDXLWithControlAdapter` that exposes the diffusers UNet `__call__` signature and dispatches to `UNet2DConditionModelSDXLWithControlEngine` (also new in `_trt/engine.py`) plus the SDXL ControlNet engine. + +### Step 5: Wire into TRTLifecycle +In `_ensure_trt_unet`, replace the `raise NotImplementedError` with the SDXL+control branch. Path resolution and cache key must include both UNet and ControlNet model IDs. + +### Step 6: TAESD +SDXL TAESD (`madebyollin/taesdxl`) already works via the existing `_ensure_trt_taesd` path. No change needed. + +## Testing +1. Eager SDXL + Canny ControlNet on a webcam frame. Output should track edges. +2. TRT SDXL + Canny ControlNet, same prompt. Confirm visual parity within fp16 tolerance. +3. Live-toggle ControlNet on/off mid-stream on SDXL-Turbo. +4. Swap SD-Turbo (SD1.5) ↔ SDXL-Turbo with ControlNet attached. Confirm correct adapter is selected each time. + +## Out of Scope (defer) +- Multi-ControlNet on SDXL (do single first). +- DMD2 + ControlNet (the 1-step UNet swap may not respect ControlNet residuals correctly — needs separate investigation). diff --git a/libndi-get.sh b/libndi-get.sh new file mode 100755 index 0000000..6362887 --- /dev/null +++ b/libndi-get.sh @@ -0,0 +1,103 @@ +#!/bin/bash +set -e + +# This script downloads and installs the NDI SDK for Linux. +# By default it downloads the NDI SDK v6 for Linux and extracts it to a temporary directory. +# +# Add argument "install" to install the library files to your system. +# Usage: ./libndi-get.sh install + + +LIBNDI_INSTALLER_NAME="Install_NDI_SDK_v6_Linux" +LIBNDI_INSTALLER="$LIBNDI_INSTALLER_NAME.tar.gz" +LIBNDI_INSTALLER_URL="https://downloads.ndi.tv/SDK/NDI_SDK_Linux/$LIBNDI_INSTALLER" + +# Use temporary directory +LIBNDI_TMP=$(mktemp --tmpdir -d ndidisk.XXXXXXX) + +# Check if the temp directory exists and is a directory. +if [[ -d "$LIBNDI_TMP" ]]; then + echo "Temporary directory created at $LIBNDI_TMP" +else + echo "Failed to create a temporary directory." + exit 1 +fi + +# While most of the command are with the folder path, this is needed for the libndi install script to run properly +pushd "$LIBNDI_TMP" + +# Download LIBNDI +# The follwoing should work with tmp folder in the user home directory - but not always... So we do not use it. +# curl -o "$LIBNDI_TMP/$LIBNDI_INSTALLER" $LIBNDI_INSTALLER_URL -f --retry 5 + +# The following is required if the temp directory is not in the user home directory. +curl -L "$LIBNDI_INSTALLER_URL" -f --retry 5 > "$LIBNDI_TMP/$LIBNDI_INSTALLER" + + +# Check if download was successful +if [ $? -ne 0 ]; then + echo "Download failed." + exit 1 +fi + +echo "Download complete." + +# Step 3: Uncompress the file. +echo "Uncompressing..." +tar -xzvf "$LIBNDI_TMP/$LIBNDI_INSTALLER" -C "$LIBNDI_TMP" + +# Check if uncompression was successful +if [ $? -ne 0 ]; then + echo "Uncompression failed." + exit 1 +fi + +echo "Uncompression complete." + + +yes | PAGER="cat" sh "$LIBNDI_INSTALLER_NAME.sh" + + +rm -rf "$LIBNDI_TMP/ndisdk" +echo "Moving things to a folder with no space" +mv "$LIBNDI_TMP/NDI SDK for Linux" "$LIBNDI_TMP/ndisdk" +echo +echo "Contents of $LIBNDI_TMP/ndisdk/lib:" +ls -la "$LIBNDI_TMP/ndisdk/lib" +echo +echo "Contents of $LIBNDI_TMP/ndisdk/lib/x86_64-linux-gnu:" +ls -la "$LIBNDI_TMP/ndisdk/lib/x86_64-linux-gnu" +echo + +popd + +if [ "$1" == "install" ]; then + echo "Copying the library files to the long-term location. You might be prompted for authentication." + sudo cp -P "$LIBNDI_TMP/ndisdk/lib/x86_64-linux-gnu/"* /usr/local/lib/ + sudo ldconfig + + echo "libndi installed to /usr/local/lib" + ls -la "/usr/local/lib/"libndi* + + echo "Adding backward compatibility tweaks for older plugins version to work with NDI v6" + sudo ln -s /usr/local/lib/libndi.so.6 /usr/local/lib/libndi.so.5 + + echo "Clean-up : Removing temporary folder" + rm -rf "$LIBNDI_TMP" + if [[ ! -d "$LIBNDI_TMP" ]]; then + echo "Temporary directory $LIBNDI_TMP does not exist anymore (good!)" + else + echo "Failed to clean-up temporary directory." + echo "Please clean this up manually - All should be in $LIBNDI_TMP" + exit 1 + fi + echo "Installation complete." +else + # Allow to keep the temporary files (to use with libndi-package.sh) + echo "No installation requested. The library files are in $LIBNDI_TMP/ndisdk/lib/x86_64-linux-gnu/" + echo "You can copy them manually to your system if needed." + ls -la "$LIBNDI_TMP/ndisdk/lib/x86_64-linux-gnu/libndi"* +fi + +echo "Script execution Complete." +exit 0 diff --git a/src/scope_streamdiffusion/_trt/__init__.py b/src/scope_streamdiffusion/_trt/__init__.py index 16c7d3e..725b128 100644 --- a/src/scope_streamdiffusion/_trt/__init__.py +++ b/src/scope_streamdiffusion/_trt/__init__.py @@ -24,6 +24,7 @@ AutoencoderKLEngine, ControlNetEngine, UNet2DConditionModelEngine, + UNet2DConditionModelSDXLEngine, UNet2DConditionModelWithControlEngine, ) from .models import ( @@ -33,6 +34,8 @@ ControlNetExportWrapper, UNet, UNetExportWrapperWithControl, + UNetSDXL, + UNetSDXLExportWrapper, UNetWithControlInputs, VAEEncoder, ) @@ -109,14 +112,18 @@ def compile_unet( "TorchVAEEncoder", "UNet", "UNet2DConditionModelEngine", + "UNet2DConditionModelSDXLEngine", "UNet2DConditionModelWithControlEngine", "UNetExportWrapperWithControl", + "UNetSDXL", + "UNetSDXLExportWrapper", "UNetWithControlInputs", "VAE", "VAEEncoder", "build_engine", "compile_controlnet", "compile_unet", + "compile_unet_sdxl", "compile_unet_with_control", "compile_vae_decoder", "compile_vae_encoder", @@ -126,6 +133,41 @@ def compile_unet( ] +def compile_unet_sdxl( + unet: UNet2DConditionModel, + model_data: BaseModel, + onnx_path: str, + onnx_opt_path: str, # noqa: ARG001 — kept for API symmetry; SDXL skips the polygraphy optimizer + engine_path: str, + opt_batch_size: int = 1, + engine_build_options: dict = {}, +): + """Compile an SDXL UNet to TRT — wraps text_embeds/time_ids as positional inputs. + + Differs from `compile_unet` in one important way: **the polygraphy + ONNX optimization pass is skipped**. SDXL's UNet exports to a ~5 GB + ONNX file, and the polygraphy `optimize_onnx` step runs onnxruntime + shape-inference + Unsqueeze elimination passes that load the entire + graph into RAM with ~3-5× overhead — peaks at 20-25 GB and OOMs on + a 32 GB host. TensorRT's builder does its own graph optimization + during engine construction, so the pre-pass is double-work anyway. + + Mechanism: pass the same path for both raw and "optimized" ONNX. + After export the file exists at `onnx_opt_path`, so the EngineBuilder + skips the optimize step and feeds the raw ONNX directly to TRT. + """ + wrapped = UNetSDXLExportWrapper(unet).to( + torch.device("cuda"), dtype=torch.float16 + ).eval() + builder = EngineBuilder(model_data, wrapped, device=torch.device("cuda")) + builder.build( + onnx_path, onnx_path, engine_path, # same path twice: skip polygraphy optimizer + opt_batch_size=opt_batch_size, + use_external_data=True, # SDXL UNet fp16 is ~2.6 GB → must use external-data format + **engine_build_options, + ) + + def compile_unet_with_control( unet, model_data: BaseModel, diff --git a/src/scope_streamdiffusion/_trt/builder.py b/src/scope_streamdiffusion/_trt/builder.py index b2e7154..5c78c33 100644 --- a/src/scope_streamdiffusion/_trt/builder.py +++ b/src/scope_streamdiffusion/_trt/builder.py @@ -47,6 +47,7 @@ def build( force_engine_build: bool = False, force_onnx_export: bool = False, force_onnx_optimize: bool = False, + use_external_data: bool = False, ): if not force_onnx_export and os.path.exists(onnx_path): print(f"Found cached model: {onnx_path}") @@ -58,7 +59,7 @@ def build( self.network, self.controlnet_model ) - + export_onnx( self.network, onnx_path=onnx_path, @@ -67,6 +68,7 @@ def build( opt_image_width=opt_image_width, opt_batch_size=opt_batch_size, onnx_opset=onnx_opset, + use_external_data=use_external_data, ) del self.network gc.collect() diff --git a/src/scope_streamdiffusion/_trt/engine.py b/src/scope_streamdiffusion/_trt/engine.py index 20a4f6d..640c29b 100644 --- a/src/scope_streamdiffusion/_trt/engine.py +++ b/src/scope_streamdiffusion/_trt/engine.py @@ -56,6 +56,60 @@ def forward(self, *args, **kwargs): pass +class UNet2DConditionModelSDXLEngine: + """SDXL UNet engine — adds text_embeds + time_ids inputs to the plain UNet engine.""" + + def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False): + self.engine = Engine(filepath) + self.stream = stream + self.use_cuda_graph = use_cuda_graph + self.engine.load() + self.engine.activate() + + def __call__( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + text_embeds: torch.Tensor, + time_ids: torch.Tensor, + **kwargs, + ) -> Any: + if timestep.dtype != torch.float32: + timestep = timestep.float() + + self.engine.allocate_buffers( + shape_dict={ + "sample": latent_model_input.shape, + "timestep": timestep.shape, + "encoder_hidden_states": encoder_hidden_states.shape, + "text_embeds": text_embeds.shape, + "time_ids": time_ids.shape, + "latent": latent_model_input.shape, + }, + device=latent_model_input.device, + ) + + noise_pred = self.engine.infer( + { + "sample": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "text_embeds": text_embeds, + "time_ids": time_ids, + }, + self.stream, + use_cuda_graph=self.use_cuda_graph, + )["latent"] + return UNet2DConditionOutput(sample=noise_pred) + + def to(self, *args, **kwargs): + pass + + def forward(self, *args, **kwargs): + pass + + class UNet2DConditionModelWithControlEngine: """UNet engine variant that accepts ControlNet residuals as runtime inputs. @@ -233,7 +287,24 @@ def __init__( self.encoder.activate() self.decoder.activate() + def _ensure_activated(self, engine): + """Re-activate the TRT engine if its execution context was lost. + + The cached adapter survives across plugin reinit (module-scope + cache in ``_trt_cache``), but rapid acceleration_mode / model + swaps can leave a previously-activated engine with context=None + if its activation was torn down by a sibling teardown path. + Repairing here keeps the streaming loop alive instead of dying + on `AttributeError: 'NoneType' object has no attribute + 'set_input_shape'`. + """ + if engine.context is None: + if engine.engine is None: + engine.load() + engine.activate() + def encode(self, images: torch.Tensor, **kwargs): + self._ensure_activated(self.encoder) self.encoder.allocate_buffers( shape_dict={ "images": images.shape, @@ -254,6 +325,7 @@ def encode(self, images: torch.Tensor, **kwargs): return AutoencoderTinyOutput(latents=latents) def decode(self, latent: torch.Tensor, **kwargs): + self._ensure_activated(self.decoder) self.decoder.allocate_buffers( shape_dict={ "latent": latent.shape, diff --git a/src/scope_streamdiffusion/_trt/models.py b/src/scope_streamdiffusion/_trt/models.py index 32a8692..693f15f 100644 --- a/src/scope_streamdiffusion/_trt/models.py +++ b/src/scope_streamdiffusion/_trt/models.py @@ -484,6 +484,141 @@ def get_sample_input(self, batch_size, image_height, image_width): ) +class UNetSDXL(BaseModel): + """TRT I/O spec for the SDXL UNet. + + Adds the SDXL-specific aug-conditioning inputs that the SD1.5 UNet + doesn't have: + - text_embeds: pooled output of the 2nd text encoder (dim=1280) + - time_ids: resolution conditioning (dim=6 = orig_size[2] + + crops_top_left[2] + target_size[2]) + + Without these, SDXL UNet's `get_aug_embed` raises `TypeError: argument + of type 'NoneType' is not iterable` because it expects a dict at + `added_cond_kwargs`. + + SDXL standard config: + cross_attention_dim = 2048 + addition_time_embed_dim = 256 + projection_class_embeddings_input_dim = 2816 (= 1280 + 256*6) + """ + + def __init__( + self, + fp16=False, + device="cuda", + max_batch_size=16, + min_batch_size=1, + embedding_dim=2048, + text_maxlen=77, + unet_dim=4, + text_embeds_dim=1280, + time_ids_dim=6, + ): + super(UNetSDXL, self).__init__( + fp16=fp16, + device=device, + max_batch_size=max_batch_size, + min_batch_size=min_batch_size, + embedding_dim=embedding_dim, + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + self.text_embeds_dim = text_embeds_dim + self.time_ids_dim = time_ids_dim + self.name = "UNetSDXL" + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states", "text_embeds", "time_ids"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "timestep": {0: "2B"}, + "encoder_hidden_states": {0: "2B"}, + "text_embeds": {0: "2B"}, + "time_ids": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, max_batch, _, _, _, _, + min_latent_height, max_latent_height, min_latent_width, max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + return { + "sample": [ + (min_batch, self.unet_dim, min_latent_height, min_latent_width), + (batch_size, self.unet_dim, latent_height, latent_width), + (max_batch, self.unet_dim, max_latent_height, max_latent_width), + ], + "timestep": [(min_batch,), (batch_size,), (max_batch,)], + "encoder_hidden_states": [ + (min_batch, self.text_maxlen, self.embedding_dim), + (batch_size, self.text_maxlen, self.embedding_dim), + (max_batch, self.text_maxlen, self.embedding_dim), + ], + "text_embeds": [ + (min_batch, self.text_embeds_dim), + (batch_size, self.text_embeds_dim), + (max_batch, self.text_embeds_dim), + ], + "time_ids": [ + (min_batch, self.time_ids_dim), + (batch_size, self.time_ids_dim), + (max_batch, self.time_ids_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "timestep": (2 * batch_size,), + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "text_embeds": (2 * batch_size, self.text_embeds_dim), + "time_ids": (2 * batch_size, self.time_ids_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn(2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device), + torch.ones((2 * batch_size,), dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + torch.randn(2 * batch_size, self.text_embeds_dim, dtype=dtype, device=self.device), + torch.randn(2 * batch_size, self.time_ids_dim, dtype=dtype, device=self.device), + ) + + +class UNetSDXLExportWrapper(torch.nn.Module): + """Wraps the SDXL UNet so text_embeds/time_ids are positional args. + + Diffusers' UNet expects them inside an `added_cond_kwargs` dict, but + ONNX export prefers positional inputs. Reconstruct the dict here. + """ + + def __init__(self, unet): + super().__init__() + self.unet = unet + + def forward(self, sample, timestep, encoder_hidden_states, text_embeds, time_ids): + out = self.unet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids}, + return_dict=False, + ) + return out[0] + + class UNetExportWrapperWithControl(torch.nn.Module): """Wraps the diffusers UNet so ControlNet residuals are positional inputs. diff --git a/src/scope_streamdiffusion/_trt/utilities.py b/src/scope_streamdiffusion/_trt/utilities.py index 6e82bbc..04997e4 100644 --- a/src/scope_streamdiffusion/_trt/utilities.py +++ b/src/scope_streamdiffusion/_trt/utilities.py @@ -409,12 +409,18 @@ def export_onnx( opt_image_width: int, opt_batch_size: int, onnx_opset: int, + use_external_data: bool = False, ): with torch.inference_mode(), torch.autocast("cuda"): inputs = model_data.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) print('exporting onnx') print(model_data.get_input_names()) print(model_data.get_dynamic_axes()) + # use_external_data: SDXL UNet fp16 is ~2.6 GB, exceeding protobuf's + # 2 GB single-message limit. Without external-data format, the export + # produces malformed ONNX that TRT's parser rejects with + # `Invalid Engine`. SD 1.5 fits in 2 GB so we leave the default off + # for that path to keep the existing single-file cache layout. torch.onnx.export( model, inputs, @@ -428,11 +434,54 @@ def export_onnx( dynamo=False, # force legacy trace-based exporter — dynamo path # produces ONNX with op variants that polygraphy's version_converter # can't migrate (e.g. Resize). Legacy is what prism was tested on. + external_data=use_external_data, # torch 2.9+ name (was use_external_data_format) ) del model gc.collect() torch.cuda.empty_cache() + if use_external_data: + # PyTorch's exporter writes one external file per tensor with only a + # `location` field — no explicit `offset` / `length`. Some tensors + # then trip TRT's parser: + # [E] WeightsContextMemoryMap.cpp:124: Failed to open file: ... + # Consolidate into a single weights.bin with explicit offsets so + # TRT's mmap path sees the canonical layout. Bonus: cache dir goes + # from ~1500 files to 2. + import os, shutil + onnx_dir = os.path.dirname(onnx_path) + weights_name = os.path.basename(onnx_path) + ".weights" + m = onnx.load(onnx_path, load_external_data=True) + # Stage rewrite into a sibling temp dir so we can clear the original + # sidecars cleanly without colliding with the in-progress write. + staging = os.path.join(onnx_dir, "_consolidate_staging") + if os.path.isdir(staging): + shutil.rmtree(staging) + os.makedirs(staging) + staged_onnx = os.path.join(staging, os.path.basename(onnx_path)) + onnx.save_model( + m, staged_onnx, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=weights_name, + size_threshold=1024, + convert_attribute=False, + ) + # Remove the per-tensor sidecar files (everything in onnx_dir except + # the master .onnx and the new weights file we're about to move in). + for entry in os.listdir(onnx_dir): + full = os.path.join(onnx_dir, entry) + if full == staging or full == onnx_path: + continue + if os.path.isfile(full): + os.unlink(full) + # Move staged master + weights into place, then drop staging. + os.replace(staged_onnx, onnx_path) + os.replace(os.path.join(staging, weights_name), os.path.join(onnx_dir, weights_name)) + shutil.rmtree(staging) + del m + gc.collect() + def optimize_onnx( onnx_path: str, diff --git a/src/scope_streamdiffusion/_trt_cache.py b/src/scope_streamdiffusion/_trt_cache.py new file mode 100644 index 0000000..bbcfd3b --- /dev/null +++ b/src/scope_streamdiffusion/_trt_cache.py @@ -0,0 +1,83 @@ +"""Process-wide cache of built TRT adapters keyed by graph node id. + +Scope rebuilds plugin instances on every graph edit (see +`scope/src/scope/server/graph_executor.py`); a fresh `StreamDiffusionPipeline` +loses its in-memory `_trt_*_built` flags and rebuilds engines on first call, +even when the on-disk engine cache hits. Loading and binding a TRT engine +context costs ~hundreds of ms per engine, and ONNX→TRT compile costs minutes +when the disk cache misses. Both are visible stalls during graph edits. + +This module holds the built adapters at module scope so a new pipeline +instance for the same logical node can swap them straight back in without +touching the engine builder. + +Cache key: the user-supplied graph node id when Scope passes it through +`__init__` kwargs. Until that upstream change lands the plugin falls back to +`_anon_`, which is correct for the common single-SD-node setup but +collides if two SD nodes ever coexist with different engine signatures. + +Engines are tied to (model_id, image_height, image_width); changing any of +those invalidates the cached state and forces a clean rebuild. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class CachedTRTState: + signature: tuple # (model_id, height, width) + cuda_stream: Any | None = None + unet_adapter: Any | None = None + unet_has_controlnet: bool = False + cn_adapters: dict[str, Any] = field(default_factory=dict) + taesd_adapter: Any | None = None + + +_CACHE: dict[str, CachedTRTState] = {} + + +def cache_key(node_id: str | None, model_id: str) -> str: + """Return the cache key for this pipeline instance. + + Prefers the node id (stable across graph edits, unique per logical node); + falls back to a model-id-scoped anon key for compatibility with Scope + versions that don't yet pass node_id to plugin __init__. + """ + if node_id: + return f"node:{node_id}" + return f"_anon_{model_id}" + + +def get_or_create(key: str, signature: tuple) -> tuple[CachedTRTState, bool]: + """Look up an entry; return (state, restored). + + `restored=True` means the cached signature matched and the caller should + reuse `state.*_adapter`. `restored=False` means either no entry existed or + the signature changed (engines built for different dims/model); the entry + is reset to a fresh state so callers can populate it after building. + """ + existing = _CACHE.get(key) + if existing is not None and existing.signature == signature: + return existing, True + fresh = CachedTRTState(signature=signature) + _CACHE[key] = fresh + return fresh, False + + +def peek(key: str) -> CachedTRTState | None: + return _CACHE.get(key) + + +def clear(key: str | None = None) -> None: + """Drop one entry, or the whole cache when key is None. + + Adapters hold CUDA memory; clearing here releases the only strong ref + once the previous pipeline instance is also gone. + """ + if key is None: + _CACHE.clear() + else: + _CACHE.pop(key, None) diff --git a/src/scope_streamdiffusion/controlnet.py b/src/scope_streamdiffusion/controlnet.py index a3b9de2..0303e68 100644 --- a/src/scope_streamdiffusion/controlnet.py +++ b/src/scope_streamdiffusion/controlnet.py @@ -85,6 +85,28 @@ def __init__(self, device: torch.device, dtype: torch.dtype): self.input: Optional[torch.Tensor] = None self.scale: float = 1.0 + def release(self) -> None: + """Drop all GPU-resident models and tensors held by this handler. + + Call before swapping the diffusion model — otherwise SD1.5 ControlNets, + depth-anything, and scribble weights stay resident across the swap and + contend with the new model's allocation. Caller is expected to run + ``torch.cuda.empty_cache()`` after this returns. + """ + self._controlnet_cache.clear() + self._depth_model = None + self._depth_hidden_state = None + self._last_depth_shape = None + self._depth_min_ema = None + self._depth_max_ema = None + self._prev_depth_input = None + self._scribble_model = None + self._prev_scribble_input = None + self._depth_norm_mean = None + self._depth_norm_std = None + self.model = None + self.input = None + def update( self, mode: str, diff --git a/src/scope_streamdiffusion/pipeline.py b/src/scope_streamdiffusion/pipeline.py index e59b4f6..848333f 100644 --- a/src/scope_streamdiffusion/pipeline.py +++ b/src/scope_streamdiffusion/pipeline.py @@ -9,21 +9,57 @@ DiffusionPipeline, LCMScheduler, StableDiffusionXLPipeline, + UNet2DConditionModel, ) from diffusers.image_processor import VaeImageProcessor from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( retrieve_latents, ) from scope.core.pipelines.interface import Pipeline, Requirements -from scope.core.pipelines.blending import EmbeddingBlender, parse_transition_config +from . import _trt_cache from .controlnet import ControlNetHandler +from .prompt_encoder import PromptEncoder, normalize_prompts from .schema import StreamDiffusionConfig if TYPE_CHECKING: from scope.core.pipelines.base_schema import BasePipelineConfig +# Curated presets — model_id strings that aren't direct HuggingFace repos but +# describe a (base, distillation) recipe. Extending this dict is how we add +# new 1-step / few-step models to the dropdown without exposing the user to +# the underlying repo plumbing. +# +# Schema currently exposes the `unet_swap` shape. Future shapes: +# "lora": (lora_repo, lora_filename) — fuse a step-LoRA at scale=1.0 onto +# the base. Works for Hyper-SD-1step / SDXL-Lightning-1step ONLY +# after `_set_timesteps` is taught about TCD / Euler schedulers +# (it currently calls LCM-specific +# `scheduler.get_scalings_for_boundary_condition_discrete`). +# "scheduler": SchedulerClass — override the LCMScheduler default in +# _swap_model. Same caveat as above re: `_set_timesteps`. +# "timesteps_override": [int, ...] — pin specific timesteps (Hyper-SD-1step +# wants [800] with TCD). +MODEL_PRESETS: Dict[str, dict] = { + "dmd2-sdxl-1step": { + "base": "stabilityai/stable-diffusion-xl-base-1.0", + # tianweiy/DMD2 ships several distilled UNet checkpoints; the + # 1-step fp16 variant is the SDXL-Turbo equivalent. + "unet_swap": ("tianweiy/DMD2", "dmd2_sdxl_1step_unet_fp16.bin"), + # DMD2 was distilled at this specific timestep — feeding it + # LCMScheduler's default 1-step pick (~979) produces noise. + "timesteps_override": [399], + # DMD2 has CFG distilled into its weights, so its single-shot + # output already looks like a guidance-shaped result. Implicit + # txt2img→img2img loopback re-applies that CFG-shape every frame + # and the chain blows up within a few iterations. Skip the + # implicit fallback; explicit image_loopback=True still works. + "implicit_loopback": False, + }, +} + + # Import or inline the helper utilities class SimilarImageFilter: """Simple similar image filter implementation.""" @@ -56,7 +92,8 @@ def get_config_class(cls) -> type["BasePipelineConfig"]: def __init__( self, device: Optional[torch.device] = None, - model_id: str = "stabilityai/sd-turbo", + model_id: Optional[str] = None, + model_id_or_path: Optional[str] = None, torch_dtype: torch.dtype = torch.float16, **kwargs, # noqa: ARG002 ) -> None: @@ -64,7 +101,9 @@ def __init__( Args: device: Torch device to use - model_id: Model ID or path to load + model_id / model_id_or_path: Model ID or path to load. The schema + field is ``model_id_or_path``; ``model_id`` is accepted as an + alias so older callers keep working. torch_dtype: Data type for tensors """ self.device = ( @@ -78,19 +117,39 @@ def __init__( self.config = kwargs.get("config") or kwargs.get("pipeline_config") print(f"Init - Config object: {self.config}") - # Load the base model - print(f"Loading model: {model_id}") - self._model_id_for_trt = model_id # cache for TRT engine path keying - self.pipe = self._load_model(model_id) - print(f"Model loaded: {self.pipe.__class__.__name__}") - - # Model components - self.text_encoder = self.pipe.text_encoder - self.unet = self.pipe.unet - self.vae = self.pipe.vae - self._full_vae = self.vae # keep reference for toggling + # The schema's field is ``model_id_or_path``. Scope's pipeline_manager + # merges schema defaults into the init kwargs by their declared name, + # so what we see at __init__ is the *schema default*, not the user's + # UI selection — that only arrives via runtime kwargs/config on the + # first __call__. To avoid a spurious "load SD-Turbo, then immediately + # swap to the user's pick" on every startup, defer the actual model + # load to ``_ensure_pipe_loaded`` (called from __call__ once we have + # the runtime selection). The init-time arg is just a tentative + # default in case nothing more authoritative shows up at runtime. + config_model = getattr(self.config, "model_id_or_path", None) if self.config else None + model_id = model_id or config_model or model_id_or_path or "stabilityai/sd-turbo" + self.model_id = model_id + preset = MODEL_PRESETS.get(model_id, {}) + self._timesteps_override = preset.get("timesteps_override") + # CFG-distilled models (DMD2, future Hyper-SD / Lightning) explode in + # the implicit txt2img→img2img loopback because each iteration re- + # applies the model's baked-in guidance shaping. Default True for + # everything else (SD-Turbo, SDXL-Turbo) since the iterative + # refinement is what gives those models their polished t2i look. + self._implicit_loopback: bool = preset.get("implicit_loopback", True) + print(f"[StreamDiffusion] Tentative model: {model_id} (load deferred to first __call__)") + + # Model-dependent attrs are populated by ``_ensure_pipe_loaded``. + self.pipe = None + self.sdxl: bool = False + self.text_encoder = None + self.unet = None + self.vae = None + self._full_vae = None # populated on load self._taesd_vae = None self._using_taesd = False + self.scheduler = None + self.image_processor = None # legacy torch.compile flag — kept so other code paths that read # `_unet_compiled` (e.g. _ensure_trt_unet's "restore eager" branch) @@ -114,7 +173,10 @@ def __init__( self._trt_eager_controlnets: dict[str, Any] = {} # mode -> diffusers ControlNetModel (fallback) self._trt_cuda_stream = None self._trt_eager_unet = None # original; kept for fallback - self._model_id_for_trt: str | None = None + # (height, width, controlnet_mode, use_taesd) of the last _setup_trt call. + # __call__ compares the current values against this and re-runs setup + # only on real divergence — otherwise the per-frame TRT block is a no-op. + self._trt_setup_signature: tuple | None = None # Read acceleration_mode at init from schema defaults / load_params. # The runtime kwargs path is unreliable because moth's 30fps param flood @@ -126,28 +188,26 @@ def __init__( if self._acceleration_mode == "trt": print(f"[TRT] acceleration_mode='trt' detected at init") - # Check if SDXL - self.sdxl: bool = type(self.pipe) is StableDiffusionXLPipeline - - # Setup scheduler - self.scheduler: LCMScheduler = LCMScheduler.from_config( - self.pipe.scheduler.config - ) + # Identify this pipeline instance for the cross-instance TRT adapter + # cache. node_id is the user-supplied graph node id from Scope and is + # stable across graph edits; until upstream Scope passes it through, + # we fall back to a model-scoped anon key (correct for a single SD + # node, the common case). + self._node_id: str | None = kwargs.get("node_id") + self._trt_cache_key: str = _trt_cache.cache_key(self._node_id, model_id) - # Setup image processor - self.image_processor: VaeImageProcessor = VaeImageProcessor( - self.pipe.vae_scale_factor - ) + # Scheduler / image_processor are model-dependent — populated by + # ``_ensure_pipe_loaded`` on the first __call__. - # Setup embedding blender for prompt weighting and interpolation - self.embedding_blender = EmbeddingBlender( - device=self.device, - dtype=self.dtype, - ) + # Prompt encoding (text-encode, blending, transitions, negative + # subtraction) lives on its own helper. ``attach()`` wires it to + # the live pipe at load time and on every model swap. Inference + # reads ``self.prompts.prompt_embeds`` / ``add_text_embeds`` / + # ``add_time_ids`` directly. + self.prompts = PromptEncoder(self.device, self.dtype) # State that will be set during runtime self.generator = torch.Generator(device=self.device) - self._previous_prompt_embeddings = None self.similar_filter = SimilarImageFilter() self.prev_image_result = None self.inference_time_ema = 0 @@ -181,21 +241,64 @@ def __init__( ) self._last_seed: int | None = None self._noise_shape: tuple | None = None # (batch_size, latent_h, latent_w) - self._prompts_key: tuple | None = None - self._cached_base_embed: torch.Tensor | None = None # (1, seq_len, hidden_dim) - # Transition state — the main embedding queue lives inside - # EmbeddingBlender; the pooled embedding (SDXL only) is interpolated - # linearly in lockstep here so `add_text_embeds` tracks the morph. - self._last_transition_id: str | None = None - self._pooled_source: torch.Tensor | None = None - self._pooled_target: torch.Tensor | None = None - self._transition_total_steps: int = 0 + # Seed transition state — when seed_transition_steps > 0, lerp + # `init_noise` from the previous seed's tensor to the new seed's + # tensor over N frames instead of hard-swapping. SDXL-Turbo / + # DMD2-1step have weaker stock_noise feedback than SD-Turbo, so + # without this seed changes read as hard cuts. + self._seed_transition_source: torch.Tensor | None = None + self._seed_transition_target: torch.Tensor | None = None + self._seed_transition_progress: int = 0 + self._seed_transition_total: int = 0 # Mode-transition tracking — detect video↔text switches without a pipeline reload self._last_mode: str | None = None - print("StreamDiffusion pipeline initialized") + # TRT setup is deferred along with the model load — engines need + # ``self.pipe.unet`` to exist. ``_ensure_pipe_loaded`` runs + # ``_setup_trt`` immediately after loading when acceleration_mode + # is 'trt', so the first frame still pays the build cost up-front + # rather than mid-stream. + + print("StreamDiffusion pipeline initialized (model load deferred)") + + def _ensure_pipe_loaded(self, model_id: str) -> None: + """Load the diffusion model and populate model-dependent state. + + Called once from the first ``__call__`` with the user's actual + ``model_id_or_path`` from runtime kwargs/config. Doing the load here + instead of in ``__init__`` avoids a wasted "load schema default, + immediately swap to user's pick" cycle, since Scope's + pipeline_manager only forwards schema defaults at __init__ time. + Subsequent runtime model changes go through ``_swap_model``. + """ + if self.pipe is not None: + return + print(f"[StreamDiffusion] Loading model: {model_id}") + self.model_id = model_id + preset = MODEL_PRESETS.get(model_id, {}) + self._timesteps_override = preset.get("timesteps_override") + self._implicit_loopback = preset.get("implicit_loopback", True) + self._trt_cache_key = _trt_cache.cache_key(self._node_id, model_id) + self.pipe = self._load_model(model_id) + print(f"[StreamDiffusion] Model loaded: {self.pipe.__class__.__name__}") + + self.sdxl = type(self.pipe) is StableDiffusionXLPipeline + if self.sdxl and self.dtype == torch.float16: + self._install_sdxl_fp16_vae() + + self.text_encoder = self.pipe.text_encoder + self.unet = self.pipe.unet + self.vae = self.pipe.vae + self._full_vae = self.vae + self._using_taesd = False + self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) + self.image_processor = VaeImageProcessor(self.pipe.vae_scale_factor) + self.prompts.attach(self.pipe, self.sdxl) + + if self._acceleration_mode == "trt": + self._setup_trt(**self._trt_setup_args_from_config()) def _ensure_trt_taesd(self) -> None: """Build TRT engines for the TAESD encoder + decoder, swap self.vae. @@ -208,6 +311,23 @@ def _ensure_trt_taesd(self) -> None: return if self._taesd_vae is None: return + + signature = (self.model_id, int(self.height), int(self.width)) + cache_state, restored = _trt_cache.get_or_create(self._trt_cache_key, signature) + if self._trt_cuda_stream is None and cache_state.cuda_stream is not None: + self._trt_cuda_stream = cache_state.cuda_stream + if restored and cache_state.taesd_adapter is not None: + self._trt_eager_taesd = self._taesd_vae + self._taesd_vae = cache_state.taesd_adapter + if self._using_taesd: + self.vae = cache_state.taesd_adapter + self._trt_taesd_built = True + print( + f"[TRT] TAESD adapter restored from cache (key={self._trt_cache_key})", + flush=True, + ) + return + self._trt_taesd_built = True # prevent retry on failure from .trt_engines import ( @@ -217,14 +337,16 @@ def _ensure_trt_taesd(self) -> None: ) if self._trt_cuda_stream is None: self._trt_cuda_stream = make_cuda_stream() + cache_state.cuda_stream = self._trt_cuda_stream print( "[TRT] Preparing TAESD engines — first build takes ~1 min, cached after", flush=True, ) try: + taesd_model_id = "madebyollin/taesdxl" if self.sdxl else "madebyollin/taesd" enc_path, dec_path = build_taesd_engines( self._taesd_vae, - model_id="madebyollin/taesd", + model_id=taesd_model_id, image_height=int(self.height), image_width=int(self.width), min_batch_size=1, @@ -245,6 +367,7 @@ def _ensure_trt_taesd(self) -> None: self._taesd_vae = adapter if self._using_taesd: self.vae = adapter + cache_state.taesd_adapter = adapter print(f"[TRT] TAESD engines active: enc={enc_path.name}, dec={dec_path.name}", flush=True) def _ensure_trt_controlnet(self, mode: str) -> None: @@ -265,6 +388,24 @@ def _ensure_trt_controlnet(self, mode: str) -> None: self._trt_eager_controlnets[mode] = self._cn.model self.controlnet = adapter return + + signature = (self.model_id, int(self.height), int(self.width)) + cache_state, restored = _trt_cache.get_or_create(self._trt_cache_key, signature) + if self._trt_cuda_stream is None and cache_state.cuda_stream is not None: + self._trt_cuda_stream = cache_state.cuda_stream + cached_cn = cache_state.cn_adapters.get(mode) if restored else None + if cached_cn is not None: + self._trt_eager_controlnets[mode] = self._cn.model + self._trt_cn_engines[mode] = cached_cn + self._trt_cn_built_modes.add(mode) + self.controlnet = cached_cn + print( + f"[TRT] ControlNet adapter restored from cache " + f"(mode={mode}, key={self._trt_cache_key})", + flush=True, + ) + return + self._trt_cn_built_modes.add(mode) # mark before build to prevent retry storm from .trt_engines import ( @@ -275,6 +416,7 @@ def _ensure_trt_controlnet(self, mode: str) -> None: if self._trt_cuda_stream is None: self._trt_cuda_stream = make_cuda_stream() + cache_state.cuda_stream = self._trt_cuda_stream # ControlNet ONNX export needs default attention too (same xformers # issue as the UNet path). @@ -290,7 +432,7 @@ def _ensure_trt_controlnet(self, mode: str) -> None: try: engine_path = build_controlnet_engine( self._cn.model, - model_id=self._model_id_for_trt or "stabilityai/sd-turbo", + model_id=self.model_id, controlnet_id=mode, image_height=int(self.height), image_width=int(self.width), @@ -304,9 +446,15 @@ def _ensure_trt_controlnet(self, mode: str) -> None: self._trt_eager_controlnets[mode] = self._cn.model self._trt_cn_engines[mode] = adapter self.controlnet = adapter + cache_state.cn_adapters[mode] = adapter print(f"[TRT] ControlNet engine active ({mode}): {engine_path}", flush=True) - def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: + def _ensure_trt_unet( + self, + controlnet_mode: str = "none", + image_height: int | None = None, + image_width: int | None = None, + ) -> None: """Build TRT engine for the UNet and swap self.unet to the adapter. Two variants depending on controlnet_mode: @@ -317,6 +465,12 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: residuals get silently dropped and ControlNet conditioning has no effect on the output. + ``image_height`` / ``image_width`` should be the runtime spatial + dims for this build. Falls back to ``self.height`` / ``self.width`` + when omitted, but caller should pass them explicitly because + ``_prepare_runtime_state`` (which sets ``self.{height,width}``) + normally runs *after* this method. + Engines are cached separately on disk because they have different signatures. Switching modes mid-process may trigger a rebuild. """ @@ -331,6 +485,28 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: f"(had={self._trt_unet_has_controlnet}, want={want_control}); rebuilding" ) + eff_h = int(image_height if image_height is not None else self.height) + eff_w = int(image_width if image_width is not None else self.width) + signature = (self.model_id, eff_h, eff_w) + cache_state, restored = _trt_cache.get_or_create(self._trt_cache_key, signature) + if self._trt_cuda_stream is None and cache_state.cuda_stream is not None: + self._trt_cuda_stream = cache_state.cuda_stream + if ( + restored + and cache_state.unet_adapter is not None + and cache_state.unet_has_controlnet == want_control + ): + self._trt_eager_unet = self.pipe.unet + self.unet = cache_state.unet_adapter + self._trt_unet_built = True + self._trt_unet_has_controlnet = want_control + print( + f"[TRT] UNet adapter restored from cache " + f"(want_control={want_control}, key={self._trt_cache_key})", + flush=True, + ) + return + # Set sticky flags before the build so failures don't retry every frame. self._trt_unet_built = True self._trt_unet_has_controlnet = want_control @@ -353,25 +529,39 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: from .trt_engines import ( TRTUNetAdapter, + TRTUNetSDXLAdapter, TRTUNetWithControlAdapter, build_unet_engine, + build_unet_sdxl_engine, build_unet_with_control_engine, make_cuda_stream, ) if self._trt_cuda_stream is None: self._trt_cuda_stream = make_cuda_stream() + cache_state.cuda_stream = self._trt_cuda_stream if want_control: + if self.sdxl: + # SDXL + ControlNet + TRT: not yet wired. The ControlNet + # path uses UNetWithControlInputs which assumes SD1.5 + # signature (no text_embeds/time_ids). Falling through to + # eager keeps SDXL+ControlNet working until that variant + # gets the same SDXL aug-conditioning treatment as + # build_unet_sdxl_engine. + raise NotImplementedError( + "SDXL + ControlNet + TRT not yet supported. Use " + "acceleration_mode='none' with controlnet on SDXL models." + ) print( "[TRT] Preparing UNet+ctrl engine — first build takes 5-10 min, cached after", flush=True, ) engine_path = build_unet_with_control_engine( self.pipe.unet, - model_id=self._model_id_for_trt or "stabilityai/sd-turbo", - image_height=int(self.height), - image_width=int(self.width), + model_id=self.model_id, + image_height=eff_h, + image_width=eff_w, min_batch_size=1, max_batch_size=4, num_down_residuals=12, @@ -380,7 +570,53 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: self.unet = TRTUNetWithControlAdapter( engine_path, self._trt_cuda_stream, num_down_residuals=12, ) + cache_state.unet_adapter = self.unet + cache_state.unet_has_controlnet = True print(f"[TRT] UNet+ctrl engine active: {engine_path}", flush=True) + elif self.sdxl: + print( + "[TRT] Preparing SDXL UNet engine — first build takes 5-10 min, cached after", + flush=True, + ) + # TRT's builder TACTIC_DRAM allocator can race our resident + # pipeline allocations during engine build. Free what we can + # (VAE + text encoders, ~5 GB combined) without disturbing + # the UNet — the ONNX tracer needs it on GPU. Restore after + # the build completes. + print("[TRT] Moving VAE + text encoders to CPU during build", flush=True) + cpu_components = [] + for attr in ("vae", "text_encoder", "text_encoder_2"): + comp = getattr(self.pipe, attr, None) + if comp is not None and hasattr(comp, "to"): + try: + comp.to("cpu") + cpu_components.append((attr, comp)) + except Exception as e: + print(f"[TRT] could not move {attr} to CPU: {e}", flush=True) + torch.cuda.empty_cache() + # batch=1 + static shape (set in build_unet_sdxl_engine) make + # TRT's tactic search bounded enough to fit on a 24 GB card. + # Engine is only valid at the (height, width, batch=1) profile + # it was built for; resolution changes will trigger a rebuild. + engine_path = build_unet_sdxl_engine( + self.pipe.unet, + model_id=self.model_id, + image_height=eff_h, + image_width=eff_w, + min_batch_size=1, + max_batch_size=1, + ) + print("[TRT] Restoring VAE + text encoders to GPU", flush=True) + for attr, comp in cpu_components: + try: + comp.to(self.device) + except Exception as e: + print(f"[TRT] could not restore {attr} to {self.device}: {e}", flush=True) + self._trt_eager_unet = self.pipe.unet + self.unet = TRTUNetSDXLAdapter(engine_path, self._trt_cuda_stream) + cache_state.unet_adapter = self.unet + cache_state.unet_has_controlnet = False + print(f"[TRT] SDXL UNet engine active: {engine_path}", flush=True) else: print( "[TRT] Preparing UNet engine — first build takes 5-10 min, cached after", @@ -388,28 +624,171 @@ def _ensure_trt_unet(self, controlnet_mode: str = "none") -> None: ) engine_path = build_unet_engine( self.pipe.unet, - model_id=self._model_id_for_trt or "stabilityai/sd-turbo", - image_height=int(self.height), - image_width=int(self.width), + model_id=self.model_id, + image_height=eff_h, + image_width=eff_w, min_batch_size=1, max_batch_size=4, ) self._trt_eager_unet = self.pipe.unet self.unet = TRTUNetAdapter(engine_path, self._trt_cuda_stream) + cache_state.unet_adapter = self.unet + cache_state.unet_has_controlnet = False print(f"[TRT] UNet engine active: {engine_path}", flush=True) - def _load_model(self, model_id: str) -> DiffusionPipeline: - """Load the diffusion model.""" + def _setup_trt( + self, + *, + height: int, + width: int, + controlnet_mode: str, + use_taesd: bool, + ) -> None: + """Build or attach TRT engines for the current model. + + Called at load time (``__init__`` and ``_swap_model``) so the first + frame doesn't stall on a 5-10 minute compile, and again from + ``__call__`` only when ``(height, width, controlnet_mode, use_taesd)`` + diverges from the last setup. The inner ``_ensure_trt_*`` methods + short-circuit when nothing needs to change. + """ + if self._acceleration_mode != "trt": + return try: - pipe = DiffusionPipeline.from_pretrained( - model_id, - torch_dtype=self.dtype, - variant="fp16" if self.dtype == torch.float16 else None, + self._ensure_trt_unet( + controlnet_mode, + image_height=int(height), + image_width=int(width), ) + except Exception as e: + print(f"[TRT] UNet engine swap failed, falling back to eager: {e}") + import traceback + traceback.print_exc() + if self._trt_eager_unet is not None: + self.unet = self._trt_eager_unet + if controlnet_mode in ("depth", "scribble"): + try: + self._ensure_trt_controlnet(controlnet_mode) + except Exception as e: + print( + f"[TRT] ControlNet engine swap failed for {controlnet_mode}, using eager: {e}" + ) + import traceback + traceback.print_exc() + if use_taesd: + try: + self._ensure_trt_taesd() + except Exception as e: + print(f"[TRT] TAESD engine swap failed, using eager: {e}") + import traceback + traceback.print_exc() + self._trt_setup_signature = ( + int(height), + int(width), + controlnet_mode, + bool(use_taesd), + ) + + def _reset_trt_state(self, new_model_id: str) -> None: + """Invalidate TRT sticky state so the next ``_setup_trt`` rebuilds. + + Called from ``_swap_model`` before loading the new model. Without + this, the sticky ``_trt_unet_built`` / ``_trt_taesd_built`` flags + cause subsequent ``_ensure_trt_*`` calls to short-circuit and the + new model runs eager regardless of ``acceleration_mode``. + """ + # Drop the module-scope cache entry for the previous model. Without + # this its ``unet_adapter`` / ``cn_adapters`` / ``taesd_adapter`` + # references stay live in ``_trt_cache._CACHE`` and pin engine + # memory across the swap — direct cause of OOM on a 24 GB card + # when going SD1.5 → SDXL with TRT on. + old_key = getattr(self, "_trt_cache_key", None) + if old_key: + _trt_cache.clear(old_key) + self._trt_unet_built = False + self._trt_unet_has_controlnet = False + self._trt_taesd_built = False + self._trt_eager_unet = None + self._trt_eager_taesd = None + self._trt_cn_built_modes.clear() + self._trt_cn_engines.clear() + self._trt_eager_controlnets.clear() + self._trt_cache_key = _trt_cache.cache_key(self._node_id, new_model_id) + self._trt_setup_signature = None + + def _set_acceleration_mode(self, mode: str) -> None: + """Swap between TRT-accelerated and eager modules at runtime. + + TRT engines themselves are immutable after build, but the choice of + which UNet / ControlNet / TAESD module ``self.*`` points at *can* be + flipped per frame. Cached adapters (in ``_trt_cache._CACHE`` and on + the instance) stay alive across the swap so toggling back to 'trt' + is instant after the first build. + """ + if mode not in ("none", "trt") or mode == self._acceleration_mode: + return + print( + f"[StreamDiffusion] acceleration_mode swap: " + f"{self._acceleration_mode} -> {mode}" + ) + if mode == "none": + self._deactivate_trt() + self._acceleration_mode = "none" + else: + self._acceleration_mode = "trt" + self._setup_trt(**self._trt_setup_args_from_config()) + + def _deactivate_trt(self) -> None: + """Restore eager UNet / ControlNet / TAESD; keep adapters cached. + + Resets the sticky ``_trt_*_built`` flags so a future ``_setup_trt`` + re-enters the cache-restore path and re-attaches the same adapters + without rebuilding. + """ + if self._trt_eager_unet is not None and self.unet is not self._trt_eager_unet: + self.unet = self._trt_eager_unet + if self._trt_eager_taesd is not None: + self._taesd_vae = self._trt_eager_taesd + if self._using_taesd: + self.vae = self._taesd_vae + if self._cn.model is not None: + self.controlnet = self._cn.model + self._trt_unet_built = False + self._trt_unet_has_controlnet = False + self._trt_taesd_built = False + self._trt_cn_built_modes.clear() + self._trt_setup_signature = None + + def _trt_setup_args_from_config(self) -> dict: + """Resolve _setup_trt args from self.config, with schema-default fallbacks.""" + cfg = self.config + return { + "height": int(getattr(cfg, "height", 512)) if cfg else 512, + "width": int(getattr(cfg, "width", 512)) if cfg else 512, + "controlnet_mode": getattr(cfg, "controlnet_mode", "none") if cfg else "none", + "use_taesd": bool(getattr(cfg, "use_taesd", True)) if cfg else True, + } + + def _load_model(self, model_id: str) -> DiffusionPipeline: + """Load the diffusion model. + + For HuggingFace model IDs, loads via DiffusionPipeline.from_pretrained + directly. For curated presets in MODEL_PRESETS, follows the preset's + recipe (base load + UNet swap, etc.). + """ + try: + preset = MODEL_PRESETS.get(model_id) + if preset is not None: + pipe = self._load_preset(preset) + else: + pipe = DiffusionPipeline.from_pretrained( + model_id, + torch_dtype=self.dtype, + variant="fp16" if self.dtype == torch.float16 else None, + ) pipe = pipe.to(self.device) # Enable xformers memory-efficient attention if available. - # The schema declares acceleration="xformers" but this was never called. try: pipe.enable_xformers_memory_efficient_attention() print("[StreamDiffusion] xformers memory-efficient attention enabled") @@ -421,6 +800,186 @@ def _load_model(self, model_id: str) -> DiffusionPipeline: print(f"Failed to load model {model_id}: {e}") raise + def _load_preset(self, preset: dict) -> DiffusionPipeline: + """Build a DiffusionPipeline from a MODEL_PRESETS recipe. + + Currently supports the ``unet_swap`` shape — load the base pipeline, + then override its UNet weights from a distilled checkpoint. Other + recipe shapes (LoRA fuse, scheduler override, timesteps_override) + will land alongside the `_set_timesteps` refactor needed to support + non-LCM schedulers. + """ + base = preset["base"] + print(f"[StreamDiffusion] Loading preset base: {base}") + pipe = DiffusionPipeline.from_pretrained( + base, + torch_dtype=self.dtype, + variant="fp16" if self.dtype == torch.float16 else None, + ) + + unet_swap = preset.get("unet_swap") + if unet_swap is not None: + from huggingface_hub import hf_hub_download + from huggingface_hub.utils import LocalEntryNotFoundError + + unet_repo, unet_file = unet_swap + # Probe the local cache first so we can log accurately. The unconditional + # "Downloading" print was misleading on every cached load. + try: + ckpt_path = hf_hub_download(unet_repo, unet_file, local_files_only=True) + print(f"[StreamDiffusion] Loading cached distilled UNet: {unet_repo}/{unet_file}") + except LocalEntryNotFoundError: + print(f"[StreamDiffusion] Downloading distilled UNet: {unet_repo}/{unet_file}") + ckpt_path = hf_hub_download(unet_repo, unet_file) + # Distilled-UNet repos (DMD2, SDXL-Lightning, etc.) often ship + # weights only — no config.json — because the architecture is + # identical to the base UNet. Reuse the base pipeline's UNet + # module and override its state_dict. + if unet_file.endswith(".safetensors"): + from safetensors.torch import load_file + state_dict = load_file(ckpt_path) + else: + state_dict = torch.load(ckpt_path, map_location="cpu") + pipe.unet.load_state_dict(state_dict) + print("[StreamDiffusion] Distilled UNet weights loaded") + return pipe + + return pipe + + def _release_pipe_state(self) -> None: + """Drop every GPU-resident reference owned by the pipeline. + + Called from :meth:`_swap_model` before loading the new model. + Clears module references (``unet`` / ``vae`` / ``text_encoder`` / + any TRT adapter still pinned in ``self.unet``), per-step cached + tensors, prompt-embedding caches, and the ControlNet handler's + sub-models. Caller (``_swap_model``) is expected to have already + run :meth:`_reset_trt_state` so the cache-state-held adapter + references are gone too. Finishes with a ``gc.collect`` + + ``torch.cuda.empty_cache`` so the next allocation starts clean. + """ + import gc + + # Module references — these are the big-ticket allocations. + # self.unet may be a TRT adapter that owns engine memory; nulling + # it here is what actually releases the engine. + self.unet = None + self.vae = None + self.text_encoder = None + self._taesd_vae = None + self._full_vae = None + self.controlnet = None + self.controlnet_input = None + if hasattr(self, "_cn") and self._cn is not None: + self._cn.release() + + # Cached per-step tensors. + for attr in ( + "init_noise", + "stock_noise", + "x_t_latent_buffer", + "prev_image_result", + "alpha_prod_t_sqrt", + "beta_prod_t_sqrt", + "c_skip", + "c_out", + "sub_timesteps_tensor", + "timesteps", + ): + if hasattr(self, attr): + setattr(self, attr, None) + + # Reset prompt-encoder caches (text-encoder-specific; the new + # model will have a different text encoder). + if hasattr(self, "prompts"): + self.prompts.reset_caches() + + # Seed-transition state. + self._seed_transition_source = None + self._seed_transition_target = None + + # Drop the pipeline last so any of the above that aliased its + # submodules have already been nulled. + self.pipe = None + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def _swap_model(self, new_model_id: str) -> None: + """Replace the loaded model in place. + + Scope routes ``model_id_or_path`` through both the load-time path + (which would reinit the pipeline cleanly) and the runtime + ``setNodeParams`` path (which only updates kwargs and never touches + ``__init__``). When the runtime kwarg disagrees with what we loaded, + rebuild the model parts here so picking a model in the UI actually + swaps it. Stalls the frame loop while loading — same as a fresh load. + """ + print(f"[StreamDiffusion] Swapping model: {self.model_id} -> {new_model_id}") + # Reset TRT sticky state for the new model. Without this the + # ``_trt_unet_built`` / ``_trt_taesd_built`` flags from the previous + # model cause the next ``_setup_trt`` to short-circuit, leaving the + # new model running eager regardless of acceleration_mode. + self._reset_trt_state(new_model_id) + # Tear down everything the old model holds on the GPU before loading + # the new one — without this we peak at 2x model weights + engines + # and OOM on large models (SDXL UNet alone is ~5 GB fp16, plus a + # 2 GB+ TRT engine, plus VAE / text encoders / cached tensors). + self._release_pipe_state() + + self.model_id = new_model_id + preset = MODEL_PRESETS.get(new_model_id, {}) + self._timesteps_override = preset.get("timesteps_override") + self._implicit_loopback = preset.get("implicit_loopback", True) + self.pipe = self._load_model(new_model_id) + print(f"[StreamDiffusion] Model loaded: {self.pipe.__class__.__name__}") + self.sdxl = type(self.pipe) is StableDiffusionXLPipeline + if self.sdxl and self.dtype == torch.float16: + self._install_sdxl_fp16_vae() + + self.text_encoder = self.pipe.text_encoder + self.unet = self.pipe.unet + self.vae = self.pipe.vae + self._full_vae = self.vae + self._using_taesd = False + self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) + self.image_processor = VaeImageProcessor(self.pipe.vae_scale_factor) + self.prompts.attach(self.pipe, self.sdxl) + + # Invalidate runtime caches so the next __call__ rebuilds the + # timestep schedule and noise buffers against the new model. + # Prompt-encoder caches are reset by ``prompts.attach()`` above. + self._schedule_key = None + self._noise_shape = None + self.prev_image_result = None + self._cancel_seed_transition() + + # Build TRT engines for the new model now so the next frame doesn't stall. + if self._acceleration_mode == "trt": + self._setup_trt(**self._trt_setup_args_from_config()) + + def _install_sdxl_fp16_vae(self) -> None: + """Swap SDXL's default VAE for madebyollin/sdxl-vae-fp16-fix. + + Stability AI's SDXL VAE overflows on certain inputs in fp16 and decodes + to NaN — even from a perfectly valid UNet prediction. The community + fp16-fix VAE is a drop-in replacement with the same architecture and + quality, retuned to be numerically stable in fp16. + """ + from diffusers import AutoencoderKL + + try: + print("[StreamDiffusion] Installing madebyollin/sdxl-vae-fp16-fix") + new_vae = AutoencoderKL.from_pretrained( + "madebyollin/sdxl-vae-fp16-fix", torch_dtype=self.dtype + ).to(self.device) + self.pipe.vae = new_vae + print("[StreamDiffusion] SDXL fp16-fix VAE installed") + except Exception as e: + print(f"[StreamDiffusion] Failed to install fp16-fix VAE: {e}") + def _set_taesd(self, enabled: bool) -> None: """Switch between TAESD (fast) and full VAE decoder.""" if enabled == self._using_taesd: @@ -492,6 +1051,7 @@ def _prepare_runtime_state( do_add_noise: bool, transition: Optional[dict] = None, transition_steps: int = 0, + seed_transition_steps: int = 0, cfg_type: Literal["none", "full", "self", "initialize"] = "self", t_index_list: Optional[List[int]] = None, ): @@ -510,6 +1070,14 @@ def _prepare_runtime_state( self.latent_height = int(height // self.pipe.vae_scale_factor) self.latent_width = int(width // self.pipe.vae_scale_factor) + # --- Scheduler defaults --- + # `num_inference_steps` is the user-facing sharpness lever: more steps = + # sharper detail (SD-Turbo proper is the exception — it's distilled for + # 1 step). When the caller didn't pin a `t_index_list`, walk every step + # in the schedule so the UNet sees the full LCM timestep range. + if t_index_list is None: + t_index_list = list(range(num_inference_steps)) + # --- Cheap scalar assignments --- self.strength = strength self.guidance_scale = guidance_scale @@ -542,381 +1110,61 @@ def _prepare_runtime_state( seed_changed = seed != self._last_seed shape_changed = noise_shape != self._noise_shape or dims_changed - if seed_changed: + if shape_changed: + # Different latent shape can't be lerped against the old buffer; + # hard-reset and cancel any in-flight seed transition. self.generator.manual_seed(seed) self._last_seed = seed - - if seed_changed or shape_changed: - if self.denoising_steps_num > 1: - self.x_t_latent_buffer = torch.zeros( - ( - (self.denoising_steps_num - 1) * self.frame_bff_size, - 4, - self.latent_height, - self.latent_width, - ), - dtype=self.dtype, - device=self.device, - ) - else: - self.x_t_latent_buffer = None + self._cancel_seed_transition() + self.x_t_latent_buffer = None self._initialize_noise() self._noise_shape = noise_shape + elif seed_changed: + # Hard cut when seed_transition_steps == 0; multi-frame lerp otherwise. + self._setup_seed_transition(seed, seed_transition_steps) - # --- Prompt embeddings & transitions --- - # The key includes spatial dims for SDXL because add_time_ids depend on them. - # When an explicit transition dict is present, its target_prompts is the - # authoritative destination; keying against the incoming source prompts - # would make prompts_changed flap during/after the transition and snap - # steady state back to the source. - key_prompts = prompts - if transition is not None: - target_raw = transition.get("target_prompts") - if target_raw: - key_prompts = self._normalize_prompts(target_raw) - new_prompts_key = self._make_prompts_key( - key_prompts, prompt_interpolation_method, width, height - ) - prompts_changed = new_prompts_key != self._prompts_key - - # Hash the explicit transition dict so repeated sends don't restart it. - transition_id = self._hash_transition(transition) if transition else None - new_explicit_transition = ( - transition_id is not None and transition_id != self._last_transition_id - ) - - started_transition = False - - # Cancel any in-flight transition if a new target has arrived so we - # redirect from the current interpolated position rather than snapping - # after the old transition drains. - if self.embedding_blender.is_transitioning() and ( - new_explicit_transition - or (transition is None and transition_steps > 0 and prompts_changed) - ): - self.embedding_blender.cancel_transition() - self._finish_pooled_transition() - - # 1) Explicit transition (transition dict with target_prompts). - if new_explicit_transition and not self.embedding_blender.is_transitioning(): - transition_config = parse_transition_config(transition) - target_prompts_raw = transition.get("target_prompts", []) - if transition_config.num_steps > 0 and target_prompts_raw: - target_prompts = self._normalize_prompts(target_prompts_raw) - started_transition = self._begin_transition( - target_prompts=target_prompts, - interpolation_method=prompt_interpolation_method, - num_steps=transition_config.num_steps, - temporal_method=transition_config.temporal_interpolation_method, - width=width, - height=height, - ) - self._last_transition_id = transition_id - - # 2) Auto-transition when `prompts` changes with transition_steps > 0. - elif ( - transition is None - and transition_steps > 0 - and prompts_changed - and self._previous_prompt_embeddings is not None - and not self.embedding_blender.is_transitioning() - ): - started_transition = self._begin_transition( - target_prompts=prompts, - interpolation_method=prompt_interpolation_method, - num_steps=transition_steps, - temporal_method=prompt_interpolation_method, - width=width, - height=height, - ) - - # --- Produce prompt_embeds for this frame --- - if self.embedding_blender.is_transitioning(): - next_embedding = self.embedding_blender.get_next_embedding() - if next_embedding is not None: - self.prompt_embeds = next_embedding.repeat(self.batch_size, 1, 1) - self._advance_pooled_transition() - else: - self.prompt_embeds = self._cached_base_embed.repeat( - self.batch_size, 1, 1 - ) - self._finish_pooled_transition() - else: - # Steady state — re-encode if prompts changed and we didn't start a - # transition for it (hard cut path, e.g. transition_steps == 0). - if prompts_changed and not started_transition: - raw_embeds, _ = self._encode_prompts_array( - key_prompts, prompt_interpolation_method - ) - self._cached_base_embed = raw_embeds[0:1] - self._prompts_key = new_prompts_key - # Drop the transition-id guard once the explicit dict is gone so a - # later identical dict is treated as a fresh request. - if transition is None: - self._last_transition_id = None - self._finish_pooled_transition() - self.prompt_embeds = self._cached_base_embed.repeat(self.batch_size, 1, 1) - - # Cache embedding as source for the next transition. - self._previous_prompt_embeddings = self.prompt_embeds[0:1].detach() - - def _make_prompts_key( - self, - prompts: list[dict], - interpolation_method: str, - width: int, - height: int, - ) -> tuple: - """Identity key for a prompts payload; SDXL includes dims for add_time_ids.""" - return ( - tuple((p.get("text", ""), p.get("weight", 1.0)) for p in prompts), - interpolation_method, - (width, height) if self.sdxl else (), - ) - - @staticmethod - def _hash_transition(transition: dict) -> str: - """Stable identity for a transition dict so repeated sends don't restart it.""" - import hashlib - import json - - payload = { - "num_steps": int(transition.get("num_steps", 0) or 0), - "method": transition.get("temporal_interpolation_method", "linear"), - "target": [ - { - "text": p.get("text", "") if isinstance(p, dict) else str(p), - "weight": float(p.get("weight", 1.0)) if isinstance(p, dict) else 1.0, - } - for p in (transition.get("target_prompts") or []) - ], - } - encoded = json.dumps(payload, sort_keys=True).encode("utf-8") - return hashlib.sha1(encoded).hexdigest() + # Advance any in-flight seed transition by one frame. No-op when idle. + self._advance_seed_transition() - def _begin_transition( - self, - target_prompts: list[dict], - interpolation_method: str, - num_steps: int, - temporal_method: str, - width: int, - height: int, - ) -> bool: - """Start a temporal transition from the last emitted embedding toward - the target prompts. Eagerly advances `_cached_base_embed` and - `_prompts_key` to the target so steady state lands there when the queue - drains. Returns True if a transition was actually started. - """ - source_embedding = self._previous_prompt_embeddings - if source_embedding is None: - return False - - # Encode and blend target in main embedding space + pooled (SDXL). - target_embed, target_pooled = self._encode_prompts_array( - target_prompts, interpolation_method, apply_sdxl_conditioning=False - ) - target_embed_single = target_embed[0:1] - - # Eagerly move the steady-state cache to the target so once the queue - # drains we land on the target prompts with no bounce-back. - self._cached_base_embed = target_embed_single - self._prompts_key = self._make_prompts_key( - target_prompts, interpolation_method, width, height + # --- Prompt embeddings & transitions --- + # All prompt encoding, blending, transition handling, and SDXL aug- + # conditioning lives in the PromptEncoder helper. After this call, + # ``self.prompts.prompt_embeds`` (and add_text_embeds / add_time_ids + # for SDXL) holds the conditioning for this frame. + self.prompts.encode_for_frame( + prompts=prompts, + interpolation_method=prompt_interpolation_method, + width=width, + height=height, + batch_size=self.batch_size, + transition=transition, + transition_steps=transition_steps, ) - # Slerp is not supported here: upstream EmbeddingBlender.slerp runs - # torch.acos on the native dtype; at fp16 the [-1, 1] clamp isn't - # enough to prevent acos(1.0) → NaN at certain token positions, which - # nukes the whole conditioning tensor. Until that's fixed upstream, - # fall back to linear and warn once. - if temporal_method == "slerp": - if not getattr(self, "_slerp_fallback_warned", False): - print( - "[StreamDiffusion] slerp temporal interpolation is not " - "supported (fp16 NaN in upstream blender); falling back " - "to linear." - ) - self._slerp_fallback_warned = True - temporal_method = "linear" - - self.embedding_blender.start_transition( - source_embedding=source_embedding, - target_embedding=target_embed_single, - num_steps=num_steps, - temporal_interpolation_method=temporal_method, - ) + def _set_timesteps(self, num_inference_steps: int, strength: float): + """Set the timesteps for the diffusion process. - # Pooled interpolation runs in lockstep with the main queue for SDXL. - if self.sdxl and target_pooled is not None: - self._pooled_source = ( - self.add_text_embeds.detach().clone() - if hasattr(self, "add_text_embeds") and self.add_text_embeds is not None - else target_pooled.clone() - ) - self._pooled_target = target_pooled.clone() - self._transition_total_steps = max(1, num_steps) - else: - self._pooled_source = None - self._pooled_target = None - self._transition_total_steps = 0 - - # start_transition short-circuits when source ≈ target - # (MIN_EMBEDDING_DIFF_THRESHOLD); report accurately so the caller falls - # to steady state instead of assuming a transition is live. - if not self.embedding_blender.is_transitioning(): - self._finish_pooled_transition() - return False - return True - - def _advance_pooled_transition(self) -> None: - """Linearly interpolate `add_text_embeds` toward the target pooled. - - Uses the blender's remaining queue length to compute progress so - pooled and main embeds stay in lockstep even if start_transition - short-circuited. + Honors `MODEL_PRESETS[...]["timesteps_override"]` when present. + Distilled 1-step models (DMD2, Hyper-SD, Lightning) are trained at + a specific timestep and produce garbage at any other one — letting + LCMScheduler pick the default would feed them ~t=979 (near max + noise) where they were never trained. """ - if not self.sdxl or self._pooled_target is None: - return - if self._transition_total_steps <= 0: - return - remaining = len(self.embedding_blender._transition_queue) - done_steps = self._transition_total_steps - remaining - t = min(1.0, max(0.0, done_steps / self._transition_total_steps)) - source = ( - self._pooled_source - if self._pooled_source is not None - else self._pooled_target - ) - self.add_text_embeds = torch.lerp(source, self._pooled_target, t).to( - dtype=self.dtype, device=self.device - ) - - def _finish_pooled_transition(self) -> None: - """Snap pooled to the target and clear transition state.""" - if self.sdxl and self._pooled_target is not None: - self.add_text_embeds = self._pooled_target.to( - dtype=self.dtype, device=self.device + if self._timesteps_override is not None: + # Pin the override; still call set_timesteps so the scheduler + # internals (timestep_scaling, etc.) are populated for any + # downstream lookups. + self.scheduler.set_timesteps( + num_inference_steps, self.device, strength=strength ) - self._pooled_source = None - self._pooled_target = None - self._transition_total_steps = 0 - - @staticmethod - def _normalize_prompts(prompts: str | list[str] | list[dict]) -> list[dict]: - """Normalize prompts to list[dict] format.""" - if isinstance(prompts, str): - return [{"text": prompts, "weight": 1.0}] - if isinstance(prompts, list): - if len(prompts) == 0: - return [{"text": "", "weight": 1.0}] - # Check if it's a list of strings - if isinstance(prompts[0], str): - return [{"text": text, "weight": 1.0} for text in prompts] - # Already list[dict] - return prompts - return [{"text": str(prompts), "weight": 1.0}] - - def _encode_single_prompt( - self, prompt_text: str - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """Encode a single prompt string to embeddings. - - Returns: - (prompt_embeds, pooled_embeds) tuple - """ - # Use diffusers' built-in encoding - encoder_output = self.pipe.encode_prompt( - prompt=prompt_text, - device=self.device, - num_images_per_prompt=1, - do_classifier_free_guidance=False, - negative_prompt=None, - ) - prompt_embeds = encoder_output[0] # [1, seq_len, hidden_dim] - pooled_embeds = encoder_output[2] if self.sdxl else None - - return prompt_embeds, pooled_embeds - - def _encode_prompts_array( - self, - prompt_items: list[dict], - interpolation_method: str = "linear", - apply_sdxl_conditioning: bool = True, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """Encode multiple weighted prompts and blend them. - - Args: - prompt_items: List of {"text": str, "weight": float} - interpolation_method: "linear" or "slerp" - apply_sdxl_conditioning: When True (default, steady-state encode), - also updates `self.add_text_embeds` and `self.add_time_ids` - for SDXL. Set False when encoding a transition target so the - in-flight pooled/time_ids aren't overwritten mid-morph. - - Returns: - (blended_prompt_embeds, blended_pooled_embeds) tuple - """ - if not prompt_items: - prompt_items = [{"text": "", "weight": 1.0}] - - # Extract texts and weights - texts = [item.get("text", "") for item in prompt_items] - weights = [item.get("weight", 1.0) for item in prompt_items] - - # Encode each prompt - all_prompt_embeds = [] - all_pooled_embeds = [] if self.sdxl else None - - for text in texts: - prompt_embeds, pooled_embeds = self._encode_single_prompt(text) - all_prompt_embeds.append(prompt_embeds) - if self.sdxl and pooled_embeds is not None: - all_pooled_embeds.append(pooled_embeds) - - # Blend embeddings - blended_prompt_embeds = self.embedding_blender.blend( - all_prompt_embeds, - weights, - interpolation_method, - cache_result=True, - ) - - blended_pooled_embeds = None - if self.sdxl and all_pooled_embeds: - blended_pooled_embeds = self.embedding_blender.blend( - all_pooled_embeds, - weights, - interpolation_method, - cache_result=False, + self.timesteps = torch.tensor( + self._timesteps_override, device=self.device, dtype=torch.long ) - - # Handle SDXL additional embeddings (skipped for transition-target - # encoding so the live pooled/time_ids aren't overwritten mid-morph). - if apply_sdxl_conditioning and self.sdxl and blended_pooled_embeds is not None: - self.add_text_embeds = blended_pooled_embeds - original_size = (self.height, self.width) - crops_coords_top_left = (0, 0) - target_size = (self.height, self.width) - text_encoder_projection_dim = int(self.add_text_embeds.shape[-1]) - self.add_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - dtype=self.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, + else: + self.scheduler.set_timesteps( + num_inference_steps, self.device, strength=strength ) - - return blended_prompt_embeds.repeat( - self.batch_size, 1, 1 - ), blended_pooled_embeds - - def _set_timesteps(self, num_inference_steps: int, strength: float): - """Set the timesteps for the diffusion process.""" - self.scheduler.set_timesteps( - num_inference_steps, self.device, strength=strength - ) - self.timesteps = self.scheduler.timesteps.to(self.device) + self.timesteps = self.scheduler.timesteps.to(self.device) # Make sub timesteps list self.sub_timesteps = [] @@ -956,16 +1204,15 @@ def _set_timesteps(self, num_inference_steps: int, strength: float): # Calculate alpha/beta values alpha_prod_t_sqrt_list = [] beta_prod_t_sqrt_list = [] + ac = self.scheduler.alphas_cumprod + last_idx = len(ac) - 1 for timestep in self.sub_timesteps: - if timestep >= len(self.scheduler.alphas_cumprod): - print( - f"Warning: timestep {timestep} is greater than the number of timesteps {len(self.scheduler.alphas_cumprod)}" - ) - continue - alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt() - beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt() - alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt) - beta_prod_t_sqrt_list.append(beta_prod_t_sqrt) + # Clamp into range instead of skipping — skipping would make the + # downstream .view(len(t_list), 1, 1, 1) reshape fail when any + # timestep happened to land out of range for this scheduler. + idx = min(int(timestep), last_idx) + alpha_prod_t_sqrt_list.append(ac[idx].sqrt()) + beta_prod_t_sqrt_list.append((1 - ac[idx]).sqrt()) alpha_prod_t_sqrt = ( torch.stack(alpha_prod_t_sqrt_list) @@ -999,31 +1246,86 @@ def _initialize_noise(self): self.stock_noise = torch.zeros_like(self.init_noise) - def _get_add_time_ids( - self, - original_size, - crops_coords_top_left, - target_size, - dtype, - text_encoder_projection_dim=None, - ): - """Get additional time IDs for SDXL.""" - add_time_ids = list(original_size + crops_coords_top_left + target_size) + def _setup_seed_transition(self, new_seed: int, total_steps: int) -> None: + """Begin a multi-frame lerp from the current init_noise to the new seed. + + Falls back to a hard cut (re-seed + regenerate immediately) when + ``total_steps <= 0`` or no prior ``init_noise`` exists. The first + frame after this runs at the source noise; subsequent frames lerp + toward the target via :meth:`_advance_seed_transition`. + """ + self._cancel_seed_transition() + if total_steps <= 0 or self.init_noise is None: + self.generator.manual_seed(new_seed) + self._last_seed = new_seed + self.x_t_latent_buffer = None + self._initialize_noise() + return + + self._seed_transition_source = self.init_noise.detach().clone() + self.generator.manual_seed(new_seed) + self._seed_transition_target = torch.randn( + self.init_noise.shape, + generator=self.generator, + device=self.device, + dtype=self.dtype, + ) + self._seed_transition_progress = 0 + self._seed_transition_total = total_steps + self._last_seed = new_seed + # Match the hard-cut path's stock_noise reset so the StreamDiffusion + # feedback term doesn't carry the previous seed's accumulator. + self.stock_noise = torch.zeros_like(self.init_noise) + self.x_t_latent_buffer = None - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) - + text_encoder_projection_dim + @staticmethod + def _slerp_noise(a: torch.Tensor, b: torch.Tensor, t: float) -> torch.Tensor: + """Spherical interpolation between two noise tensors. + + Linear interpolation drops the variance of standard-normal noise to + ``(1-t)² + t²`` mid-blend (0.5 at t=0.5), which the diffusion model + renders as washed-out / blurry output. Slerp keeps the result on the + same hypersphere as the endpoints, preserving variance and producing + a perceptually smooth crossfade between scenes. + """ + a_flat = a.flatten().float() + b_flat = b.flatten().float() + a_norm = a_flat.norm() + b_norm = b_flat.norm() + cos_omega = (a_flat @ b_flat) / (a_norm * b_norm + 1e-8) + cos_omega = cos_omega.clamp(-1.0, 1.0) + omega = torch.acos(cos_omega) + sin_omega = torch.sin(omega) + # Collinear endpoints — degenerate to lerp to avoid divide-by-zero. + if sin_omega.abs() < 1e-6: + return torch.lerp(a, b, t) + w_a = torch.sin((1.0 - t) * omega) / sin_omega + w_b = torch.sin(t * omega) / sin_omega + return (w_a * a + w_b * b).to(dtype=a.dtype) + + def _advance_seed_transition(self) -> None: + """Slerp ``init_noise`` one step toward the target. No-op when idle.""" + if self._seed_transition_total <= 0: + return + self._seed_transition_progress += 1 + if self._seed_transition_progress >= self._seed_transition_total: + self.init_noise = self._seed_transition_target.clone() + self._cancel_seed_transition() + return + t = self._seed_transition_progress / self._seed_transition_total + self.init_noise = self._slerp_noise( + self._seed_transition_source, + self._seed_transition_target, + t, ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " - f"but a vector of {passed_add_embed_dim} was created." - ) + def _cancel_seed_transition(self) -> None: + """Drop any in-flight seed transition without snapping init_noise.""" + self._seed_transition_source = None + self._seed_transition_target = None + self._seed_transition_progress = 0 + self._seed_transition_total = 0 - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids def _encode_image( self, image_tensors: torch.Tensor, add_noise: bool = True @@ -1114,7 +1416,7 @@ def _unet_step( down_block_res_samples, mid_block_res_sample = self.controlnet( x_t_latent_plus_uc, t_list, - encoder_hidden_states=self.prompt_embeds, + encoder_hidden_states=self.prompts.prompt_embeds, controlnet_cond=cond_image, conditioning_scale=self.controlnet_conditioning_scale, return_dict=False, @@ -1123,7 +1425,7 @@ def _unet_step( model_pred = self.unet( x_t_latent_plus_uc, t_list, - encoder_hidden_states=self.prompt_embeds, + encoder_hidden_states=self.prompts.prompt_embeds, added_cond_kwargs=added_cond_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, @@ -1174,16 +1476,15 @@ def _predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: if self.use_denoising_batch: t_list = self.sub_timesteps_tensor - if self.denoising_steps_num > 1: - x_t_latent = torch.cat((x_t_latent, prev_latent_batch), dim=0) - self.stock_noise = torch.cat( - (self.init_noise[0:1], self.stock_noise[:-1]), dim=0 - ) if self.sdxl: - added_cond_kwargs = { - "text_embeds": self.add_text_embeds.to(self.device), - "time_ids": self.add_time_ids.to(self.device), - } + batch = x_t_latent.shape[0] + te = self.prompts.add_text_embeds.to(self.device) + ti = self.prompts.add_time_ids.to(self.device) + if te.shape[0] != batch: + te = te[:1].expand(batch, -1) + if ti.shape[0] != batch: + ti = ti[:1].expand(batch, -1) + added_cond_kwargs = {"text_embeds": te, "time_ids": ti} x_t_latent = x_t_latent.to(self.device) t_list = t_list.to(self.device) @@ -1191,20 +1492,8 @@ def _predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: x_t_latent, t_list, added_cond_kwargs=added_cond_kwargs ) - if self.denoising_steps_num > 1: - x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0) - if self.do_add_noise: - self.x_t_latent_buffer = ( - self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] - + self.beta_prod_t_sqrt[1:] * self.init_noise[1:] - ) - else: - self.x_t_latent_buffer = ( - self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] - ) - else: - x_0_pred_out = x_0_pred_batch - self.x_t_latent_buffer = None + x_0_pred_out = x_0_pred_batch + self.x_t_latent_buffer = None else: self.init_noise = x_t_latent for idx, t in enumerate(self.sub_timesteps_tensor): @@ -1215,8 +1504,8 @@ def _predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: ) if self.sdxl: added_cond_kwargs = { - "text_embeds": self.add_text_embeds.to(self.device), - "time_ids": self.add_time_ids.to(self.device), + "text_embeds": self.prompts.add_text_embeds.to(self.device), + "time_ids": self.prompts.add_time_ids.to(self.device), } x_0_pred, _model_pred = self._unet_step( x_t_latent, t, idx=idx, added_cond_kwargs=added_cond_kwargs @@ -1287,7 +1576,7 @@ def __call__(self, **kwargs) -> dict: prompts = kwargs.get("prompts", []) # Normalize to list[dict] format prompts = ( - self._normalize_prompts(prompts) + normalize_prompts(prompts) if prompts else [{"text": "", "weight": 1.0}] ) @@ -1314,19 +1603,31 @@ def get_param(key, default): # Finally use default return default + # Resolve the user's model selection from runtime kwargs/config. + # On the first call the pipe isn't loaded yet — __init__ defers the + # load specifically so we can pick the *real* selection here instead + # of the schema default that pipeline_manager hands us at __init__. + # On subsequent calls a runtime change (UI swap) routes through + # _swap_model. + requested_model = get_param("model_id_or_path", None) or self.model_id + if self.pipe is None: + self._ensure_pipe_loaded(requested_model) + elif requested_model and requested_model != self.model_id: + self._swap_model(requested_model) + # Extract all parameters with config fallback prompt_interpolation_method = get_param("prompt_interpolation_method", "linear") guidance_scale = get_param("guidance_scale", 0.0) - # SD Turbo: Use single timestep (t_index_list=[0]) but set schedule length - # This matches your working project setup - num_inference_steps = get_param("num_inference_steps", 3) + # SD-Turbo and SDXL-Turbo are both 1-step distillations. + num_inference_steps = 1 # For img2img with SD Turbo, need higher strength for visible changes # 0.5-0.7 = moderate, 0.8-0.95 = heavy transformation strength = get_param("strength", 0.9) seed = get_param("seed", 42) + seed_transition_steps = get_param("seed_transition_steps", 0) delta = get_param("delta", 1.0) width = get_param("width", 512) height = get_param("height", 512) @@ -1334,6 +1635,8 @@ def get_param(key, default): do_add_noise = get_param("do_add_noise", True) similar_image_filter_enabled = get_param("similar_image_filter_enabled", False) image_loopback = get_param("image_loopback", False) + negative_prompt = get_param("negative_prompt", "") + negative_prompt_scale = float(get_param("negative_prompt_scale", 1.0)) controlnet_mode = get_param("controlnet_mode", "none") controlnet_scale = get_param("controlnet_scale", 1.0) controlnet_temporal_smoothing = get_param("controlnet_temporal_smoothing", 0.5) @@ -1347,8 +1650,14 @@ def get_param(key, default): # at ~40 ms/call vs TAESD's ~5 ms. Big perf cliff if the param isn't # propagated from moth (e.g. queue-drop or absent from project file). use_taesd = get_param("use_taesd", True) - # acceleration_mode is locked at init (see __init__) — runtime updates - # don't change it because TRT engines can't be hot-swapped. + # acceleration_mode is hot-swappable: the engines themselves can't be + # rebuilt at runtime, but the module references (self.unet etc.) can + # flip between TRT adapters and eager modules. _set_acceleration_mode + # swaps; first 'trt' activation builds (slow), subsequent ones hit + # the cached adapters (instant). + requested_mode = get_param("acceleration_mode", self._acceleration_mode) + if requested_mode != self._acceleration_mode: + self._set_acceleration_mode(requested_mode) acceleration_mode = self._acceleration_mode # --- Safeguard: prevent invalid strength / num_inference_steps combos --- @@ -1389,33 +1698,20 @@ def get_param(key, default): self.controlnet = self._cn.model self.controlnet_input = self._cn.input - # TRT engine swap — UNet always, ControlNet additionally when active. - # Two separate engines (each <2 GB ONNX) instead of a single combined - # graph that hits TRT's cask-convolution bug. + # TRT engines are normally built at load time (in __init__ / + # _swap_model). This guard catches the residual cases where runtime + # values diverge from what was used at load — e.g. the user changes + # resolution, toggles controlnet on/off, or flips use_taesd in the + # UI. Fast no-op when nothing changed. if acceleration_mode == "trt": - try: - self._ensure_trt_unet(controlnet_mode) - except Exception as e: - print(f"[TRT] UNet engine swap failed, falling back to eager: {e}") - import traceback - traceback.print_exc() - if self._trt_eager_unet is not None: - self.unet = self._trt_eager_unet - if controlnet_mode in ("depth", "scribble"): - try: - self._ensure_trt_controlnet(controlnet_mode) - except Exception as e: - print(f"[TRT] ControlNet engine swap failed for {controlnet_mode}, using eager: {e}") - import traceback - traceback.print_exc() - # TAESD TRT — saves ~3-5 ms vs eager TAESD (which is already fast) - if use_taesd: - try: - self._ensure_trt_taesd() - except Exception as e: - print(f"[TRT] TAESD engine swap failed, using eager: {e}") - import traceback - traceback.print_exc() + sig = (int(height), int(width), controlnet_mode, bool(use_taesd)) + if sig != self._trt_setup_signature: + self._setup_trt( + height=int(height), + width=int(width), + controlnet_mode=controlnet_mode, + use_taesd=bool(use_taesd), + ) self.controlnet_conditioning_scale = self._cn.scale # Extract transition (explicit transition overrides auto-transition) @@ -1437,14 +1733,31 @@ def get_param(key, default): do_add_noise=do_add_noise, transition=transition, transition_steps=transition_steps, + seed_transition_steps=seed_transition_steps, ) + # Apply embedding-space negative subtraction *after* prompt embeds + # are settled (including any prompt transition / SDXL pooled + # update). Acts on whatever this frame's conditioning happens to + # be, which is the right thing during transitions too. + self.prompts.apply_negative_subtraction(negative_prompt, negative_prompt_scale) + frame = None - # Process input - if image_loopback or ( - (video is None or len(video) == 0) and self.prev_image_result is not None - ): + # Process input. In text-only mode (no video stream) we fall back to + # the previous frame's output as input — the implicit-loopback path. + # This is what gives txt2img its iterative refinement: frame 1 is a + # cold t2i pass and frames 2+ are img2img on the previous output, so + # SD-Turbo's single-step recovery sharpens detail across frames. + # Disabled per-model for CFG-distilled checkpoints (DMD2) where the + # baked-in guidance shaping compounds catastrophically across the + # feedback loop. Explicit image_loopback=True still wins regardless, + # so the user can force loopback on DMD2 if they want the stylized + # divergence (or for testing). + implicit_ok = self._implicit_loopback and ( + video is None or len(video) == 0 + ) and self.prev_image_result is not None + if image_loopback or implicit_ok: frame = self.prev_image_result elif video is not None and len(video) > 0: # Convert Scope tensor format to pipeline format @@ -1489,18 +1802,16 @@ def get_param(key, default): return {"video": output.permute(0, 2, 3, 1).clamp(0, 1)} input_tensor = filtered - # Encode to latent space - input_latent = self._encode_image(input_tensor) + input_latent = self._encode_image(input_tensor, add_noise=True) else: - # Text-to-image mode - input_latent = torch.randn( - (1, 4, self.latent_height, self.latent_width), - device=self.device, - dtype=self.dtype, - ) + # Text-to-image mode — use the seeded `init_noise` instead of a + # fresh unseeded randn. With a fresh randn per call, every frame + # would generate a different scene; the seeded buffer keeps the + # output stable across frames for the same seed (and lets the + # user reseed deterministically by changing `seed`). + input_latent = self.init_noise[0:1].clone() - # Run diffusion x_0_pred_out = self._predict_x0_batch(input_latent) # Decode to image space x_output = self._decode_image(x_0_pred_out).detach().clone() @@ -1558,7 +1869,6 @@ def main(): test_params = { "prompt": "A beautiful sunset over mountains", "negative_prompt": "ugly, blurry, low quality", - "num_inference_steps": 4, "guidance_scale": 0.0, "strength": 0.99, "seed": 42, @@ -1571,7 +1881,6 @@ def main(): print("\nTest parameters:") print(f" Prompt: {test_params['prompt']}") - print(f" Steps: {test_params['num_inference_steps']}") print(f" Size: {test_params['width']}x{test_params['height']}") print("\nRunning pipeline 10 times...\n") diff --git a/src/scope_streamdiffusion/prompt_encoder.py b/src/scope_streamdiffusion/prompt_encoder.py new file mode 100644 index 0000000..0ff3b1d --- /dev/null +++ b/src/scope_streamdiffusion/prompt_encoder.py @@ -0,0 +1,527 @@ +"""Prompt encoding, blending, transitions, and negative subtraction. + +Owns everything text-encoder-related so the main pipeline doesn't have to. +The pipeline holds an instance as ``self.prompts`` and calls +``encode_for_frame()`` once per ``__call__``, then optionally +``apply_negative_subtraction()``. Inference reads the produced embeds via +``self.prompts.prompt_embeds`` / ``add_text_embeds`` / ``add_time_ids``. + +Lifecycle: ``attach(pipe, sdxl)`` after a model load (or model swap) wires +us to the live pipeline and resets all text-encoder-dependent caches. +``reset_caches()`` is the lighter version called during teardown without +re-attaching. +""" + +from __future__ import annotations + +import hashlib +import json +from typing import Any, Optional + +import torch + +from scope.core.pipelines.blending import EmbeddingBlender, parse_transition_config + + +def normalize_prompts(prompts: str | list[str] | list[dict]) -> list[dict]: + """Coerce a prompts payload into ``list[{"text": str, "weight": float}]``. + + Module-level so the pipeline can call it on raw kwargs before any + PromptEncoder method that expects normalized input. + """ + if isinstance(prompts, str): + return [{"text": prompts, "weight": 1.0}] + if isinstance(prompts, list): + if len(prompts) == 0: + return [{"text": "", "weight": 1.0}] + if isinstance(prompts[0], str): + return [{"text": text, "weight": 1.0} for text in prompts] + return prompts + return [{"text": str(prompts), "weight": 1.0}] + + +class PromptEncoder: + """Per-frame prompt encoding with transitions, caching, and negative + subtraction. + + Attach to a loaded ``DiffusionPipeline`` via ``attach(pipe, sdxl)``; + re-attach on every model swap because the text encoder identity (and + hidden dim, for SDXL) changes between SD 1.5 and SDXL. + """ + + def __init__(self, device: torch.device, dtype: torch.dtype) -> None: + self.device = device + self.dtype = dtype + + # Live pipe references — set via ``attach()``. Until then the encoder + # is inert; calling encode_for_frame would raise. + self.pipe: Any = None + self.sdxl: bool = False + + self.embedding_blender = EmbeddingBlender(device=device, dtype=dtype) + + # Current-frame outputs the inference path reads. Inference accesses + # ``self.prompts.prompt_embeds`` / ``add_text_embeds`` / ``add_time_ids``. + self.prompt_embeds: Optional[torch.Tensor] = None + self.add_text_embeds: Optional[torch.Tensor] = None + self.add_time_ids: Optional[torch.Tensor] = None + + # Per-text-encoder caches. All invalidate on attach() and reset_caches(). + self._cached_base_embed: Optional[torch.Tensor] = None + self._previous_prompt_embeddings: Optional[torch.Tensor] = None + self._prompts_key: Optional[tuple] = None + + # Negative-prompt cache. + self._cached_negative_text: Optional[str] = None + self._cached_negative_embed: Optional[torch.Tensor] = None + self._cached_negative_pooled: Optional[torch.Tensor] = None + + # Pooled (SDXL) transition state — main embedding queue lives in + # ``embedding_blender``; pooled is interpolated linearly in lockstep. + self._pooled_source: Optional[torch.Tensor] = None + self._pooled_target: Optional[torch.Tensor] = None + self._transition_total_steps: int = 0 + + # Transition-id guard so repeated identical explicit transition dicts + # don't restart the transition every frame. + self._last_transition_id: Optional[str] = None + + # One-shot warning when slerp is requested for temporal interpolation + # (fp16 NaN bug in upstream blender). We fall back to linear silently + # after the first warn. + self._slerp_fallback_warned: bool = False + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def attach(self, pipe: Any, sdxl: bool) -> None: + """Wire to a loaded pipeline and reset all caches. + + Call from ``_ensure_pipe_loaded`` and ``_swap_model`` after the new + pipe is available; SD 1.5 and SDXL have different text-encoder + hidden dims, so cached embeds from the prior model would mismatch. + """ + self.pipe = pipe + self.sdxl = sdxl + self.reset_caches() + + def reset_caches(self) -> None: + """Drop every cached tensor and cancel any in-flight transition.""" + self._cached_base_embed = None + self._previous_prompt_embeddings = None + self._prompts_key = None + self._cached_negative_text = None + self._cached_negative_embed = None + self._cached_negative_pooled = None + self._pooled_source = None + self._pooled_target = None + self._transition_total_steps = 0 + self._last_transition_id = None + try: + self.embedding_blender.cancel_transition() + except Exception: + pass + + # ------------------------------------------------------------------ + # Per-frame encode + # ------------------------------------------------------------------ + + def encode_for_frame( + self, + prompts: list[dict], + interpolation_method: str, + width: int, + height: int, + batch_size: int, + transition: Optional[dict] = None, + transition_steps: int = 0, + ) -> None: + """Update ``self.prompt_embeds`` (and SDXL extras) for this frame. + + Handles: prompts-changed re-encoding, explicit transition-dict + starts, auto-transitions on prompt change, blender advance for + in-flight transitions, and pooled (SDXL) lockstep lerp. + """ + # When an explicit transition dict is present, its target_prompts is + # the authoritative destination; keying against the source prompts + # would make prompts_changed flap during/after the transition and + # snap steady state back to the source. + key_prompts = prompts + if transition is not None: + target_raw = transition.get("target_prompts") + if target_raw: + key_prompts = normalize_prompts(target_raw) + new_prompts_key = self._make_prompts_key( + key_prompts, interpolation_method, width, height + ) + prompts_changed = new_prompts_key != self._prompts_key + + transition_id = self._hash_transition(transition) if transition else None + new_explicit_transition = ( + transition_id is not None and transition_id != self._last_transition_id + ) + + started_transition = False + + # Cancel any in-flight transition if a new target has arrived so we + # redirect from the current interpolated position rather than + # snapping after the old transition drains. + if self.embedding_blender.is_transitioning() and ( + new_explicit_transition + or (transition is None and transition_steps > 0 and prompts_changed) + ): + self.embedding_blender.cancel_transition() + self._finish_pooled_transition() + + if new_explicit_transition and not self.embedding_blender.is_transitioning(): + transition_config = parse_transition_config(transition) + target_prompts_raw = transition.get("target_prompts", []) + if transition_config.num_steps > 0 and target_prompts_raw: + target_prompts = normalize_prompts(target_prompts_raw) + started_transition = self._begin_transition( + target_prompts=target_prompts, + interpolation_method=interpolation_method, + num_steps=transition_config.num_steps, + temporal_method=transition_config.temporal_interpolation_method, + width=width, + height=height, + ) + self._last_transition_id = transition_id + elif ( + transition is None + and transition_steps > 0 + and prompts_changed + and self._previous_prompt_embeddings is not None + and not self.embedding_blender.is_transitioning() + ): + started_transition = self._begin_transition( + target_prompts=prompts, + interpolation_method=interpolation_method, + num_steps=transition_steps, + temporal_method=interpolation_method, + width=width, + height=height, + ) + + # --- Produce prompt_embeds for this frame --- + if self.embedding_blender.is_transitioning(): + next_embedding = self.embedding_blender.get_next_embedding() + if next_embedding is not None: + self.prompt_embeds = next_embedding.repeat(batch_size, 1, 1) + self._advance_pooled_transition() + else: + self.prompt_embeds = self._cached_base_embed.repeat(batch_size, 1, 1) + self._finish_pooled_transition() + else: + # Steady state — re-encode if prompts changed and we didn't start + # a transition for it (hard cut path, e.g. transition_steps == 0). + if prompts_changed and not started_transition: + raw_embeds, _ = self._encode_prompts_array( + key_prompts, + interpolation_method, + width=width, + height=height, + batch_size=batch_size, + ) + self._cached_base_embed = raw_embeds[0:1] + self._prompts_key = new_prompts_key + # Drop the transition-id guard once the explicit dict is gone so + # a later identical dict is treated as a fresh request. + if transition is None: + self._last_transition_id = None + self._finish_pooled_transition() + self.prompt_embeds = self._cached_base_embed.repeat(batch_size, 1, 1) + + # Cache embedding as source for the next transition. + self._previous_prompt_embeddings = self.prompt_embeds[0:1].detach() + + # ------------------------------------------------------------------ + # Negative-prompt subtraction (single-pass models) + # ------------------------------------------------------------------ + + def apply_negative_subtraction( + self, negative_prompt: str, negative_prompt_scale: float + ) -> None: + """Norm-preserving negative subtraction in embedding space. + + Single-pass models (Turbo, DMD2) can't use standard CFG without + doubling UNet cost. Embedding subtraction is the cheap alternative, + but raw ``pos - scale * neg`` blows up the L2 norm of each token, + knocking the conditioning out of the training distribution and + the UNet predicts pure noise. + + We do the subtraction directionally and then renormalize each + token's embedding back to the original L2 norm. Same treatment + applied to SDXL's pooled ``add_text_embeds``. ``add_time_ids`` + are positional / size-derived, not text-derived, so they stay put. + + Encoded negative is cached on text; empty text or scale 0 is a + no-op. Cache invalidates on model swap (text-encoder dim changes). + """ + if negative_prompt_scale <= 0 or not negative_prompt: + return + if self.prompt_embeds is None: + return + if ( + self._cached_negative_text != negative_prompt + or self._cached_negative_embed is None + ): + neg_embed, neg_pooled = self._encode_single_prompt(negative_prompt) + self._cached_negative_text = negative_prompt + self._cached_negative_embed = neg_embed.detach() + self._cached_negative_pooled = ( + neg_pooled.detach() if neg_pooled is not None else None + ) + + self.prompt_embeds = _norm_preserving_subtract( + self.prompt_embeds, self._cached_negative_embed, negative_prompt_scale + ) + if self.sdxl and self._cached_negative_pooled is not None and self.add_text_embeds is not None: + self.add_text_embeds = _norm_preserving_subtract( + self.add_text_embeds, + self._cached_negative_pooled, + negative_prompt_scale, + ) + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + + def _make_prompts_key( + self, + prompts: list[dict], + interpolation_method: str, + width: int, + height: int, + ) -> tuple: + return ( + tuple((p.get("text", ""), p.get("weight", 1.0)) for p in prompts), + interpolation_method, + (width, height) if self.sdxl else (), + ) + + @staticmethod + def _hash_transition(transition: dict) -> str: + payload = { + "num_steps": int(transition.get("num_steps", 0) or 0), + "method": transition.get("temporal_interpolation_method", "linear"), + "target": [ + { + "text": p.get("text", "") if isinstance(p, dict) else str(p), + "weight": float(p.get("weight", 1.0)) if isinstance(p, dict) else 1.0, + } + for p in (transition.get("target_prompts") or []) + ], + } + encoded = json.dumps(payload, sort_keys=True).encode("utf-8") + return hashlib.sha1(encoded).hexdigest() + + def _begin_transition( + self, + target_prompts: list[dict], + interpolation_method: str, + num_steps: int, + temporal_method: str, + width: int, + height: int, + ) -> bool: + source_embedding = self._previous_prompt_embeddings + if source_embedding is None: + return False + + target_embed, target_pooled = self._encode_prompts_array( + target_prompts, + interpolation_method, + apply_sdxl_conditioning=False, + width=width, + height=height, + batch_size=1, + ) + target_embed_single = target_embed[0:1] + + # Eagerly move steady-state cache to the target so once the queue + # drains we land on the target prompts with no bounce-back. + self._cached_base_embed = target_embed_single + self._prompts_key = self._make_prompts_key( + target_prompts, interpolation_method, width, height + ) + + # Slerp NaNs at fp16 in the upstream blender (acos at the [-1, 1] + # boundary) — fall back to linear with a one-shot warn. + if temporal_method == "slerp": + if not self._slerp_fallback_warned: + print( + "[StreamDiffusion] slerp temporal interpolation is not " + "supported (fp16 NaN in upstream blender); falling back " + "to linear." + ) + self._slerp_fallback_warned = True + temporal_method = "linear" + + self.embedding_blender.start_transition( + source_embedding=source_embedding, + target_embedding=target_embed_single, + num_steps=num_steps, + temporal_interpolation_method=temporal_method, + ) + + if self.sdxl and target_pooled is not None: + self._pooled_source = ( + self.add_text_embeds.detach().clone() + if self.add_text_embeds is not None + else target_pooled.clone() + ) + self._pooled_target = target_pooled.clone() + self._transition_total_steps = max(1, num_steps) + else: + self._pooled_source = None + self._pooled_target = None + self._transition_total_steps = 0 + + # start_transition short-circuits when source ≈ target + # (MIN_EMBEDDING_DIFF_THRESHOLD); report accurately so the caller + # falls to steady state instead of assuming a transition is live. + if not self.embedding_blender.is_transitioning(): + self._finish_pooled_transition() + return False + return True + + def _advance_pooled_transition(self) -> None: + """Linearly interpolate ``add_text_embeds`` toward the target pooled.""" + if not self.sdxl or self._pooled_target is None: + return + if self._transition_total_steps <= 0: + return + remaining = len(self.embedding_blender._transition_queue) + done_steps = self._transition_total_steps - remaining + t = min(1.0, max(0.0, done_steps / self._transition_total_steps)) + source = ( + self._pooled_source + if self._pooled_source is not None + else self._pooled_target + ) + self.add_text_embeds = torch.lerp(source, self._pooled_target, t).to( + dtype=self.dtype, device=self.device + ) + + def _finish_pooled_transition(self) -> None: + """Snap pooled to the target and clear transition state.""" + if self.sdxl and self._pooled_target is not None: + self.add_text_embeds = self._pooled_target.to( + dtype=self.dtype, device=self.device + ) + self._pooled_source = None + self._pooled_target = None + self._transition_total_steps = 0 + + def _encode_single_prompt( + self, prompt_text: str + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + encoder_output = self.pipe.encode_prompt( + prompt=prompt_text, + device=self.device, + num_images_per_prompt=1, + do_classifier_free_guidance=False, + negative_prompt=None, + ) + prompt_embeds = encoder_output[0] + pooled_embeds = encoder_output[2] if self.sdxl else None + return prompt_embeds, pooled_embeds + + def _encode_prompts_array( + self, + prompt_items: list[dict], + interpolation_method: str, + *, + width: int, + height: int, + batch_size: int, + apply_sdxl_conditioning: bool = True, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if not prompt_items: + prompt_items = [{"text": "", "weight": 1.0}] + + texts = [item.get("text", "") for item in prompt_items] + weights = [item.get("weight", 1.0) for item in prompt_items] + + all_prompt_embeds = [] + all_pooled_embeds = [] if self.sdxl else None + + for text in texts: + prompt_embeds, pooled_embeds = self._encode_single_prompt(text) + all_prompt_embeds.append(prompt_embeds) + if self.sdxl and pooled_embeds is not None: + all_pooled_embeds.append(pooled_embeds) + + blended_prompt_embeds = self.embedding_blender.blend( + all_prompt_embeds, + weights, + interpolation_method, + cache_result=True, + ) + + blended_pooled_embeds = None + if self.sdxl and all_pooled_embeds: + blended_pooled_embeds = self.embedding_blender.blend( + all_pooled_embeds, + weights, + interpolation_method, + cache_result=False, + ) + + # SDXL aug-conditioning: write add_text_embeds and add_time_ids for + # the steady-state encode. Skipped for transition-target encodes so + # the in-flight pooled / time_ids aren't overwritten mid-morph. + if apply_sdxl_conditioning and self.sdxl and blended_pooled_embeds is not None: + self.add_text_embeds = blended_pooled_embeds + self.add_time_ids = self._compute_add_time_ids( + width=width, height=height, dtype=self.dtype + ) + + return blended_prompt_embeds.repeat(batch_size, 1, 1), blended_pooled_embeds + + def _compute_add_time_ids( + self, width: int, height: int, dtype: torch.dtype + ) -> torch.Tensor: + """Build SDXL aug-conditioning time_ids from the current dims. + + Reads ``self.pipe.unet.config.addition_time_embed_dim`` and + ``self.pipe.unet.add_embedding.linear_1.in_features`` to validate + the vector length matches what the UNet expects. Raises if not. + """ + original_size = (height, width) + crops_coords_top_left = (0, 0) + target_size = (height, width) + text_encoder_projection_dim = int(self.add_text_embeds.shape[-1]) + + add_time_ids_list = list(original_size + crops_coords_top_left + target_size) + unet = self.pipe.unet + passed_add_embed_dim = ( + unet.config.addition_time_embed_dim * len(add_time_ids_list) + + text_encoder_projection_dim + ) + expected_add_embed_dim = unet.add_embedding.linear_1.in_features + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length " + f"{expected_add_embed_dim}, but a vector of " + f"{passed_add_embed_dim} was created." + ) + return torch.tensor([add_time_ids_list], dtype=dtype) + + +def _norm_preserving_subtract( + positive: torch.Tensor, negative: torch.Tensor, scale: float +) -> torch.Tensor: + """Subtract ``scale * negative`` then rescale to match positive's + original per-row L2 norm. Direction shifts, magnitude is preserved, + UNet stays inside training distribution. + """ + neg = negative.to(device=positive.device, dtype=positive.dtype) + if neg.shape[0] != positive.shape[0]: + neg = neg[:1].expand_as(positive) + orig_norm = positive.norm(dim=-1, keepdim=True) + shifted = positive - scale * neg + new_norm = shifted.norm(dim=-1, keepdim=True).clamp(min=1e-6) + return shifted * (orig_norm / new_norm) diff --git a/src/scope_streamdiffusion/schema.py b/src/scope_streamdiffusion/schema.py index d872037..9051f0a 100644 --- a/src/scope_streamdiffusion/schema.py +++ b/src/scope_streamdiffusion/schema.py @@ -1,8 +1,9 @@ """Configuration schema for StreamDiffusion pipeline.""" +from enum import IntEnum, StrEnum from typing import Literal -from pydantic import Field +from pydantic import Field, field_validator from scope.core.pipelines.base_schema import ( BasePipelineConfig, InputMode, @@ -11,6 +12,37 @@ ) +class ModelId(StrEnum): + """Supported StreamDiffusion models (all 1-step distillations).""" + + SD_TURBO = "stabilityai/sd-turbo" + SDXL_TURBO = "stabilityai/sdxl-turbo" + DMD2_SDXL_1STEP = "dmd2-sdxl-1step" + + +class Resolution(IntEnum): + """Allowed pixel dimensions for width/height. + + Multiples of 64 in [256, 1024]. UNet downsamples latents 3x and ControlNet + residuals land at /8 in latent space, so pixel dims must divide by 64. TRT + engines are built for this dynamic range. + """ + + R256 = 256 + R320 = 320 + R384 = 384 + R448 = 448 + R512 = 512 + R576 = 576 + R640 = 640 + R704 = 704 + R768 = 768 + R832 = 832 + R896 = 896 + R960 = 960 + R1024 = 1024 + + class StreamDiffusionConfig(BasePipelineConfig): """Configuration for the StreamDiffusion pipeline.""" @@ -44,32 +76,45 @@ class StreamDiffusionConfig(BasePipelineConfig): enabled: bool = Field( default=True, description="Enable pipeline processing. When disabled, input video is passed through unchanged.", - json_schema_extra=ui_field_config(order=0, label="Enabled"), + #json_schema_extra=ui_field_config(order=0, label="Enabled"), ) input_mode: InputMode = Field( default="text", description="Input mode: 'text' generates from prompts only, 'video' transforms input frames", - json_schema_extra=ui_field_config(order=1, label="Input Mode"), + #json_schema_extra=ui_field_config(order=1, label="Input Mode"), ) # ======================================== # Model Configuration # ======================================== - model_id_or_path: str = Field( - default="stabilityai/sd-turbo", - description="Model ID from HuggingFace or local path to model", + model_id_or_path: ModelId = Field( + default=ModelId.SD_TURBO, + description=( + "Model selection. All entries are 1-step distillations. " + "'dmd2-sdxl-1step' is SDXL-base with the DMD2 distilled UNet " + "(tianweiy/DMD2) swapped in — quality bump over SDXL-Turbo per " + "the DMD2 paper. SDXL-derived entries auto-install the fp16-fix VAE." + ), + json_schema_extra=ui_field_config(order=8, label="Model"), ) acceleration_mode: Literal["none", "trt"] = Field( default="trt", description=( - "TRT-compile UNet (and ControlNet) for ~2-3x denoising speedup. " - "First build per (model, batch range) takes 5-10 min and caches to " - "~/.cache/scope-streamdiffusion-trt/. Set at session start; changing " - "requires pipeline reload. Engines support dynamic resolution 256-1024 " - "and batch 1-4." + "TRT-compile UNet (and ControlNet on SD 1.5) for 2-8x denoising " + "speedup. First build per model takes 5-10 min and caches to " + "~/.cache/scope-streamdiffusion-trt/. Hot-swappable at runtime: " + "toggling restores cached engines from process-scope cache " + "(instant) or builds them on first activation (stalls the " + "stream). SD 1.5 engines support dynamic resolution 256-1024 " + "and batch 1-4. SDXL engines (sdxl-turbo, dmd2-sdxl-1step) " + "support dynamic resolution 512-1024 with static batch=1 — " + "different envelope to fit a 24 GB VRAM build budget. SDXL + " + "ControlNet + TRT is not yet supported (raises " + "NotImplementedError); use acceleration_mode='none' with " + "controlnet on SDXL until that lands." ), json_schema_extra=ui_field_config(order=2, label="Acceleration"), ) @@ -77,7 +122,7 @@ class StreamDiffusionConfig(BasePipelineConfig): use_taesd: bool = Field( default=True, description="Use Tiny AutoEncoder (TAESD) for ~10x faster VAE decoding at slight quality cost", - json_schema_extra=ui_field_config(order=2, label="Use TAESD"), + #json_schema_extra=ui_field_config(order=2, label="Use TAESD"), ) controlnet_mode: Literal["none", "depth", "scribble"] = Field( @@ -99,7 +144,7 @@ class StreamDiffusionConfig(BasePipelineConfig): ge=0.0, le=1, description="Minimum depth value for ControlNet", - json_schema_extra=ui_field_config(order=5, label="Depth Min"), + #json_schema_extra=ui_field_config(order=5, label="Depth Min"), ) depth_max: float = Field( @@ -107,27 +152,27 @@ class StreamDiffusionConfig(BasePipelineConfig): ge=0.0, le=1, description="Maximum depth value for ControlNet", - json_schema_extra=ui_field_config(order=6, label="Depth Max"), + #json_schema_extra=ui_field_config(order=6, label="Depth Max"), ) depth_skip_interval: int = Field( - default=3, + default=2, ge=1, le=10, description="Run depth model every Nth frame; reuse cached depth map on intermediate frames. Higher = less GPU cost, more temporal lag.", - json_schema_extra=ui_field_config(order=7, label="Depth Skip Interval"), + #json_schema_extra=ui_field_config(order=7, label="Depth Skip Interval"), ) depth_input_size: Literal[252, 364, 518] = Field( - default=518, + default=252, description="Resolution the depth model runs at (must be multiple of 14). Lower = faster but coarser depth. 252 ≈ 4× faster than 518; the depth map is bilinear-upsampled to controlnet resolution either way.", - json_schema_extra=ui_field_config(order=8, label="Depth Input Size"), + #json_schema_extra=ui_field_config(order=8, label="Depth Input Size"), ) depth_temporal_cache: bool = Field( default=True, description="Use the video model's temporal hidden-state cache for inter-frame consistency. Disabling skips the temporal motion modules entirely (faster, slightly more flicker). Combined with skip interval > 1 the cache buys little, so toggle off for speed.", - json_schema_extra=ui_field_config(order=9, label="Depth Temporal Cache"), + #json_schema_extra=ui_field_config(order=9, label="Depth Temporal Cache"), ) controlnet_temporal_smoothing: float = Field( @@ -135,7 +180,7 @@ class StreamDiffusionConfig(BasePipelineConfig): ge=0.0, le=1.0, description="Temporal blending of the ControlNet conditioning map. 0.0 = fully smoothed (previous frame only), 1.0 = no smoothing (current frame only). Lower values reduce flicker; higher values reduce latency.", - json_schema_extra=ui_field_config(order=5, label="ControlNet Smoothing"), + #json_schema_extra=ui_field_config(order=5, label="ControlNet Smoothing"), ) # ======================================== @@ -146,14 +191,22 @@ class StreamDiffusionConfig(BasePipelineConfig): negative_prompt: str = Field( default="", description="Negative prompt — what to avoid in the generated image", - json_schema_extra=ui_field_config(order=11, label="Negative Prompt"), + #json_schema_extra=ui_field_config(order=11, label="Negative Prompt"), ) negative_prompt_scale: float = Field( - default=1.0, + default=0.5, ge=0.0, le=2.0, - description="Strength of embedding-space negative guidance (used when guidance_scale=0). Subtracts the negative prompt embedding from the positive. 0 = disabled, 1 = full subtraction.", + description=( + "Strength of embedding-space negative guidance for single-pass " + "models (Turbo, DMD2) that can't use standard CFG. The negative " + "embedding is subtracted from the positive then rescaled to " + "preserve magnitude — direction shifts but the result stays in " + "the UNet's training distribution. 0 = disabled. 0.3-0.7 is " + "typical; >1.0 starts to push out of distribution and outputs " + "may degrade or go to noise." + ), json_schema_extra=ui_field_config(order=12, label="Negative Scale"), ) @@ -171,7 +224,7 @@ class StreamDiffusionConfig(BasePipelineConfig): "0 = hard cut (can cause garbage frames); 8-30 is typical for smooth " "prompt morphs. Ignored when an explicit transition dict is sent." ), - json_schema_extra=ui_field_config(order=10, label="Transition Steps"), + #json_schema_extra=ui_field_config(order=10, label="Transition Steps"), ) seed: int = Field( @@ -182,6 +235,19 @@ class StreamDiffusionConfig(BasePipelineConfig): json_schema_extra=ui_field_config(order=13, label="Seed"), ) + seed_transition_steps: int = Field( + default=0, + ge=0, + le=240, + description=( + "Lerp the seed noise toward the new seed over N frames on each " + "seed change. 0 = hard cut. SDXL-Turbo and DMD2-1step have less " + "natural frame-to-frame correlation than SD-Turbo; this gives a " + "deterministic settle independent of the model." + ), + json_schema_extra=ui_field_config(order=14, label="Seed Transition Steps"), + ) + # ======================================== # Diffusion Parameters # ======================================== @@ -194,14 +260,6 @@ class StreamDiffusionConfig(BasePipelineConfig): # json_schema_extra=ui_field_config(order=20, label="Guidance Scale"), ) - num_inference_steps: int = Field( - default=2, - ge=1, - le=50, - description="Number of denoising steps", - # json_schema_extra=ui_field_config(order=21, label="Inference Steps"), - ) - strength: float = Field( default=0.99, ge=0.0, @@ -265,7 +323,7 @@ class StreamDiffusionConfig(BasePipelineConfig): image_loopback: bool = Field( default=False, description="Use last frame as input for the next generation", - json_schema_extra=ui_field_config(order=49, label="Image Loopback"), + #json_schema_extra=ui_field_config(order=49, label="Image Loopback"), ) # ======================================== @@ -279,7 +337,7 @@ class StreamDiffusionConfig(BasePipelineConfig): "mask. SD output goes where mask=1, original goes where mask=0. " "Flip directions by toggling the upstream segmenter's Invert Mask." ), - json_schema_extra=ui_field_config(order=55, label="Mask Compositing"), + #json_schema_extra=ui_field_config(order=55, label="Mask Compositing"), ) mask_feather: float = Field( @@ -290,7 +348,7 @@ class StreamDiffusionConfig(BasePipelineConfig): "Soft mask edges (pixels). 0 = hard edge. Cheap box-blur applied " "to the mask before compositing." ), - json_schema_extra=ui_field_config(order=56, label="Mask Feather"), + #json_schema_extra=ui_field_config(order=56, label="Mask Feather"), ) mask_strength: float = Field( @@ -301,24 +359,29 @@ class StreamDiffusionConfig(BasePipelineConfig): "Overall mask blend strength. 0 disables compositing, 1 is full effect. " "Use intermediate values to ghost the original through the SD output." ), - json_schema_extra=ui_field_config(order=57, label="Mask Strength"), + #json_schema_extra=ui_field_config(order=57, label="Mask Strength"), ) - # Resolution settings — must be a multiple of 64 (UNet downsamples latents - # 3x; ControlNet residuals go to /8 in latent space, so pixel dim has to - # divide by 64). TRT engines are built for the 256-1024 dynamic range. - width: Literal[ - 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024 - ] = Field( - default=512, + width: Resolution = Field( + default=Resolution.R512, description="Output width (multiple of 64, 256-1024)", - json_schema_extra=ui_field_config(order=60, label="Width"), + #json_schema_extra=ui_field_config(order=60, label="Width"), ) - height: Literal[ - 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024 - ] = Field( - default=512, + height: Resolution = Field( + default=Resolution.R512, description="Output height (multiple of 64, 256-1024)", - json_schema_extra=ui_field_config(order=61, label="Height"), + #json_schema_extra=ui_field_config(order=61, label="Height"), ) + + @field_validator("width", "height", mode="before") + @classmethod + def _validate_resolution(cls, v: object) -> Resolution: + try: + return Resolution(int(v)) # type: ignore[arg-type] + except (ValueError, TypeError) as e: + allowed = ", ".join(str(r.value) for r in Resolution) + raise ValueError( + f"Resolution must be one of: {allowed} (multiples of 64 in [256, 1024]); got {v!r}" + ) from e + diff --git a/src/scope_streamdiffusion/trt_engines.py b/src/scope_streamdiffusion/trt_engines.py index 74d81d7..5ed74db 100644 --- a/src/scope_streamdiffusion/trt_engines.py +++ b/src/scope_streamdiffusion/trt_engines.py @@ -369,6 +369,147 @@ def build_unet_engine( return engine_path +def build_unet_sdxl_engine( + unet: UNet2DConditionModel, + *, + model_id: str, + image_height: int = 1024, + image_width: int = 1024, + min_batch_size: int = 1, + max_batch_size: int = 1, + min_image_resolution: int = 512, + max_image_resolution: int = 1024, +) -> Path: + """Build (or reuse) a TRT engine for an SDXL UNet. + + Differs from `build_unet_engine` only in the I/O spec — adds + `text_embeds` and `time_ids` as engine inputs so SDXL's `get_aug_embed` + has the kwargs it expects. Without these the ONNX export crashes with + `TypeError: argument of type 'NoneType' is not iterable`. + + Dynamic-shape build over [min_image_resolution, max_image_resolution] + on both axes — runtime can pick any resolution in that range without + triggering a rebuild. The opt point (image_height, image_width) is + where TRT's tactic selection is centered; runs at the opt size are + fastest, runs at min/max get slightly suboptimal tactics. + + Default range 512–1024 is the tightest envelope that covers SDXL's + sweet spot. Wider ranges (256–1024) blow past the builder's memory + budget on 24 GB cards. Static batch (max=1) is kept because + guidance_scale=0 (default for Turbo / DMD2) means inference never + uses batch>1; allowing batch>1 doubles the workspace. + """ + from ._trt import UNetSDXL, compile_unet_sdxl, create_onnx_path + + suffix = ( + f"unet_sdxl_b{min_batch_size}-{max_batch_size}_" + f"h{min_image_resolution}-{max_image_resolution}_" + f"w{min_image_resolution}-{max_image_resolution}" + ) + cache_dir = _model_cache_dir(model_id, suffix) + onnx_dir = cache_dir / "onnx" + onnx_dir.mkdir(parents=True, exist_ok=True) + engine_path = cache_dir / "unet_sdxl.engine" + + if engine_path.exists(): + logger.info(f"[TRT] Reusing cached SDXL UNet engine: {engine_path}") + return engine_path + + logger.info(f"[TRT] Building SDXL UNet engine -> {engine_path} (5-10 min on first build)") + + unet_model = UNetSDXL( + fp16=True, + device=str(unet.device) if unet.device.type != "meta" else "cuda", + max_batch_size=max_batch_size, + min_batch_size=min_batch_size, + embedding_dim=unet.config.cross_attention_dim, + unet_dim=unet.config.in_channels, + ) + compile_unet_sdxl( + unet, unet_model, + str(create_onnx_path("unet_sdxl", str(onnx_dir), opt=False)), + str(create_onnx_path("unet_sdxl", str(onnx_dir), opt=True)), + str(engine_path), + opt_batch_size=min_batch_size, + engine_build_options={ + "build_dynamic_shape": True, + "build_static_batch": True, + # opt point — TRT's tactic selection is centered here. + "opt_image_height": image_height, + "opt_image_width": image_width, + # Min/max bounds for the dynamic shape envelope. + "min_image_resolution": min_image_resolution, + "max_image_resolution": max_image_resolution, + }, + ) + import shutil + if onnx_dir.exists(): + shutil.rmtree(onnx_dir, ignore_errors=True) + logger.info(f"[TRT] SDXL UNet engine built: {engine_path}") + return engine_path + + +class TRTUNetSDXLAdapter: + """Drop-in for diffusers SDXL UNet — accepts added_cond_kwargs.""" + + def __init__(self, engine_path: Path, cuda_stream, *, use_cuda_graph: bool = False): + from ._trt import UNet2DConditionModelSDXLEngine + self.engine = UNet2DConditionModelSDXLEngine( + str(engine_path), cuda_stream, use_cuda_graph=use_cuda_graph, + ) + self._use_cuda_graph = use_cuda_graph + self.config = _ConfigShim(sdxl=True) + # SDXL pipelines read `unet.add_embedding.linear_1.in_features` to + # size the add_time_ids tensor (must equal the original UNet's + # projection_class_embeddings_input_dim — 2816 for stock SDXL = + # 1280 text_embeds + 6 * 256 addition_time_embed_dim). + class _AddEmbeddingShim: + class _Linear1Shim: + in_features = 2816 + linear_1 = _Linear1Shim() + self.add_embedding = _AddEmbeddingShim() + + def __call__( + self, + sample: torch.Tensor, + timestep, + encoder_hidden_states: torch.Tensor, + added_cond_kwargs: dict | None = None, + return_dict: bool = True, + **kwargs, + ): + if not isinstance(timestep, torch.Tensor): + timestep = torch.tensor(timestep, device=sample.device) + if timestep.ndim == 0: + timestep = timestep.unsqueeze(0) + + if ( + added_cond_kwargs is None + or "text_embeds" not in added_cond_kwargs + or "time_ids" not in added_cond_kwargs + ): + raise RuntimeError( + "TRTUNetSDXLAdapter requires added_cond_kwargs with 'text_embeds' and 'time_ids'." + ) + + out = self.engine( + latent_model_input=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + text_embeds=added_cond_kwargs["text_embeds"], + time_ids=added_cond_kwargs["time_ids"], + ) + if return_dict: + return out + return (out.sample,) + + def to(self, *args, **kwargs): + return self + + def eval(self): + return self + + class TRTUNetAdapter: """Thin wrapper for the vendored UNet engine. @@ -435,13 +576,23 @@ def eval(self): class _ConfigShim: - """Diffusers config object surface — read by pipeline._prepare_runtime_state.""" + """Diffusers config object surface — read by pipeline._prepare_runtime_state. + + Defaults match SD 1.5/2.1: addition_time_embed_dim=None (no aug + conditioning), cross_attention_dim=1024. Pass `sdxl=True` for SDXL + where the pipeline reads addition_time_embed_dim to size add_time_ids + and cross_attention_dim for the encoder hidden state. + """ - def __init__(self): - self.addition_time_embed_dim = None + def __init__(self, sdxl: bool = False): + if sdxl: + self.addition_time_embed_dim = 256 # SDXL standard + self.cross_attention_dim = 2048 + else: + self.addition_time_embed_dim = None + self.cross_attention_dim = 1024 self.in_channels = 4 self.out_channels = 4 - self.cross_attention_dim = 1024 def build_taesd_engines(