diff --git a/requirements.txt b/requirements.txt index 6f59d745..bfd95bc5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,6 +26,7 @@ fsspec==2025.9.0 # via # huggingface-hub # torch +geocalib @ git+https://github.com/cvg/GeoCalib.git gsplat==1.5.3 # via sharp hf-xet==1.1.10 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index b34326cf..92a9ed6e 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -345,6 +345,18 @@ def _center_crop(img: np.ndarray, target_h: int, target_w: int) -> np.ndarray: default=False, help="Compute PLY export conversions in fp32 (even if inference uses AMP).", ) +@click.option( + "--geocalib", + is_flag=True, + default=False, + help="Enable GeoCalib focal length estimation.", +) +@click.option( + "--geocalib-per-folder", + is_flag=True, + default=False, + help="Run GeoCalib once per folder and reuse for images in that folder.", +) @click.option( "--device", type=str, @@ -398,6 +410,8 @@ def predict_cli( skip_world_conversion: bool, defer_world_conversion_for_export: bool, export_fp32: bool, + geocalib: bool, + geocalib_per_folder: bool, device: str, amp: bool | None, amp_dtype: str, @@ -425,6 +439,9 @@ def predict_cli( return if batch_size < 1: raise click.ClickException("--batch-size must be >= 1.") + if geocalib_per_folder and not geocalib: + LOGGER.warning("--geocalib-per-folder ignored because --geocalib was not set.") + geocalib_per_folder = False def _natural_sort_key(path: Path) -> list[tuple[int, object]]: relative_path = path.relative_to(input_path).as_posix() @@ -484,6 +501,30 @@ def _natural_sort_key(path: Path) -> list[tuple[int, object]]: LOGGER.warning("Can only run rendering with gsplat on CUDA. Rendering is disabled.") with_rendering = False + geocalib_runner = None + folder_fpx_cache: dict[Path, float] = {} + if geocalib: + from sharp.utils.geocalib import GeoCalibRunner + + geocalib_device = torch.device("cpu" if device == "mps" else device) + geocalib_runner = GeoCalibRunner(geocalib_device) + if geocalib_per_folder and input_is_dir: + folder_map: dict[Path, list[Path]] = {} + for path in image_paths: + folder_map.setdefault(path.parent, []).append(path) + for folder, folder_images in folder_map.items(): + try: + f_px_folder = geocalib_runner.calibrate_folder(folder_images) + except Exception as exc: # pragma: no cover - best-effort fallback + LOGGER.warning( + "GeoCalib folder calibration failed for %s: %s. Falling back to EXIF/default.", + folder, + exc, + ) + continue + folder_fpx_cache[folder] = f_px_folder + LOGGER.info("GeoCalib folder focal computed: %s -> %.2f", folder, f_px_folder) + # Load or download checkpoint if checkpoint_path is None: LOGGER.info("No checkpoint provided. Downloading default model from %s", DEFAULT_MODEL_URL) @@ -738,6 +779,21 @@ def _finalize_prediction( LOGGER.info("Skipping .ply save because --no-save-ply was requested.") metrics.add_time("per_image_total", perf_counter() - image_start) + def _resolve_f_px(image_path: Path, f_px_exif: float) -> float: + if geocalib_runner is None: + return f_px_exif + if geocalib_per_folder and input_is_dir: + return folder_fpx_cache.get(image_path.parent, f_px_exif) + try: + return geocalib_runner.calibrate_image(image_path) + except Exception as exc: # pragma: no cover - best-effort fallback + LOGGER.warning( + "GeoCalib failed for %s: %s. Falling back to EXIF/default.", + image_path, + exc, + ) + return f_px_exif + run_start = perf_counter() try: if batch_size <= 1 or len(image_paths) == 1: @@ -749,11 +805,12 @@ def _finalize_prediction( LOGGER.info("Processing %s (%d/%d)", image_path, index, len(image_paths)) io_start = perf_counter() try: - image, _, f_px = io.load_rgb(image_path) + image, _, f_px_exif = io.load_rgb(image_path) except (OSError, UnidentifiedImageError, ValueError) as exc: LOGGER.warning("Skipping unreadable image %s: %s", image_path, exc) continue metrics.add_time("io_decode", perf_counter() - io_start) + f_px = _resolve_f_px(image_path, f_px_exif) height, width = image.shape[:2] intrinsics = torch.tensor( [ @@ -834,11 +891,12 @@ def _finalize_prediction( LOGGER.info("Processing %s (%d/%d)", image_path, index, total_images) io_start = perf_counter() try: - image, _, f_px = io.load_rgb(image_path) + image, _, f_px_exif = io.load_rgb(image_path) except (OSError, UnidentifiedImageError, ValueError) as exc: LOGGER.warning("Skipping unreadable image %s: %s", image_path, exc) continue metrics.add_time("io_decode", perf_counter() - io_start) + f_px = _resolve_f_px(image_path, f_px_exif) height, width = image.shape[:2] intrinsics = torch.tensor( [ diff --git a/src/sharp/utils/geocalib.py b/src/sharp/utils/geocalib.py new file mode 100644 index 00000000..e6f1b779 --- /dev/null +++ b/src/sharp/utils/geocalib.py @@ -0,0 +1,66 @@ +"""GeoCalib adapter for focal length estimation.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import torch + +LOGGER = logging.getLogger(__name__) + + +def _get_attr(obj: object, name: str) -> Any | None: + if isinstance(obj, dict): + return obj.get(name) + return getattr(obj, name, None) + + +def _extract_f_px(calibration: object) -> float: + camera = _get_attr(calibration, "camera") + if camera is None and isinstance(calibration, dict): + camera = calibration.get("camera") + if camera is None: + camera = calibration + fx = _get_attr(camera, "fx") + fy = _get_attr(camera, "fy") + if fx is not None and fy is not None: + return float((fx + fy) / 2.0) + f = _get_attr(camera, "f") + if f is None and isinstance(calibration, dict): + f = calibration.get("f") + if f is not None: + return float(f) + raise RuntimeError("GeoCalib calibration did not include focal length values.") + + +@dataclass +class GeoCalibRunner: + device: torch.device + _model: Any = field(init=False, repr=False) + + def __post_init__(self) -> None: + from geocalib import GeoCalib + + self._model = GeoCalib() + if hasattr(self._model, "to"): + self._model = self._model.to(self.device) + + def calibrate_image(self, image_path: Path) -> float: + image = self._model.load_image(str(image_path)) + if hasattr(image, "to"): + image = image.to(self.device) + calibration = self._model.calibrate(image) + return _extract_f_px(calibration) + + def calibrate_folder(self, image_paths: list[Path]) -> float: + images = [] + for path in image_paths: + image = self._model.load_image(str(path)) + if hasattr(image, "to"): + image = image.to(self.device) + images.append(image) + calibration = self._model.calibrate(images, shared_intrinsics=True) + return _extract_f_px(calibration)