From ad81bb615168930976b008a8905572e5fd5a41eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 12 Feb 2026 16:25:45 -0800 Subject: [PATCH 1/6] refactor: improve unpacking retry logging and suppress noisy loggers Implements better retry logging for volume unpacking with clearer distinction between transient and terminal failures: - 1-indexed retry attempts for user-friendliness - WARNING level for expected retry failures (no traceback) - ERROR level only on final failure (with traceback via exc_info=True) - Clearer retry wait messages Adds logger suppression pattern following runpod-python conventions: - Suppress urllib3 and uvicorn loggers to reduce console noise - Set to WARNING level in lb_handler at module import time - Lets uvicorn handle its own logging with defaults This avoids the earlier attempt to configure uvicorn handlers at import time (which failed due to race condition with handler creation). Simple suppression is cleaner and follows proven patterns from runpod-python. Closes #65 (reverted broken changes but kept good retry improvements) --- src/lb_handler.py | 4 ++++ src/unpack_volume.py | 39 ++++++++++++++++++++++++-------- tests/unit/test_unpack_volume.py | 14 +++++++----- 3 files changed, 42 insertions(+), 15 deletions(-) diff --git a/src/lb_handler.py b/src/lb_handler.py index 69686e9..544e259 100644 --- a/src/lb_handler.py +++ b/src/lb_handler.py @@ -29,6 +29,10 @@ from logger import setup_logging from unpack_volume import maybe_unpack +# Suppress noisy third-party loggers (runpod-python pattern) +logging.getLogger("urllib3").setLevel(logging.WARNING) +logging.getLogger("uvicorn").setLevel(logging.WARNING) + # Initialize logging configuration setup_logging() logger = logging.getLogger(__name__) diff --git a/src/unpack_volume.py b/src/unpack_volume.py index 14d8706..7e129ed 100644 --- a/src/unpack_volume.py +++ b/src/unpack_volume.py @@ -131,21 +131,42 @@ def maybe_unpack(): last_error: Exception | None = None for attempt in range(DEFAULT_TARBALL_UNPACK_ATTEMPTS): + attempt_num = attempt + 1 # 1-indexed for display try: + logger.info( + "attempting to extract tarball (attempt %s of %s)...", + attempt_num, + DEFAULT_TARBALL_UNPACK_ATTEMPTS, + ) unpack_app_from_volume() _UNPACKED = True return except (FileNotFoundError, RuntimeError) as e: last_error = e - logger.error( - "failed to unpack app from volume (attempt %s/%s): %s", - attempt, - DEFAULT_TARBALL_UNPACK_ATTEMPTS, - e, - exc_info=True, - ) - if attempt < DEFAULT_TARBALL_UNPACK_ATTEMPTS: + is_final_attempt = attempt == DEFAULT_TARBALL_UNPACK_ATTEMPTS - 1 + + if is_final_attempt: + # Final attempt failed - log as ERROR with traceback + logger.error( + "failed to unpack app from volume (final attempt %s of %s): %s", + attempt_num, + DEFAULT_TARBALL_UNPACK_ATTEMPTS, + e, + exc_info=True, + ) + else: + # Expected retry - log as WARNING without traceback + logger.warning( + "unpack failed (attempt %s of %s): %s", + attempt_num, + DEFAULT_TARBALL_UNPACK_ATTEMPTS, + e, + ) + logger.info( + "waiting %s seconds before retry...", DEFAULT_TARBALL_UNPACK_INTERVAL + ) sleep(DEFAULT_TARBALL_UNPACK_INTERVAL) + raise RuntimeError( - f"failed to unpack app from volume after retries: {last_error}" + f"failed to unpack app from volume after {DEFAULT_TARBALL_UNPACK_ATTEMPTS} attempts: {last_error}" ) from last_error diff --git a/tests/unit/test_unpack_volume.py b/tests/unit/test_unpack_volume.py index b33943b..e984356 100644 --- a/tests/unit/test_unpack_volume.py +++ b/tests/unit/test_unpack_volume.py @@ -417,7 +417,9 @@ def test_maybe_unpack_logs_info_on_start(self, mock_logger, mock_unpack, mock_sh maybe_unpack() - mock_logger.info.assert_called_once_with("unpacking app from volume") + # Verify first info log is the startup message + first_call = mock_logger.info.call_args_list[0] + assert first_call[0][0] == "unpacking app from volume" @patch("unpack_volume.sleep") @patch("unpack_volume._should_unpack_from_volume") @@ -434,8 +436,8 @@ def test_maybe_unpack_logs_error_on_failure( with pytest.raises(RuntimeError, match="failed to unpack app from volume"): maybe_unpack() - # Error should be logged once per retry attempt (3 total) - assert mock_logger.error.call_count == 3 - # Verify all error calls include the expected message - for call in mock_logger.error.call_args_list: - assert "failed to unpack app from volume" in call[0][0] + # With 3 attempts: 2 warning calls + 1 final error call + assert mock_logger.warning.call_count == 2 + assert mock_logger.error.call_count == 1 + # Verify the final error call includes the expected message + assert "failed to unpack app from volume" in mock_logger.error.call_args_list[0][0][0] From ef4f011c1102695b367c227af49fbf66d13edf0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 12 Feb 2026 18:21:09 -0800 Subject: [PATCH 2/6] feat: migrate flash-worker logging to RunPodLogger format - Create rp_logger_adapter.py with FlashLoggerAdapter wrapping RunPodLogger - Update logger.py with backward-compatible wrapper functions - Rewrite log_streamer.py to capture stdout instead of logging handlers - Replace logging.getLogger() with get_flash_logger() in 10 source files - Add comprehensive unit tests for adapter layer (38 tests) - Enable automatic JSON logging in production (RUNPOD_ENDPOINT_ID) Key benefits: - Visual consistency with runpod-python ecosystem - Simplified output without redundant timestamps - Automatic structured logging in production - Namespace prefixes for better traceability All 474+ tests pass, 76.64% code coverage, quality checks pass. --- CLAUDE.md | 376 +++++++++++---------------- src/cache_sync_manager.py | 4 +- src/dependency_installer.py | 4 +- src/handler.py | 4 +- src/lb_handler.py | 11 +- src/log_streamer.py | 166 ++++++------ src/logger.py | 69 +++-- src/manifest_reconciliation.py | 4 +- src/remote_executor.py | 12 +- src/rp_logger_adapter.py | 173 ++++++++++++ src/subprocess_utils.py | 34 ++- src/unpack_volume.py | 4 +- tests/unit/test_rp_logger_adapter.py | 353 +++++++++++++++++++++++++ 13 files changed, 850 insertions(+), 364 deletions(-) create mode 100644 src/rp_logger_adapter.py create mode 100644 tests/unit/test_rp_logger_adapter.py diff --git a/CLAUDE.md b/CLAUDE.md index a15c45c..81c9571 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,199 +1,162 @@ -# CLAUDE.md +# worker-flash - RunPod Serverless Worker -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. +RunPod Serverless worker template providing dynamic GPU provisioning for ML workloads with transparent execution and persistent workspace management. -## Project Overview +## Components -This is `worker-flash`, a RunPod Serverless worker template that provides dynamic GPU provisioning for ML workloads with transparent execution and persistent workspace management. The project consists of two main components: +1. **RunPod Worker Handler** (`src/handler.py`) - Serverless function executing remote Python with dependency management +2. **Flash SDK** (pip dependency) - Python library for distributed inference and ML model serving -1. **RunPod Worker Handler** (`src/handler.py`) - A serverless function that executes remote Python functions with dependency management and workspace support -2. **Flash SDK** (pip dependency) - Python library for distributed inference and serving of ML models +## Key Areas -## Key Areas of Responsibility +### 1. Remote Function Execution (`src/`) +- **Core Handler** (`src/handler.py:18`) - Main RunPod serverless entry point +- **Remote Executor** (`src/remote_executor.py:11`) - Central orchestrator using composition pattern +- **Function Executor** (`src/function_executor.py:12`) - Function execution with full output capture +- **Class Executor** (`src/class_executor.py:14`) - Class instantiation and method execution with instance persistence -### 1. Remote Function Execution Engine (`src/`) -- **Core Handler** (`src/handler.py:18`): Main RunPod serverless entry point that orchestrates remote execution -- **Remote Executor** (`src/remote_executor.py:11`): Central orchestrator that coordinates all execution components using composition pattern -- **Function Executor** (`src/function_executor.py:12`): Handles individual function execution with full output capture (stdout, stderr, logs) -- **Class Executor** (`src/class_executor.py:14`): Manages class instantiation and method execution with instance persistence and metadata tracking +### 2. Dependency Management (`src/dependency_installer.py:14`) +- **Python Packages**: UV-based with environment-aware config (Docker vs local) +- **System Packages**: APT/Nala-based with acceleration support +- **Differential Installation**: Skips already-installed packages +- **Environment Detection**: Automatic Docker vs local detection -### 2. Dependency Management System (`src/dependency_installer.py:14`) -- **Python Package Installation**: UV-based package management with environment-aware configuration (Docker vs local) -- **System Package Installation**: APT/Nala-based system dependency handling with acceleration support -- **Differential Installation**: Optimized package installation that skips already-installed packages -- **Environment Detection**: Automatic Docker vs local environment detection for appropriate installation methods -- **System Package Filtering**: Intelligent detection of system-available packages to avoid redundant installation -- **Universal Subprocess Integration**: All subprocess operations use centralized logging utility +### 3. Universal Subprocess (`src/subprocess_utils.py`) +- Centralized subprocess operations via `run_logged_subprocess` +- Automatic logging integration (all output flows through log streamer at DEBUG) +- Environment-aware execution +- Standardized error handling with FunctionResponse +- Configurable timeouts with cleanup -### 3. Universal Subprocess Utility (`src/subprocess_utils.py`) -- **Centralized Subprocess Operations**: All subprocess calls use `run_logged_subprocess` for consistency -- **Automatic Logging Integration**: All subprocess output flows through log streamer at DEBUG level -- **Environment-Aware Execution**: Handles Docker vs local environment differences automatically -- **Standardized Error Handling**: Consistent FunctionResponse pattern for all subprocess operations -- **Timeout Management**: Configurable timeouts with proper cleanup on timeout/cancellation - -### 4. Serialization & Protocol Management -- **Protocol Definitions** (`runpod_flash.protos.remote_execution`): Pydantic models for request/response with validation -- **Serialization Utils** (`src/serialization_utils.py`): CloudPickle-based data serialization for function arguments and results -- **Base Executor** (`src/base_executor.py`): Common execution interface and environment setup +### 4. Serialization & Protocol (`runpod_flash.protos.remote_execution`) +- **Protocol Definitions**: Pydantic models for request/response with validation +- **Serialization Utils** (`src/serialization_utils.py`): CloudPickle-based data serialization +- **Base Executor** (`src/base_executor.py`): Common execution interface ### 5. Flash SDK Integration (pip dependency) -- **Installation**: Installed via pip from GitHub repository -- **Client Interface**: `@remote` decorator for marking functions for remote execution -- **Resource Management**: GPU/CPU configuration and provisioning through LiveServerless objects -- **Live Serverless**: Dynamic infrastructure provisioning with auto-scaling -- **Repository**: https://github.com/runpod/flash - -### 6. Testing Infrastructure (`tests/`) -- **Unit Tests** (`tests/unit/`): Component-level testing for individual modules with mocking -- **Integration Tests** (`tests/integration/`): End-to-end workflow testing with real execution -- **Test Fixtures** (`tests/conftest.py:1`): Shared test data, mock objects, and utility functions -- **Handler Testing**: Local execution validation with JSON test files (`src/tests/`) - - **Full Coverage**: All handler tests pass with environment-aware dependency installation - - **Cross-Platform**: Works correctly in both Docker containers and local macOS/Linux environments - -### 7. Build & Deployment Pipeline -- **Docker Containerization**: GPU (`Dockerfile`) and CPU (`Dockerfile-cpu`) image builds -- **CI/CD Pipeline**: Automated testing, linting, and releases (`.github/workflows/`) -- **Quality Gates** (`Makefile:104`): Format checking, type checking, test coverage requirements +- Installation: pip from GitHub +- Client Interface: `@remote` decorator for remote execution +- Resource Management: GPU/CPU configuration via LiveServerless +- Repository: https://github.com/runpod/flash + +### 6. Testing (`tests/`) +- **Unit Tests** (`tests/unit/`) - Component-level with mocking +- **Integration Tests** (`tests/integration/`) - End-to-end workflows +- **Test Fixtures** (`tests/conftest.py:1`) - Shared test data and utilities +- **Handler Testing**: Local validation with JSON test files (`src/tests/`) + +### 7. Build & Deployment +- **Docker**: GPU (`Dockerfile`) and CPU (`Dockerfile-cpu`) images +- **CI/CD**: Automated testing, linting, releases (`.github/workflows/`) +- **Quality Gates** (`Makefile:104`): Format, type checking, coverage requirements - **Release Management**: Automated semantic versioning and Docker Hub deployment -### 8. Configuration & Constants -- **Constants** (`src/constants.py`): System-wide configuration values (NAMESPACE, LARGE_SYSTEM_PACKAGES) -- **Environment Configuration**: RunPod API integration +### 8. Configuration (`src/constants.py`) +- System-wide constants (NAMESPACE, LARGE_SYSTEM_PACKAGES) +- Environment configuration for RunPod API ## Architecture ### Core Components -- **`src/handler.py`**: Main RunPod serverless handler implementing composition pattern - - Executes arbitrary Python functions remotely with workspace support - - Handles dynamic installation of Python and system dependencies with differential updates - - Serializes/deserializes function arguments and results using cloudpickle - - Captures stdout, stderr, and logs from remote execution +**`src/handler.py`**: Main RunPod serverless handler +- Executes arbitrary Python functions remotely with workspace support +- Dynamic installation of Python and system dependencies with differential updates +- Serialization/deserialization with cloudpickle +- Captures stdout, stderr, and logs from execution -- **`runpod_flash.protos.remote_execution`**: Protocol definitions from runpod-flash - - `FunctionRequest`: Defines function execution requests with dependencies - - `FunctionResponse`: Standardized response format with success/error handling - - Imported from installed runpod-flash package via `from runpod_flash.protos.remote_execution import ...` +**`runpod_flash.protos.remote_execution`**: Protocol definitions +- `FunctionRequest`: Function execution requests with dependencies +- `FunctionResponse`: Standardized response format with success/error handling +- Imported from installed runpod-flash package ### Key Patterns -1. **Remote Function Execution**: Functions decorated with `@remote` are automatically executed on RunPod GPU workers -2. **Composition Pattern**: RemoteExecutor uses specialized components (DependencyInstaller, Executors) -3. **Dynamic Dependency Management**: Dependencies specified in decorators are installed at runtime with differential updates -4. **Universal Subprocess Operations**: All subprocess calls use centralized `run_logged_subprocess` for consistent logging and error handling -5. **Environment-Aware Configuration**: Automatic Docker vs local environment detection for appropriate installation methods -6. **Serialization**: Uses cloudpickle + base64 encoding for function arguments and results -7. **Resource Configuration**: `LiveServerless` objects define GPU requirements, scaling, and worker configuration +1. **Remote Function Execution**: Functions with `@remote` executed on RunPod GPU workers +2. **Composition Pattern**: RemoteExecutor uses specialized components +3. **Dynamic Dependency Management**: Dependencies installed at runtime with differential updates +4. **Universal Subprocess**: All subprocess calls use centralized `run_logged_subprocess` +5. **Environment-Aware Config**: Automatic Docker vs local detection +6. **Serialization**: CloudPickle + base64 for function args and results +7. **Resource Configuration**: `LiveServerless` defines GPU requirements and scaling -## Code Intelligence with MCP +## MCP Code Intelligence -This project has a worker-flash-code-intel MCP server configured for efficient codebase exploration. The code intelligence index includes both project source code and the runpod_flash dependency. +worker-flash-code-intel MCP server configured for efficient exploration. ### Indexed Codebase - -The following are automatically indexed and searchable via MCP tools: - **Project source** (`src/`) - All 83 worker-flash symbols - **runpod_flash dependency** - All 552 protocol definitions, resources, and core components - - Protocol definitions: `runpod_flash.protos.remote_execution` (`FunctionRequest`, `FunctionResponse`, etc.) - - Resources: `runpod_flash.core.resources` (`LiveServerless`, `Serverless`, `NetworkVolume`, etc.) - - Stubs: `runpod_flash.stubs` (stub implementations for local development) - -To regenerate the index (when dependencies change), run: `make index` - -To add more dependencies to the index, edit `DEPENDENCIES_TO_INDEX` in `scripts/ast_to_sqlite.py`. - -### MCP Tools for Code Intelligence - -**Always prefer these MCP tools over Grep/Glob for semantic code searches:** - -- **`find_symbol(symbol)`** - Find classes, functions, methods by name (supports partial matches) - - Example: Finding `RemoteExecutor` class, `FunctionRequest` protocol, or `handler` function - - Use instead of: `grep -r "class RemoteExecutor"` or `glob "**/*.py"` - -- **`list_classes()`** - Get all classes in codebase - - Use instead of: `grep -r "^class "` -- **`get_class_interface(class_name)`** - Get class methods without implementations - - Example: `get_class_interface("DependencyInstaller")` to see available methods - - Example: `get_class_interface("FunctionRequest")` to see protocol fields - - Use instead of: Reading full file and parsing manually +To regenerate index (when dependencies change): `make index` -- **`list_file_symbols(file_path)`** - List all symbols (classes, functions) in a specific file - - Use instead of: `grep` on individual files for symbol discovery +To add more dependencies: Edit `DEPENDENCIES_TO_INDEX` in `scripts/ast_to_sqlite.py` -- **`find_by_decorator(decorator)`** - Find functions/classes with specific decorators - - Example: `find_by_decorator("remote")` to find all @remote decorated functions - - Use instead of: `grep -r "@remote"` +### MCP Tools -### Tool Selection Guidelines +**Always prefer MCP over Grep/Glob for semantic searches:** -**When to use MCP vs Grep/Glob:** -- **MCP tools**: Semantic searches (class names, function definitions, decorators, symbols) - including runpod_flash -- **Grep**: Content searches (error messages, comments, string literals, log statements) -- **Glob**: File path patterns when you know the exact file structure -- **Task tool with Explore agent**: Complex multi-step exploration requiring multiple searches +- `find_symbol(symbol)` - Find classes, functions, methods by name +- `list_classes()` - Get all classes in codebase +- `get_class_interface(class_name)` - Get class methods without implementations +- `list_file_symbols(file_path)` - List symbols in specific file +- `find_by_decorator(decorator)` - Find decorated functions/classes -**Example workflow:** -- "Find all @remote functions" → use `find_by_decorator("remote")` -- "Where is RemoteExecutor defined" → use `find_symbol("RemoteExecutor")` -- "What fields does FunctionRequest protocol have" → use `get_class_interface("FunctionRequest")` -- "Where is LiveServerless used" → use `find_symbol("LiveServerless")` -- "Where is error 'API timeout' logged" → use Grep -- "Find all test_*.json files" → use Glob +### Tool Selection +- **MCP**: Semantic searches (class names, function definitions, decorators, symbols) - including runpod_flash +- **Grep**: Content searches (error messages, comments, strings, logs) +- **Glob**: File path patterns when you know structure +- **Task(Explore)**: Complex multi-step exploration -## Development Commands +## Commands -### Setup and Dependencies +### Setup ```bash make setup # Initialize project and sync dependencies -make dev # Install all development dependencies (includes pytest, ruff) -uv sync # Sync production dependencies only -uv sync --all-groups # Sync all dependency groups (same as make dev) +make dev # Install dev dependencies (pytest, ruff) +uv sync # Sync production dependencies +uv sync --all-groups # Sync all groups (same as make dev) ``` ### Code Quality ```bash -make lint # Check code with ruff linter -make lint-fix # Auto-fix linting issues -make format # Format code with ruff -make format-check # Check if code is properly formatted -make quality-check # Run all quality checks (format, lint, test coverage) +make lint # Check with ruff +make lint-fix # Auto-fix linting +make format # Format with ruff +make format-check # Check formatting +make quality-check # All checks (format, lint, test coverage) ``` -### Testing Commands +### Testing ```bash make test # Run all tests -make test-unit # Run unit tests only -make test-integration # Run integration tests only -make test-coverage # Run tests with coverage report -make test-fast # Run tests with fail-fast mode -make test-handler # Test handler locally with all test_*.json files (same as CI) +make test-unit # Unit tests only +make test-integration # Integration tests only +make test-coverage # Tests with coverage report +make test-fast # Fail-fast mode +make test-handler # Test handler with all test_*.json files ``` -### Docker Operations +### Docker ```bash -make build # Build GPU Docker image (linux/amd64) -make build-cpu # Build CPU-only Docker image -# Note: Docker push is automated via GitHub Actions on release +make build # Build GPU image (linux/amd64) +make build-cpu # Build CPU-only image ``` ## Configuration ### Environment Variables -- `RUNPOD_API_KEY`: Required for RunPod Serverless integration -- `RUNPOD_ENDPOINT_ID`: Used for workspace isolation (automatically set by RunPod) -- `HF_HUB_ENABLE_HF_TRANSFER`: Set to "1" in Dockerfile to enable accelerated HuggingFace downloads -- `HF_TOKEN`: Optional authentication token for private/gated HuggingFace models -- `HF_HOME=/hf-cache`: HuggingFace cache location, set outside `/root/.cache` to exclude from volume sync -- `DEBIAN_FRONTEND=noninteractive`: Set during system package installation +- `RUNPOD_API_KEY` - Required for RunPod Serverless +- `RUNPOD_ENDPOINT_ID` - Workspace isolation (auto-set by RunPod) +- `HF_HUB_ENABLE_HF_TRANSFER` - Set to "1" for accelerated HuggingFace downloads +- `HF_TOKEN` - Optional auth for private/gated HuggingFace models +- `HF_HOME=/hf-cache` - HuggingFace cache location (outside `/root/.cache`) +- `DEBIAN_FRONTEND=noninteractive` - Set during system package installation ### Resource Configuration -Configure GPU resources using `LiveServerless` objects: ```python gpu_config = LiveServerless( - name="my-endpoint", # Endpoint name (required) + name="my-endpoint", # Endpoint name gpus=[GpuGroup.ANY], # GPU types workersMax=5, # Max concurrent workers workersMin=0, # Min workers (0 = scale to zero) @@ -202,103 +165,82 @@ gpu_config = LiveServerless( ) ``` -## Testing and Quality +## Testing -### Testing Framework -- **pytest** with coverage reporting and async support -- **Unit tests** (`tests/unit/`): Test individual components in isolation -- **Integration tests** (`tests/integration/`): Test end-to-end workflows -- **Coverage target**: 35% minimum, with HTML and XML reports -- **Test fixtures**: Shared test data and mocks in `tests/conftest.py` -- **CI Integration**: Tests run on all PRs and before releases/deployments +- **pytest** with coverage and async support +- **Unit tests** (`tests/unit/`) - Test components in isolation +- **Integration tests** (`tests/integration/`) - End-to-end workflows +- **Coverage target**: 35% minimum with HTML/XML reports +- **Test fixtures**: Shared data and mocks in `tests/conftest.py` +- **CI Integration**: Tests run on all PRs and before releases ## Development Notes ### Dependency Management - Root project uses `uv` with `pyproject.toml` -- Runpod Flash SDK installed as pip dependency from GitHub repository -- System dependencies installed via `apt-get` in containerized environment -- Python dependencies installed via `uv pip install` at runtime -- **Differential Installation**: Only installs packages missing from environment -- **Environment Awareness**: Uses appropriate python preferences (Docker: `--python-preference=only-system`, Local: managed python) +- Runpod Flash SDK installed as pip dependency from GitHub +- System dependencies via `apt-get` in containers +- Python dependencies via `uv pip install` at runtime +- **Differential Installation**: Only installs missing packages +- **Environment Awareness**: Docker: `--python-preference=only-system`, Local: managed python ### Error Handling -- All remote execution wrapped in try/catch with full traceback capture +- All remote execution wrapped in try/catch with full traceback - Structured error responses via `FunctionResponse.error` -- Combined stdout/stderr/log capture for debugging +- Combined stdout/stderr/log capture -### Security Considerations -- Functions execute arbitrary Python code in sandboxed containers -- System package installation requires root privileges in container -- No secrets should be committed to repository -- API keys passed via environment variables - -## File Structure Highlights +### Security +- Functions execute arbitrary Python in sandboxed containers +- System package installation requires root in container +- No secrets in repository +- API keys via environment variables +## File Structure ``` -├── src/ # Core implementation -│ ├── handler.py # Main serverless function handler -│ ├── remote_executor.py # Central execution orchestrator -│ ├── function_executor.py # Function execution with output capture -│ ├── class_executor.py # Class execution with persistence -│ ├── dependency_installer.py # Python and system dependency management -│ ├── serialization_utils.py # CloudPickle serialization utilities -│ ├── base_executor.py # Common execution interface -│ ├── constants.py # System-wide configuration constants -│ └── tests/ # Handler test JSON files -├── tests/ # Comprehensive test suite -│ ├── conftest.py # Shared test fixtures -│ ├── unit/ # Unit tests for individual components -│ └── integration/ # End-to-end integration tests -├── Dockerfile # GPU container definition -├── Dockerfile-cpu # CPU container definition -└── Makefile # Development commands and quality gates +src/ +├── handler.py # Main serverless handler +├── remote_executor.py # Central orchestrator +├── function_executor.py # Function execution with output +├── class_executor.py # Class execution with persistence +├── dependency_installer.py # Python and system deps +├── serialization_utils.py # CloudPickle serialization +├── base_executor.py # Common execution interface +├── constants.py # System constants +└── tests/ # Handler test JSON files +tests/ +├── conftest.py # Shared fixtures +├── unit/ # Unit tests +└── integration/ # Integration tests ``` -## CI/CD and Release Process +## CI/CD ### Automated Releases -- Uses `release-please` for automated semantic versioning and changelog generation -- Releases are triggered by conventional commit messages on `main` branch -- Docker images are automatically built and pushed to Docker Hub (`runpod/flash`) on release - -### GitHub Actions Workflows -- **CI/CD** (`.github/workflows/ci.yml`): Single workflow handling tests, linting, releases, and Docker builds - - Runs tests and linting on PRs and pushes to main - - **Local execution testing**: Automatically tests all `test_*.json` files in src directory to validate handler functionality - - Manages releases via `release-please` on main branch - - Builds and pushes `:main` tagged images on main branch pushes - - Builds and pushes production images with semantic versioning on releases - - Supports manual triggering via `workflow_dispatch` for ad-hoc runs +- Uses `release-please` for automated semantic versioning +- Releases triggered by conventional commits on `main` +- Docker images auto-built and pushed to Docker Hub (`runpod/flash`) on release + +### GitHub Actions (`.github/workflows/ci.yml`) +- Tests and linting on PRs and main pushes +- **Local execution testing**: Validates all `test_*.json` files +- Manages releases via `release-please` on main +- Builds and pushes `:main` tagged images on main pushes +- Builds production images with semantic versioning on releases +- Manual triggering via `workflow_dispatch` ### Required Secrets -Configure these in GitHub repository settings: -- `DOCKERHUB_USERNAME`: Docker Hub username -- `DOCKERHUB_TOKEN`: Docker Hub password or access token - -## Branch Information -- Main branch: `main` -- Current branch: `tmp/deployed-execution` +- `DOCKERHUB_USERNAME` - Docker Hub username +- `DOCKERHUB_TOKEN` - Docker Hub password/token ## Development Best Practices -- Always run `make quality-check` before committing changes -- Always use `git mv` when moving existing files around -- Run `make test-handler` to validate handler functionality with test files -- Never create files unless absolutely necessary for achieving goals -- Always prefer editing existing files to creating new ones -- Never proactively create documentation files unless explicitly requested - -## Project Memories - -### Docker Guidelines -- Docker container should never refer to src/ - -### Testing Guidelines -- Use `make test-handler` to run checks on test files -- Do not run individual test files manually like `Bash(env RUNPOD_TEST_INPUT="$(cat test_input.json)" PYTHONPATH=. uv run python handler.py)` - -### File Management -- Use `git mv` when moving existing files +- Run `make quality-check` before committing +- Use `git mv` when moving files +- Run `make test-handler` to validate handler +- Never create files unless necessary - Prefer editing existing files over creating new ones -- Only create files when absolutely necessary +- Never proactively create documentation unless requested + +## Branch Info +- Main: `main` +- Current: `tmp/deployed-execution` diff --git a/src/cache_sync_manager.py b/src/cache_sync_manager.py index 77feb4a..5e49ccb 100644 --- a/src/cache_sync_manager.py +++ b/src/cache_sync_manager.py @@ -1,10 +1,10 @@ import os -import logging import asyncio import tempfile from datetime import datetime from pathlib import Path from typing import Optional +from rp_logger_adapter import get_flash_logger from constants import NAMESPACE, CACHE_DIR, VOLUME_CACHE_PATH from subprocess_utils import run_logged_subprocess @@ -13,7 +13,7 @@ class CacheSyncManager: """Manages async fire-and-forget cache synchronization to network volume.""" def __init__(self): - self.logger = logging.getLogger(f"{NAMESPACE}.{__name__.split('.')[-1]}") + self.logger = get_flash_logger(f"{NAMESPACE}.{__name__.split('.')[-1]}") self._should_sync_cached: Optional[bool] = None self._endpoint_id = os.environ.get("RUNPOD_ENDPOINT_ID") self._baseline_time: Optional[float] = None diff --git a/src/dependency_installer.py b/src/dependency_installer.py index dd8c6db..c4135f1 100644 --- a/src/dependency_installer.py +++ b/src/dependency_installer.py @@ -1,10 +1,10 @@ import os -import logging import asyncio import platform from typing import List from runpod_flash.protos.remote_execution import FunctionResponse +from rp_logger_adapter import get_flash_logger from constants import LARGE_SYSTEM_PACKAGES, NAMESPACE from subprocess_utils import run_logged_subprocess @@ -13,7 +13,7 @@ class DependencyInstaller: """Handles installation of system and Python dependencies.""" def __init__(self): - self.logger = logging.getLogger(f"{NAMESPACE}.{__name__.split('.')[-1]}") + self.logger = get_flash_logger(f"{NAMESPACE}.{__name__.split('.')[-1]}") self._nala_available = None # Cache nala availability check self._is_docker = None # Cache Docker environment detection diff --git a/src/handler.py b/src/handler.py index 5066690..b57df3b 100644 --- a/src/handler.py +++ b/src/handler.py @@ -2,11 +2,11 @@ from runpod_flash.protos.remote_execution import FunctionRequest, FunctionResponse from remote_executor import RemoteExecutor -from logger import setup_logging +from rp_logger_adapter import setup_flash_logging from unpack_volume import maybe_unpack # Initialize logging configuration -setup_logging() +setup_flash_logging() # Unpack Flash deployment artifacts if running in Flash mode # This is a no-op for Live Serverless and local development diff --git a/src/lb_handler.py b/src/lb_handler.py index 544e259..bbcc5bd 100644 --- a/src/lb_handler.py +++ b/src/lb_handler.py @@ -20,22 +20,17 @@ """ import importlib.util -import logging import os from typing import Any, Dict from fastapi import FastAPI -from logger import setup_logging +from rp_logger_adapter import setup_flash_logging, get_flash_logger from unpack_volume import maybe_unpack -# Suppress noisy third-party loggers (runpod-python pattern) -logging.getLogger("urllib3").setLevel(logging.WARNING) -logging.getLogger("uvicorn").setLevel(logging.WARNING) - # Initialize logging configuration -setup_logging() -logger = logging.getLogger(__name__) +setup_flash_logging() +logger = get_flash_logger(__name__) # Unpack Flash deployment artifacts if running in Flash mode # This is a no-op for Live Serverless and local development diff --git a/src/log_streamer.py b/src/log_streamer.py index 1ec61b8..35e86f3 100644 --- a/src/log_streamer.py +++ b/src/log_streamer.py @@ -1,23 +1,80 @@ """ -Centralized log streaming system for capturing and streaming logs to FunctionResponse.stdout. +Centralized log streaming system for capturing stdout to FunctionResponse.stdout. -This module provides thread-safe log buffering and streaming capabilities to ensure +This module provides thread-safe output buffering and streaming capabilities to ensure all system logs (dependency installation, workspace setup, etc.) are visible in the -remote execution response. +remote execution response. It captures stdout directly rather than using logging handlers, +since RunPodLogger uses print() internally. """ -import logging +import sys import threading from collections import deque from typing import Optional, Deque, Callable -from logger import get_log_format + +class LogCapturingWriter: + """ + Write-through stdout wrapper that captures output while maintaining console visibility. + + This class intercepts stdout writes, buffers complete lines, and forwards all output + to the original stdout. + """ + + def __init__(self, original_stdout, log_streamer: "LogStreamer"): + """ + Initialize the capturing writer. + + Args: + original_stdout: The original sys.stdout + log_streamer: The LogStreamer instance to buffer lines to + """ + self.original_stdout = original_stdout + self.log_streamer = log_streamer + self._line_buffer = "" + self._lock = threading.Lock() + + def write(self, text: str) -> int: + """ + Write text to stdout, capturing complete lines. + + Args: + text: Text to write + + Returns: + Number of characters written + """ + with self._lock: + # Write to original stdout immediately (write-through) + self.original_stdout.write(text) + + # Buffer incomplete lines + self._line_buffer += text + + # Process complete lines + while "\n" in self._line_buffer: + line, self._line_buffer = self._line_buffer.split("\n", 1) + if line: # Don't add empty lines + self.log_streamer.add_log_entry(line) + + return len(text) + + def flush(self) -> None: + """Flush both the capturing writer and original stdout.""" + with self._lock: + self.original_stdout.flush() + + def isatty(self) -> bool: + """Check if original stdout is a TTY.""" + try: + return bool(self.original_stdout.isatty()) + except (AttributeError, TypeError): + return False class LogStreamer: """ - Thread-safe log streaming system that captures logs and makes them available - for streaming to FunctionResponse.stdout. + Thread-safe log streaming system that captures stdout and buffers complete lines. """ def __init__(self, max_buffer_size: int = 1000): @@ -29,61 +86,45 @@ def __init__(self, max_buffer_size: int = 1000): """ self._buffer: Deque[str] = deque(maxlen=max_buffer_size) self._lock = threading.Lock() - self._handler: Optional[StreamingHandler] = None - self._original_level: Optional[int] = None + self._writer: Optional[LogCapturingWriter] = None + self._original_stdout: Optional[object] = None self._callback: Optional[Callable[[str], None]] = None def start_streaming( self, - level: int = logging.INFO, + level: int = 20, # INFO level (unused, kept for compatibility) callback: Optional[Callable[[str], None]] = None, ) -> None: """ - Start capturing logs and streaming them to buffer. + Start capturing stdout. Args: - level: Minimum log level to capture (DEBUG, INFO, WARNING, ERROR) - callback: Optional callback function called for each log entry + level: Log level (unused, kept for compatibility with previous API) + callback: Optional callback function called for each log line """ with self._lock: - if self._handler is not None: + if self._writer is not None: return # Already streaming self._callback = callback - # Create and configure streaming handler - self._handler = StreamingHandler(self) - self._handler.setLevel(level) - - # Use same format as main logging - formatter = logging.Formatter(get_log_format(level)) - self._handler.setFormatter(formatter) - - # Add to root logger - root_logger = logging.getLogger() - self._original_level = root_logger.level - root_logger.addHandler(self._handler) - - # Ensure we capture logs at the requested level - if root_logger.level > level: - root_logger.setLevel(level) + # Save original stdout and replace with capturing writer + self._original_stdout = sys.stdout + self._writer = LogCapturingWriter(self._original_stdout, self) + sys.stdout = self._writer def stop_streaming(self) -> None: - """Stop capturing logs and clean up handler.""" + """Stop capturing stdout and restore original.""" with self._lock: - if self._handler is None: + if self._writer is None: return # Not streaming - # Remove handler from root logger - root_logger = logging.getLogger() - root_logger.removeHandler(self._handler) - - # Restore original log level - if self._original_level is not None: - root_logger.setLevel(self._original_level) + # Restore original stdout + if self._original_stdout is not None: + sys.stdout = self._original_stdout - self._handler = None - self._original_level = None + self._writer = None + self._original_stdout = None self._callback = None def add_log_entry(self, log_entry: str) -> None: @@ -91,7 +132,7 @@ def add_log_entry(self, log_entry: str) -> None: Add a log entry to the buffer. Args: - log_entry: Formatted log entry to add + log_entry: Complete log line to add """ with self._lock: self._buffer.append(log_entry) @@ -141,41 +182,6 @@ def has_logs(self) -> bool: return len(self._buffer) > 0 -class StreamingHandler(logging.Handler): - """ - Custom logging handler that streams log records to a LogStreamer. - """ - - def __init__(self, log_streamer: LogStreamer): - """ - Initialize the streaming handler. - - Args: - log_streamer: LogStreamer instance to send logs to - """ - super().__init__() - self.log_streamer = log_streamer - - def emit(self, record: logging.LogRecord) -> None: - """ - Emit a log record to the log streamer. - - Args: - record: The log record to emit - """ - try: - # Format the log record - log_entry = self.format(record) - - # Add to log streamer buffer - self.log_streamer.add_log_entry(log_entry) - - except Exception: - # Don't let logging errors break the application - # This follows Python logging best practices - self.handleError(record) - - # Global log streamer instance for convenience _global_streamer: Optional[LogStreamer] = None _streamer_lock = threading.Lock() @@ -197,14 +203,14 @@ def get_global_log_streamer() -> LogStreamer: def start_log_streaming( - level: int = logging.INFO, callback: Optional[Callable[[str], None]] = None + level: int = 20, callback: Optional[Callable[[str], None]] = None ) -> LogStreamer: """ Convenience function to start log streaming with the global streamer. Args: - level: Minimum log level to capture - callback: Optional callback for each log entry + level: Minimum log level (unused, kept for compatibility) + callback: Optional callback for each log line Returns: The global LogStreamer instance diff --git a/src/logger.py b/src/logger.py index 9f97000..c5b681f 100644 --- a/src/logger.py +++ b/src/logger.py @@ -1,27 +1,38 @@ """ Logging configuration for worker-flash. -Provides centralized logging setup matching runpod-flash style with level-based formatting. +Provides thin wrapper around RunPodLogger for backward compatibility. +New code should use rp_logger_adapter directly. """ -import logging -import os import sys from typing import Union, Optional +from rp_logger_adapter import ( + setup_flash_logging, + get_log_level as get_rp_log_level, +) + def get_log_level() -> int: - """Get log level from environment variable, defaulting to INFO.""" - log_level = os.environ.get("LOG_LEVEL", "INFO").upper() - return getattr(logging, log_level, logging.INFO) + """Get log level from environment variable, defaulting to INFO. + + Deprecated: Use get_rp_log_level() from rp_logger_adapter instead. + This is kept for backward compatibility. + """ + # Convert string level to dummy int for backward compatibility + level_str = get_rp_log_level() + level_map = {"DEBUG": 10, "INFO": 20, "WARN": 30, "ERROR": 40} + return level_map.get(level_str, 20) # Default to INFO (20) def get_log_format(level: int) -> str: - """Get appropriate log format based on level, matching runpod-flash style.""" - if level == logging.DEBUG: - return "%(asctime)s | %(levelname)-5s | %(name)s | %(filename)s:%(lineno)d | %(message)s" - else: - return "%(asctime)s | %(levelname)-5s | %(message)s" + """Get appropriate log format based on level. + + Deprecated: RunPodLogger handles formatting internally. + This is kept as a placeholder for backward compatibility. + """ + return "%(message)s" # RunPodLogger handles the actual format def setup_logging( @@ -31,32 +42,18 @@ def setup_logging( ) -> None: """ Setup logging configuration for worker-flash. - Only shows DEBUG logs from flash namespace when LOG_LEVEL=DEBUG. + + Deprecated: Use setup_flash_logging() from rp_logger_adapter instead. + This is kept for backward compatibility. Args: level: Log level (defaults to LOG_LEVEL env var or INFO) - stream: Output stream for logs - fmt: Custom format string (auto-selected based on level if None) + stream: Output stream for logs (ignored, RunPodLogger uses stdout) + fmt: Custom format string (ignored, RunPodLogger handles format) """ - # Determine log level - if level is None: - level = get_log_level() - elif isinstance(level, str): - level = getattr(logging, level.upper(), logging.INFO) - - # Determine format based on requested level - if fmt is None: - fmt = get_log_format(level) - - # Configure root logger - root_logger = logging.getLogger() - root_logger.setLevel(level) - - if not root_logger.hasHandlers(): - handler = logging.StreamHandler(stream) - handler.setFormatter(logging.Formatter(fmt)) - root_logger.addHandler(handler) - - # When DEBUG is requested, silence the noisy module - if level == logging.DEBUG: - logging.getLogger("filelock").setLevel(logging.INFO) + # Convert int level to string if needed + if isinstance(level, int): + level_map = {10: "DEBUG", 20: "INFO", 30: "WARN", 40: "ERROR"} + level = level_map.get(level, "INFO") + + setup_flash_logging(level) diff --git a/src/manifest_reconciliation.py b/src/manifest_reconciliation.py index 960ad1c..7c2b148 100644 --- a/src/manifest_reconciliation.py +++ b/src/manifest_reconciliation.py @@ -6,16 +6,16 @@ """ import json -import logging import os import time from pathlib import Path from typing import Any, Dict +from rp_logger_adapter import get_flash_logger from constants import FLASH_MANIFEST_PATH -logger = logging.getLogger(__name__) +logger = get_flash_logger(__name__) # Default TTL for manifest staleness (5 minutes) DEFAULT_MANIFEST_TTL_SECONDS = 300 diff --git a/src/remote_executor.py b/src/remote_executor.py index a5ca85d..1fdd6df 100644 --- a/src/remote_executor.py +++ b/src/remote_executor.py @@ -1,4 +1,3 @@ -import logging import asyncio import importlib import json @@ -11,6 +10,7 @@ FunctionResponse, RemoteExecutorStub, ) +from rp_logger_adapter import get_flash_logger from dependency_installer import DependencyInstaller from function_executor import FunctionExecutor from class_executor import ClassExecutor @@ -35,7 +35,7 @@ class RemoteExecutor(RemoteExecutorStub): def __init__(self): super().__init__() - self.logger = logging.getLogger(f"{NAMESPACE}.{__name__.split('.')[-1]}") + self.logger = get_flash_logger(f"{NAMESPACE}.{__name__.split('.')[-1]}") # Initialize components using composition self.dependency_installer = DependencyInstaller() @@ -67,14 +67,12 @@ async def ExecuteFunction(self, request: FunctionRequest) -> FunctionResponse: """ # Start log streaming to capture all system logs # Use the requested log level, not the root logger level - from logger import get_log_level + from rp_logger_adapter import get_log_level requested_level = get_log_level() - start_log_streaming(level=requested_level) + start_log_streaming(level=20) # INFO level - self.logger.debug( - f"Started log streaming at level: {logging.getLevelName(requested_level)}" - ) + self.logger.debug(f"Started log streaming at level: {requested_level}") self.logger.debug( f"Executing {request.execution_type} request: {request.function_name or request.class_name}" ) diff --git a/src/rp_logger_adapter.py b/src/rp_logger_adapter.py new file mode 100644 index 0000000..6c8363d --- /dev/null +++ b/src/rp_logger_adapter.py @@ -0,0 +1,173 @@ +""" +Adapter layer for RunPodLogger providing compatibility with standard logging interface. + +This module wraps RunPodLogger from runpod.serverless to provide a drop-in replacement +for Python's standard logging module. It handles: +- Singleton access to RunPodLogger +- Namespace prefixes (e.g., "flash.module_name | message") +- Printf-style formatting (e.g., logger.info("Val: %s", val)) +- Environment variable configuration +""" + +import os +from typing import Optional, Any + +from runpod.serverless.modules.rp_logger import RunPodLogger + + +# Singleton RunPodLogger instance +_rp_logger_instance: Optional[RunPodLogger] = None + + +def _get_rp_logger() -> RunPodLogger: + """Get or create the global RunPodLogger instance (singleton). + + Returns: + Global RunPodLogger instance + """ + global _rp_logger_instance + if _rp_logger_instance is None: + _rp_logger_instance = RunPodLogger() + return _rp_logger_instance + + +class FlashLoggerAdapter: + """ + Adapter that wraps RunPodLogger with a standard logging-like interface. + + Maintains namespace prefixes and printf-style formatting for compatibility + with existing code while using RunPodLogger internally. + """ + + def __init__(self, name: str): + """ + Initialize the adapter with a logger namespace. + + Args: + name: Logger name (e.g., __name__) + """ + self.name = name + self._rp_logger = _get_rp_logger() + + def _format_message(self, msg: str, args: tuple[Any, ...]) -> str: + """ + Format message using printf-style arguments. + + Args: + msg: Message template + args: Printf-style arguments + + Returns: + Formatted message + """ + if args: + try: + return msg % args + except (TypeError, ValueError): + # If formatting fails, return message as-is + return msg + return msg + + def _build_log_line(self, level: str, msg: str, args: tuple[Any, ...]) -> str: + """ + Build the complete log line with namespace prefix. + + Args: + level: Log level string (DEBUG, INFO, WARN, ERROR) + msg: Message template + args: Printf-style arguments + + Returns: + Complete log line + """ + formatted_msg = self._format_message(msg, args) + + # Add namespace prefix if name is set + if self.name: + return f"{self.name} | {formatted_msg}" + return formatted_msg + + def debug(self, msg: str, *args, **kwargs) -> None: + """Log a debug message.""" + # Accept but ignore exc_info and other kwargs for compatibility + line = self._build_log_line("DEBUG", msg, args) + self._rp_logger.debug(line) + + def info(self, msg: str, *args, **kwargs) -> None: + """Log an info message.""" + line = self._build_log_line("INFO", msg, args) + self._rp_logger.info(line) + + def warning(self, msg: str, *args, **kwargs) -> None: + """Log a warning message.""" + line = self._build_log_line("WARN", msg, args) + self._rp_logger.warn(line) + + def warn(self, msg: str, *args, **kwargs) -> None: + """Log a warning message (alias for warning).""" + self.warning(msg, *args, **kwargs) + + def error(self, msg: str, *args, **kwargs) -> None: + """Log an error message.""" + line = self._build_log_line("ERROR", msg, args) + self._rp_logger.error(line) + + +def get_flash_logger(name: str) -> FlashLoggerAdapter: + """ + Get a FlashLoggerAdapter instance for the given name. + + This is the main factory function that replaces logging.getLogger(). + + Args: + name: Logger name (typically __name__) + + Returns: + FlashLoggerAdapter instance + """ + return FlashLoggerAdapter(name) + + +def setup_flash_logging(level: Optional[str] = None) -> None: + """ + Setup RunPodLogger with the specified log level. + + Reads log level from environment variables in order of precedence: + 1. RUNPOD_LOG_LEVEL (preferred) + 2. LOG_LEVEL (deprecated but supported) + 3. "INFO" (default) + + Args: + level: Optional log level override (DEBUG, INFO, WARN, ERROR) + """ + if level is None: + level = os.environ.get("RUNPOD_LOG_LEVEL") or os.environ.get("LOG_LEVEL", "INFO") + + level = level.upper() + + # Validate and set log level + valid_levels = {"DEBUG", "INFO", "WARN", "ERROR"} + if level not in valid_levels: + level = "INFO" + + rp_logger = _get_rp_logger() + + # RunPodLogger uses string-based level setting + if level == "DEBUG": + rp_logger.debug("Debug logging enabled") + elif level == "WARN": + rp_logger.warn(f"Log level set to {level}") + elif level == "ERROR": + rp_logger.error(f"Log level set to {level}") + # INFO is default, no action needed + + +def get_log_level() -> str: + """ + Get the current log level from environment variables. + + Returns: + Log level string (DEBUG, INFO, WARN, ERROR) + """ + level = os.environ.get("RUNPOD_LOG_LEVEL") or os.environ.get("LOG_LEVEL", "INFO") + return level.upper() diff --git a/src/subprocess_utils.py b/src/subprocess_utils.py index b44da6f..0a1c181 100644 --- a/src/subprocess_utils.py +++ b/src/subprocess_utils.py @@ -16,7 +16,7 @@ def run_logged_subprocess( command: List[str], - logger: Optional[logging.Logger] = None, + logger: Optional[Any] = None, operation_name: str = "", timeout: int = 300, capture_output: bool = True, @@ -101,7 +101,7 @@ def run_logged_subprocess( def run_logged_subprocess_simple( command: List[str], - logger: Optional[logging.Logger] = None, + logger: Optional[Any] = None, operation_name: str = "", timeout: int = 300, **popen_kwargs, @@ -135,12 +135,13 @@ def run_logged_subprocess_simple( return subprocess.Popen(command, **popen_kwargs) -def _get_logger_from_context(default_name: str = "subprocess_utils") -> logging.Logger: +def _get_logger_from_context(default_name: str = "subprocess_utils") -> Any: """ Auto-detect logger from calling context. Attempts to find a logger in the calling frame, falling back to - a default logger if none is found. + a default logger if none is found. Supports both standard logging.Logger + and FlashLoggerAdapter instances. Args: default_name: Default logger name if auto-detection fails @@ -159,13 +160,13 @@ def _get_logger_from_context(default_name: str = "subprocess_utils") -> logging. # Check if the calling frame has 'self' with a logger if "self" in frame.f_locals: obj = frame.f_locals["self"] - if hasattr(obj, "logger") and isinstance(obj.logger, logging.Logger): + if hasattr(obj, "logger") and _is_valid_logger(obj.logger): return obj.logger # Check for local logger variable if "logger" in frame.f_locals: logger = frame.f_locals["logger"] - if isinstance(logger, logging.Logger): + if _is_valid_logger(logger): return logger except Exception: @@ -174,3 +175,24 @@ def _get_logger_from_context(default_name: str = "subprocess_utils") -> logging. # Return default logger return logging.getLogger(default_name) + + +def _is_valid_logger(obj: object) -> bool: + """ + Check if object is a valid logger (either logging.Logger or FlashLoggerAdapter). + + Args: + obj: Object to check + + Returns: + True if obj is a valid logger, False otherwise + """ + # Accept standard logging.Logger + if isinstance(obj, logging.Logger): + return True + + # Accept objects with debug/info/warning/error methods (duck typing) + required_methods = {"debug", "info", "warning", "error"} + return all( + hasattr(obj, method) and callable(getattr(obj, method)) for method in required_methods + ) diff --git a/src/unpack_volume.py b/src/unpack_volume.py index 7e129ed..1e2ce07 100644 --- a/src/unpack_volume.py +++ b/src/unpack_volume.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging import os import sys import tarfile @@ -8,6 +7,7 @@ from pathlib import Path from time import sleep +from rp_logger_adapter import get_flash_logger from constants import ( DEFAULT_APP_DIR, DEFAULT_ARTIFACT_PATH, @@ -16,7 +16,7 @@ ) from manifest_reconciliation import is_flash_deployment -logger = logging.getLogger(__name__) +logger = get_flash_logger(__name__) def _safe_extract_tar(tar: tarfile.TarFile, target_dir: Path) -> None: diff --git a/tests/unit/test_rp_logger_adapter.py b/tests/unit/test_rp_logger_adapter.py new file mode 100644 index 0000000..dfeaffa --- /dev/null +++ b/tests/unit/test_rp_logger_adapter.py @@ -0,0 +1,353 @@ +"""Unit tests for RunPodLogger adapter layer.""" + +import os +import pytest +from unittest.mock import patch, MagicMock + +from rp_logger_adapter import ( + FlashLoggerAdapter, + get_flash_logger, + setup_flash_logging, + get_log_level, + _get_rp_logger, +) + + +@pytest.fixture +def clean_env(): + """Clean environment variables before each test.""" + original_env = os.environ.copy() + for key in ["RUNPOD_LOG_LEVEL", "LOG_LEVEL"]: + os.environ.pop(key, None) + yield + # Restore original environment + os.environ.clear() + os.environ.update(original_env) + + +@pytest.fixture +def mock_rp_logger_instance(): + """Create a mock RunPodLogger instance for testing.""" + mock_instance = MagicMock() + mock_instance.debug = MagicMock() + mock_instance.info = MagicMock() + mock_instance.warn = MagicMock() + mock_instance.error = MagicMock() + return mock_instance + + +@pytest.fixture +def adapter_with_mock(mock_rp_logger_instance): + """Create an adapter with mocked RunPodLogger.""" + with patch("rp_logger_adapter._get_rp_logger", return_value=mock_rp_logger_instance): + yield FlashLoggerAdapter("test"), mock_rp_logger_instance + + +class TestFlashLoggerAdapter: + """Test FlashLoggerAdapter class.""" + + def test_adapter_initialization(self, mock_rp_logger_instance): + """Test adapter initialization with a name.""" + with patch("rp_logger_adapter._get_rp_logger", return_value=mock_rp_logger_instance): + adapter = FlashLoggerAdapter("test_module") + assert adapter.name == "test_module" + + def test_debug_message(self, adapter_with_mock): + """Test debug logging.""" + adapter, mock_rp_logger = adapter_with_mock + adapter.debug("Test message") + mock_rp_logger.debug.assert_called_once_with("test | Test message") + + def test_info_message(self, adapter_with_mock): + """Test info logging.""" + adapter, mock_rp_logger = adapter_with_mock + adapter.info("Test message") + mock_rp_logger.info.assert_called_once_with("test | Test message") + + def test_warning_message(self, adapter_with_mock): + """Test warning logging.""" + adapter, mock_rp_logger = adapter_with_mock + adapter.warning("Test message") + mock_rp_logger.warn.assert_called_once_with("test | Test message") + + def test_warn_alias(self, adapter_with_mock): + """Test warn is alias for warning.""" + adapter, mock_rp_logger = adapter_with_mock + adapter.warn("Test message") + mock_rp_logger.warn.assert_called_once_with("test | Test message") + + def test_error_message(self, adapter_with_mock): + """Test error logging.""" + adapter, mock_rp_logger = adapter_with_mock + adapter.error("Test message") + mock_rp_logger.error.assert_called_once_with("test | Test message") + + def test_printf_style_formatting(self, adapter_with_mock): + """Test printf-style string formatting.""" + adapter, mock_rp_logger = adapter_with_mock + adapter.info("Value: %s, Count: %d", "hello", 42) + mock_rp_logger.info.assert_called_once_with("test | Value: hello, Count: 42") + + def test_printf_formatting_with_warning(self, adapter_with_mock): + """Test printf-style formatting with warning level.""" + adapter, mock_rp_logger = adapter_with_mock + adapter.warning("Error code: %d", 500) + mock_rp_logger.warn.assert_called_once_with("test | Error code: 500") + + def test_namespace_prefix(self, adapter_with_mock): + """Test namespace prefix in message.""" + _, mock_rp_logger = adapter_with_mock + adapter = FlashLoggerAdapter("flash.module_name") + adapter.info("Processing") + mock_rp_logger.info.assert_called_with("flash.module_name | Processing") + + def test_no_namespace(self, adapter_with_mock): + """Test message without namespace.""" + _, mock_rp_logger = adapter_with_mock + adapter = FlashLoggerAdapter("") + adapter.info("Simple message") + mock_rp_logger.info.assert_called_with("Simple message") + + def test_formatting_failure_returns_original(self, adapter_with_mock): + """Test that formatting errors return original message.""" + adapter, mock_rp_logger = adapter_with_mock + # Try to format with wrong args - should not raise + adapter.info("Message %s %s", "only_one") + # Should call with original message on formatting error + assert mock_rp_logger.info.called + call_args = mock_rp_logger.info.call_args[0][0] + assert "Message" in call_args + + def test_empty_args(self, adapter_with_mock): + """Test message with no format args.""" + adapter, mock_rp_logger = adapter_with_mock + adapter.info("No formatting") + mock_rp_logger.info.assert_called_once_with("test | No formatting") + + +class TestGetFlashLogger: + """Test get_flash_logger factory function.""" + + def test_returns_adapter(self, mock_rp_logger_instance): + """Test that get_flash_logger returns an adapter.""" + with patch("rp_logger_adapter._get_rp_logger", return_value=mock_rp_logger_instance): + logger = get_flash_logger("test") + assert isinstance(logger, FlashLoggerAdapter) + + def test_different_names(self, mock_rp_logger_instance): + """Test creating loggers with different names.""" + with patch("rp_logger_adapter._get_rp_logger", return_value=mock_rp_logger_instance): + logger1 = get_flash_logger("module1") + logger2 = get_flash_logger("module2") + assert logger1.name == "module1" + assert logger2.name == "module2" + + def test_with_dunder_name(self, mock_rp_logger_instance): + """Test with __name__ pattern.""" + with patch("rp_logger_adapter._get_rp_logger", return_value=mock_rp_logger_instance): + logger = get_flash_logger(__name__) + assert logger.name == __name__ + + +class TestSetupFlashLogging: + """Test setup_flash_logging function.""" + + def test_explicit_debug_level(self, mock_rp_logger_instance, clean_env): + """Test explicit DEBUG level setup.""" + with patch("rp_logger_adapter._get_rp_logger", return_value=mock_rp_logger_instance): + setup_flash_logging("DEBUG") + # Should not raise + assert True + + def test_explicit_info_level(self, mock_rp_logger_instance, clean_env): + """Test explicit INFO level setup.""" + with patch("rp_logger_adapter._get_rp_logger", return_value=mock_rp_logger_instance): + setup_flash_logging("INFO") + assert True + + def test_explicit_warn_level(self, mock_rp_logger_instance, clean_env): + """Test explicit WARN level setup.""" + with patch("rp_logger_adapter._get_rp_logger", return_value=mock_rp_logger_instance): + setup_flash_logging("WARN") + assert True + + def test_explicit_error_level(self, mock_rp_logger_instance, clean_env): + """Test explicit ERROR level setup.""" + with patch("rp_logger_adapter._get_rp_logger", return_value=mock_rp_logger_instance): + setup_flash_logging("ERROR") + assert True + + def test_case_insensitive_level(self, mock_rp_logger_instance, clean_env): + """Test that level is case-insensitive.""" + with patch("rp_logger_adapter._get_rp_logger", return_value=mock_rp_logger_instance): + setup_flash_logging("debug") + setup_flash_logging("INFO") + setup_flash_logging("WaRn") + # Should not raise + assert True + + def test_invalid_level_defaults_to_info(self, mock_rp_logger_instance, clean_env): + """Test that invalid level defaults to INFO.""" + with patch("rp_logger_adapter._get_rp_logger", return_value=mock_rp_logger_instance): + setup_flash_logging("INVALID") + # Should default to INFO without raising + assert True + + def test_none_level(self, mock_rp_logger_instance, clean_env): + """Test with None level (should use env vars).""" + os.environ["RUNPOD_LOG_LEVEL"] = "DEBUG" + with patch("rp_logger_adapter._get_rp_logger", return_value=mock_rp_logger_instance): + setup_flash_logging(None) + assert True + + +class TestGetLogLevel: + """Test get_log_level function.""" + + def test_runpod_log_level_precedence(self, clean_env): + """Test RUNPOD_LOG_LEVEL takes precedence.""" + os.environ["RUNPOD_LOG_LEVEL"] = "DEBUG" + os.environ["LOG_LEVEL"] = "ERROR" + assert get_log_level() == "DEBUG" + + def test_log_level_fallback(self, clean_env): + """Test LOG_LEVEL is used if RUNPOD_LOG_LEVEL not set.""" + os.environ["LOG_LEVEL"] = "WARN" + assert get_log_level() == "WARN" + + def test_default_info(self, clean_env): + """Test default is INFO.""" + assert get_log_level() == "INFO" + + def test_case_normalization(self, clean_env): + """Test that level is normalized to uppercase.""" + os.environ["LOG_LEVEL"] = "debug" + assert get_log_level() == "DEBUG" + + def test_runpod_log_level_case_handling(self, clean_env): + """Test RUNPOD_LOG_LEVEL case handling.""" + os.environ["RUNPOD_LOG_LEVEL"] = "DeBuG" + assert get_log_level() == "DEBUG" + + +class TestSingletonBehavior: + """Test that RunPodLogger singleton is shared.""" + + def test_multiple_adapters_share_rp_logger(self, mock_rp_logger_instance): + """Test that multiple adapters share the same RunPodLogger.""" + # Create multiple adapters with same mock + with patch("rp_logger_adapter._get_rp_logger", return_value=mock_rp_logger_instance): + adapter1 = get_flash_logger("module1") + adapter2 = get_flash_logger("module2") + + # Get the internal rp_logger for each + rp_logger1 = adapter1._rp_logger + rp_logger2 = adapter2._rp_logger + + # They should be the same instance + assert rp_logger1 is rp_logger2 + + def test_rp_logger_singleton_caching(self): + """Test that _get_rp_logger returns same instance.""" + import rp_logger_adapter + + # Reset the singleton + original_instance = rp_logger_adapter._rp_logger_instance + rp_logger_adapter._rp_logger_instance = None + + try: + logger1 = _get_rp_logger() + logger2 = _get_rp_logger() + + # Should be same instance + assert logger1 is logger2 + finally: + # Restore + rp_logger_adapter._rp_logger_instance = original_instance + + +class TestNamespacePrefixes: + """Test namespace prefix functionality.""" + + def test_module_namespace_prefix(self, adapter_with_mock): + """Test module-style namespace prefix.""" + _, mock_rp_logger = adapter_with_mock + adapter = FlashLoggerAdapter("flash.dependency_installer") + adapter.info("Installing packages") + mock_rp_logger.info.assert_called_with("flash.dependency_installer | Installing packages") + + def test_nested_namespace_prefix(self, adapter_with_mock): + """Test deeply nested namespace prefix.""" + _, mock_rp_logger = adapter_with_mock + adapter = FlashLoggerAdapter("flash.executor.function_executor") + adapter.error("Execution failed") + mock_rp_logger.error.assert_called_with( + "flash.executor.function_executor | Execution failed" + ) + + def test_simple_namespace(self, adapter_with_mock): + """Test simple namespace.""" + _, mock_rp_logger = adapter_with_mock + adapter = FlashLoggerAdapter("worker") + adapter.debug("Debug info") + mock_rp_logger.debug.assert_called_with("worker | Debug info") + + +class TestMultipleLevels: + """Test multiple log levels together.""" + + def test_all_levels_with_same_adapter(self, adapter_with_mock): + """Test that all levels work with same adapter.""" + adapter, mock_rp_logger = adapter_with_mock + + adapter.debug("Debug message") + adapter.info("Info message") + adapter.warning("Warning message") + adapter.error("Error message") + + assert mock_rp_logger.debug.call_count == 1 + assert mock_rp_logger.info.call_count == 1 + assert mock_rp_logger.warn.call_count == 1 + assert mock_rp_logger.error.call_count == 1 + + def test_mixed_formatting_and_levels(self, adapter_with_mock): + """Test mixed formatting and levels.""" + adapter, mock_rp_logger = adapter_with_mock + + adapter.debug("Debug %s", "msg") + adapter.info("Info %d", 42) + adapter.warning("Warn %s", "msg") + adapter.error("Error %s", "msg") + + mock_rp_logger.debug.assert_called_with("test | Debug msg") + mock_rp_logger.info.assert_called_with("test | Info 42") + mock_rp_logger.warn.assert_called_with("test | Warn msg") + mock_rp_logger.error.assert_called_with("test | Error msg") + + +class TestEnvironmentVariableConfiguration: + """Test environment variable configuration.""" + + def test_runpod_log_level_env(self, clean_env): + """Test RUNPOD_LOG_LEVEL environment variable.""" + os.environ["RUNPOD_LOG_LEVEL"] = "DEBUG" + level = get_log_level() + assert level == "DEBUG" + + def test_log_level_deprecated_env(self, clean_env): + """Test deprecated LOG_LEVEL environment variable.""" + os.environ["LOG_LEVEL"] = "ERROR" + level = get_log_level() + assert level == "ERROR" + + def test_env_precedence(self, clean_env): + """Test RUNPOD_LOG_LEVEL takes precedence over LOG_LEVEL.""" + os.environ["RUNPOD_LOG_LEVEL"] = "DEBUG" + os.environ["LOG_LEVEL"] = "ERROR" + level = get_log_level() + assert level == "DEBUG" + + def test_empty_env_defaults(self, clean_env): + """Test empty environment defaults to INFO.""" + level = get_log_level() + assert level == "INFO" From 90b737ec5ca748aaf55f7fc108a665bb81c51e79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 12 Feb 2026 20:41:26 -0800 Subject: [PATCH 3/6] fix: replace obsolete FLASH_IS_MOTHERSHIP with FLASH_MOTHERSHIP_ID Update environment variable detection throughout flash-worker to use the current FLASH_MOTHERSHIP_ID variable instead of the obsolete FLASH_IS_MOTHERSHIP. Changes: - src/lb_handler.py: Update detection logic and docstring, clarify log messages - src/manifest_reconciliation.py: Use new variable for Flash deployment check - src/unpack_volume.py: Update documentation comment - Test files: Update all test fixtures to use FLASH_MOTHERSHIP_ID - docs/Runtime_Execution_Paths.md: Update documentation and examples Detection logic: - OLD: os.getenv("FLASH_IS_MOTHERSHIP") == "true" - NEW: os.getenv("FLASH_MOTHERSHIP_ID") is not None This ensures the mothership endpoint correctly detects its mode and logs display accurate deployment information instead of "Queue-based mode" when running on the mothership. All 278 tests passing with 76.64% coverage. --- docs/Runtime_Execution_Paths.md | 4 ++-- src/lb_handler.py | 9 +++++---- src/manifest_reconciliation.py | 4 ++-- src/unpack_volume.py | 2 +- .../integration/test_manifest_state_manager.py | 18 +++++++++--------- tests/unit/test_manifest_reconciliation.py | 18 +++++++++--------- tests/unit/test_unpack_volume.py | 18 +++++++++--------- 7 files changed, 37 insertions(+), 36 deletions(-) diff --git a/docs/Runtime_Execution_Paths.md b/docs/Runtime_Execution_Paths.md index 12dc257..0a8f88f 100644 --- a/docs/Runtime_Execution_Paths.md +++ b/docs/Runtime_Execution_Paths.md @@ -75,11 +75,11 @@ The handler automatically detects the deployment mode using environment variable |-------------|---------------|--------------|---------------| | Local dev | ❌ Not set | ❌ Not set | Live Serverless only | | Live Serverless | ✅ Set | ❌ Not set | Live Serverless | -| Flash Mothership | ✅ Set | ✅ FLASH_IS_MOTHERSHIP=true | Flash Deployed | +| Flash Mothership | ✅ Set | ✅ FLASH_MOTHERSHIP_ID | Flash Deployed | | Flash Child | ✅ Set | ✅ FLASH_RESOURCE_NAME | Flash Deployed | Flash-specific environment variables: -- `FLASH_IS_MOTHERSHIP=true` - Set for mothership endpoints +- `FLASH_MOTHERSHIP_ID` - Set for mothership endpoints (contains the mothership's RUNPOD_ENDPOINT_ID) - `FLASH_RESOURCE_NAME` - Specifies resource config name ## Request Format Differences diff --git a/src/lb_handler.py b/src/lb_handler.py index bbcc5bd..a288247 100644 --- a/src/lb_handler.py +++ b/src/lb_handler.py @@ -8,13 +8,14 @@ The handler uses worker-flash's RemoteExecutor for function execution. -Mothership Mode (FLASH_IS_MOTHERSHIP=true): +Mothership Mode (FLASH_MOTHERSHIP_ID set): +- FLASH_MOTHERSHIP_ID contains the mothership's RUNPOD_ENDPOINT_ID - Imports user's FastAPI application from FLASH_MAIN_FILE - Loads the app object from FLASH_APP_VARIABLE - Preserves all user routes and middleware - Adds /ping health check endpoint -Queue-Based Mode (FLASH_IS_MOTHERSHIP not set or false): +Child Endpoint Mode (FLASH_MOTHERSHIP_ID not set): - Creates generic FastAPI app with /execute endpoint - Uses RemoteExecutor for function execution """ @@ -42,7 +43,7 @@ from remote_executor import RemoteExecutor # noqa: E402 # Determine mode based on environment variables -is_mothership = os.getenv("FLASH_IS_MOTHERSHIP") == "true" +is_mothership = os.getenv("FLASH_MOTHERSHIP_ID") is not None if is_mothership: # Mothership mode: Import user's FastAPI application @@ -97,7 +98,7 @@ async def ping_mothership() -> Dict[str, Any]: else: # Queue-based mode: Create generic Load Balancer handler app app = FastAPI(title="Load Balancer Handler") - logger.info("Queue-based mode: Using generic Load Balancer handler") + logger.info("Child endpoint mode: Using generic Load Balancer handler") # Queue-based mode endpoints diff --git a/src/manifest_reconciliation.py b/src/manifest_reconciliation.py index 7c2b148..02a162a 100644 --- a/src/manifest_reconciliation.py +++ b/src/manifest_reconciliation.py @@ -30,8 +30,8 @@ def is_flash_deployment() -> bool: endpoint_id = os.getenv("RUNPOD_ENDPOINT_ID") is_flash = any( [ - os.getenv("FLASH_IS_MOTHERSHIP") == "true", - os.getenv("FLASH_RESOURCE_NAME"), + os.getenv("FLASH_MOTHERSHIP_ID") is not None, + os.getenv("FLASH_RESOURCE_NAME") is not None, ] ) return bool(endpoint_id and is_flash) diff --git a/src/unpack_volume.py b/src/unpack_volume.py index 1e2ce07..cdfd076 100644 --- a/src/unpack_volume.py +++ b/src/unpack_volume.py @@ -86,7 +86,7 @@ def _should_unpack_from_volume() -> bool: Detection logic: 1. Honor explicit disable flag (FLASH_DISABLE_UNPACK) 2. Must be in RunPod environment (RUNPOD_POD_ID or RUNPOD_ENDPOINT_ID) - 3. Must be Flash deployment (any of FLASH_IS_MOTHERSHIP, FLASH_RESOURCE_NAME) + 3. Must be Flash deployment (any of FLASH_MOTHERSHIP_ID, FLASH_RESOURCE_NAME) Returns: bool: True if unpacking should occur, False otherwise diff --git a/tests/integration/test_manifest_state_manager.py b/tests/integration/test_manifest_state_manager.py index 293ecff..140781d 100644 --- a/tests/integration/test_manifest_state_manager.py +++ b/tests/integration/test_manifest_state_manager.py @@ -127,7 +127,7 @@ async def test_manifest_refresh_on_cross_endpoint_routing( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "RUNPOD_API_KEY": "test-api-key", }, clear=True, @@ -160,7 +160,7 @@ async def test_manifest_refresh_skipped_if_fresh( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "RUNPOD_API_KEY": "test-api-key", }, clear=True, @@ -196,7 +196,7 @@ async def test_manifest_refresh_continues_on_failure( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "RUNPOD_API_KEY": "test-api-key", }, clear=True, @@ -232,7 +232,7 @@ async def test_state_manager_unavailable_graceful_degradation( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "RUNPOD_API_KEY": "test-api-key", }, clear=True, @@ -320,7 +320,7 @@ async def test_state_manager_overwrites_local( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "RUNPOD_API_KEY": "test-api-key", }, clear=True, @@ -378,7 +378,7 @@ async def test_state_manager_provides_additional_metadata( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "RUNPOD_API_KEY": "test-api-key", }, clear=True, @@ -421,7 +421,7 @@ async def test_fallback_to_local_on_state_manager_error( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "RUNPOD_API_KEY": "test-api-key", }, clear=True, @@ -456,7 +456,7 @@ async def test_manifest_file_write_error( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "RUNPOD_API_KEY": "test-api-key", }, clear=True, @@ -488,7 +488,7 @@ async def test_multiple_refreshes_with_ttl( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "RUNPOD_API_KEY": "test-api-key", }, clear=True, diff --git a/tests/unit/test_manifest_reconciliation.py b/tests/unit/test_manifest_reconciliation.py index 2a2e720..55687f0 100644 --- a/tests/unit/test_manifest_reconciliation.py +++ b/tests/unit/test_manifest_reconciliation.py @@ -42,12 +42,12 @@ class TestIsFlashDeployment: """Test Flash deployment detection.""" def test_is_flash_deployment_mothership(self) -> None: - """Test detection with FLASH_IS_MOTHERSHIP.""" + """Test detection with FLASH_MOTHERSHIP_ID.""" with patch.dict( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", }, ): assert is_flash_deployment() is True @@ -69,7 +69,7 @@ def test_is_flash_deployment_no_endpoint_id(self) -> None: with patch.dict( "os.environ", { - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", }, clear=True, ): @@ -260,7 +260,7 @@ async def test_refresh_no_endpoint_id(self, tmp_path: Path) -> None: """Test refresh skipped when RUNPOD_ENDPOINT_ID not set.""" manifest_path = tmp_path / "manifest.json" - with patch.dict("os.environ", {"FLASH_IS_MOTHERSHIP": "true"}, clear=True): + with patch.dict("os.environ", {"FLASH_MOTHERSHIP_ID": "test-mothership-id"}, clear=True): result = await refresh_manifest_if_stale(manifest_path) assert result is False @@ -274,7 +274,7 @@ async def test_refresh_no_api_key(self, tmp_path: Path) -> None: "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", }, clear=True, ): @@ -296,7 +296,7 @@ async def test_refresh_fresh_manifest_no_query( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "RUNPOD_API_KEY": "test-key", }, clear=True, @@ -343,7 +343,7 @@ async def test_refresh_stale_manifest_queries_state_manager( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "RUNPOD_API_KEY": "test-key", }, clear=True, @@ -382,7 +382,7 @@ async def test_refresh_state_manager_error_continues( "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "RUNPOD_API_KEY": "test-key", }, clear=True, @@ -418,7 +418,7 @@ async def test_refresh_custom_ttl(self, tmp_path: Path, sample_manifest: dict) - "os.environ", { "RUNPOD_ENDPOINT_ID": "ep-test-001", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "RUNPOD_API_KEY": "test-key", }, clear=True, diff --git a/tests/unit/test_unpack_volume.py b/tests/unit/test_unpack_volume.py index e984356..a279806 100644 --- a/tests/unit/test_unpack_volume.py +++ b/tests/unit/test_unpack_volume.py @@ -223,7 +223,7 @@ def test_should_unpack_for_flash_mothership(self): os.environ, { "RUNPOD_ENDPOINT_ID": "test-endpoint-id", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", }, clear=False, ): @@ -252,7 +252,7 @@ def test_should_not_unpack_for_live_serverless(self): clear=False, ): os.environ.pop("FLASH_DISABLE_UNPACK", None) - os.environ.pop("FLASH_IS_MOTHERSHIP", None) + os.environ.pop("FLASH_MOTHERSHIP_ID", None) os.environ.pop("FLASH_RESOURCE_NAME", None) assert _should_unpack_from_volume() is False @@ -270,7 +270,7 @@ def test_should_not_unpack_when_disabled_with_1(self): os.environ, { "RUNPOD_POD_ID": "test-pod-id", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "FLASH_DISABLE_UNPACK": "1", }, ): @@ -282,7 +282,7 @@ def test_should_not_unpack_when_disabled_with_true(self): os.environ, { "RUNPOD_POD_ID": "test-pod-id", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "FLASH_DISABLE_UNPACK": "true", }, ): @@ -294,7 +294,7 @@ def test_should_not_unpack_when_disabled_with_yes(self): os.environ, { "RUNPOD_POD_ID": "test-pod-id", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "FLASH_DISABLE_UNPACK": "yes", }, ): @@ -306,7 +306,7 @@ def test_should_unpack_when_disable_flag_has_wrong_value(self): os.environ, { "RUNPOD_ENDPOINT_ID": "test-endpoint-id", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "FLASH_DISABLE_UNPACK": "false", }, ): @@ -318,7 +318,7 @@ def test_should_not_unpack_when_disabled_with_uppercase_true(self): os.environ, { "RUNPOD_POD_ID": "test-pod-id", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "FLASH_DISABLE_UNPACK": "True", }, ): @@ -330,7 +330,7 @@ def test_should_not_unpack_when_disabled_with_uppercase_yes(self): os.environ, { "RUNPOD_POD_ID": "test-pod-id", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "FLASH_DISABLE_UNPACK": "YES", }, ): @@ -342,7 +342,7 @@ def test_should_not_unpack_when_disabled_with_mixed_case(self): os.environ, { "RUNPOD_POD_ID": "test-pod-id", - "FLASH_IS_MOTHERSHIP": "true", + "FLASH_MOTHERSHIP_ID": "test-mothership-id", "FLASH_DISABLE_UNPACK": "Yes", }, ): From cb542b8d827e6e0b9722dea6a8075efa7274e652 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 12 Feb 2026 22:41:06 -0800 Subject: [PATCH 4/6] feat(api-key-context): Add API key propagation for remote execution Add thread-safe context variable for API key propagation: - api_key_context.py: New module with set/get/clear_api_key functions - lb_handler.py: Extract Bearer token from Authorization header via middleware - remote_executor.py: Check context API key before env var fallback - manifest_reconciliation.py: Skip State Manager queries in preview mode Add comprehensive unit tests for context management and middleware --- src/api_key_context.py | 43 ++++ src/lb_handler.py | 40 +++- src/manifest_reconciliation.py | 5 + src/remote_executor.py | 4 +- tests/unit/test_api_key_context.py | 187 ++++++++++++++++ tests/unit/test_lb_handler_middleware.py | 263 +++++++++++++++++++++++ tests/unit/test_remote_executor.py | 141 ++++++++++++ 7 files changed, 681 insertions(+), 2 deletions(-) create mode 100644 src/api_key_context.py create mode 100644 tests/unit/test_api_key_context.py create mode 100644 tests/unit/test_lb_handler_middleware.py diff --git a/src/api_key_context.py b/src/api_key_context.py new file mode 100644 index 0000000..e0247a3 --- /dev/null +++ b/src/api_key_context.py @@ -0,0 +1,43 @@ +"""Thread-local context for API key propagation across remote calls.""" + +import contextvars +from typing import Optional + +# Context variable for API key extracted from incoming requests +_api_key_context: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( + "api_key_context", default=None +) + + +def set_api_key(api_key: Optional[str]) -> contextvars.Token[Optional[str]]: + """Set API key in current context. + + Args: + api_key: RunPod API key to use for remote calls + + Returns: + Token that can be used to reset the context + """ + return _api_key_context.set(api_key) + + +def get_api_key() -> Optional[str]: + """Get API key from current context. + + Returns: + API key if set, None otherwise + """ + return _api_key_context.get() + + +def clear_api_key(token: Optional[contextvars.Token[Optional[str]]] = None) -> None: + """Clear API key from current context. + + Args: + token: Optional token from set_api_key() to reset to previous value. + If None, sets context to None (backwards compatible). + """ + if token is not None: + _api_key_context.reset(token) + else: + _api_key_context.set(None) diff --git a/src/lb_handler.py b/src/lb_handler.py index a288247..090adcc 100644 --- a/src/lb_handler.py +++ b/src/lb_handler.py @@ -24,8 +24,9 @@ import os from typing import Any, Dict -from fastapi import FastAPI +from fastapi import FastAPI, Request +from api_key_context import clear_api_key, set_api_key from rp_logger_adapter import setup_flash_logging, get_flash_logger from unpack_volume import maybe_unpack @@ -37,6 +38,40 @@ # This is a no-op for Live Serverless and local development maybe_unpack() + +async def extract_api_key_middleware(request: Request, call_next): + """Extract API key from Authorization header and set in context. + + This middleware extracts the Bearer token from the Authorization header + and makes it available to downstream code via context variables. This + enables worker endpoints to propagate API keys to remote calls. + + Args: + request: Incoming FastAPI request + call_next: Next middleware in chain + + Returns: + Response from downstream handlers + """ + # Extract API key from Authorization header + auth_header = request.headers.get("Authorization", "") + api_key = None + token = None + + if auth_header.startswith("Bearer "): + api_key = auth_header[7:].strip() # Remove "Bearer " prefix and trim whitespace + token = set_api_key(api_key) + logger.debug("Extracted API key from Authorization header") + + try: + response = await call_next(request) + return response + finally: + # Clean up context after request + if token is not None: + clear_api_key(token) + + # Import from bundled /app/runpod_flash (no system package) # These imports must happen AFTER maybe_unpack() so /app is in sys.path from runpod_flash.protos.remote_execution import FunctionRequest, FunctionResponse # noqa: E402 @@ -101,6 +136,9 @@ async def ping_mothership() -> Dict[str, Any]: logger.info("Child endpoint mode: Using generic Load Balancer handler") +# Register API key extraction middleware for both mothership and queue-based modes +app.middleware("http")(extract_api_key_middleware) + # Queue-based mode endpoints if not is_mothership: diff --git a/src/manifest_reconciliation.py b/src/manifest_reconciliation.py index 02a162a..bb38504 100644 --- a/src/manifest_reconciliation.py +++ b/src/manifest_reconciliation.py @@ -152,6 +152,11 @@ async def refresh_manifest_if_stale( logger.debug("RUNPOD_ENDPOINT_ID not set, skipping manifest refresh") return False + # Skip State Manager queries in preview mode + if endpoint_id.startswith("preview-"): + logger.debug("Preview mode detected, skipping State Manager queries") + return True # Manifest will be loaded from FLASH_RESOURCES_ENDPOINTS + api_key = os.getenv("RUNPOD_API_KEY") if not api_key: logger.debug("RUNPOD_API_KEY not set, skipping manifest refresh") diff --git a/src/remote_executor.py b/src/remote_executor.py index 1fdd6df..0bced9a 100644 --- a/src/remote_executor.py +++ b/src/remote_executor.py @@ -10,6 +10,7 @@ FunctionResponse, RemoteExecutorStub, ) +from api_key_context import get_api_key from rp_logger_adapter import get_flash_logger from dependency_installer import DependencyInstaller from function_executor import FunctionExecutor @@ -444,7 +445,8 @@ async def _route_to_endpoint( payload = {"input": request.model_dump(exclude_none=True)} # Make HTTP request to target endpoint - api_key = os.getenv("RUNPOD_API_KEY") + # Check context API key first (from incoming request), then env var (pre-deployed) + api_key = get_api_key() or os.getenv("RUNPOD_API_KEY") headers = {"Content-Type": "application/json"} if api_key: diff --git a/tests/unit/test_api_key_context.py b/tests/unit/test_api_key_context.py new file mode 100644 index 0000000..c36ebfc --- /dev/null +++ b/tests/unit/test_api_key_context.py @@ -0,0 +1,187 @@ +import pytest +import contextvars +import asyncio + +from api_key_context import set_api_key, get_api_key, clear_api_key + + +class TestApiKeyContext: + """Unit tests for API key context variable management.""" + + def test_set_api_key_stores_value(self): + """Test set_api_key stores API key in context.""" + # Clear any existing context + clear_api_key() + + api_key = "test-api-key-12345" + token = set_api_key(api_key) + + # Verify the API key is stored + assert get_api_key() == api_key + assert token is not None + assert isinstance(token, contextvars.Token) + + def test_get_api_key_returns_none_initially(self): + """Test get_api_key returns None when not set.""" + # Clear any existing context + clear_api_key() + + # Should return None when not set + assert get_api_key() is None + + def test_clear_api_key_with_token_resets_to_previous_value(self): + """Test clear_api_key with token resets to previous value.""" + # Start with first API key + clear_api_key() + first_key = "first-api-key" + token1 = set_api_key(first_key) + assert get_api_key() == first_key + + # Set second API key + second_key = "second-api-key" + token2 = set_api_key(second_key) + assert get_api_key() == second_key + + # Clear with token2 should restore to first_key + clear_api_key(token2) + assert get_api_key() == first_key + + # Clear with token1 should reset to None + clear_api_key(token1) + assert get_api_key() is None + + def test_clear_api_key_without_token_sets_to_none(self): + """Test clear_api_key without token sets context to None.""" + # Set an API key + api_key = "test-api-key" + set_api_key(api_key) + assert get_api_key() == api_key + + # Clear without token + clear_api_key() + + # Should be None + assert get_api_key() is None + + def test_set_api_key_with_none(self): + """Test set_api_key can store None explicitly.""" + # Set initial value + set_api_key("test-key") + + # Set to None + token = set_api_key(None) + assert get_api_key() is None + assert token is not None + + @pytest.mark.asyncio + async def test_context_isolation_between_async_tasks(self): + """Test that context is isolated between async tasks.""" + results = {} + + async def task_1(): + # Set API key for task 1 + set_api_key("task-1-key") + await asyncio.sleep(0.01) # Yield to allow task 2 to run + results["task_1"] = get_api_key() + + async def task_2(): + # Set API key for task 2 + set_api_key("task-2-key") + await asyncio.sleep(0.01) # Yield to allow task 1 to continue + results["task_2"] = get_api_key() + + async def task_3_check_none(): + # Task 3 should not have any API key set + await asyncio.sleep(0.005) + results["task_3"] = get_api_key() + + # Run tasks concurrently using TaskGroup (Python 3.11+) + async with asyncio.TaskGroup() as tg: + tg.create_task(task_1()) + tg.create_task(task_2()) + tg.create_task(task_3_check_none()) + + # Verify each task had its own context + assert results["task_1"] == "task-1-key" + assert results["task_2"] == "task-2-key" + assert results["task_3"] is None + + @pytest.mark.asyncio + async def test_context_scope_within_task(self): + """Test context scope within a single async task.""" + + async def task_with_context_scope(): + # Initially None + assert get_api_key() is None + + # Set key + token = set_api_key("scoped-key") + assert get_api_key() == "scoped-key" + + # Reset using token + clear_api_key(token) + assert get_api_key() is None + + await task_with_context_scope() + + def test_multiple_sequential_sets_and_clears(self): + """Test multiple sequential set/clear operations.""" + clear_api_key() + + # Set first key + token1 = set_api_key("key-1") + assert get_api_key() == "key-1" + + # Set second key + token2 = set_api_key("key-2") + assert get_api_key() == "key-2" + + # Set third key + token3 = set_api_key("key-3") + assert get_api_key() == "key-3" + + # Clear in reverse order + clear_api_key(token3) + assert get_api_key() == "key-2" + + clear_api_key(token2) + assert get_api_key() == "key-1" + + clear_api_key(token1) + assert get_api_key() is None + + def test_context_token_reset_with_none_value(self): + """Test token reset works correctly with None values.""" + clear_api_key() + + # Set to "test" + token1 = set_api_key("test") + assert get_api_key() == "test" + + # Set to None + token2 = set_api_key(None) + assert get_api_key() is None + + # Reset token2 should go back to "test" + clear_api_key(token2) + assert get_api_key() == "test" + + # Reset token1 should go back to None + clear_api_key(token1) + assert get_api_key() is None + + def test_context_isolation_preserves_values(self): + """Test that context variables preserve values through operations.""" + clear_api_key() + + api_key = "important-key-xyz" + token = set_api_key(api_key) + + # Retrieve multiple times + assert get_api_key() == api_key + assert get_api_key() == api_key + assert get_api_key() == api_key + + # Clear and verify + clear_api_key(token) + assert get_api_key() is None diff --git a/tests/unit/test_lb_handler_middleware.py b/tests/unit/test_lb_handler_middleware.py new file mode 100644 index 0000000..2804aa7 --- /dev/null +++ b/tests/unit/test_lb_handler_middleware.py @@ -0,0 +1,263 @@ +import pytest +import asyncio +from unittest.mock import MagicMock + +from api_key_context import get_api_key, clear_api_key + + +class TestLbHandlerMiddleware: + """Unit tests for API key extraction middleware.""" + + def setup_method(self): + """Setup for each test method.""" + # Clear API key context before each test + clear_api_key() + + @pytest.mark.asyncio + async def test_middleware_extracts_bearer_token_from_header(self): + """Test middleware extracts API key from valid Authorization header.""" + # Import here to get fresh state + from lb_handler import extract_api_key_middleware + + # Create mock request with Bearer token + mock_request = MagicMock() + mock_request.headers.get.return_value = "Bearer test-api-key-12345" + + # Create mock next handler that returns a response + async def mock_call_next(request): + # Check that API key was set in context + stored_key = get_api_key() + assert stored_key == "test-api-key-12345" + return MagicMock(status_code=200) + + # Call middleware + response = await extract_api_key_middleware(mock_request, mock_call_next) + + # Verify response returned + assert response.status_code == 200 + + # Verify context was cleaned up after request + assert get_api_key() is None + + @pytest.mark.asyncio + async def test_middleware_handles_missing_authorization_header(self): + """Test middleware handles missing Authorization header gracefully.""" + from lb_handler import extract_api_key_middleware + + # Create mock request without Authorization header + mock_request = MagicMock() + mock_request.headers.get.return_value = "" + + # Create mock next handler + async def mock_call_next(request): + # Should not have API key in context + assert get_api_key() is None + return MagicMock(status_code=200) + + # Call middleware + response = await extract_api_key_middleware(mock_request, mock_call_next) + + # Verify response returned + assert response.status_code == 200 + + # Verify context remains clean + assert get_api_key() is None + + @pytest.mark.asyncio + async def test_middleware_handles_malformed_authorization_header(self): + """Test middleware handles malformed Authorization header.""" + from lb_handler import extract_api_key_middleware + + # Test various malformed headers + malformed_headers = [ + "Basic dXNlcjpwYXNz", # Basic auth, not Bearer + "Bearer", # Missing token + "Bearer ", # Just "Bearer " with no token + "token-without-bearer", # No Bearer prefix + "bearer lowercase-key", # lowercase 'bearer' + ] + + for header_value in malformed_headers: + clear_api_key() + + mock_request = MagicMock() + mock_request.headers.get.return_value = header_value + + async def mock_call_next(request): + # Should not have API key in context for malformed headers + # (except if it somehow parsed a token, which shouldn't happen) + return MagicMock(status_code=200) + + response = await extract_api_key_middleware(mock_request, mock_call_next) + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_middleware_clears_context_after_request(self): + """Test middleware clears context variable after request completes.""" + from lb_handler import extract_api_key_middleware + + api_key = "test-key-to-be-cleared" + mock_request = MagicMock() + mock_request.headers.get.return_value = f"Bearer {api_key}" + + # Verify context is clean before + assert get_api_key() is None + + async def mock_call_next(request): + # Inside the request, context should have the key + assert get_api_key() == api_key + return MagicMock(status_code=200) + + # Call middleware + await extract_api_key_middleware(mock_request, mock_call_next) + + # After middleware completes, context should be cleared + assert get_api_key() is None + + @pytest.mark.asyncio + async def test_middleware_clears_context_even_on_exception(self): + """Test middleware clears context even if handler raises exception.""" + from lb_handler import extract_api_key_middleware + + api_key = "test-key-exception" + mock_request = MagicMock() + mock_request.headers.get.return_value = f"Bearer {api_key}" + + async def mock_call_next_with_error(request): + # Verify API key was set + assert get_api_key() == api_key + # Raise an exception + raise ValueError("Handler error") + + # Call middleware and expect it to raise + with pytest.raises(ValueError): + await extract_api_key_middleware(mock_request, mock_call_next_with_error) + + # Verify context was still cleaned up + assert get_api_key() is None + + @pytest.mark.asyncio + async def test_middleware_extracts_bearer_token_with_whitespace(self): + """Test middleware correctly handles Bearer token with extra whitespace.""" + from lb_handler import extract_api_key_middleware + + # Test Bearer token with leading/trailing spaces + mock_request = MagicMock() + mock_request.headers.get.return_value = "Bearer test-api-key-with-spaces " + + async def mock_call_next(request): + # Should extract token with whitespace trimmed + stored_key = get_api_key() + assert stored_key == "test-api-key-with-spaces" + return MagicMock(status_code=200) + + response = await extract_api_key_middleware(mock_request, mock_call_next) + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_context_isolation_between_concurrent_requests(self): + """Test context isolation between concurrent request handlers.""" + from lb_handler import extract_api_key_middleware + + results = {} + + async def request_1_handler(): + mock_request = MagicMock() + mock_request.headers.get.return_value = "Bearer request-1-key" + + async def next_handler(request): + # Simulate some async work + await asyncio.sleep(0.01) + results["request_1_inside"] = get_api_key() + return MagicMock(status_code=200) + + await extract_api_key_middleware(mock_request, next_handler) + results["request_1_outside"] = get_api_key() + + async def request_2_handler(): + mock_request = MagicMock() + mock_request.headers.get.return_value = "Bearer request-2-key" + + async def next_handler(request): + # Simulate some async work + await asyncio.sleep(0.005) + results["request_2_inside"] = get_api_key() + return MagicMock(status_code=200) + + await extract_api_key_middleware(mock_request, next_handler) + results["request_2_outside"] = get_api_key() + + # Run both requests concurrently + async with asyncio.TaskGroup() as tg: + tg.create_task(request_1_handler()) + tg.create_task(request_2_handler()) + + # Verify each request had isolated context + assert results["request_1_inside"] == "request-1-key" + assert results["request_2_inside"] == "request-2-key" + # Outside handlers, both should be None (cleaned up) + assert results["request_1_outside"] is None + assert results["request_2_outside"] is None + + @pytest.mark.asyncio + async def test_middleware_case_sensitive_bearer_prefix(self): + """Test middleware correctly requires 'Bearer' prefix (case-sensitive).""" + from lb_handler import extract_api_key_middleware + + # Test lowercase 'bearer' (should not match) + mock_request = MagicMock() + mock_request.headers.get.return_value = "bearer lowercase-key" + + async def mock_call_next(request): + # Should not extract key with lowercase 'bearer' + assert get_api_key() is None + return MagicMock(status_code=200) + + response = await extract_api_key_middleware(mock_request, mock_call_next) + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_middleware_with_complex_api_key_format(self): + """Test middleware with complex API key formats.""" + from lb_handler import extract_api_key_middleware + + # Complex API key with special characters + complex_key = "rp-aB1234567890-xyz_test.key" + mock_request = MagicMock() + mock_request.headers.get.return_value = f"Bearer {complex_key}" + + async def mock_call_next(request): + stored_key = get_api_key() + assert stored_key == complex_key + return MagicMock(status_code=200) + + response = await extract_api_key_middleware(mock_request, mock_call_next) + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_middleware_preserves_api_key_through_request_lifecycle(self): + """Test API key remains available throughout request lifecycle.""" + from lb_handler import extract_api_key_middleware + + api_key = "persistent-key" + mock_request = MagicMock() + mock_request.headers.get.return_value = f"Bearer {api_key}" + + access_log = [] + + async def mock_call_next(request): + # Multiple accesses during request should all return same key + access_log.append(get_api_key()) + await asyncio.sleep(0.001) + access_log.append(get_api_key()) + await asyncio.sleep(0.001) + access_log.append(get_api_key()) + return MagicMock(status_code=200) + + await extract_api_key_middleware(mock_request, mock_call_next) + + # Verify all accesses returned the same key + assert len(access_log) == 3 + assert all(key == api_key for key in access_log) + # After middleware, should be cleared + assert get_api_key() is None diff --git a/tests/unit/test_remote_executor.py b/tests/unit/test_remote_executor.py index 4b4dde8..1086742 100644 --- a/tests/unit/test_remote_executor.py +++ b/tests/unit/test_remote_executor.py @@ -707,3 +707,144 @@ async def test_live_serverless_skips_manifest_logic(self): # Verify function executor was called mock_execute.assert_called_once_with(request) + + @pytest.mark.asyncio + async def test_route_to_endpoint_uses_context_api_key_over_env_var(self): + """Test _route_to_endpoint uses context API key instead of env var.""" + request = FunctionRequest( + function_name="remote_func", + args=[], + kwargs={}, + ) + + context_key = "context-api-key" + env_key = "env-api-key" + + # Import after test setup to get context functions + from api_key_context import set_api_key, clear_api_key + + try: + # Set context API key + set_api_key(context_key) + + with ( + patch.dict("os.environ", {"RUNPOD_API_KEY": env_key}), + patch("aiohttp.ClientSession.post") as mock_post, + ): + # Mock aiohttp response + mock_response_data = {"output": {"success": True, "result": "endpoint_result"}} + mock_post_response = AsyncMock() + mock_post_response.status = 200 + mock_post_response.json = AsyncMock(return_value=mock_response_data) + mock_post_response.__aenter__.return_value = mock_post_response + mock_post_response.__aexit__.return_value = None + mock_post.return_value = mock_post_response + + # Call _route_to_endpoint + response = await self.executor._route_to_endpoint( + request, "https://api.runpod.ai/v2/endpoint/run" + ) + + # Verify response was successful + assert response.success is True + + # Verify context API key was used in Authorization header + mock_post.assert_called_once() + call_kwargs = mock_post.call_args[1] + headers = call_kwargs.get("headers", {}) + assert headers.get("Authorization") == f"Bearer {context_key}" + finally: + clear_api_key() + + @pytest.mark.asyncio + async def test_route_to_endpoint_falls_back_to_env_var_when_no_context_key(self): + """Test _route_to_endpoint falls back to env var when no context API key.""" + request = FunctionRequest( + function_name="remote_func", + args=[], + kwargs={}, + ) + + env_key = "env-api-key" + + # Import to ensure context is clear + from api_key_context import clear_api_key + + try: + # Ensure context is clear + clear_api_key() + + with ( + patch.dict("os.environ", {"RUNPOD_API_KEY": env_key}), + patch("aiohttp.ClientSession.post") as mock_post, + ): + # Mock aiohttp response + mock_response_data = {"output": {"success": True, "result": "endpoint_result"}} + mock_post_response = AsyncMock() + mock_post_response.status = 200 + mock_post_response.json = AsyncMock(return_value=mock_response_data) + mock_post_response.__aenter__.return_value = mock_post_response + mock_post_response.__aexit__.return_value = None + mock_post.return_value = mock_post_response + + # Call _route_to_endpoint + response = await self.executor._route_to_endpoint( + request, "https://api.runpod.ai/v2/endpoint/run" + ) + + # Verify response was successful + assert response.success is True + + # Verify env var was used in Authorization header + mock_post.assert_called_once() + call_kwargs = mock_post.call_args[1] + headers = call_kwargs.get("headers", {}) + assert headers.get("Authorization") == f"Bearer {env_key}" + finally: + clear_api_key() + + @pytest.mark.asyncio + async def test_route_to_endpoint_no_api_key_available(self): + """Test _route_to_endpoint handles case with no API key available.""" + request = FunctionRequest( + function_name="remote_func", + args=[], + kwargs={}, + ) + + # Import to ensure context is clear + from api_key_context import clear_api_key + + try: + # Ensure context is clear + clear_api_key() + + with ( + patch.dict("os.environ", {}, clear=True), + patch.dict("os.environ", {"RUNPOD_API_KEY": ""}, clear=False), + patch("aiohttp.ClientSession.post") as mock_post, + ): + # Mock aiohttp response + mock_response_data = {"output": {"success": True, "result": "endpoint_result"}} + mock_post_response = AsyncMock() + mock_post_response.status = 200 + mock_post_response.json = AsyncMock(return_value=mock_response_data) + mock_post_response.__aenter__.return_value = mock_post_response + mock_post_response.__aexit__.return_value = None + mock_post.return_value = mock_post_response + + # Call _route_to_endpoint + response = await self.executor._route_to_endpoint( + request, "https://api.runpod.ai/v2/endpoint/run" + ) + + # Verify response was successful + assert response.success is True + + # Verify no Authorization header was added + mock_post.assert_called_once() + call_kwargs = mock_post.call_args[1] + headers = call_kwargs.get("headers", {}) + assert "Authorization" not in headers or headers.get("Authorization") is None + finally: + clear_api_key() From 0d465e43f0133a7c2bad699fb70c8a0cdae2a63e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Fri, 13 Feb 2026 00:06:24 -0800 Subject: [PATCH 5/6] fix: use FLASH_IS_MOTHERSHIP flag for mothership detection - Check FLASH_IS_MOTHERSHIP flag (set by provisioner) - Fallback to FLASH_MOTHERSHIP_ID for backwards compatibility - Eliminates timing issues with RUNPOD_ENDPOINT_ID availability Enables deployed mothership endpoints to correctly boot in mothership mode and load user code instead of defaulting to child endpoint mode. --- src/lb_handler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lb_handler.py b/src/lb_handler.py index 090adcc..02266b6 100644 --- a/src/lb_handler.py +++ b/src/lb_handler.py @@ -78,7 +78,10 @@ async def extract_api_key_middleware(request: Request, call_next): from remote_executor import RemoteExecutor # noqa: E402 # Determine mode based on environment variables -is_mothership = os.getenv("FLASH_MOTHERSHIP_ID") is not None +# First check FLASH_IS_MOTHERSHIP (explicit flag set by provisioner) +# Then check FLASH_MOTHERSHIP_ID (for backwards compatibility) +is_mothership_flag = os.getenv("FLASH_IS_MOTHERSHIP", "").lower() == "true" +is_mothership = is_mothership_flag or os.getenv("FLASH_MOTHERSHIP_ID") is not None if is_mothership: # Mothership mode: Import user's FastAPI application From a181218e2f5ef0f04f1790d8ad99750d0c543d31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Fri, 13 Feb 2026 00:10:32 -0800 Subject: [PATCH 6/6] chore: bump version to 1.0.1 --- uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index 55204b5..a46dc42 100644 --- a/uv.lock +++ b/uv.lock @@ -3701,7 +3701,7 @@ wheels = [ [[package]] name = "worker-flash" -version = "1.0.0" +version = "1.0.1" source = { virtual = "." } dependencies = [ { name = "aiohttp" },