From d1fa219cb7eaa2850d94bf00b06399ddb978a61f Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Wed, 21 Jan 2026 15:43:43 -0800 Subject: [PATCH 01/16] Add tiling CLI flags and validation --- src/sharp/cli/predict.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index 3cf8707d..b99f3a91 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -275,6 +275,33 @@ def _center_crop(img: np.ndarray, target_h: int, target_w: int) -> np.ndarray: default=False, help="Render SBS preview using predicted-space gaussians (skips world conversion).", ) +@click.option( + "--tiling", + is_flag=True, + default=False, + help="Enable tiled inference mode (requires --sbs-image and --fast-preview-render).", +) +@click.option( + "--tile-size", + type=int, + default=1536, + show_default=True, + help="Nominal tile size in pixels (square tiles assumed).", +) +@click.option( + "--tile-overlap", + type=float, + default=0.25, + show_default=True, + help="Fractional overlap between tiles in [0.0, 0.5).", +) +@click.option( + "--tile-keep", + type=float, + default=None, + show_default=True, + help="Optional keep region fraction for tiles (defaults to derive from overlap).", +) @click.option( "--sbs-min-opacity", type=float, @@ -389,6 +416,10 @@ def predict_cli( sbs_image_frame: int, stereo_strength: float, fast_preview_render: bool, + tiling: bool, + tile_size: int, + tile_overlap: float, + tile_keep: float | None, sbs_min_opacity: float, sbs_min_scale: float, sbs_max_splats: int | None, @@ -426,6 +457,14 @@ def predict_cli( return if batch_size < 1: raise click.ClickException("--batch-size must be >= 1.") + if tiling and (sbs_image is None or not fast_preview_render): + raise click.ClickException( + "--tiling is only supported with --sbs-image and --fast-preview-render." + ) + if tile_overlap < 0.0 or tile_overlap >= 0.5: + raise click.ClickException("--tile-overlap must be in [0.0, 0.5).") + if tile_keep is not None and (tile_keep <= 0.0 or tile_keep > 1.0): + raise click.ClickException("--tile-keep must be in (0.0, 1.0].") def _natural_sort_key(path: Path) -> list[tuple[int, object]]: relative_path = path.relative_to(input_path).as_posix() From ca836789e9a37dbc95e984032a9972a91eb1e9b9 Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Wed, 21 Jan 2026 15:56:02 -0800 Subject: [PATCH 02/16] Add tiling utility helpers --- src/sharp/utils/tiling.py | 106 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 src/sharp/utils/tiling.py diff --git a/src/sharp/utils/tiling.py b/src/sharp/utils/tiling.py new file mode 100644 index 00000000..95f53939 --- /dev/null +++ b/src/sharp/utils/tiling.py @@ -0,0 +1,106 @@ +"""Utilities for tiled inference and camera compensation. + +For licensing see accompanying LICENSE file. +Copyright (C) 2025 Apple Inc. All Rights Reserved. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True) +class Tile: + """Axis-aligned integer tile bounds in pixel coordinates. + + Coordinates follow (x0, y0, x1, y1) with a half-open interval: + x in [x0, x1), y in [y0, y1). + """ + + x0: int + y0: int + x1: int + y1: int + + +def make_tiles(W: int, H: int, tile_size: int, overlap: float) -> list[Tile]: + """Generate a grid of overlapping tiles that cover an image. + + Args: + W: Full image width in pixels. + H: Full image height in pixels. + tile_size: Nominal tile edge length in pixels (square tiles). + overlap: Fractional overlap in [0.0, 0.5). + """ + if W <= 0 or H <= 0: + raise ValueError("W and H must be positive.") + if tile_size <= 0: + raise ValueError("tile_size must be positive.") + if overlap < 0.0 or overlap >= 0.5: + raise ValueError("overlap must be in [0.0, 0.5).") + + if tile_size >= W or tile_size >= H: + return [Tile(0, 0, W, H)] + + stride = max(1, int(round(tile_size * (1.0 - overlap)))) + + def _starts(full: int) -> list[int]: + starts: list[int] = [] + pos = 0 + while pos + tile_size < full: + starts.append(pos) + pos += stride + final_start = max(0, full - tile_size) + if not starts or starts[-1] != final_start: + starts.append(final_start) + return starts + + x_starts = _starts(W) + y_starts = _starts(H) + + tiles: list[Tile] = [] + for y0 in y_starts: + for x0 in x_starts: + x1 = min(W, x0 + tile_size) + y1 = min(H, y0 + tile_size) + x0_tile = x0 + y0_tile = y0 + if x1 - x0_tile < tile_size: + x0_tile = max(0, x1 - tile_size) + if y1 - y0_tile < tile_size: + y0_tile = max(0, y1 - tile_size) + tiles.append(Tile(x0_tile, y0_tile, x1, y1)) + + return tiles + + +def shift_intrinsics_for_tile( + K_full: torch.Tensor, x0: float, y0: float +) -> torch.Tensor: + """Shift intrinsics for a tile offset in full-image coordinates.""" + if K_full.shape != (3, 3): + raise ValueError("K_full must have shape (3, 3).") + K_tile = K_full.clone() + K_tile[0, 2] = K_tile[0, 2] - x0 + K_tile[1, 2] = K_tile[1, 2] - y0 + return K_tile + + +def scale_intrinsics_for_resize( + K: torch.Tensor, src_wh: tuple[int, int], dst_wh: tuple[int, int] +) -> torch.Tensor: + """Scale intrinsics to match a resize from src_wh to dst_wh (width, height).""" + if K.shape != (3, 3): + raise ValueError("K must have shape (3, 3).") + src_w, src_h = src_wh + dst_w, dst_h = dst_wh + if src_w <= 0 or src_h <= 0 or dst_w <= 0 or dst_h <= 0: + raise ValueError("src_wh and dst_wh must be positive.") + scale_x = dst_w / src_w + scale_y = dst_h / src_h + K_scaled = K.clone() + K_scaled[0, :] = K_scaled[0, :] * scale_x + K_scaled[1, :] = K_scaled[1, :] * scale_y + return K_scaled From 20ddfd7a8f3748d1b8b64408bf76e4cd86ec061c Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Wed, 21 Jan 2026 17:31:10 -0800 Subject: [PATCH 03/16] Fix tiling early return condition --- src/sharp/utils/tiling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sharp/utils/tiling.py b/src/sharp/utils/tiling.py index 95f53939..940d749d 100644 --- a/src/sharp/utils/tiling.py +++ b/src/sharp/utils/tiling.py @@ -41,7 +41,7 @@ def make_tiles(W: int, H: int, tile_size: int, overlap: float) -> list[Tile]: if overlap < 0.0 or overlap >= 0.5: raise ValueError("overlap must be in [0.0, 0.5).") - if tile_size >= W or tile_size >= H: + if tile_size >= W and tile_size >= H: return [Tile(0, 0, W, H)] stride = max(1, int(round(tile_size * (1.0 - overlap)))) From 1a9fd6c7cb1a1dbc63ea1103101d8754abe86ab9 Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Wed, 21 Jan 2026 18:36:18 -0800 Subject: [PATCH 04/16] Add reference width for disparity scaling --- src/sharp/cli/predict.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index b99f3a91..54a03257 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -1037,6 +1037,7 @@ def preprocess_one( *, target_size_wh: tuple[int, int], dtype: torch.dtype, + reference_width: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor, dict]: target_w, target_h = target_size_wh image_np = np.ascontiguousarray(image_np) @@ -1050,7 +1051,12 @@ def preprocess_one( ) image_pt = image_pt / 255.0 _, height, width = image_pt.shape - disparity_factor_pt = torch.tensor([f_px / width], dtype=dtype, device=device) + width_for_disparity = width + if reference_width is not None: + if not isinstance(reference_width, int) or reference_width <= 0: + raise click.ClickException("--reference_width must be a positive integer.") + width_for_disparity = reference_width + disparity_factor_pt = torch.tensor([f_px / width_for_disparity], dtype=dtype, device=device) image_resized_pt = TF.resize( image_pt, [target_h, target_w], From 7f859ed1fe669a4642a9a98282610e95266a2738 Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Wed, 21 Jan 2026 21:13:16 -0800 Subject: [PATCH 05/16] Clarify reference_width error message --- src/sharp/cli/predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index 54a03257..a4261f68 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -1054,7 +1054,7 @@ def preprocess_one( width_for_disparity = width if reference_width is not None: if not isinstance(reference_width, int) or reference_width <= 0: - raise click.ClickException("--reference_width must be a positive integer.") + raise click.ClickException("reference_width must be a positive integer.") width_for_disparity = reference_width disparity_factor_pt = torch.tensor([f_px / width_for_disparity], dtype=dtype, device=device) image_resized_pt = TF.resize( From 5d2bb30221efb4159a3701ef3862a66efa04af96 Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Wed, 21 Jan 2026 21:13:28 -0800 Subject: [PATCH 06/16] Add tiled prediction path --- src/sharp/cli/predict.py | 166 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 154 insertions(+), 12 deletions(-) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index a4261f68..7b3906da 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -29,6 +29,7 @@ RGBGaussianPredictor, create_predictor, ) +from sharp.utils.tiling import make_tiles, scale_intrinsics_for_resize, shift_intrinsics_for_tile from sharp.utils import io from sharp.utils import logging as logging_utils from sharp.utils.gaussians import ( @@ -457,6 +458,10 @@ def predict_cli( return if batch_size < 1: raise click.ClickException("--batch-size must be >= 1.") + if tiling and batch_size > 1: + raise click.ClickException( + "--tiling is not yet supported with batched inputs; use --batch-size=1." + ) if tiling and (sbs_image is None or not fast_preview_render): raise click.ClickException( "--tiling is only supported with --sbs-image and --fast-preview-render." @@ -810,6 +815,8 @@ def _finalize_prediction( want_world_for_predict = want_world_for_render or ( effective_save_ply and not defer_export_world ) + if tiling: + want_world_for_predict = True if skip_world_conversion and (want_world_for_predict or effective_save_ply): raise click.ClickException( "World-space conversion is required for rendering or PLY export. " @@ -826,18 +833,32 @@ def _finalize_prediction( LOGGER.info("Computing world-space gaussians: %s", "yes" if space == "world" else "no") predict_start = perf_counter() - prediction = predict_image( - gaussian_predictor, - image, - f_px, - torch.device(device), - amp_enabled=amp, - amp_dtype=amp_dtype_to_use, - metrics=metrics, - return_world=want_world_for_predict, - return_unprojection=fast_preview_render, - return_unprojection_context=fast_preview_render or effective_save_ply, - ) + if tiling: + prediction = predict_image_tiled( + gaussian_predictor, + image, + f_px, + torch.device(device), + tile_size=tile_size, + tile_overlap=tile_overlap, + target_size_wh=(1536, 1536), + amp_enabled=amp, + amp_dtype=amp_dtype_to_use, + metrics=metrics, + ) + else: + prediction = predict_image( + gaussian_predictor, + image, + f_px, + torch.device(device), + amp_enabled=amp, + amp_dtype=amp_dtype_to_use, + metrics=metrics, + return_world=want_world_for_predict, + return_unprojection=fast_preview_render, + return_unprojection_context=fast_preview_render or effective_save_ply, + ) metrics.add_time("predict_total", perf_counter() - predict_start) _finalize_prediction( prediction=prediction, @@ -1289,6 +1310,127 @@ def predict_image( ) +def predict_image_tiled( + predictor: RGBGaussianPredictor, + image: np.ndarray, + f_px: float, + device: torch.device, + *, + tile_size: int, + tile_overlap: float, + target_size_wh: tuple[int, int], + amp_enabled: bool, + amp_dtype: torch.dtype | None, + metrics: Metrics | None = None, +) -> PredictionResult: + """Predict Gaussians per tile and unproject each tile to world space.""" + height_full, width_full = image.shape[:2] + fx = f_px + fy = f_px + cx = (width_full - 1) / 2.0 + cy = (height_full - 1) / 2.0 + k_full = torch.tensor( + [ + [fx, 0.0, cx], + [0.0, fy, cy], + [0.0, 0.0, 1.0], + ], + device=device, + dtype=torch.float32, + ) + tiles = make_tiles(width_full, height_full, tile_size, tile_overlap) + LOGGER.info( + "Tiling enabled: tile_size=%d tile_overlap=%.2f tiles=%d", + tile_size, + tile_overlap, + len(tiles), + ) + + amp_dtype_to_use = amp_dtype if amp_enabled else None + target_w, target_h = target_size_wh + + tile_worlds: list[Gaussians3D] = [] + for tile in tiles: + tile_img = image[tile.y0 : tile.y1, tile.x0 : tile.x1] + preprocess_start = perf_counter() if metrics else None + tile_resized_pt, tile_disp_factor, tile_aux = preprocess_one( + tile_img, + f_px, + device, + target_size_wh=target_size_wh, + dtype=torch.float32, + reference_width=width_full, + ) + tile_aux["metrics"] = metrics + if metrics and preprocess_start is not None: + metrics.add_time("tile_preprocess", perf_counter() - preprocess_start) + + forward_start = perf_counter() if metrics else None + gaussians_ndc_batch = model_forward_batch( + predictor, + tile_resized_pt, + tile_disp_factor, + amp_dtype=amp_dtype_to_use, + ) + if metrics and forward_start is not None: + metrics.add_time("tile_forward", perf_counter() - forward_start) + + gaussians_ndc_one = _slice_gaussians(gaussians_ndc_batch, 0) + prediction = postprocess_one( + gaussians_ndc_one, + tile_aux, + return_world=False, + return_unprojection=False, + device=device, + ) + + k_tile = shift_intrinsics_for_tile(k_full, tile.x0, tile.y0) + k_tile_resized = scale_intrinsics_for_resize( + k_tile, + src_wh=(width_full, height_full), + dst_wh=target_size_wh, + ) + intrinsics_4 = torch.eye(4, device=device, dtype=k_tile_resized.dtype) + intrinsics_4[:3, :3] = k_tile_resized + extrinsics = torch.eye(4, device=device, dtype=k_tile_resized.dtype) + _ = get_unprojection_matrix(extrinsics, intrinsics_4, (target_w, target_h)) + + unproject_start = perf_counter() if metrics else None + world_gaussians_tile = unproject_gaussians( + prediction.pred, + extrinsics, + intrinsics_4, + (target_w, target_h), + metrics=metrics, + ) + if metrics and unproject_start is not None: + metrics.add_time("tile_unproject", perf_counter() - unproject_start) + + tile_worlds.append(world_gaussians_tile) + + def _concat_gaussians(items: list[Gaussians3D]) -> Gaussians3D: + if not items: + raise ValueError("No tiles produced for tiled prediction.") + return Gaussians3D( + mean_vectors=torch.cat([g.mean_vectors for g in items], dim=0), + singular_values=torch.cat([g.singular_values for g in items], dim=0), + quaternions=torch.cat([g.quaternions for g in items], dim=0), + colors=torch.cat([g.colors for g in items], dim=0), + opacities=torch.cat([g.opacities for g in items], dim=0), + ) + + world_gaussians = _concat_gaussians(tile_worlds) + pred_gaussians = world_gaussians + unprojection_matrix = torch.eye(4, device=device, dtype=torch.float32) + + return PredictionResult( + pred=pred_gaussians, + world=world_gaussians, + unprojection_matrix=unprojection_matrix, + unprojection_context=None, + ) + + def _cast_gaussians(gaussians: Gaussians3D, dtype: torch.dtype) -> Gaussians3D: return Gaussians3D( mean_vectors=gaussians.mean_vectors.to(dtype=dtype), From 06648e4fe8e872e5b5fc5794b149a87f7ef76849 Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Wed, 21 Jan 2026 21:57:40 -0800 Subject: [PATCH 07/16] Remap tiled gaussians to global pred space --- src/sharp/cli/predict.py | 49 ++++++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index 7b3906da..af6cabe9 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -35,6 +35,7 @@ from sharp.utils.gaussians import ( Gaussians3D, SceneMetaData, + apply_transform, get_unprojection_matrix, save_ply, unproject_gaussians, @@ -1323,7 +1324,7 @@ def predict_image_tiled( amp_dtype: torch.dtype | None, metrics: Metrics | None = None, ) -> PredictionResult: - """Predict Gaussians per tile and unproject each tile to world space.""" + """Predict Gaussians per tile and remap into a shared pred-space.""" height_full, width_full = image.shape[:2] fx = f_px fy = f_px @@ -1349,7 +1350,22 @@ def predict_image_tiled( amp_dtype_to_use = amp_dtype if amp_enabled else None target_w, target_h = target_size_wh - tile_worlds: list[Gaussians3D] = [] + k_full_resized = scale_intrinsics_for_resize( + k_full, + src_wh=(width_full, height_full), + dst_wh=target_size_wh, + ) + intrinsics_full = torch.eye(4, device=device, dtype=k_full_resized.dtype) + intrinsics_full[:3, :3] = k_full_resized + extrinsics_full = torch.eye(4, device=device, dtype=k_full_resized.dtype) + u_global = get_unprojection_matrix( + extrinsics_full, + intrinsics_full, + (target_w, target_h), + ) + u_global_inv = torch.linalg.inv(u_global) + + tile_preds_global: list[Gaussians3D] = [] for tile in tiles: tile_img = image[tile.y0 : tile.y1, tile.x0 : tile.x1] preprocess_start = perf_counter() if metrics else None @@ -1393,20 +1409,20 @@ def predict_image_tiled( intrinsics_4 = torch.eye(4, device=device, dtype=k_tile_resized.dtype) intrinsics_4[:3, :3] = k_tile_resized extrinsics = torch.eye(4, device=device, dtype=k_tile_resized.dtype) - _ = get_unprojection_matrix(extrinsics, intrinsics_4, (target_w, target_h)) + u_tile = get_unprojection_matrix(extrinsics, intrinsics_4, (target_w, target_h)) - unproject_start = perf_counter() if metrics else None - world_gaussians_tile = unproject_gaussians( + # Remap tile pred-space gaussians into the global pred-space so fast preview renders once. + remap_start = perf_counter() if metrics else None + tile_to_global_pred = u_global_inv @ u_tile + global_pred_tile = apply_transform( prediction.pred, - extrinsics, - intrinsics_4, - (target_w, target_h), + tile_to_global_pred[:3], metrics=metrics, ) - if metrics and unproject_start is not None: - metrics.add_time("tile_unproject", perf_counter() - unproject_start) + if metrics and remap_start is not None: + metrics.add_time("tile_remap", perf_counter() - remap_start) - tile_worlds.append(world_gaussians_tile) + tile_preds_global.append(global_pred_tile) def _concat_gaussians(items: list[Gaussians3D]) -> Gaussians3D: if not items: @@ -1419,13 +1435,16 @@ def _concat_gaussians(items: list[Gaussians3D]) -> Gaussians3D: opacities=torch.cat([g.opacities for g in items], dim=0), ) - world_gaussians = _concat_gaussians(tile_worlds) - pred_gaussians = world_gaussians - unprojection_matrix = torch.eye(4, device=device, dtype=torch.float32) + pred_gaussians = _concat_gaussians(tile_preds_global) + if pred_gaussians.mean_vectors.numel() == 0: + raise ValueError("Tiled prediction produced no Gaussians.") + unprojection_matrix = u_global + if unprojection_matrix is None: + raise ValueError("Missing global unprojection matrix for tiled prediction.") return PredictionResult( pred=pred_gaussians, - world=world_gaussians, + world=None, unprojection_matrix=unprojection_matrix, unprojection_context=None, ) From 59040f905efc92e0f25aeebf3a8ccc40b4a37240 Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Wed, 21 Jan 2026 22:33:53 -0800 Subject: [PATCH 08/16] Suppress tiled edge gaussians --- src/sharp/cli/predict.py | 63 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index af6cabe9..080b25a1 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -842,6 +842,7 @@ def _finalize_prediction( torch.device(device), tile_size=tile_size, tile_overlap=tile_overlap, + tile_keep=tile_keep, target_size_wh=(1536, 1536), amp_enabled=amp, amp_dtype=amp_dtype_to_use, @@ -1319,6 +1320,7 @@ def predict_image_tiled( *, tile_size: int, tile_overlap: float, + tile_keep: float | None, target_size_wh: tuple[int, int], amp_enabled: bool, amp_dtype: torch.dtype | None, @@ -1366,7 +1368,7 @@ def predict_image_tiled( u_global_inv = torch.linalg.inv(u_global) tile_preds_global: list[Gaussians3D] = [] - for tile in tiles: + for tile_index, tile in enumerate(tiles, start=1): tile_img = image[tile.y0 : tile.y1, tile.x0 : tile.x1] preprocess_start = perf_counter() if metrics else None tile_resized_pt, tile_disp_factor, tile_aux = preprocess_one( @@ -1400,6 +1402,65 @@ def predict_image_tiled( device=device, ) + tile_width = tile.x1 - tile.x0 + tile_height = tile.y1 - tile.y0 + if tile_keep is not None: + keep_margin_px = int(round((1.0 - tile_keep) * tile_size / 2.0)) + else: + keep_margin_px = int(tile_size * tile_overlap / 2.0) + keep_margin_px = max(0, keep_margin_px) + max_margin = max(0, min(tile_width, tile_height) // 2) + keep_margin_px = min(keep_margin_px, max_margin) + x_keep0 = keep_margin_px + y_keep0 = keep_margin_px + x_keep1 = tile_width - keep_margin_px + y_keep1 = tile_height - keep_margin_px + if x_keep1 <= x_keep0 or y_keep1 <= y_keep0: + x_keep0 = 0 + y_keep0 = 0 + x_keep1 = tile_width + y_keep1 = tile_height + + mean_vectors = prediction.pred.mean_vectors + if mean_vectors.ndim == 3: + mean_xy = mean_vectors[0, :, :2] + elif mean_vectors.ndim == 2: + mean_xy = mean_vectors[:, :2] + else: + raise ValueError("Unsupported gaussians mean_vectors shape for tiling.") + gx = (mean_xy[:, 0] + 1.0) * 0.5 * tile_width + gy = (mean_xy[:, 1] + 1.0) * 0.5 * tile_height + keep_mask = ( + (gx >= x_keep0) + & (gx < x_keep1) + & (gy >= y_keep0) + & (gy < y_keep1) + ) + kept_count = int(keep_mask.sum().item()) + total_count = int(keep_mask.numel()) + LOGGER.debug("Tile %d: kept %d / %d gaussians", tile_index, kept_count, total_count) + if kept_count == 0: + LOGGER.warning("Tile %d produced no gaussians after edge suppression.", tile_index) + continue + + def _index(tensor: torch.Tensor) -> torch.Tensor: + dim = 1 if tensor.ndim >= 3 else 0 + indices = keep_mask.nonzero(as_tuple=False).flatten() + return tensor.index_select(dim, indices) + + prediction = PredictionResult( + pred=Gaussians3D( + mean_vectors=_index(prediction.pred.mean_vectors), + singular_values=_index(prediction.pred.singular_values), + quaternions=_index(prediction.pred.quaternions), + colors=_index(prediction.pred.colors), + opacities=_index(prediction.pred.opacities), + ), + world=None, + unprojection_matrix=None, + unprojection_context=None, + ) + k_tile = shift_intrinsics_for_tile(k_full, tile.x0, tile.y0) k_tile_resized = scale_intrinsics_for_resize( k_tile, From 517537b202094da3ebca11e8c0704f53d080b6ec Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Wed, 21 Jan 2026 22:47:04 -0800 Subject: [PATCH 09/16] Fix tile edge suppression coordinates --- src/sharp/cli/predict.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index 080b25a1..87b8af78 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -1402,24 +1402,18 @@ def predict_image_tiled( device=device, ) - tile_width = tile.x1 - tile.x0 - tile_height = tile.y1 - tile.y0 + # Gaussians are in pred-space for the inference canvas; keep mask uses target pixels. + min_target = min(target_w, target_h) if tile_keep is not None: - keep_margin_px = int(round((1.0 - tile_keep) * tile_size / 2.0)) + keep_margin_px = int(round((1.0 - tile_keep) * min_target / 2.0)) else: - keep_margin_px = int(tile_size * tile_overlap / 2.0) - keep_margin_px = max(0, keep_margin_px) - max_margin = max(0, min(tile_width, tile_height) // 2) - keep_margin_px = min(keep_margin_px, max_margin) + keep_margin_px = int(round(min_target * tile_overlap / 2.0)) + max_margin = max(0, min_target // 2 - 1) + keep_margin_px = max(0, min(keep_margin_px, max_margin)) x_keep0 = keep_margin_px y_keep0 = keep_margin_px - x_keep1 = tile_width - keep_margin_px - y_keep1 = tile_height - keep_margin_px - if x_keep1 <= x_keep0 or y_keep1 <= y_keep0: - x_keep0 = 0 - y_keep0 = 0 - x_keep1 = tile_width - y_keep1 = tile_height + x_keep1 = target_w - keep_margin_px + y_keep1 = target_h - keep_margin_px mean_vectors = prediction.pred.mean_vectors if mean_vectors.ndim == 3: @@ -1428,8 +1422,8 @@ def predict_image_tiled( mean_xy = mean_vectors[:, :2] else: raise ValueError("Unsupported gaussians mean_vectors shape for tiling.") - gx = (mean_xy[:, 0] + 1.0) * 0.5 * tile_width - gy = (mean_xy[:, 1] + 1.0) * 0.5 * tile_height + gx = (mean_xy[:, 0] + 1.0) * 0.5 * target_w + gy = (mean_xy[:, 1] + 1.0) * 0.5 * target_h keep_mask = ( (gx >= x_keep0) & (gx < x_keep1) From de9b2e94761e1554856c3e5aed05266c2793bd1b Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Wed, 21 Jan 2026 23:01:16 -0800 Subject: [PATCH 10/16] Fix tiled gaussian concatenation shape --- src/sharp/cli/predict.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index 87b8af78..407b88fe 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -1482,12 +1482,15 @@ def _index(tensor: torch.Tensor) -> torch.Tensor: def _concat_gaussians(items: list[Gaussians3D]) -> Gaussians3D: if not items: raise ValueError("No tiles produced for tiled prediction.") + # Gaussians are batch-first (B, N, ...); keep B=1 and concat along splat dim. + if not all(g.mean_vectors.shape[0] == 1 for g in items): + raise ValueError("Tiled gaussians must have batch dimension of 1.") return Gaussians3D( - mean_vectors=torch.cat([g.mean_vectors for g in items], dim=0), - singular_values=torch.cat([g.singular_values for g in items], dim=0), - quaternions=torch.cat([g.quaternions for g in items], dim=0), - colors=torch.cat([g.colors for g in items], dim=0), - opacities=torch.cat([g.opacities for g in items], dim=0), + mean_vectors=torch.cat([g.mean_vectors for g in items], dim=1), + singular_values=torch.cat([g.singular_values for g in items], dim=1), + quaternions=torch.cat([g.quaternions for g in items], dim=1), + colors=torch.cat([g.colors for g in items], dim=1), + opacities=torch.cat([g.opacities for g in items], dim=1), ) pred_gaussians = _concat_gaussians(tile_preds_global) From 6628b4082820bea8d4bcee7c2097ba107791c722 Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Wed, 21 Jan 2026 23:24:56 -0800 Subject: [PATCH 11/16] Guard tiled mode against export and compare --- src/sharp/cli/predict.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index 407b88fe..8c43a626 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -463,6 +463,14 @@ def predict_cli( raise click.ClickException( "--tiling is not yet supported with batched inputs; use --batch-size=1." ) + if tiling and fast_preview_compare: + raise click.ClickException( + "--tiling does not support compare; use non-tiled mode." + ) + if tiling and save_ply: + raise click.ClickException( + "--tiling does not support PLY export; use non-tiled mode." + ) if tiling and (sbs_image is None or not fast_preview_render): raise click.ClickException( "--tiling is only supported with --sbs-image and --fast-preview-render." @@ -578,7 +586,9 @@ def _natural_sort_key(path: Path) -> list[tuple[int, object]]: raise click.ClickException( "--fast-preview-render is only supported for --sbs-image (not --render)." ) - if save_ply is None and want_sbs_image: + if tiling: + effective_save_ply = False + elif save_ply is None and want_sbs_image: effective_save_ply = False elif save_ply is None and not want_sbs_image: effective_save_ply = True @@ -817,7 +827,7 @@ def _finalize_prediction( effective_save_ply and not defer_export_world ) if tiling: - want_world_for_predict = True + want_world_for_predict = False if skip_world_conversion and (want_world_for_predict or effective_save_ply): raise click.ClickException( "World-space conversion is required for rendering or PLY export. " From 31d1a21b7d1ad3da4b857180179c59a197950cd2 Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Wed, 21 Jan 2026 23:40:38 -0800 Subject: [PATCH 12/16] Fix tiled gaussian indexing --- src/sharp/cli/predict.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index 8c43a626..7f1b0deb 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -1448,9 +1448,11 @@ def predict_image_tiled( continue def _index(tensor: torch.Tensor) -> torch.Tensor: - dim = 1 if tensor.ndim >= 3 else 0 - indices = keep_mask.nonzero(as_tuple=False).flatten() - return tensor.index_select(dim, indices) + if tensor.ndim == 1: + return tensor[keep_mask] + if tensor.ndim == 2: + return tensor[:, keep_mask] + return tensor[:, keep_mask, ...] prediction = PredictionResult( pred=Gaussians3D( From 2d7fcb8a842b37a9b2016a1f1a1e2409ee1f17b3 Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Thu, 22 Jan 2026 22:14:38 -0800 Subject: [PATCH 13/16] Fix tiled intrinsics scaling --- src/sharp/cli/predict.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index 7f1b0deb..7215be84 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -1467,12 +1467,26 @@ def _index(tensor: torch.Tensor) -> torch.Tensor: unprojection_context=None, ) + tile_w = tile.x1 - tile.x0 + tile_h = tile.y1 - tile.y0 k_tile = shift_intrinsics_for_tile(k_full, tile.x0, tile.y0) k_tile_resized = scale_intrinsics_for_resize( k_tile, - src_wh=(width_full, height_full), + src_wh=(tile_w, tile_h), dst_wh=target_size_wh, ) + LOGGER.debug( + "tile=%d x0=%d y0=%d w=%d h=%d cx=%.2f cx_resized=%.2f cy=%.2f cy_resized=%.2f", + tile_index, + tile.x0, + tile.y0, + tile_w, + tile_h, + k_tile[0, 2].item(), + k_tile_resized[0, 2].item(), + k_tile[1, 2].item(), + k_tile_resized[1, 2].item(), + ) intrinsics_4 = torch.eye(4, device=device, dtype=k_tile_resized.dtype) intrinsics_4[:3, :3] = k_tile_resized extrinsics = torch.eye(4, device=device, dtype=k_tile_resized.dtype) From 46d7b922a3b678ee9a1596a2ae9702006ba2da1e Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Thu, 22 Jan 2026 22:53:01 -0800 Subject: [PATCH 14/16] Fix tiled edge suppression margins --- src/sharp/cli/predict.py | 55 +++++++++++++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 9 deletions(-) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index 7215be84..21dddbee 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -1413,17 +1413,24 @@ def predict_image_tiled( ) # Gaussians are in pred-space for the inference canvas; keep mask uses target pixels. - min_target = min(target_w, target_h) if tile_keep is not None: - keep_margin_px = int(round((1.0 - tile_keep) * min_target / 2.0)) + base_x_margin = int(round((1.0 - tile_keep) * target_w / 2.0)) + base_y_margin = int(round((1.0 - tile_keep) * target_h / 2.0)) else: - keep_margin_px = int(round(min_target * tile_overlap / 2.0)) - max_margin = max(0, min_target // 2 - 1) - keep_margin_px = max(0, min(keep_margin_px, max_margin)) - x_keep0 = keep_margin_px - y_keep0 = keep_margin_px - x_keep1 = target_w - keep_margin_px - y_keep1 = target_h - keep_margin_px + base_x_margin = int(round(tile_overlap * target_w / 2.0)) + base_y_margin = int(round(tile_overlap * target_h / 2.0)) + base_x_margin = max(0, min(base_x_margin, target_w // 2 - 1)) + base_y_margin = max(0, min(base_y_margin, target_h // 2 - 1)) + + left_margin = base_x_margin if tile.x0 > 0 else 0 + right_margin = base_x_margin if tile.x1 < width_full else 0 + top_margin = base_y_margin if tile.y0 > 0 else 0 + bottom_margin = base_y_margin if tile.y1 < height_full else 0 + + x_keep0 = left_margin + y_keep0 = top_margin + x_keep1 = target_w - right_margin + y_keep1 = target_h - bottom_margin mean_vectors = prediction.pred.mean_vectors if mean_vectors.ndim == 3: @@ -1434,6 +1441,36 @@ def predict_image_tiled( raise ValueError("Unsupported gaussians mean_vectors shape for tiling.") gx = (mean_xy[:, 0] + 1.0) * 0.5 * target_w gy = (mean_xy[:, 1] + 1.0) * 0.5 * target_h + if LOGGER.isEnabledFor(logging.DEBUG): + mean_x_min = mean_xy[:, 0].min().item() + mean_x_max = mean_xy[:, 0].max().item() + mean_y_min = mean_xy[:, 1].min().item() + mean_y_max = mean_xy[:, 1].max().item() + gx_min = gx.min().item() + gx_max = gx.max().item() + gy_min = gy.min().item() + gy_max = gy.max().item() + LOGGER.debug( + "tile=%d x0=%d y0=%d x1=%d y1=%d mean_x=[%.3f,%.3f] mean_y=[%.3f,%.3f] " + "gx=[%.1f,%.1f] gy=[%.1f,%.1f] keep=[%d:%d,%d:%d]", + tile_index, + tile.x0, + tile.y0, + tile.x1, + tile.y1, + mean_x_min, + mean_x_max, + mean_y_min, + mean_y_max, + gx_min, + gx_max, + gy_min, + gy_max, + x_keep0, + x_keep1, + y_keep0, + y_keep1, + ) keep_mask = ( (gx >= x_keep0) & (gx < x_keep1) From be7d01d202d5b2505e4eb5792a85793351730108 Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Thu, 22 Jan 2026 23:23:44 -0800 Subject: [PATCH 15/16] Fix tiled edge suppression projection --- src/sharp/cli/predict.py | 89 ++++++++++++++++++++++++++-------------- 1 file changed, 58 insertions(+), 31 deletions(-) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index 21dddbee..fe481026 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -1412,6 +1412,27 @@ def predict_image_tiled( device=device, ) + tile_w = tile.x1 - tile.x0 + tile_h = tile.y1 - tile.y0 + k_tile = shift_intrinsics_for_tile(k_full, tile.x0, tile.y0) + k_tile_resized = scale_intrinsics_for_resize( + k_tile, + src_wh=(tile_w, tile_h), + dst_wh=target_size_wh, + ) + LOGGER.debug( + "tile=%d x0=%d y0=%d w=%d h=%d cx=%.2f cx_resized=%.2f cy=%.2f cy_resized=%.2f", + tile_index, + tile.x0, + tile.y0, + tile_w, + tile_h, + k_tile[0, 2].item(), + k_tile_resized[0, 2].item(), + k_tile[1, 2].item(), + k_tile_resized[1, 2].item(), + ) + # Gaussians are in pred-space for the inference canvas; keep mask uses target pixels. if tile_keep is not None: base_x_margin = int(round((1.0 - tile_keep) * target_w / 2.0)) @@ -1439,20 +1460,46 @@ def predict_image_tiled( mean_xy = mean_vectors[:, :2] else: raise ValueError("Unsupported gaussians mean_vectors shape for tiling.") - gx = (mean_xy[:, 0] + 1.0) * 0.5 * target_w - gy = (mean_xy[:, 1] + 1.0) * 0.5 * target_h + + def _project_to_pixels( + mean_vectors: torch.Tensor, intrinsics: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + if mean_vectors.ndim == 2: + mean_vectors = mean_vectors.unsqueeze(0) + squeeze = True + elif mean_vectors.ndim == 3: + squeeze = False + else: + raise ValueError("Unsupported gaussians mean_vectors shape for projection.") + x = mean_vectors[..., 0] + y = mean_vectors[..., 1] + z = mean_vectors[..., 2].clamp_min(1e-6) + fx = intrinsics[0, 0] + fy = intrinsics[1, 1] + cx = intrinsics[0, 2] + cy = intrinsics[1, 2] + u = fx * (x / z) + cx + v = fy * (y / z) + cy + if squeeze: + return u[0], v[0] + return u, v + + gx, gy = _project_to_pixels(prediction.pred.mean_vectors, k_tile_resized) + if gx.ndim > 1: + gx = gx[0] + gy = gy[0] if LOGGER.isEnabledFor(logging.DEBUG): mean_x_min = mean_xy[:, 0].min().item() mean_x_max = mean_xy[:, 0].max().item() mean_y_min = mean_xy[:, 1].min().item() mean_y_max = mean_xy[:, 1].max().item() - gx_min = gx.min().item() - gx_max = gx.max().item() - gy_min = gy.min().item() - gy_max = gy.max().item() + u_min = gx.min().item() + u_max = gx.max().item() + v_min = gy.min().item() + v_max = gy.max().item() LOGGER.debug( "tile=%d x0=%d y0=%d x1=%d y1=%d mean_x=[%.3f,%.3f] mean_y=[%.3f,%.3f] " - "gx=[%.1f,%.1f] gy=[%.1f,%.1f] keep=[%d:%d,%d:%d]", + "u=[%.1f,%.1f] v=[%.1f,%.1f] keep=[%d:%d,%d:%d]", tile_index, tile.x0, tile.y0, @@ -1462,10 +1509,10 @@ def predict_image_tiled( mean_x_max, mean_y_min, mean_y_max, - gx_min, - gx_max, - gy_min, - gy_max, + u_min, + u_max, + v_min, + v_max, x_keep0, x_keep1, y_keep0, @@ -1504,26 +1551,6 @@ def _index(tensor: torch.Tensor) -> torch.Tensor: unprojection_context=None, ) - tile_w = tile.x1 - tile.x0 - tile_h = tile.y1 - tile.y0 - k_tile = shift_intrinsics_for_tile(k_full, tile.x0, tile.y0) - k_tile_resized = scale_intrinsics_for_resize( - k_tile, - src_wh=(tile_w, tile_h), - dst_wh=target_size_wh, - ) - LOGGER.debug( - "tile=%d x0=%d y0=%d w=%d h=%d cx=%.2f cx_resized=%.2f cy=%.2f cy_resized=%.2f", - tile_index, - tile.x0, - tile.y0, - tile_w, - tile_h, - k_tile[0, 2].item(), - k_tile_resized[0, 2].item(), - k_tile[1, 2].item(), - k_tile_resized[1, 2].item(), - ) intrinsics_4 = torch.eye(4, device=device, dtype=k_tile_resized.dtype) intrinsics_4[:3, :3] = k_tile_resized extrinsics = torch.eye(4, device=device, dtype=k_tile_resized.dtype) From a272d2198bce42204a074f2505b7c73b68b90b38 Mon Sep 17 00:00:00 2001 From: disk02 <130608498+disk02@users.noreply.github.com> Date: Thu, 22 Jan 2026 23:40:52 -0800 Subject: [PATCH 16/16] Project tiled gaussians with unprojection inverse --- src/sharp/cli/predict.py | 65 ++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 40 deletions(-) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index fe481026..1cd09805 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -1461,42 +1461,32 @@ def predict_image_tiled( else: raise ValueError("Unsupported gaussians mean_vectors shape for tiling.") - def _project_to_pixels( - mean_vectors: torch.Tensor, intrinsics: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - if mean_vectors.ndim == 2: - mean_vectors = mean_vectors.unsqueeze(0) - squeeze = True - elif mean_vectors.ndim == 3: - squeeze = False - else: - raise ValueError("Unsupported gaussians mean_vectors shape for projection.") - x = mean_vectors[..., 0] - y = mean_vectors[..., 1] - z = mean_vectors[..., 2].clamp_min(1e-6) - fx = intrinsics[0, 0] - fy = intrinsics[1, 1] - cx = intrinsics[0, 2] - cy = intrinsics[1, 2] - u = fx * (x / z) + cx - v = fy * (y / z) + cy - if squeeze: - return u[0], v[0] - return u, v - - gx, gy = _project_to_pixels(prediction.pred.mean_vectors, k_tile_resized) - if gx.ndim > 1: - gx = gx[0] - gy = gy[0] + intrinsics_4 = torch.eye(4, device=device, dtype=k_tile_resized.dtype) + intrinsics_4[:3, :3] = k_tile_resized + extrinsics = torch.eye(4, device=device, dtype=k_tile_resized.dtype) + u_tile = get_unprojection_matrix(extrinsics, intrinsics_4, (target_w, target_h)) + u_tile_inv = torch.linalg.inv(u_tile) + + if mean_vectors.ndim == 3: + pred_points = mean_vectors[0] + else: + pred_points = mean_vectors + ones = torch.ones( + (pred_points.shape[0], 1), device=pred_points.device, dtype=pred_points.dtype + ) + pred_points_h = torch.cat([pred_points, ones], dim=-1) + px_h = (u_tile_inv @ pred_points_h.T).T + u = px_h[:, 0] / px_h[:, 3] + v = px_h[:, 1] / px_h[:, 3] if LOGGER.isEnabledFor(logging.DEBUG): mean_x_min = mean_xy[:, 0].min().item() mean_x_max = mean_xy[:, 0].max().item() mean_y_min = mean_xy[:, 1].min().item() mean_y_max = mean_xy[:, 1].max().item() - u_min = gx.min().item() - u_max = gx.max().item() - v_min = gy.min().item() - v_max = gy.max().item() + u_min = u.min().item() + u_max = u.max().item() + v_min = v.min().item() + v_max = v.max().item() LOGGER.debug( "tile=%d x0=%d y0=%d x1=%d y1=%d mean_x=[%.3f,%.3f] mean_y=[%.3f,%.3f] " "u=[%.1f,%.1f] v=[%.1f,%.1f] keep=[%d:%d,%d:%d]", @@ -1519,10 +1509,10 @@ def _project_to_pixels( y_keep1, ) keep_mask = ( - (gx >= x_keep0) - & (gx < x_keep1) - & (gy >= y_keep0) - & (gy < y_keep1) + (u >= x_keep0) + & (u < x_keep1) + & (v >= y_keep0) + & (v < y_keep1) ) kept_count = int(keep_mask.sum().item()) total_count = int(keep_mask.numel()) @@ -1551,11 +1541,6 @@ def _index(tensor: torch.Tensor) -> torch.Tensor: unprojection_context=None, ) - intrinsics_4 = torch.eye(4, device=device, dtype=k_tile_resized.dtype) - intrinsics_4[:3, :3] = k_tile_resized - extrinsics = torch.eye(4, device=device, dtype=k_tile_resized.dtype) - u_tile = get_unprojection_matrix(extrinsics, intrinsics_4, (target_w, target_h)) - # Remap tile pred-space gaussians into the global pred-space so fast preview renders once. remap_start = perf_counter() if metrics else None tile_to_global_pred = u_global_inv @ u_tile