@@ -255,6 +255,8 @@ def encode(self, x: torch.Tensor) -> torch.Tensor:
255255 return (self .taesd_encoder (x * 0.5 + 0.5 ) / self .vae_scale ) + self .vae_shift
256256
257257
258+ _taesd_cache = {}
259+
258260def taesd_preview (x : torch .Tensor , flux : bool = False ):
259261 """#### Preview the batched latent tensors as images.
260262
@@ -263,52 +265,44 @@ def taesd_preview(x: torch.Tensor, flux: bool = False):
263265 - `flux` (bool, optional): Whether using flux model (for channel ordering). Defaults to False.
264266 """
265267 if app_instance .app .previewer_var .get () is True :
266- taesd_instance = TAESD ()
267-
268- # Handle channel dimension
269- if x .shape [1 ] != 4 :
270- desired_channels = 4
271- current_channels = x .shape [1 ]
272-
273- if current_channels > desired_channels :
274- x = x [:, :desired_channels , :, :]
275- else :
276- padding = torch .zeros (x .shape [0 ], desired_channels - current_channels ,
277- x .shape [2 ], x .shape [3 ], device = x .device )
278- x = torch .cat ([x , padding ], dim = 1 )
268+ # Optimization: Cache TAESD instance by latent channels to avoid constant re-init
269+ latent_channels = x .shape [1 ]
270+ cache_key = (latent_channels , flux )
271+ if cache_key in _taesd_cache :
272+ taesd_instance = _taesd_cache [cache_key ]
273+ else :
274+ taesd_instance = TAESD (latent_channels = latent_channels )
275+ # Ensure it's on the same device as x for fast inference
276+ taesd_instance .to (x .device )
277+ _taesd_cache [cache_key ] = taesd_instance
278+
279+ # Handle channel dimension mismatch (rare for TAESD but good for robustness)
280+ if x .shape [1 ] != latent_channels :
281+ # Already handled by cache_key, but if it somehow slips through:
282+ pass
279283
280284 # Process entire batch at once
281- decoded_batch = taesd_instance .decode (x )
282-
285+ with torch .no_grad ():
286+ decoded_batch = taesd_instance .decode (x )
287+
288+ # Apply normalization and color space conversion in one go if possible
289+ if flux :
290+ # For flux: BGR -> RGB and specific scale
291+ decoded_batch = decoded_batch [:, [2 , 1 , 0 ], :, :].clamp (- 1 , 1 ).add (1.0 ).mul (0.5 )
292+ else :
293+ # Standard normalization
294+ decoded_batch = decoded_batch .add (1.0 ).mul (0.5 ).clamp (0 , 1 )
295+
296+ # Optimization: Use non_blocking=True for CPU transfer to avoid GPU stall
297+ # Then convert to numpy and uint8
298+ decoded_np = (decoded_batch .mul (255.0 ).to ("cpu" , dtype = torch .uint8 , non_blocking = True ).numpy ())
299+
283300 images = []
284-
285- # Convert each image in batch
286- for decoded in decoded_batch :
287- # Handle channel dimension
288- if decoded .shape [0 ] == 1 :
289- decoded = decoded .repeat (3 , 1 , 1 )
290-
291- # Apply different normalization for flux vs standard mode
292- if flux :
293- # For flux: Assume BGR ordering and different normalization
294- decoded = decoded [[2 ,1 ,0 ], :, :] # BGR -> RGB
295- # Adjust normalization for flux model range
296- decoded = decoded .clamp (- 1 , 1 )
297- decoded = (decoded + 1.0 ) * 0.5 # Scale from [-1,1] to [0,1]
298- else :
299- # Standard normalization
300- decoded = (decoded + 1.0 ) / 2.0
301-
302- # Convert to numpy and uint8
303- image_np = (decoded .cpu ().detach ().numpy () * 255.0 )
304- image_np = np .transpose (image_np , (1 , 2 , 0 ))
305- image_np = np .clip (image_np , 0 , 255 ).astype (np .uint8 )
306-
307- # Create PIL Image
308- img = Image .fromarray (image_np , mode = 'RGB' )
301+ for i in range (decoded_np .shape [0 ]):
302+ # Transpose HWC for PIL
303+ img_data = np .transpose (decoded_np [i ], (1 , 2 , 0 ))
304+ img = Image .fromarray (img_data , mode = 'RGB' )
309305 images .append (img )
310306
311307 # Update display with all images
312308 app_instance .app .update_image (images )
313- else :
314- pass
0 commit comments