diff --git a/effectful/handlers/llm/evaluation.py b/effectful/handlers/llm/evaluation.py index 271e55f6..a673eb29 100644 --- a/effectful/handlers/llm/evaluation.py +++ b/effectful/handlers/llm/evaluation.py @@ -5,6 +5,14 @@ from types import CodeType from typing import Any +from RestrictedPython import ( + Eval, + Guards, + RestrictingNodeTransformer, + compile_restricted, + safe_globals, +) + from effectful.ops.syntax import ObjectInterpretation, defop, implements @@ -86,3 +94,80 @@ def exec( # Execute module-style so top-level defs land in `env`. builtins.exec(bytecode, env, env) + + +class RestrictedEvalProvider(ObjectInterpretation): + """ + Safer provider using RestrictedPython. + + RestrictedPython is not a complete sandbox, but it enforces a restricted + language subset and expects you to provide a constrained exec environment. + + policy : dict[str, Any], optional + RestrictedPython compile_restricted policy for compilation + """ + + policy: type[RestrictingNodeTransformer] | None = None + + def __init__( + self, + *, + policy: type[RestrictingNodeTransformer] | None = None, + ): + self.policy = policy + + @implements(parse) + def parse(self, source: str, filename: str) -> ast.Module: + # Keep inspect.getsource() working for dynamically-defined objects. + linecache.cache[filename] = ( + len(source), + None, + source.splitlines(True), + filename, + ) + return ast.parse(source, filename=filename, mode="exec") + + @implements(compile) + def compile(self, module: ast.Module, filename: str) -> CodeType: + # RestrictedPython can compile from an AST directly. + return compile_restricted( + module, + filename=filename, + mode="exec", + policy=self.policy or RestrictingNodeTransformer, + ) + + @implements(exec) + def exec( + self, + bytecode: CodeType, + env: dict[str, Any], + ) -> None: + # Build restricted globals from RestrictedPython's defaults + rglobals: dict[str, Any] = safe_globals.copy() + + # Enable class definitions (required for Python 3) + rglobals["__metaclass__"] = type + rglobals["__name__"] = "restricted" + + # Layer `env` on top (without letting callers replace the restricted builtins). + rglobals.update({k: v for k, v in env.items() if k != "__builtins__"}) + + # Enable for loops and comprehensions + rglobals["_getiter_"] = Eval.default_guarded_getiter + # Enable sequence unpacking in comprehensions and for loops + rglobals["_iter_unpack_sequence_"] = Guards.guarded_iter_unpack_sequence + + rglobals["getattr"] = Guards.safer_getattr + rglobals["setattr"] = Guards.guarded_setattr + rglobals["_write_"] = lambda x: x + + # Track keys before execution to identify new definitions + keys_before = set(rglobals.keys()) + + builtins.exec(bytecode, rglobals, rglobals) + + # Copy newly defined items back to env so caller can access them + for key in rglobals: + if key not in keys_before: + env[key] = rglobals[key] diff --git a/pyproject.toml b/pyproject.toml index 38ff53ba..c75c21d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ llm = [ "litellm", "pillow", "pydantic", + "restrictedpython>=8.1" ] prettyprinter = ["prettyprinter"] docs = [ diff --git a/tests/test_handlers_llm_encoding.py b/tests/test_handlers_llm_encoding.py index 775d5f72..472db961 100644 --- a/tests/test_handlers_llm_encoding.py +++ b/tests/test_handlers_llm_encoding.py @@ -1,3 +1,4 @@ +import builtins from collections.abc import Callable from dataclasses import asdict, dataclass from typing import Any, NamedTuple, TypedDict @@ -5,12 +6,19 @@ import pydantic import pytest from PIL import Image +from RestrictedPython import RestrictingNodeTransformer from effectful.handlers.llm.encoding import Encodable, SynthesizedFunction -from effectful.handlers.llm.evaluation import UnsafeEvalProvider +from effectful.handlers.llm.evaluation import RestrictedEvalProvider, UnsafeEvalProvider from effectful.ops.semantics import handler from effectful.ops.types import Operation, Term +# Eval providers for parameterized tests +EVAL_PROVIDERS = [ + pytest.param(UnsafeEvalProvider(), id="unsafe"), + pytest.param(RestrictedEvalProvider(), id="restricted"), +] + def test_type_to_encodable_type_term(): with pytest.raises(TypeError): @@ -726,7 +734,8 @@ class Person(pydantic.BaseModel): class TestCallableEncodable: """Tests for CallableEncodable - encoding/decoding callables as SynthesizedFunction.""" - def test_encode_decode_function(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_encode_decode_function(self, eval_provider): def add(a: int, b: int) -> int: return a + b @@ -737,13 +746,14 @@ def add(a: int, b: int) -> int: assert "def add" in encoded.module_code assert "return a + b" in encoded.module_code - with handler(UnsafeEvalProvider()): + with handler(eval_provider): decoded = encodable.decode(encoded) assert callable(decoded) assert decoded(2, 3) == 5 assert decoded.__name__ == "add" - def test_decode_with_ellipsis_params(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_decode_with_ellipsis_params(self, eval_provider): # Callable[..., int] allows any params but validates return type encodable = Encodable.define(Callable[..., int], {}) @@ -751,12 +761,13 @@ def test_decode_with_ellipsis_params(self): func_source = SynthesizedFunction( module_code="def double(x) -> int:\n return x * 2" ) - with handler(UnsafeEvalProvider()): + with handler(eval_provider): decoded = encodable.decode(func_source) assert callable(decoded) assert decoded(5) == 10 - def test_decode_with_env(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_decode_with_env(self, eval_provider): # Test decoding a function that uses env variables encodable = Encodable.define(Callable[..., int], {"factor": 3}) source = SynthesizedFunction( @@ -764,7 +775,7 @@ def test_decode_with_env(self): return x * factor""" ) - with handler(UnsafeEvalProvider()): + with handler(eval_provider): decoded = encodable.decode(source) assert callable(decoded) assert decoded(4) == 12 @@ -796,17 +807,19 @@ def __call__(self): with pytest.raises(RuntimeError, match="no source code and no docstring"): encodable.encode(NoDocCallable()) - def test_decode_no_function_at_end_raises(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_decode_no_function_at_end_raises(self, eval_provider): encodable = Encodable.define(Callable[..., int], {}) # Source code where last statement is not a function definition source = SynthesizedFunction(module_code="x = 42") with pytest.raises( ValueError, match="last statement to be a function definition" ): - with handler(UnsafeEvalProvider()): + with handler(eval_provider): encodable.decode(source) - def test_decode_multiple_functions_uses_last(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_decode_multiple_functions_uses_last(self, eval_provider): encodable = Encodable.define(Callable[..., int], {}) # Source code that defines multiple functions - should use the last one source = SynthesizedFunction( @@ -816,13 +829,14 @@ def test_decode_multiple_functions_uses_last(self): def bar() -> int: return 2""" ) - with handler(UnsafeEvalProvider()): + with handler(eval_provider): decoded = encodable.decode(source) assert callable(decoded) assert decoded.__name__ == "bar" assert decoded() == 2 - def test_decode_class_raises(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_decode_class_raises(self, eval_provider): encodable = Encodable.define(Callable[..., int], {}) # Classes are callable but the last statement must be a function definition source = SynthesizedFunction( @@ -837,15 +851,16 @@ def greet(self): with pytest.raises( ValueError, match="last statement to be a function definition" ): - with handler(UnsafeEvalProvider()): + with handler(eval_provider): encodable.decode(source) - def test_roundtrip(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_roundtrip(self, eval_provider): def greet(name: str) -> str: return f"Hello, {name}!" encodable = Encodable.define(Callable[[str], str], {}) - with handler(UnsafeEvalProvider()): + with handler(eval_provider): encoded = encodable.encode(greet) decoded = encodable.decode(encoded) @@ -871,7 +886,8 @@ def add(a: int, b: int) -> int: assert isinstance(deserialized, SynthesizedFunction) assert "def add" in deserialized.module_code - def test_decode_validates_last_statement(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_decode_validates_last_statement(self, eval_provider): encodable = Encodable.define(Callable[..., int], {}) # Helper function followed by assignment - should fail @@ -884,7 +900,7 @@ def test_decode_validates_last_statement(self): with pytest.raises( ValueError, match="last statement to be a function definition" ): - with handler(UnsafeEvalProvider()): + with handler(eval_provider): encodable.decode(source) def test_typed_callable_includes_signature_in_docstring(self): @@ -894,7 +910,8 @@ def test_typed_callable_includes_signature_in_docstring(self): assert "Callable[[int, int], int]" in encodable.enc.__doc__ assert "" in encodable.enc.__doc__ - def test_typed_callable_validates_param_count(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_typed_callable_validates_param_count(self, eval_provider): encodable = Encodable.define(Callable[[int, int], int], {}) # Function with wrong number of parameters @@ -903,10 +920,11 @@ def test_typed_callable_validates_param_count(self): return a""" ) with pytest.raises(ValueError, match="expected function with 2 parameters"): - with handler(UnsafeEvalProvider()): + with handler(eval_provider): encodable.decode(source) - def test_typed_callable_validates_return_type(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_typed_callable_validates_return_type(self, eval_provider): encodable = Encodable.define(Callable[[int, int], int], {}) # Function with wrong return type @@ -915,10 +933,11 @@ def test_typed_callable_validates_return_type(self): return str(a + b)""" ) with pytest.raises(ValueError, match="expected function with return type int"): - with handler(UnsafeEvalProvider()): + with handler(eval_provider): encodable.decode(source) - def test_typed_callable_requires_return_annotation(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_typed_callable_requires_return_annotation(self, eval_provider): encodable = Encodable.define(Callable[[int, int], int], {}) # Function missing return type annotation @@ -930,10 +949,11 @@ def test_typed_callable_requires_return_annotation(self): ValueError, match="requires synthesized function to have a return type annotation", ): - with handler(UnsafeEvalProvider()): + with handler(eval_provider): encodable.decode(source) - def test_typed_callable_accepts_correct_signature(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_typed_callable_accepts_correct_signature(self, eval_provider): encodable = Encodable.define(Callable[[int, int], int], {}) # Function with correct signature @@ -941,12 +961,13 @@ def test_typed_callable_accepts_correct_signature(self): module_code="""def add(a: int, b: int) -> int: return a + b""" ) - with handler(UnsafeEvalProvider()): + with handler(eval_provider): result = encodable.decode(source) assert callable(result) assert result(2, 3) == 5 - def test_ellipsis_callable_skips_param_validation(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_ellipsis_callable_skips_param_validation(self, eval_provider): # Callable[..., int] should skip param validation but still validate return encodable = Encodable.define(Callable[..., int], {}) @@ -954,7 +975,7 @@ def test_ellipsis_callable_skips_param_validation(self): module_code="""def anything(a, b, c, d, e) -> int: return 42""" ) - with handler(UnsafeEvalProvider()): + with handler(eval_provider): result = encodable.decode(source) assert callable(result) assert result(1, 2, 3, 4, 5) == 42 @@ -983,7 +1004,8 @@ def test_typed_callable_json_schema_different_signatures(self): assert "Callable[[str], str]" in schema1["description"] assert "Callable[[int, int, int], bool]" in schema2["description"] - def test_validates_param_count_via_ast(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_validates_param_count_via_ast(self, eval_provider): # Test that param validation happens via AST analysis encodable = Encodable.define(Callable[[int, int], int], {}) @@ -993,10 +1015,11 @@ def test_validates_param_count_via_ast(self): return a + b + c""" ) with pytest.raises(ValueError, match="expected function with 2 parameters"): - with handler(UnsafeEvalProvider()): + with handler(eval_provider): encodable.decode(source) - def test_validates_param_count_zero_params(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_validates_param_count_zero_params(self, eval_provider): # Test callable with no params encodable = Encodable.define(Callable[[], int], {}) @@ -1006,10 +1029,11 @@ def test_validates_param_count_zero_params(self): return x""" ) with pytest.raises(ValueError, match="expected function with 0 parameters"): - with handler(UnsafeEvalProvider()): + with handler(eval_provider): encodable.decode(source) - def test_validates_accepts_zero_params(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_validates_accepts_zero_params(self, eval_provider): # Test callable with no params - correct signature encodable = Encodable.define(Callable[[], int], {}) @@ -1017,7 +1041,7 @@ def test_validates_accepts_zero_params(self): module_code="""def get_value() -> int: return 42""" ) - with handler(UnsafeEvalProvider()): + with handler(eval_provider): result = encodable.decode(source) assert callable(result) assert result() == 42 @@ -1031,7 +1055,8 @@ def test_ellipsis_callable_json_schema_includes_signature(self): assert "Callable[[...], int]" in schema["description"] assert "" in schema["description"] - def test_ellipsis_callable_validates_return_type(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_ellipsis_callable_validates_return_type(self, eval_provider): # Callable[..., int] should still validate return type encodable = Encodable.define(Callable[..., int], {}) @@ -1040,41 +1065,44 @@ def test_ellipsis_callable_validates_return_type(self): return "wrong type\"""" ) with pytest.raises(ValueError, match="expected function with return type int"): - with handler(UnsafeEvalProvider()): + with handler(eval_provider): encodable.decode(source) - def test_callable_with_single_param(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_callable_with_single_param(self, eval_provider): encodable = Encodable.define(Callable[[str], int], {}) source = SynthesizedFunction( module_code="""def count_chars(s: str) -> int: return len(s)""" ) - with handler(UnsafeEvalProvider()): + with handler(eval_provider): result = encodable.decode(source) assert callable(result) assert result("hello") == 5 - def test_callable_with_many_params(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_callable_with_many_params(self, eval_provider): encodable = Encodable.define(Callable[[int, int, int, int], int], {}) source = SynthesizedFunction( module_code="""def sum_four(a: int, b: int, c: int, d: int) -> int: return a + b + c + d""" ) - with handler(UnsafeEvalProvider()): + with handler(eval_provider): result = encodable.decode(source) assert callable(result) assert result(1, 2, 3, 4) == 10 - def test_callable_with_bool_return(self): + @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) + def test_callable_with_bool_return(self, eval_provider): encodable = Encodable.define(Callable[[int], bool], {}) source = SynthesizedFunction( module_code="""def is_positive(x: int) -> bool: return x > 0""" ) - with handler(UnsafeEvalProvider()): + with handler(eval_provider): result = encodable.decode(source) assert callable(result) assert result(5) is True @@ -1098,3 +1126,96 @@ def test_callable_type_variations_schema(self): f"Expected {expected_sig} in schema for {callable_type}, " f"got: {schema['description'][:100]}..." ) + + +class TestRestrictedEvalProviderConfig: + """Tests for RestrictedEvalProvider configuration options.""" + + def test_restricted_blocks_private_attribute_access(self): + """RestrictedPython blocks access to underscore-prefixed attributes by default.""" + encodable = Encodable.define(Callable[[str], int], {}) + source = SynthesizedFunction( + module_code="""def get_private(s: str) -> int: + return s.__class__.__name__""" + ) + # Should raise due to restricted attribute access + with pytest.raises(Exception): # Could be NameError or AttributeError + with handler(RestrictedEvalProvider()): + fn = encodable.decode(source) + fn("test") + + def test_restricted_with_custom_policy(self): + """Can pass custom policy via kwargs.""" + + # Create a custom policy that's the same as default (just to test the plumbing) + class CustomPolicy(RestrictingNodeTransformer): + pass + + encodable = Encodable.define(Callable[[int, int], int], {}) + source = SynthesizedFunction( + module_code="""def add(a: int, b: int) -> int: + return a + b""" + ) + with handler(RestrictedEvalProvider(policy=CustomPolicy)): + fn = encodable.decode(source) + assert fn(2, 3) == 5 + + def test_builtins_in_env_does_not_bypass_security(self): + """Including __builtins__ in env should not bypass RestrictedEvalProvider security. + + RestrictedEvalProvider explicitly filters out __builtins__ from the env + to prevent callers from replacing the restricted builtins with full Python builtins. + This test verifies that even if __builtins__ is passed in the context, + dangerous operations remain blocked. + """ + + # Attempt to pass full builtins in the context, which should be filtered out + dangerous_ctx = {"__builtins__": builtins.__dict__} + + # Test 1: open() should not be usable even with __builtins__ in context + # The function may fail at compile/exec time or at call time, but either way + # it should not be able to actually open files + encodable_open = Encodable.define(Callable[[str], str], dangerous_ctx) + source_open = SynthesizedFunction( + module_code="""def read_file(path: str) -> str: + return open(path).read()""" + ) + with pytest.raises(Exception): # Could be NameError, ValueError, or other + with handler(RestrictedEvalProvider()): + fn = encodable_open.decode(source_open) + # If decode succeeded (shouldn't), calling should still fail + fn("/etc/passwd") + + # Test 2: __import__ should not be usable + encodable_import = Encodable.define(Callable[[], str], dangerous_ctx) + source_import = SynthesizedFunction( + module_code="""def get_os_name() -> str: + os = __import__('os') + return os.name""" + ) + with pytest.raises(Exception): + with handler(RestrictedEvalProvider()): + fn = encodable_import.decode(source_import) + fn() + + # Test 3: Verify safe code still works with dangerous context + # This confirms we're not just breaking everything + encodable_safe = Encodable.define(Callable[[int, int], int], dangerous_ctx) + source_safe = SynthesizedFunction( + module_code="""def add(a: int, b: int) -> int: + return a + b""" + ) + with handler(RestrictedEvalProvider()): + fn = encodable_safe.decode(source_safe) + assert fn(2, 3) == 5, "Safe code should still work" + + # Test 4: Private attribute access should still be blocked + encodable_private = Encodable.define(Callable[[str], str], dangerous_ctx) + source_private = SynthesizedFunction( + module_code="""def get_class(s: str) -> str: + return s.__class__.__name__""" + ) + with pytest.raises(Exception): + with handler(RestrictedEvalProvider()): + fn = encodable_private.decode(source_private) + fn("test")