Skip to content

Commit d189d05

Browse files
committed
Add LTX2 Vocoder
1 parent cddbf6a commit d189d05

File tree

2 files changed

+610
-0
lines changed

2 files changed

+610
-0
lines changed
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
"""
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import math
18+
from typing import Sequence
19+
20+
import jax
21+
import jax.numpy as jnp
22+
from flax import nnx
23+
from ... import common_types
24+
from maxdiffusion.configuration_utils import ConfigMixin, register_to_config
25+
from maxdiffusion.models.modeling_flax_utils import FlaxModelMixin
26+
27+
Array = common_types.Array
28+
DType = common_types.DType
29+
30+
31+
class ResBlock(nnx.Module):
32+
"""
33+
Residual Block for the LTX-2 Vocoder.
34+
"""
35+
36+
def __init__(
37+
self,
38+
channels: int,
39+
kernel_size: int = 3,
40+
stride: int = 1,
41+
dilations: Sequence[int] = (1, 3, 5),
42+
leaky_relu_negative_slope: float = 0.1,
43+
*,
44+
rngs: nnx.Rngs,
45+
dtype: DType = jnp.float32,
46+
):
47+
self.dilations = dilations
48+
self.negative_slope = leaky_relu_negative_slope
49+
50+
self.convs1 = nnx.List(
51+
[
52+
nnx.Conv(
53+
in_features=channels,
54+
out_features=channels,
55+
kernel_size=(kernel_size,),
56+
strides=(stride,),
57+
kernel_dilation=(dilation,),
58+
padding="SAME",
59+
rngs=rngs,
60+
dtype=dtype,
61+
)
62+
for dilation in dilations
63+
]
64+
)
65+
66+
self.convs2 = nnx.List(
67+
[
68+
nnx.Conv(
69+
in_features=channels,
70+
out_features=channels,
71+
kernel_size=(kernel_size,),
72+
strides=(stride,),
73+
kernel_dilation=(1,),
74+
padding="SAME",
75+
rngs=rngs,
76+
dtype=dtype,
77+
)
78+
for _ in range(len(dilations))
79+
]
80+
)
81+
82+
def __call__(self, x: Array) -> Array:
83+
for conv1, conv2 in zip(self.convs1, self.convs2):
84+
xt = jax.nn.leaky_relu(x, negative_slope=self.negative_slope)
85+
xt = conv1(xt)
86+
xt = jax.nn.leaky_relu(xt, negative_slope=self.negative_slope)
87+
xt = conv2(xt)
88+
x = x + xt
89+
return x
90+
91+
92+
class LTX2Vocoder(nnx.Module, FlaxModelMixin, ConfigMixin):
93+
"""
94+
LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.
95+
"""
96+
97+
@register_to_config
98+
def __init__(
99+
self,
100+
in_channels: int = 128,
101+
hidden_channels: int = 1024,
102+
out_channels: int = 2,
103+
upsample_kernel_sizes: Sequence[int] = (16, 15, 8, 4, 4),
104+
upsample_factors: Sequence[int] = (6, 5, 2, 2, 2),
105+
resnet_kernel_sizes: Sequence[int] = (3, 7, 11),
106+
resnet_dilations: Sequence[Sequence[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
107+
leaky_relu_negative_slope: float = 0.1,
108+
# output_sampling_rate is unused in model structure but kept for config compat
109+
output_sampling_rate: int = 24000,
110+
*,
111+
rngs: nnx.Rngs,
112+
dtype: DType = jnp.float32,
113+
):
114+
self.num_upsample_layers = len(upsample_kernel_sizes)
115+
self.resnets_per_upsample = len(resnet_kernel_sizes)
116+
self.out_channels = out_channels
117+
self.total_upsample_factor = math.prod(upsample_factors)
118+
self.negative_slope = leaky_relu_negative_slope
119+
self.dtype = dtype
120+
121+
if self.num_upsample_layers != len(upsample_factors):
122+
raise ValueError(
123+
f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length"
124+
f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively."
125+
)
126+
127+
if self.resnets_per_upsample != len(resnet_dilations):
128+
raise ValueError(
129+
f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length"
130+
f" {self.resnets_per_upsample} and {len(resnet_dilations)}, respectively."
131+
)
132+
133+
# PyTorch Conv1d expects (Batch, Channels, Length), we use (Batch, Length, Channels)
134+
# So in_channels/out_channels args are standard, but data layout is transposed in __call__
135+
self.conv_in = nnx.Conv(
136+
in_features=in_channels,
137+
out_features=hidden_channels,
138+
kernel_size=(7,),
139+
strides=(1,),
140+
padding="SAME",
141+
rngs=rngs,
142+
dtype=self.dtype,
143+
)
144+
145+
self.upsamplers = nnx.List()
146+
self.resnets = nnx.List()
147+
input_channels = hidden_channels
148+
149+
for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
150+
output_channels = input_channels // 2
151+
152+
# ConvTranspose with padding='SAME' matches PyTorch's specific padding logic
153+
# for these standard HiFi-GAN upsampling configurations.
154+
self.upsamplers.append(
155+
nnx.ConvTranspose(
156+
in_features=input_channels,
157+
out_features=output_channels,
158+
kernel_size=(kernel_size,),
159+
strides=(stride,),
160+
padding="SAME",
161+
rngs=rngs,
162+
dtype=self.dtype,
163+
)
164+
)
165+
166+
for res_kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
167+
self.resnets.append(
168+
ResBlock(
169+
channels=output_channels,
170+
kernel_size=res_kernel_size,
171+
dilations=dilations,
172+
leaky_relu_negative_slope=leaky_relu_negative_slope,
173+
rngs=rngs,
174+
dtype=self.dtype,
175+
)
176+
)
177+
input_channels = output_channels
178+
179+
self.conv_out = nnx.Conv(
180+
in_features=input_channels,
181+
out_features=out_channels,
182+
kernel_size=(7,),
183+
strides=(1,),
184+
padding="SAME",
185+
rngs=rngs,
186+
dtype=self.dtype,
187+
)
188+
189+
def __call__(self, hidden_states: Array, time_last: bool = False) -> Array:
190+
"""
191+
Forward pass of the vocoder.
192+
193+
Args:
194+
hidden_states: Input Mel spectrogram tensor.
195+
Shape: `(B, C, T, F)` or `(B, C, F, T)`
196+
time_last: Legacy flag for input layout.
197+
198+
Returns:
199+
Audio waveform: `(B, OutChannels, AudioLength)`
200+
"""
201+
# Ensure layout: (Batch, Channels, MelBins, Time)
202+
if not time_last:
203+
hidden_states = jnp.transpose(hidden_states, (0, 1, 3, 2))
204+
205+
# Flatten Channels and MelBins -> (Batch, Features, Time)
206+
batch, channels, mel_bins, time = hidden_states.shape
207+
hidden_states = hidden_states.reshape(batch, channels * mel_bins, time)
208+
209+
# Transpose to (Batch, Time, Features) for Flax NWC Convolutions
210+
hidden_states = jnp.transpose(hidden_states, (0, 2, 1))
211+
212+
hidden_states = self.conv_in(hidden_states)
213+
214+
for i in range(self.num_upsample_layers):
215+
hidden_states = jax.nn.leaky_relu(hidden_states, negative_slope=self.negative_slope)
216+
hidden_states = self.upsamplers[i](hidden_states)
217+
218+
# Accumulate ResNet outputs (Memory Optimization)
219+
start = i * self.resnets_per_upsample
220+
end = (i + 1) * self.resnets_per_upsample
221+
222+
res_sum = 0.0
223+
for j in range(start, end):
224+
res_sum = res_sum + self.resnets[j](hidden_states)
225+
226+
# Average the outputs (matches PyTorch mean(stack))
227+
hidden_states = res_sum / self.resnets_per_upsample
228+
229+
# Final Post-Processing
230+
# Note: using 0.01 slope here specifically (matches Diffusers implementation quirk)
231+
hidden_states = jax.nn.leaky_relu(hidden_states, negative_slope=0.01)
232+
hidden_states = self.conv_out(hidden_states)
233+
hidden_states = jnp.tanh(hidden_states)
234+
235+
# Transpose back to (Batch, Channels, Time) to match PyTorch/Diffusers output format
236+
hidden_states = jnp.transpose(hidden_states, (0, 2, 1))
237+
238+
return hidden_states

0 commit comments

Comments
 (0)