Skip to content

Conversation

@lawrence-cj
Copy link
Contributor

@lawrence-cj lawrence-cj commented Nov 26, 2025

This PR supports LongSANA: a minute-length real-time video generation model

Related links:

project: https://nvlabs.github.io/Sana/Video
code: https://github.com/NVlabs/Sana
paper: https://arxiv.org/pdf/2509.24695

PR feature:

LongSANA uses Causal Linear Attention KV Cache during inference, which is crucial for long video generation(FlashAttention may need other PR). This PR adds Causal computation logi for both Linear Attention and Mix-FFN (Conv in MLP)

Added classes and functions

  1. add SanaVideoCausalTransformerBlock and SanaVideoCausalTransformer3DModel;
  2. add LongSanaVideoPipeline for Linear Attention KV-Cache;
  3. support LongSANA converting from pth to diffusers safetensor;

Cc: @sayakpaul @dg845
Co-author: @HeliosZhao

Code snap:

from diffusers import LongSanaVideoPipeline
from diffusers.utils import export_to_video

pipe = LongSanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_LongLive_diffusers", torch_dtype=torch.bfloat16)

pipe.scheduler = FlowMatchEulerDiscreteScheduler()
pipe.vae.to(torch.float32)
pipe.text_encoder.to(torch.bfloat16)
pipe.to("cuda")

prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=480,
    width=832,
    frames=161,
    guidance_scale=1.0,
    timesteps=[1000, 960, 889, 727, 0],  # Multi-step denoising per chunk
    generator=torch.Generator(device="cuda").manual_seed(42),
).frames[0]
export_to_video(video, "longsana.mp4", fps=16)

lawrence-cj and others added 4 commits November 26, 2025 07:32
1. add `SanaVideoCausalTransformerBlock` and `SanaVideoCausalTransformer3DModel`;
2. add `LongSanaVideoPipeline` for Linear Attention KV-Cache;
3. support LongSANA converting from pth to diffusers safetensor;
Co-authored-by: Yuyang Zhao <43061147+HeliosZhao@users.noreply.github.com>
@sayakpaul sayakpaul requested review from DN6, dg845 and yiyixuxu November 26, 2025 15:48
@sayakpaul
Copy link
Member

FlashAttention may need other PR

We can actually leverage our attention backends:
https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends

@lawrence-cj
Copy link
Contributor Author

FlashAttention may need other PR

We can actually leverage our attention backends: https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends

Is KV cache is supported in any backends?
Actually, in my PR, the kv-cache part is not well organized. So we do need your kind help to do it better to match diffusers style.

@lawrence-cj
Copy link
Contributor Author

Gentle ping @dg845

@dg845
Copy link
Collaborator

dg845 commented Dec 5, 2025

Hi @lawrence-cj, is the Efficient-Large-Model/SANA-Video_2B_480p_LongLive_diffusers model available on HF Hub? If I try the sample code above, I get an error when trying to load the checkpoint with LongSanaVideoPipeline.from_pretrained. On the hub, I see that there is a Efficient-Large-Model/SANA-Video_2B_480p_LongLive repo but it doesn't look like there is a diffusers variant.

return hidden_states


class CachedGLUMBConvTemp(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the additions of the LongSANA modeling blocks (CachedGLUMBConvTemp, SanaCausalLinearAttnProcessor1_0, SanaVideoCausalTransformerBlock, and SanaVideoCausalTransformer3DModel) to transformer_sana_video.py intended? It looks like all of these blocks are also defined in transformer_sana_video_causal.py and it doesn't look like any of the previous Sana Video models are being modified to use these blocks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I must have forgotten to delete them.

@lawrence-cj
Copy link
Contributor Author

Hi @lawrence-cj, is the Efficient-Large-Model/SANA-Video_2B_480p_LongLive_diffusers model available on HF Hub? If I try the sample code above, I get an error when trying to load the checkpoint with LongSanaVideoPipeline.from_pretrained. On the hub, I see that there is a Efficient-Large-Model/SANA-Video_2B_480p_LongLive repo but it doesn't look like there is a diffusers variant.

There is a Efficient-Large-Model/SANA-Video_2B_480p_LongLive_diffusers for diffusers pipeline, but it's now private. Can you access it through internal API?

@dg845
Copy link
Collaborator

dg845 commented Dec 6, 2025

Hi @lawrence-cj, I don't think I can access it unless I'm specifically given permission (for example, via a read access token).

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cleaned up & refactored the KV cache implementation in this commit: 8b177ff

Please take a look and feel free to cherry-pick if it works for you. Main changes are

  • Created SanaBlockKvCache to replace [None, None, None] - It is just more readable and we can know exactly what we are caching
  • Created LongSanaKvCachclass for pipeline, to abstract away the logic related to initialize and accumulate kv_caches across different chunks inside the pipeline
  • kv_cache is always passed/returned (can be None when not used) -> this simplify the code a bit so that our input/output format is consistent
  • we have a enable_save flag on the cache class — this way we don't need to pass save_kv_cache argument through every layer, we just have to enable/disable inside pipeline

Additionally, do you think we should create a custom scheduler for SANA long video? There's quite a bit of logic in the pipeline that I think belongs in a scheduler instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants