@@ -225,7 +225,9 @@ def __init__(
225225 self .per_channel_scale2 = None
226226
227227 if timestep_conditioning :
228- self .scale_shift_table = nnx .Param (jax .random .normal (rngs .params (), (4 , in_channels ), dtype = dtype ) / (in_channels ** 0.5 ))
228+ self .scale_shift_table = nnx .Param (
229+ jax .random .normal (rngs .params (), (4 , in_channels ), dtype = dtype ) / (in_channels ** 0.5 )
230+ )
229231 else :
230232 self .scale_shift_table = None
231233
@@ -1261,40 +1263,37 @@ def enable_tiling(
12611263 def blend_v (self , a : jax .Array , b : jax .Array , blend_extent : int ) -> jax .Array :
12621264 blend_extent = min (a .shape [2 ], b .shape [2 ], blend_extent )
12631265 if blend_extent <= 0 :
1264- return b
1265-
1266+ return b
1267+
12661268 # Create broadcastable blending weights: (1, 1, blend_extent, 1, 1)
12671269 y = jnp .arange (blend_extent , dtype = a .dtype ).reshape (1 , 1 , - 1 , 1 , 1 )
1268-
1269- val = a [:, :, - blend_extent :, :, :] * (1.0 - y / blend_extent ) + \
1270- b [:, :, :blend_extent , :, :] * (y / blend_extent )
1271-
1270+
1271+ val = a [:, :, - blend_extent :, :, :] * (1.0 - y / blend_extent ) + b [:, :, :blend_extent , :, :] * (y / blend_extent )
1272+
12721273 return b .at [:, :, :blend_extent , :, :].set (val )
12731274
12741275 def blend_h (self , a : jax .Array , b : jax .Array , blend_extent : int ) -> jax .Array :
12751276 blend_extent = min (a .shape [3 ], b .shape [3 ], blend_extent )
12761277 if blend_extent <= 0 :
1277- return b
1278-
1278+ return b
1279+
12791280 # Create broadcastable blending weights: (1, 1, 1, blend_extent, 1)
12801281 x = jnp .arange (blend_extent , dtype = a .dtype ).reshape (1 , 1 , 1 , - 1 , 1 )
1281-
1282- val = a [:, :, :, - blend_extent :, :] * (1.0 - x / blend_extent ) + \
1283- b [:, :, :, :blend_extent , :] * (x / blend_extent )
1284-
1282+
1283+ val = a [:, :, :, - blend_extent :, :] * (1.0 - x / blend_extent ) + b [:, :, :, :blend_extent , :] * (x / blend_extent )
1284+
12851285 return b .at [:, :, :, :blend_extent , :].set (val )
12861286
12871287 def blend_t (self , a : jax .Array , b : jax .Array , blend_extent : int ) -> jax .Array :
12881288 blend_extent = min (a .shape [1 ], b .shape [1 ], blend_extent )
12891289 if blend_extent <= 0 :
1290- return b
1291-
1290+ return b
1291+
12921292 # Create broadcastable blending weights: (1, blend_extent, 1, 1, 1)
12931293 x = jnp .arange (blend_extent , dtype = a .dtype ).reshape (1 , - 1 , 1 , 1 , 1 )
1294-
1295- val = a [:, - blend_extent :, :, :, :] * (1.0 - x / blend_extent ) + \
1296- b [:, :blend_extent , :, :, :] * (x / blend_extent )
1297-
1294+
1295+ val = a [:, - blend_extent :, :, :, :] * (1.0 - x / blend_extent ) + b [:, :blend_extent , :, :, :] * (x / blend_extent )
1296+
12981297 return b .at [:, :blend_extent , :, :, :].set (val )
12991298
13001299 def tiled_encode (self , x : jax .Array , key : Optional [jax .Array ] = None , causal : Optional [bool ] = None ) -> jax .Array :
0 commit comments