refactor(kv-cache): embed KvCacheAllocator in MemoryManager as allocator#1301
refactor(kv-cache): embed KvCacheAllocator in MemoryManager as allocator#1301hiworldwzj wants to merge 6 commits intomainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors memory management by introducing the KvCacheAllocator class and delegating allocation, freeing, and resizing responsibilities from MemoryManager to this new component. The changes include updating various modules to access memory state through the allocator attribute and removing redundant manual state management. Feedback focuses on improving type hint consistency, ensuring internal state like HOLD_TOKEN_MEMINDEX is correctly updated during resizing, and refining log message formatting by removing unnecessary trailing newlines.
| def alloc(self, need_size) -> torch.Tensor: | ||
| if need_size > self.mark_end - self.mark_start: | ||
| logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") | ||
| assert False, "error alloc state" | ||
|
|
||
| start = self.mark_start | ||
| end = self.mark_start + need_size | ||
| self.mark_start += need_size | ||
|
|
||
| self.can_use_mem_size -= need_size | ||
| self.shared_can_use_token_num.set_value(self.can_use_mem_size) | ||
|
|
||
| # 利用缓冲区返回,避免异步情况下的内存竞争 | ||
| if self._return_start + need_size > self._mem_state_return.shape[0]: | ||
| self._return_start = 0 | ||
| ans = self._mem_state_return[self._return_start : self._return_start + need_size] | ||
| ans.copy_(self.mem_state[start:end]) | ||
| self._return_start += need_size | ||
| return ans | ||
|
|
||
| def free(self, free_index: Union[torch.Tensor, List[int]]): | ||
| """_summary_ | ||
|
|
||
| Args: | ||
| free_index (torch.Tensor): _description_ | ||
| """ | ||
| return self.allocator.alloc(need_size) | ||
|
|
||
| end = self.mark_start | ||
| start = self.mark_start - len(free_index) | ||
| assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" | ||
|
|
||
| if isinstance(free_index, list): | ||
| self.mem_state.numpy()[start:end] = free_index | ||
| else: | ||
| # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 | ||
| self.mem_state[start:end] = free_index | ||
|
|
||
| self.mark_start -= len(free_index) | ||
|
|
||
| self.can_use_mem_size += len(free_index) | ||
| self.shared_can_use_token_num.set_value(self.can_use_mem_size) | ||
|
|
||
| if self.can_use_mem_size == len(self.mem_state): | ||
| logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") | ||
| return | ||
| def free(self, free_index: Union[torch.Tensor, List[int]]) -> None: | ||
| self.allocator.free(free_index) | ||
|
|
||
| def free_all(self): | ||
| self.can_use_mem_size = len(self.mem_state) | ||
| self.shared_can_use_token_num.set_value(self.can_use_mem_size) | ||
| self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) | ||
| self.mark_start = 0 | ||
| self.mark_end = len(self.mem_state) | ||
| self.allocator.free_all() | ||
|
|
||
| def resize_mem(self, new_size): |
There was a problem hiding this comment.
For consistency with the free method, please add type hints for the parameters and return values of alloc, free_all, and resize_mem.
| def alloc(self, need_size) -> torch.Tensor: | |
| if need_size > self.mark_end - self.mark_start: | |
| logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") | |
| assert False, "error alloc state" | |
| start = self.mark_start | |
| end = self.mark_start + need_size | |
| self.mark_start += need_size | |
| self.can_use_mem_size -= need_size | |
| self.shared_can_use_token_num.set_value(self.can_use_mem_size) | |
| # 利用缓冲区返回,避免异步情况下的内存竞争 | |
| if self._return_start + need_size > self._mem_state_return.shape[0]: | |
| self._return_start = 0 | |
| ans = self._mem_state_return[self._return_start : self._return_start + need_size] | |
| ans.copy_(self.mem_state[start:end]) | |
| self._return_start += need_size | |
| return ans | |
| def free(self, free_index: Union[torch.Tensor, List[int]]): | |
| """_summary_ | |
| Args: | |
| free_index (torch.Tensor): _description_ | |
| """ | |
| return self.allocator.alloc(need_size) | |
| end = self.mark_start | |
| start = self.mark_start - len(free_index) | |
| assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" | |
| if isinstance(free_index, list): | |
| self.mem_state.numpy()[start:end] = free_index | |
| else: | |
| # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 | |
| self.mem_state[start:end] = free_index | |
| self.mark_start -= len(free_index) | |
| self.can_use_mem_size += len(free_index) | |
| self.shared_can_use_token_num.set_value(self.can_use_mem_size) | |
| if self.can_use_mem_size == len(self.mem_state): | |
| logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") | |
| return | |
| def free(self, free_index: Union[torch.Tensor, List[int]]) -> None: | |
| self.allocator.free(free_index) | |
| def free_all(self): | |
| self.can_use_mem_size = len(self.mem_state) | |
| self.shared_can_use_token_num.set_value(self.can_use_mem_size) | |
| self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) | |
| self.mark_start = 0 | |
| self.mark_end = len(self.mem_state) | |
| self.allocator.free_all() | |
| def resize_mem(self, new_size): | |
| def alloc(self, need_size: int) -> torch.Tensor: | |
| return self.allocator.alloc(need_size) | |
| def free(self, free_index: Union[torch.Tensor, List[int]]) -> None: | |
| self.allocator.free(free_index) | |
| def free_all(self) -> None: | |
| self.allocator.free_all() | |
| def resize_mem(self, new_size: int) -> None: |
| self.size = new_size | ||
| self.mem_state = torch.arange( | ||
| 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True | ||
| ) | ||
| self.mark_start = 0 | ||
| self.mark_end = self.size | ||
| self.can_use_mem_size = self.size | ||
| self.shared_can_use_token_num.set_value(self.can_use_mem_size) | ||
| self.allocator.resize(new_size) |
There was a problem hiding this comment.
The HOLD_TOKEN_MEMINDEX attribute is initialized to self.size in __init__. When resizing the memory, this attribute should also be updated to reflect the new size, ensuring consistency for any components relying on this marker.
| self.size = new_size | |
| self.mem_state = torch.arange( | |
| 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True | |
| ) | |
| self.mark_start = 0 | |
| self.mark_end = self.size | |
| self.can_use_mem_size = self.size | |
| self.shared_can_use_token_num.set_value(self.can_use_mem_size) | |
| self.allocator.resize(new_size) | |
| self.size = new_size | |
| self.allocator.resize(new_size) | |
| self.HOLD_TOKEN_MEMINDEX = self.size |
| f"mem manager can alloc token num {self.req_manager.mem_manager.can_use_mem_size}\n" | ||
| f"mem manager total size {self.req_manager.mem_manager.size}" | ||
| f"mem manager can alloc token num {self.req_manager.mem_manager.allocator.can_use_mem_size}\n" | ||
| f"mem manager total size {self.req_manager.mem_manager.allocator.size}\n" |
There was a problem hiding this comment.
The addition of a trailing newline \n at the end of the log message string will result in an extra empty line in the output. It is better to keep the log message concise without the trailing newline, consistent with the original implementation.
| f"mem manager total size {self.req_manager.mem_manager.allocator.size}\n" | |
| f"mem manager total size {self.req_manager.mem_manager.allocator.size}" |
| f"mem manager can alloc token num {self.backend.model.mem_manager.can_use_mem_size}\n" | ||
| f"mem manager total size {self.backend.model.mem_manager.size}" | ||
| f"mem manager can alloc token num {self.backend.model.mem_manager.allocator.can_use_mem_size}\n" | ||
| f"mem manager total size {self.backend.model.mem_manager.allocator.size}\n" |
There was a problem hiding this comment.
Similar to the change in infer_batch.py, the trailing newline \n here adds an unnecessary empty line to the debug log output. Please remove it for cleaner log formatting.
| f"mem manager total size {self.backend.model.mem_manager.allocator.size}\n" | |
| f"mem manager total size {self.backend.model.mem_manager.allocator.size}" |
No description provided.