diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 7fa7c0f3..d301fa7e 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -84,8 +84,14 @@ def __init__( self.action_cls = action_cls self.observation_cls = observation_cls # Create thread pool for running sync code in async context - # This is needed for environments using sync libraries (e.g., Playwright sync API) - self._executor = ThreadPoolExecutor(max_workers=1) + # This is needed for environments using sync libraries (e.g., Playwright) + # Configurable via OPENENV_THREAD_POOL_SIZE (default: 32) + pool_size = int(os.getenv("OPENENV_THREAD_POOL_SIZE", "32")) + self._executor = ThreadPoolExecutor(max_workers=pool_size) + + # Check if environment has async methods for better concurrency + self._has_step_async = hasattr(env, "step_async") and asyncio.iscoroutinefunction(env.step_async) + self._has_reset_async = hasattr(env, "reset_async") and asyncio.iscoroutinefunction(env.reset_async) async def _run_sync_in_thread_pool(self, func, *args, **kwargs): """Run a synchronous function in the thread pool executor.""" @@ -99,9 +105,7 @@ def _get_valid_kwargs(self, sig, kwargs, skip_params=None): valid_kwargs = {} - has_kwargs = any( - p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() - ) + has_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) for k, v in kwargs.items(): if k in sig.parameters or has_kwargs: @@ -128,13 +132,17 @@ async def reset_handler( kwargs = request.model_dump(exclude_unset=True) # Pass arguments only if environment accepts them - sig = inspect.signature(self.env.reset) + if self._has_reset_async: + sig = inspect.signature(self.env.reset_async) + else: + sig = inspect.signature(self.env.reset) valid_kwargs = self._get_valid_kwargs(sig, kwargs) - # Run synchronous reset in thread pool to avoid blocking event loop - observation = await self._run_sync_in_thread_pool( - self.env.reset, **valid_kwargs - ) + # Use async method if available for better concurrency + if self._has_reset_async: + observation = await self.env.reset_async(**valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool(self.env.reset, **valid_kwargs) return ResetResponse(**serialize_observation(observation)) # Helper function to handle step endpoint @@ -147,22 +155,24 @@ async def step_handler(request: StepRequest) -> StepResponse: action = deserialize_action(action_data, self.action_cls) except ValidationError as e: # Return HTTP 422 with detailed validation errors - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors() - ) + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors()) # Handle optional parameters # Start with all fields from the request, including extra ones, but exclude 'action' kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) # Pass arguments only if environment accepts them - sig = inspect.signature(self.env.step) + if self._has_step_async: + sig = inspect.signature(self.env.step_async) + else: + sig = inspect.signature(self.env.step) valid_kwargs = self._get_valid_kwargs(sig, kwargs, skip_params={"action"}) - # Run synchronous step in thread pool to avoid blocking event loop - observation = await self._run_sync_in_thread_pool( - self.env.step, action, **valid_kwargs - ) + # Use async method if available for better concurrency + if self._has_step_async: + observation = await self.env.step_async(action, **valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool(self.env.step, action, **valid_kwargs) # Return serialized observation return StepResponse(**serialize_observation(observation)) @@ -388,9 +398,7 @@ def create_fastapi_app( try: from fastapi import FastAPI except ImportError: - raise ImportError( - "FastAPI is required. Install with: pip install fastapi uvicorn" - ) + raise ImportError("FastAPI is required. Install with: pip install fastapi uvicorn") app = FastAPI( title="OpenEnv Environment HTTP API",