Skip to content

Commit 28e6e79

Browse files
committed
refactor: redesign wrapper for NPU fused_layernorm operator.
1 parent baf3c90 commit 28e6e79

File tree

3 files changed

+6
-10
lines changed

3 files changed

+6
-10
lines changed

xllm/core/kernels/ops_api.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,8 @@ void fused_layernorm(FusedLayerNormParams& params) {
229229
params.dynamic_quant);
230230
#elif defined(USE_CUDA)
231231
cuda::rmsnorm(params.output, params.input, params.weight, params.eps);
232-
#else
233-
LOG(FATAL) << "fused_layernorm not implemented";
234-
#endif
235-
}
236-
237-
torch::Tensor fused_layernorm_tensor(FusedLayerNormParams& params) {
238-
#if defined(USE_NPU)
239-
return npu::fused_layernorm(
232+
#elif defined(USE_NPU)
233+
params.output = npu::fused_layernorm(
240234
params.input, params.weight, params.eps, params.mode);
241235
#else
242236
LOG(FATAL) << "fused_layernorm not implemented";

xllm/core/kernels/ops_api.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ void batch_decode(AttentionParams& params);
3636

3737
void fused_layernorm(FusedLayerNormParams& params);
3838

39-
torch::Tensor fused_layernorm_tensor(FusedLayerNormParams& params);
40-
4139
torch::Tensor matmul(MatmulParams& params);
4240

4341
torch::Tensor fused_moe(FusedMoEParams& params);

xllm/core/layers/common/fuse_norm.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ FusedRMSNormImpl::FusedRMSNormImpl(int64_t dim,
3535
}
3636

3737
torch::Tensor FusedRMSNormImpl::forward(torch::Tensor& input) {
38+
#if defined(USE_NPU)
39+
torch::Tensor output;
40+
#else
3841
auto output = torch::empty_like(input);
42+
#endif
3943
return forward_output(input, output);
4044
}
4145

0 commit comments

Comments
 (0)