feat: support qwen3-next on npu device.#945
feat: support qwen3-next on npu device.#945liyu119 wants to merge 7 commits intojd-opensource:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for the qwen3-next model on NPU devices, introducing new model architecture files, custom kernels, and updates to KV cache management for linear attention. The changes are extensive and well-structured. However, I've identified a few critical issues related to incorrect memory allocation for the new caches and a constructor signature mismatch that would lead to compilation failure. These issues need to be addressed to ensure correctness and allow the code to compile.
| int64_t head_k_dim = args_.linear_value_head_dim(); | ||
| int64_t head_v_dim = args_.linear_key_head_dim(); |
There was a problem hiding this comment.
There appears to be a variable naming swap here. head_k_dim is being initialized with linear_value_head_dim, and head_v_dim with linear_key_head_dim. This is likely to cause incorrect calculations for linear_ssm_slot_size and linear_conv_slot_size, leading to memory allocation errors or incorrect behavior. Please swap the initializations to match the variable names.
| int64_t head_k_dim = args_.linear_value_head_dim(); | |
| int64_t head_v_dim = args_.linear_key_head_dim(); | |
| int64_t head_k_dim = args_.linear_key_head_dim(); | |
| int64_t head_v_dim = args_.linear_value_head_dim(); |
| args_.linear_key_head_dim() * n_local_linear_v_heads_, args_.linear_conv_kernel_dim() - 1}); | ||
| kv_cache_shape.emplace_back(std::vector<int64_t>{ | ||
| kv_cache_cap.n_blocks, n_local_linear_v_heads_, args_.linear_key_head_dim(), | ||
| args_.linear_key_head_dim()}); |
There was a problem hiding this comment.
The shape for the SSM cache appears to be incorrect. Both the third and fourth dimensions are set to args_.linear_key_head_dim(). The SSM state typically has dimensions corresponding to key and value head dimensions (k_dim, v_dim). It should likely be args_.linear_value_head_dim() for the last dimension to correctly represent the state.
| args_.linear_key_head_dim()}); | |
| args_.linear_value_head_dim()}); |
| KVCache(torch::Tensor key_cache, | ||
| torch::Tensor value_cache, | ||
| torch::Tensor conv_cache, | ||
| torch::Tensor ssm_cache); |
There was a problem hiding this comment.
This new constructor for KVCache takes four torch::Tensor arguments. However, it is being called with five arguments (key_cache, value_cache, index_cache, conv_cache, ssm_cache) in xllm/core/runtime/worker_impl.cpp on line 148. This will cause a compilation error. The constructor should be updated to accept all five tensors to correctly initialize all cache types. The implementation in kv_cache.cpp will also need to be updated to initialize all five members.
KVCache(torch::Tensor key_cache,
torch::Tensor value_cache,
torch::Tensor index_cache,
torch::Tensor conv_cache,
torch::Tensor ssm_cache);| } | ||
| #endif | ||
| kv_caches_.emplace_back(key_cache, value_cache, index_cache); | ||
| kv_caches_.emplace_back(key_cache, value_cache, index_cache, conv_cache, ssm_cache); |
There was a problem hiding this comment.
This line attempts to construct a KVCache object with five arguments. However, there is no matching constructor defined for KVCache that accepts five tensors. The newly added constructor in kv_cache.h only takes four arguments. This will result in a compilation error. Please ensure the KVCache class has a constructor that matches this call.
| @@ -0,0 +1,44 @@ | |||
| /* Copyright 2025 The xLLM Authors. All Rights Reserved. | |||
There was a problem hiding this comment.
place this file to models/llm/npu.
There was a problem hiding this comment.
放在 models/llm/ 下面吧,这个是 torch 组图通用的,不是atb 组图
|
|
||
| // qwen3 next | ||
| PROPERTY(bool, attn_output_gate) = true; | ||
| PROPERTY(int32_t, full_attention_interval) = 4; |
There was a problem hiding this comment.
default value of full_attention_interval should be set to 1, in case other models don't have this config can behave correctly
| return padded_qkvz; | ||
| } | ||
| std::vector<torch::Tensor> valid_batches; | ||
| int64_t bs = attn_metadata.query_start_loc.size(0); |
There was a problem hiding this comment.
qwen3_next_gated_delta_net.cpp:418:32: error: ‘const struct xllm::layer::AttentionMetadata’ has no member named ‘query_start_loc’
418 | int64_t bs = attn_metadata.query_start_loc.size(0);
| torch::Tensor& weight, | ||
| bool& weight_is_loaded); | ||
|
|
||
| void load_merged_weight_v2(const StateDict& state_dict, |
There was a problem hiding this comment.
#define DEFINE_MERGED_WEIGHT_V2(name) \
| std::vector<torch::Tensor> valid_batches; | ||
| int64_t bs = attn_metadata.query_start_loc.size(0); | ||
| int64_t max_len = attn_metadata.max_query_len; | ||
| const auto& ori_seq_lens = attn_metadata.query_start_loc; |
There was a problem hiding this comment.
qwen3_next_gated_delta_net.cpp:420:46: error: ‘const struct xllm::layer::AttentionMetadata’ has no member named ‘query_start_loc’
420 | const auto& ori_seq_lens = attn_metadata.query_start_loc;
| } | ||
|
|
||
| private: | ||
| layer::Qwen3NextDecoderLayer decoder_layer_{nullptr}; |
There was a problem hiding this comment.
Qwen3NextDecoderLayer’ in namespace ‘xllm::layer’ does not name a type;
|
上一个moe的pr, |
| } | ||
| #endif | ||
| kv_caches_.emplace_back(key_cache, value_cache, index_cache); | ||
| kv_caches_.emplace_back(key_cache, value_cache, index_cache, conv_cache, ssm_cache); |
There was a problem hiding this comment.
why there have five arguments?
Uh oh!
There was an error while loading. Please reload this page.