Skip to content

Commit bd473fc

Browse files
committed
feat: support deepseek prefixcache.
1 parent fdec742 commit bd473fc

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

xllm/core/runtime/llm_engine.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ limitations under the License.
4141

4242
namespace xllm {
4343

44+
constexpr int32_t NZ_ALIGNMENT = 16;
45+
4446
namespace {
4547
uint32_t determine_micro_batches_num(const std::vector<Batch>& batch) {
4648
bool not_all_in_decode =
@@ -258,8 +260,10 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
258260
int64_t slot_size = 0;
259261
if (FLAGS_enable_mla) {
260262
if (FLAGS_enable_prefix_cache) {
261-
slot_size = dtype_size * ((args_.kv_lora_rank() + 15) / 16 * 16 +
262-
(args_.qk_rope_head_dim() + 15) / 16 * 16);
263+
slot_size =
264+
dtype_size *
265+
((args_.kv_lora_rank() + NZ_ALIGNMENT - 1) / NZ_ALIGNMENT +
266+
(args_.qk_rope_head_dim() + NZ_ALIGNMENT - 1) / NZ_ALIGNMENT);
263267
} else {
264268
slot_size =
265269
dtype_size * (args_.kv_lora_rank() + args_.qk_rope_head_dim());

xllm/core/runtime/worker_impl.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ limitations under the License.
4848
namespace xllm {
4949

5050
constexpr uint64_t MBUF_SIZE = 128 * 1024 * 1024;
51+
constexpr int32_t FORMAT_ND = 2;
52+
constexpr int32_t FORMAT_NZ = 29;
5153

5254
WorkerImpl::WorkerImpl(const ParallelArgs& parallel_args,
5355
const torch::Device& device,
@@ -88,7 +90,7 @@ bool WorkerImpl::allocate_kv_cache(
8890
torch::Tensor key_cache, value_cache;
8991
#if defined(USE_NPU)
9092
int32_t npu_format_type =
91-
FLAGS_enable_mla && FLAGS_enable_prefix_cache ? 29 : 2;
93+
FLAGS_enable_mla && FLAGS_enable_prefix_cache ? FORMAT_NZ : FORMAT_ND;
9294
key_cache = at_npu::native::npu_format_cast(
9395
torch::empty(kv_cache_shape[0], torch::dtype(dtype_).device(device_)),
9496
npu_format_type);

xllm/core/util/utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ torch::ScalarType parse_dtype(const std::string& dtype_str,
8080
}
8181
if ((boost::iequals(dtype_str, "float") ||
8282
boost::iequals(dtype_str, "float32"))) {
83-
return torch::kFloat16;
83+
return torch::kFloat32;
8484
}
8585

8686
if (dtype_str.empty() || boost::iequals(dtype_str, "auto")) {
@@ -99,7 +99,7 @@ torch::ScalarType parse_dtype(const std::string& dtype_str) {
9999
}
100100
if ((boost::iequals(dtype_str, "float") ||
101101
boost::iequals(dtype_str, "float32"))) {
102-
return torch::kFloat16;
102+
return torch::kFloat32;
103103
}
104104

105105
if (dtype_str.empty() || boost::iequals(dtype_str, "auto")) {

0 commit comments

Comments
 (0)