Skip to content

Commit ca99d98

Browse files
committed
Implement LTX-2 embedding connector
Signed-off-by: James Huang <syhuang1201@gmail.com>
1 parent cddbf6a commit ca99d98

File tree

2 files changed

+304
-0
lines changed

2 files changed

+304
-0
lines changed
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Optional, Tuple
18+
import jax
19+
import jax.numpy as jnp
20+
from flax import nnx
21+
from maxdiffusion import common_types
22+
from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention
23+
from maxdiffusion.models.attention_flax import NNXSimpleFeedForward
24+
25+
Array = common_types.Array
26+
DType = common_types.DType
27+
28+
29+
class _BasicTransformerBlock1D(nnx.Module):
30+
31+
def __init__(
32+
self,
33+
dim: int,
34+
heads: int,
35+
dim_head: int,
36+
rope_type: str = "interleaved",
37+
attention_kernel: str = "flash",
38+
mesh: jax.sharding.Mesh = None,
39+
rngs: nnx.Rngs = None,
40+
):
41+
self.attn1 = LTX2Attention(
42+
query_dim=dim,
43+
heads=heads,
44+
dim_head=dim_head,
45+
rope_type=rope_type,
46+
bias=True, # LTX-2 default
47+
out_bias=True,
48+
attention_kernel=attention_kernel,
49+
mesh=mesh,
50+
rngs=rngs,
51+
)
52+
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim)
53+
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs)
54+
self.norm2 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs)
55+
56+
def __call__(
57+
self,
58+
hidden_states: Array,
59+
attention_mask: Optional[Array] = None,
60+
rotary_emb: Optional[Tuple[Array, Array]] = None,
61+
) -> Array:
62+
# 1. Norm -> Attention
63+
normed = self.norm1(hidden_states).astype(hidden_states.dtype)
64+
attn_output = self.attn1(normed, attention_mask=attention_mask, rotary_emb=rotary_emb)
65+
hidden_states = hidden_states + attn_output
66+
67+
# 2. Norm -> FeedForward
68+
normed = self.norm2(hidden_states).astype(hidden_states.dtype)
69+
ff_output = self.ff(normed)
70+
hidden_states = hidden_states + ff_output
71+
72+
return hidden_states
73+
74+
75+
class Embeddings1DConnector(nnx.Module):
76+
"""
77+
Applies 1D transformer processing with Thinking Tokens (Learnable Registers).
78+
Uses nnx.scan for efficient JAX-idiomatic layer execution.
79+
"""
80+
81+
def __init__(
82+
self,
83+
input_dim: int,
84+
heads: int = 30,
85+
head_dim: int = 128,
86+
layers: int = 2,
87+
theta: float = 10000.0,
88+
num_learnable_registers: int = 128,
89+
rope_type: str = "interleaved",
90+
attention_kernel: str = "flash",
91+
mesh: jax.sharding.Mesh = None,
92+
rngs: nnx.Rngs = None,
93+
):
94+
self.dim = input_dim
95+
self.theta = theta
96+
self.num_learnable_registers = num_learnable_registers
97+
self.num_layers = layers
98+
99+
# 1. Initialize Stacked Layers using vmap
100+
# This creates a single module where parameters have an extra leading dimension [layers, ...]
101+
# We need to ensure rngs are split for each layer
102+
@nnx.split_rngs(splits=layers)
103+
@nnx.vmap(in_axes=0, out_axes=0, axis_size=layers)
104+
def create_block(rngs):
105+
return _BasicTransformerBlock1D(
106+
dim=input_dim,
107+
heads=heads,
108+
dim_head=head_dim,
109+
rope_type=rope_type,
110+
attention_kernel=attention_kernel,
111+
mesh=mesh,
112+
rngs=rngs,
113+
)
114+
115+
# Call the vmapped constructor
116+
self.stacked_blocks = create_block(rngs)
117+
118+
# 2. Thinking Tokens
119+
if num_learnable_registers > 0:
120+
key = rngs.params()
121+
self.learnable_registers = nnx.Param(
122+
jax.random.uniform(key, (num_learnable_registers, self.dim), dtype=jnp.bfloat16) * 2.0 - 1.0
123+
)
124+
125+
self.final_norm = nnx.RMSNorm(
126+
self.dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs
127+
)
128+
129+
def _replace_padded_with_learnable_registers(self, hidden_states: Array, attention_mask: Array) -> Tuple[Array, Array]:
130+
b, t, d = hidden_states.shape
131+
if t % self.num_learnable_registers != 0:
132+
raise ValueError(f"Sequence length {t} must be divisible by {self.num_learnable_registers}")
133+
134+
num_duplications = t // self.num_learnable_registers
135+
registers = jnp.tile(self.learnable_registers[...], (num_duplications, 1))
136+
registers = jnp.expand_dims(registers, 0)
137+
138+
if attention_mask.ndim == 2:
139+
mask = attention_mask[:, :, None]
140+
else:
141+
mask = attention_mask
142+
143+
output = jnp.where(mask > 0.5, hidden_states, registers)
144+
new_mask = jnp.ones_like(attention_mask)
145+
return output, new_mask
146+
147+
def _compute_1d_rope(self, seq_len: int, dtype: DType) -> Tuple[Array, Array]:
148+
t = jnp.arange(seq_len, dtype=jnp.float32)
149+
freqs = 1.0 / (self.theta ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
150+
emb = jnp.outer(t, freqs)
151+
cos = jnp.cos(emb)
152+
sin = jnp.sin(emb)
153+
cos = jnp.repeat(cos, 2, axis=-1)
154+
sin = jnp.repeat(sin, 2, axis=-1)
155+
return cos[None, ...], sin[None, ...]
156+
157+
def __call__(
158+
self,
159+
hidden_states: Array,
160+
attention_mask: Optional[Array] = None,
161+
) -> Array:
162+
# 1. Thinking Tokens
163+
if self.num_learnable_registers > 0 and attention_mask is not None:
164+
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
165+
166+
# 2. RoPE
167+
seq_len = hidden_states.shape[1]
168+
rotary_emb = self._compute_1d_rope(seq_len, hidden_states.dtype)
169+
170+
# 3. Transformer Blocks (Scan)
171+
172+
# Scan function signature: (carry, x) -> (carry, y)
173+
def block_scan_fn(carry, block_module):
174+
hidden_states = carry
175+
# block_module is a sliced view of the vmapped module
176+
hidden_states = block_module(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
177+
return hidden_states, None
178+
179+
# Execute scan
180+
hidden_states, _ = nnx.scan(
181+
block_scan_fn,
182+
length=self.num_layers,
183+
in_axes=(nnx.Carry, 0), # Scan over the layers dimension (0) of block_module
184+
out_axes=(nnx.Carry, 0),
185+
)(hidden_states, self.stacked_blocks)
186+
187+
# 4. Final Norm
188+
hidden_states = self.final_norm(hidden_states)
189+
190+
return hidden_states
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import unittest
18+
import jax.numpy as jnp
19+
import numpy as np
20+
from flax import nnx
21+
from ..models.ltx2.text_encoders.embeddings_connector_ltx2 import Embeddings1DConnector
22+
23+
24+
class Embeddings1DConnectorTest(unittest.TestCase):
25+
26+
def setUp(self):
27+
self.rng = nnx.Rngs(0)
28+
self.B = 2
29+
self.T = 16 # Must be divisible by num_learnable_registers if we want tiling to work simply
30+
self.D = 64 # inner_dim
31+
32+
# Test config
33+
self.num_learnable_registers = 8
34+
self.heads = 4
35+
self.head_dim = 16
36+
37+
# input dim = heads * head_dim = 64
38+
39+
def test_thinking_tokens_replacement(self):
40+
connector = Embeddings1DConnector(
41+
input_dim=self.D,
42+
heads=self.heads,
43+
head_dim=self.head_dim,
44+
layers=1,
45+
num_learnable_registers=self.num_learnable_registers,
46+
mesh=None,
47+
rngs=self.rng,
48+
)
49+
50+
# Create input [B, T, D]
51+
hidden_states = jnp.zeros((self.B, self.T, self.D))
52+
53+
# Create mask [B, T]
54+
# Batch 0: First 4 valid, rest padding
55+
# Batch 1: First 8 valid, rest padding
56+
mask = np.zeros((self.B, self.T), dtype=np.int32)
57+
mask[0, :4] = 1
58+
mask[1, :8] = 1
59+
60+
# Explicitly run replacement method
61+
output, new_mask = connector._replace_padded_with_learnable_registers(hidden_states, jnp.array(mask))
62+
63+
# 1. Check Mask Reset
64+
self.assertTrue(jnp.all(new_mask == 1.0), "New mask should be all 1s")
65+
66+
# 2. Check Valid Tokens (should be 0 as input was 0)
67+
# Batch 0, 0-3
68+
valid_b0 = output[0, :4, :]
69+
self.assertTrue(jnp.all(valid_b0 == 0.0), "Valid tokens should remain unchanged")
70+
71+
# 3. Check Thinking Tokens (Padding area)
72+
# Batch 0, 4-15
73+
thinking_b0 = output[0, 4:, :]
74+
75+
# The learnable registers should be tiled.
76+
# Registers shape: [8, 64]
77+
# T=16, so it's tiled 2 times -> [16, 64]
78+
# We need to verify that padding positions contain values from registers
79+
80+
# Get expected registers values
81+
registers_val = connector.learnable_registers[...] # [8, 64]
82+
tiled_regs = jnp.tile(registers_val, (2, 1)) # [16, 64]
83+
84+
expected_padding = tiled_regs[4:, :] # corresponding slice
85+
86+
np.testing.assert_allclose(
87+
thinking_b0, expected_padding, err_msg="Padding should be replaced by corresponding register values"
88+
)
89+
print("\n[PASS] Thinking Tokens Replacement Logic Verified.")
90+
91+
def test_forward_shape_and_run(self):
92+
connector = Embeddings1DConnector(
93+
input_dim=self.D,
94+
heads=self.heads,
95+
head_dim=self.head_dim,
96+
layers=2,
97+
num_learnable_registers=self.num_learnable_registers,
98+
attention_kernel="dot_product", # Use dot_product for testing on CPU
99+
mesh=None,
100+
rngs=self.rng,
101+
)
102+
103+
hidden_states = jnp.array(np.random.randn(self.B, self.T, self.D))
104+
mask = jnp.ones((self.B, self.T)) # All valid
105+
106+
output = connector(hidden_states, mask)
107+
108+
self.assertEqual(output.shape, (self.B, self.T, self.D))
109+
self.assertFalse(jnp.isnan(output).any(), "Output should not contain NaNs")
110+
print("\n[PASS] Embeddings1DConnector Forward Pass Verified.")
111+
112+
113+
if __name__ == "__main__":
114+
unittest.main()

0 commit comments

Comments
 (0)