Skip to content

Commit 1da759f

Browse files
committed
feat: add wrappers for ATB and ACLNN fused operators.
1 parent 7ec13b0 commit 1da759f

38 files changed

+1874
-3107
lines changed
Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,29 @@
11
include(cc_library)
22

3-
add_subdirectory(impl)
43
add_subdirectory(xllm_ops)
54

5+
file(GLOB_RECURSE XLLM_CORE_KERNELS_NPU_HEADER
6+
"${CMAKE_CURRENT_LIST_DIR}/custom_functions_npu/*.h"
7+
"${CMAKE_CURRENT_LIST_DIR}/ops_npu/*.h"
8+
"${CMAKE_CURRENT_LIST_DIR}/*.h"
9+
)
10+
11+
file(GLOB_RECURSE XLLM_CORE_KERNELS_NPU_SRCS
12+
"${CMAKE_CURRENT_LIST_DIR}/custom_functions_npu/*.cpp"
13+
"${CMAKE_CURRENT_LIST_DIR}/ops_npu/*.cpp"
14+
"${CMAKE_CURRENT_LIST_DIR}/*.cpp"
15+
)
16+
617
cc_library(
718
NAME
819
npu_kernels
920
HDRS
10-
linear.h
11-
split.h
12-
rms_norm.h
13-
rope.h
21+
${XLLM_CORE_KERNELS_NPU_HEADER}
22+
SRCS
23+
${XLLM_CORE_KERNELS_NPU_SRCS}
1424
DEPS
15-
:npu_kernels_impl
16-
# spdlog::spdlog
17-
)
25+
:model_context
26+
glog::glog
27+
torch
28+
torch_npu
29+
)

xllm/core/kernels/npu/rms_norm.h renamed to xllm/core/kernels/npu/active.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,18 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#pragma once
17-
#include "impl/npu_rms_norm_impl.h"
16+
#include <torch_npu/csrc/aten/CustomFunctions.h>
1817

