Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions src/runpod_flash/cli/commands/update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""CLI command for updating runpod-flash to latest or a specific version."""

import json
import shutil
import subprocess
import sys
import urllib.error
import urllib.request
from importlib import metadata
from typing import Optional

import typer
from rich.console import Console

console = Console()

PYPI_URL = "https://pypi.org/pypi/runpod-flash/json"
INSTALL_TIMEOUT_SECONDS = 120


def _get_current_version() -> str:
"""Return installed runpod-flash version, or 'unknown' if not found."""
try:
return metadata.version("runpod-flash")
except metadata.PackageNotFoundError:
return "unknown"


def _parse_version(version: str) -> tuple[int, ...]:
"""Parse a version string like '1.5.0' into a comparable tuple (1, 5, 0).

Tuples are NOT padded here -- callers comparing two parsed versions should
use ``_compare_versions()`` to handle differing component counts.
"""
return tuple(int(part) for part in version.split("."))


def _compare_versions(a: tuple[int, ...], b: tuple[int, ...]) -> int:
"""Compare two parsed version tuples, padding shorter one with zeros.

Returns negative if a < b, zero if equal, positive if a > b.
Handles differing component counts: (2, 0) and (2, 0, 0) are equal.
"""
max_len = max(len(a), len(b))
a_padded = a + (0,) * (max_len - len(a))
b_padded = b + (0,) * (max_len - len(b))
if a_padded < b_padded:
return -1
if a_padded > b_padded:
return 1
return 0


def _fetch_pypi_metadata() -> tuple[str, set[str]]:
"""Fetch latest version and available releases from PyPI.

Returns:
Tuple of (latest_version, set_of_all_version_strings).

Raises:
ConnectionError: Network unreachable or DNS failure.
RuntimeError: HTTP error from PyPI.
"""
try:
with urllib.request.urlopen(PYPI_URL, timeout=15) as resp:
data = json.loads(resp.read().decode())
except urllib.error.URLError as exc:
if isinstance(exc, urllib.error.HTTPError):
raise RuntimeError(
f"PyPI returned HTTP {exc.code}. Try again later."
) from exc
raise ConnectionError(
"Could not reach PyPI. Check your network connection."
) from exc
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
raise RuntimeError(
"PyPI returned an unexpected response. Try again later."
) from exc

try:
latest = data["info"]["version"]
except (KeyError, TypeError) as exc:
raise RuntimeError(
"PyPI response missing version info. Try again later."
) from exc

releases = set(data.get("releases", {}).keys())
return latest, releases


def _build_install_command(version: str) -> list[str]:
"""Build the install command, preferring uv over pip.

Returns the command as a list of strings suitable for subprocess.run.
Uses ``uv pip install`` when uv is on PATH, otherwise falls back to
``python -m pip install``.
"""
package_spec = f"runpod-flash=={version}"
if shutil.which("uv"):
return ["uv", "pip", "install", package_spec, "--quiet"]
return [sys.executable, "-m", "pip", "install", package_spec, "--quiet"]


def _run_install(version: str) -> subprocess.CompletedProcess[str]:
"""Install the given version of runpod-flash.

Raises:
subprocess.TimeoutExpired: Install took longer than INSTALL_TIMEOUT_SECONDS.
RuntimeError: Installer exited with non-zero code.
"""
cmd = _build_install_command(version)
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=INSTALL_TIMEOUT_SECONDS,
)
if result.returncode != 0:
installer = "uv" if cmd[0] == "uv" else "pip"
stderr = result.stderr.strip()
raise RuntimeError(
f"{installer} install failed (exit {result.returncode}): {stderr}"
)
return result


def update_command(
version: Optional[str] = typer.Option(
None, "--version", "-V", help="Target version to install (default: latest)"
),
) -> None:
"""Update runpod-flash to the latest version or a specific version."""
current = _get_current_version()
console.print(f"Current version: [bold]{current}[/bold]")

# Fetch PyPI metadata
with console.status("Checking PyPI for available versions..."):
try:
latest, releases = _fetch_pypi_metadata()
except (ConnectionError, RuntimeError) as exc:
console.print(f"[red]error:[/red] {exc}")
raise typer.Exit(code=1)

target = version or latest

# Validate target version exists on PyPI
if target not in releases:
console.print(
f"[red]error:[/red] version [bold]{target}[/bold] not found on PyPI"
)
raise typer.Exit(code=1)

# Already on target
if current == target:
console.print(f"Already on version [bold]{target}[/bold]. Nothing to do.")
raise typer.Exit(code=0)

# Downgrade warning
if current != "unknown":
try:
if _compare_versions(_parse_version(target), _parse_version(current)) < 0:
console.print(
f"[yellow]note:[/yellow] {target} is older than {current} (downgrade)"
)
except ValueError:
pass # non-standard version string, skip comparison

# Install
console.print(f"Installing runpod-flash [bold]{target}[/bold]...")
with console.status("Installing..."):
try:
_run_install(target)
except subprocess.TimeoutExpired:
console.print(
f"[red]error:[/red] install timed out after {INSTALL_TIMEOUT_SECONDS}s"
)
raise typer.Exit(code=1)
except RuntimeError as exc:
console.print(f"[red]error:[/red] {exc}")
raise typer.Exit(code=1)

console.print(f"[green]Updated runpod-flash {current} -> {target}[/green]")
9 changes: 9 additions & 0 deletions src/runpod_flash/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
apps,
undeploy,
login,
update,
)
from .update_checker import start_background_check


def get_version() -> str:
Expand All @@ -41,6 +43,7 @@ def get_version() -> str:
app.command("build")(build.build_command)
app.command("login")(login.login_command)
app.command("deploy")(deploy.deploy_command)
app.command("update")(update.update_command)
# app.command("report")(resource.report_command)


Expand All @@ -66,6 +69,9 @@ def get_version() -> str:
app.command("undeploy")(undeploy.undeploy_command)


_UPDATE_CHECK_EXCLUDED = frozenset({"run", "update"})


@app.callback(invoke_without_command=True)
def main(
ctx: typer.Context,
Expand All @@ -87,6 +93,9 @@ def main(
)
)

if ctx.invoked_subcommand and ctx.invoked_subcommand not in _UPDATE_CHECK_EXCLUDED:
start_background_check()


if __name__ == "__main__":
app()
Loading
Loading