diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index 3cf8707d..1cd09805 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -29,11 +29,13 @@ RGBGaussianPredictor, create_predictor, ) +from sharp.utils.tiling import make_tiles, scale_intrinsics_for_resize, shift_intrinsics_for_tile from sharp.utils import io from sharp.utils import logging as logging_utils from sharp.utils.gaussians import ( Gaussians3D, SceneMetaData, + apply_transform, get_unprojection_matrix, save_ply, unproject_gaussians, @@ -275,6 +277,33 @@ def _center_crop(img: np.ndarray, target_h: int, target_w: int) -> np.ndarray: default=False, help="Render SBS preview using predicted-space gaussians (skips world conversion).", ) +@click.option( + "--tiling", + is_flag=True, + default=False, + help="Enable tiled inference mode (requires --sbs-image and --fast-preview-render).", +) +@click.option( + "--tile-size", + type=int, + default=1536, + show_default=True, + help="Nominal tile size in pixels (square tiles assumed).", +) +@click.option( + "--tile-overlap", + type=float, + default=0.25, + show_default=True, + help="Fractional overlap between tiles in [0.0, 0.5).", +) +@click.option( + "--tile-keep", + type=float, + default=None, + show_default=True, + help="Optional keep region fraction for tiles (defaults to derive from overlap).", +) @click.option( "--sbs-min-opacity", type=float, @@ -389,6 +418,10 @@ def predict_cli( sbs_image_frame: int, stereo_strength: float, fast_preview_render: bool, + tiling: bool, + tile_size: int, + tile_overlap: float, + tile_keep: float | None, sbs_min_opacity: float, sbs_min_scale: float, sbs_max_splats: int | None, @@ -426,6 +459,26 @@ def predict_cli( return if batch_size < 1: raise click.ClickException("--batch-size must be >= 1.") + if tiling and batch_size > 1: + raise click.ClickException( + "--tiling is not yet supported with batched inputs; use --batch-size=1." + ) + if tiling and fast_preview_compare: + raise click.ClickException( + "--tiling does not support compare; use non-tiled mode." + ) + if tiling and save_ply: + raise click.ClickException( + "--tiling does not support PLY export; use non-tiled mode." + ) + if tiling and (sbs_image is None or not fast_preview_render): + raise click.ClickException( + "--tiling is only supported with --sbs-image and --fast-preview-render." + ) + if tile_overlap < 0.0 or tile_overlap >= 0.5: + raise click.ClickException("--tile-overlap must be in [0.0, 0.5).") + if tile_keep is not None and (tile_keep <= 0.0 or tile_keep > 1.0): + raise click.ClickException("--tile-keep must be in (0.0, 1.0].") def _natural_sort_key(path: Path) -> list[tuple[int, object]]: relative_path = path.relative_to(input_path).as_posix() @@ -533,7 +586,9 @@ def _natural_sort_key(path: Path) -> list[tuple[int, object]]: raise click.ClickException( "--fast-preview-render is only supported for --sbs-image (not --render)." ) - if save_ply is None and want_sbs_image: + if tiling: + effective_save_ply = False + elif save_ply is None and want_sbs_image: effective_save_ply = False elif save_ply is None and not want_sbs_image: effective_save_ply = True @@ -771,6 +826,8 @@ def _finalize_prediction( want_world_for_predict = want_world_for_render or ( effective_save_ply and not defer_export_world ) + if tiling: + want_world_for_predict = False if skip_world_conversion and (want_world_for_predict or effective_save_ply): raise click.ClickException( "World-space conversion is required for rendering or PLY export. " @@ -787,18 +844,33 @@ def _finalize_prediction( LOGGER.info("Computing world-space gaussians: %s", "yes" if space == "world" else "no") predict_start = perf_counter() - prediction = predict_image( - gaussian_predictor, - image, - f_px, - torch.device(device), - amp_enabled=amp, - amp_dtype=amp_dtype_to_use, - metrics=metrics, - return_world=want_world_for_predict, - return_unprojection=fast_preview_render, - return_unprojection_context=fast_preview_render or effective_save_ply, - ) + if tiling: + prediction = predict_image_tiled( + gaussian_predictor, + image, + f_px, + torch.device(device), + tile_size=tile_size, + tile_overlap=tile_overlap, + tile_keep=tile_keep, + target_size_wh=(1536, 1536), + amp_enabled=amp, + amp_dtype=amp_dtype_to_use, + metrics=metrics, + ) + else: + prediction = predict_image( + gaussian_predictor, + image, + f_px, + torch.device(device), + amp_enabled=amp, + amp_dtype=amp_dtype_to_use, + metrics=metrics, + return_world=want_world_for_predict, + return_unprojection=fast_preview_render, + return_unprojection_context=fast_preview_render or effective_save_ply, + ) metrics.add_time("predict_total", perf_counter() - predict_start) _finalize_prediction( prediction=prediction, @@ -998,6 +1070,7 @@ def preprocess_one( *, target_size_wh: tuple[int, int], dtype: torch.dtype, + reference_width: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor, dict]: target_w, target_h = target_size_wh image_np = np.ascontiguousarray(image_np) @@ -1011,7 +1084,12 @@ def preprocess_one( ) image_pt = image_pt / 255.0 _, height, width = image_pt.shape - disparity_factor_pt = torch.tensor([f_px / width], dtype=dtype, device=device) + width_for_disparity = width + if reference_width is not None: + if not isinstance(reference_width, int) or reference_width <= 0: + raise click.ClickException("reference_width must be a positive integer.") + width_for_disparity = reference_width + disparity_factor_pt = torch.tensor([f_px / width_for_disparity], dtype=dtype, device=device) image_resized_pt = TF.resize( image_pt, [target_h, target_w], @@ -1244,6 +1322,267 @@ def predict_image( ) +def predict_image_tiled( + predictor: RGBGaussianPredictor, + image: np.ndarray, + f_px: float, + device: torch.device, + *, + tile_size: int, + tile_overlap: float, + tile_keep: float | None, + target_size_wh: tuple[int, int], + amp_enabled: bool, + amp_dtype: torch.dtype | None, + metrics: Metrics | None = None, +) -> PredictionResult: + """Predict Gaussians per tile and remap into a shared pred-space.""" + height_full, width_full = image.shape[:2] + fx = f_px + fy = f_px + cx = (width_full - 1) / 2.0 + cy = (height_full - 1) / 2.0 + k_full = torch.tensor( + [ + [fx, 0.0, cx], + [0.0, fy, cy], + [0.0, 0.0, 1.0], + ], + device=device, + dtype=torch.float32, + ) + tiles = make_tiles(width_full, height_full, tile_size, tile_overlap) + LOGGER.info( + "Tiling enabled: tile_size=%d tile_overlap=%.2f tiles=%d", + tile_size, + tile_overlap, + len(tiles), + ) + + amp_dtype_to_use = amp_dtype if amp_enabled else None + target_w, target_h = target_size_wh + + k_full_resized = scale_intrinsics_for_resize( + k_full, + src_wh=(width_full, height_full), + dst_wh=target_size_wh, + ) + intrinsics_full = torch.eye(4, device=device, dtype=k_full_resized.dtype) + intrinsics_full[:3, :3] = k_full_resized + extrinsics_full = torch.eye(4, device=device, dtype=k_full_resized.dtype) + u_global = get_unprojection_matrix( + extrinsics_full, + intrinsics_full, + (target_w, target_h), + ) + u_global_inv = torch.linalg.inv(u_global) + + tile_preds_global: list[Gaussians3D] = [] + for tile_index, tile in enumerate(tiles, start=1): + tile_img = image[tile.y0 : tile.y1, tile.x0 : tile.x1] + preprocess_start = perf_counter() if metrics else None + tile_resized_pt, tile_disp_factor, tile_aux = preprocess_one( + tile_img, + f_px, + device, + target_size_wh=target_size_wh, + dtype=torch.float32, + reference_width=width_full, + ) + tile_aux["metrics"] = metrics + if metrics and preprocess_start is not None: + metrics.add_time("tile_preprocess", perf_counter() - preprocess_start) + + forward_start = perf_counter() if metrics else None + gaussians_ndc_batch = model_forward_batch( + predictor, + tile_resized_pt, + tile_disp_factor, + amp_dtype=amp_dtype_to_use, + ) + if metrics and forward_start is not None: + metrics.add_time("tile_forward", perf_counter() - forward_start) + + gaussians_ndc_one = _slice_gaussians(gaussians_ndc_batch, 0) + prediction = postprocess_one( + gaussians_ndc_one, + tile_aux, + return_world=False, + return_unprojection=False, + device=device, + ) + + tile_w = tile.x1 - tile.x0 + tile_h = tile.y1 - tile.y0 + k_tile = shift_intrinsics_for_tile(k_full, tile.x0, tile.y0) + k_tile_resized = scale_intrinsics_for_resize( + k_tile, + src_wh=(tile_w, tile_h), + dst_wh=target_size_wh, + ) + LOGGER.debug( + "tile=%d x0=%d y0=%d w=%d h=%d cx=%.2f cx_resized=%.2f cy=%.2f cy_resized=%.2f", + tile_index, + tile.x0, + tile.y0, + tile_w, + tile_h, + k_tile[0, 2].item(), + k_tile_resized[0, 2].item(), + k_tile[1, 2].item(), + k_tile_resized[1, 2].item(), + ) + + # Gaussians are in pred-space for the inference canvas; keep mask uses target pixels. + if tile_keep is not None: + base_x_margin = int(round((1.0 - tile_keep) * target_w / 2.0)) + base_y_margin = int(round((1.0 - tile_keep) * target_h / 2.0)) + else: + base_x_margin = int(round(tile_overlap * target_w / 2.0)) + base_y_margin = int(round(tile_overlap * target_h / 2.0)) + base_x_margin = max(0, min(base_x_margin, target_w // 2 - 1)) + base_y_margin = max(0, min(base_y_margin, target_h // 2 - 1)) + + left_margin = base_x_margin if tile.x0 > 0 else 0 + right_margin = base_x_margin if tile.x1 < width_full else 0 + top_margin = base_y_margin if tile.y0 > 0 else 0 + bottom_margin = base_y_margin if tile.y1 < height_full else 0 + + x_keep0 = left_margin + y_keep0 = top_margin + x_keep1 = target_w - right_margin + y_keep1 = target_h - bottom_margin + + mean_vectors = prediction.pred.mean_vectors + if mean_vectors.ndim == 3: + mean_xy = mean_vectors[0, :, :2] + elif mean_vectors.ndim == 2: + mean_xy = mean_vectors[:, :2] + else: + raise ValueError("Unsupported gaussians mean_vectors shape for tiling.") + + intrinsics_4 = torch.eye(4, device=device, dtype=k_tile_resized.dtype) + intrinsics_4[:3, :3] = k_tile_resized + extrinsics = torch.eye(4, device=device, dtype=k_tile_resized.dtype) + u_tile = get_unprojection_matrix(extrinsics, intrinsics_4, (target_w, target_h)) + u_tile_inv = torch.linalg.inv(u_tile) + + if mean_vectors.ndim == 3: + pred_points = mean_vectors[0] + else: + pred_points = mean_vectors + ones = torch.ones( + (pred_points.shape[0], 1), device=pred_points.device, dtype=pred_points.dtype + ) + pred_points_h = torch.cat([pred_points, ones], dim=-1) + px_h = (u_tile_inv @ pred_points_h.T).T + u = px_h[:, 0] / px_h[:, 3] + v = px_h[:, 1] / px_h[:, 3] + if LOGGER.isEnabledFor(logging.DEBUG): + mean_x_min = mean_xy[:, 0].min().item() + mean_x_max = mean_xy[:, 0].max().item() + mean_y_min = mean_xy[:, 1].min().item() + mean_y_max = mean_xy[:, 1].max().item() + u_min = u.min().item() + u_max = u.max().item() + v_min = v.min().item() + v_max = v.max().item() + LOGGER.debug( + "tile=%d x0=%d y0=%d x1=%d y1=%d mean_x=[%.3f,%.3f] mean_y=[%.3f,%.3f] " + "u=[%.1f,%.1f] v=[%.1f,%.1f] keep=[%d:%d,%d:%d]", + tile_index, + tile.x0, + tile.y0, + tile.x1, + tile.y1, + mean_x_min, + mean_x_max, + mean_y_min, + mean_y_max, + u_min, + u_max, + v_min, + v_max, + x_keep0, + x_keep1, + y_keep0, + y_keep1, + ) + keep_mask = ( + (u >= x_keep0) + & (u < x_keep1) + & (v >= y_keep0) + & (v < y_keep1) + ) + kept_count = int(keep_mask.sum().item()) + total_count = int(keep_mask.numel()) + LOGGER.debug("Tile %d: kept %d / %d gaussians", tile_index, kept_count, total_count) + if kept_count == 0: + LOGGER.warning("Tile %d produced no gaussians after edge suppression.", tile_index) + continue + + def _index(tensor: torch.Tensor) -> torch.Tensor: + if tensor.ndim == 1: + return tensor[keep_mask] + if tensor.ndim == 2: + return tensor[:, keep_mask] + return tensor[:, keep_mask, ...] + + prediction = PredictionResult( + pred=Gaussians3D( + mean_vectors=_index(prediction.pred.mean_vectors), + singular_values=_index(prediction.pred.singular_values), + quaternions=_index(prediction.pred.quaternions), + colors=_index(prediction.pred.colors), + opacities=_index(prediction.pred.opacities), + ), + world=None, + unprojection_matrix=None, + unprojection_context=None, + ) + + # Remap tile pred-space gaussians into the global pred-space so fast preview renders once. + remap_start = perf_counter() if metrics else None + tile_to_global_pred = u_global_inv @ u_tile + global_pred_tile = apply_transform( + prediction.pred, + tile_to_global_pred[:3], + metrics=metrics, + ) + if metrics and remap_start is not None: + metrics.add_time("tile_remap", perf_counter() - remap_start) + + tile_preds_global.append(global_pred_tile) + + def _concat_gaussians(items: list[Gaussians3D]) -> Gaussians3D: + if not items: + raise ValueError("No tiles produced for tiled prediction.") + # Gaussians are batch-first (B, N, ...); keep B=1 and concat along splat dim. + if not all(g.mean_vectors.shape[0] == 1 for g in items): + raise ValueError("Tiled gaussians must have batch dimension of 1.") + return Gaussians3D( + mean_vectors=torch.cat([g.mean_vectors for g in items], dim=1), + singular_values=torch.cat([g.singular_values for g in items], dim=1), + quaternions=torch.cat([g.quaternions for g in items], dim=1), + colors=torch.cat([g.colors for g in items], dim=1), + opacities=torch.cat([g.opacities for g in items], dim=1), + ) + + pred_gaussians = _concat_gaussians(tile_preds_global) + if pred_gaussians.mean_vectors.numel() == 0: + raise ValueError("Tiled prediction produced no Gaussians.") + unprojection_matrix = u_global + if unprojection_matrix is None: + raise ValueError("Missing global unprojection matrix for tiled prediction.") + + return PredictionResult( + pred=pred_gaussians, + world=None, + unprojection_matrix=unprojection_matrix, + unprojection_context=None, + ) + + def _cast_gaussians(gaussians: Gaussians3D, dtype: torch.dtype) -> Gaussians3D: return Gaussians3D( mean_vectors=gaussians.mean_vectors.to(dtype=dtype), diff --git a/src/sharp/utils/tiling.py b/src/sharp/utils/tiling.py new file mode 100644 index 00000000..940d749d --- /dev/null +++ b/src/sharp/utils/tiling.py @@ -0,0 +1,106 @@ +"""Utilities for tiled inference and camera compensation. + +For licensing see accompanying LICENSE file. +Copyright (C) 2025 Apple Inc. All Rights Reserved. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True) +class Tile: + """Axis-aligned integer tile bounds in pixel coordinates. + + Coordinates follow (x0, y0, x1, y1) with a half-open interval: + x in [x0, x1), y in [y0, y1). + """ + + x0: int + y0: int + x1: int + y1: int + + +def make_tiles(W: int, H: int, tile_size: int, overlap: float) -> list[Tile]: + """Generate a grid of overlapping tiles that cover an image. + + Args: + W: Full image width in pixels. + H: Full image height in pixels. + tile_size: Nominal tile edge length in pixels (square tiles). + overlap: Fractional overlap in [0.0, 0.5). + """ + if W <= 0 or H <= 0: + raise ValueError("W and H must be positive.") + if tile_size <= 0: + raise ValueError("tile_size must be positive.") + if overlap < 0.0 or overlap >= 0.5: + raise ValueError("overlap must be in [0.0, 0.5).") + + if tile_size >= W and tile_size >= H: + return [Tile(0, 0, W, H)] + + stride = max(1, int(round(tile_size * (1.0 - overlap)))) + + def _starts(full: int) -> list[int]: + starts: list[int] = [] + pos = 0 + while pos + tile_size < full: + starts.append(pos) + pos += stride + final_start = max(0, full - tile_size) + if not starts or starts[-1] != final_start: + starts.append(final_start) + return starts + + x_starts = _starts(W) + y_starts = _starts(H) + + tiles: list[Tile] = [] + for y0 in y_starts: + for x0 in x_starts: + x1 = min(W, x0 + tile_size) + y1 = min(H, y0 + tile_size) + x0_tile = x0 + y0_tile = y0 + if x1 - x0_tile < tile_size: + x0_tile = max(0, x1 - tile_size) + if y1 - y0_tile < tile_size: + y0_tile = max(0, y1 - tile_size) + tiles.append(Tile(x0_tile, y0_tile, x1, y1)) + + return tiles + + +def shift_intrinsics_for_tile( + K_full: torch.Tensor, x0: float, y0: float +) -> torch.Tensor: + """Shift intrinsics for a tile offset in full-image coordinates.""" + if K_full.shape != (3, 3): + raise ValueError("K_full must have shape (3, 3).") + K_tile = K_full.clone() + K_tile[0, 2] = K_tile[0, 2] - x0 + K_tile[1, 2] = K_tile[1, 2] - y0 + return K_tile + + +def scale_intrinsics_for_resize( + K: torch.Tensor, src_wh: tuple[int, int], dst_wh: tuple[int, int] +) -> torch.Tensor: + """Scale intrinsics to match a resize from src_wh to dst_wh (width, height).""" + if K.shape != (3, 3): + raise ValueError("K must have shape (3, 3).") + src_w, src_h = src_wh + dst_w, dst_h = dst_wh + if src_w <= 0 or src_h <= 0 or dst_w <= 0 or dst_h <= 0: + raise ValueError("src_wh and dst_wh must be positive.") + scale_x = dst_w / src_w + scale_y = dst_h / src_h + K_scaled = K.clone() + K_scaled[0, :] = K_scaled[0, :] * scale_x + K_scaled[1, :] = K_scaled[1, :] * scale_y + return K_scaled