Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions contrib/models/S3Diff/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Contrib Model: S3Diff

S3Diff one-step 4x super-resolution on AWS Neuron using `torch_neuronx.trace()`.
Supports arbitrary input resolutions via tiling with Gaussian blending.

## Model Information

- **HuggingFace ID:** `zhangap/S3Diff` (weights), base model `stabilityai/sd-turbo`
- **Model Type:** One-step diffusion model for image super-resolution
- **Parameters:** ~1.3B total (~2 GB on disk)
- **Architecture:** SD-Turbo UNet with degradation-guided dynamic LoRA modulation (DEResNet encoder, CLIP text encoder, VAE encoder/decoder with per-layer LoRA, UNet with per-layer LoRA)
- **Paper:** "Degradation-Guided One-Step Image Super-Resolution with Diffusion Priors" (ECCV 2024)
- **License:** Check model cards for SD-Turbo and S3Diff

## Key Architecture Notes

S3Diff is unusual among diffusion models:

1. **Single denoising step**: Only one UNet forward pass per image (at t=999), making it extremely fast.
2. **Dynamic LoRA modulation**: A DEResNet encoder estimates input degradation and produces per-layer LoRA scaling matrices. These `[rank, rank]` modulation matrices are injected between `lora_A` and `lora_B` via einsum operations, conditioning the UNet on the specific degradation pattern of each input.
3. **Two LoRA ranks**: VAE uses rank=16 (6 blocks), UNet uses rank=32 (10 blocks).
4. **Small model**: Total size ~2 GB, fits on a single NeuronCore with no tensor parallelism needed.
5. **Arbitrary resolution via tiling**: All components are compiled at a fixed tile size (512x512 pixels). Images larger than this are split into overlapping tiles, processed independently, and blended with Gaussian weights for seamless output.

This contrib uses `torch_neuronx.trace()` rather than NxDI tensor parallelism, which is appropriate for the model's small size and non-autoregressive architecture.

## Validation Results

**Validated:** 2026-05-06
**Instance:** trn2.3xlarge (LNC=2)
**SDK:** Neuron SDK 2.29 (DLAMI 20260410), PyTorch 2.9

### Benchmark Results (multi-resolution, single step)

| Input Size | Output Size | Tiles | Time | Throughput |
|-----------|-------------|-------|------|------------|
| 128x128 | 512x512 | 1 | 0.545s | 1.8 img/s |
| 256x256 | 1024x1024 | 9 | 4.809s | 0.21 img/s |
| 512x512 | 2048x2048 | 25 | 13.346s | 0.075 img/s |

### Component Timing (single tile, 512x512 output)

| Component | Time |
|-----------|------|
| DEResNet | 3.8ms |
| Modulation (CPU) | 0.5ms |
| VAE Encode | 83.2ms |
| UNet x2 (CFG) | 218.8ms |
| VAE Decode | 164.6ms |
| **Total (single tile)** | **0.471s** |

| Metric | Value |
|--------|-------|
| Inference steps | 1 (one-step model) |
| Total compile time | ~21 min |
| CPU baseline (128->512) | 11.53s |
| Speedup vs CPU | ~21x |

### Accuracy Validation

Visual quality validated against CPU reference output. The model produces high-quality 4x upscaled images with correct degradation-aware enhancement. Tiled outputs show pixel std > 20 (indicating good visual detail) with seamless blending at tile boundaries.

## Usage

```python
from S3Diff.src.modeling_s3diff import S3DiffNeuronPipeline
from PIL import Image

pipeline = S3DiffNeuronPipeline(
sd_turbo_path="/shared/sd-turbo/",
s3diff_weights_path="/shared/s3diff/s3diff.pkl",
de_net_path="/shared/s3diff/de_net.pth",
compile_dir="/tmp/s3diff/compiled/",
lr_size=128, # DEResNet fixed input (always 128)
tile_size=512, # HR tile size (default)
tile_overlap=128, # Overlap for blending (default)
)
pipeline.load()
pipeline.compile()

# Works with any input size -- tiling is automatic
lr_image = Image.open("input.png").convert("RGB")
sr_image = pipeline(lr_image) # 4x upscaled output
sr_image.save("output.png")
```

Or use the provided script:

```bash
# 128x128 -> 512x512 (single tile, fast)
python src/generate_s3diff.py \
--input_image input_128.png \
--output_image output_512.png

# 256x256 -> 1024x1024 (tiled)
python src/generate_s3diff.py \
--input_image input_256.png \
--output_image output_1024.png

# Custom tile settings
python src/generate_s3diff.py \
--input_image input.png \
--output_image output.png \
--tile_size 512 \
--tile_overlap 128
```

## Setup

```bash
# Activate NxDI environment
source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate

# Install dependencies
pip install diffusers transformers peft accelerate torchvision

# Download weights
python src/generate_s3diff.py --download

# Or manually:
# SD-Turbo: huggingface-cli download stabilityai/sd-turbo --local-dir /shared/sd-turbo/
# S3Diff: huggingface-cli download zhangap/S3Diff --local-dir /shared/s3diff/
# DEResNet: git clone https://github.com/ArcticHare105/S3Diff.git /tmp/s3diff_repo
# cp /tmp/s3diff_repo/assets/mm-realsr/de_net.pth /shared/s3diff/
```

## Tiling Design

For images whose HR output (input x 4) exceeds 512x512 pixels, the pipeline automatically:

1. **Upscales** the input image to HR resolution via bicubic interpolation
2. **Splits** the HR image into overlapping 512x512 tiles (128px overlap by default)
3. **Processes** each tile independently through VAE encode -> UNet -> VAE decode
4. **Blends** tile outputs using Gaussian weights (smooth center-to-edge falloff)

This approach avoids recompilation for different resolutions and produces seamless outputs. The degradation estimation (DEResNet) runs once on the full image resized to 128x128, producing global modulation parameters shared across all tiles.

## Compatibility Matrix

| Instance/Version | SDK 2.29 | SDK 2.28 |
|------------------|----------|----------|
| trn2.3xlarge | VALIDATED | Not tested |

## Example Checkpoints

* [zhangap/S3Diff](https://huggingface.co/zhangap/S3Diff) -- S3Diff LoRA weights
* [stabilityai/sd-turbo](https://huggingface.co/stabilityai/sd-turbo) -- Base SD-Turbo model

## Testing Instructions

```bash
export SD_TURBO_PATH=/shared/sd-turbo/
export S3DIFF_WEIGHTS=/shared/s3diff/s3diff.pkl
export DE_NET_WEIGHTS=/shared/s3diff/de_net.pth

cd contrib/models/S3Diff/
pytest test/integration/test_model.py -v

# Or standalone
python test/integration/test_model.py
```

## Known Issues

- **LoRA + `--auto-cast=matmult` produces NaN**: The LoRA modulation einsum operations are numerically unstable when `--auto-cast=matmult` casts them to BF16. All components with LoRA use `--model-type=unet-inference` instead. Only DEResNet and text encoder (no LoRA) use `--auto-cast=matmult`.
- **Compilation time**: ~21 minutes total (UNet is the slowest at ~12 min). Compiled models are cached for reuse.
- **CFG is sequential**: Two separate UNet passes (positive + negative prompt), not batched. Batching with batch_size=2 would halve UNet wall time but requires recompilation.
- **Neuron runtime HBM**: Once loaded, compiled models stay in HBM even if the Python object is deleted (within the same process). Plan memory accordingly.
- **Tiling artifacts at very high resolution**: At 4K+ output, very minor blending seams may be visible in uniform regions. Increasing `--tile_overlap` to 192 or 256 reduces this at the cost of more tiles.
3 changes: 3 additions & 0 deletions contrib/models/S3Diff/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# S3Diff one-step super-resolution on AWS Neuron
# Uses torch_neuronx.trace() for compilation -- not NxDI TP sharding
# (model is ~2 GB, fits on a single NeuronCore)
187 changes: 187 additions & 0 deletions contrib/models/S3Diff/src/generate_s3diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""
S3Diff one-step 4x super-resolution on AWS Neuron.

Downloads required weights (SD-Turbo, S3Diff LoRA, DEResNet), compiles all
components, and runs super-resolution inference. Supports arbitrary input
resolutions via tiling (images whose 4x upscaled size exceeds 512x512 are
automatically processed with overlapping tiles).

Usage:
python generate_s3diff.py \
--input_image /path/to/lr_image.png \
--output_image /path/to/sr_output.png \
--compile_dir /tmp/s3diff/compiled/

# Multi-resolution examples:
# 128x128 input -> 512x512 output (single tile, ~0.5s)
# 256x256 input -> 1024x1024 output (9 tiles, ~4.8s)
# 512x512 input -> 2048x2048 output (25 tiles, ~13.3s)

Requirements:
pip install diffusers transformers peft accelerate torchvision
"""

import argparse
import os
import time

import torch
from PIL import Image

try:
from .modeling_s3diff import S3DiffNeuronPipeline
except ImportError:
from modeling_s3diff import S3DiffNeuronPipeline


DEFAULT_SD_TURBO_PATH = "/shared/sd-turbo/"
DEFAULT_S3DIFF_WEIGHTS = "/shared/s3diff/s3diff.pkl"
DEFAULT_DE_NET_WEIGHTS = "/shared/s3diff/de_net.pth"
DEFAULT_COMPILE_DIR = "/tmp/s3diff/compiled/"


def download_weights(sd_turbo_path, s3diff_weights_path, de_net_path):
"""Download model weights if not already present."""
from huggingface_hub import hf_hub_download, snapshot_download

if not os.path.exists(sd_turbo_path):
print("Downloading SD-Turbo...")
snapshot_download("stabilityai/sd-turbo", local_dir=sd_turbo_path)

if not os.path.exists(s3diff_weights_path):
print("Downloading S3Diff weights...")
os.makedirs(os.path.dirname(s3diff_weights_path), exist_ok=True)
hf_hub_download(
"zhangap/S3Diff",
filename="s3diff.pkl",
local_dir=os.path.dirname(s3diff_weights_path),
)

if not os.path.exists(de_net_path):
print("Downloading DEResNet weights...")
os.makedirs(os.path.dirname(de_net_path), exist_ok=True)
# DEResNet weights are in the S3Diff GitHub repo
import subprocess

repo_dir = "/tmp/s3diff_repo"
if not os.path.exists(repo_dir):
subprocess.run(
[
"git",
"clone",
"https://github.com/ArcticHare105/S3Diff.git",
repo_dir,
],
check=True,
)
import shutil

shutil.copy2(
os.path.join(repo_dir, "assets", "mm-realsr", "de_net.pth"),
de_net_path,
)


def main():
parser = argparse.ArgumentParser(description="S3Diff 4x SR on AWS Neuron")
parser.add_argument(
"--input_image",
type=str,
default=None,
help="Path to input low-resolution image (any size; will be 4x upscaled)",
)
parser.add_argument(
"--output_image", type=str, default="sr_output.png", help="Output path"
)
parser.add_argument("--sd_turbo_path", type=str, default=DEFAULT_SD_TURBO_PATH)
parser.add_argument("--s3diff_weights", type=str, default=DEFAULT_S3DIFF_WEIGHTS)
parser.add_argument("--de_net_weights", type=str, default=DEFAULT_DE_NET_WEIGHTS)
parser.add_argument("--compile_dir", type=str, default=DEFAULT_COMPILE_DIR)
parser.add_argument("--num_images", type=int, default=3)
parser.add_argument("--warmup_rounds", type=int, default=5)
parser.add_argument(
"--tile_size",
type=int,
default=512,
help="Pixel-space tile size for VAE/UNet (default: 512). Must be divisible by 8.",
)
parser.add_argument(
"--tile_overlap",
type=int,
default=128,
help="Pixel-space overlap between tiles (default: 128). Must be divisible by 8.",
)
parser.add_argument("--download", action="store_true", help="Download weights")
args = parser.parse_args()

if args.download:
download_weights(args.sd_turbo_path, args.s3diff_weights, args.de_net_weights)

# Create a test image if none provided
if args.input_image is None:
print("No input image provided, creating a 128x128 test pattern...")
import numpy as np

test_img = np.random.randint(0, 255, (128, 128, 3), dtype=np.uint8)
lr_image = Image.fromarray(test_img)
else:
lr_image = Image.open(args.input_image).convert("RGB")

lr_w, lr_h = lr_image.size
hr_w, hr_h = lr_w * 4, lr_h * 4
print(f"Input image: {lr_w}x{lr_h} -> Output: {hr_w}x{hr_h}")
if hr_h > args.tile_size or hr_w > args.tile_size:
print(
f"Tiling enabled (tile_size={args.tile_size}, overlap={args.tile_overlap})"
)

# Build pipeline (lr_size is always 128 for DEResNet)
pipeline = S3DiffNeuronPipeline(
sd_turbo_path=args.sd_turbo_path,
s3diff_weights_path=args.s3diff_weights,
de_net_path=args.de_net_weights,
compile_dir=args.compile_dir,
lr_size=128,
tile_size=args.tile_size,
tile_overlap=args.tile_overlap,
)

print("\nLoading model...")
pipeline.load()

print("\nCompiling...")
t0 = time.time()
pipeline.compile()
compile_time = time.time() - t0
print(f"Total compilation: {compile_time:.1f}s")

# Warmup
print(f"\nWarming up ({args.warmup_rounds} rounds)...")
for _ in range(args.warmup_rounds):
pipeline(lr_image)

# Benchmark
print(f"\nGenerating {args.num_images} images...")
total_time = 0
for i in range(args.num_images):
t0 = time.time()
sr_image = pipeline(lr_image)
elapsed = time.time() - t0
total_time += elapsed
print(f" Image {i + 1}: {elapsed:.3f}s")

avg_time = total_time / args.num_images
print(f"\nResults:")
print(f" Average time: {avg_time:.3f}s")
print(f" Throughput: {1.0 / avg_time:.2f} img/s")
print(f" Compilation: {compile_time:.1f}s")

sr_image.save(args.output_image)
print(f" Saved: {args.output_image}")


if __name__ == "__main__":
main()
Loading