Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 61 additions & 135 deletions eval_protocol/benchmarks/test_glm_streaming_compliance.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,19 @@
EvaluationRow,
Message,
MetricResult,
ChatCompletionContentPartTextParam,
)
from eval_protocol.pytest.default_single_turn_rollout_process import (
SingleTurnRolloutProcessor,
)
from eval_protocol.pytest.evaluation_test import evaluation_test


DEFAULT_MODEL_ID = "fireworks_ai/accounts/fireworks/models/glm-4p6"
DEFAULT_MODEL_ID = "fireworks_ai/accounts/pyroworks/deployedModels/minimax-m2-zmi4qk9f"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets keep it as serverless

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unless this was meant to be used only for internal use and we don't mind keeping it pointed at pyroworks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it is just internal use

DEFAULT_MAX_TOKENS = 10000


def _coerce_content_to_str(
content: str | list[ChatCompletionContentPartTextParam] | None,
content: str | list[Any] | None,
) -> str:
if isinstance(content, list):
texts: list[str] = []
Expand Down Expand Up @@ -153,7 +152,34 @@ def _safe_json_loads(payload: str) -> Any | None:
"content": "Call test_brace_bug with param1='test_value', param2=42, and param3=true",
}
],
"tools": WEATHER_TOOL_DEFINITION,
"tools": [
{
"type": "function",
"function": {
"name": "test_brace_bug",
"description": "A test function to validate JSON brace handling in tool arguments",
"parameters": {
"type": "object",
"properties": {
"param1": {
"type": "string",
"description": "A string parameter",
},
"param2": {
"type": "integer",
"description": "An integer parameter",
},
"param3": {
"type": "boolean",
"description": "A boolean parameter",
},
},
"required": ["param1", "param2", "param3"],
"additionalProperties": False,
},
},
}
],
"temperature": 0.1,
"top_p": 1,
}
Expand Down Expand Up @@ -468,48 +494,6 @@ def _safe_json_loads(payload: str) -> Any | None:
"stream": True,
}

PEER_TOOL_RECOVERY_FAILURE_PAYLOAD = {
"messages": [
{
"role": "user",
"content": (
"View the file at /tmp/test.txt. If that fails, try again with the correct parameters. "
"Keep retrying until it works."
),
}
],
"tools": [
{
"type": "function",
"function": {
"name": "view",
"description": "View a file or directory",
"strict": True,
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file or directory to view",
},
"type": {
"type": "string",
"enum": ["file", "directory"],
"description": "Type of the path (file or directory)",
},
},
"required": ["path", "type"],
"additionalProperties": False,
},
},
}
],
"tool_choice": "required",
"temperature": 0.1,
"max_tokens": 4000,
"stream": True,
}


