Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 115 additions & 44 deletions microgen/rve.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""Representative Volume Element (RVE).

The ``Rve.box`` attribute is a :class:`~microgen.cad.CadShape` wrapping an
OCCT box; it is built lazily on first access and requires the ``[cad]``
extra (``cadquery-ocp-novtk``).
Frozen, immutable container for the RVE bounding box and its periodicity flags.

``Rve.box`` is a :class:`~microgen.cad.CadShape` wrapping an OCCT box; it is
built lazily on first access and requires the ``[cad]`` extra. ``Rve.grid``
returns a :class:`pyvista.StructuredGrid` aligned with the cell — used by
implicit shapes to sample SDF/level-set fields.
"""

from __future__ import annotations

from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -16,78 +21,144 @@
if TYPE_CHECKING:
from collections.abc import Sequence

from .cad import CadShape

Vector3DType = tuple[float, float, float] | Sequence[float]
import numpy.typing as npt
import pyvista as pv

from .cad import CadShape

Vector3DType = (
tuple[float, float, float] | Sequence[float] | npt.NDArray[np.float64]
)
ResolutionType = int | tuple[int, int, int] | Sequence[int]


def _validate_center(center: object) -> np.ndarray:
if isinstance(center, (tuple, list)) and len(center) == _DIM:
return np.asarray(center, dtype=float)
if isinstance(center, np.ndarray) and center.shape == (_DIM,):
return center.astype(float, copy=True)
err_msg = f"center must be an array or Sequence of length {_DIM}"
raise ValueError(err_msg)


def _validate_dim(dim: object) -> np.ndarray:
if isinstance(dim, (int, float)) and not isinstance(dim, bool):
arr = np.array([dim] * _DIM, dtype=float)
elif isinstance(dim, (tuple, list)) and len(dim) == _DIM:
arr = np.asarray(dim, dtype=float)
elif isinstance(dim, np.ndarray) and dim.shape == (_DIM,):
arr = dim.astype(float, copy=True)
else:
err_msg = f"dim must be an array or Sequence of length {_DIM}"
raise ValueError(err_msg)
if np.any(arr <= 0):
err_msg = f"dimensions of the RVE must be greater than 0, got {arr.tolist()}"
raise ValueError(err_msg)
return arr


def _validate_pbc(pbc: object) -> tuple[bool, bool, bool]:
if isinstance(pbc, bool):
return (pbc, pbc, pbc)
if isinstance(pbc, (tuple, list)) and len(pbc) == _DIM:
return (bool(pbc[0]), bool(pbc[1]), bool(pbc[2]))
err_msg = f"pbc must be a bool or Sequence[bool] of length {_DIM}"
raise ValueError(err_msg)


@dataclass(frozen=True, init=False, eq=False, repr=False)
class Rve:
"""Representative Volume Element (RVE).
"""Representative Volume Element (RVE) — frozen.

:param center: center of the RVE
:param dim: dimensions of the RVE
:param dim: dimensions of the RVE (scalar → cube)
:param pbc: periodic-boundary-condition flags per axis ``(x, y, z)``;
defaults to fully periodic. A single ``bool`` is broadcast to all axes.
"""

center: np.ndarray
dim: np.ndarray
pbc: tuple[bool, bool, bool]

def __init__(
self: Rve,
center: Vector3DType = (0, 0, 0),
dim: float | Vector3DType = 1,
pbc: bool | tuple[bool, bool, bool] | Sequence[bool] = (True, True, True),
) -> None:
"""Initialize the RVE."""
if isinstance(center, (tuple, list)) and len(center) == _DIM:
self.center = np.array(center)
elif isinstance(center, np.ndarray) and center.shape == (_DIM,):
self.center = center
else:
err_msg = f"center must be an array or Sequence of length {_DIM}"
raise ValueError(err_msg)

if isinstance(dim, (int, float)):
self.dim = np.array([dim for _ in range(_DIM)])
elif isinstance(dim, (tuple, list)) and len(dim) == _DIM:
self.dim = np.array(dim)
elif isinstance(dim, np.ndarray) and dim.shape == (_DIM,):
self.dim = dim
else:
err_msg = f"dim must be an array or Sequence of length {_DIM}"
raise ValueError(err_msg)
object.__setattr__(self, "center", _validate_center(center))
object.__setattr__(self, "dim", _validate_dim(dim))
object.__setattr__(self, "pbc", _validate_pbc(pbc))
Comment on lines +90 to +92
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look good but there is probably no cleaner way to do this


