Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions bin/git-landpr
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#!/usr/bin/env python3

import argparse
import collections
import json
import subprocess


class _GitError(Exception):
pass


def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("commit", help="The commit to land")
return parser


def _run_git_command(args: list[str]) -> str:
try:
output: bytes = subprocess.check_output(["git"] + args)
return output.strip().decode("utf-8")
except subprocess.CalledProcessError:
raise _GitError()
except UnicodeDecodeError:
print(f"error: output was {output}")
raise


def _get_possible_shas() -> list[str]:
return _run_git_command(
["log", "--no-show-signature", "--pretty=%H", "@{upstream}..HEAD"]
).splitlines()


def _get_branch_name(commit: str) -> str:
return _run_git_command(["pilebranchname", commit])


def _get_pr_url(commit: str) -> str:
"""Get the PR URL for a commit by finding the PR with a matching head branch."""
branch_name = _get_branch_name(commit)
result = subprocess.run(
["gh", "pr", "list", "--head", branch_name, "--json", "url", "--limit", "1"],
capture_output=True,
text=True,
)
if result.returncode != 0:
raise SystemExit(
f"error: failed to find PR for branch {branch_name}, {result.stderr.strip()}"
)

prs = json.loads(result.stdout)
if not prs:
raise SystemExit(f"error: no PR found for branch {branch_name}")

return prs[0]["url"]


def _get_prs_targeting_base(base_branch: str) -> list[str]:
"""Get all PR head branches that target the given base branch."""
output = subprocess.check_output(
["gh", "pr", "list", "--base", base_branch, "--json", "headRefName"],
)
prs = json.loads(output)
return [pr["headRefName"] for pr in prs]


def _find_commit_for_branch(branch_name: str, possible_shas: list[str]) -> str | None:
"""Find the local commit that corresponds to a branch name."""
for sha in possible_shas:
if _get_branch_name(sha) == branch_name:
return sha
return None


def _main(commit: str) -> None:
# Normalize HEAD -> commit, etc
commit = _run_git_command(["rev-parse", commit])

pr_url = _get_pr_url(commit)
branch_being_merged = _get_branch_name(commit)

subprocess.check_call(["gh", "pr", "merge", pr_url, "--squash"])
try:
_run_git_command(["pull"])
except _GitError:
raise SystemExit("error: failed to pull, dependent PRs have not been rebased")

# Build a map of branch names to their local commits
possible_shas = _get_possible_shas()

queue = collections.deque([branch_being_merged])
processed = set()

while queue:
base = queue.popleft()
if base in processed:
continue
processed.add(base)

# Find all PRs that target this base branch
dependent_branches = _get_prs_targeting_base(base)

for dep_branch in dependent_branches:
# Find the local commit for this dependent branch
dep_commit = _find_commit_for_branch(dep_branch, possible_shas)
if not dep_commit:
continue

# Rebase the dependent PR
if base == branch_being_merged:
_run_git_command(["rebasepr", dep_commit])
else:
_run_git_command(["rebasepr", dep_commit, base])

# Add this branch to the queue to process its dependents
queue.append(dep_branch)


if __name__ == "__main__":
args = _build_parser().parse_args()
_main(args.commit)
6 changes: 3 additions & 3 deletions bin/git-rebasepr
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ do
;;
*)
if [[ -n "$commit" ]]; then
echo "error: multiple commit args passed, '$commit' and '$arg'" >&2
exit 1
rebase_args+=("$arg")
else
commit="$arg"
fi
commit="$arg"
;;
esac
done
Expand Down