Skip to content

Commit ce7d1f4

Browse files
committed
feat: LTX-2 feature extractor
Signed-off-by: James Huang <syhuang1201@gmail.com>
1 parent 85ba65e commit ce7d1f4

File tree

3 files changed

+293
-0
lines changed

3 files changed

+293
-0
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Tuple, Union
18+
import jax.numpy as jnp
19+
from flax import nnx
20+
from maxdiffusion import common_types
21+
22+
Array = common_types.Array
23+
DType = common_types.DType
24+
25+
26+
def _norm_and_concat_padded_batch(
27+
encoded_text: Array,
28+
sequence_lengths: Array,
29+
padding_side: str = "right",
30+
) -> Array:
31+
"""Normalize and flatten multi-layer hidden states, respecting padding.
32+
Performs per-batch, per-layer normalization using masked mean and range,
33+
then concatenates across the layer dimension.
34+
35+
Args:
36+
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
37+
sequence_lengths: Number of valid (non-padded) tokens per batch item.
38+
padding_side: Whether padding is on "left" or "right".
39+
40+
Returns:
41+
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
42+
with padded positions zeroed out.
43+
"""
44+
b, t, d, l = encoded_text.shape
45+
46+
# Build mask: [B, T] -> [B, T, 1, 1]
47+
# token_indices: [1, T]
48+
token_indices = jnp.arange(t)[None, :]
49+
50+
if padding_side == "right":
51+
# Valid: indices < lengths
52+
mask = token_indices < sequence_lengths[:, None]
53+
elif padding_side == "left":
54+
# Valid: indices >= (T - lengths)
55+
start_indices = t - sequence_lengths[:, None]
56+
mask = token_indices >= start_indices
57+
else:
58+
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
59+
60+
# [B, T, 1, 1]
61+
mask = mask[:, :, None, None]
62+
63+
eps = 1e-6
64+
65+
# 1. Compute Masked Mean
66+
# Masked sum: [B, 1, 1, L] (sum over T, D)
67+
# Using jnp.where to zero-out padding
68+
masked_text = jnp.where(mask, encoded_text, 0.0)
69+
sum_vals = jnp.sum(masked_text, axis=(1, 2), keepdims=True)
70+
71+
# Denom: sequence_length * D
72+
denom = (sequence_lengths * d).reshape(b, 1, 1, 1)
73+
mean = sum_vals / (denom + eps)
74+
75+
# 2. Compute Masked Min/Max for Range
76+
# Use jnp.inf / -jnp.inf for padding to ignore them in min/max
77+
safe_text_min = jnp.where(mask, encoded_text, jnp.inf)
78+
safe_text_max = jnp.where(mask, encoded_text, -jnp.inf)
79+
80+
x_min = jnp.min(safe_text_min, axis=(1, 2), keepdims=True)
81+
x_max = jnp.max(safe_text_max, axis=(1, 2), keepdims=True)
82+
83+
range_val = x_max - x_min
84+
85+
# 3. Normalize
86+
# Only valid tokens are normalized. Padding will be garbage but masked out later.
87+
normed = 8.0 * (encoded_text - mean) / (range_val + eps)
88+
89+
# 4. Concatenate/Flatten Layers
90+
# [B, T, D, L] -> [B, T, D * L]
91+
normed = normed.reshape(b, t, -1)
92+
93+
# 5. Apply Mask to Output
94+
# Ensure padding positions are exactly 0.0
95+
# mask: [B, T, 1, 1] -> [B, T, 1]
96+
output_mask = mask.squeeze(-1).squeeze(-1)[:, :, None]
97+
normed = jnp.where(output_mask, normed, 0.0)
98+
99+
return normed
100+
101+
102+
class LTX2GemmaFeatureExtractor(nnx.Module):
103+
"""
104+
Feature extractor module for Gemma models in LTX-2.
105+
Applies mean-centered scaling and a linear projection.
106+
"""
107+
108+
def __init__(
109+
self,
110+
input_dim: int,
111+
output_dim: int,
112+
dtype: DType = jnp.float32,
113+
rngs: nnx.Rngs = None,
114+
):
115+
"""
116+
Args:
117+
input_dim: Dimension of flattened hidden states (Gemma dim * Num layers).
118+
output_dim: Target dimension for diffusion conditioning.
119+
"""
120+
# LTX-2 uses bias=False for the projection
121+
self.linear = nnx.Linear(input_dim, output_dim, use_bias=False, dtype=dtype, rngs=rngs)
122+
123+
def __call__(
124+
self, hidden_states: Union[Tuple[Array, ...], Array], attention_mask: Array, padding_side: str = "right"
125+
) -> Array:
126+
"""
127+
Args:
128+
hidden_states: Tuple of arrays from Gemma, each [B, T, D].
129+
Or pre-stacked array [B, T, D, L].
130+
attention_mask: Mask [B, T] (1 for valid, 0 for padding).
131+
padding_side: "right" or "left".
132+
133+
Returns:
134+
Projected features [B, T, OutputDim].
135+
"""
136+
137+
# 1. Stack Hidden States if needed
138+
if isinstance(hidden_states, (tuple, list)):
139+
# [B, T, D, L]
140+
x = jnp.stack(hidden_states, axis=-1)
141+
else:
142+
x = hidden_states
143+
144+
# 2. Calculate Sequence Lengths
145+
sequence_lengths = jnp.sum(attention_mask, axis=-1)
146+
147+
# 3. Norm and Concat
148+
x_norm = _norm_and_concat_padded_batch(x, sequence_lengths, padding_side=padding_side)
149+
150+
# 4. Projection
151+
return self.linear(x_norm)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import unittest
18+
import torch
19+
import numpy as np
20+
import jax.numpy as jnp
21+
from flax import nnx
22+
23+
from ..models.ltx2.text_encoders.feature_extractor_ltx2 import LTX2GemmaFeatureExtractor, _norm_and_concat_padded_batch
24+
25+
26+
# ==========================================
27+
# PyTorch Reference Logic
28+
# ==========================================
29+
def pt_norm_and_concat_padded_batch(
30+
encoded_text: torch.Tensor,
31+
sequence_lengths: torch.Tensor,
32+
padding_side: str = "right",
33+
) -> torch.Tensor:
34+
b, t, d, l = encoded_text.shape
35+
device = encoded_text.device
36+
37+
token_indices = torch.arange(t, device=device)[None, :]
38+
if padding_side == "right":
39+
mask = token_indices < sequence_lengths[:, None]
40+
elif padding_side == "left":
41+
start_indices = t - sequence_lengths[:, None]
42+
mask = token_indices >= start_indices
43+
else:
44+
raise ValueError
45+
46+
mask = mask[:, :, None, None] # [B, T, 1, 1]
47+
48+
eps = 1e-6
49+
masked = encoded_text.masked_fill(~mask, 0.0)
50+
denom = (sequence_lengths * d).view(b, 1, 1, 1)
51+
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
52+
53+
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
54+
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
55+
range_ = x_max - x_min
56+
57+
normed = 8 * (encoded_text - mean) / (range_ + eps)
58+
normed = normed.reshape(b, t, -1)
59+
60+
# Apply mask
61+
mask_flattened = mask.view(b, t, 1).expand(-1, -1, d * l)
62+
normed = normed.masked_fill(~mask_flattened, 0.0)
63+
64+
return normed
65+
66+
67+
class LTX2FeatureExtractorTest(unittest.TestCase):
68+
69+
def setUp(self):
70+
self.rng = nnx.Rngs(0)
71+
self.B = 2
72+
self.T = 10
73+
self.D = 8
74+
self.L = 3
75+
self.target_dim = 16
76+
77+
def test_norm_parity(self):
78+
# Create random input with some padding
79+
np_input = np.random.randn(self.B, self.T, self.D, self.L).astype(np.float32)
80+
81+
# Lengths: e.g. [5, 8] out of 10
82+
lengths = np.array([5, 8], dtype=np.int32)
83+
84+
# PyTorch Reference
85+
pt_input = torch.from_numpy(np_input)
86+
pt_lengths = torch.from_numpy(lengths)
87+
pt_out = pt_norm_and_concat_padded_batch(pt_input, pt_lengths)
88+
89+
# JAX Implementation
90+
jax_input = jnp.array(np_input)
91+
jax_lengths = jnp.array(lengths)
92+
jax_out = _norm_and_concat_padded_batch(jax_input, jax_lengths)
93+
94+
diff = np.abs(pt_out.numpy() - np.array(jax_out)).max()
95+
print(f"\n[Norm Parity] Max Diff: {diff:.6f}")
96+
97+
np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), atol=1e-5)
98+
print("[PASS] Normalization Logic Parity Verified.")
99+
100+
def test_module_forward(self):
101+
# Test full module
102+
model = LTX2GemmaFeatureExtractor(input_dim=self.D * self.L, output_dim=self.target_dim, rngs=self.rng)
103+
104+
# Create input tuple (simulate Gemma output)
105+
hidden_states = [jnp.array(np.random.randn(self.B, self.T, self.D)) for _ in range(self.L)]
106+
107+
# Attention Mask [B, T]
108+
mask = np.zeros((self.B, self.T), dtype=np.int32)
109+
mask[0, :5] = 1
110+
mask[1, :8] = 1
111+
jax_mask = jnp.array(mask)
112+
113+
output = model(tuple(hidden_states), jax_mask)
114+
115+
expected_shape = (self.B, self.T, self.target_dim)
116+
self.assertEqual(output.shape, expected_shape)
117+
118+
# Check padding regions are zero
119+
# Batch 0, indices 5: should be 0
120+
padding_val = output[0, 5:, :]
121+
self.assertTrue(jnp.all(padding_val == 0.0), "Padding region should be zero")
122+
123+
print("\n[PASS] Feature Extractor Module Forward Pass Verified.")
124+
125+
126+
if __name__ == "__main__":
127+
unittest.main()

0 commit comments

Comments
 (0)