Skip to content

Commit 7747a33

Browse files
committed
Implement embedding connector
Signed-off-by: James Huang <syhuang1201@gmail.com>
1 parent 6e3b58b commit 7747a33

File tree

2 files changed

+305
-0
lines changed

2 files changed

+305
-0
lines changed
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Optional, Tuple
18+
import jax
19+
import jax.numpy as jnp
20+
from flax import nnx
21+
from maxdiffusion import common_types
22+
from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention
23+
from maxdiffusion.models.attention_flax import NNXSimpleFeedForward
24+
25+
Array = common_types.Array
26+
DType = common_types.DType
27+
28+
29+
class _BasicTransformerBlock1D(nnx.Module):
30+
31+
def __init__(
32+
self,
33+
dim: int,
34+
heads: int,
35+
dim_head: int,
36+
rope_type: str = "interleaved",
37+
attention_kernel: str = "flash",
38+
mesh: jax.sharding.Mesh = None,
39+
rngs: nnx.Rngs = None,
40+
):
41+
self.attn1 = LTX2Attention(
42+
query_dim=dim,
43+
heads=heads,
44+
dim_head=dim_head,
45+
rope_type=rope_type,
46+
bias=True, # LTX-2 default
47+
out_bias=True,
48+
attention_kernel=attention_kernel,
49+
mesh=mesh,
50+
rngs=rngs,
51+
)
52+
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim)
53+
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs)
54+
self.norm2 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs)
55+
56+
def __call__(
57+
self,
58+
hidden_states: Array,
59+
attention_mask: Optional[Array] = None,
60+
rotary_emb: Optional[Tuple[Array, Array]] = None,
61+
) -> Array:
62+
# 1. Norm -> Attention
63+
normed = self.norm1(hidden_states)
64+
attn_output = self.attn1(normed, attention_mask=attention_mask, rotary_emb=rotary_emb)
65+
hidden_states = hidden_states + attn_output
66+
67+
# 2. Norm -> FeedForward
68+
normed = self.norm2(hidden_states)
69+
ff_output = self.ff(normed)
70+
hidden_states = hidden_states + ff_output
71+
72+
return hidden_states
73+
74+
75+
class Embeddings1DConnector(nnx.Module):
76+
"""
77+
Applies 1D transformer processing with Thinking Tokens (Learnable Registers).
78+
Uses nnx.scan for efficient JAX-idiomatic layer execution.
79+
"""
80+
81+
def __init__(
82+
self,
83+
input_dim: int,
84+
heads: int = 30,
85+
head_dim: int = 128,
86+
layers: int = 2,
87+
theta: float = 10000.0,
88+
num_learnable_registers: int = 128,
89+
rope_type: str = "interleaved",
90+
attention_kernel: str = "flash",
91+
mesh: jax.sharding.Mesh = None,
92+
rngs: nnx.Rngs = None,
93+
):
94+
self.dim = input_dim
95+
self.theta = theta
96+
self.num_learnable_registers = num_learnable_registers
97+
self.num_layers = layers
98+
99+
# 1. Initialize Stacked Layers using vmap
100+
# This creates a single module where parameters have an extra leading dimension [layers, ...]
101+
# We need to ensure rngs are split for each layer
102+
@nnx.split_rngs(splits=layers)
103+
@nnx.vmap(in_axes=0, out_axes=0, axis_size=layers)
104+
def create_block(rngs):
105+
return _BasicTransformerBlock1D(
106+
dim=input_dim,
107+
heads=heads,
108+
dim_head=head_dim,
109+
rope_type=rope_type,
110+
attention_kernel=attention_kernel,
111+
mesh=mesh,
112+
rngs=rngs,
113+
)
114+
115+
# Call the vmapped constructor
116+
self.stacked_blocks = create_block(rngs)
117+
118+
# 2. Thinking Tokens
119+
if num_learnable_registers > 0:
120+
key = rngs.params()
121+
self.learnable_registers = nnx.Param(
122+
jax.random.uniform(key, (num_learnable_registers, self.dim), dtype=jnp.bfloat16) * 2.0 - 1.0
123+
)
124+
125+
self.final_norm = nnx.RMSNorm(
126+
self.dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs
127+
)
128+
129+
def _replace_padded_with_learnable_registers(self, hidden_states: Array, attention_mask: Array) -> Tuple[Array, Array]:
130+
b, t, d = hidden_states.shape
131+
if t % self.num_learnable_registers != 0:
132+
raise ValueError(f"Sequence length {t} must be divisible by {self.num_learnable_registers}")
133+
134+
num_duplications = t // self.num_learnable_registers
135+
registers = jnp.tile(self.learnable_registers[...], (num_duplications, 1))
136+
registers = jnp.expand_dims(registers, 0)
137+
138+
if attention_mask.ndim == 2:
139+
mask = attention_mask[:, :, None]
140+
else:
141+
mask = attention_mask
142+
143+
output = jnp.where(mask > 0.5, hidden_states, registers)
144+
new_mask = jnp.ones_like(attention_mask)
145+
return output, new_mask
146+
147+
def _compute_1d_rope(self, seq_len: int, dtype: DType) -> Tuple[Array, Array]:
148+
t = jnp.arange(seq_len, dtype=jnp.float32)
149+
freqs = 1.0 / (self.theta ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
150+
emb = jnp.outer(t, freqs)
151+
cos = jnp.cos(emb)
152+
sin = jnp.sin(emb)
153+
cos = jnp.repeat(cos, 2, axis=-1)
154+
sin = jnp.repeat(sin, 2, axis=-1)
155+
return cos[None, ...], sin[None, ...]
156+
157+
def __call__(
158+
self,
159+
hidden_states: Array,
160+
attention_mask: Optional[Array] = None,
161+
) -> Array:
162+
# 1. Thinking Tokens
163+
if self.num_learnable_registers > 0 and attention_mask is not None:
164+
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
165+
166+
# 2. RoPE
167+
seq_len = hidden_states.shape[1]
168+
rotary_emb = self._compute_1d_rope(seq_len, hidden_states.dtype)
169+
170+
# 3. Transformer Blocks (Scan)
171+
172+
# Scan function signature: (carry, x) -> (carry, y)
173+
def block_scan_fn(carry, block_module):
174+
hidden_states = carry
175+
# block_module is a sliced view of the vmapped module
176+
hidden_states = block_module(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
177+
return hidden_states, None
178+
179+
# Execute scan
180+
hidden_states, _ = nnx.scan(
181+
block_scan_fn,
182+
length=self.num_layers,
183+
in_axes=(nnx.Carry, 0), # Scan over the layers dimension (0) of block_module
184+
out_axes=(nnx.Carry, 0),
185+
)(hidden_states, self.stacked_blocks)
186+
187+
# 4. Final Norm
188+
hidden_states = self.final_norm(hidden_states)
189+
190+
return hidden_states
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import unittest
18+
import jax
19+
import jax.numpy as jnp
20+
import numpy as np
21+
from flax import nnx
22+
from ..models.ltx2.text_encoders.embeddings_connector_ltx2 import Embeddings1DConnector
23+
24+
25+
class Embeddings1DConnectorTest(unittest.TestCase):
26+
27+
def setUp(self):
28+
self.rng = nnx.Rngs(0)
29+
self.B = 2
30+
self.T = 16 # Must be divisible by num_learnable_registers if we want tiling to work simply
31+
self.D = 64 # inner_dim
32+
33+
# Test config
34+
self.num_learnable_registers = 8
35+
self.heads = 4
36+
self.head_dim = 16
37+
38+
# input dim = heads * head_dim = 64
39+
40+
def test_thinking_tokens_replacement(self):
41+
connector = Embeddings1DConnector(
42+
input_dim=self.D,
43+
heads=self.heads,
44+
head_dim=self.head_dim,
45+
layers=1,
46+
num_learnable_registers=self.num_learnable_registers,
47+
mesh=None,
48+
rngs=self.rng,
49+
)
50+
51+
# Create input [B, T, D]
52+
hidden_states = jnp.zeros((self.B, self.T, self.D))
53+
54+
# Create mask [B, T]
55+
# Batch 0: First 4 valid, rest padding
56+
# Batch 1: First 8 valid, rest padding
57+
mask = np.zeros((self.B, self.T), dtype=np.int32)
58+
mask[0, :4] = 1
59+
mask[1, :8] = 1
60+
61+
# Explicitly run replacement method
62+
output, new_mask = connector._replace_padded_with_learnable_registers(hidden_states, jnp.array(mask))
63+
64+
# 1. Check Mask Reset
65+
self.assertTrue(jnp.all(new_mask == 1.0), "New mask should be all 1s")
66+
67+
# 2. Check Valid Tokens (should be 0 as input was 0)
68+
# Batch 0, 0-3
69+
valid_b0 = output[0, :4, :]
70+
self.assertTrue(jnp.all(valid_b0 == 0.0), "Valid tokens should remain unchanged")
71+
72+
# 3. Check Thinking Tokens (Padding area)
73+
# Batch 0, 4-15
74+
thinking_b0 = output[0, 4:, :]
75+
76+
# The learnable registers should be tiled.
77+
# Registers shape: [8, 64]
78+
# T=16, so it's tiled 2 times -> [16, 64]
79+
# We need to verify that padding positions contain values from registers
80+
81+
# Get expected registers values
82+
registers_val = connector.learnable_registers[...] # [8, 64]
83+
tiled_regs = jnp.tile(registers_val, (2, 1)) # [16, 64]
84+
85+
expected_padding = tiled_regs[4:, :] # corresponding slice
86+
87+
np.testing.assert_allclose(
88+
thinking_b0, expected_padding, err_msg="Padding should be replaced by corresponding register values"
89+
)
90+
print("\n[PASS] Thinking Tokens Replacement Logic Verified.")
91+
92+
def test_forward_shape_and_run(self):
93+
connector = Embeddings1DConnector(
94+
input_dim=self.D,
95+
heads=self.heads,
96+
head_dim=self.head_dim,
97+
layers=2,
98+
num_learnable_registers=self.num_learnable_registers,
99+
attention_kernel="dot_product", # Use dot_product for testing on CPU
100+
mesh=None,
101+
rngs=self.rng,
102+
)
103+
104+
hidden_states = jnp.array(np.random.randn(self.B, self.T, self.D))
105+
mask = jnp.ones((self.B, self.T)) # All valid
106+
107+
output = connector(hidden_states, mask)
108+
109+
self.assertEqual(output.shape, (self.B, self.T, self.D))
110+
self.assertFalse(jnp.isnan(output).any(), "Output should not contain NaNs")
111+
print("\n[PASS] Embeddings1DConnector Forward Pass Verified.")
112+
113+
114+
if __name__ == "__main__":
115+
unittest.main()

0 commit comments

Comments
 (0)