Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
376 changes: 159 additions & 217 deletions CLAUDE.md

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs/Runtime_Execution_Paths.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ The handler automatically detects the deployment mode using environment variable
|-------------|---------------|--------------|---------------|
| Local dev | ❌ Not set | ❌ Not set | Live Serverless only |
| Live Serverless | ✅ Set | ❌ Not set | Live Serverless |
| Flash Mothership | ✅ Set | ✅ FLASH_IS_MOTHERSHIP=true | Flash Deployed |
| Flash Mothership | ✅ Set | ✅ FLASH_MOTHERSHIP_ID | Flash Deployed |
| Flash Child | ✅ Set | ✅ FLASH_RESOURCE_NAME | Flash Deployed |

Flash-specific environment variables:
- `FLASH_IS_MOTHERSHIP=true` - Set for mothership endpoints
- `FLASH_MOTHERSHIP_ID` - Set for mothership endpoints (contains the mothership's RUNPOD_ENDPOINT_ID)
- `FLASH_RESOURCE_NAME` - Specifies resource config name

## Request Format Differences
Expand Down
43 changes: 43 additions & 0 deletions src/api_key_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Thread-local context for API key propagation across remote calls."""

import contextvars
from typing import Optional

# Context variable for API key extracted from incoming requests
_api_key_context: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"api_key_context", default=None
)


def set_api_key(api_key: Optional[str]) -> contextvars.Token[Optional[str]]:
"""Set API key in current context.

Args:
api_key: RunPod API key to use for remote calls

Returns:
Token that can be used to reset the context
"""
return _api_key_context.set(api_key)


def get_api_key() -> Optional[str]:
"""Get API key from current context.

Returns:
API key if set, None otherwise
"""
return _api_key_context.get()


def clear_api_key(token: Optional[contextvars.Token[Optional[str]]] = None) -> None:
"""Clear API key from current context.

Args:
token: Optional token from set_api_key() to reset to previous value.
If None, sets context to None (backwards compatible).
"""
if token is not None:
_api_key_context.reset(token)
else:
_api_key_context.set(None)
4 changes: 2 additions & 2 deletions src/cache_sync_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import logging
import asyncio
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Optional
from rp_logger_adapter import get_flash_logger
from constants import NAMESPACE, CACHE_DIR, VOLUME_CACHE_PATH
from subprocess_utils import run_logged_subprocess

Expand All @@ -13,7 +13,7 @@ class CacheSyncManager:
"""Manages async fire-and-forget cache synchronization to network volume."""

def __init__(self):
self.logger = logging.getLogger(f"{NAMESPACE}.{__name__.split('.')[-1]}")
self.logger = get_flash_logger(f"{NAMESPACE}.{__name__.split('.')[-1]}")
self._should_sync_cached: Optional[bool] = None
self._endpoint_id = os.environ.get("RUNPOD_ENDPOINT_ID")
self._baseline_time: Optional[float] = None
Expand Down
4 changes: 2 additions & 2 deletions src/dependency_installer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import logging
import asyncio
import platform
from typing import List

from runpod_flash.protos.remote_execution import FunctionResponse
from rp_logger_adapter import get_flash_logger
from constants import LARGE_SYSTEM_PACKAGES, NAMESPACE
from subprocess_utils import run_logged_subprocess

Expand All @@ -13,7 +13,7 @@ class DependencyInstaller:
"""Handles installation of system and Python dependencies."""

def __init__(self):
self.logger = logging.getLogger(f"{NAMESPACE}.{__name__.split('.')[-1]}")
self.logger = get_flash_logger(f"{NAMESPACE}.{__name__.split('.')[-1]}")
self._nala_available = None # Cache nala availability check
self._is_docker = None # Cache Docker environment detection

Expand Down
4 changes: 2 additions & 2 deletions src/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from runpod_flash.protos.remote_execution import FunctionRequest, FunctionResponse
from remote_executor import RemoteExecutor
from logger import setup_logging
from rp_logger_adapter import setup_flash_logging
from unpack_volume import maybe_unpack

# Initialize logging configuration
setup_logging()
setup_flash_logging()

# Unpack Flash deployment artifacts if running in Flash mode
# This is a no-op for Live Serverless and local development
Expand Down
59 changes: 50 additions & 9 deletions src/lb_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,80 @@

The handler uses worker-flash's RemoteExecutor for function execution.

Mothership Mode (FLASH_IS_MOTHERSHIP=true):
Mothership Mode (FLASH_MOTHERSHIP_ID set):
- FLASH_MOTHERSHIP_ID contains the mothership's RUNPOD_ENDPOINT_ID
- Imports user's FastAPI application from FLASH_MAIN_FILE
- Loads the app object from FLASH_APP_VARIABLE
- Preserves all user routes and middleware
- Adds /ping health check endpoint

Queue-Based Mode (FLASH_IS_MOTHERSHIP not set or false):
Child Endpoint Mode (FLASH_MOTHERSHIP_ID not set):
- Creates generic FastAPI app with /execute endpoint
- Uses RemoteExecutor for function execution
"""

import importlib.util
import logging
import os
from typing import Any, Dict

from fastapi import FastAPI
from fastapi import FastAPI, Request

from logger import setup_logging
from api_key_context import clear_api_key, set_api_key
from rp_logger_adapter import setup_flash_logging, get_flash_logger
from unpack_volume import maybe_unpack

# Initialize logging configuration
setup_logging()
logger = logging.getLogger(__name__)
setup_flash_logging()
logger = get_flash_logger(__name__)

# Unpack Flash deployment artifacts if running in Flash mode
# This is a no-op for Live Serverless and local development
maybe_unpack()


async def extract_api_key_middleware(request: Request, call_next):
"""Extract API key from Authorization header and set in context.

This middleware extracts the Bearer token from the Authorization header
and makes it available to downstream code via context variables. This
enables worker endpoints to propagate API keys to remote calls.

Args:
request: Incoming FastAPI request
call_next: Next middleware in chain

Returns:
Response from downstream handlers
"""
# Extract API key from Authorization header
auth_header = request.headers.get("Authorization", "")
api_key = None
token = None

if auth_header.startswith("Bearer "):
api_key = auth_header[7:].strip() # Remove "Bearer " prefix and trim whitespace
token = set_api_key(api_key)
logger.debug("Extracted API key from Authorization header")

try:
response = await call_next(request)
return response
finally:
# Clean up context after request
if token is not None:
clear_api_key(token)


# Import from bundled /app/runpod_flash (no system package)
# These imports must happen AFTER maybe_unpack() so /app is in sys.path
from runpod_flash.protos.remote_execution import FunctionRequest, FunctionResponse # noqa: E402
from remote_executor import RemoteExecutor # noqa: E402

# Determine mode based on environment variables
is_mothership = os.getenv("FLASH_IS_MOTHERSHIP") == "true"
# First check FLASH_IS_MOTHERSHIP (explicit flag set by provisioner)
# Then check FLASH_MOTHERSHIP_ID (for backwards compatibility)
is_mothership_flag = os.getenv("FLASH_IS_MOTHERSHIP", "").lower() == "true"
is_mothership = is_mothership_flag or os.getenv("FLASH_MOTHERSHIP_ID") is not None

if is_mothership:
# Mothership mode: Import user's FastAPI application
Expand Down Expand Up @@ -98,8 +136,11 @@ async def ping_mothership() -> Dict[str, Any]:
else:
# Queue-based mode: Create generic Load Balancer handler app
app = FastAPI(title="Load Balancer Handler")
logger.info("Queue-based mode: Using generic Load Balancer handler")
logger.info("Child endpoint mode: Using generic Load Balancer handler")


# Register API key extraction middleware for both mothership and queue-based modes
app.middleware("http")(extract_api_key_middleware)

# Queue-based mode endpoints
if not is_mothership:
Expand Down
Loading
Loading