From db5a6eb91fc8d4b05a7126e3b3df5b94d7ed0c91 Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Sat, 16 May 2026 13:33:52 -0700 Subject: [PATCH] fix(scripts): avoid shell in dependency checker --- scripts/check_deps.py | 132 ++++++++++++++++++++------------ tests/test_check_deps_script.py | 10 +++ 2 files changed, 94 insertions(+), 48 deletions(-) create mode 100644 tests/test_check_deps_script.py diff --git a/scripts/check_deps.py b/scripts/check_deps.py index 7e8cd30..adb4913 100755 --- a/scripts/check_deps.py +++ b/scripts/check_deps.py @@ -1,61 +1,97 @@ #!/usr/bin/env python3 """Check for dependency conflicts in pyproject.toml minimum versions.""" -import subprocess import json import re import sys +from urllib.error import HTTPError, URLError +from urllib.request import urlopen try: import tomllib except ModuleNotFoundError: import tomli as tomllib -with open("pyproject.toml", "rb") as f: - data = tomllib.load(f) - -deps = {} -for dep in data["project"]["dependencies"]: - match = re.match(r"([a-zA-Z0-9_-]+)>=([0-9.]+)", dep) - if match: - pkg, ver = match.groups() - deps[pkg] = ver - -conflicts = [] -for pkg, ver in deps.items(): - result = subprocess.run( - f"curl -s https://pypi.org/pypi/{pkg}/{ver}/json", - shell=True, - capture_output=True, - text=True, - ) - if result.returncode == 0: + +def load_dependencies() -> dict[str, str]: + """Load minimum dependency versions from pyproject.toml.""" + with open("pyproject.toml", "rb") as f: + data = tomllib.load(f) + + deps = {} + for dep in data["project"]["dependencies"]: + match = re.match(r"([a-zA-Z0-9_-]+)>=([0-9.]+)", dep) + if match: + pkg, ver = match.groups() + deps[pkg] = ver + return deps + + +def fetch_package_metadata(pkg: str, ver: str) -> dict: + """Fetch PyPI metadata for a package version.""" + url = f"https://pypi.org/pypi/{pkg}/{ver}/json" + with urlopen(url, timeout=15) as response: + return json.loads(response.read().decode("utf-8")) + + +def version_parts(version: str) -> list[int]: + """Convert a dotted version string to comparable integer parts.""" + return [int(x) for x in version.split(".")] + + +def requires_higher_version(requirement: str, package: str, current_version: str) -> str | None: + """Return the required version if a requirement exceeds the current minimum.""" + if not (requirement.startswith(package + ">") or requirement.startswith(package + " ")): + return None + + match = re.search(r">=(\d+\.\d+(?:\.\d+)?)", requirement) + if not match: + return None + + required_version = match.group(1) + required_parts = version_parts(required_version) + current_parts = version_parts(current_version) + + while len(required_parts) < len(current_parts): + required_parts.append(0) + while len(current_parts) < len(required_parts): + current_parts.append(0) + + return required_version if required_parts > current_parts else None + + +def find_conflicts(deps: dict[str, str]) -> list[str]: + """Find dependency minimum-version conflicts.""" + conflicts = [] + for pkg, ver in deps.items(): try: - pkg_data = json.loads(result.stdout) - requires = pkg_data.get("info", {}).get("requires_dist", []) - for req in requires: - if "extra ==" in req: - continue - for our_pkg, our_ver in deps.items(): - if req.startswith(our_pkg + ">") or req.startswith(our_pkg + " "): - match = re.search(r">=(\d+\.\d+(?:\.\d+)?)", req) - if match: - required_ver = match.group(1) - req_parts = [int(x) for x in required_ver.split(".")] - our_parts = [int(x) for x in our_ver.split(".")] - while len(req_parts) < len(our_parts): - req_parts.append(0) - while len(our_parts) < len(req_parts): - our_parts.append(0) - if req_parts > our_parts: - conflicts.append( - f" ❌ {pkg}=={ver} requires {our_pkg}>={required_ver}, but we have >={our_ver}" - ) - except Exception: - pass - -if conflicts: - print("\n⚠️ Dependency conflicts found:") - for c in conflicts: - print(c) - sys.exit(1) + pkg_data = fetch_package_metadata(pkg, ver) + except (HTTPError, URLError, TimeoutError, json.JSONDecodeError): + continue + + requires = pkg_data.get("info", {}).get("requires_dist", []) + for req in requires: + if "extra ==" in req: + continue + for our_pkg, our_ver in deps.items(): + required_ver = requires_higher_version(req, our_pkg, our_ver) + if required_ver: + conflicts.append( + f" ❌ {pkg}=={ver} requires {our_pkg}>={required_ver}, but we have >={our_ver}" + ) + return conflicts + + +def main() -> int: + """Check dependency minimum-version conflicts.""" + conflicts = find_conflicts(load_dependencies()) + if conflicts: + print("\n⚠️ Dependency conflicts found:") + for c in conflicts: + print(c) + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/test_check_deps_script.py b/tests/test_check_deps_script.py new file mode 100644 index 0000000..9e1abd8 --- /dev/null +++ b/tests/test_check_deps_script.py @@ -0,0 +1,10 @@ +"""Tests for the dependency conflict checker script.""" + +from pathlib import Path + + +def test_check_deps_script_does_not_use_shell_true(): + """The dependency checker should not invoke network calls through a shell.""" + source = Path("scripts/check_deps.py").read_text(encoding="utf-8") + + assert "shell=True" not in source