Skip to content

Commit 70b2705

Browse files
committed
style: adjust ruff formatting for CI compatibility
- Adjust line breaks to match GitHub Actions ruff version - Ensure consistent formatting across environments
1 parent f03f45b commit 70b2705

2 files changed

Lines changed: 24 additions & 64 deletions

File tree

genrl/diffusers_patch/wan_pipeline_with_logprob.py

Lines changed: 22 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,13 @@ def sde_step_with_logprob(
9999
# This is also reproducible, because I have set global seed in the trainer.
100100
# Some local seeding would not impact the global seed, and thus the reproducibility.
101101
)
102-
prev_sample = (
103-
prev_sample_mean + std_dev_t * torch.sqrt(-1 * dt) * variance_noise
104-
)
102+
prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1 * dt) * variance_noise
105103

106104
if deterministic:
107105
prev_sample = sample + dt * model_output
108106

109107
log_prob = (
110-
-((prev_sample.detach() - prev_sample_mean) ** 2)
111-
/ (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2))
108+
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2))
112109
- torch.log(std_dev_t * torch.sqrt(-1 * dt))
113110
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
114111
)
@@ -117,9 +114,9 @@ def sde_step_with_logprob(
117114
std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2) # sigma_t in paper
118115
pred_original_sample = sample - sigma * model_output # predicted x_0 in paper
119116
noise_estimate = sample + model_output * (1 - sigma) # predicted x_1 in paper
120-
prev_sample_mean = pred_original_sample * (
121-
1 - sigma_prev
122-
) + noise_estimate * torch.sqrt(sigma_prev**2 - std_dev_t**2)
117+
prev_sample_mean = pred_original_sample * (1 - sigma_prev) + noise_estimate * torch.sqrt(
118+
sigma_prev**2 - std_dev_t**2
119+
)
123120

