Skip to content

Commit ac81c85

Browse files
[PyTorch] Python GroupedTensor (#2654)
* PyTorch-Python GroupedTensor Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Update transformer_engine/pytorch/tensor/storage/grouped_tensor.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Remove mxfp8 gq test Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix recipe tests and FP8 weights Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix device test Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Disable grouped weights for unsupported recipes Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent 8ebb47e commit ac81c85

14 files changed

Lines changed: 1714 additions & 35 deletions
Lines changed: 385 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,385 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
"""Tests for GroupedTensor class"""
6+
7+
from typing import List, Tuple
8+
import pytest
9+
import torch
10+
import transformer_engine.pytorch as te
11+
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
12+
from transformer_engine.pytorch import (
13+
Quantizer,
14+
Float8Quantizer,
15+
Float8CurrentScalingQuantizer,
16+
Float8BlockQuantizer,
17+
MXFP8Quantizer,
18+
NVFP4Quantizer,
19+
)
20+
from transformer_engine.pytorch.constants import TE_DType_To_Torch
21+
import transformer_engine_torch as tex
22+
23+
# Check available recipes
24+
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
25+
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
26+
return_reason=True
27+
)
28+
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
29+
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
30+
31+
_quantization_params = [
32+
pytest.param(
33+
"fp8_delayed_scaling",
34+
marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8),
35+
),
36+
pytest.param(
37+
"fp8_current_scaling",
38+
marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8),
39+
),
40+
pytest.param(
41+
"fp8_blockwise",
42+
marks=pytest.mark.skipif(
43+
not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling
44+
),
45+
),
46+
pytest.param(
47+
"mxfp8",
48+
marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8),
49+
),
50+
pytest.param(
51+
"nvfp4",
52+
marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4),
53+
),
54+
]
55+
56+
57+
def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer:
58+
"""Create quantizers for given quantization scheme"""
59+
60+
if quantization == "fp8_delayed_scaling":
61+
quantizer = Float8Quantizer(
62+
scale=torch.ones(1, dtype=torch.float32, device="cuda"),
63+
amax=torch.zeros(1, dtype=torch.float32, device="cuda"),
64+
fp8_dtype=tex.DType.kFloat8E4M3,
65+
)
66+
elif quantization == "fp8_current_scaling":
67+
quantizer = Float8CurrentScalingQuantizer(
68+
fp8_dtype=tex.DType.kFloat8E4M3,
69+
device="cuda",
70+
)
71+
quantizer.set_usage(rowwise=True, columnwise=False)
72+
elif quantization == "fp8_blockwise":
73+
quantizer = Float8BlockQuantizer(
74+
fp8_dtype=tex.DType.kFloat8E4M3,
75+
rowwise=True,
76+
columnwise=False,
77+
force_pow_2_scales=True,
78+
amax_epsilon=0.0,
79+
block_scaling_dim=1,
80+
)
81+
elif quantization == "mxfp8":
82+
quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
83+
elif quantization == "nvfp4":
84+
quantizer = NVFP4Quantizer(
85+
with_rht=False,
86+
with_post_rht_amax=False,
87+
with_2d_quantization=False,
88+
stochastic_rounding=False,
89+
with_random_sign_mask=False,
90+
)
91+
else:
92+
raise ValueError(f"Unknown quantization scheme: {quantization}")
93+
94+
quantizer.internal = False
95+
96+
return quantizer
97+
98+
99+
def _get_rowwise_data_tensor(qtensor, quantization: str) -> torch.Tensor:
100+
if quantization in ("fp8_delayed_scaling", "fp8_current_scaling"):
101+
return qtensor._data
102+
if quantization in ("fp8_blockwise", "mxfp8", "nvfp4"):
103+
return qtensor._rowwise_data
104+
raise ValueError(f"Unknown quantization scheme: {quantization}")
105+
106+
107+
def _rowwise_offset_bytes(numel: int, quantization: str) -> int:
108+
if quantization == "nvfp4":
109+
return numel // 2
110+
return numel
111+
112+
113+
class TestGroupedTensor:
114+
@staticmethod
115+
def setup_class(cls) -> None:
116+
# Configure RNG
117+
seed = 1234
118+
torch.manual_seed(seed)
119+
torch.cuda.manual_seed(seed)
120+
121+
def test_basic_construction_all_same_shape(self) -> None:
122+
"""Test GroupedTensor construction with all tensors having same shape"""
123+
num_tensors = 4
124+
shape = [(256, 512) for _ in range(num_tensors)]
125+
126+
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
127+
num_tensors=num_tensors,
128+
shape=shape,
129+
quantizer=None,
130+
device="cuda",
131+
dtype=torch.float32,
132+
)
133+
134+
assert grouped_tensor.num_tensors == num_tensors
135+
assert grouped_tensor.all_same_shape()
136+
assert grouped_tensor.all_same_first_dim()
137+
assert grouped_tensor.all_same_last_dim()
138+
assert grouped_tensor.logical_shape == (num_tensors * 256, 512)
139+
assert grouped_tensor.get_common_first_dim() == 256
140+
assert grouped_tensor.get_common_last_dim() == 512
141+
assert grouped_tensor.has_data()
142+
143+
def test_basic_construction_varying_first_dim(self) -> None:
144+
"""Test GroupedTensor construction with varying first dimension"""
145+
num_tensors = 3
146+
shape = [(128, 512), (256, 512), (384, 512)]
147+
148+
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
149+
num_tensors=num_tensors,
150+
shape=shape,
151+
quantizer=None,
152+
device="cuda",
153+
dtype=torch.float32,
154+
)
155+
156+
assert grouped_tensor.num_tensors == num_tensors
157+
assert not grouped_tensor.all_same_shape()
158+
assert not grouped_tensor.all_same_first_dim()
159+
assert grouped_tensor.all_same_last_dim()
160+
assert grouped_tensor.get_common_last_dim() == shape[0][1]
161+
assert grouped_tensor.logical_shape == (
162+
sum(v for v, _ in shape),
163+
shape[0][1],
164+
) # sum of first dims
165+
166+
def test_split_into_quantized_tensors_no_quantization(self) -> None:
167+
"""Test split_into_quantized_tensors for unquantized tensors"""
168+
num_tensors = 3
169+
shape = [(256, 512) for _ in range(num_tensors)]
170+
171+
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
172+
num_tensors=num_tensors,
173+
shape=shape,
174+
quantizer=None,
175+
device="cuda",
176+
dtype=torch.float32,
177+
)
178+
179+
# Get the original data pointer
180+
original_data_ptr = grouped_tensor.data.data_ptr()
181+
182+
# Split into tensors
183+
tensors = grouped_tensor.split_into_quantized_tensors()
184+
185+
assert len(tensors) == num_tensors
186+
187+
# Verify each tensor has correct shape and shares storage
188+
for i, tensor in enumerate(tensors):
189+
assert tensor.shape == shape[i]
190+
assert isinstance(tensor, torch.Tensor)
191+
assert not hasattr(tensor, "_data") # Not a quantized tensor
192+
193+
# Verify data pointer is within the original grouped tensor storage
194+
# The tensor should be a view of the original data
195+
assert tensor.data_ptr() >= original_data_ptr
196+
197+
# Calculate expected offset
198+
expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size()
199+
assert tensor.data_ptr() == original_data_ptr + expected_offset
200+
201+
@pytest.mark.parametrize("quantization", _quantization_params)
202+
def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None:
203+
"""Test split_into_quantized_tensors for quantized tensors"""
204+
num_tensors = 3
205+
shape = [(512, 512) for _ in range(num_tensors)]
206+
quantizers = make_quantizer(quantization, num_tensors, shape)
207+
208+
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
209+
num_tensors=num_tensors,
210+
shape=shape,
211+
quantizer=quantizers,
212+
device="cuda",
213+
)
214+
215+
# Get the original data pointer
216+
original_data_ptr = grouped_tensor.data.data_ptr()
217+
218+
# Split into tensors
219+
tensors = grouped_tensor.split_into_quantized_tensors()
220+
221+
assert len(tensors) == num_tensors
222+
223+
# Verify each tensor shares storage with the grouped tensor
224+
for i, tensor in enumerate(tensors):
225+
rowwise_data = _get_rowwise_data_tensor(tensor, quantization)
226+
assert rowwise_data is not None
227+
assert rowwise_data.data_ptr() >= original_data_ptr
228+
numel = shape[i][0] * shape[i][1]
229+
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
230+
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
231+
232+
def test_split_varying_shapes(self) -> None:
233+
"""Test split_into_quantized_tensors with varying shapes"""
234+
num_tensors = 3
235+
shape = [(128, 512), (256, 512), (384, 512)]
236+
237+
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
238+
num_tensors=num_tensors,
239+
shape=shape,
240+
quantizer=None,
241+
device="cuda",
242+
dtype=torch.float32,
243+
)
244+
245+
original_data_ptr = grouped_tensor.data.data_ptr()
246+
tensors = grouped_tensor.split_into_quantized_tensors()
247+
248+
assert len(tensors) == num_tensors
249+
250+
# Verify shapes and storage
251+
cumulative_offset = 0
252+
for i, tensor in enumerate(tensors):
253+
assert tensor.shape == shape[i]
254+
expected_offset = cumulative_offset * tensor.element_size()
255+
assert tensor.data_ptr() == original_data_ptr + expected_offset
256+
cumulative_offset += shape[i][0] * shape[i][1]
257+
258+
@pytest.mark.parametrize("quantization", _quantization_params)
259+
def test_quantize_inplace(self, quantization: str) -> None:
260+
"""Test that quantize is done in-place for all recipes"""
261+
num_tensors = 3
262+
shape = [(512, 512) for _ in range(num_tensors)]
263+
quantizers = make_quantizer(quantization, num_tensors, shape)
264+
265+
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
266+
num_tensors=num_tensors,
267+
shape=shape,
268+
quantizer=quantizers,
269+
device="cuda",
270+
)
271+
272+
# Get original data pointers before quantization
273+
original_data_ptr = grouped_tensor.data.data_ptr()
274+
original_scale_inv_ptr = grouped_tensor.scale_inv.data_ptr()
275+
original_scale_ptr = (
276+
grouped_tensor.scale.data_ptr() if grouped_tensor.scale is not None else None
277+
)
278+
279+
# Create input tensors
280+
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
281+
282+
# Quantize in place
283+
quantized_tensors = grouped_tensor.quantize(input_tensors)
284+
285+
# Verify data pointers haven't changed (in-place operation)
286+
assert grouped_tensor.data.data_ptr() == original_data_ptr
287+
assert grouped_tensor.scale_inv.data_ptr() == original_scale_inv_ptr
288+
if original_scale_ptr is not None:
289+
assert grouped_tensor.scale.data_ptr() == original_scale_ptr
290+
291+
# Verify returned tensors point to the same storage
292+
for i, qtensor in enumerate(quantized_tensors):
293+
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
294+
numel = shape[i][0] * shape[i][1]
295+
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
296+
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
297+
298+
@pytest.mark.parametrize("quantization", _quantization_params)
299+
def test_quantize_varying_shapes(self, quantization: str) -> None:
300+
"""Test quantize with varying shapes"""
301+
num_tensors = 3
302+
shape = [(256, 512), (512, 512), (768, 512)]
303+
quantizers = make_quantizer(quantization, num_tensors, shape)
304+
305+
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
306+
num_tensors=num_tensors,
307+
shape=shape,
308+
quantizer=quantizers,
309+
device="cuda",
310+
)
311+
312+
# Get original data pointers
313+
original_data_ptr = grouped_tensor.data.data_ptr()
314+
315+
# Create input tensors with varying shapes
316+
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
317+
318+
# Quantize in place
319+
quantized_tensors = grouped_tensor.quantize(input_tensors)
320+
321+
# Verify data pointer hasn't changed
322+
assert grouped_tensor.data.data_ptr() == original_data_ptr
323+
324+
# Verify each tensor points to correct location
325+
cumulative_numel = 0
326+
for qtensor, tensor_shape in zip(quantized_tensors, shape):
327+
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
328+
expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization)
329+
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
330+
cumulative_numel += tensor_shape[0] * tensor_shape[1]
331+
332+
@pytest.mark.parametrize("quantization", _quantization_params)
333+
def test_static_quantize_method(self, quantization: str) -> None:
334+
"""Test the static quantize method"""
335+
num_tensors = 3
336+
shape = [(512, 512) for _ in range(num_tensors)]
337+
quantizers = make_quantizer(quantization, num_tensors, shape)
338+
339+
# Create input tensors
340+
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
341+
342+
# Use static quantize method
343+
grouped_tensor = GroupedTensor.create_and_quantize(
344+
tensors=input_tensors,
345+
quantizer=quantizers,
346+
device="cuda",
347+
)
348+
349+
# Verify the grouped tensor was created correctly
350+
assert grouped_tensor.num_tensors == num_tensors
351+
assert grouped_tensor.has_data()
352+
353+
# Verify quantized_tensors were created and point to same storage
354+
assert grouped_tensor.quantized_tensors is not None
355+
assert len(grouped_tensor.quantized_tensors) == num_tensors
356+
357+
original_data_ptr = grouped_tensor.data.data_ptr()
358+
for i, qtensor in enumerate(grouped_tensor.quantized_tensors):
359+
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
360+
numel = shape[i][0] * shape[i][1]
361+
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
362+
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
363+
364+
def test_clear(self) -> None:
365+
"""Test clear method"""
366+
num_tensors = 3
367+
shape = [(256, 512) for _ in range(num_tensors)]
368+
369+
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
370+
num_tensors=num_tensors,
371+
shape=shape,
372+
quantizer=None,
373+
device="cuda",
374+
dtype=torch.float32,
375+
)
376+
377+
assert grouped_tensor.has_data()
378+
assert grouped_tensor.num_tensors == num_tensors
379+
380+
grouped_tensor.clear()
381+
382+
assert not grouped_tensor.has_data()
383+
assert grouped_tensor.num_tensors == 0
384+
assert grouped_tensor.data is None
385+
assert grouped_tensor.logical_shape == (0, 0)

0 commit comments

Comments
 (0)