Skip to content

Add FlashVSR contrib model with video super-resolution on Neuron#165

Open
jimburtoft wants to merge 11 commits into
aws-neuron:mainfrom
jimburtoft:contrib/flashvsr
Open

Add FlashVSR contrib model with video super-resolution on Neuron#165
jimburtoft wants to merge 11 commits into
aws-neuron:mainfrom
jimburtoft:contrib/flashvsr

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

@jimburtoft jimburtoft commented May 18, 2026

Note: The below template includes items meant for model contributions only. For other contributions such as bug fixes, features, etc., only fill out the relevant portions of the form.

Description

FlashVSR is a video super-resolution model (4x upscaling) using a streaming DiT architecture based on Wan 2.1 1.3B. This contrib packages the DiT backbone for Neuron via NxDI ModelBuilder with TP=4, using NKI tiled flash attention (attention_cte) for 23040-token sequences that would otherwise OOM.

The pipeline processes video in overlapping chunks: a first chunk (6 latent frames → 24 output frames) followed by stream chunks (2 latent frames → 8 output frames each). Single-step DMD denoising enables efficient 4x upscaling at 768×1280 output resolution.

Model Information

Model Name: FlashVSR v1.1 (JunhaoZhuang/FlashVSR-v1.1)

Model Architecture: 30-layer DiT (dim=1536, 12 heads, head_dim=128) with factored 3D RoPE, LCSA self-attention, text cross-attention, AdaLN modulation, QK-norm with DistributedRMSNorm

Purpose: Video super-resolution (4x spatial upscaling, 480p → 1920p)

Checklist

Required Components

  • Accuracy Test (test/integration/test_dit_accuracy.py)

    • Validates DiT output against CPU reference using neuron_allclose(rtol=0.05, atol=0.1)
    • Measured: max_rel_error=0.025, max_abs_error=0.066, cosine_similarity=0.9997
    • Complementary cosine similarity assertion (>0.999 threshold)
    • 5% rtol justified for 30-layer BF16 DiT with TP=4 (cf. NxDI MLP tests use rtol=6e-2)
  • README.md with the following sections:

    • Usage Example: Clear code example showing compile → load → inference
    • Compatibility Matrix: trn2.3xlarge TP=4 LNC=2, SDK 2.29
    • Example Checkpoints: Link to JunhaoZhuang/FlashVSR-v1.1 on HuggingFace
    • Testing Instructions: pytest commands for accuracy and E2E tests
  • Source Code (src/)

    • modeling_flashvsr.py — NxDI-compatible DiT with Application/ModelWrapper/InferenceConfig (1242 lines)
    • pipeline.py — Full inference pipeline orchestration
    • tcdecoder.py — TCDecoder (latent → RGB) wrapper
    • lq_projection.py — LQ conditioning projection wrapper
    • weights.py — Weight format detection and conversion (DiffSynth/diffusers → Neuron)
    • download_weights.py — HuggingFace weight download utility

Optional Components

  • E2E Pipeline Test (test/integration/test_pipeline_e2e.py) — PSNR validation
  • Unit Tests — Not included

Folder Structure

/contrib/models/FlashVSR/
  README.md
  /src
    __init__.py
    modeling_flashvsr.py
    pipeline.py
    tcdecoder.py
    lq_projection.py
    weights.py
    download_weights.py
  /test
    __init__.py
    /integration
      __init__.py
      test_dit_accuracy.py
      test_pipeline_e2e.py

Testing

How did you test this change?

Compiled and tested on trn2.3xlarge (LNC=2, 4 NeuronCores) with Neuron SDK 2.29.1 (DLAMI 20260502, neuronx-cc 2.24.8799.0, NxDI 0.9.17334). DiT first-chunk compiled via NxDI ModelBuilder (TP=4, BF16) and validated against CPU reference model with identical weights.

Test Results:

neuron_allclose(rtol=0.05, atol=0.1):
  allclose: True
  max_rel_error: 0.025
  max_abs_error: 0.066
  cosine_similarity: 0.9997

