Skip to content

Commit d1d4bf6

Browse files
committed
reformatted
1 parent 4d44e21 commit d1d4bf6

1 file changed

Lines changed: 18 additions & 19 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ def __init__(
225225
self.per_channel_scale2 = None
226226

227227
if timestep_conditioning:
228-
self.scale_shift_table = nnx.Param(jax.random.normal(rngs.params(), (4, in_channels), dtype=dtype) / (in_channels**0.5))
228+
self.scale_shift_table = nnx.Param(
229+
jax.random.normal(rngs.params(), (4, in_channels), dtype=dtype) / (in_channels**0.5)
230+
)
229231
else:
230232
self.scale_shift_table = None
231233

@@ -1261,40 +1263,37 @@ def enable_tiling(
12611263
def blend_v(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array:
12621264
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
12631265
if blend_extent <= 0:
1264-
return b
1265-
1266+
return b
1267+
12661268
# Create broadcastable blending weights: (1, 1, blend_extent, 1, 1)
12671269
y = jnp.arange(blend_extent, dtype=a.dtype).reshape(1, 1, -1, 1, 1)
1268-
1269-
val = a[:, :, -blend_extent:, :, :] * (1.0 - y / blend_extent) + \
1270-
b[:, :, :blend_extent, :, :] * (y / blend_extent)
1271-
1270+
1271+
val = a[:, :, -blend_extent:, :, :] * (1.0 - y / blend_extent) + b[:, :, :blend_extent, :, :] * (y / blend_extent)
1272+
12721273
return b.at[:, :, :blend_extent, :, :].set(val)
12731274

12741275
def blend_h(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array:
12751276
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
12761277
if blend_extent <= 0:
1277-
return b
1278-
1278+
return b
1279+
12791280
# Create broadcastable blending weights: (1, 1, 1, blend_extent, 1)
12801281
x = jnp.arange(blend_extent, dtype=a.dtype).reshape(1, 1, 1, -1, 1)
1281-
1282-
val = a[:, :, :, -blend_extent:, :] * (1.0 - x / blend_extent) + \
1283-
b[:, :, :, :blend_extent, :] * (x / blend_extent)
1284-
1282+
1283+
val = a[:, :, :, -blend_extent:, :] * (1.0 - x / blend_extent) + b[:, :, :, :blend_extent, :] * (x / blend_extent)
1284+
12851285
return b.at[:, :, :, :blend_extent, :].set(val)
12861286

12871287
def blend_t(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array:
12881288
blend_extent = min(a.shape[1], b.shape[1], blend_extent)
12891289
if blend_extent <= 0:
1290-
return b
1291-
1290+
return b
1291+
12921292
# Create broadcastable blending weights: (1, blend_extent, 1, 1, 1)
12931293
x = jnp.arange(blend_extent, dtype=a.dtype).reshape(1, -1, 1, 1, 1)
1294-
1295-
val = a[:, -blend_extent:, :, :, :] * (1.0 - x / blend_extent) + \
1296-
b[:, :blend_extent, :, :, :] * (x / blend_extent)
1297-
1294+
1295+
val = a[:, -blend_extent:, :, :, :] * (1.0 - x / blend_extent) + b[:, :blend_extent, :, :, :] * (x / blend_extent)
1296+
12981297
return b.at[:, :blend_extent, :, :, :].set(val)
12991298

13001299
def tiled_encode(self, x: jax.Array, key: Optional[jax.Array] = None, causal: Optional[bool] = None) -> jax.Array:

0 commit comments

Comments
 (0)