diff --git a/README.md b/README.md index 0fd3e1ebb..ada04b456 100644 --- a/README.md +++ b/README.md @@ -273,6 +273,7 @@ See example scripts in `examples/` directory. The OpenEnv CLI provides commands to manage environments: - **`openenv init `** - Initialize a new environment from template +- **`openenv import --name --output-dir `** - Wrap a supported third-party source environment, including ORS/OpenReward and Verifiers, as OpenEnv - **`openenv push [--repo-id ] [--private]`** - Deploy environment to Hugging Face Spaces ### Quick Start @@ -281,12 +282,15 @@ The OpenEnv CLI provides commands to manage environments: # Create a new environment openenv init my_game_env +# Or import an ORS/OpenReward or Verifiers source environment +openenv import path/to/source --name my_game_env --output-dir . + # Deploy to Hugging Face (will prompt for login if needed) cd my_game_env openenv push ``` -For detailed options: `openenv init --help` and `openenv push --help`. +For detailed options: `openenv init --help`, `openenv import --help`, and `openenv push --help`. ## Design Principles diff --git a/docs/source/getting_started/environment-builder.md b/docs/source/getting_started/environment-builder.md index 8d38d9d4c..e96d75e38 100644 --- a/docs/source/getting_started/environment-builder.md +++ b/docs/source/getting_started/environment-builder.md @@ -30,6 +30,7 @@ Already familiar with OpenEnv? Here's the 8-step process at a glance: | Command | Description | |---------|-------------| | `openenv init NAME` | Scaffold new environment | +| `openenv import SOURCE --name NAME --output-dir DIR` | Wrap a supported third-party environment source tree | | `openenv serve` | Start local dev server | | `openenv build` | Build Docker image | | `openenv validate --verbose` | Validate environment structure | @@ -101,6 +102,12 @@ my_env/ Python classes are generated for the action, observation, environment, and client. For example, you will find `MyEnvironment`, `MyAction`, `MyObservation`, and `MyEnv` (client) in the `my_env` directory based on the name you provided. The environment uses the core `State` class from `openenv.core.env_server.types`. +If you already have an ORS/OpenReward or Prime Intellect Verifiers environment +source tree, use `openenv import SOURCE --name my_env --output-dir /Users/you/envs` +instead. The importer detects the source type from the code, vendors the source +under the generated package, and emits an OpenEnv wrapper with task/split and +MCP-style tool actions. + ### 2. Define Models Edit `models.py` to describe your action and observation using Pydantic: diff --git a/docs/source/guides/environment-anatomy.md b/docs/source/guides/environment-anatomy.md index 0bd6804ce..7849e7b31 100644 --- a/docs/source/guides/environment-anatomy.md +++ b/docs/source/guides/environment-anatomy.md @@ -101,6 +101,18 @@ app = create_app( This is what the environment's `server/app.py` entry point typically does — see `envs/echo_env/server/app.py` for a minimal real example. +## Optional Task APIs + +Environments that expose reusable datasets can implement the optional task and +split `TaskProvider` protocol: `list_splits`, `list_tasks`, `num_tasks`, +`get_task`, and `get_task_range`. These are discovery methods, not part of the +core step/reset environment contract, and should be side-effect-free because +compatibility routes may call them on short-lived environment instances. OpenEnv +serves them through ORS-compatible endpoints such as `/list_environments`, +`/{env_name}/splits`, and `/{env_name}/task_range`. Existing environments that +do not implement these methods continue to work; task routes return `501` for +unsupported APIs. + ## Rewards via the Rubric Rewards are computed **inside the environment**, not by external code. The base `Environment` accepts an optional `rubric` on `__init__` — pass it to `super().__init__(rubric=...)`, call `self._reset_rubric()` from `reset`, and `self._apply_rubric(action, observation)` from `step` (or `_apply_rubric_async` from `step_async`). The [Rubrics tutorial](../tutorials/rubrics.md) covers the composable API end-to-end. diff --git a/docs/source/reference/cli.md b/docs/source/reference/cli.md index 1959b3b36..7a8736b03 100644 --- a/docs/source/reference/cli.md +++ b/docs/source/reference/cli.md @@ -11,6 +11,30 @@ The `openenv` CLI provides a set of commands for building, validating, and pushi :show-inheritance: ``` +## `openenv import` + +Import a supported third-party source environment into a generated OpenEnv +wrapper package. The command detects the source format from the directory +contents, so ORS/OpenReward and Prime Intellect Verifiers sources do not +require `--type` in the common case. + +The generated wrapper vendors the source tree into the package. The importer +skips VCS/cache/build directories and common secret file patterns such as +`.env`, `secrets.yaml`, and private key files; review the generated `vendor/` +directory before publishing a wrapper. + +```bash +openenv import path/to/source --name my_env --output-dir ./envs +openenv import path/to/source --name my_env --output-dir ./envs --env-class MyEnv +``` + +```{eval-rst} +.. automodule:: openenv.cli.commands.import_env + :members: + :undoc-members: + :show-inheritance: +``` + ## `openenv build` ```{eval-rst} diff --git a/examples/import_integrations_demo.sh b/examples/import_integrations_demo.sh new file mode 100755 index 000000000..4b5b0ee9d --- /dev/null +++ b/examples/import_integrations_demo.sh @@ -0,0 +1,355 @@ +#!/usr/bin/env bash +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +WORK_DIR="${OPENENV_IMPORT_DEMO_DIR:-$(mktemp -d "${TMPDIR:-/tmp}/openenv-import-demo.XXXXXX")}" +GENERATED_DIR="$WORK_DIR/generated" +ORS_SOURCE="$WORK_DIR/ors_source" +VERIFIERS_SOURCE="$WORK_DIR/verifiers_source" +PIDS=() + +cleanup() { + for pid in "${PIDS[@]:-}"; do + if kill -0 "$pid" >/dev/null 2>&1; then + kill "$pid" >/dev/null 2>&1 || true + wait "$pid" >/dev/null 2>&1 || true + fi + done + if [[ -z "${OPENENV_IMPORT_DEMO_KEEP:-}" ]]; then + rm -rf "$WORK_DIR" + else + printf 'Keeping demo workspace: %s\n' "$WORK_DIR" + fi +} +trap cleanup EXIT + +real_uv="$(command -v uv)" +mkdir -p "$WORK_DIR/fakebin" +cat > "$WORK_DIR/fakebin/uv" < bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + sock.bind(("127.0.0.1", port)) + except OSError: + return False + return True + +if preferred > 0 and available(preferred): + print(preferred) +else: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + print(sock.getsockname()[1]) +PY +} + +write_sources() { + mkdir -p "$ORS_SOURCE/ors" "$VERIFIERS_SOURCE/verifiers" "$GENERATED_DIR" + + cat > "$ORS_SOURCE/ors/__init__.py" <<'PY' +from .environment import ( + Environment, + ListToolsOutput, + RunToolOutput, + Split, + TextBlock, + ToolOutput, + ToolSpec, +) +PY + + cat > "$ORS_SOURCE/ors/environment.py" <<'PY' +class Model: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def model_dump(self): + return dict(self.__dict__) + + +class Split(Model): + pass + + +class ToolSpec(Model): + pass + + +class ListToolsOutput(Model): + pass + + +class TextBlock(Model): + def __init__(self, text, detail=None, type="text"): + super().__init__(text=text, detail=detail, type=type) + + +class ToolOutput(Model): + pass + + +class RunToolSuccess: + ok = True + + def __init__(self, output): + self.output = output + + +class RunToolOutput: + def __init__(self, output): + self.root = RunToolSuccess(output) + + +class Environment: + def __init__(self, task_spec=None, secrets=None): + self.task_spec = task_spec or {} + self.secrets = secrets or {} + + def setup(self): + return None + + def teardown(self): + return None +PY + + cat > "$ORS_SOURCE/demo_env.py" <<'PY' +from ors import Environment, ListToolsOutput, RunToolOutput, Split, TextBlock, ToolOutput, ToolSpec + + +class DemoEnvironment(Environment): + @classmethod + def list_splits(cls): + return [Split(name="train", type="train")] + + @classmethod + def list_tasks(cls, split): + return [{"id": "task-1", "question": "What is 2 + 2?", "answer": "4"}] + + @classmethod + def num_tasks(cls, split): + return len(cls.list_tasks(split)) + + @classmethod + def get_task(cls, split, index): + return cls.list_tasks(split)[index] + + @classmethod + def get_task_range(cls, split, start=None, stop=None): + return cls.list_tasks(split)[slice(start, stop)] + + @classmethod + def list_tools(cls): + return ListToolsOutput( + tools=[ + ToolSpec( + name="answer", + description="Submit an answer", + input_schema={"type": "object", "properties": {"value": {"type": "string"}}}, + ) + ] + ) + + def get_prompt(self): + return [TextBlock(text=self.task_spec["question"])] + + async def _call_tool(self, name, input): + value = str(input.get("value", "")) + correct = value == self.task_spec["answer"] + return RunToolOutput( + ToolOutput( + blocks=[TextBlock(text="correct" if correct else "wrong")], + metadata={"submitted": value}, + reward=1.0 if correct else 0.0, + finished=True, + ) + ) +PY + + cat > "$VERIFIERS_SOURCE/verifiers/__init__.py" <<'PY' +class Environment: + pass + + +class Rubric: + def __init__(self, funcs=None): + self.funcs = funcs or [self.exact_match] + + async def exact_match(self, completion, answer, **kwargs): + text = completion[-1]["content"] if completion else "" + return 1.0 if answer and answer in text else 0.0 + + async def score_rollout(self, state): + metrics = {} + reward = 0.0 + for func in self.funcs: + score = await func( + completion=state.get("completion") or [], + answer=state.get("answer") or state.get("task", {}).get("answer", ""), + state=state, + ) + metrics[getattr(func, "__name__", "reward")] = float(score) + reward += float(score) + state["reward"] = reward + state["metrics"] = metrics + + +class SingleTurnEnv(Environment): + def __init__(self, dataset, eval_dataset=None, rubric=None): + self._dataset = dataset + self._eval_dataset = eval_dataset or dataset + self.rubric = rubric or Rubric() + + def get_dataset(self): + return self._dataset + + def get_eval_dataset(self): + return self._eval_dataset +PY + + cat > "$VERIFIERS_SOURCE/simple_math.py" <<'PY' +import verifiers as vf + + +def load_environment() -> vf.Environment: + dataset = [ + { + "prompt": [{"role": "user", "content": "What is 2 + 2?"}], + "answer": "4", + "example_id": 0, + } + ] + return vf.SingleTurnEnv(dataset=dataset, rubric=vf.Rubric()) +PY +} + +import_env() { + local source_dir="$1" + local name="$2" + + printf '\n==> openenv import %s --name %s\n' "$source_dir" "$name" + uv run python -m openenv.cli.__main__ import "$source_dir" \ + --name "$name" \ + --output-dir "$GENERATED_DIR" +} + +start_server() { + local package="$1" + local port="$2" + local log_file="$WORK_DIR/$package.log" + + printf '==> starting %s on http://127.0.0.1:%s\n' "$package" "$port" + PYTHONPATH="$REPO_ROOT/src:$GENERATED_DIR${PYTHONPATH:+:$PYTHONPATH}" \ + uv run python -m "$package.server.app" --port "$port" >"$log_file" 2>&1 & + local pid="$!" + PIDS+=("$pid") + + for _ in $(seq 1 80); do + if curl -fsS "http://127.0.0.1:$port/health" >/dev/null 2>&1; then + return 0 + fi + if ! kill -0 "$pid" >/dev/null 2>&1; then + printf 'Server %s exited early. Log:\n' "$package" >&2 + cat "$log_file" >&2 + return 1 + fi + sleep 0.25 + done + + printf 'Timed out waiting for %s. Log:\n' "$package" >&2 + cat "$log_file" >&2 + return 1 +} + +exercise_server() { + local label="$1" + local env_name="$2" + local port="$3" + local tool_name="$4" + local arguments_json="$5" + + LABEL="$label" ENV_NAME="$env_name" PORT="$port" TOOL_NAME="$tool_name" ARGUMENTS_JSON="$arguments_json" \ + uv run python <<'PY' +import json +import os +import urllib.request + +base = f"http://127.0.0.1:{os.environ['PORT']}" +env_name = os.environ["ENV_NAME"] +tool_name = os.environ["TOOL_NAME"] +arguments = json.loads(os.environ["ARGUMENTS_JSON"]) + + +def request(method, path, payload=None): + data = None if payload is None else json.dumps(payload).encode() + headers = {} + if data is not None: + headers["content-type"] = "application/json" + req = urllib.request.Request(base + path, data=data, method=method, headers=headers) + with urllib.request.urlopen(req, timeout=10) as response: + return json.loads(response.read().decode()) + + +def show(name, value): + print(f"\n{os.environ['LABEL']} {name}") + print(json.dumps(value, indent=2, sort_keys=True)) + + +show("environments", request("GET", "/list_environments")) +show("splits", request("GET", f"/{env_name}/splits")) +show("task", request("POST", f"/{env_name}/task", {"split": "train", "index": 0})) +show("tools/list", request("POST", "/mcp", { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list", + "params": {}, +})) +show("tools/call", request("POST", "/mcp", { + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": {"name": tool_name, "arguments": arguments}, +})) +PY +} + +main() { + printf 'Demo workspace: %s\n' "$WORK_DIR" + write_sources + + import_env "$ORS_SOURCE" "ors_openenv_demo" + import_env "$VERIFIERS_SOURCE" "verifiers_openenv_demo" + + ORS_PORT="${ORS_PORT:-$(choose_port 8000)}" + VERIFIERS_PORT="${VERIFIERS_PORT:-$(choose_port 8001)}" + if [[ "$ORS_PORT" == "$VERIFIERS_PORT" ]]; then + VERIFIERS_PORT="$(choose_port 0)" + fi + + start_server "ors_openenv_demo" "$ORS_PORT" + start_server "verifiers_openenv_demo" "$VERIFIERS_PORT" + + exercise_server "ORS/OpenReward" "ors_openenv_demo" "$ORS_PORT" "answer" '{"value":"4"}' + exercise_server "Verifiers" "verifiers_openenv_demo" "$VERIFIERS_PORT" "submit" '{"completion":"The answer is 4."}' + + printf '\nDemo completed successfully.\n' +} + +main "$@" diff --git a/src/openenv/cli/__main__.py b/src/openenv/cli/__main__.py index b80e5b9fd..5569e4416 100644 --- a/src/openenv/cli/__main__.py +++ b/src/openenv/cli/__main__.py @@ -18,6 +18,7 @@ build, collect, fork, + import_env, init, push, serve, @@ -34,6 +35,9 @@ # Register commands app.command(name="init", help="Initialize a new OpenEnv environment")(init.init) +app.command(name="import", help="Import a third-party environment into OpenEnv")( + import_env.import_env +) app.command(name="build", help="Build Docker images for OpenEnv environments")( build.build ) diff --git a/src/openenv/cli/commands/__init__.py b/src/openenv/cli/commands/__init__.py index f351a32ff..7455fc8a8 100644 --- a/src/openenv/cli/commands/__init__.py +++ b/src/openenv/cli/commands/__init__.py @@ -6,6 +6,15 @@ """OpenEnv CLI commands.""" -from . import build, fork, init, push, serve, skills, validate +from . import build, fork, import_env, init, push, serve, skills, validate -__all__ = ["build", "fork", "init", "push", "serve", "skills", "validate"] +__all__ = [ + "build", + "fork", + "import_env", + "init", + "push", + "serve", + "skills", + "validate", +] diff --git a/src/openenv/cli/commands/import_env.py b/src/openenv/cli/commands/import_env.py new file mode 100644 index 000000000..45acc8059 --- /dev/null +++ b/src/openenv/cli/commands/import_env.py @@ -0,0 +1,145 @@ +"""Import third-party environments into OpenEnv wrappers.""" + +from __future__ import annotations + +import shutil +from pathlib import Path +from typing import Annotated + +import typer +from openenv.cli.importers import DEFAULT_IMPORTERS, ImporterRegistry + +from .._cli_utils import console +from .init import _generate_uv_lock, _validate_env_name + + +def _select_match( + registry: ImporterRegistry, + source: Path, + source_type: str | None, + env_class: str | None, +): + try: + matches = registry.detect(source, source_type=source_type) + except ValueError as e: + raise typer.BadParameter(str(e)) from e + if not matches: + supported = ", ".join(registry.supported_types) + raise typer.BadParameter( + f"No supported environment found in {source}. Supported source types: {supported}." + ) + + if env_class: + matches = [ + match + for match in matches + if match[1].class_name == env_class or match[1].qualified_name == env_class + ] + if not matches: + raise typer.BadParameter(f"No detected environment matched {env_class!r}.") + + detected_types = {importer.source_type for importer, _ in matches} + if source_type is None and len(detected_types) > 1: + raise typer.BadParameter( + "Multiple source formats were detected. Re-run with --type " + f"({'/'.join(sorted(detected_types))})." + ) + + if len(matches) > 1: + choices = ", ".join(detected.qualified_name for _, detected in matches) + raise typer.BadParameter( + f"Multiple environment entrypoints were detected ({choices}). " + "Re-run with --env-class." + ) + + return matches[0] + + +def import_env( + source: Annotated[ + str, + typer.Argument(help="Local source repository or directory to import"), + ], + name: Annotated[ + str, + typer.Option("--name", "-n", help="Name for the generated OpenEnv package"), + ], + output_dir: Annotated[ + str, + typer.Option( + "--output-dir", + "-o", + help="Directory where the generated package will be created", + ), + ], + env_class: Annotated[ + str | None, + typer.Option( + "--env-class", + help="Environment class name or module:Class when detection is ambiguous", + ), + ] = None, + source_type: Annotated[ + str | None, + typer.Option( + "--type", + help="Optional source type override, such as 'ors'", + ), + ] = None, +) -> None: + """Deterministically import a third-party environment into OpenEnv.""" + env_name = _validate_env_name(name) + source_path = Path(source).expanduser().resolve() + if not source_path.exists() or not source_path.is_dir(): + raise typer.BadParameter(f"Source must be an existing directory: {source_path}") + + base_dir = Path(output_dir).expanduser().resolve() + env_dir = base_dir / env_name + try: + env_dir.relative_to(source_path) + except ValueError: + pass + else: + raise typer.BadParameter("Output directory must not be inside the source tree") + + if env_dir.exists(): + if env_dir.is_file(): + raise typer.BadParameter(f"Path '{env_dir}' exists and is a file") + if any(env_dir.iterdir()): + raise typer.BadParameter( + f"Directory '{env_dir}' already exists and is not empty." + ) + + registry = ImporterRegistry(DEFAULT_IMPORTERS) + importer, detected = _select_match( + registry, + source_path, + source_type=source_type, + env_class=env_class, + ) + + try: + env_dir.mkdir(parents=True, exist_ok=True) + console.print( + "[bold cyan]Importing environment[/bold cyan] " + f"{detected.qualified_name} as '{env_name}' ({importer.source_type})" + ) + importer.generate( + source=source_path, + destination=env_dir, + env_name=env_name, + detected=detected, + ) + + console.print("[bold green]OK[/bold green] Generated OpenEnv wrapper") + if _generate_uv_lock(env_dir): + console.print("[green]OK[/green] Generated uv.lock") + else: + console.print("[yellow]Warning:[/yellow] Could not generate uv.lock") + + console.print(f"[bold green]Environment created at: {env_dir}[/bold green]") + except Exception as e: + if env_dir.exists() and env_dir.is_dir(): + shutil.rmtree(env_dir, ignore_errors=True) + console.print(f"[bold red]Error:[/bold red] {e}") + raise typer.Exit(1) from e diff --git a/src/openenv/cli/importers/__init__.py b/src/openenv/cli/importers/__init__.py new file mode 100644 index 000000000..9519278cf --- /dev/null +++ b/src/openenv/cli/importers/__init__.py @@ -0,0 +1,16 @@ +"""Deterministic source importers for OpenEnv environments.""" + +from .base import DetectedEnvironment, EnvironmentImporter, ImporterRegistry +from .ors import ORSImporter +from .verifiers import VerifiersImporter + +DEFAULT_IMPORTERS = [ORSImporter(), VerifiersImporter()] + +__all__ = [ + "DEFAULT_IMPORTERS", + "DetectedEnvironment", + "EnvironmentImporter", + "ImporterRegistry", + "ORSImporter", + "VerifiersImporter", +] diff --git a/src/openenv/cli/importers/base.py b/src/openenv/cli/importers/base.py new file mode 100644 index 000000000..64e63aa9a --- /dev/null +++ b/src/openenv/cli/importers/base.py @@ -0,0 +1,237 @@ +"""Shared importer registry types.""" + +from __future__ import annotations + +import fnmatch +import shutil +import textwrap +from dataclasses import dataclass +from pathlib import Path +from typing import Protocol + +try: + import tomllib +except ModuleNotFoundError: # pragma: no cover - Python 3.10 fallback + import tomli as tomllib # type: ignore[no-redef] + +import tomli_w + + +_EXCLUDED_DIRS = { + ".git", + ".hg", + ".mypy_cache", + ".pytest_cache", + ".ruff_cache", + ".tox", + ".venv", + "__pycache__", + "build", + "dist", + "node_modules", + "venv", +} + +_EXCLUDED_FILE_SUFFIXES = { + ".key", + ".p12", + ".pfx", + ".pem", + ".pyc", + ".pyo", +} + +_EXCLUDED_FILE_NAMES = { + ".env", + ".netrc", + "credentials.json", + "secrets.json", + "secrets.toml", + "secrets.yaml", + "secrets.yml", +} + +_EXCLUDED_FILE_PATTERNS = { + ".env.*", + "*_secret.*", + "*_secrets.*", + "id_ed25519*", + "id_rsa*", +} + + +def _is_excluded(path: Path) -> bool: + if any(part in _EXCLUDED_DIRS for part in path.parts): + return True + name = path.name + return ( + name in _EXCLUDED_FILE_NAMES + or path.suffix in _EXCLUDED_FILE_SUFFIXES + or any(fnmatch.fnmatch(name, pattern) for pattern in _EXCLUDED_FILE_PATTERNS) + ) + + +def iter_python_files(source: Path) -> list[Path]: + return [ + path + for path in sorted(source.rglob("*.py")) + if not _is_excluded(path.relative_to(source)) + ] + + +def module_path(source: Path, file_path: Path) -> str: + rel = file_path.relative_to(source).with_suffix("") + parts = list(rel.parts) + if parts[-1] == "__init__": + parts = parts[:-1] + return ".".join(parts) + + +def safe_vendor_dir_name(source: Path) -> str: + name = source.name.strip().replace("-", "_") + return name if name.isidentifier() else "source" + + +def copy_source_tree(source: Path, destination: Path) -> None: + if destination.exists(): + raise FileExistsError(f"Vendored source path already exists: {destination}") + + def ignore(_dir: str, names: list[str]) -> set[str]: + return {name for name in names if _is_excluded(Path(name))} + + shutil.copytree(source, destination, ignore=ignore) + + +def ensure_vendor_package(vendor_dir: Path) -> None: + for path in (vendor_dir.parent, vendor_dir): + init_file = path / "__init__.py" + if not init_file.exists(): + init_file.write_text("", encoding="utf-8", newline="\n") + + +def write_text(path: Path, content: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(textwrap.dedent(content).lstrip(), encoding="utf-8", newline="\n") + + +def _dependency_name(requirement: str) -> str: + requirement = requirement.strip() + for marker in ("[", "<", ">", "=", "!", "~", ";", " "): + if marker in requirement: + requirement = requirement.split(marker, 1)[0] + return requirement.lower().replace("_", "-") + + +def append_dependency_files( + env_dir: Path, + env_name: str, + dependencies: list[str], +) -> None: + requirements = env_dir / "server" / "requirements.txt" + if requirements.exists(): + content = requirements.read_text(encoding="utf-8") + existing = { + _dependency_name(line) + for line in content.splitlines() + if line.strip() and not line.lstrip().startswith("#") + } + missing = [ + dependency + for dependency in dependencies + if _dependency_name(dependency) not in existing + ] + if missing: + requirements.write_text( + content.rstrip() + "\n" + "\n".join(missing) + "\n", + encoding="utf-8", + newline="\n", + ) + + pyproject = env_dir / "pyproject.toml" + if not pyproject.exists(): + return + + data = tomllib.loads(pyproject.read_text(encoding="utf-8")) + project = data.setdefault("project", {}) + project_dependencies = list(project.get("dependencies") or []) + existing = {_dependency_name(dependency) for dependency in project_dependencies} + for dependency in dependencies: + if _dependency_name(dependency) not in existing: + project_dependencies.append(dependency) + existing.add(_dependency_name(dependency)) + project["dependencies"] = project_dependencies + + tool = data.setdefault("tool", {}) + setuptools = tool.setdefault("setuptools", {}) + package_data = setuptools.setdefault("package-data", {}) + vendor_data = list(package_data.get(env_name) or []) + if "vendor/**/*" not in vendor_data: + vendor_data.append("vendor/**/*") + package_data[env_name] = vendor_data + + pyproject.write_text(tomli_w.dumps(data), encoding="utf-8", newline="\n") + + +@dataclass(frozen=True) +class DetectedEnvironment: + """A source environment class detected without importing user code.""" + + source_type: str + class_name: str + module_path: str + file_path: Path + + @property + def qualified_name(self) -> str: + return f"{self.module_path}:{self.class_name}" + + +class EnvironmentImporter(Protocol): + source_type: str + + def detect(self, source: Path) -> list[DetectedEnvironment]: + """Return environments supported by this importer.""" + ... + + def generate( + self, + *, + source: Path, + destination: Path, + env_name: str, + detected: DetectedEnvironment, + ) -> None: + """Generate an OpenEnv wrapper package.""" + ... + + +class ImporterRegistry: + """Registry of deterministic environment importers.""" + + def __init__(self, importers: list[EnvironmentImporter]): + self._importers = importers + + @property + def supported_types(self) -> list[str]: + return [importer.source_type for importer in self._importers] + + def get(self, source_type: str) -> EnvironmentImporter: + for importer in self._importers: + if importer.source_type == source_type: + return importer + supported = ", ".join(self.supported_types) + raise ValueError( + f"Unsupported source type {source_type!r}. Supported: {supported}" + ) + + def detect( + self, + source: Path, + source_type: str | None = None, + ) -> list[tuple[EnvironmentImporter, DetectedEnvironment]]: + importers = [self.get(source_type)] if source_type else self._importers + matches: list[tuple[EnvironmentImporter, DetectedEnvironment]] = [] + for importer in importers: + for detected in importer.detect(source): + matches.append((importer, detected)) + return matches diff --git a/src/openenv/cli/importers/ors.py b/src/openenv/cli/importers/ors.py new file mode 100644 index 000000000..6cc75e23c --- /dev/null +++ b/src/openenv/cli/importers/ors.py @@ -0,0 +1,546 @@ +"""Open Reward Standard source importer.""" + +from __future__ import annotations + +import ast +from pathlib import Path + +from .base import ( + append_dependency_files, + copy_source_tree, + DetectedEnvironment, + ensure_vendor_package, + iter_python_files, + module_path, + safe_vendor_dir_name, + write_text, +) + + +_ORS_MODULES = { + "ors", + "ors.environment", + "openreward", + "openreward.environment", + "openreward.environments", + "openreward.environments.environment", + "openrewardstandard", + "openrewardstandard.environment", +} + +_ORS_ROOT_DEPENDENCIES = { + "openreward": "openreward", + "openrewardstandard": "openrewardstandard", + "ors": "ors", +} + + +def _dotted_name(node: ast.AST) -> str | None: + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + base = _dotted_name(node.value) + if base: + return f"{base}.{node.attr}" + if isinstance(node, ast.Subscript): + return _dotted_name(node.value) + return None + + +def _collect_environment_aliases( + tree: ast.AST, +) -> tuple[set[str], dict[str, str]]: + environment_aliases: set[str] = set() + module_aliases: dict[str, str] = {} + + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom): + module = node.module or "" + if module in _ORS_MODULES: + for alias in node.names: + if alias.name == "Environment": + environment_aliases.add(alias.asname or alias.name) + elif isinstance(node, ast.Import): + for alias in node.names: + if alias.name in _ORS_MODULES: + module_aliases[alias.asname or alias.name.split(".", 1)[0]] = ( + alias.name + ) + + return environment_aliases, module_aliases + + +def _ors_dependency_roots(tree: ast.AST) -> set[str]: + roots: set[str] = set() + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom): + module = node.module or "" + root = module.split(".", 1)[0] + if module in _ORS_MODULES and root in _ORS_ROOT_DEPENDENCIES: + roots.add(root) + elif isinstance(node, ast.Import): + for alias in node.names: + root = alias.name.split(".", 1)[0] + if alias.name in _ORS_MODULES and root in _ORS_ROOT_DEPENDENCIES: + roots.add(root) + return roots + + +def _inherits_ors_environment( + base: ast.AST, + environment_aliases: set[str], + module_aliases: dict[str, str], +) -> bool: + dotted = _dotted_name(base) + if dotted is None: + return False + if dotted in environment_aliases: + return True + + for alias, module in module_aliases.items(): + root = module.split(".", 1)[0] + if dotted == f"{alias}.Environment": + return module in _ORS_MODULES or root in _ORS_ROOT_DEPENDENCIES + if dotted == f"{alias}.environment.Environment": + return module in _ORS_MODULES or root in _ORS_ROOT_DEPENDENCIES + if dotted == f"{alias}.environments.Environment": + return module in _ORS_MODULES or root in _ORS_ROOT_DEPENDENCIES + if dotted == f"{alias}.Environment" and ( + module.endswith(".environment") or module.endswith(".environments") + ): + return True + return False + + +def detect_ors_environments(source: Path) -> list[DetectedEnvironment]: + """Detect ORS/OpenReward environment classes without importing source files.""" + source = source.resolve() + matches: list[DetectedEnvironment] = [] + + for file_path in iter_python_files(source): + try: + tree = ast.parse(file_path.read_text(encoding="utf-8")) + except (SyntaxError, UnicodeDecodeError): + continue + + environment_aliases, module_aliases = _collect_environment_aliases(tree) + if not environment_aliases and not module_aliases: + continue + + for node in tree.body: + if not isinstance(node, ast.ClassDef): + continue + if any( + _inherits_ors_environment(base, environment_aliases, module_aliases) + for base in node.bases + ): + matches.append( + DetectedEnvironment( + source_type="ors", + class_name=node.name, + module_path=module_path(source, file_path), + file_path=file_path, + ) + ) + + return matches + + +def detect_ors_dependencies(source: Path) -> list[str]: + roots: set[str] = set() + for file_path in iter_python_files(source.resolve()): + try: + tree = ast.parse(file_path.read_text(encoding="utf-8")) + except (SyntaxError, UnicodeDecodeError): + continue + roots.update(_ors_dependency_roots(tree)) + + dependencies = [] + for root in sorted(roots): + if (source / root).exists() or (source / f"{root}.py").exists(): + continue + dependencies.append(_ORS_ROOT_DEPENDENCIES[root]) + return dependencies + + +def _wrapper_source( + *, + env_name: str, + class_name_prefix: str, + source_module: str, + source_class: str, + vendor_dir: str, +) -> str: + source_import_module = f"{env_name}.vendor.{vendor_dir}" + if source_module: + source_import_module = f"{source_import_module}.{source_module}" + return f''' + from __future__ import annotations + + import asyncio + import contextlib + import inspect + import sys + import threading + from importlib import import_module + from pathlib import Path + from typing import Any + from uuid import uuid4 + + from openenv.core.env_server.interfaces import Environment + from openenv.core.env_server.mcp_types import ( + CallToolAction, + CallToolObservation, + ListToolsAction, + ListToolsObservation, + Tool, + ToolError, + ToolErrorType, + ) + from openenv.core.env_server.types import Observation, State + + + _VENDORED_SOURCE_ROOT = Path(__file__).resolve().parents[1] / "vendor" / "{vendor_dir}" + _SOURCE_MODULE = "{source_import_module}" + + + @contextlib.contextmanager + def _vendored_source_path(): + source_path = str(_VENDORED_SOURCE_ROOT) + inserted = source_path not in sys.path + if inserted: + sys.path.insert(0, source_path) + try: + yield + finally: + if inserted: + try: + sys.path.remove(source_path) + except ValueError: + pass + + + with _vendored_source_path(): + _ORIGINAL_ENV_CLASS = getattr(import_module(_SOURCE_MODULE), "{source_class}") + + + def _run_sync(value: Any) -> Any: + if not inspect.isawaitable(value): + return value + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(value) + if not loop.is_running(): + return loop.run_until_complete(value) + + result: dict[str, Any] = {{}} + + def runner() -> None: + try: + result["value"] = asyncio.run(value) + except BaseException as exc: + result["error"] = exc + + thread = threading.Thread(target=runner, daemon=True) + thread.start() + thread.join() + if "error" in result: + raise result["error"] + return result.get("value") + + + def _call_vendored(func: Any, *args: Any, **kwargs: Any) -> Any: + with _vendored_source_path(): + return _run_sync(func(*args, **kwargs)) + + + def _dump(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, list): + return [_dump(item) for item in value] + if isinstance(value, tuple): + return [_dump(item) for item in value] + if isinstance(value, dict): + return {{str(key): _dump(item) for key, item in value.items()}} + if hasattr(value, "model_dump"): + return _dump(value.model_dump()) + if hasattr(value, "__dict__"): + return _dump(value.__dict__) + return str(value) + + + def _normalize_split(split: Any) -> dict[str, Any]: + value = _dump(split) + if isinstance(value, dict): + return value + split_name = str(value) + split_type = split_name if split_name in {{"train", "validation", "test"}} else "validation" + return {{"name": split_name, "type": split_type}} + + + def _tool_from_ors(tool: Any) -> Tool: + value = _dump(tool) + input_schema = value.get("input_schema") or value.get("inputSchema") + if input_schema is None: + input_schema = {{"type": "object", "properties": {{}}}} + return Tool( + name=value["name"], + description=value.get("description") or "", + input_schema=input_schema, + ) + + + class {class_name_prefix}Environment(Environment): + """OpenEnv wrapper around a vendored ORS/OpenReward environment.""" + + SUPPORTS_CONCURRENT_SESSIONS = False + + def __init__(self): + self._ors_cls = _ORIGINAL_ENV_CLASS + self._ors_env: Any | None = None + self._state = State(episode_id=str(uuid4()), step_count=0) + self._task_spec: Any | None = None + self._last_reward: float | None = None + self._done = False + + def list_splits(self) -> list[dict[str, Any]]: + splits = _call_vendored(self._ors_cls.list_splits) + return [_normalize_split(split) for split in splits] + + def list_tasks(self, split: str) -> list[Any]: + return _dump(_call_vendored(self._ors_cls.list_tasks, split)) + + def num_tasks(self, split: str) -> int: + return int(_call_vendored(self._ors_cls.num_tasks, split)) + + def get_task(self, split: str, index: int) -> Any: + return _dump(_call_vendored(self._ors_cls.get_task, split, index)) + + def get_task_range( + self, + split: str, + start: int | None = None, + stop: int | None = None, + ) -> list[Any]: + return _dump(_call_vendored(self._ors_cls.get_task_range, split, start, stop)) + + def _first_task(self) -> tuple[str, int, Any]: + splits = self.list_splits() + if not splits: + raise RuntimeError("ORS environment has no splits") + split = splits[0]["name"] + return split, 0, self.get_task(split, 0) + + def reset( + self, + seed: int | None = None, + episode_id: str | None = None, + task_spec: dict[str, Any] | None = None, + split: str | None = None, + index: int | None = None, + secrets: dict[str, str] | None = None, + **kwargs: Any, + ) -> Observation: + self.close() + if task_spec is None: + if split is None and index is None: + split, index, task_spec = self._first_task() + elif split is None or index is None: + raise ValueError("split and index must be provided together") + else: + task_spec = self.get_task(split, index) + + self._task_spec = _dump(task_spec) + self._ors_env = _call_vendored( + self._ors_cls, + task_spec=task_spec, + secrets=secrets or {{}}, + ) + _call_vendored(self._ors_env.setup) + prompt = _dump(_call_vendored(self._ors_env.get_prompt)) + self._last_reward = None + self._done = False + self._state = State( + episode_id=episode_id or str(uuid4()), + step_count=0, + source_type="ors", + original_env_class="{source_class}", + task_spec=self._task_spec, + split=split, + index=index, + ) + return Observation( + done=False, + reward=None, + metadata={{ + "source_type": "ors", + "original_env_class": "{source_class}", + "task_spec": self._task_spec, + "prompt": prompt, + }}, + ) + + def _ensure_session(self) -> None: + if self._ors_env is None: + raise RuntimeError("Call reset() before invoking ORS tools") + + def _all_tools(self) -> list[Tool]: + shared = _call_vendored(self._ors_cls.list_tools) + tools = [_tool_from_ors(tool) for tool in getattr(shared, "tools", [])] + if self._ors_env is not None: + task_tools = _call_vendored(self._ors_env.list_task_tools) + tools.extend(_tool_from_ors(tool) for tool in getattr(task_tools, "tools", [])) + return tools + + def step( + self, + action: Any, + timeout_s: float | None = None, + **kwargs: Any, + ) -> Observation: + if isinstance(action, ListToolsAction): + return ListToolsObservation(tools=self._all_tools()) + if not isinstance(action, CallToolAction): + raise TypeError(f"Unsupported action type: {{type(action).__name__}}") + + self._ensure_session() + assert self._ors_env is not None + result = _call_vendored(self._ors_env._call_tool, action.tool_name, action.arguments) + root = getattr(result, "root", result) + ok = getattr(root, "ok", False) + self._state.step_count += 1 + + if ok: + output = root.output + blocks = _dump(getattr(output, "blocks", [])) + metadata = _dump(getattr(output, "metadata", None)) or {{}} + reward = getattr(output, "reward", None) + done = bool(getattr(output, "finished", False)) + self._last_reward = reward + self._done = done + return CallToolObservation( + tool_name=action.tool_name, + result={{"blocks": blocks, "metadata": metadata}}, + reward=reward, + done=done, + metadata=metadata, + ) + + message = str(getattr(root, "error", "ORS tool call failed")) + return CallToolObservation( + tool_name=action.tool_name, + result=None, + error=ToolError( + error_type=ToolErrorType.EXECUTION_ERROR, + message=message, + ), + reward=None, + done=False, + ) + + @property + def state(self) -> State: + return self._state + + def close(self) -> None: + if self._ors_env is None: + return + try: + _call_vendored(self._ors_env.teardown) + finally: + self._ors_env = None + ''' + + +def _app_source(*, env_name: str, class_name_prefix: str) -> str: + return f''' + from __future__ import annotations + + from openenv.core.env_server.http_server import create_app + from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation + + from .{env_name}_environment import {class_name_prefix}Environment + + + app = create_app( + {class_name_prefix}Environment, + CallToolAction, + CallToolObservation, + env_name="{env_name}", + max_concurrent_envs=1, + ) + + + def main(host: str = "0.0.0.0", port: int = 8000): + import uvicorn + + uvicorn.run(app, host=host, port=port) + + + if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=8000) + args = parser.parse_args() + main(port=args.port) + ''' + + +class ORSImporter: + """Importer for source repos that define ORS/OpenReward environments.""" + + source_type = "ors" + + def detect(self, source: Path) -> list[DetectedEnvironment]: + return detect_ors_environments(source) + + def generate( + self, + *, + source: Path, + destination: Path, + env_name: str, + detected: DetectedEnvironment, + ) -> None: + from openenv.cli.commands.init import ( + _copy_template_directory, + _create_template_replacements, + ) + + replacements = _create_template_replacements(env_name) + _copy_template_directory( + "openenv.cli.templates.openenv_env", + "", + destination, + replacements, + env_name, + ) + + vendor_dir = safe_vendor_dir_name(source) + vendor_path = destination / "vendor" / vendor_dir + copy_source_tree(source, vendor_path) + ensure_vendor_package(vendor_path) + + prefix = replacements["__ENV_CLASS_NAME__"] + write_text( + destination / "server" / f"{env_name}_environment.py", + _wrapper_source( + env_name=env_name, + class_name_prefix=prefix, + source_module=detected.module_path, + source_class=detected.class_name, + vendor_dir=vendor_dir, + ), + ) + write_text( + destination / "server" / "app.py", + _app_source(env_name=env_name, class_name_prefix=prefix), + ) + append_dependency_files( + destination, + env_name, + detect_ors_dependencies(source), + ) diff --git a/src/openenv/cli/importers/verifiers.py b/src/openenv/cli/importers/verifiers.py new file mode 100644 index 000000000..e47b20404 --- /dev/null +++ b/src/openenv/cli/importers/verifiers.py @@ -0,0 +1,562 @@ +"""Prime Intellect Verifiers source importer.""" + +from __future__ import annotations + +import ast +from pathlib import Path + +from .base import ( + append_dependency_files, + copy_source_tree, + DetectedEnvironment, + ensure_vendor_package, + iter_python_files, + module_path, + safe_vendor_dir_name, + write_text, +) + + +_VERIFIERS_MODULES = { + "verifiers", + "verifiers.envs.environment", + "verifiers.v1", +} + + +def _imports_verifiers(tree: ast.AST) -> bool: + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name == "verifiers" or alias.name.startswith("verifiers."): + return True + elif isinstance(node, ast.ImportFrom): + module = node.module or "" + if module in _VERIFIERS_MODULES or module.startswith("verifiers."): + return True + return False + + +def detect_verifiers_environments(source: Path) -> list[DetectedEnvironment]: + """Detect Verifiers load_environment entrypoints without importing source.""" + source = source.resolve() + matches: list[DetectedEnvironment] = [] + + for file_path in iter_python_files(source): + try: + tree = ast.parse(file_path.read_text(encoding="utf-8")) + except (SyntaxError, UnicodeDecodeError): + continue + + if not _imports_verifiers(tree): + continue + + for node in tree.body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + if node.name == "load_environment": + matches.append( + DetectedEnvironment( + source_type="verifiers", + class_name=node.name, + module_path=module_path(source, file_path), + file_path=file_path, + ) + ) + + return matches + + +def _wrapper_source( + *, + env_name: str, + class_name_prefix: str, + source_module: str, + vendor_dir: str, +) -> str: + source_import_module = f"{env_name}.vendor.{vendor_dir}" + if source_module: + source_import_module = f"{source_import_module}.{source_module}" + return f''' + from __future__ import annotations + + import asyncio + import contextlib + import inspect + import sys + import threading + from importlib import import_module + from pathlib import Path + from typing import Any + from uuid import uuid4 + + from openenv.core.env_server.interfaces import Environment + from openenv.core.env_server.mcp_types import ( + CallToolAction, + CallToolObservation, + ListToolsAction, + ListToolsObservation, + Tool, + ToolError, + ToolErrorType, + ) + from openenv.core.env_server.types import Observation, State + + + _VENDORED_SOURCE_ROOT = Path(__file__).resolve().parents[1] / "vendor" / "{vendor_dir}" + _SOURCE_MODULE = "{source_import_module}" + + + @contextlib.contextmanager + def _vendored_source_path(): + source_path = str(_VENDORED_SOURCE_ROOT) + inserted = source_path not in sys.path + if inserted: + sys.path.insert(0, source_path) + try: + yield + finally: + if inserted: + try: + sys.path.remove(source_path) + except ValueError: + pass + + + with _vendored_source_path(): + _LOAD_ENVIRONMENT = getattr(import_module(_SOURCE_MODULE), "load_environment") + + + def _run_sync(value: Any) -> Any: + if not inspect.isawaitable(value): + return value + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(value) + if not loop.is_running(): + return loop.run_until_complete(value) + + result: dict[str, Any] = {{}} + + def runner() -> None: + try: + result["value"] = asyncio.run(value) + except BaseException as exc: + result["error"] = exc + + thread = threading.Thread(target=runner, daemon=True) + thread.start() + thread.join() + if "error" in result: + raise result["error"] + return result.get("value") + + + def _call_vendored(func: Any, *args: Any, **kwargs: Any) -> Any: + with _vendored_source_path(): + return _run_sync(func(*args, **kwargs)) + + + def _dump(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, list): + return [_dump(item) for item in value] + if isinstance(value, tuple): + return [_dump(item) for item in value] + if isinstance(value, dict): + return {{str(key): _dump(item) for key, item in value.items()}} + if hasattr(value, "model_dump"): + return _dump(value.model_dump()) + if hasattr(value, "__dict__"): + return _dump(value.__dict__) + return str(value) + + + def _coerce_prompt(task: dict[str, Any]) -> list[dict[str, Any]]: + prompt = task.get("prompt") + if isinstance(prompt, list): + return _dump(prompt) + if isinstance(prompt, str): + return [{{"role": "user", "content": prompt}}] + question = task.get("question") + if question is not None: + return [{{"role": "user", "content": str(question)}}] + return [] + + + def _completion_messages(arguments: dict[str, Any]) -> list[dict[str, Any]]: + messages = arguments.get("messages") + if isinstance(messages, list): + return _dump(messages) + completion = arguments.get("completion", arguments.get("answer", "")) + return [{{"role": "assistant", "content": str(completion)}}] + + + def _load_environment() -> Any: + sig = inspect.signature(_LOAD_ENVIRONMENT) + kwargs: dict[str, Any] = {{}} + config_param = sig.parameters.get("config") + if config_param is not None and config_param.default is inspect.Parameter.empty: + annotation = config_param.annotation + if isinstance(annotation, type): + kwargs["config"] = annotation() + return _call_vendored(_LOAD_ENVIRONMENT, **kwargs) + + + def _dataset_to_tasks(dataset: Any) -> list[dict[str, Any]]: + if dataset is None: + return [] + rows: list[dict[str, Any]] = [] + for index in range(len(dataset)): + row = _dump(dataset[index]) + if isinstance(row, dict): + row.setdefault("example_id", index) + rows.append(row) + return rows + + + class {class_name_prefix}Environment(Environment): + """OpenEnv wrapper around a vendored Prime Intellect Verifiers environment.""" + + SUPPORTS_CONCURRENT_SESSIONS = False + + def __init__(self): + self._vf_env: Any | None = None + self._state = State(episode_id=str(uuid4()), step_count=0) + self._task_spec: dict[str, Any] | None = None + self._prompt: list[dict[str, Any]] = [] + self._last_reward: float | None = None + self._done = False + + def _env(self) -> Any: + if self._vf_env is None: + self._vf_env = _load_environment() + return self._vf_env + + def _rows_for_split(self, split: str) -> list[dict[str, Any]]: + env = self._env() + taskset = getattr(env, "taskset", None) + if taskset is not None: + if split in {{"eval", "validation", "test"}} and hasattr(taskset, "eval_rows"): + rows = _dump(_call_vendored(taskset.eval_rows)) + elif hasattr(taskset, "rows"): + rows = _dump(_call_vendored(taskset.rows)) + else: + rows = [] + tasks = [] + for index, row in enumerate(rows): + if not isinstance(row, dict): + continue + row_split = row.get("split") + if row_split is not None and split not in {{"eval", "validation", "test"}} and row_split != split: + continue + task = _call_vendored(taskset.task, row) if hasattr(taskset, "task") else row + dumped = _dump(task) + if isinstance(dumped, dict): + dumped.setdefault("example_id", index) + tasks.append(dumped) + return tasks + + if split in {{"eval", "validation", "test"}} and hasattr(env, "get_eval_dataset"): + return _dataset_to_tasks(_call_vendored(env.get_eval_dataset)) + if hasattr(env, "get_dataset"): + return _dataset_to_tasks(_call_vendored(env.get_dataset)) + return [] + + def list_splits(self) -> list[dict[str, Any]]: + env = self._env() + taskset = getattr(env, "taskset", None) + names: list[str] = [] + if taskset is not None and hasattr(taskset, "rows"): + for row in _dump(_call_vendored(taskset.rows)): + if isinstance(row, dict) and row.get("split"): + names.append(str(row["split"])) + if hasattr(taskset, "eval_rows"): + try: + if len(_dump(_call_vendored(taskset.eval_rows))) > 0: + names.append("eval") + except Exception: + pass + else: + if hasattr(env, "get_dataset"): + names.append("train") + if hasattr(env, "get_eval_dataset"): + names.append("eval") + if not names: + names.append("train") + + seen = set() + splits = [] + for name in names: + if name in seen: + continue + seen.add(name) + split_type = name if name in {{"train", "validation", "test"}} else "validation" + splits.append({{"name": name, "type": split_type}}) + return splits + + def list_tasks(self, split: str) -> list[dict[str, Any]]: + return self._rows_for_split(split) + + def num_tasks(self, split: str) -> int: + return len(self.list_tasks(split)) + + def get_task(self, split: str, index: int) -> dict[str, Any]: + return self.list_tasks(split)[index] + + def get_task_range( + self, + split: str, + start: int | None = None, + stop: int | None = None, + ) -> list[dict[str, Any]]: + return self.list_tasks(split)[slice(start, stop)] + + def _first_task(self) -> tuple[str, int, dict[str, Any]]: + splits = self.list_splits() + if not splits: + raise RuntimeError("Verifiers environment has no splits") + split = splits[0]["name"] + return split, 0, self.get_task(split, 0) + + def reset( + self, + seed: int | None = None, + episode_id: str | None = None, + task_spec: dict[str, Any] | None = None, + split: str | None = None, + index: int | None = None, + **kwargs: Any, + ) -> Observation: + if task_spec is None: + if split is None and index is None: + split, index, task_spec = self._first_task() + elif split is None or index is None: + raise ValueError("split and index must be provided together") + else: + task_spec = self.get_task(split, index) + + self._task_spec = _dump(task_spec) + self._prompt = _coerce_prompt(self._task_spec) + self._last_reward = None + self._done = False + self._state = State( + episode_id=episode_id or str(uuid4()), + step_count=0, + source_type="verifiers", + task_spec=self._task_spec, + split=split, + index=index, + ) + return Observation( + done=False, + reward=None, + metadata={{ + "source_type": "verifiers", + "task_spec": self._task_spec, + "prompt": self._prompt, + }}, + ) + + def _ensure_session(self) -> None: + if self._task_spec is None: + raise RuntimeError("Call reset() before submitting Verifiers completions") + + def step( + self, + action: Any, + timeout_s: float | None = None, + **kwargs: Any, + ) -> Observation: + if isinstance(action, ListToolsAction): + return ListToolsObservation( + tools=[ + Tool( + name="submit", + description="Submit a completion to score with the Verifiers environment.", + input_schema={{ + "type": "object", + "properties": {{ + "completion": {{"type": "string"}}, + "messages": {{"type": "array"}}, + }}, + }}, + ) + ] + ) + if not isinstance(action, CallToolAction): + raise TypeError(f"Unsupported action type: {{type(action).__name__}}") + if action.tool_name != "submit": + return CallToolObservation( + tool_name=action.tool_name, + result=None, + error=ToolError( + error_type=ToolErrorType.TOOL_NOT_FOUND, + message=f"Unknown Verifiers wrapper tool: {{action.tool_name}}", + ), + done=False, + reward=None, + ) + + self._ensure_session() + assert self._task_spec is not None + completion = _completion_messages(action.arguments) + score = self._score_completion(completion) + self._state.step_count += 1 + self._last_reward = score.get("reward") + self._done = True + return CallToolObservation( + tool_name=action.tool_name, + result=score, + reward=score.get("reward"), + done=True, + metadata={{"metrics": score.get("metrics", {{}})}}, + ) + + def _score_completion(self, completion: list[dict[str, Any]]) -> dict[str, Any]: + env = self._env() + assert self._task_spec is not None + state: dict[str, Any] = {{ + "input": dict(self._task_spec), + "task": dict(self._task_spec), + "prompt": self._prompt, + "completion": completion, + "answer": self._task_spec.get("answer", ""), + "info": self._task_spec.get("info", {{}}), + "trajectory": [], + "reward": None, + "metrics": None, + "is_completed": True, + "is_truncated": False, + }} + + taskset = getattr(env, "taskset", None) + harness = getattr(env, "harness", None) + if taskset is not None and harness is not None: + task = _call_vendored(taskset.to_task, self._task_spec) if hasattr(taskset, "to_task") else self._task_spec + state["task"] = _dump(task) + with _vendored_source_path(): + maybe_state_cls = getattr(import_module("verifiers"), "State", None) + if maybe_state_cls is not None and hasattr(maybe_state_cls, "for_task"): + try: + vf_state = maybe_state_cls.for_task(task) + vf_state.update(state) + state = vf_state + except Exception: + pass + if hasattr(harness, "score_group"): + _call_vendored(harness.score_group, [task], [state]) + + elif hasattr(env, "rubric") and hasattr(env.rubric, "score_rollout"): + _call_vendored(env.rubric.score_rollout, state) + + reward = state.get("reward") + return {{ + "completion": completion, + "reward": reward, + "metrics": _dump(state.get("metrics") or {{}}), + "state": _dump(state), + }} + + @property + def state(self) -> State: + return self._state + + def close(self) -> None: + if self._vf_env is None: + return + try: + teardown = getattr(self._vf_env, "_teardown", None) + if callable(teardown): + _call_vendored(teardown) + finally: + self._vf_env = None + ''' + + +def _app_source(*, env_name: str, class_name_prefix: str) -> str: + return f''' + from __future__ import annotations + + from openenv.core.env_server.http_server import create_app + from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation + + from .{env_name}_environment import {class_name_prefix}Environment + + + app = create_app( + {class_name_prefix}Environment, + CallToolAction, + CallToolObservation, + env_name="{env_name}", + max_concurrent_envs=1, + ) + + + def main(host: str = "0.0.0.0", port: int = 8000): + import uvicorn + + uvicorn.run(app, host=host, port=port) + + + if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=8000) + args = parser.parse_args() + main(port=args.port) + ''' + + +class VerifiersImporter: + """Importer for Prime Intellect Verifiers environment modules.""" + + source_type = "verifiers" + + def detect(self, source: Path) -> list[DetectedEnvironment]: + return detect_verifiers_environments(source) + + def generate( + self, + *, + source: Path, + destination: Path, + env_name: str, + detected: DetectedEnvironment, + ) -> None: + from openenv.cli.commands.init import ( + _copy_template_directory, + _create_template_replacements, + ) + + replacements = _create_template_replacements(env_name) + _copy_template_directory( + "openenv.cli.templates.openenv_env", + "", + destination, + replacements, + env_name, + ) + + vendor_dir = safe_vendor_dir_name(source) + vendor_path = destination / "vendor" / vendor_dir + copy_source_tree(source, vendor_path) + ensure_vendor_package(vendor_path) + + prefix = replacements["__ENV_CLASS_NAME__"] + write_text( + destination / "server" / f"{env_name}_environment.py", + _wrapper_source( + env_name=env_name, + class_name_prefix=prefix, + source_module=detected.module_path, + vendor_dir=vendor_dir, + ), + ) + write_text( + destination / "server" / "app.py", + _app_source(env_name=env_name, class_name_prefix=prefix), + ) + append_dependency_files(destination, env_name, ["verifiers>=0.1.14"]) diff --git a/src/openenv/core/env_server/__init__.py b/src/openenv/core/env_server/__init__.py index 2c0f1f284..a46ce9970 100644 --- a/src/openenv/core/env_server/__init__.py +++ b/src/openenv/core/env_server/__init__.py @@ -16,7 +16,7 @@ SessionNotFoundError, ) from .http_server import create_app, create_fastapi_app, HTTPEnvServer -from .interfaces import Environment, Message, ModelTokenizer, Transform +from .interfaces import Environment, Message, ModelTokenizer, TaskProvider, Transform try: from .mcp_environment import MCPEnvironment @@ -51,8 +51,12 @@ Action, BaseMessage, ConcurrencyConfig, + GetTaskRangeRequest, + GetTaskRequest, HealthResponse, HealthStatus, + ListTasksRequest, + NumTasksRequest, Observation, SchemaResponse, ServerCapacityStatus, @@ -80,6 +84,7 @@ # Core interfaces "Environment", "Transform", + "TaskProvider", "Message", "ModelTokenizer", # Types @@ -88,6 +93,10 @@ "State", "SchemaResponse", "HealthResponse", + "ListTasksRequest", + "NumTasksRequest", + "GetTaskRequest", + "GetTaskRangeRequest", # Enums "HealthStatus", "ServerMode", diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index b69b93d31..8710858f8 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -40,9 +40,13 @@ from .interfaces import Environment from .mcp_environment import get_server_tools from .mcp_types import ( + CallToolAction, + CallToolObservation, JsonRpcErrorCode, JsonRpcRequest, JsonRpcResponse, + ListToolsAction, + ListToolsObservation, McpMethod, WSMCPMessage, WSMCPResponse, @@ -53,8 +57,12 @@ Action, ConcurrencyConfig, EnvironmentMetadata, + GetTaskRangeRequest, + GetTaskRequest, HealthResponse, HealthStatus, + ListTasksRequest, + NumTasksRequest, Observation, ResetRequest, ResetResponse, @@ -106,6 +114,18 @@ def _make_json_serializable(obj: Any) -> Any: return str(obj) +async def _maybe_await(value: Any) -> Any: + """Await values returned by async task APIs while preserving sync APIs.""" + if inspect.isawaitable(value): + return await value + return value + + +def _overrides_method(method: Any, base_method: Any) -> bool: + """Return whether a bound method differs from the base implementation.""" + return getattr(method, "__func__", method) is not base_method + + from .exceptions import ( ConcurrencyConfigurationError, EnvironmentFactoryError, @@ -150,6 +170,7 @@ def __init__( observation_cls: Type[Observation], max_concurrent_envs: Optional[int] = None, concurrency_config: Optional[ConcurrencyConfig] = None, + env_name: Optional[str] = None, ): """ Initialize HTTP server wrapper. @@ -163,6 +184,7 @@ def __init__( Mutually exclusive with concurrency_config. concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. Mutually exclusive with max_concurrent_envs. + env_name: Public environment name used by task/split endpoints. Raises: ValueError: If both max_concurrent_envs and concurrency_config are provided. @@ -206,6 +228,7 @@ def __init__( self.action_cls = action_cls self.observation_cls = observation_cls + self.env_name = env_name or self._default_env_name() # Session management for WebSocket connections self._sessions: Dict[str, Optional[Environment]] = {} @@ -231,6 +254,12 @@ def __init__( ) self._reaper_task: Optional[asyncio.Task[None]] = None + def _default_env_name(self) -> str: + factory = self._env_factory + if inspect.isclass(factory): + return factory.__name__ + return getattr(factory, "__name__", "environment") + def _validate_concurrency_safety(self) -> None: """ Validate that the environment supports the configured concurrency level. @@ -610,7 +639,7 @@ async def reset_handler( try: kwargs = request.model_dump(exclude_unset=True) - is_async = _env.reset_async.__func__ is not Environment.reset_async + is_async = _overrides_method(_env.reset_async, Environment.reset_async) if is_async: sig = inspect.signature(_env.reset_async) @@ -645,7 +674,7 @@ async def step_handler(request: StepRequest) -> StepResponse: try: kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) - is_async = _env.step_async.__func__ is not Environment.step_async + is_async = _overrides_method(_env.step_async, Environment.step_async) if is_async: sig = inspect.signature(_env.step_async) @@ -812,9 +841,40 @@ async def mcp_handler( mcp_server = getattr(_env, "mcp_server", None) mcp_session_factory = getattr(_env, "mcp_session", None) + async def call_mcp_style_step(action: Action) -> Observation: + is_async = _overrides_method( + _env.step_async, Environment.step_async + ) + if is_async: + return await _env.step_async(action) + if managed_session_id: + return await self._run_in_session_executor( + managed_session_id, + _env.step, + action, + ) + return await self._run_sync_in_thread_pool(_env.step, action) + + supports_mcp_style_actions = self.action_cls in { + CallToolAction, + ListToolsAction, + } + if method == McpMethod.TOOLS_LIST: # Check if environment is MCP-enabled if mcp_client is None and mcp_server is None: + if supports_mcp_style_actions: + observation = await call_mcp_style_step(ListToolsAction()) + if isinstance(observation, ListToolsObservation): + return JsonRpcResponse.success( + result={ + "tools": [ + tool.model_dump() + for tool in observation.tools + ] + }, + request_id=request_id, + ) return JsonRpcResponse.error_response( JsonRpcErrorCode.INTERNAL_ERROR, "Environment does not support MCP", @@ -875,17 +935,29 @@ async def mcp_handler( tool_name = params.get("name") arguments = params.get("arguments", {}) - if mcp_client is None and mcp_server is None: + if not tool_name: return JsonRpcResponse.error_response( - JsonRpcErrorCode.INTERNAL_ERROR, - "Environment does not support MCP", + JsonRpcErrorCode.INVALID_PARAMS, + "Missing 'name' in params", request_id=request_id, ) - if not tool_name: + if mcp_client is None and mcp_server is None: + if supports_mcp_style_actions: + observation = await call_mcp_style_step( + CallToolAction( + tool_name=tool_name, + arguments=arguments, + ) + ) + if isinstance(observation, CallToolObservation): + return JsonRpcResponse.success( + result=_make_json_serializable(observation), + request_id=request_id, + ) return JsonRpcResponse.error_response( - JsonRpcErrorCode.INVALID_PARAMS, - "Missing 'name' in params", + JsonRpcErrorCode.INTERNAL_ERROR, + "Environment does not support MCP", request_id=request_id, ) @@ -961,6 +1033,96 @@ async def mcp_handler( if should_close: _env.close() + def _check_env_name(env_name: str) -> None: + if env_name.lower() != self.env_name.lower(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"unknown environment {env_name!r}", + ) + + def _normalize_split(split: Any) -> Dict[str, Any]: + if hasattr(split, "model_dump"): + return cast(Dict[str, Any], split.model_dump()) + if isinstance(split, dict): + return _make_json_serializable(split) + if isinstance(split, str): + split_type = ( + split if split in {"train", "validation", "test"} else "validation" + ) + return {"name": split, "type": split_type} + return {"name": str(split), "type": "validation"} + + async def _call_task_method(method_name: str, *args: Any) -> Any: + _env = self._env_factory() + try: + method = getattr(_env, method_name, None) + if not callable(method): + raise NotImplementedError + return await _maybe_await(method(*args)) + except NotImplementedError as e: + raise HTTPException( + status_code=status.HTTP_501_NOT_IMPLEMENTED, + detail=f"{method_name} is not supported for this environment", + ) from e + except IndexError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid task index", + ) from e + finally: + _env.close() + + @app.get("/list_environments", tags=["Task API"]) + async def list_environments() -> list[str]: + return [self.env_name] + + @app.get("/{env_name}/splits", tags=["Task API"]) + async def list_splits(env_name: str) -> list[Dict[str, Any]]: + _check_env_name(env_name) + splits = await _call_task_method("list_splits") + return [_normalize_split(split) for split in splits] + + @app.post("/{env_name}/tasks", tags=["Task API"]) + async def list_tasks( + env_name: str, + request: ListTasksRequest, + ) -> Dict[str, Any]: + _check_env_name(env_name) + tasks = await _call_task_method("list_tasks", request.split) + return { + "tasks": _make_json_serializable(tasks), + "env_name": self.env_name, + } + + @app.post("/{env_name}/num_tasks", tags=["Task API"]) + async def num_tasks( + env_name: str, + request: NumTasksRequest, + ) -> Dict[str, Any]: + _check_env_name(env_name) + count = await _call_task_method("num_tasks", request.split) + return {"num_tasks": count} + + @app.post("/{env_name}/task", tags=["Task API"]) + async def get_task( + env_name: str, + request: GetTaskRequest, + ) -> Dict[str, Any]: + _check_env_name(env_name) + task = await _call_task_method("get_task", request.split, request.index) + return {"task": _make_json_serializable(task)} + + @app.post("/{env_name}/task_range", tags=["Task API"]) + async def get_task_range( + env_name: str, + request: GetTaskRangeRequest, + ) -> Dict[str, Any]: + _check_env_name(env_name) + tasks = await _call_task_method( + "get_task_range", request.split, request.start, request.stop + ) + return {"tasks": _make_json_serializable(tasks)} + # Register MCP WebSocket endpoint (available in both production and simulation modes) @app.websocket("/mcp") async def mcp_websocket_endpoint(websocket: WebSocket): @@ -1354,9 +1516,9 @@ async def websocket_endpoint(websocket: WebSocket): case "reset": msg = WSResetMessage(**message_dict) - is_async = ( - session_env.reset_async.__func__ - is not Environment.reset_async + is_async = _overrides_method( + session_env.reset_async, + Environment.reset_async, ) if is_async: @@ -1392,9 +1554,9 @@ async def websocket_endpoint(websocket: WebSocket): msg.data, self.action_cls ) - is_async = ( - session_env.step_async.__func__ - is not Environment.step_async + is_async = _overrides_method( + session_env.step_async, + Environment.step_async, ) if is_async: @@ -1586,7 +1748,12 @@ def create_app( else: # Use standard FastAPI app without web interface return create_fastapi_app( - env, action_cls, observation_cls, max_concurrent_envs, concurrency_config + env, + action_cls, + observation_cls, + max_concurrent_envs, + concurrency_config, + env_name=env_name, ) @@ -1596,6 +1763,7 @@ def create_fastapi_app( observation_cls: Type[Observation], max_concurrent_envs: Optional[int] = None, concurrency_config: Optional[ConcurrencyConfig] = None, + env_name: Optional[str] = None, ) -> FastAPI: """ Create a FastAPI application with comprehensive documentation. @@ -1608,6 +1776,7 @@ def create_fastapi_app( Mutually exclusive with concurrency_config. concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. Mutually exclusive with max_concurrent_envs. + env_name: Optional environment name for task/split endpoints. Returns: FastAPI application instance @@ -1685,6 +1854,7 @@ def create_fastapi_app( observation_cls, max_concurrent_envs, concurrency_config=concurrency_config, + env_name=env_name, ) server.register_routes(app) return app diff --git a/src/openenv/core/env_server/interfaces.py b/src/openenv/core/env_server/interfaces.py index 4c77266d9..83de41c66 100644 --- a/src/openenv/core/env_server/interfaces.py +++ b/src/openenv/core/env_server/interfaces.py @@ -74,6 +74,40 @@ def decode( ... +class TaskProvider(Protocol): + """Optional task discovery API for dataset-backed environments. + + Task provider methods are for metadata/discovery only and should be + side-effect-free. The HTTP compatibility routes may call them on a + short-lived environment instance. + """ + + def list_splits(self) -> list[Any]: + """Return task split descriptors supported by this environment.""" + ... + + def list_tasks(self, split: str) -> list[Any]: + """Return all task specs for a split.""" + ... + + def num_tasks(self, split: str) -> int: + """Return the number of task specs in a split.""" + ... + + def get_task(self, split: str, index: int) -> Any: + """Return one task spec by split and index.""" + ... + + def get_task_range( + self, + split: str, + start: Optional[int] = None, + stop: Optional[int] = None, + ) -> list[Any]: + """Return task specs for Python slice-style range bounds.""" + ... + + class Transform(ABC, Generic[ObsT]): """Transform observations to add rewards, metrics, or other modifications. diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 34a198013..00f229da9 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -245,6 +245,33 @@ class HealthResponse(BaseMessage): ) +class ListTasksRequest(BaseMessage): + """Request model for ORS-compatible task listing.""" + + split: str = Field(description="Task split name") + + +class NumTasksRequest(BaseMessage): + """Request model for ORS-compatible task counts.""" + + split: str = Field(description="Task split name") + + +class GetTaskRequest(BaseMessage): + """Request model for ORS-compatible task lookup.""" + + split: str = Field(description="Task split name") + index: int = Field(description="Task index within the split") + + +class GetTaskRangeRequest(BaseMessage): + """Request model for ORS-compatible task range lookup.""" + + split: str = Field(description="Task split name") + start: Optional[int] = Field(default=None, description="Inclusive start index") + stop: Optional[int] = Field(default=None, description="Exclusive stop index") + + class WSResetMessage(BaseMessage): """WebSocket message to reset the environment.""" diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index 5555afae0..996471fed 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -32,6 +32,11 @@ from .serialization import deserialize_action_with_preprocessing, serialize_observation from .types import Action, EnvironmentMetadata, Observation, State + +def _overrides_method(method: Any, base_method: Any) -> bool: + return getattr(method, "__func__", method) is not base_method + + # Quick Start markdown template; placeholders match init suffixes (__ENV_NAME__, __ENV_CLASS_NAME__*). DEFAULT_QUICK_START_MARKDOWN = """ ### Connect to this environment @@ -347,7 +352,7 @@ async def reset_environment( """Reset the environment and update state.""" reset_kwargs = reset_kwargs or {} - is_async = self.env.reset_async.__func__ is not Environment.reset_async + is_async = _overrides_method(self.env.reset_async, Environment.reset_async) sig = inspect.signature(self.env.reset_async if is_async else self.env.reset) valid_kwargs = self._get_valid_kwargs(sig, reset_kwargs) @@ -475,7 +480,12 @@ def create_web_interface_app( # Create the base environment app app = create_fastapi_app( - env, action_cls, observation_cls, max_concurrent_envs, concurrency_config + env, + action_cls, + observation_cls, + max_concurrent_envs, + concurrency_config, + env_name=env_name, ) # Load environment metadata diff --git a/tests/core/test_task_api.py b/tests/core/test_task_api.py new file mode 100644 index 000000000..8292f3a6d --- /dev/null +++ b/tests/core/test_task_api.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for ORS-compatible task and split endpoints.""" + +import functools + +from fastapi import FastAPI +from fastapi.testclient import TestClient +from openenv.core.env_server.http_server import create_app, HTTPEnvServer +from openenv.core.env_server.interfaces import Environment +from openenv.core.env_server.mcp_types import ( + CallToolAction, + CallToolObservation, + ListToolsAction, + ListToolsObservation, +) +from openenv.core.env_server.types import Action, Observation, State + + +class TaskAction(Action): + value: str = "" + + +class TaskObservation(Observation): + message: str = "" + + +class TaskEnvironment(Environment): + def reset(self, **kwargs) -> TaskObservation: + return TaskObservation(message="ready") + + def step(self, action: TaskAction, **kwargs) -> TaskObservation: + return TaskObservation(message=action.value, reward=1.0) + + @property + def state(self) -> State: + return State() + + def list_splits(self) -> list[str]: + return ["train", "holdout"] + + def list_tasks(self, split: str) -> list[dict[str, str]]: + return [{"id": f"{split}-0"}, {"id": f"{split}-1"}] + + def num_tasks(self, split: str) -> int: + return 2 + + def get_task(self, split: str, index: int) -> dict[str, str | int]: + return {"id": f"{split}-{index}", "index": index} + + def get_task_range( + self, split: str, start: int | None = None, stop: int | None = None + ) -> list[dict[str, str | int]]: + start = 0 if start is None else start + stop = 2 if stop is None else stop + return [{"id": f"{split}-{i}", "index": i} for i in range(start, stop)] + + +class UnsupportedTaskEnvironment(Environment): + def reset(self, **kwargs) -> TaskObservation: + return TaskObservation(message="ready") + + def step(self, action: TaskAction, **kwargs) -> TaskObservation: + return TaskObservation(message=action.value) + + @property + def state(self) -> State: + return State() + + +class PartialStepAsyncEnvironment(Environment): + def __init__(self): + async def step_async(action: CallToolAction, **kwargs) -> CallToolObservation: + return CallToolObservation(tool_name=action.tool_name, result={"ok": True}) + + self.step_async = functools.partial(step_async) # type: ignore[method-assign] + + def reset(self, **kwargs) -> TaskObservation: + return TaskObservation(message="ready") + + def step(self, action: CallToolAction, **kwargs) -> CallToolObservation: + return CallToolObservation(tool_name=action.tool_name, result={"sync": True}) + + @property + def state(self) -> State: + return State() + + +class PartialMcpStepAsyncEnvironment(Environment): + def __init__(self): + async def step_async(action: Action, **kwargs) -> Observation: + if isinstance(action, ListToolsAction): + return ListToolsObservation(tools=[]) + return CallToolObservation(tool_name="unknown", result=None) + + self.step_async = functools.partial(step_async) # type: ignore[method-assign] + + def reset(self, **kwargs) -> TaskObservation: + return TaskObservation(message="ready") + + def step(self, action: Action, **kwargs) -> Observation: + return TaskObservation(message="sync") + + @property + def state(self) -> State: + return State() + + +def test_task_routes_expose_ors_compatible_shapes() -> None: + app = FastAPI() + server = HTTPEnvServer( + env=TaskEnvironment, + action_cls=TaskAction, + observation_cls=TaskObservation, + env_name="task_env", + ) + server.register_routes(app) + client = TestClient(app) + + assert client.get("/list_environments").json() == ["task_env"] + assert client.get("/task_env/splits").json() == [ + {"name": "train", "type": "train"}, + {"name": "holdout", "type": "validation"}, + ] + assert client.post("/task_env/tasks", json={"split": "train"}).json() == { + "tasks": [{"id": "train-0"}, {"id": "train-1"}], + "env_name": "task_env", + } + assert client.post("/task_env/num_tasks", json={"split": "train"}).json() == { + "num_tasks": 2 + } + assert client.post( + "/task_env/task", json={"split": "train", "index": 1} + ).json() == {"task": {"id": "train-1", "index": 1}} + assert client.post( + "/task_env/task_range", + json={"split": "train", "start": 0, "stop": 2}, + ).json() == { + "tasks": [{"id": "train-0", "index": 0}, {"id": "train-1", "index": 1}] + } + + +def test_task_routes_reject_unknown_environment_name() -> None: + app = FastAPI() + server = HTTPEnvServer( + env=TaskEnvironment, + action_cls=TaskAction, + observation_cls=TaskObservation, + env_name="task_env", + ) + server.register_routes(app) + client = TestClient(app) + + response = client.get("/other_env/splits") + + assert response.status_code == 404 + + +def test_task_routes_return_501_when_environment_does_not_support_tasks() -> None: + app = FastAPI() + server = HTTPEnvServer( + env=UnsupportedTaskEnvironment, + action_cls=TaskAction, + observation_cls=TaskObservation, + env_name="plain_env", + ) + server.register_routes(app) + client = TestClient(app) + + response = client.get("/plain_env/splits") + + assert response.status_code == 501 + + +def test_create_app_threads_env_name_to_task_routes() -> None: + app = create_app( + TaskEnvironment, + TaskAction, + TaskObservation, + env_name="created_env", + ) + client = TestClient(app) + + assert client.get("/list_environments").json() == ["created_env"] + assert client.get("/created_env/splits").status_code == 200 + + +def test_step_route_handles_partial_step_async() -> None: + app = create_app( + PartialStepAsyncEnvironment, + CallToolAction, + CallToolObservation, + env_name="partial_env", + ) + client = TestClient(app) + + response = client.post( + "/step", + json={"action": {"tool_name": "submit", "arguments": {}}}, + ) + + assert response.status_code == 200 + assert response.json()["observation"]["result"] == {"ok": True} + + +def test_mcp_style_step_handles_partial_step_async() -> None: + app = create_app( + PartialMcpStepAsyncEnvironment, + CallToolAction, + CallToolObservation, + env_name="partial_mcp_env", + ) + client = TestClient(app) + + response = client.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "tools/list", "params": {}, "id": 1}, + ) + + assert response.status_code == 200 + assert response.json()["result"] == {"tools": []} diff --git a/tests/test_cli/test_import_env.py b/tests/test_cli/test_import_env.py new file mode 100644 index 000000000..2cd16a6ce --- /dev/null +++ b/tests/test_cli/test_import_env.py @@ -0,0 +1,598 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for deterministic OpenEnv environment import.""" + +from __future__ import annotations + +import sys +from pathlib import Path +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient +from openenv.cli.__main__ import app +from openenv.cli.importers.ors import detect_ors_dependencies, detect_ors_environments +from openenv.cli.importers.verifiers import detect_verifiers_environments +from openenv.core.env_server.mcp_types import CallToolAction, ListToolsAction +from typer.testing import CliRunner + + +runner = CliRunner() + + +def _write_fake_ors_sdk(root: Path) -> None: + ors_dir = root / "ors" + ors_dir.mkdir(parents=True) + (ors_dir / "__init__.py").write_text( + "from .environment import Environment, ListToolsOutput, Split, TextBlock, " + "ToolOutput, ToolSpec\n", + encoding="utf-8", + ) + (ors_dir / "environment.py").write_text( + """ +class _Model: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def model_dump(self): + return dict(self.__dict__) + + +class Split(_Model): + pass + + +class ToolSpec(_Model): + pass + + +class ListToolsOutput(_Model): + pass + + +class TextBlock(_Model): + def __init__(self, text, detail=None, type="text"): + super().__init__(text=text, detail=detail, type=type) + + +class ToolOutput(_Model): + pass + + +class _RunToolSuccess: + ok = True + + def __init__(self, output): + self.output = output + + +class RunToolOutput: + def __init__(self, output): + self.root = _RunToolSuccess(output) + + +class Environment: + def __init__(self, task_spec=None, secrets=None): + self.task_spec = task_spec or {} + self.secrets = secrets or {} + self.setup_called = False + self.teardown_called = False + + def setup(self): + self.setup_called = True + + def teardown(self): + self.teardown_called = True +""".lstrip(), + encoding="utf-8", + ) + + +def _write_single_fake_ors_env(root: Path) -> None: + _write_fake_ors_sdk(root) + (root / "demo_env.py").write_text( + """ +from ors import Environment, ListToolsOutput, Split, TextBlock, ToolOutput, ToolSpec + + +class DemoEnvironment(Environment): + @classmethod + def list_splits(cls): + return [Split(name="train", type="train")] + + @classmethod + def list_tasks(cls, split): + return [{"id": "alpha", "goal": "answer"}] + + @classmethod + def num_tasks(cls, split): + return 1 + + @classmethod + def get_task(cls, split, index): + return cls.list_tasks(split)[index] + + @classmethod + def get_task_range(cls, split, start=None, stop=None): + return cls.list_tasks(split)[slice(start, stop)] + + @classmethod + def list_tools(cls): + return ListToolsOutput( + tools=[ + ToolSpec( + name="answer", + description="Submit an answer", + input_schema={"type": "object", "properties": {"value": {"type": "string"}}}, + ) + ] + ) + + def list_task_tools(self): + return ListToolsOutput( + tools=[ + ToolSpec( + name="hint", + description="Get a hint", + input_schema={"type": "object", "properties": {}}, + ) + ] + ) + + def get_prompt(self): + return [TextBlock(text=f"Task: {self.task_spec['id']}")] + + def _call_tool(self, name, input): + return __import__("ors.environment").environment.RunToolOutput( + ToolOutput( + blocks=[TextBlock(text=f"{name}:{input.get('value', '')}")], + metadata={"tool": name}, + reward=1.0, + finished=True, + ) + ) +""".lstrip(), + encoding="utf-8", + ) + + +def _write_fake_verifiers_sdk(root: Path) -> None: + verifiers_dir = root / "verifiers" + verifiers_dir.mkdir(parents=True) + (verifiers_dir / "__init__.py").write_text( + """ +class Environment: + pass + + +class Rubric: + async def score_rollout(self, state): + answer = state.get("answer") or state.get("task", {}).get("answer") + completion = state.get("completion") or [] + text = completion[-1].get("content", "") if completion else "" + state["reward"] = 1.0 if answer and answer in text else 0.0 + state["metrics"] = {"contains_answer": state["reward"]} + + +class SingleTurnEnv(Environment): + def __init__(self, dataset, eval_dataset=None, rubric=None): + self._dataset = dataset + self._eval_dataset = eval_dataset or dataset + self.rubric = rubric or Rubric() + + def get_dataset(self): + return self._dataset + + def get_eval_dataset(self): + return self._eval_dataset +""".lstrip(), + encoding="utf-8", + ) + + +def _write_single_fake_verifiers_env(root: Path) -> None: + _write_fake_verifiers_sdk(root) + (root / "vf_demo.py").write_text( + """ +import verifiers as vf + + +def load_environment() -> vf.Environment: + train = [ + {"prompt": [{"role": "user", "content": "Say alpha"}], "answer": "alpha", "example_id": 0}, + {"prompt": [{"role": "user", "content": "Say beta"}], "answer": "beta", "example_id": 1}, + ] + eval_rows = [ + {"prompt": [{"role": "user", "content": "Say gamma"}], "answer": "gamma", "example_id": 0} + ] + return vf.SingleTurnEnv(dataset=train, eval_dataset=eval_rows) +""".lstrip(), + encoding="utf-8", + ) + + +def test_ors_detector_finds_environment_class_without_importing_source( + tmp_path: Path, +) -> None: + source = tmp_path / "source" + source.mkdir() + (source / "envs").mkdir() + (source / "envs" / "sample.py").write_text( + """ +from ors import Environment as ORSEnvironment + +SIDE_EFFECT = 0 + + +class SampleEnv(ORSEnvironment): + pass +""".lstrip(), + encoding="utf-8", + ) + + matches = detect_ors_environments(source) + + assert len(matches) == 1 + assert matches[0].class_name == "SampleEnv" + assert matches[0].module_path == "envs.sample" + assert matches[0].source_type == "ors" + + +def test_ors_detector_finds_openreward_environments_import_path( + tmp_path: Path, +) -> None: + source = tmp_path / "source" + source.mkdir() + (source / "sample.py").write_text( + """ +from openreward.environments import Environment + + +class SampleEnv(Environment): + pass +""".lstrip(), + encoding="utf-8", + ) + + matches = detect_ors_environments(source) + + assert len(matches) == 1 + assert matches[0].class_name == "SampleEnv" + assert detect_ors_dependencies(source) == ["openreward"] + + +def test_ors_detector_returns_no_matches_for_unrelated_source(tmp_path: Path) -> None: + source = tmp_path / "source" + source.mkdir() + (source / "plain.py").write_text("class Plain: pass\n", encoding="utf-8") + + assert detect_ors_environments(source) == [] + + +def test_verifiers_detector_finds_load_environment_without_importing_source( + tmp_path: Path, +) -> None: + source = tmp_path / "source" + source.mkdir() + (source / "demo.py").write_text( + """ +import verifiers as vf + +SIDE_EFFECT = 0 + + +def load_environment() -> vf.Environment: + raise RuntimeError("should not import") +""".lstrip(), + encoding="utf-8", + ) + + matches = detect_verifiers_environments(source) + + assert len(matches) == 1 + assert matches[0].source_type == "verifiers" + assert matches[0].class_name == "load_environment" + assert matches[0].module_path == "demo" + + +def test_import_command_requires_env_class_when_multiple_ors_classes( + tmp_path: Path, +) -> None: + source = tmp_path / "source" + source.mkdir() + (source / "a.py").write_text( + "from ors import Environment\nclass First(Environment): pass\n", + encoding="utf-8", + ) + (source / "b.py").write_text( + "from ors import Environment\nclass Second(Environment): pass\n", + encoding="utf-8", + ) + + result = runner.invoke( + app, + [ + "import", + str(source), + "--name", + "imported_env", + "--output-dir", + str(tmp_path), + ], + ) + + assert result.exit_code != 0 + assert "Multiple environment entrypoints" in result.output + assert "env" in result.output + assert "class" in result.output + + +def test_import_command_detects_ors_and_generates_working_wrapper( + tmp_path: Path, +) -> None: + source = tmp_path / "source" + source.mkdir() + _write_single_fake_ors_env(source) + output_dir = tmp_path / "out" + + with patch("openenv.cli.commands.import_env._generate_uv_lock", return_value=True): + result = runner.invoke( + app, + [ + "import", + str(source), + "--name", + "imported_env", + "--output-dir", + str(output_dir), + ], + ) + + assert result.exit_code == 0, result.output + env_dir = output_dir / "imported_env" + assert (env_dir / "server" / "imported_env_environment.py").exists() + assert (env_dir / "vendor" / "source" / "demo_env.py").exists() + assert (env_dir / "vendor" / "source" / "ors" / "environment.py").exists() + + sys.path.insert(0, str(output_dir)) + try: + from imported_env.server.imported_env_environment import ( # type: ignore + ImportedEnvironment, + ) + + env = ImportedEnvironment() + assert env.list_splits() == [{"name": "train", "type": "train"}] + assert env.get_task("train", 0) == {"id": "alpha", "goal": "answer"} + with pytest.raises(RuntimeError, match="reset"): + ImportedEnvironment().step( + CallToolAction(tool_name="answer", arguments={"value": "42"}) + ) + + reset_obs = env.reset(split="train", index=0) + assert reset_obs.metadata["task_spec"] == {"id": "alpha", "goal": "answer"} + assert reset_obs.metadata["prompt"][0]["text"] == "Task: alpha" + + tools_obs = env.step(ListToolsAction()) + assert [tool.name for tool in tools_obs.tools] == ["answer", "hint"] + + call_obs = env.step( + CallToolAction(tool_name="answer", arguments={"value": "42"}) + ) + assert call_obs.reward == 1.0 + assert call_obs.done is True + assert call_obs.result["blocks"][0]["text"] == "answer:42" + + from imported_env.server.app import app as generated_app # type: ignore + + client = TestClient(generated_app) + assert client.get("/list_environments").json() == ["imported_env"] + assert client.get("/imported_env/splits").status_code == 200 + mcp_tools = client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "method": "tools/list", + "params": {}, + "id": 1, + }, + ).json() + assert mcp_tools["result"]["tools"][0]["name"] == "answer" + mcp_call = client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"name": "answer", "arguments": {"value": "42"}}, + "id": 2, + }, + ).json() + assert "reset" in mcp_call["error"]["message"] + finally: + sys.path.remove(str(output_dir)) + + +def test_import_command_handles_source_module_matching_generated_package( + tmp_path: Path, +) -> None: + source = tmp_path / "source" + source.mkdir() + _write_fake_ors_sdk(source) + (source / "collision_env.py").write_text( + """ +from ors import Environment, Split + + +class CollisionEnvironment(Environment): + @classmethod + def list_splits(cls): + return [Split(name="train", type="train")] +""".lstrip(), + encoding="utf-8", + ) + output_dir = tmp_path / "out" + + with patch("openenv.cli.commands.import_env._generate_uv_lock", return_value=True): + result = runner.invoke( + app, + [ + "import", + str(source), + "--name", + "collision_env", + "--output-dir", + str(output_dir), + ], + ) + + assert result.exit_code == 0, result.output + sys.path.insert(0, str(output_dir)) + try: + from collision_env.server.collision_env_environment import ( # type: ignore + CollisionEnvironment, + ) + + env = CollisionEnvironment() + assert env.list_splits() == [{"name": "train", "type": "train"}] + finally: + sys.path.remove(str(output_dir)) + + +def test_import_command_excludes_common_secret_files(tmp_path: Path) -> None: + source = tmp_path / "source" + source.mkdir() + _write_single_fake_ors_env(source) + (source / ".env").write_text("TOKEN=secret\n", encoding="utf-8") + (source / "secrets.yaml").write_text("token: secret\n", encoding="utf-8") + (source / "private.pem").write_text("secret\n", encoding="utf-8") + output_dir = tmp_path / "out" + + with patch("openenv.cli.commands.import_env._generate_uv_lock", return_value=True): + result = runner.invoke( + app, + [ + "import", + str(source), + "--name", + "secret_env", + "--output-dir", + str(output_dir), + ], + ) + + assert result.exit_code == 0, result.output + vendor_dir = output_dir / "secret_env" / "vendor" / "source" + assert not (vendor_dir / ".env").exists() + assert not (vendor_dir / "secrets.yaml").exists() + assert not (vendor_dir / "private.pem").exists() + + +def test_import_command_uses_detected_ors_dependency(tmp_path: Path) -> None: + source = tmp_path / "source" + source.mkdir() + (source / "demo.py").write_text( + """ +from openreward.environments import Environment + + +class DemoEnvironment(Environment): + pass +""".lstrip(), + encoding="utf-8", + ) + output_dir = tmp_path / "out" + + with patch("openenv.cli.commands.import_env._generate_uv_lock", return_value=True): + result = runner.invoke( + app, + [ + "import", + str(source), + "--name", + "openreward_env", + "--output-dir", + str(output_dir), + ], + ) + + assert result.exit_code == 0, result.output + requirements = ( + output_dir / "openreward_env" / "server" / "requirements.txt" + ).read_text(encoding="utf-8") + pyproject = (output_dir / "openreward_env" / "pyproject.toml").read_text( + encoding="utf-8" + ) + assert "openreward" in requirements + assert "openreward" in pyproject + assert "ors-sdk" not in requirements + assert "ors-sdk" not in pyproject + + +def test_import_command_detects_verifiers_and_generates_working_wrapper( + tmp_path: Path, +) -> None: + source = tmp_path / "vf_source" + source.mkdir() + _write_single_fake_verifiers_env(source) + output_dir = tmp_path / "out" + + with patch("openenv.cli.commands.import_env._generate_uv_lock", return_value=True): + result = runner.invoke( + app, + [ + "import", + str(source), + "--name", + "vf_imported_env", + "--output-dir", + str(output_dir), + ], + ) + + assert result.exit_code == 0, result.output + env_dir = output_dir / "vf_imported_env" + assert (env_dir / "server" / "vf_imported_env_environment.py").exists() + assert (env_dir / "vendor" / "vf_source" / "vf_demo.py").exists() + assert (env_dir / "vendor" / "vf_source" / "verifiers" / "__init__.py").exists() + + sys.path.insert(0, str(output_dir)) + try: + from vf_imported_env.server.vf_imported_env_environment import ( # type: ignore + VfImportedEnvironment, + ) + + env = VfImportedEnvironment() + assert env.list_splits() == [ + {"name": "train", "type": "train"}, + {"name": "eval", "type": "validation"}, + ] + with pytest.raises(RuntimeError, match="reset"): + VfImportedEnvironment().step( + CallToolAction(tool_name="submit", arguments={"completion": "alpha"}) + ) + assert env.num_tasks("train") == 2 + assert env.get_task("train", 1)["answer"] == "beta" + + reset_obs = env.reset(split="train", index=0) + assert reset_obs.metadata["prompt"][0]["content"] == "Say alpha" + + tools_obs = env.step(ListToolsAction()) + assert [tool.name for tool in tools_obs.tools] == ["submit"] + + call_obs = env.step( + CallToolAction(tool_name="submit", arguments={"completion": "alpha"}) + ) + assert call_obs.reward == 1.0 + assert call_obs.done is True + assert call_obs.result["reward"] == 1.0 + + from vf_imported_env.server.app import app as generated_app # type: ignore + + client = TestClient(generated_app) + assert client.get("/list_environments").json() == ["vf_imported_env"] + assert client.get("/vf_imported_env/splits").status_code == 200 + finally: + sys.path.remove(str(output_dir))