Skip to content

Commit 68e1b71

Browse files
author
Wang, Yiting
authored
[XPU] Add deepseek_scaling_rope fused kernel (vllm-project#36612)
Signed-off-by: yitingw1 <yiting.wang@intel.com>
1 parent 0024f39 commit 68e1b71

2 files changed

Lines changed: 67 additions & 0 deletions

File tree

vllm/_xpu_ops.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from vllm.logger import init_logger
1010
from vllm.platforms import current_platform
11+
from vllm.utils.torch_utils import direct_register_custom_op
1112

1213
logger = init_logger(__name__)
1314

@@ -54,6 +55,37 @@ def _int4_gemm_w4a16_fake(
5455
return torch.empty((M, N), dtype=input.dtype, device=input.device)
5556

5657

58+
def _xpu_ops_deepseek_scaling_rope_impl(
59+
positions: torch.Tensor,
60+
query: torch.Tensor,
61+
key: torch.Tensor | None,
62+
offsets: torch.Tensor | None,
63+
cos_sin_cache: torch.Tensor | None,
64+
rotary_dim: int,
65+
is_neox_style: bool,
66+
) -> tuple[torch.Tensor, torch.Tensor]:
67+
assert key is not None
68+
return torch.ops._xpu_C.deepseek_scaling_rope(
69+
positions, query, key, offsets, cos_sin_cache, rotary_dim, is_neox_style
70+
)
71+
72+
73+
def _xpu_ops_deepseek_scaling_rope_fake(
74+
positions: torch.Tensor,
75+
query: torch.Tensor,
76+
key: torch.Tensor | None,
77+
offsets: torch.Tensor | None,
78+
cos_sin_cache: torch.Tensor | None,
79+
rotary_dim: int,
80+
is_neox_style: bool,
81+
) -> tuple[torch.Tensor, torch.Tensor]:
82+
return query, key
83+
84+
85+
# Global flag to ensure ops are registered only once
86+
_OPS_REGISTERED = False
87+
88+
5789
class xpu_ops:
5890
@staticmethod
5991
def flash_attn_varlen_func(
@@ -402,3 +434,21 @@ def top_k_per_row_decode(
402434
raw_topk_indices[: topk_indices.shape[0], : topk_indices.shape[1]] = (
403435
topk_indices
404436
)
437+
438+
@staticmethod
439+
def register_ops_once() -> None:
440+
global _OPS_REGISTERED
441+
if not _OPS_REGISTERED:
442+
# register all the custom ops here
443+
direct_register_custom_op(
444+
op_name="xpu_ops_deepseek_scaling_rope",
445+
op_func=_xpu_ops_deepseek_scaling_rope_impl,
446+
mutates_args=[],
447+
fake_impl=_xpu_ops_deepseek_scaling_rope_fake,
448+
dispatch_key=current_platform.dispatch_key,
449+
)
450+
451+
_OPS_REGISTERED = True
452+
453+
454+
xpu_ops.register_ops_once()

vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,23 @@ def forward_native(
152152
key = key_rot
153153
return query, key
154154

155+
def forward_xpu(
156+
self,
157+
positions: torch.Tensor,
158+
query: torch.Tensor,
159+
key: torch.Tensor | None = None,
160+
offsets: torch.Tensor | None = None,
161+
) -> tuple[torch.Tensor, torch.Tensor | None]:
162+
return torch.ops.vllm.xpu_ops_deepseek_scaling_rope(
163+
positions,
164+
query,
165+
key,
166+
offsets,
167+
self._match_cos_sin_cache_dtype(query),
168+
self.rotary_dim,
169+
self.is_neox_style,
170+
)
171+
155172
def forward_hip(
156173
self,
157174
positions: torch.Tensor,

0 commit comments

Comments
 (0)