Skip to content

Commit a0382bb

Browse files
committed
refactor: redesign wrapper for NPU fused_layernorm operator.
1 parent 51ab721 commit a0382bb

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

xllm/core/kernels/ops_api.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,10 @@ void batch_decode(AttentionParams& params) {
216216

217217
void fused_layernorm(FusedLayerNormParams& params) {
218218
#if defined(USE_MLU)
219+
params.output = torch::empty(
220+
{params.input.sizes()[0], params.intermediate_size / params.world_size},
221+
params.input.options());
222+
219223
mlu::fused_layernorm(params.input,
220224
params.output,
221225
params.residual,
@@ -238,16 +242,14 @@ void fused_layernorm(FusedLayerNormParams& params) {
238242
params.output = params.input;
239243
params.residual_out = params.residual;
240244
} else {
245+
params.output = torch::empty(
246+
{params.input.sizes()[0], params.intermediate_size / params.world_size},
247+
params.input.options());
248+
241249
cuda::rms_norm(params.output, params.input, params.weight, params.eps);
242250
}
243-
#else
244-
LOG(FATAL) << "fused_layernorm not implemented";
245-
#endif
246-
}
247-
248-
torch::Tensor fused_layernorm_tensor(FusedLayerNormParams& params) {
249-
#if defined(USE_NPU)
250-
return npu::fused_layernorm(
251+
#elif defined(USE_NPU)
252+
params.output = npu::fused_layernorm(
251253
params.input, params.weight, params.eps, params.mode);
252254
#else
253255
LOG(FATAL) << "fused_layernorm not implemented";

xllm/core/kernels/ops_api.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ void batch_decode(AttentionParams& params);
3636

3737
void fused_layernorm(FusedLayerNormParams& params);
3838

39-
torch::Tensor fused_layernorm_tensor(FusedLayerNormParams& params);
40-
4139
torch::Tensor matmul(MatmulParams& params);
4240

4341
torch::Tensor group_gemm(GroupGemmParams& params);

xllm/core/layers/common/fuse_norm.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ FusedRMSNormImpl::FusedRMSNormImpl(int64_t dim,
3535
}
3636

3737
torch::Tensor FusedRMSNormImpl::forward(torch::Tensor& input) {
38+
#if defined(USE_NPU)
39+
torch::Tensor output;
40+
#else
3841
auto output = torch::empty_like(input);
42+
#endif
3943
return forward_output(input, output);
4044
}
4145

0 commit comments

Comments
 (0)