def _build_row_from_payload(case: str, payload: dict[str, Any]) -> EvaluationRow:
messages = [
Expand Down Expand Up @@ -1329,47 +1313,50 @@ def test_streaming_multiple_tool_calls(row: EvaluationRow) -> EvaluationRow:
return row


_PEER_TOOL_MISSING_REQUIRED_ROW = _build_row_from_payload(
"peer-tool-missing-required-param", PEER_TOOL_MISSING_REQUIRED_PARAM_PAYLOAD
_PEER_TOOL_REQUIRED_PARAMS_ROW = _build_row_from_payload(
"peer-tool-required-params", PEER_TOOL_MISSING_REQUIRED_PARAM_PAYLOAD
)


@evaluation_test(
input_rows=[[_PEER_TOOL_MISSING_REQUIRED_ROW]],
input_rows=[[_PEER_TOOL_REQUIRED_PARAMS_ROW]],
completion_params=[_build_completion_params_from_payload(PEER_TOOL_MISSING_REQUIRED_PARAM_PAYLOAD)],
rollout_processor=SingleTurnRolloutProcessor(),
aggregation_method="mean",
passed_threshold=0.0,
num_runs=1,
mode="pointwise",
)
def test_streaming_tool_missing_required_param(row: EvaluationRow) -> EvaluationRow:
"""Detect whether required parameters are omitted during streaming."""
def test_streaming_tool_required_params_present(row: EvaluationRow) -> EvaluationRow:
"""Verify that tool calls include all required parameters during streaming."""

assistant_msg = row.last_assistant_message()
finish_reason = row.execution_metadata.finish_reason
_debug_log_assistant_message("tool_missing_required_param", assistant_msg, finish_reason)
_debug_log_assistant_message("tool_required_params", assistant_msg, finish_reason)
content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else ""
reasoning_str = (assistant_msg.reasoning_content or "").strip() if assistant_msg else ""
calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else [])

missing_required = False
required_params_present = False
arguments = None
for _, args in calls:
if args:
arguments = args
missing_required = "type" not in args or args.get("type") not in {"file", "directory"}
# Check that required 'type' param is present and valid
required_params_present = "type" in args and args.get("type") in {"file", "directory"}

metrics = {
"tool_call_emitted": MetricResult(
score=1.0 if calls else 0.0,
is_score_valid=True,
reason="Tool call emitted" if calls else "No tool call emitted",
),
"missing_required_param": MetricResult(
score=1.0 if missing_required else 0.0,
"required_params_present": MetricResult(
score=1.0 if required_params_present else 0.0,
is_score_valid=bool(calls),
reason="Required parameter missing or invalid" if missing_required else "All required parameters present",
reason="All required parameters present"
if required_params_present
else "Required parameter missing or invalid",
data={"arguments": arguments},
),
"finish_reason": MetricResult(
Expand All @@ -1386,15 +1373,19 @@ def test_streaming_tool_missing_required_param(row: EvaluationRow) -> Evaluation
)

all_checks_passed = (
missing_required and finish_reason_present and no_forbidden_tags and no_xml_tags and no_reasoning_leakage
required_params_present
and finish_reason_present
and no_forbidden_tags
and no_xml_tags
and no_reasoning_leakage
)

row.evaluation_result = EvaluateResult(
score=1.0 if all_checks_passed else 0.0,
is_score_valid=True,
reason="Detected missing required parameter"
reason="All required parameters included in tool call"
if all_checks_passed
else "Required parameters satisfied or response invalid",
else "Required parameters missing or response invalid",
metrics=metrics,
)
return row
Expand Down Expand Up @@ -1674,71 +1665,6 @@ def test_streaming_tool_parameter_types(row: EvaluationRow) -> EvaluationRow:
return row


_PEER_TOOL_RECOVERY_ROW = _build_row_from_payload("peer-tool-recovery-failure", PEER_TOOL_RECOVERY_FAILURE_PAYLOAD)


@evaluation_test(
input_rows=[[_PEER_TOOL_RECOVERY_ROW]],
completion_params=[_build_completion_params_from_payload(PEER_TOOL_RECOVERY_FAILURE_PAYLOAD)],
rollout_processor=SingleTurnRolloutProcessor(),
aggregation_method="mean",
passed_threshold=0.0,
num_runs=1,
mode="pointwise",
)
def test_streaming_tool_retry_behavior(row: EvaluationRow) -> EvaluationRow:
"""Check whether the assistant retries tool calls when instructed to recover."""

assistant_msg = row.last_assistant_message()
print(f"assistant_msg: {assistant_msg}")
finish_reason = row.execution_metadata.finish_reason
_debug_log_assistant_message("tool_recovery", assistant_msg, finish_reason)
content_str = _coerce_content_to_str(assistant_msg.content) if assistant_msg else ""
calls = _collect_tool_calls(assistant_msg.tool_calls if assistant_msg else [])
reasoning = (assistant_msg.reasoning_content or "").strip() if assistant_msg else ""

multiple_attempts = len(calls) >= 2
metrics = {
"tool_call_attempts": MetricResult(
score=1.0 if multiple_attempts else 0.0,
is_score_valid=True,
reason="Multiple tool call attempts" if multiple_attempts else "Single/no tool call attempt",
data={"tool_call_count": len(calls)},
),
"reasoning_present": MetricResult(
score=1.0 if reasoning else 0.0,
is_score_valid=True,
reason="Reasoning present" if reasoning else "No reasoning provided",
data={"reasoning": reasoning[:160]},
),
"finish_reason": MetricResult(
score=1.0 if finish_reason in {"tool_calls", "stop"} else 0.0,
is_score_valid=True,
reason="finish_reason acceptable"
if finish_reason in {"tool_calls", "stop"}
else f"Unexpected finish_reason: {finish_reason}",
),
}

finish_reason_present, no_forbidden_tags, no_xml_tags, no_reasoning_leakage = _augment_metrics_with_common_checks(
metrics, finish_reason, content_str, reasoning
)

all_checks_passed = (
multiple_attempts and finish_reason_present and no_forbidden_tags and no_xml_tags and no_reasoning_leakage
)

row.evaluation_result = EvaluateResult(
score=1.0 if all_checks_passed else 0.0,
is_score_valid=True,
reason="Multiple recovery attempts observed"
if all_checks_passed
else "Recovery attempts missing or response invalid",
metrics=metrics,
)
return row


# ============================================================================
# Reasoning Effort Tests
# ============================================================================
Expand All @@ -1759,7 +1685,7 @@ def test_streaming_tool_retry_behavior(row: EvaluationRow) -> EvaluationRow:
input_rows=[[REASONING_DISABLED_ROW]],
completion_params=[
{
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", # Reasoning-capable model
"model": DEFAULT_MODEL_ID, # Reasoning-capable model
"reasoning_effort": "none", # Explicitly disable reasoning
"max_tokens": DEFAULT_MAX_TOKENS,
"temperature": 0.0,
Expand Down Expand Up @@ -1869,7 +1795,7 @@ def test_reasoning_effort_none_no_reasoning(row: EvaluationRow) -> EvaluationRow
input_rows=[[REASONING_ENABLED_ROW]],
completion_params=[
{
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", # Reasoning-capable model
"model": DEFAULT_MODEL_ID, # Reasoning-capable model
"reasoning_effort": "low", # Enable reasoning
"max_tokens": DEFAULT_MAX_TOKENS,
"temperature": 0.0,
Expand Down Expand Up @@ -2004,7 +1930,7 @@ def test_reasoning_effort_low_has_reasoning(row: EvaluationRow) -> EvaluationRow
input_rows=[[TOOLS_WITH_REASONING_ROW]],
completion_params=[
{
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1", # Reasoning-capable model
"model": DEFAULT_MODEL_ID, # Reasoning-capable model
"reasoning_effort": "low", # Enable reasoning
"max_tokens": DEFAULT_MAX_TOKENS,
"temperature": 0.0,
Expand Down Expand Up @@ -2727,7 +2653,7 @@ def test_non_streaming_multiple_tool_calls(row: EvaluationRow) -> EvaluationRow:
input_rows=[[REASONING_DISABLED_NON_STREAM_ROW]],
completion_params=[
{
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
"model": DEFAULT_MODEL_ID,
"reasoning_effort": "none",
"max_tokens": DEFAULT_MAX_TOKENS,
"temperature": 0.0,
Expand Down Expand Up @@ -2834,7 +2760,7 @@ def test_reasoning_effort_none_no_reasoning_non_stream(row: EvaluationRow) -> Ev
input_rows=[[REASONING_ENABLED_NON_STREAM_ROW]],
completion_params=[
{
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
"model": DEFAULT_MODEL_ID,
"reasoning_effort": "low",
"max_tokens": DEFAULT_MAX_TOKENS,
"temperature": 0.0,
Expand Down Expand Up @@ -2962,7 +2888,7 @@ def test_reasoning_effort_low_has_reasoning_non_stream(row: EvaluationRow) -> Ev
input_rows=[[TOOLS_WITH_REASONING_NON_STREAM_ROW]],
completion_params=[
{
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
"model": DEFAULT_MODEL_ID,
"reasoning_effort": "low",
"max_tokens": DEFAULT_MAX_TOKENS,
"temperature": 0.0,
Expand Down Expand Up @@ -3108,7 +3034,7 @@ def test_non_streaming_tools_with_reasoning(row: EvaluationRow) -> EvaluationRow
input_rows=[[STRUCTURED_OUTPUT_WITH_REASONING_ROW]],
completion_params=[
{
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
"model": DEFAULT_MODEL_ID,
"stream": True,
"reasoning_effort": "low",
"response_format": STRUCTURED_JSON_SCHEMA,
Expand Down Expand Up @@ -3211,7 +3137,7 @@ def test_streaming_structured_output_with_reasoning(row: EvaluationRow) -> Evalu
input_rows=[[STRUCTURED_OUTPUT_WITH_REASONING_NON_STREAM_ROW]],
completion_params=[
{
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
"model": DEFAULT_MODEL_ID,
"stream": False,
"reasoning_effort": "low",
"response_format": STRUCTURED_JSON_SCHEMA,
Expand Down Expand Up @@ -3334,7 +3260,7 @@ def test_non_streaming_structured_output_with_reasoning(row: EvaluationRow) -> E
input_rows=[[MULTIPLE_TOOLS_WITH_REASONING_ROW]],
completion_params=[
{
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
"model": DEFAULT_MODEL_ID,
"stream": True,
"reasoning_effort": "low",
"temperature": 0.0,
Expand Down Expand Up @@ -3461,7 +3387,7 @@ def test_streaming_multiple_tools_with_reasoning(row: EvaluationRow) -> Evaluati
input_rows=[[MULTIPLE_TOOLS_WITH_REASONING_NON_STREAM_ROW]],
completion_params=[
{
"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p1",
"model": DEFAULT_MODEL_ID,
"stream": False,
"reasoning_effort": "low",
"temperature": 0.0,
Expand Down
Loading