1+ /* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License.
14+ ==============================================================================*/
15+
16+ #include " atb_common.h"
17+
18+ namespace atb {
19+ atb::Tensor at_tensor_to_atb_tensor (const at::Tensor at_tensor) {
20+ static std::map<at::ScalarType, aclDataType> dtype_map = {
21+ {at::ScalarType::Bool, ACL_BOOL},
22+ {at::ScalarType::Byte, ACL_UINT8},
23+ {at::ScalarType::Char, ACL_INT8},
24+ {at::ScalarType::Half, ACL_FLOAT16},
25+ {at::ScalarType::Float, ACL_FLOAT},
26+ {at::ScalarType::Int, ACL_INT32},
27+ {at::ScalarType::Long, ACL_INT64},
28+ {at::ScalarType::BFloat16, ACL_BF16},
29+ {at::ScalarType::Double, ACL_DOUBLE},
30+ {at::ScalarType::Short, ACL_INT16},
31+ {at::ScalarType::ComplexHalf, ACL_COMPLEX32},
32+ {at::ScalarType::ComplexFloat, ACL_COMPLEX64},
33+ {at::ScalarType::ComplexDouble, ACL_COMPLEX128},
34+ };
35+
36+ TORCH_CHECK (at_tensor.is_contiguous (), " at_tensor is not contiguous" );
37+ atb::Tensor tensor;
38+ tensor.desc .format = atb::utils::get_format_for_atb (at_tensor);
39+ if (at_tensor.device ().type () == at::kCPU ) {
40+ tensor.hostData = at_tensor.data_ptr ();
41+ } else {
42+ tensor.deviceData = at_tensor.data_ptr ();
43+ }
44+
45+ tensor.desc .shape .dimNum = at_tensor.sizes ().size ();
46+ for (uint64_t i = 0 ; i < at_tensor.sizes ().size (); i++) {
47+ tensor.desc .shape .dims [i] = at_tensor.sizes ()[i];
48+ }
49+
50+ auto dtype_iterator = dtype_map.find (at_tensor.scalar_type ());
51+ TORCH_CHECK (dtype_iterator != dtype_map.end (),
52+ " not support dtype: " ,
53+ at_tensor.scalar_type ());
54+ tensor.desc .dtype = dtype_iterator->second ;
55+
56+ tensor.dataSize = atb::Utils::GetTensorSize (tensor);
57+
58+ return tensor;
59+ }
60+
61+ void run_atb_cmd_v1 (atb::Operation* op,
62+ const ParamSetter& paramsetter,
63+ const std::string& name) {
64+ aclrtStream stream = c10_npu::getCurrentNPUStream ().stream (false );
65+ auto context_ptr = atb::utils::get_context (stream);
66+ atb::VariantPack variant_pack = paramsetter.variant_pack_ ;
67+ uint64_t workspace_size = operation_setup (variant_pack, op, context_ptr);
68+ at::Tensor workspace_tensor;
69+ void * workspace_ptr = nullptr ;
70+ if (workspace_size != 0 ) {
71+ at::TensorOptions options = at::TensorOptions (c10::DeviceType::PrivateUse1);
72+ workspace_tensor = at::empty ({workspace_size}, options.dtype (at::kByte ));
73+ workspace_ptr = const_cast <void *>(workspace_tensor.storage ().data ());
74+ }
75+ const c10::SmallVector<at::Tensor, N>& cpu_tensors =
76+ paramsetter.tensor_maintainer_ .cpu_tensors ;
77+ auto acl_call = [variant_pack,
78+ workspace_ptr,
79+ workspace_size,
80+ context_ptr,
81+ op,
82+ cpu_tensors]() -> int {
83+ auto st = op->Execute (
84+ variant_pack, (uint8_t *)workspace_ptr, workspace_size, context_ptr);
85+ DestroyOperation (op);
86+ return st;
87+ };
88+ at_npu::native::OpCommand::RunOpApiV2 (name, acl_call);
89+ }
90+
91+ void run_atb_cmd_v2 (atb::Operation* op,
92+ const ParamSetter& paramsetter,
93+ const std::string& name) {
94+ aclrtStream stream = c10_npu::getCurrentNPUStream ().stream (false );
95+ atb::VariantPack variant_pack = paramsetter.variant_pack_ ;
96+ const c10::SmallVector<at::Tensor, N>& cpu_tensors =
97+ paramsetter.tensor_maintainer_ .cpu_tensors ;
98+ auto acl_call = [op, variant_pack, stream, cpu_tensors]() -> int {
99+ auto context_ptr = atb::utils::get_context (stream);
100+ uint64_t workspace_size = operation_setup (variant_pack, op, context_ptr);
101+ at::Tensor workspace_tensor;
102+ void * workspace_ptr = nullptr ;
103+ if (workspace_size != 0 ) {
104+ workspace_tensor =
105+ at_npu::native::allocate_workspace (workspace_size, stream);
106+ workspace_ptr = const_cast <void *>(workspace_tensor.storage ().data ());
107+ }
108+ auto st = op->Execute (
109+ variant_pack, (uint8_t *)workspace_ptr, workspace_size, context_ptr);
110+ return 0 ;
111+ };
112+ at_npu::native::OpCommand::RunOpApiV2 (name, acl_call);
113+ }
114+
115+ void run_atb_cmd (atb::Operation* op,
116+ const ParamSetter& paramsetter,
117+ const std::string& name) {
118+ const auto is_capturing =
119+ static_cast <int >(c10_npu::currentStreamCaptureStatusMayInitCtx ());
120+ if (is_capturing) {
121+ run_atb_cmd_v1 (op, paramsetter, name);
122+ } else {
123+ run_atb_cmd_v2 (op, paramsetter, name);
124+ }
125+ }
126+
127+ ParamSetter& ParamSetter::Input (const at::Tensor& tensor,
128+ const bool & format_trans) {
129+ if (!tensor.defined ()) {
130+ variant_pack_.inTensors .push_back (atb::Tensor ());
131+ return *this ;
132+ }
133+ at::Tensor new_tensor = tensor.contiguous ();
134+ if (format_trans) {
135+ new_tensor = atb::utils::format_trans (new_tensor);
136+ }
137+ atb::Tensor atb_tensor;
138+ if (new_tensor.device ().type () == at::kCPU ) {
139+ auto tensor_clone = new_tensor.clone ();
140+ atb_tensor = at_tensor_to_atb_tensor (tensor_clone);
141+ tensor_maintainer_.cpu_tensors .emplace_back (std::move (tensor_clone));
142+ } else {
143+ atb_tensor = at_tensor_to_atb_tensor (new_tensor);
144+ tensor_maintainer_.contiguous_tensors .emplace_back (std::move (new_tensor));
145+ }
146+ variant_pack_.inTensors .push_back (atb_tensor);
147+ return *this ;
148+ }
149+
150+ ParamSetter& ParamSetter::Input (const c10::optional<at::Tensor>& tensor,
151+ const bool & format_trans) {
152+ if (!tensor.has_value ()) {
153+ variant_pack_.inTensors .push_back (atb::Tensor ());
154+ return *this ;
155+ }
156+ return Input (tensor.value (), format_trans);
157+ }
158+
159+ ParamSetter& ParamSetter::Output (at::Tensor& output) {
160+ auto atb_tensor = at_tensor_to_atb_tensor (output);
161+ variant_pack_.outTensors .push_back (atb_tensor);
162+ return *this ;
163+ }
164+
165+ uint64_t operation_setup (atb::VariantPack variant_pack,
166+ atb::Operation* operation,
167+ atb::Context* context_ptr) {
168+ uint64_t workspace_size = 0 ;
169+ atb::Status status =
170+ operation->Setup (variant_pack, workspace_size, context_ptr);
171+ TORCH_CHECK (status == 0 , operation->GetName (), " setup failed!" );
172+ return workspace_size;
173+ }
174+
175+ } // namespace atb
0 commit comments