diff --git a/examples/sandbox_port_expose_demo.py b/examples/sandbox_port_expose_demo.py new file mode 100644 index 00000000..d4eef144 --- /dev/null +++ b/examples/sandbox_port_expose_demo.py @@ -0,0 +1,195 @@ +import socket +import time +import urllib.request + +from prime_sandboxes import APIClient, APIError, CreateSandboxRequest, SandboxClient + + +def verify_http(url: str) -> bool: + """Verify HTTP endpoint is accessible and returns expected response.""" + try: + # Add User-Agent header to avoid 403 from bot protection + req = urllib.request.Request(url, headers={"User-Agent": "curl/8.0"}) + with urllib.request.urlopen(req, timeout=10) as response: + status = response.getcode() + body = response.read().decode("utf-8") + # Python's http.server returns a directory listing HTML + if status == 200 and "Directory listing" in body: + return True + return False + except Exception as e: + print(f" HTTP verification error: {e}") + return False + + +def verify_tcp(endpoint: str, test_message: bytes = b"Hello") -> bool: + """Verify TCP endpoint is accessible and echoes back data.""" + try: + # Parse host:port from endpoint address + host, port_str = endpoint.rsplit(":", 1) + port = int(port_str) + + # Connect with raw TCP + with socket.create_connection((host, port), timeout=10) as sock: + sock.sendall(test_message) + response = sock.recv(1024) + expected = b"Echo: " + test_message + return response == expected + except Exception as e: + print(f" TCP verification error: {e}") + return False + + +def main() -> None: + """Demonstrate HTTP and TCP port exposure""" + try: + client = APIClient() + sandbox_client = SandboxClient(client) + + request = CreateSandboxRequest( + name="port-expose-demo", + docker_image="python:3.11-slim", + start_command="tail -f /dev/null", + cpu_cores=1, + memory_gb=2, + timeout_minutes=30, + ) + + print("Creating sandbox...") + sandbox = sandbox_client.create(request) + print(f"Created: {sandbox.name} ({sandbox.id})") + + print("\nWaiting for sandbox to be running...") + sandbox_client.wait_for_creation(sandbox.id, max_attempts=60) + print("Sandbox is running!") + + print("\n--- HTTP Port Exposure ---") + print("Starting HTTP server on port 8000...") + sandbox_client.execute_command( + sandbox.id, + "nohup python -m http.server 8000 > /tmp/http.log 2>&1 &", + ) + time.sleep(2) # Give server time to start + + # Expose the HTTP port + http_exposure = sandbox_client.expose( + sandbox_id=sandbox.id, + port=8000, + name="web-server", + protocol="HTTP", + ) + print("HTTP port exposed!") + print(f" Exposure ID: {http_exposure.exposure_id}") + print(f" URL: {http_exposure.url}") + print(f" TLS Socket: {http_exposure.tls_socket}") + time.sleep(10) + + # Verify HTTP endpoint is accessible + print(" Verifying HTTP endpoint...") + if verify_http(http_exposure.url): + print(" HTTP verification: SUCCESS") + else: + print(" HTTP verification: FAILED") + + # Start a TCP echo server in the sandbox + print("\n--- TCP Port Exposure ---") + print("Starting TCP echo server on port 9000...") + + # Create a simple TCP echo server + tcp_server_code = """ +import socket +import threading + +def handle_client(conn, addr): + print(f"Connection from {addr}") + while True: + data = conn.recv(1024) + if not data: + break + conn.sendall(b"Echo: " + data) + conn.close() + +server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +server.bind(("0.0.0.0", 9000)) +server.listen(5) +print("TCP server listening on port 9000") + +while True: + conn, addr = server.accept() + thread = threading.Thread(target=handle_client, args=(conn, addr)) + thread.daemon = True + thread.start() +""" + # Write and run the TCP server + sandbox_client.execute_command( + sandbox.id, + f"cat > /tmp/tcp_server.py << 'SCRIPT'\n{tcp_server_code}\nSCRIPT", + ) + sandbox_client.execute_command( + sandbox.id, + "nohup python /tmp/tcp_server.py > /tmp/tcp.log 2>&1 &", + ) + time.sleep(2) # Give server time to start + + # Expose the TCP port + tcp_exposure = sandbox_client.expose( + sandbox_id=sandbox.id, + port=9000, + name="echo-server", + protocol="TCP", + ) + print("TCP port exposed!") + print(f" Exposure ID: {tcp_exposure.exposure_id}") + print(f" External Endpoint: {tcp_exposure.external_endpoint}") + if tcp_exposure.external_port: + print(f" External Port: {tcp_exposure.external_port}") + time.sleep(120) + + # Verify TCP endpoint is accessible + print(" Verifying TCP endpoint...") + if verify_tcp(tcp_exposure.external_endpoint): + print(" TCP verification: SUCCESS (echo server responded correctly)") + else: + print(" TCP verification: FAILED") + + # List all exposed ports + print("\n--- All Exposed Ports ---") + ports_response = sandbox_client.list_exposed_ports(sandbox.id) + for port in ports_response.exposures: + print(f" {port.name} (port {port.port}):") + print(f" Protocol: {port.protocol}") + print(f" Exposure ID: {port.exposure_id}") + if port.protocol == "HTTP": + print(f" URL: {port.url}") + else: + print(f" External Endpoint: {port.external_endpoint}") + + # Usage instructions + print("\n--- How to Connect ---") + print(f"HTTP: curl {http_exposure.url}") + print(f"TCP: Connect to {tcp_exposure.external_endpoint} with a TCP client") + + # Clean up exposures + print("\n--- Cleanup ---") + print("Removing port exposures...") + sandbox_client.unexpose(sandbox.id, http_exposure.exposure_id) + print(f" Removed HTTP exposure: {http_exposure.exposure_id}") + sandbox_client.unexpose(sandbox.id, tcp_exposure.exposure_id) + print(f" Removed TCP exposure: {tcp_exposure.exposure_id}") + + # Delete sandbox + print(f"\nDeleting sandbox {sandbox.name}...") + sandbox_client.delete(sandbox.id) + print("Done!") + + except APIError as e: + print(f"API Error: {e}") + print("Make sure you're logged in: run 'prime login' first") + except Exception as e: + print(f"Error: {e}") + raise + + +if __name__ == "__main__": + main() diff --git a/packages/prime-sandboxes/src/prime_sandboxes/__init__.py b/packages/prime-sandboxes/src/prime_sandboxes/__init__.py index f8f16bec..03aedb3d 100644 --- a/packages/prime-sandboxes/src/prime_sandboxes/__init__.py +++ b/packages/prime-sandboxes/src/prime_sandboxes/__init__.py @@ -37,6 +37,7 @@ Sandbox, SandboxListResponse, SandboxStatus, + SSHSession, UpdateSandboxRequest, ) from .sandbox import AsyncSandboxClient, AsyncTemplateClient, SandboxClient, TemplateClient @@ -76,6 +77,7 @@ "ExposePortRequest", "ExposedPort", "ListExposedPortsResponse", + "SSHSession", # Exceptions "APIError", "UnauthorizedError", diff --git a/packages/prime-sandboxes/src/prime_sandboxes/models.py b/packages/prime-sandboxes/src/prime_sandboxes/models.py index b39af540..0739b478 100644 --- a/packages/prime-sandboxes/src/prime_sandboxes/models.py +++ b/packages/prime-sandboxes/src/prime_sandboxes/models.py @@ -178,6 +178,7 @@ class ExposePortRequest(BaseModel): port: int name: Optional[str] = None + protocol: str = "HTTP" # HTTP or TCP/UDP class ExposedPort(BaseModel): @@ -190,6 +191,8 @@ class ExposedPort(BaseModel): url: str tls_socket: str protocol: Optional[str] = None + external_port: Optional[int] = None # For TCP/UDP exposures + external_endpoint: Optional[str] = None # For TCP/UDP: host:port endpoint created_at: Optional[str] = None @@ -199,6 +202,23 @@ class ListExposedPortsResponse(BaseModel): exposures: List[ExposedPort] +class SSHSession(BaseModel): + """SSH session details""" + + session_id: str + exposure_id: str + sandbox_id: str + host: str + port: int + external_endpoint: str + expires_at: datetime + ttl_seconds: int + gateway_url: str + user_ns: str + job_id: str + token: str + + class BackgroundJob(BaseModel): """Background job handle returned when starting a background job""" diff --git a/packages/prime-sandboxes/src/prime_sandboxes/sandbox.py b/packages/prime-sandboxes/src/prime_sandboxes/sandbox.py index 50e8c0c6..5b0dbb07 100644 --- a/packages/prime-sandboxes/src/prime_sandboxes/sandbox.py +++ b/packages/prime-sandboxes/src/prime_sandboxes/sandbox.py @@ -43,6 +43,7 @@ Sandbox, SandboxListResponse, SandboxLogsResponse, + SSHSession, ) # Retry configuration for transient connection errors on gateway requests @@ -713,9 +714,10 @@ def expose( sandbox_id: str, port: int, name: Optional[str] = None, + protocol: str = "HTTP", ) -> ExposedPort: - """Expose an HTTP port from a sandbox.""" - request = ExposePortRequest(port=port, name=name) + """Expose a port from a sandbox.""" + request = ExposePortRequest(port=port, name=name, protocol=protocol) response = self.client.request( "POST", f"/sandbox/{sandbox_id}/expose", @@ -732,6 +734,31 @@ def list_exposed_ports(self, sandbox_id: str) -> ListExposedPortsResponse: response = self.client.request("GET", f"/sandbox/{sandbox_id}/expose") return ListExposedPortsResponse.model_validate(response) + def list_all_exposed_ports(self) -> ListExposedPortsResponse: + """List all exposed ports across all sandboxes for the current user""" + response = self.client.request("GET", "/sandbox/expose/all") + return ListExposedPortsResponse.model_validate(response) + + def create_ssh_session( + self, + sandbox_id: str, + ttl_seconds: Optional[int] = None, + ) -> SSHSession: + """Create an SSH session""" + payload: Dict[str, Any] = {} + if ttl_seconds is not None: + payload["ttl_seconds"] = ttl_seconds + response = self.client.request( + "POST", + f"/sandbox/{sandbox_id}/ssh-session", + json=payload, + ) + return SSHSession.model_validate(response) + + def close_ssh_session(self, sandbox_id: str, session_id: str) -> None: + """Close an SSH session and remove its exposure""" + self.client.request("DELETE", f"/sandbox/{sandbox_id}/ssh-session/{session_id}") + class AsyncSandboxClient: """Async client for sandbox API operations""" @@ -1293,9 +1320,10 @@ async def expose( sandbox_id: str, port: int, name: Optional[str] = None, + protocol: str = "HTTP", ) -> ExposedPort: - """Expose an HTTP port from a sandbox.""" - request = ExposePortRequest(port=port, name=name) + """Expose a port from a sandbox.""" + request = ExposePortRequest(port=port, name=name, protocol=protocol) response = await self.client.request( "POST", f"/sandbox/{sandbox_id}/expose", @@ -1312,6 +1340,31 @@ async def list_exposed_ports(self, sandbox_id: str) -> ListExposedPortsResponse: response = await self.client.request("GET", f"/sandbox/{sandbox_id}/expose") return ListExposedPortsResponse.model_validate(response) + async def list_all_exposed_ports(self) -> ListExposedPortsResponse: + """List all exposed ports across all sandboxes for the current user""" + response = await self.client.request("GET", "/sandbox/expose/all") + return ListExposedPortsResponse.model_validate(response) + + async def create_ssh_session( + self, + sandbox_id: str, + ttl_seconds: Optional[int] = None, + ) -> SSHSession: + """Create an SSH session""" + payload: Dict[str, Any] = {} + if ttl_seconds is not None: + payload["ttl_seconds"] = ttl_seconds + response = await self.client.request( + "POST", + f"/sandbox/{sandbox_id}/ssh-session", + json=payload, + ) + return SSHSession.model_validate(response) + + async def close_ssh_session(self, sandbox_id: str, session_id: str) -> None: + """Close an SSH session and remove its exposure""" + await self.client.request("DELETE", f"/sandbox/{sandbox_id}/ssh-session/{session_id}") + class TemplateClient: """Client for template/registry helper APIs.""" diff --git a/packages/prime/src/prime_cli/commands/sandbox.py b/packages/prime/src/prime_cli/commands/sandbox.py index fd3de03d..7c8c40be 100644 --- a/packages/prime/src/prime_cli/commands/sandbox.py +++ b/packages/prime/src/prime_cli/commands/sandbox.py @@ -1,10 +1,15 @@ import json +import os import random import shlex +import shutil import string +import subprocess +import tempfile import time from typing import Any, Dict, List, Optional +import httpx import typer from prime_sandboxes import ( APIClient, @@ -951,20 +956,29 @@ def expose_port( sandbox_id: str = typer.Argument(..., help="Sandbox ID to expose port from"), port: int = typer.Argument(..., help="Port number to expose"), name: Optional[str] = typer.Option(None, help="Optional name for the exposed port"), + protocol: str = typer.Option( + "HTTP", + "--protocol", + "-p", + help="Protocol: HTTP or TCP/UDP", + ), output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), ) -> None: - """Expose an HTTP port from a sandbox. - - Currently only HTTP is supported. TCP, UDP, and SSH support coming soon. - """ + """Expose a port from a sandbox.""" validate_output_format(output, console) + # Validate protocol + protocol = protocol.upper() + if protocol not in ("HTTP", "TCP", "UDP"): + console.print(f"[red]Error:[/red] Invalid protocol '{protocol}'. Use HTTP, TCP, or UDP.") + raise typer.Exit(1) + try: base_client = APIClient() sandbox_client = SandboxClient(base_client) with console.status("[bold blue]Exposing port...", spinner="dots"): - exposed = sandbox_client.expose(sandbox_id, port, name) + exposed = sandbox_client.expose(sandbox_id, port, name, protocol) if output == "json": output_data_as_json(exposed.model_dump(), console) @@ -972,10 +986,21 @@ def expose_port( console.print("[green]✓[/green] Port exposed successfully!") console.print(f"[bold green]Exposure ID:[/bold green] {exposed.exposure_id}") console.print(f"[bold green]Port:[/bold green] {exposed.port}") + console.print(f"[bold green]Protocol:[/bold green] {exposed.protocol or protocol}") if exposed.name: console.print(f"[bold green]Name:[/bold green] {exposed.name}") console.print(f"[bold green]URL:[/bold green] {exposed.url}") - console.print(f"[bold green]TLS Socket:[/bold green] {exposed.tls_socket}") + if protocol in ("TCP", "UDP"): + if exposed.external_port: + console.print( + f"[bold green]External Port:[/bold green] {exposed.external_port}" + ) + if exposed.external_endpoint: + console.print( + f"[bold green]External Endpoint:[/bold green] {exposed.external_endpoint}" + ) + else: + console.print(f"[bold green]TLS Socket:[/bold green] {exposed.tls_socket}") except APIError as e: console.print(f"[red]Error:[/red] {str(e)}") @@ -1017,55 +1042,266 @@ def unexpose_port( raise typer.Exit(1) -@app.command("list-ports", no_args_is_help=True) +@app.command("list-ports") def list_ports( - sandbox_id: str = typer.Argument(..., help="Sandbox ID"), + sandbox_id: Optional[str] = typer.Argument( + None, help="Sandbox ID (omit to list all exposed ports across all sandboxes)" + ), output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), ) -> None: - """List all exposed ports for a sandbox""" + """List exposed ports for a sandbox, or all sandboxes if no ID is provided""" validate_output_format(output, console) try: base_client = APIClient() sandbox_client = SandboxClient(base_client) - with console.status("[bold blue]Fetching exposed ports...", spinner="dots"): - response = sandbox_client.list_exposed_ports(sandbox_id) + if sandbox_id: + # List ports for a specific sandbox + with console.status("[bold blue]Fetching exposed ports...", spinner="dots"): + response = sandbox_client.list_exposed_ports(sandbox_id) - if output == "json": - output_data_as_json( - {"exposures": [exp.model_dump() for exp in response.exposures]}, console - ) - else: - if not response.exposures: - console.print(f"[yellow]No exposed ports for sandbox {sandbox_id}[/yellow]") - else: - table = build_table( - f"Exposed Ports for Sandbox {sandbox_id}", - [ - ("Exposure ID", "cyan"), - ("Port", "blue"), - ("Name", "green"), - ("URL", "magenta"), - ("TLS Socket", "yellow"), - ], + if output == "json": + output_data_as_json( + {"exposures": [exp.model_dump() for exp in response.exposures]}, console ) + else: + if not response.exposures: + console.print(f"[yellow]No exposed ports for sandbox {sandbox_id}[/yellow]") + else: + table = build_table( + f"Exposed Ports for Sandbox {sandbox_id}", + [ + ("Exposure ID", "cyan"), + ("Protocol", "white"), + ("Port", "blue"), + ("External", "blue"), + ("Name", "green"), + ("URL", "magenta"), + ], + ) + + for exp in response.exposures: + external_port = str(exp.external_port) if exp.external_port else "-" + table.add_row( + exp.exposure_id, + exp.protocol or "HTTP", + str(exp.port), + external_port, + exp.name or "-", + exp.url, + ) + + console.print(table) + else: + # List all exposed ports across all sandboxes + with console.status("[bold blue]Fetching all exposed ports...", spinner="dots"): + response = sandbox_client.list_all_exposed_ports() - for exp in response.exposures: - table.add_row( - exp.exposure_id, - str(exp.port), - exp.name or "N/A", - exp.url, - exp.tls_socket, + if output == "json": + output_data_as_json( + {"exposures": [exp.model_dump() for exp in response.exposures]}, console + ) + else: + if not response.exposures: + console.print("[yellow]No exposed ports found[/yellow]") + else: + table = build_table( + "All Exposed Ports", + [ + ("Sandbox ID", "yellow"), + ("Exposure ID", "cyan"), + ("Protocol", "white"), + ("Port", "blue"), + ("External", "blue"), + ("Name", "green"), + ("URL", "magenta"), + ], ) - console.print(table) + for exp in response.exposures: + external_port = str(exp.external_port) if exp.external_port else "-" + table.add_row( + exp.sandbox_id, + exp.exposure_id, + exp.protocol or "HTTP", + str(exp.port), + external_port, + exp.name or "-", + exp.url, + ) + + console.print(table) + + except APIError as e: + console.print(f"[red]Error:[/red] {str(e)}") + raise typer.Exit(1) + except Exception as e: + console.print(f"[red]Unexpected error:[/red] {escape(str(e))}") + console.print_exception(show_locals=True) + raise typer.Exit(1) + + +@app.command("ssh", no_args_is_help=True) +def ssh_connect( + sandbox_id: str = typer.Argument(..., help="Sandbox ID to SSH into"), + identity: Optional[str] = typer.Option( + None, "--identity", "-i", help="Path to SSH private key file (will be authorized)" + ), + ssh_args: Optional[List[str]] = typer.Argument( + None, help="Additional SSH arguments (e.g., -- -v for verbose)" + ), +) -> None: + """Connect to a sandbox via SSH. + + This command creates a SSH session, authorizes your key, and cleans up the exposure. + + Examples:\n + prime sandbox ssh sb_abc123\n + prime sandbox ssh sb_abc123 -i ~/.ssh/my_key\n + prime sandbox ssh sb_abc123 -- -v -L 8080:localhost:8080\n + """ + session_id: Optional[str] = None + sandbox_client: Optional[SandboxClient] = None + temp_dir: Optional[str] = None + key_path: Optional[str] = None + + def cleanup() -> None: + """Clean up the SSH session and temporary keys.""" + if session_id and sandbox_client: + try: + console.print("\n[bold blue]Cleaning up SSH session...[/bold blue]") + sandbox_client.close_ssh_session(sandbox_id, session_id) + console.print("[green]✓[/green] SSH session closed") + except Exception: + pass + if temp_dir and os.path.isdir(temp_dir): + shutil.rmtree(temp_dir, ignore_errors=True) + + try: + # Check if ssh and ssh-keygen commands are available + if not shutil.which("ssh"): + console.print("[red]Error:[/red] SSH client not found. Please install OpenSSH.") + raise typer.Exit(1) + if not shutil.which("ssh-keygen"): + console.print("[red]Error:[/red] ssh-keygen not found. Please install OpenSSH.") + raise typer.Exit(1) + + base_client = APIClient() + sandbox_client = SandboxClient(base_client) + + # Check if sandbox is running + with console.status("[bold blue]Checking sandbox status...", spinner="dots"): + sandbox = sandbox_client.get(sandbox_id) + + if sandbox.status != "RUNNING": + console.print(f"[red]Error:[/red] Sandbox is not running (status: {sandbox.status})") + console.print( + f"[yellow]Tip:[/yellow] Check sandbox status with: prime sandbox get {sandbox_id}" + ) + raise typer.Exit(1) + + # Prepare SSH key (use provided identity or generate ephemeral) + def load_public_key_from_identity(path: str) -> str: + pub_path = f"{path}.pub" + if os.path.exists(pub_path): + with open(pub_path, "r") as f: + return f.read().strip() + result = subprocess.run( + ["ssh-keygen", "-y", "-f", path], + check=True, + capture_output=True, + text=True, + ) + return result.stdout.strip() + + if identity: + key_path = os.path.expanduser(identity) + if not os.path.exists(key_path): + console.print(f"[red]Error:[/red] Identity file not found: {key_path}") + raise typer.Exit(1) + public_key = load_public_key_from_identity(key_path) + else: + temp_dir = tempfile.mkdtemp(prefix="prime-ssh-") + key_path = os.path.join(temp_dir, "id_ed25519") + subprocess.run( + ["ssh-keygen", "-t", "ed25519", "-N", "", "-f", key_path], + check=True, + capture_output=True, + ) + with open(f"{key_path}.pub", "r") as f: + public_key = f.read().strip() + + # Create SSH session + console.print("[bold blue]Creating SSH session...[/bold blue]") + with console.status("[bold blue]Setting up SSH session...", spinner="dots"): + session = sandbox_client.create_ssh_session(sandbox_id) + session_id = session.session_id + + # Authorize the key + authorize_url = ( + f"{session.gateway_url.rstrip('/')}/{session.user_ns}/{session.job_id}/authorize" + ) + headers = {"Authorization": f"Bearer {session.token}"} + payload = { + "session_id": session.session_id, + "public_key": public_key, + "ttl_seconds": session.ttl_seconds, + } + try: + with httpx.Client(timeout=30) as client: + client.post(authorize_url, json=payload, headers=headers).raise_for_status() + except Exception as e: + console.print(f"[red]Error:[/red] Failed to authorize SSH key: {e}") + cleanup() + raise typer.Exit(1) + + ssh_host = session.host + ssh_port = session.port + + console.print("[green]✓[/green] SSH session ready!") + console.print(f"[bold green]Connecting to:[/bold green] {session.session_id}@{ssh_host}") + console.print(f"[bold green]Port:[/bold green] {ssh_port}") + console.print() + console.print("[dim]Press Ctrl+D or type 'exit' to disconnect[/dim]") + console.print() + # Build SSH command + ssh_cmd = ["ssh", f"{session.session_id}@{ssh_host}", "-p", str(ssh_port)] + + # Disable strict host key checking for dynamic hosts + ssh_cmd.extend(["-o", "StrictHostKeyChecking=no"]) + ssh_cmd.extend(["-o", "UserKnownHostsFile=/dev/null"]) + + # Add identity file if specified + if key_path: + ssh_cmd.extend(["-i", key_path]) + + # Add any additional SSH arguments + if ssh_args: + ssh_cmd.extend(ssh_args) + + # Connect via SSH (this will be interactive) + result = subprocess.run(ssh_cmd) + + # Check if SSH connection failed + if result.returncode != 0 and result.returncode != 255: + console.print(f"\n[yellow]SSH connection exited with code {result.returncode}[/yellow]") + + cleanup() + + except KeyboardInterrupt: + console.print("\n[yellow]SSH connection interrupted[/yellow]") + cleanup() + raise typer.Exit(130) except APIError as e: console.print(f"[red]Error:[/red] {str(e)}") + cleanup() raise typer.Exit(1) + except typer.Exit: + raise except Exception as e: console.print(f"[red]Unexpected error:[/red] {escape(str(e))}") console.print_exception(show_locals=True) + cleanup() raise typer.Exit(1)