Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import bisect
import triton
from typing import Optional
from tqdm import tqdm
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager
Expand Down Expand Up @@ -197,7 +198,11 @@ def warmup(self, model):
model: TpPartBaseModel = model

# decode cuda graph init
for batch_size in self.cuda_graph_batch_sizes[::-1]:
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs")
for batch_size in progress_bar:
avail_mem, _ = torch.cuda.mem_get_info()
avail_mem_gb = avail_mem / (1024 ** 3)
progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB")
seq_len = 2
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
Expand Down Expand Up @@ -252,7 +257,13 @@ def warmup_overlap(self, model):

model: TpPartBaseModel = model

for batch_size in self.cuda_graph_batch_sizes[::-1]:
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs")
for batch_size in progress_bar:
avail_mem, _ = torch.cuda.mem_get_info()
avail_mem_gb = avail_mem / (1024 ** 3)
progress_bar.set_description(
f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB"
)
decode_batches = []
for micro_batch_index in [0, 1]:
# dummy decoding, capture the cudagraph
Expand Down
6 changes: 3 additions & 3 deletions lightllm/common/triton_utils/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _try_load_cache(self, static_key):

cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key))
if os.path.exists(cache_file):
logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}")
logger.info(f"Loading cached configs for {self.kernel_name} - {dict(static_key)}")
with open(cache_file, "rb") as f:
self.cached_configs[static_key] = orjson.loads(f.read())
return True
Expand Down Expand Up @@ -353,9 +353,9 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS | orjson.OPT_NON_STR_KEYS,
)
)
logger.info(f"Saved configs for {self.kernel_name} - {_static_key}")
logger.info(f"Saved configs for {self.kernel_name} - {dict(_static_key)}")

logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {static_key} finished")
logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {dict(static_key)} finished")

def _mutate_args_clone(self, args, kwargs):
origin_list = []
Expand Down
1 change: 0 additions & 1 deletion lightllm/server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def set_args(self, args: StartArgs):
app = FastAPI()
g_objs.app = app

_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"}
_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"}
_ACCESS_LOG_RESET = "\033[0m"

Expand Down
6 changes: 0 additions & 6 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,6 @@ def normal_or_p_d_start(args):
f"{args.host}:{args.port}",
"--log-level",
"info",
"--access-logfile",
"-",
"--error-logfile",
"-",
"lightllm.server.api_http:app",
Expand Down Expand Up @@ -556,8 +554,6 @@ def pd_master_start(args):
f"{args.host}:{args.port}",
"--log-level",
"info",
"--access-logfile",
"-",
"--error-logfile",
"-",
"lightllm.server.api_http:app",
Expand Down Expand Up @@ -652,8 +648,6 @@ def config_server_start(args):
f"{args.config_server_host}:{args.config_server_port}",
"--log-level",
"info",
"--access-logfile",
"-",
"--error-logfile",
"-",
"lightllm.server.config_server.api_http:app",
Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/detokenization/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _add_new_group_req_index(self, recv_obj: GroupReqIndexes):
req.link_prompt_ids_shm_array()
req.link_logprobs_shm_array()

