From b6ea38c18e78da30976a99d925effbad67b32562 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Tue, 4 Mar 2025 19:23:38 +0100 Subject: [PATCH 1/4] mvp --- src/squidpy/__init__.py | 2 +- src/squidpy/pp/__init__.py | 5 ++ src/squidpy/pp/_simple.py | 137 +++++++++++++++++++++++++++++++++++++ 3 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 src/squidpy/pp/__init__.py create mode 100644 src/squidpy/pp/_simple.py diff --git a/src/squidpy/__init__.py b/src/squidpy/__init__.py index 5fb2b848a..d52a64cbd 100644 --- a/src/squidpy/__init__.py +++ b/src/squidpy/__init__.py @@ -3,7 +3,7 @@ from importlib import metadata from importlib.metadata import PackageMetadata -from squidpy import datasets, gr, im, pl, read, tl +from squidpy import datasets, gr, im, pl, pp, read, tl try: md: PackageMetadata = metadata.metadata(__name__) diff --git a/src/squidpy/pp/__init__.py b/src/squidpy/pp/__init__.py new file mode 100644 index 000000000..26edb04a8 --- /dev/null +++ b/src/squidpy/pp/__init__.py @@ -0,0 +1,5 @@ +"""Basic pre-processing functions adapted from scanpy.""" + +from __future__ import annotations + +from squidpy.pp._simple import filter_cells diff --git a/src/squidpy/pp/_simple.py b/src/squidpy/pp/_simple.py new file mode 100644 index 000000000..35a588128 --- /dev/null +++ b/src/squidpy/pp/_simple.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import anndata as ad +import numpy as np +import pandas as pd +import scanpy as sc +import spatialdata as sd + + +def filter_cells( + data: ad.AnnData | sd.SpatialData, + table: str | None = None, + min_counts: int | None = None, + min_genes: int | None = None, + max_counts: int | None = None, + max_genes: int | None = None, + inplace: bool = True, +) -> ad.AnnData | sd.SpatialData | None: + if not isinstance(data, ad.AnnData | sd.SpatialData): + raise ValueError(f"Expected `AnnData` or `SpatialData`, found `{type(data)}`") + + if isinstance(data, ad.AnnData) and table is not None: + raise ValueError("When filtering `AnnData`, `table` is not used.") + + tables_to_use: list[str] = [] + + if isinstance(data, sd.SpatialData) and table is not None: + if isinstance(table, str): + tables_to_use = [table] + + if isinstance(data, sd.SpatialData) and table is None: + if isinstance(table, str): + tables_to_use = list(data.tables.keys()) + + if tables_to_use is not None and len(tables_to_use) == 0: + raise ValueError("Expected at least one table to be filtered, found `0`") + + if any(t not in data.tables for t in tables_to_use): + raise ValueError(f"Expected all tables to be in `{data.tables.keys()}`.`") + + # mimic scanpy's behavior in only allowing one filtering parameter per call + n_given_options = sum(option is not None for option in [min_genes, min_counts, max_genes, max_counts]) + if n_given_options > 1: + raise ValueError("Only one filtering parameter can be provided per call (scanpy behavior).") + + for param_name, param_value in [ + ("min_counts", min_counts), + ("min_genes", min_genes), + ("max_counts", max_counts), + ("max_genes", max_genes), + ]: + if param_value is not None and not isinstance(param_value, int): + raise ValueError(f"Expected `{param_name}` to be an integer, found `{type(param_value)}`") + + if not isinstance(inplace, bool): + raise ValueError(f"Expected `inplace` to be a boolean, found `{type(inplace)}`") + + def _apply_anndata_filters( + data: ad.AnnData, + min_counts: int | None, + min_genes: int | None, + max_counts: int | None, + max_genes: int | None, + inplace: bool = True, + ) -> ad.AnnData | None: + result = data if inplace else data.copy() + + # robust way to feed in whichever filtering parameters is not None + filter_params = { + "min_counts": min_counts, + "min_genes": min_genes, + "max_counts": max_counts, + "max_genes": max_genes, + } + + for param_name, param_value in filter_params.items(): + if param_value is not None: + # Always modify result in place since we're using our own copy + sc.pp.filter_cells(result, **{param_name: param_value}, inplace=inplace) + + # Return the filtered data if not in place + return None if inplace else result + + if isinstance(data, ad.AnnData): + data_out = data if inplace else data.copy() + + _apply_anndata_filters(data_out, min_counts, min_genes, max_counts, max_genes, inplace=inplace) + + return None if inplace else data_out + + # if it's SpatialData, we need to filter other elements in the object + elif isinstance(data, sd.SpatialData): + if not inplace: + data_out = sd.SpatialData( + images=data.images if data.images is not None else None, + labels=data.labels if data.labels is not None else None, + points=data.points if data.points is not None else None, + shapes=data.shapes if data.shapes is not None else None, + tables=data.tables if data.tables is not None else None, + ) + else: + data_out = data + + for t in tables_to_use: + if "spatialdata_attrs" in data.tables[t].uns: + instance_key = data.tables[t].uns["spatialdata_attrs"]["instance_key"] + region_key = data.tables[t].uns["spatialdata_attrs"]["region_key"] + region = data.tables[t].uns["spatialdata_attrs"]["region"] + + filter_params = { + "min_counts": min_counts, + "min_genes": min_genes, + "max_counts": max_counts, + "max_genes": max_genes, + } + + # remove the rows from the table + table_old = data.tables[t].copy() + for param_name, param_value in filter_params.items(): + if param_value is not None: + # Always modify result in place since we're using our own copy + mask, _ = sc.pp.filter_cells(table_old, **{param_name: param_value}, inplace=inplace) + + data_out.tables[t] = table_old[~mask] + + # remove the rows from the shapes + removed_obs = table_old.obs[mask][[instance_key, region_key]] + + assert removed_obs[region_key].unique() == region + + idx_to_remove = removed_obs[instance_key].values.tolist() + ele_to_modify = data.shapes[region].copy() + filtered_gdf = ele_to_modify[~ele_to_modify.index.isin(idx_to_remove)] + + data_out.shapes[region] = filtered_gdf + + return data_out From 1a6b5f8a8b8d27bf3feedfd5d8d07e21e0e41f92 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Tue, 11 Mar 2025 00:22:35 +0100 Subject: [PATCH 2/4] got labels to work --- src/squidpy/pp/_simple.py | 243 ++++++++++++++++++++++++++++---------- 1 file changed, 180 insertions(+), 63 deletions(-) diff --git a/src/squidpy/pp/_simple.py b/src/squidpy/pp/_simple.py index 35a588128..0b1d1753e 100644 --- a/src/squidpy/pp/_simple.py +++ b/src/squidpy/pp/_simple.py @@ -5,6 +5,11 @@ import pandas as pd import scanpy as sc import spatialdata as sd +import xarray as xr +from xarray import DataTree +from spatialdata._logging import logger as logg +from spatialdata.models import Labels2DModel, PointsModel, ShapesModel, get_model +from spatialdata.transformations import get_transformation def filter_cells( @@ -15,6 +20,7 @@ def filter_cells( max_counts: int | None = None, max_genes: int | None = None, inplace: bool = True, + filter_labels: bool = True, ) -> ad.AnnData | sd.SpatialData | None: if not isinstance(data, ad.AnnData | sd.SpatialData): raise ValueError(f"Expected `AnnData` or `SpatialData`, found `{type(data)}`") @@ -29,19 +35,26 @@ def filter_cells( tables_to_use = [table] if isinstance(data, sd.SpatialData) and table is None: - if isinstance(table, str): - tables_to_use = list(data.tables.keys()) + tables_to_use = list(data.tables.keys()) - if tables_to_use is not None and len(tables_to_use) == 0: + if ( + not isinstance(data, ad.AnnData) + and tables_to_use is not None + and not tables_to_use + ): raise ValueError("Expected at least one table to be filtered, found `0`") if any(t not in data.tables for t in tables_to_use): raise ValueError(f"Expected all tables to be in `{data.tables.keys()}`.`") # mimic scanpy's behavior in only allowing one filtering parameter per call - n_given_options = sum(option is not None for option in [min_genes, min_counts, max_genes, max_counts]) + n_given_options = sum( + option is not None for option in [min_genes, min_counts, max_genes, max_counts] + ) if n_given_options > 1: - raise ValueError("Only one filtering parameter can be provided per call (scanpy behavior).") + raise ValueError( + "Only one filtering parameter can be provided per call (scanpy behavior)." + ) for param_name, param_value in [ ("min_counts", min_counts), @@ -50,22 +63,20 @@ def filter_cells( ("max_genes", max_genes), ]: if param_value is not None and not isinstance(param_value, int): - raise ValueError(f"Expected `{param_name}` to be an integer, found `{type(param_value)}`") + raise ValueError( + f"Expected `{param_name}` to be an integer, found `{type(param_value)}`" + ) if not isinstance(inplace, bool): raise ValueError(f"Expected `inplace` to be a boolean, found `{type(inplace)}`") - def _apply_anndata_filters( - data: ad.AnnData, - min_counts: int | None, - min_genes: int | None, - max_counts: int | None, - max_genes: int | None, - inplace: bool = True, - ) -> ad.AnnData | None: - result = data if inplace else data.copy() - - # robust way to feed in whichever filtering parameters is not None + # if it's an AnnData object, we add a pseudo tablename + if isinstance(data, ad.AnnData): + tables_to_use = ["adata"] + + # we need to filter the adata object either way + for t in tables_to_use: + filter_params = { "min_counts": min_counts, "min_genes": min_genes, @@ -73,65 +84,171 @@ def _apply_anndata_filters( "max_genes": max_genes, } + if isinstance(data, ad.AnnData): + table_old = data + else: + table_old = data.tables[t] if inplace else data.tables[t].copy() + for param_name, param_value in filter_params.items(): if param_value is not None: - # Always modify result in place since we're using our own copy - sc.pp.filter_cells(result, **{param_name: param_value}, inplace=inplace) + if inplace and isinstance(data, ad.AnnData): + sc.pp.filter_cells( + table_old, **{param_name: param_value}, inplace=True + ) + elif not inplace: + # inplace=False gives us boolean vector of which rows to remove + mask_to_remove, _ = sc.pp.filter_cells( + table_old, **{param_name: param_value}, inplace=False + ) + + if isinstance(data, ad.AnnData): + + return table_old[~mask_to_remove] + + # we're SpatialData now + assert isinstance(data, sd.SpatialData) + + if not inplace: + logg.warning("Creating a deepcopy of the SpatialData object, depending on the size of the object this can take a while.") + data_out = sd.deepcopy(data) + + # elements_dict = {} + # for _, element_name, element in data.gen_elements(): + # elements_dict[element_name] = sd.deepcopy(element) + # deepcopied_attrs = data.attrs + # data_out = sd.SpatialData.from_elements_dict(elements_dict, attrs=deepcopied_attrs) + + else: + data_out = data + + table_filtered = table_old[~mask_to_remove] + if table_filtered.n_obs == 0 or table_filtered.n_vars == 0: + raise ValueError( + f"Filter results in empty table when filtering table `{t}`." + ) + data_out.tables[t] = table_filtered + + # if this doesn't exist, the table doesn't annotate anything + if "spatialdata_attrs" not in data.tables[t].uns: + raise ValueError( + f"Table `{t}` does not have 'spatialdata_attrs' to indicate what it annotates." + ) + + instance_key = data.tables[t].uns["spatialdata_attrs"][ + "instance_key" + ] + region_key = data.tables[t].uns["spatialdata_attrs"]["region_key"] + + # region can annotate one (dtype str) or multiple (dtype list[str]) + region = data.tables[t].uns["spatialdata_attrs"]["region"] + if isinstance(region, str): + region = [region] + + removed_obs = table_old.obs[mask_to_remove][ + [instance_key, region_key] + ] + + # iterate over all elements that the table annotates (region var) + for r in region: + element_model = get_model(data_out[r]) + + ids_to_remove = removed_obs.query(f"{region_key} == '{r}'")[ + instance_key + ].tolist() + if element_model == ShapesModel: + data_out.shapes[r] = _filter_ShapesModel_by_instance_ids( + element=data_out.shapes[r], ids_to_remove=ids_to_remove + ) + + if filter_labels: + logg.warning("Filtering labels, this can be slow depending on the resolution.") + if element_model == Labels2DModel: + new_label = _filter_Labels2DModel_by_instance_ids( + element=data_out.labels[r], ids_to_remove=ids_to_remove + ) + + del data_out.labels[r] + + data_out.labels[r] = new_label + + if not inplace: + return data_out - # Return the filtered data if not in place - return None if inplace else result - if isinstance(data, ad.AnnData): - data_out = data if inplace else data.copy() +def _filter_ShapesModel_by_instance_ids( + element: ShapesModel, ids_to_remove: list[str] +) -> ShapesModel: - _apply_anndata_filters(data_out, min_counts, min_genes, max_counts, max_genes, inplace=inplace) + return element[~element.index.isin(ids_to_remove)] - return None if inplace else data_out - # if it's SpatialData, we need to filter other elements in the object - elif isinstance(data, sd.SpatialData): - if not inplace: - data_out = sd.SpatialData( - images=data.images if data.images is not None else None, - labels=data.labels if data.labels is not None else None, - points=data.points if data.points is not None else None, - shapes=data.shapes if data.shapes is not None else None, - tables=data.tables if data.tables is not None else None, - ) - else: - data_out = data +def _filter_Labels2DModel_by_instance_ids( + element: Labels2DModel, ids_to_remove: list[str] +) -> Labels2DModel: - for t in tables_to_use: - if "spatialdata_attrs" in data.tables[t].uns: - instance_key = data.tables[t].uns["spatialdata_attrs"]["instance_key"] - region_key = data.tables[t].uns["spatialdata_attrs"]["region_key"] - region = data.tables[t].uns["spatialdata_attrs"]["region"] + def set_ids_in_label_to_zero( + image: xr.DataArray, ids_to_remove: list[int] + ) -> xr.DataArray: + # Use apply_ufunc for efficient processing + def _mask_block(block): + # Create a copy to avoid modifying read-only array + result = block.copy() + result[np.isin(result, masks)] = 0 + return result - filter_params = { - "min_counts": min_counts, - "min_genes": min_genes, - "max_counts": max_counts, - "max_genes": max_genes, - } + processed = xr.apply_ufunc( + _mask_block, + image, + input_core_dims=[["y", "x"]], + output_core_dims=[["y", "x"]], + vectorize=True, + dask="parallelized", + output_dtypes=[image.dtype], + dask_gufunc_kwargs={"allow_rechunk": True}, + ) - # remove the rows from the table - table_old = data.tables[t].copy() - for param_name, param_value in filter_params.items(): - if param_value is not None: - # Always modify result in place since we're using our own copy - mask, _ = sc.pp.filter_cells(table_old, **{param_name: param_value}, inplace=inplace) + # Force computation to ensure the changes are materialized + computed_result = processed.compute() - data_out.tables[t] = table_old[~mask] + # Create a new DataArray to ensure persistence + result = xr.DataArray( + data=computed_result.data, + coords=image.coords, + dims=image.dims, + attrs=image.attrs.copy(), # Preserve all attributes + ) - # remove the rows from the shapes - removed_obs = table_old.obs[mask][[instance_key, region_key]] + return result - assert removed_obs[region_key].unique() == region + if isinstance(element, xr.DataArray): + return Labels2DModel.parse(set_ids_in_label_to_zero(element, ids_to_remove)) - idx_to_remove = removed_obs[instance_key].values.tolist() - ele_to_modify = data.shapes[region].copy() - filtered_gdf = ele_to_modify[~ele_to_modify.index.isin(idx_to_remove)] + if isinstance(element, DataTree): + # we extract the info to just reconstruct the DataTree after filtering the max scale + max_scale = list(element.keys())[0] + scale_factors = _get_scale_factors(element) + scale_factors = [int(sf[0]) for sf in scale_factors] - data_out.shapes[region] = filtered_gdf + return Labels2DModel.parse( + data=set_ids_in_label_to_zero(element[max_scale].image, ids_to_remove), + scale_factors=scale_factors, + ) - return data_out + +def _get_scale_factors(labels_element: Labels2DModel) -> list[tuple[float, float]]: + + scales = list(labels_element.keys()) + + # Calculate relative scale factors between consecutive scales + scale_factors = [] + for i in range(len(scales) - 1): + y_size_current = labels_element[scales[i]].image.shape[0] + x_size_current = labels_element[scales[i]].image.shape[1] + y_size_next = labels_element[scales[i + 1]].image.shape[0] + x_size_next = labels_element[scales[i + 1]].image.shape[1] + y_factor = y_size_current / y_size_next + x_factor = x_size_current / x_size_next + + scale_factors.append((y_factor, x_factor)) + + return scale_factors From fa7cfbab5ab5615001dc28cbce73a7d1d956ae5c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Mar 2025 23:23:03 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/pp/_simple.py | 65 ++++++++++----------------------------- 1 file changed, 17 insertions(+), 48 deletions(-) diff --git a/src/squidpy/pp/_simple.py b/src/squidpy/pp/_simple.py index 0b1d1753e..9a11fca40 100644 --- a/src/squidpy/pp/_simple.py +++ b/src/squidpy/pp/_simple.py @@ -6,10 +6,10 @@ import scanpy as sc import spatialdata as sd import xarray as xr -from xarray import DataTree from spatialdata._logging import logger as logg from spatialdata.models import Labels2DModel, PointsModel, ShapesModel, get_model from spatialdata.transformations import get_transformation +from xarray import DataTree def filter_cells( @@ -37,24 +37,16 @@ def filter_cells( if isinstance(data, sd.SpatialData) and table is None: tables_to_use = list(data.tables.keys()) - if ( - not isinstance(data, ad.AnnData) - and tables_to_use is not None - and not tables_to_use - ): + if not isinstance(data, ad.AnnData) and tables_to_use is not None and not tables_to_use: raise ValueError("Expected at least one table to be filtered, found `0`") if any(t not in data.tables for t in tables_to_use): raise ValueError(f"Expected all tables to be in `{data.tables.keys()}`.`") # mimic scanpy's behavior in only allowing one filtering parameter per call - n_given_options = sum( - option is not None for option in [min_genes, min_counts, max_genes, max_counts] - ) + n_given_options = sum(option is not None for option in [min_genes, min_counts, max_genes, max_counts]) if n_given_options > 1: - raise ValueError( - "Only one filtering parameter can be provided per call (scanpy behavior)." - ) + raise ValueError("Only one filtering parameter can be provided per call (scanpy behavior).") for param_name, param_value in [ ("min_counts", min_counts), @@ -63,9 +55,7 @@ def filter_cells( ("max_genes", max_genes), ]: if param_value is not None and not isinstance(param_value, int): - raise ValueError( - f"Expected `{param_name}` to be an integer, found `{type(param_value)}`" - ) + raise ValueError(f"Expected `{param_name}` to be an integer, found `{type(param_value)}`") if not isinstance(inplace, bool): raise ValueError(f"Expected `inplace` to be a boolean, found `{type(inplace)}`") @@ -76,7 +66,6 @@ def filter_cells( # we need to filter the adata object either way for t in tables_to_use: - filter_params = { "min_counts": min_counts, "min_genes": min_genes, @@ -92,24 +81,21 @@ def filter_cells( for param_name, param_value in filter_params.items(): if param_value is not None: if inplace and isinstance(data, ad.AnnData): - sc.pp.filter_cells( - table_old, **{param_name: param_value}, inplace=True - ) + sc.pp.filter_cells(table_old, **{param_name: param_value}, inplace=True) elif not inplace: # inplace=False gives us boolean vector of which rows to remove - mask_to_remove, _ = sc.pp.filter_cells( - table_old, **{param_name: param_value}, inplace=False - ) + mask_to_remove, _ = sc.pp.filter_cells(table_old, **{param_name: param_value}, inplace=False) if isinstance(data, ad.AnnData): - return table_old[~mask_to_remove] # we're SpatialData now assert isinstance(data, sd.SpatialData) if not inplace: - logg.warning("Creating a deepcopy of the SpatialData object, depending on the size of the object this can take a while.") + logg.warning( + "Creating a deepcopy of the SpatialData object, depending on the size of the object this can take a while." + ) data_out = sd.deepcopy(data) # elements_dict = {} @@ -123,9 +109,7 @@ def filter_cells( table_filtered = table_old[~mask_to_remove] if table_filtered.n_obs == 0 or table_filtered.n_vars == 0: - raise ValueError( - f"Filter results in empty table when filtering table `{t}`." - ) + raise ValueError(f"Filter results in empty table when filtering table `{t}`.") data_out.tables[t] = table_filtered # if this doesn't exist, the table doesn't annotate anything @@ -134,9 +118,7 @@ def filter_cells( f"Table `{t}` does not have 'spatialdata_attrs' to indicate what it annotates." ) - instance_key = data.tables[t].uns["spatialdata_attrs"][ - "instance_key" - ] + instance_key = data.tables[t].uns["spatialdata_attrs"]["instance_key"] region_key = data.tables[t].uns["spatialdata_attrs"]["region_key"] # region can annotate one (dtype str) or multiple (dtype list[str]) @@ -144,17 +126,13 @@ def filter_cells( if isinstance(region, str): region = [region] - removed_obs = table_old.obs[mask_to_remove][ - [instance_key, region_key] - ] + removed_obs = table_old.obs[mask_to_remove][[instance_key, region_key]] # iterate over all elements that the table annotates (region var) for r in region: element_model = get_model(data_out[r]) - ids_to_remove = removed_obs.query(f"{region_key} == '{r}'")[ - instance_key - ].tolist() + ids_to_remove = removed_obs.query(f"{region_key} == '{r}'")[instance_key].tolist() if element_model == ShapesModel: data_out.shapes[r] = _filter_ShapesModel_by_instance_ids( element=data_out.shapes[r], ids_to_remove=ids_to_remove @@ -175,20 +153,12 @@ def filter_cells( return data_out -def _filter_ShapesModel_by_instance_ids( - element: ShapesModel, ids_to_remove: list[str] -) -> ShapesModel: - +def _filter_ShapesModel_by_instance_ids(element: ShapesModel, ids_to_remove: list[str]) -> ShapesModel: return element[~element.index.isin(ids_to_remove)] -def _filter_Labels2DModel_by_instance_ids( - element: Labels2DModel, ids_to_remove: list[str] -) -> Labels2DModel: - - def set_ids_in_label_to_zero( - image: xr.DataArray, ids_to_remove: list[int] - ) -> xr.DataArray: +def _filter_Labels2DModel_by_instance_ids(element: Labels2DModel, ids_to_remove: list[str]) -> Labels2DModel: + def set_ids_in_label_to_zero(image: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: # Use apply_ufunc for efficient processing def _mask_block(block): # Create a copy to avoid modifying read-only array @@ -236,7 +206,6 @@ def _mask_block(block): def _get_scale_factors(labels_element: Labels2DModel) -> list[tuple[float, float]]: - scales = list(labels_element.keys()) # Calculate relative scale factors between consecutive scales From 975475edb126215d4c08b5df732140ff1890309e Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Tue, 11 Mar 2025 00:39:29 +0100 Subject: [PATCH 4/4] fixed typo --- src/squidpy/pp/_simple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/squidpy/pp/_simple.py b/src/squidpy/pp/_simple.py index 0b1d1753e..b684fff67 100644 --- a/src/squidpy/pp/_simple.py +++ b/src/squidpy/pp/_simple.py @@ -193,7 +193,7 @@ def set_ids_in_label_to_zero( def _mask_block(block): # Create a copy to avoid modifying read-only array result = block.copy() - result[np.isin(result, masks)] = 0 + result[np.isin(result, ids_to_remove)] = 0 return result processed = xr.apply_ufunc(