Skip to content

Commit 176331f

Browse files
committed
posegnn LoRA added to all layers
1 parent fcf3578 commit 176331f

4 files changed

Lines changed: 181 additions & 121 deletions

File tree

models/pos_egnn/posegnn/adapter/README.md

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,6 @@
22

33
This adapter injects LoRA into mergeable linear layers of **PosEGNN** and exports merged weights that load into a plain `PosEGNN` with `strict=True`.
44

5-
## Skipped layers
6-
7-
These layers have a built-in activation inside their Dense block, which makes algebraic merging incorrect. They are always skipped so that merged exports match adapter-enabled outputs exactly.
8-
9-
- `encoder.neighbor_embedding.combine.dense_layers.0`
10-
- `encoder.edge_embedding.edge_up.dense_layers.0`
11-
- `encoder.gata.0.gamma_s.0`
12-
- `encoder.gata.0.gamma_v.0`
13-
- `encoder.gata.0.phik_w_ra`
14-
- `encoder.gata.0.edge_attr_up.dense_layers.0`
15-
- `encoder.gata.1.gamma_s.0`
16-
- `encoder.gata.1.gamma_v.0`
17-
- `encoder.gata.1.phik_w_ra`
18-
- `encoder.gata.1.edge_attr_up.dense_layers.0`
19-
- `encoder.gata.2.gamma_s.0`
20-
- `encoder.gata.2.gamma_v.0`
21-
- `encoder.gata.2.phik_w_ra`
22-
- `encoder.gata.2.edge_attr_up.dense_layers.0`
23-
- `encoder.gata.3.gamma_s.0`
24-
- `encoder.gata.3.gamma_v.0`
25-
- `encoder.gata.3.phik_w_ra`
26-
- `encoder.eqff.0.gamma_m.0`
27-
- `encoder.eqff.1.gamma_m.0`
28-
- `encoder.eqff.2.gamma_m.0`
29-
- `encoder.eqff.3.gamma_m.0`
30-
31-
Skipping only affects where LoRA is attached. The base model behavior is unchanged.
32-
335
## Usage
346

357
```python

models/pos_egnn/posegnn/adapter/inject.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# inject.py
21
import re
32
import torch
43
import torch.nn as nn
@@ -8,7 +7,7 @@
87
def apply_lora(model: nn.Module, cfg: LoRAConfig) -> tuple[int, int]:
98
"""
109
Replace leaf linear-like layers under include patterns with LoRA.
11-
Skips any module that has a non-identity .activation to guarantee mergeability.
10+
Safely wraps linears with internal norm/activation since LoRA is pre-activation.
1211
Returns (num_scalar_wrapped, 0).
1312
"""
1413
include_patterns = list(cfg.include_names or [])
@@ -32,25 +31,18 @@ def is_linear_like(m: nn.Module) -> bool:
3231
return False
3332
return isinstance(w, torch.Tensor) and w.ndim == 2
3433

35-
def has_post_act(m: nn.Module) -> bool:
36-
act = getattr(m, "activation", None)
37-
return (act is not None) and (not isinstance(act, nn.Identity))
38-
3934
n_scalar = 0
40-
skipped = [] # <— track skipped post-activation linears
4135

4236
for full_name, module in list(model.named_modules()):
4337
if not is_linear_like(module):
4438
continue
4539
if not wants(full_name):
4640
continue
47-
if has_post_act(module):
48-
skipped.append(full_name) # <— record and skip
49-
continue
5041

5142
parent_name, _, child = full_name.rpartition(".")
5243
parent = model.get_submodule(parent_name) if parent_name else model
5344

45+
# already wrapped guard
5446
if hasattr(module, "base") and hasattr(module, "lora_A") and hasattr(module, "lora_B"):
5547
continue
5648

@@ -60,9 +52,4 @@ def has_post_act(m: nn.Module) -> bool:
6052
setattr(parent, child, wrapped)
6153
n_scalar += 1
6254

63-
if getattr(cfg, "log_skipped", False) and skipped:
64-
print("[lora] skipped post-activation linears:")
65-
for n in skipped:
66-
print(" -", n)
67-
6855
return n_scalar, 0

models/pos_egnn/posegnn/adapter/layers.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torch.nn as nn
3+
import torch.nn.functional as F
34
from typing import Optional
45

56
def _init_lora(linear: nn.Linear, freeze_base: bool):
@@ -10,8 +11,8 @@ def _init_lora(linear: nn.Linear, freeze_base: bool):
1011

1112
class LoRALinear(nn.Module):
1213
"""
13-
LoRA for linear layers:
14-
y = base(x) + scaling * B(A(dropout(x)))
14+
LoRA for linear layers applied pre-activation:
15+
y = act( norm( (W x + b) + scaling * B(A(dropout(x))) ) )
1516
"""
1617
def __init__(self, base_linear: nn.Linear, rank: int, alpha: Optional[float],
1718
dropout: float, merge_on_save: bool, freeze_base: bool):
@@ -28,13 +29,21 @@ def __init__(self, base_linear: nn.Linear, rank: int, alpha: Optional[float],
2829
self.enable_lora = True
2930
self.merged = False
3031

32+
# Optional submodules carried by custom Dense
33+
self._norm = getattr(base_linear, "norm", None)
34+
if not isinstance(self._norm, nn.Module):
35+
self._norm = None
36+
3137
self._post_act = getattr(base_linear, "activation", None)
3238
self._has_post_act = self._post_act is not None and not isinstance(self._post_act, nn.Identity)
33-
self.merge_on_save = bool(merge_on_save and not self._has_post_act)
3439

40+
# Always allow merge on save now that we inject pre-activation
41+
self.merge_on_save = bool(merge_on_save)
42+
43+
# LoRA adapters
3544
self.lora_dropout = nn.Dropout(dropout) if dropout and dropout > 0 else nn.Identity()
36-
self.lora_A = nn.Linear(self.in_features, self.r, bias=False)
37-
self.lora_B = nn.Linear(self.r, self.out_features, bias=False)
45+
self.lora_A = nn.Linear(self.in_features, self.r, bias=False) # down
46+
self.lora_B = nn.Linear(self.r, self.out_features, bias=False) # up
3847

3948
nn.init.kaiming_uniform_(self.lora_A.weight, a=5**0.5)
4049
nn.init.zeros_(self.lora_B.weight)
@@ -43,21 +52,37 @@ def __init__(self, base_linear: nn.Linear, rank: int, alpha: Optional[float],
4352
self._register_state_dict_hook(self._merge_on_state_dict)
4453
self._register_load_state_dict_pre_hook(self._strict_fill_on_load, with_module=True)
4554

55+
def _apply_activation(self, y):
56+
if not self._has_post_act:
57+
return y
58+
act = self._post_act
59+
# support nn.Module or callable (e.g. torch.nn.functional.silu)
60+
if isinstance(act, nn.Module):
61+
return act(y)
62+
if callable(act):
63+
return act(y)
64+
return y
65+
4666
def forward(self, x):
47-
y = self.base(x)
48-
if self._has_post_act:
49-
y = self._post_act(y)
67+
# linear pre-activation
68+
y = F.linear(x, self.base.weight, self.base.bias)
69+
70+
# add LoRA delta pre-activation
5071
if self.enable_lora and self.r > 0:
5172
z = self.lora_dropout(x)
5273
z = self.lora_A(z)
5374
z = self.lora_B(z)
5475
y = y + self.scaling * z
76+
77+
# optional norm then activation
78+
if self._norm is not None:
79+
y = self._norm(y)
80+
y = self._apply_activation(y)
5581
return y
5682

5783
@torch.no_grad()
5884
def merged_weight(self):
59-
if self._has_post_act:
60-
return self.base.weight
85+
# Always valid since injected pre-activation
6186
return self.base.weight + self.scaling * (self.lora_B.weight @ self.lora_A.weight)
6287

6388
def _merge_on_state_dict(self, module, state_dict, prefix, local_metadata):

0 commit comments

Comments
 (0)