diff --git a/src/linkml_map/transformer/object_transformer.py b/src/linkml_map/transformer/object_transformer.py index e6c05e2..a6c50f7 100644 --- a/src/linkml_map/transformer/object_transformer.py +++ b/src/linkml_map/transformer/object_transformer.py @@ -15,6 +15,7 @@ from linkml_runtime.linkml_model import SlotDefinition from linkml_runtime.utils.yamlutils import YAMLRoot from pydantic import BaseModel +from simpleeval import InvalidExpression from linkml_map.datamodel.transformer_model import ( ClassDerivation, @@ -363,14 +364,9 @@ def _derive_from_expr(self, slot_derivation: SlotDerivation, bindings: Bindings) """Evaluate a slot derivation expression, with fallback to asteval for unrestricted mode.""" try: return eval_expr_with_mapping(slot_derivation.expr, bindings) - except Exception as err: - # Broad catch is intentional: simpleeval raises various exception types - # (NameNotDefined, FeatureNotAvailable, etc.) for expressions outside its - # safe subset. Should also handle KeyError, TypeError in the future. + except (InvalidExpression, TypeError, ValueError): if not self.unrestricted_eval: - logger.warning(f"Expression evaluation failed for '{slot_derivation.name}': {err}") - msg = f"Expression not in safe subset: {slot_derivation.expr}" - raise RuntimeError(msg) from err + raise ctxt_obj, _ = bindings.get_ctxt_obj_and_dict() aeval = Interpreter(usersyms={"src": ctxt_obj, "target": None, "uuid5": _uuid5}) aeval(slot_derivation.expr) diff --git a/src/linkml_map/utils/eval_utils.py b/src/linkml_map/utils/eval_utils.py index a190a82..43e7f56 100644 --- a/src/linkml_map/utils/eval_utils.py +++ b/src/linkml_map/utils/eval_utils.py @@ -18,12 +18,15 @@ """ import ast +import logging import uuid from collections.abc import Mapping from typing import Any from simpleeval import EvalWithCompoundTypes, NameNotDefined +logger = logging.getLogger(__name__) + def eval_conditional(*conds: tuple[bool, Any]) -> Any: # noqa: ANN401 """ @@ -73,6 +76,57 @@ def _uuid5(namespace: str, name: str) -> str: return str(uuid.uuid5(ns, name)) +def _try_numeric(value: Any) -> Any: # noqa: ANN401 + """Attempt to coerce a value to a numeric type. + + Returns the value as-is if already numeric (int/float, not bool), + coerces numeric strings to float, and returns None for anything else. + + >>> _try_numeric(5) + 5 + >>> _try_numeric(3.14) + 3.14 + >>> _try_numeric("3.14") + 3.14 + >>> _try_numeric("abc") + >>> _try_numeric(None) + >>> _try_numeric(True) + """ + if isinstance(value, bool): + return None + if isinstance(value, (int, float)): + return value + if isinstance(value, str): + try: + return float(value) + except ValueError: + return None + return None + + +def _is_numeric(value: Any) -> bool: # noqa: ANN401 + """ + Check whether a value can be converted to float. + + >>> _is_numeric("3.14") + True + >>> _is_numeric("abc") + False + >>> _is_numeric(5) + True + >>> _is_numeric("") + False + >>> _is_numeric(None) + False + >>> _is_numeric(True) + False + + :param value: The value to check. + :return: True if float(value) would succeed, False otherwise. + """ + return _try_numeric(value) is not None + + def _null_safe(func): # noqa: ANN001, ANN202 """Wrap a function to return None if any argument is None.""" @@ -124,6 +178,7 @@ def wrapper(*args: Any) -> Any: # noqa: ANN401 **_LIST_FUNCTIONS, **{name: _distributing(func) for name, func in _SCALAR_FUNCTIONS.items()}, "case": eval_conditional, + "is_numeric": _is_numeric, } @@ -161,12 +216,30 @@ def _maybe_coerce_numeric(left: Any, right: Any) -> tuple[Any, Any]: # noqa: AN def _null_propagating(op): # noqa: ANN001, ANN202 - """Wrap a binary operator to return None if either operand is None.""" + """Wrap a binary operator with null propagation and numeric coercion fallback. + + Handles four cases: + - Either operand is None → None (null propagation) + - Operation succeeds natively → return result (e.g. str + str is concat) + - Operation fails but operands are numeric strings → coerce to float and retry + - Operands can't be made numeric → None with warning (enables case() guards) + + Note: ``+`` on two strings succeeds natively as concatenation and is not + coerced. Use ``x + 0 + y`` or explicit ``float()`` if numeric addition of + string values is needed. + """ def wrapper(left: Any, right: Any) -> Any: # noqa: ANN401 if left is None or right is None: return None - return op(left, right) + try: + return op(left, right) + except (TypeError, ValueError): + left_n, right_n = _try_numeric(left), _try_numeric(right) + if left_n is None or right_n is None: + logger.warning(f"Non-numeric operand in {op.__name__}: {left!r}, {right!r}; returning None") + return None + return op(left_n, right_n) return wrapper diff --git a/tests/test_transformer/test_object_transformer.py b/tests/test_transformer/test_object_transformer.py index 722ffd2..ebe3edc 100644 --- a/tests/test_transformer/test_object_transformer.py +++ b/tests/test_transformer/test_object_transformer.py @@ -606,7 +606,7 @@ def test_derive_from_expr_unrestricted_fallback() -> None: def test_derive_from_expr_restricted_raises() -> None: - """Expressions rejected by simpleeval raise RuntimeError when unrestricted_eval=False.""" + """Expressions rejected by simpleeval raise TransformationError when unrestricted_eval=False.""" source_schema: dict[str, Any] = yaml.safe_load(open(str(PERSONINFO_SRC_SCHEMA))) target_schema: dict[str, Any] = yaml.safe_load(open(str(PERSONINFO_TGT_SCHEMA))) transform_spec: dict[str, Any] = yaml.safe_load(open(str(PERSONINFO_TR))) @@ -614,7 +614,7 @@ def test_derive_from_expr_restricted_raises() -> None: transform_spec.setdefault("class_derivations", {}).setdefault("Agent", {}).setdefault("slot_derivations", {})[ "label" ] = { - "expr": "target = name", + "expr": "lambda x: x", } obj_tr = ObjectTransformer(unrestricted_eval=False) @@ -623,7 +623,7 @@ def test_derive_from_expr_restricted_raises() -> None: obj_tr.create_transformer_specification(transform_spec) person_dict: dict[str, Any] = yaml.safe_load(open(str(PERSONINFO_DATA))) - with pytest.raises(TransformationError, match="Expression not in safe subset"): + with pytest.raises(TransformationError, match="(?i)lambda"): obj_tr.map_object(person_dict, source_type="Person") diff --git a/tests/test_transformer/test_range_override.py b/tests/test_transformer/test_range_override.py index d810a10..486946a 100644 --- a/tests/test_transformer/test_range_override.py +++ b/tests/test_transformer/test_range_override.py @@ -20,7 +20,7 @@ """ import copy -from typing import Any, Optional +from typing import Any import pytest from linkml.utils.schema_builder import SchemaBuilder @@ -133,9 +133,11 @@ def _run( transform_spec: dict[str, Any], input_data: dict[str, Any], source_type: str, + *, + unrestricted_eval: bool = True, ) -> dict[str, Any]: """Instantiate an ObjectTransformer and map a single object.""" - tr = ObjectTransformer(unrestricted_eval=True) + tr = ObjectTransformer(unrestricted_eval=unrestricted_eval) tr.source_schemaview = SchemaView(source_schema.schema) tr.target_schemaview = SchemaView(target_schema.schema) tr.create_transformer_specification(copy.deepcopy(transform_spec)) @@ -206,44 +208,60 @@ def test_parse_string_into_object() -> None: # --------------------------------------------------------------------------- -@pytest.mark.parametrize( - "depth_input", - [ - pytest.param(None, id="null_input"), - pytest.param("", id="empty_string"), - pytest.param("5", id="no_unit"), - pytest.param("five m", id="non_numeric_value"), - ], -) -def test_parse_expr_malformed_input_yields_none(depth_input: Optional[str]) -> None: - """Malformed depth strings cause expr evaluation errors caught by simpleeval.""" +def test_parse_expr_null_input_yields_none() -> None: + """Null depth input propagates None through the expression.""" result = _run( source_schema=_source_schema_string(), target_schema=_target_schema_quantity(), transform_spec=TRANSFORM_PARSE, - input_data={"id": "samp1", "depth": depth_input}, + input_data={"id": "samp1", "depth": None}, source_type="StringSample", ) assert result["id"] == "samp1" assert result["depth"] is None +@pytest.mark.parametrize( + "depth_input", + [ + pytest.param("", id="empty_string"), + pytest.param("5", id="no_unit"), + pytest.param("five m", id="non_numeric_value"), + ], +) +def test_parse_expr_malformed_input_raises(depth_input: str) -> None: + """Malformed depth strings raise TransformationError in restricted mode.""" + from linkml_map.transformer.errors import TransformationError + + with pytest.raises(TransformationError): + _run( + source_schema=_source_schema_string(), + target_schema=_target_schema_quantity(), + transform_spec=TRANSFORM_PARSE, + input_data={"id": "samp1", "depth": depth_input}, + source_type="StringSample", + unrestricted_eval=False, + ) + + # --------------------------------------------------------------------------- # Test E -- non-numeric depth_value in construct expr # --------------------------------------------------------------------------- -def test_construct_non_numeric_depth_value_yields_none() -> None: - """float('five') fails; simpleeval catches the error and returns None.""" - result = _run( - source_schema=_source_schema_flat(), - target_schema=_target_schema_quantity(), - transform_spec=TRANSFORM_CONSTRUCT, - input_data={"id": "samp1", "depth_value": "five", "depth_unit": "m"}, - source_type="FlatSample", - ) - assert result["id"] == "samp1" - assert result["depth"] is None +def test_construct_non_numeric_depth_value_raises() -> None: + """float('five') raises TransformationError in restricted mode.""" + from linkml_map.transformer.errors import TransformationError + + with pytest.raises(TransformationError, match="could not convert string to float"): + _run( + source_schema=_source_schema_flat(), + target_schema=_target_schema_quantity(), + transform_spec=TRANSFORM_CONSTRUCT, + input_data={"id": "samp1", "depth_value": "five", "depth_unit": "m"}, + source_type="FlatSample", + unrestricted_eval=False, + ) # --------------------------------------------------------------------------- diff --git a/tests/test_utils/test_eval_utils.py b/tests/test_utils/test_eval_utils.py index fdc1ede..0610727 100644 --- a/tests/test_utils/test_eval_utils.py +++ b/tests/test_utils/test_eval_utils.py @@ -266,6 +266,38 @@ def test_null_in_numeric_guard_pattern() -> None: assert eval_expr("case(({x} <= 0, None), (True, {x} * 2.54))", x=0) is None +def test_is_numeric() -> None: + """is_numeric() checks whether a value can be converted to float.""" + assert eval_expr("is_numeric(x)", x="3.14") is True + assert eval_expr("is_numeric(x)", x="abc") is False + assert eval_expr("is_numeric(x)", x=5) is True + assert eval_expr("is_numeric(x)", x="") is False + assert eval_expr("is_numeric(x)", x=None) is False + assert eval_expr("is_numeric(x)", x="0") is True + + +def test_is_numeric_guard_pattern() -> None: + """is_numeric() enables guarded numeric branching in case() expressions.""" + expr = "case((is_numeric(x), x * 2.54), (True, None))" + assert eval_expr(expr, x="5") == 12.7 + assert eval_expr(expr, x="abc") is None + assert eval_expr(expr, x="") is None + assert eval_expr(expr, x=None) is None + + +def test_arithmetic_coerces_numeric_strings() -> None: + """Arithmetic operators coerce numeric strings to float.""" + assert eval_expr("x / y * 10", x="100", y="50") == 20.0 + assert eval_expr("{x} / 100.0 * {y}", x="200", y="50") == 100.0 + + +def test_arithmetic_non_numeric_string_returns_none() -> None: + """Non-numeric strings in arithmetic return None with a warning instead of crashing.""" + assert eval_expr("x / y", x="100", y="abc") is None + assert eval_expr("x * y", x="abc", y="10") is None + assert eval_expr("x + y", x="abc", y=10) is None + + def test_null_in_function_call() -> None: """None propagates through function calls.""" assert eval_expr("float(x)", x=None) is None