-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
113 lines (86 loc) · 2.35 KB
/
main.py
File metadata and controls
113 lines (86 loc) · 2.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from util.logger import logger
from typing import Optional, Tuple, Union, Dict, List, Set, Callable, TypeVar, Generic, NewType, Protocol
import hydra
from omegaconf import DictConfig, OmegaConf
import yaml
from util.basic_util import pause, set_global_variable_dict, get_global_variable, set_global_variable
def sample_flux_1_dev(
cfg: DictConfig
):
from task.sample_flux_1_dev import sample_flux_1_dev
sample_flux_1_dev(cfg)
# `sample_flux_1_dev()` done
pass
def sample_sd_2_1(
cfg: DictConfig
):
from task.sample_sd_2_1 import sample_sd_2_1
sample_sd_2_1(cfg)
# `sample_sd_2_1()` done
pass
def sample_sd_turbo(
cfg: DictConfig
):
from task.sample_sd_turbo import sample_sd_turbo
sample_sd_turbo(cfg)
# `sample_sd_turbo()` done
pass
def sample_sdxl(
cfg: DictConfig
):
from task.sample_sdxl import sample_sdxl
sample_sdxl(cfg)
# `sample_sdxl()` done
pass
def sample_sdxl_turbo(
cfg: DictConfig
):
from task.sample_sdxl_turbo import sample_sdxl_turbo
sample_sdxl_turbo(cfg)
# `sample_sdxl_turbo()` done
pass
def sample_freeu(
cfg: DictConfig
):
from task.sample_freeu import sample_freeu
sample_freeu(cfg)
# `sample_freeu()` done
pass
def run_task(
cfg: DictConfig
):
task_name = cfg["task"]["name"]
if task_name.startswith("sample_flux.1-dev"):
sample_flux_1_dev(cfg)
elif task_name.startswith("sample_sd_2.1"):
sample_sd_2_1(cfg)
elif task_name.startswith("sample_sd-turbo"):
sample_sd_turbo(cfg)
elif task_name.startswith("sample_sdxl"):
sample_sdxl(cfg)
elif task_name.startswith("sample_sdxl_turbo"):
sample_sdxl_turbo(cfg)
elif task_name.startswith("sample_freeu"):
sample_freeu(cfg)
else:
raise NotImplementedError(
f"Unsupported task: `{task_name}`. "
)
@hydra.main(version_base = None, config_path = "config", config_name = "cfg")
def main(
cfg: DictConfig
):
cfg = OmegaConf.create(cfg)
cfg = OmegaConf.to_container(
cfg,
resolve = True
)
set_global_variable_dict(cfg)
exp_name = get_global_variable("exp_name")
logger(f"Start experiment `{exp_name}`. ")
run_task(cfg)
logger(f"Experiment `{exp_name}` finished. ")
# `main()` done
pass
if __name__ == "__main__":
main()