File tree Expand file tree Collapse file tree 3 files changed +11
-5
lines changed
Expand file tree Collapse file tree 3 files changed +11
-5
lines changed Original file line number Diff line number Diff line change @@ -41,6 +41,8 @@ limitations under the License.
4141
4242namespace xllm {
4343
44+ constexpr int32_t NZ_ALIGNMENT = 16 ;
45+
4446namespace {
4547uint32_t determine_micro_batches_num (const std::vector<Batch>& batch) {
4648 bool not_all_in_decode =
@@ -258,8 +260,10 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
258260 int64_t slot_size = 0 ;
259261 if (FLAGS_enable_mla) {
260262 if (FLAGS_enable_prefix_cache) {
261- slot_size = dtype_size * ((args_.kv_lora_rank () + 15 ) / 16 * 16 +
262- (args_.qk_rope_head_dim () + 15 ) / 16 * 16 );
263+ slot_size =
264+ dtype_size *
265+ ((args_.kv_lora_rank () + NZ_ALIGNMENT - 1 ) / NZ_ALIGNMENT +
266+ (args_.qk_rope_head_dim () + NZ_ALIGNMENT - 1 ) / NZ_ALIGNMENT);
263267 } else {
264268 slot_size =
265269 dtype_size * (args_.kv_lora_rank () + args_.qk_rope_head_dim ());
Original file line number Diff line number Diff line change @@ -48,6 +48,8 @@ limitations under the License.
4848namespace xllm {
4949
5050constexpr uint64_t MBUF_SIZE = 128 * 1024 * 1024 ;
51+ constexpr int32_t FORMAT_ND = 2 ;
52+ constexpr int32_t FORMAT_NZ = 29 ;
5153
5254WorkerImpl::WorkerImpl (const ParallelArgs& parallel_args,
5355 const torch::Device& device,
@@ -88,7 +90,7 @@ bool WorkerImpl::allocate_kv_cache(
8890 torch::Tensor key_cache, value_cache;
8991#if defined(USE_NPU)
9092 int32_t npu_format_type =
91- FLAGS_enable_mla && FLAGS_enable_prefix_cache ? 29 : 2 ;
93+ FLAGS_enable_mla && FLAGS_enable_prefix_cache ? FORMAT_NZ : FORMAT_ND ;
9294 key_cache = at_npu::native::npu_format_cast (
9395 torch::empty (kv_cache_shape[0 ], torch::dtype (dtype_).device (device_)),
9496 npu_format_type);
Original file line number Diff line number Diff line change @@ -80,7 +80,7 @@ torch::ScalarType parse_dtype(const std::string& dtype_str,
8080 }
8181 if ((boost::iequals (dtype_str, " float" ) ||
8282 boost::iequals (dtype_str, " float32" ))) {
83- return torch::kFloat16 ;
83+ return torch::kFloat32 ;
8484 }
8585
8686 if (dtype_str.empty () || boost::iequals (dtype_str, " auto" )) {
@@ -99,7 +99,7 @@ torch::ScalarType parse_dtype(const std::string& dtype_str) {
9999 }
100100 if ((boost::iequals (dtype_str, " float" ) ||
101101 boost::iequals (dtype_str, " float32" ))) {
102- return torch::kFloat16 ;
102+ return torch::kFloat32 ;
103103 }
104104
105105 if (dtype_str.empty () || boost::iequals (dtype_str, " auto" )) {
You can’t perform that action at this time.
0 commit comments