Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class BuildEndpointRequest(BaseModel):
high_priority: Optional[bool] = None
default_callback_url: Optional[str] = None
default_callback_auth: Optional[CallbackAuth] = None
queue_message_timeout_duration: Optional[int] = None


class BuildEndpointStatus(str, Enum):
Expand Down
2 changes: 2 additions & 0 deletions model-engine/model_engine_server/common/dtos/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class CreateLLMModelEndpointV1Request(BaseModel):
default_callback_url: Optional[HttpUrlStr] = None
default_callback_auth: Optional[CallbackAuth] = None
public_inference: Optional[bool] = True # LLM endpoints are public by default.
queue_message_timeout_duration: Optional[int] = Field(default=None, ge=1)


class CreateLLMModelEndpointV1Response(BaseModel):
Expand Down Expand Up @@ -137,6 +138,7 @@ class UpdateLLMModelEndpointV1Request(BaseModel):
default_callback_url: Optional[HttpUrlStr] = None
default_callback_auth: Optional[CallbackAuth] = None
public_inference: Optional[bool] = None
queue_message_timeout_duration: Optional[int] = Field(default=None, ge=1)


class UpdateLLMModelEndpointV1Response(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class CreateModelEndpointV1Request(BaseModel):
default_callback_url: Optional[HttpUrlStr] = None
default_callback_auth: Optional[CallbackAuth] = None
public_inference: Optional[bool] = Field(default=False)
queue_message_timeout_duration: Optional[int] = Field(default=None, ge=1)


class CreateModelEndpointV1Response(BaseModel):
Expand All @@ -95,6 +96,7 @@ class UpdateModelEndpointV1Request(BaseModel):
default_callback_url: Optional[HttpUrlStr] = None
default_callback_auth: Optional[CallbackAuth] = None
public_inference: Optional[bool] = None
queue_message_timeout_duration: Optional[int] = Field(default=None, ge=1)


class UpdateModelEndpointV1Response(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ async def create_model_endpoint(
default_callback_url: Optional[str],
default_callback_auth: Optional[CallbackAuth],
public_inference: Optional[bool] = False,
queue_message_timeout_duration: Optional[int] = None,
) -> ModelEndpointRecord:
"""
Creates a model endpoint.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ async def execute(
default_callback_url=request.default_callback_url,
default_callback_auth=request.default_callback_auth,
public_inference=request.public_inference,
queue_message_timeout_duration=request.queue_message_timeout_duration,
)
_handle_post_inference_hooks(
created_by=user.user_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def create_model_endpoint_infra(
billing_tags: Optional[Dict[str, Any]] = None,
default_callback_url: Optional[str],
default_callback_auth: Optional[CallbackAuth],
queue_message_timeout_duration: Optional[int] = None,
) -> str:
deployment_name = generate_deployment_name(
model_endpoint_record.created_by, model_endpoint_record.name
Expand All @@ -99,6 +100,7 @@ def create_model_endpoint_infra(
billing_tags=billing_tags,
default_callback_url=default_callback_url,
default_callback_auth=default_callback_auth,
queue_message_timeout_duration=queue_message_timeout_duration,
)
response = self.task_queue_gateway.send_task(
task_name=BUILD_TASK_NAME,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import Any, Dict
from datetime import timedelta
from typing import Any, Dict, Optional

from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from azure.identity import DefaultAzureCredential
Expand Down Expand Up @@ -32,13 +33,36 @@ async def create_queue_if_not_exists(
endpoint_name: str,
endpoint_created_by: str,
endpoint_labels: Dict[str, Any],
queue_message_timeout_duration: Optional[int] = None,
) -> QueueInfo:
queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id)
timeout_duration = queue_message_timeout_duration or 60 # Default to 60 seconds

# Validation: Azure Service Bus lock duration must be <= 5 minutes (300s)
if timeout_duration > 300:
raise ValueError(f"queue_message_timeout_duration ({timeout_duration}s) exceeds Azure Service Bus maximum of 300 seconds")

with _get_servicebus_administration_client() as client:
try:
# First, try to create the queue with default properties
client.create_queue(queue_name=queue_name)

# Then update the queue properties to set custom lock duration
queue_properties = client.get_queue(queue_name)
queue_properties.lock_duration = timedelta(seconds=timeout_duration)
client.update_queue(queue_properties)

except ResourceExistsError:
pass
# Queue already exists, update its properties if needed
try:
queue_properties = client.get_queue(queue_name)
# Only update if the lock duration is different
if queue_properties.lock_duration != timedelta(seconds=timeout_duration):
queue_properties.lock_duration = timedelta(seconds=timeout_duration)
client.update_queue(queue_properties)
except Exception as e:
# If we can't update properties, log but don't fail
logger.warning(f"Could not update queue properties for {queue_name}: {e}")

return QueueInfo(queue_name, None)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Sequence
from typing import Any, Dict, Optional, Sequence

from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import (
QueueEndpointResourceDelegate,
Expand All @@ -15,6 +15,7 @@ async def create_queue_if_not_exists(
endpoint_name: str,
endpoint_created_by: str,
endpoint_labels: Dict[str, Any],
queue_message_timeout_duration: Optional[int] = None,
) -> QueueInfo:
queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id)
queue_url = f"http://foobar.com/{queue_name}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ async def create_queue(
self,
endpoint_record: ModelEndpointRecord,
labels: Dict[str, str],
queue_message_timeout_duration: Optional[int] = None,
) -> QueueInfo:
"""Creates a new queue, returning its unique name and queue URL."""
queue_name, queue_url = await self.queue_delegate.create_queue_if_not_exists(
endpoint_id=endpoint_record.id,
endpoint_name=endpoint_record.name,
endpoint_created_by=endpoint_record.created_by,
endpoint_labels=labels,
queue_message_timeout_duration=queue_message_timeout_duration,
)
return QueueInfo(queue_name, queue_url)

Expand All @@ -56,7 +58,11 @@ async def create_or_update_resources(
request.build_endpoint_request.model_endpoint_record.endpoint_type
== ModelEndpointType.ASYNC
):
q = await self.create_queue(endpoint_record, request.build_endpoint_request.labels)
q = await self.create_queue(
endpoint_record,
request.build_endpoint_request.labels,
request.build_endpoint_request.queue_message_timeout_duration
)
queue_name: Optional[str] = q.queue_name
queue_url: Optional[str] = q.queue_url
destination: str = q.queue_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@ async def create_queue_if_not_exists(
endpoint_name: str,
endpoint_created_by: str,
endpoint_labels: Dict[str, Any],
queue_message_timeout_duration: Optional[int] = None,
) -> QueueInfo:
"""
Creates a queue associated with the given endpoint_id. Other fields are set as tags on the queue.

Args:
queue_message_timeout_duration: Optional timeout duration in seconds for queue messages.
For SQS, this sets the VisibilityTimeout.
For Azure Service Bus, this sets the lock_duration (max 300 seconds).
"""

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ async def create_queue_if_not_exists(
endpoint_name: str,
endpoint_created_by: str,
endpoint_labels: Dict[str, Any],
queue_message_timeout_duration: Optional[int] = None,
) -> QueueInfo:
# Use provided timeout or default to 43200 (12 hours, max SQS visibility)
timeout_duration = queue_message_timeout_duration or 43200

async with _create_async_sqs_client(sqs_profile=self.sqs_profile) as sqs_client:
queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id)

Expand All @@ -73,9 +77,7 @@ async def create_queue_if_not_exists(
create_response = await sqs_client.create_queue(
QueueName=queue_name,
Attributes=dict(
VisibilityTimeout="43200",
# To match current hardcoded Celery timeout of 24hr
# However, the max SQS visibility is 12hrs.
VisibilityTimeout=str(timeout_duration),
Policy=_get_queue_policy(queue_name=queue_name),
),
tags=_get_queue_tags(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ async def create_model_endpoint(
default_callback_url: Optional[str] = None,
default_callback_auth: Optional[CallbackAuth],
public_inference: Optional[bool] = False,
queue_message_timeout_duration: Optional[int] = None,
) -> ModelEndpointRecord:
existing_endpoints = (
await self.model_endpoint_record_repository.list_model_endpoint_records(
Expand Down Expand Up @@ -203,6 +204,7 @@ async def create_model_endpoint(
high_priority=high_priority,
default_callback_url=default_callback_url,
default_callback_auth=default_callback_auth,
queue_message_timeout_duration=queue_message_timeout_duration,
)
await self.model_endpoint_record_repository.update_model_endpoint_record(
model_endpoint_id=model_endpoint_record.id,
Expand Down