-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtrain.py
More file actions
executable file
·120 lines (107 loc) · 4.93 KB
/
train.py
File metadata and controls
executable file
·120 lines (107 loc) · 4.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
import numpy as np
from tqdm import tqdm
from typing import Dict, Tuple, Union
from copy import deepcopy
from utils import barrier, reduce_mean, update_loss_info
from evaluate import evaluate
def train(
model: nn.Module,
data_loader: DataLoader,
loss_fn: nn.Module,
optimizer: Optimizer,
grad_scaler: Union[GradScaler, None],
device: torch.device = torch.device("cuda"),
rank: int = 0,
nprocs: int = 1,
**kwargs,
) -> Tuple[nn.Module, Optimizer, GradScaler, Dict[str, float]]:
info = None
data_iter = tqdm(data_loader) if rank == 0 else data_loader
ddp = nprocs > 1
if "eval_data_loader" in kwargs: # we are evaluting the model withing one training epoch
assert "eval_freq" in kwargs and 0 < kwargs["eval_freq"] < 1, f"eval_freq should be a float between 0 and 1, but got {kwargs['eval_freq']}"
assert "sliding_window" in kwargs, "sliding_window should be provided in kwargs"
assert "max_input_size" in kwargs, "max_input_size should be provided in kwargs"
assert "window_size" in kwargs, "window_size should be provided in kwargs"
assert "stride" in kwargs, "stride should be provided in kwargs"
assert "max_num_windows" in kwargs, "max_num_windows should be provided in kwargs"
eval_within_epoch = True
eval_data_loader = kwargs["eval_data_loader"]
eval_freq = int(kwargs["eval_freq"] * len(data_loader))
sliding_window = kwargs["sliding_window"]
max_input_size = kwargs["max_input_size"]
window_size = kwargs["window_size"]
stride = kwargs["stride"]
max_num_windows = kwargs["max_num_windows"]
best_scores = {}
best_weights = {}
else:
eval_within_epoch = False
best_scores = None
best_weights = None
for batch_idx, (image, gt_points, gt_den_map) in enumerate(data_iter):
image = image.to(device)
gt_points = [p.to(device) for p in gt_points]
gt_den_map = gt_den_map.to(device)
model.train()
with torch.set_grad_enabled(True):
with autocast(device_type="cuda", enabled=grad_scaler is not None and grad_scaler.is_enabled()):
if (model.module.zero_inflated if ddp else model.zero_inflated):
pred_logit_pi_map, pred_logit_map, pred_lambda_map, pred_den_map = model(image)
total_loss, total_loss_info = loss_fn(
pred_logit_pi_map=pred_logit_pi_map,
pred_logit_map=pred_logit_map,
pred_lambda_map=pred_lambda_map,
pred_den_map=pred_den_map,
gt_den_map=gt_den_map,
gt_points=gt_points,
)
else:
pred_logit_map, pred_den_map = model(image)
total_loss, total_loss_info = loss_fn(
pred_logit_map=pred_logit_map,
pred_den_map=pred_den_map,
gt_den_map=gt_den_map,
gt_points=gt_points,
)
optimizer.zero_grad()
if grad_scaler is not None:
grad_scaler.scale(total_loss).backward()
grad_scaler.step(optimizer)
grad_scaler.update()
else:
total_loss.backward()
optimizer.step()
total_loss_info = {k: reduce_mean(v.detach(), nprocs).item() if ddp else v.detach().item() for k, v in total_loss_info.items()}
info = update_loss_info(info, total_loss_info)
barrier(ddp)
if eval_within_epoch and ((batch_idx + 1) % eval_freq == 0 or batch_idx == len(data_loader) - 1):
batch_scores = evaluate(
model=model,
data_loader=eval_data_loader,
sliding_window=sliding_window,
max_input_size=max_input_size,
window_size=window_size,
stride=stride,
max_num_windows=max_num_windows,
device=device,
amp=grad_scaler is not None and grad_scaler.is_enabled(),
local_rank=rank,
nprocs=nprocs,
progress_bar=False,
)
for k, v in batch_scores.items():
if k not in best_scores:
best_scores[k] = v
best_weights[k] = deepcopy(model.module.state_dict() if ddp else model.state_dict())
elif v < best_scores[k]: # smaller is better
best_scores[k] = v
best_weights[k] = deepcopy(model.module.state_dict() if ddp else model.state_dict())
barrier(ddp)
torch.cuda.empty_cache()
return model, optimizer, grad_scaler, {k: np.mean(v) for k, v in info.items()}, best_scores, best_weights