-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtracing_model.py
More file actions
155 lines (123 loc) · 5.95 KB
/
tracing_model.py
File metadata and controls
155 lines (123 loc) · 5.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Custom model classes for integrating mini-swe-agent with eval-protocol's tracing infrastructure.
## Why This File Exists
mini-swe-agent is an autonomous agent that makes 20-100+ LLM API calls per SWE-bench instance
(e.g., reading files, editing code, running tests). To debug agent behavior and display results
in eval-protocol's UI, we need to capture and analyze every LLM call.
This file bridges mini-swe-agent (which uses LitellmModel) with the Fireworks tracing proxy
(which requires specific URL patterns and SDK usage).
## What These Classes Do
### FireworksCompatibleModel (Base)
- Extends mini-swe-agent's LitellmModel
- Handles Fireworks API compatibility:
* Strips non-standard message fields that Fireworks API rejects
* Adds stop sequences to prevent common agent failure modes
* Applies temperature/reasoning overrides from wrapper script
- Used when tracing isn't needed (direct Fireworks API calls)
### TracingFireworksModel (For eval-protocol integration)
- Extends FireworksCompatibleModel
- Routes ALL LLM calls through Fireworks tracing proxy instead of direct API
- Uses OpenAI SDK (not litellm) to preserve full model names
"""
import os
from urllib.parse import urlparse
from minisweagent.models.litellm_model import LitellmModel
import litellm
def _sanitize_openai_compatible_base_url(raw: str) -> str:
"""
Convert a full endpoint URL into an OpenAI-compatible api_base.
Examples:
- https://api.fireworks.ai/inference/v1/chat/completions -> https://api.fireworks.ai/inference/v1
- https://api.fireworks.ai/inference/v1/ -> https://api.fireworks.ai/inference/v1
"""
s = (raw or "").strip().rstrip("/")
for suffix in ("/chat/completions", "/completions"):
if s.endswith(suffix):
s = s[: -len(suffix)]
break
return s
def _infer_use_litellm_proxy(api_base: str | None) -> bool:
"""
Heuristic:
- If caller explicitly sets USE_LITELLM_PROXY, honor it.
- Otherwise, if api_base points at api.fireworks.ai, treat it as direct Fireworks inference (no proxy prefix).
- Else assume it's a LiteLLM/OpenAI-compatible proxy (e.g. tracing.fireworks.ai, localhost) and use litellm_proxy/.
"""
forced = os.environ.get("USE_LITELLM_PROXY")
if forced is not None:
return forced.strip() in ("1", "true", "TRUE", "yes", "YES")
if not api_base:
return False
try:
host = urlparse(api_base).hostname or ""
except Exception:
return True
return host != "api.fireworks.ai"
class FireworksCompatibleModel(LitellmModel):
"""
Fireworks-compatible wrapper for LitellmModel.
"""
def __init__(self, **kwargs):
model_id = os.environ.get("FIREWORKS_MODEL_ID")
if model_id:
kwargs["model_name"] = model_id
if "model_kwargs" not in kwargs:
kwargs["model_kwargs"] = {}
fw_key = os.environ.get("FIREWORKS_API_KEY")
# Optional override of OpenAI-compatible api_base (used for either tracing proxy or direct inference).
# Back-compat: TRACING_BASE_URL.
api_base_raw = os.environ.get("LITELLM_API_BASE") or os.environ.get("TRACING_BASE_URL")
api_base = _sanitize_openai_compatible_base_url(api_base_raw) if api_base_raw else None
if api_base:
litellm.api_base = api_base
mn = kwargs.get("model_name")
use_proxy_mode = _infer_use_litellm_proxy(api_base)
if use_proxy_mode:
# When routing through a LiteLLM proxy, use the proxy api key env var that LiteLLM expects.
if fw_key:
os.environ["LITELLM_PROXY_API_KEY"] = fw_key
if isinstance(mn, str) and not mn.startswith("litellm_proxy/"):
kwargs["model_name"] = f"litellm_proxy/{mn}"
# CRITICAL: Set drop_params to False so stop sequences aren't stripped!
kwargs["model_kwargs"]["drop_params"] = False
# Get existing stop sequences
existing_stop = kwargs["model_kwargs"].get("stop", [])
if isinstance(existing_stop, str):
existing_stop = [existing_stop]
elif existing_stop is None:
existing_stop = []
kwargs["model_kwargs"]["max_tokens"] = 20000 # Increased to avoid truncating long heredocs/commands
if "temperature" not in kwargs["model_kwargs"]:
kwargs["model_kwargs"]["temperature"] = 0.0
# Apply per-run overrides injected by the wrapper (no environment variables)
overrides = globals().get("WRAPPER_MODEL_OVERRIDES")
if isinstance(overrides, dict):
if overrides.get("reasoning") in ("low", "medium", "high"):
kwargs["model_kwargs"]["reasoning_effort"] = overrides["reasoning"]
if overrides.get("temperature") is not None:
try:
kwargs["model_kwargs"]["temperature"] = float(overrides["temperature"])
except Exception:
pass
if overrides.get("max_tokens") is not None:
try:
kwargs["model_kwargs"]["max_tokens"] = int(overrides["max_tokens"])
except Exception:
pass
super().__init__(**kwargs)
def _query(self, messages: list[dict[str, str]], **kwargs):
"""Remove non-standard fields before sending to Fireworks API."""
# Keep only standard OpenAI-compatible fields
clean_messages = []
for msg in messages:
clean_msg = {"role": msg["role"], "content": msg["content"]}
if "tool_calls" in msg:
clean_msg["tool_calls"] = msg["tool_calls"]
if "name" in msg:
clean_msg["name"] = msg["name"]
clean_messages.append(clean_msg)
# IMPORTANT: Ensure drop_params stays False in the actual query
kwargs_with_stop = kwargs.copy()
if "drop_params" not in kwargs_with_stop:
kwargs_with_stop["drop_params"] = False
return super()._query(clean_messages, **kwargs_with_stop)