System Info / 系統信息
None
Who can help? / 谁可以帮助到您?
No response
Information / 问题信息
Reproduction / 复现过程
Description
In the ChatGLM implementation, the configuration flag:
apply_query_key_layer_scaling = True
is ignored when running on PyTorch 2.0 or later.
Background
In the original THUDM ChatGLM3 implementation, layer-wise attention scaling is applied:
if self.apply_query_key_layer_scaling:
coeff = layer_number
self.norm_factor *= coeff
This effectively scales attention scores by 1 / (sqrt(head_dim) * layer_number), stabilizing deep layers.
⸻
Current Behavior
For PyTorch < 2.0, the implementation uses baddbmm with alpha = 1 / norm_factor:
matmul_result = torch.baddbmm(
...,
alpha=1.0 / self.norm_factor
)
In this path, layer-wise scaling works as intended.
For PyTorch ≥ 2.0, the implementation uses torch.nn.functional.scaled_dot_product_attention:
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attention_mask, is_causal=True
)
This function applies a fixed scale of 1 / sqrt(head_dim) and does not accept a layer_number parameter.
As a result, apply_query_key_layer_scaling=True is completely ignored.
⸻
Expected behavior / 期待表现
Suggested Fix
• Expose a scaling or norm_factor parameter in the PyTorch 2.0 path
• Multiply by layer_number when apply_query_key_layer_scaling=True
System Info / 系統信息
None
Who can help? / 谁可以帮助到您?
No response
Information / 问题信息
Reproduction / 复现过程
Description
In the ChatGLM implementation, the configuration flag:
is ignored when running on PyTorch 2.0 or later.
Background
In the original THUDM ChatGLM3 implementation, layer-wise attention scaling is applied:
This effectively scales attention scores by 1 / (sqrt(head_dim) * layer_number), stabilizing deep layers.
⸻
Current Behavior
For PyTorch < 2.0, the implementation uses baddbmm with alpha = 1 / norm_factor:
In this path, layer-wise scaling works as intended.
For PyTorch ≥ 2.0, the implementation uses torch.nn.functional.scaled_dot_product_attention:
This function applies a fixed scale of 1 / sqrt(head_dim) and does not accept a layer_number parameter.
As a result, apply_query_key_layer_scaling=True is completely ignored.
⸻
Expected behavior / 期待表现
Suggested Fix
• Expose a scaling or norm_factor parameter in the PyTorch 2.0 path
• Multiply by layer_number when apply_query_key_layer_scaling=True