Skip to content

feat: support prefix cache for multi-modal model.#997

Draft
wly-115 wants to merge 1 commit intojd-opensource:mainfrom
wly-115:feat/mm_prefix_cache
Draft

feat: support prefix cache for multi-modal model.#997
wly-115 wants to merge 1 commit intojd-opensource:mainfrom
wly-115:feat/mm_prefix_cache

Conversation

@wly-115
Copy link
Collaborator

@wly-115 wly-115 commented Mar 4, 2026

No description provided.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant support for prefix caching in multi-modal models, involving extensive changes to block management, prefix cache implementation, input batching, and model-specific logic, notably the new MMPrefixCache and refactored MMItemState. However, several security and stability issues were identified, including unchecked variant access and potential out-of-bounds array access that could lead to Denial of Service (DoS) crashes, as well as an exception-safety issue in the prefix cache insertion logic that might cause dangling pointers and use-after-free vulnerabilities. Additionally, a high-severity code quality issue regarding an unused member variable needs to be addressed.

Comment on lines +34 to +49
void hash_mm_items(MMInput& mm_input, MMData& mm_data) {
std::vector<Murmur3Key> mm_hashes;
const auto& mm_input_items = mm_input.items();
auto& mm_items = mm_data.items<MMItemVec>();
mm_hashes.reserve(mm_input_items.size());
int size = mm_input_items.size();
for (int idx = 0; idx < size; ++idx) {
auto data = mm_input_items[idx].raw_data;
if (!data.empty()) {
auto mm_hash = hash_string(data);
auto& schedule_data =
mm_items[idx].mutable_state().mutable_schedule_data();
schedule_data.key = mm_hash;
}
}
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The hash_mm_items function iterates over mm_input_items and accesses mm_items using the same index without verifying that mm_items has sufficient size. If the number of items in mm_data (populated by the image processor) is less than the number of items in mm_input, an out-of-bounds access will occur, leading to a crash of the Master process. Additionally, the code calls mm_data.items<MMItemVec>() without verifying that the variant actually holds a MMItemVec, which can throw a std::bad_variant_access exception.

Comment on lines +94 to +240
auto& mm_items = mm_data.items<MMItemVec>();

for (auto& mm_item : mm_items) {
const auto& pos = mm_item.state().token_pos();
if (start_index >= pos.offset + pos.length)
cur_mm_idx++;
else
break;
}

size_t end_token_index;
size_t cur_token_index;

for (cur_token_index = start_index; cur_token_index < n_tokens;
cur_token_index += block_size_) {
end_token_index = cur_token_index + block_size_;
std::vector<const uint8_t*> mm_hash_values = get_block_mm_hash_values(
mm_data, cur_token_index, end_token_index, cur_mm_idx);
if (cur_token_index == 0) {
mm_murmur_hash3(mm_hash_values,
nullptr,
token_ids.slice(cur_token_index, end_token_index),
token_hash_key.data);
} else {
mm_murmur_hash3(mm_hash_values,
token_hash_key.data,
token_ids.slice(cur_token_index, end_token_index),
token_hash_key.data);
}

auto iter = cached_blocks_.find(token_hash_key);
if (iter != cached_blocks_.end()) {
blocks.push_back(iter->second->block);
lru_lst_.remove_node(iter->second);
node_list.push_front(iter->second);
} else {
break;
}
}

// update LRU list
while (!node_list.is_empty()) {
Node* node = node_list.pop_front();
lru_lst_.push_back(node);
}

matched_blocks_.fetch_add(blocks.size());

int64_t int_rate_percent = static_cast<int64_t>(
static_cast<double>(blocks.size()) * 100.0 / n_blocks);
HISTOGRAM_OBSERVE(prefix_cache_block_matched_rate, int_rate_percent);
HISTOGRAM_OBSERVE(prefix_cache_block_matched_num, blocks.size());

return blocks;
}

