-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
41 lines (30 loc) · 1020 Bytes
/
utils.py
File metadata and controls
41 lines (30 loc) · 1020 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import traceback
import torch
dtype = torch.double
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
def unit_bounds(d):
bounds = torch.tensor([[i for _ in range(d)] for i in range(2)],
dtype=dtype,
device=device)
return bounds
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
print("\n")
print(f"{category.__name__}: \n {message} \n")
traceback.print_stack()
print("\n")
def gpu_warmup():
device = torch.device("cuda")
print(f"Warming up GPU: {device}")
for _ in range(100):
x = torch.randn(5000, 5000, device=device)
y = torch.matmul(x, x)
torch.cuda.synchronize()
print("GPU warmed up and ready.")
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', '1'):
return True
elif v.lower() in ('no', 'false', 'f', '0'):
return False