Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
^.*\.Rproj$
^\.Rproj\.user$
^CLAUDE\.md$
^\.github$
^\.claude$
^TORCHSCRIPT_MIGRATION\.md$
^cat\.png$
^cat2\.png$
^gambling_cat\.png$
^fyi\.md$
^man-md$
^inst/validation/\.venv$
44 changes: 44 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ See cornyverse CLAUDE.md for safetensors package setup (use cornball-ai fork unt
- [x] Pipeline integration (txt2vid_ltx2)
- [x] Video output utilities (save_video)
- [x] Weight loading from HuggingFace safetensors
- [x] Two-stage distilled pipeline (Wan2GP parity) - see learnings below
- [x] Latent upsampler (Conv3d ResBlocks + SpatialRationalResampler)

#### LTX-2 Weight Loading

Expand Down Expand Up @@ -488,6 +490,48 @@ text <- decode_bpe(tok, result$input_ids)
- Torch tensor output
- SentencePiece-style space markers (▁)

#### LTX-2 Pipeline Debugging (February 2026)

Pipeline produced noise instead of video. Four bugs found by comparing against diffusers reference using `pyrotechnics::py2r_file()`:

1. **Scheduler step treated velocity as x0** (CRITICAL): LTX-2 DiT predicts velocity directly. The step is `latents = latents + dt * model_output`, not the x0-to-velocity derivation.

2. **Missing latent denormalization** (CRITICAL): Must denormalize latents (`latents * std / scaling_factor + mean`) before `vae$decode()`. The VAE's `latents_mean`/`latents_std` buffers are loaded from `per_channel_statistics` in safetensors weights.

3. **VAE encoder negative indexing**: R's `[, -1,,,]` excludes the last channel; Python's `[:, -1:]` selects it. Fixed to `[, hidden_states$shape[2],,,, drop = FALSE]`.

4. **Text encoder skipped embedding layer**: Gemma3 outputs 49 hidden states (1 embedding + 48 layers). Code was taking `[2:length]` (48 states) but `text_proj_in_factor=49` expects all 49.

**Lesson**: Use `pyrotechnics::py2r_file()` to auto-convert reference Python to R for side-by-side comparison. The converted code isn't runnable but reveals algorithmic differences clearly.

#### LTX-2 Two-Stage Distilled Pipeline (February 2026)

Wan2GP's distilled pipeline uses two stages for higher quality output:

**Architecture:**
1. Stage 1: Denoise at half resolution (H/2, W/2) with 8 steps using `DISTILLED_SIGMA_VALUES`
2. Upsampler: Un-normalize latents → Conv3d ResBlocks + SpatialRationalResampler (2x) → Re-normalize
3. Stage 2: Add noise at `noise_scale=0.909375`, denoise at full resolution with 3 steps using `STAGE_2_DISTILLED_SIGMA_VALUES`

**Sigma schedules (hardcoded):**
- Stage 1: `[1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]`
- Stage 2: `[0.909375, 0.725, 0.421875, 0.0]`

**Upsampler model details:**
- Weight file: `ltx-2-spatial-upscaler-x2-1.0.safetensors` (~950MB)
- Architecture: `dims=3` (Conv3d), `mid_channels=1024`, `in_channels=128`
- Uses `SpatialRationalResampler(scale=2.0)`: rearranges to per-frame 2D, applies Conv2d(1024→4096) + PixelShuffle(2), then back to 5D
- For scale=2.0: `num=2, den=1`, so BlurDownsample is identity (stride=1)
- Weight key `upsampler.blur_down.kernel` is a buffer, safely skipped during loading

**Stage 2 noise injection:**
```r
noise <- torch_randn_like(latents)
latents <- noise * noise_scale + latents * (1 - noise_scale)
```
Where `noise_scale = stage_2_sigmas[1] = 0.909375`.

**Resolution constraint:** For two-stage, resolution must be divisible by 64 (not 32).
## R torch API Quirks

Important differences between R torch and Python PyTorch:
Expand Down
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Description: A native R implementation of diffusion models providing a functiona
using text prompts through models like Stable Diffusion without Python dependencies.
The package provides a streamlined, idiomatic R experience with support for multiple
diffusion schedulers and device acceleration.
License: Apache License 2.0
License: Apache License (>= 2)
URL: https://github.com/cornball-ai/diffuseR
BugReports: https://github.com/cornball-ai/diffuseR/issues
Encoding: UTF-8
Expand All @@ -26,4 +26,6 @@ Suggests:
hfhub,
safetensors,
tinytest
Remotes:
cornball-ai/gpu.ctl
RoxygenNote: 7.3.3
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ export(load_int4_weights)
export(load_int4_weights_into_model)
export(load_ltx2_connectors)
export(load_ltx2_transformer)
export(load_ltx2_upsampler)
export(load_ltx2_vae)
export(load_model_component)
export(load_pipeline)
Expand Down Expand Up @@ -103,3 +104,5 @@ export(vocab_size)
export(vram_report)

S3method(print,bpe_tokenizer)

importFrom(utils,head)
5 changes: 5 additions & 0 deletions R/dit_ltx2.R
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,11 @@ ltx2_video_transformer_3d_model <- torch::nn_module(
audio_hidden_states <- self$audio_proj_in(audio_hidden_states)

# 3. Prepare timestep embeddings
# Scale timesteps from [0,1] to [0, timestep_scale_multiplier] for sinusoidal embeddings
# (matches WanGP convention: model receives sigma in [0,1], scales internally)
timestep <- timestep * self$timestep_scale_multiplier
audio_timestep <- audio_timestep * self$timestep_scale_multiplier

timestep_cross_attn_gate_scale_factor <- self$cross_attn_timestep_scale_multiplier / self$timestep_scale_multiplier

# 3.1 Global timestep embedding
Expand Down
13 changes: 6 additions & 7 deletions R/gemma3_text_encoder.R
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,9 @@ gemma3_text_model <- torch::nn_module(
# Final normalization
hidden_states <- self$norm(hidden_states)

if (output_hidden_states) {
all_hidden_states <- c(all_hidden_states, list(hidden_states))
}
# NOTE: do NOT append post-norm output to all_hidden_states.
# LTX-2 connectors expect exactly 49 states (embedding + 48 layers).
# The post-norm output is only returned as last_hidden_state.

list(
last_hidden_state = hidden_states,
Expand Down Expand Up @@ -852,13 +852,12 @@ encode_with_gemma3 <- function(
output <- model(input_ids, attention_mask = attention_mask, output_hidden_states = TRUE)
})

# Stack hidden states from transformer layers (skip embedding layer at index 1)
# Stack ALL hidden states (embedding + 48 transformer layers = 49 total)
hidden_states_list <- output$hidden_states
# hidden_states_list[[1]] is embedding layer, [[2]] onwards are transformer layers
# hidden_states_list[[1]] is embedding layer, [[2]]..[[49]] are transformer layers
# LTX-2 connectors expect 49 layers (text_proj_in_factor=49)
transformer_hidden_states <- hidden_states_list[2:length(hidden_states_list)]
# Stack: [batch, seq_len, hidden_size, num_layers]
hidden_states_stacked <- torch::torch_stack(transformer_hidden_states, dim = - 1L)
hidden_states_stacked <- torch::torch_stack(hidden_states_list, dim = -1L)

# Compute sequence lengths from attention mask
sequence_lengths <- as.integer(attention_mask$sum(dim = 2L)$cpu())
Expand Down
Loading
Loading