diff --git a/github_rest_api/scripts/github/create_pull_request.py b/github_rest_api/scripts/github/create_pull_request.py index 3d1d5e4..72ddf20 100644 --- a/github_rest_api/scripts/github/create_pull_request.py +++ b/github_rest_api/scripts/github/create_pull_request.py @@ -6,6 +6,7 @@ import os import sys from github_rest_api import Repository +from github_rest_api.utils import compile_patterns def parse_args(args=None, namespace=None) -> Namespace: @@ -34,6 +35,19 @@ def parse_args(args=None, namespace=None) -> Namespace: required=True, help="The base branch to merge changes into.", ) + parser.add_argument( + "--ignore-patterns", + dest="ignore_patterns", + nargs="*", + default=["^_"], + help="A list of regular expression patterns. Branches matching any of these patterns will be ignored.", + ) + parser.add_argument( + "--update", + dest="update", + action="store_true", + help="Update the head branch using the base branch before creating the pull request.", + ) return parser.parse_args(args=args, namespace=namespace) @@ -43,10 +57,23 @@ def main() -> int: The branch is updated (using dev) before creating the PR. """ args = parse_args() - # skip branches with the pattern _* - if args.head_branch.startswith("_"): - return 0 + # skip branches matching any of the ignore patterns + try: + compiled = compile_patterns(args.ignore_patterns) + except ValueError as e: + print(e, file=sys.stderr) + return 1 + for pattern in compiled: + if pattern.search(args.head_branch): + print( + f"Branch '{args.head_branch}' matches ignore pattern '{ + pattern.pattern + }', skipping." + ) + return 0 repo = Repository(args.token, os.environ["GITHUB_REPOSITORY"]) + if args.update: + repo.update_branch(update=args.head_branch, upstream=args.base_branch) repo.create_pull_request( { "base": args.base_branch, diff --git a/github_rest_api/utils.py b/github_rest_api/utils.py index 41c0f4b..6bd4c27 100644 --- a/github_rest_api/utils.py +++ b/github_rest_api/utils.py @@ -1,6 +1,8 @@ """Some generally useful util functions.""" +from collections.abc import Sequence from itertools import tee, filterfalse +import re def partition(pred, iterable): @@ -12,6 +14,27 @@ def partition(pred, iterable): return filter(pred, it1), filterfalse(pred, it2) +def compile_patterns(patterns: str | Sequence[str] | None) -> list[re.Pattern[str]]: + """Compile a list of regular expression patterns. + + :param patterns: A list of regular expression patterns to compile. + :return: A list of compiled regular expression patterns. + """ + if not patterns: + return [] + if isinstance(patterns, str): + patterns = [patterns] + compiled = [] + for pattern in patterns: + try: + compiled.append(re.compile(pattern)) + except re.error as e: + raise ValueError( + f"Invalid regular expression pattern '{pattern}': {e}" + ) from e + return compiled + + def strip_patch_version(version: str) -> str: """Strip the patch version.