Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,16 @@ def __init__(
self.norm2 = nn.LayerNorm(hidden_size)
self.with_cross_attention = with_cross_attention

self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=False,
use_flash_attention=use_flash_attention,
)
if with_cross_attention:
self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=False,
use_flash_attention=use_flash_attention,
)
Comment on lines +83 to +92
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Preserve legacy checkpoints when disabling cross-attention.
Old checkpoints created with with_cross_attention=False still contain cross_attn.* and norm_cross_attn.* entries. With these attributes now missing, load_state_dict(..., strict=True) throws unexpected-key errors, so this is a breaking change. Please strip those keys during load (or otherwise ensure they’re ignored) before dropping the modules at init.

@@
         return x
+
+    def _load_from_state_dict(
+        self,
+        state_dict,
+        prefix,
+        local_metadata,
+        strict,
+        missing_keys,
+        unexpected_keys,
+        error_msgs,
+    ):
+        if not self.with_cross_attention:
+            keys = [
+                key
+                for key in list(state_dict.keys())
+                if key.startswith(f"{prefix}cross_attn.") or key.startswith(f"{prefix}norm_cross_attn.")
+            ]
+            for key in keys:
+                state_dict.pop(key)
+        super()._load_from_state_dict(
+            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if with_cross_attention:
self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=False,
use_flash_attention=use_flash_attention,
)
return x
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
# Strip out any cross‐attention params when with_cross_attention=False
if not self.with_cross_attention:
keys = [
key
for key in list(state_dict.keys())
if key.startswith(f"{prefix}cross_attn.") or key.startswith(f"{prefix}norm_cross_attn.")
]
for key in keys:
state_dict.pop(key)
# Delegate actual loading to the parent implementation
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
🤖 Prompt for AI Agents
In monai/networks/blocks/transformerblock.py around lines 83-92, legacy
checkpoints include parameters under cross_attn.* and norm_cross_attn.* even
when with_cross_attention is False, causing unexpected-key errors; update the
class to either preserve dummy attributes or strip those keys when loading.
Implement a small fix: if you choose to drop the modules at init (keep them
absent), override load_state_dict to detect when with_cross_attention is False
and remove any keys that start with "cross_attn." or "norm_cross_attn." from the
incoming state_dict (also handle optimizer/state dict nested structures if
applicable) before delegating to the parent load_state_dict; alternatively, when
with_cross_attention is False, assign lightweight placeholders (e.g.,
nn.Identity or empty submodules) for self.cross_attn and self.norm_cross_attn so
the parameter names remain present and strict loading succeeds.


def forward(
self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None
Expand Down
Loading