Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 53 additions & 5 deletions rslearn/config/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,16 @@ class BandSetConfig(BaseModel):
default=None, description="Optional nodata value for each band."
)

# Optional explicit spatial dimensions for the materialized output. When set,
# the projection resolution is adjusted so that the window's geographic extent
# maps to exactly spatial_size pixels (height, width).
# This is useful for coarse-resolution layers (e.g. ERA5 at 0.1 deg) where
# only 1 pixel covers a typical window.
spatial_size: tuple[int, int] | None = Field(
default=None,
description="Optional (height, width) output size. Mutually exclusive with non-zero zoom_offset.",
)

@model_validator(mode="after")
def after_validator(self) -> "BandSetConfig":
"""Ensure the BandSetConfig is valid, and handle the num_bands field."""
Expand All @@ -197,24 +207,62 @@ def after_validator(self) -> "BandSetConfig":
self.bands = [f"B{band_idx}" for band_idx in range(self.num_bands)]
self.num_bands = None

if self.spatial_size is not None and self.zoom_offset != 0:
raise ValueError(
"spatial_size and non-zero zoom_offset are mutually exclusive"
)

if self.spatial_size is not None and any(v <= 0 for v in self.spatial_size):
raise ValueError("spatial_size values must be positive integers")

return self

def get_final_projection_and_bounds(
self, projection: Projection, bounds: PixelBounds
) -> tuple[Projection, PixelBounds]:
"""Gets the final projection/bounds based on band set config.

The band set config may apply a non-zero zoom offset that modifies the window's
projection.
The band set config may apply a non-zero zoom offset or a fixed spatial_size
that modifies the window's projection and bounds.

When ``spatial_size`` is set, the projection resolution is scaled so that the
window's geographic extent maps to exactly ``spatial_size`` pixels. The returned
bounds will have width = spatial_size[1] and height = spatial_size[0].

Note: the ``spatial_size`` path uses ``round()`` to compute the new pixel-space
origin, which can shift the geographic origin by up to half of the new
(coarser) pixel size. This is the same rounding behaviour used by
``ResolutionFactor.multiply_bounds`` and is negligible for coarse-resolution
layers (e.g. ERA5 at 0.1 deg) where sub-pixel shifts are irrelevant.

Args:
projection: the window's projection
bounds: the window's bounds (optional)
band_set: band set configuration object
bounds: the window's bounds

Returns:
tuple of updated projection and bounds with zoom offset applied
tuple of updated projection and bounds
"""
if self.spatial_size is not None:
target_h, target_w = self.spatial_size
cur_w = bounds[2] - bounds[0]
cur_h = bounds[3] - bounds[1]

x_factor = target_w / cur_w
y_factor = target_h / cur_h

new_projection = Projection(
projection.crs,
projection.x_resolution / x_factor,
projection.y_resolution / y_factor,
)
new_bounds = (
round(bounds[0] * x_factor),
round(bounds[1] * y_factor),
round(bounds[0] * x_factor) + target_w,
round(bounds[1] * y_factor) + target_h,
)
return (new_projection, new_bounds)

if self.zoom_offset >= 0:
factor = ResolutionFactor(numerator=2**self.zoom_offset)
else:
Expand Down
14 changes: 11 additions & 3 deletions rslearn/train/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from rslearn.utils.feature import Feature
from rslearn.utils.geometry import PixelBounds, ResolutionFactor
from rslearn.utils.mp import make_pool_and_star_imap_unordered
from rslearn.utils.raster_format import NumpyRasterFormat