124121
if prev_sample is None:
125122
variance_noise = randn_tensor(
@@ -131,18 +128,14 @@ def sde_step_with_logprob(
131128
prev_sample = prev_sample_mean + std_dev_t * variance_noise
132129

133130
if deterministic:
134-
prev_sample = (
135-
pred_original_sample * (1 - sigma_prev) + noise_estimate * sigma_prev
136-
)
131+
prev_sample = pred_original_sample * (1 - sigma_prev) + noise_estimate * sigma_prev
137132

138133
# remove all constants
139134
log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2)
140135

141136
else:
142137
msg = f"Unknown sde_type: {sde_type}. Must be 'flow_sde' or 'flow_cps'."
143-
raise ValueError(
144-
msg
145-
)
138+
raise ValueError(msg)
146139

147140
# mean along all but batch dimension
148141
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
@@ -212,12 +205,7 @@ def wan_pipeline_with_logprob(
212205
)
213206

214207
if num_frames % self.vae_scale_factor_temporal != 1:
215-
num_frames = (
216-
num_frames
217-
// self.vae_scale_factor_temporal
218-
* self.vae_scale_factor_temporal
219-
+ 1
220-
)
208+
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
221209
num_frames = max(num_frames, 1)
222210

223211
self._guidance_scale = guidance_scale
@@ -277,24 +265,18 @@ def wan_pipeline_with_logprob(
277265
f"sde_window_range span ({sde_window_range[1] - sde_window_range[0]}) "
278266
f"must be >= sde_window_size ({sde_window_size})"
279267
)
280-
raise ValueError(
281-
msg
282-
)
268+
raise ValueError(msg)
283269
# Use generator if provided (for training reproducibility), otherwise fallback to random
284270
if generator is not None:
285271
# Extract generator from list if needed
286272
gen = generator[0] if isinstance(generator, list) and len(generator) > 0 else generator
287273
# Use torch.randint with generator for deterministic randomness
288274
max_start = sde_window_range[1] - sde_window_size
289-
start = torch.randint(
290-
sde_window_range[0], max_start + 1, (1,), generator=gen, device=device
291-
).item()
275+
start = torch.randint(sde_window_range[0], max_start + 1, (1,), generator=gen, device=device).item()
292276
else:
293277
# Fallback to Python random (for eval, where generator may not be provided)
294278
# This is safe because eval uses deterministic=True and set_seed at the start
295-
start = random.randint(
296-
sde_window_range[0], sde_window_range[1] - sde_window_size
297-
)
279+
start = random.randint(sde_window_range[0], sde_window_range[1] - sde_window_size)
298280
end = start + sde_window_size
299281
sde_window = (start, end)
300282
# In window mode, initialize all_latents as empty list (will be populated in the loop)
@@ -401,9 +383,7 @@ def wan_pipeline_with_logprob(
401383

402384
latents = callback_outputs.pop("latents", latents)
403385
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
404-
negative_prompt_embeds = callback_outputs.pop(
405-
"negative_prompt_embeds", negative_prompt_embeds
406-
)
386+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
407387

408388
# Compute KL reward
409389
if use_window:
@@ -412,9 +392,7 @@ def wan_pipeline_with_logprob(
412392
if in_window:
413393
if kl_reward > 0 and not deterministic:
414394
latent_model_input = (
415-
torch.cat([latents_ori] * 2)
416-
if self.do_classifier_free_guidance
417-
else latents_ori
395+
torch.cat([latents_ori] * 2) if self.do_classifier_free_guidance else latents_ori
418396
)
419397
ref_model = getattr(self, "ref_transformer", None)
420398
if ref_model is not None:
@@ -440,9 +418,7 @@ def wan_pipeline_with_logprob(
440418
# perform guidance
441419
if self.do_classifier_free_guidance:
442420
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
443-
noise_pred = noise_pred_uncond + self.guidance_scale * (
444-
noise_pred_text - noise_pred_uncond
445-
)
421+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
446422

447423
(
448424
_,
@@ -464,21 +440,15 @@ def wan_pipeline_with_logprob(
464440
diffusion_clip_value=diffusion_clip_value,
465441
)
466442
assert std_dev_t == ref_std_dev_t
467-
kl = (prev_latents_mean - ref_prev_latents_mean) ** 2 / (
468-
2 * std_dev_t**2
469-
)
443+
kl = (prev_latents_mean - ref_prev_latents_mean) ** 2 / (2 * std_dev_t**2)
470444
kl = kl.mean(dim=tuple(range(1, kl.ndim)))
471445
all_kl.append(kl)
472446
else:
473447
# In window but no KL reward, append zero KL
474448
all_kl.append(torch.zeros(len(latents), device=latents.device))
475449
# Original mode: compute KL for all timesteps (sde_window_size == 0)
476450
elif kl_reward > 0 and not deterministic:
477-
latent_model_input = (
478-
torch.cat([latents_ori] * 2)
479-
if self.do_classifier_free_guidance
480-
else latents_ori
481-
)
451+
latent_model_input = torch.cat([latents_ori] * 2) if self.do_classifier_free_guidance else latents_ori
482452
ref_model = getattr(self, "ref_transformer", None)
483453
if ref_model is not None:
484454
ref_ctx = contextlib.nullcontext()
@@ -503,9 +473,7 @@ def wan_pipeline_with_logprob(
503473
# perform guidance
504474
if self.do_classifier_free_guidance:
505475
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
506-
noise_pred = noise_pred_uncond + self.guidance_scale * (
507-
noise_pred_text - noise_pred_uncond
508-
)
476+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
509477

510478
(
511479
_,
@@ -527,19 +495,15 @@ def wan_pipeline_with_logprob(
527495
diffusion_clip_value=diffusion_clip_value,
528496
)
529497
assert std_dev_t == ref_std_dev_t
530-
kl = (prev_latents_mean - ref_prev_latents_mean) ** 2 / (
531-
2 * std_dev_t**2
532-
)
498+
kl = (prev_latents_mean - ref_prev_latents_mean) ** 2 / (2 * std_dev_t**2)
533499
kl = kl.mean(dim=tuple(range(1, kl.ndim)))
534500
all_kl.append(kl)
535501
else:
536502
# no kl reward, we do not need to compute, just put a pre-position value, kl will be 0
537503
all_kl.append(torch.zeros(len(latents), device=latents.device))
538504

539505
# call the callback, if provided
540-
if i == len(timesteps) - 1 or (
541-
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
542-
):
506+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
543507
progress_bar.update()
544508

545509
self._current_timestep = None
@@ -551,9 +515,9 @@ def wan_pipeline_with_logprob(
551515
.view(1, self.vae.config.z_dim, 1, 1, 1)
552516
.to(latents.device, latents.dtype)
553517
)
554-
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
555-
1, self.vae.config.z_dim, 1, 1, 1
556-
).to(latents.device, latents.dtype)
518+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
519+
latents.device, latents.dtype
520+
)
557521
latents = latents / latents_std + latents_mean
558522
# Decode one sample at a time to reduce peak memory.
559523
decoded_videos = []

genrl/reward/hpsv3.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,7 @@ def _fn(
145145
# Repeat the general prompt for all frames
146146
frame_prompts = [general_prompt] * len(frame_paths)
147147
with torch.no_grad(), torch.amp.autocast(device_type=device_type):
148-
frame_rewards_raw = inferencer.reward(
149-
frame_prompts, image_paths=frame_paths
150-
)
148+
frame_rewards_raw = inferencer.reward(frame_prompts, image_paths=frame_paths)
151149

152150
# Extract mu values (mean scores)
153151
# HPSv3 returns a list where each element is [mu, sigma] or a tensor
@@ -226,9 +224,7 @@ def _fn(
226224
# Use the same prompt for all frames in the video
227225
frame_prompts = [prompt] * len(frame_paths)
228226
with torch.no_grad(), torch.amp.autocast(device_type=device_type):
229-
frame_rewards_raw = inferencer.reward(
230-
frame_prompts, image_paths=frame_paths
231-
)
227+
frame_rewards_raw = inferencer.reward(frame_prompts, image_paths=frame_paths)
232228

233229
# Extract mu values (mean scores)
234230
# HPSv3 returns a list where each element is [mu, sigma] or a tensor

0 commit comments

Comments
 (0)