diff --git a/src/ndv/models/_data_wrapper.py b/src/ndv/models/_data_wrapper.py index 0ee4f48a..91112328 100644 --- a/src/ndv/models/_data_wrapper.py +++ b/src/ndv/models/_data_wrapper.py @@ -209,7 +209,7 @@ def create(cls, data: ArrayT) -> DataWrapper[ArrayT]: for subclass in sorted(_recurse_subclasses(cls), key=lambda x: x.PRIORITY): try: if subclass.supports(data): - logger.debug(f"Using {subclass.__name__} to wrap {type(data)}") + logger.debug("Using %s to wrap %s", subclass.__name__, type(data)) return subclass(data) except Exception as e: warnings.warn( diff --git a/src/ndv/views/_pygfx/_array_canvas.py b/src/ndv/views/_pygfx/_array_canvas.py index d7968623..666a43e8 100755 --- a/src/ndv/views/_pygfx/_array_canvas.py +++ b/src/ndv/views/_pygfx/_array_canvas.py @@ -1,7 +1,8 @@ from __future__ import annotations from contextlib import suppress -from typing import TYPE_CHECKING, Any, Literal, cast +from functools import lru_cache +from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast from weakref import ReferenceType, WeakValueDictionary, ref import cmap as _cmap @@ -18,6 +19,7 @@ ) from ndv.models._viewer_model import ArrayViewerModel, InteractionMode from ndv.views._app import filter_mouse_events +from ndv.views._util import downsample_data from ndv.views.bases import ArrayCanvas, CanvasElement, ImageHandle from ndv.views.bases._graphics._canvas_elements import RectangularROIHandle, ROIMoveMode @@ -63,11 +65,20 @@ def __init__(self, image: pygfx.Image | pygfx.Volume, render: Callable) -> None: self._render = render self._grid = cast("Texture", image.geometry.grid) self._material = cast("ImageBasicMaterial", image.material) + # per-axis downsample strides applied to fit GPU texture limits + self._downsample_factors: tuple[int, ...] = () def data(self) -> np.ndarray: return self._grid.data # type: ignore [no-any-return] def set_data(self, data: np.ndarray) -> None: + is_three_d = isinstance(self._image, pygfx.Volume) + data, self._downsample_factors = _downcast_and_downsample( + data, + three_d=is_three_d, + warn=False, + copy=False, + ) # If dimensions are unchanged, reuse the buffer if data.shape == self._grid.data.shape: self._grid.data[:] = data # pyright: ignore[reportOptionalSubscript] @@ -75,10 +86,12 @@ def set_data(self, data: np.ndarray) -> None: # Otherwise, the size (and maybe number of dimensions) changed # - we need a new buffer else: - self._grid = pygfx.Texture(data, dim=2) + dim = 3 if is_three_d else 2 + self._grid = pygfx.Texture(data, dim=dim) self._image.geometry = pygfx.Geometry(grid=self._grid) # RGB images (i.e. 3D datasets) cannot have a colormap - self._material.map = None if self._is_rgb() else self._cmap.to_pygfx() + if not is_three_d: + self._material.map = None if self._is_rgb() else self._cmap.to_pygfx() def visible(self) -> bool: return bool(self._image.visible) @@ -465,11 +478,7 @@ def set_ndim(self, ndim: Literal[2, 3]) -> None: def add_image(self, data: np.ndarray | None = None) -> PyGFXImageHandle: """Add a new Image node to the scene.""" - if data is not None: - # pygfx uses a view of the data without copy, so if we don't - # copy it here, the original data will be modified when the - # texture changes. - data = data.copy() + data, downsample_factors = _downcast_and_downsample(data, three_d=False) tex = pygfx.Texture(data, dim=2) image = pygfx.Image( pygfx.Geometry(grid=tex), @@ -485,15 +494,12 @@ def add_image(self, data: np.ndarray | None = None) -> PyGFXImageHandle: # FIXME: I suspect there are more performant ways to refresh the canvas # look into it. handle = PyGFXImageHandle(image, self.refresh) + handle._downsample_factors = downsample_factors self._elements[image] = handle return handle def add_volume(self, data: np.ndarray | None = None) -> PyGFXImageHandle: - if data is not None: - # pygfx uses a view of the data without copy, so if we don't - # copy it here, the original data will be modified when the - # texture changes. - data = data.copy() + data, downsample_factors = _downcast_and_downsample(data, three_d=True) tex = pygfx.Texture(data, dim=3) vol = pygfx.Volume( pygfx.Geometry(grid=tex), @@ -512,6 +518,7 @@ def add_volume(self, data: np.ndarray | None = None) -> PyGFXImageHandle: # FIXME: I suspect there are more performant ways to refresh the canvas # look into it. handle = PyGFXImageHandle(vol, self.refresh) + handle._downsample_factors = downsample_factors self._elements[vol] = handle return handle @@ -539,10 +546,23 @@ def set_scales(self, scales: tuple[float, ...]) -> None: gfx_scales.append(1.0) sx, sy, sz = gfx_scales[0], gfx_scales[1], gfx_scales[2] has_visuals = False - for child in self._scene.children: - if isinstance(child, (pygfx.Image, pygfx.Volume)): - child.local.scale = (sx, sy, sz) - has_visuals = True + for handle in self._elements.values(): + if not isinstance(handle, PyGFXImageHandle): + continue + child = handle._image + if not isinstance(child, (pygfx.Image, pygfx.Volume)): + continue + _sx, _sy, _sz = sx, sy, sz + # compensate for downsampling so coordinates stay correct + # factors are in data order; pygfx order is (x, y, z) = reversed + factors = handle._downsample_factors + if factors and any(f > 1 for f in factors): + rev = list(reversed(factors)) + _sx *= rev[0] + _sy *= rev[1] if len(rev) > 1 else 1 + _sz *= rev[2] if len(rev) > 2 else 1 + child.local.scale = (_sx, _sy, _sz) + has_visuals = True if has_visuals: self.set_range() @@ -710,3 +730,37 @@ def get_cursor(self, event: MouseMoveEvent) -> CursorType: if cursor := vis.get_cursor(event): return cursor return CursorType.DEFAULT + + +T = TypeVar("T", bound=np.ndarray | None) + + +@lru_cache(maxsize=1) +def _get_max_texture_sizes() -> tuple[int | None, int | None]: + """Return (max_2d, max_3d) texture dimensions from the wgpu adapter.""" + try: + import wgpu + + adapter = wgpu.gpu.request_adapter_sync() + limits = adapter.limits + max_2d = limits.get("max-texture-dimension-2d") + max_3d = limits.get("max-texture-dimension-3d") + return max_2d, max_3d + except Exception: + return None, None + + +def _downcast_and_downsample( + data: T, three_d: bool, *, warn: bool = True, copy: bool = True +) -> tuple[T, tuple[int, ...]]: + downsample_factors: tuple[int, ...] = () + if data is not None: + if copy: + # pygfx uses a view of the data without copy, so if we don't + # copy it here, the original data will be modified when the + # texture changes. + data = data.copy() + maxd = _get_max_texture_sizes()[1 if three_d else 0] + if maxd is not None: + data, downsample_factors = downsample_data(data, maxd, warn=warn) # type: ignore[assignment] + return data, downsample_factors # pyright: ignore[reportReturnType] diff --git a/src/ndv/views/_util.py b/src/ndv/views/_util.py new file mode 100644 index 00000000..73cbc96a --- /dev/null +++ b/src/ndv/views/_util.py @@ -0,0 +1,33 @@ +"""Shared utilities for canvas backends.""" + +from __future__ import annotations + +import logging + +import numpy as np + +logger = logging.getLogger("ndv") + + +def downsample_data( + data: np.ndarray, max_size: int, *, warn: bool = True +) -> tuple[np.ndarray, tuple[int, ...]]: + """Downsample data so no axis exceeds max_size. + + Returns the (possibly downsampled view) data and the per-axis stride factors. + """ + factors = tuple( + int(np.ceil(s / max_size)) if s > max_size else 1 for s in data.shape + ) + if any(f > 1 for f in factors): + if warn: + logger.warning( + "Data shape %s exceeds max texture dimension (%d) and will be " + "downsampled for rendering (strides: %s).", + data.shape, + max_size, + factors, + ) + slices = tuple(slice(None, None, f) for f in factors) + data = data[slices] + return data, factors diff --git a/src/ndv/views/_vispy/_array_canvas.py b/src/ndv/views/_vispy/_array_canvas.py index fc5eedb9..efed67cd 100755 --- a/src/ndv/views/_vispy/_array_canvas.py +++ b/src/ndv/views/_vispy/_array_canvas.py @@ -3,7 +3,7 @@ import warnings from contextlib import suppress -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast from weakref import ReferenceType, WeakValueDictionary import cmap as _cmap @@ -23,6 +23,8 @@ ) from ndv.models._viewer_model import ArrayViewerModel, InteractionMode from ndv.views._app import filter_mouse_events +from ndv.views._util import downsample_data +from ndv.views._vispy._util import get_max_texture_sizes from ndv.views.bases import ArrayCanvas from ndv.views.bases._graphics._canvas_elements import ( CanvasElement, @@ -43,6 +45,8 @@ class VispyImageHandle(ImageHandle): def __init__(self, visual: visuals.ImageVisual | visuals.VolumeVisual) -> None: self._visual = visual self._allowed_dims = {2, 3} if isinstance(visual, visuals.ImageVisual) else {3} + # per-axis downsample strides applied to fit GPU texture limits + self._downsample_factors: tuple[int, ...] = () def data(self) -> np.ndarray: try: @@ -58,6 +62,13 @@ def set_data(self, data: np.ndarray) -> None: stacklevel=2, ) return + + data, downsample_factors = _downcast_and_downsample( + data, + three_d=isinstance(self._visual, visuals.VolumeVisual), + warn=False, + ) + self._downsample_factors = downsample_factors self._visual.set_data(data) def visible(self) -> bool: @@ -358,7 +369,7 @@ def refresh(self) -> None: def add_image(self, data: np.ndarray | None = None) -> VispyImageHandle: """Add a new Image node to the scene.""" - data = _downcast(data) + data, downsample_factors = _downcast_and_downsample(data, three_d=False) try: img = scene.visuals.Image( data, parent=self._view.scene, texture_format="auto" @@ -370,13 +381,14 @@ def add_image(self, data: np.ndarray | None = None) -> VispyImageHandle: img.set_gl_state("additive", depth_test=False) img.interactive = True handle = VispyImageHandle(img) + handle._downsample_factors = downsample_factors self._elements[img] = handle if data is not None: self.set_range() return handle def add_volume(self, data: np.ndarray | None = None) -> VispyImageHandle: - data = _downcast(data) + data, downsample_factors = _downcast_and_downsample(data, three_d=True) try: vol = scene.visuals.Volume( data, @@ -393,6 +405,7 @@ def add_volume(self, data: np.ndarray | None = None) -> VispyImageHandle: vol.set_gl_state("additive", depth_test=False) vol.interactive = True handle = VispyImageHandle(vol) + handle._downsample_factors = downsample_factors self._elements[vol] = handle if data is not None: self.set_range() @@ -418,11 +431,24 @@ def set_scales(self, scales: tuple[float, ...]) -> None: while len(vis_scales) < 3: vis_scales.append(1.0) sx, sy, sz = vis_scales[0], vis_scales[1], vis_scales[2] - for child in self._view.scene.children: - if isinstance(child, (visuals.ImageVisual, visuals.VolumeVisual)): - child.transform = vispy.visuals.transforms.STTransform( - scale=(sx, sy, sz) - ) + for handle in self._elements.values(): + if not isinstance(handle, VispyImageHandle): + continue + child = handle._visual + if not isinstance(child, (visuals.ImageVisual, visuals.VolumeVisual)): + continue + _sx, _sy, _sz = sx, sy, sz + # compensate for downsampling so coordinates stay correct + # factors are in data order; scene order is (x, y, z) = reversed + factors = handle._downsample_factors + if factors and any(f > 1 for f in factors): + rev = list(reversed(factors)) + _sx *= rev[0] + _sy *= rev[1] if len(rev) > 1 else 1 + _sz *= rev[2] if len(rev) > 2 else 1 + child.transform = vispy.visuals.transforms.STTransform( + scale=(_sx, _sy, _sz) + ) self.set_range() def set_range( @@ -561,13 +587,29 @@ def get_cursor(self, event: MouseMoveEvent) -> CursorType: return CursorType.DEFAULT -def _downcast(data: np.ndarray | None) -> np.ndarray | None: +T = TypeVar("T", bound="np.ndarray | None") + + +def _downcast(data: T) -> T: """Downcast >32bit data to 32bit.""" # downcast to 32bit, preserving int/float if data is not None: if np.issubdtype(data.dtype, np.integer) and data.dtype.itemsize > 2: warnings.warn("Downcasting integer data to uint16.", stacklevel=2) - data = data.astype(np.uint16) + data = data.astype(np.uint16) # type: ignore[assignment] elif np.issubdtype(data.dtype, np.floating) and data.dtype.itemsize > 4: - data = data.astype(np.float32) + data = data.astype(np.float32) # type: ignore[assignment] return data + + +def _downcast_and_downsample( + data: T, three_d: bool, warn: bool = True +) -> tuple[T, tuple[int, ...]]: + """Downcast >32bit data to 32bit, and downsample GPU texture limits are exceeded.""" + data = _downcast(data) + downsample_factors: tuple[int, ...] = () + if data is not None: + maxd = get_max_texture_sizes()[1 if three_d else 0] + if maxd is not None: + data, downsample_factors = downsample_data(data, maxd, warn=warn) # type: ignore[assignment] + return data, downsample_factors diff --git a/src/ndv/views/_vispy/_util.py b/src/ndv/views/_vispy/_util.py new file mode 100644 index 00000000..ab7ead76 --- /dev/null +++ b/src/ndv/views/_vispy/_util.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from contextlib import contextmanager +from functools import lru_cache +from typing import TYPE_CHECKING + +from vispy.app import Canvas +from vispy.gloo import gl +from vispy.gloo.context import get_current_canvas + +if TYPE_CHECKING: + from collections.abc import Generator + + +@contextmanager +def _opengl_context() -> Generator[None, None, None]: + """Assure we are running with a valid OpenGL context. + + Only create a Canvas is one doesn't exist. Creating and closing a + Canvas causes vispy to process Qt events which can cause problems. + """ + canvas = Canvas(show=False) if get_current_canvas() is None else None + try: + yield + finally: + if canvas is not None: + canvas.close() + + +@lru_cache +def get_max_texture_sizes() -> tuple[int | None, int | None]: + """Return the maximum texture sizes for 2D and 3D rendering. + + Returns + ------- + Tuple[int | None, int | None] + The max textures sizes for (2d, 3d) rendering. + """ + with _opengl_context(): + max_size_2d = gl.glGetParameter(gl.GL_MAX_TEXTURE_SIZE) + + if not max_size_2d: + max_size_2d = None + + # vispy/gloo doesn't provide the GL_MAX_3D_TEXTURE_SIZE location, + # but it can be found in this list of constants + # http://pyopengl.sourceforge.net/documentation/pydoc/OpenGL.GL.html + with _opengl_context(): + GL_MAX_3D_TEXTURE_SIZE = 32883 + max_size_3d = gl.glGetParameter(GL_MAX_3D_TEXTURE_SIZE) + + if not max_size_3d: + max_size_3d = None + + return max_size_2d, max_size_3d diff --git a/tests/test_downsample_data.py b/tests/test_downsample_data.py new file mode 100644 index 00000000..24dea13a --- /dev/null +++ b/tests/test_downsample_data.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import numpy as np + +from ndv.views._util import downsample_data + + +def test_no_downsample_when_within_limit() -> None: + data = np.zeros((50, 60), dtype=np.float32) + result, factors = downsample_data(data, 64) + assert factors == (1, 1) + assert result is data # no copy, same object + + +def test_exact_boundary_no_downsample() -> None: + data = np.zeros((64, 64, 64), dtype=np.float32) + result, factors = downsample_data(data, 64) + assert factors == (1, 1, 1) + assert result is data + + +def test_single_axis_overflow() -> None: + data = np.zeros((10, 100), dtype=np.float32) + result, factors = downsample_data(data, 64) + assert factors == (1, 2) + assert result.shape == (10, 50) + + +def test_all_axes_overflow() -> None: + data = np.zeros((200, 150, 130), dtype=np.float32) + result, factors = downsample_data(data, 64) + assert factors == (4, 3, 3) + assert result.shape == (50, 50, 44) + + +def test_2d_input() -> None: + data = np.zeros((100, 80), dtype=np.float32) + result, factors = downsample_data(data, 64) + assert factors == (2, 2) + assert result.shape == (50, 40) + + +def test_4d_input() -> None: + data = np.zeros((10, 100, 80, 3), dtype=np.float32) + result, factors = downsample_data(data, 64) + assert len(factors) == 4 + assert factors == (1, 2, 2, 1) + assert result.shape == (10, 50, 40, 3) + + +def test_returns_view_not_copy() -> None: + """Downsampled result should be a view (no memory copy).""" + data = np.ones((100, 100), dtype=np.float32) + result, _ = downsample_data(data, 64) + assert result.base is data + + +def test_max_size_1() -> None: + data = np.zeros((5, 3), dtype=np.float32) + result, factors = downsample_data(data, 1) + assert factors == (5, 3) + assert result.shape == (1, 1) diff --git a/tests/views/_pygfx/test_volume_downsample.py b/tests/views/_pygfx/test_volume_downsample.py new file mode 100644 index 00000000..d4730c57 --- /dev/null +++ b/tests/views/_pygfx/test_volume_downsample.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from unittest.mock import patch + +import numpy as np +import pygfx +import pytest + +from ndv.models._viewer_model import ArrayViewerModel +from ndv.views._pygfx._array_canvas import GfxArrayCanvas + +PATCH_TARGET = "ndv.views._pygfx._array_canvas._get_max_texture_sizes" + + +def _force_canvas_size(canvas: GfxArrayCanvas, w: int = 600, h: int = 600) -> None: + rc = canvas._canvas + rc._size_info.set_physical_size(w, h, 1.0) + + +@pytest.mark.usefixtures("any_app") +def test_volume_downsampled_when_exceeding_texture_limit() -> None: + """Volume data should be stride-downsampled to fit GPU texture limits.""" + canvas = GfxArrayCanvas(ArrayViewerModel()) + _force_canvas_size(canvas) + canvas.set_ndim(3) + + data = np.zeros((10, 100, 100), dtype=np.float32) + with patch(PATCH_TARGET, return_value=(None, 64)): + handle = canvas.add_volume(data) + + # shape (10, 100, 100) with max 64 -> strides (1, 2, 2) + assert handle._downsample_factors == (1, 2, 2) + assert handle.data().shape == (10, 50, 50) + + # set_data with the same original shape should also downsample + with patch(PATCH_TARGET, return_value=(None, 64)): + handle.set_data(data) + assert handle.data().shape == (10, 50, 50) + + canvas.close() + + +@pytest.mark.usefixtures("any_app") +def test_volume_no_downsample_when_within_limit() -> None: + """Volume data within texture limits should not be downsampled.""" + canvas = GfxArrayCanvas(ArrayViewerModel()) + _force_canvas_size(canvas) + canvas.set_ndim(3) + + data = np.zeros((10, 50, 50), dtype=np.float32) + with patch(PATCH_TARGET, return_value=(None, 64)): + handle = canvas.add_volume(data) + + assert handle._downsample_factors == (1, 1, 1) + assert handle.data().shape == (10, 50, 50) + + canvas.close() + + +@pytest.mark.usefixtures("any_app") +def test_image_downsampled_when_exceeding_2d_texture_limit() -> None: + """2D image data should be stride-downsampled to fit GPU texture limits.""" + canvas = GfxArrayCanvas(ArrayViewerModel()) + _force_canvas_size(canvas) + canvas.set_ndim(2) + + data = np.zeros((100, 100), dtype=np.float32) + with patch(PATCH_TARGET, return_value=(64, None)): + handle = canvas.add_image(data) + + assert handle._downsample_factors == (2, 2) + assert handle.data().shape == (50, 50) + + # set_data should also downsample + with patch(PATCH_TARGET, return_value=(64, None)): + handle.set_data(data) + assert handle.data().shape == (50, 50) + + canvas.close() + + +@pytest.mark.usefixtures("any_app") +def test_set_scales_compensates_for_volume_downsample() -> None: + """set_scales should multiply by downsample factors so world coords stay correct.""" + canvas = GfxArrayCanvas(ArrayViewerModel()) + _force_canvas_size(canvas) + canvas.set_ndim(3) + + data = np.zeros((10, 100, 100), dtype=np.float32) + with patch(PATCH_TARGET, return_value=(None, 64)): + handle = canvas.add_volume(data) + + assert handle._downsample_factors == (1, 2, 2) + + # scales in data order (Z, Y, X) = (0.4, 0.2, 0.2) + with patch(PATCH_TARGET, return_value=(None, 64)): + canvas.set_scales((0.4, 0.2, 0.2)) + + vol = handle._image + assert isinstance(vol, pygfx.Volume) + sx, sy, sz = vol.local.scale + # pygfx order is (x=W, y=H, z=D), reversed from data order + # x scale: 0.2 (X) * 2 (fw) = 0.4 + # y scale: 0.2 (Y) * 2 (fh) = 0.4 + # z scale: 0.4 (Z) * 1 (fd) = 0.4 + assert sx == pytest.approx(0.4) + assert sy == pytest.approx(0.4) + assert sz == pytest.approx(0.4) + + canvas.close() + + +@pytest.mark.usefixtures("any_app") +def test_no_downsample_when_limits_none() -> None: + """When GPU limits are unavailable, data should pass through unchanged.""" + canvas = GfxArrayCanvas(ArrayViewerModel()) + _force_canvas_size(canvas) + canvas.set_ndim(3) + + data = np.zeros((10, 100, 100), dtype=np.float32) + with patch(PATCH_TARGET, return_value=(None, None)): + handle = canvas.add_volume(data) + + assert handle._downsample_factors == () + assert handle.data().shape == (10, 100, 100) + + canvas.close() + + +@pytest.mark.usefixtures("any_app") +def test_set_data_with_different_shape() -> None: + """set_data with a new shape should re-downsample correctly.""" + canvas = GfxArrayCanvas(ArrayViewerModel()) + _force_canvas_size(canvas) + canvas.set_ndim(3) + + data1 = np.zeros((10, 100, 100), dtype=np.float32) + with patch(PATCH_TARGET, return_value=(None, 64)): + handle = canvas.add_volume(data1) + assert handle.data().shape == (10, 50, 50) + + # now set_data with a larger volume + data2 = np.zeros((200, 100, 100), dtype=np.float32) + with patch(PATCH_TARGET, return_value=(None, 64)): + handle.set_data(data2) + assert handle._downsample_factors == (4, 2, 2) + assert handle.data().shape == (50, 50, 50) + + canvas.close() diff --git a/tests/views/_vispy/test_volume_downsample.py b/tests/views/_vispy/test_volume_downsample.py new file mode 100644 index 00000000..19ea0ec4 --- /dev/null +++ b/tests/views/_vispy/test_volume_downsample.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +from unittest.mock import patch + +import numpy as np +import pytest +import vispy.visuals.transforms + +from ndv.models._viewer_model import ArrayViewerModel +from ndv.views._vispy._array_canvas import VispyArrayCanvas + +PATCH_TARGET = "ndv.views._vispy._array_canvas.get_max_texture_sizes" + + +@pytest.mark.usefixtures("any_app") +def test_volume_downsampled_when_exceeding_texture_limit() -> None: + """Volume data should be stride-downsampled to fit GPU texture limits.""" + canvas = VispyArrayCanvas(ArrayViewerModel()) + canvas.set_ndim(3) + + data = np.zeros((10, 100, 100), dtype=np.float32) + with patch(PATCH_TARGET, return_value=(None, 64)): + handle = canvas.add_volume(data) + + # shape (10, 100, 100) with max 64 -> strides (1, 2, 2) + assert handle._downsample_factors == (1, 2, 2) + assert handle.data().shape == (10, 50, 50) + + # set_data with the same original shape should also downsample + with patch(PATCH_TARGET, return_value=(None, 64)): + handle.set_data(data) + assert handle.data().shape == (10, 50, 50) + + canvas.close() + + +@pytest.mark.usefixtures("any_app") +def test_volume_no_downsample_when_within_limit() -> None: + """Volume data within texture limits should not be downsampled.""" + canvas = VispyArrayCanvas(ArrayViewerModel()) + canvas.set_ndim(3) + + data = np.zeros((10, 50, 50), dtype=np.float32) + with patch(PATCH_TARGET, return_value=(None, 64)): + handle = canvas.add_volume(data) + + assert handle._downsample_factors == (1, 1, 1) + assert handle.data().shape == (10, 50, 50) + + canvas.close() + + +@pytest.mark.usefixtures("any_app") +def test_image_downsampled_when_exceeding_2d_texture_limit() -> None: + """2D image data should be stride-downsampled to fit GPU texture limits.""" + canvas = VispyArrayCanvas(ArrayViewerModel()) + canvas.set_ndim(2) + + data = np.zeros((100, 100), dtype=np.float32) + with patch(PATCH_TARGET, return_value=(64, None)): + handle = canvas.add_image(data) + + assert handle._downsample_factors == (2, 2) + assert handle.data().shape == (50, 50) + + # set_data should also downsample + with patch(PATCH_TARGET, return_value=(64, None)): + handle.set_data(data) + assert handle.data().shape == (50, 50) + + canvas.close() + + +@pytest.mark.usefixtures("any_app") +def test_set_scales_compensates_for_volume_downsample() -> None: + """set_scales should multiply by downsample factors so world coords stay correct.""" + canvas = VispyArrayCanvas(ArrayViewerModel()) + canvas.set_ndim(3) + + # original shape (400, 2200, 2200), factors (1, 2, 2) + data = np.zeros((10, 100, 100), dtype=np.float32) + with patch(PATCH_TARGET, return_value=(None, 64)): + handle = canvas.add_volume(data) + + assert handle._downsample_factors == (1, 2, 2) + + # scales in data order (Z, Y, X) = (0.4, 0.2, 0.2) + with patch(PATCH_TARGET, return_value=(None, 64)): + canvas.set_scales((0.4, 0.2, 0.2)) + + tform = handle._visual.transform + assert isinstance(tform, vispy.visuals.transforms.STTransform) + sx, sy, sz = tform.scale[:3] + # scene order is (x=W, y=H, z=D), reversed from data order + # x scale: 0.2 (X) * 2 (fw) = 0.4 + # y scale: 0.2 (Y) * 2 (fh) = 0.4 + # z scale: 0.4 (Z) * 1 (fd) = 0.4 + assert sx == pytest.approx(0.4) + assert sy == pytest.approx(0.4) + assert sz == pytest.approx(0.4) + + canvas.close() + + +@pytest.mark.usefixtures("any_app") +def test_set_range_correct_bounds_after_downsample() -> None: + """set_range should compute world bounds as if data were full-resolution.""" + canvas = VispyArrayCanvas(ArrayViewerModel()) + canvas.set_ndim(3) + + # shape (10, 100, 80) with max 64 -> factors (1, 2, 2) + # downsampled shape: (10, 50, 40) + data = np.zeros((10, 100, 80), dtype=np.float32) + with patch(PATCH_TARGET, return_value=(None, 64)): + handle = canvas.add_volume(data) + canvas.set_scales((1.0, 1.0, 1.0)) + + # After set_scales with (1,1,1), the transform should be (2, 2, 1) + # set_range reads downsampled shape and multiplies by transform scale: + # x = shape[2] * sx = 40 * 2 = 80 (matches original W) + # y = shape[1] * sy = 50 * 2 = 100 (matches original H) + # z = shape[0] * sz = 10 * 1 = 10 (matches original D) + tform = handle._visual.transform + assert isinstance(tform, vispy.visuals.transforms.STTransform) + ds_shape = handle.data().shape + sx, sy, sz = tform.scale[:3] + assert ds_shape[2] * sx == pytest.approx(80.0) + assert ds_shape[1] * sy == pytest.approx(100.0) + assert ds_shape[0] * sz == pytest.approx(10.0) + + canvas.close() + + +@pytest.mark.usefixtures("any_app") +def test_no_downsample_when_limits_none() -> None: + """When GPU limits are unavailable, data should pass through unchanged.""" + canvas = VispyArrayCanvas(ArrayViewerModel()) + canvas.set_ndim(3) + + data = np.zeros((10, 100, 100), dtype=np.float32) + with patch(PATCH_TARGET, return_value=(None, None)): + handle = canvas.add_volume(data) + + assert handle._downsample_factors == () + assert handle.data().shape == (10, 100, 100) + + canvas.close() + + +@pytest.mark.usefixtures("any_app") +def test_set_data_with_different_shape() -> None: + """set_data with a new shape should re-downsample correctly.""" + canvas = VispyArrayCanvas(ArrayViewerModel()) + canvas.set_ndim(3) + + data1 = np.zeros((10, 100, 100), dtype=np.float32) + with patch(PATCH_TARGET, return_value=(None, 64)): + handle = canvas.add_volume(data1) + assert handle.data().shape == (10, 50, 50) + + # now set_data with a larger volume + data2 = np.zeros((200, 100, 100), dtype=np.float32) + with patch(PATCH_TARGET, return_value=(None, 64)): + handle.set_data(data2) + assert handle._downsample_factors == (4, 2, 2) + assert handle.data().shape == (50, 50, 50) + + canvas.close()