Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions python/ray/serve/_private/autoscaling_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(self, deployment_id: DeploymentID):
# content of the dictionary is determined by the user defined policy
self._policy_state: Optional[Dict[str, Any]] = None
self._running_replicas: List[ReplicaID] = []
self._cached_running_replica_strs: Set[str] = set()
self._target_capacity: Optional[float] = None
self._target_capacity_direction: Optional[TargetCapacityDirection] = None
self._cached_deployment_snapshot: Optional[DeploymentSnapshot] = None
Expand Down Expand Up @@ -190,6 +191,9 @@ def get_num_replicas_upper_bound(self) -> int:
def update_running_replica_ids(self, running_replicas: List[ReplicaID]):
"""Update cached set of running replica IDs for this deployment."""
self._running_replicas = running_replicas
self._cached_running_replica_strs = {
r.to_full_id_str() for r in running_replicas
}

def record_scale_up(self):
"""Record a scale up event by updating the timestamp."""
Expand Down Expand Up @@ -454,15 +458,11 @@ def _collect_handle_running_requests(self) -> List[TimeSeries]:
timeseries_list = []

for handle_metric in self._handle_requests.values():
for replica_id in self._running_replicas:
if (
RUNNING_REQUESTS_KEY not in handle_metric.metrics
or replica_id not in handle_metric.metrics[RUNNING_REQUESTS_KEY]
):
running_reqs = handle_metric.metrics.get(RUNNING_REQUESTS_KEY, {})
for replica_str in self._cached_running_replica_strs:
if replica_str not in running_reqs:
continue
timeseries_list.append(
handle_metric.metrics[RUNNING_REQUESTS_KEY][replica_id]
)
timeseries_list.append(running_reqs[replica_str])

return timeseries_list

Expand Down Expand Up @@ -664,11 +664,12 @@ def _calculate_total_requests_simple_mode(self) -> float:
"""
total_requests = 0

for id in self._running_replicas:
if id in self._replica_metrics:
total_requests += self._replica_metrics[id].aggregated_metrics.get(
RUNNING_REQUESTS_KEY, 0
)
# Iterate over _replica_metrics but only count running replicas. Stale metrics from
# stopped replicas can remain until on_replica_stopped runs; filtering avoids inflation.
for report in self._replica_metrics.values():
# TODO(abrar): Store replica_id as string in report to avoid this conversion.
if report.replica_id.to_full_id_str() in self._cached_running_replica_strs:
total_requests += report.aggregated_metrics.get(RUNNING_REQUESTS_KEY, 0)

metrics_collected_on_replicas = total_requests > 0

Expand All @@ -677,13 +678,12 @@ def _calculate_total_requests_simple_mode(self) -> float:
total_requests += handle_metric.aggregated_queued_requests
# Add running requests from handles if not collected on replicas
if not metrics_collected_on_replicas:
for replica_id in self._running_replicas:
if replica_id in handle_metric.aggregated_metrics.get(
RUNNING_REQUESTS_KEY, {}
):
total_requests += handle_metric.aggregated_metrics.get(
RUNNING_REQUESTS_KEY
).get(replica_id)
running_reqs = handle_metric.aggregated_metrics.get(
RUNNING_REQUESTS_KEY, {}
)
for replica_str, count in running_reqs.items():
if replica_str in self._cached_running_replica_strs:
total_requests += count
return total_requests

def _should_aggregate_metrics_at_controller(self) -> bool:
Expand Down
15 changes: 10 additions & 5 deletions python/ray/serve/_private/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,10 +1022,11 @@ class HandleMetricReport:
handle over the past look_back_period_s seconds. This is a list because
we take multiple measurements over time.
aggregated_metrics: A map of metric name to the aggregated value over the past
look_back_period_s seconds at the handle for each replica.
look_back_period_s seconds at the handle for each replica. Replica keys
use ReplicaID.to_full_id_str() for efficient controller-side lookups.
metrics: A map of metric name to the list of values running at that handle for each replica
over the past look_back_period_s seconds. This is a list because
we take multiple measurements over time.
over the past look_back_period_s seconds. Replica keys use to_full_id_str().
This is a list because we take multiple measurements over time.
timestamp: The time at which this report was created.
"""

Expand All @@ -1035,8 +1036,12 @@ class HandleMetricReport:
handle_source: DeploymentHandleSource
aggregated_queued_requests: float
queued_requests: TimeSeries
aggregated_metrics: Dict[str, Dict[ReplicaID, float]]
metrics: Dict[str, Dict[ReplicaID, TimeSeries]]
aggregated_metrics: Dict[
str, Dict[str, float]
] # replica key = ReplicaID.to_full_id_str()
metrics: Dict[
str, Dict[str, TimeSeries]
] # replica key = ReplicaID.to_full_id_str()
timestamp: float

@property
Expand Down
17 changes: 4 additions & 13 deletions python/ray/serve/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,21 +366,17 @@
"RAY_SERVE_RECORD_AUTOSCALING_STATS_TIMEOUT_S", 10.0
)

# How often autoscaling metrics are recorded on Serve replicas.
RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_INTERVAL_S = get_env_float(
"RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_INTERVAL_S", 0.5
# Factor of look_back_period_s for autoscaling metric record interval.
# Record interval = look_back_period_s * factor. Used by both router and replica.
RAY_SERVE_AUTOSCALING_METRIC_RECORD_INTERVAL_FACTOR = get_env_float(
"RAY_SERVE_AUTOSCALING_METRIC_RECORD_INTERVAL_FACTOR", 0.2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this changing the record interval from 0.5 seconds to 6 seconds? (since the default look back is 30s)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is, my argument is: couldn't find a argument against it, seem safe and reasonable default for large scales of cluster.

)

# Replica autoscaling metrics push interval.
RAY_SERVE_REPLICA_AUTOSCALING_METRIC_PUSH_INTERVAL_S = get_env_float(
"RAY_SERVE_REPLICA_AUTOSCALING_METRIC_PUSH_INTERVAL_S", 10.0
)

# How often autoscaling metrics are recorded on Serve handles.
RAY_SERVE_HANDLE_AUTOSCALING_METRIC_RECORD_INTERVAL_S = get_env_float(
"RAY_SERVE_HANDLE_AUTOSCALING_METRIC_RECORD_INTERVAL_S", 0.5
)

# Handle autoscaling metrics push interval. (This interval will affect the cold start time period)
RAY_SERVE_HANDLE_AUTOSCALING_METRIC_PUSH_INTERVAL_S = get_env_float(
"RAY_SERVE_HANDLE_AUTOSCALING_METRIC_PUSH_INTERVAL_S",
Expand Down Expand Up @@ -780,11 +776,6 @@
if RAY_SERVE_ENABLE_HA_PROXY:
RAY_SERVE_ENABLE_DIRECT_INGRESS = True

# The maximum allowed RPC latency in milliseconds.
# This is used to detect and warn about long RPC latencies
# between the controller and the replicas.
RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS = 2000

# Feature flag to aggregate metrics at the controller instead of the replicas or handles.
RAY_SERVE_AGGREGATE_METRICS_AT_CONTROLLER = get_env_bool(
"RAY_SERVE_AGGREGATE_METRICS_AT_CONTROLLER", "0"
Expand Down
33 changes: 7 additions & 26 deletions python/ray/serve/_private/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
RAY_SERVE_CONTROLLER_CALLBACK_IMPORT_PATH,
RAY_SERVE_ENABLE_DIRECT_INGRESS,
RAY_SERVE_ENABLE_HA_PROXY,
RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS,
RECOVERING_LONG_POLL_BROADCAST_TIMEOUT_S,
SERVE_CONTROLLER_NAME,
SERVE_DEFAULT_APP_NAME,
Expand Down Expand Up @@ -78,6 +77,7 @@
from ray.serve._private.usage import ServeUsageTag
from ray.serve._private.utils import (
call_function_from_import_path,
decompress_metric_report,
get_all_live_placement_group_names,
get_head_node_id,
is_grpc_enabled,
Expand Down Expand Up @@ -339,8 +339,10 @@ def get_pid(self) -> int:
return os.getpid()

def record_autoscaling_metrics_from_replica(
self, replica_metric_report: ReplicaMetricReport
self, replica_metric_report: Union[ReplicaMetricReport, bytes]
):
if isinstance(replica_metric_report, bytes):
replica_metric_report = decompress_metric_report(replica_metric_report)
latency = time.time() - replica_metric_report.timestamp
latency_ms = latency * 1000
# Record the metrics delay for observability
Expand All @@ -354,20 +356,15 @@ def record_autoscaling_metrics_from_replica(
)
# Track in health metrics
self._health_metrics_tracker.record_replica_metrics_delay(latency_ms)
if latency_ms > RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS:
logger.warning(
f"Received autoscaling metrics from replica {replica_metric_report.replica_id} with timestamp {replica_metric_report.timestamp} "
f"which is {latency_ms}ms ago. "
f"This is greater than the warning threshold RPC latency of {RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS}ms. "
"This may indicate a performance issue with the controller try increasing the RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS environment variable."
)
self.autoscaling_state_manager.record_request_metrics_for_replica(
replica_metric_report
)

def record_autoscaling_metrics_from_handle(
self, handle_metric_report: HandleMetricReport
self, handle_metric_report: Union[HandleMetricReport, bytes]
):
if isinstance(handle_metric_report, bytes):
handle_metric_report = decompress_metric_report(handle_metric_report)
latency = time.time() - handle_metric_report.timestamp
latency_ms = latency * 1000
# Record the metrics delay for observability
Expand All @@ -381,13 +378,6 @@ def record_autoscaling_metrics_from_handle(
)
# Track in health metrics
self._health_metrics_tracker.record_handle_metrics_delay(latency_ms)
if latency_ms > RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS:
logger.warning(
f"Received autoscaling metrics from handle {handle_metric_report.handle_id} for deployment {handle_metric_report.deployment_id} with timestamp {handle_metric_report.timestamp} "
f"which is {latency_ms}ms ago. "
f"This is greater than the warning threshold RPC latency of {RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS}ms. "
"This may indicate a performance issue with the controller try increasing the RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS environment variable."
)
self.autoscaling_state_manager.record_request_metrics_for_handle(
handle_metric_report
)
Expand All @@ -406,15 +396,6 @@ def record_autoscaling_metrics_from_async_inference_task_queue(
"application": report.deployment_id.app_name,
},
)
if latency_ms > RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS:
logger.warning(
f"Received async inference task queue metrics for deployment "
f"{report.deployment_id} with timestamp {report.timestamp_s} "
f"which is {latency_ms}ms ago. "
f"This is greater than the warning threshold RPC latency of "
f"{RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS}ms. "
"This may indicate a performance issue with the controller."
)
self.autoscaling_state_manager.record_async_inference_task_queue_metrics(report)

def _get_total_num_requests_for_deployment_for_testing(
Expand Down
30 changes: 22 additions & 8 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@
GRPC_CONTEXT_ARG_NAME,
HEALTH_CHECK_METHOD,
HEALTHY_MESSAGE,
RAY_SERVE_AUTOSCALING_METRIC_RECORD_INTERVAL_FACTOR,
RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE,
RAY_SERVE_DIRECT_INGRESS_MIN_DRAINING_PERIOD_S,
RAY_SERVE_DIRECT_INGRESS_PORT_RETRY_COUNT,
RAY_SERVE_ENABLE_DIRECT_INGRESS,
RAY_SERVE_METRICS_EXPORT_INTERVAL_MS,
RAY_SERVE_RECORD_AUTOSCALING_STATS_TIMEOUT_S,
RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_INTERVAL_S,
RAY_SERVE_REPLICA_GRPC_MAX_MESSAGE_LENGTH,
RAY_SERVE_REPLICA_UTILIZATION_NUM_BUCKETS,
RAY_SERVE_REPLICA_UTILIZATION_REPORT_INTERVAL_S,
Expand Down Expand Up @@ -147,6 +147,8 @@
from ray.serve._private.utils import (
Semaphore,
asyncio_grpc_exception_handler,
check_obj_ref_ready_nowait,
compress_metric_report,
generate_request_id,
get_component_file_name, # noqa: F401
is_grpc_enabled,
Expand All @@ -173,6 +175,7 @@
from ray.serve.grpc_util import RayServegRPCContext, gRPCInputStream
from ray.serve.handle import DeploymentHandle
from ray.serve.schema import EncodingType, LoggingConfig, ReplicaRank
from ray.types import ObjectRef
from ray.util import metrics as ray_metrics

logger = logging.getLogger(SERVE_LOGGER_NAME)
Expand Down Expand Up @@ -339,6 +342,10 @@ def __init__(
self._checked_custom_metrics = False
self._record_autoscaling_stats_fn = None

# Tracks in-flight metrics push to controller. Skip if new one is sent.
self._pending_metrics_push_ref: Optional[ObjectRef] = None
self._metrics_push_lock = threading.Lock()

# If the interval is set to 0, eagerly sets all metrics.
self._cached_metrics_enabled = RAY_SERVE_METRICS_EXPORT_INTERVAL_MS != 0
self._cached_metrics_interval_s = RAY_SERVE_METRICS_EXPORT_INTERVAL_MS / 1000
Expand Down Expand Up @@ -638,13 +645,14 @@ def start_metrics_pusher(self):
self._autoscaling_config.metrics_interval_s,
)
# Collect autoscaling metrics locally periodically.
record_interval_s = (
self._autoscaling_config.look_back_period_s
* RAY_SERVE_AUTOSCALING_METRIC_RECORD_INTERVAL_FACTOR
)
self._metrics_pusher.register_or_update_task(
self.RECORD_METRICS_TASK_NAME,
self._add_autoscaling_metrics_point_async,
min(
RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_INTERVAL_S,
self._autoscaling_config.metrics_interval_s,
),
min(record_interval_s, self._autoscaling_config.metrics_interval_s),
)

def should_collect_ongoing_requests(self) -> bool:
Expand Down Expand Up @@ -905,9 +913,15 @@ def _push_autoscaling_metrics(self) -> Dict[str, Any]:
aggregated_metrics=new_aggregated_metrics,
metrics=new_metrics,
)
self._controller_handle.record_autoscaling_metrics_from_replica.remote(
replica_metric_report
)
with self._metrics_push_lock:
if self._pending_metrics_push_ref is not None:
if not check_obj_ref_ready_nowait(self._pending_metrics_push_ref):
return # Previous push still in flight, skip and try again later
self._pending_metrics_push_ref = (
self._controller_handle.record_autoscaling_metrics_from_replica.remote(
compress_metric_report(replica_metric_report)
)
)

async def _fetch_custom_autoscaling_metrics(
self,
Expand Down
36 changes: 26 additions & 10 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
)
from ray.serve._private.config import DeploymentConfig
from ray.serve._private.constants import (
RAY_SERVE_AUTOSCALING_METRIC_RECORD_INTERVAL_FACTOR,
RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE,
RAY_SERVE_HANDLE_AUTOSCALING_METRIC_RECORD_INTERVAL_S,
RAY_SERVE_METRICS_EXPORT_INTERVAL_MS,
RAY_SERVE_PROXY_PREFER_LOCAL_AZ_ROUTING,
SERVE_LOGGER_NAME,
Expand Down Expand Up @@ -68,11 +68,14 @@
)
from ray.serve._private.usage import ServeUsageTag
from ray.serve._private.utils import (
check_obj_ref_ready_nowait,
compress_metric_report,
generate_request_id,
resolve_deployment_response,
)
from ray.serve.config import AutoscalingConfig
from ray.serve.exceptions import BackPressureError, DeploymentUnavailableError
from ray.types import ObjectRef
from ray.util import metrics

logger = logging.getLogger(SERVE_LOGGER_NAME)
Expand Down Expand Up @@ -154,6 +157,10 @@ def __init__(
# Track whether the metrics manager has been shutdown
self._shutdown: bool = False

# Tracks in-flight metrics push to controller. Skip if new one is sent.
self._pending_metrics_push_ref: Optional[ObjectRef] = None
self._metrics_push_lock = threading.Lock()

# If the interval is set to 0, eagerly sets all metrics.
self._cached_metrics_enabled = RAY_SERVE_METRICS_EXPORT_INTERVAL_MS != 0
self._cached_metrics_interval_s = RAY_SERVE_METRICS_EXPORT_INTERVAL_MS / 1000
Expand Down Expand Up @@ -275,13 +282,14 @@ def update_deployment_config(

# Record number of queued + ongoing requests at regular
# intervals into the in-memory metrics store
record_interval_s = (
autoscaling_config.look_back_period_s
* RAY_SERVE_AUTOSCALING_METRIC_RECORD_INTERVAL_FACTOR
)
self.metrics_pusher.register_or_update_task(
self.RECORD_METRICS_TASK_NAME,
self._add_autoscaling_metrics_point,
min(
RAY_SERVE_HANDLE_AUTOSCALING_METRIC_RECORD_INTERVAL_S,
autoscaling_config.metrics_interval_s,
),
min(record_interval_s, autoscaling_config.metrics_interval_s),
)
# Push metrics to the controller periodically.
self.metrics_pusher.register_or_update_task(
Expand Down Expand Up @@ -365,10 +373,17 @@ def push_autoscaling_metrics_to_controller(self):
"""Pushes queued and running request metrics to the controller.

These metrics are used by the controller for autoscaling.
If a previous push is already in flight, skips this push (will try again next interval).
"""
self._controller_handle.record_autoscaling_metrics_from_handle.remote(
self._get_metrics_report()
)
with self._metrics_push_lock:
if self._pending_metrics_push_ref is not None:
if not check_obj_ref_ready_nowait(self._pending_metrics_push_ref):
return # Previous push still in flight, skip and try again later
self._pending_metrics_push_ref = (
self._controller_handle.record_autoscaling_metrics_from_handle.remote(
compress_metric_report(self._get_metrics_report())
)
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scaled-to-zero cold-start push silently throttled

Medium Severity

push_autoscaling_metrics_to_controller is used both for periodic pushes and for the critical scaled-to-zero cold-start optimization (called from wrap_queued_request and update_deployment_config). The new throttling logic doesn't distinguish between these use cases. If a periodic push or a update_deployment_config push is still in flight when the first request arrives at a zero-replica deployment, the cold-start push is silently skipped. The controller won't learn about the queued request until the next periodic push (metrics_interval_s later, default 10s), defeating the cold-start optimization.

Additional Locations (1)

Fix in Cursor Fix in Web


def _add_autoscaling_metrics_point(self):
"""Adds metrics point for queued and running requests at replicas.
Expand Down Expand Up @@ -446,11 +461,12 @@ def _get_metrics_report(self) -> HandleMetricReport:
# If the running requests timeseries is empty, we set the sum
# to the current number of requests.
running_requests_sum = num_requests
avg_running_requests[replica_id] = (
replica_str = replica_id.to_full_id_str()
avg_running_requests[replica_str] = (
running_requests_sum / num_data_points
)
# Get running requests data
running_requests[replica_id] = self.metrics_store.data.get(
running_requests[replica_str] = self.metrics_store.data.get(
replica_id, [TimeStampedValue(timestamp, num_requests)]
)
handle_metric_report = HandleMetricReport(
Expand Down
Loading