Single parameter for GroupedLinear module#2727
Conversation
Greptile SummaryThis PR follows up on prior grouped-linear work (#2600, #2654, #2678) with three main changes: (1) introduces a Key changes and issues found:
Confidence Score: 3/5
Important Files Changed
Class Diagram%%{init: {'theme': 'neutral'}}%%
classDiagram
class `torch.Tensor` {
<<built-in>>
}
class GroupedTensorStorage {
+num_tensors: int
+quantizers: List~Quantizer~
+tensor_shapes: List~Tuple~
+rowwise_data: Tensor
+columnwise_data: Tensor
+logical_shape: Tuple
+fake_dtype: dtype
+quantized_tensors: List
+make_grouped_tensor()$
+make_grouped_tensor_with_shapes()$
+split_into_quantized_tensors()
+quantize()
+has_data()
+all_same_shape()
}
class GroupedTensor {
+__new__()
+__torch_dispatch__()$
+__torch_function__()$
}
class GroupedLinear {
+single_grouped_parameter: bool
+make_grouped_weights()
+reset_parameters()
+_get_weight_tensors()
+set_tensor_parallel_attributes()
}
class Quantizer {
<<abstract>>
+internal: bool
+rowwise_usage: bool
+columnwise_usage: bool
+create_grouped_tensor()
}
GroupedTensorStorage <|-- GroupedTensor
`torch.Tensor` <|-- GroupedTensor
GroupedLinear --> GroupedTensor : weight (single_grouped_parameter=True)
GroupedLinear --> GroupedTensorStorage : internal weights
Quantizer --> GroupedTensorStorage : creates via create_grouped_tensor()
GroupedTensorStorage o-- Quantizer : quantizers[]
Last reviewed commit: f413a93 |
| else: | ||
| grouped_weights.quantized_tensors[i].copy_(weights[i]) | ||
|
|
||
| # Re-register the grouped weights as parameters. | ||
| # Re-register as a single grouped weight parameter. | ||
| self.register_parameter( | ||
| "weight", | ||
| torch.nn.Parameter(grouped_weights), | ||
| init_fn=self.init_method, | ||
| get_rng_state_tracker=self.get_rng_state_tracker, | ||
| fp8_meta_index=self._offsets["weight"], |
There was a problem hiding this comment.
Single fp8_meta_index for all GEMM sub-weights
The old per-GEMM registration assigned each weight{i} its own offset:
fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]The new registration uses a single offset for all sub-weights:
fp8_meta_index=self._offsets["weight"],For FP8 recipes that use per-tensor metadata (e.g., MXFP8, Float8BlockScaling), this means all num_gemms sub-weights share the same fp8_meta_index. During FP8 metadata updates (e.g., amax tracking), only the first GEMM's metadata will be updated correctly; the remaining GEMMs will silently use stale or zero metadata. Note that make_grouped_weights does guard against delayed() and float8_current_scaling() recipes, but not against mxfp8() or float8_block_scaling(), so those paths could be affected.
| quantizer=None, | ||
| shapes=shape, | ||
| quantizers=None, | ||
| device="cuda", |
There was a problem hiding this comment.
Each per-tensor quantizer constructed with full-group num_tensors
Each list entry calls make_quantizer(quantization, num_tensors, shape) with num_tensors=3, meaning each quantizer's internal buffers (e.g., FP8 amax/scale tensors) are sized for the entire group of 3 tensors, not for a single tensor. While this doesn't break correctness today (only index 0 of the per-quantizer buffers is used), it inflates memory usage and diverges from production use, where each per-tensor quantizer should be sized for one tensor.
Consider constructing each quantizer for num_tensors=1:
quantizers = [make_quantizer(quantization, 1, shape) for _ in range(num_tensors)]| @@ -450,7 +445,7 @@ def make_grouped_tensor( | |||
| total_scale_elements = 0 | |||
| scale_inv_offsets = [0] | |||
| for i, s in enumerate(shape): | |||
| scale_inv_shape = quantizer.get_scale_shape(s, False) | |||
| scale_inv_shape = reference_quantizer.get_scale_shape(s, False) | |||
| scale_elements = math.prod(scale_inv_shape) | |||
| total_scale_elements += scale_elements | |||
There was a problem hiding this comment.
reference_quantizer used for all per-tensor scale shape calculations
Scale buffer sizes for each tensor are computed using reference_quantizer.get_scale_shape(s, ...) for every index i. If the quantizers in the list share the same recipe type but differ in a parameter that affects block size (e.g., different block_dim in a block-quantizer), scale buffers could be mis-sized for non-reference tensors.
The existing check only validates recipe type equality:
if any(type(q._get_compatible_recipe()) is not type(reference_quantizer._get_compatible_recipe()) ...):It would be safer to also validate that any shape-determining attributes (e.g., block scaling dim) match across all quantizers.
Additional Comments (1)
For in-place ( This can silently break parameter initialization when |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
a32a640 to
5213efb
Compare
|
/te-ci L0 |
| constexpr size_t kCumsumThreadsPerBlock = 256; | ||
|
|
||
| __global__ void __launch_bounds__(kCumsumThreadsPerBlock) | ||
| cumsum_with_leading_zero_kernel(const int64_t *__restrict__ input, int64_t *__restrict__ output, |
There was a problem hiding this comment.
we shouldn't need so many syncthreads in the kernel
There was a problem hiding this comment.
I might prefer a single thread kernel, but if we target a performant cumsum kernel, it's not done properly
|
|
||
| const int64_t logical_last_dim_i64 = static_cast<int64_t>(logical_last_dim); | ||
| auto scaled_first_dims = first_dims_tensor * logical_last_dim_i64; | ||
| auto scaled_first_dims = (first_dims_tensor * logical_last_dim_i64).contiguous(); |
There was a problem hiding this comment.
this should be fused with the cumsum
the cumsum kernel can be renamed to something like nvte_compute_tensor_offsets
| return tensor ? py::cast(*tensor) : py::none(); | ||
| } | ||
|
|
||
| py::object make_grouped_quantizers(const py::object& quantizer, const size_t num_tensors) { |
| "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), | ||
| "last_dims"_a = py::none(), | ||
| "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), | ||
| "logical_shape"_a = std::vector<int64_t>{static_cast<int64_t>(logical_first_dim), |
| weight_tensors = grouped_weight.quantized_tensors | ||
| if weight_tensors is None: | ||
| # TODO(ksivaman): Remove this after GEMM integration. | ||
| weight_tensors = grouped_weight.split_into_quantized_tensors() |
| rowwise_usage = quantizer.rowwise_usage if not no_quantization else True | ||
| columnwise_usage = quantizer.columnwise_usage if not no_quantization else False | ||
| no_quantization = quantizers is None or all(q is None for q in quantizers) | ||
| reference_quantizer = None |
There was a problem hiding this comment.
what is reference_quantizer
|
|
||
| size_t get_cudnn_version() { return cudnnGetVersion(); } | ||
|
|
||
| at::Tensor cumsum(at::Tensor input, std::optional<at::Tensor> out) { |
There was a problem hiding this comment.
need to clean it up and replace it with a nvte_compute_tensor_offsets call
Description
Follow ups and miscellaneous fixes from #2600, #2654, and #2678.
Type of change
Changes
Checklist: