diff --git a/dev-suite/src/orchestrator.py b/dev-suite/src/orchestrator.py index 88ad87c..f9f222a 100644 --- a/dev-suite/src/orchestrator.py +++ b/dev-suite/src/orchestrator.py @@ -63,6 +63,7 @@ parse_generated_code, validate_paths_for_workspace, ) +from .tools.github_fetch import fetch_refs_as_context_items from .tracing import add_trace_event, create_trace_config load_dotenv() @@ -757,21 +758,42 @@ def _resolve_candidate(raw: str) -> Path | None: seen.add(key) ordered_files.append(f) - if not ordered_files: - trace.append("gather_context: no relevant files found") - logger.info("[CONTEXT] No files to gather for task") - return {"gathered_context": [], "trace": trace} + gathered: list[dict] = [] + if ordered_files: + gathered = _read_context_files( + ordered_files, workspace_root, allowed_root=repo_root + ) - gathered = _read_context_files( - ordered_files, workspace_root, allowed_root=repo_root + # Source 4: GitHub issue/PR pre-fetch (issue #193). + # Scans the task description for refs like "issue #113", + # "fixes #42", or "owner/repo#99" and fetches their summaries so + # the Architect has the context without needing tools. Best-effort: + # missing token, network errors, and 404s are silently skipped. + github_items = await fetch_refs_as_context_items( + task_description, + default_owner=os.getenv("GITHUB_OWNER", ""), + default_repo=os.getenv("GITHUB_REPO", ""), + token=os.getenv("GITHUB_TOKEN", ""), + max_refs=5, + max_chars=2000, ) + if github_items: + gathered.extend(github_items) + trace.append( + f"gather_context: pre-fetched {len(github_items)} GitHub ref(s)" + ) + + if not gathered: + trace.append("gather_context: no relevant files found") + logger.info("[CONTEXT] No files or GitHub refs to gather for task") + return {"gathered_context": [], "trace": trace} total_tokens = sum(_estimate_tokens(f["content"]) for f in gathered) trace.append( - f"gather_context: gathered {len(gathered)} files (~{total_tokens} tokens)" + f"gather_context: gathered {len(gathered)} items (~{total_tokens} tokens)" ) logger.info( - "[CONTEXT] Gathered %d files (~%d tokens) for Architect", + "[CONTEXT] Gathered %d items (~%d tokens) for Architect", len(gathered), total_tokens, ) diff --git a/dev-suite/src/tools/github_fetch.py b/dev-suite/src/tools/github_fetch.py new file mode 100644 index 0000000..f97e13e --- /dev/null +++ b/dev-suite/src/tools/github_fetch.py @@ -0,0 +1,307 @@ +"""GitHub issue/PR pre-fetch helpers (issue #193). + +Scans a task description for issue/PR references and fetches their +titles + bodies via the GitHub REST API so the Architect has the +context without needing to call tools. Kept intentionally small so +it can be used from `gather_context_node` without spinning up a full +ToolProvider. + +Public surface: + extract_github_refs(text, default_owner, default_repo, max_refs) + -> list[GitHubRef] + fetch_issue_or_pr(owner, repo, number, token, max_chars, timeout) + -> dict | None + fetch_refs_as_context_items(text, default_owner, default_repo, + token, max_refs, max_chars) + -> list[dict] # shape matches gathered_context entries + +Design notes: +- Uses httpx (already a dependency) to match the existing pattern in + LocalToolProvider._github_read_diff. +- Best-effort: any failure (missing token, network error, 404) returns + None for that ref. Callers treat pre-fetch as optional context. +- Issue and PR refs both hit /issues/{n}; the GitHub REST API returns + PR data from this endpoint as an issue with a pull_request field. +""" + +from __future__ import annotations + +import logging +import re +from dataclasses import dataclass + +import httpx + +logger = logging.getLogger(__name__) + +# GitHub reference patterns (issue #193). +# +# Matches are produced by two complementary patterns: +# +# 1. Cross-repo: "Abernaughty/agent-dev#113" +# No qualifying word needed — `owner/repo#N` is GitHub's native +# auto-link syntax and is already unambiguous. +# +# 2. Same-repo with qualifier: "issue #113", "fixes #42", +# "closes #7", "refs #99", "see #1", "review #12", "address #8", +# "gh #5", "PR #113", "pull request #113", "pull #113", "pulls #5". +# +# Intentionally excludes bare `#N` (no qualifying word, no owner/repo +# prefix) to avoid false matches on markdown headings, CSS colors, +# anchor fragments, etc. +# Note: we group the optional suffix explicitly (e.g. `fix(?:e[sd])?`) +# because `fixes?` would expand to "fixe"/"fixes", missing the bare +# "fix". Accepts GitHub's full closing-keyword set: close[sd], fix[e[sd]], +# resolve[sd]. +_QUALIFIER = ( + r"issues?|fix(?:e[sd])?|close[sd]?|resolve[sd]?|refs?|" + r"see|review|address|gh|pulls?|pull\s*request|pr" +) + +_CROSS_REPO_PATTERN = re.compile( + r"(?[\w.-]+)/(?P[\w.-]+)" + r"#(?P\d+)\b", +) + +_SAME_REPO_PATTERN = re.compile( + rf"(?\d+)\b", + re.IGNORECASE, +) + + +@dataclass(frozen=True) +class GitHubRef: + """A parsed GitHub issue/PR reference.""" + + owner: str + repo: str + number: int + + @property + def key(self) -> tuple[str, str, int]: + return (self.owner.lower(), self.repo.lower(), self.number) + + @property + def synthetic_path(self) -> str: + """A path-like identifier used as the `path` in gathered_context.""" + return f"github://{self.owner}/{self.repo}/issues/{self.number}" + + +def extract_github_refs( + text: str, + default_owner: str, + default_repo: str, + max_refs: int = 5, +) -> list[GitHubRef]: + """Extract unique issue/PR references from free-form text. + + References without an explicit `owner/repo` prefix use the default. + If default_owner/default_repo are empty, same-repo refs are dropped. + + Cross-repo refs (e.g. "Abernaughty/agent-dev#113") are always kept. + + Returns at most `max_refs` unique refs, preserving first-seen order + across both patterns (cross-repo and same-repo matches are merged + by starting offset in `text`). + """ + if not text: + return [] + + # Gather candidates from both patterns with their start offsets so + # we can merge in source order. + candidates: list[tuple[int, GitHubRef]] = [] + + for match in _CROSS_REPO_PATTERN.finditer(text): + try: + number = int(match.group("number")) + except ValueError: + continue + if number <= 0: + continue + candidates.append(( + match.start(), + GitHubRef( + owner=match.group("owner"), + repo=match.group("repo"), + number=number, + ), + )) + + if default_owner and default_repo: + # Track spans already claimed by cross-repo matches so we don't + # double-count a ref like "foo/bar#1" as also matching the + # same-repo pattern via "bar#1". + cross_spans = [ + (m.start("number"), m.end("number")) + for m in _CROSS_REPO_PATTERN.finditer(text) + ] + for match in _SAME_REPO_PATTERN.finditer(text): + span = (match.start("number"), match.end("number")) + if any(s <= span[0] and span[1] <= e for s, e in cross_spans): + continue + try: + number = int(match.group("number")) + except ValueError: + continue + if number <= 0: + continue + candidates.append(( + match.start(), + GitHubRef( + owner=default_owner, + repo=default_repo, + number=number, + ), + )) + + # Sort by start offset for deterministic first-seen order. + candidates.sort(key=lambda pair: pair[0]) + + seen: set[tuple[str, str, int]] = set() + refs: list[GitHubRef] = [] + for _start, ref in candidates: + if ref.key in seen: + continue + seen.add(ref.key) + refs.append(ref) + if len(refs) >= max_refs: + break + + return refs + + +def _summarize_issue_payload(data: dict, max_chars: int) -> tuple[str, bool]: + """Build a compact text summary from the GitHub issue/PR JSON. + + Returns (summary_text, truncated_flag). + """ + number = data.get("number", "?") + title = (data.get("title") or "").strip() + state = (data.get("state") or "").strip() + is_pr = "pull_request" in data + kind = "PR" if is_pr else "Issue" + labels = [ + label.get("name", "") + for label in (data.get("labels") or []) + if isinstance(label, dict) and label.get("name") + ] + body = (data.get("body") or "").strip() + + header_parts = [f"{kind} #{number}: {title}"] + if state: + header_parts.append(f"State: {state}") + if labels: + header_parts.append(f"Labels: {', '.join(labels)}") + header = "\n".join(header_parts) + + if not body: + return header, False + + # Leave room for header + a blank line + body_budget = max(0, max_chars - len(header) - 2) + truncated = False + if len(body) > body_budget: + body = body[:body_budget].rstrip() + "\n... [truncated]" + truncated = True + + return f"{header}\n\n{body}", truncated + + +async def fetch_issue_or_pr( + owner: str, + repo: str, + number: int, + token: str, + max_chars: int = 2000, + timeout: float = 10.0, +) -> dict | None: + """Fetch a single GitHub issue or PR as a gathered_context-shaped dict. + + Returns a dict matching the `gathered_context` entry shape: + {"path": "github://owner/repo/issues/N", + "content": "", + "truncated": bool, + "source": "github_issue"} + + Returns None on any failure (missing token, network error, non-200). + """ + if not token: + logger.debug("[GH-FETCH] No GITHUB_TOKEN; skipping %s/%s#%d", owner, repo, number) + return None + if not owner or not repo: + return None + + url = f"https://api.github.com/repos/{owner}/{repo}/issues/{number}" + headers = { + "Accept": "application/vnd.github.v3+json", + "Authorization": f"Bearer {token}", + "X-GitHub-Api-Version": "2022-11-28", + } + + try: + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.get(url, headers=headers) + except httpx.HTTPError as exc: + logger.debug( + "[GH-FETCH] Network error fetching %s/%s#%d: %s", + owner, repo, number, exc, + ) + return None + + if response.status_code != 200: + logger.debug( + "[GH-FETCH] %s/%s#%d -> HTTP %d", + owner, repo, number, response.status_code, + ) + return None + + try: + data = response.json() + except ValueError: + return None + + if not isinstance(data, dict): + return None + + summary, truncated = _summarize_issue_payload(data, max_chars=max_chars) + return { + "path": f"github://{owner}/{repo}/issues/{number}", + "content": summary, + "truncated": truncated, + "source": "github_issue", + } + + +async def fetch_refs_as_context_items( + text: str, + default_owner: str, + default_repo: str, + token: str, + max_refs: int = 5, + max_chars: int = 2000, +) -> list[dict]: + """Extract and fetch issue/PR refs from text as context entries. + + Returns a (possibly empty) list of gathered_context-shaped dicts, + skipping any refs that failed to fetch. Best-effort: never raises + for network/auth errors — the caller can continue without them. + """ + refs = extract_github_refs(text, default_owner, default_repo, max_refs=max_refs) + if not refs: + return [] + + items: list[dict] = [] + for ref in refs: + item = await fetch_issue_or_pr( + ref.owner, ref.repo, ref.number, + token=token, max_chars=max_chars, + ) + if item is not None: + items.append(item) + + if items: + logger.info( + "[GH-FETCH] Pre-fetched %d/%d GitHub ref(s) for gather_context", + len(items), len(refs), + ) + return items diff --git a/dev-suite/src/tools/mcp_bridge.py b/dev-suite/src/tools/mcp_bridge.py index eb3fcdc..95aa3f2 100644 --- a/dev-suite/src/tools/mcp_bridge.py +++ b/dev-suite/src/tools/mcp_bridge.py @@ -32,6 +32,20 @@ logger = logging.getLogger(__name__) +# Read-only tool allowlist (issue #193). +# +# Used by agents that should be able to explore the workspace and +# GitHub without making changes — currently the Planner and the +# Architect's optional Phase 2. Keep this in sync with the tool +# registry in provider.py: any tool whose handler performs only +# reads (no filesystem writes, no GitHub mutations) belongs here. +READONLY_TOOLS: frozenset[str] = frozenset({ + "filesystem_read", + "filesystem_list", + "github_read_diff", +}) + + class MCPConfigError(Exception): """Raised when mcp-config.json is invalid or missing.""" @@ -228,7 +242,10 @@ def _run_async(coro): # -- Tool factories -- -def get_tools(provider: ToolProvider) -> list[Tool]: +def get_tools( + provider: ToolProvider, + tool_filter: set[str] | frozenset[str] | None = None, +) -> list[Tool]: """Create LangChain Tool objects from an async ToolProvider (sync). Dynamically generates Tools from the provider's list_tools() @@ -242,16 +259,25 @@ def get_tools(provider: ToolProvider) -> list[Tool]: Args: provider: Any async ToolProvider + tool_filter: Optional allowlist of tool names. When provided, + only tools whose name is in the set are returned. Typically + set to READONLY_TOOLS for read-only agents (Planner, + Architect Phase 2) — see issue #193. Returns: List of LangChain Tool objects the agents can use. """ # list_tools() is async, so we need to bridge here definitions = _run_async(provider.list_tools()) + if tool_filter is not None: + definitions = [d for d in definitions if d.name in tool_filter] return _build_langchain_tools(provider, definitions) -async def aget_tools(provider: ToolProvider) -> list[Tool]: +async def aget_tools( + provider: ToolProvider, + tool_filter: set[str] | frozenset[str] | None = None, +) -> list[Tool]: """Create LangChain Tool objects from an async ToolProvider (async). ARCH-3: Async variant of get_tools(). Avoids the _run_async bridge @@ -259,11 +285,17 @@ async def aget_tools(provider: ToolProvider) -> list[Tool]: Args: provider: Any async ToolProvider + tool_filter: Optional allowlist of tool names. When provided, + only tools whose name is in the set are returned. Typically + set to READONLY_TOOLS for read-only agents (Planner, + Architect Phase 2) — see issue #193. Returns: List of LangChain Tool objects the agents can use. """ definitions = await provider.list_tools() + if tool_filter is not None: + definitions = [d for d in definitions if d.name in tool_filter] return _build_langchain_tools(provider, definitions) diff --git a/dev-suite/tests/test_gather_context.py b/dev-suite/tests/test_gather_context.py index 1a7dd7d..ee7ec14 100644 --- a/dev-suite/tests/test_gather_context.py +++ b/dev-suite/tests/test_gather_context.py @@ -244,7 +244,7 @@ async def test_trace_reports_file_count(self, tmp_path): "status": WorkflowStatus.PLANNING, } result = await gather_context_node(state) - assert any("gathered 1 files" in t for t in result["trace"]) + assert any("gathered 1 items" in t for t in result["trace"]) @pytest.mark.asyncio async def test_no_workspace_root_uses_default(self, tmp_path, monkeypatch): @@ -334,6 +334,111 @@ async def test_finds_sibling_file_via_repo_root(self, tmp_path): paths = [f["path"] for f in result["gathered_context"]] assert any("BottomPanel.svelte" in p for p in paths), paths + # -- GitHub pre-fetch (issue #193) -- + + @pytest.mark.asyncio + async def test_github_ref_pre_fetch_injects_items(self, tmp_path, monkeypatch): + """gather_context_node pre-fetches GitHub issues referenced in task.""" + from unittest.mock import AsyncMock, MagicMock, patch + + (tmp_path / ".git").mkdir() + monkeypatch.setenv("GITHUB_TOKEN", "t") + monkeypatch.setenv("GITHUB_OWNER", "Abernaughty") + monkeypatch.setenv("GITHUB_REPO", "agent-dev") + + payload = { + "number": 113, + "title": "Gate test issue", + "state": "open", + "body": "Run the end-to-end gate test", + } + response = MagicMock() + response.status_code = 200 + response.json = MagicMock(return_value=payload) + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.return_value = response + + state: GraphState = { + "task_description": "Please fix issue #113 — it's blocking us", + "workspace_root": str(tmp_path), + "trace": [], + "status": WorkflowStatus.PLANNING, + } + with patch("src.tools.github_fetch.httpx.AsyncClient", return_value=mock_client): + result = await gather_context_node(state) + + gh_items = [ + c for c in result["gathered_context"] + if c.get("source") == "github_issue" + ] + assert len(gh_items) == 1 + assert gh_items[0]["path"] == "github://Abernaughty/agent-dev/issues/113" + assert "Gate test issue" in gh_items[0]["content"] + + @pytest.mark.asyncio + async def test_no_github_refs_no_fetches(self, tmp_path, monkeypatch): + """Tasks with no issue/PR refs don't hit the network.""" + from unittest.mock import AsyncMock, patch + + (tmp_path / ".git").mkdir() + monkeypatch.setenv("GITHUB_TOKEN", "t") + monkeypatch.setenv("GITHUB_OWNER", "o") + monkeypatch.setenv("GITHUB_REPO", "r") + + mock_client = AsyncMock() + state: GraphState = { + "task_description": "Add a greet function to greet.py", + "workspace_root": str(tmp_path), + "trace": [], + "status": WorkflowStatus.PLANNING, + } + with patch("src.tools.github_fetch.httpx.AsyncClient", return_value=mock_client): + result = await gather_context_node(state) + + # No network calls made + mock_client.__aenter__.assert_not_called() + gh_items = [ + c for c in result.get("gathered_context", []) + if c.get("source") == "github_issue" + ] + assert gh_items == [] + + @pytest.mark.asyncio + async def test_github_fetch_failure_degrades_gracefully( + self, tmp_path, monkeypatch, + ): + """404 / missing token / network error don't break gather_context.""" + from unittest.mock import AsyncMock, MagicMock, patch + + (tmp_path / ".git").mkdir() + monkeypatch.setenv("GITHUB_TOKEN", "t") + monkeypatch.setenv("GITHUB_OWNER", "o") + monkeypatch.setenv("GITHUB_REPO", "r") + + response = MagicMock() + response.status_code = 404 + response.json = MagicMock(return_value={}) + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.return_value = response + + state: GraphState = { + "task_description": "fixes #999", + "workspace_root": str(tmp_path), + "trace": [], + "status": WorkflowStatus.PLANNING, + } + with patch("src.tools.github_fetch.httpx.AsyncClient", return_value=mock_client): + result = await gather_context_node(state) + + # No exception; empty gathered_context + assert result["gathered_context"] == [] + @pytest.mark.asyncio async def test_rejects_path_outside_repo_root(self, tmp_path): (tmp_path / ".git").mkdir() diff --git a/dev-suite/tests/test_github_fetch.py b/dev-suite/tests/test_github_fetch.py new file mode 100644 index 0000000..2170299 --- /dev/null +++ b/dev-suite/tests/test_github_fetch.py @@ -0,0 +1,356 @@ +"""Tests for GitHub issue/PR pre-fetch (issue #193). + +Covers: +- Ref pattern extraction (same-repo, cross-repo, dedupe, max_refs cap) +- Rejection of bare `#N` without qualifier +- fetch_issue_or_pr: success (issue body), truncation, PR detection, + missing token, non-200, network error, malformed JSON +- fetch_refs_as_context_items end-to-end with mocked httpx +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from src.tools.github_fetch import ( + GitHubRef, + extract_github_refs, + fetch_issue_or_pr, + fetch_refs_as_context_items, +) + +# --- extract_github_refs --------------------------------------------------- + + +class TestExtractGithubRefs: + def test_empty_text(self): + assert extract_github_refs("", "o", "r") == [] + assert extract_github_refs(None, "o", "r") == [] # type: ignore[arg-type] + + def test_simple_issue_reference(self): + refs = extract_github_refs("fix issue #113", "owner", "repo") + assert refs == [GitHubRef(owner="owner", repo="repo", number=113)] + + def test_all_qualifiers_accepted(self): + qualifiers = [ + # singular, plural, and past-tense closing keywords + "issue", "issues", + "fix", "fixes", "fixed", + "close", "closes", "closed", + "resolve", "resolves", "resolved", + "ref", "refs", "see", "review", + "address", "gh", "pr", "pull", "pulls", + ] + for q in qualifiers: + refs = extract_github_refs(f"{q} #42", "o", "r") + assert len(refs) == 1, f"qualifier '{q}' did not match" + assert refs[0].number == 42 + + def test_pull_request_multiword_qualifier(self): + refs = extract_github_refs("see pull request #77", "o", "r") + assert refs == [GitHubRef(owner="o", repo="r", number=77)] + + def test_case_insensitive_qualifier(self): + refs = extract_github_refs("FIXES #5 and Closes #6", "o", "r") + assert len(refs) == 2 + assert {r.number for r in refs} == {5, 6} + + def test_cross_repo_reference(self): + refs = extract_github_refs( + "see Abernaughty/agent-dev#113", "default", "default" + ) + assert refs == [ + GitHubRef(owner="Abernaughty", repo="agent-dev", number=113), + ] + + def test_cross_repo_without_default(self): + # Cross-repo still works even without default owner/repo + refs = extract_github_refs("foo/bar#1", "", "") + assert refs == [GitHubRef(owner="foo", repo="bar", number=1)] + + def test_bare_hash_number_rejected(self): + # No qualifier, no cross-repo prefix — must not match + refs = extract_github_refs( + "Heading\n# 113\n\nSome #456 random text", + "owner", "repo", + ) + assert refs == [] + + def test_markdown_heading_not_matched(self): + refs = extract_github_refs( + "# Main heading\n## Subheading", + "o", "r", + ) + assert refs == [] + + def test_hex_color_not_matched(self): + # #abc123 starts with letter, regex requires digits only after # + refs = extract_github_refs("color #abc123 and #deadbeef", "o", "r") + assert refs == [] + + def test_issue_prefix_on_longer_word_not_matched(self): + # "issue123" (no space) should not be treated as qualifier + refs = extract_github_refs("see issue123\n#42 here", "o", "r") + assert refs == [] + + def test_default_owner_repo_used(self): + refs = extract_github_refs("fixes #42", "myowner", "myrepo") + assert refs[0].owner == "myowner" + assert refs[0].repo == "myrepo" + + def test_same_repo_dropped_when_no_default(self): + # Same-repo refs require a default owner/repo to resolve + refs = extract_github_refs("fixes #42", "", "") + assert refs == [] + + def test_deduplication(self): + refs = extract_github_refs( + "fixes #10, closes #10, refs #10", "o", "r", + ) + assert len(refs) == 1 + assert refs[0].number == 10 + + def test_dedup_across_same_and_cross_repo(self): + # Same number from different repos is NOT a duplicate + refs = extract_github_refs( + "fixes #10 and foo/bar#10", "o", "r", + ) + assert len(refs) == 2 + keys = {r.key for r in refs} + assert ("o", "r", 10) in keys + assert ("foo", "bar", 10) in keys + + def test_cross_repo_not_double_counted(self): + # "foo/bar#1" must not also match "bar#1" via same-repo + refs = extract_github_refs("see foo/bar#1", "o", "r") + assert len(refs) == 1 + assert refs[0].owner == "foo" + + def test_max_refs_cap(self): + text = " ".join(f"fixes #{i}" for i in range(1, 20)) + refs = extract_github_refs(text, "o", "r", max_refs=5) + assert len(refs) == 5 + # First-seen order preserved + assert [r.number for r in refs] == [1, 2, 3, 4, 5] + + def test_preserves_first_seen_order(self): + refs = extract_github_refs( + "see foo/bar#99 and fixes #1 then closes #2", + "o", "r", + ) + assert [r.number for r in refs] == [99, 1, 2] + + def test_zero_and_negative_rejected(self): + refs = extract_github_refs("fixes #0", "o", "r") + assert refs == [] + + def test_synthetic_path(self): + ref = GitHubRef(owner="Abernaughty", repo="agent-dev", number=113) + assert ref.synthetic_path == "github://Abernaughty/agent-dev/issues/113" + + +# --- fetch_issue_or_pr ----------------------------------------------------- + + +def _make_response(status_code=200, json_data=None): + """Build a mock httpx.Response-like object.""" + response = MagicMock() + response.status_code = status_code + response.json = MagicMock(return_value=json_data or {}) + return response + + +class TestFetchIssueOrPr: + @pytest.mark.asyncio + async def test_returns_none_without_token(self): + result = await fetch_issue_or_pr("o", "r", 1, token="") + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_without_owner_repo(self): + result = await fetch_issue_or_pr("", "r", 1, token="t") + assert result is None + result = await fetch_issue_or_pr("o", "", 1, token="t") + assert result is None + + @pytest.mark.asyncio + async def test_success_issue(self): + payload = { + "number": 113, + "title": "Gate test", + "state": "open", + "body": "The gate test body.", + "labels": [{"name": "phase/2"}, {"name": "P0"}], + } + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.return_value = _make_response(200, payload) + + with patch("src.tools.github_fetch.httpx.AsyncClient", return_value=mock_client): + result = await fetch_issue_or_pr("owner", "repo", 113, token="t") + + assert result is not None + assert result["path"] == "github://owner/repo/issues/113" + assert result["source"] == "github_issue" + assert result["truncated"] is False + assert "Issue #113: Gate test" in result["content"] + assert "State: open" in result["content"] + assert "Labels: phase/2, P0" in result["content"] + assert "The gate test body." in result["content"] + + @pytest.mark.asyncio + async def test_success_pr_marked(self): + payload = { + "number": 200, + "title": "A PR", + "state": "open", + "body": "diff coming", + "pull_request": {"url": "..."}, # presence indicates PR + } + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.return_value = _make_response(200, payload) + + with patch("src.tools.github_fetch.httpx.AsyncClient", return_value=mock_client): + result = await fetch_issue_or_pr("o", "r", 200, token="t") + + assert result is not None + assert "PR #200" in result["content"] + + @pytest.mark.asyncio + async def test_body_truncated(self): + big_body = "X" * 5000 + payload = {"number": 1, "title": "t", "state": "open", "body": big_body} + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.return_value = _make_response(200, payload) + + with patch("src.tools.github_fetch.httpx.AsyncClient", return_value=mock_client): + result = await fetch_issue_or_pr("o", "r", 1, token="t", max_chars=500) + + assert result is not None + assert result["truncated"] is True + assert "[truncated]" in result["content"] + # Overall content respects budget (roughly) + assert len(result["content"]) <= 600 + + @pytest.mark.asyncio + async def test_non_200_returns_none(self): + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.return_value = _make_response(404, {}) + + with patch("src.tools.github_fetch.httpx.AsyncClient", return_value=mock_client): + result = await fetch_issue_or_pr("o", "r", 999, token="t") + assert result is None + + @pytest.mark.asyncio + async def test_network_error_returns_none(self): + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.side_effect = httpx.ConnectError("boom") + + with patch("src.tools.github_fetch.httpx.AsyncClient", return_value=mock_client): + result = await fetch_issue_or_pr("o", "r", 1, token="t") + assert result is None + + @pytest.mark.asyncio + async def test_malformed_json_returns_none(self): + response = MagicMock() + response.status_code = 200 + response.json = MagicMock(side_effect=ValueError("bad json")) + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.return_value = response + + with patch("src.tools.github_fetch.httpx.AsyncClient", return_value=mock_client): + result = await fetch_issue_or_pr("o", "r", 1, token="t") + assert result is None + + @pytest.mark.asyncio + async def test_empty_body_ok(self): + payload = {"number": 1, "title": "Title", "state": "closed", "body": None} + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.return_value = _make_response(200, payload) + + with patch("src.tools.github_fetch.httpx.AsyncClient", return_value=mock_client): + result = await fetch_issue_or_pr("o", "r", 1, token="t") + + assert result is not None + assert result["truncated"] is False + assert "Issue #1: Title" in result["content"] + + +# --- fetch_refs_as_context_items ------------------------------------------ + + +class TestFetchRefsAsContextItems: + @pytest.mark.asyncio + async def test_no_refs_returns_empty(self): + result = await fetch_refs_as_context_items( + "just some text", "o", "r", token="t", + ) + assert result == [] + + @pytest.mark.asyncio + async def test_best_effort_skips_failed_fetches(self): + """When one fetch fails, others still succeed.""" + payload_ok = { + "number": 10, "title": "OK", "state": "open", "body": "body" + } + + call_count = {"n": 0} + + def side_effect(*args, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return _make_response(404, {}) # first fails + return _make_response(200, payload_ok) + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.side_effect = side_effect + + with patch("src.tools.github_fetch.httpx.AsyncClient", return_value=mock_client): + items = await fetch_refs_as_context_items( + "fixes #9, closes #10", "o", "r", token="t", + ) + + assert len(items) == 1 + assert items[0]["path"] == "github://o/r/issues/10" + + @pytest.mark.asyncio + async def test_no_token_returns_empty(self): + items = await fetch_refs_as_context_items( + "fixes #1", "o", "r", token="", + ) + assert items == [] + + @pytest.mark.asyncio + async def test_respects_max_refs(self): + payload = {"number": 1, "title": "t", "state": "open", "body": "b"} + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.get.return_value = _make_response(200, payload) + + text = " ".join(f"fixes #{i}" for i in range(1, 20)) + with patch("src.tools.github_fetch.httpx.AsyncClient", return_value=mock_client): + items = await fetch_refs_as_context_items( + text, "o", "r", token="t", max_refs=3, + ) + + assert len(items) == 3 diff --git a/dev-suite/tests/test_mcp_tools.py b/dev-suite/tests/test_mcp_tools.py index a9344b4..8e1efcb 100644 --- a/dev-suite/tests/test_mcp_tools.py +++ b/dev-suite/tests/test_mcp_tools.py @@ -657,6 +657,64 @@ def test_tools_have_coroutine(self, provider): assert tool.coroutine is not None +# ============================================================ +# Read-only tool filtering (issue #193) +# ============================================================ + + +class TestToolFilter: + """Verify READONLY_TOOLS allowlist + tool_filter parameter.""" + + @pytest.fixture + def provider(self, tmp_path): + (tmp_path / "a.txt").write_text("a", encoding="utf-8") + return LocalToolProvider(workspace_root=tmp_path) + + def test_readonly_tools_is_frozenset(self): + from src.tools.mcp_bridge import READONLY_TOOLS + assert isinstance(READONLY_TOOLS, frozenset) + + def test_readonly_tools_excludes_writes(self): + from src.tools.mcp_bridge import READONLY_TOOLS + for name in ("filesystem_write", "filesystem_patch", "github_create_pr"): + assert name not in READONLY_TOOLS + + def test_readonly_tools_includes_reads(self): + from src.tools.mcp_bridge import READONLY_TOOLS + assert "filesystem_read" in READONLY_TOOLS + assert "filesystem_list" in READONLY_TOOLS + assert "github_read_diff" in READONLY_TOOLS + + def test_get_tools_no_filter_returns_all(self, provider): + assert len(get_tools(provider)) == 6 + + def test_get_tools_with_readonly_filter(self, provider): + from src.tools.mcp_bridge import READONLY_TOOLS + tools = get_tools(provider, tool_filter=READONLY_TOOLS) + names = {t.name for t in tools} + assert "filesystem_write" not in names + assert "filesystem_patch" not in names + assert "github_create_pr" not in names + # Read-only tools that exist in this provider + assert "filesystem_read" in names + assert "filesystem_list" in names + + def test_get_tools_with_empty_filter(self, provider): + assert get_tools(provider, tool_filter=set()) == [] + + def test_get_tools_filter_drops_unknown_names(self, provider): + tools = get_tools(provider, tool_filter={"nonexistent_tool"}) + assert tools == [] + + @pytest.mark.asyncio + async def test_aget_tools_with_readonly_filter(self, provider): + from src.tools.mcp_bridge import READONLY_TOOLS, aget_tools + tools = await aget_tools(provider, tool_filter=READONLY_TOOLS) + names = {t.name for t in tools} + assert "filesystem_write" not in names + assert "filesystem_read" in names + + # ============================================================ # Integration: full filesystem round-trip # ============================================================