Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@
)
from flink_agents.api.tools.tool import Tool
from flink_agents.e2e_tests.e2e_tests_integration.react_agent_tools import add, multiply
from flink_agents.e2e_tests.test_utils import pull_model
from flink_agents.e2e_tests.test_utils import (
assert_tool_invoked,
collect_tool_invocations,
pull_model,
tool_invocations_from_events,
)

current_dir = Path(__file__).parent

Expand Down Expand Up @@ -132,7 +137,13 @@ def test_react_agent_on_local_runner(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(output_list) == 1, (
"This may be caused by the LLM response does not match the output schema, you can rerun this case."
)
assert output_list[0]["0001"].result == 1386528
assert int(output_list[0]["0001"].result) == 1386528

# multiply's first arg (4444 = 2123 + 2321) proves the addition was computed
# correctly and the multiply tool was used; the model often does the addition
# without the add tool, so add is not a reliable signal to assert on.
invocations = tool_invocations_from_events(env.get_tool_request_events())
assert_tool_invoked(invocations, "multiply", {"a": 4444, "b": 312})


@pytest.mark.skipif(
Expand All @@ -149,7 +160,7 @@ def test_react_agent_on_remote_runner(
t_env = StreamTableEnvironment.create(stream_execution_environment=stream_env)

table = t_env.from_elements(
elements=[(1, 2, 3)],
elements=[(2123, 2321, 312)],
schema=DataTypes.ROW(
[
DataTypes.FIELD("a", DataTypes.INT()),
Expand All @@ -169,6 +180,10 @@ def test_react_agent_on_remote_runner(

env.get_config().set(AgentExecutionOptions.MAX_RETRIES, 3)

log_dir = tmp_path / "event_logs"
log_dir.mkdir(parents=True, exist_ok=True)
env.get_config().set_str("baseLogDir", str(log_dir))

# register resource to execution environment
(
env.add_resource(
Expand Down Expand Up @@ -243,4 +258,12 @@ def test_react_agent_on_remote_runner(
assert len(actual_result) == 1, (
"This may be caused by the LLM response does not match the output schema, you can rerun this case."
)
assert "result" in json.loads(actual_result[0].strip())
assert json.loads(actual_result[0].strip())["result"] == 1386528

# multiply's first arg (4444 = 2123 + 2321) proves the addition was computed
# correctly and threaded into multiply; the model often does the addition
# without the add tool, so add is not a reliable signal to assert on. This
# exercises the same reasoning chain as the local-runner test, but read back
# through the event-log capture path.
invocations = collect_tool_invocations(log_dir)
assert_tool_invoked(invocations, "multiply", {"a": 4444, "b": 312})
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
from flink_agents.e2e_tests.e2e_tests_resource_cross_language.chat_model_cross_language_agent import (
ChatModelCrossLanguageAgent,
)
from flink_agents.e2e_tests.test_utils import pull_model
from flink_agents.e2e_tests.test_utils import (
assert_tool_invoked,
collect_tool_invocations,
pull_model,
)

current_dir = Path(__file__).parent

Expand Down Expand Up @@ -72,6 +76,9 @@ def test_java_chat_model_integration(
deserialize_datastream = input_datastream.map(lambda x: str(x))

agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env)
log_dir = tmp_path / "event_logs"
log_dir.mkdir(parents=True, exist_ok=True)
agents_env.get_config().set_str("baseLogDir", str(log_dir))
output_datastream = (
agents_env.from_datastream(
input=deserialize_datastream, key_selector=lambda x: "orderKey"
Expand Down Expand Up @@ -106,6 +113,8 @@ def test_java_chat_model_integration(
with file.open() as f:
actual_result.extend(f.readlines())

invocations = collect_tool_invocations(log_dir)
assert_tool_invoked(invocations, "add", {"a": 1, "b": 2})

joined = "\n".join(actual_result).lower()
assert "3" in joined, f"math answer missing '3': {actual_result!r}"
assert "cat" in joined, f"creative answer missing 'cat': {actual_result!r}"
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@
)

from flink_agents.api.execution_environment import AgentsExecutionEnvironment
from flink_agents.e2e_tests.test_utils import pull_model
from flink_agents.e2e_tests.test_utils import (
assert_tool_invoked,
collect_tool_invocations,
pull_model,
)

current_dir = Path(__file__).parent
_RESOURCES = current_dir.parent / "resources"
Expand Down Expand Up @@ -116,6 +120,9 @@ def test_yaml_cross_language_agent(
deserialize_datastream = input_datastream.map(lambda x: str(x))

agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env)
log_dir = tmp_path / "event_logs"
log_dir.mkdir(parents=True, exist_ok=True)
agents_env.get_config().set_str("baseLogDir", str(log_dir))
agents_env.load_yaml(_RESOURCES / "yaml_cross_language_agent.yaml")

output_datastream = (
Expand Down Expand Up @@ -152,12 +159,16 @@ def test_yaml_cross_language_agent(
with file.open() as f:
actual_result.extend(f.readlines())

# Math path went through the Java ``calculateBMI`` tool:
# 70 / (1.75 * 1.75) ≈ 22.86, so the final answer should mention 22.
# Creative path doesn't use any tool.
# Math path went through the Java ``calculateBMI`` tool, called with the
# weight/height parsed from the input ("1.75 meters tall and weighs 70 kg").
assert_tool_invoked(
collect_tool_invocations(log_dir),
"calculateBMI",
{"weightKg": 70, "heightM": 1.75},
)
# Creative path doesn't use any tool; its answer mentions a cat.
# NOTE: We join all results and search without relying on order, because
# StreamingFileSink may produce multiple part files and iterdir() does not
# guarantee a deterministic traversal order across platforms.
joined = "\n".join(actual_result).lower()
assert "22" in joined, f"math answer missing '22': {actual_result!r}"
assert "cat" in joined, f"creative answer missing 'cat': {actual_result!r}"
115 changes: 115 additions & 0 deletions python/flink_agents/e2e_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,129 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
import json
import subprocess
from pathlib import Path

from ollama import Client

from flink_agents.api.events.tool_event import ToolRequestEvent

current_dir = Path(__file__).parent


def _normalize_arguments(arguments: object) -> dict:
"""Return tool-call arguments as a dict, parsing a JSON string if needed.

Args:
arguments: Tool-call arguments, either a mapping (Ollama path), a
JSON-encoded string (some providers ``json.dumps`` the arguments),
or ``None`` for a no-argument tool call.

Returns:
The arguments as a dict; an empty dict when ``arguments`` is ``None``.
"""
if arguments is None:
return {}
if isinstance(arguments, str):
return json.loads(arguments)
return dict(arguments)


def collect_tool_invocations(log_dir: str | Path) -> list[dict]:
"""Read ``events-*.log`` under ``log_dir`` and return tool invocations in order.

Globs the per-subtask event-log files the ``FileEventLogger`` writes, parses
each JSONL record, and extracts every ``_tool_request_event`` tool call. The
tool-call dict is nested under ``function`` in the wire format.

Args:
log_dir: Directory containing the ``events-*.log`` files (the configured
``baseLogDir``).

Returns:
Ordered list of ``{"name": str, "arguments": dict | str}``. Empty when the
model invoked no tool (a legitimate, assertable outcome).
"""
invocations = []
for log_file in sorted(Path(log_dir).glob("events-*.log")):
with log_file.open() as handle:
for line in handle:
if not line.strip():
continue
record = json.loads(line)
if record.get("eventType") != "_tool_request_event":
continue
tool_calls = record["event"]["attributes"].get("tool_calls", [])
for tool_call in tool_calls:
function = tool_call["function"]
invocations.append(
{
"name": function["name"],
"arguments": function["arguments"],
}
)
return invocations


def tool_invocations_from_events(events: list[ToolRequestEvent]) -> list[dict]:
"""Normalize live ``ToolRequestEvent`` objects to the same invocation shape.

Adapts the in-memory capture (the ``LocalRunner`` hook) to the same
``{name, arguments}`` shape :func:`collect_tool_invocations` returns from the
event log, so both sources feed :func:`assert_tool_invoked` identically. Each
event's ``tool_calls`` is a list of nested ``{id, type, function:{name,
arguments}}`` dicts; order is preserved.

Args:
events: ``ToolRequestEvent`` objects captured during a local run.

Returns:
Ordered list of ``{"name": str, "arguments": dict | str}``, one per tool
call across all events.
"""
invocations = []
for event in events:
for tool_call in event.tool_calls:
function = tool_call["function"]
invocations.append(
{
"name": function["name"],
"arguments": function["arguments"],
}
)
return invocations


def assert_tool_invoked(invocations: list[dict], name: str, arguments: dict) -> None:
"""Assert some invocation called tool ``name`` with arguments equal to ``arguments``.

Argument values are compared after normalizing both sides to a dict (a
JSON-string ``arguments`` is parsed first), so the comparison is
order-independent and tolerant of providers that encode arguments as a string.

Args:
invocations: Tool invocations as returned by :func:`collect_tool_invocations`.
name: Expected tool name.
arguments: Expected tool arguments.

Raises:
AssertionError: If no invocation matches both ``name`` and ``arguments``;
the message dumps the actual invocations.
"""
expected_args = _normalize_arguments(arguments)
for invocation in invocations:
if invocation["name"] != name:
continue
if _normalize_arguments(invocation["arguments"]) == expected_args:
return
message = (
f"No invocation of tool {name!r} with arguments {expected_args!r}; "
f"actual invocations: {invocations!r}"
)
raise AssertionError(message)


def pull_model(ollama_model: str) -> Client:
"""Run ollama pull ollama_model."""
try:
Expand Down
Loading
Loading