diff --git a/src/runpod_flash/cli/commands/update.py b/src/runpod_flash/cli/commands/update.py new file mode 100644 index 00000000..a5c9c482 --- /dev/null +++ b/src/runpod_flash/cli/commands/update.py @@ -0,0 +1,182 @@ +"""CLI command for updating runpod-flash to latest or a specific version.""" + +import json +import shutil +import subprocess +import sys +import urllib.error +import urllib.request +from importlib import metadata +from typing import Optional + +import typer +from rich.console import Console + +console = Console() + +PYPI_URL = "https://pypi.org/pypi/runpod-flash/json" +INSTALL_TIMEOUT_SECONDS = 120 + + +def _get_current_version() -> str: + """Return installed runpod-flash version, or 'unknown' if not found.""" + try: + return metadata.version("runpod-flash") + except metadata.PackageNotFoundError: + return "unknown" + + +def _parse_version(version: str) -> tuple[int, ...]: + """Parse a version string like '1.5.0' into a comparable tuple (1, 5, 0). + + Tuples are NOT padded here -- callers comparing two parsed versions should + use ``_compare_versions()`` to handle differing component counts. + """ + return tuple(int(part) for part in version.split(".")) + + +def _compare_versions(a: tuple[int, ...], b: tuple[int, ...]) -> int: + """Compare two parsed version tuples, padding shorter one with zeros. + + Returns negative if a < b, zero if equal, positive if a > b. + Handles differing component counts: (2, 0) and (2, 0, 0) are equal. + """ + max_len = max(len(a), len(b)) + a_padded = a + (0,) * (max_len - len(a)) + b_padded = b + (0,) * (max_len - len(b)) + if a_padded < b_padded: + return -1 + if a_padded > b_padded: + return 1 + return 0 + + +def _fetch_pypi_metadata() -> tuple[str, set[str]]: + """Fetch latest version and available releases from PyPI. + + Returns: + Tuple of (latest_version, set_of_all_version_strings). + + Raises: + ConnectionError: Network unreachable or DNS failure. + RuntimeError: HTTP error from PyPI. + """ + try: + with urllib.request.urlopen(PYPI_URL, timeout=15) as resp: + data = json.loads(resp.read().decode()) + except urllib.error.URLError as exc: + if isinstance(exc, urllib.error.HTTPError): + raise RuntimeError( + f"PyPI returned HTTP {exc.code}. Try again later." + ) from exc + raise ConnectionError( + "Could not reach PyPI. Check your network connection." + ) from exc + except (json.JSONDecodeError, UnicodeDecodeError) as exc: + raise RuntimeError( + "PyPI returned an unexpected response. Try again later." + ) from exc + + try: + latest = data["info"]["version"] + except (KeyError, TypeError) as exc: + raise RuntimeError( + "PyPI response missing version info. Try again later." + ) from exc + + releases = set(data.get("releases", {}).keys()) + return latest, releases + + +def _build_install_command(version: str) -> list[str]: + """Build the install command, preferring uv over pip. + + Returns the command as a list of strings suitable for subprocess.run. + Uses ``uv pip install`` when uv is on PATH, otherwise falls back to + ``python -m pip install``. + """ + package_spec = f"runpod-flash=={version}" + if shutil.which("uv"): + return ["uv", "pip", "install", package_spec, "--quiet"] + return [sys.executable, "-m", "pip", "install", package_spec, "--quiet"] + + +def _run_install(version: str) -> subprocess.CompletedProcess[str]: + """Install the given version of runpod-flash. + + Raises: + subprocess.TimeoutExpired: Install took longer than INSTALL_TIMEOUT_SECONDS. + RuntimeError: Installer exited with non-zero code. + """ + cmd = _build_install_command(version) + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=INSTALL_TIMEOUT_SECONDS, + ) + if result.returncode != 0: + installer = "uv" if cmd[0] == "uv" else "pip" + stderr = result.stderr.strip() + raise RuntimeError( + f"{installer} install failed (exit {result.returncode}): {stderr}" + ) + return result + + +def update_command( + version: Optional[str] = typer.Option( + None, "--version", "-V", help="Target version to install (default: latest)" + ), +) -> None: + """Update runpod-flash to the latest version or a specific version.""" + current = _get_current_version() + console.print(f"Current version: [bold]{current}[/bold]") + + # Fetch PyPI metadata + with console.status("Checking PyPI for available versions..."): + try: + latest, releases = _fetch_pypi_metadata() + except (ConnectionError, RuntimeError) as exc: + console.print(f"[red]error:[/red] {exc}") + raise typer.Exit(code=1) + + target = version or latest + + # Validate target version exists on PyPI + if target not in releases: + console.print( + f"[red]error:[/red] version [bold]{target}[/bold] not found on PyPI" + ) + raise typer.Exit(code=1) + + # Already on target + if current == target: + console.print(f"Already on version [bold]{target}[/bold]. Nothing to do.") + raise typer.Exit(code=0) + + # Downgrade warning + if current != "unknown": + try: + if _compare_versions(_parse_version(target), _parse_version(current)) < 0: + console.print( + f"[yellow]note:[/yellow] {target} is older than {current} (downgrade)" + ) + except ValueError: + pass # non-standard version string, skip comparison + + # Install + console.print(f"Installing runpod-flash [bold]{target}[/bold]...") + with console.status("Installing..."): + try: + _run_install(target) + except subprocess.TimeoutExpired: + console.print( + f"[red]error:[/red] install timed out after {INSTALL_TIMEOUT_SECONDS}s" + ) + raise typer.Exit(code=1) + except RuntimeError as exc: + console.print(f"[red]error:[/red] {exc}") + raise typer.Exit(code=1) + + console.print(f"[green]Updated runpod-flash {current} -> {target}[/green]") diff --git a/src/runpod_flash/cli/main.py b/src/runpod_flash/cli/main.py index e02dbeb9..5d5ac6b5 100644 --- a/src/runpod_flash/cli/main.py +++ b/src/runpod_flash/cli/main.py @@ -14,7 +14,9 @@ apps, undeploy, login, + update, ) +from .update_checker import start_background_check def get_version() -> str: @@ -41,6 +43,7 @@ def get_version() -> str: app.command("build")(build.build_command) app.command("login")(login.login_command) app.command("deploy")(deploy.deploy_command) +app.command("update")(update.update_command) # app.command("report")(resource.report_command) @@ -66,6 +69,9 @@ def get_version() -> str: app.command("undeploy")(undeploy.undeploy_command) +_UPDATE_CHECK_EXCLUDED = frozenset({"run", "update"}) + + @app.callback(invoke_without_command=True) def main( ctx: typer.Context, @@ -87,6 +93,9 @@ def main( ) ) + if ctx.invoked_subcommand and ctx.invoked_subcommand not in _UPDATE_CHECK_EXCLUDED: + start_background_check() + if __name__ == "__main__": app() diff --git a/src/runpod_flash/cli/update_checker.py b/src/runpod_flash/cli/update_checker.py new file mode 100644 index 00000000..fc046a22 --- /dev/null +++ b/src/runpod_flash/cli/update_checker.py @@ -0,0 +1,181 @@ +"""Passive background update check for the flash CLI. + +Spawns a daemon thread that checks PyPI (at most once per 24h, cached to disk). +An atexit handler prints a one-line notice to stderr if a newer version exists. +The thread never blocks the command -- if the network is slow, the notice is +silently skipped. +""" + +from __future__ import annotations + +import atexit +import json +import os +import sys +import threading +from datetime import datetime, timezone +from pathlib import Path + +from .commands.update import ( + _compare_versions, + _fetch_pypi_metadata, + _get_current_version, + _parse_version, +) + +CACHE_FILENAME = "update_check.json" +CHECK_INTERVAL_HOURS = 24 + +_newer_version: str | None = None +_result_lock = threading.Lock() +_check_done = threading.Event() +_started = False +_start_lock = threading.Lock() + + +def _get_cache_path() -> Path: + """Return path to the update check cache file. + + Follows XDG_CONFIG_HOME convention, same directory as credentials.toml. + """ + config_home = os.getenv("XDG_CONFIG_HOME") + base_dir = ( + Path(config_home).expanduser() if config_home else Path.home() / ".config" + ) + return base_dir / "runpod" / CACHE_FILENAME + + +def _read_cache(path: Path) -> dict | None: + """Read the cache JSON file. Return None on any error.""" + try: + return json.loads(path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError, ValueError): + return None + + +def _write_cache(path: Path, latest_version: str) -> None: + """Write cache with current UTC timestamp and latest version. + + Creates parent directories if needed. Silently ignores write failures. + """ + try: + path.parent.mkdir(parents=True, exist_ok=True) + data = { + "last_checked_utc": datetime.now(timezone.utc).isoformat(), + "latest_version": latest_version, + } + path.write_text(json.dumps(data), encoding="utf-8") + except OSError: + pass + + +def _is_cache_fresh(cache: dict) -> bool: + """Return True if the cache was written within CHECK_INTERVAL_HOURS.""" + try: + last_checked = datetime.fromisoformat(cache["last_checked_utc"]) + elapsed_hours = ( + datetime.now(timezone.utc) - last_checked + ).total_seconds() / 3600 + return elapsed_hours < CHECK_INTERVAL_HOURS + except (KeyError, ValueError, TypeError): + return False + + +def _run_check() -> None: + """Thread body: check PyPI for a newer version. + + Reads cache first. If fresh, uses cached latest_version. Otherwise fetches + from PyPI and updates cache. Compares against current installed version. + All exceptions are swallowed -- this must never crash the CLI. + """ + global _newer_version # noqa: PLW0603 + try: + current = _get_current_version() + if current == "unknown": + return + + cache_path = _get_cache_path() + cache = _read_cache(cache_path) + + latest = None + if cache and _is_cache_fresh(cache): + latest = cache.get("latest_version") or None + + if not latest: + latest, _ = _fetch_pypi_metadata() + _write_cache(cache_path, latest) + + if not latest: + return + + current_tuple = _parse_version(current) + latest_tuple = _parse_version(latest) + + if _compare_versions(latest_tuple, current_tuple) > 0: + with _result_lock: + _newer_version = latest + except Exception: # noqa: BLE001 + pass + finally: + _check_done.set() + + +def _print_update_notice() -> None: + """atexit handler: print update notice to stderr if a newer version was found. + + If the background thread hasn't finished yet, skip silently. + Uses plain text (no Rich markup) since atexit runs after Rich teardown. + """ + if not _check_done.is_set(): + return + + with _result_lock: + version = _newer_version + + if version: + print( + f"\nA new version of runpod-flash is available: {version}\n" + " Run 'flash update' to upgrade.", + file=sys.stderr, + ) + + +def _is_interactive() -> bool: + """Return True if at least one of stdout/stderr is a TTY.""" + try: + if sys.stderr is not None and sys.stderr.isatty(): + return True + except Exception: # noqa: BLE001 + pass + try: + if sys.stdout is not None and sys.stdout.isatty(): + return True + except Exception: # noqa: BLE001 + pass + return False + + +def start_background_check() -> None: + """Start the passive update check. + + Skips if FLASH_NO_UPDATE_CHECK or CI environment variables are set, or when + neither stdout nor stderr is attached to a TTY. Idempotent — only starts + once per process. The guard flag is set only after passing all skip checks, + so a skipped first call does not prevent future calls from starting. + """ + global _started # noqa: PLW0603 + with _start_lock: + if _started: + return + + if os.getenv("FLASH_NO_UPDATE_CHECK"): + return + if os.getenv("CI"): + return + if not _is_interactive(): + return + + _started = True + thread = threading.Thread(target=_run_check, daemon=True) + thread.start() + atexit.register(_print_update_notice) diff --git a/tests/unit/cli/commands/test_update.py b/tests/unit/cli/commands/test_update.py new file mode 100644 index 00000000..26a58657 --- /dev/null +++ b/tests/unit/cli/commands/test_update.py @@ -0,0 +1,387 @@ +"""Tests for flash update command.""" + +import subprocess +import sys +from unittest.mock import MagicMock, Mock, patch + +import pytest +import typer + +from runpod_flash.cli.commands.update import ( + _build_install_command, + _compare_versions, + _fetch_pypi_metadata, + _get_current_version, + _parse_version, + _run_install, + update_command, +) + + +# --------------------------------------------------------------------------- +# Unit tests for helpers +# --------------------------------------------------------------------------- + + +class TestGetCurrentVersion: + def test_returns_version(self): + with patch( + "runpod_flash.cli.commands.update.metadata.version", + return_value="1.3.0", + ): + assert _get_current_version() == "1.3.0" + + def test_returns_unknown_on_not_found(self): + from importlib.metadata import PackageNotFoundError + + with patch( + "runpod_flash.cli.commands.update.metadata.version", + side_effect=PackageNotFoundError("runpod-flash"), + ): + assert _get_current_version() == "unknown" + + +class TestParseVersion: + def test_standard_version(self): + assert _parse_version("1.5.0") == (1, 5, 0) + + def test_two_part_version(self): + assert _parse_version("2.0") == (2, 0) + + def test_comparison(self): + assert _parse_version("1.4.0") < _parse_version("1.5.0") + assert _parse_version("2.0.0") > _parse_version("1.9.9") + assert _parse_version("1.0.0") == _parse_version("1.0.0") + + def test_invalid_raises(self): + with pytest.raises(ValueError): + _parse_version("not.a.version") + + +class TestCompareVersions: + def test_equal_same_length(self): + assert _compare_versions((1, 5, 0), (1, 5, 0)) == 0 + + def test_equal_different_length(self): + """Core edge case: (2, 0) and (2, 0, 0) are semantically equal.""" + assert _compare_versions((2, 0), (2, 0, 0)) == 0 + + def test_less_than(self): + assert _compare_versions((1, 4, 0), (1, 5, 0)) < 0 + + def test_greater_than(self): + assert _compare_versions((2, 0, 0), (1, 9, 9)) > 0 + + def test_shorter_tuple_less(self): + assert _compare_versions((1, 9), (1, 9, 1)) < 0 + + def test_shorter_tuple_greater(self): + assert _compare_versions((2, 1), (2, 0, 0)) > 0 + + def test_empty_tuples(self): + assert _compare_versions((), ()) == 0 + + def test_one_empty(self): + assert _compare_versions((), (1,)) < 0 + + +class TestFetchPypiMetadata: + def _make_response(self, latest: str, releases: list[str]) -> MagicMock: + import json + + data = { + "info": {"version": latest}, + "releases": {v: [] for v in releases}, + } + resp = MagicMock() + resp.read.return_value = json.dumps(data).encode() + resp.__enter__ = Mock(return_value=resp) + resp.__exit__ = Mock(return_value=False) + return resp + + def test_returns_latest_and_releases(self): + resp = self._make_response("1.5.0", ["1.3.0", "1.4.0", "1.5.0"]) + with patch( + "runpod_flash.cli.commands.update.urllib.request.urlopen", return_value=resp + ): + latest, releases = _fetch_pypi_metadata() + assert latest == "1.5.0" + assert releases == {"1.3.0", "1.4.0", "1.5.0"} + + def test_connection_error_on_url_error(self): + import urllib.error + + with patch( + "runpod_flash.cli.commands.update.urllib.request.urlopen", + side_effect=urllib.error.URLError("DNS failure"), + ): + with pytest.raises(ConnectionError, match="Could not reach PyPI"): + _fetch_pypi_metadata() + + def test_runtime_error_on_http_error(self): + import urllib.error + + with patch( + "runpod_flash.cli.commands.update.urllib.request.urlopen", + side_effect=urllib.error.HTTPError( + url="https://pypi.org", + code=503, + msg="Service Unavailable", + hdrs={}, + fp=None, + ), + ): + with pytest.raises(RuntimeError, match="PyPI returned HTTP 503"): + _fetch_pypi_metadata() + + def test_runtime_error_on_malformed_json(self): + resp = MagicMock() + resp.read.return_value = b"not json {{" + resp.__enter__ = Mock(return_value=resp) + resp.__exit__ = Mock(return_value=False) + + with patch( + "runpod_flash.cli.commands.update.urllib.request.urlopen", + return_value=resp, + ): + with pytest.raises(RuntimeError, match="unexpected response"): + _fetch_pypi_metadata() + + def test_runtime_error_on_missing_version_key(self): + import json as _json + + resp = MagicMock() + resp.read.return_value = _json.dumps({"info": {}}).encode() + resp.__enter__ = Mock(return_value=resp) + resp.__exit__ = Mock(return_value=False) + + with patch( + "runpod_flash.cli.commands.update.urllib.request.urlopen", + return_value=resp, + ): + with pytest.raises(RuntimeError, match="missing version info"): + _fetch_pypi_metadata() + + +class TestBuildInstallCommand: + def test_uses_uv_when_available(self): + with patch( + "runpod_flash.cli.commands.update.shutil.which", return_value="/usr/bin/uv" + ): + cmd = _build_install_command("1.5.0") + assert cmd == ["uv", "pip", "install", "runpod-flash==1.5.0", "--quiet"] + + def test_falls_back_to_pip(self): + with patch("runpod_flash.cli.commands.update.shutil.which", return_value=None): + cmd = _build_install_command("1.5.0") + assert cmd[0:3] == [sys.executable, "-m", "pip"] + assert "runpod-flash==1.5.0" in cmd + + +class TestRunInstall: + def test_success(self): + result = MagicMock(returncode=0, stderr="", stdout="") + with ( + patch( + "runpod_flash.cli.commands.update.subprocess.run", return_value=result + ), + patch( + "runpod_flash.cli.commands.update._build_install_command", + return_value=["uv", "pip", "install", "runpod-flash==1.5.0", "--quiet"], + ), + ): + assert _run_install("1.5.0") is result + + def test_failure_raises_runtime_error_uv(self): + result = MagicMock(returncode=1, stderr="No matching distribution") + with ( + patch( + "runpod_flash.cli.commands.update.subprocess.run", return_value=result + ), + patch( + "runpod_flash.cli.commands.update._build_install_command", + return_value=[ + "uv", + "pip", + "install", + "runpod-flash==99.99.99", + "--quiet", + ], + ), + ): + with pytest.raises(RuntimeError, match="uv install failed"): + _run_install("99.99.99") + + def test_failure_raises_runtime_error_pip(self): + result = MagicMock(returncode=1, stderr="No matching distribution") + with ( + patch( + "runpod_flash.cli.commands.update.subprocess.run", return_value=result + ), + patch( + "runpod_flash.cli.commands.update._build_install_command", + return_value=[ + sys.executable, + "-m", + "pip", + "install", + "runpod-flash==99.99.99", + ], + ), + ): + with pytest.raises(RuntimeError, match="pip install failed"): + _run_install("99.99.99") + + def test_timeout_propagates(self): + with ( + patch( + "runpod_flash.cli.commands.update.subprocess.run", + side_effect=subprocess.TimeoutExpired(cmd="uv", timeout=120), + ), + patch( + "runpod_flash.cli.commands.update._build_install_command", + return_value=["uv", "pip", "install", "runpod-flash==1.5.0", "--quiet"], + ), + ): + with pytest.raises(subprocess.TimeoutExpired): + _run_install("1.5.0") + + +# --------------------------------------------------------------------------- +# Integration tests for update_command +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_update_env(): + """Provide mocks for console, PyPI fetch, pip install, and current version.""" + mock_console = MagicMock() + mock_console.status.return_value.__enter__ = Mock(return_value=None) + mock_console.status.return_value.__exit__ = Mock(return_value=False) + + mocks = { + "console": mock_console, + "get_version": MagicMock(return_value="1.3.0"), + "fetch_pypi": MagicMock(return_value=("1.5.0", {"1.3.0", "1.4.0", "1.5.0"})), + "run_install": MagicMock(return_value=MagicMock(returncode=0)), + } + + patches = [ + patch("runpod_flash.cli.commands.update.console", mocks["console"]), + patch( + "runpod_flash.cli.commands.update._get_current_version", + mocks["get_version"], + ), + patch( + "runpod_flash.cli.commands.update._fetch_pypi_metadata", mocks["fetch_pypi"] + ), + patch("runpod_flash.cli.commands.update._run_install", mocks["run_install"]), + ] + + for p in patches: + p.start() + + yield mocks + + for p in patches: + p.stop() + + +class TestUpdateCommandHappyPath: + def test_update_to_latest(self, mock_update_env): + update_command(version=None) + + mock_update_env["run_install"].assert_called_once_with("1.5.0") + # Verify success message printed + calls = [str(c) for c in mock_update_env["console"].print.call_args_list] + assert any("1.3.0 -> 1.5.0" in c for c in calls) + + def test_update_to_specific_version(self, mock_update_env): + update_command(version="1.4.0") + + mock_update_env["run_install"].assert_called_once_with("1.4.0") + + def test_downgrade_prints_warning(self, mock_update_env): + mock_update_env["get_version"].return_value = "1.5.0" + mock_update_env["fetch_pypi"].return_value = ( + "1.5.0", + {"1.3.0", "1.4.0", "1.5.0"}, + ) + + update_command(version="1.3.0") + + mock_update_env["run_install"].assert_called_once_with("1.3.0") + calls = [str(c) for c in mock_update_env["console"].print.call_args_list] + assert any("downgrade" in c for c in calls) + + +class TestUpdateCommandAlreadyOnTarget: + def test_already_on_latest(self, mock_update_env): + mock_update_env["get_version"].return_value = "1.5.0" + + with pytest.raises(typer.Exit) as exc_info: + update_command(version=None) + + assert exc_info.value.exit_code == 0 + mock_update_env["run_install"].assert_not_called() + + def test_already_on_specific_version(self, mock_update_env): + mock_update_env["get_version"].return_value = "1.4.0" + + with pytest.raises(typer.Exit) as exc_info: + update_command(version="1.4.0") + + assert exc_info.value.exit_code == 0 + + +class TestUpdateCommandErrors: + def test_version_not_found(self, mock_update_env): + with pytest.raises(typer.Exit) as exc_info: + update_command(version="99.0.0") + + assert exc_info.value.exit_code == 1 + mock_update_env["run_install"].assert_not_called() + calls = [str(c) for c in mock_update_env["console"].print.call_args_list] + assert any("not found on PyPI" in c for c in calls) + + def test_network_error(self, mock_update_env): + mock_update_env["fetch_pypi"].side_effect = ConnectionError("no network") + + with pytest.raises(typer.Exit) as exc_info: + update_command(version=None) + + assert exc_info.value.exit_code == 1 + + def test_http_error(self, mock_update_env): + mock_update_env["fetch_pypi"].side_effect = RuntimeError( + "PyPI returned HTTP 503" + ) + + with pytest.raises(typer.Exit) as exc_info: + update_command(version=None) + + assert exc_info.value.exit_code == 1 + + def test_install_failure(self, mock_update_env): + mock_update_env["run_install"].side_effect = RuntimeError( + "uv install failed (exit 1): No matching distribution" + ) + + with pytest.raises(typer.Exit) as exc_info: + update_command(version=None) + + assert exc_info.value.exit_code == 1 + calls = [str(c) for c in mock_update_env["console"].print.call_args_list] + assert any("install failed" in c for c in calls) + + def test_install_timeout(self, mock_update_env): + mock_update_env["run_install"].side_effect = subprocess.TimeoutExpired( + cmd="uv", timeout=120 + ) + + with pytest.raises(typer.Exit) as exc_info: + update_command(version=None) + + assert exc_info.value.exit_code == 1 + calls = [str(c) for c in mock_update_env["console"].print.call_args_list] + assert any("timed out" in c for c in calls) diff --git a/tests/unit/cli/test_update_checker.py b/tests/unit/cli/test_update_checker.py new file mode 100644 index 00000000..a814c8b6 --- /dev/null +++ b/tests/unit/cli/test_update_checker.py @@ -0,0 +1,419 @@ +"""Unit tests for passive background update checker.""" + +from __future__ import annotations + +import json +import threading +from datetime import datetime, timezone, timedelta +from pathlib import Path +from unittest.mock import patch, MagicMock + +import pytest + +from runpod_flash.cli import update_checker +from runpod_flash.cli.update_checker import ( + _get_cache_path, + _is_cache_fresh, + _print_update_notice, + _read_cache, + _run_check, + _write_cache, + start_background_check, + CHECK_INTERVAL_HOURS, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _reset_module_state() -> None: + """Reset module-level state between tests.""" + update_checker._newer_version = None + update_checker._check_done = threading.Event() + update_checker._started = False + + +@pytest.fixture(autouse=True) +def _clean_state(): + """Reset module state before and after each test.""" + _reset_module_state() + yield + _reset_module_state() + + +# --------------------------------------------------------------------------- +# _get_cache_path +# --------------------------------------------------------------------------- + + +class TestGetCachePath: + def test_default_path(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv("XDG_CONFIG_HOME", raising=False) + path = _get_cache_path() + assert path == Path.home() / ".config" / "runpod" / "update_check.json" + + def test_custom_xdg(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path): + monkeypatch.setenv("XDG_CONFIG_HOME", str(tmp_path)) + path = _get_cache_path() + assert path == tmp_path / "runpod" / "update_check.json" + + +# --------------------------------------------------------------------------- +# _read_cache +# --------------------------------------------------------------------------- + + +class TestReadCache: + def test_missing_file(self, tmp_path: Path): + result = _read_cache(tmp_path / "nonexistent.json") + assert result is None + + def test_valid_json(self, tmp_path: Path): + cache_file = tmp_path / "cache.json" + data = { + "last_checked_utc": "2026-01-01T00:00:00+00:00", + "latest_version": "2.0.0", + } + cache_file.write_text(json.dumps(data)) + + result = _read_cache(cache_file) + assert result == data + + def test_malformed_json(self, tmp_path: Path): + cache_file = tmp_path / "cache.json" + cache_file.write_text("not valid json {{{") + + result = _read_cache(cache_file) + assert result is None + + +# --------------------------------------------------------------------------- +# _write_cache +# --------------------------------------------------------------------------- + + +class TestWriteCache: + def test_writes_correct_json(self, tmp_path: Path): + cache_file = tmp_path / "runpod" / "update_check.json" + + _write_cache(cache_file, "1.6.0") + + data = json.loads(cache_file.read_text()) + assert data["latest_version"] == "1.6.0" + assert "last_checked_utc" in data + # Verify timestamp is parseable and recent + ts = datetime.fromisoformat(data["last_checked_utc"]) + assert (datetime.now(timezone.utc) - ts).total_seconds() < 10 + + def test_creates_parent_dirs(self, tmp_path: Path): + cache_file = tmp_path / "a" / "b" / "c" / "cache.json" + _write_cache(cache_file, "1.0.0") + assert cache_file.exists() + + def test_silent_on_oserror(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path): + """Writing to an unwritable path does not raise.""" + + def _raise_oserror(*args, **kwargs): + raise OSError("simulated write failure") + + monkeypatch.setattr(Path, "write_text", _raise_oserror) + + cache_file = tmp_path / "cache.json" + _write_cache(cache_file, "1.0.0") + + +# --------------------------------------------------------------------------- +# _is_cache_fresh +# --------------------------------------------------------------------------- + + +class TestIsCacheFresh: + def test_fresh_cache(self): + one_hour_ago = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + assert _is_cache_fresh({"last_checked_utc": one_hour_ago}) is True + + def test_stale_cache(self): + old = ( + datetime.now(timezone.utc) - timedelta(hours=CHECK_INTERVAL_HOURS + 1) + ).isoformat() + assert _is_cache_fresh({"last_checked_utc": old}) is False + + def test_missing_key(self): + assert _is_cache_fresh({}) is False + + def test_invalid_timestamp(self): + assert _is_cache_fresh({"last_checked_utc": "not-a-date"}) is False + + +# --------------------------------------------------------------------------- +# _run_check +# --------------------------------------------------------------------------- + + +class TestRunCheck: + @patch("runpod_flash.cli.update_checker._get_current_version", return_value="1.5.0") + @patch( + "runpod_flash.cli.update_checker._fetch_pypi_metadata", + return_value=("1.6.0", {"1.5.0", "1.6.0"}), + ) + @patch("runpod_flash.cli.update_checker._get_cache_path") + def test_fetches_when_stale( + self, + mock_cache_path: MagicMock, + mock_fetch: MagicMock, + mock_version: MagicMock, + tmp_path: Path, + ): + cache_file = tmp_path / "update_check.json" + mock_cache_path.return_value = cache_file + + _run_check() + + mock_fetch.assert_called_once() + assert update_checker._check_done.is_set() + assert update_checker._newer_version == "1.6.0" + # Cache should be written + assert cache_file.exists() + + @patch("runpod_flash.cli.update_checker._get_current_version", return_value="1.5.0") + @patch("runpod_flash.cli.update_checker._fetch_pypi_metadata") + @patch("runpod_flash.cli.update_checker._get_cache_path") + def test_uses_cache_when_fresh( + self, + mock_cache_path: MagicMock, + mock_fetch: MagicMock, + mock_version: MagicMock, + tmp_path: Path, + ): + cache_file = tmp_path / "update_check.json" + fresh_data = { + "last_checked_utc": datetime.now(timezone.utc).isoformat(), + "latest_version": "1.6.0", + } + cache_file.write_text(json.dumps(fresh_data)) + mock_cache_path.return_value = cache_file + + _run_check() + + mock_fetch.assert_not_called() + assert update_checker._newer_version == "1.6.0" + + @patch("runpod_flash.cli.update_checker._get_current_version", return_value="1.6.0") + @patch( + "runpod_flash.cli.update_checker._fetch_pypi_metadata", + return_value=("1.6.0", {"1.6.0"}), + ) + @patch("runpod_flash.cli.update_checker._get_cache_path") + def test_no_update_when_current( + self, + mock_cache_path: MagicMock, + mock_fetch: MagicMock, + mock_version: MagicMock, + tmp_path: Path, + ): + mock_cache_path.return_value = tmp_path / "update_check.json" + + _run_check() + + assert update_checker._newer_version is None + assert update_checker._check_done.is_set() + + @patch("runpod_flash.cli.update_checker._get_current_version", return_value="1.5.0") + @patch( + "runpod_flash.cli.update_checker._fetch_pypi_metadata", + side_effect=ConnectionError("network down"), + ) + @patch("runpod_flash.cli.update_checker._get_cache_path") + def test_sets_done_on_network_failure( + self, + mock_cache_path: MagicMock, + mock_fetch: MagicMock, + mock_version: MagicMock, + tmp_path: Path, + ): + mock_cache_path.return_value = tmp_path / "update_check.json" + + _run_check() + + assert update_checker._check_done.is_set() + assert update_checker._newer_version is None + + @patch("runpod_flash.cli.update_checker._get_current_version", return_value="1.5.0") + @patch( + "runpod_flash.cli.update_checker._fetch_pypi_metadata", + return_value=("1.6.0", {"1.5.0", "1.6.0"}), + ) + @patch("runpod_flash.cli.update_checker._get_cache_path") + def test_fetches_when_cache_fresh_but_missing_latest_version( + self, + mock_cache_path: MagicMock, + mock_fetch: MagicMock, + mock_version: MagicMock, + tmp_path: Path, + ): + cache_file = tmp_path / "update_check.json" + # Fresh cache but missing latest_version key + fresh_data = { + "last_checked_utc": datetime.now(timezone.utc).isoformat(), + } + cache_file.write_text(json.dumps(fresh_data)) + mock_cache_path.return_value = cache_file + + _run_check() + + mock_fetch.assert_called_once() + assert update_checker._newer_version == "1.6.0" + + @patch( + "runpod_flash.cli.update_checker._get_current_version", return_value="unknown" + ) + def test_skips_when_version_unknown(self, mock_version: MagicMock): + _run_check() + + assert update_checker._check_done.is_set() + assert update_checker._newer_version is None + + +# --------------------------------------------------------------------------- +# _print_update_notice +# --------------------------------------------------------------------------- + + +class TestPrintUpdateNotice: + def test_prints_when_newer_available(self, capsys: pytest.CaptureFixture[str]): + update_checker._newer_version = "2.0.0" + update_checker._check_done.set() + + _print_update_notice() + + captured = capsys.readouterr() + assert "2.0.0" in captured.err + assert "flash update" in captured.err + # Version and instruction should be on separate lines + assert "\n" in captured.err.strip() + + def test_silent_when_no_update(self, capsys: pytest.CaptureFixture[str]): + update_checker._newer_version = None + update_checker._check_done.set() + + _print_update_notice() + + captured = capsys.readouterr() + assert captured.err == "" + + def test_silent_when_thread_not_done(self, capsys: pytest.CaptureFixture[str]): + update_checker._newer_version = "2.0.0" + # _check_done is NOT set + + _print_update_notice() + + captured = capsys.readouterr() + assert captured.err == "" + + +# --------------------------------------------------------------------------- +# start_background_check +# --------------------------------------------------------------------------- + + +class TestStartBackgroundCheck: + @patch("runpod_flash.cli.update_checker._is_interactive", return_value=True) + @patch("runpod_flash.cli.update_checker.atexit.register") + @patch("runpod_flash.cli.update_checker.threading.Thread") + def test_spawns_daemon_thread( + self, + mock_thread_cls: MagicMock, + mock_register: MagicMock, + mock_interactive: MagicMock, + monkeypatch: pytest.MonkeyPatch, + ): + monkeypatch.delenv("FLASH_NO_UPDATE_CHECK", raising=False) + monkeypatch.delenv("CI", raising=False) + + mock_thread = MagicMock() + mock_thread_cls.return_value = mock_thread + + start_background_check() + + mock_thread_cls.assert_called_once_with(target=_run_check, daemon=True) + mock_thread.start.assert_called_once() + mock_register.assert_called_once_with(_print_update_notice) + + @patch("runpod_flash.cli.update_checker._is_interactive", return_value=True) + @patch("runpod_flash.cli.update_checker.atexit.register") + @patch("runpod_flash.cli.update_checker.threading.Thread") + def test_skips_on_flash_no_update_check( + self, + mock_thread_cls: MagicMock, + mock_register: MagicMock, + mock_interactive: MagicMock, + monkeypatch: pytest.MonkeyPatch, + ): + monkeypatch.setenv("FLASH_NO_UPDATE_CHECK", "1") + monkeypatch.delenv("CI", raising=False) + + start_background_check() + + mock_thread_cls.assert_not_called() + mock_register.assert_not_called() + + @patch("runpod_flash.cli.update_checker._is_interactive", return_value=True) + @patch("runpod_flash.cli.update_checker.atexit.register") + @patch("runpod_flash.cli.update_checker.threading.Thread") + def test_skips_on_ci( + self, + mock_thread_cls: MagicMock, + mock_register: MagicMock, + mock_interactive: MagicMock, + monkeypatch: pytest.MonkeyPatch, + ): + monkeypatch.delenv("FLASH_NO_UPDATE_CHECK", raising=False) + monkeypatch.setenv("CI", "true") + + start_background_check() + + mock_thread_cls.assert_not_called() + mock_register.assert_not_called() + + @patch("runpod_flash.cli.update_checker._is_interactive", return_value=False) + @patch("runpod_flash.cli.update_checker.atexit.register") + @patch("runpod_flash.cli.update_checker.threading.Thread") + def test_skips_when_not_interactive( + self, + mock_thread_cls: MagicMock, + mock_register: MagicMock, + mock_interactive: MagicMock, + monkeypatch: pytest.MonkeyPatch, + ): + monkeypatch.delenv("FLASH_NO_UPDATE_CHECK", raising=False) + monkeypatch.delenv("CI", raising=False) + + start_background_check() + + mock_thread_cls.assert_not_called() + mock_register.assert_not_called() + + @patch("runpod_flash.cli.update_checker._is_interactive", return_value=True) + @patch("runpod_flash.cli.update_checker.atexit.register") + @patch("runpod_flash.cli.update_checker.threading.Thread") + def test_idempotent_only_starts_once( + self, + mock_thread_cls: MagicMock, + mock_register: MagicMock, + mock_interactive: MagicMock, + monkeypatch: pytest.MonkeyPatch, + ): + monkeypatch.delenv("FLASH_NO_UPDATE_CHECK", raising=False) + monkeypatch.delenv("CI", raising=False) + + mock_thread = MagicMock() + mock_thread_cls.return_value = mock_thread + + start_background_check() + start_background_check() + start_background_check() + + mock_thread_cls.assert_called_once() + mock_register.assert_called_once()