-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtimed_proxy.py
More file actions
61 lines (53 loc) · 2.81 KB
/
timed_proxy.py
File metadata and controls
61 lines (53 loc) · 2.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
TimedPDProxyServer — measures prefill + KV transfer latencies,
pushes metrics to MetricsSidecar actor.
"""
import logging
import time
from typing import AsyncGenerator, Optional, Union
import ray
from ray.llm._internal.serve.serving_patterns.prefill_decode.pd_server import PDProxyServer
from ray.llm._internal.serve.core.configs.openai_api_models import (
ChatCompletionRequest, CompletionRequest,
ChatCompletionResponse, CompletionResponse, ErrorResponse,
)
# Raw HTTP request info wrapper
from ray.llm._internal.serve.core.protocol import RawRequestInfo
logger = logging.getLogger("ray.serve")
RequestType = Union[ChatCompletionRequest, CompletionRequest]
class TimedPDProxyServer(PDProxyServer):
# override to add timing, it receives every incoming request
async def _handle_request(
self,
request: RequestType,
raw_request_info: Optional[RawRequestInfo] = None,
) -> AsyncGenerator[Union[str, ChatCompletionResponse, CompletionResponse, ErrorResponse], None]:
self._maybe_add_request_id_to_request(request) # Attach unique req ID
method = "chat" if isinstance(request, ChatCompletionRequest) else "completions" # which LLM method to call
# Prefill
prefill_request = self._prepare_prefill_request(request) # Prepare prefill req
t0 = time.perf_counter() # start timer t0 at prefill
prefill_gen = getattr(self.prefill_server, method).remote(prefill_request, raw_request_info) # Call prefill engine async (.remote)
prefill_chunk = await prefill_gen.__anext__() # Wait for first chunk of output
if isinstance(prefill_chunk, ErrorResponse):
yield prefill_chunk
return
# Decode engine will pull KV cache from prefill via NixlConnector
t_kv = time.perf_counter() # Start KV timing
decode_request = self._prepare_decode_request(request, prefill_chunk) #prep decode req (includes KV cache from prefill)
decode_gen = getattr(self.decode_server, method).remote(decode_request, raw_request_info) #call decode engine async
# sream decoded tokens/resp to caller
t_first = None
async for chunk in decode_gen:
if t_first is None:
t_first = time.perf_counter()# record time of first decoded token
yield chunk
# Get Ray actor metrics_sidecar
# (t_first - t0) = prefill-start -> first decoded token (includes prefill + any handoff + decode start)
# (t_first - t_kv) = decode-start -> first decoded token (includes KV pull + decode scheduling/compute)
if t_first is not None:
try:
sidecar = ray.get_actor("metrics_sidecar", namespace="serve")
sidecar.record.remote((t_first - t0) * 1000, (t_first - t_kv) * 1000)
except Exception:
pass