diff --git a/.gitignore b/.gitignore index 52e67d6..7099d42 100644 --- a/.gitignore +++ b/.gitignore @@ -56,6 +56,7 @@ venv.bak/ # Performance test results tests/performance/results/ .coverage* +coverage* # Local test script tools/tmp/* diff --git a/app/api/routes/proxy.py b/app/api/routes/proxy.py index 782ee88..cc9e5e1 100644 --- a/app/api/routes/proxy.py +++ b/app/api/routes/proxy.py @@ -15,6 +15,7 @@ EmbeddingsRequest, ImageEditsRequest, ImageGenerationRequest, + ResponsesRequest, ) from app.core.async_cache import forge_scope_cache_async, get_forge_scope_cache_async from app.core.database import get_async_db @@ -283,3 +284,51 @@ async def create_embeddings( raise HTTPException( status_code=500, detail=f"Error processing request: {str(err)}" ) from err + +@router.post("/responses") +async def create_responses( + request: Request, + responses_request: ResponsesRequest, + user_details: dict[str, Any] = Depends(get_user_details_by_api_key), + db: AsyncSession = Depends(get_async_db), +) -> Any: + """ + Create a response (OpenAI-compatible endpoint). + """ + try: + user = user_details["user"] + api_key_id = user_details["api_key_id"] + provider_service = await ProviderService.async_get_instance(user, db, api_key_id=api_key_id) + allowed_provider_names = await _get_allowed_provider_names(request, db) + + response = await provider_service.process_request( + "responses", + responses_request.model_dump(mode="json", exclude_unset=True), + allowed_provider_names=allowed_provider_names, + ) + + # Check if it's a streaming response + if inspect.isasyncgen(response): + headers = { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", # Prevent Nginx buffering + } + + return StreamingResponse( + response, media_type="text/event-stream", headers=headers + ) + + # Otherwise, return the JSON response directly + return response + except NotImplementedError as err: + raise HTTPException( + status_code=404, detail=f"Error processing request: {str(err)}" + ) from err + except ValueError as err: + raise HTTPException(status_code=400, detail=str(err)) from err + except Exception as err: + raise HTTPException( + status_code=500, detail=f"Error processing request: {str(err)}" + ) from err diff --git a/app/api/schemas/openai.py b/app/api/schemas/openai.py index eef3b8e..d397512 100644 --- a/app/api/schemas/openai.py +++ b/app/api/schemas/openai.py @@ -252,3 +252,202 @@ class EmbeddingsRequest(BaseModel): encoding_format: str | None = 'float' # inpput_type is for cohere embeddings only input_type: str | None = 'search_document' + + +# --------------------------------------------------------------------------- +# OpenAI Responses Request +# https://platform.openai.com/docs/api-reference/responses/create +# --------------------------------------------------------------------------- +class ResponsesInputTextItem(BaseModel): + text: str + type: str # always input_text + +class ResponsesInputImageItem(BaseModel): + detail: str | None = 'auto' + type: str # always input_image + file_id: str | None = None + image_url: str | None = None + +class ResponsesInputFileItem(BaseModel): + type: str # always input_file + file_data: str | None = None + file_id: str | None = None + file_url: str | None = None + filename: str | None = None + +class ResponsesInputAudioItem(BaseModel): + input_audio: object + type: str # always input_audio + +class ResponsesInputMessageItem(BaseModel): + role: str + type: str | None = None + content: str | list[ResponsesInputTextItem | ResponsesInputImageItem | ResponsesInputFileItem | ResponsesInputAudioItem] + + +class ResponsesItemInputMessage(BaseModel): + role: str + content: list[ResponsesInputTextItem | ResponsesInputImageItem | ResponsesInputFileItem | ResponsesInputAudioItem] + status: str | None = None + type: str | None = None + +class ResponsesItemOutputMessage(BaseModel): + content: list[object] + id: str + role: str + status: str + type: str + +class ResponsesItemFileSearchToolCall(BaseModel): + id: str + query: str + status: str + type: str + results: list[object] + +class ResponsesItemComputerToolCall(BaseModel): + action: object + call_id: str + id: str + pending_safety_checks: list[object] + status: str + type: str + +class ResponsesItemComputerToolCallOutput(BaseModel): + call_id: str + output: object + type: str + acknowledged_safety_checks: list[object] | None = None + id: str | None = None + status: str | None = None + +class ResponsesItemWebSearchToolCall(BaseModel): + action: object + id: str + status: str + type: str + +class ResponsesItemFunctionToolCall(BaseModel): + arguments: str + call_id: str + name: str + type: str + id: str | None = None + status: str | None = None + +class ResponsesItemFunctionToolCallOutput(BaseModel): + call_id: str + output: str | list[object] + type: str + id: str | None = None + status: str | None = None + +class ResponsesItemReasoning(BaseModel): + id: str + summary: list[object] + type: str + content: list[object] | None = None + encrypted_content: str | None = None + status: str | None = None + +class ResponsesItemImageGenerationCall(BaseModel): + id: str + result: str + status: str + type: str + +class ResponsesItemCodeInterpreterToolCall(BaseModel): + code: str + container_id: str + id: str + outputs: list[object] + status: str + type: str + +class ResponsesItemLocalShellCall(BaseModel): + action: object + call_id: str + id: str + status: str + type: str + +class ResponsesItemLocalShellCallOutput(BaseModel): + id: str + output: str + type: str + status: str | None = None + +class ResponsesItemMCPListTools(BaseModel): + id: str + server_label: str + tools: list[object] + type: str + error: str | None = None + +class ResponsesItemMCPApprovalRequest(BaseModel): + arguments: str + id: str + name: str + server_label: str + type: str + +class ResponsesItemMCPApprovalResponse(BaseModel): + approval_request_id: str + approve: bool + type: str + id: str | None = None + reason: str | None = None + +class ResponsesItemMCPToolCall(BaseModel): + arguments: str + id: str + name: str + server_label: str + type: str + error: str | None = None + output: str | None = None + +class ResponsesItemCustomToolCallOutput(BaseModel): + call_id: str + output: str | list[object] + type: str + id: str | None = None + +class ResponsesItemCustomToolCall(BaseModel): + call_id: str + input: str + name: str + type: str + id: str | None = None + +class ResponsesItemReference(BaseModel): + id: str + type: str + +class ResponsesRequest(BaseModel): + background: bool | None = False + conversation: str | object | None = None + include: list[Any] | None = None + input: str | list[ResponsesInputMessageItem | ResponsesItemReference | ResponsesItemInputMessage | ResponsesItemFileSearchToolCall | ResponsesItemComputerToolCall | ResponsesItemWebSearchToolCall | ResponsesItemFunctionToolCall | ResponsesItemReasoning | ResponsesItemImageGenerationCall | ResponsesItemCodeInterpreterToolCall | ResponsesItemLocalShellCall | ResponsesItemMCPListTools | ResponsesItemMCPApprovalRequest | ResponsesItemMCPApprovalResponse | ResponsesItemMCPToolCall | ResponsesItemCustomToolCallOutput | ResponsesItemCustomToolCall] | None = None + instructions: str | None = None + max_output_tokens: int | None = None + max_tool_calls: int | None = None + metadata: dict[Any, Any] | None = None + model: str | None = None + parallel_tool_calls: bool | None = True + previous_response_id: str | None = None + prompt: object | None = None + prompt_cache_key: str | None = None + reasoning: object | None = None + safety_identifier: str | None = None + service_tier: str | None = 'auto' + store: bool | None = True + stream: bool | None = False + stream_options: object | None = None + temperature: float | None = 1.0 + text: object | None = None + tool_choice: str | object | None = None + tools: list[Any] | None = None + top_logprobs: int | None = None + top_p: float | None = 1.0 + truncation: str | None = 'disabled' diff --git a/app/services/provider_service.py b/app/services/provider_service.py index ff2446e..e492ad9 100644 --- a/app/services/provider_service.py +++ b/app/services/provider_service.py @@ -620,6 +620,12 @@ async def process_request( payload, api_key, ) + elif "responses" == endpoint: + result = await adapter.process_responses( + endpoint, + payload, + api_key, + ) elif "images/generations" in endpoint: # TODO: we only support openai for now if provider_name != "openai": diff --git a/app/services/providers/base.py b/app/services/providers/base.py index 0ccb392..206461c 100644 --- a/app/services/providers/base.py +++ b/app/services/providers/base.py @@ -36,6 +36,27 @@ async def process_completion( ) -> Any: """Process a completion request""" pass + + async def process_responses( + self, + endpoint: str, + payload: dict[str, Any], + api_key: str, + base_url: str | None = None, + ) -> Any: + """Process a response request""" + # TODO: currently it's openai only + raise NotImplementedError("Process response is not implemented") + + async def process_conversations( + self, + endpoint: str, + payload: dict[str, Any], + api_key: str, + base_url: str | None = None, + ) -> Any: + """Process a conversations request""" + raise NotImplementedError("Process conversations is not implemented") @abstractmethod async def process_embeddings( diff --git a/app/services/providers/openai_adapter.py b/app/services/providers/openai_adapter.py index 20f5036..7594795 100644 --- a/app/services/providers/openai_adapter.py +++ b/app/services/providers/openai_adapter.py @@ -352,3 +352,68 @@ async def process_embeddings( "usage": total_usage, } return final_response + + async def process_responses( + self, + endpoint: str, + payload: dict[str, Any], + api_key: str, + base_url: str | None = None, + ) -> Any: + """Process a response request using OpenAI API""" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + + url = f"{base_url or self._base_url}/{endpoint}" + + # Check if streaming is requested + streaming = payload.get("stream", False) + if streaming: + # For streaming, return a streaming generator + async def stream_response() -> AsyncGenerator[bytes, None]: + async with ( + aiohttp.ClientSession() as session, + session.post( + url, headers=headers, json=payload + ) as response, + ): + if response.status != HTTPStatus.OK: + error_text = await response.text() + logger.error( + f"Responses Streaming API error for {self.provider_name}: {error_text}" + ) + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text, + ) + + # Stream the response back + async for chunk in response.content: + if chunk: + yield chunk + + # Return the streaming generator + return stream_response() + else: + # For non-streaming, use the regular approach + async with ( + aiohttp.ClientSession() as session, + session.post( + url, headers=headers, json=payload + ) as response, + ): + if response.status != HTTPStatus.OK: + error_text = await response.text() + logger.error( + f"Responses API error for {self.provider_name}: {error_text}" + ) + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text, + ) + + return await response.json() diff --git a/tests/unit_tests/assets/openai/responses_response_1.json b/tests/unit_tests/assets/openai/responses_response_1.json new file mode 100644 index 0000000..88bd7d8 --- /dev/null +++ b/tests/unit_tests/assets/openai/responses_response_1.json @@ -0,0 +1,53 @@ +{ + "id": "resp_0a379700213743ce0068db3cded47481a29d3552eea69e6939", + "object": "response", + "created_at": 1759198430, + "status": "completed", + "background": false, + "billing": {"payer": "developer"}, + "error": null, + "incomplete_details": null, + "instructions": null, + "max_output_tokens": null, + "max_tool_calls": null, + "model": "gpt-4o-mini-2024-07-18", + "output": [ + { + "id": "msg_0a379700213743ce0068db3cdf654481a29d5a27842a0e095f", + "type": "message", + "status": "completed", + "content": [ + { + "type": "output_text", + "annotations": [], + "logprobs": [], + "text": "Hello! I'm doing well, thank you. How about you?" + } + ], + "role": "assistant" + } + ], + "parallel_tool_calls": true, + "previous_response_id": null, + "prompt_cache_key": null, + "reasoning": {"effort": null, "summary": null}, + "safety_identifier": null, + "service_tier": "default", + "store": true, + "temperature": 1.0, + "text": {"format": {"type": "text"}, "verbosity": "medium"}, + "tool_choice": "auto", + "tools": [], + "top_logprobs": 0, + "top_p": 1.0, + "truncation": "disabled", + "usage": { + "input_tokens": 13, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens": 14, + "output_tokens_details": {"reasoning_tokens": 0}, + "total_tokens": 27 + }, + "user": null, + "metadata": {} +} diff --git a/tests/unit_tests/test_openai_provider.py b/tests/unit_tests/test_openai_provider.py index 7d9edd3..51602dd 100644 --- a/tests/unit_tests/test_openai_provider.py +++ b/tests/unit_tests/test_openai_provider.py @@ -6,10 +6,12 @@ from app.services.providers.openai_adapter import OpenAIAdapter from tests.unit_tests.utils.helpers import ( ClientSessionMock, - validate_chat_completion_response, OPENAAI_STANDARD_CHAT_COMPLETION_RESPONSE, + OPENAI_STANDARD_RESPONSES_RESPONSE, + validate_chat_completion_response, process_openai_streaming_response, validate_chat_completion_streaming_response, + validate_responses_response, ) CURRENT_DIR = os.path.dirname(__file__) @@ -30,6 +32,11 @@ ) as f: MOCK_CHAT_COMPLETION_STREAMING_RESPONSE_DATA = json.load(f) +with open( + os.path.join(CURRENT_DIR, "assets", "openai", "responses_response_1.json"), "r" +) as f: + MOCK_RESPONSES_RESPONSE_DATA = json.load(f) + with open( os.path.join(CURRENT_DIR, "assets", "openai", "embeddings_response.json"), "r" ) as f: @@ -113,6 +120,28 @@ async def test_chat_completion_streaming(self): expected_model="gpt-4o-mini-2024-07-18", expected_message=OPENAAI_STANDARD_CHAT_COMPLETION_RESPONSE, ) + + async def test_responses(self): + payload = { + "model": "gpt-4o-mini", + "messages": [{"input": "Hello, how are you?"}], + "stream": False, + } + with patch("aiohttp.ClientSession", ClientSessionMock()) as mock_session: + mock_session.responses = [ + (MOCK_RESPONSES_RESPONSE_DATA, 200) + ] + + # Call the method + result = await self.adapter.process_responses( + api_key=self.api_key, payload=payload, endpoint="responses" + ) + # Assert the result contains the expected model IDs + validate_responses_response( + result, + expected_model="gpt-4o-mini-2024-07-18", + expected_message=OPENAI_STANDARD_RESPONSES_RESPONSE, + ) async def test_process_embeddings(self): payload = { diff --git a/tests/unit_tests/utils/helpers.py b/tests/unit_tests/utils/helpers.py index a62554c..656840e 100644 --- a/tests/unit_tests/utils/helpers.py +++ b/tests/unit_tests/utils/helpers.py @@ -93,6 +93,7 @@ def post(self, url, *args, **kwargs): OPENAAI_STANDARD_CHAT_COMPLETION_RESPONSE = "Hello! I'm just a program, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?" +OPENAI_STANDARD_RESPONSES_RESPONSE = "Hello! I'm doing well, thank you. How about you?" ANTHROPIC_STANDARD_CHAT_COMPLETION_RESPONSE = "Hello! I'm doing well, thank you for asking. I'm here and ready to help with whatever you'd like to discuss or work on. How are you doing today?" GOOGLE_STANDARD_CHAT_COMPLETION_RESPONSE = ( "I am doing well, thank you for asking. How are you today?\n" @@ -182,3 +183,22 @@ def validate_chat_completion_streaming_response( expected_usage["prompt_tokens_details"] = usage["prompt_tokens_details"] if "completion_tokens_details" in expected_usage: expected_usage["completion_tokens_details"] = usage["completion_tokens_details"] + +def validate_responses_response( + response: dict, + expected_model: str = None, + expected_message: str = None, + expected_usage: dict = None, +): + assert "model" in response, "model is required" + assert "output" in response, "output is required" + assert "usage" in response, "usage is required" + + if expected_model: + assert response["model"] == expected_model + if expected_message: + assert response["output"][0]["content"][0]["text"] == expected_message + if expected_usage: + usage = response["usage"] + assert usage["input_tokens"] == expected_usage["input_tokens"] + assert usage["output_tokens"] == expected_usage["output_tokens"]