|
| 1 | +""" |
| 2 | +Copyright 2025 Google LLC |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +from typing import Tuple, Union |
| 18 | +import jax.numpy as jnp |
| 19 | +from flax import nnx |
| 20 | +from maxdiffusion import common_types |
| 21 | + |
| 22 | +Array = common_types.Array |
| 23 | +DType = common_types.DType |
| 24 | + |
| 25 | + |
| 26 | +def _norm_and_concat_padded_batch( |
| 27 | + encoded_text: Array, |
| 28 | + sequence_lengths: Array, |
| 29 | + padding_side: str = "right", |
| 30 | +) -> Array: |
| 31 | + """Normalize and flatten multi-layer hidden states, respecting padding. |
| 32 | + Performs per-batch, per-layer normalization using masked mean and range, |
| 33 | + then concatenates across the layer dimension. |
| 34 | +
|
| 35 | + Args: |
| 36 | + encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers]. |
| 37 | + sequence_lengths: Number of valid (non-padded) tokens per batch item. |
| 38 | + padding_side: Whether padding is on "left" or "right". |
| 39 | +
|
| 40 | + Returns: |
| 41 | + Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers], |
| 42 | + with padded positions zeroed out. |
| 43 | + """ |
| 44 | + b, t, d, l = encoded_text.shape |
| 45 | + |
| 46 | + # Build mask: [B, T] -> [B, T, 1, 1] |
| 47 | + # token_indices: [1, T] |
| 48 | + token_indices = jnp.arange(t)[None, :] |
| 49 | + |
| 50 | + if padding_side == "right": |
| 51 | + # Valid: indices < lengths |
| 52 | + mask = token_indices < sequence_lengths[:, None] |
| 53 | + elif padding_side == "left": |
| 54 | + # Valid: indices >= (T - lengths) |
| 55 | + start_indices = t - sequence_lengths[:, None] |
| 56 | + mask = token_indices >= start_indices |
| 57 | + else: |
| 58 | + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") |
| 59 | + |
| 60 | + # [B, T, 1, 1] |
| 61 | + mask = mask[:, :, None, None] |
| 62 | + |
| 63 | + eps = 1e-6 |
| 64 | + |
| 65 | + # 1. Compute Masked Mean |
| 66 | + # Masked sum: [B, 1, 1, L] (sum over T, D) |
| 67 | + # Using jnp.where to zero-out padding |
| 68 | + masked_text = jnp.where(mask, encoded_text, 0.0) |
| 69 | + sum_vals = jnp.sum(masked_text, axis=(1, 2), keepdims=True) |
| 70 | + |
| 71 | + # Denom: sequence_length * D |
| 72 | + denom = (sequence_lengths * d).reshape(b, 1, 1, 1) |
| 73 | + mean = sum_vals / (denom + eps) |
| 74 | + |
| 75 | + # 2. Compute Masked Min/Max for Range |
| 76 | + # Use jnp.inf / -jnp.inf for padding to ignore them in min/max |
| 77 | + safe_text_min = jnp.where(mask, encoded_text, jnp.inf) |
| 78 | + safe_text_max = jnp.where(mask, encoded_text, -jnp.inf) |
| 79 | + |
| 80 | + x_min = jnp.min(safe_text_min, axis=(1, 2), keepdims=True) |
| 81 | + x_max = jnp.max(safe_text_max, axis=(1, 2), keepdims=True) |
| 82 | + |
| 83 | + range_val = x_max - x_min |
| 84 | + |
| 85 | + # 3. Normalize |
| 86 | + # Only valid tokens are normalized. Padding will be garbage but masked out later. |
| 87 | + normed = 8.0 * (encoded_text - mean) / (range_val + eps) |
| 88 | + |
| 89 | + # 4. Concatenate/Flatten Layers |
| 90 | + # [B, T, D, L] -> [B, T, D * L] |
| 91 | + normed = normed.reshape(b, t, -1) |
| 92 | + |
| 93 | + # 5. Apply Mask to Output |
| 94 | + # Ensure padding positions are exactly 0.0 |
| 95 | + # mask: [B, T, 1, 1] -> [B, T, 1] |
| 96 | + output_mask = mask.squeeze(-1).squeeze(-1)[:, :, None] |
| 97 | + normed = jnp.where(output_mask, normed, 0.0) |
| 98 | + |
| 99 | + return normed |
| 100 | + |
| 101 | + |
| 102 | +class LTX2GemmaFeatureExtractor(nnx.Module): |
| 103 | + """ |
| 104 | + Feature extractor module for Gemma models in LTX-2. |
| 105 | + Applies mean-centered scaling and a linear projection. |
| 106 | + """ |
| 107 | + |
| 108 | + def __init__( |
| 109 | + self, |
| 110 | + input_dim: int, |
| 111 | + output_dim: int, |
| 112 | + dtype: DType = jnp.float32, |
| 113 | + rngs: nnx.Rngs = None, |
| 114 | + ): |
| 115 | + """ |
| 116 | + Args: |
| 117 | + input_dim: Dimension of flattened hidden states (Gemma dim * Num layers). |
| 118 | + output_dim: Target dimension for diffusion conditioning. |
| 119 | + """ |
| 120 | + # LTX-2 uses bias=False for the projection |
| 121 | + self.linear = nnx.Linear(input_dim, output_dim, use_bias=False, dtype=dtype, rngs=rngs) |
| 122 | + |
| 123 | + def __call__( |
| 124 | + self, hidden_states: Union[Tuple[Array, ...], Array], attention_mask: Array, padding_side: str = "right" |
| 125 | + ) -> Array: |
| 126 | + """ |
| 127 | + Args: |
| 128 | + hidden_states: Tuple of arrays from Gemma, each [B, T, D]. |
| 129 | + Or pre-stacked array [B, T, D, L]. |
| 130 | + attention_mask: Mask [B, T] (1 for valid, 0 for padding). |
| 131 | + padding_side: "right" or "left". |
| 132 | +
|
| 133 | + Returns: |
| 134 | + Projected features [B, T, OutputDim]. |
| 135 | + """ |
| 136 | + |
| 137 | + # 1. Stack Hidden States if needed |
| 138 | + if isinstance(hidden_states, (tuple, list)): |
| 139 | + # [B, T, D, L] |
| 140 | + x = jnp.stack(hidden_states, axis=-1) |
| 141 | + else: |
| 142 | + x = hidden_states |
| 143 | + |
| 144 | + # 2. Calculate Sequence Lengths |
| 145 | + sequence_lengths = jnp.sum(attention_mask, axis=-1) |
| 146 | + |
| 147 | + # 3. Norm and Concat |
| 148 | + x_norm = _norm_and_concat_padded_batch(x, sequence_lengths, padding_side=padding_side) |
| 149 | + |
| 150 | + # 4. Projection |
| 151 | + return self.linear(x_norm) |
0 commit comments