@cached_property
def min_point(self: Rve) -> np.ndarray:
"""Min corner ``center - 0.5 * dim``."""
return self.center - 0.5 * self.dim

@cached_property
def max_point(self: Rve) -> np.ndarray:
"""Max corner ``center + 0.5 * dim``."""
return self.center + 0.5 * self.dim

@cached_property
def box(self: Rve) -> CadShape:
"""Return a :class:`~microgen.cad.CadShape` box of the RVE (cached).

if np.any(self.dim <= 0):
err_msg = f"dimensions of the RVE must be greater than 0, got {self.dim}"
raise ValueError(err_msg)
Requires the ``[cad]`` extra.
"""
from .cad import make_box # noqa: PLC0415

self.min_point = self.center - 0.5 * self.dim
self.max_point = self.center + 0.5 * self.dim
return make_box(tuple(self.dim), tuple(self.center))

self._cached_box: CadShape | None = None
def grid(self: Rve, resolution: ResolutionType) -> pv.StructuredGrid:
"""Return a structured grid aligned with the RVE.

@property
def box(self) -> CadShape:
"""Return a :class:`~microgen.cad.CadShape` box of the RVE (cached).

Requires the ``[cad]`` extra.
:param resolution: points per axis — int (broadcast to all 3 axes) or
length-3 sequence. The grid spans ``min_point`` → ``max_point``
with endpoints included on every axis.
"""
if self._cached_box is None:
from .cad import make_box # noqa: PLC0415
import pyvista as pv # noqa: PLC0415

if isinstance(resolution, int):
nx, ny, nz = resolution, resolution, resolution
elif isinstance(resolution, (tuple, list)) and len(resolution) == _DIM:
nx, ny, nz = (int(r) for r in resolution)
elif isinstance(resolution, np.ndarray) and resolution.shape == (_DIM,):
nx, ny, nz = (int(r) for r in resolution.tolist())
else:
err_msg = f"resolution must be an int or Sequence of length {_DIM}"
raise ValueError(err_msg)

self._cached_box = make_box(tuple(self.dim), tuple(self.center))
return self._cached_box
xi = np.linspace(self.min_point[0], self.max_point[0], nx)
yi = np.linspace(self.min_point[1], self.max_point[1], ny)
zi = np.linspace(self.min_point[2], self.max_point[2], nz)
x, y, z = np.meshgrid(xi, yi, zi, indexing="ij")
return pv.StructuredGrid(x, y, z)

@box.setter
def box(self, value: CadShape) -> None:
self._cached_box = value
def __repr__(self: Rve) -> str:
return (
f"Rve(center={tuple(self.center.tolist())}, "
f"dim={tuple(self.dim.tolist())}, pbc={self.pbc})"
)

@classmethod
def from_min_max(
cls: type[Rve],
min_point: Vector3DType = (-0.5, -0.5, -0.5),
max_point: Vector3DType = (0.5, 0.5, 0.5),
pbc: bool | tuple[bool, bool, bool] | Sequence[bool] = (True, True, True),
) -> Rve:
"""Generate a Rve from min and max corner points.

:param min_point: ``(x_min, y_min, z_min)`` corner of the RVE
:param max_point: ``(x_max, y_max, z_max)`` corner of the RVE
:param pbc: periodic-boundary-condition flags (see ``__init__``)
"""
lo = np.asarray(min_point, dtype=float)
hi = np.asarray(max_point, dtype=float)
return cls(center=tuple(0.5 * (lo + hi)), dim=tuple(np.abs(hi - lo)))
return cls(
center=tuple(0.5 * (lo + hi)),
dim=tuple(np.abs(hi - lo)),
pbc=pbc,
)
84 changes: 84 additions & 0 deletions tests/test_rve.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Tests for Rve class."""

import dataclasses

import numpy as np
import pytest
import pyvista as pv

from microgen import Rve

