Skip to content

Commit c4925bd

Browse files
committed
feat: support prefix cache for multi-modal model.
1 parent 5137f4e commit c4925bd

55 files changed

Lines changed: 1059 additions & 325 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

xllm/core/distributed_runtime/vlm_engine.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,8 @@ bool VLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
288288
.block_size(block_size)
289289
.enable_prefix_cache(options_.enable_prefix_cache())
290290
.enable_disagg_pd(options_.enable_disagg_pd())
291-
.enable_cache_upload(options_.enable_cache_upload());
291+
.enable_cache_upload(options_.enable_cache_upload())
292+
.enable_mm_prefix_cache(options_.enable_prefix_cache());
292293
kv_cache_manager_ = std::make_unique<BlockManagerPool>(options);
293294

294295
// init kv cache for each worker in parallel

xllm/core/distributed_runtime/vlm_master.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ std::shared_ptr<Request> VLMMaster::generate_request(
418418
"Image processor process failed.");
419419
return nullptr;
420420
}
421-
421+
input_processor_->hash_mm_items(mm_inputs, mm_data);
422422
auto prompt = chat_template_->apply(messages);
423423
if (!prompt.has_value()) {
424424
CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT,

xllm/core/framework/batch/batch.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,6 @@ void Batch::add(Sequence* sequence, uint32_t allowed_max_token) {
5353
if (input_embedding.defined())
5454
input_embeddings_vec_.emplace_back(input_embedding);
5555

56-
const auto& mm_data = sequence->get_mm_data();
57-
// if (sequence->is_chunked_prefill_stage() && mm_data.valid())
58-
// TODO:Compatible With Chunked Prefill
59-
if ((sequence->stage() == SequenceStage::PREFILL) && mm_data.valid()) {
60-
mm_data_vec_.emplace_back(mm_data);
61-
}
6256
update_forward_type(sequence);
6357
}
6458

@@ -315,9 +309,8 @@ void Batch::process_sample_output(const RawForwardOutput& raw_output,
315309
}
316310
CHECK_LT(output_idx, num_seqs);
317311

318-
// mm embed task
319-
if (raw_output.mm_embeddings.size() > 0) {
320-
int64_t n_images = seq->get_mm_data().size();
312+
if (raw_output.mm_embeddings.size() > 0) { // mm embed task
313+
int64_t n_images = seq->mm_data().size();
321314
if (n_images > 0) {
322315
std::vector<torch::Tensor> seq_mm_embeddings;
323316
seq_mm_embeddings.reserve(n_images);

xllm/core/framework/batch/batch_input_builder.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828
#include "framework/model/model_input_params.h"
2929
#include "framework/request/sequence.h"
3030
#include "framework/sampling/sampling_params.h"
31+
#include "request/mm_data_visitor.h"
3132
#include "runtime/params_utils.h"
3233
#include "util/blocking_counter.h"
3334
#include "util/slice.h"
@@ -303,6 +304,8 @@ void BatchInputBuilder::process_single_sequence(
303304
state.seq_lens.push_back(state.seq_lens.back() + seq_len + offset);
304305
state.q_seq_lens.push_back(state.q_seq_lens.back() + q_seq_len);
305306
#endif
307+
// Process multi-modal input
308+
process_multi_modal_inputs(sequence, n_kv_cache_tokens, q_seq_len);
306309
// Process tokens and positions
307310
extract_tokens_and_positions(sequence, n_kv_cache_tokens, seq_len, state_ptr);
308311

@@ -340,7 +343,11 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence,
340343
if (use_mrope_) {
341344
const auto& args = *args_;
342345
MPositionHelper helper(*sequence, args);
343-
state.mrope_positions_vec.emplace_back(helper.get_positions());
346+
const auto& whole_positions = helper.get_positions();
347+
auto position = (sequence->stage() == SequenceStage::DECODE)
348+
? whole_positions
349+
: whole_positions.slice(1, n_kv_cache_tokens, seq_len);
350+
state.mrope_positions_vec.push_back(position);
344351
}
345352

346353
// Process each token
@@ -734,4 +741,16 @@ void BatchInputBuilder::process_swap_block_infos(
734741
swap_block_transfer_infos_->end());
735742
}
736743
}
744+
745+
void BatchInputBuilder::process_multi_modal_inputs(Sequence* sequence,
746+
uint32_t n_kv_cache_tokens,
747+
uint32_t q_seq_len) {
748+
MMData& mm_data = sequence->mutable_mm_data();
749+
if ((sequence->stage() != SequenceStage::DECODE) && mm_data.valid()) {
750+
UpdateMMItemScheduleStateVisitor visitor(n_kv_cache_tokens, q_seq_len);
751+
mm_data.foreach (visitor);
752+
MMType ty{static_cast<MMType::Value>(mm_data.type())};
753+
mm_data_vec_.emplace_back(MMData(ty, std::move(visitor.mm_data_items_)));
754+
}
755+
}
737756
} // namespace xllm

xllm/core/framework/batch/batch_input_builder.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ class BatchInputBuilder {
5555
void process_sequences_multithreaded();
5656
void padding_decode_batch_size(uint32_t num_decoding_tokens,
5757
uint32_t min_decoding_batch_size);
58+
void process_multi_modal_inputs(Sequence* sequence,
59+
uint32_t n_kv_cache_tokens,
60+
uint32_t q_seq_len);
5861
ForwardInput state_to_forward_input();
5962
RawForwardInput state_to_raw_forward_input();
6063

@@ -145,7 +148,7 @@ class BatchInputBuilder {
145148
const std::vector<Sequence*>& sequences_;
146149
const std::vector<uint32_t>& allowed_max_tokens_;
147150
const std::vector<torch::Tensor>& input_embeddings_vec_;
148-
const std::vector<MMData>& mm_data_vec_;
151+
std::vector<MMData> mm_data_vec_;
149152
const ModelArgs* args_;
150153

151154
// Builder state

xllm/core/framework/batch/mposition.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include <absl/strings/match.h>
1919

2020
#include "framework/model/model_args.h"
21+
#include "framework/request/mm_batch_data.h"
2122
#include "framework/request/sequence.h"
2223

2324
namespace xllm {
@@ -46,7 +47,8 @@ std::vector<std::tuple<std::string, int, int>> groupByTokenType(
4647
torch::Tensor MPositionHelper::get_positions() {
4748
// if (seq_.is_chunked_prefill_stage()) {
4849
if (seq_.kv_state().kv_cache_tokens_num() < seq_.num_prompt_tokens()) {
49-
auto& mm_data = seq_.get_mm_data();
50+
auto& data = seq_.mm_data();
51+
MMBatchData mm_data({data});
5052

5153
torch::Tensor image_grid_thw;
5254
if (auto res = mm_data.get<torch::Tensor>("image_grid_thw"))

xllm/core/framework/batch/onerec_batch_input_builder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ ForwardInput OneRecBatchInputBuilder::build_rec_forward_input(
172172
src_ptr + group_encoder_seq_len);
173173
}
174174
// Collect sparse_embedding
175-
auto mm_data = sequence->get_mm_data();
175+
auto mm_data = sequence->mm_data();
176176
auto sparse_embedding_optional =
177177
mm_data.get<torch::Tensor>(Sequence::ENCODER_SPARSE_EMBEDDING_NAME);
178178
if (sparse_embedding_optional.has_value()) {

xllm/core/framework/block/block_manager.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class BlockManager {
4545
PROPERTY(int32_t, block_size) = 0;
4646
PROPERTY(bool, enable_prefix_cache) = true;
4747
PROPERTY(bool, enable_disagg_pd) = false;
48+
PROPERTY(bool, enable_mm_prefix_cache) = false;
4849
PROPERTY(bool, enable_cache_upload) = false;
4950
};
5051

@@ -56,10 +57,12 @@ class BlockManager {
5657
virtual std::vector<Block> allocate(size_t num_blocks) = 0;
5758

5859
virtual std::vector<Block> allocate_shared(
60+
Sequence* sequence,
5961
const Slice<int32_t>& tokens_ids,
6062
const Slice<Block>& existed_shared_blocks = {}) = 0;
6163

62-
virtual void cache(const Slice<int32_t>& token_ids,
64+
virtual void cache(Sequence* sequence,
65+
const Slice<int32_t>& token_ids,
6366
std::vector<Block>& blocks,
6467
size_t existed_shared_blocks_num = 0) = 0;
6568
virtual void cache(const std::vector<Block>& blocks) = 0;

xllm/core/framework/block/block_manager_impl.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@ BlockManagerImpl::BlockManagerImpl(const Options& options)
2626
CHECK_GT(options.num_blocks(), 0) << "No blocks to allocate";
2727
CHECK_GT(options.block_size(), 0) << "Block size must be positive";
2828
if (options_.enable_prefix_cache()) {
29-
prefix_cache_ = create_prefix_cache(options.block_size(),
30-
options.enable_cache_upload());
29+
PrefixCache::Options prefix_cache_options;
30+
prefix_cache_options.block_size(options.block_size())
31+
.enable_cache_upload(options.enable_cache_upload())
32+
.enable_mm_prefix_cache(options.enable_mm_prefix_cache());
33+
prefix_cache_ = create_prefix_cache(prefix_cache_options);
3134
CHECK(prefix_cache_) << "Failed to create prefix cache!";
3235
}
3336

@@ -122,14 +125,15 @@ bool BlockManagerImpl::has_enough_blocks(uint32_t num_blocks) {
122125
}
123126

124127
std::vector<Block> BlockManagerImpl::allocate_shared(
128+
Sequence* sequence,
125129
const Slice<int32_t>& tokens_ids,
126130
const Slice<Block>& existed_shared_blocks) {
127131
// only allocate shared blocks for prefill sequences
128132
if (options_.enable_prefix_cache()) {
129133
AUTO_COUNTER(prefix_cache_latency_seconds_match);
130134

131135
std::vector<Block> shared_blocks =
132-
prefix_cache_->match(tokens_ids, existed_shared_blocks);
136+
prefix_cache_->match(sequence, tokens_ids, existed_shared_blocks);
133137

134138
const size_t prefix_length =
135139
shared_blocks.empty() ? 0
@@ -148,13 +152,17 @@ std::vector<Block> BlockManagerImpl::allocate_shared(
148152
return {};
149153
}
150154

151-
void BlockManagerImpl::cache(const Slice<int32_t>& token_ids,
155+
void BlockManagerImpl::cache(Sequence* sequence,
156+
const Slice<int32_t>& token_ids,
152157
std::vector<Block>& blocks,
153158
size_t existed_shared_blocks_num) {
154159
if (options_.enable_prefix_cache()) {
155160
AUTO_COUNTER(prefix_cache_latency_seconds_insert);
156161
// Add the kv cache to the prefix cache
157-
prefix_cache_->insert(token_ids, blocks, existed_shared_blocks_num);
162+
prefix_cache_->insert(sequence,
163+
token_ids,
164+
blocks,
165+
existed_shared_blocks_num);
158166
}
159167
}
160168

xllm/core/framework/block/block_manager_impl.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,13 @@ class BlockManagerImpl : public BlockManager {
3737

3838
// allocate shared blocks when enable prefix cache
3939
std::vector<Block> allocate_shared(
40+
Sequence* sequence,
4041
const Slice<int32_t>& tokens_ids,
4142
const Slice<Block>& existed_shared_blocks = {}) override;
4243

4344
// cache blocks when enable prefix cache
44-
void cache(const Slice<int32_t>& token_ids,
45+
void cache(Sequence* sequence,
46+
const Slice<int32_t>& token_ids,
4547
std::vector<Block>& blocks,
4648
size_t existed_shared_blocks_num = 0) override;
4749
void cache(const std::vector<Block>& blocks) override;

0 commit comments

Comments
 (0)