DiT first-chunk latency (5 iters, post-warmup): 1540 ± 68 ms
Output shape: [1, 16, 6, 96, 160] (correct)
Weight loading: 0 missing, 0 unexpected keys

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.29.1 (neuronx-cc 2.24.8799.0)
  • Instance Type(s): trn2.3xlarge (LNC=2, TP=4)
  • PyTorch Version: 2.9.0
  • Python Version: 3.12

Additional Information

  • Uses attention_cte NKI kernel from nkilib for tiled flash attention (avoids materializing full S×S attention matrix in HBM)
  • DistributedRMSNorm for QK-norm with all-reduce across TP ranks
  • Single-step DMD (Distribution Matching Distillation) — one DiT forward pass per chunk
  • The attn_mask input is unused in Phase 1 (dense attention); kept for future Phase 2 LCSA block-sparse support on larger instances (trn2.48xlarge TP=16)
  • Weight conversion supports both DiffSynth and diffusers checkpoint formats

Related Issues

None.

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions

Not applicable — FlashVSR is a video generation model, not an LLM.


By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

DiT accuracy test now uses rtol=0.05, atol=0.1 (validated on trn2.3xlarge):
- max_rel_error: 0.025, cosine_similarity: 0.9997
- 5% rtol is standard for 30-layer BF16 DiT with TP=4 (cf NxDI MLP test rtol=6e-2)
- Added complementary cosine similarity check (>0.999 threshold)
- Updated README accuracy section with measured values
Migrate TCDecoder from torch_neuronx.trace() to NxDI ModelBuilder with
input_output_aliases for MemBlock state persistence in device HBM.

Key changes:
- tcdecoder.py: Add NeuronTCDecoderStateful (stateful nn.Module with
  9 state Parameters), TCDecoderApplication (ModelBuilder wrapper),
  and decode_video_nxdi() inference helper
- pipeline.py: compile_pipeline() now compiles TCDecoder via NxDI,
  load_pipeline() loads NxDI TCDecoder with legacy fallback,
  run_inference() dispatches to NxDI or trace path automatically
- __init__.py: Export new NxDI classes

Performance (validated on trn2.3xlarge, SDK 2.29.1):
- Per-frame latency: 78ms (vs 237ms trace baseline) = 3.0x faster
- Compilation: 2.1s with NEFF cache, ~226s fresh
- Accuracy: cosine >0.9995, neuron_allclose PASS (rtol=0.05)
- Output shape: (4, 3, 768, 1280) per frame
@jimburtoft
Copy link
Copy Markdown
Contributor Author

Update: NxDI TCDecoder with HBM State Persistence (3.0x decode speedup)

Migrated the TCDecoder from torch_neuronx.trace() to NxDI ModelBuilder with input_output_aliases for MemBlock state persistence in device HBM.

What changed

  • TCDecoder now uses NeuronTCDecoderStateful — an nn.Module with 9 state Parameters that persist in HBM between NEFF calls via input_output_aliases
  • compile_pipeline() compiles TCDecoder via NxDI ModelBuilder (no separate trace step needed)
  • load_pipeline() auto-detects and loads NxDI TCDecoder (with legacy trace fallback)
  • New decode_video_nxdi() function — states persist in HBM, no PCIe per frame

Performance (trn2.3xlarge, SDK 2.29.1)

Metric trace baseline NxDI (this PR) Improvement
Per-frame latency 237 ms 78 ms 3.04x faster
Total decode (22 frames) 5,210 ms 1,717 ms 3.03x faster
Compilation time ~5 min 2.1s (cached) / 226s (fresh) ~1.3x faster
Output shape (4, 3, 768, 1280) (4, 3, 768, 1280) Identical

Why it's faster

The trace-based approach transfers 9 MemBlock state tensors (total ~100MB) over PCIe on every frame call. With input_output_aliases, states remain in device HBM — the compiler writes updated states back to the same memory locations as zero-copy aliases. Only the 784-channel input frame crosses PCIe per call.

