-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
74 lines (68 loc) · 2.46 KB
/
utils.py
File metadata and controls
74 lines (68 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import re
import random
import numpy as np
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
from torchvision import transforms
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
def set_seed(s):
random.seed(s)
np.random.seed(s)
torch.manual_seed(s)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(s)
def get_best_model_paths(save_dir: str, n_splits: int) -> List[str]:
files = os.listdir(save_dir)
fold_best = {}
pattern = re.compile(r'fold(\d+)_epoch(\d+)_acc([0-9.]+)\.pth')
for f in files:
m = pattern.match(f)
if m:
fold_idx = int(m.group(1))
acc = float(m.group(3))
prev = fold_best.get(fold_idx)
if (prev is None) or (acc > prev[0]):
fold_best[fold_idx] = (acc, os.path.join(save_dir, f))
paths = []
for fold in range(n_splits):
if fold not in fold_best:
raise FileNotFoundError(f"No saved model for fold {fold} in {save_dir}")
paths.append(fold_best[fold][1])
return paths
# helper to get dataset with certain transform (train=True)
#此函数完全可以根据实际情况修改,自定义数据集请自定义dataset,再用内置dataloader读取即可(transform is important)
def cifar10_with_transform(transform, root):
return torchvision.datasets.CIFAR10(root, train=True, download=False, transform=transform)
# 跑一次试试
def train_one_epoch(model, dl, optimizer, device):
model.train()
total = 0; loss_sum = 0
for x,y in dl:
x,y = x.to(device), y.to(device)
optimizer.zero_grad()
out = model(x)
loss = F.cross_entropy(out, y)
loss.backward()
optimizer.step()
total += x.size(0)
loss_sum += loss.item() * x.size(0)
return loss_sum / total
def eval_on_dl(model, dl, device):
model.eval()
total = 0; correct = 0; loss_sum = 0
with torch.no_grad():
for x,y in dl:
x,y = x.to(device), y.to(device)
out = model(x)
loss = F.cross_entropy(out, y)
preds = out.argmax(dim=1)
correct += (preds == y).sum().item()
total += x.size(0)
loss_sum += loss.item() * x.size(0)
return loss_sum/total, correct/total