diff --git a/envs/coding_env/server/__init__.py b/envs/coding_env/server/__init__.py index dab6b748a..33f7e7894 100644 --- a/envs/coding_env/server/__init__.py +++ b/envs/coding_env/server/__init__.py @@ -4,8 +4,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Coding environment server components.""" +"""Coding environment server components. -from .python_codeact_env import PythonCodeActEnv +Keep imports lazy so utility modules (for example transforms) remain importable +without pulling optional runtime dependencies like smolagents. +""" + +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from .python_codeact_env import PythonCodeActEnv __all__ = ["PythonCodeActEnv"] + + +def __dir__() -> list[str]: + return sorted({*globals(), *__all__}) + + +def __getattr__(name: str) -> Any: + if name == "PythonCodeActEnv": + from .python_codeact_env import PythonCodeActEnv + + return PythonCodeActEnv + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/envs/coding_env/server/app.py b/envs/coding_env/server/app.py index 4c712916b..2271b69de 100644 --- a/envs/coding_env/server/app.py +++ b/envs/coding_env/server/app.py @@ -21,9 +21,10 @@ python -m envs.coding_env.server.app """ +from openenv.core.env_server import create_app + from coding_env.models import CodeAction, CodeObservation from coding_env.server.python_codeact_env import PythonCodeActEnv -from openenv.core.env_server import create_app # Create the app with web interface and README integration # Pass the class (factory) instead of an instance for WebSocket session support diff --git a/envs/coding_env/server/python_codeact_env.py b/envs/coding_env/server/python_codeact_env.py index dbfc39e6a..aa96bfa15 100644 --- a/envs/coding_env/server/python_codeact_env.py +++ b/envs/coding_env/server/python_codeact_env.py @@ -12,6 +12,7 @@ """ import uuid +from typing import Any, Optional from openenv.core.env_server.interfaces import Action, Environment, Observation @@ -50,15 +51,33 @@ def __init__( self._executor = PyExecutor() self._state = CodeState() - def reset(self) -> Observation: + def reset( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> Observation: """ Reset environment and start fresh execution session. + Args: + seed: Accepted for API compatibility. This deterministic executor + has no random state to seed. + episode_id: Optional episode identifier override. If omitted or + empty, a new episode ID is generated. + **kwargs: Forward-compatible reset parameters accepted by the base + Environment API but unused by this environment. + Returns: Initial observation with empty stdout/stderr and exit_code=0 """ + del seed, kwargs + # Initialize fresh state - self._state = CodeState(episode_id=str(uuid.uuid4()), step_count=0) + self._state = CodeState( + episode_id=episode_id or str(uuid.uuid4()), + step_count=0, + ) # Add last_exit_code to state self._state.last_exit_code = 0 @@ -77,12 +96,21 @@ def reset(self) -> Observation: return self._apply_transform(observation) - def step(self, action: Action) -> Observation: + def step( + self, + action: Action, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> Observation: """ Execute code action and return observation. Args: action: CodeAction containing the code to execute + timeout_s: Accepted for Environment API compatibility. PyExecutor + does not currently expose per-call timeout control. + **kwargs: Forward-compatible step parameters accepted by the base + Environment API but unused by this environment. Returns: CodeObservation with execution results (stdout, stderr, exit_code) @@ -90,6 +118,8 @@ def step(self, action: Action) -> Observation: Raises: ValueError: If action is not a CodeAction instance """ + del timeout_s, kwargs + if not isinstance(action, CodeAction): raise ValueError(f"Expected CodeAction, got {type(action)}") diff --git a/envs/coding_env/server/transforms.py b/envs/coding_env/server/transforms.py index fc92e89ba..a03a77cff 100644 --- a/envs/coding_env/server/transforms.py +++ b/envs/coding_env/server/transforms.py @@ -17,33 +17,80 @@ class CodeSafetyTransform(Transform): - """Evaluates code safety and assigns penalties for dangerous patterns.""" + """ + Assign penalties for obviously unsafe coding patterns. + + This is a reward heuristic, not a security sandbox. Container isolation is + the security boundary; this transform only shapes rewards for common cases. + """ def __init__(self, penalty: float = -1.0): self.penalty = penalty - self.dangerous_patterns = [ - r"import\s+os", - r"import\s+subprocess", - r"eval\(", - r"exec\(", - r"__import__", - r"open\(", + self._fallback_patterns = [ + (re.compile(r"\bimport\s+os\b"), "import os"), + (re.compile(r"\bimport\s+subprocess\b"), "import subprocess"), + (re.compile(r"\beval\s*\("), "eval"), + (re.compile(r"\bexec\s*\("), "exec"), + (re.compile(r"\b__import__\s*\("), "__import__"), + (re.compile(r"\bopen\s*\("), "open"), + (re.compile(r"\.open\s*\("), "open"), ] + def _detect_text_violation(self, code: str) -> str | None: + for pattern, violation in self._fallback_patterns: + if pattern.search(code): + return violation + return None + + def _detect_violation(self, code: str) -> str | None: + """ + Detect dangerous operations using AST analysis. + + AST-based detection avoids false positives from harmless string literals + (e.g. ``print("import os")``) or similarly named user functions + (e.g. ``myopen()``). + """ + try: + tree = ast.parse(code) + except (SyntaxError, RecursionError, ValueError): + # Fall back to the previous raw-text heuristic when AST parsing + # cannot inspect malformed or pathologically nested code. + return self._detect_text_violation(code) + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + top_level_module = alias.name.split(".", 1)[0] + if top_level_module in {"os", "subprocess"}: + return f"import {top_level_module}" + + if isinstance(node, ast.ImportFrom) and node.module: + top_level_module = node.module.split(".", 1)[0] + if top_level_module in {"os", "subprocess"}: + return f"import {top_level_module}" + + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + called_name = node.func.id + if called_name in {"eval", "exec", "open", "__import__"}: + return called_name + if isinstance(node.func, ast.Attribute) and node.func.attr == "open": + return "open" + + return None + def __call__(self, observation: Observation) -> Observation: if not isinstance(observation, CodeObservation): return observation if "last_code" in observation.metadata: code = observation.metadata["last_code"] - for pattern in self.dangerous_patterns: - if re.search(pattern, code): - observation.reward = self.penalty - observation.metadata["safety_violation"] = pattern - break - else: - if observation.reward is None: - observation.reward = 0.0 + violation = self._detect_violation(code) + if violation is not None: + observation.reward = self.penalty + observation.metadata["safety_violation"] = violation + elif observation.reward is None: + observation.reward = 0.0 return observation @@ -77,7 +124,7 @@ def __call__(self, observation: Observation) -> Observation: # Check syntax (redundant but useful for quality assessment) try: ast.parse(code) - except SyntaxError: + except (SyntaxError, RecursionError, ValueError): quality_score += self.syntax_penalty # Add to existing reward diff --git a/tests/envs/test_coding_safety_transform.py b/tests/envs/test_coding_safety_transform.py new file mode 100644 index 000000000..7e9429eda --- /dev/null +++ b/tests/envs/test_coding_safety_transform.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for coding_env safety transform false-positive handling.""" + +from unittest.mock import patch + +from coding_env.models import CodeObservation +from coding_env.server.transforms import CodeQualityTransform, CodeSafetyTransform + + +def _apply_safety_transform(code: str) -> CodeObservation: + transform = CodeSafetyTransform() + observation = CodeObservation( + stdout="", + stderr="", + exit_code=0, + metadata={"last_code": code}, + ) + transformed = transform(observation) + assert isinstance(transformed, CodeObservation) + return transformed + + +def test_blocks_real_dangerous_import(): + observation = _apply_safety_transform("import os\nprint('x')") + assert observation.reward == -1.0 + assert "safety_violation" in observation.metadata + + +def test_blocks_import_with_alias(): + observation = _apply_safety_transform("import os as operating_system") + assert observation.reward == -1.0 + assert observation.metadata["safety_violation"] == "import os" + + +def test_blocks_subprocess_import(): + observation = _apply_safety_transform("import subprocess") + assert observation.reward == -1.0 + assert observation.metadata["safety_violation"] == "import subprocess" + + +def test_blocks_from_subprocess_import(): + observation = _apply_safety_transform("from subprocess import run") + assert observation.reward == -1.0 + assert observation.metadata["safety_violation"] == "import subprocess" + + +def test_blocks_from_os_path_import(): + observation = _apply_safety_transform("from os.path import join") + assert observation.reward == -1.0 + assert observation.metadata["safety_violation"] == "import os" + + +def test_blocks_builtin_open_call(): + observation = _apply_safety_transform( + "with open('f.txt') as f:\n data = f.read()" + ) + assert observation.reward == -1.0 + assert "safety_violation" in observation.metadata + + +def test_blocks_attribute_open_call(): + observation = _apply_safety_transform("Path('f.txt').open()") + assert observation.reward == -1.0 + assert observation.metadata["safety_violation"] == "open" + + +def test_blocks_raw_text_violation_when_parse_fails(): + observation = _apply_safety_transform("import os\n\x00") + assert observation.reward == -1.0 + assert observation.metadata["safety_violation"] == "import os" + + +def test_blocks_builtin_eval_call(): + observation = _apply_safety_transform("result = eval('1 + 1')") + assert observation.reward == -1.0 + assert observation.metadata["safety_violation"] == "eval" + + +def test_blocks_builtin_exec_call(): + observation = _apply_safety_transform("exec('x = 1')") + assert observation.reward == -1.0 + assert observation.metadata["safety_violation"] == "exec" + + +def test_blocks_builtin_import_call(): + observation = _apply_safety_transform("__import__('os')") + assert observation.reward == -1.0 + assert observation.metadata["safety_violation"] == "__import__" + + +def test_does_not_flag_string_literal_with_dangerous_text(): + observation = _apply_safety_transform("print('import os')") + assert observation.reward == 0.0 + assert "safety_violation" not in observation.metadata + + +def test_does_not_flag_user_defined_myopen_function(): + observation = _apply_safety_transform( + "def myopen():\n return 1\nresult = myopen()" + ) + assert observation.reward == 0.0 + assert "safety_violation" not in observation.metadata + + +def test_does_not_flag_attribute_method_named_exec(): + observation = _apply_safety_transform( + "class DB:\n" + " def exec(self, sql):\n" + " return sql\n" + "db = DB()\n" + "result = db.exec('SELECT 1')" + ) + assert observation.reward == 0.0 + assert "safety_violation" not in observation.metadata + + +def test_quality_transform_handles_ast_recursion_error(): + def raise_recursion_error(_code: str): + raise RecursionError("pathologically nested code") + + transform = CodeQualityTransform(concise_bonus=0.0, syntax_penalty=-0.2) + observation = CodeObservation( + stdout="", + stderr="", + exit_code=0, + metadata={"last_code": "x = 1"}, + ) + + with patch("coding_env.server.transforms.ast.parse", raise_recursion_error): + transformed = transform(observation) + + assert isinstance(transformed, CodeObservation) + assert transformed.reward == -0.2 diff --git a/tests/envs/test_python_codeact_reset.py b/tests/envs/test_python_codeact_reset.py index b4d8b59f1..a7336a9e9 100644 --- a/tests/envs/test_python_codeact_reset.py +++ b/tests/envs/test_python_codeact_reset.py @@ -166,3 +166,43 @@ def test_reset_changes_episode_id(): # Episode IDs should be different assert episode_id_1 != episode_id_2 + + +def test_reset_accepts_episode_id_override(): + """Test that reset() accepts an explicit episode_id.""" + env = PythonCodeActEnv() + + env.reset(episode_id="episode-123") + + assert env.state.episode_id == "episode-123" + assert env.state.step_count == 0 + + +def test_reset_accepts_seed_parameter(): + """Test that reset() accepts a seed for API compatibility.""" + env = PythonCodeActEnv() + + obs = env.reset(seed=42) + + assert obs.exit_code == 0 + assert env.state.step_count == 0 + + +def test_reset_replaces_empty_episode_id_override(): + """Test that reset() replaces an empty episode_id with a generated ID.""" + env = PythonCodeActEnv() + + env.reset(episode_id="") + + assert env.state.episode_id + assert env.state.step_count == 0 + + +def test_step_accepts_timeout_parameter(): + """Test that step() accepts timeout_s for API compatibility.""" + env = PythonCodeActEnv() + env.reset() + + obs = env.step(CodeAction(code="print('ok')"), timeout_s=30.0) + + assert obs.exit_code == 0