-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathload_model_for_interactive.py
More file actions
64 lines (51 loc) · 2.33 KB
/
load_model_for_interactive.py
File metadata and controls
64 lines (51 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#!/usr/bin/env python3
from argparse import Namespace
import torch
from classifiers.FFBase import FFBase
from common import build_model, EXTRACTORS, CLASSIFIERS
from torchaudio import load
def load_model_for_interactive():
args: Namespace = Namespace()
args.extractor = "XLSR_300M"
args.classifier = "FFAttn1"
args.processor = "MHFA"
model_mhfa, _ = build_model(args)
assert isinstance(model_mhfa, FFBase)
# model_mhfa.load_state_dict(torch.load("FFConcat3_MHFA_finetune_7.pt", map_location=torch.device('cpu'), weights_only=True))
# args.processor = "AASIST"
# model_aasist, _ = build_model(args)
# assert isinstance(model_aasist, FFBase)
# model_aasist.load_state_dict(torch.load("FF_AASIST_finetune_5.pt", map_location=torch.device('cpu'), weights_only=True))
# print("Models loaded successfully")
# print(model_mhfa, model_aasist)
return model_mhfa.eval() #, model_aasist.eval()
def model_params():
params = {}
args: Namespace = Namespace()
for extractor in EXTRACTORS.keys():
for pooling in ("MHFA", "AASIST", "SLS"):
# for classifier in CLASSIFIERS.keys():
# if classifier in ["FFLSTM", "FFLSTM2", "GMMDiff", "LDAGaussianDiff", "SVMDiff"]:
# continue
args.extractor = extractor
args.processor = pooling
args.classifier = "FF"
model, _ = build_model(args)
assert isinstance(model, FFBase)
# compute the number of parameters in the model
num_params = sum(p.numel() for p in model.parameters())
params[f"{extractor}_{pooling}_scores.txt"] = num_params
# args.processor = "AASIST"
# model_aasist, _ = build_model(args)
# assert isinstance(model_aasist, FFBase)
# model_aasist.load_state_dict(torch.load("FF_AASIST_finetune_5.pt", map_location=torch.device('cpu'), weights_only=True))
print(params)
# return model_mhfa.parameters() #, model_aasist.parameters()
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
f_wf, sr1 = load("fake.flac")
r_wf, sr2 = load("fake.flac")
# wf, sr = load("babis-zeman.mp3")
model = load_model_for_interactive()
print(model(torch.vstack([r_wf, r_wf]).to(device), torch.vstack([f_wf, f_wf]).to(device)))
# model_params()