diff --git a/pyproject.toml b/pyproject.toml index 6bada7e2..45bbd798 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dynamic= [ license = {file = "LICENSE"} readme = "README.md" dependencies = [ - "spatialdata>=0.2.6", + "spatialdata>=0.3.0", "matplotlib", "scikit-learn", "scanpy", diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 3cdf77fc..e9dd630d 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -169,6 +169,7 @@ def render_shapes( scale: float | int = 1.0, method: str | None = None, table_name: str | None = None, + table_layer: str | None = None, **kwargs: Any, ) -> sd.SpatialData: """ @@ -228,6 +229,9 @@ def render_shapes( Name of the table containing the color(s) columns. If one name is given than the table is used for each spatial element to be plotted if the table annotates it. If you want to use different tables for particular elements, as specified under element. + table_layer: str | None + Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in + :attr:`sdata.table.X` is used for coloring. **kwargs : Any Additional arguments for customization. This can include: @@ -271,6 +275,7 @@ def render_shapes( norm=norm, scale=scale, table_name=table_name, + table_layer=table_layer, method=method, ds_reduction=kwargs.get("datashader_reduction"), ) @@ -298,6 +303,7 @@ def render_shapes( fill_alpha=param_values["fill_alpha"], transfunc=kwargs.get("transfunc"), table_name=param_values["table_name"], + table_layer=param_values["table_layer"], zorder=n_steps, method=param_values["method"], ds_reduction=param_values["ds_reduction"], @@ -320,6 +326,7 @@ def render_points( size: float | int = 1.0, method: str | None = None, table_name: str | None = None, + table_layer: str | None = None, **kwargs: Any, ) -> sd.SpatialData: """ @@ -370,6 +377,9 @@ def render_points( Name of the table containing the color(s) columns. If one name is given than the table is used for each spatial element to be plotted if the table annotates it. If you want to use different tables for particular elements, as specified under element. + table_layer: str | None + Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in + :attr:`sdata.table.X` is used for coloring. **kwargs : Any Additional arguments for customization. This can include: @@ -403,6 +413,7 @@ def render_points( norm=norm, size=size, table_name=table_name, + table_layer=table_layer, ds_reduction=kwargs.get("datashader_reduction"), ) @@ -433,6 +444,7 @@ def render_points( transfunc=kwargs.get("transfunc"), size=param_values["size"], table_name=param_values["table_name"], + table_layer=param_values["table_layer"], zorder=n_steps, method=method, ds_reduction=param_values["ds_reduction"], @@ -573,6 +585,7 @@ def render_labels( fill_alpha: float | int = 0.4, scale: str | None = None, table_name: str | None = None, + table_layer: str | None = None, **kwargs: Any, ) -> sd.SpatialData: """ @@ -590,10 +603,10 @@ def render_labels( The name of the labels element to render. If `None`, all label elements in the `SpatialData` object will be used and all parameters will be broadcasted if possible. color : list[str] | str | None - Can either be string representing a color-like or key in :attr:`sdata.table.obs`. The latter can be used to - color by categorical or continuous variables. If the color column is found in multiple locations, please - provide the table_name to be used for the element if you would like a specific table to be used. By default - one table will automatically be choosen. + Can either be string representing a color-like or key in :attr:`sdata.table.obs` or in the index of + :attr:`sdata.table.var`. The latter can be used to color by categorical or continuous variables. If the + color column is found in multiple locations, please provide the table_name to be used for the element if you + would like a specific table to be used. By default one table will automatically be choosen. groups : list[str] | str | None When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of them. Other values are set to NA. The list can contain multiple discrete labels to be visualized. @@ -626,6 +639,9 @@ def render_labels( with the highest resolution is selected. This can lead to long computing times for large images! table_name: str | None Name of the table containing the color columns. + table_layer: str | None + Layer of the AnnData table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, + :attr:`sdata.table.X` of the default table is used for coloring. kwargs Additional arguments to be passed to cmap and norm. @@ -654,6 +670,7 @@ def render_labels( palette=palette, scale=scale, table_name=table_name, + table_layer=table_layer, ) sdata = self._copy() @@ -678,6 +695,7 @@ def render_labels( transfunc=kwargs.get("transfunc"), scale=param_values["scale"], table_name=param_values["table_name"], + table_layer=param_values["table_layer"], zorder=n_steps, ) n_steps += 1 @@ -811,7 +829,6 @@ def show( ax_x_min, ax_x_max = ax.get_xlim() ax_y_max, ax_y_min = ax.get_ylim() # (0, 0) is top-left - # handle coordinate system coordinate_systems = sdata.coordinate_systems if coordinate_systems is None else coordinate_systems if isinstance(coordinate_systems, str): coordinate_systems = [coordinate_systems] diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 8f392844..5bafe7a8 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -17,7 +17,7 @@ from matplotlib.cm import ScalarMappable from matplotlib.colors import ListedColormap, Normalize from scanpy._settings import settings as sc_settings -from spatialdata import get_extent, join_spatialelement_table +from spatialdata import get_extent, get_values, join_spatialelement_table from spatialdata.models import PointsModel, ShapesModel, get_table_keys from spatialdata.transformations import get_transformation, set_transformation from spatialdata.transformations.transformations import Identity @@ -70,6 +70,7 @@ def _render_shapes( element = render_params.element col_for_color = render_params.col_for_color groups = render_params.groups + table_layer = render_params.table_layer sdata_filt = sdata.filter_by_coordinate_system( coordinate_system=coordinate_system, @@ -115,6 +116,7 @@ def _render_shapes( na_color=render_params.color or render_params.cmap_params.na_color, cmap_params=render_params.cmap_params, table_name=table_name, + table_layer=table_layer, ) values_are_categorical = color_source_vector is not None @@ -397,6 +399,7 @@ def _render_points( element = render_params.element col_for_color = render_params.col_for_color table_name = render_params.table_name + table_layer = render_params.table_layer color = render_params.color groups = render_params.groups palette = render_params.palette @@ -409,10 +412,22 @@ def _render_points( points = sdata.points[element] coords = ["x", "y"] - if col_for_color is None or (table_name is not None and col_for_color in sdata_filt[table_name].obs.columns): + if table_name is not None and col_for_color not in points.columns: + warnings.warn( + f"Annotating points with {col_for_color} which is stored in the table `{table_name}`. " + f"To improve performance, it is advisable to store point annotations directly in the .parquet file.", + UserWarning, + stacklevel=2, + ) + + if col_for_color is None or ( + table_name is not None + and (col_for_color in sdata_filt[table_name].obs.columns or col_for_color in sdata_filt[table_name].var_names) + ): points = points[coords].compute() if ( col_for_color + and col_for_color in sdata_filt[table_name].obs.columns and (color_col := sdata_filt[table_name].obs[col_for_color]).dtype == "O" and not _is_coercable_to_float(color_col) ): @@ -428,7 +443,20 @@ def _render_points( points = points[coords].compute() if groups is not None and col_for_color is not None: - points = points[points[col_for_color].isin(groups)] + if col_for_color in points.columns: + points_color_values = points[col_for_color] + else: + points_color_values = get_values( + value_key=col_for_color, + sdata=sdata_filt, + element_name=element, + table_name=table_name, + table_layer=table_layer, + ) + points_color_values = points.merge(points_color_values, how="left", left_index=True, right_index=True)[ + col_for_color + ] + points = points[points_color_values.isin(groups)] if len(points) <= 0: raise ValueError(f"None of the groups {groups} could be found in the column '{col_for_color}'.") @@ -438,9 +466,18 @@ def _render_points( X=points[["x", "y"]].values, obs=points[coords].reset_index(), dtype=points[["x", "y"]].values.dtype ) else: + adata_obs = sdata_filt[table_name].obs + # if the points are colored by values in X (or a different layer), add the values to obs + if col_for_color in sdata_filt[table_name].var_names: + if table_layer is None: + adata_obs[col_for_color] = sdata_filt[table_name][:, col_for_color].X.flatten().copy() + else: + adata_obs[col_for_color] = sdata_filt[table_name][:, col_for_color].layers[table_layer].flatten().copy() + if groups is not None: + adata_obs = adata_obs[adata_obs[col_for_color].isin(groups)] adata = AnnData( X=points[["x", "y"]].values, - obs=sdata_filt[table_name].obs, + obs=adata_obs, dtype=points[["x", "y"]].values.dtype, uns=sdata_filt[table_name].uns, ) @@ -847,6 +884,7 @@ def _render_labels( ) -> None: element = render_params.element table_name = render_params.table_name + table_layer = render_params.table_layer palette = render_params.palette color = render_params.color groups = render_params.groups @@ -882,7 +920,7 @@ def _render_labels( extent=extent, ) - # the avove adds a useless c dimension of 1 (y, x) -> (1, y, x) + # the above adds a useless c dimension of 1 (y, x) -> (1, y, x) label = label.squeeze() if table_name is None: @@ -907,6 +945,7 @@ def _render_labels( na_color=render_params.cmap_params.na_color, cmap_params=render_params.cmap_params, table_name=table_name, + table_layer=table_layer, ) def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) -> matplotlib.image.AxesImage: diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index ee37c57f..b44175c3 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -89,6 +89,7 @@ class ShapesRenderParams: method: str | None = None zorder: int = 0 table_name: str | None = None + table_layer: str | None = None ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None @@ -108,6 +109,7 @@ class PointsRenderParams: method: str | None = None zorder: int = 0 table_name: str | None = None + table_layer: str | None = None ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None @@ -141,4 +143,5 @@ class LabelsRenderParams: transfunc: Callable[[float], float] | None = None scale: str | None = None table_name: str | None = None + table_layer: str | None = None zorder: int = 0 diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 2fe377cc..03941dd8 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -56,9 +56,9 @@ from skimage.segmentation import find_boundaries from skimage.util import map_array from spatialdata import SpatialData, get_element_annotators, get_extent, get_values, rasterize -from spatialdata._core.query.relational_query import _locate_value, _ValueOrigin +from spatialdata._core.query.relational_query import _locate_value from spatialdata._types import ArrayLike -from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, SpatialElement, get_model +from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement # from spatialdata.transformations.transformations import Scale from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Translation @@ -703,17 +703,6 @@ def _get_colors_for_categorical_obs( return palette[:len_cat] # type: ignore[return-value] -# TODO consider move to relational query in spatialdata -def get_values_point_table(sdata: SpatialData, origin: _ValueOrigin, table_name: str) -> pd.Series: - """Get a particular column stored in _ValueOrigin from the table in the spatialdata object.""" - table = sdata[table_name] - if origin.origin == "obs": - return table.obs[origin.value_key] - if origin.origin == "var": - return table[:, table.var_names.isin([origin.value_key])].X.copy() - raise ValueError(f"Color column `{origin.value_key}` not found in table {table_name}") - - def _set_color_source_vec( sdata: sd.SpatialData, element: SpatialElement | None, @@ -724,6 +713,7 @@ def _set_color_source_vec( palette: list[str] | str | None = None, cmap_params: CmapParams | None = None, table_name: str | None = None, + table_layer: str | None = None, ) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]: if value_to_plot is None and element is not None: color = np.full(len(element), na_color) @@ -738,13 +728,13 @@ def _set_color_source_vec( ) if len(origins) == 1: - color_source_vector = _robust_get_value( + color_source_vector = get_values( + value_key=value_to_plot, sdata=sdata, - origin=origins[0], - value_to_plot=value_to_plot, element_name=element_name, table_name=table_name, - ) + table_layer=table_layer, + )[value_to_plot] # numerical case, return early # TODO temporary split until refactor is complete @@ -1610,13 +1600,13 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st raise ValueError("Parameter 'na_color' must be color-like.") if (norm := param_dict.get("norm")) is not None: - if element_type in ["images", "labels"] and not isinstance(norm, Normalize): + if element_type in {"images", "labels"} and not isinstance(norm, Normalize): raise TypeError("Parameter 'norm' must be of type Normalize.") if element_type in ["shapes", "points"] and not isinstance(norm, bool | Normalize): raise TypeError("Parameter 'norm' must be a boolean or a mpl.Normalize.") if (scale := param_dict.get("scale")) is not None: - if element_type in ["images", "labels"] and not isinstance(scale, str): + if element_type in {"images", "labels"} and not isinstance(scale, str): raise TypeError("Parameter 'scale' must be a string if specified.") if element_type == "shapes": if not isinstance(scale, float | int): @@ -1630,10 +1620,53 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if size < 0: raise ValueError("Parameter 'size' must be a positive number.") - if param_dict.get("table_name") and not isinstance(param_dict["table_name"], str): - raise TypeError("Parameter 'table_name' must be a string .") + table_name = param_dict.get("table_name") + table_layer = param_dict.get("table_layer") + if table_name and not isinstance(param_dict["table_name"], str): + raise TypeError("Parameter 'table_name' must be a string.") + + if table_layer and not isinstance(param_dict["table_layer"], str): + raise TypeError("Parameter 'table_layer' must be a string.") + + def _ensure_table_and_layer_exist_in_sdata( + sdata: SpatialData, table_name: str | None, table_layer: str | None + ) -> bool: + """Ensure that table_name and table_layer are valid; throw error if not.""" + if table_name: + if table_layer: + if table_layer in sdata.tables[table_name].layers: + return True + raise ValueError(f"Layer '{table_layer}' not found in table '{table_name}'.") + return True # using sdata.tables[table_name].X + + if table_layer: + # user specified a layer but we have no tables => invalid + if len(sdata.tables) == 0: + raise ValueError("Trying to use 'table_layer' but no tables are present in the SpatialData object.") + if len(sdata.tables) == 1: + single_table_name = list(sdata.tables.keys())[0] + if table_layer in sdata.tables[single_table_name].layers: + return True + raise ValueError(f"Layer '{table_layer}' not found in table '{single_table_name}'.") + # more than one tables, try to find which one has the given layer + found_table = False + for tname in sdata.tables: + if table_layer in sdata.tables[tname].layers: + if found_table: + raise ValueError( + "Trying to guess 'table_name' based on 'table_layer', " "but found multiple matches." + ) + found_table = True + + if found_table: + return True + + raise ValueError(f"Layer '{table_layer}' not found in any table.") + + return True # not using any table + + assert _ensure_table_and_layer_exist_in_sdata(param_dict.get("sdata"), table_name, table_layer) - # like this because the following would assign True/False to 'method' if (method := param_dict.get("method")) not in ["matplotlib", "datashader", None]: raise ValueError("If specified, parameter 'method' must be either 'matplotlib' or 'datashader'.") @@ -1672,6 +1705,7 @@ def _validate_label_render_params( outline_alpha: float | int, scale: str | None, table_name: str | None, + table_layer: str | None, ) -> dict[str, dict[str, Any]]: param_dict: dict[str, Any] = { "sdata": sdata, @@ -1687,6 +1721,7 @@ def _validate_label_render_params( "norm": norm, "scale": scale, "table_name": table_name, + "table_layer": table_layer, } param_dict = _type_check_params(param_dict, "labels") @@ -1700,11 +1735,11 @@ def _validate_label_render_params( element_params[el]["na_color"] = param_dict["na_color"] element_params[el]["cmap"] = param_dict["cmap"] element_params[el]["norm"] = param_dict["norm"] - element_params[el]["color"] = param_dict["color"] element_params[el]["fill_alpha"] = param_dict["fill_alpha"] element_params[el]["scale"] = param_dict["scale"] element_params[el]["outline_alpha"] = param_dict["outline_alpha"] element_params[el]["contour_px"] = param_dict["contour_px"] + element_params[el]["table_layer"] = param_dict["table_layer"] element_params[el]["table_name"] = None element_params[el]["color"] = None @@ -1731,6 +1766,7 @@ def _validate_points_render_params( norm: Normalize | None, size: float | int, table_name: str | None, + table_layer: str | None, ds_reduction: str | None, ) -> dict[str, dict[str, Any]]: param_dict: dict[str, Any] = { @@ -1745,6 +1781,7 @@ def _validate_points_render_params( "norm": norm, "size": size, "table_name": table_name, + "table_layer": table_layer, "ds_reduction": ds_reduction, } param_dict = _type_check_params(param_dict, "points") @@ -1762,6 +1799,7 @@ def _validate_points_render_params( element_params[el]["color"] = param_dict["color"] element_params[el]["size"] = param_dict["size"] element_params[el]["alpha"] = param_dict["alpha"] + element_params[el]["table_layer"] = param_dict["table_layer"] element_params[el]["table_name"] = None element_params[el]["col_for_color"] = None @@ -1794,6 +1832,7 @@ def _validate_shape_render_params( norm: Normalize | None, scale: float | int, table_name: str | None, + table_layer: str | None, method: str | None, ds_reduction: str | None, ) -> dict[str, dict[str, Any]]: @@ -1812,6 +1851,7 @@ def _validate_shape_render_params( "norm": norm, "scale": scale, "table_name": table_name, + "table_layer": table_layer, "method": method, "ds_reduction": ds_reduction, } @@ -1832,6 +1872,7 @@ def _validate_shape_render_params( element_params[el]["cmap"] = param_dict["cmap"] element_params[el]["norm"] = param_dict["norm"] element_params[el]["scale"] = param_dict["scale"] + element_params[el]["table_layer"] = param_dict["table_layer"] element_params[el]["color"] = param_dict["color"] @@ -2185,21 +2226,6 @@ def _datshader_get_how_kw_for_spread( return reduction_to_how_map[reduction] -def _robust_get_value( - sdata: sd.SpatialData, - origin: _ValueOrigin, - value_to_plot: str | None, - element_name: list[str] | str | None = None, - table_name: str | None = None, -) -> pd.Series | None: - """Locate the value to plot in the spatial data object.""" - model = get_model(sdata[element_name]) - if model == PointsModel and table_name is not None: - return get_values_point_table(sdata=sdata, origin=origin, table_name=table_name) - vals = get_values(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name) - return vals[value_to_plot] - - def _prepare_transformation( element: DataArray | GeoDataFrame | dask.dataframe.core.DataFrame, coordinate_system: str, ax: Axes | None = None ) -> tuple[matplotlib.transforms.Affine2D, matplotlib.transforms.CompositeGenericTransform | None]: diff --git a/tests/_images/Labels_can_annotate_labels_with_table_layer.png b/tests/_images/Labels_can_annotate_labels_with_table_layer.png new file mode 100644 index 00000000..ce2c179a Binary files /dev/null and b/tests/_images/Labels_can_annotate_labels_with_table_layer.png differ diff --git a/tests/_images/Points_can_annotate_points_with_table_X.png b/tests/_images/Points_can_annotate_points_with_table_X.png new file mode 100644 index 00000000..bad4e3e6 Binary files /dev/null and b/tests/_images/Points_can_annotate_points_with_table_X.png differ diff --git a/tests/_images/Points_can_annotate_points_with_table_and_groups.png b/tests/_images/Points_can_annotate_points_with_table_and_groups.png new file mode 100644 index 00000000..4941bfd2 Binary files /dev/null and b/tests/_images/Points_can_annotate_points_with_table_and_groups.png differ diff --git a/tests/_images/Points_can_annotate_points_with_table_layer.png b/tests/_images/Points_can_annotate_points_with_table_layer.png new file mode 100644 index 00000000..bee86536 Binary files /dev/null and b/tests/_images/Points_can_annotate_points_with_table_layer.png differ diff --git a/tests/_images/Points_can_annotate_points_with_table_obs.png b/tests/_images/Points_can_annotate_points_with_table_obs.png new file mode 100644 index 00000000..8002e8f9 Binary files /dev/null and b/tests/_images/Points_can_annotate_points_with_table_obs.png differ diff --git a/tests/_images/Shapes_can_annotate_shapes_with_table_layer.png b/tests/_images/Shapes_can_annotate_shapes_with_table_layer.png new file mode 100644 index 00000000..99a04ad8 Binary files /dev/null and b/tests/_images/Shapes_can_annotate_shapes_with_table_layer.png differ diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index 1e609f83..d7697bd7 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -233,3 +233,7 @@ def _make_tablemodel_with_categorical_labels(self, sdata_blobs, labels_name: str ) sdata_blobs["other_table"] = table sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category") + + def test_plot_can_annotate_labels_with_table_layer(self, sdata_blobs: SpatialData): + sdata_blobs["table"].layers["normalized"] = RNG.random(sdata_blobs["table"].X.shape) + sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", table_layer="normalized").pl.show() diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index f1ae56aa..3ddef2bb 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -224,3 +224,78 @@ def test_plot_datashader_can_transform_points(self, sdata_blobs: SpatialData): _set_transformations(sdata_blobs["blobs_points"], {"global": seq}) sdata_blobs.pl.render_points("blobs_points", method="datashader", color="black", size=5).pl.show() + + def test_plot_can_annotate_points_with_table_obs(self, sdata_blobs: SpatialData): + nrows, ncols = 200, 3 + feature_matrix = RNG.random((nrows, ncols)) + var_names = [f"feature{i}" for i in range(ncols)] + + obs_indices = sdata_blobs["blobs_points"].index + + obs = pd.DataFrame() + obs["instance_id"] = obs_indices + obs["region"] = "blobs_points" + obs["region"].astype("category") + obs["extra_feature"] = [1, 2] * 100 + + table = AnnData(X=feature_matrix, var=pd.DataFrame(index=var_names), obs=obs) + table = TableModel.parse(table, region="blobs_points", region_key="region", instance_key="instance_id") + sdata_blobs["points_table"] = table + + sdata_blobs.pl.render_points("blobs_points", color="extra_feature", size=10).pl.show() + + def test_plot_can_annotate_points_with_table_X(self, sdata_blobs: SpatialData): + nrows, ncols = 200, 3 + feature_matrix = RNG.random((nrows, ncols)) + var_names = [f"feature{i}" for i in range(ncols)] + + obs_indices = sdata_blobs["blobs_points"].index + + obs = pd.DataFrame() + obs["instance_id"] = obs_indices + obs["region"] = "blobs_points" + obs["region"].astype("category") + + table = AnnData(X=feature_matrix, var=pd.DataFrame(index=var_names), obs=obs) + table = TableModel.parse(table, region="blobs_points", region_key="region", instance_key="instance_id") + sdata_blobs["points_table"] = table + + sdata_blobs.pl.render_points("blobs_points", color="feature0", size=10).pl.show() + + def test_plot_can_annotate_points_with_table_and_groups(self, sdata_blobs: SpatialData): + nrows, ncols = 200, 3 + feature_matrix = RNG.random((nrows, ncols)) + var_names = [f"feature{i}" for i in range(ncols)] + + obs_indices = sdata_blobs["blobs_points"].index + + obs = pd.DataFrame() + obs["instance_id"] = obs_indices + obs["region"] = "blobs_points" + obs["region"].astype("category") + obs["extra_feature_cat"] = ["one", "two"] * 100 + + table = AnnData(X=feature_matrix, var=pd.DataFrame(index=var_names), obs=obs) + table = TableModel.parse(table, region="blobs_points", region_key="region", instance_key="instance_id") + sdata_blobs["points_table"] = table + + sdata_blobs.pl.render_points("blobs_points", color="extra_feature_cat", groups="two", size=10).pl.show() + + def test_plot_can_annotate_points_with_table_layer(self, sdata_blobs: SpatialData): + nrows, ncols = 200, 3 + feature_matrix = RNG.random((nrows, ncols)) + var_names = [f"feature{i}" for i in range(ncols)] + + obs_indices = sdata_blobs["blobs_points"].index + + obs = pd.DataFrame() + obs["instance_id"] = obs_indices + obs["region"] = "blobs_points" + obs["region"].astype("category") + + table = AnnData(X=feature_matrix, var=pd.DataFrame(index=var_names), obs=obs) + table = TableModel.parse(table, region="blobs_points", region_key="region", instance_key="instance_id") + sdata_blobs["points_table"] = table + sdata_blobs["points_table"].layers["normalized"] = RNG.random((nrows, ncols)) + + sdata_blobs.pl.render_points("blobs_points", color="feature0", size=10, table_layer="normalized").pl.show() diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index dcc24fac..7abb0783 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -455,3 +455,22 @@ def test_plot_can_do_non_matching_table(self, sdata_blobs: SpatialData): sdata_blobs["new_table"] = table_shapes sdata_blobs.pl.render_shapes("blobs_circles", color="instance_id").pl.show() + + def test_plot_can_annotate_shapes_with_table_layer(self, sdata_blobs: SpatialData): + nrows, ncols = 5, 3 + feature_matrix = RNG.random((nrows, ncols)) + var_names = [f"feature{i}" for i in range(ncols)] + + obs_indices = sdata_blobs["blobs_circles"].index + + obs = pd.DataFrame() + obs["instance_id"] = obs_indices + obs["region"] = "blobs_circles" + obs["region"].astype("category") + + table = AnnData(X=feature_matrix, var=pd.DataFrame(index=var_names), obs=obs) + table = TableModel.parse(table, region="blobs_circles", region_key="region", instance_key="instance_id") + sdata_blobs["circle_table"] = table + sdata_blobs["circle_table"].layers["normalized"] = RNG.random((nrows, ncols)) + + sdata_blobs.pl.render_shapes("blobs_circles", color="feature0", table_layer="normalized").pl.show()