From 74b87a9c9f8e134dbe18b38403e1916922d74767 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 6 May 2026 11:26:32 +0800 Subject: [PATCH 01/10] refactor(logging): colored levels, windowed cache stats, quieter per-request logs - Add ANSI color codes to log level names (TTY only; plain in files) - Introduce SystemStatusReporter with windowed prefix-cache hit rate alongside the global rate, plus a more compact status line - Drop gunicorn --access-logfile flags (FastAPI middleware now handles it) - Remove duplicate _ACCESS_LOG_STATUS_COLORS declaration in api_http.py - Downgrade noisy per-request / per-batch progress logs from INFO to DEBUG - Fix flake8 F841 (unused exception variable) in detokenization manager --- lightllm/common/basemodel/cuda_graph.py | 15 ++- lightllm/common/triton_utils/autotuner.py | 6 +- lightllm/server/api_http.py | 1 - lightllm/server/api_start.py | 6 - lightllm/server/detokenization/manager.py | 9 +- lightllm/server/httpserver/manager.py | 8 +- .../httpserver_for_pd_master/manager.py | 2 +- lightllm/server/router/batch.py | 8 +- lightllm/server/router/manager.py | 64 ++++----- lightllm/server/router/stats.py | 123 +++++++++++++++++- lightllm/utils/log_utils.py | 72 ++++++---- 11 files changed, 228 insertions(+), 86 deletions(-) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 782150661e..7a24fd17fb 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -4,6 +4,7 @@ import bisect import triton from typing import Optional +from tqdm import tqdm from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager @@ -197,7 +198,11 @@ def warmup(self, model): model: TpPartBaseModel = model # decode cuda graph init - for batch_size in self.cuda_graph_batch_sizes[::-1]: + progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs") + for batch_size in progress_bar: + avail_mem, _ = torch.cuda.mem_get_info() + avail_mem_gb = avail_mem / (1024 ** 3) + progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB") seq_len = 2 total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch @@ -252,7 +257,13 @@ def warmup_overlap(self, model): model: TpPartBaseModel = model - for batch_size in self.cuda_graph_batch_sizes[::-1]: + progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs") + for batch_size in progress_bar: + avail_mem, _ = torch.cuda.mem_get_info() + avail_mem_gb = avail_mem / (1024 ** 3) + progress_bar.set_description( + f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB" + ) decode_batches = [] for micro_batch_index in [0, 1]: # dummy decoding, capture the cudagraph diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index c62a2572ff..3cbc5dc0f3 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -215,7 +215,7 @@ def _try_load_cache(self, static_key): cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key)) if os.path.exists(cache_file): - logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}") + logger.info(f"Loading cached configs for {self.kernel_name} - {dict(static_key)}") with open(cache_file, "rb") as f: self.cached_configs[static_key] = orjson.loads(f.read()) return True @@ -353,9 +353,9 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size): option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS | orjson.OPT_NON_STR_KEYS, ) ) - logger.info(f"Saved configs for {self.kernel_name} - {_static_key}") + logger.info(f"Saved configs for {self.kernel_name} - {dict(_static_key)}") - logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {static_key} finished") + logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {dict(static_key)} finished") def _mutate_args_clone(self, args, kwargs): origin_list = [] diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 6c3f8b3fe9..2ba917117a 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -116,7 +116,6 @@ def set_args(self, args: StartArgs): app = FastAPI() g_objs.app = app -_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"} _ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"} _ACCESS_LOG_RESET = "\033[0m" diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index a687971edf..ad170f5a18 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -487,8 +487,6 @@ def normal_or_p_d_start(args): f"{args.host}:{args.port}", "--log-level", "info", - "--access-logfile", - "-", "--error-logfile", "-", "lightllm.server.api_http:app", @@ -556,8 +554,6 @@ def pd_master_start(args): f"{args.host}:{args.port}", "--log-level", "info", - "--access-logfile", - "-", "--error-logfile", "-", "lightllm.server.api_http:app", @@ -652,8 +648,6 @@ def config_server_start(args): f"{args.config_server_host}:{args.config_server_port}", "--log-level", "info", - "--access-logfile", - "-", "--error-logfile", "-", "lightllm.server.config_server.api_http:app", diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 389171ba8a..b27c1c95bc 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -52,7 +52,7 @@ def _add_new_group_req_index(self, recv_obj: GroupReqIndexes): req.link_prompt_ids_shm_array() req.link_logprobs_shm_array() - logger.info( + logger.debug( f"detokenization recv req id {req.request_id} " f"cost time {time.time() - recv_obj.time_mark} s" ) @@ -76,7 +76,10 @@ def handle_loop(self): for _ in range(recv_max_count): recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) assert isinstance(recv_obj, GroupReqIndexes) - self._add_new_group_req_index(recv_obj=recv_obj) + try: + self._add_new_group_req_index(recv_obj=recv_obj) + except Exception: + logger.exception("add new group req index has exception") # 当队列中存在较多的请求时,将一次接受的数量上调 recv_max_count = min(int(recv_max_count * 1.3), 256) @@ -160,7 +163,7 @@ def remove_finished_reqs(self): for decode_req in finished_reqs: decode_req.req.can_released_mark = True - logger.info(f"detoken release req id {decode_req.req.request_id}") + logger.debug(f"detoken release req id {decode_req.req.request_id}") self.shm_req_manager.put_back_req_obj(decode_req.req) self.req_id_to_out.pop(decode_req.request_id, None) return diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index c772e97d19..78f4dd7e5c 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -481,7 +481,7 @@ async def _log_req_header(self, request_headers, group_request_id: int): x_session_id = request_headers.get("X-Session-Id", "") format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") - logger.info( + logger.debug( f"received req X-Request-Id:{x_request_id} " f"X-Session-Id:{x_session_id} start_time:{format_in_time} " f"lightllm_req_id:{group_request_id} " @@ -719,7 +719,7 @@ async def _wait_to_token_package( (out_token_counter - sum(sub_req_id_to_mtp_accepted_token_num.values())), 1 ) format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") - logger.info( + logger.debug( f"X-Request-Id:{x_request_id} " f"X-Session-Id:{x_session_id} start_time:{format_start_time} " f"lightllm_req_id:{group_request_id} first_token_cost:{first_token_cost_ms}ms " @@ -812,8 +812,8 @@ async def recycle_resource_loop(self): if req_status is None: continue - logger.info( - f"left req id {req_status.group_req_objs.group_req_id}" + logger.debug( + f"left req id {req_status.group_req_objs.group_req_id} " f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} " f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}" ) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index ae3c3d8960..94a30947bb 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -183,7 +183,7 @@ async def _log_req_header(self, request: Request, group_request_id: int): x_request_id = request.headers.get("X-Request-Id", "") x_session_id = request.headers.get("X-Session-Id", "") format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") - logger.info( + logger.debug( f"received req X-Request-Id:{x_request_id} " f"X-Session-Id:{x_session_id} start_time:{format_in_time} " f"lightllm_req_id:{group_request_id} " diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py index 24d0b9b824..34902d812a 100644 --- a/lightllm/server/router/batch.py +++ b/lightllm/server/router/batch.py @@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Tuple, Union from lightllm.server.core.objs import ShmReqManager, Req from lightllm.utils.log_utils import init_logger -from .stats import RouterStatics logger = init_logger(__name__) @@ -50,14 +49,11 @@ def get_all_dp_req_num(self) -> List[int]: all_dp_req_num[req.sample_params.suggested_dp_index] += 1 return all_dp_req_num - def filter_out_finished_req(self, shm_req_manager: ShmReqManager, router_statics: RouterStatics): + def filter_out_finished_req(self, shm_req_manager: ShmReqManager): unfinished_req_ids = [] for req in self.reqs: if req.shm_infer_released: - logger.info(f"router release req id {req.request_id}") - if not req.is_aborted: - router_statics.update(req.candetoken_out_len) - + logger.debug(f"router release req id {req.request_id}") shm_req_manager.put_back_req_obj(req) req = None else: diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index f5e0b8df9a..3ee4de8848 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -16,6 +16,7 @@ from .batch import Batch, Req from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue +from .stats import SystemStatusReporter from lightllm.server.core.objs.io_objs import ( GroupReqIndexes, AbortedReqCmd, @@ -25,7 +26,7 @@ from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer -from lightllm.utils.log_utils import init_logger, log_time_ready +from lightllm.utils.log_utils import init_logger from lightllm.server.router.token_load import TokenLoad from lightllm.server.metrics.manager import MetricClient from lightllm.common.basemodel.infer_lock import g_router_lock @@ -65,6 +66,7 @@ def __init__(self, args: StartArgs): self.read_only_statics_mem_manager = ReadOnlyStaticsMemoryManager() # 初始化 radix_cache_client 用于读取 prompt cache 的管理信息 self.radix_cache_client = None + self.status_reporter = None # 共享变量,用于存储router端调度分析得到的机器负载信息 self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) @@ -194,6 +196,11 @@ async def wait_to_model_ready(self): ) self.req_queue = build_req_queue(self.args, self, self.dp_size_in_node) logger.info(f"use req queue {self.req_queue.__class__.__name__}") + self.status_reporter = SystemStatusReporter( + args=self.args, + max_total_token_num=self.max_total_token_num, + dp_size_in_node=self.dp_size_in_node, + ) if self.args.run_mode == "prefill": # 启动 prefill kv move 管理进程 @@ -239,27 +246,10 @@ async def loop_for_fwd( await self._step() counter_count += 1 if self.running_batch is not None: + # Count output tokens (each running req produces ~1 token per decode step) + self.status_reporter.count_output_tokens(len(self.running_batch.reqs)) if counter_count % 100 == 0: - for dp_index in range(self.dp_size_in_node): - token_ratio1 = self.get_used_tokens(dp_index) / self.max_total_token_num - token_ratio2 = ( - self.max_total_token_num - - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index) - ) / self.max_total_token_num - d_i = dp_index - frozen_token_num = self.shared_token_load.get_frozened_token_count(d_i) - estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i) - paused_req_num = self._get_paused_req_num_in_dp_index(dp_index=d_i) - logger.debug( - f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n" - f"dp_i {d_i} paused req num: {paused_req_num} \n" - f"dp_i {d_i} frozen token num: {frozen_token_num} \n" - f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n" - f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n" - f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token" - ) - logger.debug(self.router_statics.log_str()) - self.metric_client.gauge_set("lightllm_batch_pause_size", paused_req_num) + self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num()) # pd decode mode need to update token_load more frequently self.req_queue.update_token_load(self.running_batch, force_update=self.is_pd_decode_mode) self.metric_client.gauge_set("lightllm_batch_current_size", len(self.running_batch.reqs)) @@ -278,13 +268,15 @@ async def loop_for_fwd( self.metric_client.gauge_set("lightllm_batch_pause_size", 0.0) self.metric_client.gauge_set("lightllm_queue_size", 0.0) self.metric_client.gauge_set("lightllm_batch_current_max_tokens", 0.0) - # 60s print once - if log_time_ready("frozen_info", 60): - for dp_i in range(self.dp_size_in_node): - frozen_token_num = self.shared_token_load.get_frozened_token_count(dp_i) - estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(dp_i) - logger.debug(f"dp_i {dp_i} frozen token num: {frozen_token_num} \n") - logger.debug(f"dp_i {dp_i} estimated_peak_token_count: {estimated_peak_token_count} \n") + + self.status_reporter.maybe_print( + running_batch=self.running_batch, + req_queue=self.req_queue, + read_only_statics_mem_manager=self.read_only_statics_mem_manager, + paused_req_num=self._get_paused_req_num(), + radix_cache_client=self.radix_cache_client, + disable_dynamic_prompt_cache=self.args.disable_dynamic_prompt_cache, + ) await asyncio.sleep(self._get_schedule_time_interval()) @@ -314,6 +306,7 @@ async def _step(self): async def _add_batch(self, batch: Batch): # 添加新请求 + self.status_reporter.count_prompt_tokens(batch.input_tokens()) reqs = [r.to_router_rpc_obj() for r in batch.reqs] while not self.shm_reqs_io_buffer.is_empty(): await asyncio.sleep(0.001) @@ -347,7 +340,16 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch): def _filter_reqs_from_running_batch(self): if self.running_batch is not None: - self.running_batch.filter_out_finished_req(self.shm_req_manager, self.router_statics) + # Capture finished req stats before filtering + for req in self.running_batch.reqs: + if req.shm_infer_released: + self.status_reporter.on_request_completed( + input_len=req.input_len, + output_len=req.shm_cur_output_len, + cache_len=req.prompt_cache_len, + mtp_accepted=req.mtp_accepted_token_num, + ) + self.running_batch.filter_out_finished_req(self.shm_req_manager) if self.running_batch.is_clear(): self.running_batch = None return @@ -419,7 +421,7 @@ def _add_req(self, group_req_indexes: GroupReqIndexes): req._router_stop_str_matched = False req_group.append(req) - logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s") + logger.debug(f"router receive req id {req.request_id} cost time {time.time() - req.start_time} s") self.req_queue.extend(req_group) self.send_to_detokenization.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) return @@ -431,7 +433,7 @@ def _generate_new_batch(self): ) self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch) if self.schedule_new_batch is not None: - logger.info(f"gen new batch, {self.schedule_new_batch.simple_log()}") + logger.debug(f"gen new batch, {self.schedule_new_batch.simple_log()}") return def _multinode_tp_generate_new_batch(self): diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index b715c5bcb3..85548d9138 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -1,7 +1,126 @@ -from lightllm.utils.log_utils import init_logger +import time +import logging from lightllm.server.core.objs import StartArgs +from lightllm.utils.log_utils import init_system_status_logger -logger = init_logger(__name__) +logger = logging.getLogger(__name__) + + +class SystemStatusReporter: + def __init__(self, args, max_total_token_num, dp_size_in_node): + self.enabled = not args.disable_log_stats + self.interval = max(5, args.log_stats_interval) + if args.log_stats_interval < 5: + logger.warning(f"log_stats_interval={args.log_stats_interval}s is below minimum, using 5s") + self.max_total_token_num = max_total_token_num + self.dp_size_in_node = dp_size_in_node + self.status_logger = init_system_status_logger("router") + + # Accumulation counters (reset each interval) + self.last_print_time = time.time() + self.prompt_tokens = 0 + self.output_tokens = 0 + + # Windowed counters for cache hit (reset each interval) + self.window_input_total = 0 + self.window_cache_total = 0 + + # Global counters (never reset, for lifetime stats) + self.global_input_total = 0 + self.global_cache_total = 0 + self.global_mtp_output_total = 0 + self.global_mtp_accepted_total = 0 + + def count_prompt_tokens(self, num_tokens: int): + if self.enabled: + self.prompt_tokens += num_tokens + + def count_output_tokens(self, num_tokens: int): + if self.enabled: + self.output_tokens += num_tokens + + def on_request_completed(self, input_len: int, output_len: int, cache_len: int, mtp_accepted: int): + if self.enabled: + self.window_input_total += input_len + self.window_cache_total += cache_len + self.global_input_total += input_len + self.global_cache_total += cache_len + self.global_mtp_output_total += output_len + self.global_mtp_accepted_total += mtp_accepted + + def maybe_print( + self, + running_batch, + req_queue, + read_only_statics_mem_manager, + paused_req_num=0, + radix_cache_client=None, + disable_dynamic_prompt_cache=False, + ): + if not self.enabled: + return + now = time.time() + elapsed = now - self.last_print_time + if elapsed < self.interval: + return + + total_tps = (self.prompt_tokens + self.output_tokens) / elapsed + input_tps = self.prompt_tokens / elapsed + output_tps = self.output_tokens / elapsed + + running = len(running_batch.reqs) if running_batch else 0 + queued = req_queue.get_wait_req_num() + + # Memory utilization (average across dp) + # kv_used: physical KV memory usage (includes prefix cache tree occupancy) + # kv_used_no_cache: effective usage excluding unrefed prefix cache tokens + kv_used_list = [] + kv_used_no_cache_list = [] + for dp_i in range(self.dp_size_in_node): + unrefed = read_only_statics_mem_manager.get_unrefed_token_num(dp_i) + used = self.max_total_token_num - unrefed + kv_used_list.append(used / self.max_total_token_num) + if not disable_dynamic_prompt_cache and radix_cache_client is not None: + cache_unrefed = radix_cache_client.get_unrefed_tokens_num(dp_i) + kv_used_no_cache_list.append((used - cache_unrefed) / self.max_total_token_num) + else: + kv_used_no_cache_list.append(used / self.max_total_token_num) + avg_kv_used = sum(kv_used_list) / len(kv_used_list) + avg_kv_used_no_cache = sum(kv_used_no_cache_list) / len(kv_used_no_cache_list) + + # Windowed prefix cache hit rate (this interval only) + window_cache_hit_rate = ( + (self.window_cache_total / self.window_input_total * 100) if self.window_input_total > 0 else 0.0 + ) + # Global prefix cache hit rate (lifetime) + global_cache_hit_rate = ( + (self.global_cache_total / self.global_input_total * 100) if self.global_input_total > 0 else 0.0 + ) + + kv_pct = avg_kv_used * 100 + kv_pct_no_cache = avg_kv_used_no_cache * 100 + + # Avg MTP accepted length (only shown when MTP is active) + mtp_suffix = "" + if self.global_mtp_accepted_total > 0: + decode_steps = self.global_mtp_output_total - self.global_mtp_accepted_total + avg_mtp_len = self.global_mtp_output_total / max(decode_steps, 1) + mtp_suffix = f", MTP {avg_mtp_len:.2f}" + + self.status_logger.info( + f"TPS {total_tps:.1f} (in {input_tps:.1f}, out {output_tps:.1f}), " + f"REQ {running}run, {queued}wait, {paused_req_num}pause, " + f"KV CACHE {kv_pct:.1f}% (active {kv_pct_no_cache:.1f}%), " + f"CACHE HIT {window_cache_hit_rate:.1f}% (global {global_cache_hit_rate:.1f}%)" + f"{mtp_suffix}" + ) + + # Reset windowed counters + self.prompt_tokens = 0 + self.output_tokens = 0 + self.window_input_total = 0 + self.window_cache_total = 0 + self.last_print_time = now class RouterStatics: diff --git a/lightllm/utils/log_utils.py b/lightllm/utils/log_utils.py index f15309d5cf..c3057d18f4 100644 --- a/lightllm/utils/log_utils.py +++ b/lightllm/utils/log_utils.py @@ -4,24 +4,41 @@ import logging import sys import os -import time from typing import Optional _FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s" _DATE_FORMAT = "%m-%d %H:%M:%S" -_LOG_LEVEL = os.environ.get("LIGHTLLM_LOG_LEVEL", "debug") +_STATUS_FORMAT = "%(levelname)s [%(asctime)s] %(message)s" + +_LOG_LEVEL = os.environ.get("LIGHTLLM_LOG_LEVEL", "info") _LOG_LEVEL = getattr(logging, _LOG_LEVEL.upper(), 0) _LOG_DIR = os.environ.get("LIGHTLLM_LOG_DIR", None) +# ANSI color codes +_RESET = "\033[0m" +_LEVEL_COLORS = { + logging.DEBUG: "\033[36m", # cyan + logging.INFO: "\033[32m", # green + logging.WARNING: "\033[33m", # yellow + logging.ERROR: "\033[31m", # red + logging.CRITICAL: "\033[1;31m", # bold red +} + class NewLineFormatter(logging.Formatter): - """Adds logging prefix to newlines to align multi-line messages.""" + """Adds logging prefix to newlines to align multi-line messages, with optional color on levelname.""" - def __init__(self, fmt, datefmt=None): + def __init__(self, fmt, datefmt=None, use_color=False): logging.Formatter.__init__(self, fmt, datefmt) + self.use_color = use_color def format(self, record): + if self.use_color: + color = _LEVEL_COLORS.get(record.levelno, "") + if color: + record = logging.makeLogRecord(record.__dict__) + record.levelname = color + record.levelname + _RESET msg = logging.Formatter.format(self, record) if record.message != "": parts = msg.split(record.message) @@ -39,7 +56,9 @@ def _setup_logger(): _root_logger.setLevel(_LOG_LEVEL) global _default_handler global _default_file_handler - fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) + _use_color = hasattr(sys.stdout, "isatty") and sys.stdout.isatty() + color_fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT, use_color=_use_color) + plain_fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT, use_color=False) if _default_handler is None: _default_handler = logging.StreamHandler(sys.stdout) @@ -55,10 +74,10 @@ def _setup_logger(): _root_logger.warn(f"Error creating directory {_LOG_DIR} : {e}") _default_file_handler = logging.FileHandler(_LOG_DIR + "/default.log") _default_file_handler.setLevel(_LOG_LEVEL) - _default_file_handler.setFormatter(fmt) + _default_file_handler.setFormatter(plain_fmt) _root_logger.addHandler(_default_file_handler) - _default_handler.setFormatter(fmt) + _default_handler.setFormatter(color_fmt) # Setting this will avoid the message # being propagated to the parent logger. _root_logger.propagate = False @@ -89,29 +108,28 @@ def init_logger(name: str): _root_logger.warn(f"Error creating directory {_LOG_DIR} : {e}") _inference_log_file_handler[pid] = logging.FileHandler(_LOG_DIR + f"/process.{pid}.log") _inference_log_file_handler[pid].setLevel(_LOG_LEVEL) - _inference_log_file_handler[pid].setFormatter(NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT)) + _inference_log_file_handler[pid].setFormatter( + NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT, use_color=False) + ) _root_logger.addHandler(_inference_log_file_handler[pid]) logger.addHandler(_inference_log_file_handler[pid]) logger.propagate = False return logger -_log_time_mark_dict = {} - - -def log_time_ready(mark_name, time_count: int): - """ - time_count 间隔时间超过多少s调用该函数会返回True,否则返回False - 用于控制一些日志输出的频率 - """ - global _log_time_mark_dict - - if mark_name not in _log_time_mark_dict: - _log_time_mark_dict[mark_name] = time.time() - return False - cur_time_mark = time.time() - if cur_time_mark - _log_time_mark_dict[mark_name] >= time_count: - _log_time_mark_dict[mark_name] = cur_time_mark - return True - else: - return False +def init_system_status_logger(name: str): + logger = logging.getLogger(f"lightllm.status.{name}") + if not logger.handlers: + logger.setLevel(logging.INFO) + fmt = logging.Formatter(_STATUS_FORMAT, datefmt=_DATE_FORMAT) + handler = logging.StreamHandler(sys.stdout) + handler.flush = sys.stdout.flush + handler.setFormatter(fmt) + logger.addHandler(handler) + if _LOG_DIR is not None: + file_handler = logging.FileHandler(os.path.join(_LOG_DIR, f"status.{name}.log")) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(fmt) + logger.addHandler(file_handler) + logger.propagate = False + return logger From 1de38dac3a116f5510f77bdd5efaef009c804282 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 6 May 2026 11:26:51 +0800 Subject: [PATCH 02/10] fix(httpserver): reject oversized prompts and translate ValueError to 400 Reject prompts whose character length exceeds max_req_total_len * 8 before tokenization, so a long string can no longer reach the tokenizer and stall the loop. The raised ValueError is caught one level up: log it at WARNING, release any held multimodal resources, abort the in-flight group request, and re-raise so the API layer (which already maps ValueError to HTTP 400) returns a graceful error to the client instead of a 500. --- lightllm/server/httpserver/manager.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 78f4dd7e5c..4667f996b5 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -298,6 +298,13 @@ async def generate( # 用于等待 pd_master 下发的交换信息 nixl_pd_event: asyncio.Event = None, ) -> AsyncGenerator[Tuple[int, str, dict, FinishStatus], None]: + if isinstance(prompt, str): + max_prompt_chars = self.max_req_total_len * 8 + if len(prompt) > max_prompt_chars: + raise ValueError( + f"prompt text length {len(prompt)} exceeds the character limit {max_prompt_chars}, " + f"the request is rejected before tokenization." + ) start_time = time.time() request_headers = request.headers if request is not None else {} @@ -445,6 +452,12 @@ async def generate( yield sub_req_id, request_output, metadata, finish_status + except ValueError as e: + logger.warning(f"group_request_id: {group_request_id} request invalid: {str(e)}") + if group_request_id not in self.req_id_to_out_inf: + await self._release_multimodal_resources(multimodal_params) + await self.abort(group_request_id) + raise e except (ClientDisconnected, Exception) as e: logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") From a38235b1032fc34344d44efdcf3e7e82b74182c0 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 6 May 2026 13:26:35 +0800 Subject: [PATCH 03/10] fix(router): restore router_statics.update() on req completion The earlier refactor that moved the finished-req loop out of Batch.filter_out_finished_req into Router._filter_reqs_from_running_batch forgot to keep the router_statics.update(candetoken_out_len) call, freezing ema_req_out_len at its initial value. Multiple schedulers (chunked_prefill, beam, pd_decode, nixl_pd) read that EMA for KV-budget estimation, so leaving it stale degraded scheduling accuracy. --- lightllm/server/router/manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 3ee4de8848..77af8a2e5a 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -349,6 +349,7 @@ def _filter_reqs_from_running_batch(self): cache_len=req.prompt_cache_len, mtp_accepted=req.mtp_accepted_token_num, ) + self.router_statics.update(req.candetoken_out_len) self.running_batch.filter_out_finished_req(self.shm_req_manager) if self.running_batch.is_clear(): self.running_batch = None From fa207242798c24a782c127017ba89cde0db1f76c Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 6 May 2026 15:28:54 +0800 Subject: [PATCH 04/10] fix(router): output TPS via per-req deltas, skip aborted reqs in stats MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two correctness fixes flagged in PR review: 1. count_output_tokens(len(running_batch.reqs)) once per router loop is wrong — the router loop polls on schedule_time_interval, decoupled from inference, so this overcounts when the loop is faster than decode and undercounts when slower, and includes paused/prefill-only reqs. Track shm_cur_output_len per request and accumulate the delta each tick (with a tail settlement when the req is filtered out so we don't lose its last tokens to the post-final-tick window). 2. on_request_completed() and router_statics.update() now both run for aborted requests, whose candetoken_out_len is a short partial value. Restore the prior `if not req.is_aborted` guard so disconnects don't bias the output-length EMA used by KV-budget estimators. --- lightllm/server/router/manager.py | 44 ++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 77af8a2e5a..6d312c55ed 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -67,6 +67,9 @@ def __init__(self, args: StartArgs): # 初始化 radix_cache_client 用于读取 prompt cache 的管理信息 self.radix_cache_client = None self.status_reporter = None + # Track shm_cur_output_len per running request to compute per-tick deltas + # for accurate output TPS regardless of router schedule interval. + self._req_last_output_len: Dict[int, int] = {} # 共享变量,用于存储router端调度分析得到的机器负载信息 self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) @@ -246,8 +249,18 @@ async def loop_for_fwd( await self._step() counter_count += 1 if self.running_batch is not None: - # Count output tokens (each running req produces ~1 token per decode step) - self.status_reporter.count_output_tokens(len(self.running_batch.reqs)) + # Count output tokens via per-request shm_cur_output_len deltas, since the + # router loop runs on schedule_time_interval and len(reqs) is not a per-step + # token count. + new_output_tokens = 0 + for req in self.running_batch.reqs: + cur_out_len = req.shm_cur_output_len + prev_out_len = self._req_last_output_len.get(req.request_id, 0) + if cur_out_len > prev_out_len: + new_output_tokens += cur_out_len - prev_out_len + self._req_last_output_len[req.request_id] = cur_out_len + if new_output_tokens: + self.status_reporter.count_output_tokens(new_output_tokens) if counter_count % 100 == 0: self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num()) # pd decode mode need to update token_load more frequently @@ -342,14 +355,25 @@ def _filter_reqs_from_running_batch(self): if self.running_batch is not None: # Capture finished req stats before filtering for req in self.running_batch.reqs: - if req.shm_infer_released: - self.status_reporter.on_request_completed( - input_len=req.input_len, - output_len=req.shm_cur_output_len, - cache_len=req.prompt_cache_len, - mtp_accepted=req.mtp_accepted_token_num, - ) - self.router_statics.update(req.candetoken_out_len) + if not req.shm_infer_released: + continue + # Settle any output-token delta produced after the last router tick + # so windowed TPS does not lose the request's tail tokens. + cur_out_len = req.shm_cur_output_len + prev_out_len = self._req_last_output_len.pop(req.request_id, 0) + if cur_out_len > prev_out_len: + self.status_reporter.count_output_tokens(cur_out_len - prev_out_len) + # Aborted/disconnected requests can leave a partial output_len that + # would bias the EMA toward shorter generations; skip them. + if req.is_aborted: + continue + self.status_reporter.on_request_completed( + input_len=req.input_len, + output_len=cur_out_len, + cache_len=req.prompt_cache_len, + mtp_accepted=req.mtp_accepted_token_num, + ) + self.router_statics.update(req.candetoken_out_len) self.running_batch.filter_out_finished_req(self.shm_req_manager) if self.running_batch.is_clear(): self.running_batch = None From 9b867615d5cc48602273e52511b31b1c074dcc74 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 6 May 2026 16:05:59 +0800 Subject: [PATCH 05/10] perf(router): sweep output-token deltas once per print interval Move the per-running-req shm_cur_output_len delta tracking from the router tick (~33 Hz) into SystemStatusReporter.maybe_print, which only runs once per log_stats_interval (>= 5s). The reporter now owns the per-req snapshot dict and exposes discard_req(req) for tail settlement when a req leaves the running batch, so the router loop's hot path no longer walks the batch every schedule cycle. Output TPS accuracy is unchanged: still based on real shm_cur_output_len deltas, with tail tokens settled at completion. --- lightllm/server/router/manager.py | 28 ++++++---------------------- lightllm/server/router/stats.py | 28 +++++++++++++++++++++++++--- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 6d312c55ed..8009063674 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -67,9 +67,6 @@ def __init__(self, args: StartArgs): # 初始化 radix_cache_client 用于读取 prompt cache 的管理信息 self.radix_cache_client = None self.status_reporter = None - # Track shm_cur_output_len per running request to compute per-tick deltas - # for accurate output TPS regardless of router schedule interval. - self._req_last_output_len: Dict[int, int] = {} # 共享变量,用于存储router端调度分析得到的机器负载信息 self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) @@ -249,18 +246,8 @@ async def loop_for_fwd( await self._step() counter_count += 1 if self.running_batch is not None: - # Count output tokens via per-request shm_cur_output_len deltas, since the - # router loop runs on schedule_time_interval and len(reqs) is not a per-step - # token count. - new_output_tokens = 0 - for req in self.running_batch.reqs: - cur_out_len = req.shm_cur_output_len - prev_out_len = self._req_last_output_len.get(req.request_id, 0) - if cur_out_len > prev_out_len: - new_output_tokens += cur_out_len - prev_out_len - self._req_last_output_len[req.request_id] = cur_out_len - if new_output_tokens: - self.status_reporter.count_output_tokens(new_output_tokens) + # Output-token counting is done in bulk at the print-window boundary + # inside SystemStatusReporter.maybe_print, so the router tick stays cheap. if counter_count % 100 == 0: self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num()) # pd decode mode need to update token_load more frequently @@ -357,19 +344,16 @@ def _filter_reqs_from_running_batch(self): for req in self.running_batch.reqs: if not req.shm_infer_released: continue - # Settle any output-token delta produced after the last router tick - # so windowed TPS does not lose the request's tail tokens. - cur_out_len = req.shm_cur_output_len - prev_out_len = self._req_last_output_len.pop(req.request_id, 0) - if cur_out_len > prev_out_len: - self.status_reporter.count_output_tokens(cur_out_len - prev_out_len) + # Settle any output-token tail produced after the last window boundary, + # so windowed TPS does not lose the req's last tokens. + self.status_reporter.discard_req(req) # Aborted/disconnected requests can leave a partial output_len that # would bias the EMA toward shorter generations; skip them. if req.is_aborted: continue self.status_reporter.on_request_completed( input_len=req.input_len, - output_len=cur_out_len, + output_len=req.shm_cur_output_len, cache_len=req.prompt_cache_len, mtp_accepted=req.mtp_accepted_token_num, ) diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index 85548d9138..f6db924b53 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -1,5 +1,6 @@ import time import logging +from typing import Dict from lightllm.server.core.objs import StartArgs from lightllm.utils.log_utils import init_system_status_logger @@ -31,13 +32,23 @@ def __init__(self, args, max_total_token_num, dp_size_in_node): self.global_mtp_output_total = 0 self.global_mtp_accepted_total = 0 + # Per-req shm_cur_output_len snapshot at the previous window boundary, + # used to compute the windowed output-token count without per-tick scans. + self._req_last_output_len: Dict[int, int] = {} + def count_prompt_tokens(self, num_tokens: int): if self.enabled: self.prompt_tokens += num_tokens - def count_output_tokens(self, num_tokens: int): - if self.enabled: - self.output_tokens += num_tokens + def discard_req(self, req): + """Settle a finished/aborted req's tail output tokens (those produced after the last + window-boundary sweep) and drop its tracking entry.""" + if not self.enabled: + return + cur_out_len = req.shm_cur_output_len + prev_out_len = self._req_last_output_len.pop(req.request_id, 0) + if cur_out_len > prev_out_len: + self.output_tokens += cur_out_len - prev_out_len def on_request_completed(self, input_len: int, output_len: int, cache_len: int, mtp_accepted: int): if self.enabled: @@ -64,6 +75,17 @@ def maybe_print( if elapsed < self.interval: return + # Single bulk sweep at the window boundary: account for output tokens produced + # by every still-running req since the previous boundary, and refresh their + # snapshots. Reqs that finished in this window already settled via discard_req. + if running_batch is not None: + for req in running_batch.reqs: + cur_out_len = req.shm_cur_output_len + prev_out_len = self._req_last_output_len.get(req.request_id, 0) + if cur_out_len > prev_out_len: + self.output_tokens += cur_out_len - prev_out_len + self._req_last_output_len[req.request_id] = cur_out_len + total_tps = (self.prompt_tokens + self.output_tokens) / elapsed input_tps = self.prompt_tokens / elapsed output_tps = self.output_tokens / elapsed From 09a488b46f39c8139ad0a75d460326c88bcc1573 Mon Sep 17 00:00:00 2001 From: sufubao Date: Sat, 9 May 2026 11:46:40 +0800 Subject: [PATCH 06/10] fix(httpserver,router): defensive group_request_id init; reorder is_aborted skip - httpserver: initialize group_request_id=None so the ValueError except handler does not hit UnboundLocalError when the oversized-prompt guard raises before alloc_req_id. - router: move the is_aborted skip after on_request_completed so aborted reqs still update completion stats, but do not pollute the router_statics EMA with their truncated output_len. --- lightllm/server/httpserver/manager.py | 4 ++++ lightllm/server/router/manager.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 4667f996b5..e068425982 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -298,7 +298,11 @@ async def generate( # 用于等待 pd_master 下发的交换信息 nixl_pd_event: asyncio.Event = None, ) -> AsyncGenerator[Tuple[int, str, dict, FinishStatus], None]: + group_request_id = None if isinstance(prompt, str): + # Guard against extremely long string prompts that might stall the tokenizer + # or cause excessive memory usage before tokenization. + # 8 characters per token is a conservative heuristic (avg is ~4). max_prompt_chars = self.max_req_total_len * 8 if len(prompt) > max_prompt_chars: raise ValueError( diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 8009063674..af892d618e 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -347,16 +347,16 @@ def _filter_reqs_from_running_batch(self): # Settle any output-token tail produced after the last window boundary, # so windowed TPS does not lose the req's last tokens. self.status_reporter.discard_req(req) - # Aborted/disconnected requests can leave a partial output_len that - # would bias the EMA toward shorter generations; skip them. - if req.is_aborted: - continue self.status_reporter.on_request_completed( input_len=req.input_len, output_len=req.shm_cur_output_len, cache_len=req.prompt_cache_len, mtp_accepted=req.mtp_accepted_token_num, ) + # Aborted/disconnected requests can leave a partial output_len that + # would bias the EMA toward shorter generations; skip them. + if req.is_aborted: + continue self.router_statics.update(req.candetoken_out_len) self.running_batch.filter_out_finished_req(self.shm_req_manager) if self.running_batch.is_clear(): From ee5fa88110592b6be27dcb80fb28071f5ee64c1d Mon Sep 17 00:00:00 2001 From: sufubao Date: Sat, 9 May 2026 13:40:57 +0800 Subject: [PATCH 07/10] fix(detokenization): keep req registration failures fatal --- lightllm/server/detokenization/manager.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index b27c1c95bc..1cb5357562 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -76,10 +76,7 @@ def handle_loop(self): for _ in range(recv_max_count): recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) assert isinstance(recv_obj, GroupReqIndexes) - try: - self._add_new_group_req_index(recv_obj=recv_obj) - except Exception: - logger.exception("add new group req index has exception") + self._add_new_group_req_index(recv_obj=recv_obj) # 当队列中存在较多的请求时,将一次接受的数量上调 recv_max_count = min(int(recv_max_count * 1.3), 256) From da6fc3eedcc35cb25f99586069a690fe783c1fd1 Mon Sep 17 00:00:00 2001 From: sufubao Date: Sat, 9 May 2026 13:58:36 +0800 Subject: [PATCH 08/10] refactor(router): improve status log format --- lightllm/server/router/stats.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index f6db924b53..dac37b67b2 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -122,20 +122,24 @@ def maybe_print( kv_pct = avg_kv_used * 100 kv_pct_no_cache = avg_kv_used_no_cache * 100 + log_parts = [ + f"router_status(window={elapsed:.1f}s)", + f"throughput(total={total_tps:.1f},input={input_tps:.1f},output={output_tps:.1f})", + f"req(running={running},waiting={queued},paused={paused_req_num})", + f"kv(used={kv_pct_no_cache:.1f}%)", + f"gpu_cache_hit(window={window_cache_hit_rate:.1f}%,global={global_cache_hit_rate:.1f}%)", + ] + # Avg MTP accepted length (only shown when MTP is active) - mtp_suffix = "" if self.global_mtp_accepted_total > 0: decode_steps = self.global_mtp_output_total - self.global_mtp_accepted_total avg_mtp_len = self.global_mtp_output_total / max(decode_steps, 1) - mtp_suffix = f", MTP {avg_mtp_len:.2f}" - - self.status_logger.info( - f"TPS {total_tps:.1f} (in {input_tps:.1f}, out {output_tps:.1f}), " - f"REQ {running}run, {queued}wait, {paused_req_num}pause, " - f"KV CACHE {kv_pct:.1f}% (active {kv_pct_no_cache:.1f}%), " - f"CACHE HIT {window_cache_hit_rate:.1f}% (global {global_cache_hit_rate:.1f}%)" - f"{mtp_suffix}" - ) + log_parts.append( + f"mtp(avg_tokens_per_step={avg_mtp_len:.2f}," + f"accepted={self.global_mtp_accepted_total},output={self.global_mtp_output_total})" + ) + + self.status_logger.info(" | ".join(log_parts)) # Reset windowed counters self.prompt_tokens = 0 From 1b8274c85b5d5e8fbec3f7ce866ce7684575b8b3 Mon Sep 17 00:00:00 2001 From: sufubao Date: Sat, 9 May 2026 14:09:17 +0800 Subject: [PATCH 09/10] fix(router): remove unused status log variable --- lightllm/server/router/stats.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index dac37b67b2..8aac2b0360 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -119,7 +119,6 @@ def maybe_print( (self.global_cache_total / self.global_input_total * 100) if self.global_input_total > 0 else 0.0 ) - kv_pct = avg_kv_used * 100 kv_pct_no_cache = avg_kv_used_no_cache * 100 log_parts = [ From 213eac3e51fb3bdbc396572f7b5b48c21bc3c719 Mon Sep 17 00:00:00 2001 From: sufubao Date: Sat, 9 May 2026 14:47:17 +0800 Subject: [PATCH 10/10] feat(router): add debug status details --- lightllm/server/router/stats.py | 53 +++++++++++++++++++++++++ lightllm/server/visualserver/manager.py | 2 +- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index 8aac2b0360..c414a43861 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -1,5 +1,6 @@ import time import logging +import subprocess from typing import Dict from lightllm.server.core.objs import StartArgs from lightllm.utils.log_utils import init_system_status_logger @@ -59,6 +60,44 @@ def on_request_completed(self, input_len: int, output_len: int, cache_len: int, self.global_mtp_output_total += output_len self.global_mtp_accepted_total += mtp_accepted + def _get_gpu_status_for_debug(self) -> str: + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=index,utilization.gpu,memory.used,memory.total", + "--format=csv,noheader,nounits", + ], + check=True, + capture_output=True, + text=True, + timeout=2, + ) + except (OSError, subprocess.SubprocessError) as e: + return f"gpu=unavailable({e.__class__.__name__})" + + gpu_infos = [] + for line in result.stdout.splitlines(): + parts = [part.strip() for part in line.split(",")] + if len(parts) != 4: + continue + gpu_index, util, mem_used, mem_total = parts + try: + mem_used_mb = float(mem_used) + mem_total_mb = float(mem_total) + mem_ratio = mem_used_mb / mem_total_mb * 100 if mem_total_mb > 0 else 0.0 + mem_used_gb = mem_used_mb / 1024 + mem_total_gb = mem_total_mb / 1024 + gpu_infos.append( + f"{gpu_index}(util={float(util):.0f}%,mem={mem_ratio:.1f}%," + f"used={mem_used_gb:.1f}GiB/{mem_total_gb:.1f}GiB)" + ) + except ValueError: + continue + if not gpu_infos: + return "gpu=unavailable(empty)" + return "gpu=[" + ";".join(gpu_infos) + "]" + def maybe_print( self, running_batch, @@ -119,6 +158,7 @@ def maybe_print( (self.global_cache_total / self.global_input_total * 100) if self.global_input_total > 0 else 0.0 ) + kv_pct = avg_kv_used * 100 kv_pct_no_cache = avg_kv_used_no_cache * 100 log_parts = [ @@ -139,6 +179,19 @@ def maybe_print( ) self.status_logger.info(" | ".join(log_parts)) + if logger.isEnabledFor(logging.DEBUG): + kv_unrefed_prefix_cache_pct = max(0.0, kv_pct - kv_pct_no_cache) + debug_parts = [ + "router_status_debug", + f"kv_physical={kv_pct:.1f}%", + f"kv_unrefed_prefix_cache={kv_unrefed_prefix_cache_pct:.1f}%", + f"throughput_tokens(input={self.prompt_tokens},output={self.output_tokens})", + f"gpu_cache_tokens(window={self.window_cache_total}/{self.window_input_total}," + f"global={self.global_cache_total}/{self.global_input_total})", + f"tracked_output_reqs={len(self._req_last_output_len)}", + self._get_gpu_status_for_debug(), + ] + logger.debug(" | ".join(debug_parts)) # Reset windowed counters self.prompt_tokens = 0 diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 1dffdaf681..ef86378533 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -176,7 +176,7 @@ async def loop_for_netio_req(self): while True: recv_req: GroupReqIndexes = await asyncio.to_thread(self.zmq_recv_socket.recv_pyobj) if isinstance(recv_req, GroupReqIndexes): - logger.info( + logger.debug( f"visual recv req id {recv_req.group_req_id} " f"img count {len(recv_req.multimodal_params.images)}" )