Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 332 additions & 4 deletions src/underworld3/function/functions_unit_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
import underworld3 as uw


@uw.timing.routine_timer_decorator
def evaluate(
def _evaluate_impl(
expr,
coords,
coord_sys=None,
Expand All @@ -49,6 +48,11 @@ def evaluate(
"""
Evaluate expression at coordinates with automatic unit handling.

Internal implementation of :func:`evaluate`. The public wrapper
delegates here and then optionally applies the ``monotone``
bounded-interpolation post-process. This body is unchanged from the
historical ``evaluate`` so that ``monotone=False`` is bit-identical.

This function wraps the Cython evaluate_nd implementation to automatically
handle unit conversions and return unit-aware results.

Expand Down Expand Up @@ -369,8 +373,7 @@ def evaluate(
return raw_values


@uw.timing.routine_timer_decorator
def global_evaluate(
def _global_evaluate_impl(
expr,
coords=None,
coord_sys=None,
Expand All @@ -389,6 +392,12 @@ def global_evaluate(
"""
Global evaluate with automatic unit-aware results.

Internal implementation of :func:`global_evaluate`. The public
wrapper delegates here and then optionally applies the ``monotone``
bounded-interpolation post-process. This body is unchanged from the
historical ``global_evaluate`` so that ``monotone=False`` is
bit-identical.

Similar to evaluate() but performs global evaluation across all processes.
Returns unit-aware objects when expression has units.

Expand Down Expand Up @@ -583,3 +592,322 @@ def global_evaluate(
else:
# Array result - wrap as UnitAwareArray
return UnitAwareArray(raw_result, units=result_units)


# ---------------------------------------------------------------------------
# Monotone (bounded) interpolation
#
# An opt-in post-process that bounds an interpolated result to the local
# data range of its source field. Lifted from the semi-Lagrangian
# advection-diffusion trace-back (``SemiLagrangian.update_pre_solve``),
# where FE Lagrange-P3 overshoot at non-nodal upstream points ignited
# catastrophic ringing in high-Ra free-surface convection (PR #186/#188).
# Exposed here so any resampling path (remesh re-interpolation, projection,
# restart, the SL DDt itself) can request the same bounded result from one
# place. It operates on the already-computed numbers and never touches the
# symbolic evaluation path.
# ---------------------------------------------------------------------------


def _normalize_monotone(monotone):
"""Map the ``monotone`` kwarg to a canonical mode string or ``None``.

``False`` / ``None`` → ``None`` (no limiting); ``True`` → ``"clamp"``
(the cheap, always-safe choice); ``"clamp"`` / ``"pick"`` pass
through. Anything else raises ``ValueError``.
"""
if monotone is False or monotone is None:
return None
if monotone is True:
return "clamp"
if monotone in ("clamp", "pick"):
return monotone
raise ValueError(
f"Unknown monotone option: {monotone!r}. "
f"Use False, True, 'clamp', or 'pick'."
)


def _apply_monotone_limit(
expr, coords, value, mode, coord_sys=None, other_arguments=None):
"""Bound an interpolated result to the local data range of its source
field (monotone / bounded interpolation).

For each evaluation coordinate this finds the ``mesh.dim + 1`` nearest
source-field DOFs and bounds ``value`` to their nodal min/max:

- ``mode == "clamp"`` clips the result into ``[nbr_min, nbr_max]``
(cheap, always bounded).
- ``mode == "pick"`` keeps the in-bounds result and re-evaluates only
the out-of-bounds subset via the bounded RBF path (``evalf=True``,
which is intrinsically neighbour-bounded). ``evalf=True`` is a
deliberate override -- the FE result is exactly what overshot, so we
do *not* honour the caller's ``mode``/``rbf``/``force_l2`` here.
``coord_sys`` / ``other_arguments`` are propagated so the re-eval
stays in the same frame.

MVP scope: a single source MeshVariable. ``expr`` must resolve
directly to one mesh variable (or one of its components, e.g.
``T.sym``); composite / multi-field expressions raise ``ValueError``
because the neighbour bound across fields is ill-defined.

The clamp algorithm is lifted verbatim from the SL trace-back limiter
so that routing the DDt through here is bit-identical.
"""
from ..utilities.unit_aware_array import UnitAwareArray
from .expressions import extract_meshes

# "pick" re-evaluates the out-of-bounds subset via a collective
# global_evaluate, gated on a rank-local mask -> would deadlock in
# parallel (see TODO(parallel) below). Fail fast with a clear message
# instead of hanging. "clamp" is rank-local-only and parallel-safe.
if mode == "pick" and uw.mpi.size > 1:
raise NotImplementedError(
"monotone='pick' is serial-only: its out-of-bounds "
"re-evaluation is gated on a rank-local mask around a "
"collective evaluate and would deadlock under MPI. Use "
"monotone='clamp' (parallel-safe) instead."
)

# --- Resolve the single source MeshVariable --------------------------
meshes = extract_meshes(expr)
if len(meshes) != 1:
raise ValueError(
"monotone interpolation requires an expression backed by a "
f"single MeshVariable, but the expression references "
f"{len(meshes)} mesh(es). Composite / multi-field expressions "
"are not supported (the neighbour bound across fields is "
"ill-defined)."
)
mesh = next(iter(meshes))
hit = uw.discretisation.meshVariable_lookup_by_symbol(mesh, expr)
if hit is None:
raise ValueError(
"monotone interpolation requires an expression that is a "
"single MeshVariable (or one of its components), e.g. `T.sym`. "
"Composite expressions such as `a*T.sym + b` cannot be bounded "
"against a single source field."
)
var, _comp = hit # use the full var.data (matches the SL limiter)

# --- Convert coords to the ND space of the source DOF cloud ----------
# In unit-aware runs `coords` is dimensional (e.g. metres) while
# `var.coords_nd` is [0,1] non-dimensional, so the kdtree must compare
# like with like. Same conversion as the SL trace-back.
if hasattr(coords, "magnitude"):
coords_raw = uw.non_dimensionalise(coords)
if isinstance(coords_raw, UnitAwareArray):
coords_nd = np.array(coords_raw)
elif hasattr(coords_raw, "magnitude"):
coords_nd = coords_raw.magnitude
else:
coords_nd = np.asarray(coords_raw)
else:
coords_nd = np.asarray(coords)

psi_coords_nd = np.asarray(var.coords_nd)
if hasattr(psi_coords_nd, "magnitude"):
psi_coords_nd = np.asarray(psi_coords_nd.magnitude)

# --- kNN neighbour stats from the source nodal data ------------------
# TODO(parallel): the KDTree is built from rank-local `var.coords_nd`,
# so near a partition seam the neighbour stats bound against a
# truncated neighbourhood. This matches the validated SL behaviour. For
# full parallel correctness the bound should include halo / global DOF
# neighbours -- see the nav-only overlap-clone machinery
# (project_parallel_point_eval_decision) as the hook if hardened.
# TODO(units): nbr bounds come from `var.data` (always non-dimensional)
# while `value` is dimensional in a units-active run -- a pre-existing
# latent mismatch (scaling is inactive in the validated baseline so it
# never bites). Do not "fix" without re-validating the trajectory.
nnn = mesh.dim + 1
kdt = uw.kdtree.KDTree(np.ascontiguousarray(psi_coords_nd))
_, idxs = kdt.query(
np.ascontiguousarray(coords_nd), k=nnn, sqr_dists=False)
psi_data = np.asarray(var.data)
# Flatten trailing dims to compute per-coord nbr stats, then restore
# the original shape afterwards.
psi_data_flat = psi_data.reshape(psi_data.shape[0], -1)
nbr_vals = psi_data_flat[idxs]
nbr_min = nbr_vals.min(axis=1)
nbr_max = nbr_vals.max(axis=1)

veep_np = np.asarray(value)
orig_shape = veep_np.shape
veep_flat = veep_np.reshape(nbr_min.shape)

if mode == "clamp":
veep_lim = np.clip(veep_flat, nbr_min, nbr_max)
else:
# "pick": re-evaluate via RBF only at the subset of coords whose
# result is out-of-bounds. Keeps cost dominated by the cheap pass
# when most points are in-bounds.
out_of_bounds = ((veep_flat < nbr_min) | (veep_flat > nbr_max))
oob_mask = out_of_bounds.any(
axis=tuple(range(1, out_of_bounds.ndim)))
veep_lim = veep_flat.copy()
# TODO(parallel): this re-evaluation is gated on a rank-local
# `oob_mask.any()`, so in parallel one rank may enter the
# (collective) global_evaluate while another skips it -> deadlock.
# Guarded against above (pick is serial-only) until the rank-local
# bound is hardened (see project_parallel_point_eval_decision).
if oob_mask.any():
# `coords` is an ndarray / UnitAwareArray here (the evaluator
# rejects non-array coords upstream, before monotone runs), so
# boolean masking preserves both the data and any units.
oob_coords = coords[oob_mask]
value_rbf_oob = global_evaluate(
expr, oob_coords, coord_sys=coord_sys,
other_arguments=other_arguments, evalf=True)
vrbf_flat = np.asarray(value_rbf_oob).reshape(
(-1,) + veep_flat.shape[1:])
Comment on lines +758 to +763
# Only overwrite entries individually out of bounds
# (multi-component case).
sub_oob = out_of_bounds[oob_mask]
veep_sub = veep_lim[oob_mask]
veep_sub = np.where(sub_oob, vrbf_flat, veep_sub)
veep_lim[oob_mask] = veep_sub

limited = veep_lim.reshape(orig_shape)

# Re-wrap units after the numpy ops, mirroring the source field.
var_units = getattr(var, "units", None)
if var_units is not None and not isinstance(limited, UnitAwareArray):
limited = UnitAwareArray(limited, units=var_units)
return limited


@uw.timing.routine_timer_decorator
def evaluate(
expr,
coords,
coord_sys=None,
other_arguments=None,
simplify=False,
verbose=False,
evalf=False,
mode="default",
data_layout=None,
check_extrapolated=False,
smoothing=1e-6,
# Expert overrides (override mode settings)
rbf=None,
force_l2=None,
monotone=False,
):
"""Evaluate ``expr`` at ``coords`` with automatic unit handling.

Thin wrapper over :func:`_evaluate_impl`. See that function for the
full parameter documentation and evaluation-mode notes. With the
default ``monotone=False`` this is bit-identical to the historical
``evaluate``.

Parameters
----------
monotone : bool or str, optional
Opt-in bounded (monotone) interpolation, applied as a
post-process to the computed result. ``False`` (default) → no
limiting. ``True`` / ``"clamp"`` → clip the result into the
``[min, max]`` of the ``mesh.dim + 1`` nearest source-field DOFs.
``"pick"`` → keep in-bounds values and re-evaluate only the
out-of-bounds subset via (bounded) RBF interpolation. Only
single-MeshVariable expressions are supported; composites raise
``ValueError``. See :func:`_apply_monotone_limit`.
"""
# Validate up front so an unknown option fails fast (no wasted eval).
monotone_mode = _normalize_monotone(monotone)

result = _evaluate_impl(
expr,
coords,
coord_sys=coord_sys,
other_arguments=other_arguments,
simplify=simplify,
verbose=verbose,
evalf=evalf,
mode=mode,
data_layout=data_layout,
check_extrapolated=check_extrapolated,
smoothing=smoothing,
rbf=rbf,
force_l2=force_l2,
)

if monotone_mode is None:
return result
Comment on lines +820 to +837

if check_extrapolated:
value, extrapolated = result
limited = _apply_monotone_limit(
expr, coords, value, monotone_mode,
coord_sys=coord_sys, other_arguments=other_arguments)
return limited, extrapolated

return _apply_monotone_limit(
expr, coords, result, monotone_mode,
coord_sys=coord_sys, other_arguments=other_arguments)


@uw.timing.routine_timer_decorator
def global_evaluate(
expr,
coords=None,
coord_sys=None,
other_arguments=None,
simplify=False,
verbose=False,
evalf=False,
mode="default",
data_layout=None,
check_extrapolated=False,
smoothing=1e-6,
# Expert overrides (override mode settings)
rbf=None,
force_l2=None,
monotone=False,
):
"""Parallel-safe evaluate with automatic unit-aware results.

Thin wrapper over :func:`_global_evaluate_impl`. See that function and
:func:`evaluate` for the full parameter documentation. With the
default ``monotone=False`` this is bit-identical to the historical
``global_evaluate``.

Parameters
----------
monotone : bool or str, optional
Opt-in bounded (monotone) interpolation post-process. See
:func:`evaluate` for semantics. Not supported together with
``check_extrapolated`` (raises ``NotImplementedError``).
"""
# Validate up front so an unknown option or the unsupported
# monotone + check_extrapolated combination fails fast (no wasted eval).
monotone_mode = _normalize_monotone(monotone)
if monotone_mode is not None and check_extrapolated:
raise NotImplementedError(
"monotone interpolation is not supported together with "
"check_extrapolated in global_evaluate."
)

result = _global_evaluate_impl(
expr,
coords=coords,
coord_sys=coord_sys,
other_arguments=other_arguments,
simplify=simplify,
verbose=verbose,
evalf=evalf,
mode=mode,
data_layout=data_layout,
check_extrapolated=check_extrapolated,
smoothing=smoothing,
rbf=rbf,
force_l2=force_l2,
)

if monotone_mode is None:
return result

return _apply_monotone_limit(
expr, coords, result, monotone_mode,
coord_sys=coord_sys, other_arguments=other_arguments)
Loading
Loading