diff --git a/sdk/python/agentfield/node_logs.py b/sdk/python/agentfield/node_logs.py index f5768417..909b7c6a 100644 --- a/sdk/python/agentfield/node_logs.py +++ b/sdk/python/agentfield/node_logs.py @@ -5,6 +5,7 @@ from __future__ import annotations import json +import io import os import queue import secrets @@ -93,7 +94,9 @@ def append(self, stream: str, text: str, max_line_bytes: int) -> None: ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" with self._lock: self._seq += 1 - entry = LogEntry(seq=self._seq, ts=ts, stream=stream, line=raw, truncated=truncated) + entry = LogEntry( + seq=self._seq, ts=ts, stream=stream, line=raw, truncated=truncated + ) self._entries.append(entry) self._approx_bytes += len(entry.line.encode("utf-8")) + 64 while self._approx_bytes > self._max_bytes and len(self._entries) > 1: @@ -101,7 +104,9 @@ def append(self, stream: str, text: str, max_line_bytes: int) -> None: self._approx_bytes -= len(old.line.encode("utf-8")) + 64 _notify_followers() - def snapshot_after(self, since_seq: int, limit: Optional[int] = None) -> List[LogEntry]: + def snapshot_after( + self, since_seq: int, limit: Optional[int] = None + ) -> List[LogEntry]: with self._lock: items = [e for e in self._entries if e.seq > since_seq] if limit is not None and limit > 0: @@ -119,7 +124,7 @@ def max_seq(self) -> int: return self._seq -class _TeeTextIO(TextIO): +class _TeeTextIO(io.TextIOBase): """Write-through to original stream and log ring (line-buffered by \\n).""" def __init__( @@ -146,10 +151,33 @@ def write(self, s: str) -> int: self._ring.append(self._stream_name, line, self._max_line_bytes) return len(s) + def writelines(self, lines) -> None: + for line in lines: + self.write(line) + def flush(self) -> None: self._original.flush() - # Minimal TextIO protocol for print() + def fileno(self) -> int: + return self._original.fileno() + + def readable(self) -> bool: + return bool(self._original.readable()) + + def writable(self) -> bool: + return bool(self._original.writable()) + + def seekable(self) -> bool: + return bool(self._original.seekable()) + + def close(self) -> None: + if self.closed: + return + if self._buf: + self._ring.append(self._stream_name, self._buf, self._max_line_bytes) + self._buf = "" + super().close() + @property def encoding(self) -> str: return getattr(self._original, "encoding", "utf-8") or "utf-8" @@ -221,7 +249,9 @@ def iter_tail_ndjson( ring = get_ring() cap_tail = tail_lines if since_seq > 0: - entries = ring.snapshot_after(since_seq, limit=cap_tail if cap_tail > 0 else None) + entries = ring.snapshot_after( + since_seq, limit=cap_tail if cap_tail > 0 else None + ) else: n = cap_tail if cap_tail > 0 else 200 entries = ring.tail(n) diff --git a/sdk/python/tests/test_harness_cli.py b/sdk/python/tests/test_harness_cli.py new file mode 100644 index 00000000..314c9b91 --- /dev/null +++ b/sdk/python/tests/test_harness_cli.py @@ -0,0 +1,172 @@ +"""Tests for shared subprocess helpers used by CLI harness providers.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentfield.harness._cli import ( + estimate_cli_cost, + extract_final_text, + parse_jsonl, + run_cli, + strip_ansi, +) + + +def test_strip_ansi_removes_colors(): + assert strip_ansi("\x1b[31mError\x1b[0m") == "Error" + + +@pytest.mark.asyncio +async def test_run_cli_success(): + process = MagicMock() + process.communicate = AsyncMock(return_value=(b"OK", b"")) + process.returncode = 0 + + create_process = AsyncMock(return_value=process) + + with patch("asyncio.create_subprocess_exec", create_process): + stdout, stderr, returncode = await run_cli( + ["agentfield", "status"], + env={"AGENTFIELD_TEST": "1"}, + cwd=".", + timeout=1, + ) + + assert stdout == "OK" + assert stderr == "" + assert returncode == 0 + create_process.assert_awaited_once() + _, kwargs = create_process.call_args + assert kwargs["env"]["AGENTFIELD_TEST"] == "1" + assert kwargs["cwd"] == "." + assert kwargs["stdout"] is asyncio.subprocess.PIPE + assert kwargs["stderr"] is asyncio.subprocess.PIPE + + +@pytest.mark.asyncio +async def test_run_cli_timeout(): + class HangingProcess: + returncode = None + + def __init__(self) -> None: + self.killed = False + self.wait = AsyncMock(return_value=None) + + async def communicate(self): + await asyncio.sleep(1) + return b"", b"" + + def kill(self): + self.killed = True + + process = HangingProcess() + + with patch("asyncio.create_subprocess_exec", AsyncMock(return_value=process)): + with pytest.raises(TimeoutError, match="CLI command timed out"): + await run_cli(["agentfield", "hang"], timeout=0.01) + + assert process.killed is True + process.wait.assert_awaited_once() + + +def test_parse_jsonl_skips_invalid(): + events = parse_jsonl('{"type":"a"}\nnot-json\n{"type":"b"}') + + assert events == [{"type": "a"}, {"type": "b"}] + + +def test_extract_final_text_codex_style(): + events = [ + {"type": "item.completed", "item": {"type": "agent_message", "text": "first"}}, + { + "type": "item.completed", + "item": {"type": "agent_message", "text": "final answer"}, + }, + ] + + assert extract_final_text(events) == "final answer" + + +@pytest.mark.parametrize( + ("events", "expected"), + [ + ([{"type": "result", "result": "result answer"}], "result answer"), + ([{"type": "result", "text": "text answer"}], "text answer"), + ([{"type": "turn.completed", "text": "turn answer"}], "turn answer"), + ([{"type": "message", "content": "message answer"}], "message answer"), + ([{"type": "assistant", "text": "assistant answer"}], "assistant answer"), + ], +) +def test_extract_final_text_event_variants(events, expected): + assert extract_final_text(events) == expected + + +def test_extract_final_text_empty_events(): + assert extract_final_text([]) is None + + +def test_estimate_cli_cost_calls_litellm(): + mock_litellm = MagicMock() + mock_litellm.completion_cost.return_value = 0.05 + + with patch.dict("sys.modules", {"litellm": mock_litellm}): + cost = estimate_cli_cost( + model="openai/gpt-4o", + prompt="Summarize this run", + result_text="Done", + ) + + assert cost == 0.05 + mock_litellm.completion_cost.assert_called_once_with( + model="openai/gpt-4o", + prompt="Summarize this run", + completion="Done", + ) + + +def test_estimate_cli_cost_returns_none_without_model(): + assert estimate_cli_cost(model="", prompt="prompt", result_text="Done") is None + + +def test_estimate_cli_cost_returns_none_when_litellm_missing(): + with patch.dict("sys.modules", {"litellm": None}): + cost = estimate_cli_cost( + model="openai/gpt-4o", + prompt="Summarize this run", + result_text="Done", + ) + + assert cost is None + + +@pytest.mark.parametrize("raw_cost", [0, None]) +def test_estimate_cli_cost_returns_none_for_non_positive_cost(raw_cost): + mock_litellm = MagicMock() + mock_litellm.completion_cost.return_value = raw_cost + + with patch.dict("sys.modules", {"litellm": mock_litellm}): + cost = estimate_cli_cost( + model="openai/gpt-4o", + prompt="Summarize this run", + result_text="Done", + ) + + assert cost is None + + +def test_estimate_cli_cost_returns_none_when_litellm_raises(): + mock_litellm = MagicMock() + mock_litellm.completion_cost.side_effect = RuntimeError("pricing unavailable") + + with patch.dict("sys.modules", {"litellm": mock_litellm}): + cost = estimate_cli_cost( + model="openai/gpt-4o", + prompt="Summarize this run", + result_text="Done", + ) + + assert cost is None diff --git a/sdk/python/tests/test_node_logs.py b/sdk/python/tests/test_node_logs.py index 2bd27f56..36469ded 100644 --- a/sdk/python/tests/test_node_logs.py +++ b/sdk/python/tests/test_node_logs.py @@ -1,18 +1,25 @@ """ Tests for agentfield.node_logs — ProcessLogRing and related helpers. """ + from __future__ import annotations +import io import json +import queue +import sys import threading +import pytest from agentfield.node_logs import ( LogEntry, ProcessLogRing, + _TeeTextIO, + get_ring, + install_stdio_tee, iter_tail_ndjson, verify_internal_bearer, - get_ring, ) @@ -23,7 +30,9 @@ class TestLogEntryNdjson: def test_stdout_produces_info_level(self): - entry = LogEntry(seq=1, ts="2024-01-01T00:00:00.000Z", stream="stdout", line="hello") + entry = LogEntry( + seq=1, ts="2024-01-01T00:00:00.000Z", stream="stdout", line="hello" + ) data = json.loads(entry.to_ndjson_line().decode()) assert data["level"] == "info" assert data["line"] == "hello" @@ -31,12 +40,16 @@ def test_stdout_produces_info_level(self): assert data["source"] == "process" def test_stderr_produces_error_level(self): - entry = LogEntry(seq=2, ts="2024-01-01T00:00:00.000Z", stream="stderr", line="err") + entry = LogEntry( + seq=2, ts="2024-01-01T00:00:00.000Z", stream="stderr", line="err" + ) data = json.loads(entry.to_ndjson_line().decode()) assert data["level"] == "error" def test_other_stream_produces_log_level(self): - entry = LogEntry(seq=3, ts="2024-01-01T00:00:00.000Z", stream="custom", line="msg") + entry = LogEntry( + seq=3, ts="2024-01-01T00:00:00.000Z", stream="custom", line="msg" + ) data = json.loads(entry.to_ndjson_line().decode()) assert data["level"] == "log" @@ -55,7 +68,9 @@ def test_ndjson_ends_with_newline(self): assert entry.to_ndjson_line().endswith(b"\n") def test_seq_and_ts_preserved(self): - entry = LogEntry(seq=42, ts="2024-06-15T10:00:00.000Z", stream="stdout", line="data") + entry = LogEntry( + seq=42, ts="2024-06-15T10:00:00.000Z", stream="stdout", line="data" + ) data = json.loads(entry.to_ndjson_line().decode()) assert data["seq"] == 42 assert data["ts"] == "2024-06-15T10:00:00.000Z" @@ -157,7 +172,9 @@ def test_long_line_is_truncated(self): ring.append("stdout", long_text, max_line_bytes=10) entries = ring.tail(1) assert entries[0].truncated is True - assert len(entries[0].line.encode("utf-8")) <= 10 + 3 # allow for replacement chars + assert ( + len(entries[0].line.encode("utf-8")) <= 10 + 3 + ) # allow for replacement chars def test_short_line_is_not_truncated(self): ring = ProcessLogRing(max_bytes=1024 * 1024) @@ -275,6 +292,279 @@ def test_iter_tail_empty_ring(self, monkeypatch): assert chunks == [] +# --------------------------------------------------------------------------- +# _TeeTextIO and install_stdio_tee +# --------------------------------------------------------------------------- + + +class TestTeeTextIO: + def test_tee_text_io_writes_to_original(self): + original = io.StringIO() + ring = ProcessLogRing(max_bytes=1024 * 1024) + tee = _TeeTextIO("stdout", original, ring, max_line_bytes=1024) + + written = tee.write("hello\n") + + assert written == len("hello\n") + assert original.getvalue() == "hello\n" + + def test_tee_text_io_appends_to_ring(self): + original = io.StringIO() + ring = ProcessLogRing(max_bytes=1024 * 1024) + tee = _TeeTextIO("stdout", original, ring, max_line_bytes=1024) + + tee.write("one line\n") + + entries = ring.tail(1) + assert len(entries) == 1 + assert entries[0].stream == "stdout" + assert entries[0].line == "one line" + + def test_tee_text_io_buffers_until_newline(self): + original = io.StringIO() + ring = ProcessLogRing(max_bytes=1024 * 1024) + tee = _TeeTextIO("stderr", original, ring, max_line_bytes=1024) + + tee.write("partial") + assert ring.tail(1) == [] + + tee.write(" line\n") + entries = ring.tail(1) + assert entries[0].stream == "stderr" + assert entries[0].line == "partial line" + + def test_installed_tee_exposes_text_io_methods(self, monkeypatch): + import agentfield.node_logs as nl + + class TextStream(io.StringIO): + def fileno(self): + return 42 + + previous_stdout = sys.stdout + previous_stderr = sys.stderr + original_stdout = TextStream() + original_stderr = TextStream() + ring = ProcessLogRing(max_bytes=1024 * 1024) + + monkeypatch.setenv("AGENTFIELD_LOGS_ENABLED", "true") + monkeypatch.setattr(sys, "__stdout__", original_stdout) + monkeypatch.setattr(sys, "__stderr__", original_stderr) + monkeypatch.setattr(nl, "_global_ring", ring) + monkeypatch.setattr(nl, "_tee_installed", False) + + try: + install_stdio_tee() + assert isinstance(sys.stdout, _TeeTextIO) + assert sys.stdout.fileno() == 42 + assert sys.stdout.readable() is True + assert sys.stdout.writable() is True + assert sys.stdout.seekable() is True + + sys.stdout.writelines(["first\n", "second\n"]) + assert original_stdout.getvalue() == "first\nsecond\n" + assert [entry.line for entry in ring.tail(2)] == ["first", "second"] + + sys.stdout.write("partial") + sys.stdout.close() + assert original_stdout.closed is False + assert ring.tail(1)[0].line == "partial" + original_stdout.write(" still usable") + assert original_stdout.getvalue().endswith("partial still usable") + finally: + sys.stdout = previous_stdout + sys.stderr = previous_stderr + nl._tee_installed = False + + def test_install_stdio_tee_replaces_sys_stdout(self, monkeypatch): + import agentfield.node_logs as nl + + previous_stdout = sys.stdout + previous_stderr = sys.stderr + original_stdout = io.StringIO() + original_stderr = io.StringIO() + ring = ProcessLogRing(max_bytes=1024 * 1024) + + monkeypatch.setenv("AGENTFIELD_LOGS_ENABLED", "true") + monkeypatch.setattr(sys, "__stdout__", original_stdout) + monkeypatch.setattr(sys, "__stderr__", original_stderr) + monkeypatch.setattr(nl, "_global_ring", ring) + monkeypatch.setattr(nl, "_tee_installed", False) + + try: + install_stdio_tee() + assert isinstance(sys.stdout, _TeeTextIO) + assert isinstance(sys.stderr, _TeeTextIO) + first_stdout = sys.stdout + install_stdio_tee() + assert sys.stdout is first_stdout + assert sys.stdout._original is original_stdout + + sys.stdout.write("captured\n") + assert original_stdout.getvalue() == "captured\n" + assert ring.tail(1)[0].line == "captured" + finally: + sys.stdout = previous_stdout + sys.stderr = previous_stderr + nl._tee_installed = False + + def test_install_stdio_tee_disabled_env_leaves_streams_unchanged(self, monkeypatch): + import agentfield.node_logs as nl + + previous_stdout = sys.stdout + previous_stderr = sys.stderr + original_stdout = io.StringIO() + original_stderr = io.StringIO() + + monkeypatch.setenv("AGENTFIELD_LOGS_ENABLED", "false") + monkeypatch.setattr(sys, "__stdout__", original_stdout) + monkeypatch.setattr(sys, "__stderr__", original_stderr) + monkeypatch.setattr(nl, "_global_ring", ProcessLogRing(max_bytes=1024 * 1024)) + monkeypatch.setattr(nl, "_tee_installed", False) + + install_stdio_tee() + + assert sys.stdout is previous_stdout + assert sys.stderr is previous_stderr + assert nl._tee_installed is False + + +class TestIterTailNdjsonFollow: + def test_iter_tail_ndjson_follow_mode(self, monkeypatch): + import agentfield.node_logs as nl + + ring = ProcessLogRing(max_bytes=1024 * 1024) + monkeypatch.setattr(nl, "_global_ring", ring) + monkeypatch.setattr(nl, "_follow_queues", []) + queue_registered = threading.Event() + original_register_follow_queue = nl.register_follow_queue + + def register_follow_queue(q): + original_register_follow_queue(q) + queue_registered.set() + + monkeypatch.setattr(nl, "register_follow_queue", register_follow_queue) + + chunks: list[bytes] = [] + errors: list[BaseException] = [] + generator = iter_tail_ndjson(tail_lines=0, since_seq=0, follow=True) + + def read_next(): + try: + chunks.append(next(generator)) + except Exception as exc: # pragma: no cover - assertion reports details + errors.append(exc) + + thread = threading.Thread(target=read_next) + thread.start() + assert queue_registered.wait(timeout=2) + + ring.append("stdout", "new log", max_line_bytes=1024) + thread.join(timeout=2) + generator.close() + + assert errors == [] + assert len(chunks) == 1 + assert json.loads(chunks[0].decode())["line"] == "new log" + + def test_iter_tail_ndjson_follow_emits_tail_then_new_entries(self, monkeypatch): + import agentfield.node_logs as nl + + ring = ProcessLogRing(max_bytes=1024 * 1024) + for i in range(3): + ring.append("stdout", f"line{i}", max_line_bytes=1024) + monkeypatch.setattr(nl, "_global_ring", ring) + monkeypatch.setattr(nl, "_follow_queues", []) + queue_registered = threading.Event() + original_register_follow_queue = nl.register_follow_queue + + def register_follow_queue(q): + original_register_follow_queue(q) + queue_registered.set() + + monkeypatch.setattr(nl, "register_follow_queue", register_follow_queue) + + generator = iter_tail_ndjson(tail_lines=2, since_seq=0, follow=True) + prelude = [json.loads(next(generator).decode()) for _ in range(2)] + chunks: list[bytes] = [] + errors: list[BaseException] = [] + + def read_next(): + try: + chunks.append(next(generator)) + except Exception as exc: # pragma: no cover - assertion reports details + errors.append(exc) + + thread = threading.Thread(target=read_next) + thread.start() + assert queue_registered.wait(timeout=2) + + ring.append("stdout", "followed", max_line_bytes=1024) + thread.join(timeout=2) + generator.close() + + assert [entry["line"] for entry in prelude] == ["line1", "line2"] + assert errors == [] + assert len(chunks) == 1 + assert json.loads(chunks[0].decode())["line"] == "followed" + + def test_iter_tail_ndjson_unregisters_on_close(self, monkeypatch): + import agentfield.node_logs as nl + + class ClosingQueue: + def __init__(self, maxsize: int) -> None: + self.maxsize = maxsize + + def put_nowait(self, _item): + return None + + def get(self, timeout: float): + assert timeout == 0.5 + raise GeneratorExit + + ring = ProcessLogRing(max_bytes=1024 * 1024) + monkeypatch.setattr(nl, "_global_ring", ring) + monkeypatch.setattr(nl, "_follow_queues", []) + monkeypatch.setattr(nl.queue, "Queue", ClosingQueue) + + generator = iter_tail_ndjson(tail_lines=0, since_seq=0, follow=True) + with pytest.raises(GeneratorExit): + next(generator) + + assert nl._follow_queues == [] + + def test_iter_tail_ndjson_queue_timeout(self, monkeypatch): + import agentfield.node_logs as nl + + ring = ProcessLogRing(max_bytes=1024 * 1024) + + class TimeoutQueue: + def __init__(self, maxsize: int) -> None: + self.maxsize = maxsize + self._appended = False + + def put_nowait(self, _item): + return None + + def get(self, timeout: float): + assert timeout == 0.5 + if not self._appended: + self._appended = True + ring.append("stdout", "after timeout", max_line_bytes=1024) + raise queue.Empty + + monkeypatch.setattr(nl, "_global_ring", ring) + monkeypatch.setattr(nl, "_follow_queues", []) + monkeypatch.setattr(nl.queue, "Queue", TimeoutQueue) + + generator = iter_tail_ndjson(tail_lines=0, since_seq=0, follow=True) + try: + chunk = next(generator) + finally: + generator.close() + + assert json.loads(chunk.decode())["line"] == "after timeout" + + # --------------------------------------------------------------------------- # verify_internal_bearer # ---------------------------------------------------------------------------