Add FlashVSR contrib model with video super-resolution on Neuron#165
Add FlashVSR contrib model with video super-resolution on Neuron#165jimburtoft wants to merge 11 commits into
Conversation
DiT accuracy test now uses rtol=0.05, atol=0.1 (validated on trn2.3xlarge): - max_rel_error: 0.025, cosine_similarity: 0.9997 - 5% rtol is standard for 30-layer BF16 DiT with TP=4 (cf NxDI MLP test rtol=6e-2) - Added complementary cosine similarity check (>0.999 threshold) - Updated README accuracy section with measured values
Migrate TCDecoder from torch_neuronx.trace() to NxDI ModelBuilder with input_output_aliases for MemBlock state persistence in device HBM. Key changes: - tcdecoder.py: Add NeuronTCDecoderStateful (stateful nn.Module with 9 state Parameters), TCDecoderApplication (ModelBuilder wrapper), and decode_video_nxdi() inference helper - pipeline.py: compile_pipeline() now compiles TCDecoder via NxDI, load_pipeline() loads NxDI TCDecoder with legacy fallback, run_inference() dispatches to NxDI or trace path automatically - __init__.py: Export new NxDI classes Performance (validated on trn2.3xlarge, SDK 2.29.1): - Per-frame latency: 78ms (vs 237ms trace baseline) = 3.0x faster - Compilation: 2.1s with NEFF cache, ~226s fresh - Accuracy: cosine >0.9995, neuron_allclose PASS (rtol=0.05) - Output shape: (4, 3, 768, 1280) per frame
f6b9828 to
6db2715
Compare
Update: NxDI TCDecoder with HBM State Persistence (3.0x decode speedup)Migrated the TCDecoder from What changed
Performance (trn2.3xlarge, SDK 2.29.1)
Why it's fasterThe trace-based approach transfers 9 MemBlock state tensors (total ~100MB) over PCIe on every frame call. With Validated
|
Fix TCDecoder output shape issue at TP>1 where NxDI shards the temporal dimension across ranks. Reshape (4, 3, H, W) -> (1, 12, H, W) inside the NEFF to keep batch dim=1, then reshape back after extraction. Enable co-resident model loading: DiT + TCDecoder both loaded in HBM at startup (total ~15 GB out of 96 GB). Eliminates the 3.2s unload/reload transition between pipeline stages. Validated on trn2.3xlarge (SDK 2.29.1, TP=4, LNC=2): - TCDecoder: 91.7 ms/call, 22 calls -> 85 frames in 2.0s - E2E pipeline: 10.3 FPS (up from 7.3 FPS with transition overhead)
| if HAS_NXDI: | ||
| try: | ||
| return get_tensor_model_parallel_size() | ||
| except Exception: |
There was a problem hiding this comment.
Generic exception, might swallow all other possible exceptions thrown.
There was a problem hiding this comment.
Fixed in 72d94ed. Narrowed to except RuntimeError which is the only expected exception (parallel context not initialized when running outside NxDI).
| _nxd_trace.__SUPPORTED_SHARDED_MODULES = ( | ||
| *_nxd_trace.__SUPPORTED_SHARDED_MODULES, | ||
| DistributedRMSNorm, | ||
| ) |
There was a problem hiding this comment.
Overriding private attributes may not be reliable across NxD/I versions, e.g. in the case of upstream refactors, this may break.
There was a problem hiding this comment.
Good point. Fixed in 72d94ed: added hasattr(_nxd_trace, '__SUPPORTED_SHARDED_MODULES') guard so it fails gracefully if the internal registry is renamed/removed in a future NxDI release. Also added an inline comment explaining why the registration is needed (NxD ModelBuilder rejects DistributedRMSNorm during compilation without it).
|
|
||
| pipeline = FlashVSRPipeline(config=config) | ||
|
|
||
| # Patch ThreadPoolExecutor for NxDI load |
There was a problem hiding this comment.
Why is the patch required? An inline comment here to explain would help.
There was a problem hiding this comment.
Added detailed inline comment in 72d94ed explaining the rationale: NxDI ModelBuilder uses ThreadPoolExecutor internally to load weights in parallel across TP ranks. In single-process mode (no torchrun), all ranks share one process and the default thread pool can deadlock or race on shared state. Limiting to 1 worker serializes rank loading. The patch is restored after load completes.
| gc.collect() | ||
|
|
||
| # Step 3: Streaming DiT inference | ||
| process_total_num = (num_frames - 1) // 8 - 2 |
There was a problem hiding this comment.
Is this math safe? This seems like process_total_num can be negative, and I'm not too sure what the consequences of that are.
There was a problem hiding this comment.
You're right, it could go negative with very short videos. Fixed in 72d94ed: added an explicit if process_total_num < 1: raise ValueError(...) with a clear message stating the minimum requirement (25 frames in 8n+1 format). Also added an inline comment explaining the formula derivation.
lutfanm-aws
left a comment
There was a problem hiding this comment.
I was able to compile and upscale a video, but the output I got did not seem up to standard. Is there a video that you upscaled on your end, and the resulting output that you can share? Are there specific instructions I can follow so I can reproduce on my end? I'm trying to make sure the poor output quality I'm seeing is not my doing.
Full pipeline notebook: LQ projection + DiT streaming + TCDecoder + Neuron AdaIN color correction. Includes F.interpolate identity-skip optimization. All models co-resident on TP=4 (trn2.3xlarge LNC=2). Results: 10.3 FPS Neuron-only, 9.8 FPS full E2E (85 frames @ 768x1280).
…mments - Narrow bare 'except Exception' to 'except RuntimeError' in _get_tp_degree - Guard __SUPPORTED_SHARDED_MODULES override with hasattr + explain why needed - Add detailed comment explaining ThreadPoolExecutor single-worker patch - Add ValueError guard for process_total_num < 1 with minimum frame requirement
Add notebook-based quick start, troubleshooting section for output quality issues (missing LQ projection, color correction, state reset, frame count). Remove incorrect path references (tcdecoder_seq.pt, lq_proj.pt).
Addressing Output Quality ConcernThanks for testing! I've pushed fixes for all inline comments (72d94ed) and added a troubleshooting guide to the README (e3eb76b). Reference OutputA validated output video is now included at Most Common Causes of Poor Output
Recommended Reproduction StepsThe notebook is the easiest path to reproduce validated results: source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate
cd contrib/models/FlashVSR
# Assumes weights at ~/FlashVSR-v1.1 and compiled models exist
jupyter nbconvert --to notebook --execute \
notebooks/tcdecoder_benchmark.ipynb \
--output tcdecoder_benchmark_executed.ipynbThe notebook handles all stages including compilation path setup, LQ projection, streaming DiT, TCDecoder with state management, and Neuron-accelerated AdaIN color correction. If you can share a frame from your output, I can help diagnose which stage might be causing the quality issue. |
Support compiling multiple stream bucket sizes (e.g., f=8,4,2) within a single NxDI Application. All NEFFs load co-resident in HBM — no swap overhead when switching between bucket sizes at runtime. Set via: FLASHVSR_STREAM_BUCKETS=8,4,2 Includes build_greedy_chunk_schedule() utility that assigns chunks to the largest available bucket first. For 1-min video: reduces DiT calls from 221 (f=2 only) to 56 (f=8 primary, f=4 remainder).
- neuron_dit_forward now accepts explicit temporal_offset for multi-bucket scheduling (required when frame count varies between chunks) - Add test_multi_bucket.py: compiles f=8,4,2 co-resident, benchmarks each bucket, simulates 1-min video with greedy scheduler
…r full-attention DiT Tested f=8, f=4, f=2 co-resident on trn2.3xlarge (SDK 2.30). Per-frame latency scales super-linearly due to full temporal self-attention: f=8 is 77% slower/frame than f=2. Multi-bucket code retained as reference for windowed-attention models. Also adds SDK 2.30 validation (22% DiT speedup, 12.6 FPS) to compatibility matrix.
| if pipeline.tcdecoder_model is not None and pipeline.tc_pixel_shuffle is not None: | ||
| from .tcdecoder import TCDecoderApplication, decode_video_nxdi | ||
|
|
||
| LQ_cur_idx = process_total_num * 8 + 21 if process_total_num > 0 else 21 |
There was a problem hiding this comment.
I get a RuntimeError: Sizes of tensors must match... Expected 19, got 18 error which seems to be caused by this line.
Also, your Notebook seems to use a different formular to compute this. Can we converge on one (correct) formula?
| python -m src.download_weights --output-dir ~/FlashVSR-v1.1 | ||
|
|
||
| # 3. Compile all models (one-time, ~30 min) | ||
| python -m src.pipeline compile \ |
There was a problem hiding this comment.
pipeline.py does not contain a main function that would handle different args so I'm not confident that these instructions would work
|
|
||
| ## Known Issues | ||
|
|
||
| - **Resolution constraint:** Input video must produce output dimensions divisible by 128 (e.g., 768x1280). Other resolutions require recompilation. |
There was a problem hiding this comment.
Can you re-compile for all resolutions? Compiling for a resolution larger than one you've validated might result in a compilation error for having too many instructions.
Note: The below template includes items meant for model contributions only. For other contributions such as bug fixes, features, etc., only fill out the relevant portions of the form.
Description
FlashVSR is a video super-resolution model (4x upscaling) using a streaming DiT architecture based on Wan 2.1 1.3B. This contrib packages the DiT backbone for Neuron via NxDI ModelBuilder with TP=4, using NKI tiled flash attention (
attention_cte) for 23040-token sequences that would otherwise OOM.The pipeline processes video in overlapping chunks: a first chunk (6 latent frames → 24 output frames) followed by stream chunks (2 latent frames → 8 output frames each). Single-step DMD denoising enables efficient 4x upscaling at 768×1280 output resolution.
Model Information
Model Name: FlashVSR v1.1 (JunhaoZhuang/FlashVSR-v1.1)
Model Architecture: 30-layer DiT (dim=1536, 12 heads, head_dim=128) with factored 3D RoPE, LCSA self-attention, text cross-attention, AdaLN modulation, QK-norm with DistributedRMSNorm
Purpose: Video super-resolution (4x spatial upscaling, 480p → 1920p)
Checklist
Required Components
Accuracy Test (
test/integration/test_dit_accuracy.py)neuron_allclose(rtol=0.05, atol=0.1)README.md with the following sections:
Source Code (
src/)modeling_flashvsr.py— NxDI-compatible DiT with Application/ModelWrapper/InferenceConfig (1242 lines)pipeline.py— Full inference pipeline orchestrationtcdecoder.py— TCDecoder (latent → RGB) wrapperlq_projection.py— LQ conditioning projection wrapperweights.py— Weight format detection and conversion (DiffSynth/diffusers → Neuron)download_weights.py— HuggingFace weight download utilityOptional Components
test/integration/test_pipeline_e2e.py) — PSNR validationFolder Structure
Testing
How did you test this change?
Compiled and tested on trn2.3xlarge (LNC=2, 4 NeuronCores) with Neuron SDK 2.29.1 (DLAMI 20260502, neuronx-cc 2.24.8799.0, NxDI 0.9.17334). DiT first-chunk compiled via NxDI ModelBuilder (TP=4, BF16) and validated against CPU reference model with identical weights.
Test Results:
Compatibility
Tested with:
Additional Information
attention_cteNKI kernel from nkilib for tiled flash attention (avoids materializing full S×S attention matrix in HBM)attn_maskinput is unused in Phase 1 (dense attention); kept for future Phase 2 LCSA block-sparse support on larger instances (trn2.48xlarge TP=16)Related Issues
None.
vLLM Integration
Not applicable — FlashVSR is a video generation model, not an LLM.
By submitting this PR, I confirm that: