Skip to content

Commit d00561f

Browse files
committed
reuse session
1 parent dca171a commit d00561f

File tree

2 files changed

+32
-16
lines changed

2 files changed

+32
-16
lines changed

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ def __init__(
264264
self.project_id = project_id
265265
self.base_url = base_url.rstrip("/")
266266
self.timeout = timeout
267+
# Reuse a single session for connection pooling and to avoid leaking FDs.
268+
self._session = requests.Session()
267269

268270
def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) -> List[Dict[str, Any]]:
269271
"""Fetch logs from Fireworks tracing gateway /logs endpoint.
@@ -287,14 +289,14 @@ def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) -
287289
last_error: Optional[str] = None
288290
for url in urls_to_try:
289291
try:
290-
response = requests.get(url, params=params, timeout=self.timeout, headers=headers)
291-
if response.status_code == 404:
292-
# Try next variant
293-
last_error = f"404 for {url}"
294-
continue
295-
response.raise_for_status()
296-
data = response.json() or {}
297-
break
292+
with self._session.get(url, params=params, timeout=self.timeout, headers=headers) as response:
293+
if response.status_code == 404:
294+
# Try next variant (must close response to release connection)
295+
last_error = f"404 for {url}"
296+
continue
297+
response.raise_for_status()
298+
data = response.json() or {}
299+
break
298300
except requests.exceptions.RequestException as e:
299301
last_error = str(e)
300302
continue
@@ -412,9 +414,9 @@ def get_evaluation_rows(
412414

413415
result = None
414416
try:
415-
response = requests.get(url, params=params, timeout=self.timeout, headers=headers)
416-
response.raise_for_status()
417-
result = response.json()
417+
with self._session.get(url, params=params, timeout=self.timeout, headers=headers) as response:
418+
response.raise_for_status()
419+
result = response.json()
418420
except requests.exceptions.HTTPError as e:
419421
error_msg = str(e)
420422

@@ -451,3 +453,10 @@ def get_evaluation_rows(
451453

452454
logger.info("Successfully converted %d traces to evaluation rows", len(eval_rows))
453455
return eval_rows
456+
457+
def close(self) -> None:
458+
"""Close underlying HTTP resources."""
459+
try:
460+
self._session.close()
461+
except Exception:
462+
pass

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Dict, List, Optional, Callable
44

55
import requests
6+
from requests.adapters import HTTPAdapter
67

78
from eval_protocol.models import EvaluationRow, Status
89
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
@@ -55,6 +56,8 @@ def __init__(
5556
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader
5657
self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url)
5758

59+
self._session = requests.Session()
60+
5861
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
5962
tasks: List[asyncio.Task[EvaluationRow]] = []
6063

@@ -94,8 +97,8 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
9497
def _post_init() -> None:
9598
url = f"{remote_base_url}/init"
9699
try:
97-
r = requests.post(url, json=init_payload.model_dump(), timeout=300)
98-
r.raise_for_status()
100+
with self._session.post(url, json=init_payload.model_dump(), timeout=300) as r:
101+
r.raise_for_status()
99102
except requests.exceptions.Timeout:
100103
raise TimeoutError(
101104
f"The /init endpoint tried {url} with {init_payload.model_dump()} but timed out after 300 seconds."
@@ -108,9 +111,9 @@ def _post_init() -> None:
108111

109112
def _get_status() -> Dict[str, Any]:
110113
url = f"{remote_base_url}/status"
111-
r = requests.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15)
112-
r.raise_for_status()
113-
return r.json()
114+
with self._session.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15) as r:
115+
r.raise_for_status()
116+
return r.json()
114117

115118
continue_polling_status = True
116119
while time.time() < deadline:
@@ -204,4 +207,8 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
204207
return tasks
205208

206209
def cleanup(self) -> None:
210+
try:
211+
self._session.close()
212+
except Exception:
213+
pass
207214
return None

0 commit comments

Comments
 (0)