From 12cb7abbfeb83f797d5dc06a5bee7662e761f8c2 Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Fri, 9 Jan 2026 18:53:56 -0800 Subject: [PATCH 1/2] Add highlight rolloff in preprocessing --- src/sharp/cli/predict.py | 41 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index b34326cf..1326981a 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -888,6 +888,13 @@ def _finalize_prediction( target_size_wh=(1536, 1536), dtype=torch.float32, ) + if aux.get("highlight_rolloff_applied"): + LOGGER.info( + "Highlight rolloff enabled (sat_frac=%.6f, threshold=%.6f): %s", + aux.get("highlight_sat_frac", 0.0), + aux.get("highlight_sat_frac_threshold", 0.0), + image_path, + ) aux["metrics"] = metrics preprocess_elapsed = 0.0 if metrics and preprocess_start is not None: @@ -1010,6 +1017,30 @@ def preprocess_one( image_np = image_np.copy() image_pt = torch.from_numpy(image_np).to(dtype=dtype, device=device).permute(2, 0, 1) image_pt = image_pt / 255.0 + sat_threshold = 0.98 + sat_frac_threshold = 0.001 + knee_start = 0.98 + rolloff_strength = 4.0 + max_rgb = image_pt.max(dim=0).values + sat_mask = max_rgb >= sat_threshold + sat_frac_pt = sat_mask.float().mean() + sat_frac = float(sat_frac_pt.item()) + highlight_rolloff_applied = False + if sat_frac >= sat_frac_threshold: + knee_start_pt = torch.tensor(knee_start, dtype=image_pt.dtype, device=image_pt.device) + one_minus_knee = torch.tensor( + 1.0 - knee_start, dtype=image_pt.dtype, device=image_pt.device + ) + rolloff_strength_pt = torch.tensor( + rolloff_strength, dtype=image_pt.dtype, device=image_pt.device + ) + normalized = (image_pt - knee_start_pt) / one_minus_knee + rolloff = knee_start_pt + one_minus_knee * ( + 1.0 - torch.exp(-rolloff_strength_pt * normalized) + ) + image_pt = torch.where(image_pt > knee_start_pt, rolloff, image_pt) + image_pt = image_pt.clamp(0.0, 1.0) + highlight_rolloff_applied = True _, height, width = image_pt.shape disparity_factor_pt = torch.tensor([f_px / width], dtype=dtype, device=device) image_resized_pt = F.interpolate( @@ -1031,6 +1062,9 @@ def preprocess_one( "f_px": f_px, "target_w": target_w, "target_h": target_h, + "highlight_rolloff_applied": highlight_rolloff_applied, + "highlight_sat_frac": sat_frac, + "highlight_sat_frac_threshold": sat_frac_threshold, } return image_resized_pt, disparity_factor_pt, aux @@ -1203,6 +1237,13 @@ def predict_image( target_size_wh=target_size_wh, dtype=torch.float32, ) + if aux.get("highlight_rolloff_applied"): + LOGGER.info( + "Highlight rolloff enabled (sat_frac=%.6f, threshold=%.6f): %s", + aux.get("highlight_sat_frac", 0.0), + aux.get("highlight_sat_frac_threshold", 0.0), + "", + ) aux["metrics"] = metrics if metrics and preprocess_start is not None: metrics.add_time("preprocess", perf_counter() - preprocess_start) From d7ce027ce47bcbaf3d2ea24365e2729f9bb266b0 Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Fri, 9 Jan 2026 19:52:20 -0800 Subject: [PATCH 2/2] Refine highlight rolloff detection --- src/sharp/cli/predict.py | 50 ++++++++++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index 1326981a..a01155a2 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -890,10 +890,14 @@ def _finalize_prediction( ) if aux.get("highlight_rolloff_applied"): LOGGER.info( - "Highlight rolloff enabled (sat_frac=%.6f, threshold=%.6f): %s", - aux.get("highlight_sat_frac", 0.0), - aux.get("highlight_sat_frac_threshold", 0.0), + "Highlight rolloff enabled [trigger=%s] for %s (white_frac=%.6f, " + "sat_frac=%.6f; white_thr=%.6f, sat_fallback=%.6f)", + aux.get("highlight_rolloff_trigger", "unknown"), image_path, + aux.get("highlight_white_frac", 0.0), + aux.get("highlight_sat_frac", 0.0), + aux.get("highlight_white_frac_threshold", 0.0), + aux.get("highlight_sat_frac_fallback_threshold", 0.0), ) aux["metrics"] = metrics preprocess_elapsed = 0.0 @@ -1018,15 +1022,28 @@ def preprocess_one( image_pt = torch.from_numpy(image_np).to(dtype=dtype, device=device).permute(2, 0, 1) image_pt = image_pt / 255.0 sat_threshold = 0.98 - sat_frac_threshold = 0.001 + white_threshold = 0.98 + chroma_threshold = 0.08 + white_frac_threshold = 0.01 + sat_frac_fallback_threshold = 0.30 knee_start = 0.98 rolloff_strength = 4.0 max_rgb = image_pt.max(dim=0).values + min_rgb = image_pt.min(dim=0).values + chroma = max_rgb - min_rgb sat_mask = max_rgb >= sat_threshold - sat_frac_pt = sat_mask.float().mean() - sat_frac = float(sat_frac_pt.item()) + white_mask = (min_rgb >= white_threshold) & (chroma <= chroma_threshold) + sat_frac = float(sat_mask.float().mean().item()) + white_frac = float(white_mask.float().mean().item()) highlight_rolloff_applied = False - if sat_frac >= sat_frac_threshold: + highlight_rolloff_trigger = "none" + if white_frac >= white_frac_threshold: + highlight_rolloff_applied = True + highlight_rolloff_trigger = "white" + elif sat_frac >= sat_frac_fallback_threshold: + highlight_rolloff_applied = True + highlight_rolloff_trigger = "sat_fallback" + if highlight_rolloff_applied: knee_start_pt = torch.tensor(knee_start, dtype=image_pt.dtype, device=image_pt.device) one_minus_knee = torch.tensor( 1.0 - knee_start, dtype=image_pt.dtype, device=image_pt.device @@ -1040,7 +1057,6 @@ def preprocess_one( ) image_pt = torch.where(image_pt > knee_start_pt, rolloff, image_pt) image_pt = image_pt.clamp(0.0, 1.0) - highlight_rolloff_applied = True _, height, width = image_pt.shape disparity_factor_pt = torch.tensor([f_px / width], dtype=dtype, device=device) image_resized_pt = F.interpolate( @@ -1063,8 +1079,14 @@ def preprocess_one( "target_w": target_w, "target_h": target_h, "highlight_rolloff_applied": highlight_rolloff_applied, + "highlight_rolloff_trigger": highlight_rolloff_trigger, "highlight_sat_frac": sat_frac, - "highlight_sat_frac_threshold": sat_frac_threshold, + "highlight_white_frac": white_frac, + "highlight_sat_threshold": sat_threshold, + "highlight_white_threshold": white_threshold, + "highlight_chroma_threshold": chroma_threshold, + "highlight_white_frac_threshold": white_frac_threshold, + "highlight_sat_frac_fallback_threshold": sat_frac_fallback_threshold, } return image_resized_pt, disparity_factor_pt, aux @@ -1239,10 +1261,14 @@ def predict_image( ) if aux.get("highlight_rolloff_applied"): LOGGER.info( - "Highlight rolloff enabled (sat_frac=%.6f, threshold=%.6f): %s", - aux.get("highlight_sat_frac", 0.0), - aux.get("highlight_sat_frac_threshold", 0.0), + "Highlight rolloff enabled [trigger=%s] for %s (white_frac=%.6f, sat_frac=%.6f; " + "white_thr=%.6f, sat_fallback=%.6f)", + aux.get("highlight_rolloff_trigger", "unknown"), "", + aux.get("highlight_white_frac", 0.0), + aux.get("highlight_sat_frac", 0.0), + aux.get("highlight_white_frac_threshold", 0.0), + aux.get("highlight_sat_frac_fallback_threshold", 0.0), ) aux["metrics"] = metrics if metrics and preprocess_start is not None: