Skip to content

Commit 2ff684b

Browse files
author
Shrey Modi
committed
linterrors
1 parent 38e29dc commit 2ff684b

File tree

11 files changed

+465
-28
lines changed

11 files changed

+465
-28
lines changed

eval_protocol/pytest/github_action_rollout_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def _dispatch_workflow():
9797
"model": model,
9898
"metadata": init_request.metadata.model_dump_json(),
9999
"model_base_url": init_request.model_base_url,
100+
"completion_params": json.dumps(init_request.completion_params),
100101
},
101102
}
102103
r = requests.post(url, json=payload, headers=self._headers(), timeout=30)

eval_protocol/pytest/tracing_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,21 +82,21 @@ def build_init_request(
8282

8383
# Build completion_params from row and config
8484
completion_params_dict: Dict[str, Any] = {}
85-
85+
8686
# Start with config-level completion_params
8787
if config.completion_params and isinstance(config.completion_params, dict):
8888
completion_params_dict.update(config.completion_params)
89-
89+
9090
# Override with row-specific completion_params
9191
if row.input_metadata and row.input_metadata.completion_params:
9292
row_cp = row.input_metadata.completion_params
9393
if isinstance(row_cp, dict):
9494
completion_params_dict.update(row_cp)
95-
95+
9696
# Validate model is present
9797
if not completion_params_dict.get("model"):
9898
raise ValueError("Model must be provided in completion_params")
99-
99+
100100
# Extract base_url from completion_params
101101
completion_params_base_url: Optional[str] = completion_params_dict.get("base_url")
102102

eval_protocol/types/remote_rollout_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class InitRequest(BaseModel):
4646

4747
completion_params: Dict[str, Any] = Field(
4848
default_factory=dict,
49-
description="Completion parameters including model and optional model_kwargs, temperature, etc."
49+
description="Completion parameters including model and optional model_kwargs, temperature, etc.",
5050
)
5151
elastic_search_config: Optional[ElasticsearchConfig] = None
5252
messages: Optional[List[Message]] = None

tests/github_actions/rollout_worker.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,20 @@ def main():
5858

5959
try:
6060
completion_kwargs = {"model": args.model, "messages": messages}
61-
62-
if completion_params.get("model_kwargs"):
63-
completion_kwargs.update(completion_params["model_kwargs"])
61+
# Parse and apply completion_params if provided
62+
if args.completion_params:
63+
try:
64+
cp = json.loads(args.completion_params)
65+
if cp.get("model_kwargs"):
66+
completion_kwargs.update(cp["model_kwargs"])
67+
print(f" Applied model_kwargs: {cp.get('model_kwargs')}")
68+
except Exception as e:
69+
print(f"⚠️ Failed to parse completion_params: {e}")
6470

6571
client = OpenAI(base_url=args.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))
6672

6773
print("📡 Calling OpenAI completion...")
74+
print(f" Completion kwargs: {completion_kwargs}")
6875
completion = client.chat.completions.create(**completion_kwargs)
6976

7077
print(f"✅ Rollout {rollout_id} completed successfully")

tests/github_actions/test_github_actions_rollout.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ def rows() -> List[EvaluationRow]:
5454

5555

5656
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
57-
@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}])
57+
@pytest.mark.parametrize(
58+
"completion_params",
59+
[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "model_kwargs": {"temperature": 0.5}}],
60+
)
5861
@evaluation_test(
5962
data_loaders=DynamicDataLoader(
6063
generators=[rows],

tests/remote_server/remote_server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _worker():
3535
try:
3636
if not req.messages:
3737
raise ValueError("messages is required")
38-
38+
3939
model = req.completion_params.get("model")
4040
if not model:
4141
raise ValueError("model is required in completion_params")
@@ -44,7 +44,7 @@ def _worker():
4444
"model": model,
4545
"messages": req.messages,
4646
}
47-
47+
4848
# Apply model_kwargs if present
4949
if req.completion_params.get("model_kwargs"):
5050
model_kwargs = req.completion_params["model_kwargs"]
@@ -55,7 +55,7 @@ def _worker():
5555
completion_kwargs["tools"] = req.tools
5656

5757
logger.info(f"Final completion_kwargs: {completion_kwargs}")
58-
58+
5959
client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))
6060

6161
logger.info(f"Sending completion request to model {model}")

tests/remote_server/remote_server_multi_turn.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _worker():
3131
try:
3232
if not req.messages:
3333
raise ValueError("messages is required")
34-
34+
3535
model = req.completion_params.get("model")
3636
if not model:
3737
raise ValueError("model is required in completion_params")
@@ -52,13 +52,12 @@ def _worker():
5252
"What else can you share about this topic?",
5353
]
5454

55-
5655
# First completion (turns 1-2: initial user message + assistant response)
5756
logger.info(f"Turn 1-2: Sending initial completion request to model {model}")
5857
completion = client.chat.completions.create(
5958
model=model,
6059
messages=conversation_history, # type: ignore,
61-
**completion_kwargs
60+
**completion_kwargs,
6261
)
6362
assistant_message = completion.choices[0].message
6463
assistant_content = assistant_message.content or ""

tests/remote_server/test_remote_fireworks.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,10 @@ def rows() -> List[EvaluationRow]:
5858

5959

6060
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
61-
@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
62-
"model_kwargs": {"temperature": 0.5}
63-
}])
61+
@pytest.mark.parametrize(
62+
"completion_params",
63+
[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "model_kwargs": {"temperature": 0.5}}],
64+
)
6465
@evaluation_test(
6566
data_loaders=DynamicDataLoader(
6667
generators=[rows],
@@ -84,6 +85,8 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat
8485
assert row.execution_metadata.rollout_id in ROLLOUT_IDS, (
8586
f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}"
8687
)
87-
assert row.input_metadata.completion_params["model_kwargs"] == {"temperature": 0.5}, "Row should have correct model_kwargs"
88-
88+
assert row.input_metadata.completion_params["model_kwargs"] == {"temperature": 0.5}, (
89+
"Row should have correct model_kwargs"
90+
)
91+
8992
return row

0 commit comments

Comments
 (0)