size_t MMPrefixCache::insert(Sequence* sequence,
const Slice<int32_t>& token_ids,
std::vector<Block>& blocks,
size_t existed_shared_blocks_num,
std::vector<Murmur3Key>* insert_keys) {
const int64_t now = absl::ToUnixMicros(absl::Now());

// allign tokens to block boundary
const size_t n_blocks =
std::min(token_ids.size() / block_size_, blocks.size());
const size_t n_tokens = n_blocks * block_size_;

if (n_blocks == 0) {
return 0;
}

// truncate the token ids and blocks to boundary
DNodeList node_list;
CHECK_GE(n_blocks, existed_shared_blocks_num);
Murmur3Key token_hash_key =
existed_shared_blocks_num == 0
? Murmur3Key{}
: Murmur3Key{blocks[existed_shared_blocks_num - 1]
.get_immutable_hash_value()};
int32_t cur_mm_idx = 0;
uint32_t block_idx = existed_shared_blocks_num;
insert_keys->reserve(n_blocks);
const auto& mm_data = sequence->mm_data();
const auto& mm_items = mm_data.items<MMItemVec>();
for (auto& mm_item : mm_items) {
const auto& pos = mm_item.state().token_pos();
if (existed_shared_blocks_num * block_size_ >= pos.offset + pos.length) {
cur_mm_idx++;
} else {
break;
}
}
for (size_t i = existed_shared_blocks_num * block_size_; i < n_tokens;
i += block_size_) {
std::vector<const uint8_t*> mm_hash_values =
get_block_mm_hash_values(mm_data, i, i + block_size_, cur_mm_idx);
if (i == 0) {
mm_murmur_hash3(mm_hash_values,
nullptr,
token_ids.slice(i, i + block_size_),
token_hash_key.data);
} else {
mm_murmur_hash3(mm_hash_values,
token_hash_key.data,
token_ids.slice(i, i + block_size_),
token_hash_key.data);
}
blocks[block_idx].set_hash_value(token_hash_key.data);

auto iter = cached_blocks_.find(token_hash_key);
if (iter != cached_blocks_.end()) {
iter->second->last_access_time = now;

lru_lst_.remove_node(iter->second);
node_list.push_front(iter->second);
} else {
Node* new_node = new Node();

new_node->block = blocks[block_idx];
new_node->last_access_time = now;

node_list.push_front(new_node);

cached_blocks_.emplace(std::make_pair(token_hash_key, new_node));

num_blocks_++;

insert_keys->emplace_back(token_hash_key.data);
}

++block_idx;
}

while (!node_list.is_empty()) {
Node* node = node_list.pop_front();
lru_lst_.push_back(node);
}

return n_tokens;
}
std::vector<const uint8_t*> MMPrefixCache::get_block_mm_hash_values(
const MMData& mm_data,
int32_t start_token_idx,
int32_t end_token_idx,
int32_t& start_mm_idx) {
auto& mm_items = mm_data.items<MMItemVec>();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The code frequently calls mm_data.items<MMItemVec>() (e.g., lines 94, 178, 240) without first verifying that the mm_data variant actually holds a MMItemVec (using hold<MMItemVec>()). If mm_data holds a different type (such as MMDict), these calls will throw a std::bad_variant_access exception, causing the engine process to crash.

Comment on lines +211 to +223
Node* new_node = new Node();

new_node->block = blocks[block_idx];
new_node->last_access_time = now;

node_list.push_front(new_node);

cached_blocks_.emplace(std::make_pair(token_hash_key, new_node));

num_blocks_++;

insert_keys->emplace_back(token_hash_key.data);
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

In MMPrefixCache::insert, a new Node is allocated and added to a local DNodeList and the cached_blocks_ map. If insert_keys->emplace_back on line 222 throws an exception (e.g., std::bad_alloc), the node_list destructor will delete the Node, but the cached_blocks_ map will still contain a pointer to the deleted memory. Subsequent access to this key in the cache will result in a use-after-free vulnerability.

int32_t& start_mm_idx);

private:
ThreadPool threadpool_;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The threadpool_ member in MMPrefixCache is declared but appears to be unused throughout the implementation in mm_prefix_cache.cpp. If there are no plans for its immediate use, it should be removed to avoid confusion and unnecessary overhead. Unused members can make the code harder to understand and maintain, and in performance-critical code, it might imply an incomplete implementation of parallelization.

@wly-115 wly-115 force-pushed the feat/mm_prefix_cache branch from c4925bd to 0717c5f Compare March 5, 2026 08:34
@wly-115 wly-115 marked this pull request as draft March 5, 2026 08:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant