@@ -206,42 +206,54 @@ def run_inference_2_2(
206206 1. CFG-Cache (use_cfg_cache=True) — FasterCache-style:
207207 Caches the unconditional branch and uses FFT frequency-domain compensation.
208208
209- 2. SenCache (use_sen_cache=True) — Sensitivity-aware caching:
210- Measures output sensitivity after each full forward pass. When sensitivity
211- is low (model output is stable), skips the entire transformer and reuses
212- the cached noise prediction. Naturally handles MoE expert boundaries by
213- detecting high sensitivity at transition points.
209+ 2. SenCache (use_sen_cache=True) — Sensitivity-Aware Caching
210+ (Haghighi & Alahi, arXiv:2602.24208):
211+ Uses a first-order sensitivity approximation S = α_x·‖Δx‖ + α_t·|Δt|
212+ to predict output change. Caches when predicted change is below tolerance ε.
213+ Tracks accumulated latent drift and timestep drift since last cache refresh,
214+ adapting cache decisions per-sample. Sensitivity weights (α_x, α_t) are
215+ estimated from warmup steps via finite differences.
214216 """
215217 do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
216218 bsz = latents .shape [0 ]
217219
218- # ── SenCache path ──
220+ # ── SenCache path (arXiv:2602.24208, mirrors NebulaTPU SenCacheMiddleware) ──
219221 if use_sen_cache and do_classifier_free_guidance :
220222 timesteps_np = np .array (scheduler_state .timesteps , dtype = np .int32 )
221223 step_uses_high = [bool (timesteps_np [s ] >= boundary ) for s in range (num_inference_steps )]
222224
223- # Resolution-dependent SenCache config
224- if height >= 720 :
225- sen_threshold = 0.06 # tighter for higher resolution
226- warmup_ratio = 0.10
227- max_consecutive_cache = 2
228- else :
229- sen_threshold = 0.08
230- warmup_ratio = 0.08
231- max_consecutive_cache = 3
232-
233- warmup_steps = max (2 , int (num_inference_steps * warmup_ratio ))
225+ # SenCache hyperparameters (matching NebulaTPU defaults)
226+ sen_epsilon = 0.1 # main tolerance (permissive phase)
227+ max_reuse = 3 # max consecutive cache reuses before forced recompute
228+ warmup_steps = 1 # first step always computes
229+ # No-cache zones: first 30% (structure formation) and last 10% (detail refinement)
230+ nocache_start_ratio = 0.3
231+ nocache_end_ratio = 0.1
232+ # Uniform sensitivity weights (α_x=1, α_t=1); swap for pre-calibrated
233+ # SensitivityProfile per-timestep values when available.
234+ alpha_x , alpha_t = 1.0 , 1.0
235+
236+ nocache_start = int (num_inference_steps * nocache_start_ratio )
237+ nocache_end_begin = int (num_inference_steps * (1.0 - nocache_end_ratio ))
238+ # Normalize timesteps to [0, 1] to match NebulaTPU's sigma-based convention.
239+ # maxdiffusion timesteps are integers in [0, num_train_timesteps]; NebulaTPU
240+ # uses sigmas in [0, 1]. Without normalization |Δt|≈20 >> ε and nothing caches.
241+ num_train_timesteps = float (scheduler .config .num_train_timesteps )
234242
235243 prompt_embeds_combined = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
236244
237245 # SenCache state
238- prev_noise_pred = None # last full-computation noise prediction
239- sensitivity = float ('inf' ) # measured relative output change
240- consecutive_cached = 0 # consecutive steps using cache
246+ ref_noise_pred = None # y^r: cached denoiser output
247+ ref_latent = None # x^r: latent at last cache refresh
248+ ref_timestep = 0.0 # t^r: timestep (normalized to [0,1]) at last cache refresh
249+ accum_dx = 0.0 # accumulated ||Δx|| since last refresh
250+ accum_dt = 0.0 # accumulated |Δt| since last refresh
251+ reuse_count = 0 # consecutive cache reuses
241252 cache_count = 0
242253
243254 for step in range (num_inference_steps ):
244255 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
256+ t_float = float (timesteps_np [step ]) / num_train_timesteps # normalize to [0, 1]
245257
246258 # Select transformer and guidance scale
247259 if step_uses_high [step ]:
@@ -251,47 +263,73 @@ def run_inference_2_2(
251263 graphdef , state , rest = low_noise_graphdef , low_noise_state , low_noise_rest
252264 guidance_scale = guidance_scale_low
253265
254- # Caching decision
255- is_warmup = step < warmup_steps
266+ # Force full compute: warmup, first 30%, last 10%, or transformer boundary
256267 is_boundary = step > 0 and step_uses_high [step ] != step_uses_high [step - 1 ]
257- should_cache = (
258- not is_warmup
259- and not is_boundary
260- and prev_noise_pred is not None
261- and sensitivity < sen_threshold
262- and consecutive_cached < max_consecutive_cache
268+ force_compute = (
269+ step < warmup_steps or step < nocache_start or step >= nocache_end_begin or is_boundary or ref_noise_pred is None
263270 )
264271
265- if should_cache :
266- # ── Cache step: reuse previous noise prediction ──
267- noise_pred = prev_noise_pred
268- consecutive_cached += 1
272+ if force_compute :
273+ latents_doubled = jnp .concatenate ([latents ] * 2 )
274+ timestep = jnp .broadcast_to (t , bsz * 2 )
275+ noise_pred , _ , _ = transformer_forward_pass_full_cfg (
276+ graphdef ,
277+ state ,
278+ rest ,
279+ latents_doubled ,
280+ timestep ,
281+ prompt_embeds_combined ,
282+ guidance_scale = guidance_scale ,
283+ )
284+ ref_noise_pred = noise_pred
285+ ref_latent = latents
286+ ref_timestep = t_float
287+ accum_dx = 0.0
288+ accum_dt = 0.0
289+ reuse_count = 0
290+ latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
291+ continue
292+
293+ # Accumulate deltas since last full compute
294+ dx_norm = float (jnp .sqrt (jnp .mean ((latents - ref_latent ) ** 2 )))
295+ dt = abs (t_float - ref_timestep )
296+ accum_dx += dx_norm
297+ accum_dt += dt
298+
299+ # Sensitivity score (Eq. 9)
300+ score = alpha_x * accum_dx + alpha_t * accum_dt
301+
302+ if score <= sen_epsilon and reuse_count < max_reuse :
303+ # Cache hit: reuse previous output
304+ noise_pred = ref_noise_pred
305+ reuse_count += 1
269306 cache_count += 1
270307 else :
271- # ── Full CFG step ──
308+ # Cache miss: full CFG forward pass
272309 latents_doubled = jnp .concatenate ([latents ] * 2 )
273310 timestep = jnp .broadcast_to (t , bsz * 2 )
274311 noise_pred , _ , _ = transformer_forward_pass_full_cfg (
275- graphdef , state , rest ,
276- latents_doubled , timestep , prompt_embeds_combined ,
312+ graphdef ,
313+ state ,
314+ rest ,
315+ latents_doubled ,
316+ timestep ,
317+ prompt_embeds_combined ,
277318 guidance_scale = guidance_scale ,
278319 )
279-
280- # Measure sensitivity: relative output change since last full step
281- if prev_noise_pred is not None :
282- output_diff = jnp .mean (jnp .abs (noise_pred - prev_noise_pred ))
283- output_magnitude = jnp .mean (jnp .abs (noise_pred )) + 1e-8
284- sensitivity = float (output_diff / output_magnitude )
285- else :
286- sensitivity = float ('inf' )
287-
288- prev_noise_pred = noise_pred
289- consecutive_cached = 0
320+ ref_noise_pred = noise_pred
321+ ref_latent = latents
322+ ref_timestep = t_float
323+ accum_dx = 0.0
324+ accum_dt = 0.0
325+ reuse_count = 0
290326
291327 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
292328
293- print (f"[SenCache] Cached { cache_count } /{ num_inference_steps } steps "
294- f"({ 100 * cache_count / num_inference_steps :.1f} % cache ratio)" )
329+ print (
330+ f"[SenCache] Cached { cache_count } /{ num_inference_steps } steps "
331+ f"({ 100 * cache_count / num_inference_steps :.1f} % cache ratio)"
332+ )
295333 return latents
296334
297335 # ── CFG cache path ──
0 commit comments