|
8 | 8 |
|
9 | 9 | from vllm.logger import init_logger |
10 | 10 | from vllm.platforms import current_platform |
| 11 | +from vllm.utils.torch_utils import direct_register_custom_op |
11 | 12 |
|
12 | 13 | logger = init_logger(__name__) |
13 | 14 |
|
@@ -54,6 +55,37 @@ def _int4_gemm_w4a16_fake( |
54 | 55 | return torch.empty((M, N), dtype=input.dtype, device=input.device) |
55 | 56 |
|
56 | 57 |
|
| 58 | +def _xpu_ops_deepseek_scaling_rope_impl( |
| 59 | + positions: torch.Tensor, |
| 60 | + query: torch.Tensor, |
| 61 | + key: torch.Tensor | None, |
| 62 | + offsets: torch.Tensor | None, |
| 63 | + cos_sin_cache: torch.Tensor | None, |
| 64 | + rotary_dim: int, |
| 65 | + is_neox_style: bool, |
| 66 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 67 | + assert key is not None |
| 68 | + return torch.ops._xpu_C.deepseek_scaling_rope( |
| 69 | + positions, query, key, offsets, cos_sin_cache, rotary_dim, is_neox_style |
| 70 | + ) |
| 71 | + |
| 72 | + |
| 73 | +def _xpu_ops_deepseek_scaling_rope_fake( |
| 74 | + positions: torch.Tensor, |
| 75 | + query: torch.Tensor, |
| 76 | + key: torch.Tensor | None, |
| 77 | + offsets: torch.Tensor | None, |
| 78 | + cos_sin_cache: torch.Tensor | None, |
| 79 | + rotary_dim: int, |
| 80 | + is_neox_style: bool, |
| 81 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 82 | + return query, key |
| 83 | + |
| 84 | + |
| 85 | +# Global flag to ensure ops are registered only once |
| 86 | +_OPS_REGISTERED = False |
| 87 | + |
| 88 | + |
57 | 89 | class xpu_ops: |
58 | 90 | @staticmethod |
59 | 91 | def flash_attn_varlen_func( |
@@ -402,3 +434,21 @@ def top_k_per_row_decode( |
402 | 434 | raw_topk_indices[: topk_indices.shape[0], : topk_indices.shape[1]] = ( |
403 | 435 | topk_indices |
404 | 436 | ) |
| 437 | + |
| 438 | + @staticmethod |
| 439 | + def register_ops_once() -> None: |
| 440 | + global _OPS_REGISTERED |
| 441 | + if not _OPS_REGISTERED: |
| 442 | + # register all the custom ops here |
| 443 | + direct_register_custom_op( |
| 444 | + op_name="xpu_ops_deepseek_scaling_rope", |
| 445 | + op_func=_xpu_ops_deepseek_scaling_rope_impl, |
| 446 | + mutates_args=[], |
| 447 | + fake_impl=_xpu_ops_deepseek_scaling_rope_fake, |
| 448 | + dispatch_key=current_platform.dispatch_key, |
| 449 | + ) |
| 450 | + |
| 451 | + _OPS_REGISTERED = True |
| 452 | + |
| 453 | + |
| 454 | +xpu_ops.register_ops_once() |
0 commit comments