forked from 99ffx/CLAM_Dress
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathzoom_classifier.py
More file actions
42 lines (34 loc) · 1.24 KB
/
zoom_classifier.py
File metadata and controls
42 lines (34 loc) · 1.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
from models.model_clam import CLAM_SB
class ZoomFusionClassifier(nn.Module):
def __init__(self, feature_dim=1536, n_classes=2, fusion='avg'):
super().__init__()
self.fusion = fusion
self.classifier = nn.Linear(feature_dim, n_classes)
def forward(self, feats_10x, feats_20x):
if self.fusion == 'avg':
fused = (feats_10x + feats_20x) / 2
elif self.fusion == 'sum':
fused = feats_10x + feats_20x
else:
raise ValueError(
"Invalid fusion method. Choose 'avg', 'sum'")
pooled = fused.mean(dim=0, keepdim=True) # Global average pooling
logits = self.classifier(pooled) # Shape [1, n_classes]
probs = torch.softmax(logits, dim=1)
return probs, logits
model = CLAM_SB(
gate=True,
size_arg="small",
dropout=0.25,
k_sample=50,
n_classes=2,
subtyping=False,
embed_dim=1536
)
checkpoint_path = "checkpoints/uni/s_0_checkpoint_Gigapath.pt"
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint, strict=False) # or checkpoint['model'] if it's nested
# Step 3: View architecture
print(model)