Skip to content

Commit 68bfdc6

Browse files
committed
address comments
1 parent 3be3dcc commit 68bfdc6

File tree

1 file changed

+41
-39
lines changed

1 file changed

+41
-39
lines changed

eval_protocol/proxy/proxy_core/litellm.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import json
66
import base64
77
import logging
8-
import os
98
from uuid6 import uuid7
109
from fastapi import Request, Response, HTTPException
10+
from fastapi.responses import StreamingResponse
1111
import redis
12+
import openai
1213
from litellm import acompletion
1314

1415
from .redis_utils import register_insertion_id
@@ -17,20 +18,6 @@
1718
logger = logging.getLogger(__name__)
1819

1920

20-
def _configure_langfuse_otel(config: ProxyConfig, project_id: str) -> None:
21-
"""Configure Langfuse OTEL credentials via environment variables."""
22-
public_key = config.langfuse_keys[project_id]["public_key"]
23-
secret_key = config.langfuse_keys[project_id]["secret_key"]
24-
25-
os.environ["LANGFUSE_PUBLIC_KEY"] = public_key
26-
os.environ["LANGFUSE_SECRET_KEY"] = secret_key
27-
os.environ.setdefault("LANGFUSE_HOST", config.langfuse_host)
28-
29-
logger.info(
30-
f"Langfuse OTEL configured: project={project_id}, host={os.environ['LANGFUSE_HOST']}, public_key={public_key[:20]}..."
31-
)
32-
33-
3421
async def handle_chat_completion(
3522
config: ProxyConfig,
3623
redis_client: redis.Redis,
@@ -78,7 +65,7 @@ async def handle_chat_completion(
7865
if auth_header.startswith("Bearer "):
7966
data["api_key"] = auth_header.replace("Bearer ", "").strip()
8067

81-
# Build metadata with tags for Langfuse OTEL
68+
# Build metadata with tags for Langfuse
8269
insertion_id = None
8370
metadata = data.pop("metadata", {}) or {}
8471
tags = list(metadata.pop("tags", []) or [])
@@ -96,47 +83,62 @@ async def handle_chat_completion(
9683
]
9784
)
9885

99-
# Configure Langfuse OTEL
100-
_configure_langfuse_otel(config, project_id)
101-
102-
# Build Langfuse OTEL metadata (becomes span attributes prefixed with langfuse.*)
86+
# Build Langfuse metadata (tags, trace context)
10387
litellm_metadata = {"tags": tags, **metadata}
10488
if rollout_id is not None:
10589
litellm_metadata["trace_id"] = rollout_id
10690
litellm_metadata["generation_name"] = f"chat-{insertion_id}"
10791

92+
langfuse_keys = config.langfuse_keys[project_id]
93+
94+
# Check if streaming is requested
95+
is_streaming = data.get("stream", False)
96+
10897
try:
10998
# Make the completion call - pass all params through
11099
response = await acompletion(
111100
**data,
112101
metadata=litellm_metadata,
113102
timeout=config.request_timeout,
103+
langfuse_public_key=langfuse_keys["public_key"],
104+
langfuse_secret_key=langfuse_keys["secret_key"],
114105
)
115106

116107
# Register insertion_id in Redis on success
117108
if insertion_id is not None and rollout_id is not None:
118109
register_insertion_id(redis_client, rollout_id, insertion_id)
119110

120-
# Convert ModelResponse to JSON
121-
return Response(
122-
content=response.model_dump_json(),
123-
status_code=200,
124-
media_type="application/json",
125-
)
111+
if is_streaming:
112+
# For streaming, return a StreamingResponse with SSE format
113+
async def stream_generator():
114+
async for chunk in response: # type: ignore[union-attr]
115+
yield f"data: {chunk.model_dump_json()}\n\n"
116+
yield "data: [DONE]\n\n"
117+
118+
return StreamingResponse(
119+
stream_generator(),
120+
media_type="text/event-stream",
121+
headers={
122+
"Cache-Control": "no-cache",
123+
"Connection": "keep-alive",
124+
},
125+
)
126+
else:
127+
# Non-streaming: return JSON response
128+
return Response(
129+
content=response.model_dump_json(),
130+
status_code=200,
131+
media_type="application/json",
132+
)
126133

127134
except HTTPException:
128135
raise
129-
except Exception as e:
130-
logger.error(f"LiteLLM error: {e}", exc_info=True)
131-
return Response(
132-
content=json.dumps(
133-
{
134-
"error": {
135-
"message": str(e),
136-
"type": type(e).__name__,
137-
}
138-
}
139-
),
140-
status_code=500,
141-
media_type="application/json",
136+
except openai.APIError as e:
137+
# Convert to HTTPException and let FastAPI handle it
138+
raise HTTPException(
139+
status_code=getattr(e, "status_code", 500),
140+
detail=str(e),
142141
)
142+
except Exception as e:
143+
logger.error(f"Unexpected error: {e}", exc_info=True)
144+
raise HTTPException(status_code=500, detail=str(e))

0 commit comments

Comments
 (0)