diff --git a/docs/environment-builder.md b/docs/environment-builder.md index 9fefc9ee1..5007ae1a0 100644 --- a/docs/environment-builder.md +++ b/docs/environment-builder.md @@ -10,7 +10,7 @@ A typical workflow looks like: 1. Scaffold a new environment with `openenv init`. 2. Customize your models, environment logic, and FastAPI server. -3. Implement a typed `HTTPEnvClient`. +3. Implement a typed `EnvClient` (WebSocket-based for persistent sessions). 4. Configure dependencies and the Dockerfile once. 5. Use the CLI (`openenv build`, `openenv validate`, `openenv push`) to package and share your work. @@ -58,33 +58,26 @@ my_env/ └── Dockerfile ``` -Python classes are generated for the action, observation, and state, and a client is generated for the environment. For example, you will find `MyEnvironment`, `MyAction`, `MyObservation`, and `MyState` in the `my_env` directory based on the name of the environment you provided. +Python classes are generated for the action, observation, environment, and client. For example, you will find `MyEnvironment`, `MyAction`, `MyObservation`, and `MyEnv` (client) in the `my_env` directory based on the name you provided. The environment uses the core `State` class from `openenv.core.env_server.types`. ### 2. Define Models -Edit `models.py` to describe your action, observation, and state dataclasses: +Edit `models.py` to describe your action and observation using Pydantic: ```python # models.py -from dataclasses import dataclass -from openenv.core.env_server import Action, Observation, State +from pydantic import Field +from openenv.core.env_server.types import Action, Observation -@dataclass class MyAction(Action): """Your custom action.""" - command: str - parameters: dict + command: str = Field(..., description="Command to execute") + parameters: dict = Field(default_factory=dict, description="Command parameters") -@dataclass class MyObservation(Observation): """Your custom observation.""" - result: str - success: bool - -@dataclass -class MyState(State): - """Custom state fields.""" - custom_field: int = 0 + result: str = Field(..., description="Result of the action") + success: bool = Field(..., description="Whether the action succeeded") ``` ### 3. Implement Environment Logic @@ -93,70 +86,104 @@ Customize `server/my_environment.py` by extending `Environment`: ```python # server/my_environment.py -import uuid -from openenv.core.env_server import Environment -from ..models import MyAction, MyObservation, MyState +from uuid import uuid4 +from openenv.core.env_server.interfaces import Environment +from openenv.core.env_server.types import State +from models import MyAction, MyObservation class MyEnvironment(Environment): def __init__(self): - super().__init__() - self._state = MyState() + self._state = State(episode_id=str(uuid4()), step_count=0) def reset(self) -> MyObservation: - self._state = MyState(episode_id=str(uuid.uuid4())) - return MyObservation(result="Ready", success=True) + self._state = State(episode_id=str(uuid4()), step_count=0) + return MyObservation(result="Ready", success=True, done=False, reward=0.0) def step(self, action: MyAction) -> MyObservation: # Implement your logic here self._state.step_count += 1 result = self._execute_command(action.command) - return MyObservation(result=result, success=True) + return MyObservation(result=result, success=True, done=False, reward=1.0) @property - def state(self) -> MyState: + def state(self) -> State: return self._state ``` ### 4. Create the FastAPI Server -`server/app.py` should expose the environment through `create_fastapi_app`: +`server/app.py` should expose the environment through `create_app`. + +**Important:** You must pass a class or factory function (not an instance) to enable WebSocket-based concurrent sessions: + +```python +# server/app.py +from openenv.core.env_server import create_app +from ..models import MyAction, MyObservation +from .my_environment import MyEnvironment + +# Pass the class (factory) - each WebSocket session gets its own instance +app = create_app(MyEnvironment, MyAction, MyObservation, env_name="my_env") +``` + +For environments with constructor arguments, create a factory function: ```python # server/app.py -from openenv.core.env_server import create_fastapi_app +import os +from openenv.core.env_server import create_app from ..models import MyAction, MyObservation from .my_environment import MyEnvironment -env = MyEnvironment() -app = create_fastapi_app(env, MyAction, MyObservation) +# Read config from environment variables +api_key = os.getenv("MY_API_KEY") +timeout = int(os.getenv("MY_TIMEOUT", "30")) + +def create_my_environment(): + """Factory function that creates MyEnvironment with config.""" + return MyEnvironment(api_key=api_key, timeout=timeout) + +# Pass the factory function +app = create_app(create_my_environment, MyAction, MyObservation, env_name="my_env") ``` ### 5. Implement the Client -`client.py` extends `HTTPEnvClient` so users can interact with your server over HTTP or Docker: +`client.py` extends `EnvClient` so users can interact with your server via WebSocket for persistent sessions: ```python # client.py -from openenv.core.http_env_client import HTTPEnvClient -from openenv.core.types import StepResult +from openenv.core.env_client import EnvClient +from openenv.core.client_types import StepResult from .models import MyAction, MyObservation, MyState -class MyEnv(HTTPEnvClient[MyAction, MyObservation]): +class MyEnv(EnvClient[MyAction, MyObservation, MyState]): def _step_payload(self, action: MyAction) -> dict: return {"command": action.command, "parameters": action.parameters} def _parse_result(self, payload: dict) -> StepResult[MyObservation]: - obs = MyObservation(**payload["observation"]) + obs_data = payload.get("observation", {}) + obs = MyObservation( + result=obs_data.get("result", ""), + success=obs_data.get("success", False), + done=payload.get("done", False), + reward=payload.get("reward"), + ) return StepResult( observation=obs, reward=payload.get("reward"), done=payload.get("done", False), ) - def _parse_state(self, payload: dict) -> MyState: - return MyState(**payload) + def _parse_state(self, payload: dict) -> State: + return State( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + ) ``` +The `EnvClient` maintains a persistent WebSocket connection to the server, enabling efficient multi-step interactions with lower latency compared to HTTP. Each client instance gets its own dedicated environment session on the server. + ### 6. Configure Dependencies & Dockerfile The CLI template ships with `pyproject.toml` and `server/Dockerfile`. You should manage your python dependencies with `uv` or `pip` in the `pyproject.toml` file. Other dependencies should be installed in the Dockerfile. @@ -322,22 +349,29 @@ client = MyEnv.from_hub("my-org/my-env") # Or, connect to the local server client = MyEnv(base_url="http://localhost:8000") -# Reset -result = client.reset() -print(result.observation.result) # "Ready" - -# Execute actions -result = client.step(MyAction(command="test", parameters={})) -print(result.observation.result) -print(result.observation.success) - -# Get state -state = client.state() -print(state.episode_id) -print(state.step_count) - -# Cleanup -client.close() +# Use context manager for automatic cleanup (recommended) +with client: + # Reset + result = client.reset() + print(result.observation.result) # "Ready" + + # Execute actions + result = client.step(MyAction(command="test", parameters={})) + print(result.observation.result) + print(result.observation.success) + + # Get state + state = client.state() + print(state.episode_id) + print(state.step_count) + +# Or manually manage the connection +try: + client = MyEnv(base_url="http://localhost:8000") + result = client.reset() + result = client.step(MyAction(command="test", parameters={})) +finally: + client.close() ``` ## Nice work! You've now built and used your own OpenEnv environment. diff --git a/envs/atari_env/client.py b/envs/atari_env/client.py index cbdb373f5..458895454 100644 --- a/envs/atari_env/client.py +++ b/envs/atari_env/client.py @@ -5,10 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -Atari Environment HTTP Client. +Atari Environment Client. This module provides the client for connecting to an Atari Environment server -over HTTP. +via WebSocket for persistent sessions. """ from __future__ import annotations @@ -17,7 +17,7 @@ from openenv.core.client_types import StepResult -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.env_client import EnvClient from .models import AtariAction, AtariObservation, AtariState @@ -25,28 +25,30 @@ from openenv.core.containers.runtime import ContainerProvider -class AtariEnv(HTTPEnvClient[AtariAction, AtariObservation]): +class AtariEnv(EnvClient[AtariAction, AtariObservation, AtariState]): """ - HTTP client for Atari Environment. + Client for Atari Environment. - This client connects to an AtariEnvironment HTTP server and provides - methods to interact with it: reset(), step(), and state access. + This client maintains a persistent WebSocket connection to the environment + server, enabling efficient multi-step interactions with lower latency. Example: >>> # Connect to a running server - >>> client = AtariEnv(base_url="http://localhost:8000") - >>> result = client.reset() - >>> print(result.observation.screen_shape) - >>> - >>> # Take an action - >>> result = client.step(AtariAction(action_id=2)) # UP - >>> print(result.reward, result.done) + >>> with AtariEnv(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.screen_shape) + ... + ... result = client.step(AtariAction(action_id=2)) # UP + ... print(result.reward, result.done) Example with Docker: >>> # Automatically start container and connect >>> client = AtariEnv.from_docker_image("atari-env:latest") - >>> result = client.reset() - >>> result = client.step(AtariAction(action_id=0)) # NOOP + >>> try: + ... result = client.reset() + ... result = client.step(AtariAction(action_id=0)) # NOOP + ... finally: + ... client.close() """ def _step_payload(self, action: AtariAction) -> Dict[str, Any]: diff --git a/envs/atari_env/server/app.py b/envs/atari_env/server/app.py index 14254f6d9..036e44ef3 100644 --- a/envs/atari_env/server/app.py +++ b/envs/atari_env/server/app.py @@ -8,7 +8,7 @@ FastAPI application for the Atari Environment. This module creates an HTTP server that exposes Atari games -over HTTP endpoints, making them compatible with HTTPEnvClient. +over HTTP and WebSocket endpoints, compatible with EnvClient. Usage: # Development (with auto-reload): @@ -52,19 +52,24 @@ mode = int(mode) if mode is not None else None difficulty = int(difficulty) if difficulty is not None else None -# Create the environment instance -env = AtariEnvironment( - game_name=game_name, - obs_type=obs_type, - full_action_space=full_action_space, - mode=mode, - difficulty=difficulty, - repeat_action_probability=repeat_action_prob, - frameskip=frameskip, -) + +# Factory function to create AtariEnvironment instances +def create_atari_environment(): + """Factory function that creates AtariEnvironment with config.""" + return AtariEnvironment( + game_name=game_name, + obs_type=obs_type, + full_action_space=full_action_space, + mode=mode, + difficulty=difficulty, + repeat_action_probability=repeat_action_prob, + frameskip=frameskip, + ) + # Create the FastAPI app with web interface and README integration -app = create_app(env, AtariAction, AtariObservation, env_name="atari_env") +# Pass the factory function instead of an instance for WebSocket session support +app = create_app(create_atari_environment, AtariAction, AtariObservation, env_name="atari_env") if __name__ == "__main__": diff --git a/envs/browsergym_env/client.py b/envs/browsergym_env/client.py index 5b6d3772d..cb7437f9d 100644 --- a/envs/browsergym_env/client.py +++ b/envs/browsergym_env/client.py @@ -1,8 +1,9 @@ -"""HTTP client for the BrowserGym environment.""" +"""Client for the BrowserGym environment.""" from typing import Any, Dict -from openenv.core.http_env_client import HTTPEnvClient, StepResult +from openenv.core.client_types import StepResult +from openenv.core.env_client import EnvClient from .models import ( BrowserGymAction, BrowserGymObservation, @@ -10,8 +11,8 @@ ) -class BrowserGymEnv(HTTPEnvClient[BrowserGymAction, BrowserGymObservation]): - """Client for interacting with the BrowserGym environment over HTTP. +class BrowserGymEnv(EnvClient[BrowserGymAction, BrowserGymObservation, BrowserGymState]): + """Client for interacting with the BrowserGym environment. BrowserGym provides unified access to multiple web navigation benchmarks: - MiniWoB++: 100+ training tasks (no external infrastructure needed!) diff --git a/envs/browsergym_env/server/app.py b/envs/browsergym_env/server/app.py index 488b66974..fa8214dc3 100644 --- a/envs/browsergym_env/server/app.py +++ b/envs/browsergym_env/server/app.py @@ -15,19 +15,24 @@ timeout = float(os.environ.get("BROWSERGYM_TIMEOUT", "10000")) port = int(os.environ.get("BROWSERGYM_PORT", "8000")) -# Create the environment instance -env = BrowserGymEnvironment( - benchmark=benchmark, - task_name=task_name, - headless=headless, - viewport_width=viewport_width, - viewport_height=viewport_height, - timeout=timeout, -) + +# Factory function to create BrowserGymEnvironment instances +def create_browsergym_environment(): + """Factory function that creates BrowserGymEnvironment with config.""" + return BrowserGymEnvironment( + benchmark=benchmark, + task_name=task_name, + headless=headless, + viewport_width=viewport_width, + viewport_height=viewport_height, + timeout=timeout, + ) + # Create the FastAPI app +# Pass the factory function instead of an instance for WebSocket session support app = create_app( - env, + create_browsergym_environment, BrowserGymAction, BrowserGymObservation, env_name="browsergym_env", diff --git a/envs/chat_env/client.py b/envs/chat_env/client.py index d14829f74..a1b265cd4 100644 --- a/envs/chat_env/client.py +++ b/envs/chat_env/client.py @@ -5,10 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -Chat Environment HTTP Client. +Chat Environment Client. This module provides the client for connecting to a Chat Environment server -over HTTP. +via WebSocket for persistent sessions. """ from typing import Any, Dict @@ -17,40 +17,42 @@ from openenv.core.client_types import StepResult from openenv.core.env_server.interfaces import Message -from openenv.core.env_server.types import State -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.env_client import EnvClient from .models import ChatAction, ChatObservation, ChatState -class ChatEnv(HTTPEnvClient[ChatAction, ChatObservation]): +class ChatEnv(EnvClient[ChatAction, ChatObservation, ChatState]): """ - HTTP client for the Chat Environment. + Client for the Chat Environment. - This client connects to a ChatEnvironment HTTP server and provides - methods to interact with it: reset(), step(), and state access. + This client maintains a persistent WebSocket connection to the environment + server, enabling efficient multi-step interactions with lower latency. - Note: Since ChatEnvironment works with PyTorch tensors, the HTTP layer + Note: Since ChatEnvironment works with PyTorch tensors, the client serializes tokens as lists for transport and deserializes them back to tensors. Example: >>> # Connect to a running server - >>> client = ChatEnv(base_url="http://localhost:8000") - >>> result = client.reset() - >>> print(result.observation.messages) - >>> - >>> # Send an action with tokens - >>> import torch - >>> tokens = torch.tensor([[1, 2, 3, 4, 5]]) - >>> result = client.step(ChatAction(tokens=tokens)) - >>> print(result.observation.messages) - >>> print(result.reward) + >>> with ChatEnv(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.messages) + ... + ... # Send an action with tokens + ... import torch + ... tokens = torch.tensor([[1, 2, 3, 4, 5]]) + ... result = client.step(ChatAction(tokens=tokens)) + ... print(result.observation.messages) + ... print(result.reward) Example with Docker: >>> # Automatically start container and connect >>> client = ChatEnv.from_docker_image("chat-env:latest") - >>> result = client.reset() - >>> result = client.step(ChatAction(tokens=torch.tensor([[1, 2, 3]]))) + >>> try: + ... result = client.reset() + ... result = client.step(ChatAction(tokens=torch.tensor([[1, 2, 3]]))) + ... finally: + ... client.close() """ def _step_payload(self, action: ChatAction) -> Dict: diff --git a/envs/chat_env/server/app.py b/envs/chat_env/server/app.py index 719b5ede8..88b9694f7 100644 --- a/envs/chat_env/server/app.py +++ b/envs/chat_env/server/app.py @@ -8,7 +8,7 @@ FastAPI application for the Chat Environment. This module creates an HTTP server that exposes the ChatEnvironment -over HTTP endpoints, making it compatible with HTTPEnvClient. +over HTTP and WebSocket endpoints, compatible with EnvClient. Note: This server requires a tokenizer to be initialized. The tokenizer must be specified when starting the server. @@ -27,7 +27,6 @@ import os from openenv.core.env_server import create_app -from openenv.core.env_server.web_interface import create_web_interface_app from ..models import ChatAction, ChatObservation from .chat_environment import ChatEnvironment @@ -64,12 +63,17 @@ def get_tokenizer(): # Get system prompt from environment system_prompt = os.environ.get("SYSTEM_PROMPT", None) -# Create the environment instance with tokenizer -tokenizer = get_tokenizer() -env = ChatEnvironment(tokenizer=tokenizer, system_prompt=system_prompt) + +# Factory function to create ChatEnvironment instances +def create_chat_environment(): + """Factory function that creates ChatEnvironment with tokenizer.""" + tokenizer = get_tokenizer() + return ChatEnvironment(tokenizer=tokenizer, system_prompt=system_prompt) + # Create the FastAPI app with web interface and README integration -app = create_app(env, ChatAction, ChatObservation, env_name="chat_env") +# Pass the factory function instead of an instance for WebSocket session support +app = create_app(create_chat_environment, ChatAction, ChatObservation, env_name="chat_env") if __name__ == "__main__": diff --git a/envs/coding_env/client.py b/envs/coding_env/client.py index 544b6a6e0..a05db092e 100644 --- a/envs/coding_env/client.py +++ b/envs/coding_env/client.py @@ -2,11 +2,13 @@ CodingEnv --------- Client-side wrapper for the Coding environment server. -Talks HTTP to a single base_url exposing: /reset and /step. + +This client maintains a persistent WebSocket connection to the environment +server, enabling efficient multi-step interactions with lower latency. - users instantiate CodingEnv with a base_url provided by the higher-level vector/orchestration layer. -- Environment authors ship the Docker image that serves the HTTP API. +- Environment authors ship the Docker image that serves the API. (Seeds, episode IDs, request IDs, capabilities can be added later in the payloads.) """ @@ -14,13 +16,12 @@ from __future__ import annotations from openenv.core.client_types import StepResult +from openenv.core.env_client import EnvClient -from openenv.core.http_env_client import HTTPEnvClient - -from coding_env.models import CodeAction, CodeObservation, CodeState +from .models import CodeAction, CodeObservation, CodeState -class CodingEnv(HTTPEnvClient[CodeAction, CodeObservation]): +class CodingEnv(EnvClient[CodeAction, CodeObservation, CodeState]): # --- HTTPEnvClient abstract hooks --- def _step_payload(self, action: CodeAction) -> dict: diff --git a/envs/coding_env/server/app.py b/envs/coding_env/server/app.py index b636d0784..4859585fa 100644 --- a/envs/coding_env/server/app.py +++ b/envs/coding_env/server/app.py @@ -8,7 +8,7 @@ FastAPI application for the Coding Environment. This module creates an HTTP server that exposes the PythonCodeActEnv -over HTTP endpoints, making it compatible with HTTPEnvClient. +over HTTP and WebSocket endpoints, compatible with EnvClient. Usage: # Development (with auto-reload): @@ -26,11 +26,9 @@ from coding_env.models import CodeAction, CodeObservation from coding_env.server.python_codeact_env import PythonCodeActEnv -# Create the environment instance -env = PythonCodeActEnv() - # Create the app with web interface and README integration -app = create_app(env, CodeAction, CodeObservation, env_name="coding_env") +# Pass the class (factory) instead of an instance for WebSocket session support +app = create_app(PythonCodeActEnv, CodeAction, CodeObservation, env_name="coding_env") if __name__ == "__main__": diff --git a/envs/coding_env/server/python_codeact_env.py b/envs/coding_env/server/python_codeact_env.py index ed95135d1..a73ed1e55 100644 --- a/envs/coding_env/server/python_codeact_env.py +++ b/envs/coding_env/server/python_codeact_env.py @@ -14,9 +14,9 @@ import uuid from openenv.core.env_server.interfaces import Action, Environment, Observation -from coding_env.server.python_executor import PyExecutor +from .python_executor import PyExecutor -from coding_env.models import CodeAction, CodeObservation, CodeState +from ..models import CodeAction, CodeObservation, CodeState from .transforms import create_safe_coding_transform diff --git a/envs/connect4_env/client.py b/envs/connect4_env/client.py index a462929a0..d9f6c2165 100644 --- a/envs/connect4_env/client.py +++ b/envs/connect4_env/client.py @@ -5,10 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -Connect4 Environment HTTP Client. +Connect4 Environment Client. This module provides the client for connecting to a Connect4 Environment server -over HTTP. +via WebSocket for persistent sessions. """ from __future__ import annotations @@ -16,7 +16,7 @@ from typing import Any, Dict, TYPE_CHECKING from openenv.core.client_types import StepResult -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.env_client import EnvClient from .models import Connect4Action, Connect4Observation, Connect4State @@ -24,21 +24,20 @@ from openenv.core.containers.runtime import ContainerProvider -class Connect4Env(HTTPEnvClient[Connect4Action, Connect4Observation]): +class Connect4Env(EnvClient[Connect4Action, Connect4Observation, Connect4State]): """ - HTTP client for Connect4 Environment. + Client for Connect4 Environment. - This client connects to a Connect4Environment HTTP server and provides - methods to interact with it: reset(), step(), and state access. + This client maintains a persistent WebSocket connection to the environment + server, enabling efficient multi-step interactions with lower latency. Example: - >>> client = Connect4Env(base_url="http://localhost:8000") - >>> result = client.reset() - >>> print(result.observation.board) - >>> - >>> # Take an action - >>> result = client.step(Connect4Action(column=3)) - >>> print(result.reward, result.done) + >>> with Connect4Env(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.board) + ... + ... result = client.step(Connect4Action(column=3)) + ... print(result.reward, result.done) """ def _step_payload(self, action: Connect4Action) -> Dict[str, Any]: diff --git a/envs/connect4_env/models.py b/envs/connect4_env/models.py index 8cf3309a8..90ee90742 100644 --- a/envs/connect4_env/models.py +++ b/envs/connect4_env/models.py @@ -12,14 +12,12 @@ """ from __future__ import annotations -from dataclasses import dataclass, field -import numpy as np -from typing import List +from typing import List, Dict, Any +from pydantic import Field from openenv.core.env_server import Action, Observation, State -@dataclass class Connect4Action(Action): """ Action for Connect4 environment. @@ -30,7 +28,6 @@ class Connect4Action(Action): column: int -@dataclass(kw_only=True) class Connect4Observation(Observation): """ Observation for Connect4 environment. @@ -43,15 +40,10 @@ class Connect4Observation(Observation): reward: Reward for the last action. """ - board: List[List[int]] - legal_actions: List[int] - done: bool = False - reward: float = 0.0 - metadata: dict = field(default_factory=dict) - + board: List[List[int]] = Field(default_factory=list) + legal_actions: List[int] = Field(default_factory=list) -@dataclass(kw_only=True) class Connect4State(State): """ State for Connect4 environment. @@ -62,7 +54,5 @@ class Connect4State(State): next_player: Whose turn it is (1 or -1). step_count: Number of steps taken in the game. """ - episode_id: str - board: List[List[int]] = field(default_factory=lambda: np.zeros((6,7), dtype=int).tolist()) + board: List[List[int]] = Field(default_factory=lambda: [[0]*7 for _ in range(6)]) next_player: int = 1 - step_count: int = 0 diff --git a/envs/connect4_env/server/app.py b/envs/connect4_env/server/app.py index 143ee1770..2025b2c37 100644 --- a/envs/connect4_env/server/app.py +++ b/envs/connect4_env/server/app.py @@ -1,9 +1,12 @@ -from openenv.core.env_server import create_fastapi_app +"""FastAPI application for the Connect4 Environment.""" + +from openenv.core.env_server import create_app from ..models import Connect4Action, Connect4Observation from .connect4_environment import Connect4Environment -env = Connect4Environment() -app = create_fastapi_app(env, Connect4Action, Connect4Observation) +# Create the FastAPI app +# Pass the class (factory) instead of an instance for WebSocket session support +app = create_app(Connect4Environment, Connect4Action, Connect4Observation, env_name="connect4_env") if __name__ == "__main__": diff --git a/envs/dipg_safety_env/client.py b/envs/dipg_safety_env/client.py index 9e556481f..2d11503b3 100644 --- a/envs/dipg_safety_env/client.py +++ b/envs/dipg_safety_env/client.py @@ -3,22 +3,24 @@ Client implementation for the custom DIPGSafetyEnv. This file defines the `DIPGSafetyEnv` class, which acts as the "remote control" -for the environment server. Its primary job is to handle the HTTP communication: +for the environment server. It maintains a persistent WebSocket connection +for efficient multi-step interactions: 1. It takes Python objects (like an Action) from the agent's code. 2. It converts them into JSON to send to the server. 3. It receives JSON responses from the server. 4. It parses that JSON back into useful Python objects (like Observations and Rewards). """ -from openenv.core.http_env_client import HTTPEnvClient, StepResult +from openenv.core.client_types import StepResult +from openenv.core.env_client import EnvClient from .models import DIPGAction, DIPGObservation, DIPGState -class DIPGSafetyEnv(HTTPEnvClient[DIPGAction, DIPGObservation]): +class DIPGSafetyEnv(EnvClient[DIPGAction, DIPGObservation, DIPGState]): """ Client for interacting with the `DIPGSafetyEnv` server. - This class inherits from the base `HTTPEnvClient` and is specialized to handle + This class inherits from the base `EnvClient` and is specialized to handle the specific data types of our environment: `DIPGAction` and `DIPGObservation`. """ @@ -31,8 +33,8 @@ def __init__(self, base_url: str, timeout: float = 60.0): timeout: The number of seconds to wait for a server response. """ # This correctly calls the parent initializer with the expected - # 'request_timeout_s' keyword argument. - super().__init__(base_url=base_url, request_timeout_s=timeout) + # 'message_timeout_s' keyword argument. + super().__init__(base_url=base_url, message_timeout_s=timeout) # ---------------------------------------- def _step_payload(self, action: DIPGAction) -> dict: diff --git a/envs/dipg_safety_env/models.py b/envs/dipg_safety_env/models.py index dbd9e04ec..a770e7355 100644 --- a/envs/dipg_safety_env/models.py +++ b/envs/dipg_safety_env/models.py @@ -1,24 +1,25 @@ # envs/dipg_safety_env/models.py -from dataclasses import dataclass, field +from typing import Dict, Any +from pydantic import Field from openenv.core.env_server import Action, Observation, State -@dataclass + class DIPGAction(Action): """The action taken by the agent, which is its generated response.""" llm_response: str -@dataclass + class DIPGObservation(Observation): """The observation given to the agent: a context and a question.""" - context: str - question: str + context: str = "" + question: str = "" + -@dataclass class DIPGState(State): """The internal state of the environment for tracking the current challenge.""" current_context: str = "" current_question: str = "" # This will hold the ground-truth 'analysis' and 'final' answer # for scoring purposes. - expected_answer: dict = field(default_factory=dict) \ No newline at end of file + expected_answer: Dict[str, Any] = Field(default_factory=dict) \ No newline at end of file diff --git a/envs/dipg_safety_env/server/app.py b/envs/dipg_safety_env/server/app.py index 74d9fe87d..f6dcaa8a1 100644 --- a/envs/dipg_safety_env/server/app.py +++ b/envs/dipg_safety_env/server/app.py @@ -1,4 +1,11 @@ # envs/dipg_safety_env/server/app.py +""" +FastAPI application for the DIPG Safety Environment. + +This module creates an HTTP server that exposes the DIPGEnvironment +over HTTP and WebSocket endpoints, compatible with EnvClient. +""" + import os from openenv.core.env_server import create_app from .dipg_environment import DIPGEnvironment @@ -49,33 +56,27 @@ FINAL_CHANNEL_START = os.environ.get("FINAL_CHANNEL_START", "<|channel|>final<|message|>") CHANNEL_END = os.environ.get("CHANNEL_END", "<|end|>") -# Create the environment instance, passing all reward configurations to it. -env = DIPGEnvironment( - dataset_path=DATASET_PATH, - # V1 - conflict_reward=CONFLICT_REWARD, - abstain_reward=ABSTAIN_REWARD, - hallucination_penalty=HALLUCINATION_PENALTY, - missing_answer_penalty=MISSING_ANSWER_PENALTY, - # V2 - hallucinated_trace_penalty=HALLUCINATED_TRACE_PENALTY, - proof_inconsistency_penalty=PROOF_INCONSISTENCY_PENALTY, - incorrect_answer_penalty=INCORRECT_ANSWER_PENALTY, - conflict_penalty=CONFLICT_PENALTY, - abstain_penalty=ABSTAIN_PENALTY, - missing_trace_penalty=MISSING_TRACE_PENALTY, - correct_abstention_reward=CORRECT_ABSTENTION_REWARD, - verifiable_trace_reward=VERIFIABLE_TRACE_REWARD, - correct_synthesis_reward=CORRECT_SYNTHESIS_REWARD, - exact_format_reward=EXACT_FORMAT_REWARD, - format_mismatch_penalty=FORMAT_MISMATCH_PENALTY, - no_hallucination_reward=NO_HALLUCINATION_REWARD, - # Channels - analysis_channel_start=ANALYSIS_CHANNEL_START, - proof_channel_start=PROOF_CHANNEL_START, - final_channel_start=FINAL_CHANNEL_START, - channel_end=CHANNEL_END, -) -# The rest is the same. -app = create_app(env, DIPGAction, DIPGObservation, env_name="dipg_safety_env") \ No newline at end of file +# Factory function to create DIPGEnvironment instances +def create_dipg_environment(): + """Factory function that creates DIPGEnvironment with config.""" + return DIPGEnvironment( + dataset_path=DATASET_PATH, + conflict_reward=CONFLICT_REWARD, + conflict_penalty=CONFLICT_PENALTY, + abstain_reward=ABSTAIN_REWARD, + abstain_penalty=ABSTAIN_PENALTY, + format_mismatch_penalty=FORMAT_MISMATCH_PENALTY, + exact_format_reward=EXACT_FORMAT_REWARD, + hallucination_penalty=HALLUCINATION_PENALTY, + no_hallucination_reward=NO_HALLUCINATION_REWARD, + missing_answer_penalty=MISSING_ANSWER_PENALTY, + analysis_channel_start=ANALYSIS_CHANNEL_START, + final_channel_start=FINAL_CHANNEL_START, + channel_end=CHANNEL_END, + ) + + +# Create the FastAPI app +# Pass the factory function instead of an instance for WebSocket session support +app = create_app(create_dipg_environment, DIPGAction, DIPGObservation, env_name="dipg_safety_env") diff --git a/envs/dipg_safety_env/server/dipg_environment.py b/envs/dipg_safety_env/server/dipg_environment.py index cc70cf616..c8a596ba8 100644 --- a/envs/dipg_safety_env/server/dipg_environment.py +++ b/envs/dipg_safety_env/server/dipg_environment.py @@ -3,7 +3,7 @@ import json import random from pathlib import Path -from openenv.core.http_env_client import StepResult +from openenv.core.client_types import StepResult from openenv.core.env_server import Environment from ..models import DIPGAction, DIPGObservation, DIPGState import re diff --git a/envs/echo_env/client.py b/envs/echo_env/client.py index fcb82e5ca..9c7ee2c64 100644 --- a/envs/echo_env/client.py +++ b/envs/echo_env/client.py @@ -5,10 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -Echo Environment HTTP Client. +Echo Environment Client. This module provides the client for connecting to an Echo Environment server -over HTTP. +via WebSocket for persistent sessions. """ from typing import Any, Dict @@ -18,39 +18,42 @@ # In-repo imports (when running from OpenEnv repository) from openenv.core.client_types import StepResult from openenv.core.env_server.types import State - from openenv.core.http_env_client import HTTPEnvClient + from openenv.core.env_client import EnvClient from .models import EchoAction, EchoObservation except ImportError: # Standalone imports (when environment is standalone with openenv from pip) from openenv.core.client_types import StepResult from openenv.core.env_server.types import State - from openenv.core.http_env_client import HTTPEnvClient + from openenv.core.env_client import EnvClient from models import EchoAction, EchoObservation -class EchoEnv(HTTPEnvClient[EchoAction, EchoObservation]): +class EchoEnv(EnvClient[EchoAction, EchoObservation, State]): """ - HTTP client for the Echo Environment. + Client for the Echo Environment. - This client connects to an EchoEnvironment HTTP server and provides - methods to interact with it: reset(), step(), and state access. + This client maintains a persistent WebSocket connection to the environment + server, enabling efficient multi-step interactions with lower latency. + Each client instance has its own dedicated environment session on the server. Example: >>> # Connect to a running server - >>> client = EchoEnv(base_url="http://localhost:8000") - >>> result = client.reset() - >>> print(result.observation.echoed_message) - >>> - >>> # Send a message - >>> result = client.step(EchoAction(message="Hello!")) - >>> print(result.observation.echoed_message) - >>> print(result.reward) + >>> with EchoEnv(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.echoed_message) + ... + ... result = client.step(EchoAction(message="Hello!")) + ... print(result.observation.echoed_message) + ... print(result.reward) Example with Docker: >>> # Automatically start container and connect >>> client = EchoEnv.from_docker_image("echo-env:latest") - >>> result = client.reset() - >>> result = client.step(EchoAction(message="Test")) + >>> try: + ... result = client.reset() + ... result = client.step(EchoAction(message="Test")) + ... finally: + ... client.close() """ def _step_payload(self, action: EchoAction) -> Dict: diff --git a/envs/echo_env/models.py b/envs/echo_env/models.py index 4cbf1016c..3032b7511 100644 --- a/envs/echo_env/models.py +++ b/envs/echo_env/models.py @@ -10,7 +10,7 @@ The Echo environment is a simple test environment that echoes back messages. """ -from dataclasses import dataclass +from pydantic import Field # Support both in-repo and standalone imports try: @@ -21,16 +21,14 @@ from openenv.core.env_server.types import Action, Observation -@dataclass(kw_only=True) class EchoAction(Action): """Action for the Echo environment - just a message to echo.""" - message: str + message: str = Field(..., min_length=1, description="Message to echo back") -@dataclass(kw_only=True) class EchoObservation(Observation): """Observation from the Echo environment - the echoed message.""" - echoed_message: str - message_length: int = 0 \ No newline at end of file + echoed_message: str = Field(..., description="The echoed message from the environment") + message_length: int = Field(default=0, ge=0, description="Length of the echoed message") \ No newline at end of file diff --git a/envs/echo_env/server/app.py b/envs/echo_env/server/app.py index 96c803040..07fe59ecb 100644 --- a/envs/echo_env/server/app.py +++ b/envs/echo_env/server/app.py @@ -8,7 +8,7 @@ FastAPI application for the Echo Environment. This module creates an HTTP server that exposes the EchoEnvironment -over HTTP endpoints, making it compatible with HTTPEnvClient. +over HTTP and WebSocket endpoints, compatible with EnvClient. Usage: # Development (with auto-reload): @@ -33,11 +33,9 @@ from models import EchoAction, EchoObservation from server.echo_environment import EchoEnvironment -# Create the environment instance -env = EchoEnvironment() - # Create the app with web interface and README integration -app = create_app(env, EchoAction, EchoObservation, env_name="echo_env") +# Pass the class (factory) instead of an instance for WebSocket session support +app = create_app(EchoEnvironment, EchoAction, EchoObservation, env_name="echo_env") def main(): diff --git a/envs/finrl_env/client.py b/envs/finrl_env/client.py index 38ab07382..9fb1a51ed 100644 --- a/envs/finrl_env/client.py +++ b/envs/finrl_env/client.py @@ -5,10 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -FinRL Environment HTTP Client. +FinRL Environment Client. This module provides the client for connecting to a FinRL Environment server -over HTTP. +via WebSocket for persistent sessions. """ from typing import Any, Dict @@ -16,81 +16,69 @@ from openenv.core.client_types import StepResult from openenv.core.env_server.types import State -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.env_client import EnvClient from .models import FinRLAction, FinRLObservation -class FinRLEnv(HTTPEnvClient[FinRLAction, FinRLObservation]): +class FinRLEnv(EnvClient[FinRLAction, FinRLObservation, State]): """ - HTTP client for the FinRL Environment. + Client for the FinRL Environment. - This client connects to a FinRLEnvironment HTTP server and provides - methods to interact with it for stock trading RL tasks. + This client maintains a persistent WebSocket connection to the environment + server, enabling efficient multi-step interactions for stock trading RL tasks. Example: >>> # Connect to a running server - >>> client = FinRLEnv(base_url="http://localhost:8000") - >>> result = client.reset() - >>> print(result.observation.state) - >>> print(result.observation.portfolio_value) - >>> - >>> # Execute a trading action - >>> action = FinRLAction(actions=[0.5, -0.3]) # Buy stock 0, sell stock 1 - >>> result = client.step(action) - >>> print(result.reward) - >>> print(result.observation.portfolio_value) + >>> with FinRLEnv(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.state) + ... print(result.observation.portfolio_value) + ... + ... # Execute a trading action + ... action = FinRLAction(actions=[0.5, -0.3]) # Buy stock 0, sell stock 1 + ... result = client.step(action) + ... print(result.reward) + ... print(result.observation.portfolio_value) Example with Docker: >>> # Automatically start container and connect >>> client = FinRLEnv.from_docker_image("finrl-env:latest") - >>> result = client.reset() - >>> result = client.step(FinRLAction(actions=[0.1])) - >>> client.close() + >>> try: + ... result = client.reset() + ... result = client.step(FinRLAction(actions=[0.1])) + ... finally: + ... client.close() Example training loop: >>> import numpy as np >>> from envs.finrl_env import FinRLEnv, FinRLAction >>> - >>> client = FinRLEnv(base_url="http://localhost:8000") - >>> - >>> # Training loop - >>> for episode in range(10): - >>> result = client.reset() - >>> done = False - >>> episode_reward = 0 - >>> - >>> while not done: - >>> # Get state - >>> state = result.observation.state - >>> - >>> # Simple random policy (replace with your RL agent) - >>> num_stocks = len(state) // 7 # Simplified calculation - >>> actions = np.random.uniform(-1, 1, size=num_stocks).tolist() - >>> - >>> # Execute action - >>> result = client.step(FinRLAction(actions=actions)) - >>> - >>> episode_reward += result.reward or 0 - >>> done = result.done - >>> - >>> print(f"Episode {episode}: reward={episode_reward:.2f}, " - >>> f"final value={result.observation.portfolio_value:.2f}") - >>> - >>> client.close() + >>> with FinRLEnv(base_url="http://localhost:8000") as client: + ... # Training loop + ... for episode in range(10): + ... result = client.reset() + ... done = False + ... episode_reward = 0 + ... + ... while not done: + ... # Get state + ... state = result.observation.state + ... + ... # Simple random policy (replace with your RL agent) + ... num_stocks = len(state) // 7 # Simplified calculation + ... actions = np.random.uniform(-1, 1, size=num_stocks).tolist() + ... + ... # Execute action + ... result = client.step(FinRLAction(actions=actions)) + ... + ... episode_reward += result.reward or 0 + ... done = result.done + ... + ... print(f"Episode {episode}: reward={episode_reward:.2f}, " + ... f"final value={result.observation.portfolio_value:.2f}") """ - def get_config(self) -> Dict[str, Any]: - """ - Get the environment configuration from the server. - - Returns: - Dictionary containing environment configuration - """ - response = self.session.get(f"{self.base_url}/config") - response.raise_for_status() - return response.json() - def _step_payload(self, action: FinRLAction) -> Dict: """ Convert FinRLAction to JSON payload for step request. diff --git a/envs/finrl_env/server/app.py b/envs/finrl_env/server/app.py index 1e4a34ca9..f02f659c7 100644 --- a/envs/finrl_env/server/app.py +++ b/envs/finrl_env/server/app.py @@ -8,7 +8,7 @@ FastAPI application for the FinRL Environment. This module creates an HTTP server that exposes the FinRLEnvironment -over HTTP endpoints, making it compatible with HTTPEnvClient. +over HTTP and WebSocket endpoints, compatible with EnvClient. The server expects environment configuration to be provided either: 1. Through environment variables (FINRL_CONFIG_PATH) @@ -32,7 +32,7 @@ from pathlib import Path import pandas as pd -from openenv.core.env_server import create_fastapi_app +from openenv.core.env_server import create_app from ..models import FinRLAction, FinRLObservation from .finrl_environment import FinRLEnvironment @@ -116,11 +116,16 @@ def load_finrl_config(): # Load configuration finrl_env_class, finrl_config = load_finrl_config() -# Create the environment instance -env = FinRLEnvironment(finrl_env_class=finrl_env_class, finrl_env_config=finrl_config) + +# Factory function to create FinRLEnvironment instances +def create_finrl_environment(): + """Factory function that creates FinRLEnvironment with config.""" + return FinRLEnvironment(finrl_env_class=finrl_env_class, finrl_env_config=finrl_config) + # Create the FastAPI app with routes -app = create_fastapi_app(env, FinRLAction, FinRLObservation) +# Pass the factory function instead of an instance for WebSocket session support +app = create_app(create_finrl_environment, FinRLAction, FinRLObservation, env_name="finrl_env") @app.get("/config") diff --git a/envs/git_env/client.py b/envs/git_env/client.py index 28824a578..efbf6182d 100644 --- a/envs/git_env/client.py +++ b/envs/git_env/client.py @@ -3,7 +3,9 @@ GitEnv Client ------------- Client-side wrapper for the Git environment server. -Talks HTTP to a single base_url exposing: /reset and /step. + +This client maintains a persistent WebSocket connection to the environment +server, enabling efficient multi-step interactions with lower latency. """ from __future__ import annotations @@ -11,7 +13,7 @@ from typing import TYPE_CHECKING from openenv.core.client_types import StepResult -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.env_client import EnvClient from .models import GitAction, GitObservation, GitState @@ -19,12 +21,12 @@ from openenv.core.containers.runtime import ContainerProvider -class GitEnv(HTTPEnvClient[GitAction, GitObservation]): +class GitEnv(EnvClient[GitAction, GitObservation, GitState]): """ Client for Git Environment with Gitea server. - This client communicates with the Git environment server over HTTP, - allowing agents to perform Git operations through a simple API. + This client maintains a persistent WebSocket connection to the environment + server, enabling efficient multi-step interactions for Git operations. The environment connects to a shared external Gitea service. Repositories must be pre-migrated to Gitea before use. @@ -32,25 +34,25 @@ class GitEnv(HTTPEnvClient[GitAction, GitObservation]): Example: >>> # From Docker image >>> client = GitEnv.from_docker_image("git-env:latest") - >>> result = client.reset() - >>> - >>> # List available repositories - >>> from envs.git_env import GitAction - >>> result = client.step(GitAction(action_type="list_repos")) - >>> print(result.observation.repos) - >>> - >>> # Clone repository to workspace - >>> result = client.step(GitAction(action_type="clone_repo", repo_name="OpenEnv")) - >>> - >>> # Execute git commands - >>> result = client.step(GitAction( - ... action_type="execute_git_command", - ... command="status", - ... working_dir="OpenEnv" - ... )) - >>> - >>> # Cleanup - >>> client.close() + >>> try: + ... result = client.reset() + ... + ... # List available repositories + ... from envs.git_env import GitAction + ... result = client.step(GitAction(action_type="list_repos")) + ... print(result.observation.repos) + ... + ... # Clone repository to workspace + ... result = client.step(GitAction(action_type="clone_repo", repo_name="OpenEnv")) + ... + ... # Execute git commands + ... result = client.step(GitAction( + ... action_type="execute_git_command", + ... command="status", + ... working_dir="OpenEnv" + ... )) + ... finally: + ... client.close() """ def _step_payload(self, action: GitAction) -> dict: diff --git a/envs/git_env/server/app.py b/envs/git_env/server/app.py index 3246c4af5..a73e22973 100644 --- a/envs/git_env/server/app.py +++ b/envs/git_env/server/app.py @@ -44,16 +44,21 @@ if not gitea_password: raise RuntimeError("GITEA_PASSWORD environment variable is required") -# Create the environment instance (connects to external Gitea) -env = GitTaskEnvironment( - gitea_url=gitea_url, - username=gitea_username, - password=gitea_password, - workspace_dir=workspace_dir, -) + +# Factory function to create GitTaskEnvironment instances +def create_git_environment(): + """Factory function that creates GitTaskEnvironment with config.""" + return GitTaskEnvironment( + gitea_url=gitea_url, + username=gitea_username, + password=gitea_password, + workspace_dir=workspace_dir, + ) + # Create the app with web interface and README integration -app = create_app(env, GitAction, GitObservation, env_name="git_env") +# Pass the factory function instead of an instance for WebSocket session support +app = create_app(create_git_environment, GitAction, GitObservation, env_name="git_env") if __name__ == "__main__": diff --git a/envs/openspiel_env/client.py b/envs/openspiel_env/client.py index cb80e8f68..946cd1fdd 100644 --- a/envs/openspiel_env/client.py +++ b/envs/openspiel_env/client.py @@ -5,10 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -OpenSpielEnv HTTP Client. +OpenSpielEnv Client. This module provides the client for connecting to an OpenSpiel Environment server -over HTTP. +via WebSocket for persistent sessions. """ from __future__ import annotations @@ -17,7 +17,7 @@ from openenv.core.client_types import StepResult -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.env_client import EnvClient from .models import OpenSpielAction, OpenSpielObservation, OpenSpielState @@ -25,28 +25,30 @@ from openenv.core.containers.runtime import ContainerProvider -class OpenSpielEnv(HTTPEnvClient[OpenSpielAction, OpenSpielObservation]): +class OpenSpielEnv(EnvClient[OpenSpielAction, OpenSpielObservation, OpenSpielState]): """ - HTTP client for OpenSpiel Environment. + Client for OpenSpiel Environment. - This client connects to an OpenSpielEnvironment HTTP server and provides - methods to interact with it: reset(), step(), and state access. + This client maintains a persistent WebSocket connection to the environment + server, enabling efficient multi-step interactions with lower latency. Example: >>> # Connect to a running server - >>> client = OpenSpielEnv(base_url="http://localhost:8000") - >>> result = client.reset() - >>> print(result.observation.info_state) - >>> - >>> # Take an action - >>> result = client.step(OpenSpielAction(action_id=1, game_name="catch")) - >>> print(result.observation.reward) + >>> with OpenSpielEnv(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.info_state) + ... + ... result = client.step(OpenSpielAction(action_id=1, game_name="catch")) + ... print(result.observation.reward) Example with Docker: >>> # Automatically start container and connect >>> client = OpenSpielEnv.from_docker_image("openspiel-env:latest") - >>> result = client.reset() - >>> result = client.step(OpenSpielAction(action_id=0)) + >>> try: + ... result = client.reset() + ... result = client.step(OpenSpielAction(action_id=0)) + ... finally: + ... client.close() """ def _step_payload(self, action: OpenSpielAction) -> Dict[str, Any]: diff --git a/envs/openspiel_env/server/app.py b/envs/openspiel_env/server/app.py index 11107fbd4..01dc35218 100644 --- a/envs/openspiel_env/server/app.py +++ b/envs/openspiel_env/server/app.py @@ -8,7 +8,7 @@ FastAPI application for the OpenSpiel Environment. This module creates an HTTP server that exposes OpenSpiel games -over HTTP endpoints, making them compatible with HTTPEnvClient. +over HTTP and WebSocket endpoints, compatible with EnvClient. Usage: # Development (with auto-reload): @@ -38,15 +38,20 @@ agent_player = int(os.getenv("OPENSPIEL_AGENT_PLAYER", "0")) opponent_policy = os.getenv("OPENSPIEL_OPPONENT_POLICY", "random") -# Create the environment instance -env = OpenSpielEnvironment( - game_name=game_name, - agent_player=agent_player, - opponent_policy=opponent_policy, -) + +# Factory function to create OpenSpielEnvironment instances +def create_openspiel_environment(): + """Factory function that creates OpenSpielEnvironment with config.""" + return OpenSpielEnvironment( + game_name=game_name, + agent_player=agent_player, + opponent_policy=opponent_policy, + ) + # Create the FastAPI app with web interface and README integration -app = create_app(env, OpenSpielAction, OpenSpielObservation, env_name="openspiel_env") +# Pass the factory function instead of an instance for WebSocket session support +app = create_app(create_openspiel_environment, OpenSpielAction, OpenSpielObservation, env_name="openspiel_env") if __name__ == "__main__": diff --git a/envs/sumo_rl_env/client.py b/envs/sumo_rl_env/client.py index 19fb5bd36..89390398d 100644 --- a/envs/sumo_rl_env/client.py +++ b/envs/sumo_rl_env/client.py @@ -5,47 +5,46 @@ # LICENSE file in the root directory of this source tree. """ -HTTP client for SUMO-RL environment. +Client for SUMO-RL environment. This module provides a client to interact with the SUMO traffic signal -control environment over HTTP. +control environment via WebSocket for persistent sessions. """ from typing import Any, Dict from openenv.core.client_types import StepResult -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.env_client import EnvClient from .models import SumoAction, SumoObservation, SumoState -class SumoRLEnv(HTTPEnvClient[SumoAction, SumoObservation]): +class SumoRLEnv(EnvClient[SumoAction, SumoObservation, SumoState]): """ - HTTP client for SUMO-RL traffic signal control environment. + Client for SUMO-RL traffic signal control environment. - This client communicates with a SUMO environment server to control - traffic signals using reinforcement learning. + This client maintains a persistent WebSocket connection to a SUMO + environment server to control traffic signals using reinforcement learning. Example: >>> # Start container and connect >>> env = SumoRLEnv.from_docker_image("sumo-rl-env:latest") - >>> - >>> # Reset environment - >>> result = env.reset() - >>> print(f"Observation shape: {result.observation.observation_shape}") - >>> print(f"Action space: {result.observation.action_mask}") - >>> - >>> # Take action - >>> result = env.step(SumoAction(phase_id=1)) - >>> print(f"Reward: {result.reward}, Done: {result.done}") - >>> - >>> # Get state - >>> state = env.state() - >>> print(f"Sim time: {state.sim_time}, Total vehicles: {state.total_vehicles}") - >>> - >>> # Cleanup - >>> env.close() + >>> try: + ... # Reset environment + ... result = env.reset() + ... print(f"Observation shape: {result.observation.observation_shape}") + ... print(f"Action space: {result.observation.action_mask}") + ... + ... # Take action + ... result = env.step(SumoAction(phase_id=1)) + ... print(f"Reward: {result.reward}, Done: {result.done}") + ... + ... # Get state + ... state = env.state() + ... print(f"Sim time: {state.sim_time}, Total vehicles: {state.total_vehicles}") + ... finally: + ... env.close() Example with custom network: >>> # Use custom SUMO network via volume mount diff --git a/envs/sumo_rl_env/server/app.py b/envs/sumo_rl_env/server/app.py index 3240902c2..b0f5ea7d3 100644 --- a/envs/sumo_rl_env/server/app.py +++ b/envs/sumo_rl_env/server/app.py @@ -13,7 +13,7 @@ import os -from openenv.core.env_server import create_fastapi_app +from openenv.core.env_server import create_app from ..models import SumoAction, SumoObservation from .sumo_environment import SumoEnvironment @@ -29,19 +29,23 @@ reward_fn = os.getenv("SUMO_REWARD_FN", "diff-waiting-time") sumo_seed = int(os.getenv("SUMO_SEED", "42")) -# Create single environment instance -# This is reused for all HTTP requests (avoids TraCI connection issues) -env = SumoEnvironment( - net_file=net_file, - route_file=route_file, - num_seconds=num_seconds, - delta_time=delta_time, - yellow_time=yellow_time, - min_green=min_green, - max_green=max_green, - reward_fn=reward_fn, - sumo_seed=sumo_seed, -) + +# Factory function to create SumoEnvironment instances +def create_sumo_environment(): + """Factory function that creates SumoEnvironment with config.""" + return SumoEnvironment( + net_file=net_file, + route_file=route_file, + num_seconds=num_seconds, + delta_time=delta_time, + yellow_time=yellow_time, + min_green=min_green, + max_green=max_green, + reward_fn=reward_fn, + sumo_seed=sumo_seed, + ) + # Create FastAPI app -app = create_fastapi_app(env, SumoAction, SumoObservation) +# Pass the factory function instead of an instance for WebSocket session support +app = create_app(create_sumo_environment, SumoAction, SumoObservation, env_name="sumo_rl_env") diff --git a/envs/textarena_env/client.py b/envs/textarena_env/client.py index 36f59716a..9c2b52a01 100644 --- a/envs/textarena_env/client.py +++ b/envs/textarena_env/client.py @@ -4,14 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""HTTP client for the generic TextArena environment.""" +"""Client for the generic TextArena environment.""" from __future__ import annotations from typing import Any, Dict, TYPE_CHECKING from openenv.core.client_types import StepResult -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core.env_client import EnvClient from .models import ( TextArenaAction, @@ -24,8 +24,8 @@ from openenv.core.containers.runtime import ContainerProvider -class TextArenaEnv(HTTPEnvClient[TextArenaAction, TextArenaObservation]): - """HTTP client for the TextArena environment server.""" +class TextArenaEnv(EnvClient[TextArenaAction, TextArenaObservation, TextArenaState]): + """Client for the TextArena environment server.""" def _step_payload(self, action: TextArenaAction) -> Dict[str, Any]: return {"message": action.message} diff --git a/envs/textarena_env/server/app.py b/envs/textarena_env/server/app.py index 83d8d09ec..900a138c0 100644 --- a/envs/textarena_env/server/app.py +++ b/envs/textarena_env/server/app.py @@ -35,15 +35,22 @@ def _parse_env_kwargs(prefix: str = "TEXTARENA_KW_") -> dict[str, str]: extra_kwargs = _parse_env_kwargs() -environment = TextArenaEnvironment( - env_id=env_id, - num_players=num_players, - max_turns=max_turns, - download_nltk=download_nltk, - env_kwargs=extra_kwargs, -) - -app = create_app(environment, TextArenaAction, TextArenaObservation, env_name="textarena_env") + +# Factory function to create TextArenaEnvironment instances +def create_textarena_environment(): + """Factory function that creates TextArenaEnvironment with config.""" + return TextArenaEnvironment( + env_id=env_id, + num_players=num_players, + max_turns=max_turns, + download_nltk=download_nltk, + env_kwargs=extra_kwargs, + ) + + +# Create the FastAPI app +# Pass the factory function instead of an instance for WebSocket session support +app = create_app(create_textarena_environment, TextArenaAction, TextArenaObservation, env_name="textarena_env") if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 811c068c9..b7fa6794a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,8 @@ dependencies = [ "huggingface_hub>=0.20.0", "openai>=2.7.2", "tomli>=2.3.0", - "tomli-w>=1.2.0" + "tomli-w>=1.2.0", + "websockets>=15.0.1", ] [project.optional-dependencies] @@ -32,6 +33,7 @@ core = [ "pydantic>=2.0.0", "uvicorn>=0.24.0", "requests>=2.25.0", + "websockets>=15.0.1", ] cli = [ "typer>=0.9.0", diff --git a/src/openenv/cli/templates/openenv_env/README.md b/src/openenv/cli/templates/openenv_env/README.md index ef238dfb7..3f14526a0 100644 --- a/src/openenv/cli/templates/openenv_env/README.md +++ b/src/openenv/cli/templates/openenv_env/README.md @@ -114,6 +114,7 @@ The deployed space includes: - **Web Interface** at `/web` - Interactive UI for exploring the environment - **API Documentation** at `/docs` - Full OpenAPI/Swagger interface - **Health Check** at `/health` - Container health monitoring +- **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions ## Environment Details @@ -154,6 +155,61 @@ result = __ENV_NAME__env.step(__ENV_CLASS_NAME__Action(message="Hello!")) Note: When connecting to an existing server, `__ENV_NAME__env.close()` will NOT stop the server. +### Using the Context Manager + +The client supports context manager usage for automatic connection management: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env + +# Connect with context manager (auto-connects and closes) +with __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") as env: + result = env.reset() + print(f"Reset: {result.observation.echoed_message}") + # Multiple steps with low latency + for msg in ["Hello", "World", "!"]: + result = env.step(__ENV_CLASS_NAME__Action(message=msg)) + print(f"Echoed: {result.observation.echoed_message}") +``` + +The client uses WebSocket connections for: +- **Lower latency**: No HTTP connection overhead per request +- **Persistent session**: Server maintains your environment state +- **Efficient for episodes**: Better for many sequential steps + +### Concurrent WebSocket Sessions + +The server supports multiple concurrent WebSocket connections. To enable this, +modify `server/app.py` to use factory mode: + +```python +# In server/app.py - use factory mode for concurrent sessions +app = create_app( + __ENV_CLASS_NAME__Environment, # Pass class, not instance + __ENV_CLASS_NAME__Action, + __ENV_CLASS_NAME__Observation, + max_concurrent_envs=4, # Allow 4 concurrent sessions +) +``` + +Then multiple clients can connect simultaneously: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env +from concurrent.futures import ThreadPoolExecutor + +def run_episode(client_id: int): + with __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") as env: + result = env.reset() + for i in range(10): + result = env.step(__ENV_CLASS_NAME__Action(message=f"Client {client_id}, step {i}")) + return client_id, result.observation.message_length + +# Run 4 episodes concurrently +with ThreadPoolExecutor(max_workers=4) as executor: + results = list(executor.map(run_episode, range(4))) +``` + ## Development & Testing ### Direct Environment Testing @@ -189,11 +245,11 @@ __ENV_NAME__/ ├── openenv.yaml # OpenEnv manifest ├── pyproject.toml # Project metadata and dependencies ├── uv.lock # Locked dependencies (generated) -├── client.py # __ENV_CLASS_NAME__Env client implementation +├── client.py # __ENV_CLASS_NAME__Env client ├── models.py # Action and Observation models └── server/ ├── __init__.py # Server module exports ├── __ENV_NAME___environment.py # Core environment logic - ├── app.py # FastAPI application + ├── app.py # FastAPI application (HTTP + WebSocket endpoints) └── Dockerfile # Container image definition ``` diff --git a/src/openenv/cli/templates/openenv_env/__init__.py b/src/openenv/cli/templates/openenv_env/__init__.py index 656800a55..cbe07a082 100644 --- a/src/openenv/cli/templates/openenv_env/__init__.py +++ b/src/openenv/cli/templates/openenv_env/__init__.py @@ -4,10 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""__ENV_TITLE_NAME__ Environment - A simple test environment for HTTP server.""" +"""__ENV_TITLE_NAME__ Environment.""" from .client import __ENV_CLASS_NAME__Env from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation -__all__ = ["__ENV_CLASS_NAME__Action", "__ENV_CLASS_NAME__Observation", "__ENV_CLASS_NAME__Env"] - +__all__ = [ + "__ENV_CLASS_NAME__Action", + "__ENV_CLASS_NAME__Observation", + "__ENV_CLASS_NAME__Env", +] diff --git a/src/openenv/cli/templates/openenv_env/client.py b/src/openenv/cli/templates/openenv_env/client.py index 703b28a85..6be3eefd9 100644 --- a/src/openenv/cli/templates/openenv_env/client.py +++ b/src/openenv/cli/templates/openenv_env/client.py @@ -4,50 +4,47 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -__ENV_TITLE_NAME__ Environment HTTP Client. +"""__ENV_TITLE_NAME__ Environment Client.""" -This module provides the client for connecting to a __ENV_TITLE_NAME__ Environment server -over HTTP. -""" - -from typing import Any, Dict +from typing import Dict from openenv.core.client_types import StepResult from openenv.core.env_server.types import State -from openenv.core.http_env_client import HTTPEnvClient +from openenv.core import EnvClient from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation -class __ENV_CLASS_NAME__Env(HTTPEnvClient[__ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation]): +class __ENV_CLASS_NAME__Env(EnvClient[__ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation]): """ - HTTP client for the __ENV_TITLE_NAME__ Environment. + Client for the __ENV_TITLE_NAME__ Environment. - This client connects to a __ENV_CLASS_NAME__Environment HTTP server and provides - methods to interact with it: reset(), step(), and state access. + This client maintains a persistent WebSocket connection to the environment server, + enabling efficient multi-step interactions with lower latency. + Each client instance has its own dedicated environment session on the server. Example: >>> # Connect to a running server - >>> client = __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") - >>> result = client.reset() - >>> print(result.observation.echoed_message) - >>> - >>> # Send a message - >>> result = client.step(__ENV_CLASS_NAME__Action(message="Hello!")) - >>> print(result.observation.echoed_message) - >>> print(result.reward) + >>> with __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.echoed_message) + ... + ... result = client.step(__ENV_CLASS_NAME__Action(message="Hello!")) + ... print(result.observation.echoed_message) Example with Docker: >>> # Automatically start container and connect >>> client = __ENV_CLASS_NAME__Env.from_docker_image("__ENV_NAME__-env:latest") - >>> result = client.reset() - >>> result = client.step(__ENV_CLASS_NAME__Action(message="Test")) + >>> try: + ... result = client.reset() + ... result = client.step(__ENV_CLASS_NAME__Action(message="Test")) + ... finally: + ... client.close() """ def _step_payload(self, action: __ENV_CLASS_NAME__Action) -> Dict: """ - Convert __ENV_CLASS_NAME__Action to JSON payload for step request. + Convert __ENV_CLASS_NAME__Action to JSON payload for step message. Args: action: __ENV_CLASS_NAME__Action instance @@ -64,7 +61,7 @@ def _parse_result(self, payload: Dict) -> StepResult[__ENV_CLASS_NAME__Observati Parse server response into StepResult[__ENV_CLASS_NAME__Observation]. Args: - payload: JSON response from server + payload: JSON response data from server Returns: StepResult with __ENV_CLASS_NAME__Observation @@ -89,7 +86,7 @@ def _parse_state(self, payload: Dict) -> State: Parse server response into State object. Args: - payload: JSON response from /state endpoint + payload: JSON response from state request Returns: State object with episode_id and step_count diff --git a/src/openenv/cli/templates/openenv_env/models.py b/src/openenv/cli/templates/openenv_env/models.py index 64010449b..4540d5a29 100644 --- a/src/openenv/cli/templates/openenv_env/models.py +++ b/src/openenv/cli/templates/openenv_env/models.py @@ -10,22 +10,20 @@ The __ENV_NAME__ environment is a simple test environment that echoes back messages. """ -from dataclasses import dataclass +from pydantic import Field from openenv.core.env_server.types import Action, Observation -@dataclass(kw_only=True) class __ENV_CLASS_NAME__Action(Action): """Action for the __ENV_TITLE_NAME__ environment - just a message to echo.""" - message: str + message: str = Field(..., description="Message to echo back") -@dataclass(kw_only=True) class __ENV_CLASS_NAME__Observation(Observation): """Observation from the __ENV_TITLE_NAME__ environment - the echoed message.""" - echoed_message: str - message_length: int = 0 + echoed_message: str = Field(default="", description="The echoed message") + message_length: int = Field(default=0, description="Length of the echoed message") diff --git a/src/openenv/cli/templates/openenv_env/pyproject.toml b/src/openenv/cli/templates/openenv_env/pyproject.toml index 55b90113f..4c6b948ff 100644 --- a/src/openenv/cli/templates/openenv_env/pyproject.toml +++ b/src/openenv/cli/templates/openenv_env/pyproject.toml @@ -15,6 +15,8 @@ description = "__ENV_TITLE_NAME__ environment for OpenEnv" requires-python = ">=3.10" dependencies = [ # Core OpenEnv runtime (provides FastAPI server + HTTP client types) + # install from github + # "openenv[core] @ git+https://github.com/meta-pytorch/OpenEnv.git", "openenv[core]>=0.2.0", # Environment-specific dependencies # Add all dependencies needed for your environment here diff --git a/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py b/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py index e2a9ce0b7..454ea6808 100644 --- a/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py +++ b/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py @@ -36,6 +36,12 @@ class __ENV_CLASS_NAME__Environment(Environment): >>> print(obs.message_length) # 5 """ + # Enable concurrent WebSocket sessions. + # Set to True if your environment isolates state between instances. + # When True, multiple WebSocket clients can connect simultaneously, each + # getting their own environment instance (when using factory mode in app.py). + SUPPORTS_CONCURRENT_SESSIONS: bool = True + def __init__(self): """Initialize the __ENV_NAME__ environment.""" self._state = State(episode_id=str(uuid4()), step_count=0) diff --git a/src/openenv/cli/templates/openenv_env/server/app.py b/src/openenv/cli/templates/openenv_env/server/app.py index db216fb06..025920a1b 100644 --- a/src/openenv/cli/templates/openenv_env/server/app.py +++ b/src/openenv/cli/templates/openenv_env/server/app.py @@ -8,7 +8,14 @@ FastAPI application for the __ENV_TITLE_NAME__ Environment. This module creates an HTTP server that exposes the __ENV_CLASS_NAME__Environment -over HTTP endpoints, making it compatible with HTTPEnvClient. +over HTTP and WebSocket endpoints, compatible with EnvClient. + +Endpoints: + - POST /reset: Reset the environment + - POST /step: Execute an action + - GET /state: Get current environment state + - GET /schema: Get action/observation schemas + - WS /ws: WebSocket endpoint for persistent sessions Usage: # Development (with auto-reload): @@ -28,18 +35,18 @@ "openenv is required for the web interface. Install dependencies with '\n uv sync\n'" ) from e -from __ENV_NAME__.models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation +# Import from local models.py (PYTHONPATH includes /app/env in Docker) +from models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation from .__ENV_NAME___environment import __ENV_CLASS_NAME__Environment -# Create the environment instance -env = __ENV_CLASS_NAME__Environment() # Create the app with web interface and README integration app = create_app( - env, + __ENV_CLASS_NAME__Environment, __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation, env_name="__ENV_NAME__", + max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions ) diff --git a/src/openenv/core/README.md b/src/openenv/core/README.md index 2251e10a6..ebfa579aa 100644 --- a/src/openenv/core/README.md +++ b/src/openenv/core/README.md @@ -22,8 +22,8 @@ Core components for OpenEnv - a framework for building HTTP-based agentic enviro ## Features -- **HTTPEnvClient**: Generic HTTP client for interacting with remote environments -- **HTTPEnvServer**: FastAPI-based server wrapper for exposing environments over HTTP +- **EnvClient**: Generic client for interacting with remote environments +- **HTTPEnvServer**: FastAPI-based server wrapper for exposing environments over HTTP/WebSocket - **Container Providers**: Pluggable architecture for running containers (Docker, Kubernetes, etc.) - **Type System**: Strongly-typed Action/Observation/State interfaces - **Web Interface**: Optional web UI for interacting with environments @@ -44,7 +44,7 @@ pip install "openenv[core]" ### Creating an Environment Client ```python -from openenv.core import HTTPEnvClient, StepResult +from openenv.core import EnvClient, StepResult from dataclasses import dataclass @dataclass @@ -55,7 +55,7 @@ class MyAction: class MyObservation: response: str -class MyEnvClient(HTTPEnvClient[MyAction, MyObservation]): +class MyEnvClient(EnvClient[MyAction, MyObservation]): def _step_payload(self, action: MyAction) -> dict: return {"text": action.text} @@ -141,7 +141,7 @@ provider.stop_container() ## API Reference -### HTTPEnvClient +### EnvClient Base class for environment clients with these abstract methods: diff --git a/src/openenv/core/__init__.py b/src/openenv/core/__init__.py index 99507ab55..5a7af20db 100644 --- a/src/openenv/core/__init__.py +++ b/src/openenv/core/__init__.py @@ -7,13 +7,10 @@ """Core components for agentic environments.""" # Re-export main components from submodules for convenience -from .env_server import * -from .client_types import StepResult -from .http_env_client import HTTPEnvClient +from .env_server import * # noqa: F403 +from . import env_server +from .env_client import EnvClient # Note: MCP module doesn't export anything yet -__all__ = [ - "HTTPEnvClient", - "StepResult", -] +__all__ = ["EnvClient"] + env_server.__all__ # type: ignore diff --git a/src/openenv/core/client_types.py b/src/openenv/core/client_types.py index 8808e96bf..c7501c656 100644 --- a/src/openenv/core/client_types.py +++ b/src/openenv/core/client_types.py @@ -1,9 +1,10 @@ # Type definitions for EnvTorch from dataclasses import dataclass -from typing import Any, Generic, Optional, TypeVar +from typing import Generic, Optional, TypeVar # Generic type for observations -ObsT = TypeVar("ObsT") # TypeVar for typehinting in IDEs +ObsT = TypeVar("ObsT") +StateT = TypeVar("StateT") @dataclass diff --git a/src/openenv/core/containers/runtime/providers.py b/src/openenv/core/containers/runtime/providers.py index a8022ddca..f6f2b0ca6 100644 --- a/src/openenv/core/containers/runtime/providers.py +++ b/src/openenv/core/containers/runtime/providers.py @@ -8,7 +8,7 @@ Container provider abstractions for running environment servers. This module provides a pluggable architecture for different container providers -(local Docker, Kubernetes, cloud providers, etc.) to be used with HTTPEnvClient. +(local Docker, Kubernetes, cloud providers, etc.) to be used with EnvClient. """ from __future__ import annotations diff --git a/src/openenv/core/env_client.py b/src/openenv/core/env_client.py new file mode 100644 index 000000000..356fe72c9 --- /dev/null +++ b/src/openenv/core/env_client.py @@ -0,0 +1,289 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Environment client for persistent sessions. + +This module provides a WebSocket-based client that maintains a persistent connection +to an environment server, enabling efficient multi-step interactions without +the overhead of HTTP request/response cycles. +""" + +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar + +from .client_types import StepResult, StateT +from .containers.runtime import LocalDockerProvider +from .utils import convert_to_ws_url + +if TYPE_CHECKING: + from .containers.runtime import ContainerProvider + from websockets.sync.client import ClientConnection + +from websockets.sync.client import connect as ws_connect + +ActT = TypeVar("ActT") +ObsT = TypeVar("ObsT") +EnvClientT = TypeVar("EnvClientT", bound="EnvClient") + + +class EnvClient(ABC, Generic[ActT, ObsT, StateT]): + """ + Environment client for persistent sessions. + + This client maintains a persistent WebSocket connection to an environment + server, enabling efficient multi-step interactions. Each client instance + corresponds to a dedicated environment session on the server. + + Features: + - Lower latency for sequential interactions + - Session state is maintained server-side + - Better suited for long-running episodes + + Example: + >>> from envs.coding_env.client import CodingEnv + >>> + >>> # Connect to a server + >>> with CodingEnv(base_url="ws://localhost:8000") as env: + ... result = env.reset(seed=42) + ... while not result.done: + ... action = agent.predict(result.observation) + ... result = env.step(action) + """ + + def __init__( + self, + base_url: str, + connect_timeout_s: float = 10.0, + message_timeout_s: float = 60.0, + provider: Optional["ContainerProvider"] = None, + ): + """ + Initialize environment client. + + Args: + base_url: Base URL of the environment server (http:// or ws://). + Will be converted to ws:// if http:// is provided. + connect_timeout_s: Timeout for establishing WebSocket connection + message_timeout_s: Timeout for receiving responses to messages + provider: Optional container provider for lifecycle management + """ + # Convert HTTP URL to WebSocket URL + ws_url = convert_to_ws_url(base_url) + + self._ws_url = f"{ws_url}/ws" + self._connect_timeout = connect_timeout_s + self._message_timeout = message_timeout_s + self._provider = provider + self._ws: Optional[ClientConnection] = None + + def connect(self) -> "EnvClient": + """ + Establish WebSocket connection to the server. + + Returns: + self for method chaining + + Raises: + ConnectionError: If connection cannot be established + """ + if self._ws is not None: + return self + + try: + self._ws = ws_connect( + self._ws_url, + open_timeout=self._connect_timeout, + ) + except Exception as e: + raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e + + return self + + def disconnect(self) -> None: + """Close the WebSocket connection.""" + if self._ws is not None: + try: + # Send close message + self._send({"type": "close"}) + except Exception: + pass # Best effort + try: + self._ws.close() + except Exception: + pass + self._ws = None + + def _ensure_connected(self) -> None: + """Ensure WebSocket connection is established.""" + if self._ws is None: + self.connect() + + def _send(self, message: Dict[str, Any]) -> None: + """Send a message over the WebSocket.""" + self._ensure_connected() + assert self._ws is not None + self._ws.send(json.dumps(message)) + + def _receive(self) -> Dict[str, Any]: + """Receive and parse a message from the WebSocket.""" + assert self._ws is not None + raw = self._ws.recv(timeout=self._message_timeout) + return json.loads(raw) + + def _send_and_receive(self, message: Dict[str, Any]) -> Dict[str, Any]: + """Send a message and wait for response.""" + self._send(message) + response = self._receive() + + # Check for error response + if response.get("type") == "error": + error_data = response.get("data", {}) + raise RuntimeError( + f"Server error: {error_data.get('message', 'Unknown error')} " + f"(code: {error_data.get('code', 'UNKNOWN')})" + ) + + return response + + @classmethod + def from_docker_image( + cls: Type[EnvClientT], + image: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> EnvClientT: + """ + Create an environment client by spinning up a Docker container. + + Args: + image: Docker image name to run (e.g., "coding-env:latest") + provider: Container provider to use (defaults to LocalDockerProvider) + **kwargs: Additional arguments to pass to provider.start_container() + + Returns: + Connected client instance + """ + if provider is None: + provider = LocalDockerProvider() + + # Start container + base_url = provider.start_container(image, **kwargs) + + # Wait for server to be ready + provider.wait_for_ready(base_url) + + # Create and connect client + client = cls(base_url=base_url, provider=provider) + client.connect() + + return client + + @classmethod + def from_hub( + cls: Type[EnvClientT], + repo_id: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> EnvClientT: + """ + Create a client by pulling from a Hugging Face model hub. + """ + if provider is None: + provider = LocalDockerProvider() + + tag = kwargs.pop("tag", "latest") + base_url = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}" + + return cls.from_docker_image(image=base_url, provider=provider, **kwargs) + + @abstractmethod + def _step_payload(self, action: ActT) -> Dict[str, Any]: + """Convert an Action object to the JSON data expected by the env server.""" + raise NotImplementedError + + @abstractmethod + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]: + """Convert a JSON response from the env server to StepResult[ObsT].""" + raise NotImplementedError + + @abstractmethod + def _parse_state(self, payload: Dict[str, Any]) -> StateT: + """Convert a JSON response from the state endpoint to a State object.""" + raise NotImplementedError + + def reset(self, **kwargs: Any) -> StepResult[ObsT]: + """ + Reset the environment with optional parameters. + + Args: + **kwargs: Optional parameters passed to the environment's reset method. + Common parameters include: + - seed: Random seed for reproducibility + - episode_id: Custom episode identifier + + Returns: + StepResult containing initial observation + """ + message = { + "type": "reset", + "data": kwargs, + } + response = self._send_and_receive(message) + return self._parse_result(response.get("data", {})) + + def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: + """ + Execute an action in the environment. + + Args: + action: The action to execute + **kwargs: Optional parameters (currently ignored) + + Returns: + StepResult containing observation, reward, and done status + """ + message = { + "type": "step", + "data": self._step_payload(action), + } + response = self._send_and_receive(message) + return self._parse_result(response.get("data", {})) + + def state(self) -> StateT: + """ + Get the current environment state from the server. + + Returns: + State object with environment state information + """ + message = {"type": "state"} + response = self._send_and_receive(message) + return self._parse_state(response.get("data", {})) + + def close(self) -> None: + """ + Close the WebSocket connection and clean up resources. + + If this client was created via from_docker_image(), this will also + stop and remove the associated container. + """ + self.disconnect() + + if self._provider is not None: + self._provider.stop_container() + + def __enter__(self) -> "EnvClient": + """Enter context manager, ensuring connection is established.""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit context manager, closing connection.""" + self.close() diff --git a/src/openenv/core/env_server/__init__.py b/src/openenv/core/env_server/__init__.py index 4e1c2d7ac..ed0d41278 100644 --- a/src/openenv/core/env_server/__init__.py +++ b/src/openenv/core/env_server/__init__.py @@ -15,7 +15,33 @@ deserialize_action_with_preprocessing, serialize_observation, ) -from .types import Action, Observation, State, SchemaResponse, HealthResponse +from .types import ( + Action, + Observation, + State, + SchemaResponse, + HealthResponse, + BaseMessage, + WSIncomingMessage, + WSResetMessage, + WSStepMessage, + WSStateMessage, + WSCloseMessage, + WSObservationResponse, + WSStateResponse, + WSErrorResponse, + ConcurrencyConfig, + ServerCapacityStatus, + SessionInfo, +) +from .exceptions import ( + OpenEnvError, + ConcurrencyConfigurationError, + SessionCapacityError, + SessionNotFoundError, + SessionCreationError, + EnvironmentFactoryError, +) from .web_interface import create_web_interface_app, WebInterfaceManager __all__ = [ @@ -30,6 +56,27 @@ "State", "SchemaResponse", "HealthResponse", + # WebSocket message types + "BaseMessage", + "WSIncomingMessage", + "WSResetMessage", + "WSStepMessage", + "WSStateMessage", + "WSCloseMessage", + "WSObservationResponse", + "WSStateResponse", + "WSErrorResponse", + # Concurrency types + "ConcurrencyConfig", + "ServerCapacityStatus", + "SessionInfo", + # Exceptions + "OpenEnvError", + "ConcurrencyConfigurationError", + "SessionCapacityError", + "SessionNotFoundError", + "SessionCreationError", + "EnvironmentFactoryError", # Base transforms "CompositeTransform", "NullTransform", diff --git a/src/openenv/core/env_server/exceptions.py b/src/openenv/core/env_server/exceptions.py new file mode 100644 index 000000000..4fb4a6ec8 --- /dev/null +++ b/src/openenv/core/env_server/exceptions.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Custom exceptions for environment server operations.""" + +from typing import Optional + + +class OpenEnvError(Exception): + """Base exception for all OpenEnv errors.""" + + pass + + +class ConcurrencyConfigurationError(OpenEnvError): + """ + Raised when an environment is misconfigured for concurrent sessions. + + This error is raised during server startup when max_concurrent_envs > 1 + is specified for an environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. + """ + + def __init__( + self, + environment_name: str, + max_concurrent_envs: int, + message: Optional[str] = None, + ): + self.environment_name = environment_name + self.max_concurrent_envs = max_concurrent_envs + + if message is None: + message = ( + f"Environment '{environment_name}' is not marked as SUPPORTS_CONCURRENT_SESSIONS. " + f"Cannot run with max_concurrent_envs={max_concurrent_envs}. " + f"Either set max_concurrent_envs=1 or ensure the environment " + f"properly isolates session state and set SUPPORTS_CONCURRENT_SESSIONS=True." + ) + + super().__init__(message) + + +class SessionCapacityError(OpenEnvError): + """ + Raised when the server cannot accept new sessions due to capacity limits. + + This error is raised when a new WebSocket connection is attempted but + the server has already reached max_concurrent_envs active sessions. + """ + + def __init__( + self, + active_sessions: int, + max_sessions: int, + message: Optional[str] = None, + ): + self.active_sessions = active_sessions + self.max_sessions = max_sessions + + if message is None: + message = ( + f"Server at capacity: {active_sessions}/{max_sessions} sessions active. " + f"Cannot accept new connections." + ) + + super().__init__(message) + + +class SessionNotFoundError(OpenEnvError): + """Raised when attempting to access a session that does not exist.""" + + def __init__(self, session_id: str, message: Optional[str] = None): + self.session_id = session_id + + if message is None: + message = f"Session '{session_id}' not found." + + super().__init__(message) + + +class SessionCreationError(OpenEnvError): + """Raised when a session cannot be created.""" + + def __init__(self, reason: str, message: Optional[str] = None): + self.reason = reason + + if message is None: + message = f"Failed to create session: {reason}" + + super().__init__(message) + + +class EnvironmentFactoryError(OpenEnvError): + """Raised when the environment factory fails to create an instance.""" + + def __init__(self, factory_name: str, message: Optional[str] = None): + self.factory_name = factory_name + + if message is None: + message = f"Environment factory '{factory_name}' failed to create instance." + + super().__init__(message) diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 7fa7c0f32..ad3f8b365 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -8,18 +8,21 @@ HTTP server wrapper for Environment instances. This module provides utilities to wrap any Environment subclass and expose it -over HTTP endpoints that HTTPEnvClient can consume. +over HTTP and WebSocket endpoints that EnvClient can consume. """ from __future__ import annotations import asyncio import inspect +import json import os +import time +import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Optional, Type +from typing import Any, Callable, Dict, Optional, Type, Union -from fastapi import Body, FastAPI, HTTPException, status +from fastapi import Body, FastAPI, HTTPException, WebSocket, WebSocketDisconnect, status from pydantic import ValidationError from .interfaces import Environment @@ -39,6 +42,21 @@ EnvironmentMetadata, SchemaResponse, HealthResponse, + WSResetMessage, + WSStepMessage, + WSStateMessage, + WSCloseMessage, + WSObservationResponse, + WSStateResponse, + WSErrorResponse, + ConcurrencyConfig, + ServerCapacityStatus, + SessionInfo, +) +from .exceptions import ( + ConcurrencyConfigurationError, + SessionCapacityError, + EnvironmentFactoryError, ) @@ -47,7 +65,7 @@ class HTTPEnvServer: HTTP server wrapper for Environment instances. This class wraps an Environment and exposes its reset(), step(), and state - methods as HTTP endpoints compatible with HTTPEnvClient. + methods as HTTP and WebSocket endpoints compatible with EnvClient. The server expects: - Action deserialization: Converts JSON dict to Action subclass @@ -56,9 +74,15 @@ class HTTPEnvServer: Example: >>> from core.env_server import HTTPEnvServer >>> from envs.coding_env.server import CodeExecutionEnvironment + >>> from envs.coding_env.models import CodeAction, CodeObservation >>> - >>> env = CodeExecutionEnvironment() - >>> server = HTTPEnvServer(env) + >>> # Pass environment class (factory pattern) + >>> server = HTTPEnvServer( + ... env=CodeExecutionEnvironment, + ... action_cls=CodeAction, + ... observation_cls=CodeObservation, + ... max_concurrent_envs=4, + ... ) >>> >>> # Register routes with FastAPI >>> from fastapi import FastAPI @@ -68,31 +92,128 @@ class HTTPEnvServer: def __init__( self, - env: Environment, + env: Callable[[], Environment], action_cls: Type[Action], observation_cls: Type[Observation], + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[ConcurrencyConfig] = None, ): """ Initialize HTTP server wrapper. Args: - env: The Environment instance to wrap + env: Environment factory (callable) that creates new instances. + Will be called to create a new environment for each WebSocket session. action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns + max_concurrent_envs: Maximum number of concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. + + Raises: + ValueError: If both max_concurrent_envs and concurrency_config are provided. + ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an + environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. """ - self.env = env + # Validate that env is callable + if not callable(env): + raise TypeError( + f"env must be a callable (class or factory function), got {type(env)}. " + f"Pass the environment class (e.g., MyEnvironment) not an instance (e.g., MyEnvironment())." + ) + + self._env_factory: Callable[[], Environment] = env + + # Handle concurrency configuration + if max_concurrent_envs is not None and concurrency_config is not None: + raise ValueError( + "Cannot specify both 'max_concurrent_envs' and 'concurrency_config'. " + "Please use only one method to configure concurrency." + ) + + if concurrency_config is not None: + self._concurrency_config = concurrency_config + elif max_concurrent_envs is not None: + self._concurrency_config = ConcurrencyConfig( + max_concurrent_envs=max_concurrent_envs, + session_timeout=None, + ) + else: + # Default configuration + self._concurrency_config = ConcurrencyConfig( + max_concurrent_envs=1, + session_timeout=None, + ) + + self._max_concurrent_envs = self._concurrency_config.max_concurrent_envs + + # Validate concurrency configuration + self._validate_concurrency_safety() + self.action_cls = action_cls self.observation_cls = observation_cls + + # Session management for WebSocket connections + self._sessions: Dict[str, Environment] = {} + self._session_executors: Dict[str, ThreadPoolExecutor] = {} + self._session_info: Dict[str, SessionInfo] = {} + self._session_lock = asyncio.Lock() + # Create thread pool for running sync code in async context - # This is needed for environments using sync libraries (e.g., Playwright sync API) - self._executor = ThreadPoolExecutor(max_workers=1) + # This is needed for environments using sync libraries (e.g., Playwright) + self._executor = ThreadPoolExecutor(max_workers=32) + + def _validate_concurrency_safety(self) -> None: + """ + Validate that the environment supports the configured concurrency level. + + Raises: + ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an + environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. + """ + if self._max_concurrent_envs <= 1: + return + + if inspect.isclass(self._env_factory): + env_cls = self._env_factory + else: + _temp_env = self._env_factory() + env_cls = type(_temp_env) + _temp_env.close() + del _temp_env + + if not getattr(env_cls, "SUPPORTS_CONCURRENT_SESSIONS", False): + raise ConcurrencyConfigurationError( + environment_name=env_cls.__name__, + max_concurrent_envs=self._max_concurrent_envs, + ) + + def get_capacity_status(self) -> ServerCapacityStatus: + """ + Get the current capacity status of the server. + + Returns: + ServerCapacityStatus with current session counts and availability. + """ + return ServerCapacityStatus.from_counts( + active=len(self._sessions), + max_sessions=self._max_concurrent_envs, + ) - async def _run_sync_in_thread_pool(self, func, *args, **kwargs): + async def _run_sync_in_thread_pool( + self, func: Callable[..., Observation], *args, **kwargs + ) -> Observation: """Run a synchronous function in the thread pool executor.""" loop = asyncio.get_event_loop() return await loop.run_in_executor(self._executor, lambda: func(*args, **kwargs)) - def _get_valid_kwargs(self, sig, kwargs, skip_params=None): + def _get_valid_kwargs( + self, + sig: inspect.Signature, + kwargs: Dict[str, Any], + skip_params: Optional[set[str]] = None, + ) -> Dict[str, Any]: """Filter kwargs to only include parameters accepted by the function signature.""" if skip_params is None: skip_params = set() @@ -110,6 +231,129 @@ def _get_valid_kwargs(self, sig, kwargs, skip_params=None): return valid_kwargs + async def _create_session(self) -> tuple[str, Environment]: + """ + Create a new WebSocket session with its own environment instance. + + Returns: + Tuple of (session_id, environment) + + Raises: + SessionCapacityError: If max concurrent sessions reached + EnvironmentFactoryError: If the factory fails to create an environment + """ + async with self._session_lock: + if len(self._sessions) >= self._max_concurrent_envs: + raise SessionCapacityError( + active_sessions=len(self._sessions), + max_sessions=self._max_concurrent_envs, + ) + + session_id = str(uuid.uuid4()) + current_time = time.time() + + try: + env = self._env_factory() + except Exception as e: + factory_name = getattr(self._env_factory, "__name__", str(self._env_factory)) + raise EnvironmentFactoryError(factory_name) from e + + self._sessions[session_id] = env + + self._session_executors[session_id] = ThreadPoolExecutor(max_workers=1) + + # Track session metadata + self._session_info[session_id] = SessionInfo( + session_id=session_id, + created_at=current_time, + last_activity_at=current_time, + step_count=0, + environment_type=type(env).__name__, + ) + + return session_id, env + + async def _destroy_session(self, session_id: str) -> None: + """ + Destroy a WebSocket session and cleanup resources. + + Args: + session_id: The session ID to destroy + """ + async with self._session_lock: + if session_id in self._sessions: + env = self._sessions.pop(session_id) + env.close() + + if session_id in self._session_executors: + executor = self._session_executors.pop(session_id) + executor.shutdown(wait=False) + + self._session_info.pop(session_id, None) + + def _update_session_activity( + self, session_id: str, increment_step: bool = False + ) -> None: + """ + Update session activity timestamp and optionally increment step count. + + Args: + session_id: The session ID to update + increment_step: If True, increment the step count + """ + if session_id in self._session_info: + self._session_info[session_id].last_activity_at = time.time() + if increment_step: + self._session_info[session_id].step_count += 1 + + def get_session_info(self, session_id: str) -> Optional[SessionInfo]: + """ + Get information about a specific session. + + Args: + session_id: The session ID to query + + Returns: + SessionInfo if the session exists, None otherwise + """ + return self._session_info.get(session_id) + + async def _run_in_session_executor( + self, session_id: str, func: Callable[..., Observation], *args, **kwargs + ) -> Observation: + """Run a synchronous function in the session's thread pool executor.""" + executor = self._session_executors.get(session_id, self._executor) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(executor, lambda: func(*args, **kwargs)) + + @property + def active_sessions(self) -> int: + """Return the number of active WebSocket sessions.""" + return len(self._sessions) + + @property + def max_concurrent_envs(self) -> int: + """Return the maximum number of concurrent environments.""" + return self._max_concurrent_envs + + @property + def is_concurrency_safe(self) -> bool: + """Return whether the environment is marked as concurrency safe.""" + import inspect + if inspect.isclass(self._env_factory): + return getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False) + else: + _temp_env = self._env_factory() + result = getattr(_temp_env, "SUPPORTS_CONCURRENT_SESSIONS", False) + _temp_env.close() + del _temp_env + return result + + @property + def concurrency_config(self) -> ConcurrencyConfig: + """Return the concurrency configuration.""" + return self._concurrency_config + def register_routes(self, app: FastAPI) -> None: """ Register HTTP routes on a FastAPI application. @@ -123,49 +367,64 @@ async def reset_handler( request: ResetRequest = Body(default_factory=ResetRequest), ) -> ResetResponse: """Reset endpoint - returns initial observation.""" - # Handle optional parameters - # Start with all fields from the request, including extra ones - kwargs = request.model_dump(exclude_unset=True) - - # Pass arguments only if environment accepts them - sig = inspect.signature(self.env.reset) - valid_kwargs = self._get_valid_kwargs(sig, kwargs) - - # Run synchronous reset in thread pool to avoid blocking event loop - observation = await self._run_sync_in_thread_pool( - self.env.reset, **valid_kwargs - ) - return ResetResponse(**serialize_observation(observation)) + _env = self._env_factory() + + try: + kwargs = request.model_dump(exclude_unset=True) + + is_async = _env.reset_async.__func__ is not Environment.reset_async + + if is_async: + sig = inspect.signature(_env.reset_async) + else: + sig = inspect.signature(_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, kwargs) + + if is_async: + observation = await _env.reset_async(**valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool( + _env.reset, **valid_kwargs + ) + return ResetResponse(**serialize_observation(observation)) + finally: + _env.close() # Helper function to handle step endpoint async def step_handler(request: StepRequest) -> StepResponse: """Step endpoint - executes action and returns observation.""" action_data = request.action - # Deserialize action with Pydantic validation try: action = deserialize_action(action_data, self.action_cls) except ValidationError as e: - # Return HTTP 422 with detailed validation errors raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors() ) - # Handle optional parameters - # Start with all fields from the request, including extra ones, but exclude 'action' - kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) + _env = self._env_factory() + + try: + kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) - # Pass arguments only if environment accepts them - sig = inspect.signature(self.env.step) - valid_kwargs = self._get_valid_kwargs(sig, kwargs, skip_params={"action"}) + is_async = _env.step_async.__func__ is not Environment.step_async - # Run synchronous step in thread pool to avoid blocking event loop - observation = await self._run_sync_in_thread_pool( - self.env.step, action, **valid_kwargs - ) + if is_async: + sig = inspect.signature(_env.step_async) + else: + sig = inspect.signature(_env.step) + valid_kwargs = self._get_valid_kwargs(sig, kwargs, skip_params={"action"}) - # Return serialized observation - return StepResponse(**serialize_observation(observation)) + if is_async: + observation = await _env.step_async(action, **valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool( + _env.step, action, **valid_kwargs + ) + + return StepResponse(**serialize_observation(observation)) + finally: + _env.close() # Register routes using the helpers @app.post( @@ -251,24 +510,36 @@ async def reset( async def step(request: StepRequest) -> StepResponse: return await step_handler(request) - # Configure and register GET endpoints declaratively + def get_state_handler() -> State: + _env = self._env_factory() + try: + return _env.state + finally: + _env.close() + + def get_metadata_handler() -> EnvironmentMetadata: + _env = self._env_factory() + try: + return _env.get_metadata() + finally: + _env.close() + get_endpoints = [ GetEndpointConfig( path="/state", - handler=lambda: self.env.state, + handler=get_state_handler, response_model=State, tag="State Management", summary="Get current environment state", description=""" Retrieve the current internal state of the environment. -This endpoint allows inspection of the environment state without modifying it. The structure of the state object is defined by the environment's State model. """, ), GetEndpointConfig( path="/metadata", - handler=self.env.get_metadata, + handler=get_metadata_handler, response_model=EnvironmentMetadata, tag="Environment Info", summary="Get environment metadata", @@ -290,6 +561,7 @@ async def step(request: StepRequest) -> StepResponse: ] register_get_endpoints(app, get_endpoints) + # Register combined schema endpoint @app.get( "/schema", @@ -339,12 +611,171 @@ async def get_schemas() -> SchemaResponse: state=State.model_json_schema(), ) + # Register WebSocket endpoint for persistent sessions + @app.websocket("/ws") + async def websocket_endpoint(websocket: WebSocket): + """ + WebSocket endpoint for persistent environment sessions. + + Each WebSocket connection gets its own environment instance. + + Message Protocol: + - Client sends: WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage + - Server responds: WSObservationResponse | WSStateResponse | WSErrorResponse + """ + await websocket.accept() + + session_id = None + session_env = None + + try: + # Create session with dedicated environment + session_id, session_env = await self._create_session() + + while True: + # Receive message from client + raw_message = await websocket.receive_text() + + try: + message_dict = json.loads(raw_message) + except json.JSONDecodeError as e: + error_resp = WSErrorResponse( + data={ + "message": f"Invalid JSON: {e}", + "code": "INVALID_JSON", + } + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + msg_type = message_dict.get("type", "") + + try: + match msg_type: + case "reset": + msg = WSResetMessage(**message_dict) + + is_async = session_env.reset_async.__func__ is not Environment.reset_async + + if is_async: + sig = inspect.signature(session_env.reset_async) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await session_env.reset_async(**valid_kwargs) + else: + sig = inspect.signature(session_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await self._run_in_session_executor( + session_id, session_env.reset, **valid_kwargs + ) + + self._update_session_activity(session_id) + + response = WSObservationResponse( + data=serialize_observation(observation) + ) + + case "step": + msg = WSStepMessage(**message_dict) + action = deserialize_action(msg.data, self.action_cls) + + is_async = session_env.step_async.__func__ is not Environment.step_async + + if is_async: + observation = await session_env.step_async(action) + else: + observation = await self._run_in_session_executor( + session_id, session_env.step, action + ) + + self._update_session_activity( + session_id, increment_step=True + ) + + response = WSObservationResponse( + data=serialize_observation(observation) + ) + + case "state": + msg = WSStateMessage(**message_dict) + state = session_env.state + if hasattr(state, "model_dump"): + state_data = state.model_dump() + else: + state_data = dict(state) if state else {} + + response = WSStateResponse(data=state_data) + + case "close": + msg = WSCloseMessage(**message_dict) + break + + case _: + response = WSErrorResponse( + data={ + "message": f"Unknown message type: {msg_type}", + "code": "UNKNOWN_TYPE", + } + ) + + await websocket.send_text(response.model_dump_json()) + + except ValidationError as e: + error_resp = WSErrorResponse( + data={ + "message": "Invalid message", + "code": "VALIDATION_ERROR", + "errors": e.errors(), + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception as e: + error_resp = WSErrorResponse( + data={"message": str(e), "code": "EXECUTION_ERROR"} + ) + await websocket.send_text(error_resp.model_dump_json()) + + except WebSocketDisconnect: + pass + except SessionCapacityError as e: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": "CAPACITY_REACHED", + "active_sessions": e.active_sessions, + "max_sessions": e.max_sessions, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except EnvironmentFactoryError as e: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": "FACTORY_ERROR", + "factory_name": e.factory_name, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception as e: + error_resp = WSErrorResponse( + data={"message": str(e), "code": "SESSION_ERROR"} + ) + await websocket.send_text(error_resp.model_dump_json()) + finally: + if session_id: + await self._destroy_session(session_id) + try: + await websocket.close() + except RuntimeError: + pass + def create_app( - env: Environment, + env: Callable[[], Environment], action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: """ Create a FastAPI application with or without web interface. @@ -353,10 +784,14 @@ def create_app( including README integration for better user experience. Args: - env: The Environment instance to serve + env: Environment factory (callable) that creates new instances action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. Returns: FastAPI application instance with or without web interface and README integration @@ -373,18 +808,43 @@ def create_app( # Import web interface only when needed from .web_interface import create_web_interface_app - return create_web_interface_app(env, action_cls, observation_cls, env_name) + return create_web_interface_app( + env, + action_cls, + observation_cls, + env_name, + max_concurrent_envs, + concurrency_config, + ) else: # Use standard FastAPI app without web interface - return create_fastapi_app(env, action_cls, observation_cls) + return create_fastapi_app( + env, action_cls, observation_cls, max_concurrent_envs, concurrency_config + ) def create_fastapi_app( - env: Environment, + env: Callable[[], Environment], action_cls: Type[Action], observation_cls: Type[Observation], + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[ConcurrencyConfig] = None, ) -> FastAPI: - """Create a FastAPI application with comprehensive documentation.""" + """ + Create a FastAPI application with comprehensive documentation. + + Args: + env: Environment factory (callable) that creates new instances + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. + + Returns: + FastAPI application instance + """ try: from fastapi import FastAPI except ImportError: @@ -452,6 +912,12 @@ def create_fastapi_app( }, ) - server = HTTPEnvServer(env, action_cls, observation_cls) + server = HTTPEnvServer( + env, + action_cls, + observation_cls, + max_concurrent_envs, + concurrency_config=concurrency_config, + ) server.register_routes(app) return app diff --git a/src/openenv/core/env_server/interfaces.py b/src/openenv/core/env_server/interfaces.py index b438cd667..c02ba4a05 100644 --- a/src/openenv/core/env_server/interfaces.py +++ b/src/openenv/core/env_server/interfaces.py @@ -5,10 +5,14 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any, Optional, Protocol, TypedDict +from typing import Any, Generic, Optional, Protocol, TypedDict, TypeVar from .types import Action, Observation, State, EnvironmentMetadata +ActT = TypeVar("ActT", bound=Action) +ObsT = TypeVar("ObsT", bound=Observation) +StateT = TypeVar("StateT", bound=State) + class Message(TypedDict): """A message in a conversation. @@ -64,7 +68,7 @@ def decode( ... -class Transform(ABC): +class Transform(ABC, Generic[ObsT]): """Transform observations to add rewards, metrics, or other modifications. Transforms follow the TorchRL pattern where they take an observation @@ -73,7 +77,7 @@ class Transform(ABC): """ @abstractmethod - def __call__(self, observation: Observation) -> Observation: + def __call__(self, observation: ObsT) -> ObsT: """Transform an observation. Args: @@ -85,14 +89,28 @@ def __call__(self, observation: Observation) -> Observation: pass -class Environment(ABC): +class Environment(ABC, Generic[ActT, ObsT, StateT]): """Base class for all environment servers following Gym/Gymnasium API. Args: transform: Optional transform to apply to observations + + Class Attributes: + SUPPORTS_CONCURRENT_SESSIONS: Whether this environment supports concurrent sessions. + When True, multiple WebSocket connections can each have their own + environment instance (up to max_concurrent_envs). When False (default), + the environment should only be used with a single session at a time. + + Set this to True in your Environment subclass if: + - The environment uses proper session isolation (e.g., unique working dirs) + - No shared mutable state exists between instances + - External resources (databases, APIs) can handle concurrent access """ + + # Class-level flag indicating whether this environment supports concurrent sessions + SUPPORTS_CONCURRENT_SESSIONS: bool = False - def __init__(self, transform: Transform | None = None): + def __init__(self, transform: Optional[Transform[ObsT]] = None): self.transform = transform @abstractmethod @@ -101,23 +119,47 @@ def reset( seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, - ) -> Observation: + ) -> ObsT: """Reset the environment and return initial observation.""" pass + async def reset_async( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> ObsT: + """Async version of reset. Default implementation calls sync reset. + + Override to provide true async implementation. + """ + return self.reset(seed=seed, episode_id=episode_id, **kwargs) + @abstractmethod def step( self, - action: Action, + action: ActT, timeout_s: Optional[float] = None, **kwargs: Any, - ) -> Observation: + ) -> ObsT: """Take a step in the environment.""" pass + async def step_async( + self, + action: ActT, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> ObsT: + """Async version of step. Default implementation calls sync step. + + Override to provide true async implementation. + """ + return self.step(action, timeout_s=timeout_s, **kwargs) + @property @abstractmethod - def state(self) -> State: + def state(self) -> StateT: """Get the current environment state.""" pass @@ -137,8 +179,16 @@ def get_metadata(self) -> EnvironmentMetadata: version="1.0.0", ) - def _apply_transform(self, observation: Observation) -> Observation: + def _apply_transform(self, observation: ObsT) -> ObsT: """Apply transform if one is provided.""" if self.transform is not None: return self.transform(observation) return observation + + def close(self) -> None: + """Clean up resources used by the environment. + + Override this method to implement custom cleanup logic. + Called when the environment is being destroyed or reset. + """ + pass diff --git a/src/openenv/core/env_server/serialization.py b/src/openenv/core/env_server/serialization.py index a97a05283..9e88a33c9 100644 --- a/src/openenv/core/env_server/serialization.py +++ b/src/openenv/core/env_server/serialization.py @@ -80,7 +80,7 @@ def deserialize_action_with_preprocessing( value = [] if isinstance(value, list): try: - import torch + import torch # type: ignore processed_data[key] = torch.tensor(value, dtype=torch.long) except ImportError: @@ -109,9 +109,9 @@ def serialize_observation(observation: Observation) -> Dict[str, Any]: observation: Observation instance Returns: - Dictionary compatible with HTTPEnvClient._parse_result() + Dictionary compatible with EnvClient._parse_result() - The format matches what HTTPEnvClient expects: + The format matches what EnvClient expects: { "observation": {...}, # Observation fields "reward": float | None, @@ -131,7 +131,7 @@ def serialize_observation(observation: Observation) -> Dict[str, Any]: reward = observation.reward done = observation.done - # Return in HTTPEnvClient expected format + # Return in EnvClient expected format return { "observation": obs_dict, "reward": reward, diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py index c3ee689c0..a22914b73 100644 --- a/src/openenv/core/env_server/types.py +++ b/src/openenv/core/env_server/types.py @@ -4,8 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, Optional, Union -from pydantic import BaseModel, Field, ConfigDict +from typing import Any, Dict, Optional, Union, Literal, Annotated +from pydantic import BaseModel, Field, ConfigDict, model_validator # Type aliases @@ -127,6 +127,15 @@ class StepResponse(BaseModel): done: bool = Field(default=False, description="Whether the episode has terminated") +class BaseMessage(BaseModel): + """Base class for WebSocket messages with shared configuration.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + class State(BaseModel): """Base class for environment state. @@ -149,27 +158,17 @@ class State(BaseModel): ) -class CodeExecResult(BaseModel): +class CodeExecResult(BaseMessage): """Result of code execution containing stdout, stderr, and exit code.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - stdout: str = Field(description="Standard output from code execution") stderr: str = Field(description="Standard error from code execution") exit_code: int = Field(description="Exit code from code execution") -class EnvironmentMetadata(BaseModel): +class EnvironmentMetadata(BaseMessage): """Metadata about an environment for documentation and UI purposes.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - name: str = Field(description="Name of the environment") description: str = Field(description="Description of what the environment does") readme_content: Optional[str] = Field( @@ -184,14 +183,9 @@ class EnvironmentMetadata(BaseModel): ) -class SchemaResponse(BaseModel): +class SchemaResponse(BaseMessage): """Response model for the combined schema endpoint.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, - ) - action: Dict[str, Any] = Field( description="JSON schema for actions accepted by this environment" ) @@ -203,12 +197,145 @@ class SchemaResponse(BaseModel): ) -class HealthResponse(BaseModel): +class HealthResponse(BaseMessage): """Response model for health check endpoint.""" - model_config = ConfigDict( - extra="forbid", - validate_assignment=True, + status: str = Field(description="Health status of the environment server") + + +class WSResetMessage(BaseMessage): + """WebSocket message to reset the environment.""" + + type: Literal["reset"] = Field(default="reset", description="Message type") + data: Dict[str, Any] = Field( + default_factory=dict, + description="Optional reset parameters (seed, episode_id, etc.)", ) - status: str = Field(description="Health status of the environment server") + +class WSStepMessage(BaseMessage): + """WebSocket message to execute a step.""" + + type: Literal["step"] = Field(default="step", description="Message type") + data: Dict[str, Any] = Field( + ..., description="Action data conforming to environment's action schema" + ) + + +class WSStateMessage(BaseMessage): + """WebSocket message to request current state.""" + + type: Literal["state"] = Field(default="state", description="Message type") + + +class WSCloseMessage(BaseMessage): + """WebSocket message to close the session.""" + + type: Literal["close"] = Field(default="close", description="Message type") + + +# Discriminated union for incoming WebSocket messages +WSIncomingMessage = Annotated[ + WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage, + Field(discriminator="type"), +] + + +class WSObservationResponse(BaseModel): + """WebSocket response containing an observation.""" + + model_config = ConfigDict(extra="forbid") + + type: str = Field(default="observation", description="Response type") + data: Dict[str, Any] = Field(description="Observation data") + + +class WSStateResponse(BaseModel): + """WebSocket response containing environment state.""" + + model_config = ConfigDict(extra="forbid") + + type: str = Field(default="state", description="Response type") + data: Dict[str, Any] = Field(description="State data") + + +class WSErrorResponse(BaseModel): + """WebSocket response for errors.""" + + model_config = ConfigDict(extra="forbid") + + type: str = Field(default="error", description="Response type") + data: Dict[str, Any] = Field(description="Error details including message and code") + + +class ConcurrencyConfig(BaseMessage): + """Configuration for concurrent environment sessions.""" + + max_concurrent_envs: int = Field( + default=1, + ge=1, + description="Maximum number of concurrent WebSocket sessions allowed", + ) + session_timeout: Optional[float] = Field( + default=None, + gt=0, + description="Timeout in seconds for inactive sessions. None means no timeout.", + ) + + +class ServerCapacityStatus(BaseMessage): + """Status of server capacity for concurrent sessions.""" + + active_sessions: int = Field( + ge=0, + description="Number of currently active sessions", + ) + max_sessions: int = Field( + ge=1, + description="Maximum number of allowed sessions", + ) + + @model_validator(mode="after") + def check_capacity_bounds(self) -> "ServerCapacityStatus": + if self.active_sessions > self.max_sessions: + raise ValueError( + f"active_sessions ({self.active_sessions}) cannot exceed " + f"max_sessions ({self.max_sessions})" + ) + return self + + @property + def available_slots(self) -> int: + """Number of available session slots.""" + return self.max_sessions - self.active_sessions + + @property + def is_at_capacity(self) -> bool: + """Whether the server has reached maximum capacity.""" + return self.available_slots == 0 + + @classmethod + def from_counts(cls, active: int, max_sessions: int) -> "ServerCapacityStatus": + """Create status from active and max session counts.""" + return cls( + active_sessions=active, + max_sessions=max_sessions, + ) + + +class SessionInfo(BaseMessage): + """Information about an active session.""" + + session_id: str = Field(description="Unique identifier for the session") + created_at: float = Field(description="Unix timestamp when the session was created") + last_activity_at: float = Field( + description="Unix timestamp of the last activity in the session" + ) + step_count: int = Field( + default=0, + ge=0, + description="Number of steps executed in this session", + ) + environment_type: str = Field( + description="Environment type for this session (e.g. `CodingEnv`)" + ) diff --git a/src/openenv/core/http_env_client.py b/src/openenv/core/http_env_client.py deleted file mode 100644 index 007ef6a5f..000000000 --- a/src/openenv/core/http_env_client.py +++ /dev/null @@ -1,236 +0,0 @@ -""" -core/runner_env.py -Minimal HTTP-based environment client. -- Talks to a single env worker exposing: POST /reset, POST /step - -Future hooks (commented below) for: -- episode_id, seed on reset -- request_id on step -- custom headers (auth/trace) -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar - -import requests - -from .client_types import StepResult -from .containers.runtime import LocalDockerProvider - -if TYPE_CHECKING: - from .containers.runtime import ContainerProvider - -ActT = TypeVar("ActT") -ObsT = TypeVar("ObsT") -EnvClientT = TypeVar("EnvClientT", bound="HTTPEnvClient") - - -class HTTPEnvClient(ABC, Generic[ActT, ObsT]): - def __init__( - self, - base_url: str, - request_timeout_s: float = 15.0, - default_headers: Optional[Dict[str, str]] = None, - provider: Optional["ContainerProvider"] = None, - ): - self._base = base_url.rstrip("/") - self._timeout = float(request_timeout_s) - self._http = requests.Session() - self._headers = default_headers or {} - self._provider = provider - - @classmethod - def from_docker_image( - cls: Type[EnvClientT], - image: str, - provider: Optional["ContainerProvider"] = None, - **kwargs: Any, - ) -> EnvClientT: - """ - Create an environment client by spinning up a Docker container locally. - - This is a development utility that: - 1. Starts a Docker container from the specified image - 2. Waits for the server to be ready - 3. Creates and returns a client instance connected to the container - - Note: The container lifecycle management is left to the user or higher-level - orchestration. The container will keep running until manually stopped. - - Args: - image: Docker image name to run (e.g., "echo-env:latest") - provider: Container provider to use (defaults to LocalDockerProvider) - **kwargs: Additional arguments to pass to provider.start_container() - (e.g., env_vars, port) - - Returns: - An instance of the client class connected to the running container - - Example: - >>> from envs.coding_env.client import CodingEnv - >>> from envs.coding_env.models import CodeAction - >>> - >>> # Create environment from image - >>> env = CodingEnv.from_docker_image("coding-env:latest") - >>> - >>> # Create environment with custom env vars - >>> env = CodingEnv.from_docker_image( - ... "coding-env:latest", - ... env_vars={"MY_VAR": "value"} - ... ) - >>> - >>> # Use the environment - >>> result = env.reset() - >>> print(result.observation) - >>> - >>> step_result = env.step(CodeAction(code="print('hello')")) - >>> print(step_result.observation.stdout) - >>> - >>> # Cleanup (optional) - >>> env.close() - """ - - # Use default provider if none provided - if provider is None: - provider = LocalDockerProvider() - - # 1. Start container with optional kwargs (e.g., env_vars, port) - base_url = provider.start_container(image, **kwargs) - - # 2. Wait for server to be ready - provider.wait_for_ready(base_url) - - # 3. Create and return client instance with provider reference - return cls(base_url=base_url, provider=provider) - - @classmethod - def from_hub( - cls: Type[EnvClientT], - repo_id: str, - provider: Optional["ContainerProvider"] = None, - **kwargs: Any, - ) -> EnvClientT: - """ - Create an environment client by pulling from a Hugging Face model hub. - """ - - if provider is None: - provider = LocalDockerProvider() - - if "tag" in kwargs: - tag = kwargs["tag"] - else: - tag = "latest" - - base_url = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}" - - return cls.from_docker_image(image=base_url, provider=provider) - - @abstractmethod - def _step_payload(self, action: ActT) -> dict: - """Convert an Action object to the JSON body expected by the env server.""" - raise NotImplementedError - - @abstractmethod - def _parse_result(self, payload: dict) -> StepResult[ObsT]: - """Convert a JSON response from the env server to StepResult[ObsT].""" - raise NotImplementedError - - @abstractmethod - def _parse_state(self, payload: dict) -> Any: - """Convert a JSON response from the state endpoint to a State object.""" - raise NotImplementedError - - # ---------- Environment Server Interface Methods ---------- - def reset(self, **kwargs: Any) -> StepResult[ObsT]: - """ - Reset the environment with optional parameters. - - Args: - **kwargs: Optional parameters passed to the environment's reset method. - Common parameters include: - - seed: Random seed for reproducibility - - episode_id: Custom episode identifier - - Any environment-specific reset parameters - - Returns: - StepResult containing initial observation - - Example: - >>> env.reset(seed=42, episode_id="ep-001") - """ - body: Dict[str, Any] = kwargs.copy() - r = self._http.post( - f"{self._base}/reset", - json=body, - headers=self._headers, - timeout=self._timeout, - ) - r.raise_for_status() - return self._parse_result(r.json()) - - def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: - """ - Execute an action in the environment with optional parameters. - - Args: - action: The action to execute - **kwargs: Optional parameters passed to the environment's step method. - Common parameters include: - - timeout_s: Execution timeout in seconds - - request_id: Request identifier for tracking - - render: Whether to render the environment - - Any environment-specific step parameters - - Returns: - StepResult containing observation, reward, and done status - - Example: - >>> env.step(action, timeout_s=30.0, request_id="req-123", render=True) - """ - body: Dict[str, Any] = { - "action": self._step_payload(action), - **kwargs # Forward all additional parameters - } - r = self._http.post( - f"{self._base}/step", - json=body, - headers=self._headers, - timeout=self._timeout, - ) - r.raise_for_status() - return self._parse_result(r.json()) - - def state(self) -> Any: - """ - Get the current environment state from the server. - - Returns: - State object with environment state information (e.g., episode_id, step_count) - - Example: - >>> client = EchoEnv.from_docker_image("echo-env:latest") - >>> result = client.reset() - >>> state = client.state() - >>> print(state.episode_id) - >>> print(state.step_count) - """ - r = self._http.get( - f"{self._base}/state", - headers=self._headers, - timeout=self._timeout, - ) - r.raise_for_status() - return self._parse_state(r.json()) - - def close(self) -> None: - """ - Close the environment and clean up resources. - - If this client was created via from_docker_image(), this will stop - and remove the associated container. - """ - if self._provider is not None: - self._provider.stop_container() diff --git a/src/openenv/core/utils.py b/src/openenv/core/utils.py new file mode 100644 index 000000000..42e9cee82 --- /dev/null +++ b/src/openenv/core/utils.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Utility functions for OpenEnv core.""" + +def convert_to_ws_url(url: str) -> str: + """ + Convert an HTTP/HTTPS URL to a WS/WSS URL. + + Args: + url: The URL to convert. + + Returns: + The converted WebSocket URL. + """ + ws_url = url.rstrip("/") + if ws_url.startswith("http://"): + ws_url = "ws://" + ws_url[7:] + elif ws_url.startswith("https://"): + ws_url = "wss://" + ws_url[8:] + elif not ws_url.startswith("ws://") and not ws_url.startswith("wss://"): + ws_url = "ws://" + ws_url + return ws_url