diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 6d06218c..cd4e5686 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -302,8 +302,11 @@ guidance_scale_high: 4.0 # timestep to switch between low noise and high noise transformer boundary_ratio: 0.875 -# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only) +# Diffusion CFG cache (FasterCache-style) use_cfg_cache: False +# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass +# when predicted output change (based on accumulated latent/timestep drift) is small +use_sen_cache: False # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index d9d3af7c..56c5a8a0 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -139,6 +139,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): guidance_scale_low=config.guidance_scale_low, guidance_scale_high=config.guidance_scale_high, use_cfg_cache=config.use_cfg_cache, + use_sen_cache=config.use_sen_cache, ) else: raise ValueError(f"Unsupported model_name for T2Vin config: {model_key}") @@ -179,6 +180,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): max_logging.log("Could not retrieve Git commit hash.") if pipeline is None: + load_start = time.perf_counter() model_type = config.model_type if model_key == WAN2_1: if model_type == "I2V": @@ -193,6 +195,10 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): else: raise ValueError(f"Unsupported model_name for checkpointer: {model_key}") pipeline, _, _ = checkpoint_loader.load_checkpoint() + load_time = time.perf_counter() - load_start + max_logging.log(f"load_time: {load_time:.1f}s") + else: + load_time = 0.0 # If LoRA is specified, inject layers and load weights. if ( @@ -276,6 +282,17 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): max_logging.log(f"generation time per video: {generation_time_per_video}") else: max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.") + max_logging.log( + f"\n{'=' * 50}\n" + f" TIMING SUMMARY\n" + f"{'=' * 50}\n" + f" Load (checkpoint): {load_time:>7.1f}s\n" + f" Compile: {compile_time:>7.1f}s\n" + f" {'─' * 40}\n" + f" Inference: {generation_time:>7.1f}s\n" + f"{'=' * 50}" + ) + s0 = time.perf_counter() if config.enable_profiler: max_utils.activate_profiler(config) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index b8f818e3..f6a3d937 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -111,7 +111,11 @@ def __call__( negative_prompt_embeds: jax.Array = None, vae_only: bool = False, use_cfg_cache: bool = False, + use_sen_cache: bool = False, ): + if use_cfg_cache and use_sen_cache: + raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.") + if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0): raise ValueError( f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 " @@ -119,6 +123,13 @@ def __call__( "CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases." ) + if use_sen_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0): + raise ValueError( + f"use_sen_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 " + f"(got {guidance_scale_low}, {guidance_scale_high}). " + "SenCache requires classifier-free guidance to be enabled for both transformer phases." + ) + latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs( prompt, negative_prompt, @@ -148,6 +159,7 @@ def __call__( scheduler=self.scheduler, scheduler_state=scheduler_state, use_cfg_cache=use_cfg_cache, + use_sen_cache=use_sen_cache, height=height, ) @@ -184,22 +196,142 @@ def run_inference_2_2( scheduler: FlaxUniPCMultistepScheduler, scheduler_state, use_cfg_cache: bool = False, + use_sen_cache: bool = False, height: int = 480, ): - """Denoising loop for WAN 2.2 T2V with optional FasterCache CFG-Cache. - - Dual-transformer CFG-Cache strategy (enabled via use_cfg_cache=True): - - High-noise phase (t >= boundary): always full CFG — short phase, critical - for establishing video structure. - - Low-noise phase (t < boundary): FasterCache alternation — full CFG every N - steps, FFT frequency-domain compensation on cache steps (batch×1). - - Boundary transition: mandatory full CFG step to populate cache for the - low-noise transformer. - - FFT compensation identical to WAN 2.1 (Lv et al., ICLR 2025). + """Denoising loop for WAN 2.2 T2V with optional caching acceleration. + + Supports two caching strategies: + + 1. CFG-Cache (use_cfg_cache=True) — FasterCache-style: + Caches the unconditional branch and uses FFT frequency-domain compensation. + + 2. SenCache (use_sen_cache=True) — Sensitivity-Aware Caching + (Haghighi & Alahi, arXiv:2602.24208): + Uses a first-order sensitivity approximation S = α_x·‖Δx‖ + α_t·|Δt| + to predict output change. Caches when predicted change is below tolerance ε. + Tracks accumulated latent drift and timestep drift since last cache refresh, + adapting cache decisions per-sample. Sensitivity weights (α_x, α_t) are + estimated from warmup steps via finite differences. """ do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 bsz = latents.shape[0] + # ── SenCache path (arXiv:2602.24208) ── + if use_sen_cache and do_classifier_free_guidance: + timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) + step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)] + + # SenCache hyperparameters + sen_epsilon = 0.1 # main tolerance (permissive phase) + max_reuse = 3 # max consecutive cache reuses before forced recompute + warmup_steps = 1 # first step always computes + # No-cache zones: first 30% (structure formation) and last 10% (detail refinement) + nocache_start_ratio = 0.3 + nocache_end_ratio = 0.1 + # Uniform sensitivity weights (α_x=1, α_t=1); swap for pre-calibrated + # SensitivityProfile per-timestep values when available. + alpha_x, alpha_t = 1.0, 1.0 + + nocache_start = int(num_inference_steps * nocache_start_ratio) + nocache_end_begin = int(num_inference_steps * (1.0 - nocache_end_ratio)) + # Normalize timesteps to [0, 1]. + # maxdiffusion timesteps are integers in [0, num_train_timesteps] + # uses sigmas in [0, 1]. Without normalization |Δt|≈20 >> ε and nothing caches. + num_train_timesteps = float(scheduler.config.num_train_timesteps) + + prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + + # SenCache state + ref_noise_pred = None # y^r: cached denoiser output + ref_latent = None # x^r: latent at last cache refresh + ref_timestep = 0.0 # t^r: timestep (normalized to [0,1]) at last cache refresh + accum_dx = 0.0 # accumulated ||Δx|| since last refresh + accum_dt = 0.0 # accumulated |Δt| since last refresh + reuse_count = 0 # consecutive cache reuses + cache_count = 0 + + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + t_float = float(timesteps_np[step]) / num_train_timesteps # normalize to [0, 1] + + # Select transformer and guidance scale + if step_uses_high[step]: + graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + guidance_scale = guidance_scale_high + else: + graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest + guidance_scale = guidance_scale_low + + # Force full compute: warmup, first 30%, last 10%, or transformer boundary + is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1] + force_compute = ( + step < warmup_steps or step < nocache_start or step >= nocache_end_begin or is_boundary or ref_noise_pred is None + ) + + if force_compute: + latents_doubled = jnp.concatenate([latents] * 2) + timestep = jnp.broadcast_to(t, bsz * 2) + noise_pred, _, _ = transformer_forward_pass_full_cfg( + graphdef, + state, + rest, + latents_doubled, + timestep, + prompt_embeds_combined, + guidance_scale=guidance_scale, + ) + ref_noise_pred = noise_pred + ref_latent = latents + ref_timestep = t_float + accum_dx = 0.0 + accum_dt = 0.0 + reuse_count = 0 + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + continue + + # Accumulate deltas since last full compute + dx_norm = float(jnp.sqrt(jnp.mean((latents - ref_latent) ** 2))) + dt = abs(t_float - ref_timestep) + accum_dx += dx_norm + accum_dt += dt + + # Sensitivity score (Eq. 9) + score = alpha_x * accum_dx + alpha_t * accum_dt + + if score <= sen_epsilon and reuse_count < max_reuse: + # Cache hit: reuse previous output + noise_pred = ref_noise_pred + reuse_count += 1 + cache_count += 1 + else: + # Cache miss: full CFG forward pass + latents_doubled = jnp.concatenate([latents] * 2) + timestep = jnp.broadcast_to(t, bsz * 2) + noise_pred, _, _ = transformer_forward_pass_full_cfg( + graphdef, + state, + rest, + latents_doubled, + timestep, + prompt_embeds_combined, + guidance_scale=guidance_scale, + ) + ref_noise_pred = noise_pred + ref_latent = latents + ref_timestep = t_float + accum_dx = 0.0 + accum_dt = 0.0 + reuse_count = 0 + + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + + print( + f"[SenCache] Cached {cache_count}/{num_inference_steps} steps " + f"({100*cache_count/num_inference_steps:.1f}% cache ratio)" + ) + return latents + # ── CFG cache path ── if use_cfg_cache and do_classifier_free_guidance: # Get timesteps as numpy for Python-level scheduling decisions diff --git a/src/maxdiffusion/tests/wan_sen_cache_test.py b/src/maxdiffusion/tests/wan_sen_cache_test.py new file mode 100644 index 00000000..1d2fe76c --- /dev/null +++ b/src/maxdiffusion/tests/wan_sen_cache_test.py @@ -0,0 +1,354 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import time +import unittest + +import numpy as np +import pytest +from absl.testing import absltest + +from maxdiffusion.pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2 + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +class WanSenCacheValidationTest(unittest.TestCase): + """Tests that use_sen_cache validation raises correct errors.""" + + def _make_pipeline(self): + pipeline = WanPipeline2_2.__new__(WanPipeline2_2) + return pipeline + + def test_sen_cache_with_both_scales_low_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + guidance_scale_low=1.0, + guidance_scale_high=1.0, + use_sen_cache=True, + ) + self.assertIn("use_sen_cache", str(ctx.exception)) + + def test_sen_cache_with_low_scale_low_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + guidance_scale_low=0.5, + guidance_scale_high=4.0, + use_sen_cache=True, + ) + self.assertIn("use_sen_cache", str(ctx.exception)) + + def test_sen_cache_with_high_scale_low_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + guidance_scale_low=3.0, + guidance_scale_high=1.0, + use_sen_cache=True, + ) + self.assertIn("use_sen_cache", str(ctx.exception)) + + def test_sen_cache_mutually_exclusive_with_cfg_cache(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + guidance_scale_low=3.0, + guidance_scale_high=4.0, + use_cfg_cache=True, + use_sen_cache=True, + ) + self.assertIn("mutually exclusive", str(ctx.exception)) + + def test_sen_cache_with_valid_scales_no_validation_error(self): + """Both guidance_scales > 1.0 should pass validation (may fail later without model).""" + pipeline = self._make_pipeline() + try: + pipeline( + prompt=["test"], + guidance_scale_low=3.0, + guidance_scale_high=4.0, + use_sen_cache=True, + ) + except ValueError as e: + if "use_sen_cache" in str(e): + self.fail(f"Unexpected validation error: {e}") + except Exception: + pass + + def test_no_sen_cache_with_low_scales_no_error(self): + """use_sen_cache=False should never raise our ValueError.""" + pipeline = self._make_pipeline() + try: + pipeline( + prompt=["test"], + guidance_scale_low=0.5, + guidance_scale_high=0.5, + use_sen_cache=False, + ) + except ValueError as e: + if "use_sen_cache" in str(e): + self.fail(f"Unexpected validation error: {e}") + except Exception: + pass + + +class WanSenCacheScheduleTest(unittest.TestCase): + """Tests the SenCache schedule logic (force-compute zones and sensitivity gating). + + Mirrors the schedule computation in run_inference_2_2 to verify correctness + of force_compute zones. The actual sensitivity gating (score <= epsilon) is + data-dependent, so we test the deterministic scheduling constraints here. + """ + + def _get_force_compute_schedule(self, num_inference_steps, boundary_ratio=0.875, num_train_timesteps=1000): + """Extract which steps are forced to compute (cannot be cached). + + Returns (force_compute, step_uses_high) lists. + """ + boundary = boundary_ratio * num_train_timesteps + timesteps = np.linspace(num_train_timesteps - 1, 0, num_inference_steps, dtype=np.int32) + step_uses_high = [bool(timesteps[s] >= boundary) for s in range(num_inference_steps)] + + # SenCache hyperparameters (mirrored from run_inference_2_2) + warmup_steps = 1 + nocache_start_ratio = 0.3 + nocache_end_ratio = 0.1 + + nocache_start = int(num_inference_steps * nocache_start_ratio) + nocache_end_begin = int(num_inference_steps * (1.0 - nocache_end_ratio)) + + force_compute = [] + for s in range(num_inference_steps): + is_boundary = s > 0 and step_uses_high[s] != step_uses_high[s - 1] + forced = ( + s < warmup_steps + or s < nocache_start + or s >= nocache_end_begin + or is_boundary + or s == 0 # ref_noise_pred is None on first step + ) + force_compute.append(forced) + + return force_compute, step_uses_high + + def test_first_step_always_forced(self): + """Step 0 must always compute (warmup + ref_noise_pred is None).""" + force_compute, _ = self._get_force_compute_schedule(50) + self.assertTrue(force_compute[0]) + + def test_first_30_percent_always_forced(self): + """First 30% of steps are in the no-cache zone.""" + force_compute, _ = self._get_force_compute_schedule(50) + nocache_start = int(50 * 0.3) # 15 + self.assertTrue(all(force_compute[:nocache_start])) + + def test_last_10_percent_always_forced(self): + """Last 10% of steps are in the no-cache zone.""" + force_compute, _ = self._get_force_compute_schedule(50) + nocache_end_begin = int(50 * 0.9) # 45 + self.assertTrue(all(force_compute[nocache_end_begin:])) + + def test_boundary_transition_forced(self): + """Steps at high-to-low transformer transitions are forced.""" + force_compute, step_uses_high = self._get_force_compute_schedule(50) + for s in range(1, 50): + if step_uses_high[s] != step_uses_high[s - 1]: + self.assertTrue(force_compute[s], f"Boundary step {s} should be forced") + + def test_cacheable_window_exists(self): + """There should be steps in [30%, 90%) that are NOT forced (eligible for caching).""" + force_compute, _ = self._get_force_compute_schedule(50) + nocache_start = int(50 * 0.3) + nocache_end_begin = int(50 * 0.9) + cacheable = [not force_compute[s] for s in range(nocache_start, nocache_end_begin)] + self.assertGreater(sum(cacheable), 0, "Should have cacheable steps in the middle window") + + def test_short_run_all_forced(self): + """Very few steps should all be forced (no-cache zones overlap completely).""" + force_compute, _ = self._get_force_compute_schedule(3) + self.assertTrue(all(force_compute), "3 steps is too short — all should be forced") + + def test_max_reuse_limit(self): + """Simulate max_reuse=3: even if score stays low, after 3 reuses must recompute.""" + max_reuse = 3 + # Simulate a sequence of cache decisions where score is always below epsilon + reuse_count = 0 + recompute_happened = False + for _ in range(10): + if reuse_count < max_reuse: + reuse_count += 1 + else: + reuse_count = 0 + recompute_happened = True + self.assertTrue(recompute_happened, "Should force recompute after max_reuse consecutive reuses") + + def test_sensitivity_score_formula(self): + """Verify the sensitivity score formula: S = α_x·‖Δx‖ + α_t·|Δt|.""" + alpha_x, alpha_t = 1.0, 1.0 + sen_epsilon = 0.1 + + # Case 1: small deltas => cache hit + score = alpha_x * 0.03 + alpha_t * 0.02 + self.assertLessEqual(score, sen_epsilon, "Small deltas should yield score <= epsilon") + + # Case 2: large latent drift => cache miss + score = alpha_x * 0.5 + alpha_t * 0.02 + self.assertGreater(score, sen_epsilon, "Large dx should yield score > epsilon") + + # Case 3: large timestep drift => cache miss + score = alpha_x * 0.01 + alpha_t * 0.5 + self.assertGreater(score, sen_epsilon, "Large dt should yield score > epsilon") + + def test_all_high_noise_no_cacheable_window(self): + """If boundary_ratio=0, all steps are high-noise — boundary transitions still force compute.""" + force_compute, step_uses_high = self._get_force_compute_schedule(50, boundary_ratio=0.0) + self.assertTrue(all(step_uses_high), "All steps should be high-noise") + + def test_nocache_zones_scale_with_steps(self): + """No-cache zones should scale proportionally with num_inference_steps.""" + for n_steps in [20, 50, 100]: + force_compute, _ = self._get_force_compute_schedule(n_steps) + nocache_start = int(n_steps * 0.3) + nocache_end_begin = int(n_steps * 0.9) + self.assertTrue(all(force_compute[:nocache_start]), f"First 30% forced for {n_steps} steps") + self.assertTrue(all(force_compute[nocache_end_begin:]), f"Last 10% forced for {n_steps} steps") + + +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Requires TPU v7-8 and model weights") +class WanSenCacheSmokeTest(unittest.TestCase): + """End-to-end smoke test: SenCache should be faster with SSIM >= 0.95. + + Runs on TPU v7-8 (8 chips, context_parallelism=8) with WAN 2.2 27B, 720p. + Skipped in CI (GitHub Actions) — run locally with: + python -m pytest src/maxdiffusion/tests/wan_sen_cache_test.py::WanSenCacheSmokeTest -v + """ + + @classmethod + def setUpClass(cls): + from maxdiffusion import pyconfig + from maxdiffusion.checkpointing.wan_checkpointer_2_2 import WanCheckpointer2_2 + + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_27b.yml"), + "num_inference_steps=50", + "height=720", + "width=1280", + "num_frames=81", + "fps=24", + "guidance_scale_low=3.0", + "guidance_scale_high=4.0", + "boundary_ratio=0.875", + "flow_shift=3.0", + "seed=11234567893", + "attention=flash", + "remat_policy=FULL", + "allow_split_physical_axes=True", + "skip_jax_distributed_system=True", + "weights_dtype=bfloat16", + "activations_dtype=bfloat16", + "per_device_batch_size=0.125", + "ici_data_parallelism=1", + "ici_fsdp_parallelism=1", + "ici_context_parallelism=8", + "ici_tensor_parallelism=1", + "flash_min_seq_length=0", + 'flash_block_sizes={"block_q": 2048, "block_kv_compute": 1024, "block_kv": 2048, "block_q_dkv": 2048, "block_kv_dkv": 2048, "block_kv_dkv_compute": 2048, "use_fused_bwd_kernel": true}', + ], + unittest=True, + ) + cls.config = pyconfig.config + checkpoint_loader = WanCheckpointer2_2(config=cls.config) + cls.pipeline, _, _ = checkpoint_loader.load_checkpoint() + + cls.prompt = [cls.config.prompt] * cls.config.global_batch_size_to_train_on + cls.negative_prompt = [cls.config.negative_prompt] * cls.config.global_batch_size_to_train_on + + # Warmup both XLA code paths + for use_cache in [False, True]: + cls.pipeline( + prompt=cls.prompt, + negative_prompt=cls.negative_prompt, + height=cls.config.height, + width=cls.config.width, + num_frames=cls.config.num_frames, + num_inference_steps=cls.config.num_inference_steps, + guidance_scale_low=cls.config.guidance_scale_low, + guidance_scale_high=cls.config.guidance_scale_high, + use_sen_cache=use_cache, + ) + + def _run_pipeline(self, use_sen_cache): + t0 = time.perf_counter() + videos = self.pipeline( + prompt=self.prompt, + negative_prompt=self.negative_prompt, + height=self.config.height, + width=self.config.width, + num_frames=self.config.num_frames, + num_inference_steps=self.config.num_inference_steps, + guidance_scale_low=self.config.guidance_scale_low, + guidance_scale_high=self.config.guidance_scale_high, + use_sen_cache=use_sen_cache, + ) + return videos, time.perf_counter() - t0 + + def test_sen_cache_speedup_and_fidelity(self): + """SenCache must be faster than baseline with PSNR >= 30 dB and SSIM >= 0.95.""" + videos_baseline, t_baseline = self._run_pipeline(use_sen_cache=False) + videos_cached, t_cached = self._run_pipeline(use_sen_cache=True) + + # Speed check + speedup = t_baseline / t_cached + print(f"Baseline: {t_baseline:.2f}s, SenCache: {t_cached:.2f}s, Speedup: {speedup:.3f}x") + self.assertGreater(speedup, 1.0, f"SenCache should be faster. Speedup={speedup:.3f}x") + + # Fidelity checks + v1 = np.array(videos_baseline[0], dtype=np.float64) + v2 = np.array(videos_cached[0], dtype=np.float64) + + # PSNR + mse = np.mean((v1 - v2) ** 2) + psnr = 10.0 * np.log10(1.0 / mse) if mse > 0 else float("inf") + print(f"PSNR: {psnr:.2f} dB") + self.assertGreaterEqual(psnr, 30.0, f"PSNR={psnr:.2f} dB < 30 dB") + + # SSIM (per-frame) + C1, C2 = 0.01**2, 0.03**2 + ssim_scores = [] + for f in range(v1.shape[0]): + mu1, mu2 = np.mean(v1[f]), np.mean(v2[f]) + sigma1_sq, sigma2_sq = np.var(v1[f]), np.var(v2[f]) + sigma12 = np.mean((v1[f] - mu1) * (v2[f] - mu2)) + ssim = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2)) + ssim_scores.append(float(ssim)) + + mean_ssim = np.mean(ssim_scores) + print(f"SSIM: mean={mean_ssim:.4f}, min={np.min(ssim_scores):.4f}") + self.assertGreaterEqual(mean_ssim, 0.95, f"Mean SSIM={mean_ssim:.4f} < 0.95") + + +if __name__ == "__main__": + absltest.main()