From b874acf604d068dd71d66424db738650eb6ca1ea Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Mon, 20 Apr 2026 08:04:34 +0000 Subject: [PATCH 1/2] [ascend] optimize gated rmsnorm --- dlinfer/vendor/ascend/triton_ops/rms_norm_gated.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/dlinfer/vendor/ascend/triton_ops/rms_norm_gated.py b/dlinfer/vendor/ascend/triton_ops/rms_norm_gated.py index 42301214..0a47acb4 100644 --- a/dlinfer/vendor/ascend/triton_ops/rms_norm_gated.py +++ b/dlinfer/vendor/ascend/triton_ops/rms_norm_gated.py @@ -164,12 +164,7 @@ def reset_parameters(self): torch.nn.init.ones_(self.weight) def forward(self, x, z=None): - return rmsnorm_fn( - x, - self.weight, - self.bias, - z=z, - eps=self.eps, - group_size=self.group_size, - norm_before_gate=self.norm_before_gate, - ) + input_dtype = x.dtype + x = torch.ops.npu.npu_rms_norm(x, self.weight, self.eps)[0] + out = x * F.silu(z.to(torch.float32)) + return out.to(input_dtype) From f6ac78b38818d0f2fc16302b19ad3e3e571a27fe Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Mon, 20 Apr 2026 08:08:50 +0000 Subject: [PATCH 2/2] [ascend] fix import error --- dlinfer/vendor/ascend/triton_ops/rms_norm_gated.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dlinfer/vendor/ascend/triton_ops/rms_norm_gated.py b/dlinfer/vendor/ascend/triton_ops/rms_norm_gated.py index 0a47acb4..6c18702c 100644 --- a/dlinfer/vendor/ascend/triton_ops/rms_norm_gated.py +++ b/dlinfer/vendor/ascend/triton_ops/rms_norm_gated.py @@ -12,6 +12,7 @@ import torch.nn as nn import triton import triton.language as tl +import torch.nn.functional as F MAX_CORES = 65535