diff --git a/src/spatialdata/_core/query/_utils.py b/src/spatialdata/_core/query/_utils.py index 3b63470ed..1f39212c3 100644 --- a/src/spatialdata/_core/query/_utils.py +++ b/src/spatialdata/_core/query/_utils.py @@ -2,13 +2,22 @@ from typing import Any +import numba as nb +import numpy as np from anndata import AnnData +from datatree import DataTree from xarray import DataArray from spatialdata._core._elements import Tables from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata._utils import Number, _parse_list_into_array +from spatialdata.transformations._utils import compute_coordinates +from spatialdata.transformations.transformations import ( + BaseTransformation, + Sequence, + Translation, +) def get_bounding_box_corners( @@ -36,37 +45,146 @@ def get_bounding_box_corners( min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) - if len(min_coordinate) not in (2, 3): + if min_coordinate.ndim == 1: + min_coordinate = min_coordinate[np.newaxis, :] + max_coordinate = max_coordinate[np.newaxis, :] + + if min_coordinate.shape[1] not in (2, 3): raise ValueError("bounding box must be 2D or 3D") - if len(min_coordinate) == 2: + num_boxes = min_coordinate.shape[0] + num_dims = min_coordinate.shape[1] + + if num_dims == 2: # 2D bounding box assert len(axes) == 2 - return DataArray( + corners = np.array( [ - [min_coordinate[0], min_coordinate[1]], - [min_coordinate[0], max_coordinate[1]], - [max_coordinate[0], max_coordinate[1]], - [max_coordinate[0], min_coordinate[1]], - ], - coords={"corner": range(4), "axis": list(axes)}, + [min_coordinate[:, 0], min_coordinate[:, 1]], + [min_coordinate[:, 0], max_coordinate[:, 1]], + [max_coordinate[:, 0], max_coordinate[:, 1]], + [max_coordinate[:, 0], min_coordinate[:, 1]], + ] ) - - # 3D bounding cube - assert len(axes) == 3 - return DataArray( - [ - [min_coordinate[0], min_coordinate[1], min_coordinate[2]], - [min_coordinate[0], min_coordinate[1], max_coordinate[2]], - [min_coordinate[0], max_coordinate[1], max_coordinate[2]], - [min_coordinate[0], max_coordinate[1], min_coordinate[2]], - [max_coordinate[0], min_coordinate[1], min_coordinate[2]], - [max_coordinate[0], min_coordinate[1], max_coordinate[2]], - [max_coordinate[0], max_coordinate[1], max_coordinate[2]], - [max_coordinate[0], max_coordinate[1], min_coordinate[2]], - ], - coords={"corner": range(8), "axis": list(axes)}, + corners = np.transpose(corners, (2, 0, 1)) + else: + # 3D bounding cube + assert len(axes) == 3 + corners = np.array( + [ + [min_coordinate[:, 0], min_coordinate[:, 1], min_coordinate[:, 2]], + [min_coordinate[:, 0], min_coordinate[:, 1], max_coordinate[:, 2]], + [min_coordinate[:, 0], max_coordinate[:, 1], max_coordinate[:, 2]], + [min_coordinate[:, 0], max_coordinate[:, 1], min_coordinate[:, 2]], + [max_coordinate[:, 0], min_coordinate[:, 1], min_coordinate[:, 2]], + [max_coordinate[:, 0], min_coordinate[:, 1], max_coordinate[:, 2]], + [max_coordinate[:, 0], max_coordinate[:, 1], max_coordinate[:, 2]], + [max_coordinate[:, 0], max_coordinate[:, 1], min_coordinate[:, 2]], + ] + ) + corners = np.transpose(corners, (2, 0, 1)) + output = DataArray( + corners, + coords={ + "box": range(num_boxes), + "corner": range(corners.shape[1]), + "axis": list(axes), + }, ) + if num_boxes > 1: + return output + return output.squeeze().drop_vars("box") + + +@nb.njit(parallel=False, nopython=True) +def _create_slices_and_translation( + min_values: nb.types.Array, + max_values: nb.types.Array, +) -> tuple[nb.types.Array, nb.types.Array]: + n_boxes, n_dims = min_values.shape + slices = np.empty((n_boxes, n_dims, 2), dtype=np.float64) # (n_boxes, n_dims, [min, max]) + translation_vectors = np.empty((n_boxes, n_dims), dtype=np.float64) # (n_boxes, n_dims) + + for i in range(n_boxes): + for j in range(n_dims): + slices[i, j, 0] = min_values[i, j] + slices[i, j, 1] = max_values[i, j] + translation_vectors[i, j] = np.ceil(max(min_values[i, j], 0)) + + return slices, translation_vectors + + +def _process_data_tree_query_result(query_result: DataTree) -> DataTree | None: + d = {} + for k, data_tree in query_result.items(): + v = data_tree.values() + assert len(v) == 1 + xdata = v.__iter__().__next__() + if 0 in xdata.shape: + if k == "scale0": + return None + else: + d[k] = xdata + + # Remove scales after finding a missing scale + scales_to_keep = [] + for i, scale_name in enumerate(d.keys()): + if scale_name == f"scale{i}": + scales_to_keep.append(scale_name) + else: + break + + # Case in which scale0 is not present but other scales are + if len(scales_to_keep) == 0: + return None + + d = {k: d[k] for k in scales_to_keep} + result = DataTree.from_dict(d) + + # Rechunk the data to avoid irregular chunks + for scale in result: + result[scale]["image"] = result[scale]["image"].chunk("auto") + + return result + + +def _process_query_result( + result: DataArray | DataTree, translation_vector: ArrayLike, axes: tuple[str, ...] +) -> DataArray | DataTree | None: + from spatialdata.transformations import get_transformation, set_transformation + + if isinstance(result, DataArray): + if 0 in result.shape: + return None + # rechunk the data to avoid irregular chunks + result = result.chunk("auto") + elif isinstance(result, DataTree): + result = _process_data_tree_query_result(result) + if result is None: + return None + + result = compute_coordinates(result) + + if not np.allclose(np.array(translation_vector), 0): + translation_transform = Translation(translation=translation_vector, axes=axes) + + transformations = get_transformation(result, get_all=True) + assert isinstance(transformations, dict) + + new_transformations = {} + for coordinate_system, initial_transform in transformations.items(): + new_transformation: BaseTransformation = Sequence( + [translation_transform, initial_transform], + ) + new_transformations[coordinate_system] = new_transformation + set_transformation(result, new_transformations, set_all=True) + + # let's make a copy of the transformations so that we don't modify the original object + t = get_transformation(result, get_all=True) + assert isinstance(t, dict) + set_transformation(result, t.copy(), set_all=True) + + return result def _get_filtered_or_unfiltered_tables( diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index dea2280a5..a8d0b475d 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -22,6 +22,7 @@ get_bounding_box_corners, ) from spatialdata._core.spatialdata import SpatialData +from spatialdata._docs import docstring_parameter from spatialdata._types import ArrayLike from spatialdata._utils import Number, _parse_list_into_array from spatialdata.models import ( @@ -33,17 +34,24 @@ points_geopandas_to_dask_dataframe, ) from spatialdata.models._utils import ValidAxis_t, get_spatial_axes -from spatialdata.transformations._utils import compute_coordinates from spatialdata.transformations.operations import set_transformation from spatialdata.transformations.transformations import ( Affine, BaseTransformation, - Sequence, - Translation, _get_affine_for_element, ) +MIN_COORDINATE_DOCS = """\ + The upper left hand corners of the bounding boxes (i.e., minimum coordinates along all dimensions). + Shape: (n_boxes, n_axes) or (n_axes,) for a single box. +""" +MAX_COORDINATE_DOCS = """\ + The lower right hand corners of the bounding boxes (i.e., the maximum coordinates along all dimensions). + Shape: (n_boxes, n_axes) +""" + +@docstring_parameter(min_coordinate_docs=MIN_COORDINATE_DOCS, max_coordinate_docs=MAX_COORDINATE_DOCS) def _get_bounding_box_corners_in_intrinsic_coordinates( element: SpatialElement, axes: tuple[str, ...], @@ -60,11 +68,9 @@ def _get_bounding_box_corners_in_intrinsic_coordinates( axes The axes that min_coordinate and max_coordinate refer to. min_coordinate - The upper left hand corner of the bounding box (i.e., minimum coordinates - along all dimensions). + {min_coordinate_docs} max_coordinate - The lower right hand corner of the bounding box (i.e., the maximum coordinates - along all dimensions + {max_coordinate_docs} target_coordinate_system The coordinate system the bounding box is defined in. @@ -86,7 +92,7 @@ def _get_bounding_box_corners_in_intrinsic_coordinates( spatial_transform = Affine(m_without_c, input_axes=input_axes_without_c, output_axes=output_axes_without_c) # we identified 5 cases (see the responsible function for details), cases 1 and 5 correspond to invertible - # transformations; we focus on them + # transformations; we focus on them. The following code triggers a validation that ensures we are in case 1 or 5. m_without_c_linear = m_without_c[:-1, :-1] _ = _get_case_of_bounding_box_query(m_without_c_linear, input_axes_without_c, output_axes_without_c) @@ -120,10 +126,18 @@ def _get_bounding_box_corners_in_intrinsic_coordinates( intrinsic_bounding_box_corners = bounding_box_corners.data @ rotation_matrix.T + translation + if bounding_box_corners.ndim > 2: # multiple boxes + coords = { + "box": range(len(bounding_box_corners)), + "corner": range(bounding_box_corners.shape[1]), + "axis": list(inverse.output_axes), + } + else: + coords = {"corner": range(len(bounding_box_corners)), "axis": list(inverse.output_axes)} return ( DataArray( intrinsic_bounding_box_corners, - coords={"corner": range(len(bounding_box_corners)), "axis": list(inverse.output_axes)}, + coords=coords, ), input_axes_without_c, ) @@ -230,6 +244,9 @@ def _adjust_bounding_box_to_real_axes( The bounding box is defined by the user and its axes may not coincide with the axes of the transformation. """ + # the following variable `axis` is the index of the axis in the variable min_coordinates that corresponds to the + # named axes ('x', 'y', ...). We need it to know at which index to remove/add new named axes + axis = min_coordinate.ndim - 1 if set(axes_bb) != set(axes_out_without_c): axes_only_in_bb = set(axes_bb) - set(axes_out_without_c) axes_only_in_output = set(axes_out_without_c) - set(axes_bb) @@ -238,20 +255,20 @@ def _adjust_bounding_box_to_real_axes( # 3D bounding box) indices_to_remove_from_bb = [axes_bb.index(ax) for ax in axes_only_in_bb] axes_bb = tuple(ax for ax in axes_bb if ax not in axes_only_in_bb) - min_coordinate = np.delete(min_coordinate, indices_to_remove_from_bb) - max_coordinate = np.delete(max_coordinate, indices_to_remove_from_bb) + min_coordinate = np.delete(min_coordinate, indices_to_remove_from_bb, axis=axis) + max_coordinate = np.delete(max_coordinate, indices_to_remove_from_bb, axis=axis) # if there are axes in the output axes that are not in the bounding box, we need to add them to the bounding box # with a range that includes everything (e.g. querying 3D points with a 2D bounding box) + M = np.finfo(np.float32).max - 1 for ax in axes_only_in_output: axes_bb = axes_bb + (ax,) - M = np.finfo(np.float32).max - 1 - min_coordinate = np.append(min_coordinate, -M) - max_coordinate = np.append(max_coordinate, M) + min_coordinate = np.insert(min_coordinate, min_coordinate.shape[axis], -M, axis=axis) + max_coordinate = np.insert(max_coordinate, max_coordinate.shape[axis], M, axis=axis) else: indices = [axes_bb.index(ax) for ax in axes_out_without_c] - min_coordinate = min_coordinate[np.array(indices)] - max_coordinate = max_coordinate[np.array(indices)] + min_coordinate = np.take(min_coordinate, indices, axis=axis) + max_coordinate = np.take(max_coordinate, indices, axis=axis) axes_bb = axes_out_without_c return axes_bb, min_coordinate, max_coordinate @@ -323,6 +340,7 @@ def to_dict(self) -> dict[str, Any]: pass +@docstring_parameter(min_coordinate_docs=MIN_COORDINATE_DOCS, max_coordinate_docs=MAX_COORDINATE_DOCS) @dataclass(frozen=True) class BoundingBoxRequest(BaseSpatialRequest): """Query with an axis-aligned bounding box. @@ -332,11 +350,9 @@ class BoundingBoxRequest(BaseSpatialRequest): axes The axes the coordinates are expressed in. min_coordinate - The coordinate of the lower left hand corner (i.e., minimum values) - of the bounding box. + {min_coordinate_docs} max_coordinate - The coordinate of the upper right hand corner (i.e., maximum values) - of the bounding box + {max_coordinate_docs} """ min_coordinate: ArrayLike @@ -351,7 +367,12 @@ def __post_init__(self) -> None: raise ValueError(f"Non-spatial axes specified: {non_spatial_axes}") # validate the axes - if len(self.axes) != len(self.min_coordinate) or len(self.axes) != len(self.max_coordinate): + if self.min_coordinate.shape != self.max_coordinate.shape: + raise ValueError("The `min_coordinate` and `max_coordinate` must have the same shape.") + + n_axes_coordinate = len(self.min_coordinate) if self.min_coordinate.ndim == 1 else self.min_coordinate.shape[1] + + if len(self.axes) != n_axes_coordinate: raise ValueError("The number of axes must match the number of coordinates.") # validate the coordinates @@ -367,13 +388,14 @@ def to_dict(self) -> dict[str, Any]: } +@docstring_parameter(min_coordinate_docs=MIN_COORDINATE_DOCS, max_coordinate_docs=MAX_COORDINATE_DOCS) def _bounding_box_mask_points( points: DaskDataFrame, axes: tuple[str, ...], min_coordinate: list[Number] | ArrayLike, max_coordinate: list[Number] | ArrayLike, ) -> da.Array: - """Compute a mask that is true for the points inside an axis-aligned bounding box. + """Compute a mask that is true for the points inside axis-aligned bounding boxes. Parameters ---------- @@ -382,30 +404,45 @@ def _bounding_box_mask_points( axes The axes that min_coordinate and max_coordinate refer to. min_coordinate - The upper left hand corner of the bounding box (i.e., minimum coordinates along all dimensions). + PLACEHOLDER + The upper left hand corners of the bounding boxes (i.e., minimum coordinates along all dimensions). + Shape: (n_boxes, n_axes) or (n_axes,) for a single box. + {min_coordinate_docs} max_coordinate - The lower right hand corner of the bounding box (i.e., the maximum coordinates along all dimensions). + The lower right hand corners of the bounding boxes (i.e., the maximum coordinates along all dimensions). + Shape: (n_boxes, n_axes) or (n_axes,) for a single box. + {max_coordinate_docs} Returns ------- - The mask for the points inside the bounding box. + The masks for the points inside the bounding boxes. """ element_axes = get_axes_names(points) + min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) + + # Ensure min_coordinate and max_coordinate are 2D arrays + min_coordinate = min_coordinate[np.newaxis, :] if min_coordinate.ndim == 1 else min_coordinate + max_coordinate = max_coordinate[np.newaxis, :] if max_coordinate.ndim == 1 else max_coordinate + + n_boxes = min_coordinate.shape[0] in_bounding_box_masks = [] - for axis_index, axis_name in enumerate(axes): - if axis_name not in element_axes: - continue - min_value = min_coordinate[axis_index] - in_bounding_box_masks.append(points[axis_name].gt(min_value).to_dask_array(lengths=True)) - for axis_index, axis_name in enumerate(axes): - if axis_name not in element_axes: - continue - max_value = max_coordinate[axis_index] - in_bounding_box_masks.append(points[axis_name].lt(max_value).to_dask_array(lengths=True)) - in_bounding_box_masks = da.stack(in_bounding_box_masks, axis=-1) - return da.all(in_bounding_box_masks, axis=1) + + for box in range(n_boxes): + box_masks = [] + for axis_index, axis_name in enumerate(axes): + if axis_name not in element_axes: + continue + min_value = min_coordinate[box, axis_index] + max_value = max_coordinate[box, axis_index] + box_masks.append( + points[axis_name].gt(min_value).to_dask_array(lengths=True) + & points[axis_name].lt(max_value).to_dask_array(lengths=True) + ) + bounding_box_mask = da.stack(box_masks, axis=-1) + in_bounding_box_masks.append(da.all(bounding_box_mask, axis=1)) + return in_bounding_box_masks def _dict_query_dispatcher( @@ -426,6 +463,7 @@ def _dict_query_dispatcher( return queried_elements +@docstring_parameter(min_coordinate_docs=MIN_COORDINATE_DOCS, max_coordinate_docs=MAX_COORDINATE_DOCS) @singledispatch def bounding_box_query( element: SpatialElement | SpatialData, @@ -445,9 +483,9 @@ def bounding_box_query( axes The axes `min_coordinate` and `max_coordinate` refer to. min_coordinate - The minimum coordinates of the bounding box. + {min_coordinate_docs} max_coordinate - The maximum coordinates of the bounding box. + {max_coordinate_docs} target_coordinate_system The coordinate system the bounding box is defined in. filter_table @@ -503,7 +541,7 @@ def _( max_coordinate: list[Number] | ArrayLike, target_coordinate_system: str, return_request_only: bool = False, -) -> DataArray | DataTree | Mapping[str, slice] | None: +) -> DataArray | DataTree | Mapping[str, slice] | list[DataArray] | list[DataTree] | None: """Implement bounding box query for Spatialdata supported DataArray. Notes @@ -511,7 +549,7 @@ def _( See https://github.com/scverse/spatialdata/pull/151 for a detailed overview of the logic of this code, and for the cases the comments refer to. """ - from spatialdata.transformations import get_transformation, set_transformation + from spatialdata._core.query._utils import _create_slices_and_translation, _process_query_result min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) @@ -530,92 +568,48 @@ def _( if TYPE_CHECKING: assert isinstance(intrinsic_bounding_box_corners, DataArray) - # build the request: now that we have the bounding box corners in the intrinsic coordinate system, we can use them - # to build the request to query the raster data using the xarray APIs - selection = {} - translation_vector = [] - for axis_name in axes: - # get the min value along the axis - min_value = intrinsic_bounding_box_corners.sel(axis=axis_name).min().item() + min_values = intrinsic_bounding_box_corners.min(dim="corner") + max_values = intrinsic_bounding_box_corners.max(dim="corner") - # get max value, slices are open half interval - max_value = intrinsic_bounding_box_corners.sel(axis=axis_name).max().item() + min_values_np = min_values.data + max_values_np = max_values.data - # add the - selection[axis_name] = slice(min_value, max_value) + if min_values_np.ndim == 1: + min_values_np = min_values_np[np.newaxis, :] + max_values_np = max_values_np[np.newaxis, :] - if min_value > 0: - translation_vector.append(np.ceil(min_value).item()) - else: - translation_vector.append(0) + slices, translation_vectors = _create_slices_and_translation(min_values_np, max_values_np) + + if min_values.ndim == 2: # Multiple boxes + selection: list[dict[str, Any]] | dict[str, Any] = [ + { + axis: slice(slices[box_idx, axis_idx, 0], slices[box_idx, axis_idx, 1]) + for axis_idx, axis in enumerate(axes) + } + for box_idx in range(len(min_values_np)) + ] + translation_vectors = translation_vectors.tolist() + else: # Single box + selection = {axis: slice(slices[0, axis_idx, 0], slices[0, axis_idx, 1]) for axis_idx, axis in enumerate(axes)} + translation_vectors = translation_vectors[0].tolist() if return_request_only: return selection # query the data - query_result = image.sel(selection) - if isinstance(image, DataArray): - if 0 in query_result.shape: - return None - assert isinstance(query_result, DataArray) - # rechunk the data to avoid irregular chunks - image = image.chunk("auto") + query_result: DataArray | DataTree | list[DataArray] | list[DataTree] | None = ( + image.sel(selection) if isinstance(selection, dict) else [image.sel(sel) for sel in selection] + ) + + if isinstance(query_result, list): + processed_results = [] + for result, translation_vector in zip(query_result, translation_vectors): + processed_result = _process_query_result(result, translation_vector, axes) + if processed_result is not None: + processed_results.append(processed_result) + query_result = processed_results if processed_results else None else: - assert isinstance(image, DataTree) - assert isinstance(query_result, DataTree) - - d = {} - for k, data_tree in query_result.items(): - v = data_tree.values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - if 0 in xdata.shape: - if k == "scale0": - return None - else: - d[k] = xdata - # the list of scales may not be contiguous when the data has small shape (for instance with yx = 22 and - # rotations we may end up having scale0 and scale2 but not scale1. Practically this may occur in torch tiler if - # the tiles are request to be too small). - # Here we remove scales after we found a scale missing - scales_to_keep = [] - for i, scale_name in enumerate(d.keys()): - if scale_name == f"scale{i}": - scales_to_keep.append(scale_name) - else: - break - # case in which scale0 is not present but other scales are - if len(scales_to_keep) == 0: - return None - d = {k: d[k] for k in scales_to_keep} - - query_result = DataTree.from_dict(d) - # rechunk the data to avoid irregular chunks - for scale in query_result: - query_result[scale]["image"] = query_result[scale]["image"].chunk("auto") - query_result = compute_coordinates(query_result) - - # the bounding box, mapped back to the intrinsic coordinate system is a set of points. The bounding box of these - # points is likely starting away from the origin (this is described by translation_vector), so we need to prepend - # this translation to every transformation in the new queries elements (unless the translation_vector is zero, - # in that case the translation is not needed) - if not np.allclose(np.array(translation_vector), 0): - translation_transform = Translation(translation=translation_vector, axes=axes) - - transformations = get_transformation(query_result, get_all=True) - assert isinstance(transformations, dict) - - new_transformations = {} - for coordinate_system, initial_transform in transformations.items(): - new_transformation: BaseTransformation = Sequence( - [translation_transform, initial_transform], - ) - new_transformations[coordinate_system] = new_transformation - set_transformation(query_result, new_transformations, set_all=True) - # let's make a copy of the transformations so that we don't modify the original object - t = get_transformation(query_result, get_all=True) - assert isinstance(t, dict) - set_transformation(query_result, t.copy(), set_all=True) + query_result = _process_query_result(query_result, translation_vectors, axes) return query_result @@ -626,13 +620,17 @@ def _( min_coordinate: list[Number] | ArrayLike, max_coordinate: list[Number] | ArrayLike, target_coordinate_system: str, -) -> DaskDataFrame | None: +) -> DaskDataFrame | list[DaskDataFrame] | None: from spatialdata import transform from spatialdata.transformations import get_transformation min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) + # Ensure min_coordinate and max_coordinate are 2D arrays + min_coordinate = min_coordinate[np.newaxis, :] if min_coordinate.ndim == 1 else min_coordinate + max_coordinate = max_coordinate[np.newaxis, :] if max_coordinate.ndim == 1 else max_coordinate + # for triggering validation _ = BoundingBoxRequest( target_coordinate_system=target_coordinate_system, @@ -649,9 +647,11 @@ def _( max_coordinate=max_coordinate, target_coordinate_system=target_coordinate_system, ) - intrinsic_bounding_box_corners = intrinsic_bounding_box_corners.data - min_coordinate_intrinsic = intrinsic_bounding_box_corners.min(axis=0) - max_coordinate_intrinsic = intrinsic_bounding_box_corners.max(axis=0) + min_coordinate_intrinsic = intrinsic_bounding_box_corners.min(dim="corner") + max_coordinate_intrinsic = intrinsic_bounding_box_corners.max(dim="corner") + + min_coordinate_intrinsic = min_coordinate_intrinsic.data + max_coordinate_intrinsic = max_coordinate_intrinsic.data # get the points in the intrinsic coordinate bounding box in_intrinsic_bounding_box = _bounding_box_mask_points( @@ -660,10 +660,20 @@ def _( min_coordinate=min_coordinate_intrinsic, max_coordinate=max_coordinate_intrinsic, ) - # if there aren't any points, just return - if in_intrinsic_bounding_box.sum() == 0: + + # assert that the number of bounding boxes is correct + assert len(in_intrinsic_bounding_box) == len(min_coordinate) + points_in_intrinsic_bounding_box: list[DaskDataFrame | None] = [] + for mask in in_intrinsic_bounding_box: + if mask.sum() == 0: + points_in_intrinsic_bounding_box.append(None) + else: + points_in_intrinsic_bounding_box.append(points.loc[mask]) + if len(points_in_intrinsic_bounding_box) == 0: return None - points_in_intrinsic_bounding_box = points.loc[in_intrinsic_bounding_box] + + # assert that the number of queried points is correct + assert len(points_in_intrinsic_bounding_box) == len(min_coordinate) # # we have to reset the index since we have subset # # https://stackoverflow.com/questions/61395351/how-to-reset-index-on-concatenated-dataframe-in-dask @@ -677,25 +687,42 @@ def _( # points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.drop(columns=["idx"]) # transform the element to the query coordinate system - points_query_coordinate_system = transform( - points_in_intrinsic_bounding_box, to_coordinate_system=target_coordinate_system, maintain_positioning=False - ) # type: ignore[union-attr] + output: list[DaskDataFrame | None] = [] + for p, min_c, max_c in zip(points_in_intrinsic_bounding_box, min_coordinate, max_coordinate): + if p is None: + output.append(None) + else: + points_query_coordinate_system = transform( + p, to_coordinate_system=target_coordinate_system, maintain_positioning=False + ) - # get a mask for the points in the bounding box - bounding_box_mask = _bounding_box_mask_points( - points=points_query_coordinate_system, - axes=axes, - min_coordinate=min_coordinate, - max_coordinate=max_coordinate, - ) - bounding_box_indices = np.where(bounding_box_mask.compute())[0] - if len(bounding_box_indices) == 0: + # get a mask for the points in the bounding box + bounding_box_mask = _bounding_box_mask_points( + points=points_query_coordinate_system, + axes=axes, + min_coordinate=min_c, + max_coordinate=max_c, + ) + if len(bounding_box_mask) == 1: + bounding_box_mask = bounding_box_mask[0] + bounding_box_indices = np.where(bounding_box_mask.compute())[0] + + if len(bounding_box_indices) == 0: + output.append(None) + else: + points_df = p.compute().iloc[bounding_box_indices] + old_transformations = get_transformation(p, get_all=True) + assert isinstance(old_transformations, dict) + output.append( + PointsModel.parse( + dd.from_pandas(points_df, npartitions=1), transformations=old_transformations.copy() + ) + ) + if len(output) == 0: return None - points_df = points_in_intrinsic_bounding_box.compute().iloc[bounding_box_indices] - old_transformations = get_transformation(points, get_all=True) - assert isinstance(old_transformations, dict) - # an alternative approach is to query for each partition in parallel - return PointsModel.parse(dd.from_pandas(points_df, npartitions=1), transformations=old_transformations.copy()) + if len(output) == 1: + return output[0] + return output @bounding_box_query.register(GeoDataFrame) @@ -705,7 +732,7 @@ def _( min_coordinate: list[Number] | ArrayLike, max_coordinate: list[Number] | ArrayLike, target_coordinate_system: str, -) -> GeoDataFrame | None: +) -> GeoDataFrame | list[GeoDataFrame] | None: from spatialdata.transformations import get_transformation min_coordinate = _parse_list_into_array(min_coordinate) @@ -727,16 +754,32 @@ def _( max_coordinate=max_coordinate, target_coordinate_system=target_coordinate_system, ) - intrinsic_bounding_box_corners = intrinsic_bounding_box_corners.data - bounding_box_non_axes_aligned = Polygon(intrinsic_bounding_box_corners) - indices = polygons.geometry.intersects(bounding_box_non_axes_aligned) - queried = polygons[indices] - if len(queried) == 0: - return None + + # Create a list of Polygons for each bounding box old_transformations = get_transformation(polygons, get_all=True) assert isinstance(old_transformations, dict) - del queried.attrs[ShapesModel.TRANSFORM_KEY] - return ShapesModel.parse(queried, transformations=old_transformations.copy()) + + queried_polygons = [] + intrinsic_bounding_box_corners = ( + intrinsic_bounding_box_corners.expand_dims(dim="box") + if "box" not in intrinsic_bounding_box_corners.dims + else intrinsic_bounding_box_corners + ) + for box_corners in intrinsic_bounding_box_corners: + bounding_box_non_axes_aligned = Polygon(box_corners.data) + indices = polygons.geometry.intersects(bounding_box_non_axes_aligned) + queried = polygons[indices] + if len(queried) == 0: + queried_polygon = None + else: + del queried.attrs[ShapesModel.TRANSFORM_KEY] + queried_polygon = ShapesModel.parse(queried, transformations=old_transformations.copy()) + queried_polygons.append(queried_polygon) + if len(queried_polygons) == 0: + return None + if len(queried_polygons) == 1: + return queried_polygons[0] + return queried_polygons # TODO: we can replace the manually triggered deprecation warning heres with the decorator from Wouter @@ -823,7 +866,6 @@ def _( images: bool = True, labels: bool = True, ) -> SpatialData: - _check_deprecated_kwargs({"shapes": shapes, "points": points, "images": images, "labels": labels}) new_elements = {} for element_type in ["points", "images", "labels", "shapes"]: diff --git a/src/spatialdata/_docs.py b/src/spatialdata/_docs.py new file mode 100644 index 000000000..d3e9dae74 --- /dev/null +++ b/src/spatialdata/_docs.py @@ -0,0 +1,13 @@ +# from https://stackoverflow.com/questions/10307696/how-to-put-a-variable-into-python-docstring +from typing import Any, Callable, TypeVar + +T = TypeVar("T") + + +def docstring_parameter(*args: Any, **kwargs: Any) -> Callable[[T], T]: + def dec(obj: T) -> T: + if obj.__doc__: + obj.__doc__ = obj.__doc__.format(*args, **kwargs) + return obj + + return dec diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index 9444d8e9e..bc50a5a4e 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -108,11 +108,12 @@ def test_bounding_box_request_wrong_coordinate_order(): @pytest.mark.parametrize("is_3d", [True, False]) @pytest.mark.parametrize("is_bb_3d", [True, False]) @pytest.mark.parametrize("with_polygon_query", [True, False]) -def test_query_points(is_3d: bool, is_bb_3d: bool, with_polygon_query: bool): +@pytest.mark.parametrize("multiple_boxes", [True, False]) +def test_query_points(is_3d: bool, is_bb_3d: bool, with_polygon_query: bool, multiple_boxes: bool): """test the points bounding box_query""" - data_x = np.array([10, 20, 20, 20]) - data_y = np.array([10, 20, 30, 30]) - data_z = np.array([100, 200, 200, 300]) + data_x = np.array([10, 20, 20, 20, 40]) + data_y = np.array([10, 20, 30, 30, 50]) + data_z = np.array([100, 200, 200, 300, 500]) data = np.stack((data_x, data_y), axis=1) if is_3d: @@ -125,16 +126,24 @@ def test_query_points(is_3d: bool, is_bb_3d: bool, with_polygon_query: bool): original_z = points_element["z"] if is_bb_3d: - _min_coordinate = np.array([18, 25, 250]) - _max_coordinate = np.array([22, 35, 350]) + if multiple_boxes: + _min_coordinate = np.array([[18, 25, 250], [35, 45, 450], [100, 110, 1100]]) + _max_coordinate = np.array([[22, 35, 350], [45, 55, 550], [110, 120, 1200]]) + else: + _min_coordinate = np.array([18, 25, 250]) + _max_coordinate = np.array([22, 35, 350]) _axes = ("x", "y", "z") else: - _min_coordinate = np.array([18, 25]) - _max_coordinate = np.array([22, 35]) + if multiple_boxes: + _min_coordinate = np.array([[18, 25], [35, 45], [100, 110]]) + _max_coordinate = np.array([[22, 35], [45, 55], [110, 120]]) + else: + _min_coordinate = np.array([18, 25]) + _max_coordinate = np.array([22, 35]) _axes = ("x", "y") if with_polygon_query: - if is_bb_3d: + if is_bb_3d or multiple_boxes: return polygon = Polygon([(18, 25), (18, 35), (22, 35), (22, 25)]) points_result = polygon_query(points_element, polygon=polygon, target_coordinate_system="global") @@ -147,22 +156,49 @@ def test_query_points(is_3d: bool, is_bb_3d: bool, with_polygon_query: bool): target_coordinate_system="global", ) - # Check that the correct point was selected + # Check that the correct points were selected if is_3d: if is_bb_3d: - np.testing.assert_allclose(points_result["x"].compute(), [20]) - np.testing.assert_allclose(points_result["y"].compute(), [30]) - np.testing.assert_allclose(points_result["z"].compute(), [300]) + if multiple_boxes: + np.testing.assert_allclose(points_result[0]["x"].compute(), [20]) + np.testing.assert_allclose(points_result[0]["y"].compute(), [30]) + np.testing.assert_allclose(points_result[0]["z"].compute(), [300]) + np.testing.assert_allclose(points_result[1]["x"].compute(), [40]) + np.testing.assert_allclose(points_result[1]["y"].compute(), [50]) + np.testing.assert_allclose(points_result[1]["z"].compute(), [500]) + else: + np.testing.assert_allclose(points_result["x"].compute(), [20]) + np.testing.assert_allclose(points_result["y"].compute(), [30]) + np.testing.assert_allclose(points_result["z"].compute(), [300]) + else: + if multiple_boxes: + np.testing.assert_allclose(points_result[0]["x"].compute(), [20, 20]) + np.testing.assert_allclose(points_result[0]["y"].compute(), [30, 30]) + np.testing.assert_allclose(points_result[0]["z"].compute(), [200, 300]) + np.testing.assert_allclose(points_result[1]["x"].compute(), [40]) + np.testing.assert_allclose(points_result[1]["y"].compute(), [50]) + np.testing.assert_allclose(points_result[1]["z"].compute(), [500]) + else: + np.testing.assert_allclose(points_result["x"].compute(), [20, 20]) + np.testing.assert_allclose(points_result["y"].compute(), [30, 30]) + np.testing.assert_allclose(points_result["z"].compute(), [200, 300]) + else: + if multiple_boxes: + np.testing.assert_allclose(points_result[0]["x"].compute(), [20, 20]) + np.testing.assert_allclose(points_result[0]["y"].compute(), [30, 30]) + np.testing.assert_allclose(points_result[1]["x"].compute(), [40]) + np.testing.assert_allclose(points_result[1]["y"].compute(), [50]) + assert points_result[2] is None else: np.testing.assert_allclose(points_result["x"].compute(), [20, 20]) np.testing.assert_allclose(points_result["y"].compute(), [30, 30]) - np.testing.assert_allclose(points_result["z"].compute(), [200, 300]) - else: - np.testing.assert_allclose(points_result["x"].compute(), [20, 20]) - np.testing.assert_allclose(points_result["y"].compute(), [30, 30]) # result should be valid points element - PointsModel.validate(points_result) + if multiple_boxes: + for result in points_result: + if result is None: + continue + PointsModel.validate(result) # original element should be unchanged np.testing.assert_allclose(points_element["x"].compute(), original_x) @@ -192,8 +228,15 @@ def test_query_points_no_points(): @pytest.mark.parametrize("is_bb_3d", [True, False]) @pytest.mark.parametrize("with_polygon_query", [True, False]) @pytest.mark.parametrize("return_request_only", [True, False]) +@pytest.mark.parametrize("multiple_boxes", [True, False]) def test_query_raster( - n_channels: int, is_labels: bool, is_3d: bool, is_bb_3d: bool, with_polygon_query: bool, return_request_only: bool + n_channels: int, + is_labels: bool, + is_3d: bool, + is_bb_3d: bool, + with_polygon_query: bool, + return_request_only: bool, + multiple_boxes: bool, ): """Apply a bounding box to a raster element.""" if is_labels and n_channels > 1: @@ -232,16 +275,16 @@ def test_query_raster( for image in images: if is_bb_3d: - _min_coordinate = np.array([2, 5, 0]) - _max_coordinate = np.array([7, 10, 5]) + _min_coordinate = np.array([[2, 5, 0], [1, 4, 0]]) if multiple_boxes else np.array([2, 5, 0]) + _max_coordinate = np.array([[7, 10, 5], [6, 9, 4]]) if multiple_boxes else np.array([7, 10, 5]) _axes = ("z", "y", "x") else: - _min_coordinate = np.array([5, 0]) - _max_coordinate = np.array([10, 5]) + _min_coordinate = np.array([[5, 0], [4, 0]]) if multiple_boxes else np.array([5, 0]) + _max_coordinate = np.array([[10, 5], [9, 4]]) if multiple_boxes else np.array([10, 5]) _axes = ("y", "x") if with_polygon_query: - if is_bb_3d: + if is_bb_3d or multiple_boxes: return # make a triangle whose bounding box is the same as the bounding box specified with the query polygon = Polygon([(0, 5), (5, 5), (5, 10)]) @@ -258,36 +301,67 @@ def test_query_raster( return_request_only=return_request_only, ) - slices = {"y": slice(5, 10), "x": slice(0, 5)} - if is_bb_3d and is_3d: - slices["z"] = slice(2, 7) + if multiple_boxes: + slices = [{"y": slice(5, 10), "x": slice(0, 5)}, {"y": slice(4, 9), "x": slice(0, 4)}] + if is_bb_3d and is_3d: + slices[0]["z"] = slice(2, 7) + slices[1]["z"] = slice(1, 6) + else: + slices = {"y": slice(5, 10), "x": slice(0, 5)} + if is_bb_3d and is_3d: + slices["z"] = slice(2, 7) + if return_request_only: - assert isinstance(image_result, dict) - if not (is_bb_3d and is_3d) and ("z" in image_result): - image_result.pop("z") # remove z from slices if `polygon_query` - for k, v in image_result.items(): - assert isinstance(v, slice) - assert image_result[k] == slices[k] + assert isinstance(image_result, (dict, list)) + if multiple_boxes: + for i, result in enumerate(image_result): + if not (is_bb_3d and is_3d) and ("z" in result): + result.pop("z") # remove z from slices if `polygon_query` + for k, v in result.items(): + assert isinstance(v, slice) + assert result[k] == slices[i][k] + else: + if not (is_bb_3d and is_3d) and ("z" in image_result): + image_result.pop("z") # remove z from slices if `polygon_query` + for k, v in image_result.items(): + assert isinstance(v, slice) + assert image_result[k] == slices[k] return - expected_image = ximage.sel(**slices) + if multiple_boxes: + expected_images = [ximage.sel(**s) for s in slices] + else: + expected_image = ximage.sel(**slices) if isinstance(image, DataArray): - assert isinstance(image, DataArray) - np.testing.assert_allclose(image_result, expected_image) + assert isinstance(image_result, (DataArray, list)) + if multiple_boxes: + for result, expected in zip(image_result, expected_images): + np.testing.assert_allclose(result, expected) + else: + np.testing.assert_allclose(image_result, expected_image) elif isinstance(image, DataTree): - assert isinstance(image_result, DataTree) - v = image_result["scale0"].values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - np.testing.assert_allclose(xdata, expected_image) + assert isinstance(image_result, (DataTree, list)) + if multiple_boxes: + for result, expected in zip(image_result, expected_images): + v = result["scale0"].values() + assert len(v) == 1 + xdata = v.__iter__().__next__() + np.testing.assert_allclose(xdata, expected) + else: + v = image_result["scale0"].values() + assert len(v) == 1 + xdata = v.__iter__().__next__() + np.testing.assert_allclose(xdata, expected_image) else: raise ValueError("Unexpected type") @pytest.mark.parametrize("is_bb_3d", [True, False]) @pytest.mark.parametrize("with_polygon_query", [True, False]) -def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool): +@pytest.mark.parametrize("multiple_boxes", [True, False]) +@pytest.mark.parametrize("box_outside_polygon", [True, False]) +def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool, multiple_boxes: bool, box_outside_polygon: bool): centroids = np.array([[10, 10], [10, 80], [80, 20], [70, 60]]) half_widths = [6] * 4 sd_polygons = _make_squares(centroid_coordinates=centroids, half_widths=half_widths) @@ -303,12 +377,20 @@ def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool): ) else: if is_bb_3d: - _min_coordinate = np.array([2, 40, 40]) - _max_coordinate = np.array([7, 100, 100]) + _min_coordinate = np.array([[2, 40, 40], [2, 50, 50]]) if multiple_boxes else np.array([2, 40, 40]) + _max_coordinate = np.array([[7, 100, 100], [7, 110, 110]]) if multiple_boxes else np.array([7, 100, 100]) + if box_outside_polygon: + _min_coordinate = np.array([[2, 100, 100], [2, 50, 50]]) if multiple_boxes else np.array([2, 40, 40]) + _max_coordinate = ( + np.array([[7, 110, 110], [7, 110, 110]]) if multiple_boxes else np.array([7, 100, 100]) + ) _axes = ("z", "y", "x") else: - _min_coordinate = np.array([40, 40]) - _max_coordinate = np.array([100, 100]) + _min_coordinate = np.array([[40, 40], [50, 50]]) if multiple_boxes else np.array([40, 40]) + _max_coordinate = np.array([[100, 100], [110, 110]]) if multiple_boxes else np.array([100, 100]) + if box_outside_polygon: + _min_coordinate = np.array([[100, 100], [50, 50]]) if multiple_boxes else np.array([40, 40]) + _max_coordinate = np.array([[110, 110], [110, 110]]) if multiple_boxes else np.array([100, 100]) _axes = ("y", "x") polygons_result = bounding_box_query( @@ -319,8 +401,19 @@ def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool): max_coordinate=_max_coordinate, ) - assert len(polygons_result) == 1 - assert polygons_result.index[0] == 3 + if multiple_boxes and not with_polygon_query: + assert isinstance(polygons_result, list) + assert len(polygons_result) == 2 + if box_outside_polygon: + + assert polygons_result[0] is None + assert polygons_result[1].index[0] == 3 + else: + assert polygons_result[0].index[0] == 3 + assert len(polygons_result[1]) == 1 + else: + assert len(polygons_result) == 1 + assert polygons_result.index[0] == 3 @pytest.mark.parametrize("is_bb_3d", [True, False]) @@ -721,3 +814,28 @@ def query_polyon_contains_queried_data(extent: dict[str, tuple[float, float]]) - query_polyon_contains_queried_data(extent_circles) query_polyon_contains_queried_data(extent_polygons) + + +def test_query_multiple_boxes_len_one(sdata_blobs): + """ + Tests that querying by a list of bounding boxes with length one is equivalent to querying by a single bounding box. + """ + min_coordinate = np.array([[80, 80]]) + max_coordinate = np.array([[165, 150]]) + axes = ("x", "y") + + queried0 = bounding_box_query( + sdata_blobs, + axes=axes, + min_coordinate=min_coordinate, + max_coordinate=max_coordinate, + target_coordinate_system="global", + ) + queried1 = bounding_box_query( + sdata_blobs, + axes=axes, + min_coordinate=min_coordinate[0], + max_coordinate=max_coordinate[0], + target_coordinate_system="global", + ) + assert_spatial_data_objects_are_identical(queried0, queried1)