From d6cb996760cf8e4910c3e14e3f68c15bd50fa4cb Mon Sep 17 00:00:00 2001 From: redreceipt Date: Thu, 4 Jun 2026 12:09:34 -0400 Subject: [PATCH] Harden stale review post GitHub fetches --- github.py | 93 ++++++++++++++++++++++--------- tests/test_gql_client_requests.py | 69 +++++++++++++++++++++++ 2 files changed, 136 insertions(+), 26 deletions(-) diff --git a/github.py b/github.py index 7526fa9..cac9f37 100644 --- a/github.py +++ b/github.py @@ -1,3 +1,4 @@ +import logging import os import threading from concurrent.futures import ThreadPoolExecutor @@ -18,6 +19,20 @@ token = os.getenv("GITHUB_TOKEN") headers = {"Authorization": f"bearer {token}"} +TRACKED_REPOSITORIES = ( + "apollosproject/apollos-platforms", + "apollosproject/apollos-cluster", + "apollosproject/apollos-admin", + "apollosproject/admin-transcriptions", + "apollosproject/apollos-shovel", + "apollosproject/apollos-embeds", + "differential/crossroads-anywhere", +) + + +class GitHubDataError(RuntimeError): + """Raised when GitHub data would otherwise be silently incomplete.""" + _thread_local = threading.local() @@ -51,21 +66,21 @@ def _execute(query, variable_values=None): } +def _format_failure(name: str, exc: Exception) -> str: + message = str(exc) + if message: + return f"{name}: {message}" + return f"{name}: {type(exc).__name__}" + + @lru_cache(maxsize=1) -def get_repo_ids(): +def get_repo_ids_by_name(): if not token: - return [] + return {} # List of repositories to track in the format "owner/name". - repos = [ - "apollosproject/apollos-platforms", - "apollosproject/apollos-cluster", - "apollosproject/apollos-admin", - "apollosproject/admin-transcriptions", - "apollosproject/apollos-shovel", - "apollosproject/apollos-embeds", - "differential/crossroads-anywhere", - ] - ids = [] + repos = TRACKED_REPOSITORIES + ids_by_name = {} + failures = [] # GraphQL query for fetching a repository ID by owner and name. repo_id_query = gql( """ @@ -85,15 +100,27 @@ def get_repo_ids(): params = {"owner": owner, "name": name} try: data = _execute(repo_id_query, variable_values=params) - except Exception: + except Exception as exc: + logging.exception("Failed to look up GitHub repository ID for %s", full_name) + failures.append(_format_failure(full_name, exc)) continue - ids.append(data["repository"]["id"]) - return ids + repository = data.get("repository") if data else None + repo_id = repository.get("id") if repository else None + if not repo_id: + failures.append(f"{full_name}: repository was not returned by GitHub") + continue + ids_by_name[full_name] = repo_id + if failures: + raise GitHubDataError( + "Failed to look up all tracked GitHub repositories: " + "; ".join(failures) + ) + return ids_by_name -def get_prs(repo_id, pr_states): +def get_prs(repo_id, pr_states, repo_name=None): if not token: return [] + repo_context = repo_name or repo_id query = gql( """ query PRs ($repo_id: ID!, $pr_states: [PullRequestState!], $cursor: String) { @@ -173,9 +200,16 @@ def get_prs(repo_id, pr_states): params = {"repo_id": repo_id, "pr_states": pr_states, "cursor": cursor} try: data = _execute(query, variable_values=params) - except Exception: - return [] - payload = data["node"]["pullRequests"] + except Exception as exc: + raise GitHubDataError( + f"Failed to fetch GitHub pull requests for {repo_context}" + ) from exc + node = data.get("node") if data else None + payload = node.get("pullRequests") if node else None + if payload is None: + raise GitHubDataError( + f"GitHub pull request response was missing data for {repo_context}" + ) all_prs.extend(payload["nodes"]) page_info = payload["pageInfo"] if not page_info["hasNextPage"]: @@ -255,17 +289,24 @@ def get_active_change_request_reviewers(pr): def _get_all_prs(pr_states: List[str]) -> List[Dict[str, Any]]: """Fetch PRs for all tracked repositories concurrently.""" - repo_ids = get_repo_ids() - if not repo_ids: + repo_ids_by_name = get_repo_ids_by_name() + if not repo_ids_by_name: return [] - with ThreadPoolExecutor(max_workers=len(repo_ids)) as executor: - futures = [executor.submit(get_prs, repo_id, pr_states) for repo_id in repo_ids] + with ThreadPoolExecutor(max_workers=len(repo_ids_by_name)) as executor: + futures = { + executor.submit(get_prs, repo_id, pr_states, repo_name): repo_name + for repo_name, repo_id in repo_ids_by_name.items() + } all_prs: List[Dict[str, Any]] = [] - for future in futures: + failures = [] + for future, repo_name in futures.items(): try: all_prs.extend(future.result()) - except Exception: - continue + except Exception as exc: + logging.exception("Failed to fetch GitHub PRs for %s", repo_name) + failures.append(_format_failure(repo_name, exc)) + if failures: + raise GitHubDataError("Failed to fetch complete GitHub PR data: " + "; ".join(failures)) return all_prs diff --git a/tests/test_gql_client_requests.py b/tests/test_gql_client_requests.py index 1abe66b..bc34fc3 100644 --- a/tests/test_gql_client_requests.py +++ b/tests/test_gql_client_requests.py @@ -59,6 +59,75 @@ def test_execute_without_variables_uses_original_request(self): self.assertIs(request, query) self.assertEqual(kwargs, {}) + def test_repo_id_lookup_raises_instead_of_caching_partial_tracking_set(self): + def fake_execute(query, variable_values=None): + if variable_values["name"] == "apollos-cluster": + raise RuntimeError("missing access") + return {"repository": {"id": variable_values["name"]}} + + github.get_repo_ids_by_name.cache_clear() + try: + with patch.object(github, "token", "token"): + with patch.object(github, "_execute", side_effect=fake_execute): + with patch.object(github.logging, "exception"): + with self.assertRaisesRegex( + github.GitHubDataError, + "apollosproject/apollos-cluster", + ): + github.get_repo_ids_by_name() + + with patch.object( + github, + "_execute", + return_value={"repository": {"id": "ok"}}, + ): + repos = github.get_repo_ids_by_name() + + self.assertEqual(set(repos), set(github.TRACKED_REPOSITORIES)) + finally: + github.get_repo_ids_by_name.cache_clear() + + def test_repo_id_lookup_raises_when_github_omits_tracked_repository(self): + github.get_repo_ids_by_name.cache_clear() + try: + with patch.object(github, "token", "token"): + with patch.object(github, "_execute", return_value={"repository": None}): + with self.assertRaisesRegex( + github.GitHubDataError, + "repository was not returned", + ): + github.get_repo_ids_by_name() + finally: + github.get_repo_ids_by_name.cache_clear() + + def test_get_prs_raises_when_repo_fetch_fails(self): + with patch.object(github, "token", "token"): + with patch.object(github, "_execute", side_effect=RuntimeError("rate limited")): + with self.assertRaisesRegex( + github.GitHubDataError, + "apollosproject/apollos-cluster", + ): + github.get_prs("repo-id", ["OPEN"], "apollosproject/apollos-cluster") + + def test_get_all_prs_raises_when_any_repo_fetch_fails(self): + def fake_get_prs(repo_id, pr_states, repo_name=None): + if repo_name == "apollosproject/apollos-cluster": + raise RuntimeError("rate limited") + return [{"number": 1, "repo": repo_name}] + + repo_ids = { + "apollosproject/apollos-platforms": "platforms-id", + "apollosproject/apollos-cluster": "cluster-id", + } + with patch.object(github, "get_repo_ids_by_name", return_value=repo_ids): + with patch.object(github, "get_prs", side_effect=fake_get_prs): + with patch.object(github.logging, "exception"): + with self.assertRaisesRegex( + github.GitHubDataError, + "apollosproject/apollos-cluster", + ): + github._get_all_prs(["OPEN"]) + def test_waiting_for_review_uses_utc_timestamps(self): class FixedDateTime(datetime): @classmethod