diff --git a/optimized/tensorRT/scripts/diff_attn_nocast_plugin.py b/optimized/tensorRT/scripts/diff_attn_nocast_plugin.py index 7124edb..40a6e1e 100644 --- a/optimized/tensorRT/scripts/diff_attn_nocast_plugin.py +++ b/optimized/tensorRT/scripts/diff_attn_nocast_plugin.py @@ -5,6 +5,8 @@ """ import torch import tensorrt.plugin as trtp +import numpy as np +import numpy.typing as npt from typing import Tuple _stream_cache = {} @@ -13,7 +15,8 @@ @trtp.register("samel::diff_attn_swa") def diff_attn_swa_desc(q_bat: trtp.TensorDesc, k_bat: trtp.TensorDesc, - v_bat: trtp.TensorDesc, num_heads: int) -> trtp.TensorDesc: + v_bat: trtp.TensorDesc, + num_heads: npt.NDArray[np.int64]) -> trtp.TensorDesc: out = q_bat.like() out.shape_expr[-2] = q_bat.shape_expr[-2] // 2 return out @@ -21,7 +24,8 @@ def diff_attn_swa_desc(q_bat: trtp.TensorDesc, k_bat: trtp.TensorDesc, @trtp.impl("samel::diff_attn_swa") def diff_attn_swa_impl(q_bat: trtp.Tensor, k_bat: trtp.Tensor, v_bat: trtp.Tensor, - num_heads: int, outputs: Tuple[trtp.Tensor], stream: int): + num_heads: npt.NDArray[np.int64], + outputs: Tuple[trtp.Tensor], stream: int): global _triton_fn if stream not in _stream_cache: _stream_cache[stream] = torch.cuda.ExternalStream(stream) @@ -37,5 +41,5 @@ def diff_attn_swa_impl(q_bat: trtp.Tensor, k_bat: trtp.Tensor, v_bat: trtp.Tenso # NO dtype cast — Triton auto-compiles for the input dtype o = _triton_fn(q, k, v, window=17) - H = num_heads + H = int(np.asarray(num_heads).reshape(-1)[0]) out_t.copy_(o[:, :, :H, :] - o[:, :, H:, :])