Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions lightllm/common/req_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
from lightllm.utils.config_utils import get_vocab_size
from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager

logger = init_logger(__name__)

Expand Down Expand Up @@ -155,7 +156,11 @@ def init_req_sampling_params(self, req):
else:
self.req_to_out_token_id_counter[req.req_idx].fill_(0)
if req.sampling_param.shm_param.input_penalty and req.need_out_token_id_statistics:
prompt_ids = torch.from_numpy(req.shm_req.get_prompt_ids_numpy()).pin_memory().cuda(non_blocking=True)
prompt_ids = g_pin_mem_manager.gen_from_list(
key="prompt_ids_for_penalty",
data=req.shm_req.get_prompt_ids_numpy(),
dtype=torch.int32,
).cuda(non_blocking=True)
token_id_counter(
prompt_ids=prompt_ids, out_token_id_counter=self.req_to_out_token_id_counter[req.req_idx]
)
Expand Down Expand Up @@ -214,22 +219,13 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List):
cum_sum_len += len(id_to_count)
p_cumsum_seq_len.append(cum_sum_len)

from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager

p_token_ids_tensor = g_pin_mem_manager.alloc_pin_tensor(
key="p_token_ids", size=len(p_token_ids), dtype=torch.int32
)
p_token_ids_tensor.numpy()[:] = p_token_ids

p_token_counts_tensor = g_pin_mem_manager.alloc_pin_tensor(
key="p_token_counts", size=len(p_token_counts), dtype=torch.int32
p_token_ids_tensor = g_pin_mem_manager.gen_from_list(key="p_token_ids", data=p_token_ids, dtype=torch.int32)
p_token_counts_tensor = g_pin_mem_manager.gen_from_list(
key="p_token_counts", data=p_token_counts, dtype=torch.int32
)
p_token_counts_tensor.numpy()[:] = p_token_counts

p_cumsum_seq_len_tensor = g_pin_mem_manager.alloc_pin_tensor(
key="p_cumsum_seq_len", size=len(p_cumsum_seq_len), dtype=torch.int32
p_cumsum_seq_len_tensor = g_pin_mem_manager.gen_from_list(
key="p_cumsum_seq_len", data=p_cumsum_seq_len, dtype=torch.int32
)
p_cumsum_seq_len_tensor.numpy()[:] = p_cumsum_seq_len

return (
p_token_ids_tensor.cuda(non_blocking=True),
Expand Down
10 changes: 10 additions & 0 deletions lightllm/server/core/objs/req.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,16 @@ def link_logprobs_shm_array(self):
self.shm_logprobs.link_shm()
return

def release_shm_arrays(self):
"""释放共享内存连接,防止内存泄露(仅 detach,不 unlink)"""
if hasattr(self, "shm_prompt_ids") and self.shm_prompt_ids is not None:
self.shm_prompt_ids.detach_shm()
self.shm_prompt_ids = None
if hasattr(self, "shm_logprobs") and self.shm_logprobs is not None:
self.shm_logprobs.detach_shm()
self.shm_logprobs = None
return

def get_prompt_ids(self):
return self.shm_prompt_ids.arr[: self.input_len].tolist()

Expand Down
6 changes: 6 additions & 0 deletions lightllm/server/core/objs/shm_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ def close_shm(self):
self.shm.unlink()
self.shm = None
self.arr = None

def detach_shm(self):
if self.shm is not None:
self.shm.close()
self.shm = None
self.arr = None
Comment on lines +36 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new detach_shm method has logic that is very similar to the existing close_shm method. To improve maintainability and reduce code duplication, you could extract the common logic into a private helper method. The suggestion below introduces _cleanup and refactors detach_shm to use it. You could then also update close_shm to call self._cleanup(unlink=True).

Suggested change
def detach_shm(self):
if self.shm is not None:
self.shm.close()
self.shm = None
self.arr = None
def detach_shm(self):
self._cleanup(unlink=False)
def _cleanup(self, unlink: bool):
if self.shm is not None:
self.shm.close()
if unlink:
self.shm.unlink()
self.shm = None
self.arr = None

1 change: 1 addition & 0 deletions lightllm/server/detokenization/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def remove_finished_reqs(self):
for decode_req in finished_reqs:
decode_req.req.can_released_mark = True
logger.info(f"detoken release req id {decode_req.req.request_id}")
decode_req.req.release_shm_arrays()
self.shm_req_manager.put_back_req_obj(decode_req.req)
self.req_id_to_out.pop(decode_req.request_id, None)
return
Expand Down
2 changes: 2 additions & 0 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,8 @@ async def recycle_resource_loop(self):
for req_status in release_req_status:
self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None)
for req in req_status.group_req_objs.shm_req_objs:
req.shm_prompt_ids.close_shm()
req.shm_logprobs.close_shm()
await self.shm_req_manager.async_put_back_req_obj(req)
await self.shm_req_manager.async_release_req_index(req.index_in_shm_mem)
await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params)
Expand Down
11 changes: 6 additions & 5 deletions lightllm/server/router/dynamic_prompt/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def split_node(self, prefix_len):
split_parent_node = TreeNode()
split_parent_node.parent = self.parent
split_parent_node.parent.children[self.token_id_key[0].item()] = split_parent_node
split_parent_node.token_id_key = self.token_id_key[0:prefix_len]
split_parent_node.token_mem_index_value = self.token_mem_index_value[0:prefix_len]
split_parent_node.token_id_key = self.token_id_key[0:prefix_len].clone()
split_parent_node.token_mem_index_value = self.token_mem_index_value[0:prefix_len].clone()
split_parent_node.children = {}
split_parent_node.children[self.token_id_key[prefix_len].item()] = self
split_parent_node.ref_counter = self.ref_counter
Expand All @@ -58,8 +58,8 @@ def split_node(self, prefix_len):

def add_and_return_new_child(self, token_id_key, token_mem_index_value):
child = TreeNode()
child.token_id_key = token_id_key
child.token_mem_index_value = token_mem_index_value
child.token_id_key = token_id_key.clone()
child.token_mem_index_value = token_mem_index_value.clone()
first_token_key = child.token_id_key[0].item()
assert first_token_key not in self.children.keys()
self.children[first_token_key] = child
Expand Down Expand Up @@ -241,7 +241,8 @@ def match_prefix(self, key, update_refs=False):
value = torch.zeros((0,), device="cpu", dtype=self._value_dtype)
return tree_node, len(value), value
else:
self.dec_node_ref_counter(self.root_node)
if update_refs:
self.dec_node_ref_counter(self.root_node)
return None, 0, None

def _match_prefix_helper(
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def _filter(self, finished_request_ids: List[int]):
free_req_index.append(req.req_idx)
# logger.info(f"infer release req id {req.shm_req.request_id}")
req.shm_req.shm_infer_released = True
req.shm_req.release_shm_arrays()
self.shm_req_manager.put_back_req_obj(req.shm_req)

free_token_index = custom_cat(free_token_index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty
from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache
from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context
from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager
from lightllm.utils.envs_utils import get_env_start_args


Expand All @@ -16,7 +17,7 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]):
b_mask_eos_reqs,
is_all_greedy,
) = _get_post_sample_tensors(reqs)
eos_ids = torch.tensor(eos_id, dtype=torch.int32, device="cpu", pin_memory=True).cuda(non_blocking=True)
eos_ids = g_pin_mem_manager.gen_from_list(key="eos_ids", data=eos_id, dtype=torch.int32).cuda(non_blocking=True)

