Skip to content

Commit 2dad518

Browse files
author
Shrey Modi
committed
linterrors
1 parent 71f4165 commit 2dad518

File tree

7 files changed

+182
-145
lines changed

7 files changed

+182
-145
lines changed

examples/swebench/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,4 @@ pytest examples/swebench/tests/test_swebench.py -v -s
5454
Notes
5555
- The test currently generates 10 rows by numeric index (0–9)
5656
- Each request triggers the server to run one SWE-bench instance and write to its own `row_{index}`
57-
- Control harness workers via: `export SWEBENCH_EVAL_WORKERS=5`
57+
- Control harness workers via: `export SWEBENCH_EVAL_WORKERS=5`

examples/swebench/SWE-bench

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 5cd4be9fb23971679cbbafe5a0ecade27cef99be

examples/swebench/run_swe_agent_fw.py

Lines changed: 94 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -45,113 +45,110 @@ class FireworksCompatibleModel(LitellmModel):
4545
"""
4646

4747
def __init__(self, **kwargs):
48-
if model_id := os.environ.get('FIREWORKS_MODEL_ID'):
49-
kwargs['model_name'] = model_id
48+
if model_id := os.environ.get("FIREWORKS_MODEL_ID"):
49+
kwargs["model_name"] = model_id
5050
print(f"kwargs: {kwargs}")
51-
if 'model_kwargs' not in kwargs:
52-
kwargs['model_kwargs'] = {}
53-
51+
if "model_kwargs" not in kwargs:
52+
kwargs["model_kwargs"] = {}
53+
5454
# CRITICAL: Set drop_params to False so stop sequences aren't stripped!
55-
kwargs['model_kwargs']['drop_params'] = False
56-
55+
kwargs["model_kwargs"]["drop_params"] = False
56+
5757
# Get existing stop sequences
58-
existing_stop = kwargs['model_kwargs'].get('stop', [])
58+
existing_stop = kwargs["model_kwargs"].get("stop", [])
5959
if isinstance(existing_stop, str):
6060
existing_stop = [existing_stop]
6161
elif existing_stop is None:
6262
existing_stop = []
63-
63+
6464
# Add stop sequences (only the non-natural ones)
6565
stop_sequences = existing_stop + [
66-
# ASCII versions
66+
# ASCII versions
6767
"<|User|>",
6868
"<|Assistant|>",
69-
7069
# Full-width PIPE versions (U+FF5C)
71-
"<|User|>", # \uff5c
70+
"<|User|>", # \uff5c
7271
"<|Assistant|>",
7372
"```<|",
7473
"<|User",
7574
"<|Ass",
76-
77-
# Full-width LETTER L versions (U+FF4C)
78-
"<lUser|>", # \uff4c
75+
# Full-width LETTER L versions (U+FF4C)
76+
"<lUser|>", # \uff4c
7977
"<lAssistant|>",
8078
"```<l",
8179
"<lUser",
8280
"<lAss",
8381
]
84-
kwargs['model_kwargs']['stop'] = stop_sequences
85-
kwargs['model_kwargs']['max_tokens'] = 1024 # Reduce to 1024 to save tokens
86-
87-
if 'temperature' not in kwargs['model_kwargs']:
88-
kwargs['model_kwargs']['temperature'] = 0.0
82+
kwargs["model_kwargs"]["stop"] = stop_sequences
83+
kwargs["model_kwargs"]["max_tokens"] = 1024 # Reduce to 1024 to save tokens
84+
85+
if "temperature" not in kwargs["model_kwargs"]:
86+
kwargs["model_kwargs"]["temperature"] = 0.0
8987

9088
# Apply per-run overrides injected by the wrapper (no environment variables)
91-
overrides = globals().get('WRAPPER_MODEL_OVERRIDES')
89+
overrides = globals().get("WRAPPER_MODEL_OVERRIDES")
9290
if isinstance(overrides, dict):
93-
if overrides.get('reasoning') in ('low', 'medium', 'high'):
94-
kwargs['model_kwargs']['reasoning_effort'] = overrides['reasoning']
95-
if overrides.get('temperature') is not None:
91+
if overrides.get("reasoning") in ("low", "medium", "high"):
92+
kwargs["model_kwargs"]["reasoning_effort"] = overrides["reasoning"]
93+
if overrides.get("temperature") is not None:
9694
try:
97-
kwargs['model_kwargs']['temperature'] = float(overrides['temperature'])
95+
kwargs["model_kwargs"]["temperature"] = float(overrides["temperature"])
9896
except Exception:
9997
pass
100-
if overrides.get('max_tokens') is not None:
98+
if overrides.get("max_tokens") is not None:
10199
try:
102-
kwargs['model_kwargs']['max_tokens'] = int(overrides['max_tokens'])
100+
kwargs["model_kwargs"]["max_tokens"] = int(overrides["max_tokens"])
103101
except Exception:
104102
pass
105-
103+
106104
super().__init__(**kwargs)
107105

108106
def _query(self, messages: list[dict[str, str]], **kwargs):
109107
"""Remove non-standard fields before sending to Fireworks API."""
110108
# Keep only standard OpenAI-compatible fields
111109
clean_messages = []
112110
for msg in messages:
113-
clean_msg = {
114-
"role": msg["role"],
115-
"content": msg["content"]
116-
}
111+
clean_msg = {"role": msg["role"], "content": msg["content"]}
117112
if "tool_calls" in msg:
118113
clean_msg["tool_calls"] = msg["tool_calls"]
119114
if "name" in msg:
120115
clean_msg["name"] = msg["name"]
121116
clean_messages.append(clean_msg)
122-
117+
123118
# IMPORTANT: Ensure drop_params stays False in the actual query
124119
kwargs_with_stop = kwargs.copy()
125-
if 'drop_params' not in kwargs_with_stop:
126-
kwargs_with_stop['drop_params'] = False
127-
120+
if "drop_params" not in kwargs_with_stop:
121+
kwargs_with_stop["drop_params"] = False
122+
128123
return super()._query(clean_messages, **kwargs_with_stop)
129124

125+
130126
def __get_api_key():
131127
"""Get Fireworks API key from environment or mini-swe-agent config."""
132128
# Environment variable takes precedence
133-
if api_key := os.environ.get('FIREWORKS_API_KEY'):
129+
if api_key := os.environ.get("FIREWORKS_API_KEY"):
134130
return api_key
135131

136132
# Try to get API key from mini-swe-agent's config system
137133
try:
138134
from minisweagent.config import get_config
135+
139136
config = get_config()
140-
return config.get('FIREWORKS_API_KEY')
137+
return config.get("FIREWORKS_API_KEY")
141138
except (ImportError, AttributeError, KeyError):
142139
# Fallback: check common config file locations
143140
config_paths = [
144141
Path.home() / ".config" / "mini-swe-agent" / ".env",
145-
Path.home() / "Library" / "Application Support" / "mini-swe-agent" / ".env"
142+
Path.home() / "Library" / "Application Support" / "mini-swe-agent" / ".env",
146143
]
147144

148145
for config_path in config_paths:
149146
if config_path.exists():
150147
try:
151148
with open(config_path) as f:
152149
for line in f:
153-
if line.startswith('FIREWORKS_API_KEY='):
154-
value = line.split('=', 1)[1].strip()
150+
if line.startswith("FIREWORKS_API_KEY="):
151+
value = line.split("=", 1)[1].strip()
155152
return value.strip("'\"")
156153
except (IOError, OSError):
157154
continue
@@ -170,7 +167,7 @@ def __test_model(model_id):
170167
return False
171168

172169
# Configure environment for litellm
173-
os.environ['FIREWORKS_API_KEY'] = api_key
170+
os.environ["FIREWORKS_API_KEY"] = api_key
174171
# Assume model_id is fully qualified
175172
model_name = model_id
176173

@@ -182,7 +179,7 @@ def __test_model(model_id):
182179
model=model_name,
183180
messages=[{"role": "user", "content": "Test message. Reply with OK."}],
184181
temperature=0.0,
185-
max_tokens=10
182+
max_tokens=10,
186183
)
187184

188185
print(f"Success. Response: {response.choices[0].message.content}")
@@ -201,8 +198,6 @@ def __validate_environment():
201198
print("Set it with: mini-extra config set FIREWORKS_API_KEY <key>")
202199

203200

204-
205-
206201
def __build_command(args, wrapper_module_path):
207202
"""Build mini-swe-agent command with appropriate arguments."""
208203
# Construct model class path
@@ -212,12 +207,17 @@ def __build_command(args, wrapper_module_path):
212207
# Base command - assume model_id is fully qualified
213208
cmd = [
214209
sys.executable,
215-
"-m", "minisweagent.run.mini_extra",
210+
"-m",
211+
"minisweagent.run.mini_extra",
216212
"swebench-single" if args.single is not None else "swebench",
217-
"--model", args.model_id,
218-
"--model-class", model_class,
219-
"--subset", args.subset,
220-
"--split", args.split
213+
"--model",
214+
args.model_id,
215+
"--model-class",
216+
model_class,
217+
"--subset",
218+
args.subset,
219+
"--split",
220+
args.split,
221221
]
222222
if args.model_class:
223223
cmd.extend(["--model-class", args.model_class])
@@ -230,18 +230,26 @@ def __build_command(args, wrapper_module_path):
230230
if args.single is not None:
231231
# Use batch mode for a single index via slice and write to a per-row directory
232232
from pathlib import Path
233-
slice_spec = f"{args.single}:{args.single+1}"
233+
234+
slice_spec = f"{args.single}:{args.single + 1}"
234235
row_dir = str((Path(args.output) if args.output else Path.cwd()) / f"row_{args.single}")
235236
cmd = [
236237
sys.executable,
237-
"-m", "minisweagent.run.mini_extra",
238+
"-m",
239+
"minisweagent.run.mini_extra",
238240
"swebench",
239-
"--model", args.model_id,
240-
"--model-class", model_class,
241-
"--subset", args.subset,
242-
"--split", args.split,
243-
"--slice", slice_spec,
244-
"--output", row_dir,
241+
"--model",
242+
args.model_id,
243+
"--model-class",
244+
model_class,
245+
"--subset",
246+
args.subset,
247+
"--split",
248+
args.split,
249+
"--slice",
250+
slice_spec,
251+
"--output",
252+
row_dir,
245253
]
246254
if args.model_class:
247255
cmd.extend(["--model-class", args.model_class])
@@ -253,31 +261,35 @@ def __build_command(args, wrapper_module_path):
253261

254262
return cmd
255263

256-
257-
258264

259265
def main():
260266
parser = argparse.ArgumentParser(
261-
description='Run mini-swe-agent with Fireworks models on SWE-bench',
267+
description="Run mini-swe-agent with Fireworks models on SWE-bench",
262268
formatter_class=argparse.RawDescriptionHelpFormatter,
263-
epilog=__doc__
269+
epilog=__doc__,
264270
)
265271

266272
# Required model ID
267-
parser.add_argument('model_id', help='Fireworks model ID')
268-
parser.add_argument('--model-class', type=str, default=None, help='Optional mini-swe-agent model-class')
273+
parser.add_argument("model_id", help="Fireworks model ID")
274+
parser.add_argument("--model-class", type=str, default=None, help="Optional mini-swe-agent model-class")
269275
# Execution options
270-
parser.add_argument('--instances', type=int, help='Number of instances to run')
271-
parser.add_argument('--workers', type=int, default=1, help='Parallel workers (default: 1)')
272-
parser.add_argument('--output', help='Output directory')
273-
parser.add_argument('--subset', default='verified', choices=['verified', 'lite', 'full'])
274-
parser.add_argument('--split', default='test', choices=['dev', 'test'])
275-
parser.add_argument('--single', type=int, metavar='INDEX', help='Run single instance')
276-
parser.add_argument('--exit-immediately', action='store_true')
277-
parser.add_argument('--test', action='store_true', help='Test model connectivity')
278-
parser.add_argument('--reasoning', type=str, choices=['low', 'medium', 'high'], default=None, help='Provider-specific reasoning effort')
279-
parser.add_argument('--temperature', type=float, default=None, help='Model temperature override')
280-
parser.add_argument('--max-tokens', type=int, default=None, help='Max tokens override')
276+
parser.add_argument("--instances", type=int, help="Number of instances to run")
277+
parser.add_argument("--workers", type=int, default=1, help="Parallel workers (default: 1)")
278+
parser.add_argument("--output", help="Output directory")
279+
parser.add_argument("--subset", default="verified", choices=["verified", "lite", "full"])
280+
parser.add_argument("--split", default="test", choices=["dev", "test"])
281+
parser.add_argument("--single", type=int, metavar="INDEX", help="Run single instance")
282+
parser.add_argument("--exit-immediately", action="store_true")
283+
parser.add_argument("--test", action="store_true", help="Test model connectivity")
284+
parser.add_argument(
285+
"--reasoning",
286+
type=str,
287+
choices=["low", "medium", "high"],
288+
default=None,
289+
help="Provider-specific reasoning effort",
290+
)
291+
parser.add_argument("--temperature", type=float, default=None, help="Model temperature override")
292+
parser.add_argument("--max-tokens", type=int, default=None, help="Max tokens override")
281293
args = parser.parse_args()
282294

283295
# Handle test mode
@@ -291,11 +303,11 @@ def main():
291303
if args.output is None:
292304
safe_model_id = args.model_id.replace("/", "-").replace(":", "-")
293305
script_dir = Path(__file__).parent.resolve()
294-
args.output = str(script_dir / f'swebench-{safe_model_id}-results')
306+
args.output = str(script_dir / f"swebench-{safe_model_id}-results")
295307

296308
# Create temporary module for importing FireworksCompatibleModel
297-
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
298-
with open(__file__, 'r') as current_file:
309+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
310+
with open(__file__, "r") as current_file:
299311
f.write(current_file.read())
300312
# Inject per-run model overrides directly into the temp module
301313
f.write("\n# --- Injected by wrapper: per-run model overrides ---\n")
@@ -309,14 +321,14 @@ def main():
309321
try:
310322
# Configure environment
311323
env = os.environ.copy()
312-
env['PYTHONPATH'] = f"{temp_module_path.parent}:{env.get('PYTHONPATH', '')}"
324+
env["PYTHONPATH"] = f"{temp_module_path.parent}:{env.get('PYTHONPATH', '')}"
313325
# Pass the fully qualified model path to the subprocess
314-
env['FIREWORKS_MODEL_ID'] = args.model_id
326+
env["FIREWORKS_MODEL_ID"] = args.model_id
315327

316328
# Ensure API key is passed to subprocess
317329
api_key = __get_api_key()
318330
if api_key:
319-
env['FIREWORKS_API_KEY'] = api_key
331+
env["FIREWORKS_API_KEY"] = api_key
320332

321333
# No environment variables for model kwargs; overrides are injected into the temp module
322334

@@ -343,5 +355,5 @@ def main():
343355
temp_module_path.unlink()
344356

345357

346-
if __name__ == '__main__':
358+
if __name__ == "__main__":
347359
main()

0 commit comments

Comments
 (0)