-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathevaluate_exp.py
More file actions
90 lines (77 loc) · 2.37 KB
/
Copy pathevaluate_exp.py
File metadata and controls
90 lines (77 loc) · 2.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
from argparse import ArgumentParser
import torch
from nb_utils.clip_eval import ExpEvaluator
from nb_utils.experiments_viewer import ExpsViewer
from nb_utils.cache import Cache, DistributedCache
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--base_path",
type=str,
required=True,
help='Path to the folder with experiments'
)
parser.add_argument(
'--exp_names',
type=str,
nargs='+',
required=True,
help='Target experiments'
)
parser.add_argument(
"--checkpoints_idxs",
type=int,
nargs='+',
required=True,
help='Target checkpoint idxs'
)
parser.add_argument(
"--cache_files_template",
type=str,
default='experiments/dreambooth/*/eval.json',
required=False,
help='Template for existing cache files'
)
parser.add_argument(
"--processes",
type=int,
default=4,
help='Number of parallel threads to perform evaluation'
)
parser.add_argument(
'--gs',
type=float,
default=7.5,
help='Guidance scale used in generation'
)
return parser.parse_args()
def main(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Evaluating for {args.base_path}')
print(f'Run evaluation for names: {args.exp_names} for checkpoints: {args.checkpoints_idxs}')
evaluator = ExpEvaluator(device)
all_cache = DistributedCache(args.cache_files_template).get()
exps_viewer = ExpsViewer(
base_path=args.base_path,
exp_filter_fn=lambda x: x in args.exp_names,
lazy_load=False,
evaluator=evaluator
)
for checkpoint_idx in args.checkpoints_idxs:
inference_specs = ('50', f'{args.gs:.1f}')
stats = exps_viewer.evaluate(
exps_names=args.exp_names,
checkpoint_idx=str(checkpoint_idx),
inference_specs=inference_specs,
cache=all_cache,
processes=args.processes,
)
for key, value in stats.items():
if 'config' not in value:
continue
exp_cache = Cache(os.path.join(value['config']['output_dir'], 'eval.json'), indent=2)
exp_cache.update({key: value})
if __name__ == '__main__':
args = parse_args()
main(args)