logger.info(
logger.debug(
f"detokenization recv req id {req.request_id} " f"cost time {time.time() - recv_obj.time_mark} s"
)

Expand Down Expand Up @@ -160,7 +160,7 @@ def remove_finished_reqs(self):

for decode_req in finished_reqs:
decode_req.req.can_released_mark = True
logger.info(f"detoken release req id {decode_req.req.request_id}")
logger.debug(f"detoken release req id {decode_req.req.request_id}")
self.shm_req_manager.put_back_req_obj(decode_req.req)
self.req_id_to_out.pop(decode_req.request_id, None)
return
Expand Down
25 changes: 21 additions & 4 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,17 @@ async def generate(
# 用于等待 pd_master 下发的交换信息
nixl_pd_event: asyncio.Event = None,
) -> AsyncGenerator[Tuple[int, str, dict, FinishStatus], None]:
group_request_id = None
if isinstance(prompt, str):
# Guard against extremely long string prompts that might stall the tokenizer
# or cause excessive memory usage before tokenization.
# 8 characters per token is a conservative heuristic (avg is ~4).
max_prompt_chars = self.max_req_total_len * 8
if len(prompt) > max_prompt_chars:
raise ValueError(
f"prompt text length {len(prompt)} exceeds the character limit {max_prompt_chars}, "
f"the request is rejected before tokenization."
)

start_time = time.time()
request_headers = request.headers if request is not None else {}
Expand Down Expand Up @@ -445,6 +456,12 @@ async def generate(

yield sub_req_id, request_output, metadata, finish_status

except ValueError as e:
logger.warning(f"group_request_id: {group_request_id} request invalid: {str(e)}")
if group_request_id not in self.req_id_to_out_inf:
await self._release_multimodal_resources(multimodal_params)
await self.abort(group_request_id)
raise e
except (ClientDisconnected, Exception) as e:
logger.error(f"group_request_id: {group_request_id} has exception {str(e)}")

Expand Down Expand Up @@ -481,7 +498,7 @@ async def _log_req_header(self, request_headers, group_request_id: int):
x_session_id = request_headers.get("X-Session-Id", "")

format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S")
logger.info(
logger.debug(
f"received req X-Request-Id:{x_request_id} "
f"X-Session-Id:{x_session_id} start_time:{format_in_time} "
f"lightllm_req_id:{group_request_id} "
Expand Down Expand Up @@ -719,7 +736,7 @@ async def _wait_to_token_package(
(out_token_counter - sum(sub_req_id_to_mtp_accepted_token_num.values())), 1
)
format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
logger.info(
logger.debug(
f"X-Request-Id:{x_request_id} "
f"X-Session-Id:{x_session_id} start_time:{format_start_time} "
f"lightllm_req_id:{group_request_id} first_token_cost:{first_token_cost_ms}ms "
Expand Down Expand Up @@ -812,8 +829,8 @@ async def recycle_resource_loop(self):
if req_status is None:
continue

logger.info(
f"left req id {req_status.group_req_objs.group_req_id}"
logger.debug(
f"left req id {req_status.group_req_objs.group_req_id} "
f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} "
f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}"
)
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/httpserver_for_pd_master/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ async def _log_req_header(self, request: Request, group_request_id: int):
x_request_id = request.headers.get("X-Request-Id", "")
x_session_id = request.headers.get("X-Session-Id", "")
format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S")
logger.info(
logger.debug(
f"received req X-Request-Id:{x_request_id} "
f"X-Session-Id:{x_session_id} start_time:{format_in_time} "
f"lightllm_req_id:{group_request_id} "
Expand Down
8 changes: 2 additions & 6 deletions lightllm/server/router/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Dict, List, Optional, Tuple, Union
from lightllm.server.core.objs import ShmReqManager, Req
from lightllm.utils.log_utils import init_logger
from .stats import RouterStatics

logger = init_logger(__name__)

Expand Down Expand Up @@ -50,14 +49,11 @@ def get_all_dp_req_num(self) -> List[int]:
all_dp_req_num[req.sample_params.suggested_dp_index] += 1
return all_dp_req_num

def filter_out_finished_req(self, shm_req_manager: ShmReqManager, router_statics: RouterStatics):
def filter_out_finished_req(self, shm_req_manager: ShmReqManager):
unfinished_req_ids = []
for req in self.reqs:
if req.shm_infer_released:
logger.info(f"router release req id {req.request_id}")
if not req.is_aborted:
router_statics.update(req.candetoken_out_len)

logger.debug(f"router release req id {req.request_id}")
shm_req_manager.put_back_req_obj(req)
req = None
else:
Expand Down
73 changes: 42 additions & 31 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .batch import Batch, Req
from .model_infer.model_rpc import start_model_process, ModelRpcClient
from .req_queue import build_req_queue
from .stats import SystemStatusReporter
from lightllm.server.core.objs.io_objs import (
GroupReqIndexes,
AbortedReqCmd,
Expand All @@ -25,7 +26,7 @@
from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient
from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient
from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer
from lightllm.utils.log_utils import init_logger, log_time_ready
from lightllm.utils.log_utils import init_logger
from lightllm.server.router.token_load import TokenLoad
from lightllm.server.metrics.manager import MetricClient
from lightllm.common.basemodel.infer_lock import g_router_lock
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(self, args: StartArgs):
self.read_only_statics_mem_manager = ReadOnlyStaticsMemoryManager()
# 初始化 radix_cache_client 用于读取 prompt cache 的管理信息
self.radix_cache_client = None
self.status_reporter = None

# 共享变量,用于存储router端调度分析得到的机器负载信息
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node)
Expand Down Expand Up @@ -194,6 +196,11 @@ async def wait_to_model_ready(self):
)
self.req_queue = build_req_queue(self.args, self, self.dp_size_in_node)
logger.info(f"use req queue {self.req_queue.__class__.__name__}")
self.status_reporter = SystemStatusReporter(
args=self.args,
max_total_token_num=self.max_total_token_num,
dp_size_in_node=self.dp_size_in_node,
)

if self.args.run_mode == "prefill":
# 启动 prefill kv move 管理进程
Expand Down Expand Up @@ -239,27 +246,10 @@ async def loop_for_fwd(
await self._step()
counter_count += 1
if self.running_batch is not None:
# Output-token counting is done in bulk at the print-window boundary
# inside SystemStatusReporter.maybe_print, so the router tick stays cheap.
if counter_count % 100 == 0:
for dp_index in range(self.dp_size_in_node):
token_ratio1 = self.get_used_tokens(dp_index) / self.max_total_token_num
token_ratio2 = (
self.max_total_token_num
- self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index)
) / self.max_total_token_num
d_i = dp_index
frozen_token_num = self.shared_token_load.get_frozened_token_count(d_i)
estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i)
paused_req_num = self._get_paused_req_num_in_dp_index(dp_index=d_i)
logger.debug(
f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n"
f"dp_i {d_i} paused req num: {paused_req_num} \n"
f"dp_i {d_i} frozen token num: {frozen_token_num} \n"
f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n"
f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n"
f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token"
)
logger.debug(self.router_statics.log_str())
self.metric_client.gauge_set("lightllm_batch_pause_size", paused_req_num)
self.metric_client.gauge_set("lightllm_batch_pause_size", self._get_paused_req_num())
# pd decode mode need to update token_load more frequently
self.req_queue.update_token_load(self.running_batch, force_update=self.is_pd_decode_mode)
self.metric_client.gauge_set("lightllm_batch_current_size", len(self.running_batch.reqs))
Expand All @@ -278,13 +268,15 @@ async def loop_for_fwd(
self.metric_client.gauge_set("lightllm_batch_pause_size", 0.0)
self.metric_client.gauge_set("lightllm_queue_size", 0.0)
self.metric_client.gauge_set("lightllm_batch_current_max_tokens", 0.0)
# 60s print once
if log_time_ready("frozen_info", 60):
for dp_i in range(self.dp_size_in_node):
frozen_token_num = self.shared_token_load.get_frozened_token_count(dp_i)
estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(dp_i)
logger.debug(f"dp_i {dp_i} frozen token num: {frozen_token_num} \n")
logger.debug(f"dp_i {dp_i} estimated_peak_token_count: {estimated_peak_token_count} \n")

self.status_reporter.maybe_print(
running_batch=self.running_batch,
req_queue=self.req_queue,
read_only_statics_mem_manager=self.read_only_statics_mem_manager,
paused_req_num=self._get_paused_req_num(),
radix_cache_client=self.radix_cache_client,
disable_dynamic_prompt_cache=self.args.disable_dynamic_prompt_cache,
)

await asyncio.sleep(self._get_schedule_time_interval())

Expand Down Expand Up @@ -314,6 +306,7 @@ async def _step(self):

async def _add_batch(self, batch: Batch):
# 添加新请求
self.status_reporter.count_prompt_tokens(batch.input_tokens())
reqs = [r.to_router_rpc_obj() for r in batch.reqs]
while not self.shm_reqs_io_buffer.is_empty():
await asyncio.sleep(0.001)
Expand Down Expand Up @@ -347,7 +340,25 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch):

def _filter_reqs_from_running_batch(self):
if self.running_batch is not None:
self.running_batch.filter_out_finished_req(self.shm_req_manager, self.router_statics)
# Capture finished req stats before filtering
for req in self.running_batch.reqs:
if not req.shm_infer_released:
continue
# Settle any output-token tail produced after the last window boundary,
# so windowed TPS does not lose the req's last tokens.
self.status_reporter.discard_req(req)
self.status_reporter.on_request_completed(
input_len=req.input_len,
output_len=req.shm_cur_output_len,
cache_len=req.prompt_cache_len,
mtp_accepted=req.mtp_accepted_token_num,
)
# Aborted/disconnected requests can leave a partial output_len that
# would bias the EMA toward shorter generations; skip them.
if req.is_aborted:
continue
self.router_statics.update(req.candetoken_out_len)
self.running_batch.filter_out_finished_req(self.shm_req_manager)
if self.running_batch.is_clear():
self.running_batch = None
return
Expand Down Expand Up @@ -419,7 +430,7 @@ def _add_req(self, group_req_indexes: GroupReqIndexes):
req._router_stop_str_matched = False
req_group.append(req)

logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s")
logger.debug(f"router receive req id {req.request_id} cost time {time.time() - req.start_time} s")
self.req_queue.extend(req_group)
self.send_to_detokenization.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
return
Expand All @@ -431,7 +442,7 @@ def _generate_new_batch(self):
)
self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch)
if self.schedule_new_batch is not None:
logger.info(f"gen new batch, {self.schedule_new_batch.simple_log()}")
logger.debug(f"gen new batch, {self.schedule_new_batch.simple_log()}")
return

def _multinode_tp_generate_new_batch(self):
Expand Down
Loading
Loading