Validated

  • Compilation: PASS
  • Load + weight initialization: PASS
  • Output shape: PASS (4, 3, 768, 1280)
  • Latency: 78 ms/frame (target was <80ms)
  • Speedup: 3.04x (target was 3.0x)

Fix TCDecoder output shape issue at TP>1 where NxDI shards the temporal
dimension across ranks. Reshape (4, 3, H, W) -> (1, 12, H, W) inside the
NEFF to keep batch dim=1, then reshape back after extraction.

Enable co-resident model loading: DiT + TCDecoder both loaded in HBM at
startup (total ~15 GB out of 96 GB). Eliminates the 3.2s unload/reload
transition between pipeline stages.

Validated on trn2.3xlarge (SDK 2.29.1, TP=4, LNC=2):
- TCDecoder: 91.7 ms/call, 22 calls -> 85 frames in 2.0s
- E2E pipeline: 10.3 FPS (up from 7.3 FPS with transition overhead)
if HAS_NXDI:
try:
return get_tensor_model_parallel_size()
except Exception:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generic exception, might swallow all other possible exceptions thrown.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 72d94ed. Narrowed to except RuntimeError which is the only expected exception (parallel context not initialized when running outside NxDI).

Comment on lines +344 to +347
_nxd_trace.__SUPPORTED_SHARDED_MODULES = (
*_nxd_trace.__SUPPORTED_SHARDED_MODULES,
DistributedRMSNorm,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overriding private attributes may not be reliable across NxD/I versions, e.g. in the case of upstream refactors, this may break.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Fixed in 72d94ed: added hasattr(_nxd_trace, '__SUPPORTED_SHARDED_MODULES') guard so it fails gracefully if the internal registry is renamed/removed in a future NxDI release. Also added an inline comment explaining why the registration is needed (NxD ModelBuilder rejects DistributedRMSNorm during compilation without it).

Comment thread contrib/models/FlashVSR/src/pipeline.py Outdated

pipeline = FlashVSRPipeline(config=config)

# Patch ThreadPoolExecutor for NxDI load
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the patch required? An inline comment here to explain would help.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added detailed inline comment in 72d94ed explaining the rationale: NxDI ModelBuilder uses ThreadPoolExecutor internally to load weights in parallel across TP ranks. In single-process mode (no torchrun), all ranks share one process and the default thread pool can deadlock or race on shared state. Limiting to 1 worker serializes rank loading. The patch is restored after load completes.

gc.collect()

# Step 3: Streaming DiT inference
process_total_num = (num_frames - 1) // 8 - 2
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this math safe? This seems like process_total_num can be negative, and I'm not too sure what the consequences of that are.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, it could go negative with very short videos. Fixed in 72d94ed: added an explicit if process_total_num < 1: raise ValueError(...) with a clear message stating the minimum requirement (25 frames in 8n+1 format). Also added an inline comment explaining the formula derivation.

Copy link
Copy Markdown

@lutfanm-aws lutfanm-aws left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to compile and upscale a video, but the output I got did not seem up to standard. Is there a video that you upscaled on your end, and the resulting output that you can share? Are there specific instructions I can follow so I can reproduce on my end? I'm trying to make sure the poor output quality I'm seeing is not my doing.

Full pipeline notebook: LQ projection + DiT streaming + TCDecoder + Neuron
AdaIN color correction. Includes F.interpolate identity-skip optimization.
All models co-resident on TP=4 (trn2.3xlarge LNC=2).

Results: 10.3 FPS Neuron-only, 9.8 FPS full E2E (85 frames @ 768x1280).
…mments

- Narrow bare 'except Exception' to 'except RuntimeError' in _get_tp_degree
- Guard __SUPPORTED_SHARDED_MODULES override with hasattr + explain why needed
- Add detailed comment explaining ThreadPoolExecutor single-worker patch
- Add ValueError guard for process_total_num < 1 with minimum frame requirement
Add notebook-based quick start, troubleshooting section for output quality
issues (missing LQ projection, color correction, state reset, frame count).
Remove incorrect path references (tcdecoder_seq.pt, lq_proj.pt).
@jimburtoft
Copy link
Copy Markdown
Contributor Author

Addressing Output Quality Concern

Thanks for testing! I've pushed fixes for all inline comments (72d94ed) and added a troubleshooting guide to the README (e3eb76b).

Reference Output

A validated output video is now included at notebooks/output_sample.mp4 (85 frames, 768x1280, 4x upscaled). This was generated by the notebook notebooks/tcdecoder_benchmark.ipynb which is fully executable and includes expected outputs inline.

Most Common Causes of Poor Output

  1. Missing LQ projection — Without all_lq_tokens passed to the DiT via lq_residual_0, the output will be blurry/generic (the model has no content guidance).
  2. Missing color correction — The raw DiT+TCDecoder output has slight color drift from BF16. The AdaIN step (Stage 4 in the notebook) corrects this.
  3. TCDecoder state not reset — You must call tcd_app.reset_states() before each new video. Leftover state corrupts output.
  4. Wrong prompt embedding — Must use posi_prompt.pth from FlashVSR-v1.1 weights.

Recommended Reproduction Steps

The notebook is the easiest path to reproduce validated results:

source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate
cd contrib/models/FlashVSR
# Assumes weights at ~/FlashVSR-v1.1 and compiled models exist
jupyter nbconvert --to notebook --execute \
    notebooks/tcdecoder_benchmark.ipynb \
    --output tcdecoder_benchmark_executed.ipynb

The notebook handles all stages including compilation path setup, LQ projection, streaming DiT, TCDecoder with state management, and Neuron-accelerated AdaIN color correction.

If you can share a frame from your output, I can help diagnose which stage might be causing the quality issue.

Support compiling multiple stream bucket sizes (e.g., f=8,4,2) within
a single NxDI Application. All NEFFs load co-resident in HBM — no swap
overhead when switching between bucket sizes at runtime.

Set via: FLASHVSR_STREAM_BUCKETS=8,4,2

Includes build_greedy_chunk_schedule() utility that assigns chunks to
the largest available bucket first. For 1-min video: reduces DiT calls
from 221 (f=2 only) to 56 (f=8 primary, f=4 remainder).
- neuron_dit_forward now accepts explicit temporal_offset for multi-bucket
  scheduling (required when frame count varies between chunks)
- Add test_multi_bucket.py: compiles f=8,4,2 co-resident, benchmarks each
  bucket, simulates 1-min video with greedy scheduler
…r full-attention DiT

Tested f=8, f=4, f=2 co-resident on trn2.3xlarge (SDK 2.30). Per-frame latency
scales super-linearly due to full temporal self-attention: f=8 is 77% slower/frame
than f=2. Multi-bucket code retained as reference for windowed-attention models.

Also adds SDK 2.30 validation (22% DiT speedup, 12.6 FPS) to compatibility matrix.
if pipeline.tcdecoder_model is not None and pipeline.tc_pixel_shuffle is not None:
from .tcdecoder import TCDecoderApplication, decode_video_nxdi

LQ_cur_idx = process_total_num * 8 + 21 if process_total_num > 0 else 21
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get a RuntimeError: Sizes of tensors must match... Expected 19, got 18 error which seems to be caused by this line.

Also, your Notebook seems to use a different formular to compute this. Can we converge on one (correct) formula?

python -m src.download_weights --output-dir ~/FlashVSR-v1.1

# 3. Compile all models (one-time, ~30 min)
python -m src.pipeline compile \
Copy link
Copy Markdown

@lutfanm-aws lutfanm-aws May 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pipeline.py does not contain a main function that would handle different args so I'm not confident that these instructions would work


## Known Issues

- **Resolution constraint:** Input video must produce output dimensions divisible by 128 (e.g., 768x1280). Other resolutions require recompilation.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you re-compile for all resolutions? Compiling for a resolution larger than one you've validated might result in a compilation error for having too many instructions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants