From 1402a88ebed5e843285e890a13803fcffe404738 Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Tue, 8 Dec 2020 21:38:59 +0100 Subject: [PATCH 1/6] [WIP] Creates Interval classes. No mo\re floats. Allows None --- ndstructs/point5D.py | 541 +++++++++++++++++++++---------------------- 1 file changed, 266 insertions(+), 275 deletions(-) diff --git a/ndstructs/point5D.py b/ndstructs/point5D.py index 82c49c5..161b0e8 100644 --- a/ndstructs/point5D.py +++ b/ndstructs/point5D.py @@ -1,4 +1,4 @@ -from itertools import product +import itertools import functools from ndstructs.utils.JsonSerializable import Referencer import operator @@ -30,19 +30,22 @@ def reversed(self) -> "KeyMap": return KeyMap(**{v: k for k, v in self._map.items()}) +INF = float("inf") +NINF = -INF + + class Point5D(JsonSerializable): LABELS = "txyzc" SPATIAL_LABELS = "xyz" LABEL_MAP = {label: index for index, label in enumerate(LABELS)} DTYPE = np.float64 - INF = float("inf") - NINF = -INF - def __init__(self, *, t: float = 0, x: float = 0, y: float = 0, z: float = 0, c: float = 0): - assert all( - v in (self.INF, self.NINF) or int(v) == v for v in (t, c, x, y, z) - ), f"Point5D accepts only ints or 'inf' {(t,c,x,y,z)}" - self._coords = {"t": t, "c": c, "x": x, "y": y, "z": z} + def __init__(self, *, t: int = 0, x: int = 0, y: int = 0, z: int = 0, c: int = 0): + self.x = x + self.y = y + self.z = z + self.t = t + self.c = c def __hash__(self) -> int: return hash(self.to_tuple(self.LABELS)) @@ -57,85 +60,55 @@ def from_np(cls: Type[PT], arr: np.ndarray, labels: str) -> PT: return cls.from_tuple(tuple(float(e) for e in arr), labels) def to_tuple(self, axis_order: str, type_converter: Callable[[float], T] = lambda x: float(x)) -> Tuple[T, ...]: - return tuple(type_converter(self._coords[label]) for label in axis_order) + return tuple(type_converter(self[label]) for label in axis_order) def to_dict(self) -> Dict[str, float]: - return self._coords.copy() + return {k: self[k] for k in self.LABELS} def to_np(self, axis_order: str = LABELS) -> np.ndarray: return np.asarray(self.to_tuple(axis_order)) def __repr__(self) -> str: - contents = ",".join((f"{label}:{val}" for label, val in self._coords.items())) + contents = ",".join((f"{label}:{val}" for label, val in self.to_dict().items())) return f"{self.__class__.__name__}({contents})" @staticmethod - def inf(*, t: float = None, x: float = None, y: float = None, z: float = None, c: float = None) -> "Point5D": - return Point5D( - t=Point5D.INF if t is None else t, - x=Point5D.INF if x is None else x, - y=Point5D.INF if y is None else y, - z=Point5D.INF if z is None else z, - c=Point5D.INF if c is None else c, - ) - - @staticmethod - def ninf(*, t: float = None, x: float = None, y: float = None, z: float = None, c: float = None) -> "Point5D": - return Point5D( - t=Point5D.NINF if t is None else t, - x=Point5D.NINF if x is None else x, - y=Point5D.NINF if y is None else y, - z=Point5D.NINF if z is None else z, - c=Point5D.NINF if c is None else c, - ) - - @staticmethod - def zero(*, t: float = 0, x: float = 0, y: float = 0, z: float = 0, c: float = 0) -> "Point5D": - return Point5D(t=t or 0, x=x or 0, y=y or 0, z=z or 0, c=c or 0) + def zero(*, t: int = 0, x: int = 0, y: int = 0, z: int = 0, c: int = 0) -> "Point5D": + return Point5D(t=t, x=x, y=y, z=z, c=c) @staticmethod - def one(*, t: float = 1, x: float = 1, y: float = 1, z: float = 1, c: float = 1) -> "Point5D": + def one(*, t: int = 1, x: int = 1, y: int = 1, z: int = 1, c: int = 1) -> "Point5D": return Point5D(t=t, x=x, y=y, z=z, c=c) - def __getitem__(self, key: str) -> float: - return self._coords[key] - - @property - def t(self) -> float: - return self["t"] - - @property - def x(self) -> float: - return self["x"] - - @property - def y(self) -> float: - return self["y"] - - @property - def z(self) -> float: - return self["z"] - - @property - def c(self) -> float: - return self["c"] + def __getitem__(self, key: str) -> int: + if key == "x": + return self.x + if key == "y": + return self.y + if key == "z": + return self.z + if key == "t": + return self.t + if key == "c": + return self.c + raise KeyError(key) def with_coord( self: PT, *, - t: Optional[float] = None, - c: Optional[float] = None, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, + t: Optional[int] = None, + c: Optional[int] = None, + x: Optional[int] = None, + y: Optional[int] = None, + z: Optional[int] = None, ) -> PT: - params = self.to_dict() - params["t"] = t if t is not None else params["t"] - params["c"] = c if c is not None else params["c"] - params["x"] = x if x is not None else params["x"] - params["y"] = y if y is not None else params["y"] - params["z"] = z if z is not None else params["z"] - return self.__class__(**params) + return self.__class__( + t=t if t is not None else self.t, + c=c if c is not None else self.c, + x=x if x is not None else self.x, + y=y if y is not None else self.y, + z=z if z is not None else self.z, + ) def __np_op(self: PT, other: PT_OPERABLE, op: str) -> PT: if isinstance(other, Point5D): @@ -188,8 +161,8 @@ def __mul__(self: PT, other: PT_OPERABLE) -> PT: return self.__np_op(other, "__mul__") def clamped(self: PT, minimum: "Point5D" = None, maximum: "Point5D" = None) -> PT: - minimum = minimum or Point5D.ninf() - maximum = maximum or Point5D.inf() + minimum = minimum or self + maximum = maximum or self result = np.maximum(self.to_np(self.LABELS), minimum.to_np(self.LABELS)) result = np.minimum(result, maximum.to_np(self.LABELS)) return self.__class__(**{label: val for label, val in zip(self.LABELS, result)}) @@ -241,7 +214,8 @@ def ensure_matching(cls, raw_shape: Tuple[int, ...], axiskeys: str): class Shape5D(Point5D): - def __init__(self, *, t: float = 1, x: float = 1, y: float = 1, z: float = 1, c: float = 1): + def __init__(self, *, t: int = 1, x: int = 1, y: int = 1, z: int = 1, c: int = 1): + assert all(coord >= 0 for coord in (x, y, z, t, c)) super().__init__(t=t, x=x, y=y, z=z, c=c) @classmethod @@ -254,22 +228,19 @@ def hypercube(cls, length: int) -> "Shape5D": return cls(t=length, x=length, y=length, z=length, c=length) def __repr__(self) -> str: - contents = ",".join((f"{label}:{val}" for label, val in self._coords.items() if val != 1)) + contents = ",".join((f"{label}:{val}" for label, val in self.to_dict().items() if val != 1)) return f"{self.__class__.__name__}({contents or 1})" - def to_tuple(self, axis_order: str) -> Tuple[float, ...]: - return tuple(int(v) for v in super().to_tuple(axis_order)) - @property - def spatial_axes(self) -> Dict[str, float]: - return {k: self._coords[k] for k in self.SPATIAL_LABELS} + def spatial_axes(self) -> Dict[str, int]: + return {k: self[k] for k in self.SPATIAL_LABELS} @property - def missing_spatial_axes(self) -> Dict[str, float]: + def missing_spatial_axes(self) -> Dict[str, int]: return {k: v for k, v in self.spatial_axes.items() if v == 1} @property - def present_spatial_axes(self) -> Dict[str, float]: + def present_spatial_axes(self) -> Dict[str, int]: return {k: v for k, v in self.spatial_axes.items() if k not in self.missing_spatial_axes} @property @@ -296,264 +267,286 @@ def volume(self) -> float: def hypervolume(self) -> float: return functools.reduce(operator.mul, self.to_tuple(Point5D.LABELS)) - def to_slice_5d(self, offset: Point5D = Point5D.zero()) -> "Slice5D": - return Slice5D.create_from_start_stop(offset, self + offset) + def to_slice_5d(self, offset: Point5D = Point5D.zero()) -> "Interval5D": + return Interval5D.create_from_start_stop(offset, self + offset) @classmethod def from_point(cls: Type[PT], point: Point5D) -> PT: return cls(**{k: v or 1 for k, v in point.to_dict().items()}) -SLC = TypeVar("SLC", bound="Slice5D", covariant=True) -SLC_PARAM = Union[slice, float, int] +INTERVALABLE = Union["Interval", int, None, Tuple[Optional[int], Optional[int]]] + + +class Interval: + """A contiguous interval in space of indicies between start (inclusive) and stop (exclusive)""" + + def __init__(self, start: int = 0, stop: Optional[int] = None): + self.start = start + self.stop = stop + assert self.stop >= self.start + + @classmethod + def create(cls, value: INTERVALABLE) -> "Interval": + if isinstance(value, int): + return Interval(value, value + 1) + if isinstance(value, Interval): + return value + if isinstance(value, tuple): + return Interval(value[0] or 0, value[1]) + return Interval.all() + + def __eq__(self, other: INTERVALABLE) -> bool: + other_interval = Interval.create(other) + return self.start == other_interval.start and self.stop == other_interval.stop + def __hash__(self) -> int: + return hash((self.start, self.stop)) -class Slice5D(JsonSerializable): - """A labeled 5D slice""" + @classmethod + def all(cls) -> "Interval": + return cls() @classmethod - def ensure_slice(cls, value: Optional[SLC_PARAM]) -> slice: - if value is None: - return slice(None) - if isinstance(value, slice): - start = None if value.start in (None, Point5D.NINF) else int(value.start) - stop = None if value.stop in (None, Point5D.INF) else int(value.stop) + def zero(cls) -> "Interval": + return cls(start=0, stop=1) + + def to_slice(self) -> slice: + return slice(self.start, self.stop) + + def is_defined(self) -> bool: + return self.stop != None + + def defined_with(self, limit: "Interval") -> "Interval": + assert limit.is_defined() + return Interval(self.start, self.stop if self.stop != INF else limit.stop) + + def contains(self, other: "Interval") -> bool: + if self.stop == None or other.stop == None: + return False + return self.start <= other.start and self.stop >= other.stop + + def split(self, step: int, clamp: bool = True) -> Iterable["Interval"]: + start = self.start + while self.stop == None or start < self.stop: + stop = start + step + piece = Interval(start, stop) + if clamp: + piece = piece.clamped(self) + yield piece + start = stop + + def get_tiles(self, tile_side: int, clamp: bool) -> Iterable["Interval"]: + start = (self.start // tile_side) * tile_side + return Interval(start, self.stop).split(tile_side, clamp=clamp) + + def clamped(self, limits: "Interval") -> "Interval": + if limits.stop == None: + stop = self.stop + elif self.stop == None: + stop = limits.stop else: - start = int(value) - stop = start + 1 - return slice(start, stop) + stop = min(self.stop, limits.stop) + return Interval(max(self.start, limits.start), stop) + + def enlarged(self, radius: int) -> "Interval": + return Interval(self.start - radius, None if self.stop == None else self.stop + radius) + + def translated(self, offset: int) -> "Interval": + return Interval(self.start + offset, None if self.stop == None else self.stop + offset) + + +INTERVAL_5D = TypeVar("INTERVAL_5D", bound="Interval5D", covariant=True) + + +class Interval5D(JsonSerializable): + """A labeled 5D interval""" def __init__( self, *, - t: SLC_PARAM = slice(None), - c: SLC_PARAM = slice(None), - x: SLC_PARAM = slice(None), - y: SLC_PARAM = slice(None), - z: SLC_PARAM = slice(None), + t: INTERVALABLE = Interval.all(), + c: INTERVALABLE = Interval.all(), + x: INTERVALABLE = Interval.all(), + y: INTERVALABLE = Interval.all(), + z: INTERVALABLE = Interval.all(), ): - self._slices = { - "t": self.ensure_slice(t), - "c": self.ensure_slice(c), - "x": self.ensure_slice(x), - "y": self.ensure_slice(y), - "z": self.ensure_slice(z), - } - - self.start = Point5D.ninf(**{label: slc.start for label, slc in self._slices.items()}) - self.stop = Point5D.inf(**{label: slc.stop for label, slc in self._slices.items()}) + self.x = Interval.create(x) + self.y = Interval.create(y) + self.z = Interval.create(z) + self.t = Interval.create(t) + self.c = Interval.create(c) + self.start = Point5D(x=self.x.start, y=self.y.start, z=self.z.start, t=self.t.start, c=self.c.start) + + def get_stop(self) -> Optional[Point5D]: + x = self.x.stop + y = self.y.stop + z = self.z.stop + t = self.t.stop + c = self.c.stop + if x is None or y is None or z is None or t is None or c is None: + return None + return Point5D(x=x, y=y, z=z, t=t, c=c) @staticmethod - def zero(*, t: SLC_PARAM = 0, c: SLC_PARAM = 0, x: SLC_PARAM = 0, y: SLC_PARAM = 0, z: SLC_PARAM = 0) -> "Slice5D": + def zero( + *, + t: INTERVALABLE = Interval.zero(), + c: INTERVALABLE = Interval.zero(), + x: INTERVALABLE = Interval.zero(), + y: INTERVALABLE = Interval.zero(), + z: INTERVALABLE = Interval.zero(), + ) -> "Interval5D": """Creates a slice with coords defaulting to slice(0, 1), except where otherwise specified""" - return Slice5D(t=t, c=c, x=x, y=y, z=z) + return Interval5D(t=t, c=c, x=x, y=y, z=z) - def relabeled(self: SLC, keymap: KeyMap) -> SLC: + def relabeled(self: INTERVAL_5D, keymap: KeyMap) -> INTERVAL_5D: params = {target_key: self[src_key] for src_key, target_key in keymap.items()} return self.with_coord(**params) def __eq__(self, other: object) -> bool: - if not isinstance(other, Slice5D): + if not isinstance(other, Interval5D): return False - return self.start == other.start and self.stop == other.stop + return self.to_tuple(Point5D.LABELS) == other.to_tuple(Point5D.LABELS) def __hash__(self) -> int: - return hash((self.start, self.stop)) + return hash(self.to_tuple(Point5D.LABELS)) - def contains(self, other: "Slice5D") -> bool: - assert other.is_defined() - return self.start <= other.start and self.stop >= other.stop + def contains(self, other: "Interval5D") -> bool: + return all(self[k].contains(other[k]) for k in Point5D.LABELS) def is_defined(self) -> bool: - if any(slc.stop is None for slc in self._slices.values()): - return False - if any(slc.start is None for slc in self._slices.values()): - return False - return True - - def defined_with(self: SLC, limits: Union[Shape5D, "Slice5D"]) -> SLC: - """Slice5D can have slices which are open to interpretation, like slice(None). This method - forces those slices expand into their interpretation within an array of shape 'shape'""" - limits_slice = limits if isinstance(limits, Slice5D) else limits.to_slice_5d() - assert limits_slice.is_defined() - params = {} - for key in Point5D.LABELS: - this_slc = self[key] - limit_slc = limits_slice[key] - - start = limit_slc.start if this_slc.start is None else this_slc.start - stop = limit_slc.stop if this_slc.stop is None else this_slc.stop - params[key] = slice(start, stop) - return self.with_coord(**params) + return all(i.is_defined() for i in self.to_tuple(Point5D.LABELS)) + + def defined_with(self: INTERVAL_5D, limits: Union[Shape5D, "Interval5D"]) -> INTERVAL_5D: + """Interval5D can have intervals which are open to interpretation, like Interval(0, None). This method + forces those slices expand into their interpretation within the boundaries of 'limits'""" + limits_interval = limits if isinstance(limits, Interval5D) else limits.to_slice_5d() + return self.with_coord(**{k: self[k].defined_with(limits_interval[k]) for k in Point5D.LABELS}) - def to_dict(self) -> Dict[str, slice]: - return self._slices.copy() + def to_dict(self) -> Dict[str, Interval]: + return {k: self[k] for k in Point5D.LABELS} @staticmethod def all( - t: SLC_PARAM = slice(None), - c: SLC_PARAM = slice(None), - x: SLC_PARAM = slice(None), - y: SLC_PARAM = slice(None), - z: SLC_PARAM = slice(None), - ) -> "Slice5D": - return Slice5D(t=t, c=c, x=x, y=y, z=z) + t: Interval = Interval.all(), + c: Interval = Interval.all(), + x: Interval = Interval.all(), + y: Interval = Interval.all(), + z: Interval = Interval.all(), + ) -> "Interval5D": + return Interval5D(t=t, c=c, x=x, y=y, z=z) @classmethod - def make_slices(cls, start: Point5D, stop: Point5D) -> Dict[str, slice]: - slices = {} - for label in Point5D.LABELS: - slice_start = None if start[label] == Point5D.NINF else start[label] - slice_stop = None if stop[label] == Point5D.INF else stop[label] - slices[label] = slice(slice_start, slice_stop) - return slices + def make_intervals(cls, start: Point5D, stop: Point5D) -> Dict[str, Interval]: + return {k: Interval(int(start[k]), int(stop[k])) for k in Point5D.LABELS} @staticmethod - def create_from_start_stop(start: Point5D, stop: Point5D) -> "Slice5D": - return Slice5D(**Slice5D.make_slices(start, stop)) + def create_from_start_stop(start: Point5D, stop: Point5D) -> "Interval5D": + return Interval5D(**Interval5D.make_intervals(start, stop)) @staticmethod - def from_json_data(data: dict, dereferencer: Optional[Dereferencer] = None) -> "Slice5D": + def from_json_data(data: dict, dereferencer: Optional[Dereferencer] = None) -> "Interval5D": start = Point5D.from_json_data(data["start"]) stop = Point5D.from_json_data(data["stop"]) - return Slice5D.create_from_start_stop(start, stop) + return Interval5D.create_from_start_stop(start, stop) def to_json_data(self, referencer: Referencer = lambda obj: None) -> dict: - return {"start": self.start.to_json_data(), "stop": self.stop.to_json_data()} + self_tuple = self.to_tuple(Point5D.LABELS) + return {"start": self_tuple[0], "stop": self_tuple[1]} - def from_start_stop(self: SLC, start: Point5D, stop: Point5D) -> SLC: - slices = self.make_slices(start, stop) + def from_start_stop(self: INTERVAL_5D, start: Point5D, stop: Point5D) -> INTERVAL_5D: + slices = self.make_intervals(start, stop) return self.with_coord(**slices) - def _ranges(self, block_shape: Shape5D) -> Iterator[List[float]]: - starts = self.start.to_np(Point5D.LABELS) - ends = self.stop.to_np(Point5D.LABELS) - steps = block_shape.to_np(Point5D.LABELS) - for start, end, step in zip(starts, ends, steps): - yield list(np.arange(start, end, step)) - - def split(self: SLC, block_shape: Shape5D) -> Iterator[SLC]: - assert self.is_defined() - for begin_tuple in product(*self._ranges(block_shape)): - start = Point5D.from_tuple(begin_tuple, Point5D.LABELS) - stop = (start + block_shape).clamped(maximum=self.stop) - yield self.from_start_stop(start, stop) - - def get_tiles(self: SLC, tile_shape: Shape5D) -> Iterator[SLC]: - assert self.is_defined() - start = Point5D.as_floor(self.start.to_np() / tile_shape.to_np()) * tile_shape - stop = Point5D.as_ceil(self.stop.to_np() / tile_shape.to_np()) * tile_shape - return self.from_start_stop(start, stop).split(tile_shape) - - @property - def t(self) -> slice: - return self._slices["t"] - - @property - def c(self) -> slice: - return self._slices["c"] + def split(self: INTERVAL_5D, block_shape: Shape5D) -> Iterator[INTERVAL_5D]: + """Splits self into multiple Interval5D instances, starting from self.start. Every piece shall have + shape == block_shape excedpt for the last one, which will be clamped to self.stop""" - @property - def x(self) -> slice: - return self._slices["x"] + yield from itertools.product([self[k].split(int(block_shape[k])) for k in Point5D.LABELS]) - @property - def y(self) -> slice: - return self._slices["y"] + def get_tiles(self: INTERVAL_5D, tile_shape: Shape5D, clamp: bool) -> Iterator[INTERVAL_5D]: + """Gets all tiles that would cover the entirety of self. Tiles that overflow self can be clamped + by setting `clamp` to True""" - @property - def z(self) -> slice: - return self._slices["z"] + yield from itertools.product([self[k].get_tiles(int(tile_shape[k]), clamp=clamp) for k in Point5D.LABELS]) - def __getitem__(self, key: str) -> slice: - return self._slices[key] + def __getitem__(self, key: str) -> Interval: + if key == "x": + return self.x + if key == "y": + return self.y + if key == "z": + return self.z + if key == "t": + return self.t + if key == "c": + return self.c + raise KeyError(key) + # override this in subclasses so that it returns an instance of self.__class__ def with_coord( - self: SLC, + self: INTERVAL_5D, *, - t: Optional[SLC_PARAM] = None, - c: Optional[SLC_PARAM] = None, - x: Optional[SLC_PARAM] = None, - y: Optional[SLC_PARAM] = None, - z: Optional[SLC_PARAM] = None, - ) -> SLC: - params = {} - params["t"] = self.t if t is None else t - params["c"] = self.c if c is None else c - params["x"] = self.x if x is None else x - params["y"] = self.y if y is None else y - params["z"] = self.z if z is None else z - return self.__class__(**params) - - def with_full_c(self: SLC) -> SLC: - return self.with_coord(c=slice(None)) + t: INTERVALABLE = None, + c: INTERVALABLE = None, + x: INTERVALABLE = None, + y: INTERVALABLE = None, + z: INTERVALABLE = None, + ) -> INTERVAL_5D: + return self.__class__( + t=self.t if t is None else t, + c=self.c if c is None else c, + x=self.x if x is None else x, + y=self.y if y is None else y, + z=self.z if z is None else z, + ) - @property - def shape(self) -> Shape5D: - assert self.is_defined() - return Shape5D(**(self.stop - self.start).to_dict()) + def with_full_c(self: INTERVAL_5D) -> INTERVAL_5D: + return self.with_coord(c=Interval.all()) - def clamped(self: SLC, roi: Union[Shape5D, "Slice5D"]) -> SLC: - slc = roi if isinstance(roi, Slice5D) else roi.to_slice_5d() - return self.from_start_stop(self.start.clamped(slc.start, slc.stop), self.stop.clamped(slc.start, slc.stop)) + def clamped(self: INTERVAL_5D, roi: Union[Shape5D, "Interval5D"]) -> INTERVAL_5D: + interv = roi if isinstance(roi, Interval5D) else roi.to_slice_5d() + return self.with_coord(**{k: self[k].clamped(interv[k]) for k in Point5D.LABELS}) - def enlarged(self: SLC, radius: Point5D) -> SLC: - start = self.start - radius - stop = self.stop + radius - return self.from_start_stop(start, stop) + def enlarged(self: INTERVAL_5D, radius: Point5D) -> INTERVAL_5D: + return self.with_coord(**{k: self[k].enlarged(int(radius[k])) for k in Point5D.LABELS}) - def translated(self: SLC, offset: Point5D) -> SLC: - return self.from_start_stop(self.start + offset, self.stop + offset) + def translated(self: INTERVAL_5D, offset: Point5D) -> INTERVAL_5D: + return self.with_coord(**{k: self[k].translated(offset[k]) for k in Point5D.LABELS}) def to_slices(self, axis_order: str = Point5D.LABELS) -> Tuple[slice, ...]: - slices = [] - for axis in axis_order: - slc = self._slices[axis] - start = slc.start if slc.start is None else int(slc.start) - stop = slc.stop if slc.stop is None else int(slc.stop) - slices.append(slice(start, stop)) - return tuple(slices) - - def to_np_tuple(self, axis_order: str) -> Tuple[float, ...]: - assert self.is_defined() - return (self.start.to_np(axis_order), self.stop.to_np(axis_order)) - - def to_tuple(self, axis_order: str) -> Tuple[Tuple[Optional[int], ...], Tuple[Optional[int], ...]]: - start = tuple(self._slices[k].start for k in axis_order) - stop = tuple(self._slices[k].stop for k in axis_order) + return tuple(self[axis].to_slice() for axis in axis_order) + + def to_tuple(self, axis_order: str) -> Tuple[Interval, ...]: + return tuple(self[k] for k in axis_order) + + def to_start_stop(self, axis_order: str) -> Tuple[Tuple[Optional[int], ...], Tuple[Optional[int], ...]]: + start = tuple(self[k].start for k in axis_order) + stop = tuple(self[k].stop for k in axis_order) return (start, stop) - def to_ilastik_cutout_subregion(self, axiskeys: str) -> str: - start = [slc.start for slc in self.to_slices(axiskeys)] - stop = [slc.stop for slc in self.to_slices(axiskeys)] - return str([tuple(start), tuple(stop)]) + def to_ilastik_cutout_subregion(self, axis_order: str) -> str: + return str(list(self.to_start_stop(axis_order=axis_order))) def __repr__(self) -> str: - slice_reprs = [] - starts = self.start.to_tuple(Point5D.LABELS) - stops = self.stop.to_tuple(Point5D.LABELS) - for label, start, stop in zip(Point5D.LABELS, starts, stops): - if start == Point5D.NINF and stop == Point5D.INF: - continue - if stop - start == 1: - label_repr = str(int(start)) - else: - start_str = int(start) if start != Point5D.NINF else start - stop_str = int(stop) if stop != Point5D.INF else stop - label_repr = f"{start_str}_{stop_str}" - slice_reprs.append(f"{label}:{label_repr}") - return ",".join(slice_reprs) + interval_reprs = ", ".join( + f"{k}:{self[k].start}_{self[k].stop}" for k in Point5D.LABELS if self[k] != Interval.all() + ) + return f"{self.__class__.__name__}({interval_reprs})" - def get_borders(self: SLC, thickness: Shape5D) -> Iterable[SLC]: + def get_borders(self: INTERVAL_5D, thickness: Shape5D) -> Iterable[INTERVAL_5D]: """Returns subslices of self, such that these subslices are at the borders of self (i.e.: touching the start or end of self) No axis of thickness should exceed self.shape[axis], since the subslices must be contained in self Axis where thickness[axis] == 0 will produce no borders: - slc.get_borders(Slice5D.zero(x=1, y=1)) will produce 4 borders (left, right, top, bottom) + slc.get_borders(Interval5D.zero(x=1, y=1)) will produce 4 borders (left, right, top, bottom) If, for any axis, thickness[axis] == self.shape[axis], then there will be duplicated borders in the output """ - assert self.shape >= thickness + thickness_interval = thickness.to_slice_5d(offset=self.start) + assert all(self[k].contains(thickness_interval[k]) for k in Point5D.LABELS) + # FIXME: I haven't ported this yet!!!!! for axis, axis_thickness in thickness.to_dict().items(): if axis_thickness == 0: continue @@ -561,23 +554,21 @@ def get_borders(self: SLC, thickness: Shape5D) -> Iterable[SLC]: yield self.with_coord(**{axis: slice(slc.start, slc.start + axis_thickness)}) yield self.with_coord(**{axis: slice(slc.stop - axis_thickness, slc.stop)}) - def mod_tile(self: SLC, tile_shape: Shape5D) -> SLC: - assert self.is_defined() + def mod_tile(self: INTERVAL_5D, tile_shape: Shape5D) -> INTERVAL_5D: assert self.shape <= tile_shape offset = self.start - (self.start % tile_shape) return self.from_start_stop(self.start - offset, self.stop - offset) - def get_neighboring_tiles(self: SLC, tile_shape: Shape5D) -> Iterator[SLC]: - assert self.is_defined() + def get_neighboring_tiles(self: INTERVAL_5D, tile_shape: Shape5D) -> Iterator[INTERVAL_5D]: assert self.shape <= tile_shape for axis in Point5D.LABELS: for axis_offset in (tile_shape[axis], -tile_shape[axis]): offset = Point5D.zero(**{axis: axis_offset}) yield self.translated(offset) - def get_neighbor_tile_adjacent_to(self: SLC, *, anchor: "Slice5D", tile_shape: Shape5D) -> Optional[SLC]: - assert self.is_defined() - anchor = anchor.defined_with(self.shape) + def get_neighbor_tile_adjacent_to( + self: INTERVAL_5D, *, anchor: "Interval5D", tile_shape: Shape5D + ) -> Optional[INTERVAL_5D]: assert self.contains(anchor) direction_axis: Optional[str] = None @@ -605,7 +596,7 @@ def get_neighbor_tile_adjacent_to(self: SLC, *, anchor: "Slice5D", tile_shape: S raise ValueError(f"Bad anchor for slice {self}: {anchor}") @staticmethod - def enclosing(points: Iterable[Union[Point5D, "Slice5D"]]) -> "Slice5D": + def enclosing(points: Iterable[Union[Point5D, "Interval5D"]]) -> "Interval5D": all_points = [] for p in points: if isinstance(p, Point5D): @@ -613,7 +604,7 @@ def enclosing(points: Iterable[Union[Point5D, "Slice5D"]]) -> "Slice5D": else: all_points += [p.start, p.stop - Point5D.one()] if not all_points: - return Slice5D.create_from_start_stop(Point5D.zero(), Point5D.zero()) + return Interval5D.create_from_start_stop(Point5D.zero(), Point5D.zero()) start = Point5D.min_coords(all_points) stop = Point5D.max_coords(all_points) + Point5D.one() - return Slice5D.create_from_start_stop(start=start, stop=stop) + return Interval5D.create_from_start_stop(start=start, stop=stop) From f874ff4cdb0281488639559e00aa469aa12a5a88 Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Wed, 9 Dec 2020 18:40:18 +0100 Subject: [PATCH 2/6] [WIP]Simplifies Interval5D to use Tuple[int, int] instead of Interval --- ndstructs/__init__.py | 4 +- ndstructs/array5D.py | 142 +++++++++++------- ndstructs/point5D.py | 339 +++++++++++++++--------------------------- tests/test_array5D.py | 49 ++++-- tests/test_point5D.py | 41 ++--- tests/test_slice5D.py | 209 ++++++++++++-------------- 6 files changed, 350 insertions(+), 434 deletions(-) diff --git a/ndstructs/__init__.py b/ndstructs/__init__.py index b5c650b..a765aa7 100644 --- a/ndstructs/__init__.py +++ b/ndstructs/__init__.py @@ -1,5 +1,5 @@ -from .array5D import Array5D +from .array5D import Array5D, All from .array5D import Image, ScalarImage, LinearData, ScalarData, ScalarLine, StaticLine -from .point5D import Point5D, Shape5D, Slice5D, KeyMap +from .point5D import Point5D, Shape5D, Interval5D, SPAN, KeyMap __version__ = "0.0.5dev0" diff --git a/ndstructs/array5D.py b/ndstructs/array5D.py index 506d249..c7ff816 100644 --- a/ndstructs/array5D.py +++ b/ndstructs/array5D.py @@ -1,5 +1,4 @@ -import itertools -from typing import Iterator, Tuple, Iterable, Optional, Union, TypeVar, Type, cast, Sequence +from typing import Iterator, Iterable, Optional, Union, TypeVar, Type, cast, Sequence import numpy as np from skimage import measure as skmeasure import skimage.io @@ -8,7 +7,7 @@ import uuid from numbers import Number -from .point5D import Point5D, Slice5D, Shape5D, KeyMap +from .point5D import Point5D, Interval5D, Shape5D, KeyMap, SPAN from ndstructs.utils import JsonSerializable Arr = TypeVar("Arr", bound="Array5D") @@ -28,20 +27,29 @@ ] +class All: + pass + + +SPAN_OVERRIDE = Union[SPAN, All] + + class Array5D(JsonSerializable): """A wrapper around np.ndarray with labeled axes. Enforces 5D, even if some - dimensions are of size 1. Sliceable with Slice5D's""" + dimensions are of size 1. Sliceable with Interval5D's""" LINEAR_RAW_AXISKEYS = "txyzc" def __init__(self, arr: np.ndarray, axiskeys: str, location: Point5D = Point5D.zero()): assert len(arr.shape) == len(axiskeys) missing_keys = [key for key in Point5D.LABELS if key not in axiskeys] - self._axiskeys = "".join(missing_keys) + axiskeys - assert sorted(self._axiskeys) == sorted(Point5D.LABELS) + self.axiskeys = "".join(missing_keys) + axiskeys + assert sorted(self.axiskeys) == sorted(Point5D.LABELS) slices = tuple([np.newaxis for key in missing_keys] + [...]) self._data = arr[slices] self.location = location + self.shape = Shape5D(**{key: value for key, value in zip(self.axiskeys, self._data.shape)}) + self.dtype = arr.dtype def relabeled(self: Arr, keymap: KeyMap) -> Arr: new_location = self.location.relabeled(keymap) @@ -76,15 +84,15 @@ def from_file(cls: Type[Arr], filelike: io.IOBase, location: Point5D = Point5D.z return cls(data, "yxc"[: len(data.shape)], location=location) def __repr__(self) -> str: - return f"<{self.__class__.__name__} {self.to_slice_5d()}>" + return f"<{self.__class__.__name__} {self.interval}>" @classmethod def allocate( - cls: Type[Arr], slc: Union[Slice5D, Shape5D], dtype: DTYPE, axiskeys: str = Point5D.LABELS, value: int = None + cls: Type[Arr], slc: Union[Interval5D, Shape5D], dtype: DTYPE, axiskeys: str = Point5D.LABELS, value: int = None ) -> Arr: - slc = slc.to_slice_5d() if isinstance(slc, Shape5D) else slc + slc = slc.to_interval5d() if isinstance(slc, Shape5D) else slc assert sorted(axiskeys) == sorted(Point5D.LABELS) - assert slc.is_defined() # FIXME: Create DefinedSlice class? + assert slc.shape.hypervolume != float("inf") arr = np.empty(slc.shape.to_tuple(axiskeys), dtype=dtype) arr = cls(arr, axiskeys, location=slc.start) if value is not None: @@ -95,26 +103,10 @@ def allocate( def allocate_like( cls: Type[Arr], arr: "Array5D", dtype: Optional[DTYPE], axiskeys: str = "", value: int = None ) -> Arr: - return cls.allocate(arr.roi, dtype=dtype or arr.dtype, axiskeys=axiskeys or arr.axiskeys, value=value) - - @property - def dtype(self) -> Type: - return self._data.dtype - - @property - def axiskeys(self) -> str: - return self._axiskeys - - @property - def _shape(self) -> Tuple: - return self._data.shape - - @property - def shape(self) -> Shape5D: - return Shape5D(**{key: value for key, value in zip(self.axiskeys, self._shape)}) + return cls.allocate(arr.interval, dtype=dtype or arr.dtype, axiskeys=axiskeys or arr.axiskeys, value=value) def split(self: Arr, shape: Shape5D) -> Iterator[Arr]: - for slc in self.roi.split(shape): + for slc in self.interval.split(shape): yield self.cut(slc) def as_mask(self) -> "Array5D": @@ -127,7 +119,7 @@ def sample_channels(self, mask: "ScalarData") -> "LinearData": (N, c) where N is the number of True-valued elements in 'mask', and c is the number of channels in self.""" - assert self.shape.with_coord(c=1) == mask.shape + assert self.shape.updated(c=1) == mask.shape assert mask.dtype == bool # FIXME: create "Mask" type? # mask has singleton channel axis, so 'c' must be in the end to index self.raw @@ -159,7 +151,7 @@ def setflags(self, *, write: bool) -> None: self._data.setflags(write=write) def normalized(self: Arr, step: Optional[Shape5D] = None) -> Arr: - step = step if step is not None else self.roi.with_coord(c=1, t=1).defined_with(self.shape).shape + step = step if step is not None else self.interval.updated(c=1, t=1).clamped(self.shape).shape normalized = self.allocate(self.shape, self.dtype, self.axiskeys) for source, dest in zip(normalized.split(step), self.split(step)): source_raw = source.raw(self.axiskeys) @@ -176,7 +168,7 @@ def rebuild(self: Arr, arr: np.ndarray, *, axiskeys: str, location: Point5D = No return self.__class__(arr, axiskeys, location) def translated(self: Arr, offset: Point5D) -> Arr: - return self.rebuild(self._data, axiskeys=self._axiskeys, location=self.location + offset) + return self.rebuild(self._data, axiskeys=self.axiskeys, location=self.location + offset) def raw(self, axiskeys: str) -> np.ndarray: """Returns a raw view of the underlying np.ndarray, containing only the axes @@ -214,36 +206,82 @@ def reordered(self: Arr, axiskeys: str) -> Arr: return self.rebuild(moved_arr, axiskeys=new_axes) - def local_cut(self: Arr, roi: Slice5D, *, copy: bool = False) -> Arr: - defined_roi = roi.defined_with(self.shape) - slices = defined_roi.to_slices(self.axiskeys) + def local_cut( + self: Arr, + interval: Interval5D = None, + *, + x: Optional[SPAN_OVERRIDE] = None, + y: Optional[SPAN_OVERRIDE] = None, + z: Optional[SPAN_OVERRIDE] = None, + t: Optional[SPAN_OVERRIDE] = None, + c: Optional[SPAN_OVERRIDE] = None, + copy: bool = False, + ) -> Arr: + local_interval = self.shape.to_interval5d() + interval = (interval or local_interval).updated( + x=local_interval.x if isinstance(x, All) else x, + y=local_interval.y if isinstance(y, All) else y, + z=local_interval.z if isinstance(z, All) else z, + t=local_interval.t if isinstance(t, All) else t, + c=local_interval.c if isinstance(c, All) else c, + ) + slices = interval.to_slices(self.axiskeys) + if any(slc.start < 0 for slc in slices): + raise ValueError(f"Cant't cut locally with negative indices: {interval}") if copy: cut_data = np.copy(self._data[slices]) else: cut_data = self._data[slices] - return self.rebuild(cut_data, axiskeys=self.axiskeys, location=self.location + defined_roi.start) - - def cut(self: Arr, roi: Slice5D, *, copy: bool = False) -> Arr: - return self.local_cut(roi.translated(-self.location), copy=copy) # TODO: define before translate? + return self.rebuild(cut_data, axiskeys=self.axiskeys, location=self.location + interval.start) + + def cut( + self: Arr, + interval: Interval5D = None, + *, + x: Optional[SPAN_OVERRIDE] = None, + y: Optional[SPAN_OVERRIDE] = None, + z: Optional[SPAN_OVERRIDE] = None, + t: Optional[SPAN_OVERRIDE] = None, + c: Optional[SPAN_OVERRIDE] = None, + copy: bool = False, + ) -> Arr: + interval = ( + (interval or self.interval) + .updated( + x=self.interval.x if isinstance(x, All) else x, + y=self.interval.y if isinstance(y, All) else y, + z=self.interval.z if isinstance(z, All) else z, + t=self.interval.t if isinstance(t, All) else t, + c=self.interval.c if isinstance(c, All) else c, + ) + .translated(-self.location) + ) + return self.local_cut(interval, copy=copy) def duplicate(self: Arr) -> Arr: - return self.cut(self.roi, copy=True) - - def clamped(self: Arr, roi: Slice5D) -> Arr: - return self.cut(self.roi.clamped(roi)) - - def to_slice_5d(self) -> Slice5D: - return self.shape.to_slice_5d().translated(self.location) + return self.cut(self.interval, copy=True) + + def clamped( + self: Arr, + interval: Union[Shape5D, Interval5D, None] = None, + *, + x: Optional[SPAN] = None, + y: Optional[SPAN] = None, + z: Optional[SPAN] = None, + t: Optional[SPAN] = None, + c: Optional[SPAN] = None, + ) -> Arr: + return self.cut(self.interval.clamped(interval, x=x, y=y, z=z, t=t, c=c)) @property - def roi(self) -> Slice5D: - return self.to_slice_5d() + def interval(self) -> Interval5D: + return self.shape.to_interval5d().translated(self.location) def set(self, value: "Array5D", autocrop: bool = False, mask_value: Optional[Number] = None) -> None: if autocrop: - value_slc = value.roi.clamped(self.roi) + value_slc = value.interval.clamped(self.interval) value = value.cut(value_slc) - self.cut(value.roi).localSet(value.translated(-self.location), mask_value=mask_value) + self.cut(value.interval).localSet(value.translated(-self.location), mask_value=mask_value) def localSet(self, value: "Array5D", mask_value: Optional[Number] = None) -> None: self_raw = self.raw(Point5D.LABELS) @@ -264,7 +302,7 @@ def as_uint8(self, normalized: bool = True) -> "Array5D": return Array5D((self._data * multi).astype(np.uint8), axiskeys=self.axiskeys) def get_borders(self: Arr, thickness: Shape5D) -> Iterable[Arr]: - for border_slc in self.roi.get_borders(thickness): + for border_slc in self.interval.get_borders(thickness): yield self.cut(border_slc) def unique_border_colors(self, border_thickness: Optional[Shape5D] = None) -> "StaticLine": @@ -284,7 +322,7 @@ def threshold(self: Arr, threshold: float) -> Arr: return out def connected_components(self: Arr, background: int = 0, connectivity: str = "xyz") -> Arr: - piece_shape = self.shape.with_coord(**{axis: 1 for axis in set("xyztc").difference(connectivity)}) + piece_shape = self.shape.updated(**{axis: 1 for axis in set("xyztc").difference(connectivity)}) output = Array5D.allocate_like(self, dtype=np.int64) for piece in self.split(piece_shape): raw = piece.raw(connectivity) @@ -299,7 +337,7 @@ def paint_point(self, point: Point5D, value: Number, local: bool = False): self._data[np_selection] = value def combine(self: Arr, others: Sequence[Arr]) -> Arr: - out_roi = Slice5D.enclosing([self.roi] + [o.roi for o in others]) + out_roi = Interval5D.enclosing([self.interval] + [o.interval for o in others]) out = self.allocate(slc=out_roi, dtype=self.dtype, axiskeys=self.axiskeys, value=0) out.set(self) for other in others: diff --git a/ndstructs/point5D.py b/ndstructs/point5D.py index 161b0e8..d398c66 100644 --- a/ndstructs/point5D.py +++ b/ndstructs/point5D.py @@ -1,19 +1,13 @@ import itertools import functools -from ndstructs.utils.JsonSerializable import Referencer import operator -import numpy as np -from typing import Dict, Tuple, Iterator, List, Iterable, TypeVar, Type, Union, Optional, Callable, Any +from typing import Dict, Tuple, Iterator, List, Iterable, TypeVar, Type, Union, Optional from numbers import Number +import numpy as np from ndstructs.utils import JsonSerializable, Dereferencer, Referencer -PT = TypeVar("PT", bound="Point5D", covariant=True) -PT_OPERABLE = Union["Point5D", Number] - -T = TypeVar("T") - class KeyMap: def __init__(self, x: str = "x", y: str = "y", z: str = "z", t: str = "t", c: str = "c"): @@ -30,12 +24,12 @@ def reversed(self) -> "KeyMap": return KeyMap(**{v: k for k, v in self._map.items()}) -INF = float("inf") -NINF = -INF +PT = TypeVar("PT", bound="Point5D", covariant=True) +PT_OPERABLE = Union["Point5D", int] class Point5D(JsonSerializable): - LABELS = "txyzc" + LABELS = "txyzc" # if you change this order, also change self._array order SPATIAL_LABELS = "xyz" LABEL_MAP = {label: index for index, label in enumerate(LABELS)} DTYPE = np.float64 @@ -46,21 +40,23 @@ def __init__(self, *, t: int = 0, x: int = 0, y: int = 0, z: int = 0, c: int = 0 self.z = z self.t = t self.c = c + self._array = np.asarray([t, x, y, z, c]) def __hash__(self) -> int: return hash(self.to_tuple(self.LABELS)) @classmethod - def from_tuple(cls: Type[PT], tup: Tuple[float, ...], labels: str) -> PT: - assert len(tup) == len(labels) + def from_tuple(cls: Type[PT], tup: Tuple[int, ...], labels: str) -> PT: + if len(tup) != len(labels): + raise ValueError(f"Mismatched args: {tup} , {labels}") return cls(**{label: value for label, value in zip(labels, tup)}) @classmethod def from_np(cls: Type[PT], arr: np.ndarray, labels: str) -> PT: - return cls.from_tuple(tuple(float(e) for e in arr), labels) + return cls.from_tuple(tuple(int(e) for e in arr), labels) - def to_tuple(self, axis_order: str, type_converter: Callable[[float], T] = lambda x: float(x)) -> Tuple[T, ...]: - return tuple(type_converter(self[label]) for label in axis_order) + def to_tuple(self, axis_order: str) -> Tuple[int, ...]: + return tuple(self[label] for label in axis_order) def to_dict(self) -> Dict[str, float]: return {k: self[k] for k in self.LABELS} @@ -72,9 +68,9 @@ def __repr__(self) -> str: contents = ",".join((f"{label}:{val}" for label, val in self.to_dict().items())) return f"{self.__class__.__name__}({contents})" - @staticmethod - def zero(*, t: int = 0, x: int = 0, y: int = 0, z: int = 0, c: int = 0) -> "Point5D": - return Point5D(t=t, x=x, y=y, z=z, c=c) + @classmethod + def zero(cls: Type[PT], *, t: int = 0, x: int = 0, y: int = 0, z: int = 0, c: int = 0) -> PT: + return cls(t=t, x=x, y=y, z=z, c=c) @staticmethod def one(*, t: int = 1, x: int = 1, y: int = 1, z: int = 1, c: int = 1) -> "Point5D": @@ -93,7 +89,7 @@ def __getitem__(self, key: str) -> int: return self.c raise KeyError(key) - def with_coord( + def updated( self: PT, *, t: Optional[int] = None, @@ -160,29 +156,20 @@ def __floordiv__(self: PT, other: PT_OPERABLE) -> PT: def __mul__(self: PT, other: PT_OPERABLE) -> PT: return self.__np_op(other, "__mul__") - def clamped(self: PT, minimum: "Point5D" = None, maximum: "Point5D" = None) -> PT: - minimum = minimum or self - maximum = maximum or self - result = np.maximum(self.to_np(self.LABELS), minimum.to_np(self.LABELS)) - result = np.minimum(result, maximum.to_np(self.LABELS)) - return self.__class__(**{label: val for label, val in zip(self.LABELS, result)}) + def clamped(self: PT, minimum: Optional["Point5D"] = None, maximum: Optional["Point5D"] = None) -> PT: + result = self.to_np(self.LABELS) + if minimum is not None: + result = np.maximum(self.to_np(self.LABELS), minimum.to_np(self.LABELS)) + if maximum is not None: + result = np.minimum(result, maximum.to_np(self.LABELS)) + return self.from_np(result, labels=self.LABELS) def as_shape(self) -> "Shape5D": return Shape5D(**self.to_dict()) - @classmethod - def as_ceil(cls: Type[PT], arr: np.ndarray, axis_order: str = LABELS) -> PT: - raw = np.ceil(arr) - return cls.from_np(raw, axis_order) - - @classmethod - def as_floor(cls: Type[PT], arr: np.ndarray, axis_order: str = LABELS) -> PT: - raw = np.floor(arr) - return cls.from_np(raw, axis_order) - def relabeled(self: PT, keymap: KeyMap) -> PT: params = {target_key: self[src_key] for src_key, target_key in keymap.items()} - return self.with_coord(**params) + return self.updated(**params) def interpolate_until(self, endpoint: "Point5D") -> Iterable["Point5D"]: start = self.to_np(self.LABELS) @@ -267,7 +254,7 @@ def volume(self) -> float: def hypervolume(self) -> float: return functools.reduce(operator.mul, self.to_tuple(Point5D.LABELS)) - def to_slice_5d(self, offset: Point5D = Point5D.zero()) -> "Interval5D": + def to_interval5d(self, offset: Point5D = Point5D.zero()) -> "Interval5D": return Interval5D.create_from_start_stop(offset, self + offset) @classmethod @@ -275,86 +262,8 @@ def from_point(cls: Type[PT], point: Point5D) -> PT: return cls(**{k: v or 1 for k, v in point.to_dict().items()}) -INTERVALABLE = Union["Interval", int, None, Tuple[Optional[int], Optional[int]]] - - -class Interval: - """A contiguous interval in space of indicies between start (inclusive) and stop (exclusive)""" - - def __init__(self, start: int = 0, stop: Optional[int] = None): - self.start = start - self.stop = stop - assert self.stop >= self.start - - @classmethod - def create(cls, value: INTERVALABLE) -> "Interval": - if isinstance(value, int): - return Interval(value, value + 1) - if isinstance(value, Interval): - return value - if isinstance(value, tuple): - return Interval(value[0] or 0, value[1]) - return Interval.all() - - def __eq__(self, other: INTERVALABLE) -> bool: - other_interval = Interval.create(other) - return self.start == other_interval.start and self.stop == other_interval.stop - - def __hash__(self) -> int: - return hash((self.start, self.stop)) - - @classmethod - def all(cls) -> "Interval": - return cls() - - @classmethod - def zero(cls) -> "Interval": - return cls(start=0, stop=1) - - def to_slice(self) -> slice: - return slice(self.start, self.stop) - - def is_defined(self) -> bool: - return self.stop != None - - def defined_with(self, limit: "Interval") -> "Interval": - assert limit.is_defined() - return Interval(self.start, self.stop if self.stop != INF else limit.stop) - - def contains(self, other: "Interval") -> bool: - if self.stop == None or other.stop == None: - return False - return self.start <= other.start and self.stop >= other.stop - - def split(self, step: int, clamp: bool = True) -> Iterable["Interval"]: - start = self.start - while self.stop == None or start < self.stop: - stop = start + step - piece = Interval(start, stop) - if clamp: - piece = piece.clamped(self) - yield piece - start = stop - - def get_tiles(self, tile_side: int, clamp: bool) -> Iterable["Interval"]: - start = (self.start // tile_side) * tile_side - return Interval(start, self.stop).split(tile_side, clamp=clamp) - - def clamped(self, limits: "Interval") -> "Interval": - if limits.stop == None: - stop = self.stop - elif self.stop == None: - stop = limits.stop - else: - stop = min(self.stop, limits.stop) - return Interval(max(self.start, limits.start), stop) - - def enlarged(self, radius: int) -> "Interval": - return Interval(self.start - radius, None if self.stop == None else self.stop + radius) - - def translated(self, offset: int) -> "Interval": - return Interval(self.start + offset, None if self.stop == None else self.stop + offset) - +INTERVAL = Tuple[int, int] +SPAN = Union[int, INTERVAL] INTERVAL_5D = TypeVar("INTERVAL_5D", bound="Interval5D", covariant=True) @@ -362,84 +271,43 @@ def translated(self, offset: int) -> "Interval": class Interval5D(JsonSerializable): """A labeled 5D interval""" - def __init__( - self, - *, - t: INTERVALABLE = Interval.all(), - c: INTERVALABLE = Interval.all(), - x: INTERVALABLE = Interval.all(), - y: INTERVALABLE = Interval.all(), - z: INTERVALABLE = Interval.all(), - ): - self.x = Interval.create(x) - self.y = Interval.create(y) - self.z = Interval.create(z) - self.t = Interval.create(t) - self.c = Interval.create(c) - self.start = Point5D(x=self.x.start, y=self.y.start, z=self.z.start, t=self.t.start, c=self.c.start) - - def get_stop(self) -> Optional[Point5D]: - x = self.x.stop - y = self.y.stop - z = self.z.stop - t = self.t.stop - c = self.c.stop - if x is None or y is None or z is None or t is None or c is None: - return None - return Point5D(x=x, y=y, z=z, t=t, c=c) + def __init__(self, *, t: SPAN, c: SPAN, x: SPAN, y: SPAN, z: SPAN): + self.x = (x, x + 1) if isinstance(x, int) else x + self.y = (y, y + 1) if isinstance(y, int) else y + self.z = (z, z + 1) if isinstance(z, int) else z + self.t = (t, t + 1) if isinstance(t, int) else t + self.c = (c, c + 1) if isinstance(c, int) else c + if any(interval[0] > interval[1] for interval in (self.x, self.y, self.z, self.t, self.c)): + raise ValueError(f"Intervals must have start <= stop") + self.start = Point5D(x=self.x[0], y=self.y[0], z=self.z[0], t=self.t[0], c=self.c[0]) + self.stop = Point5D(x=self.x[1], y=self.y[1], z=self.z[1], t=self.t[1], c=self.c[1]) @staticmethod - def zero( - *, - t: INTERVALABLE = Interval.zero(), - c: INTERVALABLE = Interval.zero(), - x: INTERVALABLE = Interval.zero(), - y: INTERVALABLE = Interval.zero(), - z: INTERVALABLE = Interval.zero(), - ) -> "Interval5D": + def zero(*, t: SPAN = 0, c: SPAN = 0, x: SPAN = 0, y: SPAN = 0, z: SPAN = 0) -> "Interval5D": """Creates a slice with coords defaulting to slice(0, 1), except where otherwise specified""" return Interval5D(t=t, c=c, x=x, y=y, z=z) def relabeled(self: INTERVAL_5D, keymap: KeyMap) -> INTERVAL_5D: params = {target_key: self[src_key] for src_key, target_key in keymap.items()} - return self.with_coord(**params) + return self.updated(**params) def __eq__(self, other: object) -> bool: if not isinstance(other, Interval5D): return False - return self.to_tuple(Point5D.LABELS) == other.to_tuple(Point5D.LABELS) + return self.start == other.start and self.stop == other.stop def __hash__(self) -> int: return hash(self.to_tuple(Point5D.LABELS)) def contains(self, other: "Interval5D") -> bool: - return all(self[k].contains(other[k]) for k in Point5D.LABELS) - - def is_defined(self) -> bool: - return all(i.is_defined() for i in self.to_tuple(Point5D.LABELS)) - - def defined_with(self: INTERVAL_5D, limits: Union[Shape5D, "Interval5D"]) -> INTERVAL_5D: - """Interval5D can have intervals which are open to interpretation, like Interval(0, None). This method - forces those slices expand into their interpretation within the boundaries of 'limits'""" - limits_interval = limits if isinstance(limits, Interval5D) else limits.to_slice_5d() - return self.with_coord(**{k: self[k].defined_with(limits_interval[k]) for k in Point5D.LABELS}) + return self.start <= other.start and self.stop >= other.stop - def to_dict(self) -> Dict[str, Interval]: + def to_dict(self) -> Dict[str, INTERVAL]: return {k: self[k] for k in Point5D.LABELS} - @staticmethod - def all( - t: Interval = Interval.all(), - c: Interval = Interval.all(), - x: Interval = Interval.all(), - y: Interval = Interval.all(), - z: Interval = Interval.all(), - ) -> "Interval5D": - return Interval5D(t=t, c=c, x=x, y=y, z=z) - @classmethod - def make_intervals(cls, start: Point5D, stop: Point5D) -> Dict[str, Interval]: - return {k: Interval(int(start[k]), int(stop[k])) for k in Point5D.LABELS} + def make_intervals(cls, start: Point5D, stop: Point5D) -> Dict[str, INTERVAL]: + return {k: (start[k], stop[k]) for k in Point5D.LABELS} @staticmethod def create_from_start_stop(start: Point5D, stop: Point5D) -> "Interval5D": @@ -452,26 +320,37 @@ def from_json_data(data: dict, dereferencer: Optional[Dereferencer] = None) -> " return Interval5D.create_from_start_stop(start, stop) def to_json_data(self, referencer: Referencer = lambda obj: None) -> dict: - self_tuple = self.to_tuple(Point5D.LABELS) - return {"start": self_tuple[0], "stop": self_tuple[1]} + return {"start": self.start.to_tuple(Point5D.LABELS), "stop": self.stop.to_tuple(Point5D.LABELS)} def from_start_stop(self: INTERVAL_5D, start: Point5D, stop: Point5D) -> INTERVAL_5D: slices = self.make_intervals(start, stop) - return self.with_coord(**slices) + return self.updated(**slices) + + def _ranges(self, block_shape: Shape5D) -> Iterator[List[int]]: + starts = self.start.to_np(Point5D.LABELS) + ends = self.stop.to_np(Point5D.LABELS) + steps = block_shape.to_np(Point5D.LABELS) + for start, end, step in zip(starts, ends, steps): + yield list(np.arange(start, end, step)) def split(self: INTERVAL_5D, block_shape: Shape5D) -> Iterator[INTERVAL_5D]: """Splits self into multiple Interval5D instances, starting from self.start. Every piece shall have shape == block_shape excedpt for the last one, which will be clamped to self.stop""" + for begin_tuple in itertools.product(*self._ranges(block_shape)): + start = Point5D.from_tuple(begin_tuple, Point5D.LABELS) + stop = (start + block_shape).clamped(maximum=self.stop) + yield self.from_start_stop(start, stop) - yield from itertools.product([self[k].split(int(block_shape[k])) for k in Point5D.LABELS]) - - def get_tiles(self: INTERVAL_5D, tile_shape: Shape5D, clamp: bool) -> Iterator[INTERVAL_5D]: + def get_tiles(self: INTERVAL_5D, tile_shape: Shape5D) -> Iterator[INTERVAL_5D]: """Gets all tiles that would cover the entirety of self. Tiles that overflow self can be clamped by setting `clamp` to True""" + start = (self.start // tile_shape) * tile_shape + tile_shape_raw = tile_shape.to_np(Point5D.LABELS) + stop_raw = np.ceil(self.stop.to_np(Point5D.LABELS) / tile_shape_raw) * tile_shape_raw + stop = Point5D.from_np(stop_raw, labels=Point5D.LABELS) + yield from self.from_start_stop(start, stop).split(tile_shape) - yield from itertools.product([self[k].get_tiles(int(tile_shape[k]), clamp=clamp) for k in Point5D.LABELS]) - - def __getitem__(self, key: str) -> Interval: + def __getitem__(self, key: str) -> INTERVAL: if key == "x": return self.x if key == "y": @@ -485,14 +364,14 @@ def __getitem__(self, key: str) -> Interval: raise KeyError(key) # override this in subclasses so that it returns an instance of self.__class__ - def with_coord( + def updated( self: INTERVAL_5D, *, - t: INTERVALABLE = None, - c: INTERVALABLE = None, - x: INTERVALABLE = None, - y: INTERVALABLE = None, - z: INTERVALABLE = None, + t: Optional[SPAN] = None, + c: Optional[SPAN] = None, + x: Optional[SPAN] = None, + y: Optional[SPAN] = None, + z: Optional[SPAN] = None, ) -> INTERVAL_5D: return self.__class__( t=self.t if t is None else t, @@ -502,38 +381,56 @@ def with_coord( z=self.z if z is None else z, ) - def with_full_c(self: INTERVAL_5D) -> INTERVAL_5D: - return self.with_coord(c=Interval.all()) + @property + def shape(self) -> Shape5D: + return Shape5D(**(self.stop - self.start).to_dict()) - def clamped(self: INTERVAL_5D, roi: Union[Shape5D, "Interval5D"]) -> INTERVAL_5D: - interv = roi if isinstance(roi, Interval5D) else roi.to_slice_5d() - return self.with_coord(**{k: self[k].clamped(interv[k]) for k in Point5D.LABELS}) + def clamped( + self: INTERVAL_5D, + limits: Union[Shape5D, "Interval5D", None] = None, + *, + x: Optional[SPAN] = None, + y: Optional[SPAN] = None, + z: Optional[SPAN] = None, + t: Optional[SPAN] = None, + c: Optional[SPAN] = None, + ) -> INTERVAL_5D: + limits = limits or self + limits_interval = limits if isinstance(limits, Interval5D) else limits.to_interval5d() + updated_limits = limits_interval.updated(x=x, y=y, z=z, t=t, c=c) + return self.from_start_stop( + self.start.clamped(updated_limits.start, updated_limits.stop), + self.stop.clamped(updated_limits.start, updated_limits.stop), + ) def enlarged(self: INTERVAL_5D, radius: Point5D) -> INTERVAL_5D: - return self.with_coord(**{k: self[k].enlarged(int(radius[k])) for k in Point5D.LABELS}) + return self.from_start_stop(self.start - radius, self.stop + radius) def translated(self: INTERVAL_5D, offset: Point5D) -> INTERVAL_5D: - return self.with_coord(**{k: self[k].translated(offset[k]) for k in Point5D.LABELS}) + return self.from_start_stop(self.start + offset, self.stop + offset) def to_slices(self, axis_order: str = Point5D.LABELS) -> Tuple[slice, ...]: - return tuple(self[axis].to_slice() for axis in axis_order) + return tuple(slice(self[k][0], self[k][1]) for k in axis_order) - def to_tuple(self, axis_order: str) -> Tuple[Interval, ...]: + def to_tuple(self, axis_order: str) -> Tuple[INTERVAL, ...]: return tuple(self[k] for k in axis_order) - def to_start_stop(self, axis_order: str) -> Tuple[Tuple[Optional[int], ...], Tuple[Optional[int], ...]]: - start = tuple(self[k].start for k in axis_order) - stop = tuple(self[k].stop for k in axis_order) - return (start, stop) + def to_start_stop_tuple(self, axis_order: str) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + return (self.start.to_tuple(axis_order), self.stop.to_tuple(axis_order)) def to_ilastik_cutout_subregion(self, axis_order: str) -> str: - return str(list(self.to_start_stop(axis_order=axis_order))) + return str(list(self.to_start_stop_tuple(axis_order=axis_order))) def __repr__(self) -> str: - interval_reprs = ", ".join( - f"{k}:{self[k].start}_{self[k].stop}" for k in Point5D.LABELS if self[k] != Interval.all() - ) - return f"{self.__class__.__name__}({interval_reprs})" + reprs: List[str] = [] + for k, span in self.to_dict().items(): + if span[1] - span[0] == 1: + if span[0] != 0: + reprs.append(f"{k}:{span[0]}") + else: + reprs.append(f"{k}:{span[0]}_{span[1]}") + spans = ", ".join(reprs) + return self.__class__.__name__ + f"({spans})" def get_borders(self: INTERVAL_5D, thickness: Shape5D) -> Iterable[INTERVAL_5D]: """Returns subslices of self, such that these subslices are at the borders @@ -544,23 +441,18 @@ def get_borders(self: INTERVAL_5D, thickness: Shape5D) -> Iterable[INTERVAL_5D]: slc.get_borders(Interval5D.zero(x=1, y=1)) will produce 4 borders (left, right, top, bottom) If, for any axis, thickness[axis] == self.shape[axis], then there will be duplicated borders in the output """ - thickness_interval = thickness.to_slice_5d(offset=self.start) - assert all(self[k].contains(thickness_interval[k]) for k in Point5D.LABELS) + thickness_interval = thickness.to_interval5d(offset=self.start) + if not self.contains(thickness_interval): + raise ValueError(f"Bad thickness {thickness} for interval {self}") # FIXME: I haven't ported this yet!!!!! for axis, axis_thickness in thickness.to_dict().items(): if axis_thickness == 0: continue - slc = self[axis] - yield self.with_coord(**{axis: slice(slc.start, slc.start + axis_thickness)}) - yield self.with_coord(**{axis: slice(slc.stop - axis_thickness, slc.stop)}) - - def mod_tile(self: INTERVAL_5D, tile_shape: Shape5D) -> INTERVAL_5D: - assert self.shape <= tile_shape - offset = self.start - (self.start % tile_shape) - return self.from_start_stop(self.start - offset, self.stop - offset) + span = self[axis] + yield self.updated(**{axis: (span[0], span[0] + axis_thickness)}) + yield self.updated(**{axis: (span[1] - axis_thickness, span[1])}) def get_neighboring_tiles(self: INTERVAL_5D, tile_shape: Shape5D) -> Iterator[INTERVAL_5D]: - assert self.shape <= tile_shape for axis in Point5D.LABELS: for axis_offset in (tile_shape[axis], -tile_shape[axis]): offset = Point5D.zero(**{axis: axis_offset}) @@ -569,7 +461,8 @@ def get_neighboring_tiles(self: INTERVAL_5D, tile_shape: Shape5D) -> Iterator[IN def get_neighbor_tile_adjacent_to( self: INTERVAL_5D, *, anchor: "Interval5D", tile_shape: Shape5D ) -> Optional[INTERVAL_5D]: - assert self.contains(anchor) + if not self.contains(anchor): + raise ValueError(f"Anchor {anchor} is not contained within {self}") direction_axis: Optional[str] = None for axis in Point5D.LABELS: @@ -584,11 +477,11 @@ def get_neighbor_tile_adjacent_to( # a neighbor has all but one coords equal offset = Point5D.zero(**{direction_axis: tile_shape[direction_axis]}) - if anchor[direction_axis].stop == self[direction_axis].stop: + if anchor[direction_axis][1] == self[direction_axis][1]: if self.shape != tile_shape: # Getting a further tile from a partial tile return None return self.translated(offset) - if anchor[direction_axis].start == self[direction_axis].start: + if anchor[direction_axis][0] == self[direction_axis][0]: if self.start - offset < Point5D.zero(): # no negative neighbors return None return self.translated(-offset) diff --git a/tests/test_array5D.py b/tests/test_array5D.py index 0aaee86..2e4b2bf 100644 --- a/tests/test_array5D.py +++ b/tests/test_array5D.py @@ -1,4 +1,4 @@ -from ndstructs import Point5D, Shape5D, Slice5D, Array5D +from ndstructs import Point5D, Shape5D, Interval5D, Array5D, All, ScalarData, StaticLine import numpy @@ -9,11 +9,11 @@ def test_creation(): def test_allocation(): - arr = Array5D.allocate(Slice5D.zero(x=slice(100, 200), y=slice(200, 300)), numpy.uint8) + arr = Array5D.allocate(Interval5D.zero(x=(100, 200), y=(200, 300)), numpy.uint8) assert arr.shape == Shape5D(x=100, y=100) assert arr.location == Point5D.zero(x=100, y=200) - arr = Array5D.allocate(Slice5D.zero(x=slice(-100, 200), y=slice(200, 300)), numpy.uint8) + arr = Array5D.allocate(Interval5D.zero(x=(-100, 200), y=(200, 300)), numpy.uint8) assert arr.shape == Shape5D(x=300, y=100) assert arr.location == Point5D.zero(x=-100, y=200) @@ -156,16 +156,42 @@ def test_cut(): ]) # fmt: on arr = Array5D(raw, "zy") - piece = arr.cut(Slice5D(y=slice(1, 3))) + piece = arr.cut(y=(1, 3)) assert (piece.raw("zy") == expected_piece).all() assert piece.location == Point5D.zero(y=1) + assert piece.shape == Shape5D(y=2, z=4) - global_sub_piece = piece.cut(Slice5D(y=2)) + global_sub_piece = piece.cut(y=2) assert (global_sub_piece.raw("zy") == expected_global_sub_piece).all() - local_sub_piece = piece.local_cut(Slice5D(y=1)) + local_sub_piece = piece.local_cut(y=1) assert (local_sub_piece.raw("zy") == global_sub_piece.raw("zy")).all() + slice_z2_y2 = arr.cut(Interval5D.zero(z=2, y=(2, 4))) + assert (slice_z2_y2.raw("zy") == numpy.asarray([[13, 14]])).all() + + slice_z0_2__yall = arr.cut(Interval5D.zero(z=(0, 2), y=123456), y=All()) + assert (slice_z0_2__yall.raw("zy") == numpy.asarray([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])).all() + + +def test_local_cut(): + # fmt: off + data = Array5D(numpy.asarray([ + [1, 2, 3, 4, 5], + [6, 7, 8, 9, 10], + [11, 12, 13, 14, 15], + [16, 17, 18, 19, 20], + ]), axiskeys="yx", location=Point5D.zero(y=1, x=123)) + # fmt: on + + piece = data.cut(Interval5D.zero(y=2), x=All()) + assert (piece.raw("yx") == numpy.asarray([[6, 7, 8, 9, 10]])).all() + assert piece.location == Point5D.zero(y=2, x=123) + + local_piece = data.local_cut(Interval5D.zero(y=1), x=All()) + assert (local_piece.raw("yx") == numpy.asarray([[6, 7, 8, 9, 10]])).all() + assert local_piece.location == Point5D.zero(y=2, x=123) + def test_setting_rois(): # fmt: off @@ -242,7 +268,7 @@ def test_clamping(): ]) # fmt: on arr = Array5D(raw, "zyx") - clamped_raw = arr.clamped(Slice5D(z=1, x=slice(1, 4), y=slice(1, 3))).raw("zyx") + clamped_raw = arr.clamped(z=1, x=(1, 4), y=(1, 3)).raw("zyx") assert (clamped_raw == expected_clamped_array).all() @@ -265,7 +291,7 @@ def test_sample_channels(): [13, 23, 33, 43, 53]], ]), "cyx") - mask = Array5D(numpy.asarray([ + mask = ScalarData(numpy.asarray([ [1, 1, 1, 0, 0], [0, 0, 1, 0, 0], [0, 0, 1, 0, 0], @@ -408,7 +434,7 @@ def test_color_filter(): [ 17, 27, 37, 47]] ]), axiskeys="cyx") - color = Array5D(numpy.asarray([100, 200]), axiskeys="c") + color = StaticLine(numpy.asarray([100, 200]), axiskeys="c") expected_color_filtered = Array5D(numpy.asarray([ [[100, 0, 0, 100], @@ -561,12 +587,11 @@ def test_from_stack(): z_stacked = Array5D.from_stack(stack, stack_along="z") for i in range(len(stack)): - assert (z_stacked.cut(Slice5D(z=i)).raw("yx") == stack[i].raw("yx")).all() + assert (z_stacked.cut(z=i).raw("yx") == stack[i].raw("yx")).all() y_stacked = Array5D.from_stack(stack, stack_along="y") for i in range(len(stack)): - stack_slc = Slice5D(y=slice(3 * i, 3 * (i + 1))) - assert (y_stacked.cut(stack_slc).raw("yx") == stack[i].raw("yx")).all() + assert (y_stacked.cut(y=(3 * i, 3 * (i + 1))).raw("yx") == stack[i].raw("yx")).all() def test_combine(): diff --git a/tests/test_point5D.py b/tests/test_point5D.py index eb75d6f..0433d6c 100644 --- a/tests/test_point5D.py +++ b/tests/test_point5D.py @@ -3,6 +3,10 @@ import pytest +inf = Point5D.one() * 999999 +ninf = Point5D.one() * -999999 + + def test_labeled_coords_constructor_property_assignment(): p = Point5D(x=1, y=2, z=3, t=4, c=5) assert p.x == 1 @@ -30,30 +34,6 @@ def test_one_factory_method_defaults_coords_to_one(): assert p.c == 456 -def test_inf_factory_method_defaults_coords_to_inf(): - p = Point5D.inf(c=123, y=456) - assert p.x == Point5D.INF - assert p.y == 456 - assert p.z == Point5D.INF - assert p.t == Point5D.INF - assert p.c == 123 - - -def test_ininf_factory_method_defaults_coords_to_ninf(): - p = Point5D.ninf(c=123, y=456) - assert p.x == Point5D.NINF - assert p.y == 456 - assert p.z == Point5D.NINF - assert p.t == Point5D.NINF - assert p.c == 123 - - -def test_as_ceil_factory(): - raw = numpy.asarray([1.1, 2.2, 3.3, 4.0, 5.0]) - p = Point5D.as_ceil(raw, "xyztc") - assert p == Point5D(x=2, y=3, z=4, t=4, c=5) - - def test_to_tuple_respects_given_axis_order(): p = Point5D(x=1, y=2, z=3, t=4, c=5) assert p.to_tuple("xyz") == (1, 2, 3) == (p.x, p.y, p.z) @@ -67,17 +47,17 @@ def test_to_dict_consistent(): assert p.to_dict() == {"x": 1, "y": 2, "z": 3, "t": 4, "c": 5} -def test_with_coord_modifies_coords_and_keeps_original_intact(): +def test_updated_modifies_coords_and_keeps_original_intact(): p = Point5D(x=1, y=2, z=3, t=4, c=5) - assert p.with_coord(z=1000).z == 1000 - assert p.with_coord(x=99, y=88, c=77).to_tuple("xyctz") == (99, 88, 77, 4, 3) + assert p.updated(z=1000).z == 1000 + assert p.updated(x=99, y=88, c=77).to_tuple("xyctz") == (99, 88, 77, 4, 3) assert p.to_tuple("xyztc") == (1, 2, 3, 4, 5) def test_clamped_keeps_values_within_limits(): p = Point5D(x=100, y=200, z=300, t=400, c=500) - assert p.clamped(maximum=Point5D.inf(y=50, c=600)).to_tuple("yc") == (50, 500) - assert p.clamped(minimum=Point5D.ninf(y=300, x=90)).to_tuple("yx") == (300, 100) + assert p.clamped(maximum=inf.updated(y=50, c=600)).to_tuple("yc") == (50, 500) + assert p.clamped(minimum=ninf.updated(y=300, x=90)).to_tuple("yx") == (300, 100) min_pt = Point5D(x=10, y=20, z=30, t=40, c=1000) assert p.clamped(minimum=min_pt).to_tuple("xyztc") == (100, 200, 300, 400, 1000) @@ -85,7 +65,7 @@ def test_clamped_keeps_values_within_limits(): max_pt = Point5D(x=1, y=2, z=3, t=4, c=1000) assert p.clamped(maximum=max_pt).to_tuple("xyztc") == (1, 2, 3, 4, 500) - clamped_pt = p.clamped(minimum=Point5D.ninf(x=20, t=50), maximum=Point5D.inf(x=120, t=500)) + clamped_pt = p.clamped(minimum=ninf.updated(x=20, t=50), maximum=inf.updated(x=120, t=500)) assert clamped_pt.to_tuple("xt") == (100, 400) @@ -98,7 +78,6 @@ def test_point_equality(): def test_point_arithmetic(): p = Point5D(x=100, y=200, z=300, t=400, c=500) assert p + Point5D.zero(x=100) == Point5D(x=200, y=200, z=300, t=400, c=500) - assert p + Point5D.inf(x=100) == Point5D.inf(x=200) assert p + Point5D(x=1, y=2, z=3, t=4, c=5) == Point5D(x=101, y=202, z=303, t=404, c=505) other = Point5D(x=1, y=2, z=3, t=4, c=5) diff --git a/tests/test_slice5D.py b/tests/test_slice5D.py index 672173e..4a6bdd0 100644 --- a/tests/test_slice5D.py +++ b/tests/test_slice5D.py @@ -1,96 +1,81 @@ -from ndstructs import Point5D, Shape5D, Slice5D, KeyMap +from ndstructs import Point5D, Shape5D, Interval5D, KeyMap import numpy import pytest -def test_all_constructor(): - slc = Slice5D.all(x=3, z=slice(10, 20)) - assert slc.to_slices("xyztc") == (slice(3, 4), slice(None), slice(10, 20), slice(None), slice(None)) - - def test_from_start_stop(): start = Point5D(x=10, y=20, z=30, t=40, c=50) stop = start + 10 - slc = Slice5D.create_from_start_stop(start, stop) - assert slc == Slice5D(x=slice(10, 20), y=slice(20, 30), z=slice(30, 40), t=slice(40, 50), c=slice(50, 60)) + slc = Interval5D.create_from_start_stop(start, stop) + assert slc == Interval5D(x=(10, 20), y=(20, 30), z=(30, 40), t=(40, 50), c=(50, 60)) def test_slice_translation(): - slc = Slice5D(x=slice(10, 100), y=slice(20, 200)) + slc = Interval5D.zero(x=(10, 100), y=(20, 200)) translated_slc = slc.translated(Point5D(x=1, y=2, z=3, t=4, c=5)) - assert translated_slc == Slice5D(x=slice(11, 101), y=slice(22, 202)) + assert translated_slc == Interval5D(x=(11, 101), y=(22, 202), z=(3, 4), t=(4, 5), c=(5, 6)) - slc = Slice5D(x=slice(10, 100), y=slice(20, 200), z=0, t=0, c=0) + slc = Interval5D(x=(10, 100), y=(20, 200), z=0, t=0, c=0) translated_slc = slc.translated(Point5D(x=-1, y=-2, z=-3, t=-4, c=-5000)) - assert translated_slc == Slice5D( - x=slice(9, 99), y=slice(18, 198), z=slice(-3, -2), t=slice(-4, -3), c=slice(-5000, -4999) - ) + assert translated_slc == Interval5D(x=(9, 99), y=(18, 198), z=(-3, -2), t=(-4, -3), c=(-5000, -4999)) def test_slice_enlarge(): - slc = Slice5D(x=slice(10, 100), y=slice(20, 200)) + slc = Interval5D.zero(x=(10, 100), y=(20, 200)) enlarged = slc.enlarged(radius=Point5D(x=1, y=2, z=3, t=4, c=5)) - assert enlarged == Slice5D(x=slice(9, 101), y=slice(18, 202)) + assert enlarged == Interval5D(x=(9, 101), y=(18, 202), z=(-3, 4), t=(-4, 5), c=(-5, 6)) - slc = Slice5D(x=slice(10, 100), y=slice(20, 200), z=0, t=0, c=0) - enlarged = slc.enlarged(radius=Point5D(x=1, y=2, z=3, t=4, c=5)) - assert enlarged == Slice5D(x=slice(9, 101), y=slice(18, 202), z=slice(-3, 4), t=slice(-4, 5), c=slice(-5, 6)) + slc2 = Interval5D(x=(10, 100), y=(20, 200), z=0, t=0, c=0) + enlarged2 = slc2.enlarged(radius=Point5D(x=1, y=2, z=3, t=4, c=5)) + assert enlarged2 == Interval5D(x=(9, 101), y=(18, 202), z=(-3, 4), t=(-4, 5), c=(-5, 6)) def test_slice_contains_smaller_slice(): - outer_slice = Slice5D(x=slice(10, 100), y=slice(20, 200)) - inner_slice = Slice5D(x=slice(20, 50), y=slice(30, 40), z=0, t=0, c=0) + outer_slice = Interval5D.zero(x=(10, 100), y=(20, 200)) + inner_slice = Interval5D(x=(20, 50), y=(30, 40), z=0, t=0, c=0) assert outer_slice.contains(inner_slice) def test_slice_does_not_contain_translated_slice(): - slc = Slice5D(x=slice(10, 100), y=slice(20, 200), z=0, t=0, c=0) + slc = Interval5D(x=(10, 100), y=(20, 200), z=0, t=0, c=0) translated_slc = slc.translated(Point5D.zero(x=10)) assert not slc.contains(translated_slc) def test_slice_clamp(): - outer = Slice5D(x=slice(10, 100), y=slice(20, 200)) - inner = Slice5D(x=slice(20, 50), y=slice(30, 40), z=0, t=0, c=0) + outer = Interval5D.zero(x=(10, 100), y=(20, 200)) + inner = Interval5D.zero(x=(20, 50), y=(30, 40)) assert outer.clamped(inner) == inner assert inner.clamped(outer) == inner - intersecting_outer = Slice5D(x=slice(50, 200), y=slice(30, 900)) - assert intersecting_outer.clamped(outer) == Slice5D(x=slice(50, 100), y=slice(30, 200)) - - intersecting_outer = Slice5D(x=slice(-100, 50), y=slice(10, 100)) - assert intersecting_outer.clamped(outer) == Slice5D(x=slice(10, 50), y=slice(20, 100)) + intersecting_outer = Interval5D.zero(x=(50, 200), y=(30, 900)) + assert intersecting_outer.clamped(outer) == Interval5D.zero(x=(50, 100), y=(30, 200)) - outside_outer = Slice5D(x=slice(200, 300), y=slice(400, 500)) - assert outside_outer.clamped(outer).defined_with(Shape5D()).shape.volume == 0 + intersecting_outer = Interval5D.zero(x=(-100, 50), y=(10, 100)) + assert intersecting_outer.clamped(outer) == Interval5D.zero(x=(10, 50), y=(20, 100)) - -def test_slice_defined_with(): - slc = Slice5D(x=slice(10, 20)) - - assert slc.defined_with(Shape5D(x=100, y=15, z=17)) == Slice5D.zero(x=slice(10, 20), y=slice(0, 15), z=slice(0, 17)) - - assert slc.defined_with(Slice5D.zero(x=slice(1, 3), y=slice(10, 20))) == Slice5D.zero( - x=slice(10, 20), y=slice(10, 20) - ) + outside_outer = Interval5D.zero(x=(200, 300), y=(400, 500)) + a = outside_outer.clamped(outer) + vol = a.shape.volume + assert a.shape.volume == 0 def test_to_slices(): - slc = Slice5D(x=1, y=2, z=slice(10, 20)) - assert slc.to_slices("xyztc") == (slice(1, 2), slice(2, 3), slice(10, 20), slice(None), slice(None)) - assert slc.to_slices("ytzcx") == (slice(2, 3), slice(None), slice(10, 20), slice(None), slice(1, 2)) + slc = Interval5D.zero(x=1, y=2, z=(10, 20)) + assert slc.to_slices("xyztc") == (slice(1, 2), slice(2, 3), slice(10, 20), slice(0, 1), slice(0, 1)) + assert slc.to_slices("ytzcx") == (slice(2, 3), slice(0, 1), slice(10, 20), slice(0, 1), slice(1, 2)) def test_with_coord(): - slc = Slice5D(x=0, y=1, z=2, t=3, c=4) - assert slc.with_coord(z=slice(10, 20)).to_slices("xyztc") == ( + slc = Interval5D(x=0, y=1, z=2, t=3, c=4) + assert slc.updated(z=(10, 20)).to_slices("xyztc") == ( slice(0, 1), slice(1, 2), slice(10, 20), slice(3, 4), slice(4, 5), ) - assert slc.with_coord(x=123).to_slices("xyztc") == ( + assert slc.updated(x=123).to_slices("xyztc") == ( slice(123, 124), slice(1, 2), slice(2, 3), @@ -100,136 +85,132 @@ def test_with_coord(): def test_split_when_slice_is_multiple_of_block_shape(): - slc = Slice5D.zero(x=slice(100, 200), y=slice(200, 300)) + slc = Interval5D.zero(x=(100, 200), y=(200, 300)) pieces = list(slc.split(Shape5D(x=50, y=50))) - assert Slice5D.zero(x=slice(100, 150), y=slice(200, 250)) in pieces - assert Slice5D.zero(x=slice(100, 150), y=slice(250, 300)) in pieces - assert Slice5D.zero(x=slice(150, 200), y=slice(200, 250)) in pieces - assert Slice5D.zero(x=slice(150, 200), y=slice(250, 300)) in pieces + assert Interval5D.zero(x=(100, 150), y=(200, 250)) in pieces + assert Interval5D.zero(x=(100, 150), y=(250, 300)) in pieces + assert Interval5D.zero(x=(150, 200), y=(200, 250)) in pieces + assert Interval5D.zero(x=(150, 200), y=(250, 300)) in pieces assert len(pieces) == 4 def test_split_when_slice_is_NOT_multiple_of_block_shape(): - slc = Slice5D.zero(x=slice(100, 210), y=slice(200, 320)) + slc = Interval5D.zero(x=(100, 210), y=(200, 320)) pieces = list(slc.split(Shape5D(x=50, y=50))) - assert Slice5D.zero(x=slice(100, 150), y=slice(200, 250)) in pieces - assert Slice5D.zero(x=slice(100, 150), y=slice(250, 300)) in pieces - assert Slice5D.zero(x=slice(100, 150), y=slice(300, 320)) in pieces + assert Interval5D.zero(x=(100, 150), y=(200, 250)) in pieces + assert Interval5D.zero(x=(100, 150), y=(250, 300)) in pieces + assert Interval5D.zero(x=(100, 150), y=(300, 320)) in pieces - assert Slice5D.zero(x=slice(150, 200), y=slice(200, 250)) in pieces - assert Slice5D.zero(x=slice(150, 200), y=slice(250, 300)) in pieces - assert Slice5D.zero(x=slice(150, 200), y=slice(300, 320)) in pieces + assert Interval5D.zero(x=(150, 200), y=(200, 250)) in pieces + assert Interval5D.zero(x=(150, 200), y=(250, 300)) in pieces + assert Interval5D.zero(x=(150, 200), y=(300, 320)) in pieces - assert Slice5D.zero(x=slice(200, 210), y=slice(200, 250)) in pieces - assert Slice5D.zero(x=slice(200, 210), y=slice(250, 300)) in pieces - assert Slice5D.zero(x=slice(200, 210), y=slice(300, 320)) in pieces + assert Interval5D.zero(x=(200, 210), y=(200, 250)) in pieces + assert Interval5D.zero(x=(200, 210), y=(250, 300)) in pieces + assert Interval5D.zero(x=(200, 210), y=(300, 320)) in pieces assert len(pieces) == 9 def test_get_tiles_when_slice_is_multiple_of_tile(): - slc = Slice5D.zero(x=slice(100, 200), y=slice(200, 300)) + slc = Interval5D.zero(x=(100, 200), y=(200, 300)) tiles = list(slc.get_tiles(Shape5D(x=50, y=50))) - assert Slice5D.zero(x=slice(100, 150), y=slice(200, 250)) in tiles - assert Slice5D.zero(x=slice(100, 150), y=slice(250, 300)) in tiles - assert Slice5D.zero(x=slice(150, 200), y=slice(200, 250)) in tiles - assert Slice5D.zero(x=slice(150, 200), y=slice(250, 300)) in tiles + assert Interval5D.zero(x=(100, 150), y=(200, 250)) in tiles + assert Interval5D.zero(x=(100, 150), y=(250, 300)) in tiles + assert Interval5D.zero(x=(150, 200), y=(200, 250)) in tiles + assert Interval5D.zero(x=(150, 200), y=(250, 300)) in tiles assert len(tiles) == 4 def test_get_tiles_when_slice_is_NOT_multiple_of_tile(): - slc = Slice5D.zero(x=slice(90, 210), y=slice(200, 320), z=slice(10, 20)) + slc = Interval5D.zero(x=(90, 210), y=(200, 320), z=(10, 20)) pieces = list(slc.get_tiles(Shape5D(x=50, y=50, z=10))) - assert Slice5D.zero(x=slice(50, 100), y=slice(200, 250), z=slice(10, 20)) in pieces - assert Slice5D.zero(x=slice(50, 100), y=slice(250, 300), z=slice(10, 20)) in pieces - assert Slice5D.zero(x=slice(50, 100), y=slice(300, 350), z=slice(10, 20)) in pieces + assert Interval5D.zero(x=(50, 100), y=(200, 250), z=(10, 20)) in pieces + assert Interval5D.zero(x=(50, 100), y=(250, 300), z=(10, 20)) in pieces + assert Interval5D.zero(x=(50, 100), y=(300, 350), z=(10, 20)) in pieces - assert Slice5D.zero(x=slice(100, 150), y=slice(200, 250), z=slice(10, 20)) in pieces - assert Slice5D.zero(x=slice(100, 150), y=slice(250, 300), z=slice(10, 20)) in pieces - assert Slice5D.zero(x=slice(100, 150), y=slice(300, 350), z=slice(10, 20)) in pieces + assert Interval5D.zero(x=(100, 150), y=(200, 250), z=(10, 20)) in pieces + assert Interval5D.zero(x=(100, 150), y=(250, 300), z=(10, 20)) in pieces + assert Interval5D.zero(x=(100, 150), y=(300, 350), z=(10, 20)) in pieces - assert Slice5D.zero(x=slice(150, 200), y=slice(200, 250), z=slice(10, 20)) in pieces - assert Slice5D.zero(x=slice(150, 200), y=slice(250, 300), z=slice(10, 20)) in pieces - assert Slice5D.zero(x=slice(150, 200), y=slice(300, 350), z=slice(10, 20)) in pieces + assert Interval5D.zero(x=(150, 200), y=(200, 250), z=(10, 20)) in pieces + assert Interval5D.zero(x=(150, 200), y=(250, 300), z=(10, 20)) in pieces + assert Interval5D.zero(x=(150, 200), y=(300, 350), z=(10, 20)) in pieces - assert Slice5D.zero(x=slice(200, 250), y=slice(200, 250), z=slice(10, 20)) in pieces - assert Slice5D.zero(x=slice(200, 250), y=slice(250, 300), z=slice(10, 20)) in pieces - assert Slice5D.zero(x=slice(200, 250), y=slice(300, 350), z=slice(10, 20)) in pieces + assert Interval5D.zero(x=(200, 250), y=(200, 250), z=(10, 20)) in pieces + assert Interval5D.zero(x=(200, 250), y=(250, 300), z=(10, 20)) in pieces + assert Interval5D.zero(x=(200, 250), y=(300, 350), z=(10, 20)) in pieces assert len(pieces) == 12 def test_get_borders(): - slc = Slice5D.zero(x=slice(100, 200), y=slice(300, 400), c=slice(0, 4)) + slc = Interval5D.zero(x=(100, 200), y=(300, 400), c=(0, 4)) thickness = Shape5D.zero(x=1, y=1) expected_borders = { - slc.with_coord(x=slice(100, 101)), - slc.with_coord(y=slice(300, 301)), - slc.with_coord(x=slice(199, 200)), - slc.with_coord(y=slice(399, 400)), + slc.updated(x=(100, 101)), + slc.updated(y=(300, 301)), + slc.updated(x=(199, 200)), + slc.updated(y=(399, 400)), } assert expected_borders == set(slc.get_borders(thickness)) assert len(list(slc.get_borders(thickness))) == 4 thickness = Shape5D.zero(x=10, y=20) expected_thick_borders = { - slc.with_coord(x=slice(100, 110)), - slc.with_coord(x=slice(190, 200)), - slc.with_coord(y=slice(300, 320)), - slc.with_coord(y=slice(380, 400)), + slc.updated(x=(100, 110)), + slc.updated(x=(190, 200)), + slc.updated(y=(300, 320)), + slc.updated(y=(380, 400)), } assert expected_thick_borders == set(slc.get_borders(thickness=thickness)) assert len(list(slc.get_borders(thickness=thickness))) == 4 - z2_slc = Slice5D.zero(x=slice(100, 200), y=slice(300, 400), z=slice(8, 10)) + z2_slc = Interval5D.zero(x=(100, 200), y=(300, 400), z=(8, 10)) thickness = Shape5D.zero(x=10, z=2) - expected_z2_borders = { - z2_slc.with_coord(x=slice(100, 110)), - z2_slc.with_coord(x=slice(190, 200)), - z2_slc.with_coord(z=slice(8, 10)), - } + expected_z2_borders = {z2_slc.updated(x=(100, 110)), z2_slc.updated(x=(190, 200)), z2_slc.updated(z=(8, 10))} assert expected_z2_borders == set(z2_slc.get_borders(thickness=thickness)) assert len(list(z2_slc.get_borders(thickness=thickness))) == 4 def test_get_neighbor_tile_adjacent_to(): - source_tile = Slice5D(x=slice(100, 200), y=slice(300, 400), c=slice(0, 3), z=1, t=1) + source_tile = Interval5D(x=(100, 200), y=(300, 400), c=(0, 3), z=1, t=1) - right_border = source_tile.with_coord(x=slice(199, 200)) + right_border = source_tile.updated(x=(199, 200)) right_neighbor = source_tile.get_neighbor_tile_adjacent_to(anchor=right_border, tile_shape=source_tile.shape) - assert right_neighbor == source_tile.with_coord(x=slice(200, 300)) + assert right_neighbor == source_tile.updated(x=(200, 300)) - left_border = source_tile.with_coord(x=slice(100, 101)) + left_border = source_tile.updated(x=(100, 101)) left_neighbor = source_tile.get_neighbor_tile_adjacent_to(anchor=left_border, tile_shape=source_tile.shape) - assert left_neighbor == source_tile.with_coord(x=slice(0, 100)) + assert left_neighbor == source_tile.updated(x=(0, 100)) - top_border = source_tile.with_coord(y=slice(399, 400)) + top_border = source_tile.updated(y=(399, 400)) top_neighbor = source_tile.get_neighbor_tile_adjacent_to(anchor=top_border, tile_shape=source_tile.shape) - assert top_neighbor == source_tile.with_coord(y=slice(400, 500)) + assert top_neighbor == source_tile.updated(y=(400, 500)) - bottom_border = source_tile.with_coord(y=slice(300, 301)) + bottom_border = source_tile.updated(y=(300, 301)) bottom_neighbor = source_tile.get_neighbor_tile_adjacent_to(anchor=bottom_border, tile_shape=source_tile.shape) - assert bottom_neighbor == source_tile.with_coord(y=slice(200, 300)) + assert bottom_neighbor == source_tile.updated(y=(200, 300)) - partial_tile = Slice5D(x=slice(100, 200), y=slice(400, 470), c=slice(0, 3), z=1, t=1) + partial_tile = Interval5D(x=(100, 200), y=(400, 470), c=(0, 3), z=1, t=1) - right_border = partial_tile.with_coord(x=slice(199, 200)) + right_border = partial_tile.updated(x=(199, 200)) assert partial_tile.get_neighbor_tile_adjacent_to(anchor=right_border, tile_shape=source_tile.shape) == None - left_border = partial_tile.with_coord(x=slice(100, 101)) + left_border = partial_tile.updated(x=(100, 101)) left_neighbor = partial_tile.get_neighbor_tile_adjacent_to(anchor=left_border, tile_shape=source_tile.shape) - assert left_neighbor == partial_tile.with_coord(x=slice(0, 100)) + assert left_neighbor == partial_tile.updated(x=(0, 100)) def test_slice_relabeling_swap(): - slc = Slice5D(x=100, y=200, z=300, t=400, c=500) + slc = Interval5D(x=100, y=200, z=300, t=400, c=500) keymap = KeyMap(x="y", y="x") - assert slc.relabeled(keymap) == Slice5D(y=100, x=200, z=300, t=400, c=500) + assert slc.relabeled(keymap) == Interval5D(y=100, x=200, z=300, t=400, c=500) def test_slice_relabeling_shift(): - slc = Slice5D(x=100, y=200, z=300, t=400, c=500) + slc = Interval5D(x=100, y=200, z=300, t=400, c=500) keymap = KeyMap(x="y", y="z", z="x") - assert slc.relabeled(keymap) == Slice5D(y=100, z=200, x=300, t=400, c=500) + assert slc.relabeled(keymap) == Interval5D(y=100, z=200, x=300, t=400, c=500) def test_slice_enclosing(): @@ -238,7 +219,7 @@ def test_slice_enclosing(): p3 = Point5D.zero(t=3, x=4) p4 = Point5D.zero(t=100, y=400) - expected_slice = Slice5D(x=slice(-13, 4 + 1), y=slice(40, 400 + 1), z=slice(-1, -1 + 1), c=slice(6, 6 + 1)) - assert Slice5D.enclosing([p1, p2, p3, p4]) + expected_slice = Interval5D.zero(x=(-13, 4 + 1), y=(40, 400 + 1), z=(-1, -1 + 1), c=(6, 6 + 1)) + assert Interval5D.enclosing([p1, p2, p3, p4]) - assert Slice5D.enclosing([p2]).start == p2 + assert Interval5D.enclosing([p2]).start == p2 From 4e57276720047c144194ba8b8ef4cd94eecd6df6 Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Thu, 10 Dec 2020 16:42:26 +0100 Subject: [PATCH 3/6] Finished migrating to Interval5D --- ndstructs/datasink/DataSink.py | 13 +- ndstructs/datasink/N5DataSink.py | 12 +- ndstructs/datasource/DataRoi.py | 98 +++++++++++++++ ndstructs/datasource/DataSource.py | 55 +++++--- ndstructs/datasource/DataSourceSlice.py | 117 ------------------ ndstructs/datasource/N5DataSource.py | 6 +- .../datasource/PrecomputedChunksDataSource.py | 8 +- ndstructs/datasource/SequenceDataSource.py | 14 +-- ndstructs/datasource/__init__.py | 2 +- tests/test_datasink.py | 20 +-- tests/test_datasource.py | 74 +++++------ 11 files changed, 206 insertions(+), 213 deletions(-) create mode 100644 ndstructs/datasource/DataRoi.py delete mode 100644 ndstructs/datasource/DataSourceSlice.py diff --git a/ndstructs/datasink/DataSink.py b/ndstructs/datasink/DataSink.py index 535cdb0..7950ba6 100644 --- a/ndstructs/datasink/DataSink.py +++ b/ndstructs/datasink/DataSink.py @@ -1,21 +1,20 @@ from abc import abstractmethod from typing import Optional -from ndstructs import Point5D, Shape5D, Slice5D, Array5D +from ndstructs import Point5D, Shape5D, Interval5D, Array5D from ndstructs.datasource import UnsupportedUrlException from ndstructs.datasource.DataSource import DataSource, AddressMode -from ndstructs.datasource.DataSourceSlice import DataSourceSlice +from ndstructs.datasource.DataRoi import DataRoi class DataSink: - def __init__(self, *, data_slice: DataSourceSlice, tile_shape: Optional[Shape5D] = None): + def __init__(self, *, data_slice: DataRoi, tile_shape: Optional[Shape5D] = None): self.data_slice = data_slice self.tile_shape = tile_shape or data_slice.tile_shape - def process(self, roi: Slice5D = Slice5D.all(), address_mode: AddressMode = AddressMode.BLACK) -> None: - defined_roi = roi.defined_with(self.data_slice) - assert self.data_slice.contains(defined_roi) - for piece in defined_roi.split(self.tile_shape): + def process(self, roi: Interval5D, address_mode: AddressMode = AddressMode.BLACK) -> None: + assert self.data_slice.contains(roi) + for piece in roi.split(self.tile_shape): source_data = self.data_slice.datasource.retrieve(piece, address_mode=address_mode) self._process_tile(source_data) diff --git a/ndstructs/datasink/N5DataSink.py b/ndstructs/datasink/N5DataSink.py index 5e00e1a..52b0f2a 100644 --- a/ndstructs/datasink/N5DataSink.py +++ b/ndstructs/datasink/N5DataSink.py @@ -8,11 +8,11 @@ from fs.base import FS from fs.osfs import OSFS -from ndstructs.point5D import Point5D, Slice5D, Shape5D +from ndstructs.point5D import Point5D, Interval5D, Shape5D from ndstructs.array5D import Array5D from ndstructs.datasource.DataSource import DataSource, UnsupportedUrlException from ndstructs.datasource.N5DataSource import N5Block -from ndstructs.datasource.DataSourceSlice import DataSourceSlice +from ndstructs.datasource.DataRoi import DataRoi from ndstructs.datasink.DataSink import DataSink @@ -29,7 +29,7 @@ def __init__( self, *, path: Path, # dataset path, e.g. "mydata.n5/mydataset" - data_slice: DataSourceSlice, + data_slice: DataRoi, axiskeys: str = "tzyxc", compression_type: str = "raw", tile_shape: Optional[Shape5D] = None, @@ -77,17 +77,17 @@ def __init__( self.filesystem.makedirs(dir_path) created_dirs.add(dir_path) - def get_tile_dataset_path(self, global_roi: Slice5D) -> str: + def get_tile_dataset_path(self, global_roi: Interval5D) -> str: "Gets the relative path into the n5 dataset where 'tile' should be stored" local_roi = global_roi.translated(-self.data_slice.start) slice_address_components = (local_roi.start // self.tile_shape).to_np(self.axiskeys[::-1]).astype(np.uint32) return "/".join(map(str, slice_address_components)) - def get_tile_dir_dataset_path(self, global_roi: Slice5D) -> str: + def get_tile_dir_dataset_path(self, global_roi: Interval5D) -> str: return "/".join(self.get_tile_dataset_path(global_roi).split("/")[:-1]) def _process_tile(self, tile: Array5D) -> None: tile = N5Block.fromArray5D(tile) - tile_path = self.get_tile_dataset_path(global_roi=tile.roi) + tile_path = self.get_tile_dataset_path(global_roi=tile.interval) with self.filesystem.openbin(tile_path, "w") as f: f.write(tile.to_n5_bytes(axiskeys=self.axiskeys, compression_type=self.compression_type)) diff --git a/ndstructs/datasource/DataRoi.py b/ndstructs/datasource/DataRoi.py new file mode 100644 index 0000000..de66ee7 --- /dev/null +++ b/ndstructs/datasource/DataRoi.py @@ -0,0 +1,98 @@ +from ndstructs.datasource.DataSource import DataSource, AddressMode +from ndstructs import Interval5D, Shape5D, Array5D, Point5D +from ndstructs.point5D import INTERVAL_5D, SPAN +from typing import Iterator, Optional, Union + + +class DataRoi(Interval5D): + def __init__( + self, datasource: DataSource, *, t: SPAN = None, c: SPAN = None, x: SPAN = None, y: SPAN = None, z: SPAN = None + ): + super().__init__( + t=t if t is not None else datasource.interval.t, + c=c if c is not None else datasource.interval.c, + x=x if x is not None else datasource.interval.x, + y=y if y is not None else datasource.interval.y, + z=z if z is not None else datasource.interval.z, + ) + self.datasource = datasource + + def __hash__(self) -> int: + return hash((super().__hash__(), self.datasource)) + + def __eq__(self, other: object) -> bool: + if not super().__eq__(other): + return False + if isinstance(other, DataRoi) and self.datasource != other.datasource: + return False + return True + + def updated( + self, + *, + x: Optional[SPAN] = None, + y: Optional[SPAN] = None, + z: Optional[SPAN] = None, + t: Optional[SPAN] = None, + c: Optional[SPAN] = None, + ) -> "DataRoi": + inter = self.interval.updated(t=t, c=c, x=x, y=y, z=z) + return self.__class__(datasource=self.datasource, x=inter.x, y=inter.y, z=inter.z, t=inter.t, c=inter.c) + + def __repr__(self) -> str: + return super().__repr__() + " " + self.datasource.url + + def full(self) -> "DataRoi": + return self.updated(**self.full_shape.to_interval5d().to_dict()) + + @property + def full_shape(self) -> Shape5D: + return self.datasource.shape + + @property + def tile_shape(self) -> Shape5D: + return self.datasource.tile_shape + + @property + def dtype(self): + return self.datasource.dtype + + def is_tile(self, tile_shape: Shape5D = None) -> bool: + tile_shape = tile_shape or self.tile_shape + has_tile_start = self.start % tile_shape == Point5D.zero() + has_tile_end = self.stop % tile_shape == Point5D.zero() or self.stop == self.full().stop + return has_tile_start and has_tile_end + + @property + def interval(self) -> Interval5D: + return Interval5D(t=self.t, c=self.c, x=self.x, y=self.y, z=self.z) + + def retrieve(self, address_mode: AddressMode = AddressMode.BLACK) -> Array5D: + return self.datasource.retrieve(self.interval, address_mode=address_mode) + + def split(self, block_shape: Optional[Shape5D] = None) -> Iterator["DataRoi"]: + yield from super().split(block_shape or self.tile_shape) + + def get_tiles(self, tile_shape: Shape5D = None, clamp: bool = True) -> Iterator["DataRoi"]: + for tile in super().get_tiles(tile_shape or self.tile_shape): + if clamp: + clamped = tile.clamped(self) + if not self.contains(clamped): + continue + yield clamped + else: + yield tile + + # for this and the next method, tile_shape is needed because self could be an edge tile, and therefor + # self.shape would not return a typical tile shape + def get_neighboring_tiles(self, tile_shape: Shape5D) -> Iterator["DataRoi"]: + for neighbor in super().get_neighboring_tiles(tile_shape): + neighbor = neighbor.clamped(self.full()) + if neighbor.shape.hypervolume > 0 and neighbor != self: + yield neighbor + + def get_neighbor_tile_adjacent_to(self, *, anchor: Interval5D, tile_shape: Shape5D) -> Optional["DataRoi"]: + neighbor = super().get_neighbor_tile_adjacent_to(anchor=anchor, tile_shape=tile_shape) + if not self.full().contains(neighbor): + return None + return neighbor.clamped(self.full()) diff --git a/ndstructs/datasource/DataSource.py b/ndstructs/datasource/DataSource.py index d1153c7..0c95ecf 100644 --- a/ndstructs/datasource/DataSource.py +++ b/ndstructs/datasource/DataSource.py @@ -14,8 +14,9 @@ from fs.osfs import OSFS -from ndstructs import Array5D, Shape5D, Slice5D, Point5D -from ndstructs.utils import JsonSerializable, to_json_data +from ndstructs import Array5D, Shape5D, Interval5D, Point5D +from ndstructs.array5D import SPAN_OVERRIDE, All +from ndstructs.utils import JsonSerializable, to_json_data, Referencer from .UnsupportedUrlException import UnsupportedUrlException @@ -30,8 +31,6 @@ @enum.unique class AddressMode(IntEnum): BLACK = 0 - MIRROR = enum.auto() - WRAP = enum.auto() # DS_CTOR = Callable[[str, Optional[Shape5D], str], "DataSource"] @@ -72,18 +71,18 @@ def __init__( axiskeys: str, ): self.url = url - self.tile_shape = (tile_shape or Shape5D.hypercube(256)).to_slice_5d().clamped(shape.to_slice_5d()).shape + self.tile_shape = (tile_shape or Shape5D.hypercube(256)).to_interval5d().clamped(shape.to_interval5d()).shape self.dtype = dtype self.name = name or self.url.split("/")[-1] self.shape = shape - self.roi = shape.to_slice_5d(offset=location) + self.interval = shape.to_interval5d(offset=location) self.location = location self.axiskeys = axiskeys def __str__(self) -> str: return f"<{self.__class__.__name__} {self.shape} {self.url}>" - def to_json_data(self, referencer: Callable[[Any], str] = lambda obj: None) -> Dict: + def to_json_data(self, referencer: Referencer = lambda obj: None) -> Dict: return to_json_data( { "__class__": self.__class__.__name__, @@ -93,7 +92,7 @@ def to_json_data(self, referencer: Callable[[Any], str] = lambda obj: None) -> D "dtype": self.dtype.name, "name": self.name, "shape": self.shape, - "roi": self.roi, + "interval": self.interval, } ) @@ -109,23 +108,39 @@ def __eq__(self, other: object) -> bool: return self.url == other.url and self.tile_shape == other.tile_shape @ndstructs_datasource_cache - def get_tile(self, tile: Slice5D) -> Array5D: + def get_tile(self, tile: Interval5D) -> Array5D: return self._get_tile(tile) @abstractmethod - def _get_tile(self, tile: Slice5D) -> Array5D: + def _get_tile(self, tile: Interval5D) -> Array5D: pass def close(self) -> None: pass - def _allocate(self, roi: Union[Shape5D, Slice5D], fill_value: int) -> Array5D: - return Array5D.allocate(roi, dtype=self.dtype, value=fill_value) + def _allocate(self, interval: Union[Shape5D, Interval5D], fill_value: int) -> Array5D: + return Array5D.allocate(interval, dtype=self.dtype, value=fill_value) - def retrieve(self, roi: Slice5D, address_mode: AddressMode = AddressMode.BLACK) -> Array5D: - # FIXME: Remove address_mode or implement all variations and make feature extractors use the correct one - out = self._allocate(roi.defined_with(self.shape).translated(-self.location), fill_value=0) - local_data_roi = roi.clamped(self.roi).translated(-self.location) + def retrieve( + self, + interval: Optional[Interval5D] = None, + *, + x: Optional[SPAN_OVERRIDE] = None, + y: Optional[SPAN_OVERRIDE] = None, + z: Optional[SPAN_OVERRIDE] = None, + t: Optional[SPAN_OVERRIDE] = None, + c: Optional[SPAN_OVERRIDE] = None, + address_mode: AddressMode = AddressMode.BLACK, + ) -> Array5D: + interval = (interval or self.interval).updated( + x=self.interval.x if isinstance(x, All) else x, + y=self.interval.y if isinstance(y, All) else y, + z=self.interval.z if isinstance(z, All) else z, + t=self.interval.t if isinstance(t, All) else t, + c=self.interval.c if isinstance(c, All) else c, + ) + out = self._allocate(interval.translated(-self.location), fill_value=0) + local_data_roi = interval.clamped(self.interval).translated(-self.location) for tile in local_data_roi.get_tiles(self.tile_shape): tile_within_bounds = tile.clamped(self.shape) tile_data = self.get_tile(tile_within_bounds) @@ -155,7 +170,7 @@ def __init__(self, path: Path, *, location: Point5D = Point5D.zero(), filesystem self._dataset.file.close() raise e - def _get_tile(self, tile: Slice5D) -> Array5D: + def _get_tile(self, tile: Interval5D) -> Array5D: slices = tile.to_slices(self.axiskeys) raw = cast(h5py.Dataset, self._dataset)[slices] return Array5D(raw, axiskeys=self.axiskeys, location=tile.start) @@ -241,11 +256,11 @@ def __init__( def from_array5d(cls, arr, *, tile_shape: Optional[Shape5D] = None, location: Point5D = Point5D.zero()): return cls(data=arr.raw(Point5D.LABELS), axiskeys=Point5D.LABELS, location=location, tile_shape=tile_shape) - def _get_tile(self, tile: Slice5D) -> Array5D: + def _get_tile(self, tile: Interval5D) -> Array5D: return self._data.cut(tile, copy=True) - def _allocate(self, roi: Union[Shape5D, Slice5D], fill_value: int) -> Array5D: - return self._data.__class__.allocate(roi, dtype=self.dtype, value=fill_value) + def _allocate(self, interval: Union[Shape5D, Interval5D], fill_value: int) -> Array5D: + return self._data.__class__.allocate(interval, dtype=self.dtype, value=fill_value) class SkimageDataSource(ArrayDataSource): diff --git a/ndstructs/datasource/DataSourceSlice.py b/ndstructs/datasource/DataSourceSlice.py deleted file mode 100644 index 7af2da0..0000000 --- a/ndstructs/datasource/DataSourceSlice.py +++ /dev/null @@ -1,117 +0,0 @@ -from ndstructs.datasource.DataSource import DataSource, AddressMode -from ndstructs import Slice5D, Shape5D, Array5D, Point5D -from ndstructs.point5D import SLC, SLC_PARAM -from typing import Iterator, Optional - - -class DataSourceSlice(Slice5D): - def __init__( - self, - datasource: DataSource, - *, - t: Optional[SLC_PARAM] = None, - c: Optional[SLC_PARAM] = None, - x: Optional[SLC_PARAM] = None, - y: Optional[SLC_PARAM] = None, - z: Optional[SLC_PARAM] = None, - ): - super().__init__( - t=t if t is not None else datasource.roi.t, - c=c if c is not None else datasource.roi.c, - x=x if x is not None else datasource.roi.x, - y=y if y is not None else datasource.roi.y, - z=z if z is not None else datasource.roi.z, - ) - self.datasource = datasource - - def __hash__(self) -> int: - return hash((super().__hash__(), self.datasource)) - - def __eq__(self, other: object) -> bool: - if not super().__eq__(other): - return False - if isinstance(other, DataSourceSlice) and self.datasource != other.datasource: - return False - return True - - def with_coord( - self, - *, - t: Optional[SLC_PARAM] = None, - c: Optional[SLC_PARAM] = None, - x: Optional[SLC_PARAM] = None, - y: Optional[SLC_PARAM] = None, - z: Optional[SLC_PARAM] = None, - ) -> "DataSourceSlice": - slc = self.roi.with_coord(t=t, c=c, x=x, y=y, z=z) - return self.__class__(datasource=self.datasource, **slc.to_dict()) - - def defined(self) -> "DataSourceSlice": - return self.defined_with(self.full_shape) - - def __repr__(self) -> str: - return super().__repr__() + " " + self.datasource.url - - def full(self) -> "DataSourceSlice": - return self.with_coord(**self.full_shape.to_slice_5d().to_dict()) - - @property - def full_shape(self) -> Shape5D: - return self.datasource.shape - - def contains(self, slc: Slice5D) -> bool: - return super().contains(slc.defined_with(self.full_shape)) - - @property - def tile_shape(self) -> Shape5D: - return self.datasource.tile_shape - - @property - def dtype(self): - return self.datasource.dtype - - def is_tile(self, tile_shape: Shape5D = None) -> bool: - tile_shape = tile_shape or self.tile_shape - has_tile_start = self.start % tile_shape == Point5D.zero() - has_tile_end = self.stop % tile_shape == Point5D.zero() or self.stop == self.full().stop - return has_tile_start and has_tile_end - - @property - def roi(self) -> Slice5D: - return Slice5D(t=self.t, c=self.c, x=self.x, y=self.y, z=self.z) - - def retrieve(self, address_mode: AddressMode = AddressMode.BLACK) -> Array5D: - return self.datasource.retrieve(self.roi, address_mode=address_mode) - - def split(self: SLC, block_shape: Optional[Shape5D] = None) -> Iterator[SLC]: - if not self.is_defined(): - return self.defined().split(block_shape=block_shape) - yield from super().split(block_shape or self.tile_shape) - - def get_tiles(self, tile_shape: Shape5D = None, clamp: bool = True) -> Iterator["DataSourceSlice"]: - if not self.is_defined(): - return self.defined().get_tiles(tile_shape=tile_shape, clamp=clamp) - for tile in super().get_tiles(tile_shape or self.tile_shape): - if clamp: - clamped = tile.clamped(self) - if not self.contains(clamped): - continue - yield clamped - else: - yield tile - - # for this and the next method, tile_shape is needed because self could be an edge tile, and therefor - # self.shape would not return a typical tile shape - def get_neighboring_tiles(self, tile_shape: Shape5D) -> Iterator["DataSourceSlice"]: - if not self.is_defined(): - return self.defined().get_neighboring_tiles(tile_shape=tile_shape) - for neighbor in super().get_neighboring_tiles(tile_shape): - neighbor = neighbor.clamped(self.full()) - if neighbor.shape.hypervolume > 0 and neighbor != self: - yield neighbor - - def get_neighbor_tile_adjacent_to(self, *, anchor: Slice5D, tile_shape: Shape5D) -> Optional["DataSourceSlice"]: - neighbor = super().get_neighboring_tiles(anchor=anchor, tile_shape=tile_shape) - if not self.full().contains(neighbor): - return None - return neighbor.clamped(self.full()) diff --git a/ndstructs/datasource/N5DataSource.py b/ndstructs/datasource/N5DataSource.py index 81e1112..3495794 100644 --- a/ndstructs/datasource/N5DataSource.py +++ b/ndstructs/datasource/N5DataSource.py @@ -12,11 +12,11 @@ import numpy as np -from ndstructs import Point5D, Shape5D, Slice5D, Array5D +from ndstructs import Point5D, Shape5D, Interval5D, Array5D from ndstructs.datasource.DataSource import DataSource, guess_axiskeys from .UnsupportedUrlException import UnsupportedUrlException -from ndstructs.datasource.DataSourceSlice import DataSourceSlice +from ndstructs.datasource.DataRoi import DataRoi from fs import open_fs from fs.base import FS @@ -115,7 +115,7 @@ def __init__(self, path: Path, *, location: Point5D = Point5D.zero(), filesystem if self.compression_type not in N5Block.DECOMPRESSORS.keys(): raise NotImplementedError(f"Don't know how to decompress from {self.compression_type}") - def _get_tile(self, tile: Slice5D) -> Array5D: + def _get_tile(self, tile: Interval5D) -> Array5D: f_axiskeys = self.axiskeys[::-1] slice_address_components = (tile.start // self.tile_shape).to_tuple(f_axiskeys) slice_address = "/".join(str(int(comp)) for comp in slice_address_components) diff --git a/ndstructs/datasource/PrecomputedChunksDataSource.py b/ndstructs/datasource/PrecomputedChunksDataSource.py index e20a6f8..340f6d3 100644 --- a/ndstructs/datasource/PrecomputedChunksDataSource.py +++ b/ndstructs/datasource/PrecomputedChunksDataSource.py @@ -12,11 +12,9 @@ from fs.base import FS from fs.osfs import OSFS -from ndstructs import Point5D, Shape5D, Slice5D, Array5D +from ndstructs import Point5D, Shape5D, Interval5D, Array5D from ndstructs.datasource.DataSource import DataSource -from .UnsupportedUrlException import UnsupportedUrlException -from ndstructs.datasource.DataSourceSlice import DataSourceSlice from ndstructs.utils import JsonSerializable, Dereferencer, Referencer @@ -148,9 +146,9 @@ def __init__( self.decompressor = noop self.compressor = noop else: - raise NotImplementedError(f"Don't know how to decompress {compression_type}") + raise NotImplementedError(f"Don't know how to decompress {encoding_type}") - def _get_tile(self, tile: Slice5D) -> Array5D: + def _get_tile(self, tile: Interval5D) -> Array5D: slice_address = "_".join(f"{s.start}-{s.stop}" for s in tile.to_slices(self.scale.spatial_axiskeys)) path = self.scale.key + "/" + slice_address with self.filesystem.openbin(path) as f: diff --git a/ndstructs/datasource/SequenceDataSource.py b/ndstructs/datasource/SequenceDataSource.py index b7dba9b..b2180c1 100644 --- a/ndstructs/datasource/SequenceDataSource.py +++ b/ndstructs/datasource/SequenceDataSource.py @@ -5,8 +5,8 @@ import itertools from ndstructs.datasource.DataSource import DataSource -from ndstructs.datasource.DataSourceSlice import DataSourceSlice -from ndstructs import Shape5D, Slice5D, Point5D, Array5D +from ndstructs.datasource.DataRoi import DataRoi +from ndstructs import Shape5D, Interval5D, Point5D, Array5D class SequenceDataSource(DataSource): @@ -31,13 +31,13 @@ def __init__( self.layer_offsets.append(layer_offset[stack_axis]) layer_offset += Point5D.zero(**{stack_axis: layer.shape[stack_axis]}) - if len(set(layer.shape.with_coord(**{stack_axis: 1}) for layer in self.layers)) > 1: + if len(set(layer.shape.updated(**{stack_axis: 1}) for layer in self.layers)) > 1: raise ValueError("Provided files have different dimensions on the non-stacking axis") if any(layer.dtype != self.layers[0].dtype for layer in self.layers): raise ValueError("All layers must have the same data type!") stack_size = sum(layer.shape[self.stack_axis] for layer in self.layers) - full_shape = self.layers[0].shape.with_coord(**{self.stack_axis: stack_size}) + full_shape = self.layers[0].shape.updated(**{self.stack_axis: stack_size}) super().__init__( url=":".join(p.as_posix() for p in paths), @@ -48,13 +48,13 @@ def __init__( axiskeys=stack_axis + Point5D.LABELS.replace(stack_axis, ""), ) - def _get_tile(self, tile: Slice5D) -> Array5D: + def _get_tile(self, tile: Interval5D) -> Array5D: first_layer_idx = bisect.bisect_left(self.layer_offsets, tile.start[self.stack_axis]) - out = self._allocate(roi=tile, fill_value=0) + out = self._allocate(interval=tile, fill_value=0) for layer, layer_offset in zip(self.layers[first_layer_idx:], self.layer_offsets[first_layer_idx:]): if layer_offset > tile.stop[self.stack_axis]: break - layer_tile = tile.clamped(layer.roi) + layer_tile = tile.clamped(layer.interval) layer_data = layer.retrieve(layer_tile) out.set(layer_data, autocrop=True) diff --git a/ndstructs/datasource/__init__.py b/ndstructs/datasource/__init__.py index 0f2e5f4..3f3b9f3 100644 --- a/ndstructs/datasource/__init__.py +++ b/ndstructs/datasource/__init__.py @@ -1,5 +1,5 @@ from .DataSource import DataSource, ArrayDataSource, SkimageDataSource, H5DataSource -from .DataSourceSlice import DataSourceSlice +from .DataRoi import DataRoi from .N5DataSource import N5DataSource from .PrecomputedChunksDataSource import PrecomputedChunksDataSource from .SequenceDataSource import SequenceDataSource diff --git a/tests/test_datasink.py b/tests/test_datasink.py index 227f91a..8644ee5 100644 --- a/tests/test_datasink.py +++ b/tests/test_datasink.py @@ -3,9 +3,9 @@ import pytest import numpy as np -from ndstructs import Point5D, Slice5D, Array5D, Shape5D +from ndstructs import Point5D, Interval5D, Array5D, Shape5D from ndstructs.datasource.DataSource import ArrayDataSource, DataSource -from ndstructs.datasource.DataSourceSlice import DataSourceSlice +from ndstructs.datasource.DataRoi import DataRoi from ndstructs.datasource.N5DataSource import N5DataSource from ndstructs.datasink import N5DataSink @@ -22,27 +22,27 @@ def datasource(data: Array5D): def test_n5_datasink(tmp_path: Path, data: Array5D, datasource: DataSource): dataset_path = tmp_path / "test_n5_datasink.n5/data" - sink = N5DataSink(path=dataset_path, data_slice=DataSourceSlice(datasource), tile_shape=Shape5D(x=10, y=10)) - sink.process(Slice5D.all()) + sink = N5DataSink(path=dataset_path, data_slice=DataRoi(datasource), tile_shape=Shape5D(x=10, y=10)) + sink.process(sink.data_slice) n5ds = DataSource.create(dataset_path) - assert n5ds.retrieve(Slice5D.all()) == data + assert n5ds.retrieve() == data def test_n5_datasink_saves_roi(tmp_path: Path, data: Array5D, datasource: DataSource): - roi = DataSourceSlice(datasource, x=slice(5, 8), y=slice(2, 4)) + roi = DataRoi(datasource, x=(5, 8), y=(2, 4)) dataset_path = tmp_path / "test_n5_datasink_saves_roi.n5/data" sink = N5DataSink(path=dataset_path, data_slice=roi, tile_shape=Shape5D(x=10, y=10)) - sink.process(Slice5D.all()) + sink.process(sink.data_slice) n5ds = DataSource.create(dataset_path) - assert n5ds.retrieve(Slice5D.all()) == roi.retrieve() + assert n5ds.retrieve() == roi.retrieve() def test_distributed_n5_datasink(tmp_path: Path, data: Array5D, datasource: DataSource): dataset_path = tmp_path / "test_distributed_n5_datasink.n5/data" - data_slice = DataSourceSlice(datasource) + data_slice = DataRoi(datasource) sinks = [ N5DataSink(path=dataset_path, data_slice=data_slice, mode=N5DataSink.Mode.CREATE), @@ -56,4 +56,4 @@ def test_distributed_n5_datasink(tmp_path: Path, data: Array5D, datasource: Data sink.process(piece) n5ds = DataSource.create(dataset_path) - assert n5ds.retrieve(Slice5D.all()) == data + assert n5ds.retrieve() == data diff --git a/tests/test_datasource.py b/tests/test_datasource.py index df8cbb3..b8d7ba5 100644 --- a/tests/test_datasource.py +++ b/tests/test_datasource.py @@ -1,11 +1,11 @@ import pytest -from typing import Optional +from typing import Optional, Iterator import os import tempfile from pathlib import Path import numpy as np import pickle -from ndstructs import Shape5D, Slice5D, Array5D, Point5D, KeyMap +from ndstructs import Shape5D, Interval5D, Array5D, Point5D, KeyMap from ndstructs.datasource import ( DataSource, SkimageDataSource, @@ -16,7 +16,7 @@ ) from ndstructs.datasink import N5DataSink from fs.osfs import OSFS -from ndstructs.datasource import DataSourceSlice +from ndstructs.datasource import DataRoi import h5py import json import shutil @@ -69,20 +69,20 @@ def create_png(array: Array5D) -> Path: def create_n5(array: Array5D, axiskeys: str = "xyztc", chunk_size: Optional[Shape5D] = None): - data_slice = DataSourceSlice(ArrayDataSource.from_array5d(array)) + data_slice = DataRoi(ArrayDataSource.from_array5d(array)) chunk_size = chunk_size or Shape5D.hypercube(10) path = Path(tempfile.mkstemp()[1] + ".n5/data") sink = N5DataSink(path=path, data_slice=data_slice, axiskeys=axiskeys, tile_shape=chunk_size) - sink.process() + sink.process(data_slice) return path.as_posix() def create_h5(array: Array5D, axiskeys_style: str, chunk_shape: Shape5D = None, axiskeys="xyztc"): - chunk_shape = (chunk_shape or Shape5D() * 2).clamped(maximum=array.shape).to_tuple(axiskeys) + raw_chunk_shape = (chunk_shape or Shape5D() * 2).clamped(maximum=array.shape).to_tuple(axiskeys) path = tempfile.mkstemp()[1] + ".h5" f = h5py.File(path, "w") - ds = f.create_dataset("data", chunks=chunk_shape, data=array.raw(axiskeys)) + ds = f.create_dataset("data", chunks=raw_chunk_shape, data=array.raw(axiskeys)) if axiskeys_style == "dims": for key, dim in zip(axiskeys, ds.dims): dim.label = key @@ -97,7 +97,7 @@ def create_h5(array: Array5D, axiskeys_style: str, chunk_shape: Shape5D = None, @pytest.fixture -def png_image() -> Path: +def png_image() -> Iterator[Path]: png_path = create_png(Array5D(raw, axiskeys="yx")) yield png_path os.remove(png_path) @@ -123,9 +123,9 @@ def test_retrieve_roi_smaller_than_tile(): # fmt: on path = Path(create_n5(data, chunk_size=Shape5D(c=2, y=4, x=4))) ds = DataSource.create(path) - print(f"\n\n====>> tile shape: {ds.roi}") + print(f"\n\n====>> tile shape: {ds.shape}") - smaller_than_tile = ds.retrieve(Slice5D.all(c=1, y=slice(0, 4), x=slice(0, 4))) + smaller_than_tile = ds.retrieve(c=1, y=(0, 4), x=(0, 4)) print(smaller_than_tile.raw("cyx")) @@ -149,10 +149,10 @@ def test_n5_datasource(): [6, 7, 8] ]).astype(np.uint8), axiskeys="yx") # fmt: on - assert ds.retrieve(Slice5D(x=slice(0, 3), y=slice(0, 2))) == expected_raw_piece + assert ds.retrieve(x=(0, 3), y=(0, 2)) == expected_raw_piece ds2 = pickle.loads(pickle.dumps(ds)) - assert ds2.retrieve(Slice5D(x=slice(0, 3), y=slice(0, 2))) == expected_raw_piece + assert ds2.retrieve(x=(0, 3), y=(0, 2)) == expected_raw_piece try: @@ -173,7 +173,7 @@ def test_h5_datasource(): assert ds.shape == data_2d.shape assert ds.tile_shape == Shape5D(x=3, y=3) - slc = Slice5D(x=slice(0, 3), y=slice(0, 2)) + slc = ds.interval.updated(x=(0, 3), y=(0, 2)) assert (ds.retrieve(slc).raw("yx") == data_2d.cut(slc).raw("yx")).all() data_3d = Array5D(np.arange(10 * 10 * 10).reshape(10, 10, 10), axiskeys="zyx") @@ -182,25 +182,25 @@ def test_h5_datasource(): assert ds.shape == data_3d.shape assert ds.tile_shape == Shape5D(x=3, y=3) - slc = Slice5D(x=slice(0, 3), y=slice(0, 2), z=3) + slc = ds.interval.updated(x=(0, 3), y=(0, 2), z=3) assert (ds.retrieve(slc).raw("yxz") == data_3d.cut(slc).raw("yxz")).all() def test_skimage_datasource_tiles(png_image: Path): - bs = DataSourceSlice(SkimageDataSource(png_image, filesystem=OSFS("/"))) + bs = DataRoi(SkimageDataSource(png_image, filesystem=OSFS("/"))) num_checked_tiles = 0 for tile in bs.split(Shape5D(x=2, y=2)): - if tile == Slice5D.zero(x=slice(0, 2), y=slice(0, 2)): + if tile == Interval5D.zero(x=(0, 2), y=(0, 2)): expected_raw = raw_0_2x0_2y - elif tile == Slice5D.zero(x=slice(0, 2), y=slice(2, 4)): + elif tile == Interval5D.zero(x=(0, 2), y=(2, 4)): expected_raw = raw_0_2x2_4y - elif tile == Slice5D.zero(x=slice(2, 4), y=slice(0, 2)): + elif tile == Interval5D.zero(x=(2, 4), y=(0, 2)): expected_raw = raw_2_4x0_2y - elif tile == Slice5D.zero(x=slice(2, 4), y=slice(2, 4)): + elif tile == Interval5D.zero(x=(2, 4), y=(2, 4)): expected_raw = raw_2_4x2_4y - elif tile == Slice5D.zero(x=slice(4, 5), y=slice(0, 2)): + elif tile == Interval5D.zero(x=(4, 5), y=(0, 2)): expected_raw = raw_4_5x0_2y - elif tile == Slice5D.zero(x=slice(4, 5), y=slice(2, 4)): + elif tile == Interval5D.zero(x=(4, 5), y=(2, 4)): expected_raw = raw_4_5x2_4y else: raise Exception(f"Unexpected tile {tile}") @@ -228,7 +228,7 @@ def test_neighboring_tiles(): ds = DataSource.create(create_png(arr)) - fifties_slice = DataSourceSlice(ds, x=slice(3, 6), y=slice(3, 6)) + fifties_slice = DataRoi(ds, x=(3, 6), y=(3, 6)) expected_fifties_slice = Array5D(np.asarray([ [50, 51, 52], [53, 54, 55], @@ -236,11 +236,11 @@ def test_neighboring_tiles(): ]), axiskeys="yx") # fmt: on - top_slice = DataSourceSlice(ds, x=slice(3, 6), y=slice(0, 3)) - bottom_slice = DataSourceSlice(ds, x=slice(3, 6), y=slice(6, 9)) + top_slice = DataRoi(ds, x=(3, 6), y=(0, 3)) + bottom_slice = DataRoi(ds, x=(3, 6), y=(6, 9)) - right_slice = DataSourceSlice(ds, x=slice(6, 7), y=slice(3, 6)) - left_slice = DataSourceSlice(ds, x=slice(0, 3), y=slice(3, 6)) + right_slice = DataRoi(ds, x=(6, 7), y=(3, 6)) + left_slice = DataRoi(ds, x=(0, 3), y=(3, 6)) # fmt: off fifties_neighbor_data = { @@ -368,7 +368,7 @@ def test_sequence_datasource(): [352, 353]]], ]), axiskeys="zcyx") # fmt: on - slice_x_2_4__y_1_3 = Slice5D(x=slice(2, 4), y=slice(1, 3)) + slice_x_2_4__y_1_3 = {"x": (2, 4), "y": (1, 3)} urls = [ # create_n5(img1_data, axiskeys="cyx"), @@ -381,16 +381,16 @@ def test_sequence_datasource(): seq_ds = SequenceDataSource(urls, stack_axis="z") assert seq_ds.shape == Shape5D(x=5, y=4, c=3, z=3) - data = seq_ds.retrieve(slice_x_2_4__y_1_3) + data = seq_ds.retrieve(**slice_x_2_4__y_1_3) assert (expected_x_2_4__y_1_3.raw("xyzc") == data.raw("xyzc")).all() seq_ds = SequenceDataSource(urls, stack_axis="z") - data = seq_ds.retrieve(slice_x_2_4__y_1_3) + data = seq_ds.retrieve(**slice_x_2_4__y_1_3) assert (expected_x_2_4__y_1_3.raw("xyzc") == data.raw("xyzc")).all() seq_ds = SequenceDataSource(urls, stack_axis="c") expected_c = sum([img1_data.shape.c, img2_data.shape.c, img3_data.shape.c]) - assert seq_ds.shape == img1_data.shape.with_coord(c=expected_c) + assert seq_ds.shape == img1_data.shape.updated(c=expected_c) cstack_data = Array5D.allocate(Shape5D(x=5, y=4, c=expected_c), dtype=img1_data.dtype) cstack_data.set(img1_data.translated(Point5D.zero(c=0))) @@ -398,8 +398,8 @@ def test_sequence_datasource(): cstack_data.set(img3_data.translated(Point5D.zero(c=6))) assert seq_ds.shape == cstack_data.shape - expected_data = cstack_data.cut(slice_x_2_4__y_1_3) - data = seq_ds.retrieve(slice_x_2_4__y_1_3) + expected_data = cstack_data.cut(**slice_x_2_4__y_1_3) + data = seq_ds.retrieve(**slice_x_2_4__y_1_3) assert (expected_data.raw("cxy") == data.raw("cxy")).all() @@ -413,8 +413,8 @@ def test_sequence_datasource(): # adjusted = DataSource.create(png_path, axiskeys="zy") # assert adjusted.shape == Shape5D(z=data.shape.y, y=data.shape.x) # -# data_slc = Slice5D(y=slice(4, 7), x=slice(3, 5)) -# adjusted_slice = Slice5D(z=data_slc.y, y=data_slc.x) +# data_slc = Interval5D(y=(4, 7), x=(3, 5)) +# adjusted_slice = Interval5D(z=data_slc.y, y=data_slc.x) # # assert (data.cut(data_slc).raw("yx") == adjusted.retrieve(adjusted_slice).raw("zy")).all() @@ -430,7 +430,7 @@ def test_datasource_slice_clamped_get_tiles_is_tile_aligned(): # fmt: on ds = ArrayDataSource.from_array5d(data, tile_shape=Shape5D(x=2, y=2)) - data_slice = DataSourceSlice(datasource=ds, x=slice(1, 4), y=slice(0, 3)) + data_slice = DataRoi(datasource=ds, x=(1, 4), y=(0, 3)) # fmt: off dataslice_expected_data = Array5D(np.asarray([ @@ -463,8 +463,8 @@ def test_datasource_slice_clamped_get_tiles_is_tile_aligned(): ]).astype(np.uint8), axiskeys="yx", location=Point5D.zero(x=2, y=2)) ] # fmt: on - expected_slice_dict = {a.roi: a for a in dataslice_expected_slices} + expected_slice_dict = {a.interval: a for a in dataslice_expected_slices} for piece in data_slice.get_tiles(clamp=True): - expected_data = expected_slice_dict.pop(piece.roi) + expected_data = expected_slice_dict.pop(piece.interval) assert expected_data == piece.retrieve() assert len(expected_slice_dict) == 0 From 8c703b0b414dd7c6f0c82d25ed405528f971c44e Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Fri, 11 Dec 2020 13:58:27 +0100 Subject: [PATCH 4/6] Fixes some typing annotations --- ndstructs/array5D.py | 26 ++++++++------------------ ndstructs/point5D.py | 7 +++---- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/ndstructs/array5D.py b/ndstructs/array5D.py index c7ff816..5f8299b 100644 --- a/ndstructs/array5D.py +++ b/ndstructs/array5D.py @@ -12,20 +12,6 @@ Arr = TypeVar("Arr", bound="Array5D") -DTYPE = Union[ - Type[np.uint8], - Type[np.uint16], - Type[np.uint32], - Type[np.uint64], - Type[np.int8], - Type[np.int16], - Type[np.int32], - Type[np.int64], - Type[np.float16], - Type[np.float32], - Type[np.float64], -] - class All: pass @@ -88,7 +74,11 @@ def __repr__(self) -> str: @classmethod def allocate( - cls: Type[Arr], slc: Union[Interval5D, Shape5D], dtype: DTYPE, axiskeys: str = Point5D.LABELS, value: int = None + cls: Type[Arr], + slc: Union[Interval5D, Shape5D], + dtype: np.dtype, + axiskeys: str = Point5D.LABELS, + value: int = None, ) -> Arr: slc = slc.to_interval5d() if isinstance(slc, Shape5D) else slc assert sorted(axiskeys) == sorted(Point5D.LABELS) @@ -101,7 +91,7 @@ def allocate( @classmethod def allocate_like( - cls: Type[Arr], arr: "Array5D", dtype: Optional[DTYPE], axiskeys: str = "", value: int = None + cls: Type[Arr], arr: "Array5D", dtype: Optional[np.dtype], axiskeys: str = "", value: int = None ) -> Arr: return cls.allocate(arr.interval, dtype=dtype or arr.dtype, axiskeys=axiskeys or arr.axiskeys, value=value) @@ -263,7 +253,7 @@ def duplicate(self: Arr) -> Arr: def clamped( self: Arr, - interval: Union[Shape5D, Interval5D, None] = None, + limits: Union[Shape5D, Interval5D, None] = None, *, x: Optional[SPAN] = None, y: Optional[SPAN] = None, @@ -271,7 +261,7 @@ def clamped( t: Optional[SPAN] = None, c: Optional[SPAN] = None, ) -> Arr: - return self.cut(self.interval.clamped(interval, x=x, y=y, z=z, t=t, c=c)) + return self.cut(self.interval.clamped(limits, x=x, y=y, z=z, t=t, c=c)) @property def interval(self) -> Interval5D: diff --git a/ndstructs/point5D.py b/ndstructs/point5D.py index d398c66..35a62be 100644 --- a/ndstructs/point5D.py +++ b/ndstructs/point5D.py @@ -32,7 +32,6 @@ class Point5D(JsonSerializable): LABELS = "txyzc" # if you change this order, also change self._array order SPATIAL_LABELS = "xyz" LABEL_MAP = {label: index for index, label in enumerate(LABELS)} - DTYPE = np.float64 def __init__(self, *, t: int = 0, x: int = 0, y: int = 0, z: int = 0, c: int = 0): self.x = x @@ -58,7 +57,7 @@ def from_np(cls: Type[PT], arr: np.ndarray, labels: str) -> PT: def to_tuple(self, axis_order: str) -> Tuple[int, ...]: return tuple(self[label] for label in axis_order) - def to_dict(self) -> Dict[str, float]: + def to_dict(self) -> Dict[str, int]: return {k: self[k] for k in self.LABELS} def to_np(self, axis_order: str = LABELS) -> np.ndarray: @@ -247,11 +246,11 @@ def is_scalar(self) -> bool: return self.c == 1 @property - def volume(self) -> float: + def volume(self) -> int: return self.x * self.y * self.z @property - def hypervolume(self) -> float: + def hypervolume(self) -> int: return functools.reduce(operator.mul, self.to_tuple(Point5D.LABELS)) def to_interval5d(self, offset: Point5D = Point5D.zero()) -> "Interval5D": From 2e6a3df6b09417ec024895157781aad0a9b3b18e Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Fri, 11 Dec 2020 14:14:44 +0100 Subject: [PATCH 5/6] Fixes bad param name in N5DataSource --- ndstructs/datasource/N5DataSource.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndstructs/datasource/N5DataSource.py b/ndstructs/datasource/N5DataSource.py index 3495794..15c8d15 100644 --- a/ndstructs/datasource/N5DataSource.py +++ b/ndstructs/datasource/N5DataSource.py @@ -126,7 +126,7 @@ def _get_tile(self, tile: Interval5D) -> Array5D: data=raw_tile, on_disk_axiskeys=f_axiskeys, dtype=self.dtype, compression_type=self.compression_type ) except ResourceNotFound as e: - tile_5d = self._allocate(roi=tile, fill_value=0) + tile_5d = self._allocate(interval=tile, fill_value=0) return tile_5d.translated(tile.start) def __getstate__(self) -> Dict[str, Any]: From 86893abfb4106cc4e914d57477e245fb17dde3a4 Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Fri, 5 Feb 2021 14:47:54 +0100 Subject: [PATCH 6/6] Renames parameters from slc to interval --- ndstructs/array5D.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/ndstructs/array5D.py b/ndstructs/array5D.py index 5f8299b..a7308cc 100644 --- a/ndstructs/array5D.py +++ b/ndstructs/array5D.py @@ -75,16 +75,16 @@ def __repr__(self) -> str: @classmethod def allocate( cls: Type[Arr], - slc: Union[Interval5D, Shape5D], + interval: Union[Interval5D, Shape5D], dtype: np.dtype, axiskeys: str = Point5D.LABELS, value: int = None, ) -> Arr: - slc = slc.to_interval5d() if isinstance(slc, Shape5D) else slc + interval = interval.to_interval5d() if isinstance(interval, Shape5D) else interval assert sorted(axiskeys) == sorted(Point5D.LABELS) - assert slc.shape.hypervolume != float("inf") - arr = np.empty(slc.shape.to_tuple(axiskeys), dtype=dtype) - arr = cls(arr, axiskeys, location=slc.start) + assert interval.shape.hypervolume != float("inf") + arr = np.empty(interval.shape.to_tuple(axiskeys), dtype=dtype) + arr = cls(arr, axiskeys, location=interval.start) if value is not None: arr._data[...] = value return arr @@ -327,8 +327,9 @@ def paint_point(self, point: Point5D, value: Number, local: bool = False): self._data[np_selection] = value def combine(self: Arr, others: Sequence[Arr]) -> Arr: + """Pastes self and others together into a single Array5D""" out_roi = Interval5D.enclosing([self.interval] + [o.interval for o in others]) - out = self.allocate(slc=out_roi, dtype=self.dtype, axiskeys=self.axiskeys, value=0) + out = self.allocate(interval=out_roi, dtype=self.dtype, axiskeys=self.axiskeys, value=0) out.set(self) for other in others: out.set(other)