-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcpap_sample.py
More file actions
98 lines (83 loc) · 4.6 KB
/
cpap_sample.py
File metadata and controls
98 lines (83 loc) · 4.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import numpy as np
import torch.nn as nn
import os
import click
import glob
from tqdm import tqdm
def weight_init(shape, mode, fan_in, fan_out):
if mode == 'xavier_uniform': return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
if mode == 'xavier_normal': return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
if mode == 'kaiming_uniform': return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
if mode == 'kaiming_normal': return np.sqrt(1 / fan_in) * torch.randn(*shape)
raise ValueError(f'Invalid init mode "{mode}"')
class Linear(torch.nn.Module):
def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features)
self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight)
self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None
def forward(self, x):
x = x @ self.weight.to(x.dtype).t()
if self.bias is not None:
x = x.add_(self.bias.to(x.dtype))
return x
def quan_dequan(self, input, cal=False):
if not cal:
input_q = torch.quantize_per_tensor(input.float(), 1 / 32., 0, dtype=torch.qint8)
return input_q.dequantize()
class ProteinEncoder(torch.nn.Module):
def __init__(self,label_dim,noise_channels):
super().__init__()
init = dict(init_mode='xavier_uniform')
self.map_label_0 = nn.Sequential(nn.TransformerEncoderLayer(d_model=label_dim, nhead=8, dim_feedforward=768, ),
nn.TransformerEncoderLayer(d_model=label_dim, nhead=8, dim_feedforward=768, ),
nn.TransformerEncoderLayer(d_model=label_dim, nhead=8, dim_feedforward=768, ),)
self.map_label_1 = nn.Sequential(Linear(in_features=label_dim, out_features=label_dim*2, **init),
nn.Tanh(),
Linear(in_features=label_dim*2, out_features=label_dim*2, **init),
nn.Tanh(),
Linear(in_features=label_dim*2, out_features=label_dim, **init),
nn.Tanh(),
Linear(in_features=label_dim, out_features=noise_channels, **init),)
def forward(self, x):
x = self.map_label_0(x)
x = self.map_label_1(x)
return x
@click.command()
@click.option('--input_dir', help='Input directory containing .npy files', default='datasets/MASSA_labels', show_default=True)
@click.option('--output_dir', help='Output directory for clustered labels', default='datasets/CPAP_MASSA_labels_clustered', show_default=True)
@click.option('--ckpt', help='Path to the checkpoint file', default='training-runs/cpap_cluster_train_0605/training-state-010003.pt', show_default=True)
@click.option('--device', help='Device to use', default='cuda' if torch.cuda.is_available() else 'cpu', show_default=True)
def main(input_dir, output_dir, ckpt, device):
device = torch.device(device)
print(f"Loading checkpoint from {ckpt}...")
ckp = torch.load(ckpt, map_location=torch.device('cpu'))
net = ckp['ema'].encode_protein.eval().to(device)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Created output directory: {output_dir}")
files = glob.glob(os.path.join(input_dir, "*.npy"))
print(f"Found {len(files)} .npy files in {input_dir}")
for file_path in tqdm(files, desc="Processing"):
file_name = os.path.basename(file_path)
label = np.load(file_path) / 10.0
# Padding/Truncating logic (max_len=908, dim=512)
target_len = 908
current_len = label.shape[0]
if current_len < target_len:
padding = np.zeros((target_len - current_len, 512))
label_padded = np.append(label[:target_len], padding, axis=0)
else:
label_padded = label[:target_len]
class_label = torch.tensor(label_padded).float().unsqueeze(dim=0).to(device)
with torch.no_grad():
result = net(class_label)
result = result / result.norm(dim=-1, keepdim=True)
result = result.cpu().numpy()
save_path = os.path.join(output_dir, file_name)
np.save(save_path, result)
if __name__ == "__main__":
main()