Skip to content

Commit de3aba0

Browse files
authored
new tests (#362)
1 parent 0902602 commit de3aba0

File tree

1 file changed

+61
-135
lines changed

1 file changed

+61
-135
lines changed

eval_protocol/benchmarks/test_glm_streaming_compliance.py

Lines changed: 61 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,19 @@
1010
EvaluationRow,
1111
Message,
1212
MetricResult,
13-
ChatCompletionContentPartTextParam,
1413
)
1514
from eval_protocol.pytest.default_single_turn_rollout_process import (
1615
SingleTurnRolloutProcessor,
1716
)
1817
from eval_protocol.pytest.evaluation_test import evaluation_test
1918

2019

21-
DEFAULT_MODEL_ID = "fireworks_ai/accounts/fireworks/models/glm-4p6"
20+
DEFAULT_MODEL_ID = "fireworks_ai/accounts/pyroworks/deployedModels/minimax-m2-zmi4qk9f"
2221
DEFAULT_MAX_TOKENS = 10000
2322

2423

2524
def _coerce_content_to_str(
26-
content: str | list[ChatCompletionContentPartTextParam] | None,
25+
content: str | list[Any] | None,
2726
) -> str:
2827
if isinstance(content, list):
2928
texts: list[str] = []
@@ -153,7 +152,34 @@ def _safe_json_loads(payload: str) -> Any | None:
153152
"content": "Call test_brace_bug with param1='test_value', param2=42, and param3=true",
154153
}
155154
],
156-
"tools": WEATHER_TOOL_DEFINITION,
155+
"tools": [
156+
{
157+
"type": "function",
158+
"function": {
159+
"name": "test_brace_bug",
160+
"description": "A test function to validate JSON brace handling in tool arguments",
161+
"parameters": {
162+
"type": "object",
163+
"properties": {
164+
"param1": {
165+
"type": "string",
166+
"description": "A string parameter",
167+
},
168+
"param2": {
169+
"type": "integer",
170+
"description": "An integer parameter",
171+
},
172+
"param3": {
173+
"type": "boolean",
174+
"description": "A boolean parameter",
175+
},
176+
},
177+
"required": ["param1", "param2", "param3"],
178+
"additionalProperties": False,
179+
},
180+
},
181+
}
182+
],
157183
"temperature": 0.1,
158184
"top_p": 1,
159185
}
@@ -468,48 +494,6 @@ def _safe_json_loads(payload: str) -> Any | None:
468494
"stream": True,
469495
}
470496

471-
PEER_TOOL_RECOVERY_FAILURE_PAYLOAD = {
472-
"messages": [
473-
{
474-
"role": "user",
475-
"content": (
476-
"View the file at /tmp/test.txt. If that fails, try again with the correct parameters. "
477-
"Keep retrying until it works."
478-
),
479-
}
480-
],
481-
"tools": [
482-
{
483-
"type": "function",
484-
"function": {
485-
"name": "view",
486-
"description": "View a file or directory",
487-
"strict": True,
488-
"parameters": {
489-
"type": "object",
490-
"properties": {
491-
"path": {
492-
"type": "string",
493-
"description": "Path to the file or directory to view",
494-
},
495-
"type": {
496-
"type": "string",
497-
"enum": ["file", "directory"],
498-
"description": "Type of the path (file or directory)",
499-
},
500-
},
501-
"required": ["path", "type"],
502-
"additionalProperties": False,
503-
},
504-
},
505-
}
506-
],
507-
"tool_choice": "required",
508-
"temperature": 0.1,
509-
"max_tokens": 4000,
510-
"stream": True,
511-
}
512-
513497

