Skip to content

apply_query_key_layer_scaling ignored in PyTorch ≥ 2.0 scaled_dot_product_attention path #1363

@Qi-Zhan

Description

@Qi-Zhan

System Info / 系統信息

None

Who can help? / 谁可以帮助到您?

No response

Information / 问题信息

  • The official example scripts / 官方的示例脚本
  • My own modified scripts / 我自己修改的脚本和任务

Reproduction / 复现过程

Description

In the ChatGLM implementation, the configuration flag:

apply_query_key_layer_scaling = True

is ignored when running on PyTorch 2.0 or later.

Background

In the original THUDM ChatGLM3 implementation, layer-wise attention scaling is applied:

if self.apply_query_key_layer_scaling:
    coeff = layer_number
    self.norm_factor *= coeff

This effectively scales attention scores by 1 / (sqrt(head_dim) * layer_number), stabilizing deep layers.

Current Behavior
For PyTorch < 2.0, the implementation uses baddbmm with alpha = 1 / norm_factor:

matmul_result = torch.baddbmm(
    ...,
    alpha=1.0 / self.norm_factor
)

In this path, layer-wise scaling works as intended.
For PyTorch ≥ 2.0, the implementation uses torch.nn.functional.scaled_dot_product_attention:

context_layer = torch.nn.functional.scaled_dot_product_attention(
    query_layer, key_layer, value_layer, attention_mask, is_causal=True
)

This function applies a fixed scale of 1 / sqrt(head_dim) and does not accept a layer_number parameter.
As a result, apply_query_key_layer_scaling=True is completely ignored.

Expected behavior / 期待表现

Suggested Fix
• Expose a scaling or norm_factor parameter in the PyTorch 2.0 path
• Multiply by layer_number when apply_query_key_layer_scaling=True

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions