Skip to content

Commit 02dbc99

Browse files
committed
feat: add ltx2 text encoders wrappers
Signed-off-by: James Huang <syhuang1201@gmail.com>
1 parent 780b7fc commit 02dbc99

File tree

2 files changed

+253
-0
lines changed

2 files changed

+253
-0
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Tuple, Union, List
18+
import jax
19+
import jax.numpy as jnp
20+
from flax import nnx
21+
from maxdiffusion import common_types
22+
23+
from .feature_extractor_ltx2 import LTX2GemmaFeatureExtractor
24+
from .embeddings_connector_ltx2 import Embeddings1DConnector
25+
26+
Array = common_types.Array
27+
DType = common_types.DType
28+
29+
30+
class LTX2VideoGemmaTextEncoder(nnx.Module):
31+
"""
32+
Encoder for Video-only tasks.
33+
Pipeline: Gemma Hidden States -> Feature Extractor -> Video Connector -> Output
34+
"""
35+
36+
def __init__(
37+
self,
38+
# Feature Extractor Config
39+
gemma_dim: int = 3840, # Gemma-3-12b
40+
gemma_layers: int = 49, # Gemma-3 has 48 layers + 1 embedding layer output = 49 hidden states
41+
projection_dim: int = 3840, # LTX-2 conditioning dim
42+
# Connector Config
43+
connector_heads: int = 32,
44+
connector_head_dim: int = 128,
45+
connector_layers: int = 2,
46+
num_thinking_tokens: int = 128,
47+
dtype: DType = jnp.float32,
48+
attention_kernel: str = "flash",
49+
mesh: jax.sharding.Mesh = None,
50+
rngs: nnx.Rngs = None,
51+
):
52+
input_dim = gemma_dim * gemma_layers
53+
54+
self.feature_extractor = LTX2GemmaFeatureExtractor(
55+
input_dim=input_dim,
56+
output_dim=projection_dim,
57+
dtype=dtype,
58+
rngs=rngs,
59+
)
60+
61+
self.embeddings_connector = Embeddings1DConnector(
62+
input_dim=projection_dim,
63+
heads=connector_heads,
64+
head_dim=connector_head_dim,
65+
layers=connector_layers,
66+
num_learnable_registers=num_thinking_tokens,
67+
rope_type="interleaved",
68+
attention_kernel=attention_kernel,
69+
mesh=mesh,
70+
rngs=rngs,
71+
)
72+
73+
def __call__(
74+
self,
75+
hidden_states: Union[Tuple[Array, ...], List[Array]],
76+
attention_mask: Array,
77+
) -> Array:
78+
"""
79+
Args:
80+
hidden_states: From Gemma output.hidden_states (Tuple of [B, T, D])
81+
attention_mask: [B, T]
82+
"""
83+
# 1. Feature Extraction (Stack -> Norm -> Project)
84+
features = self.feature_extractor(hidden_states, attention_mask)
85+
86+
# 2. Connection (Refine + Thinking Tokens)
87+
video_embeds = self.embeddings_connector(features, attention_mask)
88+
89+
return video_embeds
90+
91+
92+
class LTX2AudioVideoGemmaTextEncoder(nnx.Module):
93+
"""
94+
Encoder for Audio-Video tasks.
95+
Pipeline: Gemma Hidden States -> Feature Extractor -> [Video Connector, Audio Connector]
96+
"""
97+
98+
def __init__(
99+
self,
100+
# Feature Extractor Config (Shared)
101+
gemma_dim: int = 3840, # Gemma-3-12b
102+
gemma_layers: int = 49, # Gemma-3 has 48 layers + 1 embedding layer output = 49 hidden states
103+
projection_dim: int = 3840,
104+
# Connector Config
105+
connector_heads: int = 30,
106+
connector_head_dim: int = 128,
107+
connector_layers: int = 2,
108+
num_thinking_tokens: int = 128,
109+
dtype: DType = jnp.float32,
110+
attention_kernel: str = "flash",
111+
mesh: jax.sharding.Mesh = None,
112+
rngs: nnx.Rngs = None,
113+
):
114+
input_dim = gemma_dim * gemma_layers
115+
116+
self.feature_extractor = LTX2GemmaFeatureExtractor(
117+
input_dim=input_dim,
118+
output_dim=projection_dim,
119+
dtype=dtype,
120+
rngs=rngs,
121+
)
122+
123+
# Two independent connectors
124+
self.video_embeddings_connector = Embeddings1DConnector(
125+
input_dim=projection_dim,
126+
heads=connector_heads,
127+
head_dim=connector_head_dim,
128+
layers=connector_layers,
129+
num_learnable_registers=num_thinking_tokens,
130+
rope_type="interleaved",
131+
attention_kernel=attention_kernel,
132+
mesh=mesh,
133+
rngs=rngs,
134+
)
135+
136+
self.audio_embeddings_connector = Embeddings1DConnector(
137+
input_dim=projection_dim,
138+
heads=connector_heads,
139+
head_dim=connector_head_dim,
140+
layers=connector_layers,
141+
num_learnable_registers=num_thinking_tokens,
142+
rope_type="interleaved",
143+
attention_kernel=attention_kernel,
144+
mesh=mesh,
145+
rngs=rngs,
146+
)
147+
148+
def __call__(
149+
self,
150+
hidden_states: Union[Tuple[Array, ...], List[Array]],
151+
attention_mask: Array,
152+
) -> Tuple[Array, Array]:
153+
"""
154+
Returns:
155+
(video_embeds, audio_embeds)
156+
"""
157+
# 1. Shared Feature Extraction
158+
features = self.feature_extractor(hidden_states, attention_mask)
159+
160+
# 2. Parallel Connection
161+
video_embeds = self.video_embeddings_connector(features, attention_mask)
162+
audio_embeds = self.audio_embeddings_connector(features, attention_mask)
163+
164+
return video_embeds, audio_embeds
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import unittest
18+
import jax.numpy as jnp
19+
import numpy as np
20+
from flax import nnx
21+
from ..models.ltx2.text_encoders.text_encoders_ltx2 import LTX2VideoGemmaTextEncoder, LTX2AudioVideoGemmaTextEncoder
22+
23+
24+
class LTX2TextEncodersTest(unittest.TestCase):
25+
26+
def setUp(self):
27+
self.rng = nnx.Rngs(0)
28+
self.B = 2
29+
self.T = 16
30+
self.gemma_dim = 32
31+
self.gemma_layers = 3
32+
self.proj_dim = 64
33+
34+
# Mock Gemma hidden states
35+
self.hidden_states = [jnp.array(np.random.randn(self.B, self.T, self.gemma_dim)) for _ in range(self.gemma_layers)]
36+
37+
self.attention_mask = jnp.ones((self.B, self.T))
38+
39+
def test_video_encoder_forward(self):
40+
encoder = LTX2VideoGemmaTextEncoder(
41+
gemma_dim=self.gemma_dim,
42+
gemma_layers=self.gemma_layers,
43+
projection_dim=self.proj_dim,
44+
connector_heads=4,
45+
connector_head_dim=16,
46+
connector_layers=1,
47+
num_thinking_tokens=8,
48+
attention_kernel="dot_product",
49+
mesh=None,
50+
rngs=self.rng,
51+
)
52+
53+
output = encoder(tuple(self.hidden_states), self.attention_mask)
54+
55+
# Expected shape: [B, T, proj_dim]
56+
self.assertEqual(output.shape, (self.B, self.T, self.proj_dim))
57+
print("\n[PASS] Video Encoder Forward Pass Verified.")
58+
59+
def test_av_encoder_forward(self):
60+
encoder = LTX2AudioVideoGemmaTextEncoder(
61+
gemma_dim=self.gemma_dim,
62+
gemma_layers=self.gemma_layers,
63+
projection_dim=self.proj_dim,
64+
connector_heads=4,
65+
connector_head_dim=16,
66+
connector_layers=1,
67+
num_thinking_tokens=8,
68+
attention_kernel="dot_product",
69+
mesh=None,
70+
rngs=self.rng,
71+
)
72+
73+
video_out, audio_out = encoder(tuple(self.hidden_states), self.attention_mask)
74+
75+
# Expected shapes: Both [B, T, proj_dim]
76+
self.assertEqual(video_out.shape, (self.B, self.T, self.proj_dim))
77+
self.assertEqual(audio_out.shape, (self.B, self.T, self.proj_dim))
78+
79+
# Ensure they are different (different random init for connectors)
80+
# Note: In reality they are initialized differently, so outputs should differ
81+
self.assertFalse(
82+
jnp.allclose(video_out, audio_out), "Video and Audio outputs should differ due to different connector weights"
83+
)
84+
85+
print("\n[PASS] Audio-Video Encoder Forward Pass Verified.")
86+
87+
88+
if __name__ == "__main__":
89+
unittest.main()

0 commit comments

Comments
 (0)