from .model_context import SampleMetadata
from .tasks import Task
Expand Down Expand Up @@ -448,8 +449,15 @@ def read_raster_layer_for_data_input(
# resampling. If it really is much faster to handle it via torch, then it may
# make sense to bring back that functionality.

decode_kwargs: dict[str, Any] = {}
if isinstance(raster_format, NumpyRasterFormat):
decode_kwargs["expect_bounds_mismatch"] = band_set.spatial_size is not None
raster_array = raster_format.decode_raster(
raster_dir, final_projection, final_bounds, resampling=Resampling.nearest
raster_dir,
final_projection,
final_bounds,
resampling=Resampling.nearest,
**decode_kwargs,
)
src = raster_array.array # (C, T, H, W)

Expand All @@ -459,8 +467,8 @@ def read_raster_layer_for_data_input(
(
len(needed_bands),
t,
final_bounds[3] - final_bounds[1],
final_bounds[2] - final_bounds[0],
src.shape[2],
src.shape[3],
),
dtype=get_torch_dtype(data_input.dtype),
)
Expand Down
111 changes: 111 additions & 0 deletions rslearn/utils/raster_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,3 +773,114 @@ def decode_raster(
array=array[:, np.newaxis, :, :],
timestamps=image_metadata.timestamps,
)


class NumpyRasterMetadata(pydantic.BaseModel):
"""Metadata sidecar for NumpyRasterFormat."""

projection: dict[str, Any]
bounds: PixelBounds
timestamps: list[tuple[datetime, datetime]] | None = None


class NumpyRasterFormat(RasterFormat):
"""A raster format that stores data as a NumPy ``.npy`` file.

This avoids GeoTIFF overhead for small spatial arrays (e.g. 1x1 pixels)
and/or arrays with many bands (e.g. C*T > 1000).

The directory contains two files:
- ``data.npy``: the raw (C, T, H, W) array.
- ``metadata.json``: projection, bounds, dtype, channel/timestep counts,
and optional timestamps.

``decode_raster`` returns the stored array as-is without any reprojection
or resampling -- data is assumed to have been materialized at the target
resolution already.
"""

data_fname = "data.npy"

def encode_raster(
self,
path: UPath,
projection: Projection,
bounds: PixelBounds,
raster: RasterArray,
) -> None:
"""Encode a RasterArray to ``data.npy`` + ``metadata.json``.

Args:
path: directory to write into.
projection: the projection of the raster data.
bounds: the bounds of the raster data in the projection.
raster: the (C, T, H, W) RasterArray to store.
"""
path.mkdir(parents=True, exist_ok=True)

# Write the raw array.
with (path / self.data_fname).open("wb") as f:
np.save(f, raster.array)

# Write the metadata sidecar.
metadata = NumpyRasterMetadata(
projection=projection.serialize(),
bounds=bounds,
timestamps=raster.timestamps,
)
with (path / METADATA_FNAME).open("w") as f:
f.write(metadata.model_dump_json())

def decode_raster(
self,
path: UPath,
projection: Projection,
bounds: PixelBounds,
resampling: Resampling = Resampling.bilinear,
expect_bounds_mismatch: bool = False,
) -> RasterArray:
"""Decode a previously stored ``data.npy`` + ``metadata.json``.

The returned array is the stored array *as-is* -- no reprojection or
resampling is performed. The ``projection``, ``bounds``, and
``resampling`` parameters are accepted for interface conformance but
are not used for spatial transformation.

Args:
path: directory to read from.
projection: used to verify consistency with stored projection.
bounds: used to verify consistency with stored bounds.
resampling: ignored (kept for interface conformance).
expect_bounds_mismatch: if True, a bounds mismatch is expected
(e.g. because spatial_size was used at materialization time,
which stores data in a different pixel coordinate system) and
only triggers a debug log. If False, a mismatch raises
ValueError.

Returns:
the (C, T, H, W) RasterArray.
"""
with (path / METADATA_FNAME).open() as f:
metadata = NumpyRasterMetadata.model_validate_json(f.read())

if metadata.bounds != tuple(bounds):
if expect_bounds_mismatch:
logger.debug(
"NumpyRasterFormat: requested bounds %s differ from stored "
"bounds %s (expected due to spatial_size) "
"— returning stored data as-is",
bounds,
metadata.bounds,
)
else:
raise ValueError(
f"NumpyRasterFormat: requested bounds {bounds} differ from "
f"stored bounds {metadata.bounds}. Unlike GeotiffRasterFormat, "
f"NumpyRasterFormat cannot reproject or crop to different "
f"bounds."
)

with (path / self.data_fname).open("rb") as f:
array = np.load(f)

return RasterArray(array=array, timestamps=metadata.timestamps)
80 changes: 77 additions & 3 deletions tests/unit/config/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,18 @@

import pytest
from pydantic import ValidationError

from rslearn.config.dataset import DType, LayerConfig, QueryConfig, TimeMode
from rasterio.crs import CRS

from rslearn.config.dataset import (
BandSetConfig,
DType,
LayerConfig,
QueryConfig,
TimeMode,
)
from rslearn.data_sources.planetary_computer import Sentinel1, Sentinel2
from rslearn.utils.raster_format import SingleImageRasterFormat
from rslearn.utils.geometry import Projection
from rslearn.utils.raster_format import NumpyRasterFormat, SingleImageRasterFormat
from rslearn.utils.vector_format import TileVectorFormat


Expand Down Expand Up @@ -287,3 +295,69 @@ def test_warning_when_time_mode_set(self) -> None:
dumped = query_config.model_dump()
assert "space_mode" in dumped
assert "time_mode" not in dumped


class TestBandSetConfigSpatialSize:
"""Tests for the spatial_size option on BandSetConfig."""

def test_spatial_size_default_none(self) -> None:
"""Default spatial_size should be None."""
bs = BandSetConfig(dtype=DType.FLOAT32, bands=["a"])
assert bs.spatial_size is None

def test_spatial_size_projection_and_bounds(self) -> None:
"""spatial_size should adjust projection and bounds to target dimensions."""
bs = BandSetConfig(dtype=DType.FLOAT32, bands=["a"], spatial_size=(1, 1))
projection = Projection(CRS.from_epsg(3857), 10.0, -10.0)
bounds = (100, 200, 228, 328) # 128 x 128 pixels

new_proj, new_bounds = bs.get_final_projection_and_bounds(projection, bounds)

# Output should be 1x1 pixels.
assert new_bounds[2] - new_bounds[0] == 1
assert new_bounds[3] - new_bounds[1] == 1

# The resolution should scale up by the original pixel count.
assert new_proj.x_resolution == pytest.approx(10.0 / (1 / 128))
# x_resolution: 10 / (1/128) = 1280 -- each output pixel covers 128 original pixels
assert new_proj.x_resolution == pytest.approx(10.0 * 128)

def test_spatial_size_non_square(self) -> None:
"""spatial_size with non-square dimensions should work correctly."""
bs = BandSetConfig(dtype=DType.FLOAT32, bands=["a"], spatial_size=(2, 4))
projection = Projection(CRS.from_epsg(3857), 1.0, -1.0)
bounds = (0, 0, 100, 200) # 100 x 200 pixels

new_proj, new_bounds = bs.get_final_projection_and_bounds(projection, bounds)

assert new_bounds[2] - new_bounds[0] == 4 # width
assert new_bounds[3] - new_bounds[1] == 2 # height

def test_spatial_size_mutually_exclusive_with_zoom_offset(self) -> None:
"""spatial_size and non-zero zoom_offset should raise an error."""
with pytest.raises(ValidationError, match="mutually exclusive"):
BandSetConfig(
dtype=DType.FLOAT32, bands=["a"], spatial_size=(1, 1), zoom_offset=1
)

def test_spatial_size_zero_rejected(self) -> None:
"""spatial_size with zero value should raise an error."""
with pytest.raises(ValidationError, match="positive integers"):
BandSetConfig(dtype=DType.FLOAT32, bands=["a"], spatial_size=(0, 1))

def test_spatial_size_negative_rejected(self) -> None:
"""spatial_size with negative value should raise an error."""
with pytest.raises(ValidationError, match="positive integers"):
BandSetConfig(dtype=DType.FLOAT32, bands=["a"], spatial_size=(-1, 1))

def test_numpy_raster_format_from_config(self) -> None:
"""NumpyRasterFormat should be instantiable from config via jsonargparse."""
bs = BandSetConfig(
dtype=DType.FLOAT32,
bands=["a"],
format={
"class_path": "rslearn.utils.raster_format.NumpyRasterFormat",
},
)
fmt = bs.instantiate_raster_format()
assert isinstance(fmt, NumpyRasterFormat)
Loading
Loading