Expand Down Expand Up @@ -80,3 +83,84 @@ def test_rve_from_min_max_must_return_expected_rve() -> None:
rve = Rve.from_min_max(min_point=(-1, 0, -3), max_point=(0, 2, 3))
assert np.all(rve.center == [-0.5, 1, 0])
assert np.all(rve.dim == [1, 2, 6])


def test_rve_pbc_default_is_fully_periodic() -> None:
"""``pbc`` defaults to ``(True, True, True)``."""
assert Rve().pbc == (True, True, True)


def test_rve_pbc_scalar_broadcasts_to_all_axes() -> None:
"""A single bool is broadcast to all three axes."""
assert Rve(pbc=False).pbc == (False, False, False)
assert Rve(pbc=True).pbc == (True, True, True)


def test_rve_pbc_accepts_per_axis_tuple() -> None:
"""A length-3 sequence sets each axis independently."""
assert Rve(pbc=(True, False, True)).pbc == (True, False, True)
assert Rve(pbc=[False, False, True]).pbc == (False, False, True)


def test_rve_invalid_pbc_raises() -> None:
"""``pbc`` of wrong length or type raises ``ValueError``."""
with pytest.raises(ValueError, match="pbc must be a bool"):
Rve(pbc=(True, False)) # type: ignore[arg-type]
with pytest.raises(ValueError, match="pbc must be a bool"):
Rve(pbc=1.5) # type: ignore[arg-type]


def test_rve_is_frozen() -> None:
"""Reassigning ``center``/``dim``/``pbc`` after construction raises FrozenInstanceError."""
rve = Rve(dim=2)
with pytest.raises(dataclasses.FrozenInstanceError):
rve.center = np.array([1.0, 2.0, 3.0]) # type: ignore[misc]
with pytest.raises(dataclasses.FrozenInstanceError):
rve.dim = np.array([3.0, 3.0, 3.0]) # type: ignore[misc]
with pytest.raises(dataclasses.FrozenInstanceError):
rve.pbc = (False, False, False) # type: ignore[misc]


def test_rve_min_max_points_are_consistent() -> None:
"""``min_point`` / ``max_point`` are derived from center ± 0.5·dim."""
rve = Rve(center=(1.0, 2.0, 3.0), dim=(2.0, 4.0, 6.0))
assert np.allclose(rve.min_point, [0.0, 0.0, 0.0])
assert np.allclose(rve.max_point, [2.0, 4.0, 6.0])


def test_rve_grid_with_scalar_resolution() -> None:
"""``Rve.grid(n)`` returns a StructuredGrid spanning the cell with n^3 points."""
rve = Rve(center=(0.0, 0.0, 0.0), dim=(2.0, 4.0, 6.0))
grid = rve.grid(5)
assert isinstance(grid, pv.StructuredGrid)
assert grid.dimensions == (5, 5, 5)
pts = np.asarray(grid.points)
assert np.isclose(pts[:, 0].min(), -1.0)
assert np.isclose(pts[:, 0].max(), 1.0)
assert np.isclose(pts[:, 1].min(), -2.0)
assert np.isclose(pts[:, 1].max(), 2.0)
assert np.isclose(pts[:, 2].min(), -3.0)
assert np.isclose(pts[:, 2].max(), 3.0)


def test_rve_grid_with_per_axis_resolution() -> None:
"""``Rve.grid((nx, ny, nz))`` honors per-axis resolution."""
rve = Rve(center=(0.0, 0.0, 0.0), dim=1.0)
grid = rve.grid((3, 4, 5))
assert grid.dimensions == (3, 4, 5)


def test_rve_grid_invalid_resolution_raises() -> None:
"""Wrong-length resolution raises ``ValueError``."""
rve = Rve()
with pytest.raises(ValueError, match="resolution must be an int"):
rve.grid((3, 4)) # type: ignore[arg-type]


def test_rve_repr_round_trips_inputs() -> None:
"""``repr(rve)`` contains the canonical ``center``/``dim``/``pbc`` triple."""
rve = Rve(center=(1.0, 2.0, 3.0), dim=(2.0, 2.0, 2.0), pbc=(True, False, True))
text = repr(rve)
assert "center=(1.0, 2.0, 3.0)" in text
assert "dim=(2.0, 2.0, 2.0)" in text
assert "pbc=(True, False, True)" in text
Loading