diff --git a/tests/ops/__init__.py b/tests/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/ops/test_config.py b/tests/ops/test_config.py new file mode 100644 index 0000000..27f55b6 --- /dev/null +++ b/tests/ops/test_config.py @@ -0,0 +1,159 @@ +"""AgentOps 配置模块测试""" + +from pathlib import Path + +from jojo_code.ops.config import OpsConfig + + +class TestOpsConfigDefaults: + """OpsConfig 默认值测试""" + + def test_default_enabled(self): + """默认应启用 Ops""" + config = OpsConfig() + assert config.enabled is True + + def test_default_persist_traces(self): + """默认应持久化 Trace""" + config = OpsConfig() + assert config.persist_traces is True + + def test_default_trace_dir(self): + """默认 Trace 目录应为 .jojo-code/traces""" + config = OpsConfig() + assert config.trace_dir == ".jojo-code/traces" + + def test_default_max_traces_in_memory(self): + """默认内存中最大 Trace 数量应为 1000""" + config = OpsConfig() + assert config.max_traces_in_memory == 1000 + + def test_default_real_time_display(self): + """默认不启用实时显示""" + config = OpsConfig() + assert config.real_time_display is False + + +class TestOpsConfigCustomValues: + """OpsConfig 自定义值测试""" + + def test_custom_enabled(self): + """应支持自定义 enabled""" + config = OpsConfig(enabled=False) + assert config.enabled is False + + def test_custom_persist_traces(self): + """应支持自定义 persist_traces""" + config = OpsConfig(persist_traces=False) + assert config.persist_traces is False + + def test_custom_trace_dir(self): + """应支持自定义 trace_dir""" + config = OpsConfig(trace_dir="/tmp/custom_traces") + assert config.trace_dir == "/tmp/custom_traces" + + def test_custom_max_traces_in_memory(self): + """应支持自定义 max_traces_in_memory""" + config = OpsConfig(max_traces_in_memory=500) + assert config.max_traces_in_memory == 500 + + def test_custom_real_time_display(self): + """应支持自定义 real_time_display""" + config = OpsConfig(real_time_display=True) + assert config.real_time_display is True + + +class TestOpsConfigFromEnv: + """OpsConfig.from_env() 环境变量加载测试""" + + def test_from_env_defaults(self, monkeypatch): + """无环境变量时应使用默认值""" + # 清除所有相关环境变量 + for key in [ + "JOJO_CODE_OPS_ENABLED", + "JOJO_CODE_OPS_PERSIST", + "JOJO_CODE_OPS_TRACE_DIR", + "JOJO_CODE_OPS_MAX_TRACES", + "JOJO_CODE_OPS_REALTIME", + ]: + monkeypatch.delenv(key, raising=False) + + config = OpsConfig.from_env() + assert config.enabled is True + assert config.persist_traces is True + assert config.trace_dir == ".jojo-code/traces" + assert config.max_traces_in_memory == 1000 + assert config.real_time_display is False + + def test_from_env_enabled_false(self, monkeypatch): + """JOJO_CODE_OPS_ENABLED=false 应禁用""" + monkeypatch.setenv("JOJO_CODE_OPS_ENABLED", "false") + config = OpsConfig.from_env() + assert config.enabled is False + + def test_from_env_enabled_case_insensitive(self, monkeypatch): + """JOJO_CODE_OPS_ENABLED 应不区分大小写""" + monkeypatch.setenv("JOJO_CODE_OPS_ENABLED", "FALSE") + config = OpsConfig.from_env() + assert config.enabled is False + + def test_from_env_persist_false(self, monkeypatch): + """JOJO_CODE_OPS_PERSIST=false 应禁用持久化""" + monkeypatch.setenv("JOJO_CODE_OPS_PERSIST", "false") + config = OpsConfig.from_env() + assert config.persist_traces is False + + def test_from_env_custom_trace_dir(self, monkeypatch): + """JOJO_CODE_OPS_TRACE_DIR 应设置自定义目录""" + monkeypatch.setenv("JOJO_CODE_OPS_TRACE_DIR", "/tmp/my_traces") + config = OpsConfig.from_env() + assert config.trace_dir == "/tmp/my_traces" + + def test_from_env_custom_max_traces(self, monkeypatch): + """JOJO_CODE_OPS_MAX_TRACES 应设置自定义数量""" + monkeypatch.setenv("JOJO_CODE_OPS_MAX_TRACES", "500") + config = OpsConfig.from_env() + assert config.max_traces_in_memory == 500 + + def test_from_env_realtime_true(self, monkeypatch): + """JOJO_CODE_OPS_REALTIME=true 应启用实时显示""" + monkeypatch.setenv("JOJO_CODE_OPS_REALTIME", "true") + config = OpsConfig.from_env() + assert config.real_time_display is True + + def test_from_env_all_custom(self, monkeypatch): + """应支持同时设置所有环境变量""" + monkeypatch.setenv("JOJO_CODE_OPS_ENABLED", "false") + monkeypatch.setenv("JOJO_CODE_OPS_PERSIST", "false") + monkeypatch.setenv("JOJO_CODE_OPS_TRACE_DIR", "/custom/path") + monkeypatch.setenv("JOJO_CODE_OPS_MAX_TRACES", "200") + monkeypatch.setenv("JOJO_CODE_OPS_REALTIME", "true") + + config = OpsConfig.from_env() + assert config.enabled is False + assert config.persist_traces is False + assert config.trace_dir == "/custom/path" + assert config.max_traces_in_memory == 200 + assert config.real_time_display is True + + +class TestOpsConfigGetTracePath: + """OpsConfig.get_trace_path() 测试""" + + def test_get_trace_path_returns_path(self): + """get_trace_path 应返回 Path 对象""" + config = OpsConfig() + path = config.get_trace_path() + assert isinstance(path, Path) + + def test_get_trace_path_default(self): + """默认路径应为 .jojo-code/traces""" + config = OpsConfig() + path = config.get_trace_path() + assert path == Path(".jojo-code/traces") + + def test_get_trace_path_custom(self): + """自定义路径应正确返回""" + config = OpsConfig(trace_dir="/tmp/custom") + path = config.get_trace_path() + assert path == Path("/tmp/custom") diff --git a/tests/ops/test_dashboard.py b/tests/ops/test_dashboard.py new file mode 100644 index 0000000..4748e40 --- /dev/null +++ b/tests/ops/test_dashboard.py @@ -0,0 +1,473 @@ +"""Tests for the ops.dashboard module.""" + +from datetime import datetime +from io import StringIO + +import pytest +from rich.console import Console + +from jojo_code.ops.dashboard import Dashboard +from jojo_code.ops.metrics import MetricsSummary, TraceMetrics +from jojo_code.ops.models import Span, SpanStatus, SpanType, Trace + + +@pytest.fixture +def dashboard(): + """Create a Dashboard with a string-buffered console.""" + buf = StringIO() + console = Console(file=buf, force_terminal=True, width=120) + d = Dashboard() + d.console = console + return d, buf + + +@pytest.fixture +def sample_trace() -> Trace: + """Create a sample Trace for dashboard testing.""" + trace = Trace( + id="dash-trace-001", + session_id="dash-session", + task="Read README.md and summarize the content", + start_time=datetime(2026, 1, 1, 10, 0, 0), + end_time=datetime(2026, 1, 1, 10, 0, 5), + status=SpanStatus.COMPLETED, + ) + trace.spans = [ + Span( + id="span-1", + trace_id="dash-trace-001", + type=SpanType.THINKING, + name="thinking", + status=SpanStatus.COMPLETED, + start_time=datetime(2026, 1, 1, 10, 0, 0), + end_time=datetime(2026, 1, 1, 10, 0, 1), + ), + Span( + id="span-2", + trace_id="dash-trace-001", + type=SpanType.TOOL_CALL, + name="read_file", + status=SpanStatus.COMPLETED, + start_time=datetime(2026, 1, 1, 10, 0, 1), + end_time=datetime(2026, 1, 1, 10, 0, 2), + ), + Span( + id="span-3", + trace_id="dash-trace-001", + type=SpanType.OBSERVE, + name="observe", + status=SpanStatus.COMPLETED, + start_time=datetime(2026, 1, 1, 10, 0, 2), + end_time=datetime(2026, 1, 1, 10, 0, 3), + ), + Span( + id="span-4", + trace_id="dash-trace-001", + type=SpanType.ERROR, + name="error", + status=SpanStatus.FAILED, + error="Something went wrong", + start_time=datetime(2026, 1, 1, 10, 0, 3), + end_time=datetime(2026, 1, 1, 10, 0, 4), + ), + Span( + id="span-5", + trace_id="dash-trace-001", + type=SpanType.TOOL_CALL, + name="write_file", + status=SpanStatus.FAILED, + start_time=datetime(2026, 1, 1, 10, 0, 4), + end_time=datetime(2026, 1, 1, 10, 0, 5), + ), + ] + return trace + + +@pytest.fixture +def sample_metrics() -> MetricsSummary: + """Create a sample MetricsSummary.""" + return MetricsSummary( + total_traces=20, + completed_traces=16, + failed_traces=4, + avg_thinking_rounds=3.2, + avg_tool_calls=5.1, + avg_duration_ms=4500.0, + tool_success_rate=0.88, + task_success_rate=0.8, + tool_usage={"read_file": 30, "write_file": 10, "grep_search": 20, "shell": 5}, + error_types={ + "PermissionError: access denied": 3, + "TimeoutError": 2, + "FileNotFoundError": 1, + }, + start_time=datetime(2026, 1, 1, 0, 0, 0), + end_time=datetime(2026, 1, 1, 23, 59, 59), + ) + + +@pytest.fixture +def sample_trace_metrics() -> TraceMetrics: + """Create a sample TraceMetrics.""" + return TraceMetrics( + trace_id="tm-001", + task="Search for TODOs", + status=SpanStatus.COMPLETED, + thinking_rounds=2, + tool_calls=3, + errors=0, + duration_ms=2500, + tool_success_rate=1.0, + tools_used=["grep_search", "read_file", "read_file"], + ) + + +class TestDashboardShowCurrentTrace: + """Tests for Dashboard.show_current_trace.""" + + def test_does_not_raise(self, dashboard, sample_trace): + """show_current_trace should not raise exceptions.""" + d, buf = dashboard + d.show_current_trace(sample_trace) + output = buf.getvalue() + assert len(output) > 0 + + def test_contains_task_name(self, dashboard, sample_trace): + """Output should contain the task name.""" + d, buf = dashboard + d.show_current_trace(sample_trace) + output = buf.getvalue() + assert "Read README.md" in output + + def test_contains_span_types(self, dashboard, sample_trace): + """Output should contain span type indicators.""" + d, buf = dashboard + d.show_current_trace(sample_trace) + output = buf.getvalue() + assert "thinking" in output + assert "tool_call" in output + + def test_contains_summary_panel(self, dashboard, sample_trace): + """Output should contain the summary panel with counts.""" + d, buf = dashboard + d.show_current_trace(sample_trace) + output = buf.getvalue() + assert "汇总" in output + + def test_handles_trace_with_no_spans(self, dashboard): + """Should handle trace with no spans gracefully.""" + d, buf = dashboard + trace = Trace( + id="empty", + task="Empty task", + start_time=datetime(2026, 1, 1), + end_time=datetime(2026, 1, 1, 0, 0, 1), + status=SpanStatus.COMPLETED, + ) + d.show_current_trace(trace) + output = buf.getvalue() + assert "Empty task" in output + + +class TestDashboardShowMetrics: + """Tests for Dashboard.show_metrics.""" + + def test_does_not_raise(self, dashboard, sample_metrics): + """show_metrics should not raise exceptions.""" + d, buf = dashboard + d.show_metrics(sample_metrics) + output = buf.getvalue() + assert len(output) > 0 + + def test_contains_success_rates(self, dashboard, sample_metrics): + """Output should contain success rate values.""" + d, buf = dashboard + d.show_metrics(sample_metrics) + output = buf.getvalue() + assert "80.0%" in output + assert "88.0%" in output + + def test_contains_tool_usage_table(self, dashboard, sample_metrics): + """Output should contain tool usage statistics.""" + d, buf = dashboard + d.show_metrics(sample_metrics) + output = buf.getvalue() + assert "read_file" in output + assert "write_file" in output + assert "grep_search" in output + + def test_contains_error_table(self, dashboard, sample_metrics): + """Output should contain error statistics.""" + d, buf = dashboard + d.show_metrics(sample_metrics) + output = buf.getvalue() + assert "PermissionError" in output + + def test_no_error_table_when_no_errors(self, dashboard): + """Should not show error table when there are no errors.""" + d, buf = dashboard + metrics = MetricsSummary( + total_traces=5, + completed_traces=5, + failed_traces=0, + tool_usage={"read_file": 5}, + ) + d.show_metrics(metrics) + output = buf.getvalue() + # Should still have tool usage but no error table header + assert "read_file" in output + + def test_no_tool_table_when_no_tools(self, dashboard): + """Should not show tool table when there is no tool usage.""" + d, buf = dashboard + metrics = MetricsSummary( + total_traces=5, + completed_traces=5, + failed_traces=0, + ) + d.show_metrics(metrics) + output = buf.getvalue() + assert len(output) > 0 + + +class TestDashboardShowTraceMetrics: + """Tests for Dashboard.show_trace_metrics.""" + + def test_does_not_raise(self, dashboard, sample_trace_metrics): + """show_trace_metrics should not raise exceptions.""" + d, buf = dashboard + d.show_trace_metrics(sample_trace_metrics) + output = buf.getvalue() + assert len(output) > 0 + + def test_contains_trace_id(self, dashboard, sample_trace_metrics): + """Output should contain the trace ID.""" + d, buf = dashboard + d.show_trace_metrics(sample_trace_metrics) + output = buf.getvalue() + assert "tm-001" in output + + def test_contains_task(self, dashboard, sample_trace_metrics): + """Output should contain the task description.""" + d, buf = dashboard + d.show_trace_metrics(sample_trace_metrics) + output = buf.getvalue() + assert "Search for TODOs" in output + + def test_contains_metrics_values(self, dashboard, sample_trace_metrics): + """Output should contain the metrics values.""" + d, buf = dashboard + d.show_trace_metrics(sample_trace_metrics) + output = buf.getvalue() + assert "2500ms" in output + assert "100.0%" in output + + def test_contains_tools_used(self, dashboard, sample_trace_metrics): + """Output should contain the tools used list.""" + d, buf = dashboard + d.show_trace_metrics(sample_trace_metrics) + output = buf.getvalue() + assert "grep_search" in output + assert "read_file" in output + + def test_shows_completed_status_in_green(self, dashboard, sample_trace_metrics): + """Completed status should be styled green.""" + d, buf = dashboard + d.show_trace_metrics(sample_trace_metrics) + output = buf.getvalue() + assert "completed" in output + + def test_shows_failed_status_in_red(self, dashboard): + """Failed status should be styled red.""" + d, buf = dashboard + metrics = TraceMetrics( + trace_id="tm-fail", + task="Failed task", + status=SpanStatus.FAILED, + tools_used=[], + ) + d.show_trace_metrics(metrics) + output = buf.getvalue() + assert "failed" in output + + def test_empty_tools_shows_dash(self, dashboard): + """Should show dash when no tools are used.""" + d, buf = dashboard + metrics = TraceMetrics( + trace_id="tm-empty", + task="No tools task", + status=SpanStatus.COMPLETED, + tools_used=[], + ) + d.show_trace_metrics(metrics) + output = buf.getvalue() + assert "-" in output + + +class TestDashboardShowTracesList: + """Tests for Dashboard.show_traces_list.""" + + def test_does_not_raise(self, dashboard, sample_trace): + """show_traces_list should not raise exceptions.""" + d, buf = dashboard + d.show_traces_list([sample_trace]) + output = buf.getvalue() + assert len(output) > 0 + + def test_contains_trace_id(self, dashboard, sample_trace): + """Output should contain trace IDs.""" + d, buf = dashboard + d.show_traces_list([sample_trace]) + output = buf.getvalue() + assert "dash-trace-001" in output + + def test_truncates_long_task_name(self, dashboard): + """Should truncate task names longer than 30 characters.""" + d, buf = dashboard + trace = Trace( + id="long-task", + task="A" * 50, + start_time=datetime(2026, 1, 1), + end_time=datetime(2026, 1, 1, 0, 0, 1), + status=SpanStatus.COMPLETED, + ) + d.show_traces_list([trace]) + output = buf.getvalue() + assert "..." in output + + def test_respects_limit(self, dashboard): + """Should respect the limit parameter.""" + d, buf = dashboard + traces = [] + for i in range(20): + t = Trace( + id=f"trace-{i:03d}", + task=f"Task {i}", + start_time=datetime(2026, 1, 1), + end_time=datetime(2026, 1, 1, 0, 0, 1), + status=SpanStatus.COMPLETED, + ) + traces.append(t) + d.show_traces_list(traces, limit=5) + output = buf.getvalue() + assert "最近 5 个任务" in output + + def test_shows_status_styles(self, dashboard): + """Should show different status styles.""" + d, buf = dashboard + traces = [ + Trace( + id="completed", + task="Done", + start_time=datetime(2026, 1, 1), + end_time=datetime(2026, 1, 1, 0, 0, 1), + status=SpanStatus.COMPLETED, + ), + Trace( + id="failed", + task="Oops", + start_time=datetime(2026, 1, 1), + end_time=datetime(2026, 1, 1, 0, 0, 1), + status=SpanStatus.FAILED, + ), + Trace( + id="started", + task="Running", + start_time=datetime(2026, 1, 1), + status=SpanStatus.STARTED, + ), + ] + d.show_traces_list(traces) + output = buf.getvalue() + assert "completed" in output + assert "failed" in output + assert "started" in output + + +class TestDashboardShowSummaryReport: + """Tests for Dashboard.show_summary_report.""" + + def test_does_not_raise(self, dashboard, sample_metrics): + """show_summary_report should not raise exceptions.""" + d, buf = dashboard + d.show_summary_report(sample_metrics) + output = buf.getvalue() + assert len(output) > 0 + + def test_contains_report_title(self, dashboard, sample_metrics): + """Output should contain the report title.""" + d, buf = dashboard + d.show_summary_report(sample_metrics) + output = buf.getvalue() + assert "AgentOps 汇总报告" in output + + def test_contains_performance_stats(self, dashboard, sample_metrics): + """Output should contain performance statistics.""" + d, buf = dashboard + d.show_summary_report(sample_metrics) + output = buf.getvalue() + assert "性能统计" in output + assert "16" in output # completed + assert "4" in output # failed + + def test_skips_performance_panel_when_no_traces(self, dashboard): + """Should skip performance stats when total_traces is 0.""" + d, buf = dashboard + metrics = MetricsSummary(total_traces=0) + d.show_summary_report(metrics) + output = buf.getvalue() + # Should still have title but no performance panel + assert "AgentOps 汇总报告" in output + + +class TestDashboardHelperMethods: + """Tests for Dashboard helper methods.""" + + def test_create_progress_bar_full(self, dashboard): + """Full progress bar should be all filled.""" + d, _ = dashboard + bar = d._create_progress_bar(1.0, width=10) + # Rich markup tags wrap each block, so check individual blocks + assert bar.count("█") == 10 # filled blocks + + def test_create_progress_bar_empty(self, dashboard): + """Empty progress bar should be all empty.""" + d, _ = dashboard + bar = d._create_progress_bar(0.0, width=10) + # Rich markup tags wrap each block, so check individual blocks + assert bar.count("░") == 10 # empty blocks + + def test_create_progress_bar_half(self, dashboard): + """Half progress bar should be half filled.""" + d, _ = dashboard + bar = d._create_progress_bar(0.5, width=10) + assert bar.count("█") == 5 + assert bar.count("░") == 5 + + def test_print_error(self, dashboard): + """print_error should output an error message.""" + d, buf = dashboard + d.print_error("Something failed") + output = buf.getvalue() + assert "Something failed" in output + + def test_print_success(self, dashboard): + """print_success should output a success message.""" + d, buf = dashboard + d.print_success("Operation completed") + output = buf.getvalue() + assert "Operation completed" in output + + def test_print_info(self, dashboard): + """print_info should output an info message.""" + d, buf = dashboard + d.print_info("Processing data") + output = buf.getvalue() + assert "Processing data" in output + + def test_print_warning(self, dashboard): + """print_warning should output a warning message.""" + d, buf = dashboard + d.print_warning("Low memory") + output = buf.getvalue() + assert "Low memory" in output diff --git a/tests/ops/test_exporter.py b/tests/ops/test_exporter.py new file mode 100644 index 0000000..51915af --- /dev/null +++ b/tests/ops/test_exporter.py @@ -0,0 +1,337 @@ +"""Tests for the ops.exporter module.""" + +import json +from datetime import datetime + +import pytest + +from jojo_code.ops.exporter import Exporter +from jojo_code.ops.models import Span, SpanStatus, SpanType, Trace + + +@pytest.fixture +def sample_trace() -> Trace: + """Create a sample Trace for export testing.""" + trace = Trace( + id="export-trace-001", + session_id="export-session", + task="Test export task", + start_time=datetime(2026, 1, 1, 10, 0, 0), + end_time=datetime(2026, 1, 1, 10, 0, 2), + status=SpanStatus.COMPLETED, + ) + trace.spans = [ + Span( + id="span-1", + trace_id="export-trace-001", + type=SpanType.THINKING, + name="thinking", + status=SpanStatus.COMPLETED, + start_time=datetime(2026, 1, 1, 10, 0, 0), + end_time=datetime(2026, 1, 1, 10, 0, 1), + ), + Span( + id="span-2", + trace_id="export-trace-001", + type=SpanType.TOOL_CALL, + name="read_file", + input={"path": "README.md"}, + output="File content here", + status=SpanStatus.COMPLETED, + start_time=datetime(2026, 1, 1, 10, 0, 1), + end_time=datetime(2026, 1, 1, 10, 0, 2), + ), + ] + return trace + + +@pytest.fixture +def sample_traces(sample_trace) -> list[Trace]: + """Create a list of sample traces.""" + trace2 = Trace( + id="export-trace-002", + session_id="export-session", + task="Second task", + start_time=datetime(2026, 1, 2, 10, 0, 0), + end_time=datetime(2026, 1, 2, 10, 0, 3), + status=SpanStatus.FAILED, + ) + trace2.spans = [ + Span( + id="span-3", + trace_id="export-trace-002", + type=SpanType.TOOL_CALL, + name="write_file", + status=SpanStatus.FAILED, + error="PermissionError", + start_time=datetime(2026, 1, 2, 10, 0, 0), + end_time=datetime(2026, 1, 2, 10, 0, 1), + ), + ] + return [sample_trace, trace2] + + +class TestExportTracesJson: + """Tests for Exporter.export_traces_json.""" + + def test_creates_valid_json(self, sample_traces, tmp_path): + """Exported JSON should be valid and parseable.""" + output = tmp_path / "traces.json" + Exporter.export_traces_json(sample_traces, str(output)) + + assert output.exists() + data = json.loads(output.read_text(encoding="utf-8")) + assert isinstance(data, list) + assert len(data) == 2 + + def test_exported_data_has_trace_fields(self, sample_traces, tmp_path): + """Exported JSON should contain expected trace fields.""" + output = tmp_path / "traces.json" + Exporter.export_traces_json(sample_traces, str(output)) + + data = json.loads(output.read_text(encoding="utf-8")) + trace = data[0] + assert "id" in trace + assert "task" in trace + assert "status" in trace + assert "spans" in trace + assert "duration_ms" in trace + + def test_exported_data_has_span_fields(self, sample_traces, tmp_path): + """Exported JSON should contain expected span fields.""" + output = tmp_path / "traces.json" + Exporter.export_traces_json(sample_traces, str(output)) + + data = json.loads(output.read_text(encoding="utf-8")) + span = data[0]["spans"][0] + assert "id" in span + assert "type" in span + assert "name" in span + assert "status" in span + + def test_export_preserves_trace_ids(self, sample_traces, tmp_path): + """Exported JSON should preserve original trace IDs.""" + output = tmp_path / "traces.json" + Exporter.export_traces_json(sample_traces, str(output)) + + data = json.loads(output.read_text(encoding="utf-8")) + ids = [t["id"] for t in data] + assert "export-trace-001" in ids + assert "export-trace-002" in ids + + def test_export_empty_list(self, tmp_path): + """Exporting empty list should produce empty JSON array.""" + output = tmp_path / "empty.json" + Exporter.export_traces_json([], str(output)) + + data = json.loads(output.read_text(encoding="utf-8")) + assert data == [] + + def test_export_with_unicode(self, tmp_path): + """Export should handle Unicode characters correctly.""" + trace = Trace( + id="unicode-trace", + task="读取中文文件", + start_time=datetime(2026, 1, 1), + end_time=datetime(2026, 1, 1, 0, 0, 1), + status=SpanStatus.COMPLETED, + ) + output = tmp_path / "unicode.json" + Exporter.export_traces_json([trace], str(output)) + + data = json.loads(output.read_text(encoding="utf-8")) + assert data[0]["task"] == "读取中文文件" + + def test_export_preserves_error_info(self, sample_traces, tmp_path): + """Exported JSON should preserve error information in spans.""" + output = tmp_path / "traces.json" + Exporter.export_traces_json(sample_traces, str(output)) + + data = json.loads(output.read_text(encoding="utf-8")) + failed_span = data[1]["spans"][0] + assert failed_span["error"] == "PermissionError" + + +class TestExportTraceJson: + """Tests for Exporter.export_trace_json.""" + + def test_creates_valid_json(self, sample_trace, tmp_path): + """Single trace export should produce valid JSON object.""" + output = tmp_path / "single.json" + Exporter.export_trace_json(sample_trace, str(output)) + + assert output.exists() + data = json.loads(output.read_text(encoding="utf-8")) + assert isinstance(data, dict) + + def test_exported_has_correct_id(self, sample_trace, tmp_path): + """Exported trace should have correct ID.""" + output = tmp_path / "single.json" + Exporter.export_trace_json(sample_trace, str(output)) + + data = json.loads(output.read_text(encoding="utf-8")) + assert data["id"] == "export-trace-001" + + def test_exported_has_summary(self, sample_trace, tmp_path): + """Exported trace should contain summary section.""" + output = tmp_path / "single.json" + Exporter.export_trace_json(sample_trace, str(output)) + + data = json.loads(output.read_text(encoding="utf-8")) + assert "summary" in data + summary = data["summary"] + assert "thinking_count" in summary + assert "tool_call_count" in summary + assert "error_count" in summary + assert "tool_success_rate" in summary + + def test_export_preserves_span_input_output(self, sample_trace, tmp_path): + """Exported trace should preserve span input/output data.""" + output = tmp_path / "single.json" + Exporter.export_trace_json(sample_trace, str(output)) + + data = json.loads(output.read_text(encoding="utf-8")) + tool_span = next(s for s in data["spans"] if s["type"] == "tool_call") + assert tool_span["input"] == {"path": "README.md"} + assert tool_span["output"] == "File content here" + + +class TestExportSummaryMarkdown: + """Tests for Exporter.export_summary_markdown.""" + + def test_returns_string(self): + """Export should return a string.""" + report = Exporter.export_summary_markdown( + total_traces=10, + completed_traces=8, + failed_traces=2, + avg_thinking_rounds=2.5, + avg_tool_calls=4.0, + avg_duration_ms=3000.0, + tool_success_rate=0.95, + task_success_rate=0.8, + tool_usage={"read_file": 10, "write_file": 5}, + ) + assert isinstance(report, str) + + def test_contains_metrics(self): + """Report should contain the provided metrics.""" + report = Exporter.export_summary_markdown( + total_traces=10, + completed_traces=8, + failed_traces=2, + avg_thinking_rounds=2.5, + avg_tool_calls=4.0, + avg_duration_ms=3000.0, + tool_success_rate=0.95, + task_success_rate=0.8, + tool_usage={"read_file": 10}, + ) + assert "10" in report + assert "8" in report + assert "80.00%" in report + assert "95.00%" in report + + def test_contains_tool_usage_table(self): + """Report should contain tool usage statistics.""" + report = Exporter.export_summary_markdown( + total_traces=10, + completed_traces=8, + failed_traces=2, + avg_thinking_rounds=2.5, + avg_tool_calls=4.0, + avg_duration_ms=3000.0, + tool_success_rate=0.95, + task_success_rate=0.8, + tool_usage={"read_file": 10, "grep_search": 15}, + ) + assert "read_file" in report + assert "grep_search" in report + + def test_contains_error_types_when_provided(self): + """Report should contain error types when provided.""" + report = Exporter.export_summary_markdown( + total_traces=10, + completed_traces=8, + failed_traces=2, + avg_thinking_rounds=2.5, + avg_tool_calls=4.0, + avg_duration_ms=3000.0, + tool_success_rate=0.95, + task_success_rate=0.8, + tool_usage={"read_file": 10}, + error_types={"TimeoutError": 3, "PermissionError": 1}, + ) + assert "TimeoutError" in report + assert "PermissionError" in report + + def test_no_error_section_when_no_errors(self): + """Report should not have error section when error_types is None.""" + report = Exporter.export_summary_markdown( + total_traces=10, + completed_traces=8, + failed_traces=2, + avg_thinking_rounds=2.5, + avg_tool_calls=4.0, + avg_duration_ms=3000.0, + tool_success_rate=0.95, + task_success_rate=0.8, + tool_usage={"read_file": 10}, + error_types=None, + ) + assert "错误统计" not in report + + def test_writes_to_file(self, tmp_path): + """Report should be written to file when output_path is given.""" + output = tmp_path / "report.md" + report = Exporter.export_summary_markdown( + total_traces=10, + completed_traces=8, + failed_traces=2, + avg_thinking_rounds=2.5, + avg_tool_calls=4.0, + avg_duration_ms=3000.0, + tool_success_rate=0.95, + task_success_rate=0.8, + tool_usage={"read_file": 10}, + output_path=str(output), + ) + assert output.exists() + assert output.read_text(encoding="utf-8") == report + + def test_contains_header(self): + """Report should contain the main header.""" + report = Exporter.export_summary_markdown( + total_traces=10, + completed_traces=8, + failed_traces=2, + avg_thinking_rounds=2.5, + avg_tool_calls=4.0, + avg_duration_ms=3000.0, + tool_success_rate=0.95, + task_success_rate=0.8, + tool_usage={}, + ) + assert "AgentOps 报告" in report + + def test_tool_usage_sorted_by_count(self): + """Tool usage should be sorted by count descending.""" + report = Exporter.export_summary_markdown( + total_traces=10, + completed_traces=8, + failed_traces=2, + avg_thinking_rounds=2.5, + avg_tool_calls=4.0, + avg_duration_ms=3000.0, + tool_success_rate=0.95, + task_success_rate=0.8, + tool_usage={"read_file": 5, "write_file": 15, "grep_search": 10}, + ) + lines = report.split("\n") + tool_lines = [ + line for line in lines + if "|" in line + and ("read_file" in line or "write_file" in line or "grep_search" in line) + ] + # write_file (15) should appear before grep_search (10) before read_file (5) + assert "write_file" in tool_lines[0] diff --git a/tests/ops/test_metrics.py b/tests/ops/test_metrics.py new file mode 100644 index 0000000..956119c --- /dev/null +++ b/tests/ops/test_metrics.py @@ -0,0 +1,406 @@ +"""Tests for AgentOps metrics module.""" + +from datetime import datetime, timedelta + +from jojo_code.ops.metrics import MetricsEngine, MetricsSummary, TraceMetrics +from jojo_code.ops.models import Span, SpanStatus, SpanType, Trace + + +def _make_span( + span_type: SpanType = SpanType.TOOL_CALL, + name: str = "read_file", + status: SpanStatus = SpanStatus.COMPLETED, + error: str | None = None, + duration_ms: int = 100, +) -> Span: + """Helper to create a Span with controlled timing.""" + start = datetime(2026, 1, 1, 12, 0, 0) + end = start + timedelta(milliseconds=duration_ms) + return Span( + trace_id="test-trace", + type=span_type, + name=name, + status=status, + error=error, + start_time=start, + end_time=end, + ) + + +def _make_trace( + task: str = "test task", + status: SpanStatus = SpanStatus.COMPLETED, + spans: list[Span] | None = None, + session_id: str = "session-1", + start_offset_seconds: int = 0, + duration_ms: int = 500, +) -> Trace: + """Helper to create a Trace with controlled timing.""" + start = datetime(2026, 1, 1, 12, 0, 0) + timedelta(seconds=start_offset_seconds) + end = start + timedelta(milliseconds=duration_ms) + return Trace( + id=f"trace-{start_offset_seconds}", + session_id=session_id, + task=task, + spans=spans or [], + start_time=start, + end_time=end, + status=status, + ) + + +class TestMetricsSummary: + """Tests for MetricsSummary dataclass.""" + + def test_default_values(self): + summary = MetricsSummary() + assert summary.total_traces == 0 + assert summary.completed_traces == 0 + assert summary.failed_traces == 0 + assert summary.avg_thinking_rounds == 0.0 + assert summary.avg_tool_calls == 0.0 + assert summary.avg_duration_ms == 0.0 + assert summary.tool_success_rate == 0.0 + assert summary.task_success_rate == 0.0 + assert summary.tool_usage == {} + assert summary.error_types == {} + assert summary.start_time is None + assert summary.end_time is None + + def test_to_dict(self): + start = datetime(2026, 1, 1, 12, 0, 0) + end = datetime(2026, 1, 1, 13, 0, 0) + summary = MetricsSummary( + total_traces=10, + completed_traces=8, + failed_traces=2, + avg_thinking_rounds=3.5, + avg_tool_calls=5.2, + avg_duration_ms=1500.0, + tool_success_rate=0.95, + task_success_rate=0.8, + tool_usage={"read_file": 10, "write_file": 5}, + error_types={"timeout": 2}, + start_time=start, + end_time=end, + ) + d = summary.to_dict() + assert d["total_traces"] == 10 + assert d["completed_traces"] == 8 + assert d["failed_traces"] == 2 + assert d["avg_thinking_rounds"] == 3.5 + assert d["avg_tool_calls"] == 5.2 + assert d["avg_duration_ms"] == 1500.0 + assert d["tool_success_rate"] == 0.95 + assert d["task_success_rate"] == 0.8 + assert d["tool_usage"] == {"read_file": 10, "write_file": 5} + assert d["error_types"] == {"timeout": 2} + assert d["start_time"] == start.isoformat() + assert d["end_time"] == end.isoformat() + + def test_to_dict_none_times(self): + summary = MetricsSummary() + d = summary.to_dict() + assert d["start_time"] is None + assert d["end_time"] is None + + +class TestTraceMetrics: + """Tests for TraceMetrics dataclass.""" + + def test_default_values(self): + tm = TraceMetrics(trace_id="t1", task="test", status=SpanStatus.COMPLETED) + assert tm.trace_id == "t1" + assert tm.task == "test" + assert tm.status == SpanStatus.COMPLETED + assert tm.thinking_rounds == 0 + assert tm.tool_calls == 0 + assert tm.errors == 0 + assert tm.duration_ms == 0 + assert tm.tool_success_rate == 1.0 + assert tm.tools_used == [] + + def test_to_dict(self): + tm = TraceMetrics( + trace_id="t1", + task="read file", + status=SpanStatus.COMPLETED, + thinking_rounds=2, + tool_calls=3, + errors=1, + duration_ms=500, + tool_success_rate=0.67, + tools_used=["read_file", "write_file"], + ) + d = tm.to_dict() + assert d["trace_id"] == "t1" + assert d["task"] == "read file" + assert d["status"] == "completed" + assert d["thinking_rounds"] == 2 + assert d["tool_calls"] == 3 + assert d["errors"] == 1 + assert d["duration_ms"] == 500 + assert d["tool_success_rate"] == 0.67 + assert d["tools_used"] == ["read_file", "write_file"] + + +class TestMetricsEngine: + """Tests for MetricsEngine.""" + + def test_empty_traces(self): + engine = MetricsEngine([]) + summary = engine.calculate() + assert summary.total_traces == 0 + assert summary.completed_traces == 0 + assert summary.failed_traces == 0 + + def test_single_completed_trace(self): + spans = [ + _make_span(SpanType.THINKING, "thinking", SpanStatus.COMPLETED), + _make_span(SpanType.TOOL_CALL, "read_file", SpanStatus.COMPLETED), + ] + trace = _make_trace(status=SpanStatus.COMPLETED, spans=spans, duration_ms=200) + engine = MetricsEngine([trace]) + summary = engine.calculate() + + assert summary.total_traces == 1 + assert summary.completed_traces == 1 + assert summary.failed_traces == 0 + assert summary.task_success_rate == 1.0 + assert summary.avg_duration_ms == 200.0 + assert summary.tool_usage == {"read_file": 1} + + def test_multiple_traces_mixed_status(self): + spans_ok = [ + _make_span(SpanType.TOOL_CALL, "read_file", SpanStatus.COMPLETED), + ] + spans_fail = [ + _make_span(SpanType.TOOL_CALL, "write_file", SpanStatus.FAILED), + ] + trace1 = _make_trace(status=SpanStatus.COMPLETED, spans=spans_ok, duration_ms=100) + trace2 = _make_trace( + status=SpanStatus.FAILED, + spans=spans_fail, + start_offset_seconds=10, + duration_ms=200, + ) + engine = MetricsEngine([trace1, trace2]) + summary = engine.calculate() + + assert summary.total_traces == 2 + assert summary.completed_traces == 1 + assert summary.failed_traces == 1 + assert summary.task_success_rate == 0.5 + assert summary.avg_duration_ms == 150.0 + + def test_tool_usage_aggregation(self): + spans = [ + _make_span(SpanType.TOOL_CALL, "read_file", SpanStatus.COMPLETED), + _make_span(SpanType.TOOL_CALL, "read_file", SpanStatus.COMPLETED), + _make_span(SpanType.TOOL_CALL, "write_file", SpanStatus.COMPLETED), + ] + trace = _make_trace(spans=spans) + engine = MetricsEngine([trace]) + summary = engine.calculate() + + assert summary.tool_usage["read_file"] == 2 + assert summary.tool_usage["write_file"] == 1 + + def test_error_types_aggregation(self): + spans = [ + _make_span( + SpanType.TOOL_CALL, "read_file", + SpanStatus.FAILED, error="FileNotFoundError: not found", + ), + _make_span( + SpanType.TOOL_CALL, "write_file", + SpanStatus.FAILED, error="Permission denied", + ), + ] + trace = _make_trace(spans=spans) + engine = MetricsEngine([trace]) + summary = engine.calculate() + + assert "FileNotFoundError: not found" in summary.error_types + assert "Permission denied" in summary.error_types + + def test_time_range(self): + trace1 = _make_trace(start_offset_seconds=0, duration_ms=100) + trace2 = _make_trace(start_offset_seconds=60, duration_ms=200) + engine = MetricsEngine([trace1, trace2]) + summary = engine.calculate() + + assert summary.start_time == datetime(2026, 1, 1, 12, 0, 0) + assert summary.end_time == datetime(2026, 1, 1, 12, 1, 0) + timedelta(milliseconds=200) + + +class TestMetricsEngineTraceMetrics: + """Tests for calculate_trace_metrics.""" + + def test_basic_trace_metrics(self): + spans = [ + _make_span(SpanType.THINKING, "thinking", SpanStatus.COMPLETED), + _make_span(SpanType.TOOL_CALL, "read_file", SpanStatus.COMPLETED), + _make_span(SpanType.TOOL_CALL, "write_file", SpanStatus.COMPLETED), + ] + trace = _make_trace( + task="edit file", + status=SpanStatus.COMPLETED, + spans=spans, + duration_ms=300, + ) + engine = MetricsEngine([trace]) + tm = engine.calculate_trace_metrics(trace) + + assert tm.trace_id == trace.id + assert tm.task == "edit file" + assert tm.status == SpanStatus.COMPLETED + assert tm.thinking_rounds == 1 + assert tm.tool_calls == 2 + assert tm.errors == 0 + assert tm.duration_ms == 300 + assert tm.tools_used == ["read_file", "write_file"] + + +class TestMetricsEngineFilters: + """Tests for filter methods.""" + + def test_filter_by_time_start_only(self): + trace1 = _make_trace(start_offset_seconds=0) + trace2 = _make_trace(start_offset_seconds=120) + engine = MetricsEngine([trace1, trace2]) + + filtered = engine.filter_by_time(start=datetime(2026, 1, 1, 12, 1, 0)) + assert len(filtered) == 1 + assert filtered[0].id == "trace-120" + + def test_filter_by_time_end_only(self): + trace1 = _make_trace(start_offset_seconds=0) + trace2 = _make_trace(start_offset_seconds=120) + engine = MetricsEngine([trace1, trace2]) + + filtered = engine.filter_by_time(end=datetime(2026, 1, 1, 12, 0, 30)) + assert len(filtered) == 1 + assert filtered[0].id == "trace-0" + + def test_filter_by_time_range(self): + trace1 = _make_trace(start_offset_seconds=0) + trace2 = _make_trace(start_offset_seconds=60) + trace3 = _make_trace(start_offset_seconds=120) + engine = MetricsEngine([trace1, trace2, trace3]) + + filtered = engine.filter_by_time( + start=datetime(2026, 1, 1, 12, 0, 30), + end=datetime(2026, 1, 1, 12, 1, 30), + ) + assert len(filtered) == 1 + assert filtered[0].id == "trace-60" + + def test_filter_by_session(self): + trace1 = _make_trace(session_id="s1", start_offset_seconds=0) + trace2 = _make_trace(session_id="s2", start_offset_seconds=10) + trace3 = _make_trace(session_id="s1", start_offset_seconds=20) + engine = MetricsEngine([trace1, trace2, trace3]) + + filtered = engine.filter_by_session("s1") + assert len(filtered) == 2 + + def test_filter_by_status(self): + trace1 = _make_trace(status=SpanStatus.COMPLETED, start_offset_seconds=0) + trace2 = _make_trace(status=SpanStatus.FAILED, start_offset_seconds=10) + trace3 = _make_trace(status=SpanStatus.COMPLETED, start_offset_seconds=20) + engine = MetricsEngine([trace1, trace2, trace3]) + + assert len(engine.filter_by_status(SpanStatus.COMPLETED)) == 2 + assert len(engine.filter_by_status(SpanStatus.FAILED)) == 1 + + +class TestMetricsEngineRankings: + """Tests for ranking and distribution methods.""" + + def test_tool_usage_ranking(self): + spans = [ + _make_span(SpanType.TOOL_CALL, "read_file", SpanStatus.COMPLETED), + _make_span(SpanType.TOOL_CALL, "read_file", SpanStatus.COMPLETED), + _make_span(SpanType.TOOL_CALL, "write_file", SpanStatus.COMPLETED), + _make_span(SpanType.TOOL_CALL, "grep_search", SpanStatus.COMPLETED), + _make_span(SpanType.TOOL_CALL, "grep_search", SpanStatus.COMPLETED), + _make_span(SpanType.TOOL_CALL, "grep_search", SpanStatus.COMPLETED), + ] + trace = _make_trace(spans=spans) + engine = MetricsEngine([trace]) + + ranking = engine.get_tool_usage_ranking() + assert ranking[0] == ("grep_search", 3) + assert ranking[1] == ("read_file", 2) + assert ranking[2] == ("write_file", 1) + + def test_tool_usage_ranking_with_limit(self): + spans = [ + _make_span(SpanType.TOOL_CALL, f"tool_{i}", SpanStatus.COMPLETED) + for i in range(15) + ] + trace = _make_trace(spans=spans) + engine = MetricsEngine([trace]) + + ranking = engine.get_tool_usage_ranking(limit=5) + assert len(ranking) == 5 + + def test_error_distribution(self): + spans = [ + _make_span( + SpanType.TOOL_CALL, "read_file", + SpanStatus.FAILED, error="FileNotFoundError", + ), + _make_span( + SpanType.TOOL_CALL, "write_file", + SpanStatus.FAILED, error="PermissionError", + ), + _make_span( + SpanType.TOOL_CALL, "read_file", + SpanStatus.FAILED, error="FileNotFoundError", + ), + ] + trace = _make_trace(spans=spans) + engine = MetricsEngine([trace]) + + dist = engine.get_error_distribution() + assert dist["FileNotFoundError"] == 2 + assert dist["PermissionError"] == 1 + + def test_error_distribution_empty(self): + trace = _make_trace(spans=[]) + engine = MetricsEngine([trace]) + assert engine.get_error_distribution() == {} + + +class TestMetricsEnginePerformanceStats: + """Tests for get_performance_stats.""" + + def test_empty_traces(self): + engine = MetricsEngine([]) + stats = engine.get_performance_stats() + assert stats["min_duration_ms"] == 0 + assert stats["max_duration_ms"] == 0 + assert stats["median_duration_ms"] == 0 + assert stats["p95_duration_ms"] == 0 + + def test_single_trace(self): + trace = _make_trace(duration_ms=500) + engine = MetricsEngine([trace]) + stats = engine.get_performance_stats() + assert stats["min_duration_ms"] == 500 + assert stats["max_duration_ms"] == 500 + assert stats["median_duration_ms"] == 500 + assert stats["p95_duration_ms"] == 500 + + def test_multiple_traces(self): + traces = [_make_trace(duration_ms=i * 100, start_offset_seconds=i) for i in range(1, 11)] + engine = MetricsEngine(traces) + stats = engine.get_performance_stats() + + assert stats["min_duration_ms"] == 100 + assert stats["max_duration_ms"] == 1000 + # Implementation uses durations[n // 2], for n=10 that is index 5 = 600 + assert stats["median_duration_ms"] == 600 + assert stats["p95_duration_ms"] == 1000 diff --git a/tests/ops/test_models.py b/tests/ops/test_models.py new file mode 100644 index 0000000..bc4512e --- /dev/null +++ b/tests/ops/test_models.py @@ -0,0 +1,471 @@ +"""AgentOps 数据模型测试 - Span 和 Trace""" + +from datetime import datetime, timedelta + +from jojo_code.ops.models import Span, SpanStatus, SpanType, Trace + + +class TestSpanType: + """SpanType 枚举测试""" + + def test_thinking_value(self): + """THINKING 值应为 thinking""" + assert SpanType.THINKING.value == "thinking" + + def test_tool_call_value(self): + """TOOL_CALL 值应为 tool_call""" + assert SpanType.TOOL_CALL.value == "tool_call" + + def test_observe_value(self): + """OBSERVE 值应为 observe""" + assert SpanType.OBSERVE.value == "observe" + + def test_error_value(self): + """ERROR 值应为 error""" + assert SpanType.ERROR.value == "error" + + def test_all_types_count(self): + """应有 4 种 Span 类型""" + assert len(SpanType) == 4 + + +class TestSpanStatus: + """SpanStatus 枚举测试""" + + def test_started_value(self): + """STARTED 值应为 started""" + assert SpanStatus.STARTED.value == "started" + + def test_completed_value(self): + """COMPLETED 值应为 completed""" + assert SpanStatus.COMPLETED.value == "completed" + + def test_failed_value(self): + """FAILED 值应为 failed""" + assert SpanStatus.FAILED.value == "failed" + + def test_all_statuses_count(self): + """应有 3 种状态""" + assert len(SpanStatus) == 3 + + +class TestSpanDefaults: + """Span 默认值测试""" + + def test_default_type(self): + """默认类型应为 THINKING""" + span = Span() + assert span.type == SpanType.THINKING + + def test_default_status(self): + """默认状态应为 STARTED""" + span = Span() + assert span.status == SpanStatus.STARTED + + def test_default_parent_id(self): + """默认无父 Span""" + span = Span() + assert span.parent_id is None + + def test_default_error(self): + """默认无错误""" + span = Span() + assert span.error is None + + def test_default_end_time(self): + """默认无结束时间""" + span = Span() + assert span.end_time is None + + def test_default_metadata(self): + """默认元数据为空字典""" + span = Span() + assert span.metadata == {} + + def test_default_input_output(self): + """默认输入输出为 None""" + span = Span() + assert span.input is None + assert span.output is None + + def test_default_id_is_unique(self): + """每次创建应生成唯一 ID""" + span1 = Span() + span2 = Span() + assert span1.id != span2.id + + def test_default_start_time_is_recent(self): + """默认开始时间应为当前时间附近""" + before = datetime.now() + span = Span() + after = datetime.now() + assert before <= span.start_time <= after + + +class TestSpanCustomValues: + """Span 自定义值测试""" + + def test_custom_type(self): + """应支持自定义类型""" + span = Span(type=SpanType.TOOL_CALL) + assert span.type == SpanType.TOOL_CALL + + def test_custom_name(self): + """应支持自定义名称""" + span = Span(name="read_file") + assert span.name == "read_file" + + def test_custom_trace_id(self): + """应支持自定义 trace_id""" + span = Span(trace_id="abc123") + assert span.trace_id == "abc123" + + def test_custom_parent_id(self): + """应支持自定义 parent_id""" + span = Span(parent_id="parent-1") + assert span.parent_id == "parent-1" + + def test_custom_input_output(self): + """应支持自定义输入输出""" + span = Span(input={"path": "test.py"}, output="file content") + assert span.input == {"path": "test.py"} + assert span.output == "file content" + + +class TestSpanDurationMs: + """Span.duration_ms 属性测试""" + + def test_duration_no_end_time(self): + """无结束时间时 duration_ms 应为 0""" + span = Span() + assert span.duration_ms == 0 + + def test_duration_with_end_time(self): + """有结束时间时应计算毫秒差""" + start = datetime(2026, 1, 1, 12, 0, 0) + end = start + timedelta(seconds=2, milliseconds=500) + span = Span(start_time=start, end_time=end) + assert span.duration_ms == 2500 + + def test_duration_zero_seconds(self): + """同一时间点应为 0 毫秒""" + now = datetime.now() + span = Span(start_time=now, end_time=now) + assert span.duration_ms == 0 + + +class TestSpanToDict: + """Span.to_dict() 测试""" + + def test_to_dict_contains_all_fields(self): + """to_dict 应包含所有必要字段""" + span = Span() + d = span.to_dict() + expected_keys = { + "id", "trace_id", "parent_id", "type", "name", + "input", "output", "error", "status", "start_time", + "end_time", "duration_ms", "metadata", + } + assert set(d.keys()) == expected_keys + + def test_to_dict_type_is_string(self): + """to_dict 中 type 应为字符串值""" + span = Span(type=SpanType.TOOL_CALL) + d = span.to_dict() + assert d["type"] == "tool_call" + + def test_to_dict_status_is_string(self): + """to_dict 中 status 应为字符串值""" + span = Span(status=SpanStatus.COMPLETED) + d = span.to_dict() + assert d["status"] == "completed" + + def test_to_dict_start_time_is_isoformat(self): + """to_dict 中 start_time 应为 ISO 格式""" + span = Span() + d = span.to_dict() + # 验证可以解析回 datetime + datetime.fromisoformat(d["start_time"]) + + def test_to_dict_end_time_none(self): + """无结束时间时 end_time 应为 None""" + span = Span() + d = span.to_dict() + assert d["end_time"] is None + + def test_to_dict_end_time_isoformat(self): + """有结束时间时应为 ISO 格式""" + now = datetime.now() + span = Span(end_time=now) + d = span.to_dict() + assert d["end_time"] == now.isoformat() + + def test_to_dict_serializes_complex_input(self): + """to_dict 应序列化复杂对象为字符串""" + span = Span(input=object()) + d = span.to_dict() + assert isinstance(d["input"], str) + + def test_to_dict_preserves_simple_types(self): + """to_dict 应保留简单类型的输入输出""" + span = Span(input="hello", output=42) + d = span.to_dict() + assert d["input"] == "hello" + assert d["output"] == 42 + + +class TestSpanSerialize: + """Span._serialize() 测试""" + + def test_serialize_none(self): + """None 应返回 None""" + span = Span() + assert span._serialize(None) is None + + def test_serialize_string(self): + """字符串应原样返回""" + span = Span() + assert span._serialize("hello") == "hello" + + def test_serialize_int(self): + """整数应原样返回""" + span = Span() + assert span._serialize(42) == 42 + + def test_serialize_float(self): + """浮点数应原样返回""" + span = Span() + assert span._serialize(3.14) == 3.14 + + def test_serialize_bool(self): + """布尔值应原样返回""" + span = Span() + assert span._serialize(True) is True + + def test_serialize_list(self): + """列表应原样返回""" + span = Span() + assert span._serialize([1, 2, 3]) == [1, 2, 3] + + def test_serialize_dict(self): + """字典应原样返回""" + span = Span() + assert span._serialize({"key": "value"}) == {"key": "value"} + + def test_serialize_complex_object(self): + """复杂对象应转为字符串""" + span = Span() + result = span._serialize(object()) + assert isinstance(result, str) + + +class TestTraceDefaults: + """Trace 默认值测试""" + + def test_default_status(self): + """默认状态应为 STARTED""" + trace = Trace() + assert trace.status == SpanStatus.STARTED + + def test_default_spans(self): + """默认 spans 为空列表""" + trace = Trace() + assert trace.spans == [] + + def test_default_end_time(self): + """默认无结束时间""" + trace = Trace() + assert trace.end_time is None + + def test_default_metadata(self): + """默认元数据为空字典""" + trace = Trace() + assert trace.metadata == {} + + def test_default_id_is_unique(self): + """每次创建应生成唯一 ID""" + trace1 = Trace() + trace2 = Trace() + assert trace1.id != trace2.id + + def test_default_start_time_is_recent(self): + """默认开始时间应为当前时间附近""" + before = datetime.now() + trace = Trace() + after = datetime.now() + assert before <= trace.start_time <= after + + +class TestTraceCustomValues: + """Trace 自定义值测试""" + + def test_custom_session_id(self): + """应支持自定义 session_id""" + trace = Trace(session_id="session-123") + assert trace.session_id == "session-123" + + def test_custom_task(self): + """应支持自定义 task""" + trace = Trace(task="读取 README.md") + assert trace.task == "读取 README.md" + + def test_custom_spans(self): + """应支持自定义 spans""" + spans = [Span(name="span1"), Span(name="span2")] + trace = Trace(spans=spans) + assert len(trace.spans) == 2 + + +class TestTraceDurationMs: + """Trace.duration_ms 属性测试""" + + def test_duration_no_end_time(self): + """无结束时间时 duration_ms 应为 0""" + trace = Trace() + assert trace.duration_ms == 0 + + def test_duration_with_end_time(self): + """有结束时间时应计算毫秒差""" + start = datetime(2026, 1, 1, 12, 0, 0) + end = start + timedelta(seconds=5) + trace = Trace(start_time=start, end_time=end) + assert trace.duration_ms == 5000 + + +class TestTraceSpanCounts: + """Trace span 计数属性测试""" + + def test_thinking_count_empty(self): + """无 span 时 thinking_count 应为 0""" + trace = Trace() + assert trace.thinking_count == 0 + + def test_thinking_count(self): + """应正确统计 thinking span 数量""" + trace = Trace(spans=[ + Span(type=SpanType.THINKING), + Span(type=SpanType.TOOL_CALL), + Span(type=SpanType.THINKING), + ]) + assert trace.thinking_count == 2 + + def test_tool_call_count_empty(self): + """无 span 时 tool_call_count 应为 0""" + trace = Trace() + assert trace.tool_call_count == 0 + + def test_tool_call_count(self): + """应正确统计 tool_call span 数量""" + trace = Trace(spans=[ + Span(type=SpanType.TOOL_CALL), + Span(type=SpanType.TOOL_CALL), + Span(type=SpanType.ERROR), + ]) + assert trace.tool_call_count == 2 + + def test_error_count_empty(self): + """无 span 时 error_count 应为 0""" + trace = Trace() + assert trace.error_count == 0 + + def test_error_count(self): + """应正确统计 error span 数量""" + trace = Trace(spans=[ + Span(type=SpanType.ERROR), + Span(type=SpanType.ERROR), + Span(type=SpanType.THINKING), + ]) + assert trace.error_count == 2 + + +class TestTraceToolSuccessRate: + """Trace.tool_success_rate 属性测试""" + + def test_success_rate_no_tool_calls(self): + """无工具调用时成功率应为 1.0""" + trace = Trace() + assert trace.tool_success_rate == 1.0 + + def test_success_rate_all_success(self): + """所有工具调用成功时应为 1.0""" + trace = Trace(spans=[ + Span(type=SpanType.TOOL_CALL, status=SpanStatus.COMPLETED), + Span(type=SpanType.TOOL_CALL, status=SpanStatus.COMPLETED), + ]) + assert trace.tool_success_rate == 1.0 + + def test_success_rate_partial_success(self): + """部分成功时应返回正确比率""" + trace = Trace(spans=[ + Span(type=SpanType.TOOL_CALL, status=SpanStatus.COMPLETED), + Span(type=SpanType.TOOL_CALL, status=SpanStatus.FAILED), + ]) + assert trace.tool_success_rate == 0.5 + + def test_success_rate_all_failed(self): + """全部失败时应为 0.0""" + trace = Trace(spans=[ + Span(type=SpanType.TOOL_CALL, status=SpanStatus.FAILED), + ]) + assert trace.tool_success_rate == 0.0 + + def test_success_rate_ignores_non_tool_spans(self): + """非工具调用 span 不应影响成功率""" + trace = Trace(spans=[ + Span(type=SpanType.THINKING, status=SpanStatus.STARTED), + Span(type=SpanType.TOOL_CALL, status=SpanStatus.COMPLETED), + Span(type=SpanType.ERROR, status=SpanStatus.FAILED), + ]) + assert trace.tool_success_rate == 1.0 + + +class TestTraceToDict: + """Trace.to_dict() 测试""" + + def test_to_dict_contains_all_fields(self): + """to_dict 应包含所有必要字段""" + trace = Trace() + d = trace.to_dict() + expected_keys = { + "id", "session_id", "task", "spans", "start_time", + "end_time", "duration_ms", "status", "metadata", "summary", + } + assert set(d.keys()) == expected_keys + + def test_to_dict_status_is_string(self): + """to_dict 中 status 应为字符串值""" + trace = Trace(status=SpanStatus.COMPLETED) + d = trace.to_dict() + assert d["status"] == "completed" + + def test_to_dict_spans_are_dicts(self): + """to_dict 中 spans 应为字典列表""" + trace = Trace(spans=[Span(name="test")]) + d = trace.to_dict() + assert len(d["spans"]) == 1 + assert isinstance(d["spans"][0], dict) + assert d["spans"][0]["name"] == "test" + + def test_to_dict_summary_contains_counts(self): + """to_dict 中 summary 应包含计数信息""" + trace = Trace(spans=[ + Span(type=SpanType.THINKING), + Span(type=SpanType.TOOL_CALL, status=SpanStatus.COMPLETED), + Span(type=SpanType.ERROR), + ]) + d = trace.to_dict() + summary = d["summary"] + assert summary["thinking_count"] == 1 + assert summary["tool_call_count"] == 1 + assert summary["error_count"] == 1 + assert summary["tool_success_rate"] == 1.0 + + def test_to_dict_empty_trace(self): + """空 trace 的 summary 应为零值""" + trace = Trace() + d = trace.to_dict() + summary = d["summary"] + assert summary["thinking_count"] == 0 + assert summary["tool_call_count"] == 0 + assert summary["error_count"] == 0 + assert summary["tool_success_rate"] == 1.0 diff --git a/tests/ops/test_report.py b/tests/ops/test_report.py index a902bf5..12362a8 100644 --- a/tests/ops/test_report.py +++ b/tests/ops/test_report.py @@ -1,220 +1,437 @@ -"""Report 单元测试""" +"""Tests for the ops.report module.""" from datetime import datetime import pytest -from jojo_code.ops import Span, SpanStatus, SpanType, Trace from jojo_code.ops.evaluator import EvaluationResult, EvaluationScore from jojo_code.ops.metrics import MetricsSummary +from jojo_code.ops.models import Span, SpanStatus, SpanType, Trace from jojo_code.ops.report import ReportGenerator -class TestReportGenerator: - """ReportGenerator 测试""" - - @pytest.fixture - def trace(self): - """创建测试 Trace""" - trace = Trace(task="读取 README.md") - trace.spans.append( - Span( - type=SpanType.TOOL_CALL, - name="read_file", - status=SpanStatus.COMPLETED, - input={"path": "README.md"}, - output="# README\n项目说明", - ) - ) - trace.spans.append(Span(type=SpanType.THINKING, name="thinking")) - trace.start_time = datetime.now() - trace.end_time = datetime.fromtimestamp(trace.start_time.timestamp() + 1) - trace.status = SpanStatus.COMPLETED - return trace - - @pytest.fixture - def score(self): - """创建测试评分""" - return EvaluationScore( - result=EvaluationResult.PASS, - score=0.95, - reason="表现良好", - details={"issues": [], "rules_checked": 5}, - ) - - def test_generate_evaluation_report(self, trace, score): - """测试生成评估报告""" - report = ReportGenerator.generate_evaluation_report(trace, score) - - assert "AgentOps 评估报告" in report - assert trace.id in report - assert trace.task in report +@pytest.fixture +def sample_trace() -> Trace: + """Create a sample Trace with various span types.""" + trace = Trace( + id="test-trace-001", + session_id="test-session", + task="Read README.md and summarize", + start_time=datetime(2026, 1, 1, 10, 0, 0), + end_time=datetime(2026, 1, 1, 10, 0, 5), + status=SpanStatus.COMPLETED, + ) + trace.spans = [ + Span( + id="span-1", + trace_id="test-trace-001", + type=SpanType.THINKING, + name="thinking", + status=SpanStatus.COMPLETED, + start_time=datetime(2026, 1, 1, 10, 0, 0), + end_time=datetime(2026, 1, 1, 10, 0, 1), + ), + Span( + id="span-2", + trace_id="test-trace-001", + type=SpanType.TOOL_CALL, + name="read_file", + status=SpanStatus.COMPLETED, + start_time=datetime(2026, 1, 1, 10, 0, 1), + end_time=datetime(2026, 1, 1, 10, 0, 2), + ), + Span( + id="span-3", + trace_id="test-trace-001", + type=SpanType.TOOL_CALL, + name="grep_search", + status=SpanStatus.FAILED, + error="FileNotFoundError: file not found", + start_time=datetime(2026, 1, 1, 10, 0, 2), + end_time=datetime(2026, 1, 1, 10, 0, 3), + ), + Span( + id="span-4", + trace_id="test-trace-001", + type=SpanType.THINKING, + name="thinking", + status=SpanStatus.COMPLETED, + start_time=datetime(2026, 1, 1, 10, 0, 3), + end_time=datetime(2026, 1, 1, 10, 0, 4), + ), + ] + return trace + + +@pytest.fixture +def passing_score() -> EvaluationScore: + """Create a passing evaluation score.""" + return EvaluationScore( + result=EvaluationResult.PASS, + score=0.95, + reason="规划质量良好", + details={"issues": [], "rules_checked": 5}, + ) + + +@pytest.fixture +def failing_score() -> EvaluationScore: + """Create a failing evaluation score.""" + return EvaluationScore( + result=EvaluationResult.FAIL, + score=0.3, + reason="思考轮数过多; 任务执行失败", + details={"issues": ["思考轮数过多: 8", "任务执行失败"]}, + ) + + +@pytest.fixture +def sample_metrics() -> MetricsSummary: + """Create a sample MetricsSummary.""" + return MetricsSummary( + total_traces=10, + completed_traces=8, + failed_traces=2, + avg_thinking_rounds=2.5, + avg_tool_calls=4.0, + avg_duration_ms=3000.0, + tool_success_rate=0.92, + task_success_rate=0.8, + tool_usage={"read_file": 15, "write_file": 5, "grep_search": 10}, + error_types={"PermissionError: access denied": 2, "TimeoutError: timeout": 1}, + start_time=datetime(2026, 1, 1, 0, 0, 0), + end_time=datetime(2026, 1, 1, 23, 59, 59), + ) + + +class TestReportGeneratorGenerateEvaluationReport: + """Tests for ReportGenerator.generate_evaluation_report.""" + + def test_returns_string(self, sample_trace, passing_score): + """Report should be a string.""" + report = ReportGenerator.generate_evaluation_report(sample_trace, passing_score) + assert isinstance(report, str) + + def test_contains_trace_info(self, sample_trace, passing_score): + """Report should contain trace ID and task.""" + report = ReportGenerator.generate_evaluation_report(sample_trace, passing_score) + assert "test-trace-001" in report + assert "Read README.md and summarize" in report + + def test_contains_evaluation_result(self, sample_trace, passing_score): + """Report should contain evaluation result.""" + report = ReportGenerator.generate_evaluation_report(sample_trace, passing_score) assert "PASS" in report - assert "95" in report or "0.95" in report - def test_generate_report_with_errors(self, trace): - """测试包含错误的报告""" - trace.spans.append( - Span( - type=SpanType.TOOL_CALL, - name="write_file", - status=SpanStatus.FAILED, - error="权限不足", - ) - ) + def test_contains_score_percentage(self, sample_trace, passing_score): + """Report should contain formatted score.""" + report = ReportGenerator.generate_evaluation_report(sample_trace, passing_score) + assert "95.00%" in report - score = EvaluationScore( - result=EvaluationResult.FAIL, - score=0.3, - reason="任务失败", + def test_contains_tool_usage_table(self, sample_trace, passing_score): + """Report should contain tool usage table.""" + report = ReportGenerator.generate_evaluation_report(sample_trace, passing_score) + assert "read_file" in report + assert "grep_search" in report + + def test_contains_error_details_when_errors_present(self, sample_trace, passing_score): + """Report should contain error details when spans have errors.""" + report = ReportGenerator.generate_evaluation_report(sample_trace, passing_score) + assert "FileNotFoundError" in report + + def test_contains_evaluation_details_json(self, sample_trace, passing_score): + """Report should contain evaluation details as JSON.""" + report = ReportGenerator.generate_evaluation_report(sample_trace, passing_score) + assert "rules_checked" in report + + def test_contains_suggestions_for_passing(self, sample_trace, passing_score): + """Report should contain suggestions.""" + report = ReportGenerator.generate_evaluation_report(sample_trace, passing_score) + assert "改进建议" in report + + def test_contains_suggestions_for_failing(self, sample_trace, failing_score): + """Report should contain failure-specific suggestions.""" + report = ReportGenerator.generate_evaluation_report(sample_trace, failing_score) + assert "改进建议" in report + assert "评估未通过" in report + + def test_writes_to_file(self, sample_trace, passing_score, tmp_path): + """Report should be written to file when output_path is given.""" + output = tmp_path / "report.md" + report = ReportGenerator.generate_evaluation_report( + sample_trace, passing_score, output_path=str(output) ) + assert output.exists() + assert output.read_text(encoding="utf-8") == report - report = ReportGenerator.generate_evaluation_report(trace, score) - - assert "权限不足" in report - assert "❌" in report - - def test_generate_report_to_file(self, trace, score, tmp_path): - """测试输出到文件""" - output_file = str(tmp_path / "report.md") - ReportGenerator.generate_evaluation_report(trace, score, output_file) - - import os - - assert os.path.exists(output_file) - with open(output_file) as f: - content = f.read() - assert "AgentOps 评估报告" in content - - def test_generate_summary_report(self): - """测试生成汇总报告""" - metrics = MetricsSummary( - total_traces=100, - completed_traces=85, - failed_traces=15, - avg_thinking_rounds=3.5, - avg_tool_calls=4.2, - avg_duration_ms=2500.0, - tool_success_rate=0.92, - task_success_rate=0.85, - tool_usage={"read_file": 150, "write_file": 80, "execute": 50}, - error_types={"文件不存在": 10, "权限不足": 5}, + def test_no_file_written_without_output_path(self, sample_trace, passing_score, tmp_path): + """No file should be written when output_path is None.""" + report = ReportGenerator.generate_evaluation_report( + sample_trace, passing_score, output_path=None ) - - report = ReportGenerator.generate_summary_report(metrics) - - assert "AgentOps 汇总报告" in report - assert "100" in report + assert isinstance(report, str) + + def test_trace_with_no_tool_calls(self, passing_score): + """Report should handle trace with no tool calls.""" + trace = Trace( + id="no-tools", + task="Simple task", + start_time=datetime(2026, 1, 1), + end_time=datetime(2026, 1, 1, 0, 0, 1), + status=SpanStatus.COMPLETED, + ) + report = ReportGenerator.generate_evaluation_report(trace, passing_score) + assert "无工具调用" in report + + def test_failing_score_contains_specific_suggestions(self, sample_trace, failing_score): + """Failing score should trigger specific improvement suggestions.""" + report = ReportGenerator.generate_evaluation_report(sample_trace, failing_score) + assert "评估未通过" in report + + +class TestReportGeneratorGenerateSummaryReport: + """Tests for ReportGenerator.generate_summary_report.""" + + def test_returns_string(self, sample_metrics): + """Report should be a string.""" + report = ReportGenerator.generate_summary_report(sample_metrics) + assert isinstance(report, str) + + def test_contains_metrics(self, sample_metrics): + """Report should contain key metrics.""" + report = ReportGenerator.generate_summary_report(sample_metrics) + assert "10" in report # total_traces + assert "8" in report # completed + assert "2" in report # failed + assert "80.00%" in report # task_success_rate + + def test_contains_tool_usage(self, sample_metrics): + """Report should contain tool usage table.""" + report = ReportGenerator.generate_summary_report(sample_metrics) assert "read_file" in report - - def test_generate_summary_report_with_evaluations(self): - """测试带评估的汇总报告""" - metrics = MetricsSummary( - total_traces=10, - completed_traces=8, - failed_traces=2, - avg_thinking_rounds=2.0, - avg_tool_calls=3.0, - avg_duration_ms=1000.0, - tool_success_rate=0.95, - task_success_rate=0.80, - tool_usage={"read_file": 20}, - error_types={}, + assert "write_file" in report + assert "grep_search" in report + + def test_contains_error_types(self, sample_metrics): + """Report should contain error statistics.""" + report = ReportGenerator.generate_summary_report(sample_metrics) + assert "PermissionError" in report + assert "TimeoutError" in report + + def test_contains_time_range(self, sample_metrics): + """Report should contain time range.""" + report = ReportGenerator.generate_summary_report(sample_metrics) + assert "2026-01-01" in report + + def test_with_evaluation_scores(self, sample_metrics, passing_score, failing_score): + """Report should include evaluation summary when scores are provided.""" + scores = [passing_score, failing_score] + report = ReportGenerator.generate_summary_report(sample_metrics, evaluation_scores=scores) + assert "评估汇总" in report + assert "通过" in report + assert "失败" in report + + def test_without_evaluation_scores(self, sample_metrics): + """Report should work without evaluation scores.""" + report = ReportGenerator.generate_summary_report(sample_metrics, evaluation_scores=None) + assert isinstance(report, str) + assert "评估汇总" not in report + + def test_writes_to_file(self, sample_metrics, tmp_path): + """Report should be written to file when output_path is given.""" + output = tmp_path / "summary.md" + report = ReportGenerator.generate_summary_report( + sample_metrics, output_path=str(output) ) + assert output.exists() + assert output.read_text(encoding="utf-8") == report - scores = [ - EvaluationScore(result=EvaluationResult.PASS, score=0.95, reason="良好"), - EvaluationScore(result=EvaluationResult.PASS, score=0.88, reason="良好"), - EvaluationScore(result=EvaluationResult.PARTIAL, score=0.72, reason="一般"), - EvaluationScore(result=EvaluationResult.FAIL, score=0.45, reason="失败"), - ] - - report = ReportGenerator.generate_summary_report(metrics, scores) - - assert "评估汇总" in report - assert "平均得分" in report + def test_contains_suggestions(self, sample_metrics): + """Report should contain improvement suggestions.""" + report = ReportGenerator.generate_summary_report(sample_metrics) + assert "改进建议" in report - def test_generate_summary_report_to_file(self, tmp_path): - """测试汇总报告输出到文件""" + def test_no_error_section_when_no_errors(self): + """Report should not have error section when there are no errors.""" metrics = MetricsSummary( - total_traces=10, - completed_traces=8, - failed_traces=2, - avg_thinking_rounds=2.0, - avg_tool_calls=3.0, + total_traces=5, + completed_traces=5, + failed_traces=0, + avg_thinking_rounds=1.0, + avg_tool_calls=2.0, avg_duration_ms=1000.0, - tool_success_rate=0.95, - task_success_rate=0.80, - tool_usage={"read_file": 20}, - error_types={}, + tool_success_rate=1.0, + task_success_rate=1.0, + tool_usage={"read_file": 5}, ) + report = ReportGenerator.generate_summary_report(metrics) + assert "错误统计" not in report - output_file = str(tmp_path / "summary.md") - ReportGenerator.generate_summary_report(metrics, output_path=output_file) - - import os - - assert os.path.exists(output_file) - - def test_generate_suggestions_good_performance(self, trace, score): - """测试良好表现的建议""" - suggestions = ReportGenerator._generate_suggestions(trace, score) - - assert "表现良好" in suggestions[0] - - def test_generate_suggestions_too_many_thinking(self, trace): - """测试思考过多的建议""" - for _ in range(10): - trace.spans.append(Span(type=SpanType.THINKING, name="thinking")) - - score = EvaluationScore(result=EvaluationResult.PARTIAL, score=0.6, reason="") - - suggestions = ReportGenerator._generate_suggestions(trace, score) - assert any("思考轮数" in s or "优化" in s for s in suggestions) +class TestReportGeneratorSuggestions: + """Tests for suggestion generation methods.""" - def test_generate_suggestions_tool_failure(self, trace): - """测试工具失败的建议""" + def test_high_thinking_count_suggestion(self, passing_score): + """Should suggest reducing thinking rounds when count > 5.""" + trace = Trace( + id="t", + task="task", + start_time=datetime(2026, 1, 1), + end_time=datetime(2026, 1, 1, 0, 0, 1), + status=SpanStatus.COMPLETED, + ) + # Add 6 thinking spans + for _ in range(6): + trace.spans.append( + Span( + type=SpanType.THINKING, + name="thinking", + status=SpanStatus.COMPLETED, + start_time=datetime(2026, 1, 1), + end_time=datetime(2026, 1, 1, 0, 0, 1), + ) + ) + suggestions = ReportGenerator._generate_suggestions(trace, passing_score) + assert any("思考轮数" in s for s in suggestions) + + def test_low_tool_success_rate_suggestion(self, passing_score): + """Should suggest checking tool parameters when success rate < 0.8.""" + trace = Trace( + id="t", + task="task", + start_time=datetime(2026, 1, 1), + end_time=datetime(2026, 1, 1, 0, 0, 1), + status=SpanStatus.COMPLETED, + ) + # Add a failed tool call trace.spans.append( Span( type=SpanType.TOOL_CALL, - name="bad_tool", + name="write_file", status=SpanStatus.FAILED, - error="错误", + start_time=datetime(2026, 1, 1), + end_time=datetime(2026, 1, 1, 0, 0, 1), ) ) + suggestions = ReportGenerator._generate_suggestions(trace, passing_score) + assert any("工具调用失败" in s for s in suggestions) + + def test_error_count_suggestion(self, passing_score): + """Should suggest error handling when errors > 0.""" + trace = Trace( + id="t", + task="task", + start_time=datetime(2026, 1, 1), + end_time=datetime(2026, 1, 1, 0, 0, 1), + status=SpanStatus.COMPLETED, + ) + trace.spans.append( + Span( + type=SpanType.ERROR, + name="error", + status=SpanStatus.FAILED, + start_time=datetime(2026, 1, 1), + end_time=datetime(2026, 1, 1, 0, 0, 1), + ) + ) + suggestions = ReportGenerator._generate_suggestions(trace, passing_score) + assert any("错误" in s for s in suggestions) + + def test_long_duration_suggestion(self, passing_score): + """Should suggest optimization when duration > 10s.""" + trace = Trace( + id="t", + task="task", + start_time=datetime(2026, 1, 1), + end_time=datetime(2026, 1, 1, 0, 0, 15), + status=SpanStatus.COMPLETED, + ) + suggestions = ReportGenerator._generate_suggestions(trace, passing_score) + assert any("耗时" in s for s in suggestions) + + def test_good_performance_suggestion(self, passing_score): + """Should give positive suggestion when everything is fine.""" + trace = Trace( + id="t", + task="task", + start_time=datetime(2026, 1, 1), + end_time=datetime(2026, 1, 1, 0, 0, 1), + status=SpanStatus.COMPLETED, + ) + suggestions = ReportGenerator._generate_suggestions(trace, passing_score) + assert any("表现良好" in s for s in suggestions) - score = EvaluationScore(result=EvaluationResult.FAIL, score=0.3, reason="失败") - - suggestions = ReportGenerator._generate_suggestions(trace, score) - - assert len(suggestions) > 0 + def test_summary_suggestions_low_success_rate(self): + """Should suggest analyzing failures when task success rate < 0.9.""" + metrics = MetricsSummary( + total_traces=10, + completed_traces=7, + failed_traces=3, + task_success_rate=0.7, + ) + suggestions = ReportGenerator._generate_summary_suggestions(metrics, None) + assert any("任务成功率" in s for s in suggestions) - def test_report_markdown_format(self, trace, score): - """测试 Markdown 格式""" - report = ReportGenerator.generate_evaluation_report(trace, score) + def test_summary_suggestions_high_thinking_rounds(self): + """Should suggest optimizing planning when avg thinking > 3.""" + metrics = MetricsSummary( + total_traces=10, + completed_traces=10, + failed_traces=0, + avg_thinking_rounds=4.0, + task_success_rate=1.0, + ) + suggestions = ReportGenerator._generate_summary_suggestions(metrics, None) + assert any("思考轮数" in s for s in suggestions) - # 检查 Markdown 元素 - assert report.startswith("#") # 标题 - assert "|" in report # 表格 - assert "---" in report or "##" in report # 分隔线或子标题 + def test_summary_suggestions_top_error(self): + """Should mention most common error.""" + metrics = MetricsSummary( + total_traces=10, + completed_traces=10, + failed_traces=0, + task_success_rate=1.0, + error_types={"ConnectionError": 5}, + ) + suggestions = ReportGenerator._generate_summary_suggestions(metrics, None) + assert any("ConnectionError" in s for s in suggestions) - def test_empty_metrics_report(self): - """测试空指标报告""" + def test_summary_suggestions_high_fail_rate(self): + """Should suggest optimizing when evaluation fail rate > 10%.""" metrics = MetricsSummary( - total_traces=0, - completed_traces=0, + total_traces=10, + completed_traces=10, failed_traces=0, - avg_thinking_rounds=0, - avg_tool_calls=0, - avg_duration_ms=0, - tool_success_rate=0, - task_success_rate=0, - tool_usage={}, - error_types={}, + task_success_rate=1.0, ) + scores = [ + EvaluationScore(result=EvaluationResult.FAIL, score=0.3, reason="fail"), + EvaluationScore(result=EvaluationResult.PASS, score=0.9, reason="ok"), + ] + suggestions = ReportGenerator._generate_summary_suggestions(metrics, scores) + assert any("评估失败率" in s for s in suggestions) - report = ReportGenerator.generate_summary_report(metrics) + def test_summary_suggestions_good_performance(self): + """Should give positive suggestion when everything is fine.""" + metrics = MetricsSummary( + total_traces=10, + completed_traces=10, + failed_traces=0, + avg_thinking_rounds=1.0, + tool_success_rate=1.0, + task_success_rate=1.0, + ) + suggestions = ReportGenerator._generate_summary_suggestions(metrics, None) + assert any("表现良好" in s for s in suggestions) - assert "AgentOps 汇总报告" in report - assert "0" in report + def test_summary_suggestions_low_tool_success_rate(self): + """Should suggest checking tool implementation when tool success rate < 0.95.""" + metrics = MetricsSummary( + total_traces=10, + completed_traces=10, + failed_traces=0, + task_success_rate=1.0, + tool_success_rate=0.90, + ) + suggestions = ReportGenerator._generate_summary_suggestions(metrics, None) + assert any("工具成功率" in s for s in suggestions) diff --git a/tests/test_cli/test_main.py b/tests/test_cli/test_main.py new file mode 100644 index 0000000..dde123d --- /dev/null +++ b/tests/test_cli/test_main.py @@ -0,0 +1,262 @@ +"""Tests for CLI main module: argument parsing and config management.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from jojo_code.cli.main import ( + DEFAULT_CONFIG, + config_get, + config_set, + config_show, + load_config, + main, + save_config, +) + +# ========== Config Management Tests ========== + + +class TestLoadConfig: + """Tests for load_config().""" + + def test_returns_default_when_no_file(self, tmp_path, monkeypatch): + """Should return DEFAULT_CONFIG when config file does not exist.""" + monkeypatch.setattr("jojo_code.cli.main.CONFIG_FILE", tmp_path / "nonexistent.json") + monkeypatch.setattr("jojo_code.cli.main.CONFIG_DIR", tmp_path) + result = load_config() + assert result == DEFAULT_CONFIG + + def test_loads_valid_json(self, tmp_path, monkeypatch): + """Should load and return valid JSON config.""" + config_file = tmp_path / "config.json" + expected = {"model": "gpt-4", "port": "9090"} + config_file.write_text(json.dumps(expected), encoding="utf-8") + monkeypatch.setattr("jojo_code.cli.main.CONFIG_FILE", config_file) + monkeypatch.setattr("jojo_code.cli.main.CONFIG_DIR", tmp_path) + result = load_config() + assert result == expected + + def test_returns_default_on_invalid_json(self, tmp_path, monkeypatch): + """Should return DEFAULT_CONFIG when JSON is malformed.""" + config_file = tmp_path / "config.json" + config_file.write_text("not valid json{{", encoding="utf-8") + monkeypatch.setattr("jojo_code.cli.main.CONFIG_FILE", config_file) + monkeypatch.setattr("jojo_code.cli.main.CONFIG_DIR", tmp_path) + result = load_config() + assert result == DEFAULT_CONFIG + + +class TestSaveConfig: + """Tests for save_config().""" + + def test_creates_config_dir(self, tmp_path, monkeypatch): + """Should create config directory if it does not exist.""" + config_dir = tmp_path / "new_dir" + config_file = config_dir / "config.json" + monkeypatch.setattr("jojo_code.cli.main.CONFIG_FILE", config_file) + monkeypatch.setattr("jojo_code.cli.main.CONFIG_DIR", config_dir) + save_config({"key": "value"}) + assert config_file.exists() + + def test_writes_json_content(self, tmp_path, monkeypatch): + """Should write config as formatted JSON.""" + config_dir = tmp_path / "config_dir" + config_file = config_dir / "config.json" + monkeypatch.setattr("jojo_code.cli.main.CONFIG_FILE", config_file) + monkeypatch.setattr("jojo_code.cli.main.CONFIG_DIR", config_dir) + data = {"model": "claude-3", "port": "3000"} + save_config(data) + loaded = json.loads(config_file.read_text(encoding="utf-8")) + assert loaded == data + + def test_overwrites_existing(self, tmp_path, monkeypatch): + """Should overwrite existing config file.""" + config_dir = tmp_path / "config_dir" + config_file = config_dir / "config.json" + monkeypatch.setattr("jojo_code.cli.main.CONFIG_FILE", config_file) + monkeypatch.setattr("jojo_code.cli.main.CONFIG_DIR", config_dir) + save_config({"old": "data"}) + save_config({"new": "data"}) + loaded = json.loads(config_file.read_text(encoding="utf-8")) + assert loaded == {"new": "data"} + + +class TestConfigSet: + """Tests for config_set() command handler.""" + + def test_sets_config_value(self, tmp_path, monkeypatch, capsys): + """Should set a config key-value pair.""" + config_dir = tmp_path / "config" + config_file = config_dir / "config.json" + monkeypatch.setattr("jojo_code.cli.main.CONFIG_FILE", config_file) + monkeypatch.setattr("jojo_code.cli.main.CONFIG_DIR", config_dir) + args = MagicMock() + args.key = "model" + args.value = "gpt-4" + config_set(args) + loaded = json.loads(config_file.read_text(encoding="utf-8")) + assert loaded["model"] == "gpt-4" + captured = capsys.readouterr() + assert "model" in captured.out + assert "gpt-4" in captured.out + + +class TestConfigShow: + """Tests for config_show() command handler.""" + + def test_displays_config(self, tmp_path, monkeypatch, capsys): + """Should display all config items.""" + config_dir = tmp_path / "config" + config_file = config_dir / "config.json" + monkeypatch.setattr("jojo_code.cli.main.CONFIG_FILE", config_file) + monkeypatch.setattr("jojo_code.cli.main.CONFIG_DIR", config_dir) + save_config({"model": "gpt-4", "port": "8080"}) + args = MagicMock() + config_show(args) + captured = capsys.readouterr() + assert "model" in captured.out + assert "port" in captured.out + + +class TestConfigGet: + """Tests for config_get() command handler.""" + + def test_gets_existing_key(self, tmp_path, monkeypatch, capsys): + """Should print the value for an existing key.""" + config_dir = tmp_path / "config" + config_file = config_dir / "config.json" + monkeypatch.setattr("jojo_code.cli.main.CONFIG_FILE", config_file) + monkeypatch.setattr("jojo_code.cli.main.CONFIG_DIR", config_dir) + save_config({"model": "gpt-4"}) + args = MagicMock() + args.key = "model" + config_get(args) + captured = capsys.readouterr() + assert "gpt-4" in captured.out + + def test_reports_missing_key(self, tmp_path, monkeypatch, capsys): + """Should report when a key does not exist.""" + config_dir = tmp_path / "config" + config_file = config_dir / "config.json" + monkeypatch.setattr("jojo_code.cli.main.CONFIG_FILE", config_file) + monkeypatch.setattr("jojo_code.cli.main.CONFIG_DIR", config_dir) + save_config({}) + args = MagicMock() + args.key = "nonexistent" + config_get(args) + captured = capsys.readouterr() + assert "nonexistent" in captured.out + + +# ========== Argument Parsing Tests ========== + + +class TestArgParsing: + """Tests for CLI argument parsing via main().""" + + def test_version_flag(self, capsys): + """Should print version and exit.""" + with pytest.raises(SystemExit) as exc_info: + with patch("sys.argv", ["jojo-code", "--version"]): + main() + assert exc_info.value.code == 0 + captured = capsys.readouterr() + assert "0.2.0" in captured.out + + def test_help_flag(self, capsys): + """Should print help and exit.""" + with pytest.raises(SystemExit) as exc_info: + with patch("sys.argv", ["jojo-code", "--help"]): + main() + assert exc_info.value.code == 0 + captured = capsys.readouterr() + assert "jojo-code" in captured.out + + @patch("jojo_code.cli.main.config_show") + def test_config_show_command(self, mock_show): + """Should dispatch to config_show.""" + with patch("sys.argv", ["jojo-code", "config", "show"]): + main() + mock_show.assert_called_once() + + @patch("jojo_code.cli.main.config_set") + def test_config_set_command(self, mock_set): + """Should dispatch to config_set with key and value.""" + with patch("sys.argv", ["jojo-code", "config", "set", "model", "gpt-4"]): + main() + mock_set.assert_called_once() + args = mock_set.call_args[0][0] + assert args.key == "model" + assert args.value == "gpt-4" + + @patch("jojo_code.cli.main.config_get") + def test_config_get_command(self, mock_get): + """Should dispatch to config_get with key.""" + with patch("sys.argv", ["jojo-code", "config", "get", "model"]): + main() + mock_get.assert_called_once() + args = mock_get.call_args[0][0] + assert args.key == "model" + + @patch("jojo_code.cli.main.setup_wizard") + def test_setup_command(self, mock_setup): + """Should dispatch to setup_wizard.""" + with patch("sys.argv", ["jojo-code", "setup"]): + main() + mock_setup.assert_called_once() + + @patch("jojo_code.cli.main.server_status") + def test_server_status_command(self, mock_status): + """Should dispatch to server_status.""" + with patch("sys.argv", ["jojo-code", "server", "status"]): + main() + mock_status.assert_called_once() + + @patch("jojo_code.cli.main.server_stop") + def test_server_stop_command(self, mock_stop): + """Should dispatch to server_stop.""" + with patch("sys.argv", ["jojo-code", "server", "stop"]): + main() + mock_stop.assert_called_once() + + def test_server_start_with_daemon_flag(self): + """Should parse daemon flag for server start.""" + with patch("jojo_code.cli.main.server_start") as mock_start: + with patch("sys.argv", ["jojo-code", "server", "start", "-d"]): + main() + mock_start.assert_called_once() + args = mock_start.call_args[0][0] + assert args.daemon is True + + def test_server_start_with_host_port(self): + """Should parse host and port for server start.""" + with patch("jojo_code.cli.main.server_start") as mock_start: + with patch( + "sys.argv", + ["jojo-code", "server", "start", "--host", "127.0.0.1", "--port", "9090"], + ): + main() + mock_start.assert_called_once() + args = mock_start.call_args[0][0] + assert args.host == "127.0.0.1" + assert args.port == "9090" + + def test_no_server_flag(self): + """Should parse --no-server flag.""" + with patch("jojo_code.cli.main.start_tui") as mock_tui: + with patch("sys.argv", ["jojo-code", "--no-server"]): + main() + mock_tui.assert_called_once() + args = mock_tui.call_args[0][0] + assert args.no_server is True + + def test_server_url_flag(self): + """Should parse --server flag.""" + with patch("jojo_code.cli.main.start_tui") as mock_tui: + with patch("sys.argv", ["jojo-code", "--server", "ws://custom:9090/ws"]): + main() + mock_tui.assert_called_once() + args = mock_tui.call_args[0][0] + assert args.server == "ws://custom:9090/ws" diff --git a/tests/test_context/__init__.py b/tests/test_context/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_context/test_init.py b/tests/test_context/test_init.py new file mode 100644 index 0000000..45ae86d --- /dev/null +++ b/tests/test_context/test_init.py @@ -0,0 +1,276 @@ +"""Tests for context package: project context initialization and utilities.""" + +from pathlib import Path + +from jojo_code.context import ( + LazyIgnoreManager, + find_project_root, + init_project_context, + load_project_context, + parse_agents_md, +) + +# ========== find_project_root Tests ========== + + +class TestFindProjectRoot: + """Tests for find_project_root().""" + + def test_finds_root_with_git_marker(self, tmp_path): + """Should find root when .git directory exists.""" + project = tmp_path / "my_project" + project.mkdir() + (project / ".git").mkdir() + sub = project / "src" / "deep" + sub.mkdir(parents=True) + result = find_project_root(sub) + assert result == project + + def test_finds_root_with_pyproject_toml(self, tmp_path): + """Should find root when pyproject.toml exists.""" + project = tmp_path / "my_project" + project.mkdir() + (project / "pyproject.toml").write_text("[project]", encoding="utf-8") + result = find_project_root(project) + assert result == project + + def test_returns_none_when_no_markers(self, tmp_path): + """Should return None when no root markers found up to filesystem root.""" + isolated = tmp_path / "no_root" / "subdir" + isolated.mkdir(parents=True) + # tmp_path has no .git or pyproject.toml, but may have them in parents + # Use a deeply nested path that definitely has no markers + result = find_project_root(isolated) + # If tmp_path itself has markers (unlikely), result might not be None + # The key assertion is that it doesn't crash + assert result is None or isinstance(result, Path) + + def test_start_from_file(self, tmp_path): + """Should handle start path being a file.""" + project = tmp_path / "proj" + project.mkdir() + (project / ".git").mkdir() + some_file = project / "main.py" + some_file.write_text("print('hello')", encoding="utf-8") + result = find_project_root(some_file) + assert result == project + + def test_start_from_none(self): + """Should use cwd when start is None.""" + result = find_project_root(None) + # Should return something (the actual project root from cwd) + assert result is None or isinstance(result, Path) + + def test_finds_nearest_root(self, tmp_path): + """Should find the nearest ancestor with markers.""" + outer = tmp_path / "outer" + outer.mkdir() + (outer / ".git").mkdir() + inner = outer / "inner" + inner.mkdir() + (inner / ".git").mkdir() + deep = inner / "src" + deep.mkdir() + result = find_project_root(deep) + assert result == inner + + +# ========== parse_agents_md Tests ========== + + +class TestParseAgentsMd: + """Tests for parse_agents_md().""" + + def test_parses_basic_format(self, tmp_path): + """Should parse a standard AGENTS.md file.""" + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text( + "# AGENTS\n\n## code-reviewer\n- review code\n- find bugs\n\n" + "## researcher\n- search docs\n", + encoding="utf-8", + ) + result = parse_agents_md(agents_file) + assert "code-reviewer" in result + assert result["code-reviewer"] == ["review code", "find bugs"] + assert "researcher" in result + assert result["researcher"] == ["search docs"] + + def test_returns_empty_for_missing_file(self, tmp_path): + """Should return empty dict when file does not exist.""" + result = parse_agents_md(tmp_path / "nonexistent.md") + assert result == {} + + def test_returns_empty_for_empty_file(self, tmp_path): + """Should return empty dict for empty file.""" + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text("", encoding="utf-8") + result = parse_agents_md(agents_file) + assert result == {} + + def test_handles_star_bullets(self, tmp_path): + """Should parse bullet points starting with *.""" + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text( + "## agent-x\n* task one\n* task two\n", + encoding="utf-8", + ) + result = parse_agents_md(agents_file) + assert result["agent-x"] == ["task one", "task two"] + + def test_ignores_content_before_first_header(self, tmp_path): + """Should ignore lines before the first ## header.""" + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text( + "Some intro text\nMore text\n\n## my-agent\n- do stuff\n", + encoding="utf-8", + ) + result = parse_agents_md(agents_file) + assert "my-agent" in result + assert result["my-agent"] == ["do stuff"] + + def test_handles_empty_sections(self, tmp_path): + """Should handle sections with no bullet points.""" + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text( + "## empty-agent\n\n## useful-agent\n- do things\n", + encoding="utf-8", + ) + result = parse_agents_md(agents_file) + assert "empty-agent" in result + assert result["empty-agent"] == [] + assert result["useful-agent"] == ["do things"] + + +# ========== load_project_context Tests ========== + + +class TestLoadProjectContext: + """Tests for load_project_context().""" + + def test_returns_root_and_agents(self, tmp_path): + """Should return dict with root and agents keys.""" + project = tmp_path / "proj" + project.mkdir() + (project / ".git").mkdir() + result = load_project_context(project) + assert "root" in result + assert "agents" in result + assert result["root"] == str(project) + + def test_agents_empty_when_no_agents_md(self, tmp_path): + """Should have empty agents dict when AGENTS.md is missing.""" + project = tmp_path / "proj" + project.mkdir() + (project / ".git").mkdir() + result = load_project_context(project) + assert result["agents"] == {} + + def test_agents_parsed_when_file_exists(self, tmp_path): + """Should parse AGENTS.md when present.""" + project = tmp_path / "proj" + project.mkdir() + (project / ".git").mkdir() + (project / "AGENTS.md").write_text("## helper\n- assist user\n", encoding="utf-8") + result = load_project_context(project) + assert "helper" in result["agents"] + assert result["agents"]["helper"] == ["assist user"] + + +# ========== init_project_context Tests ========== + + +class TestInitProjectContext: + """Tests for init_project_context().""" + + def test_returns_tuple_with_path_and_dict(self, tmp_path): + """Should return (agents_md_path, context_dict).""" + project = tmp_path / "proj" + project.mkdir() + (project / ".git").mkdir() + agents_path, context = init_project_context(project) + assert isinstance(agents_path, Path) + assert isinstance(context, dict) + assert "root" in context + assert "agents_md" in context + + def test_creates_agents_md_if_missing(self, tmp_path): + """Should create AGENTS.md skeleton when missing.""" + project = tmp_path / "proj" + project.mkdir() + (project / ".git").mkdir() + agents_path, _ = init_project_context(project) + assert agents_path.exists() + content = agents_path.read_text(encoding="utf-8") + assert "AGENTS" in content + + def test_preserves_existing_agents_md(self, tmp_path): + """Should not overwrite existing AGENTS.md.""" + project = tmp_path / "proj" + project.mkdir() + (project / ".git").mkdir() + existing = project / "AGENTS.md" + existing.write_text("## my-agent\n- custom\n", encoding="utf-8") + agents_path, _ = init_project_context(project) + content = agents_path.read_text(encoding="utf-8") + assert "custom" in content + + def test_context_has_correct_root(self, tmp_path): + """Should set root in context dict.""" + project = tmp_path / "proj" + project.mkdir() + (project / ".git").mkdir() + _, context = init_project_context(project) + assert context["root"] == str(project) + + +# ========== LazyIgnoreManager Tests ========== + + +class TestLazyIgnoreManager: + """Tests for LazyIgnoreManager.""" + + def test_should_ignore_basic_pattern(self, tmp_path): + """Should ignore files matching .gitignore patterns.""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.pyc\n__pycache__/\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + assert manager.should_ignore(tmp_path / "test.pyc") is True + assert manager.should_ignore(tmp_path / "main.py") is False + + def test_should_ignore_directory_pattern(self, tmp_path): + """Should ignore directories matching trailing-slash patterns.""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("node_modules/\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + target = tmp_path / "node_modules" + target.mkdir() + assert manager.should_ignore(target) is True + + def test_should_not_ignore_unmatched(self, tmp_path): + """Should not ignore files that don't match any pattern.""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.log\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + assert manager.should_ignore(tmp_path / "readme.md") is False + + def test_handles_missing_gitignore(self, tmp_path): + """Should handle missing .gitignore gracefully.""" + manager = LazyIgnoreManager(tmp_path) + assert manager.should_ignore(tmp_path / "anything.txt") is False + + def test_clear_cache(self, tmp_path): + """Should clear and reload cache.""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.log\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + assert manager.should_ignore(tmp_path / "test.log") is True + manager.clear_cache() + assert manager.should_ignore(tmp_path / "test.log") is True + + def test_outside_root_not_ignored(self, tmp_path): + """Should not ignore paths outside the project root.""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.pyc\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + outside = Path("/some/other/path/test.pyc") + assert manager.should_ignore(outside) is False diff --git a/tests/test_context/test_lazy_ignore.py b/tests/test_context/test_lazy_ignore.py new file mode 100644 index 0000000..9c1c3dd --- /dev/null +++ b/tests/test_context/test_lazy_ignore.py @@ -0,0 +1,321 @@ +"""Lazy .gitignore 加载管理器测试""" + +from pathlib import Path + +from jojo_code.context.lazy_ignore import LazyIgnoreManager + + +class TestLazyIgnoreManagerInit: + """测试 LazyIgnoreManager 初始化""" + + def test_init_with_gitignore(self, tmp_path: Path): + """测试有 .gitignore 时初始化""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.pyc\n__pycache__/\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + assert manager._root == tmp_path.resolve() + + def test_init_without_gitignore(self, tmp_path: Path): + """测试没有 .gitignore 时初始化""" + manager = LazyIgnoreManager(tmp_path) + assert manager._root == tmp_path.resolve() + # 空的缓存集合 + assert tmp_path.resolve() in manager._cache + assert len(manager._cache[tmp_path.resolve()]) == 0 + + def test_root_is_resolved(self, tmp_path: Path): + """测试根目录被解析为绝对路径""" + subdir = tmp_path / "project" + subdir.mkdir() + manager = LazyIgnoreManager(subdir) + assert manager._root == subdir.resolve() + + def test_cache_loaded_on_init(self, tmp_path: Path): + """测试初始化时加载缓存""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.log\n*.tmp\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + patterns = manager._cache[tmp_path.resolve()] + assert "*.log" in patterns + assert "*.tmp" in patterns + + +class TestParseGitignore: + """测试 .gitignore 解析""" + + def test_parse_basic_patterns(self, tmp_path: Path): + """测试解析基本模式""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.pyc\n*.log\nnode_modules/\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + patterns = manager._cache[tmp_path.resolve()] + assert "*.pyc" in patterns + assert "*.log" in patterns + assert "node_modules/" in patterns + + def test_parse_ignores_comments(self, tmp_path: Path): + """测试忽略注释""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text( + "# This is a comment\n*.pyc\n# Another comment\n*.log\n", + encoding="utf-8", + ) + manager = LazyIgnoreManager(tmp_path) + patterns = manager._cache[tmp_path.resolve()] + assert len(patterns) == 2 + assert "*.pyc" in patterns + assert "*.log" in patterns + + def test_parse_ignores_empty_lines(self, tmp_path: Path): + """测试忽略空行""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.pyc\n\n\n*.log\n\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + patterns = manager._cache[tmp_path.resolve()] + assert len(patterns) == 2 + + def test_parse_strips_whitespace(self, tmp_path: Path): + """测试去除空白字符""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text(" *.pyc \n *.log \n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + patterns = manager._cache[tmp_path.resolve()] + assert "*.pyc" in patterns + assert "*.log" in patterns + + def test_parse_empty_file(self, tmp_path: Path): + """测试解析空文件""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + patterns = manager._cache[tmp_path.resolve()] + assert len(patterns) == 0 + + def test_parse_only_comments(self, tmp_path: Path): + """测试只有注释的文件""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("# comment1\n# comment2\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + patterns = manager._cache[tmp_path.resolve()] + assert len(patterns) == 0 + + +class TestShouldIgnore: + """测试 should_ignore 方法""" + + def test_ignore_matching_pattern(self, tmp_path: Path): + """测试匹配模式被忽略""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.pyc\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + assert manager.should_ignore(tmp_path / "test.pyc") is True + + def test_not_ignore_unmatched(self, tmp_path: Path): + """测试不匹配模式不被忽略""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.pyc\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + assert manager.should_ignore(tmp_path / "main.py") is False + + def test_ignore_directory_pattern(self, tmp_path: Path): + """测试目录模式""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("node_modules/\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + target = tmp_path / "node_modules" + target.mkdir() + assert manager.should_ignore(target) is True + + def test_ignore_nested_file(self, tmp_path: Path): + """测试嵌套文件匹配""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.pyc\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + nested = tmp_path / "src" / "deep" / "test.pyc" + nested.parent.mkdir(parents=True) + assert manager.should_ignore(nested) is True + + def test_ignore_subdirectory_gitignore(self, tmp_path: Path): + """测试子目录 .gitignore""" + root_gitignore = tmp_path / ".gitignore" + root_gitignore.write_text("*.log\n", encoding="utf-8") + subdir = tmp_path / "subdir" + subdir.mkdir() + sub_gitignore = subdir / ".gitignore" + sub_gitignore.write_text("*.tmp\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + # 根目录模式 + assert manager.should_ignore(tmp_path / "test.log") is True + # 子目录模式 + assert manager.should_ignore(subdir / "test.tmp") is True + # 根目录模式对子目录文件也生效 + assert manager.should_ignore(subdir / "test.log") is True + + def test_outside_root_not_ignored(self, tmp_path: Path): + """测试项目外路径不被忽略""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.pyc\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + outside = Path("/some/other/path/test.pyc") + assert manager.should_ignore(outside) is False + + def test_negation_pattern(self, tmp_path: Path): + """测试否定模式""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.pyc\n!important.pyc\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + # 普通 .pyc 被忽略 + assert manager.should_ignore(tmp_path / "test.pyc") is True + # important.pyc 不被忽略(否定模式) + assert manager.should_ignore(tmp_path / "important.pyc") is False + + def test_root_relative_pattern(self, tmp_path: Path): + """测试根目录相对模式(以 / 开头)""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("/root_only.txt\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + # 根目录下的文件被忽略 + assert manager.should_ignore(tmp_path / "root_only.txt") is True + # 子目录下的同名文件不被忽略 + subdir = tmp_path / "subdir" + subdir.mkdir() + assert manager.should_ignore(subdir / "root_only.txt") is False + + def test_multiple_patterns(self, tmp_path: Path): + """测试多个模式""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.pyc\n*.log\n*.tmp\nnode_modules/\n.env\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + assert manager.should_ignore(tmp_path / "test.pyc") is True + assert manager.should_ignore(tmp_path / "app.log") is True + assert manager.should_ignore(tmp_path / "cache.tmp") is True + nm = tmp_path / "node_modules" + nm.mkdir() + assert manager.should_ignore(nm) is True + assert manager.should_ignore(tmp_path / ".env") is True + assert manager.should_ignore(tmp_path / "main.py") is False + + def test_no_gitignore(self, tmp_path: Path): + """测试没有 .gitignore 时不忽略任何文件""" + manager = LazyIgnoreManager(tmp_path) + assert manager.should_ignore(tmp_path / "test.pyc") is False + assert manager.should_ignore(tmp_path / "main.py") is False + + def test_file_path_input(self, tmp_path: Path): + """测试文件路径输入""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.pyc\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + test_file = tmp_path / "test.pyc" + test_file.touch() + assert manager.should_ignore(test_file) is True + + +class TestClearCache: + """测试缓存清除""" + + def test_clear_cache_reloads_root(self, tmp_path: Path): + """测试清除缓存后重新加载根目录""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.log\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + assert manager.should_ignore(tmp_path / "test.log") is True + manager.clear_cache() + assert manager.should_ignore(tmp_path / "test.log") is True + + def test_clear_cache_removes_subdirectory_cache(self, tmp_path: Path): + """测试清除缓存移除子目录缓存""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.log\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + # 触发子目录缓存加载 + subdir = tmp_path / "subdir" + subdir.mkdir() + sub_gitignore = subdir / ".gitignore" + sub_gitignore.write_text("*.tmp\n", encoding="utf-8") + manager.should_ignore(subdir / "test.tmp") + assert subdir.resolve() in manager._cache + manager.clear_cache() + # 清除后子目录缓存应该不存在 + assert subdir.resolve() not in manager._cache + + +class TestLazyLoading: + """测试懒加载特性""" + + def test_subdirectory_loaded_lazily(self, tmp_path: Path): + """测试子目录按需加载""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.log\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + subdir = tmp_path / "subdir" + subdir.mkdir() + sub_gitignore = subdir / ".gitignore" + sub_gitignore.write_text("*.tmp\n", encoding="utf-8") + # 子目录缓存不应在初始化时加载 + assert subdir.resolve() not in manager._cache + # 触发加载 + manager.should_ignore(subdir / "test.tmp") + # 现在应该在缓存中 + assert subdir.resolve() in manager._cache + + def test_only_root_loaded_at_init(self, tmp_path: Path): + """测试初始化时只加载根目录""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.log\n", encoding="utf-8") + subdir1 = tmp_path / "sub1" + subdir1.mkdir() + sub1_gitignore = subdir1 / ".gitignore" + sub1_gitignore.write_text("*.a\n", encoding="utf-8") + subdir2 = tmp_path / "sub2" + subdir2.mkdir() + sub2_gitignore = subdir2 / ".gitignore" + sub2_gitignore.write_text("*.b\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + # 只有根目录在缓存中 + assert len(manager._cache) == 1 + assert tmp_path.resolve() in manager._cache + + +class TestEdgeCases: + """测试边界情况""" + + def test_empty_gitignore(self, tmp_path: Path): + """测试空 .gitignore""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + assert manager.should_ignore(tmp_path / "any_file.py") is False + + def test_gitignore_with_only_whitespace(self, tmp_path: Path): + """测试只有空白的 .gitignore""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text(" \n \n\n \n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + assert manager.should_ignore(tmp_path / "any_file.py") is False + + def test_special_characters_in_patterns(self, tmp_path: Path): + """测试特殊字符模式""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("[\n]\n*\n?\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + # 不应抛出异常 + manager.should_ignore(tmp_path / "test.py") + + def test_deeply_nested_subdirectory(self, tmp_path: Path): + """测试深层嵌套子目录""" + gitignore = tmp_path / ".gitignore" + gitignore.write_text("*.log\n", encoding="utf-8") + manager = LazyIgnoreManager(tmp_path) + deep_dir = tmp_path / "a" / "b" / "c" / "d" + deep_dir.mkdir(parents=True) + deep_gitignore = deep_dir / ".gitignore" + deep_gitignore.write_text("*.secret\n", encoding="utf-8") + assert manager.should_ignore(deep_dir / "test.secret") is True + assert manager.should_ignore(deep_dir / "test.log") is True + + def test_file_outside_root_returns_false(self, tmp_path: Path): + """测试根目录外文件返回 False""" + manager = LazyIgnoreManager(tmp_path) + # 完全不同的路径 + assert manager.should_ignore(Path("/completely/different/path")) is False diff --git a/tests/test_context/test_project.py b/tests/test_context/test_project.py new file mode 100644 index 0000000..f778c5e --- /dev/null +++ b/tests/test_context/test_project.py @@ -0,0 +1,343 @@ +"""Tests for jojo_code.context.project module.""" + +from pathlib import Path + +from jojo_code.context.project import ( + ROOT_MARKERS, + _markers_in_dir, + find_project_root, + load_project_context, + parse_agents_md, +) + + +class TestRootMarkers: + """Tests for ROOT_MARKERS constant.""" + + def test_root_markers_is_list(self): + assert isinstance(ROOT_MARKERS, list) + + def test_root_markers_contains_git(self): + assert ".git" in ROOT_MARKERS + + def test_root_markers_contains_pyproject(self): + assert "pyproject.toml" in ROOT_MARKERS + + +class TestMarkersInDir: + """Tests for _markers_in_dir function.""" + + def test_with_git_directory(self, tmp_path): + """Should return True when .git directory exists.""" + (tmp_path / ".git").mkdir() + assert _markers_in_dir(tmp_path) is True + + def test_with_pyproject_toml(self, tmp_path): + """Should return True when pyproject.toml exists.""" + (tmp_path / "pyproject.toml").touch() + assert _markers_in_dir(tmp_path) is True + + def test_with_both_markers(self, tmp_path): + """Should return True when both markers exist.""" + (tmp_path / ".git").mkdir() + (tmp_path / "pyproject.toml").touch() + assert _markers_in_dir(tmp_path) is True + + def test_without_markers(self, tmp_path): + """Should return False when no markers exist.""" + assert _markers_in_dir(tmp_path) is False + + def test_with_other_files(self, tmp_path): + """Should return False when only unrelated files exist.""" + (tmp_path / "README.md").touch() + (tmp_path / "src").mkdir() + assert _markers_in_dir(tmp_path) is False + + +class TestFindProjectRoot: + """Tests for find_project_root function.""" + + def test_from_directory_with_git(self, tmp_path): + """Should find root when .git is present.""" + (tmp_path / ".git").mkdir() + result = find_project_root(tmp_path) + assert result == tmp_path + + def test_from_directory_with_pyproject(self, tmp_path): + """Should find root when pyproject.toml is present.""" + (tmp_path / "pyproject.toml").touch() + result = find_project_root(tmp_path) + assert result == tmp_path + + def test_from_subdirectory(self, tmp_path): + """Should find root from a subdirectory.""" + (tmp_path / ".git").mkdir() + subdir = tmp_path / "src" / "jojo_code" + subdir.mkdir(parents=True) + result = find_project_root(subdir) + assert result == tmp_path + + def test_from_nested_subdirectory(self, tmp_path): + """Should find root from deeply nested subdirectory.""" + (tmp_path / "pyproject.toml").touch() + deep_dir = tmp_path / "a" / "b" / "c" / "d" + deep_dir.mkdir(parents=True) + result = find_project_root(deep_dir) + assert result == tmp_path + + def test_from_file_path(self, tmp_path): + """Should find root when start path is a file.""" + (tmp_path / ".git").mkdir() + file_path = tmp_path / "src" / "main.py" + file_path.parent.mkdir(parents=True) + file_path.touch() + result = find_project_root(file_path) + assert result == tmp_path + + def test_no_root_found(self, tmp_path): + """Should return None when no root markers found.""" + # Create a directory without any markers + subdir = tmp_path / "no_markers" + subdir.mkdir() + # We need to make sure we don't find markers above tmp_path + # Use a path that won't have markers + result = find_project_root(Path("/tmp/jojo_test_no_root")) + assert result is None + + def test_default_start_is_cwd(self, monkeypatch, tmp_path): + """Should use cwd when start is None.""" + (tmp_path / ".git").mkdir() + monkeypatch.chdir(tmp_path) + result = find_project_root() + assert result == tmp_path + + def test_finds_first_marker_upwards(self, tmp_path): + """Should find the nearest root marker when multiple exist.""" + # Create nested project structure + outer = tmp_path / "outer" + outer.mkdir() + (outer / ".git").mkdir() + + inner = outer / "inner" + inner.mkdir() + (inner / "pyproject.toml").touch() + + # From inner, should find inner first + result = find_project_root(inner) + assert result == inner + + def test_marker_at_root(self, tmp_path): + """Should find root marker at filesystem root.""" + # This test verifies the function handles reaching root gracefully + result = find_project_root(Path("/")) + # At filesystem root, should check for markers + # This is implementation-dependent, but should not raise + assert result is None or isinstance(result, Path) + + +class TestParseAgentsMd: + """Tests for parse_agents_md function.""" + + def test_empty_file(self, tmp_path): + """Should return empty dict for empty file.""" + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text("") + result = parse_agents_md(agents_file) + assert result == {} + + def test_nonexistent_file(self, tmp_path): + """Should return empty dict for nonexistent file.""" + agents_file = tmp_path / "nonexistent.md" + result = parse_agents_md(agents_file) + assert result == {} + + def test_single_agent(self, tmp_path): + """Should parse single agent with capabilities.""" + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text("""## CodeReviewer +- Review code quality +- Check for bugs +- Suggest improvements +""") + result = parse_agents_md(agents_file) + assert "CodeReviewer" in result + assert len(result["CodeReviewer"]) == 3 + assert "Review code quality" in result["CodeReviewer"] + + def test_multiple_agents(self, tmp_path): + """Should parse multiple agents.""" + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text("""## CodeReviewer +- Review code quality +- Check for bugs + +## TestWriter +- Generate unit tests +- Create test fixtures + +## Researcher +- Search documentation +- Find examples +""") + result = parse_agents_md(agents_file) + assert len(result) == 3 + assert "CodeReviewer" in result + assert "TestWriter" in result + assert "Researcher" in result + + def test_agent_with_no_capabilities(self, tmp_path): + """Should handle agent with no bullet points.""" + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text("""## EmptyAgent + +## AnotherAgent +- Has capability +""") + result = parse_agents_md(agents_file) + assert "EmptyAgent" in result + assert result["EmptyAgent"] == [] + assert "AnotherAgent" in result + assert len(result["AnotherAgent"]) == 1 + + def test_mixed_bullet_styles(self, tmp_path): + """Should handle both - and * bullet styles.""" + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text("""## MixedAgent +- Dash bullet +* Star bullet +- Another dash +""") + result = parse_agents_md(agents_file) + assert len(result["MixedAgent"]) == 3 + + def test_ignores_non_header_lines(self, tmp_path): + """Should ignore lines that are not headers or bullets.""" + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text("""Some preamble text + +## MyAgent +- Capability 1 + +Some random text in the middle + +- Capability 2 +""") + result = parse_agents_md(agents_file) + assert "MyAgent" in result + # Both bullets should be captured under MyAgent + assert len(result["MyAgent"]) == 2 + + def test_ignores_h1_headers(self, tmp_path): + """Should ignore # headers (only ## are agent headers).""" + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text("""# Main Title + +## RealAgent +- Capability 1 +""") + result = parse_agents_md(agents_file) + assert "Main Title" not in result + assert "RealAgent" in result + + def test_strips_whitespace(self, tmp_path): + """Should strip whitespace from agent names and capabilities.""" + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text("""## SpacedAgent +- Capability with spaces +""") + result = parse_agents_md(agents_file) + assert "SpacedAgent" in result + assert "Capability with spaces" in result["SpacedAgent"] + + def test_read_error_returns_empty(self, tmp_path, monkeypatch): + """Should return empty dict on read error.""" + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text("## Agent\n- Cap") + + # Mock read_text to raise an exception + def mock_read_text(*args, **kwargs): + raise PermissionError("Access denied") + + monkeypatch.setattr(type(agents_file), "read_text", mock_read_text) + result = parse_agents_md(agents_file) + assert result == {} + + +class TestLoadProjectContext: + """Tests for load_project_context function.""" + + def test_loads_from_project_root(self, tmp_path): + """Should load context from project root.""" + (tmp_path / ".git").mkdir() + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text("""## MyAgent +- Do something +""") + + result = load_project_context(tmp_path) + assert result["root"] == str(tmp_path) + assert "MyAgent" in result["agents"] + + def test_loads_from_subdirectory(self, tmp_path): + """Should find root and load context from subdirectory.""" + (tmp_path / ".git").mkdir() + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text("""## TestAgent +- Test capability +""") + subdir = tmp_path / "src" + subdir.mkdir() + + result = load_project_context(subdir) + assert result["root"] == str(tmp_path) + assert "TestAgent" in result["agents"] + + def test_no_root_found(self, tmp_path): + """Should return empty agents when no root found.""" + subdir = tmp_path / "no_root" + subdir.mkdir() + + result = load_project_context(subdir) + # root might be None or might find a parent with markers + if result["root"] is None: + assert result["agents"] == {} + + def test_no_agents_file(self, tmp_path): + """Should return empty agents when AGENTS.md doesn't exist.""" + (tmp_path / ".git").mkdir() + + result = load_project_context(tmp_path) + assert result["root"] == str(tmp_path) + assert result["agents"] == {} + + def test_default_start_is_cwd(self, monkeypatch, tmp_path): + """Should use cwd when start is None.""" + (tmp_path / ".git").mkdir() + monkeypatch.chdir(tmp_path) + + result = load_project_context() + assert result["root"] == str(tmp_path) + + def test_result_structure(self, tmp_path): + """Should return dict with root and agents keys.""" + (tmp_path / ".git").mkdir() + + result = load_project_context(tmp_path) + assert isinstance(result, dict) + assert "root" in result + assert "agents" in result + + def test_agents_is_dict(self, tmp_path): + """Should return agents as a dict.""" + (tmp_path / ".git").mkdir() + agents_file = tmp_path / "AGENTS.md" + agents_file.write_text("""## Agent1 +- Cap1 + +## Agent2 +- Cap2 +""") + + result = load_project_context(tmp_path) + assert isinstance(result["agents"], dict) + assert len(result["agents"]) == 2 diff --git a/tests/test_core/test_api_server.py b/tests/test_core/test_api_server.py new file mode 100644 index 0000000..1749ffe --- /dev/null +++ b/tests/test_core/test_api_server.py @@ -0,0 +1,608 @@ +"""API Server 模块测试 + +测试 APIServer, APIMiddleware, AuthMiddleware, CORSMiddleware, +RateLimitMiddleware, 以及辅助函数 json_response, json_error, require_auth。 +""" + +import json +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock + +import pytest +from aiohttp import web + +from jojo_code.core.api_server import ( + APIMiddleware, + APIRoute, + APIServer, + APIServerBuilder, + AuthMiddleware, + ConversationAPI, + CORSMiddleware, + RateLimitMiddleware, + create_api_server, + get_api_server, + json_error, + json_response, + require_auth, +) + +# ============================================================================= +# 辅助函数测试 +# ============================================================================= + + +class TestJsonResponse: + """json_response 辅助函数测试""" + + def test_basic_json_response(self): + """应该返回正确的 JSON 响应""" + resp = json_response({"key": "value"}) + assert resp.status == 200 + assert resp.content_type == "application/json" + + def test_custom_status_code(self): + """应该支持自定义状态码""" + resp = json_response({"created": True}, status=201) + assert resp.status == 201 + + def test_response_body_is_valid_json(self): + """响应体应该是有效的 JSON""" + data = {"name": "test", "count": 42} + resp = json_response(data) + body = json.loads(resp.text) + assert body == data + + def test_default_str_serialization(self): + """应该用 default=str 序列化不可直接序列化的对象""" + data = {"time": datetime(2026, 1, 1)} + resp = json_response(data) + body = json.loads(resp.text) + assert "2026-01-01" in body["time"] + + +class TestJsonError: + """json_error 辅助函数测试""" + + def test_error_response_status(self): + """应该返回对应的状态码""" + resp = json_error(404, "Not found") + assert resp.status == 404 + + def test_error_response_body(self): + """错误响应体应包含 error 字段""" + resp = json_error(400, "Bad request") + body = json.loads(resp.text) + assert body["error"]["code"] == 400 + assert body["error"]["message"] == "Bad request" + + def test_error_with_details(self): + """应该支持 details 字段""" + resp = json_error(500, "Server error", details={"trace": "abc"}) + body = json.loads(resp.text) + assert body["error"]["details"] == {"trace": "abc"} + + def test_error_without_details(self): + """没有 details 时应为 None""" + resp = json_error(401, "Unauthorized") + body = json.loads(resp.text) + assert body["error"]["details"] is None + + +# ============================================================================= +# APIRoute 测试 +# ============================================================================= + + +class TestAPIRoute: + """APIRoute 数据类测试""" + + def test_create_route(self): + """应该正确创建路由""" + handler = AsyncMock() + route = APIRoute(path="/test", method="GET", handler=handler) + assert route.path == "/test" + assert route.method == "GET" + assert route.handler is handler + assert route.auth_required is False + + def test_route_with_auth_required(self): + """auth_required 默认应为 False""" + route = APIRoute(path="/admin", method="POST", handler=AsyncMock(), auth_required=True) + assert route.auth_required is True + + +# ============================================================================= +# 中间件测试 +# ============================================================================= + + +class TestAPIMiddleware: + """APIMiddleware 基类测试""" + + @pytest.mark.asyncio + async def test_process_request_returns_none(self): + """基类 process_request 应返回 None""" + middleware = APIMiddleware() + result = await middleware.process_request(MagicMock()) + assert result is None + + @pytest.mark.asyncio + async def test_process_response_returns_same(self): + """基类 process_response 应返回原响应""" + middleware = APIMiddleware() + mock_response = MagicMock() + result = await middleware.process_response(mock_response) + assert result is mock_response + + +class TestAuthMiddleware: + """AuthMiddleware 认证中间件测试""" + + @pytest.mark.asyncio + async def test_no_api_key_passes_through(self): + """没有配置 api_key 时应放行""" + middleware = AuthMiddleware(api_key=None) + request = MagicMock() + request.headers = {} + result = await middleware.process_request(request) + assert result is None + + @pytest.mark.asyncio + async def test_missing_auth_header_returns_401(self): + """缺少 Authorization 头应返回 401""" + middleware = AuthMiddleware(api_key="secret") + request = MagicMock() + request.headers = {} + result = await middleware.process_request(request) + assert result.status == 401 + + @pytest.mark.asyncio + async def test_invalid_bearer_prefix_returns_401(self): + """Authorization 头不以 Bearer 开头应返回 401""" + middleware = AuthMiddleware(api_key="secret") + request = MagicMock() + request.headers = {"Authorization": "Basic abc123"} + result = await middleware.process_request(request) + assert result.status == 401 + + @pytest.mark.asyncio + async def test_wrong_token_returns_401(self): + """错误的 token 应返回 401""" + middleware = AuthMiddleware(api_key="secret") + request = MagicMock() + request.headers = {"Authorization": "Bearer wrong-token"} + result = await middleware.process_request(request) + assert result.status == 401 + + @pytest.mark.asyncio + async def test_valid_token_passes_through(self): + """正确的 token 应放行""" + middleware = AuthMiddleware(api_key="secret") + request = MagicMock() + request.headers = {"Authorization": "Bearer secret"} + result = await middleware.process_request(request) + assert result is None + + +class TestCORSMiddleware: + """CORSMiddleware CORS 中间件测试""" + + @pytest.mark.asyncio + async def test_default_origins_allows_all(self): + """默认应允许所有来源""" + middleware = CORSMiddleware() + response = MagicMock() + response.headers = {} + result = await middleware.process_response(response) + assert result.headers["Access-Control-Allow-Origin"] == "*" + + @pytest.mark.asyncio + async def test_custom_origins(self): + """应支持自定义来源列表""" + middleware = CORSMiddleware(allowed_origins=["https://example.com", "https://other.com"]) + response = MagicMock() + response.headers = {} + result = await middleware.process_response(response) + assert "example.com" in result.headers["Access-Control-Allow-Origin"] + assert "other.com" in result.headers["Access-Control-Allow-Origin"] + + @pytest.mark.asyncio + async def test_cors_headers_present(self): + """响应应包含 CORS 相关头""" + middleware = CORSMiddleware() + response = MagicMock() + response.headers = {} + result = await middleware.process_response(response) + assert "Access-Control-Allow-Methods" in result.headers + assert "Access-Control-Allow-Headers" in result.headers + + +class TestRateLimitMiddleware: + """RateLimitMiddleware 速率限制中间件测试""" + + @pytest.mark.asyncio + async def test_within_rate_limit(self): + """在速率限制内应放行""" + middleware = RateLimitMiddleware(max_requests=5, window=60) + request = MagicMock() + request.remote = "127.0.0.1" + result = await middleware.process_request(request) + assert result is None + + @pytest.mark.asyncio + async def test_exceeds_rate_limit(self): + """超出速率限制应返回 429""" + middleware = RateLimitMiddleware(max_requests=2, window=60) + request = MagicMock() + request.remote = "127.0.0.1" + + # 前两次通过 + await middleware.process_request(request) + await middleware.process_request(request) + + # 第三次被限制 + result = await middleware.process_request(request) + assert result.status == 429 + + @pytest.mark.asyncio + async def test_different_ips_independent(self): + """不同 IP 的请求应该独立计数""" + middleware = RateLimitMiddleware(max_requests=1, window=60) + + req1 = MagicMock() + req1.remote = "10.0.0.1" + req2 = MagicMock() + req2.remote = "10.0.0.2" + + result1 = await middleware.process_request(req1) + result2 = await middleware.process_request(req2) + assert result1 is None + assert result2 is None + + @pytest.mark.asyncio + async def test_requests_tracked_per_ip(self): + """请求计数应按 IP 追踪""" + middleware = RateLimitMiddleware(max_requests=3, window=60) + request = MagicMock() + request.remote = "192.168.1.1" + + for _ in range(3): + await middleware.process_request(request) + + assert len(middleware.requests["192.168.1.1"]) == 3 + + +# ============================================================================= +# APIServer 测试 +# ============================================================================= + + +class TestAPIServer: + """APIServer 服务器测试""" + + def test_default_init(self): + """默认初始化参数""" + server = APIServer() + assert server.host == "0.0.0.0" + assert server.port == 8080 + assert isinstance(server.app, web.Application) + + def test_custom_host_port(self): + """应支持自定义 host 和 port""" + server = APIServer(host="127.0.0.1", port=9090) + assert server.host == "127.0.0.1" + assert server.port == 9090 + + def test_builtin_routes_registered(self): + """初始化时应注册内置路由""" + server = APIServer() + paths = [r.path for r in server.routes] + assert "/health" in paths + assert "/api/routes" in paths + + def test_default_middlewares(self): + """默认应包含 Auth、CORS、RateLimit 中间件""" + server = APIServer() + types = [type(m).__name__ for m in server.middlewares] + assert "AuthMiddleware" in types + assert "CORSMiddleware" in types + assert "RateLimitMiddleware" in types + + def test_add_route(self): + """add_route 应添加路由""" + server = APIServer() + initial_count = len(server.routes) + + async def custom_handler(request): + return json_response({"custom": True}) + + server.add_route("GET", "/custom", custom_handler) + assert len(server.routes) == initial_count + 1 + assert any(r.path == "/custom" for r in server.routes) + + def test_add_route_with_auth(self): + """add_route 应支持 auth_required 标志""" + server = APIServer() + + async def handler(request): + return json_response({}) + + server.add_route("POST", "/protected", handler, auth_required=True) + protected = [r for r in server.routes if r.path == "/protected"] + assert len(protected) == 1 + assert protected[0].auth_required is True + + @pytest.mark.asyncio + async def test_health_check(self): + """health_check 应返回健康状态""" + server = APIServer() + request = MagicMock() + result = await server.health_check(request) + body = json.loads(result.text) + assert body["status"] == "healthy" + assert "timestamp" in body + + @pytest.mark.asyncio + async def test_list_routes(self): + """list_routes 应返回所有路由信息""" + server = APIServer() + request = MagicMock() + result = await server.list_routes(request) + body = json.loads(result.text) + assert "routes" in body + assert isinstance(body["routes"], list) + # 至少有内置的 health 和 api/routes + assert len(body["routes"]) >= 2 + + +# ============================================================================= +# APIServerBuilder 测试 +# ============================================================================= + + +class TestAPIServerBuilder: + """APIServerBuilder 构建器测试""" + + def test_build_default_server(self): + """默认构建应返回 APIServer""" + builder = APIServerBuilder() + server = builder.build() + assert isinstance(server, APIServer) + + def test_build_with_custom_host(self): + """应该支持自定义 host(通过属性直接设置)""" + builder = APIServerBuilder() + builder.host = "127.0.0.1" + builder.port = 9090 + server = builder.build() + assert server.host == "127.0.0.1" + assert server.port == 9090 + + def test_build_with_api_key(self): + """应该支持设置 API key""" + builder = APIServerBuilder() + builder.api_key = "test-key" + server = builder.build() + auth_middlewares = [m for m in server.middlewares if isinstance(m, AuthMiddleware)] + assert len(auth_middlewares) == 1 + assert auth_middlewares[0].api_key == "test-key" + + def test_build_with_custom_middleware(self): + """应该支持添加自定义中间件""" + builder = APIServerBuilder() + custom = CORSMiddleware(allowed_origins=["https://custom.com"]) + builder.middlewares.append(custom) + server = builder.build() + cors_middlewares = [m for m in server.middlewares if isinstance(m, CORSMiddleware)] + assert len(cors_middlewares) >= 2 # 默认 + 自定义 + + def test_build_with_routes(self): + """应该支持添加自定义路由""" + builder = APIServerBuilder() + builder.routes.append(("GET", "/test", AsyncMock())) + server = builder.build() + assert any(r.path == "/test" for r in server.routes) + + +# ============================================================================= +# ConversationAPI 测试 +# ============================================================================= + + +class TestConversationAPI: + """ConversationAPI 对话 API 测试""" + + @pytest.mark.asyncio + async def test_list_conversations(self): + """list_conversations 应返回对话列表""" + manager = MagicMock() + conv = MagicMock() + conv.id = "conv-1" + conv.title = "Test" + conv.messages = [] + conv.created_at = datetime(2026, 1, 1) + manager.list_conversations.return_value = [conv] + + api = ConversationAPI(manager) + request = MagicMock() + request.query = {"limit": "20", "offset": "0"} + + result = await api.list_conversations(request) + body = json.loads(result.text) + assert len(body["conversations"]) == 1 + assert body["conversations"][0]["id"] == "conv-1" + + @pytest.mark.asyncio + async def test_list_conversations_with_pagination(self): + """应支持分页""" + manager = MagicMock() + conversations = [] + for i in range(5): + conv = MagicMock() + conv.id = f"conv-{i}" + conv.title = f"Conv {i}" + conv.messages = [] + conv.created_at = None + conversations.append(conv) + manager.list_conversations.return_value = conversations + + api = ConversationAPI(manager) + request = MagicMock() + request.query = {"limit": "2", "offset": "1"} + + result = await api.list_conversations(request) + body = json.loads(result.text) + assert len(body["conversations"]) == 2 + + @pytest.mark.asyncio + async def test_get_conversation_found(self): + """get_conversation 应返回对话详情""" + manager = MagicMock() + conv = MagicMock() + conv.id = "conv-1" + conv.title = "Test" + msg = MagicMock() + msg.role = "user" + msg.content = "Hello" + msg.timestamp = datetime(2026, 1, 1) + conv.messages = [msg] + manager.get_conversation.return_value = conv + + api = ConversationAPI(manager) + request = MagicMock() + request.match_info = {"id": "conv-1"} + + result = await api.get_conversation(request) + body = json.loads(result.text) + assert body["id"] == "conv-1" + assert len(body["messages"]) == 1 + + @pytest.mark.asyncio + async def test_get_conversation_not_found(self): + """对话不存在时应返回 404""" + manager = MagicMock() + manager.get_conversation.return_value = None + + api = ConversationAPI(manager) + request = MagicMock() + request.match_info = {"id": "nonexistent"} + + result = await api.get_conversation(request) + assert result.status == 404 + + @pytest.mark.asyncio + async def test_create_conversation(self): + """create_conversation 应创建新对话""" + manager = MagicMock() + conv = MagicMock() + conv.id = "new-conv" + conv.title = "New Chat" + manager.create_conversation.return_value = conv + + api = ConversationAPI(manager) + request = AsyncMock() + request.json.return_value = {"title": "New Chat"} + + result = await api.create_conversation(request) + assert result.status == 201 + body = json.loads(result.text) + assert body["id"] == "new-conv" + + @pytest.mark.asyncio + async def test_send_message_not_found(self): + """对话不存在时发送消息应返回 404""" + manager = MagicMock() + manager.get_conversation.return_value = None + + api = ConversationAPI(manager) + request = AsyncMock() + request.match_info = {"id": "nonexistent"} + request.json.return_value = {"content": "Hello"} + + result = await api.send_message(request) + assert result.status == 404 + + @pytest.mark.asyncio + async def test_send_message_empty_content(self): + """空消息应返回 400""" + manager = MagicMock() + conv = MagicMock() + manager.get_conversation.return_value = conv + + api = ConversationAPI(manager) + request = AsyncMock() + request.match_info = {"id": "conv-1"} + request.json.return_value = {"content": ""} + + result = await api.send_message(request) + assert result.status == 400 + + +# ============================================================================= +# 全局函数测试 +# ============================================================================= + + +class TestGlobalFunctions: + """全局函数测试""" + + def test_create_api_server(self): + """create_api_server 应创建并缓存服务器""" + import jojo_code.core.api_server as module + + module._api_server = None + server = create_api_server() + assert isinstance(server, APIServer) + + def test_get_api_server_returns_cached(self): + """get_api_server 应返回缓存的实例""" + import jojo_code.core.api_server as module + + module._api_server = None + create_api_server() + cached = get_api_server() + assert cached is module._api_server + + def test_get_api_server_returns_none_when_not_created(self): + """未创建时 get_api_server 应返回 None""" + import jojo_code.core.api_server as module + + module._api_server = None + assert get_api_server() is None + + +# ============================================================================= +# require_auth 装饰器测试 +# ============================================================================= + + +class TestRequireAuth: + """require_auth 装饰器测试""" + + @pytest.mark.asyncio + async def test_missing_auth_returns_401(self): + """缺少 Authorization 头应返回 401""" + + @require_auth + async def handler(request): + return json_response({"ok": True}) + + request = MagicMock() + request.headers = {} + result = await handler(request) + assert result.status == 401 + + @pytest.mark.asyncio + async def test_valid_auth_passes_through(self): + """有效的 Authorization 头应通过""" + + @require_auth + async def handler(request): + return json_response({"ok": True}) + + request = MagicMock() + request.headers = {"Authorization": "Bearer some-token"} + result = await handler(request) + body = json.loads(result.text) + assert body["ok"] is True diff --git a/tests/test_core/test_database.py b/tests/test_core/test_database.py new file mode 100644 index 0000000..10405f8 --- /dev/null +++ b/tests/test_core/test_database.py @@ -0,0 +1,544 @@ +"""Core 数据库抽象层测试 + +测试 Query, Record, MockDatabaseBackend, MockCursor, +Repository, DatabaseManager, get_db_manager, get_repository。 +""" + +from datetime import datetime + +import pytest + +from jojo_code.core.database import ( + ConnectionError, + DatabaseBackend, + DatabaseError, + DatabaseManager, + MockCursor, + MockDatabaseBackend, + Query, + QueryError, + Record, + Repository, + get_db_manager, + get_repository, +) + +# --- Data classes --- + + +class TestQuery: + def test_default_values(self): + q = Query(table="users") + assert q.table == "users" + assert q.filter == {} + assert q.projection is None + assert q.sort is None + assert q.limit is None + assert q.offset is None + + def test_custom_values(self): + q = Query( + table="orders", + filter={"status": "active"}, + projection=["id", "name"], + sort=[("created_at", "desc")], + limit=10, + offset=20, + ) + assert q.table == "orders" + assert q.filter == {"status": "active"} + assert q.projection == ["id", "name"] + assert q.sort == [("created_at", "desc")] + assert q.limit == 10 + assert q.offset == 20 + + +class TestRecord: + def test_default_values(self): + r = Record() + assert r.id is None + assert r.data == {} + assert r.created_at is None + assert r.updated_at is None + + def test_with_values(self): + now = datetime.now() + r = Record(id="abc", data={"name": "test"}, created_at=now, updated_at=now) + assert r.id == "abc" + assert r.data == {"name": "test"} + assert r.created_at == now + assert r.updated_at == now + + +# --- Exceptions --- + + +class TestExceptions: + def test_database_error_is_exception(self): + assert issubclass(DatabaseError, Exception) + + def test_connection_error_is_database_error(self): + assert issubclass(ConnectionError, DatabaseError) + + def test_query_error_is_database_error(self): + assert issubclass(QueryError, DatabaseError) + + def test_database_error_message(self): + err = DatabaseError("something broke") + assert str(err) == "something broke" + + def test_connection_error_message(self): + err = ConnectionError("cannot connect") + assert str(err) == "cannot connect" + + def test_query_error_message(self): + err = QueryError("bad query") + assert str(err) == "bad query" + + +# --- MockCursor --- + + +class TestMockCursor: + def test_rowcount_default(self): + cursor = MockCursor() + assert cursor.rowcount == 0 + + def test_lastrowid(self): + cursor = MockCursor() + assert cursor.lastrowid is None + + +# --- MockDatabaseBackend --- + + +class TestMockDatabaseBackend: + @pytest.fixture + def backend(self): + return MockDatabaseBackend() + + @pytest.mark.asyncio + async def test_connect(self, backend): + await backend.connect() # Should not raise + + @pytest.mark.asyncio + async def test_disconnect_clears_data(self, backend): + await backend.create_table("t", {}) + await backend.insert("t", {"x": 1}) + await backend.disconnect() + assert backend._data == {} + + @pytest.mark.asyncio + async def test_execute_returns_mock_cursor(self, backend): + result = await backend.execute("SELECT 1") + assert isinstance(result, MockCursor) + + @pytest.mark.asyncio + async def test_create_table(self, backend): + await backend.create_table("users", {"name": {"type": "TEXT"}}) + assert "users" in backend._data + assert backend._data["users"] == [] + + @pytest.mark.asyncio + async def test_create_table_idempotent(self, backend): + await backend.create_table("users", {}) + await backend.create_table("users", {}) + assert "users" in backend._data + + @pytest.mark.asyncio + async def test_drop_table(self, backend): + await backend.create_table("users", {}) + await backend.drop_table("users") + assert "users" not in backend._data + + @pytest.mark.asyncio + async def test_drop_nonexistent_table(self, backend): + await backend.drop_table("nope") # Should not raise + + @pytest.mark.asyncio + async def test_insert_returns_id(self, backend): + await backend.create_table("t", {}) + id_ = await backend.insert("t", {"name": "Alice"}) + assert isinstance(id_, str) + assert len(id_) > 0 + + @pytest.mark.asyncio + async def test_insert_creates_table_if_missing(self, backend): + id_ = await backend.insert("new_table", {"x": 1}) + assert id_ is not None + assert "new_table" in backend._data + + @pytest.mark.asyncio + async def test_insert_stores_record(self, backend): + await backend.create_table("t", {}) + id_ = await backend.insert("t", {"name": "Alice"}) + records = backend._data["t"] + assert len(records) == 1 + assert records[0].id == id_ + assert records[0].data["name"] == "Alice" + assert records[0].created_at is not None + assert records[0].updated_at is not None + + @pytest.mark.asyncio + async def test_update_existing_record(self, backend): + await backend.create_table("t", {}) + id_ = await backend.insert("t", {"name": "Alice"}) + result = await backend.update("t", id_, {"name": "Bob"}) + assert result is True + records = backend._data["t"] + assert records[0].data["name"] == "Bob" + + @pytest.mark.asyncio + async def test_update_nonexistent_record(self, backend): + await backend.create_table("t", {}) + result = await backend.update("t", "fake-id", {"name": "Bob"}) + assert result is False + + @pytest.mark.asyncio + async def test_update_nonexistent_table(self, backend): + result = await backend.update("nope", "id", {"x": 1}) + assert result is False + + @pytest.mark.asyncio + async def test_delete_existing_record(self, backend): + await backend.create_table("t", {}) + id_ = await backend.insert("t", {"name": "Alice"}) + result = await backend.delete("t", id_) + assert result is True + assert len(backend._data["t"]) == 0 + + @pytest.mark.asyncio + async def test_delete_nonexistent_record(self, backend): + await backend.create_table("t", {}) + result = await backend.delete("t", "fake-id") + assert result is False + + @pytest.mark.asyncio + async def test_delete_nonexistent_table(self, backend): + result = await backend.delete("nope", "id") + assert result is False + + @pytest.mark.asyncio + async def test_fetch_one_found(self, backend): + await backend.create_table("t", {}) + id_ = await backend.insert("t", {"name": "Alice"}) + query = Query(table="t", filter={"name": "Alice"}) + record = await backend.fetch_one(query) + assert record is not None + assert record.id == id_ + + @pytest.mark.asyncio + async def test_fetch_one_not_found(self, backend): + await backend.create_table("t", {}) + query = Query(table="t", filter={"name": "Nobody"}) + record = await backend.fetch_one(query) + assert record is None + + @pytest.mark.asyncio + async def test_fetch_one_empty_filter(self, backend): + await backend.create_table("t", {}) + await backend.insert("t", {"name": "Alice"}) + query = Query(table="t") + record = await backend.fetch_one(query) + assert record is not None + + @pytest.mark.asyncio + async def test_fetch_many_no_filter(self, backend): + await backend.create_table("t", {}) + await backend.insert("t", {"name": "Alice"}) + await backend.insert("t", {"name": "Bob"}) + query = Query(table="t") + records = await backend.fetch_many(query) + assert len(records) == 2 + + @pytest.mark.asyncio + async def test_fetch_many_with_filter(self, backend): + await backend.create_table("t", {}) + await backend.insert("t", {"role": "admin", "name": "Alice"}) + await backend.insert("t", {"role": "user", "name": "Bob"}) + query = Query(table="t", filter={"role": "admin"}) + records = await backend.fetch_many(query) + assert len(records) == 1 + assert records[0].data["name"] == "Alice" + + @pytest.mark.asyncio + async def test_fetch_many_with_sort_asc(self, backend): + await backend.create_table("t", {}) + await backend.insert("t", {"name": "Charlie"}) + await backend.insert("t", {"name": "Alice"}) + await backend.insert("t", {"name": "Bob"}) + query = Query(table="t", sort=[("name", "asc")]) + records = await backend.fetch_many(query) + assert [r.data["name"] for r in records] == ["Alice", "Bob", "Charlie"] + + @pytest.mark.asyncio + async def test_fetch_many_with_sort_desc(self, backend): + await backend.create_table("t", {}) + await backend.insert("t", {"name": "Charlie"}) + await backend.insert("t", {"name": "Alice"}) + await backend.insert("t", {"name": "Bob"}) + query = Query(table="t", sort=[("name", "desc")]) + records = await backend.fetch_many(query) + assert [r.data["name"] for r in records] == ["Charlie", "Bob", "Alice"] + + @pytest.mark.asyncio + async def test_fetch_many_with_limit(self, backend): + await backend.create_table("t", {}) + for i in range(5): + await backend.insert("t", {"idx": i}) + query = Query(table="t", limit=2) + records = await backend.fetch_many(query) + assert len(records) == 2 + + @pytest.mark.asyncio + async def test_fetch_many_with_offset(self, backend): + await backend.create_table("t", {}) + await backend.insert("t", {"name": "A"}) + await backend.insert("t", {"name": "B"}) + await backend.insert("t", {"name": "C"}) + query = Query(table="t", offset=2) + records = await backend.fetch_many(query) + assert len(records) == 1 + assert records[0].data["name"] == "C" + + @pytest.mark.asyncio + async def test_fetch_many_empty_table(self, backend): + await backend.create_table("t", {}) + query = Query(table="t") + records = await backend.fetch_many(query) + assert records == [] + + @pytest.mark.asyncio + async def test_fetch_many_nonexistent_table(self, backend): + query = Query(table="nope") + records = await backend.fetch_many(query) + assert records == [] + + +# --- DatabaseBackend abstract --- + + +class TestDatabaseBackendAbstract: + def test_cannot_instantiate(self): + with pytest.raises(TypeError): + DatabaseBackend() + + +# --- Repository --- + + +class _SimpleModel: + """Simple model for Repository tests.""" + + def __init__(self, name: str = "", value: int = 0): + self.name = name + self.value = value + + +class TestRepository: + @pytest.fixture + def backend(self): + b = MockDatabaseBackend() + return b + + @pytest.fixture + def repo(self, backend): + return Repository(backend, "items", _SimpleModel) + + @pytest.mark.asyncio + async def test_create(self, repo, backend): + await backend.create_table("items", {}) + id_ = await repo.create({"name": "widget", "value": 42}) + assert isinstance(id_, str) + + @pytest.mark.asyncio + async def test_find_by_id(self, repo, backend): + await backend.create_table("items", {}) + id_ = await repo.create({"name": "widget", "value": 42}) + model = await repo.find_by_id(id_) + assert model is not None + assert model.name == "widget" + assert model.value == 42 + + @pytest.mark.asyncio + async def test_find_by_id_not_found(self, repo, backend): + await backend.create_table("items", {}) + model = await repo.find_by_id("nonexistent") + assert model is None + + @pytest.mark.asyncio + async def test_find_one(self, repo, backend): + await backend.create_table("items", {}) + await repo.create({"name": "widget", "value": 42}) + model = await repo.find_one({"name": "widget"}) + assert model is not None + assert model.value == 42 + + @pytest.mark.asyncio + async def test_find_one_not_found(self, repo, backend): + await backend.create_table("items", {}) + model = await repo.find_one({"name": "missing"}) + assert model is None + + @pytest.mark.asyncio + async def test_find_many(self, repo, backend): + await backend.create_table("items", {}) + await repo.create({"name": "a", "value": 1}) + await repo.create({"name": "b", "value": 2}) + await repo.create({"name": "c", "value": 3}) + models = await repo.find_many() + assert len(models) == 3 + + @pytest.mark.asyncio + async def test_find_many_with_filter(self, repo, backend): + await backend.create_table("items", {}) + await repo.create({"name": "a", "value": 1}) + await repo.create({"name": "b", "value": 2}) + models = await repo.find_many(filter={"name": "a"}) + assert len(models) == 1 + assert models[0].value == 1 + + @pytest.mark.asyncio + async def test_find_many_with_sort(self, repo, backend): + await backend.create_table("items", {}) + await repo.create({"name": "c", "value": 3}) + await repo.create({"name": "a", "value": 1}) + await repo.create({"name": "b", "value": 2}) + models = await repo.find_many(sort=[("name", "asc")]) + assert [m.name for m in models] == ["a", "b", "c"] + + @pytest.mark.asyncio + async def test_find_many_with_limit_and_offset(self, repo, backend): + await backend.create_table("items", {}) + for i in range(5): + await repo.create({"name": f"item{i}", "value": i}) + models = await repo.find_many(limit=2, offset=1) + assert len(models) == 2 + + @pytest.mark.asyncio + async def test_update(self, repo, backend): + await backend.create_table("items", {}) + id_ = await repo.create({"name": "old", "value": 1}) + result = await repo.update(id_, {"name": "new", "value": 2}) + assert result is True + model = await repo.find_by_id(id_) + assert model.name == "new" + assert model.value == 2 + + @pytest.mark.asyncio + async def test_delete(self, repo, backend): + await backend.create_table("items", {}) + id_ = await repo.create({"name": "gone", "value": 0}) + result = await repo.delete(id_) + assert result is True + model = await repo.find_by_id(id_) + assert model is None + + @pytest.mark.asyncio + async def test_count_no_filter(self, repo, backend): + await backend.create_table("items", {}) + await repo.create({"name": "a", "value": 1}) + await repo.create({"name": "b", "value": 2}) + assert await repo.count() == 2 + + @pytest.mark.asyncio + async def test_count_with_filter(self, repo, backend): + await backend.create_table("items", {}) + await repo.create({"name": "a", "value": 1}) + await repo.create({"name": "b", "value": 2}) + assert await repo.count(filter={"name": "a"}) == 1 + + @pytest.mark.asyncio + async def test_count_empty(self, repo, backend): + await backend.create_table("items", {}) + assert await repo.count() == 0 + + +# --- DatabaseManager --- + + +class TestDatabaseManager: + @pytest.fixture + def manager(self): + return DatabaseManager() + + def test_add_backend(self, manager): + backend = MockDatabaseBackend() + manager.add_backend("mock", backend) + assert "mock" in manager.backends + + def test_set_default(self, manager): + backend = MockDatabaseBackend() + manager.add_backend("mock", backend) + manager.set_default("mock") + assert manager._default is backend + + def test_set_default_unknown_name_is_noop(self, manager): + manager.set_default("nonexistent") + assert manager._default is None + + def test_get_backend_by_name(self, manager): + backend = MockDatabaseBackend() + manager.add_backend("mock", backend) + result = manager.get_backend("mock") + assert result is backend + + def test_get_backend_unknown_name_raises(self, manager): + with pytest.raises(ValueError, match="Backend .* not found"): + manager.get_backend("nonexistent") + + def test_get_backend_returns_default(self, manager): + backend = MockDatabaseBackend() + manager.add_backend("mock", backend) + manager.set_default("mock") + result = manager.get_backend() + assert result is backend + + def test_get_backend_auto_creates_mock(self, manager): + result = manager.get_backend() + assert isinstance(result, MockDatabaseBackend) + assert "default" in manager.backends + + @pytest.mark.asyncio + async def test_connection_context_manager(self, manager): + backend = MockDatabaseBackend() + manager.add_backend("mock", backend) + manager.set_default("mock") + async with manager.connection() as conn: + assert conn is backend + + @pytest.mark.asyncio + async def test_connection_by_name(self, manager): + backend = MockDatabaseBackend() + manager.add_backend("mock", backend) + async with manager.connection("mock") as conn: + assert conn is backend + + def test_get_repository(self, manager): + repo = manager.get_repository("items", _SimpleModel) + assert isinstance(repo, Repository) + assert repo.table == "items" + + +# --- Module-level helpers --- + + +class TestGlobalHelpers: + def test_get_db_manager_returns_same_instance(self): + # Reset global state first + import jojo_code.core.database as db_mod + + original = db_mod._db_manager + try: + db_mod._db_manager = None + m1 = get_db_manager() + m2 = get_db_manager() + assert m1 is m2 + assert isinstance(m1, DatabaseManager) + finally: + db_mod._db_manager = original + + def test_get_repository_returns_repository(self): + repo = get_repository("items", _SimpleModel) + assert isinstance(repo, Repository) + assert repo.table == "items" diff --git a/tests/test_core/test_error_code.py b/tests/test_core/test_error_code.py new file mode 100644 index 0000000..73756ef --- /dev/null +++ b/tests/test_core/test_error_code.py @@ -0,0 +1,415 @@ +"""Error Code 模块测试 + +测试 ErrorCode, ErrorContext, ErrorCategory, ERROR_MESSAGES, +get_error_message, is_retryable_error。 +""" + +from jojo_code.core.error_code import ( + ERROR_MESSAGES, + ErrorCategory, + ErrorCode, + ErrorContext, + get_error_message, + is_retryable_error, +) + +# ============================================================================= +# ErrorCode 枚举测试 +# ============================================================================= + + +class TestErrorCode: + """ErrorCode 枚举测试""" + + def test_config_error_codes(self): + """配置错误码应为 1xxx""" + assert ErrorCode.CONFIG_NOT_FOUND == 1001 + assert ErrorCode.CONFIG_INVALID == 1002 + assert ErrorCode.CONFIG_PERMISSION_DENIED == 1003 + + def test_llm_error_codes(self): + """LLM 错误码应为 2xxx""" + assert ErrorCode.LLM_API_ERROR == 2001 + assert ErrorCode.LLM_TIMEOUT == 2002 + assert ErrorCode.LLM_RATE_LIMIT == 2003 + assert ErrorCode.LLM_INVALID_RESPONSE == 2004 + assert ErrorCode.LLM_MODEL_NOT_FOUND == 2005 + assert ErrorCode.LLM_CONTEXT_OVERFLOW == 2006 + assert ErrorCode.LLM_AUTH_FAILED == 2007 + + def test_tool_error_codes(self): + """工具错误码应为 3xxx""" + assert ErrorCode.TOOL_NOT_FOUND == 3001 + assert ErrorCode.TOOL_EXECUTION_FAILED == 3002 + assert ErrorCode.TOOL_TIMEOUT == 3003 + assert ErrorCode.TOOL_PERMISSION_DENIED == 3004 + assert ErrorCode.TOOL_INVALID_INPUT == 3005 + assert ErrorCode.TOOL_OUTPUT_TOO_LARGE == 3006 + + def test_security_error_codes(self): + """安全错误码应为 4xxx""" + assert ErrorCode.SECURITY_DENIED == 4001 + assert ErrorCode.SECURITY_RISK_HIGH == 4002 + assert ErrorCode.SECURITY_PATH_FORBIDDEN == 4003 + assert ErrorCode.SECURITY_COMMAND_FORBIDDEN == 4004 + + def test_validation_error_codes(self): + """验证错误码应为 5xxx""" + assert ErrorCode.VALIDATION_FAILED == 5001 + assert ErrorCode.VALIDATION_TYPE_ERROR == 5002 + assert ErrorCode.VALIDATION_CONSTRAINT == 5003 + + def test_network_error_codes(self): + """网络错误码应为 6xxx""" + assert ErrorCode.NETWORK_ERROR == 6001 + assert ErrorCode.NETWORK_TIMEOUT == 6002 + assert ErrorCode.NETWORK_CONNECTION_REFUSED == 6003 + + def test_task_error_codes(self): + """任务错误码应为 7xxx""" + assert ErrorCode.TASK_NOT_FOUND == 7001 + assert ErrorCode.TASK_FAILED == 7002 + assert ErrorCode.TASK_TIMEOUT == 7003 + assert ErrorCode.TASK_CANCELLED == 7004 + assert ErrorCode.TASK_MAX_RETRIES_EXCEEDED == 7005 + + def test_internal_error_codes(self): + """内部错误码应为 9xxx""" + assert ErrorCode.INTERNAL_ERROR == 9001 + assert ErrorCode.NOT_IMPLEMENTED == 9002 + assert ErrorCode.UNEXPECTED_ERROR == 9003 + + def test_error_codes_are_unique(self): + """所有错误码应该是唯一的""" + codes = [e.value for e in ErrorCode] + assert len(codes) == len(set(codes)) + + def test_error_code_is_int(self): + """ErrorCode 应该是 IntEnum""" + assert isinstance(ErrorCode.CONFIG_NOT_FOUND, int) + assert ErrorCode.CONFIG_NOT_FOUND == 1001 + + def test_error_code_from_value(self): + """应该能从整数值创建 ErrorCode""" + code = ErrorCode(1001) + assert code == ErrorCode.CONFIG_NOT_FOUND + + def test_all_error_codes_have_prefix_ranges(self): + """错误码应按照文档中的范围定义""" + for code in ErrorCode: + value = code.value + # 所有错误码应在 1000-9999 范围内 + assert 1000 <= value <= 9999, f"{code} has value {value}" + + +# ============================================================================= +# ErrorContext 测试 +# ============================================================================= + + +class TestErrorContext: + """ErrorContext 数据类测试""" + + def test_create_error_context(self): + """应该正确创建错误上下文""" + ctx = ErrorContext( + code=ErrorCode.LLM_API_ERROR, + message="API call failed", + ) + assert ctx.code == ErrorCode.LLM_API_ERROR + assert ctx.message == "API call failed" + assert ctx.details == {} + assert ctx.hint is None + assert ctx.source is None + assert ctx.correlation_id is None + + def test_error_context_with_details(self): + """应该支持 details 字段""" + ctx = ErrorContext( + code=ErrorCode.TOOL_EXECUTION_FAILED, + message="Tool failed", + details={"tool": "read_file", "path": "/tmp/test.py"}, + ) + assert ctx.details["tool"] == "read_file" + assert ctx.details["path"] == "/tmp/test.py" + + def test_error_context_with_hint(self): + """应该支持 hint 字段""" + ctx = ErrorContext( + code=ErrorCode.CONFIG_NOT_FOUND, + message="Config not found", + hint="Run 'jojo-code setup' to create config", + ) + assert ctx.hint == "Run 'jojo-code setup' to create config" + + def test_error_context_with_source(self): + """应该支持 source 字段""" + ctx = ErrorContext( + code=ErrorCode.INTERNAL_ERROR, + message="Something broke", + source="agent/graph.py", + ) + assert ctx.source == "agent/graph.py" + + def test_error_context_with_correlation_id(self): + """应该支持 correlation_id 字段""" + ctx = ErrorContext( + code=ErrorCode.NETWORK_ERROR, + message="Connection failed", + correlation_id="trace-abc-123", + ) + assert ctx.correlation_id == "trace-abc-123" + + def test_error_context_full(self): + """应该支持所有字段""" + ctx = ErrorContext( + code=ErrorCode.LLM_TIMEOUT, + message="Request timed out", + details={"timeout_ms": 30000}, + hint="Increase timeout in config", + source="core/llm.py", + correlation_id="req-xyz", + ) + assert ctx.code == ErrorCode.LLM_TIMEOUT + assert ctx.message == "Request timed out" + assert ctx.details == {"timeout_ms": 30000} + assert ctx.hint == "Increase timeout in config" + assert ctx.source == "core/llm.py" + assert ctx.correlation_id == "req-xyz" + + def test_details_default_factory(self): + """每个实例应有独立的 details 字典""" + ctx1 = ErrorContext(code=ErrorCode.INTERNAL_ERROR, message="err1") + ctx2 = ErrorContext(code=ErrorCode.INTERNAL_ERROR, message="err2") + ctx1.details["key"] = "value" + assert "key" not in ctx2.details + + +# ============================================================================= +# ErrorCategory 测试 +# ============================================================================= + + +class TestErrorCategory: + """ErrorCategory 错误类别映射测试""" + + def test_config_category(self): + """配置错误应归类为 config""" + assert ErrorCategory.get(ErrorCode.CONFIG_NOT_FOUND) == "config" + assert ErrorCategory.get(ErrorCode.CONFIG_INVALID) == "config" + + def test_llm_category(self): + """LLM 错误应归类为 llm""" + assert ErrorCategory.get(ErrorCode.LLM_API_ERROR) == "llm" + assert ErrorCategory.get(ErrorCode.LLM_TIMEOUT) == "llm" + + def test_tool_category(self): + """工具错误应归类为 tool""" + assert ErrorCategory.get(ErrorCode.TOOL_NOT_FOUND) == "tool" + assert ErrorCategory.get(ErrorCode.TOOL_EXECUTION_FAILED) == "tool" + + def test_security_category(self): + """安全错误应归类为 security""" + assert ErrorCategory.get(ErrorCode.SECURITY_DENIED) == "security" + assert ErrorCategory.get(ErrorCode.SECURITY_RISK_HIGH) == "security" + + def test_validation_category(self): + """验证错误应归类为 validation""" + assert ErrorCategory.get(ErrorCode.VALIDATION_FAILED) == "validation" + + def test_network_category(self): + """网络错误应归类为 network""" + assert ErrorCategory.get(ErrorCode.NETWORK_ERROR) == "network" + + def test_task_category(self): + """任务错误应归类为 task""" + assert ErrorCategory.get(ErrorCode.TASK_FAILED) == "task" + + def test_internal_category(self): + """内部错误应归类为 internal""" + assert ErrorCategory.get(ErrorCode.INTERNAL_ERROR) == "internal" + + def test_unknown_category(self): + """未映射的错误码应返回 unknown""" + # SECURITY_PATH_FORBIDDEN 没有在 _CATEGORIES 中映射 + result = ErrorCategory.get(ErrorCode.SECURITY_PATH_FORBIDDEN) + assert result == "unknown" + + def test_retryable_codes(self): + """可重试的错误码应正确标识""" + retryable = [ + ErrorCode.LLM_API_ERROR, + ErrorCode.LLM_TIMEOUT, + ErrorCode.LLM_RATE_LIMIT, + ErrorCode.NETWORK_ERROR, + ErrorCode.NETWORK_TIMEOUT, + ErrorCode.TASK_FAILED, + ] + for code in retryable: + assert ErrorCategory.is_retryable(code) is True, f"{code} should be retryable" + + def test_non_retryable_codes(self): + """不可重试的错误码应正确标识""" + non_retryable = [ + ErrorCode.CONFIG_NOT_FOUND, + ErrorCode.TOOL_NOT_FOUND, + ErrorCode.SECURITY_DENIED, + ErrorCode.VALIDATION_FAILED, + ErrorCode.INTERNAL_ERROR, + ] + for code in non_retryable: + assert ErrorCategory.is_retryable(code) is False, f"{code} should not be retryable" + + +# ============================================================================= +# ERROR_MESSAGES 测试 +# ============================================================================== + + +class TestErrorMessages: + """ERROR_MESSAGES 消息映射测试""" + + def test_all_mapped_codes_have_messages(self): + """所有映射的错误码应该有对应消息""" + for code, message in ERROR_MESSAGES.items(): + assert isinstance(code, ErrorCode) + assert isinstance(message, str) + assert len(message) > 0 + + def test_config_messages(self): + """配置错误应有对应消息""" + assert "配置" in ERROR_MESSAGES[ErrorCode.CONFIG_NOT_FOUND] + assert "配置" in ERROR_MESSAGES[ErrorCode.CONFIG_INVALID] + + def test_llm_messages(self): + """LLM 错误应有对应消息""" + assert "LLM" in ERROR_MESSAGES[ErrorCode.LLM_API_ERROR] + assert "超时" in ERROR_MESSAGES[ErrorCode.LLM_TIMEOUT] + + def test_tool_messages(self): + """工具错误应有对应消息""" + assert "工具" in ERROR_MESSAGES[ErrorCode.TOOL_NOT_FOUND] + assert "工具" in ERROR_MESSAGES[ErrorCode.TOOL_EXECUTION_FAILED] + + def test_security_messages(self): + """安全错误应有对应消息""" + assert "安全" in ERROR_MESSAGES[ErrorCode.SECURITY_DENIED] + + def test_network_messages(self): + """网络错误应有对应消息""" + assert "网络" in ERROR_MESSAGES[ErrorCode.NETWORK_ERROR] + + def test_task_messages(self): + """任务错误应有对应消息""" + assert "任务" in ERROR_MESSAGES[ErrorCode.TASK_NOT_FOUND] + assert "任务" in ERROR_MESSAGES[ErrorCode.TASK_FAILED] + + def test_internal_messages(self): + """内部错误应有对应消息""" + assert "内部" in ERROR_MESSAGES[ErrorCode.INTERNAL_ERROR] + + +# ============================================================================= +# get_error_message 测试 +# ============================================================================= + + +class TestGetErrorMessage: + """get_error_message 函数测试""" + + def test_known_code(self): + """已知错误码应返回对应消息""" + msg = get_error_message(ErrorCode.CONFIG_NOT_FOUND) + assert msg == "配置文件未找到" + + def test_llm_error_message(self): + """LLM 错误应返回 LLM 相关消息""" + msg = get_error_message(ErrorCode.LLM_API_ERROR) + assert msg == "LLM API 调用失败" + + def test_tool_error_message(self): + """工具错误应返回工具相关消息""" + msg = get_error_message(ErrorCode.TOOL_NOT_FOUND) + assert msg == "工具未找到" + + def test_unknown_code_returns_default(self): + """未映射的错误码应返回默认消息""" + # ErrorCode.SECURITY_PATH_FORBIDDEN 没有在 ERROR_MESSAGES 中 + msg = get_error_message(ErrorCode.SECURITY_PATH_FORBIDDEN) + assert msg == "未知错误" + + def test_all_mapped_codes_return_non_default(self): + """所有映射的错误码不应返回默认消息""" + for code in ERROR_MESSAGES: + msg = get_error_message(code) + assert msg != "未知错误", f"{code} returned default message" + + +# ============================================================================= +# is_retryable_error 测试 +# ============================================================================= + + +class TestIsRetryableError: + """is_retryable_error 函数测试""" + + def test_retryable_llm_errors(self): + """LLM API 错误、超时、限流应可重试""" + assert is_retryable_error(ErrorCode.LLM_API_ERROR) is True + assert is_retryable_error(ErrorCode.LLM_TIMEOUT) is True + assert is_retryable_error(ErrorCode.LLM_RATE_LIMIT) is True + + def test_retryable_network_errors(self): + """网络错误应可重试""" + assert is_retryable_error(ErrorCode.NETWORK_ERROR) is True + assert is_retryable_error(ErrorCode.NETWORK_TIMEOUT) is True + + def test_retryable_task_errors(self): + """任务失败应可重试""" + assert is_retryable_error(ErrorCode.TASK_FAILED) is True + + def test_non_retryable_config_errors(self): + """配置错误不应重试""" + assert is_retryable_error(ErrorCode.CONFIG_NOT_FOUND) is False + assert is_retryable_error(ErrorCode.CONFIG_INVALID) is False + + def test_non_retryable_security_errors(self): + """安全错误不应重试""" + assert is_retryable_error(ErrorCode.SECURITY_DENIED) is False + assert is_retryable_error(ErrorCode.SECURITY_RISK_HIGH) is False + + def test_non_retryable_internal_errors(self): + """内部错误不应重试""" + assert is_retryable_error(ErrorCode.INTERNAL_ERROR) is False + assert is_retryable_error(ErrorCode.NOT_IMPLEMENTED) is False + + +# ============================================================================= +# __all__ 导出测试 +# ============================================================================= + + +class TestExports: + """模块导出测试""" + + def test_all_exports_present(self): + """__all__ 中的所有名称都应该可以从模块导入""" + from jojo_code.core.error_code import __all__ + + expected = [ + "ErrorCode", + "ErrorContext", + "ErrorCategory", + "ERROR_MESSAGES", + "get_error_message", + "is_retryable_error", + ] + for name in expected: + assert name in __all__, f"{name} missing from __all__" + + def test_all_exports_importable(self): + """所有导出的名称都应该可以导入""" + import jojo_code.core.error_code as module + + for name in module.__all__: + assert hasattr(module, name), f"{name} not importable" diff --git a/tests/test_core/test_monitoring.py b/tests/test_core/test_monitoring.py new file mode 100644 index 0000000..ed06995 --- /dev/null +++ b/tests/test_core/test_monitoring.py @@ -0,0 +1,864 @@ +"""Core 监控和指标模块测试 + +测试 MetricType, Metric, SystemMetrics, AgentMetrics, +MetricsCollector, SystemMonitor, AgentMonitor, AlertManager, +AlertRule, AlertRules, get_metrics_collector, get_system_monitor, +get_alert_manager。 +""" + +import asyncio +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch + +import pytest + +from jojo_code.core.monitoring import ( + AgentMetrics, + AgentMonitor, + AlertManager, + AlertRule, + AlertRules, + Metric, + MetricsCollector, + MetricType, + SystemMetrics, + SystemMonitor, + get_alert_manager, + get_metrics_collector, + get_system_monitor, +) + +# --- Enums and data classes --- + + +class TestMetricType: + def test_values(self): + assert MetricType.COUNTER.value == "counter" + assert MetricType.GAUGE.value == "gauge" + assert MetricType.HISTOGRAM.value == "histogram" + assert MetricType.TIMER.value == "timer" + + +class TestMetric: + def test_default_values(self): + m = Metric(name="test", value=1.0, metric_type=MetricType.GAUGE) + assert m.name == "test" + assert m.value == 1.0 + assert m.metric_type == MetricType.GAUGE + assert isinstance(m.timestamp, datetime) + assert m.tags == {} + assert m.unit == "" + + def test_custom_values(self): + tags = {"env": "prod"} + m = Metric( + name="cpu", + value=75.5, + metric_type=MetricType.GAUGE, + tags=tags, + unit="%", + ) + assert m.tags == {"env": "prod"} + assert m.unit == "%" + + +class TestSystemMetrics: + def test_construction(self): + sm = SystemMetrics( + cpu_percent=50.0, + memory_percent=60.0, + memory_used_mb=1024.0, + memory_available_mb=2048.0, + disk_percent=70.0, + network_sent_mb=100.0, + network_recv_mb=200.0, + ) + assert sm.cpu_percent == 50.0 + assert sm.memory_percent == 60.0 + assert sm.memory_used_mb == 1024.0 + assert sm.memory_available_mb == 2048.0 + assert sm.disk_percent == 70.0 + assert sm.network_sent_mb == 100.0 + assert sm.network_recv_mb == 200.0 + assert isinstance(sm.timestamp, datetime) + + +class TestAgentMetrics: + def test_defaults(self): + am = AgentMetrics(agent_id="agent-1") + assert am.agent_id == "agent-1" + assert am.total_conversations == 0 + assert am.total_messages == 0 + assert am.total_tokens == 0 + assert am.avg_response_time_ms == 0.0 + assert am.error_count == 0 + assert am.success_rate == 1.0 + assert am.uptime_seconds == 0.0 + + +# --- MetricsCollector --- + + +class TestMetricsCollector: + @pytest.fixture + def collector(self): + return MetricsCollector(retention_minutes=60) + + @pytest.mark.asyncio + async def test_record_metric(self, collector): + await collector.record("test_metric", 42.0, MetricType.GAUGE) + metrics = await collector.get_metrics("test_metric") + assert len(metrics) == 1 + assert metrics[0].value == 42.0 + assert metrics[0].metric_type == MetricType.GAUGE + + @pytest.mark.asyncio + async def test_record_with_tags(self, collector): + await collector.record("req", 1.0, tags={"path": "/api"}) + metrics = await collector.get_metrics("req") + assert metrics[0].tags == {"path": "/api"} + + @pytest.mark.asyncio + async def test_record_with_unit(self, collector): + await collector.record("temp", 36.5, unit="celsius") + metrics = await collector.get_metrics("temp") + assert metrics[0].unit == "celsius" + + @pytest.mark.asyncio + async def test_increment(self, collector): + await collector.increment("counter") + await collector.increment("counter", delta=5.0) + metrics = await collector.get_metrics("counter") + assert len(metrics) == 2 + assert metrics[0].value == 1.0 + assert metrics[1].value == 5.0 + assert metrics[0].metric_type == MetricType.COUNTER + + @pytest.mark.asyncio + async def test_gauge(self, collector): + await collector.gauge("cpu", 55.5, unit="%") + metrics = await collector.get_metrics("cpu") + assert metrics[0].value == 55.5 + assert metrics[0].metric_type == MetricType.GAUGE + + @pytest.mark.asyncio + async def test_histogram(self, collector): + await collector.histogram("latency", 120.5, unit="ms") + metrics = await collector.get_metrics("latency") + assert metrics[0].value == 120.5 + assert metrics[0].metric_type == MetricType.HISTOGRAM + + @pytest.mark.asyncio + async def test_timer(self, collector): + await collector.timer("response_time", 250.0) + metrics = await collector.get_metrics("response_time") + assert metrics[0].value == 250.0 + assert metrics[0].metric_type == MetricType.TIMER + assert metrics[0].unit == "ms" + + @pytest.mark.asyncio + async def test_get_metrics_empty(self, collector): + metrics = await collector.get_metrics("nonexistent") + assert metrics == [] + + @pytest.mark.asyncio + async def test_get_metrics_with_since(self, collector): + await collector.record("m", 1.0) + # Record a second metric with a future timestamp to test filtering + now = datetime.now() + future = now + timedelta(hours=1) + async with collector._lock: + collector.metrics["m"][0] = Metric( + name="m", value=1.0, metric_type=MetricType.GAUGE, timestamp=now + ) + collector.metrics["m"].append( + Metric(name="m", value=2.0, metric_type=MetricType.GAUGE, timestamp=future) + ) + + # Only get metrics after the first one + metrics = await collector.get_metrics("m", since=now + timedelta(seconds=1)) + assert len(metrics) == 1 + assert metrics[0].value == 2.0 + + @pytest.mark.asyncio + async def test_get_latest(self, collector): + await collector.record("x", 10.0) + await collector.record("x", 20.0) + latest = await collector.get_latest("x") + assert latest is not None + assert latest.value == 20.0 + + @pytest.mark.asyncio + async def test_get_latest_empty(self, collector): + latest = await collector.get_latest("nonexistent") + assert latest is None + + @pytest.mark.asyncio + async def test_get_average(self, collector): + await collector.record("v", 10.0) + await collector.record("v", 20.0) + await collector.record("v", 30.0) + avg = await collector.get_average("v") + assert avg == 20.0 + + @pytest.mark.asyncio + async def test_get_average_empty(self, collector): + avg = await collector.get_average("nonexistent") + assert avg is None + + @pytest.mark.asyncio + async def test_get_percentile(self, collector): + for i in range(1, 101): + await collector.record("pct", float(i)) + p50 = await collector.get_percentile("pct", 50) + assert p50 is not None + assert 1 <= p50 <= 101 + + @pytest.mark.asyncio + async def test_get_percentile_empty(self, collector): + p = await collector.get_percentile("nonexistent", 95) + assert p is None + + @pytest.mark.asyncio + async def test_get_rate(self, collector): + # Single metric returns 0 + await collector.record("rate", 1.0) + rate = await collector.get_rate("rate") + assert rate == 0.0 + + @pytest.mark.asyncio + async def test_get_rate_empty(self, collector): + rate = await collector.get_rate("nonexistent") + assert rate is None + + @pytest.mark.asyncio + async def test_clear_specific(self, collector): + await collector.record("a", 1.0) + await collector.record("b", 2.0) + await collector.clear("a") + assert await collector.get_metrics("a") == [] + assert len(await collector.get_metrics("b")) == 1 + + @pytest.mark.asyncio + async def test_clear_all(self, collector): + await collector.record("a", 1.0) + await collector.record("b", 2.0) + await collector.clear() + assert await collector.get_metrics("a") == [] + assert await collector.get_metrics("b") == [] + + @pytest.mark.asyncio + async def test_clear_nonexistent(self, collector): + await collector.clear("nope") # Should not raise + + @pytest.mark.asyncio + async def test_export_json(self, collector): + await collector.record("test", 42.0, MetricType.GAUGE, {"tag": "v"}, "%") + json_str = await collector.export_json() + import json + + data = json.loads(json_str) + assert "test" in data + assert data["test"][0]["value"] == 42.0 + assert data["test"][0]["type"] == "gauge" + assert data["test"][0]["tags"] == {"tag": "v"} + assert data["test"][0]["unit"] == "%" + + @pytest.mark.asyncio + async def test_export_json_empty(self, collector): + json_str = await collector.export_json() + import json + + data = json.loads(json_str) + assert data == {} + + +# --- SystemMonitor --- + + +class TestSystemMonitor: + @pytest.fixture + def monitor(self): + return SystemMonitor(interval_seconds=0.1) + + def test_initial_state(self, monitor): + assert monitor._running is False + assert monitor._task is None + assert monitor.cpu_percent == 0.0 + assert monitor.memory_percent == 0.0 + assert monitor.disk_percent == 0.0 + + @pytest.mark.asyncio + async def test_start_and_stop(self, monitor): + await monitor.start() + assert monitor._running is True + assert monitor._task is not None + await monitor.stop() + assert monitor._running is False + + @pytest.mark.asyncio + async def test_start_idempotent(self, monitor): + await monitor.start() + task1 = monitor._task + await monitor.start() # Should not create a second task + assert monitor._task is task1 + await monitor.stop() + + @pytest.mark.asyncio + async def test_add_and_remove_callback(self, monitor): + cb = MagicMock() + monitor.add_callback(cb) + assert cb in monitor._callbacks + monitor.remove_callback(cb) + assert cb not in monitor._callbacks + + @pytest.mark.asyncio + async def test_remove_callback_not_present(self, monitor): + cb = MagicMock() + monitor.remove_callback(cb) # Should not raise + + @patch("jojo_code.core.monitoring.psutil") + @pytest.mark.asyncio + async def test_get_current_metrics(self, mock_psutil, monitor): + mock_psutil.cpu_percent.return_value = 25.0 + mem = MagicMock() + mem.percent = 50.0 + mem.used = 1024 * 1024 * 512 + mem.available = 1024 * 1024 * 512 + mock_psutil.virtual_memory.return_value = mem + disk = MagicMock() + disk.percent = 75.0 + mock_psutil.disk_usage.return_value = disk + net = MagicMock() + net.bytes_sent = 1024 * 1024 * 100 + net.bytes_recv = 1024 * 1024 * 200 + mock_psutil.net_io_counters.return_value = net + + metrics = await monitor.get_current_metrics() + assert isinstance(metrics, SystemMetrics) + assert metrics.cpu_percent == 25.0 + assert metrics.memory_percent == 50.0 + assert metrics.disk_percent == 75.0 + + @patch("jojo_code.core.monitoring.psutil") + @pytest.mark.asyncio + async def test_monitor_loop_collects_and_updates(self, mock_psutil, monitor): + mock_psutil.cpu_percent.return_value = 10.0 + mem = MagicMock() + mem.percent = 20.0 + mem.used = 1024 * 1024 + mem.available = 1024 * 1024 * 4 + mock_psutil.virtual_memory.return_value = mem + disk = MagicMock() + disk.percent = 30.0 + mock_psutil.disk_usage.return_value = disk + net = MagicMock() + net.bytes_sent = 1024 + net.bytes_recv = 2048 + mock_psutil.net_io_counters.return_value = net + + await monitor.start() + await asyncio.sleep(0.3) # Let the loop run a few iterations + await monitor.stop() + + assert monitor.cpu_percent == 10.0 + assert monitor.memory_percent == 20.0 + assert monitor.disk_percent == 30.0 + + @patch("jojo_code.core.monitoring.psutil") + @pytest.mark.asyncio + async def test_monitor_loop_invokes_callbacks(self, mock_psutil, monitor): + mock_psutil.cpu_percent.return_value = 10.0 + mem = MagicMock() + mem.percent = 20.0 + mem.used = 1024 * 1024 + mem.available = 1024 * 1024 * 4 + mock_psutil.virtual_memory.return_value = mem + disk = MagicMock() + disk.percent = 30.0 + mock_psutil.disk_usage.return_value = disk + net = MagicMock() + net.bytes_sent = 1024 + net.bytes_recv = 2048 + mock_psutil.net_io_counters.return_value = net + + received = [] + + async def async_callback(metrics): + received.append(metrics) + + monitor.add_callback(async_callback) + await monitor.start() + await asyncio.sleep(0.3) + await monitor.stop() + + assert len(received) >= 1 + assert all(isinstance(m, SystemMetrics) for m in received) + + @patch("jojo_code.core.monitoring.psutil") + @pytest.mark.asyncio + async def test_monitor_loop_sync_callback(self, mock_psutil, monitor): + mock_psutil.cpu_percent.return_value = 10.0 + mem = MagicMock() + mem.percent = 20.0 + mem.used = 1024 * 1024 + mem.available = 1024 * 1024 * 4 + mock_psutil.virtual_memory.return_value = mem + disk = MagicMock() + disk.percent = 30.0 + mock_psutil.disk_usage.return_value = disk + net = MagicMock() + net.bytes_sent = 1024 + net.bytes_recv = 2048 + mock_psutil.net_io_counters.return_value = net + + received = [] + + def sync_callback(metrics): + received.append(metrics) + + monitor.add_callback(sync_callback) + await monitor.start() + await asyncio.sleep(0.3) + await monitor.stop() + + assert len(received) >= 1 + + +# --- AgentMonitor --- + + +class TestAgentMonitor: + @pytest.fixture + def monitor(self): + return AgentMonitor("test-agent") + + @pytest.mark.asyncio + async def test_initial_metrics(self, monitor): + m = await monitor.get_metrics() + assert m.agent_id == "test-agent" + assert m.total_conversations == 0 + assert m.total_messages == 0 + assert m.total_tokens == 0 + assert m.error_count == 0 + assert m.success_rate == 1.0 + assert m.uptime_seconds >= 0 + + @pytest.mark.asyncio + async def test_record_conversation(self, monitor): + await monitor.record_conversation() + await monitor.record_conversation() + m = await monitor.get_metrics() + assert m.total_conversations == 2 + + @pytest.mark.asyncio + async def test_record_message(self, monitor): + await monitor.record_message(tokens=100) + await monitor.record_message(tokens=200) + m = await monitor.get_metrics() + assert m.total_messages == 2 + assert m.total_tokens == 300 + + @pytest.mark.asyncio + async def test_record_message_no_tokens(self, monitor): + await monitor.record_message() + m = await monitor.get_metrics() + assert m.total_messages == 1 + assert m.total_tokens == 0 + + @pytest.mark.asyncio + async def test_record_response_time(self, monitor): + await monitor.record_response_time(100.0) + await monitor.record_response_time(200.0) + m = await monitor.get_metrics() + assert m.avg_response_time_ms == 150.0 + + @pytest.mark.asyncio + async def test_record_error(self, monitor): + await monitor.record_message() + await monitor.record_message() + await monitor.record_error() + m = await monitor.get_metrics() + assert m.error_count == 1 + assert m.success_rate == 0.5 + + @pytest.mark.asyncio + async def test_success_rate_no_messages(self, monitor): + await monitor.record_error() + m = await monitor.get_metrics() + # No messages recorded, so success_rate defaults to 1.0 + assert m.success_rate == 1.0 + + @pytest.mark.asyncio + async def test_reset(self, monitor): + await monitor.record_conversation() + await monitor.record_message(tokens=100) + await monitor.record_error() + await monitor.record_response_time(500.0) + + await monitor.reset() + m = await monitor.get_metrics() + assert m.total_conversations == 0 + assert m.total_messages == 0 + assert m.total_tokens == 0 + assert m.error_count == 0 + assert m.success_rate == 1.0 + + +# --- AlertRule --- + + +class TestAlertRule: + def test_construction(self): + rule = AlertRule( + name="test_rule", + message="Something happened", + severity="warning", + condition=lambda m, s: True, + cooldown_seconds=120, + ) + assert rule.name == "test_rule" + assert rule.message == "Something happened" + assert rule.severity == "warning" + assert rule.cooldown_seconds == 120 + + +# --- AlertManager --- + + +class TestAlertManager: + @pytest.fixture + def manager(self): + return AlertManager() + + @pytest.mark.asyncio + async def test_add_rule(self, manager): + rule = AlertRule( + name="r1", + message="msg", + severity="info", + condition=lambda m, s: False, + ) + manager.add_rule(rule) + assert "r1" in manager.rules + + @pytest.mark.asyncio + async def test_remove_rule(self, manager): + rule = AlertRule( + name="r1", + message="msg", + severity="info", + condition=lambda m, s: False, + ) + manager.add_rule(rule) + manager.remove_rule("r1") + assert "r1" not in manager.rules + + @pytest.mark.asyncio + async def test_remove_nonexistent_rule(self, manager): + manager.remove_rule("nope") # Should not raise + + @pytest.mark.asyncio + async def test_check_triggers_alert(self, manager): + rule = AlertRule( + name="always_fire", + message="Fired!", + severity="critical", + condition=lambda m, s: True, + ) + manager.add_rule(rule) + + collector = MetricsCollector() + sys_metrics = SystemMetrics( + cpu_percent=50.0, + memory_percent=50.0, + memory_used_mb=100.0, + memory_available_mb=100.0, + disk_percent=50.0, + network_sent_mb=10.0, + network_recv_mb=10.0, + ) + triggered = await manager.check(collector, sys_metrics) + assert len(triggered) == 1 + assert triggered[0]["rule"] == "always_fire" + assert triggered[0]["severity"] == "critical" + + @pytest.mark.asyncio + async def test_check_no_trigger(self, manager): + rule = AlertRule( + name="never_fire", + message="Nope", + severity="info", + condition=lambda m, s: False, + ) + manager.add_rule(rule) + + collector = MetricsCollector() + sys_metrics = SystemMetrics( + cpu_percent=50.0, + memory_percent=50.0, + memory_used_mb=100.0, + memory_available_mb=100.0, + disk_percent=50.0, + network_sent_mb=10.0, + network_recv_mb=10.0, + ) + triggered = await manager.check(collector, sys_metrics) + assert len(triggered) == 0 + + @pytest.mark.asyncio + async def test_check_stores_alerts(self, manager): + rule = AlertRule( + name="r1", + message="Fired", + severity="warning", + condition=lambda m, s: True, + ) + manager.add_rule(rule) + + collector = MetricsCollector() + sys_metrics = SystemMetrics( + cpu_percent=50.0, + memory_percent=50.0, + memory_used_mb=100.0, + memory_available_mb=100.0, + disk_percent=50.0, + network_sent_mb=10.0, + network_recv_mb=10.0, + ) + await manager.check(collector, sys_metrics) + alerts = await manager.get_alerts() + assert len(alerts) == 1 + assert alerts[0]["rule"] == "r1" + + @pytest.mark.asyncio + async def test_get_alerts_filtered_by_severity(self, manager): + rule_w = AlertRule( + name="w", + message="warn", + severity="warning", + condition=lambda m, s: True, + ) + rule_c = AlertRule( + name="c", + message="crit", + severity="critical", + condition=lambda m, s: True, + ) + manager.add_rule(rule_w) + manager.add_rule(rule_c) + + collector = MetricsCollector() + sys_metrics = SystemMetrics( + cpu_percent=50.0, + memory_percent=50.0, + memory_used_mb=100.0, + memory_available_mb=100.0, + disk_percent=50.0, + network_sent_mb=10.0, + network_recv_mb=10.0, + ) + await manager.check(collector, sys_metrics) + + warnings = await manager.get_alerts(severity="warning") + assert len(warnings) == 1 + assert warnings[0]["severity"] == "warning" + + criticals = await manager.get_alerts(severity="critical") + assert len(criticals) == 1 + + @pytest.mark.asyncio + async def test_get_alerts_filtered_by_since(self, manager): + rule = AlertRule( + name="r1", + message="msg", + severity="info", + condition=lambda m, s: True, + ) + manager.add_rule(rule) + + collector = MetricsCollector() + sys_metrics = SystemMetrics( + cpu_percent=50.0, + memory_percent=50.0, + memory_used_mb=100.0, + memory_available_mb=100.0, + disk_percent=50.0, + network_sent_mb=10.0, + network_recv_mb=10.0, + ) + await manager.check(collector, sys_metrics) + + future = datetime.now() + timedelta(hours=1) + alerts = await manager.get_alerts(since=future) + assert len(alerts) == 0 + + @pytest.mark.asyncio + async def test_clear_alerts(self, manager): + rule = AlertRule( + name="r1", + message="msg", + severity="info", + condition=lambda m, s: True, + ) + manager.add_rule(rule) + + collector = MetricsCollector() + sys_metrics = SystemMetrics( + cpu_percent=50.0, + memory_percent=50.0, + memory_used_mb=100.0, + memory_available_mb=100.0, + disk_percent=50.0, + network_sent_mb=10.0, + network_recv_mb=10.0, + ) + await manager.check(collector, sys_metrics) + await manager.clear_alerts() + alerts = await manager.get_alerts() + assert alerts == [] + + +# --- AlertRules (predefined) --- + + +class TestAlertRules: + def test_cpu_high_rule(self): + rule = AlertRules.cpu_high(threshold=80.0) + assert rule.name == "cpu_high" + assert rule.severity == "warning" + assert "80" in rule.message + + @pytest.mark.asyncio + async def test_cpu_high_condition_true(self): + rule = AlertRules.cpu_high(threshold=80.0) + sys_m = SystemMetrics( + cpu_percent=95.0, + memory_percent=50.0, + memory_used_mb=100.0, + memory_available_mb=100.0, + disk_percent=50.0, + network_sent_mb=10.0, + network_recv_mb=10.0, + ) + result = rule.condition(MetricsCollector(), sys_m) + assert result is True + + @pytest.mark.asyncio + async def test_cpu_high_condition_false(self): + rule = AlertRules.cpu_high(threshold=80.0) + sys_m = SystemMetrics( + cpu_percent=50.0, + memory_percent=50.0, + memory_used_mb=100.0, + memory_available_mb=100.0, + disk_percent=50.0, + network_sent_mb=10.0, + network_recv_mb=10.0, + ) + result = rule.condition(MetricsCollector(), sys_m) + assert result is False + + def test_memory_high_rule(self): + rule = AlertRules.memory_high(threshold=85.0) + assert rule.name == "memory_high" + assert rule.severity == "warning" + assert "85" in rule.message + + @pytest.mark.asyncio + async def test_memory_high_condition_true(self): + rule = AlertRules.memory_high(threshold=85.0) + sys_m = SystemMetrics( + cpu_percent=50.0, + memory_percent=90.0, + memory_used_mb=100.0, + memory_available_mb=100.0, + disk_percent=50.0, + network_sent_mb=10.0, + network_recv_mb=10.0, + ) + result = rule.condition(MetricsCollector(), sys_m) + assert result is True + + def test_disk_high_rule(self): + rule = AlertRules.disk_high(threshold=90.0) + assert rule.name == "disk_high" + assert rule.severity == "critical" + + @pytest.mark.asyncio + async def test_disk_high_condition_true(self): + rule = AlertRules.disk_high(threshold=90.0) + sys_m = SystemMetrics( + cpu_percent=50.0, + memory_percent=50.0, + memory_used_mb=100.0, + memory_available_mb=100.0, + disk_percent=95.0, + network_sent_mb=10.0, + network_recv_mb=10.0, + ) + result = rule.condition(MetricsCollector(), sys_m) + assert result is True + + @pytest.mark.asyncio + async def test_error_rate_high_no_data(self): + rule = AlertRules.error_rate_high(threshold=0.1) + assert rule.name == "error_rate_high" + assert rule.severity == "critical" + collector = MetricsCollector() + sys_m = SystemMetrics( + cpu_percent=50.0, + memory_percent=50.0, + memory_used_mb=100.0, + memory_available_mb=100.0, + disk_percent=50.0, + network_sent_mb=10.0, + network_recv_mb=10.0, + ) + result = await rule.condition(collector, sys_m) + # No data means percentile is None, so condition is False + assert result is False + + +# --- Global helpers --- + + +class TestGlobalHelpers: + def test_get_metrics_collector_returns_same_instance(self): + import jojo_code.core.monitoring as mon_mod + + original = mon_mod._metrics_collector + try: + mon_mod._metrics_collector = None + c1 = get_metrics_collector() + c2 = get_metrics_collector() + assert c1 is c2 + assert isinstance(c1, MetricsCollector) + finally: + mon_mod._metrics_collector = original + + def test_get_system_monitor_returns_same_instance(self): + import jojo_code.core.monitoring as mon_mod + + original = mon_mod._system_monitor + try: + mon_mod._system_monitor = None + m1 = get_system_monitor() + m2 = get_system_monitor() + assert m1 is m2 + assert isinstance(m1, SystemMonitor) + finally: + mon_mod._system_monitor = original + + def test_get_alert_manager_returns_same_instance(self): + import jojo_code.core.monitoring as mon_mod + + original = mon_mod._alert_manager + try: + mon_mod._alert_manager = None + a1 = get_alert_manager() + a2 = get_alert_manager() + assert a1 is a2 + assert isinstance(a1, AlertManager) + finally: + mon_mod._alert_manager = original diff --git a/tests/test_core/test_plugin.py b/tests/test_core/test_plugin.py new file mode 100644 index 0000000..59c4a88 --- /dev/null +++ b/tests/test_core/test_plugin.py @@ -0,0 +1,721 @@ +"""Core 插件系统测试 + +测试 PluginManager, PluginMetadata, Plugin, PluginContext, +以及钩子系统和版本兼容性检查。 +""" + +import json +import sys +from datetime import datetime +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# 由于 semver 可能未安装,需要 mock +_semver_mock = MagicMock() + +class FakeVersionInfo: + """模拟 semver.VersionInfo""" + + def __init__(self, major=0, minor=0, patch=0): + self.major = major + self.minor = minor + self.patch = patch + + @classmethod + def parse(cls, version_str: str): + parts = version_str.split(".") + return cls( + major=int(parts[0]), + minor=int(parts[1]) if len(parts) > 1 else 0, + patch=int(parts[2]) if len(parts) > 2 else 0, + ) + + def __lt__(self, other): + return (self.major, self.minor, self.patch) < (other.major, other.minor, other.patch) + + def __gt__(self, other): + return (self.major, self.minor, self.patch) > (other.major, other.minor, other.patch) + + def __le__(self, other): + return (self.major, self.minor, self.patch) <= (other.major, other.minor, other.patch) + + def __ge__(self, other): + return (self.major, self.minor, self.patch) >= (other.major, other.minor, other.patch) + + def __eq__(self, other): + return (self.major, self.minor, self.patch) == (other.major, other.minor, other.patch) + + +_semver_mock.VersionInfo = FakeVersionInfo + +# 注入 mock semver 到 sys.modules +sys.modules.setdefault("semver", _semver_mock) + +from jojo_code.core.plugin import ( + Plugin, + PluginContext, + PluginError, + PluginLoadError, + PluginManager, + PluginMetadata, + PluginNotFoundError, + create_plugin_template, + get_plugin_manager, +) + +# ============================================================================= +# PluginMetadata 测试 +# ============================================================================= + + +class TestPluginMetadata: + """PluginMetadata 数据类测试""" + + def test_create_metadata(self): + """应该正确创建元数据""" + meta = PluginMetadata( + name="test-plugin", + version="1.0.0", + description="A test plugin", + author="tester", + ) + assert meta.name == "test-plugin" + assert meta.version == "1.0.0" + assert meta.description == "A test plugin" + assert meta.author == "tester" + assert meta.license == "MIT" + + def test_default_values(self): + """应该有合理的默认值""" + meta = PluginMetadata( + name="test", version="0.1.0", description="test", author="test" + ) + assert meta.homepage == "" + assert meta.keywords == [] + assert meta.dependencies == {} + assert meta.min_jojo_code_version == "0.1.0" + assert meta.max_jojo_code_version is None + + def test_custom_values(self): + """应该支持自定义值""" + meta = PluginMetadata( + name="custom", + version="2.0.0", + description="Custom plugin", + author="dev", + license="Apache-2.0", + homepage="https://example.com", + keywords=["test", "demo"], + dependencies={"numpy": ">=1.0"}, + min_jojo_code_version="0.2.0", + max_jojo_code_version="1.0.0", + ) + assert meta.license == "Apache-2.0" + assert meta.homepage == "https://example.com" + assert meta.keywords == ["test", "demo"] + assert meta.dependencies == {"numpy": ">=1.0"} + + +# ============================================================================= +# Plugin 数据类测试 +# ============================================================================= + + +class TestPlugin: + """Plugin 数据类测试""" + + def test_create_plugin(self): + """应该正确创建插件实例""" + meta = PluginMetadata( + name="test", version="1.0.0", description="test", author="test" + ) + module = MagicMock() + plugin = Plugin(metadata=meta, module=module) + assert plugin.metadata is meta + assert plugin.module is module + assert plugin.enabled is False + assert plugin.loaded_at is not None + + def test_loaded_at_auto_set(self): + """loaded_at 应自动设置为当前时间""" + meta = PluginMetadata( + name="test", version="1.0.0", description="test", author="test" + ) + plugin = Plugin(metadata=meta, module=MagicMock()) + assert plugin.loaded_at is not None + # 验证是有效的 ISO 格式 + datetime.fromisoformat(plugin.loaded_at) + + def test_custom_loaded_at(self): + """应该支持自定义 loaded_at""" + meta = PluginMetadata( + name="test", version="1.0.0", description="test", author="test" + ) + custom_time = "2026-01-01T00:00:00" + plugin = Plugin(metadata=meta, module=MagicMock(), loaded_at=custom_time) + assert plugin.loaded_at == custom_time + + +# ============================================================================= +# 异常类测试 +# ============================================================================= + + +class TestPluginExceptions: + """插件异常类测试""" + + def test_plugin_error_hierarchy(self): + """PluginError 应继承 Exception""" + assert issubclass(PluginError, Exception) + + def test_plugin_load_error_hierarchy(self): + """PluginLoadError 应继承 PluginError""" + assert issubclass(PluginLoadError, PluginError) + + def test_plugin_not_found_error_hierarchy(self): + """PluginNotFoundError 应继承 PluginError""" + assert issubclass(PluginNotFoundError, PluginError) + + def test_raise_plugin_error(self): + """应该能抛出 PluginError""" + with pytest.raises(PluginError): + raise PluginError("test error") + + def test_raise_plugin_load_error(self): + """应该能抛出 PluginLoadError""" + with pytest.raises(PluginLoadError): + raise PluginLoadError("load failed") + + def test_raise_plugin_not_found_error(self): + """应该能抛出 PluginNotFoundError""" + with pytest.raises(PluginNotFoundError): + raise PluginNotFoundError("not found") + + +# ============================================================================= +# PluginManager 测试 +# ============================================================================= + + +class TestPluginManager: + """PluginManager 插件管理器测试""" + + def test_init_with_default_dir(self, tmp_path): + """默认插件目录应为 ~/.jojo-code/plugins""" + with patch("jojo_code.core.plugin.Path.home", return_value=tmp_path): + manager = PluginManager() + assert "plugins" in str(manager.plugins_dir) + + def test_init_with_custom_dir(self, tmp_path): + """应支持自定义插件目录""" + custom_dir = tmp_path / "my_plugins" + manager = PluginManager(plugins_dir=custom_dir) + assert manager.plugins_dir == custom_dir + assert custom_dir.exists() + + def test_plugins_dict_empty_init(self, tmp_path): + """初始化时插件列表应为空""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + assert manager.plugins == {} + + def test_hooks_dict_empty_init(self, tmp_path): + """初始化时钩子列表应为空""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + assert manager.hooks == {} + + def test_discover_plugins_empty_dir(self, tmp_path): + """空目录应返回空列表""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + manager = PluginManager(plugins_dir=plugin_dir) + assert manager.discover_plugins() == [] + + def test_discover_plugins_with_py_files(self, tmp_path): + """应该发现 .py 文件""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + (plugin_dir / "my_plugin.py").write_text("# plugin") + (plugin_dir / "not_plugin.txt").write_text("not a plugin") + + manager = PluginManager(plugins_dir=plugin_dir) + plugins = manager.discover_plugins() + assert len(plugins) == 1 + assert plugins[0].name == "my_plugin.py" + + def test_discover_plugins_with_dirs(self, tmp_path): + """应该发现包含 plugin.py 的目录""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + plugin_subdir = plugin_dir / "my_plugin" + plugin_subdir.mkdir() + (plugin_subdir / "plugin.py").write_text("# plugin") + + manager = PluginManager(plugins_dir=plugin_dir) + plugins = manager.discover_plugins() + assert len(plugins) == 1 + assert plugins[0].name == "my_plugin" + + @pytest.mark.asyncio + async def test_get_plugin_returns_none(self, tmp_path): + """不存在的插件应返回 None""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + assert manager.get_plugin("nonexistent") is None + + @pytest.mark.asyncio + async def test_get_plugin_returns_plugin(self, tmp_path): + """已加载的插件应返回 Plugin 实例""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + meta = PluginMetadata( + name="test", version="1.0.0", description="test", author="test" + ) + plugin = Plugin(metadata=meta, module=MagicMock()) + manager.plugins["test"] = plugin + + result = manager.get_plugin("test") + assert result is plugin + + @pytest.mark.asyncio + async def test_list_plugins_empty(self, tmp_path): + """没有插件时应返回空列表""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + assert manager.list_plugins() == [] + + @pytest.mark.asyncio + async def test_list_plugins_returns_all(self, tmp_path): + """list_plugins 应返回所有插件""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + for i in range(3): + meta = PluginMetadata( + name=f"plugin-{i}", version="1.0.0", description="test", author="test" + ) + manager.plugins[f"plugin-{i}"] = Plugin(metadata=meta, module=MagicMock()) + + plugins = manager.list_plugins() + assert len(plugins) == 3 + + @pytest.mark.asyncio + async def test_list_plugins_enabled_only(self, tmp_path): + """enabled_only=True 应只返回已启用的插件""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + for i in range(3): + meta = PluginMetadata( + name=f"plugin-{i}", version="1.0.0", description="test", author="test" + ) + plugin = Plugin(metadata=meta, module=MagicMock(), enabled=(i % 2 == 0)) + manager.plugins[f"plugin-{i}"] = plugin + + enabled = manager.list_plugins(enabled_only=True) + assert len(enabled) == 2 # plugin-0 和 plugin-2 + + @pytest.mark.asyncio + async def test_unload_nonexistent_raises(self, tmp_path): + """卸载不存在的插件应抛出 PluginNotFoundError""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + with pytest.raises(PluginNotFoundError): + await manager.unload_plugin("nonexistent") + + @pytest.mark.asyncio + async def test_enable_nonexistent_raises(self, tmp_path): + """启用不存在的插件应抛出 PluginNotFoundError""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + with pytest.raises(PluginNotFoundError): + await manager.enable_plugin("nonexistent") + + @pytest.mark.asyncio + async def test_disable_nonexistent_raises(self, tmp_path): + """禁用不存在的插件应抛出 PluginNotFoundError""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + with pytest.raises(PluginNotFoundError): + await manager.disable_plugin("nonexistent") + + @pytest.mark.asyncio + async def test_enable_plugin(self, tmp_path): + """启用插件应设置 enabled=True""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + meta = PluginMetadata( + name="test", version="1.0.0", description="test", author="test" + ) + module = MagicMock(spec=[]) # 没有 on_enable 属性 + plugin = Plugin(metadata=meta, module=module, enabled=False) + manager.plugins["test"] = plugin + + await manager.enable_plugin("test") + assert plugin.enabled is True + + @pytest.mark.asyncio + async def test_enable_plugin_calls_on_enable(self, tmp_path): + """启用插件时应调用 on_enable 钩子""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + meta = PluginMetadata( + name="test", version="1.0.0", description="test", author="test" + ) + module = MagicMock() + module.on_enable = AsyncMock() + plugin = Plugin(metadata=meta, module=module, enabled=False) + manager.plugins["test"] = plugin + + await manager.enable_plugin("test") + module.on_enable.assert_called_once() + + @pytest.mark.asyncio + async def test_disable_plugin(self, tmp_path): + """禁用插件应设置 enabled=False""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + meta = PluginMetadata( + name="test", version="1.0.0", description="test", author="test" + ) + module = MagicMock(spec=[]) # 没有 on_disable 属性 + plugin = Plugin(metadata=meta, module=module, enabled=True) + manager.plugins["test"] = plugin + + await manager.disable_plugin("test") + assert plugin.enabled is False + + @pytest.mark.asyncio + async def test_disable_plugin_calls_on_disable(self, tmp_path): + """禁用插件时应调用 on_disable 钩子""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + meta = PluginMetadata( + name="test", version="1.0.0", description="test", author="test" + ) + module = MagicMock() + module.on_disable = AsyncMock() + plugin = Plugin(metadata=meta, module=module, enabled=True) + manager.plugins["test"] = plugin + + await manager.disable_plugin("test") + module.on_disable.assert_called_once() + + @pytest.mark.asyncio + async def test_enable_already_enabled_no_op(self, tmp_path): + """重复启用已启用的插件应为 no-op""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + meta = PluginMetadata( + name="test", version="1.0.0", description="test", author="test" + ) + module = MagicMock() + module.on_enable = AsyncMock() + plugin = Plugin(metadata=meta, module=module, enabled=True) + manager.plugins["test"] = plugin + + await manager.enable_plugin("test") + module.on_enable.assert_not_called() + + @pytest.mark.asyncio + async def test_disable_already_disabled_no_op(self, tmp_path): + """重复禁用已禁用的插件应为 no-op""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + meta = PluginMetadata( + name="test", version="1.0.0", description="test", author="test" + ) + module = MagicMock() + module.on_disable = AsyncMock() + plugin = Plugin(metadata=meta, module=module, enabled=False) + manager.plugins["test"] = plugin + + await manager.disable_plugin("test") + module.on_disable.assert_not_called() + + +# ============================================================================= +# 版本兼容性测试 +# ============================================================================= + + +class TestVersionCompatibility: + """版本兼容性检查测试""" + + def test_compatible_version(self, tmp_path): + """版本在范围内应返回 True""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + assert manager._check_version_compatibility("0.2.0", "0.1.0", None) is True + + def test_version_below_minimum(self, tmp_path): + """版本低于最小版本应返回 False""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + assert manager._check_version_compatibility("0.1.0", "0.2.0", None) is False + + def test_version_above_maximum(self, tmp_path): + """版本高于最大版本应返回 False""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + assert manager._check_version_compatibility("2.0.0", "0.1.0", "1.0.0") is False + + def test_version_within_range(self, tmp_path): + """版本在范围内应返回 True""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + assert manager._check_version_compatibility("0.5.0", "0.1.0", "1.0.0") is True + + def test_no_max_version(self, tmp_path): + """没有最大版本限制时应只检查最小版本""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + assert manager._check_version_compatibility("99.0.0", "0.1.0", None) is True + + +# ============================================================================= +# 钩子系统测试 +# ============================================================================= + + +class TestHookSystem: + """钩子系统测试""" + + def test_register_hook(self, tmp_path): + """注册钩子应添加到 hooks 字典""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + handler = MagicMock() + manager.register_hook("test_hook", handler) + assert "test_hook" in manager.hooks + assert handler in manager.hooks["test_hook"] + + def test_register_multiple_hooks(self, tmp_path): + """同一个钩子名可以注册多个处理器""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + handler1 = MagicMock() + handler2 = MagicMock() + manager.register_hook("test_hook", handler1) + manager.register_hook("test_hook", handler2) + assert len(manager.hooks["test_hook"]) == 2 + + def test_unregister_hook(self, tmp_path): + """注销钩子应从 hooks 字典移除""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + handler = MagicMock() + manager.register_hook("test_hook", handler) + manager.unregister_hook("test_hook", handler) + assert handler not in manager.hooks["test_hook"] + + @pytest.mark.asyncio + async def test_trigger_sync_hook(self, tmp_path): + """触发同步钩子应调用处理器""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + handler = MagicMock(return_value="result") + manager.register_hook("test_hook", handler) + + results = await manager.trigger_hook("test_hook", "arg1", key="value") + handler.assert_called_once_with("arg1", key="value") + assert results == ["result"] + + @pytest.mark.asyncio + async def test_trigger_async_hook(self, tmp_path): + """触发异步钩子应调用处理器""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + handler = AsyncMock(return_value="async_result") + manager.register_hook("test_hook", handler) + + results = await manager.trigger_hook("test_hook", "arg1") + handler.assert_called_once_with("arg1") + assert results == ["async_result"] + + @pytest.mark.asyncio + async def test_trigger_nonexistent_hook(self, tmp_path): + """触发不存在的钩子应返回空列表""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + results = await manager.trigger_hook("nonexistent") + assert results == [] + + @pytest.mark.asyncio + async def test_trigger_hook_error_handling(self, tmp_path): + """钩子处理器异常不应中断其他处理器""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + + def bad_handler(): + raise ValueError("boom") + + good_handler = MagicMock(return_value="ok") + manager.register_hook("test_hook", bad_handler) + manager.register_hook("test_hook", good_handler) + + results = await manager.trigger_hook("test_hook") + assert "ok" in results + + +# ============================================================================= +# PluginContext 测试 +# ============================================================================= + + +class TestPluginContext: + """PluginContext 插件上下文测试""" + + def test_init(self, tmp_path): + """应该正确初始化""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + ctx = PluginContext("my-plugin", manager) + assert ctx.plugin_name == "my-plugin" + assert ctx.manager is manager + + def test_register_command(self, tmp_path): + """register_command 应注册命令钩子""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + ctx = PluginContext("my-plugin", manager) + handler = MagicMock() + ctx.register_command("my-cmd", handler) + assert "command:my-cmd" in manager.hooks + + def test_register_tool(self, tmp_path): + """register_tool 应注册工具钩子""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + ctx = PluginContext("my-plugin", manager) + tool_class = MagicMock() + ctx.register_tool("my-tool", tool_class) + assert "tool:my-tool" in manager.hooks + + def test_get_config_default(self, tmp_path): + """没有配置文件时应返回默认值""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + ctx = PluginContext("my-plugin", manager) + result = ctx.get_config("key", default="default_value") + assert result == "default_value" + + def test_get_config_from_file(self, tmp_path): + """应该从配置文件读取值""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + config_dir = tmp_path / "plugins" / "my-plugin" + config_dir.mkdir(parents=True) + config_file = config_dir / "config.json" + config_file.write_text(json.dumps({"key": "value"})) + + ctx = PluginContext("my-plugin", manager) + result = ctx.get_config("key") + assert result == "value" + + def test_set_config(self, tmp_path): + """set_config 应写入配置文件""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + ctx = PluginContext("my-plugin", manager) + ctx.set_config("new_key", "new_value") + + config_file = tmp_path / "plugins" / "my-plugin" / "config.json" + assert config_file.exists() + with open(config_file) as f: + config = json.load(f) + assert config["new_key"] == "new_value" + + def test_set_config_preserves_existing(self, tmp_path): + """set_config 应保留已有配置""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + config_dir = tmp_path / "plugins" / "my-plugin" + config_dir.mkdir(parents=True) + config_file = config_dir / "config.json" + config_file.write_text(json.dumps({"existing": "val"})) + + ctx = PluginContext("my-plugin", manager) + ctx.set_config("new_key", "new_value") + + with open(config_file) as f: + config = json.load(f) + assert config["existing"] == "val" + assert config["new_key"] == "new_value" + + +# ============================================================================= +# create_plugin_template 测试 +# ============================================================================= + + +class TestCreatePluginTemplate: + """create_plugin_template 函数测试""" + + def test_creates_template_files(self, tmp_path): + """应该创建 plugin.py, config.json, README.md""" + with patch( + "jojo_code.core.plugin.Path.__new__", + side_effect=lambda cls, *args: Path.__new__(cls, *args), + ): + # create_plugin_template 使用相对路径,这里直接测试输出 + # 通过 patch 来控制输出目录 + pass + + def test_template_contains_metadata(self): + """模板代码应包含 PLUGIN_METADATA""" + # 验证函数存在且可调用 + assert callable(create_plugin_template) + + +# ============================================================================= +# get_plugin_manager 测试 +# ============================================================================= + + +class TestGetPluginManager: + """get_plugin_manager 全局函数测试""" + + def test_returns_plugin_manager(self): + """应返回 PluginManager 实例""" + import jojo_code.core.plugin as module + + module._plugin_manager = None + manager = get_plugin_manager() + assert isinstance(manager, PluginManager) + + def test_returns_singleton(self): + """应返回单例""" + import jojo_code.core.plugin as module + + module._plugin_manager = None + manager1 = get_plugin_manager() + manager2 = get_plugin_manager() + assert manager1 is manager2 + + +# ============================================================================= +# 集成测试 +# ============================================================================= + + +class TestPluginIntegration: + """插件系统集成测试""" + + @pytest.mark.asyncio + async def test_full_lifecycle(self, tmp_path): + """完整插件生命周期:加载 -> 启用 -> 禁用 -> 卸载""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + meta = PluginMetadata( + name="lifecycle-test", version="1.0.0", description="test", author="test" + ) + module = MagicMock() + module.on_enable = AsyncMock() + module.on_disable = AsyncMock() + module.on_unload = AsyncMock() + plugin = Plugin(metadata=meta, module=module) + manager.plugins["lifecycle-test"] = plugin + + # 初始状态 + assert plugin.enabled is False + + # 启用 + await manager.enable_plugin("lifecycle-test") + assert plugin.enabled is True + module.on_enable.assert_called_once() + + # 禁用 + await manager.disable_plugin("lifecycle-test") + assert plugin.enabled is False + module.on_disable.assert_called_once() + + # 卸载 + await manager.unload_plugin("lifecycle-test") + assert "lifecycle-test" not in manager.plugins + + @pytest.mark.asyncio + async def test_hook_integration(self, tmp_path): + """钩子与插件管理器集成""" + manager = PluginManager(plugins_dir=tmp_path / "plugins") + + call_log = [] + + def on_before_tool(tool_name, args): + call_log.append(f"before:{tool_name}") + + def on_after_tool(tool_name, result): + call_log.append(f"after:{tool_name}") + + manager.register_hook("before_tool_call", on_before_tool) + manager.register_hook("after_tool_call", on_after_tool) + + await manager.trigger_hook("before_tool_call", "read_file", {}) + await manager.trigger_hook("after_tool_call", "read_file", "content") + + assert call_log == ["before:read_file", "after:read_file"] diff --git a/tests/test_core/test_webhook.py b/tests/test_core/test_webhook.py new file mode 100644 index 0000000..3ff9988 --- /dev/null +++ b/tests/test_core/test_webhook.py @@ -0,0 +1,638 @@ +"""Core Webhook 模块测试 + +测试 WebhookEventType, WebhookEvent, WebhookConfig, +WebhookSignatureError, WebhookDeliveryError, WebhookManager, +WebhookBuilder, get_webhook_manager, emit_event。 +""" + +import hashlib +import hmac +import json +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from jojo_code.core.webhook import ( + WebhookBuilder, + WebhookConfig, + WebhookDeliveryError, + WebhookEvent, + WebhookEventType, + WebhookManager, + WebhookSignatureError, + emit_event, + get_webhook_manager, +) + +# --- Enums --- + + +class TestWebhookEventType: + def test_all_event_types_exist(self): + expected = [ + "MESSAGE_RECEIVED", + "MESSAGE_SENT", + "AGENT_STARTED", + "AGENT_STOPPED", + "AGENT_ERROR", + "TOOL_CALLED", + "TOOL_RESULT", + "CONVERSATION_STARTED", + "CONVERSATION_ENDED", + "USER_JOINED", + "USER_LEFT", + ] + for name in expected: + assert hasattr(WebhookEventType, name) + + def test_values_are_strings(self): + for event_type in WebhookEventType: + assert isinstance(event_type.value, str) + + +# --- Data classes --- + + +class TestWebhookEvent: + def test_construction(self): + event = WebhookEvent( + id="evt-1", + type=WebhookEventType.MESSAGE_RECEIVED, + data={"text": "hello"}, + ) + assert event.id == "evt-1" + assert event.type == WebhookEventType.MESSAGE_RECEIVED + assert event.data == {"text": "hello"} + assert isinstance(event.timestamp, datetime) + assert event.source == "jojo-code" + assert event.version == "1.0.0" + + def test_custom_source_and_version(self): + event = WebhookEvent( + id="evt-2", + type=WebhookEventType.AGENT_STARTED, + data={}, + source="custom-source", + version="2.0.0", + ) + assert event.source == "custom-source" + assert event.version == "2.0.0" + + +class TestWebhookConfig: + def test_defaults(self): + config = WebhookConfig(url="https://example.com/hook") + assert config.url == "https://example.com/hook" + assert config.secret is None + assert config.timeout == 30 + assert config.retry_count == 3 + assert config.retry_delay == 1.0 + assert config.enabled is True + assert config.events == [] + + def test_custom_values(self): + config = WebhookConfig( + url="https://example.com/hook", + secret="s3cret", + timeout=60, + retry_count=5, + retry_delay=2.0, + enabled=False, + events=[WebhookEventType.AGENT_ERROR], + ) + assert config.secret == "s3cret" + assert config.timeout == 60 + assert config.retry_count == 5 + assert config.retry_delay == 2.0 + assert config.enabled is False + assert config.events == [WebhookEventType.AGENT_ERROR] + + +# --- Exceptions --- + + +class TestExceptions: + def test_webhook_signature_error(self): + err = WebhookSignatureError("bad sig") + assert str(err) == "bad sig" + assert issubclass(WebhookSignatureError, Exception) + + def test_webhook_delivery_error(self): + err = WebhookDeliveryError("delivery failed") + assert str(err) == "delivery failed" + assert issubclass(WebhookDeliveryError, Exception) + + +# --- WebhookManager --- + + +class TestWebhookManager: + @pytest.fixture + def manager(self): + return WebhookManager() + + def test_register_webhook(self, manager): + config = WebhookConfig(url="https://example.com/hook") + manager.register_webhook("test", config) + assert "test" in manager.webhooks + assert manager.webhooks["test"] is config + + def test_register_webhook_initializes_event_handlers(self): + manager = WebhookManager() + config = WebhookConfig( + url="https://example.com/hook", + events=[WebhookEventType.MESSAGE_RECEIVED, WebhookEventType.AGENT_ERROR], + ) + manager.register_webhook("test", config) + assert WebhookEventType.MESSAGE_RECEIVED in manager.event_handlers + assert WebhookEventType.AGENT_ERROR in manager.event_handlers + + def test_unregister_webhook(self, manager): + config = WebhookConfig(url="https://example.com/hook") + manager.register_webhook("test", config) + manager.unregister_webhook("test") + assert "test" not in manager.webhooks + + def test_unregister_nonexistent(self, manager): + manager.unregister_webhook("nope") # Should not raise + + def test_get_webhook(self, manager): + config = WebhookConfig(url="https://example.com/hook") + manager.register_webhook("test", config) + assert manager.get_webhook("test") is config + + def test_get_webhook_not_found(self, manager): + assert manager.get_webhook("nope") is None + + def test_list_webhooks(self, manager): + config = WebhookConfig(url="https://example.com/hook") + manager.register_webhook("a", config) + manager.register_webhook("b", config) + names = manager.list_webhooks() + assert sorted(names) == ["a", "b"] + + def test_list_webhooks_empty(self, manager): + assert manager.list_webhooks() == [] + + def test_on_registers_handler(self, manager): + handler = MagicMock() + manager.on(WebhookEventType.MESSAGE_RECEIVED, handler) + assert handler in manager.event_handlers[WebhookEventType.MESSAGE_RECEIVED] + + def test_on_multiple_handlers(self, manager): + h1 = MagicMock() + h2 = MagicMock() + manager.on(WebhookEventType.TOOL_CALLED, h1) + manager.on(WebhookEventType.TOOL_CALLED, h2) + assert len(manager.event_handlers[WebhookEventType.TOOL_CALLED]) == 2 + + def test_off_removes_handler(self, manager): + handler = MagicMock() + manager.on(WebhookEventType.MESSAGE_SENT, handler) + manager.off(WebhookEventType.MESSAGE_SENT, handler) + assert handler not in manager.event_handlers.get(WebhookEventType.MESSAGE_SENT, []) + + def test_off_handler_not_present(self, manager): + handler = MagicMock() + manager.off(WebhookEventType.AGENT_STARTED, handler) # Should not raise + + @pytest.mark.asyncio + async def test_trigger_calls_local_handlers(self, manager): + received = [] + + async def handler(event): + received.append(event) + + manager.on(WebhookEventType.AGENT_STARTED, handler) + + event = WebhookEvent(id="1", type=WebhookEventType.AGENT_STARTED, data={"msg": "hi"}) + results = await manager.trigger(event) + + assert len(received) == 1 + assert received[0].id == "1" + # No webhooks registered, only the local handler result + assert len(results) == 1 + assert results[0]["webhook"] == "local" + assert results[0]["success"] is True # handler succeeded + + @pytest.mark.asyncio + async def test_trigger_local_handler_exception(self, manager): + async def bad_handler(event): + raise ValueError("oops") + + manager.on(WebhookEventType.AGENT_ERROR, bad_handler) + + event = WebhookEvent(id="1", type=WebhookEventType.AGENT_ERROR, data={}) + results = await manager.trigger(event) + assert len(results) == 1 + assert results[0]["webhook"] == "local" + assert results[0]["success"] is False + assert "oops" in results[0]["error"] + + @pytest.mark.asyncio + async def test_trigger_skips_disabled_webhook(self, manager): + config = WebhookConfig(url="https://example.com/hook", enabled=False) + manager.register_webhook("disabled", config) + + event = WebhookEvent(id="1", type=WebhookEventType.MESSAGE_RECEIVED, data={}) + + with patch.object(manager, "_deliver", new_callable=AsyncMock) as mock_deliver: + await manager.trigger(event) + mock_deliver.assert_not_called() + + @pytest.mark.asyncio + async def test_trigger_skips_non_matching_event_type(self, manager): + config = WebhookConfig( + url="https://example.com/hook", + events=[WebhookEventType.AGENT_ERROR], + ) + manager.register_webhook("filter", config) + + event = WebhookEvent(id="1", type=WebhookEventType.MESSAGE_RECEIVED, data={}) + + with patch.object(manager, "_deliver", new_callable=AsyncMock) as mock_deliver: + await manager.trigger(event) + mock_deliver.assert_not_called() + + @pytest.mark.asyncio + async def test_trigger_delivers_to_matching_webhook(self, manager): + config = WebhookConfig( + url="https://example.com/hook", + events=[WebhookEventType.MESSAGE_RECEIVED], + ) + manager.register_webhook("target", config) + + event = WebhookEvent(id="1", type=WebhookEventType.MESSAGE_RECEIVED, data={"text": "hi"}) + + expected_result = {"webhook": "target", "success": True, "status": 200, "attempt": 1} + with patch.object( + manager, "_deliver", new_callable=AsyncMock, return_value=expected_result + ) as mock_deliver: + results = await manager.trigger(event) + mock_deliver.assert_called_once_with("target", config, event) + assert results == [expected_result] + + def test_build_payload(self, manager): + now = datetime(2024, 1, 15, 10, 30, 0) + event = WebhookEvent( + id="evt-1", + type=WebhookEventType.TOOL_CALLED, + data={"tool": "read_file"}, + timestamp=now, + source="test", + version="2.0", + ) + payload = manager._build_payload(event) + assert payload["id"] == "evt-1" + assert payload["type"] == "tool_called" + assert payload["data"] == {"tool": "read_file"} + assert payload["timestamp"] == "2024-01-15T10:30:00" + assert payload["source"] == "test" + assert payload["version"] == "2.0" + + def test_sign(self, manager): + payload = {"key": "value"} + secret = "my-secret" + sig = manager._sign(payload, secret) + + # Verify signature is correct HMAC-SHA256 + data = json.dumps(payload, sort_keys=True) + expected = hmac.new(secret.encode(), data.encode(), hashlib.sha256).hexdigest() + assert sig == expected + + def test_sign_deterministic(self, manager): + payload = {"a": 1, "b": 2} + sig1 = manager._sign(payload, "secret") + sig2 = manager._sign(payload, "secret") + assert sig1 == sig2 + + def test_verify_signature_valid(self, manager): + payload = {"key": "value"} + secret = "my-secret" + sig = manager._sign(payload, secret) + assert manager.verify_signature(payload, sig, secret) is True + + def test_verify_signature_invalid(self, manager): + payload = {"key": "value"} + secret = "my-secret" + assert manager.verify_signature(payload, "badsig", secret) is False + + def test_record_delivery(self, manager): + result = {"webhook": "test", "success": True, "status": 200} + manager._record_delivery(result) + assert len(manager.delivery_history) == 1 + assert manager.delivery_history[0]["webhook"] == "test" + assert "timestamp" in manager.delivery_history[0] + + def test_delivery_history_max_1000(self, manager): + for i in range(1005): + manager._record_delivery({"webhook": "test", "success": True, "idx": i}) + assert len(manager.delivery_history) == 1000 + # Latest entries should be kept + assert manager.delivery_history[-1]["idx"] == 1004 + + def test_get_delivery_history(self, manager): + manager._record_delivery({"webhook": "a", "success": True}) + manager._record_delivery({"webhook": "b", "success": False}) + manager._record_delivery({"webhook": "a", "success": True}) + + history = manager.get_delivery_history("a") + assert len(history) == 2 + assert all(h["webhook"] == "a" for h in history) + + def test_get_delivery_history_limit(self, manager): + for i in range(10): + manager._record_delivery({"webhook": "test", "success": True, "idx": i}) + history = manager.get_delivery_history(limit=3) + assert len(history) == 3 + + def test_get_delivery_history_all(self, manager): + manager._record_delivery({"webhook": "a", "success": True}) + manager._record_delivery({"webhook": "b", "success": True}) + history = manager.get_delivery_history() + assert len(history) == 2 + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession") + async def test_deliver_success(self, mock_session_cls, manager): + config = WebhookConfig(url="https://example.com/hook", retry_count=1) + + mock_response = AsyncMock() + mock_response.status = 200 + + mock_context = AsyncMock() + mock_context.__aenter__ = AsyncMock(return_value=mock_response) + mock_context.__aexit__ = AsyncMock(return_value=False) + + mock_session = AsyncMock() + mock_session.post.return_value = mock_context + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + mock_session_cls.return_value = mock_session + + event = WebhookEvent(id="1", type=WebhookEventType.MESSAGE_RECEIVED, data={}) + result = await manager._deliver("test", config, event) + + assert result["success"] is True + assert result["status"] == 200 + assert result["attempt"] == 1 + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession") + async def test_deliver_with_secret_adds_signature(self, mock_session_cls, manager): + config = WebhookConfig(url="https://example.com/hook", secret="mysecret", retry_count=1) + + mock_response = AsyncMock() + mock_response.status = 200 + + mock_context = AsyncMock() + mock_context.__aenter__ = AsyncMock(return_value=mock_response) + mock_context.__aexit__ = AsyncMock(return_value=False) + + mock_session = AsyncMock() + mock_session.post.return_value = mock_context + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + mock_session_cls.return_value = mock_session + + event = WebhookEvent(id="1", type=WebhookEventType.MESSAGE_RECEIVED, data={}) + result = await manager._deliver("test", config, event) + + assert result["success"] is True + # Check that post was called with a payload containing signature + call_args = mock_session.post.call_args + payload = call_args[1]["json"] + assert "signature" in payload + assert isinstance(payload["signature"], str) + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession") + async def test_deliver_retries_on_failure(self, mock_session_cls, manager): + config = WebhookConfig(url="https://example.com/hook", retry_count=3, retry_delay=0.01) + + mock_response = AsyncMock() + mock_response.status = 500 + + mock_context = AsyncMock() + mock_context.__aenter__ = AsyncMock(return_value=mock_response) + mock_context.__aexit__ = AsyncMock(return_value=False) + + mock_session = AsyncMock() + mock_session.post.return_value = mock_context + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + mock_session_cls.return_value = mock_session + + event = WebhookEvent(id="1", type=WebhookEventType.MESSAGE_RECEIVED, data={}) + result = await manager._deliver("test", config, event) + + assert result["success"] is False + assert result["attempt"] == 3 + assert mock_session.post.call_count == 3 + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession") + async def test_deliver_timeout_error(self, mock_session_cls, manager): + config = WebhookConfig(url="https://example.com/hook", retry_count=2, retry_delay=0.01) + + mock_session = AsyncMock() + mock_session.post.side_effect = TimeoutError("timed out") + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + mock_session_cls.return_value = mock_session + + event = WebhookEvent(id="1", type=WebhookEventType.MESSAGE_RECEIVED, data={}) + result = await manager._deliver("test", config, event) + + assert result["success"] is False + assert result["error"] == "Timeout" + assert result["attempt"] == 2 + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession") + async def test_deliver_generic_exception(self, mock_session_cls, manager): + config = WebhookConfig(url="https://example.com/hook", retry_count=1, retry_delay=0.01) + + mock_session = AsyncMock() + mock_session.post.side_effect = ConnectionError("refused") + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + mock_session_cls.return_value = mock_session + + event = WebhookEvent(id="1", type=WebhookEventType.MESSAGE_RECEIVED, data={}) + result = await manager._deliver("test", config, event) + + assert result["success"] is False + assert "refused" in result["error"] + + @pytest.mark.asyncio + @patch("aiohttp.ClientSession") + async def test_deliver_records_to_history(self, mock_session_cls, manager): + config = WebhookConfig(url="https://example.com/hook", retry_count=1) + + mock_response = AsyncMock() + mock_response.status = 200 + + mock_context = AsyncMock() + mock_context.__aenter__ = AsyncMock(return_value=mock_response) + mock_context.__aexit__ = AsyncMock(return_value=False) + + mock_session = AsyncMock() + mock_session.post.return_value = mock_context + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + mock_session_cls.return_value = mock_session + + event = WebhookEvent(id="1", type=WebhookEventType.MESSAGE_RECEIVED, data={}) + await manager._deliver("test", config, event) + + history = manager.get_delivery_history("test") + assert len(history) == 1 + assert history[0]["success"] is True + + +# --- WebhookBuilder --- + + +class TestWebhookBuilder: + def test_build_minimal(self): + config = WebhookBuilder().url("https://example.com/hook").build() + assert config.url == "https://example.com/hook" + assert config.secret is None + assert config.timeout == 30 + assert config.retry_count == 3 + assert config.retry_delay == 1.0 + assert config.enabled is True + assert config.events == [] + + def test_build_full(self): + config = ( + WebhookBuilder() + .url("https://example.com/hook") + .secret("mysecret") + .timeout(60) + .retries(5, 2.0) + .enabled(False) + .events(WebhookEventType.AGENT_ERROR, WebhookEventType.TOOL_CALLED) + .build() + ) + assert config.url == "https://example.com/hook" + assert config.secret == "mysecret" + assert config.timeout == 60 + assert config.retry_count == 5 + assert config.retry_delay == 2.0 + assert config.enabled is False + assert config.events == [ + WebhookEventType.AGENT_ERROR, + WebhookEventType.TOOL_CALLED, + ] + + def test_url_required(self): + with pytest.raises(ValueError, match="URL is required"): + WebhookBuilder().build() + + def test_url_missing_scheme(self): + with pytest.raises(ValueError, match="missing scheme"): + WebhookBuilder().url("example.com/hook") + + def test_url_with_http(self): + config = WebhookBuilder().url("http://example.com/hook").build() + assert config.url == "http://example.com/hook" + + def test_url_with_https(self): + config = WebhookBuilder().url("https://example.com/hook").build() + assert config.url == "https://example.com/hook" + + def test_fluent_chaining(self): + builder = WebhookBuilder() + result = builder.url("https://example.com/hook") + assert result is builder + + result = builder.secret("s") + assert result is builder + + result = builder.timeout(10) + assert result is builder + + result = builder.retries(2) + assert result is builder + + result = builder.enabled(True) + assert result is builder + + result = builder.events(WebhookEventType.MESSAGE_RECEIVED) + assert result is builder + + +# --- Global helpers --- + + +class TestGlobalHelpers: + def test_get_webhook_manager_returns_same_instance(self): + import jojo_code.core.webhook as wh_mod + + original = wh_mod._webhook_manager + try: + wh_mod._webhook_manager = None + m1 = get_webhook_manager() + m2 = get_webhook_manager() + assert m1 is m2 + assert isinstance(m1, WebhookManager) + finally: + wh_mod._webhook_manager = original + + @pytest.mark.asyncio + async def test_emit_event_triggers_manager(self): + import jojo_code.core.webhook as wh_mod + + original = wh_mod._webhook_manager + try: + wh_mod._webhook_manager = None + manager = get_webhook_manager() + + received = [] + + async def handler(event): + received.append(event) + + manager.on(WebhookEventType.TOOL_RESULT, handler) + + await emit_event( + WebhookEventType.TOOL_RESULT, {"tool": "read_file", "result": "ok"} + ) + + assert len(received) == 1 + assert received[0].type == WebhookEventType.TOOL_RESULT + assert received[0].data == {"tool": "read_file", "result": "ok"} + assert received[0].source == "jojo-code" + finally: + wh_mod._webhook_manager = original + + @pytest.mark.asyncio + async def test_emit_event_with_custom_source(self): + import jojo_code.core.webhook as wh_mod + + original = wh_mod._webhook_manager + try: + wh_mod._webhook_manager = None + manager = get_webhook_manager() + + received = [] + + async def handler(event): + received.append(event) + + manager.on(WebhookEventType.AGENT_STARTED, handler) + + await emit_event(WebhookEventType.AGENT_STARTED, {"id": "a1"}, source="custom") + + assert received[0].source == "custom" + finally: + wh_mod._webhook_manager = original diff --git a/tests/test_mcp/__init__.py b/tests/test_mcp/__init__.py new file mode 100644 index 0000000..cc3010d --- /dev/null +++ b/tests/test_mcp/__init__.py @@ -0,0 +1 @@ +"""MCP client tests.""" diff --git a/tests/test_mcp/test_client.py b/tests/test_mcp/test_client.py new file mode 100644 index 0000000..f5f134e --- /dev/null +++ b/tests/test_mcp/test_client.py @@ -0,0 +1,273 @@ +"""MCP 客户端测试""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from jojo_code.mcp.client import ( + MCPClient, + MCPClientManager, + MCPConfig, + MCPResource, + MCPTool, + add_mcp_server, + get_mcp_manager, +) + + +class TestMCPConfig: + """测试 MCP 配置""" + + def test_default_config(self): + """测试默认配置""" + config = MCPConfig(name="test", url="http://localhost:8080") + assert config.name == "test" + assert config.url == "http://localhost:8080" + assert config.transport == "stdio" + assert config.auth == {} + assert config.timeout == 30.0 + assert config.retry == 3 + + def test_custom_config(self): + """测试自定义配置""" + config = MCPConfig( + name="my-server", + url="http://example.com/mcp", + transport="http", + auth={"token": "abc123"}, + timeout=60.0, + retry=5, + ) + assert config.transport == "http" + assert config.auth == {"token": "abc123"} + assert config.timeout == 60.0 + + +class TestMCPTool: + """测试 MCP 工具""" + + def test_tool_creation(self): + """测试工具创建""" + tool = MCPTool( + name="read_file", + description="读取文件", + input_schema={"type": "object", "properties": {"path": {"type": "string"}}}, + ) + assert tool.name == "read_file" + assert tool.description == "读取文件" + assert "path" in tool.input_schema["properties"] + + def test_tool_defaults(self): + """测试工具默认值""" + tool = MCPTool(name="test") + assert tool.description == "" + assert tool.input_schema == {} + + +class TestMCPResource: + """测试 MCP 资源""" + + def test_resource_creation(self): + """测试资源创建""" + resource = MCPResource( + uri="file:///test.txt", + name="test.txt", + description="测试文件", + mime_type="text/plain", + ) + assert resource.uri == "file:///test.txt" + assert resource.name == "test.txt" + assert resource.mime_type == "text/plain" + + def test_resource_defaults(self): + """测试资源默认值""" + resource = MCPResource(uri="file:///test.txt") + assert resource.name == "" + assert resource.description == "" + assert resource.mime_type == "text/plain" + + +class TestMCPClient: + """测试 MCP 客户端""" + + @pytest.fixture + def config(self): + return MCPConfig(name="test-server", url="echo test", transport="stdio") + + @pytest.fixture + def client(self, config): + return MCPClient(config) + + def test_initial_state(self, client): + """测试初始状态""" + assert client.is_connected is False + assert client.list_tools() == [] + assert client.get_tool("nonexistent") is None + + @pytest.mark.asyncio + async def test_connect_unsupported_transport(self): + """测试不支持的传输方式""" + config = MCPConfig(name="test", url="test", transport="unsupported") + client = MCPClient(config) + with pytest.raises(ValueError, match="不支持的传输方式"): + await client.connect() + + @pytest.mark.asyncio + async def test_call_tool_when_disconnected(self, client): + """测试未连接时调用工具""" + with patch.object(client, "connect", new_callable=AsyncMock): + with patch.object( + client, "_send_request", new_callable=AsyncMock, return_value={"result": "ok"} + ): + result = await client.call_tool("test_tool", {"arg": "value"}) + assert result == "ok" + + @pytest.mark.asyncio + async def test_call_tool_error(self, client): + """测试工具调用错误""" + client._connected = True + with patch.object( + client, + "_send_request", + new_callable=AsyncMock, + return_value={"error": "tool not found"}, + ): + from jojo_code.core.exceptions import NetworkError + + with pytest.raises(NetworkError, match="MCP 工具调用失败"): + await client.call_tool("nonexistent", {}) + + @pytest.mark.asyncio + async def test_send_stdio_not_connected(self, client): + """测试 stdio 未连接时发送请求""" + with pytest.raises(RuntimeError, match="未连接"): + await client._send_stdio({"method": "test"}) + + @pytest.mark.asyncio + async def test_send_http_not_connected(self, client): + """测试 HTTP 未连接时发送请求""" + with pytest.raises(RuntimeError, match="未连接"): + await client._send_http({"method": "test"}) + + @pytest.mark.asyncio + async def test_discover_tools(self, client): + """测试工具发现""" + with patch.object( + client, + "_send_request", + new_callable=AsyncMock, + return_value={ + "tools": [ + {"name": "tool1", "description": "Tool 1", "inputSchema": {"type": "object"}}, + {"name": "tool2", "description": "Tool 2"}, + ] + }, + ): + await client._discover_tools() + tools = client.list_tools() + assert len(tools) == 2 + assert tools[0].name == "tool1" + assert client.get_tool("tool2") is not None + + @pytest.mark.asyncio + async def test_list_resources(self, client): + """测试列出资源""" + client._connected = True + with patch.object( + client, + "_send_request", + new_callable=AsyncMock, + return_value={ + "resources": [ + {"uri": "file:///a.txt", "name": "a.txt", "mimeType": "text/plain"}, + ] + }, + ): + resources = await client.list_resources() + assert len(resources) == 1 + assert resources[0].uri == "file:///a.txt" + + @pytest.mark.asyncio + async def test_read_resource(self, client): + """测试读取资源""" + client._connected = True + with patch.object( + client, + "_send_request", + new_callable=AsyncMock, + return_value={"contents": [{"text": "hello world"}]}, + ): + content = await client.read_resource("file:///test.txt") + assert content == "hello world" + + @pytest.mark.asyncio + async def test_read_resource_empty(self, client): + """测试读取空资源""" + client._connected = True + with patch.object(client, "_send_request", new_callable=AsyncMock, return_value={}): + content = await client.read_resource("file:///empty.txt") + assert content == "" + + @pytest.mark.asyncio + async def test_close(self, client): + """测试关闭连接""" + client._connected = True + await client.close() + assert client.is_connected is False + + +class TestMCPClientManager: + """测试 MCP 客户端管理器""" + + def test_add_server(self): + """测试添加服务器""" + manager = MCPClientManager() + config = MCPConfig(name="server1", url="http://localhost:8080") + client = manager.add_server(config) + assert isinstance(client, MCPClient) + assert "server1" in manager.list_servers() + + def test_get_client(self): + """测试获取客户端""" + manager = MCPClientManager() + config = MCPConfig(name="server1", url="http://localhost:8080") + manager.add_server(config) + assert manager.get_client("server1") is not None + assert manager.get_client("nonexistent") is None + + @pytest.mark.asyncio + async def test_remove_server(self): + """测试移除服务器""" + manager = MCPClientManager() + config = MCPConfig(name="server1", url="http://localhost:8080") + manager.add_server(config) + assert manager.remove_server("server1") is True + assert manager.remove_server("nonexistent") is False + assert "server1" not in manager.list_servers() + + def test_list_servers(self): + """测试列出服务器""" + manager = MCPClientManager() + manager.add_server(MCPConfig(name="a", url="http://a")) + manager.add_server(MCPConfig(name="b", url="http://b")) + servers = manager.list_servers() + assert len(servers) == 2 + assert "a" in servers + assert "b" in servers + + +class TestGlobalFunctions: + """测试全局函数""" + + def test_get_mcp_manager(self): + """测试获取全局管理器""" + manager = get_mcp_manager() + assert isinstance(manager, MCPClientManager) + + @pytest.mark.asyncio + async def test_add_mcp_server(self): + """测试快速添加服务器""" + client = add_mcp_server("test_quick", "http://localhost:9999") + assert isinstance(client, MCPClient) + # 清理 + get_mcp_manager().remove_server("test_quick") diff --git a/tests/test_mcp/test_integration.py b/tests/test_mcp/test_integration.py new file mode 100644 index 0000000..1dd9ef9 --- /dev/null +++ b/tests/test_mcp/test_integration.py @@ -0,0 +1,438 @@ +"""MCP client integration tests. + +Tests the MCP client end-to-end flow: connect, discover tools, call tools, +manage resources, and handle errors. Uses mocks for the transport layer +to avoid real subprocess/network dependencies. +""" + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from jojo_code.core.exceptions import NetworkError +from jojo_code.mcp.client import ( + MCPClient, + MCPClientManager, + MCPConfig, + MCPResource, + MCPTool, + add_mcp_server, + get_mcp_manager, +) + + +class TestMCPClientConnectFlow: + """Test the full connection lifecycle.""" + + @pytest.mark.asyncio + async def test_connect_stdio_full_flow(self): + """Test connecting via stdio, discovering tools, then closing.""" + config = MCPConfig(name="test-stdio", url="mock-server --stdio", transport="stdio") + client = MCPClient(config) + + mock_process = AsyncMock() + mock_process.stdin = AsyncMock() + mock_process.stdout = AsyncMock() + mock_process.terminate = MagicMock() + mock_process.wait = AsyncMock() + + # Mock the tools/list response for _discover_tools + tools_response = json.dumps({ + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": "read_file", + "description": "Read a file", + "inputSchema": { + "type": "object", + "properties": {"path": {"type": "string"}}, + }, + }, + { + "name": "write_file", + "description": "Write a file", + "inputSchema": { + "type": "object", + "properties": { + "path": {"type": "string"}, + "content": {"type": "string"}, + }, + }, + }, + ] + }, + }).encode() + b"\n" + + mock_process.stdout.readline.return_value = tools_response + mock_process.stdin.write = MagicMock() # not awaited in client code + mock_process.stdin.drain = AsyncMock() + + with patch("asyncio.create_subprocess_exec", return_value=mock_process): + await client.connect() + + assert client.is_connected is True + tools = client.list_tools() + assert len(tools) == 2 + assert tools[0].name == "read_file" + assert tools[1].name == "write_file" + assert client.get_tool("read_file").description == "Read a file" + + # Close + await client.close() + assert client.is_connected is False + + @pytest.mark.asyncio + async def test_connect_http_full_flow(self): + """Test connecting via HTTP, discovering tools, then closing.""" + config = MCPConfig( + name="test-http", + url="http://localhost:8080/mcp", + transport="http", + timeout=10.0, + ) + client = MCPClient(config) + + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.json = AsyncMock( + return_value={ + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + {"name": "search", "description": "Search tool"}, + ] + }, + } + ) + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=False) + mock_session.post = MagicMock(return_value=mock_response) + mock_session.close = AsyncMock() + + with patch("aiohttp.ClientSession", return_value=mock_session): + await client.connect() + + assert client.is_connected is True + assert len(client.list_tools()) == 1 + assert client.get_tool("search") is not None + + await client.close() + assert client.is_connected is False + + @pytest.mark.asyncio + async def test_connect_idempotent(self): + """Calling connect() twice should be a no-op.""" + config = MCPConfig(name="idempotent", url="echo test", transport="stdio") + client = MCPClient(config) + + call_count = 0 + + async def counting_connect(): + nonlocal call_count + call_count += 1 + + client._connect_stdio = counting_connect + client._discover_tools = AsyncMock() + + await client.connect() + await client.connect() + + assert call_count == 1 + assert client.is_connected is True + + +class TestMCPClientToolInvocation: + """Test tool invocation through the MCP client.""" + + @pytest.mark.asyncio + async def test_call_tool_success(self): + """Test successful tool invocation.""" + config = MCPConfig(name="invoke-test", url="mock", transport="stdio") + client = MCPClient(config) + client._connected = True + + with patch.object( + client, + "_send_request", + new_callable=AsyncMock, + return_value={"result": {"content": [{"type": "text", "text": "file contents"}]}}, + ): + result = await client.call_tool("read_file", {"path": "/test.txt"}) + assert result == {"content": [{"type": "text", "text": "file contents"}]} + + @pytest.mark.asyncio + async def test_call_tool_auto_connects(self): + """call_tool should auto-connect if not connected.""" + config = MCPConfig(name="auto-connect", url="mock", transport="stdio") + client = MCPClient(config) + + connect_called = False + + async def mock_connect(): + nonlocal connect_called + connect_called = True + client._connected = True + + client.connect = mock_connect + + with patch.object( + client, "_send_request", new_callable=AsyncMock, return_value={"result": "ok"} + ): + result = await client.call_tool("any_tool", {}) + assert connect_called is True + assert result == "ok" + + @pytest.mark.asyncio + async def test_call_tool_raises_on_error(self): + """Tool call with error in response should raise NetworkError.""" + config = MCPConfig(name="err-test", url="mock", transport="stdio") + client = MCPClient(config) + client._connected = True + + with patch.object( + client, + "_send_request", + new_callable=AsyncMock, + return_value={"error": {"code": -32601, "message": "Method not found"}}, + ): + with pytest.raises(NetworkError, match="MCP 工具调用失败"): + await client.call_tool("missing_tool", {}) + + +class TestMCPClientResourceFlow: + """Test resource listing and reading.""" + + @pytest.mark.asyncio + async def test_list_and_read_resources(self): + """Test listing resources then reading one.""" + config = MCPConfig(name="res-test", url="mock", transport="stdio") + client = MCPClient(config) + client._connected = True + + with patch.object( + client, + "_send_request", + new_callable=AsyncMock, + return_value={ + "resources": [ + { + "uri": "file:///home/user/project/README.md", + "name": "README.md", + "description": "Project README", + "mimeType": "text/markdown", + }, + { + "uri": "file:///home/user/project/config.json", + "name": "config.json", + "mimeType": "application/json", + }, + ] + }, + ): + resources = await client.list_resources() + assert len(resources) == 2 + assert resources[0].uri == "file:///home/user/project/README.md" + assert resources[0].mime_type == "text/markdown" + assert resources[1].name == "config.json" + + # Now read one + with patch.object( + client, + "_send_request", + new_callable=AsyncMock, + return_value={"contents": [{"text": "# My Project\n\nWelcome."}]}, + ): + content = await client.read_resource("file:///home/user/project/README.md") + assert content == "# My Project\n\nWelcome." + + +class TestMCPClientManagerIntegration: + """Test the client manager with multiple servers.""" + + def test_manager_lifecycle(self): + """Test adding, listing, getting, and removing servers.""" + manager = MCPClientManager() + + c1 = manager.add_server(MCPConfig(name="server-a", url="http://a:8080")) + c2 = manager.add_server(MCPConfig(name="server-b", url="http://b:8080")) + + assert isinstance(c1, MCPClient) + assert isinstance(c2, MCPClient) + assert set(manager.list_servers()) == {"server-a", "server-b"} + assert manager.get_client("server-a") is c1 + assert manager.get_client("server-b") is c2 + assert manager.get_client("missing") is None + + @pytest.mark.asyncio + async def test_manager_connect_all(self): + """Test connecting all servers at once.""" + manager = MCPClientManager() + + c1 = manager.add_server(MCPConfig(name="s1", url="mock1", transport="stdio")) + c2 = manager.add_server(MCPConfig(name="s2", url="mock2", transport="stdio")) + + connect_calls = [] + + async def mock_connect(): + connect_calls.append(True) + + c1.connect = mock_connect + c2.connect = mock_connect + + # Mark one as already connected + c1._connected = True + + await manager.connect_all() + # Only c2 should have been connected (c1 already was) + assert len(connect_calls) == 1 + + @pytest.mark.asyncio + async def test_manager_close_all(self): + """Test closing all connections.""" + manager = MCPClientManager() + + c1 = manager.add_server(MCPConfig(name="x", url="http://x")) + c2 = manager.add_server(MCPConfig(name="y", url="http://y")) + + close_calls = [] + + async def mock_close(): + close_calls.append(True) + + c1.close = mock_close + c2.close = mock_close + + await manager.close_all() + assert len(close_calls) == 2 + + +class TestGlobalMCPFunctions: + """Test the global helper functions.""" + + def test_get_mcp_manager_returns_singleton(self): + """get_mcp_manager should return the same instance.""" + m1 = get_mcp_manager() + m2 = get_mcp_manager() + assert m1 is m2 + assert isinstance(m1, MCPClientManager) + + @pytest.mark.asyncio + async def test_add_mcp_server_convenience(self): + """add_mcp_server should create and register a client.""" + client = add_mcp_server("quick-test", "http://quick:9090", transport="http") + assert isinstance(client, MCPClient) + assert client.config.name == "quick-test" + assert client.config.transport == "http" + + # Verify it's in the global manager + manager = get_mcp_manager() + assert manager.get_client("quick-test") is client + + # Cleanup + manager.remove_server("quick-test") + + +class TestMCPClientEdgeCases: + """Test edge cases and error handling.""" + + @pytest.mark.asyncio + async def test_send_stdio_request_format(self): + """Verify the JSON-RPC request format sent via stdio.""" + config = MCPConfig(name="format-test", url="mock", transport="stdio") + client = MCPClient(config) + + mock_process = AsyncMock() + mock_process.stdin = AsyncMock() + mock_process.stdout = AsyncMock() + response = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}).encode() + b"\n" + mock_process.stdout.readline = AsyncMock(return_value=response) + + client._process = mock_process + + await client._send_request("tools/list", {"cursor": "abc"}) + + # Verify the written data + mock_process.stdin.write.assert_called_once() + written = mock_process.stdin.write.call_args[0][0] + parsed = json.loads(written.decode().strip()) + assert parsed["jsonrpc"] == "2.0" + assert parsed["method"] == "tools/list" + assert parsed["params"] == {"cursor": "abc"} + + @pytest.mark.asyncio + async def test_send_request_default_params(self): + """Request with no params should send empty dict.""" + config = MCPConfig(name="default-params", url="mock", transport="stdio") + client = MCPClient(config) + + mock_process = AsyncMock() + mock_process.stdin = AsyncMock() + mock_process.stdout = AsyncMock() + response = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}).encode() + b"\n" + mock_process.stdout.readline = AsyncMock(return_value=response) + + client._process = mock_process + + await client._send_request("ping") + + written = mock_process.stdin.write.call_args[0][0] + parsed = json.loads(written.decode().strip()) + assert parsed["params"] == {} + + @pytest.mark.asyncio + async def test_close_without_connection(self): + """Closing a never-connected client should not raise.""" + config = MCPConfig(name="no-conn", url="mock", transport="stdio") + client = MCPClient(config) + # Should not raise + await client.close() + assert client.is_connected is False + + def test_mcp_tool_dataclass(self): + """Test MCPTool dataclass creation.""" + tool = MCPTool( + name="complex_tool", + description="A complex tool", + input_schema={ + "type": "object", + "properties": { + "file_path": {"type": "string", "description": "Path to file"}, + "options": { + "type": "object", + "properties": {"verbose": {"type": "boolean"}}, + }, + }, + "required": ["file_path"], + }, + ) + assert tool.name == "complex_tool" + assert tool.input_schema["required"] == ["file_path"] + + def test_mcp_resource_dataclass(self): + """Test MCPResource dataclass creation.""" + resource = MCPResource( + uri="https://example.com/api/data", + name="API Data", + description="Remote API data", + mime_type="application/json", + ) + assert resource.uri == "https://example.com/api/data" + assert resource.mime_type == "application/json" + + def test_mcp_config_auth(self): + """Test MCPConfig with authentication.""" + config = MCPConfig( + name="auth-server", + url="https://secure.example.com/mcp", + transport="http", + auth={"Authorization": "Bearer secret-token", "X-API-Key": "key123"}, + timeout=60.0, + retry=5, + ) + assert config.auth["Authorization"] == "Bearer secret-token" + assert config.timeout == 60.0 + assert config.retry == 5 diff --git a/tests/test_memory/test_long_term.py b/tests/test_memory/test_long_term.py new file mode 100644 index 0000000..683a200 --- /dev/null +++ b/tests/test_memory/test_long_term.py @@ -0,0 +1,444 @@ +"""Tests for LongTermMemory.""" + +import json +from datetime import datetime, timedelta + +import pytest + +from jojo_code.memory.long_term import LongTermMemory, create_longterm_memory +from jojo_code.memory.types import MemoryItem, MemoryType, SearchResult + + +@pytest.fixture +def memory(tmp_path): + """Create a LongTermMemory instance with tmp storage.""" + return LongTermMemory(storage_dir=tmp_path / "memory", retention_days=0) + + +@pytest.fixture +def memory_with_retention(tmp_path): + """Create a LongTermMemory with retention enabled.""" + return LongTermMemory(storage_dir=tmp_path / "memory", retention_days=30) + + +class TestLongTermMemoryInit: + """Tests for initialization.""" + + def test_creates_storage_dir(self, tmp_path): + storage = tmp_path / "new_dir" + assert not storage.exists() + LongTermMemory(storage_dir=storage, retention_days=0) + assert storage.exists() + + def test_default_values(self, memory): + assert memory.max_items == 10000 + assert memory.retention_days == 0 + assert memory._cache == {} + + def test_custom_max_items(self, tmp_path): + mem = LongTermMemory(storage_dir=tmp_path / "mem", max_items=50, retention_days=0) + assert mem.max_items == 50 + + def test_custom_retention_days(self, tmp_path): + mem = LongTermMemory(storage_dir=tmp_path / "mem", retention_days=7) + assert mem.retention_days == 7 + + +class TestAdd: + """Tests for add and add_message.""" + + def test_add_returns_memory_item(self, memory): + item = memory.add("Hello world", session_id="s1") + assert isinstance(item, MemoryItem) + assert item.content == "Hello world" + assert item.session_id == "s1" + assert item.memory_type == MemoryType.LONG_TERM + + def test_add_with_tags(self, memory): + item = memory.add("content", session_id="s1", tags=["python", "test"]) + assert item.tags == ["python", "test"] + + def test_add_with_metadata(self, memory): + item = memory.add("content", session_id="s1", metadata={"key": "value"}) + assert item.metadata == {"key": "value"} + + def test_add_default_tags_and_metadata(self, memory): + item = memory.add("content", session_id="s1") + assert item.tags == [] + assert item.metadata == {} + + def test_add_persists_to_file(self, memory, tmp_path): + memory.add("persisted content", session_id="s1") + items_file = tmp_path / "memory" / "s1" / "memory.json" + assert items_file.exists() + data = json.loads(items_file.read_text(encoding="utf-8")) + assert len(data["items"]) == 1 + assert data["items"][0]["content"] == "persisted content" + + def test_add_multiple_items(self, memory): + memory.add("item1", session_id="s1") + memory.add("item2", session_id="s1") + memory.add("item3", session_id="s1") + items = memory.get_session_memories("s1") + assert len(items) == 3 + + def test_add_message(self, memory): + item = memory.add_message("user says hi", role="user", session_id="s1") + assert item.metadata["role"] == "user" + assert item.content == "user says hi" + + def test_add_message_ai_role(self, memory): + item = memory.add_message("ai responds", role="ai", session_id="s1") + assert item.metadata["role"] == "ai" + + def test_add_to_cache(self, memory): + item = memory.add("cached", session_id="s1") + assert item.id in memory._cache + + +class TestGet: + """Tests for get.""" + + def test_get_existing_item(self, memory): + added = memory.add("find me", session_id="s1") + found = memory.get(added.id) + assert found is not None + assert found.content == "find me" + + def test_get_nonexistent_item(self, memory): + assert memory.get("nonexistent-id") is None + + def test_get_from_cache(self, memory): + added = memory.add("cached item", session_id="s1") + # Force clear file scan path by clearing and re-adding to cache + found = memory.get(added.id) + assert found is not None + assert found.id == added.id + + def test_get_from_file_not_in_cache(self, memory, tmp_path): + """When cache miss, get should scan session files.""" + added = memory.add("file item", session_id="s1") + # Clear cache to force file lookup + memory._cache.clear() + found = memory.get(added.id) + assert found is not None + assert found.content == "file item" + + def test_get_populates_cache(self, memory): + added = memory.add("cache populate", session_id="s1") + memory._cache.clear() + assert added.id not in memory._cache + memory.get(added.id) + assert added.id in memory._cache + + +class TestGetSessionMemories: + """Tests for get_session_memories.""" + + def test_empty_session(self, memory): + assert memory.get_session_memories("nonexistent") == [] + + def test_returns_all_items(self, memory): + memory.add("a", session_id="s1") + memory.add("b", session_id="s1") + memory.add("c", session_id="s1") + items = memory.get_session_memories("s1") + assert len(items) == 3 + + def test_limit(self, memory): + for i in range(10): + memory.add(f"item-{i}", session_id="s1") + items = memory.get_session_memories("s1", limit=3) + assert len(items) == 3 + + def test_only_returns_matching_session(self, memory): + memory.add("s1 item", session_id="s1") + memory.add("s2 item", session_id="s2") + items = memory.get_session_memories("s1") + assert len(items) == 1 + assert items[0].content == "s1 item" + + +class TestListSessions: + """Tests for list_sessions.""" + + def test_empty(self, memory): + assert memory.list_sessions() == [] + + def test_returns_sorted_sessions(self, memory): + memory.add("a", session_id="charlie") + memory.add("b", session_id="alpha") + memory.add("c", session_id="bravo") + sessions = memory.list_sessions() + assert sessions == ["alpha", "bravo", "charlie"] + + def test_no_duplicates(self, memory): + memory.add("a", session_id="s1") + memory.add("b", session_id="s1") + sessions = memory.list_sessions() + assert sessions == ["s1"] + + +class TestSearch: + """Tests for search.""" + + def test_search_finds_keyword(self, memory): + memory.add("Python is great", session_id="s1") + memory.add("Java is okay", session_id="s1") + results = memory.search("python") + assert len(results) == 1 + assert results[0].item.content == "Python is great" + + def test_search_case_insensitive(self, memory): + memory.add("PYTHON programming", session_id="s1") + results = memory.search("python") + assert len(results) == 1 + + def test_search_no_results(self, memory): + memory.add("Python", session_id="s1") + results = memory.search("javascript") + assert len(results) == 0 + + def test_search_limit(self, memory): + for i in range(10): + memory.add(f"Python tutorial part {i}", session_id="s1") + results = memory.search("python", limit=3) + assert len(results) <= 3 + + def test_search_across_sessions(self, memory): + memory.add("Python in s1", session_id="s1") + memory.add("Python in s2", session_id="s2") + results = memory.search("python", session_id=None) + assert len(results) == 2 + + def test_search_specific_session(self, memory): + memory.add("Python in s1", session_id="s1") + memory.add("Python in s2", session_id="s2") + results = memory.search("python", session_id="s1") + assert len(results) == 1 + assert results[0].item.session_id == "s1" + + def test_search_returns_search_results(self, memory): + memory.add("Python code", session_id="s1") + results = memory.search("python") + assert len(results) == 1 + assert isinstance(results[0], SearchResult) + assert 0 < results[0].score <= 1.0 + + def test_search_matched_content(self, memory): + memory.add("I love Python programming language", session_id="s1") + results = memory.search("python") + assert ( + "Python" in results[0].matched_content or "python" in results[0].matched_content.lower() + ) + + def test_search_sorted_by_score(self, memory): + memory.add("Python Python Python best language", session_id="s1") + memory.add("python once", session_id="s1") + results = memory.search("python") + if len(results) >= 2: + assert results[0].score >= results[1].score + + +class TestDeleteSession: + """Tests for delete_session.""" + + def test_delete_existing_session(self, memory): + memory.add("item", session_id="s1") + assert memory.delete_session("s1") is True + assert memory.get_session_memories("s1") == [] + assert "s1" not in memory.list_sessions() + + def test_delete_nonexistent_session(self, memory): + assert memory.delete_session("nope") is False + + def test_delete_cleans_cache(self, memory): + item = memory.add("cached", session_id="s1") + memory.delete_session("s1") + assert item.id not in memory._cache + + def test_delete_does_not_affect_other_sessions(self, memory): + memory.add("s1 item", session_id="s1") + memory.add("s2 item", session_id="s2") + memory.delete_session("s1") + items = memory.get_session_memories("s2") + assert len(items) == 1 + + +class TestCleanup: + """Tests for cleanup.""" + + def test_cleanup_removes_old_items(self, tmp_path): + mem = LongTermMemory(storage_dir=tmp_path / "mem", retention_days=0) + # Manually create an old memory file + session_dir = tmp_path / "mem" / "old-session" + session_dir.mkdir(parents=True) + old_time = datetime.now() - timedelta(days=100) + data = { + "session_id": "old-session", + "updated_at": old_time.isoformat(), + "items": [ + { + "id": "old-item", + "content": "old content", + "memory_type": "long_term", + "session_id": "old-session", + "created_at": old_time.isoformat(), + "metadata": {}, + "tags": [], + } + ], + } + (session_dir / "memory.json").write_text(json.dumps(data), encoding="utf-8") + + # Cleanup with retention_days=30 should remove the 100-day-old item + cleaned = mem.cleanup(retention_days=30) + assert cleaned >= 1 + + def test_cleanup_preserves_recent_items(self, memory): + memory.add("recent item", session_id="s1") + cleaned = memory.cleanup(retention_days=90) + assert cleaned == 0 + items = memory.get_session_memories("s1") + assert len(items) == 1 + + def test_cleanup_custom_retention(self, memory): + # Items added now should not be cleaned with any positive retention + memory.add("fresh", session_id="s1") + cleaned = memory.cleanup(retention_days=1) + assert cleaned == 0 + + +class TestGetStats: + """Tests for get_stats.""" + + def test_stats_empty(self, memory): + stats = memory.get_stats() + assert stats["total_items"] == 0 + assert stats["sessions"] == 0 + assert stats["session_ids"] == [] + + def test_stats_with_data(self, memory): + memory.add("a", session_id="s1") + memory.add("b", session_id="s1") + memory.add("c", session_id="s2") + stats = memory.get_stats() + assert stats["total_items"] == 3 + assert stats["sessions"] == 2 + assert len(stats["session_ids"]) == 2 + + def test_stats_session_ids_limited(self, memory): + for i in range(15): + memory.add(f"item-{i}", session_id=f"session-{i:02d}") + stats = memory.get_stats() + assert len(stats["session_ids"]) <= 10 + + def test_stats_storage_dir(self, memory, tmp_path): + stats = memory.get_stats() + assert "storage_dir" in stats + + +class TestMaxItemsLimit: + """Tests for max_items enforcement.""" + + def test_enforces_max_items(self, tmp_path): + mem = LongTermMemory(storage_dir=tmp_path / "mem", max_items=3, retention_days=0) + for i in range(5): + mem.add(f"item-{i}", session_id="s1") + items = mem.get_session_memories("s1", limit=100) + assert len(items) == 3 + # Should keep the most recent 3 + assert items[0].content == "item-2" + assert items[2].content == "item-4" + + +class TestCalculateScore: + """Tests for _calculate_score.""" + + def test_exact_match(self, memory): + score = memory._calculate_score("Python is great", "Python") + assert score >= 0.5 + + def test_no_match(self, memory): + # _calculate_score assumes match exists, but let's test the path + score = memory._calculate_score("Java", "Python") + assert score == 0.0 + + def test_multiple_occurrences_higher_score(self, memory): + score1 = memory._calculate_score("Python is great", "Python") + score2 = memory._calculate_score("Python Python Python", "Python") + assert score2 > score1 + + def test_position_bonus(self, memory): + score_early = memory._calculate_score("Python at start", "Python") + score_late = memory._calculate_score("Start with Python", "Python") + assert score_early >= score_late + + def test_score_capped_at_one(self, memory): + score = memory._calculate_score("Python " * 100, "Python") + assert score <= 1.0 + + +class TestExtractMatch: + """Tests for _extract_match.""" + + def test_extracts_context(self, memory): + content = "This is a long text with Python somewhere in the middle of it all" + snippet = memory._extract_match(content, "Python") + assert "Python" in snippet + + def test_adds_ellipsis_prefix(self, memory): + content = "A" * 100 + "Python" + "B" * 100 + snippet = memory._extract_match(content, "Python") + assert snippet.startswith("...") + + def test_adds_ellipsis_suffix(self, memory): + content = "A" * 100 + "Python" + "B" * 100 + snippet = memory._extract_match(content, "Python") + assert snippet.endswith("...") + + def test_no_match_returns_prefix(self, memory): + content = "x" * 200 + snippet = memory._extract_match(content, "Python") + assert snippet == content[:100] + + def test_match_at_start_no_prefix_ellipsis(self, memory): + content = "Python is at the very start of this content" + snippet = memory._extract_match(content, "Python") + assert not snippet.startswith("...") + + def test_match_at_end_no_suffix_ellipsis(self, memory): + content = "At the very end we have Python" + snippet = memory._extract_match(content, "Python") + assert not snippet.endswith("...") + + +class TestSaveItems: + """Tests for _save_items persistence format.""" + + def test_save_creates_valid_json(self, memory, tmp_path): + memory.add("test", session_id="s1") + items_file = tmp_path / "memory" / "s1" / "memory.json" + data = json.loads(items_file.read_text(encoding="utf-8")) + assert "session_id" in data + assert "updated_at" in data + assert "items" in data + assert data["session_id"] == "s1" + + def test_save_multiple_sessions(self, memory, tmp_path): + memory.add("s1 item", session_id="s1") + memory.add("s2 item", session_id="s2") + assert (tmp_path / "memory" / "s1" / "memory.json").exists() + assert (tmp_path / "memory" / "s2" / "memory.json").exists() + + +class TestCreateLongtermMemory: + """Tests for the factory function.""" + + def test_factory_creates_instance(self, tmp_path): + mem = create_longterm_memory(storage_dir=tmp_path / "mem") + assert isinstance(mem, LongTermMemory) + + def test_factory_default_storage(self): + mem = create_longterm_memory() + assert isinstance(mem, LongTermMemory) diff --git a/tests/test_memory/test_retriever.py b/tests/test_memory/test_retriever.py new file mode 100644 index 0000000..6c5a9e0 --- /dev/null +++ b/tests/test_memory/test_retriever.py @@ -0,0 +1,445 @@ +"""Tests for MemoryRetriever and SessionMemory.""" + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage + +from jojo_code.memory.long_term import LongTermMemory +from jojo_code.memory.retriever import MemoryRetriever, SessionMemory, create_session_memory +from jojo_code.memory.short_term import ShortTermMemory +from jojo_code.memory.types import SearchResult + + +@pytest.fixture +def short_term(): + """Create a ShortTermMemory with known session id.""" + return ShortTermMemory(session_id="st-session") + + +@pytest.fixture +def long_term(tmp_path): + """Create a LongTermMemory with tmp storage.""" + return LongTermMemory(storage_dir=tmp_path / "memory", retention_days=0) + + +@pytest.fixture +def retriever(short_term, long_term): + """Create a MemoryRetriever with both memories.""" + return MemoryRetriever(short_term=short_term, long_term=long_term) + + +class TestMemoryRetrieverInit: + """Tests for MemoryRetriever initialization.""" + + def test_default_init(self): + retriever = MemoryRetriever() + assert isinstance(retriever.short_term, ShortTermMemory) + assert isinstance(retriever.long_term, LongTermMemory) + + def test_custom_init(self, short_term, long_term): + retriever = MemoryRetriever(short_term=short_term, long_term=long_term) + assert retriever.short_term is short_term + assert retriever.long_term is long_term + + +class TestSearch: + """Tests for MemoryRetriever.search.""" + + def test_search_all_scope(self, retriever, short_term, long_term): + short_term.add_user_message("Python question") + long_term.add("Python history", session_id="old-sess") + results = retriever.search("python", scope="all") + assert len(results["current_session"]) == 1 + assert len(results["history"]) == 1 + + def test_search_current_scope(self, retriever, short_term, long_term): + short_term.add_user_message("Python question") + long_term.add("Python history", session_id="old-sess") + results = retriever.search("python", scope="current") + assert len(results["current_session"]) == 1 + assert len(results["history"]) == 0 + + def test_search_history_scope(self, retriever, short_term, long_term): + short_term.add_user_message("Python question") + long_term.add("Python history", session_id="old-sess") + results = retriever.search("python", scope="history") + assert len(results["current_session"]) == 0 + assert len(results["history"]) == 1 + + def test_search_no_matches(self, retriever, short_term): + short_term.add_user_message("Java is cool") + results = retriever.search("python", scope="all") + assert results["current_session"] == [] + assert results["history"] == [] + + def test_search_empty_short_term(self, retriever, long_term): + long_term.add("Python stuff", session_id="s1") + results = retriever.search("python", scope="all") + assert results["current_session"] == [] + assert len(results["history"]) == 1 + + def test_search_limit(self, retriever, short_term, long_term): + for i in range(10): + short_term.add_user_message(f"Python tutorial {i}") + results = retriever.search("python", scope="current", limit=3) + assert len(results["current_session"]) <= 3 + + def test_search_returns_search_result_in_current(self, retriever, short_term): + short_term.add_user_message("Python rocks") + results = retriever.search("python", scope="current") + assert len(results["current_session"]) == 1 + sr = results["current_session"][0] + assert isinstance(sr, SearchResult) + assert sr.score == 1.0 + + def test_search_user_message_role(self, retriever, short_term): + short_term.add_user_message("Python question") + results = retriever.search("python", scope="current") + assert results["current_session"][0].item.metadata["role"] == "user" + + def test_search_ai_message_role(self, retriever, short_term): + short_term.add_ai_message("Python answer") + results = retriever.search("python", scope="current") + assert results["current_session"][0].item.metadata["role"] == "ai" + + +class TestSearchCurrentSession: + """Tests for search_current_session convenience method.""" + + def test_returns_list(self, retriever, short_term): + short_term.add_user_message("Python test") + results = retriever.search_current_session("python") + assert isinstance(results, list) + + def test_returns_search_results(self, retriever, short_term): + short_term.add_user_message("Python test") + results = retriever.search_current_session("python") + assert len(results) == 1 + assert isinstance(results[0], SearchResult) + + def test_no_match(self, retriever, short_term): + short_term.add_user_message("Java") + results = retriever.search_current_session("python") + assert results == [] + + def test_limit(self, retriever, short_term): + for i in range(10): + short_term.add_user_message(f"Python {i}") + results = retriever.search_current_session("python", limit=2) + assert len(results) <= 2 + + +class TestSearchHistory: + """Tests for search_history convenience method.""" + + def test_returns_list(self, retriever, long_term): + long_term.add("Python history", session_id="s1") + results = retriever.search_history("python") + assert isinstance(results, list) + + def test_returns_search_results(self, retriever, long_term): + long_term.add("Python history", session_id="s1") + results = retriever.search_history("python") + assert len(results) == 1 + assert isinstance(results[0], SearchResult) + + def test_specific_session(self, retriever, long_term): + long_term.add("Python s1", session_id="s1") + long_term.add("Python s2", session_id="s2") + results = retriever.search_history("python", session_id="s1") + assert len(results) == 1 + assert results[0].item.session_id == "s1" + + def test_limit(self, retriever, long_term): + for i in range(10): + long_term.add(f"Python {i}", session_id="s1") + results = retriever.search_history("python", limit=3) + assert len(results) <= 3 + + +class TestGetRecentMemories: + """Tests for get_recent_memories.""" + + def test_empty(self, retriever): + results = retriever.get_recent_memories() + assert results == [] + + def test_returns_short_term_items(self, retriever, short_term): + short_term.add_user_message("recent msg") + results = retriever.get_recent_memories() + assert len(results) >= 1 + contents = [r.content for r in results] + assert "recent msg" in contents + + def test_returns_long_term_items(self, retriever, long_term): + long_term.add("historical", session_id="old-sess") + results = retriever.get_recent_memories() + contents = [r.content for r in results] + assert "historical" in contents + + def test_sorted_by_created_at_desc(self, retriever, short_term): + short_term.add_user_message("first") + short_term.add_ai_message("second") + results = retriever.get_recent_memories(limit=10) + if len(results) >= 2: + assert results[0].created_at >= results[1].created_at + + def test_limit(self, retriever, short_term, long_term): + short_term.add_user_message("m1") + short_term.add_ai_message("m2") + for i in range(5): + long_term.add(f"hist-{i}", session_id=f"s{i}") + results = retriever.get_recent_memories(limit=3) + assert len(results) <= 3 + + +class TestSaveCurrentSession: + """Tests for save_current_session.""" + + def test_saves_to_long_term(self, retriever, short_term, long_term): + short_term.add_user_message("save me") + short_term.add_ai_message("saved too") + retriever.save_current_session() + archived_id = f"archived_{short_term.session_id}" + items = long_term.get_session_memories(archived_id) + assert len(items) == 2 + + def test_preserves_content(self, retriever, short_term, long_term): + short_term.add_user_message("exact content") + retriever.save_current_session() + archived_id = f"archived_{short_term.session_id}" + items = long_term.get_session_memories(archived_id) + assert items[0].content == "exact content" + + def test_preserves_metadata(self, retriever, short_term, long_term): + short_term.add_user_message("user msg") + retriever.save_current_session() + archived_id = f"archived_{short_term.session_id}" + items = long_term.get_session_memories(archived_id) + assert items[0].metadata["role"] == "user" + + def test_empty_session(self, retriever, short_term, long_term): + retriever.save_current_session() + archived_id = f"archived_{short_term.session_id}" + items = long_term.get_session_memories(archived_id) + assert items == [] + + +class TestLoadSession: + """Tests for load_session.""" + + def test_loads_existing_session(self, retriever, long_term): + long_term.add("item", session_id="load-me") + results = retriever.load_session("load-me") + assert len(results) == 1 + assert results[0].content == "item" + + def test_loads_nonexistent_session(self, retriever): + results = retriever.load_session("does-not-exist") + assert results == [] + + +class TestGetAllSessions: + """Tests for get_all_sessions.""" + + def test_includes_current_session(self, retriever, short_term): + sessions = retriever.get_all_sessions() + assert short_term.session_id in sessions + + def test_includes_long_term_sessions(self, retriever, long_term): + long_term.add("item", session_id="hist-s1") + sessions = retriever.get_all_sessions() + assert "hist-s1" in sessions + + def test_no_duplicates(self, retriever, short_term, long_term): + # If long-term has a session with same id as short-term + long_term.add("item", session_id=short_term.session_id) + sessions = retriever.get_all_sessions() + assert sessions.count(short_term.session_id) == 1 + + def test_returns_list(self, retriever): + sessions = retriever.get_all_sessions() + assert isinstance(sessions, list) + + +class TestSessionMemoryInit: + """Tests for SessionMemory initialization.""" + + def test_default_init(self, tmp_path): + sm = SessionMemory(storage_dir=str(tmp_path / "mem")) + assert sm.session_id.startswith("session_") + assert isinstance(sm.short_term, ShortTermMemory) + assert isinstance(sm.long_term, LongTermMemory) + assert isinstance(sm.retriever, MemoryRetriever) + + def test_custom_session_id(self, tmp_path): + sm = SessionMemory(session_id="custom-sid", storage_dir=str(tmp_path / "mem")) + assert sm.session_id == "custom-sid" + + def test_custom_max_tokens(self, tmp_path): + sm = SessionMemory(max_tokens=5000, storage_dir=str(tmp_path / "mem")) + assert sm.short_term.max_tokens == 5000 + + +class TestSessionMemoryAddMessage: + """Tests for SessionMemory.add_message.""" + + def test_add_user_message(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("hello", role="user") + assert sm.short_term.message_count == 1 + assert isinstance(sm.short_term.messages[0], HumanMessage) + + def test_add_ai_message(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("reply", role="ai") + assert isinstance(sm.short_term.messages[0], AIMessage) + + def test_add_assistant_message(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("reply", role="assistant") + assert isinstance(sm.short_term.messages[0], AIMessage) + + def test_add_system_message(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("system prompt", role="system") + assert isinstance(sm.short_term.messages[0], SystemMessage) + + def test_add_unknown_role_ignored(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("content", role="unknown") + assert sm.short_term.message_count == 0 + + def test_add_message_persists_to_long_term(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("persisted", role="user") + items = sm.long_term.get_session_memories("s1") + assert len(items) == 1 + assert items[0].content == "persisted" + + def test_add_multiple_messages(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("q1", role="user") + sm.add_message("a1", role="ai") + sm.add_message("q2", role="user") + assert sm.short_term.message_count == 3 + items = sm.long_term.get_session_memories("s1") + assert len(items) == 3 + + +class TestSessionMemoryGetContext: + """Tests for SessionMemory.get_context.""" + + def test_get_all_context(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("m1", role="user") + sm.add_message("m2", role="ai") + ctx = sm.get_context() + assert len(ctx) == 2 + + def test_get_limited_context(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("m1", role="user") + sm.add_message("m2", role="ai") + sm.add_message("m3", role="user") + ctx = sm.get_context(max_messages=2) + assert len(ctx) == 2 + + def test_get_context_none_limit(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("m1", role="user") + ctx = sm.get_context(max_messages=None) + assert len(ctx) == 1 + + +class TestSessionMemorySearch: + """Tests for SessionMemory.search.""" + + def test_search_all(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("Python question", role="user") + sm.add_message("Python answer", role="ai") + results = sm.search("python", scope="all") + assert "current_session" in results + assert "history" in results + + def test_search_current(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("Python question", role="user") + results = sm.search("python", scope="current") + assert len(results["current_session"]) == 1 + + def test_search_no_match(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("Java", role="user") + results = sm.search("python") + assert results["current_session"] == [] + + +class TestSessionMemoryClear: + """Tests for SessionMemory.clear.""" + + def test_clear_short_term(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("msg", role="user") + sm.clear() + assert sm.short_term.message_count == 0 + + def test_clear_does_not_affect_long_term(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("msg", role="user") + sm.clear() + items = sm.long_term.get_session_memories("s1") + assert len(items) == 1 + + +class TestSessionMemorySave: + """Tests for SessionMemory.save (no-op).""" + + def test_save_does_not_raise(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.save() # Should not raise + + +class TestSessionMemoryGetStats: + """Tests for SessionMemory.get_stats.""" + + def test_stats_keys(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + stats = sm.get_stats() + assert "session_id" in stats + assert "current_messages" in stats + assert "current_tokens" in stats + assert "history_sessions" in stats + + def test_stats_values(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("msg1", role="user") + sm.add_message("msg2", role="ai") + stats = sm.get_stats() + assert stats["session_id"] == "s1" + assert stats["current_messages"] == 2 + assert stats["current_tokens"] > 0 + + def test_stats_history_sessions(self, tmp_path): + sm = SessionMemory(session_id="s1", storage_dir=str(tmp_path / "mem")) + sm.add_message("msg", role="user") + stats = sm.get_stats() + # At least 1 session in long-term (the current session's persisted messages) + assert stats["history_sessions"] >= 1 + + +class TestCreateSessionMemoryFactory: + """Tests for the create_session_memory factory function.""" + + def test_factory_creates_session_memory(self, tmp_path): + sm = create_session_memory(session_id="factory-test") + assert isinstance(sm, SessionMemory) + assert sm.session_id == "factory-test" + + def test_factory_default_session_id(self, tmp_path): + sm = create_session_memory() + assert sm.session_id.startswith("session_") + + def test_factory_custom_max_tokens(self, tmp_path): + sm = create_session_memory(max_tokens=1000) + assert sm.short_term.max_tokens == 1000 diff --git a/tests/test_memory/test_short_term.py b/tests/test_memory/test_short_term.py new file mode 100644 index 0000000..f953275 --- /dev/null +++ b/tests/test_memory/test_short_term.py @@ -0,0 +1,426 @@ +"""Tests for ShortTermMemory.""" + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage + +from jojo_code.memory.short_term import ShortTermMemory, create_session_memory +from jojo_code.memory.types import MemoryType + + +class TestShortTermMemoryInit: + """Tests for ShortTermMemory initialization.""" + + def test_default_init(self): + mem = ShortTermMemory() + assert isinstance(mem.session_id, str) + assert len(mem.session_id) > 0 + assert mem.max_tokens == 100000 + assert mem.messages == [] + assert mem.message_count == 0 + + def test_custom_session_id(self): + mem = ShortTermMemory(session_id="my-session") + assert mem.session_id == "my-session" + + def test_custom_max_tokens(self): + mem = ShortTermMemory(max_tokens=5000) + assert mem.max_tokens == 5000 + + def test_created_at_is_set(self): + mem = ShortTermMemory() + assert mem.created_at is not None + + def test_unique_session_ids_by_default(self): + mem1 = ShortTermMemory() + mem2 = ShortTermMemory() + assert mem1.session_id != mem2.session_id + + +class TestAddMessage: + """Tests for add_message and convenience methods.""" + + def test_add_message(self): + mem = ShortTermMemory() + msg = HumanMessage(content="Hello") + mem.add_message(msg) + assert mem.message_count == 1 + assert mem.messages[0] is msg + + def test_add_messages_batch(self): + mem = ShortTermMemory() + msgs = [ + HumanMessage(content="Q1"), + AIMessage(content="A1"), + HumanMessage(content="Q2"), + ] + mem.add_messages(msgs) + assert mem.message_count == 3 + + def test_add_user_message(self): + mem = ShortTermMemory() + msg = mem.add_user_message("user says hi") + assert isinstance(msg, HumanMessage) + assert msg.content == "user says hi" + assert mem.message_count == 1 + + def test_add_ai_message(self): + mem = ShortTermMemory() + msg = mem.add_ai_message("ai responds") + assert isinstance(msg, AIMessage) + assert msg.content == "ai responds" + assert mem.message_count == 1 + + def test_add_system_message(self): + mem = ShortTermMemory() + msg = mem.add_system_message("system prompt") + assert isinstance(msg, SystemMessage) + assert msg.content == "system prompt" + assert mem.message_count == 1 + + def test_add_multiple_mixed(self): + mem = ShortTermMemory() + mem.add_system_message("You are helpful") + mem.add_user_message("Question") + mem.add_ai_message("Answer") + assert mem.message_count == 3 + + +class TestGetMessages: + """Tests for get_messages, get_last_n, get_messages_by_role.""" + + def test_get_messages_returns_copy(self): + mem = ShortTermMemory() + mem.add_user_message("test") + msgs = mem.get_messages() + msgs.clear() + assert mem.message_count == 1 # original unchanged + + def test_get_messages_empty(self): + mem = ShortTermMemory() + assert mem.get_messages() == [] + + def test_get_last_n_less_than_total(self): + mem = ShortTermMemory() + for i in range(5): + mem.add_user_message(f"msg-{i}") + result = mem.get_last_n(2) + assert len(result) == 2 + assert result[0].content == "msg-3" + assert result[1].content == "msg-4" + + def test_get_last_n_equal_to_total(self): + mem = ShortTermMemory() + for i in range(3): + mem.add_user_message(f"msg-{i}") + result = mem.get_last_n(3) + assert len(result) == 3 + + def test_get_last_n_greater_than_total(self): + mem = ShortTermMemory() + mem.add_user_message("only one") + result = mem.get_last_n(10) + assert len(result) == 1 + + def test_get_messages_by_role_user(self): + mem = ShortTermMemory() + mem.add_user_message("u1") + mem.add_ai_message("a1") + mem.add_user_message("u2") + result = mem.get_messages_by_role("user") + assert len(result) == 2 + assert all(isinstance(m, HumanMessage) for m in result) + + def test_get_messages_by_role_ai(self): + mem = ShortTermMemory() + mem.add_user_message("u1") + mem.add_ai_message("a1") + mem.add_ai_message("a2") + result = mem.get_messages_by_role("ai") + assert len(result) == 2 + assert all(isinstance(m, AIMessage) for m in result) + + def test_get_messages_by_role_assistant_alias(self): + mem = ShortTermMemory() + mem.add_ai_message("reply") + result = mem.get_messages_by_role("assistant") + assert len(result) == 1 + + def test_get_messages_by_role_system(self): + mem = ShortTermMemory() + mem.add_system_message("sys") + mem.add_user_message("user") + result = mem.get_messages_by_role("system") + assert len(result) == 1 + assert isinstance(result[0], SystemMessage) + + def test_get_messages_by_role_case_insensitive(self): + mem = ShortTermMemory() + mem.add_user_message("test") + assert len(mem.get_messages_by_role("USER")) == 1 + assert len(mem.get_messages_by_role("User")) == 1 + + def test_get_messages_by_role_unknown(self): + mem = ShortTermMemory() + mem.add_user_message("test") + assert mem.get_messages_by_role("unknown") == [] + + def test_get_messages_by_role_empty(self): + mem = ShortTermMemory() + assert mem.get_messages_by_role("user") == [] + + +class TestTokenCount: + """Tests for token_count.""" + + def test_token_count_empty(self): + mem = ShortTermMemory() + assert mem.token_count() == 0 + + def test_token_count_single_message(self): + mem = ShortTermMemory() + mem.add_user_message("Hello World") + count = mem.token_count() + assert count > 0 + + def test_token_count_accumulates(self): + mem = ShortTermMemory() + mem.add_user_message("Hello") + count1 = mem.token_count() + mem.add_ai_message("Hello back") + count2 = mem.token_count() + assert count2 > count1 + + def test_token_count_longer_content_is_more(self): + mem_short = ShortTermMemory() + mem_short.add_user_message("Hi") + mem_long = ShortTermMemory() + mem_long.add_user_message("This is a much longer message with many more words") + assert mem_long.token_count() > mem_short.token_count() + + +class TestClear: + """Tests for clear.""" + + def test_clear_removes_all_messages(self): + mem = ShortTermMemory() + mem.add_user_message("msg1") + mem.add_ai_message("msg2") + mem.clear() + assert mem.message_count == 0 + assert mem.get_messages() == [] + + def test_clear_on_empty(self): + mem = ShortTermMemory() + mem.clear() # should not raise + assert mem.message_count == 0 + + def test_can_add_after_clear(self): + mem = ShortTermMemory() + mem.add_user_message("before") + mem.clear() + mem.add_user_message("after") + assert mem.message_count == 1 + assert mem.messages[0].content == "after" + + +class TestSearch: + """Tests for search.""" + + def test_search_finds_match(self): + mem = ShortTermMemory() + mem.add_user_message("Python is great") + mem.add_ai_message("I agree, Python is versatile") + mem.add_user_message("What about Java?") + results = mem.search("python") + assert len(results) == 2 + + def test_search_case_insensitive_by_default(self): + mem = ShortTermMemory() + mem.add_user_message("HELLO world") + results = mem.search("hello") + assert len(results) == 1 + + def test_search_case_sensitive(self): + mem = ShortTermMemory() + mem.add_user_message("HELLO world") + results = mem.search("hello", case_sensitive=True) + assert len(results) == 0 + + def test_search_case_sensitive_match(self): + mem = ShortTermMemory() + mem.add_user_message("HELLO world") + results = mem.search("HELLO", case_sensitive=True) + assert len(results) == 1 + + def test_search_no_match(self): + mem = ShortTermMemory() + mem.add_user_message("Python is great") + results = mem.search("javascript") + assert len(results) == 0 + + def test_search_empty_messages(self): + mem = ShortTermMemory() + results = mem.search("anything") + assert results == [] + + def test_search_partial_match(self): + mem = ShortTermMemory() + mem.add_user_message("The program is running") + results = mem.search("progr") + assert len(results) == 1 + + def test_search_returns_message_instances(self): + mem = ShortTermMemory() + msg = mem.add_user_message("searchable content") + results = mem.search("searchable") + assert results[0] is msg + + +class TestCompress: + """Tests for _compress behavior.""" + + def test_compress_preserves_system_messages(self): + mem = ShortTermMemory(max_tokens=1) + mem.add_system_message("You are a helpful assistant") + # Add many messages to exceed token limit + for i in range(30): + mem.add_user_message(f"Message number {i} with some content to fill tokens") + system_msgs = [m for m in mem.messages if isinstance(m, SystemMessage)] + assert len(system_msgs) >= 1 + + def test_compress_preserves_recent_messages(self): + mem = ShortTermMemory(max_tokens=1) + for i in range(30): + mem.add_user_message(f"Message {i} " * 10) + # The last message should still be present + last_msg = mem.messages[-1] + assert "Message 29" in last_msg.content + + def test_compress_creates_summary_message(self): + mem = ShortTermMemory(max_tokens=1) + for i in range(30): + mem.add_user_message(f"Message {i} " * 10) + # Should have a compression summary + summary_msgs = [m for m in mem.messages if "已压缩" in str(m.content)] + assert len(summary_msgs) == 1 + + def test_compress_no_op_when_under_limit(self): + mem = ShortTermMemory(max_tokens=100000) + for i in range(5): + mem.add_user_message(f"msg-{i}") + before = mem.message_count + mem._compress() + assert mem.message_count == before + + def test_compress_no_op_when_at_keep_recent(self): + mem = ShortTermMemory(max_tokens=1) + for i in range(20): + mem.add_user_message(f"msg-{i} " * 5) + # Exactly keep_recent messages, compress should be a no-op on count <= keep_recent + # The compress checks len(messages) <= keep_recent + count_before = len(mem.messages) + mem._compress(keep_recent=count_before) + assert len(mem.messages) == count_before + + +class TestToMemoryItems: + """Tests for to_memory_items.""" + + def test_empty_messages(self): + mem = ShortTermMemory() + items = mem.to_memory_items() + assert items == [] + + def test_user_message_role(self): + mem = ShortTermMemory() + mem.add_user_message("hello") + items = mem.to_memory_items() + assert len(items) == 1 + assert items[0].metadata["role"] == "user" + assert items[0].memory_type == MemoryType.SHORT_TERM + + def test_ai_message_role(self): + mem = ShortTermMemory() + mem.add_ai_message("reply") + items = mem.to_memory_items() + assert items[0].metadata["role"] == "ai" + + def test_system_message_role(self): + mem = ShortTermMemory() + mem.add_system_message("prompt") + items = mem.to_memory_items() + assert items[0].metadata["role"] == "system" + + def test_session_id_propagated(self): + mem = ShortTermMemory(session_id="test-sid") + mem.add_user_message("msg") + items = mem.to_memory_items() + assert items[0].session_id == "test-sid" + + def test_content_preserved(self): + mem = ShortTermMemory() + mem.add_user_message("exact content here") + items = mem.to_memory_items() + assert items[0].content == "exact content here" + + def test_unique_ids(self): + mem = ShortTermMemory() + mem.add_user_message("msg1") + mem.add_user_message("msg2") + items = mem.to_memory_items() + assert items[0].id != items[1].id + + +class TestGetContextCompat: + """Tests for backward-compatible get_context and get_last_n_messages.""" + + def test_get_context_returns_all_messages(self): + mem = ShortTermMemory() + mem.add_user_message("m1") + mem.add_ai_message("m2") + ctx = mem.get_context() + assert len(ctx) == 2 + + def test_get_last_n_messages_alias(self): + mem = ShortTermMemory() + for i in range(5): + mem.add_user_message(f"msg-{i}") + result = mem.get_last_n_messages(3) + assert len(result) == 3 + assert result[0].content == "msg-2" + + +class TestCreateSessionMemory: + """Tests for the factory function.""" + + def test_factory_creates_instance(self): + mem = create_session_memory() + assert isinstance(mem, ShortTermMemory) + + def test_factory_with_session_id(self): + mem = create_session_memory(session_id="factory-sid") + assert mem.session_id == "factory-sid" + + def test_factory_with_max_tokens(self): + mem = create_session_memory(max_tokens=500) + assert mem.max_tokens == 500 + + +class TestMessageCountProperty: + """Tests for message_count property.""" + + def test_zero_on_init(self): + mem = ShortTermMemory() + assert mem.message_count == 0 + + def test_increments_on_add(self): + mem = ShortTermMemory() + mem.add_user_message("a") + assert mem.message_count == 1 + mem.add_ai_message("b") + assert mem.message_count == 2 + + def test_resets_on_clear(self): + mem = ShortTermMemory() + mem.add_user_message("a") + mem.clear() + assert mem.message_count == 0 diff --git a/tests/test_memory/test_types.py b/tests/test_memory/test_types.py new file mode 100644 index 0000000..cc0d8d0 --- /dev/null +++ b/tests/test_memory/test_types.py @@ -0,0 +1,229 @@ +"""Tests for memory type definitions: MemoryType, MemoryScope, MemoryItem, SearchResult.""" + +from datetime import datetime + +import pytest + +from jojo_code.memory.types import MemoryItem, MemoryScope, MemoryType, SearchResult + + +class TestMemoryType: + """Tests for MemoryType enum.""" + + def test_short_term_value(self): + assert MemoryType.SHORT_TERM.value == "short_term" + + def test_long_term_value(self): + assert MemoryType.LONG_TERM.value == "long_term" + + def test_from_value(self): + assert MemoryType("short_term") == MemoryType.SHORT_TERM + assert MemoryType("long_term") == MemoryType.LONG_TERM + + def test_invalid_value_raises(self): + with pytest.raises(ValueError): + MemoryType("nonexistent") + + +class TestMemoryScope: + """Tests for MemoryScope enum.""" + + def test_current_session_value(self): + assert MemoryScope.CURRENT_SESSION.value == "current_session" + + def test_all_sessions_value(self): + assert MemoryScope.ALL_SESSIONS.value == "all_sessions" + + def test_from_value(self): + assert MemoryScope("current_session") == MemoryScope.CURRENT_SESSION + assert MemoryScope("all_sessions") == MemoryScope.ALL_SESSIONS + + +class TestMemoryItem: + """Tests for MemoryItem dataclass.""" + + def test_create_with_required_fields(self): + item = MemoryItem( + id="item-1", + content="Hello world", + memory_type=MemoryType.SHORT_TERM, + session_id="session-1", + ) + assert item.id == "item-1" + assert item.content == "Hello world" + assert item.memory_type == MemoryType.SHORT_TERM + assert item.session_id == "session-1" + assert item.metadata == {} + assert item.tags == [] + assert isinstance(item.created_at, datetime) + + def test_create_with_all_fields(self): + now = datetime(2026, 5, 29, 12, 0, 0) + item = MemoryItem( + id="item-2", + content="Test content", + memory_type=MemoryType.LONG_TERM, + session_id="session-2", + created_at=now, + metadata={"role": "user", "extra": 42}, + tags=["important", "code"], + ) + assert item.created_at == now + assert item.metadata == {"role": "user", "extra": 42} + assert item.tags == ["important", "code"] + + def test_default_created_at_is_recent(self): + before = datetime.now() + item = MemoryItem( + id="item-3", + content="content", + memory_type=MemoryType.SHORT_TERM, + session_id="s1", + ) + after = datetime.now() + assert before <= item.created_at <= after + + def test_default_metadata_is_independent(self): + """Each instance should get its own mutable default dict.""" + item1 = MemoryItem(id="1", content="", memory_type=MemoryType.SHORT_TERM, session_id="s") + item2 = MemoryItem(id="2", content="", memory_type=MemoryType.SHORT_TERM, session_id="s") + item1.metadata["key"] = "value" + assert "key" not in item2.metadata + + def test_default_tags_is_independent(self): + """Each instance should get its own mutable default list.""" + item1 = MemoryItem(id="1", content="", memory_type=MemoryType.SHORT_TERM, session_id="s") + item2 = MemoryItem(id="2", content="", memory_type=MemoryType.SHORT_TERM, session_id="s") + item1.tags.append("tag") + assert "tag" not in item2.tags + + +class TestMemoryItemSerialization: + """Tests for MemoryItem to_dict / from_dict round-trip.""" + + def _make_item(self, **overrides) -> MemoryItem: + defaults = { + "id": "test-id", + "content": "test content", + "memory_type": MemoryType.SHORT_TERM, + "session_id": "session-abc", + "created_at": datetime(2026, 5, 29, 10, 30, 0), + "metadata": {"role": "user"}, + "tags": ["python"], + } + defaults.update(overrides) + return MemoryItem(**defaults) + + def test_to_dict_keys(self): + item = self._make_item() + d = item.to_dict() + assert set(d.keys()) == { + "id", + "content", + "memory_type", + "session_id", + "created_at", + "metadata", + "tags", + } + + def test_to_dict_values(self): + item = self._make_item() + d = item.to_dict() + assert d["id"] == "test-id" + assert d["content"] == "test content" + assert d["memory_type"] == "short_term" + assert d["session_id"] == "session-abc" + assert d["created_at"] == "2026-05-29T10:30:00" + assert d["metadata"] == {"role": "user"} + assert d["tags"] == ["python"] + + def test_round_trip(self): + item = self._make_item() + d = item.to_dict() + restored = MemoryItem.from_dict(d) + assert restored.id == item.id + assert restored.content == item.content + assert restored.memory_type == item.memory_type + assert restored.session_id == item.session_id + assert restored.created_at == item.created_at + assert restored.metadata == item.metadata + assert restored.tags == item.tags + + def test_round_trip_long_term(self): + item = self._make_item(memory_type=MemoryType.LONG_TERM) + d = item.to_dict() + restored = MemoryItem.from_dict(d) + assert restored.memory_type == MemoryType.LONG_TERM + + def test_round_trip_empty_metadata_and_tags(self): + item = self._make_item(metadata={}, tags=[]) + d = item.to_dict() + restored = MemoryItem.from_dict(d) + assert restored.metadata == {} + assert restored.tags == [] + + def test_round_trip_with_special_characters(self): + item = self._make_item(content="Line1\nLine2\tTabbed \"quoted\" 'single'") + d = item.to_dict() + restored = MemoryItem.from_dict(d) + assert restored.content == item.content + + def test_from_dict_missing_metadata_and_tags_uses_defaults(self): + """from_dict should handle missing optional fields gracefully.""" + data = { + "id": "x", + "content": "c", + "memory_type": "short_term", + "session_id": "s", + "created_at": "2026-01-01T00:00:00", + } + item = MemoryItem.from_dict(data) + assert item.metadata == {} + assert item.tags == [] + + def test_round_trip_preserves_microseconds(self): + now = datetime(2026, 5, 29, 10, 30, 0, 123456) + item = self._make_item(created_at=now) + d = item.to_dict() + restored = MemoryItem.from_dict(d) + assert restored.created_at == now + + +class TestSearchResult: + """Tests for SearchResult dataclass.""" + + def test_create_search_result(self): + item = MemoryItem( + id="sr-1", + content="Python is great", + memory_type=MemoryType.LONG_TERM, + session_id="s1", + ) + result = SearchResult(item=item, score=0.85, matched_content="Python is great") + assert result.item is item + assert result.score == 0.85 + assert result.matched_content == "Python is great" + + def test_score_boundary_zero(self): + item = MemoryItem(id="x", content="", memory_type=MemoryType.SHORT_TERM, session_id="s") + result = SearchResult(item=item, score=0.0, matched_content="") + assert result.score == 0.0 + + def test_score_boundary_one(self): + item = MemoryItem(id="x", content="", memory_type=MemoryType.SHORT_TERM, session_id="s") + result = SearchResult(item=item, score=1.0, matched_content="") + assert result.score == 1.0 + + def test_search_result_stores_item_reference(self): + item = MemoryItem( + id="ref-test", + content="some content", + memory_type=MemoryType.LONG_TERM, + session_id="s1", + tags=["tag1"], + ) + result = SearchResult(item=item, score=0.5, matched_content="some") + # Modifying the original item should reflect in the result + item.tags.append("tag2") + assert "tag2" in result.item.tags diff --git a/tests/test_model_registry/__init__.py b/tests/test_model_registry/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_model_registry/test_factory.py b/tests/test_model_registry/test_factory.py new file mode 100644 index 0000000..8411089 --- /dev/null +++ b/tests/test_model_registry/test_factory.py @@ -0,0 +1,238 @@ +"""Model Factory tests. + +Tests for provider determination logic and model creation helpers. +""" + +from unittest.mock import MagicMock, patch + +from jojo_code.models.factory import _determine_provider, create_model +from jojo_code.models.types import ModelInfo, ModelProvider + + +class TestDetermineProvider: + """Tests for _determine_provider function.""" + + def _make_settings(self, **kwargs): + """Create a mock settings object.""" + settings = MagicMock() + settings.openai_base_url = kwargs.get("openai_base_url", None) + settings.openai_api_key = kwargs.get("openai_api_key", None) + settings.anthropic_api_key = kwargs.get("anthropic_api_key", None) + settings.model = kwargs.get("model", "gpt-4o-mini") + return settings + + def test_model_info_provider_used(self): + """When model_info exists, its provider is returned.""" + info = ModelInfo( + name="my-model", + provider=ModelProvider.ANTHROPIC, + display_name="My Model", + description="test", + ) + settings = self._make_settings() + result = _determine_provider("my-model", info, settings) + assert result == ModelProvider.ANTHROPIC + + def test_claude_name_returns_anthropic(self): + """Model names starting with 'claude' should use Anthropic.""" + settings = self._make_settings() + result = _determine_provider("claude-sonnet-4-20250514", None, settings) + assert result == ModelProvider.ANTHROPIC + + def test_gpt_name_returns_openai(self): + """Model names starting with 'gpt' should use OpenAI.""" + settings = self._make_settings() + result = _determine_provider("gpt-4o", None, settings) + assert result == ModelProvider.OPENAI + + def test_anthropic_env_var_fallback(self): + """Without model_info or name prefix, ANTHROPIC_API_KEY selects Anthropic.""" + settings = self._make_settings() + with patch("jojo_code.models.factory.os.getenv") as mock_getenv: + mock_getenv.side_effect = lambda k: { + "ANTHROPIC_API_KEY": "sk-ant-test", + }.get(k) + result = _determine_provider("some-model", None, settings) + assert result == ModelProvider.ANTHROPIC + + def test_custom_base_url_returns_custom(self): + """OPENAI_BASE_URL in settings selects Custom provider.""" + settings = self._make_settings(openai_base_url="https://custom.api/v1") + with patch("jojo_code.models.factory.os.getenv") as mock_getenv: + mock_getenv.return_value = None + result = _determine_provider("some-model", None, settings) + assert result == ModelProvider.CUSTOM + + def test_custom_base_url_env_var(self): + """OPENAI_BASE_URL env var selects Custom provider.""" + settings = self._make_settings() + with patch("jojo_code.models.factory.os.getenv") as mock_getenv: + mock_getenv.side_effect = lambda k: { + "OPENAI_BASE_URL": "https://custom.api/v1", + }.get(k) + result = _determine_provider("some-model", None, settings) + assert result == ModelProvider.CUSTOM + + def test_default_returns_openai(self): + """With no special conditions, defaults to OpenAI.""" + settings = self._make_settings() + with patch("jojo_code.models.factory.os.getenv") as mock_getenv: + mock_getenv.return_value = None + result = _determine_provider("unknown-model", None, settings) + assert result == ModelProvider.OPENAI + + +class TestCreateModel: + """Tests for create_model function (with mocked LLM clients).""" + + @patch("jojo_code.models.factory.ChatOpenAI") + @patch("jojo_code.models.factory.get_model_registry") + @patch("jojo_code.core.config.get_settings") + def test_create_model_default(self, mock_settings, mock_registry, mock_openai): + """Default model creation uses gpt-4o-mini.""" + mock_settings.return_value = MagicMock( + model=None, + openai_base_url=None, + openai_api_key=None, + anthropic_api_key=None, + ) + mock_registry.return_value = MagicMock(get=MagicMock(return_value=None)) + mock_openai.return_value = MagicMock() + + create_model() + mock_openai.assert_called_once() + + @patch("jojo_code.models.factory.ChatOpenAI") + @patch("jojo_code.models.factory.get_model_registry") + @patch("jojo_code.core.config.get_settings") + def test_create_model_explicit_name(self, mock_settings, mock_registry, mock_openai): + """Explicit model name is passed through.""" + mock_settings.return_value = MagicMock( + openai_base_url=None, + openai_api_key=None, + anthropic_api_key=None, + ) + mock_registry.return_value = MagicMock(get=MagicMock(return_value=None)) + mock_openai.return_value = MagicMock() + + create_model(model_name="gpt-4o", temperature=0.5) + call_kwargs = mock_openai.call_args + assert call_kwargs[1]["model"] == "gpt-4o" + assert call_kwargs[1]["temperature"] == 0.5 + + @patch("jojo_code.models.factory.ChatAnthropic") + @patch("jojo_code.models.factory.get_model_registry") + @patch("jojo_code.core.config.get_settings") + def test_create_anthropic_model(self, mock_settings, mock_registry, mock_anthropic): + """Claude model name creates ChatAnthropic.""" + mock_settings.return_value = MagicMock( + openai_base_url=None, + openai_api_key=None, + anthropic_api_key="sk-ant-test", + ) + mock_registry.return_value = MagicMock(get=MagicMock(return_value=None)) + mock_anthropic.return_value = MagicMock() + + create_model(model_name="claude-sonnet-4-20250514") + mock_anthropic.assert_called_once() + call_kwargs = mock_anthropic.call_args + assert call_kwargs[1]["model"] == "claude-sonnet-4-20250514" + + +class TestModelCreationHelpers: + """Tests for create_fast_model, create_smart_model, create_cheap_model.""" + + @patch("jojo_code.models.factory.create_model") + @patch("jojo_code.models.factory.get_model_registry") + def test_create_fast_model_with_fast_models(self, mock_registry, mock_create): + """Fast model selects cheapest fast model.""" + fast_model = ModelInfo( + name="gpt-4o-mini", + provider=ModelProvider.OPENAI, + display_name="GPT-4o Mini", + description="fast", + cost_per_1k_input=0.00015, + tags=["fast"], + ) + mock_registry.return_value = MagicMock(list_fast=MagicMock(return_value=[fast_model])) + mock_create.return_value = MagicMock() + + from jojo_code.models.factory import create_fast_model + + create_fast_model(temperature=0.3) + mock_create.assert_called_once_with("gpt-4o-mini", 0.3) + + @patch("jojo_code.models.factory.create_model") + @patch("jojo_code.models.factory.get_model_registry") + def test_create_fast_model_defaults_to_mini(self, mock_registry, mock_create): + """Fast model defaults to gpt-4o-mini when no fast models exist.""" + mock_registry.return_value = MagicMock(list_fast=MagicMock(return_value=[])) + mock_create.return_value = MagicMock() + + from jojo_code.models.factory import create_fast_model + + create_fast_model() + mock_create.assert_called_once_with("gpt-4o-mini", 0.3) + + @patch("jojo_code.models.factory.create_model") + @patch("jojo_code.models.factory.get_model_registry") + def test_create_smart_model_with_smart_models(self, mock_registry, mock_create): + """Smart model selects first smart model.""" + smart_model = ModelInfo( + name="claude-opus-4-20250514", + provider=ModelProvider.ANTHROPIC, + display_name="Claude 4 Opus", + description="smart", + tags=["smart"], + ) + mock_registry.return_value = MagicMock(list_smart=MagicMock(return_value=[smart_model])) + mock_create.return_value = MagicMock() + + from jojo_code.models.factory import create_smart_model + + create_smart_model(temperature=0.7) + mock_create.assert_called_once_with("claude-opus-4-20250514", 0.7) + + @patch("jojo_code.models.factory.create_model") + @patch("jojo_code.models.factory.get_model_registry") + def test_create_smart_model_defaults_to_gpt4o(self, mock_registry, mock_create): + """Smart model defaults to gpt-4o when no smart models exist.""" + mock_registry.return_value = MagicMock(list_smart=MagicMock(return_value=[])) + mock_create.return_value = MagicMock() + + from jojo_code.models.factory import create_smart_model + + create_smart_model() + mock_create.assert_called_once_with("gpt-4o", 0.7) + + @patch("jojo_code.models.factory.create_model") + @patch("jojo_code.models.factory.get_model_registry") + def test_create_cheap_model_with_cheap_models(self, mock_registry, mock_create): + """Cheap model selects cheapest model.""" + cheap_model = ModelInfo( + name="gpt-4o-mini", + provider=ModelProvider.OPENAI, + display_name="GPT-4o Mini", + description="cheap", + cost_per_1k_input=0.00015, + tags=["cheap"], + ) + mock_registry.return_value = MagicMock(list_cheap=MagicMock(return_value=[cheap_model])) + mock_create.return_value = MagicMock() + + from jojo_code.models.factory import create_cheap_model + + create_cheap_model(temperature=0.5) + mock_create.assert_called_once_with("gpt-4o-mini", 0.5) + + @patch("jojo_code.models.factory.create_model") + @patch("jojo_code.models.factory.get_model_registry") + def test_create_cheap_model_defaults_to_mini(self, mock_registry, mock_create): + """Cheap model defaults to gpt-4o-mini when no cheap models exist.""" + mock_registry.return_value = MagicMock(list_cheap=MagicMock(return_value=[])) + mock_create.return_value = MagicMock() + + from jojo_code.models.factory import create_cheap_model + + create_cheap_model() + mock_create.assert_called_once_with("gpt-4o-mini", 0.5) diff --git a/tests/test_model_registry/test_registry.py b/tests/test_model_registry/test_registry.py new file mode 100644 index 0000000..14bcd6e --- /dev/null +++ b/tests/test_model_registry/test_registry.py @@ -0,0 +1,86 @@ +"""ModelRegistry 测试 + +测试模型注册、注销、过滤、预置模型。 +""" + +import pytest + +from jojo_code.models.registry import ModelRegistry +from jojo_code.models.types import ( + PRESET_MODELS, + ModelCapability, + ModelInfo, + ModelProvider, +) + + +@pytest.fixture +def registry(): + return ModelRegistry() + + +class TestModelRegistry: + def test_has_preset_models(self, registry): + assert len(registry._models) >= len(PRESET_MODELS) + + def test_get_existing_model(self, registry): + info = registry.get("gpt-4o") + assert info is not None + assert info.name == "gpt-4o" + assert info.provider == ModelProvider.OPENAI + + def test_get_nonexistent_returns_none(self, registry): + assert registry.get("nonexistent") is None + + def test_register_custom_model(self, registry): + custom = ModelInfo( + name="my-model", + provider=ModelProvider.CUSTOM, + display_name="My Model", + description="Custom model", + ) + registry.register(custom) + assert registry.get("my-model") is custom + + def test_unregister_custom_model(self, registry): + custom = ModelInfo( + name="temp-model", + provider=ModelProvider.CUSTOM, + display_name="Temp", + description="Temporary", + ) + registry.register(custom) + assert registry.unregister("temp-model") is True + assert registry.get("temp-model") is None + + def test_cannot_unregister_preset_model(self, registry): + assert registry.unregister("gpt-4o") is False + assert registry.get("gpt-4o") is not None + + def test_unregister_nonexistent_returns_false(self, registry): + assert registry.unregister("nope") is False + + def test_list_by_provider(self, registry): + openai_models = registry.list_by_provider(ModelProvider.OPENAI) + assert len(openai_models) >= 1 + assert all(m.provider == ModelProvider.OPENAI for m in openai_models) + + def test_list_by_capability(self, registry): + vision_models = registry.list_models(capability=ModelCapability.VISION) + assert len(vision_models) >= 1 + assert all(ModelCapability.VISION in m.capabilities for m in vision_models) + + def test_list_fast_models(self, registry): + fast = registry.list_fast() + assert len(fast) >= 1 + assert any("fast" in m.tags for m in fast) + + def test_list_smart_models(self, registry): + smart = registry.list_smart() + assert len(smart) >= 1 + + def test_get_stats(self, registry): + stats = registry.get_stats() + assert "total" in stats + assert "by_provider" in stats + assert stats["total"] >= len(PRESET_MODELS) diff --git a/tests/test_models/test_registry.py b/tests/test_models/test_registry.py index 14bcd6e..7454b1d 100644 --- a/tests/test_models/test_registry.py +++ b/tests/test_models/test_registry.py @@ -1,11 +1,16 @@ -"""ModelRegistry 测试 +"""Model registry tests. -测试模型注册、注销、过滤、预置模型。 +Tests for ModelRegistry: registration, unregistration, filtering, stats, +custom providers, and global registry management. """ import pytest -from jojo_code.models.registry import ModelRegistry +from jojo_code.models.registry import ( + ModelRegistry, + get_model_registry, + set_model_registry, +) from jojo_code.models.types import ( PRESET_MODELS, ModelCapability, @@ -19,49 +24,101 @@ def registry(): return ModelRegistry() -class TestModelRegistry: +@pytest.fixture +def custom_model(): + return ModelInfo( + name="my-custom-model", + provider=ModelProvider.CUSTOM, + display_name="My Custom Model", + description="A custom test model", + context_length=64000, + capabilities=[ModelCapability.CHAT, ModelCapability.STREAMING], + cost_per_1k_input=0.001, + cost_per_1k_output=0.002, + tags=["test"], + ) + + +class TestModelRegistryInit: def test_has_preset_models(self, registry): assert len(registry._models) >= len(PRESET_MODELS) - def test_get_existing_model(self, registry): + def test_preset_models_copied(self, registry): + # Modifying registry should not affect PRESET_MODELS + custom = ModelInfo( + name="temp", + provider=ModelProvider.CUSTOM, + display_name="Temp", + description="Temp", + ) + registry.register(custom) + assert "temp" not in PRESET_MODELS + + def test_custom_providers_empty(self, registry): + assert registry._custom_providers == {} + + +class TestModelRegistryGet: + def test_get_existing_preset(self, registry): info = registry.get("gpt-4o") assert info is not None assert info.name == "gpt-4o" assert info.provider == ModelProvider.OPENAI def test_get_nonexistent_returns_none(self, registry): - assert registry.get("nonexistent") is None + assert registry.get("nonexistent-model") is None - def test_register_custom_model(self, registry): - custom = ModelInfo( - name="my-model", + def test_get_registered_custom(self, registry, custom_model): + registry.register(custom_model) + result = registry.get("my-custom-model") + assert result is custom_model + + +class TestModelRegistryRegister: + def test_register_custom_model(self, registry, custom_model): + registry.register(custom_model) + assert registry.get("my-custom-model") is custom_model + + def test_register_overwrites_existing(self, registry): + m1 = ModelInfo( + name="overwrite-test", provider=ModelProvider.CUSTOM, - display_name="My Model", - description="Custom model", + display_name="V1", + description="Version 1", ) - registry.register(custom) - assert registry.get("my-model") is custom - - def test_unregister_custom_model(self, registry): - custom = ModelInfo( - name="temp-model", + m2 = ModelInfo( + name="overwrite-test", provider=ModelProvider.CUSTOM, - display_name="Temp", - description="Temporary", + display_name="V2", + description="Version 2", ) - registry.register(custom) - assert registry.unregister("temp-model") is True - assert registry.get("temp-model") is None + registry.register(m1) + registry.register(m2) + result = registry.get("overwrite-test") + assert result.display_name == "V2" - def test_cannot_unregister_preset_model(self, registry): + +class TestModelRegistryUnregister: + def test_unregister_custom_model(self, registry, custom_model): + registry.register(custom_model) + assert registry.unregister("my-custom-model") is True + assert registry.get("my-custom-model") is None + + def test_cannot_unregister_preset(self, registry): assert registry.unregister("gpt-4o") is False assert registry.get("gpt-4o") is not None def test_unregister_nonexistent_returns_false(self, registry): assert registry.unregister("nope") is False + +class TestModelRegistryListModels: + def test_list_all(self, registry): + all_models = registry.list_models() + assert len(all_models) >= len(PRESET_MODELS) + def test_list_by_provider(self, registry): - openai_models = registry.list_by_provider(ModelProvider.OPENAI) + openai_models = registry.list_models(provider=ModelProvider.OPENAI) assert len(openai_models) >= 1 assert all(m.provider == ModelProvider.OPENAI for m in openai_models) @@ -70,17 +127,160 @@ def test_list_by_capability(self, registry): assert len(vision_models) >= 1 assert all(ModelCapability.VISION in m.capabilities for m in vision_models) - def test_list_fast_models(self, registry): + def test_list_by_tags(self, registry): + fast_models = registry.list_models(tags=["fast"]) + assert len(fast_models) >= 1 + assert all("fast" in m.tags for m in fast_models) + + def test_list_by_multiple_tags_any_match(self, registry): + models = registry.list_models(tags=["smart", "cheap"]) + assert len(models) >= 1 + for m in models: + assert any(tag in m.tags for tag in ["smart", "cheap"]) + + def test_list_combined_filters(self, registry): + anthropic_vision = registry.list_models( + provider=ModelProvider.ANTHROPIC, + capability=ModelCapability.VISION, + ) + assert len(anthropic_vision) >= 1 + for m in anthropic_vision: + assert m.provider == ModelProvider.ANTHROPIC + assert ModelCapability.VISION in m.capabilities + + def test_list_no_match_returns_empty(self, registry): + models = registry.list_models(provider=ModelProvider.LOCAL) + assert models == [] + + def test_list_by_provider_method(self, registry): + anthropic = registry.list_by_provider(ModelProvider.ANTHROPIC) + assert len(anthropic) >= 1 + assert all(m.provider == ModelProvider.ANTHROPIC for m in anthropic) + + def test_list_fast(self, registry): fast = registry.list_fast() assert len(fast) >= 1 - assert any("fast" in m.tags for m in fast) + assert all("fast" in m.tags for m in fast) - def test_list_smart_models(self, registry): + def test_list_cheap(self, registry): + cheap = registry.list_cheap() + assert len(cheap) >= 1 + assert all("cheap" in m.tags for m in cheap) + + def test_list_smart(self, registry): smart = registry.list_smart() assert len(smart) >= 1 + for m in smart: + assert any(tag in m.tags for tag in ["smart", "reasoning"]) + + +class TestModelRegistryCustomProvider: + def test_register_custom_provider(self, registry): + factory = lambda: "fake_model" # noqa: E731 + registry.register_custom_provider(ModelProvider.LOCAL, factory) + assert ModelProvider.LOCAL in registry._custom_providers + assert registry._custom_providers[ModelProvider.LOCAL] is factory + def test_register_multiple_providers(self, registry): + f1 = lambda: "local" # noqa: E731 + f2 = lambda: "custom" # noqa: E731 + registry.register_custom_provider(ModelProvider.LOCAL, f1) + registry.register_custom_provider(ModelProvider.CUSTOM, f2) + assert len(registry._custom_providers) == 2 + + +class TestModelRegistryStats: def test_get_stats(self, registry): stats = registry.get_stats() assert "total" in stats assert "by_provider" in stats assert stats["total"] >= len(PRESET_MODELS) + + def test_stats_by_provider_counts(self, registry): + stats = registry.get_stats() + by_provider = stats["by_provider"] + assert "openai" in by_provider + assert "anthropic" in by_provider + assert by_provider["openai"] >= 1 + assert by_provider["anthropic"] >= 1 + + def test_stats_reflect_registration(self, registry, custom_model): + before = registry.get_stats()["total"] + registry.register(custom_model) + after = registry.get_stats()["total"] + assert after == before + 1 + + def test_stats_reflect_unregistration(self, registry, custom_model): + registry.register(custom_model) + before = registry.get_stats()["total"] + registry.unregister("my-custom-model") + after = registry.get_stats()["total"] + assert after == before - 1 + + +class TestGlobalRegistry: + def test_get_model_registry_returns_singleton(self): + r1 = get_model_registry() + r2 = get_model_registry() + assert r1 is r2 + + def test_set_model_registry(self): + original = get_model_registry() + custom = ModelRegistry() + set_model_registry(custom) + assert get_model_registry() is custom + # Restore + set_model_registry(original) + + def test_set_model_registry_none(self): + original = get_model_registry() + set_model_registry(None) + # get_model_registry should create a new one + new_registry = get_model_registry() + assert new_registry is not original + # Restore + set_model_registry(original) + + +class TestModelRegistryEdgeCases: + def test_list_models_empty_tags_returns_all(self, registry): + """Passing tags=[] should return all models (no tag filter).""" + all_models = registry.list_models(tags=[]) + assert len(all_models) == len(registry._models) + + def test_list_models_no_matching_tags(self, registry): + models = registry.list_models(tags=["nonexistent_tag_xyz"]) + assert models == [] + + def test_register_overwrite_preset_model(self, registry): + """Registering with a preset model name overwrites in the registry + but does not affect the original PRESET_MODELS dict.""" + original = registry.get("gpt-4o") + assert original is not None + + replacement = ModelInfo( + name="gpt-4o", + provider=ModelProvider.CUSTOM, + display_name="Overwritten GPT-4o", + description="Replaced", + ) + registry.register(replacement) + assert registry.get("gpt-4o").display_name == "Overwritten GPT-4o" + # PRESET_MODELS should be unaffected (already imported at module level) + assert PRESET_MODELS["gpt-4o"].display_name == "GPT-4o" + + def test_unregister_preset_does_not_change_stats(self, registry, custom_model): + """Unregistering a preset should fail and stats should remain unchanged.""" + before = registry.get_stats()["total"] + result = registry.unregister("gpt-4o") + assert result is False + after = registry.get_stats()["total"] + assert before == after + + def test_list_models_combined_provider_and_tags(self, registry): + models = registry.list_models( + provider=ModelProvider.ANTHROPIC, tags=["fast", "cheap"] + ) + for m in models: + assert m.provider == ModelProvider.ANTHROPIC + assert any(tag in m.tags for tag in ["fast", "cheap"]) diff --git a/tests/test_models/test_types.py b/tests/test_models/test_types.py new file mode 100644 index 0000000..40679a6 --- /dev/null +++ b/tests/test_models/test_types.py @@ -0,0 +1,215 @@ +"""Model types tests. + +Tests for ModelProvider, ModelCapability, ModelInfo, and PRESET_MODELS. +""" + +from jojo_code.models.types import ( + PRESET_MODELS, + ModelCapability, + ModelInfo, + ModelProvider, +) + + +class TestModelProvider: + def test_enum_values(self): + assert ModelProvider.OPENAI.value == "openai" + assert ModelProvider.ANTHROPIC.value == "anthropic" + assert ModelProvider.CUSTOM.value == "custom" + assert ModelProvider.LOCAL.value == "local" + + def test_enum_members(self): + members = list(ModelProvider) + assert len(members) == 4 + + +class TestModelCapability: + def test_enum_values(self): + assert ModelCapability.CHAT.value == "chat" + assert ModelCapability.FUNCTION_CALLING.value == "function_calling" + assert ModelCapability.VISION.value == "vision" + assert ModelCapability.STREAMING.value == "streaming" + assert ModelCapability.JSON_MODE.value == "json_mode" + assert ModelCapability.REASONING.value == "reasoning" + + def test_enum_members(self): + members = list(ModelCapability) + assert len(members) == 6 + + def test_capability_shortcut_alias(self): + from jojo_code.models.types import C + + assert C is ModelCapability + + +class TestModelInfo: + def test_defaults(self): + info = ModelInfo( + name="test", + provider=ModelProvider.OPENAI, + display_name="Test", + description="A test model", + ) + assert info.context_length == 128000 + assert info.capabilities == [] + assert info.cost_per_1k_input == 0.0 + assert info.cost_per_1k_output == 0.0 + assert info.default_temperature == 0.7 + assert info.max_output_tokens == 16384 + assert info.tags == [] + + def test_custom_values(self): + info = ModelInfo( + name="custom-model", + provider=ModelProvider.CUSTOM, + display_name="Custom", + description="Custom model", + context_length=64000, + capabilities=[ModelCapability.CHAT, ModelCapability.STREAMING], + cost_per_1k_input=0.001, + cost_per_1k_output=0.002, + default_temperature=0.5, + max_output_tokens=8192, + tags=["fast"], + ) + assert info.name == "custom-model" + assert info.provider == ModelProvider.CUSTOM + assert info.context_length == 64000 + assert ModelCapability.CHAT in info.capabilities + assert ModelCapability.STREAMING in info.capabilities + assert info.cost_per_1k_input == 0.001 + assert info.cost_per_1k_output == 0.002 + assert info.default_temperature == 0.5 + assert info.max_output_tokens == 8192 + assert info.tags == ["fast"] + + +class TestPresetModels: + def test_preset_models_not_empty(self): + assert len(PRESET_MODELS) > 0 + + def test_openai_models_present(self): + assert "gpt-4o" in PRESET_MODELS + assert "gpt-4o-mini" in PRESET_MODELS + assert "gpt-4-turbo" in PRESET_MODELS + + def test_anthropic_models_present(self): + assert "claude-sonnet-4-20250514" in PRESET_MODELS + assert "claude-opus-4-20250514" in PRESET_MODELS + assert "claude-3-5-sonnet-20240620" in PRESET_MODELS + assert "claude-3-haiku-20240307" in PRESET_MODELS + + def test_longcat_models_present(self): + assert "LongCat-Flash-Chat" in PRESET_MODELS + assert "LongCat-Flash-Thinking-2601" in PRESET_MODELS + + def test_preset_model_types(self): + for name, info in PRESET_MODELS.items(): + assert isinstance(info, ModelInfo), f"{name} is not ModelInfo" + assert info.name == name, f"{name} mismatch" + + def test_preset_model_providers(self): + for _name, info in PRESET_MODELS.items(): + assert isinstance(info.provider, ModelProvider) + + def test_preset_model_capabilities(self): + for _name, info in PRESET_MODELS.items(): + for cap in info.capabilities: + assert isinstance(cap, ModelCapability) + + def test_gpt4o_capabilities(self): + info = PRESET_MODELS["gpt-4o"] + assert ModelCapability.CHAT in info.capabilities + assert ModelCapability.FUNCTION_CALLING in info.capabilities + assert ModelCapability.VISION in info.capabilities + assert ModelCapability.STREAMING in info.capabilities + assert ModelCapability.JSON_MODE in info.capabilities + + def test_claude_opus_reasoning(self): + info = PRESET_MODELS["claude-opus-4-20250514"] + assert ModelCapability.REASONING in info.capabilities + assert "smart" in info.tags + assert "reasoning" in info.tags + + def test_fast_models_tagged(self): + gpt4o_mini = PRESET_MODELS["gpt-4o-mini"] + assert "fast" in gpt4o_mini.tags + assert "cheap" in gpt4o_mini.tags + + haiku = PRESET_MODELS["claude-3-haiku-20240307"] + assert "fast" in haiku.tags + assert "cheap" in haiku.tags + + def test_cost_fields_are_floats(self): + for _name, info in PRESET_MODELS.items(): + assert isinstance(info.cost_per_1k_input, float) + assert isinstance(info.cost_per_1k_output, float) + + def test_costs_are_non_negative(self): + for name, info in PRESET_MODELS.items(): + assert info.cost_per_1k_input >= 0, f"{name} has negative input cost" + assert info.cost_per_1k_output >= 0, f"{name} has negative output cost" + + def test_context_length_positive(self): + for name, info in PRESET_MODELS.items(): + assert info.context_length > 0, f"{name} has non-positive context_length" + + def test_preset_display_names_not_empty(self): + for name, info in PRESET_MODELS.items(): + assert info.display_name, f"{name} has empty display_name" + + def test_preset_descriptions_not_empty(self): + for name, info in PRESET_MODELS.items(): + assert info.description, f"{name} has empty description" + + def test_max_output_tokens_positive(self): + for name, info in PRESET_MODELS.items(): + assert info.max_output_tokens > 0, f"{name} has non-positive max_output_tokens" + + +class TestModelInfoEdgeCases: + def test_empty_capabilities_list(self): + info = ModelInfo( + name="bare", + provider=ModelProvider.LOCAL, + display_name="Bare", + description="No capabilities", + capabilities=[], + ) + assert info.capabilities == [] + + def test_multiple_capabilities(self): + caps = [ + ModelCapability.CHAT, + ModelCapability.FUNCTION_CALLING, + ModelCapability.VISION, + ModelCapability.STREAMING, + ModelCapability.JSON_MODE, + ModelCapability.REASONING, + ] + info = ModelInfo( + name="full", + provider=ModelProvider.OPENAI, + display_name="Full", + description="All capabilities", + capabilities=caps, + ) + assert len(info.capabilities) == 6 + + def test_default_tags_is_empty_list(self): + info = ModelInfo( + name="no-tags", + provider=ModelProvider.CUSTOM, + display_name="No Tags", + description="Model without tags", + ) + assert info.tags == [] + # Ensure it's a new list, not a shared reference + info.tags.append("test") + info2 = ModelInfo( + name="no-tags-2", + provider=ModelProvider.CUSTOM, + display_name="No Tags 2", + description="Another model", + ) + assert info2.tags == [] diff --git a/tests/test_plugin/test_builtin_plugins.py b/tests/test_plugin/test_builtin_plugins.py new file mode 100644 index 0000000..76da2d8 --- /dev/null +++ b/tests/test_plugin/test_builtin_plugins.py @@ -0,0 +1,228 @@ +"""内置插件测试 - code_review, test_generator, git_plugin""" + +import pytest + +from jojo_code.plugins.code_review import CodeReviewPlugin +from jojo_code.plugins.git_plugin import GitPlugin +from jojo_code.plugins.test_generator import TestGeneratorPlugin + + +@pytest.fixture +def sample_python_file(tmp_path): + """创建示例 Python 文件""" + code = '''"""示例模块""" + +import os + + +def hello(name: str) -> str: + """打招呼""" + return f"Hello, {name}!" + + +def add(a: int, b: int) -> int: + """加法""" + return a + b + + +class Calculator: + """计算器""" + + def multiply(self, a: int, b: int) -> int: + return a * b + + def divide(self, a: float, b: float) -> float: + if b == 0: + raise ValueError("Division by zero") + return a / b +''' + file_path = tmp_path / "sample.py" + file_path.write_text(code, encoding="utf-8") + return file_path + + +@pytest.fixture +def unsafe_python_file(tmp_path): + """创建包含安全问题的 Python 文件""" + code = '''"""不安全的模块""" + +import pickle +import subprocess + +password = "hardcoded_secret_123" + +def unsafe_eval(user_input: str): + return eval(user_input) + +def unsafe_deserialize(data: bytes): + return pickle.loads(data) + +def unsafe_command(cmd: str): + subprocess.call(cmd, shell=True) +''' + file_path = tmp_path / "unsafe.py" + file_path.write_text(code, encoding="utf-8") + return file_path + + +class TestCodeReviewPlugin: + """测试代码审查插件""" + + def test_metadata(self): + """测试插件元数据""" + plugin = CodeReviewPlugin() + assert plugin.metadata.name == "code-review" + assert plugin.metadata.version == "0.1.0" + + def test_get_tools(self): + """测试获取工具列表""" + plugin = CodeReviewPlugin() + tools = plugin.get_tools() + tool_names = [t.name for t in tools] + assert "review_python_security" in tool_names + assert "review_code_quality" in tool_names + assert "review_code_style" in tool_names + + def test_review_security_unsafe_file(self, unsafe_python_file): + """测试安全审查 - 不安全文件""" + plugin = CodeReviewPlugin() + result = plugin._review_python_security(str(unsafe_python_file)) + # 应该检测到安全问题 + assert ( + "hardcoded" in result.lower() + or "secret" in result.lower() + or "unsafe" in result.lower() + or "eval" in result.lower() + or "问题" in result + or "风险" in result + ) + + def test_review_security_safe_file(self, sample_python_file): + """测试安全审查 - 安全文件""" + plugin = CodeReviewPlugin() + result = plugin._review_python_security(str(sample_python_file)) + # 安全文件应该没有严重问题 + assert isinstance(result, str) + + def test_review_quality(self, sample_python_file): + """测试代码质量审查""" + plugin = CodeReviewPlugin() + result = plugin._review_code_quality(str(sample_python_file)) + assert isinstance(result, str) + assert len(result) > 0 + + def test_review_style(self, sample_python_file): + """测试代码风格审查""" + plugin = CodeReviewPlugin() + result = plugin._review_code_style(str(sample_python_file)) + assert isinstance(result, str) + + def test_review_nonexistent_file(self): + """测试审查不存在的文件""" + plugin = CodeReviewPlugin() + result = plugin._review_python_security("/nonexistent/file.py") + assert ( + "错误" in result + or "error" in result.lower() + or "不存在" in result + or "未找到" in result + ) + + +class TestTestGeneratorPlugin: + """测试测试生成插件""" + + def test_metadata(self): + """测试插件元数据""" + plugin = TestGeneratorPlugin() + assert plugin.metadata.name == "test-generator" + + def test_get_tools(self): + """测试获取工具列表""" + plugin = TestGeneratorPlugin() + tools = plugin.get_tools() + tool_names = [t.name for t in tools] + assert "generate_unit_tests" in tool_names + assert "generate_test_fixtures" in tool_names + assert "generate_test_mocks" in tool_names + + def test_generate_unit_tests(self, sample_python_file): + """测试生成单元测试""" + plugin = TestGeneratorPlugin() + result = plugin._generate_unit_tests(str(sample_python_file)) + assert isinstance(result, str) + assert len(result) > 0 + # 应该包含 pytest 相关内容 + assert ( + "def test" in result.lower() or "assert" in result.lower() or "test" in result.lower() + ) + + def test_generate_fixtures(self, sample_python_file): + """测试生成 fixtures""" + plugin = TestGeneratorPlugin() + result = plugin._generate_test_fixtures(str(sample_python_file)) + assert isinstance(result, str) + + def test_generate_mocks(self, sample_python_file): + """测试生成 mocks""" + plugin = TestGeneratorPlugin() + result = plugin._generate_test_mocks(str(sample_python_file)) + assert isinstance(result, str) + + +class TestGitPlugin: + """测试 Git 插件""" + + def test_metadata(self): + """测试插件元数据""" + plugin = GitPlugin() + assert plugin.metadata.name == "git" + + def test_get_tools(self): + """测试获取工具列表""" + plugin = GitPlugin() + tools = plugin.get_tools() + tool_names = [t.name for t in tools] + assert "git_status" in tool_names + assert "git_branch_list" in tool_names + assert "git_log" in tool_names + assert "git_diff" in tool_names + assert "git_stash_list" in tool_names + + def test_git_status(self): + """测试 git status""" + plugin = GitPlugin() + result = plugin._git_status(".") + assert isinstance(result, str) + + def test_git_log(self): + """测试 git log""" + plugin = GitPlugin() + result = plugin._git_log(".", limit=3) + assert isinstance(result, str) + + def test_git_branch_list(self): + """测试 git branch list""" + plugin = GitPlugin() + result = plugin._git_branch_list(".") + assert isinstance(result, str) + + def test_git_diff(self): + """测试 git diff""" + plugin = GitPlugin() + result = plugin._git_diff(".") + assert isinstance(result, str) + + def test_find_repo_root(self): + """测试查找仓库根目录""" + plugin = GitPlugin() + root = plugin._find_repo_root(".") + # 当前目录应该在 git 仓库中 + assert root is not None or root is None # 可能不在 git 仓库中 + + def test_run_git_invalid_command(self): + """测试运行无效 git 命令""" + plugin = GitPlugin() + stdout, stderr, code = plugin._run_git("invalid-command-xyz") + # 应该返回非零退出码 + assert code != 0 or "invalid" in stderr.lower() or "not a git" in stderr.lower() diff --git a/tests/test_plugin/test_config.py b/tests/test_plugin/test_config.py new file mode 100644 index 0000000..1edf67b --- /dev/null +++ b/tests/test_plugin/test_config.py @@ -0,0 +1,287 @@ +"""Tests for PluginConfig - configuration loading and management""" + + +import pytest + +from jojo_code.plugin.config import PluginConfig, get_plugin_config + + +@pytest.fixture +def config(): + return PluginConfig() + + +class TestPluginConfigInit: + """Test PluginConfig initialization.""" + + def test_default_state(self, config): + assert config._config == {} + assert config._enabled_plugins == set() + assert config._plugin_settings == {} + + def test_get_returns_none_for_missing_key(self, config): + assert config.get("nonexistent") is None + + def test_get_returns_default_for_missing_key(self, config): + assert config.get("missing", "fallback") == "fallback" + + +class TestPluginConfigSet: + """Test setting config values programmatically.""" + + def test_set_and_get(self, config): + config.set("key1", "value1") + assert config.get("key1") == "value1" + + def test_set_overwrites(self, config): + config.set("key", "old") + config.set("key", "new") + assert config.get("key") == "new" + + def test_set_various_types(self, config): + config.set("str_val", "hello") + config.set("int_val", 42) + config.set("list_val", [1, 2, 3]) + config.set("dict_val", {"a": 1}) + + assert config.get("str_val") == "hello" + assert config.get("int_val") == 42 + assert config.get("list_val") == [1, 2, 3] + assert config.get("dict_val") == {"a": 1} + + +class TestPluginConfigYaml: + """Test loading config from YAML files.""" + + def test_load_from_yaml_basic(self, config, tmp_path): + yaml_file = tmp_path / "plugin.yaml" + yaml_file.write_text( + "plugins:\n" + " enabled:\n" + " - plugin-a\n" + " - plugin-b\n" + "plugin_settings:\n" + " plugin-a:\n" + " debug: true\n" + ) + + config.load_from_yaml(yaml_file) + assert config.is_plugin_enabled("plugin-a") + assert config.is_plugin_enabled("plugin-b") + + def test_load_from_yaml_plugin_settings(self, config, tmp_path): + yaml_file = tmp_path / "plugin.yaml" + yaml_file.write_text( + "plugins:\n" + " enabled:\n" + " - my-plugin\n" + "plugin_settings:\n" + " my-plugin:\n" + " api_key: secret123\n" + " timeout: 30\n" + ) + + config.load_from_yaml(yaml_file) + assert config.get_plugin_setting("my-plugin", "api_key") == "secret123" + assert config.get_plugin_setting("my-plugin", "timeout") == 30 + + def test_load_from_yaml_nonexistent_file(self, config, tmp_path): + # Should be a no-op, not raise + config.load_from_yaml(tmp_path / "nonexistent.yaml") + + def test_load_from_yaml_empty_file(self, config, tmp_path): + yaml_file = tmp_path / "empty.yaml" + yaml_file.write_text("") + # Should not raise + config.load_from_yaml(yaml_file) + + def test_load_from_yaml_invalid_content(self, config, tmp_path): + yaml_file = tmp_path / "invalid.yaml" + yaml_file.write_text("{{{{not yaml}}}}") + # Should not raise - error is silently handled + config.load_from_yaml(yaml_file) + + +class TestPluginConfigPyproject: + """Test loading config from pyproject.toml.""" + + def test_load_from_pyproject_basic(self, config, tmp_path): + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text( + '[tool.jojo-code.plugins]\n' + 'enabled = ["plugin-x", "plugin-y"]\n' + ) + + config.load_from_pyproject(pyproject) + assert config.is_plugin_enabled("plugin-x") + assert config.is_plugin_enabled("plugin-y") + + def test_load_from_pyproject_nonexistent(self, config, tmp_path): + # Should be a no-op + config.load_from_pyproject(tmp_path / "nonexistent.toml") + + def test_load_from_pyproject_empty(self, config, tmp_path): + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text("") + # Should not raise + config.load_from_pyproject(pyproject) + + def test_load_from_pyproject_fallback_parser(self, config, tmp_path): + """Test the regex fallback parser for simple TOML content.""" + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text( + '[tool.jojo-code.plugins]\n' + 'enabled = ["fallback-plugin"]\n' + 'debug = "true"\n' + ) + + config.load_from_pyproject(pyproject) + # The fallback parser should still populate config + assert isinstance(config._enabled_plugins, set) + + +class TestPluginConfigEnv: + """Test loading config from environment variables.""" + + def test_load_enabled_from_env(self, config, monkeypatch): + monkeypatch.setenv("JOJO_PLUGINS_ENABLED", "env-plugin-a, env-plugin-b") + config.load_from_env() + assert config.is_plugin_enabled("env-plugin-a") + assert config.is_plugin_enabled("env-plugin-b") + + def test_load_individual_plugin_enabled(self, config, monkeypatch): + monkeypatch.setenv("JOJO_PLUGIN_MYPLUGIN_ENABLED", "true") + config.load_from_env() + assert config.is_plugin_enabled("myplugin") + + def test_load_individual_plugin_disabled(self, config, monkeypatch): + # Test the discard path: disable one of two enabled plugins + # Note: env var names use underscores, plugin names use hyphens in the list + monkeypatch.setenv("JOJO_PLUGINS_ENABLED", "plugina,pluginb") + monkeypatch.setenv("JOJO_PLUGIN_PLUGINB_ENABLED", "false") + config.load_from_env() + # pluginb is added by JOJO_PLUGINS_ENABLED then discarded by individual env + # After discard, set becomes {"plugina"} (not empty), so is_plugin_enabled + # checks membership + assert config.is_plugin_enabled("plugina") + assert not config.is_plugin_enabled("pluginb") + + def test_load_plugin_config_json(self, config, monkeypatch): + # JOJO_PLUGIN_FOO_CONFIG stores JSON under setting name "config" + monkeypatch.setenv("JOJO_PLUGIN_FOO_CONFIG", '{"key": "value", "num": 42}') + config.load_from_env() + result = config.get_plugin_setting("foo", "config") + assert isinstance(result, dict) + assert result["key"] == "value" + assert result["num"] == 42 + + def test_load_plugin_config_non_json_falls_back_to_string(self, config, monkeypatch): + monkeypatch.setenv("JOJO_PLUGIN_BAR_SETTING", "plain-string") + config.load_from_env() + assert config.get_plugin_setting("bar", "setting") == "plain-string" + + def test_no_env_vars(self, config, monkeypatch): + monkeypatch.delenv("JOJO_PLUGINS_ENABLED", raising=False) + config.load_from_env() + # Should not raise + + def test_empty_env_var(self, config, monkeypatch): + monkeypatch.setenv("JOJO_PLUGINS_ENABLED", "") + config.load_from_env() + # Empty string should not add any plugins + + +class TestPluginConfigIsEnabled: + """Test is_plugin_enabled logic.""" + + def test_all_enabled_when_no_list(self, config): + """When no enabled list is set, all plugins are enabled by default.""" + assert config.is_plugin_enabled("anything") is True + + def test_only_listed_plugins_enabled(self, config): + config._enabled_plugins = {"alpha", "beta"} + assert config.is_plugin_enabled("alpha") is True + assert config.is_plugin_enabled("beta") is True + assert config.is_plugin_enabled("gamma") is False + + def test_empty_enabled_list_allows_all(self, config): + # Empty set means no restrictions + assert config.is_plugin_enabled("any") is True + + +class TestPluginConfigSettings: + """Test plugin-specific settings.""" + + def test_get_plugin_setting(self, config): + config._plugin_settings["my-plugin"] = {"timeout": 30, "debug": True} + assert config.get_plugin_setting("my-plugin", "timeout") == 30 + assert config.get_plugin_setting("my-plugin", "debug") is True + + def test_get_plugin_setting_default(self, config): + assert config.get_plugin_setting("unknown", "key") is None + assert config.get_plugin_setting("unknown", "key", "default") == "default" + + def test_get_plugin_setting_missing_key(self, config): + config._plugin_settings["exists"] = {"known": 1} + assert config.get_plugin_setting("exists", "unknown") is None + + +class TestPluginConfigAutoLoad: + """Test auto_load from standard locations.""" + + def test_auto_load_from_yaml(self, config, tmp_path): + yaml_file = tmp_path / "plugin.yaml" + yaml_file.write_text( + "plugins:\n" + " enabled:\n" + " - auto-plugin\n" + ) + + config.auto_load(project_root=tmp_path) + assert config.is_plugin_enabled("auto-plugin") + + def test_auto_load_from_pyproject(self, config, tmp_path): + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text( + '[tool.jojo-code.plugins]\n' + 'enabled = ["toml-plugin"]\n' + ) + + config.auto_load(project_root=tmp_path) + assert config.is_plugin_enabled("toml-plugin") + + def test_auto_load_with_no_config_files(self, config, tmp_path): + # Should not raise even with no config files + config.auto_load(project_root=tmp_path) + + def test_auto_load_default_uses_cwd(self, config, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + # Should not raise + config.auto_load() + + +class TestGetPluginConfig: + """Test the global get_plugin_config function.""" + + def test_returns_plugin_config_instance(self): + + # Reset global state + import jojo_code.plugin.config + + jojo_code.plugin.config._config = None + result = get_plugin_config() + assert isinstance(result, PluginConfig) + + # Clean up + jojo_code.plugin.config._config = None + + def test_returns_same_instance(self): + import jojo_code.plugin.config + + jojo_code.plugin.config._config = None + r1 = get_plugin_config() + r2 = get_plugin_config() + assert r1 is r2 + + # Clean up + jojo_code.plugin.config._config = None diff --git a/tests/test_plugin/test_discovery.py b/tests/test_plugin/test_discovery.py new file mode 100644 index 0000000..ce09782 --- /dev/null +++ b/tests/test_plugin/test_discovery.py @@ -0,0 +1,365 @@ +"""Tests for plugin discovery system. + +Covers: discovering plugins from files, directories, subdirectories +with plugin.py, non-existent paths, invalid files, and entry points. +""" + +from pathlib import Path + +from jojo_code.plugin.discovery import PluginDiscovery + +# --------------------------------------------------------------------------- +# Helper: plugin source strings +# --------------------------------------------------------------------------- + +VALID_PLUGIN_SRC = """\ +from jojo_code.plugin.base import BasePlugin, PluginMetadata + +class MyPlugin(BasePlugin): + metadata = PluginMetadata(name="my-plugin", version="1.0.0", description="A test plugin") + + def on_load(self) -> None: + pass + + def on_unload(self) -> None: + pass +""" + +VALID_PLUGIN_SRC_ALT = """\ +from jojo_code.plugin.base import BasePlugin, PluginMetadata + +class AnotherPlugin(BasePlugin): + metadata = PluginMetadata(name="another", version="0.1.0", description="Another plugin") + + def on_load(self) -> None: + pass + + def on_unload(self) -> None: + pass +""" + +MULTIPLE_PLUGINS_SRC = """\ +from jojo_code.plugin.base import BasePlugin, PluginMetadata + +class PluginOne(BasePlugin): + metadata = PluginMetadata(name="one", version="1.0.0", description="First") + + def on_load(self) -> None: + pass + + def on_unload(self) -> None: + pass + +class PluginTwo(BasePlugin): + metadata = PluginMetadata(name="two", version="1.0.0", description="Second") + + def on_load(self) -> None: + pass + + def on_unload(self) -> None: + pass +""" + +INVALID_MODULE_SRC = """\ +this is not valid python !!! +def broken( +""" + +NO_PLUGIN_SRC = """\ +# This module has no BasePlugin subclass +x = 42 + +def hello(): + return "world" +""" + + +class TestDiscoverFromFile: + """Test discovering plugins from a single file.""" + + def test_discover_valid_plugin_file(self, tmp_path): + """Should discover a plugin from a valid Python file.""" + plugin_file = tmp_path / "my_plugin.py" + plugin_file.write_text(VALID_PLUGIN_SRC) + + discovery = PluginDiscovery() + plugins = discovery.discover(plugin_file) + + assert len(plugins) == 1 + assert plugins[0].metadata.name == "my-plugin" + assert plugins[0].metadata.version == "1.0.0" + + def test_discover_file_with_multiple_plugins(self, tmp_path): + """Should discover all BasePlugin subclasses in a single file.""" + plugin_file = tmp_path / "multi.py" + plugin_file.write_text(MULTIPLE_PLUGINS_SRC) + + discovery = PluginDiscovery() + plugins = discovery.discover(plugin_file) + + assert len(plugins) == 2 + names = {p.metadata.name for p in plugins} + assert names == {"one", "two"} + + def test_discover_file_without_plugin_class(self, tmp_path): + """Should return empty list for a file without BasePlugin subclass.""" + plugin_file = tmp_path / "no_plugin.py" + plugin_file.write_text(NO_PLUGIN_SRC) + + discovery = PluginDiscovery() + plugins = discovery.discover(plugin_file) + + assert plugins == [] + + def test_discover_invalid_python_file(self, tmp_path): + """Should return empty list and not crash on invalid Python.""" + plugin_file = tmp_path / "broken.py" + plugin_file.write_text(INVALID_MODULE_SRC) + + discovery = PluginDiscovery() + plugins = discovery.discover(plugin_file) + + assert plugins == [] + + def test_discover_nonexistent_file(self, tmp_path): + """Should return empty list for a non-existent file.""" + discovery = PluginDiscovery() + plugins = discovery.discover(tmp_path / "does_not_exist.py") + + assert plugins == [] + + +class TestDiscoverFromDirectory: + """Test discovering plugins from a directory.""" + + def test_discover_from_directory_with_py_files(self, tmp_path): + """Should discover plugins from .py files in a directory.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + + (plugin_dir / "plugin_a.py").write_text(VALID_PLUGIN_SRC) + (plugin_dir / "plugin_b.py").write_text(VALID_PLUGIN_SRC_ALT) + + discovery = PluginDiscovery() + plugins = discovery.discover(plugin_dir) + + assert len(plugins) == 2 + names = {p.metadata.name for p in plugins} + assert "my-plugin" in names + assert "another" in names + + def test_discover_from_directory_with_subdirs(self, tmp_path): + """Should discover plugins from plugin.py in subdirectories.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + + # Subdirectory with plugin.py + sub1 = plugin_dir / "sub1" + sub1.mkdir() + (sub1 / "plugin.py").write_text(VALID_PLUGIN_SRC) + + # Another subdirectory + sub2 = plugin_dir / "sub2" + sub2.mkdir() + (sub2 / "plugin.py").write_text(VALID_PLUGIN_SRC_ALT) + + discovery = PluginDiscovery() + plugins = discovery.discover(plugin_dir) + + assert len(plugins) == 2 + + def test_discover_ignores_init_files(self, tmp_path): + """Should not try to load __init__.py as a plugin.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + + (plugin_dir / "__init__.py").write_text("# init file\n") + (plugin_dir / "real_plugin.py").write_text(VALID_PLUGIN_SRC) + + discovery = PluginDiscovery() + plugins = discovery.discover(plugin_dir) + + assert len(plugins) == 1 + assert plugins[0].metadata.name == "my-plugin" + + def test_discover_ignores_non_py_files(self, tmp_path): + """Should ignore non-Python files in a directory.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + + (plugin_dir / "readme.txt").write_text("This is not a plugin") + (plugin_dir / "config.json").write_text("{}") + (plugin_dir / "real_plugin.py").write_text(VALID_PLUGIN_SRC) + + discovery = PluginDiscovery() + plugins = discovery.discover(plugin_dir) + + assert len(plugins) == 1 + + def test_discover_empty_directory(self, tmp_path): + """Should return empty list for an empty directory.""" + plugin_dir = tmp_path / "empty_plugins" + plugin_dir.mkdir() + + discovery = PluginDiscovery() + plugins = discovery.discover(plugin_dir) + + assert plugins == [] + + def test_discover_directory_with_broken_files(self, tmp_path): + """Should skip broken files and still find valid plugins.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + + (plugin_dir / "broken.py").write_text(INVALID_MODULE_SRC) + (plugin_dir / "valid.py").write_text(VALID_PLUGIN_SRC) + + discovery = PluginDiscovery() + plugins = discovery.discover(plugin_dir) + + assert len(plugins) == 1 + assert plugins[0].metadata.name == "my-plugin" + + +class TestDiscoverEdgeCases: + """Test edge cases in plugin discovery.""" + + def test_discover_nonexistent_path(self): + """Should return empty list for a non-existent path.""" + discovery = PluginDiscovery() + plugins = discovery.discover(Path("/nonexistent/path/that/does/not/exist")) + + assert plugins == [] + + def test_discover_with_string_path(self, tmp_path): + """Should accept string paths in addition to Path objects.""" + plugin_file = tmp_path / "str_path.py" + plugin_file.write_text(VALID_PLUGIN_SRC) + + discovery = PluginDiscovery() + plugins = discovery.discover(str(plugin_file)) + + assert len(plugins) == 1 + + def test_discover_mixed_directory(self, tmp_path): + """Directory with mix of valid plugins, broken files, and non-plugin files.""" + plugin_dir = tmp_path / "mixed" + plugin_dir.mkdir() + + (plugin_dir / "valid.py").write_text(VALID_PLUGIN_SRC) + (plugin_dir / "broken.py").write_text(INVALID_MODULE_SRC) + (plugin_dir / "no_plugin.py").write_text(NO_PLUGIN_SRC) + (plugin_dir / "notes.md").write_text("# Notes") + + discovery = PluginDiscovery() + plugins = discovery.discover(plugin_dir) + + assert len(plugins) == 1 + assert plugins[0].metadata.name == "my-plugin" + + def test_discover_directory_with_nested_subdirs_no_plugins(self, tmp_path): + """Subdirectories without plugin.py should be skipped.""" + plugin_dir = tmp_path / "plugins" + plugin_dir.mkdir() + + empty_sub = plugin_dir / "empty_sub" + empty_sub.mkdir() + (empty_sub / "readme.txt").write_text("no plugin here") + + discovery = PluginDiscovery() + plugins = discovery.discover(plugin_dir) + + assert plugins == [] + + +class TestDiscoverEntryPoints: + """Test entry point discovery.""" + + def test_discover_entry_points_returns_list(self): + """discover_entry_points should return a list (may be empty).""" + discovery = PluginDiscovery() + plugins = discovery.discover_entry_points() + + assert isinstance(plugins, list) + + +class TestPluginDiscoveryWithRealPlugins: + """Test discovery with plugins that use various features.""" + + def test_discover_plugin_with_hooks(self, tmp_path): + """Should discover a plugin that defines hooks.""" + src = """\ +from jojo_code.plugin.base import BasePlugin, PluginMetadata + +class HookPlugin(BasePlugin): + metadata = PluginMetadata(name="hook-plugin", version="1.0.0", description="Has hooks") + + def on_load(self) -> None: + pass + + def on_unload(self) -> None: + pass + + def get_hooks(self): + return {"before_tool_call": lambda *a: None} +""" + plugin_file = tmp_path / "hook_plugin.py" + plugin_file.write_text(src) + + discovery = PluginDiscovery() + plugins = discovery.discover(plugin_file) + + assert len(plugins) == 1 + hooks = plugins[0].get_hooks() + assert "before_tool_call" in hooks + + def test_discover_plugin_with_permission(self, tmp_path): + """Should discover a plugin with custom permissions.""" + src = """\ +from jojo_code.plugin.base import BasePlugin, PluginMetadata, PluginPermission + +class TrustedPlugin(BasePlugin): + metadata = PluginMetadata(name="trusted", version="1.0.0", description="Trusted") + permission = PluginPermission.TRUSTED + + def on_load(self) -> None: + pass + + def on_unload(self) -> None: + pass +""" + plugin_file = tmp_path / "trusted_plugin.py" + plugin_file.write_text(src) + + discovery = PluginDiscovery() + plugins = discovery.discover(plugin_file) + + assert len(plugins) == 1 + from jojo_code.plugin.base import PluginPermission + + assert plugins[0].permission == PluginPermission.TRUSTED + + def test_discover_plugin_with_sandbox(self, tmp_path): + """Should discover a plugin with sandbox configuration.""" + src = """\ +from jojo_code.plugin.base import BasePlugin, PluginMetadata, PluginSandbox + +class SandboxedPlugin(BasePlugin): + metadata = PluginMetadata(name="sandboxed", version="1.0.0", description="Sandboxed") + sandbox = PluginSandbox(restricted=True, allowed_paths=["/tmp"]) + + def on_load(self) -> None: + pass + + def on_unload(self) -> None: + pass +""" + plugin_file = tmp_path / "sandbox_plugin.py" + plugin_file.write_text(src) + + discovery = PluginDiscovery() + plugins = discovery.discover(plugin_file) + + assert len(plugins) == 1 + assert plugins[0].sandbox.restricted is True + assert "/tmp" in plugins[0].sandbox.allowed_paths diff --git a/tests/test_plugin/test_hooks.py b/tests/test_plugin/test_hooks.py new file mode 100644 index 0000000..af73be1 --- /dev/null +++ b/tests/test_plugin/test_hooks.py @@ -0,0 +1,378 @@ +"""Tests for the plugin hook dispatcher system. + +Covers: hook decorator, HookDispatcher register/unregister/dispatch, +multiple handlers, error handling, async handlers, and edge cases. +""" + +import logging +from unittest.mock import MagicMock + +import pytest + +from jojo_code.plugin.hooks import ( + HOOK_AFTER_AGENT_RUN, + HOOK_AFTER_TOOL_CALL, + HOOK_BEFORE_AGENT_RUN, + HOOK_BEFORE_TOOL_CALL, + HOOK_ON_ERROR, + HOOK_ON_LOAD, + HOOK_ON_UNLOAD, + HookDispatcher, + hook, +) + + +class TestHookDecorator: + """Test the @hook decorator.""" + + def test_hook_decorator_sets_attribute(self): + """@hook should set _hook_name on the decorated function.""" + + @hook("before_tool_call") + def my_handler(): + pass + + assert hasattr(my_handler, "_hook_name") + assert my_handler._hook_name == "before_tool_call" + + def test_hook_decorator_preserves_function(self): + """@hook should return the original function unchanged.""" + + @hook("after_tool_call") + def my_handler(x: int) -> int: + return x * 2 + + assert my_handler(5) == 10 + + def test_hook_decorator_with_all_hook_names(self): + """@hook should work with all defined hook name constants.""" + hook_names = [ + HOOK_BEFORE_TOOL_CALL, + HOOK_AFTER_TOOL_CALL, + HOOK_BEFORE_AGENT_RUN, + HOOK_AFTER_AGENT_RUN, + HOOK_ON_ERROR, + HOOK_ON_LOAD, + HOOK_ON_UNLOAD, + ] + + for name in hook_names: + + @hook(name) + def handler(): + pass + + assert handler._hook_name == name + + +class TestHookDispatcherRegister: + """Test handler registration.""" + + def test_register_single_handler(self): + """Should register a handler for a hook name.""" + dispatcher = HookDispatcher() + handler = MagicMock() + + dispatcher.register("test_hook", handler) + assert dispatcher.has_handlers("test_hook") is True + + def test_register_multiple_handlers(self): + """Should support multiple handlers for the same hook.""" + dispatcher = HookDispatcher() + h1 = MagicMock() + h2 = MagicMock() + h3 = MagicMock() + + dispatcher.register("test_hook", h1) + dispatcher.register("test_hook", h2) + dispatcher.register("test_hook", h3) + + assert dispatcher.has_handlers("test_hook") is True + + def test_has_handlers_returns_false_for_unknown(self): + """has_handlers should return False for unregistered hooks.""" + dispatcher = HookDispatcher() + assert dispatcher.has_handlers("nonexistent") is False + + +class TestHookDispatcherUnregister: + """Test handler unregistration.""" + + def test_unregister_handler(self): + """Should remove a specific handler.""" + dispatcher = HookDispatcher() + h1 = MagicMock() + h2 = MagicMock() + + dispatcher.register("test_hook", h1) + dispatcher.register("test_hook", h2) + + dispatcher.unregister("test_hook", h1) + + # Dispatching should only call h2 + dispatcher.dispatch("test_hook") + h1.assert_not_called() + h2.assert_called_once() + + def test_unregister_nonexistent_handler(self): + """Unregistering a handler that isn't registered should not raise.""" + dispatcher = HookDispatcher() + handler = MagicMock() + # Should not raise + dispatcher.unregister("nonexistent", handler) + + def test_unregister_from_empty_hook(self): + """Unregistering from a hook with no handlers should not raise.""" + dispatcher = HookDispatcher() + handler = MagicMock() + dispatcher.register("other_hook", handler) + dispatcher.unregister("empty_hook", handler) + assert dispatcher.has_handlers("other_hook") is True + + +class TestHookDispatcherDispatch: + """Test hook dispatching.""" + + def test_dispatch_calls_all_handlers(self): + """Dispatch should call all registered handlers in order.""" + dispatcher = HookDispatcher() + call_order = [] + + def handler_a(): + call_order.append("a") + + def handler_b(): + call_order.append("b") + + def handler_c(): + call_order.append("c") + + dispatcher.register("test", handler_a) + dispatcher.register("test", handler_b) + dispatcher.register("test", handler_c) + + dispatcher.dispatch("test") + assert call_order == ["a", "b", "c"] + + def test_dispatch_passes_args(self): + """Dispatch should pass positional and keyword args to handlers.""" + dispatcher = HookDispatcher() + received_args = {} + + def handler(tool_name, args, verbose=False): + received_args["tool_name"] = tool_name + received_args["args"] = args + received_args["verbose"] = verbose + + dispatcher.register(HOOK_BEFORE_TOOL_CALL, handler) + dispatcher.dispatch( + HOOK_BEFORE_TOOL_CALL, "read_file", {"path": "/tmp"}, verbose=True + ) + + assert received_args["tool_name"] == "read_file" + assert received_args["args"] == {"path": "/tmp"} + assert received_args["verbose"] is True + + def test_dispatch_returns_results(self): + """Dispatch should collect and return handler return values.""" + dispatcher = HookDispatcher() + + def handler_a(): + return "result_a" + + def handler_b(): + return 42 + + dispatcher.register("test", handler_a) + dispatcher.register("test", handler_b) + + results = dispatcher.dispatch("test") + assert results == ["result_a", 42] + + def test_dispatch_empty_hook_returns_empty_list(self): + """Dispatching a hook with no handlers should return empty list.""" + dispatcher = HookDispatcher() + results = dispatcher.dispatch("nonexistent_hook") + assert results == [] + + def test_dispatch_with_no_args(self): + """Dispatch should work with hooks that take no arguments.""" + dispatcher = HookDispatcher() + called = [] + + def handler(): + called.append(True) + + dispatcher.register("ping", handler) + dispatcher.dispatch("ping") + assert called == [True] + + +class TestHookDispatcherErrorHandling: + """Test error handling in dispatch.""" + + def test_handler_exception_does_not_stop_others(self): + """A failing handler should not prevent other handlers from running.""" + dispatcher = HookDispatcher() + call_order = [] + + def handler_a(): + call_order.append("a") + raise ValueError("boom") + + def handler_b(): + call_order.append("b") + + dispatcher.register("test", handler_a) + dispatcher.register("test", handler_b) + + results = dispatcher.dispatch("test") + assert call_order == ["a", "b"] + # handler_b returns None, handler_a raised so not in results + assert results == [None] + + def test_handler_exception_logged(self, caplog): + """Handler exceptions should be logged as warnings.""" + dispatcher = HookDispatcher() + + def bad_handler(): + raise RuntimeError("something went wrong") + + dispatcher.register("test", bad_handler) + + with caplog.at_level(logging.WARNING): + dispatcher.dispatch("test") + + assert any("failed" in record.message for record in caplog.records) + + def test_multiple_failing_handlers(self): + """Multiple failing handlers should all be attempted.""" + dispatcher = HookDispatcher() + + def fail_1(): + raise TypeError("error 1") + + def fail_2(): + raise ValueError("error 2") + + def success(): + return "ok" + + dispatcher.register("test", fail_1) + dispatcher.register("test", fail_2) + dispatcher.register("test", success) + + results = dispatcher.dispatch("test") + assert results == ["ok"] + + +class TestHookDispatcherAsyncHandlers: + """Test async handler support.""" + + @pytest.mark.asyncio + async def test_async_handler_in_running_loop(self, caplog): + """Async handler in a running loop should be scheduled as a task.""" + dispatcher = HookDispatcher() + task_created = [] + + async def async_handler(): + task_created.append(True) + + dispatcher.register("test", async_handler) + + with caplog.at_level(logging.WARNING): + dispatcher.dispatch("test") + + # Should log a warning about fire-and-forget + assert any("fire-and-forget" in record.message for record in caplog.records) + + +class TestHookDispatcherClear: + """Test clearing all handlers.""" + + def test_clear_removes_all_handlers(self): + """Clear should remove all registered handlers.""" + dispatcher = HookDispatcher() + + dispatcher.register("hook_a", lambda: None) + dispatcher.register("hook_b", lambda: None) + dispatcher.register("hook_c", lambda: None) + + assert dispatcher.has_handlers("hook_a") is True + + dispatcher.clear() + + assert dispatcher.has_handlers("hook_a") is False + assert dispatcher.has_handlers("hook_b") is False + assert dispatcher.has_handlers("hook_c") is False + + +class TestHookDispatcherWithRealHooks: + """Test using the actual hook name constants.""" + + def test_before_tool_call_hook(self): + """Test registering and dispatching HOOK_BEFORE_TOOL_CALL.""" + dispatcher = HookDispatcher() + calls = [] + + def on_before(tool_name: str, args: dict) -> None: + calls.append(("before", tool_name, args)) + + dispatcher.register(HOOK_BEFORE_TOOL_CALL, on_before) + dispatcher.dispatch(HOOK_BEFORE_TOOL_CALL, "run_command", {"command": "ls"}) + + assert len(calls) == 1 + assert calls[0] == ("before", "run_command", {"command": "ls"}) + + def test_after_tool_call_hook(self): + """Test registering and dispatching HOOK_AFTER_TOOL_CALL.""" + dispatcher = HookDispatcher() + calls = [] + + def on_after(tool_name: str, result: str) -> None: + calls.append(("after", tool_name, result)) + + dispatcher.register(HOOK_AFTER_TOOL_CALL, on_after) + dispatcher.dispatch(HOOK_AFTER_TOOL_CALL, "read_file", "file content here") + + assert len(calls) == 1 + assert calls[0] == ("after", "read_file", "file content here") + + def test_on_error_hook(self): + """Test registering and dispatching HOOK_ON_ERROR.""" + dispatcher = HookDispatcher() + errors = [] + + def on_error(error: Exception) -> None: + errors.append(error) + + dispatcher.register(HOOK_ON_ERROR, on_error) + exc = RuntimeError("test error") + dispatcher.dispatch(HOOK_ON_ERROR, exc) + + assert len(errors) == 1 + assert errors[0] is exc + + def test_multiple_hook_types_independently(self): + """Different hook types should be independent.""" + dispatcher = HookDispatcher() + calls = [] + + def before_handler(tool_name, args): + calls.append(f"before:{tool_name}") + + def after_handler(tool_name, result): + calls.append(f"after:{tool_name}") + + def error_handler(error): + calls.append(f"error:{error}") + + dispatcher.register(HOOK_BEFORE_TOOL_CALL, before_handler) + dispatcher.register(HOOK_AFTER_TOOL_CALL, after_handler) + dispatcher.register(HOOK_ON_ERROR, error_handler) + + dispatcher.dispatch(HOOK_BEFORE_TOOL_CALL, "tool_a", {}) + dispatcher.dispatch(HOOK_AFTER_TOOL_CALL, "tool_b", "result") + dispatcher.dispatch(HOOK_ON_ERROR, "oops") + + assert calls == ["before:tool_a", "after:tool_b", "error:oops"] diff --git a/tests/test_plugin/test_loader.py b/tests/test_plugin/test_loader.py new file mode 100644 index 0000000..78ee0d8 --- /dev/null +++ b/tests/test_plugin/test_loader.py @@ -0,0 +1,206 @@ +"""Tests for PluginLoader - loading plugins from modules, files, and classes""" + +import pytest + +from jojo_code.plugin.base import BasePlugin, PluginMetadata +from jojo_code.plugin.loader import PluginLoader, PluginLoadError + + +def _make_plugin_class(name: str = "test") -> type[BasePlugin]: + """Helper to create a concrete plugin class.""" + + class _Plugin(BasePlugin): + metadata = PluginMetadata(name=name, version="1.0.0", description=f"{name} plugin") + + def on_load(self) -> None: + pass + + def on_unload(self) -> None: + pass + + _Plugin.__name__ = f"{name.title()}Plugin" + return _Plugin + + +@pytest.fixture +def loader(): + return PluginLoader() + + +class TestPluginLoadError: + """Test PluginLoadError exception.""" + + def test_is_exception(self): + assert issubclass(PluginLoadError, Exception) + + def test_can_raise_and_catch(self): + with pytest.raises(PluginLoadError, match="test error"): + raise PluginLoadError("test error") + + +class TestLoadFromClass: + """Test loading plugins from a class reference.""" + + def test_load_from_valid_class(self, loader): + cls = _make_plugin_class("from-class") + plugin = loader.load_from_class(cls) + assert isinstance(plugin, BasePlugin) + assert plugin.metadata.name == "from-class" + + def test_load_from_class_returns_new_instance(self, loader): + cls = _make_plugin_class("instance-test") + p1 = loader.load_from_class(cls) + p2 = loader.load_from_class(cls) + assert p1 is not p2 + + def test_load_from_non_subclass_raises(self, loader): + class NotAPlugin: + pass + + with pytest.raises(PluginLoadError, match="not a BasePlugin subclass"): + loader.load_from_class(NotAPlugin) # type: ignore[arg-type] + + def test_load_from_class_preserves_metadata(self, loader): + cls = _make_plugin_class("meta-check") + plugin = loader.load_from_class(cls) + assert plugin.metadata.version == "1.0.0" + assert plugin.metadata.description == "meta-check plugin" + + +class TestLoadFromFile: + """Test loading plugins from a Python file.""" + + def test_load_from_valid_file(self, loader, tmp_path): + plugin_file = tmp_path / "my_plugin.py" + plugin_file.write_text( + 'from jojo_code.plugin.base import BasePlugin, PluginMetadata\n' + '\n' + 'class FilePlugin(BasePlugin):\n' + ' metadata = PluginMetadata(name="file-plugin", version="0.1.0", description="")\n' + '\n' + ' def on_load(self) -> None:\n' + ' pass\n' + '\n' + ' def on_unload(self) -> None:\n' + ' pass\n' + ) + + plugin = loader.load_from_file(plugin_file) + assert isinstance(plugin, BasePlugin) + assert plugin.metadata.name == "file-plugin" + + def test_load_from_file_with_path_object(self, loader, tmp_path): + plugin_file = tmp_path / "path_plugin.py" + plugin_file.write_text( + 'from jojo_code.plugin.base import BasePlugin, PluginMetadata\n' + '\n' + 'class PathPlugin(BasePlugin):\n' + ' metadata = PluginMetadata(name="path", version="1.0.0", description="")\n' + '\n' + ' def on_load(self) -> None:\n' + ' pass\n' + '\n' + ' def on_unload(self) -> None:\n' + ' pass\n' + ) + + from pathlib import Path + + plugin = loader.load_from_file(Path(str(plugin_file))) + assert plugin.metadata.name == "path" + + def test_load_from_nonexistent_file_raises(self, loader): + with pytest.raises(PluginLoadError): + loader.load_from_file("/nonexistent/path/plugin.py") + + def test_load_from_file_without_plugin_class_raises(self, loader, tmp_path): + no_plugin = tmp_path / "no_plugin.py" + no_plugin.write_text('x = 42\n') + + with pytest.raises(PluginLoadError, match="No BasePlugin subclass"): + loader.load_from_file(no_plugin) + + def test_load_from_file_with_syntax_error_raises(self, loader, tmp_path): + bad_file = tmp_path / "bad.py" + bad_file.write_text("def broken(\n") + + with pytest.raises(PluginLoadError): + loader.load_from_file(bad_file) + + def test_load_from_file_with_multiple_classes_returns_first(self, loader, tmp_path): + plugin_file = tmp_path / "multi.py" + plugin_file.write_text( + 'from jojo_code.plugin.base import BasePlugin, PluginMetadata\n' + '\n' + 'class FirstPlugin(BasePlugin):\n' + ' metadata = PluginMetadata(name="first", version="1.0.0", description="")\n' + ' def on_load(self) -> None: pass\n' + ' def on_unload(self) -> None: pass\n' + '\n' + 'class SecondPlugin(BasePlugin):\n' + ' metadata = PluginMetadata(name="second", version="1.0.0", description="")\n' + ' def on_load(self) -> None: pass\n' + ' def on_unload(self) -> None: pass\n' + ) + + plugin = loader.load_from_file(plugin_file) + assert isinstance(plugin, BasePlugin) + + +class TestLoadFromModule: + """Test loading plugins from a dotted module path.""" + + def test_load_from_nonexistent_module_raises(self, loader): + with pytest.raises(PluginLoadError): + loader.load_from_module("nonexistent.module.path") + + def test_load_from_module_without_plugin_raises(self, loader): + with pytest.raises(PluginLoadError, match="No BasePlugin subclass"): + loader.load_from_module("json") + + def test_load_from_module_method_exists(self, loader): + """Verify the method signature exists and is callable.""" + assert callable(loader.load_from_module) + + +class TestLoaderEdgeCases: + """Test edge cases and error handling.""" + + def test_load_from_file_string_path(self, loader, tmp_path): + plugin_file = tmp_path / "str_path.py" + plugin_file.write_text( + 'from jojo_code.plugin.base import BasePlugin, PluginMetadata\n' + '\n' + 'class StrPlugin(BasePlugin):\n' + ' metadata = PluginMetadata(name="str", version="1.0.0", description="")\n' + ' def on_load(self) -> None: pass\n' + ' def on_unload(self) -> None: pass\n' + ) + + plugin = loader.load_from_file(str(plugin_file)) + assert plugin.metadata.name == "str" + + def test_load_from_empty_file_raises(self, loader, tmp_path): + empty_file = tmp_path / "empty.py" + empty_file.write_text("") + + with pytest.raises(PluginLoadError, match="No BasePlugin subclass"): + loader.load_from_file(empty_file) + + def test_load_from_file_with_imports(self, loader, tmp_path): + """Plugin file that imports other modules should still load.""" + plugin_file = tmp_path / "with_imports.py" + plugin_file.write_text( + 'import os\n' + 'import sys\n' + 'from pathlib import Path\n' + 'from jojo_code.plugin.base import BasePlugin, PluginMetadata\n' + '\n' + 'class ImportPlugin(BasePlugin):\n' + ' metadata = PluginMetadata(name="imports", version="1.0.0", description="")\n' + ' def on_load(self) -> None: pass\n' + ' def on_unload(self) -> None: pass\n' + ) + + plugin = loader.load_from_file(plugin_file) + assert plugin.metadata.name == "imports" diff --git a/tests/test_plugin/test_registry.py b/tests/test_plugin/test_registry.py new file mode 100644 index 0000000..7f4a63f --- /dev/null +++ b/tests/test_plugin/test_registry.py @@ -0,0 +1,309 @@ +"""Tests for PluginRegistry - plugin registration, lifecycle, and lookup""" + +import pytest + +from jojo_code.plugin.base import BasePlugin, PluginMetadata +from jojo_code.plugin.hooks import HookDispatcher +from jojo_code.plugin.registry import PluginRegistry + + +def _make_plugin(name: str = "test") -> type[BasePlugin]: + """Helper to create a simple plugin class.""" + + class _Plugin(BasePlugin): + metadata = PluginMetadata(name=name, version="1.0.0", description=f"{name} plugin") + loaded = False + unloaded = False + + def on_load(self) -> None: + self.__class__.loaded = True + + def on_unload(self) -> None: + self.__class__.unloaded = True + + _Plugin.__name__ = f"{name.title()}Plugin" + return _Plugin + + +@pytest.fixture(autouse=True) +def _clean_registry(): + """Ensure a clean registry before and after each test.""" + registry = PluginRegistry.get_instance() + registry.clear() + yield + registry.clear() + + +class TestPluginRegistrySingleton: + """Test that PluginRegistry is a proper singleton.""" + + def test_get_instance_returns_same_object(self): + r1 = PluginRegistry.get_instance() + r2 = PluginRegistry.get_instance() + assert r1 is r2 + + def test_get_instance_returns_plugin_registry(self): + registry = PluginRegistry.get_instance() + assert isinstance(registry, PluginRegistry) + + +class TestPluginRegistryRegister: + """Test plugin registration.""" + + def test_register_plugin(self): + registry = PluginRegistry.get_instance() + plugin_cls = _make_plugin("alpha") + plugin = plugin_cls() + registry.register("alpha", plugin) + + assert registry.get("alpha") is plugin + + def test_register_calls_on_load(self): + registry = PluginRegistry.get_instance() + plugin_cls = _make_plugin("beta") + plugin = plugin_cls() + plugin_cls.loaded = False + registry.register("beta", plugin) + assert plugin_cls.loaded is True + + def test_register_duplicate_raises(self): + registry = PluginRegistry.get_instance() + plugin_cls = _make_plugin("dup") + registry.register("dup", plugin_cls()) + + with pytest.raises(ValueError, match="already registered"): + registry.register("dup", plugin_cls()) + + def test_register_multiple_plugins(self): + registry = PluginRegistry.get_instance() + for i in range(5): + cls = _make_plugin(f"multi-{i}") + registry.register(f"multi-{i}", cls()) + + assert len(registry.list_plugins()) == 5 + + def test_register_with_dispatcher_registers_hooks(self): + registry = PluginRegistry.get_instance() + dispatcher = HookDispatcher() + registry.set_dispatcher(dispatcher) + + hook_calls = [] + + class HookPlugin(BasePlugin): + metadata = PluginMetadata(name="hooky", version="1.0.0", description="") + + def on_load(self) -> None: + pass + + def on_unload(self) -> None: + pass + + def get_hooks(self) -> dict: + return {"test_hook": lambda: hook_calls.append("called")} + + registry.register("hooky", HookPlugin()) + dispatcher.dispatch("test_hook") + assert hook_calls == ["called"] + + def test_register_without_dispatcher_skips_hooks(self): + """Hooks should not be registered when no dispatcher is set.""" + registry = PluginRegistry.get_instance() + registry.set_dispatcher(None) # type: ignore[arg-type] + + class HookPlugin(BasePlugin): + metadata = PluginMetadata(name="no-dispatcher", version="1.0.0", description="") + + def on_load(self) -> None: + pass + + def on_unload(self) -> None: + pass + + def get_hooks(self) -> dict: + return {"test_hook": lambda: None} + + # Should not raise + registry.register("no-dispatcher", HookPlugin()) + + +class TestPluginRegistryUnregister: + """Test plugin unregistration.""" + + def test_unregister_existing_plugin(self): + registry = PluginRegistry.get_instance() + plugin_cls = _make_plugin("removeme") + plugin = plugin_cls() + registry.register("removeme", plugin) + registry.unregister("removeme") + + assert registry.get("removeme") is None + + def test_unregister_calls_on_unload(self): + registry = PluginRegistry.get_instance() + plugin_cls = _make_plugin("unloadable") + plugin_cls.unloaded = False + registry.register("unloadable", plugin_cls()) + registry.unregister("unloadable") + assert plugin_cls.unloaded is True + + def test_unregister_nonexistent_is_noop(self): + registry = PluginRegistry.get_instance() + # Should not raise + registry.unregister("does-not-exist") + + def test_unregister_removes_from_list(self): + registry = PluginRegistry.get_instance() + cls_a = _make_plugin("a") + cls_b = _make_plugin("b") + registry.register("a", cls_a()) + registry.register("b", cls_b()) + + registry.unregister("a") + assert registry.list_plugins() == ["b"] + + +class TestPluginRegistryLookup: + """Test plugin lookup methods.""" + + def test_get_existing_plugin(self): + registry = PluginRegistry.get_instance() + cls = _make_plugin("finder") + plugin = cls() + registry.register("finder", plugin) + assert registry.get("finder") is plugin + + def test_get_nonexistent_returns_none(self): + registry = PluginRegistry.get_instance() + assert registry.get("nope") is None + + def test_list_plugins_empty(self): + registry = PluginRegistry.get_instance() + assert registry.list_plugins() == [] + + def test_list_plugins_returns_names(self): + registry = PluginRegistry.get_instance() + for name in ["x", "y", "z"]: + registry.register(name, _make_plugin(name)()) + result = registry.list_plugins() + assert set(result) == {"x", "y", "z"} + + def test_get_all_returns_all_plugins(self): + registry = PluginRegistry.get_instance() + plugins = [] + for name in ["p1", "p2"]: + p = _make_plugin(name)() + registry.register(name, p) + plugins.append(p) + + all_plugins = registry.get_all() + assert len(all_plugins) == 2 + for p in plugins: + assert p in all_plugins + + def test_get_all_empty(self): + registry = PluginRegistry.get_instance() + assert registry.get_all() == [] + + +class TestPluginRegistryClear: + """Test clearing the registry.""" + + def test_clear_removes_all_plugins(self): + registry = PluginRegistry.get_instance() + for name in ["c1", "c2", "c3"]: + registry.register(name, _make_plugin(name)()) + + registry.clear() + assert registry.list_plugins() == [] + assert registry.get_all() == [] + + def test_clear_allows_re_registration(self): + registry = PluginRegistry.get_instance() + cls = _make_plugin("rerun") + registry.register("rerun", cls()) + registry.clear() + + # Should not raise (duplicate check should pass after clear) + registry.register("rerun", cls()) + + +class TestPluginRegistryEnableDisable: + """Test plugin enable/disable lifecycle.""" + + def test_enable_calls_on_enable(self): + registry = PluginRegistry.get_instance() + + class EnablePlugin(BasePlugin): + metadata = PluginMetadata(name="enabler", version="1.0.0", description="") + enabled = False + + def on_load(self) -> None: + pass + + def on_unload(self) -> None: + pass + + def on_enable(self) -> None: + self.__class__.enabled = True + + plugin = EnablePlugin() + registry.register("enabler", plugin) + registry.enable("enabler") + assert EnablePlugin.enabled is True + + def test_disable_calls_on_disable(self): + registry = PluginRegistry.get_instance() + + class DisablePlugin(BasePlugin): + metadata = PluginMetadata(name="disabler", version="1.0.0", description="") + disabled = False + + def on_load(self) -> None: + pass + + def on_unload(self) -> None: + pass + + def on_disable(self) -> None: + self.__class__.disabled = True + + plugin = DisablePlugin() + registry.register("disabler", plugin) + registry.disable("disabler") + assert DisablePlugin.disabled is True + + def test_enable_nonexistent_is_noop(self): + registry = PluginRegistry.get_instance() + # Should not raise + registry.enable("ghost") + + def test_disable_nonexistent_is_noop(self): + registry = PluginRegistry.get_instance() + # Should not raise + registry.disable("ghost") + + def test_enable_plugin_without_on_enable_is_noop(self): + """Plugins that don't define on_enable should not cause errors.""" + registry = PluginRegistry.get_instance() + cls = _make_plugin("no-enable") + registry.register("no-enable", cls()) + # Should not raise even though _Plugin doesn't define on_enable + registry.enable("no-enable") + + +class TestPluginRegistryDispatcher: + """Test dispatcher integration.""" + + def test_set_dispatcher(self): + registry = PluginRegistry.get_instance() + dispatcher = HookDispatcher() + registry.set_dispatcher(dispatcher) + assert registry._dispatcher is dispatcher + + def test_replace_dispatcher(self): + registry = PluginRegistry.get_instance() + d1 = HookDispatcher() + d2 = HookDispatcher() + registry.set_dispatcher(d1) + registry.set_dispatcher(d2) + assert registry._dispatcher is d2 diff --git a/tests/test_security/test_denial.py b/tests/test_security/test_denial.py new file mode 100644 index 0000000..eb2107b --- /dev/null +++ b/tests/test_security/test_denial.py @@ -0,0 +1,135 @@ +"""拒绝追踪模块测试""" + +from datetime import datetime, timedelta +from unittest.mock import MagicMock + +from jojo_code.security.denial import AdaptivePermissionMixin, DenialTracker +from jojo_code.security.permission import PermissionLevel, PermissionResult + + +class TestDenialTracker: + """测试拒绝追踪器""" + + def test_record_denial(self): + """测试记录拒绝""" + tracker = DenialTracker() + tracker.record("run_command", {"command": "rm -rf /"}, "危险命令") + assert tracker.get_denial_count("run_command", {"command": "rm -rf /"}) == 1 + + def test_increment_count(self): + """测试计数递增""" + tracker = DenialTracker() + args = {"command": "rm -rf /"} + tracker.record("run_command", args, "危险命令") + tracker.record("run_command", args, "危险命令") + tracker.record("run_command", args, "危险命令") + assert tracker.get_denial_count("run_command", args) == 3 + + def test_threshold_not_exceeded(self): + """测试未超过阈值""" + tracker = DenialTracker(threshold=3) + args = {"command": "rm -rf /"} + tracker.record("run_command", args, "危险命令") + tracker.record("run_command", args, "危险命令") + assert tracker.is_threshold_exceeded("run_command", args) is False + + def test_threshold_exceeded(self): + """测试超过阈值""" + tracker = DenialTracker(threshold=3) + args = {"command": "rm -rf /"} + for _ in range(3): + tracker.record("run_command", args, "危险命令") + assert tracker.is_threshold_exceeded("run_command", args) is True + + def test_threshold_not_exceeded_for_unknown(self): + """测试未知工具未超过阈值""" + tracker = DenialTracker() + assert tracker.is_threshold_exceeded("unknown", {}) is False + + def test_get_tool_denials(self): + """测试获取工具拒绝记录""" + tracker = DenialTracker() + tracker.record("run_command", {"command": "ls"}, "原因1") + tracker.record("run_command", {"command": "pwd"}, "原因2") + tracker.record("read_file", {"path": "test.py"}, "原因3") + + cmd_denials = tracker.get_tool_denials("run_command") + assert len(cmd_denials) == 2 + + file_denials = tracker.get_tool_denials("read_file") + assert len(file_denials) == 1 + + def test_clear_specific_tool(self): + """测试清除特定工具记录""" + tracker = DenialTracker() + tracker.record("run_command", {"command": "ls"}, "原因") + tracker.record("read_file", {"path": "test.py"}, "原因") + tracker.clear("run_command") + assert tracker.get_denial_count("run_command", {"command": "ls"}) == 0 + assert tracker.get_denial_count("read_file", {"path": "test.py"}) == 1 + + def test_clear_all(self): + """测试清除所有记录""" + tracker = DenialTracker() + tracker.record("run_command", {"command": "ls"}, "原因") + tracker.record("read_file", {"path": "test.py"}, "原因") + tracker.clear() + assert tracker.get_denial_count("run_command", {"command": "ls"}) == 0 + assert tracker.get_denial_count("read_file", {"path": "test.py"}) == 0 + + def test_cleanup_expired(self): + """测试清理过期记录""" + tracker = DenialTracker(expiry_seconds=1) + tracker.record("run_command", {"command": "ls"}, "原因") + # 手动设置时间戳为过去 + key = tracker._make_key("run_command", {"command": "ls"}) + tracker._denials[key].timestamp = datetime.now() - timedelta(seconds=2) + cleaned = tracker.cleanup_expired() + assert cleaned == 1 + assert tracker.get_denial_count("run_command", {"command": "ls"}) == 0 + + def test_get_stats(self): + """测试获取统计信息""" + tracker = DenialTracker(threshold=5, window_seconds=600) + tracker.record("run_command", {"command": "ls"}, "原因") + stats = tracker.get_stats() + assert stats["total_denials"] == 1 + assert stats["tools_tracked"] == 1 + assert stats["threshold"] == 5 + assert stats["window_seconds"] == 600 + + +class TestAdaptivePermissionMixin: + """测试自适应权限混合类""" + + def test_denial_tracking_allows_normal(self): + """测试正常请求通过""" + mixin = AdaptivePermissionMixin() + check_fn = MagicMock(return_value=PermissionResult(PermissionLevel.ALLOW, "test", {})) + allowed, reason = mixin.check_with_denial_tracking("test", {}, check_fn) + assert allowed is True + assert reason == "" + + def test_denial_tracking_records_denial(self): + """测试拒绝被记录""" + mixin = AdaptivePermissionMixin() + check_fn = MagicMock( + return_value=PermissionResult(PermissionLevel.DENY, "test", {}, reason="权限不足") + ) + allowed, reason = mixin.check_with_denial_tracking("test", {}, check_fn) + assert allowed is False + assert "权限不足" in reason + assert mixin.denial_tracker.get_denial_count("test", {}) == 1 + + def test_denial_tracking_threshold_message(self): + """测试阈值消息""" + mixin = AdaptivePermissionMixin() + mixin._denial_tracker.threshold = 2 + check_fn = MagicMock( + return_value=PermissionResult(PermissionLevel.DENY, "test", {}, reason="权限不足") + ) + # 多次拒绝达到阈值 + mixin.check_with_denial_tracking("test", {}, check_fn) + allowed, reason = mixin.check_with_denial_tracking("test", {}, check_fn) + assert allowed is False + assert "连续拒绝" in reason diff --git a/tests/test_security/test_enhanced.py b/tests/test_security/test_enhanced.py new file mode 100644 index 0000000..3f81218 --- /dev/null +++ b/tests/test_security/test_enhanced.py @@ -0,0 +1,360 @@ +"""增强版权限管理器测试""" + +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from jojo_code.security.denial import DenialTracker +from jojo_code.security.enhanced import ( + EnhancedPermissionConfig, + EnhancedPermissionManager, + get_enhanced_permission_manager, + init_enhanced_permission_manager, + set_enhanced_permission_manager, +) +from jojo_code.security.modes import PermissionMode +from jojo_code.security.permission import PermissionLevel +from jojo_code.security.rule import RuleEngine + + +class TestEnhancedPermissionConfig: + """测试增强版权限配置""" + + def test_default_config(self): + """测试默认配置""" + config = EnhancedPermissionConfig() + assert config.enable_rule_engine is True + assert config.default_rule_action == "ask" + assert config.enable_denial_tracking is True + assert config.denial_threshold == 3 + assert config.denial_window_seconds == 300 + assert config.confirm_callback is None + + def test_custom_config(self): + """测试自定义配置""" + callback = MagicMock(return_value=True) + config = EnhancedPermissionConfig( + enable_rule_engine=False, + default_rule_action="deny", + enable_denial_tracking=False, + denial_threshold=5, + denial_window_seconds=600, + confirm_callback=callback, + ) + assert config.enable_rule_engine is False + assert config.default_rule_action == "deny" + assert config.enable_denial_tracking is False + assert config.denial_threshold == 5 + assert config.denial_window_seconds == 600 + assert config.confirm_callback is callback + + def test_base_config_default(self): + """测试基础配置默认值""" + config = EnhancedPermissionConfig() + assert config.base is not None + assert config.base.shell_enabled is True + + +class TestEnhancedPermissionManager: + """测试增强版权限管理器""" + + @pytest.fixture + def manager(self, tmp_path: Path) -> EnhancedPermissionManager: + """创建测试管理器""" + config = EnhancedPermissionConfig( + base=EnhancedPermissionConfig().base.__class__( + workspace_root=tmp_path, + audit_log=False, + ), + ) + return EnhancedPermissionManager(config) + + @pytest.fixture + def manager_no_rules(self, tmp_path: Path) -> EnhancedPermissionManager: + """创建禁用规则引擎的管理器""" + config = EnhancedPermissionConfig( + enable_rule_engine=False, + enable_denial_tracking=False, + base=EnhancedPermissionConfig().base.__class__( + workspace_root=tmp_path, + audit_log=False, + ), + ) + return EnhancedPermissionManager(config) + + def test_init_with_default_config(self): + """测试默认配置初始化""" + manager = EnhancedPermissionManager() + assert manager.get_rule_engine() is not None + assert manager.get_denial_tracker() is not None + + def test_init_without_rule_engine(self): + """测试禁用规则引擎""" + config = EnhancedPermissionConfig(enable_rule_engine=False) + manager = EnhancedPermissionManager(config) + assert manager.get_rule_engine() is None + + def test_init_without_denial_tracking(self): + """测试禁用拒绝追踪""" + config = EnhancedPermissionConfig(enable_denial_tracking=False) + manager = EnhancedPermissionManager(config) + assert manager.get_denial_tracker() is None + + def test_default_rules_loaded(self): + """测试默认规则加载""" + manager = EnhancedPermissionManager() + rules = manager.list_rules() + # 应包含危险命令拒绝规则 + assert len(rules) >= 3 + rule_names = [r.name for r in rules] + assert "deny_rm_rf" in rule_names + assert "deny_sudo" in rule_names + assert "deny_chmod_777" in rule_names + + def test_mode_property(self, manager: EnhancedPermissionManager): + """测试权限模式属性""" + assert manager.mode == PermissionMode.AUTO + + def test_set_mode(self, manager: EnhancedPermissionManager): + """测试设置权限模式""" + manager.set_mode("manual") + assert manager.mode == PermissionMode.MANUAL + + def test_add_rule(self, manager: EnhancedPermissionManager): + """测试添加规则""" + manager.add_rule( + tool_pattern="test_tool", + action="allow", + name="test_rule", + description="Test rule", + ) + rules = manager.list_rules() + assert any(r.name == "test_rule" for r in rules) + + def test_add_rule_with_args_pattern(self, manager: EnhancedPermissionManager): + """测试添加带参数模式的规则""" + manager.add_rule( + tool_pattern="run_command", + action="deny", + args_pattern={"command": "rm *"}, + name="deny_rm", + priority=100, + ) + result = manager.check("run_command", {"command": "rm -rf /"}) + assert result.denied + + def test_add_rule_when_engine_disabled(self, manager_no_rules: EnhancedPermissionManager): + """测试禁用规则引擎时添加规则""" + # 应该是 no-op,不抛异常 + manager_no_rules.add_rule(tool_pattern="test", action="allow") + assert manager_no_rules.list_rules() == [] + + def test_rule_engine_allow(self, manager: EnhancedPermissionManager): + """测试规则引擎允许""" + manager.add_rule( + tool_pattern="safe_tool", + action="allow", + name="allow_safe", + priority=100, + ) + result = manager.check("safe_tool", {}) + assert result.allowed + + def test_rule_engine_deny(self, manager: EnhancedPermissionManager): + """测试规则引擎拒绝""" + manager.add_rule( + tool_pattern="dangerous_tool", + action="deny", + name="deny_danger", + priority=100, + ) + result = manager.check("dangerous_tool", {}) + assert result.denied + assert "规则拒绝" in result.reason + + def test_rule_engine_deny_records_denial(self, manager: EnhancedPermissionManager): + """测试规则引擎拒绝记录拒绝""" + manager.add_rule( + tool_pattern="blocked_tool", + action="deny", + name="block_it", + priority=100, + ) + manager.check("blocked_tool", {}) + tracker = manager.get_denial_tracker() + assert tracker.get_denial_count("blocked_tool", {}) >= 1 + + def test_rule_engine_ask_falls_through(self, manager: EnhancedPermissionManager): + """测试规则引擎 ASK 继续往下走""" + manager.add_rule( + tool_pattern="ask_tool", + action="ask", + name="ask_it", + priority=100, + ) + # ASK 规则不直接返回,继续基础权限检查 + result = manager.check("ask_tool", {}) + # 结果取决于基础权限管理器 + assert result is not None + + def test_dangerous_command_denied_by_default_rules(self, manager: EnhancedPermissionManager): + """测试危险命令被默认规则拒绝""" + result = manager.check("run_command", {"command": "rm -rf /"}) + assert result.denied + + def test_sudo_denied_by_default_rules(self, manager: EnhancedPermissionManager): + """测试 sudo 被默认规则拒绝""" + result = manager.check("run_command", {"command": "sudo ls"}) + assert result.denied + + def test_base_permission_check(self, manager: EnhancedPermissionManager, tmp_path: Path): + """测试基础权限检查""" + (tmp_path / "test.py").write_text("print('hello')") + result = manager.check("read_file", {"path": str(tmp_path / "test.py")}) + assert result.allowed + + def test_denial_tracking_threshold(self, manager: EnhancedPermissionManager): + """测试拒绝追踪阈值""" + # 超出基础权限的操作(访问工作空间外的文件) + args = {"path": "/etc/passwd"} + for _ in range(3): + manager.check("read_file", args) + result = manager.check("read_file", args) + assert result.denied + assert "连续拒绝" in result.reason + + def test_confirm_callback_confirmed(self, tmp_path: Path): + """测试用户确认回调 - 确认""" + callback = MagicMock(return_value=True) + config = EnhancedPermissionConfig( + confirm_callback=callback, + base=EnhancedPermissionConfig().base.__class__( + workspace_root=tmp_path, + shell_default=PermissionLevel.CONFIRM, + audit_log=False, + ), + ) + manager = EnhancedPermissionManager(config) + result = manager.check("run_command", {"command": "ls"}) + assert result.allowed + callback.assert_called_once() + + def test_confirm_callback_denied(self, tmp_path: Path): + """测试用户确认回调 - 拒绝""" + callback = MagicMock(return_value=False) + config = EnhancedPermissionConfig( + confirm_callback=callback, + base=EnhancedPermissionConfig().base.__class__( + workspace_root=tmp_path, + shell_default=PermissionLevel.CONFIRM, + audit_log=False, + ), + ) + manager = EnhancedPermissionManager(config) + result = manager.check("run_command", {"command": "ls"}) + assert result.denied + assert "用户拒绝" in result.reason + + def test_set_confirm_callback(self, manager: EnhancedPermissionManager): + """测试设置确认回调""" + callback = MagicMock(return_value=True) + manager.set_confirm_callback(callback) + assert manager._confirm_callback is callback + + def test_clear_rules(self, manager: EnhancedPermissionManager): + """测试清空规则""" + manager.add_rule(tool_pattern="test", action="allow", name="test") + assert len(manager.list_rules()) > 0 + manager.clear_rules() + # 只剩下默认规则(如果规则引擎启用) + # clear_rules 会清空所有规则 + assert len(manager.list_rules()) == 0 + + def test_get_stats(self, manager: EnhancedPermissionManager): + """测试获取统计信息""" + stats = manager.get_stats() + assert "mode" in stats + assert "rules_count" in stats + assert "denial_tracker" in stats + assert stats["mode"] == "auto" + + def test_get_stats_no_tracker(self, tmp_path: Path): + """测试无拒绝追踪器时获取统计信息""" + config = EnhancedPermissionConfig( + enable_denial_tracking=False, + base=EnhancedPermissionConfig().base.__class__( + workspace_root=tmp_path, + audit_log=False, + ), + ) + manager = EnhancedPermissionManager(config) + stats = manager.get_stats() + assert "denial_tracker" not in stats + + def test_get_rule_engine(self, manager: EnhancedPermissionManager): + """测试获取规则引擎""" + engine = manager.get_rule_engine() + assert isinstance(engine, RuleEngine) + + def test_get_denial_tracker(self, manager: EnhancedPermissionManager): + """测试获取拒绝追踪器""" + tracker = manager.get_denial_tracker() + assert isinstance(tracker, DenialTracker) + + def test_rule_engine_disabled_check( + self, manager_no_rules: EnhancedPermissionManager, tmp_path: Path + ): + """测试禁用规则引擎时的权限检查""" + (tmp_path / "test.py").write_text("test") + result = manager_no_rules.check("read_file", {"path": str(tmp_path / "test.py")}) + # 应该通过基础权限检查 + assert result.allowed + + +class TestGlobalFunctions: + """测试全局函数""" + + def test_get_enhanced_permission_manager_singleton(self): + """测试单例模式""" + # 重置全局实例 + import jojo_code.security.enhanced as mod + + mod._enhanced_manager = None + + manager1 = get_enhanced_permission_manager() + manager2 = get_enhanced_permission_manager() + assert manager1 is manager2 + + def test_init_enhanced_permission_manager(self): + """测试初始化全局实例""" + config = EnhancedPermissionConfig(enable_rule_engine=False) + manager = init_enhanced_permission_manager(config) + assert manager is not None + assert manager.get_rule_engine() is None + + def test_set_enhanced_permission_manager(self): + """测试设置全局实例""" + import jojo_code.security.enhanced as mod + + custom_manager = EnhancedPermissionManager() + set_enhanced_permission_manager(custom_manager) + assert get_enhanced_permission_manager() is custom_manager + + # 清理 + mod._enhanced_manager = None + + def test_init_creates_new_instance(self): + """测试初始化创建新实例""" + import jojo_code.security.enhanced as mod + + mod._enhanced_manager = None + + manager1 = get_enhanced_permission_manager() + config = EnhancedPermissionConfig(enable_rule_engine=False) + manager2 = init_enhanced_permission_manager(config) + assert manager1 is not manager2 + assert get_enhanced_permission_manager() is manager2 + + # 清理 + mod._enhanced_manager = None diff --git a/tests/test_security/test_guards.py b/tests/test_security/test_guards.py new file mode 100644 index 0000000..63e3388 --- /dev/null +++ b/tests/test_security/test_guards.py @@ -0,0 +1,279 @@ +"""权限守卫基类测试""" + +from typing import Any + +import pytest + +from jojo_code.security.guards import BaseGuard +from jojo_code.security.permission import PermissionLevel, PermissionResult + + +class ConcreteGuard(BaseGuard): + """用于测试的具体守卫实现""" + + def __init__( + self, + guard_name: str = "test_guard", + default_level: PermissionLevel = PermissionLevel.ALLOW, + ): + self._name = guard_name + self._default_level = default_level + self._check_calls: list[tuple[str, dict[str, Any]]] = [] + + @property + def name(self) -> str: + return self._name + + def check(self, tool_name: str, args: dict[str, Any]) -> PermissionResult: + self._check_calls.append((tool_name, args)) + return PermissionResult(self._default_level, tool_name, args) + + +class DenyGuard(BaseGuard): + """总是拒绝的守卫""" + + @property + def name(self) -> str: + return "deny_guard" + + def check(self, tool_name: str, args: dict[str, Any]) -> PermissionResult: + return PermissionResult( + PermissionLevel.DENY, + tool_name, + args, + reason="Always denied", + ) + + +class ConfirmGuard(BaseGuard): + """总是需要确认的守卫""" + + @property + def name(self) -> str: + return "confirm_guard" + + def check(self, tool_name: str, args: dict[str, Any]) -> PermissionResult: + return PermissionResult( + PermissionLevel.CONFIRM, + tool_name, + args, + reason="Please confirm", + ) + + +class SelectiveGuard(BaseGuard): + """根据工具名选择性拒绝的守卫""" + + BLOCKED_TOOLS = {"delete_file", "rm_command"} + + @property + def name(self) -> str: + return "selective_guard" + + def check(self, tool_name: str, args: dict[str, Any]) -> PermissionResult: + if tool_name in self.BLOCKED_TOOLS: + return PermissionResult( + PermissionLevel.DENY, + tool_name, + args, + reason=f"Tool {tool_name} is blocked", + ) + return PermissionResult(PermissionLevel.ALLOW, tool_name, args) + + +class TestBaseGuardAbstract: + """测试 BaseGuard 抽象类""" + + def test_cannot_instantiate_directly(self): + """测试不能直接实例化抽象类""" + with pytest.raises(TypeError): + BaseGuard() # type: ignore + + def test_concrete_guard_implements_check(self): + """测试具体守卫实现 check 方法""" + guard = ConcreteGuard() + result = guard.check("test_tool", {"key": "value"}) + assert isinstance(result, PermissionResult) + assert result.tool_name == "test_tool" + assert result.args == {"key": "value"} + + def test_concrete_guard_implements_name(self): + """测试具体守卫实现 name 属性""" + guard = ConcreteGuard(guard_name="my_guard") + assert guard.name == "my_guard" + + def test_name_is_property(self): + """测试 name 是属性而非方法""" + _ = ConcreteGuard() + assert isinstance(BaseGuard.name, property) + + def test_check_is_abstract(self): + """测试 check 是抽象方法""" + assert hasattr(BaseGuard, "check") + # Verify it's abstract by checking if it's in __abstractmethods__ + assert "check" in BaseGuard.__abstractmethods__ + + def test_name_is_abstract(self): + """测试 name 是抽象属性""" + assert "name" in BaseGuard.__abstractmethods__ + + +class TestConcreteGuard: + """测试具体守卫实现""" + + def test_allow_result(self): + """测试返回允许结果""" + guard = ConcreteGuard(default_level=PermissionLevel.ALLOW) + result = guard.check("read_file", {"path": "test.py"}) + assert result.allowed + assert not result.denied + assert not result.needs_confirm + + def test_deny_result(self): + """测试返回拒绝结果""" + guard = ConcreteGuard(default_level=PermissionLevel.DENY) + result = guard.check("write_file", {"path": "test.py"}) + assert result.denied + assert not result.allowed + + def test_confirm_result(self): + """测试返回确认结果""" + guard = ConcreteGuard(default_level=PermissionLevel.CONFIRM) + result = guard.check("run_command", {"command": "ls"}) + assert result.needs_confirm + assert not result.allowed + assert not result.denied + + def test_records_check_calls(self): + """测试记录检查调用""" + guard = ConcreteGuard() + guard.check("tool1", {"a": 1}) + guard.check("tool2", {"b": 2}) + assert len(guard._check_calls) == 2 + assert guard._check_calls[0] == ("tool1", {"a": 1}) + assert guard._check_calls[1] == ("tool2", {"b": 2}) + + def test_custom_guard_name(self): + """测试自定义守卫名称""" + guard = ConcreteGuard(guard_name="custom_name") + assert guard.name == "custom_name" + + +class TestDenyGuard: + """测试总是拒绝的守卫""" + + def test_always_denies(self): + """测试总是拒绝""" + guard = DenyGuard() + result = guard.check("any_tool", {}) + assert result.denied + assert result.reason == "Always denied" + + def test_name(self): + """测试守卫名称""" + guard = DenyGuard() + assert guard.name == "deny_guard" + + def test_denies_with_args(self): + """测试带参数时仍然拒绝""" + guard = DenyGuard() + result = guard.check("read_file", {"path": "/etc/passwd"}) + assert result.denied + assert result.tool_name == "read_file" + + +class TestConfirmGuard: + """测试总是需要确认的守卫""" + + def test_always_needs_confirm(self): + """测试总是需要确认""" + guard = ConfirmGuard() + result = guard.check("any_tool", {}) + assert result.needs_confirm + assert result.reason == "Please confirm" + + def test_name(self): + """测试守卫名称""" + guard = ConfirmGuard() + assert guard.name == "confirm_guard" + + +class TestSelectiveGuard: + """测试选择性守卫""" + + def test_allows_normal_tools(self): + """测试允许普通工具""" + guard = SelectiveGuard() + result = guard.check("read_file", {"path": "test.py"}) + assert result.allowed + + def test_denies_blocked_tools(self): + """测试拒绝被阻止的工具""" + guard = SelectiveGuard() + result = guard.check("delete_file", {"path": "test.py"}) + assert result.denied + assert "delete_file" in result.reason + + def test_denies_rm_command(self): + """测试拒绝 rm_command""" + guard = SelectiveGuard() + result = guard.check("rm_command", {"command": "rm -rf /"}) + assert result.denied + + def test_allows_other_tools(self): + """测试允许其他工具""" + guard = SelectiveGuard() + for tool in ["read_file", "write_file", "run_command", "git_status"]: + result = guard.check(tool, {}) + assert result.allowed, f"Expected {tool} to be allowed" + + def test_name(self): + """测试守卫名称""" + guard = SelectiveGuard() + assert guard.name == "selective_guard" + + +class TestGuardIntegration: + """测试守卫与其他组件的集成""" + + def test_guard_returns_permission_result(self): + """测试守卫返回 PermissionResult""" + guard = ConcreteGuard() + result = guard.check("test", {}) + assert isinstance(result, PermissionResult) + assert hasattr(result, "level") + assert hasattr(result, "tool_name") + assert hasattr(result, "args") + assert hasattr(result, "reason") + + def test_multiple_guards_independent(self): + """测试多个守卫独立工作""" + guard1 = ConcreteGuard(guard_name="guard1", default_level=PermissionLevel.ALLOW) + guard2 = DenyGuard() + guard3 = ConfirmGuard() + + result1 = guard1.check("tool", {}) + result2 = guard2.check("tool", {}) + result3 = guard3.check("tool", {}) + + assert result1.allowed + assert result2.denied + assert result3.needs_confirm + + def test_guard_with_empty_args(self): + """测试空参数""" + guard = ConcreteGuard() + result = guard.check("test_tool", {}) + assert result.tool_name == "test_tool" + assert result.args == {} + + def test_guard_with_complex_args(self): + """测试复杂参数""" + guard = ConcreteGuard() + complex_args = { + "command": "python -c 'print(1)'", + "timeout": 30, + "env": {"PATH": "/usr/bin"}, + } + result = guard.check("run_command", complex_args) + assert result.args == complex_args diff --git a/tests/test_security/test_manager.py b/tests/test_security/test_manager.py new file mode 100644 index 0000000..527a7b1 --- /dev/null +++ b/tests/test_security/test_manager.py @@ -0,0 +1,439 @@ +"""Tests for the permission manager module.""" + +import tempfile +from pathlib import Path + +import pytest + +from jojo_code.security.manager import ( + PermissionConfig, + PermissionManager, + get_permission_manager, + init_permission_manager, + set_permission_manager, +) +from jojo_code.security.modes import PermissionMode +from jojo_code.security.permission import PermissionLevel + + +class TestPermissionConfig: + """Tests for PermissionConfig dataclass.""" + + def test_default_config(self): + """Test default configuration values.""" + config = PermissionConfig() + assert config.workspace_root == Path(".") + assert config.allow_outside is False + assert config.allowed_paths == ["*"] + assert config.denied_paths == [] + assert config.confirm_on_write == [] + assert config.shell_enabled is True + assert config.allowed_commands == [] + assert config.denied_commands == ["rm -rf /", "sudo"] + assert config.shell_default == PermissionLevel.CONFIRM + assert config.max_timeout == 300 + assert config.allow_network is False + assert config.max_tool_calls == 100 + assert config.audit_log is True + assert config.audit_log_path == Path(".jojo-code/audit.log") + assert config.mode == "auto" + + def test_post_init_string_paths(self): + """Test that string paths are converted to Path objects.""" + config = PermissionConfig( + workspace_root="/tmp/workspace", + audit_log_path="/tmp/audit.log", + ) + assert isinstance(config.workspace_root, Path) + assert config.workspace_root == Path("/tmp/workspace") + assert isinstance(config.audit_log_path, Path) + assert config.audit_log_path == Path("/tmp/audit.log") + + def test_development_config(self): + """Test development preset configuration.""" + config = PermissionConfig.development() + assert config.workspace_root == Path(".") + assert config.allow_outside is False + assert ".env" in config.denied_paths + assert "*.pem" in config.denied_paths + assert "*.key" in config.denied_paths + assert config.shell_enabled is True + assert config.shell_default == PermissionLevel.CONFIRM + assert "rm -rf /" in config.denied_commands + assert "rm -rf ~" in config.denied_commands + assert "sudo" in config.denied_commands + + def test_production_config(self): + """Test production preset configuration.""" + config = PermissionConfig.production() + assert config.allow_outside is False + assert "src/**" in config.allowed_paths + assert "tests/**" in config.allowed_paths + assert ".env" in config.denied_paths + assert ".git/**" in config.denied_paths + assert config.shell_enabled is True + assert "ls" in config.allowed_commands + assert "pytest" in config.allowed_commands + assert config.allow_network is False + assert config.max_tool_calls == 50 + assert config.audit_log is True + + def test_from_yaml_missing_file(self): + """Test loading from a non-existent YAML file returns default.""" + config = PermissionConfig.from_yaml(Path("/nonexistent/config.yaml")) + assert config.workspace_root == Path(".") + assert config.mode == "auto" + + def test_from_yaml_valid_file(self): + """Test loading from a valid YAML file.""" + yaml_content = """ +workspace: + root: /tmp/myproject + allow_outside: true +file: + allowed_paths: + - "src/**" + - "tests/**" + denied_paths: + - ".env" + - "secrets/**" +shell: + enabled: true + allowed_commands: + - "ls" + - "cat" + denied_commands: + - "rm -rf /" + default: "allow" + max_timeout: 600 + allow_network: true +global: + max_tool_calls: 200 + audit_log: false + audit_log_path: "/tmp/audit.log" +""" + with tempfile.NamedTemporaryFile( + mode="w", suffix=".yaml", delete=False, encoding="utf-8" + ) as f: + f.write(yaml_content) + f.flush() + config = PermissionConfig.from_yaml(Path(f.name)) + + assert config.workspace_root == Path("/tmp/myproject") + assert config.allow_outside is True + assert config.allowed_paths == ["src/**", "tests/**"] + assert config.denied_paths == [".env", "secrets/**"] + assert config.shell_enabled is True + assert config.allowed_commands == ["ls", "cat"] + assert config.shell_default == PermissionLevel.ALLOW + assert config.max_timeout == 600 + assert config.allow_network is True + assert config.max_tool_calls == 200 + assert config.audit_log is False + assert config.audit_log_path == Path("/tmp/audit.log") + + def test_from_yaml_empty_file(self): + """Test loading from an empty YAML file returns default.""" + with tempfile.NamedTemporaryFile( + mode="w", suffix=".yaml", delete=False, encoding="utf-8" + ) as f: + f.write("") + f.flush() + config = PermissionConfig.from_yaml(Path(f.name)) + + assert config.workspace_root == Path(".") + assert config.mode == "auto" + + +class TestPermissionManager: + """Tests for PermissionManager.""" + + def _make_manager(self, **kwargs) -> PermissionManager: + """Helper to create a PermissionManager with given config overrides.""" + config = PermissionConfig(**kwargs) + return PermissionManager(config) + + def test_init_default(self): + """Test default initialization.""" + manager = PermissionManager(PermissionConfig()) + assert manager.mode == PermissionMode.AUTO + assert len(manager.guards) == 2 # PathGuard + CommandGuard + assert manager._call_count == 0 + assert manager.get_audit_log() == [] + + def test_mode_property(self): + """Test mode property returns current mode.""" + manager = self._make_manager(mode="auto") + assert manager.mode == PermissionMode.AUTO + + manager.set_mode("manual") + assert manager.mode == PermissionMode.MANUAL + + def test_set_mode_valid(self): + """Test setting valid permission modes.""" + manager = self._make_manager() + + manager.set_mode("auto") + assert manager.mode == PermissionMode.AUTO + assert manager.config.mode == "auto" + + manager.set_mode("manual") + assert manager.mode == PermissionMode.MANUAL + assert manager.config.mode == "manual" + + manager.set_mode("bypass") + assert manager.mode == PermissionMode.BYPASS + assert manager.config.mode == "bypass" + + def test_set_mode_invalid(self): + """Test setting an invalid permission mode raises ValueError.""" + manager = self._make_manager() + with pytest.raises(ValueError, match="无效的权限模式"): + manager.set_mode("invalid_mode") + + def test_bypass_mode_allows_all(self): + """Test that BYPASS mode allows all operations.""" + manager = self._make_manager(mode="bypass") + result = manager.check("read_file", {"path": "/etc/passwd"}) + assert result.level == PermissionLevel.ALLOW + + result = manager.check("run_command", {"command": "rm -rf /"}) + assert result.level == PermissionLevel.ALLOW + + def test_max_tool_calls_limit(self): + """Test that max tool calls limit is enforced.""" + manager = self._make_manager(max_tool_calls=2, mode="bypass") + + # First two calls should succeed + result = manager.check("read_file", {"path": "test.txt"}) + assert result.level == PermissionLevel.ALLOW + + result = manager.check("read_file", {"path": "test.txt"}) + assert result.level == PermissionLevel.ALLOW + + # Third call should be denied (call count check happens before bypass) + # Note: In BYPASS mode, the call count check is skipped because + # BYPASS returns immediately. Let's test with AUTO mode instead. + manager = self._make_manager(max_tool_calls=1, mode="auto") + result = manager.check("read_file", {"path": "test.txt"}) + assert result.level == PermissionLevel.ALLOW + + result = manager.check("read_file", {"path": "test.txt"}) + assert result.level == PermissionLevel.DENY + assert "已达到最大调用次数" in (result.reason or "") + + def test_manual_mode_denies_medium_risk(self): + """Test that MANUAL mode denies medium-risk operations.""" + manager = self._make_manager(mode="manual") + # write_file is medium risk + result = manager.check("write_file", {"path": "test.py"}) + assert result.level == PermissionLevel.DENY + assert "Manual 模式拒绝" in (result.reason or "") + + def test_manual_mode_denies_high_risk(self): + """Test that MANUAL mode denies high-risk operations.""" + manager = self._make_manager(mode="manual") + # rm is high risk + result = manager.check("run_command", {"command": "rm file.txt"}) + assert result.level == PermissionLevel.DENY + assert "Manual 模式拒绝" in (result.reason or "") + + def test_manual_mode_denies_critical_risk(self): + """Test that MANUAL mode denies critical-risk operations.""" + manager = self._make_manager(mode="manual") + result = manager.check("run_command", {"command": "sudo apt update"}) + assert result.level == PermissionLevel.DENY + assert "Manual 模式拒绝" in (result.reason or "") + + def test_auto_mode_allows_low_risk(self): + """Test that AUTO mode allows low-risk operations.""" + manager = self._make_manager(mode="auto") + result = manager.check("read_file", {"path": "test.txt"}) + assert result.level == PermissionLevel.ALLOW + + def test_auto_mode_confirms_medium_risk(self): + """Test that AUTO mode requires confirmation for medium-risk operations.""" + manager = self._make_manager(mode="auto") + # write_file is medium risk + result = manager.check("write_file", {"path": "test.py"}) + assert result.level == PermissionLevel.CONFIRM + assert "需要确认" in (result.reason or "") + + def test_auto_mode_confirms_high_risk(self): + """Test that AUTO mode requires confirmation for high-risk operations.""" + manager = self._make_manager(mode="auto") + result = manager.check("run_command", {"command": "rm file.txt"}) + assert result.level == PermissionLevel.CONFIRM + + def test_call_count_increments(self): + """Test that call count increments on allowed operations.""" + manager = self._make_manager(mode="auto") + assert manager._call_count == 0 + + manager.check("read_file", {"path": "test.txt"}) + assert manager._call_count == 1 + + manager.check("git_status", {}) + assert manager._call_count == 2 + + def test_reset_call_count(self): + """Test resetting the call count.""" + manager = self._make_manager(mode="auto", max_tool_calls=5) + + for _ in range(3): + manager.check("read_file", {"path": "test.txt"}) + assert manager._call_count == 3 + + manager.reset_call_count() + assert manager._call_count == 0 + + def test_audit_log_records_calls(self): + """Test that audit log records tool calls.""" + with tempfile.TemporaryDirectory() as tmpdir: + log_path = Path(tmpdir) / "audit.log" + manager = self._make_manager( + mode="auto", + audit_log=True, + audit_log_path=str(log_path), + ) + + manager.check("read_file", {"path": "test.txt"}) + log = manager.get_audit_log() + + assert len(log) >= 1 + assert log[0]["tool"] == "read_file" + assert log[0]["args"] == {"path": "test.txt"} + assert "timestamp" in log[0] + + def test_audit_log_disabled(self): + """Test that audit log can be disabled.""" + manager = self._make_manager(mode="auto", audit_log=False) + manager.check("read_file", {"path": "test.txt"}) + assert manager.get_audit_log() == [] + + def test_flush_writes_to_file(self): + """Test that flush writes buffered log entries to file.""" + with tempfile.TemporaryDirectory() as tmpdir: + log_path = Path(tmpdir) / "audit.log" + manager = self._make_manager( + mode="auto", + audit_log=True, + audit_log_path=str(log_path), + ) + + # Make a few calls (less than buffer size) + manager.check("read_file", {"path": "test.txt"}) + manager.check("git_status", {}) + + # File shouldn't exist yet (buffer not full) + # Actually it will exist if buffer is flushed, let's check + manager.flush() + + assert log_path.exists() + lines = log_path.read_text(encoding="utf-8").strip().split("\n") + # May have duplicate entries due to internal logging + assert len(lines) >= 2 + + def test_flush_empty_buffer(self): + """Test that flush on empty buffer does nothing.""" + with tempfile.TemporaryDirectory() as tmpdir: + log_path = Path(tmpdir) / "audit.log" + manager = self._make_manager( + mode="auto", + audit_log=True, + audit_log_path=str(log_path), + ) + + manager.flush() + assert not log_path.exists() + + def test_log_buffer_auto_flush(self): + """Test that log buffer auto-flushes when full.""" + with tempfile.TemporaryDirectory() as tmpdir: + log_path = Path(tmpdir) / "audit.log" + manager = self._make_manager( + mode="auto", + audit_log=True, + audit_log_path=str(log_path), + ) + + # Make enough calls to fill the buffer (LOG_BUFFER_SIZE = 10) + for _ in range(10): + manager.check("read_file", {"path": "test.txt"}) + + # Buffer should have been flushed + assert log_path.exists() + lines = log_path.read_text(encoding="utf-8").strip().split("\n") + # May have duplicate entries due to internal logging + assert len(lines) >= 10 + + def test_denied_paths(self): + """Test that denied paths are rejected.""" + manager = self._make_manager( + mode="auto", + denied_paths=[".env", "*.pem"], + ) + result = manager.check("read_file", {"path": ".env"}) + # PathGuard should deny this + assert result.level == PermissionLevel.DENY + + def test_disabled_shell(self): + """Test that shell commands are denied when shell is disabled.""" + manager = self._make_manager(mode="auto", shell_enabled=False) + result = manager.check("run_command", {"command": "ls"}) + assert result.level == PermissionLevel.DENY + + +class TestGlobalPermissionManager: + """Tests for global permission manager functions.""" + + def test_get_permission_manager_none(self): + """Test getting global manager when none is set.""" + # Save and restore + from jojo_code.security import manager as mgr_module + + old = mgr_module._permission_manager + mgr_module._permission_manager = None + try: + assert get_permission_manager() is None + finally: + mgr_module._permission_manager = old + + def test_set_and_get_permission_manager(self): + """Test setting and getting global manager.""" + from jojo_code.security import manager as mgr_module + + old = mgr_module._permission_manager + try: + config = PermissionConfig() + mgr = PermissionManager(config) + set_permission_manager(mgr) + assert get_permission_manager() is mgr + finally: + mgr_module._permission_manager = old + + def test_init_permission_manager_default(self): + """Test initializing with default config.""" + from jojo_code.security import manager as mgr_module + + old = mgr_module._permission_manager + try: + mgr = init_permission_manager() + assert isinstance(mgr, PermissionManager) + assert mgr.mode == PermissionMode.AUTO + assert get_permission_manager() is mgr + finally: + mgr_module._permission_manager = old + + def test_init_permission_manager_custom(self): + """Test initializing with custom config.""" + from jojo_code.security import manager as mgr_module + + old = mgr_module._permission_manager + try: + config = PermissionConfig(mode="manual") + mgr = init_permission_manager(config) + assert mgr.mode == PermissionMode.MANUAL + assert get_permission_manager() is mgr + finally: + mgr_module._permission_manager = old diff --git a/tests/test_security/test_modes.py b/tests/test_security/test_modes.py index ee3e3fd..7f62e0d 100644 --- a/tests/test_security/test_modes.py +++ b/tests/test_security/test_modes.py @@ -1,4 +1,4 @@ -"""测试权限模式和风险等级""" +"""Tests for permission modes and risk levels.""" import pytest @@ -6,67 +6,108 @@ class TestPermissionMode: - """测试 PermissionMode 枚举""" + """Tests for PermissionMode enum.""" def test_all_modes_exist(self): - """测试所有模式都存在""" + """Test that all expected modes exist.""" assert PermissionMode.AUTO.value == "auto" assert PermissionMode.MANUAL.value == "manual" assert PermissionMode.BYPASS.value == "bypass" + def test_mode_count(self): + """Test that there are exactly 3 permission modes.""" + assert len(PermissionMode) == 3 + + def test_mode_is_string_enum(self): + """Test that PermissionMode values are strings.""" + for mode in PermissionMode: + assert isinstance(mode.value, str) + def test_allows_write(self): - """测试写操作权限""" - # AUTO 和 MANUAL 允许写操作(需要确认) + """Test write permission for each mode.""" + # AUTO and MANUAL allow writes (with confirmation) assert PermissionMode.AUTO.allows_write() is True assert PermissionMode.MANUAL.allows_write() is True - # BYPASS 当前实现不允许写操作(这可能是实现错误) + # BYPASS does not allow write (current implementation) assert PermissionMode.BYPASS.allows_write() is False def test_requires_confirmation_bypass(self): - """BYPASS 模式永远不需要确认""" + """BYPASS mode never requires confirmation.""" assert PermissionMode.BYPASS.requires_confirmation(RiskLevel.LOW) is False assert PermissionMode.BYPASS.requires_confirmation(RiskLevel.MEDIUM) is False assert PermissionMode.BYPASS.requires_confirmation(RiskLevel.HIGH) is False assert PermissionMode.BYPASS.requires_confirmation(RiskLevel.CRITICAL) is False def test_requires_confirmation_auto(self): - """AUTO 模式 MEDIUM 及以上需要确认""" + """AUTO mode requires confirmation for MEDIUM and above.""" assert PermissionMode.AUTO.requires_confirmation(RiskLevel.LOW) is False assert PermissionMode.AUTO.requires_confirmation(RiskLevel.MEDIUM) is True assert PermissionMode.AUTO.requires_confirmation(RiskLevel.HIGH) is True assert PermissionMode.AUTO.requires_confirmation(RiskLevel.CRITICAL) is True def test_requires_confirmation_manual(self): - """MANUAL 模式所有操作都需要确认""" + """MANUAL mode requires confirmation for all operations.""" assert PermissionMode.MANUAL.requires_confirmation(RiskLevel.LOW) is True assert PermissionMode.MANUAL.requires_confirmation(RiskLevel.MEDIUM) is True assert PermissionMode.MANUAL.requires_confirmation(RiskLevel.HIGH) is True assert PermissionMode.MANUAL.requires_confirmation(RiskLevel.CRITICAL) is True def test_from_string_valid(self): - """测试从字符串解析有效模式""" + """Test parsing valid mode strings.""" assert PermissionMode.from_string("auto") == PermissionMode.AUTO assert PermissionMode.from_string("manual") == PermissionMode.MANUAL assert PermissionMode.from_string("bypass") == PermissionMode.BYPASS def test_from_string_invalid(self): - """测试从字符串解析无效模式""" + """Test parsing an invalid mode string raises ValueError.""" with pytest.raises(ValueError, match="无效的权限模式"): PermissionMode.from_string("invalid_mode") + def test_from_string_empty(self): + """Test parsing an empty mode string raises ValueError.""" + with pytest.raises(ValueError, match="无效的权限模式"): + PermissionMode.from_string("") + + def test_str_representation(self): + """Test string representation of modes.""" + assert str(PermissionMode.AUTO) == "auto" + assert str(PermissionMode.MANUAL) == "manual" + assert str(PermissionMode.BYPASS) == "bypass" + + def test_mode_equality(self): + """Test mode equality comparison.""" + assert PermissionMode.AUTO == PermissionMode.AUTO + assert PermissionMode.AUTO != PermissionMode.MANUAL + assert PermissionMode.AUTO != PermissionMode.BYPASS + + def test_mode_identity(self): + """Test that from_string returns the same enum instance.""" + assert PermissionMode.from_string("auto") is PermissionMode.AUTO + assert PermissionMode.from_string("manual") is PermissionMode.MANUAL + assert PermissionMode.from_string("bypass") is PermissionMode.BYPASS + class TestRiskLevel: - """测试 RiskLevel 枚举""" + """Tests for RiskLevel enum.""" def test_all_levels_exist(self): - """测试所有风险等级都存在""" + """Test that all expected risk levels exist.""" assert RiskLevel.LOW.value == "low" assert RiskLevel.MEDIUM.value == "medium" assert RiskLevel.HIGH.value == "high" assert RiskLevel.CRITICAL.value == "critical" + def test_level_count(self): + """Test that there are exactly 4 risk levels.""" + assert len(RiskLevel) == 4 + + def test_level_is_string_enum(self): + """Test that RiskLevel values are strings.""" + for level in RiskLevel: + assert isinstance(level.value, str) + def test_comparison_less_than(self): - """测试风险等级小于比较""" + """Test strict less-than comparison.""" assert RiskLevel.LOW < RiskLevel.MEDIUM assert RiskLevel.LOW < RiskLevel.HIGH assert RiskLevel.LOW < RiskLevel.CRITICAL @@ -75,7 +116,7 @@ def test_comparison_less_than(self): assert RiskLevel.HIGH < RiskLevel.CRITICAL def test_comparison_greater_than(self): - """测试风险等级大于比较""" + """Test strict greater-than comparison.""" assert RiskLevel.CRITICAL > RiskLevel.HIGH assert RiskLevel.CRITICAL > RiskLevel.MEDIUM assert RiskLevel.CRITICAL > RiskLevel.LOW @@ -84,31 +125,107 @@ def test_comparison_greater_than(self): assert RiskLevel.MEDIUM > RiskLevel.LOW def test_comparison_equal(self): - """测试风险等级等于比较""" + """Test equality comparison.""" assert RiskLevel.LOW == RiskLevel.LOW - assert RiskLevel.LOW <= RiskLevel.LOW - assert RiskLevel.LOW >= RiskLevel.LOW + assert RiskLevel.MEDIUM == RiskLevel.MEDIUM + assert RiskLevel.HIGH == RiskLevel.HIGH + assert RiskLevel.CRITICAL == RiskLevel.CRITICAL + + def test_comparison_not_equal(self): + """Test not-equal comparison between different levels.""" + assert RiskLevel.LOW != RiskLevel.MEDIUM + assert RiskLevel.LOW != RiskLevel.HIGH + assert RiskLevel.LOW != RiskLevel.CRITICAL + assert RiskLevel.MEDIUM != RiskLevel.HIGH def test_comparison_less_equal(self): - """测试风险等级小于等于比较""" + """Test less-than-or-equal comparison.""" assert RiskLevel.LOW <= RiskLevel.MEDIUM assert RiskLevel.LOW <= RiskLevel.LOW assert RiskLevel.MEDIUM <= RiskLevel.HIGH + assert RiskLevel.MEDIUM <= RiskLevel.MEDIUM def test_comparison_greater_equal(self): - """测试风险等级大于等于比较""" + """Test greater-than-or-equal comparison.""" assert RiskLevel.CRITICAL >= RiskLevel.HIGH assert RiskLevel.CRITICAL >= RiskLevel.CRITICAL assert RiskLevel.HIGH >= RiskLevel.MEDIUM + assert RiskLevel.HIGH >= RiskLevel.HIGH + + def test_ordering_consistency(self): + """Test that ordering is consistent across all comparisons.""" + levels = [RiskLevel.LOW, RiskLevel.MEDIUM, RiskLevel.HIGH, RiskLevel.CRITICAL] + for i, a in enumerate(levels): + for j, b in enumerate(levels): + if i < j: + assert a < b + assert a <= b + assert not a > b + assert not a >= b + assert a != b + elif i == j: + assert a == b + assert a <= b + assert a >= b + assert not a < b + assert not a > b + else: + assert a > b + assert a >= b + assert not a < b + assert not a <= b + assert a != b def test_from_string_valid(self): - """测试从字符串解析有效风险等级""" + """Test parsing valid risk level strings.""" assert RiskLevel.from_string("low") == RiskLevel.LOW assert RiskLevel.from_string("medium") == RiskLevel.MEDIUM assert RiskLevel.from_string("high") == RiskLevel.HIGH assert RiskLevel.from_string("critical") == RiskLevel.CRITICAL def test_from_string_invalid(self): - """测试从字符串解析无效风险等级""" + """Test parsing an invalid risk level string raises ValueError.""" with pytest.raises(ValueError, match="无效的风险等级"): RiskLevel.from_string("extreme") + + def test_from_string_empty(self): + """Test parsing an empty risk level string raises ValueError.""" + with pytest.raises(ValueError, match="无效的风险等级"): + RiskLevel.from_string("") + + def test_from_string_case_sensitive(self): + """Test that from_string is case-sensitive.""" + with pytest.raises(ValueError, match="无效的风险等级"): + RiskLevel.from_string("LOW") + with pytest.raises(ValueError, match="无效的风险等级"): + RiskLevel.from_string("Medium") + + def test_str_representation(self): + """Test string representation of levels.""" + assert str(RiskLevel.LOW) == "low" + assert str(RiskLevel.MEDIUM) == "medium" + assert str(RiskLevel.HIGH) == "high" + assert str(RiskLevel.CRITICAL) == "critical" + + def test_from_string_identity(self): + """Test that from_string returns the same enum instance.""" + assert RiskLevel.from_string("low") is RiskLevel.LOW + assert RiskLevel.from_string("medium") is RiskLevel.MEDIUM + assert RiskLevel.from_string("high") is RiskLevel.HIGH + assert RiskLevel.from_string("critical") is RiskLevel.CRITICAL + + def test_risk_levels_in_ascending_order(self): + """Test that risk levels can be sorted correctly.""" + levels = [ + RiskLevel.CRITICAL, + RiskLevel.LOW, + RiskLevel.HIGH, + RiskLevel.MEDIUM, + ] + sorted_levels = sorted(levels) + assert sorted_levels == [ + RiskLevel.LOW, + RiskLevel.MEDIUM, + RiskLevel.HIGH, + RiskLevel.CRITICAL, + ] diff --git a/tests/test_security/test_risk.py b/tests/test_security/test_risk.py index 7c7556b..9edbc11 100644 --- a/tests/test_security/test_risk.py +++ b/tests/test_security/test_risk.py @@ -1,178 +1,430 @@ -"""测试风险评估模块""" +"""Tests for the risk assessment module.""" -from jojo_code.security.risk import RISK_PATTERNS, assess_risk, get_risk_description +import re + +from jojo_code.security.risk import ( + RISK_PATTERNS, + _get_compiled_patterns, + assess_risk, + get_risk_description, +) class TestRiskPatterns: - """测试风险模式定义""" + """Tests for risk pattern definitions.""" def test_critical_patterns_exist(self): - """测试 critical 风险模式存在""" + """Test that critical risk patterns exist.""" assert "critical" in RISK_PATTERNS assert len(RISK_PATTERNS["critical"]) > 0 def test_high_patterns_exist(self): - """测试 high 风险模式存在""" + """Test that high risk patterns exist.""" assert "high" in RISK_PATTERNS assert len(RISK_PATTERNS["high"]) > 0 def test_medium_patterns_exist(self): - """测试 medium 风险模式存在""" + """Test that medium risk patterns exist.""" assert "medium" in RISK_PATTERNS assert len(RISK_PATTERNS["medium"]) > 0 def test_low_patterns_exist(self): - """测试 low 风险模式存在""" + """Test that low risk patterns exist.""" assert "low" in RISK_PATTERNS assert len(RISK_PATTERNS["low"]) > 0 + def test_all_pattern_keys(self): + """Test that all expected risk levels have patterns.""" + expected_keys = {"critical", "high", "medium", "low"} + assert set(RISK_PATTERNS.keys()) == expected_keys + + def test_patterns_are_valid_regex(self): + """Test that all patterns compile as valid regular expressions.""" + for level, patterns in RISK_PATTERNS.items(): + for pattern in patterns: + try: + re.compile(pattern) + except re.error as err: + raise AssertionError( + f"Invalid regex pattern in {level}: {pattern}" + ) from err + + def test_critical_patterns_include_rm_rf(self): + """Test that critical patterns include rm -rf variants.""" + critical_text = " ".join(RISK_PATTERNS["critical"]) + assert "rm" in critical_text + assert "sudo" in critical_text + + def test_patterns_are_non_empty_strings(self): + """Test that all patterns are non-empty strings.""" + for _level, patterns in RISK_PATTERNS.items(): + for pattern in patterns: + assert isinstance(pattern, str) + assert len(pattern) > 0 + + +class TestCompiledPatterns: + """Tests for the compiled pattern cache.""" + + def test_get_compiled_patterns_returns_list(self): + """Test that _get_compiled_patterns returns a list.""" + result = _get_compiled_patterns("low") + assert isinstance(result, list) + + def test_get_compiled_patterns_are_regex(self): + """Test that compiled patterns are compiled regex objects.""" + for level in RISK_PATTERNS: + patterns = _get_compiled_patterns(level) + for p in patterns: + assert isinstance(p, re.Pattern) + + def test_get_compiled_patterns_caching(self): + """Test that compiled patterns are cached.""" + result1 = _get_compiled_patterns("low") + result2 = _get_compiled_patterns("low") + assert result1 is result2 + + def test_get_compiled_patterns_unknown_level(self): + """Test that unknown level returns empty list.""" + result = _get_compiled_patterns("nonexistent") + assert result == [] + class TestAssessRisk: - """测试 assess_risk 函数""" + """Tests for the assess_risk function.""" - # ─── 低风险测试 ─── + # ─── Low risk tests (read-only tools) ─── def test_read_file_low_risk(self): - """读取文件是低风险""" + """Reading a file is low risk.""" assert assess_risk("read_file", {"path": "/tmp/test.txt"}) == "low" def test_list_directory_low_risk(self): - """列出目录是低风险""" + """Listing a directory is low risk.""" assert assess_risk("list_directory", {"path": "."}) == "low" def test_grep_search_low_risk(self): - """grep 搜索是低风险""" + """Grep search is low risk.""" assert assess_risk("grep_search", {"pattern": "test"}) == "low" + def test_glob_search_low_risk(self): + """Glob search is low risk.""" + assert assess_risk("glob_search", {"pattern": "*.py"}) == "low" + def test_git_status_low_risk(self): - """git status 是低风险""" + """git status is low risk.""" assert assess_risk("git_status", {}) == "low" def test_git_log_low_risk(self): - """git log 是低风险""" + """git log is low risk.""" assert assess_risk("git_log", {}) == "low" + def test_git_diff_low_risk(self): + """git diff is low risk.""" + assert assess_risk("git_diff", {}) == "low" + + def test_git_blame_low_risk(self): + """git blame is low risk.""" + assert assess_risk("git_blame", {"path": "test.py"}) == "low" + + def test_git_branch_low_risk(self): + """git branch is low risk.""" + assert assess_risk("git_branch", {}) == "low" + + def test_web_search_low_risk(self): + """Web search is low risk.""" + assert assess_risk("web_search", {"query": "python"}) == "low" + def test_ls_command_low_risk(self): - """ls 命令是低风险""" + """ls command is low risk.""" assert assess_risk("run_command", {"command": "ls -la"}) == "low" def test_cat_command_low_risk(self): - """cat 命令是低风险""" + """cat command is low risk.""" assert assess_risk("run_command", {"command": "cat file.txt"}) == "low" - # ─── 中等风险测试 ─── + def test_find_command_low_risk(self): + """find command is low risk.""" + assert assess_risk("run_command", {"command": "find . -name '*.py'"}) == "low" + + def test_git_show_command_low_risk(self): + """git show command is low risk.""" + assert assess_risk("run_command", {"command": "git show HEAD"}) == "low" + + def test_tail_command_low_risk(self): + """tail command is low risk.""" + assert assess_risk("run_command", {"command": "tail -f log.txt"}) == "low" + + # ─── Medium risk tests ─── def test_write_file_medium_risk(self): - """写入普通文件是中等风险""" + """Writing a normal file is medium risk.""" assert assess_risk("write_file", {"path": "src/test.py"}) == "medium" def test_edit_file_medium_risk(self): - """编辑普通文件是中等风险""" + """Editing a normal file is medium risk.""" assert assess_risk("edit_file", {"path": "config.json"}) == "medium" def test_npm_install_medium_risk(self): - """npm install 是中等风险""" + """npm install is medium risk.""" assert assess_risk("run_command", {"command": "npm install"}) == "medium" def test_pip_install_medium_risk(self): - """pip install 是中等风险""" + """pip install is medium risk.""" assert assess_risk("run_command", {"command": "pip install requests"}) == "medium" def test_python_command_medium_risk(self): - """python 命令是中等风险""" + """python command is medium risk.""" assert assess_risk("run_command", {"command": "python script.py"}) == "medium" - # ─── 高风险测试 ─── + def test_node_command_medium_risk(self): + """node command is medium risk.""" + assert assess_risk("run_command", {"command": "node app.js"}) == "medium" + + def test_git_commit_medium_risk(self): + """git commit is medium risk.""" + assert assess_risk("git_commit", {}) == "medium" + + def test_git_commit_command_medium_risk(self): + """git commit via run_command is medium risk.""" + assert assess_risk("run_command", {"command": "git commit -m 'test'"}) == "medium" + + def test_git_checkout_medium_risk(self): + """git checkout via run_command is medium risk.""" + assert assess_risk("run_command", {"command": "git checkout main"}) == "medium" + + def test_mv_command_medium_risk(self): + """mv command is medium risk.""" + assert assess_risk("run_command", {"command": "mv old.txt new.txt"}) == "medium" + + def test_cp_command_medium_risk(self): + """cp command is medium risk.""" + assert assess_risk("run_command", {"command": "cp src dst"}) == "medium" + + def test_tar_command_medium_risk(self): + """tar command is medium risk.""" + assert assess_risk("run_command", {"command": "tar -czf archive.tar.gz dir/"}) == "medium" + + def test_uv_install_medium_risk(self): + """uv install is medium risk.""" + assert assess_risk("run_command", {"command": "uv add requests"}) == "medium" + + def test_yarn_install_medium_risk(self): + """yarn install is medium risk.""" + assert assess_risk("run_command", {"command": "yarn add lodash"}) == "medium" + + # ─── High risk tests ─── def test_write_to_etc_high_risk(self): - """写入 /etc 是高风险""" + """Writing to /etc is high risk.""" assert assess_risk("write_file", {"path": "/etc/config"}) == "high" def test_write_to_usr_high_risk(self): - """写入 /usr 是高风险""" + """Writing to /usr is high risk.""" assert assess_risk("write_file", {"path": "/usr/local/bin/script"}) == "high" + def test_write_to_var_high_risk(self): + """Writing to /var is high risk.""" + assert assess_risk("write_file", {"path": "/var/log/app.log"}) == "high" + + def test_write_to_root_high_risk(self): + """Writing to /root is high risk.""" + assert assess_risk("write_file", {"path": "/root/.bashrc"}) == "high" + def test_write_env_file_high_risk(self): - """写入 .env 文件是高风险""" + """Writing to .env file is high risk.""" assert assess_risk("write_file", {"path": ".env"}) == "high" def test_write_credentials_high_risk(self): - """写入 credentials 文件是高风险""" + """Writing to credentials file is high risk.""" assert assess_risk("write_file", {"path": "credentials.json"}) == "high" + def test_write_secrets_high_risk(self): + """Writing to secrets file is high risk.""" + assert assess_risk("write_file", {"path": "secrets/config.yaml"}) == "high" + + def test_write_pem_high_risk(self): + """Writing to .pem file is high risk.""" + assert assess_risk("write_file", {"path": "cert.pem"}) == "high" + + def test_write_key_high_risk(self): + """Writing to .key file is high risk.""" + assert assess_risk("write_file", {"path": "private.key"}) == "high" + + def test_write_id_rsa_high_risk(self): + """Writing to id_rsa file is high risk.""" + assert assess_risk("write_file", {"path": "/home/user/.ssh/id_rsa"}) == "high" + def test_rm_command_high_risk(self): - """rm 命令是高风险""" + """rm command is high risk.""" assert assess_risk("run_command", {"command": "rm file.txt"}) == "high" def test_git_push_high_risk(self): - """git push 是高风险""" + """git push is high risk.""" assert assess_risk("run_command", {"command": "git push origin main"}) == "high" + def test_git_push_force_high_risk(self): + """git push --force is high risk.""" + assert assess_risk("run_command", {"command": "git push --force"}) == "high" + + def test_git_push_f_high_risk(self): + """git push -f is high risk.""" + assert assess_risk("run_command", {"command": "git push -f origin main"}) == "high" + + def test_git_reset_hard_high_risk(self): + """git reset --hard is high risk.""" + assert assess_risk("run_command", {"command": "git reset --hard HEAD~1"}) == "high" + def test_docker_run_high_risk(self): - """docker run 是高风险""" + """docker run is high risk.""" assert assess_risk("run_command", {"command": "docker run -it ubuntu"}) == "high" - # ─── 极高风险测试 ─── + def test_docker_exec_high_risk(self): + """docker exec is high risk.""" + assert assess_risk("run_command", {"command": "docker exec -it container bash"}) == "high" + + def test_npm_publish_high_risk(self): + """npm publish is high risk.""" + assert assess_risk("run_command", {"command": "npm publish"}) == "high" + + def test_write_to_etc_via_redirect_high_risk(self): + """Writing to /etc via redirect is high risk.""" + assert assess_risk("run_command", {"command": "echo 'test' > /etc/hostname"}) == "high" + + # ─── Critical risk tests ─── def test_rm_rf_root_critical(self): - """rm -rf / 是极高风险""" + """rm -rf / is critical risk.""" assert assess_risk("run_command", {"command": "rm -rf /"}) == "critical" def test_rm_rf_home_critical(self): - """rm -rf ~ 是极高风险""" + """rm -rf ~ is critical risk.""" assert assess_risk("run_command", {"command": "rm -rf ~"}) == "critical" + def test_rm_rf_star_critical(self): + """rm -rf * is critical risk.""" + assert assess_risk("run_command", {"command": "rm -rf *"}) == "critical" + def test_sudo_critical(self): - """sudo 是极高风险""" + """sudo is critical risk.""" assert assess_risk("run_command", {"command": "sudo apt update"}) == "critical" def test_chmod_777_critical(self): - """chmod 777 是极高风险""" + """chmod 777 is critical risk.""" assert assess_risk("run_command", {"command": "chmod 777 /etc/passwd"}) == "critical" def test_mkfs_critical(self): - """mkfs 是极高风险""" + """mkfs is critical risk.""" assert assess_risk("run_command", {"command": "mkfs.ext4 /dev/sda1"}) == "critical" - # ─── 边界测试 ─── + def test_dd_command_critical(self): + """dd command is critical risk.""" + assert assess_risk("run_command", {"command": "dd if=/dev/zero of=/dev/sda"}) == "critical" + + def test_write_to_disk_device_critical(self): + """Writing to disk device is critical risk.""" + assert assess_risk("run_command", {"command": "echo > /dev/sda"}) == "critical" + + def test_sudo_rm_critical(self): + """sudo rm is critical risk (matches sudo pattern first).""" + assert assess_risk("run_command", {"command": "sudo rm file.txt"}) == "critical" + + def test_case_insensitive_critical(self): + """Critical patterns match case-insensitively.""" + assert assess_risk("run_command", {"command": "SUDO apt update"}) == "critical" + + # ─── Edge case tests ─── def test_unknown_tool_default_medium(self): - """未知工具默认中等风险""" + """Unknown tools default to medium risk.""" assert assess_risk("unknown_tool", {}) == "medium" + def test_unknown_tool_with_args_medium(self): + """Unknown tool with args defaults to medium risk.""" + assert assess_risk("custom_tool", {"key": "value"}) == "medium" + def test_empty_command(self): - """空命令是低风险""" + """Empty command is low risk.""" assert assess_risk("run_command", {"command": ""}) == "low" + def test_whitespace_only_command(self): + """Whitespace-only command is low risk.""" + assert assess_risk("run_command", {"command": " "}) == "low" + def test_write_file_no_path(self): - """无路径的写入是中等风险""" + """Writing without a path is medium risk.""" assert assess_risk("write_file", {}) == "medium" + def test_edit_file_no_path(self): + """Editing without a path is medium risk.""" + assert assess_risk("edit_file", {}) == "medium" + + def test_run_command_no_command(self): + """Running without a command is low risk.""" + assert assess_risk("run_command", {}) == "low" + + def test_write_file_path_without_sensitive_marker(self): + """Writing to a normal path is medium risk.""" + assert assess_risk("write_file", {"path": "src/main.py"}) == "medium" + + def test_write_file_path_partial_sensitive_match(self): + """Writing to a path containing sensitive substring is high risk.""" + assert assess_risk("write_file", {"path": "config/my.env.local"}) == "high" + + def test_rm_as_substring_in_word(self): + """rm as part of a larger word should still match as high risk.""" + # The pattern uses \b word boundary + assert assess_risk("run_command", {"command": "program --rm-flag"}) == "high" + + def test_command_with_extra_spaces(self): + """Command with extra spaces still matches.""" + assert assess_risk("run_command", {"command": "rm -rf /"}) == "critical" + class TestGetRiskDescription: - """测试 get_risk_description 函数""" + """Tests for the get_risk_description function.""" def test_low_description(self): - """测试低风险描述""" + """Test low risk description.""" desc = get_risk_description("low") assert "低风险" in desc + assert "读取" in desc def test_medium_description(self): - """测试中等风险描述""" + """Test medium risk description.""" desc = get_risk_description("medium") assert "中风险" in desc def test_high_description(self): - """测试高风险描述""" + """Test high risk description.""" desc = get_risk_description("high") assert "高风险" in desc def test_critical_description(self): - """测试极高风险描述""" + """Test critical risk description.""" desc = get_risk_description("critical") assert "极高风险" in desc def test_unknown_description(self): - """测试未知风险描述""" + """Test unknown risk description.""" desc = get_risk_description("unknown") assert "未知" in desc + + def test_empty_string_description(self): + """Test empty string risk description.""" + desc = get_risk_description("") + assert "未知" in desc + + def test_all_valid_levels_have_descriptions(self): + """Test that all valid risk levels return non-empty descriptions.""" + for level in ["low", "medium", "high", "critical"]: + desc = get_risk_description(level) + assert len(desc) > 0 + assert "未知" not in desc + + def test_descriptions_are_chinese(self): + """Test that descriptions contain Chinese characters.""" + for level in ["low", "medium", "high", "critical"]: + desc = get_risk_description(level) + # Check for Chinese characters (risk-related) + assert "风险" in desc diff --git a/tests/test_security/test_rule.py b/tests/test_security/test_rule.py new file mode 100644 index 0000000..09b4b15 --- /dev/null +++ b/tests/test_security/test_rule.py @@ -0,0 +1,193 @@ +"""权限规则引擎测试""" + +from jojo_code.security.rule import ( + PermissionRule, + RuleAction, + RuleEngine, + RuleFactory, + RuleMatchType, +) + + +class TestPermissionRule: + """测试权限规则""" + + def test_exact_match(self): + """测试精确匹配""" + rule = PermissionRule(tool_pattern="read_file", match_type=RuleMatchType.EXACT) + assert rule.matches_tool("read_file") is True + assert rule.matches_tool("write_file") is False + assert rule.matches_tool("read_file_extra") is False + + def test_glob_match(self): + """测试通配符匹配""" + rule = PermissionRule(tool_pattern="*_file", match_type=RuleMatchType.GLOB) + assert rule.matches_tool("read_file") is True + assert rule.matches_tool("write_file") is True + assert rule.matches_tool("run_command") is False + + def test_regex_match(self): + """测试正则匹配""" + rule = PermissionRule(tool_pattern=r"git_\w+", match_type=RuleMatchType.REGEX) + assert rule.matches_tool("git_status") is True + assert rule.matches_tool("git_diff") is True + assert rule.matches_tool("read_file") is False + + def test_prefix_match(self): + """测试前缀匹配""" + rule = PermissionRule(tool_pattern="git_", match_type=RuleMatchType.PREFIX) + assert rule.matches_tool("git_status") is True + assert rule.matches_tool("git_diff") is True + assert rule.matches_tool("read_file") is False + + def test_disabled_rule_never_matches(self): + """测试禁用规则不匹配""" + rule = PermissionRule(tool_pattern="*", match_type=RuleMatchType.GLOB, enabled=False) + assert rule.matches_tool("anything") is False + + def test_args_pattern_match(self): + """测试参数模式匹配""" + rule = PermissionRule( + tool_pattern="run_command", + args_pattern={"command": "ls *"}, + ) + assert rule.matches("run_command", {"command": "ls -la"}) is True + assert rule.matches("run_command", {"command": "rm -rf /"}) is False + + def test_args_pattern_missing_key(self): + """测试参数模式缺少 key""" + rule = PermissionRule( + tool_pattern="run_command", + args_pattern={"command": "ls *"}, + ) + assert rule.matches("run_command", {"timeout": 30}) is False + + def test_empty_args_pattern_matches_all(self): + """测试空参数模式匹配所有""" + rule = PermissionRule(tool_pattern="run_command", args_pattern={}) + assert rule.matches("run_command", {"command": "anything"}) is True + + def test_repr(self): + """测试字符串表示""" + rule = PermissionRule(name="test_rule", tool_pattern="read_*", action=RuleAction.ALLOW) + assert "test_rule" in repr(rule) + assert "allow" in repr(rule) + + +class TestRuleEngine: + """测试规则引擎""" + + def test_add_and_list_rules(self): + """测试添加和列出规则""" + engine = RuleEngine() + rule = PermissionRule(name="test", tool_pattern="*") + engine.add_rule(rule) + assert len(engine.list_rules()) == 1 + + def test_priority_ordering(self): + """测试优先级排序""" + engine = RuleEngine() + low = PermissionRule(name="low", tool_pattern="*", priority=1, action=RuleAction.ALLOW) + high = PermissionRule(name="high", tool_pattern="*", priority=100, action=RuleAction.DENY) + engine.add_rule(low) + engine.add_rule(high) + # 高优先级应该排在前面 + assert engine.list_rules()[0].name == "high" + + def test_remove_rule(self): + """测试移除规则""" + engine = RuleEngine() + rule = PermissionRule(name="test", tool_pattern="*") + engine.add_rule(rule) + assert engine.remove_rule("test") is True + assert len(engine.list_rules()) == 0 + assert engine.remove_rule("nonexistent") is False + + def test_check_returns_first_match(self): + """测试 check 返回第一个匹配""" + engine = RuleEngine() + engine.add_rule( + PermissionRule(name="allow_all", tool_pattern="*", action=RuleAction.ALLOW, priority=1) + ) + engine.add_rule( + PermissionRule( + name="deny_danger", + tool_pattern="run_command", + args_pattern={"command": "rm *"}, + action=RuleAction.DENY, + priority=100, + ) + ) + # 高优先级的 deny 应该先匹配 + assert engine.check("run_command", {"command": "rm -rf /"}) == RuleAction.DENY + # 普通命令匹配 allow_all + assert engine.check("read_file", {"path": "test.py"}) == RuleAction.ALLOW + + def test_default_action(self): + """测试默认动作""" + engine = RuleEngine() + assert engine.check("unknown_tool", {}) == RuleAction.ASK + engine.default_action = RuleAction.ALLOW + assert engine.check("unknown_tool", {}) == RuleAction.ALLOW + + def test_clear_rules(self): + """测试清空规则""" + engine = RuleEngine() + engine.add_rule(PermissionRule(name="test", tool_pattern="*")) + engine.clear() + assert len(engine.list_rules()) == 0 + + def test_from_config(self): + """测试从配置创建""" + config = { + "default_action": "deny", + "rules": [ + { + "name": "allow_read", + "tool_pattern": "read_*", + "action": "allow", + "priority": 10, + }, + ], + } + engine = RuleEngine.from_config(config) + assert engine.default_action == RuleAction.DENY + assert len(engine.list_rules()) == 1 + assert engine.check("read_file", {}) == RuleAction.ALLOW + + def test_batch_add_rules(self): + """测试批量添加规则""" + engine = RuleEngine() + rules = [ + PermissionRule(name="r1", tool_pattern="a*"), + PermissionRule(name="r2", tool_pattern="b*"), + ] + engine.add_rules(rules) + assert len(engine.list_rules()) == 2 + + +class TestRuleFactory: + """测试规则工厂""" + + def test_allow_all_tools(self): + """测试允许所有工具规则""" + rule = RuleFactory.allow_all_tools() + assert rule.action == RuleAction.ALLOW + assert rule.matches_tool("anything") + + def test_deny_dangerous_commands(self): + """测试危险命令拒绝规则""" + rules = RuleFactory.deny_dangerous_commands() + assert len(rules) >= 3 + # 应该包含 rm -rf、sudo、chmod 777 + names = [r.name for r in rules] + assert "deny_rm_rf" in names + assert "deny_sudo" in names + assert "deny_chmod_777" in names + + def test_require_confirmation_for_writes(self): + """测试写操作确认规则""" + rules = RuleFactory.require_confirmation_for_writes() + assert len(rules) >= 4 + for rule in rules: + assert rule.action == RuleAction.ASK diff --git a/tests/test_security/test_ssrf.py b/tests/test_security/test_ssrf.py new file mode 100644 index 0000000..6614bcc --- /dev/null +++ b/tests/test_security/test_ssrf.py @@ -0,0 +1,55 @@ +"""SSRF 防护模块测试""" + +from jojo_code.security.ssrf import _is_safe_url + + +class TestIsSafeUrl: + """测试 URL 安全检查""" + + def test_safe_public_url(self): + """测试安全的公网 URL""" + assert _is_safe_url("https://example.com") is True + assert _is_safe_url("https://google.com/search?q=test") is True + assert _is_safe_url("http://api.example.com:8080/v1/data") is True + + def test_localhost_blocked(self): + """测试 localhost 被阻止""" + assert _is_safe_url("http://localhost") is False + assert _is_safe_url("http://localhost:8080") is False + assert _is_safe_url("http://127.0.0.1") is False + assert _is_safe_url("http://127.0.0.1:3000") is False + assert _is_safe_url("http://[::1]") is False + assert _is_safe_url("http://0.0.0.0") is False + + def test_private_ip_blocked(self): + """测试私有 IP 被阻止""" + assert _is_safe_url("http://10.0.0.1") is False + assert _is_safe_url("http://172.16.0.1") is False + assert _is_safe_url("http://192.168.1.1") is False + assert _is_safe_url("http://169.254.1.1") is False + + def test_credential_injection_blocked(self): + """测试凭证注入被阻止""" + assert _is_safe_url("http://evil.com@internal-host/") is False + assert _is_safe_url("http://user:pass@example.com") is False + + def test_ipv6_mapped_blocked(self): + """测试 IPv6-mapped 地址被阻止""" + assert _is_safe_url("http://::ffff:127.0.0.1") is False + assert _is_safe_url("http://::ffff:10.0.0.1") is False + + def test_empty_hostname_blocked(self): + """测试空 hostname 被阻止""" + assert _is_safe_url("") is False + assert _is_safe_url("http://") is False + + def test_malformed_url_blocked(self): + """测试畸形 URL 被阻止""" + assert _is_safe_url("not-a-url") is False + assert _is_safe_url("://missing-scheme") is False + + def test_safe_domain_names(self): + """测试安全域名""" + assert _is_safe_url("https://github.com") is True + assert _is_safe_url("https://pypi.org/simple/") is True + assert _is_safe_url("https://api.openai.com/v1/chat") is True diff --git a/tests/test_session/test_manager.py b/tests/test_session/test_manager.py index d717412..2ada59b 100644 --- a/tests/test_session/test_manager.py +++ b/tests/test_session/test_manager.py @@ -1,25 +1,172 @@ +"""Session Manager tests. + +Tests for Session CRUD, corrupted JSON tolerance (Bug #7), +metadata handling, multi-message sessions, and directory auto-creation. +""" + +import json +import os + +import pytest + from jojo_code.session.manager import SessionManager -def test_create_and_save_session(tmp_path): - storage = str(tmp_path) - sm = SessionManager(storage_dir=storage) - s = sm.create_session(user_id="user-1") - sm.add_message(s.id, "user", "Hello world") - # simulate restart by creating a new manager pointing to same storage - sm2 = SessionManager(storage_dir=storage) - s2 = sm2.get_session(s.id) - assert s2 is not None - assert s2.id == s.id - assert len(s2.messages) == 1 - - -def test_recover_session(tmp_path): - storage = str(tmp_path) - sm = SessionManager(storage_dir=storage) - s = sm.create_session() - sm.add_message(s.id, "user", "first message") - recovered = sm.recover_session(s.id) - assert recovered is not None - assert len(recovered.messages) == 1 - assert recovered.messages[0].content == "first message" +@pytest.fixture +def manager(tmp_path): + return SessionManager(storage_dir=str(tmp_path / "sessions")) + + +class TestSessionManager: + def test_create_session(self, manager): + session = manager.create_session(user_id="user1") + assert session.id is not None + assert session.user_id == "user1" + + def test_create_session_without_user_id(self, manager): + session = manager.create_session() + assert session.id is not None + assert session.user_id is None + + def test_create_session_with_metadata(self, manager): + metadata = {"source": "cli", "version": "1.0"} + session = manager.create_session(user_id="u1", metadata=metadata) + assert session.metadata == metadata + + # Verify metadata persists after reload + loaded = manager.get_session(session.id) + assert loaded is not None + assert loaded.metadata == metadata + + def test_create_session_metadata_default(self, manager): + session = manager.create_session() + assert session.metadata == {} + + def test_get_session(self, manager): + session = manager.create_session() + loaded = manager.get_session(session.id) + assert loaded is not None + assert loaded.id == session.id + + def test_get_missing_session_returns_none(self, manager): + assert manager.get_session("nonexistent") is None + + def test_add_message(self, manager): + session = manager.create_session() + manager.add_message(session.id, "user", "hello") + loaded = manager.get_session(session.id) + assert len(loaded.messages) == 1 + assert loaded.messages[0].content == "hello" + + def test_add_message_missing_session_raises(self, manager): + with pytest.raises(ValueError, match="not found"): + manager.add_message("nonexistent", "user", "msg") + + def test_recover_session(self, manager): + session = manager.create_session() + recovered = manager.recover_session(session.id) + assert recovered is not None + assert recovered.id == session.id + + def test_recover_missing_session_returns_none(self, manager): + assert manager.recover_session("nonexistent") is None + + def test_add_multiple_messages(self, manager): + session = manager.create_session() + manager.add_message(session.id, "user", "first") + manager.add_message(session.id, "assistant", "second") + manager.add_message(session.id, "user", "third") + loaded = manager.get_session(session.id) + assert len(loaded.messages) == 3 + assert loaded.messages[0].content == "first" + assert loaded.messages[1].content == "second" + assert loaded.messages[2].content == "third" + + def test_add_message_different_roles(self, manager): + session = manager.create_session() + manager.add_message(session.id, "user", "question") + manager.add_message(session.id, "assistant", "answer") + manager.add_message(session.id, "system", "system note") + loaded = manager.get_session(session.id) + assert loaded.messages[0].role == "user" + assert loaded.messages[1].role == "assistant" + assert loaded.messages[2].role == "system" + + def test_storage_directory_auto_created(self, tmp_path): + storage = str(tmp_path / "new_dir" / "sessions") + assert not os.path.exists(storage) + SessionManager(storage_dir=storage) + assert os.path.exists(storage) + + def test_session_persists_across_managers(self, tmp_path): + storage = str(tmp_path / "shared") + sm1 = SessionManager(storage_dir=storage) + session = sm1.create_session(user_id="u1") + sm1.add_message(session.id, "user", "persisted message") + + sm2 = SessionManager(storage_dir=storage) + loaded = sm2.get_session(session.id) + assert loaded is not None + assert loaded.user_id == "u1" + assert len(loaded.messages) == 1 + assert loaded.messages[0].content == "persisted message" + + def test_save_session_overwrites(self, manager): + session = manager.create_session(user_id="original") + manager.add_message(session.id, "user", "msg1") + + # Modify and save again + session.user_id = "modified" + manager.save_session(session) + + loaded = manager.get_session(session.id) + assert loaded.user_id == "modified" + + def test_path_method(self, manager, tmp_path): + session = manager.create_session() + expected = os.path.join(str(tmp_path / "sessions"), f"{session.id}.json") + assert manager._path(session.id) == expected + + def test_session_json_format(self, manager): + session = manager.create_session(user_id="u1") + manager.add_message(session.id, "user", "test") + path = manager._path(session.id) + with open(path, encoding="utf-8") as f: + data = json.load(f) + assert data["id"] == session.id + assert data["user_id"] == "u1" + assert len(data["messages"]) == 1 + + +class TestCorruptedJSON: + """Bug #7: corrupted JSON files should return None instead of crashing""" + + def test_corrupted_json_returns_none(self, manager): + """Corrupted JSON file should return None""" + session = manager.create_session() + path = manager._path(session.id) + with open(path, "w") as f: + f.write("not valid json {{{") + + result = manager.get_session(session.id) + assert result is None + + def test_empty_file_returns_none(self, manager): + """Empty file should return None""" + session = manager.create_session() + path = manager._path(session.id) + with open(path, "w") as f: + f.write("") + + result = manager.get_session(session.id) + assert result is None + + def test_truncated_json_returns_none(self, manager): + """Truncated JSON should return None""" + session = manager.create_session() + path = manager._path(session.id) + with open(path, "w") as f: + f.write('{"id": "abc", "user_id": "u1", "messages": [') + + result = manager.get_session(session.id) + assert result is None diff --git a/tests/test_session/test_models.py b/tests/test_session/test_models.py index d46405d..d376d59 100644 --- a/tests/test_session/test_models.py +++ b/tests/test_session/test_models.py @@ -1,21 +1,129 @@ from jojo_code.session.models import Message, Session -def test_message_roundtrip(): - m = Message(role="user", content="Hello") - d = m.to_dict() - m2 = Message.from_dict(d) - assert m2.role == m.role - assert m2.content == m.content - assert isinstance(m2.timestamp, float) - - -def test_session_roundtrip(): - s = Session(id="sess-1", user_id="user-1") - s.add_message("user", "Hello") - s.add_message("assistant", "Hi there!") - d = s.to_dict() - s2 = Session.from_dict(d) - assert s2.id == s.id - assert s2.user_id == s.user_id - assert len(s2.messages) == 2 +class TestMessage: + """Message dataclass tests.""" + + def test_creation(self): + m = Message(role="user", content="Hello") + assert m.role == "user" + assert m.content == "Hello" + assert isinstance(m.timestamp, float) + + def test_to_dict(self): + m = Message(role="assistant", content="Hi", timestamp=1234567890.0) + d = m.to_dict() + assert d["role"] == "assistant" + assert d["content"] == "Hi" + assert d["timestamp"] == 1234567890.0 + + def test_from_dict_with_timestamp(self): + d = {"role": "system", "content": "prompt", "timestamp": 9999.0} + m = Message.from_dict(d) + assert m.role == "system" + assert m.content == "prompt" + assert m.timestamp == 9999.0 + + def test_from_dict_without_timestamp(self): + d = {"role": "user", "content": "hello"} + m = Message.from_dict(d) + assert m.role == "user" + assert isinstance(m.timestamp, float) + + def test_roundtrip(self): + m = Message(role="user", content="Hello") + d = m.to_dict() + m2 = Message.from_dict(d) + assert m2.role == m.role + assert m2.content == m.content + assert isinstance(m2.timestamp, float) + + +class TestSession: + """Session dataclass tests.""" + + def test_creation_defaults(self): + s = Session(id="s1") + assert s.id == "s1" + assert s.user_id is None + assert s.messages == [] + assert s.metadata == {} + assert isinstance(s.created_at, float) + assert isinstance(s.last_seen_at, float) + + def test_creation_with_user(self): + s = Session(id="s1", user_id="u1") + assert s.user_id == "u1" + + def test_add_message(self): + s = Session(id="s1") + s.add_message("user", "Hello") + assert len(s.messages) == 1 + assert s.messages[0].role == "user" + assert s.messages[0].content == "Hello" + + def test_add_message_updates_last_seen(self): + s = Session(id="s1") + before = s.last_seen_at + s.add_message("user", "msg") + assert s.last_seen_at >= before + + def test_add_multiple_messages(self): + s = Session(id="s1") + s.add_message("user", "q1") + s.add_message("assistant", "a1") + s.add_message("user", "q2") + assert len(s.messages) == 3 + assert s.messages[2].content == "q2" + + def test_to_dict(self): + s = Session(id="s1", user_id="u1", metadata={"key": "val"}) + s.add_message("user", "hi") + d = s.to_dict() + assert d["id"] == "s1" + assert d["user_id"] == "u1" + assert d["metadata"] == {"key": "val"} + assert len(d["messages"]) == 1 + assert d["messages"][0]["role"] == "user" + + def test_from_dict(self): + d = { + "id": "s2", + "user_id": "u2", + "created_at": 1000.0, + "last_seen_at": 2000.0, + "messages": [{"role": "user", "content": "hello", "timestamp": 1500.0}], + "metadata": {"env": "test"}, + } + s = Session.from_dict(d) + assert s.id == "s2" + assert s.user_id == "u2" + assert s.created_at == 1000.0 + assert s.last_seen_at == 2000.0 + assert len(s.messages) == 1 + assert s.messages[0].content == "hello" + assert s.metadata == {"env": "test"} + + def test_from_dict_minimal(self): + d = {"id": "s3"} + s = Session.from_dict(d) + assert s.id == "s3" + assert s.user_id is None + assert s.messages == [] + assert s.metadata == {} + + def test_roundtrip(self): + s = Session(id="sess-1", user_id="user-1") + s.add_message("user", "Hello") + s.add_message("assistant", "Hi there!") + d = s.to_dict() + s2 = Session.from_dict(d) + assert s2.id == s.id + assert s2.user_id == s.user_id + assert len(s2.messages) == 2 + + def test_roundtrip_preserves_metadata(self): + s = Session(id="s1", metadata={"lang": "en", "mode": "build"}) + d = s.to_dict() + s2 = Session.from_dict(d) + assert s2.metadata == {"lang": "en", "mode": "build"} diff --git a/tests/test_session/test_session_manager.py b/tests/test_session/test_session_manager.py index 40c64c7..02139ca 100644 --- a/tests/test_session/test_session_manager.py +++ b/tests/test_session/test_session_manager.py @@ -1,8 +1,12 @@ """Session Manager 测试 -测试 Session CRUD 和 Bug #7 修复(损坏 JSON 容错)。 +测试 Session CRUD、Bug #7 修复(损坏 JSON 容错)、 +元数据处理、多消息会话和目录自动创建。 """ +import json +import os + import pytest from jojo_code.session.manager import SessionManager @@ -45,6 +49,175 @@ def test_recover_session(self, manager): assert recovered is not None assert recovered.id == session.id + def test_recover_missing_session_returns_none(self, manager): + assert manager.recover_session("nonexistent") is None + + def test_create_session_with_metadata(self, manager): + metadata = {"source": "cli", "version": "1.0"} + session = manager.create_session(user_id="u1", metadata=metadata) + assert session.metadata == metadata + + loaded = manager.get_session(session.id) + assert loaded is not None + assert loaded.metadata == metadata + + def test_create_session_metadata_default(self, manager): + session = manager.create_session() + assert session.metadata == {} + + def test_add_multiple_messages(self, manager): + session = manager.create_session() + manager.add_message(session.id, "user", "first") + manager.add_message(session.id, "assistant", "second") + manager.add_message(session.id, "user", "third") + loaded = manager.get_session(session.id) + assert len(loaded.messages) == 3 + assert loaded.messages[0].content == "first" + assert loaded.messages[1].content == "second" + assert loaded.messages[2].content == "third" + + def test_add_message_different_roles(self, manager): + session = manager.create_session() + manager.add_message(session.id, "user", "question") + manager.add_message(session.id, "assistant", "answer") + manager.add_message(session.id, "system", "system note") + loaded = manager.get_session(session.id) + assert loaded.messages[0].role == "user" + assert loaded.messages[1].role == "assistant" + assert loaded.messages[2].role == "system" + + def test_storage_directory_auto_created(self, tmp_path): + storage = str(tmp_path / "new_dir" / "sessions") + assert not os.path.exists(storage) + SessionManager(storage_dir=storage) + assert os.path.exists(storage) + + def test_session_persists_across_managers(self, tmp_path): + storage = str(tmp_path / "shared") + sm1 = SessionManager(storage_dir=storage) + session = sm1.create_session(user_id="u1") + sm1.add_message(session.id, "user", "persisted message") + + sm2 = SessionManager(storage_dir=storage) + loaded = sm2.get_session(session.id) + assert loaded is not None + assert loaded.user_id == "u1" + assert len(loaded.messages) == 1 + assert loaded.messages[0].content == "persisted message" + + def test_save_session_overwrites(self, manager): + session = manager.create_session(user_id="original") + manager.add_message(session.id, "user", "msg1") + + session.user_id = "modified" + manager.save_session(session) + + loaded = manager.get_session(session.id) + assert loaded.user_id == "modified" + + def test_path_method(self, manager, tmp_path): + session = manager.create_session() + expected = os.path.join(str(tmp_path / "sessions"), f"{session.id}.json") + assert manager._path(session.id) == expected + + def test_session_json_format(self, manager): + session = manager.create_session(user_id="u1") + manager.add_message(session.id, "user", "test") + path = manager._path(session.id) + with open(path, encoding="utf-8") as f: + data = json.load(f) + assert data["id"] == session.id + assert data["user_id"] == "u1" + assert len(data["messages"]) == 1 + + def test_unicode_messages(self, manager): + session = manager.create_session() + manager.add_message(session.id, "user", "你好世界") + manager.add_message(session.id, "assistant", "こんにちは") + loaded = manager.get_session(session.id) + assert loaded.messages[0].content == "你好世界" + assert loaded.messages[1].content == "こんにちは" + + def test_empty_content_message(self, manager): + session = manager.create_session() + manager.add_message(session.id, "user", "") + loaded = manager.get_session(session.id) + assert loaded.messages[0].content == "" + + def test_create_multiple_sessions(self, manager): + s1 = manager.create_session(user_id="a") + s2 = manager.create_session(user_id="b") + assert s1.id != s2.id + assert manager.get_session(s1.id).user_id == "a" + assert manager.get_session(s2.id).user_id == "b" + + def test_recover_session_with_messages(self, manager): + """Recover should return session with all messages intact.""" + session = manager.create_session(user_id="u1") + manager.add_message(session.id, "user", "question 1") + manager.add_message(session.id, "assistant", "answer 1") + manager.add_message(session.id, "user", "question 2") + + recovered = manager.recover_session(session.id) + assert recovered is not None + assert len(recovered.messages) == 3 + assert recovered.messages[0].content == "question 1" + assert recovered.messages[2].content == "question 2" + assert recovered.user_id == "u1" + + def test_long_message_content(self, manager): + """Session should handle long message content.""" + session = manager.create_session() + long_content = "x" * 10000 + manager.add_message(session.id, "user", long_content) + loaded = manager.get_session(session.id) + assert loaded.messages[0].content == long_content + + def test_special_characters_in_user_id(self, manager): + """Session should handle special characters in user_id.""" + special_id = "user@example.com/test" + session = manager.create_session(user_id=special_id) + loaded = manager.get_session(session.id) + assert loaded.user_id == special_id + + def test_create_session_with_empty_user_id(self, manager): + session = manager.create_session(user_id="") + assert session.user_id == "" + loaded = manager.get_session(session.id) + assert loaded.user_id == "" + + def test_metadata_persists_after_add_message(self, manager): + """Metadata should survive add_message round-trips.""" + session = manager.create_session( + user_id="u1", metadata={"key": "value", "count": "0"} + ) + manager.add_message(session.id, "user", "hello") + loaded = manager.get_session(session.id) + assert loaded.metadata == {"key": "value", "count": "0"} + + def test_session_file_is_valid_json_with_indent(self, manager): + """Saved JSON should be indented for readability.""" + session = manager.create_session() + path = manager._path(session.id) + with open(path, encoding="utf-8") as f: + content = f.read() + # Indented JSON contains newlines + assert "\n" in content + + def test_add_message_to_multiple_sessions(self, manager): + """Messages in different sessions should be independent.""" + s1 = manager.create_session() + s2 = manager.create_session() + manager.add_message(s1.id, "user", "msg in s1") + manager.add_message(s2.id, "user", "msg in s2") + + loaded1 = manager.get_session(s1.id) + loaded2 = manager.get_session(s2.id) + assert len(loaded1.messages) == 1 + assert len(loaded2.messages) == 1 + assert loaded1.messages[0].content == "msg in s1" + assert loaded2.messages[0].content == "msg in s2" + class TestCorruptedJSON: """Bug #7: 损坏的 JSON 文件应返回 None 而不是崩溃""" diff --git a/tests/test_skills/test_builtins.py b/tests/test_skills/test_builtins.py new file mode 100644 index 0000000..e3511c1 --- /dev/null +++ b/tests/test_skills/test_builtins.py @@ -0,0 +1,282 @@ +"""Built-in Skills tests + +Tests for all skills defined in jojo_code.skills.builtins: +web_search, web_fetch, read_file, write_file, run_command, +analyze_code, format_json, validate_json, calculate, translate. +""" + +import json + +from jojo_code.skills.builtins import ( + analyze_code, + calculate, + format_json, + read_file, + run_command, + translate, + validate_json, + web_fetch, + web_search, + write_file, +) + +# --------------------------------------------------------------------------- +# web_search +# --------------------------------------------------------------------------- + + +class TestWebSearch: + def test_returns_string(self): + result = web_search("Python tutorial") + assert isinstance(result, str) + + def test_contains_query(self): + result = web_search("hello world") + assert "hello world" in result + + def test_empty_query(self): + result = web_search("") + assert isinstance(result, str) + + +# --------------------------------------------------------------------------- +# web_fetch +# --------------------------------------------------------------------------- + + +class TestWebFetch: + def test_returns_string(self): + result = web_fetch("https://example.com") + assert isinstance(result, str) + + def test_contains_url(self): + url = "https://example.com/page" + result = web_fetch(url) + assert url in result + + +# --------------------------------------------------------------------------- +# read_file +# --------------------------------------------------------------------------- + + +class TestReadFile: + def test_read_existing_file(self, tmp_path): + f = tmp_path / "test.txt" + f.write_text("hello world", encoding="utf-8") + result = read_file(str(f)) + assert result == "hello world" + + def test_read_nonexistent_file(self): + result = read_file("/nonexistent/path/file.txt") + assert result.startswith("Error:") + + def test_read_with_custom_encoding(self, tmp_path): + f = tmp_path / "test.txt" + f.write_text("content", encoding="utf-8") + result = read_file(str(f), encoding="utf-8") + assert result == "content" + + def test_read_empty_file(self, tmp_path): + f = tmp_path / "empty.txt" + f.write_text("", encoding="utf-8") + result = read_file(str(f)) + assert result == "" + + +# --------------------------------------------------------------------------- +# write_file +# --------------------------------------------------------------------------- + + +class TestWriteFile: + def test_write_creates_file(self, tmp_path): + f = tmp_path / "output.txt" + result = write_file(str(f), "test content") + assert "Successfully wrote" in result + assert f.read_text(encoding="utf-8") == "test content" + + def test_write_overwrites_existing(self, tmp_path): + f = tmp_path / "output.txt" + f.write_text("old", encoding="utf-8") + write_file(str(f), "new") + assert f.read_text(encoding="utf-8") == "new" + + def test_write_to_invalid_path(self): + result = write_file("/nonexistent/dir/file.txt", "data") + assert result.startswith("Error:") + + def test_write_empty_content(self, tmp_path): + f = tmp_path / "empty.txt" + result = write_file(str(f), "") + assert "Successfully wrote" in result + assert f.read_text(encoding="utf-8") == "" + + +# --------------------------------------------------------------------------- +# run_command +# --------------------------------------------------------------------------- + + +class TestRunCommand: + def test_simple_echo(self): + result = run_command("echo hello") + assert "hello" in result + + def test_command_with_error(self): + result = run_command("ls /nonexistent_path_xyz_12345") + # Should return stderr or error, not crash + assert isinstance(result, str) + + def test_shell_false(self): + result = run_command("echo test", shell=False) + assert isinstance(result, str) + + +# --------------------------------------------------------------------------- +# analyze_code +# --------------------------------------------------------------------------- + + +class TestAnalyzeCode: + def test_analyze_python_file(self, tmp_path): + f = tmp_path / "sample.py" + content = 'def hello():\n return "hi"\n\n# comment\n\nprint(hello())\n' + f.write_text(content, encoding="utf-8") + result = analyze_code(str(f)) + assert isinstance(result, dict) + assert result["file"] == str(f) + # Implementation uses split("\n") which counts trailing empty line + assert result["lines"] == len(content.split("\n")) + assert result["code_lines"] > 0 + assert result["blank_lines"] >= 1 + assert "error" not in result + + def test_analyze_nonexistent_file(self): + result = analyze_code("/nonexistent/file.py") + assert isinstance(result, dict) + assert "error" in result + + def test_analyze_empty_file(self, tmp_path): + f = tmp_path / "empty.py" + f.write_text("", encoding="utf-8") + result = analyze_code(str(f)) + assert result["lines"] == 1 # split on \n gives [''] + assert result["code_lines"] == 0 + + +# --------------------------------------------------------------------------- +# format_json +# --------------------------------------------------------------------------- + + +class TestFormatJson: + def test_format_valid_json(self): + result = format_json('{"a": 1, "b": 2}') + parsed = json.loads(result) + assert parsed == {"a": 1, "b": 2} + + def test_format_with_custom_indent(self): + result = format_json('{"key": "value"}', indent=4) + assert " " in result + + def test_format_invalid_json(self): + result = format_json("not json") + assert result.startswith("Error:") + + def test_format_nested_json(self): + data = '{"outer": {"inner": [1, 2, 3]}}' + result = format_json(data) + parsed = json.loads(result) + assert parsed["outer"]["inner"] == [1, 2, 3] + + def test_format_preserves_unicode(self): + result = format_json('{"name": "jojo"}') + assert "jojo" in result + + +# --------------------------------------------------------------------------- +# validate_json +# --------------------------------------------------------------------------- + + +class TestValidateJson: + def test_valid_json(self): + result = validate_json('{"a": 1}') + assert result["valid"] is True + + def test_invalid_json(self): + result = validate_json("{bad}") + assert result["valid"] is False + assert "error" in result + + def test_empty_object(self): + result = validate_json("{}") + assert result["valid"] is True + + def test_array_json(self): + result = validate_json("[1, 2, 3]") + assert result["valid"] is True + + def test_string_json(self): + result = validate_json('"hello"') + assert result["valid"] is True + + def test_number_json(self): + result = validate_json("42") + assert result["valid"] is True + + +# --------------------------------------------------------------------------- +# calculate +# --------------------------------------------------------------------------- + + +class TestCalculate: + def test_simple_addition(self): + assert calculate("2+2") == "4" + + def test_multiplication(self): + assert calculate("3*7") == "21" + + def test_parentheses(self): + assert calculate("(2+3)*4") == "20" + + def test_decimal(self): + result = calculate("1.5+2.5") + assert result == "4.0" + + def test_division(self): + assert calculate("10/3") == str(10 / 3) + + def test_invalid_chars_rejected(self): + result = calculate("import os") + assert "Error" in result + + def test_letters_rejected(self): + result = calculate("abc") + assert "Error" in result + + def test_complex_expression(self): + assert calculate("(10+20)*(2+3)") == "150" + + +# --------------------------------------------------------------------------- +# translate +# --------------------------------------------------------------------------- + + +class TestTranslate: + def test_default_target_lang(self): + result = translate("Hello") + assert "en" in result + assert "Hello" in result + + def test_custom_target_lang(self): + result = translate("Hello", target_lang="zh") + assert "zh" in result + assert "Hello" in result + + def test_returns_string(self): + result = translate("test") + assert isinstance(result, str) diff --git a/tests/test_task/test_types.py b/tests/test_task/test_types.py index 28b7aa0..d95ee3c 100644 --- a/tests/test_task/test_types.py +++ b/tests/test_task/test_types.py @@ -1,13 +1,17 @@ """Task 类型和状态机测试 测试 Task 数据类的状态转换(pending→running→completed/failed/killed), -以及 TaskResult、TaskInput、TaskProgress 等辅助数据类。 +以及 TaskResult、TaskInput、TaskOutput、TaskProgress 等辅助数据类。 """ +from datetime import datetime + from jojo_code.task.types import ( Task, TaskInput, + TaskOutput, TaskPriority, + TaskProgress, TaskResult, TaskStatus, TaskType, @@ -154,3 +158,148 @@ def test_priority_values(self): assert TaskPriority.NORMAL.value == 5 assert TaskPriority.HIGH.value == 10 assert TaskPriority.CRITICAL.value == 20 + + +class TestTaskOutput: + """TaskOutput 数据类测试""" + + def test_creation_defaults(self): + output = TaskOutput() + assert output.data is None + assert output.error is None + assert output.logs == [] + + def test_creation_with_values(self): + output = TaskOutput(data="result", error="some error", logs=["log1", "log2"]) + assert output.data == "result" + assert output.error == "some error" + assert len(output.logs) == 2 + + def test_logs_list_is_independent(self): + out1 = TaskOutput() + out2 = TaskOutput() + out1.logs.append("entry") + assert len(out2.logs) == 0 + + +class TestTaskProgress: + """TaskProgress 数据类测试""" + + def test_creation_defaults(self): + p = TaskProgress(task_id="t1", status=TaskStatus.RUNNING) + assert p.task_id == "t1" + assert p.status == TaskStatus.RUNNING + assert p.progress == 0.0 + assert p.message == "" + assert isinstance(p.timestamp, datetime) + + def test_creation_with_values(self): + now = datetime.now() + p = TaskProgress( + task_id="t2", + status=TaskStatus.COMPLETED, + progress=100.0, + message="done", + timestamp=now, + ) + assert p.progress == 100.0 + assert p.message == "done" + assert p.timestamp == now + + +class TestTaskMetadata: + """Task 元数据和父任务测试""" + + def test_default_metadata_empty(self): + task = Task(id="t1", type=TaskType.BASH) + assert task.metadata == {} + + def test_custom_metadata(self): + task = Task( + id="t1", + type=TaskType.AGENT, + metadata={"source": "test", "priority_group": "dev"}, + ) + assert task.metadata["source"] == "test" + + def test_parent_id_default_none(self): + task = Task(id="t1", type=TaskType.BASH) + assert task.parent_id is None + + def test_parent_id_set(self): + task = Task(id="child1", type=TaskType.AGENT, parent_id="parent1") + assert task.parent_id == "parent1" + + def test_default_priority_is_normal(self): + task = Task(id="t1", type=TaskType.BASH) + assert task.priority == TaskPriority.NORMAL + + def test_created_at_is_set(self): + task = Task(id="t1", type=TaskType.BASH) + assert isinstance(task.created_at, datetime) + + +class TestTaskStateTransitions: + """Task 完整状态转换链测试""" + + def test_full_lifecycle_success(self): + task = Task(id="t1", type=TaskType.BASH) + assert task.status == TaskStatus.PENDING + assert task.is_running is False + assert task.is_done is False + + task.start() + assert task.status == TaskStatus.RUNNING + assert task.is_running is True + assert task.is_done is False + + result = TaskResult(success=True, output="ok") + task.complete(result) + assert task.status == TaskStatus.COMPLETED + assert task.is_running is False + assert task.is_done is True + + def test_full_lifecycle_failure(self): + task = Task(id="t1", type=TaskType.BASH) + task.start() + task.fail("oops") + assert task.status == TaskStatus.FAILED + assert task.is_done is True + assert task.output.error == "oops" + + def test_full_lifecycle_killed(self): + task = Task(id="t1", type=TaskType.BASH) + task.start() + task.kill() + assert task.status == TaskStatus.KILLED + assert task.is_done is True + + def test_duration_positive_after_start(self): + task = Task(id="t1", type=TaskType.BASH) + task.start() + # Duration should be >= 0 even immediately after start + assert task.duration >= 0.0 + + def test_started_at_set_on_start(self): + task = Task(id="t1", type=TaskType.BASH) + assert task.started_at is None + task.start() + assert task.started_at is not None + + def test_completed_at_set_on_complete(self): + task = Task(id="t1", type=TaskType.BASH) + task.start() + task.complete(TaskResult(success=True)) + assert task.completed_at is not None + + def test_completed_at_set_on_fail(self): + task = Task(id="t1", type=TaskType.BASH) + task.start() + task.fail("err") + assert task.completed_at is not None + + def test_completed_at_set_on_kill(self): + task = Task(id="t1", type=TaskType.BASH) + task.start() + task.kill() + assert task.completed_at is not None