sampling_params_manager = g_infer_context.req_manager.req_sampling_params_manager

Expand Down Expand Up @@ -128,12 +129,14 @@ def _get_post_sample_tensors(reqs: List[InferReq]):
is_all_greedy = False
req_idxes.append(req_obj.req_idx)

req_idxes_cpu = torch.tensor(req_idxes, dtype=torch.int32, device="cpu", pin_memory=True)
temperatures_cpu = torch.tensor(temperatures, dtype=torch.float, device="cpu", pin_memory=True)
top_ps_cpu = torch.tensor(top_ps, dtype=torch.float, device="cpu", pin_memory=True)
top_ks_cpu = torch.tensor(top_ks, dtype=torch.int32, device="cpu", pin_memory=True)
length_penalty_param_cpu = torch.tensor(length_penalty_param, dtype=torch.int32, device="cpu", pin_memory=True)
mask_eos_reqs_cpu = torch.tensor(mask_eos_reqs, dtype=torch.bool, device="cpu", pin_memory=True)
req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32)
temperatures_cpu = g_pin_mem_manager.gen_from_list(key="temperatures", data=temperatures, dtype=torch.float32)
top_ps_cpu = g_pin_mem_manager.gen_from_list(key="top_ps", data=top_ps, dtype=torch.float32)
top_ks_cpu = g_pin_mem_manager.gen_from_list(key="top_ks", data=top_ks, dtype=torch.int32)
length_penalty_param_cpu = g_pin_mem_manager.gen_from_list(
key="length_penalty_param", data=length_penalty_param, dtype=torch.int32
)
mask_eos_reqs_cpu = g_pin_mem_manager.gen_from_list(key="mask_eos_reqs", data=mask_eos_reqs, dtype=torch.bool)

return (
req_idxes_cpu.cuda(non_blocking=True),
Expand Down
136 changes: 99 additions & 37 deletions test/benchmark/service/benchmark_qps.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def get_custom_input_data(data_path, output_len, tokenizer, range_ratio):
model_name = []


# Minimal fix: one retry on transient network errors.
_DEFAULT_RETRY = 1


async def async_post_stream_openai(url, prompt, max_new_tokens, session):
try:
text_input, input_len = prompt
Expand All @@ -116,21 +120,34 @@ async def async_post_stream_openai(url, prompt, max_new_tokens, session):
"best_of": 1,
}
headers = {"Content-Type": "application/json"}
used_time = []
start_time = time.time()
last_time = start_time
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
return []

