55import json
66import base64
77import logging
8- import os
98from uuid6 import uuid7
109from fastapi import Request , Response , HTTPException
10+ from fastapi .responses import StreamingResponse
1111import redis
12+ import openai
1213from litellm import acompletion
1314
1415from .redis_utils import register_insertion_id
1718logger = logging .getLogger (__name__ )
1819
1920
20- def _configure_langfuse_otel (config : ProxyConfig , project_id : str ) -> None :
21- """Configure Langfuse OTEL credentials via environment variables."""
22- public_key = config .langfuse_keys [project_id ]["public_key" ]
23- secret_key = config .langfuse_keys [project_id ]["secret_key" ]
24-
25- os .environ ["LANGFUSE_PUBLIC_KEY" ] = public_key
26- os .environ ["LANGFUSE_SECRET_KEY" ] = secret_key
27- os .environ .setdefault ("LANGFUSE_HOST" , config .langfuse_host )
28-
29- logger .info (
30- f"Langfuse OTEL configured: project={ project_id } , host={ os .environ ['LANGFUSE_HOST' ]} , public_key={ public_key [:20 ]} ..."
31- )
32-
33-
3421async def handle_chat_completion (
3522 config : ProxyConfig ,
3623 redis_client : redis .Redis ,
@@ -78,7 +65,7 @@ async def handle_chat_completion(
7865 if auth_header .startswith ("Bearer " ):
7966 data ["api_key" ] = auth_header .replace ("Bearer " , "" ).strip ()
8067
81- # Build metadata with tags for Langfuse OTEL
68+ # Build metadata with tags for Langfuse
8269 insertion_id = None
8370 metadata = data .pop ("metadata" , {}) or {}
8471 tags = list (metadata .pop ("tags" , []) or [])
@@ -96,47 +83,62 @@ async def handle_chat_completion(
9683 ]
9784 )
9885
99- # Configure Langfuse OTEL
100- _configure_langfuse_otel (config , project_id )
101-
102- # Build Langfuse OTEL metadata (becomes span attributes prefixed with langfuse.*)
86+ # Build Langfuse metadata (tags, trace context)
10387 litellm_metadata = {"tags" : tags , ** metadata }
10488 if rollout_id is not None :
10589 litellm_metadata ["trace_id" ] = rollout_id
10690 litellm_metadata ["generation_name" ] = f"chat-{ insertion_id } "
10791
92+ langfuse_keys = config .langfuse_keys [project_id ]
93+
94+ # Check if streaming is requested
95+ is_streaming = data .get ("stream" , False )
96+
10897 try :
10998 # Make the completion call - pass all params through
11099 response = await acompletion (
111100 ** data ,
112101 metadata = litellm_metadata ,
113102 timeout = config .request_timeout ,
103+ langfuse_public_key = langfuse_keys ["public_key" ],
104+ langfuse_secret_key = langfuse_keys ["secret_key" ],
114105 )
115106
116107 # Register insertion_id in Redis on success
117108 if insertion_id is not None and rollout_id is not None :
118109 register_insertion_id (redis_client , rollout_id , insertion_id )
119110
120- # Convert ModelResponse to JSON
121- return Response (
122- content = response .model_dump_json (),
123- status_code = 200 ,
124- media_type = "application/json" ,
125- )
111+ if is_streaming :
112+ # For streaming, return a StreamingResponse with SSE format
113+ async def stream_generator ():
114+ async for chunk in response : # type: ignore[union-attr]
115+ yield f"data: { chunk .model_dump_json ()} \n \n "
116+ yield "data: [DONE]\n \n "
117+
118+ return StreamingResponse (
119+ stream_generator (),
120+ media_type = "text/event-stream" ,
121+ headers = {
122+ "Cache-Control" : "no-cache" ,
123+ "Connection" : "keep-alive" ,
124+ },
125+ )
126+ else :
127+ # Non-streaming: return JSON response
128+ return Response (
129+ content = response .model_dump_json (),
130+ status_code = 200 ,
131+ media_type = "application/json" ,
132+ )
126133
127134 except HTTPException :
128135 raise
129- except Exception as e :
130- logger .error (f"LiteLLM error: { e } " , exc_info = True )
131- return Response (
132- content = json .dumps (
133- {
134- "error" : {
135- "message" : str (e ),
136- "type" : type (e ).__name__ ,
137- }
138- }
139- ),
140- status_code = 500 ,
141- media_type = "application/json" ,
136+ except openai .APIError as e :
137+ # Convert to HTTPException and let FastAPI handle it
138+ raise HTTPException (
139+ status_code = getattr (e , "status_code" , 500 ),
140+ detail = str (e ),
142141 )
142+ except Exception as e :
143+ logger .error (f"Unexpected error: { e } " , exc_info = True )
144+ raise HTTPException (status_code = 500 , detail = str (e ))
0 commit comments