-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_sweeps.py
More file actions
31 lines (23 loc) · 821 Bytes
/
run_sweeps.py
File metadata and controls
31 lines (23 loc) · 821 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import yaml
import wandb
import os.path as path
from train_reddit import wandb_sweep as reddit_sweep
from train_twitter import wandb_sweep as twitter_sweep
# from sweeptest import main
import pprint
COUNT = 30
PROJECT = 'twitter'
RUNNING_DIR = path.dirname(path.realpath(__file__))
with open(path.join(RUNNING_DIR, '{}_sweep.yaml'.format(PROJECT))) as f:
sweep_config = yaml.safe_load(f)
pprint.pprint(sweep_config)
if PROJECT == 'reddit':
sweep_id = wandb.sweep(
sweep_config, project="BERT Implicitly Labeled Reddit v2")
wandb.agent(sweep_id, reddit_sweep, count=COUNT)
elif PROJECT == 'twitter':
sweep_id = wandb.sweep(
sweep_config, project="TwitterSI Classification")
wandb.agent(sweep_id, twitter_sweep, count=COUNT)
else:
print('specified project is not available')