diff --git a/agentlightning/verl/async_server.py b/agentlightning/verl/async_server.py index 433d99c98..07f82be83 100644 --- a/agentlightning/verl/async_server.py +++ b/agentlightning/verl/async_server.py @@ -10,7 +10,8 @@ from verl.workers.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse -from agentlightning.instrumentation.vllm import ChatCompletionResponsePatched, instrument_vllm +from agentlightning.instrumentation.vllm import instrument_vllm +from agentlightning.logging import configure_logger def _unwrap_ray_remote(cls): @@ -19,9 +20,11 @@ def _unwrap_ray_remote(cls): return cls +logger = configure_logger() + + @ray.remote(num_cpus=1) class PatchedvLLMServer(_unwrap_ray_remote(AsyncvLLMServer)): - def __init__(self, *args, **kwargs): instrument_vllm() super().__init__(*args, **kwargs) @@ -36,10 +39,14 @@ async def chat_completion(self, raw_request: Request): """ request_json = await raw_request.json() request = ChatCompletionRequest(**request_json) - generator = await self.openai_serving_chat.create_chat_completion(request, raw_request) + generator = await self.openai_serving_chat.create_chat_completion( + request, raw_request + ) if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), status_code=generator.code) + status_code = getattr(generator, "code", None) or 500 + logger.error("vLLM chat completion error: %s", generator.model_dump()) + return JSONResponse(content=generator.model_dump(), status_code=status_code) if request.stream: return StreamingResponse(content=generator, media_type="text/event-stream") else: