Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 94 additions & 2 deletions src/runpod_flash/cli/commands/build_utils/handler_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,78 @@ def handler(job):
return {{"error": str(e), "traceback": traceback.format_exc()}}


if __name__ == "__main__":
import runpod
runpod.serverless.start({{"handler": handler}})
'''

DEPLOYED_CLASS_HANDLER_TEMPLATE = '''"""
Auto-generated deployed handler for class-based resource: {resource_name}
Generated at: {timestamp}

Deployed class handler: creates a module-level instance, dispatches to
methods by name. One class per endpoint, identified by FLASH_RESOURCE_NAME.

This file is generated by the Flash build process. Do not edit manually.
"""

import asyncio
import importlib
import inspect
import logging
import traceback

_logger = logging.getLogger(__name__)

# Import the class for this endpoint
{import_statement}

# Module-level instance (created once per worker, reused across jobs)
_instance = {class_name}()


def handler(job):
"""Handler for deployed class-based QB endpoint.

Expects job input with optional 'method_name' key for method dispatch.
All other keys are passed as kwargs to the method.
"""
job_input = job.get("input", {{}})
method_name = job_input.get("method_name", "{default_method}")
kwargs = {{k: v for k, v in job_input.items() if k != "method_name"}}
try:
method = getattr(_instance, method_name)
result = method(**kwargs)
if inspect.iscoroutine(result):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
import concurrent.futures

with concurrent.futures.ThreadPoolExecutor() as pool:
result = pool.submit(asyncio.run, result).result()
else:
result = asyncio.run(result)
return result
except AttributeError:
_logger.error(
"Method '%s' not found on %s",
method_name,
type(_instance).__name__,
)
return {{"error": f"Method '{{method_name}}' not found on {class_name}"}}
except Exception as e:
_logger.error(
"Deployed class handler error for {class_name}.%s: %s",
method_name,
e,
exc_info=True,
)
return {{"error": str(e), "traceback": traceback.format_exc()}}


if __name__ == "__main__":
import runpod
runpod.serverless.start({{"handler": handler}})
Expand Down Expand Up @@ -186,7 +258,9 @@ def _generate_deployed_handler_code(
timestamp: str,
functions: List[Any],
) -> str:
"""Generate deployed handler code for a single-function endpoint.
"""Generate deployed handler code for a single-function or class endpoint.

Selects between function and class templates based on is_class metadata.

Args:
resource_name: Name of the resource config.
Expand All @@ -202,17 +276,35 @@ def _generate_deployed_handler_code(
f"Cannot generate a deployed handler without at least one function."
)

# Use the first function (one function per deployed QB endpoint)
# Use the first function (one function/class per deployed QB endpoint)
func = functions[0]
module = func.module if hasattr(func, "module") else func.get("module")
name = func.name if hasattr(func, "name") else func.get("name")
is_class = (
func.is_class if hasattr(func, "is_class") else func.get("is_class", False)
)

import_statement = (
f"{name} = importlib.import_module('{module}').{name}"
if module and name
else "# No function to import"
)

if is_class:
class_methods = (
func.class_methods
if hasattr(func, "class_methods")
else func.get("class_methods", [])
)
default_method = class_methods[0] if len(class_methods) == 1 else "__call__"
return DEPLOYED_CLASS_HANDLER_TEMPLATE.format(
resource_name=resource_name,
timestamp=timestamp,
import_statement=import_statement,
class_name=name or "None",
default_method=default_method,
)

return DEPLOYED_HANDLER_TEMPLATE.format(
resource_name=resource_name,
timestamp=timestamp,
Expand Down
5 changes: 5 additions & 0 deletions src/runpod_flash/cli/commands/build_utils/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,11 @@ def build(self) -> Dict[str, Any]:
if is_load_balanced
else {}
),
**(
{"class_methods": f.class_methods}
if f.is_class and f.class_methods
else {}
),
}
for f in functions
]
Expand Down
181 changes: 181 additions & 0 deletions tests/unit/cli/commands/build_utils/test_handler_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,184 @@ def test_validate_handler_rejects_syntax_errors():

with pytest.raises(ValueError, match="Handler has syntax errors"):
generator._validate_handler_imports(handler_path)


# --- Tests for deployed class-based handler (is_class=True) ---


def _make_class_manifest(
*, is_live_resource=False, class_methods=None, extra_resource_fields=None
):
"""Helper to create a manifest with a class-based @remote entry."""
func_entry = {
"name": "MyModel",
"module": "workers.model",
"is_async": False,
"is_class": True,
}
if class_methods is not None:
func_entry["class_methods"] = class_methods
resource = {
"resource_type": "Serverless",
"is_live_resource": is_live_resource,
"functions": [func_entry],
}
if extra_resource_fields:
resource.update(extra_resource_fields)
return {
"version": "1.0",
"generated_at": "2026-01-02T10:00:00Z",
"project_name": "test_app",
"resources": {"my_model_config": resource},
}


def test_deployed_class_handler_creates_module_level_instance():
"""Deployed class handler instantiates the class at module level."""
with tempfile.TemporaryDirectory() as tmpdir:
build_dir = Path(tmpdir)
manifest = _make_class_manifest()

generator = HandlerGenerator(manifest, build_dir)
handler_paths = generator.generate_handlers()
content = handler_paths[0].read_text()

assert "_instance = MyModel()" in content


def test_deployed_class_handler_dispatches_to_method():
"""Deployed class handler reads method_name from job input and dispatches."""
with tempfile.TemporaryDirectory() as tmpdir:
build_dir = Path(tmpdir)
manifest = _make_class_manifest()

generator = HandlerGenerator(manifest, build_dir)
handler_paths = generator.generate_handlers()
content = handler_paths[0].read_text()

# Handler should get method_name from input and call it on the instance
assert "method_name" in content
assert "getattr(_instance, method_name)" in content


def test_deployed_class_handler_does_not_call_class_directly():
"""Deployed class handler must NOT do MyModel(**job_input) — that's the bug."""
with tempfile.TemporaryDirectory() as tmpdir:
build_dir = Path(tmpdir)
manifest = _make_class_manifest()

generator = HandlerGenerator(manifest, build_dir)
handler_paths = generator.generate_handlers()
content = handler_paths[0].read_text()

# The broken pattern: calling the class as a function with job input
assert "MyModel(**job_input)" not in content


def test_deployed_class_handler_excludes_method_name_from_kwargs():
"""method_name must be stripped from kwargs before passing to the method."""
with tempfile.TemporaryDirectory() as tmpdir:
build_dir = Path(tmpdir)
manifest = _make_class_manifest()

generator = HandlerGenerator(manifest, build_dir)
handler_paths = generator.generate_handlers()
content = handler_paths[0].read_text()

# Should filter out method_name from the kwargs passed to the method
assert 'k != "method_name"' in content


def test_deployed_class_handler_handles_async_methods():
"""Deployed class handler handles coroutines from async methods."""
with tempfile.TemporaryDirectory() as tmpdir:
build_dir = Path(tmpdir)
manifest = _make_class_manifest()

generator = HandlerGenerator(manifest, build_dir)
handler_paths = generator.generate_handlers()
content = handler_paths[0].read_text()

assert "inspect.iscoroutine(result)" in content
assert "asyncio.run(result)" in content


def test_deployed_class_handler_has_valid_syntax():
"""Generated class handler must be valid Python (passes ast.parse)."""
with tempfile.TemporaryDirectory() as tmpdir:
build_dir = Path(tmpdir)
manifest = _make_class_manifest()

generator = HandlerGenerator(manifest, build_dir)
handler_paths = generator.generate_handlers()

# Should not raise — _validate_handler_imports uses ast.parse
assert handler_paths[0].exists()
import ast

ast.parse(handler_paths[0].read_text())


def test_deployed_class_handler_has_runpod_start():
"""Deployed class handler includes runpod.serverless.start."""
with tempfile.TemporaryDirectory() as tmpdir:
build_dir = Path(tmpdir)
manifest = _make_class_manifest()

generator = HandlerGenerator(manifest, build_dir)
handler_paths = generator.generate_handlers()
content = handler_paths[0].read_text()

assert 'runpod.serverless.start({"handler": handler})' in content


