diff --git a/.gitignore b/.gitignore index 8e8d30c7..379f16be 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,4 @@ data # data folder data/ tests/data +uv.lock diff --git a/pyproject.toml b/pyproject.toml index 1ab82a31..8c0c1087 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "readfcs", "tifffile>=2023.8.12", "ome-types", + "xmltodict", ] [project.optional-dependencies] diff --git a/src/spatialdata_io/_constants/_constants.py b/src/spatialdata_io/_constants/_constants.py index f9e6f8b8..46f983b1 100644 --- a/src/spatialdata_io/_constants/_constants.py +++ b/src/spatialdata_io/_constants/_constants.py @@ -68,7 +68,6 @@ class SeqfishKeys(ModeEnum): TIFF_FILE = ".tiff" GEOJSON_FILE = ".geojson" # file identifiers - ROI = "Roi" TRANSCRIPT_COORDINATES = "TranscriptList" DAPI = "DAPI" COUNTS_FILE = "CellxGene" @@ -78,6 +77,7 @@ class SeqfishKeys(ModeEnum): # transcripts TRANSCRIPTS_X = "x" TRANSCRIPTS_Y = "y" + TRANSCRIPTS_Z = "z" FEATURE_KEY = "name" INSTANCE_KEY_POINTS = "cell" # cells @@ -88,8 +88,6 @@ class SeqfishKeys(ModeEnum): SPATIAL_KEY = "spatial" REGION_KEY = "region" INSTANCE_KEY_TABLE = "instance_id" - SCALEFEFACTOR_X = "PhysicalSizeX" - SCALEFEFACTOR_Y = "PhysicalSizeY" @unique diff --git a/src/spatialdata_io/readers/_utils/_utils.py b/src/spatialdata_io/readers/_utils/_utils.py index 6072fbc8..e4156d36 100644 --- a/src/spatialdata_io/readers/_utils/_utils.py +++ b/src/spatialdata_io/readers/_utils/_utils.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Union +from anndata import AnnData from anndata.io import read_text from h5py import File from ome_types import from_tiff diff --git a/src/spatialdata_io/readers/seqfish.py b/src/spatialdata_io/readers/seqfish.py index 276f27b5..e0f26d04 100644 --- a/src/spatialdata_io/readers/seqfish.py +++ b/src/spatialdata_io/readers/seqfish.py @@ -2,7 +2,9 @@ import os import re +import warnings import xml.etree.ElementTree as ET +from collections.abc import Mapping from pathlib import Path from types import MappingProxyType from typing import TYPE_CHECKING, Any @@ -11,8 +13,10 @@ import numpy as np import pandas as pd import tifffile +import xmltodict from dask_image.imread import imread from spatialdata import SpatialData +from spatialdata._logging import logger from spatialdata.models import ( Image2DModel, Labels2DModel, @@ -30,8 +34,10 @@ __all__ = ["seqfish"] +LARGE_IMAGE_THRESHOLD = 100_000_000 -@inject_docs(vx=SK) + +@inject_docs(vx=SK, megapixels_value=str(int(LARGE_IMAGE_THRESHOLD / 1e6))) def seqfish( path: str | Path, load_images: bool = True, @@ -39,19 +45,19 @@ def seqfish( load_points: bool = True, load_shapes: bool = True, cells_as_circles: bool = False, - rois: list[int] | None = None, + rois: list[str] | None = None, imread_kwargs: Mapping[str, Any] = MappingProxyType({}), - raster_models_scale_factors: list[float] | None = None, + raster_models_scale_factors: list[int] | None = None, ) -> SpatialData: """Read *seqfish* formatted dataset. This function reads the following files: - - ```{vx.ROI!r}{vx.COUNTS_FILE!r}{vx.CSV_FILE!r}```: Counts and metadata file. - - ```{vx.ROI!r}{vx.CELL_COORDINATES!r}{vx.CSV_FILE!r}```: Cell coordinates file. - - ```{vx.ROI!r}{vx.DAPI!r}{vx.TIFF_FILE!r}```: High resolution tiff image. - - ```{vx.ROI!r}{vx.SEGMENTATION!r}{vx.TIFF_FILE!r}```: Cell mask file. - - ```{vx.ROI!r}{vx.TRANSCRIPT_COORDINATES!r}{vx.CSV_FILE!r}```: Transcript coordinates file. + - ```_{vx.COUNTS_FILE!r}{vx.CSV_FILE!r}```: Counts and metadata file. + - ```_{vx.CELL_COORDINATES!r}{vx.CSV_FILE!r}```: Cell coordinates file. + - ```_{vx.DAPI!r}{vx.TIFF_FILE!r}```: High resolution tiff image. + - ```_{vx.SEGMENTATION!r}{vx.TIFF_FILE!r}```: Cell mask file. + - ```_{vx.TRANSCRIPT_COORDINATES!r}{vx.CSV_FILE!r}```: Transcript coordinates file. .. seealso:: @@ -72,9 +78,13 @@ def seqfish( cells_as_circles Whether to read cells also as circles instead of labels. rois - Which ROIs (specified as integers) to load. Only necessary if multiple ROIs present. + Which ROIs (specified as strings, without trailing "_") to load (the ROI strings are used as prefixes for the + filenames). If `None`, all ROIs are loaded. imread_kwargs Keyword arguments to pass to :func:`dask_image.imread.imread`. + raster_models_scale_factors + Scale factors to downscale high-resolution images and labels. The scale factors will be automatically set to + obtain a multi-scale image for all the images and labels that are larger than {megapixels_value} megapixels. Returns ------- @@ -93,24 +103,29 @@ def seqfish( >>> sdata.write("path/to/data.zarr") """ path = Path(path) - count_file_pattern = re.compile(rf"(.*?){re.escape(SK.CELL_COORDINATES)}{re.escape(SK.CSV_FILE)}$") + count_file_pattern = re.compile(rf"(.*?)_{re.escape(SK.CELL_COORDINATES)}{re.escape(SK.CSV_FILE)}$") count_files = [f for f in os.listdir(path) if count_file_pattern.match(f)] if not count_files: raise ValueError( f"No files matching the pattern {count_file_pattern} were found. Cannot infer the naming scheme." ) - roi_pattern = re.compile(f"^{SK.ROI}(\\d+)") - found_rois = {m.group(1) for i in os.listdir(path) if (m := roi_pattern.match(i))} - if rois is None: - rois_str = [f"{SK.ROI}{roi}" for roi in found_rois] - elif isinstance(rois, list): + rois_str_set = set() + for count_file in count_files: + found = count_file_pattern.match(count_file) + if found is None: + raise ValueError(f"File {count_file} does not match the expected pattern.") + rois_str_set.add(found.group(1)) + logger.info(f"Found ROIs: {rois_str_set}") + rois_str = list(rois_str_set) + + if isinstance(rois, list): for roi in rois: - if str(roi) not in found_rois: + if str(roi) not in rois_str_set: raise ValueError(f"ROI{roi} not found.") - rois_str = [f"{SK.ROI}{roi}" for roi in rois] - else: - raise ValueError("Invalid type for 'roi'. Must be list[int] or None.") + rois_str = rois + elif rois is not None: + raise ValueError("Invalid type for 'roi'. Must be list[str] or None.") def get_cell_file(roi: str) -> str: return f"{roi}_{SK.CELL_COORDINATES}{SK.CSV_FILE}" @@ -168,33 +183,44 @@ def get_transcript_file(roi: str) -> str: scaled = {} for roi_str in rois_str: scaled[roi_str] = Scale( - np.array(_get_scale_factors(path / get_dapi_file(roi_str), SK.SCALEFEFACTOR_X, SK.SCALEFEFACTOR_Y)), + np.array(_get_scale_factors_scale0(path / get_dapi_file(roi_str))), axes=("y", "x"), ) + def _get_scale_factors(raster_path: Path, raster_models_scale_factors: list[int] | None) -> list[int] | None: + n_pixels = _get_n_pixels(raster_path) + if n_pixels > LARGE_IMAGE_THRESHOLD and raster_models_scale_factors is None: + return [2, 2, 2] + else: + return raster_models_scale_factors + if load_images: - images = { - f"{os.path.splitext(get_dapi_file(x))[0]}": Image2DModel.parse( - imread(path / get_dapi_file(x), **imread_kwargs), + images = {} + for x in rois_str: + image_path = path / get_dapi_file(x) + scale_factors = _get_scale_factors(image_path, raster_models_scale_factors) + + images[f"{os.path.splitext(get_dapi_file(x))[0]}"] = Image2DModel.parse( + imread(image_path, **imread_kwargs), dims=("c", "y", "x"), - scale_factors=raster_models_scale_factors, - transformations={"global": scaled[x]}, + scale_factors=scale_factors, + transformations={x: scaled[x]}, ) - for x in rois_str - } else: images = {} if load_labels: - labels = { - f"{os.path.splitext(get_cell_segmentation_labels_file(x))[0]}": Labels2DModel.parse( - imread(path / get_cell_segmentation_labels_file(x), **imread_kwargs).squeeze(), + labels = {} + for x in rois_str: + labels_path = path / get_cell_segmentation_labels_file(x) + scale_factors = _get_scale_factors(labels_path, raster_models_scale_factors) + + labels[f"{os.path.splitext(get_cell_segmentation_labels_file(x))[0]}"] = Labels2DModel.parse( + imread(labels_path, **imread_kwargs).squeeze(), dims=("y", "x"), - scale_factors=raster_models_scale_factors, - transformations={"global": scaled[x]}, + scale_factors=scale_factors, + transformations={x: scaled[x]}, ) - for x in rois_str - } else: labels = {} @@ -206,32 +232,39 @@ def get_transcript_file(roi: str) -> str: p = pd.read_csv(path / get_transcript_file(x), delimiter=",") instance_key_points = SK.INSTANCE_KEY_POINTS.value if SK.INSTANCE_KEY_POINTS.value in p.columns else None + coordinates = {"x": SK.TRANSCRIPTS_X, "y": SK.TRANSCRIPTS_Y, "z": SK.TRANSCRIPTS_Z} + if SK.TRANSCRIPTS_Z not in p.columns: + coordinates.pop("z") + warnings.warn( + f"Column {SK.TRANSCRIPTS_Z} not found in {get_transcript_file(x)}.", UserWarning, stacklevel=2 + ) + # call parser points[name] = PointsModel.parse( p, - coordinates={"x": SK.TRANSCRIPTS_X, "y": SK.TRANSCRIPTS_Y}, + coordinates=coordinates, feature_key=SK.FEATURE_KEY.value, instance_key=instance_key_points, - transformations={"global": Identity()}, + transformations={x: Identity()}, ) shapes = {} if cells_as_circles: - for x, adata in zip(rois_str, tables.values(), strict=False): + for x, adata in zip(rois_str, tables.values(), strict=True): shapes[f"{os.path.splitext(get_cell_file(x))[0]}"] = ShapesModel.parse( adata.obsm[SK.SPATIAL_KEY], geometry=0, radius=np.sqrt(adata.obs[SK.AREA].to_numpy() / np.pi), index=adata.obs[SK.INSTANCE_KEY_TABLE].copy(), - transformations={"global": Identity()}, + transformations={x: Identity()}, ) if load_shapes: - for x in rois_str: + for x, adata in zip(rois_str, tables.values(), strict=True): # this assumes that the index matches the instance key of the table. A more robust approach could be # implemented, as described here https://github.com/scverse/spatialdata-io/issues/249 shapes[f"{os.path.splitext(get_cell_segmentation_shapes_file(x))[0]}"] = ShapesModel.parse( path / get_cell_segmentation_shapes_file(x), - transformations={"global": scaled[x]}, + transformations={x: scaled[x]}, index=adata.obs[SK.INSTANCE_KEY_TABLE].copy(), ) @@ -240,12 +273,40 @@ def get_transcript_file(roi: str) -> str: return sdata -def _get_scale_factors(DAPI_path: Path, scalefactor_x_key: str, scalefactor_y_key: str) -> list[float]: - with tifffile.TiffFile(DAPI_path) as tif: - ome_metadata = tif.ome_metadata - root = ET.fromstring(ome_metadata) - for element in root.iter(): - if scalefactor_x_key in element.attrib.keys(): - scalefactor_x = element.attrib[scalefactor_x_key] - scalefactor_y = element.attrib[scalefactor_y_key] - return [float(scalefactor_x), float(scalefactor_y)] +def _is_ome_tiff_multiscale(ome_tiff_file: Path) -> bool: + """Check if the OME-TIFF file is multi-scale. + + Parameters + ---------- + ome_tiff_file + Path to the OME-TIFF file. + + Returns + ------- + Whether the OME-TIFF file is multi-scale. + """ + # for some image files we couldn't find the multiscale information in the omexml metadata, and this method proves to + # be more robust + try: + zarr_tiff_store = tifffile.imread(ome_tiff_file, is_ome=True, level=1, aszarr=True) + zarr_tiff_store.close() + except IndexError: + return False + return True + + +def _get_n_pixels(ome_tiff_file: Path) -> int: + with tifffile.TiffFile(ome_tiff_file, is_ome=True) as tif: + page = tif.pages[0] + shape = page.shape + n_pixels = np.array(shape).prod().item() + assert isinstance(n_pixels, int) + return n_pixels + + +def _get_scale_factors_scale0(DAPI_path: Path) -> list[float]: + with tifffile.TiffFile(DAPI_path, is_ome=True) as tif: + ome_metadata = xmltodict.parse(tif.ome_metadata) + scalefactor_x = ome_metadata["OME"]["Image"]["Pixels"]["@PhysicalSizeX"] + scalefactor_y = ome_metadata["OME"]["Image"]["Pixels"]["@PhysicalSizeY"] + return [float(scalefactor_x), float(scalefactor_y)] diff --git a/tests/test_seqfish.py b/tests/test_seqfish.py index 47d4fbd7..b71bfa6a 100644 --- a/tests/test_seqfish.py +++ b/tests/test_seqfish.py @@ -17,7 +17,7 @@ @pytest.mark.parametrize( "dataset,expected", [("seqfish-2-test-dataset/instrument 2 official", "{'y': (0, 108), 'x': (0, 108)}")] ) -@pytest.mark.parametrize("rois", [[1], None]) +@pytest.mark.parametrize("rois", [["Roi1"], None]) @pytest.mark.parametrize("cells_as_circles", [False, True]) def test_example_data(dataset: str, expected: str, rois: list[int] | None, cells_as_circles: bool) -> None: f = Path("./data") / dataset @@ -25,8 +25,9 @@ def test_example_data(dataset: str, expected: str, rois: list[int] | None, cells sdata = seqfish(f, cells_as_circles=cells_as_circles, rois=rois) from spatialdata import get_extent - extent = get_extent(sdata, exact=False) + extent = get_extent(sdata, exact=False, coordinate_system="Roi1") extent = {ax: (math.floor(extent[ax][0]), math.ceil(extent[ax][1])) for ax in extent} + del extent["z"] if cells_as_circles: # manual correction required to take into account for the circle radii expected = "{'y': (-2, 109), 'x': (-2, 109)}"