-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathutil.py
More file actions
executable file
·33 lines (24 loc) · 762 Bytes
/
util.py
File metadata and controls
executable file
·33 lines (24 loc) · 762 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os, time, gc, json, pickle, argparse, math
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.distributed as dist
import torch.multiprocessing as mp
import numpy as np
from data.util import *
def num_params(model):
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])
def switch_schedule(schedule, mult, switch):
""" Apply LR multiplier before iteration "switch" """
def f(e):
s = schedule(e)
if e < switch:
return s * mult
return s
return f
def linear_schedule(args):
def f(e):
if e <= args.warmup:
return e / args.warmup
return max((e - args.iterations) / (args.warmup - args.iterations), 0)
return f