@@ -99,16 +99,13 @@ def sde_step_with_logprob(
9999 # This is also reproducible, because I have set global seed in the trainer.
100100 # Some local seeding would not impact the global seed, and thus the reproducibility.
101101 )
102- prev_sample = (
103- prev_sample_mean + std_dev_t * torch .sqrt (- 1 * dt ) * variance_noise
104- )
102+ prev_sample = prev_sample_mean + std_dev_t * torch .sqrt (- 1 * dt ) * variance_noise
105103
106104 if deterministic :
107105 prev_sample = sample + dt * model_output
108106
109107 log_prob = (
110- - ((prev_sample .detach () - prev_sample_mean ) ** 2 )
111- / (2 * ((std_dev_t * torch .sqrt (- 1 * dt )) ** 2 ))
108+ - ((prev_sample .detach () - prev_sample_mean ) ** 2 ) / (2 * ((std_dev_t * torch .sqrt (- 1 * dt )) ** 2 ))
112109 - torch .log (std_dev_t * torch .sqrt (- 1 * dt ))
113110 - torch .log (torch .sqrt (2 * torch .as_tensor (math .pi )))
114111 )
@@ -117,9 +114,9 @@ def sde_step_with_logprob(
117114 std_dev_t = sigma_prev * math .sin (noise_level * math .pi / 2 ) # sigma_t in paper
118115 pred_original_sample = sample - sigma * model_output # predicted x_0 in paper
119116 noise_estimate = sample + model_output * (1 - sigma ) # predicted x_1 in paper
120- prev_sample_mean = pred_original_sample * (
121- 1 - sigma_prev
122- ) + noise_estimate * torch . sqrt ( sigma_prev ** 2 - std_dev_t ** 2 )
117+ prev_sample_mean = pred_original_sample * (1 - sigma_prev ) + noise_estimate * torch . sqrt (
118+ sigma_prev ** 2 - std_dev_t ** 2
119+ )
123120
124121 if prev_sample is None :
125122 variance_noise = randn_tensor (
@@ -131,18 +128,14 @@ def sde_step_with_logprob(
131128 prev_sample = prev_sample_mean + std_dev_t * variance_noise
132129
133130 if deterministic :
134- prev_sample = (
135- pred_original_sample * (1 - sigma_prev ) + noise_estimate * sigma_prev
136- )
131+ prev_sample = pred_original_sample * (1 - sigma_prev ) + noise_estimate * sigma_prev
137132
138133 # remove all constants
139134 log_prob = - ((prev_sample .detach () - prev_sample_mean ) ** 2 )
140135
141136 else :
142137 msg = f"Unknown sde_type: { sde_type } . Must be 'flow_sde' or 'flow_cps'."
143- raise ValueError (
144- msg
145- )
138+ raise ValueError (msg )
146139
147140 # mean along all but batch dimension
148141 log_prob = log_prob .mean (dim = tuple (range (1 , log_prob .ndim )))
@@ -212,12 +205,7 @@ def wan_pipeline_with_logprob(
212205 )
213206
214207 if num_frames % self .vae_scale_factor_temporal != 1 :
215- num_frames = (
216- num_frames
217- // self .vae_scale_factor_temporal
218- * self .vae_scale_factor_temporal
219- + 1
220- )
208+ num_frames = num_frames // self .vae_scale_factor_temporal * self .vae_scale_factor_temporal + 1
221209 num_frames = max (num_frames , 1 )
222210
223211 self ._guidance_scale = guidance_scale
@@ -277,24 +265,18 @@ def wan_pipeline_with_logprob(
277265 f"sde_window_range span ({ sde_window_range [1 ] - sde_window_range [0 ]} ) "
278266 f"must be >= sde_window_size ({ sde_window_size } )"
279267 )
280- raise ValueError (
281- msg
282- )
268+ raise ValueError (msg )
283269 # Use generator if provided (for training reproducibility), otherwise fallback to random
284270 if generator is not None :
285271 # Extract generator from list if needed
286272 gen = generator [0 ] if isinstance (generator , list ) and len (generator ) > 0 else generator
287273 # Use torch.randint with generator for deterministic randomness
288274 max_start = sde_window_range [1 ] - sde_window_size
289- start = torch .randint (
290- sde_window_range [0 ], max_start + 1 , (1 ,), generator = gen , device = device
291- ).item ()
275+ start = torch .randint (sde_window_range [0 ], max_start + 1 , (1 ,), generator = gen , device = device ).item ()
292276 else :
293277 # Fallback to Python random (for eval, where generator may not be provided)
294278 # This is safe because eval uses deterministic=True and set_seed at the start
295- start = random .randint (
296- sde_window_range [0 ], sde_window_range [1 ] - sde_window_size
297- )
279+ start = random .randint (sde_window_range [0 ], sde_window_range [1 ] - sde_window_size )
298280 end = start + sde_window_size
299281 sde_window = (start , end )
300282 # In window mode, initialize all_latents as empty list (will be populated in the loop)
@@ -401,9 +383,7 @@ def wan_pipeline_with_logprob(
401383
402384 latents = callback_outputs .pop ("latents" , latents )
403385 prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
404- negative_prompt_embeds = callback_outputs .pop (
405- "negative_prompt_embeds" , negative_prompt_embeds
406- )
386+ negative_prompt_embeds = callback_outputs .pop ("negative_prompt_embeds" , negative_prompt_embeds )
407387
408388 # Compute KL reward
409389 if use_window :
@@ -412,9 +392,7 @@ def wan_pipeline_with_logprob(
412392 if in_window :
413393 if kl_reward > 0 and not deterministic :
414394 latent_model_input = (
415- torch .cat ([latents_ori ] * 2 )
416- if self .do_classifier_free_guidance
417- else latents_ori
395+ torch .cat ([latents_ori ] * 2 ) if self .do_classifier_free_guidance else latents_ori
418396 )
419397 ref_model = getattr (self , "ref_transformer" , None )
420398 if ref_model is not None :
@@ -440,9 +418,7 @@ def wan_pipeline_with_logprob(
440418 # perform guidance
441419 if self .do_classifier_free_guidance :
442420 noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
443- noise_pred = noise_pred_uncond + self .guidance_scale * (
444- noise_pred_text - noise_pred_uncond
445- )
421+ noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
446422
447423 (
448424 _ ,
@@ -464,21 +440,15 @@ def wan_pipeline_with_logprob(
464440 diffusion_clip_value = diffusion_clip_value ,
465441 )
466442 assert std_dev_t == ref_std_dev_t
467- kl = (prev_latents_mean - ref_prev_latents_mean ) ** 2 / (
468- 2 * std_dev_t ** 2
469- )
443+ kl = (prev_latents_mean - ref_prev_latents_mean ) ** 2 / (2 * std_dev_t ** 2 )
470444 kl = kl .mean (dim = tuple (range (1 , kl .ndim )))
471445 all_kl .append (kl )
472446 else :
473447 # In window but no KL reward, append zero KL
474448 all_kl .append (torch .zeros (len (latents ), device = latents .device ))
475449 # Original mode: compute KL for all timesteps (sde_window_size == 0)
476450 elif kl_reward > 0 and not deterministic :
477- latent_model_input = (
478- torch .cat ([latents_ori ] * 2 )
479- if self .do_classifier_free_guidance
480- else latents_ori
481- )
451+ latent_model_input = torch .cat ([latents_ori ] * 2 ) if self .do_classifier_free_guidance else latents_ori
482452 ref_model = getattr (self , "ref_transformer" , None )
483453 if ref_model is not None :
484454 ref_ctx = contextlib .nullcontext ()
@@ -503,9 +473,7 @@ def wan_pipeline_with_logprob(
503473 # perform guidance
504474 if self .do_classifier_free_guidance :
505475 noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
506- noise_pred = noise_pred_uncond + self .guidance_scale * (
507- noise_pred_text - noise_pred_uncond
508- )
476+ noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
509477
510478 (
511479 _ ,
@@ -527,19 +495,15 @@ def wan_pipeline_with_logprob(
527495 diffusion_clip_value = diffusion_clip_value ,
528496 )
529497 assert std_dev_t == ref_std_dev_t
530- kl = (prev_latents_mean - ref_prev_latents_mean ) ** 2 / (
531- 2 * std_dev_t ** 2
532- )
498+ kl = (prev_latents_mean - ref_prev_latents_mean ) ** 2 / (2 * std_dev_t ** 2 )
533499 kl = kl .mean (dim = tuple (range (1 , kl .ndim )))
534500 all_kl .append (kl )
535501 else :
536502 # no kl reward, we do not need to compute, just put a pre-position value, kl will be 0
537503 all_kl .append (torch .zeros (len (latents ), device = latents .device ))
538504
539505 # call the callback, if provided
540- if i == len (timesteps ) - 1 or (
541- (i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0
542- ):
506+ if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
543507 progress_bar .update ()
544508
545509 self ._current_timestep = None
@@ -551,9 +515,9 @@ def wan_pipeline_with_logprob(
551515 .view (1 , self .vae .config .z_dim , 1 , 1 , 1 )
552516 .to (latents .device , latents .dtype )
553517 )
554- latents_std = 1.0 / torch .tensor (self .vae .config .latents_std ).view (
555- 1 , self . vae . config . z_dim , 1 , 1 , 1
556- ). to ( latents . device , latents . dtype )
518+ latents_std = 1.0 / torch .tensor (self .vae .config .latents_std ).view (1 , self . vae . config . z_dim , 1 , 1 , 1 ). to (
519+ latents . device , latents . dtype
520+ )
557521 latents = latents / latents_std + latents_mean
558522 # Decode one sample at a time to reduce peak memory.
559523 decoded_videos = []
0 commit comments