From a3cf0ed111a5b8cd1942fe78e10f8388e413a0ea Mon Sep 17 00:00:00 2001 From: Lingtong Lu Date: Thu, 16 Oct 2025 23:21:32 -0700 Subject: [PATCH] Update the token tracking (#248) --- app/api/routes/__init__.py | 82 +++++++++++++++++++ app/api/routes/proxy.py | 37 ++------- app/services/provider_service.py | 60 +++++--------- .../providers/usage_tracker_service.py | 1 - 4 files changed, 108 insertions(+), 72 deletions(-) diff --git a/app/api/routes/__init__.py b/app/api/routes/__init__.py index e69de29..b1abe4e 100644 --- a/app/api/routes/__init__.py +++ b/app/api/routes/__init__.py @@ -0,0 +1,82 @@ +from typing import AsyncGenerator +import json +from fastapi import HTTPException +from starlette.responses import StreamingResponse +from app.exceptions.exceptions import ProviderAPIException + +async def wrap_streaming_response_with_error_handling( + logger, async_gen: AsyncGenerator[bytes, None] +) -> StreamingResponse: + """ + Wraps an async generator to catch and properly handle errors in streaming responses. + Returns a StreamingResponse that will: + - Return proper HTTP error status if error occurs before first chunk + - Send error as SSE event if error occurs mid-stream + + Args: + logger: Logger instance for error logging + async_gen: The async generator producing the stream chunks + + Returns: + StreamingResponse with proper error handling + + Raises: + HTTPException: If error occurs before streaming starts + """ + + # Try to get the first chunk BEFORE creating StreamingResponse + # This allows us to catch immediate errors and return proper HTTP status + try: + first_chunk = await async_gen.__anext__() + except StopAsyncIteration: + # Empty stream + logger.error("Empty stream response") + raise HTTPException(status_code=500, detail="Empty stream response") + except ProviderAPIException as e: + logger.error(f"Provider API error: {str(e)}") + raise HTTPException(status_code=e.error_code, detail=e.error_message) from e + except Exception as e: + # Convert other exceptions to HTTPException + logger.error(f"Error before streaming started: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) from e + + # Success! Now create generator that replays first chunk + rest + async def response_generator(): + # Yield the first chunk we already got + yield first_chunk + + try: + # Continue with the rest of the stream + async for chunk in async_gen: + yield chunk + except Exception as e: + # Error occurred mid-stream - HTTP status already sent + # Send error as SSE event to inform the client + logger.error(f"Error during streaming: {str(e)}") + + error_message = str(e) + error_event = { + "error": { + "message": error_message, + "type": "stream_error", + "code": "provider_error" + } + } + yield f"data: {json.dumps(error_event)}\n\n".encode() + + # Send [DONE] to properly close the stream + yield b"data: [DONE]\n\n" + + # Set appropriate headers for streaming + headers = { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", # Prevent Nginx buffering + } + + return StreamingResponse( + response_generator(), + media_type="text/event-stream", + headers=headers + ) \ No newline at end of file diff --git a/app/api/routes/proxy.py b/app/api/routes/proxy.py index cc9e5e1..f38a66a 100644 --- a/app/api/routes/proxy.py +++ b/app/api/routes/proxy.py @@ -23,6 +23,7 @@ from app.models.forge_api_key import ForgeApiKey from app.models.user import User from app.services.provider_service import ProviderService +from app.api.routes import wrap_streaming_response_with_error_handling router = APIRouter() logger = get_logger(name="proxy") @@ -96,23 +97,15 @@ async def create_chat_completion( # Check if it's a streaming response by checking if it's an async generator if inspect.isasyncgen(response): - # Set appropriate headers for streaming - headers = { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", # Prevent Nginx buffering - } - - return StreamingResponse( - response, media_type="text/event-stream", headers=headers - ) + return await wrap_streaming_response_with_error_handling(logger, response) # Otherwise, return the JSON response directly return response except ValueError as err: logger.exception(f"Error processing chat completion request: {str(err)}") raise HTTPException(status_code=400, detail=str(err)) from err + except HTTPException as err: + raise err except Exception as err: logger.exception(f"Error processing chat completion request: {str(err)}") raise HTTPException( @@ -144,16 +137,7 @@ async def create_completion( # Check if it's a streaming response if inspect.isasyncgen(response): - headers = { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", # Prevent Nginx buffering - } - - return StreamingResponse( - response, media_type="text/event-stream", headers=headers - ) + return await wrap_streaming_response_with_error_handling(logger, response) # Otherwise, return the JSON response directly return response @@ -309,16 +293,7 @@ async def create_responses( # Check if it's a streaming response if inspect.isasyncgen(response): - headers = { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", # Prevent Nginx buffering - } - - return StreamingResponse( - response, media_type="text/event-stream", headers=headers - ) + return await wrap_streaming_response_with_error_handling(logger, response) # Otherwise, return the JSON response directly return response diff --git a/app/services/provider_service.py b/app/services/provider_service.py index e492ad9..8a3797d 100644 --- a/app/services/provider_service.py +++ b/app/services/provider_service.py @@ -690,21 +690,21 @@ async def process_request( # re-calculate output tokens output_tokens = max(output_tokens, total_tokens - input_tokens) - asyncio.create_task( - update_usage_in_background( - usage_tracker_id, - input_tokens, - output_tokens, - cached_tokens, - reasoning_tokens, + if input_tokens > 0 or output_tokens > 0: + asyncio.create_task( + update_usage_in_background( + usage_tracker_id, + input_tokens, + output_tokens, + cached_tokens, + reasoning_tokens, + ) ) - ) return result else: # For streaming responses, wrap the generator to count tokens async def token_counting_stream() -> AsyncGenerator[bytes, None]: - approximate_input_tokens = 0 - approximate_output_tokens = 0 + update_usage = True output_tokens = 0 input_tokens = 0 total_tokens = 0 @@ -720,13 +720,6 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: messages = payload.get("messages", []) - # Rough estimate of input tokens based on message length - for msg in messages: - content = msg.get("content", "") - if isinstance(content, str): - # Rough approximation: 4 chars ~= 1 token - approximate_input_tokens += len(content) // 4 - try: async for chunk in result: chunks_processed += 1 @@ -767,21 +760,6 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: # re-calculate output tokens output_tokens = max(output_tokens, total_tokens - input_tokens) - - # Extract content from the chunk based on OpenAI format - if "choices" in data: - for choice in data["choices"]: - if ( - "delta" in choice - and "content" in choice["delta"] - ): - content = choice["delta"]["content"] - # Only count tokens if we don't have final usage data - if content: - # Count tokens in content (approx) - approximate_output_tokens += ( - len(content) // 4 - ) except json.JSONDecodeError: # If JSON parsing fails, just continue pass @@ -799,6 +777,7 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: "Error in streaming response: {}", str(e), exc_info=True ) # Re-raise to propagate the error + update_usage = False raise finally: logger.debug( @@ -807,15 +786,16 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: f"output_tokens={output_tokens}, cached_tokens={cached_tokens}, reasoning_tokens={reasoning_tokens}" ) - asyncio.create_task( - update_usage_in_background( - usage_tracker_id, - input_tokens or approximate_input_tokens, - output_tokens or approximate_output_tokens, - cached_tokens, - reasoning_tokens, + if update_usage and (input_tokens > 0 or output_tokens > 0): + asyncio.create_task( + update_usage_in_background( + usage_tracker_id, + input_tokens, + output_tokens, + cached_tokens, + reasoning_tokens, + ) ) - ) return token_counting_stream() diff --git a/app/services/providers/usage_tracker_service.py b/app/services/providers/usage_tracker_service.py index cf9e6a7..b158e3b 100644 --- a/app/services/providers/usage_tracker_service.py +++ b/app/services/providers/usage_tracker_service.py @@ -65,7 +65,6 @@ async def update_usage_tracker( ) usage_tracker = result.scalar_one_or_none() now = datetime.now(UTC) - logger.info(f"provider_key: {usage_tracker.provider_key.provider_name}") price_info = await PricingService.calculate_usage_cost( db, usage_tracker.provider_key.provider_name.lower(),