diff --git a/src/ghstack/cli.py b/src/ghstack/cli.py index 1a7f4f7..029a1bd 100644 --- a/src/ghstack/cli.py +++ b/src/ghstack/cli.py @@ -164,8 +164,29 @@ def cherry_pick(stack: bool, pull_request: str) -> None: @main.command("land") @click.option("--force", is_flag=True, help="force land even if the PR is closed") +@click.option( + "--validate-rules", + is_flag=True, + help="validate against merge rules defined in .github/merge_rules.yaml", +) +@click.option( + "--dry-run", + is_flag=True, + help="validate merge rules but don't actually merge", +) +@click.option( + "--comment-on-failure", + is_flag=True, + help="post validation errors as a PR comment", +) @click.argument("pull_request", metavar="PR") -def land(force: bool, pull_request: str) -> None: +def land( + force: bool, + validate_rules: bool, + dry_run: bool, + comment_on_failure: bool, + pull_request: str, +) -> None: """ Land a PR stack """ @@ -177,6 +198,9 @@ def land(force: bool, pull_request: str) -> None: github_url=config.github_url, remote_name=config.remote_name, force=force, + validate_rules=validate_rules, + dry_run=dry_run, + comment_on_failure=comment_on_failure, ) diff --git a/src/ghstack/land.py b/src/ghstack/land.py index 08adfb6..f1cb20c 100644 --- a/src/ghstack/land.py +++ b/src/ghstack/land.py @@ -2,11 +2,12 @@ import logging import re -from typing import List, Tuple +from typing import List, Optional, Tuple import ghstack.git import ghstack.github import ghstack.github_utils +import ghstack.merge_rules import ghstack.shell from ghstack.diff import PullRequestResolved from ghstack.types import GitCommitHash @@ -50,6 +51,10 @@ def main( github_url: str, *, force: bool = False, + validate_rules: bool = False, + dry_run: bool = False, + comment_on_failure: bool = False, + rules_file: Optional[str] = None, ) -> None: # We land the entire stack pointed to by a URL. @@ -60,11 +65,14 @@ def main( params = ghstack.github_utils.parse_pull_request( pull_request, sh=sh, remote_name=remote_name ) + owner = params["owner"] + name = params["name"] + default_branch = ghstack.github_utils.get_github_repo_info( github=github, sh=sh, - repo_owner=params["owner"], - repo_name=params["name"], + repo_owner=owner, + repo_name=name, github_url=github_url, remote_name=remote_name, )["default_branch"] @@ -72,7 +80,7 @@ def main( needs_force = False try: protection = github.get( - f"repos/{params['owner']}/{params['name']}/branches/{default_branch}/protection" + f"repos/{owner}/{name}/branches/{default_branch}/protection" ) if not protection["allow_force_pushes"]["enabled"]: raise RuntimeError( @@ -91,12 +99,12 @@ def main( orig_ref, closed = lookup_pr_to_orig_ref_and_closed( github, - owner=params["owner"], - name=params["name"], + owner=owner, + name=name, number=params["number"], ) - if closed: + if closed and not force: raise RuntimeError("PR is already closed, cannot land it!") if sh is None: @@ -117,6 +125,64 @@ def main( github_url=github_url, ) + # Compute the metadata for each commit + stack_orig_refs: List[Tuple[str, PullRequestResolved]] = [] + for s in stack: + pr_resolved = s.pull_request_resolved + # We got this from GitHub, this better not be corrupted + assert pr_resolved is not None + + ref, stack_closed = lookup_pr_to_orig_ref_and_closed( + github, + owner=pr_resolved.owner, + name=pr_resolved.repo, + number=pr_resolved.number, + ) + if stack_closed and not force: + continue + stack_orig_refs.append((ref, pr_resolved)) + + # Validate merge rules if requested + if validate_rules: + logging.info("Validating merge rules for PR stack...") + + # Load merge rules + loader = ghstack.merge_rules.MergeRulesLoader(github, owner, name) + if rules_file: + rules = loader.load_from_file(rules_file) + else: + rules = loader.load_from_repo() + + if not rules: + logging.warning("No merge rules found, skipping validation") + else: + # Validate each PR in the stack + validator = ghstack.merge_rules.MergeValidator(github, owner, name) + for _, pr_resolved in stack_orig_refs: + result = validator.validate_pr(pr_resolved.number, rules) + if not result.valid: + logging.error( + f"Validation failed for PR #{pr_resolved.number}: " + f"{', '.join(result.errors)}" + ) + if comment_on_failure: + comment_body = ( + ghstack.merge_rules.format_validation_error_comment(result) + ) + github.post_issue_comment( + owner, name, pr_resolved.number, comment_body + ) + raise ghstack.merge_rules.MergeValidationError(result) + logging.info( + f"PR #{pr_resolved.number} passed validation (rule: {result.rule_name})" + ) + + logging.info("All PRs in stack passed merge rules validation") + + if dry_run: + logging.info("Dry run complete - no changes made") + return + # Switch working copy try: prev_ref = sh.git("symbolic-ref", "--short", "HEAD") @@ -127,23 +193,6 @@ def main( sh.git("checkout", f"{remote_name}/{default_branch}") try: - # Compute the metadata for each commit - stack_orig_refs: List[Tuple[str, PullRequestResolved]] = [] - for s in stack: - pr_resolved = s.pull_request_resolved - # We got this from GitHub, this better not be corrupted - assert pr_resolved is not None - - ref, closed = lookup_pr_to_orig_ref_and_closed( - github, - owner=pr_resolved.owner, - name=pr_resolved.repo, - number=pr_resolved.number, - ) - if closed and not force: - continue - stack_orig_refs.append((ref, pr_resolved)) - # OK, actually do the land now for orig_ref, pr_resolved in stack_orig_refs: try: