Skip to content

Commit 869fade

Browse files
committed
redis update
1 parent da7fc9d commit 869fade

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

eval_protocol/proxy/proxy_core/app.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
from fastapi import FastAPI, Depends, Request, Query
7-
from typing import Optional, List
7+
from typing import Optional, List, Callable
88
import os
99
import redis
1010
import logging
@@ -105,6 +105,7 @@ def create_app(
105105
auth_provider: AuthProvider = NoAuthProvider(),
106106
preprocess_chat_request: Optional[ChatRequestHook] = None,
107107
preprocess_traces_request: Optional[TracesRequestHook] = None,
108+
extra_routes: Optional[Callable[[FastAPI], None]] = None,
108109
) -> FastAPI:
109110
@asynccontextmanager
110111
async def lifespan(app: FastAPI):
@@ -288,6 +289,9 @@ async def pointwise_get_langfuse_trace(
288289
params=params,
289290
)
290291

292+
if extra_routes is not None:
293+
extra_routes(app)
294+
291295
# Health
292296
@app.get("/health")
293297
async def health():

eval_protocol/proxy/proxy_core/redis_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88

99
logger = logging.getLogger(__name__)
1010

11+
DEFAULT_ROLLOUT_TTL_SECONDS = 60 * 60 * 24
1112

12-
def register_insertion_id(redis_client: redis.Redis, rollout_id: str, insertion_id: str) -> bool:
13+
14+
def register_insertion_id(
15+
redis_client: redis.Redis, rollout_id: str, insertion_id: str, ttl_seconds: int = DEFAULT_ROLLOUT_TTL_SECONDS
16+
) -> bool:
1317
"""Register an insertion_id for a rollout_id in Redis.
1418
1519
Tracks all expected completion insertion_ids for this rollout.
@@ -22,7 +26,10 @@ def register_insertion_id(redis_client: redis.Redis, rollout_id: str, insertion_
2226
True if successful, False otherwise
2327
"""
2428
try:
25-
redis_client.sadd(rollout_id, insertion_id)
29+
pipe = redis_client.pipeline()
30+
pipe.sadd(rollout_id, insertion_id)
31+
pipe.expire(rollout_id, int(ttl_seconds))
32+
pipe.execute()
2633
logger.info(f"Registered insertion_id {insertion_id} for rollout {rollout_id}")
2734
return True
2835
except Exception as e:

0 commit comments

Comments
 (0)