File tree Expand file tree Collapse file tree 10 files changed +37
-15
lines changed
template-text-classification
template-vision-classification
template-vision-segmentation Expand file tree Collapse file tree 10 files changed +37
-15
lines changed Original file line number Diff line number Diff line change @@ -2,9 +2,15 @@ def test_save_config():
22 with open ("./config.yaml" , "r" ) as f :
33 config = OmegaConf .load (f )
44
5- save_config (config , "./" )
5+ # Add backend to config (similar to setup_config)
6+ config .backend = None
67
7- with open ( "./config-lock.yaml" , "r" ) as f :
8- test_config = OmegaConf . load ( f )
8+ with tempfile . TemporaryDirectory ( ) as output_dir :
9+ output_dir = Path ( output_dir )
910
10- assert config == test_config
11+ save_config (config , output_dir )
12+
13+ with open (output_dir / "config-lock.yaml" , "r" ) as f :
14+ test_config = OmegaConf .load (f )
15+
16+ assert config == test_config
Original file line number Diff line number Diff line change @@ -149,14 +149,14 @@ def resume_from(
149149
150150def setup_output_dir (config : Any , rank : int ) -> Path :
151151 """Create output folder."""
152+ output_dir = config .output_dir
152153 if rank == 0 :
153154 now = datetime .now ().strftime ("%Y%m%d-%H%M%S" )
154155 name = f"{ now } -backend-{ config .backend } -lr-{ config .lr } "
155156 path = Path (config .output_dir , name )
156157 path .mkdir (parents = True , exist_ok = True )
157- config .output_dir = path .as_posix ()
158-
159- return Path (idist .broadcast (config .output_dir , src = 0 ))
158+ output_dir = path .as_posix ()
159+ return Path (idist .broadcast (output_dir , src = 0 ))
160160
161161
162162def save_config (config , output_dir ):
Original file line number Diff line number Diff line change @@ -27,9 +27,11 @@ def run(local_rank: int, config: Any):
2727 manual_seed (config .seed + rank )
2828
2929 # create output folder and copy config file to output dir
30- config . output_dir = setup_output_dir (config , rank )
30+ output_dir = setup_output_dir (config , rank )
3131 if rank == 0 :
32- save_config (config , config .output_dir )
32+ save_config (config , output_dir )
33+
34+ config .output_dir = output_dir
3335
3436 # donwload datasets and create dataloaders
3537 dataloader_train , dataloader_eval = setup_data (config )
Original file line number Diff line number Diff line change 11import os
2+ import tempfile
23from argparse import Namespace
4+ from pathlib import Path
35from typing import Iterable
46
57import ignite .distributed as idist
Original file line number Diff line number Diff line change @@ -24,9 +24,11 @@ def run(local_rank: int, config: Any):
2424 manual_seed (config .seed + rank )
2525
2626 # create output folder and copy config file to output dir
27- config . output_dir = setup_output_dir (config , rank )
27+ output_dir = setup_output_dir (config , rank )
2828 if rank == 0 :
29- save_config (config , config .output_dir )
29+ save_config (config , output_dir )
30+
31+ config .output_dir = output_dir
3032
3133 # donwload datasets and create dataloaders
3234 dataloader_train , dataloader_eval = setup_data (config )
Original file line number Diff line number Diff line change 11import os
2+ import tempfile
23from argparse import Namespace
4+ from pathlib import Path
35from typing import Iterable
46
57import ignite .distributed as idist
Original file line number Diff line number Diff line change @@ -28,9 +28,11 @@ def run(local_rank: int, config: Any):
2828 manual_seed (config .seed + rank )
2929
3030 # create output folder and copy config file to output dir
31- config . output_dir = setup_output_dir (config , rank )
31+ output_dir = setup_output_dir (config , rank )
3232 if rank == 0 :
33- save_config (config , config .output_dir )
33+ save_config (config , output_dir )
34+
35+ config .output_dir = output_dir
3436
3537 # donwload datasets and create dataloaders
3638 dataloader_train , dataloader_eval , num_channels = setup_data (config )
Original file line number Diff line number Diff line change 11import os
2+ import tempfile
23from argparse import Namespace
4+ from pathlib import Path
35from typing import Iterable
46
57import ignite .distributed as idist
Original file line number Diff line number Diff line change @@ -34,9 +34,11 @@ def run(local_rank: int, config: Any):
3434 manual_seed (config .seed + rank )
3535
3636 # create output folder and copy config file to output dir
37- config . output_dir = setup_output_dir (config , rank )
37+ output_dir = setup_output_dir (config , rank )
3838 if rank == 0 :
39- save_config (config , config .output_dir )
39+ save_config (config , output_dir )
40+
41+ config .output_dir = output_dir
4042
4143 # donwload datasets and create dataloaders
4244 dataloader_train , dataloader_eval = setup_data (config )
Original file line number Diff line number Diff line change 11import os
2+ import tempfile
23from argparse import Namespace
4+ from pathlib import Path
35
46import pytest
57from data import setup_data
You can’t perform that action at this time.
0 commit comments