19-
namespace xllm {
20-
namespace kernel {
18+
#include "npu_ops_api.h"
19+
#include "ops_npu/npu_ops.h"
2120

22-
class RmsNorm : public torch::nn::ModuleHolder<NpuRmsNormImpl> {
23-
public:
24-
using torch::nn::ModuleHolder<NpuRmsNormImpl>::ModuleHolder;
25-
using Impl __attribute__((__unused__)) = NpuRmsNormImpl;
21+
namespace xllm::kernel::npu {
2622

27-
RmsNorm(const ModelContext& context)
28-
: ModuleHolder(std::make_shared<NpuRmsNormImpl>(context)) {}
29-
};
30-
31-
} // namespace kernel
32-
} // namespace xllm
23+
torch::Tensor active(const torch::Tensor& input, const std::string& act_mode) {
24+
if (act_mode != "silu" && act_mode != "swiglu") {
25+
throw std::runtime_error(
26+
"Only swiglu activation is supported in NPU active");
27+
}
28+
return at_npu::native::custom_ops::npu_swiglu(input);
29+
}
30+
} // namespace xllm::kernel::npu
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "npu_ops_api.h"
17+
#include "ops_npu/npu_ops.h"
18+
namespace xllm::kernel::npu {
19+
20+
void reshape_paged_cache(torch::Tensor& key,
21+
std::optional<torch::Tensor>& value,
22+
torch::Tensor& k_cache,
23+
std::optional<torch::Tensor>& v_cache,
24+
const torch::Tensor& slot_mapping) {
25+
atb::_npu_reshape_and_cache(
26+
key, value.value(), k_cache, v_cache.value(), slot_mapping);
27+
}
28+
29+
void batch_prefill(const torch::Tensor& query,
30+
const torch::Tensor& key,
31+
const torch::Tensor& value,
32+
const torch::Tensor& mask,
33+
const torch::Tensor& seq_len,
34+
float scale,
35+
torch::Tensor& output) {
36+
auto num_heads = query.size(-2);
37+
auto num_kv_heads = key.size(-2);
38+
atb::_npu_flash_attention(
39+
query, key, value, mask, seq_len, scale, num_heads, num_kv_heads, output);
40+
}
41+
42+
void batch_decode(const torch::Tensor& query,
43+
const torch::Tensor& k_cache,
44+
const torch::Tensor& v_cache,
45+
float scale,
46+
const torch::Tensor& block_table,
47+
const torch::Tensor& seq_lens,
48+
torch::Tensor& output) {
49+
auto head_size = query.size(-1);
50+
auto num_heads = query.size(-2);
51+
auto num_kv_heads = k_cache.size(-2);
52+
auto q = query.view({-1, num_heads, head_size});
53+
auto o = output.view({-1, num_heads, head_size});
54+
atb::_npu_paged_attention(q,
55+
k_cache,
56+
v_cache,
57+
num_kv_heads,
58+
num_heads,
59+
scale,
60+
block_table,
61+
seq_lens,
62+
o);
63+
}
64+
65+
} // namespace xllm::kernel::npu
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "atb_common.h"
17+
18+
namespace atb {
19+
atb::Tensor at_tensor_to_atb_tensor(const at::Tensor at_tensor) {
20+
static std::map<at::ScalarType, aclDataType> dtype_map = {
21+
{at::ScalarType::Bool, ACL_BOOL},
22+
{at::ScalarType::Byte, ACL_UINT8},
23+
{at::ScalarType::Char, ACL_INT8},
24+
{at::ScalarType::Half, ACL_FLOAT16},
25+
{at::ScalarType::Float, ACL_FLOAT},
26+
{at::ScalarType::Int, ACL_INT32},
27+
{at::ScalarType::Long, ACL_INT64},
28+
{at::ScalarType::BFloat16, ACL_BF16},
29+
{at::ScalarType::Double, ACL_DOUBLE},
30+
{at::ScalarType::Short, ACL_INT16},
31+
{at::ScalarType::ComplexHalf, ACL_COMPLEX32},
32+
{at::ScalarType::ComplexFloat, ACL_COMPLEX64},
33+
{at::ScalarType::ComplexDouble, ACL_COMPLEX128},
34+
};
35+
36+
TORCH_CHECK(at_tensor.is_contiguous(), "at_tensor is not contiguous");
37+
atb::Tensor tensor;
38+
tensor.desc.format = atb::utils::get_format_for_atb(at_tensor);
39+
if (at_tensor.device().type() == at::kCPU) {
40+
tensor.hostData = at_tensor.data_ptr();
41+
} else {
42+
tensor.deviceData = at_tensor.data_ptr();
43+
}
44+
45+
tensor.desc.shape.dimNum = at_tensor.sizes().size();
46+
for (uint64_t i = 0; i < at_tensor.sizes().size(); i++) {
47+
tensor.desc.shape.dims[i] = at_tensor.sizes()[i];
48+
}
49+
50+
auto dtype_iterator = dtype_map.find(at_tensor.scalar_type());
51+
TORCH_CHECK(dtype_iterator != dtype_map.end(),
52+
"not support dtype: ",
53+
at_tensor.scalar_type());
54+
tensor.desc.dtype = dtype_iterator->second;
55+
56+
tensor.dataSize = atb::Utils::GetTensorSize(tensor);
57+
58+
return tensor;
59+
}
60+
61+
void run_atb_cmd_v1(atb::Operation* op,
62+
const ParamSetter& paramsetter,
63+
const std::string& name) {
64+
aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false);
65+
auto context_ptr = atb::utils::get_context(stream);
66+
atb::VariantPack variant_pack = paramsetter.variant_pack_;
67+
uint64_t workspace_size = operation_setup(variant_pack, op, context_ptr);
68+
at::Tensor workspace_tensor;
69+
void* workspace_ptr = nullptr;
70+
if (workspace_size != 0) {
71+
at::TensorOptions options = at::TensorOptions(c10::DeviceType::PrivateUse1);
72+
workspace_tensor = at::empty({workspace_size}, options.dtype(at::kByte));
73+
workspace_ptr = const_cast<void*>(workspace_tensor.storage().data());
74+
}
75+
const c10::SmallVector<at::Tensor, N>& cpu_tensors =
76+
paramsetter.tensor_maintainer_.cpu_tensors;
77+
auto acl_call = [variant_pack,
78+
workspace_ptr,
79+
workspace_size,
80+
context_ptr,
81+
op,
82+
cpu_tensors]() -> int {
83+
auto st = op->Execute(
84+
variant_pack, (uint8_t*)workspace_ptr, workspace_size, context_ptr);
85+
DestroyOperation(op);
86+
return st;
87+
};
88+
at_npu::native::OpCommand::RunOpApiV2(name, acl_call);
89+
}
90+
91+
void run_atb_cmd_v2(atb::Operation* op,
92+
const ParamSetter& paramsetter,
93+
const std::string& name) {
94+
aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false);
95+
atb::VariantPack variant_pack = paramsetter.variant_pack_;
96+
const c10::SmallVector<at::Tensor, N>& cpu_tensors =
97+
paramsetter.tensor_maintainer_.cpu_tensors;
98+
auto acl_call = [op, variant_pack, stream, cpu_tensors]() -> int {
99+
auto context_ptr = atb::utils::get_context(stream);
100+
uint64_t workspace_size = operation_setup(variant_pack, op, context_ptr);
101+
at::Tensor workspace_tensor;
102+
void* workspace_ptr = nullptr;
103+
if (workspace_size != 0) {
104+
workspace_tensor =
105+
at_npu::native::allocate_workspace(workspace_size, stream);
106+
workspace_ptr = const_cast<void*>(workspace_tensor.storage().data());
107+
}
108+
auto st = op->Execute(
109+
variant_pack, (uint8_t*)workspace_ptr, workspace_size, context_ptr);
110+
return 0;
111+
};
112+
at_npu::native::OpCommand::RunOpApiV2(name, acl_call);
113+
}
114+
115+
void run_atb_cmd(atb::Operation* op,
116+
const ParamSetter& paramsetter,
117+
const std::string& name) {
118+
const auto is_capturing =
119+
static_cast<int>(c10_npu::currentStreamCaptureStatusMayInitCtx());
120+
if (is_capturing) {
121+
run_atb_cmd_v1(op, paramsetter, name);
122+
} else {
123+
run_atb_cmd_v2(op, paramsetter, name);
124+
}
125+
}
126+
127+
ParamSetter& ParamSetter::Input(const at::Tensor& tensor,
128+
const bool& format_trans) {
129+
if (!tensor.defined()) {
130+
variant_pack_.inTensors.push_back(atb::Tensor());
131+
return *this;
132+
}
133+
at::Tensor new_tensor = tensor.contiguous();
134+
if (format_trans) {
135+
new_tensor = atb::utils::format_trans(new_tensor);
136+
}
137+
atb::Tensor atb_tensor;
138+
if (new_tensor.device().type() == at::kCPU) {
139+
auto tensor_clone = new_tensor.clone();
140+
atb_tensor = at_tensor_to_atb_tensor(tensor_clone);
141+
tensor_maintainer_.cpu_tensors.emplace_back(std::move(tensor_clone));
142+
} else {
143+
atb_tensor = at_tensor_to_atb_tensor(new_tensor);
144+
tensor_maintainer_.contiguous_tensors.emplace_back(std::move(new_tensor));
145+
}
146+
variant_pack_.inTensors.push_back(atb_tensor);
147+
return *this;
148+
}
149+
150+
ParamSetter& ParamSetter::Input(const c10::optional<at::Tensor>& tensor,
151+
const bool& format_trans) {
152+
if (!tensor.has_value()) {
153+
variant_pack_.inTensors.push_back(atb::Tensor());
154+
return *this;
155+
}
156+
return Input(tensor.value(), format_trans);
157+
}
158+
159+
ParamSetter& ParamSetter::Output(at::Tensor& output) {
160+
auto atb_tensor = at_tensor_to_atb_tensor(output);
161+
variant_pack_.outTensors.push_back(atb_tensor);
162+
return *this;
163+
}
164+
165+
uint64_t operation_setup(atb::VariantPack variant_pack,
166+
atb::Operation* operation,
167+
atb::Context* context_ptr) {
168+
uint64_t workspace_size = 0;
169+
atb::Status status =
170+
operation->Setup(variant_pack, workspace_size, context_ptr);
171+
TORCH_CHECK(status == 0, operation->GetName(), " setup failed!");
172+
return workspace_size;
173+
}
174+
175+
} // namespace atb

0 commit comments

Comments
 (0)