diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index dd29c9a833..40e3ad9e63 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -3,6 +3,7 @@ import copy import bisect from typing import Optional +from tqdm import tqdm from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup @@ -196,7 +197,11 @@ def warmup(self, model): model: TpPartBaseModel = model # decode cuda graph init - for batch_size in self.cuda_graph_batch_sizes[::-1]: + progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs") + for batch_size in progress_bar: + avail_mem, _ = torch.cuda.mem_get_info() + avail_mem_gb = avail_mem / (1024 ** 3) + progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB") seq_len = 2 total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch @@ -251,7 +256,13 @@ def warmup_overlap(self, model): model: TpPartBaseModel = model - for batch_size in self.cuda_graph_batch_sizes[::-1]: + progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs") + for batch_size in progress_bar: + avail_mem, _ = torch.cuda.mem_get_info() + avail_mem_gb = avail_mem / (1024 ** 3) + progress_bar.set_description( + f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB" + ) decode_batches = [] for micro_batch_index in [0, 1]: # dummy decoding, capture the cudagraph diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index c62a2572ff..3cbc5dc0f3 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -215,7 +215,7 @@ def _try_load_cache(self, static_key): cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key)) if os.path.exists(cache_file): - logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}") + logger.info(f"Loading cached configs for {self.kernel_name} - {dict(static_key)}") with open(cache_file, "rb") as f: self.cached_configs[static_key] = orjson.loads(f.read()) return True @@ -353,9 +353,9 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size): option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS | orjson.OPT_NON_STR_KEYS, ) ) - logger.info(f"Saved configs for {self.kernel_name} - {_static_key}") + logger.info(f"Saved configs for {self.kernel_name} - {dict(_static_key)}") - logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {static_key} finished") + logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {dict(static_key)} finished") def _mutate_args_clone(self, args, kwargs): origin_list = [] diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 389171ba8a..1cb5357562 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -52,7 +52,7 @@ def _add_new_group_req_index(self, recv_obj: GroupReqIndexes): req.link_prompt_ids_shm_array() req.link_logprobs_shm_array() - logger.info( + logger.debug( f"detokenization recv req id {req.request_id} " f"cost time {time.time() - recv_obj.time_mark} s" ) @@ -160,7 +160,7 @@ def remove_finished_reqs(self): for decode_req in finished_reqs: decode_req.req.can_released_mark = True - logger.info(f"detoken release req id {decode_req.req.request_id}") + logger.debug(f"detoken release req id {decode_req.req.request_id}") self.shm_req_manager.put_back_req_obj(decode_req.req) self.req_id_to_out.pop(decode_req.request_id, None) return diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 212e037e90..b0a755090a 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -414,7 +414,7 @@ async def _log_req_header(self, request_headers, group_request_id: int): x_session_id = request_headers.get("X-Session-Id", "") format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") - logger.info( + logger.debug( f"recieved req X-Request-Id:{x_request_id} " f"X-Session-Id:{x_session_id} start_time:{format_in_time} " f"lightllm_req_id:{group_request_id} " @@ -611,7 +611,7 @@ async def _wait_to_token_package( (out_token_counter - sum(sub_req_id_to_mtp_accepted_token_num.values())), 1 ) format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") - logger.info( + logger.debug( f"X-Request-Id:{x_request_id} " f"X-Session-Id:{x_session_id} start_time:{format_start_time} " f"lightllm_req_id:{group_request_id} first_token_cost:{first_token_cost_ms}ms " @@ -698,8 +698,8 @@ async def recycle_resource_loop(self): if req_status is None: continue - logger.info( - f"left req id {req_status.group_req_objs.group_req_id}" + logger.debug( + f"left req id {req_status.group_req_objs.group_req_id} " f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} " f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}" ) diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py index 40529a3f5d..f0f9576930 100644 --- a/lightllm/server/router/batch.py +++ b/lightllm/server/router/batch.py @@ -53,7 +53,7 @@ def filter_out_finished_req(self, shm_req_manager: ShmReqManager): unfinished_req_ids = [] for req in self.reqs: if req.shm_infer_released: - logger.info(f"router release req id {req.request_id}") + logger.debug(f"router release req id {req.request_id}") shm_req_manager.put_back_req_obj(req) req = None else: diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index ac5c1abee3..c94a6cf6da 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -16,6 +16,7 @@ from .batch import Batch, Req from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue +from .stats import SystemStatusReporter from lightllm.server.core.objs.io_objs import ( GroupReqIndexes, AbortedReqCmd, @@ -25,7 +26,7 @@ from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer -from lightllm.utils.log_utils import init_logger, log_time_ready +from lightllm.utils.log_utils import init_logger from lightllm.server.router.token_load import TokenLoad from lightllm.server.metrics.manager import MetricClient from lightllm.common.basemodel.infer_lock import g_router_lock @@ -64,6 +65,7 @@ def __init__(self, args: StartArgs): self.read_only_statics_mem_manager = ReadOnlyStaticsMemoryManager() # 初始化 radix_cache_client 用于读取 prompt cache 的管理信息 self.radix_cache_client = None + self.status_reporter = None # 共享变量,用于存储router端调度分析得到的机器负载信息 self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) @@ -192,6 +194,11 @@ async def wait_to_model_ready(self): ) self.req_queue = build_req_queue(self.args, self, self.dp_size_in_node) logger.info(f"use req queue {self.req_queue.__class__.__name__}") + self.status_reporter = SystemStatusReporter( + args=self.args, + max_total_token_num=self.max_total_token_num, + dp_size_in_node=self.dp_size_in_node, + ) if self.args.run_mode == "prefill": # 启动 prefill kv move 管理进程 @@ -237,26 +244,10 @@ async def loop_for_fwd( await self._step() counter_count += 1 if self.running_batch is not None: + # Count output tokens (each running req produces ~1 token per decode step) + self.status_reporter.count_output_tokens(len(self.running_batch.reqs)) if counter_count % 100 == 0: - for dp_index in range(self.dp_size_in_node): - token_ratio1 = self.get_used_tokens(dp_index) / self.max_total_token_num - token_ratio2 = ( - self.max_total_token_num - - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index) - ) / self.max_total_token_num - d_i = dp_index - frozen_token_num = self.shared_token_load.get_frozened_token_count(d_i) - estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i) - paused_req_num = self._get_paused_req_num_in_dp_index(dp_index=d_i) - logger.debug( - f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n" - f"dp_i {d_i} paused req num: {paused_req_num} \n" - f"dp_i {d_i} frozen token num: {frozen_token_num} \n" - f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n" - f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n" - f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token" - ) - self.metric_client.gauge_set("lightllm_batch_pause_size", paused_req_num) + self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num()) # pd decode mode need to update token_load more frequently self.req_queue.update_token_load(self.running_batch, force_update=self.is_pd_decode_mode) self.metric_client.gauge_set("lightllm_batch_current_size", len(self.running_batch.reqs)) @@ -275,13 +266,15 @@ async def loop_for_fwd( self.metric_client.gauge_set("lightllm_batch_pause_size", 0.0) self.metric_client.gauge_set("lightllm_queue_size", 0.0) self.metric_client.gauge_set("lightllm_batch_current_max_tokens", 0.0) - # 60s print once - if log_time_ready("frozen_info", 60): - for dp_i in range(self.dp_size_in_node): - frozen_token_num = self.shared_token_load.get_frozened_token_count(dp_i) - estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(dp_i) - logger.debug(f"dp_i {dp_i} frozen token num: {frozen_token_num} \n") - logger.debug(f"dp_i {dp_i} estimated_peak_token_count: {estimated_peak_token_count} \n") + + self.status_reporter.maybe_print( + running_batch=self.running_batch, + req_queue=self.req_queue, + read_only_statics_mem_manager=self.read_only_statics_mem_manager, + paused_req_num=self._get_paused_req_num(), + radix_cache_client=self.radix_cache_client, + disable_dynamic_prompt_cache=self.args.disable_dynamic_prompt_cache, + ) await asyncio.sleep(self._get_schedule_time_interval()) @@ -311,6 +304,7 @@ async def _step(self): async def _add_batch(self, batch: Batch): # 添加新请求 + self.status_reporter.count_prompt_tokens(batch.input_tokens()) reqs = [r.to_router_rpc_obj() for r in batch.reqs] while not self.shm_reqs_io_buffer.is_empty(): await asyncio.sleep(0.02) @@ -347,6 +341,15 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch): def _filter_reqs_from_running_batch(self): if self.running_batch is not None: + # Capture finished req stats before filtering + for req in self.running_batch.reqs: + if req.shm_infer_released: + self.status_reporter.on_request_completed( + input_len=req.input_len, + output_len=req.shm_cur_output_len, + cache_len=req.prompt_cache_len, + mtp_accepted=req.mtp_accepted_token_num, + ) self.running_batch.filter_out_finished_req(self.shm_req_manager) if self.running_batch.is_clear(): self.running_batch = None @@ -419,7 +422,7 @@ def _add_req(self, group_req_indexes: GroupReqIndexes): req._router_stop_str_matched = False req_group.append(req) - logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s") + logger.debug(f"router receive req id {req.request_id} cost time {time.time() - req.start_time} s") self.req_queue.extend(req_group) self.send_to_detokenization.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) return diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index d50c4e7ca5..ce003d7a4e 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -1,47 +1,110 @@ import time -from lightllm.utils.log_utils import init_logger -from .batch import Batch +import logging +from lightllm.utils.log_utils import init_system_status_logger -logger = init_logger(__name__) +logger = logging.getLogger(__name__) -class Stats: - def __init__(self, log_status, log_stats_interval) -> None: - self.log_stats = log_status - self.log_stats_interval = log_stats_interval - self.last_log_time = time.time() - self.all_tokens = 0 - self.output_tokens = 0 +class SystemStatusReporter: + def __init__(self, args, max_total_token_num, dp_size_in_node): + self.enabled = not args.disable_log_stats + self.interval = max(5, args.log_stats_interval) + if args.log_stats_interval < 5: + logger.warning(f"log_stats_interval={args.log_stats_interval}s is below minimum, using 5s") + self.max_total_token_num = max_total_token_num + self.dp_size_in_node = dp_size_in_node + self.status_logger = init_system_status_logger("router") + + # Accumulation counters (reset each interval) + self.last_print_time = time.time() self.prompt_tokens = 0 - return - - def count_prompt_tokens(self, run_batch: Batch): - if self.log_stats and run_batch is not None: - tokens = run_batch.input_tokens() - self.prompt_tokens += tokens - self.all_tokens += tokens - return - - def count_output_tokens(self, run_batch: Batch): - if self.log_stats and run_batch is not None: - tokens = len(run_batch.reqs) - self.output_tokens += tokens - self.all_tokens += tokens - return - - def print_stats(self): - if not self.log_stats: - return + self.output_tokens = 0 + + # Global counters (never reset, for lifetime stats) + self.global_input_total = 0 + self.global_cache_total = 0 + self.global_mtp_output_total = 0 + self.global_mtp_accepted_total = 0 + def count_prompt_tokens(self, num_tokens: int): + if self.enabled: + self.prompt_tokens += num_tokens + + def count_output_tokens(self, num_tokens: int): + if self.enabled: + self.output_tokens += num_tokens + + def on_request_completed(self, input_len: int, output_len: int, cache_len: int, mtp_accepted: int): + if self.enabled: + self.global_input_total += input_len + self.global_cache_total += cache_len + self.global_mtp_output_total += output_len + self.global_mtp_accepted_total += mtp_accepted + + def maybe_print( + self, + running_batch, + req_queue, + read_only_statics_mem_manager, + paused_req_num=0, + radix_cache_client=None, + disable_dynamic_prompt_cache=False, + ): + if not self.enabled: + return now = time.time() - if now - self.last_log_time > self.log_stats_interval: - logger.debug( - f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n" - f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n" - f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s" - ) - self.all_tokens = 0 - self.output_tokens = 0 - self.prompt_tokens = 0 - self.last_log_time = now - return + elapsed = now - self.last_print_time + if elapsed < self.interval: + return + + total_tps = (self.prompt_tokens + self.output_tokens) / elapsed + input_tps = self.prompt_tokens / elapsed + output_tps = self.output_tokens / elapsed + + running = len(running_batch.reqs) if running_batch else 0 + queued = req_queue.get_wait_req_num() + + # Memory utilization (average across dp) + # kv_used: physical KV memory usage (includes prefix cache tree occupancy) + # kv_used_no_cache: effective usage excluding unrefed prefix cache tokens + kv_used_list = [] + kv_used_no_cache_list = [] + for dp_i in range(self.dp_size_in_node): + unrefed = read_only_statics_mem_manager.get_unrefed_token_num(dp_i) + used = self.max_total_token_num - unrefed + kv_used_list.append(used / self.max_total_token_num) + if not disable_dynamic_prompt_cache and radix_cache_client is not None: + cache_unrefed = radix_cache_client.get_unrefed_tokens_num(dp_i) + kv_used_no_cache_list.append((used - cache_unrefed) / self.max_total_token_num) + else: + kv_used_no_cache_list.append(used / self.max_total_token_num) + avg_kv_used = sum(kv_used_list) / len(kv_used_list) + avg_kv_used_no_cache = sum(kv_used_no_cache_list) / len(kv_used_no_cache_list) + + # Global prefix cache hit rate + cache_hit_rate = ( + (self.global_cache_total / self.global_input_total * 100) if self.global_input_total > 0 else 0.0 + ) + + kv_pct = avg_kv_used * 100 + kv_pct_no_cache = avg_kv_used_no_cache * 100 + + # Avg MTP accepted length (only shown when MTP is active) + mtp_suffix = "" + if self.global_mtp_accepted_total > 0: + decode_steps = self.global_mtp_output_total - self.global_mtp_accepted_total + avg_mtp_len = self.global_mtp_output_total / max(decode_steps, 1) + mtp_suffix = f" | MTP {avg_mtp_len:.2f}" + + self.status_logger.info( + f"Throughput {total_tps:>7.1f} tok/s (in {input_tps:.1f}, out {output_tps:.1f}) | " + f"Reqs {running} run, {queued} wait, {paused_req_num} pause | " + f"KV Cache {kv_pct:.1f}% (active {kv_pct_no_cache:.1f}%) | " + f"Prefix Hit {cache_hit_rate:.1f}%" + f"{mtp_suffix}" + ) + + # Reset windowed counters + self.prompt_tokens = 0 + self.output_tokens = 0 + self.last_print_time = now diff --git a/lightllm/utils/log_utils.py b/lightllm/utils/log_utils.py index f15309d5cf..1cb7aaecee 100644 --- a/lightllm/utils/log_utils.py +++ b/lightllm/utils/log_utils.py @@ -4,13 +4,14 @@ import logging import sys import os -import time from typing import Optional _FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s" _DATE_FORMAT = "%m-%d %H:%M:%S" -_LOG_LEVEL = os.environ.get("LIGHTLLM_LOG_LEVEL", "debug") +_STATUS_FORMAT = "%(levelname)s [%(asctime)s] %(message)s" + +_LOG_LEVEL = os.environ.get("LIGHTLLM_LOG_LEVEL", "info") _LOG_LEVEL = getattr(logging, _LOG_LEVEL.upper(), 0) _LOG_DIR = os.environ.get("LIGHTLLM_LOG_DIR", None) @@ -96,22 +97,19 @@ def init_logger(name: str): return logger -_log_time_mark_dict = {} - - -def log_time_ready(mark_name, time_count: int): - """ - time_count 间隔时间超过多少s调用该函数会返回True,否则返回False - 用于控制一些日志输出的频率 - """ - global _log_time_mark_dict - - if mark_name not in _log_time_mark_dict: - _log_time_mark_dict[mark_name] = time.time() - return False - cur_time_mark = time.time() - if cur_time_mark - _log_time_mark_dict[mark_name] >= time_count: - _log_time_mark_dict[mark_name] = cur_time_mark - return True - else: - return False +def init_system_status_logger(name: str): + logger = logging.getLogger(f"lightllm.status.{name}") + if not logger.handlers: + logger.setLevel(logging.INFO) + fmt = logging.Formatter(_STATUS_FORMAT, datefmt=_DATE_FORMAT) + handler = logging.StreamHandler(sys.stdout) + handler.flush = sys.stdout.flush + handler.setFormatter(fmt) + logger.addHandler(handler) + if _LOG_DIR is not None: + file_handler = logging.FileHandler(os.path.join(_LOG_DIR, f"status.{name}.log")) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(fmt) + logger.addHandler(file_handler) + logger.propagate = False + return logger