From fa63931258cc5f9da34e76397c80642fd391e7a3 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Mon, 10 Nov 2025 17:53:58 +0000 Subject: [PATCH 01/25] feat(sagemaker/sessions): add config to enable stateful sessions using env variable --- .../sagemaker/config.py | 13 +++++++++++++ .../sagemaker/sessions/manager.py | 16 ++++++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/python/model_hosting_container_standards/sagemaker/config.py b/python/model_hosting_container_standards/sagemaker/config.py index 58f8cb5..b823da4 100644 --- a/python/model_hosting_container_standards/sagemaker/config.py +++ b/python/model_hosting_container_standards/sagemaker/config.py @@ -1,5 +1,18 @@ """SageMaker-specific configuration constants.""" +import os + +SAGEMAKER_ENV_VAR_PREFIX = "OPTION_" + + +def get_configs_from_env_vars(): + sagemaker_args = { + key[len(SAGEMAKER_ENV_VAR_PREFIX) :].lower(): val + for key, val in os.environ.items() + if key.startswith(SAGEMAKER_ENV_VAR_PREFIX) + } + return sagemaker_args + class SageMakerEnvVars: """SageMaker environment variable names.""" diff --git a/python/model_hosting_container_standards/sagemaker/sessions/manager.py b/python/model_hosting_container_standards/sagemaker/sessions/manager.py index 59897e5..acb82d9 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/manager.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/manager.py @@ -9,7 +9,9 @@ import time import uuid from threading import RLock -from typing import Optional +from typing import Dict, Optional + +from ..config import get_configs_from_env_vars class Session: @@ -248,5 +250,15 @@ def _clean_expired_session(self): self.close_session(session_id) +def _init_session_manager(sessions_configs: Dict[str, str]) -> SessionManager | None: + enable_stateful_sessions = sessions_configs.get( + "enable_stateful_sessions", "false" + ).lower() + if enable_stateful_sessions == "true": + return SessionManager(sessions_configs) + return None + + +_sessions_configs = get_configs_from_env_vars() # Global SessionManager instance -session_manager = SessionManager({}) +session_manager = _init_session_manager(_sessions_configs) From 0a99c84afb050c292465ef62c0c6f0514fbf242c Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Mon, 10 Nov 2025 18:15:03 +0000 Subject: [PATCH 02/25] feat(sagemaker/sessions): add utility functions for getting session_manager --- .../sagemaker/sessions/manager.py | 34 ++++++++++++++++++- .../sagemaker/sessions/transform.py | 4 +-- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/manager.py b/python/model_hosting_container_standards/sagemaker/sessions/manager.py index acb82d9..3271c8b 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/manager.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/manager.py @@ -251,6 +251,14 @@ def _clean_expired_session(self): def _init_session_manager(sessions_configs: Dict[str, str]) -> SessionManager | None: + """Initialize a SessionManager if stateful sessions are enabled. + + Args: + sessions_configs: Configuration dictionary with session settings + + Returns: + SessionManager instance if enabled, None otherwise + """ enable_stateful_sessions = sessions_configs.get( "enable_stateful_sessions", "false" ).lower() @@ -259,6 +267,30 @@ def _init_session_manager(sessions_configs: Dict[str, str]) -> SessionManager | return None +def get_session_manager() -> SessionManager | None: + """Get the global session manager instance. + + Returns: + The global SessionManager instance, or None if not initialized + """ + return session_manager + + +def init_session_manager_from_env() -> SessionManager | None: + """Initialize the global session manager from environment variables. + + This can be called to reinitialize the session manager after environment + variables have been set. + + Returns: + The initialized SessionManager instance, or None if disabled + """ + global session_manager + sessions_configs = get_configs_from_env_vars() + session_manager = _init_session_manager(sessions_configs) + return session_manager + + +# Global SessionManager instance - initialized from environment variables _sessions_configs = get_configs_from_env_vars() -# Global SessionManager instance session_manager = _init_session_manager(_sessions_configs) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transform.py index 8290b78..12c493d 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transform.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transform.py @@ -9,7 +9,7 @@ from ...common import BaseApiTransform, BaseTransformRequestOutput from .handlers import get_handler_for_request_type -from .manager import SessionManager, session_manager +from .manager import get_session_manager, SessionManager from .models import SessionRequest from .utils import get_session, get_session_id_from_request @@ -124,7 +124,7 @@ def __init__(self, request_shape, response_shape={}): The request/response shapes are passed to the parent class but not used for validation in this transform, as session requests use their own validation. """ - self._session_manager = session_manager + self._session_manager = get_session_manager() super().__init__(request_shape, response_shape) async def transform_request(self, raw_request): From 2e57cd5d4866dd9c34211ddb18b14c712c071ef2 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Mon, 10 Nov 2025 18:15:25 +0000 Subject: [PATCH 03/25] chore(sagemaker/sessions): update tests --- .../sagemaker/sessions/transform.py | 2 +- .../test_sagemaker_sessions_integration.py | 29 +++++++++++++++++++ python/tests/sagemaker/sessions/conftest.py | 24 +++++++++++++++ .../sagemaker/sessions/test_transform.py | 8 +++-- 4 files changed, 59 insertions(+), 4 deletions(-) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transform.py index 12c493d..3c90175 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transform.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transform.py @@ -9,7 +9,7 @@ from ...common import BaseApiTransform, BaseTransformRequestOutput from .handlers import get_handler_for_request_type -from .manager import get_session_manager, SessionManager +from .manager import SessionManager, get_session_manager from .models import SessionRequest from .utils import get_session, get_session_id_from_request diff --git a/python/tests/integration/test_sagemaker_sessions_integration.py b/python/tests/integration/test_sagemaker_sessions_integration.py index 7f127fe..416f362 100644 --- a/python/tests/integration/test_sagemaker_sessions_integration.py +++ b/python/tests/integration/test_sagemaker_sessions_integration.py @@ -18,6 +18,9 @@ """ import json +import os +import shutil +import tempfile from typing import Optional import pytest @@ -26,11 +29,37 @@ from fastapi.testclient import TestClient import model_hosting_container_standards.sagemaker as sagemaker_standards +from model_hosting_container_standards.sagemaker.sessions.manager import ( + init_session_manager_from_env, +) from model_hosting_container_standards.sagemaker.sessions.models import ( SageMakerSessionHeader, ) +@pytest.fixture(autouse=True) +def enable_sessions_for_integration(monkeypatch): + """Automatically enable sessions for all integration tests in this module.""" + temp_dir = tempfile.mkdtemp() + + monkeypatch.setenv("OPTION_ENABLE_STATEFUL_SESSIONS", "true") + monkeypatch.setenv("OPTION_SESSIONS_PATH", temp_dir) + monkeypatch.setenv("OPTION_SESSIONS_EXPIRATION", "600") + + # Reinitialize the global session manager + init_session_manager_from_env() + + yield + + # Clean up + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + monkeypatch.delenv("OPTION_ENABLE_STATEFUL_SESSIONS", raising=False) + monkeypatch.delenv("OPTION_SESSIONS_PATH", raising=False) + monkeypatch.delenv("OPTION_SESSIONS_EXPIRATION", raising=False) + init_session_manager_from_env() + + def extract_session_id_from_header(header_value: str) -> str: """Extract session ID from SageMaker session header. diff --git a/python/tests/sagemaker/sessions/conftest.py b/python/tests/sagemaker/sessions/conftest.py index c340c90..a9f7670 100644 --- a/python/tests/sagemaker/sessions/conftest.py +++ b/python/tests/sagemaker/sessions/conftest.py @@ -12,6 +12,7 @@ from model_hosting_container_standards.sagemaker.sessions.manager import ( Session, SessionManager, + init_session_manager_from_env, ) from model_hosting_container_standards.sagemaker.sessions.models import ( SageMakerSessionHeader, @@ -63,3 +64,26 @@ def session_manager(temp_session_storage): """Create a real session manager with temporary storage for integration tests.""" properties = {"sessions_path": temp_session_storage, "sessions_expiration": "600"} return SessionManager(properties) + + +@pytest.fixture +def enable_sessions_env(monkeypatch, temp_session_storage): + """Set environment variables to enable stateful sessions for tests. + + This fixture sets the necessary environment variables and reinitializes + the global session manager from those variables. + """ + monkeypatch.setenv("OPTION_ENABLE_STATEFUL_SESSIONS", "true") + monkeypatch.setenv("OPTION_SESSIONS_PATH", temp_session_storage) + monkeypatch.setenv("OPTION_SESSIONS_EXPIRATION", "600") + + # Reinitialize the global session manager with the new environment variables + init_session_manager_from_env() + + yield + + # Clean up - reinitialize with empty env (will set session_manager to None) + monkeypatch.delenv("OPTION_ENABLE_STATEFUL_SESSIONS", raising=False) + monkeypatch.delenv("OPTION_SESSIONS_PATH", raising=False) + monkeypatch.delenv("OPTION_SESSIONS_EXPIRATION", raising=False) + init_session_manager_from_env() diff --git a/python/tests/sagemaker/sessions/test_transform.py b/python/tests/sagemaker/sessions/test_transform.py index fc9da59..56962b5 100644 --- a/python/tests/sagemaker/sessions/test_transform.py +++ b/python/tests/sagemaker/sessions/test_transform.py @@ -250,18 +250,20 @@ class TestSessionApiTransform: """Test SessionApiTransform class.""" @pytest.fixture - def transform(self): + def transform(self, enable_sessions_env): """Create SessionApiTransform instance.""" return SessionApiTransform(request_shape={}, response_shape={}) - def test_initialization_creates_session_manager(self): + def test_initialization_creates_session_manager(self, enable_sessions_env): """Test initialization creates internal session manager.""" transform = SessionApiTransform(request_shape={}, response_shape={}) assert hasattr(transform, "_session_manager") assert isinstance(transform._session_manager, SessionManager) - def test_initialization_accepts_request_and_response_shapes(self): + def test_initialization_accepts_request_and_response_shapes( + self, enable_sessions_env + ): """Test initialization accepts request and response shapes.""" request_shape = {"field": "value"} response_shape = {"output": "format"} From 85e1687fbdc9b3a048cc8ddea2d101c4a05ec07d Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Mon, 10 Nov 2025 21:22:43 +0000 Subject: [PATCH 04/25] feat: update env config to use pydantic model SageMakerConfig and use SAGEMAKER_ as prefix. --- .../sagemaker/config.py | 97 +++++++++++++++++-- .../sagemaker/sessions/manager.py | 28 +++--- .../test_sagemaker_sessions_integration.py | 12 +-- python/tests/sagemaker/sessions/conftest.py | 12 +-- 4 files changed, 116 insertions(+), 33 deletions(-) diff --git a/python/model_hosting_container_standards/sagemaker/config.py b/python/model_hosting_container_standards/sagemaker/config.py index b823da4..4c85540 100644 --- a/python/model_hosting_container_standards/sagemaker/config.py +++ b/python/model_hosting_container_standards/sagemaker/config.py @@ -1,17 +1,98 @@ """SageMaker-specific configuration constants.""" import os +from typing import Any, Dict, Optional -SAGEMAKER_ENV_VAR_PREFIX = "OPTION_" +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +SAGEMAKER_ENV_VAR_PREFIX = "SAGEMAKER_" -def get_configs_from_env_vars(): - sagemaker_args = { - key[len(SAGEMAKER_ENV_VAR_PREFIX) :].lower(): val - for key, val in os.environ.items() - if key.startswith(SAGEMAKER_ENV_VAR_PREFIX) - } - return sagemaker_args + +class SageMakerConfig(BaseModel): + """Pydantic model for SageMaker configuration. + + Automatically loads configuration from environment variables with SAGEMAKER_ prefix. + Example: SAGEMAKER_ENABLE_STATEFUL_SESSIONS=true -> enable_stateful_sessions=True + + Only fields defined in this model are loaded. Other SAGEMAKER_* env vars + (like SAGEMAKER_MODEL_PATH) are ignored. + + Usage: + # Create from environment variables + config = SagemakerConfig.from_env() + + # Or just instantiate (automatically loads from env) + config = SagemakerConfig() + + # Override specific values + config = SagemakerConfig(enable_stateful_sessions=True) + """ + + model_config = ConfigDict(extra="ignore") + + # Stateful sessions configuration + enable_stateful_sessions: bool = Field( + default=False, description="Enable stateful sessions for the application" + ) + sessions_expiration: int = Field( + default=1200, # 20 minutes + description="Session expiration time in seconds", + gt=0, + ) + sessions_path: Optional[str] = Field( + default=None, + description="Custom path for session storage (defaults to /dev/shm or temp)", + ) + + @classmethod + def from_env(cls) -> "SageMakerConfig": + """Create SagemakerConfig from environment variables. + + Returns: + SagemakerConfig instance with values loaded from SAGEMAKER_* env vars + """ + return cls() + + @model_validator(mode="before") + @classmethod + def load_from_env_vars(cls, data: Any) -> Dict[str, Any]: + """Load configuration from environment variables. + + Extracts SAGEMAKER_* environment variables and merges with any provided data. + Provided data takes precedence over environment variables. + Unknown SAGEMAKER_* variables are ignored (only defined fields are loaded). + """ + # Extract env vars with SAGEMAKER_ prefix + env_config = { + key[len(SAGEMAKER_ENV_VAR_PREFIX) :].lower(): val + for key, val in os.environ.items() + if key.startswith(SAGEMAKER_ENV_VAR_PREFIX) + } + + # If data is provided, merge with env config (data takes precedence) + if isinstance(data, dict): + return {**env_config, **data} + return env_config + + @field_validator("enable_stateful_sessions", mode="before") + @classmethod + def parse_bool_string(cls, v: Any) -> bool: + """Convert string values from env vars to boolean.""" + if isinstance(v, bool): + return v + if isinstance(v, str): + return v.lower() in ("true", "1") + return bool(v) + + @field_validator("sessions_expiration", mode="before") + @classmethod + def parse_int_string(cls, v: Any) -> int: + """Convert string values from env vars to integer.""" + if isinstance(v, int): + return v + if isinstance(v, str): + return int(v) + return int(v) class SageMakerEnvVars: diff --git a/python/model_hosting_container_standards/sagemaker/sessions/manager.py b/python/model_hosting_container_standards/sagemaker/sessions/manager.py index 3271c8b..d1b96e0 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/manager.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/manager.py @@ -9,9 +9,9 @@ import time import uuid from threading import RLock -from typing import Dict, Optional +from typing import Optional -from ..config import get_configs_from_env_vars +from ..config import SageMakerConfig class Session: @@ -250,20 +250,22 @@ def _clean_expired_session(self): self.close_session(session_id) -def _init_session_manager(sessions_configs: Dict[str, str]) -> SessionManager | None: +def _init_session_manager(config: SageMakerConfig) -> SessionManager | None: """Initialize a SessionManager if stateful sessions are enabled. Args: - sessions_configs: Configuration dictionary with session settings + config: SagemakerConfig instance with session settings Returns: SessionManager instance if enabled, None otherwise """ - enable_stateful_sessions = sessions_configs.get( - "enable_stateful_sessions", "false" - ).lower() - if enable_stateful_sessions == "true": - return SessionManager(sessions_configs) + if config.enable_stateful_sessions: + # Convert config to dict for SessionManager + config_dict = { + "sessions_expiration": str(config.sessions_expiration), + "sessions_path": config.sessions_path, + } + return SessionManager(config_dict) return None @@ -286,11 +288,11 @@ def init_session_manager_from_env() -> SessionManager | None: The initialized SessionManager instance, or None if disabled """ global session_manager - sessions_configs = get_configs_from_env_vars() - session_manager = _init_session_manager(sessions_configs) + config = SageMakerConfig.from_env() + session_manager = _init_session_manager(config) return session_manager # Global SessionManager instance - initialized from environment variables -_sessions_configs = get_configs_from_env_vars() -session_manager = _init_session_manager(_sessions_configs) +_config = SageMakerConfig.from_env() +session_manager = _init_session_manager(_config) diff --git a/python/tests/integration/test_sagemaker_sessions_integration.py b/python/tests/integration/test_sagemaker_sessions_integration.py index 416f362..dca9871 100644 --- a/python/tests/integration/test_sagemaker_sessions_integration.py +++ b/python/tests/integration/test_sagemaker_sessions_integration.py @@ -42,9 +42,9 @@ def enable_sessions_for_integration(monkeypatch): """Automatically enable sessions for all integration tests in this module.""" temp_dir = tempfile.mkdtemp() - monkeypatch.setenv("OPTION_ENABLE_STATEFUL_SESSIONS", "true") - monkeypatch.setenv("OPTION_SESSIONS_PATH", temp_dir) - monkeypatch.setenv("OPTION_SESSIONS_EXPIRATION", "600") + monkeypatch.setenv("SAGEMAKER_ENABLE_STATEFUL_SESSIONS", "true") + monkeypatch.setenv("SAGEMAKER_SESSIONS_PATH", temp_dir) + monkeypatch.setenv("SAGEMAKER_SESSIONS_EXPIRATION", "600") # Reinitialize the global session manager init_session_manager_from_env() @@ -54,9 +54,9 @@ def enable_sessions_for_integration(monkeypatch): # Clean up if os.path.exists(temp_dir): shutil.rmtree(temp_dir) - monkeypatch.delenv("OPTION_ENABLE_STATEFUL_SESSIONS", raising=False) - monkeypatch.delenv("OPTION_SESSIONS_PATH", raising=False) - monkeypatch.delenv("OPTION_SESSIONS_EXPIRATION", raising=False) + monkeypatch.delenv("SAGEMAKER_ENABLE_STATEFUL_SESSIONS", raising=False) + monkeypatch.delenv("SAGEMAKER_SESSIONS_PATH", raising=False) + monkeypatch.delenv("SAGEMAKER_SESSIONS_EXPIRATION", raising=False) init_session_manager_from_env() diff --git a/python/tests/sagemaker/sessions/conftest.py b/python/tests/sagemaker/sessions/conftest.py index a9f7670..c1d193c 100644 --- a/python/tests/sagemaker/sessions/conftest.py +++ b/python/tests/sagemaker/sessions/conftest.py @@ -73,9 +73,9 @@ def enable_sessions_env(monkeypatch, temp_session_storage): This fixture sets the necessary environment variables and reinitializes the global session manager from those variables. """ - monkeypatch.setenv("OPTION_ENABLE_STATEFUL_SESSIONS", "true") - monkeypatch.setenv("OPTION_SESSIONS_PATH", temp_session_storage) - monkeypatch.setenv("OPTION_SESSIONS_EXPIRATION", "600") + monkeypatch.setenv("SAGEMAKER_ENABLE_STATEFUL_SESSIONS", "true") + monkeypatch.setenv("SAGEMAKER_SESSIONS_PATH", temp_session_storage) + monkeypatch.setenv("SAGEMAKER_SESSIONS_EXPIRATION", "600") # Reinitialize the global session manager with the new environment variables init_session_manager_from_env() @@ -83,7 +83,7 @@ def enable_sessions_env(monkeypatch, temp_session_storage): yield # Clean up - reinitialize with empty env (will set session_manager to None) - monkeypatch.delenv("OPTION_ENABLE_STATEFUL_SESSIONS", raising=False) - monkeypatch.delenv("OPTION_SESSIONS_PATH", raising=False) - monkeypatch.delenv("OPTION_SESSIONS_EXPIRATION", raising=False) + monkeypatch.delenv("SAGEMAKER_ENABLE_STATEFUL_SESSIONS", raising=False) + monkeypatch.delenv("SAGEMAKER_SESSIONS_PATH", raising=False) + monkeypatch.delenv("SAGEMAKER_SESSIONS_EXPIRATION", raising=False) init_session_manager_from_env() From e040f2d1290c151b618a09b10e8cf81a07bb1d01 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Tue, 11 Nov 2025 01:23:16 +0000 Subject: [PATCH 05/25] feat(sagemaker/sessions): add validation layer so if session_manager is None but request is a session request, raise 400 error - change session api transform's transform_request to never return request field in output (previously used to pass session_manager - change handler code to only take raw_request as a parameter and use new utility function get_session_manager - add new integration tests for expected errors when sessions is disabled --- .../sagemaker/sessions/handlers.py | 29 ++++-- .../sagemaker/sessions/models.py | 8 ++ .../sagemaker/sessions/transform.py | 18 +++- .../test_sagemaker_sessions_integration.py | 93 +++++++++++++++++++ .../tests/sagemaker/sessions/test_handlers.py | 36 +++++-- .../sagemaker/sessions/test_transform.py | 8 +- python/tests/sagemaker/sessions/test_utils.py | 3 +- 7 files changed, 172 insertions(+), 23 deletions(-) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/handlers.py b/python/model_hosting_container_standards/sagemaker/sessions/handlers.py index c074468..e883bcb 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/handlers.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/handlers.py @@ -5,8 +5,13 @@ from fastapi import Request, Response from fastapi.exceptions import HTTPException -from .manager import SessionManager -from .models import SageMakerSessionHeader, SessionRequestType +from .manager import get_session_manager +from .models import ( + SESSION_DISABLED_ERROR_DETAIL, + SESSION_DISABLED_LOG_MESSAGE, + SageMakerSessionHeader, + SessionRequestType, +) from .utils import get_session_id_from_request logger = logging.getLogger(__name__) @@ -29,11 +34,10 @@ def get_handler_for_request_type(request_type: SessionRequestType): return None -async def close_session(session_manager: SessionManager, raw_request: Request): +async def close_session(raw_request: Request): """Close an existing session and clean up its resources. Args: - session_manager: SessionManager instance to manage the session lifecycle raw_request: FastAPI Request object containing session ID in headers Returns: @@ -43,6 +47,13 @@ async def close_session(session_manager: SessionManager, raw_request: Request): HTTPException: If session closure fails with 424 FAILED_DEPENDENCY status """ session_id = get_session_id_from_request(raw_request) + session_manager = get_session_manager() + if session_manager is None: + logger.error(SESSION_DISABLED_LOG_MESSAGE) + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=SESSION_DISABLED_ERROR_DETAIL, + ) try: session_manager.close_session(session_id) logger.info(f"Session {session_id} closed") @@ -59,11 +70,10 @@ async def close_session(session_manager: SessionManager, raw_request: Request): ) -async def create_session(session_manager: SessionManager, raw_request: Request): +async def create_session(raw_request: Request): """Create a new stateful session with expiration tracking. Args: - session_manager: SessionManager instance to manage the session lifecycle raw_request: FastAPI Request object (unused but part of handler signature) Returns: @@ -72,6 +82,13 @@ async def create_session(session_manager: SessionManager, raw_request: Request): Raises: HTTPException: If session creation fails with 424 FAILED_DEPENDENCY status """ + session_manager = get_session_manager() + if session_manager is None: + logger.error(SESSION_DISABLED_LOG_MESSAGE) + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=SESSION_DISABLED_ERROR_DETAIL, + ) try: session = session_manager.create_session() # expiration_ts is guaranteed to be set for newly created sessions diff --git a/python/model_hosting_container_standards/sagemaker/sessions/models.py b/python/model_hosting_container_standards/sagemaker/sessions/models.py index cbbadd7..a87df8d 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/models.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/models.py @@ -37,3 +37,11 @@ class SageMakerSessionHeader: SESSION_ID = "X-Amzn-SageMaker-Session-Id" NEW_SESSION_ID = "X-Amzn-SageMaker-New-Session-Id" CLOSED_SESSION_ID = "X-Amzn-SageMaker-Closed-Session-Id" + + +# Error messages for session management +SESSION_DISABLED_ERROR_DETAIL = "Invalid payload. stateful sessions not enabled" +SESSION_DISABLED_LOG_MESSAGE = ( + f"Invalid payload. stateful sessions not enabled, " + f"{SageMakerSessionHeader.SESSION_ID} header not supported" +) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transform.py index 3c90175..b893c19 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transform.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transform.py @@ -10,7 +10,11 @@ from ...common import BaseApiTransform, BaseTransformRequestOutput from .handlers import get_handler_for_request_type from .manager import SessionManager, get_session_manager -from .models import SessionRequest +from .models import ( + SESSION_DISABLED_ERROR_DETAIL, + SESSION_DISABLED_LOG_MESSAGE, + SessionRequest, +) from .utils import get_session, get_session_id_from_request logger = logging.getLogger(__name__) @@ -63,7 +67,7 @@ def _validate_session_if_present(raw_request: Request, session_manager: SessionM def process_session_request( - request_data: dict, raw_request: Request, session_manager: SessionManager + request_data: dict, raw_request: Request, session_manager: Optional[SessionManager] ): """Process a potential session management request. @@ -92,16 +96,22 @@ def process_session_request( # Not a session request - pass through for normal processing if session_request is None: return BaseTransformRequestOutput( - request=None, raw_request=raw_request, intercept_func=None, ) + if session_manager is None: + logger.error(SESSION_DISABLED_LOG_MESSAGE) + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=SESSION_DISABLED_ERROR_DETAIL, + ) + # Route to appropriate session management handler intercept_func = get_handler_for_request_type(session_request.requestType) return BaseTransformRequestOutput( - request=session_manager, raw_request=raw_request, intercept_func=intercept_func + raw_request=raw_request, intercept_func=intercept_func ) diff --git a/python/tests/integration/test_sagemaker_sessions_integration.py b/python/tests/integration/test_sagemaker_sessions_integration.py index dca9871..baddf77 100644 --- a/python/tests/integration/test_sagemaker_sessions_integration.py +++ b/python/tests/integration/test_sagemaker_sessions_integration.py @@ -33,6 +33,7 @@ init_session_manager_from_env, ) from model_hosting_container_standards.sagemaker.sessions.models import ( + SESSION_DISABLED_ERROR_DETAIL, SageMakerSessionHeader, ) @@ -542,5 +543,97 @@ def test_interleaved_session_operations(self): assert response.status_code == 200 +class TestSessionsDisabled: + """Test behavior when stateful sessions are disabled. + + These tests verify that session management requests fail gracefully + when the SAGEMAKER_ENABLE_STATEFUL_SESSIONS flag is not set. + """ + + @pytest.fixture + def app_with_sessions_disabled(self, monkeypatch): + """Create app with sessions disabled.""" + # Explicitly disable sessions + monkeypatch.delenv("SAGEMAKER_ENABLE_STATEFUL_SESSIONS", raising=False) + monkeypatch.delenv("SAGEMAKER_SESSIONS_PATH", raising=False) + monkeypatch.delenv("SAGEMAKER_SESSIONS_EXPIRATION", raising=False) + + # Reinitialize the global session manager (should be None) + init_session_manager_from_env() + + # Now create the app with sessions disabled + app = FastAPI() + router = APIRouter() + + @router.post("/invocations") + @sagemaker_standards.stateful_session_manager() + async def invocations(request: Request): + """Stateful invocation handler.""" + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + + return Response( + status_code=200, + content=json.dumps({"message": "success", "echo": body}), + ) + + app.include_router(router) + sagemaker_standards.bootstrap(app) + + return TestClient(app) + + def test_new_session_request_fails_when_disabled(self, app_with_sessions_disabled): + """Test that NEW_SESSION request fails when sessions are disabled.""" + response = app_with_sessions_disabled.post( + "/invocations", json={"requestType": "NEW_SESSION"} + ) + + # Should fail with 400 BAD_REQUEST since sessions are not enabled + assert response.status_code == 400 + assert SESSION_DISABLED_ERROR_DETAIL in response.text + + def test_close_session_request_fails_when_disabled( + self, app_with_sessions_disabled + ): + """Test that CLOSE request fails when sessions are disabled.""" + response = app_with_sessions_disabled.post( + "/invocations", + json={"requestType": "CLOSE"}, + headers={SageMakerSessionHeader.SESSION_ID: "some-session-id"}, + ) + + # Should fail with 400 BAD_REQUEST due to session header when sessions disabled + assert response.status_code == 400 + assert SESSION_DISABLED_ERROR_DETAIL in response.text + + def test_regular_requests_work_when_sessions_disabled( + self, app_with_sessions_disabled + ): + """Test that regular requests still work when sessions are disabled.""" + response = app_with_sessions_disabled.post( + "/invocations", json={"prompt": "test request"} + ) + + # Regular requests should still work + assert response.status_code == 200 + data = json.loads(response.text) + assert data["message"] == "success" + assert data["echo"]["prompt"] == "test request" + + def test_regular_requests_with_session_header_when_disabled( + self, app_with_sessions_disabled + ): + """Test that requests with session headers fail validation when sessions disabled.""" + response = app_with_sessions_disabled.post( + "/invocations", + json={"prompt": "test"}, + headers={SageMakerSessionHeader.SESSION_ID: "invalid-session"}, + ) + + # Should fail with 400 BAD_REQUEST since sessions are not enabled + assert response.status_code == 400 + assert SESSION_DISABLED_ERROR_DETAIL in response.text + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/python/tests/sagemaker/sessions/test_handlers.py b/python/tests/sagemaker/sessions/test_handlers.py index 79a56f0..5e99371 100644 --- a/python/tests/sagemaker/sessions/test_handlers.py +++ b/python/tests/sagemaker/sessions/test_handlers.py @@ -2,7 +2,7 @@ import time from http import HTTPStatus -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest from fastapi import Response @@ -61,7 +61,11 @@ async def test_creates_session_successfully( """Test successfully creates a session and returns response.""" mock_session_manager.create_session.return_value = mock_session_with_expiration - response = await create_session(mock_session_manager, mock_request) + with patch( + "model_hosting_container_standards.sagemaker.sessions.handlers.get_session_manager", + return_value=mock_session_manager, + ): + response = await create_session(mock_request) assert isinstance(response, Response) assert response.status_code == HTTPStatus.OK.value @@ -80,7 +84,11 @@ async def test_calls_session_manager_create_session( """Test calls session_manager.create_session method.""" mock_session_manager.create_session.return_value = mock_session_with_expiration - await create_session(mock_session_manager, mock_request) + with patch( + "model_hosting_container_standards.sagemaker.sessions.handlers.get_session_manager", + return_value=mock_session_manager, + ): + await create_session(mock_request) mock_session_manager.create_session.assert_called_once() @@ -91,8 +99,12 @@ async def test_raises_http_exception_on_session_creation_failure( """Test raises HTTPException when session creation fails.""" mock_session_manager.create_session.side_effect = Exception("Creation failed") - with pytest.raises(HTTPException) as exc_info: - await create_session(mock_session_manager, mock_request) + with patch( + "model_hosting_container_standards.sagemaker.sessions.handlers.get_session_manager", + return_value=mock_session_manager, + ): + with pytest.raises(HTTPException) as exc_info: + await create_session(mock_request) assert exc_info.value.status_code == HTTPStatus.FAILED_DEPENDENCY.value assert "Failed to create session" in exc_info.value.detail @@ -109,7 +121,11 @@ async def test_closes_session_successfully( session_id = "test-session-123" mock_session_manager.close_session.return_value = None - response = await close_session(mock_session_manager, mock_request_with_session) + with patch( + "model_hosting_container_standards.sagemaker.sessions.handlers.get_session_manager", + return_value=mock_session_manager, + ): + response = await close_session(mock_request_with_session) assert isinstance(response, Response) assert response.status_code == HTTPStatus.OK.value @@ -124,8 +140,12 @@ async def test_raises_http_exception_on_close_failure( """Test raises HTTPException when session close fails.""" mock_session_manager.close_session.side_effect = ValueError("Session not found") - with pytest.raises(HTTPException) as exc_info: - await close_session(mock_session_manager, mock_request_with_session) + with patch( + "model_hosting_container_standards.sagemaker.sessions.handlers.get_session_manager", + return_value=mock_session_manager, + ): + with pytest.raises(HTTPException) as exc_info: + await close_session(mock_request_with_session) assert exc_info.value.status_code == HTTPStatus.FAILED_DEPENDENCY.value assert "Failed to close session" in exc_info.value.detail diff --git a/python/tests/sagemaker/sessions/test_transform.py b/python/tests/sagemaker/sessions/test_transform.py index 56962b5..70f0e53 100644 --- a/python/tests/sagemaker/sessions/test_transform.py +++ b/python/tests/sagemaker/sessions/test_transform.py @@ -183,7 +183,7 @@ def test_returns_create_handler_for_new_session_request( ) assert isinstance(result, BaseTransformRequestOutput) - assert result.request == mock_session_manager + assert result.request is None assert result.raw_request == mock_request assert result.intercept_func == create_session @@ -198,7 +198,7 @@ def test_returns_close_handler_for_close_request( ) assert isinstance(result, BaseTransformRequestOutput) - assert result.request == mock_session_manager + assert result.request is None assert result.raw_request == mock_request assert result.intercept_func == close_session @@ -335,10 +335,10 @@ async def test_end_to_end_new_session_flow(self, transform): # Verify we get an intercept function assert result.intercept_func == create_session - assert result.request == transform._session_manager + assert result.request is None # Verify we can call the handler - response = await result.intercept_func(result.request, mock_request) + response = await result.intercept_func(mock_request) assert response.status_code == HTTPStatus.OK.value assert SageMakerSessionHeader.NEW_SESSION_ID in response.headers diff --git a/python/tests/sagemaker/sessions/test_utils.py b/python/tests/sagemaker/sessions/test_utils.py index 925dea3..62ee6e3 100644 --- a/python/tests/sagemaker/sessions/test_utils.py +++ b/python/tests/sagemaker/sessions/test_utils.py @@ -8,6 +8,7 @@ from fastapi.exceptions import HTTPException from model_hosting_container_standards.sagemaker.sessions.models import ( + SESSION_DISABLED_ERROR_DETAIL, SageMakerSessionHeader, ) from model_hosting_container_standards.sagemaker.sessions.utils import ( @@ -92,7 +93,7 @@ def test_raises_http_exception_when_sessions_not_enabled_but_header_present(self get_session(None, raw_request) assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value - assert "stateful sessions not enabled" in exc_info.value.detail + assert SESSION_DISABLED_ERROR_DETAIL in exc_info.value.detail assert SageMakerSessionHeader.SESSION_ID in exc_info.value.detail def test_returns_none_when_sessions_not_enabled_and_no_header(self): From 9c9791abb67a8e10df21ea11d4a25ad0136edb2f Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Tue, 11 Nov 2025 18:21:48 +0000 Subject: [PATCH 06/25] Update way of setting sessions_path. --- .../sagemaker/sessions/manager.py | 2 +- .../sagemaker/sessions/transform.py | 4 +++- .../sagemaker/sessions/utils.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/manager.py b/python/model_hosting_container_standards/sagemaker/sessions/manager.py index d1b96e0..c169261 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/manager.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/manager.py @@ -139,7 +139,7 @@ def __init__(self, properties: dict): else: session_dir = os.path.join(tempfile.gettempdir(), "sagemaker_sessions") - self.sessions_path = properties.get("sessions_path", session_dir) + self.sessions_path = properties.get("sessions_path") or session_dir self.sessions: dict[str, Session] = {} self._lock = RLock() # Thread safety for concurrent session access diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transform.py index b893c19..0b3f167 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transform.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transform.py @@ -45,7 +45,9 @@ def _parse_session_request(request_data: dict) -> Optional[SessionRequest]: return None -def _validate_session_if_present(raw_request: Request, session_manager: SessionManager): +def _validate_session_if_present( + raw_request: Request, session_manager: Optional[SessionManager] +): """Validate that the session ID in the request exists and is not expired. Args: diff --git a/python/model_hosting_container_standards/sagemaker/sessions/utils.py b/python/model_hosting_container_standards/sagemaker/sessions/utils.py index 8926ae0..e0970d7 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/utils.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/utils.py @@ -1,5 +1,6 @@ import logging from http import HTTPStatus +from typing import Optional from fastapi import Request from fastapi.exceptions import HTTPException @@ -24,7 +25,7 @@ def get_session_id_from_request(raw_request: Request): return raw_request.headers.get(SageMakerSessionHeader.SESSION_ID) -def get_session(session_manager: SessionManager, raw_request: Request): +def get_session(session_manager: Optional[SessionManager], raw_request: Request): """Retrieve the session associated with the request. Args: From 413dfd27d70d49a3e3b5b42bf5ccbe4b18f4a4b2 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Mon, 17 Nov 2025 22:30:09 +0000 Subject: [PATCH 07/25] feat(initial - sagemaker/sessions): support engines with their own create/close session apis --- .../common/fastapi/utils.py | 34 +++++++- .../common/transforms/base_api_transform.py | 18 +++- .../sagemaker/__init__.py | 22 ++++- .../sagemaker/sagemaker_router.py | 5 ++ .../sagemaker/sessions/__init__.py | 35 ++++++++ .../sagemaker/sessions/handlers.py | 16 +++- .../sagemaker/sessions/transform.py | 17 +++- .../sagemaker/sessions/transforms/__init__.py | 0 .../sessions/transforms/close_session.py | 82 +++++++++++++++++ .../sessions/transforms/create_session.py | 87 +++++++++++++++++++ 10 files changed, 307 insertions(+), 9 deletions(-) create mode 100644 python/model_hosting_container_standards/sagemaker/sessions/transforms/__init__.py create mode 100644 python/model_hosting_container_standards/sagemaker/sessions/transforms/close_session.py create mode 100644 python/model_hosting_container_standards/sagemaker/sessions/transforms/create_session.py diff --git a/python/model_hosting_container_standards/common/fastapi/utils.py b/python/model_hosting_container_standards/common/fastapi/utils.py index 3c01253..dd3cd18 100644 --- a/python/model_hosting_container_standards/common/fastapi/utils.py +++ b/python/model_hosting_container_standards/common/fastapi/utils.py @@ -1,8 +1,13 @@ +import json +from logging import getLogger from typing import Any, Dict, Optional, Union -from fastapi import Request +from fastapi import Request, Response +from fastapi.responses import JSONResponse from pydantic import BaseModel +logger = getLogger(__name__) + def serialize_request( request: Optional[Union[BaseModel, Dict[str, Any]]], raw_request: Request @@ -33,3 +38,30 @@ def serialize_request( "query_params": raw_request.query_params, "path_params": raw_request.path_params, } + + +def serialize_response(response: Union[Response, JSONResponse]): + """Create a structured data dictionary for JMESPath transformations. + + Extracts and organizes response data into a standardized format that can be used + with JMESPath expressions to transform and extract specific data elements. + + :param Union[Response, JSONResponse] response: Response body data - can be: + - FastAPI Response object + - JSONResponse object + :return Dict[str, Any]: Structured data with body, headers, status_code, and media_type + """ + # Process response body based on type + body = response.body.decode(response.charset) + try: + body = json.loads(body) + except json.JSONDecodeError as e: + # If body is not JSON, leave it as a string + logger.warning(f"Response body is not JSON: {e}") + pass + + logger.info(body) + return { + "body": body, + "headers": response.headers, + } diff --git a/python/model_hosting_container_standards/common/transforms/base_api_transform.py b/python/model_hosting_container_standards/common/transforms/base_api_transform.py index f2fc35b..57e8e22 100644 --- a/python/model_hosting_container_standards/common/transforms/base_api_transform.py +++ b/python/model_hosting_container_standards/common/transforms/base_api_transform.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Field from ...logging_config import logger -from ..fastapi.utils import serialize_request +from ..fastapi.utils import serialize_request, serialize_response from .utils import _compile_jmespath_expressions @@ -54,7 +54,8 @@ def _transform( if isinstance(nested_or_compiled, jmespath.parser.ParsedResult): # Apply compiled JMESPath expression to extract value value = nested_or_compiled.search(source_data) - transformed_request[target_key] = value + if value: + transformed_request[target_key] = value elif isinstance(nested_or_compiled, dict): # Recursively transform nested structures transformed_request[target_key] = self._transform( @@ -103,6 +104,19 @@ async def transform_request(self, raw_request: Request): """ raise NotImplementedError() + def _transform_response(self, response: Response): + """Transform the response based on the request processing results. + + Subclasses must implement this method to handle request parsing, validation, + and transformation according to their specific operation requirements. + + :param Response response: The response to transform + :param transform_request_output: Output from the request transformation + :raises NotImplementedError: Must be implemented by subclasses + """ + response_data = serialize_response(response) + return self._transform(response_data, self._response_shape) + def _transform_request( self, request: Optional[BaseModel], raw_request: Request ) -> Dict[str, Any]: diff --git a/python/model_hosting_container_standards/sagemaker/__init__.py b/python/model_hosting_container_standards/sagemaker/__init__.py index 54b30fc..a15795c 100644 --- a/python/model_hosting_container_standards/sagemaker/__init__.py +++ b/python/model_hosting_container_standards/sagemaker/__init__.py @@ -21,7 +21,10 @@ from .lora.models import AppendOperation from .sagemaker_loader import SageMakerFunctionLoader from .sagemaker_router import create_sagemaker_router -from .sessions import create_session_transform_decorator +from .sessions import ( + create_session_transform_decorator, + register_engine_session_handler, +) # SageMaker decorator instances - created using utility functions @@ -131,6 +134,23 @@ def stateful_session_manager(): return create_session_transform_decorator()(request_shape={}, response_shape={}) +def register_create_session_handler( + request_shape, session_id_path: str, content_path: Optional[str] = None +): + return register_engine_session_handler( + "create_session", + request_shape=request_shape, + session_id_path=session_id_path, + content_path=content_path, + ) + + +def register_close_session_handler(request_shape, content_path: Optional[str] = None): + return register_engine_session_handler( + "close_session", request_shape=request_shape, content_path=content_path + ) + + def bootstrap(app: FastAPI) -> FastAPI: """Configure a FastAPI application with SageMaker functionality. diff --git a/python/model_hosting_container_standards/sagemaker/sagemaker_router.py b/python/model_hosting_container_standards/sagemaker/sagemaker_router.py index 45b66ec..ccb003e 100644 --- a/python/model_hosting_container_standards/sagemaker/sagemaker_router.py +++ b/python/model_hosting_container_standards/sagemaker/sagemaker_router.py @@ -43,6 +43,11 @@ def get_sagemaker_route_config(handler_type: str) -> Optional[RouteConfig]: summary="Model inference endpoint", ) + if handler_type in ["create_session", "close_session"]: + # It's a request transformer, not a standalone API endpoint + # It modifies requests in-flight but doesn't expose its own route + return None + # Delegate to LoRA route resolver for LoRA-specific handlers return get_lora_route_config(handler_type) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/__init__.py b/python/model_hosting_container_standards/sagemaker/sessions/__init__.py index 9d7978d..7f98bfa 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/__init__.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/__init__.py @@ -1,5 +1,12 @@ +from typing import Optional + from ...common.transforms.base_factory import create_transform_decorator +from .models import SageMakerSessionHeader from .transform import SessionApiTransform +from .transforms.create_session import ( + RESPONSE_CONTENT_KEY, + resolve_engine_session_transform, +) def resolve_session_transform(handler_type: str) -> type: @@ -20,3 +27,31 @@ def create_session_transform_decorator(): return create_transform_decorator( "stateful_session_manager", resolve_session_transform ) + + +def _create_engine_session_transform_decorator(handler_type: str): + return create_transform_decorator(handler_type, resolve_engine_session_transform) + + +def register_engine_session_handler( + handler_type: str, + request_shape, + session_id_path: Optional[str] = None, + content_path: Optional[str] = None, +): + """Register a handler for creating a new session. + + Args: + session_id_path: JMESPath expression for session ID + content_path: JMESPath expression for session content + """ + response_shape = { + RESPONSE_CONTENT_KEY: content_path, + } + if handler_type == "create_session": + if not session_id_path: + raise ValueError("session_id_path is required for create_session") + response_shape[SageMakerSessionHeader.NEW_SESSION_ID] = session_id_path + return _create_engine_session_transform_decorator(handler_type)( + request_shape, response_shape + ) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/handlers.py b/python/model_hosting_container_standards/sagemaker/sessions/handlers.py index e883bcb..01fb955 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/handlers.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/handlers.py @@ -5,6 +5,7 @@ from fastapi import Request, Response from fastapi.exceptions import HTTPException +from ...common.handler import handler_registry from .manager import get_session_manager from .models import ( SESSION_DISABLED_ERROR_DETAIL, @@ -27,10 +28,21 @@ def get_handler_for_request_type(request_type: SessionRequestType): Handler function for the request type, or None if no handler """ if request_type == SessionRequestType.NEW_SESSION: - return create_session + registered_handler = handler_registry.get_handler("create_session") + logger.info(f"Handler for {request_type} request: {registered_handler}") + if not registered_handler: + logger.debug(f"No handler found for {request_type} request, using default") + registered_handler = create_session # Default use SageMaker system + return registered_handler elif request_type == SessionRequestType.CLOSE: - return close_session + registered_handler = handler_registry.get_handler("close_session") + logger.info(f"Handler for {request_type} request: {registered_handler}") + if not registered_handler: + logger.debug(f"No handler found for {request_type} request, using default") + registered_handler = close_session # Default use SageMaker system + return registered_handler else: + logger.warning(f"No handler found for {request_type} request") return None diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transform.py index 0b3f167..01a42d0 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transform.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transform.py @@ -8,12 +8,14 @@ from pydantic import ValidationError from ...common import BaseApiTransform, BaseTransformRequestOutput +from ...common.handler import handler_registry from .handlers import get_handler_for_request_type from .manager import SessionManager, get_session_manager from .models import ( SESSION_DISABLED_ERROR_DETAIL, SESSION_DISABLED_LOG_MESSAGE, SessionRequest, + SessionRequestType, ) from .utils import get_session, get_session_id_from_request @@ -91,9 +93,18 @@ def process_session_request( """ session_request = _parse_session_request(request_data) - # Validate session if session ID is present in headers - # and raise error if session ID is invalid - _validate_session_if_present(raw_request, session_manager) + if ( + session_request + and session_request.requestType == SessionRequestType.NEW_SESSION + and not handler_registry.has_handler("create_session") + ) or ( + session_request + and session_request.requestType == SessionRequestType.CLOSE + and not handler_registry.has_handler("close_session") + ): + # Validate session if session ID is present in headers + # and raise error if session ID is invalid + _validate_session_if_present(raw_request, session_manager) # Not a session request - pass through for normal processing if session_request is None: diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/__init__.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/close_session.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/close_session.py new file mode 100644 index 0000000..4d650d9 --- /dev/null +++ b/python/model_hosting_container_standards/sagemaker/sessions/transforms/close_session.py @@ -0,0 +1,82 @@ +import json +from http import HTTPStatus +from typing import Any, Dict + +from fastapi import Request, Response +from fastapi.exceptions import HTTPException + +from ....common import BaseApiTransform, BaseTransformRequestOutput +from ..models import SageMakerSessionHeader +from ..utils import get_session_id_from_request + + +from pydantic import BaseModel +from logging import getLogger + +RESPONSE_CONTENT_KEY = "content" + +logger = getLogger(__name__) + + +class CloseSessionApiTransform(BaseApiTransform): + def __init__( + self, request_shape: Dict[str, Any], response_shape: Dict[str, Any] = {} + ): + try: + assert RESPONSE_CONTENT_KEY in response_shape.keys() + except AssertionError as e: + raise ValueError( + f"Response shape must contain {SageMakerSessionHeader.CLOSED_SESSION_ID} and {RESPONSE_CONTENT_KEY} keys" + ) from e + + super().__init__(request_shape, response_shape) + + async def transform_request(self, raw_request: Request): + try: + request_data = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e + transformed_request = self._transform_request(None, raw_request) + logger.info(transformed_request) + raw_request._body = json.dumps(transformed_request).encode("utf-8") + return BaseTransformRequestOutput( + request=transformed_request, + raw_request=raw_request, + intercept_func=None, + ) + + def transform_response(self, response: Response, transform_request_output): + session_id = get_session_id_from_request( + transform_request_output.raw_request + ) + if not hasattr(response, 'status_code'): + # Handle the case where the response is not a Response object + if isinstance(response, BaseModel): + response = response.model_dump_json() + elif not isinstance(response, str): + response = json.dumps(response) + response = Response( + status_code=HTTPStatus.OK.value, + content=response, + ) + if response.status_code == HTTPStatus.OK.value: + return self._transform_ok_response(response, session_id=session_id) + else: + return self._transform_error_response(response) + + def _transform_error_response(self, response: Response, **kwargs): + return response + + def _transform_ok_response(self, response: Response, **kwargs): + session_id = kwargs.get("session_id") + transformed_response_data = self._transform_response(response) + content = transformed_response_data.get(RESPONSE_CONTENT_KEY) + logger.info(f"Session {session_id}: {content}") + return Response( + status_code=HTTPStatus.OK.value, + content=f"Session {session_id}: {content}", + headers={SageMakerSessionHeader.CLOSED_SESSION_ID: session_id}, + ) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/create_session.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/create_session.py new file mode 100644 index 0000000..3a54b42 --- /dev/null +++ b/python/model_hosting_container_standards/sagemaker/sessions/transforms/create_session.py @@ -0,0 +1,87 @@ +import json +from http import HTTPStatus +from logging import getLogger +from typing import Any, Dict + +from fastapi import Request, Response +from fastapi.exceptions import HTTPException +from pydantic import BaseModel + +from ....common import BaseApiTransform, BaseTransformRequestOutput +from ..models import SageMakerSessionHeader +from .close_session import CloseSessionApiTransform + +RESPONSE_CONTENT_KEY = "content" + +logger = getLogger(__name__) + + +class CreateSessionApiTransform(BaseApiTransform): + def __init__( + self, request_shape: Dict[str, Any], response_shape: Dict[str, Any] = {} + ): + try: + assert SageMakerSessionHeader.NEW_SESSION_ID in response_shape.keys() + assert RESPONSE_CONTENT_KEY in response_shape.keys() + except AssertionError as e: + raise ValueError( + f"Response shape must contain {SageMakerSessionHeader.NEW_SESSION_ID} and {RESPONSE_CONTENT_KEY} keys" + ) from e + + super().__init__(request_shape, response_shape) + + async def transform_request(self, raw_request: Request): + try: + _ = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e + transformed_request = self._transform_request(None, raw_request) + raw_request._body = json.dumps(transformed_request).encode("utf-8") + return BaseTransformRequestOutput( + request=transformed_request, + raw_request=raw_request, + intercept_func=None, + ) + + def transform_response(self, response: Response, transform_request_output): + if not hasattr(response, "status_code"): + # Handle the case where the response is not a Response object + if isinstance(response, BaseModel): + response = response.model_dump_json() + elif not isinstance(response, str): + response = json.dumps(response) + response = Response( + status_code=HTTPStatus.OK.value, + content=response, + ) + if response.status_code == HTTPStatus.OK.value: + return self._transform_ok_response(response) + else: + return self._transform_error_response(response) + + def _transform_error_response(self, response: Response, **kwargs): + return response + + def _transform_ok_response(self, response: Response, **kwargs): + transformed_response_data = self._transform_response(response) + content = transformed_response_data.get(RESPONSE_CONTENT_KEY) + session_id = transformed_response_data.get( + SageMakerSessionHeader.NEW_SESSION_ID + ) + logger.info(f"Session {session_id}: {content}") + return Response( + status_code=HTTPStatus.OK.value, + content=f"Session {session_id}: {content}", + headers={SageMakerSessionHeader.NEW_SESSION_ID: session_id}, + ) + + +def resolve_engine_session_transform(handler_type: str): + if handler_type == "create_session": + return CreateSessionApiTransform + elif handler_type == "close_session": + return CloseSessionApiTransform + return None From aaf777426dc477df2ff67b48071d22c51f7003d8 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Tue, 18 Nov 2025 00:34:53 +0000 Subject: [PATCH 08/25] feat(initial - sagemaker/sessions): refactor create/close api transform methods --- .../common/fastapi/utils.py | 6 +- .../common/transforms/base_api_transform.py | 3 +- .../sagemaker/sessions/__init__.py | 30 +++- .../sagemaker/sessions/transform.py | 39 +++-- .../sagemaker/sessions/transforms/__init__.py | 15 ++ .../base_engine_session_api_transform.py | 142 ++++++++++++++++++ .../sessions/transforms/close_session.py | 84 +++++------ .../sessions/transforms/constants.py | 4 + .../sessions/transforms/create_session.py | 80 ++++------ 9 files changed, 283 insertions(+), 120 deletions(-) create mode 100644 python/model_hosting_container_standards/sagemaker/sessions/transforms/base_engine_session_api_transform.py create mode 100644 python/model_hosting_container_standards/sagemaker/sessions/transforms/constants.py diff --git a/python/model_hosting_container_standards/common/fastapi/utils.py b/python/model_hosting_container_standards/common/fastapi/utils.py index dd3cd18..dbcbf47 100644 --- a/python/model_hosting_container_standards/common/fastapi/utils.py +++ b/python/model_hosting_container_standards/common/fastapi/utils.py @@ -56,9 +56,9 @@ def serialize_response(response: Union[Response, JSONResponse]): try: body = json.loads(body) except json.JSONDecodeError as e: - # If body is not JSON, leave it as a string - logger.warning(f"Response body is not JSON: {e}") - pass + # If body is not JSON, keep it as a string + logger.warning(f"Response body is not JSON, keeping as string: {e}") + # body remains as string - no action needed logger.info(body) return { diff --git a/python/model_hosting_container_standards/common/transforms/base_api_transform.py b/python/model_hosting_container_standards/common/transforms/base_api_transform.py index 57e8e22..feeec81 100644 --- a/python/model_hosting_container_standards/common/transforms/base_api_transform.py +++ b/python/model_hosting_container_standards/common/transforms/base_api_transform.py @@ -54,8 +54,7 @@ def _transform( if isinstance(nested_or_compiled, jmespath.parser.ParsedResult): # Apply compiled JMESPath expression to extract value value = nested_or_compiled.search(source_data) - if value: - transformed_request[target_key] = value + transformed_request[target_key] = value elif isinstance(nested_or_compiled, dict): # Recursively transform nested structures transformed_request[target_key] = self._transform( diff --git a/python/model_hosting_container_standards/sagemaker/sessions/__init__.py b/python/model_hosting_container_standards/sagemaker/sessions/__init__.py index 7f98bfa..9e1af91 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/__init__.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/__init__.py @@ -3,10 +3,8 @@ from ...common.transforms.base_factory import create_transform_decorator from .models import SageMakerSessionHeader from .transform import SessionApiTransform -from .transforms.create_session import ( - RESPONSE_CONTENT_KEY, - resolve_engine_session_transform, -) +from .transforms import resolve_engine_session_transform +from .transforms.constants import RESPONSE_CONTENT_KEY def resolve_session_transform(handler_type: str) -> type: @@ -39,19 +37,37 @@ def register_engine_session_handler( session_id_path: Optional[str] = None, content_path: Optional[str] = None, ): - """Register a handler for creating a new session. + """Register a handler for engine-specific session management. Args: - session_id_path: JMESPath expression for session ID - content_path: JMESPath expression for session content + handler_type: Type of session handler ('create_session' or 'close_session') + request_shape: JMESPath expressions for transforming request data + session_id_path: JMESPath expression for extracting session ID from response + (required for 'create_session', ignored for 'close_session') + content_path: JMESPath expression for extracting content from response + + Returns: + Decorator function for registering the session handler + + Raises: + ValueError: If handler_type is invalid or required parameters are missing """ + # Validate handler_type + if handler_type not in ("create_session", "close_session"): + raise ValueError( + f"Invalid handler_type '{handler_type}'. " + f"Must be 'create_session' or 'close_session'" + ) + response_shape = { RESPONSE_CONTENT_KEY: content_path, } + if handler_type == "create_session": if not session_id_path: raise ValueError("session_id_path is required for create_session") response_shape[SageMakerSessionHeader.NEW_SESSION_ID] = session_id_path + return _create_engine_session_transform_decorator(handler_type)( request_shape, response_shape ) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transform.py index 01a42d0..1c89579 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transform.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transform.py @@ -47,6 +47,31 @@ def _parse_session_request(request_data: dict) -> Optional[SessionRequest]: return None +def _should_validate_session(session_request: Optional[SessionRequest]) -> bool: + """Determine if session validation is needed for the given request. + + Session validation is required when: + - Request is a NEW_SESSION without a custom create_session handler + - Request is a CLOSE without a custom close_session handler + + Args: + session_request: Parsed session request, or None if not a session request + + Returns: + True if session validation should be performed, False otherwise + """ + if not session_request: + return False + + if session_request.requestType == SessionRequestType.NEW_SESSION: + return not handler_registry.has_handler("create_session") + + if session_request.requestType == SessionRequestType.CLOSE: + return not handler_registry.has_handler("close_session") + + return False + + def _validate_session_if_present( raw_request: Request, session_manager: Optional[SessionManager] ): @@ -92,16 +117,8 @@ def process_session_request( HTTPException: If request is malformed or session validation fails """ session_request = _parse_session_request(request_data) - - if ( - session_request - and session_request.requestType == SessionRequestType.NEW_SESSION - and not handler_registry.has_handler("create_session") - ) or ( - session_request - and session_request.requestType == SessionRequestType.CLOSE - and not handler_registry.has_handler("close_session") - ): + should_validate = _should_validate_session(session_request) + if should_validate: # Validate session if session ID is present in headers # and raise error if session ID is invalid _validate_session_if_present(raw_request, session_manager) @@ -113,7 +130,7 @@ def process_session_request( intercept_func=None, ) - if session_manager is None: + if should_validate and session_manager is None: logger.error(SESSION_DISABLED_LOG_MESSAGE) raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/__init__.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/__init__.py index e69de29..c0b27d5 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transforms/__init__.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transforms/__init__.py @@ -0,0 +1,15 @@ +from .close_session import CloseSessionApiTransform +from .create_session import CreateSessionApiTransform + + +def resolve_engine_session_transform(handler_type: str): + """Resolve the appropriate transform class for engine session handlers. + + :param str handler_type: Type of session handler ('create_session' or 'close_session') + :return: Transform class or None if handler type is not recognized + """ + if handler_type == "create_session": + return CreateSessionApiTransform + elif handler_type == "close_session": + return CloseSessionApiTransform + return None diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/base_engine_session_api_transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/base_engine_session_api_transform.py new file mode 100644 index 0000000..9204f8a --- /dev/null +++ b/python/model_hosting_container_standards/sagemaker/sessions/transforms/base_engine_session_api_transform.py @@ -0,0 +1,142 @@ +import abc +import json +from http import HTTPStatus +from logging import getLogger + +from fastapi import Request, Response +from fastapi.exceptions import HTTPException +from pydantic import BaseModel + +from ....common import BaseApiTransform, BaseTransformRequestOutput + +logger = getLogger(__name__) + + +class BaseEngineSessionApiTransform(BaseApiTransform): + """Base abstract class for engine-specific session API transformations. + + This class provides the foundation for transforming HTTP requests and responses + for engines that implement their own session management APIs. It handles common + response normalization and routing logic, while subclasses implement specific + transformation behavior for create/close session operations. + """ + + async def transform_request( + self, raw_request: Request + ) -> BaseTransformRequestOutput: + """Transform an incoming HTTP request for engine session operations. + + Parses JSON request body, applies JMESPath transformations, and validates + any session-specific requirements. Subclasses can override to add custom + validation logic before or after the base transformation. + + :param Request raw_request: The incoming FastAPI request object + :return BaseTransformRequestOutput: Transformed request data and metadata + :raises HTTPException: If JSON parsing fails or validation errors occur + """ + # Subclasses can override _validate_request_preconditions for early validation + self._validate_request_preconditions(raw_request) + + try: + request_data = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e + + transformed_request = self._transform_request(request_data, raw_request) + raw_request._body = json.dumps(transformed_request).encode("utf-8") + + return BaseTransformRequestOutput( + request=transformed_request, + raw_request=raw_request, + intercept_func=None, + ) + + def _validate_request_preconditions(self, raw_request: Request) -> None: + """Validate request preconditions before transformation. + + Subclasses can override this method to perform early validation + (e.g., checking for required headers). Default implementation does nothing. + + :param Request raw_request: The incoming request to validate + :raises HTTPException: If validation fails + """ + pass + + def transform_response( + self, response: Response, transform_request_output: BaseTransformRequestOutput + ) -> Response: + """Transform the response based on the request processing results. + + Normalizes various response types to FastAPI Response objects and routes + to appropriate transformation method based on HTTP status code. + + :param Response response: The response to transform (may be Response, BaseModel, dict, or str) + :param BaseTransformRequestOutput transform_request_output: Output from the request transformation + :return Response: Transformed response + """ + # Normalize response to Response object + response = self._normalize_response(response) + + # Route based on status code + if response.status_code == HTTPStatus.OK.value: + return self._transform_ok_response( + response, transform_request_output=transform_request_output + ) + else: + return self._transform_error_response(response) + + def _normalize_response(self, response): + """Convert various response types to FastAPI Response object. + + Handles responses that may be BaseModel instances, dictionaries, strings, + or already Response objects. If the response doesn't have a status_code, + it's assumed to be a successful response (200 OK) from the engine handler. + + Note: This method only normalizes the response format. Validation of required + fields (like session IDs) should be done in _transform_ok_response() to provide + appropriate error responses if the engine returns invalid data. + + :param response: Response in various formats + :return Response: Normalized FastAPI Response object + """ + if not hasattr(response, "status_code"): + # Handle the case where the response is not a Response object + # Assume success if the handler returned data without explicit status + if isinstance(response, BaseModel): + response = response.model_dump_json() + elif not isinstance(response, str): + response = json.dumps(response) + response = Response( + status_code=HTTPStatus.OK.value, + content=response, + ) + return response + + @abc.abstractmethod + def _transform_ok_response(self, response: Response, **kwargs) -> Response: + """Transform successful (200 OK) responses. + + Subclasses must implement this method to handle session-specific response + formatting and header management. + + :param Response response: The successful response to transform + :param BaseTransformRequestOutput transform_request_output: Output from the request transformation + :return Response: Transformed response + :raises NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError() + + def _transform_error_response(self, response: Response, **kwargs) -> Response: + """Transform error responses. + + Default implementation passes through error responses unchanged. + Subclasses can override to add custom error handling. + + :param Response response: The error response to transform + :param BaseTransformRequestOutput transform_request_output: Output from the request transformation + :return Response: Transformed response (default: unchanged) + """ + return response diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/close_session.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/close_session.py index 4d650d9..ef5aa5d 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transforms/close_session.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transforms/close_session.py @@ -1,24 +1,20 @@ -import json from http import HTTPStatus +from logging import getLogger from typing import Any, Dict from fastapi import Request, Response from fastapi.exceptions import HTTPException -from ....common import BaseApiTransform, BaseTransformRequestOutput +from ....common import BaseTransformRequestOutput from ..models import SageMakerSessionHeader from ..utils import get_session_id_from_request - - -from pydantic import BaseModel -from logging import getLogger - -RESPONSE_CONTENT_KEY = "content" +from .base_engine_session_api_transform import BaseEngineSessionApiTransform +from .constants import RESPONSE_CONTENT_KEY logger = getLogger(__name__) -class CloseSessionApiTransform(BaseApiTransform): +class CloseSessionApiTransform(BaseEngineSessionApiTransform): def __init__( self, request_shape: Dict[str, Any], response_shape: Dict[str, Any] = {} ): @@ -26,54 +22,48 @@ def __init__( assert RESPONSE_CONTENT_KEY in response_shape.keys() except AssertionError as e: raise ValueError( - f"Response shape must contain {SageMakerSessionHeader.CLOSED_SESSION_ID} and {RESPONSE_CONTENT_KEY} keys" + f"Response shape must contain {RESPONSE_CONTENT_KEY} key" ) from e super().__init__(request_shape, response_shape) - async def transform_request(self, raw_request: Request): - try: - request_data = await raw_request.json() - except json.JSONDecodeError as e: + def _validate_request_preconditions(self, raw_request: Request) -> None: + """Validate that session ID exists in request headers before processing. + + :param Request raw_request: The incoming request to validate + :raises HTTPException: If session ID is missing from headers + """ + session_id = get_session_id_from_request(raw_request) + if not session_id: + logger.error("No session ID found in request headers for close session") raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, - detail=f"JSON decode error: {e}", - ) from e - transformed_request = self._transform_request(None, raw_request) - logger.info(transformed_request) - raw_request._body = json.dumps(transformed_request).encode("utf-8") - return BaseTransformRequestOutput( - request=transformed_request, - raw_request=raw_request, - intercept_func=None, - ) - - def transform_response(self, response: Response, transform_request_output): - session_id = get_session_id_from_request( - transform_request_output.raw_request - ) - if not hasattr(response, 'status_code'): - # Handle the case where the response is not a Response object - if isinstance(response, BaseModel): - response = response.model_dump_json() - elif not isinstance(response, str): - response = json.dumps(response) - response = Response( - status_code=HTTPStatus.OK.value, - content=response, + detail="Session ID is required in request headers to close a session", ) - if response.status_code == HTTPStatus.OK.value: - return self._transform_ok_response(response, session_id=session_id) - else: - return self._transform_error_response(response) - - def _transform_error_response(self, response: Response, **kwargs): - return response - def _transform_ok_response(self, response: Response, **kwargs): - session_id = kwargs.get("session_id") + def _transform_ok_response(self, response: Response, **kwargs) -> Response: + """Transform successful close session response. + + Extracts session ID from request headers and content from engine response, + validates them, and returns formatted response with CLOSED_SESSION_ID header. + + :param Response response: The successful response to transform + :param BaseTransformRequestOutput transform_request_output: Output from the request transformation + :return Response: Transformed response with session headers + """ + transform_request_output: BaseTransformRequestOutput = kwargs.get("transform_request_output") # type: ignore + # Session ID already validated in transform_request, safe to extract + session_id = get_session_id_from_request(transform_request_output.raw_request) + transformed_response_data = self._transform_response(response) content = transformed_response_data.get(RESPONSE_CONTENT_KEY) + + # Validate that content was extracted from the response + if not content: + logger.debug( + f"No content extracted from close session response for session {session_id}" + ) + logger.info(f"Session {session_id}: {content}") return Response( status_code=HTTPStatus.OK.value, diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/constants.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/constants.py new file mode 100644 index 0000000..3c696bd --- /dev/null +++ b/python/model_hosting_container_standards/sagemaker/sessions/transforms/constants.py @@ -0,0 +1,4 @@ +"""Constants for engine session transforms.""" + +# Key used in response_shape to specify where to extract content from engine response +RESPONSE_CONTENT_KEY = "content" diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/create_session.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/create_session.py index 3a54b42..c63c6f9 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transforms/create_session.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transforms/create_session.py @@ -1,22 +1,18 @@ -import json from http import HTTPStatus from logging import getLogger from typing import Any, Dict -from fastapi import Request, Response +from fastapi import Response from fastapi.exceptions import HTTPException -from pydantic import BaseModel -from ....common import BaseApiTransform, BaseTransformRequestOutput from ..models import SageMakerSessionHeader -from .close_session import CloseSessionApiTransform - -RESPONSE_CONTENT_KEY = "content" +from .base_engine_session_api_transform import BaseEngineSessionApiTransform +from .constants import RESPONSE_CONTENT_KEY logger = getLogger(__name__) -class CreateSessionApiTransform(BaseApiTransform): +class CreateSessionApiTransform(BaseEngineSessionApiTransform): def __init__( self, request_shape: Dict[str, Any], response_shape: Dict[str, Any] = {} ): @@ -30,58 +26,42 @@ def __init__( super().__init__(request_shape, response_shape) - async def transform_request(self, raw_request: Request): - try: - _ = await raw_request.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail=f"JSON decode error: {e}", - ) from e - transformed_request = self._transform_request(None, raw_request) - raw_request._body = json.dumps(transformed_request).encode("utf-8") - return BaseTransformRequestOutput( - request=transformed_request, - raw_request=raw_request, - intercept_func=None, - ) - - def transform_response(self, response: Response, transform_request_output): - if not hasattr(response, "status_code"): - # Handle the case where the response is not a Response object - if isinstance(response, BaseModel): - response = response.model_dump_json() - elif not isinstance(response, str): - response = json.dumps(response) - response = Response( - status_code=HTTPStatus.OK.value, - content=response, - ) - if response.status_code == HTTPStatus.OK.value: - return self._transform_ok_response(response) - else: - return self._transform_error_response(response) + def _transform_ok_response(self, response: Response, **kwargs) -> Response: + """Transform successful create session response. - def _transform_error_response(self, response: Response, **kwargs): - return response + Extracts session ID and content from engine response, validates them, + and returns formatted response with NEW_SESSION_ID header. - def _transform_ok_response(self, response: Response, **kwargs): + :param Response response: The successful response to transform + :return Response: Transformed response with session headers + :raises HTTPException: If session ID cannot be extracted from response + """ transformed_response_data = self._transform_response(response) content = transformed_response_data.get(RESPONSE_CONTENT_KEY) session_id = transformed_response_data.get( SageMakerSessionHeader.NEW_SESSION_ID ) + + # Validate that session_id was extracted from the response + if not session_id: + logger.error( + f"Failed to extract session ID from engine response. " + f"Response data: {transformed_response_data}" + ) + raise HTTPException( + status_code=HTTPStatus.BAD_GATEWAY.value, + detail="Engine failed to return a valid session ID in the response", + ) + + # Validate that content was extracted from the response + if not content: + logger.debug( + f"No content extracted from create session response for session {session_id}" + ) + logger.info(f"Session {session_id}: {content}") return Response( status_code=HTTPStatus.OK.value, content=f"Session {session_id}: {content}", headers={SageMakerSessionHeader.NEW_SESSION_ID: session_id}, ) - - -def resolve_engine_session_transform(handler_type: str): - if handler_type == "create_session": - return CreateSessionApiTransform - elif handler_type == "close_session": - return CloseSessionApiTransform - return None From 1a31cf2c18ed82a8ba20311d455b872824b521d3 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Mon, 1 Dec 2025 20:28:09 +0000 Subject: [PATCH 09/25] import logger to sessions/transform.py --- .../sagemaker/sessions/transform.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transform.py index 252a8fc..c890e13 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transform.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transform.py @@ -7,6 +7,7 @@ from pydantic import ValidationError from ...common import BaseApiTransform, BaseTransformRequestOutput +from ...logging_config import logger from .handlers import get_handler_for_request_type from .manager import SessionManager, get_session_manager from .models import ( From ce4ab4000819e1eb55ca91306ffc15f4795d2636 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Mon, 1 Dec 2025 22:46:02 +0000 Subject: [PATCH 10/25] Remove manual logger setups. --- .../model_hosting_container_standards/common/fastapi/utils.py | 4 ---- .../sessions/transforms/base_engine_session_api_transform.py | 4 +--- .../sagemaker/sessions/transforms/close_session.py | 4 +--- .../sagemaker/sessions/transforms/create_session.py | 4 +--- 4 files changed, 3 insertions(+), 13 deletions(-) diff --git a/python/model_hosting_container_standards/common/fastapi/utils.py b/python/model_hosting_container_standards/common/fastapi/utils.py index dbcbf47..324aafe 100644 --- a/python/model_hosting_container_standards/common/fastapi/utils.py +++ b/python/model_hosting_container_standards/common/fastapi/utils.py @@ -1,13 +1,10 @@ import json -from logging import getLogger from typing import Any, Dict, Optional, Union from fastapi import Request, Response from fastapi.responses import JSONResponse from pydantic import BaseModel -logger = getLogger(__name__) - def serialize_request( request: Optional[Union[BaseModel, Dict[str, Any]]], raw_request: Request @@ -60,7 +57,6 @@ def serialize_response(response: Union[Response, JSONResponse]): logger.warning(f"Response body is not JSON, keeping as string: {e}") # body remains as string - no action needed - logger.info(body) return { "body": body, "headers": response.headers, diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/base_engine_session_api_transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/base_engine_session_api_transform.py index 9204f8a..29820d2 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transforms/base_engine_session_api_transform.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transforms/base_engine_session_api_transform.py @@ -1,15 +1,13 @@ import abc import json from http import HTTPStatus -from logging import getLogger from fastapi import Request, Response from fastapi.exceptions import HTTPException from pydantic import BaseModel from ....common import BaseApiTransform, BaseTransformRequestOutput - -logger = getLogger(__name__) +from ....logging_config import logger class BaseEngineSessionApiTransform(BaseApiTransform): diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/close_session.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/close_session.py index ef5aa5d..1878892 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transforms/close_session.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transforms/close_session.py @@ -1,18 +1,16 @@ from http import HTTPStatus -from logging import getLogger from typing import Any, Dict from fastapi import Request, Response from fastapi.exceptions import HTTPException from ....common import BaseTransformRequestOutput +from ....logging_config import logger from ..models import SageMakerSessionHeader from ..utils import get_session_id_from_request from .base_engine_session_api_transform import BaseEngineSessionApiTransform from .constants import RESPONSE_CONTENT_KEY -logger = getLogger(__name__) - class CloseSessionApiTransform(BaseEngineSessionApiTransform): def __init__( diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/create_session.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/create_session.py index c63c6f9..f92762a 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transforms/create_session.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transforms/create_session.py @@ -1,16 +1,14 @@ from http import HTTPStatus -from logging import getLogger from typing import Any, Dict from fastapi import Response from fastapi.exceptions import HTTPException +from ....logging_config import logger from ..models import SageMakerSessionHeader from .base_engine_session_api_transform import BaseEngineSessionApiTransform from .constants import RESPONSE_CONTENT_KEY -logger = getLogger(__name__) - class CreateSessionApiTransform(BaseEngineSessionApiTransform): def __init__( From 4f62b3fe3d2d3de203e32330ebcdbb595900981d Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Tue, 2 Dec 2025 00:05:02 +0000 Subject: [PATCH 11/25] Update README.md --- .../sagemaker/sessions/README.md | 59 ++++++++++++++----- 1 file changed, 45 insertions(+), 14 deletions(-) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/README.md b/python/model_hosting_container_standards/sagemaker/sessions/README.md index 4720d91..9e0309b 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/README.md +++ b/python/model_hosting_container_standards/sagemaker/sessions/README.md @@ -2,6 +2,17 @@ This module provides stateful session management for SageMaker model hosting containers, enabling multi-turn conversations and persistent state across requests. +## Table of Contents + +- [Overview](#overview) +- [Architecture](#architecture) +- [Quick Start](#quick-start) +- [Configuration](#configuration) +- [Session Storage](#session-storage) +- [Expiration and Cleanup](#expiration-and-cleanup) +- [Advanced Usage](#advanced-usage) + - [Custom Session Handlers](./CUSTOM_HANDLERS.md) + ## Overview Stateful sessions allow clients to maintain context across multiple inference requests without passing all state in every request. Each session has: @@ -10,6 +21,23 @@ Stateful sessions allow clients to maintain context across multiple inference re - **Automatic expiration**: Configurable TTL (default: 20 minutes) - **Thread-safe access**: Concurrent request handling +### Session Management Modes + +The framework supports two modes of session management: + +1. **SageMaker-Managed Sessions** (Default) + - Sessions managed by the built-in `SessionManager` + - File-based key-value storage in `/dev/shm` + - Automatic expiration and cleanup + - Best for general-purpose session state + +2. **Engine-Managed Sessions** (Custom Handlers) + - Sessions delegated to your inference engine's native API + - Leverages engine-specific session features + - Requires custom handler registration + - Best when engine has built-in session support (e.g., vLLM, TGI) + - See [CUSTOM_HANDLERS.md](./CUSTOM_HANDLERS.md) for details + ## Architecture ``` @@ -31,7 +59,7 @@ SessionApiTransform (transform.py) - **Session Handlers** (`handlers.py`): Functions to create and close sessions - **Utilities** (`utils.py`): Helper functions for session ID extraction and retrieval -## Usage +## Quick Start ### Enabling Sessions in Your Handler @@ -102,12 +130,23 @@ session_manager = SessionManager({ ### Storage Location Sessions are stored in memory-backed filesystem when available: -- **Preferred**: `/dev/shm/sagemaker_sessions` (tmpfs - fast) +- **Preferred**: `/dev/shm/sagemaker_sessions` (tmpfs - fast, in-memory) - **Fallback**: `{tempdir}/sagemaker_sessions` (disk-backed) +**Note**: Session data is not persistent across container restarts. + ## Session Storage -Each session maintains its own directory with JSON files for key-value pairs. +Each session maintains its own directory with JSON files for key-value pairs: + +``` +/dev/shm/sagemaker_sessions/ +├── / +│ ├── key1.json +│ └── key2.json +└── / + └── key1.json +``` ## Expiration and Cleanup @@ -119,16 +158,8 @@ Each session maintains its own directory with JSON files for key-value pairs. ## Advanced Usage -For more control, use `create_session_transform_decorator()` directly: - -```python -from model_hosting_container_standards.sagemaker.sessions import create_session_transform_decorator +### Custom Session Handlers -session_transform = create_session_transform_decorator() - -@session_transform(request_shape={}, response_shape={}) -def my_handler(request, context): - pass -``` +If your inference engine has its own session management API, you can register custom handlers to delegate session creation and closure to the engine instead of using SageMaker's built-in session management. -**Note**: `SessionApiTransform` ignores the `request_shape` and `response_shape` parameters. These are passed to the parent `BaseApiTransform` class for interface compatibility, but session requests use their own validation via `SessionRequest` model instead of JMESPath transformations. +See [CUSTOM_HANDLERS.md](./CUSTOM_HANDLERS.md) for detailed documentation on implementing custom create/close session handlers. From 2975b104155302d6b0e941870ae3564b45bfe669 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Tue, 2 Dec 2025 00:09:32 +0000 Subject: [PATCH 12/25] Fix linting. --- .../common/fastapi/utils.py | 6 +- .../sagemaker/sessions/CUSTOM_HANDLERS.md | 106 ++++++++++++++++++ .../base_engine_session_api_transform.py | 1 - 3 files changed, 109 insertions(+), 4 deletions(-) create mode 100644 python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md diff --git a/python/model_hosting_container_standards/common/fastapi/utils.py b/python/model_hosting_container_standards/common/fastapi/utils.py index 324aafe..886bd10 100644 --- a/python/model_hosting_container_standards/common/fastapi/utils.py +++ b/python/model_hosting_container_standards/common/fastapi/utils.py @@ -52,10 +52,10 @@ def serialize_response(response: Union[Response, JSONResponse]): body = response.body.decode(response.charset) try: body = json.loads(body) - except json.JSONDecodeError as e: + except json.JSONDecodeError: # If body is not JSON, keep it as a string - logger.warning(f"Response body is not JSON, keeping as string: {e}") - # body remains as string - no action needed + # logger.warning(f"Response body is not JSON, keeping as string: {e}") + pass return { "body": body, diff --git a/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md b/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md new file mode 100644 index 0000000..3c06300 --- /dev/null +++ b/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md @@ -0,0 +1,106 @@ +# Custom Session Handlers + +This guide explains how to implement custom create and close session handlers when your inference engine has its own session management API. + +## Overview + +By default, SageMaker's session management uses the built-in `SessionManager` to handle session lifecycle. However, if your inference engine (like vLLM, TGI, or a custom engine) provides its own session API, you can register custom handlers to delegate session operations to the engine. + +### When to Use Custom Handlers + +Use custom handlers when: +- Your engine has native session management capabilities +- You want to leverage engine-specific session features +- Session state needs to be managed within the engine's memory space +- You need custom session initialization or cleanup logic + +### Architecture + +``` +Client Request + ↓ +SessionApiTransform (detects session request) + ↓ +get_handler_for_request_type() + ↓ + ├─→ Custom Handler (if registered) + │ └─→ Engine's Session API + │ + └─→ Default Handler (if not registered) + └─→ SageMaker SessionManager +``` + +## Handler Signatures + +Both handlers must be async functions that accept a FastAPI `Request` object: + +```python +from fastapi import Request, Response + +async def my_create_session_handler(raw_request: Request) -> Response: + """Create a new session via the engine's API.""" + pass + +async def my_close_session_handler(raw_request: Request) -> Response: + """Close an existing session via the engine's API.""" + pass +``` + +## Using Transform Classes + + +```python +from model_hosting_container_standards.sagemaker.sessions.transforms import ( + CreateSessionApiTransform, + CloseSessionApiTransform +) + +# Define request/response shapes using JMESPath +create_transform = CreateSessionApiTransform( + request_shape={}, # Transform incoming request + response_shape={ + "X-Amzn-SageMaker-New-Session-Id": "session_id", + "content": "message" + } +) + +close_transform = CloseSessionApiTransform( + request_shape={ + "session_id": "headers.'X-Amzn-SageMaker-Session-Id'" + }, + response_shape={ + "content": "message" + } +) +``` + +## Best Practices + +1. **Validate session IDs**: Always validate that the engine returns valid session IDs +2. **Handle timeouts**: Set appropriate timeouts when calling engine APIs +3. **Log operations**: Log session creation/closure for debugging +4. **Error propagation**: Provide clear error messages when engine operations fail +5. **Cleanup**: Ensure sessions are properly cleaned up even on errors +6. **Testing**: Test both success and failure scenarios +7. **Idempotency**: Handle duplicate close requests gracefully + +## Utilities + +The framework provides utility functions for working with sessions: + +```python +from model_hosting_container_standards.sagemaker.sessions.utils import ( + get_session_id_from_request, # Extract session ID from headers + get_session, # Get session from manager +) +from model_hosting_container_standards.sagemaker.sessions.models import ( + SageMakerSessionHeader, # Header name constants + SessionRequestType, # Request type enum +) +``` + +## See Also + +- [README.md](./README.md) - Main sessions documentation +- [handlers.py](./handlers.py) - Default handler implementations +- [transforms/](./transforms/) - Transform classes for engine integration diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transforms/base_engine_session_api_transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transforms/base_engine_session_api_transform.py index 29820d2..62ab3ae 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transforms/base_engine_session_api_transform.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transforms/base_engine_session_api_transform.py @@ -7,7 +7,6 @@ from pydantic import BaseModel from ....common import BaseApiTransform, BaseTransformRequestOutput -from ....logging_config import logger class BaseEngineSessionApiTransform(BaseApiTransform): From 39a1af29332f4ab7ef59663559d7764277dcd707 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Wed, 3 Dec 2025 19:45:58 +0000 Subject: [PATCH 13/25] wip - update stateful sessions manager to move sm id header to target --- .../sagemaker/sessions/transform.py | 116 ++++++++++-------- 1 file changed, 63 insertions(+), 53 deletions(-) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transform.py index b30a4fc..e17d7c5 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transform.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transform.py @@ -86,6 +86,7 @@ def _validate_session_if_present( if session_id: try: get_session(session_manager, raw_request) + return session_id except ValueError as e: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, @@ -93,56 +94,6 @@ def _validate_session_if_present( ) -def process_session_request( - request_data: dict, raw_request: Request, session_manager: Optional[SessionManager] -): - """Process a potential session management request. - - Determines if the request is a session management operation (NEW_SESSION or CLOSE) - and routes it to the appropriate handler, or passes through for normal processing. - - Args: - request_data: Parsed JSON request body - raw_request: FastAPI Request object - session_manager: SessionManager instance - - Returns: - BaseTransformRequestOutput with either: - - intercept_func set if this is a session management request - - None/passthrough if this is a regular request - - Raises: - HTTPException: If request is malformed or session validation fails - """ - session_request = _parse_session_request(request_data) - should_validate = _should_validate_session(session_request) - if should_validate: - # Validate session if session ID is present in headers - # and raise error if session ID is invalid - _validate_session_if_present(raw_request, session_manager) - - # Not a session request - pass through for normal processing - if session_request is None: - return BaseTransformRequestOutput( - raw_request=raw_request, - intercept_func=None, - ) - - if should_validate and session_manager is None: - logger.error(SESSION_DISABLED_LOG_MESSAGE) - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail=SESSION_DISABLED_ERROR_DETAIL, - ) - - # Route to appropriate session management handler - intercept_func = get_handler_for_request_type(session_request.requestType) - - return BaseTransformRequestOutput( - raw_request=raw_request, intercept_func=intercept_func - ) - - class SessionApiTransform(BaseApiTransform): """API transform that intercepts and processes stateful session management requests. @@ -155,8 +106,8 @@ def __init__(self, request_shape, response_shape={}): """Initialize the SessionApiTransform. Args: - request_shape: Passed to parent BaseApiTransform (unused in session logic) - response_shape: Passed to parent BaseApiTransform (unused in session logic) + request_shape: Passed to parent BaseApiTransform + response_shape: Passed to parent BaseApiTransform Note: The request/response shapes are passed to the parent class but not used @@ -183,7 +134,7 @@ async def transform_request(self, raw_request): """ try: request_data = await raw_request.json() - return process_session_request( + return self._process_session_request( request_data, raw_request, self._session_manager ) except json.JSONDecodeError as e: @@ -203,3 +154,62 @@ def transform_response(self, response, transform_request_output): The unmodified response object """ return response + + + def _process_invocations_request( + self, session_id: Optional[str], request_data: dict, raw_request: Request + ): + # if session_id: + # TODO: move session id to location based on request shape + return BaseTransformRequestOutput( + raw_request=raw_request, + intercept_func=None, + ) + + + def _process_session_request( + self, request_data: dict, raw_request: Request, session_manager: Optional[SessionManager] + ): + """Process a potential session management request. + + Determines if the request is a session management operation (NEW_SESSION or CLOSE) + and routes it to the appropriate handler, or passes through for normal processing. + + Args: + request_data: Parsed JSON request body + raw_request: FastAPI Request object + session_manager: SessionManager instance + + Returns: + BaseTransformRequestOutput with either: + - intercept_func set if this is a session management request + - None/passthrough if this is a regular request + + Raises: + HTTPException: If request is malformed or session validation fails + """ + session_request = _parse_session_request(request_data) + should_validate = _should_validate_session(session_request) + session_id = None + if should_validate: + # Validate session if session ID is present in headers + # and raise error if session ID is invalid + session_id = _validate_session_if_present(raw_request, session_manager) + + # Not a session request - pass through for normal processing + if session_request is None: + self._process_invocations_request(request_data, raw_request) + + if should_validate and session_manager is None: + logger.error(SESSION_DISABLED_LOG_MESSAGE) + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=SESSION_DISABLED_ERROR_DETAIL, + ) + + # Route to appropriate session management handler + intercept_func = get_handler_for_request_type(session_request.requestType) + + return BaseTransformRequestOutput( + raw_request=raw_request, intercept_func=intercept_func + ) \ No newline at end of file From c77013e727094af33c077281b80dbe4f96f24542 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Thu, 4 Dec 2025 19:32:05 +0000 Subject: [PATCH 14/25] fix(sessions): Fix session ID injection and update tests - Fix session ID injection logic to work with default session manager - Changed conditional from elif to separate if statements - Session ID now properly injected into request body when session_id_path is specified - Validation and injection can now happen independently - Update unit tests to match refactored transform API - Remove import of non-existent _validate_session_if_present function - Update test signatures to match _process_session_request parameters - Add missing SessionRequest import - Replace outdated test logic with current implementation - Add integration tests for session_id_path injection feature - Test session ID injection into flat and nested request body paths - Test injection with multiple requests and different sessions - Verify existing body fields are preserved during injection --- .../sagemaker/__init__.py | 12 +- .../sagemaker/sessions/transform.py | 170 ++++++-------- .../test_sagemaker_sessions_integration.py | 222 +++++++++++++++++- .../sagemaker/sessions/test_transform.py | 192 +++++++-------- 4 files changed, 394 insertions(+), 202 deletions(-) diff --git a/python/model_hosting_container_standards/sagemaker/__init__.py b/python/model_hosting_container_standards/sagemaker/__init__.py index a15795c..2d38fba 100644 --- a/python/model_hosting_container_standards/sagemaker/__init__.py +++ b/python/model_hosting_container_standards/sagemaker/__init__.py @@ -25,6 +25,7 @@ create_session_transform_decorator, register_engine_session_handler, ) +from .sessions.models import SageMakerSessionHeader # SageMaker decorator instances - created using utility functions @@ -121,7 +122,7 @@ def inject_adapter_id( ) -def stateful_session_manager(): +def stateful_session_manager(session_id_path: Optional[str] = None): """Create a decorator for session-based sticky routing. This decorator enables stateful session management without JMESPath transformations. @@ -131,7 +132,14 @@ def stateful_session_manager(): Returns: A decorator that can be applied to route handlers to enable session management """ - return create_session_transform_decorator()(request_shape={}, response_shape={}) + request_shape = {} + if session_id_path: + request_shape[session_id_path] = ( + f'headers."{SageMakerSessionHeader.SESSION_ID}"' + ) + return create_session_transform_decorator()( + request_shape=request_shape, response_shape={} + ) def register_create_session_handler( diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transform.py index e17d7c5..a31d721 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transform.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transform.py @@ -8,12 +8,14 @@ from ...common import BaseApiTransform, BaseTransformRequestOutput from ...common.handler import handler_registry +from ...common.transforms.utils import set_value from ...logging_config import logger from .handlers import get_handler_for_request_type from .manager import SessionManager, get_session_manager from .models import ( SESSION_DISABLED_ERROR_DETAIL, SESSION_DISABLED_LOG_MESSAGE, + SageMakerSessionHeader, SessionRequest, SessionRequestType, ) @@ -45,63 +47,7 @@ def _parse_session_request(request_data: dict) -> Optional[SessionRequest]: return None -def _should_validate_session(session_request: Optional[SessionRequest]) -> bool: - """Determine if session validation is needed for the given request. - - Session validation is required when: - - Request is a NEW_SESSION without a custom create_session handler - - Request is a CLOSE without a custom close_session handler - - Args: - session_request: Parsed session request, or None if not a session request - - Returns: - True if session validation should be performed, False otherwise - """ - if not session_request: - return False - - if session_request.requestType == SessionRequestType.NEW_SESSION: - return not handler_registry.has_handler("create_session") - - if session_request.requestType == SessionRequestType.CLOSE: - return not handler_registry.has_handler("close_session") - - return False - - -def _validate_session_if_present( - raw_request: Request, session_manager: Optional[SessionManager] -): - """Validate that the session ID in the request exists and is not expired. - - Args: - raw_request: FastAPI Request object - session_manager: Optional SessionManager instance - - Raises: - HTTPException: If session validation fails - """ - session_id = get_session_id_from_request(raw_request) - if session_id: - try: - get_session(session_manager, raw_request) - return session_id - except ValueError as e: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail=f"Bad request: {str(e)}", - ) - - class SessionApiTransform(BaseApiTransform): - """API transform that intercepts and processes stateful session management requests. - - This transform extends BaseApiTransform to add session management capabilities. - It parses incoming requests to detect session management operations (NEW_SESSION, CLOSE) - and routes them to appropriate handlers, while passing through regular API requests. - """ - def __init__(self, request_shape, response_shape={}): """Initialize the SessionApiTransform. @@ -113,9 +59,24 @@ def __init__(self, request_shape, response_shape={}): The request/response shapes are passed to the parent class but not used for validation in this transform, as session requests use their own validation. """ + # Use default session manager if no custom create / close session handlers are registered + self._use_default_manager = not handler_registry.has_handler( + "create_session" + ) and not handler_registry.has_handler("close_session") self._session_manager = get_session_manager() + + # Extract session_id_target_key before compiling JMESPath expressions + self._session_id_target_key = self._get_session_id_target_key(request_shape) super().__init__(request_shape, response_shape) + def _get_session_id_target_key(self, request_shape: dict) -> Optional[str]: + if not request_shape: + return None + for target_key, source_path in request_shape.items(): + if source_path == f'headers."{SageMakerSessionHeader.SESSION_ID}"': + return target_key + return None + async def transform_request(self, raw_request): """Transform incoming request, intercepting session management operations. @@ -134,7 +95,7 @@ async def transform_request(self, raw_request): """ try: request_data = await raw_request.json() - return self._process_session_request( + return self._process_request( request_data, raw_request, self._session_manager ) except json.JSONDecodeError as e: @@ -155,61 +116,78 @@ def transform_response(self, response, transform_request_output): """ return response - + def _validate_session_id(self, session_id: Optional[str], raw_request: Request): + """Validate that the session ID in the request exists and is not expired. + + Raises: + HTTPException: If session validation fails + """ + try: + get_session(self._session_manager, raw_request) + return session_id + except ValueError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"Bad request: {str(e)}", + ) + def _process_invocations_request( self, session_id: Optional[str], request_data: dict, raw_request: Request ): - # if session_id: - # TODO: move session id to location based on request shape + # If not a session request + if session_id and self._use_default_manager: + # but it has a session id header and we are using the default session manager, + # then we need to validate that the session id exists in the session manager + self._validate_session_id(session_id, raw_request) + + # Inject session ID into request body if target key is specified + if session_id and self._session_id_target_key: + request_data = set_value( + request_data, + self._session_id_target_key, + session_id, + ) + logger.debug(f"Updated request body: {request_data}") + raw_request._body = json.dumps(request_data).encode("utf-8") + return BaseTransformRequestOutput( raw_request=raw_request, intercept_func=None, ) - - def _process_session_request( - self, request_data: dict, raw_request: Request, session_manager: Optional[SessionManager] - ): - """Process a potential session management request. - - Determines if the request is a session management operation (NEW_SESSION or CLOSE) - and routes it to the appropriate handler, or passes through for normal processing. - - Args: - request_data: Parsed JSON request body - raw_request: FastAPI Request object - session_manager: SessionManager instance - - Returns: - BaseTransformRequestOutput with either: - - intercept_func set if this is a session management request - - None/passthrough if this is a regular request - - Raises: - HTTPException: If request is malformed or session validation fails - """ - session_request = _parse_session_request(request_data) - should_validate = _should_validate_session(session_request) - session_id = None - if should_validate: - # Validate session if session ID is present in headers - # and raise error if session ID is invalid - session_id = _validate_session_if_present(raw_request, session_manager) - - # Not a session request - pass through for normal processing - if session_request is None: - self._process_invocations_request(request_data, raw_request) - - if should_validate and session_manager is None: + def _process_session_request(self, session_request, session_id, raw_request): + # Validation + if self._use_default_manager and not self._session_manager: + # if no custom handlers are registered, but default session manager + # does not exist -> then raise error that session management is disabled logger.error(SESSION_DISABLED_LOG_MESSAGE) raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, detail=SESSION_DISABLED_ERROR_DETAIL, ) + elif self._use_default_manager and self._session_manager: + if session_request.requestType == SessionRequestType.NEW_SESSION: + # Ignores any session id header in create session request + session_id = SessionRequestType.NEW_SESSION + session_id = self._validate_session_id(session_id, raw_request) # Route to appropriate session management handler intercept_func = get_handler_for_request_type(session_request.requestType) return BaseTransformRequestOutput( raw_request=raw_request, intercept_func=intercept_func - ) \ No newline at end of file + ) + + def _process_request( + self, request_data, raw_request, session_manager: Optional[SessionManager] + ): + session_request = _parse_session_request(request_data) + session_id = get_session_id_from_request(raw_request) + if not session_request: + return self._process_invocations_request( + session_id, request_data, raw_request + ) + else: + return self._process_session_request( + session_request, session_id, raw_request + ) diff --git a/python/tests/integration/test_sagemaker_sessions_integration.py b/python/tests/integration/test_sagemaker_sessions_integration.py index baddf77..ce35bcc 100644 --- a/python/tests/integration/test_sagemaker_sessions_integration.py +++ b/python/tests/integration/test_sagemaker_sessions_integration.py @@ -635,5 +635,223 @@ def test_regular_requests_with_session_header_when_disabled( assert SESSION_DISABLED_ERROR_DETAIL in response.text -if __name__ == "__main__": - pytest.main([__file__, "-v"]) +class TestSessionIdPathInjection(BaseSessionIntegrationTest): + """Test session_id_path parameter for injecting session ID into request body.""" + + def setup_handlers(self): + """Define handlers with session_id_path parameter.""" + + @self.router.post("/invocations-with-path") + @sagemaker_standards.stateful_session_manager(session_id_path="session_id") + async def invocations_with_path(request: Request): + """Handler that injects session ID into request body at 'session_id' key.""" + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + + # Capture for test verification + self.capture.capture( + "invocation_with_path", body.get("session_id"), {"body": body} + ) + + return Response( + status_code=200, + content=json.dumps( + { + "message": "success", + "session_id_from_body": body.get("session_id"), + "echo": body, + } + ), + ) + + @self.router.post("/invocations-nested-path") + @sagemaker_standards.stateful_session_manager( + session_id_path="metadata.session_id" + ) + async def invocations_nested_path(request: Request): + """Handler that injects session ID into nested path in request body.""" + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + + # Capture for test verification + session_id = ( + body.get("metadata", {}).get("session_id") + if isinstance(body.get("metadata"), dict) + else None + ) + self.capture.capture("invocation_nested_path", session_id, {"body": body}) + + return Response( + status_code=200, + content=json.dumps( + { + "message": "success", + "session_id_from_body": session_id, + "echo": body, + } + ), + ) + + def test_session_id_injected_into_body(self): + """Test that session ID from header is injected into request body.""" + # Create a session + create_response = self.client.post( + "/invocations-with-path", json={"requestType": "NEW_SESSION"} + ) + session_id = extract_session_id_from_header( + create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Make request with session ID in header + self.capture.clear() + response = self.client.post( + "/invocations-with-path", + json={"prompt": "test request"}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + + assert response.status_code == 200 + data = json.loads(response.text) + + # Verify session ID was injected into body + assert data["session_id_from_body"] == session_id + assert data["echo"]["session_id"] == session_id + assert data["echo"]["prompt"] == "test request" + + def test_session_id_injected_into_nested_path(self): + """Test that session ID is injected into nested path in request body.""" + # Create a session + create_response = self.client.post( + "/invocations-nested-path", json={"requestType": "NEW_SESSION"} + ) + session_id = extract_session_id_from_header( + create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Make request with session ID in header + self.capture.clear() + response = self.client.post( + "/invocations-nested-path", + json={"prompt": "test request", "metadata": {"user": "test"}}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + + assert response.status_code == 200 + data = json.loads(response.text) + + # Verify session ID was injected into nested path + assert data["session_id_from_body"] == session_id + assert data["echo"]["metadata"]["session_id"] == session_id + assert data["echo"]["metadata"]["user"] == "test" + assert data["echo"]["prompt"] == "test request" + + def test_session_id_not_injected_without_header(self): + """Test that session ID is not injected when header is not present.""" + response = self.client.post( + "/invocations-with-path", + json={"prompt": "test request"}, + # No session header + ) + + assert response.status_code == 200 + data = json.loads(response.text) + + # Verify session ID was not injected + assert data["session_id_from_body"] is None + assert "session_id" not in data["echo"] or data["echo"]["session_id"] is None + + def test_session_id_injection_with_multiple_requests(self): + """Test that session ID injection works across multiple requests.""" + # Create a session + create_response = self.client.post( + "/invocations-with-path", json={"requestType": "NEW_SESSION"} + ) + session_id = extract_session_id_from_header( + create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Make multiple requests with the same session ID + for i in range(3): + response = self.client.post( + "/invocations-with-path", + json={"prompt": f"request {i+1}"}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + + assert response.status_code == 200 + data = json.loads(response.text) + assert data["session_id_from_body"] == session_id + + def test_different_sessions_inject_different_ids(self): + """Test that different sessions inject their respective IDs.""" + # Create two sessions + create1 = self.client.post( + "/invocations-with-path", json={"requestType": "NEW_SESSION"} + ) + session1_id = extract_session_id_from_header( + create1.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + create2 = self.client.post( + "/invocations-with-path", json={"requestType": "NEW_SESSION"} + ) + session2_id = extract_session_id_from_header( + create2.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Make requests with each session + response1 = self.client.post( + "/invocations-with-path", + json={"prompt": "session 1"}, + headers={SageMakerSessionHeader.SESSION_ID: session1_id}, + ) + response2 = self.client.post( + "/invocations-with-path", + json={"prompt": "session 2"}, + headers={SageMakerSessionHeader.SESSION_ID: session2_id}, + ) + + # Verify each request got the correct session ID + data1 = json.loads(response1.text) + data2 = json.loads(response2.text) + + assert data1["session_id_from_body"] == session1_id + assert data2["session_id_from_body"] == session2_id + assert session1_id != session2_id + + def test_session_id_injection_preserves_existing_body_fields(self): + """Test that session ID injection doesn't overwrite other body fields.""" + # Create a session + create_response = self.client.post( + "/invocations-with-path", json={"requestType": "NEW_SESSION"} + ) + session_id = extract_session_id_from_header( + create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Make request with multiple body fields + original_body = { + "prompt": "test", + "temperature": 0.7, + "max_tokens": 100, + "metadata": {"user": "test_user", "request_id": "123"}, + } + + response = self.client.post( + "/invocations-with-path", + json=original_body, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + + assert response.status_code == 200 + data = json.loads(response.text) + + # Verify session ID was added + assert data["echo"]["session_id"] == session_id + + # Verify all original fields are preserved + assert data["echo"]["prompt"] == "test" + assert data["echo"]["temperature"] == 0.7 + assert data["echo"]["max_tokens"] == 100 + assert data["echo"]["metadata"]["user"] == "test_user" + assert data["echo"]["metadata"]["request_id"] == "123" diff --git a/python/tests/sagemaker/sessions/test_transform.py b/python/tests/sagemaker/sessions/test_transform.py index 70f0e53..080b08c 100644 --- a/python/tests/sagemaker/sessions/test_transform.py +++ b/python/tests/sagemaker/sessions/test_transform.py @@ -16,13 +16,12 @@ from model_hosting_container_standards.sagemaker.sessions.manager import SessionManager from model_hosting_container_standards.sagemaker.sessions.models import ( SageMakerSessionHeader, + SessionRequest, SessionRequestType, ) from model_hosting_container_standards.sagemaker.sessions.transform import ( SessionApiTransform, _parse_session_request, - _validate_session_if_present, - process_session_request, ) @@ -82,165 +81,154 @@ def test_raises_http_exception_for_extra_fields(self): assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value -class TestValidateSessionIfPresent: - """Test _validate_session_if_present function.""" +class TestValidateSessionId: + """Test _validate_session_id method.""" - def test_does_not_raise_when_no_session_id_present( - self, mock_request, mock_session_manager - ): - """Test does not raise exception when no session ID in request.""" - # Should not raise any exception - _validate_session_if_present(mock_request, mock_session_manager) - - def test_does_not_raise_when_session_id_valid(self, mock_session_manager): + def test_does_not_raise_when_session_id_valid(self, enable_sessions_env): """Test does not raise exception when session ID is valid.""" + transform = SessionApiTransform(request_shape={}, response_shape={}) mock_request = Mock(spec=Request) mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "valid-session"} with patch( - "model_hosting_container_standards.sagemaker.sessions.transform.get_session_id_from_request" - ) as mock_get_id: - mock_get_id.return_value = "valid-session" - - with patch( - "model_hosting_container_standards.sagemaker.sessions.transform.get_session" - ) as mock_get_session: - mock_session = Mock() - mock_get_session.return_value = mock_session + "model_hosting_container_standards.sagemaker.sessions.transform.get_session" + ) as mock_get_session: + mock_session = Mock() + mock_get_session.return_value = mock_session - # Should not raise any exception - _validate_session_if_present(mock_request, mock_session_manager) + # Should not raise any exception + result = transform._validate_session_id("valid-session", mock_request) + assert result == "valid-session" - def test_raises_http_exception_when_session_not_found(self, mock_session_manager): + def test_raises_http_exception_when_session_not_found(self, enable_sessions_env): """Test raises HTTPException when session ID not found.""" + transform = SessionApiTransform(request_shape={}, response_shape={}) mock_request = Mock(spec=Request) mock_request.headers = { SageMakerSessionHeader.SESSION_ID: "nonexistent-session" } with patch( - "model_hosting_container_standards.sagemaker.sessions.transform.get_session_id_from_request" - ) as mock_get_id: - mock_get_id.return_value = "nonexistent-session" - - with patch( - "model_hosting_container_standards.sagemaker.sessions.transform.get_session" - ) as mock_get_session: - mock_get_session.side_effect = ValueError("session not found") + "model_hosting_container_standards.sagemaker.sessions.transform.get_session" + ) as mock_get_session: + mock_get_session.side_effect = ValueError("session not found") - with pytest.raises(HTTPException) as exc_info: - _validate_session_if_present(mock_request, mock_session_manager) + with pytest.raises(HTTPException) as exc_info: + transform._validate_session_id("nonexistent-session", mock_request) - assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value - def test_error_message_includes_original_error(self, mock_session_manager): + def test_error_message_includes_original_error(self, enable_sessions_env): """Test error message includes the original error message.""" + transform = SessionApiTransform(request_shape={}, response_shape={}) mock_request = Mock(spec=Request) mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "bad-session"} with patch( - "model_hosting_container_standards.sagemaker.sessions.transform.get_session_id_from_request" - ) as mock_get_id: - mock_get_id.return_value = "bad-session" - - with patch( - "model_hosting_container_standards.sagemaker.sessions.transform.get_session" - ) as mock_get_session: - mock_get_session.side_effect = ValueError("custom error message") + "model_hosting_container_standards.sagemaker.sessions.transform.get_session" + ) as mock_get_session: + mock_get_session.side_effect = ValueError("custom error message") - with pytest.raises(HTTPException) as exc_info: - _validate_session_if_present(mock_request, mock_session_manager) + with pytest.raises(HTTPException) as exc_info: + transform._validate_session_id("bad-session", mock_request) - assert "custom error message" in exc_info.value.detail + assert "custom error message" in exc_info.value.detail class TestProcessSessionRequest: - """Test process_session_request function.""" - - def test_returns_passthrough_for_non_session_request( - self, mock_request, mock_session_manager - ): - """Test returns passthrough output for non-session request.""" - request_data = {"data": "regular_data"} + """Test _process_session_request method.""" - result = process_session_request( - request_data, mock_request, mock_session_manager - ) - - assert isinstance(result, BaseTransformRequestOutput) - assert result.request is None - assert result.raw_request == mock_request - assert result.intercept_func is None + @pytest.fixture + def transform(self, enable_sessions_env): + """Create SessionApiTransform instance.""" + return SessionApiTransform(request_shape={}, response_shape={}) def test_returns_create_handler_for_new_session_request( - self, mock_request, mock_session_manager + self, transform, mock_request ): """Test returns create_session handler for NEW_SESSION request.""" - request_data = {"requestType": "NEW_SESSION"} + session_request = SessionRequest(requestType=SessionRequestType.NEW_SESSION) - result = process_session_request( - request_data, mock_request, mock_session_manager - ) + result = transform._process_session_request(session_request, None, mock_request) assert isinstance(result, BaseTransformRequestOutput) - assert result.request is None assert result.raw_request == mock_request assert result.intercept_func == create_session - def test_returns_close_handler_for_close_request( - self, mock_request, mock_session_manager - ): + def test_returns_close_handler_for_close_request(self, transform, mock_request): """Test returns close_session handler for CLOSE request.""" - request_data = {"requestType": "CLOSE"} + session_request = SessionRequest(requestType=SessionRequestType.CLOSE) + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "test-session"} - result = process_session_request( - request_data, mock_request, mock_session_manager - ) + with patch( + "model_hosting_container_standards.sagemaker.sessions.transform.get_session" + ) as mock_get_session: + mock_session = Mock() + mock_get_session.return_value = mock_session - assert isinstance(result, BaseTransformRequestOutput) - assert result.request is None - assert result.raw_request == mock_request - assert result.intercept_func == close_session + result = transform._process_session_request( + session_request, "test-session", mock_request + ) + + assert isinstance(result, BaseTransformRequestOutput) + assert result.raw_request == mock_request + assert result.intercept_func == close_session - def test_validates_session_if_session_id_present(self, mock_session_manager): + def test_validates_session_if_session_id_present(self, transform): """Test validates session when session ID is present in headers.""" - request_data = {"data": "regular_data"} mock_request = Mock(spec=Request) mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "test-session"} with patch( - "model_hosting_container_standards.sagemaker.sessions.transform._validate_session_if_present" - ) as mock_validate: - process_session_request(request_data, mock_request, mock_session_manager) + "model_hosting_container_standards.sagemaker.sessions.transform.get_session" + ) as mock_get_session: + mock_session = Mock() + mock_get_session.return_value = mock_session + + transform._process_session_request( + SessionRequest(requestType=SessionRequestType.CLOSE), + "test-session", + mock_request, + ) - mock_validate.assert_called_once_with(mock_request, mock_session_manager) + # Should validate the session + mock_get_session.assert_called_once() - def test_raises_exception_for_invalid_session_request( - self, mock_request, mock_session_manager + def test_raises_exception_when_sessions_disabled( + self, mock_request, monkeypatch, temp_session_storage ): - """Test raises HTTPException for invalid session request.""" - request_data = {"requestType": "INVALID_TYPE"} + """Test raises HTTPException when sessions are disabled.""" + # Disable sessions + monkeypatch.delenv("SAGEMAKER_ENABLE_STATEFUL_SESSIONS", raising=False) + from model_hosting_container_standards.sagemaker.sessions.manager import ( + init_session_manager_from_env, + ) - with pytest.raises(HTTPException): - process_session_request(request_data, mock_request, mock_session_manager) + init_session_manager_from_env() + + transform = SessionApiTransform(request_shape={}, response_shape={}) + session_request = SessionRequest(requestType=SessionRequestType.NEW_SESSION) - def test_propagates_validation_errors(self, mock_session_manager): - """Test propagates validation errors from _validate_session_if_present.""" - request_data = {"data": "regular_data"} + with pytest.raises(HTTPException) as exc_info: + transform._process_session_request(session_request, None, mock_request) + + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value + + def test_propagates_validation_errors(self, transform): + """Test propagates validation errors from session validation.""" mock_request = Mock(spec=Request) mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "invalid-session"} with patch( - "model_hosting_container_standards.sagemaker.sessions.transform._validate_session_if_present" - ) as mock_validate: - mock_validate.side_effect = HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail="Session validation failed", - ) + "model_hosting_container_standards.sagemaker.sessions.transform.get_session" + ) as mock_get_session: + mock_get_session.side_effect = ValueError("Session not found") with pytest.raises(HTTPException) as exc_info: - process_session_request( - request_data, mock_request, mock_session_manager + transform._process_session_request( + SessionRequest(requestType=SessionRequestType.CLOSE), + "invalid-session", + mock_request, ) assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value From 454ad34bdd423af9a070a0cfe10c176a60f80fa6 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Thu, 4 Dec 2025 20:09:07 +0000 Subject: [PATCH 15/25] Add unit tests --- .../sagemaker/sessions/test_registration.py | 127 ++++++++++ .../sagemaker/sessions/transforms/__init__.py | 1 + .../test_close_session_transform.py | 178 ++++++++++++++ .../test_create_session_transform.py | 226 ++++++++++++++++++ 4 files changed, 532 insertions(+) create mode 100644 python/tests/sagemaker/sessions/test_registration.py create mode 100644 python/tests/sagemaker/sessions/transforms/__init__.py create mode 100644 python/tests/sagemaker/sessions/transforms/test_close_session_transform.py create mode 100644 python/tests/sagemaker/sessions/transforms/test_create_session_transform.py diff --git a/python/tests/sagemaker/sessions/test_registration.py b/python/tests/sagemaker/sessions/test_registration.py new file mode 100644 index 0000000..d1b053f --- /dev/null +++ b/python/tests/sagemaker/sessions/test_registration.py @@ -0,0 +1,127 @@ +"""Unit tests for session handler registration functions.""" + +import pytest + +from model_hosting_container_standards.sagemaker.sessions import ( + register_engine_session_handler, +) +from model_hosting_container_standards.sagemaker.sessions.models import ( + SageMakerSessionHeader, +) +from model_hosting_container_standards.sagemaker.sessions.transforms.constants import ( + RESPONSE_CONTENT_KEY, +) + + +class TestRegisterEngineSessionHandler: + """Test register_engine_session_handler function.""" + + def test_create_session_requires_session_id_path(self): + """Test that create_session requires session_id_path parameter.""" + with pytest.raises(ValueError) as exc_info: + register_engine_session_handler( + handler_type="create_session", + request_shape={}, + session_id_path=None, + content_path="message", + ) + + assert "session_id_path is required" in str(exc_info.value) + + def test_create_session_with_valid_params(self): + """Test successful create_session registration.""" + decorator = register_engine_session_handler( + handler_type="create_session", + request_shape={"model": "body.model"}, + session_id_path="session_id", + content_path="message", + ) + + assert decorator is not None + assert callable(decorator) + + def test_close_session_without_session_id_path(self): + """Test that close_session doesn't require session_id_path.""" + decorator = register_engine_session_handler( + handler_type="close_session", + request_shape={}, + content_path="message", + ) + + assert decorator is not None + assert callable(decorator) + + def test_invalid_handler_type(self): + """Test that invalid handler_type raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + register_engine_session_handler( + handler_type="invalid_type", + request_shape={}, + ) + + assert "Invalid handler_type" in str(exc_info.value) + assert "create_session" in str(exc_info.value) + assert "close_session" in str(exc_info.value) + + def test_adds_body_prefix_to_paths(self): + """Test that body. prefix is automatically added to response paths.""" + # This is tested indirectly - the decorator should work with paths + # relative to the handler's return value, not the serialized response + decorator = register_engine_session_handler( + handler_type="create_session", + request_shape={}, + session_id_path="id", # Should become body.id internally + content_path="message", # Should become body.message internally + ) + + assert decorator is not None + + def test_preserves_body_prefix_if_present(self): + """Test that existing body. prefix is not duplicated.""" + decorator = register_engine_session_handler( + handler_type="create_session", + request_shape={}, + session_id_path="body.id", # Already has body. prefix + content_path="body.message", + ) + + assert decorator is not None + + +class TestResponseShapeConstruction: + """Test that response_shape is constructed correctly.""" + + def test_create_session_response_shape_has_required_keys(self): + """Test that create_session response_shape includes session ID and content.""" + # We can't directly inspect the response_shape, but we can verify + # the decorator is created successfully with the right parameters + decorator = register_engine_session_handler( + handler_type="create_session", + request_shape={}, + session_id_path="session.id", + content_path="session.message", + ) + + # If this doesn't raise, the response_shape was constructed correctly + assert decorator is not None + + def test_close_session_response_shape_has_content_key(self): + """Test that close_session response_shape includes content.""" + decorator = register_engine_session_handler( + handler_type="close_session", + request_shape={}, + content_path="result.message", + ) + + assert decorator is not None + + def test_none_content_path_is_handled(self): + """Test that None content_path is handled correctly.""" + decorator = register_engine_session_handler( + handler_type="close_session", + request_shape={}, + content_path=None, + ) + + # Should still create decorator, content extraction will just return None + assert decorator is not None diff --git a/python/tests/sagemaker/sessions/transforms/__init__.py b/python/tests/sagemaker/sessions/transforms/__init__.py new file mode 100644 index 0000000..98a045f --- /dev/null +++ b/python/tests/sagemaker/sessions/transforms/__init__.py @@ -0,0 +1 @@ +"""Tests for session transforms.""" diff --git a/python/tests/sagemaker/sessions/transforms/test_close_session_transform.py b/python/tests/sagemaker/sessions/transforms/test_close_session_transform.py new file mode 100644 index 0000000..08c4e5c --- /dev/null +++ b/python/tests/sagemaker/sessions/transforms/test_close_session_transform.py @@ -0,0 +1,178 @@ +"""Unit tests for CloseSessionApiTransform.""" + +import json +from http import HTTPStatus +from unittest.mock import AsyncMock, Mock + +import pytest +from fastapi import Request, Response +from fastapi.exceptions import HTTPException + +from model_hosting_container_standards.common import BaseTransformRequestOutput +from model_hosting_container_standards.sagemaker.sessions.models import ( + SageMakerSessionHeader, +) +from model_hosting_container_standards.sagemaker.sessions.transforms.close_session import ( + CloseSessionApiTransform, +) +from model_hosting_container_standards.sagemaker.sessions.transforms.constants import ( + RESPONSE_CONTENT_KEY, +) + + +class TestCloseSessionInitialization: + """Test CloseSessionApiTransform initialization.""" + + def test_requires_content_in_response_shape(self): + """Test that initialization requires RESPONSE_CONTENT_KEY in response_shape.""" + with pytest.raises(ValueError) as exc_info: + CloseSessionApiTransform(request_shape={}, response_shape={}) + + assert RESPONSE_CONTENT_KEY in str(exc_info.value) + + def test_successful_initialization(self): + """Test successful initialization with valid response_shape.""" + transform = CloseSessionApiTransform( + request_shape={}, + response_shape={RESPONSE_CONTENT_KEY: "body.message"}, + ) + assert transform is not None + + +class TestCloseSessionValidation: + """Test request validation.""" + + @pytest.fixture + def transform(self): + """Create transform.""" + return CloseSessionApiTransform( + request_shape={}, + response_shape={RESPONSE_CONTENT_KEY: "body.message"}, + ) + + @pytest.mark.asyncio + async def test_requires_session_id_header(self, transform): + """Test that session ID header is required.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {} + mock_request.headers = {} # No session ID + + with pytest.raises(HTTPException) as exc_info: + await transform.transform_request(mock_request) + + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value + assert "Session ID is required" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_succeeds_with_session_id_header(self, transform): + """Test that request succeeds with session ID header.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {} + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + + result = await transform.transform_request(mock_request) + + assert result is not None + + +class TestCloseSessionTransformRequest: + """Test request transformation.""" + + @pytest.fixture + def transform(self): + """Create transform with request shape.""" + return CloseSessionApiTransform( + request_shape={"reason": "body.reason"}, + response_shape={RESPONSE_CONTENT_KEY: "body.message"}, + ) + + @pytest.mark.asyncio + async def test_transforms_request_body(self, transform): + """Test that request body is transformed using JMESPath.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"reason": "timeout"} + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + + result = await transform.transform_request(mock_request) + + assert result.request["reason"] == "timeout" + + @pytest.mark.asyncio + async def test_updates_raw_request_body(self, transform): + """Test that raw request body is updated.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"reason": "timeout"} + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + + await transform.transform_request(mock_request) + + updated_body = json.loads(mock_request._body.decode()) + assert updated_body == {"reason": "timeout"} + + +class TestCloseSessionTransformResponse: + """Test response transformation.""" + + @pytest.fixture + def transform(self): + """Create transform.""" + return CloseSessionApiTransform( + request_shape={}, + response_shape={RESPONSE_CONTENT_KEY: "body.message"}, + ) + + def test_extracts_content_and_adds_header(self, transform): + """Test that content is extracted and session ID added to headers.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"message": "Session closed"}), + ) + + mock_request = Mock(spec=Request) + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + transform_output = BaseTransformRequestOutput( + raw_request=mock_request, intercept_func=None + ) + + result = transform.transform_response(response, transform_output) + + assert result.status_code == HTTPStatus.OK.value + assert result.headers[SageMakerSessionHeader.CLOSED_SESSION_ID] == "sess-123" + assert b"sess-123" in result.body + assert b"Session closed" in result.body + + def test_handles_missing_content(self, transform): + """Test that missing content is handled gracefully.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({}), + ) + + mock_request = Mock(spec=Request) + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + transform_output = BaseTransformRequestOutput( + raw_request=mock_request, intercept_func=None + ) + + result = transform.transform_response(response, transform_output) + + assert result.status_code == HTTPStatus.OK.value + assert result.headers[SageMakerSessionHeader.CLOSED_SESSION_ID] == "sess-123" + + def test_passes_through_error_responses(self, transform): + """Test that error responses pass through unchanged.""" + response = Response( + status_code=HTTPStatus.NOT_FOUND.value, + content=b"Session not found", + ) + + mock_request = Mock(spec=Request) + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + transform_output = BaseTransformRequestOutput( + raw_request=mock_request, intercept_func=None + ) + + result = transform.transform_response(response, transform_output) + + assert result.status_code == HTTPStatus.NOT_FOUND.value + assert result.body == b"Session not found" diff --git a/python/tests/sagemaker/sessions/transforms/test_create_session_transform.py b/python/tests/sagemaker/sessions/transforms/test_create_session_transform.py new file mode 100644 index 0000000..93492f3 --- /dev/null +++ b/python/tests/sagemaker/sessions/transforms/test_create_session_transform.py @@ -0,0 +1,226 @@ +"""Unit tests for CreateSessionApiTransform.""" + +import json +from http import HTTPStatus +from unittest.mock import AsyncMock, Mock + +import pytest +from fastapi import Request, Response +from fastapi.exceptions import HTTPException +from pydantic import BaseModel + +from model_hosting_container_standards.sagemaker.sessions.models import ( + SageMakerSessionHeader, +) +from model_hosting_container_standards.sagemaker.sessions.transforms.constants import ( + RESPONSE_CONTENT_KEY, +) +from model_hosting_container_standards.sagemaker.sessions.transforms.create_session import ( + CreateSessionApiTransform, +) + + +class TestCreateSessionInitialization: + """Test CreateSessionApiTransform initialization.""" + + def test_requires_session_id_in_response_shape(self): + """Test that initialization requires NEW_SESSION_ID in response_shape.""" + with pytest.raises(ValueError) as exc_info: + CreateSessionApiTransform( + request_shape={}, + response_shape={RESPONSE_CONTENT_KEY: "body.message"}, + ) + assert SageMakerSessionHeader.NEW_SESSION_ID in str(exc_info.value) + + def test_requires_content_in_response_shape(self): + """Test that initialization requires RESPONSE_CONTENT_KEY in response_shape.""" + with pytest.raises(ValueError) as exc_info: + CreateSessionApiTransform( + request_shape={}, + response_shape={SageMakerSessionHeader.NEW_SESSION_ID: "body.id"}, + ) + assert RESPONSE_CONTENT_KEY in str(exc_info.value) + + def test_successful_initialization(self): + """Test successful initialization with valid response_shape.""" + transform = CreateSessionApiTransform( + request_shape={"model": "body.model"}, + response_shape={ + SageMakerSessionHeader.NEW_SESSION_ID: "body.session_id", + RESPONSE_CONTENT_KEY: "body.message", + }, + ) + assert transform is not None + + +class TestCreateSessionTransformRequest: + """Test request transformation.""" + + @pytest.fixture + def transform(self): + """Create transform with request shape.""" + return CreateSessionApiTransform( + request_shape={"model": "body.model"}, + response_shape={ + SageMakerSessionHeader.NEW_SESSION_ID: "body.session_id", + RESPONSE_CONTENT_KEY: "body.message", + }, + ) + + @pytest.mark.asyncio + async def test_transforms_request_body(self, transform): + """Test that request body is transformed using JMESPath.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"model": "llama-3"} + mock_request.headers = {} + + result = await transform.transform_request(mock_request) + + assert result.request["model"] == "llama-3" + + @pytest.mark.asyncio + async def test_updates_raw_request_body(self, transform): + """Test that raw request body is updated with transformed data.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"model": "llama-3"} + mock_request.headers = {} + + await transform.transform_request(mock_request) + + updated_body = json.loads(mock_request._body.decode()) + assert updated_body == {"model": "llama-3"} + + @pytest.mark.asyncio + async def test_handles_invalid_json(self, transform): + """Test that invalid JSON raises HTTPException.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.side_effect = json.JSONDecodeError("Invalid", "doc", 0) + + with pytest.raises(HTTPException) as exc_info: + await transform.transform_request(mock_request) + + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value + + +class TestCreateSessionTransformResponse: + """Test response transformation.""" + + @pytest.fixture + def transform(self): + """Create transform.""" + return CreateSessionApiTransform( + request_shape={}, + response_shape={ + SageMakerSessionHeader.NEW_SESSION_ID: "body.session_id", + RESPONSE_CONTENT_KEY: "body.message", + }, + ) + + def test_extracts_session_id_from_response(self, transform): + """Test that session ID is extracted and added to headers.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"session_id": "sess-123", "message": "created"}), + ) + + result = transform.transform_response(response, Mock()) + + assert result.status_code == HTTPStatus.OK.value + assert result.headers[SageMakerSessionHeader.NEW_SESSION_ID] == "sess-123" + assert b"sess-123" in result.body + assert b"created" in result.body + + def test_fails_when_session_id_missing(self, transform): + """Test that missing session ID raises HTTPException.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"message": "created"}), + ) + + with pytest.raises(HTTPException) as exc_info: + transform.transform_response(response, Mock()) + + assert exc_info.value.status_code == HTTPStatus.BAD_GATEWAY.value + assert "session ID" in exc_info.value.detail + + def test_fails_when_session_id_empty(self, transform): + """Test that empty session ID raises HTTPException.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"session_id": "", "message": "created"}), + ) + + with pytest.raises(HTTPException) as exc_info: + transform.transform_response(response, Mock()) + + assert exc_info.value.status_code == HTTPStatus.BAD_GATEWAY.value + + def test_passes_through_error_responses(self, transform): + """Test that error responses pass through unchanged.""" + response = Response( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + content=b"Engine error", + ) + + result = transform.transform_response(response, Mock()) + + assert result.status_code == HTTPStatus.INTERNAL_SERVER_ERROR.value + assert result.body == b"Engine error" + + +class TestCreateSessionNormalizeResponse: + """ normalization.""" + + @pytest.fixture + def transform(self): + """Create transform.""" + return CreateSessionApiTransform( + request_shape={}, + response_shape={ + SageMakerSessionHeader.NEW_SESSION_ID: "body.id", + RESPONSE_CONTENT_KEY: "body.msg", + }, + ) + + def test_normalizes_dict_response(self, transform): + """Test normalization of dict response.""" + response_dict = {"id": "sess-123", "msg": "created"} + + normalized = transform._normalize_response(response_dict) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + body = json.loads(normalized.body) + assert body["id"] == "sess-123" + + def test_normalizes_string_response(self, transform): + """Test normalizatistring response.""" + response_str = "Session created" + + normalized = transform._normalize_response(response_str) + + assert isinstance(normalized, Response) + assert normalized.body == b"Session created" + + def test_normalizes_pydantic_response(self, transform): + """Test normalization of Pydantic model response.""" + + class SessionResponse(BaseModel): + id: str + msg: str + + response_model = SessionResponse(id="sess-123", msg="created") + + normalized = transform._normalize_response(response_model) + + assert isinstance(normalized, Response) + body = json.loads(normalized.body) + assert body["id"] == "sess-123" + + def test_passes_through_response_object(self, transform): + """Test that Response objects pass through unchanged.""" + response = Response(status_code=HTTPStatus.OK.value, content=b"test") + + normalized = transform._normalize_response(response) + + assert normalized is response From 259ea2ff0491baa21fb667225c3843ca262e91e7 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Thu, 4 Dec 2025 21:23:47 +0000 Subject: [PATCH 16/25] Update tests, improve how check for use default is done. --- .../sagemaker/sessions/transform.py | 38 +- ...est_custom_session_handlers_integration.py | 736 ++++++++++++++++++ .../sagemaker/sessions/test_registration.py | 6 - .../test_base_engine_session_api_transform.py | 417 ++++++++++ .../test_close_session_transform.py | 126 +++ .../test_create_session_transform.py | 115 ++- 6 files changed, 1425 insertions(+), 13 deletions(-) create mode 100644 python/tests/integration/test_custom_session_handlers_integration.py create mode 100644 python/tests/sagemaker/sessions/transforms/test_base_engine_session_api_transform.py diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transform.py index a31d721..4cf84cc 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transform.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transform.py @@ -59,16 +59,42 @@ def __init__(self, request_shape, response_shape={}): The request/response shapes are passed to the parent class but not used for validation in this transform, as session requests use their own validation. """ - # Use default session manager if no custom create / close session handlers are registered - self._use_default_manager = not handler_registry.has_handler( + self._session_manager = get_session_manager() + + # Hybrid caching strategy for _use_default_manager: + # - If custom handlers exist at init → cache False (fast path on every request) + # - If no custom handlers at init → cache True but check dynamically (allows late registration) + # This optimizes the common case while maintaining flexibility + self._use_default_manager_cached = not handler_registry.has_handler( "create_session" ) and not handler_registry.has_handler("close_session") - self._session_manager = get_session_manager() # Extract session_id_target_key before compiling JMESPath expressions self._session_id_target_key = self._get_session_id_target_key(request_shape) super().__init__(request_shape, response_shape) + def _use_default_manager(self) -> bool: + """Check if default session manager should be used. + + Hybrid approach for performance: + - If custom handlers existed at init time (cached=False), return False immediately + - If no custom handlers at init (cached=True), check dynamically in case they were registered later + + This optimizes the common case (custom handlers registered before transform creation) + while still supporting late registration for flexibility. + + Returns: + bool: True if default manager should be used, False if custom handlers exist + """ + # Fast path: if custom handlers existed at init, they still exist + if not self._use_default_manager_cached: + return False + + # Slow path: no custom handlers at init, check if any were registered since + return not handler_registry.has_handler( + "create_session" + ) and not handler_registry.has_handler("close_session") + def _get_session_id_target_key(self, request_shape: dict) -> Optional[str]: if not request_shape: return None @@ -135,7 +161,7 @@ def _process_invocations_request( self, session_id: Optional[str], request_data: dict, raw_request: Request ): # If not a session request - if session_id and self._use_default_manager: + if session_id and self._use_default_manager(): # but it has a session id header and we are using the default session manager, # then we need to validate that the session id exists in the session manager self._validate_session_id(session_id, raw_request) @@ -157,7 +183,7 @@ def _process_invocations_request( def _process_session_request(self, session_request, session_id, raw_request): # Validation - if self._use_default_manager and not self._session_manager: + if self._use_default_manager() and not self._session_manager: # if no custom handlers are registered, but default session manager # does not exist -> then raise error that session management is disabled logger.error(SESSION_DISABLED_LOG_MESSAGE) @@ -165,7 +191,7 @@ def _process_session_request(self, session_request, session_id, raw_request): status_code=HTTPStatus.BAD_REQUEST.value, detail=SESSION_DISABLED_ERROR_DETAIL, ) - elif self._use_default_manager and self._session_manager: + elif self._use_default_manager() and self._session_manager: if session_request.requestType == SessionRequestType.NEW_SESSION: # Ignores any session id header in create session request session_id = SessionRequestType.NEW_SESSION diff --git a/python/tests/integration/test_custom_session_handlers_integration.py b/python/tests/integration/test_custom_session_handlers_integration.py new file mode 100644 index 0000000..7dd0b4f --- /dev/null +++ b/python/tests/integration/test_custom_session_handlers_integration.py @@ -0,0 +1,736 @@ +"""Integration tests for custom session handlers. + +Tests the integration of custom engine-specific session handlers with the +SageMaker session management system. These tests verify that: +- Custom handlers can be registered and invoked +- Custom handlers take precedence over default handlers +- Request/response transformations work end-to-end +- Error handling propagates correctly from custom handlers +""" + +import json +import os +import shutil +import tempfile +from http import HTTPStatus +from typing import Optional + +import pytest +from fastapi import APIRouter, FastAPI, Request, Response +from fastapi.testclient import TestClient + +import model_hosting_container_standards.sagemaker as sagemaker_standards +from model_hosting_container_standards.common.handler.registry import handler_registry +from model_hosting_container_standards.sagemaker.sessions import ( + register_engine_session_handler, +) +from model_hosting_container_standards.sagemaker.sessions.manager import ( + init_session_manager_from_env, +) +from model_hosting_container_standards.sagemaker.sessions.models import ( + SageMakerSessionHeader, +) + + +@pytest.fixture(autouse=True) +def enable_sessions_for_integration(monkeypatch): + """Automatically enable sessions for all integration tests in this module.""" + temp_dir = tempfile.mkdtemp() + + monkeypatch.setenv("SAGEMAKER_ENABLE_STATEFUL_SESSIONS", "true") + monkeypatch.setenv("SAGEMAKER_SESSIONS_PATH", temp_dir) + monkeypatch.setenv("SAGEMAKER_SESSIONS_EXPIRATION", "600") + + # Reinitialize the global session manager + init_session_manager_from_env() + + yield + + # Clean up + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + monkeypatch.delenv("SAGEMAKER_ENABLE_STATEFUL_SESSIONS", raising=False) + monkeypatch.delenv("SAGEMAKER_SESSIONS_PATH", raising=False) + monkeypatch.delenv("SAGEMAKER_SESSIONS_EXPIRATION", raising=False) + init_session_manager_from_env() + + +@pytest.fixture(autouse=True) +def clear_handler_registry(): + """Clear handler registry before and after each test.""" + handler_registry.clear() + yield + handler_registry.clear() + + +def extract_session_id_from_header(header_value: str) -> str: + """Extract session ID from SageMaker session header. + + Header format: "; Expires=" + """ + if ";" in header_value: + return header_value.split(";")[0].strip() + return header_value.strip() + + +class MockEngineAPI: + """Mock engine API that simulates an inference engine's session management. + + This simulates engines like vLLM or TGI that have their own session APIs. + """ + + def __init__(self): + self.sessions = {} + self.call_count = {"create": 0, "close": 0} + + def create_session(self, model: Optional[str] = None): + """Simulate engine's create session API.""" + self.call_count["create"] += 1 + session_id = f"engine-session-{self.call_count['create']}" + self.sessions[session_id] = {"model": model, "active": True} + return { + "session": {"id": session_id, "status": "active"}, + "message": f"Engine session created with model {model}", + } + + def close_session(self, session_id: str): + """Simulate engine's close session API.""" + self.call_count["close"] += 1 + if session_id not in self.sessions: + raise ValueError(f"Session {session_id} not found in engine") + self.sessions[session_id]["active"] = False + return {"result": {"message": f"Engine session {session_id} closed"}} + + def reset(self): + """Reset the mock engine state.""" + self.sessions.clear() + self.call_count = {"create": 0, "close": 0} + + +class TestCustomHandlerRegistration: + """Test that custom handlers can be registered and invoked.""" + + def setup_method(self): + """Set up test fixtures.""" + self.app = FastAPI() + self.router = APIRouter() + self.mock_engine = MockEngineAPI() + self.setup_handlers() + self.app.include_router(self.router) + sagemaker_standards.bootstrap(self.app) + self.client = TestClient(self.app) + + def setup_handlers(self): + """Set up handlers and register custom session handlers. + + Note: With the hybrid caching strategy, handlers can be registered + in any order - no timing dependency! + """ + + @self.router.post("/invocations") + @sagemaker_standards.stateful_session_manager() + async def invocations(request: Request): + """Handler with session management.""" + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + return Response( + status_code=200, + content=json.dumps({"message": "success", "echo": body}), + ) + + # Register custom create session handler + @register_engine_session_handler( + handler_type="create_session", + request_shape={}, # Session requests don't have extra fields + session_id_path="body.session.id", # Path to session ID in response body + content_path="body.message", # Path to message in response body + ) + async def custom_create_session(raw_request: Request): + """Custom handler that delegates to engine API.""" + # Call mock engine API with default model + result = self.mock_engine.create_session("default-model") + return result + + # Register custom close session handler + @register_engine_session_handler( + handler_type="close_session", + request_shape={}, + content_path="body.result.message", # Path to message in response body + ) + async def custom_close_session(raw_request: Request): + """Custom handler that delegates to engine API.""" + session_id = raw_request.headers.get(SageMakerSessionHeader.SESSION_ID) + + # Call mock engine API + result = self.mock_engine.close_session(session_id) + return result + + def test_custom_create_handler_is_invoked(self): + """Test that custom create handler is called instead of default.""" + response = self.client.post("/invocations", json={"requestType": "NEW_SESSION"}) + + assert response.status_code == 200 + assert SageMakerSessionHeader.NEW_SESSION_ID in response.headers + + # Verify custom handler was called + assert self.mock_engine.call_count["create"] == 1 + + def test_custom_close_handler_is_invoked(self): + """Test that custom close handler is called instead of default.""" + # First create a session + create_response = self.client.post( + "/invocations", json={"requestType": "NEW_SESSION"} + ) + session_id = extract_session_id_from_header( + create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Now close it + close_response = self.client.post( + "/invocations", + json={"requestType": "CLOSE"}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + + if close_response.status_code != 200: + print(f"Close response status: {close_response.status_code}") + print(f"Close response body: {close_response.text}") + + assert close_response.status_code == 200 + assert SageMakerSessionHeader.CLOSED_SESSION_ID in close_response.headers + + # Verify custom handler was called + assert self.mock_engine.call_count["close"] == 1 + + def test_custom_handler_response_transformation(self): + """Test that response is properly transformed from engine format to SageMaker format.""" + response = self.client.post("/invocations", json={"requestType": "NEW_SESSION"}) + + # Verify response has SageMaker format + assert response.status_code == 200 + session_id = extract_session_id_from_header( + response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Session ID should be from engine (engine-session-X format) + assert session_id.startswith("engine-session-") + + # Response body should contain engine message + assert b"Engine session created" in response.content + + def test_custom_handler_request_transformation(self): + """Test that custom handler is invoked for session creation.""" + self.client.post( + "/invocations", + json={"requestType": "NEW_SESSION"}, + ) + + # Verify the engine received the request + assert self.mock_engine.call_count["create"] == 1 + # Check that session was created in engine + sessions = list(self.mock_engine.sessions.values()) + assert len(sessions) == 1 + assert sessions[0]["model"] == "default-model" + + +class TestCustomHandlerErrorHandling: + """Test error handling in custom handlers.""" + + def setup_method(self): + """Set up test fixtures.""" + self.app = FastAPI() + self.router = APIRouter() + self.mock_engine = MockEngineAPI() + self.setup_handlers() + self.app.include_router(self.router) + sagemaker_standards.bootstrap(self.app) + self.client = TestClient(self.app) + + def setup_handlers(self): + """Set up handlers and register custom session handlers.""" + + @self.router.post("/invocations") + @sagemaker_standards.stateful_session_manager() + async def invocations(request: Request): + """Handler with session management.""" + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + return Response( + status_code=200, + content=json.dumps({"message": "success", "echo": body}), + ) + + @register_engine_session_handler( + handler_type="create_session", + request_shape={}, + session_id_path="body.session.id", + content_path="body.message", + ) + async def custom_create_session(raw_request: Request): + """Custom handler that can fail.""" + result = self.mock_engine.create_session() + return result + + @register_engine_session_handler( + handler_type="close_session", + request_shape={}, + content_path="body.result.message", + ) + async def custom_close_session(raw_request: Request): + """Custom handler that validates session exists.""" + session_id = raw_request.headers.get(SageMakerSessionHeader.SESSION_ID) + # This will raise ValueError if session not found + result = self.mock_engine.close_session(session_id) + return result + + def test_custom_handler_error_propagates(self): + """Test that errors from custom handlers propagate correctly. + + Note: Unhandled exceptions in custom handlers will bubble up through FastAPI. + In production, these should be caught by FastAPI's exception handlers. + For testing with TestClient, the exception is raised directly. + """ + # Try to close a non-existent session - this will raise ValueError + with pytest.raises(ValueError, match="Session nonexistent-session not found"): + self.client.post( + "/invocations", + json={"requestType": "CLOSE"}, + headers={SageMakerSessionHeader.SESSION_ID: "nonexistent-session"}, + ) + + def test_custom_handler_missing_session_id_in_response(self): + """Test handling when custom handler doesn't return session ID.""" + + # Create a handler that returns invalid response + @register_engine_session_handler( + handler_type="create_session", + request_shape={}, + session_id_path="body.session.id", # Path that won't exist + content_path="body.message", + ) + async def broken_create_session(raw_request: Request): + """Handler that returns response without session.id.""" + return {"message": "created but no session id"} + + response = self.client.post("/invocations", json={"requestType": "NEW_SESSION"}) + + # Should fail with BAD_GATEWAY since session ID is missing + assert response.status_code == HTTPStatus.BAD_GATEWAY.value + + +class TestCustomHandlerWithSessionIdPath: + """Test custom handlers work with session_id_path parameter.""" + + def setup_method(self): + """Set up test fixtures.""" + self.app = FastAPI() + self.router = APIRouter() + self.mock_engine = MockEngineAPI() + self.captured_requests = [] + self.setup_handlers_with_session_id_path() + self.app.include_router(self.router) + sagemaker_standards.bootstrap(self.app) + self.client = TestClient(self.app) + + def setup_handlers_with_session_id_path(self): + """Set up handlers that use session_id_path.""" + + @self.router.post("/invocations") + @sagemaker_standards.stateful_session_manager(session_id_path="session_id") + async def invocations(request: Request): + """Handler that injects session ID into request body.""" + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + + # Capture for verification + self.captured_requests.append(body) + + return Response( + status_code=200, + content=json.dumps( + { + "message": "success", + "session_id_from_body": body.get("session_id"), + } + ), + ) + + @register_engine_session_handler( + handler_type="create_session", + request_shape={}, + session_id_path="body.session.id", + content_path="body.message", + ) + async def custom_create_session(raw_request: Request): + """Custom create handler.""" + result = self.mock_engine.create_session() + return result + + @register_engine_session_handler( + handler_type="close_session", + request_shape={}, + content_path="body.result.message", + ) + async def custom_close_session(raw_request: Request): + """Custom close handler.""" + session_id = raw_request.headers.get(SageMakerSessionHeader.SESSION_ID) + result = self.mock_engine.close_session(session_id) + return result + + def test_session_id_injected_with_custom_handler(self): + """Test that session_id_path works with custom handlers.""" + # Create session with custom handler + create_response = self.client.post( + "/invocations", json={"requestType": "NEW_SESSION"} + ) + session_id = extract_session_id_from_header( + create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Make request with session ID - should be injected into body + self.captured_requests.clear() + response = self.client.post( + "/invocations", + json={"prompt": "test"}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + + assert response.status_code == 200 + data = json.loads(response.text) + + # Verify session ID was injected + assert data["session_id_from_body"] == session_id + assert len(self.captured_requests) == 1 + assert self.captured_requests[0]["session_id"] == session_id + + +class TestCustomHandlerConcurrency: + """Test concurrent operations with custom handlers.""" + + def setup_method(self): + """Set up test fixtures.""" + self.app = FastAPI() + self.router = APIRouter() + self.mock_engine = MockEngineAPI() + self.setup_custom_handlers() + self.app.include_router(self.router) + sagemaker_standards.bootstrap(self.app) + self.client = TestClient(self.app) + + def setup_custom_handlers(self): + """Register custom handlers.""" + + @self.router.post("/invocations") + @sagemaker_standards.stateful_session_manager() + async def invocations(request: Request): + """Handler with session management.""" + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + return Response( + status_code=200, + content=json.dumps({"message": "success", "echo": body}), + ) + + @register_engine_session_handler( + handler_type="create_session", + request_shape={}, + session_id_path="body.session.id", + content_path="body.message", + ) + async def custom_create_session(raw_request: Request): + """Custom create handler.""" + result = self.mock_engine.create_session() + return result + + @register_engine_session_handler( + handler_type="close_session", + request_shape={}, + content_path="body.result.message", + ) + async def custom_close_session(raw_request: Request): + """Custom close handler.""" + session_id = raw_request.headers.get(SageMakerSessionHeader.SESSION_ID) + result = self.mock_engine.close_session(session_id) + return result + + def test_multiple_sessions_with_custom_handlers(self): + """Test creating and managing multiple sessions with custom handlers.""" + # Create multiple sessions + session_ids = [] + for i in range(3): + response = self.client.post( + "/invocations", json={"requestType": "NEW_SESSION"} + ) + assert response.status_code == 200 + session_id = extract_session_id_from_header( + response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + session_ids.append(session_id) + + # Verify all sessions were created in engine + assert self.mock_engine.call_count["create"] == 3 + assert len(self.mock_engine.sessions) == 3 + + # Close all sessions + for session_id in session_ids: + response = self.client.post( + "/invocations", + json={"requestType": "CLOSE"}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + assert response.status_code == 200 + + # Verify all sessions were closed in engine + assert self.mock_engine.call_count["close"] == 3 + + def test_interleaved_operations_with_custom_handlers(self): + """Test interleaved create/use/close operations.""" + # Create session 1 + response1 = self.client.post( + "/invocations", json={"requestType": "NEW_SESSION"} + ) + session1_id = extract_session_id_from_header( + response1.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Create session 2 + response2 = self.client.post( + "/invocations", json={"requestType": "NEW_SESSION"} + ) + session2_id = extract_session_id_from_header( + response2.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Close session 1 + close1 = self.client.post( + "/invocations", + json={"requestType": "CLOSE"}, + headers={SageMakerSessionHeader.SESSION_ID: session1_id}, + ) + assert close1.status_code == 200 + + # Session 2 should still work + # (Note: In this test we're just verifying the close worked) + assert self.mock_engine.sessions[session1_id]["active"] is False + assert self.mock_engine.sessions[session2_id]["active"] is True + + +class TestCustomHandlerComplexTransformations: + """Test custom handlers with complex request/response transformations.""" + + def setup_method(self): + """Set up test fixtures.""" + self.app = FastAPI() + self.router = APIRouter() + self.mock_engine = MockEngineAPI() + self.setup_complex_handlers() + self.app.include_router(self.router) + sagemaker_standards.bootstrap(self.app) + self.client = TestClient(self.app) + + def setup_complex_handlers(self): + """Register handlers with complex transformations.""" + + @self.router.post("/invocations") + @sagemaker_standards.stateful_session_manager() + async def invocations(request: Request): + """Handler with session management.""" + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + return Response( + status_code=200, + content=json.dumps({"message": "success", "echo": body}), + ) + + @register_engine_session_handler( + handler_type="create_session", + request_shape={}, + session_id_path="body.session.id", + content_path="body.message", + ) + async def custom_create_session(raw_request: Request): + """Handler for session creation.""" + result = self.mock_engine.create_session("default") + return result + + @register_engine_session_handler( + handler_type="close_session", + request_shape={}, + content_path="body.result.message", + ) + async def custom_close_session(raw_request: Request): + """Handler for session closure.""" + session_id = raw_request.headers.get(SageMakerSessionHeader.SESSION_ID) + + # Engine API closes the session + result = self.mock_engine.close_session(session_id) + return result + + def test_complex_request_transformation(self): + """Test that custom handlers work with session creation.""" + response = self.client.post( + "/invocations", + json={"requestType": "NEW_SESSION"}, + ) + + assert response.status_code == 200 + assert SageMakerSessionHeader.NEW_SESSION_ID in response.headers + + # Verify engine created session + assert self.mock_engine.call_count["create"] == 1 + + def test_close_with_custom_handler(self): + """Test close handler with custom implementation.""" + # Create session first + create_response = self.client.post( + "/invocations", json={"requestType": "NEW_SESSION"} + ) + session_id = extract_session_id_from_header( + create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Close session + close_response = self.client.post( + "/invocations", + json={"requestType": "CLOSE"}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + + assert close_response.status_code == 200 + assert SageMakerSessionHeader.CLOSED_SESSION_ID in close_response.headers + assert self.mock_engine.call_count["close"] == 1 + + +class TestCustomHandlerEndToEnd: + """End-to-end tests simulating real engine integration patterns.""" + + def setup_method(self): + """Set up test fixtures.""" + self.app = FastAPI() + self.router = APIRouter() + self.mock_engine = MockEngineAPI() + self.setup_realistic_handlers() + self.app.include_router(self.router) + sagemaker_standards.bootstrap(self.app) + self.client = TestClient(self.app) + + def setup_realistic_handlers(self): + """Set up handlers that simulate real vLLM/TGI integration.""" + + @self.router.post("/invocations") + @sagemaker_standards.stateful_session_manager( + session_id_path="metadata.session_id" + ) + async def invocations(request: Request): + """Realistic inference handler.""" + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + + # Simulate inference with session context + session_id = body.get("metadata", {}).get("session_id") + prompt = body.get("prompt", "") + + return Response( + status_code=200, + content=json.dumps( + { + "generated_text": f"Response to: {prompt}", + "session_id": session_id, + "metadata": {"tokens": 42}, + } + ), + ) + + @register_engine_session_handler( + handler_type="create_session", + request_shape={}, + session_id_path="body.session.id", + content_path="body.message", + ) + async def vllm_create_session(raw_request: Request): + """Simulate vLLM session creation.""" + result = self.mock_engine.create_session("default-model") + return result + + @register_engine_session_handler( + handler_type="close_session", + request_shape={}, + content_path="body.result.message", + ) + async def vllm_close_session(raw_request: Request): + """Simulate vLLM session closure.""" + session_id = raw_request.headers.get(SageMakerSessionHeader.SESSION_ID) + result = self.mock_engine.close_session(session_id) + return result + + def test_full_lifecycle_with_custom_handlers(self): + """Test complete lifecycle: create -> use -> close with custom handlers.""" + # 1. Create session via custom handler + create_response = self.client.post( + "/invocations", + json={"requestType": "NEW_SESSION"}, + ) + assert create_response.status_code == 200 + session_id = extract_session_id_from_header( + create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + assert session_id.startswith("engine-session-") + + # 2. Use session for inference + inference_response = self.client.post( + "/invocations", + json={"prompt": "Hello, world!", "metadata": {}}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + assert inference_response.status_code == 200 + data = json.loads(inference_response.text) + assert data["session_id"] == session_id + assert "Response to: Hello, world!" in data["generated_text"] + + # 3. Close session via custom handler + close_response = self.client.post( + "/invocations", + json={"requestType": "CLOSE"}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + assert close_response.status_code == 200 + assert ( + close_response.headers[SageMakerSessionHeader.CLOSED_SESSION_ID] + == session_id + ) + + # Verify engine state + assert self.mock_engine.call_count["create"] == 1 + assert self.mock_engine.call_count["close"] == 1 + assert self.mock_engine.sessions[session_id]["active"] is False + + def test_multiple_inference_calls_with_custom_session(self): + """Test multiple inference calls using custom session.""" + # Create session + create_response = self.client.post( + "/invocations", json={"requestType": "NEW_SESSION"} + ) + session_id = extract_session_id_from_header( + create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + ) + + # Make multiple inference calls + prompts = ["First prompt", "Second prompt", "Third prompt"] + for prompt in prompts: + response = self.client.post( + "/invocations", + json={"prompt": prompt, "metadata": {}}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + assert response.status_code == 200 + data = json.loads(response.text) + assert data["session_id"] == session_id + assert prompt in data["generated_text"] + + # Close session + close_response = self.client.post( + "/invocations", + json={"requestType": "CLOSE"}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) + assert close_response.status_code == 200 diff --git a/python/tests/sagemaker/sessions/test_registration.py b/python/tests/sagemaker/sessions/test_registration.py index d1b053f..da30a1c 100644 --- a/python/tests/sagemaker/sessions/test_registration.py +++ b/python/tests/sagemaker/sessions/test_registration.py @@ -5,12 +5,6 @@ from model_hosting_container_standards.sagemaker.sessions import ( register_engine_session_handler, ) -from model_hosting_container_standards.sagemaker.sessions.models import ( - SageMakerSessionHeader, -) -from model_hosting_container_standards.sagemaker.sessions.transforms.constants import ( - RESPONSE_CONTENT_KEY, -) class TestRegisterEngineSessionHandler: diff --git a/python/tests/sagemaker/sessions/transforms/test_base_engine_session_api_transform.py b/python/tests/sagemaker/sessions/transforms/test_base_engine_session_api_transform.py new file mode 100644 index 0000000..0b7407f --- /dev/null +++ b/python/tests/sagemaker/sessions/transforms/test_base_engine_session_api_transform.py @@ -0,0 +1,417 @@ +"""Unit tests for BaseEngineSessionApiTransform.""" + +import json +from http import HTTPStatus +from unittest.mock import AsyncMock, Mock + +import pytest +from fastapi import Request, Response +from fastapi.exceptions import HTTPException +from pydantic import BaseModel + +from model_hosting_container_standards.common import BaseTransformRequestOutput +from model_hosting_container_standards.sagemaker.sessions.transforms.base_engine_session_api_transform import ( + BaseEngineSessionApiTransform, +) + + +class ConcreteTransform(BaseEngineSessionApiTransform): + """Concrete implementation for testing the abstract base class.""" + + def _transform_ok_response(self, response: Response, **kwargs) -> Response: + """Simple implementation that just returns the response.""" + return response + + +class TestBaseEngineSessionApiTransformRequest: + """Test transform_request method.""" + + @pytest.fixture + def transform(self): + """Create concrete transform instance.""" + return ConcreteTransform( + request_shape={"field": "body.field"}, response_shape={} + ) + + @pytest.mark.asyncio + async def test_transforms_request_body(self, transform): + """Test that request body is transformed using JMESPath.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"field": "value"} + mock_request.headers = {} + + result = await transform.transform_request(mock_request) + + assert result.request["field"] == "value" + assert isinstance(result, BaseTransformRequestOutput) + + @pytest.mark.asyncio + async def test_updates_raw_request_body(self, transform): + """Test that raw request _body is updated with transformed data.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"field": "value"} + mock_request.headers = {} + + await transform.transform_request(mock_request) + + updated_body = json.loads(mock_request._body.decode()) + assert updated_body == {"field": "value"} + + @pytest.mark.asyncio + async def test_handles_json_decode_error(self, transform): + """Test that JSON decode errors raise HTTPException.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.side_effect = json.JSONDecodeError("Invalid", "doc", 0) + + with pytest.raises(HTTPException) as exc_info: + await transform.transform_request(mock_request) + + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value + assert "JSON decode error" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_calls_validate_request_preconditions(self, transform): + """Test that _validate_request_preconditions is called.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {} + mock_request.headers = {} + + # Mock the validation method + transform._validate_request_preconditions = Mock() + + await transform.transform_request(mock_request) + + transform._validate_request_preconditions.assert_called_once_with(mock_request) + + @pytest.mark.asyncio + async def test_validation_errors_propagate(self, transform): + """Test that validation errors from preconditions propagate.""" + mock_request = AsyncMock(spec=Request) + mock_request.headers = {} + + # Make validation raise an exception + def raise_validation_error(req): + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, detail="Validation failed" + ) + + transform._validate_request_preconditions = raise_validation_error + + with pytest.raises(HTTPException) as exc_info: + await transform.transform_request(mock_request) + + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value + assert "Validation failed" in exc_info.value.detail + + +class TestBaseEngineSessionApiTransformResponse: + """Test transform_response method.""" + + @pytest.fixture + def transform(self): + """Create concrete transform instance.""" + return ConcreteTransform(request_shape={}, response_shape={}) + + def test_routes_ok_response_to_transform_ok_response(self, transform): + """Test that 200 OK responses are routed to _transform_ok_response.""" + response = Response(status_code=HTTPStatus.OK.value, content=b"success") + transform_output = Mock(spec=BaseTransformRequestOutput) + + # Mock the _transform_ok_response method + transform._transform_ok_response = Mock(return_value=response) + + result = transform.transform_response(response, transform_output) + + transform._transform_ok_response.assert_called_once() + assert result == response + + def test_routes_error_response_to_transform_error_response(self, transform): + """Test that error responses are routed to _transform_error_response.""" + response = Response( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, content=b"error" + ) + transform_output = Mock(spec=BaseTransformRequestOutput) + + # Mock the _transform_error_response method + transform._transform_error_response = Mock(return_value=response) + + result = transform.transform_response(response, transform_output) + + transform._transform_error_response.assert_called_once_with(response) + assert result == response + + def test_normalizes_response_before_routing(self, transform): + """Test that response is normalized before routing.""" + # Pass a dict instead of Response object + response_dict = {"status": "success"} + transform_output = Mock(spec=BaseTransformRequestOutput) + + result = transform.transform_response(response_dict, transform_output) + + # Should be normalized to Response object + assert isinstance(result, Response) + assert result.status_code == HTTPStatus.OK.value + + +class TestNormalizeResponse: + """Test _normalize_response method.""" + + @pytest.fixture + def transform(self): + """Create concrete transform instance.""" + return ConcreteTransform(request_shape={}, response_shape={}) + + def test_passes_through_response_object(self, transform): + """Test that Response objects pass through unchanged.""" + response = Response(status_code=HTTPStatus.OK.value, content=b"test") + + normalized = transform._normalize_response(response) + + assert normalized is response + + def test_normalizes_dict_to_response(self, transform): + """Test that dict is normalized to Response.""" + response_dict = {"key": "value"} + + normalized = transform._normalize_response(response_dict) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + body = json.loads(normalized.body) + assert body["key"] == "value" + + def test_normalizes_string_to_response(self, transform): + """Test that string is normalized to Response.""" + response_str = "success message" + + normalized = transform._normalize_response(response_str) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + assert normalized.body == b"success message" + + def test_normalizes_pydantic_model_to_response(self, transform): + """Test that Pydantic model is normalized to Response.""" + + class TestModel(BaseModel): + field1: str + field2: int + + model = TestModel(field1="test", field2=42) + + normalized = transform._normalize_response(model) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + body = json.loads(normalized.body) + assert body["field1"] == "test" + assert body["field2"] == 42 + + def test_normalizes_none_to_response(self, transform): + """Test that None is normalized to Response.""" + normalized = transform._normalize_response(None) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + assert normalized.body == b"null" + + def test_normalizes_list_to_response(self, transform): + """Test that list is normalized to Response.""" + response_list = [{"id": 1}, {"id": 2}] + + normalized = transform._normalize_response(response_list) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + body = json.loads(normalized.body) + assert len(body) == 2 + assert body[0]["id"] == 1 + + def test_normalizes_int_to_response(self, transform): + """Test that integer is normalized to Response.""" + normalized = transform._normalize_response(42) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + assert normalized.body == b"42" + + def test_normalizes_bool_to_response(self, transform): + """Test that boolean is normalized to Response.""" + normalized = transform._normalize_response(True) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + assert normalized.body == b"true" + + def test_normalizes_empty_dict_to_response(self, transform): + """Test that empty dict is normalized to Response.""" + normalized = transform._normalize_response({}) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + assert normalized.body == b"{}" + + def test_normalizes_nested_structure_to_response(self, transform): + """Test that nested structure is normalized to Response.""" + response_data = { + "session": {"id": "sess-123", "metadata": {"user": "test"}}, + "status": "active", + } + + normalized = transform._normalize_response(response_data) + + assert isinstance(normalized, Response) + body = json.loads(normalized.body) + assert body["session"]["id"] == "sess-123" + assert body["session"]["metadata"]["user"] == "test" + + def test_preserves_response_with_error_status_code(self, transform): + """Test that Response with error status code is preserved.""" + response = Response( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, content=b"error" + ) + + normalized = transform._normalize_response(response) + + assert normalized is response + assert normalized.status_code == HTTPStatus.INTERNAL_SERVER_ERROR.value + + +class TestTransformErrorResponse: + """Test _transform_error_response method.""" + + @pytest.fixture + def transform(self): + """Create concrete transform instance.""" + return ConcreteTransform(request_shape={}, response_shape={}) + + def test_passes_through_error_response_unchanged(self, transform): + """Test that error responses pass through unchanged by default.""" + response = Response( + status_code=HTTPStatus.NOT_FOUND.value, content=b"Not found" + ) + + result = transform._transform_error_response(response) + + assert result is response + assert result.status_code == HTTPStatus.NOT_FOUND.value + assert result.body == b"Not found" + + def test_handles_various_error_status_codes(self, transform): + """Test that various error status codes are handled.""" + error_codes = [ + HTTPStatus.BAD_REQUEST.value, + HTTPStatus.UNAUTHORIZED.value, + HTTPStatus.FORBIDDEN.value, + HTTPStatus.NOT_FOUND.value, + HTTPStatus.INTERNAL_SERVER_ERROR.value, + HTTPStatus.BAD_GATEWAY.value, + HTTPStatus.SERVICE_UNAVAILABLE.value, + ] + + for status_code in error_codes: + response = Response(status_code=status_code, content=b"error") + result = transform._transform_error_response(response) + assert result.status_code == status_code + + +class TestValidateRequestPreconditions: + """Test _validate_request_preconditions method.""" + + @pytest.fixture + def transform(self): + """Create concrete transform instance.""" + return ConcreteTransform(request_shape={}, response_shape={}) + + def test_default_implementation_does_nothing(self, transform): + """Test that default implementation doesn't raise exceptions.""" + mock_request = Mock(spec=Request) + mock_request.headers = {} + + # Should not raise any exception + transform._validate_request_preconditions(mock_request) + + def test_can_be_overridden_in_subclass(self): + """Test that subclasses can override validation.""" + + class CustomTransform(BaseEngineSessionApiTransform): + def _validate_request_preconditions(self, raw_request: Request) -> None: + if not raw_request.headers.get("X-Custom-Header"): + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing custom header", + ) + + def _transform_ok_response(self, response: Response, **kwargs) -> Response: + return response + + transform = CustomTransform(request_shape={}, response_shape={}) + mock_request = Mock(spec=Request) + mock_request.headers = {} + + with pytest.raises(HTTPException) as exc_info: + transform._validate_request_preconditions(mock_request) + + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value + assert "Missing custom header" in exc_info.value.detail + + +class TestAbstractMethods: + """Test that abstract methods must be implemented.""" + + def test_transform_ok_response_must_be_implemented(self): + """Test that _transform_ok_response must be implemented by subclasses.""" + + # Try to instantiate without implementing abstract method + with pytest.raises(TypeError) as exc_info: + + class IncompleteTransform(BaseEngineSessionApiTransform): + pass + + IncompleteTransform(request_shape={}, response_shape={}) + + assert "_transform_ok_response" in str(exc_info.value) + + +class TestTransformRequestOutputStructure: + """Test the structure of transform_request output.""" + + @pytest.fixture + def transform(self): + """Create concrete transform instance.""" + return ConcreteTransform( + request_shape={"param": "body.param"}, response_shape={} + ) + + @pytest.mark.asyncio + async def test_output_contains_transformed_request(self, transform): + """Test that output contains the transformed request data.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"param": "value"} + mock_request.headers = {} + + result = await transform.transform_request(mock_request) + + assert result.request == {"param": "value"} + + @pytest.mark.asyncio + async def test_output_contains_raw_request(self, transform): + """Test that output contains the raw request object.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"param": "value"} + mock_request.headers = {} + + result = await transform.transform_request(mock_request) + + assert result.raw_request is mock_request + + @pytest.mark.asyncio + async def test_output_intercept_func_is_none(self, transform): + """Test that intercept_func is None for base transform.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"param": "value"} + mock_request.headers = {} + + result = await transform.transform_request(mock_request) + + assert result.intercept_func is None diff --git a/python/tests/sagemaker/sessions/transforms/test_close_session_transform.py b/python/tests/sagemaker/sessions/transforms/test_close_session_transform.py index 08c4e5c..11a0ec8 100644 --- a/python/tests/sagemaker/sessions/transforms/test_close_session_transform.py +++ b/python/tests/sagemaker/sessions/transforms/test_close_session_transform.py @@ -176,3 +176,129 @@ def test_passes_through_error_responses(self, transform): assert result.status_code == HTTPStatus.NOT_FOUND.value assert result.body == b"Session not found" + + +class TestCloseSessionEdgeCases: + """Test edge cases for CloseSessionApiTransform.""" + + @pytest.fixture + def transform(self): + """Create transform.""" + return CloseSessionApiTransform( + request_shape={}, + response_shape={RESPONSE_CONTENT_KEY: "body.message"}, + ) + + def test_handles_none_content(self, transform): + """Test that None content is handled gracefully.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"message": None}), + ) + + mock_request = Mock(spec=Request) + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + transform_output = BaseTransformRequestOutput( + raw_request=mock_request, intercept_func=None + ) + + result = transform.transform_response(response, transform_output) + + assert result.status_code == HTTPStatus.OK.value + assert result.headers[SageMakerSessionHeader.CLOSED_SESSION_ID] == "sess-123" + + def test_handles_empty_string_content(self, transform): + """Test that empty string content is handled gracefully.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"message": ""}), + ) + + mock_request = Mock(spec=Request) + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + transform_output = BaseTransformRequestOutput( + raw_request=mock_request, intercept_func=None + ) + + result = transform.transform_response(response, transform_output) + + assert result.status_code == HTTPStatus.OK.value + assert result.headers[SageMakerSessionHeader.CLOSED_SESSION_ID] == "sess-123" + + def test_extracts_content_from_nested_path(self, transform): + """Test extraction of content from nested response structure.""" + transform_nested = CloseSessionApiTransform( + request_shape={}, + response_shape={RESPONSE_CONTENT_KEY: "body.result.message"}, + ) + + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"result": {"message": "Session closed successfully"}}), + ) + + mock_request = Mock(spec=Request) + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-nested-123"} + transform_output = BaseTransformRequestOutput( + raw_request=mock_request, intercept_func=None + ) + + result = transform_nested.transform_response(response, transform_output) + + assert result.status_code == HTTPStatus.OK.value + assert ( + result.headers[SageMakerSessionHeader.CLOSED_SESSION_ID] + == "sess-nested-123" + ) + assert b"Session closed successfully" in result.body + + def test_handles_malformed_json_in_response(self, transform): + """Test that malformed JSON in response is handled gracefully. + + The serialize_response function catches JSONDecodeError and keeps the body as a string, + so malformed JSON doesn't cause the transform to fail. + """ + response = Response( + status_code=HTTPStatus.OK.value, + content=b"not valid json {{{", + ) + + mock_request = Mock(spec=Request) + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: "sess-123"} + transform_output = BaseTransformRequestOutput( + raw_request=mock_request, intercept_func=None + ) + + # Should handle gracefully - malformed JSON is kept as string + result = transform.transform_response(response, transform_output) + + # Should still return a response with the session ID header + assert result.status_code == HTTPStatus.OK.value + assert result.headers[SageMakerSessionHeader.CLOSED_SESSION_ID] == "sess-123" + + @pytest.mark.asyncio + async def test_validates_session_id_before_transformation(self, transform): + """Test that session ID validation happens before request transformation.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {"reason": "timeout"} + mock_request.headers = {} # Missing session ID + + # Should fail validation before even attempting transformation + with pytest.raises(HTTPException) as exc_info: + await transform.transform_request(mock_request) + + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value + # json() should not have been called since validation failed first + mock_request.json.assert_not_called() + + @pytest.mark.asyncio + async def test_validates_empty_session_id(self, transform): + """Test that empty session ID is rejected.""" + mock_request = AsyncMock(spec=Request) + mock_request.json.return_value = {} + mock_request.headers = {SageMakerSessionHeader.SESSION_ID: ""} + + with pytest.raises(HTTPException) as exc_info: + await transform.transform_request(mock_request) + + assert exc_info.value.status_code == HTTPStatus.BAD_REQUEST.value diff --git a/python/tests/sagemaker/sessions/transforms/test_create_session_transform.py b/python/tests/sagemaker/sessions/transforms/test_create_session_transform.py index 93492f3..baa5b98 100644 --- a/python/tests/sagemaker/sessions/transforms/test_create_session_transform.py +++ b/python/tests/sagemaker/sessions/transforms/test_create_session_transform.py @@ -169,7 +169,7 @@ def test_passes_through_error_responses(self, transform): class TestCreateSessionNormalizeResponse: - """ normalization.""" + """normalization.""" @pytest.fixture def transform(self): @@ -224,3 +224,116 @@ def test_passes_through_response_object(self, transform): normalized = transform._normalize_response(response) assert normalized is response + + def test_normalizes_none_response(self, transform): + """Test normalization of None response.""" + normalized = transform._normalize_response(None) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + assert normalized.body == b"null" + + def test_normalizes_list_response(self, transform): + """Test normalization of list response.""" + response_list = [{"id": "sess-1"}, {"id": "sess-2"}] + + normalized = transform._normalize_response(response_list) + + assert isinstance(normalized, Response) + assert normalized.status_code == HTTPStatus.OK.value + body = json.loads(normalized.body) + assert len(body) == 2 + assert body[0]["id"] == "sess-1" + + +class TestCreateSessionEdgeCases: + """Test edge cases for CreateSessionApiTransform.""" + + @pytest.fixture + def transform(self): + """Create transform.""" + return CreateSessionApiTransform( + request_shape={}, + response_shape={ + SageMakerSessionHeader.NEW_SESSION_ID: "body.session_id", + RESPONSE_CONTENT_KEY: "body.message", + }, + ) + + def test_fails_when_session_id_is_none(self, transform): + """Test that None session ID raises HTTPException.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"session_id": None, "message": "created"}), + ) + + with pytest.raises(HTTPException) as exc_info: + transform.transform_response(response, Mock()) + + assert exc_info.value.status_code == HTTPStatus.BAD_GATEWAY.value + + def test_handles_none_content(self, transform): + """Test that None content is handled gracefully.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"session_id": "sess-123", "message": None}), + ) + + result = transform.transform_response(response, Mock()) + + assert result.status_code == HTTPStatus.OK.value + assert result.headers[SageMakerSessionHeader.NEW_SESSION_ID] == "sess-123" + + def test_handles_empty_string_content(self, transform): + """Test that empty string content is handled gracefully.""" + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps({"session_id": "sess-123", "message": ""}), + ) + + result = transform.transform_response(response, Mock()) + + assert result.status_code == HTTPStatus.OK.value + assert result.headers[SageMakerSessionHeader.NEW_SESSION_ID] == "sess-123" + + def test_extracts_session_id_from_nested_path(self, transform): + """Test extraction of session ID from nested response structure.""" + transform_nested = CreateSessionApiTransform( + request_shape={}, + response_shape={ + SageMakerSessionHeader.NEW_SESSION_ID: "body.data.session.id", + RESPONSE_CONTENT_KEY: "body.data.message", + }, + ) + + response = Response( + status_code=HTTPStatus.OK.value, + content=json.dumps( + {"data": {"session": {"id": "sess-nested-123"}, "message": "created"}} + ), + ) + + result = transform_nested.transform_response(response, Mock()) + + assert result.status_code == HTTPStatus.OK.value + assert ( + result.headers[SageMakerSessionHeader.NEW_SESSION_ID] == "sess-nested-123" + ) + + def test_handles_malformed_json_in_response(self, transform): + """Test that malformed JSON in response is handled gracefully. + + The serialize_response function catches JSONDecodeError and keeps the body as a string, + but since we can't extract a session_id from a string, this should fail validation. + """ + response = Response( + status_code=HTTPStatus.OK.value, + content=b"not valid json {{{", + ) + + # Should fail because session_id cannot be extracted from malformed JSON + with pytest.raises(HTTPException) as exc_info: + transform.transform_response(response, Mock()) + + assert exc_info.value.status_code == HTTPStatus.BAD_GATEWAY.value + assert "session ID" in exc_info.value.detail From 5c01442a7c1e504f3d8b0d51a993b0b2fecfbdf8 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Fri, 5 Dec 2025 17:34:10 +0000 Subject: [PATCH 17/25] Remove unnecessary bootstrap in integ tests. --- .../test_custom_session_handlers_integration.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tests/integration/test_custom_session_handlers_integration.py b/python/tests/integration/test_custom_session_handlers_integration.py index 7dd0b4f..6a6a233 100644 --- a/python/tests/integration/test_custom_session_handlers_integration.py +++ b/python/tests/integration/test_custom_session_handlers_integration.py @@ -117,7 +117,7 @@ def setup_method(self): self.mock_engine = MockEngineAPI() self.setup_handlers() self.app.include_router(self.router) - sagemaker_standards.bootstrap(self.app) + # bootstrap() not needed - tests define their own routes self.client = TestClient(self.app) def setup_handlers(self): @@ -243,7 +243,7 @@ def setup_method(self): self.mock_engine = MockEngineAPI() self.setup_handlers() self.app.include_router(self.router) - sagemaker_standards.bootstrap(self.app) + # bootstrap() not needed - tests define their own routes self.client = TestClient(self.app) def setup_handlers(self): @@ -329,7 +329,7 @@ def setup_method(self): self.captured_requests = [] self.setup_handlers_with_session_id_path() self.app.include_router(self.router) - sagemaker_standards.bootstrap(self.app) + # bootstrap() not needed - tests define their own routes self.client = TestClient(self.app) def setup_handlers_with_session_id_path(self): @@ -414,7 +414,7 @@ def setup_method(self): self.mock_engine = MockEngineAPI() self.setup_custom_handlers() self.app.include_router(self.router) - sagemaker_standards.bootstrap(self.app) + # bootstrap() not needed - tests define their own routes self.client = TestClient(self.app) def setup_custom_handlers(self): @@ -525,7 +525,7 @@ def setup_method(self): self.mock_engine = MockEngineAPI() self.setup_complex_handlers() self.app.include_router(self.router) - sagemaker_standards.bootstrap(self.app) + # bootstrap() not needed - tests define their own routes self.client = TestClient(self.app) def setup_complex_handlers(self): @@ -611,7 +611,7 @@ def setup_method(self): self.mock_engine = MockEngineAPI() self.setup_realistic_handlers() self.app.include_router(self.router) - sagemaker_standards.bootstrap(self.app) + # bootstrap() not needed - tests define their own routes self.client = TestClient(self.app) def setup_realistic_handlers(self): From c94f0e3bdd9700c8680c96ea22076f65d5e07d23 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Sat, 6 Dec 2025 00:15:06 +0000 Subject: [PATCH 18/25] chore(sessions): clarify parameter naming and improve documentation - Rename session_id_path to request_session_id_path in stateful_session_manager - Rename session_id_path to response_session_id_path in register handlers - Add docstring clarifying request_session_id_path usage - Provide default content messages for create/close session handlers - Remove specific engine references (vLLM, TGI) from documentation - Delete obsolete test_registration.py - Improve integration tests --- .../sagemaker/__init__.py | 20 +- .../sagemaker/sessions/CUSTOM_HANDLERS.md | 2 +- .../sagemaker/sessions/README.md | 2 +- .../sagemaker/sessions/__init__.py | 13 +- ...est_custom_session_handlers_integration.py | 1268 +++++++++-------- .../test_sagemaker_sessions_integration.py | 10 +- .../sagemaker/sessions/test_registration.py | 121 -- 7 files changed, 734 insertions(+), 702 deletions(-) delete mode 100644 python/tests/sagemaker/sessions/test_registration.py diff --git a/python/model_hosting_container_standards/sagemaker/__init__.py b/python/model_hosting_container_standards/sagemaker/__init__.py index 2d38fba..ac861e4 100644 --- a/python/model_hosting_container_standards/sagemaker/__init__.py +++ b/python/model_hosting_container_standards/sagemaker/__init__.py @@ -122,19 +122,23 @@ def inject_adapter_id( ) -def stateful_session_manager(session_id_path: Optional[str] = None): +def stateful_session_manager(request_session_id_path: Optional[str] = None): """Create a decorator for session-based sticky routing. This decorator enables stateful session management without JMESPath transformations. Pass empty dicts to enable transform infrastructure (for intercept functionality) without requiring JMESPath expressions. + Args: + request_session_id_path: JMESPath target path where session ID should be + injected INTO the request body from the session header + Returns: A decorator that can be applied to route handlers to enable session management """ request_shape = {} - if session_id_path: - request_shape[session_id_path] = ( + if request_session_id_path: + request_shape[request_session_id_path] = ( f'headers."{SageMakerSessionHeader.SESSION_ID}"' ) return create_session_transform_decorator()( @@ -143,19 +147,21 @@ def stateful_session_manager(session_id_path: Optional[str] = None): def register_create_session_handler( - request_shape, session_id_path: str, content_path: Optional[str] = None + request_shape, response_session_id_path: str, content_path: Optional[str] = None ): return register_engine_session_handler( "create_session", request_shape=request_shape, - session_id_path=session_id_path, - content_path=content_path, + response_session_id_path=response_session_id_path, + content_path=content_path or "`successfully created session.`", ) def register_close_session_handler(request_shape, content_path: Optional[str] = None): return register_engine_session_handler( - "close_session", request_shape=request_shape, content_path=content_path + "close_session", + request_shape=request_shape, + content_path=content_path or "`successfully closed session.`", ) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md b/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md index 3c06300..fa65de3 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md +++ b/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md @@ -4,7 +4,7 @@ This guide explains how to implement custom create and close session handlers wh ## Overview -By default, SageMaker's session management uses the built-in `SessionManager` to handle session lifecycle. However, if your inference engine (like vLLM, TGI, or a custom engine) provides its own session API, you can register custom handlers to delegate session operations to the engine. +By default, SageMaker's session management uses the built-in `SessionManager` to handle session lifecycle. However, if your inference engine provides its own session API, you can register custom handlers to delegate session operations to the engine. ### When to Use Custom Handlers diff --git a/python/model_hosting_container_standards/sagemaker/sessions/README.md b/python/model_hosting_container_standards/sagemaker/sessions/README.md index 9e0309b..3896dd9 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/README.md +++ b/python/model_hosting_container_standards/sagemaker/sessions/README.md @@ -35,7 +35,7 @@ The framework supports two modes of session management: - Sessions delegated to your inference engine's native API - Leverages engine-specific session features - Requires custom handler registration - - Best when engine has built-in session support (e.g., vLLM, TGI) + - Best when engine has built-in session support - See [CUSTOM_HANDLERS.md](./CUSTOM_HANDLERS.md) for details ## Architecture diff --git a/python/model_hosting_container_standards/sagemaker/sessions/__init__.py b/python/model_hosting_container_standards/sagemaker/sessions/__init__.py index 9e1af91..37219d0 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/__init__.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/__init__.py @@ -34,7 +34,7 @@ def _create_engine_session_transform_decorator(handler_type: str): def register_engine_session_handler( handler_type: str, request_shape, - session_id_path: Optional[str] = None, + response_session_id_path: Optional[str] = None, content_path: Optional[str] = None, ): """Register a handler for engine-specific session management. @@ -42,8 +42,9 @@ def register_engine_session_handler( Args: handler_type: Type of session handler ('create_session' or 'close_session') request_shape: JMESPath expressions for transforming request data - session_id_path: JMESPath expression for extracting session ID from response - (required for 'create_session', ignored for 'close_session') + response_session_id_path: JMESPath expression for extracting session ID FROM + the engine's response (required for 'create_session', + ignored for 'close_session') content_path: JMESPath expression for extracting content from response Returns: @@ -64,9 +65,9 @@ def register_engine_session_handler( } if handler_type == "create_session": - if not session_id_path: - raise ValueError("session_id_path is required for create_session") - response_shape[SageMakerSessionHeader.NEW_SESSION_ID] = session_id_path + if not response_session_id_path: + raise ValueError("response_session_id_path is required for create_session") + response_shape[SageMakerSessionHeader.NEW_SESSION_ID] = response_session_id_path return _create_engine_session_transform_decorator(handler_type)( request_shape, response_shape diff --git a/python/tests/integration/test_custom_session_handlers_integration.py b/python/tests/integration/test_custom_session_handlers_integration.py index 6a6a233..d3f872e 100644 --- a/python/tests/integration/test_custom_session_handlers_integration.py +++ b/python/tests/integration/test_custom_session_handlers_integration.py @@ -1,29 +1,40 @@ -"""Integration tests for custom session handlers. - -Tests the integration of custom engine-specific session handlers with the -SageMaker session management system. These tests verify that: -- Custom handlers can be registered and invoked -- Custom handlers take precedence over default handlers -- Request/response transformations work end-to-end -- Error handling propagates correctly from custom handlers +"""Integration tests for custom session handlers functionality. + +Tests the integration of custom session handlers with engine-specific session APIs using +the proper decorator-based registration pattern: +- @register_create_session_handler decorator +- @register_close_session_handler decorator +- Mixed scenarios (custom + default handlers) +- Transform request/response shape mapping via decorators +- Error handling in custom handlers +- Handler registration and resolution + +Key Testing Pattern: + These tests simulate real-world scenarios where an inference engine + has its own session management API. We use the + proper decorators to register handlers and verify that: + 1. Decorators properly register and invoke custom handlers + 2. Transforms correctly map between SageMaker and engine formats + 3. Session lifecycle works end-to-end with custom handlers + 4. Error cases are handled gracefully """ import json import os import shutil import tempfile -from http import HTTPStatus +import uuid from typing import Optional import pytest -from fastapi import APIRouter, FastAPI, Request, Response +from fastapi import APIRouter, FastAPI, Request +from fastapi.exceptions import HTTPException +from fastapi.responses import Response from fastapi.testclient import TestClient +from pydantic import BaseModel import model_hosting_container_standards.sagemaker as sagemaker_standards from model_hosting_container_standards.common.handler.registry import handler_registry -from model_hosting_container_standards.sagemaker.sessions import ( - register_engine_session_handler, -) from model_hosting_container_standards.sagemaker.sessions.manager import ( init_session_manager_from_env, ) @@ -31,6 +42,17 @@ SageMakerSessionHeader, ) +DEFAULT_SESSION_ID = "default-session" + + +class CreateSessionRequest(BaseModel): + capacity_of_str_len: int + session_id: Optional[str] = None + + +class CloseSessionRequest(BaseModel): + session_id: str + @pytest.fixture(autouse=True) def enable_sessions_for_integration(monkeypatch): @@ -41,12 +63,10 @@ def enable_sessions_for_integration(monkeypatch): monkeypatch.setenv("SAGEMAKER_SESSIONS_PATH", temp_dir) monkeypatch.setenv("SAGEMAKER_SESSIONS_EXPIRATION", "600") - # Reinitialize the global session manager init_session_manager_from_env() yield - # Clean up if os.path.exists(temp_dir): shutil.rmtree(temp_dir) monkeypatch.delenv("SAGEMAKER_ENABLE_STATEFUL_SESSIONS", raising=False) @@ -56,681 +76,805 @@ def enable_sessions_for_integration(monkeypatch): @pytest.fixture(autouse=True) -def clear_handler_registry(): - """Clear handler registry before and after each test.""" - handler_registry.clear() +def cleanup_handler_registry(): + """Clean up handler registry after each test.""" yield - handler_registry.clear() + handler_registry.remove_handler("create_session") + handler_registry.remove_handler("close_session") def extract_session_id_from_header(header_value: str) -> str: - """Extract session ID from SageMaker session header. - - Header format: "; Expires=" - """ + """Extract session ID from SageMaker session header.""" if ";" in header_value: return header_value.split(";")[0].strip() return header_value.strip() -class MockEngineAPI: - """Mock engine API that simulates an inference engine's session management. - - This simulates engines like vLLM or TGI that have their own session APIs. - """ - - def __init__(self): - self.sessions = {} - self.call_count = {"create": 0, "close": 0} - - def create_session(self, model: Optional[str] = None): - """Simulate engine's create session API.""" - self.call_count["create"] += 1 - session_id = f"engine-session-{self.call_count['create']}" - self.sessions[session_id] = {"model": model, "active": True} - return { - "session": {"id": session_id, "status": "active"}, - "message": f"Engine session created with model {model}", - } - - def close_session(self, session_id: str): - """Simulate engine's close session API.""" - self.call_count["close"] += 1 - if session_id not in self.sessions: - raise ValueError(f"Session {session_id} not found in engine") - self.sessions[session_id]["active"] = False - return {"result": {"message": f"Engine session {session_id} closed"}} - - def reset(self): - """Reset the mock engine state.""" - self.sessions.clear() - self.call_count = {"create": 0, "close": 0} +class BaseCustomHandlerIntegrationTest: + """Base class for custom handler integration tests with common setup. + Provides: + - FastAPI app and router setup + - Mock engine client for simulating engine APIs + - Handler call tracking + - TestClient for making requests + - Common setup/teardown patterns -class TestCustomHandlerRegistration: - """Test that custom handlers can be registered and invoked.""" + Subclasses should override setup_handlers() to register their specific + custom handlers using the appropriate decorators. + """ def setup_method(self): - """Set up test fixtures.""" + """Common setup for all custom handler integration tests.""" self.app = FastAPI() self.router = APIRouter() - self.mock_engine = MockEngineAPI() + + # Track handler invocations for verification + self.handler_calls = {"create": 0, "close": 0} + + # Setup handlers (to be overridden by subclasses) self.setup_handlers() + + # Bootstrap the app with SageMaker standards self.app.include_router(self.router) - # bootstrap() not needed - tests define their own routes + sagemaker_standards.bootstrap(self.app) self.client = TestClient(self.app) def setup_handlers(self): - """Set up handlers and register custom session handlers. + """Override in subclasses to register custom handlers. - Note: With the hybrid caching strategy, handlers can be registered - in any order - no timing dependency! + This method should: + 1. Define custom handler functions + 2. Register them using @register_create_session_handler or @register_close_session_handler + 3. Set up the /invocations endpoint with @stateful_session_manager """ - + self.setup_common_handlers() + self.setup_invocation_handler() + + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + # Implement in child classes + pass + + def custom_close_session(self, obj: CloseSessionRequest, request: Request): + # Implement in child classes + pass + + def setup_common_handlers(self): + @sagemaker_standards.register_create_session_handler( + request_shape={ + "capacity_of_str_len": "`1024`", + }, + response_session_id_path="body", + content_path="`successfully created session.`", + ) + @self.app.api_route("/open_session", methods=["GET", "POST"]) + async def create_session(obj: CreateSessionRequest, request: Request): + return self.custom_create_session(obj, request) + + @sagemaker_standards.register_close_session_handler( + request_shape={ + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' + }, + content_path="`successfully closed session.`", + ) + @self.app.api_route("/close_session", methods=["GET", "POST"]) + async def close_session(obj: CloseSessionRequest, request: Request): + return self.custom_close_session(obj, request) + + async def custom_invocations(self, request: Request): + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + # Extract session ID from request headers if present + session_id = body.get("session_id") or request.headers.get( + SageMakerSessionHeader.SESSION_ID + ) + return Response( + status_code=200, + content=json.dumps( + { + "message": "Request in session", + "session_id": session_id or "no-session", + "echo": body, + } + ), + ) + + def setup_invocation_handler(self): @self.router.post("/invocations") @sagemaker_standards.stateful_session_manager() async def invocations(request: Request): - """Handler with session management.""" - body_bytes = await request.body() - body = json.loads(body_bytes.decode()) - return Response( - status_code=200, - content=json.dumps({"message": "success", "echo": body}), - ) + return await self.custom_invocations(request) - # Register custom create session handler - @register_engine_session_handler( - handler_type="create_session", - request_shape={}, # Session requests don't have extra fields - session_id_path="body.session.id", # Path to session ID in response body - content_path="body.message", # Path to message in response body - ) - async def custom_create_session(raw_request: Request): - """Custom handler that delegates to engine API.""" - # Call mock engine API with default model - result = self.mock_engine.create_session("default-model") - return result - - # Register custom close session handler - @register_engine_session_handler( - handler_type="close_session", - request_shape={}, - content_path="body.result.message", # Path to message in response body - ) - async def custom_close_session(raw_request: Request): - """Custom handler that delegates to engine API.""" - session_id = raw_request.headers.get(SageMakerSessionHeader.SESSION_ID) - - # Call mock engine API - result = self.mock_engine.close_session(session_id) - return result - - def test_custom_create_handler_is_invoked(self): - """Test that custom create handler is called instead of default.""" + # Helper methods for common test operations + def create_session(self) -> str: + """Helper to create a session and return the session ID.""" response = self.client.post("/invocations", json={"requestType": "NEW_SESSION"}) - assert response.status_code == 200 assert SageMakerSessionHeader.NEW_SESSION_ID in response.headers - - # Verify custom handler was called - assert self.mock_engine.call_count["create"] == 1 - - def test_custom_close_handler_is_invoked(self): - """Test that custom close handler is called instead of default.""" - # First create a session - create_response = self.client.post( - "/invocations", json={"requestType": "NEW_SESSION"} + return extract_session_id_from_header( + response.headers[SageMakerSessionHeader.NEW_SESSION_ID] ) - session_id = extract_session_id_from_header( - create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] + + def create_session_with_id(self, session_id: str) -> Response: + """Helper to create a session with a specific ID.""" + return self.client.post( + "/invocations", + json={"requestType": "NEW_SESSION"}, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, ) - # Now close it - close_response = self.client.post( + def close_session(self, session_id: str) -> Response: + """Helper to close a session.""" + return self.client.post( "/invocations", json={"requestType": "CLOSE"}, headers={SageMakerSessionHeader.SESSION_ID: session_id}, ) - if close_response.status_code != 200: - print(f"Close response status: {close_response.status_code}") - print(f"Close response body: {close_response.text}") + def invoke_with_session(self, session_id: str, body: dict) -> Response: + """Helper to make an invocation request with a session.""" + return self.client.post( + "/invocations", + json=body, + headers={SageMakerSessionHeader.SESSION_ID: session_id}, + ) - assert close_response.status_code == 200 - assert SageMakerSessionHeader.CLOSED_SESSION_ID in close_response.headers - # Verify custom handler was called - assert self.mock_engine.call_count["close"] == 1 +class TestSimpleCreateSessionCustomHandler(BaseCustomHandlerIntegrationTest): + """Test basic custom create session handler with simple string return.""" + + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + return DEFAULT_SESSION_ID + + def test_create_new_session(self): + """Test that custom handler returning a simple string works correctly. - def test_custom_handler_response_transformation(self): - """Test that response is properly transformed from engine format to SageMaker format.""" + This validates the simplest case where a custom handler returns just a string + (the session ID) rather than a complex object. This is useful when the engine's + session API returns a simple session identifier. + """ + # Send NEW_SESSION request to trigger custom create handler response = self.client.post("/invocations", json={"requestType": "NEW_SESSION"}) - # Verify response has SageMaker format + # Verify successful session creation assert response.status_code == 200 + assert SageMakerSessionHeader.NEW_SESSION_ID in response.headers + + # Extract session ID from response header session_id = extract_session_id_from_header( response.headers[SageMakerSessionHeader.NEW_SESSION_ID] ) - # Session ID should be from engine (engine-session-X format) - assert session_id.startswith("engine-session-") + # Verify the custom handler's return value (DEFAULT_SESSION_ID) is used as session ID + # This confirms the transform correctly extracted the session ID from the string response + assert session_id == DEFAULT_SESSION_ID - # Response body should contain engine message - assert b"Engine session created" in response.content - def test_custom_handler_request_transformation(self): - """Test that custom handler is invoked for session creation.""" - self.client.post( - "/invocations", - json={"requestType": "NEW_SESSION"}, - ) +class TestErrorCreateSessionCustomHandler(BaseCustomHandlerIntegrationTest): + """Test error handling when custom create session handler fails.""" - # Verify the engine received the request - assert self.mock_engine.call_count["create"] == 1 - # Check that session was created in engine - sessions = list(self.mock_engine.sessions.values()) - assert len(sessions) == 1 - assert sessions[0]["model"] == "default-model" + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + raise HTTPException(status_code=400, detail="Engine failed to create session") + def test_create_new_session_error(self): + """Test that errors from custom create handler are properly propagated. -class TestCustomHandlerErrorHandling: - """Test error handling in custom handlers.""" + When the underlying engine fails to create a session (e.g., resource exhaustion, + invalid parameters), the error should be propagated to the client with appropriate + status code and error message. + """ + # Attempt to create session - custom handler will raise HTTPException + response = self.client.post("/invocations", json={"requestType": "NEW_SESSION"}) + + # Verify error status code is returned + assert response.status_code == 400 + + # Verify error message from custom handler is included in response + assert "Engine failed to create session" in response.text + + # Verify no session header is present on error (session was not created) + assert SageMakerSessionHeader.NEW_SESSION_ID not in response.headers + + +class TestErrorCloseSessionCustomHandler(BaseCustomHandlerIntegrationTest): + """Test error handling when custom close session handler fails.""" def setup_method(self): - """Set up test fixtures.""" - self.app = FastAPI() - self.router = APIRouter() - self.mock_engine = MockEngineAPI() - self.setup_handlers() - self.app.include_router(self.router) - # bootstrap() not needed - tests define their own routes - self.client = TestClient(self.app) + self.sessions = {} + super().setup_method() - def setup_handlers(self): - """Set up handlers and register custom session handlers.""" + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + session_id = str(uuid.uuid4()) + self.sessions[session_id] = session_id + return session_id - @self.router.post("/invocations") - @sagemaker_standards.stateful_session_manager() - async def invocations(request: Request): - """Handler with session management.""" - body_bytes = await request.body() - body = json.loads(body_bytes.decode()) + def custom_close_session(self, obj: CloseSessionRequest, request: Request): + if obj.session_id in self.sessions: + self.sessions.pop(obj.session_id) return Response( - status_code=200, - content=json.dumps({"message": "success", "echo": body}), + status_code=200, content=f"Session {obj.session_id} closed." ) - - @register_engine_session_handler( - handler_type="create_session", - request_shape={}, - session_id_path="body.session.id", - content_path="body.message", + raise HTTPException( + status_code=404, detail=f"Session {obj.session_id} does not exist." ) - async def custom_create_session(raw_request: Request): - """Custom handler that can fail.""" - result = self.mock_engine.create_session() - return result - - @register_engine_session_handler( - handler_type="close_session", - request_shape={}, - content_path="body.result.message", - ) - async def custom_close_session(raw_request: Request): - """Custom handler that validates session exists.""" - session_id = raw_request.headers.get(SageMakerSessionHeader.SESSION_ID) - # This will raise ValueError if session not found - result = self.mock_engine.close_session(session_id) - return result - - def test_custom_handler_error_propagates(self): - """Test that errors from custom handlers propagate correctly. - - Note: Unhandled exceptions in custom handlers will bubble up through FastAPI. - In production, these should be caught by FastAPI's exception handlers. - For testing with TestClient, the exception is raised directly. - """ - # Try to close a non-existent session - this will raise ValueError - with pytest.raises(ValueError, match="Session nonexistent-session not found"): - self.client.post( - "/invocations", - json={"requestType": "CLOSE"}, - headers={SageMakerSessionHeader.SESSION_ID: "nonexistent-session"}, - ) - def test_custom_handler_missing_session_id_in_response(self): - """Test handling when custom handler doesn't return session ID.""" + def test_duplicate_close_session(self): + """Test that closing an already-closed session returns 404. - # Create a handler that returns invalid response - @register_engine_session_handler( - handler_type="create_session", - request_shape={}, - session_id_path="body.session.id", # Path that won't exist - content_path="body.message", - ) - async def broken_create_session(raw_request: Request): - """Handler that returns response without session.id.""" - return {"message": "created but no session id"} + This validates idempotency handling - attempting to close a session that's + already been closed should return a 404 error rather than succeeding silently. + This is important for detecting client-side bugs or race conditions. + """ + # Create a new session for testing + session_id = self.create_session() - response = self.client.post("/invocations", json={"requestType": "NEW_SESSION"}) + # First close should succeed - session exists in custom handler's storage + success_response = self.close_session(session_id) + assert success_response.status_code == 200 + assert SageMakerSessionHeader.CLOSED_SESSION_ID in success_response.headers - # Should fail with BAD_GATEWAY since session ID is missing - assert response.status_code == HTTPStatus.BAD_GATEWAY.value + # Second close should fail - session no longer exists (was removed on first close) + # Custom handler raises HTTPException(404) when session not found + duplicate_response = self.close_session(session_id) + assert duplicate_response.status_code == 404 -class TestCustomHandlerWithSessionIdPath: - """Test custom handlers work with session_id_path parameter.""" +class TestCustomSessionEndToEndFlow(BaseCustomHandlerIntegrationTest): + """Test complete end-to-end flows with custom session handlers.""" def setup_method(self): - """Set up test fixtures.""" - self.app = FastAPI() - self.router = APIRouter() - self.mock_engine = MockEngineAPI() - self.captured_requests = [] - self.setup_handlers_with_session_id_path() - self.app.include_router(self.router) - # bootstrap() not needed - tests define their own routes - self.client = TestClient(self.app) - - def setup_handlers_with_session_id_path(self): - """Set up handlers that use session_id_path.""" - + self.sessions = {} + super().setup_method() + + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + self.handler_calls["create"] += 1 + if not obj.session_id: + obj.session_id = str(uuid.uuid4()) + if obj.session_id in self.sessions: + return Response(status_code=400) + self.sessions[obj.session_id] = obj.session_id + return {"session_id": obj.session_id} + + def custom_close_session(self, obj: CloseSessionRequest, request: Request): + self.handler_calls["close"] += 1 + if obj.session_id not in self.sessions: + raise HTTPException( + status_code=404, detail=f"Session {obj.session_id} does not exist." + ) + self.sessions.pop(obj.session_id) + return Response(status_code=200, content=f"Session {obj.session_id} closed.") + + def setup_common_handlers(self): + @sagemaker_standards.register_create_session_handler( + request_shape={ + "capacity_of_str_len": "`1024`", + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"', + }, + response_session_id_path="body.session_id", # Nested + content_path="`successfully created session.`", + ) + @self.app.api_route("/open_session", methods=["GET", "POST"]) + async def create_session(obj: CreateSessionRequest, request: Request): + return self.custom_create_session(obj, request) + + @sagemaker_standards.register_close_session_handler( + request_shape={ + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' + }, + content_path="`successfully closed session.`", + ) + @self.app.api_route("/close_session", methods=["GET", "POST"]) + async def close_session(obj: CloseSessionRequest, request: Request): + return self.custom_close_session(obj, request) + + def setup_invocation_handler(self): @self.router.post("/invocations") - @sagemaker_standards.stateful_session_manager(session_id_path="session_id") + @sagemaker_standards.stateful_session_manager( + request_session_id_path="session_id" + ) async def invocations(request: Request): - """Handler that injects session ID into request body.""" - body_bytes = await request.body() - body = json.loads(body_bytes.decode()) + return await self.custom_invocations(request) - # Capture for verification - self.captured_requests.append(body) + def test_create_existing_session_error_handling(self): + """Test that attempting to create a session with existing ID fails. - return Response( - status_code=200, - content=json.dumps( - { - "message": "success", - "session_id_from_body": body.get("session_id"), - } - ), - ) + This validates that the custom handler properly rejects attempts to create + a session with a duplicate ID. This prevents session ID collisions and ensures + session uniqueness. + """ + # Create initial session + session_id = self.create_session() - @register_engine_session_handler( - handler_type="create_session", - request_shape={}, - session_id_path="body.session.id", - content_path="body.message", - ) - async def custom_create_session(raw_request: Request): - """Custom create handler.""" - result = self.mock_engine.create_session() - return result - - @register_engine_session_handler( - handler_type="close_session", - request_shape={}, - content_path="body.result.message", - ) - async def custom_close_session(raw_request: Request): - """Custom close handler.""" - session_id = raw_request.headers.get(SageMakerSessionHeader.SESSION_ID) - result = self.mock_engine.close_session(session_id) - return result - - def test_session_id_injected_with_custom_handler(self): - """Test that session_id_path works with custom handlers.""" - # Create session with custom handler - create_response = self.client.post( - "/invocations", json={"requestType": "NEW_SESSION"} - ) - session_id = extract_session_id_from_header( - create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] - ) + # Try to create another session with the same ID by passing it in the header + # Custom handler checks if session_id already exists and returns 400 if it does + header_response = self.create_session_with_id(session_id) + assert header_response.status_code == 400 - # Make request with session ID - should be injected into body - self.captured_requests.clear() - response = self.client.post( - "/invocations", - json={"prompt": "test"}, - headers={SageMakerSessionHeader.SESSION_ID: session_id}, - ) + def test_end_to_end_simple(self): + """Test complete session lifecycle: create -> use -> close. - assert response.status_code == 200 - data = json.loads(response.text) + This is the primary happy path test that validates the full session workflow + works correctly with custom handlers. This simulates a typical client interaction + pattern for stateful ML inference (e.g., multi-turn conversation with an LLM). + """ + # Step 1: Create session via custom handler + session_id = self.create_session() + + # Step 2: Use session for inference request + # Session ID is passed in header and should be available to the handler + invoke_response = self.invoke_with_session(session_id, {"prompt": "hello"}) + assert invoke_response.status_code == 200 + # Verify session ID is echoed back, confirming session context was maintained + assert session_id in invoke_response.text + + # Step 3: Close session via custom handler + close_response = self.close_session(session_id) + assert close_response.status_code == 200 + # Verify closed session header is returned + assert SageMakerSessionHeader.CLOSED_SESSION_ID in close_response.headers - # Verify session ID was injected - assert data["session_id_from_body"] == session_id - assert len(self.captured_requests) == 1 - assert self.captured_requests[0]["session_id"] == session_id + def test_handler_call_tracking(self): + """Test that custom handlers are actually being invoked. + This validates that the decorator registration system correctly routes session + requests to the custom handlers rather than using default handlers. The counters + prove the custom handler code is executing. + """ + # Reset counters to ensure clean state + self.handler_calls = {"create": 0, "close": 0} -class TestCustomHandlerConcurrency: - """Test concurrent operations with custom handlers.""" + # Create session - should increment create counter + session_id = self.create_session() + assert self.handler_calls["create"] == 1 # Custom create handler was called + assert self.handler_calls["close"] == 0 # Close handler not called yet - def setup_method(self): - """Set up test fixtures.""" - self.app = FastAPI() - self.router = APIRouter() - self.mock_engine = MockEngineAPI() - self.setup_custom_handlers() - self.app.include_router(self.router) - # bootstrap() not needed - tests define their own routes - self.client = TestClient(self.app) + # Close session - should increment close counter + close_response = self.close_session(session_id) + assert close_response.status_code == 200 + assert self.handler_calls["create"] == 1 # Create counter unchanged + assert self.handler_calls["close"] == 1 # Custom close handler was called - def setup_custom_handlers(self): - """Register custom handlers.""" + def test_multiple_sessions_independent_state(self): + """Test that multiple sessions maintain independent state in custom handlers. - @self.router.post("/invocations") - @sagemaker_standards.stateful_session_manager() - async def invocations(request: Request): - """Handler with session management.""" - body_bytes = await request.body() - body = json.loads(body_bytes.decode()) + This validates session isolation - multiple concurrent sessions should not + interfere with each other. This is critical for multi-tenant scenarios where + different users/clients have active sessions simultaneously. + """ + # Create two independent sessions + session1_id = self.create_session() + session2_id = self.create_session() + + # Verify both sessions exist in custom handler's storage + assert session1_id in self.sessions + assert session2_id in self.sessions + # Verify sessions have unique IDs + assert session1_id != session2_id + + # Close first session only + self.close_session(session1_id) + + # Verify only first session was removed from storage + assert session1_id not in self.sessions + # Verify second session still exists and is unaffected + assert session2_id in self.sessions + + # Verify second session is still functional after first session closed + response = self.invoke_with_session(session2_id, {"prompt": "test"}) + assert response.status_code == 200 + + +class TestCustomHandlerResponseFormats(BaseCustomHandlerIntegrationTest): + """Test that custom handlers can return different response formats.""" + + def setup_method(self): + self.sessions = {} + self.response_format = "dict" # Can be "dict", "string", or "response_object" + super().setup_method() + + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + session_id = str(uuid.uuid4()) + self.sessions[session_id] = True + + if self.response_format == "dict": + return {"session_id": session_id, "metadata": {"engine": "custom"}} + elif self.response_format == "string": + return session_id + elif self.response_format == "response_object": return Response( - status_code=200, - content=json.dumps({"message": "success", "echo": body}), + status_code=201, + content=json.dumps({"session_id": session_id}), + media_type="application/json", ) - @register_engine_session_handler( - handler_type="create_session", - request_shape={}, - session_id_path="body.session.id", - content_path="body.message", - ) - async def custom_create_session(raw_request: Request): - """Custom create handler.""" - result = self.mock_engine.create_session() - return result - - @register_engine_session_handler( - handler_type="close_session", - request_shape={}, - content_path="body.result.message", - ) - async def custom_close_session(raw_request: Request): - """Custom close handler.""" - session_id = raw_request.headers.get(SageMakerSessionHeader.SESSION_ID) - result = self.mock_engine.close_session(session_id) - return result - - def test_multiple_sessions_with_custom_handlers(self): - """Test creating and managing multiple sessions with custom handlers.""" - # Create multiple sessions - session_ids = [] - for i in range(3): - response = self.client.post( - "/invocations", json={"requestType": "NEW_SESSION"} - ) - assert response.status_code == 200 - session_id = extract_session_id_from_header( - response.headers[SageMakerSessionHeader.NEW_SESSION_ID] - ) - session_ids.append(session_id) - - # Verify all sessions were created in engine - assert self.mock_engine.call_count["create"] == 3 - assert len(self.mock_engine.sessions) == 3 - - # Close all sessions - for session_id in session_ids: - response = self.client.post( - "/invocations", - json={"requestType": "CLOSE"}, - headers={SageMakerSessionHeader.SESSION_ID: session_id}, - ) - assert response.status_code == 200 + def custom_close_session(self, obj: CloseSessionRequest, request: Request): + if obj.session_id in self.sessions: + del self.sessions[obj.session_id] + return Response(status_code=200, content="Closed") - # Verify all sessions were closed in engine - assert self.mock_engine.call_count["close"] == 3 + def setup_common_handlers(self): + # Use different response_session_id_path based on format + response_path = "body.session_id" if self.response_format == "dict" else "body" - def test_interleaved_operations_with_custom_handlers(self): - """Test interleaved create/use/close operations.""" - # Create session 1 - response1 = self.client.post( - "/invocations", json={"requestType": "NEW_SESSION"} - ) - session1_id = extract_session_id_from_header( - response1.headers[SageMakerSessionHeader.NEW_SESSION_ID] + @sagemaker_standards.register_create_session_handler( + request_shape={"capacity_of_str_len": "`1024`"}, + response_session_id_path=response_path, + content_path="`successfully created session.`", ) + @self.app.api_route("/open_session", methods=["GET", "POST"]) + async def create_session(obj: CreateSessionRequest, request: Request): + return self.custom_create_session(obj, request) - # Create session 2 - response2 = self.client.post( - "/invocations", json={"requestType": "NEW_SESSION"} - ) - session2_id = extract_session_id_from_header( - response2.headers[SageMakerSessionHeader.NEW_SESSION_ID] + @sagemaker_standards.register_close_session_handler( + request_shape={ + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' + }, + content_path="`successfully closed session.`", ) + @self.app.api_route("/close_session", methods=["GET", "POST"]) + async def close_session(obj: CloseSessionRequest, request: Request): + return self.custom_close_session(obj, request) - # Close session 1 - close1 = self.client.post( - "/invocations", - json={"requestType": "CLOSE"}, - headers={SageMakerSessionHeader.SESSION_ID: session1_id}, - ) - assert close1.status_code == 200 + def test_dict_response_with_metadata(self): + """Test custom handler returning dict with additional metadata. - # Session 2 should still work - # (Note: In this test we're just verifying the close worked) - assert self.mock_engine.sessions[session1_id]["active"] is False - assert self.mock_engine.sessions[session2_id]["active"] is True + Many engine APIs return rich response objects with metadata alongside the + session ID. This validates that the transform can extract the session ID + from a nested path while preserving other response data. + """ + self.response_format = "dict" + # Create session - handler returns {"session_id": "...", "metadata": {...}} + session_id = self.create_session() + # Verify session was created successfully + assert session_id in self.sessions + # Verify session ID is in UUID format (36 characters with hyphens) + assert len(session_id) == 36 -class TestCustomHandlerComplexTransformations: - """Test custom handlers with complex request/response transformations.""" + def test_dict_response_with_nested_session_id(self): + """Test custom handler returning dict with nested session ID path. - def setup_method(self): - """Set up test fixtures.""" - self.app = FastAPI() - self.router = APIRouter() - self.mock_engine = MockEngineAPI() - self.setup_complex_handlers() - self.app.include_router(self.router) - # bootstrap() not needed - tests define their own routes - self.client = TestClient(self.app) + This validates that the response_session_id_path configuration correctly + extracts the session ID from nested response structures (e.g., body.session_id). + """ + self.response_format = "dict" + # Create session with nested response structure + session_id = self.create_session() - def setup_complex_handlers(self): - """Register handlers with complex transformations.""" + # Verify session was created and can be used for subsequent requests + response = self.invoke_with_session(session_id, {"test": "data"}) + assert response.status_code == 200 - @self.router.post("/invocations") - @sagemaker_standards.stateful_session_manager() - async def invocations(request: Request): - """Handler with session management.""" - body_bytes = await request.body() - body = json.loads(body_bytes.decode()) - return Response( - status_code=200, - content=json.dumps({"message": "success", "echo": body}), - ) - @register_engine_session_handler( - handler_type="create_session", - request_shape={}, - session_id_path="body.session.id", - content_path="body.message", - ) - async def custom_create_session(raw_request: Request): - """Handler for session creation.""" - result = self.mock_engine.create_session("default") - return result - - @register_engine_session_handler( - handler_type="close_session", - request_shape={}, - content_path="body.result.message", - ) - async def custom_close_session(raw_request: Request): - """Handler for session closure.""" - session_id = raw_request.headers.get(SageMakerSessionHeader.SESSION_ID) +class TestCustomHandlerMultipleInvocations(BaseCustomHandlerIntegrationTest): + """Test multiple invocations within the same session with custom handlers.""" - # Engine API closes the session - result = self.mock_engine.close_session(session_id) - return result + def setup_method(self): + self.sessions = {} + self.invocation_counts = {} + super().setup_method() + + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + session_id = str(uuid.uuid4()) + self.sessions[session_id] = {"created": True} + self.invocation_counts[session_id] = 0 + return {"session_id": session_id} + + def custom_close_session(self, obj: CloseSessionRequest, request: Request): + if obj.session_id in self.sessions: + del self.sessions[obj.session_id] + if obj.session_id in self.invocation_counts: + del self.invocation_counts[obj.session_id] + return Response(status_code=200) + + def setup_common_handlers(self): + @sagemaker_standards.register_create_session_handler( + request_shape={"capacity_of_str_len": "`1024`"}, + response_session_id_path="body.session_id", + content_path="`successfully created session.`", + ) + @self.app.api_route("/open_session", methods=["GET", "POST"]) + async def create_session(obj: CreateSessionRequest, request: Request): + return self.custom_create_session(obj, request) + + @sagemaker_standards.register_close_session_handler( + request_shape={ + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' + }, + content_path="`successfully closed session.`", + ) + @self.app.api_route("/close_session", methods=["GET", "POST"]) + async def close_session(obj: CloseSessionRequest, request: Request): + return self.custom_close_session(obj, request) + + async def custom_invocations(self, request: Request): + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + session_id = request.headers.get(SageMakerSessionHeader.SESSION_ID) + + # Track invocation count per session + if session_id and session_id in self.invocation_counts: + self.invocation_counts[session_id] += 1 + + return Response( + status_code=200, + content=json.dumps( + { + "message": "success", + "session_id": session_id, + "invocation_count": self.invocation_counts.get(session_id, 0), + "echo": body, + } + ), + ) + + def test_multiple_invocations_same_session(self): + """Test that multiple invocations work correctly within the same session. + + This validates that session state (invocation count) accumulates correctly + across multiple requests. This is essential for stateful ML scenarios like + maintaining conversation context or tracking request history. + """ + session_id = self.create_session() - def test_complex_request_transformation(self): - """Test that custom handlers work with session creation.""" - response = self.client.post( - "/invocations", - json={"requestType": "NEW_SESSION"}, - ) + # Make 5 sequential invocations to the same session + for i in range(5): + response = self.invoke_with_session(session_id, {"request_num": i + 1}) + assert response.status_code == 200 + data = json.loads(response.text) + # Verify invocation count increments with each request + assert data["invocation_count"] == i + 1 + # Verify session ID remains consistent + assert data["session_id"] == session_id - assert response.status_code == 200 - assert SageMakerSessionHeader.NEW_SESSION_ID in response.headers + def test_invocation_counts_independent_across_sessions(self): + """Test that invocation counts are independent across different sessions. - # Verify engine created session - assert self.mock_engine.call_count["create"] == 1 + This validates session isolation at the invocation level - each session + maintains its own independent counter. Critical for ensuring one user's + session activity doesn't affect another user's session. + """ + # Create two separate sessions + session1_id = self.create_session() + session2_id = self.create_session() - def test_close_with_custom_handler(self): - """Test close handler with custom implementation.""" - # Create session first - create_response = self.client.post( - "/invocations", json={"requestType": "NEW_SESSION"} - ) - session_id = extract_session_id_from_header( - create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] - ) + # Make 3 invocations to session 1 + for i in range(3): + self.invoke_with_session(session1_id, {"msg": "session1"}) - # Close session - close_response = self.client.post( - "/invocations", - json={"requestType": "CLOSE"}, - headers={SageMakerSessionHeader.SESSION_ID: session_id}, - ) + # Make 5 invocations to session 2 + for i in range(5): + self.invoke_with_session(session2_id, {"msg": "session2"}) - assert close_response.status_code == 200 - assert SageMakerSessionHeader.CLOSED_SESSION_ID in close_response.headers - assert self.mock_engine.call_count["close"] == 1 + # Verify each session has its own independent count + assert self.invocation_counts[session1_id] == 3 + assert self.invocation_counts[session2_id] == 5 -class TestCustomHandlerEndToEnd: - """End-to-end tests simulating real engine integration patterns.""" +class TestCustomHandlerWithSessionIdInjection(BaseCustomHandlerIntegrationTest): + """Test custom handlers with request_session_id_path parameter.""" def setup_method(self): - """Set up test fixtures.""" - self.app = FastAPI() - self.router = APIRouter() - self.mock_engine = MockEngineAPI() - self.setup_realistic_handlers() - self.app.include_router(self.router) - # bootstrap() not needed - tests define their own routes - self.client = TestClient(self.app) - - def setup_realistic_handlers(self): - """Set up handlers that simulate real vLLM/TGI integration.""" - + self.sessions = {} + super().setup_method() + + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + session_id = str(uuid.uuid4()) + self.sessions[session_id] = {"created": True} + return {"session_id": session_id} + + def custom_close_session(self, obj: CloseSessionRequest, request: Request): + if obj.session_id in self.sessions: + del self.sessions[obj.session_id] + return Response(status_code=200, content="Session closed") + raise HTTPException(status_code=404, detail="Session not found") + + def setup_common_handlers(self): + @sagemaker_standards.register_create_session_handler( + request_shape={"capacity_of_str_len": "`1024`"}, + response_session_id_path="body.session_id", + content_path="`successfully created session.`", + ) + @self.app.api_route("/open_session", methods=["GET", "POST"]) + async def create_session(obj: CreateSessionRequest, request: Request): + return self.custom_create_session(obj, request) + + @sagemaker_standards.register_close_session_handler( + request_shape={ + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' + }, + content_path="`successfully closed session.`", + ) + @self.app.api_route("/close_session", methods=["GET", "POST"]) + async def close_session(obj: CloseSessionRequest, request: Request): + return self.custom_close_session(obj, request) + + def setup_invocation_handler(self): @self.router.post("/invocations") @sagemaker_standards.stateful_session_manager( - session_id_path="metadata.session_id" + request_session_id_path="metadata.session_id" ) async def invocations(request: Request): - """Realistic inference handler.""" body_bytes = await request.body() body = json.loads(body_bytes.decode()) - # Simulate inference with session context + # Extract session ID from nested path session_id = body.get("metadata", {}).get("session_id") - prompt = body.get("prompt", "") return Response( status_code=200, content=json.dumps( - { - "generated_text": f"Response to: {prompt}", - "session_id": session_id, - "metadata": {"tokens": 42}, - } + {"message": "success", "session_id": session_id, "body": body} ), ) - @register_engine_session_handler( - handler_type="create_session", - request_shape={}, - session_id_path="body.session.id", - content_path="body.message", - ) - async def vllm_create_session(raw_request: Request): - """Simulate vLLM session creation.""" - result = self.mock_engine.create_session("default-model") - return result - - @register_engine_session_handler( - handler_type="close_session", - request_shape={}, - content_path="body.result.message", - ) - async def vllm_close_session(raw_request: Request): - """Simulate vLLM session closure.""" - session_id = raw_request.headers.get(SageMakerSessionHeader.SESSION_ID) - result = self.mock_engine.close_session(session_id) - return result - - def test_full_lifecycle_with_custom_handlers(self): - """Test complete lifecycle: create -> use -> close with custom handlers.""" - # 1. Create session via custom handler - create_response = self.client.post( - "/invocations", - json={"requestType": "NEW_SESSION"}, - ) - assert create_response.status_code == 200 - session_id = extract_session_id_from_header( - create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] - ) - assert session_id.startswith("engine-session-") + def test_session_id_injected_into_nested_path(self): + """Test that session ID is injected into nested path in request body. - # 2. Use session for inference - inference_response = self.client.post( - "/invocations", - json={"prompt": "Hello, world!", "metadata": {}}, - headers={SageMakerSessionHeader.SESSION_ID: session_id}, + Some ML engines expect the session ID to be in the request body rather than + just in headers. The request_session_id_path parameter allows automatic + injection of the session ID into a specified path in the request body + (e.g., metadata.session_id). This test validates that injection works correctly. + """ + # Create session + session_id = self.create_session() + + # Make request with session - note we don't include session_id in the body + # The framework should inject it automatically at metadata.session_id + response = self.invoke_with_session( + session_id, {"prompt": "test", "metadata": {"user": "test_user"}} ) - assert inference_response.status_code == 200 - data = json.loads(inference_response.text) + + assert response.status_code == 200 + data = json.loads(response.text) + + # Verify session ID was automatically injected into the nested path assert data["session_id"] == session_id - assert "Response to: Hello, world!" in data["generated_text"] + assert data["body"]["metadata"]["session_id"] == session_id + # Verify original metadata fields are preserved + assert data["body"]["metadata"]["user"] == "test_user" - # 3. Close session via custom handler - close_response = self.client.post( - "/invocations", - json={"requestType": "CLOSE"}, - headers={SageMakerSessionHeader.SESSION_ID: session_id}, - ) - assert close_response.status_code == 200 - assert ( - close_response.headers[SageMakerSessionHeader.CLOSED_SESSION_ID] - == session_id - ) - # Verify engine state - assert self.mock_engine.call_count["create"] == 1 - assert self.mock_engine.call_count["close"] == 1 - assert self.mock_engine.sessions[session_id]["active"] is False +class TestCustomHandlerSessionPersistence(BaseCustomHandlerIntegrationTest): + """Test that session state persists correctly across invocations with custom handlers.""" - def test_multiple_inference_calls_with_custom_session(self): - """Test multiple inference calls using custom session.""" - # Create session - create_response = self.client.post( - "/invocations", json={"requestType": "NEW_SESSION"} - ) - session_id = extract_session_id_from_header( - create_response.headers[SageMakerSessionHeader.NEW_SESSION_ID] - ) + def setup_method(self): + self.sessions = {} + super().setup_method() + + def custom_create_session(self, obj: CreateSessionRequest, request: Request): + session_id = str(uuid.uuid4()) + # Store session with initial state for ML inference + self.sessions[session_id] = { + "conversation_history": [], + "inference_params": {}, + "created_at": "2024-01-01", + } + return {"session_id": session_id} + + def custom_close_session(self, obj: CloseSessionRequest, request: Request): + if obj.session_id in self.sessions: + del self.sessions[obj.session_id] + return Response(status_code=200) + + def setup_common_handlers(self): + @sagemaker_standards.register_create_session_handler( + request_shape={"capacity_of_str_len": "`1024`"}, + response_session_id_path="body.session_id", + content_path="`successfully created session.`", + ) + @self.app.api_route("/open_session", methods=["GET", "POST"]) + async def create_session(obj: CreateSessionRequest, request: Request): + return self.custom_create_session(obj, request) + + @sagemaker_standards.register_close_session_handler( + request_shape={ + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' + }, + content_path="`successfully closed session.`", + ) + @self.app.api_route("/close_session", methods=["GET", "POST"]) + async def close_session(obj: CloseSessionRequest, request: Request): + return self.custom_close_session(obj, request) + + async def custom_invocations(self, request: Request): + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + session_id = request.headers.get(SageMakerSessionHeader.SESSION_ID) + + # Simulate updating session state for ML inference + if session_id and session_id in self.sessions: + if "message" in body: + self.sessions[session_id]["conversation_history"].append( + body["message"] + ) + if "inference_params" in body: + self.sessions[session_id]["inference_params"].update( + body["inference_params"] + ) + + session_data = self.sessions.get(session_id, {}) + + return Response( + status_code=200, + content=json.dumps( + { + "session_id": session_id, + "conversation_history": session_data.get( + "conversation_history", [] + ), + "inference_params": session_data.get("inference_params", {}), + } + ), + ) + + def test_conversation_history_persists(self): + """Test that conversation history accumulates across invocations. + + This simulates a multi-turn conversation with an LLM where each message + is added to the session's conversation history. This is a common pattern + for chatbots and conversational AI where context from previous turns + needs to be maintained. + """ + session_id = self.create_session() - # Make multiple inference calls - prompts = ["First prompt", "Second prompt", "Third prompt"] - for prompt in prompts: - response = self.client.post( - "/invocations", - json={"prompt": prompt, "metadata": {}}, - headers={SageMakerSessionHeader.SESSION_ID: session_id}, - ) + # Send multiple messages in sequence (simulating a conversation) + messages = ["Hello", "How are you?", "Tell me a joke"] + for msg in messages: + # Each message is added to the session's conversation history + response = self.invoke_with_session(session_id, {"message": msg}) assert response.status_code == 200 - data = json.loads(response.text) - assert data["session_id"] == session_id - assert prompt in data["generated_text"] - # Close session - close_response = self.client.post( - "/invocations", - json={"requestType": "CLOSE"}, - headers={SageMakerSessionHeader.SESSION_ID: session_id}, - ) - assert close_response.status_code == 200 + # Make a final request to retrieve the accumulated history + final_response = self.invoke_with_session(session_id, {}) + data = json.loads(final_response.text) + # Verify all messages were stored in order + assert data["conversation_history"] == messages + + def test_inference_parameters_persist(self): + """Test that ML inference parameters are maintained across invocations. + + This validates that ML-specific inference parameters (temperature, max_tokens, top_p) + can be set incrementally and persist across the session. This is useful for: + - LLM inference where users want consistent generation parameters + - A/B testing different parameter combinations within a session + - Gradual parameter tuning based on user feedback + """ + session_id = self.create_session() + + # Set inference parameters incrementally across multiple requests + # Temperature: controls randomness in text generation (0.0 = deterministic, 1.0 = creative) + self.invoke_with_session(session_id, {"inference_params": {"temperature": 0.7}}) + # Max tokens: limits the length of generated output + self.invoke_with_session(session_id, {"inference_params": {"max_tokens": 512}}) + # Top-p (nucleus sampling): controls diversity of token selection + self.invoke_with_session(session_id, {"inference_params": {"top_p": 0.9}}) + + # Retrieve accumulated parameters + response = self.invoke_with_session(session_id, {}) + data = json.loads(response.text) + # Verify all parameters were stored and are accessible + assert data["inference_params"]["temperature"] == 0.7 + assert data["inference_params"]["max_tokens"] == 512 + assert data["inference_params"]["top_p"] == 0.9 + + def test_session_state_cleared_after_close(self): + """Test that session state is properly cleared when session is closed. + + This validates proper cleanup of session resources. When a session is closed, + all associated state (conversation history, parameters, etc.) should be + removed to prevent memory leaks and ensure data privacy. + """ + session_id = self.create_session() + + # Add some state to the session + self.invoke_with_session(session_id, {"message": "test"}) + # Verify state was stored + assert len(self.sessions[session_id]["conversation_history"]) == 1 + + # Close the session - should trigger cleanup in custom handler + self.close_session(session_id) + + # Verify session and all its state was completely removed from storage + # This is important for memory management and data privacy + assert session_id not in self.sessions diff --git a/python/tests/integration/test_sagemaker_sessions_integration.py b/python/tests/integration/test_sagemaker_sessions_integration.py index ce35bcc..0903216 100644 --- a/python/tests/integration/test_sagemaker_sessions_integration.py +++ b/python/tests/integration/test_sagemaker_sessions_integration.py @@ -636,13 +636,15 @@ def test_regular_requests_with_session_header_when_disabled( class TestSessionIdPathInjection(BaseSessionIntegrationTest): - """Test session_id_path parameter for injecting session ID into request body.""" + """Test request_session_id_path parameter for injecting session ID into request body.""" def setup_handlers(self): - """Define handlers with session_id_path parameter.""" + """Define handlers with request_session_id_path parameter.""" @self.router.post("/invocations-with-path") - @sagemaker_standards.stateful_session_manager(session_id_path="session_id") + @sagemaker_standards.stateful_session_manager( + request_session_id_path="session_id" + ) async def invocations_with_path(request: Request): """Handler that injects session ID into request body at 'session_id' key.""" body_bytes = await request.body() @@ -666,7 +668,7 @@ async def invocations_with_path(request: Request): @self.router.post("/invocations-nested-path") @sagemaker_standards.stateful_session_manager( - session_id_path="metadata.session_id" + request_session_id_path="metadata.session_id" ) async def invocations_nested_path(request: Request): """Handler that injects session ID into nested path in request body.""" diff --git a/python/tests/sagemaker/sessions/test_registration.py b/python/tests/sagemaker/sessions/test_registration.py deleted file mode 100644 index da30a1c..0000000 --- a/python/tests/sagemaker/sessions/test_registration.py +++ /dev/null @@ -1,121 +0,0 @@ -"""Unit tests for session handler registration functions.""" - -import pytest - -from model_hosting_container_standards.sagemaker.sessions import ( - register_engine_session_handler, -) - - -class TestRegisterEngineSessionHandler: - """Test register_engine_session_handler function.""" - - def test_create_session_requires_session_id_path(self): - """Test that create_session requires session_id_path parameter.""" - with pytest.raises(ValueError) as exc_info: - register_engine_session_handler( - handler_type="create_session", - request_shape={}, - session_id_path=None, - content_path="message", - ) - - assert "session_id_path is required" in str(exc_info.value) - - def test_create_session_with_valid_params(self): - """Test successful create_session registration.""" - decorator = register_engine_session_handler( - handler_type="create_session", - request_shape={"model": "body.model"}, - session_id_path="session_id", - content_path="message", - ) - - assert decorator is not None - assert callable(decorator) - - def test_close_session_without_session_id_path(self): - """Test that close_session doesn't require session_id_path.""" - decorator = register_engine_session_handler( - handler_type="close_session", - request_shape={}, - content_path="message", - ) - - assert decorator is not None - assert callable(decorator) - - def test_invalid_handler_type(self): - """Test that invalid handler_type raises ValueError.""" - with pytest.raises(ValueError) as exc_info: - register_engine_session_handler( - handler_type="invalid_type", - request_shape={}, - ) - - assert "Invalid handler_type" in str(exc_info.value) - assert "create_session" in str(exc_info.value) - assert "close_session" in str(exc_info.value) - - def test_adds_body_prefix_to_paths(self): - """Test that body. prefix is automatically added to response paths.""" - # This is tested indirectly - the decorator should work with paths - # relative to the handler's return value, not the serialized response - decorator = register_engine_session_handler( - handler_type="create_session", - request_shape={}, - session_id_path="id", # Should become body.id internally - content_path="message", # Should become body.message internally - ) - - assert decorator is not None - - def test_preserves_body_prefix_if_present(self): - """Test that existing body. prefix is not duplicated.""" - decorator = register_engine_session_handler( - handler_type="create_session", - request_shape={}, - session_id_path="body.id", # Already has body. prefix - content_path="body.message", - ) - - assert decorator is not None - - -class TestResponseShapeConstruction: - """Test that response_shape is constructed correctly.""" - - def test_create_session_response_shape_has_required_keys(self): - """Test that create_session response_shape includes session ID and content.""" - # We can't directly inspect the response_shape, but we can verify - # the decorator is created successfully with the right parameters - decorator = register_engine_session_handler( - handler_type="create_session", - request_shape={}, - session_id_path="session.id", - content_path="session.message", - ) - - # If this doesn't raise, the response_shape was constructed correctly - assert decorator is not None - - def test_close_session_response_shape_has_content_key(self): - """Test that close_session response_shape includes content.""" - decorator = register_engine_session_handler( - handler_type="close_session", - request_shape={}, - content_path="result.message", - ) - - assert decorator is not None - - def test_none_content_path_is_handled(self): - """Test that None content_path is handled correctly.""" - decorator = register_engine_session_handler( - handler_type="close_session", - request_shape={}, - content_path=None, - ) - - # Should still create decorator, content extraction will just return None - assert decorator is not None From ecb6184852a8409c69f1ed50910f1b4a9556c15f Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Sat, 6 Dec 2025 00:23:35 +0000 Subject: [PATCH 19/25] Update docs. --- .../sagemaker/sessions/CUSTOM_HANDLERS.md | 346 +++++++++++++++--- .../sagemaker/sessions/README.md | 77 +++- 2 files changed, 362 insertions(+), 61 deletions(-) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md b/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md index fa65de3..3aaf546 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md +++ b/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md @@ -17,90 +17,348 @@ Use custom handlers when: ### Architecture ``` -Client Request +Client Request (NEW_SESSION or CLOSE) ↓ SessionApiTransform (detects session request) ↓ -get_handler_for_request_type() +Handler Registry Check ↓ ├─→ Custom Handler (if registered) - │ └─→ Engine's Session API + │ └─→ Your Engine's Session API │ └─→ Default Handler (if not registered) └─→ SageMaker SessionManager ``` -## Handler Signatures +## Registration API -Both handlers must be async functions that accept a FastAPI `Request` object: +Use the `@register_create_session_handler` and `@register_close_session_handler` decorators to register custom handlers: ```python -from fastapi import Request, Response +from fastapi import FastAPI, Request +from pydantic import BaseModel +from model_hosting_container_standards.sagemaker import ( + register_create_session_handler, + register_close_session_handler, + stateful_session_manager, + bootstrap +) -async def my_create_session_handler(raw_request: Request) -> Response: - """Create a new session via the engine's API.""" - pass +app = FastAPI() + +# Define your engine's request/response models +class CreateSessionRequest(BaseModel): + capacity: int + +class CreateSessionResponse(BaseModel): + session_id: str + message: str + +# Register custom create session handler +@register_create_session_handler( + request_shape={ + "capacity": "`1024`" # JMESPath literal value + }, + response_session_id_path="body.session_id", # Extract session ID from response + content_path="body.message" # Extract content for logging +) +@app.post("/engine/create_session") +async def create_session(obj: CreateSessionRequest, request: Request): + # Call your engine's session creation API + session_id = await my_engine.create_session(capacity=obj.capacity) + return CreateSessionResponse(session_id=session_id, message="Session created") + +# Register custom close session handler +@register_close_session_handler( + request_shape={ + "session_id": 'headers."X-Amzn-SageMaker-Session-Id"' # Extract from header + }, + content_path="`Session closed successfully`" # Static message +) +@app.post("/engine/close_session") +async def close_session(session_id: str, request: Request): + # Call your engine's session closure API + await my_engine.close_session(session_id) + return {"status": "closed"} -async def my_close_session_handler(raw_request: Request) -> Response: - """Close an existing session via the engine's API.""" +# Your main invocations endpoint +@app.post("/invocations") +@stateful_session_manager() +async def invocations(request: Request): + # Handle regular inference requests pass + +bootstrap(app) ``` -## Using Transform Classes +## Decorator Parameters +### `@register_create_session_handler` ```python -from model_hosting_container_standards.sagemaker.sessions.transforms import ( - CreateSessionApiTransform, - CloseSessionApiTransform +@register_create_session_handler( + request_shape: dict, # Required: JMESPath mappings for request transformation + response_session_id_path: str, # Required: JMESPath to extract session ID from response + content_path: str = None # Optional: JMESPath to extract content for logging ) +``` -# Define request/response shapes using JMESPath -create_transform = CreateSessionApiTransform( - request_shape={}, # Transform incoming request - response_shape={ - "X-Amzn-SageMaker-New-Session-Id": "session_id", - "content": "message" - } +- **`request_shape`**: Maps target keys to source JMESPath expressions. Transforms the incoming SageMaker request into your engine's expected format. +- **`response_session_id_path`**: JMESPath expression to extract the session ID from your engine's response. This is **required** because the framework needs to return the session ID in the response header. +- **`content_path`**: Optional JMESPath expression to extract a message for logging. Defaults to a generic success message. + +### `@register_close_session_handler` + +```python +@register_close_session_handler( + request_shape: dict, # Required: JMESPath mappings for request transformation + content_path: str = None # Optional: JMESPath to extract content for logging +) +``` + +- **`request_shape`**: Maps target keys to source JMESPath expressions. Typically extracts the session ID from the request header. +- **`content_path`**: Optional JMESPath expression to extract a message for logging. Defaults to a generic success message. + +**Note**: `response_session_id_path` is not needed for close handlers because the session ID comes from the request header, not the response. + +## How It Works + +When you register custom handlers: + +1. **Client sends session request** to `/invocations` with `{"requestType": "NEW_SESSION"}` or `{"requestType": "CLOSE"}` +2. **SessionApiTransform intercepts** the request and checks the handler registry +3. **If custom handler registered**: Request is routed to your custom endpoint (e.g., `/engine/create_session`) +4. **Transform applies**: Request/response shapes are transformed using JMESPath +5. **Response returned**: With appropriate SageMaker session headers (`X-Amzn-SageMaker-New-Session-Id` or `X-Amzn-SageMaker-Closed-Session-Id`) + +The key benefit: Your `/invocations` endpoint stays clean, and session management is handled transparently. + +## JMESPath Expressions + +The `request_shape` and `response_shape` parameters use JMESPath expressions to transform data: + +### Request Shape + +Maps target keys to source expressions: + +```python +request_shape={ + "capacity": "`1024`", # Literal value + "session_id": 'headers."X-Amzn-SageMaker-Session-Id"', # From header + "user_id": "body.metadata.user" # From request body +} +``` + +### Response Shape + +For **create session**, you must specify: +- `response_session_id_path`: Where to extract the session ID from the engine's response +- `content_path`: Where to extract content for logging (optional) + +```python +response_session_id_path="body.session_id" # Extract from {"session_id": "..."} +response_session_id_path="body" # If response is just the session ID string +content_path="body.message" # Extract message from response +content_path="`Session created`" # Use literal string +``` + +For **close session**, you only need: +- `content_path`: Where to extract content for logging (optional) + +## Response Formats + +Your custom handlers can return different response formats: + +### Dictionary Response +```python +async def create_session(obj: CreateSessionRequest, request: Request): + session_id = str(uuid.uuid4()) + return {"session_id": session_id, "metadata": {"engine": "custom"}} +``` + +### String Response +```python +async def create_session(obj: CreateSessionRequest, request: Request): + session_id = str(uuid.uuid4()) + return session_id # Just return the session ID +``` + +### FastAPI Response Object +```python +from fastapi import Response + +async def create_session(obj: CreateSessionRequest, request: Request): + session_id = str(uuid.uuid4()) + return Response( + status_code=201, + content=json.dumps({"session_id": session_id}), + media_type="application/json" + ) +``` + +## Error Handling + +Raise `HTTPException` to return errors to the client: + +```python +from fastapi.exceptions import HTTPException + +@register_create_session_handler(...) +async def create_session(obj: CreateSessionRequest, request: Request): + try: + session_id = await my_engine.create_session() + return {"session_id": session_id} + except EngineError as e: + raise HTTPException(status_code=500, detail=f"Engine error: {e}") +``` + +## Complete Example + +Here's a complete example with error handling and session tracking: + +```python +from fastapi import FastAPI, Request, HTTPException, Response +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from typing import Optional +import uuid +import json + +from model_hosting_container_standards.sagemaker import ( + register_create_session_handler, + register_close_session_handler, + stateful_session_manager, + bootstrap +) +from model_hosting_container_standards.sagemaker.sessions.models import ( + SageMakerSessionHeader ) -close_transform = CloseSessionApiTransform( +app = FastAPI() + +# Track sessions in memory (for demo purposes) +active_sessions = {} + +class CreateSessionRequest(BaseModel): + capacity: int + session_id: Optional[str] = None + +@register_create_session_handler( request_shape={ - "session_id": "headers.'X-Amzn-SageMaker-Session-Id'" + "capacity": "`1024`", + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' }, - response_shape={ - "content": "message" + response_session_id_path="body.session_id", + content_path="body.message" +) +@app.post("/engine/create_session") +async def create_session(obj: CreateSessionRequest, request: Request): + # Generate or use provided session ID + session_id = obj.session_id or str(uuid.uuid4()) + + # Check if session already exists + if session_id in active_sessions: + raise HTTPException(status_code=400, detail="Session already exists") + + # Create session in your engine + active_sessions[session_id] = {"capacity": obj.capacity} + + return { + "session_id": session_id, + "message": f"Session created with capacity {obj.capacity}" } + +@register_close_session_handler( + request_shape={"session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"'}, + content_path="`Session closed successfully`" ) +@app.post("/engine/close_session") +async def close_session(session_id: str, request: Request): + if session_id not in active_sessions: + raise HTTPException(status_code=404, detail="Session not found") + + # Close session in your engine + del active_sessions[session_id] + + return Response(status_code=200, content="Session closed") + +@app.post("/invocations") +@stateful_session_manager(request_session_id_path="session_id") +async def invocations(request: Request): + body_bytes = await request.body() + body = json.loads(body_bytes.decode()) + session_id = body.get("session_id") + + if session_id and session_id not in active_sessions: + raise HTTPException(status_code=400, detail="Invalid session") + + # Process inference request with session context + return JSONResponse({ + "result": "success", + "session_id": session_id or "no-session", + "echo": body + }) + +bootstrap(app) ``` +## Session Validation Behavior + +When custom handlers are registered, the framework **does not** validate session IDs against the default `SessionManager`. This means: + +- **With custom handlers**: Session validation is your responsibility. The framework only routes requests to your handlers. +- **Without custom handlers** (default mode): The framework validates session IDs against the `SessionManager` automatically. + +This design allows your engine to manage sessions independently without interference from the default session manager. + ## Best Practices -1. **Validate session IDs**: Always validate that the engine returns valid session IDs -2. **Handle timeouts**: Set appropriate timeouts when calling engine APIs +1. **Validate session IDs**: Check that the engine returns valid session IDs in create handlers +2. **Handle errors gracefully**: Use HTTPException for clear error messages 3. **Log operations**: Log session creation/closure for debugging -4. **Error propagation**: Provide clear error messages when engine operations fail -5. **Cleanup**: Ensure sessions are properly cleaned up even on errors -6. **Testing**: Test both success and failure scenarios -7. **Idempotency**: Handle duplicate close requests gracefully +4. **Test thoroughly**: Test both success and failure scenarios +5. **Idempotency**: Handle duplicate close requests gracefully (return 404 or succeed silently) +6. **Session isolation**: Ensure different sessions maintain independent state +7. **Thread safety**: If your engine stores session state, ensure thread-safe access for concurrent requests -## Utilities +## Troubleshooting -The framework provides utility functions for working with sessions: +### Session ID not extracted from response +**Problem**: Getting "Engine failed to return a valid session ID" error. + +**Solution**: Check that your `response_session_id_path` matches your response structure: ```python -from model_hosting_container_standards.sagemaker.sessions.utils import ( - get_session_id_from_request, # Extract session ID from headers - get_session, # Get session from manager -) -from model_hosting_container_standards.sagemaker.sessions.models import ( - SageMakerSessionHeader, # Header name constants - SessionRequestType, # Request type enum -) +# If your handler returns: {"session_id": "abc123"} +response_session_id_path="body.session_id" + +# If your handler returns: "abc123" +response_session_id_path="body" +``` + +### Request not reaching custom handler + +**Problem**: Custom handler not being called. + +**Solution**: Ensure you call `bootstrap(app)` **after** registering your handlers: +```python +@register_create_session_handler(...) +async def create_session(...): + pass + +bootstrap(app) # Must be after handler registration +``` + +### Session header not found in close handler + +**Problem**: Getting "Session ID is required in request headers" error. + +**Solution**: Ensure your `request_shape` extracts the session ID from the header: +```python +request_shape={"session_id": 'headers."X-Amzn-SageMaker-Session-Id"'} ``` ## See Also - [README.md](./README.md) - Main sessions documentation -- [handlers.py](./handlers.py) - Default handler implementations -- [transforms/](./transforms/) - Transform classes for engine integration +- [Integration tests](../../../tests/integration/test_custom_session_handlers_integration.py) - Complete working examples diff --git a/python/model_hosting_container_standards/sagemaker/sessions/README.md b/python/model_hosting_container_standards/sagemaker/sessions/README.md index 3896dd9..c067e08 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/README.md +++ b/python/model_hosting_container_standards/sagemaker/sessions/README.md @@ -41,37 +41,53 @@ The framework supports two modes of session management: ## Architecture ``` -SessionApiTransform (transform.py) +Client Request to /invocations ↓ - ├─→ Session Management Request - │ ├─→ create_session (handlers.py) - │ └─→ close_session (handlers.py) +SessionApiTransform (intercepts and inspects) + ↓ + ├─→ Session Management Request (NEW_SESSION or CLOSE) + │ ├─→ Check Handler Registry + │ │ ├─→ Custom Handler (if registered) + │ │ │ └─→ Your engine's session API + │ │ └─→ Default Handler (if not registered) + │ │ └─→ SageMaker SessionManager + │ └─→ Return with session headers │ └─→ Regular Inference Request - └─→ Pass through with session context + ├─→ Validate session ID (if present) + ├─→ Inject session ID into body (if configured) + └─→ Pass to your handler ``` ### Key Components -- **`SessionManager`** (`manager.py`): Manages session lifecycle, expiration, and cleanup +- **`SessionManager`** (`manager.py`): Manages session lifecycle, expiration, and cleanup (default mode) - **`Session`** (`manager.py`): Individual session with file-based key-value storage -- **`SessionApiTransform`** (`transform.py`): API transform that intercepts session requests -- **Session Handlers** (`handlers.py`): Functions to create and close sessions +- **`SessionApiTransform`** (`transform.py`): API transform that intercepts and routes session requests +- **Handler Registry**: Routes session requests to custom or default handlers +- **Session Handlers** (`handlers.py`): Default functions to create and close sessions +- **Engine Session Transforms** (`transforms/`): Transform classes for custom engine integration - **Utilities** (`utils.py`): Helper functions for session ID extraction and retrieval ## Quick Start ### Enabling Sessions in Your Handler -Use the `stateful_session_manager()` convenience decorator: +Use the `stateful_session_manager()` decorator on your `/invocations` endpoint: ```python -from model_hosting_container_standards.sagemaker import stateful_session_manager +from fastapi import FastAPI, Request +from model_hosting_container_standards.sagemaker import stateful_session_manager, bootstrap + +app = FastAPI() +@app.post("/invocations") @stateful_session_manager() -def my_handler(request): +async def invocations(request: Request): # Handler logic with session support pass + +bootstrap(app) ``` ### Creating a Session @@ -118,15 +134,18 @@ X-Amzn-SageMaker-Closed-Session-Id: ## Configuration -Configure via `SessionManager` properties: +Configure via environment variables: -```python -session_manager = SessionManager({ - "sessions_expiration": "1200", # TTL in seconds (default: 1200) - "sessions_path": "/dev/shm/sagemaker_sessions" # Storage path -}) +```bash +export SAGEMAKER_ENABLE_STATEFUL_SESSIONS=true +export SAGEMAKER_SESSIONS_EXPIRATION=1200 # TTL in seconds (default: 1200) +export SAGEMAKER_SESSIONS_PATH=/dev/shm/sagemaker_sessions # Storage path (optional) ``` +The session manager is automatically initialized from these environment variables when you call `bootstrap(app)`. + +**Important**: If `SAGEMAKER_ENABLE_STATEFUL_SESSIONS` is not set to `true`, session management requests will fail with a 400 error. Regular inference requests without session headers will continue to work normally. + ### Storage Location Sessions are stored in memory-backed filesystem when available: @@ -158,6 +177,30 @@ Each session maintains its own directory with JSON files for key-value pairs: ## Advanced Usage +### Injecting Session ID into Request Body + +If your handler needs the session ID in the request body (not just headers), use the `request_session_id_path` parameter: + +```python +@app.post("/invocations") +@stateful_session_manager(request_session_id_path="session_id") +async def invocations(request: Request): + body = await request.json() + session_id = body.get("session_id") # Automatically injected from header + # Handler logic +``` + +For nested paths, use dot notation: + +```python +@stateful_session_manager(request_session_id_path="metadata.session_id") +async def invocations(request: Request): + body = await request.json() + session_id = body["metadata"]["session_id"] # Injected at nested path +``` + +**Note**: The session ID is only injected when the `X-Amzn-SageMaker-Session-Id` header is present in the request. + ### Custom Session Handlers If your inference engine has its own session management API, you can register custom handlers to delegate session creation and closure to the engine instead of using SageMaker's built-in session management. From ec0ea84092a61a6acb3c6b9710ee3fccd373ffcc Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Tue, 9 Dec 2025 20:44:54 +0000 Subject: [PATCH 20/25] refactor(sagemaker/sessions): simplify custom session handler registration API Improve the developer experience for registering custom session handlers by introducing a cleaner API with better parameter naming and adding a build_session_request_shape helper to consolidate request shape logic. This makes it easier to understand how session IDs are injected and provides better conflict detection through logging. --- docs/INTEGRATION_RUNBOOK.md | 180 +++++++++--------- .../sagemaker/__init__.py | 68 ++++++- .../sagemaker/sessions/CUSTOM_HANDLERS.md | 14 +- .../sagemaker/sessions/__init__.py | 39 +++- ...est_custom_session_handlers_integration.py | 47 ++--- python/tests/sagemaker/sessions/test_init.py | 149 +++++++++++++++ 6 files changed, 368 insertions(+), 129 deletions(-) create mode 100644 python/tests/sagemaker/sessions/test_init.py diff --git a/docs/INTEGRATION_RUNBOOK.md b/docs/INTEGRATION_RUNBOOK.md index 60fe52d..bbaf926 100644 --- a/docs/INTEGRATION_RUNBOOK.md +++ b/docs/INTEGRATION_RUNBOOK.md @@ -1,7 +1,7 @@ # MHCS Integration Runbook **Version**: 1.0 container -**Last Updated**: November 16, 2025 +**Last Updated**: November 16, 2025 **Target Audience**: ML framework developers integrating with Amazon SageMaker --- @@ -94,7 +94,7 @@ ## 1. Introduction -### 1.1 What is MHCS? +### 1.1 What is MHCS? Model Hosting Container Standards (MHCS) is a Python library that acts as a bridge between model hosting platforms and ML inference engines with rapidly evolving APIs. It standardizes how ML frameworks integrate with hosting platforms while maintaining backwards compatibility and adapting to changing engine interfaces. @@ -276,10 +276,10 @@ async def invocations(request: Request) -> dict: """Model inference endpoint for SageMaker.""" body = await request.json() prompt = body.get("prompt", "") - + # Your framework's inference logic here result = f"Processed: {prompt}" - + return {"predictions": [result]} # Bootstrap MHCS - must be called after handler definitions @@ -354,10 +354,10 @@ async def invocations(request: Request) -> dict: body = await request.json() prompt = body.get("prompt", "") adapter_id = body.get("model", "base-model") # Injected by decorator - + # Your framework's inference logic with adapter result = f"[{adapter_id}] Processed: {prompt}" - + return {"predictions": [result], "adapter_used": adapter_id} @sagemaker_standards.register_load_adapter_handler( @@ -369,10 +369,10 @@ async def load_adapter(request: Request) -> dict: body = await request.json() adapter_name = body["adapter_name"] adapter_path = body.get("adapter_path", "") - + # Your framework's adapter loading logic loaded_adapters[adapter_name] = {"path": adapter_path, "loaded": True} - + return {"status": "success", "adapter_name": adapter_name} @sagemaker_standards.register_unload_adapter_handler( @@ -382,12 +382,12 @@ async def load_adapter(request: Request) -> dict: async def unload_adapter(request: Request) -> dict: """Unload a LoRA adapter.""" adapter_name = request.path_params.get("adapter_name") - + # Your framework's adapter unloading logic if adapter_name in loaded_adapters: del loaded_adapters[adapter_name] return {"status": "success", "adapter_name": adapter_name} - + return {"status": "not_found", "adapter_name": adapter_name} sagemaker_standards.bootstrap(app) @@ -399,7 +399,7 @@ if __name__ == "__main__": **How it works**: - The LoRA decorators use the transform decorator system under the hood (see [Section 4: Transform Decorators](#4-transform-decorators) for details). -- `@inject_adapter_id("model")` - Extracts adapter ID from `X-Amzn-SageMaker-Adapter-Identifier` header and injects it into the `model` field of the request body. +- `@inject_adapter_id("model")` - Extracts adapter ID from `X-Amzn-SageMaker-Adapter-Identifier` header and injects it into the `model` field of the request body. - `@register_load_adapter_handler` - Creates `POST /adapters` endpoint for loading adapters - `@register_unload_adapter_handler` - Creates `DELETE /adapters/{adapter_name}` endpoint for unloading adapters @@ -471,16 +471,16 @@ async def invocations(request: Request) -> dict: """Inference with session management.""" body = await request.json() prompt = body.get("prompt", "") - + # Access session data if available session_id = request.headers.get("X-Amzn-SageMaker-Session-Id") - + # Your framework's inference logic with session context if session_id: result = f"[Session {session_id}] Processed: {prompt}" else: result = f"Processed: {prompt}" - + return {"predictions": [result]} sagemaker_standards.bootstrap(app) @@ -726,10 +726,10 @@ The `bootstrap(app)` function is the central integration point that connects you ```python def bootstrap(app: FastAPI) -> FastAPI: """Configure a FastAPI application with SageMaker functionality. - + Args: app: The FastAPI application instance to configure - + Returns: The configured FastAPI app """ @@ -832,7 +832,7 @@ sequenceDiagram participant SageMaker Router participant Handler Registry participant Your Handler - + Client->>FastAPI App: GET /ping FastAPI App->>Middleware: Process request Middleware->>SageMaker Router: Route to /ping @@ -893,10 +893,10 @@ from fastapi import Request, Response @sagemaker_standards.register_ping_handler async def ping(request: Request) -> Response: """Health check handler for SageMaker. - + Args: request: FastAPI Request object containing headers, body, etc. - + Returns: Response: FastAPI Response object with status code and content """ @@ -912,19 +912,19 @@ from typing import Dict, Any @sagemaker_standards.register_invocation_handler async def invocations(request: Request) -> Dict[str, Any]: """Model inference handler for SageMaker. - + Args: request: FastAPI Request object containing the inference request - + Returns: Dict: JSON-serializable dictionary with predictions """ body = await request.json() prompt = body.get("prompt", "") - + # Your framework's inference logic result = your_model.generate(prompt) - + return {"predictions": [result]} ``` @@ -948,11 +948,11 @@ sequenceDiagram participant MHCS Router participant Handler Registry participant Your Handler - + Note over Client,Your Handler: Registration Phase (at startup) Your Handler->>Handler Registry: @register_ping_handler decorator Handler Registry->>Handler Registry: Store handler as "ping" type - + Note over Client,Your Handler: Request Phase (at runtime) Client->>FastAPI: GET /ping FastAPI->>MHCS Router: Route to /ping endpoint @@ -1028,13 +1028,13 @@ graph TD F -->|No| H{Register Decorator?} H -->|Yes| I[Use Framework Default Handler] H -->|No| J[No Handler Found] - + C --> K[Handler Resolved] E --> K G --> K I --> K J --> L[Skip Route Creation] - + style C fill:#ff6b6b style E fill:#ffa500 style G fill:#ffd93d @@ -1271,12 +1271,12 @@ This step-by-step checklist guides you through a complete MHCS integration. Foll @sagemaker_standards.register_ping_handler async def ping(request: Request) -> Response: return Response(status_code=200, content="Healthy") - + @sagemaker_standards.register_invocation_handler async def invocations(request: Request) -> dict: # Your inference logic here ... - + # Call bootstrap() last sagemaker_standards.bootstrap(app) ``` @@ -1364,13 +1364,13 @@ If your framework supports LoRA adapters, add these handlers: curl -X POST http://localhost:8000/adapters \ -H "Content-Type: application/json" \ -d '{"name": "my-adapter", "src": "/path/to/adapter"}' - + # Use adapter curl -X POST http://localhost:8000/invocations \ -H "Content-Type: application/json" \ -H "X-Amzn-SageMaker-Adapter-Identifier: my-adapter" \ -d '{"prompt": "test"}' - + # Unload adapter curl -X DELETE http://localhost:8000/adapters/my-adapter ``` @@ -1411,13 +1411,13 @@ If your framework needs stateful sessions: curl -X POST http://localhost:8000/invocations \ -H "Content-Type: application/json" \ -d '{"requestType": "NEW_SESSION"}' - + # Use session (replace with actual ID) curl -X POST http://localhost:8000/invocations \ -H "Content-Type: application/json" \ -H "X-Amzn-SageMaker-Session-Id: " \ -d '{"prompt": "test"}' - + # Close session curl -X POST http://localhost:8000/invocations \ -H "Content-Type: application/json" \ @@ -1568,15 +1568,15 @@ graph TD B -->|used by| C[create_transform_decorator Factory] C -->|creates| D[Decorator Functions] D -->|applied to| E[Your Handler Functions] - + B1[RegisterLoRAApiTransform] -.->|example| B B2[InjectToBodyApiTransform] -.->|example| B B3[SessionApiTransform] -.->|example| B - + D1[inject_adapter_id decorator] -.->|example| D D2[register_load_adapter_handler decorator] -.->|example| D D3[stateful_session_manager decorator] -.->|example| D - + style A fill:#e1f5ff style C fill:#fff4e1 style D fill:#e8f5e9 @@ -1592,15 +1592,15 @@ class BaseApiTransform: # Compiles JMESPath expressions for efficient execution self._request_shape = _compile_jmespath_expressions(request_shape) self._response_shape = _compile_jmespath_expressions(response_shape) - + def _transform(self, source_data: Dict, target_shape: Dict) -> Dict: # Applies JMESPath expressions to extract and transform data pass - + async def transform_request(self, raw_request: Request): # Subclasses implement specific request transformation logic raise NotImplementedError() - + def transform_response(self, response: Response, transform_request_output): # Subclasses implement specific response transformation logic raise NotImplementedError() @@ -1624,26 +1624,26 @@ A factory function that creates decorators dynamically: ```python def create_transform_decorator(handler_type: str, transform_resolver: Callable): """Creates a decorator factory for a specific handler type.""" - + def decorator_with_params(request_shape: Dict = None, response_shape: Dict = None): """Configures the transformation shapes.""" - + def decorator(func: Callable): """The actual decorator that wraps your handler.""" # Resolves the appropriate transform class - transformer = _resolve_transforms(handler_type, transform_resolver, + transformer = _resolve_transforms(handler_type, transform_resolver, request_shape, response_shape) - + async def decorated_func(raw_request: Request): # Apply request transformation transform_output = await transformer.transform_request(raw_request) - + # Call your handler with transformed data response = await transformer.intercept(func, transform_output) - + # Apply response transformation return transformer.transform_response(response, transform_output) - + return decorated_func return decorator return decorator_with_params @@ -1705,7 +1705,7 @@ sequenceDiagram participant FastAPI participant Transform participant Handler - + Client->>FastAPI: POST /invocations
Header: X-Amzn-SageMaker-Adapter-Identifier: my-adapter
Body: {"prompt": "..."} FastAPI->>Transform: Raw Request Transform->>Transform: Extract adapter ID from header @@ -1738,7 +1738,7 @@ sequenceDiagram } ``` -4. **Transform vs Passthrough**: +4. **Transform vs Passthrough**: - Pass `request_shape=None` for no transformation (passthrough mode) - Pass `request_shape={}` for transform infrastructure without JMESPath - Pass `request_shape={...}` for full transformation @@ -1937,7 +1937,7 @@ def create_transform_decorator( Args: handler_type: Identifier for the handler (e.g., 'register_adapter') transform_resolver: Function that maps handler_type to transform class - + Returns: Decorator factory that accepts request_shape and response_shape """ @@ -1957,15 +1957,15 @@ def decorator(func): async def wrapped_func(raw_request: Request): # 1. Transform request transform_output = await transformer.transform_request(raw_request) - + # 2. Call your handler response = await transformer.intercept(func, transform_output) - + # 3. Transform response final_response = transformer.transform_response(response, transform_output) - + return final_response - + return wrapped_func ``` @@ -1997,7 +1997,7 @@ def get_sagemaker_route_config(handler_type: str) -> Optional[RouteConfig]: return RouteConfig(path="/ping", method="GET", ...) elif handler_type == "invoke": return RouteConfig(path="/invocations", method="POST", ...) - + # Delegate to LoRA routes for adapter handlers return get_lora_route_config(handler_type) ``` @@ -2056,28 +2056,28 @@ from model_hosting_container_standards.common import BaseApiTransform, BaseTrans class MyCustomTransform(BaseApiTransform): """Custom transform for my specific use case.""" - + def __init__(self, request_shape, response_shape={}): """Initialize with request and response shapes. - + Args: request_shape: JMESPath expressions for extracting request data response_shape: JMESPath expressions for transforming responses """ super().__init__(request_shape, response_shape) # Add any custom initialization here - + async def transform_request(self, raw_request: Request): """Transform incoming request. - + This method is called before your handler executes. Extract and validate data, then return a BaseTransformRequestOutput. """ raise NotImplementedError() - + def transform_response(self, response: Response, transform_request_output): """Transform outgoing response. - + This method is called after your handler executes. Modify the response based on the request transformation output. """ @@ -2117,7 +2117,7 @@ class MyCustomTransform(BaseApiTransform): status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}", ) - + # Step 2: Validate using Pydantic model try: validated_request = MyRequestModel.model_validate(request_data) @@ -2126,10 +2126,10 @@ class MyCustomTransform(BaseApiTransform): status_code=HTTPStatus.BAD_REQUEST.value, detail=e.json(include_url=False), ) - + # Step 3: Apply JMESPath transformations (if request_shape provided) transformed_data = self._transform_request(validated_request, raw_request) - + # Step 4: Return BaseTransformRequestOutput return BaseTransformRequestOutput( request=transformed_data, # Transformed data passed to handler @@ -2163,24 +2163,24 @@ class MyCustomTransform(BaseApiTransform): """Transform the response based on request processing.""" # Option 1: Simple passthrough (no transformation) return response - + # Option 2: Route based on status code if response.status_code == HTTPStatus.OK: return self._transform_ok_response(response, transform_request_output) else: return self._transform_error_response(response, transform_request_output) - + def _transform_ok_response(self, response: Response, transform_request_output): """Transform successful responses.""" # Extract data from request transformation adapter_name = transform_request_output.request.get("adapter_name") - + # Create custom response return Response( status_code=HTTPStatus.OK.value, content=f"Success: Processed {adapter_name}", ) - + def _transform_error_response(self, response: Response, transform_request_output): """Transform error responses.""" # Pass through or customize error responses @@ -2249,7 +2249,7 @@ from fastapi import Request ) async def my_handler(transformed_data, raw_request: Request): """Handler receives transformed data as first argument. - + Args: transformed_data: SimpleNamespace with attributes from request_shape raw_request: Original FastAPI Request for additional context @@ -2258,10 +2258,10 @@ async def my_handler(transformed_data, raw_request: Request): adapter_name = transformed_data.adapter_name adapter_path = transformed_data.adapter_path custom_header = transformed_data.custom_header - + # Your handler logic here result = f"Processing {adapter_name} from {adapter_path}" - + return {"status": "success", "message": result} ``` @@ -2295,13 +2295,13 @@ If your custom transform needs its own HTTP endpoint (not just transforming exis ```python def get_sagemaker_route_config(handler_type: str) -> Optional[RouteConfig]: """Map handler types to their route configurations.""" - + # Core SageMaker routes if handler_type == "ping": return RouteConfig(path="/ping", method="GET", ...) elif handler_type == "invoke": return RouteConfig(path="/invocations", method="POST", ...) - + # Your custom route elif handler_type == "my_custom_operation": return RouteConfig( @@ -2310,7 +2310,7 @@ def get_sagemaker_route_config(handler_type: str) -> Optional[RouteConfig]: response_model=None, status_code=200 ) - + # Delegate to LoRA routes for adapter handlers return get_lora_route_config(handler_type) ``` @@ -2418,14 +2418,14 @@ async def invocations(request: Request): body = await request.json() adapter_id = body.get("model") # Automatically injected from header # Your inference logic with adapter_id - + # 2. Implement adapter loading @register_load_adapter_handler( request_shape={"adapter_name": "body.name", "adapter_path": "body.src"} ) async def load_adapter(data, request): # Your framework's adapter loading logic - + # 3. Implement adapter unloading @register_unload_adapter_handler( request_shape={"adapter_name": "path_params.adapter_name"} @@ -2515,7 +2515,7 @@ import model_hosting_container_standards.sagemaker as sagemaker_standards async def invocations(request: Request) -> dict: body = await request.json() adapter_id = body.get("model") # Adapter ID is now in body["model"] - + # Your framework's inference logic result = f"Using adapter: {adapter_id}" return {"predictions": [result]} @@ -2557,10 +2557,10 @@ Append mode concatenates the adapter ID to an existing value using a separator. async def invocations(request: Request) -> dict: body = await request.json() model_with_adapter = body.get("model") # "base-model:my-adapter" - + # Parse the composite identifier base_model, adapter_id = model_with_adapter.split(":", 1) - + return {"predictions": [f"Base: {base_model}, Adapter: {adapter_id}"]} ``` @@ -2570,7 +2570,7 @@ async def invocations(request: Request) -> dict: # Incoming request body: {"prompt": "Hello", "model": "base-model"} -# After @inject_adapter_id("model", append=True, separator=":") +# After @inject_adapter_id("model", append=True, separator=":") # with header X-Amzn-SageMaker-Adapter-Identifier: my-adapter # Transformed request body: {"prompt": "Hello", "model": "base-model:my-adapter"} @@ -2832,11 +2832,11 @@ class MyFramework: def load_adapter(self, name: str, path: str, **kwargs): """Load adapter from path with given name.""" pass - + def unload_adapter(self, name: str): """Unload adapter by name.""" pass - + def has_adapter(self, name: str) -> bool: """Check if adapter is loaded.""" pass @@ -2962,7 +2962,7 @@ graph LR D -->|Response| A A -->|CLOSE + Session ID| E[Close Session] E -->|Delete Data| F[Session Removed] - + C -->|TTL Expired| G[Auto Cleanup] G -->|Delete Data| F ``` @@ -3302,13 +3302,13 @@ async def invocations(request: Request) -> dict: """Inference handler with session support.""" body = await request.json() prompt = body.get("prompt", "") - + # Access session ID if present session_id = request.headers.get("X-Amzn-SageMaker-Session-Id") - + # Your inference logic here result = f"Processed: {prompt}" - + return {"predictions": [result]} ``` @@ -3333,7 +3333,7 @@ The decorator order follows Python's decorator application rules: decorators are **What the Decorator Does:** 1. **Intercepts Session Requests**: Detects `requestType` field in request body -2. **Routes to Default Session Handlers**: +2. **Routes to Default Session Handlers**: - `NEW_SESSION` → `create_session()` handler - `CLOSE` → `close_session()` handler 3. **Validates Session IDs**: Checks session existence and expiration @@ -3549,15 +3549,15 @@ def get_session(self, session_id: str) -> Optional[Session]: with self._lock: if session_id not in self.sessions: raise ValueError(f"session not found: {session_id}") - + session = self.sessions[session_id] - + # Check expiration if session.expiration_ts is not None and time.time() > session.expiration_ts: logging.info(f"Session expired: {session_id}") self.close_session(session_id) # Automatic cleanup return None - + return session ``` @@ -3896,7 +3896,7 @@ graph TD F -->|No| H{Register decorator?} H -->|Yes| I[Use register decorator handler] H -->|No| J[Error: No handler found] - + C --> K[Create route with handler] E --> K G --> K diff --git a/python/model_hosting_container_standards/sagemaker/__init__.py b/python/model_hosting_container_standards/sagemaker/__init__.py index ac861e4..25a194a 100644 --- a/python/model_hosting_container_standards/sagemaker/__init__.py +++ b/python/model_hosting_container_standards/sagemaker/__init__.py @@ -22,6 +22,7 @@ from .sagemaker_loader import SageMakerFunctionLoader from .sagemaker_router import create_sagemaker_router from .sessions import ( + build_session_request_shape, create_session_transform_decorator, register_engine_session_handler, ) @@ -147,21 +148,80 @@ def stateful_session_manager(request_session_id_path: Optional[str] = None): def register_create_session_handler( - request_shape, response_session_id_path: str, content_path: Optional[str] = None + request_session_id_path: str, + response_session_id_path: str, + additional_request_shape: Optional[Dict[str, str]] = None, + content_path: str = "`successfully created session.`", ): + """Register a handler for session creation with custom request/response transformations. + + This decorator creates a session handler that transforms incoming requests to include + the session ID and extracts the session ID from the engine's response. + + Args: + request_session_id_path: JMESPath target where the session ID should be injected + into the request body sent to the engine. + response_session_id_path: JMESPath expression to extract the session ID from + the engine's response body. + additional_request_shape: Optional dict of additional JMESPath transformations + to apply to the request. Keys are target paths, values + are source expressions. Defaults to None. + content_path: JMESPath expression for the success message in the response. + Defaults to a literal success message. + + Returns: + A decorator that can be applied to engine-specific session creation handlers. + + Note: + If request_session_id_path appears in additional_request_shape, it will be + overwritten to ensure the session ID is properly injected. + """ + request_shape = build_session_request_shape( + request_session_id_path, additional_request_shape + ) + return register_engine_session_handler( "create_session", request_shape=request_shape, response_session_id_path=response_session_id_path, - content_path=content_path or "`successfully created session.`", + content_path=content_path, ) -def register_close_session_handler(request_shape, content_path: Optional[str] = None): +def register_close_session_handler( + request_session_id_path: str, + additional_request_shape: Optional[Dict[str, str]] = None, + content_path: str = "`successfully closed session.`", +): + """Register a handler for session closure with custom request transformations. + + This decorator creates a session handler that transforms incoming requests to include + the session ID for proper session cleanup. + + Args: + request_session_id_path: JMESPath target where the session ID should be injected + into the request body sent to the engine. + additional_request_shape: Optional dict of additional JMESPath transformations + to apply to the request. Keys are target paths, values + are source expressions. Defaults to None. + content_path: JMESPath expression for the success message in the response. + Defaults to a literal success message. + + Returns: + A decorator that can be applied to engine-specific session closure handlers. + + Note: + If request_session_id_path appears in additional_request_shape, it will be + overwritten to ensure the session ID is properly injected. + """ + request_shape = build_session_request_shape( + request_session_id_path, additional_request_shape + ) + return register_engine_session_handler( "close_session", request_shape=request_shape, - content_path=content_path or "`successfully closed session.`", + content_path=content_path, ) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md b/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md index 3aaf546..37acd77 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md +++ b/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md @@ -255,14 +255,14 @@ class CreateSessionRequest(BaseModel): async def create_session(obj: CreateSessionRequest, request: Request): # Generate or use provided session ID session_id = obj.session_id or str(uuid.uuid4()) - + # Check if session already exists if session_id in active_sessions: raise HTTPException(status_code=400, detail="Session already exists") - + # Create session in your engine active_sessions[session_id] = {"capacity": obj.capacity} - + return { "session_id": session_id, "message": f"Session created with capacity {obj.capacity}" @@ -276,10 +276,10 @@ async def create_session(obj: CreateSessionRequest, request: Request): async def close_session(session_id: str, request: Request): if session_id not in active_sessions: raise HTTPException(status_code=404, detail="Session not found") - + # Close session in your engine del active_sessions[session_id] - + return Response(status_code=200, content="Session closed") @app.post("/invocations") @@ -288,10 +288,10 @@ async def invocations(request: Request): body_bytes = await request.body() body = json.loads(body_bytes.decode()) session_id = body.get("session_id") - + if session_id and session_id not in active_sessions: raise HTTPException(status_code=400, detail="Invalid session") - + # Process inference request with session context return JSONResponse({ "result": "success", diff --git a/python/model_hosting_container_standards/sagemaker/sessions/__init__.py b/python/model_hosting_container_standards/sagemaker/sessions/__init__.py index 37219d0..fb6ac92 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/__init__.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/__init__.py @@ -1,6 +1,7 @@ -from typing import Optional +from typing import Dict, Optional from ...common.transforms.base_factory import create_transform_decorator +from ...logging_config import logger from .models import SageMakerSessionHeader from .transform import SessionApiTransform from .transforms import resolve_engine_session_transform @@ -72,3 +73,39 @@ def register_engine_session_handler( return _create_engine_session_transform_decorator(handler_type)( request_shape, response_shape ) + + +def build_session_request_shape( + session_id_path: str, + additional_shape: Optional[Dict[str, str]] = None, +) -> Dict[str, str]: + """Build the request shape for session handlers with proper session ID injection. + + This helper consolidates the logic for constructing request shapes, ensuring + the session ID is always properly mapped and warning about any conflicts. + + Args: + session_id_path: The target path for the session ID in the request. + additional_shape: Optional additional transformations to merge. + + Returns: + A complete request shape dict with session ID and any additional mappings. + """ + request_shape: Dict[str, str] = {} + + if additional_shape: + # Warn if session_id_path would be overwritten + if session_id_path in additional_shape: + existing_value = additional_shape[session_id_path] + logger.warning( + f"Session ID path '{session_id_path}' found in additional_request_shape " + f"with value '{existing_value}'. This will be overwritten with the " + f"SageMaker session header value." + ) + + # Merge additional shape, ensuring session ID takes precedence + request_shape.update(additional_shape) + + request_shape[session_id_path] = f'headers."{SageMakerSessionHeader.SESSION_ID}"' + + return request_shape diff --git a/python/tests/integration/test_custom_session_handlers_integration.py b/python/tests/integration/test_custom_session_handlers_integration.py index d3f872e..ef72735 100644 --- a/python/tests/integration/test_custom_session_handlers_integration.py +++ b/python/tests/integration/test_custom_session_handlers_integration.py @@ -141,10 +141,11 @@ def custom_close_session(self, obj: CloseSessionRequest, request: Request): def setup_common_handlers(self): @sagemaker_standards.register_create_session_handler( - request_shape={ + request_session_id_path="session_id", + response_session_id_path="body", + additional_request_shape={ "capacity_of_str_len": "`1024`", }, - response_session_id_path="body", content_path="`successfully created session.`", ) @self.app.api_route("/open_session", methods=["GET", "POST"]) @@ -152,9 +153,7 @@ async def create_session(obj: CreateSessionRequest, request: Request): return self.custom_create_session(obj, request) @sagemaker_standards.register_close_session_handler( - request_shape={ - "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' - }, + request_session_id_path="session_id", content_path="`successfully closed session.`", ) @self.app.api_route("/close_session", methods=["GET", "POST"]) @@ -346,11 +345,11 @@ def custom_close_session(self, obj: CloseSessionRequest, request: Request): def setup_common_handlers(self): @sagemaker_standards.register_create_session_handler( - request_shape={ + request_session_id_path="session_id", + response_session_id_path="body.session_id", # Nested + additional_request_shape={ "capacity_of_str_len": "`1024`", - "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"', }, - response_session_id_path="body.session_id", # Nested content_path="`successfully created session.`", ) @self.app.api_route("/open_session", methods=["GET", "POST"]) @@ -358,9 +357,7 @@ async def create_session(obj: CreateSessionRequest, request: Request): return self.custom_create_session(obj, request) @sagemaker_standards.register_close_session_handler( - request_shape={ - "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' - }, + request_session_id_path="session_id", content_path="`successfully closed session.`", ) @self.app.api_route("/close_session", methods=["GET", "POST"]) @@ -497,8 +494,9 @@ def setup_common_handlers(self): response_path = "body.session_id" if self.response_format == "dict" else "body" @sagemaker_standards.register_create_session_handler( - request_shape={"capacity_of_str_len": "`1024`"}, + request_session_id_path="session_id", response_session_id_path=response_path, + additional_request_shape={"capacity_of_str_len": "`1024`"}, content_path="`successfully created session.`", ) @self.app.api_route("/open_session", methods=["GET", "POST"]) @@ -506,9 +504,7 @@ async def create_session(obj: CreateSessionRequest, request: Request): return self.custom_create_session(obj, request) @sagemaker_standards.register_close_session_handler( - request_shape={ - "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' - }, + request_session_id_path="session_id", content_path="`successfully closed session.`", ) @self.app.api_route("/close_session", methods=["GET", "POST"]) @@ -569,8 +565,9 @@ def custom_close_session(self, obj: CloseSessionRequest, request: Request): def setup_common_handlers(self): @sagemaker_standards.register_create_session_handler( - request_shape={"capacity_of_str_len": "`1024`"}, + request_session_id_path="session_id", response_session_id_path="body.session_id", + additional_request_shape={"capacity_of_str_len": "`1024`"}, content_path="`successfully created session.`", ) @self.app.api_route("/open_session", methods=["GET", "POST"]) @@ -578,9 +575,7 @@ async def create_session(obj: CreateSessionRequest, request: Request): return self.custom_create_session(obj, request) @sagemaker_standards.register_close_session_handler( - request_shape={ - "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' - }, + request_session_id_path="session_id", content_path="`successfully closed session.`", ) @self.app.api_route("/close_session", methods=["GET", "POST"]) @@ -671,8 +666,9 @@ def custom_close_session(self, obj: CloseSessionRequest, request: Request): def setup_common_handlers(self): @sagemaker_standards.register_create_session_handler( - request_shape={"capacity_of_str_len": "`1024`"}, + request_session_id_path="session_id", response_session_id_path="body.session_id", + additional_request_shape={"capacity_of_str_len": "`1024`"}, content_path="`successfully created session.`", ) @self.app.api_route("/open_session", methods=["GET", "POST"]) @@ -680,9 +676,7 @@ async def create_session(obj: CreateSessionRequest, request: Request): return self.custom_create_session(obj, request) @sagemaker_standards.register_close_session_handler( - request_shape={ - "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' - }, + request_session_id_path="session_id", content_path="`successfully closed session.`", ) @self.app.api_route("/close_session", methods=["GET", "POST"]) @@ -759,8 +753,9 @@ def custom_close_session(self, obj: CloseSessionRequest, request: Request): def setup_common_handlers(self): @sagemaker_standards.register_create_session_handler( - request_shape={"capacity_of_str_len": "`1024`"}, + request_session_id_path="session_id", response_session_id_path="body.session_id", + additional_request_shape={"capacity_of_str_len": "`1024`"}, content_path="`successfully created session.`", ) @self.app.api_route("/open_session", methods=["GET", "POST"]) @@ -768,9 +763,7 @@ async def create_session(obj: CreateSessionRequest, request: Request): return self.custom_create_session(obj, request) @sagemaker_standards.register_close_session_handler( - request_shape={ - "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' - }, + request_session_id_path="session_id", content_path="`successfully closed session.`", ) @self.app.api_route("/close_session", methods=["GET", "POST"]) diff --git a/python/tests/sagemaker/sessions/test_init.py b/python/tests/sagemaker/sessions/test_init.py new file mode 100644 index 0000000..c444505 --- /dev/null +++ b/python/tests/sagemaker/sessions/test_init.py @@ -0,0 +1,149 @@ +"""Unit tests for sessions module public API.""" + +from unittest.mock import patch + +from model_hosting_container_standards.sagemaker.sessions import ( + build_session_request_shape, +) +from model_hosting_container_standards.sagemaker.sessions.models import ( + SageMakerSessionHeader, +) + + +class TestBuildSessionRequestShape: + """Test build_session_request_shape function.""" + + def test_creates_basic_request_shape_with_session_id_only(self): + """Test creates request shape with only session ID path.""" + result = build_session_request_shape("session_id") + + assert result == { + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' + } + + def test_creates_request_shape_with_nested_session_id_path(self): + """Test creates request shape with nested session ID path.""" + result = build_session_request_shape("metadata.session_id") + + assert result == { + "metadata.session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' + } + + def test_merges_additional_shape_without_conflicts(self): + """Test merges additional shape when no conflicts exist.""" + additional = { + "capacity": "`1024`", + "model_name": "`gpt-4`", + } + + result = build_session_request_shape("session_id", additional) + + assert result == { + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"', + "capacity": "`1024`", + "model_name": "`gpt-4`", + } + + @patch("model_hosting_container_standards.sagemaker.sessions.logger") + def test_overwrites_conflicting_session_id_path_and_warns(self, mock_logger): + """Test overwrites session ID path in additional shape and logs warning.""" + additional = { + "session_id": "some_other_value", + "capacity": "`1024`", + } + + result = build_session_request_shape("session_id", additional) + + # Session ID should be overwritten with the correct value + assert result == { + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"', + "capacity": "`1024`", + } + + # Should have logged a warning + mock_logger.warning.assert_called_once() + warning_message = mock_logger.warning.call_args[0][0] + assert "session_id" in warning_message + assert "some_other_value" in warning_message + assert "overwritten" in warning_message.lower() + + def test_handles_none_additional_shape(self): + """Test handles None as additional shape gracefully.""" + result = build_session_request_shape("session_id", None) + + assert result == { + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' + } + + def test_handles_empty_additional_shape(self): + """Test handles empty dict as additional shape.""" + result = build_session_request_shape("session_id", {}) + + assert result == { + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' + } + + def test_preserves_all_additional_fields(self): + """Test preserves all fields from additional shape.""" + additional = { + "field1": "value1", + "field2": "value2", + "field3": "value3", + "nested.field": "nested_value", + } + + result = build_session_request_shape("session_id", additional) + + assert result == { + "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"', + "field1": "value1", + "field2": "value2", + "field3": "value3", + "nested.field": "nested_value", + } + + @patch("model_hosting_container_standards.sagemaker.sessions.logger") + def test_session_id_always_takes_precedence(self, mock_logger): + """Test session ID value always takes precedence even after merge.""" + additional = { + "session_id": "wrong_value", + "other_field": "other_value", + } + + result = build_session_request_shape("session_id", additional) + + # Verify session_id has the correct value, not the one from additional + assert result["session_id"] == f'headers."{SageMakerSessionHeader.SESSION_ID}"' + assert result["session_id"] != "wrong_value" + assert result["other_field"] == "other_value" + + def test_works_with_complex_jmespath_expressions(self): + """Test works with complex JMESPath expressions in additional shape.""" + additional = { + "model": 'headers."X-Model-Name"', + "temperature": "`0.7`", + "max_tokens": "body.parameters.max_tokens", + } + + result = build_session_request_shape("request.session_id", additional) + + assert result == { + "request.session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"', + "model": 'headers."X-Model-Name"', + "temperature": "`0.7`", + "max_tokens": "body.parameters.max_tokens", + } + + @patch("model_hosting_container_standards.sagemaker.sessions.logger") + def test_no_warning_when_no_conflict(self, mock_logger): + """Test no warning is logged when there's no conflict.""" + additional = { + "capacity": "`1024`", + "model": "`gpt-4`", + } + + result = build_session_request_shape("session_id", additional) + + # Should not have logged any warning + mock_logger.warning.assert_not_called() + assert result["session_id"] == f'headers."{SageMakerSessionHeader.SESSION_ID}"' From b7fb991952b44a7906165110c1a937d052e3b827 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Tue, 9 Dec 2025 21:08:44 +0000 Subject: [PATCH 21/25] refactor(sagemaker/sessions): improve parameter naming clarity for custom session handlers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rename parameters to better distinguish between engine request and response paths, making it clearer what data flows where. Make engine_request_session_id_path optional for create_session while keeping it required for close_session (must know which session to close). Changes: - Rename request_session_id_path → engine_request_session_id_path - Rename response_session_id_path → engine_response_session_id_path - Make engine_request_session_id_path optional in create handler - Add validation to require engine_request_session_id_path in close handler - Expand docstrings with detailed explanations, examples, and limitations - Update all tests and documentation to use new parameter names --- .../sagemaker/__init__.py | 103 +++++++++++++----- .../sagemaker/sessions/CUSTOM_HANDLERS.md | 102 ++++++++++------- .../sagemaker/sessions/__init__.py | 20 ++-- ...est_custom_session_handlers_integration.py | 40 +++---- .../test_sagemaker_sessions_integration.py | 4 +- python/tests/sagemaker/sessions/test_init.py | 59 ++++++++++ 6 files changed, 234 insertions(+), 94 deletions(-) diff --git a/python/model_hosting_container_standards/sagemaker/__init__.py b/python/model_hosting_container_standards/sagemaker/__init__.py index 25a194a..4d56c41 100644 --- a/python/model_hosting_container_standards/sagemaker/__init__.py +++ b/python/model_hosting_container_standards/sagemaker/__init__.py @@ -123,23 +123,29 @@ def inject_adapter_id( ) -def stateful_session_manager(request_session_id_path: Optional[str] = None): +def stateful_session_manager(engine_request_session_id_path: Optional[str] = None): """Create a decorator for session-based sticky routing. - This decorator enables stateful session management without JMESPath transformations. - Pass empty dicts to enable transform infrastructure (for intercept functionality) - without requiring JMESPath expressions. + This decorator enables stateful session management for regular invocation requests, + allowing the session ID to be injected into the request body for stateful inference. Args: - request_session_id_path: JMESPath target path where session ID should be - injected INTO the request body from the session header + engine_request_session_id_path: Optional target path in the request body where + the session ID will be injected. The session ID + is extracted from the SageMaker session header and + placed at this path in the request sent to the engine. + + Examples: "session_id", "metadata.session_id" + + If None, session management is enabled but the + session ID is not injected into the request body. Returns: A decorator that can be applied to route handlers to enable session management """ request_shape = {} - if request_session_id_path: - request_shape[request_session_id_path] = ( + if engine_request_session_id_path: + request_shape[engine_request_session_id_path] = ( f'headers."{SageMakerSessionHeader.SESSION_ID}"' ) return create_session_transform_decorator()( @@ -148,8 +154,8 @@ def stateful_session_manager(request_session_id_path: Optional[str] = None): def register_create_session_handler( - request_session_id_path: str, - response_session_id_path: str, + engine_response_session_id_path: str, + engine_request_session_id_path: Optional[str] = None, additional_request_shape: Optional[Dict[str, str]] = None, content_path: str = "`successfully created session.`", ): @@ -159,13 +165,36 @@ def register_create_session_handler( the session ID and extracts the session ID from the engine's response. Args: - request_session_id_path: JMESPath target where the session ID should be injected - into the request body sent to the engine. - response_session_id_path: JMESPath expression to extract the session ID from - the engine's response body. + engine_response_session_id_path: JMESPath expression specifying where to extract + the session ID from the engine's response. Must + include a prefix indicating the source location: + + - "body.session_id" - extract from response body + - "headers.X-Session-Id" - extract from response headers + + The extracted session ID is placed in the SageMaker + response body for the client. + + engine_request_session_id_path: Optional target path in the engine request body + where the session ID will be injected. The session + ID is extracted from the SageMaker session header + and placed at this path in the request sent to the + engine. + + Examples: "session_id", "metadata.session_id" + + If None, the session ID is not injected into the + engine request body. This is useful when the engine + manages session IDs internally and doesn't need them + in the request. + + Limitation: Currently only supports injection into + the request body, not headers. + additional_request_shape: Optional dict of additional JMESPath transformations - to apply to the request. Keys are target paths, values - are source expressions. Defaults to None. + to apply to the request. Keys are target paths in the + request body, values are source expressions. Defaults to None. + content_path: JMESPath expression for the success message in the response. Defaults to a literal success message. @@ -173,23 +202,23 @@ def register_create_session_handler( A decorator that can be applied to engine-specific session creation handlers. Note: - If request_session_id_path appears in additional_request_shape, it will be + If engine_request_session_id_path appears in additional_request_shape, it will be overwritten to ensure the session ID is properly injected. """ request_shape = build_session_request_shape( - request_session_id_path, additional_request_shape + engine_request_session_id_path, additional_request_shape ) return register_engine_session_handler( "create_session", request_shape=request_shape, - response_session_id_path=response_session_id_path, + response_session_id_path=engine_response_session_id_path, content_path=content_path, ) def register_close_session_handler( - request_session_id_path: str, + engine_request_session_id_path: str, additional_request_shape: Optional[Dict[str, str]] = None, content_path: str = "`successfully closed session.`", ): @@ -199,23 +228,45 @@ def register_close_session_handler( the session ID for proper session cleanup. Args: - request_session_id_path: JMESPath target where the session ID should be injected - into the request body sent to the engine. + engine_request_session_id_path: Required. Target path in the engine request body + where the session ID will be injected. The session + ID is extracted from the SageMaker session header + and placed at this path in the request sent to the + engine. + + Examples: "session_id", "metadata.session_id" + + This parameter is required because the engine needs + to know which session to close. + + Limitation: Currently only supports injection into + the request body, not headers. + additional_request_shape: Optional dict of additional JMESPath transformations - to apply to the request. Keys are target paths, values - are source expressions. Defaults to None. + to apply to the request. Keys are target paths in the + request body, values are source expressions. Defaults to None. + content_path: JMESPath expression for the success message in the response. Defaults to a literal success message. Returns: A decorator that can be applied to engine-specific session closure handlers. + Raises: + ValueError: If engine_request_session_id_path is None or empty. + Note: - If request_session_id_path appears in additional_request_shape, it will be + If engine_request_session_id_path appears in additional_request_shape, it will be overwritten to ensure the session ID is properly injected. """ + if not engine_request_session_id_path: + raise ValueError( + "engine_request_session_id_path is required for close_session handler. " + "The engine needs to know which session to close." + ) + request_shape = build_session_request_shape( - request_session_id_path, additional_request_shape + engine_request_session_id_path, additional_request_shape ) return register_engine_session_handler( diff --git a/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md b/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md index 37acd77..082622c 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md +++ b/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md @@ -56,10 +56,11 @@ class CreateSessionResponse(BaseModel): # Register custom create session handler @register_create_session_handler( - request_shape={ - "capacity": "`1024`" # JMESPath literal value + engine_response_session_id_path="body.session_id", # Extract session ID from response + engine_request_session_id_path="session_id", # Where to inject session ID in engine request + additional_request_shape={ + "capacity": "`1024`" # Additional fields to include }, - response_session_id_path="body.session_id", # Extract session ID from response content_path="body.message" # Extract content for logging ) @app.post("/engine/create_session") @@ -68,11 +69,23 @@ async def create_session(obj: CreateSessionRequest, request: Request): session_id = await my_engine.create_session(capacity=obj.capacity) return CreateSessionResponse(session_id=session_id, message="Session created") +# Alternative: If your engine manages session IDs internally +@register_create_session_handler( + engine_response_session_id_path="body.session_id", # Extract session ID from response + # No engine_request_session_id_path - engine generates its own session ID + additional_request_shape={ + "capacity": "`1024`" + } +) +@app.post("/engine/create_session") +async def create_session_auto(obj: CreateSessionRequest, request: Request): + # Engine generates and returns its own session ID + session_id = await my_engine.create_session_auto(capacity=obj.capacity) + return CreateSessionResponse(session_id=session_id, message="Session created") + # Register custom close session handler @register_close_session_handler( - request_shape={ - "session_id": 'headers."X-Amzn-SageMaker-Session-Id"' # Extract from header - }, + engine_request_session_id_path="session_id", # Where to inject session ID in engine request content_path="`Session closed successfully`" # Static message ) @app.post("/engine/close_session") @@ -97,29 +110,33 @@ bootstrap(app) ```python @register_create_session_handler( - request_shape: dict, # Required: JMESPath mappings for request transformation - response_session_id_path: str, # Required: JMESPath to extract session ID from response - content_path: str = None # Optional: JMESPath to extract content for logging + engine_response_session_id_path: str, # Required: Where to extract session ID from engine response + engine_request_session_id_path: str = None, # Optional: Where to inject session ID in engine request + additional_request_shape: dict = None, # Optional: Additional JMESPath mappings + content_path: str = None # Optional: JMESPath to extract content for logging ) ``` -- **`request_shape`**: Maps target keys to source JMESPath expressions. Transforms the incoming SageMaker request into your engine's expected format. -- **`response_session_id_path`**: JMESPath expression to extract the session ID from your engine's response. This is **required** because the framework needs to return the session ID in the response header. +- **`engine_response_session_id_path`**: JMESPath expression to extract the session ID from your engine's response. Must include prefix (`body.` or `headers.`). This is **required** because the framework needs to return the session ID to the client. +- **`engine_request_session_id_path`**: Optional target path in the engine request body where the session ID will be injected. The session ID is extracted from the SageMaker session header and placed at this path. Example: `"session_id"` or `"metadata.session_id"`. If None, the session ID is not injected (useful when the engine manages sessions internally) +- **`additional_request_shape`**: Optional dict mapping target keys to source JMESPath expressions for additional fields to include in the engine request. - **`content_path`**: Optional JMESPath expression to extract a message for logging. Defaults to a generic success message. ### `@register_close_session_handler` ```python @register_close_session_handler( - request_shape: dict, # Required: JMESPath mappings for request transformation - content_path: str = None # Optional: JMESPath to extract content for logging + engine_request_session_id_path: str, # Required: Where to inject session ID in engine request + additional_request_shape: dict = None, # Optional: Additional JMESPath mappings + content_path: str = None # Optional: JMESPath to extract content for logging ) ``` -- **`request_shape`**: Maps target keys to source JMESPath expressions. Typically extracts the session ID from the request header. +- **`engine_request_session_id_path`**: **Required.** Target path in the engine request body where the session ID will be injected. The session ID is extracted from the SageMaker session header and placed at this path. This is required because the engine needs to know which session to close. Example: `"session_id"` or `"metadata.session_id"` +- **`additional_request_shape`**: Optional dict mapping target keys to source JMESPath expressions for additional fields to include in the engine request. - **`content_path`**: Optional JMESPath expression to extract a message for logging. Defaults to a generic success message. -**Note**: `response_session_id_path` is not needed for close handlers because the session ID comes from the request header, not the response. +**Note**: `engine_response_session_id_path` is not needed for close handlers because the session ID comes from the request header, not the response. ## How It Works @@ -135,29 +152,35 @@ The key benefit: Your `/invocations` endpoint stays clean, and session managemen ## JMESPath Expressions -The `request_shape` and `response_shape` parameters use JMESPath expressions to transform data: +The parameters use JMESPath expressions to transform data: + +### Request Transformation + +The `engine_request_session_id_path` specifies where to inject the session ID (always relative to request body): -### Request Shape +```python +engine_request_session_id_path="session_id" # Inject at root level +engine_request_session_id_path="metadata.session_id" # Inject in nested path +``` -Maps target keys to source expressions: +The `additional_request_shape` maps target keys to source expressions: ```python -request_shape={ +additional_request_shape={ "capacity": "`1024`", # Literal value - "session_id": 'headers."X-Amzn-SageMaker-Session-Id"', # From header - "user_id": "body.metadata.user" # From request body } ``` -### Response Shape +### Response Extraction For **create session**, you must specify: -- `response_session_id_path`: Where to extract the session ID from the engine's response +- `engine_response_session_id_path`: Where to extract the session ID from the engine's response - `content_path`: Where to extract content for logging (optional) ```python -response_session_id_path="body.session_id" # Extract from {"session_id": "..."} -response_session_id_path="body" # If response is just the session ID string +engine_response_session_id_path="body.session_id" # Extract from {"session_id": "..."} +engine_response_session_id_path="body" # If response is just the session ID string +engine_response_session_id_path="headers.X-Session-Id" # Extract from response header content_path="body.message" # Extract message from response content_path="`Session created`" # Use literal string ``` @@ -244,11 +267,11 @@ class CreateSessionRequest(BaseModel): session_id: Optional[str] = None @register_create_session_handler( - request_shape={ - "capacity": "`1024`", - "session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"' + engine_request_session_id_path="session_id", + engine_response_session_id_path="body.session_id", + additional_request_shape={ + "capacity": "`1024`" }, - response_session_id_path="body.session_id", content_path="body.message" ) @app.post("/engine/create_session") @@ -269,7 +292,7 @@ async def create_session(obj: CreateSessionRequest, request: Request): } @register_close_session_handler( - request_shape={"session_id": f'headers."{SageMakerSessionHeader.SESSION_ID}"'}, + engine_request_session_id_path="session_id", content_path="`Session closed successfully`" ) @app.post("/engine/close_session") @@ -283,7 +306,7 @@ async def close_session(session_id: str, request: Request): return Response(status_code=200, content="Session closed") @app.post("/invocations") -@stateful_session_manager(request_session_id_path="session_id") +@stateful_session_manager(engine_request_session_id_path="session_id") async def invocations(request: Request): body_bytes = await request.body() body = json.loads(body_bytes.decode()) @@ -327,13 +350,16 @@ This design allows your engine to manage sessions independently without interfer **Problem**: Getting "Engine failed to return a valid session ID" error. -**Solution**: Check that your `response_session_id_path` matches your response structure: +**Solution**: Check that your `engine_response_session_id_path` matches your response structure: ```python # If your handler returns: {"session_id": "abc123"} -response_session_id_path="body.session_id" +engine_response_session_id_path="body.session_id" # If your handler returns: "abc123" -response_session_id_path="body" +engine_response_session_id_path="body" + +# If session ID is in response header +engine_response_session_id_path="headers.X-Session-Id" ``` ### Request not reaching custom handler @@ -349,13 +375,13 @@ async def create_session(...): bootstrap(app) # Must be after handler registration ``` -### Session header not found in close handler +### Session ID not injected into engine request -**Problem**: Getting "Session ID is required in request headers" error. +**Problem**: Engine receives request without session ID. -**Solution**: Ensure your `request_shape` extracts the session ID from the header: +**Solution**: Ensure your `engine_request_session_id_path` specifies where to inject the session ID: ```python -request_shape={"session_id": 'headers."X-Amzn-SageMaker-Session-Id"'} +engine_request_session_id_path="session_id" # Injects at root level of request body ``` ## See Also diff --git a/python/model_hosting_container_standards/sagemaker/sessions/__init__.py b/python/model_hosting_container_standards/sagemaker/sessions/__init__.py index fb6ac92..2f01312 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/__init__.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/__init__.py @@ -76,7 +76,7 @@ def register_engine_session_handler( def build_session_request_shape( - session_id_path: str, + session_id_path: Optional[str], additional_shape: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: """Build the request shape for session handlers with proper session ID injection. @@ -85,7 +85,8 @@ def build_session_request_shape( the session ID is always properly mapped and warning about any conflicts. Args: - session_id_path: The target path for the session ID in the request. + session_id_path: Optional target path for the session ID in the request. + If None, session ID is not injected into the request. additional_shape: Optional additional transformations to merge. Returns: @@ -94,18 +95,21 @@ def build_session_request_shape( request_shape: Dict[str, str] = {} if additional_shape: + request_shape.update(additional_shape) + + # Only inject session ID if a path is specified + if session_id_path: # Warn if session_id_path would be overwritten - if session_id_path in additional_shape: - existing_value = additional_shape[session_id_path] + if session_id_path in request_shape: + existing_value = request_shape[session_id_path] logger.warning( f"Session ID path '{session_id_path}' found in additional_request_shape " f"with value '{existing_value}'. This will be overwritten with the " f"SageMaker session header value." ) - # Merge additional shape, ensuring session ID takes precedence - request_shape.update(additional_shape) - - request_shape[session_id_path] = f'headers."{SageMakerSessionHeader.SESSION_ID}"' + request_shape[session_id_path] = ( + f'headers."{SageMakerSessionHeader.SESSION_ID}"' + ) return request_shape diff --git a/python/tests/integration/test_custom_session_handlers_integration.py b/python/tests/integration/test_custom_session_handlers_integration.py index ef72735..65ef2d6 100644 --- a/python/tests/integration/test_custom_session_handlers_integration.py +++ b/python/tests/integration/test_custom_session_handlers_integration.py @@ -141,8 +141,8 @@ def custom_close_session(self, obj: CloseSessionRequest, request: Request): def setup_common_handlers(self): @sagemaker_standards.register_create_session_handler( - request_session_id_path="session_id", - response_session_id_path="body", + engine_request_session_id_path="session_id", + engine_response_session_id_path="body", additional_request_shape={ "capacity_of_str_len": "`1024`", }, @@ -153,7 +153,7 @@ async def create_session(obj: CreateSessionRequest, request: Request): return self.custom_create_session(obj, request) @sagemaker_standards.register_close_session_handler( - request_session_id_path="session_id", + engine_request_session_id_path="session_id", content_path="`successfully closed session.`", ) @self.app.api_route("/close_session", methods=["GET", "POST"]) @@ -345,8 +345,8 @@ def custom_close_session(self, obj: CloseSessionRequest, request: Request): def setup_common_handlers(self): @sagemaker_standards.register_create_session_handler( - request_session_id_path="session_id", - response_session_id_path="body.session_id", # Nested + engine_request_session_id_path="session_id", + engine_response_session_id_path="body.session_id", # Nested additional_request_shape={ "capacity_of_str_len": "`1024`", }, @@ -357,7 +357,7 @@ async def create_session(obj: CreateSessionRequest, request: Request): return self.custom_create_session(obj, request) @sagemaker_standards.register_close_session_handler( - request_session_id_path="session_id", + engine_request_session_id_path="session_id", content_path="`successfully closed session.`", ) @self.app.api_route("/close_session", methods=["GET", "POST"]) @@ -367,7 +367,7 @@ async def close_session(obj: CloseSessionRequest, request: Request): def setup_invocation_handler(self): @self.router.post("/invocations") @sagemaker_standards.stateful_session_manager( - request_session_id_path="session_id" + engine_request_session_id_path="session_id" ) async def invocations(request: Request): return await self.custom_invocations(request) @@ -494,8 +494,8 @@ def setup_common_handlers(self): response_path = "body.session_id" if self.response_format == "dict" else "body" @sagemaker_standards.register_create_session_handler( - request_session_id_path="session_id", - response_session_id_path=response_path, + engine_request_session_id_path="session_id", + engine_response_session_id_path=response_path, additional_request_shape={"capacity_of_str_len": "`1024`"}, content_path="`successfully created session.`", ) @@ -504,7 +504,7 @@ async def create_session(obj: CreateSessionRequest, request: Request): return self.custom_create_session(obj, request) @sagemaker_standards.register_close_session_handler( - request_session_id_path="session_id", + engine_request_session_id_path="session_id", content_path="`successfully closed session.`", ) @self.app.api_route("/close_session", methods=["GET", "POST"]) @@ -565,8 +565,8 @@ def custom_close_session(self, obj: CloseSessionRequest, request: Request): def setup_common_handlers(self): @sagemaker_standards.register_create_session_handler( - request_session_id_path="session_id", - response_session_id_path="body.session_id", + engine_request_session_id_path="session_id", + engine_response_session_id_path="body.session_id", additional_request_shape={"capacity_of_str_len": "`1024`"}, content_path="`successfully created session.`", ) @@ -575,7 +575,7 @@ async def create_session(obj: CreateSessionRequest, request: Request): return self.custom_create_session(obj, request) @sagemaker_standards.register_close_session_handler( - request_session_id_path="session_id", + engine_request_session_id_path="session_id", content_path="`successfully closed session.`", ) @self.app.api_route("/close_session", methods=["GET", "POST"]) @@ -666,8 +666,8 @@ def custom_close_session(self, obj: CloseSessionRequest, request: Request): def setup_common_handlers(self): @sagemaker_standards.register_create_session_handler( - request_session_id_path="session_id", - response_session_id_path="body.session_id", + engine_request_session_id_path="session_id", + engine_response_session_id_path="body.session_id", additional_request_shape={"capacity_of_str_len": "`1024`"}, content_path="`successfully created session.`", ) @@ -676,7 +676,7 @@ async def create_session(obj: CreateSessionRequest, request: Request): return self.custom_create_session(obj, request) @sagemaker_standards.register_close_session_handler( - request_session_id_path="session_id", + engine_request_session_id_path="session_id", content_path="`successfully closed session.`", ) @self.app.api_route("/close_session", methods=["GET", "POST"]) @@ -686,7 +686,7 @@ async def close_session(obj: CloseSessionRequest, request: Request): def setup_invocation_handler(self): @self.router.post("/invocations") @sagemaker_standards.stateful_session_manager( - request_session_id_path="metadata.session_id" + engine_request_session_id_path="metadata.session_id" ) async def invocations(request: Request): body_bytes = await request.body() @@ -753,8 +753,8 @@ def custom_close_session(self, obj: CloseSessionRequest, request: Request): def setup_common_handlers(self): @sagemaker_standards.register_create_session_handler( - request_session_id_path="session_id", - response_session_id_path="body.session_id", + engine_request_session_id_path="session_id", + engine_response_session_id_path="body.session_id", additional_request_shape={"capacity_of_str_len": "`1024`"}, content_path="`successfully created session.`", ) @@ -763,7 +763,7 @@ async def create_session(obj: CreateSessionRequest, request: Request): return self.custom_create_session(obj, request) @sagemaker_standards.register_close_session_handler( - request_session_id_path="session_id", + engine_request_session_id_path="session_id", content_path="`successfully closed session.`", ) @self.app.api_route("/close_session", methods=["GET", "POST"]) diff --git a/python/tests/integration/test_sagemaker_sessions_integration.py b/python/tests/integration/test_sagemaker_sessions_integration.py index 0903216..716d023 100644 --- a/python/tests/integration/test_sagemaker_sessions_integration.py +++ b/python/tests/integration/test_sagemaker_sessions_integration.py @@ -643,7 +643,7 @@ def setup_handlers(self): @self.router.post("/invocations-with-path") @sagemaker_standards.stateful_session_manager( - request_session_id_path="session_id" + engine_request_session_id_path="session_id" ) async def invocations_with_path(request: Request): """Handler that injects session ID into request body at 'session_id' key.""" @@ -668,7 +668,7 @@ async def invocations_with_path(request: Request): @self.router.post("/invocations-nested-path") @sagemaker_standards.stateful_session_manager( - request_session_id_path="metadata.session_id" + engine_request_session_id_path="metadata.session_id" ) async def invocations_nested_path(request: Request): """Handler that injects session ID into nested path in request body.""" diff --git a/python/tests/sagemaker/sessions/test_init.py b/python/tests/sagemaker/sessions/test_init.py index c444505..fccbd13 100644 --- a/python/tests/sagemaker/sessions/test_init.py +++ b/python/tests/sagemaker/sessions/test_init.py @@ -2,6 +2,8 @@ from unittest.mock import patch +import pytest + from model_hosting_container_standards.sagemaker.sessions import ( build_session_request_shape, ) @@ -147,3 +149,60 @@ def test_no_warning_when_no_conflict(self, mock_logger): # Should not have logged any warning mock_logger.warning.assert_not_called() assert result["session_id"] == f'headers."{SageMakerSessionHeader.SESSION_ID}"' + + def test_none_session_id_path_returns_only_additional_shape(self): + """Test that None session_id_path returns only additional shape.""" + additional = { + "capacity": "`1024`", + "model": "`gpt-4`", + } + + result = build_session_request_shape(None, additional) + + # Should only have additional fields, no session ID + assert result == additional + assert f'headers."{SageMakerSessionHeader.SESSION_ID}"' not in result.values() + + def test_none_session_id_path_with_no_additional_shape(self): + """Test that None session_id_path with no additional shape returns empty dict.""" + result = build_session_request_shape(None, None) + + assert result == {} + + def test_none_session_id_path_with_empty_additional_shape(self): + """Test that None session_id_path with empty additional shape returns empty dict.""" + result = build_session_request_shape(None, {}) + + assert result == {} + + +class TestRegisterCloseSessionHandler: + """Test register_close_session_handler validation.""" + + def test_raises_error_when_engine_request_session_id_path_is_none(self): + """Test raises ValueError when engine_request_session_id_path is None.""" + from model_hosting_container_standards.sagemaker import ( + register_close_session_handler, + ) + + with pytest.raises( + ValueError, match="engine_request_session_id_path is required" + ): + register_close_session_handler( + engine_request_session_id_path=None, + content_path="`Session closed`", + ) + + def test_raises_error_when_engine_request_session_id_path_is_empty(self): + """Test raises ValueError when engine_request_session_id_path is empty string.""" + from model_hosting_container_standards.sagemaker import ( + register_close_session_handler, + ) + + with pytest.raises( + ValueError, match="engine_request_session_id_path is required" + ): + register_close_session_handler( + engine_request_session_id_path="", + content_path="`Session closed`", + ) From 4036e22e3c80d1dd3a3196e97194099eeef06ae6 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Wed, 10 Dec 2025 20:21:24 +0000 Subject: [PATCH 22/25] fix docstring, add status_code to serialize_response dict output. --- .../model_hosting_container_standards/common/fastapi/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/model_hosting_container_standards/common/fastapi/utils.py b/python/model_hosting_container_standards/common/fastapi/utils.py index 886bd10..723a19f 100644 --- a/python/model_hosting_container_standards/common/fastapi/utils.py +++ b/python/model_hosting_container_standards/common/fastapi/utils.py @@ -46,7 +46,7 @@ def serialize_response(response: Union[Response, JSONResponse]): :param Union[Response, JSONResponse] response: Response body data - can be: - FastAPI Response object - JSONResponse object - :return Dict[str, Any]: Structured data with body, headers, status_code, and media_type + :return Dict[str, Any]: Structured data with body, headers, and status_code """ # Process response body based on type body = response.body.decode(response.charset) @@ -54,10 +54,10 @@ def serialize_response(response: Union[Response, JSONResponse]): body = json.loads(body) except json.JSONDecodeError: # If body is not JSON, keep it as a string - # logger.warning(f"Response body is not JSON, keeping as string: {e}") pass return { "body": body, "headers": response.headers, + "status_code": response.status_code, } From 4bc21c7544a7911e0ab75bafbcb68e3dc902b926 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Wed, 10 Dec 2025 21:10:04 +0000 Subject: [PATCH 23/25] update set_value call in SessionApiTransform _process_invocations_request when target key is specified for session id injection to support create_parent --- .../sagemaker/sessions/transform.py | 7 +++-- ...est_custom_session_handlers_integration.py | 30 ++++++++++++++++++- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transform.py index 4cf84cc..aff9d7e 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transform.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transform.py @@ -169,9 +169,10 @@ def _process_invocations_request( # Inject session ID into request body if target key is specified if session_id and self._session_id_target_key: request_data = set_value( - request_data, - self._session_id_target_key, - session_id, + obj=request_data, + path=self._session_id_target_key, + value=session_id, + create_parent=True, ) logger.debug(f"Updated request body: {request_data}") raw_request._body = json.dumps(request_data).encode("utf-8") diff --git a/python/tests/integration/test_custom_session_handlers_integration.py b/python/tests/integration/test_custom_session_handlers_integration.py index 65ef2d6..a0c0914 100644 --- a/python/tests/integration/test_custom_session_handlers_integration.py +++ b/python/tests/integration/test_custom_session_handlers_integration.py @@ -708,7 +708,8 @@ def test_session_id_injected_into_nested_path(self): Some ML engines expect the session ID to be in the request body rather than just in headers. The request_session_id_path parameter allows automatic injection of the session ID into a specified path in the request body - (e.g., metadata.session_id). This test validates that injection works correctly. + (e.g., metadata.session_id). This test validates that the session ID is + correctly injected when the metadata dict already exists. """ # Create session session_id = self.create_session() @@ -728,6 +729,33 @@ def test_session_id_injected_into_nested_path(self): # Verify original metadata fields are preserved assert data["body"]["metadata"]["user"] == "test_user" + def test_session_id_injected_creates_missing_metadata_dict(self): + """Test that session ID injection creates missing parent structures. + + When the request path expects metadata.session_id but the request doesn't + include a "metadata" dict, the set_value function should create the missing + parent structure and inject the session ID. This tests the create_parent=True + functionality in set_value. + """ + # Create session + session_id = self.create_session() + + # Make request with session - note we don't include "metadata" dict at all + # The framework should create the missing "metadata" dict and inject session_id + response = self.invoke_with_session( + session_id, {"prompt": "test", "user": "test_user"} + ) + + assert response.status_code == 200 + data = json.loads(response.text) + + # Verify session ID was automatically injected and metadata dict was created + assert data["session_id"] == session_id + assert data["body"]["metadata"]["session_id"] == session_id + # Verify original fields at root level are preserved + assert data["body"]["prompt"] == "test" + assert data["body"]["user"] == "test_user" + class TestCustomHandlerSessionPersistence(BaseCustomHandlerIntegrationTest): """Test that session state persists correctly across invocations with custom handlers.""" From 3e880beb351396b76922a2eefb0fc5dc96cf2e05 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Fri, 12 Dec 2025 18:34:39 +0000 Subject: [PATCH 24/25] Improve how SessionApiTransform tracks whether to use default manager. --- .../sagemaker/sessions/transform.py | 54 +++++++------------ 1 file changed, 18 insertions(+), 36 deletions(-) diff --git a/python/model_hosting_container_standards/sagemaker/sessions/transform.py b/python/model_hosting_container_standards/sagemaker/sessions/transform.py index aff9d7e..f2cd307 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/transform.py +++ b/python/model_hosting_container_standards/sagemaker/sessions/transform.py @@ -11,7 +11,7 @@ from ...common.transforms.utils import set_value from ...logging_config import logger from .handlers import get_handler_for_request_type -from .manager import SessionManager, get_session_manager +from .manager import get_session_manager from .models import ( SESSION_DISABLED_ERROR_DETAIL, SESSION_DISABLED_LOG_MESSAGE, @@ -60,40 +60,26 @@ def __init__(self, request_shape, response_shape={}): for validation in this transform, as session requests use their own validation. """ self._session_manager = get_session_manager() - - # Hybrid caching strategy for _use_default_manager: - # - If custom handlers exist at init → cache False (fast path on every request) - # - If no custom handlers at init → cache True but check dynamically (allows late registration) - # This optimizes the common case while maintaining flexibility - self._use_default_manager_cached = not handler_registry.has_handler( - "create_session" - ) and not handler_registry.has_handler("close_session") + self._use_default_manager = None # Extract session_id_target_key before compiling JMESPath expressions self._session_id_target_key = self._get_session_id_target_key(request_shape) super().__init__(request_shape, response_shape) - def _use_default_manager(self) -> bool: - """Check if default session manager should be used. - - Hybrid approach for performance: - - If custom handlers existed at init time (cached=False), return False immediately - - If no custom handlers at init (cached=True), check dynamically in case they were registered later - - This optimizes the common case (custom handlers registered before transform creation) - while still supporting late registration for flexibility. + def _check_use_default_manager(self): + """Check if the default session manager should be used. Returns: - bool: True if default manager should be used, False if custom handlers exist + bool: True if the default session manager should be used, False otherwise """ - # Fast path: if custom handlers existed at init, they still exist - if not self._use_default_manager_cached: - return False - - # Slow path: no custom handlers at init, check if any were registered since - return not handler_registry.has_handler( - "create_session" - ) and not handler_registry.has_handler("close_session") + if self._use_default_manager is None: + # If unset, first call -> set cached value + logger.info("Checking if default session manager should be used.") + self._use_default_manager = not handler_registry.has_handler( + "create_session" + ) and not handler_registry.has_handler("close_session") + logger.info(f"Using default session manager: {self._use_default_manager}") + return self._use_default_manager def _get_session_id_target_key(self, request_shape: dict) -> Optional[str]: if not request_shape: @@ -121,9 +107,7 @@ async def transform_request(self, raw_request): """ try: request_data = await raw_request.json() - return self._process_request( - request_data, raw_request, self._session_manager - ) + return self._process_request(request_data, raw_request) except json.JSONDecodeError as e: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, @@ -161,7 +145,7 @@ def _process_invocations_request( self, session_id: Optional[str], request_data: dict, raw_request: Request ): # If not a session request - if session_id and self._use_default_manager(): + if session_id and self._check_use_default_manager(): # but it has a session id header and we are using the default session manager, # then we need to validate that the session id exists in the session manager self._validate_session_id(session_id, raw_request) @@ -184,7 +168,7 @@ def _process_invocations_request( def _process_session_request(self, session_request, session_id, raw_request): # Validation - if self._use_default_manager() and not self._session_manager: + if self._check_use_default_manager() and not self._session_manager: # if no custom handlers are registered, but default session manager # does not exist -> then raise error that session management is disabled logger.error(SESSION_DISABLED_LOG_MESSAGE) @@ -192,7 +176,7 @@ def _process_session_request(self, session_request, session_id, raw_request): status_code=HTTPStatus.BAD_REQUEST.value, detail=SESSION_DISABLED_ERROR_DETAIL, ) - elif self._use_default_manager() and self._session_manager: + elif self._check_use_default_manager() and self._session_manager: if session_request.requestType == SessionRequestType.NEW_SESSION: # Ignores any session id header in create session request session_id = SessionRequestType.NEW_SESSION @@ -205,9 +189,7 @@ def _process_session_request(self, session_request, session_id, raw_request): raw_request=raw_request, intercept_func=intercept_func ) - def _process_request( - self, request_data, raw_request, session_manager: Optional[SessionManager] - ): + def _process_request(self, request_data, raw_request): session_request = _parse_session_request(request_data) session_id = get_session_id_from_request(raw_request) if not session_request: From 55292371fe119905e8ab852d1dcf29d98eed8da9 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Tue, 16 Dec 2025 01:41:26 +0000 Subject: [PATCH 25/25] update register_create_session_handler to no longer take path for request sagemaker session id since it'll always be NEW_SESSION value, update docs and tests --- .../sagemaker/__init__.py | 33 +++-------------- .../sagemaker/sessions/CUSTOM_HANDLERS.md | 33 +++++++++++------ ...est_custom_session_handlers_integration.py | 35 ++++--------------- 3 files changed, 34 insertions(+), 67 deletions(-) diff --git a/python/model_hosting_container_standards/sagemaker/__init__.py b/python/model_hosting_container_standards/sagemaker/__init__.py index 4d56c41..4db35b1 100644 --- a/python/model_hosting_container_standards/sagemaker/__init__.py +++ b/python/model_hosting_container_standards/sagemaker/__init__.py @@ -155,8 +155,7 @@ def stateful_session_manager(engine_request_session_id_path: Optional[str] = Non def register_create_session_handler( engine_response_session_id_path: str, - engine_request_session_id_path: Optional[str] = None, - additional_request_shape: Optional[Dict[str, str]] = None, + request_shape: Optional[Dict[str, str]] = None, content_path: str = "`successfully created session.`", ): """Register a handler for session creation with custom request/response transformations. @@ -175,39 +174,17 @@ def register_create_session_handler( The extracted session ID is placed in the SageMaker response body for the client. - engine_request_session_id_path: Optional target path in the engine request body - where the session ID will be injected. The session - ID is extracted from the SageMaker session header - and placed at this path in the request sent to the - engine. - - Examples: "session_id", "metadata.session_id" - - If None, the session ID is not injected into the - engine request body. This is useful when the engine - manages session IDs internally and doesn't need them - in the request. - - Limitation: Currently only supports injection into - the request body, not headers. - - additional_request_shape: Optional dict of additional JMESPath transformations - to apply to the request. Keys are target paths in the - request body, values are source expressions. Defaults to None. + request_shape: Optional dict of JMESPath transformations + to apply to the request. Keys are target paths in the + request body, values are source expressions. Defaults to None. content_path: JMESPath expression for the success message in the response. Defaults to a literal success message. Returns: A decorator that can be applied to engine-specific session creation handlers. - - Note: - If engine_request_session_id_path appears in additional_request_shape, it will be - overwritten to ensure the session ID is properly injected. """ - request_shape = build_session_request_shape( - engine_request_session_id_path, additional_request_shape - ) + request_shape = build_session_request_shape(None, request_shape) return register_engine_session_handler( "create_session", diff --git a/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md b/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md index 082622c..2641f66 100644 --- a/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md +++ b/python/model_hosting_container_standards/sagemaker/sessions/CUSTOM_HANDLERS.md @@ -57,8 +57,7 @@ class CreateSessionResponse(BaseModel): # Register custom create session handler @register_create_session_handler( engine_response_session_id_path="body.session_id", # Extract session ID from response - engine_request_session_id_path="session_id", # Where to inject session ID in engine request - additional_request_shape={ + request_shape={ "capacity": "`1024`" # Additional fields to include }, content_path="body.message" # Extract content for logging @@ -72,8 +71,7 @@ async def create_session(obj: CreateSessionRequest, request: Request): # Alternative: If your engine manages session IDs internally @register_create_session_handler( engine_response_session_id_path="body.session_id", # Extract session ID from response - # No engine_request_session_id_path - engine generates its own session ID - additional_request_shape={ + request_shape={ "capacity": "`1024`" } ) @@ -111,15 +109,14 @@ bootstrap(app) ```python @register_create_session_handler( engine_response_session_id_path: str, # Required: Where to extract session ID from engine response - engine_request_session_id_path: str = None, # Optional: Where to inject session ID in engine request - additional_request_shape: dict = None, # Optional: Additional JMESPath mappings + request_shape: dict = None, # Optional: Additional JMESPath mappings content_path: str = None # Optional: JMESPath to extract content for logging ) ``` - **`engine_response_session_id_path`**: JMESPath expression to extract the session ID from your engine's response. Must include prefix (`body.` or `headers.`). This is **required** because the framework needs to return the session ID to the client. - **`engine_request_session_id_path`**: Optional target path in the engine request body where the session ID will be injected. The session ID is extracted from the SageMaker session header and placed at this path. Example: `"session_id"` or `"metadata.session_id"`. If None, the session ID is not injected (useful when the engine manages sessions internally) -- **`additional_request_shape`**: Optional dict mapping target keys to source JMESPath expressions for additional fields to include in the engine request. +- **`request_shape`**: Optional dict mapping target keys to source JMESPath expressions for additional fields to include in the engine request. - **`content_path`**: Optional JMESPath expression to extract a message for logging. Defaults to a generic success message. ### `@register_close_session_handler` @@ -156,18 +153,33 @@ The parameters use JMESPath expressions to transform data: ### Request Transformation +#### For Create Session Handlers + +The `request_shape` parameter maps target keys to source expressions: + +```python +# register_create_session_handler only +request_shape={ + "capacity": "`1024`", # Literal value +} +``` + +#### For Close Session Handlers + The `engine_request_session_id_path` specifies where to inject the session ID (always relative to request body): ```python +# register_close_session_handler only engine_request_session_id_path="session_id" # Inject at root level engine_request_session_id_path="metadata.session_id" # Inject in nested path ``` -The `additional_request_shape` maps target keys to source expressions: +The `additional_request_shape` parameter maps target keys to source expressions: ```python +# register_close_session_handler only additional_request_shape={ - "capacity": "`1024`", # Literal value + "timeout": "`30`", # Literal value } ``` @@ -267,9 +279,8 @@ class CreateSessionRequest(BaseModel): session_id: Optional[str] = None @register_create_session_handler( - engine_request_session_id_path="session_id", engine_response_session_id_path="body.session_id", - additional_request_shape={ + request_shape={ "capacity": "`1024`" }, content_path="body.message" diff --git a/python/tests/integration/test_custom_session_handlers_integration.py b/python/tests/integration/test_custom_session_handlers_integration.py index a0c0914..92d4bd7 100644 --- a/python/tests/integration/test_custom_session_handlers_integration.py +++ b/python/tests/integration/test_custom_session_handlers_integration.py @@ -141,9 +141,8 @@ def custom_close_session(self, obj: CloseSessionRequest, request: Request): def setup_common_handlers(self): @sagemaker_standards.register_create_session_handler( - engine_request_session_id_path="session_id", engine_response_session_id_path="body", - additional_request_shape={ + request_shape={ "capacity_of_str_len": "`1024`", }, content_path="`successfully created session.`", @@ -327,7 +326,7 @@ def setup_method(self): def custom_create_session(self, obj: CreateSessionRequest, request: Request): self.handler_calls["create"] += 1 - if not obj.session_id: + if not getattr(obj, "session_id", None): obj.session_id = str(uuid.uuid4()) if obj.session_id in self.sessions: return Response(status_code=400) @@ -345,9 +344,8 @@ def custom_close_session(self, obj: CloseSessionRequest, request: Request): def setup_common_handlers(self): @sagemaker_standards.register_create_session_handler( - engine_request_session_id_path="session_id", engine_response_session_id_path="body.session_id", # Nested - additional_request_shape={ + request_shape={ "capacity_of_str_len": "`1024`", }, content_path="`successfully created session.`", @@ -372,21 +370,6 @@ def setup_invocation_handler(self): async def invocations(request: Request): return await self.custom_invocations(request) - def test_create_existing_session_error_handling(self): - """Test that attempting to create a session with existing ID fails. - - This validates that the custom handler properly rejects attempts to create - a session with a duplicate ID. This prevents session ID collisions and ensures - session uniqueness. - """ - # Create initial session - session_id = self.create_session() - - # Try to create another session with the same ID by passing it in the header - # Custom handler checks if session_id already exists and returns 400 if it does - header_response = self.create_session_with_id(session_id) - assert header_response.status_code == 400 - def test_end_to_end_simple(self): """Test complete session lifecycle: create -> use -> close. @@ -494,9 +477,8 @@ def setup_common_handlers(self): response_path = "body.session_id" if self.response_format == "dict" else "body" @sagemaker_standards.register_create_session_handler( - engine_request_session_id_path="session_id", engine_response_session_id_path=response_path, - additional_request_shape={"capacity_of_str_len": "`1024`"}, + request_shape={"capacity_of_str_len": "`1024`"}, content_path="`successfully created session.`", ) @self.app.api_route("/open_session", methods=["GET", "POST"]) @@ -565,9 +547,8 @@ def custom_close_session(self, obj: CloseSessionRequest, request: Request): def setup_common_handlers(self): @sagemaker_standards.register_create_session_handler( - engine_request_session_id_path="session_id", engine_response_session_id_path="body.session_id", - additional_request_shape={"capacity_of_str_len": "`1024`"}, + request_shape={"capacity_of_str_len": "`1024`"}, content_path="`successfully created session.`", ) @self.app.api_route("/open_session", methods=["GET", "POST"]) @@ -666,9 +647,8 @@ def custom_close_session(self, obj: CloseSessionRequest, request: Request): def setup_common_handlers(self): @sagemaker_standards.register_create_session_handler( - engine_request_session_id_path="session_id", engine_response_session_id_path="body.session_id", - additional_request_shape={"capacity_of_str_len": "`1024`"}, + request_shape={"capacity_of_str_len": "`1024`"}, content_path="`successfully created session.`", ) @self.app.api_route("/open_session", methods=["GET", "POST"]) @@ -781,9 +761,8 @@ def custom_close_session(self, obj: CloseSessionRequest, request: Request): def setup_common_handlers(self): @sagemaker_standards.register_create_session_handler( - engine_request_session_id_path="session_id", engine_response_session_id_path="body.session_id", - additional_request_shape={"capacity_of_str_len": "`1024`"}, + request_shape={"capacity_of_str_len": "`1024`"}, content_path="`successfully created session.`", ) @self.app.api_route("/open_session", methods=["GET", "POST"])