Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 185 additions & 3 deletions .github/scripts/pr_labeler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
"""PR labeler: compute size, risk, and template-field labels for one or more PRs.
"""PR labeler: compute size, risk, template-field, and domain labels for PRs.

Inputs come from environment variables set by the calling workflow:
GITHUB_REPOSITORY e.g. "owner/repo" (always set on Actions)
Expand All @@ -18,20 +18,23 @@
The current format is preferred; the legacy format is matched as a
fallback so PRs opened before the template change keep working until
the queue rolls over.
5. Reconciling with current labels and applying adds/removes via `gh pr edit`.
5. Matching changed files against CODEOWNERS to apply domain/* labels.
6. Reconciling with current labels and applying adds/removes via `gh pr edit`.

For backfill mode (PR_NUMBER == "all"), per-PR failures are logged but do not
abort the run, unless more than 10% of PRs fail.
"""

from __future__ import annotations

import base64
import json
import os
import re
import subprocess
import sys
from dataclasses import dataclass, field
from fnmatch import fnmatch

SIZE_LABELS = ["size/XS", "size/S", "size/M", "size/L", "size/XL"]
RISK_LABELS = ["risk/low", "risk/medium", "risk/high"]
Expand All @@ -51,6 +54,156 @@
# Conservative fallback for unmapped Bugbot levels (e.g., "Critical", "Minimal").
RISK_FALLBACK = "risk/high"

DOMAIN_LABEL_PREFIX = "domain/"
KNOWN_DOMAIN_SLUGS = frozenset(
[
"scanning",
"findings",
"integrations",
"platform",
"frontend",
"infra",
"database",
]
)


# ---------------------------------------------------------------------------
# CODEOWNERS parsing (last-match-wins per file, union across all files)
# ---------------------------------------------------------------------------

CodeownersRule = tuple[str, list[str]] # (pattern, [team_slugs])


def parse_codeowners(text: str) -> list[CodeownersRule]:
"""Parse CODEOWNERS text into an ordered list of (pattern, teams) rules."""
rules: list[CodeownersRule] = []
for line in text.splitlines():
line = line.split("#", 1)[0].strip()
if not line:
continue
tokens = line.split()
pattern = tokens[0]
slugs: list[str] = []
for owner in tokens[1:]:
# @org/team -> team (lowercased)
if "/" in owner:
slugs.append(owner.rsplit("/", 1)[1].lower())
else:
slugs.append(owner.lstrip("@").lower())
rules.append((pattern, slugs))
return rules


def _codeowners_match(pattern: str, filepath: str) -> bool:
"""Test whether a CODEOWNERS pattern matches a file path.

Implements GitHub's CODEOWNERS matching rules:
- ``*`` alone matches everything.
- A pattern starting with ``/`` is anchored to the repo root; the
leading ``/`` is stripped before matching.
- A pattern ending with ``/`` matches everything under that directory.
- A pattern containing an internal ``/`` (after stripping leading ``/``)
is implicitly anchored to the repo root.
- A pattern with no ``/`` at all matches by basename at any depth.
- A single ``*`` does not cross directory boundaries (unlike fnmatch);
use ``**`` to match across directories.
"""
if pattern == "*":
return True

anchored = pattern.startswith("/")
p = pattern.lstrip("/")

# Check for internal slash *before* appending ** for trailing-slash dirs.
# A trailing-only slash does not anchor; only leading or internal slashes do.
has_internal_slash = "/" in p.rstrip("/")

if p.endswith("/"):
p += "**"

if anchored or has_internal_slash:
return _gitignore_match(p, filepath)

if "/" not in p:
# No slash at all: match against basename at any depth.
basename = filepath.rsplit("/", 1)[-1]
return fnmatch(basename, p)

# Trailing-slash-only dir (e.g. "vendor/") with no anchoring:
# match at any depth by prepending **/.
return _gitignore_match("**/" + p, filepath)


def _gitignore_match(pattern: str, filepath: str) -> bool:
"""Match a pattern against a filepath where ``*`` does not cross ``/``.

Splits both pattern and path on ``/`` and matches segment-by-segment.
``**`` matches zero or more directory segments.
"""
return _segments_match(pattern.split("/"), filepath.split("/"))


def _segments_match(pat_parts: list[str], path_parts: list[str]) -> bool:
if not pat_parts:
return not path_parts
if pat_parts[0] == "**":
rest = pat_parts[1:]
# ** matches zero or more segments
for i in range(len(path_parts) + 1):
if _segments_match(rest, path_parts[i:]):
return True
return False
if not path_parts:
return False
if fnmatch(path_parts[0], pat_parts[0]):
return _segments_match(pat_parts[1:], path_parts[1:])
return False


def domains_for_pr(rules: list[CodeownersRule], changed_files: list[str]) -> set[str]:
"""Return the set of domain slugs that own any changed file."""
teams: set[str] = set()
for filepath in changed_files:
matched_slugs: list[str] = []
for pattern, slugs in rules:
if _codeowners_match(pattern, filepath):
matched_slugs = slugs
teams.update(matched_slugs)
return teams


CODEOWNERS_PATHS = [".github/CODEOWNERS", "CODEOWNERS", "docs/CODEOWNERS"]


def fetch_codeowners(repo: str) -> str | None:
"""Fetch CODEOWNERS from the repo's default branch via the Contents API.

Checks the three locations GitHub supports, in priority order:
``.github/CODEOWNERS``, ``CODEOWNERS``, ``docs/CODEOWNERS``.
"""
for path in CODEOWNERS_PATHS:
result = gh(
["api", f"repos/{repo}/contents/{path}", "--jq", ".content"],
check=False,
)
if result.returncode != 0:
continue
try:
return base64.b64decode(result.stdout.strip()).decode()
except Exception:
continue
return None


def fetch_pr_files(repo: str, pr_number: int) -> list[str]:
"""Return the list of changed file paths for a PR."""
result = gh(
["pr", "view", str(pr_number), "--repo", repo, "--json", "files"],
)
data = json.loads(result.stdout)
return [f["path"] for f in data.get("files", [])]


def yesno_regex(keyword: str) -> re.Pattern[str]:
"""Match the current template format and capture ``yes`` or ``no``.
Expand Down Expand Up @@ -222,6 +375,7 @@ def reconcile(
pr: dict,
*,
plan: LabelPlan,
domain_slugs: set[str] | None = None,
) -> None:
current_labels = {label["name"] for label in pr.get("labels", [])}
body = pr.get("body") or ""
Expand Down Expand Up @@ -259,6 +413,19 @@ def reconcile(
elif state == "off" and label in current_labels:
plan.remove.append(label)

# Domain labels: add for matched teams, remove stale ones.
if domain_slugs is not None:
desired_domain = {
f"{DOMAIN_LABEL_PREFIX}{s}" for s in domain_slugs if s in KNOWN_DOMAIN_SLUGS
}
for slug in KNOWN_DOMAIN_SLUGS:
label = f"{DOMAIN_LABEL_PREFIX}{slug}"
if label in desired_domain:
if label not in current_labels:
plan.add.append(label)
elif label in current_labels:
plan.remove.append(label)


def apply(repo: str, plan: LabelPlan, dry_run: bool) -> None:
if dry_run or (not plan.add and not plan.remove):
Expand Down Expand Up @@ -294,6 +461,15 @@ def main() -> int:

print(f"Processing {len(targets)} PR(s) in {repo} (dry_run={dry_run})")

# Fetch CODEOWNERS once per run (same for all PRs in this repo).
codeowners_text = fetch_codeowners(repo)
codeowners_rules: list[CodeownersRule] | None = None
if codeowners_text is not None:
codeowners_rules = parse_codeowners(codeowners_text)
print(f"Loaded {len(codeowners_rules)} CODEOWNERS rule(s) for domain labeling")
else:
print("No CODEOWNERS found; skipping domain labeling")

failures = 0
for pr_number in targets:
plan = LabelPlan(pr_number=pr_number)
Expand All @@ -302,7 +478,13 @@ def main() -> int:
if pr.get("state") != "OPEN":
print(f"PR #{pr_number} (skip: not open)")
continue
reconcile(pr, plan=plan)

domain_slugs: set[str] | None = None
if codeowners_rules is not None:
files = fetch_pr_files(repo, pr_number)
domain_slugs = domains_for_pr(codeowners_rules, files)

reconcile(pr, plan=plan, domain_slugs=domain_slugs)
apply(repo, plan, dry_run)
print(plan.summary())
except subprocess.CalledProcessError as exc:
Expand Down
Loading