diff --git a/microgen/rve.py b/microgen/rve.py index 116692a6..a9c2f02a 100644 --- a/microgen/rve.py +++ b/microgen/rve.py @@ -1,12 +1,17 @@ """Representative Volume Element (RVE). -The ``Rve.box`` attribute is a :class:`~microgen.cad.CadShape` wrapping an -OCCT box; it is built lazily on first access and requires the ``[cad]`` -extra (``cadquery-ocp-novtk``). +Frozen, immutable container for the RVE bounding box and its periodicity flags. + +``Rve.box`` is a :class:`~microgen.cad.CadShape` wrapping an OCCT box; it is +built lazily on first access and requires the ``[cad]`` extra. ``Rve.grid`` +returns a :class:`pyvista.StructuredGrid` aligned with the cell — used by +implicit shapes to sample SDF/level-set fields. """ from __future__ import annotations +from dataclasses import dataclass +from functools import cached_property from typing import TYPE_CHECKING import numpy as np @@ -16,78 +21,144 @@ if TYPE_CHECKING: from collections.abc import Sequence - from .cad import CadShape - - Vector3DType = tuple[float, float, float] | Sequence[float] + import numpy.typing as npt + import pyvista as pv + from .cad import CadShape + Vector3DType = ( + tuple[float, float, float] | Sequence[float] | npt.NDArray[np.float64] + ) + ResolutionType = int | tuple[int, int, int] | Sequence[int] + + +def _validate_center(center: object) -> np.ndarray: + if isinstance(center, (tuple, list)) and len(center) == _DIM: + return np.asarray(center, dtype=float) + if isinstance(center, np.ndarray) and center.shape == (_DIM,): + return center.astype(float, copy=True) + err_msg = f"center must be an array or Sequence of length {_DIM}" + raise ValueError(err_msg) + + +def _validate_dim(dim: object) -> np.ndarray: + if isinstance(dim, (int, float)) and not isinstance(dim, bool): + arr = np.array([dim] * _DIM, dtype=float) + elif isinstance(dim, (tuple, list)) and len(dim) == _DIM: + arr = np.asarray(dim, dtype=float) + elif isinstance(dim, np.ndarray) and dim.shape == (_DIM,): + arr = dim.astype(float, copy=True) + else: + err_msg = f"dim must be an array or Sequence of length {_DIM}" + raise ValueError(err_msg) + if np.any(arr <= 0): + err_msg = f"dimensions of the RVE must be greater than 0, got {arr.tolist()}" + raise ValueError(err_msg) + return arr + + +def _validate_pbc(pbc: object) -> tuple[bool, bool, bool]: + if isinstance(pbc, bool): + return (pbc, pbc, pbc) + if isinstance(pbc, (tuple, list)) and len(pbc) == _DIM: + return (bool(pbc[0]), bool(pbc[1]), bool(pbc[2])) + err_msg = f"pbc must be a bool or Sequence[bool] of length {_DIM}" + raise ValueError(err_msg) + + +@dataclass(frozen=True, init=False, eq=False, repr=False) class Rve: - """Representative Volume Element (RVE). + """Representative Volume Element (RVE) — frozen. :param center: center of the RVE - :param dim: dimensions of the RVE + :param dim: dimensions of the RVE (scalar → cube) + :param pbc: periodic-boundary-condition flags per axis ``(x, y, z)``; + defaults to fully periodic. A single ``bool`` is broadcast to all axes. """ + center: np.ndarray + dim: np.ndarray + pbc: tuple[bool, bool, bool] + def __init__( self: Rve, center: Vector3DType = (0, 0, 0), dim: float | Vector3DType = 1, + pbc: bool | tuple[bool, bool, bool] | Sequence[bool] = (True, True, True), ) -> None: """Initialize the RVE.""" - if isinstance(center, (tuple, list)) and len(center) == _DIM: - self.center = np.array(center) - elif isinstance(center, np.ndarray) and center.shape == (_DIM,): - self.center = center - else: - err_msg = f"center must be an array or Sequence of length {_DIM}" - raise ValueError(err_msg) - - if isinstance(dim, (int, float)): - self.dim = np.array([dim for _ in range(_DIM)]) - elif isinstance(dim, (tuple, list)) and len(dim) == _DIM: - self.dim = np.array(dim) - elif isinstance(dim, np.ndarray) and dim.shape == (_DIM,): - self.dim = dim - else: - err_msg = f"dim must be an array or Sequence of length {_DIM}" - raise ValueError(err_msg) + object.__setattr__(self, "center", _validate_center(center)) + object.__setattr__(self, "dim", _validate_dim(dim)) + object.__setattr__(self, "pbc", _validate_pbc(pbc)) + + @cached_property + def min_point(self: Rve) -> np.ndarray: + """Min corner ``center - 0.5 * dim``.""" + return self.center - 0.5 * self.dim + + @cached_property + def max_point(self: Rve) -> np.ndarray: + """Max corner ``center + 0.5 * dim``.""" + return self.center + 0.5 * self.dim + + @cached_property + def box(self: Rve) -> CadShape: + """Return a :class:`~microgen.cad.CadShape` box of the RVE (cached). - if np.any(self.dim <= 0): - err_msg = f"dimensions of the RVE must be greater than 0, got {self.dim}" - raise ValueError(err_msg) + Requires the ``[cad]`` extra. + """ + from .cad import make_box # noqa: PLC0415 - self.min_point = self.center - 0.5 * self.dim - self.max_point = self.center + 0.5 * self.dim + return make_box(tuple(self.dim), tuple(self.center)) - self._cached_box: CadShape | None = None + def grid(self: Rve, resolution: ResolutionType) -> pv.StructuredGrid: + """Return a structured grid aligned with the RVE. - @property - def box(self) -> CadShape: - """Return a :class:`~microgen.cad.CadShape` box of the RVE (cached). - - Requires the ``[cad]`` extra. + :param resolution: points per axis — int (broadcast to all 3 axes) or + length-3 sequence. The grid spans ``min_point`` → ``max_point`` + with endpoints included on every axis. """ - if self._cached_box is None: - from .cad import make_box # noqa: PLC0415 + import pyvista as pv # noqa: PLC0415 + + if isinstance(resolution, int): + nx, ny, nz = resolution, resolution, resolution + elif isinstance(resolution, (tuple, list)) and len(resolution) == _DIM: + nx, ny, nz = (int(r) for r in resolution) + elif isinstance(resolution, np.ndarray) and resolution.shape == (_DIM,): + nx, ny, nz = (int(r) for r in resolution.tolist()) + else: + err_msg = f"resolution must be an int or Sequence of length {_DIM}" + raise ValueError(err_msg) - self._cached_box = make_box(tuple(self.dim), tuple(self.center)) - return self._cached_box + xi = np.linspace(self.min_point[0], self.max_point[0], nx) + yi = np.linspace(self.min_point[1], self.max_point[1], ny) + zi = np.linspace(self.min_point[2], self.max_point[2], nz) + x, y, z = np.meshgrid(xi, yi, zi, indexing="ij") + return pv.StructuredGrid(x, y, z) - @box.setter - def box(self, value: CadShape) -> None: - self._cached_box = value + def __repr__(self: Rve) -> str: + return ( + f"Rve(center={tuple(self.center.tolist())}, " + f"dim={tuple(self.dim.tolist())}, pbc={self.pbc})" + ) @classmethod def from_min_max( cls: type[Rve], min_point: Vector3DType = (-0.5, -0.5, -0.5), max_point: Vector3DType = (0.5, 0.5, 0.5), + pbc: bool | tuple[bool, bool, bool] | Sequence[bool] = (True, True, True), ) -> Rve: """Generate a Rve from min and max corner points. :param min_point: ``(x_min, y_min, z_min)`` corner of the RVE :param max_point: ``(x_max, y_max, z_max)`` corner of the RVE + :param pbc: periodic-boundary-condition flags (see ``__init__``) """ lo = np.asarray(min_point, dtype=float) hi = np.asarray(max_point, dtype=float) - return cls(center=tuple(0.5 * (lo + hi)), dim=tuple(np.abs(hi - lo))) + return cls( + center=tuple(0.5 * (lo + hi)), + dim=tuple(np.abs(hi - lo)), + pbc=pbc, + ) diff --git a/tests/test_rve.py b/tests/test_rve.py index 63dbdbf6..21397342 100644 --- a/tests/test_rve.py +++ b/tests/test_rve.py @@ -1,7 +1,10 @@ """Tests for Rve class.""" +import dataclasses + import numpy as np import pytest +import pyvista as pv from microgen import Rve @@ -80,3 +83,84 @@ def test_rve_from_min_max_must_return_expected_rve() -> None: rve = Rve.from_min_max(min_point=(-1, 0, -3), max_point=(0, 2, 3)) assert np.all(rve.center == [-0.5, 1, 0]) assert np.all(rve.dim == [1, 2, 6]) + + +def test_rve_pbc_default_is_fully_periodic() -> None: + """``pbc`` defaults to ``(True, True, True)``.""" + assert Rve().pbc == (True, True, True) + + +def test_rve_pbc_scalar_broadcasts_to_all_axes() -> None: + """A single bool is broadcast to all three axes.""" + assert Rve(pbc=False).pbc == (False, False, False) + assert Rve(pbc=True).pbc == (True, True, True) + + +def test_rve_pbc_accepts_per_axis_tuple() -> None: + """A length-3 sequence sets each axis independently.""" + assert Rve(pbc=(True, False, True)).pbc == (True, False, True) + assert Rve(pbc=[False, False, True]).pbc == (False, False, True) + + +def test_rve_invalid_pbc_raises() -> None: + """``pbc`` of wrong length or type raises ``ValueError``.""" + with pytest.raises(ValueError, match="pbc must be a bool"): + Rve(pbc=(True, False)) # type: ignore[arg-type] + with pytest.raises(ValueError, match="pbc must be a bool"): + Rve(pbc=1.5) # type: ignore[arg-type] + + +def test_rve_is_frozen() -> None: + """Reassigning ``center``/``dim``/``pbc`` after construction raises FrozenInstanceError.""" + rve = Rve(dim=2) + with pytest.raises(dataclasses.FrozenInstanceError): + rve.center = np.array([1.0, 2.0, 3.0]) # type: ignore[misc] + with pytest.raises(dataclasses.FrozenInstanceError): + rve.dim = np.array([3.0, 3.0, 3.0]) # type: ignore[misc] + with pytest.raises(dataclasses.FrozenInstanceError): + rve.pbc = (False, False, False) # type: ignore[misc] + + +def test_rve_min_max_points_are_consistent() -> None: + """``min_point`` / ``max_point`` are derived from center ± 0.5·dim.""" + rve = Rve(center=(1.0, 2.0, 3.0), dim=(2.0, 4.0, 6.0)) + assert np.allclose(rve.min_point, [0.0, 0.0, 0.0]) + assert np.allclose(rve.max_point, [2.0, 4.0, 6.0]) + + +def test_rve_grid_with_scalar_resolution() -> None: + """``Rve.grid(n)`` returns a StructuredGrid spanning the cell with n^3 points.""" + rve = Rve(center=(0.0, 0.0, 0.0), dim=(2.0, 4.0, 6.0)) + grid = rve.grid(5) + assert isinstance(grid, pv.StructuredGrid) + assert grid.dimensions == (5, 5, 5) + pts = np.asarray(grid.points) + assert np.isclose(pts[:, 0].min(), -1.0) + assert np.isclose(pts[:, 0].max(), 1.0) + assert np.isclose(pts[:, 1].min(), -2.0) + assert np.isclose(pts[:, 1].max(), 2.0) + assert np.isclose(pts[:, 2].min(), -3.0) + assert np.isclose(pts[:, 2].max(), 3.0) + + +def test_rve_grid_with_per_axis_resolution() -> None: + """``Rve.grid((nx, ny, nz))`` honors per-axis resolution.""" + rve = Rve(center=(0.0, 0.0, 0.0), dim=1.0) + grid = rve.grid((3, 4, 5)) + assert grid.dimensions == (3, 4, 5) + + +def test_rve_grid_invalid_resolution_raises() -> None: + """Wrong-length resolution raises ``ValueError``.""" + rve = Rve() + with pytest.raises(ValueError, match="resolution must be an int"): + rve.grid((3, 4)) # type: ignore[arg-type] + + +def test_rve_repr_round_trips_inputs() -> None: + """``repr(rve)`` contains the canonical ``center``/``dim``/``pbc`` triple.""" + rve = Rve(center=(1.0, 2.0, 3.0), dim=(2.0, 2.0, 2.0), pbc=(True, False, True)) + text = repr(rve) + assert "center=(1.0, 2.0, 3.0)" in text + assert "dim=(2.0, 2.0, 2.0)" in text + assert "pbc=(True, False, True)" in text