feat: support prefix cache for multi-modal model.#997
feat: support prefix cache for multi-modal model.#997wly-115 wants to merge 1 commit intojd-opensource:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces significant support for prefix caching in multi-modal models, involving extensive changes to block management, prefix cache implementation, input batching, and model-specific logic, notably the new MMPrefixCache and refactored MMItemState. However, several security and stability issues were identified, including unchecked variant access and potential out-of-bounds array access that could lead to Denial of Service (DoS) crashes, as well as an exception-safety issue in the prefix cache insertion logic that might cause dangling pointers and use-after-free vulnerabilities. Additionally, a high-severity code quality issue regarding an unused member variable needs to be addressed.
| void hash_mm_items(MMInput& mm_input, MMData& mm_data) { | ||
| std::vector<Murmur3Key> mm_hashes; | ||
| const auto& mm_input_items = mm_input.items(); | ||
| auto& mm_items = mm_data.items<MMItemVec>(); | ||
| mm_hashes.reserve(mm_input_items.size()); | ||
| int size = mm_input_items.size(); | ||
| for (int idx = 0; idx < size; ++idx) { | ||
| auto data = mm_input_items[idx].raw_data; | ||
| if (!data.empty()) { | ||
| auto mm_hash = hash_string(data); | ||
| auto& schedule_data = | ||
| mm_items[idx].mutable_state().mutable_schedule_data(); | ||
| schedule_data.key = mm_hash; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
The hash_mm_items function iterates over mm_input_items and accesses mm_items using the same index without verifying that mm_items has sufficient size. If the number of items in mm_data (populated by the image processor) is less than the number of items in mm_input, an out-of-bounds access will occur, leading to a crash of the Master process. Additionally, the code calls mm_data.items<MMItemVec>() without verifying that the variant actually holds a MMItemVec, which can throw a std::bad_variant_access exception.
| auto& mm_items = mm_data.items<MMItemVec>(); | ||
|
|
||
| for (auto& mm_item : mm_items) { | ||
| const auto& pos = mm_item.state().token_pos(); | ||
| if (start_index >= pos.offset + pos.length) | ||
| cur_mm_idx++; | ||
| else | ||
| break; | ||
| } | ||
|
|
||
| size_t end_token_index; | ||
| size_t cur_token_index; | ||
|
|
||
| for (cur_token_index = start_index; cur_token_index < n_tokens; | ||
| cur_token_index += block_size_) { | ||
| end_token_index = cur_token_index + block_size_; | ||
| std::vector<const uint8_t*> mm_hash_values = get_block_mm_hash_values( | ||
| mm_data, cur_token_index, end_token_index, cur_mm_idx); | ||
| if (cur_token_index == 0) { | ||
| mm_murmur_hash3(mm_hash_values, | ||
| nullptr, | ||
| token_ids.slice(cur_token_index, end_token_index), | ||
| token_hash_key.data); | ||
| } else { | ||
| mm_murmur_hash3(mm_hash_values, | ||
| token_hash_key.data, | ||
| token_ids.slice(cur_token_index, end_token_index), | ||
| token_hash_key.data); | ||
| } | ||
|
|
||
| auto iter = cached_blocks_.find(token_hash_key); | ||
| if (iter != cached_blocks_.end()) { | ||
| blocks.push_back(iter->second->block); | ||
| lru_lst_.remove_node(iter->second); | ||
| node_list.push_front(iter->second); | ||
| } else { | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| // update LRU list | ||
| while (!node_list.is_empty()) { | ||
| Node* node = node_list.pop_front(); | ||
| lru_lst_.push_back(node); | ||
| } | ||
|
|
||
| matched_blocks_.fetch_add(blocks.size()); | ||
|
|
||
| int64_t int_rate_percent = static_cast<int64_t>( | ||
| static_cast<double>(blocks.size()) * 100.0 / n_blocks); | ||
| HISTOGRAM_OBSERVE(prefix_cache_block_matched_rate, int_rate_percent); | ||
| HISTOGRAM_OBSERVE(prefix_cache_block_matched_num, blocks.size()); | ||
|
|
||
| return blocks; | ||
| } | ||
|
|
||
| size_t MMPrefixCache::insert(Sequence* sequence, | ||
| const Slice<int32_t>& token_ids, | ||
| std::vector<Block>& blocks, | ||
| size_t existed_shared_blocks_num, | ||
| std::vector<Murmur3Key>* insert_keys) { | ||
| const int64_t now = absl::ToUnixMicros(absl::Now()); | ||
|
|
||
| // allign tokens to block boundary | ||
| const size_t n_blocks = | ||
| std::min(token_ids.size() / block_size_, blocks.size()); | ||
| const size_t n_tokens = n_blocks * block_size_; | ||
|
|
||
| if (n_blocks == 0) { | ||
| return 0; | ||
| } | ||
|
|
||
| // truncate the token ids and blocks to boundary | ||
| DNodeList node_list; | ||
| CHECK_GE(n_blocks, existed_shared_blocks_num); | ||
| Murmur3Key token_hash_key = | ||
| existed_shared_blocks_num == 0 | ||
| ? Murmur3Key{} | ||
| : Murmur3Key{blocks[existed_shared_blocks_num - 1] | ||
| .get_immutable_hash_value()}; | ||
| int32_t cur_mm_idx = 0; | ||
| uint32_t block_idx = existed_shared_blocks_num; | ||
| insert_keys->reserve(n_blocks); | ||
| const auto& mm_data = sequence->mm_data(); | ||
| const auto& mm_items = mm_data.items<MMItemVec>(); | ||
| for (auto& mm_item : mm_items) { | ||
| const auto& pos = mm_item.state().token_pos(); | ||
| if (existed_shared_blocks_num * block_size_ >= pos.offset + pos.length) { | ||
| cur_mm_idx++; | ||
| } else { | ||
| break; | ||
| } | ||
| } | ||
| for (size_t i = existed_shared_blocks_num * block_size_; i < n_tokens; | ||
| i += block_size_) { | ||
| std::vector<const uint8_t*> mm_hash_values = | ||
| get_block_mm_hash_values(mm_data, i, i + block_size_, cur_mm_idx); | ||
| if (i == 0) { | ||
| mm_murmur_hash3(mm_hash_values, | ||
| nullptr, | ||
| token_ids.slice(i, i + block_size_), | ||
| token_hash_key.data); | ||
| } else { | ||
| mm_murmur_hash3(mm_hash_values, | ||
| token_hash_key.data, | ||
| token_ids.slice(i, i + block_size_), | ||
| token_hash_key.data); | ||
| } | ||
| blocks[block_idx].set_hash_value(token_hash_key.data); | ||
|
|
||
| auto iter = cached_blocks_.find(token_hash_key); | ||
| if (iter != cached_blocks_.end()) { | ||
| iter->second->last_access_time = now; | ||
|
|
||
| lru_lst_.remove_node(iter->second); | ||
| node_list.push_front(iter->second); | ||
| } else { | ||
| Node* new_node = new Node(); | ||
|
|
||
| new_node->block = blocks[block_idx]; | ||
| new_node->last_access_time = now; | ||
|
|
||
| node_list.push_front(new_node); | ||
|
|
||
| cached_blocks_.emplace(std::make_pair(token_hash_key, new_node)); | ||
|
|
||
| num_blocks_++; | ||
|
|
||
| insert_keys->emplace_back(token_hash_key.data); | ||
| } | ||
|
|
||
| ++block_idx; | ||
| } | ||
|
|
||
| while (!node_list.is_empty()) { | ||
| Node* node = node_list.pop_front(); | ||
| lru_lst_.push_back(node); | ||
| } | ||
|
|
||
| return n_tokens; | ||
| } | ||
| std::vector<const uint8_t*> MMPrefixCache::get_block_mm_hash_values( | ||
| const MMData& mm_data, | ||
| int32_t start_token_idx, | ||
| int32_t end_token_idx, | ||
| int32_t& start_mm_idx) { | ||
| auto& mm_items = mm_data.items<MMItemVec>(); |
There was a problem hiding this comment.
The code frequently calls mm_data.items<MMItemVec>() (e.g., lines 94, 178, 240) without first verifying that the mm_data variant actually holds a MMItemVec (using hold<MMItemVec>()). If mm_data holds a different type (such as MMDict), these calls will throw a std::bad_variant_access exception, causing the engine process to crash.
| Node* new_node = new Node(); | ||
|
|
||
| new_node->block = blocks[block_idx]; | ||
| new_node->last_access_time = now; | ||
|
|
||
| node_list.push_front(new_node); | ||
|
|
||
| cached_blocks_.emplace(std::make_pair(token_hash_key, new_node)); | ||
|
|
||
| num_blocks_++; | ||
|
|
||
| insert_keys->emplace_back(token_hash_key.data); | ||
| } |
There was a problem hiding this comment.
In MMPrefixCache::insert, a new Node is allocated and added to a local DNodeList and the cached_blocks_ map. If insert_keys->emplace_back on line 222 throws an exception (e.g., std::bad_alloc), the node_list destructor will delete the Node, but the cached_blocks_ map will still contain a pointer to the deleted memory. Subsequent access to this key in the cache will result in a use-after-free vulnerability.
| int32_t& start_mm_idx); | ||
|
|
||
| private: | ||
| ThreadPool threadpool_; |
There was a problem hiding this comment.
The threadpool_ member in MMPrefixCache is declared but appears to be unused throughout the implementation in mm_prefix_cache.cpp. If there are no plans for its immediate use, it should be removed to avoid confusion and unnecessary overhead. Unused members can make the code harder to understand and maintain, and in performance-critical code, it might imply an incomplete implementation of parallelization.
c4925bd to
0717c5f
Compare
No description provided.