From 85c47c6d41785f410cace743d65e64f99c5644e8 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Thu, 4 Dec 2025 23:01:27 +0530 Subject: [PATCH 01/41] Add WebSocket support for environment interactions and enhance HTTP server capabilities - Introduced WebSocketEnvClient for persistent sessions with multi-step interactions. - Updated HTTPEnvServer to support WebSocket connections and manage multiple concurrent environments. - Added WebSocket message types and responses for better communication. - Enhanced Environment interface with concurrency safety attributes. --- src/openenv/core/__init__.py | 2 + src/openenv/core/env_server/__init__.py | 26 +- src/openenv/core/env_server/http_server.py | 293 +++++++++++++++++++- src/openenv/core/env_server/interfaces.py | 14 + src/openenv/core/env_server/types.py | 68 +++++ src/openenv/core/ws_env_client.py | 305 +++++++++++++++++++++ 6 files changed, 692 insertions(+), 16 deletions(-) create mode 100644 src/openenv/core/ws_env_client.py diff --git a/src/openenv/core/__init__.py b/src/openenv/core/__init__.py index 99507ab55..3592ead53 100644 --- a/src/openenv/core/__init__.py +++ b/src/openenv/core/__init__.py @@ -10,10 +10,12 @@ from .env_server import * from .client_types import StepResult from .http_env_client import HTTPEnvClient +from .ws_env_client import WebSocketEnvClient # Note: MCP module doesn't export anything yet __all__ = [ "HTTPEnvClient", + "WebSocketEnvClient", "StepResult", ] diff --git a/src/openenv/core/env_server/__init__.py b/src/openenv/core/env_server/__init__.py index 4e1c2d7ac..92ebbeb2d 100644 --- a/src/openenv/core/env_server/__init__.py +++ b/src/openenv/core/env_server/__init__.py @@ -15,7 +15,22 @@ deserialize_action_with_preprocessing, serialize_observation, ) -from .types import Action, Observation, State, SchemaResponse, HealthResponse +from .types import ( + Action, + Observation, + State, + SchemaResponse, + HealthResponse, + # WebSocket message types + WSMessage, + WSResetMessage, + WSStepMessage, + WSStateMessage, + WSCloseMessage, + WSObservationResponse, + WSStateResponse, + WSErrorResponse, +) from .web_interface import create_web_interface_app, WebInterfaceManager __all__ = [ @@ -30,6 +45,15 @@ "State", "SchemaResponse", "HealthResponse", + # WebSocket message types + "WSMessage", + "WSResetMessage", + "WSStepMessage", + "WSStateMessage", + "WSCloseMessage", + "WSObservationResponse", + "WSStateResponse", + "WSErrorResponse", # Base transforms "CompositeTransform", "NullTransform", diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 7fa7c0f32..41cc32315 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -8,18 +8,21 @@ HTTP server wrapper for Environment instances. This module provides utilities to wrap any Environment subclass and expose it -over HTTP endpoints that HTTPEnvClient can consume. +over HTTP endpoints that HTTPEnvClient can consume. Also supports WebSocket +connections for persistent sessions with multi-environment concurrency. """ from __future__ import annotations import asyncio import inspect +import json import os +import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Optional, Type +from typing import Any, Callable, Dict, Optional, Type, Union -from fastapi import Body, FastAPI, HTTPException, status +from fastapi import Body, FastAPI, HTTPException, WebSocket, WebSocketDisconnect, status from pydantic import ValidationError from .interfaces import Environment @@ -39,6 +42,13 @@ EnvironmentMetadata, SchemaResponse, HealthResponse, + WSResetMessage, + WSStepMessage, + WSStateMessage, + WSCloseMessage, + WSObservationResponse, + WSStateResponse, + WSErrorResponse, ) @@ -47,7 +57,8 @@ class HTTPEnvServer: HTTP server wrapper for Environment instances. This class wraps an Environment and exposes its reset(), step(), and state - methods as HTTP endpoints compatible with HTTPEnvClient. + methods as HTTP endpoints compatible with HTTPEnvClient. Also supports + WebSocket connections for persistent sessions with multi-environment concurrency. The server expects: - Action deserialization: Converts JSON dict to Action subclass @@ -57,9 +68,16 @@ class HTTPEnvServer: >>> from core.env_server import HTTPEnvServer >>> from envs.coding_env.server import CodeExecutionEnvironment >>> + >>> # Single environment (backward compatible) >>> env = CodeExecutionEnvironment() >>> server = HTTPEnvServer(env) >>> + >>> # Factory pattern for concurrent sessions + >>> server = HTTPEnvServer( + ... env=CodeExecutionEnvironment, # Pass class, not instance + ... max_concurrent_envs=4, + ... ) + >>> >>> # Register routes with FastAPI >>> from fastapi import FastAPI >>> app = FastAPI() @@ -68,21 +86,50 @@ class HTTPEnvServer: def __init__( self, - env: Environment, - action_cls: Type[Action], - observation_cls: Type[Observation], + env: Union[Environment, Callable[[], Environment], Type[Environment]], + action_cls: Type[Action] = None, + observation_cls: Type[Observation] = None, + max_concurrent_envs: int = 1, ): """ Initialize HTTP server wrapper. Args: - env: The Environment instance to wrap + env: The Environment instance, factory callable, or class to wrap. + - If an instance is provided, it's used directly (single-env mode) + - If a callable/class is provided, it's called to create new + environments for each WebSocket session (factory mode) action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns + max_concurrent_envs: Maximum number of concurrent WebSocket sessions. + Only applies when env is a factory. Default is 1. """ - self.env = env + self._env_factory: Optional[Callable[[], Environment]] = None + self._max_concurrent_envs = max_concurrent_envs + + # Determine if env is an instance or factory + if isinstance(env, Environment): + # Single instance mode (backward compatible) + self.env = env + self._env_factory = None + elif callable(env): + # Factory mode - env is a class or callable + self._env_factory = env + # Create a single instance for HTTP endpoints (backward compat) + self.env = env() + else: + raise TypeError( + f"env must be an Environment instance or callable, got {type(env)}" + ) + self.action_cls = action_cls self.observation_cls = observation_cls + + # Session management for WebSocket connections + self._sessions: Dict[str, Environment] = {} + self._session_executors: Dict[str, ThreadPoolExecutor] = {} + self._session_lock = asyncio.Lock() + # Create thread pool for running sync code in async context # This is needed for environments using sync libraries (e.g., Playwright sync API) self._executor = ThreadPoolExecutor(max_workers=1) @@ -110,6 +157,80 @@ def _get_valid_kwargs(self, sig, kwargs, skip_params=None): return valid_kwargs + async def _create_session(self) -> tuple[str, Environment]: + """ + Create a new WebSocket session with its own environment instance. + + Returns: + Tuple of (session_id, environment) + + Raises: + RuntimeError: If max concurrent sessions reached or no factory available + """ + async with self._session_lock: + if len(self._sessions) >= self._max_concurrent_envs: + raise RuntimeError( + f"Maximum concurrent environments ({self._max_concurrent_envs}) reached" + ) + + if self._env_factory is None: + # Single instance mode - use shared env (limited concurrency) + if self._sessions: + raise RuntimeError( + "Single instance mode: only one WebSocket session allowed" + ) + session_id = str(uuid.uuid4()) + self._sessions[session_id] = self.env + else: + # Factory mode - create new environment + session_id = str(uuid.uuid4()) + env = self._env_factory() + self._sessions[session_id] = env + + # Create dedicated executor for this session + self._session_executors[session_id] = ThreadPoolExecutor(max_workers=1) + + return session_id, self._sessions[session_id] + + async def _destroy_session(self, session_id: str) -> None: + """ + Destroy a WebSocket session and cleanup resources. + + Args: + session_id: The session ID to destroy + """ + async with self._session_lock: + if session_id in self._sessions: + env = self._sessions.pop(session_id) + # Call close() if environment has it + if hasattr(env, 'close') and callable(env.close): + try: + env.close() + except Exception: + pass # Best effort cleanup + + if session_id in self._session_executors: + executor = self._session_executors.pop(session_id) + executor.shutdown(wait=False) + + async def _run_in_session_executor( + self, session_id: str, func: Callable, *args, **kwargs + ) -> Any: + """Run a synchronous function in the session's thread pool executor.""" + executor = self._session_executors.get(session_id, self._executor) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(executor, lambda: func(*args, **kwargs)) + + @property + def active_sessions(self) -> int: + """Return the number of active WebSocket sessions.""" + return len(self._sessions) + + @property + def max_concurrent_envs(self) -> int: + """Return the maximum number of concurrent environments.""" + return self._max_concurrent_envs + def register_routes(self, app: FastAPI) -> None: """ Register HTTP routes on a FastAPI application. @@ -339,12 +460,141 @@ async def get_schemas() -> SchemaResponse: state=State.model_json_schema(), ) + # Register WebSocket endpoint for persistent sessions + @app.websocket("/ws") + async def websocket_endpoint(websocket: WebSocket): + """ + WebSocket endpoint for persistent environment sessions. + + Each WebSocket connection gets its own environment instance (when using + factory mode) or shares the single instance (backward compatible mode). + + Message Protocol: + - Client sends: {"type": "reset|step|state|close", "data": {...}} + - Server responds: {"type": "observation|state|error", "data": {...}} + """ + await websocket.accept() + + session_id = None + session_env = None + + try: + # Create session with dedicated environment + session_id, session_env = await self._create_session() + + while True: + # Receive message from client + raw_message = await websocket.receive_text() + + try: + message = json.loads(raw_message) + except json.JSONDecodeError as e: + error_resp = WSErrorResponse( + data={"message": f"Invalid JSON: {e}", "code": "INVALID_JSON"} + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + msg_type = message.get("type", "") + msg_data = message.get("data", {}) + + try: + if msg_type == "reset": + # Handle reset + sig = inspect.signature(session_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, msg_data) + + observation = await self._run_in_session_executor( + session_id, session_env.reset, **valid_kwargs + ) + + response = WSObservationResponse( + data=serialize_observation(observation) + ) + await websocket.send_text(response.model_dump_json()) + + elif msg_type == "step": + # Handle step + if not msg_data: + error_resp = WSErrorResponse( + data={"message": "Missing action data", "code": "MISSING_ACTION"} + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + # Deserialize action with Pydantic validation + try: + action = deserialize_action(msg_data, self.action_cls) + except ValidationError as e: + error_resp = WSErrorResponse( + data={"message": str(e), "code": "VALIDATION_ERROR", "errors": e.errors()} + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + observation = await self._run_in_session_executor( + session_id, session_env.step, action + ) + + response = WSObservationResponse( + data=serialize_observation(observation) + ) + await websocket.send_text(response.model_dump_json()) + + elif msg_type == "state": + # Handle state request + state = session_env.state + if hasattr(state, 'model_dump'): + state_data = state.model_dump() + else: + state_data = dict(state) if state else {} + + response = WSStateResponse(data=state_data) + await websocket.send_text(response.model_dump_json()) + + elif msg_type == "close": + # Client requested close + break + + else: + error_resp = WSErrorResponse( + data={"message": f"Unknown message type: {msg_type}", "code": "UNKNOWN_TYPE"} + ) + await websocket.send_text(error_resp.model_dump_json()) + + except Exception as e: + error_resp = WSErrorResponse( + data={"message": str(e), "code": "EXECUTION_ERROR"} + ) + await websocket.send_text(error_resp.model_dump_json()) + + except WebSocketDisconnect: + pass # Client disconnected normally + except RuntimeError as e: + # Could not create session (max concurrent reached) + try: + error_resp = WSErrorResponse( + data={"message": str(e), "code": "SESSION_ERROR"} + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception: + pass + finally: + # Cleanup session + if session_id: + await self._destroy_session(session_id) + try: + await websocket.close() + except Exception: + pass + def create_app( - env: Environment, + env: Union[Environment, Callable[[], Environment], Type[Environment]], action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, + max_concurrent_envs: int = 1, ) -> FastAPI: """ Create a FastAPI application with or without web interface. @@ -353,10 +603,11 @@ def create_app( including README integration for better user experience. Args: - env: The Environment instance to serve + env: The Environment instance, factory callable, or class to serve action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading + max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1) Returns: FastAPI application instance with or without web interface and README integration @@ -376,15 +627,27 @@ def create_app( return create_web_interface_app(env, action_cls, observation_cls, env_name) else: # Use standard FastAPI app without web interface - return create_fastapi_app(env, action_cls, observation_cls) + return create_fastapi_app(env, action_cls, observation_cls, max_concurrent_envs) def create_fastapi_app( - env: Environment, + env: Union[Environment, Callable[[], Environment], Type[Environment]], action_cls: Type[Action], observation_cls: Type[Observation], + max_concurrent_envs: int = 1, ) -> FastAPI: - """Create a FastAPI application with comprehensive documentation.""" + """ + Create a FastAPI application with comprehensive documentation. + + Args: + env: The Environment instance, factory callable, or class to serve + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1) + + Returns: + FastAPI application instance + """ try: from fastapi import FastAPI except ImportError: @@ -452,6 +715,6 @@ def create_fastapi_app( }, ) - server = HTTPEnvServer(env, action_cls, observation_cls) + server = HTTPEnvServer(env, action_cls, observation_cls, max_concurrent_envs) server.register_routes(app) return app diff --git a/src/openenv/core/env_server/interfaces.py b/src/openenv/core/env_server/interfaces.py index b438cd667..196e7ac82 100644 --- a/src/openenv/core/env_server/interfaces.py +++ b/src/openenv/core/env_server/interfaces.py @@ -90,7 +90,21 @@ class Environment(ABC): Args: transform: Optional transform to apply to observations + + Class Attributes: + CONCURRENCY_SAFE: Whether this environment supports concurrent sessions. + When True, multiple WebSocket connections can each have their own + environment instance (up to max_concurrent_envs). When False (default), + the environment should only be used with a single session at a time. + + Set this to True in your Environment subclass if: + - The environment uses proper session isolation (e.g., unique working dirs) + - No shared mutable state exists between instances + - External resources (databases, APIs) can handle concurrent access """ + + # Class-level flag indicating whether this environment supports concurrent sessions + CONCURRENCY_SAFE: bool = False def __init__(self, transform: Transform | None = None): self.transform = transform diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index c3ee689c0..765d6382d 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -212,3 +212,71 @@ class HealthResponse(BaseModel): ) status: str = Field(description="Health status of the environment server") + +class WSMessage(BaseModel): + """Base class for WebSocket messages.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + type: str = Field(description="Message type identifier") + + +class WSResetMessage(WSMessage): + """WebSocket message to reset the environment.""" + + type: str = Field(default="reset", description="Message type") + data: Dict[str, Any] = Field( + default_factory=dict, + description="Optional reset parameters (seed, episode_id, etc.)", + ) + + +class WSStepMessage(WSMessage): + """WebSocket message to execute a step.""" + + type: str = Field(default="step", description="Message type") + data: Dict[str, Any] = Field( + ..., description="Action data conforming to environment's action schema" + ) + + +class WSStateMessage(WSMessage): + """WebSocket message to request current state.""" + + type: str = Field(default="state", description="Message type") + + +class WSCloseMessage(WSMessage): + """WebSocket message to close the session.""" + + type: str = Field(default="close", description="Message type") + + +class WSObservationResponse(BaseModel): + """WebSocket response containing an observation.""" + + model_config = ConfigDict(extra="forbid") + + type: str = Field(default="observation", description="Response type") + data: Dict[str, Any] = Field(description="Observation data") + + +class WSStateResponse(BaseModel): + """WebSocket response containing environment state.""" + + model_config = ConfigDict(extra="forbid") + + type: str = Field(default="state", description="Response type") + data: Dict[str, Any] = Field(description="State data") + + +class WSErrorResponse(BaseModel): + """WebSocket response for errors.""" + + model_config = ConfigDict(extra="forbid") + + type: str = Field(default="error", description="Response type") + data: Dict[str, Any] = Field(description="Error details including message and code") diff --git a/src/openenv/core/ws_env_client.py b/src/openenv/core/ws_env_client.py new file mode 100644 index 000000000..c6f054e85 --- /dev/null +++ b/src/openenv/core/ws_env_client.py @@ -0,0 +1,305 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +WebSocket-based environment client for persistent sessions. + +This module provides a WebSocket client that maintains a persistent connection +to an environment server, enabling efficient multi-step interactions without +the overhead of HTTP request/response cycles. +""" + +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar + +from .client_types import StepResult +from .containers.runtime import LocalDockerProvider + +if TYPE_CHECKING: + from .containers.runtime import ContainerProvider + from websockets.sync.client import ClientConnection + +try: + import websockets + from websockets.sync.client import connect as ws_connect +except ImportError: + websockets = None # type: ignore + ws_connect = None # type: ignore + +ActT = TypeVar("ActT") +ObsT = TypeVar("ObsT") +WSEnvClientT = TypeVar("WSEnvClientT", bound="WebSocketEnvClient") + + +class WebSocketEnvClient(ABC, Generic[ActT, ObsT]): + """ + WebSocket-based environment client for persistent sessions. + + This client maintains a persistent WebSocket connection to an environment + server, enabling efficient multi-step interactions. Each client instance + corresponds to a dedicated environment session on the server. + + Compared to HTTPEnvClient: + - Lower latency for sequential interactions + - Session state is maintained server-side + - Better suited for long-running episodes + + Example: + >>> from envs.coding_env.client import CodingEnvWS + >>> + >>> # Connect to a server via WebSocket + >>> with CodingEnvWS(base_url="ws://localhost:8000") as env: + ... result = env.reset(seed=42) + ... while not result.done: + ... action = agent.predict(result.observation) + ... result = env.step(action) + """ + + def __init__( + self, + base_url: str, + connect_timeout_s: float = 10.0, + message_timeout_s: float = 60.0, + provider: Optional["ContainerProvider"] = None, + ): + """ + Initialize WebSocket client. + + Args: + base_url: Base URL of the environment server (http:// or ws://). + Will be converted to ws:// if http:// is provided. + connect_timeout_s: Timeout for establishing WebSocket connection + message_timeout_s: Timeout for receiving responses to messages + provider: Optional container provider for lifecycle management + """ + if websockets is None: + raise ImportError( + "websockets library is required for WebSocketEnvClient. " + "Install with: pip install websockets" + ) + + # Convert HTTP URL to WebSocket URL + ws_url = base_url.rstrip("/") + if ws_url.startswith("http://"): + ws_url = "ws://" + ws_url[7:] + elif ws_url.startswith("https://"): + ws_url = "wss://" + ws_url[8:] + elif not ws_url.startswith("ws://") and not ws_url.startswith("wss://"): + ws_url = "ws://" + ws_url + + self._ws_url = f"{ws_url}/ws" + self._connect_timeout = connect_timeout_s + self._message_timeout = message_timeout_s + self._provider = provider + self._ws: Optional[ClientConnection] = None + + def connect(self) -> "WebSocketEnvClient": + """ + Establish WebSocket connection to the server. + + Returns: + self for method chaining + + Raises: + ConnectionError: If connection cannot be established + """ + if self._ws is not None: + return self + + try: + self._ws = ws_connect( + self._ws_url, + open_timeout=self._connect_timeout, + ) + except Exception as e: + raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e + + return self + + def disconnect(self) -> None: + """Close the WebSocket connection.""" + if self._ws is not None: + try: + # Send close message + self._send({"type": "close"}) + except Exception: + pass # Best effort + try: + self._ws.close() + except Exception: + pass + self._ws = None + + def _ensure_connected(self) -> None: + """Ensure WebSocket connection is established.""" + if self._ws is None: + self.connect() + + def _send(self, message: Dict[str, Any]) -> None: + """Send a message over the WebSocket.""" + self._ensure_connected() + assert self._ws is not None + self._ws.send(json.dumps(message)) + + def _receive(self) -> Dict[str, Any]: + """Receive and parse a message from the WebSocket.""" + assert self._ws is not None + raw = self._ws.recv(timeout=self._message_timeout) + return json.loads(raw) + + def _send_and_receive(self, message: Dict[str, Any]) -> Dict[str, Any]: + """Send a message and wait for response.""" + self._send(message) + response = self._receive() + + # Check for error response + if response.get("type") == "error": + error_data = response.get("data", {}) + raise RuntimeError( + f"Server error: {error_data.get('message', 'Unknown error')} " + f"(code: {error_data.get('code', 'UNKNOWN')})" + ) + + return response + + @classmethod + def from_docker_image( + cls: Type[WSEnvClientT], + image: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> WSEnvClientT: + """ + Create a WebSocket environment client by spinning up a Docker container. + + Args: + image: Docker image name to run (e.g., "coding-env:latest") + provider: Container provider to use (defaults to LocalDockerProvider) + **kwargs: Additional arguments to pass to provider.start_container() + + Returns: + Connected WebSocket client instance + """ + if provider is None: + provider = LocalDockerProvider() + + # Start container + base_url = provider.start_container(image, **kwargs) + + # Wait for server to be ready + provider.wait_for_ready(base_url) + + # Create and connect client + client = cls(base_url=base_url, provider=provider) + client.connect() + + return client + + @classmethod + def from_hub( + cls: Type[WSEnvClientT], + repo_id: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> WSEnvClientT: + """ + Create a WebSocket client by pulling from a Hugging Face model hub. + """ + if provider is None: + provider = LocalDockerProvider() + + tag = kwargs.pop("tag", "latest") + base_url = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}" + + return cls.from_docker_image(image=base_url, provider=provider, **kwargs) + + @abstractmethod + def _step_payload(self, action: ActT) -> dict: + """Convert an Action object to the JSON data expected by the env server.""" + raise NotImplementedError + + @abstractmethod + def _parse_result(self, payload: dict) -> StepResult[ObsT]: + """Convert a JSON response from the env server to StepResult[ObsT].""" + raise NotImplementedError + + @abstractmethod + def _parse_state(self, payload: dict) -> Any: + """Convert a JSON response from the state endpoint to a State object.""" + raise NotImplementedError + + def reset(self, **kwargs: Any) -> StepResult[ObsT]: + """ + Reset the environment with optional parameters. + + Args: + **kwargs: Optional parameters passed to the environment's reset method. + Common parameters include: + - seed: Random seed for reproducibility + - episode_id: Custom episode identifier + + Returns: + StepResult containing initial observation + """ + message = { + "type": "reset", + "data": kwargs, + } + response = self._send_and_receive(message) + return self._parse_result(response.get("data", {})) + + def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: + """ + Execute an action in the environment. + + Args: + action: The action to execute + **kwargs: Optional parameters (currently ignored for WebSocket) + + Returns: + StepResult containing observation, reward, and done status + """ + message = { + "type": "step", + "data": self._step_payload(action), + } + response = self._send_and_receive(message) + return self._parse_result(response.get("data", {})) + + def state(self) -> Any: + """ + Get the current environment state from the server. + + Returns: + State object with environment state information + """ + message = {"type": "state"} + response = self._send_and_receive(message) + return self._parse_state(response.get("data", {})) + + def close(self) -> None: + """ + Close the WebSocket connection and clean up resources. + + If this client was created via from_docker_image(), this will also + stop and remove the associated container. + """ + self.disconnect() + + if self._provider is not None: + self._provider.stop_container() + + def __enter__(self) -> "WebSocketEnvClient": + """Enter context manager, ensuring connection is established.""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit context manager, closing connection.""" + self.close() From e0a063d5833c5ff421bdf4368539adb131ad8b55 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Thu, 4 Dec 2025 23:43:09 +0530 Subject: [PATCH 02/41] impl concurrency management and session handling --- src/openenv/core/env_server/__init__.py | 23 ++- src/openenv/core/env_server/exceptions.py | 105 ++++++++++++ src/openenv/core/env_server/http_server.py | 176 +++++++++++++++++++-- src/openenv/core/env_server/types.py | 96 +++++++++++ 4 files changed, 384 insertions(+), 16 deletions(-) create mode 100644 src/openenv/core/env_server/exceptions.py diff --git a/src/openenv/core/env_server/__init__.py b/src/openenv/core/env_server/__init__.py index 92ebbeb2d..e1014540e 100644 --- a/src/openenv/core/env_server/__init__.py +++ b/src/openenv/core/env_server/__init__.py @@ -21,7 +21,6 @@ State, SchemaResponse, HealthResponse, - # WebSocket message types WSMessage, WSResetMessage, WSStepMessage, @@ -30,6 +29,17 @@ WSObservationResponse, WSStateResponse, WSErrorResponse, + ConcurrencyConfig, + ServerCapacityStatus, + SessionInfo, +) +from .exceptions import ( + OpenEnvError, + ConcurrencyConfigurationError, + SessionCapacityError, + SessionNotFoundError, + SessionCreationError, + EnvironmentFactoryError, ) from .web_interface import create_web_interface_app, WebInterfaceManager @@ -54,6 +64,17 @@ "WSObservationResponse", "WSStateResponse", "WSErrorResponse", + # Concurrency types + "ConcurrencyConfig", + "ServerCapacityStatus", + "SessionInfo", + # Exceptions + "OpenEnvError", + "ConcurrencyConfigurationError", + "SessionCapacityError", + "SessionNotFoundError", + "SessionCreationError", + "EnvironmentFactoryError", # Base transforms "CompositeTransform", "NullTransform", diff --git a/src/openenv/core/env_server/exceptions.py b/src/openenv/core/env_server/exceptions.py new file mode 100644 index 000000000..41a8235bb --- /dev/null +++ b/src/openenv/core/env_server/exceptions.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Custom exceptions for environment server operations.""" + +from typing import Optional + + +class OpenEnvError(Exception): + """Base exception for all OpenEnv errors.""" + + pass + + +class ConcurrencyConfigurationError(OpenEnvError): + """ + Raised when an environment is misconfigured for concurrent sessions. + + This error is raised during server startup when max_concurrent_envs > 1 + is specified for an environment that is not marked as CONCURRENCY_SAFE. + """ + + def __init__( + self, + environment_name: str, + max_concurrent_envs: int, + message: Optional[str] = None, + ): + self.environment_name = environment_name + self.max_concurrent_envs = max_concurrent_envs + + if message is None: + message = ( + f"Environment '{environment_name}' is not marked as CONCURRENCY_SAFE. " + f"Cannot run with max_concurrent_envs={max_concurrent_envs}. " + f"Either set max_concurrent_envs=1 or ensure the environment " + f"properly isolates session state and set CONCURRENCY_SAFE=True." + ) + + super().__init__(message) + + +class SessionCapacityError(OpenEnvError): + """ + Raised when the server cannot accept new sessions due to capacity limits. + + This error is raised when a new WebSocket connection is attempted but + the server has already reached max_concurrent_envs active sessions. + """ + + def __init__( + self, + active_sessions: int, + max_sessions: int, + message: Optional[str] = None, + ): + self.active_sessions = active_sessions + self.max_sessions = max_sessions + + if message is None: + message = ( + f"Server at capacity: {active_sessions}/{max_sessions} sessions active. " + f"Cannot accept new connections." + ) + + super().__init__(message) + + +class SessionNotFoundError(OpenEnvError): + """Raised when attempting to access a session that does not exist.""" + + def __init__(self, session_id: str, message: Optional[str] = None): + self.session_id = session_id + + if message is None: + message = f"Session '{session_id}' not found." + + super().__init__(message) + + +class SessionCreationError(OpenEnvError): + """Raised when a session cannot be created.""" + + def __init__(self, reason: str, message: Optional[str] = None): + self.reason = reason + + if message is None: + message = f"Failed to create session: {reason}" + + super().__init__(message) + + +class EnvironmentFactoryError(OpenEnvError): + """Raised when the environment factory fails to create an instance.""" + + def __init__(self, factory_name: str, cause: Exception): + self.factory_name = factory_name + self.cause = cause + + message = f"Environment factory '{factory_name}' failed to create instance: {cause}" + + super().__init__(message) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 41cc32315..50eaac13d 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -49,6 +49,14 @@ WSObservationResponse, WSStateResponse, WSErrorResponse, + ConcurrencyConfig, + ServerCapacityStatus, + SessionInfo, +) +from .exceptions import ( + ConcurrencyConfigurationError, + SessionCapacityError, + EnvironmentFactoryError, ) @@ -90,6 +98,7 @@ def __init__( action_cls: Type[Action] = None, observation_cls: Type[Observation] = None, max_concurrent_envs: int = 1, + skip_concurrency_check: bool = False, ): """ Initialize HTTP server wrapper. @@ -103,9 +112,19 @@ def __init__( observation_cls: The Observation subclass this environment returns max_concurrent_envs: Maximum number of concurrent WebSocket sessions. Only applies when env is a factory. Default is 1. + skip_concurrency_check: If True, skip concurrency safety validation. + Use with caution for advanced users who understand + the isolation requirements. + + Raises: + ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an + environment that is not marked as CONCURRENCY_SAFE. """ self._env_factory: Optional[Callable[[], Environment]] = None self._max_concurrent_envs = max_concurrent_envs + self._skip_concurrency_check = skip_concurrency_check or os.getenv( + "OPENENV_SKIP_CONCURRENCY_CHECK", "" + ).lower() in ("1", "true", "yes") # Determine if env is an instance or factory if isinstance(env, Environment): @@ -116,24 +135,67 @@ def __init__( # Factory mode - env is a class or callable self._env_factory = env # Create a single instance for HTTP endpoints (backward compat) - self.env = env() + try: + self.env = env() + except Exception as e: + factory_name = getattr(env, "__name__", str(env)) + raise EnvironmentFactoryError(factory_name, e) from e else: raise TypeError( f"env must be an Environment instance or callable, got {type(env)}" ) + # Validate concurrency configuration + self._validate_concurrency_safety() + self.action_cls = action_cls self.observation_cls = observation_cls # Session management for WebSocket connections self._sessions: Dict[str, Environment] = {} self._session_executors: Dict[str, ThreadPoolExecutor] = {} + self._session_info: Dict[str, SessionInfo] = {} self._session_lock = asyncio.Lock() # Create thread pool for running sync code in async context # This is needed for environments using sync libraries (e.g., Playwright sync API) self._executor = ThreadPoolExecutor(max_workers=1) + def _validate_concurrency_safety(self) -> None: + """ + Validate that the environment supports the configured concurrency level. + + Raises: + ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an + environment that is not marked as CONCURRENCY_SAFE. + """ + if self._max_concurrent_envs <= 1: + return + + if self._skip_concurrency_check: + return + + is_concurrency_safe = getattr(self.env, "CONCURRENCY_SAFE", False) + + if not is_concurrency_safe: + env_name = type(self.env).__name__ + raise ConcurrencyConfigurationError( + environment_name=env_name, + max_concurrent_envs=self._max_concurrent_envs, + ) + + def get_capacity_status(self) -> ServerCapacityStatus: + """ + Get the current capacity status of the server. + + Returns: + ServerCapacityStatus with current session counts and availability. + """ + return ServerCapacityStatus.from_counts( + active=len(self._sessions), + max_sessions=self._max_concurrent_envs, + ) + async def _run_sync_in_thread_pool(self, func, *args, **kwargs): """Run a synchronous function in the thread pool executor.""" loop = asyncio.get_event_loop() @@ -165,32 +227,53 @@ async def _create_session(self) -> tuple[str, Environment]: Tuple of (session_id, environment) Raises: - RuntimeError: If max concurrent sessions reached or no factory available + SessionCapacityError: If max concurrent sessions reached + EnvironmentFactoryError: If the factory fails to create an environment """ + import time + async with self._session_lock: if len(self._sessions) >= self._max_concurrent_envs: - raise RuntimeError( - f"Maximum concurrent environments ({self._max_concurrent_envs}) reached" + raise SessionCapacityError( + active_sessions=len(self._sessions), + max_sessions=self._max_concurrent_envs, ) + session_id = str(uuid.uuid4()) + current_time = time.time() + if self._env_factory is None: # Single instance mode - use shared env (limited concurrency) if self._sessions: - raise RuntimeError( - "Single instance mode: only one WebSocket session allowed" + raise SessionCapacityError( + active_sessions=len(self._sessions), + max_sessions=1, + message="Single instance mode: only one WebSocket session allowed", ) - session_id = str(uuid.uuid4()) - self._sessions[session_id] = self.env + env = self.env else: # Factory mode - create new environment - session_id = str(uuid.uuid4()) - env = self._env_factory() - self._sessions[session_id] = env + try: + env = self._env_factory() + except Exception as e: + factory_name = getattr(self._env_factory, "__name__", str(self._env_factory)) + raise EnvironmentFactoryError(factory_name, e) from e + + self._sessions[session_id] = env # Create dedicated executor for this session self._session_executors[session_id] = ThreadPoolExecutor(max_workers=1) - return session_id, self._sessions[session_id] + # Track session metadata + self._session_info[session_id] = SessionInfo( + session_id=session_id, + created_at=current_time, + last_activity_at=current_time, + step_count=0, + environment_type=type(env).__name__, + ) + + return session_id, env async def _destroy_session(self, session_id: str) -> None: """ @@ -212,7 +295,37 @@ async def _destroy_session(self, session_id: str) -> None: if session_id in self._session_executors: executor = self._session_executors.pop(session_id) executor.shutdown(wait=False) + + # Remove session metadata + self._session_info.pop(session_id, None) + def _update_session_activity(self, session_id: str, increment_step: bool = False) -> None: + """ + Update session activity timestamp and optionally increment step count. + + Args: + session_id: The session ID to update + increment_step: If True, increment the step count + """ + import time + + if session_id in self._session_info: + self._session_info[session_id].last_activity_at = time.time() + if increment_step: + self._session_info[session_id].step_count += 1 + + def get_session_info(self, session_id: str) -> Optional[SessionInfo]: + """ + Get information about a specific session. + + Args: + session_id: The session ID to query + + Returns: + SessionInfo if the session exists, None otherwise + """ + return self._session_info.get(session_id) + async def _run_in_session_executor( self, session_id: str, func: Callable, *args, **kwargs ) -> Any: @@ -231,6 +344,11 @@ def max_concurrent_envs(self) -> int: """Return the maximum number of concurrent environments.""" return self._max_concurrent_envs + @property + def is_concurrency_safe(self) -> bool: + """Return whether the environment is marked as concurrency safe.""" + return getattr(self.env, "CONCURRENCY_SAFE", False) + def register_routes(self, app: FastAPI) -> None: """ Register HTTP routes on a FastAPI application. @@ -508,6 +626,8 @@ async def websocket_endpoint(websocket: WebSocket): session_id, session_env.reset, **valid_kwargs ) + self._update_session_activity(session_id) + response = WSObservationResponse( data=serialize_observation(observation) ) @@ -536,6 +656,8 @@ async def websocket_endpoint(websocket: WebSocket): session_id, session_env.step, action ) + self._update_session_activity(session_id, increment_step=True) + response = WSObservationResponse( data=serialize_observation(observation) ) @@ -569,9 +691,33 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_text(error_resp.model_dump_json()) except WebSocketDisconnect: - pass # Client disconnected normally - except RuntimeError as e: - # Could not create session (max concurrent reached) + pass + except SessionCapacityError as e: + try: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": "CAPACITY_REACHED", + "active_sessions": e.active_sessions, + "max_sessions": e.max_sessions, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception: + pass + except EnvironmentFactoryError as e: + try: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": "FACTORY_ERROR", + "factory_name": e.factory_name, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception: + pass + except Exception as e: try: error_resp = WSErrorResponse( data={"message": str(e), "code": "SESSION_ERROR"} diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 765d6382d..39074595f 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -280,3 +280,99 @@ class WSErrorResponse(BaseModel): type: str = Field(default="error", description="Response type") data: Dict[str, Any] = Field(description="Error details including message and code") + + +class ConcurrencySafetyLevel(str): + """ + Classification of environment concurrency safety. + + Environments are classified based on their ability to safely handle + multiple concurrent sessions within a single container. + """ + + UNSAFE = "unsafe" + SAFE = "safe" + + +class ConcurrencyConfig(BaseModel): + """Configuration for concurrent environment sessions.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + max_concurrent_envs: int = Field( + default=1, + ge=1, + le=1000, + description="Maximum number of concurrent WebSocket sessions allowed", + ) + session_timeout_seconds: Optional[float] = Field( + default=None, + gt=0, + description="Timeout in seconds for inactive sessions. None means no timeout.", + ) + reject_on_capacity: bool = Field( + default=True, + description="If True, reject new connections when at capacity. If False, queue them.", + ) + + +class ServerCapacityStatus(BaseModel): + """Status of server capacity for concurrent sessions.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + active_sessions: int = Field( + ge=0, + description="Number of currently active sessions", + ) + max_sessions: int = Field( + ge=1, + description="Maximum number of allowed sessions", + ) + available_slots: int = Field( + ge=0, + description="Number of available session slots", + ) + is_at_capacity: bool = Field( + description="Whether the server has reached maximum capacity", + ) + + @classmethod + def from_counts(cls, active: int, max_sessions: int) -> "ServerCapacityStatus": + """Create status from active and max session counts.""" + available = max(0, max_sessions - active) + return cls( + active_sessions=active, + max_sessions=max_sessions, + available_slots=available, + is_at_capacity=active >= max_sessions, + ) + + +class SessionInfo(BaseModel): + """Information about an active session.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + session_id: str = Field(description="Unique identifier for the session") + created_at: float = Field(description="Unix timestamp when the session was created") + last_activity_at: float = Field( + description="Unix timestamp of the last activity in the session" + ) + step_count: int = Field( + default=0, + ge=0, + description="Number of steps executed in this session", + ) + environment_type: str = Field( + description="Type name of the environment class for this session" + ) From 95563b0afdeb8806d37ded906544ddc9f6aceaad Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Sat, 6 Dec 2025 09:57:43 +0100 Subject: [PATCH 03/41] add async to http server --- src/openenv/core/env_server/http_server.py | 50 +++++++++++++--------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 7fa7c0f32..d301fa7e9 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -84,8 +84,14 @@ def __init__( self.action_cls = action_cls self.observation_cls = observation_cls # Create thread pool for running sync code in async context - # This is needed for environments using sync libraries (e.g., Playwright sync API) - self._executor = ThreadPoolExecutor(max_workers=1) + # This is needed for environments using sync libraries (e.g., Playwright) + # Configurable via OPENENV_THREAD_POOL_SIZE (default: 32) + pool_size = int(os.getenv("OPENENV_THREAD_POOL_SIZE", "32")) + self._executor = ThreadPoolExecutor(max_workers=pool_size) + + # Check if environment has async methods for better concurrency + self._has_step_async = hasattr(env, "step_async") and asyncio.iscoroutinefunction(env.step_async) + self._has_reset_async = hasattr(env, "reset_async") and asyncio.iscoroutinefunction(env.reset_async) async def _run_sync_in_thread_pool(self, func, *args, **kwargs): """Run a synchronous function in the thread pool executor.""" @@ -99,9 +105,7 @@ def _get_valid_kwargs(self, sig, kwargs, skip_params=None): valid_kwargs = {} - has_kwargs = any( - p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() - ) + has_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) for k, v in kwargs.items(): if k in sig.parameters or has_kwargs: @@ -128,13 +132,17 @@ async def reset_handler( kwargs = request.model_dump(exclude_unset=True) # Pass arguments only if environment accepts them - sig = inspect.signature(self.env.reset) + if self._has_reset_async: + sig = inspect.signature(self.env.reset_async) + else: + sig = inspect.signature(self.env.reset) valid_kwargs = self._get_valid_kwargs(sig, kwargs) - # Run synchronous reset in thread pool to avoid blocking event loop - observation = await self._run_sync_in_thread_pool( - self.env.reset, **valid_kwargs - ) + # Use async method if available for better concurrency + if self._has_reset_async: + observation = await self.env.reset_async(**valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool(self.env.reset, **valid_kwargs) return ResetResponse(**serialize_observation(observation)) # Helper function to handle step endpoint @@ -147,22 +155,24 @@ async def step_handler(request: StepRequest) -> StepResponse: action = deserialize_action(action_data, self.action_cls) except ValidationError as e: # Return HTTP 422 with detailed validation errors - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors() - ) + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors()) # Handle optional parameters # Start with all fields from the request, including extra ones, but exclude 'action' kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) # Pass arguments only if environment accepts them - sig = inspect.signature(self.env.step) + if self._has_step_async: + sig = inspect.signature(self.env.step_async) + else: + sig = inspect.signature(self.env.step) valid_kwargs = self._get_valid_kwargs(sig, kwargs, skip_params={"action"}) - # Run synchronous step in thread pool to avoid blocking event loop - observation = await self._run_sync_in_thread_pool( - self.env.step, action, **valid_kwargs - ) + # Use async method if available for better concurrency + if self._has_step_async: + observation = await self.env.step_async(action, **valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool(self.env.step, action, **valid_kwargs) # Return serialized observation return StepResponse(**serialize_observation(observation)) @@ -388,9 +398,7 @@ def create_fastapi_app( try: from fastapi import FastAPI except ImportError: - raise ImportError( - "FastAPI is required. Install with: pip install fastapi uvicorn" - ) + raise ImportError("FastAPI is required. Install with: pip install fastapi uvicorn") app = FastAPI( title="OpenEnv Environment HTTP API", From 3601357a9727c75f7a805c6b1364118884ce7ae8 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Mon, 8 Dec 2025 01:40:40 +0530 Subject: [PATCH 04/41] concurrency config --- src/openenv/core/__init__.py | 13 +- src/openenv/core/env_server/http_server.py | 138 ++++++++++++++++--- src/openenv/core/env_server/serialization.py | 2 +- 3 files changed, 123 insertions(+), 30 deletions(-) diff --git a/src/openenv/core/__init__.py b/src/openenv/core/__init__.py index 3592ead53..93ae09786 100644 --- a/src/openenv/core/__init__.py +++ b/src/openenv/core/__init__.py @@ -7,15 +7,10 @@ """Core components for agentic environments.""" # Re-export main components from submodules for convenience -from .env_server import * -from .client_types import StepResult -from .http_env_client import HTTPEnvClient -from .ws_env_client import WebSocketEnvClient +from .env_server import * # noqa: F403 +from .env_server import __all__ as _env_server_all + # Note: MCP module doesn't export anything yet -__all__ = [ - "HTTPEnvClient", - "WebSocketEnvClient", - "StepResult", -] +__all__ = list(_env_server_all) \ No newline at end of file diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 517809655..8dd144987 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -99,6 +99,7 @@ def __init__( observation_cls: Type[Observation] = None, max_concurrent_envs: int = 1, skip_concurrency_check: bool = False, + concurrency_config: Optional[ConcurrencyConfig] = None, ): """ Initialize HTTP server wrapper. @@ -112,16 +113,33 @@ def __init__( observation_cls: The Observation subclass this environment returns max_concurrent_envs: Maximum number of concurrent WebSocket sessions. Only applies when env is a factory. Default is 1. + If concurrency_config is provided, this parameter is ignored. skip_concurrency_check: If True, skip concurrency safety validation. Use with caution for advanced users who understand the isolation requirements. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + If provided, overrides max_concurrent_envs and allows + configuration of session timeout and capacity behavior. Raises: ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an environment that is not marked as CONCURRENCY_SAFE. """ self._env_factory: Optional[Callable[[], Environment]] = None - self._max_concurrent_envs = max_concurrent_envs + + # Handle concurrency configuration + if concurrency_config is not None: + self._concurrency_config = concurrency_config + self._max_concurrent_envs = concurrency_config.max_concurrent_envs + else: + # Use legacy parameters + self._concurrency_config = ConcurrencyConfig( + max_concurrent_envs=max_concurrent_envs, + session_timeout_seconds=None, + reject_on_capacity=True, + ) + self._max_concurrent_envs = max_concurrent_envs + self._skip_concurrency_check = skip_concurrency_check or os.getenv( "OPENENV_SKIP_CONCURRENCY_CHECK", "" ).lower() in ("1", "true", "yes") @@ -238,10 +256,18 @@ async def _create_session(self) -> tuple[str, Environment]: async with self._session_lock: if len(self._sessions) >= self._max_concurrent_envs: - raise SessionCapacityError( - active_sessions=len(self._sessions), - max_sessions=self._max_concurrent_envs, - ) + if self._concurrency_config.reject_on_capacity: + raise SessionCapacityError( + active_sessions=len(self._sessions), + max_sessions=self._max_concurrent_envs, + ) + else: + # TODO: Implement queuing mechanism when reject_on_capacity=False + raise SessionCapacityError( + active_sessions=len(self._sessions), + max_sessions=self._max_concurrent_envs, + message="Session queuing not yet implemented", + ) session_id = str(uuid.uuid4()) current_time = time.time() @@ -353,6 +379,11 @@ def is_concurrency_safe(self) -> bool: """Return whether the environment is marked as concurrency safe.""" return getattr(self.env, "CONCURRENCY_SAFE", False) + @property + def concurrency_config(self) -> ConcurrencyConfig: + """Return the concurrency configuration.""" + return self._concurrency_config + def register_routes(self, app: FastAPI) -> None: """ Register HTTP routes on a FastAPI application. @@ -539,6 +570,25 @@ async def step(request: StepRequest) -> StepResponse: ] register_get_endpoints(app, get_endpoints) + # Register concurrency config endpoint + @app.get( + "/concurrency", + response_model=ConcurrencyConfig, + tags=["Environment Info"], + summary="Get concurrency configuration", + description=""" +Get the current concurrency configuration for this server. + +Returns information about: +- **max_concurrent_envs**: Maximum number of concurrent WebSocket sessions +- **session_timeout_seconds**: Timeout for inactive sessions (None if no timeout) +- **reject_on_capacity**: Whether to reject or queue connections at capacity + """, + ) + async def get_concurrency_config() -> ConcurrencyConfig: + """Return concurrency configuration.""" + return self._concurrency_config + # Register combined schema endpoint @app.get( "/schema", @@ -598,8 +648,8 @@ async def websocket_endpoint(websocket: WebSocket): factory mode) or shares the single instance (backward compatible mode). Message Protocol: - - Client sends: {"type": "reset|step|state|close", "data": {...}} - - Server responds: {"type": "observation|state|error", "data": {...}} + - Client sends: WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage + - Server responds: WSObservationResponse | WSStateResponse | WSErrorResponse """ await websocket.accept() @@ -615,7 +665,7 @@ async def websocket_endpoint(websocket: WebSocket): raw_message = await websocket.receive_text() try: - message = json.loads(raw_message) + message_dict = json.loads(raw_message) except json.JSONDecodeError as e: error_resp = WSErrorResponse( data={"message": f"Invalid JSON: {e}", "code": "INVALID_JSON"} @@ -623,14 +673,23 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_text(error_resp.model_dump_json()) continue - msg_type = message.get("type", "") - msg_data = message.get("data", {}) + msg_type = message_dict.get("type", "") try: if msg_type == "reset": + # Parse and validate reset message + try: + msg = WSResetMessage(**message_dict) + except ValidationError as e: + error_resp = WSErrorResponse( + data={"message": "Invalid reset message", "code": "VALIDATION_ERROR", "errors": e.errors()} + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + # Handle reset sig = inspect.signature(session_env.reset) - valid_kwargs = self._get_valid_kwargs(sig, msg_data) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) observation = await self._run_in_session_executor( session_id, session_env.reset, **valid_kwargs @@ -644,17 +703,19 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_text(response.model_dump_json()) elif msg_type == "step": - # Handle step - if not msg_data: + # Parse and validate step message + try: + msg = WSStepMessage(**message_dict) + except ValidationError as e: error_resp = WSErrorResponse( - data={"message": "Missing action data", "code": "MISSING_ACTION"} + data={"message": "Invalid step message", "code": "VALIDATION_ERROR", "errors": e.errors()} ) await websocket.send_text(error_resp.model_dump_json()) continue # Deserialize action with Pydantic validation try: - action = deserialize_action(msg_data, self.action_cls) + action = deserialize_action(msg.data, self.action_cls) except ValidationError as e: error_resp = WSErrorResponse( data={"message": str(e), "code": "VALIDATION_ERROR", "errors": e.errors()} @@ -674,6 +735,16 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_text(response.model_dump_json()) elif msg_type == "state": + # Parse and validate state message + try: + msg = WSStateMessage(**message_dict) + except ValidationError as e: + error_resp = WSErrorResponse( + data={"message": "Invalid state message", "code": "VALIDATION_ERROR", "errors": e.errors()} + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + # Handle state request state = session_env.state if hasattr(state, 'model_dump'): @@ -685,6 +756,16 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_text(response.model_dump_json()) elif msg_type == "close": + # Parse and validate close message + try: + msg = WSCloseMessage(**message_dict) + except ValidationError as e: + error_resp = WSErrorResponse( + data={"message": "Invalid close message", "code": "VALIDATION_ERROR", "errors": e.errors()} + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + # Client requested close break @@ -751,6 +832,7 @@ def create_app( observation_cls: Type[Observation], env_name: Optional[str] = None, max_concurrent_envs: int = 1, + concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: """ Create a FastAPI application with or without web interface. @@ -763,7 +845,10 @@ def create_app( action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading - max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1) + max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1). + Ignored if concurrency_config is provided. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + If provided, overrides max_concurrent_envs. Returns: FastAPI application instance with or without web interface and README integration @@ -780,10 +865,16 @@ def create_app( # Import web interface only when needed from .web_interface import create_web_interface_app - return create_web_interface_app(env, action_cls, observation_cls, env_name) + return create_web_interface_app( + env, action_cls, observation_cls, env_name, + max_concurrent_envs, concurrency_config + ) else: # Use standard FastAPI app without web interface - return create_fastapi_app(env, action_cls, observation_cls, max_concurrent_envs) + return create_fastapi_app( + env, action_cls, observation_cls, + max_concurrent_envs, concurrency_config + ) def create_fastapi_app( @@ -791,6 +882,7 @@ def create_fastapi_app( action_cls: Type[Action], observation_cls: Type[Observation], max_concurrent_envs: int = 1, + concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: """ Create a FastAPI application with comprehensive documentation. @@ -799,7 +891,10 @@ def create_fastapi_app( env: The Environment instance, factory callable, or class to serve action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns - max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1) + max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1). + Ignored if concurrency_config is provided. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + If provided, overrides max_concurrent_envs. Returns: FastAPI application instance @@ -869,6 +964,9 @@ def create_fastapi_app( }, ) - server = HTTPEnvServer(env, action_cls, observation_cls, max_concurrent_envs) + server = HTTPEnvServer( + env, action_cls, observation_cls, + max_concurrent_envs, concurrency_config=concurrency_config + ) server.register_routes(app) return app diff --git a/src/openenv/core/env_server/serialization.py b/src/openenv/core/env_server/serialization.py index a97a05283..df06592f5 100644 --- a/src/openenv/core/env_server/serialization.py +++ b/src/openenv/core/env_server/serialization.py @@ -80,7 +80,7 @@ def deserialize_action_with_preprocessing( value = [] if isinstance(value, list): try: - import torch + import torch # type: ignore processed_data[key] = torch.tensor(value, dtype=torch.long) except ImportError: From 0d8fe57b16c29d5223250e095224ef5e3aa3696b Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Mon, 8 Dec 2025 13:54:19 +0100 Subject: [PATCH 05/41] update docs with repo restructure --- docs/environment-builder.md | 69 +++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/docs/environment-builder.md b/docs/environment-builder.md index 9fefc9ee1..20a793ced 100644 --- a/docs/environment-builder.md +++ b/docs/environment-builder.md @@ -58,33 +58,26 @@ my_env/ └── Dockerfile ``` -Python classes are generated for the action, observation, and state, and a client is generated for the environment. For example, you will find `MyEnvironment`, `MyAction`, `MyObservation`, and `MyState` in the `my_env` directory based on the name of the environment you provided. +Python classes are generated for the action, observation, environment, and client. For example, you will find `MyEnvironment`, `MyAction`, `MyObservation`, and `MyEnv` (client) in the `my_env` directory based on the name you provided. The environment uses the core `State` class from `openenv.core.env_server.types`. ### 2. Define Models -Edit `models.py` to describe your action, observation, and state dataclasses: +Edit `models.py` to describe your action and observation using Pydantic: ```python # models.py -from dataclasses import dataclass -from openenv.core.env_server import Action, Observation, State +from pydantic import Field +from openenv.core.env_server.types import Action, Observation -@dataclass class MyAction(Action): """Your custom action.""" - command: str - parameters: dict + command: str = Field(..., description="Command to execute") + parameters: dict = Field(default_factory=dict, description="Command parameters") -@dataclass class MyObservation(Observation): """Your custom observation.""" - result: str - success: bool - -@dataclass -class MyState(State): - """Custom state fields.""" - custom_field: int = 0 + result: str = Field(..., description="Result of the action") + success: bool = Field(..., description="Whether the action succeeded") ``` ### 3. Implement Environment Logic @@ -93,42 +86,42 @@ Customize `server/my_environment.py` by extending `Environment`: ```python # server/my_environment.py -import uuid -from openenv.core.env_server import Environment -from ..models import MyAction, MyObservation, MyState +from uuid import uuid4 +from openenv.core.env_server.interfaces import Environment +from openenv.core.env_server.types import State +from models import MyAction, MyObservation class MyEnvironment(Environment): def __init__(self): - super().__init__() - self._state = MyState() + self._state = State(episode_id=str(uuid4()), step_count=0) def reset(self) -> MyObservation: - self._state = MyState(episode_id=str(uuid.uuid4())) - return MyObservation(result="Ready", success=True) + self._state = State(episode_id=str(uuid4()), step_count=0) + return MyObservation(result="Ready", success=True, done=False, reward=0.0) def step(self, action: MyAction) -> MyObservation: # Implement your logic here self._state.step_count += 1 result = self._execute_command(action.command) - return MyObservation(result=result, success=True) + return MyObservation(result=result, success=True, done=False, reward=1.0) @property - def state(self) -> MyState: + def state(self) -> State: return self._state ``` ### 4. Create the FastAPI Server -`server/app.py` should expose the environment through `create_fastapi_app`: +`server/app.py` should expose the environment through `create_app`: ```python # server/app.py -from openenv.core.env_server import create_fastapi_app -from ..models import MyAction, MyObservation +from openenv.core.env_server.http_server import create_app +from my_env.models import MyAction, MyObservation from .my_environment import MyEnvironment env = MyEnvironment() -app = create_fastapi_app(env, MyAction, MyObservation) +app = create_app(env, MyAction, MyObservation, env_name="my_env") ``` ### 5. Implement the Client @@ -138,23 +131,33 @@ app = create_fastapi_app(env, MyAction, MyObservation) ```python # client.py from openenv.core.http_env_client import HTTPEnvClient -from openenv.core.types import StepResult -from .models import MyAction, MyObservation, MyState +from openenv.core.client_types import StepResult +from openenv.core.env_server.types import State +from .models import MyAction, MyObservation class MyEnv(HTTPEnvClient[MyAction, MyObservation]): def _step_payload(self, action: MyAction) -> dict: return {"command": action.command, "parameters": action.parameters} def _parse_result(self, payload: dict) -> StepResult[MyObservation]: - obs = MyObservation(**payload["observation"]) + obs_data = payload.get("observation", {}) + obs = MyObservation( + result=obs_data.get("result", ""), + success=obs_data.get("success", False), + done=payload.get("done", False), + reward=payload.get("reward"), + ) return StepResult( observation=obs, reward=payload.get("reward"), done=payload.get("done", False), ) - def _parse_state(self, payload: dict) -> MyState: - return MyState(**payload) + def _parse_state(self, payload: dict) -> State: + return State( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + ) ``` ### 6. Configure Dependencies & Dockerfile From 360f878d845677a44f352abeb49020c4777cfbed Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Mon, 8 Dec 2025 13:54:31 +0100 Subject: [PATCH 06/41] update echo with pydantic --- envs/echo_env/models.py | 10 ++++------ src/openenv/cli/templates/openenv_env/models.py | 10 ++++------ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/envs/echo_env/models.py b/envs/echo_env/models.py index 4cbf1016c..3032b7511 100644 --- a/envs/echo_env/models.py +++ b/envs/echo_env/models.py @@ -10,7 +10,7 @@ The Echo environment is a simple test environment that echoes back messages. """ -from dataclasses import dataclass +from pydantic import Field # Support both in-repo and standalone imports try: @@ -21,16 +21,14 @@ from openenv.core.env_server.types import Action, Observation -@dataclass(kw_only=True) class EchoAction(Action): """Action for the Echo environment - just a message to echo.""" - message: str + message: str = Field(..., min_length=1, description="Message to echo back") -@dataclass(kw_only=True) class EchoObservation(Observation): """Observation from the Echo environment - the echoed message.""" - echoed_message: str - message_length: int = 0 \ No newline at end of file + echoed_message: str = Field(..., description="The echoed message from the environment") + message_length: int = Field(default=0, ge=0, description="Length of the echoed message") \ No newline at end of file diff --git a/src/openenv/cli/templates/openenv_env/models.py b/src/openenv/cli/templates/openenv_env/models.py index 64010449b..57e2d1fca 100644 --- a/src/openenv/cli/templates/openenv_env/models.py +++ b/src/openenv/cli/templates/openenv_env/models.py @@ -10,22 +10,20 @@ The __ENV_NAME__ environment is a simple test environment that echoes back messages. """ -from dataclasses import dataclass +from pydantic import Field from openenv.core.env_server.types import Action, Observation -@dataclass(kw_only=True) class __ENV_CLASS_NAME__Action(Action): """Action for the __ENV_TITLE_NAME__ environment - just a message to echo.""" - message: str + message: str = Field(..., description="Message to echo back") -@dataclass(kw_only=True) class __ENV_CLASS_NAME__Observation(Observation): """Observation from the __ENV_TITLE_NAME__ environment - the echoed message.""" - echoed_message: str - message_length: int = 0 + echoed_message: str = Field(..., description="The echoed message") + message_length: int = Field(default=0, description="Length of the echoed message") From 600acb41e952525bbb564ae3fbeb8559f3131694 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Mon, 8 Dec 2025 18:59:59 +0530 Subject: [PATCH 07/41] chore: add websockets to pyproject.toml --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 811c068c9..edb6c1f17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,8 @@ dependencies = [ "huggingface_hub>=0.20.0", "openai>=2.7.2", "tomli>=2.3.0", - "tomli-w>=1.2.0" + "tomli-w>=1.2.0", + "websockets>=15.0.1", ] [project.optional-dependencies] From a98851a2e5a1ce12b13595f95aa632f2c19f0fd4 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 14:45:35 +0100 Subject: [PATCH 08/41] add concurrency safe pram --- .../openenv_env/server/__ENV_NAME___environment.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py b/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py index e2a9ce0b7..72db6472f 100644 --- a/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py +++ b/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py @@ -36,6 +36,12 @@ class __ENV_CLASS_NAME__Environment(Environment): >>> print(obs.message_length) # 5 """ + # Enable concurrent WebSocket sessions. + # Set to True if your environment isolates state between instances. + # When True, multiple WebSocket clients can connect simultaneously, each + # getting their own environment instance (when using factory mode in app.py). + CONCURRENCY_SAFE: bool = True + def __init__(self): """Initialize the __ENV_NAME__ environment.""" self._state = State(episode_id=str(uuid4()), step_count=0) From 8197d6f29c1f3dd6a8b7abdc364c69cd33354429 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 14:45:54 +0100 Subject: [PATCH 09/41] use factory in template app --- .../cli/templates/openenv_env/server/app.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/openenv/cli/templates/openenv_env/server/app.py b/src/openenv/cli/templates/openenv_env/server/app.py index db216fb06..87e3db6dc 100644 --- a/src/openenv/cli/templates/openenv_env/server/app.py +++ b/src/openenv/cli/templates/openenv_env/server/app.py @@ -8,7 +8,14 @@ FastAPI application for the __ENV_TITLE_NAME__ Environment. This module creates an HTTP server that exposes the __ENV_CLASS_NAME__Environment -over HTTP endpoints, making it compatible with HTTPEnvClient. +over HTTP and WebSocket endpoints, compatible with HTTPEnvClient and WebSocketEnvClient. + +Endpoints: + - POST /reset: Reset the environment + - POST /step: Execute an action + - GET /state: Get current environment state + - GET /schema: Get action/observation schemas + - WS /ws: WebSocket endpoint for persistent sessions Usage: # Development (with auto-reload): @@ -31,15 +38,14 @@ from __ENV_NAME__.models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation from .__ENV_NAME___environment import __ENV_CLASS_NAME__Environment -# Create the environment instance -env = __ENV_CLASS_NAME__Environment() # Create the app with web interface and README integration app = create_app( - env, + __ENV_CLASS_NAME__Environment, __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation, env_name="__ENV_NAME__", + max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions ) From f72b6dad63275127b536c6448daa7f9a4730d4c5 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 14:46:16 +0100 Subject: [PATCH 10/41] us WS in client --- .../cli/templates/openenv_env/client.py | 95 ++++++++++++++++++- 1 file changed, 92 insertions(+), 3 deletions(-) diff --git a/src/openenv/cli/templates/openenv_env/client.py b/src/openenv/cli/templates/openenv_env/client.py index 703b28a85..0775f2536 100644 --- a/src/openenv/cli/templates/openenv_env/client.py +++ b/src/openenv/cli/templates/openenv_env/client.py @@ -5,10 +5,11 @@ # LICENSE file in the root directory of this source tree. """ -__ENV_TITLE_NAME__ Environment HTTP Client. +__ENV_TITLE_NAME__ Environment Clients. -This module provides the client for connecting to a __ENV_TITLE_NAME__ Environment server -over HTTP. +This module provides clients for connecting to a __ENV_TITLE_NAME__ Environment server: +- __ENV_CLASS_NAME__Env: HTTP client for request/response interactions +- __ENV_CLASS_NAME__EnvWS: WebSocket client for persistent sessions """ from typing import Any, Dict @@ -16,6 +17,7 @@ from openenv.core.client_types import StepResult from openenv.core.env_server.types import State from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.ws_env_client import WebSocketEnvClient from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation @@ -98,3 +100,90 @@ def _parse_state(self, payload: Dict) -> State: episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), ) + + +class __ENV_CLASS_NAME__EnvWS(WebSocketEnvClient[__ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation]): + """ + WebSocket client for the __ENV_TITLE_NAME__ Environment. + + This client maintains a persistent WebSocket connection to the environment server, + enabling efficient multi-step interactions with lower latency than HTTP. + Each client instance has its own dedicated environment session on the server. + + Advantages over HTTP client: + - Lower latency for sequential interactions (no connection overhead per request) + - Session state is maintained server-side + - Better suited for long-running episodes + + Example: + >>> # Connect to a running server via WebSocket + >>> with __ENV_CLASS_NAME__EnvWS(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.echoed_message) + ... + ... result = client.step(__ENV_CLASS_NAME__Action(message="Hello!")) + ... print(result.observation.echoed_message) + + Example with Docker: + >>> # Automatically start container and connect via WebSocket + >>> client = __ENV_CLASS_NAME__EnvWS.from_docker_image("__ENV_NAME__-env:latest") + >>> try: + ... result = client.reset() + ... result = client.step(__ENV_CLASS_NAME__Action(message="Test")) + ... finally: + ... client.close() + """ + + def _step_payload(self, action: __ENV_CLASS_NAME__Action) -> Dict: + """ + Convert __ENV_CLASS_NAME__Action to JSON payload for step message. + + Args: + action: __ENV_CLASS_NAME__Action instance + + Returns: + Dictionary representation suitable for JSON encoding + """ + return { + "message": action.message, + } + + def _parse_result(self, payload: Dict) -> StepResult[__ENV_CLASS_NAME__Observation]: + """ + Parse WebSocket response into StepResult[__ENV_CLASS_NAME__Observation]. + + Args: + payload: JSON response data from server + + Returns: + StepResult with __ENV_CLASS_NAME__Observation + """ + obs_data = payload.get("observation", {}) + observation = __ENV_CLASS_NAME__Observation( + echoed_message=obs_data.get("echoed_message", ""), + message_length=obs_data.get("message_length", 0), + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + + return StepResult( + observation=observation, + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict) -> State: + """ + Parse WebSocket state response into State object. + + Args: + payload: JSON response from state request + + Returns: + State object with episode_id and step_count + """ + return State( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + ) From 26b1148eab604a566c00b29a651a2a0a7bed2fb5 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 14:46:22 +0100 Subject: [PATCH 11/41] expose ws classes --- src/openenv/cli/templates/openenv_env/__init__.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/openenv/cli/templates/openenv_env/__init__.py b/src/openenv/cli/templates/openenv_env/__init__.py index 656800a55..aed293ba8 100644 --- a/src/openenv/cli/templates/openenv_env/__init__.py +++ b/src/openenv/cli/templates/openenv_env/__init__.py @@ -6,8 +6,12 @@ """__ENV_TITLE_NAME__ Environment - A simple test environment for HTTP server.""" -from .client import __ENV_CLASS_NAME__Env +from .client import __ENV_CLASS_NAME__Env, __ENV_CLASS_NAME__EnvWS from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation -__all__ = ["__ENV_CLASS_NAME__Action", "__ENV_CLASS_NAME__Observation", "__ENV_CLASS_NAME__Env"] - +__all__ = [ + "__ENV_CLASS_NAME__Action", + "__ENV_CLASS_NAME__Observation", + "__ENV_CLASS_NAME__Env", + "__ENV_CLASS_NAME__EnvWS", +] From 1ddd8d8537f29c8360b255bcb0200c7a6395a0b7 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 14:46:49 +0100 Subject: [PATCH 12/41] add websocket examples to template readme --- .../cli/templates/openenv_env/README.md | 60 ++++++++++++++++++- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/src/openenv/cli/templates/openenv_env/README.md b/src/openenv/cli/templates/openenv_env/README.md index ef238dfb7..f6a5c0292 100644 --- a/src/openenv/cli/templates/openenv_env/README.md +++ b/src/openenv/cli/templates/openenv_env/README.md @@ -114,6 +114,7 @@ The deployed space includes: - **Web Interface** at `/web` - Interactive UI for exploring the environment - **API Documentation** at `/docs` - Full OpenAPI/Swagger interface - **Health Check** at `/health` - Container health monitoring +- **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions ## Environment Details @@ -154,6 +155,61 @@ result = __ENV_NAME__env.step(__ENV_CLASS_NAME__Action(message="Hello!")) Note: When connecting to an existing server, `__ENV_NAME__env.close()` will NOT stop the server. +### WebSocket Client for Persistent Sessions + +For long-running episodes or when you need lower latency, use the WebSocket client: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__EnvWS + +# Connect via WebSocket (maintains persistent connection) +with __ENV_CLASS_NAME__EnvWS(base_url="http://localhost:8000") as env: + result = env.reset() + print(f"Reset: {result.observation.echoed_message}") + # Multiple steps with low latency + for msg in ["Hello", "World", "!"]: + result = env.step(__ENV_CLASS_NAME__Action(message=msg)) + print(f"Echoed: {result.observation.echoed_message}") +``` + +WebSocket advantages: +- **Lower latency**: No HTTP connection overhead per request +- **Persistent session**: Server maintains your environment state +- **Efficient for episodes**: Better for many sequential steps + +### Concurrent WebSocket Sessions + +The server supports multiple concurrent WebSocket connections. To enable this, +modify `server/app.py` to use factory mode: + +```python +# In server/app.py - use factory mode for concurrent sessions +app = create_app( + __ENV_CLASS_NAME__Environment, # Pass class, not instance + __ENV_CLASS_NAME__Action, + __ENV_CLASS_NAME__Observation, + max_concurrent_envs=4, # Allow 4 concurrent sessions +) +``` + +Then multiple clients can connect simultaneously: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__EnvWS +from concurrent.futures import ThreadPoolExecutor + +def run_episode(client_id: int): + with __ENV_CLASS_NAME__EnvWS(base_url="http://localhost:8000") as env: + result = env.reset() + for i in range(10): + result = env.step(__ENV_CLASS_NAME__Action(message=f"Client {client_id}, step {i}")) + return client_id, result.observation.message_length + +# Run 4 episodes concurrently +with ThreadPoolExecutor(max_workers=4) as executor: + results = list(executor.map(run_episode, range(4))) +``` + ## Development & Testing ### Direct Environment Testing @@ -189,11 +245,11 @@ __ENV_NAME__/ ├── openenv.yaml # OpenEnv manifest ├── pyproject.toml # Project metadata and dependencies ├── uv.lock # Locked dependencies (generated) -├── client.py # __ENV_CLASS_NAME__Env client implementation +├── client.py # __ENV_CLASS_NAME__Env (HTTP) and __ENV_CLASS_NAME__EnvWS (WebSocket) clients ├── models.py # Action and Observation models └── server/ ├── __init__.py # Server module exports ├── __ENV_NAME___environment.py # Core environment logic - ├── app.py # FastAPI application + ├── app.py # FastAPI application (HTTP + WebSocket endpoints) └── Dockerfile # Container image definition ``` From 7138716eef49164e612e637fff40576d850762de Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 15:23:18 +0100 Subject: [PATCH 13/41] add note to toml for github install --- src/openenv/cli/templates/openenv_env/pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/openenv/cli/templates/openenv_env/pyproject.toml b/src/openenv/cli/templates/openenv_env/pyproject.toml index 55b90113f..4c6b948ff 100644 --- a/src/openenv/cli/templates/openenv_env/pyproject.toml +++ b/src/openenv/cli/templates/openenv_env/pyproject.toml @@ -15,6 +15,8 @@ description = "__ENV_TITLE_NAME__ environment for OpenEnv" requires-python = ">=3.10" dependencies = [ # Core OpenEnv runtime (provides FastAPI server + HTTP client types) + # install from github + # "openenv[core] @ git+https://github.com/meta-pytorch/OpenEnv.git", "openenv[core]>=0.2.0", # Environment-specific dependencies # Add all dependencies needed for your environment here From 438f96647c63c317f55abf3992fbcd9930209a83 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Wed, 10 Dec 2025 22:34:18 +0530 Subject: [PATCH 14/41] refactor: enforce env factory usage and drop instance mode --- src/openenv/core/env_server/http_server.py | 76 ++++++-------------- src/openenv/core/env_server/web_interface.py | 21 ++++-- 2 files changed, 36 insertions(+), 61 deletions(-) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 8dd144987..fd25739b2 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -76,13 +76,9 @@ class HTTPEnvServer: >>> from core.env_server import HTTPEnvServer >>> from envs.coding_env.server import CodeExecutionEnvironment >>> - >>> # Single environment (backward compatible) - >>> env = CodeExecutionEnvironment() - >>> server = HTTPEnvServer(env) - >>> - >>> # Factory pattern for concurrent sessions + >>> # Pass environment class (factory pattern) >>> server = HTTPEnvServer( - ... env=CodeExecutionEnvironment, # Pass class, not instance + ... env=CodeExecutionEnvironment, ... max_concurrent_envs=4, ... ) >>> @@ -94,9 +90,9 @@ class HTTPEnvServer: def __init__( self, - env: Union[Environment, Callable[[], Environment], Type[Environment]], - action_cls: Type[Action] = None, - observation_cls: Type[Observation] = None, + env: Union[Callable[[], Environment], Type[Environment]], + action_cls: Type[Action], + observation_cls: Type[Observation], max_concurrent_envs: int = 1, skip_concurrency_check: bool = False, concurrency_config: Optional[ConcurrencyConfig] = None, @@ -105,14 +101,11 @@ def __init__( Initialize HTTP server wrapper. Args: - env: The Environment instance, factory callable, or class to wrap. - - If an instance is provided, it's used directly (single-env mode) - - If a callable/class is provided, it's called to create new - environments for each WebSocket session (factory mode) + env: Environment factory (callable or class) that creates new instances. + Will be called to create a new environment for each WebSocket session. action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns - max_concurrent_envs: Maximum number of concurrent WebSocket sessions. - Only applies when env is a factory. Default is 1. + max_concurrent_envs: Maximum number of concurrent WebSocket sessions (default: 1). If concurrency_config is provided, this parameter is ignored. skip_concurrency_check: If True, skip concurrency safety validation. Use with caution for advanced users who understand @@ -125,7 +118,14 @@ def __init__( ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an environment that is not marked as CONCURRENCY_SAFE. """ - self._env_factory: Optional[Callable[[], Environment]] = None + # Validate that env is callable + if not callable(env): + raise TypeError( + f"env must be a callable (class or factory function), got {type(env)}. " + f"Pass the environment class (e.g., MyEnvironment) not an instance (e.g., MyEnvironment())." + ) + + self._env_factory: Callable[[], Environment] = env # Handle concurrency configuration if concurrency_config is not None: @@ -144,24 +144,7 @@ def __init__( "OPENENV_SKIP_CONCURRENCY_CHECK", "" ).lower() in ("1", "true", "yes") - # Determine if env is an instance or factory - if isinstance(env, Environment): - # Single instance mode (backward compatible) - self.env = env - self._env_factory = None - elif callable(env): - # Factory mode - env is a class or callable - self._env_factory = env - # Create a single instance for HTTP endpoints (backward compat) - try: - self.env = env() - except Exception as e: - factory_name = getattr(env, "__name__", str(env)) - raise EnvironmentFactoryError(factory_name, e) from e - else: - raise TypeError( - f"env must be an Environment instance or callable, got {type(env)}" - ) + self.env = env() # Validate concurrency configuration self._validate_concurrency_safety() @@ -272,22 +255,7 @@ async def _create_session(self) -> tuple[str, Environment]: session_id = str(uuid.uuid4()) current_time = time.time() - if self._env_factory is None: - # Single instance mode - use shared env (limited concurrency) - if self._sessions: - raise SessionCapacityError( - active_sessions=len(self._sessions), - max_sessions=1, - message="Single instance mode: only one WebSocket session allowed", - ) - env = self.env - else: - # Factory mode - create new environment - try: - env = self._env_factory() - except Exception as e: - factory_name = getattr(self._env_factory, "__name__", str(self._env_factory)) - raise EnvironmentFactoryError(factory_name, e) from e + env = self._env_factory() self._sessions[session_id] = env @@ -827,7 +795,7 @@ async def websocket_endpoint(websocket: WebSocket): def create_app( - env: Union[Environment, Callable[[], Environment], Type[Environment]], + env: Union[Callable[[], Environment], Type[Environment]], action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, @@ -841,7 +809,7 @@ def create_app( including README integration for better user experience. Args: - env: The Environment instance, factory callable, or class to serve + env: Environment factory (callable or class) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading @@ -878,7 +846,7 @@ def create_app( def create_fastapi_app( - env: Union[Environment, Callable[[], Environment], Type[Environment]], + env: Union[Callable[[], Environment], Type[Environment]], action_cls: Type[Action], observation_cls: Type[Observation], max_concurrent_envs: int = 1, @@ -888,7 +856,7 @@ def create_fastapi_app( Create a FastAPI application with comprehensive documentation. Args: - env: The Environment instance, factory callable, or class to serve + env: Environment factory (callable or class) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1). diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index b370cfa53..52ce4a113 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -14,7 +14,7 @@ from __future__ import annotations import json -from typing import Any, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type, Union from datetime import datetime from fastapi import FastAPI, WebSocket, WebSocketDisconnect @@ -23,7 +23,7 @@ from .interfaces import Environment from .serialization import deserialize_action_with_preprocessing, serialize_observation -from .types import Action, Observation, State, EnvironmentMetadata +from .types import Action, Observation, State, EnvironmentMetadata, ConcurrencyConfig def load_environment_metadata( @@ -251,19 +251,23 @@ def get_state(self) -> Dict[str, Any]: def create_web_interface_app( - env: Environment, + env: Union[Callable[[], Environment], Type[Environment]], action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, + max_concurrent_envs: int = 1, + concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: """ Create a FastAPI application with web interface for the given environment. Args: - env: The Environment instance to serve + env: Environment factory (callable or class) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading + max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1) + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings Returns: FastAPI application instance with web interface @@ -271,13 +275,16 @@ def create_web_interface_app( from .http_server import create_fastapi_app # Create the base environment app - app = create_fastapi_app(env, action_cls, observation_cls) + app = create_fastapi_app(env, action_cls, observation_cls, max_concurrent_envs, concurrency_config) + + # Create a test instance for metadata + env_instance = env() # Load environment metadata - metadata = load_environment_metadata(env, env_name) + metadata = load_environment_metadata(env_instance, env_name) # Create web interface manager - web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) + web_manager = WebInterfaceManager(env_instance, action_cls, observation_cls, metadata) # Add web interface routes @app.get("/web", response_class=HTMLResponse) From 7319be0aa2e3a382366fcc18601fdff259c02097 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Wed, 10 Dec 2025 22:37:47 +0530 Subject: [PATCH 15/41] refactor(ws): replace WSMessage with typed BaseMessage + discriminated WSIncomingMessage --- src/openenv/core/env_server/__init__.py | 6 +++-- src/openenv/core/env_server/types.py | 32 +++++++++++++++---------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/openenv/core/env_server/__init__.py b/src/openenv/core/env_server/__init__.py index e1014540e..ed0d41278 100644 --- a/src/openenv/core/env_server/__init__.py +++ b/src/openenv/core/env_server/__init__.py @@ -21,7 +21,8 @@ State, SchemaResponse, HealthResponse, - WSMessage, + BaseMessage, + WSIncomingMessage, WSResetMessage, WSStepMessage, WSStateMessage, @@ -56,7 +57,8 @@ "SchemaResponse", "HealthResponse", # WebSocket message types - "WSMessage", + "BaseMessage", + "WSIncomingMessage", "WSResetMessage", "WSStepMessage", "WSStateMessage", diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 39074595f..279726f6d 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, Literal, Annotated from pydantic import BaseModel, Field, ConfigDict @@ -213,46 +213,52 @@ class HealthResponse(BaseModel): status: str = Field(description="Health status of the environment server") -class WSMessage(BaseModel): - """Base class for WebSocket messages.""" + +class BaseMessage(BaseModel): + """Base class for WebSocket messages with shared configuration.""" model_config = ConfigDict( extra="forbid", validate_assignment=True, ) - type: str = Field(description="Message type identifier") - -class WSResetMessage(WSMessage): +class WSResetMessage(BaseMessage): """WebSocket message to reset the environment.""" - type: str = Field(default="reset", description="Message type") + type: Literal["reset"] = Field(default="reset", description="Message type") data: Dict[str, Any] = Field( default_factory=dict, description="Optional reset parameters (seed, episode_id, etc.)", ) -class WSStepMessage(WSMessage): +class WSStepMessage(BaseMessage): """WebSocket message to execute a step.""" - type: str = Field(default="step", description="Message type") + type: Literal["step"] = Field(default="step", description="Message type") data: Dict[str, Any] = Field( ..., description="Action data conforming to environment's action schema" ) -class WSStateMessage(WSMessage): +class WSStateMessage(BaseMessage): """WebSocket message to request current state.""" - type: str = Field(default="state", description="Message type") + type: Literal["state"] = Field(default="state", description="Message type") -class WSCloseMessage(WSMessage): +class WSCloseMessage(BaseMessage): """WebSocket message to close the session.""" - type: str = Field(default="close", description="Message type") + type: Literal["close"] = Field(default="close", description="Message type") + + +# Discriminated union for incoming WebSocket messages +WSIncomingMessage = Annotated[ + WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage, + Field(discriminator="type") +] class WSObservationResponse(BaseModel): From 561f9023b73eda7bf303c65326c2468bf4562848 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Wed, 10 Dec 2025 22:38:09 +0530 Subject: [PATCH 16/41] refactor: remove redundant ConcurrencySafetyLevel --- src/openenv/core/env_server/types.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 279726f6d..3c7d18b05 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -288,18 +288,6 @@ class WSErrorResponse(BaseModel): data: Dict[str, Any] = Field(description="Error details including message and code") -class ConcurrencySafetyLevel(str): - """ - Classification of environment concurrency safety. - - Environments are classified based on their ability to safely handle - multiple concurrent sessions within a single container. - """ - - UNSAFE = "unsafe" - SAFE = "safe" - - class ConcurrencyConfig(BaseModel): """Configuration for concurrent environment sessions.""" From c90cca06b614c242d770b2741044a03e093b6dc2 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 20:11:08 +0100 Subject: [PATCH 17/41] update web interface --- src/openenv/core/env_server/web_interface.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index b370cfa53..d1b527f14 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -255,6 +255,8 @@ def create_web_interface_app( action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, + max_concurrent_envs: int = 1, + concurrency_config: Optional[Any] = None, ) -> FastAPI: """ Create a FastAPI application with web interface for the given environment. @@ -264,14 +266,21 @@ def create_web_interface_app( action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading + max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1). + Ignored if concurrency_config is provided. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + If provided, overrides max_concurrent_envs. Returns: FastAPI application instance with web interface """ from .http_server import create_fastapi_app - # Create the base environment app - app = create_fastapi_app(env, action_cls, observation_cls) + # Create the base environment app with concurrency settings + app = create_fastapi_app( + env, action_cls, observation_cls, + max_concurrent_envs, concurrency_config + ) # Load environment metadata metadata = load_environment_metadata(env, env_name) From f57b36f615061184374cafab290eaedf631d4a32 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 20:18:26 +0100 Subject: [PATCH 18/41] make web interface compatible with websockets --- src/openenv/core/env_server/web_interface.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index d1b527f14..404abba35 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -262,7 +262,7 @@ def create_web_interface_app( Create a FastAPI application with web interface for the given environment. Args: - env: The Environment instance to serve + env: The Environment instance, factory callable, or class to serve action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading @@ -282,11 +282,22 @@ def create_web_interface_app( max_concurrent_envs, concurrency_config ) + # If env is a class/factory, instantiate it for the web interface + # (the HTTPEnvServer in create_fastapi_app handles this separately) + if isinstance(env, Environment): + env_instance = env + elif callable(env): + env_instance = env() + else: + raise TypeError( + f"env must be an Environment instance or callable, got {type(env)}" + ) + # Load environment metadata - metadata = load_environment_metadata(env, env_name) + metadata = load_environment_metadata(env_instance, env_name) # Create web interface manager - web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) + web_manager = WebInterfaceManager(env_instance, action_cls, observation_cls, metadata) # Add web interface routes @app.get("/web", response_class=HTMLResponse) From bd2a1636a1376ceccab0b38c8ae04ffee1650329 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 20:19:06 +0100 Subject: [PATCH 19/41] format --- src/openenv/core/env_server/web_interface.py | 69 +++++--------------- 1 file changed, 17 insertions(+), 52 deletions(-) diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index 404abba35..119845177 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -26,9 +26,7 @@ from .types import Action, Observation, State, EnvironmentMetadata -def load_environment_metadata( - env: Environment, env_name: Optional[str] = None -) -> EnvironmentMetadata: +def load_environment_metadata(env: Environment, env_name: Optional[str] = None) -> EnvironmentMetadata: """ Load environment metadata including README content. @@ -106,9 +104,7 @@ class ActionLog(BaseModel): timestamp: str = Field(description="Timestamp when action was taken") action: Dict[str, Any] = Field(description="Action that was taken") observation: Dict[str, Any] = Field(description="Observation returned from action") - reward: Optional[float] = Field( - default=None, description="Reward received from action" - ) + reward: Optional[float] = Field(default=None, description="Reward received from action") done: bool = Field(description="Whether the episode is done after this action") step_count: int = Field(description="Step count when this action was taken") @@ -120,15 +116,9 @@ class EpisodeState(BaseModel): episode_id: Optional[str] = Field(default=None, description="Current episode ID") step_count: int = Field(description="Current step count in episode") - current_observation: Optional[Dict[str, Any]] = Field( - default=None, description="Current observation" - ) - action_logs: List[ActionLog] = Field( - default_factory=list, description="List of action logs" - ) - is_reset: bool = Field( - default=True, description="Whether the episode has been reset" - ) + current_observation: Optional[Dict[str, Any]] = Field(default=None, description="Current observation") + action_logs: List[ActionLog] = Field(default_factory=list, description="List of action logs") + is_reset: bool = Field(default=True, description="Whether the episode has been reset") class WebInterfaceManager: @@ -211,9 +201,7 @@ async def reset_environment(self) -> Dict[str, Any]: async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]: """Execute a step in the environment and update state.""" # Deserialize action with preprocessing for web interface special cases - action: Action = deserialize_action_with_preprocessing( - action_data, self.action_cls - ) + action: Action = deserialize_action_with_preprocessing(action_data, self.action_cls) # Execute step observation: Observation = self.env.step(action) @@ -277,10 +265,7 @@ def create_web_interface_app( from .http_server import create_fastapi_app # Create the base environment app with concurrency settings - app = create_fastapi_app( - env, action_cls, observation_cls, - max_concurrent_envs, concurrency_config - ) + app = create_fastapi_app(env, action_cls, observation_cls, max_concurrent_envs, concurrency_config) # If env is a class/factory, instantiate it for the web interface # (the HTTPEnvServer in create_fastapi_app handles this separately) @@ -289,9 +274,7 @@ def create_web_interface_app( elif callable(env): env_instance = env() else: - raise TypeError( - f"env must be an Environment instance or callable, got {type(env)}" - ) + raise TypeError(f"env must be an Environment instance or callable, got {type(env)}") # Load environment metadata metadata = load_environment_metadata(env_instance, env_name) @@ -348,9 +331,7 @@ async def web_state(): return app -def get_web_interface_html( - action_cls: Type[Action], metadata: Optional[EnvironmentMetadata] = None -) -> str: +def get_web_interface_html(action_cls: Type[Action], metadata: Optional[EnvironmentMetadata] = None) -> str: """Generate the HTML for the web interface.""" # Check if this is a chat environment by looking for tokens field @@ -1332,9 +1313,7 @@ def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]: return action_fields -def _determine_input_type_from_schema( - field_info: Dict[str, Any], field_name: str -) -> str: +def _determine_input_type_from_schema(field_info: Dict[str, Any], field_name: str) -> str: """Determine the appropriate HTML input type from JSON schema info.""" schema_type = field_info.get("type") @@ -1406,15 +1385,9 @@ def _markdown_to_html(markdown: str) -> str: html_content = html.escape(markdown) # Convert headers - html_content = re.sub( - r"^# (.*?)$", r"

\1

", html_content, flags=re.MULTILINE - ) - html_content = re.sub( - r"^## (.*?)$", r"

\1

", html_content, flags=re.MULTILINE - ) - html_content = re.sub( - r"^### (.*?)$", r"

\1

", html_content, flags=re.MULTILINE - ) + html_content = re.sub(r"^# (.*?)$", r"

\1

", html_content, flags=re.MULTILINE) + html_content = re.sub(r"^## (.*?)$", r"

\1

", html_content, flags=re.MULTILINE) + html_content = re.sub(r"^### (.*?)$", r"

\1

", html_content, flags=re.MULTILINE) # Convert code blocks html_content = re.sub( @@ -1430,12 +1403,8 @@ def _markdown_to_html(markdown: str) -> str: html_content = re.sub(r"\*(.*?)\*", r"\1", html_content) # Convert lists - html_content = re.sub( - r"^- (.*?)$", r"
  • \1
  • ", html_content, flags=re.MULTILINE - ) - html_content = re.sub( - r"(
  • .*
  • )", r"
      \1
    ", html_content, flags=re.DOTALL - ) + html_content = re.sub(r"^- (.*?)$", r"
  • \1
  • ", html_content, flags=re.MULTILINE) + html_content = re.sub(r"(
  • .*
  • )", r"
      \1
    ", html_content, flags=re.DOTALL) # Convert line breaks html_content = html_content.replace("\n", "
    ") @@ -1443,9 +1412,7 @@ def _markdown_to_html(markdown: str) -> str: return html_content -def _generate_action_interface( - action_fields: List[Dict[str, Any]], is_chat_env: bool -) -> str: +def _generate_action_interface(action_fields: List[Dict[str, Any]], is_chat_env: bool) -> str: """Generate either a chat interface or action form based on environment type.""" if is_chat_env: return _generate_chat_interface() @@ -1569,9 +1536,7 @@ def _generate_single_field(field: Dict[str, Any]) -> str: for choice in choices: selected = "selected" if str(choice) == str(default_value) else "" - options_html.append( - f'' - ) + options_html.append(f'') return f'''
    From 3e116f8c0526a361ea22db977ca5cb1be0b9c5b5 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 21:22:38 +0100 Subject: [PATCH 20/41] relative imports in template --- src/openenv/cli/templates/openenv_env/server/app.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/openenv/cli/templates/openenv_env/server/app.py b/src/openenv/cli/templates/openenv_env/server/app.py index 87e3db6dc..5100b1050 100644 --- a/src/openenv/cli/templates/openenv_env/server/app.py +++ b/src/openenv/cli/templates/openenv_env/server/app.py @@ -35,7 +35,8 @@ "openenv is required for the web interface. Install dependencies with '\n uv sync\n'" ) from e -from __ENV_NAME__.models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation +# Import from local models.py (PYTHONPATH includes /app/env in Docker) +from models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation from .__ENV_NAME___environment import __ENV_CLASS_NAME__Environment From 25b7cfaf26e62a6495121482983adf285c00f21a Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 10 Dec 2025 21:22:55 +0100 Subject: [PATCH 21/41] use pydantic in template --- src/openenv/cli/templates/openenv_env/models.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/openenv/cli/templates/openenv_env/models.py b/src/openenv/cli/templates/openenv_env/models.py index 64010449b..4540d5a29 100644 --- a/src/openenv/cli/templates/openenv_env/models.py +++ b/src/openenv/cli/templates/openenv_env/models.py @@ -10,22 +10,20 @@ The __ENV_NAME__ environment is a simple test environment that echoes back messages. """ -from dataclasses import dataclass +from pydantic import Field from openenv.core.env_server.types import Action, Observation -@dataclass(kw_only=True) class __ENV_CLASS_NAME__Action(Action): """Action for the __ENV_TITLE_NAME__ environment - just a message to echo.""" - message: str + message: str = Field(..., description="Message to echo back") -@dataclass(kw_only=True) class __ENV_CLASS_NAME__Observation(Observation): """Observation from the __ENV_TITLE_NAME__ environment - the echoed message.""" - echoed_message: str - message_length: int = 0 + echoed_message: str = Field(default="", description="The echoed message") + message_length: int = Field(default=0, description="Length of the echoed message") From 8f23dc42f175bcf6d7e9c774c91590d03b7be87b Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Thu, 11 Dec 2025 21:29:07 +0530 Subject: [PATCH 22/41] rename to session_timeout --- src/openenv/core/env_server/http_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index fd25739b2..bc2a09040 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -135,7 +135,7 @@ def __init__( # Use legacy parameters self._concurrency_config = ConcurrencyConfig( max_concurrent_envs=max_concurrent_envs, - session_timeout_seconds=None, + session_timeout=None, reject_on_capacity=True, ) self._max_concurrent_envs = max_concurrent_envs @@ -549,7 +549,7 @@ async def step(request: StepRequest) -> StepResponse: Returns information about: - **max_concurrent_envs**: Maximum number of concurrent WebSocket sessions -- **session_timeout_seconds**: Timeout for inactive sessions (None if no timeout) +- **session_timeout**: Timeout in seconds for inactive sessions (None if no timeout) - **reject_on_capacity**: Whether to reject or queue connections at capacity """, ) From 0d56b834c142295795e1e410018892b16adb69ba Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Thu, 11 Dec 2025 21:30:13 +0530 Subject: [PATCH 23/41] ConcurrencyConfig, ServerCapacityStatus, and SessionInfo inherit from BaseMessage --- src/openenv/core/env_server/types.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 3c7d18b05..0821437fc 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -288,21 +288,16 @@ class WSErrorResponse(BaseModel): data: Dict[str, Any] = Field(description="Error details including message and code") -class ConcurrencyConfig(BaseModel): +class ConcurrencyConfig(BaseMessage): """Configuration for concurrent environment sessions.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - max_concurrent_envs: int = Field( default=1, ge=1, le=1000, description="Maximum number of concurrent WebSocket sessions allowed", ) - session_timeout_seconds: Optional[float] = Field( + session_timeout: Optional[float] = Field( default=None, gt=0, description="Timeout in seconds for inactive sessions. None means no timeout.", @@ -313,14 +308,9 @@ class ConcurrencyConfig(BaseModel): ) -class ServerCapacityStatus(BaseModel): +class ServerCapacityStatus(BaseMessage): """Status of server capacity for concurrent sessions.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - active_sessions: int = Field( ge=0, description="Number of currently active sessions", @@ -349,14 +339,9 @@ def from_counts(cls, active: int, max_sessions: int) -> "ServerCapacityStatus": ) -class SessionInfo(BaseModel): +class SessionInfo(BaseMessage): """Information about an active session.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - session_id: str = Field(description="Unique identifier for the session") created_at: float = Field(description="Unix timestamp when the session was created") last_activity_at: float = Field( From 9cd2aacbba661f971152fcb17ea892fc1040a0a1 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Thu, 11 Dec 2025 21:47:59 +0530 Subject: [PATCH 24/41] message classes to inherit from BaseMessage for shared config --- src/openenv/core/env_server/types.py | 48 ++++++++-------------------- 1 file changed, 14 insertions(+), 34 deletions(-) diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 0821437fc..4d0cacb70 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -127,6 +127,15 @@ class StepResponse(BaseModel): done: bool = Field(default=False, description="Whether the episode has terminated") +class BaseMessage(BaseModel): + """Base class for WebSocket messages with shared configuration.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + class State(BaseModel): """Base class for environment state. @@ -149,27 +158,17 @@ class State(BaseModel): ) -class CodeExecResult(BaseModel): +class CodeExecResult(BaseMessage): """Result of code execution containing stdout, stderr, and exit code.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - stdout: str = Field(description="Standard output from code execution") stderr: str = Field(description="Standard error from code execution") exit_code: int = Field(description="Exit code from code execution") -class EnvironmentMetadata(BaseModel): +class EnvironmentMetadata(BaseMessage): """Metadata about an environment for documentation and UI purposes.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - name: str = Field(description="Name of the environment") description: str = Field(description="Description of what the environment does") readme_content: Optional[str] = Field( @@ -184,14 +183,9 @@ class EnvironmentMetadata(BaseModel): ) -class SchemaResponse(BaseModel): +class SchemaResponse(BaseMessage): """Response model for the combined schema endpoint.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - action: Dict[str, Any] = Field( description="JSON schema for actions accepted by this environment" ) @@ -203,26 +197,12 @@ class SchemaResponse(BaseModel): ) -class HealthResponse(BaseModel): +class HealthResponse(BaseMessage): """Response model for health check endpoint.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - status: str = Field(description="Health status of the environment server") -class BaseMessage(BaseModel): - """Base class for WebSocket messages with shared configuration.""" - - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - - class WSResetMessage(BaseMessage): """WebSocket message to reset the environment.""" @@ -257,7 +237,7 @@ class WSCloseMessage(BaseMessage): # Discriminated union for incoming WebSocket messages WSIncomingMessage = Annotated[ WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage, - Field(discriminator="type") + Field(discriminator="type"), ] From 77a8c832bbe68a3a2e9d2f7528bc97219c4725f0 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Thu, 11 Dec 2025 22:19:45 +0530 Subject: [PATCH 25/41] refactor: rename CONCURRENCY_SAFE to SUPPORTS_CONCURRENT_SESSIONS --- .../server/__ENV_NAME___environment.py | 2 +- src/openenv/core/env_server/exceptions.py | 6 +- src/openenv/core/env_server/http_server.py | 283 +++++++++++------- src/openenv/core/env_server/interfaces.py | 2 +- src/openenv/core/env_server/types.py | 33 +- 5 files changed, 197 insertions(+), 129 deletions(-) diff --git a/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py b/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py index 72db6472f..454ea6808 100644 --- a/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py +++ b/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py @@ -40,7 +40,7 @@ class __ENV_CLASS_NAME__Environment(Environment): # Set to True if your environment isolates state between instances. # When True, multiple WebSocket clients can connect simultaneously, each # getting their own environment instance (when using factory mode in app.py). - CONCURRENCY_SAFE: bool = True + SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self): """Initialize the __ENV_NAME__ environment.""" diff --git a/src/openenv/core/env_server/exceptions.py b/src/openenv/core/env_server/exceptions.py index 41a8235bb..a16715721 100644 --- a/src/openenv/core/env_server/exceptions.py +++ b/src/openenv/core/env_server/exceptions.py @@ -20,7 +20,7 @@ class ConcurrencyConfigurationError(OpenEnvError): Raised when an environment is misconfigured for concurrent sessions. This error is raised during server startup when max_concurrent_envs > 1 - is specified for an environment that is not marked as CONCURRENCY_SAFE. + is specified for an environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. """ def __init__( @@ -34,10 +34,10 @@ def __init__( if message is None: message = ( - f"Environment '{environment_name}' is not marked as CONCURRENCY_SAFE. " + f"Environment '{environment_name}' is not marked as SUPPORTS_CONCURRENT_SESSIONS. " f"Cannot run with max_concurrent_envs={max_concurrent_envs}. " f"Either set max_concurrent_envs=1 or ensure the environment " - f"properly isolates session state and set CONCURRENCY_SAFE=True." + f"properly isolates session state and set SUPPORTS_CONCURRENT_SESSIONS=True." ) super().__init__(message) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index bc2a09040..3752bb50a 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -20,7 +20,7 @@ import os import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Dict, Optional, Type, Union +from typing import Any, Awaitable, Callable, Dict, Optional, Type, Union, cast from fastapi import Body, FastAPI, HTTPException, WebSocket, WebSocketDisconnect, status from pydantic import ValidationError @@ -113,10 +113,10 @@ def __init__( concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. If provided, overrides max_concurrent_envs and allows configuration of session timeout and capacity behavior. - + Raises: ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an - environment that is not marked as CONCURRENCY_SAFE. + environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. """ # Validate that env is callable if not callable(env): @@ -124,9 +124,9 @@ def __init__( f"env must be a callable (class or factory function), got {type(env)}. " f"Pass the environment class (e.g., MyEnvironment) not an instance (e.g., MyEnvironment())." ) - + self._env_factory: Callable[[], Environment] = env - + # Handle concurrency configuration if concurrency_config is not None: self._concurrency_config = concurrency_config @@ -139,51 +139,63 @@ def __init__( reject_on_capacity=True, ) self._max_concurrent_envs = max_concurrent_envs - + self._skip_concurrency_check = skip_concurrency_check or os.getenv( "OPENENV_SKIP_CONCURRENCY_CHECK", "" ).lower() in ("1", "true", "yes") - + self.env = env() - + # Validate concurrency configuration self._validate_concurrency_safety() - + self.action_cls = action_cls self.observation_cls = observation_cls - + # Session management for WebSocket connections self._sessions: Dict[str, Environment] = {} self._session_executors: Dict[str, ThreadPoolExecutor] = {} self._session_info: Dict[str, SessionInfo] = {} self._session_lock = asyncio.Lock() - + # Create thread pool for running sync code in async context # This is needed for environments using sync libraries (e.g., Playwright) # Configurable via OPENENV_THREAD_POOL_SIZE (default: 32) pool_size = int(os.getenv("OPENENV_THREAD_POOL_SIZE", "32")) self._executor = ThreadPoolExecutor(max_workers=pool_size) - # Check if environment has async methods for better concurrency - self._has_step_async = hasattr(env, "step_async") and asyncio.iscoroutinefunction(env.step_async) - self._has_reset_async = hasattr(env, "reset_async") and asyncio.iscoroutinefunction(env.reset_async) + self._reset_async: Optional[Callable[..., Awaitable[Observation]]] = None + if hasattr(self.env, "reset_async"): + reset_method = getattr(self.env, "reset_async") + if asyncio.iscoroutinefunction(reset_method): + self._reset_async = cast( + Callable[..., Awaitable[Observation]], reset_method + ) + + self._step_async: Optional[Callable[..., Awaitable[Observation]]] = None + if hasattr(self.env, "step_async"): + step_method = getattr(self.env, "step_async") + if asyncio.iscoroutinefunction(step_method): + self._step_async = cast( + Callable[..., Awaitable[Observation]], step_method + ) def _validate_concurrency_safety(self) -> None: """ Validate that the environment supports the configured concurrency level. - + Raises: ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an - environment that is not marked as CONCURRENCY_SAFE. + environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. """ if self._max_concurrent_envs <= 1: return - + if self._skip_concurrency_check: return - - is_concurrency_safe = getattr(self.env, "CONCURRENCY_SAFE", False) - + + is_concurrency_safe = getattr(self.env, "SUPPORTS_CONCURRENT_SESSIONS", False) + if not is_concurrency_safe: env_name = type(self.env).__name__ raise ConcurrencyConfigurationError( @@ -194,7 +206,7 @@ def _validate_concurrency_safety(self) -> None: def get_capacity_status(self) -> ServerCapacityStatus: """ Get the current capacity status of the server. - + Returns: ServerCapacityStatus with current session counts and availability. """ @@ -203,19 +215,28 @@ def get_capacity_status(self) -> ServerCapacityStatus: max_sessions=self._max_concurrent_envs, ) - async def _run_sync_in_thread_pool(self, func, *args, **kwargs): + async def _run_sync_in_thread_pool( + self, func: Callable[..., Observation], *args, **kwargs + ) -> Observation: """Run a synchronous function in the thread pool executor.""" loop = asyncio.get_event_loop() return await loop.run_in_executor(self._executor, lambda: func(*args, **kwargs)) - def _get_valid_kwargs(self, sig, kwargs, skip_params=None): + def _get_valid_kwargs( + self, + sig: inspect.Signature, + kwargs: Dict[str, Any], + skip_params: Optional[set[str]] = None, + ) -> Dict[str, Any]: """Filter kwargs to only include parameters accepted by the function signature.""" if skip_params is None: skip_params = set() valid_kwargs = {} - has_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) + has_kwargs = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) for k, v in kwargs.items(): if k in sig.parameters or has_kwargs: @@ -227,16 +248,16 @@ def _get_valid_kwargs(self, sig, kwargs, skip_params=None): async def _create_session(self) -> tuple[str, Environment]: """ Create a new WebSocket session with its own environment instance. - + Returns: Tuple of (session_id, environment) - + Raises: SessionCapacityError: If max concurrent sessions reached EnvironmentFactoryError: If the factory fails to create an environment """ import time - + async with self._session_lock: if len(self._sessions) >= self._max_concurrent_envs: if self._concurrency_config.reject_on_capacity: @@ -251,17 +272,16 @@ async def _create_session(self) -> tuple[str, Environment]: max_sessions=self._max_concurrent_envs, message="Session queuing not yet implemented", ) - + session_id = str(uuid.uuid4()) current_time = time.time() - + env = self._env_factory() - + self._sessions[session_id] = env - - # Create dedicated executor for this session + self._session_executors[session_id] = ThreadPoolExecutor(max_workers=1) - + # Track session metadata self._session_info[session_id] = SessionInfo( session_id=session_id, @@ -270,73 +290,74 @@ async def _create_session(self) -> tuple[str, Environment]: step_count=0, environment_type=type(env).__name__, ) - + return session_id, env - + async def _destroy_session(self, session_id: str) -> None: """ Destroy a WebSocket session and cleanup resources. - + Args: session_id: The session ID to destroy """ async with self._session_lock: if session_id in self._sessions: env = self._sessions.pop(session_id) - # Call close() if environment has it - if hasattr(env, 'close') and callable(env.close): + if hasattr(env, "close") and callable(getattr(env, "close")): try: - env.close() + getattr(env, "close")() except Exception: - pass # Best effort cleanup - + pass + if session_id in self._session_executors: executor = self._session_executors.pop(session_id) executor.shutdown(wait=False) - + # Remove session metadata self._session_info.pop(session_id, None) - - def _update_session_activity(self, session_id: str, increment_step: bool = False) -> None: + + def _update_session_activity( + self, session_id: str, increment_step: bool = False + ) -> None: """ Update session activity timestamp and optionally increment step count. - + Args: session_id: The session ID to update increment_step: If True, increment the step count """ import time - + if session_id in self._session_info: self._session_info[session_id].last_activity_at = time.time() if increment_step: self._session_info[session_id].step_count += 1 - + def get_session_info(self, session_id: str) -> Optional[SessionInfo]: """ Get information about a specific session. - + Args: session_id: The session ID to query - + Returns: SessionInfo if the session exists, None otherwise """ return self._session_info.get(session_id) async def _run_in_session_executor( - self, session_id: str, func: Callable, *args, **kwargs - ) -> Any: + self, session_id: str, func: Callable[..., Observation], *args, **kwargs + ) -> Observation: """Run a synchronous function in the session's thread pool executor.""" executor = self._session_executors.get(session_id, self._executor) loop = asyncio.get_event_loop() return await loop.run_in_executor(executor, lambda: func(*args, **kwargs)) - + @property def active_sessions(self) -> int: """Return the number of active WebSocket sessions.""" return len(self._sessions) - + @property def max_concurrent_envs(self) -> int: """Return the maximum number of concurrent environments.""" @@ -345,7 +366,7 @@ def max_concurrent_envs(self) -> int: @property def is_concurrency_safe(self) -> bool: """Return whether the environment is marked as concurrency safe.""" - return getattr(self.env, "CONCURRENCY_SAFE", False) + return getattr(self.env, "SUPPORTS_CONCURRENT_SESSIONS", False) @property def concurrency_config(self) -> ConcurrencyConfig: @@ -369,18 +390,18 @@ async def reset_handler( # Start with all fields from the request, including extra ones kwargs = request.model_dump(exclude_unset=True) - # Pass arguments only if environment accepts them - if self._has_reset_async: - sig = inspect.signature(self.env.reset_async) + if self._reset_async: + sig = inspect.signature(self._reset_async) else: sig = inspect.signature(self.env.reset) valid_kwargs = self._get_valid_kwargs(sig, kwargs) - # Use async method if available for better concurrency - if self._has_reset_async: - observation = await self.env.reset_async(**valid_kwargs) + if self._reset_async: + observation = await self._reset_async(**valid_kwargs) else: - observation = await self._run_sync_in_thread_pool(self.env.reset, **valid_kwargs) + observation = await self._run_sync_in_thread_pool( + self.env.reset, **valid_kwargs + ) return ResetResponse(**serialize_observation(observation)) # Helper function to handle step endpoint @@ -393,24 +414,26 @@ async def step_handler(request: StepRequest) -> StepResponse: action = deserialize_action(action_data, self.action_cls) except ValidationError as e: # Return HTTP 422 with detailed validation errors - raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors()) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors() + ) # Handle optional parameters # Start with all fields from the request, including extra ones, but exclude 'action' kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) - # Pass arguments only if environment accepts them - if self._has_step_async: - sig = inspect.signature(self.env.step_async) + if self._step_async: + sig = inspect.signature(self._step_async) else: sig = inspect.signature(self.env.step) valid_kwargs = self._get_valid_kwargs(sig, kwargs, skip_params={"action"}) - # Use async method if available for better concurrency - if self._has_step_async: - observation = await self.env.step_async(action, **valid_kwargs) + if self._step_async: + observation = await self._step_async(action, **valid_kwargs) else: - observation = await self._run_sync_in_thread_pool(self.env.step, action, **valid_kwargs) + observation = await self._run_sync_in_thread_pool( + self.env.step, action, **valid_kwargs + ) # Return serialized observation return StepResponse(**serialize_observation(observation)) @@ -611,38 +634,41 @@ async def get_schemas() -> SchemaResponse: async def websocket_endpoint(websocket: WebSocket): """ WebSocket endpoint for persistent environment sessions. - + Each WebSocket connection gets its own environment instance (when using factory mode) or shares the single instance (backward compatible mode). - + Message Protocol: - Client sends: WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage - Server responds: WSObservationResponse | WSStateResponse | WSErrorResponse """ await websocket.accept() - + session_id = None session_env = None - + try: # Create session with dedicated environment session_id, session_env = await self._create_session() - + while True: # Receive message from client raw_message = await websocket.receive_text() - + try: message_dict = json.loads(raw_message) except json.JSONDecodeError as e: error_resp = WSErrorResponse( - data={"message": f"Invalid JSON: {e}", "code": "INVALID_JSON"} + data={ + "message": f"Invalid JSON: {e}", + "code": "INVALID_JSON", + } ) await websocket.send_text(error_resp.model_dump_json()) continue - + msg_type = message_dict.get("type", "") - + try: if msg_type == "reset": # Parse and validate reset message @@ -650,105 +676,130 @@ async def websocket_endpoint(websocket: WebSocket): msg = WSResetMessage(**message_dict) except ValidationError as e: error_resp = WSErrorResponse( - data={"message": "Invalid reset message", "code": "VALIDATION_ERROR", "errors": e.errors()} + data={ + "message": "Invalid reset message", + "code": "VALIDATION_ERROR", + "errors": e.errors(), + } ) await websocket.send_text(error_resp.model_dump_json()) continue - + # Handle reset sig = inspect.signature(session_env.reset) valid_kwargs = self._get_valid_kwargs(sig, msg.data) - + observation = await self._run_in_session_executor( session_id, session_env.reset, **valid_kwargs ) - + self._update_session_activity(session_id) - + response = WSObservationResponse( data=serialize_observation(observation) ) await websocket.send_text(response.model_dump_json()) - + elif msg_type == "step": # Parse and validate step message try: msg = WSStepMessage(**message_dict) except ValidationError as e: error_resp = WSErrorResponse( - data={"message": "Invalid step message", "code": "VALIDATION_ERROR", "errors": e.errors()} + data={ + "message": "Invalid step message", + "code": "VALIDATION_ERROR", + "errors": e.errors(), + } ) await websocket.send_text(error_resp.model_dump_json()) continue - + # Deserialize action with Pydantic validation try: action = deserialize_action(msg.data, self.action_cls) except ValidationError as e: error_resp = WSErrorResponse( - data={"message": str(e), "code": "VALIDATION_ERROR", "errors": e.errors()} + data={ + "message": str(e), + "code": "VALIDATION_ERROR", + "errors": e.errors(), + } ) await websocket.send_text(error_resp.model_dump_json()) continue - + observation = await self._run_in_session_executor( session_id, session_env.step, action ) - - self._update_session_activity(session_id, increment_step=True) - + + self._update_session_activity( + session_id, increment_step=True + ) + response = WSObservationResponse( data=serialize_observation(observation) ) await websocket.send_text(response.model_dump_json()) - + elif msg_type == "state": # Parse and validate state message try: msg = WSStateMessage(**message_dict) except ValidationError as e: error_resp = WSErrorResponse( - data={"message": "Invalid state message", "code": "VALIDATION_ERROR", "errors": e.errors()} + data={ + "message": "Invalid state message", + "code": "VALIDATION_ERROR", + "errors": e.errors(), + } ) await websocket.send_text(error_resp.model_dump_json()) continue - + # Handle state request state = session_env.state - if hasattr(state, 'model_dump'): + if hasattr(state, "model_dump"): state_data = state.model_dump() else: state_data = dict(state) if state else {} - + response = WSStateResponse(data=state_data) await websocket.send_text(response.model_dump_json()) - + elif msg_type == "close": # Parse and validate close message try: msg = WSCloseMessage(**message_dict) except ValidationError as e: error_resp = WSErrorResponse( - data={"message": "Invalid close message", "code": "VALIDATION_ERROR", "errors": e.errors()} + data={ + "message": "Invalid close message", + "code": "VALIDATION_ERROR", + "errors": e.errors(), + } ) await websocket.send_text(error_resp.model_dump_json()) continue - + # Client requested close break - + else: error_resp = WSErrorResponse( - data={"message": f"Unknown message type: {msg_type}", "code": "UNKNOWN_TYPE"} + data={ + "message": f"Unknown message type: {msg_type}", + "code": "UNKNOWN_TYPE", + } ) await websocket.send_text(error_resp.model_dump_json()) - + except Exception as e: error_resp = WSErrorResponse( data={"message": str(e), "code": "EXECUTION_ERROR"} ) await websocket.send_text(error_resp.model_dump_json()) - + except WebSocketDisconnect: pass except SessionCapacityError as e: @@ -834,14 +885,17 @@ def create_app( from .web_interface import create_web_interface_app return create_web_interface_app( - env, action_cls, observation_cls, env_name, - max_concurrent_envs, concurrency_config + env, + action_cls, + observation_cls, + env_name, + max_concurrent_envs, + concurrency_config, ) else: # Use standard FastAPI app without web interface return create_fastapi_app( - env, action_cls, observation_cls, - max_concurrent_envs, concurrency_config + env, action_cls, observation_cls, max_concurrent_envs, concurrency_config ) @@ -854,7 +908,7 @@ def create_fastapi_app( ) -> FastAPI: """ Create a FastAPI application with comprehensive documentation. - + Args: env: Environment factory (callable or class) that creates new instances action_cls: The Action subclass this environment expects @@ -863,14 +917,16 @@ def create_fastapi_app( Ignored if concurrency_config is provided. concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. If provided, overrides max_concurrent_envs. - + Returns: FastAPI application instance """ try: from fastapi import FastAPI except ImportError: - raise ImportError("FastAPI is required. Install with: pip install fastapi uvicorn") + raise ImportError( + "FastAPI is required. Install with: pip install fastapi uvicorn" + ) app = FastAPI( title="OpenEnv Environment HTTP API", @@ -933,8 +989,11 @@ def create_fastapi_app( ) server = HTTPEnvServer( - env, action_cls, observation_cls, - max_concurrent_envs, concurrency_config=concurrency_config + env, + action_cls, + observation_cls, + max_concurrent_envs, + concurrency_config=concurrency_config, ) server.register_routes(app) return app diff --git a/src/openenv/core/env_server/interfaces.py b/src/openenv/core/env_server/interfaces.py index 196e7ac82..f147589d3 100644 --- a/src/openenv/core/env_server/interfaces.py +++ b/src/openenv/core/env_server/interfaces.py @@ -104,7 +104,7 @@ class Environment(ABC): """ # Class-level flag indicating whether this environment supports concurrent sessions - CONCURRENCY_SAFE: bool = False + SUPPORTS_CONCURRENT_SESSIONS: bool = False def __init__(self, transform: Transform | None = None): self.transform = transform diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 4d0cacb70..8993d280c 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Any, Dict, Optional, Union, Literal, Annotated -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, Field, ConfigDict, model_validator # Type aliases @@ -299,23 +299,32 @@ class ServerCapacityStatus(BaseMessage): ge=1, description="Maximum number of allowed sessions", ) - available_slots: int = Field( - ge=0, - description="Number of available session slots", - ) - is_at_capacity: bool = Field( - description="Whether the server has reached maximum capacity", - ) + + @model_validator(mode="after") + def check_capacity_bounds(self) -> "ServerCapacityStatus": + if self.active_sessions > self.max_sessions: + raise ValueError( + f"active_sessions ({self.active_sessions}) cannot exceed " + f"max_sessions ({self.max_sessions})" + ) + return self + + @property + def available_slots(self) -> int: + """Number of available session slots.""" + return self.max_sessions - self.active_sessions + + @property + def is_at_capacity(self) -> bool: + """Whether the server has reached maximum capacity.""" + return self.available_slots == 0 @classmethod def from_counts(cls, active: int, max_sessions: int) -> "ServerCapacityStatus": """Create status from active and max session counts.""" - available = max(0, max_sessions - active) return cls( active_sessions=active, max_sessions=max_sessions, - available_slots=available, - is_at_capacity=active >= max_sessions, ) @@ -333,5 +342,5 @@ class SessionInfo(BaseMessage): description="Number of steps executed in this session", ) environment_type: str = Field( - description="Type name of the environment class for this session" + description="Environment type for this session (e.g. `CodingEnv`)" ) From 86a222da891403f4088be524952b803c9be64c7b Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Fri, 12 Dec 2025 21:37:08 +0530 Subject: [PATCH 26/41] refactor: core types for better inference and fix async detection - Genericize `Environment`, `HTTPEnvClient`, and `WebSocketEnvClient` with `ActT`, `ObsT`, and `StateT` to improve type inference in IDEs. - Update client methods to use `Dict[str, Any]` for stricter typing of JSON payloads. - Remove conditional `websockets` import in `ws_env_client.py` and simplify connection logic. - Fix async method detection in `HTTPEnvServer` to correctly handle factory functions and avoid unnecessary instantiation duringRefactor core types for better inference and fix async detection - Genericize `Environment`, `HTTPEnvClient`, and `WebSocketEnvClient` with `ActT`, `ObsT`, and `StateT` to improve type inference in IDEs. - Update client methods to use `Dict[str, Any]` for stricter typing of JSON payloads. - Remove conditional `websockets` import in `ws_env_client.py` and simplify connection logic. - Fix async method detection in `HTTPEnvServer` to correctly handle factory functions and avoid unnecessary instantiation during --- src/openenv/core/client_types.py | 5 +- src/openenv/core/env_server/http_server.py | 236 +++++++++++---------- src/openenv/core/env_server/interfaces.py | 56 ++++- src/openenv/core/http_env_client.py | 12 +- src/openenv/core/ws_env_client.py | 25 +-- 5 files changed, 186 insertions(+), 148 deletions(-) diff --git a/src/openenv/core/client_types.py b/src/openenv/core/client_types.py index 8808e96bf..c7501c656 100644 --- a/src/openenv/core/client_types.py +++ b/src/openenv/core/client_types.py @@ -1,9 +1,10 @@ # Type definitions for EnvTorch from dataclasses import dataclass -from typing import Any, Generic, Optional, TypeVar +from typing import Generic, Optional, TypeVar # Generic type for observations -ObsT = TypeVar("ObsT") # TypeVar for typehinting in IDEs +ObsT = TypeVar("ObsT") +StateT = TypeVar("StateT") @dataclass diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 3752bb50a..56b73b3fa 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -20,7 +20,7 @@ import os import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Any, Awaitable, Callable, Dict, Optional, Type, Union, cast +from typing import Any, Callable, Dict, Optional, Type, Union from fastapi import Body, FastAPI, HTTPException, WebSocket, WebSocketDisconnect, status from pydantic import ValidationError @@ -75,10 +75,13 @@ class HTTPEnvServer: Example: >>> from core.env_server import HTTPEnvServer >>> from envs.coding_env.server import CodeExecutionEnvironment + >>> from envs.coding_env.models import CodeAction, CodeObservation >>> >>> # Pass environment class (factory pattern) >>> server = HTTPEnvServer( ... env=CodeExecutionEnvironment, + ... action_cls=CodeAction, + ... observation_cls=CodeObservation, ... max_concurrent_envs=4, ... ) >>> @@ -144,8 +147,6 @@ def __init__( "OPENENV_SKIP_CONCURRENCY_CHECK", "" ).lower() in ("1", "true", "yes") - self.env = env() - # Validate concurrency configuration self._validate_concurrency_safety() @@ -164,22 +165,6 @@ def __init__( pool_size = int(os.getenv("OPENENV_THREAD_POOL_SIZE", "32")) self._executor = ThreadPoolExecutor(max_workers=pool_size) - self._reset_async: Optional[Callable[..., Awaitable[Observation]]] = None - if hasattr(self.env, "reset_async"): - reset_method = getattr(self.env, "reset_async") - if asyncio.iscoroutinefunction(reset_method): - self._reset_async = cast( - Callable[..., Awaitable[Observation]], reset_method - ) - - self._step_async: Optional[Callable[..., Awaitable[Observation]]] = None - if hasattr(self.env, "step_async"): - step_method = getattr(self.env, "step_async") - if asyncio.iscoroutinefunction(step_method): - self._step_async = cast( - Callable[..., Awaitable[Observation]], step_method - ) - def _validate_concurrency_safety(self) -> None: """ Validate that the environment supports the configured concurrency level. @@ -194,10 +179,17 @@ def _validate_concurrency_safety(self) -> None: if self._skip_concurrency_check: return - is_concurrency_safe = getattr(self.env, "SUPPORTS_CONCURRENT_SESSIONS", False) + if inspect.isclass(self._env_factory): + is_concurrency_safe = getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False) + env_name = self._env_factory.__name__ + else: + _temp_env = self._env_factory() + is_concurrency_safe = getattr(_temp_env, "SUPPORTS_CONCURRENT_SESSIONS", False) + env_name = type(_temp_env).__name__ + _temp_env.close() + del _temp_env if not is_concurrency_safe: - env_name = type(self.env).__name__ raise ConcurrencyConfigurationError( environment_name=env_name, max_concurrent_envs=self._max_concurrent_envs, @@ -303,17 +295,12 @@ async def _destroy_session(self, session_id: str) -> None: async with self._session_lock: if session_id in self._sessions: env = self._sessions.pop(session_id) - if hasattr(env, "close") and callable(getattr(env, "close")): - try: - getattr(env, "close")() - except Exception: - pass + env.close() if session_id in self._session_executors: executor = self._session_executors.pop(session_id) executor.shutdown(wait=False) - # Remove session metadata self._session_info.pop(session_id, None) def _update_session_activity( @@ -366,7 +353,15 @@ def max_concurrent_envs(self) -> int: @property def is_concurrency_safe(self) -> bool: """Return whether the environment is marked as concurrency safe.""" - return getattr(self.env, "SUPPORTS_CONCURRENT_SESSIONS", False) + import inspect + if inspect.isclass(self._env_factory): + return getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False) + else: + _temp_env = self._env_factory() + result = getattr(_temp_env, "SUPPORTS_CONCURRENT_SESSIONS", False) + _temp_env.close() + del _temp_env + return result @property def concurrency_config(self) -> ConcurrencyConfig: @@ -386,57 +381,64 @@ async def reset_handler( request: ResetRequest = Body(default_factory=ResetRequest), ) -> ResetResponse: """Reset endpoint - returns initial observation.""" - # Handle optional parameters - # Start with all fields from the request, including extra ones - kwargs = request.model_dump(exclude_unset=True) - - if self._reset_async: - sig = inspect.signature(self._reset_async) - else: - sig = inspect.signature(self.env.reset) - valid_kwargs = self._get_valid_kwargs(sig, kwargs) - - if self._reset_async: - observation = await self._reset_async(**valid_kwargs) - else: - observation = await self._run_sync_in_thread_pool( - self.env.reset, **valid_kwargs - ) - return ResetResponse(**serialize_observation(observation)) + _env = self._env_factory() + + try: + kwargs = request.model_dump(exclude_unset=True) + + is_async = _env.reset_async.__func__ is not Environment.reset_async + + if is_async: + sig = inspect.signature(_env.reset_async) + else: + sig = inspect.signature(_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, kwargs) + + if is_async: + observation = await _env.reset_async(**valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool( + _env.reset, **valid_kwargs + ) + return ResetResponse(**serialize_observation(observation)) + finally: + _env.close() # Helper function to handle step endpoint async def step_handler(request: StepRequest) -> StepResponse: """Step endpoint - executes action and returns observation.""" action_data = request.action - # Deserialize action with Pydantic validation try: action = deserialize_action(action_data, self.action_cls) except ValidationError as e: - # Return HTTP 422 with detailed validation errors raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors() ) - # Handle optional parameters - # Start with all fields from the request, including extra ones, but exclude 'action' - kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) - - if self._step_async: - sig = inspect.signature(self._step_async) - else: - sig = inspect.signature(self.env.step) - valid_kwargs = self._get_valid_kwargs(sig, kwargs, skip_params={"action"}) - - if self._step_async: - observation = await self._step_async(action, **valid_kwargs) - else: - observation = await self._run_sync_in_thread_pool( - self.env.step, action, **valid_kwargs - ) + _env = self._env_factory() + + try: + kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) + + is_async = _env.step_async.__func__ is not Environment.step_async + + if is_async: + sig = inspect.signature(_env.step_async) + else: + sig = inspect.signature(_env.step) + valid_kwargs = self._get_valid_kwargs(sig, kwargs, skip_params={"action"}) + + if is_async: + observation = await _env.step_async(action, **valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool( + _env.step, action, **valid_kwargs + ) - # Return serialized observation - return StepResponse(**serialize_observation(observation)) + return StepResponse(**serialize_observation(observation)) + finally: + _env.close() # Register routes using the helpers @app.post( @@ -522,24 +524,36 @@ async def reset( async def step(request: StepRequest) -> StepResponse: return await step_handler(request) - # Configure and register GET endpoints declaratively + def get_state_handler() -> State: + _env = self._env_factory() + try: + return _env.state + finally: + _env.close() + + def get_metadata_handler() -> EnvironmentMetadata: + _env = self._env_factory() + try: + return _env.get_metadata() + finally: + _env.close() + get_endpoints = [ GetEndpointConfig( path="/state", - handler=lambda: self.env.state, + handler=get_state_handler, response_model=State, tag="State Management", summary="Get current environment state", description=""" Retrieve the current internal state of the environment. -This endpoint allows inspection of the environment state without modifying it. The structure of the state object is defined by the environment's State model. """, ), GetEndpointConfig( path="/metadata", - handler=self.env.get_metadata, + handler=get_metadata_handler, response_model=EnvironmentMetadata, tag="Environment Info", summary="Get environment metadata", @@ -686,12 +700,18 @@ async def websocket_endpoint(websocket: WebSocket): continue # Handle reset - sig = inspect.signature(session_env.reset) - valid_kwargs = self._get_valid_kwargs(sig, msg.data) + is_async = session_env.reset_async.__func__ is not Environment.reset_async - observation = await self._run_in_session_executor( - session_id, session_env.reset, **valid_kwargs - ) + if is_async: + sig = inspect.signature(session_env.reset_async) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await session_env.reset_async(**valid_kwargs) + else: + sig = inspect.signature(session_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await self._run_in_session_executor( + session_id, session_env.reset, **valid_kwargs + ) self._update_session_activity(session_id) @@ -729,9 +749,14 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_text(error_resp.model_dump_json()) continue - observation = await self._run_in_session_executor( - session_id, session_env.step, action - ) + is_async = session_env.step_async.__func__ is not Environment.step_async + + if is_async: + observation = await session_env.step_async(action) + else: + observation = await self._run_in_session_executor( + session_id, session_env.step, action + ) self._update_session_activity( session_id, increment_step=True @@ -803,46 +828,33 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass except SessionCapacityError as e: - try: - error_resp = WSErrorResponse( - data={ - "message": str(e), - "code": "CAPACITY_REACHED", - "active_sessions": e.active_sessions, - "max_sessions": e.max_sessions, - } - ) - await websocket.send_text(error_resp.model_dump_json()) - except Exception: - pass + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": "CAPACITY_REACHED", + "active_sessions": e.active_sessions, + "max_sessions": e.max_sessions, + } + ) + await websocket.send_text(error_resp.model_dump_json()) except EnvironmentFactoryError as e: - try: - error_resp = WSErrorResponse( - data={ - "message": str(e), - "code": "FACTORY_ERROR", - "factory_name": e.factory_name, - } - ) - await websocket.send_text(error_resp.model_dump_json()) - except Exception: - pass + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": "FACTORY_ERROR", + "factory_name": e.factory_name, + } + ) + await websocket.send_text(error_resp.model_dump_json()) except Exception as e: - try: - error_resp = WSErrorResponse( - data={"message": str(e), "code": "SESSION_ERROR"} - ) - await websocket.send_text(error_resp.model_dump_json()) - except Exception: - pass + error_resp = WSErrorResponse( + data={"message": str(e), "code": "SESSION_ERROR"} + ) + await websocket.send_text(error_resp.model_dump_json()) finally: - # Cleanup session if session_id: await self._destroy_session(session_id) - try: - await websocket.close() - except Exception: - pass + await websocket.close() def create_app( diff --git a/src/openenv/core/env_server/interfaces.py b/src/openenv/core/env_server/interfaces.py index f147589d3..03f1ddb21 100644 --- a/src/openenv/core/env_server/interfaces.py +++ b/src/openenv/core/env_server/interfaces.py @@ -5,10 +5,14 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any, Optional, Protocol, TypedDict +from typing import Any, Generic, Optional, Protocol, TypedDict, TypeVar from .types import Action, Observation, State, EnvironmentMetadata +ActT = TypeVar("ActT", bound=Action) +ObsT = TypeVar("ObsT", bound=Observation) +StateT = TypeVar("StateT", bound=State) + class Message(TypedDict): """A message in a conversation. @@ -64,7 +68,7 @@ def decode( ... -class Transform(ABC): +class Transform(ABC, Generic[ObsT]): """Transform observations to add rewards, metrics, or other modifications. Transforms follow the TorchRL pattern where they take an observation @@ -73,7 +77,7 @@ class Transform(ABC): """ @abstractmethod - def __call__(self, observation: Observation) -> Observation: + def __call__(self, observation: ObsT) -> ObsT: """Transform an observation. Args: @@ -85,7 +89,7 @@ def __call__(self, observation: Observation) -> Observation: pass -class Environment(ABC): +class Environment(ABC, Generic[ActT, ObsT, StateT]): """Base class for all environment servers following Gym/Gymnasium API. Args: @@ -106,7 +110,7 @@ class Environment(ABC): # Class-level flag indicating whether this environment supports concurrent sessions SUPPORTS_CONCURRENT_SESSIONS: bool = False - def __init__(self, transform: Transform | None = None): + def __init__(self, transform: Optional[Transform[ObsT]] = None): self.transform = transform @abstractmethod @@ -115,23 +119,47 @@ def reset( seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, - ) -> Observation: + ) -> ObsT: """Reset the environment and return initial observation.""" pass + async def reset_async( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> ObsT: + """Async version of reset. Default implementation calls sync reset. + + Override to provide true async implementation. + """ + return self.reset(seed=seed, episode_id=episode_id, **kwargs) + @abstractmethod def step( self, - action: Action, + action: ActT, timeout_s: Optional[float] = None, **kwargs: Any, - ) -> Observation: + ) -> ObsT: """Take a step in the environment.""" pass + async def step_async( + self, + action: ActT, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> ObsT: + """Async version of step. Default implementation calls sync step. + + Override to provide true async implementation. + """ + return self.step(action, timeout_s=timeout_s, **kwargs) + @property @abstractmethod - def state(self) -> State: + def state(self) -> StateT: """Get the current environment state.""" pass @@ -151,8 +179,16 @@ def get_metadata(self) -> EnvironmentMetadata: version="1.0.0", ) - def _apply_transform(self, observation: Observation) -> Observation: + def _apply_transform(self, observation: ObsT) -> ObsT: """Apply transform if one is provided.""" if self.transform is not None: return self.transform(observation) return observation + + def close(self) -> None: + """Clean up resources used by the environment. + + Override this method to implement custom cleanup logic. + Called when the environment is being destroyed or reset. + """ + pass diff --git a/src/openenv/core/http_env_client.py b/src/openenv/core/http_env_client.py index 007ef6a5f..0f25363d4 100644 --- a/src/openenv/core/http_env_client.py +++ b/src/openenv/core/http_env_client.py @@ -16,7 +16,7 @@ import requests -from .client_types import StepResult +from .client_types import StepResult, StateT from .containers.runtime import LocalDockerProvider if TYPE_CHECKING: @@ -27,7 +27,7 @@ EnvClientT = TypeVar("EnvClientT", bound="HTTPEnvClient") -class HTTPEnvClient(ABC, Generic[ActT, ObsT]): +class HTTPEnvClient(ABC, Generic[ActT, ObsT, StateT]): def __init__( self, base_url: str, @@ -129,17 +129,17 @@ def from_hub( return cls.from_docker_image(image=base_url, provider=provider) @abstractmethod - def _step_payload(self, action: ActT) -> dict: + def _step_payload(self, action: ActT) -> Dict[str, Any]: """Convert an Action object to the JSON body expected by the env server.""" raise NotImplementedError @abstractmethod - def _parse_result(self, payload: dict) -> StepResult[ObsT]: + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]: """Convert a JSON response from the env server to StepResult[ObsT].""" raise NotImplementedError @abstractmethod - def _parse_state(self, payload: dict) -> Any: + def _parse_state(self, payload: Dict[str, Any]) -> StateT: """Convert a JSON response from the state endpoint to a State object.""" raise NotImplementedError @@ -203,7 +203,7 @@ def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: r.raise_for_status() return self._parse_result(r.json()) - def state(self) -> Any: + def state(self) -> StateT: """ Get the current environment state from the server. diff --git a/src/openenv/core/ws_env_client.py b/src/openenv/core/ws_env_client.py index c6f054e85..6c1d6a4ab 100644 --- a/src/openenv/core/ws_env_client.py +++ b/src/openenv/core/ws_env_client.py @@ -18,26 +18,21 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar -from .client_types import StepResult +from .client_types import StepResult, StateT from .containers.runtime import LocalDockerProvider if TYPE_CHECKING: from .containers.runtime import ContainerProvider from websockets.sync.client import ClientConnection -try: - import websockets - from websockets.sync.client import connect as ws_connect -except ImportError: - websockets = None # type: ignore - ws_connect = None # type: ignore +from websockets.sync.client import connect as ws_connect ActT = TypeVar("ActT") ObsT = TypeVar("ObsT") WSEnvClientT = TypeVar("WSEnvClientT", bound="WebSocketEnvClient") -class WebSocketEnvClient(ABC, Generic[ActT, ObsT]): +class WebSocketEnvClient(ABC, Generic[ActT, ObsT, StateT]): """ WebSocket-based environment client for persistent sessions. @@ -78,12 +73,6 @@ def __init__( message_timeout_s: Timeout for receiving responses to messages provider: Optional container provider for lifecycle management """ - if websockets is None: - raise ImportError( - "websockets library is required for WebSocketEnvClient. " - "Install with: pip install websockets" - ) - # Convert HTTP URL to WebSocket URL ws_url = base_url.rstrip("/") if ws_url.startswith("http://"): @@ -220,17 +209,17 @@ def from_hub( return cls.from_docker_image(image=base_url, provider=provider, **kwargs) @abstractmethod - def _step_payload(self, action: ActT) -> dict: + def _step_payload(self, action: ActT) -> Dict[str, Any]: """Convert an Action object to the JSON data expected by the env server.""" raise NotImplementedError @abstractmethod - def _parse_result(self, payload: dict) -> StepResult[ObsT]: + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]: """Convert a JSON response from the env server to StepResult[ObsT].""" raise NotImplementedError @abstractmethod - def _parse_state(self, payload: dict) -> Any: + def _parse_state(self, payload: Dict[str, Any]) -> StateT: """Convert a JSON response from the state endpoint to a State object.""" raise NotImplementedError @@ -272,7 +261,7 @@ def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: response = self._send_and_receive(message) return self._parse_result(response.get("data", {})) - def state(self) -> Any: + def state(self) -> StateT: """ Get the current environment state from the server. From e95f8b14b9e61100cba7722cd9a984dd7bb72e80 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Fri, 12 Dec 2025 22:20:39 +0530 Subject: [PATCH 27/41] fix: concurrency handling and improve exception messages --- src/openenv/core/__init__.py | 7 ++- src/openenv/core/env_server/exceptions.py | 5 +- src/openenv/core/env_server/http_server.py | 61 +++++++++++--------- src/openenv/core/env_server/interfaces.py | 2 +- src/openenv/core/env_server/types.py | 1 - src/openenv/core/env_server/web_interface.py | 8 ++- 6 files changed, 45 insertions(+), 39 deletions(-) diff --git a/src/openenv/core/__init__.py b/src/openenv/core/__init__.py index 93ae09786..e9bbf2365 100644 --- a/src/openenv/core/__init__.py +++ b/src/openenv/core/__init__.py @@ -8,9 +8,10 @@ # Re-export main components from submodules for convenience from .env_server import * # noqa: F403 -from .env_server import __all__ as _env_server_all - +from . import env_server +from .ws_env_client import WebSocketEnvClient +from .http_env_client import HTTPEnvClient # Note: MCP module doesn't export anything yet -__all__ = list(_env_server_all) \ No newline at end of file +__all__ = ["WebSocketEnvClient", "HTTPEnvClient"] + env_server.__all__ # type: ignore \ No newline at end of file diff --git a/src/openenv/core/env_server/exceptions.py b/src/openenv/core/env_server/exceptions.py index a16715721..23fed6567 100644 --- a/src/openenv/core/env_server/exceptions.py +++ b/src/openenv/core/env_server/exceptions.py @@ -96,10 +96,9 @@ def __init__(self, reason: str, message: Optional[str] = None): class EnvironmentFactoryError(OpenEnvError): """Raised when the environment factory fails to create an instance.""" - def __init__(self, factory_name: str, cause: Exception): + def __init__(self, factory_name: str): self.factory_name = factory_name - self.cause = cause - message = f"Environment factory '{factory_name}' failed to create instance: {cause}" + message = f"Environment factory '{factory_name}' failed to create instance." super().__init__(message) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 56b73b3fa..604600f79 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -96,8 +96,7 @@ def __init__( env: Union[Callable[[], Environment], Type[Environment]], action_cls: Type[Action], observation_cls: Type[Observation], - max_concurrent_envs: int = 1, - skip_concurrency_check: bool = False, + max_concurrent_envs: Optional[int] = None, concurrency_config: Optional[ConcurrencyConfig] = None, ): """ @@ -108,16 +107,13 @@ def __init__( Will be called to create a new environment for each WebSocket session. action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns - max_concurrent_envs: Maximum number of concurrent WebSocket sessions (default: 1). - If concurrency_config is provided, this parameter is ignored. - skip_concurrency_check: If True, skip concurrency safety validation. - Use with caution for advanced users who understand - the isolation requirements. + max_concurrent_envs: Maximum number of concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. - If provided, overrides max_concurrent_envs and allows - configuration of session timeout and capacity behavior. + Mutually exclusive with max_concurrent_envs. Raises: + ValueError: If both max_concurrent_envs and concurrency_config are provided. ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. """ @@ -131,21 +127,29 @@ def __init__( self._env_factory: Callable[[], Environment] = env # Handle concurrency configuration + if max_concurrent_envs is not None and concurrency_config is not None: + raise ValueError( + "Cannot specify both 'max_concurrent_envs' and 'concurrency_config'. " + "Please use only one method to configure concurrency." + ) + if concurrency_config is not None: self._concurrency_config = concurrency_config - self._max_concurrent_envs = concurrency_config.max_concurrent_envs - else: - # Use legacy parameters + elif max_concurrent_envs is not None: self._concurrency_config = ConcurrencyConfig( max_concurrent_envs=max_concurrent_envs, session_timeout=None, reject_on_capacity=True, ) - self._max_concurrent_envs = max_concurrent_envs + else: + # Default configuration + self._concurrency_config = ConcurrencyConfig( + max_concurrent_envs=1, + session_timeout=None, + reject_on_capacity=True, + ) - self._skip_concurrency_check = skip_concurrency_check or os.getenv( - "OPENENV_SKIP_CONCURRENCY_CHECK", "" - ).lower() in ("1", "true", "yes") + self._max_concurrent_envs = self._concurrency_config.max_concurrent_envs # Validate concurrency configuration self._validate_concurrency_safety() @@ -176,9 +180,6 @@ def _validate_concurrency_safety(self) -> None: if self._max_concurrent_envs <= 1: return - if self._skip_concurrency_check: - return - if inspect.isclass(self._env_factory): is_concurrency_safe = getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False) env_name = self._env_factory.__name__ @@ -268,7 +269,11 @@ async def _create_session(self) -> tuple[str, Environment]: session_id = str(uuid.uuid4()) current_time = time.time() - env = self._env_factory() + try: + env = self._env_factory() + except Exception as e: + factory_name = getattr(self._env_factory, "__name__", str(self._env_factory)) + raise EnvironmentFactoryError(factory_name) from e self._sessions[session_id] = env @@ -862,7 +867,7 @@ def create_app( action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, - max_concurrent_envs: int = 1, + max_concurrent_envs: Optional[int] = None, concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: """ @@ -876,10 +881,10 @@ def create_app( action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading - max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1). - Ignored if concurrency_config is provided. + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. - If provided, overrides max_concurrent_envs. + Mutually exclusive with max_concurrent_envs. Returns: FastAPI application instance with or without web interface and README integration @@ -915,7 +920,7 @@ def create_fastapi_app( env: Union[Callable[[], Environment], Type[Environment]], action_cls: Type[Action], observation_cls: Type[Observation], - max_concurrent_envs: int = 1, + max_concurrent_envs: Optional[int] = None, concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: """ @@ -925,10 +930,10 @@ def create_fastapi_app( env: Environment factory (callable or class) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns - max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1). - Ignored if concurrency_config is provided. + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. - If provided, overrides max_concurrent_envs. + Mutually exclusive with max_concurrent_envs. Returns: FastAPI application instance diff --git a/src/openenv/core/env_server/interfaces.py b/src/openenv/core/env_server/interfaces.py index 03f1ddb21..c02ba4a05 100644 --- a/src/openenv/core/env_server/interfaces.py +++ b/src/openenv/core/env_server/interfaces.py @@ -96,7 +96,7 @@ class Environment(ABC, Generic[ActT, ObsT, StateT]): transform: Optional transform to apply to observations Class Attributes: - CONCURRENCY_SAFE: Whether this environment supports concurrent sessions. + SUPPORTS_CONCURRENT_SESSIONS: Whether this environment supports concurrent sessions. When True, multiple WebSocket connections can each have their own environment instance (up to max_concurrent_envs). When False (default), the environment should only be used with a single session at a time. diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 8993d280c..273994479 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -274,7 +274,6 @@ class ConcurrencyConfig(BaseMessage): max_concurrent_envs: int = Field( default=1, ge=1, - le=1000, description="Maximum number of concurrent WebSocket sessions allowed", ) session_timeout: Optional[float] = Field( diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index be55b9146..5711d0ef0 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -239,7 +239,7 @@ def create_web_interface_app( action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, - max_concurrent_envs: int = 1, + max_concurrent_envs: Optional[int] = None, concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: """ @@ -250,8 +250,10 @@ def create_web_interface_app( action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading - max_concurrent_envs: Maximum concurrent WebSocket sessions (default: 1) - concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. Returns: FastAPI application instance with web interface From 05e6da08dc6276a603db925652ae9c78d718fe91 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Sat, 13 Dec 2025 23:10:52 +0530 Subject: [PATCH 28/41] chore: clean up exception handling and remove unused concurrency config field --- src/openenv/core/env_server/exceptions.py | 13 +- src/openenv/core/env_server/http_server.py | 237 +++++++-------------- src/openenv/core/env_server/types.py | 4 - 3 files changed, 80 insertions(+), 174 deletions(-) diff --git a/src/openenv/core/env_server/exceptions.py b/src/openenv/core/env_server/exceptions.py index 23fed6567..4fb4a6ec8 100644 --- a/src/openenv/core/env_server/exceptions.py +++ b/src/openenv/core/env_server/exceptions.py @@ -31,7 +31,7 @@ def __init__( ): self.environment_name = environment_name self.max_concurrent_envs = max_concurrent_envs - + if message is None: message = ( f"Environment '{environment_name}' is not marked as SUPPORTS_CONCURRENT_SESSIONS. " @@ -39,7 +39,7 @@ def __init__( f"Either set max_concurrent_envs=1 or ensure the environment " f"properly isolates session state and set SUPPORTS_CONCURRENT_SESSIONS=True." ) - + super().__init__(message) @@ -96,9 +96,10 @@ def __init__(self, reason: str, message: Optional[str] = None): class EnvironmentFactoryError(OpenEnvError): """Raised when the environment factory fails to create an instance.""" - def __init__(self, factory_name: str): + def __init__(self, factory_name: str, message: Optional[str] = None): self.factory_name = factory_name - - message = f"Environment factory '{factory_name}' failed to create instance." - + + if message is None: + message = f"Environment factory '{factory_name}' failed to create instance." + super().__init__(message) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 604600f79..1b1797cc7 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -18,6 +18,7 @@ import inspect import json import os +import time import uuid from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Dict, Optional, Type, Union @@ -139,14 +140,12 @@ def __init__( self._concurrency_config = ConcurrencyConfig( max_concurrent_envs=max_concurrent_envs, session_timeout=None, - reject_on_capacity=True, ) else: # Default configuration self._concurrency_config = ConcurrencyConfig( max_concurrent_envs=1, session_timeout=None, - reject_on_capacity=True, ) self._max_concurrent_envs = self._concurrency_config.max_concurrent_envs @@ -165,9 +164,7 @@ def __init__( # Create thread pool for running sync code in async context # This is needed for environments using sync libraries (e.g., Playwright) - # Configurable via OPENENV_THREAD_POOL_SIZE (default: 32) - pool_size = int(os.getenv("OPENENV_THREAD_POOL_SIZE", "32")) - self._executor = ThreadPoolExecutor(max_workers=pool_size) + self._executor = ThreadPoolExecutor(max_workers=32) def _validate_concurrency_safety(self) -> None: """ @@ -181,18 +178,16 @@ def _validate_concurrency_safety(self) -> None: return if inspect.isclass(self._env_factory): - is_concurrency_safe = getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False) - env_name = self._env_factory.__name__ + env_cls = self._env_factory else: _temp_env = self._env_factory() - is_concurrency_safe = getattr(_temp_env, "SUPPORTS_CONCURRENT_SESSIONS", False) - env_name = type(_temp_env).__name__ + env_cls = type(_temp_env) _temp_env.close() del _temp_env - if not is_concurrency_safe: + if not getattr(env_cls, "SUPPORTS_CONCURRENT_SESSIONS", False): raise ConcurrencyConfigurationError( - environment_name=env_name, + environment_name=env_cls.__name__, max_concurrent_envs=self._max_concurrent_envs, ) @@ -249,22 +244,12 @@ async def _create_session(self) -> tuple[str, Environment]: SessionCapacityError: If max concurrent sessions reached EnvironmentFactoryError: If the factory fails to create an environment """ - import time - async with self._session_lock: if len(self._sessions) >= self._max_concurrent_envs: - if self._concurrency_config.reject_on_capacity: - raise SessionCapacityError( - active_sessions=len(self._sessions), - max_sessions=self._max_concurrent_envs, - ) - else: - # TODO: Implement queuing mechanism when reject_on_capacity=False - raise SessionCapacityError( - active_sessions=len(self._sessions), - max_sessions=self._max_concurrent_envs, - message="Session queuing not yet implemented", - ) + raise SessionCapacityError( + active_sessions=len(self._sessions), + max_sessions=self._max_concurrent_envs, + ) session_id = str(uuid.uuid4()) current_time = time.time() @@ -318,8 +303,6 @@ def _update_session_activity( session_id: The session ID to update increment_step: If True, increment the step count """ - import time - if session_id in self._session_info: self._session_info[session_id].last_activity_at = time.time() if increment_step: @@ -580,24 +563,6 @@ def get_metadata_handler() -> EnvironmentMetadata: ] register_get_endpoints(app, get_endpoints) - # Register concurrency config endpoint - @app.get( - "/concurrency", - response_model=ConcurrencyConfig, - tags=["Environment Info"], - summary="Get concurrency configuration", - description=""" -Get the current concurrency configuration for this server. - -Returns information about: -- **max_concurrent_envs**: Maximum number of concurrent WebSocket sessions -- **session_timeout**: Timeout in seconds for inactive sessions (None if no timeout) -- **reject_on_capacity**: Whether to reject or queue connections at capacity - """, - ) - async def get_concurrency_config() -> ConcurrencyConfig: - """Return concurrency configuration.""" - return self._concurrency_config # Register combined schema endpoint @app.get( @@ -654,8 +619,7 @@ async def websocket_endpoint(websocket: WebSocket): """ WebSocket endpoint for persistent environment sessions. - Each WebSocket connection gets its own environment instance (when using - factory mode) or shares the single instance (backward compatible mode). + Each WebSocket connection gets its own environment instance. Message Protocol: - Client sends: WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage @@ -689,141 +653,83 @@ async def websocket_endpoint(websocket: WebSocket): msg_type = message_dict.get("type", "") try: - if msg_type == "reset": - # Parse and validate reset message - try: + match msg_type: + case "reset": msg = WSResetMessage(**message_dict) - except ValidationError as e: - error_resp = WSErrorResponse( - data={ - "message": "Invalid reset message", - "code": "VALIDATION_ERROR", - "errors": e.errors(), - } - ) - await websocket.send_text(error_resp.model_dump_json()) - continue - - # Handle reset - is_async = session_env.reset_async.__func__ is not Environment.reset_async - - if is_async: - sig = inspect.signature(session_env.reset_async) - valid_kwargs = self._get_valid_kwargs(sig, msg.data) - observation = await session_env.reset_async(**valid_kwargs) - else: - sig = inspect.signature(session_env.reset) - valid_kwargs = self._get_valid_kwargs(sig, msg.data) - observation = await self._run_in_session_executor( - session_id, session_env.reset, **valid_kwargs - ) - self._update_session_activity(session_id) + is_async = session_env.reset_async.__func__ is not Environment.reset_async - response = WSObservationResponse( - data=serialize_observation(observation) - ) - await websocket.send_text(response.model_dump_json()) + if is_async: + sig = inspect.signature(session_env.reset_async) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await session_env.reset_async(**valid_kwargs) + else: + sig = inspect.signature(session_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await self._run_in_session_executor( + session_id, session_env.reset, **valid_kwargs + ) - elif msg_type == "step": - # Parse and validate step message - try: - msg = WSStepMessage(**message_dict) - except ValidationError as e: - error_resp = WSErrorResponse( - data={ - "message": "Invalid step message", - "code": "VALIDATION_ERROR", - "errors": e.errors(), - } + self._update_session_activity(session_id) + + response = WSObservationResponse( + data=serialize_observation(observation) ) - await websocket.send_text(error_resp.model_dump_json()) - continue - # Deserialize action with Pydantic validation - try: + case "step": + msg = WSStepMessage(**message_dict) action = deserialize_action(msg.data, self.action_cls) - except ValidationError as e: - error_resp = WSErrorResponse( - data={ - "message": str(e), - "code": "VALIDATION_ERROR", - "errors": e.errors(), - } - ) - await websocket.send_text(error_resp.model_dump_json()) - continue - is_async = session_env.step_async.__func__ is not Environment.step_async + is_async = session_env.step_async.__func__ is not Environment.step_async - if is_async: - observation = await session_env.step_async(action) - else: - observation = await self._run_in_session_executor( - session_id, session_env.step, action - ) + if is_async: + observation = await session_env.step_async(action) + else: + observation = await self._run_in_session_executor( + session_id, session_env.step, action + ) - self._update_session_activity( - session_id, increment_step=True - ) + self._update_session_activity( + session_id, increment_step=True + ) - response = WSObservationResponse( - data=serialize_observation(observation) - ) - await websocket.send_text(response.model_dump_json()) + response = WSObservationResponse( + data=serialize_observation(observation) + ) - elif msg_type == "state": - # Parse and validate state message - try: + case "state": msg = WSStateMessage(**message_dict) - except ValidationError as e: - error_resp = WSErrorResponse( - data={ - "message": "Invalid state message", - "code": "VALIDATION_ERROR", - "errors": e.errors(), - } - ) - await websocket.send_text(error_resp.model_dump_json()) - continue - - # Handle state request - state = session_env.state - if hasattr(state, "model_dump"): - state_data = state.model_dump() - else: - state_data = dict(state) if state else {} - - response = WSStateResponse(data=state_data) - await websocket.send_text(response.model_dump_json()) - - elif msg_type == "close": - # Parse and validate close message - try: + state = session_env.state + if hasattr(state, "model_dump"): + state_data = state.model_dump() + else: + state_data = dict(state) if state else {} + + response = WSStateResponse(data=state_data) + + case "close": msg = WSCloseMessage(**message_dict) - except ValidationError as e: - error_resp = WSErrorResponse( + break + + case _: + response = WSErrorResponse( data={ - "message": "Invalid close message", - "code": "VALIDATION_ERROR", - "errors": e.errors(), + "message": f"Unknown message type: {msg_type}", + "code": "UNKNOWN_TYPE", } ) - await websocket.send_text(error_resp.model_dump_json()) - continue - # Client requested close - break - - else: - error_resp = WSErrorResponse( - data={ - "message": f"Unknown message type: {msg_type}", - "code": "UNKNOWN_TYPE", - } - ) - await websocket.send_text(error_resp.model_dump_json()) + await websocket.send_text(response.model_dump_json()) + except ValidationError as e: + error_resp = WSErrorResponse( + data={ + "message": "Invalid message", + "code": "VALIDATION_ERROR", + "errors": e.errors(), + } + ) + await websocket.send_text(error_resp.model_dump_json()) except Exception as e: error_resp = WSErrorResponse( data={"message": str(e), "code": "EXECUTION_ERROR"} @@ -859,7 +765,10 @@ async def websocket_endpoint(websocket: WebSocket): finally: if session_id: await self._destroy_session(session_id) - await websocket.close() + try: + await websocket.close() + except RuntimeError: + pass def create_app( diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index 273994479..a22914b73 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -281,10 +281,6 @@ class ConcurrencyConfig(BaseMessage): gt=0, description="Timeout in seconds for inactive sessions. None means no timeout.", ) - reject_on_capacity: bool = Field( - default=True, - description="If True, reject new connections when at capacity. If False, queue them.", - ) class ServerCapacityStatus(BaseMessage): From d52850f646f97292ea9435bc1748c6f3ce2ad91b Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Sun, 14 Dec 2025 23:16:09 +0530 Subject: [PATCH 29/41] refactor: simplify environment factory type annotations and add utility function for URL conversion --- src/openenv/core/env_server/http_server.py | 12 ++++----- src/openenv/core/env_server/web_interface.py | 4 +-- src/openenv/core/utils.py | 26 ++++++++++++++++++++ src/openenv/core/ws_env_client.py | 9 ++----- 4 files changed, 36 insertions(+), 15 deletions(-) create mode 100644 src/openenv/core/utils.py diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 1b1797cc7..b816b3d62 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -94,7 +94,7 @@ class HTTPEnvServer: def __init__( self, - env: Union[Callable[[], Environment], Type[Environment]], + env: Callable[[], Environment], action_cls: Type[Action], observation_cls: Type[Observation], max_concurrent_envs: Optional[int] = None, @@ -104,7 +104,7 @@ def __init__( Initialize HTTP server wrapper. Args: - env: Environment factory (callable or class) that creates new instances. + env: Environment factory (callable) that creates new instances. Will be called to create a new environment for each WebSocket session. action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns @@ -772,7 +772,7 @@ async def websocket_endpoint(websocket: WebSocket): def create_app( - env: Union[Callable[[], Environment], Type[Environment]], + env: Callable[[], Environment], action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, @@ -786,7 +786,7 @@ def create_app( including README integration for better user experience. Args: - env: Environment factory (callable or class) that creates new instances + env: Environment factory (callable) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading @@ -826,7 +826,7 @@ def create_app( def create_fastapi_app( - env: Union[Callable[[], Environment], Type[Environment]], + env: Callable[[], Environment], action_cls: Type[Action], observation_cls: Type[Observation], max_concurrent_envs: Optional[int] = None, @@ -836,7 +836,7 @@ def create_fastapi_app( Create a FastAPI application with comprehensive documentation. Args: - env: Environment factory (callable or class) that creates new instances + env: Environment factory (callable) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns max_concurrent_envs: Maximum concurrent WebSocket sessions. diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index 5711d0ef0..703025375 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -235,7 +235,7 @@ def get_state(self) -> Dict[str, Any]: def create_web_interface_app( - env: Union[Callable[[], Environment], Type[Environment]], + env: Callable[[], Environment], action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, @@ -246,7 +246,7 @@ def create_web_interface_app( Create a FastAPI application with web interface for the given environment. Args: - env: Environment factory (callable or class) that creates new instances + env: Environment factory (callable) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading diff --git a/src/openenv/core/utils.py b/src/openenv/core/utils.py new file mode 100644 index 000000000..42e9cee82 --- /dev/null +++ b/src/openenv/core/utils.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Utility functions for OpenEnv core.""" + +def convert_to_ws_url(url: str) -> str: + """ + Convert an HTTP/HTTPS URL to a WS/WSS URL. + + Args: + url: The URL to convert. + + Returns: + The converted WebSocket URL. + """ + ws_url = url.rstrip("/") + if ws_url.startswith("http://"): + ws_url = "ws://" + ws_url[7:] + elif ws_url.startswith("https://"): + ws_url = "wss://" + ws_url[8:] + elif not ws_url.startswith("ws://") and not ws_url.startswith("wss://"): + ws_url = "ws://" + ws_url + return ws_url diff --git a/src/openenv/core/ws_env_client.py b/src/openenv/core/ws_env_client.py index 6c1d6a4ab..efa829f64 100644 --- a/src/openenv/core/ws_env_client.py +++ b/src/openenv/core/ws_env_client.py @@ -20,6 +20,7 @@ from .client_types import StepResult, StateT from .containers.runtime import LocalDockerProvider +from .utils import convert_to_ws_url if TYPE_CHECKING: from .containers.runtime import ContainerProvider @@ -74,13 +75,7 @@ def __init__( provider: Optional container provider for lifecycle management """ # Convert HTTP URL to WebSocket URL - ws_url = base_url.rstrip("/") - if ws_url.startswith("http://"): - ws_url = "ws://" + ws_url[7:] - elif ws_url.startswith("https://"): - ws_url = "wss://" + ws_url[8:] - elif not ws_url.startswith("ws://") and not ws_url.startswith("wss://"): - ws_url = "ws://" + ws_url + ws_url = convert_to_ws_url(base_url) self._ws_url = f"{ws_url}/ws" self._connect_timeout = connect_timeout_s From 737386bd9db47c25924e15607653485756dc529f Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Mon, 15 Dec 2025 20:51:05 +0100 Subject: [PATCH 30/41] update cli template with naming --- .../cli/templates/openenv_env/README.md | 18 +-- .../cli/templates/openenv_env/__init__.py | 5 +- .../cli/templates/openenv_env/client.py | 116 ++---------------- .../cli/templates/openenv_env/server/app.py | 2 +- 4 files changed, 24 insertions(+), 117 deletions(-) diff --git a/src/openenv/cli/templates/openenv_env/README.md b/src/openenv/cli/templates/openenv_env/README.md index f6a5c0292..3f14526a0 100644 --- a/src/openenv/cli/templates/openenv_env/README.md +++ b/src/openenv/cli/templates/openenv_env/README.md @@ -155,15 +155,15 @@ result = __ENV_NAME__env.step(__ENV_CLASS_NAME__Action(message="Hello!")) Note: When connecting to an existing server, `__ENV_NAME__env.close()` will NOT stop the server. -### WebSocket Client for Persistent Sessions +### Using the Context Manager -For long-running episodes or when you need lower latency, use the WebSocket client: +The client supports context manager usage for automatic connection management: ```python -from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__EnvWS +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env -# Connect via WebSocket (maintains persistent connection) -with __ENV_CLASS_NAME__EnvWS(base_url="http://localhost:8000") as env: +# Connect with context manager (auto-connects and closes) +with __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") as env: result = env.reset() print(f"Reset: {result.observation.echoed_message}") # Multiple steps with low latency @@ -172,7 +172,7 @@ with __ENV_CLASS_NAME__EnvWS(base_url="http://localhost:8000") as env: print(f"Echoed: {result.observation.echoed_message}") ``` -WebSocket advantages: +The client uses WebSocket connections for: - **Lower latency**: No HTTP connection overhead per request - **Persistent session**: Server maintains your environment state - **Efficient for episodes**: Better for many sequential steps @@ -195,11 +195,11 @@ app = create_app( Then multiple clients can connect simultaneously: ```python -from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__EnvWS +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env from concurrent.futures import ThreadPoolExecutor def run_episode(client_id: int): - with __ENV_CLASS_NAME__EnvWS(base_url="http://localhost:8000") as env: + with __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") as env: result = env.reset() for i in range(10): result = env.step(__ENV_CLASS_NAME__Action(message=f"Client {client_id}, step {i}")) @@ -245,7 +245,7 @@ __ENV_NAME__/ ├── openenv.yaml # OpenEnv manifest ├── pyproject.toml # Project metadata and dependencies ├── uv.lock # Locked dependencies (generated) -├── client.py # __ENV_CLASS_NAME__Env (HTTP) and __ENV_CLASS_NAME__EnvWS (WebSocket) clients +├── client.py # __ENV_CLASS_NAME__Env client ├── models.py # Action and Observation models └── server/ ├── __init__.py # Server module exports diff --git a/src/openenv/cli/templates/openenv_env/__init__.py b/src/openenv/cli/templates/openenv_env/__init__.py index aed293ba8..cbe07a082 100644 --- a/src/openenv/cli/templates/openenv_env/__init__.py +++ b/src/openenv/cli/templates/openenv_env/__init__.py @@ -4,14 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""__ENV_TITLE_NAME__ Environment - A simple test environment for HTTP server.""" +"""__ENV_TITLE_NAME__ Environment.""" -from .client import __ENV_CLASS_NAME__Env, __ENV_CLASS_NAME__EnvWS +from .client import __ENV_CLASS_NAME__Env from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation __all__ = [ "__ENV_CLASS_NAME__Action", "__ENV_CLASS_NAME__Observation", "__ENV_CLASS_NAME__Env", - "__ENV_CLASS_NAME__EnvWS", ] diff --git a/src/openenv/cli/templates/openenv_env/client.py b/src/openenv/cli/templates/openenv_env/client.py index 0775f2536..6be3eefd9 100644 --- a/src/openenv/cli/templates/openenv_env/client.py +++ b/src/openenv/cli/templates/openenv_env/client.py @@ -4,120 +4,28 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -__ENV_TITLE_NAME__ Environment Clients. +"""__ENV_TITLE_NAME__ Environment Client.""" -This module provides clients for connecting to a __ENV_TITLE_NAME__ Environment server: -- __ENV_CLASS_NAME__Env: HTTP client for request/response interactions -- __ENV_CLASS_NAME__EnvWS: WebSocket client for persistent sessions -""" - -from typing import Any, Dict +from typing import Dict from openenv.core.client_types import StepResult from openenv.core.env_server.types import State -from openenv.core.http_env_client import HTTPEnvClient -from openenv.core.ws_env_client import WebSocketEnvClient +from openenv.core import EnvClient from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation -class __ENV_CLASS_NAME__Env(HTTPEnvClient[__ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation]): - """ - HTTP client for the __ENV_TITLE_NAME__ Environment. - - This client connects to a __ENV_CLASS_NAME__Environment HTTP server and provides - methods to interact with it: reset(), step(), and state access. - - Example: - >>> # Connect to a running server - >>> client = __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") - >>> result = client.reset() - >>> print(result.observation.echoed_message) - >>> - >>> # Send a message - >>> result = client.step(__ENV_CLASS_NAME__Action(message="Hello!")) - >>> print(result.observation.echoed_message) - >>> print(result.reward) - - Example with Docker: - >>> # Automatically start container and connect - >>> client = __ENV_CLASS_NAME__Env.from_docker_image("__ENV_NAME__-env:latest") - >>> result = client.reset() - >>> result = client.step(__ENV_CLASS_NAME__Action(message="Test")) - """ - - def _step_payload(self, action: __ENV_CLASS_NAME__Action) -> Dict: - """ - Convert __ENV_CLASS_NAME__Action to JSON payload for step request. - - Args: - action: __ENV_CLASS_NAME__Action instance - - Returns: - Dictionary representation suitable for JSON encoding - """ - return { - "message": action.message, - } - - def _parse_result(self, payload: Dict) -> StepResult[__ENV_CLASS_NAME__Observation]: - """ - Parse server response into StepResult[__ENV_CLASS_NAME__Observation]. - - Args: - payload: JSON response from server - - Returns: - StepResult with __ENV_CLASS_NAME__Observation - """ - obs_data = payload.get("observation", {}) - observation = __ENV_CLASS_NAME__Observation( - echoed_message=obs_data.get("echoed_message", ""), - message_length=obs_data.get("message_length", 0), - done=payload.get("done", False), - reward=payload.get("reward"), - metadata=obs_data.get("metadata", {}), - ) - - return StepResult( - observation=observation, - reward=payload.get("reward"), - done=payload.get("done", False), - ) - - def _parse_state(self, payload: Dict) -> State: - """ - Parse server response into State object. - - Args: - payload: JSON response from /state endpoint - - Returns: - State object with episode_id and step_count - """ - return State( - episode_id=payload.get("episode_id"), - step_count=payload.get("step_count", 0), - ) - - -class __ENV_CLASS_NAME__EnvWS(WebSocketEnvClient[__ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation]): +class __ENV_CLASS_NAME__Env(EnvClient[__ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation]): """ - WebSocket client for the __ENV_TITLE_NAME__ Environment. + Client for the __ENV_TITLE_NAME__ Environment. This client maintains a persistent WebSocket connection to the environment server, - enabling efficient multi-step interactions with lower latency than HTTP. + enabling efficient multi-step interactions with lower latency. Each client instance has its own dedicated environment session on the server. - Advantages over HTTP client: - - Lower latency for sequential interactions (no connection overhead per request) - - Session state is maintained server-side - - Better suited for long-running episodes - Example: - >>> # Connect to a running server via WebSocket - >>> with __ENV_CLASS_NAME__EnvWS(base_url="http://localhost:8000") as client: + >>> # Connect to a running server + >>> with __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") as client: ... result = client.reset() ... print(result.observation.echoed_message) ... @@ -125,8 +33,8 @@ class __ENV_CLASS_NAME__EnvWS(WebSocketEnvClient[__ENV_CLASS_NAME__Action, __ENV ... print(result.observation.echoed_message) Example with Docker: - >>> # Automatically start container and connect via WebSocket - >>> client = __ENV_CLASS_NAME__EnvWS.from_docker_image("__ENV_NAME__-env:latest") + >>> # Automatically start container and connect + >>> client = __ENV_CLASS_NAME__Env.from_docker_image("__ENV_NAME__-env:latest") >>> try: ... result = client.reset() ... result = client.step(__ENV_CLASS_NAME__Action(message="Test")) @@ -150,7 +58,7 @@ def _step_payload(self, action: __ENV_CLASS_NAME__Action) -> Dict: def _parse_result(self, payload: Dict) -> StepResult[__ENV_CLASS_NAME__Observation]: """ - Parse WebSocket response into StepResult[__ENV_CLASS_NAME__Observation]. + Parse server response into StepResult[__ENV_CLASS_NAME__Observation]. Args: payload: JSON response data from server @@ -175,7 +83,7 @@ def _parse_result(self, payload: Dict) -> StepResult[__ENV_CLASS_NAME__Observati def _parse_state(self, payload: Dict) -> State: """ - Parse WebSocket state response into State object. + Parse server response into State object. Args: payload: JSON response from state request diff --git a/src/openenv/cli/templates/openenv_env/server/app.py b/src/openenv/cli/templates/openenv_env/server/app.py index 5100b1050..025920a1b 100644 --- a/src/openenv/cli/templates/openenv_env/server/app.py +++ b/src/openenv/cli/templates/openenv_env/server/app.py @@ -8,7 +8,7 @@ FastAPI application for the __ENV_TITLE_NAME__ Environment. This module creates an HTTP server that exposes the __ENV_CLASS_NAME__Environment -over HTTP and WebSocket endpoints, compatible with HTTPEnvClient and WebSocketEnvClient. +over HTTP and WebSocket endpoints, compatible with EnvClient. Endpoints: - POST /reset: Reset the environment From 4bdfb6bc0d34ffcb3526f8a035b0dd27d634d0ee Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Mon, 15 Dec 2025 20:51:25 +0100 Subject: [PATCH 31/41] update prociders docstring --- src/openenv/core/containers/runtime/providers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/openenv/core/containers/runtime/providers.py b/src/openenv/core/containers/runtime/providers.py index a8022ddca..f6f2b0ca6 100644 --- a/src/openenv/core/containers/runtime/providers.py +++ b/src/openenv/core/containers/runtime/providers.py @@ -8,7 +8,7 @@ Container provider abstractions for running environment servers. This module provides a pluggable architecture for different container providers -(local Docker, Kubernetes, cloud providers, etc.) to be used with HTTPEnvClient. +(local Docker, Kubernetes, cloud providers, etc.) to be used with EnvClient. """ from __future__ import annotations From 227ca93ca58596f65ec80d454a71015418516f5e Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Mon, 15 Dec 2025 20:51:47 +0100 Subject: [PATCH 32/41] remove http from env server --- src/openenv/core/env_server/http_server.py | 6 ++---- src/openenv/core/env_server/serialization.py | 6 +++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index b816b3d62..ad3f8b365 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -8,8 +8,7 @@ HTTP server wrapper for Environment instances. This module provides utilities to wrap any Environment subclass and expose it -over HTTP endpoints that HTTPEnvClient can consume. Also supports WebSocket -connections for persistent sessions with multi-environment concurrency. +over HTTP and WebSocket endpoints that EnvClient can consume. """ from __future__ import annotations @@ -66,8 +65,7 @@ class HTTPEnvServer: HTTP server wrapper for Environment instances. This class wraps an Environment and exposes its reset(), step(), and state - methods as HTTP endpoints compatible with HTTPEnvClient. Also supports - WebSocket connections for persistent sessions with multi-environment concurrency. + methods as HTTP and WebSocket endpoints compatible with EnvClient. The server expects: - Action deserialization: Converts JSON dict to Action subclass diff --git a/src/openenv/core/env_server/serialization.py b/src/openenv/core/env_server/serialization.py index df06592f5..9e88a33c9 100644 --- a/src/openenv/core/env_server/serialization.py +++ b/src/openenv/core/env_server/serialization.py @@ -109,9 +109,9 @@ def serialize_observation(observation: Observation) -> Dict[str, Any]: observation: Observation instance Returns: - Dictionary compatible with HTTPEnvClient._parse_result() + Dictionary compatible with EnvClient._parse_result() - The format matches what HTTPEnvClient expects: + The format matches what EnvClient expects: { "observation": {...}, # Observation fields "reward": float | None, @@ -131,7 +131,7 @@ def serialize_observation(observation: Observation) -> Dict[str, Any]: reward = observation.reward done = observation.done - # Return in HTTPEnvClient expected format + # Return in EnvClient expected format return { "observation": obs_dict, "reward": reward, From 402d144e97ad6e96a73835770ab35ed834beb666 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Mon, 15 Dec 2025 20:52:16 +0100 Subject: [PATCH 33/41] rename in core to envclient --- src/openenv/core/README.md | 10 +- src/openenv/core/__init__.py | 5 +- .../core/{ws_env_client.py => env_client.py} | 578 +++++++++--------- src/openenv/core/http_env_client.py | 236 ------- 4 files changed, 296 insertions(+), 533 deletions(-) rename src/openenv/core/{ws_env_client.py => env_client.py} (86%) delete mode 100644 src/openenv/core/http_env_client.py diff --git a/src/openenv/core/README.md b/src/openenv/core/README.md index 2251e10a6..ebfa579aa 100644 --- a/src/openenv/core/README.md +++ b/src/openenv/core/README.md @@ -22,8 +22,8 @@ Core components for OpenEnv - a framework for building HTTP-based agentic enviro ## Features -- **HTTPEnvClient**: Generic HTTP client for interacting with remote environments -- **HTTPEnvServer**: FastAPI-based server wrapper for exposing environments over HTTP +- **EnvClient**: Generic client for interacting with remote environments +- **HTTPEnvServer**: FastAPI-based server wrapper for exposing environments over HTTP/WebSocket - **Container Providers**: Pluggable architecture for running containers (Docker, Kubernetes, etc.) - **Type System**: Strongly-typed Action/Observation/State interfaces - **Web Interface**: Optional web UI for interacting with environments @@ -44,7 +44,7 @@ pip install "openenv[core]" ### Creating an Environment Client ```python -from openenv.core import HTTPEnvClient, StepResult +from openenv.core import EnvClient, StepResult from dataclasses import dataclass @dataclass @@ -55,7 +55,7 @@ class MyAction: class MyObservation: response: str -class MyEnvClient(HTTPEnvClient[MyAction, MyObservation]): +class MyEnvClient(EnvClient[MyAction, MyObservation]): def _step_payload(self, action: MyAction) -> dict: return {"text": action.text} @@ -141,7 +141,7 @@ provider.stop_container() ## API Reference -### HTTPEnvClient +### EnvClient Base class for environment clients with these abstract methods: diff --git a/src/openenv/core/__init__.py b/src/openenv/core/__init__.py index e9bbf2365..5a7af20db 100644 --- a/src/openenv/core/__init__.py +++ b/src/openenv/core/__init__.py @@ -9,9 +9,8 @@ # Re-export main components from submodules for convenience from .env_server import * # noqa: F403 from . import env_server -from .ws_env_client import WebSocketEnvClient -from .http_env_client import HTTPEnvClient +from .env_client import EnvClient # Note: MCP module doesn't export anything yet -__all__ = ["WebSocketEnvClient", "HTTPEnvClient"] + env_server.__all__ # type: ignore \ No newline at end of file +__all__ = ["EnvClient"] + env_server.__all__ # type: ignore diff --git a/src/openenv/core/ws_env_client.py b/src/openenv/core/env_client.py similarity index 86% rename from src/openenv/core/ws_env_client.py rename to src/openenv/core/env_client.py index efa829f64..356fe72c9 100644 --- a/src/openenv/core/ws_env_client.py +++ b/src/openenv/core/env_client.py @@ -1,289 +1,289 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -WebSocket-based environment client for persistent sessions. - -This module provides a WebSocket client that maintains a persistent connection -to an environment server, enabling efficient multi-step interactions without -the overhead of HTTP request/response cycles. -""" - -from __future__ import annotations - -import json -from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar - -from .client_types import StepResult, StateT -from .containers.runtime import LocalDockerProvider -from .utils import convert_to_ws_url - -if TYPE_CHECKING: - from .containers.runtime import ContainerProvider - from websockets.sync.client import ClientConnection - -from websockets.sync.client import connect as ws_connect - -ActT = TypeVar("ActT") -ObsT = TypeVar("ObsT") -WSEnvClientT = TypeVar("WSEnvClientT", bound="WebSocketEnvClient") - - -class WebSocketEnvClient(ABC, Generic[ActT, ObsT, StateT]): - """ - WebSocket-based environment client for persistent sessions. - - This client maintains a persistent WebSocket connection to an environment - server, enabling efficient multi-step interactions. Each client instance - corresponds to a dedicated environment session on the server. - - Compared to HTTPEnvClient: - - Lower latency for sequential interactions - - Session state is maintained server-side - - Better suited for long-running episodes - - Example: - >>> from envs.coding_env.client import CodingEnvWS - >>> - >>> # Connect to a server via WebSocket - >>> with CodingEnvWS(base_url="ws://localhost:8000") as env: - ... result = env.reset(seed=42) - ... while not result.done: - ... action = agent.predict(result.observation) - ... result = env.step(action) - """ - - def __init__( - self, - base_url: str, - connect_timeout_s: float = 10.0, - message_timeout_s: float = 60.0, - provider: Optional["ContainerProvider"] = None, - ): - """ - Initialize WebSocket client. - - Args: - base_url: Base URL of the environment server (http:// or ws://). - Will be converted to ws:// if http:// is provided. - connect_timeout_s: Timeout for establishing WebSocket connection - message_timeout_s: Timeout for receiving responses to messages - provider: Optional container provider for lifecycle management - """ - # Convert HTTP URL to WebSocket URL - ws_url = convert_to_ws_url(base_url) - - self._ws_url = f"{ws_url}/ws" - self._connect_timeout = connect_timeout_s - self._message_timeout = message_timeout_s - self._provider = provider - self._ws: Optional[ClientConnection] = None - - def connect(self) -> "WebSocketEnvClient": - """ - Establish WebSocket connection to the server. - - Returns: - self for method chaining - - Raises: - ConnectionError: If connection cannot be established - """ - if self._ws is not None: - return self - - try: - self._ws = ws_connect( - self._ws_url, - open_timeout=self._connect_timeout, - ) - except Exception as e: - raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e - - return self - - def disconnect(self) -> None: - """Close the WebSocket connection.""" - if self._ws is not None: - try: - # Send close message - self._send({"type": "close"}) - except Exception: - pass # Best effort - try: - self._ws.close() - except Exception: - pass - self._ws = None - - def _ensure_connected(self) -> None: - """Ensure WebSocket connection is established.""" - if self._ws is None: - self.connect() - - def _send(self, message: Dict[str, Any]) -> None: - """Send a message over the WebSocket.""" - self._ensure_connected() - assert self._ws is not None - self._ws.send(json.dumps(message)) - - def _receive(self) -> Dict[str, Any]: - """Receive and parse a message from the WebSocket.""" - assert self._ws is not None - raw = self._ws.recv(timeout=self._message_timeout) - return json.loads(raw) - - def _send_and_receive(self, message: Dict[str, Any]) -> Dict[str, Any]: - """Send a message and wait for response.""" - self._send(message) - response = self._receive() - - # Check for error response - if response.get("type") == "error": - error_data = response.get("data", {}) - raise RuntimeError( - f"Server error: {error_data.get('message', 'Unknown error')} " - f"(code: {error_data.get('code', 'UNKNOWN')})" - ) - - return response - - @classmethod - def from_docker_image( - cls: Type[WSEnvClientT], - image: str, - provider: Optional["ContainerProvider"] = None, - **kwargs: Any, - ) -> WSEnvClientT: - """ - Create a WebSocket environment client by spinning up a Docker container. - - Args: - image: Docker image name to run (e.g., "coding-env:latest") - provider: Container provider to use (defaults to LocalDockerProvider) - **kwargs: Additional arguments to pass to provider.start_container() - - Returns: - Connected WebSocket client instance - """ - if provider is None: - provider = LocalDockerProvider() - - # Start container - base_url = provider.start_container(image, **kwargs) - - # Wait for server to be ready - provider.wait_for_ready(base_url) - - # Create and connect client - client = cls(base_url=base_url, provider=provider) - client.connect() - - return client - - @classmethod - def from_hub( - cls: Type[WSEnvClientT], - repo_id: str, - provider: Optional["ContainerProvider"] = None, - **kwargs: Any, - ) -> WSEnvClientT: - """ - Create a WebSocket client by pulling from a Hugging Face model hub. - """ - if provider is None: - provider = LocalDockerProvider() - - tag = kwargs.pop("tag", "latest") - base_url = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}" - - return cls.from_docker_image(image=base_url, provider=provider, **kwargs) - - @abstractmethod - def _step_payload(self, action: ActT) -> Dict[str, Any]: - """Convert an Action object to the JSON data expected by the env server.""" - raise NotImplementedError - - @abstractmethod - def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]: - """Convert a JSON response from the env server to StepResult[ObsT].""" - raise NotImplementedError - - @abstractmethod - def _parse_state(self, payload: Dict[str, Any]) -> StateT: - """Convert a JSON response from the state endpoint to a State object.""" - raise NotImplementedError - - def reset(self, **kwargs: Any) -> StepResult[ObsT]: - """ - Reset the environment with optional parameters. - - Args: - **kwargs: Optional parameters passed to the environment's reset method. - Common parameters include: - - seed: Random seed for reproducibility - - episode_id: Custom episode identifier - - Returns: - StepResult containing initial observation - """ - message = { - "type": "reset", - "data": kwargs, - } - response = self._send_and_receive(message) - return self._parse_result(response.get("data", {})) - - def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: - """ - Execute an action in the environment. - - Args: - action: The action to execute - **kwargs: Optional parameters (currently ignored for WebSocket) - - Returns: - StepResult containing observation, reward, and done status - """ - message = { - "type": "step", - "data": self._step_payload(action), - } - response = self._send_and_receive(message) - return self._parse_result(response.get("data", {})) - - def state(self) -> StateT: - """ - Get the current environment state from the server. - - Returns: - State object with environment state information - """ - message = {"type": "state"} - response = self._send_and_receive(message) - return self._parse_state(response.get("data", {})) - - def close(self) -> None: - """ - Close the WebSocket connection and clean up resources. - - If this client was created via from_docker_image(), this will also - stop and remove the associated container. - """ - self.disconnect() - - if self._provider is not None: - self._provider.stop_container() - - def __enter__(self) -> "WebSocketEnvClient": - """Enter context manager, ensuring connection is established.""" - self.connect() - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - """Exit context manager, closing connection.""" - self.close() +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Environment client for persistent sessions. + +This module provides a WebSocket-based client that maintains a persistent connection +to an environment server, enabling efficient multi-step interactions without +the overhead of HTTP request/response cycles. +""" + +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar + +from .client_types import StepResult, StateT +from .containers.runtime import LocalDockerProvider +from .utils import convert_to_ws_url + +if TYPE_CHECKING: + from .containers.runtime import ContainerProvider + from websockets.sync.client import ClientConnection + +from websockets.sync.client import connect as ws_connect + +ActT = TypeVar("ActT") +ObsT = TypeVar("ObsT") +EnvClientT = TypeVar("EnvClientT", bound="EnvClient") + + +class EnvClient(ABC, Generic[ActT, ObsT, StateT]): + """ + Environment client for persistent sessions. + + This client maintains a persistent WebSocket connection to an environment + server, enabling efficient multi-step interactions. Each client instance + corresponds to a dedicated environment session on the server. + + Features: + - Lower latency for sequential interactions + - Session state is maintained server-side + - Better suited for long-running episodes + + Example: + >>> from envs.coding_env.client import CodingEnv + >>> + >>> # Connect to a server + >>> with CodingEnv(base_url="ws://localhost:8000") as env: + ... result = env.reset(seed=42) + ... while not result.done: + ... action = agent.predict(result.observation) + ... result = env.step(action) + """ + + def __init__( + self, + base_url: str, + connect_timeout_s: float = 10.0, + message_timeout_s: float = 60.0, + provider: Optional["ContainerProvider"] = None, + ): + """ + Initialize environment client. + + Args: + base_url: Base URL of the environment server (http:// or ws://). + Will be converted to ws:// if http:// is provided. + connect_timeout_s: Timeout for establishing WebSocket connection + message_timeout_s: Timeout for receiving responses to messages + provider: Optional container provider for lifecycle management + """ + # Convert HTTP URL to WebSocket URL + ws_url = convert_to_ws_url(base_url) + + self._ws_url = f"{ws_url}/ws" + self._connect_timeout = connect_timeout_s + self._message_timeout = message_timeout_s + self._provider = provider + self._ws: Optional[ClientConnection] = None + + def connect(self) -> "EnvClient": + """ + Establish WebSocket connection to the server. + + Returns: + self for method chaining + + Raises: + ConnectionError: If connection cannot be established + """ + if self._ws is not None: + return self + + try: + self._ws = ws_connect( + self._ws_url, + open_timeout=self._connect_timeout, + ) + except Exception as e: + raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e + + return self + + def disconnect(self) -> None: + """Close the WebSocket connection.""" + if self._ws is not None: + try: + # Send close message + self._send({"type": "close"}) + except Exception: + pass # Best effort + try: + self._ws.close() + except Exception: + pass + self._ws = None + + def _ensure_connected(self) -> None: + """Ensure WebSocket connection is established.""" + if self._ws is None: + self.connect() + + def _send(self, message: Dict[str, Any]) -> None: + """Send a message over the WebSocket.""" + self._ensure_connected() + assert self._ws is not None + self._ws.send(json.dumps(message)) + + def _receive(self) -> Dict[str, Any]: + """Receive and parse a message from the WebSocket.""" + assert self._ws is not None + raw = self._ws.recv(timeout=self._message_timeout) + return json.loads(raw) + + def _send_and_receive(self, message: Dict[str, Any]) -> Dict[str, Any]: + """Send a message and wait for response.""" + self._send(message) + response = self._receive() + + # Check for error response + if response.get("type") == "error": + error_data = response.get("data", {}) + raise RuntimeError( + f"Server error: {error_data.get('message', 'Unknown error')} " + f"(code: {error_data.get('code', 'UNKNOWN')})" + ) + + return response + + @classmethod + def from_docker_image( + cls: Type[EnvClientT], + image: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> EnvClientT: + """ + Create an environment client by spinning up a Docker container. + + Args: + image: Docker image name to run (e.g., "coding-env:latest") + provider: Container provider to use (defaults to LocalDockerProvider) + **kwargs: Additional arguments to pass to provider.start_container() + + Returns: + Connected client instance + """ + if provider is None: + provider = LocalDockerProvider() + + # Start container + base_url = provider.start_container(image, **kwargs) + + # Wait for server to be ready + provider.wait_for_ready(base_url) + + # Create and connect client + client = cls(base_url=base_url, provider=provider) + client.connect() + + return client + + @classmethod + def from_hub( + cls: Type[EnvClientT], + repo_id: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> EnvClientT: + """ + Create a client by pulling from a Hugging Face model hub. + """ + if provider is None: + provider = LocalDockerProvider() + + tag = kwargs.pop("tag", "latest") + base_url = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}" + + return cls.from_docker_image(image=base_url, provider=provider, **kwargs) + + @abstractmethod + def _step_payload(self, action: ActT) -> Dict[str, Any]: + """Convert an Action object to the JSON data expected by the env server.""" + raise NotImplementedError + + @abstractmethod + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]: + """Convert a JSON response from the env server to StepResult[ObsT].""" + raise NotImplementedError + + @abstractmethod + def _parse_state(self, payload: Dict[str, Any]) -> StateT: + """Convert a JSON response from the state endpoint to a State object.""" + raise NotImplementedError + + def reset(self, **kwargs: Any) -> StepResult[ObsT]: + """ + Reset the environment with optional parameters. + + Args: + **kwargs: Optional parameters passed to the environment's reset method. + Common parameters include: + - seed: Random seed for reproducibility + - episode_id: Custom episode identifier + + Returns: + StepResult containing initial observation + """ + message = { + "type": "reset", + "data": kwargs, + } + response = self._send_and_receive(message) + return self._parse_result(response.get("data", {})) + + def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: + """ + Execute an action in the environment. + + Args: + action: The action to execute + **kwargs: Optional parameters (currently ignored) + + Returns: + StepResult containing observation, reward, and done status + """ + message = { + "type": "step", + "data": self._step_payload(action), + } + response = self._send_and_receive(message) + return self._parse_result(response.get("data", {})) + + def state(self) -> StateT: + """ + Get the current environment state from the server. + + Returns: + State object with environment state information + """ + message = {"type": "state"} + response = self._send_and_receive(message) + return self._parse_state(response.get("data", {})) + + def close(self) -> None: + """ + Close the WebSocket connection and clean up resources. + + If this client was created via from_docker_image(), this will also + stop and remove the associated container. + """ + self.disconnect() + + if self._provider is not None: + self._provider.stop_container() + + def __enter__(self) -> "EnvClient": + """Enter context manager, ensuring connection is established.""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit context manager, closing connection.""" + self.close() diff --git a/src/openenv/core/http_env_client.py b/src/openenv/core/http_env_client.py deleted file mode 100644 index 0f25363d4..000000000 --- a/src/openenv/core/http_env_client.py +++ /dev/null @@ -1,236 +0,0 @@ -""" -core/runner_env.py -Minimal HTTP-based environment client. -- Talks to a single env worker exposing: POST /reset, POST /step - -Future hooks (commented below) for: -- episode_id, seed on reset -- request_id on step -- custom headers (auth/trace) -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar - -import requests - -from .client_types import StepResult, StateT -from .containers.runtime import LocalDockerProvider - -if TYPE_CHECKING: - from .containers.runtime import ContainerProvider - -ActT = TypeVar("ActT") -ObsT = TypeVar("ObsT") -EnvClientT = TypeVar("EnvClientT", bound="HTTPEnvClient") - - -class HTTPEnvClient(ABC, Generic[ActT, ObsT, StateT]): - def __init__( - self, - base_url: str, - request_timeout_s: float = 15.0, - default_headers: Optional[Dict[str, str]] = None, - provider: Optional["ContainerProvider"] = None, - ): - self._base = base_url.rstrip("/") - self._timeout = float(request_timeout_s) - self._http = requests.Session() - self._headers = default_headers or {} - self._provider = provider - - @classmethod - def from_docker_image( - cls: Type[EnvClientT], - image: str, - provider: Optional["ContainerProvider"] = None, - **kwargs: Any, - ) -> EnvClientT: - """ - Create an environment client by spinning up a Docker container locally. - - This is a development utility that: - 1. Starts a Docker container from the specified image - 2. Waits for the server to be ready - 3. Creates and returns a client instance connected to the container - - Note: The container lifecycle management is left to the user or higher-level - orchestration. The container will keep running until manually stopped. - - Args: - image: Docker image name to run (e.g., "echo-env:latest") - provider: Container provider to use (defaults to LocalDockerProvider) - **kwargs: Additional arguments to pass to provider.start_container() - (e.g., env_vars, port) - - Returns: - An instance of the client class connected to the running container - - Example: - >>> from envs.coding_env.client import CodingEnv - >>> from envs.coding_env.models import CodeAction - >>> - >>> # Create environment from image - >>> env = CodingEnv.from_docker_image("coding-env:latest") - >>> - >>> # Create environment with custom env vars - >>> env = CodingEnv.from_docker_image( - ... "coding-env:latest", - ... env_vars={"MY_VAR": "value"} - ... ) - >>> - >>> # Use the environment - >>> result = env.reset() - >>> print(result.observation) - >>> - >>> step_result = env.step(CodeAction(code="print('hello')")) - >>> print(step_result.observation.stdout) - >>> - >>> # Cleanup (optional) - >>> env.close() - """ - - # Use default provider if none provided - if provider is None: - provider = LocalDockerProvider() - - # 1. Start container with optional kwargs (e.g., env_vars, port) - base_url = provider.start_container(image, **kwargs) - - # 2. Wait for server to be ready - provider.wait_for_ready(base_url) - - # 3. Create and return client instance with provider reference - return cls(base_url=base_url, provider=provider) - - @classmethod - def from_hub( - cls: Type[EnvClientT], - repo_id: str, - provider: Optional["ContainerProvider"] = None, - **kwargs: Any, - ) -> EnvClientT: - """ - Create an environment client by pulling from a Hugging Face model hub. - """ - - if provider is None: - provider = LocalDockerProvider() - - if "tag" in kwargs: - tag = kwargs["tag"] - else: - tag = "latest" - - base_url = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}" - - return cls.from_docker_image(image=base_url, provider=provider) - - @abstractmethod - def _step_payload(self, action: ActT) -> Dict[str, Any]: - """Convert an Action object to the JSON body expected by the env server.""" - raise NotImplementedError - - @abstractmethod - def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]: - """Convert a JSON response from the env server to StepResult[ObsT].""" - raise NotImplementedError - - @abstractmethod - def _parse_state(self, payload: Dict[str, Any]) -> StateT: - """Convert a JSON response from the state endpoint to a State object.""" - raise NotImplementedError - - # ---------- Environment Server Interface Methods ---------- - def reset(self, **kwargs: Any) -> StepResult[ObsT]: - """ - Reset the environment with optional parameters. - - Args: - **kwargs: Optional parameters passed to the environment's reset method. - Common parameters include: - - seed: Random seed for reproducibility - - episode_id: Custom episode identifier - - Any environment-specific reset parameters - - Returns: - StepResult containing initial observation - - Example: - >>> env.reset(seed=42, episode_id="ep-001") - """ - body: Dict[str, Any] = kwargs.copy() - r = self._http.post( - f"{self._base}/reset", - json=body, - headers=self._headers, - timeout=self._timeout, - ) - r.raise_for_status() - return self._parse_result(r.json()) - - def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: - """ - Execute an action in the environment with optional parameters. - - Args: - action: The action to execute - **kwargs: Optional parameters passed to the environment's step method. - Common parameters include: - - timeout_s: Execution timeout in seconds - - request_id: Request identifier for tracking - - render: Whether to render the environment - - Any environment-specific step parameters - - Returns: - StepResult containing observation, reward, and done status - - Example: - >>> env.step(action, timeout_s=30.0, request_id="req-123", render=True) - """ - body: Dict[str, Any] = { - "action": self._step_payload(action), - **kwargs # Forward all additional parameters - } - r = self._http.post( - f"{self._base}/step", - json=body, - headers=self._headers, - timeout=self._timeout, - ) - r.raise_for_status() - return self._parse_result(r.json()) - - def state(self) -> StateT: - """ - Get the current environment state from the server. - - Returns: - State object with environment state information (e.g., episode_id, step_count) - - Example: - >>> client = EchoEnv.from_docker_image("echo-env:latest") - >>> result = client.reset() - >>> state = client.state() - >>> print(state.episode_id) - >>> print(state.step_count) - """ - r = self._http.get( - f"{self._base}/state", - headers=self._headers, - timeout=self._timeout, - ) - r.raise_for_status() - return self._parse_state(r.json()) - - def close(self) -> None: - """ - Close the environment and clean up resources. - - If this client was created via from_docker_image(), this will stop - and remove the associated container. - """ - if self._provider is not None: - self._provider.stop_container() From c42781254ada0c4c964b233541d5d901ba3626ef Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Thu, 18 Dec 2025 08:50:07 +0100 Subject: [PATCH 34/41] fix websocket ui --- src/openenv/core/env_server/web_interface.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index 703025375..fe2a1aee2 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -283,9 +283,14 @@ async def web_metadata(): """Get environment metadata.""" return web_manager.metadata.model_dump() - @app.websocket("/ws") - async def websocket_endpoint(websocket: WebSocket): - """WebSocket endpoint for real-time updates.""" + @app.websocket("/ws/ui") + async def websocket_ui_endpoint(websocket: WebSocket): + """WebSocket endpoint for web UI real-time updates. + + Note: This endpoint is separate from /ws which is used for + concurrent environment sessions. This endpoint is specifically + for the web interface state updates. + """ await web_manager.connect_websocket(websocket) try: while True: @@ -943,7 +948,7 @@ class OpenEnvWebInterface {{ connectWebSocket() {{ const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; - const wsUrl = `${{protocol}}//${{window.location.host}}/ws`; + const wsUrl = `${{protocol}}//${{window.location.host}}/ws/ui`; this.ws = new WebSocket(wsUrl); From cde46608a9ca46af66fd3f0f31ec952a58107d6f Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Thu, 18 Dec 2025 08:52:08 +0100 Subject: [PATCH 35/41] update docs in environment builder to use ws --- docs/environment-builder.md | 82 ++++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/docs/environment-builder.md b/docs/environment-builder.md index 9fefc9ee1..4e9728344 100644 --- a/docs/environment-builder.md +++ b/docs/environment-builder.md @@ -10,7 +10,7 @@ A typical workflow looks like: 1. Scaffold a new environment with `openenv init`. 2. Customize your models, environment logic, and FastAPI server. -3. Implement a typed `HTTPEnvClient`. +3. Implement a typed `EnvClient` (WebSocket-based for persistent sessions). 4. Configure dependencies and the Dockerfile once. 5. Use the CLI (`openenv build`, `openenv validate`, `openenv push`) to package and share your work. @@ -119,29 +119,52 @@ class MyEnvironment(Environment): ### 4. Create the FastAPI Server -`server/app.py` should expose the environment through `create_fastapi_app`: +`server/app.py` should expose the environment through `create_app`. + +**Important:** You must pass a class or factory function (not an instance) to enable WebSocket-based concurrent sessions: ```python # server/app.py -from openenv.core.env_server import create_fastapi_app +from openenv.core.env_server import create_app from ..models import MyAction, MyObservation from .my_environment import MyEnvironment -env = MyEnvironment() -app = create_fastapi_app(env, MyAction, MyObservation) +# Pass the class (factory) - each WebSocket session gets its own instance +app = create_app(MyEnvironment, MyAction, MyObservation, env_name="my_env") +``` + +For environments with constructor arguments, create a factory function: + +```python +# server/app.py +import os +from openenv.core.env_server import create_app +from ..models import MyAction, MyObservation +from .my_environment import MyEnvironment + +# Read config from environment variables +api_key = os.getenv("MY_API_KEY") +timeout = int(os.getenv("MY_TIMEOUT", "30")) + +def create_my_environment(): + """Factory function that creates MyEnvironment with config.""" + return MyEnvironment(api_key=api_key, timeout=timeout) + +# Pass the factory function +app = create_app(create_my_environment, MyAction, MyObservation, env_name="my_env") ``` ### 5. Implement the Client -`client.py` extends `HTTPEnvClient` so users can interact with your server over HTTP or Docker: +`client.py` extends `EnvClient` so users can interact with your server via WebSocket for persistent sessions: ```python # client.py -from openenv.core.http_env_client import HTTPEnvClient -from openenv.core.types import StepResult +from openenv.core.env_client import EnvClient +from openenv.core.client_types import StepResult from .models import MyAction, MyObservation, MyState -class MyEnv(HTTPEnvClient[MyAction, MyObservation]): +class MyEnv(EnvClient[MyAction, MyObservation, MyState]): def _step_payload(self, action: MyAction) -> dict: return {"command": action.command, "parameters": action.parameters} @@ -157,6 +180,8 @@ class MyEnv(HTTPEnvClient[MyAction, MyObservation]): return MyState(**payload) ``` +The `EnvClient` maintains a persistent WebSocket connection to the server, enabling efficient multi-step interactions with lower latency compared to HTTP. Each client instance gets its own dedicated environment session on the server. + ### 6. Configure Dependencies & Dockerfile The CLI template ships with `pyproject.toml` and `server/Dockerfile`. You should manage your python dependencies with `uv` or `pip` in the `pyproject.toml` file. Other dependencies should be installed in the Dockerfile. @@ -322,22 +347,29 @@ client = MyEnv.from_hub("my-org/my-env") # Or, connect to the local server client = MyEnv(base_url="http://localhost:8000") -# Reset -result = client.reset() -print(result.observation.result) # "Ready" - -# Execute actions -result = client.step(MyAction(command="test", parameters={})) -print(result.observation.result) -print(result.observation.success) - -# Get state -state = client.state() -print(state.episode_id) -print(state.step_count) - -# Cleanup -client.close() +# Use context manager for automatic cleanup (recommended) +with client: + # Reset + result = client.reset() + print(result.observation.result) # "Ready" + + # Execute actions + result = client.step(MyAction(command="test", parameters={})) + print(result.observation.result) + print(result.observation.success) + + # Get state + state = client.state() + print(state.episode_id) + print(state.step_count) + +# Or manually manage the connection +try: + client = MyEnv(base_url="http://localhost:8000") + result = client.reset() + result = client.step(MyAction(command="test", parameters={})) +finally: + client.close() ``` ## Nice work! You've now built and used your own OpenEnv environment. From c41c826d2bc79972575dd45ef5f23ed248cba9ab Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Thu, 18 Dec 2025 08:52:42 +0100 Subject: [PATCH 36/41] formatting in web interface --- src/openenv/core/env_server/web_interface.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index fe2a1aee2..210a7804b 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -134,9 +134,7 @@ def __init__( name=env.__class__.__name__, description=f"{env.__class__.__name__} environment", ) - self.episode_state = EpisodeState( - episode_id=None, step_count=0, current_observation=None, action_logs=[] - ) + self.episode_state = EpisodeState(episode_id=None, step_count=0, current_observation=None, action_logs=[]) self.connected_clients: List[WebSocket] = [] async def connect_websocket(self, websocket: WebSocket): @@ -262,7 +260,7 @@ def create_web_interface_app( # Create the base environment app app = create_fastapi_app(env, action_cls, observation_cls, max_concurrent_envs, concurrency_config) - + # Create a test instance for metadata env_instance = env() @@ -286,7 +284,7 @@ async def web_metadata(): @app.websocket("/ws/ui") async def websocket_ui_endpoint(websocket: WebSocket): """WebSocket endpoint for web UI real-time updates. - + Note: This endpoint is separate from /ws which is used for concurrent environment sessions. This endpoint is specifically for the web interface state updates. @@ -1329,11 +1327,7 @@ def _determine_input_type_from_schema(field_info: Dict[str, Any], field_name: st if schema_type == "string": # Check if it should be a textarea - if ( - field_info.get("maxLength", 0) > 100 - or "message" in field_name.lower() - or "code" in field_name.lower() - ): + if field_info.get("maxLength", 0) > 100 or "message" in field_name.lower() or "code" in field_name.lower(): return "textarea" return "text" From 56f8922be99737a5f0f04f5a3006d7ecbf0b2206 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Thu, 18 Dec 2025 09:40:48 +0100 Subject: [PATCH 37/41] update all envs to use factory method --- envs/atari_env/client.py | 34 +++--- envs/atari_env/server/app.py | 29 ++--- envs/browsergym_env/client.py | 9 +- envs/browsergym_env/models.py | 4 - envs/browsergym_env/server/app.py | 25 +++-- envs/chat_env/client.py | 44 ++++---- envs/chat_env/server/app.py | 16 +-- envs/coding_env/client.py | 13 +-- envs/coding_env/server/app.py | 8 +- envs/coding_env/server/python_codeact_env.py | 4 +- envs/connect4_env/client.py | 27 +++-- envs/connect4_env/models.py | 16 +-- envs/connect4_env/server/app.py | 9 +- envs/dipg_safety_env/client.py | 14 +-- envs/dipg_safety_env/server/app.py | 50 +++++---- .../server/dipg_environment.py | 2 +- envs/echo_env/client.py | 39 +++---- envs/echo_env/models.py | 4 - envs/echo_env/server/app.py | 8 +- envs/finrl_env/client.py | 102 ++++++++---------- envs/finrl_env/server/app.py | 15 ++- envs/git_env/client.py | 50 ++++----- envs/git_env/server/app.py | 21 ++-- envs/openspiel_env/client.py | 34 +++--- envs/openspiel_env/server/app.py | 21 ++-- envs/sumo_rl_env/client.py | 45 ++++---- envs/sumo_rl_env/server/app.py | 34 +++--- envs/textarena_env/client.py | 8 +- envs/textarena_env/server/app.py | 25 +++-- 29 files changed, 372 insertions(+), 338 deletions(-) diff --git a/envs/atari_env/client.py b/envs/atari_env/client.py index cbdb373f5..458895454 100644 --- a/envs/atari_env/client.py +++ b/envs/atari_env/client.py @@ -5,10 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -Atari Environment HTTP Client. +Atari Environment Client. This module provides the client for connecting to an Atari Environment server -over HTTP. +via WebSocket for persistent sessions. """ from __future__ import annotations @@ -17,7 +17,7 @@ from openenv.core.client_types import StepResult -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.env_client import EnvClient from .models import AtariAction, AtariObservation, AtariState @@ -25,28 +25,30 @@ from openenv.core.containers.runtime import ContainerProvider -class AtariEnv(HTTPEnvClient[AtariAction, AtariObservation]): +class AtariEnv(EnvClient[AtariAction, AtariObservation, AtariState]): """ - HTTP client for Atari Environment. + Client for Atari Environment. - This client connects to an AtariEnvironment HTTP server and provides - methods to interact with it: reset(), step(), and state access. + This client maintains a persistent WebSocket connection to the environment + server, enabling efficient multi-step interactions with lower latency. Example: >>> # Connect to a running server - >>> client = AtariEnv(base_url="http://localhost:8000") - >>> result = client.reset() - >>> print(result.observation.screen_shape) - >>> - >>> # Take an action - >>> result = client.step(AtariAction(action_id=2)) # UP - >>> print(result.reward, result.done) + >>> with AtariEnv(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.screen_shape) + ... + ... result = client.step(AtariAction(action_id=2)) # UP + ... print(result.reward, result.done) Example with Docker: >>> # Automatically start container and connect >>> client = AtariEnv.from_docker_image("atari-env:latest") - >>> result = client.reset() - >>> result = client.step(AtariAction(action_id=0)) # NOOP + >>> try: + ... result = client.reset() + ... result = client.step(AtariAction(action_id=0)) # NOOP + ... finally: + ... client.close() """ def _step_payload(self, action: AtariAction) -> Dict[str, Any]: diff --git a/envs/atari_env/server/app.py b/envs/atari_env/server/app.py index 14254f6d9..036e44ef3 100644 --- a/envs/atari_env/server/app.py +++ b/envs/atari_env/server/app.py @@ -8,7 +8,7 @@ FastAPI application for the Atari Environment. This module creates an HTTP server that exposes Atari games -over HTTP endpoints, making them compatible with HTTPEnvClient. +over HTTP and WebSocket endpoints, compatible with EnvClient. Usage: # Development (with auto-reload): @@ -52,19 +52,24 @@ mode = int(mode) if mode is not None else None difficulty = int(difficulty) if difficulty is not None else None -# Create the environment instance -env = AtariEnvironment( - game_name=game_name, - obs_type=obs_type, - full_action_space=full_action_space, - mode=mode, - difficulty=difficulty, - repeat_action_probability=repeat_action_prob, - frameskip=frameskip, -) + +# Factory function to create AtariEnvironment instances +def create_atari_environment(): + """Factory function that creates AtariEnvironment with config.""" + return AtariEnvironment( + game_name=game_name, + obs_type=obs_type, + full_action_space=full_action_space, + mode=mode, + difficulty=difficulty, + repeat_action_probability=repeat_action_prob, + frameskip=frameskip, + ) + # Create the FastAPI app with web interface and README integration -app = create_app(env, AtariAction, AtariObservation, env_name="atari_env") +# Pass the factory function instead of an instance for WebSocket session support +app = create_app(create_atari_environment, AtariAction, AtariObservation, env_name="atari_env") if __name__ == "__main__": diff --git a/envs/browsergym_env/client.py b/envs/browsergym_env/client.py index 5b6d3772d..cb7437f9d 100644 --- a/envs/browsergym_env/client.py +++ b/envs/browsergym_env/client.py @@ -1,8 +1,9 @@ -"""HTTP client for the BrowserGym environment.""" +"""Client for the BrowserGym environment.""" from typing import Any, Dict -from openenv.core.http_env_client import HTTPEnvClient, StepResult +from openenv.core.client_types import StepResult +from openenv.core.env_client import EnvClient from .models import ( BrowserGymAction, BrowserGymObservation, @@ -10,8 +11,8 @@ ) -class BrowserGymEnv(HTTPEnvClient[BrowserGymAction, BrowserGymObservation]): - """Client for interacting with the BrowserGym environment over HTTP. +class BrowserGymEnv(EnvClient[BrowserGymAction, BrowserGymObservation, BrowserGymState]): + """Client for interacting with the BrowserGym environment. BrowserGym provides unified access to multiple web navigation benchmarks: - MiniWoB++: 100+ training tasks (no external infrastructure needed!) diff --git a/envs/browsergym_env/models.py b/envs/browsergym_env/models.py index f62bcf773..c783abc0b 100644 --- a/envs/browsergym_env/models.py +++ b/envs/browsergym_env/models.py @@ -5,13 +5,11 @@ and more under a single Gymnasium-compatible API. """ -from dataclasses import dataclass from typing import List, Optional from openenv.core.env_server.types import Action, Observation, State -@dataclass(kw_only=True) class BrowserGymAction(Action): """Action to be executed in the BrowserGym environment. @@ -30,7 +28,6 @@ class BrowserGymAction(Action): """Natural language action string (e.g., "click('Submit')")""" -@dataclass(kw_only=True) class BrowserGymObservation(Observation): """Observation returned from the BrowserGym environment. @@ -63,7 +60,6 @@ class BrowserGymObservation(Observation): """Whether the last action resulted in an error""" -@dataclass class BrowserGymState(State): """State of the BrowserGym environment. diff --git a/envs/browsergym_env/server/app.py b/envs/browsergym_env/server/app.py index 488b66974..fa8214dc3 100644 --- a/envs/browsergym_env/server/app.py +++ b/envs/browsergym_env/server/app.py @@ -15,19 +15,24 @@ timeout = float(os.environ.get("BROWSERGYM_TIMEOUT", "10000")) port = int(os.environ.get("BROWSERGYM_PORT", "8000")) -# Create the environment instance -env = BrowserGymEnvironment( - benchmark=benchmark, - task_name=task_name, - headless=headless, - viewport_width=viewport_width, - viewport_height=viewport_height, - timeout=timeout, -) + +# Factory function to create BrowserGymEnvironment instances +def create_browsergym_environment(): + """Factory function that creates BrowserGymEnvironment with config.""" + return BrowserGymEnvironment( + benchmark=benchmark, + task_name=task_name, + headless=headless, + viewport_width=viewport_width, + viewport_height=viewport_height, + timeout=timeout, + ) + # Create the FastAPI app +# Pass the factory function instead of an instance for WebSocket session support app = create_app( - env, + create_browsergym_environment, BrowserGymAction, BrowserGymObservation, env_name="browsergym_env", diff --git a/envs/chat_env/client.py b/envs/chat_env/client.py index d14829f74..a1b265cd4 100644 --- a/envs/chat_env/client.py +++ b/envs/chat_env/client.py @@ -5,10 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -Chat Environment HTTP Client. +Chat Environment Client. This module provides the client for connecting to a Chat Environment server -over HTTP. +via WebSocket for persistent sessions. """ from typing import Any, Dict @@ -17,40 +17,42 @@ from openenv.core.client_types import StepResult from openenv.core.env_server.interfaces import Message -from openenv.core.env_server.types import State -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.env_client import EnvClient from .models import ChatAction, ChatObservation, ChatState -class ChatEnv(HTTPEnvClient[ChatAction, ChatObservation]): +class ChatEnv(EnvClient[ChatAction, ChatObservation, ChatState]): """ - HTTP client for the Chat Environment. + Client for the Chat Environment. - This client connects to a ChatEnvironment HTTP server and provides - methods to interact with it: reset(), step(), and state access. + This client maintains a persistent WebSocket connection to the environment + server, enabling efficient multi-step interactions with lower latency. - Note: Since ChatEnvironment works with PyTorch tensors, the HTTP layer + Note: Since ChatEnvironment works with PyTorch tensors, the client serializes tokens as lists for transport and deserializes them back to tensors. Example: >>> # Connect to a running server - >>> client = ChatEnv(base_url="http://localhost:8000") - >>> result = client.reset() - >>> print(result.observation.messages) - >>> - >>> # Send an action with tokens - >>> import torch - >>> tokens = torch.tensor([[1, 2, 3, 4, 5]]) - >>> result = client.step(ChatAction(tokens=tokens)) - >>> print(result.observation.messages) - >>> print(result.reward) + >>> with ChatEnv(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.messages) + ... + ... # Send an action with tokens + ... import torch + ... tokens = torch.tensor([[1, 2, 3, 4, 5]]) + ... result = client.step(ChatAction(tokens=tokens)) + ... print(result.observation.messages) + ... print(result.reward) Example with Docker: >>> # Automatically start container and connect >>> client = ChatEnv.from_docker_image("chat-env:latest") - >>> result = client.reset() - >>> result = client.step(ChatAction(tokens=torch.tensor([[1, 2, 3]]))) + >>> try: + ... result = client.reset() + ... result = client.step(ChatAction(tokens=torch.tensor([[1, 2, 3]]))) + ... finally: + ... client.close() """ def _step_payload(self, action: ChatAction) -> Dict: diff --git a/envs/chat_env/server/app.py b/envs/chat_env/server/app.py index 719b5ede8..88b9694f7 100644 --- a/envs/chat_env/server/app.py +++ b/envs/chat_env/server/app.py @@ -8,7 +8,7 @@ FastAPI application for the Chat Environment. This module creates an HTTP server that exposes the ChatEnvironment -over HTTP endpoints, making it compatible with HTTPEnvClient. +over HTTP and WebSocket endpoints, compatible with EnvClient. Note: This server requires a tokenizer to be initialized. The tokenizer must be specified when starting the server. @@ -27,7 +27,6 @@ import os from openenv.core.env_server import create_app -from openenv.core.env_server.web_interface import create_web_interface_app from ..models import ChatAction, ChatObservation from .chat_environment import ChatEnvironment @@ -64,12 +63,17 @@ def get_tokenizer(): # Get system prompt from environment system_prompt = os.environ.get("SYSTEM_PROMPT", None) -# Create the environment instance with tokenizer -tokenizer = get_tokenizer() -env = ChatEnvironment(tokenizer=tokenizer, system_prompt=system_prompt) + +# Factory function to create ChatEnvironment instances +def create_chat_environment(): + """Factory function that creates ChatEnvironment with tokenizer.""" + tokenizer = get_tokenizer() + return ChatEnvironment(tokenizer=tokenizer, system_prompt=system_prompt) + # Create the FastAPI app with web interface and README integration -app = create_app(env, ChatAction, ChatObservation, env_name="chat_env") +# Pass the factory function instead of an instance for WebSocket session support +app = create_app(create_chat_environment, ChatAction, ChatObservation, env_name="chat_env") if __name__ == "__main__": diff --git a/envs/coding_env/client.py b/envs/coding_env/client.py index 544b6a6e0..a05db092e 100644 --- a/envs/coding_env/client.py +++ b/envs/coding_env/client.py @@ -2,11 +2,13 @@ CodingEnv --------- Client-side wrapper for the Coding environment server. -Talks HTTP to a single base_url exposing: /reset and /step. + +This client maintains a persistent WebSocket connection to the environment +server, enabling efficient multi-step interactions with lower latency. - users instantiate CodingEnv with a base_url provided by the higher-level vector/orchestration layer. -- Environment authors ship the Docker image that serves the HTTP API. +- Environment authors ship the Docker image that serves the API. (Seeds, episode IDs, request IDs, capabilities can be added later in the payloads.) """ @@ -14,13 +16,12 @@ from __future__ import annotations from openenv.core.client_types import StepResult +from openenv.core.env_client import EnvClient -from openenv.core.http_env_client import HTTPEnvClient - -from coding_env.models import CodeAction, CodeObservation, CodeState +from .models import CodeAction, CodeObservation, CodeState -class CodingEnv(HTTPEnvClient[CodeAction, CodeObservation]): +class CodingEnv(EnvClient[CodeAction, CodeObservation, CodeState]): # --- HTTPEnvClient abstract hooks --- def _step_payload(self, action: CodeAction) -> dict: diff --git a/envs/coding_env/server/app.py b/envs/coding_env/server/app.py index b636d0784..4859585fa 100644 --- a/envs/coding_env/server/app.py +++ b/envs/coding_env/server/app.py @@ -8,7 +8,7 @@ FastAPI application for the Coding Environment. This module creates an HTTP server that exposes the PythonCodeActEnv -over HTTP endpoints, making it compatible with HTTPEnvClient. +over HTTP and WebSocket endpoints, compatible with EnvClient. Usage: # Development (with auto-reload): @@ -26,11 +26,9 @@ from coding_env.models import CodeAction, CodeObservation from coding_env.server.python_codeact_env import PythonCodeActEnv -# Create the environment instance -env = PythonCodeActEnv() - # Create the app with web interface and README integration -app = create_app(env, CodeAction, CodeObservation, env_name="coding_env") +# Pass the class (factory) instead of an instance for WebSocket session support +app = create_app(PythonCodeActEnv, CodeAction, CodeObservation, env_name="coding_env") if __name__ == "__main__": diff --git a/envs/coding_env/server/python_codeact_env.py b/envs/coding_env/server/python_codeact_env.py index ed95135d1..a73ed1e55 100644 --- a/envs/coding_env/server/python_codeact_env.py +++ b/envs/coding_env/server/python_codeact_env.py @@ -14,9 +14,9 @@ import uuid from openenv.core.env_server.interfaces import Action, Environment, Observation -from coding_env.server.python_executor import PyExecutor +from .python_executor import PyExecutor -from coding_env.models import CodeAction, CodeObservation, CodeState +from ..models import CodeAction, CodeObservation, CodeState from .transforms import create_safe_coding_transform diff --git a/envs/connect4_env/client.py b/envs/connect4_env/client.py index a462929a0..d9f6c2165 100644 --- a/envs/connect4_env/client.py +++ b/envs/connect4_env/client.py @@ -5,10 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -Connect4 Environment HTTP Client. +Connect4 Environment Client. This module provides the client for connecting to a Connect4 Environment server -over HTTP. +via WebSocket for persistent sessions. """ from __future__ import annotations @@ -16,7 +16,7 @@ from typing import Any, Dict, TYPE_CHECKING from openenv.core.client_types import StepResult -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.env_client import EnvClient from .models import Connect4Action, Connect4Observation, Connect4State @@ -24,21 +24,20 @@ from openenv.core.containers.runtime import ContainerProvider -class Connect4Env(HTTPEnvClient[Connect4Action, Connect4Observation]): +class Connect4Env(EnvClient[Connect4Action, Connect4Observation, Connect4State]): """ - HTTP client for Connect4 Environment. + Client for Connect4 Environment. - This client connects to a Connect4Environment HTTP server and provides - methods to interact with it: reset(), step(), and state access. + This client maintains a persistent WebSocket connection to the environment + server, enabling efficient multi-step interactions with lower latency. Example: - >>> client = Connect4Env(base_url="http://localhost:8000") - >>> result = client.reset() - >>> print(result.observation.board) - >>> - >>> # Take an action - >>> result = client.step(Connect4Action(column=3)) - >>> print(result.reward, result.done) + >>> with Connect4Env(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.board) + ... + ... result = client.step(Connect4Action(column=3)) + ... print(result.reward, result.done) """ def _step_payload(self, action: Connect4Action) -> Dict[str, Any]: diff --git a/envs/connect4_env/models.py b/envs/connect4_env/models.py index 8cf3309a8..4d1109c2d 100644 --- a/envs/connect4_env/models.py +++ b/envs/connect4_env/models.py @@ -12,14 +12,12 @@ """ from __future__ import annotations -from dataclasses import dataclass, field -import numpy as np -from typing import List +from typing import List, Dict, Any +from pydantic import Field from openenv.core.env_server import Action, Observation, State -@dataclass class Connect4Action(Action): """ Action for Connect4 environment. @@ -30,7 +28,6 @@ class Connect4Action(Action): column: int -@dataclass(kw_only=True) class Connect4Observation(Observation): """ Observation for Connect4 environment. @@ -45,13 +42,8 @@ class Connect4Observation(Observation): board: List[List[int]] legal_actions: List[int] - done: bool = False - reward: float = 0.0 - metadata: dict = field(default_factory=dict) - -@dataclass(kw_only=True) class Connect4State(State): """ State for Connect4 environment. @@ -62,7 +54,5 @@ class Connect4State(State): next_player: Whose turn it is (1 or -1). step_count: Number of steps taken in the game. """ - episode_id: str - board: List[List[int]] = field(default_factory=lambda: np.zeros((6,7), dtype=int).tolist()) + board: List[List[int]] = Field(default_factory=lambda: [[0]*7 for _ in range(6)]) next_player: int = 1 - step_count: int = 0 diff --git a/envs/connect4_env/server/app.py b/envs/connect4_env/server/app.py index 143ee1770..2025b2c37 100644 --- a/envs/connect4_env/server/app.py +++ b/envs/connect4_env/server/app.py @@ -1,9 +1,12 @@ -from openenv.core.env_server import create_fastapi_app +"""FastAPI application for the Connect4 Environment.""" + +from openenv.core.env_server import create_app from ..models import Connect4Action, Connect4Observation from .connect4_environment import Connect4Environment -env = Connect4Environment() -app = create_fastapi_app(env, Connect4Action, Connect4Observation) +# Create the FastAPI app +# Pass the class (factory) instead of an instance for WebSocket session support +app = create_app(Connect4Environment, Connect4Action, Connect4Observation, env_name="connect4_env") if __name__ == "__main__": diff --git a/envs/dipg_safety_env/client.py b/envs/dipg_safety_env/client.py index 9e556481f..2d11503b3 100644 --- a/envs/dipg_safety_env/client.py +++ b/envs/dipg_safety_env/client.py @@ -3,22 +3,24 @@ Client implementation for the custom DIPGSafetyEnv. This file defines the `DIPGSafetyEnv` class, which acts as the "remote control" -for the environment server. Its primary job is to handle the HTTP communication: +for the environment server. It maintains a persistent WebSocket connection +for efficient multi-step interactions: 1. It takes Python objects (like an Action) from the agent's code. 2. It converts them into JSON to send to the server. 3. It receives JSON responses from the server. 4. It parses that JSON back into useful Python objects (like Observations and Rewards). """ -from openenv.core.http_env_client import HTTPEnvClient, StepResult +from openenv.core.client_types import StepResult +from openenv.core.env_client import EnvClient from .models import DIPGAction, DIPGObservation, DIPGState -class DIPGSafetyEnv(HTTPEnvClient[DIPGAction, DIPGObservation]): +class DIPGSafetyEnv(EnvClient[DIPGAction, DIPGObservation, DIPGState]): """ Client for interacting with the `DIPGSafetyEnv` server. - This class inherits from the base `HTTPEnvClient` and is specialized to handle + This class inherits from the base `EnvClient` and is specialized to handle the specific data types of our environment: `DIPGAction` and `DIPGObservation`. """ @@ -31,8 +33,8 @@ def __init__(self, base_url: str, timeout: float = 60.0): timeout: The number of seconds to wait for a server response. """ # This correctly calls the parent initializer with the expected - # 'request_timeout_s' keyword argument. - super().__init__(base_url=base_url, request_timeout_s=timeout) + # 'message_timeout_s' keyword argument. + super().__init__(base_url=base_url, message_timeout_s=timeout) # ---------------------------------------- def _step_payload(self, action: DIPGAction) -> dict: diff --git a/envs/dipg_safety_env/server/app.py b/envs/dipg_safety_env/server/app.py index 5c079d171..2e8c524cc 100644 --- a/envs/dipg_safety_env/server/app.py +++ b/envs/dipg_safety_env/server/app.py @@ -1,4 +1,11 @@ # envs/dipg_safety_env/server/app.py +""" +FastAPI application for the DIPG Safety Environment. + +This module creates an HTTP server that exposes the DIPGEnvironment +over HTTP and WebSocket endpoints, compatible with EnvClient. +""" + import os from openenv.core.env_server import create_app from .dipg_environment import DIPGEnvironment @@ -24,22 +31,27 @@ FINAL_CHANNEL_START = os.environ.get("FINAL_CHANNEL_START", "<|channel|>final<|message|>") CHANNEL_END = os.environ.get("CHANNEL_END", "<|end|>") -# Create the environment instance, passing the path and rewards to it. -env = DIPGEnvironment( - dataset_path=DATASET_PATH, - conflict_reward=CONFLICT_REWARD, - conflict_penalty=CONFLICT_PENALTY, - abstain_reward=ABSTAIN_REWARD, - abstain_penalty=ABSTAIN_PENALTY, - format_mismatch_penalty=FORMAT_MISMATCH_PENALTY, - exact_format_reward=EXACT_FORMAT_REWARD, - hallucination_penalty=HALLUCINATION_PENALTY, - no_hallucination_reward=NO_HALLUCINATION_REWARD, - missing_answer_penalty=MISSING_ANSWER_PENALTY, - analysis_channel_start=ANALYSIS_CHANNEL_START, - final_channel_start=FINAL_CHANNEL_START, - channel_end=CHANNEL_END, -) - -# The rest is the same. -app = create_app(env, DIPGAction, DIPGObservation, env_name="dipg_safety_env") \ No newline at end of file + +# Factory function to create DIPGEnvironment instances +def create_dipg_environment(): + """Factory function that creates DIPGEnvironment with config.""" + return DIPGEnvironment( + dataset_path=DATASET_PATH, + conflict_reward=CONFLICT_REWARD, + conflict_penalty=CONFLICT_PENALTY, + abstain_reward=ABSTAIN_REWARD, + abstain_penalty=ABSTAIN_PENALTY, + format_mismatch_penalty=FORMAT_MISMATCH_PENALTY, + exact_format_reward=EXACT_FORMAT_REWARD, + hallucination_penalty=HALLUCINATION_PENALTY, + no_hallucination_reward=NO_HALLUCINATION_REWARD, + missing_answer_penalty=MISSING_ANSWER_PENALTY, + analysis_channel_start=ANALYSIS_CHANNEL_START, + final_channel_start=FINAL_CHANNEL_START, + channel_end=CHANNEL_END, + ) + + +# Create the FastAPI app +# Pass the factory function instead of an instance for WebSocket session support +app = create_app(create_dipg_environment, DIPGAction, DIPGObservation, env_name="dipg_safety_env") \ No newline at end of file diff --git a/envs/dipg_safety_env/server/dipg_environment.py b/envs/dipg_safety_env/server/dipg_environment.py index f154c7db6..70a7e5a7b 100644 --- a/envs/dipg_safety_env/server/dipg_environment.py +++ b/envs/dipg_safety_env/server/dipg_environment.py @@ -3,7 +3,7 @@ import json import random from pathlib import Path -from openenv.core.http_env_client import StepResult +from openenv.core.client_types import StepResult from openenv.core.env_server import Environment from ..models import DIPGAction, DIPGObservation, DIPGState import re diff --git a/envs/echo_env/client.py b/envs/echo_env/client.py index fcb82e5ca..9c7ee2c64 100644 --- a/envs/echo_env/client.py +++ b/envs/echo_env/client.py @@ -5,10 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -Echo Environment HTTP Client. +Echo Environment Client. This module provides the client for connecting to an Echo Environment server -over HTTP. +via WebSocket for persistent sessions. """ from typing import Any, Dict @@ -18,39 +18,42 @@ # In-repo imports (when running from OpenEnv repository) from openenv.core.client_types import StepResult from openenv.core.env_server.types import State - from openenv.core.http_env_client import HTTPEnvClient + from openenv.core.env_client import EnvClient from .models import EchoAction, EchoObservation except ImportError: # Standalone imports (when environment is standalone with openenv from pip) from openenv.core.client_types import StepResult from openenv.core.env_server.types import State - from openenv.core.http_env_client import HTTPEnvClient + from openenv.core.env_client import EnvClient from models import EchoAction, EchoObservation -class EchoEnv(HTTPEnvClient[EchoAction, EchoObservation]): +class EchoEnv(EnvClient[EchoAction, EchoObservation, State]): """ - HTTP client for the Echo Environment. + Client for the Echo Environment. - This client connects to an EchoEnvironment HTTP server and provides - methods to interact with it: reset(), step(), and state access. + This client maintains a persistent WebSocket connection to the environment + server, enabling efficient multi-step interactions with lower latency. + Each client instance has its own dedicated environment session on the server. Example: >>> # Connect to a running server - >>> client = EchoEnv(base_url="http://localhost:8000") - >>> result = client.reset() - >>> print(result.observation.echoed_message) - >>> - >>> # Send a message - >>> result = client.step(EchoAction(message="Hello!")) - >>> print(result.observation.echoed_message) - >>> print(result.reward) + >>> with EchoEnv(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.echoed_message) + ... + ... result = client.step(EchoAction(message="Hello!")) + ... print(result.observation.echoed_message) + ... print(result.reward) Example with Docker: >>> # Automatically start container and connect >>> client = EchoEnv.from_docker_image("echo-env:latest") - >>> result = client.reset() - >>> result = client.step(EchoAction(message="Test")) + >>> try: + ... result = client.reset() + ... result = client.step(EchoAction(message="Test")) + ... finally: + ... client.close() """ def _step_payload(self, action: EchoAction) -> Dict: diff --git a/envs/echo_env/models.py b/envs/echo_env/models.py index 4cbf1016c..c3c2e5a86 100644 --- a/envs/echo_env/models.py +++ b/envs/echo_env/models.py @@ -10,8 +10,6 @@ The Echo environment is a simple test environment that echoes back messages. """ -from dataclasses import dataclass - # Support both in-repo and standalone imports try: # In-repo imports (when running from OpenEnv repository) @@ -21,14 +19,12 @@ from openenv.core.env_server.types import Action, Observation -@dataclass(kw_only=True) class EchoAction(Action): """Action for the Echo environment - just a message to echo.""" message: str -@dataclass(kw_only=True) class EchoObservation(Observation): """Observation from the Echo environment - the echoed message.""" diff --git a/envs/echo_env/server/app.py b/envs/echo_env/server/app.py index 96c803040..07fe59ecb 100644 --- a/envs/echo_env/server/app.py +++ b/envs/echo_env/server/app.py @@ -8,7 +8,7 @@ FastAPI application for the Echo Environment. This module creates an HTTP server that exposes the EchoEnvironment -over HTTP endpoints, making it compatible with HTTPEnvClient. +over HTTP and WebSocket endpoints, compatible with EnvClient. Usage: # Development (with auto-reload): @@ -33,11 +33,9 @@ from models import EchoAction, EchoObservation from server.echo_environment import EchoEnvironment -# Create the environment instance -env = EchoEnvironment() - # Create the app with web interface and README integration -app = create_app(env, EchoAction, EchoObservation, env_name="echo_env") +# Pass the class (factory) instead of an instance for WebSocket session support +app = create_app(EchoEnvironment, EchoAction, EchoObservation, env_name="echo_env") def main(): diff --git a/envs/finrl_env/client.py b/envs/finrl_env/client.py index 38ab07382..9fb1a51ed 100644 --- a/envs/finrl_env/client.py +++ b/envs/finrl_env/client.py @@ -5,10 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -FinRL Environment HTTP Client. +FinRL Environment Client. This module provides the client for connecting to a FinRL Environment server -over HTTP. +via WebSocket for persistent sessions. """ from typing import Any, Dict @@ -16,81 +16,69 @@ from openenv.core.client_types import StepResult from openenv.core.env_server.types import State -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.env_client import EnvClient from .models import FinRLAction, FinRLObservation -class FinRLEnv(HTTPEnvClient[FinRLAction, FinRLObservation]): +class FinRLEnv(EnvClient[FinRLAction, FinRLObservation, State]): """ - HTTP client for the FinRL Environment. + Client for the FinRL Environment. - This client connects to a FinRLEnvironment HTTP server and provides - methods to interact with it for stock trading RL tasks. + This client maintains a persistent WebSocket connection to the environment + server, enabling efficient multi-step interactions for stock trading RL tasks. Example: >>> # Connect to a running server - >>> client = FinRLEnv(base_url="http://localhost:8000") - >>> result = client.reset() - >>> print(result.observation.state) - >>> print(result.observation.portfolio_value) - >>> - >>> # Execute a trading action - >>> action = FinRLAction(actions=[0.5, -0.3]) # Buy stock 0, sell stock 1 - >>> result = client.step(action) - >>> print(result.reward) - >>> print(result.observation.portfolio_value) + >>> with FinRLEnv(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.state) + ... print(result.observation.portfolio_value) + ... + ... # Execute a trading action + ... action = FinRLAction(actions=[0.5, -0.3]) # Buy stock 0, sell stock 1 + ... result = client.step(action) + ... print(result.reward) + ... print(result.observation.portfolio_value) Example with Docker: >>> # Automatically start container and connect >>> client = FinRLEnv.from_docker_image("finrl-env:latest") - >>> result = client.reset() - >>> result = client.step(FinRLAction(actions=[0.1])) - >>> client.close() + >>> try: + ... result = client.reset() + ... result = client.step(FinRLAction(actions=[0.1])) + ... finally: + ... client.close() Example training loop: >>> import numpy as np >>> from envs.finrl_env import FinRLEnv, FinRLAction >>> - >>> client = FinRLEnv(base_url="http://localhost:8000") - >>> - >>> # Training loop - >>> for episode in range(10): - >>> result = client.reset() - >>> done = False - >>> episode_reward = 0 - >>> - >>> while not done: - >>> # Get state - >>> state = result.observation.state - >>> - >>> # Simple random policy (replace with your RL agent) - >>> num_stocks = len(state) // 7 # Simplified calculation - >>> actions = np.random.uniform(-1, 1, size=num_stocks).tolist() - >>> - >>> # Execute action - >>> result = client.step(FinRLAction(actions=actions)) - >>> - >>> episode_reward += result.reward or 0 - >>> done = result.done - >>> - >>> print(f"Episode {episode}: reward={episode_reward:.2f}, " - >>> f"final value={result.observation.portfolio_value:.2f}") - >>> - >>> client.close() + >>> with FinRLEnv(base_url="http://localhost:8000") as client: + ... # Training loop + ... for episode in range(10): + ... result = client.reset() + ... done = False + ... episode_reward = 0 + ... + ... while not done: + ... # Get state + ... state = result.observation.state + ... + ... # Simple random policy (replace with your RL agent) + ... num_stocks = len(state) // 7 # Simplified calculation + ... actions = np.random.uniform(-1, 1, size=num_stocks).tolist() + ... + ... # Execute action + ... result = client.step(FinRLAction(actions=actions)) + ... + ... episode_reward += result.reward or 0 + ... done = result.done + ... + ... print(f"Episode {episode}: reward={episode_reward:.2f}, " + ... f"final value={result.observation.portfolio_value:.2f}") """ - def get_config(self) -> Dict[str, Any]: - """ - Get the environment configuration from the server. - - Returns: - Dictionary containing environment configuration - """ - response = self.session.get(f"{self.base_url}/config") - response.raise_for_status() - return response.json() - def _step_payload(self, action: FinRLAction) -> Dict: """ Convert FinRLAction to JSON payload for step request. diff --git a/envs/finrl_env/server/app.py b/envs/finrl_env/server/app.py index 1e4a34ca9..f02f659c7 100644 --- a/envs/finrl_env/server/app.py +++ b/envs/finrl_env/server/app.py @@ -8,7 +8,7 @@ FastAPI application for the FinRL Environment. This module creates an HTTP server that exposes the FinRLEnvironment -over HTTP endpoints, making it compatible with HTTPEnvClient. +over HTTP and WebSocket endpoints, compatible with EnvClient. The server expects environment configuration to be provided either: 1. Through environment variables (FINRL_CONFIG_PATH) @@ -32,7 +32,7 @@ from pathlib import Path import pandas as pd -from openenv.core.env_server import create_fastapi_app +from openenv.core.env_server import create_app from ..models import FinRLAction, FinRLObservation from .finrl_environment import FinRLEnvironment @@ -116,11 +116,16 @@ def load_finrl_config(): # Load configuration finrl_env_class, finrl_config = load_finrl_config() -# Create the environment instance -env = FinRLEnvironment(finrl_env_class=finrl_env_class, finrl_env_config=finrl_config) + +# Factory function to create FinRLEnvironment instances +def create_finrl_environment(): + """Factory function that creates FinRLEnvironment with config.""" + return FinRLEnvironment(finrl_env_class=finrl_env_class, finrl_env_config=finrl_config) + # Create the FastAPI app with routes -app = create_fastapi_app(env, FinRLAction, FinRLObservation) +# Pass the factory function instead of an instance for WebSocket session support +app = create_app(create_finrl_environment, FinRLAction, FinRLObservation, env_name="finrl_env") @app.get("/config") diff --git a/envs/git_env/client.py b/envs/git_env/client.py index 28824a578..efbf6182d 100644 --- a/envs/git_env/client.py +++ b/envs/git_env/client.py @@ -3,7 +3,9 @@ GitEnv Client ------------- Client-side wrapper for the Git environment server. -Talks HTTP to a single base_url exposing: /reset and /step. + +This client maintains a persistent WebSocket connection to the environment +server, enabling efficient multi-step interactions with lower latency. """ from __future__ import annotations @@ -11,7 +13,7 @@ from typing import TYPE_CHECKING from openenv.core.client_types import StepResult -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.env_client import EnvClient from .models import GitAction, GitObservation, GitState @@ -19,12 +21,12 @@ from openenv.core.containers.runtime import ContainerProvider -class GitEnv(HTTPEnvClient[GitAction, GitObservation]): +class GitEnv(EnvClient[GitAction, GitObservation, GitState]): """ Client for Git Environment with Gitea server. - This client communicates with the Git environment server over HTTP, - allowing agents to perform Git operations through a simple API. + This client maintains a persistent WebSocket connection to the environment + server, enabling efficient multi-step interactions for Git operations. The environment connects to a shared external Gitea service. Repositories must be pre-migrated to Gitea before use. @@ -32,25 +34,25 @@ class GitEnv(HTTPEnvClient[GitAction, GitObservation]): Example: >>> # From Docker image >>> client = GitEnv.from_docker_image("git-env:latest") - >>> result = client.reset() - >>> - >>> # List available repositories - >>> from envs.git_env import GitAction - >>> result = client.step(GitAction(action_type="list_repos")) - >>> print(result.observation.repos) - >>> - >>> # Clone repository to workspace - >>> result = client.step(GitAction(action_type="clone_repo", repo_name="OpenEnv")) - >>> - >>> # Execute git commands - >>> result = client.step(GitAction( - ... action_type="execute_git_command", - ... command="status", - ... working_dir="OpenEnv" - ... )) - >>> - >>> # Cleanup - >>> client.close() + >>> try: + ... result = client.reset() + ... + ... # List available repositories + ... from envs.git_env import GitAction + ... result = client.step(GitAction(action_type="list_repos")) + ... print(result.observation.repos) + ... + ... # Clone repository to workspace + ... result = client.step(GitAction(action_type="clone_repo", repo_name="OpenEnv")) + ... + ... # Execute git commands + ... result = client.step(GitAction( + ... action_type="execute_git_command", + ... command="status", + ... working_dir="OpenEnv" + ... )) + ... finally: + ... client.close() """ def _step_payload(self, action: GitAction) -> dict: diff --git a/envs/git_env/server/app.py b/envs/git_env/server/app.py index 3246c4af5..a73e22973 100644 --- a/envs/git_env/server/app.py +++ b/envs/git_env/server/app.py @@ -44,16 +44,21 @@ if not gitea_password: raise RuntimeError("GITEA_PASSWORD environment variable is required") -# Create the environment instance (connects to external Gitea) -env = GitTaskEnvironment( - gitea_url=gitea_url, - username=gitea_username, - password=gitea_password, - workspace_dir=workspace_dir, -) + +# Factory function to create GitTaskEnvironment instances +def create_git_environment(): + """Factory function that creates GitTaskEnvironment with config.""" + return GitTaskEnvironment( + gitea_url=gitea_url, + username=gitea_username, + password=gitea_password, + workspace_dir=workspace_dir, + ) + # Create the app with web interface and README integration -app = create_app(env, GitAction, GitObservation, env_name="git_env") +# Pass the factory function instead of an instance for WebSocket session support +app = create_app(create_git_environment, GitAction, GitObservation, env_name="git_env") if __name__ == "__main__": diff --git a/envs/openspiel_env/client.py b/envs/openspiel_env/client.py index cb80e8f68..946cd1fdd 100644 --- a/envs/openspiel_env/client.py +++ b/envs/openspiel_env/client.py @@ -5,10 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -OpenSpielEnv HTTP Client. +OpenSpielEnv Client. This module provides the client for connecting to an OpenSpiel Environment server -over HTTP. +via WebSocket for persistent sessions. """ from __future__ import annotations @@ -17,7 +17,7 @@ from openenv.core.client_types import StepResult -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.env_client import EnvClient from .models import OpenSpielAction, OpenSpielObservation, OpenSpielState @@ -25,28 +25,30 @@ from openenv.core.containers.runtime import ContainerProvider -class OpenSpielEnv(HTTPEnvClient[OpenSpielAction, OpenSpielObservation]): +class OpenSpielEnv(EnvClient[OpenSpielAction, OpenSpielObservation, OpenSpielState]): """ - HTTP client for OpenSpiel Environment. + Client for OpenSpiel Environment. - This client connects to an OpenSpielEnvironment HTTP server and provides - methods to interact with it: reset(), step(), and state access. + This client maintains a persistent WebSocket connection to the environment + server, enabling efficient multi-step interactions with lower latency. Example: >>> # Connect to a running server - >>> client = OpenSpielEnv(base_url="http://localhost:8000") - >>> result = client.reset() - >>> print(result.observation.info_state) - >>> - >>> # Take an action - >>> result = client.step(OpenSpielAction(action_id=1, game_name="catch")) - >>> print(result.observation.reward) + >>> with OpenSpielEnv(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.info_state) + ... + ... result = client.step(OpenSpielAction(action_id=1, game_name="catch")) + ... print(result.observation.reward) Example with Docker: >>> # Automatically start container and connect >>> client = OpenSpielEnv.from_docker_image("openspiel-env:latest") - >>> result = client.reset() - >>> result = client.step(OpenSpielAction(action_id=0)) + >>> try: + ... result = client.reset() + ... result = client.step(OpenSpielAction(action_id=0)) + ... finally: + ... client.close() """ def _step_payload(self, action: OpenSpielAction) -> Dict[str, Any]: diff --git a/envs/openspiel_env/server/app.py b/envs/openspiel_env/server/app.py index 11107fbd4..01dc35218 100644 --- a/envs/openspiel_env/server/app.py +++ b/envs/openspiel_env/server/app.py @@ -8,7 +8,7 @@ FastAPI application for the OpenSpiel Environment. This module creates an HTTP server that exposes OpenSpiel games -over HTTP endpoints, making them compatible with HTTPEnvClient. +over HTTP and WebSocket endpoints, compatible with EnvClient. Usage: # Development (with auto-reload): @@ -38,15 +38,20 @@ agent_player = int(os.getenv("OPENSPIEL_AGENT_PLAYER", "0")) opponent_policy = os.getenv("OPENSPIEL_OPPONENT_POLICY", "random") -# Create the environment instance -env = OpenSpielEnvironment( - game_name=game_name, - agent_player=agent_player, - opponent_policy=opponent_policy, -) + +# Factory function to create OpenSpielEnvironment instances +def create_openspiel_environment(): + """Factory function that creates OpenSpielEnvironment with config.""" + return OpenSpielEnvironment( + game_name=game_name, + agent_player=agent_player, + opponent_policy=opponent_policy, + ) + # Create the FastAPI app with web interface and README integration -app = create_app(env, OpenSpielAction, OpenSpielObservation, env_name="openspiel_env") +# Pass the factory function instead of an instance for WebSocket session support +app = create_app(create_openspiel_environment, OpenSpielAction, OpenSpielObservation, env_name="openspiel_env") if __name__ == "__main__": diff --git a/envs/sumo_rl_env/client.py b/envs/sumo_rl_env/client.py index 19fb5bd36..89390398d 100644 --- a/envs/sumo_rl_env/client.py +++ b/envs/sumo_rl_env/client.py @@ -5,47 +5,46 @@ # LICENSE file in the root directory of this source tree. """ -HTTP client for SUMO-RL environment. +Client for SUMO-RL environment. This module provides a client to interact with the SUMO traffic signal -control environment over HTTP. +control environment via WebSocket for persistent sessions. """ from typing import Any, Dict from openenv.core.client_types import StepResult -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.env_client import EnvClient from .models import SumoAction, SumoObservation, SumoState -class SumoRLEnv(HTTPEnvClient[SumoAction, SumoObservation]): +class SumoRLEnv(EnvClient[SumoAction, SumoObservation, SumoState]): """ - HTTP client for SUMO-RL traffic signal control environment. + Client for SUMO-RL traffic signal control environment. - This client communicates with a SUMO environment server to control - traffic signals using reinforcement learning. + This client maintains a persistent WebSocket connection to a SUMO + environment server to control traffic signals using reinforcement learning. Example: >>> # Start container and connect >>> env = SumoRLEnv.from_docker_image("sumo-rl-env:latest") - >>> - >>> # Reset environment - >>> result = env.reset() - >>> print(f"Observation shape: {result.observation.observation_shape}") - >>> print(f"Action space: {result.observation.action_mask}") - >>> - >>> # Take action - >>> result = env.step(SumoAction(phase_id=1)) - >>> print(f"Reward: {result.reward}, Done: {result.done}") - >>> - >>> # Get state - >>> state = env.state() - >>> print(f"Sim time: {state.sim_time}, Total vehicles: {state.total_vehicles}") - >>> - >>> # Cleanup - >>> env.close() + >>> try: + ... # Reset environment + ... result = env.reset() + ... print(f"Observation shape: {result.observation.observation_shape}") + ... print(f"Action space: {result.observation.action_mask}") + ... + ... # Take action + ... result = env.step(SumoAction(phase_id=1)) + ... print(f"Reward: {result.reward}, Done: {result.done}") + ... + ... # Get state + ... state = env.state() + ... print(f"Sim time: {state.sim_time}, Total vehicles: {state.total_vehicles}") + ... finally: + ... env.close() Example with custom network: >>> # Use custom SUMO network via volume mount diff --git a/envs/sumo_rl_env/server/app.py b/envs/sumo_rl_env/server/app.py index 3240902c2..b0f5ea7d3 100644 --- a/envs/sumo_rl_env/server/app.py +++ b/envs/sumo_rl_env/server/app.py @@ -13,7 +13,7 @@ import os -from openenv.core.env_server import create_fastapi_app +from openenv.core.env_server import create_app from ..models import SumoAction, SumoObservation from .sumo_environment import SumoEnvironment @@ -29,19 +29,23 @@ reward_fn = os.getenv("SUMO_REWARD_FN", "diff-waiting-time") sumo_seed = int(os.getenv("SUMO_SEED", "42")) -# Create single environment instance -# This is reused for all HTTP requests (avoids TraCI connection issues) -env = SumoEnvironment( - net_file=net_file, - route_file=route_file, - num_seconds=num_seconds, - delta_time=delta_time, - yellow_time=yellow_time, - min_green=min_green, - max_green=max_green, - reward_fn=reward_fn, - sumo_seed=sumo_seed, -) + +# Factory function to create SumoEnvironment instances +def create_sumo_environment(): + """Factory function that creates SumoEnvironment with config.""" + return SumoEnvironment( + net_file=net_file, + route_file=route_file, + num_seconds=num_seconds, + delta_time=delta_time, + yellow_time=yellow_time, + min_green=min_green, + max_green=max_green, + reward_fn=reward_fn, + sumo_seed=sumo_seed, + ) + # Create FastAPI app -app = create_fastapi_app(env, SumoAction, SumoObservation) +# Pass the factory function instead of an instance for WebSocket session support +app = create_app(create_sumo_environment, SumoAction, SumoObservation, env_name="sumo_rl_env") diff --git a/envs/textarena_env/client.py b/envs/textarena_env/client.py index 36f59716a..9c2b52a01 100644 --- a/envs/textarena_env/client.py +++ b/envs/textarena_env/client.py @@ -4,14 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""HTTP client for the generic TextArena environment.""" +"""Client for the generic TextArena environment.""" from __future__ import annotations from typing import Any, Dict, TYPE_CHECKING from openenv.core.client_types import StepResult -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.env_client import EnvClient from .models import ( TextArenaAction, @@ -24,8 +24,8 @@ from openenv.core.containers.runtime import ContainerProvider -class TextArenaEnv(HTTPEnvClient[TextArenaAction, TextArenaObservation]): - """HTTP client for the TextArena environment server.""" +class TextArenaEnv(EnvClient[TextArenaAction, TextArenaObservation, TextArenaState]): + """Client for the TextArena environment server.""" def _step_payload(self, action: TextArenaAction) -> Dict[str, Any]: return {"message": action.message} diff --git a/envs/textarena_env/server/app.py b/envs/textarena_env/server/app.py index 83d8d09ec..900a138c0 100644 --- a/envs/textarena_env/server/app.py +++ b/envs/textarena_env/server/app.py @@ -35,15 +35,22 @@ def _parse_env_kwargs(prefix: str = "TEXTARENA_KW_") -> dict[str, str]: extra_kwargs = _parse_env_kwargs() -environment = TextArenaEnvironment( - env_id=env_id, - num_players=num_players, - max_turns=max_turns, - download_nltk=download_nltk, - env_kwargs=extra_kwargs, -) - -app = create_app(environment, TextArenaAction, TextArenaObservation, env_name="textarena_env") + +# Factory function to create TextArenaEnvironment instances +def create_textarena_environment(): + """Factory function that creates TextArenaEnvironment with config.""" + return TextArenaEnvironment( + env_id=env_id, + num_players=num_players, + max_turns=max_turns, + download_nltk=download_nltk, + env_kwargs=extra_kwargs, + ) + + +# Create the FastAPI app +# Pass the factory function instead of an instance for WebSocket session support +app = create_app(create_textarena_environment, TextArenaAction, TextArenaObservation, env_name="textarena_env") if __name__ == "__main__": From f39e5a1fe96c02eddd5c8b8211e446136d6a3afb Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Thu, 18 Dec 2025 09:46:50 +0100 Subject: [PATCH 38/41] use pydantic in connect4 env --- envs/connect4_env/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/envs/connect4_env/models.py b/envs/connect4_env/models.py index 4d1109c2d..90ee90742 100644 --- a/envs/connect4_env/models.py +++ b/envs/connect4_env/models.py @@ -40,8 +40,8 @@ class Connect4Observation(Observation): reward: Reward for the last action. """ - board: List[List[int]] - legal_actions: List[int] + board: List[List[int]] = Field(default_factory=list) + legal_actions: List[int] = Field(default_factory=list) class Connect4State(State): From ce16e84d920613c61a1115310d7755c520ea3b9d Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Thu, 18 Dec 2025 09:47:01 +0100 Subject: [PATCH 39/41] use pydantic in dipg --- envs/dipg_safety_env/models.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/envs/dipg_safety_env/models.py b/envs/dipg_safety_env/models.py index dbd9e04ec..a770e7355 100644 --- a/envs/dipg_safety_env/models.py +++ b/envs/dipg_safety_env/models.py @@ -1,24 +1,25 @@ # envs/dipg_safety_env/models.py -from dataclasses import dataclass, field +from typing import Dict, Any +from pydantic import Field from openenv.core.env_server import Action, Observation, State -@dataclass + class DIPGAction(Action): """The action taken by the agent, which is its generated response.""" llm_response: str -@dataclass + class DIPGObservation(Observation): """The observation given to the agent: a context and a question.""" - context: str - question: str + context: str = "" + question: str = "" + -@dataclass class DIPGState(State): """The internal state of the environment for tracking the current challenge.""" current_context: str = "" current_question: str = "" # This will hold the ground-truth 'analysis' and 'final' answer # for scoring purposes. - expected_answer: dict = field(default_factory=dict) \ No newline at end of file + expected_answer: Dict[str, Any] = Field(default_factory=dict) \ No newline at end of file From 0e186eabd5ea6fc083932c9ddc61ce5c9746ca09 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Thu, 18 Dec 2025 11:07:39 +0100 Subject: [PATCH 40/41] add async to web interface --- src/openenv/core/env_server/web_interface.py | 3152 +++++++++--------- 1 file changed, 1585 insertions(+), 1567 deletions(-) diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index 210a7804b..2def62bda 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -1,1567 +1,1585 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Web interface for OpenEnv environments. - -This module provides a web-based interface for interacting with OpenEnv environments, -including a two-pane layout for HumanAgent interaction and state observation. -""" - -from __future__ import annotations - -import json -from typing import Any, Callable, Dict, List, Optional, Type, Union -from datetime import datetime - -from fastapi import FastAPI, WebSocket, WebSocketDisconnect -from fastapi.responses import HTMLResponse -from pydantic import Field - -from .interfaces import Environment -from .serialization import deserialize_action_with_preprocessing, serialize_observation -from .types import Action, Observation, State, EnvironmentMetadata, ConcurrencyConfig, BaseMessage - - -def load_environment_metadata(env: Environment, env_name: Optional[str] = None) -> EnvironmentMetadata: - """ - Load environment metadata including README content. - - Args: - env: The environment instance - env_name: Optional environment name for README file lookup - - Returns: - EnvironmentMetadata with loaded information - """ - # Try to get metadata from environment if it has a method for it - if hasattr(env, "get_metadata"): - return env.get_metadata() - - # Default metadata - metadata = EnvironmentMetadata( - name=env_name or env.__class__.__name__, - description=f"{env.__class__.__name__} environment", - version="1.0.0", - ) - - # Try to load README from file system - readme_content = _load_readme_from_filesystem(env_name) - if readme_content: - metadata.readme_content = readme_content - - return metadata - - -def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]: - """ - Load README content from the filesystem. - - Tries multiple locations: - 1. Container filesystem: /app/README.md - 2. Local development: src/envs/{env_name}/README.md - 3. Environment variable: ENV_README_PATH - """ - import os - from pathlib import Path - - # Try container filesystem first - container_readme = Path("/app/README.md") - if container_readme.exists(): - try: - return container_readme.read_text(encoding="utf-8") - except Exception: - pass - - # Try environment variable path - custom_path = os.environ.get("ENV_README_PATH") - if custom_path and Path(custom_path).exists(): - try: - return Path(custom_path).read_text(encoding="utf-8") - except Exception: - pass - - # Try local development path - if env_name: - local_readme = Path(f"src/envs/{env_name}/README.md") - if local_readme.exists(): - try: - return local_readme.read_text(encoding="utf-8") - except Exception: - pass - - return None - - -class ActionLog(BaseMessage): - """Log entry for an action taken.""" - - timestamp: str = Field(description="Timestamp when action was taken") - action: Dict[str, Any] = Field(description="Action that was taken") - observation: Dict[str, Any] = Field(description="Observation returned from action") - reward: Optional[float] = Field(default=None, description="Reward received from action") - done: bool = Field(description="Whether the episode is done after this action") - step_count: int = Field(description="Step count when this action was taken") - - -class EpisodeState(BaseMessage): - """Current episode state for the web interface.""" - - episode_id: Optional[str] = Field(default=None, description="Current episode ID") - step_count: int = Field(description="Current step count in episode") - current_observation: Optional[Dict[str, Any]] = Field(default=None, description="Current observation") - action_logs: List[ActionLog] = Field(default_factory=list, description="List of action logs") - is_reset: bool = Field(default=True, description="Whether the episode has been reset") - - -class WebInterfaceManager: - """Manages the web interface for an environment.""" - - def __init__( - self, - env: Environment, - action_cls: Type[Action], - observation_cls: Type[Observation], - metadata: Optional[EnvironmentMetadata] = None, - ): - self.env = env - self.action_cls = action_cls - self.observation_cls = observation_cls - self.metadata = metadata or EnvironmentMetadata( - name=env.__class__.__name__, - description=f"{env.__class__.__name__} environment", - ) - self.episode_state = EpisodeState(episode_id=None, step_count=0, current_observation=None, action_logs=[]) - self.connected_clients: List[WebSocket] = [] - - async def connect_websocket(self, websocket: WebSocket): - """Connect a new WebSocket client.""" - await websocket.accept() - self.connected_clients.append(websocket) - - # Send current state to the new client - await self._send_state_update() - - async def disconnect_websocket(self, websocket: WebSocket): - """Disconnect a WebSocket client.""" - if websocket in self.connected_clients: - self.connected_clients.remove(websocket) - - async def _send_state_update(self): - """Send current state to all connected clients.""" - if not self.connected_clients: - return - - state_data = { - "type": "state_update", - "episode_state": self.episode_state.model_dump(), - } - - # Send to all connected clients - disconnected_clients = [] - for client in self.connected_clients: - try: - await client.send_text(json.dumps(state_data)) - except Exception: - disconnected_clients.append(client) - - # Remove disconnected clients - for client in disconnected_clients: - self.connected_clients.remove(client) - - async def reset_environment(self) -> Dict[str, Any]: - """Reset the environment and update state.""" - observation: Observation = self.env.reset() - state: State = self.env.state - - # Serialize observation once using shared utility - serialized = serialize_observation(observation) - - # Update episode state - self.episode_state.episode_id = state.episode_id - self.episode_state.step_count = 0 - self.episode_state.current_observation = serialized["observation"] - self.episode_state.action_logs = [] - self.episode_state.is_reset = True - - # Send state update - await self._send_state_update() - - return serialized - - async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]: - """Execute a step in the environment and update state.""" - # Deserialize action with preprocessing for web interface special cases - action: Action = deserialize_action_with_preprocessing(action_data, self.action_cls) - - # Execute step - observation: Observation = self.env.step(action) - state: State = self.env.state - - # Serialize observation once using shared utility - serialized = serialize_observation(observation) - - # Create action log - action_log = ActionLog( - timestamp=datetime.now().isoformat(), - action=action.model_dump(exclude={"metadata"}), - observation=serialized["observation"], - reward=observation.reward, - done=observation.done, - step_count=state.step_count, - ) - - # Update episode state - self.episode_state.episode_id = state.episode_id - self.episode_state.step_count = state.step_count - self.episode_state.current_observation = serialized["observation"] - self.episode_state.action_logs.append(action_log) - self.episode_state.is_reset = False - - # Send state update - await self._send_state_update() - - return serialized - - def get_state(self) -> Dict[str, Any]: - """Get current environment state.""" - state: State = self.env.state - return state.model_dump() - - -def create_web_interface_app( - env: Callable[[], Environment], - action_cls: Type[Action], - observation_cls: Type[Observation], - env_name: Optional[str] = None, - max_concurrent_envs: Optional[int] = None, - concurrency_config: Optional[ConcurrencyConfig] = None, -) -> FastAPI: - """ - Create a FastAPI application with web interface for the given environment. - - Args: - env: Environment factory (callable) that creates new instances - action_cls: The Action subclass this environment expects - observation_cls: The Observation subclass this environment returns - env_name: Optional environment name for README loading - max_concurrent_envs: Maximum concurrent WebSocket sessions. - Mutually exclusive with concurrency_config. - concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. - Mutually exclusive with max_concurrent_envs. - - Returns: - FastAPI application instance with web interface - """ - from .http_server import create_fastapi_app - - # Create the base environment app - app = create_fastapi_app(env, action_cls, observation_cls, max_concurrent_envs, concurrency_config) - - # Create a test instance for metadata - env_instance = env() - - # Load environment metadata - metadata = load_environment_metadata(env_instance, env_name) - - # Create web interface manager - web_manager = WebInterfaceManager(env_instance, action_cls, observation_cls, metadata) - - # Add web interface routes - @app.get("/web", response_class=HTMLResponse) - async def web_interface(): - """Serve the web interface.""" - return get_web_interface_html(action_cls, web_manager.metadata) - - @app.get("/web/metadata") - async def web_metadata(): - """Get environment metadata.""" - return web_manager.metadata.model_dump() - - @app.websocket("/ws/ui") - async def websocket_ui_endpoint(websocket: WebSocket): - """WebSocket endpoint for web UI real-time updates. - - Note: This endpoint is separate from /ws which is used for - concurrent environment sessions. This endpoint is specifically - for the web interface state updates. - """ - await web_manager.connect_websocket(websocket) - try: - while True: - # Keep connection alive - await websocket.receive_text() - except WebSocketDisconnect: - await web_manager.disconnect_websocket(websocket) - - @app.post("/web/reset") - async def web_reset(): - """Reset endpoint for web interface.""" - return await web_manager.reset_environment() - - @app.post("/web/step") - async def web_step(request: Dict[str, Any]): - """Step endpoint for web interface.""" - if "message" in request: - message = request["message"] - if hasattr(web_manager.env, "message_to_action"): - action = getattr(web_manager.env, "message_to_action")(message) - action_data = {"tokens": action.tokens.tolist()} - else: - action_data = request.get("action", {}) - else: - action_data = request.get("action", {}) - - return await web_manager.step_environment(action_data) - - @app.get("/web/state") - async def web_state(): - """State endpoint for web interface.""" - return web_manager.get_state() - - return app - - -def get_web_interface_html(action_cls: Type[Action], metadata: Optional[EnvironmentMetadata] = None) -> str: - """Generate the HTML for the web interface.""" - - # Check if this is a chat environment by looking for tokens field - is_chat_env = False - if hasattr(action_cls, "model_fields"): - for field_name, field_info in action_cls.model_fields.items(): - if ( - field_name == "tokens" - and field_info.annotation is not None - and hasattr(field_info.annotation, "__name__") - and "Tensor" in field_info.annotation.__name__ - ): - is_chat_env = True - break - - # Get action fields for dynamic form generation with enhanced metadata - action_fields = _extract_action_fields(action_cls) - - return f""" - - - - - - OpenEnv Web Interface - - - -
    - -
    -
    - - HumanAgent Interface -
    -
    - - {_generate_instructions_section(metadata)} - - - {_generate_action_interface(action_fields, is_chat_env)} - - -
    - - -
    - - -
    -

    Current State

    -
    -
    - Status: - Not initialized -
    -
    - Episode ID: - - -
    -
    - Step Count: - 0 -
    -
    -
    -
    -
    - - -
    -
    - State Observer -
    -
    - -
    -

    Current Observation

    -
    - No observation yet -
    -
    - - -
    -

    Action History

    -
    - No actions taken yet -
    -
    -
    -
    -
    - - - - - """.replace( - "{_generate_action_form_fields(action_fields)}", - _generate_action_form_fields(action_fields), - ) - - -def _generate_instructions_section(metadata: Optional[EnvironmentMetadata]) -> str: - """Generate the instructions section with environment documentation.""" - if not metadata or not metadata.readme_content: - return "" - - html_content = _markdown_to_html(metadata.readme_content) - - return f""" - -
    -
    -

    {metadata.name}

    - -
    -
    -
    - {html_content} -
    -
    -
    - """ - - -def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]: - """Extract enhanced field metadata from Action class for form generation.""" - # Use Pydantic's JSON schema generation for robust metadata extraction - try: - schema = action_cls.model_json_schema() - except AttributeError: - # Fallback for non-Pydantic v2 models or if something goes wrong - return [] - - properties = schema.get("properties", {}) - required_fields = schema.get("required", []) - - action_fields = [] - - for field_name, field_info in properties.items(): - if field_name == "metadata": - continue - - # JSON schema "type" can be a string or list/undefined - # Determine our internal input type - input_type = _determine_input_type_from_schema(field_info, field_name) - - is_required = field_name in required_fields - - action_fields.append( - { - "name": field_name, - "type": input_type, - "required": is_required, - "description": field_info.get("description", ""), - "default_value": field_info.get("default"), - "choices": field_info.get("enum"), - "min_value": field_info.get("minimum"), - "max_value": field_info.get("maximum"), - "min_length": field_info.get("minLength"), - "max_length": field_info.get("maxLength"), - "pattern": field_info.get("pattern"), - "placeholder": _generate_placeholder(field_name, field_info), - "help_text": _generate_help_text(field_name, field_info), - } - ) - - return action_fields - - -def _determine_input_type_from_schema(field_info: Dict[str, Any], field_name: str) -> str: - """Determine the appropriate HTML input type from JSON schema info.""" - schema_type = field_info.get("type") - - # Check for specific tensor field convention - if "tokens" in field_name.lower(): - return "tensor" - - if "enum" in field_info: - return "select" - - if schema_type == "boolean": - return "checkbox" - - if schema_type == "integer" or schema_type == "number": - return "number" - - if schema_type == "string": - # Check if it should be a textarea - if field_info.get("maxLength", 0) > 100 or "message" in field_name.lower() or "code" in field_name.lower(): - return "textarea" - return "text" - - # Default fallback - return "text" - - -def _generate_placeholder(field_name: str, field_info: Dict[str, Any]) -> str: - """Generate placeholder text.""" - if "message" in field_name.lower(): - return f"Enter {field_name.replace('_', ' ')}..." - elif "code" in field_name.lower(): - return "Enter Python code here..." - elif "tokens" in field_name.lower(): - return "Enter comma-separated token IDs (e.g., 1,2,3,4,5)" - else: - return f"Enter {field_name.replace('_', ' ')}..." - - -def _generate_help_text(field_name: str, field_info: Dict[str, Any]) -> str: - """Generate help text.""" - description = field_info.get("description", "") - if description: - return description - - if "action_id" in field_name.lower(): - return "The action ID to execute in environment" - elif "game_name" in field_name.lower(): - return "Name of game or environment" - elif "tokens" in field_name.lower(): - return "Token IDs as a comma-separated list of integers" - elif "code" in field_name.lower(): - return "Python code to execute in environment" - elif "message" in field_name.lower(): - return "Text message to send" - - return "" - - -def _markdown_to_html(markdown: str) -> str: - """Convert basic markdown to HTML for README display.""" - import html - import re - - # Escape HTML first - html_content = html.escape(markdown) - - # Convert headers - html_content = re.sub(r"^# (.*?)$", r"

    \1

    ", html_content, flags=re.MULTILINE) - html_content = re.sub(r"^## (.*?)$", r"

    \1

    ", html_content, flags=re.MULTILINE) - html_content = re.sub(r"^### (.*?)$", r"

    \1

    ", html_content, flags=re.MULTILINE) - - # Convert code blocks - html_content = re.sub( - r"```(.*?)\n(.*?)\n```", - r"
    \2
    ", - html_content, - flags=re.DOTALL, - ) - html_content = re.sub(r"`([^`]+)`", r"\1", html_content) - - # Convert bold and italic - html_content = re.sub(r"\*\*(.*?)\*\*", r"\1", html_content) - html_content = re.sub(r"\*(.*?)\*", r"\1", html_content) - - # Convert lists - html_content = re.sub(r"^- (.*?)$", r"
  • \1
  • ", html_content, flags=re.MULTILINE) - html_content = re.sub(r"(
  • .*
  • )", r"
      \1
    ", html_content, flags=re.DOTALL) - - # Convert line breaks - html_content = html_content.replace("\n", "
    ") - - return html_content - - -def _generate_action_interface(action_fields: List[Dict[str, Any]], is_chat_env: bool) -> str: - """Generate either a chat interface or action form based on environment type.""" - if is_chat_env: - return _generate_chat_interface() - else: - return _generate_action_form(action_fields) - - -def _generate_chat_interface() -> str: - """Generate a chat-style interface for chat environments.""" - return """ - -
    -

    Chat Interface

    -
    -
    -
    System
    -
    Chat environment ready. Send a message to start the conversation.
    -
    -
    -
    -
    - - -
    -
    - - -
    -
    -
    - """ - - -def _generate_action_form(action_fields: List[Dict[str, Any]]) -> str: - """Generate a traditional action form for non-chat environments.""" - return f""" - -
    -

    Take Action

    -
    - {_generate_action_form_fields(action_fields)} - -
    -
    - """ - - -def _generate_action_form_fields(action_fields: List[Dict[str, Any]]) -> str: - """Generate HTML form fields for action input with enhanced metadata.""" - if not action_fields: - return "

    No action fields available

    " - - fields_html = [] - for field in action_fields: - field_html = _generate_single_field(field) - fields_html.append(field_html) - - return "\n".join(fields_html) - - -def _generate_single_field(field: Dict[str, Any]) -> str: - """Generate HTML for a single form field with enhanced metadata.""" - field_name = field["name"] - field_type = field["type"] - required = field["required"] - placeholder = field.get("placeholder", "") - help_text = field.get("help_text", "") - choices = field.get("choices", []) - min_value = field.get("min_value") - max_value = field.get("max_value") - default_value = field.get("default_value") - min_length = field.get("min_length") - max_length = field.get("max_length") - pattern = field.get("pattern") - - # Build label with required indicator - label_text = field_name.replace("_", " ").title() - if required: - label_text += ' *' - - # Build input attributes - input_attrs = [] - if required: - input_attrs.append("required") - if placeholder: - input_attrs.append(f'placeholder="{placeholder}"') - if min_value is not None: - input_attrs.append(f'min="{min_value}"') - if max_value is not None: - input_attrs.append(f'max="{max_value}"') - if min_length is not None: - input_attrs.append(f'minlength="{min_length}"') - if max_length is not None: - input_attrs.append(f'maxlength="{max_length}"') - if pattern is not None: - input_attrs.append(f'pattern="{pattern}"') - if default_value is not None: - input_attrs.append(f'value="{default_value}"') - - attrs_str = " ".join(input_attrs) - - if field_type == "checkbox": - checked = "checked" if default_value is True else "" - return f''' -
    - - {f'{help_text}' if help_text else ""} -
    - ''' - - elif field_type == "select": - options_html = [] - if not required: - options_html.append(f'') - - for choice in choices: - selected = "selected" if str(choice) == str(default_value) else "" - options_html.append(f'') - - return f''' -
    - - - {f'{help_text}' if help_text else ""} -
    - ''' - - elif field_type == "tensor": - return f''' -
    - - - {help_text or "Enter token IDs as comma-separated integers (e.g., 1,2,3,4,5)"} -
    - ''' - - elif field_type == "textarea": - return f''' -
    - - - {f'{help_text}' if help_text else ""} -
    - ''' - - else: - return f''' -
    - - - {f'{help_text}' if help_text else ""} -
    - ''' +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Web interface for OpenEnv environments. + +This module provides a web-based interface for interacting with OpenEnv environments, +including a two-pane layout for HumanAgent interaction and state observation. +""" + +from __future__ import annotations + +import asyncio +import json +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Optional, Type +from datetime import datetime + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.responses import HTMLResponse +from pydantic import BaseModel, Field, ConfigDict + +from .interfaces import Environment +from .serialization import ( + deserialize_action_with_preprocessing, + serialize_observation, +) +from .types import Action, Observation, State, EnvironmentMetadata + + +def load_environment_metadata(env: Environment, env_name: Optional[str] = None) -> EnvironmentMetadata: + """ + Load environment metadata including README content. + + Args: + env: The environment instance + env_name: Optional environment name for README file lookup + + Returns: + EnvironmentMetadata with loaded information + """ + # Try to get metadata from environment if it has a method for it + if hasattr(env, "get_metadata"): + return env.get_metadata() + + # Default metadata + metadata = EnvironmentMetadata( + name=env_name or env.__class__.__name__, + description=f"{env.__class__.__name__} environment", + version="1.0.0", + ) + + # Try to load README from file system + readme_content = _load_readme_from_filesystem(env_name) + if readme_content: + metadata.readme_content = readme_content + + return metadata + + +def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]: + """ + Load README content from the filesystem. + + Tries multiple locations: + 1. Container filesystem: /app/README.md + 2. Local development: src/envs/{env_name}/README.md + 3. Environment variable: ENV_README_PATH + """ + import os + from pathlib import Path + + # Try container filesystem first + container_readme = Path("/app/README.md") + if container_readme.exists(): + try: + return container_readme.read_text(encoding="utf-8") + except Exception: + pass + + # Try environment variable path + custom_path = os.environ.get("ENV_README_PATH") + if custom_path and Path(custom_path).exists(): + try: + return Path(custom_path).read_text(encoding="utf-8") + except Exception: + pass + + # Try local development path + if env_name: + local_readme = Path(f"src/envs/{env_name}/README.md") + if local_readme.exists(): + try: + return local_readme.read_text(encoding="utf-8") + except Exception: + pass + + return None + + +class ActionLog(BaseModel): + """Log entry for an action taken.""" + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + timestamp: str = Field(description="Timestamp when action was taken") + action: Dict[str, Any] = Field(description="Action that was taken") + observation: Dict[str, Any] = Field(description="Observation returned from action") + reward: Optional[float] = Field(default=None, description="Reward received from action") + done: bool = Field(description="Whether the episode is done after this action") + step_count: int = Field(description="Step count when this action was taken") + + +class EpisodeState(BaseModel): + """Current episode state for the web interface.""" + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + episode_id: Optional[str] = Field(default=None, description="Current episode ID") + step_count: int = Field(description="Current step count in episode") + current_observation: Optional[Dict[str, Any]] = Field(default=None, description="Current observation") + action_logs: List[ActionLog] = Field(default_factory=list, description="List of action logs") + is_reset: bool = Field(default=True, description="Whether the episode has been reset") + + +class WebInterfaceManager: + """Manages the web interface for an environment.""" + + def __init__( + self, + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], + metadata: Optional[EnvironmentMetadata] = None, + ): + self.env = env + self.action_cls = action_cls + self.observation_cls = observation_cls + self.metadata = metadata or EnvironmentMetadata( + name=env.__class__.__name__, + description=f"{env.__class__.__name__} environment", + ) + self.episode_state = EpisodeState( + episode_id=None, + step_count=0, + current_observation=None, + action_logs=[], + ) + self.connected_clients: List[WebSocket] = [] + # Thread pool for running sync code (e.g., Playwright sync API) in async context + self._executor = ThreadPoolExecutor(max_workers=1) + + async def _run_sync_in_thread_pool(self, func, *args, **kwargs): + """Run a synchronous function in the thread pool executor. + + This is needed for environments using sync libraries (e.g., Playwright sync API) + that cannot be called directly from an async context. + """ + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, lambda: func(*args, **kwargs)) + + async def connect_websocket(self, websocket: WebSocket): + """Connect a new WebSocket client.""" + await websocket.accept() + self.connected_clients.append(websocket) + + # Send current state to the new client + await self._send_state_update() + + async def disconnect_websocket(self, websocket: WebSocket): + """Disconnect a WebSocket client.""" + if websocket in self.connected_clients: + self.connected_clients.remove(websocket) + + async def _send_state_update(self): + """Send current state to all connected clients.""" + if not self.connected_clients: + return + + state_data = { + "type": "state_update", + "episode_state": self.episode_state.model_dump(), + } + + # Send to all connected clients + disconnected_clients = [] + for client in self.connected_clients: + try: + await client.send_text(json.dumps(state_data)) + except Exception: + disconnected_clients.append(client) + + # Remove disconnected clients + for client in disconnected_clients: + self.connected_clients.remove(client) + + async def reset_environment(self) -> Dict[str, Any]: + """Reset the environment and update state.""" + # Run sync reset in thread pool to avoid blocking event loop + # and to support environments using sync libraries (e.g., Playwright) + observation: Observation = await self._run_sync_in_thread_pool(self.env.reset) + state: State = self.env.state + + # Serialize observation once using shared utility + serialized = serialize_observation(observation) + + # Update episode state + self.episode_state.episode_id = state.episode_id + self.episode_state.step_count = 0 + self.episode_state.current_observation = serialized["observation"] + self.episode_state.action_logs = [] + self.episode_state.is_reset = True + + # Send state update + await self._send_state_update() + + return serialized + + async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]: + """Execute a step in the environment and update state.""" + # Deserialize action with preprocessing for web interface special cases + action: Action = deserialize_action_with_preprocessing(action_data, self.action_cls) + + # Run sync step in thread pool to avoid blocking event loop + # and to support environments using sync libraries (e.g., Playwright) + observation: Observation = await self._run_sync_in_thread_pool(self.env.step, action) + state: State = self.env.state + + # Serialize observation once using shared utility + serialized = serialize_observation(observation) + + # Create action log + action_log = ActionLog( + timestamp=datetime.now().isoformat(), + action=action.model_dump(exclude={"metadata"}), + observation=serialized["observation"], + reward=observation.reward, + done=observation.done, + step_count=state.step_count, + ) + + # Update episode state + self.episode_state.episode_id = state.episode_id + self.episode_state.step_count = state.step_count + self.episode_state.current_observation = serialized["observation"] + self.episode_state.action_logs.append(action_log) + self.episode_state.is_reset = False + + # Send state update + await self._send_state_update() + + return serialized + + def get_state(self) -> Dict[str, Any]: + """Get current environment state.""" + state: State = self.env.state + return state.model_dump() + + +def create_web_interface_app( + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], + env_name: Optional[str] = None, +) -> FastAPI: + """ + Create a FastAPI application with web interface for the given environment. + + Args: + env: The Environment instance to serve + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + env_name: Optional environment name for README loading + + Returns: + FastAPI application instance with web interface + """ + from .http_server import create_fastapi_app + + # Create the base environment app + app = create_fastapi_app(env, action_cls, observation_cls) + + # Load environment metadata + metadata = load_environment_metadata(env, env_name) + + # Create web interface manager + web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) + + # Add web interface routes + @app.get("/web", response_class=HTMLResponse) + async def web_interface(): + """Serve the web interface.""" + return get_web_interface_html(action_cls, web_manager.metadata) + + @app.get("/web/metadata") + async def web_metadata(): + """Get environment metadata.""" + return web_manager.metadata.model_dump() + + @app.websocket("/ws/ui") + async def websocket_ui_endpoint(websocket: WebSocket): + """WebSocket endpoint for web UI real-time updates. + + Note: Uses /ws/ui to avoid conflict with /ws in http_server.py + which is used for concurrent environment sessions. + """ + await web_manager.connect_websocket(websocket) + try: + while True: + # Keep connection alive + await websocket.receive_text() + except WebSocketDisconnect: + await web_manager.disconnect_websocket(websocket) + + @app.post("/web/reset") + async def web_reset(): + """Reset endpoint for web interface.""" + return await web_manager.reset_environment() + + @app.post("/web/step") + async def web_step(request: Dict[str, Any]): + """Step endpoint for web interface.""" + # Check if this is a message-based request (chat environment) + if "message" in request: + message = request["message"] + # Convert message to action using the environment's message_to_action method + action = web_manager.env.message_to_action(message) + action_data = {"tokens": action.tokens.tolist()} + else: + action_data = request.get("action", {}) + + return await web_manager.step_environment(action_data) + + @app.get("/web/state") + async def web_state(): + """State endpoint for web interface.""" + return web_manager.get_state() + + return app + + +def get_web_interface_html(action_cls: Type[Action], metadata: Optional[EnvironmentMetadata] = None) -> str: + """Generate the HTML for the web interface.""" + + # Check if this is a chat environment by looking for tokens field + is_chat_env = False + if hasattr(action_cls, "model_fields"): + for field_name, field_info in action_cls.model_fields.items(): + if ( + field_name == "tokens" + and hasattr(field_info.annotation, "__name__") + and "Tensor" in field_info.annotation.__name__ + ): + is_chat_env = True + break + + # Get action fields for dynamic form generation with enhanced metadata + action_fields = _extract_action_fields(action_cls) + + return f""" + + + + + + OpenEnv Web Interface + + + +
    + +
    +
    + + HumanAgent Interface +
    +
    + + {_generate_instructions_section(metadata)} + + + {_generate_action_interface(action_fields, is_chat_env)} + + +
    + + +
    + + +
    +

    Current State

    +
    +
    + Status: + Not initialized +
    +
    + Episode ID: + - +
    +
    + Step Count: + 0 +
    +
    +
    +
    +
    + + +
    +
    + State Observer +
    +
    + +
    +

    Current Observation

    +
    + No observation yet +
    +
    + + +
    +

    Action History

    +
    + No actions taken yet +
    +
    +
    +
    +
    + + + + + """.replace( + "{_generate_action_form_fields(action_fields)}", + _generate_action_form_fields(action_fields), + ) + + +def _generate_instructions_section( + metadata: Optional[EnvironmentMetadata], +) -> str: + """Generate the instructions section with environment documentation.""" + if not metadata or not metadata.readme_content: + return "" + + html_content = _markdown_to_html(metadata.readme_content) + + return f""" + +
    +
    +

    {metadata.name}

    + +
    +
    +
    + {html_content} +
    +
    +
    + """ + + +def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]: + """Extract enhanced field metadata from Action class for form generation.""" + # Use Pydantic's JSON schema generation for robust metadata extraction + try: + schema = action_cls.model_json_schema() + except AttributeError: + # Fallback for non-Pydantic v2 models or if something goes wrong + return [] + + properties = schema.get("properties", {}) + required_fields = schema.get("required", []) + + action_fields = [] + + for field_name, field_info in properties.items(): + if field_name == "metadata": + continue + + # JSON schema "type" can be a string or list/undefined + # Determine our internal input type + input_type = _determine_input_type_from_schema(field_info, field_name) + + is_required = field_name in required_fields + + action_fields.append( + { + "name": field_name, + "type": input_type, + "required": is_required, + "description": field_info.get("description", ""), + "default_value": field_info.get("default"), + "choices": field_info.get("enum"), + "min_value": field_info.get("minimum"), + "max_value": field_info.get("maximum"), + "min_length": field_info.get("minLength"), + "max_length": field_info.get("maxLength"), + "pattern": field_info.get("pattern"), + "placeholder": _generate_placeholder(field_name, field_info), + "help_text": _generate_help_text(field_name, field_info), + } + ) + + return action_fields + + +def _determine_input_type_from_schema(field_info: Dict[str, Any], field_name: str) -> str: + """Determine the appropriate HTML input type from JSON schema info.""" + schema_type = field_info.get("type") + + # Check for specific tensor field convention + if "tokens" in field_name.lower(): + return "tensor" + + if "enum" in field_info: + return "select" + + if schema_type == "boolean": + return "checkbox" + + if schema_type == "integer" or schema_type == "number": + return "number" + + if schema_type == "string": + # Check if it should be a textarea + if field_info.get("maxLength", 0) > 100 or "message" in field_name.lower() or "code" in field_name.lower(): + return "textarea" + return "text" + + # Default fallback + return "text" + + +def _generate_placeholder(field_name: str, field_info: Dict[str, Any]) -> str: + """Generate placeholder text.""" + if "message" in field_name.lower(): + return f"Enter {field_name.replace('_', ' ')}..." + elif "code" in field_name.lower(): + return "Enter Python code here..." + elif "tokens" in field_name.lower(): + return "Enter comma-separated token IDs (e.g., 1,2,3,4,5)" + else: + return f"Enter {field_name.replace('_', ' ')}..." + + +def _generate_help_text(field_name: str, field_info: Dict[str, Any]) -> str: + """Generate help text.""" + description = field_info.get("description", "") + if description: + return description + + if "action_id" in field_name.lower(): + return "The action ID to execute in environment" + elif "game_name" in field_name.lower(): + return "Name of game or environment" + elif "tokens" in field_name.lower(): + return "Token IDs as a comma-separated list of integers" + elif "code" in field_name.lower(): + return "Python code to execute in environment" + elif "message" in field_name.lower(): + return "Text message to send" + + return "" + + +def _markdown_to_html(markdown: str) -> str: + """Convert basic markdown to HTML for README display.""" + import html + import re + + # Escape HTML first + html_content = html.escape(markdown) + + # Convert headers + html_content = re.sub(r"^# (.*?)$", r"

    \1

    ", html_content, flags=re.MULTILINE) + html_content = re.sub(r"^## (.*?)$", r"

    \1

    ", html_content, flags=re.MULTILINE) + html_content = re.sub(r"^### (.*?)$", r"

    \1

    ", html_content, flags=re.MULTILINE) + + # Convert code blocks + html_content = re.sub( + r"```(.*?)\n(.*?)\n```", + r"
    \2
    ", + html_content, + flags=re.DOTALL, + ) + html_content = re.sub(r"`([^`]+)`", r"\1", html_content) + + # Convert bold and italic + html_content = re.sub(r"\*\*(.*?)\*\*", r"\1", html_content) + html_content = re.sub(r"\*(.*?)\*", r"\1", html_content) + + # Convert lists + html_content = re.sub(r"^- (.*?)$", r"
  • \1
  • ", html_content, flags=re.MULTILINE) + html_content = re.sub(r"(
  • .*
  • )", r"
      \1
    ", html_content, flags=re.DOTALL) + + # Convert line breaks + html_content = html_content.replace("\n", "
    ") + + return html_content + + +def _generate_action_interface(action_fields: List[Dict[str, Any]], is_chat_env: bool) -> str: + """Generate either a chat interface or action form based on environment type.""" + if is_chat_env: + return _generate_chat_interface() + else: + return _generate_action_form(action_fields) + + +def _generate_chat_interface() -> str: + """Generate a chat-style interface for chat environments.""" + return """ + +
    +

    Chat Interface

    +
    +
    +
    System
    +
    Chat environment ready. Send a message to start the conversation.
    +
    +
    +
    +
    + + +
    +
    + + +
    +
    +
    + """ + + +def _generate_action_form(action_fields: List[Dict[str, Any]]) -> str: + """Generate a traditional action form for non-chat environments.""" + return f""" + +
    +

    Take Action

    +
    + {_generate_action_form_fields(action_fields)} + +
    +
    + """ + + +def _generate_action_form_fields(action_fields: List[Dict[str, Any]]) -> str: + """Generate HTML form fields for action input with enhanced metadata.""" + if not action_fields: + return "

    No action fields available

    " + + fields_html = [] + for field in action_fields: + field_html = _generate_single_field(field) + fields_html.append(field_html) + + return "\n".join(fields_html) + + +def _generate_single_field(field: Dict[str, Any]) -> str: + """Generate HTML for a single form field with enhanced metadata.""" + field_name = field["name"] + field_type = field["type"] + required = field["required"] + placeholder = field.get("placeholder", "") + help_text = field.get("help_text", "") + choices = field.get("choices", []) + min_value = field.get("min_value") + max_value = field.get("max_value") + default_value = field.get("default_value") + min_length = field.get("min_length") + max_length = field.get("max_length") + pattern = field.get("pattern") + + # Build label with required indicator + label_text = field_name.replace("_", " ").title() + if required: + label_text += ' *' + + # Build input attributes + input_attrs = [] + if required: + input_attrs.append("required") + if placeholder: + input_attrs.append(f'placeholder="{placeholder}"') + if min_value is not None: + input_attrs.append(f'min="{min_value}"') + if max_value is not None: + input_attrs.append(f'max="{max_value}"') + if min_length is not None: + input_attrs.append(f'minlength="{min_length}"') + if max_length is not None: + input_attrs.append(f'maxlength="{max_length}"') + if pattern is not None: + input_attrs.append(f'pattern="{pattern}"') + if default_value is not None: + input_attrs.append(f'value="{default_value}"') + + attrs_str = " ".join(input_attrs) + + if field_type == "checkbox": + checked = "checked" if default_value is True else "" + return f''' +
    + + {f'{help_text}' if help_text else ""} +
    + ''' + + elif field_type == "select": + options_html = [] + if not required: + options_html.append(f'') + + for choice in choices: + selected = "selected" if str(choice) == str(default_value) else "" + options_html.append(f'') + + return f''' +
    + + + {f'{help_text}' if help_text else ""} +
    + ''' + + elif field_type == "tensor": + return f''' +
    + + + {help_text or "Enter token IDs as comma-separated integers (e.g., 1,2,3,4,5)"} +
    + ''' + + elif field_type == "textarea": + return f''' +
    + + + {f'{help_text}' if help_text else ""} +
    + ''' + + else: + return f''' +
    + + + {f'{help_text}' if help_text else ""} +
    + ''' From 6ccc4d25723537866d5c0c6c03a9a674cb0e1e70 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Thu, 18 Dec 2025 11:14:50 +0100 Subject: [PATCH 41/41] add websocket dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index edb6c1f17..b7fa6794a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ core = [ "pydantic>=2.0.0", "uvicorn>=0.24.0", "requests>=2.25.0", + "websockets>=15.0.1", ] cli = [ "typer>=0.9.0",