diff --git a/model-engine/model_engine_server/common/dtos/endpoint_builder.py b/model-engine/model_engine_server/common/dtos/endpoint_builder.py index 64ea43d0d..43f92d661 100644 --- a/model-engine/model_engine_server/common/dtos/endpoint_builder.py +++ b/model-engine/model_engine_server/common/dtos/endpoint_builder.py @@ -33,6 +33,7 @@ class BuildEndpointRequest(BaseModel): high_priority: Optional[bool] = None default_callback_url: Optional[str] = None default_callback_auth: Optional[CallbackAuth] = None + queue_message_timeout_duration: Optional[int] = None class BuildEndpointStatus(str, Enum): diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 232b8ba6c..5991e79d7 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -70,6 +70,7 @@ class CreateLLMModelEndpointV1Request(BaseModel): default_callback_url: Optional[HttpUrlStr] = None default_callback_auth: Optional[CallbackAuth] = None public_inference: Optional[bool] = True # LLM endpoints are public by default. + queue_message_timeout_duration: Optional[int] = Field(default=None, ge=1) class CreateLLMModelEndpointV1Response(BaseModel): @@ -137,6 +138,7 @@ class UpdateLLMModelEndpointV1Request(BaseModel): default_callback_url: Optional[HttpUrlStr] = None default_callback_auth: Optional[CallbackAuth] = None public_inference: Optional[bool] = None + queue_message_timeout_duration: Optional[int] = Field(default=None, ge=1) class UpdateLLMModelEndpointV1Response(BaseModel): diff --git a/model-engine/model_engine_server/common/dtos/model_endpoints.py b/model-engine/model_engine_server/common/dtos/model_endpoints.py index e86208904..0f5908a85 100644 --- a/model-engine/model_engine_server/common/dtos/model_endpoints.py +++ b/model-engine/model_engine_server/common/dtos/model_endpoints.py @@ -69,6 +69,7 @@ class CreateModelEndpointV1Request(BaseModel): default_callback_url: Optional[HttpUrlStr] = None default_callback_auth: Optional[CallbackAuth] = None public_inference: Optional[bool] = Field(default=False) + queue_message_timeout_duration: Optional[int] = Field(default=None, ge=1) class CreateModelEndpointV1Response(BaseModel): @@ -95,6 +96,7 @@ class UpdateModelEndpointV1Request(BaseModel): default_callback_url: Optional[HttpUrlStr] = None default_callback_auth: Optional[CallbackAuth] = None public_inference: Optional[bool] = None + queue_message_timeout_duration: Optional[int] = Field(default=None, ge=1) class UpdateModelEndpointV1Response(BaseModel): diff --git a/model-engine/model_engine_server/domain/services/model_endpoint_service.py b/model-engine/model_engine_server/domain/services/model_endpoint_service.py index 4c3471b4f..bc3737004 100644 --- a/model-engine/model_engine_server/domain/services/model_endpoint_service.py +++ b/model-engine/model_engine_server/domain/services/model_endpoint_service.py @@ -90,6 +90,7 @@ async def create_model_endpoint( default_callback_url: Optional[str], default_callback_auth: Optional[CallbackAuth], public_inference: Optional[bool] = False, + queue_message_timeout_duration: Optional[int] = None, ) -> ModelEndpointRecord: """ Creates a model endpoint. diff --git a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py index 9d3553076..91349514b 100644 --- a/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py @@ -313,6 +313,7 @@ async def execute( default_callback_url=request.default_callback_url, default_callback_auth=request.default_callback_auth, public_inference=request.public_inference, + queue_message_timeout_duration=request.queue_message_timeout_duration, ) _handle_post_inference_hooks( created_by=user.user_id, diff --git a/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py b/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py index 4b73a3860..71bfd3bd7 100644 --- a/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py @@ -73,6 +73,7 @@ def create_model_endpoint_infra( billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str], default_callback_auth: Optional[CallbackAuth], + queue_message_timeout_duration: Optional[int] = None, ) -> str: deployment_name = generate_deployment_name( model_endpoint_record.created_by, model_endpoint_record.name @@ -99,6 +100,7 @@ def create_model_endpoint_infra( billing_tags=billing_tags, default_callback_url=default_callback_url, default_callback_auth=default_callback_auth, + queue_message_timeout_duration=queue_message_timeout_duration, ) response = self.task_queue_gateway.send_task( task_name=BUILD_TASK_NAME, diff --git a/model-engine/model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py index 3799ed654..8a224c5fa 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py @@ -1,5 +1,6 @@ import os -from typing import Any, Dict +from datetime import timedelta +from typing import Any, Dict, Optional from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError from azure.identity import DefaultAzureCredential @@ -32,13 +33,36 @@ async def create_queue_if_not_exists( endpoint_name: str, endpoint_created_by: str, endpoint_labels: Dict[str, Any], + queue_message_timeout_duration: Optional[int] = None, ) -> QueueInfo: queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + timeout_duration = queue_message_timeout_duration or 60 # Default to 60 seconds + + # Validation: Azure Service Bus lock duration must be <= 5 minutes (300s) + if timeout_duration > 300: + raise ValueError(f"queue_message_timeout_duration ({timeout_duration}s) exceeds Azure Service Bus maximum of 300 seconds") + with _get_servicebus_administration_client() as client: try: + # First, try to create the queue with default properties client.create_queue(queue_name=queue_name) + + # Then update the queue properties to set custom lock duration + queue_properties = client.get_queue(queue_name) + queue_properties.lock_duration = timedelta(seconds=timeout_duration) + client.update_queue(queue_properties) + except ResourceExistsError: - pass + # Queue already exists, update its properties if needed + try: + queue_properties = client.get_queue(queue_name) + # Only update if the lock duration is different + if queue_properties.lock_duration != timedelta(seconds=timeout_duration): + queue_properties.lock_duration = timedelta(seconds=timeout_duration) + client.update_queue(queue_properties) + except Exception as e: + # If we can't update properties, log but don't fail + logger.warning(f"Could not update queue properties for {queue_name}: {e}") return QueueInfo(queue_name, None) diff --git a/model-engine/model_engine_server/infra/gateways/resources/fake_queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/fake_queue_endpoint_resource_delegate.py index 9ded2d6e5..b43e5c4cc 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/fake_queue_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/fake_queue_endpoint_resource_delegate.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Sequence +from typing import Any, Dict, Optional, Sequence from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( QueueEndpointResourceDelegate, @@ -15,6 +15,7 @@ async def create_queue_if_not_exists( endpoint_name: str, endpoint_created_by: str, endpoint_labels: Dict[str, Any], + queue_message_timeout_duration: Optional[int] = None, ) -> QueueInfo: queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) queue_url = f"http://foobar.com/{queue_name}" diff --git a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py index fb637c10f..61089a60b 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py @@ -38,6 +38,7 @@ async def create_queue( self, endpoint_record: ModelEndpointRecord, labels: Dict[str, str], + queue_message_timeout_duration: Optional[int] = None, ) -> QueueInfo: """Creates a new queue, returning its unique name and queue URL.""" queue_name, queue_url = await self.queue_delegate.create_queue_if_not_exists( @@ -45,6 +46,7 @@ async def create_queue( endpoint_name=endpoint_record.name, endpoint_created_by=endpoint_record.created_by, endpoint_labels=labels, + queue_message_timeout_duration=queue_message_timeout_duration, ) return QueueInfo(queue_name, queue_url) @@ -56,7 +58,11 @@ async def create_or_update_resources( request.build_endpoint_request.model_endpoint_record.endpoint_type == ModelEndpointType.ASYNC ): - q = await self.create_queue(endpoint_record, request.build_endpoint_request.labels) + q = await self.create_queue( + endpoint_record, + request.build_endpoint_request.labels, + request.build_endpoint_request.queue_message_timeout_duration + ) queue_name: Optional[str] = q.queue_name queue_url: Optional[str] = q.queue_url destination: str = q.queue_name diff --git a/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py index 76c77e64b..4998b2959 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py @@ -24,9 +24,15 @@ async def create_queue_if_not_exists( endpoint_name: str, endpoint_created_by: str, endpoint_labels: Dict[str, Any], + queue_message_timeout_duration: Optional[int] = None, ) -> QueueInfo: """ Creates a queue associated with the given endpoint_id. Other fields are set as tags on the queue. + + Args: + queue_message_timeout_duration: Optional timeout duration in seconds for queue messages. + For SQS, this sets the VisibilityTimeout. + For Azure Service Bus, this sets the lock_duration (max 300 seconds). """ @abstractmethod diff --git a/model-engine/model_engine_server/infra/gateways/resources/sqs_queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/sqs_queue_endpoint_resource_delegate.py index 748c3f699..8953d4ae4 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/sqs_queue_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/sqs_queue_endpoint_resource_delegate.py @@ -55,7 +55,11 @@ async def create_queue_if_not_exists( endpoint_name: str, endpoint_created_by: str, endpoint_labels: Dict[str, Any], + queue_message_timeout_duration: Optional[int] = None, ) -> QueueInfo: + # Use provided timeout or default to 43200 (12 hours, max SQS visibility) + timeout_duration = queue_message_timeout_duration or 43200 + async with _create_async_sqs_client(sqs_profile=self.sqs_profile) as sqs_client: queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) @@ -73,9 +77,7 @@ async def create_queue_if_not_exists( create_response = await sqs_client.create_queue( QueueName=queue_name, Attributes=dict( - VisibilityTimeout="43200", - # To match current hardcoded Celery timeout of 24hr - # However, the max SQS visibility is 12hrs. + VisibilityTimeout=str(timeout_duration), Policy=_get_queue_policy(queue_name=queue_name), ), tags=_get_queue_tags( diff --git a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py index 475fbca86..8fcc83c68 100644 --- a/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py +++ b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py @@ -160,6 +160,7 @@ async def create_model_endpoint( default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth], public_inference: Optional[bool] = False, + queue_message_timeout_duration: Optional[int] = None, ) -> ModelEndpointRecord: existing_endpoints = ( await self.model_endpoint_record_repository.list_model_endpoint_records( @@ -203,6 +204,7 @@ async def create_model_endpoint( high_priority=high_priority, default_callback_url=default_callback_url, default_callback_auth=default_callback_auth, + queue_message_timeout_duration=queue_message_timeout_duration, ) await self.model_endpoint_record_repository.update_model_endpoint_record( model_endpoint_id=model_endpoint_record.id,