def test_deployed_class_handler_imports_class():
"""Deployed class handler imports the class via importlib."""
with tempfile.TemporaryDirectory() as tmpdir:
build_dir = Path(tmpdir)
manifest = _make_class_manifest()

generator = HandlerGenerator(manifest, build_dir)
handler_paths = generator.generate_handlers()
content = handler_paths[0].read_text()

assert "MyModel = importlib.import_module('workers.model').MyModel" in content


def test_deployed_class_handler_single_method_defaults_to_that_method():
"""Single public method class defaults method_name to that method (no method_name needed)."""
with tempfile.TemporaryDirectory() as tmpdir:
build_dir = Path(tmpdir)
manifest = _make_class_manifest(class_methods=["predict"])

generator = HandlerGenerator(manifest, build_dir)
handler_paths = generator.generate_handlers()
content = handler_paths[0].read_text()

assert 'method_name = job_input.get("method_name", "predict")' in content


def test_deployed_class_handler_multi_method_defaults_to_call():
"""Multi-method class defaults method_name to __call__ (requires explicit method_name)."""
with tempfile.TemporaryDirectory() as tmpdir:
build_dir = Path(tmpdir)
manifest = _make_class_manifest(class_methods=["predict", "embed"])

generator = HandlerGenerator(manifest, build_dir)
handler_paths = generator.generate_handlers()
content = handler_paths[0].read_text()

assert 'method_name = job_input.get("method_name", "__call__")' in content


def test_deployed_class_handler_no_class_methods_defaults_to_call():
"""Class with no class_methods metadata defaults to __call__."""
with tempfile.TemporaryDirectory() as tmpdir:
build_dir = Path(tmpdir)
manifest = _make_class_manifest() # No class_methods

generator = HandlerGenerator(manifest, build_dir)
handler_paths = generator.generate_handlers()
content = handler_paths[0].read_text()

assert 'method_name = job_input.get("method_name", "__call__")' in content
69 changes: 69 additions & 0 deletions tests/unit/cli/commands/build_utils/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,72 @@ def test_extract_deployment_config_network_volume_minimal():
# Default size and dataCenterId should still be present
assert config["networkVolume"]["size"] == 100
assert config["networkVolume"]["dataCenterId"] == "EU-RO-1"


# --- Tests for class_methods in manifest ---


def test_manifest_includes_class_methods_for_class_entries():
"""Manifest function entries include class_methods when is_class=True and methods exist."""
functions = [
RemoteFunctionMetadata(
function_name="MyModel",
module_path="workers.model",
resource_config_name="model_config",
resource_type="LiveServerless",
is_async=False,
is_class=True,
file_path=Path("workers/model.py"),
class_methods=["predict", "embed"],
)
]

builder = ManifestBuilder("test_app", functions)
manifest = builder.build()

func_entry = manifest["resources"]["model_config"]["functions"][0]
assert func_entry["is_class"] is True
assert func_entry["class_methods"] == ["predict", "embed"]


def test_manifest_excludes_class_methods_for_non_class_entries():
"""Manifest function entries do not include class_methods for regular functions."""
functions = [
RemoteFunctionMetadata(
function_name="gpu_task",
module_path="workers.gpu",
resource_config_name="gpu_config",
resource_type="LiveServerless",
is_async=True,
is_class=False,
file_path=Path("workers/gpu.py"),
)
]

builder = ManifestBuilder("test_app", functions)
manifest = builder.build()

func_entry = manifest["resources"]["gpu_config"]["functions"][0]
assert "class_methods" not in func_entry


def test_manifest_excludes_class_methods_when_empty():
"""Manifest function entries do not include class_methods when list is empty."""
functions = [
RemoteFunctionMetadata(
function_name="MyModel",
module_path="workers.model",
resource_config_name="model_config",
resource_type="LiveServerless",
is_async=False,
is_class=True,
file_path=Path("workers/model.py"),
class_methods=[],
)
]

builder = ManifestBuilder("test_app", functions)
manifest = builder.build()

func_entry = manifest["resources"]["model_config"]["functions"][0]
assert "class_methods" not in func_entry
Loading