Skip to content

refactor(kv-cache): embed KvCacheAllocator in MemoryManager as allocator#1301

Open
hiworldwzj wants to merge 6 commits intomainfrom
wzj_dev
Open

refactor(kv-cache): embed KvCacheAllocator in MemoryManager as allocator#1301
hiworldwzj wants to merge 6 commits intomainfrom
wzj_dev

Conversation

@hiworldwzj
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors memory management by introducing the KvCacheAllocator class and delegating allocation, freeing, and resizing responsibilities from MemoryManager to this new component. The changes include updating various modules to access memory state through the allocator attribute and removing redundant manual state management. Feedback focuses on improving type hint consistency, ensuring internal state like HOLD_TOKEN_MEMINDEX is correctly updated during resizing, and refining log message formatting by removing unnecessary trailing newlines.

Comment on lines 322 to 331
def alloc(self, need_size) -> torch.Tensor:
if need_size > self.mark_end - self.mark_start:
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
assert False, "error alloc state"

start = self.mark_start
end = self.mark_start + need_size
self.mark_start += need_size

self.can_use_mem_size -= need_size
self.shared_can_use_token_num.set_value(self.can_use_mem_size)

# 利用缓冲区返回,避免异步情况下的内存竞争
if self._return_start + need_size > self._mem_state_return.shape[0]:
self._return_start = 0
ans = self._mem_state_return[self._return_start : self._return_start + need_size]
ans.copy_(self.mem_state[start:end])
self._return_start += need_size
return ans

def free(self, free_index: Union[torch.Tensor, List[int]]):
"""_summary_

Args:
free_index (torch.Tensor): _description_
"""
return self.allocator.alloc(need_size)

end = self.mark_start
start = self.mark_start - len(free_index)
assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}"

if isinstance(free_index, list):
self.mem_state.numpy()[start:end] = free_index
else:
# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
self.mem_state[start:end] = free_index

self.mark_start -= len(free_index)

self.can_use_mem_size += len(free_index)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)

if self.can_use_mem_size == len(self.mem_state):
logger.debug(f"freed all gpu mem size {self.can_use_mem_size}")
return
def free(self, free_index: Union[torch.Tensor, List[int]]) -> None:
self.allocator.free(free_index)

def free_all(self):
self.can_use_mem_size = len(self.mem_state)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self.mem_state.numpy()[:] = list(range(0, len(self.mem_state)))
self.mark_start = 0
self.mark_end = len(self.mem_state)
self.allocator.free_all()

def resize_mem(self, new_size):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the free method, please add type hints for the parameters and return values of alloc, free_all, and resize_mem.

Suggested change
def alloc(self, need_size) -> torch.Tensor:
if need_size > self.mark_end - self.mark_start:
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
assert False, "error alloc state"
start = self.mark_start
end = self.mark_start + need_size
self.mark_start += need_size
self.can_use_mem_size -= need_size
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
# 利用缓冲区返回,避免异步情况下的内存竞争
if self._return_start + need_size > self._mem_state_return.shape[0]:
self._return_start = 0
ans = self._mem_state_return[self._return_start : self._return_start + need_size]
ans.copy_(self.mem_state[start:end])
self._return_start += need_size
return ans
def free(self, free_index: Union[torch.Tensor, List[int]]):
"""_summary_
Args:
free_index (torch.Tensor): _description_
"""
return self.allocator.alloc(need_size)
end = self.mark_start
start = self.mark_start - len(free_index)
assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}"
if isinstance(free_index, list):
self.mem_state.numpy()[start:end] = free_index
else:
# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
self.mem_state[start:end] = free_index
self.mark_start -= len(free_index)
self.can_use_mem_size += len(free_index)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
if self.can_use_mem_size == len(self.mem_state):
logger.debug(f"freed all gpu mem size {self.can_use_mem_size}")
return
def free(self, free_index: Union[torch.Tensor, List[int]]) -> None:
self.allocator.free(free_index)
def free_all(self):
self.can_use_mem_size = len(self.mem_state)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self.mem_state.numpy()[:] = list(range(0, len(self.mem_state)))
self.mark_start = 0
self.mark_end = len(self.mem_state)
self.allocator.free_all()
def resize_mem(self, new_size):
def alloc(self, need_size: int) -> torch.Tensor:
return self.allocator.alloc(need_size)
def free(self, free_index: Union[torch.Tensor, List[int]]) -> None:
self.allocator.free(free_index)
def free_all(self) -> None:
self.allocator.free_all()
def resize_mem(self, new_size: int) -> None:

Comment on lines 341 to +342
self.size = new_size
self.mem_state = torch.arange(
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
self.mark_start = 0
self.mark_end = self.size
self.can_use_mem_size = self.size
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self.allocator.resize(new_size)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The HOLD_TOKEN_MEMINDEX attribute is initialized to self.size in __init__. When resizing the memory, this attribute should also be updated to reflect the new size, ensuring consistency for any components relying on this marker.

Suggested change
self.size = new_size
self.mem_state = torch.arange(
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
self.mark_start = 0
self.mark_end = self.size
self.can_use_mem_size = self.size
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self.allocator.resize(new_size)
self.size = new_size
self.allocator.resize(new_size)
self.HOLD_TOKEN_MEMINDEX = self.size

f"mem manager can alloc token num {self.req_manager.mem_manager.can_use_mem_size}\n"
f"mem manager total size {self.req_manager.mem_manager.size}"
f"mem manager can alloc token num {self.req_manager.mem_manager.allocator.can_use_mem_size}\n"
f"mem manager total size {self.req_manager.mem_manager.allocator.size}\n"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The addition of a trailing newline \n at the end of the log message string will result in an extra empty line in the output. It is better to keep the log message concise without the trailing newline, consistent with the original implementation.

Suggested change
f"mem manager total size {self.req_manager.mem_manager.allocator.size}\n"
f"mem manager total size {self.req_manager.mem_manager.allocator.size}"

f"mem manager can alloc token num {self.backend.model.mem_manager.can_use_mem_size}\n"
f"mem manager total size {self.backend.model.mem_manager.size}"
f"mem manager can alloc token num {self.backend.model.mem_manager.allocator.can_use_mem_size}\n"
f"mem manager total size {self.backend.model.mem_manager.allocator.size}\n"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the change in infer_batch.py, the trailing newline \n here adds an unnecessary empty line to the debug log output. Please remove it for cleaner log formatting.

Suggested change
f"mem manager total size {self.backend.model.mem_manager.allocator.size}\n"
f"mem manager total size {self.backend.model.mem_manager.allocator.size}"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant