|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Check PR conflicts and manage conflict labels. |
| 4 | +
|
| 5 | +This script checks pull requests for merge conflicts and automatically |
| 6 | +adds/removes a conflict label based on the mergeable state. |
| 7 | +""" |
| 8 | + |
| 9 | +import os |
| 10 | +import sys |
| 11 | +import time |
| 12 | +import json |
| 13 | +from typing import List, Dict, Optional, Any |
| 14 | +from enum import Enum |
| 15 | + |
| 16 | +try: |
| 17 | + import requests |
| 18 | +except ImportError: |
| 19 | + print("Error: requests module not found. Install with: pip install requests") |
| 20 | + sys.exit(1) |
| 21 | + |
| 22 | +# Constants |
| 23 | +DEFAULT_CONFLICT_LABEL = "conflicts" |
| 24 | +DEFAULT_MAX_RETRIES = 10 |
| 25 | +DEFAULT_RETRY_DELAY = 5 # seconds |
| 26 | +DEFAULT_PER_PAGE = 100 |
| 27 | + |
| 28 | +PULL_REQUEST_EVENTS = ["pull_request", "pull_request_target"] |
| 29 | +MERGEABLE_STATES = { |
| 30 | + "UNKNOWN": "unknown", |
| 31 | + "DIRTY": "dirty", |
| 32 | + "CLEAN": "clean", |
| 33 | + "BLOCKED": "blocked", |
| 34 | + "UNSTABLE": "unstable", |
| 35 | + "BEHIND": "behind" |
| 36 | +} |
| 37 | + |
| 38 | + |
| 39 | +class PRCheckingStrategy(Enum): |
| 40 | + """Strategy for determining which PRs to check.""" |
| 41 | + CHECK_CURRENT_PR_ONLY = "check_current_pr" # Check only the current PR (on synchronize) |
| 42 | + CHECK_BASE_BRANCH_PRS = "check_base_branch_prs" # Check PRs targeting specific base branch (on merge) |
| 43 | + CHECK_ALL_PRS = "check_all_prs" # Check all open PRs (scheduled/other events) |
| 44 | + |
| 45 | + |
| 46 | +class GitHubAPI: |
| 47 | + """Helper class for GitHub API operations.""" |
| 48 | + |
| 49 | + def __init__(self, token: str, owner: str, repo: str): |
| 50 | + self.token = token |
| 51 | + self.owner = owner |
| 52 | + self.repo = repo |
| 53 | + self.base_url = "https://api.github.com" |
| 54 | + self.headers = { |
| 55 | + "Authorization": f"Bearer {token}", |
| 56 | + "Accept": "application/vnd.github+json", |
| 57 | + "X-GitHub-Api-Version": "2022-11-28" |
| 58 | + } |
| 59 | + |
| 60 | + def _request(self, method: str, endpoint: str, **kwargs) -> requests.Response: |
| 61 | + """Make a request to GitHub API.""" |
| 62 | + url = f"{self.base_url}{endpoint}" |
| 63 | + response = requests.request(method, url, headers=self.headers, **kwargs) |
| 64 | + response.raise_for_status() |
| 65 | + return response |
| 66 | + |
| 67 | + def fetch_open_prs(self) -> List[Dict[str, Any]]: |
| 68 | + """Fetch all open pull requests with pagination support.""" |
| 69 | + endpoint = f"/repos/{self.owner}/{self.repo}/pulls" |
| 70 | + params = {"state": "open", "per_page": DEFAULT_PER_PAGE} |
| 71 | + all_prs = [] |
| 72 | + page = 1 |
| 73 | + |
| 74 | + while True: |
| 75 | + params["page"] = page |
| 76 | + response = self._request("GET", endpoint, params=params) |
| 77 | + prs = response.json() |
| 78 | + if not prs: |
| 79 | + break |
| 80 | + all_prs.extend(prs) |
| 81 | + page += 1 |
| 82 | + |
| 83 | + return all_prs |
| 84 | + |
| 85 | + def get_pr_details(self, pr_number: int) -> Dict[str, Any]: |
| 86 | + """Get detailed information about a PR.""" |
| 87 | + endpoint = f"/repos/{self.owner}/{self.repo}/pulls/{pr_number}" |
| 88 | + response = self._request("GET", endpoint) |
| 89 | + return response.json() |
| 90 | + |
| 91 | + def add_label(self, pr_number: int, label: str) -> None: |
| 92 | + """Add a label to a PR.""" |
| 93 | + endpoint = f"/repos/{self.owner}/{self.repo}/issues/{pr_number}/labels" |
| 94 | + self._request("POST", endpoint, json={"labels": [label]}) |
| 95 | + |
| 96 | + def remove_label(self, pr_number: int, label: str) -> None: |
| 97 | + """Remove a label from a PR.""" |
| 98 | + endpoint = f"/repos/{self.owner}/{self.repo}/issues/{pr_number}/labels/{label}" |
| 99 | + self._request("DELETE", endpoint) |
| 100 | + |
| 101 | + |
| 102 | +def log_info(message: str) -> None: |
| 103 | + """Log an info message.""" |
| 104 | + print(f"ℹ️ {message}") |
| 105 | + |
| 106 | + |
| 107 | +def log_error(message: str) -> None: |
| 108 | + """Log an error message.""" |
| 109 | + print(f"❌ {message}", file=sys.stderr) |
| 110 | + |
| 111 | + |
| 112 | +def log_warning(message: str) -> None: |
| 113 | + """Log a warning message.""" |
| 114 | + print(f"⚠️ {message}") |
| 115 | + |
| 116 | + |
| 117 | +def get_github_context() -> Dict[str, Any]: |
| 118 | + """Get GitHub Actions context from environment.""" |
| 119 | + event_path = os.getenv("GITHUB_EVENT_PATH") |
| 120 | + if not event_path: |
| 121 | + return {} |
| 122 | + |
| 123 | + try: |
| 124 | + with open(event_path, 'r') as f: |
| 125 | + return json.load(f) |
| 126 | + except Exception as e: |
| 127 | + log_warning(f"Could not load GitHub event: {e}") |
| 128 | + return {} |
| 129 | + |
| 130 | + |
| 131 | +def get_current_pr_number(event_payload: Dict[str, Any]) -> Optional[int]: |
| 132 | + """Get the current PR number from the event payload.""" |
| 133 | + pull_request = event_payload.get("pull_request", {}) |
| 134 | + return pull_request.get("number") |
| 135 | + |
| 136 | + |
| 137 | +def get_base_branch(event_payload: Dict[str, Any]) -> Optional[str]: |
| 138 | + """Get the base branch from the event payload.""" |
| 139 | + pull_request = event_payload.get("pull_request", {}) |
| 140 | + return pull_request.get("base", {}).get("ref") |
| 141 | + |
| 142 | + |
| 143 | +def determine_pr_checking_strategy(event_name: str, event_payload: Dict[str, Any]) -> PRCheckingStrategy: |
| 144 | + """ |
| 145 | + Determine PR filtering strategy based on trigger event. |
| 146 | + |
| 147 | + Returns: |
| 148 | + PRCheckingStrategy: Enum value indicating which strategy to use |
| 149 | + """ |
| 150 | + if event_name in PULL_REQUEST_EVENTS: |
| 151 | + action = event_payload.get("action") |
| 152 | + |
| 153 | + if action == "synchronize": |
| 154 | + log_info(f"PR {action} - checking only current PR") |
| 155 | + return PRCheckingStrategy.CHECK_CURRENT_PR_ONLY |
| 156 | + elif action == "closed" and event_payload.get("pull_request", {}).get("merged"): |
| 157 | + base_branch = get_base_branch(event_payload) |
| 158 | + log_info(f"PR merged - checking all PRs targeting base branch: {base_branch}") |
| 159 | + return PRCheckingStrategy.CHECK_BASE_BRANCH_PRS |
| 160 | + |
| 161 | + log_info("Non-PR trigger - checking all open PRs regardless of base branch") |
| 162 | + return PRCheckingStrategy.CHECK_ALL_PRS |
| 163 | + |
| 164 | + |
| 165 | +def filter_prs_by_base_branch(prs: List[Dict[str, Any]], base_branch: str) -> List[Dict[str, Any]]: |
| 166 | + """Filter PRs by base branch.""" |
| 167 | + return [pr for pr in prs if pr.get("base", {}).get("ref") == base_branch] |
| 168 | + |
| 169 | + |
| 170 | +def wait_for_mergeable_state(api: GitHubAPI, pr_number: int, max_retries: int, retry_delay: int) -> Dict[str, Any]: |
| 171 | + """Wait for GitHub to calculate mergeable state with retries.""" |
| 172 | + pr_details = api.get_pr_details(pr_number) |
| 173 | + |
| 174 | + retries = 0 |
| 175 | + while pr_details.get("mergeable_state") == MERGEABLE_STATES["UNKNOWN"] and retries < max_retries: |
| 176 | + log_info(f" Waiting for mergeable state... (attempt {retries + 1})") |
| 177 | + time.sleep(retry_delay) |
| 178 | + pr_details = api.get_pr_details(pr_number) |
| 179 | + retries += 1 |
| 180 | + |
| 181 | + return pr_details |
| 182 | + |
| 183 | + |
| 184 | +def has_conflicts(mergeable: Optional[bool], mergeable_state: str) -> bool: |
| 185 | + """ |
| 186 | + Check if PR has conflicts based on mergeable state. |
| 187 | + |
| 188 | + According to GitHub API: |
| 189 | + - mergeable: false + mergeable_state: "dirty" = has merge conflicts |
| 190 | + - mergeable_state: "behind" = behind base branch but mergeable (no conflicts) |
| 191 | + - mergeable_state: "blocked" = blocked by branch protection (not a conflict) |
| 192 | + - mergeable_state: "unstable" = failing checks (not a conflict) |
| 193 | + |
| 194 | + We only mark PRs with actual merge conflicts (dirty state). |
| 195 | + """ |
| 196 | + return mergeable is False and mergeable_state == MERGEABLE_STATES["DIRTY"] |
| 197 | + |
| 198 | + |
| 199 | +def process_pr(api: GitHubAPI, pr: Dict[str, Any], conflict_label: str, max_retries: int, retry_delay: int) -> None: |
| 200 | + """Process a single PR for conflict checking and label management.""" |
| 201 | + pr_number = pr["number"] |
| 202 | + pr_title = pr["title"] |
| 203 | + |
| 204 | + log_info(f"\nChecking PR #{pr_number}: {pr_title}") |
| 205 | + |
| 206 | + pr_details = wait_for_mergeable_state(api, pr_number, max_retries, retry_delay) |
| 207 | + |
| 208 | + mergeable = pr_details.get("mergeable") |
| 209 | + mergeable_state = pr_details.get("mergeable_state", "") |
| 210 | + current_labels = [label["name"] for label in pr_details.get("labels", [])] |
| 211 | + has_conflict_label = conflict_label in current_labels |
| 212 | + pr_has_conflicts = has_conflicts(mergeable, mergeable_state) |
| 213 | + |
| 214 | + log_info(f" Mergeable: {mergeable}, State: {mergeable_state}") |
| 215 | + log_info(f" Has conflicts: {pr_has_conflicts}, Has label: {has_conflict_label}") |
| 216 | + |
| 217 | + if pr_has_conflicts and not has_conflict_label: |
| 218 | + try: |
| 219 | + api.add_label(pr_number, conflict_label) |
| 220 | + log_info(f" ✅ Added {conflict_label} label to PR #{pr_number}") |
| 221 | + except requests.exceptions.HTTPError as e: |
| 222 | + log_warning(f" Could not add label: {e}") |
| 223 | + elif not pr_has_conflicts and has_conflict_label: |
| 224 | + try: |
| 225 | + api.remove_label(pr_number, conflict_label) |
| 226 | + log_info(f" ✅ Removed {conflict_label} label from PR #{pr_number}") |
| 227 | + except requests.exceptions.HTTPError as e: |
| 228 | + log_warning(f" Could not remove label: {e}") |
| 229 | + else: |
| 230 | + log_info(f" ℹ️ No label changes needed for PR #{pr_number}") |
| 231 | + |
| 232 | + |
| 233 | +def main(): |
| 234 | + """Main function to run the conflict checking script.""" |
| 235 | + # Get inputs from environment |
| 236 | + github_token = os.getenv("GITHUB_TOKEN") |
| 237 | + conflict_label = os.getenv("CONFLICT_LABEL", DEFAULT_CONFLICT_LABEL) |
| 238 | + max_retries = int(os.getenv("MAX_RETRIES", str(DEFAULT_MAX_RETRIES))) |
| 239 | + retry_delay = int(os.getenv("RETRY_DELAY", str(DEFAULT_RETRY_DELAY))) |
| 240 | + |
| 241 | + # Get GitHub context |
| 242 | + github_repository = os.getenv("GITHUB_REPOSITORY") |
| 243 | + event_name = os.getenv("GITHUB_EVENT_NAME", "") |
| 244 | + |
| 245 | + if not github_token: |
| 246 | + log_error("GITHUB_TOKEN environment variable is required") |
| 247 | + sys.exit(1) |
| 248 | + |
| 249 | + if not github_repository: |
| 250 | + log_error("GITHUB_REPOSITORY environment variable is required") |
| 251 | + sys.exit(1) |
| 252 | + |
| 253 | + try: |
| 254 | + owner, repo = github_repository.split("/") |
| 255 | + except ValueError: |
| 256 | + log_error(f"Invalid GITHUB_REPOSITORY format: {github_repository}") |
| 257 | + sys.exit(1) |
| 258 | + |
| 259 | + log_info(f"Checking repository: {owner}/{repo}") |
| 260 | + log_info(f"Event: {event_name}") |
| 261 | + |
| 262 | + # Initialize GitHub API client |
| 263 | + api = GitHubAPI(github_token, owner, repo) |
| 264 | + |
| 265 | + # Get event payload and determine strategy |
| 266 | + event_payload = get_github_context() |
| 267 | + log_info(f"Action: {event_payload.get('action', 'N/A')}") |
| 268 | + |
| 269 | + strategy = determine_pr_checking_strategy(event_name, event_payload) |
| 270 | + |
| 271 | + # Get PRs to check based on strategy |
| 272 | + if strategy == PRCheckingStrategy.CHECK_CURRENT_PR_ONLY: |
| 273 | + current_pr_number = get_current_pr_number(event_payload) |
| 274 | + if current_pr_number: |
| 275 | + log_info(f"Checking only current PR #{current_pr_number}") |
| 276 | + prs_to_check = [api.get_pr_details(current_pr_number)] |
| 277 | + else: |
| 278 | + log_error("Cannot get current PR number from event payload") |
| 279 | + sys.exit(1) |
| 280 | + |
| 281 | + elif strategy == PRCheckingStrategy.CHECK_BASE_BRANCH_PRS: |
| 282 | + base_branch = get_base_branch(event_payload) |
| 283 | + if base_branch: |
| 284 | + all_prs = api.fetch_open_prs() |
| 285 | + prs_to_check = filter_prs_by_base_branch(all_prs, base_branch) |
| 286 | + log_info(f"Filtered to {len(prs_to_check)} PRs targeting base branch '{base_branch}' " |
| 287 | + f"(out of {len(all_prs)} total open PRs)") |
| 288 | + else: |
| 289 | + log_error("Cannot get base branch from event payload") |
| 290 | + sys.exit(1) |
| 291 | + |
| 292 | + else: # CHECK_ALL_PRS |
| 293 | + all_prs = api.fetch_open_prs() |
| 294 | + prs_to_check = all_prs |
| 295 | + log_info(f"Found {len(all_prs)} open pull requests (checking all branches)") |
| 296 | + |
| 297 | + # Process each PR |
| 298 | + for pr in prs_to_check: |
| 299 | + try: |
| 300 | + process_pr(api, pr, conflict_label, max_retries, retry_delay) |
| 301 | + except Exception as e: |
| 302 | + log_error(f" Error processing PR #{pr['number']}: {e}") |
| 303 | + |
| 304 | + log_info("\n✅ Conflict checking completed successfully") |
| 305 | + |
| 306 | + |
| 307 | +if __name__ == "__main__": |
| 308 | + try: |
| 309 | + main() |
| 310 | + except Exception as e: |
| 311 | + log_error(f"Script failed: {e}") |
| 312 | + sys.exit(1) |
0 commit comments