Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/noether/core/schemas/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class PerceiverBlockConfig(TransformerBlockConfig):
@model_validator(mode="after")
def set_kv_dim(self) -> "PerceiverBlockConfig":
"""Set kv_dim to hidden_dim if not provided."""
if self.kv_dim is None and self.condition_dim is None:
if self.kv_dim is None:
self.kv_dim = self.hidden_dim
return self

Expand All @@ -138,7 +138,7 @@ def modulation_linear_projection_config(self) -> LinearProjectionConfig | None:
if self.condition_dim is not None:
return LinearProjectionConfig(
input_dim=self.condition_dim,
output_dim=self.hidden_dim * 8,
output_dim=self.hidden_dim * 6 + (self.kv_dim or self.hidden_dim) * 2,
init_weights="zeros",
)
return None
8 changes: 6 additions & 2 deletions src/noether/modeling/modules/blocks/perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def __init__(
self.modulation = None
elementwise_affine = True
else:
assert config.kv_dim is None
assert config.bias
self._kv_dim = config.kv_dim or config.hidden_dim
if config.modulation_linear_projection_config is not None:
self.modulation = LinearProjection(config=config.modulation_linear_projection_config) # type: ignore[arg-type]
elementwise_affine = False
Expand Down Expand Up @@ -93,7 +93,11 @@ def forward(
if condition is None:
raise ValueError("No conditioning vector provided, but modulation is configured.")
mod = self.modulation(condition)
q_scale, q_shift, kv_scale, kv_shift, attn_gate, mlp_scale, mlp_shift, mlp_gate = mod.chunk(8, dim=-1)
hd = self.norm1q.normalized_shape[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use config.hidden_dim here instead?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, we cannot as this is in the forward pass, and in the current PerceiverBlock implementation, config is not stored in the class instance in the constructor. So we have to infer it

kd = self._kv_dim
q_scale, q_shift, kv_scale, kv_shift, attn_gate, mlp_scale, mlp_shift, mlp_gate = mod.split(
[hd, hd, kd, kd, hd, hd, hd, hd], dim=-1
)
q = q + self.drop_path1(
modulate_gate(
self.ls1(
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/noether/modeling/modules/blocks/test_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,30 @@ def test_perceiver_block_conditioned():
assert torch.allclose(DIT_PERCEIVER_BLOCK, dit_output, atol=1e-4)


def test_perceiver_block_conditioned_with_kv_dim():
"""Test that conditioning works when kv_dim != hidden_dim."""
hidden_dim = 8
kv_dim = 4
condition_dim = 32
torch.manual_seed(0)
config = PerceiverBlockConfig(
hidden_dim=hidden_dim,
num_heads=2,
kv_dim=kv_dim,
condition_dim=condition_dim,
mlp_expansion_factor=4,
)
block = PerceiverBlock(config=config)
batch_size = 2
seq_len = 5
q = torch.randn(batch_size, seq_len, hidden_dim)
kv = torch.randn(batch_size, seq_len, kv_dim)
condition = torch.randn(batch_size, condition_dim)
output = block(q=q, kv=kv, condition=condition)
assert output.shape == q.shape, "Output shape mismatch"
assert not torch.isnan(output).any(), "Output contains NaN"


def test_no_bias():
config = PerceiverBlockConfig(hidden_dim=8, num_heads=2, bias=False, mlp_expansion_factor=4)
block = PerceiverBlock(config=config)
Expand Down
Loading