diff --git a/burr/core/__init__.py b/burr/core/__init__.py index aa2f75a46..ddc1f53e6 100644 --- a/burr/core/__init__.py +++ b/burr/core/__init__.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. -from burr.core.action import Action, Condition, Result, action, default, expr, type_eraser, when +from burr.core.action import ( + Action, + Condition, + Result, + action, + default, + expr, + safe_expr, + type_eraser, + when, +) from burr.core.application import ( Application, ApplicationBuilder, @@ -35,6 +45,7 @@ "Condition", "default", "expr", + "safe_expr", "type_eraser", "Result", "State", diff --git a/burr/core/action.py b/burr/core/action.py index 08771d411..a15d6f6f8 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -380,6 +380,320 @@ def tags(self) -> list[str]: return [] +# --------------------------------------------------------------------------- +# Safe expression evaluator used by Condition.safe_expr. +# +# These are intentionally module-private. The validator walks the AST and +# rejects any node that isn't on the allowlist; the interpreter then walks +# the (validated) tree and produces a value. eval()/compile() are never +# called on the parsed tree -- that is the entire point. +# --------------------------------------------------------------------------- + +# Builtins we are willing to expose inside safe_expr. Everything here returns +# a value (no side effects), takes only basic Python data, and cannot be used +# to reach the import system or the interpreter internals. +_SAFE_EXPR_BUILTINS: typing.Dict[str, Callable] = { + "len": len, + "abs": abs, + "min": min, + "max": max, + "sum": sum, + "all": all, + "any": any, + "str": str, + "int": int, + "float": float, + "bool": bool, +} + +# AST BinOp / UnaryOp / BoolOp / Compare operator classes we accept. +_SAFE_BINOPS = ( + ast.Add, + ast.Sub, + ast.Mult, + ast.Div, + ast.FloorDiv, + ast.Mod, + ast.Pow, +) +_SAFE_UNARYOPS = (ast.Not, ast.USub, ast.UAdd) +_SAFE_BOOLOPS = (ast.And, ast.Or) +_SAFE_CMPOPS = ( + ast.Eq, + ast.NotEq, + ast.Lt, + ast.LtE, + ast.Gt, + ast.GtE, + ast.In, + ast.NotIn, + ast.Is, + ast.IsNot, +) + + +class _SafeExprValidator(ast.NodeVisitor): + """Walks an AST and raises ``ValueError`` on any node not on the allowlist. + + Run once at ``safe_expr()`` call time -- the resulting Condition is only + built if validation passes, so rejected expressions never reach runtime. + """ + + def _reject(self, node: ast.AST, why: str) -> None: + raise ValueError( + f"safe_expr: disallowed construct in expression: {why} " + f"(at line {getattr(node, 'lineno', '?')}, col {getattr(node, 'col_offset', '?')})" + ) + + # The Expression wrapper produced by ast.parse(mode="eval"). + def visit_Expression(self, node: ast.Expression) -> None: + self.visit(node.body) + + def visit_Constant(self, node: ast.Constant) -> None: + if isinstance(node.value, (int, float, str, bool)) or node.value is None: + return + self._reject(node, f"constant of type {type(node.value).__name__}") + + def visit_Name(self, node: ast.Name) -> None: + # Reading is fine; assignment (Store/Del) is unreachable from mode="eval" + # but be explicit. + if not isinstance(node.ctx, ast.Load): + self._reject(node, "assignment / deletion") + + def visit_Attribute(self, node: ast.Attribute) -> None: + if node.attr.startswith("__"): + self._reject(node, f"dunder attribute access '{node.attr}'") + self.visit(node.value) + + def visit_Subscript(self, node: ast.Subscript) -> None: + self.visit(node.value) + self.visit(node.slice) + + def visit_Slice(self, node: ast.Slice) -> None: + for child in (node.lower, node.upper, node.step): + if child is not None: + self.visit(child) + + # Python <3.9 wraps subscript indices in ast.Index; keep a permissive visitor. + def visit_Index(self, node) -> None: # pragma: no cover - legacy py + self.visit(node.value) + + def visit_Compare(self, node: ast.Compare) -> None: + for op in node.ops: + if not isinstance(op, _SAFE_CMPOPS): + self._reject(node, f"comparison operator {type(op).__name__}") + self.visit(node.left) + for cmp in node.comparators: + self.visit(cmp) + + def visit_BoolOp(self, node: ast.BoolOp) -> None: + if not isinstance(node.op, _SAFE_BOOLOPS): + self._reject(node, f"boolean operator {type(node.op).__name__}") + for v in node.values: + self.visit(v) + + def visit_UnaryOp(self, node: ast.UnaryOp) -> None: + if not isinstance(node.op, _SAFE_UNARYOPS): + self._reject(node, f"unary operator {type(node.op).__name__}") + self.visit(node.operand) + + def visit_BinOp(self, node: ast.BinOp) -> None: + if not isinstance(node.op, _SAFE_BINOPS): + self._reject(node, f"binary operator {type(node.op).__name__}") + self.visit(node.left) + self.visit(node.right) + + def visit_Tuple(self, node: ast.Tuple) -> None: + for elt in node.elts: + self.visit(elt) + + def visit_List(self, node: ast.List) -> None: + for elt in node.elts: + self.visit(elt) + + def visit_Set(self, node: ast.Set) -> None: + for elt in node.elts: + self.visit(elt) + + def visit_Dict(self, node: ast.Dict) -> None: + for k in node.keys: + if k is not None: + self.visit(k) + for v in node.values: + self.visit(v) + + def visit_Call(self, node: ast.Call) -> None: + # Only direct calls to allowlisted builtins by bare Name. + if not isinstance(node.func, ast.Name): + self._reject(node, "indirect call (only bare builtin names allowed)") + if node.func.id not in _SAFE_EXPR_BUILTINS: + self._reject(node, f"call to disallowed function '{node.func.id}'") + if any(isinstance(a, ast.Starred) for a in node.args): + self._reject(node, "starred argument") + if node.keywords: + self._reject(node, "keyword arguments to builtin call") + for a in node.args: + self.visit(a) + + # Catch-all: anything we didn't explicitly allow is rejected. This is the + # important rule -- additions to the grammar require explicit opt-in. + def generic_visit(self, node: ast.AST) -> None: + self._reject(node, f"node type {type(node).__name__}") + + +class _SafeExprInterpreter: + """Direct AST interpreter for the safe_expr grammar. + + Only handles nodes already approved by :class:`_SafeExprValidator`. We + raise ``ValueError`` defensively if an unknown node sneaks in -- but in + practice the validator should have caught it first. + """ + + def __init__(self, names: typing.Mapping[str, typing.Any]): + self._names = names + + def eval(self, node: ast.AST) -> typing.Any: # noqa: A003 - matches ast naming + if isinstance(node, ast.Expression): + return self.eval(node.body) + method = getattr(self, f"_eval_{type(node).__name__}", None) + if method is None: + raise ValueError( + f"safe_expr: interpreter encountered unsupported node {type(node).__name__}" + ) + return method(node) + + def _eval_Constant(self, node: ast.Constant): + return node.value + + def _eval_Name(self, node: ast.Name): + # Builtins on the allowlist resolve to the builtin; otherwise look up state. + if node.id in _SAFE_EXPR_BUILTINS: + return _SAFE_EXPR_BUILTINS[node.id] + if node.id in self._names: + return self._names[node.id] + raise NameError(f"safe_expr: name '{node.id}' is not defined") + + def _eval_Attribute(self, node: ast.Attribute): + # Validator already rejected dunder access. + return getattr(self.eval(node.value), node.attr) + + def _eval_Subscript(self, node: ast.Subscript): + value = self.eval(node.value) + slice_node = node.slice + # py<3.9 wraps in ast.Index + if hasattr(ast, "Index") and isinstance(slice_node, ast.Index): # pragma: no cover + slice_node = slice_node.value # type: ignore[attr-defined] + if isinstance(slice_node, ast.Slice): + lower = self.eval(slice_node.lower) if slice_node.lower is not None else None + upper = self.eval(slice_node.upper) if slice_node.upper is not None else None + step = self.eval(slice_node.step) if slice_node.step is not None else None + return value[slice(lower, upper, step)] + return value[self.eval(slice_node)] + + def _eval_Compare(self, node: ast.Compare): + left = self.eval(node.left) + for op, right_node in zip(node.ops, node.comparators): + right = self.eval(right_node) + if isinstance(op, ast.Eq): + ok = left == right + elif isinstance(op, ast.NotEq): + ok = left != right + elif isinstance(op, ast.Lt): + ok = left < right + elif isinstance(op, ast.LtE): + ok = left <= right + elif isinstance(op, ast.Gt): + ok = left > right + elif isinstance(op, ast.GtE): + ok = left >= right + elif isinstance(op, ast.In): + ok = left in right + elif isinstance(op, ast.NotIn): + ok = left not in right + elif isinstance(op, ast.Is): + ok = left is right + elif isinstance(op, ast.IsNot): + ok = left is not right + else: # pragma: no cover - validator catches this + raise ValueError(f"safe_expr: unsupported comparator {type(op).__name__}") + if not ok: + return False + left = right + return True + + def _eval_BoolOp(self, node: ast.BoolOp): + if isinstance(node.op, ast.And): + result = True + for v in node.values: + result = self.eval(v) + if not result: + return result + return result + # Or + result = False + for v in node.values: + result = self.eval(v) + if result: + return result + return result + + def _eval_UnaryOp(self, node: ast.UnaryOp): + operand = self.eval(node.operand) + if isinstance(node.op, ast.Not): + return not operand + if isinstance(node.op, ast.USub): + return -operand + if isinstance(node.op, ast.UAdd): + return +operand + raise ValueError( # pragma: no cover + f"safe_expr: unsupported unary op {type(node.op).__name__}" + ) + + def _eval_BinOp(self, node: ast.BinOp): + left = self.eval(node.left) + right = self.eval(node.right) + op = node.op + if isinstance(op, ast.Add): + return left + right + if isinstance(op, ast.Sub): + return left - right + if isinstance(op, ast.Mult): + return left * right + if isinstance(op, ast.Div): + return left / right + if isinstance(op, ast.FloorDiv): + return left // right + if isinstance(op, ast.Mod): + return left % right + if isinstance(op, ast.Pow): + return left**right + raise ValueError( # pragma: no cover + f"safe_expr: unsupported binary op {type(op).__name__}" + ) + + def _eval_Tuple(self, node: ast.Tuple): + return tuple(self.eval(e) for e in node.elts) + + def _eval_List(self, node: ast.List): + return [self.eval(e) for e in node.elts] + + def _eval_Set(self, node: ast.Set): + return {self.eval(e) for e in node.elts} + + def _eval_Dict(self, node: ast.Dict): + return { + (self.eval(k) if k is not None else None): self.eval(v) + for k, v in zip(node.keys, node.values) + } + + def _eval_Call(self, node: ast.Call): + # Validator guarantees node.func is a Name in _SAFE_EXPR_BUILTINS, + # no keywords, no starred args. + func = _SAFE_EXPR_BUILTINS[node.func.id] # type: ignore[attr-defined] + args = [self.eval(a) for a in node.args] + return func(*args) + + class Condition(Function): KEY = "PROCEED" @@ -414,6 +728,10 @@ def expr(expr: str) -> "Condition": only state variables and Python operators. Do not trust that anything else will work. Do not accept expressions generated from user-inputted text, this has the potential to be unsafe. + Internally this uses :func:`eval`, so the expression can execute arbitrary Python and should only + be used with developer-authored strings. If you need to accept expressions from less-trusted + sources (dashboards, YAML, user input), use :meth:`Condition.safe_expr` which restricts + evaluation to a small allowlisted AST grammar interpreted directly (no :func:`eval`). You can also refer to this as ``from burr.core import expr`` in the API. @@ -444,6 +762,78 @@ def condition_func(state: State) -> bool: return Condition(keys, condition_func, name=expr) + @staticmethod + def safe_expr(expr: str) -> "Condition": + """Returns a condition that evaluates ``expr`` under a restricted, allowlisted AST grammar. + + This is the opt-in safe sibling of :meth:`Condition.expr`. Whereas ``expr()`` uses + :func:`eval` and therefore accepts the full Python expression grammar (and is unsafe + for untrusted input), ``safe_expr()`` parses the expression, validates every node + against an allowlist at call time, and then *interprets* the validated tree directly. + :func:`eval` is never invoked on the parsed tree. + + Use ``safe_expr`` when the expression string comes from a less-trusted source -- + for example a dashboard rule editor, a YAML-driven graph definition, or any user + input. Use ``expr`` when the expression is authored by a developer and checked in. + + The allowed grammar is intentionally small: + + - Constants: ``int``, ``float``, ``str``, ``bool``, ``None``. + - ``Name`` lookups, resolved against the state via :meth:`State.get_all`. + - ``Attribute`` access on names / other attributes. Any attribute whose name + starts with ``__`` (dunder) is rejected -- this closes the standard + ``().__class__.__bases__[0].__subclasses__()`` sandbox-escape pattern. + - ``Subscript`` (``state["foo"]``, ``items[0]``, slices). + - ``Compare``: ``==``, ``!=``, ``<``, ``>``, ``<=``, ``>=``, ``in``, ``not in``, + ``is``, ``is not``. + - ``BoolOp``: ``and``, ``or``. + - ``UnaryOp``: ``not``, unary ``-``, unary ``+``. + - ``BinOp`` arithmetic: ``+``, ``-``, ``*``, ``/``, ``//``, ``%``, ``**``. + - Literal containers: tuple, list, set, dict. + - ``Call`` only to a tight allowlist of safe builtins by name: + ``len``, ``abs``, ``min``, ``max``, ``sum``, ``all``, ``any``, ``str``, + ``int``, ``float``, ``bool``. All other calls are rejected. + + Everything else is rejected at ``safe_expr()`` call time (not at run time), + including: lambdas, conditional expressions (``a if b else c``), + comprehensions and generator expressions, the walrus operator, + ``await``, ``yield``, imports, and any ``Call`` not on the builtin allowlist. + + :param expr: Expression to evaluate + :return: A condition that evaluates the given expression + :raises ValueError: if the expression contains any disallowed construct + (raised at call time -- the condition is rejected before it ever runs). + :raises SyntaxError: if the expression is not syntactically valid Python. + """ + # Parse first. This raises SyntaxError for malformed input, which is fine. + tree = ast.parse(expr, mode="eval") + # Validate the whole tree against the allowlist *now*, at call time. If any + # disallowed node exists we raise here, before constructing the Condition. + _SafeExprValidator().visit(tree) + + # Collect Name references for keys, mirroring expr(). + all_builtins = builtins.__dict__ + + class _NameCollector(ast.NodeVisitor): + def __init__(self): + self.names = set() + + def visit_Name(self, node): + if node.id not in all_builtins: + self.names.add(node.id) + + collector = _NameCollector() + collector.visit(tree) + keys = list(collector.names) + + def condition_func(state: State) -> bool: + # Interpret the validated tree directly. We deliberately do NOT call + # eval()/compile() on the tree -- the whole safety argument rests on + # this. The interpreter only implements the allowlisted node types. + return bool(_SafeExprInterpreter(state.get_all()).eval(tree)) + + return Condition(keys, condition_func, name=expr) + @staticmethod def lmda(resolver: Callable[[State], bool], state_keys: List[str]) -> "Condition": """Returns a condition that evaluates the given function of State. @@ -650,6 +1040,7 @@ def __invert__(self): default = Condition.default when = Condition.when expr = Condition.expr +safe_expr = Condition.safe_expr lmda = Condition.lmda # exists = Condition.exists diff --git a/tests/core/test_action.py b/tests/core/test_action.py index 367ee58e8..0095bdafe 100644 --- a/tests/core/test_action.py +++ b/tests/core/test_action.py @@ -317,6 +317,269 @@ def test_condition_expr_complex(): assert cond.run(State({"foo": "baz", "baz": "corge"})) == {Condition.KEY: False} +# --------------------------------------------------------------------------- +# Condition.safe_expr -- restricted-AST evaluator +# --------------------------------------------------------------------------- + + +# --- positive cases: every allowed grammar construct ------------------------ + + +def test_safe_expr_basic_equality(): + cond = Condition.safe_expr("foo == 'bar'") + assert cond.name == "foo == 'bar'" + assert cond.reads == ["foo"] + assert cond.run(State({"foo": "bar"})) == {Condition.KEY: True} + assert cond.run(State({"foo": "baz"})) == {Condition.KEY: False} + + +def test_safe_expr_all_comparisons(): + cases = [ + ("x < 5", {"x": 4}, True), + ("x <= 5", {"x": 5}, True), + ("x > 5", {"x": 6}, True), + ("x >= 5", {"x": 5}, True), + ("x != 5", {"x": 4}, True), + ("x in items", {"x": 2, "items": [1, 2, 3]}, True), + ("x not in items", {"x": 9, "items": [1, 2, 3]}, True), + ("x is None", {"x": None}, True), + ("x is not None", {"x": 1}, True), + ] + for expr_str, state, expected in cases: + cond = Condition.safe_expr(expr_str) + assert cond.run(State(state)) == {Condition.KEY: expected}, expr_str + + +def test_safe_expr_chained_comparison(): + cond = Condition.safe_expr("0 < x < 10") + assert cond.run(State({"x": 5})) == {Condition.KEY: True} + assert cond.run(State({"x": 11})) == {Condition.KEY: False} + + +def test_safe_expr_boolean_ops_nested(): + cond = Condition.safe_expr("(a > 0 and b < 10) or c == 'ok'") + assert sorted(cond.reads) == ["a", "b", "c"] + assert cond.run(State({"a": 1, "b": 5, "c": "no"})) == {Condition.KEY: True} + assert cond.run(State({"a": -1, "b": 5, "c": "ok"})) == {Condition.KEY: True} + assert cond.run(State({"a": -1, "b": 5, "c": "no"})) == {Condition.KEY: False} + + +def test_safe_expr_arithmetic(): + cond = Condition.safe_expr("(a + b) * 2 - c / 2 == 9.0") + assert cond.run(State({"a": 1, "b": 2, "c": 2})) == {Condition.KEY: False} + # (1+2)*2 - 4/2 = 6 - 2 = 4 + cond2 = Condition.safe_expr("(a + b) * 2 - c // 2") + assert _eval_value(cond2, {"a": 1, "b": 2, "c": 5}) == 6 - 2 + + +def test_safe_expr_arithmetic_mod_pow(): + assert _eval_value(Condition.safe_expr("x % 3"), {"x": 10}) == 1 + assert _eval_value(Condition.safe_expr("x ** 3"), {"x": 2}) == 8 + + +def test_safe_expr_unary_ops(): + assert _eval_value(Condition.safe_expr("-x"), {"x": 3}) == -3 + assert _eval_value(Condition.safe_expr("+x"), {"x": 3}) == 3 + assert Condition.safe_expr("not flag").run(State({"flag": False})) == {Condition.KEY: True} + + +def test_safe_expr_negative_literal(): + cond = Condition.safe_expr("x == -1") + assert cond.run(State({"x": -1})) == {Condition.KEY: True} + + +def test_safe_expr_none_comparison(): + cond = Condition.safe_expr("x is None") + assert cond.run(State({"x": None})) == {Condition.KEY: True} + assert cond.run(State({"x": 0})) == {Condition.KEY: False} + + +def test_safe_expr_subscript_index_and_slice(): + cond = Condition.safe_expr("items[0] == 'a'") + assert cond.run(State({"items": ["a", "b"]})) == {Condition.KEY: True} + cond2 = Condition.safe_expr("mapping['k'] == 1") + assert cond2.run(State({"mapping": {"k": 1}})) == {Condition.KEY: True} + assert _eval_value(Condition.safe_expr("items[1:3]"), {"items": [0, 1, 2, 3]}) == [1, 2] + + +def test_safe_expr_attribute_access(): + class Obj: + def __init__(self): + self.val = 42 + + cond = Condition.safe_expr("obj.val == 42") + assert cond.run(State({"obj": Obj()})) == {Condition.KEY: True} + + +def test_safe_expr_literal_containers(): + assert _eval_value(Condition.safe_expr("[1, 2, 3]"), {}) == [1, 2, 3] + assert _eval_value(Condition.safe_expr("(1, 2, 3)"), {}) == (1, 2, 3) + assert _eval_value(Condition.safe_expr("{1, 2, 3}"), {}) == {1, 2, 3} + assert _eval_value(Condition.safe_expr("{'a': 1, 'b': 2}"), {}) == {"a": 1, "b": 2} + + +def test_safe_expr_all_allowed_builtins(): + pairs = [ + ("len(items)", {"items": [1, 2, 3]}, 3), + ("abs(x)", {"x": -4}, 4), + ("min(a, b)", {"a": 1, "b": 2}, 1), + ("max(a, b)", {"a": 1, "b": 2}, 2), + ("sum(items)", {"items": [1, 2, 3]}, 6), + ("all(items)", {"items": [True, True]}, True), + ("any(items)", {"items": [False, True]}, True), + ("str(x)", {"x": 5}, "5"), + ("int(x)", {"x": "5"}, 5), + ("float(x)", {"x": "1.5"}, 1.5), + ("bool(x)", {"x": 0}, False), + ] + for expr_str, state, expected in pairs: + assert _eval_value(Condition.safe_expr(expr_str), state) == expected, expr_str + + +def test_safe_expr_reads_collected_correctly(): + cond = Condition.safe_expr("foo == 'bar' and len(baz) == 3") + assert sorted(cond.reads) == ["baz", "foo"] + + +def test_safe_expr_only_name_lookup_no_attributes(): + """Edge case: an expression that uses *only* Name lookups still works.""" + cond = Condition.safe_expr("a and b") + assert cond.run(State({"a": True, "b": True})) == {Condition.KEY: True} + assert cond.run(State({"a": True, "b": False})) == {Condition.KEY: False} + + +def test_safe_expr_validate_missing_state_key(): + cond = Condition.safe_expr("foo == 'bar'") + with pytest.raises(ValueError, match="foo"): + cond._validate(State({"baz": "bar"})) + + +# --- determinism ------------------------------------------------------------ + + +def test_safe_expr_determinism_success(): + state = State({"x": 5, "y": 3}) + cond_a = Condition.safe_expr("x + y == 8") + cond_b = Condition.safe_expr("x + y == 8") + assert cond_a.run(state) == cond_b.run(state) == {Condition.KEY: True} + # Same condition, called twice, identical state -> identical result. + assert cond_a.run(state) == cond_a.run(state) + + +def test_safe_expr_determinism_rejection(): + # Same disallowed expression rejected the same way every time, at call time. + for _ in range(3): + with pytest.raises(ValueError): + Condition.safe_expr("lambda: 1") + + +# --- negative cases: every rejected node category --------------------------- + + +@pytest.mark.parametrize( + "expr_str", + [ + # Attack patterns explicitly called out in the issue: + '__import__("os").system("touch /tmp/pwn")', + "(0).__class__.__bases__[0].__subclasses__()", + 'open("/etc/passwd").read()', + "lambda: 1", + "[x for x in range(10)]", + # Other rejected categories: + "{x for x in range(10)}", # set comp + "{x: x for x in range(10)}", # dict comp + "(x for x in range(10))", # generator exp + "1 if a else 2", # IfExp + "(x := 5) > 0", # walrus / NamedExpr + "foo()", # Call to non-allowlisted name + "foo.bar()", # Call via Attribute (indirect) + "obj.__class__", # dunder attribute + "obj.__dict__", # dunder attribute + "len(items, key=str)", # keyword arg to builtin + "min(*items)", # starred arg + "a & b", # bitwise (not on allowlist) + "a | b", # bitwise (not on allowlist) + "a << 1", # bitshift + "f'hello {name}'", # f-string / JoinedStr + "...", # Ellipsis constant + "b'bytes'", # bytes constant + ], +) +def test_safe_expr_rejects_at_call_time(expr_str): + """Every disallowed construct must raise at safe_expr() call time, not at run time.""" + with pytest.raises((ValueError, SyntaxError)): + Condition.safe_expr(expr_str) + + +def test_safe_expr_rejects_import_dunder_attack(): + # `__import__("os").system(...)` -- the disallowed thing here is the dunder + # Name reference; we reject because `__import__` resolves nowhere (not a + # safe builtin, not in state), but specifically Call to non-allowlisted name. + with pytest.raises(ValueError, match="(?i)disallowed"): + Condition.safe_expr('__import__("os").system("touch /tmp/pwn")') + + +def test_safe_expr_rejects_class_bases_sandbox_escape(): + # The expression is rejected -- could be on either the dunder attribute + # access OR the call to a non-allowlisted name (the validator hits the + # outer Call first). Either rejection closes the escape; we just confirm + # the expression doesn't make it to a Condition. + with pytest.raises(ValueError, match="(?i)disallowed"): + Condition.safe_expr("(0).__class__.__bases__[0].__subclasses__()") + # And the bare dunder attribute (no call) is rejected with the dunder reason. + with pytest.raises(ValueError, match="(?i)dunder"): + Condition.safe_expr("(0).__class__") + + +def test_safe_expr_rejects_open_call(): + with pytest.raises(ValueError, match="(?i)disallowed"): + Condition.safe_expr('open("/etc/passwd").read()') + + +def test_safe_expr_rejects_lambda(): + with pytest.raises(ValueError): + Condition.safe_expr("lambda x: x") + + +def test_safe_expr_rejects_list_comprehension(): + with pytest.raises(ValueError): + Condition.safe_expr("[x for x in range(10)]") + + +def test_safe_expr_rejects_syntax_error(): + with pytest.raises(SyntaxError): + Condition.safe_expr("foo == ") + + +def test_safe_expr_does_not_eval_during_validation(): + """If validation accidentally executed code, this would raise at call time. + + Build an expression that *would* error at run time but is syntactically + allowed. Confirm safe_expr() returns a Condition without raising; the + error only happens when .run() is called. + """ + cond = Condition.safe_expr("1 / x") # ZeroDivisionError at run time only + with pytest.raises(ZeroDivisionError): + cond.run(State({"x": 0})) + + +def _eval_value(cond: Condition, state: dict): + """Helper: get the raw (pre-bool-coerce) interpreter result for a safe_expr. + + ``Condition.safe_expr`` wraps the interpreter result in ``bool()`` because + a Condition is fundamentally a predicate. For tests of arithmetic / + container / builtin expressions we want the underlying value, so we + re-parse-and-interpret using the same internal machinery. This mirrors + what the Condition does at run time minus the ``bool()`` cast. + """ + import ast as _ast + + from burr.core.action import _SafeExprInterpreter # type: ignore + + tree = _ast.parse(cond.name, mode="eval") + return _SafeExprInterpreter(state).eval(tree) + + def test_condition__validate_success(): cond = Condition.when(foo="bar") cond._validate(State({"foo": "bar"}))