-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtrain.py
More file actions
49 lines (35 loc) · 1.04 KB
/
train.py
File metadata and controls
49 lines (35 loc) · 1.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import hydra
from omegaconf import DictConfig, OmegaConf
import os
import sys
import shutil
import pyrootutils
'''
Adapted from SPURS
https://github.com/luo-group/SPURS/blob/9cf686eb8304740775c4cfdd2437732/spurs/train.py
'''
# add project root directory to Python path
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".git", "pyproject.toml"],
pythonpath=True,
# load environment variables from `.env` file if it exists
# recursively searches for `.env` in all folders starting from work dir
dotenv=True,
)
@hydra.main(config_path=f"{root}/configs", config_name="train.yaml")
def main(cfg: DictConfig):
"""
train ProStab
Args:
cfg: Hydra configuration object
"""
# import necessary modules
from prostab import utils
from prostab.training_pipeline import train
cfg = utils.resolve_experiment_config(cfg)
cfg = utils.extras(cfg)
return train(cfg)
if __name__ == "__main__":
main()