-
Notifications
You must be signed in to change notification settings - Fork 237
Svdquant huggingface checkpoint export support #754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,6 +54,7 @@ | |
| QUANTIZATION_NONE, | ||
| QUANTIZATION_NVFP4, | ||
| QUANTIZATION_NVFP4_AWQ, | ||
| QUANTIZATION_NVFP4_SVDQUANT, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a check to ensure the model is running on a single GPU and not in a distributed setup? I’m not sure that our current SVDQ implementation works correctly with multiple GPUs. We can remove this check later once we verify that SVDQ calibration functions properly in a multi-GPU setting. |
||
| QUANTIZATION_W4A8_AWQ, | ||
| QUANTIZATION_W4A8_NVFP4_FP8, | ||
| ) | ||
|
|
@@ -109,6 +110,10 @@ def _output_hook(module, input, output): | |
| fused_linears = {} | ||
| module_names = set() | ||
|
|
||
| # NVFP4 SVDQuant does not need pre-quant scale fusion (either into previous linear or layernorm) because | ||
| # 1) its kernel handles pre-quant scale. | ||
| # 2) fusing into previous linear will need to change the lora_up in up_proj which may cause issue in | ||
| # the later gate up fusion. | ||
| # Fuse pre_quant_scale to the linear weights if possible | ||
| if quantization_format is not None and "nvfp4_awq" in quantization_format.lower(): | ||
| fuse_prequant_to_linear(model) | ||
|
|
@@ -117,7 +122,9 @@ def _output_hook(module, input, output): | |
| module_names.add(name) | ||
|
|
||
| # For MoE models update pre_quant_scale to average pre_quant_scale amongst experts | ||
| if is_moe(module) and ("awq" in quantization_format): | ||
| if is_moe(module) and ( | ||
| ("awq" in quantization_format) or (quantization_format == QUANTIZATION_NVFP4_SVDQUANT) | ||
| ): | ||
| # update_experts_avg_prequant_scale(module) | ||
| grouped_experts = get_experts_list(module, model_type) | ||
| for modules in grouped_experts: | ||
|
|
@@ -314,6 +321,7 @@ def _export_quantized_weight( | |
|
|
||
| if quantization_format in [ | ||
| QUANTIZATION_NVFP4_AWQ, | ||
| QUANTIZATION_NVFP4_SVDQUANT, | ||
| QUANTIZATION_NVFP4, | ||
| QUANTIZATION_W4A8_AWQ, | ||
| QUANTIZATION_W4A8_NVFP4_FP8, | ||
|
|
@@ -334,7 +342,11 @@ def _export_quantized_weight( | |
| for expert_type in ["Llama4TextExperts", "GptOssExperts"] | ||
| ) | ||
|
|
||
| if quantization_format in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]: | ||
| if quantization_format in [ | ||
| QUANTIZATION_NVFP4, | ||
| QUANTIZATION_NVFP4_AWQ, | ||
| QUANTIZATION_NVFP4_SVDQUANT, | ||
| ]: | ||
| # Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim) | ||
| # for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization | ||
| weight, _ = maybe_transpose_expert_weight_dimensions( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1013,6 +1013,18 @@ def _get_awq_quantizer_block_size(tensor: torch.Tensor, quantizer: TensorQuantiz | |
| return blocksize | ||
|
|
||
|
|
||
| def svd(weight, rank): | ||
| original_device = weight.device | ||
| original_dtype = weight.dtype | ||
| weight_f64 = weight.to(dtype=torch.float64, device=original_device) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need f64?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure. I kept what @jingyu-ml has originally. This part is just a refactoring so that I can reuse this code during qkv fusion. |
||
| u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False) | ||
| us = u[:, :rank] * s[:rank] | ||
| vt = vt[:rank] | ||
| return us.to(device=original_device, dtype=original_dtype), vt.to( | ||
| device=original_device, dtype=original_dtype | ||
| ) | ||
|
|
||
|
|
||
| @torch.no_grad() | ||
| def svdquant( | ||
| model: nn.Module, | ||
|
|
@@ -1034,25 +1046,16 @@ def svdquant( | |
| def postprocess(module, name): | ||
| print_rank_0(f"SVD {name}") | ||
| weight = module.weight.data | ||
| original_device = weight.device | ||
| original_dtype = weight.dtype | ||
| weight_f64 = weight.to(dtype=torch.float64, device=original_device) | ||
| u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False) | ||
| if u.shape[1] < lowrank or vt.shape[0] < lowrank: | ||
| us, vt = svd(weight, lowrank) | ||
| if us.shape[1] < lowrank or vt.shape[0] < lowrank: | ||
| warnings.warn( | ||
| "The low-rank dimensions do not match the layer dimensions. " | ||
| "Please verify your configuration and model settings. " | ||
| f"SVD will be skipped for this layer {name}." | ||
| ) | ||
| return | ||
| us = u[:, :lowrank] * s[:lowrank] | ||
| vt = vt[:lowrank] | ||
| module.weight_quantizer.svdquant_lora_a = vt.to( | ||
| dtype=original_dtype, device=original_device | ||
| ) | ||
| module.weight_quantizer.svdquant_lora_b = us.to( | ||
| dtype=original_dtype, device=original_device | ||
| ) | ||
| module.weight_quantizer.svdquant_lora_a = vt | ||
| module.weight_quantizer.svdquant_lora_b = us | ||
| module.weight.data.sub_( | ||
| module.weight_quantizer.svdquant_lora_b @ module.weight_quantizer.svdquant_lora_a | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can
svdquant_lora_abe None in any case?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I will skip it when It is None.