diff --git a/CLAUDE.md b/CLAUDE.md index a15c45c..0b9f14b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,304 +1,223 @@ -# CLAUDE.md +# Flash Worker (worker-flash) -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. +> Auto-generated by /analyze-repos on 2026-02-22. Manual edits will be overwritten on next analysis. ## Project Overview -This is `worker-flash`, a RunPod Serverless worker template that provides dynamic GPU provisioning for ML workloads with transparent execution and persistent workspace management. The project consists of two main components: - -1. **RunPod Worker Handler** (`src/handler.py`) - A serverless function that executes remote Python functions with dependency management and workspace support -2. **Flash SDK** (pip dependency) - Python library for distributed inference and serving of ML models - -## Key Areas of Responsibility - -### 1. Remote Function Execution Engine (`src/`) -- **Core Handler** (`src/handler.py:18`): Main RunPod serverless entry point that orchestrates remote execution -- **Remote Executor** (`src/remote_executor.py:11`): Central orchestrator that coordinates all execution components using composition pattern -- **Function Executor** (`src/function_executor.py:12`): Handles individual function execution with full output capture (stdout, stderr, logs) -- **Class Executor** (`src/class_executor.py:14`): Manages class instantiation and method execution with instance persistence and metadata tracking - -### 2. Dependency Management System (`src/dependency_installer.py:14`) -- **Python Package Installation**: UV-based package management with environment-aware configuration (Docker vs local) -- **System Package Installation**: APT/Nala-based system dependency handling with acceleration support -- **Differential Installation**: Optimized package installation that skips already-installed packages -- **Environment Detection**: Automatic Docker vs local environment detection for appropriate installation methods -- **System Package Filtering**: Intelligent detection of system-available packages to avoid redundant installation -- **Universal Subprocess Integration**: All subprocess operations use centralized logging utility - -### 3. Universal Subprocess Utility (`src/subprocess_utils.py`) -- **Centralized Subprocess Operations**: All subprocess calls use `run_logged_subprocess` for consistency -- **Automatic Logging Integration**: All subprocess output flows through log streamer at DEBUG level -- **Environment-Aware Execution**: Handles Docker vs local environment differences automatically -- **Standardized Error Handling**: Consistent FunctionResponse pattern for all subprocess operations -- **Timeout Management**: Configurable timeouts with proper cleanup on timeout/cancellation - -### 4. Serialization & Protocol Management -- **Protocol Definitions** (`runpod_flash.protos.remote_execution`): Pydantic models for request/response with validation -- **Serialization Utils** (`src/serialization_utils.py`): CloudPickle-based data serialization for function arguments and results -- **Base Executor** (`src/base_executor.py`): Common execution interface and environment setup - -### 5. Flash SDK Integration (pip dependency) -- **Installation**: Installed via pip from GitHub repository -- **Client Interface**: `@remote` decorator for marking functions for remote execution -- **Resource Management**: GPU/CPU configuration and provisioning through LiveServerless objects -- **Live Serverless**: Dynamic infrastructure provisioning with auto-scaling -- **Repository**: https://github.com/runpod/flash - -### 6. Testing Infrastructure (`tests/`) -- **Unit Tests** (`tests/unit/`): Component-level testing for individual modules with mocking -- **Integration Tests** (`tests/integration/`): End-to-end workflow testing with real execution -- **Test Fixtures** (`tests/conftest.py:1`): Shared test data, mock objects, and utility functions -- **Handler Testing**: Local execution validation with JSON test files (`src/tests/`) - - **Full Coverage**: All handler tests pass with environment-aware dependency installation - - **Cross-Platform**: Works correctly in both Docker containers and local macOS/Linux environments - -### 7. Build & Deployment Pipeline -- **Docker Containerization**: GPU (`Dockerfile`) and CPU (`Dockerfile-cpu`) image builds -- **CI/CD Pipeline**: Automated testing, linting, and releases (`.github/workflows/`) -- **Quality Gates** (`Makefile:104`): Format checking, type checking, test coverage requirements -- **Release Management**: Automated semantic versioning and Docker Hub deployment - -### 8. Configuration & Constants -- **Constants** (`src/constants.py`): System-wide configuration values (NAMESPACE, LARGE_SYSTEM_PACKAGES) -- **Environment Configuration**: RunPod API integration +worker-flash (v1.0.1), a RunPod Serverless worker that executes `@remote` functions and classes inside GPU/CPU containers. Receives serialized `FunctionRequest` (cloudpickle + base64), installs dependencies on-the-fly, executes user code, returns `FunctionResponse`. Two modes: Live Serverless (dynamic code per-request) and Flash Deployed Apps (pre-deployed artifacts). Dual entry points: RunPod serverless handler (QB) and FastAPI LB handler. Python 3.10-3.14, base images: pytorch 2.9.1 (GPU), python:3.12-slim (CPU). ## Architecture -### Core Components +### Key Abstractions -- **`src/handler.py`**: Main RunPod serverless handler implementing composition pattern - - Executes arbitrary Python functions remotely with workspace support - - Handles dynamic installation of Python and system dependencies with differential updates - - Serializes/deserializes function arguments and results using cloudpickle - - Captures stdout, stderr, and logs from remote execution +1. **RemoteExecutor** (`src/remote_executor.py:30`) -- Central orchestrator using composition pattern. Coordinates dependency installation, execution routing (function vs class), log streaming. +2. **DependencyInstaller** (`src/dependency_installer.py:12`) -- Python (uv/pip) and system (apt/nala) package installation with env-aware config (Docker vs local). +3. **FunctionExecutor** (`src/function_executor.py:12`) -- Sync/async function execution via `exec()` with stdout/stderr/log capture. +4. **ClassExecutor** (`src/class_executor.py:14`) -- Class instantiation, method dispatch, instance persistence in unbounded registry. +5. **CacheSyncManager** (`src/cache_sync_manager.py:12`) -- Bidirectional cache sync between local `/root/.cache` and network volume tarballs. -- **`runpod_flash.protos.remote_execution`**: Protocol definitions from runpod-flash - - `FunctionRequest`: Defines function execution requests with dependencies - - `FunctionResponse`: Standardized response format with success/error handling - - Imported from installed runpod-flash package via `from runpod_flash.protos.remote_execution import ...` +### Entry Points -### Key Patterns +- **QB Handler** (`src/handler.py`) -- RunPod serverless entry via `runpod.serverless.start()`. Receives jobs from queue, delegates to `RemoteExecutor`. +- **LB Handler** (`src/lb_handler.py`) -- FastAPI app served by uvicorn. HTTP endpoints for load-balanced requests. -1. **Remote Function Execution**: Functions decorated with `@remote` are automatically executed on RunPod GPU workers -2. **Composition Pattern**: RemoteExecutor uses specialized components (DependencyInstaller, Executors) -3. **Dynamic Dependency Management**: Dependencies specified in decorators are installed at runtime with differential updates -4. **Universal Subprocess Operations**: All subprocess calls use centralized `run_logged_subprocess` for consistent logging and error handling -5. **Environment-Aware Configuration**: Automatic Docker vs local environment detection for appropriate installation methods -6. **Serialization**: Uses cloudpickle + base64 encoding for function arguments and results -7. **Resource Configuration**: `LiveServerless` objects define GPU requirements, scaling, and worker configuration +### Module Structure -## Code Intelligence with MCP +``` +src/ + handler.py # RunPod serverless entry point (QB mode) + lb_handler.py # FastAPI Load Balancer entry point (LB mode) + remote_executor.py # Central orchestrator (composition: DependencyInstaller + Executors) + function_executor.py # Function execution with stdout/stderr/log capture + class_executor.py # Class instantiation, method dispatch, instance persistence + dependency_installer.py # Python (uv/pip) + system (apt/nala) package installation + serialization_utils.py # CloudPickle + base64 encode/decode utilities + subprocess_utils.py # Centralized subprocess with logging via run_logged_subprocess + log_streamer.py # Thread-safe log buffering for captured output + logger.py # Logging configuration + cache_sync_manager.py # Network volume <-> local cache bidirectional sync + manifest_reconciliation.py # TTL-based flash_manifest.json refresh + unpack_volume.py # Build artifact extraction from network volume + constants.py # Named constants (NAMESPACE, LARGE_SYSTEM_PACKAGES) +``` -This project has a worker-flash-code-intel MCP server configured for efficient codebase exploration. The code intelligence index includes both project source code and the runpod_flash dependency. +## Public API Surface -### Indexed Codebase +No public Python API. The worker exposes its interface through: -The following are automatically indexed and searchable via MCP tools: -- **Project source** (`src/`) - All 83 worker-flash symbols -- **runpod_flash dependency** - All 552 protocol definitions, resources, and core components - - Protocol definitions: `runpod_flash.protos.remote_execution` (`FunctionRequest`, `FunctionResponse`, etc.) - - Resources: `runpod_flash.core.resources` (`LiveServerless`, `Serverless`, `NetworkVolume`, etc.) - - Stubs: `runpod_flash.stubs` (stub implementations for local development) +- **QB protocol**: `FunctionRequest` in, `FunctionResponse` out (via `runpod.serverless.start()`) +- **LB protocol**: HTTP endpoints mapped from `@remote(method=..., path=...)` decorators +- **Protocol definitions**: imported from `runpod_flash.protos.remote_execution` + +### Environment Variables -To regenerate the index (when dependencies change), run: `make index` +| Variable | Required | Purpose | +|----------|----------|---------| +| `RUNPOD_API_KEY` | Yes | RunPod API authentication | +| `RUNPOD_ENDPOINT_ID` | Auto | Workspace isolation (set by RunPod) | +| `FLASH_ENDPOINT_TYPE` | Auto | QB vs LB mode selection | +| `FLASH_RESOURCE_NAME` | Auto | Resource identification | +| `FLASH_MAIN_FILE` | Deploy | Entry file for deployed apps | +| `FLASH_APP_VARIABLE` | Deploy | App variable name for deployed apps | +| `FLASH_BUILD_ARTIFACT_PATH` | Deploy | Path to build artifacts | +| `FLASH_DISABLE_UNPACK` | Deploy | Skip artifact extraction | +| `LOG_LEVEL` | No | Logging verbosity (default: INFO) | +| `HF_HUB_ENABLE_HF_TRANSFER` | No | Accelerated HuggingFace downloads | +| `HF_HOME` | No | HuggingFace cache location (default: `/hf-cache`) | +| `HF_TOKEN` | No | Auth for private/gated HF models | -To add more dependencies to the index, edit `DEPENDENCIES_TO_INDEX` in `scripts/ast_to_sqlite.py`. +## Cross-Repo Dependencies -### MCP Tools for Code Intelligence +### Depends On -**Always prefer these MCP tools over Grep/Glob for semantic code searches:** +- **flash** (`runpod_flash` package) -- imports `FunctionRequest`, `FunctionResponse`, `RemoteExecutorStub`, `ServiceRegistry`, `StateManagerClient` from `runpod_flash.protos` and `runpod_flash.runtime`. +- **runpod-python** (`runpod` package) -- `runpod.serverless.start()` for QB mode handler registration. -- **`find_symbol(symbol)`** - Find classes, functions, methods by name (supports partial matches) - - Example: Finding `RemoteExecutor` class, `FunctionRequest` protocol, or `handler` function - - Use instead of: `grep -r "class RemoteExecutor"` or `glob "**/*.py"` +### Depended On By -- **`list_classes()`** - Get all classes in codebase - - Use instead of: `grep -r "^class "` +- **flash** -- builds Docker images that run this worker. Docker image names hardcoded in flash's `core/resources/constants.py`. +- **flash-examples** -- indirectly; user code runs inside this worker. -- **`get_class_interface(class_name)`** - Get class methods without implementations - - Example: `get_class_interface("DependencyInstaller")` to see available methods - - Example: `get_class_interface("FunctionRequest")` to see protocol fields - - Use instead of: Reading full file and parsing manually +### Interface Contracts -- **`list_file_symbols(file_path)`** - List all symbols (classes, functions) in a specific file - - Use instead of: `grep` on individual files for symbol discovery +- **FunctionRequest/FunctionResponse protocol** -- the primary contract between flash and flash-worker. Any field changes require coordinated releases across both repos. +- **`FunctionResponse` as generic envelope** -- `subprocess_utils.py` reuses `FunctionResponse` for subprocess results (coupling to protocol schema). +- **Docker image tags** -- flash deploys specific image tags; worker image names/tags must match flash's `constants.py`. +- **Manifest schema** -- `manifest_reconciliation.py` parses `flash_manifest.json` generated by flash's build step. -- **`find_by_decorator(decorator)`** - Find functions/classes with specific decorators - - Example: `find_by_decorator("remote")` to find all @remote decorated functions - - Use instead of: `grep -r "@remote"` +### Dependency Chain -### Tool Selection Guidelines +``` +flash-examples --> flash (runpod_flash) --> runpod-python (runpod) +flash-worker --> flash (protocols) --> runpod-python (serverless.start) +``` -**When to use MCP vs Grep/Glob:** -- **MCP tools**: Semantic searches (class names, function definitions, decorators, symbols) - including runpod_flash -- **Grep**: Content searches (error messages, comments, string literals, log statements) -- **Glob**: File path patterns when you know the exact file structure -- **Task tool with Explore agent**: Complex multi-step exploration requiring multiple searches +### Known Drift -**Example workflow:** -- "Find all @remote functions" → use `find_by_decorator("remote")` -- "Where is RemoteExecutor defined" → use `find_symbol("RemoteExecutor")` -- "What fields does FunctionRequest protocol have" → use `get_class_interface("FunctionRequest")` -- "Where is LiveServerless used" → use `find_symbol("LiveServerless")` -- "Where is error 'API timeout' logged" → use Grep -- "Find all test_*.json files" → use Glob +- Python version: runpod-python supports 3.8+, flash-worker requires 3.10+ +- Coverage thresholds: runpod-python 90%, flash 35%, flash-worker 35% ## Development Commands -### Setup and Dependencies +### Setup + ```bash make setup # Initialize project and sync dependencies -make dev # Install all development dependencies (includes pytest, ruff) -uv sync # Sync production dependencies only -uv sync --all-groups # Sync all dependency groups (same as make dev) +make dev # Install all development dependencies +uv sync --all-groups # Alternative: sync all dependency groups ``` -### Code Quality +### Testing + ```bash -make lint # Check code with ruff linter -make lint-fix # Auto-fix linting issues -make format # Format code with ruff -make format-check # Check if code is properly formatted -make quality-check # Run all quality checks (format, lint, test coverage) +make test # Run all tests +make test-unit # Unit tests only +make test-integration # Integration tests only +make test-coverage # Tests with coverage report +make test-fast # Tests with fail-fast mode +make test-handler # Test handler locally with all test_*.json files (matches CI) ``` -### Testing Commands +### Quality + ```bash -make test # Run all tests -make test-unit # Run unit tests only -make test-integration # Run integration tests only -make test-coverage # Run tests with coverage report -make test-fast # Run tests with fail-fast mode -make test-handler # Test handler locally with all test_*.json files (same as CI) +make quality-check # REQUIRED BEFORE ALL COMMITS (format + lint + tests + coverage) +make lint # Ruff linter +make lint-fix # Auto-fix lint issues +make format # Ruff formatter +make format-check # Check formatting +make typecheck # mypy type checking ``` -### Docker Operations +### Build and Deploy + ```bash -make build # Build GPU Docker image (linux/amd64) -make build-cpu # Build CPU-only Docker image -# Note: Docker push is automated via GitHub Actions on release +make build # Build GPU Docker image (single-platform, loads locally) +make build-cpu # Build CPU-only Docker image +make build-lb # Build Load Balancer image +make build-wip # Multi-platform build, pushes to Docker Hub (NOT visible in docker images) +make smoketest # Test built images locally +make smoketest-lb # Test LB images locally ``` -## Configuration +### Code Intelligence -### Environment Variables -- `RUNPOD_API_KEY`: Required for RunPod Serverless integration -- `RUNPOD_ENDPOINT_ID`: Used for workspace isolation (automatically set by RunPod) -- `HF_HUB_ENABLE_HF_TRANSFER`: Set to "1" in Dockerfile to enable accelerated HuggingFace downloads -- `HF_TOKEN`: Optional authentication token for private/gated HuggingFace models -- `HF_HOME=/hf-cache`: HuggingFace cache location, set outside `/root/.cache` to exclude from volume sync -- `DEBIAN_FRONTEND=noninteractive`: Set during system package installation - -### Resource Configuration -Configure GPU resources using `LiveServerless` objects: -```python -gpu_config = LiveServerless( - name="my-endpoint", # Endpoint name (required) - gpus=[GpuGroup.ANY], # GPU types - workersMax=5, # Max concurrent workers - workersMin=0, # Min workers (0 = scale to zero) - idleTimeout=5, # Minutes before scaling down - executionTimeoutMs=600000, # Max execution time -) +```bash +make index # Rebuild MCP code intelligence index ``` -## Testing and Quality +## Code Health -### Testing Framework -- **pytest** with coverage reporting and async support -- **Unit tests** (`tests/unit/`): Test individual components in isolation -- **Integration tests** (`tests/integration/`): Test end-to-end workflows -- **Coverage target**: 35% minimum, with HTML and XML reports -- **Test fixtures**: Shared test data and mocks in `tests/conftest.py` -- **CI Integration**: Tests run on all PRs and before releases/deployments +### High Severity -## Development Notes +- None critical. Architecture is clean with composition pattern. -### Dependency Management -- Root project uses `uv` with `pyproject.toml` -- Runpod Flash SDK installed as pip dependency from GitHub repository -- System dependencies installed via `apt-get` in containerized environment -- Python dependencies installed via `uv pip install` at runtime -- **Differential Installation**: Only installs packages missing from environment -- **Environment Awareness**: Uses appropriate python preferences (Docker: `--python-preference=only-system`, Local: managed python) +### Medium Severity -### Error Handling -- All remote execution wrapped in try/catch with full traceback capture -- Structured error responses via `FunctionResponse.error` -- Combined stdout/stderr/log capture for debugging +- `ExecuteFunction` is 150 lines (`remote_executor.py:58`) -- main execution path, consider extracting sub-methods +- `sync_to_volume` is 170 lines (`cache_sync_manager.py:103`) -- complex tarball sync logic +- `RemoteExecutor` only tested indirectly -- direct unit tests needed -### Security Considerations -- Functions execute arbitrary Python code in sandboxed containers -- System package installation requires root privileges in container -- No secrets should be committed to repository -- API keys passed via environment variables +### Low Severity -## File Structure Highlights +- `_UNPACKED` flag set before extraction completes (`unpack_volume.py:130`) -- race condition if extraction fails +- Off-by-one retry sleep in `unpack_volume.py:147` -- first retry has no backoff +- Unbounded instance registry in `class_executor.py:19` -- memory leak for long-running workers with many unique classes +- No mutable defaults, no bare except, no `print()`, no TODOs, no commented-out code -- clean -``` -├── src/ # Core implementation -│ ├── handler.py # Main serverless function handler -│ ├── remote_executor.py # Central execution orchestrator -│ ├── function_executor.py # Function execution with output capture -│ ├── class_executor.py # Class execution with persistence -│ ├── dependency_installer.py # Python and system dependency management -│ ├── serialization_utils.py # CloudPickle serialization utilities -│ ├── base_executor.py # Common execution interface -│ ├── constants.py # System-wide configuration constants -│ └── tests/ # Handler test JSON files -├── tests/ # Comprehensive test suite -│ ├── conftest.py # Shared test fixtures -│ ├── unit/ # Unit tests for individual components -│ └── integration/ # End-to-end integration tests -├── Dockerfile # GPU container definition -├── Dockerfile-cpu # CPU container definition -└── Makefile # Development commands and quality gates -``` +## Testing + +### Structure + +- `tests/unit/` -- component-level testing with mocking +- `tests/integration/` -- end-to-end workflow testing +- `tests/conftest.py` -- shared fixtures and mock objects +- `src/tests/` -- handler test JSON files for `make test-handler` +- Coverage threshold: 35% minimum + +### Coverage Gaps + +| File | Coverage | Risk | +|------|----------|------| +| `log_streamer.py` | None | MEDIUM -- thread-safe buffering untested | +| `subprocess_utils.py` | None | HIGH -- all subprocess calls flow through here | +| `logger.py` | None | LOW | +| `remote_executor.py` | Indirect only | HIGH -- central orchestrator needs direct tests | + +### Patterns + +- Arrange-Act-Assert in all tests +- Mock external services (RunPod API, file system for volumes) +- Use `make test-handler` for handler validation (matches CI behavior) +- Do NOT run individual test files manually with `RUNPOD_TEST_INPUT` -- use `make test-handler` + +### Docker Testing + +- Docker containers should never reference `src/` paths directly +- Use `make build` for local testing (visible in `docker images`) +- Use `make build-wip` only for pushing to Docker Hub (NOT visible locally) + +## Code Intelligence (MCP) + +**Server:** `worker-flash-code-intel` + +**Always prefer MCP tools over Grep/Glob for semantic code searches.** + +| Tool | Use Case | Example | +|------|----------|---------| +| `find_symbol(symbol)` | Find classes, functions, methods by name | `find_symbol("RemoteExecutor")` | +| `list_classes()` | Get all classes in codebase | Exploring class hierarchy | +| `get_class_interface(class_name)` | Inspect class methods/properties | `get_class_interface("DependencyInstaller")` | +| `list_file_symbols(file_path)` | View file structure | `list_file_symbols("src/handler.py")` | +| `find_by_decorator(decorator)` | Find decorated items | `find_by_decorator("remote")` | + +**Indexed codebase includes:** +- Project source (`src/`) -- all 83 worker-flash symbols +- `runpod_flash` dependency -- all 552 protocol definitions, resources, and core components + +**When to use Grep instead:** Content searches (error messages, string literals, log statements, env var usage). + +**Rebuild index when dependencies change:** `make index` -## CI/CD and Release Process - -### Automated Releases -- Uses `release-please` for automated semantic versioning and changelog generation -- Releases are triggered by conventional commit messages on `main` branch -- Docker images are automatically built and pushed to Docker Hub (`runpod/flash`) on release - -### GitHub Actions Workflows -- **CI/CD** (`.github/workflows/ci.yml`): Single workflow handling tests, linting, releases, and Docker builds - - Runs tests and linting on PRs and pushes to main - - **Local execution testing**: Automatically tests all `test_*.json` files in src directory to validate handler functionality - - Manages releases via `release-please` on main branch - - Builds and pushes `:main` tagged images on main branch pushes - - Builds and pushes production images with semantic versioning on releases - - Supports manual triggering via `workflow_dispatch` for ad-hoc runs - -### Required Secrets -Configure these in GitHub repository settings: -- `DOCKERHUB_USERNAME`: Docker Hub username -- `DOCKERHUB_TOKEN`: Docker Hub password or access token - -## Branch Information -- Main branch: `main` -- Current branch: `tmp/deployed-execution` - -## Development Best Practices - -- Always run `make quality-check` before committing changes -- Always use `git mv` when moving existing files around -- Run `make test-handler` to validate handler functionality with test files -- Never create files unless absolutely necessary for achieving goals -- Always prefer editing existing files to creating new ones -- Never proactively create documentation files unless explicitly requested - -## Project Memories - -### Docker Guidelines -- Docker container should never refer to src/ - -### Testing Guidelines -- Use `make test-handler` to run checks on test files -- Do not run individual test files manually like `Bash(env RUNPOD_TEST_INPUT="$(cat test_input.json)" PYTHONPATH=. uv run python handler.py)` - -### File Management -- Use `git mv` when moving existing files -- Prefer editing existing files over creating new ones -- Only create files when absolutely necessary +--- +*Last analyzed: 2026-02-22* diff --git a/docs/Runtime_Execution_Paths.md b/docs/Runtime_Execution_Paths.md index 12dc257..d5ad26d 100644 --- a/docs/Runtime_Execution_Paths.md +++ b/docs/Runtime_Execution_Paths.md @@ -71,15 +71,17 @@ graph TB The handler automatically detects the deployment mode using environment variables: -| Environment | RUNPOD_POD_ID | FLASH_* vars | Mode Detected | -|-------------|---------------|--------------|---------------| +| Environment | RUNPOD_ENDPOINT_ID | FLASH_* vars | Mode Detected | +|-------------|-------------------|--------------|---------------| | Local dev | ❌ Not set | ❌ Not set | Live Serverless only | | Live Serverless | ✅ Set | ❌ Not set | Live Serverless | -| Flash Mothership | ✅ Set | ✅ FLASH_IS_MOTHERSHIP=true | Flash Deployed | +| Flash LB Endpoint | ✅ Set | ✅ FLASH_ENDPOINT_TYPE=lb | Flash Deployed | +| Flash QB Endpoint | ✅ Set | ✅ FLASH_ENDPOINT_TYPE=qb | Flash Deployed | | Flash Child | ✅ Set | ✅ FLASH_RESOURCE_NAME | Flash Deployed | Flash-specific environment variables: -- `FLASH_IS_MOTHERSHIP=true` - Set for mothership endpoints +- `FLASH_ENDPOINT_TYPE=lb` - Set for load-balanced endpoints +- `FLASH_ENDPOINT_TYPE=qb` - Set for queue-based endpoints - `FLASH_RESOURCE_NAME` - Specifies resource config name ## Request Format Differences diff --git a/pyproject.toml b/pyproject.toml index 7fecb2c..4d8dae9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "huggingface_hub>=0.32.0", "fastapi>=0.115.0", "uvicorn[standard]>=0.34.0", - "runpod-flash", + "runpod-flash>=1.4.0", ] [dependency-groups] diff --git a/src/constants.py b/src/constants.py index 459d690..2a7c7bc 100644 --- a/src/constants.py +++ b/src/constants.py @@ -47,6 +47,6 @@ """Default timeout in seconds for cross-endpoint HTTP requests.""" DEFAULT_TARBALL_UNPACK_ATTEMPTS = 3 -"""Number of times the mothership CPU will attempt to unpack the worker-flash tarball from mounted volume""" +"""Number of times the Flash-deployed endpoint will attempt to unpack the worker-flash tarball from mounted volume.""" DEFAULT_TARBALL_UNPACK_INTERVAL = 30 -"""Time in seconds mothership CPU endpoint will wait between tarball unpack attempts""" +"""Time in seconds the Flash-deployed endpoint will wait between tarball unpack attempts.""" diff --git a/src/handler.py b/src/handler.py index 5066690..9f6276a 100644 --- a/src/handler.py +++ b/src/handler.py @@ -1,7 +1,9 @@ -from typing import Dict, Any +import importlib.util +import logging +import os +from pathlib import Path +from typing import Any, Dict, Optional -from runpod_flash.protos.remote_execution import FunctionRequest, FunctionResponse -from remote_executor import RemoteExecutor from logger import setup_logging from unpack_volume import maybe_unpack @@ -12,25 +14,129 @@ # This is a no-op for Live Serverless and local development maybe_unpack() +logger = logging.getLogger(__name__) -async def handler(event: Dict[str, Any]) -> Dict[str, Any]: - """ - RunPod serverless function handler with dependency installation. + +def _load_generated_handler() -> Optional[Any]: + """Load Flash-generated handler if available (deployed QB mode). + + Checks for a handler_.py file generated by the flash + build pipeline. These handlers accept plain JSON input without + FunctionRequest/cloudpickle serialization. + + Returns: + Handler function if generated handler found, None otherwise. """ - output: FunctionResponse + resource_name = os.getenv("FLASH_RESOURCE_NAME") + if not resource_name: + return None + + handler_file = Path(f"/app/handler_{resource_name}.py") + + if not handler_file.resolve().is_relative_to(Path("/app").resolve()): + logger.warning( + "FLASH_RESOURCE_NAME '%s' resolves outside /app. " + "Falling back to FunctionRequest handler.", + resource_name, + ) + return None + + if not handler_file.exists(): + logger.warning( + "Generated handler file %s not found for resource '%s'. " + "The build artifact may be incomplete. " + "Falling back to FunctionRequest handler.", + handler_file, + resource_name, + ) + return None + spec = importlib.util.spec_from_file_location(f"handler_{resource_name}", handler_file) + if not spec or not spec.loader: + logger.warning("Failed to create module spec for %s", handler_file) + return None + + mod = importlib.util.module_from_spec(spec) try: - executor = RemoteExecutor() - input_data = FunctionRequest(**event.get("input", {})) - output = await executor.ExecuteFunction(input_data) - - except Exception as error: - output = FunctionResponse( - success=False, - error=f"Error in handler: {str(error)}", + spec.loader.exec_module(mod) + except ImportError as e: + logger.warning( + "Generated handler %s failed to import (missing dependency: %s). " + "Deploy with --use-local-flash to include latest runpod_flash. " + "Falling back to FunctionRequest handler.", + handler_file, + e, + ) + return None + except SyntaxError as e: + logger.error( + "Generated handler %s has a syntax error: %s. " + "This indicates a bug in the flash build pipeline. " + "Falling back to FunctionRequest handler.", + handler_file, + e, + ) + return None + except Exception as e: + logger.error( + "Generated handler %s failed to load unexpectedly: %s (%s). " + "Falling back to FunctionRequest handler.", + handler_file, + e, + type(e).__name__, + exc_info=True, ) + return None + + generated = getattr(mod, "handler", None) + if generated is None: + logger.warning( + "Generated handler %s loaded but has no 'handler' attribute. " + "Ensure the flash build pipeline generates a 'handler' function. " + "Falling back to FunctionRequest handler.", + handler_file, + ) + return None + + if not callable(generated): + logger.warning( + "Generated handler %s has a 'handler' attribute but it is not callable (%s). " + "Falling back to FunctionRequest handler.", + handler_file, + type(generated).__name__, + ) + return None + + logger.info("Loaded generated handler from %s", handler_file) + return generated + + +# Try generated handler first (plain JSON mode for deployed QB endpoints) +_generated = _load_generated_handler() + +if _generated: + handler = _generated +else: + # Fallback: original FunctionRequest handler (backward compatible) + from runpod_flash.protos.remote_execution import FunctionRequest, FunctionResponse + from remote_executor import RemoteExecutor + + async def handler(event: Dict[str, Any]) -> Dict[str, Any]: + """RunPod serverless function handler with dependency installation.""" + output: FunctionResponse + + try: + executor = RemoteExecutor() + input_data = FunctionRequest(**event.get("input", {})) + output = await executor.ExecuteFunction(input_data) + + except Exception as error: + output = FunctionResponse( + success=False, + error=f"Error in handler: {str(error)}", + ) - return output.model_dump() # type: ignore[no-any-return] + return output.model_dump() # type: ignore[no-any-return] # Start the RunPod serverless handler (only available on RunPod platform) diff --git a/src/lb_handler.py b/src/lb_handler.py index 69686e9..de2f497 100644 --- a/src/lb_handler.py +++ b/src/lb_handler.py @@ -3,18 +3,18 @@ This handler provides a FastAPI application for the Load Balancer runtime. It supports: - /ping: Health check endpoint (required by RunPod Load Balancer) -- /execute: Remote function execution via HTTP POST (queue-based mode) -- User's FastAPI app routes (mothership mode) +- /execute: Remote function execution via HTTP POST (QB endpoint mode) +- User's FastAPI app routes (LB endpoint mode) The handler uses worker-flash's RemoteExecutor for function execution. -Mothership Mode (FLASH_IS_MOTHERSHIP=true): -- Imports user's FastAPI application from FLASH_MAIN_FILE -- Loads the app object from FLASH_APP_VARIABLE +LB Endpoint Mode (FLASH_ENDPOINT_TYPE=lb): +- Auto-discovers generated handler from FLASH_RESOURCE_NAME +- Loads handler_{resource_name}.py with FastAPI app - Preserves all user routes and middleware - Adds /ping health check endpoint -Queue-Based Mode (FLASH_IS_MOTHERSHIP not set or false): +QB Endpoint Mode (FLASH_ENDPOINT_TYPE not set or not "lb"): - Creates generic FastAPI app with /execute endpoint - Uses RemoteExecutor for function execution """ @@ -22,6 +22,7 @@ import importlib.util import logging import os +from pathlib import Path from typing import Any, Dict from fastapi import FastAPI @@ -42,67 +43,98 @@ from runpod_flash.protos.remote_execution import FunctionRequest, FunctionResponse # noqa: E402 from remote_executor import RemoteExecutor # noqa: E402 -# Determine mode based on environment variables -is_mothership = os.getenv("FLASH_IS_MOTHERSHIP") == "true" -if is_mothership: - # Mothership mode: Import user's FastAPI application - try: - main_file = os.getenv("FLASH_MAIN_FILE", "main.py") - app_variable = os.getenv("FLASH_APP_VARIABLE", "app") +def _is_lb_endpoint() -> bool: + """Determine if this endpoint runs in LB mode (serves user FastAPI routes).""" + return os.getenv("FLASH_ENDPOINT_TYPE") == "lb" - logger.info(f"Mothership mode: Importing {app_variable} from {main_file}") - # Dynamic import of user's module - spec = importlib.util.spec_from_file_location("user_main", main_file) - if spec is None or spec.loader is None: - raise ImportError(f"Cannot find or load {main_file}") +def _discover_lb_app(handler_dir: str = "/app") -> FastAPI: + """Auto-discover and load the generated LB handler's FastAPI app. - user_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(user_module) + Derives handler path from FLASH_RESOURCE_NAME and imports the module. - # Get the FastAPI app from user's module - if not hasattr(user_module, app_variable): - raise AttributeError(f"Module {main_file} does not have '{app_variable}' attribute") + Args: + handler_dir: Base directory for handler files (default /app). - app = getattr(user_module, app_variable) + Returns: + FastAPI app from the generated handler. - if not isinstance(app, FastAPI): - raise TypeError( - f"Expected FastAPI instance, got {type(app).__name__} for {app_variable}" - ) + Raises: + RuntimeError: If FLASH_RESOURCE_NAME is not set or resolves outside handler_dir. + FileNotFoundError: If the handler file does not exist. + ImportError: If the handler module cannot produce a valid spec. + AttributeError: If the handler module lacks an 'app' attribute. + TypeError: If the 'app' attribute is not a FastAPI instance. + """ + resource_name = os.getenv("FLASH_RESOURCE_NAME") + if not resource_name: + raise RuntimeError("FLASH_RESOURCE_NAME not set. Cannot discover generated LB handler.") + + handler_file = f"{handler_dir}/handler_{resource_name}.py" + + handler_path = Path(handler_file) + if not handler_path.resolve().is_relative_to(Path(handler_dir).resolve()): + raise RuntimeError(f"FLASH_RESOURCE_NAME '{resource_name}' resolves outside {handler_dir}") + + app_variable = "app" + + spec = importlib.util.spec_from_file_location("user_main", handler_file) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot find or load {handler_file}") + + user_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(user_module) - logger.info(f"Successfully imported FastAPI app '{app_variable}' from {main_file}") + if not hasattr(user_module, app_variable): + raise AttributeError(f"Module {handler_file} does not have '{app_variable}' attribute") + + discovered_app = getattr(user_module, app_variable) + + if not isinstance(discovered_app, FastAPI): + raise TypeError( + f"Expected FastAPI instance, got {type(discovered_app).__name__} for {app_variable}" + ) + + return discovered_app + + +is_lb_endpoint = _is_lb_endpoint() + +if is_lb_endpoint: + # LB endpoint mode: Auto-discover generated handler from FLASH_RESOURCE_NAME + try: + app = _discover_lb_app() + logger.info("Successfully imported FastAPI app for LB endpoint") # Add /ping endpoint if not already present - # Check if /ping route already exists to avoid adding a duplicate health check endpoint ping_exists = any(getattr(route, "path", None) == "/ping" for route in app.routes) if not ping_exists: @app.get("/ping") - async def ping_mothership() -> Dict[str, Any]: - """Health check endpoint for mothership (added by framework).""" + async def ping_lb() -> Dict[str, Any]: + """Health check endpoint for LB (added by framework).""" return { "status": "healthy", - "endpoint": "mothership", + "endpoint": "lb", "id": os.getenv("RUNPOD_ENDPOINT_ID", "unknown"), } logger.info("Added /ping endpoint to user's FastAPI app") except Exception as error: - logger.error(f"Failed to initialize mothership mode: {error}", exc_info=True) + logger.error("Failed to initialize LB endpoint mode: %s", error, exc_info=True) raise else: # Queue-based mode: Create generic Load Balancer handler app app = FastAPI(title="Load Balancer Handler") - logger.info("Queue-based mode: Using generic Load Balancer handler") + logger.info("QB endpoint mode: Using generic Load Balancer handler") # Queue-based mode endpoints -if not is_mothership: +if not is_lb_endpoint: @app.get("/ping") async def ping() -> Dict[str, Any]: diff --git a/src/manifest_reconciliation.py b/src/manifest_reconciliation.py index 960ad1c..0445253 100644 --- a/src/manifest_reconciliation.py +++ b/src/manifest_reconciliation.py @@ -30,7 +30,7 @@ def is_flash_deployment() -> bool: endpoint_id = os.getenv("RUNPOD_ENDPOINT_ID") is_flash = any( [ - os.getenv("FLASH_IS_MOTHERSHIP") == "true", + os.getenv("FLASH_ENDPOINT_TYPE") in ("lb", "qb"), os.getenv("FLASH_RESOURCE_NAME"), ] ) @@ -78,8 +78,9 @@ def _is_manifest_stale( if is_stale: logger.debug(f"Manifest is stale: {age_seconds:.0f}s old (TTL: {ttl_seconds}s)") return is_stale - except OSError: - return True # Error reading file, consider stale + except OSError as e: + logger.debug("Cannot stat manifest file %s: %s. Treating as stale.", manifest_path, e) + return True async def _fetch_and_save_manifest( @@ -112,8 +113,16 @@ async def _fetch_and_save_manifest( logger.info("Manifest refreshed from State Manager") return True + except (OSError, ConnectionError, TimeoutError) as e: + logger.warning("Failed to refresh manifest from State Manager: %s", e) + return False except Exception as e: - logger.warning(f"Failed to refresh manifest from State Manager: {e}") + logger.error( + "Unexpected error refreshing manifest from State Manager: %s (%s)", + e, + type(e).__name__, + exc_info=True, + ) return False diff --git a/src/remote_executor.py b/src/remote_executor.py index a5ca85d..5b50581 100644 --- a/src/remote_executor.py +++ b/src/remote_executor.py @@ -43,17 +43,22 @@ def __init__(self): self.class_executor = ClassExecutor() self.cache_sync = CacheSyncManager() - # Service discovery for cross-endpoint routing (peer-to-peer model) + # Service discovery for cross-endpoint routing (Flash Deployed only). + # Only init when manifest exists on disk — Live Serverless workers + # never have flash_manifest.json; they use @remote stacking instead. self.service_registry: Optional[ServiceRegistry] = None - if ServiceRegistry is not None: + manifest_path = Path(FLASH_MANIFEST_PATH) + if ServiceRegistry is not None and manifest_path.exists(): try: - self.service_registry = ServiceRegistry(manifest_path=Path(FLASH_MANIFEST_PATH)) + self.service_registry = ServiceRegistry(manifest_path=manifest_path) self.logger.debug("Service registry initialized for cross-endpoint routing") except Exception as e: self.logger.debug(f"Failed to initialize service registry: {e}") self.service_registry = None - else: + elif ServiceRegistry is None: self.logger.debug("ServiceRegistry not available (runpod-flash not installed)") + else: + self.logger.debug("No flash_manifest.json, skipping service registry") async def ExecuteFunction(self, request: FunctionRequest) -> FunctionResponse: """ @@ -380,17 +385,15 @@ async def _execute_flash_function(self, request: FunctionRequest) -> FunctionRes # function_name is guaranteed to be non-None by FunctionRequest validation func = getattr(module, function_name) - # Deserialize args/kwargs (same as Live Serverless) + # Deserialize args/kwargs from cloudpickle-encoded strings args = SerializationUtils.deserialize_args(request.args) kwargs = SerializationUtils.deserialize_kwargs(request.kwargs) # Execute function - # Check if async or sync if func_details["is_async"]: if asyncio.iscoroutinefunction(func): result = await func(*args, **kwargs) else: - # Run in executor for blocking calls result = await asyncio.to_thread(func, *args, **kwargs) else: result = await asyncio.to_thread(func, *args, **kwargs) diff --git a/src/unpack_volume.py b/src/unpack_volume.py index 14d8706..139d08e 100644 --- a/src/unpack_volume.py +++ b/src/unpack_volume.py @@ -85,8 +85,8 @@ def _should_unpack_from_volume() -> bool: Detection logic: 1. Honor explicit disable flag (FLASH_DISABLE_UNPACK) - 2. Must be in RunPod environment (RUNPOD_POD_ID or RUNPOD_ENDPOINT_ID) - 3. Must be Flash deployment (any of FLASH_IS_MOTHERSHIP, FLASH_RESOURCE_NAME) + 2. Must be in RunPod environment (RUNPOD_ENDPOINT_ID) + 3. Must be Flash deployment (FLASH_ENDPOINT_TYPE or FLASH_RESOURCE_NAME) Returns: bool: True if unpacking should occur, False otherwise @@ -126,7 +126,6 @@ def maybe_unpack(): if _UNPACKED: return - _UNPACKED = True logger.info("unpacking app from volume") last_error: Exception | None = None @@ -139,12 +138,12 @@ def maybe_unpack(): last_error = e logger.error( "failed to unpack app from volume (attempt %s/%s): %s", - attempt, + attempt + 1, DEFAULT_TARBALL_UNPACK_ATTEMPTS, e, exc_info=True, ) - if attempt < DEFAULT_TARBALL_UNPACK_ATTEMPTS: + if attempt < DEFAULT_TARBALL_UNPACK_ATTEMPTS - 1: sleep(DEFAULT_TARBALL_UNPACK_INTERVAL) raise RuntimeError( f"failed to unpack app from volume after retries: {last_error}" diff --git a/tests/integration/test_manifest_state_manager.py b/tests/integration/test_manifest_state_manager.py index 293ecff..db28433 100644 --- a/tests/integration/test_manifest_state_manager.py +++ b/tests/integration/test_manifest_state_manager.py @@ -127,7 +127,7 @@ async def test_manifest_refresh_on_cross_endpoint_routing( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "RUNPOD_API_KEY": "test-api-key", }, clear=True, @@ -160,7 +160,7 @@ async def test_manifest_refresh_skipped_if_fresh( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "RUNPOD_API_KEY": "test-api-key", }, clear=True, @@ -196,7 +196,7 @@ async def test_manifest_refresh_continues_on_failure( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "RUNPOD_API_KEY": "test-api-key", }, clear=True, @@ -232,7 +232,7 @@ async def test_state_manager_unavailable_graceful_degradation( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "RUNPOD_API_KEY": "test-api-key", }, clear=True, @@ -320,7 +320,7 @@ async def test_state_manager_overwrites_local( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "RUNPOD_API_KEY": "test-api-key", }, clear=True, @@ -378,7 +378,7 @@ async def test_state_manager_provides_additional_metadata( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "RUNPOD_API_KEY": "test-api-key", }, clear=True, @@ -421,7 +421,7 @@ async def test_fallback_to_local_on_state_manager_error( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "RUNPOD_API_KEY": "test-api-key", }, clear=True, @@ -456,7 +456,7 @@ async def test_manifest_file_write_error( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "RUNPOD_API_KEY": "test-api-key", }, clear=True, @@ -488,7 +488,7 @@ async def test_multiple_refreshes_with_ttl( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "RUNPOD_API_KEY": "test-api-key", }, clear=True, diff --git a/tests/unit/test_handler.py b/tests/unit/test_handler.py index 2fd319e..ab63135 100644 --- a/tests/unit/test_handler.py +++ b/tests/unit/test_handler.py @@ -4,7 +4,7 @@ import base64 import cloudpickle from unittest.mock import patch, AsyncMock -from handler import handler +from handler import handler, _load_generated_handler from runpod_flash.protos.remote_execution import FunctionResponse @@ -143,3 +143,114 @@ async def test_handler_class_execution(self): assert result["success"] is True assert "instance_id" in result assert "instance_info" in result + + +class TestLoadGeneratedHandler: + """Test cases for _load_generated_handler delegation logic.""" + + def test_returns_none_when_no_resource_name(self): + """Without FLASH_RESOURCE_NAME, returns None (fallback to FunctionRequest).""" + with patch.dict("os.environ", {}, clear=True): + result = _load_generated_handler() + assert result is None + + def test_logs_warning_when_handler_file_missing(self, tmp_path): + """With FLASH_RESOURCE_NAME but no handler file, logs warning and returns None.""" + with patch.dict("os.environ", {"FLASH_RESOURCE_NAME": "gpu_config"}): + with patch("handler.Path") as mock_path_cls: + mock_path = mock_path_cls.return_value + mock_path.exists.return_value = False + result = _load_generated_handler() + assert result is None + + def test_loads_generated_handler_from_file(self, tmp_path): + """With valid generated handler file, loads and returns handler function.""" + handler_file = tmp_path / "handler_gpu_config.py" + handler_file.write_text( + "async def handler(event):\n" + " return {'result': event.get('input', {}).get('prompt', 'default')}\n" + ) + + with patch.dict("os.environ", {"FLASH_RESOURCE_NAME": "gpu_config"}): + with patch("handler.Path", return_value=handler_file): + result = _load_generated_handler() + + assert result is not None + assert callable(result) + + def test_returns_none_when_spec_creation_fails(self): + """If importlib cannot create spec, returns None.""" + with patch.dict("os.environ", {"FLASH_RESOURCE_NAME": "gpu_config"}): + with patch("handler.Path") as mock_path_cls: + mock_path = mock_path_cls.return_value + mock_path.exists.return_value = True + with patch( + "handler.importlib.util.spec_from_file_location", + return_value=None, + ): + result = _load_generated_handler() + + assert result is None + + def test_returns_none_on_import_error(self, tmp_path): + """If generated handler has ImportError, falls back gracefully.""" + handler_file = tmp_path / "handler_gpu_config.py" + handler_file.write_text( + "from nonexistent_package import missing_function\ndef handler(event): pass\n" + ) + + with patch.dict("os.environ", {"FLASH_RESOURCE_NAME": "gpu_config"}): + with patch("handler.Path", return_value=handler_file): + result = _load_generated_handler() + + assert result is None + + def test_returns_none_on_syntax_error(self, tmp_path): + """SyntaxError in generated handler logs error and returns None.""" + handler_file = tmp_path / "handler_gpu_config.py" + handler_file.write_text("def handler(event)\n") # Missing colon + + with patch.dict("os.environ", {"FLASH_RESOURCE_NAME": "gpu_config"}): + with patch("handler.Path", return_value=handler_file): + result = _load_generated_handler() + + assert result is None + + def test_returns_none_on_generic_exception(self, tmp_path): + """Generic exception during module load falls back gracefully.""" + handler_file = tmp_path / "handler_gpu_config.py" + handler_file.write_text("raise RuntimeError('init failed')\n") + + with patch.dict("os.environ", {"FLASH_RESOURCE_NAME": "gpu_config"}): + with patch("handler.Path", return_value=handler_file): + result = _load_generated_handler() + + assert result is None + + def test_warns_when_handler_attr_missing(self, tmp_path): + """Module without 'handler' attribute logs warning and returns None.""" + handler_file = tmp_path / "handler_gpu_config.py" + handler_file.write_text("def not_a_handler(): pass\n") + + with patch.dict("os.environ", {"FLASH_RESOURCE_NAME": "gpu_config"}): + with patch("handler.Path", return_value=handler_file): + result = _load_generated_handler() + + assert result is None + + def test_returns_none_when_resource_name_has_path_traversal(self): + """Path traversal in FLASH_RESOURCE_NAME returns None.""" + with patch.dict("os.environ", {"FLASH_RESOURCE_NAME": "../../../etc/passwd"}): + result = _load_generated_handler() + assert result is None + + def test_returns_none_when_handler_not_callable(self, tmp_path): + """Non-callable 'handler' attribute returns None.""" + handler_file = tmp_path / "handler_gpu_config.py" + handler_file.write_text("handler = 42\n") + + with patch.dict("os.environ", {"FLASH_RESOURCE_NAME": "gpu_config"}): + with patch("handler.Path", return_value=handler_file): + result = _load_generated_handler() + + assert result is None diff --git a/tests/unit/test_lb_handler.py b/tests/unit/test_lb_handler.py new file mode 100644 index 0000000..e4369fd --- /dev/null +++ b/tests/unit/test_lb_handler.py @@ -0,0 +1,134 @@ +"""Tests for lb_handler mode detection and LB auto-discovery logic. + +Tests import the production functions (_is_lb_endpoint, _discover_lb_app) +directly from lb_handler by mocking module-level side effects (maybe_unpack, +RemoteExecutor, etc.) before the import. +""" + +import os +import sys +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import FastAPI + + +# Mock heavy dependencies before importing lb_handler to prevent side effects +_MOCK_MODULES = { + "logger": MagicMock(), + "unpack_volume": MagicMock(), + "remote_executor": MagicMock(), + "runpod_flash": MagicMock(), + "runpod_flash.protos": MagicMock(), + "runpod_flash.protos.remote_execution": MagicMock(), +} + + +@pytest.fixture(autouse=True) +def _import_lb_handler(): + """Import lb_handler with side effects mocked out. + + Patches sys.modules to prevent heavy imports (unpack_volume, RemoteExecutor, + runpod_flash) from executing, then imports lb_handler fresh for each test. + """ + # Remove any cached lb_handler import so we get a fresh one + sys.modules.pop("lb_handler", None) + + with patch.dict("sys.modules", _MOCK_MODULES): + # Prevent module-level _is_lb_endpoint() from triggering LB discovery + with patch.dict("os.environ", {}, clear=False): + os.environ.pop("FLASH_ENDPOINT_TYPE", None) + import lb_handler # noqa: F811 + + yield lb_handler + + sys.modules.pop("lb_handler", None) + + +class TestIsLbEndpoint: + """Tests for the _is_lb_endpoint mode detection function.""" + + def test_flash_endpoint_type_lb_returns_true(self, _import_lb_handler) -> None: + """FLASH_ENDPOINT_TYPE=lb triggers LB mode.""" + with patch.dict("os.environ", {"FLASH_ENDPOINT_TYPE": "lb"}, clear=False): + assert _import_lb_handler._is_lb_endpoint() is True + + def test_no_env_vars_returns_false(self, _import_lb_handler) -> None: + """Neither env var set results in QB mode (returns False).""" + with patch.dict("os.environ", {}, clear=False): + os.environ.pop("FLASH_ENDPOINT_TYPE", None) + assert _import_lb_handler._is_lb_endpoint() is False + + def test_flash_endpoint_type_non_lb_value_returns_false(self, _import_lb_handler) -> None: + """FLASH_ENDPOINT_TYPE with non-lb value does not trigger LB mode.""" + with patch.dict("os.environ", {"FLASH_ENDPOINT_TYPE": "qb"}, clear=False): + assert _import_lb_handler._is_lb_endpoint() is False + + +class TestDiscoverLbApp: + """Tests for the _discover_lb_app auto-discovery function (production code).""" + + def test_raises_when_resource_name_not_set(self, _import_lb_handler) -> None: + """Missing FLASH_RESOURCE_NAME raises RuntimeError with clear message.""" + with patch.dict("os.environ", {}, clear=False): + os.environ.pop("FLASH_RESOURCE_NAME", None) + with pytest.raises(RuntimeError, match="FLASH_RESOURCE_NAME not set"): + _import_lb_handler._discover_lb_app() + + def test_derives_handler_path_from_resource_name(self, _import_lb_handler, tmp_path) -> None: + """Handler file path is {handler_dir}/handler_{resource_name}.py.""" + handler_file = tmp_path / "handler_my_gpu_endpoint.py" + handler_file.write_text("from fastapi import FastAPI\napp = FastAPI()\n") + + with patch.dict("os.environ", {"FLASH_RESOURCE_NAME": "my_gpu_endpoint"}, clear=False): + app = _import_lb_handler._discover_lb_app(handler_dir=str(tmp_path)) + + assert isinstance(app, FastAPI) + + def test_loads_fastapi_app_variable(self, _import_lb_handler, tmp_path) -> None: + """Loads the 'app' variable from the generated handler module.""" + handler_file = tmp_path / "handler_inference.py" + handler_file.write_text( + "from fastapi import FastAPI\n" + "app = FastAPI(title='Test LB Handler')\n" + "@app.get('/health')\n" + "def health(): return {'ok': True}\n" + ) + + with patch.dict("os.environ", {"FLASH_RESOURCE_NAME": "inference"}, clear=False): + app = _import_lb_handler._discover_lb_app(handler_dir=str(tmp_path)) + + assert isinstance(app, FastAPI) + assert app.title == "Test LB Handler" + + def test_raises_when_handler_file_missing(self, _import_lb_handler, tmp_path) -> None: + """Missing handler file raises FileNotFoundError.""" + with patch.dict("os.environ", {"FLASH_RESOURCE_NAME": "nonexistent"}, clear=False): + with pytest.raises(FileNotFoundError): + _import_lb_handler._discover_lb_app(handler_dir=str(tmp_path)) + + def test_raises_attribute_error_when_app_missing(self, _import_lb_handler, tmp_path) -> None: + """Handler module without 'app' attribute raises AttributeError.""" + handler_file = tmp_path / "handler_broken.py" + handler_file.write_text("x = 42\n") + + with patch.dict("os.environ", {"FLASH_RESOURCE_NAME": "broken"}, clear=False): + with pytest.raises(AttributeError, match="does not have 'app' attribute"): + _import_lb_handler._discover_lb_app(handler_dir=str(tmp_path)) + + def test_raises_type_error_when_app_not_fastapi(self, _import_lb_handler, tmp_path) -> None: + """Handler module with non-FastAPI 'app' raises TypeError.""" + handler_file = tmp_path / "handler_wrong_type.py" + handler_file.write_text("app = 'not a FastAPI instance'\n") + + with patch.dict("os.environ", {"FLASH_RESOURCE_NAME": "wrong_type"}, clear=False): + with pytest.raises(TypeError, match="Expected FastAPI instance"): + _import_lb_handler._discover_lb_app(handler_dir=str(tmp_path)) + + def test_raises_when_resource_name_has_path_traversal( + self, _import_lb_handler, tmp_path + ) -> None: + """Path traversal in FLASH_RESOURCE_NAME raises RuntimeError.""" + with patch.dict("os.environ", {"FLASH_RESOURCE_NAME": "../../../etc/passwd"}, clear=False): + with pytest.raises(RuntimeError, match="resolves outside"): + _import_lb_handler._discover_lb_app(handler_dir=str(tmp_path)) diff --git a/tests/unit/test_manifest_reconciliation.py b/tests/unit/test_manifest_reconciliation.py index 2a2e720..824486b 100644 --- a/tests/unit/test_manifest_reconciliation.py +++ b/tests/unit/test_manifest_reconciliation.py @@ -41,14 +41,15 @@ def sample_manifest() -> dict: class TestIsFlashDeployment: """Test Flash deployment detection.""" - def test_is_flash_deployment_mothership(self) -> None: - """Test detection with FLASH_IS_MOTHERSHIP.""" + def test_is_flash_deployment_endpoint_type_lb(self) -> None: + """Test detection with FLASH_ENDPOINT_TYPE=lb.""" with patch.dict( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", }, + clear=True, ): assert is_flash_deployment() is True @@ -69,7 +70,7 @@ def test_is_flash_deployment_no_endpoint_id(self) -> None: with patch.dict( "os.environ", { - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", }, clear=True, ): @@ -86,6 +87,29 @@ def test_is_flash_deployment_not_flash(self) -> None: ): assert is_flash_deployment() is False + def test_is_flash_deployment_endpoint_type_qb(self) -> None: + """Test detection with FLASH_ENDPOINT_TYPE=qb.""" + with patch.dict( + "os.environ", + { + "RUNPOD_ENDPOINT_ID": "ep-001", + "FLASH_ENDPOINT_TYPE": "qb", + }, + clear=True, + ): + assert is_flash_deployment() is True + + def test_is_flash_deployment_endpoint_type_without_endpoint_id(self) -> None: + """Test FLASH_ENDPOINT_TYPE without RUNPOD_ENDPOINT_ID returns False.""" + with patch.dict( + "os.environ", + { + "FLASH_ENDPOINT_TYPE": "lb", + }, + clear=True, + ): + assert is_flash_deployment() is False + class TestSaveManifest: """Test manifest saving.""" @@ -260,7 +284,7 @@ async def test_refresh_no_endpoint_id(self, tmp_path: Path) -> None: """Test refresh skipped when RUNPOD_ENDPOINT_ID not set.""" manifest_path = tmp_path / "manifest.json" - with patch.dict("os.environ", {"FLASH_IS_MOTHERSHIP": "true"}, clear=True): + with patch.dict("os.environ", {"FLASH_ENDPOINT_TYPE": "lb"}, clear=True): result = await refresh_manifest_if_stale(manifest_path) assert result is False @@ -274,7 +298,7 @@ async def test_refresh_no_api_key(self, tmp_path: Path) -> None: "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", }, clear=True, ): @@ -296,7 +320,7 @@ async def test_refresh_fresh_manifest_no_query( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "RUNPOD_API_KEY": "test-key", }, clear=True, @@ -343,7 +367,7 @@ async def test_refresh_stale_manifest_queries_state_manager( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "RUNPOD_API_KEY": "test-key", }, clear=True, @@ -382,7 +406,7 @@ async def test_refresh_state_manager_error_continues( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "RUNPOD_API_KEY": "test-key", }, clear=True, @@ -418,7 +442,7 @@ async def test_refresh_custom_ttl(self, tmp_path: Path, sample_manifest: dict) - "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "RUNPOD_API_KEY": "test-key", }, clear=True, diff --git a/tests/unit/test_remote_executor.py b/tests/unit/test_remote_executor.py index 4b4dde8..746b23e 100644 --- a/tests/unit/test_remote_executor.py +++ b/tests/unit/test_remote_executor.py @@ -7,6 +7,63 @@ from runpod_flash.protos.remote_execution import FunctionRequest +class TestServiceRegistryInit: + """Test conditional ServiceRegistry initialization based on manifest existence.""" + + @patch("remote_executor.Path") + @patch("remote_executor.ServiceRegistry") + def test_init_creates_registry_when_manifest_exists(self, mock_registry_cls, mock_path_cls): + """ServiceRegistry is initialized when flash_manifest.json exists on disk.""" + mock_path = Mock() + mock_path.exists.return_value = True + mock_path_cls.return_value = mock_path + mock_instance = Mock() + mock_registry_cls.return_value = mock_instance + + executor = RemoteExecutor() + + assert executor.service_registry is mock_instance + mock_registry_cls.assert_called_once_with(manifest_path=mock_path) + + @patch("remote_executor.Path") + @patch("remote_executor.ServiceRegistry") + def test_init_skips_registry_when_no_manifest(self, mock_registry_cls, mock_path_cls): + """ServiceRegistry is NOT initialized when flash_manifest.json is absent.""" + mock_path = Mock() + mock_path.exists.return_value = False + mock_path_cls.return_value = mock_path + + executor = RemoteExecutor() + + assert executor.service_registry is None + mock_registry_cls.assert_not_called() + + @patch("remote_executor.Path") + @patch("remote_executor.ServiceRegistry") + def test_init_handles_registry_exception(self, mock_registry_cls, mock_path_cls): + """ServiceRegistry init failure is handled gracefully.""" + mock_path = Mock() + mock_path.exists.return_value = True + mock_path_cls.return_value = mock_path + mock_registry_cls.side_effect = RuntimeError("corrupt manifest") + + executor = RemoteExecutor() + + assert executor.service_registry is None + + @patch("remote_executor.Path") + @patch("remote_executor.ServiceRegistry", None) + def test_init_handles_missing_service_registry_import(self, mock_path_cls): + """Handles case where runpod-flash is not installed (ServiceRegistry is None).""" + mock_path = Mock() + mock_path.exists.return_value = True + mock_path_cls.return_value = mock_path + + executor = RemoteExecutor() + + assert executor.service_registry is None + + class TestRemoteExecutor: """Unit tests for the RemoteExecutor orchestration class.""" diff --git a/tests/unit/test_unpack_volume.py b/tests/unit/test_unpack_volume.py index b33943b..f37c7aa 100644 --- a/tests/unit/test_unpack_volume.py +++ b/tests/unit/test_unpack_volume.py @@ -217,13 +217,13 @@ def test_unpack_app_from_volume_creates_app_dir(self, tmp_path): class TestShouldUnpackFromVolume: """Test environment variable detection logic.""" - def test_should_unpack_for_flash_mothership(self): - """Test unpacking is enabled for Flash Mothership deployment.""" + def test_should_unpack_for_flash_lb_endpoint(self): + """Test unpacking is enabled for Flash LB endpoint deployment.""" with patch.dict( os.environ, { "RUNPOD_ENDPOINT_ID": "test-endpoint-id", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", }, clear=False, ): @@ -252,7 +252,7 @@ def test_should_not_unpack_for_live_serverless(self): clear=False, ): os.environ.pop("FLASH_DISABLE_UNPACK", None) - os.environ.pop("FLASH_IS_MOTHERSHIP", None) + os.environ.pop("FLASH_ENDPOINT_TYPE", None) os.environ.pop("FLASH_RESOURCE_NAME", None) assert _should_unpack_from_volume() is False @@ -270,7 +270,7 @@ def test_should_not_unpack_when_disabled_with_1(self): os.environ, { "RUNPOD_POD_ID": "test-pod-id", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "FLASH_DISABLE_UNPACK": "1", }, ): @@ -282,7 +282,7 @@ def test_should_not_unpack_when_disabled_with_true(self): os.environ, { "RUNPOD_POD_ID": "test-pod-id", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "FLASH_DISABLE_UNPACK": "true", }, ): @@ -294,7 +294,7 @@ def test_should_not_unpack_when_disabled_with_yes(self): os.environ, { "RUNPOD_POD_ID": "test-pod-id", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "FLASH_DISABLE_UNPACK": "yes", }, ): @@ -306,7 +306,7 @@ def test_should_unpack_when_disable_flag_has_wrong_value(self): os.environ, { "RUNPOD_ENDPOINT_ID": "test-endpoint-id", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "FLASH_DISABLE_UNPACK": "false", }, ): @@ -318,7 +318,7 @@ def test_should_not_unpack_when_disabled_with_uppercase_true(self): os.environ, { "RUNPOD_POD_ID": "test-pod-id", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "FLASH_DISABLE_UNPACK": "True", }, ): @@ -330,7 +330,7 @@ def test_should_not_unpack_when_disabled_with_uppercase_yes(self): os.environ, { "RUNPOD_POD_ID": "test-pod-id", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "FLASH_DISABLE_UNPACK": "YES", }, ): @@ -342,7 +342,7 @@ def test_should_not_unpack_when_disabled_with_mixed_case(self): os.environ, { "RUNPOD_POD_ID": "test-pod-id", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_ENDPOINT_TYPE": "lb", "FLASH_DISABLE_UNPACK": "Yes", }, ): diff --git a/uv.lock b/uv.lock index 55204b5..8ce206c 100644 --- a/uv.lock +++ b/uv.lock @@ -3094,7 +3094,7 @@ wheels = [ [[package]] name = "runpod-flash" -version = "1.1.1" +version = "1.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cloudpickle" }, @@ -3107,9 +3107,9 @@ dependencies = [ { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "typer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/80/41/65c69cabfcad4eb2b9a3400bc137d0a23244299e61cb976fcd241b421b3d/runpod_flash-1.1.1.tar.gz", hash = "sha256:679113955e4b739ca0ca24cbbef0edb698915b365dc69ad0f75d8adc38f7cb73", size = 184094 } +sdist = { url = "https://files.pythonhosted.org/packages/50/1c/af089b457dfdf2aed0a9cc6fae6c02c4e0183cbe2d66c7bff2307b84084a/runpod_flash-1.4.0.tar.gz", hash = "sha256:b0bbee1a8ce5a1668307f63a0c23de022d5b18f3b2f874a4173321cb8fb1326e", size = 200703 } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/e1/a6a49318108f363730cb2103be46628219be268d2c8239879892fef6d0eb/runpod_flash-1.1.1-py3-none-any.whl", hash = "sha256:9c7592a640ba10c22c06610b479d475b8f0e710d5ca448472aee732c7c5036bb", size = 198321 }, + { url = "https://files.pythonhosted.org/packages/cb/b8/21b64d49928f6c220658eb39cb31ddc0d9b6667f018ea7f0b99b650c3b1a/runpod_flash-1.4.0-py3-none-any.whl", hash = "sha256:ad9082dab5c177f60c33d152611b6f6d90d3aad83f0ca476c6ce61a817077153", size = 212485 }, ] [[package]] @@ -3701,7 +3701,7 @@ wheels = [ [[package]] name = "worker-flash" -version = "1.0.0" +version = "1.0.1" source = { virtual = "." } dependencies = [ { name = "aiohttp" }, @@ -3741,7 +3741,7 @@ requires-dist = [ { name = "pydantic", specifier = ">=2.11.4" }, { name = "requests", specifier = ">=2.25.0" }, { name = "runpod" }, - { name = "runpod-flash" }, + { name = "runpod-flash", specifier = ">=1.4.0" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" }, ]