Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import bisect
from typing import Optional
from tqdm import tqdm
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
Expand Down Expand Up @@ -196,7 +197,11 @@ def warmup(self, model):
model: TpPartBaseModel = model

# decode cuda graph init
for batch_size in self.cuda_graph_batch_sizes[::-1]:
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs")
for batch_size in progress_bar:
avail_mem, _ = torch.cuda.mem_get_info()
avail_mem_gb = avail_mem / (1024 ** 3)
progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB")
seq_len = 2
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
Expand Down Expand Up @@ -251,7 +256,13 @@ def warmup_overlap(self, model):

model: TpPartBaseModel = model

for batch_size in self.cuda_graph_batch_sizes[::-1]:
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs")
for batch_size in progress_bar:
avail_mem, _ = torch.cuda.mem_get_info()
avail_mem_gb = avail_mem / (1024 ** 3)
progress_bar.set_description(
f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB"
)
decode_batches = []
for micro_batch_index in [0, 1]:
# dummy decoding, capture the cudagraph
Expand Down
6 changes: 3 additions & 3 deletions lightllm/common/triton_utils/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _try_load_cache(self, static_key):

cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key))
if os.path.exists(cache_file):
logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}")
logger.info(f"Loading cached configs for {self.kernel_name} - {dict(static_key)}")
with open(cache_file, "rb") as f:
self.cached_configs[static_key] = orjson.loads(f.read())
return True
Expand Down Expand Up @@ -353,9 +353,9 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS | orjson.OPT_NON_STR_KEYS,
)
)
logger.info(f"Saved configs for {self.kernel_name} - {_static_key}")
logger.info(f"Saved configs for {self.kernel_name} - {dict(_static_key)}")

logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {static_key} finished")
logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {dict(static_key)} finished")

def _mutate_args_clone(self, args, kwargs):
origin_list = []
Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/detokenization/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _add_new_group_req_index(self, recv_obj: GroupReqIndexes):
req.link_prompt_ids_shm_array()
req.link_logprobs_shm_array()

logger.info(
logger.debug(
f"detokenization recv req id {req.request_id} " f"cost time {time.time() - recv_obj.time_mark} s"
)

Expand Down Expand Up @@ -160,7 +160,7 @@ def remove_finished_reqs(self):

for decode_req in finished_reqs:
decode_req.req.can_released_mark = True
logger.info(f"detoken release req id {decode_req.req.request_id}")
logger.debug(f"detoken release req id {decode_req.req.request_id}")
self.shm_req_manager.put_back_req_obj(decode_req.req)
self.req_id_to_out.pop(decode_req.request_id, None)
return
Expand Down
8 changes: 4 additions & 4 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ async def _log_req_header(self, request_headers, group_request_id: int):
x_session_id = request_headers.get("X-Session-Id", "")

format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S")
logger.info(
logger.debug(
f"recieved req X-Request-Id:{x_request_id} "
f"X-Session-Id:{x_session_id} start_time:{format_in_time} "
f"lightllm_req_id:{group_request_id} "
Expand Down Expand Up @@ -611,7 +611,7 @@ async def _wait_to_token_package(
(out_token_counter - sum(sub_req_id_to_mtp_accepted_token_num.values())), 1
)
format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
logger.info(
logger.debug(
f"X-Request-Id:{x_request_id} "
f"X-Session-Id:{x_session_id} start_time:{format_start_time} "
f"lightllm_req_id:{group_request_id} first_token_cost:{first_token_cost_ms}ms "
Expand Down Expand Up @@ -698,8 +698,8 @@ async def recycle_resource_loop(self):
if req_status is None:
continue

logger.info(
f"left req id {req_status.group_req_objs.group_req_id}"
logger.debug(
f"left req id {req_status.group_req_objs.group_req_id} "
f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} "
f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}"
)
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/router/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def filter_out_finished_req(self, shm_req_manager: ShmReqManager):
unfinished_req_ids = []
for req in self.reqs:
if req.shm_infer_released:
logger.info(f"router release req id {req.request_id}")
logger.debug(f"router release req id {req.request_id}")
shm_req_manager.put_back_req_obj(req)
req = None
else:
Expand Down
59 changes: 31 additions & 28 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .batch import Batch, Req
from .model_infer.model_rpc import start_model_process, ModelRpcClient
from .req_queue import build_req_queue
from .stats import SystemStatusReporter
from lightllm.server.core.objs.io_objs import (
GroupReqIndexes,
AbortedReqCmd,
Expand All @@ -25,7 +26,7 @@
from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient
from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient
from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer
from lightllm.utils.log_utils import init_logger, log_time_ready
from lightllm.utils.log_utils import init_logger
from lightllm.server.router.token_load import TokenLoad
from lightllm.server.metrics.manager import MetricClient
from lightllm.common.basemodel.infer_lock import g_router_lock
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(self, args: StartArgs):
self.read_only_statics_mem_manager = ReadOnlyStaticsMemoryManager()
# 初始化 radix_cache_client 用于读取 prompt cache 的管理信息
self.radix_cache_client = None
self.status_reporter = None

# 共享变量,用于存储router端调度分析得到的机器负载信息
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node)
Expand Down Expand Up @@ -192,6 +194,11 @@ async def wait_to_model_ready(self):
)
self.req_queue = build_req_queue(self.args, self, self.dp_size_in_node)
logger.info(f"use req queue {self.req_queue.__class__.__name__}")
self.status_reporter = SystemStatusReporter(
args=self.args,
max_total_token_num=self.max_total_token_num,
dp_size_in_node=self.dp_size_in_node,
)

if self.args.run_mode == "prefill":
# 启动 prefill kv move 管理进程
Expand Down Expand Up @@ -237,26 +244,10 @@ async def loop_for_fwd(
await self._step()
counter_count += 1
if self.running_batch is not None:
# Count output tokens (each running req produces ~1 token per decode step)
self.status_reporter.count_output_tokens(len(self.running_batch.reqs))
if counter_count % 100 == 0:
for dp_index in range(self.dp_size_in_node):
token_ratio1 = self.get_used_tokens(dp_index) / self.max_total_token_num
token_ratio2 = (
self.max_total_token_num
- self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index)
) / self.max_total_token_num
d_i = dp_index
frozen_token_num = self.shared_token_load.get_frozened_token_count(d_i)
estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i)
paused_req_num = self._get_paused_req_num_in_dp_index(dp_index=d_i)
logger.debug(
f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n"
f"dp_i {d_i} paused req num: {paused_req_num} \n"
f"dp_i {d_i} frozen token num: {frozen_token_num} \n"
f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n"
f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n"
f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token"
)
self.metric_client.gauge_set("lightllm_batch_pause_size", paused_req_num)
self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num())
# pd decode mode need to update token_load more frequently
self.req_queue.update_token_load(self.running_batch, force_update=self.is_pd_decode_mode)
self.metric_client.gauge_set("lightllm_batch_current_size", len(self.running_batch.reqs))
Expand All @@ -275,13 +266,15 @@ async def loop_for_fwd(
self.metric_client.gauge_set("lightllm_batch_pause_size", 0.0)
self.metric_client.gauge_set("lightllm_queue_size", 0.0)
self.metric_client.gauge_set("lightllm_batch_current_max_tokens", 0.0)
# 60s print once
if log_time_ready("frozen_info", 60):
for dp_i in range(self.dp_size_in_node):
frozen_token_num = self.shared_token_load.get_frozened_token_count(dp_i)
estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(dp_i)
logger.debug(f"dp_i {dp_i} frozen token num: {frozen_token_num} \n")
logger.debug(f"dp_i {dp_i} estimated_peak_token_count: {estimated_peak_token_count} \n")

self.status_reporter.maybe_print(
running_batch=self.running_batch,
req_queue=self.req_queue,
read_only_statics_mem_manager=self.read_only_statics_mem_manager,
paused_req_num=self._get_paused_req_num(),
radix_cache_client=self.radix_cache_client,
disable_dynamic_prompt_cache=self.args.disable_dynamic_prompt_cache,
)

await asyncio.sleep(self._get_schedule_time_interval())

Expand Down Expand Up @@ -311,6 +304,7 @@ async def _step(self):

async def _add_batch(self, batch: Batch):
# 添加新请求
self.status_reporter.count_prompt_tokens(batch.input_tokens())
reqs = [r.to_router_rpc_obj() for r in batch.reqs]
while not self.shm_reqs_io_buffer.is_empty():
await asyncio.sleep(0.02)
Expand Down Expand Up @@ -347,6 +341,15 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch):

def _filter_reqs_from_running_batch(self):
if self.running_batch is not None:
# Capture finished req stats before filtering
for req in self.running_batch.reqs:
if req.shm_infer_released:
self.status_reporter.on_request_completed(
input_len=req.input_len,
output_len=req.shm_cur_output_len,
cache_len=req.prompt_cache_len,
mtp_accepted=req.mtp_accepted_token_num,
)
self.running_batch.filter_out_finished_req(self.shm_req_manager)
if self.running_batch.is_clear():
self.running_batch = None
Expand Down Expand Up @@ -419,7 +422,7 @@ def _add_req(self, group_req_indexes: GroupReqIndexes):
req._router_stop_str_matched = False
req_group.append(req)

logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s")
logger.debug(f"router receive req id {req.request_id} cost time {time.time() - req.start_time} s")
self.req_queue.extend(req_group)
self.send_to_detokenization.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
return
Expand Down
143 changes: 103 additions & 40 deletions lightllm/server/router/stats.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,110 @@
import time
from lightllm.utils.log_utils import init_logger
from .batch import Batch
import logging
from lightllm.utils.log_utils import init_system_status_logger

logger = init_logger(__name__)
logger = logging.getLogger(__name__)


class Stats:
def __init__(self, log_status, log_stats_interval) -> None:
self.log_stats = log_status
self.log_stats_interval = log_stats_interval
self.last_log_time = time.time()
self.all_tokens = 0
self.output_tokens = 0
class SystemStatusReporter:
def __init__(self, args, max_total_token_num, dp_size_in_node):
self.enabled = not args.disable_log_stats
self.interval = max(5, args.log_stats_interval)
if args.log_stats_interval < 5:
logger.warning(f"log_stats_interval={args.log_stats_interval}s is below minimum, using 5s")
self.max_total_token_num = max_total_token_num
self.dp_size_in_node = dp_size_in_node
self.status_logger = init_system_status_logger("router")

# Accumulation counters (reset each interval)
self.last_print_time = time.time()
self.prompt_tokens = 0
return

def count_prompt_tokens(self, run_batch: Batch):
if self.log_stats and run_batch is not None:
tokens = run_batch.input_tokens()
self.prompt_tokens += tokens
self.all_tokens += tokens
return

def count_output_tokens(self, run_batch: Batch):
if self.log_stats and run_batch is not None:
tokens = len(run_batch.reqs)
self.output_tokens += tokens
self.all_tokens += tokens
return

def print_stats(self):
if not self.log_stats:
return
self.output_tokens = 0

# Global counters (never reset, for lifetime stats)
self.global_input_total = 0
self.global_cache_total = 0
self.global_mtp_output_total = 0
self.global_mtp_accepted_total = 0

def count_prompt_tokens(self, num_tokens: int):
if self.enabled:
self.prompt_tokens += num_tokens

def count_output_tokens(self, num_tokens: int):
if self.enabled:
self.output_tokens += num_tokens

def on_request_completed(self, input_len: int, output_len: int, cache_len: int, mtp_accepted: int):
if self.enabled:
self.global_input_total += input_len
self.global_cache_total += cache_len
self.global_mtp_output_total += output_len
self.global_mtp_accepted_total += mtp_accepted

def maybe_print(
self,
running_batch,
req_queue,
read_only_statics_mem_manager,
paused_req_num=0,
radix_cache_client=None,
disable_dynamic_prompt_cache=False,
):
Comment on lines +44 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The maybe_print method has a large number of parameters. Many of these, such as req_queue, read_only_statics_mem_manager, radix_cache_client, and disable_dynamic_prompt_cache, are available when SystemStatusReporter is initialized and seem to be constant throughout its lifetime.

To improve code clarity and maintainability, consider moving these stable dependencies to the __init__ method. This simplifies the maybe_print signature and makes the dependencies of SystemStatusReporter more explicit.

For example, you could modify __init__ to accept these objects and store them as instance attributes. Then maybe_print would only need the parameters that change on each call, like running_batch and paused_req_num.

if not self.enabled:
return
now = time.time()
if now - self.last_log_time > self.log_stats_interval:
logger.debug(
f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s"
)
self.all_tokens = 0
self.output_tokens = 0
self.prompt_tokens = 0
self.last_log_time = now
return
elapsed = now - self.last_print_time
if elapsed < self.interval:
return

total_tps = (self.prompt_tokens + self.output_tokens) / elapsed
input_tps = self.prompt_tokens / elapsed
output_tps = self.output_tokens / elapsed

running = len(running_batch.reqs) if running_batch else 0
queued = req_queue.get_wait_req_num()

# Memory utilization (average across dp)
# kv_used: physical KV memory usage (includes prefix cache tree occupancy)
# kv_used_no_cache: effective usage excluding unrefed prefix cache tokens
kv_used_list = []
kv_used_no_cache_list = []
for dp_i in range(self.dp_size_in_node):
unrefed = read_only_statics_mem_manager.get_unrefed_token_num(dp_i)
used = self.max_total_token_num - unrefed
kv_used_list.append(used / self.max_total_token_num)
if not disable_dynamic_prompt_cache and radix_cache_client is not None:
cache_unrefed = radix_cache_client.get_unrefed_tokens_num(dp_i)
kv_used_no_cache_list.append((used - cache_unrefed) / self.max_total_token_num)
else:
kv_used_no_cache_list.append(used / self.max_total_token_num)
avg_kv_used = sum(kv_used_list) / len(kv_used_list)
avg_kv_used_no_cache = sum(kv_used_no_cache_list) / len(kv_used_no_cache_list)

# Global prefix cache hit rate
cache_hit_rate = (
(self.global_cache_total / self.global_input_total * 100) if self.global_input_total > 0 else 0.0
)

kv_pct = avg_kv_used * 100
kv_pct_no_cache = avg_kv_used_no_cache * 100

# Avg MTP accepted length (only shown when MTP is active)
mtp_suffix = ""
if self.global_mtp_accepted_total > 0:
decode_steps = self.global_mtp_output_total - self.global_mtp_accepted_total
avg_mtp_len = self.global_mtp_output_total / max(decode_steps, 1)
mtp_suffix = f" | MTP {avg_mtp_len:.2f}"

self.status_logger.info(
f"Throughput {total_tps:>7.1f} tok/s (in {input_tps:.1f}, out {output_tps:.1f}) | "
f"Reqs {running} run, {queued} wait, {paused_req_num} pause | "
f"KV Cache {kv_pct:.1f}% (active {kv_pct_no_cache:.1f}%) | "
f"Prefix Hit {cache_hit_rate:.1f}%"
f"{mtp_suffix}"
)

# Reset windowed counters
self.prompt_tokens = 0
self.output_tokens = 0
self.last_print_time = now
Loading