From 145a538626cfead776b3dbe41b071474db9d193f Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Wed, 10 Jun 2026 07:55:27 -0400 Subject: [PATCH] Fix diff_attn_swa plugin int attribute under TRT 10.16 / NumPy 2.x TensorRT serializes plugin attributes as 1-element numpy arrays and calls the annotated type on them when instantiating the plugin. With the `num_heads: int` annotation this hits `int(np.array([H]))`, which NumPy (>=1.25 deprecation, error in 2.x) rejects: tensorrt_bindings/plugin/_lib.py:279, in create_plugin attrs[f.name] = attr_type_annot(f.data) TypeError: only 0-dimensional arrays can be converted to Python scalars so any build or load of an engine using samel::diff_attn_swa (the SAME-L decoder) crashes. Annotate the attribute as npt.NDArray[np.int64] and coerce to a Python int inside the impl. Verified end-to-end on TRT 10.16.1.11 + NumPy 2.4.4: the attr arrives as ndarray, the kernel output is unchanged. --- optimized/tensorRT/scripts/diff_attn_nocast_plugin.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/optimized/tensorRT/scripts/diff_attn_nocast_plugin.py b/optimized/tensorRT/scripts/diff_attn_nocast_plugin.py index 7124edb..40a6e1e 100644 --- a/optimized/tensorRT/scripts/diff_attn_nocast_plugin.py +++ b/optimized/tensorRT/scripts/diff_attn_nocast_plugin.py @@ -5,6 +5,8 @@ """ import torch import tensorrt.plugin as trtp +import numpy as np +import numpy.typing as npt from typing import Tuple _stream_cache = {} @@ -13,7 +15,8 @@ @trtp.register("samel::diff_attn_swa") def diff_attn_swa_desc(q_bat: trtp.TensorDesc, k_bat: trtp.TensorDesc, - v_bat: trtp.TensorDesc, num_heads: int) -> trtp.TensorDesc: + v_bat: trtp.TensorDesc, + num_heads: npt.NDArray[np.int64]) -> trtp.TensorDesc: out = q_bat.like() out.shape_expr[-2] = q_bat.shape_expr[-2] // 2 return out @@ -21,7 +24,8 @@ def diff_attn_swa_desc(q_bat: trtp.TensorDesc, k_bat: trtp.TensorDesc, @trtp.impl("samel::diff_attn_swa") def diff_attn_swa_impl(q_bat: trtp.Tensor, k_bat: trtp.Tensor, v_bat: trtp.Tensor, - num_heads: int, outputs: Tuple[trtp.Tensor], stream: int): + num_heads: npt.NDArray[np.int64], + outputs: Tuple[trtp.Tensor], stream: int): global _triton_fn if stream not in _stream_cache: _stream_cache[stream] = torch.cuda.ExternalStream(stream) @@ -37,5 +41,5 @@ def diff_attn_swa_impl(q_bat: trtp.Tensor, k_bat: trtp.Tensor, v_bat: trtp.Tenso # NO dtype cast — Triton auto-compiles for the input dtype o = _triton_fn(q, k, v, window=17) - H = num_heads + H = int(np.asarray(num_heads).reshape(-1)[0]) out_t.copy_(o[:, :, :H, :] - o[:, :, H:, :])