From 608fc8d742da5de95d9cdb5d9baa6511fe0ce998 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Wed, 30 Apr 2025 07:54:54 +0000 Subject: [PATCH 01/19] [feat] Add DistributedLogger --- src/cogkit/finetune/logger.py | 140 ++++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 src/cogkit/finetune/logger.py diff --git a/src/cogkit/finetune/logger.py b/src/cogkit/finetune/logger.py new file mode 100644 index 0000000..9b1c655 --- /dev/null +++ b/src/cogkit/finetune/logger.py @@ -0,0 +1,140 @@ +import logging +import sys +import os +import tempfile +import torch.distributed as dist +import inspect +from pathlib import Path +from filelock import FileLock + + +class ColoredFormatter(logging.Formatter): + COLORS = { + logging.DEBUG: "\033[36m", + logging.INFO: "\033[32m", + logging.WARNING: "\033[33m", + logging.ERROR: "\033[31m", + logging.CRITICAL: "\033[31;1m", + } + RESET = "\033[0m" + GRAY = "\033[97m" + + def format(self, record): + level_color = self.COLORS.get(record.levelno, self.RESET) + + original_levelname = record.levelname + timestamp_str = self.formatTime(record, self.datefmt) # Get the exact timestamp string + + formatted_message = super().format(record) + + colored_timestamp = f"{self.GRAY}{timestamp_str}{self.RESET}" + formatted_message = formatted_message.replace(timestamp_str, colored_timestamp, 1) + + colored_levelname = f"{level_color}{original_levelname}{self.RESET}" + formatted_message = formatted_message.replace(original_levelname, colored_levelname, 1) + + return formatted_message + + +class DistributedLogger: + def __init__(self, name=None, log_file=None, level=logging.INFO): + if not dist.is_initialized(): + raise RuntimeError("Distributed environment is not setup") + + self.rank = dist.get_rank() + self.logger = logging.getLogger(name) + self.logger.setLevel(level) + self.logger.propagate = False + + base_fmt = f"[rank {self.rank}] | %(asctime)s | %(name)s | %(levelname)s | %(message)s" + date_fmt = "%Y-%m-%d %H:%M:%S" + + if self.is_main_process() and log_file is not None: + log_file = Path(log_file) + if log_file.exists(): + log_file.write_text("") + else: + log_file.touch(exist_ok=True) + + fd, flpath = tempfile.mkstemp() + os.close(fd) # Close file descriptor as we don't need it + self.lock = FileLock(flpath) + self.flpath = flpath + + if not self.logger.handlers: + console_handler = logging.StreamHandler(sys.stdout) + console_formatter = ColoredFormatter(base_fmt, date_fmt) + console_handler.setFormatter(console_formatter) + self.logger.addHandler(console_handler) + + file_handler = logging.FileHandler(log_file) + file_formatter = logging.Formatter(base_fmt, date_fmt) + file_handler.setFormatter(file_formatter) + self.logger.addHandler(file_handler) + + dist.barrier() + + def __del__(self): + if self.is_main_process(): + Path(self.flpath).unlink() + + self.info("Logger destroyed on all processes...") + dist.barrier() + self.info("Logger destroyed on all processes... done") + + def is_main_process(self): + return self.rank == 0 + + def log(self, level, msg, main_only=False, *args, **kwargs) -> None: + # with self.lock: + if not main_only: + self.logger.log(level, msg, *args, **kwargs) + elif main_only and self.is_main_process(): + self.logger.log(level, msg, *args, **kwargs) + + def debug(self, msg, main_only=False, *args, **kwargs) -> None: + self.log(logging.DEBUG, msg, main_only, *args, **kwargs) + + def info(self, msg, main_only=False, *args, **kwargs) -> None: + self.log(logging.INFO, msg, main_only, *args, **kwargs) + + def warning(self, msg, main_only=False, *args, **kwargs) -> None: + self.log(logging.WARNING, msg, main_only, *args, **kwargs) + + def error(self, msg, main_only=False, *args, **kwargs) -> None: + self.log(logging.ERROR, msg, main_only, *args, **kwargs) + + def critical(self, msg, main_only=False, *args, **kwargs) -> None: + self.log(logging.CRITICAL, msg, main_only, *args, **kwargs) + + +def get_logger(name=None, log_file=None, level=logging.INFO) -> DistributedLogger: + if name is None: + frame = inspect.currentframe().f_back + module_name = frame.f_globals["__name__"] + name_parts = module_name.split(".") + if len(name_parts) > 2: + name = ".".join(name_parts[-2:]) + else: + name = module_name + return DistributedLogger(name, log_file, level) + + +if __name__ == "__main__": + dist.init_process_group(backend="nccl") + + logger = get_logger(name="testfile", log_file="test.log") + + logger.debug("Debug message") + logger.info("Info message") + logger.warning("Warning message") + logger.error("Error message") + logger.critical("Critical message") + + logger.debug("Debug message", main_only=True) + logger.info("Info message", main_only=True) + logger.warning("Warning message", main_only=True) + logger.error("Error message", main_only=True) + logger.critical("Critical message", main_only=True) + + dist.destroy_process_group() From 09d44da761cd40005631012c1e7c347fef6e7e6f Mon Sep 17 00:00:00 2001 From: OleehyO Date: Wed, 30 Apr 2025 08:00:44 +0000 Subject: [PATCH 02/19] [refactor] Move sampler and datasets into finetune --- src/cogkit/{ => finetune}/datasets/__init__.py | 0 src/cogkit/{ => finetune}/datasets/i2v_dataset.py | 0 src/cogkit/{ => finetune}/datasets/t2i_dataset.py | 0 src/cogkit/{ => finetune}/datasets/t2v_dataset.py | 0 src/cogkit/{ => finetune}/datasets/utils.py | 0 src/cogkit/finetune/samplers/__init__.py | 3 +++ src/cogkit/{ => finetune}/samplers/packing_sampler.py | 0 src/cogkit/samplers/__init__.py | 3 --- tests/test_sampler.py | 2 +- 9 files changed, 4 insertions(+), 4 deletions(-) rename src/cogkit/{ => finetune}/datasets/__init__.py (100%) rename src/cogkit/{ => finetune}/datasets/i2v_dataset.py (100%) rename src/cogkit/{ => finetune}/datasets/t2i_dataset.py (100%) rename src/cogkit/{ => finetune}/datasets/t2v_dataset.py (100%) rename src/cogkit/{ => finetune}/datasets/utils.py (100%) create mode 100644 src/cogkit/finetune/samplers/__init__.py rename src/cogkit/{ => finetune}/samplers/packing_sampler.py (100%) delete mode 100644 src/cogkit/samplers/__init__.py diff --git a/src/cogkit/datasets/__init__.py b/src/cogkit/finetune/datasets/__init__.py similarity index 100% rename from src/cogkit/datasets/__init__.py rename to src/cogkit/finetune/datasets/__init__.py diff --git a/src/cogkit/datasets/i2v_dataset.py b/src/cogkit/finetune/datasets/i2v_dataset.py similarity index 100% rename from src/cogkit/datasets/i2v_dataset.py rename to src/cogkit/finetune/datasets/i2v_dataset.py diff --git a/src/cogkit/datasets/t2i_dataset.py b/src/cogkit/finetune/datasets/t2i_dataset.py similarity index 100% rename from src/cogkit/datasets/t2i_dataset.py rename to src/cogkit/finetune/datasets/t2i_dataset.py diff --git a/src/cogkit/datasets/t2v_dataset.py b/src/cogkit/finetune/datasets/t2v_dataset.py similarity index 100% rename from src/cogkit/datasets/t2v_dataset.py rename to src/cogkit/finetune/datasets/t2v_dataset.py diff --git a/src/cogkit/datasets/utils.py b/src/cogkit/finetune/datasets/utils.py similarity index 100% rename from src/cogkit/datasets/utils.py rename to src/cogkit/finetune/datasets/utils.py diff --git a/src/cogkit/finetune/samplers/__init__.py b/src/cogkit/finetune/samplers/__init__.py new file mode 100644 index 0000000..cead346 --- /dev/null +++ b/src/cogkit/finetune/samplers/__init__.py @@ -0,0 +1,3 @@ +from .packing_sampler import NaivePackingSampler + +__all__ = ["NaivePackingSampler"] diff --git a/src/cogkit/samplers/packing_sampler.py b/src/cogkit/finetune/samplers/packing_sampler.py similarity index 100% rename from src/cogkit/samplers/packing_sampler.py rename to src/cogkit/finetune/samplers/packing_sampler.py diff --git a/src/cogkit/samplers/__init__.py b/src/cogkit/samplers/__init__.py deleted file mode 100644 index 55c3011..0000000 --- a/src/cogkit/samplers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from cogkit.samplers.packing_sampler import NaivePackingSampler - -__all__ = ["NaivePackingSampler"] diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 2ea5370..81fd4c6 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -4,7 +4,7 @@ import torch from torch.utils.data import DataLoader, Dataset -from cogkit.samplers import NaivePackingSampler +from cogkit.finetune.samplers import NaivePackingSampler # ============================================================================== From 701e98a23e8aacf33b6b9487078ad0f1e6f97c8d Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 2 May 2025 10:18:30 +0000 Subject: [PATCH 03/19] [sampler] Add distributed packing sampler --- src/cogkit/finetune/samplers/__init__.py | 4 +- .../finetune/samplers/packing_sampler.py | 39 ++++++++++++++++++- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/src/cogkit/finetune/samplers/__init__.py b/src/cogkit/finetune/samplers/__init__.py index cead346..570b9e5 100644 --- a/src/cogkit/finetune/samplers/__init__.py +++ b/src/cogkit/finetune/samplers/__init__.py @@ -1,3 +1,3 @@ -from .packing_sampler import NaivePackingSampler +from .packing_sampler import DistPackingSampler, NaivePackingSampler -__all__ = ["NaivePackingSampler"] +__all__ = ["NaivePackingSampler", "DistPackingSampler"] diff --git a/src/cogkit/finetune/samplers/packing_sampler.py b/src/cogkit/finetune/samplers/packing_sampler.py index 430a789..d302467 100644 --- a/src/cogkit/finetune/samplers/packing_sampler.py +++ b/src/cogkit/finetune/samplers/packing_sampler.py @@ -6,9 +6,13 @@ fixed-size batches while preserving sampling randomness. """ -from torch.utils.data import Sampler -from typing import List, Iterator import random +from typing import Iterator, List +from typing_extensions import override + +from torch.utils.data import Sampler +from cogkit.finetune.utils import get_world_size, get_global_rank +import torch.distributed as dist class NaivePackingSampler(Sampler): @@ -67,3 +71,34 @@ def __iter__(self) -> Iterator[List[int]]: def __len__(self): return len(self.idx_buckets) + + +class DistPackingSampler(NaivePackingSampler): + @override + def __init__( + self, + length_list: list[int], + packed_length: int, + shuffle: bool = True, + world_size: int | None = None, + global_rank: int | None = None, + ): + super().__init__(length_list, packed_length, shuffle) + if not dist.is_initialized(): + raise ValueError("DistPackingSampler requires distributed training") + + self.world_size = world_size or get_world_size() + self.global_rank = global_rank or get_global_rank() + + @override + def __iter__(self) -> Iterator[List[int]]: + size = len(self.idx_buckets) // self.world_size + offset = self.global_rank * size + yield from self.idx_buckets[offset : offset + size] + + if self.shuffle: + random.shuffle(self.idx_buckets) + + @override + def __len__(self): + return len(self.idx_buckets) // self.world_size From 4792ecfcf386fd56c884817b4d2816a5eee9446f Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 2 May 2025 10:19:51 +0000 Subject: [PATCH 04/19] [logger] Use filelock to prevent write conflicts --- src/cogkit/finetune/logger.py | 58 ++++++++++++----------------------- 1 file changed, 20 insertions(+), 38 deletions(-) diff --git a/src/cogkit/finetune/logger.py b/src/cogkit/finetune/logger.py index 9b1c655..8de97b4 100644 --- a/src/cogkit/finetune/logger.py +++ b/src/cogkit/finetune/logger.py @@ -37,7 +37,9 @@ def format(self, record): class DistributedLogger: - def __init__(self, name=None, log_file=None, level=logging.INFO): + def __init__( + self, name: str | None = None, log_file: str | Path | None = None, level: int = logging.INFO + ): if not dist.is_initialized(): raise RuntimeError("Distributed environment is not setup") @@ -46,7 +48,7 @@ def __init__(self, name=None, log_file=None, level=logging.INFO): self.logger.setLevel(level) self.logger.propagate = False - base_fmt = f"[rank {self.rank}] | %(asctime)s | %(name)s | %(levelname)s | %(message)s" + base_fmt = f"[rank{self.rank}]: %(asctime)s | %(name)s | %(levelname)s | %(message)s" date_fmt = "%Y-%m-%d %H:%M:%S" if self.is_main_process() and log_file is not None: @@ -67,30 +69,26 @@ def __init__(self, name=None, log_file=None, level=logging.INFO): console_handler.setFormatter(console_formatter) self.logger.addHandler(console_handler) - file_handler = logging.FileHandler(log_file) - file_formatter = logging.Formatter(base_fmt, date_fmt) - file_handler.setFormatter(file_formatter) - self.logger.addHandler(file_handler) + if log_file is not None: + file_handler = logging.FileHandler(log_file) + file_formatter = logging.Formatter(base_fmt, date_fmt) + file_handler.setFormatter(file_formatter) + self.logger.addHandler(file_handler) dist.barrier() def __del__(self): - if self.is_main_process(): - Path(self.flpath).unlink() - - self.info("Logger destroyed on all processes...") - dist.barrier() - self.info("Logger destroyed on all processes... done") + Path(self.flpath).unlink(missing_ok=True) def is_main_process(self): return self.rank == 0 def log(self, level, msg, main_only=False, *args, **kwargs) -> None: - # with self.lock: - if not main_only: - self.logger.log(level, msg, *args, **kwargs) - elif main_only and self.is_main_process(): - self.logger.log(level, msg, *args, **kwargs) + with self.lock: + if not main_only: + self.logger.log(level, msg, *args, **kwargs) + elif main_only and self.is_main_process(): + self.logger.log(level, msg, *args, **kwargs) def debug(self, msg, main_only=False, *args, **kwargs) -> None: self.log(logging.DEBUG, msg, main_only, *args, **kwargs) @@ -108,7 +106,9 @@ def critical(self, msg, main_only=False, *args, **kwargs) -> None: self.log(logging.CRITICAL, msg, main_only, *args, **kwargs) -def get_logger(name=None, log_file=None, level=logging.INFO) -> DistributedLogger: +def get_logger( + name: str | None = None, log_file: str | Path | None = None, level: int = logging.INFO +) -> DistributedLogger: if name is None: frame = inspect.currentframe().f_back module_name = frame.f_globals["__name__"] @@ -117,24 +117,6 @@ def get_logger(name=None, log_file=None, level=logging.INFO) -> DistributedLogge name = ".".join(name_parts[-2:]) else: name = module_name + if log_file is not None: + log_file = Path(log_file).expanduser().resolve() return DistributedLogger(name, log_file, level) - - -if __name__ == "__main__": - dist.init_process_group(backend="nccl") - - logger = get_logger(name="testfile", log_file="test.log") - - logger.debug("Debug message") - logger.info("Info message") - logger.warning("Warning message") - logger.error("Error message") - logger.critical("Critical message") - - logger.debug("Debug message", main_only=True) - logger.info("Info message", main_only=True) - logger.warning("Warning message", main_only=True) - logger.error("Error message", main_only=True) - logger.critical("Critical message", main_only=True) - - dist.destroy_process_group() From d678898b88f49c2f9aa383582d5d0272fe12d6b3 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 2 May 2025 10:23:00 +0000 Subject: [PATCH 05/19] [register] Rename file to avoid import error --- .../finetune/{register.py => _register.py} | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) rename src/cogkit/finetune/{register.py => _register.py} (77%) diff --git a/src/cogkit/finetune/register.py b/src/cogkit/finetune/_register.py similarity index 77% rename from src/cogkit/finetune/register.py rename to src/cogkit/finetune/_register.py index c0efe51..1e5a18c 100644 --- a/src/cogkit/finetune/register.py +++ b/src/cogkit/finetune/_register.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -from typing import Literal, TYPE_CHECKING +from typing import TYPE_CHECKING, Literal # using TYPE_CHECKING to avoid circular import if TYPE_CHECKING: @@ -46,26 +46,26 @@ def show_supported_models(): def get_model_cls( - model_type: str, training_type: Literal["lora", "sft"], use_packing: bool = False + model_name: str, training_type: Literal["lora", "sft"], use_packing: bool = False ) -> "BaseTrainer": """Get the trainer class for a specific model and training type.""" - if model_type not in SUPPORTED_MODELS: - print(f"\nModel '{model_type}' is not supported.") + if model_name not in SUPPORTED_MODELS: + print(f"\nModel '{model_name}' is not supported.") print("\nSupported models are:") for supported_model in SUPPORTED_MODELS: print(f" • {supported_model}") - raise ValueError(f"Model '{model_type}' is not supported") + raise ValueError(f"Model '{model_name}' is not supported") if use_packing: training_type = f"{training_type}-packing" - if training_type not in SUPPORTED_MODELS[model_type]: - print(f"\nTraining type '{training_type}' is not supported for model '{model_type}'.") - print(f"\nSupported training types for '{model_type}' are:") - for supported_type in SUPPORTED_MODELS[model_type]: + if training_type not in SUPPORTED_MODELS[model_name]: + print(f"\nTraining type '{training_type}' is not supported for model '{model_name}'.") + print(f"\nSupported training types for '{model_name}' are:") + for supported_type in SUPPORTED_MODELS[model_name]: print(f" • {supported_type}") raise ValueError( - f"Training type '{training_type}' is not supported for model '{model_type}'" + f"Training type '{training_type}' is not supported for model '{model_name}'" ) - return SUPPORTED_MODELS[model_type][training_type] + return SUPPORTED_MODELS[model_name][training_type] From 2444d09057c8c41fcd56e430f47e3fb83f5316e1 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Sat, 3 May 2025 10:24:23 +0000 Subject: [PATCH 06/19] [utils] Add seed utility for deterministic randomness --- src/cogkit/utils/seed.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 src/cogkit/utils/seed.py diff --git a/src/cogkit/utils/seed.py b/src/cogkit/utils/seed.py new file mode 100644 index 0000000..e10cd4f --- /dev/null +++ b/src/cogkit/utils/seed.py @@ -0,0 +1,18 @@ +import random + +import numpy as np +import torch + + +def set_global_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + if torch.backends.mps.is_available(): + torch.backends.mps.manual_seed(seed) From 11256f105742a82b5a47457300dcb000e89423ce Mon Sep 17 00:00:00 2001 From: OleehyO Date: Sat, 3 May 2025 10:26:01 +0000 Subject: [PATCH 07/19] [lora] Extract adapter name configuration --- src/cogkit/utils/lora.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/cogkit/utils/lora.py b/src/cogkit/utils/lora.py index db9e142..4eecd71 100644 --- a/src/cogkit/utils/lora.py +++ b/src/cogkit/utils/lora.py @@ -26,6 +26,7 @@ # Standard filename for LoRA adapter weights _LORA_WEIGHT_NAME = "adapter_model.safetensors" +_ADAPTER_NAME = "default" def _get_lora_config() -> LoraConfig: @@ -37,7 +38,9 @@ def _get_lora_config() -> LoraConfig: ) -def inject_lora(model, lora_dir_or_state_dict: str | Path | None = None) -> None: +def inject_lora( + model, lora_dir_or_state_dict: str | Path | None = None, adapter_name: str = _ADAPTER_NAME +) -> None: """ Inject LoRA adapters into the model. @@ -49,9 +52,10 @@ def inject_lora(model, lora_dir_or_state_dict: str | Path | None = None) -> None model: The model to inject LoRA adapters into lora_dir_or_state_dict: Path to a LoRA checkpoint directory, a state dict, or None for random initialization + adapter_name: The name of the adapter to inject """ transformer_lora_config = _get_lora_config() - inject_adapter_in_model(transformer_lora_config, model) + inject_adapter_in_model(transformer_lora_config, model, adapter_name=adapter_name) if lora_dir_or_state_dict is None: return @@ -65,7 +69,7 @@ def inject_lora(model, lora_dir_or_state_dict: str | Path | None = None) -> None else: peft_state_dict = lora_dir_or_state_dict - set_peft_model_state_dict(model, peft_state_dict) + set_peft_model_state_dict(model, peft_state_dict, adapter_name=adapter_name) def save_lora(model, lora_dir: str | Path) -> None: From f411e2ee9b96020d9f8940307cbf2b548cbd1a47 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Sat, 3 May 2025 10:27:04 +0000 Subject: [PATCH 08/19] [feat] Support generation mode guess based on object --- src/cogkit/utils/misc.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/cogkit/utils/misc.py b/src/cogkit/utils/misc.py index 0af5c98..b7ff2d4 100644 --- a/src/cogkit/utils/misc.py +++ b/src/cogkit/utils/misc.py @@ -83,7 +83,10 @@ def guess_generation_mode( if isinstance(pipeline_or_path, str): pl_cls_name = get_pipeline_meta(pipeline_or_path)["cls_name"] else: - pl_cls_name = pipeline_or_path.__class__.__name__ + if isinstance(pipeline_or_path, type): + pl_cls_name = pipeline_or_path.__name__ + else: + pl_cls_name = pipeline_or_path.__class__.__name__ if pl_cls_name not in _SUPPORTED_PIPELINE: err_msg = f"The pipeline '{pl_cls_name}' is not supported." From cedf0c3afd944b046ed7e99734d7c41c4d362d0f Mon Sep 17 00:00:00 2001 From: OleehyO Date: Sat, 3 May 2025 10:28:26 +0000 Subject: [PATCH 09/19] Update __init__ file --- src/cogkit/utils/__init__.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/cogkit/utils/__init__.py b/src/cogkit/utils/__init__.py index b2ec328..f4c9383 100644 --- a/src/cogkit/utils/__init__.py +++ b/src/cogkit/utils/__init__.py @@ -1,20 +1,21 @@ # -*- coding: utf-8 -*- -from cogkit.utils.diffusion_pipeline import get_pipeline_meta -from cogkit.utils.dtype import cast_to_torch_dtype -from cogkit.utils.lora import ( +from .diffusion_pipeline import get_pipeline_meta +from .dtype import cast_to_torch_dtype +from .lora import ( load_lora_checkpoint, unload_lora_checkpoint, inject_lora, save_lora, unload_lora, ) -from cogkit.utils.misc import guess_generation_mode, flatten_dict, expand_list -from cogkit.utils.path import mkdir, resolve_path -from cogkit.utils.prompt import convert_prompt -from cogkit.utils.random import rand_generator -from cogkit.utils.load import load_pipeline +from .misc import guess_generation_mode, flatten_dict, expand_list +from .path import mkdir, resolve_path +from .prompt import convert_prompt +from .random import rand_generator +from .load import load_pipeline +from .seed import set_global_seed __all__ = [ "get_pipeline_meta", @@ -32,4 +33,5 @@ "convert_prompt", "flatten_dict", "expand_list", + "set_global_seed", ] From 9a546fd8f6912bd666dc2d4c3586a0e4ffcf5f73 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Sun, 4 May 2025 10:29:30 +0000 Subject: [PATCH 10/19] [feat] Add utils for distributed training --- src/cogkit/finetune/utils/__init__.py | 13 +- .../utils/{attn_mask.py => attention.py} | 0 src/cogkit/finetune/utils/checkpointing.py | 54 ----- src/cogkit/finetune/utils/ckpt.py | 85 ++++++++ src/cogkit/finetune/utils/dist.py | 36 ++++ src/cogkit/finetune/utils/file_utils.py | 39 ---- src/cogkit/finetune/utils/io.py | 107 ++++++++++ src/cogkit/finetune/utils/memory.py | 35 ++++ src/cogkit/finetune/utils/memory_utils.py | 60 ------ src/cogkit/finetune/utils/misc.py | 18 ++ src/cogkit/finetune/utils/optimizer_utils.py | 186 ------------------ src/cogkit/finetune/utils/torch_utils.py | 50 ----- src/cogkit/finetune/utils/tracker.py | 27 +++ 13 files changed, 315 insertions(+), 395 deletions(-) rename src/cogkit/finetune/utils/{attn_mask.py => attention.py} (100%) delete mode 100644 src/cogkit/finetune/utils/checkpointing.py create mode 100644 src/cogkit/finetune/utils/ckpt.py create mode 100644 src/cogkit/finetune/utils/dist.py delete mode 100644 src/cogkit/finetune/utils/file_utils.py create mode 100644 src/cogkit/finetune/utils/io.py create mode 100644 src/cogkit/finetune/utils/memory.py delete mode 100644 src/cogkit/finetune/utils/memory_utils.py create mode 100644 src/cogkit/finetune/utils/misc.py delete mode 100644 src/cogkit/finetune/utils/optimizer_utils.py delete mode 100644 src/cogkit/finetune/utils/torch_utils.py create mode 100644 src/cogkit/finetune/utils/tracker.py diff --git a/src/cogkit/finetune/utils/__init__.py b/src/cogkit/finetune/utils/__init__.py index 8eaeafc..eb210cf 100644 --- a/src/cogkit/finetune/utils/__init__.py +++ b/src/cogkit/finetune/utils/__init__.py @@ -1,7 +1,8 @@ -from .checkpointing import * # noqa -from .file_utils import * # noqa -from .memory_utils import * # noqa -from .optimizer_utils import * # noqa -from .torch_utils import * # noqa +from .ckpt import * # noqa +from .memory import * # noqa from .filters import * # noqa -from .attn_mask import * # noqa +from .attention import * # noqa +from .io import * # noqa +from .dist import * # noqa +from .misc import * # noqa +from .tracker import * # noqa diff --git a/src/cogkit/finetune/utils/attn_mask.py b/src/cogkit/finetune/utils/attention.py similarity index 100% rename from src/cogkit/finetune/utils/attn_mask.py rename to src/cogkit/finetune/utils/attention.py diff --git a/src/cogkit/finetune/utils/checkpointing.py b/src/cogkit/finetune/utils/checkpointing.py deleted file mode 100644 index 5a28505..0000000 --- a/src/cogkit/finetune/utils/checkpointing.py +++ /dev/null @@ -1,54 +0,0 @@ -import os -from pathlib import Path - -from ..utils.file_utils import delete_files, find_files - - -def get_latest_ckpt_path_to_resume_from( - resume_from_checkpoint: str | None, num_update_steps_per_epoch: int, logger -) -> tuple[str | None, int, int, int]: - if resume_from_checkpoint is None: - initial_global_step = 0 - global_step = 0 - first_epoch = 0 - resume_from_checkpoint_path = None - else: - resume_from_checkpoint_path = Path(resume_from_checkpoint) - if not resume_from_checkpoint_path.exists(): - logger.info( - f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run." - ) - initial_global_step = 0 - global_step = 0 - first_epoch = 0 - resume_from_checkpoint_path = None - else: - logger.info(f"Resuming from checkpoint {resume_from_checkpoint}") - global_step = int(resume_from_checkpoint_path.name.split("-")[1]) - - initial_global_step = global_step - first_epoch = global_step // num_update_steps_per_epoch - - return ( - resume_from_checkpoint_path, - initial_global_step, - global_step, - first_epoch, - ) - - -def get_intermediate_ckpt_path(checkpointing_limit: int, step: int, output_dir: str, logger) -> str: - # before saving state, check if this save would set us over the `checkpointing_limit` - if checkpointing_limit is not None: - checkpoints = find_files(output_dir, prefix="checkpoint") - - # before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints - if len(checkpoints) >= checkpointing_limit: - num_to_remove = len(checkpoints) - checkpointing_limit + 1 - checkpoints_to_remove = checkpoints[0:num_to_remove] - delete_files(checkpoints_to_remove, logger) - - logger.info(f"Checkpointing at step {step}") - save_path = os.path.join(output_dir, f"checkpoint-{step}") - logger.info(f"Saving state to {save_path}") - return save_path diff --git a/src/cogkit/finetune/utils/ckpt.py b/src/cogkit/finetune/utils/ckpt.py new file mode 100644 index 0000000..41ac62c --- /dev/null +++ b/src/cogkit/finetune/utils/ckpt.py @@ -0,0 +1,85 @@ +from pathlib import Path + +import torch.distributed as dist +from safetensors.torch import save_file +from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict, StateDictOptions +from torch.distributed.checkpoint.stateful import Stateful + +from cogkit.utils.lora import save_lora + +from .dist import is_main_process +from .io import check_path + + +def save_state_dict( + state_dict: dict, save_dir: str, fname: str, metadata: dict = None, lora: bool = False +) -> None: + if is_main_process(): + if lora: + save_lora(state_dict, save_dir) + else: + save_file(state_dict, save_dir / fname, metadata) + + dist.barrier() + + +def get_global_step(ckpt_path: str | Path) -> int: + ckpt_path = Path(ckpt_path) + check_path(ckpt_path, must_exists=True, must_dir=True) + + try: + global_step = int(ckpt_path.name.split("-")[1]) + except IndexError: + raise ValueError(f"Checkpoint path '{ckpt_path}' is not in the correct format.") + + return global_step + + +class AppState(Stateful): + """This is a useful wrapper for checkpointing the Application State. Since this object is compliant + with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the + dcp.save/load APIs. + + Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model + and optimizer. + + For more details, please refer to: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html + """ + + def __init__(self, model, optimizer=None, lora: bool = False): + self.model = model + self.optimizer = optimizer + self.lora = lora + + def state_dict(self): + # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT + model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) + if self.lora: + from peft import get_peft_model_state_dict + + model_state_dict = get_peft_model_state_dict(self.model) + + return {"model": model_state_dict, "optim": optimizer_state_dict} + + def load_state_dict(self, state_dict): + # sets our state dicts on the model and optimizer, now that we've loaded + if self.lora: + from peft.utils.save_and_load import _insert_adapter_name_into_state_dict + from cogkit.utils.lora import _ADAPTER_NAME + from peft.utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING + + state_dict["model"] = _insert_adapter_name_into_state_dict( + state_dict["model"], + adapter_name=_ADAPTER_NAME, + parameter_prefix=PEFT_TYPE_TO_PREFIX_MAPPING[ + self.model.peft_config[_ADAPTER_NAME].peft_type + ], + ) + + set_state_dict( + self.model, + self.optimizer, + model_state_dict=state_dict["model"], + optim_state_dict=state_dict["optim"], + options=StateDictOptions(strict=False), + ) diff --git a/src/cogkit/finetune/utils/dist.py b/src/cogkit/finetune/utils/dist.py new file mode 100644 index 0000000..562bf84 --- /dev/null +++ b/src/cogkit/finetune/utils/dist.py @@ -0,0 +1,36 @@ +import os +from typing import Any + +import torch +import torch.distributed as dist + + +def check_distributed() -> None: + if not dist.is_initialized(): + raise RuntimeError("Distributed training is not initialized") + + +def is_main_process() -> bool: + return dist.get_rank() == 0 + + +def get_world_size() -> int: + return dist.get_world_size() + + +def get_global_rank() -> int: + return dist.get_rank() + + +def get_local_rank() -> int: + return int(os.environ["LOCAL_RANK"]) + + +def get_device() -> torch.device: + return torch.device(f"cuda:{get_local_rank()}") + + +def gather_object(object: Any) -> list[Any]: + output_objects = [None for _ in range(get_world_size())] + dist.all_gather_object(output_objects, object) + return output_objects diff --git a/src/cogkit/finetune/utils/file_utils.py b/src/cogkit/finetune/utils/file_utils.py deleted file mode 100644 index 93207b6..0000000 --- a/src/cogkit/finetune/utils/file_utils.py +++ /dev/null @@ -1,39 +0,0 @@ -import os -import shutil -from pathlib import Path - - -def find_files(dir: str | Path, prefix: str = "checkpoint") -> list[str]: - if not isinstance(dir, Path): - dir = Path(dir) - if not dir.exists(): - return [] - checkpoints = os.listdir(dir.as_posix()) - checkpoints = [c for c in checkpoints if c.startswith(prefix)] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - checkpoints = [dir / c for c in checkpoints] - return checkpoints - - -def delete_files(dirs: str | list[str] | Path | list[Path], logger) -> None: - if not isinstance(dirs, list): - dirs = [dirs] - dirs = [Path(d) if isinstance(d, str) else d for d in dirs] - logger.info(f"Deleting files: {dirs}") - for dir in dirs: - if not dir.exists(): - continue - shutil.rmtree(dir, ignore_errors=True) - - -def string_to_filename(s: str) -> str: - return ( - s.replace(" ", "-") - .replace("/", "-") - .replace(":", "-") - .replace(".", "-") - .replace(",", "-") - .replace(";", "-") - .replace("!", "-") - .replace("?", "-") - ) diff --git a/src/cogkit/finetune/utils/io.py b/src/cogkit/finetune/utils/io.py new file mode 100644 index 0000000..84934f5 --- /dev/null +++ b/src/cogkit/finetune/utils/io.py @@ -0,0 +1,107 @@ +from pathlib import Path +import shutil +import torch.distributed as dist + +from cogkit.finetune.logger import get_logger + +from .dist import is_main_process + + +def check_path( + path: str | Path | None, + must_exists: bool = False, + must_dir: bool = False, + must_file: bool = False, +) -> None: + if path is None: + raise ValueError("Path is None") + if isinstance(path, str): + path = Path(path) + if must_exists and not path.exists(): + raise FileNotFoundError(f"Path '{path}' does not exist.") + if must_dir and not path.is_dir(): + raise FileNotFoundError(f"Path '{path}' is not a directory.") + if must_file and not path.is_file(): + raise FileNotFoundError(f"Path '{path}' is not a file.") + + +def resolve_path(path: str | Path) -> str: + if isinstance(path, str): + path = Path(path) + check_path(path) + return str(path.expanduser().resolve()) + + +def mkdir(path: str | Path) -> None: + _logger = get_logger() + if is_main_process(): + check_path(path) + Path(resolve_path(path)).mkdir(parents=True, exist_ok=True) + _logger.debug(f"Creating directory: {resolve_path(path)}") + + dist.barrier() + + +def touch(path: str | Path) -> None: + _logger = get_logger() + if is_main_process(): + check_path(path) + Path(resolve_path(path)).touch() + _logger.debug(f"Touching file: {resolve_path(path)}") + + dist.barrier() + + +def list_files(dir: str | Path | None, prefix: str = "checkpoint") -> list[str]: + _logger = get_logger() + if dir is None: + _logger.warning("Directory is None, returning empty list") + return [] + return [str(p) for p in Path(resolve_path(dir)).glob(f"{prefix}*")] + + +def rmdir(path: str | Path) -> None: + _logger = get_logger() + if is_main_process(): + check_path(path, must_exists=True, must_dir=True) + Path(resolve_path(path)).rmdir() + _logger.debug(f"Deleted empty directory: {resolve_path(path)}") + + dist.barrier() + + +def rmfile(path: str | Path, must_exists: bool = True) -> None: + _logger = get_logger() + if is_main_process(): + check_path(path, must_exists=must_exists, must_file=True) + Path(resolve_path(path)).unlink() + _logger.debug(f"Deleted file: {resolve_path(path)}") + + dist.barrier() + + +def rmtree(path: str | Path) -> None: + """Recursively delete a directory tree.""" + _logger = get_logger() + if is_main_process(): + path = Path(resolve_path(path)) + check_path(path, must_exists=True, must_dir=True) + shutil.rmtree(path) + _logger.debug(f"Recursively deleted directory: {path}") + + dist.barrier() + + +def delete_files(files: list[str], recursive: bool = True) -> None: + for file in files: + check_path(file, must_exists=True) + path = Path(file) + if path.is_dir(): + if recursive: + rmtree(path) + else: + rmdir(path) + else: + rmfile(path) + + dist.barrier() diff --git a/src/cogkit/finetune/utils/memory.py b/src/cogkit/finetune/utils/memory.py new file mode 100644 index 0000000..9f063db --- /dev/null +++ b/src/cogkit/finetune/utils/memory.py @@ -0,0 +1,35 @@ +import gc +from typing import Any + +import torch + + +def get_memory_statistics(device: torch.device, precision: int = 3) -> dict[str, Any]: + memory_allocated = None + memory_reserved = None + max_memory_allocated = None + max_memory_reserved = None + + device = torch.cuda.current_device() + memory_allocated = torch.cuda.memory_allocated(device) + memory_reserved = torch.cuda.memory_reserved(device) + max_memory_allocated = torch.cuda.max_memory_allocated(device) + max_memory_reserved = torch.cuda.max_memory_reserved(device) + + return { + "memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision), + "memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision), + "max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision), + "max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision), + } + + +def bytes_to_gigabytes(x: int) -> float: + if x is not None: + return x / 1024**3 + + +def free_memory() -> None: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() diff --git a/src/cogkit/finetune/utils/memory_utils.py b/src/cogkit/finetune/utils/memory_utils.py deleted file mode 100644 index 3427410..0000000 --- a/src/cogkit/finetune/utils/memory_utils.py +++ /dev/null @@ -1,60 +0,0 @@ -import gc -from typing import Any - -import torch - - -def get_memory_statistics(logger, precision: int = 3) -> dict[str, Any]: - memory_allocated = None - memory_reserved = None - max_memory_allocated = None - max_memory_reserved = None - - if torch.cuda.is_available(): - device = torch.cuda.current_device() - memory_allocated = torch.cuda.memory_allocated(device) - memory_reserved = torch.cuda.memory_reserved(device) - max_memory_allocated = torch.cuda.max_memory_allocated(device) - max_memory_reserved = torch.cuda.max_memory_reserved(device) - - elif torch.mps.is_available(): - memory_allocated = torch.mps.current_allocated_memory() - - else: - logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.") - - return { - "memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision), - "memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision), - "max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision), - "max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision), - } - - -def bytes_to_gigabytes(x: int) -> float: - if x is not None: - return x / 1024**3 - - -def free_memory() -> None: - if torch.cuda.is_available(): - gc.collect() - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - - # TODO(aryan): handle non-cuda devices - - -def unload_model(model): - model.to("cpu") - - -def make_contiguous( - x: torch.Tensor | dict[str, torch.Tensor], -) -> torch.Tensor | dict[str, torch.Tensor]: - if isinstance(x, torch.Tensor): - return x.contiguous() - elif isinstance(x, dict): - return {k: make_contiguous(v) for k, v in x.items()} - else: - return x diff --git a/src/cogkit/finetune/utils/misc.py b/src/cogkit/finetune/utils/misc.py new file mode 100644 index 0000000..6efeb13 --- /dev/null +++ b/src/cogkit/finetune/utils/misc.py @@ -0,0 +1,18 @@ +import torch + + +def cast_training_params(model: torch.nn.Module | list[torch.nn.Module], dtype=torch.float32): + """ + Casts the training parameters of the model to the specified data type. + + Args: + model: The PyTorch model whose parameters will be cast. + dtype: The data type to which the model parameters will be cast. + """ + if not isinstance(model, list): + model = [model] + for m in model: + for param in m.parameters(): + # only upcast trainable parameters into fp32 + if param.requires_grad: + param.data = param.to(dtype) diff --git a/src/cogkit/finetune/utils/optimizer_utils.py b/src/cogkit/finetune/utils/optimizer_utils.py deleted file mode 100644 index 5b38fe6..0000000 --- a/src/cogkit/finetune/utils/optimizer_utils.py +++ /dev/null @@ -1,186 +0,0 @@ -import inspect - -import torch - - -def get_optimizer( - params_to_optimize, - logger, - optimizer_name: str = "adam", - learning_rate: float = 1e-3, - beta1: float = 0.9, - beta2: float = 0.95, - beta3: float = 0.98, - epsilon: float = 1e-8, - weight_decay: float = 1e-4, - prodigy_decouple: bool = False, - prodigy_use_bias_correction: bool = False, - prodigy_safeguard_warmup: bool = False, - use_8bit: bool = False, - use_4bit: bool = False, - use_torchao: bool = False, - use_deepspeed: bool = False, - use_cpu_offload_optimizer: bool = False, - offload_gradients: bool = False, -) -> torch.optim.Optimizer: - optimizer_name = optimizer_name.lower() - - # Use DeepSpeed optimzer - if use_deepspeed: - from accelerate.utils import DummyOptim - - return DummyOptim( - params_to_optimize, - lr=learning_rate, - betas=(beta1, beta2), - eps=epsilon, - weight_decay=weight_decay, - ) - - if use_8bit and use_4bit: - raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.") - - if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer: - try: - import torchao - - torchao.__version__ - except ImportError: - raise ImportError( - "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`." - ) - - if not use_torchao and use_4bit: - raise ValueError("4-bit Optimizers are only supported with torchao.") - - # Optimizer creation - supported_optimizers = ["adam", "adamw", "prodigy", "came"] - if optimizer_name not in supported_optimizers: - logger.warning( - f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`." - ) - optimizer_name = "adamw" - - if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]: - raise ValueError( - "`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers." - ) - - if use_8bit: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) - - if optimizer_name == "adamw": - if use_torchao: - from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit - - optimizer_class = ( - AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW - ) - else: - optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW - - init_kwargs = { - "betas": (beta1, beta2), - "eps": epsilon, - "weight_decay": weight_decay, - } - - elif optimizer_name == "adam": - if use_torchao: - from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit - - optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam - else: - optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam - - init_kwargs = { - "betas": (beta1, beta2), - "eps": epsilon, - "weight_decay": weight_decay, - } - - elif optimizer_name == "prodigy": - try: - import prodigyopt - except ImportError: - raise ImportError( - "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" - ) - - optimizer_class = prodigyopt.Prodigy - - if learning_rate <= 0.1: - logger.warning( - "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" - ) - - init_kwargs = { - "lr": learning_rate, - "betas": (beta1, beta2), - "beta3": beta3, - "eps": epsilon, - "weight_decay": weight_decay, - "decouple": prodigy_decouple, - "use_bias_correction": prodigy_use_bias_correction, - "safeguard_warmup": prodigy_safeguard_warmup, - } - - elif optimizer_name == "came": - try: - import came_pytorch - except ImportError: - raise ImportError( - "To use CAME, please install the came-pytorch library: `pip install came-pytorch`" - ) - - optimizer_class = came_pytorch.CAME - - init_kwargs = { - "lr": learning_rate, - "eps": (1e-30, 1e-16), - "betas": (beta1, beta2, beta3), - "weight_decay": weight_decay, - } - - if use_cpu_offload_optimizer: - from torchao.prototype.low_bit_optim import CPUOffloadOptimizer - - if "fused" in inspect.signature(optimizer_class.__init__).parameters: - init_kwargs.update({"fused": True}) - - optimizer = CPUOffloadOptimizer( - params_to_optimize, - optimizer_class=optimizer_class, - offload_gradients=offload_gradients, - **init_kwargs, - ) - else: - optimizer = optimizer_class(params_to_optimize, **init_kwargs) - - return optimizer - - -def gradient_norm(parameters): - norm = 0 - for param in parameters: - if param.grad is None: - continue - local_norm = param.grad.detach().data.norm(2) - norm += local_norm.item() ** 2 - norm = norm**0.5 - return norm - - -def max_gradient(parameters): - max_grad_value = float("-inf") - for param in parameters: - if param.grad is None: - continue - local_max_grad = param.grad.detach().data.abs().max() - max_grad_value = max(max_grad_value, local_max_grad.item()) - return max_grad_value diff --git a/src/cogkit/finetune/utils/torch_utils.py b/src/cogkit/finetune/utils/torch_utils.py deleted file mode 100644 index 9db6800..0000000 --- a/src/cogkit/finetune/utils/torch_utils.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch -from accelerate import Accelerator -from diffusers.utils.torch_utils import is_compiled_module - - -def unwrap_model(accelerator: Accelerator, model): - model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model - return model - - -def align_device_and_dtype( - x: torch.Tensor | dict[str, torch.Tensor], - device: torch.device | None = None, - dtype: torch.dtype | None = None, -): - if isinstance(x, torch.Tensor): - if device is not None: - x = x.to(device) - if dtype is not None: - x = x.to(dtype) - elif isinstance(x, dict): - if device is not None: - x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} - if dtype is not None: - x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} - return x - - -def expand_tensor_to_dims(tensor, ndim): - while len(tensor.shape) < ndim: - tensor = tensor.unsqueeze(-1) - return tensor - - -def cast_training_params(model: torch.nn.Module | list[torch.nn.Module], dtype=torch.float32): - """ - Casts the training parameters of the model to the specified data type. - - Args: - model: The PyTorch model whose parameters will be cast. - dtype: The data type to which the model parameters will be cast. - """ - if not isinstance(model, list): - model = [model] - for m in model: - for param in m.parameters(): - # only upcast trainable parameters into fp32 - if param.requires_grad: - param.data = param.to(dtype) diff --git a/src/cogkit/finetune/utils/tracker.py b/src/cogkit/finetune/utils/tracker.py new file mode 100644 index 0000000..9bb1dec --- /dev/null +++ b/src/cogkit/finetune/utils/tracker.py @@ -0,0 +1,27 @@ +from typing import Any + +import torch.distributed as dist +import wandb + +from .dist import is_main_process + + +class WandbTracker: + def __init__(self, name: str, config: dict[str, Any], **kwargs: Any) -> None: + if is_main_process(): + self.tracker = wandb.init( + name=name, + config=config, + **kwargs, + ) + dist.barrier() + + def log(self, *args: Any, **kwargs: Any) -> None: + if is_main_process(): + self.tracker.log(*args, **kwargs) + dist.barrier() + + def finish(self) -> None: + if is_main_process(): + self.tracker.finish() + dist.barrier() From f7d773bbc11cf6bbdca0db8ea9a0d97e0e6ffe71 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Mon, 5 May 2025 10:30:45 +0000 Subject: [PATCH 11/19] Refactor base component for FSDP training --- src/cogkit/finetune/base/base_args.py | 184 +++---- src/cogkit/finetune/base/base_state.py | 19 +- src/cogkit/finetune/base/base_trainer.py | 668 ++++++++++------------- 3 files changed, 405 insertions(+), 466 deletions(-) diff --git a/src/cogkit/finetune/base/base_args.py b/src/cogkit/finetune/base/base_args.py index 07e048d..d66f170 100644 --- a/src/cogkit/finetune/base/base_args.py +++ b/src/cogkit/finetune/base/base_args.py @@ -1,77 +1,143 @@ # -*- coding: utf-8 -*- - -import argparse import datetime import logging +from datetime import timedelta from pathlib import Path from typing import Literal +import yaml from pydantic import BaseModel, ValidationInfo, field_validator class BaseArgs(BaseModel): + model_config = {"frozen": True, "extra": "ignore"} + + ########## Logging ########## + name4train: str + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" + ########## Model ########## model_path: Path model_name: str - training_type: Literal["lora", "sft"] = "lora" ########## Output ########## output_dir: Path = Path(f"train_result/{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}") - report_to: Literal["tensorboard", "wandb", "all"] | None = None - tracker_name: str = "base-tracker" + + ########## Tracker ########## + report_to: Literal["wandb"] | None = None ########## Data Path ########### data_root: Path ########## Training ######### + training_type: Literal["lora", "sft"] = "lora" + strategy: Literal[ + "DDP", "SHARD_GRAD_OP", "FULL_SHARD", "HYBRID_SHARD", "_HYBRID_SHARD_ZERO2" + ] = "FULL_SHARD" + # This will offload model param and grads to CPU memory to save GPU memory, but will slow down training + offload_params_grads: bool = False + # This will increase memory usage since gradients are sharded during accumulation step. + # Note, when used with offload_params_grads, model parameters and gradients will only be offloaded + # to the CPU during the final synchronization (still retained on GPU in gradient accumulation steps) + # which means offload_params_grads is meaningless when used with no_grad_sync_when_accumulating + no_grad_sync_when_accumulating: bool = False + resume_from_checkpoint: Path | None = None seed: int | None = None train_epochs: int - train_steps: int | None = None - checkpointing_steps: int = 200 - checkpointing_limit: int = 10 + checkpointing_steps: int + checkpointing_limit: int batch_size: int gradient_accumulation_steps: int = 1 - mixed_precision: Literal["no", "fp16", "bf16"] + mixed_precision: Literal["fp32", "fp16", "bf16"] low_vram: bool = False learning_rate: float = 2e-5 optimizer: str = "adamw" beta1: float = 0.9 beta2: float = 0.95 - beta3: float = 0.98 epsilon: float = 1e-8 weight_decay: float = 1e-4 max_grad_norm: float = 1.0 - lr_scheduler: str = "linear" - lr_warmup_ratio: float = 0.01 - lr_num_cycles: int = 1 - lr_power: float = 1.0 + lr_scheduler: str = "CosineAnnealingLR" num_workers: int = 8 pin_memory: bool = True gradient_checkpointing: bool = True - nccl_timeout: int = 1800 - - ########## Lora ########## - rank: int = 128 - lora_alpha: int = 64 - target_modules: list[str] = ["to_q", "to_k", "to_v", "to_out.0"] + nccl_timeout: timedelta = timedelta(seconds=1800) ########## Validation ########## do_validation: bool = False validation_steps: int | None # if set, should be a multiple of checkpointing_steps + @field_validator("log_level") + def validate_log_level(cls, v: str) -> str: + match v: + case "DEBUG": + return logging.DEBUG + case "INFO": + return logging.INFO + case "WARNING": + return logging.WARNING + case "ERROR": + return logging.ERROR + case "CRITICAL": + return logging.CRITICAL + case _: + raise ValueError("log_level must be one of: DEBUG, INFO, WARNING, ERROR, CRITICAL") + + @field_validator("nccl_timeout") + def validate_nccl_timeout(cls, v: timedelta | int) -> timedelta: + if isinstance(v, int): + return timedelta(seconds=v) + return v + @field_validator("low_vram") def validate_low_vram(cls, v: bool, info: ValidationInfo) -> bool: if v and info.data.get("training_type") != "lora": raise ValueError("low_vram can only be True when training_type is 'lora'") + if v and info.data.get("offload_params_grads"): + raise ValueError("low_vram and offload_params_grads cannot be enabled simultaneously") + if v and info.data.get("strategy") != "DDP": + raise ValueError("low_vram can only be used with strategy='DDP'") + if v and info.data.get("resume_from_checkpoint") is not None: + raise ValueError("resume_from_checkpoint cannot be used when low_vram is True") + return v + + @field_validator("strategy") + def validate_strategy(cls, v: str, info: ValidationInfo) -> str: + if info.data.get("training_type") == "lora" and v != "DDP": + raise ValueError("When using lora training_type, strategy must be 'DDP'") + return v + + @field_validator("offload_params_grads") + def validate_offload_params_grads(cls, v: bool, info: ValidationInfo) -> bool: + if v and info.data.get("low_vram"): + raise ValueError("low_vram and offload_params_grads cannot be enabled simultaneously") + if v and info.data.get("no_grad_sync_when_accumulating"): + raise ValueError( + "offload_params_grads and no_grad_sync_when_accumulating cannot be enabled simultaneously" + ) + if v and info.data.get("strategy") == "DDP": + raise ValueError("offload_params_grads cannot be enabled when strategy is 'DDP'") + return v + + @field_validator("no_grad_sync_when_accumulating") + def validate_no_grad_sync_when_accumulating(cls, v: bool, info: ValidationInfo) -> bool: + if v and info.data.get("offload_params_grads"): + raise ValueError( + "offload_params_grads and no_grad_sync_when_accumulating cannot be enabled simultaneously" + ) + if v and info.data.get("strategy") == "DDP": + raise ValueError( + "no_grad_sync_when_accumulating cannot be enabled when strategy is 'DDP'" + ) return v @field_validator("validation_steps") @@ -94,77 +160,11 @@ def validate_mixed_precision(cls, v: str, info: ValidationInfo) -> str: return v @classmethod - def get_base_parser(cls): - """Parse command line arguments and return Args instance""" - parser = argparse.ArgumentParser() - # Required arguments - parser.add_argument("--model_path", type=str, required=True) - parser.add_argument("--model_name", type=str, required=True) - parser.add_argument("--training_type", type=str, required=True) - parser.add_argument("--output_dir", type=str, required=True) - parser.add_argument("--data_root", type=str, required=True) - parser.add_argument("--report_to", type=str, required=True) - - # Training hyperparameters - parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--train_epochs", type=int, default=1) - parser.add_argument("--train_steps", type=int, default=None) - parser.add_argument("--gradient_accumulation_steps", type=int, default=1) - parser.add_argument("--batch_size", type=int, default=1) - parser.add_argument("--learning_rate", type=float, default=2e-5) - parser.add_argument("--optimizer", type=str, default="adamw") - parser.add_argument("--beta1", type=float, default=0.9) - parser.add_argument("--beta2", type=float, default=0.95) - parser.add_argument("--beta3", type=float, default=0.98) - parser.add_argument("--epsilon", type=float, default=1e-8) - parser.add_argument("--weight_decay", type=float, default=1e-4) - parser.add_argument("--max_grad_norm", type=float, default=1.0) - - # Learning rate scheduler - parser.add_argument("--lr_scheduler", type=str, default="linear") - parser.add_argument("--lr_warmup_ratio", type=float, default=0.01) - parser.add_argument("--lr_num_cycles", type=int, default=1) - parser.add_argument("--lr_power", type=float, default=1.0) - - # Data loading - parser.add_argument("--num_workers", type=int, default=8) - parser.add_argument("--pin_memory", type=lambda x: x.lower() == "true", default=True) - - # Model configuration - parser.add_argument("--mixed_precision", type=str, default="no") - parser.add_argument("--low_vram", type=lambda x: x.lower() == "true", default=False) - parser.add_argument( - "--gradient_checkpointing", type=lambda x: x.lower() == "true", default=True - ) - parser.add_argument("--nccl_timeout", type=int, default=1800) - - # LoRA parameters - parser.add_argument("--rank", type=int, default=128) - parser.add_argument("--lora_alpha", type=int, default=64) - parser.add_argument( - "--target_modules", - type=str, - nargs="+", - default=["to_q", "to_k", "to_v", "to_out.0"], - ) - - # Checkpointing - parser.add_argument("--checkpointing_steps", type=int, default=200) - parser.add_argument("--checkpointing_limit", type=int, default=10) - parser.add_argument("--resume_from_checkpoint", type=str, default=None) - - # Validation - parser.add_argument("--do_validation", type=lambda x: x.lower() == "true", default=False) - parser.add_argument("--validation_steps", type=int, default=None) - - return parser - - @classmethod - def parse_args(cls): - parser = cls.get_base_parser() + def parse_from_yaml(cls, fpath: str | Path) -> "BaseArgs": + if isinstance(fpath, str): + fpath = Path(fpath) - # parser.add_argument(...) - # ... + with open(fpath, "r") as f: + yaml_dict = yaml.safe_load(f) - args = parser.parse_args() - return cls(**vars(args)) + return cls(**yaml_dict) diff --git a/src/cogkit/finetune/base/base_state.py b/src/cogkit/finetune/base/base_state.py index a3307b6..a475b41 100644 --- a/src/cogkit/finetune/base/base_state.py +++ b/src/cogkit/finetune/base/base_state.py @@ -6,11 +6,18 @@ class BaseState(BaseModel): # Allow arbitrary types (for torch dtype) model_config = {"arbitrary_types_allowed": True} - weight_dtype: torch.dtype = torch.float32 # dtype for mixed precision training - num_trainable_parameters: int = 0 - num_update_steps_per_epoch: int = 0 - total_batch_size_count: int = 0 + world_size: int + local_rank: int + global_rank: int - generator: torch.Generator | None = None + device: torch.device + + weight_dtype: torch.dtype - using_deepspeed: bool = False + train_steps: int = -1 + train_epochs: int = -1 + num_trainable_parameters: int = -1 + num_update_steps_per_epoch: int = -1 + total_batch_size_count: int = -1 + + generator: torch.Generator | None = None diff --git a/src/cogkit/finetune/base/base_trainer.py b/src/cogkit/finetune/base/base_trainer.py index ea32c6a..6fb4099 100644 --- a/src/cogkit/finetune/base/base_trainer.py +++ b/src/cogkit/finetune/base/base_trainer.py @@ -1,40 +1,48 @@ # -*- coding: utf-8 -*- - - import json -import logging import math +import os from abc import ABC, abstractmethod -from datetime import timedelta +from contextlib import nullcontext +from functools import partial from pathlib import Path +from typing import Any -import diffusers import torch -import transformers -from accelerate.accelerator import Accelerator, DistributedType -from accelerate.logging import get_logger -from accelerate.utils import ( - DistributedDataParallelKwargs, - InitProcessGroupKwargs, - ProjectConfiguration, - set_seed, +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + BackwardPrefetch, + CPUOffload, + MixedPrecision, + ShardingStrategy, ) -from diffusers.optimization import get_scheduler +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from cogkit.finetune.base import BaseArgs, BaseComponents, BaseState -from cogkit.utils.lora import inject_lora, save_lora +from cogkit.finetune.logger import get_logger +from cogkit.utils import inject_lora, set_global_seed, save_lora from ..utils import ( + AppState, + WandbTracker, cast_training_params, + check_distributed, + delete_files, free_memory, - get_latest_ckpt_path_to_resume_from, + get_device, + get_global_rank, + get_global_step, + get_local_rank, get_memory_statistics, - get_optimizer, - unwrap_model, - find_files, - delete_files, + get_world_size, + is_main_process, + list_files, + mkdir, ) _DTYPE_MAP = { @@ -51,16 +59,31 @@ class BaseTrainer(ABC): Note: This class assumes that only `transformer` module is needed to be trained. """ - LOG_NAME: str = "BaseTrainer" - LOG_LEVEL: str = "INFO" - # If set, should be a list of components to unload (refer to `Components``) - # `transformer` is always in UNLOAD_LIST UNLOAD_LIST: list[str] | None = None - def __init__(self) -> None: - self.logger = get_logger(self.LOG_NAME, self.LOG_LEVEL) - self.accelerator: Accelerator = None + MODEL_STATE_DICT_FNAME = "model_state_dict.safetensors" + OPTIM_STATE_DICT_FNAME = "optim_state_dict.safetensors" + + def __init__(self, uargs_fpath: str | Path) -> None: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + if isinstance(uargs_fpath, str): + uargs_fpath = Path(uargs_fpath) + + self.uargs = self._init_args(uargs_fpath) + + self._init_distributed() + self._init_directories() + + self.logger = get_logger( + name=self.uargs.name4train, + log_file=self.uargs.output_dir / f"{self.uargs.name4train}.log", + level=self.uargs.log_level, + ) + + if self.uargs.seed is not None: + set_global_seed(self.uargs.seed) + self.train_dataset: Dataset = None self.test_dataset: Dataset = None self.train_data_loader: DataLoader = None @@ -68,77 +91,62 @@ def __init__(self) -> None: self.optimizer = None self.lr_scheduler = None - self.args = self._init_args() self.state = self._init_state() + self.components = self.load_components() + self.tracker = None + if self.uargs.report_to is not None: + self.tracker = WandbTracker( + name=self.uargs.name4train, + config=self.uargs.model_dump(), + ) + self.check_setting() - self._init_distributed() - self._init_logging() - self._init_directories() + def _init_distributed(self) -> None: + dist.init_process_group(backend="nccl", timeout=self.uargs.nccl_timeout) + torch.cuda.set_device(get_local_rank()) - self.components = self.load_components() + def _init_directories(self) -> None: + mkdir(self.uargs.output_dir) - self.state.using_deepspeed = self.accelerator.state.deepspeed_plugin is not None + def _init_args(self, uargs_fpath: Path) -> BaseArgs: + return BaseArgs.parse_from_yaml(uargs_fpath) - def _init_distributed(self): - logging_dir = Path(self.args.output_dir, "logs") - project_config = ProjectConfiguration( - project_dir=self.args.output_dir, logging_dir=logging_dir - ) - ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) - init_process_group_kwargs = InitProcessGroupKwargs( - backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout) - ) - mixed_precision = "no" if torch.backends.mps.is_available() else self.args.mixed_precision - report_to = None if self.args.report_to.lower() == "none" else self.args.report_to - - accelerator = Accelerator( - project_config=project_config, - gradient_accumulation_steps=self.args.gradient_accumulation_steps, - mixed_precision=mixed_precision, - log_with=report_to, - kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], + def _init_state(self) -> BaseState: + return BaseState( + world_size=get_world_size(), + local_rank=get_local_rank(), + global_rank=get_global_rank(), + device=get_device(), + weight_dtype=_DTYPE_MAP[self.uargs.mixed_precision], ) - # Disable AMP for MPS. - if torch.backends.mps.is_available(): - accelerator.native_amp = False + def fit(self) -> None: + self.logger.info("Checking settings...") + self.check_setting() - self.accelerator = accelerator + self.logger.info("Initializing models...") + self.prepare_models() - tracker_name = self.args.tracker_name - self.accelerator.init_trackers( - project_name=tracker_name, - init_kwargs={"wandb": {"name": self.args.output_dir.name}}, - ) + self.logger.info("Initializing dataset and dataloader...") + self.prepare_dataset() - if self.args.seed is not None: - set_seed(self.args.seed) + self.logger.info("Initializing trainable parameters...") + self.prepare_trainable_parameters() - def _init_logging(self) -> None: - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=self.LOG_LEVEL, - ) - if self.accelerator.is_local_main_process: - transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() - else: - transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() + self.logger.info("Preparing model...") + self.prepare_model() - self.logger.info("Initialized Trainer") - self.logger.info( - f"Accelerator state: \n{self.accelerator.state}", - main_process_only=False, - ) + self.logger.info("Initializing optimizer and lr scheduler...") + self.prepare_optimizer() - def _init_directories(self) -> None: - if self.accelerator.is_main_process: - self.args.output_dir = Path(self.args.output_dir) - self.args.output_dir.mkdir(parents=True, exist_ok=True) + self.logger.info("Starting training...") + self.train() + + self.logger.info("Cleaning up...") + self.cleanup() def check_setting(self) -> None: + check_distributed() # Check for `UNLOAD_LIST` if self.UNLOAD_LIST is None: self.logger.warning( @@ -150,37 +158,74 @@ def check_setting(self) -> None: raise ValueError(f"Invalid component name in unload_list: {name}") def prepare_trainable_parameters(self) -> None: - # For mixed precision training we cast all non-trainable weights to half-precision - # as these weights are only used for inference, keeping weights in full precision is not required. - weight_dtype = self.state.weight_dtype - - if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: - # due to pytorch#99272, MPS does not yet support bfloat16. - raise ValueError( - "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." - ) - # For LoRA, we freeze all the parameters # For SFT, we train all the parameters in transformer model for attr_name, component in vars(self.components).items(): if hasattr(component, "requires_grad_"): - if self.args.training_type == "sft" and attr_name == "transformer": + if self.uargs.training_type == "sft" and attr_name == "transformer": component.requires_grad_(True) else: component.requires_grad_(False) - if self.args.training_type == "lora": + if self.uargs.training_type == "lora": # Initialize LoRA weights inject_lora(self.components.transformer, lora_dir_or_state_dict=None) - self.prepare_saving_loading_hooks() - if self.args.gradient_checkpointing: + if self.uargs.gradient_checkpointing: self.components.transformer.enable_gradient_checkpointing() - def prepare_optimizer(self) -> None: - # Make sure the trainable params are in float32 - # cast_training_params([self.components.transformer], dtype=torch.float32) + # cast all trainable params to the specified data type (bf16) + cast_training_params(self.components.transformer, dtype=self.state.weight_dtype) + + def prepare_model(self) -> None: + match self.uargs.strategy: + case "NO_SHARD": + sharding_strategy = ShardingStrategy.NO_SHARD + case "SHARD_GRAD_OP": + sharding_strategy = ShardingStrategy.SHARD_GRAD_OP + case "FULL_SHARD": + sharding_strategy = ShardingStrategy.FULL_SHARD + case "HYBRID_SHARD": + sharding_strategy = ShardingStrategy.HYBRID_SHARD + + if self.uargs.strategy != "DDP": + warp_policy = partial( + size_based_auto_wrap_policy, + min_num_params=int(1e8), + ) + + self.components.transformer = FSDP( + module=self.components.transformer, + device_id=self.state.local_rank, + sharding_strategy=sharding_strategy, + auto_wrap_policy=warp_policy, + cpu_offload=CPUOffload(offload_params=self.uargs.offload_params_grads), + mixed_precision=MixedPrecision( + param_dtype=self.state.weight_dtype, + reduce_dtype=self.state.weight_dtype, + ), + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + use_orig_params=True if self.uargs.training_type == "lora" else False, + ) + else: + # use qlora means we have already moved the model to the device + if not self.uargs.low_vram: + self.components.transformer = self.components.transformer.to(self.state.device) + + self.components.transformer = DDP( + module=self.components.transformer, + device_ids=[self.state.local_rank], + ) + + # Load components needed for training to GPU, and cast them to the specified data type + ignore_list = self.UNLOAD_LIST + self.move_components_to_device( + dtype=self.state.weight_dtype, + device=self.state.device, + ignore_list=ignore_list + ["transformer"], + ) + def prepare_optimizer(self) -> None: # For LoRA, we only want to train the LoRA weights # For SFT, we want to train all the parameters trainable_parameters = list( @@ -191,296 +236,155 @@ def prepare_optimizer(self) -> None: ) transformer_parameters_with_lr = { "params": trainable_parameters, - "lr": self.args.learning_rate, + "lr": self.uargs.learning_rate, } params_to_optimize = [transformer_parameters_with_lr] self.state.num_trainable_parameters = sum(p.numel() for p in trainable_parameters) - use_deepspeed_opt = ( - self.accelerator.state.deepspeed_plugin is not None - and "optimizer" in self.accelerator.state.deepspeed_plugin.deepspeed_config - ) - optimizer = get_optimizer( - params_to_optimize=params_to_optimize, - logger=self.logger, - optimizer_name=self.args.optimizer, - learning_rate=self.args.learning_rate, - beta1=self.args.beta1, - beta2=self.args.beta2, - beta3=self.args.beta3, - epsilon=self.args.epsilon, - weight_decay=self.args.weight_decay, - use_deepspeed=use_deepspeed_opt, + optimizer = torch.optim.AdamW( + params=params_to_optimize, + lr=self.uargs.learning_rate, + betas=(self.uargs.beta1, self.uargs.beta2), + eps=self.uargs.epsilon, + weight_decay=self.uargs.weight_decay, ) - # Do not need to divide by num_gpus since acclerate will handle this after prepare lr_scheduler num_update_steps_per_epoch = math.ceil( - len(self.train_data_loader) / self.args.gradient_accumulation_steps + len(self.train_data_loader) / self.uargs.gradient_accumulation_steps ) - total_train_steps = self.args.train_epochs * num_update_steps_per_epoch - total_num_warmup_steps = max(int(total_train_steps * self.args.lr_warmup_ratio), 0) + total_train_steps = self.uargs.train_epochs * num_update_steps_per_epoch - use_deepspeed_lr_scheduler = ( - self.accelerator.state.deepspeed_plugin is not None - and "scheduler" in self.accelerator.state.deepspeed_plugin.deepspeed_config + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer=optimizer, + T_max=total_train_steps, ) - if use_deepspeed_lr_scheduler: - from accelerate.utils import DummyScheduler - - lr_scheduler = DummyScheduler( - name=self.args.lr_scheduler, - optimizer=optimizer, - total_num_steps=total_train_steps, - num_warmup_steps=total_num_warmup_steps, - ) - else: - lr_scheduler = get_scheduler( - name=self.args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=total_num_warmup_steps, - num_training_steps=total_train_steps, - num_cycles=self.args.lr_num_cycles, - power=self.args.lr_power, - ) - self.optimizer = optimizer self.lr_scheduler = lr_scheduler - def prepare_for_training(self) -> None: - # cast training params to the specified data type (bf16) - cast_training_params(self.components.transformer, dtype=self.state.weight_dtype) - - ( - self.components.transformer, - self.optimizer, - self.train_data_loader, - self.lr_scheduler, - ) = self.accelerator.prepare( - self.components.transformer, - self.optimizer, - self.train_data_loader, - self.lr_scheduler, - ) - - # Load components needed for training to GPU (except transformer), and cast them to the specified data type - ignore_list = self.UNLOAD_LIST - self.move_components_to_device( - dtype=self.state.weight_dtype, device=self.accelerator.device, ignore_list=ignore_list - ) - - if self.args.do_validation: - assert self.test_data_loader is not None - self.test_data_loader = self.accelerator.prepare_data_loader(self.test_data_loader) - + def train(self) -> None: # We need to recalculate our total training steps as the size of the training dataloader may have changed in distributed training num_update_steps_per_epoch = math.ceil( - len(self.train_data_loader) / self.args.gradient_accumulation_steps + len(self.train_data_loader) / self.uargs.gradient_accumulation_steps ) - self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch + self.state.train_steps = self.uargs.train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs - self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch) + self.state.train_epochs = math.ceil(self.state.train_steps / num_update_steps_per_epoch) self.state.num_update_steps_per_epoch = num_update_steps_per_epoch - def train(self) -> None: memory_statistics = get_memory_statistics(self.logger) self.logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") self.state.total_batch_size_count = ( - self.args.batch_size - * self.accelerator.num_processes - * self.args.gradient_accumulation_steps + self.uargs.batch_size * self.state.world_size * self.uargs.gradient_accumulation_steps ) info = { "trainable parameters": self.state.num_trainable_parameters, "total samples": len(self.train_dataset), - "train epochs": self.args.train_epochs, - "train steps": self.args.train_steps, - "batches per device": self.args.batch_size, + "train epochs": self.state.train_epochs, + "train steps": self.state.train_steps, + "batches per device": self.uargs.batch_size, "total batches observed per epoch": len(self.train_data_loader), "train batch size total count": self.state.total_batch_size_count, - "gradient accumulation steps": self.args.gradient_accumulation_steps, + "gradient accumulation steps": self.uargs.gradient_accumulation_steps, } self.logger.info(f"Training configuration: {json.dumps(info, indent=4)}") global_step = 0 - first_epoch = 0 - initial_global_step = 0 - + initial_epoch = 0 # Potentially load in the weights and states from a previous save - ( - resume_from_checkpoint_path, - initial_global_step, - global_step, - first_epoch, - ) = get_latest_ckpt_path_to_resume_from( - resume_from_checkpoint=self.args.resume_from_checkpoint, - num_update_steps_per_epoch=self.state.num_update_steps_per_epoch, - logger=self.logger, - ) - if resume_from_checkpoint_path is not None: - self.accelerator.load_state(resume_from_checkpoint_path) + if self.uargs.resume_from_checkpoint is not None: + self.logger.info(f"Resuming from checkpoint {self.uargs.resume_from_checkpoint}") + global_step = get_global_step(self.uargs.resume_from_checkpoint) + for _ in range(global_step): + self.lr_scheduler.step() + self.resume_from_checkpoint(self.uargs.resume_from_checkpoint) + initial_epoch = global_step // num_update_steps_per_epoch + for group in self.optimizer.param_groups: + group["lr"] = self.lr_scheduler.get_last_lr()[0] progress_bar = tqdm( - range(self.args.train_steps), - initial=initial_global_step, + range(self.state.train_steps), + initial=global_step, desc="Training steps", - disable=not self.accelerator.is_local_main_process, + disable=not is_main_process(), ) - accelerator = self.accelerator - generator = torch.Generator(device=accelerator.device) - if self.args.seed is not None: - generator = generator.manual_seed(self.args.seed) + generator = torch.Generator(device=self.state.device) + if self.uargs.seed is not None: + generator = generator.manual_seed(self.uargs.seed) self.state.generator = generator free_memory() ckpt_path = None - for epoch in range(first_epoch, self.args.train_epochs): - self.logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})") + for epoch in range(initial_epoch, self.uargs.train_epochs): + self.logger.debug(f"Starting epoch ({epoch + 1}/{self.uargs.train_epochs})") self.components.transformer.train() - models_to_accumulate = [self.components.transformer] for step, batch in enumerate(self.train_data_loader): - self.logger.debug(f"Starting step {step + 1}") - logs = {} - - with accelerator.accumulate(models_to_accumulate): - # These weighting schemes use a uniform timestep sampling and instead post-weight the loss - loss = self.compute_loss(batch) - accelerator.backward(loss) - - if accelerator.sync_gradients: - if accelerator.distributed_type == DistributedType.DEEPSPEED: - grad_norm = self.components.transformer.get_global_grad_norm() - # In some cases the grad norm may not return a float - if torch.is_tensor(grad_norm): - grad_norm = grad_norm.item() - else: - grad_norm = accelerator.clip_grad_norm_( - self.components.transformer.parameters(), - self.args.max_grad_norm, - ) - if torch.is_tensor(grad_norm): - grad_norm = grad_norm.item() - - logs["grad_norm"] = grad_norm - - self.optimizer.step() - self.lr_scheduler.step() - self.optimizer.zero_grad() - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) + self.logger.debug(f"Starting step {step + 1}, global step: {global_step}") + + is_sync_step = (step + 1) % self.uargs.gradient_accumulation_steps == 0 + is_last_step = (step + 1) == len(self.train_data_loader) + sync_grad = is_sync_step or is_last_step + + logs = self.train_step(batch, sync_grad=sync_grad) + + if sync_grad: global_step += 1 + progress_bar.update(1) + ckpt_path = self.maybe_save_checkpoint(global_step) - logs["loss"] = loss.detach().item() - logs["lr"] = self.lr_scheduler.get_last_lr()[0] progress_bar.set_postfix(logs) - # Maybe run validation - should_run_validation = ( - self.args.do_validation - and global_step % self.args.validation_steps == 0 - and accelerator.sync_gradients - ) - if should_run_validation: - del loss + if self.tracker is not None: + self.tracker.log(logs, step=global_step) + + if self.uargs.do_validation and global_step % self.uargs.validation_steps == 0: free_memory() self.validate(global_step, ckpt_path=ckpt_path) - accelerator.log(logs, step=global_step) - - if global_step >= self.args.train_steps: - break - - memory_statistics = get_memory_statistics(self.logger) + memory_statistics = get_memory_statistics(self.state.device) self.logger.info( f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}" ) - accelerator.wait_for_everyone() - ckpt_path = self.maybe_save_checkpoint(global_step, must_save=True) - if self.args.do_validation: - free_memory() - self.validate(global_step, ckpt_path=ckpt_path) + def train_step(self, batch: dict[str, Any], sync_grad: bool) -> dict[str, Any]: + logs = {} - del self.components - free_memory() - memory_statistics = get_memory_statistics(self.logger) - self.logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}") + sync_context = self.components.transformer.no_sync() if not sync_grad else nullcontext() - accelerator.end_training() + with sync_context: + loss = self.compute_loss(batch) + loss = loss / self.uargs.gradient_accumulation_steps + loss.backward() - def fit(self) -> None: - self.logger.info("Checking settings...") - self.check_setting() - - self.logger.info("Initializing models...") - self.prepare_models() - - self.logger.info("Initializing dataset and dataloader...") - self.prepare_dataset() - - self.logger.info("Initializing trainable parameters...") - self.prepare_trainable_parameters() - - self.logger.info("Initializing optimizer and lr scheduler...") - self.prepare_optimizer() - - self.logger.info("Preparing for training...") - self.prepare_for_training() - - self.logger.info("Starting training...") - self.train() - - @abstractmethod - def _init_args(self) -> BaseArgs: - raise NotImplementedError - - @abstractmethod - def _init_state(self) -> BaseState: - raise NotImplementedError + if sync_grad: + if self.uargs.strategy != "DDP": + grad_norm = self.components.transformer.clip_grad_norm_( + max_norm=self.uargs.max_grad_norm + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + self.components.transformer.parameters(), + max_norm=self.uargs.max_grad_norm, + ) + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() - @abstractmethod - def load_components(self) -> BaseComponents: - # note: `self.components.transformer`(model needs to be trained) - # and `self.components.pipeline_cls` must be defined - raise NotImplementedError + loss = loss.detach() + dist.all_reduce(grad_norm.to(self.state.device), op=dist.ReduceOp.AVG) + dist.all_reduce(loss.to(self.state.device), op=dist.ReduceOp.AVG) - @abstractmethod - def prepare_models(self) -> None: - # Doing something like `self.components.vae.enable_slicing()` - raise NotImplementedError + logs["grad_norm"] = grad_norm.item() + logs["loss"] = loss.item() + logs["lr"] = self.lr_scheduler.get_last_lr()[0] + del loss # release graph - @abstractmethod - def prepare_dataset(self) -> None: - # initialize `self.train_dataset` and `self.train_data_loader` - # initialize `self.test_dataset` and `self.test_data_loader` if `self.args.do_validation` is True - raise NotImplementedError - - @abstractmethod - def compute_loss(self, batch) -> torch.Tensor: - raise NotImplementedError - - @abstractmethod - def validate(self, step: int, ckpt_path: str | None = None) -> None: - # validation logic defined here - # during validation, additional modules in the pipeline may need to be moved to GPU memory - raise NotImplementedError - - def get_training_dtype(self) -> torch.dtype: - if self.args.mixed_precision == "no": - return _DTYPE_MAP["fp32"] - elif self.args.mixed_precision == "fp16": - return _DTYPE_MAP["fp16"] - elif self.args.mixed_precision == "bf16": - return _DTYPE_MAP["bf16"] - else: - raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}") + return logs def move_components_to_device(self, dtype, device, ignore_list: list[str] = []): ignore_list = set(ignore_list) @@ -497,61 +401,89 @@ def move_components_to_device(self, dtype, device, ignore_list: list[str] = []): component.to(device, dtype=dtype), ) - def prepare_saving_loading_hooks(self): - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - assert self.accelerator.distributed_type != DistributedType.DEEPSPEED - - for model in models: - original_model = unwrap_model(self.accelerator, model) - original_transformer = unwrap_model(self.accelerator, self.components.transformer) - if isinstance(original_model, type(original_transformer)): - if self.accelerator.is_main_process: - save_lora(model, output_dir) - else: - raise ValueError(f"Unexpected save model: {model.__class__}") - - # make sure to pop weight so that corresponding model is not saved again - if weights: - weights.pop() - - def load_model_hook(models, input_dir): - assert self.accelerator.distributed_type != DistributedType.DEEPSPEED - - for model in models: - original_model = unwrap_model(self.accelerator, model) - original_transformer = unwrap_model(self.accelerator, self.components.transformer) - if isinstance(original_model, type(original_transformer)): - inject_lora(model, input_dir) - else: - raise ValueError(f"Unexpected save model: {model.__class__}") - - self.accelerator.register_save_state_pre_hook(save_model_hook) - self.accelerator.register_load_state_pre_hook(load_model_hook) - def maybe_save_checkpoint(self, global_step: int, must_save: bool = False) -> str | None: - if not (must_save or global_step % self.args.checkpointing_steps == 0): + if not (must_save or global_step % self.uargs.checkpointing_steps == 0): return None - checkpointing_limit = self.args.checkpointing_limit - output_dir = Path(self.args.output_dir) + checkpointing_limit = self.uargs.checkpointing_limit + output_dir = Path(self.uargs.output_dir) logger = self.logger if checkpointing_limit is not None: - checkpoints = find_files(output_dir, prefix="checkpoint") + checkpoints = list_files(output_dir, prefix="checkpoint") + + def get_checkpoint_number(path): + try: + return int(Path(path).name.split("-")[1]) + except (IndexError, ValueError): + raise ValueError(f"Invalid checkpoint path: {path}") + + checkpoints.sort(key=get_checkpoint_number) # before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= checkpointing_limit: num_to_remove = len(checkpoints) - checkpointing_limit + 1 checkpoints_to_remove = checkpoints[0:num_to_remove] - if self.accelerator.is_main_process: - delete_files(checkpoints_to_remove, logger) + delete_files(checkpoints_to_remove) - logger.info(f"Checkpointing at step {global_step}") - save_path = output_dir / f"checkpoint-{global_step}" - logger.info(f"Saving state to {save_path}") + save_dir = output_dir / f"checkpoint-{global_step}" + mkdir(save_dir) + logger.info(f"Checkpointing at step {global_step}, saving state to {save_dir} ...") - self.accelerator.save_state(save_path, safe_serialization=True) + saved_model = self.unwrap_model(self.components.transformer) - self.accelerator.wait_for_everyone() - return save_path + state_dict = { + "app": AppState(saved_model, self.optimizer, lora=self.uargs.training_type == "lora") + } + if not self.uargs.low_vram: + dcp.save(state_dict, checkpoint_id=str(save_dir)) + else: + if is_main_process(): + save_lora(saved_model, save_dir) + + return save_dir + + def resume_from_checkpoint(self, ckpt_dir: str | Path) -> None: + transformer = self.unwrap_model(self.components.transformer) + state_dict = { + "app": AppState(transformer, self.optimizer, lora=self.uargs.training_type == "lora") + } + dcp.load(state_dict, checkpoint_id=str(ckpt_dir)) + + def cleanup(self) -> None: + dist.destroy_process_group() + if self.tracker is not None: + self.tracker.finish() + + def unwrap_model(self, model: Any) -> Any: + if self.uargs.strategy == "DDP": + return model.module + else: + return model + + @abstractmethod + def load_components(self) -> BaseComponents: + # note: `self.components.transformer`(model needs to be trained) + # and `self.components.pipeline_cls` must be defined + raise NotImplementedError + + @abstractmethod + def prepare_models(self) -> None: + # Doing something like `self.components.vae.enable_slicing()` + raise NotImplementedError + + @abstractmethod + def prepare_dataset(self) -> None: + # initialize `self.train_dataset` and `self.train_data_loader` + # initialize `self.test_dataset` and `self.test_data_loader` if `self.uargs.do_validation` is True + raise NotImplementedError + + @abstractmethod + def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def validate(self, step: int, ckpt_path: str | None = None) -> None: + # validation logic defined here + # during validation, additional modules in the pipeline may need to be moved to GPU memory + raise NotImplementedError From 62202ea13bdd63eb68a9edbf63d3b6542ebc02dc Mon Sep 17 00:00:00 2001 From: OleehyO Date: Tue, 6 May 2025 10:33:06 +0000 Subject: [PATCH 12/19] [dataset] Refactor --- src/cogkit/finetune/datasets/__init__.py | 6 ++-- src/cogkit/finetune/datasets/i2v_dataset.py | 20 +++++------- src/cogkit/finetune/datasets/t2i_dataset.py | 11 +++---- src/cogkit/finetune/datasets/t2v_dataset.py | 21 ++++-------- src/cogkit/finetune/datasets/utils.py | 36 ++++++++------------- 5 files changed, 37 insertions(+), 57 deletions(-) diff --git a/src/cogkit/finetune/datasets/__init__.py b/src/cogkit/finetune/datasets/__init__.py index e440d41..d8cc16d 100644 --- a/src/cogkit/finetune/datasets/__init__.py +++ b/src/cogkit/finetune/datasets/__init__.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- -from cogkit.datasets.i2v_dataset import BaseI2VDataset, I2VDatasetWithResize -from cogkit.datasets.t2v_dataset import BaseT2VDataset, T2VDatasetWithResize -from cogkit.datasets.t2i_dataset import ( +from .i2v_dataset import BaseI2VDataset, I2VDatasetWithResize +from .t2v_dataset import BaseT2VDataset, T2VDatasetWithResize +from .t2i_dataset import ( T2IDatasetWithFactorResize, T2IDatasetWithResize, T2IDatasetWithPacking, diff --git a/src/cogkit/finetune/datasets/i2v_dataset.py b/src/cogkit/finetune/datasets/i2v_dataset.py index 001be49..14d4d77 100644 --- a/src/cogkit/finetune/datasets/i2v_dataset.py +++ b/src/cogkit/finetune/datasets/i2v_dataset.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Tuple import torch -from accelerate.logging import get_logger from datasets import load_dataset from PIL import Image from safetensors.torch import load_file, save_file @@ -14,7 +13,7 @@ from torchvision.io import VideoReader from typing_extensions import override -from cogkit.finetune.diffusion.constants import LOG_LEVEL, LOG_NAME +from cogkit.finetune.logger import get_logger from .utils import ( get_prompt_embedding, @@ -25,7 +24,7 @@ if TYPE_CHECKING: from cogkit.finetune.diffusion.trainer import DiffusionTrainer -logger = get_logger(LOG_NAME, LOG_LEVEL) +_logger = get_logger() class BaseI2VDataset(Dataset): @@ -84,7 +83,7 @@ def update_with_image(video_example, idx): self.data = video_data.map(update_with_image, with_indices=True) else: - logger.warning( + _logger.warning( f"No image data found in {self.data_root}, using first frame of video instead" ) @@ -116,7 +115,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: ##### prompt prompt = self.data[index]["prompt"] - prompt_embedding = get_prompt_embedding(self.encode_text, prompt, cache_dir, logger) + prompt_embedding = get_prompt_embedding(self.encode_text, prompt, cache_dir) ##### image image_preprocessed = self.data[index]["image"] @@ -137,10 +136,10 @@ def __getitem__(self, index: int) -> dict[str, Any]: ##### video video = self.data[index]["video"] video_path = Path(video._hf_encoded["path"]) - train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution) + train_resolution_str = "x".join(str(x) for x in self.trainer.uargs.train_resolution) video_latent_dir = ( - cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str + cache_dir / "video_latent" / self.trainer.uargs.model_name / train_resolution_str ) video_latent_dir.mkdir(parents=True, exist_ok=True) @@ -148,7 +147,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: if encoded_video_path.exists(): encoded_video = load_file(encoded_video_path)["encoded_video"] - logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False) + _logger.debug(f"Loaded encoded video from {encoded_video_path}") else: frames, _ = self.preprocess(video, None, self.device) # Current shape of frames: [F, C, H, W] @@ -162,10 +161,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: encoded_video = encoded_video[0] encoded_video = encoded_video.to("cpu") save_file({"encoded_video": encoded_video}, encoded_video_path) - logger.info( - f"Saved encoded video to {encoded_video_path}", - main_process_only=False, - ) + _logger.info(f"Saved encoded video to {encoded_video_path}") # shape of encoded_video: [C, F, H, W] # shape of image: [C, H, W] diff --git a/src/cogkit/finetune/datasets/t2i_dataset.py b/src/cogkit/finetune/datasets/t2i_dataset.py index c67f149..383f522 100644 --- a/src/cogkit/finetune/datasets/t2i_dataset.py +++ b/src/cogkit/finetune/datasets/t2i_dataset.py @@ -4,26 +4,25 @@ import torch import torchvision.transforms as transforms -from accelerate.logging import get_logger from datasets import load_dataset from PIL import Image from torch.utils.data import Dataset from typing_extensions import override -from cogkit.finetune.diffusion.constants import LOG_LEVEL, LOG_NAME +from cogkit.finetune.logger import get_logger from .utils import ( + calculate_resize_dimensions, get_image_embedding, get_prompt_embedding, pil2tensor, preprocess_image_with_resize, - calculate_resize_dimensions, ) if TYPE_CHECKING: from cogkit.finetune.diffusion.trainer import DiffusionTrainer -logger = get_logger(LOG_NAME, LOG_LEVEL) +_logger = get_logger() class BaseT2IDataset(Dataset): @@ -80,7 +79,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: ##### prompt prompt = self.data[index]["prompt"] - prompt_embedding = get_prompt_embedding(self.encode_text, prompt, cache_dir, logger) + prompt_embedding = get_prompt_embedding(self.encode_text, prompt, cache_dir) if not self.using_train: return { @@ -100,7 +99,7 @@ def encode_fn(image: Image.Image) -> torch.Tensor: return encoded_image # shape of encoded_image: [C, H, W] - encoded_image = get_image_embedding(encode_fn, image, cache_dir, logger) + encoded_image = get_image_embedding(encode_fn, image, cache_dir) # shape of image: [C, H, W] return { diff --git a/src/cogkit/finetune/datasets/t2v_dataset.py b/src/cogkit/finetune/datasets/t2v_dataset.py index 36ba644..57248c0 100644 --- a/src/cogkit/finetune/datasets/t2v_dataset.py +++ b/src/cogkit/finetune/datasets/t2v_dataset.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING, Any import torch -from accelerate.logging import get_logger from datasets import load_dataset from safetensors.torch import load_file, save_file from torch.utils.data import Dataset @@ -10,14 +9,14 @@ from torchvision.io import VideoReader from typing_extensions import override -from cogkit.finetune.diffusion.constants import LOG_LEVEL, LOG_NAME +from cogkit.finetune.logger import get_logger from .utils import get_prompt_embedding, preprocess_video_with_resize if TYPE_CHECKING: from cogkit.finetune.diffusion.trainer import DiffusionTrainer -logger = get_logger(LOG_NAME, LOG_LEVEL) +_logger = get_logger() class BaseT2VDataset(Dataset): @@ -66,7 +65,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: ##### prompt prompt = self.data[index]["prompt"] - prompt_embedding = get_prompt_embedding(self.encode_text, prompt, cache_dir, logger) + prompt_embedding = get_prompt_embedding(self.encode_text, prompt, cache_dir) if not self.using_train: return { @@ -78,20 +77,17 @@ def __getitem__(self, index: int) -> dict[str, Any]: video = self.data[index]["video"] video_path = Path(video._hf_encoded["path"]) - train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution) + train_resolution_str = "x".join(str(x) for x in self.trainer.uargs.train_resolution) video_latent_dir = ( - cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str + cache_dir / "video_latent" / self.trainer.uargs.model_name / train_resolution_str ) video_latent_dir.mkdir(parents=True, exist_ok=True) encoded_video_path = video_latent_dir / (video_path.stem + ".safetensors") if encoded_video_path.exists(): encoded_video = load_file(encoded_video_path)["encoded_video"] - logger.debug( - f"Loaded encoded video from {encoded_video_path}", - main_process_only=False, - ) + _logger.debug(f"Loaded encoded video from {encoded_video_path}") else: frames = self.preprocess(video, self.device) # Current shape of frames: [F, C, H, W] @@ -105,10 +101,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: encoded_video = encoded_video[0] encoded_video = encoded_video.to("cpu") save_file({"encoded_video": encoded_video}, encoded_video_path) - logger.info( - f"Saved encoded video to {encoded_video_path}", - main_process_only=False, - ) + _logger.info(f"Saved encoded video to {encoded_video_path}") return { "prompt": prompt, diff --git a/src/cogkit/finetune/datasets/utils.py b/src/cogkit/finetune/datasets/utils.py index fb6e87e..588ff64 100644 --- a/src/cogkit/finetune/datasets/utils.py +++ b/src/cogkit/finetune/datasets/utils.py @@ -1,5 +1,4 @@ import hashlib -import logging import math from pathlib import Path from typing import Callable @@ -12,6 +11,10 @@ from safetensors.torch import load_file, save_file from torchvision.io import VideoReader +from cogkit.finetune.logger import get_logger + +_logger = get_logger() + ########## loaders ########## @@ -55,7 +58,7 @@ def load_images_from_videos(videos_path: list[Path]) -> list[Path]: # Save frame as PNG with same name as video cv2.imwrite(str(frame_path), frame) - logging.info(f"Saved first frame to {frame_path}") + _logger.info(f"Saved first frame to {frame_path}") # Release video capture cap.release() @@ -176,16 +179,13 @@ def preprocess_video_with_resize( ########## embedding & caching ########## -def get_prompt_embedding( - encode_fn: Callable, prompt: str, cache_dir: Path, logger: logging.Logger -) -> torch.Tensor: +def get_prompt_embedding(encode_fn: Callable, prompt: str, cache_dir: Path) -> torch.Tensor: """Get prompt embedding from cache or create new one if not exists. Args: encode_fn: Function to project prompt to embedding. prompt: Text prompt to be embedded cache_dir: Base directory for caching embeddings - logger: Logger instance for logging messages Returns: torch.Tensor: Prompt embedding with shape [seq_len, hidden_size] @@ -200,9 +200,9 @@ def get_prompt_embedding( with lock: if prompt_embedding_path.exists(): prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"] - logger.debug( + _logger.debug( f"Loaded prompt embedding from {prompt_embedding_path}", - main_process_only=False, + main_only=False, ) else: prompt_embedding = encode_fn(prompt) @@ -211,22 +211,20 @@ def get_prompt_embedding( prompt_embedding = prompt_embedding.to("cpu") save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path) - logger.info( + _logger.info( f"Saved prompt embedding to {prompt_embedding_path}", - main_process_only=False, + main_only=False, ) return prompt_embedding -def get_image_embedding( - encode_fn: Callable, image: Image.Image, cache_dir: Path, logger: logging.Logger -) -> torch.Tensor: +def get_image_embedding(encode_fn: Callable, image: Image.Image, cache_dir: Path) -> torch.Tensor: encoded_images_dir = cache_dir / "encoded_images" encoded_images_dir.mkdir(parents=True, exist_ok=True) if not hasattr(image, "filename"): - logger.warning("Image object does not have filename attribute, skipping caching.") + _logger.warning("Image object does not have filename attribute, skipping caching.") return encode_fn(image.convert("RGB")).to("cpu") filename = Path(image.filename).stem @@ -235,18 +233,12 @@ def get_image_embedding( if encoded_image_path.exists(): encoded_image = load_file(encoded_image_path)["encoded_image"] - logger.debug( - f"Loaded encoded image from {encoded_image_path}", - main_process_only=False, - ) + _logger.debug(f"Loaded encoded image from {encoded_image_path}") else: encoded_image = encode_fn(image.convert("RGB")) encoded_image = encoded_image.to("cpu") save_file({"encoded_image": encoded_image}, encoded_image_path) - logger.info( - f"Saved encoded image to {encoded_image_path}", - main_process_only=False, - ) + _logger.info(f"Saved encoded image to {encoded_image_path}") return encoded_image From 4faeb4d0cf53364435c6cb8546936939a50f3afd Mon Sep 17 00:00:00 2001 From: OleehyO Date: Tue, 6 May 2025 10:34:57 +0000 Subject: [PATCH 13/19] [trainer] Refactor for FSDP training --- src/cogkit/finetune/diffusion/constants.py | 2 - .../cogvideo/cogvideox_i2v/lora_trainer.py | 52 ++-- .../cogvideo/cogvideox_t2v/lora_trainer.py | 46 +-- .../models/cogview/cogview4/lora_trainer.py | 47 ++-- .../cogview/cogview4/lora_trainer_packing.py | 18 +- src/cogkit/finetune/diffusion/schemas/args.py | 51 +--- .../finetune/diffusion/schemas/state.py | 8 +- src/cogkit/finetune/diffusion/trainer.py | 265 +++++++++--------- 8 files changed, 226 insertions(+), 263 deletions(-) delete mode 100644 src/cogkit/finetune/diffusion/constants.py diff --git a/src/cogkit/finetune/diffusion/constants.py b/src/cogkit/finetune/diffusion/constants.py deleted file mode 100644 index f8c163d..0000000 --- a/src/cogkit/finetune/diffusion/constants.py +++ /dev/null @@ -1,2 +0,0 @@ -LOG_NAME = "DiffusionTrainer" -LOG_LEVEL = "INFO" diff --git a/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_i2v/lora_trainer.py b/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_i2v/lora_trainer.py index 4311014..6cec3f6 100644 --- a/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_i2v/lora_trainer.py +++ b/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_i2v/lora_trainer.py @@ -4,22 +4,21 @@ from typing import Any import torch -from diffusers import ( - AutoencoderKLCogVideoX, - CogVideoXDPMScheduler, - CogVideoXImageToVideoPipeline, - CogVideoXTransformer3DModel, -) -from diffusers.models.embeddings import get_3d_rotary_pos_embed from PIL import Image -from transformers import AutoTokenizer, T5EncoderModel, BitsAndBytesConfig +from transformers import AutoTokenizer, BitsAndBytesConfig, T5EncoderModel from typing_extensions import override from cogkit.finetune import register from cogkit.finetune.diffusion.schemas import DiffusionComponents from cogkit.finetune.diffusion.trainer import DiffusionTrainer -from cogkit.finetune.utils import unwrap_model from cogkit.utils import load_lora_checkpoint, unload_lora_checkpoint +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDPMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXTransformer3DModel, +) +from diffusers.models.embeddings import get_3d_rotary_pos_embed class CogVideoXI2VLoraTrainer(DiffusionTrainer): @@ -37,7 +36,7 @@ def load_components(self) -> DiffusionComponents: dtype = self.state.weight_dtype components = DiffusionComponents() - model_path = str(self.args.model_path) + model_path = str(self.uargs.model_path) ### pipeline components.pipeline_cls = CogVideoXImageToVideoPipeline @@ -53,7 +52,7 @@ def load_components(self) -> DiffusionComponents: ) ### transformer - if not self.args.low_vram: + if not self.uargs.low_vram: components.transformer = CogVideoXTransformer3DModel.from_pretrained( model_path, subfolder="transformer", @@ -64,7 +63,7 @@ def load_components(self) -> DiffusionComponents: model_path, subfolder="transformer", quantization_config=nf4_config, - device=self.accelerator.device, + device=self.state.device, torch_dtype=dtype, ) @@ -84,18 +83,18 @@ def load_components(self) -> DiffusionComponents: @override def initialize_pipeline(self, ckpt_path: str | None = None) -> CogVideoXImageToVideoPipeline: - if not self.args.low_vram: + if not self.uargs.low_vram: pipe = CogVideoXImageToVideoPipeline( tokenizer=self.components.tokenizer, text_encoder=self.components.text_encoder, vae=self.components.vae, - transformer=unwrap_model(self.accelerator, self.components.transformer), + transformer=self.unwrap_model(self.components.transformer), scheduler=self.components.scheduler, ) else: - assert self.args.training_type == "lora" + assert self.uargs.training_type == "lora" transformer = CogVideoXTransformer3DModel.from_pretrained( - str(self.args.model_path), + str(self.uargs.model_path), subfolder="transformer", torch_dtype=self.state.weight_dtype, ) @@ -131,7 +130,7 @@ def encode_text(self, prompt: str) -> torch.Tensor: ) prompt_token_ids = prompt_token_ids.input_ids prompt_embedding = self.components.text_encoder( - prompt_token_ids.to(self.accelerator.device) + prompt_token_ids.to(self.state.device) ).last_hidden_state[0] # shape of prompt_embedding: [seq_len, hidden_size] @@ -176,9 +175,10 @@ def collate_fn(self, samples: list[dict[str, Any]]) -> dict[str, Any]: @override def compute_loss(self, batch) -> torch.Tensor: - prompt_embedding = batch["prompt_embedding"] - latent = batch["encoded_videos"] - images = batch["image_preprocessed"] + device = self.state.device + prompt_embedding = batch["prompt_embedding"].to(device) + latent = batch["encoded_videos"].to(device) + images = batch["image_preprocessed"].to(device) # Shape of prompt_embedding: [B, seq_len, hidden_size] # Shape of latent: [B, C, F, H, W] @@ -201,9 +201,7 @@ def compute_loss(self, batch) -> torch.Tensor: # Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W] images = images.unsqueeze(2) # Add noise to images - image_noise_sigma = torch.normal( - mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device - ) + image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=device) image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype) noisy_images = ( images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None] @@ -218,7 +216,7 @@ def compute_loss(self, batch) -> torch.Tensor: 0, self.components.scheduler.config.num_train_timesteps, (batch_size,), - device=self.accelerator.device, + device=device, ) timesteps = timesteps.long() @@ -256,7 +254,7 @@ def compute_loss(self, batch) -> torch.Tensor: num_frames=num_frames, transformer_config=transformer_config, vae_scale_factor_spatial=vae_scale_factor_spatial, - device=self.accelerator.device, + device=device, ) if transformer_config.use_rotary_positional_embeddings else None @@ -310,8 +308,8 @@ def validation_step( num_frames=self.state.train_resolution[0], height=self.state.train_resolution[1], width=self.state.train_resolution[2], - prompt_embeds=prompt_embedding, - negative_prompt_embeds=self.get_negtive_prompt_embeds().unsqueeze(0), + prompt_embeds=prompt_embedding.to(self.state.device), + negative_prompt_embeds=self.state.negative_prompt_embeds.unsqueeze(0), image=image, generator=self.state.generator, ).frames[0] diff --git a/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_t2v/lora_trainer.py b/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_t2v/lora_trainer.py index cfa2594..ac1e3ba 100644 --- a/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_t2v/lora_trainer.py +++ b/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_t2v/lora_trainer.py @@ -4,22 +4,21 @@ from typing import Any import torch -from diffusers import ( - AutoencoderKLCogVideoX, - CogVideoXDPMScheduler, - CogVideoXPipeline, - CogVideoXTransformer3DModel, -) -from diffusers.models.embeddings import get_3d_rotary_pos_embed from PIL import Image -from transformers import AutoTokenizer, T5EncoderModel, BitsAndBytesConfig +from transformers import AutoTokenizer, BitsAndBytesConfig, T5EncoderModel from typing_extensions import override from cogkit.finetune import register from cogkit.finetune.diffusion.schemas import DiffusionComponents from cogkit.finetune.diffusion.trainer import DiffusionTrainer -from cogkit.finetune.utils import unwrap_model from cogkit.utils import load_lora_checkpoint, unload_lora_checkpoint +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDPMScheduler, + CogVideoXPipeline, + CogVideoXTransformer3DModel, +) +from diffusers.models.embeddings import get_3d_rotary_pos_embed class CogVideoXT2VLoraTrainer(DiffusionTrainer): @@ -37,7 +36,7 @@ def load_components(self) -> DiffusionComponents: dtype = self.state.weight_dtype components = DiffusionComponents() - model_path = str(self.args.model_path) + model_path = str(self.uargs.model_path) ### pipeline components.pipeline_cls = CogVideoXPipeline @@ -51,7 +50,7 @@ def load_components(self) -> DiffusionComponents: ) ### transformer - if not self.args.low_vram: + if not self.uargs.low_vram: components.transformer = CogVideoXTransformer3DModel.from_pretrained( model_path, subfolder="transformer", @@ -62,7 +61,7 @@ def load_components(self) -> DiffusionComponents: model_path, subfolder="transformer", quantization_config=nf4_config, - device=self.accelerator.device, + device=self.state.device, torch_dtype=dtype, ) @@ -80,18 +79,18 @@ def load_components(self) -> DiffusionComponents: @override def initialize_pipeline(self, ckpt_path: str | None = None) -> CogVideoXPipeline: - if not self.args.low_vram: + if not self.uargs.low_vram: pipe = CogVideoXPipeline( tokenizer=self.components.tokenizer, text_encoder=self.components.text_encoder, vae=self.components.vae, - transformer=unwrap_model(self.accelerator, self.components.transformer), + transformer=self.unwrap_model(self.components.transformer), scheduler=self.components.scheduler, ) else: - assert self.args.training_type == "lora" + assert self.uargs.training_type == "lora" transformer = CogVideoXTransformer3DModel.from_pretrained( - str(self.args.model_path), + str(self.uargs.model_path), subfolder="transformer", torch_dtype=self.state.weight_dtype, ) @@ -127,7 +126,7 @@ def encode_text(self, prompt: str) -> torch.Tensor: ) prompt_token_ids = prompt_token_ids.input_ids prompt_embedding = self.components.text_encoder( - prompt_token_ids.to(self.accelerator.device) + prompt_token_ids.to(self.state.device) ).last_hidden_state[0] # shape of prompt_embedding: [seq_len, hidden_size] @@ -161,8 +160,9 @@ def collate_fn(self, samples: list[dict[str, Any]]) -> dict[str, Any]: @override def compute_loss(self, batch) -> torch.Tensor: - prompt_embedding = batch["prompt_embedding"] - latent = batch["encoded_videos"] + device = self.state.device + prompt_embedding = batch["prompt_embedding"].to(device) + latent = batch["encoded_videos"].to(device) assert latent is not None and prompt_embedding is not None @@ -188,7 +188,7 @@ def compute_loss(self, batch) -> torch.Tensor: 0, self.components.scheduler.config.num_train_timesteps, (batch_size,), - device=self.accelerator.device, + device=device, ) timesteps = timesteps.long() @@ -207,7 +207,7 @@ def compute_loss(self, batch) -> torch.Tensor: num_frames=num_frames, transformer_config=transformer_config, vae_scale_factor_spatial=vae_scale_factor_spatial, - device=self.accelerator.device, + device=device, ) if transformer_config.use_rotary_positional_embeddings else None @@ -251,8 +251,8 @@ def validation_step( num_frames=self.state.train_resolution[0], height=self.state.train_resolution[1], width=self.state.train_resolution[2], - prompt_embeds=prompt_embedding, - negative_prompt_embeds=self.get_negtive_prompt_embeds().unsqueeze(0), + prompt_embeds=prompt_embedding.to(self.state.device), + negative_prompt_embeds=self.state.negative_prompt_embeds.unsqueeze(0), generator=self.state.generator, ).frames[0] return {"text": prompt, "video": video_generate} diff --git a/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer.py b/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer.py index ac91bb2..05a3495 100644 --- a/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer.py +++ b/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer.py @@ -13,7 +13,6 @@ from cogkit.finetune.diffusion.trainer import DiffusionTrainer from cogkit.finetune.utils import ( process_prompt_attention_mask, - unwrap_model, replace_attn_processor, ) from cogkit.utils import load_lora_checkpoint, unload_lora_checkpoint @@ -43,7 +42,7 @@ def load_components(self) -> DiffusionComponents: dtype = self.state.weight_dtype components = DiffusionComponents() - model_path = str(self.args.model_path) + model_path = str(self.uargs.model_path) ### pipeline components.pipeline_cls = CogView4Pipeline @@ -59,7 +58,7 @@ def load_components(self) -> DiffusionComponents: ) ### transformer - if not self.args.low_vram: + if not self.uargs.low_vram: components.transformer = CogView4Transformer2DModel.from_pretrained( model_path, subfolder="transformer", @@ -71,7 +70,7 @@ def load_components(self) -> DiffusionComponents: subfolder="transformer", torch_dtype=dtype, quantization_config=nf4_config, - device=self.accelerator.device, + device=self.state.device, ) replace_attn_processor(components.transformer, CogView4TrainingAttnProcessor()) @@ -88,23 +87,22 @@ def load_components(self) -> DiffusionComponents: @override def initialize_pipeline(self, ckpt_path: str | None = None) -> CogView4Pipeline: - if not self.args.low_vram: + # using bf16 model rather than quantized ones + if not self.uargs.low_vram: pipe = CogView4Pipeline( tokenizer=self.components.tokenizer, text_encoder=self.components.text_encoder, vae=self.components.vae, - transformer=unwrap_model(self.accelerator, self.components.transformer), + transformer=self.unwrap_model(self.components.transformer), scheduler=self.components.scheduler, ) else: - assert self.args.training_type == "lora" - # using bf16 model rather than quantized ones + assert self.uargs.training_type == "lora" transformer = CogView4Transformer2DModel.from_pretrained( - str(self.args.model_path), + str(self.uargs.model_path), subfolder="transformer", torch_dtype=self.state.weight_dtype, ) - replace_attn_processor(transformer, CogView4TrainingAttnProcessor()) pipe = CogView4Pipeline( tokenizer=self.components.tokenizer, text_encoder=self.components.text_encoder, @@ -133,7 +131,7 @@ def encode_text(self, prompt: str) -> torch.Tensor: ).input_ids prompt_embedding = self.components.text_encoder( - prompt_token_ids.to(self.accelerator.device), output_hidden_states=True + prompt_token_ids.to(self.state.device), output_hidden_states=True ).hidden_states[-2][0] # shape of prompt_embedding: [sequence length, embedding dimension(4096)] return prompt_embedding @@ -145,7 +143,7 @@ def get_negtive_prompt_embeds(self) -> torch.Tensor: @override def encode_image(self, image: torch.Tensor) -> torch.Tensor: vae = self.components.vae - image = image.to(self.accelerator.device, dtype=vae.dtype) + image = image.to(self.state.device, dtype=vae.dtype) latent_dist = vae.encode(image).latent_dist latent = latent_dist.sample() * vae.config.scaling_factor return latent @@ -225,8 +223,9 @@ def collate_fn(self, samples: list[dict[str, Any]]) -> dict[str, Any]: @override def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor: batch_size, text_seqlen, text_embedding_dim = batch["prompt_embedding"].shape - prompt_embeds = batch["prompt_embedding"] - latent = batch["encoded_image"] + device = self.state.device + prompt_embeds = batch["prompt_embedding"].to(device) + latent = batch["encoded_image"].to(device) batch_size, num_channels, height, width = latent.shape image_height, image_width = self.state.train_resolution @@ -234,7 +233,7 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor: image_seq_len = ( (image_height // vae_scale_factor) * (image_width // vae_scale_factor) ) // (self.state.transformer_config.patch_size**2) - image_seq_len = torch.tensor([image_seq_len], device=self.accelerator.device) + image_seq_len = torch.tensor([image_seq_len], device=device) text_attn_mask = batch["text_attn_mask"] @@ -248,20 +247,20 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor: original_size = torch.tensor( [[image_height, image_width] for _ in range(batch_size)], dtype=latent.dtype, - device=self.accelerator.device, + device=device, ) target_size = torch.tensor( [[image_height, image_width] for _ in range(batch_size)], dtype=latent.dtype, - device=self.accelerator.device, + device=device, ) crop_coords = torch.tensor( - [[0, 0] for _ in range(batch_size)], dtype=latent.dtype, device=self.accelerator.device + [[0, 0] for _ in range(batch_size)], dtype=latent.dtype, device=device ) noise_pred_cond = self.components.transformer( - hidden_states=model_input, - encoder_hidden_states=prompt_embeds, + hidden_states=model_input.to(dtype=self.state.weight_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=self.state.weight_dtype), timestep=timestep, original_size=original_size, target_size=target_size, @@ -288,11 +287,11 @@ def get_sigmas(self, batch_size: int, vtoken_seq_len: torch.Tensor) -> torch.Ten scheduler.sigma_min, scheduler.sigma_max, scheduler.config.num_train_timesteps, - device=self.accelerator.device, + device=self.state.device, ) m = (vtoken_seq_len / scheduler.config.base_image_seq_len) ** 0.5 mu = m * scheduler.config.max_shift + scheduler.config.base_shift - mu = mu.unsqueeze(1) + mu = mu.unsqueeze(1).to(sigmas.device) sigmas = mu / (mu + (1 / sigmas - 1)) sigmas = torch.cat([torch.zeros((batch_size, 1), device=sigmas.device), sigmas], dim=1) return sigmas @@ -302,7 +301,7 @@ def get_timestep(self, batch_size: int, num_train_timesteps: int) -> torch.LongT 0, num_train_timesteps, (batch_size,), - device=self.accelerator.device, + device=self.state.device, ) def add_noise( @@ -335,7 +334,7 @@ def validation_step( image_generate = pipe( height=self.state.train_resolution[0], width=self.state.train_resolution[1], - prompt_embeds=prompt_embedding, + prompt_embeds=prompt_embedding.to(self.state.device), negative_prompt_embeds=self.state.negative_prompt_embeds.unsqueeze( 0 ), # Add batch dimension diff --git a/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer_packing.py b/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer_packing.py index 76f1aef..9761e34 100644 --- a/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer_packing.py +++ b/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer_packing.py @@ -41,7 +41,7 @@ def __init__(self, *args, **kwargs) -> None: self.ROPE_DIM = transformer.config.rope_axes_dim patch_size = self.PATCH_SIZE - height, width = self.args.train_resolution + height, width = self.uargs.train_resolution sample_height, sample_width = ( height // self.DOWNSAMPLER_FACTOR, width // self.DOWNSAMPLER_FACTOR, @@ -161,9 +161,9 @@ def collate_fn_packing(self, samples: list[dict[str, list[Any]]]) -> dict[str, A @override def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor: - dtype = self.get_training_dtype() - prompt_embeds = batch["prompt_embedding"] - latent = batch["encoded_image"] + device, dtype = self.state.device, self.state.weight_dtype + prompt_embeds = batch["prompt_embedding"].to(device) + latent = batch["encoded_image"].to(device) image_rotary_emb = batch["image_rotary_emb"] batch_size, text_seqlen, text_embedding_dim = prompt_embeds.shape batch_size, num_channels, height, width = latent.shape @@ -182,11 +182,9 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor: noise = torch.randn_like(latent, dtype=dtype) model_input, model_label = self.add_noise(latent, noise, timestep, sigmas) - original_size = original_size.to(dtype=dtype, device=self.accelerator.device) - target_size = original_size.clone().to(dtype=dtype, device=self.accelerator.device) - crop_coords = torch.tensor( - [[0, 0] for _ in range(batch_size)], dtype=dtype, device=self.accelerator.device - ) + original_size = original_size.to(dtype=dtype, device=device) + target_size = original_size.clone().to(dtype=dtype, device=device) + crop_coords = torch.tensor([[0, 0] for _ in range(batch_size)], dtype=dtype, device=device) noise_pred_cond = self.components.transformer( hidden_states=model_input.to(dtype=dtype), @@ -200,7 +198,7 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor: attention_kwargs=attention_kwargs, )[0] - pixel_mask = batch["pixel_mask"] + pixel_mask = batch["pixel_mask"].to(device) loss = torch.sum(((noise_pred_cond - model_label) ** 2) * pixel_mask, dim=(1, 2, 3)) loss = loss / torch.sum(pixel_mask, dim=(1, 2, 3)) loss = loss.mean() diff --git a/src/cogkit/finetune/diffusion/schemas/args.py b/src/cogkit/finetune/diffusion/schemas/args.py index 8dad3d1..59e01f3 100644 --- a/src/cogkit/finetune/diffusion/schemas/args.py +++ b/src/cogkit/finetune/diffusion/schemas/args.py @@ -1,4 +1,4 @@ -from typing import Literal +from pathlib import Path from pydantic import ValidationInfo, field_validator from typing_extensions import override @@ -7,16 +7,10 @@ class DiffusionArgs(BaseArgs): - ########## Model ########## - model_type: Literal["i2v", "t2v", "t2i"] - - ########## Output ########## - tracker_name: str = "diffusion-tracker" - ########## Training ######### - # For cogview models, train_resolution is a tuple of (height, width) - # For cogvideo models, train_resolution is a tuple of (frames, height, width) - train_resolution: tuple[int, int] | tuple[int, int, int] + # For cogview models, train_resolution is a list of (height, width) + # For cogvideo models, train_resolution is a list of (frames, height, width) + train_resolution: list[int, int] | list[int, int, int] enable_slicing: bool = True enable_tiling: bool = True @@ -25,11 +19,11 @@ class DiffusionArgs(BaseArgs): enable_packing: bool = False ########## Validation ########## - gen_fps: int = 15 + gen_fps: int | None = None @field_validator("train_resolution") def validate_train_resolution( - cls, v: tuple[int, int] | tuple[int, int, int], info: ValidationInfo + cls, v: list[int, int] | list[int, int, int], info: ValidationInfo ) -> str: if len(v) == 2: # cogview models height, width = v @@ -49,39 +43,12 @@ def validate_train_resolution( ) else: raise ValueError( - "train_resolution must be a tuple of (height, width) for cogview models or (frames, height, width) for cogvideo models" + "train_resolution must be a list of (height, width) for cogview models or (frames, height, width) for cogvideo models" ) return v @override @classmethod - def parse_args(cls): - parser = cls.get_base_parser() - - # Required arguments - parser.add_argument("--model_type", type=str, required=True) - parser.add_argument("--train_resolution", type=str, required=True) - - # Model configuration - parser.add_argument("--enable_slicing", action="store_true") - parser.add_argument("--enable_tiling", action="store_true") - - # Packing - parser.add_argument("--enable_packing", type=lambda x: x.lower() == "true", default=False) - - # Validation - parser.add_argument("--gen_fps", type=int, default=15) - - args = parser.parse_args() - - # Convert train_resolution string to tuple - parts = args.train_resolution.split("x") - if len(parts) == 2: - height, width = parts - args.train_resolution = (int(height), int(width)) - else: - frames, height, width = parts - args.train_resolution = (int(frames), int(height), int(width)) - - return cls(**vars(args)) + def parse_from_yaml(cls, fpath: str | Path) -> "DiffusionArgs": + return super().parse_from_yaml(fpath) diff --git a/src/cogkit/finetune/diffusion/schemas/state.py b/src/cogkit/finetune/diffusion/schemas/state.py index fc7bcbe..ad0f86c 100644 --- a/src/cogkit/finetune/diffusion/schemas/state.py +++ b/src/cogkit/finetune/diffusion/schemas/state.py @@ -14,13 +14,13 @@ class DiffusionState(BaseState): # for video input, train_resolution = (frames, height, width) # for image input, train_resolution = (height, width) - train_resolution: tuple[int, int, int] | tuple[int, int] + train_resolution: tuple[int, int, int] | tuple[int, int] = () - # packing realted - training_seq_length: int | None = None + negative_prompt_embeds: torch.Tensor | None = None validation_prompts: list[str] = [] validation_images: list[Path | None] = [] validation_videos: list[Path | None] = [] - negative_prompt_embeds: torch.Tensor | None = None + # packing realted + training_seq_length: int | None = None diff --git a/src/cogkit/finetune/diffusion/trainer.py b/src/cogkit/finetune/diffusion/trainer.py index 679d1ae..c8522ee 100644 --- a/src/cogkit/finetune/diffusion/trainer.py +++ b/src/cogkit/finetune/diffusion/trainer.py @@ -1,101 +1,107 @@ import json +from pathlib import Path from typing import Any import torch +import torch.distributed as dist import wandb from accelerate import cpu_offload -from accelerate.utils import gather_object +from torch.utils.data import DistributedSampler from PIL import Image from typing_extensions import override from cogkit.finetune.base import BaseTrainer -from cogkit.samplers import NaivePackingSampler -from cogkit.utils import expand_list +from cogkit.finetune.samplers import DistPackingSampler +from cogkit.utils import expand_list, guess_generation_mode +from cogkit.types import GenerationMode from diffusers.pipelines import DiffusionPipeline from diffusers.utils.export_utils import export_to_video from ..utils import ( free_memory, get_memory_statistics, - unload_model, + gather_object, + mkdir, ) -from .constants import LOG_LEVEL, LOG_NAME from .schemas import DiffusionArgs, DiffusionComponents, DiffusionState class DiffusionTrainer(BaseTrainer): - # If set, should be a list of components to unload (refer to `Components``) - UNLOAD_LIST: list[str] = None - LOG_NAME: str = LOG_NAME - LOG_LEVEL: str = LOG_LEVEL + @override + def __init__(self, uargs_fpath: str | Path) -> None: + super().__init__(uargs_fpath) + self.uargs: DiffusionArgs + self.state: DiffusionState + self.components: DiffusionComponents @override - def _init_args(self) -> DiffusionArgs: - return DiffusionArgs.parse_args() + def _init_args(self, uargs_fpath: Path) -> DiffusionArgs: + return DiffusionArgs.parse_from_yaml(uargs_fpath) @override def _init_state(self) -> DiffusionState: - return DiffusionState( - weight_dtype=self.get_training_dtype(), - train_resolution=self.args.train_resolution, - ) + state = DiffusionState(**super()._init_state().model_dump()) + state.train_resolution = self.uargs.train_resolution + return state @override def prepare_models(self) -> None: if self.components.vae is not None: - if self.args.enable_slicing: + if self.uargs.enable_slicing: self.components.vae.enable_slicing() - if self.args.enable_tiling: + if self.uargs.enable_tiling: self.components.vae.enable_tiling() self.state.transformer_config = self.components.transformer.config @override def prepare_dataset(self) -> None: - if self.args.model_type == "i2v": - from cogkit.datasets import BaseI2VDataset, I2VDatasetWithResize - - dataset_cls = I2VDatasetWithResize - if self.args.enable_packing: - dataset_cls = BaseI2VDataset - raise NotImplementedError("Packing for I2V is not implemented") - - elif self.args.model_type == "t2v": - from cogkit.datasets import BaseT2VDataset, T2VDatasetWithResize - - dataset_cls = T2VDatasetWithResize - if self.args.enable_packing: - dataset_cls = BaseT2VDataset - raise NotImplementedError("Packing for T2V is not implemented") - - elif self.args.model_type == "t2i": - from cogkit.datasets import ( - T2IDatasetWithFactorResize, - T2IDatasetWithPacking, - T2IDatasetWithResize, - ) - - dataset_cls = T2IDatasetWithResize - if self.args.enable_packing: - dataset_cls = T2IDatasetWithFactorResize - dataset_cls_packing = T2IDatasetWithPacking - - else: - raise ValueError(f"Invalid model type: {self.args.model_type}") + generation_mode = guess_generation_mode(self.components.pipeline_cls) + match generation_mode: + case GenerationMode.TextToImage: + from cogkit.finetune.datasets import ( + T2IDatasetWithFactorResize, + T2IDatasetWithPacking, + T2IDatasetWithResize, + ) + + dataset_cls = T2IDatasetWithResize + if self.uargs.enable_packing: + dataset_cls = T2IDatasetWithFactorResize + dataset_cls_packing = T2IDatasetWithPacking + + case GenerationMode.TextToVideo: + from cogkit.finetune.datasets import BaseT2VDataset, T2VDatasetWithResize + + dataset_cls = T2VDatasetWithResize + if self.uargs.enable_packing: + dataset_cls = BaseT2VDataset + raise NotImplementedError("Packing for T2V is not implemented") + + case GenerationMode.ImageToVideo: + from cogkit.finetune.datasets import BaseI2VDataset, I2VDatasetWithResize + + dataset_cls = I2VDatasetWithResize + if self.uargs.enable_packing: + dataset_cls = BaseI2VDataset + raise NotImplementedError("Packing for I2V is not implemented") + + case _: + raise ValueError(f"Invalid generation mode: {generation_mode}") additional_args = { - "device": self.accelerator.device, + "device": self.state.device, "trainer": self, } self.train_dataset = dataset_cls( - **(self.args.model_dump()), + **(self.uargs.model_dump()), **additional_args, using_train=True, ) - if self.args.do_validation: + if self.uargs.do_validation: self.test_dataset = dataset_cls( - **(self.args.model_dump()), + **(self.uargs.model_dump()), **additional_args, using_train=False, ) @@ -103,15 +109,15 @@ def prepare_dataset(self) -> None: ### Prepare VAE and text encoder for encoding self.components.vae.requires_grad_(False) self.components.text_encoder.requires_grad_(False) - self.components.vae.to(self.accelerator.device, dtype=self.state.weight_dtype) - if self.args.low_vram: # offload text encoder to CPU - cpu_offload(self.components.text_encoder, self.accelerator.device) + self.components.vae.to(self.state.device, dtype=self.state.weight_dtype) + if self.uargs.low_vram: # offload text encoder to CPU + cpu_offload(self.components.text_encoder, self.state.device) else: - self.components.text_encoder.to(self.accelerator.device, dtype=self.state.weight_dtype) + self.components.text_encoder.to(self.state.device, dtype=self.state.weight_dtype) ### Precompute embedding self.logger.info("Precomputing embedding ...") - self.state.negative_prompt_embeds = self.get_negtive_prompt_embeds() + self.state.negative_prompt_embeds = self.get_negtive_prompt_embeds().to(self.state.device) for dataset in [self.train_dataset, self.test_dataset]: if dataset is None: @@ -121,28 +127,38 @@ def prepare_dataset(self) -> None: collate_fn=self.collate_fn, batch_size=1, num_workers=0, - pin_memory=self.args.pin_memory, + pin_memory=self.uargs.pin_memory, + sampler=DistributedSampler( + dataset, + num_replicas=self.state.world_size, + rank=self.state.global_rank, + shuffle=False, + ), ) - tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader) for _ in tmp_data_loader: ... - self.accelerator.wait_for_everyone() self.logger.info("Precomputing embedding ... Done") + dist.barrier() - unload_model(self.components.vae) - if not self.args.low_vram: - unload_model(self.components.text_encoder) + self.components.vae = self.components.vae.to("cpu") + if not self.uargs.low_vram: + self.components.text_encoder = self.components.text_encoder.to("cpu") free_memory() - if not self.args.enable_packing: + if not self.uargs.enable_packing: self.train_data_loader = torch.utils.data.DataLoader( self.train_dataset, collate_fn=self.collate_fn, - batch_size=self.args.batch_size, - num_workers=self.args.num_workers, - pin_memory=self.args.pin_memory, - shuffle=True, + batch_size=self.uargs.batch_size, + num_workers=self.uargs.num_workers, + pin_memory=self.uargs.pin_memory, + sampler=DistributedSampler( + self.train_dataset, + num_replicas=self.state.world_size, + rank=self.state.global_rank, + shuffle=True, + ), ) else: length_list = [self.sample_to_length(sample) for sample in self.train_dataset] @@ -150,24 +166,31 @@ def prepare_dataset(self) -> None: self.train_data_loader = torch.utils.data.DataLoader( self.train_dataset, collate_fn=self.collate_fn_packing, - batch_size=self.args.batch_size, - num_workers=self.args.num_workers, - pin_memory=self.args.pin_memory, - sampler=NaivePackingSampler( + batch_size=self.uargs.batch_size, + num_workers=self.uargs.num_workers, + pin_memory=self.uargs.pin_memory, + sampler=DistPackingSampler( length_list, self.state.training_seq_length, shuffle=True, + world_size=self.state.world_size, + global_rank=self.state.global_rank, ), ) - if self.args.do_validation: + if self.uargs.do_validation: self.test_data_loader = torch.utils.data.DataLoader( self.test_dataset, collate_fn=self.collate_fn, batch_size=1, - num_workers=self.args.num_workers, - pin_memory=self.args.pin_memory, - shuffle=False, + num_workers=self.uargs.num_workers, + pin_memory=self.uargs.pin_memory, + sampler=DistributedSampler( + self.test_dataset, + num_replicas=self.state.world_size, + rank=self.state.global_rank, + shuffle=False, + ), ) @override @@ -179,7 +202,7 @@ def validate(self, step: int, ckpt_path: str | None = None) -> None: self.logger.warning("No validation samples found. Skipping validation.") return - self.components.transformer.eval() + # self.components.transformer.eval() torch.set_grad_enabled(False) memory_statistics = get_memory_statistics(self.logger) @@ -190,25 +213,16 @@ def validate(self, step: int, ckpt_path: str | None = None) -> None: ##### Initialize pipeline ##### pipe = self.initialize_pipeline(ckpt_path=ckpt_path) - if self.state.using_deepspeed: - # Can't using model_cpu_offload in deepspeed, - # so we need to move all components in pipe to device - self.move_components_to_device( - dtype=self.state.weight_dtype, - device=self.accelerator.device, - ignore_list=["transformer"], - ) + # if not using deepspeed, use model_cpu_offload to further reduce memory usage + # Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage + if self.uargs.low_vram: + pipe.enable_sequential_cpu_offload(device=self.state.device) else: - # if not using deepspeed, use model_cpu_offload to further reduce memory usage - # Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage - if self.args.low_vram: - pipe.enable_sequential_cpu_offload(device=self.accelerator.device) - else: - pipe.enable_model_cpu_offload(device=self.accelerator.device) + pipe.enable_model_cpu_offload(device=self.state.device) - # Convert all model weights to training dtype - # Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32 - pipe = pipe.to(dtype=self.state.weight_dtype) + # Convert all model weights to training dtype + # Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32 + pipe = pipe.to(dtype=self.state.weight_dtype) ################################# @@ -228,8 +242,7 @@ def validate(self, step: int, ckpt_path: str | None = None) -> None: encoded_video = batch.get("encoded_video", None) self.logger.debug( - f"Validating sample {i + 1}/{num_validation_samples} on process {self.accelerator.process_index}. Prompt: {prompt}", - main_process_only=False, + f"Validating sample {i + 1}/{num_validation_samples} on process {self.state.global_rank}. Prompt: {prompt}", ) val_res = self.validation_step( pipe=pipe, @@ -244,9 +257,9 @@ def validate(self, step: int, ckpt_path: str | None = None) -> None: ) artifacts = {} - val_path = self.args.output_dir / "validation_res" / f"validation-{step}" - val_path.mkdir(parents=True, exist_ok=True) - filename = f"artifact-process{self.accelerator.process_index}-batch{i}" + val_path = self.uargs.output_dir / "validation_res" / f"validation-{step}" + mkdir(val_path) + filename = f"artifact-process{self.state.global_rank}-batch{i}" image = val_res.get("image", None) video = val_res.get("video", None) @@ -258,48 +271,38 @@ def validate(self, step: int, ckpt_path: str | None = None) -> None: artifacts["image"] = wandb.Image(fpath, caption=prompt) if video: fpath = str(val_path / f"{filename}.mp4") - export_to_video(video, fpath, fps=self.args.gen_fps) + export_to_video(video, fpath, fps=self.uargs.gen_fps) artifacts["video"] = wandb.Video(fpath, caption=prompt) all_processes_artifacts.append(artifacts) - all_artifacts = gather_object(all_processes_artifacts) - all_artifacts = expand_list(all_artifacts) - - if self.accelerator.is_main_process: - tracker_key = "validation" - for tracker in self.accelerator.trackers: - if tracker.name == "wandb": - tracker.log({tracker_key: all_artifacts}, step=step) - - ########## Clean up ########## - if self.state.using_deepspeed: - del pipe - # Unload models except those needed for training - self.move_components_to_device( - dtype=self.state.weight_dtype, device="cpu", ignore_list=["transformer"] - ) - else: - pipe.remove_all_hooks() - del pipe - # Load models except those not needed for training - self.move_components_to_device( - dtype=self.state.weight_dtype, - device=self.accelerator.device, - ignore_list=self.UNLOAD_LIST, - ) - self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype) + if self.tracker is not None: + all_artifacts = gather_object(all_processes_artifacts) + all_artifacts = [item for sublist in all_artifacts for item in sublist] + all_artifacts = expand_list(all_artifacts) + self.tracker.log({"validation": all_artifacts}, step=step) + + # ============= Clean up ============= + pipe.remove_all_hooks() + del pipe + # Load models except those not needed for training + self.move_components_to_device( + dtype=self.state.weight_dtype, + device=self.state.device, + ignore_list=self.UNLOAD_LIST, + ) + # self.components.transformer.to(self.state.device, dtype=self.state.weight_dtype) - # Change trainable weights back to fp32 to keep with dtype after prepare the model - # cast_training_params([self.components.transformer], dtype=torch.float32) + # Change trainable weights back to fp32 to keep with dtype after prepare the model + # cast_training_params([self.components.transformer], dtype=torch.float32) free_memory() - self.accelerator.wait_for_everyone() - ################################ + dist.barrier() + # ======================================= memory_statistics = get_memory_statistics(self.logger) self.logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") - torch.cuda.reset_peak_memory_stats(self.accelerator.device) + torch.cuda.reset_peak_memory_stats(self.state.device) torch.set_grad_enabled(True) self.components.transformer.train() @@ -309,7 +312,7 @@ def load_components(self) -> DiffusionComponents: raise NotImplementedError @override - def compute_loss(self, batch) -> torch.Tensor: + def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor: raise NotImplementedError def collate_fn(self, samples: list[dict[str, Any]]): From c8f993ae02631e624f4a900602ebbc4652232e1f Mon Sep 17 00:00:00 2001 From: OleehyO Date: Tue, 6 May 2025 10:36:52 +0000 Subject: [PATCH 14/19] Update --- src/cogkit/finetune/__init__.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/cogkit/finetune/__init__.py b/src/cogkit/finetune/__init__.py index f7878a6..bd92860 100644 --- a/src/cogkit/finetune/__init__.py +++ b/src/cogkit/finetune/__init__.py @@ -1,20 +1,21 @@ # -*- coding: utf-8 -*- - -from cogkit.finetune.base import BaseTrainer - # import register first -from cogkit.finetune.register import get_model_cls, register, show_supported_models # noqa +from ._register import get_model_cls, register, show_supported_models # noqa + +from .base import BaseTrainer # import resgistered models -from cogkit.finetune.diffusion import models as diffusion_models -from cogkit.finetune.llm import models as llm_models +from .diffusion import models as diffusion_models +from .llm import models as llm_models +from .logger import get_logger __all__ = [ "BaseTrainer", "diffusion_models", "llm_models", + "get_logger", "get_model_cls", "register", "show_supported_models", From 7070f8bf34532517fb8e90fdc0956fc3a53c56ae Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 9 May 2025 10:39:53 +0000 Subject: [PATCH 15/19] [scripts] Refactor for FSDP training --- quickstart/configs/accelerate_config.yaml | 26 ------- quickstart/configs/zero/zero2.yaml | 38 ---------- quickstart/configs/zero/zero2_offload.yaml | 42 ------------ quickstart/configs/zero/zero3.yaml | 41 ----------- quickstart/configs/zero/zero3_offload.yaml | 49 ------------- quickstart/scripts/i2v/config.yaml | 65 ++++++++++++++++++ quickstart/scripts/i2v/start_train.sh | 7 ++ quickstart/scripts/t2i/config.yaml | 64 +++++++++++++++++ quickstart/scripts/t2i/start_train.sh | 7 ++ quickstart/scripts/t2v/config.yaml | 64 +++++++++++++++++ quickstart/scripts/t2v/start_train.sh | 7 ++ quickstart/scripts/train.py | 16 +++-- quickstart/scripts/train_ddp_i2v.sh | 70 ------------------- quickstart/scripts/train_ddp_t2i.sh | 80 ---------------------- quickstart/scripts/train_ddp_t2v.sh | 69 ------------------- quickstart/scripts/train_zero_i2v.sh | 74 -------------------- quickstart/scripts/train_zero_t2i.sh | 78 --------------------- quickstart/scripts/train_zero_t2v.sh | 73 -------------------- 18 files changed, 224 insertions(+), 646 deletions(-) delete mode 100644 quickstart/configs/accelerate_config.yaml delete mode 100644 quickstart/configs/zero/zero2.yaml delete mode 100644 quickstart/configs/zero/zero2_offload.yaml delete mode 100644 quickstart/configs/zero/zero3.yaml delete mode 100644 quickstart/configs/zero/zero3_offload.yaml create mode 100644 quickstart/scripts/i2v/config.yaml create mode 100644 quickstart/scripts/i2v/start_train.sh create mode 100644 quickstart/scripts/t2i/config.yaml create mode 100644 quickstart/scripts/t2i/start_train.sh create mode 100644 quickstart/scripts/t2v/config.yaml create mode 100644 quickstart/scripts/t2v/start_train.sh delete mode 100755 quickstart/scripts/train_ddp_i2v.sh delete mode 100755 quickstart/scripts/train_ddp_t2i.sh delete mode 100755 quickstart/scripts/train_ddp_t2v.sh delete mode 100755 quickstart/scripts/train_zero_i2v.sh delete mode 100755 quickstart/scripts/train_zero_t2i.sh delete mode 100755 quickstart/scripts/train_zero_t2v.sh diff --git a/quickstart/configs/accelerate_config.yaml b/quickstart/configs/accelerate_config.yaml deleted file mode 100644 index b6032b6..0000000 --- a/quickstart/configs/accelerate_config.yaml +++ /dev/null @@ -1,26 +0,0 @@ -compute_environment: LOCAL_MACHINE - -gpu_ids: "0,1,2,3,4,5,6,7" -num_processes: 8 # should be the same as the number of GPUs - -# gpu_ids: "0" -# num_processes: 1 - -debug: false - -distributed_type: DEEPSPEED -deepspeed_config: - deepspeed_config_file: /path/to/configs/zero/zero2.yaml # e.g. need use absolute path - zero3_init_flag: false - -downcast_bf16: 'no' -enable_cpu_affinity: false -machine_rank: 0 -main_training_function: main -num_machines: 1 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/quickstart/configs/zero/zero2.yaml b/quickstart/configs/zero/zero2.yaml deleted file mode 100644 index b056bd4..0000000 --- a/quickstart/configs/zero/zero2.yaml +++ /dev/null @@ -1,38 +0,0 @@ -{ - "bf16": { - "enabled": true - }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "weight_decay": "auto", - "torch_adam": true, - "adam_w_mode": true - } - }, - "scheduler": { - "type": "WarmupDecayLR", - "params": { - "warmup_min_lr": "auto", - "warmup_max_lr": "auto", - "warmup_num_steps": "auto", - "total_num_steps": "auto" - } - }, - "zero_optimization": { - "stage": 2, - "allgather_partitions": true, - "allgather_bucket_size": 2e8, - "overlap_comm": true, - "reduce_scatter": true, - "reduce_bucket_size": 5e8, - "contiguous_gradients": true - }, - "gradient_accumulation_steps": 1, - "train_micro_batch_size_per_gpu": 1, - "train_batch_size": "auto", - "gradient_clipping": "auto", - "steps_per_print": 2000, - "wall_clock_breakdown": false -} diff --git a/quickstart/configs/zero/zero2_offload.yaml b/quickstart/configs/zero/zero2_offload.yaml deleted file mode 100644 index 24fdcb4..0000000 --- a/quickstart/configs/zero/zero2_offload.yaml +++ /dev/null @@ -1,42 +0,0 @@ -{ - "bf16": { - "enabled": true - }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "weight_decay": "auto", - "torch_adam": true, - "adam_w_mode": true - } - }, - "scheduler": { - "type": "WarmupDecayLR", - "params": { - "warmup_min_lr": "auto", - "warmup_max_lr": "auto", - "warmup_num_steps": "auto", - "total_num_steps": "auto" - } - }, - "zero_optimization": { - "stage": 2, - "allgather_partitions": true, - "allgather_bucket_size": 2e8, - "overlap_comm": true, - "reduce_scatter": true, - "reduce_bucket_size": 5e8, - "contiguous_gradients": true, - "offload_optimizer": { - "device": "cpu", - "pin_memory": true - } - }, - "gradient_accumulation_steps": 1, - "train_micro_batch_size_per_gpu": 1, - "train_batch_size": "auto", - "gradient_clipping": "auto", - "steps_per_print": 2000, - "wall_clock_breakdown": false -} diff --git a/quickstart/configs/zero/zero3.yaml b/quickstart/configs/zero/zero3.yaml deleted file mode 100644 index 18685d0..0000000 --- a/quickstart/configs/zero/zero3.yaml +++ /dev/null @@ -1,41 +0,0 @@ -{ - "bf16": { - "enabled": true - }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "weight_decay": "auto", - "torch_adam": true, - "adam_w_mode": true - } - }, - "scheduler": { - "type": "WarmupDecayLR", - "params": { - "warmup_min_lr": "auto", - "warmup_max_lr": "auto", - "warmup_num_steps": "auto", - "total_num_steps": "auto" - } - }, - "zero_optimization": { - "stage": 3, - "overlap_comm": true, - "contiguous_gradients": true, - "reduce_bucket_size": 5e8, - "sub_group_size": 1e9, - "stage3_max_live_parameters": 1e9, - "stage3_max_reuse_distance": 1e9, - "stage3_gather_16bit_weights_on_model_save": "auto", - "stage3_prefetch_bucket_size": 5e8, - "stage3_param_persistence_threshold": 1e5 - }, - "gradient_accumulation_steps": 1, - "train_micro_batch_size_per_gpu": 1, - "train_batch_size": "auto", - "gradient_clipping": "auto", - "steps_per_print": 2000, - "wall_clock_breakdown": false -} diff --git a/quickstart/configs/zero/zero3_offload.yaml b/quickstart/configs/zero/zero3_offload.yaml deleted file mode 100644 index e780e2f..0000000 --- a/quickstart/configs/zero/zero3_offload.yaml +++ /dev/null @@ -1,49 +0,0 @@ -{ - "bf16": { - "enabled": true - }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "weight_decay": "auto", - "torch_adam": true, - "adam_w_mode": true - } - }, - "scheduler": { - "type": "WarmupDecayLR", - "params": { - "warmup_min_lr": "auto", - "warmup_max_lr": "auto", - "warmup_num_steps": "auto", - "total_num_steps": "auto" - } - }, - "zero_optimization": { - "stage": 3, - "offload_optimizer": { - "device": "cpu", - "pin_memory": true - }, - "offload_param": { - "device": "cpu", - "pin_memory": true - }, - "overlap_comm": true, - "contiguous_gradients": true, - "reduce_bucket_size": 5e8, - "sub_group_size": 1e9, - "stage3_max_live_parameters": 1e9, - "stage3_max_reuse_distance": 1e9, - "stage3_gather_16bit_weights_on_model_save": "auto", - "stage3_prefetch_bucket_size": 5e8, - "stage3_param_persistence_threshold": 1e6 - }, - "gradient_accumulation_steps": 1, - "train_micro_batch_size_per_gpu": 1, - "train_batch_size": "auto", - "gradient_clipping": "auto", - "steps_per_print": 2000, - "wall_clock_breakdown": false -} diff --git a/quickstart/scripts/i2v/config.yaml b/quickstart/scripts/i2v/config.yaml new file mode 100644 index 0000000..43e1afa --- /dev/null +++ b/quickstart/scripts/i2v/config.yaml @@ -0,0 +1,65 @@ +# ================ Logging ================ +name4train: "i2v-train" +log_level: "INFO" # Options: ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + +# ================ Model ================ +model_name: "cogvideox1.5-i2v" # Options: ["cogvideox-i2v", "cogvideox1.5-i2v"] +model_path: "THUDM/CogVideoX1.5-5B-I2V" + + +# ================ Output ================ +output_dir: "/path/to/output" + + +# ================ Tracker ================ +report_to: null # Options: ["wandb"] + + +# ================ Data ================ +data_root: "/path/to/i2v/data" + + +# ================ Training ================ +seed: 42 +training_type: "lora" # Options: ["lora", "sft"] + +strategy: "DDP" # Options: ["DDP", "SHARD_GRAD_OP", "FULL_SHARD", "HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"] + +# This will offload model param and grads to CPU memory to save GPU memory, but will slow down training +offload_params_grads: false + +# This will increase memory usage since gradients are sharded during accumulation step. +# Note: When used with offload_params_grads, model parameters and gradients will only be offloaded +# to the CPU during the final synchronization (still retained on GPU in gradient accumulation steps) +# which means offload_params_grads is meaningless when used with no_grad_sync_when_accumulating +no_grad_sync_when_accumulating: false + +# When enable_packing is true, training will use the native image resolution, +# otherwise all images will be resized to train_resolution, which may distort the original aspect ratio. +# IMPORTANT: When changing enable_packing from true to false (or false to true), +# make sure to clear the `.cache` directories in your `data_root/train` and `data_root/test` folders if they exist. +enable_packing: false + +# Note: +# for CogVideoX series models, number of training frames should be **8N+1** +# for CogVideoX1.5 series models, number of training frames should be **16N+1** +train_resolution: [81, 768, 1360] # [Frames, Height, Width] + +train_epochs: 1 +batch_size: 1 +gradient_accumulation_steps: 1 +mixed_precision: "bf16" # Options: ["fp32", "fp16", "bf16"] +learning_rate: 2.0e-5 + +num_workers: 8 +pin_memory: true + +checkpointing_steps: 10 +checkpointing_limit: 2 +resume_from_checkpoint: null # or "/path/to/checkpoint/dir" + + +# ================ Validation ================ +do_validation: true +validation_steps: 10 # Must be a multiple of `checkpointing_steps` +gen_fps: 16 diff --git a/quickstart/scripts/i2v/start_train.sh b/quickstart/scripts/i2v/start_train.sh new file mode 100644 index 0000000..0dc933f --- /dev/null +++ b/quickstart/scripts/i2v/start_train.sh @@ -0,0 +1,7 @@ +#! /usr/bin/env bash + +torchrun \ + --nproc_per_node=[number of GPUs] \ + --master_port=29501 \ + ../train.py \ + --yaml config.yaml diff --git a/quickstart/scripts/t2i/config.yaml b/quickstart/scripts/t2i/config.yaml new file mode 100644 index 0000000..42210ac --- /dev/null +++ b/quickstart/scripts/t2i/config.yaml @@ -0,0 +1,64 @@ +# ================ Logging ================ +name4train: "t2i-train" +log_level: "INFO" # Options: ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + +# ================ Model ================ +model_name: "cogview4-6b" # Options: ["cogview4-6b"] +model_path: "THUDM/CogView4-6B" + + +# ================ Output ================ +output_dir: "/path/to/output" + + +# ================ Tracker ================ +report_to: null # Options: ["wandb"] + + +# ================ Data ================ +data_root: "/path/to/t2i/data" + +# ================ Training ================ +seed: 42 +training_type: "lora" # Options: ["lora", "sft"] + +strategy: "DDP" # Options: ["DDP", "SHARD_GRAD_OP", "FULL_SHARD", "HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"] + +# This will offload model param and grads to CPU memory to save GPU memory, but will slow down training +offload_params_grads: false + +# This will increase memory usage since gradients are sharded during accumulation step. +# Note: When used with offload_params_grads, model parameters and gradients will only be offloaded +# to the CPU during the final synchronization (still retained on GPU in gradient accumulation steps) +# which means offload_params_grads is meaningless when used with no_grad_sync_when_accumulating +no_grad_sync_when_accumulating: false + +# When enable_packing is true, training will use the native image resolution, +# otherwise all images will be resized to train_resolution, which may distort the original aspect ratio. +# IMPORTANT: When changing enable_packing from true to false (or false to true), +# make sure to clear the `.cache` directories in your `data_root/train` and `data_root/test` folders if they exist. +enable_packing: false + +# This will slow down validation speed and enable quantization during training to save GPU memory +low_vram: false + +# Note: For CogView4 series models, height and width should be **32N** (multiple of 32) +train_resolution: [1024, 1024] # [Height, Width] + +train_epochs: 1 +batch_size: 1 +gradient_accumulation_steps: 1 +mixed_precision: "bf16" # Options: ["fp32", "fp16", "bf16"] +learning_rate: 2.0e-5 + +num_workers: 8 +pin_memory: true + +checkpointing_steps: 10 +checkpointing_limit: 2 +resume_from_checkpoint: null # or "/path/to/checkpoint/dir" + + +# ================ Validation ================ +do_validation: true +validation_steps: 10 # Must be a multiple of `checkpointing_steps` diff --git a/quickstart/scripts/t2i/start_train.sh b/quickstart/scripts/t2i/start_train.sh new file mode 100644 index 0000000..0dc933f --- /dev/null +++ b/quickstart/scripts/t2i/start_train.sh @@ -0,0 +1,7 @@ +#! /usr/bin/env bash + +torchrun \ + --nproc_per_node=[number of GPUs] \ + --master_port=29501 \ + ../train.py \ + --yaml config.yaml diff --git a/quickstart/scripts/t2v/config.yaml b/quickstart/scripts/t2v/config.yaml new file mode 100644 index 0000000..4156f69 --- /dev/null +++ b/quickstart/scripts/t2v/config.yaml @@ -0,0 +1,64 @@ +# ================ Logging ================ +name4train: "t2v-train" +log_level: "INFO" # Options: ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + +# ================ Model ================ +model_name: "cogvideox1.5-t2v" # Options: ["cogvideox-t2v", "cogvideox1.5-t2v"] +model_path: "THUDM/CogVideoX1.5-5B" + +# ================ Output ================ +output_dir: "/path/to/output" + + +# ================ Tracker ================ +report_to: null # Options: ["wandb"] + + +# ================ Data ================ +data_root: "/path/to/t2v/data" + + +# ================ Training ================ +seed: 42 +training_type: "lora" # Options: ["lora", "sft"] + +strategy: "DDP" # Options: ["DDP", "SHARD_GRAD_OP", "FULL_SHARD", "HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"] + +# This will offload model param and grads to CPU memory to save GPU memory, but will slow down training +offload_params_grads: false + +# This will increase memory usage since gradients are sharded during accumulation step. +# Note: When used with offload_params_grads, model parameters and gradients will only be offloaded +# to the CPU during the final synchronization (still retained on GPU in gradient accumulation steps) +# which means offload_params_grads is meaningless when used with no_grad_sync_when_accumulating +no_grad_sync_when_accumulating: false + +# When enable_packing is true, training will use the native image resolution, +# otherwise all images will be resized to train_resolution, which may distort the original aspect ratio. +# IMPORTANT: When changing enable_packing from true to false (or false to true), +# make sure to clear the `.cache` directories in your `data_root/train` and `data_root/test` folders if they exist. +enable_packing: false + +# Note: +# for CogVideoX series models, number of training frames should be **8N+1** +# for CogVideoX1.5 series models, number of training frames should be **16N+1** +train_resolution: [81, 768, 1360] # [Frames, Height, Width] + +train_epochs: 1 +batch_size: 1 +gradient_accumulation_steps: 1 +mixed_precision: "bf16" # Options: ["fp32", "fp16", "bf16"] +learning_rate: 2.0e-5 + +num_workers: 8 +pin_memory: true + +checkpointing_steps: 10 +checkpointing_limit: 2 +resume_from_checkpoint: null # or "/path/to/checkpoint/dir" + + +# ================ Validation ================ +do_validation: true +validation_steps: 10 # Must be a multiple of `checkpointing_steps` +gen_fps: 16 diff --git a/quickstart/scripts/t2v/start_train.sh b/quickstart/scripts/t2v/start_train.sh new file mode 100644 index 0000000..0dc933f --- /dev/null +++ b/quickstart/scripts/t2v/start_train.sh @@ -0,0 +1,7 @@ +#! /usr/bin/env bash + +torchrun \ + --nproc_per_node=[number of GPUs] \ + --master_port=29501 \ + ../train.py \ + --yaml config.yaml diff --git a/quickstart/scripts/train.py b/quickstart/scripts/train.py index e899273..d9e901a 100644 --- a/quickstart/scripts/train.py +++ b/quickstart/scripts/train.py @@ -1,17 +1,21 @@ import argparse +import yaml from cogkit.finetune import get_model_cls def main(): parser = argparse.ArgumentParser() - parser.add_argument("--model_name", type=str, required=True) - parser.add_argument("--training_type", type=str, required=True) - parser.add_argument("--enable_packing", type=lambda x: x.lower() == "true") - args, unknown = parser.parse_known_args() + parser.add_argument("--yaml", type=str, required=True) + args = parser.parse_args() - trainer_cls = get_model_cls(args.model_name, args.training_type, args.enable_packing) - trainer = trainer_cls() + with open(args.yaml, "r") as f: + config = yaml.safe_load(f) + + trainer_cls = get_model_cls( + config["model_name"], config["training_type"], config["enable_packing"] + ) + trainer = trainer_cls(args.yaml) trainer.fit() diff --git a/quickstart/scripts/train_ddp_i2v.sh b/quickstart/scripts/train_ddp_i2v.sh deleted file mode 100755 index 57824eb..0000000 --- a/quickstart/scripts/train_ddp_i2v.sh +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env bash -# Run by `bash scripts/train_ddp_i2v.sh` - -# Prevent tokenizer parallelism issues -export TOKENIZERS_PARALLELISM=false - -# Model Configuration -MODEL_ARGS=( - --model_path "THUDM/CogVideoX1.5-5B-I2V" - --model_name "cogvideox1.5-i2v" # candidate: ["cogvideox-i2v", "cogvideox1.5-i2v"] - --model_type "i2v" - --training_type "lora" -) - -# Output Configuration -OUTPUT_ARGS=( - --output_dir "/path/to/output" - --report_to "tensorboard" -) - -# Data Configuration -DATA_ARGS=( - --data_root "/path/to/data" -) - -# Training Configuration -TRAIN_ARGS=( - --seed 42 # random seed - --train_epochs 1 # number of training epochs - --batch_size 1 - --gradient_accumulation_steps 1 - --mixed_precision "bf16" # ["no", "fp16"] - --learning_rate 5e-5 - - # Note: - # for CogVideoX series models, number of training frames should be **8N+1** - # for CogVideoX1.5 series models, number of training frames should be **16N+1** - --train_resolution "81x768x1360" # (frames x height x width) -) - -# System Configuration -SYSTEM_ARGS=( - --num_workers 8 - --pin_memory true - --nccl_timeout 1800 -) - -# Checkpointing Configuration -CHECKPOINT_ARGS=( - --checkpointing_steps 10 # save checkpoint every x steps - --checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted - # --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint -) - -# Validation Configuration -VALIDATION_ARGS=( - --do_validation true # ["true", "false"] - --validation_steps 10 # should be multiple of checkpointing_steps - --gen_fps 16 -) - -# Combine all arguments and launch training -accelerate launch train.py \ - "${MODEL_ARGS[@]}" \ - "${OUTPUT_ARGS[@]}" \ - "${DATA_ARGS[@]}" \ - "${TRAIN_ARGS[@]}" \ - "${SYSTEM_ARGS[@]}" \ - "${CHECKPOINT_ARGS[@]}" \ - "${VALIDATION_ARGS[@]}" diff --git a/quickstart/scripts/train_ddp_t2i.sh b/quickstart/scripts/train_ddp_t2i.sh deleted file mode 100755 index 6aae45f..0000000 --- a/quickstart/scripts/train_ddp_t2i.sh +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/env bash -# Run by `bash scripts/train_ddp_i2v.sh` - -# Prevent tokenizer parallelism issues -export TOKENIZERS_PARALLELISM=false - -# Model Configuration -MODEL_ARGS=( - --model_path "THUDM/CogView4-6B" - --model_name "cogview4-6b" # candidate: ["cogview4-6b"] - --model_type "t2i" - --training_type "lora" -) - -# Output Configuration -OUTPUT_ARGS=( - --output_dir "/path/to/output" - --report_to "tensorboard" -) - -# Data Configuration -DATA_ARGS=( - --data_root "/path/to/data" -) - -# Training Configuration -TRAIN_ARGS=( - --seed 42 # random seed - --train_epochs 1 # number of training epochs - --batch_size 1 - - --gradient_accumulation_steps 1 - - # Note: For CogView4 series models, height and width should be **32N** (multiple of 32) - --train_resolution "1024x1024" # (height x width) - - # When enable_packing is true, training will use the native image resolution - # (otherwise all images will be resized to train_resolution, which may distort the original aspect ratio). - # - # IMPORTANT: When changing enable_packing from true to false (or vice versa), - # make sure to clear the .cache directories in your data_root/train and data_root/test folders if they exist. - --enable_packing false - - --mixed_precision "bf16" # ["no", "fp16"] - --learning_rate 5e-5 - - # enable --low_vram will slow down validation speed and enable quantization during training - # Note: --low_vram currently does not support multi-GPU training - --low_vram false -) - -# System Configuration -SYSTEM_ARGS=( - --num_workers 8 - --pin_memory true - --nccl_timeout 1800 -) - -# Checkpointing Configuration -CHECKPOINT_ARGS=( - --checkpointing_steps 10 # save checkpoint every x steps - --checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted - # --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint -) - -# Validation Configuration -VALIDATION_ARGS=( - --do_validation true # ["true", "false"] - --validation_steps 10 # should be multiple of checkpointing_steps -) - -# Combine all arguments and launch training -accelerate launch train.py \ - "${MODEL_ARGS[@]}" \ - "${OUTPUT_ARGS[@]}" \ - "${DATA_ARGS[@]}" \ - "${TRAIN_ARGS[@]}" \ - "${SYSTEM_ARGS[@]}" \ - "${CHECKPOINT_ARGS[@]}" \ - "${VALIDATION_ARGS[@]}" diff --git a/quickstart/scripts/train_ddp_t2v.sh b/quickstart/scripts/train_ddp_t2v.sh deleted file mode 100755 index ca31e49..0000000 --- a/quickstart/scripts/train_ddp_t2v.sh +++ /dev/null @@ -1,69 +0,0 @@ -#!/usr/bin/env bash - -# Prevent tokenizer parallelism issues -export TOKENIZERS_PARALLELISM=false - -# Model Configuration -MODEL_ARGS=( - --model_path "THUDM/CogVideoX1.5-5B" - --model_name "cogvideox1.5-t2v" # candidate: ["cogvideox-t2v", "cogvideox1.5-t2v"] - --model_type "t2v" - --training_type "lora" -) - -# Output Configuration -OUTPUT_ARGS=( - --output_dir "/path/to/output" - --report_to "tensorboard" -) - -# Data Configuration -DATA_ARGS=( - --data_root "/path/to/data" -) - -# Training Configuration -TRAIN_ARGS=( - --seed 42 # random seed - --train_epochs 1 # number of training epochs - --batch_size 1 - --gradient_accumulation_steps 1 - --mixed_precision "bf16" # ["no", "fp16"] Note: CogVideoX-2B only supports fp16 training - --learning_rate 5e-5 - - # Note: - # for CogVideoX series models, number of training frames should be **8N+1** - # for CogVideoX1.5 series models, number of training frames should be **16N+1** - --train_resolution "81x768x1360" # (frames x height x width) -) - -# System Configuration -SYSTEM_ARGS=( - --num_workers 8 - --pin_memory true - --nccl_timeout 1800 -) - -# Checkpointing Configuration -CHECKPOINT_ARGS=( - --checkpointing_steps 10 # save checkpoint every x steps - --checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted - # --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint -) - -# Validation Configuration -VALIDATION_ARGS=( - --do_validation true # ["true", "false"] - --validation_steps 10 # should be multiple of checkpointing_steps - --gen_fps 16 -) - -# Combine all arguments and launch training -accelerate launch train.py \ - "${MODEL_ARGS[@]}" \ - "${OUTPUT_ARGS[@]}" \ - "${DATA_ARGS[@]}" \ - "${TRAIN_ARGS[@]}" \ - "${SYSTEM_ARGS[@]}" \ - "${CHECKPOINT_ARGS[@]}" \ - "${VALIDATION_ARGS[@]}" diff --git a/quickstart/scripts/train_zero_i2v.sh b/quickstart/scripts/train_zero_i2v.sh deleted file mode 100755 index bfa07b7..0000000 --- a/quickstart/scripts/train_zero_i2v.sh +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env bash - -# Prevent tokenizer parallelism issues -export TOKENIZERS_PARALLELISM=false - -# Model Configuration -MODEL_ARGS=( - --model_path "THUDM/CogVideoX1.5-5B-I2V" - --model_name "cogvideox1.5-i2v" # candidate: ["cogvideox-i2v", "cogvideox1.5-i2v"] - --model_type "i2v" - --training_type "sft" -) - -# Output Configuration -OUTPUT_ARGS=( - --output_dir "/path/to/output" - --report_to "tensorboard" -) - -# Data Configuration -DATA_ARGS=( - --data_root "/path/to/data" -) - -# Training Configuration -TRAIN_ARGS=( - --seed 42 # random seed - --train_epochs 1 # number of training epochs - - --learning_rate 5e-5 - - ######### Please keep consistent with deepspeed config file ########## - --batch_size 1 - --gradient_accumulation_steps 1 - --mixed_precision "bf16" # ["no", "fp16"] Note: CogVideoX-2B only supports fp16 training - ######################################################################## - - # Note: - # for CogVideoX series models, number of training frames should be **8N+1** - # for CogVideoX1.5 series models, number of training frames should be **16N+1** - --train_resolution "81x768x1360" # (frames x height x width) - -) - -# System Configuration -SYSTEM_ARGS=( - --num_workers 8 - --pin_memory true - --nccl_timeout 1800 -) - -# Checkpointing Configuration -CHECKPOINT_ARGS=( - --checkpointing_steps 10 # save checkpoint every x steps - --checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted - # --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint -) - -# Validation Configuration -VALIDATION_ARGS=( - --do_validation true # ["true", "false"] - --validation_steps 10 # should be multiple of checkpointing_steps - --gen_fps 16 -) - -# Combine all arguments and launch training -accelerate launch --config_file ../configs/accelerate_config.yaml train.py \ - "${MODEL_ARGS[@]}" \ - "${OUTPUT_ARGS[@]}" \ - "${DATA_ARGS[@]}" \ - "${TRAIN_ARGS[@]}" \ - "${SYSTEM_ARGS[@]}" \ - "${CHECKPOINT_ARGS[@]}" \ - "${VALIDATION_ARGS[@]}" diff --git a/quickstart/scripts/train_zero_t2i.sh b/quickstart/scripts/train_zero_t2i.sh deleted file mode 100755 index 878cd27..0000000 --- a/quickstart/scripts/train_zero_t2i.sh +++ /dev/null @@ -1,78 +0,0 @@ -#!/usr/bin/env bash - -# Prevent tokenizer parallelism issues -export TOKENIZERS_PARALLELISM=false - -# Model Configuration -MODEL_ARGS=( - --model_path "THUDM/CogView4-6B" - --model_name "cogview4-6b" # candidate: ["cogview4-6b"] - --model_type "t2i" - --training_type "sft" -) - -# Output Configuration -OUTPUT_ARGS=( - --output_dir "/path/to/output" - --report_to "tensorboard" -) - -# Data Configuration -DATA_ARGS=( - --data_root "/path/to/data" -) - -# Training Configuration -TRAIN_ARGS=( - --seed 42 # random seed - --train_epochs 1 # number of training epochs - - --learning_rate 5e-5 - - # Note: For CogView4 series models, height and width should be **32N** (multiple of 32) - --train_resolution "1024x1024" # (height x width) - - ######### Please keep consistent with deepspeed config file ########## - --batch_size 1 - --gradient_accumulation_steps 1 - --mixed_precision "bf16" # ["no", "fp16"] Note: CogVideoX-2B only supports fp16 training - ######################################################################## - - # When enable_packing is true, training will use the native image resolution - # (otherwise all images will be resized to train_resolution, which may distort the original aspect ratio). - # - # IMPORTANT: When changing enable_packing from true to false (or vice versa), - # make sure to clear the .cache directories in your data_root/train and data_root/test folders if they exist. - --enable_packing false - -) - -# System Configuration -SYSTEM_ARGS=( - --num_workers 8 - --pin_memory true - --nccl_timeout 1800 -) - -# Checkpointing Configuration -CHECKPOINT_ARGS=( - --checkpointing_steps 10 # save checkpoint every x steps - --checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted - # --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint -) - -# Validation Configuration -VALIDATION_ARGS=( - --do_validation true # ["true", "false"] - --validation_steps 10 # should be multiple of checkpointing_steps -) - -# Combine all arguments and launch training -accelerate launch --config_file ../configs/accelerate_config.yaml train.py\ - "${MODEL_ARGS[@]}" \ - "${OUTPUT_ARGS[@]}" \ - "${DATA_ARGS[@]}" \ - "${TRAIN_ARGS[@]}" \ - "${SYSTEM_ARGS[@]}" \ - "${CHECKPOINT_ARGS[@]}" \ - "${VALIDATION_ARGS[@]}" diff --git a/quickstart/scripts/train_zero_t2v.sh b/quickstart/scripts/train_zero_t2v.sh deleted file mode 100755 index 516afc2..0000000 --- a/quickstart/scripts/train_zero_t2v.sh +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env bash - -# Prevent tokenizer parallelism issues -export TOKENIZERS_PARALLELISM=false - -# Model Configuration -MODEL_ARGS=( - --model_path "THUDM/CogVideoX1.5-5B" - --model_name "cogvideox1.5-t2v" # candidate: ["cogvideox-t2v", "cogvideox1.5-t2v"] - --model_type "t2v" - --training_type "sft" -) - -# Output Configuration -OUTPUT_ARGS=( - --output_dir "/path/to/output" - --report_to "tensorboard" -) - -# Data Configuration -DATA_ARGS=( - --data_root "/path/to/data" -) - -# Training Configuration -TRAIN_ARGS=( - --seed 42 # random seed - --train_epochs 1 # number of training epochs - - --learning_rate 5e-5 - - ######### Please keep consistent with deepspeed config file ########## - --batch_size 1 - --gradient_accumulation_steps 1 - --mixed_precision "bf16" # ["no", "fp16"] Note: CogVideoX-2B only supports fp16 training - ######################################################################## - - # Note: - # for CogVideoX series models, number of training frames should be **8N+1** - # for CogVideoX1.5 series models, number of training frames should be **16N+1** - --train_resolution "81x768x1360" # (frames x height x width) -) - -# System Configuration -SYSTEM_ARGS=( - --num_workers 8 - --pin_memory true - --nccl_timeout 1800 -) - -# Checkpointing Configuration -CHECKPOINT_ARGS=( - --checkpointing_steps 10 # save checkpoint every x steps - --checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted - # --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint -) - -# Validation Configuration -VALIDATION_ARGS=( - --do_validation true # ["true", "false"] - --validation_steps 10 # should be multiple of checkpointing_steps - --gen_fps 16 -) - -# Combine all arguments and launch training -accelerate launch --config_file ../configs/accelerate_config.yaml train.py \ - "${MODEL_ARGS[@]}" \ - "${OUTPUT_ARGS[@]}" \ - "${DATA_ARGS[@]}" \ - "${TRAIN_ARGS[@]}" \ - "${SYSTEM_ARGS[@]}" \ - "${CHECKPOINT_ARGS[@]}" \ - "${VALIDATION_ARGS[@]}" From 41ec0d9c3b17048aa900c58f47d3b39c473ea74d Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 9 May 2025 10:41:35 +0000 Subject: [PATCH 16/19] [tool] Add merge tool for dist checkpoint --- tools/converters/merge.py | 47 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100755 tools/converters/merge.py diff --git a/tools/converters/merge.py b/tools/converters/merge.py new file mode 100755 index 0000000..5f5c547 --- /dev/null +++ b/tools/converters/merge.py @@ -0,0 +1,47 @@ +#! /usr/bin/env python + +import argparse + +import torch +from torch.distributed.checkpoint.format_utils import dcp_to_torch_save +from pathlib import Path +from cogkit.utils.lora import _LORA_WEIGHT_NAME +from safetensors.torch import save_file + +TORCH_SAVE_CHECKPOINT_DIR = "diffusion_pytorch_model.bin" + + +def main(checkpoint_dir: str, output_dir: str, is_lora: bool = False): + # convert dcp model to torch.save (assumes checkpoint was generated as above) + checkpoint_dir = Path(checkpoint_dir) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + ckpt_file = output_dir / TORCH_SAVE_CHECKPOINT_DIR + + print("Converting FSDP checkpoint to torch.save format...") + dcp_to_torch_save(checkpoint_dir, ckpt_file) + state = torch.load(ckpt_file, map_location="cpu") + print("Deleting torch checkpoint...") + ckpt_file.unlink() + model_weights = state["app"]["model"] + + print("Saving transformer weights...") + if is_lora: + ckpt_file = ckpt_file.with_name(_LORA_WEIGHT_NAME) + save_file(model_weights, ckpt_file) + + else: + ckpt_file = ckpt_file.with_name(TORCH_SAVE_CHECKPOINT_DIR) + torch.save(model_weights, ckpt_file) + + print("Done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_dir", type=str, required=True) + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--lora", action="store_true", default=False) + args = parser.parse_args() + + main(args.checkpoint_dir, args.output_dir, args.lora) From d66ce0fd8ed0c90a790edd21fe8c38afce2b6c91 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 9 May 2025 10:42:23 +0000 Subject: [PATCH 17/19] [deps] Update dependencies --- pyproject.toml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0ddad69..3d52425 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,6 @@ dependencies = [ "pydantic~=2.10", "sentencepiece==0.2.0", "transformers~=4.49", - "wandb~=0.19.8", "fastapi[standard]~=0.115.11", "fastapi_cli~=0.0.7", "openai~=1.67", @@ -31,10 +30,10 @@ dependencies = [ [project.optional-dependencies] finetune = [ "datasets~=3.4", - "deepspeed~=0.16.4", + "wandb~=0.19.8", "av~=14.2.0", "bitsandbytes~=0.45.4", - "tensorboard~=2.19", + "pyyaml>=6.0.2", ] [project.urls] From 72445364510686c51c9b2fac1f394ac20bc9e6cc Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 9 May 2025 10:45:18 +0000 Subject: [PATCH 18/19] [docs] Update --- docs/01-Intro.md | 7 +- docs/04-Finetune/01-Prerequisites.mdx | 118 ++++++++------------------ docs/04-Finetune/02-Quick Start.md | 26 +++--- 3 files changed, 47 insertions(+), 104 deletions(-) diff --git a/docs/01-Intro.md b/docs/01-Intro.md index 92b4fe2..635a690 100644 --- a/docs/01-Intro.md +++ b/docs/01-Intro.md @@ -4,7 +4,7 @@ slug: / # Introduction -CogKit is an open-source project that provides a user-friendly interface for researchers and developers to utilize ZhipuAI's [**CogView**](https://huggingface.co/collections/THUDM/cogview-67ac3f241eefad2af015669b) (image generation) and [**CogVideoX**](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) (video generation) models. It streamlines multimodal tasks such as **text-to-image (T2I)**, **text-to-video (T2V)**, and **image-to-video (I2V)**. Users must comply with legal and ethical guidelines to ensure responsible implementation. +CogKit is an open-source project that provides a user-friendly interface for researchers and developers to utilize ZhipuAI's [CogView](https://huggingface.co/collections/THUDM/cogview-67ac3f241eefad2af015669b) (image generation) and [CogVideoX](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) (video generation) models. It streamlines multimodal tasks such as text-to-image(T2I), text-to-video(T2V), and image-to-video(I2V). Users must comply with legal and ethical guidelines to ensure responsible implementation. ## Supported Models @@ -12,7 +12,4 @@ Please refer to the [Model Card](./05-Model%20Card.mdx) for more details. ## Environment Testing -This repository has been tested in environments with `1×A100` and `8×A100` GPUs, using `CUDA 12.4, Python 3.10.16`. - -- Cog series models typically do not support `FP16` precision (Only `CogVideoX-2B` support); GPUs like the `V100` cannot be fine-tuned properly (Will cause `loss=nan` for example). At a minimum, an `A100` or other GPUs supporting `BF16` precision should be used. -- We have not yet systematically tested the minimum GPU memory requirements for each model. For `LORA(bs=1 with offload)`, a single `A100` GPU is sufficient. For `SFT`, our tests have passed in an `8×A100` environment. +This repository has been tested in environments with 8×A100 GPUs, using CUDA 12.4, Python 3.10.16. diff --git a/docs/04-Finetune/01-Prerequisites.mdx b/docs/04-Finetune/01-Prerequisites.mdx index f004cc9..2a32ceb 100644 --- a/docs/04-Finetune/01-Prerequisites.mdx +++ b/docs/04-Finetune/01-Prerequisites.mdx @@ -3,7 +3,7 @@ # Prerequisites -Before starting fine-tuning, please ensure your machine meets the minimum hardware requirements listed in the tables below. The tables show the minimum VRAM (GPU memory) requirements for different models under various configurations. +Before starting fine-tuning, please ensure your machine meets the minimum hardware requirements listed in the tables below. The tables show the minimum VRAM requirements for different models under various configurations (test on 8xA100). ## CogVideo Series @@ -11,101 +11,61 @@ Before starting fine-tuning, please ensure your machine meets the minimum hardwa Model - Training Type - Distribution Strategy - Training Resolution (FxHxW) + Type + Strategy + Resolution
(FxHxW) Requirement - cogvideox-t2v-2b + cogvideox-t2v-2b lora DDP 49x480x720 - 16GB VRAM + 1 GPU with
12GB VRAM - sft + sft DDP 49x480x720 - 36GB VRAM + 1 GPU with
25GB VRAM - 1-GPU zero-2 + opt offload - 49x480x720 - 17GB VRAM - - - 8-GPU zero-2 - 49x480x720 - 17GB VRAM - - - 8-GPU zero-3 - 49x480x720 - 19GB VRAM - - - 8-GPU zero-3 + opt and param offload - 49x480x720 - 14GB VRAM - - - cogvideox-\{t2v,i2v\}-5b + cogvideox-\{t2v,i2v\}-5b lora DDP 49x480x720 - 24GB VRAM - - - sft - 1-GPU zero-2 + opt offload - 49x480x720 - 42GB VRAM + 1 GPU with
24GB VRAM - 8-GPU zero-2 + sft + FSDP fullshard 49x480x720 - 42GB VRAM + 8 GPU with
20GB VRAM - 8-GPU zero-3 + FSDP fullshard + offload 49x480x720 - 43GB VRAM + 1 GPU with
16GB VRAM - 8-GPU zero-3 + opt and param offload - 49x480x720 - 28GB VRAM - - - cogvideox1.5-\{t2v,i2v\}-5b + cogvideox1.5-\{t2v,i2v\}-5b lora DDP 81x768x1360 - 35GB VRAM - - - sft - 1-GPU zero-2 + opt offload - 81x768x1360 - 56GB VRAM + 1 GPU with
32GB VRAM - 8-GPU zero-2 + sft + FSDP fullshard 81x768x1360 - 55GB VRAM + 8 GPUs with
31GB VRAM - 8-GPU zero-3 + FSDP fullshard + offload 81x768x1360 - 55GB VRAM - - - 8-GPU zero-3 + opt and param offload - 81x768x1360 - 40GB VRAM + 8 GPUs with
27GB VRAM @@ -116,46 +76,36 @@ Before starting fine-tuning, please ensure your machine meets the minimum hardwa Model - Training Type - Distribution Strategy - Training Resolution (HxW) + Type + Strategy + Resolution
(HxW) Requirement - CogView4-6B - qlora + param offload
(`--low_vram`) + CogView4-6B + qlora + offload
(enable --low_vram) DDP 1024x1024 - 9GB VRAM + 1 GPU with
9GB VRAM lora DDP 1024x1024 - 30GB VRAM - - - sft - 1-GPU zero-2 + opt offload - 1024x1024 - 42GB VRAM - - - 8-GPU zero-2 - 1024x1024 - 50GB VRAM + 1 GPU with
20GB VRAM - 8-GPU zero-3 + sft + FSDP fullshard 1024x1024 - 47GB VRAM + 8 GPUs with
28GB VRAM - 8-GPU zero-3 + opt and param offload + FSDP fullshard + offload 1024x1024 - 28GB VRAM + 8 GPUs with
22GB VRAM diff --git a/docs/04-Finetune/02-Quick Start.md b/docs/04-Finetune/02-Quick Start.md index ea388a5..29a04b5 100644 --- a/docs/04-Finetune/02-Quick Start.md +++ b/docs/04-Finetune/02-Quick Start.md @@ -31,32 +31,28 @@ We recommend that you read the corresponding [model card](../05-Model%20Card.mdx cd CogKit/ ``` -2. Choose the appropriate training script from the `quickstart/scripts` directory based on your task type and distribution strategy. For example, `train_ddp_t2i.sh` corresponds to DDP strategy + text-to-image task +2. Choose the appropriate sub directory from the `quickstart/scripts` based on your task type and distribution strategy. For example, `t2i` corresponds to text-to-image task -3. Review and adjust the parameters in the selected training script (e.g., `--data_root`, `--output_dir`, etc.) +3. Review and adjust the parameters in `config.yaml` in the selected training directory -4. [Optional] If you are using ZeRO strategy, refer to `quickstart/configs/accelerate_config.yaml` to confirm your ZeRO config file and number of GPUs. - -5. Run the script, for example: +4. Run the script in selected directory: ```bash - cd quickstart/scripts - bash train_ddp_t2i.sh + bash start_train.sh ``` ## Load Fine-tuned Model -### LoRA - -After fine-tuning with LoRA, you can load your trained weights during inference using the `--lora_model_id_or_path` option or parameter. For more details, please refer to the inference guide. +### Merge Checkpoint -### ZeRO - -After fine-tuning with ZeRO strategy, you need to use the `zero_to_fp32.py` script provided in the `quickstart/tools/converters` directory to convert the ZeRO checkpoint weights into Diffusers format. For example: +After fine-tuning, you need to use the `merge.py` script provided in the `quickstart/tools/converters` directory to merge the distributed checkpoint weights into a single checkpoint (**except for QLoRA fine-tuning**). For example: ```bash cd quickstart/tools/converters -python zero2diffusers.py checkpoint_dir/ output_dir/ --bfloat16 +python merge.py --checkpoint_dir ckpt/ --output_dir output_dir/ +# Add --lora option if you are using LoRA fine-tuning ``` -During inference, pass the `output_dir/` to the `--transformer_path` option or parameter. For more details, please refer to the inference guide. +### Load Checkpoint + +You can pass the `output_dir` to the `--lora_model_id_or_path` option if you are using LoRA fine-tuning, or to the `--transformer_path` option if you are using FSDP fine-tuning. For more details, please refer to the inference guide. From 5e42b8e2aae5eded9b9cb09f03c696c8575ebb54 Mon Sep 17 00:00:00 2001 From: Chenhui Zhang Date: Wed, 14 May 2025 19:20:46 +0800 Subject: [PATCH 19/19] nit: resolves format issues --- docs/04-Finetune/02-Quick Start.md | 9 ++++++--- tools/converters/merge.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/docs/04-Finetune/02-Quick Start.md b/docs/04-Finetune/02-Quick Start.md index 29a04b5..af7a79b 100644 --- a/docs/04-Finetune/02-Quick Start.md +++ b/docs/04-Finetune/02-Quick Start.md @@ -27,15 +27,16 @@ We recommend that you read the corresponding [model card](../05-Model%20Card.mdx ::: 1. Navigate to the `CogKit/` directory after cloning the repository + ```bash cd CogKit/ ``` -2. Choose the appropriate sub directory from the `quickstart/scripts` based on your task type and distribution strategy. For example, `t2i` corresponds to text-to-image task +2. Choose the appropriate subdirectory from the `quickstart/scripts` based on your task type and distribution strategy. For example, `t2i` corresponds to text-to-image task 3. Review and adjust the parameters in `config.yaml` in the selected training directory -4. Run the script in selected directory: +4. Run the script in the selected directory: ```bash bash start_train.sh @@ -45,7 +46,9 @@ We recommend that you read the corresponding [model card](../05-Model%20Card.mdx ### Merge Checkpoint -After fine-tuning, you need to use the `merge.py` script provided in the `quickstart/tools/converters` directory to merge the distributed checkpoint weights into a single checkpoint (**except for QLoRA fine-tuning**). For example: +After fine-tuning, you need to use the `merge.py` script to merge the distributed checkpoint weights into a single checkpoint (**except for QLoRA fine-tuning**). +The script can be found in the `quickstart/tools/converters` directory. +For example: ```bash cd quickstart/tools/converters diff --git a/tools/converters/merge.py b/tools/converters/merge.py index 5f5c547..e6999f0 100755 --- a/tools/converters/merge.py +++ b/tools/converters/merge.py @@ -1,12 +1,15 @@ -#! /usr/bin/env python +#!/usr/bin/env python +# -*- coding: utf-8 -*- + import argparse +from pathlib import Path import torch +from safetensors.torch import save_file from torch.distributed.checkpoint.format_utils import dcp_to_torch_save -from pathlib import Path + from cogkit.utils.lora import _LORA_WEIGHT_NAME -from safetensors.torch import save_file TORCH_SAVE_CHECKPOINT_DIR = "diffusion_pytorch_model.bin"