-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
48 lines (34 loc) · 1.15 KB
/
train.py
File metadata and controls
48 lines (34 loc) · 1.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# %%
# for data preprocessing
# import src.config.config
# from src.datapreprocessing import DataPreProcessor
# for training
from src.config.config import ConfigExperiment, get_torch_dtype, setup_experiment
from src.data.dataloader import DataloaderStableDiffusionWithCustomNetwork
from src.model.customSD import StableDiffusionWithCustomNetwork
from src.model.customSDtrainer import TrainerStableDiffusionWithCustomNetwork
from omegaconf import OmegaConf
# %%
def main():
# data preprocess if you need
# set train config
cfg: ConfigExperiment = OmegaConf.load("./default.yaml")
cfg: ConfigExperiment = setup_experiment(cfg)
dataset = DataloaderStableDiffusionWithCustomNetwork(
path=cfg.data.path_dataset, dtype=get_torch_dtype(cfg.train.dtype)
)
dataloader = dataset.get_dataloader(num_batch=cfg.train.num_batch)
# set model
models = StableDiffusionWithCustomNetwork(cfg.model_custom)
# set trainer
trainer = TrainerStableDiffusionWithCustomNetwork(
dataloader=dataloader,
models=models,
cfg=cfg,
)
print(trainer)
trainer.train()
# %%
if __name__ == "__main__":
main()
# %%