Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions app/api/routes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import AsyncGenerator
import json
from fastapi import HTTPException
from starlette.responses import StreamingResponse
from app.exceptions.exceptions import ProviderAPIException

async def wrap_streaming_response_with_error_handling(
logger, async_gen: AsyncGenerator[bytes, None]
) -> StreamingResponse:
"""
Wraps an async generator to catch and properly handle errors in streaming responses.
Returns a StreamingResponse that will:
- Return proper HTTP error status if error occurs before first chunk
- Send error as SSE event if error occurs mid-stream

Args:
logger: Logger instance for error logging
async_gen: The async generator producing the stream chunks

Returns:
StreamingResponse with proper error handling

Raises:
HTTPException: If error occurs before streaming starts
"""

# Try to get the first chunk BEFORE creating StreamingResponse
# This allows us to catch immediate errors and return proper HTTP status
try:
first_chunk = await async_gen.__anext__()
except StopAsyncIteration:
# Empty stream
logger.error("Empty stream response")
raise HTTPException(status_code=500, detail="Empty stream response")
except ProviderAPIException as e:
logger.error(f"Provider API error: {str(e)}")
raise HTTPException(status_code=e.error_code, detail=e.error_message) from e
except Exception as e:
# Convert other exceptions to HTTPException
logger.error(f"Error before streaming started: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) from e

# Success! Now create generator that replays first chunk + rest
async def response_generator():
# Yield the first chunk we already got
yield first_chunk

try:
# Continue with the rest of the stream
async for chunk in async_gen:
yield chunk
except Exception as e:
# Error occurred mid-stream - HTTP status already sent
# Send error as SSE event to inform the client
logger.error(f"Error during streaming: {str(e)}")

error_message = str(e)
error_event = {
"error": {
"message": error_message,
"type": "stream_error",
"code": "provider_error"
}
}
yield f"data: {json.dumps(error_event)}\n\n".encode()

# Send [DONE] to properly close the stream
yield b"data: [DONE]\n\n"

# Set appropriate headers for streaming
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Prevent Nginx buffering
}

return StreamingResponse(
response_generator(),
media_type="text/event-stream",
headers=headers
)
37 changes: 6 additions & 31 deletions app/api/routes/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from app.models.forge_api_key import ForgeApiKey
from app.models.user import User
from app.services.provider_service import ProviderService
from app.api.routes import wrap_streaming_response_with_error_handling

router = APIRouter()
logger = get_logger(name="proxy")
Expand Down Expand Up @@ -96,23 +97,15 @@ async def create_chat_completion(

# Check if it's a streaming response by checking if it's an async generator
if inspect.isasyncgen(response):
# Set appropriate headers for streaming
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Prevent Nginx buffering
}

return StreamingResponse(
response, media_type="text/event-stream", headers=headers
)
return await wrap_streaming_response_with_error_handling(logger, response)

# Otherwise, return the JSON response directly
return response
except ValueError as err:
logger.exception(f"Error processing chat completion request: {str(err)}")
raise HTTPException(status_code=400, detail=str(err)) from err
except HTTPException as err:
raise err
except Exception as err:
logger.exception(f"Error processing chat completion request: {str(err)}")
raise HTTPException(
Expand Down Expand Up @@ -144,16 +137,7 @@ async def create_completion(

# Check if it's a streaming response
if inspect.isasyncgen(response):
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Prevent Nginx buffering
}

return StreamingResponse(
response, media_type="text/event-stream", headers=headers
)
return await wrap_streaming_response_with_error_handling(logger, response)

# Otherwise, return the JSON response directly
return response
Expand Down Expand Up @@ -309,16 +293,7 @@ async def create_responses(

# Check if it's a streaming response
if inspect.isasyncgen(response):
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Prevent Nginx buffering
}

return StreamingResponse(
response, media_type="text/event-stream", headers=headers
)
return await wrap_streaming_response_with_error_handling(logger, response)

# Otherwise, return the JSON response directly
return response
Expand Down
60 changes: 20 additions & 40 deletions app/services/provider_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,21 +690,21 @@ async def process_request(
# re-calculate output tokens
output_tokens = max(output_tokens, total_tokens - input_tokens)

asyncio.create_task(
update_usage_in_background(
usage_tracker_id,
input_tokens,
output_tokens,
cached_tokens,
reasoning_tokens,
if input_tokens > 0 or output_tokens > 0:
asyncio.create_task(
update_usage_in_background(
usage_tracker_id,
input_tokens,
output_tokens,
cached_tokens,
reasoning_tokens,
)
)
)
return result
else:
# For streaming responses, wrap the generator to count tokens
async def token_counting_stream() -> AsyncGenerator[bytes, None]:
approximate_input_tokens = 0
approximate_output_tokens = 0
update_usage = True
output_tokens = 0
input_tokens = 0
total_tokens = 0
Expand All @@ -720,13 +720,6 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]:

messages = payload.get("messages", [])

# Rough estimate of input tokens based on message length
for msg in messages:
content = msg.get("content", "")
if isinstance(content, str):
# Rough approximation: 4 chars ~= 1 token
approximate_input_tokens += len(content) // 4

try:
async for chunk in result:
chunks_processed += 1
Expand Down Expand Up @@ -767,21 +760,6 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]:

# re-calculate output tokens
output_tokens = max(output_tokens, total_tokens - input_tokens)

# Extract content from the chunk based on OpenAI format
if "choices" in data:
for choice in data["choices"]:
if (
"delta" in choice
and "content" in choice["delta"]
):
content = choice["delta"]["content"]
# Only count tokens if we don't have final usage data
if content:
# Count tokens in content (approx)
approximate_output_tokens += (
len(content) // 4
)
except json.JSONDecodeError:
# If JSON parsing fails, just continue
pass
Expand All @@ -799,6 +777,7 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]:
"Error in streaming response: {}", str(e), exc_info=True
)
# Re-raise to propagate the error
update_usage = False
raise
finally:
logger.debug(
Expand All @@ -807,15 +786,16 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]:
f"output_tokens={output_tokens}, cached_tokens={cached_tokens}, reasoning_tokens={reasoning_tokens}"
)

asyncio.create_task(
update_usage_in_background(
usage_tracker_id,
input_tokens or approximate_input_tokens,
output_tokens or approximate_output_tokens,
cached_tokens,
reasoning_tokens,
if update_usage and (input_tokens > 0 or output_tokens > 0):
asyncio.create_task(
update_usage_in_background(
usage_tracker_id,
input_tokens,
output_tokens,
cached_tokens,
reasoning_tokens,
)
)
)

return token_counting_stream()

Expand Down
1 change: 0 additions & 1 deletion app/services/providers/usage_tracker_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ async def update_usage_tracker(
)
usage_tracker = result.scalar_one_or_none()
now = datetime.now(UTC)
logger.info(f"provider_key: {usage_tracker.provider_key.provider_name}")
price_info = await PricingService.calculate_usage_cost(
db,
usage_tracker.provider_key.provider_name.lower(),
Expand Down
Loading