async for line in response.content:
line = line.strip()
if line:
current_time = time.time()
elapsed_time = current_time - last_time
used_time.append(elapsed_time)
last_time = current_time
return used_time, input_len

for attempt in range(_DEFAULT_RETRY + 1):
used_time = []
start_time = time.time()
last_time = start_time
try:
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
return []

try:
async for line in response.content:
line = line.strip()
if line:
current_time = time.time()
elapsed_time = current_time - last_time
used_time.append(elapsed_time)
last_time = current_time
except Exception:
# server may disconnect mid-stream; keep partial timings if any.
pass

if used_time or attempt >= _DEFAULT_RETRY:
return used_time, input_len
except Exception as e:
if attempt >= _DEFAULT_RETRY:
print(e)
return []
except Exception as e:
print(e)
pass
Expand All @@ -149,21 +166,33 @@ async def async_post_stream_lightllm(url, prompt, max_new_tokens, session):
},
}
headers = {"Content-Type": "application/json"}
used_time = []
start_time = time.time()
last_time = start_time
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
return []

async for line in response.content:
if line and line.startswith(b"data:"):
# print(line)
current_time = time.time()
elapsed_time = current_time - last_time
used_time.append(elapsed_time)
last_time = current_time
return used_time, input_len

for attempt in range(_DEFAULT_RETRY + 1):
used_time = []
start_time = time.time()
last_time = start_time
try:
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
return []

try:
async for line in response.content:
if line and line.startswith(b"data:"):
current_time = time.time()
elapsed_time = current_time - last_time
used_time.append(elapsed_time)
last_time = current_time
except Exception:
# server may disconnect mid-stream; keep partial timings if any.
pass

if used_time or attempt >= _DEFAULT_RETRY:
return used_time, input_len
except Exception as e:
if attempt >= _DEFAULT_RETRY:
print(e)
return []
except Exception as e:
print(e)
pass
Expand All @@ -187,6 +216,7 @@ async def continuous_sender(
while not stop_send.is_set():
if not continuous_send and sent_count[0] >= max_count:
break

prompt = prompts[prompt_index % len(prompts)]
max_tokens = max_new_tokens[prompt_index % len(max_new_tokens)]

Expand All @@ -212,18 +242,42 @@ async def response_collector(
force_terminate,
pending_tasks,
):
# 单个请求在 collector 侧的最大等待时间,避免网络异常导致永久卡住
task_timeout_s = 600
try:
while True:
try:
task = await asyncio.wait_for(request_queue.get(), timeout=1.0)
result, input_len = await task
request_queue.task_done()
assert result is not None
if len(result) >= 1 and not stop_send.is_set():
results.append((result, input_len))
result = None
input_len = 0
try:
try:
result_tuple = await asyncio.wait_for(task, timeout=task_timeout_s)
except asyncio.TimeoutError:
print("\nError collecting response: task timeout")
if not task.done():
task.cancel()
result_tuple = None

if isinstance(result_tuple, tuple) and len(result_tuple) == 2:
result, input_len = result_tuple
else:
result = None
input_len = 0
except Exception as e:
print(f"\nError collecting response: {e}")
finally:
# 确保队列不会因为 continue/exception 而永久积压
request_queue.task_done()

# 无论成功失败都推进计数,避免等待 remaining responses 时卡死
current_count = counter[0] + 1
counter[0] = current_count
print(f"\rfinished_reqs:{current_count} / target_reqs:{reqs_num} / sent_reqs:{sent_count[0]}", end="")

if result is not None:
if len(result) >= 1 and not stop_send.is_set():
results.append((result, input_len))
if len(results) >= reqs_num and not stop_send.is_set():
end_time[0] = time.time()
print("\nReached target number of responses")
Expand All @@ -245,6 +299,7 @@ async def response_collector(
continue
except Exception as e:
print(f"\nError collecting response: {e}")
continue
finally:
if force_terminate:
for task in pending_tasks:
Expand All @@ -253,7 +308,15 @@ async def response_collector(


async def run_continuous_benchmark(
async_task, url, prompts, max_new_tokens, reqs_num, num_clients, input_qps, force_terminate, continuous_send
async_task,
url,
prompts,
max_new_tokens,
reqs_num,
num_clients,
input_qps,
force_terminate,
continuous_send,
):
request_queue = asyncio.Queue()
stop_event = asyncio.Event()
Expand Down Expand Up @@ -414,7 +477,6 @@ def main():
)
)
loop.close()
print(len(results))
first_token_time = []
decode_token_time = []
request_time = []
Expand Down