514498
def _build_row_from_payload(case: str, payload: dict[str, Any]) -> EvaluationRow:
515499
messages = [
@@ -1329,47 +1313,50 @@ def test_streaming_multiple_tool_calls(row: EvaluationRow) -> EvaluationRow:
13291313
return row
13301314

13311315

1332-
_PEER_TOOL_MISSING_REQUIRED_ROW = _build_row_from_payload(
1333-
"peer-tool-missing-required-param", PEER_TOOL_MISSING_REQUIRED_PARAM_PAYLOAD
1316+
_PEER_TOOL_REQUIRED_PARAMS_ROW = _build_row_from_payload(
1317+
"peer-tool-required-params", PEER_TOOL_MISSING_REQUIRED_PARAM_PAYLOAD
13341318
)
13351319

13361320

13371321
@evaluation_test(
1338-
input_rows=[[_PEER_TOOL_MISSING_REQUIRED_ROW]],
1322+
input_rows=[[_PEER_TOOL_REQUIRED_PARAMS_ROW]],
13391323
completion_params=[_build_completion_params_from_payload(PEER_TOOL_MISSING_REQUIRED_PARAM_PAYLOAD)],
13401324
rollout_processor=SingleTurnRolloutProcessor(),
13411325
aggregation_method="mean",
13421326
passed_threshold=0.0,
13431327
num_runs=1,
13441328
mode="pointwise",
13451329
)
1346-
def test_streaming_tool_missing_required_param(row: EvaluationRow) -> EvaluationRow:
1347-
"""Detect whether required parameters are omitted during streaming."""
1330+
def test_streaming_tool_required_params_present(row: EvaluationRow) -> EvaluationRow:
1331+
"""Verify that tool calls include all required parameters during streaming."""
13481332

13491333
assistant_msg = row.last_assistant_message()
13501334
finish_reason = row.execution_metadata.finish_reason
1351-
_debug_log_assistant_message("tool_missing_required_param", assistant_msg, finish_reason)
1335+
_debug_log_assistant_message("tool_required_params", assistant_msg, finish_reason)
13521336
content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else ""
13531337
reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else ""
13541338
calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else [])
13551339

1356-
missing_required = False
1340+
required_params_present = False
13571341
arguments = None
13581342
for _, args in calls:
13591343
if args:
13601344
arguments = args
1361-
missing_required = "type" not in args or args.get("type") not in {"file", "directory"}
1345+
# Check that required 'type' param is present and valid
1346+
required_params_present = "type" in args and args.get("type") in {"file", "directory"}
13621347

13631348
metrics = {
13641349
"tool_call_emitted": MetricResult(
13651350
score=1.0 if calls else 0.0,
13661351
is_score_valid=True,
13671352
reason="Tool call emitted" if calls else "No tool call emitted",
13681353
),
1369-
"missing_required_param": MetricResult(
1370-
score=1.0 if missing_required else 0.0,
1354+
"required_params_present": MetricResult(
1355+
score=1.0 if required_params_present else 0.0,
13711356
is_score_valid=bool(calls),
1372-
reason="Required parameter missing or invalid" if missing_required else "All required parameters present",
1357+
reason="All required parameters present"
1358+
if required_params_present
1359+
else "Required parameter missing or invalid",
13731360
data={"arguments": arguments},
13741361
),
13751362
"finish_reason": MetricResult(
@@ -1386,15 +1373,19 @@ def test_streaming_tool_missing_required_param(row: EvaluationRow) -> Evaluation
13861373
)
13871374

13881375
all_checks_passed = (
1389-
missing_required and finish_reason_present and no_forbidden_tags and no_xml_tags and no_reasoning_leakage
1376+
required_params_present
1377+
and finish_reason_present
1378+
and no_forbidden_tags
1379+
and no_xml_tags
1380+
and no_reasoning_leakage
13901381
)
13911382

13921383
row.evaluation_result = EvaluateResult(
13931384
score=1.0 if all_checks_passed else 0.0,
13941385
is_score_valid=True,
1395-
reason="Detected missing required parameter"
1386+
reason="All required parameters included in tool call"
13961387
if all_checks_passed
1397-
else "Required parameters satisfied or response invalid",
1388+
else "Required parameters missing or response invalid",
13981389
metrics=metrics,
13991390
)
14001391
return row
@@ -1674,71 +1665,6 @@ def test_streaming_tool_parameter_types(row: EvaluationRow) -> EvaluationRow:
16741665
return row
16751666

16761667

1677-
_PEER_TOOL_RECOVERY_ROW = _build_row_from_payload("peer-tool-recovery-failure", PEER_TOOL_RECOVERY_FAILURE_PAYLOAD)
1678-
1679-
1680-
@evaluation_test(
1681-
input_rows=[[_PEER_TOOL_RECOVERY_ROW]],
1682-
completion_params=[_build_completion_params_from_payload(PEER_TOOL_RECOVERY_FAILURE_PAYLOAD)],
1683-
rollout_processor=SingleTurnRolloutProcessor(),
1684-
aggregation_method="mean",
1685-
passed_threshold=0.0,
1686-
num_runs=1,
1687-
mode="pointwise",
1688-
)
1689-
def test_streaming_tool_retry_behavior(row: EvaluationRow) -> EvaluationRow:
1690-
"""Check whether the assistant retries tool calls when instructed to recover."""
1691-
1692-
assistant_msg = row.last_assistant_message()
1693-
print(f"assistant_msg: {assistant_msg}")
1694-
finish_reason = row.execution_metadata.finish_reason
1695-
_debug_log_assistant_message("tool_recovery", assistant_msg, finish_reason)
1696-
content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else ""
1697-
calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else [])
1698-
reasoning = (assistant_msg.reasoning_content or "").strip() if assistant_msg else ""
1699-
1700-
multiple_attempts = len(calls) >= 2
1701-
metrics = {
1702-
"tool_call_attempts": MetricResult(
1703-
score=1.0 if multiple_attempts else 0.0,
1704-
is_score_valid=True,
1705-
reason="Multiple tool call attempts" if multiple_attempts else "Single/no tool call attempt",
1706-
data={"tool_call_count": len(calls)},
1707-
),
1708-
"reasoning_present": MetricResult(
1709-
score=1.0 if reasoning else 0.0,
1710-
is_score_valid=True,
1711-
reason="Reasoning present" if reasoning else "No reasoning provided",
1712-
data={"reasoning": reasoning[:160]},
1713-
),
1714-
"finish_reason": MetricResult(
1715-
score=1.0 if finish_reason in {"tool_calls", "stop"} else 0.0,
1716-
is_score_valid=True,
1717-
reason="finish_reason acceptable"
1718-
if finish_reason in {"tool_calls", "stop"}
1719-
else f"Unexpected finish_reason: {finish_reason}",
1720-
),
1721-
}
1722-
1723-
finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks(
1724-
metrics, finish_reason, content_str, reasoning
1725-
)
1726-
1727-
all_checks_passed = (
1728-
multiple_attempts and finish_reason_present and no_forbidden_tags and no_xml_tags and no_reasoning_leakage
1729-
)
1730-
1731-
row.evaluation_result = EvaluateResult(
1732-
score=1.0 if all_checks_passed else 0.0,
1733-
is_score_valid=True,
1734-
reason="Multiple recovery attempts observed"
1735-
if all_checks_passed
1736-
else "Recovery attempts missing or response invalid",
1737-
metrics=metrics,
1738-
)
1739-
return row
1740-
1741-
17421668
# ============================================================================
17431669
# Reasoning Effort Tests
17441670
# ============================================================================
@@ -1759,7 +1685,7 @@ def test_streaming_tool_retry_behavior(row: EvaluationRow) -> EvaluationRow:
17591685
input_rows=[[REASONING_DISABLED_ROW]],
17601686
completion_params=[
17611687
{
1762-
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", # Reasoning-capable model
1688+
"model": DEFAULT_MODEL_ID, # Reasoning-capable model
17631689
"reasoning_effort": "none", # Explicitly disable reasoning
17641690
"max_tokens": DEFAULT_MAX_TOKENS,
17651691
"temperature": 0.0,
@@ -1869,7 +1795,7 @@ def test_reasoning_effort_none_no_reasoning(row: EvaluationRow) -> EvaluationRow
18691795
input_rows=[[REASONING_ENABLED_ROW]],
18701796
completion_params=[
18711797
{
1872-
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", # Reasoning-capable model
1798+
"model": DEFAULT_MODEL_ID, # Reasoning-capable model
18731799
"reasoning_effort": "low", # Enable reasoning
18741800
"max_tokens": DEFAULT_MAX_TOKENS,
18751801
"temperature": 0.0,
@@ -2004,7 +1930,7 @@ def test_reasoning_effort_low_has_reasoning(row: EvaluationRow) -> EvaluationRow
20041930
input_rows=[[TOOLS_WITH_REASONING_ROW]],
20051931
completion_params=[
20061932
{
2007-
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", # Reasoning-capable model
1933+
"model": DEFAULT_MODEL_ID, # Reasoning-capable model
20081934
"reasoning_effort": "low", # Enable reasoning
20091935
"max_tokens": DEFAULT_MAX_TOKENS,
20101936
"temperature": 0.0,
@@ -2727,7 +2653,7 @@ def test_non_streaming_multiple_tool_calls(row: EvaluationRow) -> EvaluationRow:
27272653
input_rows=[[REASONING_DISABLED_NON_STREAM_ROW]],
27282654
completion_params=[
27292655
{
2730-
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
2656+
"model": DEFAULT_MODEL_ID,
27312657
"reasoning_effort": "none",
27322658
"max_tokens": DEFAULT_MAX_TOKENS,
27332659
"temperature": 0.0,
@@ -2834,7 +2760,7 @@ def test_reasoning_effort_none_no_reasoning_non_stream(row: EvaluationRow) -> Ev
28342760
input_rows=[[REASONING_ENABLED_NON_STREAM_ROW]],
28352761
completion_params=[
28362762
{
2837-
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
2763+
"model": DEFAULT_MODEL_ID,
28382764
"reasoning_effort": "low",
28392765
"max_tokens": DEFAULT_MAX_TOKENS,
28402766
"temperature": 0.0,
@@ -2962,7 +2888,7 @@ def test_reasoning_effort_low_has_reasoning_non_stream(row: EvaluationRow) -> Ev
29622888
input_rows=[[TOOLS_WITH_REASONING_NON_STREAM_ROW]],
29632889
completion_params=[
29642890
{
2965-
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
2891+
"model": DEFAULT_MODEL_ID,
29662892
"reasoning_effort": "low",
29672893
"max_tokens": DEFAULT_MAX_TOKENS,
29682894
"temperature": 0.0,
@@ -3108,7 +3034,7 @@ def test_non_streaming_tools_with_reasoning(row: EvaluationRow) -> EvaluationRow
31083034
input_rows=[[STRUCTURED_OUTPUT_WITH_REASONING_ROW]],
31093035
completion_params=[
31103036
{
3111-
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
3037+
"model": DEFAULT_MODEL_ID,
31123038
"stream": True,
31133039
"reasoning_effort": "low",
31143040
"response_format": STRUCTURED_JSON_SCHEMA,
@@ -3211,7 +3137,7 @@ def test_streaming_structured_output_with_reasoning(row: EvaluationRow) -> Evalu
32113137
input_rows=[[STRUCTURED_OUTPUT_WITH_REASONING_NON_STREAM_ROW]],
32123138
completion_params=[
32133139
{
3214-
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
3140+
"model": DEFAULT_MODEL_ID,
32153141
"stream": False,
32163142
"reasoning_effort": "low",
32173143
"response_format": STRUCTURED_JSON_SCHEMA,
@@ -3334,7 +3260,7 @@ def test_non_streaming_structured_output_with_reasoning(row: EvaluationRow) -> E
33343260
input_rows=[[MULTIPLE_TOOLS_WITH_REASONING_ROW]],
33353261
completion_params=[
33363262
{
3337-
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
3263+
"model": DEFAULT_MODEL_ID,
33383264
"stream": True,
33393265
"reasoning_effort": "low",
33403266
"temperature": 0.0,
@@ -3461,7 +3387,7 @@ def test_streaming_multiple_tools_with_reasoning(row: EvaluationRow) -> Evaluati
34613387
input_rows=[[MULTIPLE_TOOLS_WITH_REASONING_NON_STREAM_ROW]],
34623388
completion_params=[
34633389
{
3464-
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
3390+
"model": DEFAULT_MODEL_ID,
34653391
"stream": False,
34663392
"reasoning_effort": "low",
34673393
"temperature": 0.0,

0 commit comments

Comments
 (0)