diff --git a/docs/01-Intro.md b/docs/01-Intro.md
index 92b4fe2..635a690 100644
--- a/docs/01-Intro.md
+++ b/docs/01-Intro.md
@@ -4,7 +4,7 @@ slug: /
# Introduction
-CogKit is an open-source project that provides a user-friendly interface for researchers and developers to utilize ZhipuAI's [**CogView**](https://huggingface.co/collections/THUDM/cogview-67ac3f241eefad2af015669b) (image generation) and [**CogVideoX**](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) (video generation) models. It streamlines multimodal tasks such as **text-to-image (T2I)**, **text-to-video (T2V)**, and **image-to-video (I2V)**. Users must comply with legal and ethical guidelines to ensure responsible implementation.
+CogKit is an open-source project that provides a user-friendly interface for researchers and developers to utilize ZhipuAI's [CogView](https://huggingface.co/collections/THUDM/cogview-67ac3f241eefad2af015669b) (image generation) and [CogVideoX](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) (video generation) models. It streamlines multimodal tasks such as text-to-image(T2I), text-to-video(T2V), and image-to-video(I2V). Users must comply with legal and ethical guidelines to ensure responsible implementation.
## Supported Models
@@ -12,7 +12,4 @@ Please refer to the [Model Card](./05-Model%20Card.mdx) for more details.
## Environment Testing
-This repository has been tested in environments with `1×A100` and `8×A100` GPUs, using `CUDA 12.4, Python 3.10.16`.
-
-- Cog series models typically do not support `FP16` precision (Only `CogVideoX-2B` support); GPUs like the `V100` cannot be fine-tuned properly (Will cause `loss=nan` for example). At a minimum, an `A100` or other GPUs supporting `BF16` precision should be used.
-- We have not yet systematically tested the minimum GPU memory requirements for each model. For `LORA(bs=1 with offload)`, a single `A100` GPU is sufficient. For `SFT`, our tests have passed in an `8×A100` environment.
+This repository has been tested in environments with 8×A100 GPUs, using CUDA 12.4, Python 3.10.16.
diff --git a/docs/04-Finetune/01-Prerequisites.mdx b/docs/04-Finetune/01-Prerequisites.mdx
index f004cc9..2a32ceb 100644
--- a/docs/04-Finetune/01-Prerequisites.mdx
+++ b/docs/04-Finetune/01-Prerequisites.mdx
@@ -3,7 +3,7 @@
# Prerequisites
-Before starting fine-tuning, please ensure your machine meets the minimum hardware requirements listed in the tables below. The tables show the minimum VRAM (GPU memory) requirements for different models under various configurations.
+Before starting fine-tuning, please ensure your machine meets the minimum hardware requirements listed in the tables below. The tables show the minimum VRAM requirements for different models under various configurations (test on 8xA100).
## CogVideo Series
@@ -11,101 +11,61 @@ Before starting fine-tuning, please ensure your machine meets the minimum hardwa
| Model |
- Training Type |
- Distribution Strategy |
- Training Resolution (FxHxW) |
+ Type |
+ Strategy |
+ Resolution (FxHxW) |
Requirement |
- | cogvideox-t2v-2b |
+ cogvideox-t2v-2b |
lora |
DDP |
49x480x720 |
- 16GB VRAM |
+ 1 GPU with 12GB VRAM |
- | sft |
+ sft |
DDP |
49x480x720 |
- 36GB VRAM |
+ 1 GPU with 25GB VRAM |
- | 1-GPU zero-2 + opt offload |
- 49x480x720 |
- 17GB VRAM |
-
-
- | 8-GPU zero-2 |
- 49x480x720 |
- 17GB VRAM |
-
-
- | 8-GPU zero-3 |
- 49x480x720 |
- 19GB VRAM |
-
-
- | 8-GPU zero-3 + opt and param offload |
- 49x480x720 |
- 14GB VRAM |
-
-
- | cogvideox-\{t2v,i2v\}-5b |
+ cogvideox-\{t2v,i2v\}-5b |
lora |
DDP |
49x480x720 |
- 24GB VRAM |
-
-
- | sft |
- 1-GPU zero-2 + opt offload |
- 49x480x720 |
- 42GB VRAM |
+ 1 GPU with 24GB VRAM |
- | 8-GPU zero-2 |
+ sft |
+ FSDP fullshard |
49x480x720 |
- 42GB VRAM |
+ 8 GPU with 20GB VRAM |
- | 8-GPU zero-3 |
+ FSDP fullshard + offload |
49x480x720 |
- 43GB VRAM |
+ 1 GPU with 16GB VRAM |
- | 8-GPU zero-3 + opt and param offload |
- 49x480x720 |
- 28GB VRAM |
-
-
- | cogvideox1.5-\{t2v,i2v\}-5b |
+ cogvideox1.5-\{t2v,i2v\}-5b |
lora |
DDP |
81x768x1360 |
- 35GB VRAM |
-
-
- | sft |
- 1-GPU zero-2 + opt offload |
- 81x768x1360 |
- 56GB VRAM |
+ 1 GPU with 32GB VRAM |
- | 8-GPU zero-2 |
+ sft |
+ FSDP fullshard |
81x768x1360 |
- 55GB VRAM |
+ 8 GPUs with 31GB VRAM |
- | 8-GPU zero-3 |
+ FSDP fullshard + offload |
81x768x1360 |
- 55GB VRAM |
-
-
- | 8-GPU zero-3 + opt and param offload |
- 81x768x1360 |
- 40GB VRAM |
+ 8 GPUs with 27GB VRAM |
@@ -116,46 +76,36 @@ Before starting fine-tuning, please ensure your machine meets the minimum hardwa
| Model |
- Training Type |
- Distribution Strategy |
- Training Resolution (HxW) |
+ Type |
+ Strategy |
+ Resolution (HxW) |
Requirement |
- | CogView4-6B |
- qlora + param offload (`--low_vram`) |
+ CogView4-6B |
+ qlora + offload (enable --low_vram) |
DDP |
1024x1024 |
- 9GB VRAM |
+ 1 GPU with 9GB VRAM |
| lora |
DDP |
1024x1024 |
- 30GB VRAM |
-
-
- | sft |
- 1-GPU zero-2 + opt offload |
- 1024x1024 |
- 42GB VRAM |
-
-
- | 8-GPU zero-2 |
- 1024x1024 |
- 50GB VRAM |
+ 1 GPU with 20GB VRAM |
- | 8-GPU zero-3 |
+ sft |
+ FSDP fullshard |
1024x1024 |
- 47GB VRAM |
+ 8 GPUs with 28GB VRAM |
- | 8-GPU zero-3 + opt and param offload |
+ FSDP fullshard + offload |
1024x1024 |
- 28GB VRAM |
+ 8 GPUs with 22GB VRAM |
diff --git a/docs/04-Finetune/02-Quick Start.md b/docs/04-Finetune/02-Quick Start.md
index ea388a5..af7a79b 100644
--- a/docs/04-Finetune/02-Quick Start.md
+++ b/docs/04-Finetune/02-Quick Start.md
@@ -27,36 +27,35 @@ We recommend that you read the corresponding [model card](../05-Model%20Card.mdx
:::
1. Navigate to the `CogKit/` directory after cloning the repository
+
```bash
cd CogKit/
```
-2. Choose the appropriate training script from the `quickstart/scripts` directory based on your task type and distribution strategy. For example, `train_ddp_t2i.sh` corresponds to DDP strategy + text-to-image task
-
-3. Review and adjust the parameters in the selected training script (e.g., `--data_root`, `--output_dir`, etc.)
+2. Choose the appropriate subdirectory from the `quickstart/scripts` based on your task type and distribution strategy. For example, `t2i` corresponds to text-to-image task
-4. [Optional] If you are using ZeRO strategy, refer to `quickstart/configs/accelerate_config.yaml` to confirm your ZeRO config file and number of GPUs.
+3. Review and adjust the parameters in `config.yaml` in the selected training directory
-5. Run the script, for example:
+4. Run the script in the selected directory:
```bash
- cd quickstart/scripts
- bash train_ddp_t2i.sh
+ bash start_train.sh
```
## Load Fine-tuned Model
-### LoRA
-
-After fine-tuning with LoRA, you can load your trained weights during inference using the `--lora_model_id_or_path` option or parameter. For more details, please refer to the inference guide.
+### Merge Checkpoint
-### ZeRO
-
-After fine-tuning with ZeRO strategy, you need to use the `zero_to_fp32.py` script provided in the `quickstart/tools/converters` directory to convert the ZeRO checkpoint weights into Diffusers format. For example:
+After fine-tuning, you need to use the `merge.py` script to merge the distributed checkpoint weights into a single checkpoint (**except for QLoRA fine-tuning**).
+The script can be found in the `quickstart/tools/converters` directory.
+For example:
```bash
cd quickstart/tools/converters
-python zero2diffusers.py checkpoint_dir/ output_dir/ --bfloat16
+python merge.py --checkpoint_dir ckpt/ --output_dir output_dir/
+# Add --lora option if you are using LoRA fine-tuning
```
-During inference, pass the `output_dir/` to the `--transformer_path` option or parameter. For more details, please refer to the inference guide.
+### Load Checkpoint
+
+You can pass the `output_dir` to the `--lora_model_id_or_path` option if you are using LoRA fine-tuning, or to the `--transformer_path` option if you are using FSDP fine-tuning. For more details, please refer to the inference guide.
diff --git a/pyproject.toml b/pyproject.toml
index 0ddad69..3d52425 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -18,7 +18,6 @@ dependencies = [
"pydantic~=2.10",
"sentencepiece==0.2.0",
"transformers~=4.49",
- "wandb~=0.19.8",
"fastapi[standard]~=0.115.11",
"fastapi_cli~=0.0.7",
"openai~=1.67",
@@ -31,10 +30,10 @@ dependencies = [
[project.optional-dependencies]
finetune = [
"datasets~=3.4",
- "deepspeed~=0.16.4",
+ "wandb~=0.19.8",
"av~=14.2.0",
"bitsandbytes~=0.45.4",
- "tensorboard~=2.19",
+ "pyyaml>=6.0.2",
]
[project.urls]
diff --git a/quickstart/configs/accelerate_config.yaml b/quickstart/configs/accelerate_config.yaml
deleted file mode 100644
index b6032b6..0000000
--- a/quickstart/configs/accelerate_config.yaml
+++ /dev/null
@@ -1,26 +0,0 @@
-compute_environment: LOCAL_MACHINE
-
-gpu_ids: "0,1,2,3,4,5,6,7"
-num_processes: 8 # should be the same as the number of GPUs
-
-# gpu_ids: "0"
-# num_processes: 1
-
-debug: false
-
-distributed_type: DEEPSPEED
-deepspeed_config:
- deepspeed_config_file: /path/to/configs/zero/zero2.yaml # e.g. need use absolute path
- zero3_init_flag: false
-
-downcast_bf16: 'no'
-enable_cpu_affinity: false
-machine_rank: 0
-main_training_function: main
-num_machines: 1
-rdzv_backend: static
-same_network: true
-tpu_env: []
-tpu_use_cluster: false
-tpu_use_sudo: false
-use_cpu: false
diff --git a/quickstart/configs/zero/zero2.yaml b/quickstart/configs/zero/zero2.yaml
deleted file mode 100644
index b056bd4..0000000
--- a/quickstart/configs/zero/zero2.yaml
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "bf16": {
- "enabled": true
- },
- "optimizer": {
- "type": "AdamW",
- "params": {
- "lr": "auto",
- "weight_decay": "auto",
- "torch_adam": true,
- "adam_w_mode": true
- }
- },
- "scheduler": {
- "type": "WarmupDecayLR",
- "params": {
- "warmup_min_lr": "auto",
- "warmup_max_lr": "auto",
- "warmup_num_steps": "auto",
- "total_num_steps": "auto"
- }
- },
- "zero_optimization": {
- "stage": 2,
- "allgather_partitions": true,
- "allgather_bucket_size": 2e8,
- "overlap_comm": true,
- "reduce_scatter": true,
- "reduce_bucket_size": 5e8,
- "contiguous_gradients": true
- },
- "gradient_accumulation_steps": 1,
- "train_micro_batch_size_per_gpu": 1,
- "train_batch_size": "auto",
- "gradient_clipping": "auto",
- "steps_per_print": 2000,
- "wall_clock_breakdown": false
-}
diff --git a/quickstart/configs/zero/zero2_offload.yaml b/quickstart/configs/zero/zero2_offload.yaml
deleted file mode 100644
index 24fdcb4..0000000
--- a/quickstart/configs/zero/zero2_offload.yaml
+++ /dev/null
@@ -1,42 +0,0 @@
-{
- "bf16": {
- "enabled": true
- },
- "optimizer": {
- "type": "AdamW",
- "params": {
- "lr": "auto",
- "weight_decay": "auto",
- "torch_adam": true,
- "adam_w_mode": true
- }
- },
- "scheduler": {
- "type": "WarmupDecayLR",
- "params": {
- "warmup_min_lr": "auto",
- "warmup_max_lr": "auto",
- "warmup_num_steps": "auto",
- "total_num_steps": "auto"
- }
- },
- "zero_optimization": {
- "stage": 2,
- "allgather_partitions": true,
- "allgather_bucket_size": 2e8,
- "overlap_comm": true,
- "reduce_scatter": true,
- "reduce_bucket_size": 5e8,
- "contiguous_gradients": true,
- "offload_optimizer": {
- "device": "cpu",
- "pin_memory": true
- }
- },
- "gradient_accumulation_steps": 1,
- "train_micro_batch_size_per_gpu": 1,
- "train_batch_size": "auto",
- "gradient_clipping": "auto",
- "steps_per_print": 2000,
- "wall_clock_breakdown": false
-}
diff --git a/quickstart/configs/zero/zero3.yaml b/quickstart/configs/zero/zero3.yaml
deleted file mode 100644
index 18685d0..0000000
--- a/quickstart/configs/zero/zero3.yaml
+++ /dev/null
@@ -1,41 +0,0 @@
-{
- "bf16": {
- "enabled": true
- },
- "optimizer": {
- "type": "AdamW",
- "params": {
- "lr": "auto",
- "weight_decay": "auto",
- "torch_adam": true,
- "adam_w_mode": true
- }
- },
- "scheduler": {
- "type": "WarmupDecayLR",
- "params": {
- "warmup_min_lr": "auto",
- "warmup_max_lr": "auto",
- "warmup_num_steps": "auto",
- "total_num_steps": "auto"
- }
- },
- "zero_optimization": {
- "stage": 3,
- "overlap_comm": true,
- "contiguous_gradients": true,
- "reduce_bucket_size": 5e8,
- "sub_group_size": 1e9,
- "stage3_max_live_parameters": 1e9,
- "stage3_max_reuse_distance": 1e9,
- "stage3_gather_16bit_weights_on_model_save": "auto",
- "stage3_prefetch_bucket_size": 5e8,
- "stage3_param_persistence_threshold": 1e5
- },
- "gradient_accumulation_steps": 1,
- "train_micro_batch_size_per_gpu": 1,
- "train_batch_size": "auto",
- "gradient_clipping": "auto",
- "steps_per_print": 2000,
- "wall_clock_breakdown": false
-}
diff --git a/quickstart/configs/zero/zero3_offload.yaml b/quickstart/configs/zero/zero3_offload.yaml
deleted file mode 100644
index e780e2f..0000000
--- a/quickstart/configs/zero/zero3_offload.yaml
+++ /dev/null
@@ -1,49 +0,0 @@
-{
- "bf16": {
- "enabled": true
- },
- "optimizer": {
- "type": "AdamW",
- "params": {
- "lr": "auto",
- "weight_decay": "auto",
- "torch_adam": true,
- "adam_w_mode": true
- }
- },
- "scheduler": {
- "type": "WarmupDecayLR",
- "params": {
- "warmup_min_lr": "auto",
- "warmup_max_lr": "auto",
- "warmup_num_steps": "auto",
- "total_num_steps": "auto"
- }
- },
- "zero_optimization": {
- "stage": 3,
- "offload_optimizer": {
- "device": "cpu",
- "pin_memory": true
- },
- "offload_param": {
- "device": "cpu",
- "pin_memory": true
- },
- "overlap_comm": true,
- "contiguous_gradients": true,
- "reduce_bucket_size": 5e8,
- "sub_group_size": 1e9,
- "stage3_max_live_parameters": 1e9,
- "stage3_max_reuse_distance": 1e9,
- "stage3_gather_16bit_weights_on_model_save": "auto",
- "stage3_prefetch_bucket_size": 5e8,
- "stage3_param_persistence_threshold": 1e6
- },
- "gradient_accumulation_steps": 1,
- "train_micro_batch_size_per_gpu": 1,
- "train_batch_size": "auto",
- "gradient_clipping": "auto",
- "steps_per_print": 2000,
- "wall_clock_breakdown": false
-}
diff --git a/quickstart/scripts/i2v/config.yaml b/quickstart/scripts/i2v/config.yaml
new file mode 100644
index 0000000..43e1afa
--- /dev/null
+++ b/quickstart/scripts/i2v/config.yaml
@@ -0,0 +1,65 @@
+# ================ Logging ================
+name4train: "i2v-train"
+log_level: "INFO" # Options: ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
+
+# ================ Model ================
+model_name: "cogvideox1.5-i2v" # Options: ["cogvideox-i2v", "cogvideox1.5-i2v"]
+model_path: "THUDM/CogVideoX1.5-5B-I2V"
+
+
+# ================ Output ================
+output_dir: "/path/to/output"
+
+
+# ================ Tracker ================
+report_to: null # Options: ["wandb"]
+
+
+# ================ Data ================
+data_root: "/path/to/i2v/data"
+
+
+# ================ Training ================
+seed: 42
+training_type: "lora" # Options: ["lora", "sft"]
+
+strategy: "DDP" # Options: ["DDP", "SHARD_GRAD_OP", "FULL_SHARD", "HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"]
+
+# This will offload model param and grads to CPU memory to save GPU memory, but will slow down training
+offload_params_grads: false
+
+# This will increase memory usage since gradients are sharded during accumulation step.
+# Note: When used with offload_params_grads, model parameters and gradients will only be offloaded
+# to the CPU during the final synchronization (still retained on GPU in gradient accumulation steps)
+# which means offload_params_grads is meaningless when used with no_grad_sync_when_accumulating
+no_grad_sync_when_accumulating: false
+
+# When enable_packing is true, training will use the native image resolution,
+# otherwise all images will be resized to train_resolution, which may distort the original aspect ratio.
+# IMPORTANT: When changing enable_packing from true to false (or false to true),
+# make sure to clear the `.cache` directories in your `data_root/train` and `data_root/test` folders if they exist.
+enable_packing: false
+
+# Note:
+# for CogVideoX series models, number of training frames should be **8N+1**
+# for CogVideoX1.5 series models, number of training frames should be **16N+1**
+train_resolution: [81, 768, 1360] # [Frames, Height, Width]
+
+train_epochs: 1
+batch_size: 1
+gradient_accumulation_steps: 1
+mixed_precision: "bf16" # Options: ["fp32", "fp16", "bf16"]
+learning_rate: 2.0e-5
+
+num_workers: 8
+pin_memory: true
+
+checkpointing_steps: 10
+checkpointing_limit: 2
+resume_from_checkpoint: null # or "/path/to/checkpoint/dir"
+
+
+# ================ Validation ================
+do_validation: true
+validation_steps: 10 # Must be a multiple of `checkpointing_steps`
+gen_fps: 16
diff --git a/quickstart/scripts/i2v/start_train.sh b/quickstart/scripts/i2v/start_train.sh
new file mode 100644
index 0000000..0dc933f
--- /dev/null
+++ b/quickstart/scripts/i2v/start_train.sh
@@ -0,0 +1,7 @@
+#! /usr/bin/env bash
+
+torchrun \
+ --nproc_per_node=[number of GPUs] \
+ --master_port=29501 \
+ ../train.py \
+ --yaml config.yaml
diff --git a/quickstart/scripts/t2i/config.yaml b/quickstart/scripts/t2i/config.yaml
new file mode 100644
index 0000000..42210ac
--- /dev/null
+++ b/quickstart/scripts/t2i/config.yaml
@@ -0,0 +1,64 @@
+# ================ Logging ================
+name4train: "t2i-train"
+log_level: "INFO" # Options: ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
+
+# ================ Model ================
+model_name: "cogview4-6b" # Options: ["cogview4-6b"]
+model_path: "THUDM/CogView4-6B"
+
+
+# ================ Output ================
+output_dir: "/path/to/output"
+
+
+# ================ Tracker ================
+report_to: null # Options: ["wandb"]
+
+
+# ================ Data ================
+data_root: "/path/to/t2i/data"
+
+# ================ Training ================
+seed: 42
+training_type: "lora" # Options: ["lora", "sft"]
+
+strategy: "DDP" # Options: ["DDP", "SHARD_GRAD_OP", "FULL_SHARD", "HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"]
+
+# This will offload model param and grads to CPU memory to save GPU memory, but will slow down training
+offload_params_grads: false
+
+# This will increase memory usage since gradients are sharded during accumulation step.
+# Note: When used with offload_params_grads, model parameters and gradients will only be offloaded
+# to the CPU during the final synchronization (still retained on GPU in gradient accumulation steps)
+# which means offload_params_grads is meaningless when used with no_grad_sync_when_accumulating
+no_grad_sync_when_accumulating: false
+
+# When enable_packing is true, training will use the native image resolution,
+# otherwise all images will be resized to train_resolution, which may distort the original aspect ratio.
+# IMPORTANT: When changing enable_packing from true to false (or false to true),
+# make sure to clear the `.cache` directories in your `data_root/train` and `data_root/test` folders if they exist.
+enable_packing: false
+
+# This will slow down validation speed and enable quantization during training to save GPU memory
+low_vram: false
+
+# Note: For CogView4 series models, height and width should be **32N** (multiple of 32)
+train_resolution: [1024, 1024] # [Height, Width]
+
+train_epochs: 1
+batch_size: 1
+gradient_accumulation_steps: 1
+mixed_precision: "bf16" # Options: ["fp32", "fp16", "bf16"]
+learning_rate: 2.0e-5
+
+num_workers: 8
+pin_memory: true
+
+checkpointing_steps: 10
+checkpointing_limit: 2
+resume_from_checkpoint: null # or "/path/to/checkpoint/dir"
+
+
+# ================ Validation ================
+do_validation: true
+validation_steps: 10 # Must be a multiple of `checkpointing_steps`
diff --git a/quickstart/scripts/t2i/start_train.sh b/quickstart/scripts/t2i/start_train.sh
new file mode 100644
index 0000000..0dc933f
--- /dev/null
+++ b/quickstart/scripts/t2i/start_train.sh
@@ -0,0 +1,7 @@
+#! /usr/bin/env bash
+
+torchrun \
+ --nproc_per_node=[number of GPUs] \
+ --master_port=29501 \
+ ../train.py \
+ --yaml config.yaml
diff --git a/quickstart/scripts/t2v/config.yaml b/quickstart/scripts/t2v/config.yaml
new file mode 100644
index 0000000..4156f69
--- /dev/null
+++ b/quickstart/scripts/t2v/config.yaml
@@ -0,0 +1,64 @@
+# ================ Logging ================
+name4train: "t2v-train"
+log_level: "INFO" # Options: ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
+
+# ================ Model ================
+model_name: "cogvideox1.5-t2v" # Options: ["cogvideox-t2v", "cogvideox1.5-t2v"]
+model_path: "THUDM/CogVideoX1.5-5B"
+
+# ================ Output ================
+output_dir: "/path/to/output"
+
+
+# ================ Tracker ================
+report_to: null # Options: ["wandb"]
+
+
+# ================ Data ================
+data_root: "/path/to/t2v/data"
+
+
+# ================ Training ================
+seed: 42
+training_type: "lora" # Options: ["lora", "sft"]
+
+strategy: "DDP" # Options: ["DDP", "SHARD_GRAD_OP", "FULL_SHARD", "HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"]
+
+# This will offload model param and grads to CPU memory to save GPU memory, but will slow down training
+offload_params_grads: false
+
+# This will increase memory usage since gradients are sharded during accumulation step.
+# Note: When used with offload_params_grads, model parameters and gradients will only be offloaded
+# to the CPU during the final synchronization (still retained on GPU in gradient accumulation steps)
+# which means offload_params_grads is meaningless when used with no_grad_sync_when_accumulating
+no_grad_sync_when_accumulating: false
+
+# When enable_packing is true, training will use the native image resolution,
+# otherwise all images will be resized to train_resolution, which may distort the original aspect ratio.
+# IMPORTANT: When changing enable_packing from true to false (or false to true),
+# make sure to clear the `.cache` directories in your `data_root/train` and `data_root/test` folders if they exist.
+enable_packing: false
+
+# Note:
+# for CogVideoX series models, number of training frames should be **8N+1**
+# for CogVideoX1.5 series models, number of training frames should be **16N+1**
+train_resolution: [81, 768, 1360] # [Frames, Height, Width]
+
+train_epochs: 1
+batch_size: 1
+gradient_accumulation_steps: 1
+mixed_precision: "bf16" # Options: ["fp32", "fp16", "bf16"]
+learning_rate: 2.0e-5
+
+num_workers: 8
+pin_memory: true
+
+checkpointing_steps: 10
+checkpointing_limit: 2
+resume_from_checkpoint: null # or "/path/to/checkpoint/dir"
+
+
+# ================ Validation ================
+do_validation: true
+validation_steps: 10 # Must be a multiple of `checkpointing_steps`
+gen_fps: 16
diff --git a/quickstart/scripts/t2v/start_train.sh b/quickstart/scripts/t2v/start_train.sh
new file mode 100644
index 0000000..0dc933f
--- /dev/null
+++ b/quickstart/scripts/t2v/start_train.sh
@@ -0,0 +1,7 @@
+#! /usr/bin/env bash
+
+torchrun \
+ --nproc_per_node=[number of GPUs] \
+ --master_port=29501 \
+ ../train.py \
+ --yaml config.yaml
diff --git a/quickstart/scripts/train.py b/quickstart/scripts/train.py
index e899273..d9e901a 100644
--- a/quickstart/scripts/train.py
+++ b/quickstart/scripts/train.py
@@ -1,17 +1,21 @@
import argparse
+import yaml
from cogkit.finetune import get_model_cls
def main():
parser = argparse.ArgumentParser()
- parser.add_argument("--model_name", type=str, required=True)
- parser.add_argument("--training_type", type=str, required=True)
- parser.add_argument("--enable_packing", type=lambda x: x.lower() == "true")
- args, unknown = parser.parse_known_args()
+ parser.add_argument("--yaml", type=str, required=True)
+ args = parser.parse_args()
- trainer_cls = get_model_cls(args.model_name, args.training_type, args.enable_packing)
- trainer = trainer_cls()
+ with open(args.yaml, "r") as f:
+ config = yaml.safe_load(f)
+
+ trainer_cls = get_model_cls(
+ config["model_name"], config["training_type"], config["enable_packing"]
+ )
+ trainer = trainer_cls(args.yaml)
trainer.fit()
diff --git a/quickstart/scripts/train_ddp_i2v.sh b/quickstart/scripts/train_ddp_i2v.sh
deleted file mode 100755
index 57824eb..0000000
--- a/quickstart/scripts/train_ddp_i2v.sh
+++ /dev/null
@@ -1,70 +0,0 @@
-#!/usr/bin/env bash
-# Run by `bash scripts/train_ddp_i2v.sh`
-
-# Prevent tokenizer parallelism issues
-export TOKENIZERS_PARALLELISM=false
-
-# Model Configuration
-MODEL_ARGS=(
- --model_path "THUDM/CogVideoX1.5-5B-I2V"
- --model_name "cogvideox1.5-i2v" # candidate: ["cogvideox-i2v", "cogvideox1.5-i2v"]
- --model_type "i2v"
- --training_type "lora"
-)
-
-# Output Configuration
-OUTPUT_ARGS=(
- --output_dir "/path/to/output"
- --report_to "tensorboard"
-)
-
-# Data Configuration
-DATA_ARGS=(
- --data_root "/path/to/data"
-)
-
-# Training Configuration
-TRAIN_ARGS=(
- --seed 42 # random seed
- --train_epochs 1 # number of training epochs
- --batch_size 1
- --gradient_accumulation_steps 1
- --mixed_precision "bf16" # ["no", "fp16"]
- --learning_rate 5e-5
-
- # Note:
- # for CogVideoX series models, number of training frames should be **8N+1**
- # for CogVideoX1.5 series models, number of training frames should be **16N+1**
- --train_resolution "81x768x1360" # (frames x height x width)
-)
-
-# System Configuration
-SYSTEM_ARGS=(
- --num_workers 8
- --pin_memory true
- --nccl_timeout 1800
-)
-
-# Checkpointing Configuration
-CHECKPOINT_ARGS=(
- --checkpointing_steps 10 # save checkpoint every x steps
- --checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
- # --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint
-)
-
-# Validation Configuration
-VALIDATION_ARGS=(
- --do_validation true # ["true", "false"]
- --validation_steps 10 # should be multiple of checkpointing_steps
- --gen_fps 16
-)
-
-# Combine all arguments and launch training
-accelerate launch train.py \
- "${MODEL_ARGS[@]}" \
- "${OUTPUT_ARGS[@]}" \
- "${DATA_ARGS[@]}" \
- "${TRAIN_ARGS[@]}" \
- "${SYSTEM_ARGS[@]}" \
- "${CHECKPOINT_ARGS[@]}" \
- "${VALIDATION_ARGS[@]}"
diff --git a/quickstart/scripts/train_ddp_t2i.sh b/quickstart/scripts/train_ddp_t2i.sh
deleted file mode 100755
index 6aae45f..0000000
--- a/quickstart/scripts/train_ddp_t2i.sh
+++ /dev/null
@@ -1,80 +0,0 @@
-#!/usr/bin/env bash
-# Run by `bash scripts/train_ddp_i2v.sh`
-
-# Prevent tokenizer parallelism issues
-export TOKENIZERS_PARALLELISM=false
-
-# Model Configuration
-MODEL_ARGS=(
- --model_path "THUDM/CogView4-6B"
- --model_name "cogview4-6b" # candidate: ["cogview4-6b"]
- --model_type "t2i"
- --training_type "lora"
-)
-
-# Output Configuration
-OUTPUT_ARGS=(
- --output_dir "/path/to/output"
- --report_to "tensorboard"
-)
-
-# Data Configuration
-DATA_ARGS=(
- --data_root "/path/to/data"
-)
-
-# Training Configuration
-TRAIN_ARGS=(
- --seed 42 # random seed
- --train_epochs 1 # number of training epochs
- --batch_size 1
-
- --gradient_accumulation_steps 1
-
- # Note: For CogView4 series models, height and width should be **32N** (multiple of 32)
- --train_resolution "1024x1024" # (height x width)
-
- # When enable_packing is true, training will use the native image resolution
- # (otherwise all images will be resized to train_resolution, which may distort the original aspect ratio).
- #
- # IMPORTANT: When changing enable_packing from true to false (or vice versa),
- # make sure to clear the .cache directories in your data_root/train and data_root/test folders if they exist.
- --enable_packing false
-
- --mixed_precision "bf16" # ["no", "fp16"]
- --learning_rate 5e-5
-
- # enable --low_vram will slow down validation speed and enable quantization during training
- # Note: --low_vram currently does not support multi-GPU training
- --low_vram false
-)
-
-# System Configuration
-SYSTEM_ARGS=(
- --num_workers 8
- --pin_memory true
- --nccl_timeout 1800
-)
-
-# Checkpointing Configuration
-CHECKPOINT_ARGS=(
- --checkpointing_steps 10 # save checkpoint every x steps
- --checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
- # --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint
-)
-
-# Validation Configuration
-VALIDATION_ARGS=(
- --do_validation true # ["true", "false"]
- --validation_steps 10 # should be multiple of checkpointing_steps
-)
-
-# Combine all arguments and launch training
-accelerate launch train.py \
- "${MODEL_ARGS[@]}" \
- "${OUTPUT_ARGS[@]}" \
- "${DATA_ARGS[@]}" \
- "${TRAIN_ARGS[@]}" \
- "${SYSTEM_ARGS[@]}" \
- "${CHECKPOINT_ARGS[@]}" \
- "${VALIDATION_ARGS[@]}"
diff --git a/quickstart/scripts/train_ddp_t2v.sh b/quickstart/scripts/train_ddp_t2v.sh
deleted file mode 100755
index ca31e49..0000000
--- a/quickstart/scripts/train_ddp_t2v.sh
+++ /dev/null
@@ -1,69 +0,0 @@
-#!/usr/bin/env bash
-
-# Prevent tokenizer parallelism issues
-export TOKENIZERS_PARALLELISM=false
-
-# Model Configuration
-MODEL_ARGS=(
- --model_path "THUDM/CogVideoX1.5-5B"
- --model_name "cogvideox1.5-t2v" # candidate: ["cogvideox-t2v", "cogvideox1.5-t2v"]
- --model_type "t2v"
- --training_type "lora"
-)
-
-# Output Configuration
-OUTPUT_ARGS=(
- --output_dir "/path/to/output"
- --report_to "tensorboard"
-)
-
-# Data Configuration
-DATA_ARGS=(
- --data_root "/path/to/data"
-)
-
-# Training Configuration
-TRAIN_ARGS=(
- --seed 42 # random seed
- --train_epochs 1 # number of training epochs
- --batch_size 1
- --gradient_accumulation_steps 1
- --mixed_precision "bf16" # ["no", "fp16"] Note: CogVideoX-2B only supports fp16 training
- --learning_rate 5e-5
-
- # Note:
- # for CogVideoX series models, number of training frames should be **8N+1**
- # for CogVideoX1.5 series models, number of training frames should be **16N+1**
- --train_resolution "81x768x1360" # (frames x height x width)
-)
-
-# System Configuration
-SYSTEM_ARGS=(
- --num_workers 8
- --pin_memory true
- --nccl_timeout 1800
-)
-
-# Checkpointing Configuration
-CHECKPOINT_ARGS=(
- --checkpointing_steps 10 # save checkpoint every x steps
- --checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
- # --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint
-)
-
-# Validation Configuration
-VALIDATION_ARGS=(
- --do_validation true # ["true", "false"]
- --validation_steps 10 # should be multiple of checkpointing_steps
- --gen_fps 16
-)
-
-# Combine all arguments and launch training
-accelerate launch train.py \
- "${MODEL_ARGS[@]}" \
- "${OUTPUT_ARGS[@]}" \
- "${DATA_ARGS[@]}" \
- "${TRAIN_ARGS[@]}" \
- "${SYSTEM_ARGS[@]}" \
- "${CHECKPOINT_ARGS[@]}" \
- "${VALIDATION_ARGS[@]}"
diff --git a/quickstart/scripts/train_zero_i2v.sh b/quickstart/scripts/train_zero_i2v.sh
deleted file mode 100755
index bfa07b7..0000000
--- a/quickstart/scripts/train_zero_i2v.sh
+++ /dev/null
@@ -1,74 +0,0 @@
-#!/usr/bin/env bash
-
-# Prevent tokenizer parallelism issues
-export TOKENIZERS_PARALLELISM=false
-
-# Model Configuration
-MODEL_ARGS=(
- --model_path "THUDM/CogVideoX1.5-5B-I2V"
- --model_name "cogvideox1.5-i2v" # candidate: ["cogvideox-i2v", "cogvideox1.5-i2v"]
- --model_type "i2v"
- --training_type "sft"
-)
-
-# Output Configuration
-OUTPUT_ARGS=(
- --output_dir "/path/to/output"
- --report_to "tensorboard"
-)
-
-# Data Configuration
-DATA_ARGS=(
- --data_root "/path/to/data"
-)
-
-# Training Configuration
-TRAIN_ARGS=(
- --seed 42 # random seed
- --train_epochs 1 # number of training epochs
-
- --learning_rate 5e-5
-
- ######### Please keep consistent with deepspeed config file ##########
- --batch_size 1
- --gradient_accumulation_steps 1
- --mixed_precision "bf16" # ["no", "fp16"] Note: CogVideoX-2B only supports fp16 training
- ########################################################################
-
- # Note:
- # for CogVideoX series models, number of training frames should be **8N+1**
- # for CogVideoX1.5 series models, number of training frames should be **16N+1**
- --train_resolution "81x768x1360" # (frames x height x width)
-
-)
-
-# System Configuration
-SYSTEM_ARGS=(
- --num_workers 8
- --pin_memory true
- --nccl_timeout 1800
-)
-
-# Checkpointing Configuration
-CHECKPOINT_ARGS=(
- --checkpointing_steps 10 # save checkpoint every x steps
- --checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
- # --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint
-)
-
-# Validation Configuration
-VALIDATION_ARGS=(
- --do_validation true # ["true", "false"]
- --validation_steps 10 # should be multiple of checkpointing_steps
- --gen_fps 16
-)
-
-# Combine all arguments and launch training
-accelerate launch --config_file ../configs/accelerate_config.yaml train.py \
- "${MODEL_ARGS[@]}" \
- "${OUTPUT_ARGS[@]}" \
- "${DATA_ARGS[@]}" \
- "${TRAIN_ARGS[@]}" \
- "${SYSTEM_ARGS[@]}" \
- "${CHECKPOINT_ARGS[@]}" \
- "${VALIDATION_ARGS[@]}"
diff --git a/quickstart/scripts/train_zero_t2i.sh b/quickstart/scripts/train_zero_t2i.sh
deleted file mode 100755
index 878cd27..0000000
--- a/quickstart/scripts/train_zero_t2i.sh
+++ /dev/null
@@ -1,78 +0,0 @@
-#!/usr/bin/env bash
-
-# Prevent tokenizer parallelism issues
-export TOKENIZERS_PARALLELISM=false
-
-# Model Configuration
-MODEL_ARGS=(
- --model_path "THUDM/CogView4-6B"
- --model_name "cogview4-6b" # candidate: ["cogview4-6b"]
- --model_type "t2i"
- --training_type "sft"
-)
-
-# Output Configuration
-OUTPUT_ARGS=(
- --output_dir "/path/to/output"
- --report_to "tensorboard"
-)
-
-# Data Configuration
-DATA_ARGS=(
- --data_root "/path/to/data"
-)
-
-# Training Configuration
-TRAIN_ARGS=(
- --seed 42 # random seed
- --train_epochs 1 # number of training epochs
-
- --learning_rate 5e-5
-
- # Note: For CogView4 series models, height and width should be **32N** (multiple of 32)
- --train_resolution "1024x1024" # (height x width)
-
- ######### Please keep consistent with deepspeed config file ##########
- --batch_size 1
- --gradient_accumulation_steps 1
- --mixed_precision "bf16" # ["no", "fp16"] Note: CogVideoX-2B only supports fp16 training
- ########################################################################
-
- # When enable_packing is true, training will use the native image resolution
- # (otherwise all images will be resized to train_resolution, which may distort the original aspect ratio).
- #
- # IMPORTANT: When changing enable_packing from true to false (or vice versa),
- # make sure to clear the .cache directories in your data_root/train and data_root/test folders if they exist.
- --enable_packing false
-
-)
-
-# System Configuration
-SYSTEM_ARGS=(
- --num_workers 8
- --pin_memory true
- --nccl_timeout 1800
-)
-
-# Checkpointing Configuration
-CHECKPOINT_ARGS=(
- --checkpointing_steps 10 # save checkpoint every x steps
- --checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
- # --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint
-)
-
-# Validation Configuration
-VALIDATION_ARGS=(
- --do_validation true # ["true", "false"]
- --validation_steps 10 # should be multiple of checkpointing_steps
-)
-
-# Combine all arguments and launch training
-accelerate launch --config_file ../configs/accelerate_config.yaml train.py\
- "${MODEL_ARGS[@]}" \
- "${OUTPUT_ARGS[@]}" \
- "${DATA_ARGS[@]}" \
- "${TRAIN_ARGS[@]}" \
- "${SYSTEM_ARGS[@]}" \
- "${CHECKPOINT_ARGS[@]}" \
- "${VALIDATION_ARGS[@]}"
diff --git a/quickstart/scripts/train_zero_t2v.sh b/quickstart/scripts/train_zero_t2v.sh
deleted file mode 100755
index 516afc2..0000000
--- a/quickstart/scripts/train_zero_t2v.sh
+++ /dev/null
@@ -1,73 +0,0 @@
-#!/usr/bin/env bash
-
-# Prevent tokenizer parallelism issues
-export TOKENIZERS_PARALLELISM=false
-
-# Model Configuration
-MODEL_ARGS=(
- --model_path "THUDM/CogVideoX1.5-5B"
- --model_name "cogvideox1.5-t2v" # candidate: ["cogvideox-t2v", "cogvideox1.5-t2v"]
- --model_type "t2v"
- --training_type "sft"
-)
-
-# Output Configuration
-OUTPUT_ARGS=(
- --output_dir "/path/to/output"
- --report_to "tensorboard"
-)
-
-# Data Configuration
-DATA_ARGS=(
- --data_root "/path/to/data"
-)
-
-# Training Configuration
-TRAIN_ARGS=(
- --seed 42 # random seed
- --train_epochs 1 # number of training epochs
-
- --learning_rate 5e-5
-
- ######### Please keep consistent with deepspeed config file ##########
- --batch_size 1
- --gradient_accumulation_steps 1
- --mixed_precision "bf16" # ["no", "fp16"] Note: CogVideoX-2B only supports fp16 training
- ########################################################################
-
- # Note:
- # for CogVideoX series models, number of training frames should be **8N+1**
- # for CogVideoX1.5 series models, number of training frames should be **16N+1**
- --train_resolution "81x768x1360" # (frames x height x width)
-)
-
-# System Configuration
-SYSTEM_ARGS=(
- --num_workers 8
- --pin_memory true
- --nccl_timeout 1800
-)
-
-# Checkpointing Configuration
-CHECKPOINT_ARGS=(
- --checkpointing_steps 10 # save checkpoint every x steps
- --checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
- # --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint
-)
-
-# Validation Configuration
-VALIDATION_ARGS=(
- --do_validation true # ["true", "false"]
- --validation_steps 10 # should be multiple of checkpointing_steps
- --gen_fps 16
-)
-
-# Combine all arguments and launch training
-accelerate launch --config_file ../configs/accelerate_config.yaml train.py \
- "${MODEL_ARGS[@]}" \
- "${OUTPUT_ARGS[@]}" \
- "${DATA_ARGS[@]}" \
- "${TRAIN_ARGS[@]}" \
- "${SYSTEM_ARGS[@]}" \
- "${CHECKPOINT_ARGS[@]}" \
- "${VALIDATION_ARGS[@]}"
diff --git a/src/cogkit/finetune/__init__.py b/src/cogkit/finetune/__init__.py
index f7878a6..bd92860 100644
--- a/src/cogkit/finetune/__init__.py
+++ b/src/cogkit/finetune/__init__.py
@@ -1,20 +1,21 @@
# -*- coding: utf-8 -*-
-
-from cogkit.finetune.base import BaseTrainer
-
# import register first
-from cogkit.finetune.register import get_model_cls, register, show_supported_models # noqa
+from ._register import get_model_cls, register, show_supported_models # noqa
+
+from .base import BaseTrainer
# import resgistered models
-from cogkit.finetune.diffusion import models as diffusion_models
-from cogkit.finetune.llm import models as llm_models
+from .diffusion import models as diffusion_models
+from .llm import models as llm_models
+from .logger import get_logger
__all__ = [
"BaseTrainer",
"diffusion_models",
"llm_models",
+ "get_logger",
"get_model_cls",
"register",
"show_supported_models",
diff --git a/src/cogkit/finetune/register.py b/src/cogkit/finetune/_register.py
similarity index 77%
rename from src/cogkit/finetune/register.py
rename to src/cogkit/finetune/_register.py
index c0efe51..1e5a18c 100644
--- a/src/cogkit/finetune/register.py
+++ b/src/cogkit/finetune/_register.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
-from typing import Literal, TYPE_CHECKING
+from typing import TYPE_CHECKING, Literal
# using TYPE_CHECKING to avoid circular import
if TYPE_CHECKING:
@@ -46,26 +46,26 @@ def show_supported_models():
def get_model_cls(
- model_type: str, training_type: Literal["lora", "sft"], use_packing: bool = False
+ model_name: str, training_type: Literal["lora", "sft"], use_packing: bool = False
) -> "BaseTrainer":
"""Get the trainer class for a specific model and training type."""
- if model_type not in SUPPORTED_MODELS:
- print(f"\nModel '{model_type}' is not supported.")
+ if model_name not in SUPPORTED_MODELS:
+ print(f"\nModel '{model_name}' is not supported.")
print("\nSupported models are:")
for supported_model in SUPPORTED_MODELS:
print(f" • {supported_model}")
- raise ValueError(f"Model '{model_type}' is not supported")
+ raise ValueError(f"Model '{model_name}' is not supported")
if use_packing:
training_type = f"{training_type}-packing"
- if training_type not in SUPPORTED_MODELS[model_type]:
- print(f"\nTraining type '{training_type}' is not supported for model '{model_type}'.")
- print(f"\nSupported training types for '{model_type}' are:")
- for supported_type in SUPPORTED_MODELS[model_type]:
+ if training_type not in SUPPORTED_MODELS[model_name]:
+ print(f"\nTraining type '{training_type}' is not supported for model '{model_name}'.")
+ print(f"\nSupported training types for '{model_name}' are:")
+ for supported_type in SUPPORTED_MODELS[model_name]:
print(f" • {supported_type}")
raise ValueError(
- f"Training type '{training_type}' is not supported for model '{model_type}'"
+ f"Training type '{training_type}' is not supported for model '{model_name}'"
)
- return SUPPORTED_MODELS[model_type][training_type]
+ return SUPPORTED_MODELS[model_name][training_type]
diff --git a/src/cogkit/finetune/base/base_args.py b/src/cogkit/finetune/base/base_args.py
index 07e048d..d66f170 100644
--- a/src/cogkit/finetune/base/base_args.py
+++ b/src/cogkit/finetune/base/base_args.py
@@ -1,77 +1,143 @@
# -*- coding: utf-8 -*-
-
-import argparse
import datetime
import logging
+from datetime import timedelta
from pathlib import Path
from typing import Literal
+import yaml
from pydantic import BaseModel, ValidationInfo, field_validator
class BaseArgs(BaseModel):
+ model_config = {"frozen": True, "extra": "ignore"}
+
+ ########## Logging ##########
+ name4train: str
+ log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
+
########## Model ##########
model_path: Path
model_name: str
- training_type: Literal["lora", "sft"] = "lora"
########## Output ##########
output_dir: Path = Path(f"train_result/{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}")
- report_to: Literal["tensorboard", "wandb", "all"] | None = None
- tracker_name: str = "base-tracker"
+
+ ########## Tracker ##########
+ report_to: Literal["wandb"] | None = None
########## Data Path ###########
data_root: Path
########## Training #########
+ training_type: Literal["lora", "sft"] = "lora"
+ strategy: Literal[
+ "DDP", "SHARD_GRAD_OP", "FULL_SHARD", "HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"
+ ] = "FULL_SHARD"
+ # This will offload model param and grads to CPU memory to save GPU memory, but will slow down training
+ offload_params_grads: bool = False
+ # This will increase memory usage since gradients are sharded during accumulation step.
+ # Note, when used with offload_params_grads, model parameters and gradients will only be offloaded
+ # to the CPU during the final synchronization (still retained on GPU in gradient accumulation steps)
+ # which means offload_params_grads is meaningless when used with no_grad_sync_when_accumulating
+ no_grad_sync_when_accumulating: bool = False
+
resume_from_checkpoint: Path | None = None
seed: int | None = None
train_epochs: int
- train_steps: int | None = None
- checkpointing_steps: int = 200
- checkpointing_limit: int = 10
+ checkpointing_steps: int
+ checkpointing_limit: int
batch_size: int
gradient_accumulation_steps: int = 1
- mixed_precision: Literal["no", "fp16", "bf16"]
+ mixed_precision: Literal["fp32", "fp16", "bf16"]
low_vram: bool = False
learning_rate: float = 2e-5
optimizer: str = "adamw"
beta1: float = 0.9
beta2: float = 0.95
- beta3: float = 0.98
epsilon: float = 1e-8
weight_decay: float = 1e-4
max_grad_norm: float = 1.0
- lr_scheduler: str = "linear"
- lr_warmup_ratio: float = 0.01
- lr_num_cycles: int = 1
- lr_power: float = 1.0
+ lr_scheduler: str = "CosineAnnealingLR"
num_workers: int = 8
pin_memory: bool = True
gradient_checkpointing: bool = True
- nccl_timeout: int = 1800
-
- ########## Lora ##########
- rank: int = 128
- lora_alpha: int = 64
- target_modules: list[str] = ["to_q", "to_k", "to_v", "to_out.0"]
+ nccl_timeout: timedelta = timedelta(seconds=1800)
########## Validation ##########
do_validation: bool = False
validation_steps: int | None # if set, should be a multiple of checkpointing_steps
+ @field_validator("log_level")
+ def validate_log_level(cls, v: str) -> str:
+ match v:
+ case "DEBUG":
+ return logging.DEBUG
+ case "INFO":
+ return logging.INFO
+ case "WARNING":
+ return logging.WARNING
+ case "ERROR":
+ return logging.ERROR
+ case "CRITICAL":
+ return logging.CRITICAL
+ case _:
+ raise ValueError("log_level must be one of: DEBUG, INFO, WARNING, ERROR, CRITICAL")
+
+ @field_validator("nccl_timeout")
+ def validate_nccl_timeout(cls, v: timedelta | int) -> timedelta:
+ if isinstance(v, int):
+ return timedelta(seconds=v)
+ return v
+
@field_validator("low_vram")
def validate_low_vram(cls, v: bool, info: ValidationInfo) -> bool:
if v and info.data.get("training_type") != "lora":
raise ValueError("low_vram can only be True when training_type is 'lora'")
+ if v and info.data.get("offload_params_grads"):
+ raise ValueError("low_vram and offload_params_grads cannot be enabled simultaneously")
+ if v and info.data.get("strategy") != "DDP":
+ raise ValueError("low_vram can only be used with strategy='DDP'")
+ if v and info.data.get("resume_from_checkpoint") is not None:
+ raise ValueError("resume_from_checkpoint cannot be used when low_vram is True")
+ return v
+
+ @field_validator("strategy")
+ def validate_strategy(cls, v: str, info: ValidationInfo) -> str:
+ if info.data.get("training_type") == "lora" and v != "DDP":
+ raise ValueError("When using lora training_type, strategy must be 'DDP'")
+ return v
+
+ @field_validator("offload_params_grads")
+ def validate_offload_params_grads(cls, v: bool, info: ValidationInfo) -> bool:
+ if v and info.data.get("low_vram"):
+ raise ValueError("low_vram and offload_params_grads cannot be enabled simultaneously")
+ if v and info.data.get("no_grad_sync_when_accumulating"):
+ raise ValueError(
+ "offload_params_grads and no_grad_sync_when_accumulating cannot be enabled simultaneously"
+ )
+ if v and info.data.get("strategy") == "DDP":
+ raise ValueError("offload_params_grads cannot be enabled when strategy is 'DDP'")
+ return v
+
+ @field_validator("no_grad_sync_when_accumulating")
+ def validate_no_grad_sync_when_accumulating(cls, v: bool, info: ValidationInfo) -> bool:
+ if v and info.data.get("offload_params_grads"):
+ raise ValueError(
+ "offload_params_grads and no_grad_sync_when_accumulating cannot be enabled simultaneously"
+ )
+ if v and info.data.get("strategy") == "DDP":
+ raise ValueError(
+ "no_grad_sync_when_accumulating cannot be enabled when strategy is 'DDP'"
+ )
return v
@field_validator("validation_steps")
@@ -94,77 +160,11 @@ def validate_mixed_precision(cls, v: str, info: ValidationInfo) -> str:
return v
@classmethod
- def get_base_parser(cls):
- """Parse command line arguments and return Args instance"""
- parser = argparse.ArgumentParser()
- # Required arguments
- parser.add_argument("--model_path", type=str, required=True)
- parser.add_argument("--model_name", type=str, required=True)
- parser.add_argument("--training_type", type=str, required=True)
- parser.add_argument("--output_dir", type=str, required=True)
- parser.add_argument("--data_root", type=str, required=True)
- parser.add_argument("--report_to", type=str, required=True)
-
- # Training hyperparameters
- parser.add_argument("--seed", type=int, default=42)
- parser.add_argument("--train_epochs", type=int, default=1)
- parser.add_argument("--train_steps", type=int, default=None)
- parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
- parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument("--learning_rate", type=float, default=2e-5)
- parser.add_argument("--optimizer", type=str, default="adamw")
- parser.add_argument("--beta1", type=float, default=0.9)
- parser.add_argument("--beta2", type=float, default=0.95)
- parser.add_argument("--beta3", type=float, default=0.98)
- parser.add_argument("--epsilon", type=float, default=1e-8)
- parser.add_argument("--weight_decay", type=float, default=1e-4)
- parser.add_argument("--max_grad_norm", type=float, default=1.0)
-
- # Learning rate scheduler
- parser.add_argument("--lr_scheduler", type=str, default="linear")
- parser.add_argument("--lr_warmup_ratio", type=float, default=0.01)
- parser.add_argument("--lr_num_cycles", type=int, default=1)
- parser.add_argument("--lr_power", type=float, default=1.0)
-
- # Data loading
- parser.add_argument("--num_workers", type=int, default=8)
- parser.add_argument("--pin_memory", type=lambda x: x.lower() == "true", default=True)
-
- # Model configuration
- parser.add_argument("--mixed_precision", type=str, default="no")
- parser.add_argument("--low_vram", type=lambda x: x.lower() == "true", default=False)
- parser.add_argument(
- "--gradient_checkpointing", type=lambda x: x.lower() == "true", default=True
- )
- parser.add_argument("--nccl_timeout", type=int, default=1800)
-
- # LoRA parameters
- parser.add_argument("--rank", type=int, default=128)
- parser.add_argument("--lora_alpha", type=int, default=64)
- parser.add_argument(
- "--target_modules",
- type=str,
- nargs="+",
- default=["to_q", "to_k", "to_v", "to_out.0"],
- )
-
- # Checkpointing
- parser.add_argument("--checkpointing_steps", type=int, default=200)
- parser.add_argument("--checkpointing_limit", type=int, default=10)
- parser.add_argument("--resume_from_checkpoint", type=str, default=None)
-
- # Validation
- parser.add_argument("--do_validation", type=lambda x: x.lower() == "true", default=False)
- parser.add_argument("--validation_steps", type=int, default=None)
-
- return parser
-
- @classmethod
- def parse_args(cls):
- parser = cls.get_base_parser()
+ def parse_from_yaml(cls, fpath: str | Path) -> "BaseArgs":
+ if isinstance(fpath, str):
+ fpath = Path(fpath)
- # parser.add_argument(...)
- # ...
+ with open(fpath, "r") as f:
+ yaml_dict = yaml.safe_load(f)
- args = parser.parse_args()
- return cls(**vars(args))
+ return cls(**yaml_dict)
diff --git a/src/cogkit/finetune/base/base_state.py b/src/cogkit/finetune/base/base_state.py
index a3307b6..a475b41 100644
--- a/src/cogkit/finetune/base/base_state.py
+++ b/src/cogkit/finetune/base/base_state.py
@@ -6,11 +6,18 @@ class BaseState(BaseModel):
# Allow arbitrary types (for torch dtype)
model_config = {"arbitrary_types_allowed": True}
- weight_dtype: torch.dtype = torch.float32 # dtype for mixed precision training
- num_trainable_parameters: int = 0
- num_update_steps_per_epoch: int = 0
- total_batch_size_count: int = 0
+ world_size: int
+ local_rank: int
+ global_rank: int
- generator: torch.Generator | None = None
+ device: torch.device
+
+ weight_dtype: torch.dtype
- using_deepspeed: bool = False
+ train_steps: int = -1
+ train_epochs: int = -1
+ num_trainable_parameters: int = -1
+ num_update_steps_per_epoch: int = -1
+ total_batch_size_count: int = -1
+
+ generator: torch.Generator | None = None
diff --git a/src/cogkit/finetune/base/base_trainer.py b/src/cogkit/finetune/base/base_trainer.py
index ea32c6a..6fb4099 100644
--- a/src/cogkit/finetune/base/base_trainer.py
+++ b/src/cogkit/finetune/base/base_trainer.py
@@ -1,40 +1,48 @@
# -*- coding: utf-8 -*-
-
-
import json
-import logging
import math
+import os
from abc import ABC, abstractmethod
-from datetime import timedelta
+from contextlib import nullcontext
+from functools import partial
from pathlib import Path
+from typing import Any
-import diffusers
import torch
-import transformers
-from accelerate.accelerator import Accelerator, DistributedType
-from accelerate.logging import get_logger
-from accelerate.utils import (
- DistributedDataParallelKwargs,
- InitProcessGroupKwargs,
- ProjectConfiguration,
- set_seed,
+import torch.distributed as dist
+import torch.distributed.checkpoint as dcp
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp.fully_sharded_data_parallel import (
+ BackwardPrefetch,
+ CPUOffload,
+ MixedPrecision,
+ ShardingStrategy,
)
-from diffusers.optimization import get_scheduler
+from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from cogkit.finetune.base import BaseArgs, BaseComponents, BaseState
-from cogkit.utils.lora import inject_lora, save_lora
+from cogkit.finetune.logger import get_logger
+from cogkit.utils import inject_lora, set_global_seed, save_lora
from ..utils import (
+ AppState,
+ WandbTracker,
cast_training_params,
+ check_distributed,
+ delete_files,
free_memory,
- get_latest_ckpt_path_to_resume_from,
+ get_device,
+ get_global_rank,
+ get_global_step,
+ get_local_rank,
get_memory_statistics,
- get_optimizer,
- unwrap_model,
- find_files,
- delete_files,
+ get_world_size,
+ is_main_process,
+ list_files,
+ mkdir,
)
_DTYPE_MAP = {
@@ -51,16 +59,31 @@ class BaseTrainer(ABC):
Note: This class assumes that only `transformer` module is needed to be trained.
"""
- LOG_NAME: str = "BaseTrainer"
- LOG_LEVEL: str = "INFO"
-
# If set, should be a list of components to unload (refer to `Components``)
- # `transformer` is always in UNLOAD_LIST
UNLOAD_LIST: list[str] | None = None
- def __init__(self) -> None:
- self.logger = get_logger(self.LOG_NAME, self.LOG_LEVEL)
- self.accelerator: Accelerator = None
+ MODEL_STATE_DICT_FNAME = "model_state_dict.safetensors"
+ OPTIM_STATE_DICT_FNAME = "optim_state_dict.safetensors"
+
+ def __init__(self, uargs_fpath: str | Path) -> None:
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+ if isinstance(uargs_fpath, str):
+ uargs_fpath = Path(uargs_fpath)
+
+ self.uargs = self._init_args(uargs_fpath)
+
+ self._init_distributed()
+ self._init_directories()
+
+ self.logger = get_logger(
+ name=self.uargs.name4train,
+ log_file=self.uargs.output_dir / f"{self.uargs.name4train}.log",
+ level=self.uargs.log_level,
+ )
+
+ if self.uargs.seed is not None:
+ set_global_seed(self.uargs.seed)
+
self.train_dataset: Dataset = None
self.test_dataset: Dataset = None
self.train_data_loader: DataLoader = None
@@ -68,77 +91,62 @@ def __init__(self) -> None:
self.optimizer = None
self.lr_scheduler = None
- self.args = self._init_args()
self.state = self._init_state()
+ self.components = self.load_components()
+ self.tracker = None
+ if self.uargs.report_to is not None:
+ self.tracker = WandbTracker(
+ name=self.uargs.name4train,
+ config=self.uargs.model_dump(),
+ )
+ self.check_setting()
- self._init_distributed()
- self._init_logging()
- self._init_directories()
+ def _init_distributed(self) -> None:
+ dist.init_process_group(backend="nccl", timeout=self.uargs.nccl_timeout)
+ torch.cuda.set_device(get_local_rank())
- self.components = self.load_components()
+ def _init_directories(self) -> None:
+ mkdir(self.uargs.output_dir)
- self.state.using_deepspeed = self.accelerator.state.deepspeed_plugin is not None
+ def _init_args(self, uargs_fpath: Path) -> BaseArgs:
+ return BaseArgs.parse_from_yaml(uargs_fpath)
- def _init_distributed(self):
- logging_dir = Path(self.args.output_dir, "logs")
- project_config = ProjectConfiguration(
- project_dir=self.args.output_dir, logging_dir=logging_dir
- )
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
- init_process_group_kwargs = InitProcessGroupKwargs(
- backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
- )
- mixed_precision = "no" if torch.backends.mps.is_available() else self.args.mixed_precision
- report_to = None if self.args.report_to.lower() == "none" else self.args.report_to
-
- accelerator = Accelerator(
- project_config=project_config,
- gradient_accumulation_steps=self.args.gradient_accumulation_steps,
- mixed_precision=mixed_precision,
- log_with=report_to,
- kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
+ def _init_state(self) -> BaseState:
+ return BaseState(
+ world_size=get_world_size(),
+ local_rank=get_local_rank(),
+ global_rank=get_global_rank(),
+ device=get_device(),
+ weight_dtype=_DTYPE_MAP[self.uargs.mixed_precision],
)
- # Disable AMP for MPS.
- if torch.backends.mps.is_available():
- accelerator.native_amp = False
+ def fit(self) -> None:
+ self.logger.info("Checking settings...")
+ self.check_setting()
- self.accelerator = accelerator
+ self.logger.info("Initializing models...")
+ self.prepare_models()
- tracker_name = self.args.tracker_name
- self.accelerator.init_trackers(
- project_name=tracker_name,
- init_kwargs={"wandb": {"name": self.args.output_dir.name}},
- )
+ self.logger.info("Initializing dataset and dataloader...")
+ self.prepare_dataset()
- if self.args.seed is not None:
- set_seed(self.args.seed)
+ self.logger.info("Initializing trainable parameters...")
+ self.prepare_trainable_parameters()
- def _init_logging(self) -> None:
- logging.basicConfig(
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
- datefmt="%m/%d/%Y %H:%M:%S",
- level=self.LOG_LEVEL,
- )
- if self.accelerator.is_local_main_process:
- transformers.utils.logging.set_verbosity_warning()
- diffusers.utils.logging.set_verbosity_info()
- else:
- transformers.utils.logging.set_verbosity_error()
- diffusers.utils.logging.set_verbosity_error()
+ self.logger.info("Preparing model...")
+ self.prepare_model()
- self.logger.info("Initialized Trainer")
- self.logger.info(
- f"Accelerator state: \n{self.accelerator.state}",
- main_process_only=False,
- )
+ self.logger.info("Initializing optimizer and lr scheduler...")
+ self.prepare_optimizer()
- def _init_directories(self) -> None:
- if self.accelerator.is_main_process:
- self.args.output_dir = Path(self.args.output_dir)
- self.args.output_dir.mkdir(parents=True, exist_ok=True)
+ self.logger.info("Starting training...")
+ self.train()
+
+ self.logger.info("Cleaning up...")
+ self.cleanup()
def check_setting(self) -> None:
+ check_distributed()
# Check for `UNLOAD_LIST`
if self.UNLOAD_LIST is None:
self.logger.warning(
@@ -150,37 +158,74 @@ def check_setting(self) -> None:
raise ValueError(f"Invalid component name in unload_list: {name}")
def prepare_trainable_parameters(self) -> None:
- # For mixed precision training we cast all non-trainable weights to half-precision
- # as these weights are only used for inference, keeping weights in full precision is not required.
- weight_dtype = self.state.weight_dtype
-
- if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
- # due to pytorch#99272, MPS does not yet support bfloat16.
- raise ValueError(
- "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
- )
-
# For LoRA, we freeze all the parameters
# For SFT, we train all the parameters in transformer model
for attr_name, component in vars(self.components).items():
if hasattr(component, "requires_grad_"):
- if self.args.training_type == "sft" and attr_name == "transformer":
+ if self.uargs.training_type == "sft" and attr_name == "transformer":
component.requires_grad_(True)
else:
component.requires_grad_(False)
- if self.args.training_type == "lora":
+ if self.uargs.training_type == "lora":
# Initialize LoRA weights
inject_lora(self.components.transformer, lora_dir_or_state_dict=None)
- self.prepare_saving_loading_hooks()
- if self.args.gradient_checkpointing:
+ if self.uargs.gradient_checkpointing:
self.components.transformer.enable_gradient_checkpointing()
- def prepare_optimizer(self) -> None:
- # Make sure the trainable params are in float32
- # cast_training_params([self.components.transformer], dtype=torch.float32)
+ # cast all trainable params to the specified data type (bf16)
+ cast_training_params(self.components.transformer, dtype=self.state.weight_dtype)
+
+ def prepare_model(self) -> None:
+ match self.uargs.strategy:
+ case "NO_SHARD":
+ sharding_strategy = ShardingStrategy.NO_SHARD
+ case "SHARD_GRAD_OP":
+ sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
+ case "FULL_SHARD":
+ sharding_strategy = ShardingStrategy.FULL_SHARD
+ case "HYBRID_SHARD":
+ sharding_strategy = ShardingStrategy.HYBRID_SHARD
+
+ if self.uargs.strategy != "DDP":
+ warp_policy = partial(
+ size_based_auto_wrap_policy,
+ min_num_params=int(1e8),
+ )
+
+ self.components.transformer = FSDP(
+ module=self.components.transformer,
+ device_id=self.state.local_rank,
+ sharding_strategy=sharding_strategy,
+ auto_wrap_policy=warp_policy,
+ cpu_offload=CPUOffload(offload_params=self.uargs.offload_params_grads),
+ mixed_precision=MixedPrecision(
+ param_dtype=self.state.weight_dtype,
+ reduce_dtype=self.state.weight_dtype,
+ ),
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
+ use_orig_params=True if self.uargs.training_type == "lora" else False,
+ )
+ else:
+ # use qlora means we have already moved the model to the device
+ if not self.uargs.low_vram:
+ self.components.transformer = self.components.transformer.to(self.state.device)
+
+ self.components.transformer = DDP(
+ module=self.components.transformer,
+ device_ids=[self.state.local_rank],
+ )
+
+ # Load components needed for training to GPU, and cast them to the specified data type
+ ignore_list = self.UNLOAD_LIST
+ self.move_components_to_device(
+ dtype=self.state.weight_dtype,
+ device=self.state.device,
+ ignore_list=ignore_list + ["transformer"],
+ )
+ def prepare_optimizer(self) -> None:
# For LoRA, we only want to train the LoRA weights
# For SFT, we want to train all the parameters
trainable_parameters = list(
@@ -191,296 +236,155 @@ def prepare_optimizer(self) -> None:
)
transformer_parameters_with_lr = {
"params": trainable_parameters,
- "lr": self.args.learning_rate,
+ "lr": self.uargs.learning_rate,
}
params_to_optimize = [transformer_parameters_with_lr]
self.state.num_trainable_parameters = sum(p.numel() for p in trainable_parameters)
- use_deepspeed_opt = (
- self.accelerator.state.deepspeed_plugin is not None
- and "optimizer" in self.accelerator.state.deepspeed_plugin.deepspeed_config
- )
- optimizer = get_optimizer(
- params_to_optimize=params_to_optimize,
- logger=self.logger,
- optimizer_name=self.args.optimizer,
- learning_rate=self.args.learning_rate,
- beta1=self.args.beta1,
- beta2=self.args.beta2,
- beta3=self.args.beta3,
- epsilon=self.args.epsilon,
- weight_decay=self.args.weight_decay,
- use_deepspeed=use_deepspeed_opt,
+ optimizer = torch.optim.AdamW(
+ params=params_to_optimize,
+ lr=self.uargs.learning_rate,
+ betas=(self.uargs.beta1, self.uargs.beta2),
+ eps=self.uargs.epsilon,
+ weight_decay=self.uargs.weight_decay,
)
- # Do not need to divide by num_gpus since acclerate will handle this after prepare lr_scheduler
num_update_steps_per_epoch = math.ceil(
- len(self.train_data_loader) / self.args.gradient_accumulation_steps
+ len(self.train_data_loader) / self.uargs.gradient_accumulation_steps
)
- total_train_steps = self.args.train_epochs * num_update_steps_per_epoch
- total_num_warmup_steps = max(int(total_train_steps * self.args.lr_warmup_ratio), 0)
+ total_train_steps = self.uargs.train_epochs * num_update_steps_per_epoch
- use_deepspeed_lr_scheduler = (
- self.accelerator.state.deepspeed_plugin is not None
- and "scheduler" in self.accelerator.state.deepspeed_plugin.deepspeed_config
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
+ optimizer=optimizer,
+ T_max=total_train_steps,
)
- if use_deepspeed_lr_scheduler:
- from accelerate.utils import DummyScheduler
-
- lr_scheduler = DummyScheduler(
- name=self.args.lr_scheduler,
- optimizer=optimizer,
- total_num_steps=total_train_steps,
- num_warmup_steps=total_num_warmup_steps,
- )
- else:
- lr_scheduler = get_scheduler(
- name=self.args.lr_scheduler,
- optimizer=optimizer,
- num_warmup_steps=total_num_warmup_steps,
- num_training_steps=total_train_steps,
- num_cycles=self.args.lr_num_cycles,
- power=self.args.lr_power,
- )
-
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
- def prepare_for_training(self) -> None:
- # cast training params to the specified data type (bf16)
- cast_training_params(self.components.transformer, dtype=self.state.weight_dtype)
-
- (
- self.components.transformer,
- self.optimizer,
- self.train_data_loader,
- self.lr_scheduler,
- ) = self.accelerator.prepare(
- self.components.transformer,
- self.optimizer,
- self.train_data_loader,
- self.lr_scheduler,
- )
-
- # Load components needed for training to GPU (except transformer), and cast them to the specified data type
- ignore_list = self.UNLOAD_LIST
- self.move_components_to_device(
- dtype=self.state.weight_dtype, device=self.accelerator.device, ignore_list=ignore_list
- )
-
- if self.args.do_validation:
- assert self.test_data_loader is not None
- self.test_data_loader = self.accelerator.prepare_data_loader(self.test_data_loader)
-
+ def train(self) -> None:
# We need to recalculate our total training steps as the size of the training dataloader may have changed in distributed training
num_update_steps_per_epoch = math.ceil(
- len(self.train_data_loader) / self.args.gradient_accumulation_steps
+ len(self.train_data_loader) / self.uargs.gradient_accumulation_steps
)
- self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
+ self.state.train_steps = self.uargs.train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
- self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch)
+ self.state.train_epochs = math.ceil(self.state.train_steps / num_update_steps_per_epoch)
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
- def train(self) -> None:
memory_statistics = get_memory_statistics(self.logger)
self.logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
self.state.total_batch_size_count = (
- self.args.batch_size
- * self.accelerator.num_processes
- * self.args.gradient_accumulation_steps
+ self.uargs.batch_size * self.state.world_size * self.uargs.gradient_accumulation_steps
)
info = {
"trainable parameters": self.state.num_trainable_parameters,
"total samples": len(self.train_dataset),
- "train epochs": self.args.train_epochs,
- "train steps": self.args.train_steps,
- "batches per device": self.args.batch_size,
+ "train epochs": self.state.train_epochs,
+ "train steps": self.state.train_steps,
+ "batches per device": self.uargs.batch_size,
"total batches observed per epoch": len(self.train_data_loader),
"train batch size total count": self.state.total_batch_size_count,
- "gradient accumulation steps": self.args.gradient_accumulation_steps,
+ "gradient accumulation steps": self.uargs.gradient_accumulation_steps,
}
self.logger.info(f"Training configuration: {json.dumps(info, indent=4)}")
global_step = 0
- first_epoch = 0
- initial_global_step = 0
-
+ initial_epoch = 0
# Potentially load in the weights and states from a previous save
- (
- resume_from_checkpoint_path,
- initial_global_step,
- global_step,
- first_epoch,
- ) = get_latest_ckpt_path_to_resume_from(
- resume_from_checkpoint=self.args.resume_from_checkpoint,
- num_update_steps_per_epoch=self.state.num_update_steps_per_epoch,
- logger=self.logger,
- )
- if resume_from_checkpoint_path is not None:
- self.accelerator.load_state(resume_from_checkpoint_path)
+ if self.uargs.resume_from_checkpoint is not None:
+ self.logger.info(f"Resuming from checkpoint {self.uargs.resume_from_checkpoint}")
+ global_step = get_global_step(self.uargs.resume_from_checkpoint)
+ for _ in range(global_step):
+ self.lr_scheduler.step()
+ self.resume_from_checkpoint(self.uargs.resume_from_checkpoint)
+ initial_epoch = global_step // num_update_steps_per_epoch
+ for group in self.optimizer.param_groups:
+ group["lr"] = self.lr_scheduler.get_last_lr()[0]
progress_bar = tqdm(
- range(self.args.train_steps),
- initial=initial_global_step,
+ range(self.state.train_steps),
+ initial=global_step,
desc="Training steps",
- disable=not self.accelerator.is_local_main_process,
+ disable=not is_main_process(),
)
- accelerator = self.accelerator
- generator = torch.Generator(device=accelerator.device)
- if self.args.seed is not None:
- generator = generator.manual_seed(self.args.seed)
+ generator = torch.Generator(device=self.state.device)
+ if self.uargs.seed is not None:
+ generator = generator.manual_seed(self.uargs.seed)
self.state.generator = generator
free_memory()
ckpt_path = None
- for epoch in range(first_epoch, self.args.train_epochs):
- self.logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})")
+ for epoch in range(initial_epoch, self.uargs.train_epochs):
+ self.logger.debug(f"Starting epoch ({epoch + 1}/{self.uargs.train_epochs})")
self.components.transformer.train()
- models_to_accumulate = [self.components.transformer]
for step, batch in enumerate(self.train_data_loader):
- self.logger.debug(f"Starting step {step + 1}")
- logs = {}
-
- with accelerator.accumulate(models_to_accumulate):
- # These weighting schemes use a uniform timestep sampling and instead post-weight the loss
- loss = self.compute_loss(batch)
- accelerator.backward(loss)
-
- if accelerator.sync_gradients:
- if accelerator.distributed_type == DistributedType.DEEPSPEED:
- grad_norm = self.components.transformer.get_global_grad_norm()
- # In some cases the grad norm may not return a float
- if torch.is_tensor(grad_norm):
- grad_norm = grad_norm.item()
- else:
- grad_norm = accelerator.clip_grad_norm_(
- self.components.transformer.parameters(),
- self.args.max_grad_norm,
- )
- if torch.is_tensor(grad_norm):
- grad_norm = grad_norm.item()
-
- logs["grad_norm"] = grad_norm
-
- self.optimizer.step()
- self.lr_scheduler.step()
- self.optimizer.zero_grad()
-
- # Checks if the accelerator has performed an optimization step behind the scenes
- if accelerator.sync_gradients:
- progress_bar.update(1)
+ self.logger.debug(f"Starting step {step + 1}, global step: {global_step}")
+
+ is_sync_step = (step + 1) % self.uargs.gradient_accumulation_steps == 0
+ is_last_step = (step + 1) == len(self.train_data_loader)
+ sync_grad = is_sync_step or is_last_step
+
+ logs = self.train_step(batch, sync_grad=sync_grad)
+
+ if sync_grad:
global_step += 1
+ progress_bar.update(1)
+
ckpt_path = self.maybe_save_checkpoint(global_step)
- logs["loss"] = loss.detach().item()
- logs["lr"] = self.lr_scheduler.get_last_lr()[0]
progress_bar.set_postfix(logs)
- # Maybe run validation
- should_run_validation = (
- self.args.do_validation
- and global_step % self.args.validation_steps == 0
- and accelerator.sync_gradients
- )
- if should_run_validation:
- del loss
+ if self.tracker is not None:
+ self.tracker.log(logs, step=global_step)
+
+ if self.uargs.do_validation and global_step % self.uargs.validation_steps == 0:
free_memory()
self.validate(global_step, ckpt_path=ckpt_path)
- accelerator.log(logs, step=global_step)
-
- if global_step >= self.args.train_steps:
- break
-
- memory_statistics = get_memory_statistics(self.logger)
+ memory_statistics = get_memory_statistics(self.state.device)
self.logger.info(
f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}"
)
- accelerator.wait_for_everyone()
- ckpt_path = self.maybe_save_checkpoint(global_step, must_save=True)
- if self.args.do_validation:
- free_memory()
- self.validate(global_step, ckpt_path=ckpt_path)
+ def train_step(self, batch: dict[str, Any], sync_grad: bool) -> dict[str, Any]:
+ logs = {}
- del self.components
- free_memory()
- memory_statistics = get_memory_statistics(self.logger)
- self.logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")
+ sync_context = self.components.transformer.no_sync() if not sync_grad else nullcontext()
- accelerator.end_training()
+ with sync_context:
+ loss = self.compute_loss(batch)
+ loss = loss / self.uargs.gradient_accumulation_steps
+ loss.backward()
- def fit(self) -> None:
- self.logger.info("Checking settings...")
- self.check_setting()
-
- self.logger.info("Initializing models...")
- self.prepare_models()
-
- self.logger.info("Initializing dataset and dataloader...")
- self.prepare_dataset()
-
- self.logger.info("Initializing trainable parameters...")
- self.prepare_trainable_parameters()
-
- self.logger.info("Initializing optimizer and lr scheduler...")
- self.prepare_optimizer()
-
- self.logger.info("Preparing for training...")
- self.prepare_for_training()
-
- self.logger.info("Starting training...")
- self.train()
-
- @abstractmethod
- def _init_args(self) -> BaseArgs:
- raise NotImplementedError
-
- @abstractmethod
- def _init_state(self) -> BaseState:
- raise NotImplementedError
+ if sync_grad:
+ if self.uargs.strategy != "DDP":
+ grad_norm = self.components.transformer.clip_grad_norm_(
+ max_norm=self.uargs.max_grad_norm
+ )
+ else:
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ self.components.transformer.parameters(),
+ max_norm=self.uargs.max_grad_norm,
+ )
+ self.optimizer.step()
+ self.lr_scheduler.step()
+ self.optimizer.zero_grad()
- @abstractmethod
- def load_components(self) -> BaseComponents:
- # note: `self.components.transformer`(model needs to be trained)
- # and `self.components.pipeline_cls` must be defined
- raise NotImplementedError
+ loss = loss.detach()
+ dist.all_reduce(grad_norm.to(self.state.device), op=dist.ReduceOp.AVG)
+ dist.all_reduce(loss.to(self.state.device), op=dist.ReduceOp.AVG)
- @abstractmethod
- def prepare_models(self) -> None:
- # Doing something like `self.components.vae.enable_slicing()`
- raise NotImplementedError
+ logs["grad_norm"] = grad_norm.item()
+ logs["loss"] = loss.item()
+ logs["lr"] = self.lr_scheduler.get_last_lr()[0]
+ del loss # release graph
- @abstractmethod
- def prepare_dataset(self) -> None:
- # initialize `self.train_dataset` and `self.train_data_loader`
- # initialize `self.test_dataset` and `self.test_data_loader` if `self.args.do_validation` is True
- raise NotImplementedError
-
- @abstractmethod
- def compute_loss(self, batch) -> torch.Tensor:
- raise NotImplementedError
-
- @abstractmethod
- def validate(self, step: int, ckpt_path: str | None = None) -> None:
- # validation logic defined here
- # during validation, additional modules in the pipeline may need to be moved to GPU memory
- raise NotImplementedError
-
- def get_training_dtype(self) -> torch.dtype:
- if self.args.mixed_precision == "no":
- return _DTYPE_MAP["fp32"]
- elif self.args.mixed_precision == "fp16":
- return _DTYPE_MAP["fp16"]
- elif self.args.mixed_precision == "bf16":
- return _DTYPE_MAP["bf16"]
- else:
- raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
+ return logs
def move_components_to_device(self, dtype, device, ignore_list: list[str] = []):
ignore_list = set(ignore_list)
@@ -497,61 +401,89 @@ def move_components_to_device(self, dtype, device, ignore_list: list[str] = []):
component.to(device, dtype=dtype),
)
- def prepare_saving_loading_hooks(self):
- # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
- def save_model_hook(models, weights, output_dir):
- assert self.accelerator.distributed_type != DistributedType.DEEPSPEED
-
- for model in models:
- original_model = unwrap_model(self.accelerator, model)
- original_transformer = unwrap_model(self.accelerator, self.components.transformer)
- if isinstance(original_model, type(original_transformer)):
- if self.accelerator.is_main_process:
- save_lora(model, output_dir)
- else:
- raise ValueError(f"Unexpected save model: {model.__class__}")
-
- # make sure to pop weight so that corresponding model is not saved again
- if weights:
- weights.pop()
-
- def load_model_hook(models, input_dir):
- assert self.accelerator.distributed_type != DistributedType.DEEPSPEED
-
- for model in models:
- original_model = unwrap_model(self.accelerator, model)
- original_transformer = unwrap_model(self.accelerator, self.components.transformer)
- if isinstance(original_model, type(original_transformer)):
- inject_lora(model, input_dir)
- else:
- raise ValueError(f"Unexpected save model: {model.__class__}")
-
- self.accelerator.register_save_state_pre_hook(save_model_hook)
- self.accelerator.register_load_state_pre_hook(load_model_hook)
-
def maybe_save_checkpoint(self, global_step: int, must_save: bool = False) -> str | None:
- if not (must_save or global_step % self.args.checkpointing_steps == 0):
+ if not (must_save or global_step % self.uargs.checkpointing_steps == 0):
return None
- checkpointing_limit = self.args.checkpointing_limit
- output_dir = Path(self.args.output_dir)
+ checkpointing_limit = self.uargs.checkpointing_limit
+ output_dir = Path(self.uargs.output_dir)
logger = self.logger
if checkpointing_limit is not None:
- checkpoints = find_files(output_dir, prefix="checkpoint")
+ checkpoints = list_files(output_dir, prefix="checkpoint")
+
+ def get_checkpoint_number(path):
+ try:
+ return int(Path(path).name.split("-")[1])
+ except (IndexError, ValueError):
+ raise ValueError(f"Invalid checkpoint path: {path}")
+
+ checkpoints.sort(key=get_checkpoint_number)
# before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= checkpointing_limit:
num_to_remove = len(checkpoints) - checkpointing_limit + 1
checkpoints_to_remove = checkpoints[0:num_to_remove]
- if self.accelerator.is_main_process:
- delete_files(checkpoints_to_remove, logger)
+ delete_files(checkpoints_to_remove)
- logger.info(f"Checkpointing at step {global_step}")
- save_path = output_dir / f"checkpoint-{global_step}"
- logger.info(f"Saving state to {save_path}")
+ save_dir = output_dir / f"checkpoint-{global_step}"
+ mkdir(save_dir)
+ logger.info(f"Checkpointing at step {global_step}, saving state to {save_dir} ...")
- self.accelerator.save_state(save_path, safe_serialization=True)
+ saved_model = self.unwrap_model(self.components.transformer)
- self.accelerator.wait_for_everyone()
- return save_path
+ state_dict = {
+ "app": AppState(saved_model, self.optimizer, lora=self.uargs.training_type == "lora")
+ }
+ if not self.uargs.low_vram:
+ dcp.save(state_dict, checkpoint_id=str(save_dir))
+ else:
+ if is_main_process():
+ save_lora(saved_model, save_dir)
+
+ return save_dir
+
+ def resume_from_checkpoint(self, ckpt_dir: str | Path) -> None:
+ transformer = self.unwrap_model(self.components.transformer)
+ state_dict = {
+ "app": AppState(transformer, self.optimizer, lora=self.uargs.training_type == "lora")
+ }
+ dcp.load(state_dict, checkpoint_id=str(ckpt_dir))
+
+ def cleanup(self) -> None:
+ dist.destroy_process_group()
+ if self.tracker is not None:
+ self.tracker.finish()
+
+ def unwrap_model(self, model: Any) -> Any:
+ if self.uargs.strategy == "DDP":
+ return model.module
+ else:
+ return model
+
+ @abstractmethod
+ def load_components(self) -> BaseComponents:
+ # note: `self.components.transformer`(model needs to be trained)
+ # and `self.components.pipeline_cls` must be defined
+ raise NotImplementedError
+
+ @abstractmethod
+ def prepare_models(self) -> None:
+ # Doing something like `self.components.vae.enable_slicing()`
+ raise NotImplementedError
+
+ @abstractmethod
+ def prepare_dataset(self) -> None:
+ # initialize `self.train_dataset` and `self.train_data_loader`
+ # initialize `self.test_dataset` and `self.test_data_loader` if `self.uargs.do_validation` is True
+ raise NotImplementedError
+
+ @abstractmethod
+ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor:
+ raise NotImplementedError
+
+ @abstractmethod
+ def validate(self, step: int, ckpt_path: str | None = None) -> None:
+ # validation logic defined here
+ # during validation, additional modules in the pipeline may need to be moved to GPU memory
+ raise NotImplementedError
diff --git a/src/cogkit/datasets/__init__.py b/src/cogkit/finetune/datasets/__init__.py
similarity index 61%
rename from src/cogkit/datasets/__init__.py
rename to src/cogkit/finetune/datasets/__init__.py
index e440d41..d8cc16d 100644
--- a/src/cogkit/datasets/__init__.py
+++ b/src/cogkit/finetune/datasets/__init__.py
@@ -1,9 +1,9 @@
# -*- coding: utf-8 -*-
-from cogkit.datasets.i2v_dataset import BaseI2VDataset, I2VDatasetWithResize
-from cogkit.datasets.t2v_dataset import BaseT2VDataset, T2VDatasetWithResize
-from cogkit.datasets.t2i_dataset import (
+from .i2v_dataset import BaseI2VDataset, I2VDatasetWithResize
+from .t2v_dataset import BaseT2VDataset, T2VDatasetWithResize
+from .t2i_dataset import (
T2IDatasetWithFactorResize,
T2IDatasetWithResize,
T2IDatasetWithPacking,
diff --git a/src/cogkit/datasets/i2v_dataset.py b/src/cogkit/finetune/datasets/i2v_dataset.py
similarity index 94%
rename from src/cogkit/datasets/i2v_dataset.py
rename to src/cogkit/finetune/datasets/i2v_dataset.py
index 001be49..14d4d77 100644
--- a/src/cogkit/datasets/i2v_dataset.py
+++ b/src/cogkit/finetune/datasets/i2v_dataset.py
@@ -5,7 +5,6 @@
from typing import TYPE_CHECKING, Any, Tuple
import torch
-from accelerate.logging import get_logger
from datasets import load_dataset
from PIL import Image
from safetensors.torch import load_file, save_file
@@ -14,7 +13,7 @@
from torchvision.io import VideoReader
from typing_extensions import override
-from cogkit.finetune.diffusion.constants import LOG_LEVEL, LOG_NAME
+from cogkit.finetune.logger import get_logger
from .utils import (
get_prompt_embedding,
@@ -25,7 +24,7 @@
if TYPE_CHECKING:
from cogkit.finetune.diffusion.trainer import DiffusionTrainer
-logger = get_logger(LOG_NAME, LOG_LEVEL)
+_logger = get_logger()
class BaseI2VDataset(Dataset):
@@ -84,7 +83,7 @@ def update_with_image(video_example, idx):
self.data = video_data.map(update_with_image, with_indices=True)
else:
- logger.warning(
+ _logger.warning(
f"No image data found in {self.data_root}, using first frame of video instead"
)
@@ -116,7 +115,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
##### prompt
prompt = self.data[index]["prompt"]
- prompt_embedding = get_prompt_embedding(self.encode_text, prompt, cache_dir, logger)
+ prompt_embedding = get_prompt_embedding(self.encode_text, prompt, cache_dir)
##### image
image_preprocessed = self.data[index]["image"]
@@ -137,10 +136,10 @@ def __getitem__(self, index: int) -> dict[str, Any]:
##### video
video = self.data[index]["video"]
video_path = Path(video._hf_encoded["path"])
- train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
+ train_resolution_str = "x".join(str(x) for x in self.trainer.uargs.train_resolution)
video_latent_dir = (
- cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
+ cache_dir / "video_latent" / self.trainer.uargs.model_name / train_resolution_str
)
video_latent_dir.mkdir(parents=True, exist_ok=True)
@@ -148,7 +147,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
if encoded_video_path.exists():
encoded_video = load_file(encoded_video_path)["encoded_video"]
- logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
+ _logger.debug(f"Loaded encoded video from {encoded_video_path}")
else:
frames, _ = self.preprocess(video, None, self.device)
# Current shape of frames: [F, C, H, W]
@@ -162,10 +161,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
encoded_video = encoded_video[0]
encoded_video = encoded_video.to("cpu")
save_file({"encoded_video": encoded_video}, encoded_video_path)
- logger.info(
- f"Saved encoded video to {encoded_video_path}",
- main_process_only=False,
- )
+ _logger.info(f"Saved encoded video to {encoded_video_path}")
# shape of encoded_video: [C, F, H, W]
# shape of image: [C, H, W]
diff --git a/src/cogkit/datasets/t2i_dataset.py b/src/cogkit/finetune/datasets/t2i_dataset.py
similarity index 97%
rename from src/cogkit/datasets/t2i_dataset.py
rename to src/cogkit/finetune/datasets/t2i_dataset.py
index c67f149..383f522 100644
--- a/src/cogkit/datasets/t2i_dataset.py
+++ b/src/cogkit/finetune/datasets/t2i_dataset.py
@@ -4,26 +4,25 @@
import torch
import torchvision.transforms as transforms
-from accelerate.logging import get_logger
from datasets import load_dataset
from PIL import Image
from torch.utils.data import Dataset
from typing_extensions import override
-from cogkit.finetune.diffusion.constants import LOG_LEVEL, LOG_NAME
+from cogkit.finetune.logger import get_logger
from .utils import (
+ calculate_resize_dimensions,
get_image_embedding,
get_prompt_embedding,
pil2tensor,
preprocess_image_with_resize,
- calculate_resize_dimensions,
)
if TYPE_CHECKING:
from cogkit.finetune.diffusion.trainer import DiffusionTrainer
-logger = get_logger(LOG_NAME, LOG_LEVEL)
+_logger = get_logger()
class BaseT2IDataset(Dataset):
@@ -80,7 +79,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
##### prompt
prompt = self.data[index]["prompt"]
- prompt_embedding = get_prompt_embedding(self.encode_text, prompt, cache_dir, logger)
+ prompt_embedding = get_prompt_embedding(self.encode_text, prompt, cache_dir)
if not self.using_train:
return {
@@ -100,7 +99,7 @@ def encode_fn(image: Image.Image) -> torch.Tensor:
return encoded_image
# shape of encoded_image: [C, H, W]
- encoded_image = get_image_embedding(encode_fn, image, cache_dir, logger)
+ encoded_image = get_image_embedding(encode_fn, image, cache_dir)
# shape of image: [C, H, W]
return {
diff --git a/src/cogkit/datasets/t2v_dataset.py b/src/cogkit/finetune/datasets/t2v_dataset.py
similarity index 90%
rename from src/cogkit/datasets/t2v_dataset.py
rename to src/cogkit/finetune/datasets/t2v_dataset.py
index 36ba644..57248c0 100644
--- a/src/cogkit/datasets/t2v_dataset.py
+++ b/src/cogkit/finetune/datasets/t2v_dataset.py
@@ -2,7 +2,6 @@
from typing import TYPE_CHECKING, Any
import torch
-from accelerate.logging import get_logger
from datasets import load_dataset
from safetensors.torch import load_file, save_file
from torch.utils.data import Dataset
@@ -10,14 +9,14 @@
from torchvision.io import VideoReader
from typing_extensions import override
-from cogkit.finetune.diffusion.constants import LOG_LEVEL, LOG_NAME
+from cogkit.finetune.logger import get_logger
from .utils import get_prompt_embedding, preprocess_video_with_resize
if TYPE_CHECKING:
from cogkit.finetune.diffusion.trainer import DiffusionTrainer
-logger = get_logger(LOG_NAME, LOG_LEVEL)
+_logger = get_logger()
class BaseT2VDataset(Dataset):
@@ -66,7 +65,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
##### prompt
prompt = self.data[index]["prompt"]
- prompt_embedding = get_prompt_embedding(self.encode_text, prompt, cache_dir, logger)
+ prompt_embedding = get_prompt_embedding(self.encode_text, prompt, cache_dir)
if not self.using_train:
return {
@@ -78,20 +77,17 @@ def __getitem__(self, index: int) -> dict[str, Any]:
video = self.data[index]["video"]
video_path = Path(video._hf_encoded["path"])
- train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
+ train_resolution_str = "x".join(str(x) for x in self.trainer.uargs.train_resolution)
video_latent_dir = (
- cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
+ cache_dir / "video_latent" / self.trainer.uargs.model_name / train_resolution_str
)
video_latent_dir.mkdir(parents=True, exist_ok=True)
encoded_video_path = video_latent_dir / (video_path.stem + ".safetensors")
if encoded_video_path.exists():
encoded_video = load_file(encoded_video_path)["encoded_video"]
- logger.debug(
- f"Loaded encoded video from {encoded_video_path}",
- main_process_only=False,
- )
+ _logger.debug(f"Loaded encoded video from {encoded_video_path}")
else:
frames = self.preprocess(video, self.device)
# Current shape of frames: [F, C, H, W]
@@ -105,10 +101,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
encoded_video = encoded_video[0]
encoded_video = encoded_video.to("cpu")
save_file({"encoded_video": encoded_video}, encoded_video_path)
- logger.info(
- f"Saved encoded video to {encoded_video_path}",
- main_process_only=False,
- )
+ _logger.info(f"Saved encoded video to {encoded_video_path}")
return {
"prompt": prompt,
diff --git a/src/cogkit/datasets/utils.py b/src/cogkit/finetune/datasets/utils.py
similarity index 90%
rename from src/cogkit/datasets/utils.py
rename to src/cogkit/finetune/datasets/utils.py
index fb6e87e..588ff64 100644
--- a/src/cogkit/datasets/utils.py
+++ b/src/cogkit/finetune/datasets/utils.py
@@ -1,5 +1,4 @@
import hashlib
-import logging
import math
from pathlib import Path
from typing import Callable
@@ -12,6 +11,10 @@
from safetensors.torch import load_file, save_file
from torchvision.io import VideoReader
+from cogkit.finetune.logger import get_logger
+
+_logger = get_logger()
+
########## loaders ##########
@@ -55,7 +58,7 @@ def load_images_from_videos(videos_path: list[Path]) -> list[Path]:
# Save frame as PNG with same name as video
cv2.imwrite(str(frame_path), frame)
- logging.info(f"Saved first frame to {frame_path}")
+ _logger.info(f"Saved first frame to {frame_path}")
# Release video capture
cap.release()
@@ -176,16 +179,13 @@ def preprocess_video_with_resize(
########## embedding & caching ##########
-def get_prompt_embedding(
- encode_fn: Callable, prompt: str, cache_dir: Path, logger: logging.Logger
-) -> torch.Tensor:
+def get_prompt_embedding(encode_fn: Callable, prompt: str, cache_dir: Path) -> torch.Tensor:
"""Get prompt embedding from cache or create new one if not exists.
Args:
encode_fn: Function to project prompt to embedding.
prompt: Text prompt to be embedded
cache_dir: Base directory for caching embeddings
- logger: Logger instance for logging messages
Returns:
torch.Tensor: Prompt embedding with shape [seq_len, hidden_size]
@@ -200,9 +200,9 @@ def get_prompt_embedding(
with lock:
if prompt_embedding_path.exists():
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
- logger.debug(
+ _logger.debug(
f"Loaded prompt embedding from {prompt_embedding_path}",
- main_process_only=False,
+ main_only=False,
)
else:
prompt_embedding = encode_fn(prompt)
@@ -211,22 +211,20 @@ def get_prompt_embedding(
prompt_embedding = prompt_embedding.to("cpu")
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
- logger.info(
+ _logger.info(
f"Saved prompt embedding to {prompt_embedding_path}",
- main_process_only=False,
+ main_only=False,
)
return prompt_embedding
-def get_image_embedding(
- encode_fn: Callable, image: Image.Image, cache_dir: Path, logger: logging.Logger
-) -> torch.Tensor:
+def get_image_embedding(encode_fn: Callable, image: Image.Image, cache_dir: Path) -> torch.Tensor:
encoded_images_dir = cache_dir / "encoded_images"
encoded_images_dir.mkdir(parents=True, exist_ok=True)
if not hasattr(image, "filename"):
- logger.warning("Image object does not have filename attribute, skipping caching.")
+ _logger.warning("Image object does not have filename attribute, skipping caching.")
return encode_fn(image.convert("RGB")).to("cpu")
filename = Path(image.filename).stem
@@ -235,18 +233,12 @@ def get_image_embedding(
if encoded_image_path.exists():
encoded_image = load_file(encoded_image_path)["encoded_image"]
- logger.debug(
- f"Loaded encoded image from {encoded_image_path}",
- main_process_only=False,
- )
+ _logger.debug(f"Loaded encoded image from {encoded_image_path}")
else:
encoded_image = encode_fn(image.convert("RGB"))
encoded_image = encoded_image.to("cpu")
save_file({"encoded_image": encoded_image}, encoded_image_path)
- logger.info(
- f"Saved encoded image to {encoded_image_path}",
- main_process_only=False,
- )
+ _logger.info(f"Saved encoded image to {encoded_image_path}")
return encoded_image
diff --git a/src/cogkit/finetune/diffusion/constants.py b/src/cogkit/finetune/diffusion/constants.py
deleted file mode 100644
index f8c163d..0000000
--- a/src/cogkit/finetune/diffusion/constants.py
+++ /dev/null
@@ -1,2 +0,0 @@
-LOG_NAME = "DiffusionTrainer"
-LOG_LEVEL = "INFO"
diff --git a/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_i2v/lora_trainer.py b/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_i2v/lora_trainer.py
index 4311014..6cec3f6 100644
--- a/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_i2v/lora_trainer.py
+++ b/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_i2v/lora_trainer.py
@@ -4,22 +4,21 @@
from typing import Any
import torch
-from diffusers import (
- AutoencoderKLCogVideoX,
- CogVideoXDPMScheduler,
- CogVideoXImageToVideoPipeline,
- CogVideoXTransformer3DModel,
-)
-from diffusers.models.embeddings import get_3d_rotary_pos_embed
from PIL import Image
-from transformers import AutoTokenizer, T5EncoderModel, BitsAndBytesConfig
+from transformers import AutoTokenizer, BitsAndBytesConfig, T5EncoderModel
from typing_extensions import override
from cogkit.finetune import register
from cogkit.finetune.diffusion.schemas import DiffusionComponents
from cogkit.finetune.diffusion.trainer import DiffusionTrainer
-from cogkit.finetune.utils import unwrap_model
from cogkit.utils import load_lora_checkpoint, unload_lora_checkpoint
+from diffusers import (
+ AutoencoderKLCogVideoX,
+ CogVideoXDPMScheduler,
+ CogVideoXImageToVideoPipeline,
+ CogVideoXTransformer3DModel,
+)
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
class CogVideoXI2VLoraTrainer(DiffusionTrainer):
@@ -37,7 +36,7 @@ def load_components(self) -> DiffusionComponents:
dtype = self.state.weight_dtype
components = DiffusionComponents()
- model_path = str(self.args.model_path)
+ model_path = str(self.uargs.model_path)
### pipeline
components.pipeline_cls = CogVideoXImageToVideoPipeline
@@ -53,7 +52,7 @@ def load_components(self) -> DiffusionComponents:
)
### transformer
- if not self.args.low_vram:
+ if not self.uargs.low_vram:
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path,
subfolder="transformer",
@@ -64,7 +63,7 @@ def load_components(self) -> DiffusionComponents:
model_path,
subfolder="transformer",
quantization_config=nf4_config,
- device=self.accelerator.device,
+ device=self.state.device,
torch_dtype=dtype,
)
@@ -84,18 +83,18 @@ def load_components(self) -> DiffusionComponents:
@override
def initialize_pipeline(self, ckpt_path: str | None = None) -> CogVideoXImageToVideoPipeline:
- if not self.args.low_vram:
+ if not self.uargs.low_vram:
pipe = CogVideoXImageToVideoPipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
- transformer=unwrap_model(self.accelerator, self.components.transformer),
+ transformer=self.unwrap_model(self.components.transformer),
scheduler=self.components.scheduler,
)
else:
- assert self.args.training_type == "lora"
+ assert self.uargs.training_type == "lora"
transformer = CogVideoXTransformer3DModel.from_pretrained(
- str(self.args.model_path),
+ str(self.uargs.model_path),
subfolder="transformer",
torch_dtype=self.state.weight_dtype,
)
@@ -131,7 +130,7 @@ def encode_text(self, prompt: str) -> torch.Tensor:
)
prompt_token_ids = prompt_token_ids.input_ids
prompt_embedding = self.components.text_encoder(
- prompt_token_ids.to(self.accelerator.device)
+ prompt_token_ids.to(self.state.device)
).last_hidden_state[0]
# shape of prompt_embedding: [seq_len, hidden_size]
@@ -176,9 +175,10 @@ def collate_fn(self, samples: list[dict[str, Any]]) -> dict[str, Any]:
@override
def compute_loss(self, batch) -> torch.Tensor:
- prompt_embedding = batch["prompt_embedding"]
- latent = batch["encoded_videos"]
- images = batch["image_preprocessed"]
+ device = self.state.device
+ prompt_embedding = batch["prompt_embedding"].to(device)
+ latent = batch["encoded_videos"].to(device)
+ images = batch["image_preprocessed"].to(device)
# Shape of prompt_embedding: [B, seq_len, hidden_size]
# Shape of latent: [B, C, F, H, W]
@@ -201,9 +201,7 @@ def compute_loss(self, batch) -> torch.Tensor:
# Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
images = images.unsqueeze(2)
# Add noise to images
- image_noise_sigma = torch.normal(
- mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device
- )
+ image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=device)
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
noisy_images = (
images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
@@ -218,7 +216,7 @@ def compute_loss(self, batch) -> torch.Tensor:
0,
self.components.scheduler.config.num_train_timesteps,
(batch_size,),
- device=self.accelerator.device,
+ device=device,
)
timesteps = timesteps.long()
@@ -256,7 +254,7 @@ def compute_loss(self, batch) -> torch.Tensor:
num_frames=num_frames,
transformer_config=transformer_config,
vae_scale_factor_spatial=vae_scale_factor_spatial,
- device=self.accelerator.device,
+ device=device,
)
if transformer_config.use_rotary_positional_embeddings
else None
@@ -310,8 +308,8 @@ def validation_step(
num_frames=self.state.train_resolution[0],
height=self.state.train_resolution[1],
width=self.state.train_resolution[2],
- prompt_embeds=prompt_embedding,
- negative_prompt_embeds=self.get_negtive_prompt_embeds().unsqueeze(0),
+ prompt_embeds=prompt_embedding.to(self.state.device),
+ negative_prompt_embeds=self.state.negative_prompt_embeds.unsqueeze(0),
image=image,
generator=self.state.generator,
).frames[0]
diff --git a/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_t2v/lora_trainer.py b/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_t2v/lora_trainer.py
index cfa2594..ac1e3ba 100644
--- a/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_t2v/lora_trainer.py
+++ b/src/cogkit/finetune/diffusion/models/cogvideo/cogvideox_t2v/lora_trainer.py
@@ -4,22 +4,21 @@
from typing import Any
import torch
-from diffusers import (
- AutoencoderKLCogVideoX,
- CogVideoXDPMScheduler,
- CogVideoXPipeline,
- CogVideoXTransformer3DModel,
-)
-from diffusers.models.embeddings import get_3d_rotary_pos_embed
from PIL import Image
-from transformers import AutoTokenizer, T5EncoderModel, BitsAndBytesConfig
+from transformers import AutoTokenizer, BitsAndBytesConfig, T5EncoderModel
from typing_extensions import override
from cogkit.finetune import register
from cogkit.finetune.diffusion.schemas import DiffusionComponents
from cogkit.finetune.diffusion.trainer import DiffusionTrainer
-from cogkit.finetune.utils import unwrap_model
from cogkit.utils import load_lora_checkpoint, unload_lora_checkpoint
+from diffusers import (
+ AutoencoderKLCogVideoX,
+ CogVideoXDPMScheduler,
+ CogVideoXPipeline,
+ CogVideoXTransformer3DModel,
+)
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
class CogVideoXT2VLoraTrainer(DiffusionTrainer):
@@ -37,7 +36,7 @@ def load_components(self) -> DiffusionComponents:
dtype = self.state.weight_dtype
components = DiffusionComponents()
- model_path = str(self.args.model_path)
+ model_path = str(self.uargs.model_path)
### pipeline
components.pipeline_cls = CogVideoXPipeline
@@ -51,7 +50,7 @@ def load_components(self) -> DiffusionComponents:
)
### transformer
- if not self.args.low_vram:
+ if not self.uargs.low_vram:
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path,
subfolder="transformer",
@@ -62,7 +61,7 @@ def load_components(self) -> DiffusionComponents:
model_path,
subfolder="transformer",
quantization_config=nf4_config,
- device=self.accelerator.device,
+ device=self.state.device,
torch_dtype=dtype,
)
@@ -80,18 +79,18 @@ def load_components(self) -> DiffusionComponents:
@override
def initialize_pipeline(self, ckpt_path: str | None = None) -> CogVideoXPipeline:
- if not self.args.low_vram:
+ if not self.uargs.low_vram:
pipe = CogVideoXPipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
- transformer=unwrap_model(self.accelerator, self.components.transformer),
+ transformer=self.unwrap_model(self.components.transformer),
scheduler=self.components.scheduler,
)
else:
- assert self.args.training_type == "lora"
+ assert self.uargs.training_type == "lora"
transformer = CogVideoXTransformer3DModel.from_pretrained(
- str(self.args.model_path),
+ str(self.uargs.model_path),
subfolder="transformer",
torch_dtype=self.state.weight_dtype,
)
@@ -127,7 +126,7 @@ def encode_text(self, prompt: str) -> torch.Tensor:
)
prompt_token_ids = prompt_token_ids.input_ids
prompt_embedding = self.components.text_encoder(
- prompt_token_ids.to(self.accelerator.device)
+ prompt_token_ids.to(self.state.device)
).last_hidden_state[0]
# shape of prompt_embedding: [seq_len, hidden_size]
@@ -161,8 +160,9 @@ def collate_fn(self, samples: list[dict[str, Any]]) -> dict[str, Any]:
@override
def compute_loss(self, batch) -> torch.Tensor:
- prompt_embedding = batch["prompt_embedding"]
- latent = batch["encoded_videos"]
+ device = self.state.device
+ prompt_embedding = batch["prompt_embedding"].to(device)
+ latent = batch["encoded_videos"].to(device)
assert latent is not None and prompt_embedding is not None
@@ -188,7 +188,7 @@ def compute_loss(self, batch) -> torch.Tensor:
0,
self.components.scheduler.config.num_train_timesteps,
(batch_size,),
- device=self.accelerator.device,
+ device=device,
)
timesteps = timesteps.long()
@@ -207,7 +207,7 @@ def compute_loss(self, batch) -> torch.Tensor:
num_frames=num_frames,
transformer_config=transformer_config,
vae_scale_factor_spatial=vae_scale_factor_spatial,
- device=self.accelerator.device,
+ device=device,
)
if transformer_config.use_rotary_positional_embeddings
else None
@@ -251,8 +251,8 @@ def validation_step(
num_frames=self.state.train_resolution[0],
height=self.state.train_resolution[1],
width=self.state.train_resolution[2],
- prompt_embeds=prompt_embedding,
- negative_prompt_embeds=self.get_negtive_prompt_embeds().unsqueeze(0),
+ prompt_embeds=prompt_embedding.to(self.state.device),
+ negative_prompt_embeds=self.state.negative_prompt_embeds.unsqueeze(0),
generator=self.state.generator,
).frames[0]
return {"text": prompt, "video": video_generate}
diff --git a/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer.py b/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer.py
index ac91bb2..05a3495 100644
--- a/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer.py
+++ b/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer.py
@@ -13,7 +13,6 @@
from cogkit.finetune.diffusion.trainer import DiffusionTrainer
from cogkit.finetune.utils import (
process_prompt_attention_mask,
- unwrap_model,
replace_attn_processor,
)
from cogkit.utils import load_lora_checkpoint, unload_lora_checkpoint
@@ -43,7 +42,7 @@ def load_components(self) -> DiffusionComponents:
dtype = self.state.weight_dtype
components = DiffusionComponents()
- model_path = str(self.args.model_path)
+ model_path = str(self.uargs.model_path)
### pipeline
components.pipeline_cls = CogView4Pipeline
@@ -59,7 +58,7 @@ def load_components(self) -> DiffusionComponents:
)
### transformer
- if not self.args.low_vram:
+ if not self.uargs.low_vram:
components.transformer = CogView4Transformer2DModel.from_pretrained(
model_path,
subfolder="transformer",
@@ -71,7 +70,7 @@ def load_components(self) -> DiffusionComponents:
subfolder="transformer",
torch_dtype=dtype,
quantization_config=nf4_config,
- device=self.accelerator.device,
+ device=self.state.device,
)
replace_attn_processor(components.transformer, CogView4TrainingAttnProcessor())
@@ -88,23 +87,22 @@ def load_components(self) -> DiffusionComponents:
@override
def initialize_pipeline(self, ckpt_path: str | None = None) -> CogView4Pipeline:
- if not self.args.low_vram:
+ # using bf16 model rather than quantized ones
+ if not self.uargs.low_vram:
pipe = CogView4Pipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
- transformer=unwrap_model(self.accelerator, self.components.transformer),
+ transformer=self.unwrap_model(self.components.transformer),
scheduler=self.components.scheduler,
)
else:
- assert self.args.training_type == "lora"
- # using bf16 model rather than quantized ones
+ assert self.uargs.training_type == "lora"
transformer = CogView4Transformer2DModel.from_pretrained(
- str(self.args.model_path),
+ str(self.uargs.model_path),
subfolder="transformer",
torch_dtype=self.state.weight_dtype,
)
- replace_attn_processor(transformer, CogView4TrainingAttnProcessor())
pipe = CogView4Pipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
@@ -133,7 +131,7 @@ def encode_text(self, prompt: str) -> torch.Tensor:
).input_ids
prompt_embedding = self.components.text_encoder(
- prompt_token_ids.to(self.accelerator.device), output_hidden_states=True
+ prompt_token_ids.to(self.state.device), output_hidden_states=True
).hidden_states[-2][0]
# shape of prompt_embedding: [sequence length, embedding dimension(4096)]
return prompt_embedding
@@ -145,7 +143,7 @@ def get_negtive_prompt_embeds(self) -> torch.Tensor:
@override
def encode_image(self, image: torch.Tensor) -> torch.Tensor:
vae = self.components.vae
- image = image.to(self.accelerator.device, dtype=vae.dtype)
+ image = image.to(self.state.device, dtype=vae.dtype)
latent_dist = vae.encode(image).latent_dist
latent = latent_dist.sample() * vae.config.scaling_factor
return latent
@@ -225,8 +223,9 @@ def collate_fn(self, samples: list[dict[str, Any]]) -> dict[str, Any]:
@override
def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor:
batch_size, text_seqlen, text_embedding_dim = batch["prompt_embedding"].shape
- prompt_embeds = batch["prompt_embedding"]
- latent = batch["encoded_image"]
+ device = self.state.device
+ prompt_embeds = batch["prompt_embedding"].to(device)
+ latent = batch["encoded_image"].to(device)
batch_size, num_channels, height, width = latent.shape
image_height, image_width = self.state.train_resolution
@@ -234,7 +233,7 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor:
image_seq_len = (
(image_height // vae_scale_factor) * (image_width // vae_scale_factor)
) // (self.state.transformer_config.patch_size**2)
- image_seq_len = torch.tensor([image_seq_len], device=self.accelerator.device)
+ image_seq_len = torch.tensor([image_seq_len], device=device)
text_attn_mask = batch["text_attn_mask"]
@@ -248,20 +247,20 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor:
original_size = torch.tensor(
[[image_height, image_width] for _ in range(batch_size)],
dtype=latent.dtype,
- device=self.accelerator.device,
+ device=device,
)
target_size = torch.tensor(
[[image_height, image_width] for _ in range(batch_size)],
dtype=latent.dtype,
- device=self.accelerator.device,
+ device=device,
)
crop_coords = torch.tensor(
- [[0, 0] for _ in range(batch_size)], dtype=latent.dtype, device=self.accelerator.device
+ [[0, 0] for _ in range(batch_size)], dtype=latent.dtype, device=device
)
noise_pred_cond = self.components.transformer(
- hidden_states=model_input,
- encoder_hidden_states=prompt_embeds,
+ hidden_states=model_input.to(dtype=self.state.weight_dtype),
+ encoder_hidden_states=prompt_embeds.to(dtype=self.state.weight_dtype),
timestep=timestep,
original_size=original_size,
target_size=target_size,
@@ -288,11 +287,11 @@ def get_sigmas(self, batch_size: int, vtoken_seq_len: torch.Tensor) -> torch.Ten
scheduler.sigma_min,
scheduler.sigma_max,
scheduler.config.num_train_timesteps,
- device=self.accelerator.device,
+ device=self.state.device,
)
m = (vtoken_seq_len / scheduler.config.base_image_seq_len) ** 0.5
mu = m * scheduler.config.max_shift + scheduler.config.base_shift
- mu = mu.unsqueeze(1)
+ mu = mu.unsqueeze(1).to(sigmas.device)
sigmas = mu / (mu + (1 / sigmas - 1))
sigmas = torch.cat([torch.zeros((batch_size, 1), device=sigmas.device), sigmas], dim=1)
return sigmas
@@ -302,7 +301,7 @@ def get_timestep(self, batch_size: int, num_train_timesteps: int) -> torch.LongT
0,
num_train_timesteps,
(batch_size,),
- device=self.accelerator.device,
+ device=self.state.device,
)
def add_noise(
@@ -335,7 +334,7 @@ def validation_step(
image_generate = pipe(
height=self.state.train_resolution[0],
width=self.state.train_resolution[1],
- prompt_embeds=prompt_embedding,
+ prompt_embeds=prompt_embedding.to(self.state.device),
negative_prompt_embeds=self.state.negative_prompt_embeds.unsqueeze(
0
), # Add batch dimension
diff --git a/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer_packing.py b/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer_packing.py
index 76f1aef..9761e34 100644
--- a/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer_packing.py
+++ b/src/cogkit/finetune/diffusion/models/cogview/cogview4/lora_trainer_packing.py
@@ -41,7 +41,7 @@ def __init__(self, *args, **kwargs) -> None:
self.ROPE_DIM = transformer.config.rope_axes_dim
patch_size = self.PATCH_SIZE
- height, width = self.args.train_resolution
+ height, width = self.uargs.train_resolution
sample_height, sample_width = (
height // self.DOWNSAMPLER_FACTOR,
width // self.DOWNSAMPLER_FACTOR,
@@ -161,9 +161,9 @@ def collate_fn_packing(self, samples: list[dict[str, list[Any]]]) -> dict[str, A
@override
def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor:
- dtype = self.get_training_dtype()
- prompt_embeds = batch["prompt_embedding"]
- latent = batch["encoded_image"]
+ device, dtype = self.state.device, self.state.weight_dtype
+ prompt_embeds = batch["prompt_embedding"].to(device)
+ latent = batch["encoded_image"].to(device)
image_rotary_emb = batch["image_rotary_emb"]
batch_size, text_seqlen, text_embedding_dim = prompt_embeds.shape
batch_size, num_channels, height, width = latent.shape
@@ -182,11 +182,9 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor:
noise = torch.randn_like(latent, dtype=dtype)
model_input, model_label = self.add_noise(latent, noise, timestep, sigmas)
- original_size = original_size.to(dtype=dtype, device=self.accelerator.device)
- target_size = original_size.clone().to(dtype=dtype, device=self.accelerator.device)
- crop_coords = torch.tensor(
- [[0, 0] for _ in range(batch_size)], dtype=dtype, device=self.accelerator.device
- )
+ original_size = original_size.to(dtype=dtype, device=device)
+ target_size = original_size.clone().to(dtype=dtype, device=device)
+ crop_coords = torch.tensor([[0, 0] for _ in range(batch_size)], dtype=dtype, device=device)
noise_pred_cond = self.components.transformer(
hidden_states=model_input.to(dtype=dtype),
@@ -200,7 +198,7 @@ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor:
attention_kwargs=attention_kwargs,
)[0]
- pixel_mask = batch["pixel_mask"]
+ pixel_mask = batch["pixel_mask"].to(device)
loss = torch.sum(((noise_pred_cond - model_label) ** 2) * pixel_mask, dim=(1, 2, 3))
loss = loss / torch.sum(pixel_mask, dim=(1, 2, 3))
loss = loss.mean()
diff --git a/src/cogkit/finetune/diffusion/schemas/args.py b/src/cogkit/finetune/diffusion/schemas/args.py
index 8dad3d1..59e01f3 100644
--- a/src/cogkit/finetune/diffusion/schemas/args.py
+++ b/src/cogkit/finetune/diffusion/schemas/args.py
@@ -1,4 +1,4 @@
-from typing import Literal
+from pathlib import Path
from pydantic import ValidationInfo, field_validator
from typing_extensions import override
@@ -7,16 +7,10 @@
class DiffusionArgs(BaseArgs):
- ########## Model ##########
- model_type: Literal["i2v", "t2v", "t2i"]
-
- ########## Output ##########
- tracker_name: str = "diffusion-tracker"
-
########## Training #########
- # For cogview models, train_resolution is a tuple of (height, width)
- # For cogvideo models, train_resolution is a tuple of (frames, height, width)
- train_resolution: tuple[int, int] | tuple[int, int, int]
+ # For cogview models, train_resolution is a list of (height, width)
+ # For cogvideo models, train_resolution is a list of (frames, height, width)
+ train_resolution: list[int, int] | list[int, int, int]
enable_slicing: bool = True
enable_tiling: bool = True
@@ -25,11 +19,11 @@ class DiffusionArgs(BaseArgs):
enable_packing: bool = False
########## Validation ##########
- gen_fps: int = 15
+ gen_fps: int | None = None
@field_validator("train_resolution")
def validate_train_resolution(
- cls, v: tuple[int, int] | tuple[int, int, int], info: ValidationInfo
+ cls, v: list[int, int] | list[int, int, int], info: ValidationInfo
) -> str:
if len(v) == 2: # cogview models
height, width = v
@@ -49,39 +43,12 @@ def validate_train_resolution(
)
else:
raise ValueError(
- "train_resolution must be a tuple of (height, width) for cogview models or (frames, height, width) for cogvideo models"
+ "train_resolution must be a list of (height, width) for cogview models or (frames, height, width) for cogvideo models"
)
return v
@override
@classmethod
- def parse_args(cls):
- parser = cls.get_base_parser()
-
- # Required arguments
- parser.add_argument("--model_type", type=str, required=True)
- parser.add_argument("--train_resolution", type=str, required=True)
-
- # Model configuration
- parser.add_argument("--enable_slicing", action="store_true")
- parser.add_argument("--enable_tiling", action="store_true")
-
- # Packing
- parser.add_argument("--enable_packing", type=lambda x: x.lower() == "true", default=False)
-
- # Validation
- parser.add_argument("--gen_fps", type=int, default=15)
-
- args = parser.parse_args()
-
- # Convert train_resolution string to tuple
- parts = args.train_resolution.split("x")
- if len(parts) == 2:
- height, width = parts
- args.train_resolution = (int(height), int(width))
- else:
- frames, height, width = parts
- args.train_resolution = (int(frames), int(height), int(width))
-
- return cls(**vars(args))
+ def parse_from_yaml(cls, fpath: str | Path) -> "DiffusionArgs":
+ return super().parse_from_yaml(fpath)
diff --git a/src/cogkit/finetune/diffusion/schemas/state.py b/src/cogkit/finetune/diffusion/schemas/state.py
index fc7bcbe..ad0f86c 100644
--- a/src/cogkit/finetune/diffusion/schemas/state.py
+++ b/src/cogkit/finetune/diffusion/schemas/state.py
@@ -14,13 +14,13 @@ class DiffusionState(BaseState):
# for video input, train_resolution = (frames, height, width)
# for image input, train_resolution = (height, width)
- train_resolution: tuple[int, int, int] | tuple[int, int]
+ train_resolution: tuple[int, int, int] | tuple[int, int] = ()
- # packing realted
- training_seq_length: int | None = None
+ negative_prompt_embeds: torch.Tensor | None = None
validation_prompts: list[str] = []
validation_images: list[Path | None] = []
validation_videos: list[Path | None] = []
- negative_prompt_embeds: torch.Tensor | None = None
+ # packing realted
+ training_seq_length: int | None = None
diff --git a/src/cogkit/finetune/diffusion/trainer.py b/src/cogkit/finetune/diffusion/trainer.py
index 679d1ae..c8522ee 100644
--- a/src/cogkit/finetune/diffusion/trainer.py
+++ b/src/cogkit/finetune/diffusion/trainer.py
@@ -1,101 +1,107 @@
import json
+from pathlib import Path
from typing import Any
import torch
+import torch.distributed as dist
import wandb
from accelerate import cpu_offload
-from accelerate.utils import gather_object
+from torch.utils.data import DistributedSampler
from PIL import Image
from typing_extensions import override
from cogkit.finetune.base import BaseTrainer
-from cogkit.samplers import NaivePackingSampler
-from cogkit.utils import expand_list
+from cogkit.finetune.samplers import DistPackingSampler
+from cogkit.utils import expand_list, guess_generation_mode
+from cogkit.types import GenerationMode
from diffusers.pipelines import DiffusionPipeline
from diffusers.utils.export_utils import export_to_video
from ..utils import (
free_memory,
get_memory_statistics,
- unload_model,
+ gather_object,
+ mkdir,
)
-from .constants import LOG_LEVEL, LOG_NAME
from .schemas import DiffusionArgs, DiffusionComponents, DiffusionState
class DiffusionTrainer(BaseTrainer):
- # If set, should be a list of components to unload (refer to `Components``)
- UNLOAD_LIST: list[str] = None
- LOG_NAME: str = LOG_NAME
- LOG_LEVEL: str = LOG_LEVEL
+ @override
+ def __init__(self, uargs_fpath: str | Path) -> None:
+ super().__init__(uargs_fpath)
+ self.uargs: DiffusionArgs
+ self.state: DiffusionState
+ self.components: DiffusionComponents
@override
- def _init_args(self) -> DiffusionArgs:
- return DiffusionArgs.parse_args()
+ def _init_args(self, uargs_fpath: Path) -> DiffusionArgs:
+ return DiffusionArgs.parse_from_yaml(uargs_fpath)
@override
def _init_state(self) -> DiffusionState:
- return DiffusionState(
- weight_dtype=self.get_training_dtype(),
- train_resolution=self.args.train_resolution,
- )
+ state = DiffusionState(**super()._init_state().model_dump())
+ state.train_resolution = self.uargs.train_resolution
+ return state
@override
def prepare_models(self) -> None:
if self.components.vae is not None:
- if self.args.enable_slicing:
+ if self.uargs.enable_slicing:
self.components.vae.enable_slicing()
- if self.args.enable_tiling:
+ if self.uargs.enable_tiling:
self.components.vae.enable_tiling()
self.state.transformer_config = self.components.transformer.config
@override
def prepare_dataset(self) -> None:
- if self.args.model_type == "i2v":
- from cogkit.datasets import BaseI2VDataset, I2VDatasetWithResize
-
- dataset_cls = I2VDatasetWithResize
- if self.args.enable_packing:
- dataset_cls = BaseI2VDataset
- raise NotImplementedError("Packing for I2V is not implemented")
-
- elif self.args.model_type == "t2v":
- from cogkit.datasets import BaseT2VDataset, T2VDatasetWithResize
-
- dataset_cls = T2VDatasetWithResize
- if self.args.enable_packing:
- dataset_cls = BaseT2VDataset
- raise NotImplementedError("Packing for T2V is not implemented")
-
- elif self.args.model_type == "t2i":
- from cogkit.datasets import (
- T2IDatasetWithFactorResize,
- T2IDatasetWithPacking,
- T2IDatasetWithResize,
- )
-
- dataset_cls = T2IDatasetWithResize
- if self.args.enable_packing:
- dataset_cls = T2IDatasetWithFactorResize
- dataset_cls_packing = T2IDatasetWithPacking
-
- else:
- raise ValueError(f"Invalid model type: {self.args.model_type}")
+ generation_mode = guess_generation_mode(self.components.pipeline_cls)
+ match generation_mode:
+ case GenerationMode.TextToImage:
+ from cogkit.finetune.datasets import (
+ T2IDatasetWithFactorResize,
+ T2IDatasetWithPacking,
+ T2IDatasetWithResize,
+ )
+
+ dataset_cls = T2IDatasetWithResize
+ if self.uargs.enable_packing:
+ dataset_cls = T2IDatasetWithFactorResize
+ dataset_cls_packing = T2IDatasetWithPacking
+
+ case GenerationMode.TextToVideo:
+ from cogkit.finetune.datasets import BaseT2VDataset, T2VDatasetWithResize
+
+ dataset_cls = T2VDatasetWithResize
+ if self.uargs.enable_packing:
+ dataset_cls = BaseT2VDataset
+ raise NotImplementedError("Packing for T2V is not implemented")
+
+ case GenerationMode.ImageToVideo:
+ from cogkit.finetune.datasets import BaseI2VDataset, I2VDatasetWithResize
+
+ dataset_cls = I2VDatasetWithResize
+ if self.uargs.enable_packing:
+ dataset_cls = BaseI2VDataset
+ raise NotImplementedError("Packing for I2V is not implemented")
+
+ case _:
+ raise ValueError(f"Invalid generation mode: {generation_mode}")
additional_args = {
- "device": self.accelerator.device,
+ "device": self.state.device,
"trainer": self,
}
self.train_dataset = dataset_cls(
- **(self.args.model_dump()),
+ **(self.uargs.model_dump()),
**additional_args,
using_train=True,
)
- if self.args.do_validation:
+ if self.uargs.do_validation:
self.test_dataset = dataset_cls(
- **(self.args.model_dump()),
+ **(self.uargs.model_dump()),
**additional_args,
using_train=False,
)
@@ -103,15 +109,15 @@ def prepare_dataset(self) -> None:
### Prepare VAE and text encoder for encoding
self.components.vae.requires_grad_(False)
self.components.text_encoder.requires_grad_(False)
- self.components.vae.to(self.accelerator.device, dtype=self.state.weight_dtype)
- if self.args.low_vram: # offload text encoder to CPU
- cpu_offload(self.components.text_encoder, self.accelerator.device)
+ self.components.vae.to(self.state.device, dtype=self.state.weight_dtype)
+ if self.uargs.low_vram: # offload text encoder to CPU
+ cpu_offload(self.components.text_encoder, self.state.device)
else:
- self.components.text_encoder.to(self.accelerator.device, dtype=self.state.weight_dtype)
+ self.components.text_encoder.to(self.state.device, dtype=self.state.weight_dtype)
### Precompute embedding
self.logger.info("Precomputing embedding ...")
- self.state.negative_prompt_embeds = self.get_negtive_prompt_embeds()
+ self.state.negative_prompt_embeds = self.get_negtive_prompt_embeds().to(self.state.device)
for dataset in [self.train_dataset, self.test_dataset]:
if dataset is None:
@@ -121,28 +127,38 @@ def prepare_dataset(self) -> None:
collate_fn=self.collate_fn,
batch_size=1,
num_workers=0,
- pin_memory=self.args.pin_memory,
+ pin_memory=self.uargs.pin_memory,
+ sampler=DistributedSampler(
+ dataset,
+ num_replicas=self.state.world_size,
+ rank=self.state.global_rank,
+ shuffle=False,
+ ),
)
- tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
for _ in tmp_data_loader:
...
- self.accelerator.wait_for_everyone()
self.logger.info("Precomputing embedding ... Done")
+ dist.barrier()
- unload_model(self.components.vae)
- if not self.args.low_vram:
- unload_model(self.components.text_encoder)
+ self.components.vae = self.components.vae.to("cpu")
+ if not self.uargs.low_vram:
+ self.components.text_encoder = self.components.text_encoder.to("cpu")
free_memory()
- if not self.args.enable_packing:
+ if not self.uargs.enable_packing:
self.train_data_loader = torch.utils.data.DataLoader(
self.train_dataset,
collate_fn=self.collate_fn,
- batch_size=self.args.batch_size,
- num_workers=self.args.num_workers,
- pin_memory=self.args.pin_memory,
- shuffle=True,
+ batch_size=self.uargs.batch_size,
+ num_workers=self.uargs.num_workers,
+ pin_memory=self.uargs.pin_memory,
+ sampler=DistributedSampler(
+ self.train_dataset,
+ num_replicas=self.state.world_size,
+ rank=self.state.global_rank,
+ shuffle=True,
+ ),
)
else:
length_list = [self.sample_to_length(sample) for sample in self.train_dataset]
@@ -150,24 +166,31 @@ def prepare_dataset(self) -> None:
self.train_data_loader = torch.utils.data.DataLoader(
self.train_dataset,
collate_fn=self.collate_fn_packing,
- batch_size=self.args.batch_size,
- num_workers=self.args.num_workers,
- pin_memory=self.args.pin_memory,
- sampler=NaivePackingSampler(
+ batch_size=self.uargs.batch_size,
+ num_workers=self.uargs.num_workers,
+ pin_memory=self.uargs.pin_memory,
+ sampler=DistPackingSampler(
length_list,
self.state.training_seq_length,
shuffle=True,
+ world_size=self.state.world_size,
+ global_rank=self.state.global_rank,
),
)
- if self.args.do_validation:
+ if self.uargs.do_validation:
self.test_data_loader = torch.utils.data.DataLoader(
self.test_dataset,
collate_fn=self.collate_fn,
batch_size=1,
- num_workers=self.args.num_workers,
- pin_memory=self.args.pin_memory,
- shuffle=False,
+ num_workers=self.uargs.num_workers,
+ pin_memory=self.uargs.pin_memory,
+ sampler=DistributedSampler(
+ self.test_dataset,
+ num_replicas=self.state.world_size,
+ rank=self.state.global_rank,
+ shuffle=False,
+ ),
)
@override
@@ -179,7 +202,7 @@ def validate(self, step: int, ckpt_path: str | None = None) -> None:
self.logger.warning("No validation samples found. Skipping validation.")
return
- self.components.transformer.eval()
+ # self.components.transformer.eval()
torch.set_grad_enabled(False)
memory_statistics = get_memory_statistics(self.logger)
@@ -190,25 +213,16 @@ def validate(self, step: int, ckpt_path: str | None = None) -> None:
##### Initialize pipeline #####
pipe = self.initialize_pipeline(ckpt_path=ckpt_path)
- if self.state.using_deepspeed:
- # Can't using model_cpu_offload in deepspeed,
- # so we need to move all components in pipe to device
- self.move_components_to_device(
- dtype=self.state.weight_dtype,
- device=self.accelerator.device,
- ignore_list=["transformer"],
- )
+ # if not using deepspeed, use model_cpu_offload to further reduce memory usage
+ # Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
+ if self.uargs.low_vram:
+ pipe.enable_sequential_cpu_offload(device=self.state.device)
else:
- # if not using deepspeed, use model_cpu_offload to further reduce memory usage
- # Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
- if self.args.low_vram:
- pipe.enable_sequential_cpu_offload(device=self.accelerator.device)
- else:
- pipe.enable_model_cpu_offload(device=self.accelerator.device)
+ pipe.enable_model_cpu_offload(device=self.state.device)
- # Convert all model weights to training dtype
- # Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
- pipe = pipe.to(dtype=self.state.weight_dtype)
+ # Convert all model weights to training dtype
+ # Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
+ pipe = pipe.to(dtype=self.state.weight_dtype)
#################################
@@ -228,8 +242,7 @@ def validate(self, step: int, ckpt_path: str | None = None) -> None:
encoded_video = batch.get("encoded_video", None)
self.logger.debug(
- f"Validating sample {i + 1}/{num_validation_samples} on process {self.accelerator.process_index}. Prompt: {prompt}",
- main_process_only=False,
+ f"Validating sample {i + 1}/{num_validation_samples} on process {self.state.global_rank}. Prompt: {prompt}",
)
val_res = self.validation_step(
pipe=pipe,
@@ -244,9 +257,9 @@ def validate(self, step: int, ckpt_path: str | None = None) -> None:
)
artifacts = {}
- val_path = self.args.output_dir / "validation_res" / f"validation-{step}"
- val_path.mkdir(parents=True, exist_ok=True)
- filename = f"artifact-process{self.accelerator.process_index}-batch{i}"
+ val_path = self.uargs.output_dir / "validation_res" / f"validation-{step}"
+ mkdir(val_path)
+ filename = f"artifact-process{self.state.global_rank}-batch{i}"
image = val_res.get("image", None)
video = val_res.get("video", None)
@@ -258,48 +271,38 @@ def validate(self, step: int, ckpt_path: str | None = None) -> None:
artifacts["image"] = wandb.Image(fpath, caption=prompt)
if video:
fpath = str(val_path / f"{filename}.mp4")
- export_to_video(video, fpath, fps=self.args.gen_fps)
+ export_to_video(video, fpath, fps=self.uargs.gen_fps)
artifacts["video"] = wandb.Video(fpath, caption=prompt)
all_processes_artifacts.append(artifacts)
- all_artifacts = gather_object(all_processes_artifacts)
- all_artifacts = expand_list(all_artifacts)
-
- if self.accelerator.is_main_process:
- tracker_key = "validation"
- for tracker in self.accelerator.trackers:
- if tracker.name == "wandb":
- tracker.log({tracker_key: all_artifacts}, step=step)
-
- ########## Clean up ##########
- if self.state.using_deepspeed:
- del pipe
- # Unload models except those needed for training
- self.move_components_to_device(
- dtype=self.state.weight_dtype, device="cpu", ignore_list=["transformer"]
- )
- else:
- pipe.remove_all_hooks()
- del pipe
- # Load models except those not needed for training
- self.move_components_to_device(
- dtype=self.state.weight_dtype,
- device=self.accelerator.device,
- ignore_list=self.UNLOAD_LIST,
- )
- self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
+ if self.tracker is not None:
+ all_artifacts = gather_object(all_processes_artifacts)
+ all_artifacts = [item for sublist in all_artifacts for item in sublist]
+ all_artifacts = expand_list(all_artifacts)
+ self.tracker.log({"validation": all_artifacts}, step=step)
+
+ # ============= Clean up =============
+ pipe.remove_all_hooks()
+ del pipe
+ # Load models except those not needed for training
+ self.move_components_to_device(
+ dtype=self.state.weight_dtype,
+ device=self.state.device,
+ ignore_list=self.UNLOAD_LIST,
+ )
+ # self.components.transformer.to(self.state.device, dtype=self.state.weight_dtype)
- # Change trainable weights back to fp32 to keep with dtype after prepare the model
- # cast_training_params([self.components.transformer], dtype=torch.float32)
+ # Change trainable weights back to fp32 to keep with dtype after prepare the model
+ # cast_training_params([self.components.transformer], dtype=torch.float32)
free_memory()
- self.accelerator.wait_for_everyone()
- ################################
+ dist.barrier()
+ # =======================================
memory_statistics = get_memory_statistics(self.logger)
self.logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
- torch.cuda.reset_peak_memory_stats(self.accelerator.device)
+ torch.cuda.reset_peak_memory_stats(self.state.device)
torch.set_grad_enabled(True)
self.components.transformer.train()
@@ -309,7 +312,7 @@ def load_components(self) -> DiffusionComponents:
raise NotImplementedError
@override
- def compute_loss(self, batch) -> torch.Tensor:
+ def compute_loss(self, batch: dict[str, Any]) -> torch.Tensor:
raise NotImplementedError
def collate_fn(self, samples: list[dict[str, Any]]):
diff --git a/src/cogkit/finetune/logger.py b/src/cogkit/finetune/logger.py
new file mode 100644
index 0000000..8de97b4
--- /dev/null
+++ b/src/cogkit/finetune/logger.py
@@ -0,0 +1,122 @@
+import logging
+import sys
+import os
+import tempfile
+import torch.distributed as dist
+import inspect
+from pathlib import Path
+from filelock import FileLock
+
+
+class ColoredFormatter(logging.Formatter):
+ COLORS = {
+ logging.DEBUG: "\033[36m",
+ logging.INFO: "\033[32m",
+ logging.WARNING: "\033[33m",
+ logging.ERROR: "\033[31m",
+ logging.CRITICAL: "\033[31;1m",
+ }
+ RESET = "\033[0m"
+ GRAY = "\033[97m"
+
+ def format(self, record):
+ level_color = self.COLORS.get(record.levelno, self.RESET)
+
+ original_levelname = record.levelname
+ timestamp_str = self.formatTime(record, self.datefmt) # Get the exact timestamp string
+
+ formatted_message = super().format(record)
+
+ colored_timestamp = f"{self.GRAY}{timestamp_str}{self.RESET}"
+ formatted_message = formatted_message.replace(timestamp_str, colored_timestamp, 1)
+
+ colored_levelname = f"{level_color}{original_levelname}{self.RESET}"
+ formatted_message = formatted_message.replace(original_levelname, colored_levelname, 1)
+
+ return formatted_message
+
+
+class DistributedLogger:
+ def __init__(
+ self, name: str | None = None, log_file: str | Path | None = None, level: int = logging.INFO
+ ):
+ if not dist.is_initialized():
+ raise RuntimeError("Distributed environment is not setup")
+
+ self.rank = dist.get_rank()
+ self.logger = logging.getLogger(name)
+ self.logger.setLevel(level)
+ self.logger.propagate = False
+
+ base_fmt = f"[rank{self.rank}]: %(asctime)s | %(name)s | %(levelname)s | %(message)s"
+ date_fmt = "%Y-%m-%d %H:%M:%S"
+
+ if self.is_main_process() and log_file is not None:
+ log_file = Path(log_file)
+ if log_file.exists():
+ log_file.write_text("")
+ else:
+ log_file.touch(exist_ok=True)
+
+ fd, flpath = tempfile.mkstemp()
+ os.close(fd) # Close file descriptor as we don't need it
+ self.lock = FileLock(flpath)
+ self.flpath = flpath
+
+ if not self.logger.handlers:
+ console_handler = logging.StreamHandler(sys.stdout)
+ console_formatter = ColoredFormatter(base_fmt, date_fmt)
+ console_handler.setFormatter(console_formatter)
+ self.logger.addHandler(console_handler)
+
+ if log_file is not None:
+ file_handler = logging.FileHandler(log_file)
+ file_formatter = logging.Formatter(base_fmt, date_fmt)
+ file_handler.setFormatter(file_formatter)
+ self.logger.addHandler(file_handler)
+
+ dist.barrier()
+
+ def __del__(self):
+ Path(self.flpath).unlink(missing_ok=True)
+
+ def is_main_process(self):
+ return self.rank == 0
+
+ def log(self, level, msg, main_only=False, *args, **kwargs) -> None:
+ with self.lock:
+ if not main_only:
+ self.logger.log(level, msg, *args, **kwargs)
+ elif main_only and self.is_main_process():
+ self.logger.log(level, msg, *args, **kwargs)
+
+ def debug(self, msg, main_only=False, *args, **kwargs) -> None:
+ self.log(logging.DEBUG, msg, main_only, *args, **kwargs)
+
+ def info(self, msg, main_only=False, *args, **kwargs) -> None:
+ self.log(logging.INFO, msg, main_only, *args, **kwargs)
+
+ def warning(self, msg, main_only=False, *args, **kwargs) -> None:
+ self.log(logging.WARNING, msg, main_only, *args, **kwargs)
+
+ def error(self, msg, main_only=False, *args, **kwargs) -> None:
+ self.log(logging.ERROR, msg, main_only, *args, **kwargs)
+
+ def critical(self, msg, main_only=False, *args, **kwargs) -> None:
+ self.log(logging.CRITICAL, msg, main_only, *args, **kwargs)
+
+
+def get_logger(
+ name: str | None = None, log_file: str | Path | None = None, level: int = logging.INFO
+) -> DistributedLogger:
+ if name is None:
+ frame = inspect.currentframe().f_back
+ module_name = frame.f_globals["__name__"]
+ name_parts = module_name.split(".")
+ if len(name_parts) > 2:
+ name = ".".join(name_parts[-2:])
+ else:
+ name = module_name
+ if log_file is not None:
+ log_file = Path(log_file).expanduser().resolve()
+ return DistributedLogger(name, log_file, level)
diff --git a/src/cogkit/finetune/samplers/__init__.py b/src/cogkit/finetune/samplers/__init__.py
new file mode 100644
index 0000000..570b9e5
--- /dev/null
+++ b/src/cogkit/finetune/samplers/__init__.py
@@ -0,0 +1,3 @@
+from .packing_sampler import DistPackingSampler, NaivePackingSampler
+
+__all__ = ["NaivePackingSampler", "DistPackingSampler"]
diff --git a/src/cogkit/samplers/packing_sampler.py b/src/cogkit/finetune/samplers/packing_sampler.py
similarity index 68%
rename from src/cogkit/samplers/packing_sampler.py
rename to src/cogkit/finetune/samplers/packing_sampler.py
index 430a789..d302467 100644
--- a/src/cogkit/samplers/packing_sampler.py
+++ b/src/cogkit/finetune/samplers/packing_sampler.py
@@ -6,9 +6,13 @@
fixed-size batches while preserving sampling randomness.
"""
-from torch.utils.data import Sampler
-from typing import List, Iterator
import random
+from typing import Iterator, List
+from typing_extensions import override
+
+from torch.utils.data import Sampler
+from cogkit.finetune.utils import get_world_size, get_global_rank
+import torch.distributed as dist
class NaivePackingSampler(Sampler):
@@ -67,3 +71,34 @@ def __iter__(self) -> Iterator[List[int]]:
def __len__(self):
return len(self.idx_buckets)
+
+
+class DistPackingSampler(NaivePackingSampler):
+ @override
+ def __init__(
+ self,
+ length_list: list[int],
+ packed_length: int,
+ shuffle: bool = True,
+ world_size: int | None = None,
+ global_rank: int | None = None,
+ ):
+ super().__init__(length_list, packed_length, shuffle)
+ if not dist.is_initialized():
+ raise ValueError("DistPackingSampler requires distributed training")
+
+ self.world_size = world_size or get_world_size()
+ self.global_rank = global_rank or get_global_rank()
+
+ @override
+ def __iter__(self) -> Iterator[List[int]]:
+ size = len(self.idx_buckets) // self.world_size
+ offset = self.global_rank * size
+ yield from self.idx_buckets[offset : offset + size]
+
+ if self.shuffle:
+ random.shuffle(self.idx_buckets)
+
+ @override
+ def __len__(self):
+ return len(self.idx_buckets) // self.world_size
diff --git a/src/cogkit/finetune/utils/__init__.py b/src/cogkit/finetune/utils/__init__.py
index 8eaeafc..eb210cf 100644
--- a/src/cogkit/finetune/utils/__init__.py
+++ b/src/cogkit/finetune/utils/__init__.py
@@ -1,7 +1,8 @@
-from .checkpointing import * # noqa
-from .file_utils import * # noqa
-from .memory_utils import * # noqa
-from .optimizer_utils import * # noqa
-from .torch_utils import * # noqa
+from .ckpt import * # noqa
+from .memory import * # noqa
from .filters import * # noqa
-from .attn_mask import * # noqa
+from .attention import * # noqa
+from .io import * # noqa
+from .dist import * # noqa
+from .misc import * # noqa
+from .tracker import * # noqa
diff --git a/src/cogkit/finetune/utils/attn_mask.py b/src/cogkit/finetune/utils/attention.py
similarity index 100%
rename from src/cogkit/finetune/utils/attn_mask.py
rename to src/cogkit/finetune/utils/attention.py
diff --git a/src/cogkit/finetune/utils/checkpointing.py b/src/cogkit/finetune/utils/checkpointing.py
deleted file mode 100644
index 5a28505..0000000
--- a/src/cogkit/finetune/utils/checkpointing.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import os
-from pathlib import Path
-
-from ..utils.file_utils import delete_files, find_files
-
-
-def get_latest_ckpt_path_to_resume_from(
- resume_from_checkpoint: str | None, num_update_steps_per_epoch: int, logger
-) -> tuple[str | None, int, int, int]:
- if resume_from_checkpoint is None:
- initial_global_step = 0
- global_step = 0
- first_epoch = 0
- resume_from_checkpoint_path = None
- else:
- resume_from_checkpoint_path = Path(resume_from_checkpoint)
- if not resume_from_checkpoint_path.exists():
- logger.info(
- f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run."
- )
- initial_global_step = 0
- global_step = 0
- first_epoch = 0
- resume_from_checkpoint_path = None
- else:
- logger.info(f"Resuming from checkpoint {resume_from_checkpoint}")
- global_step = int(resume_from_checkpoint_path.name.split("-")[1])
-
- initial_global_step = global_step
- first_epoch = global_step // num_update_steps_per_epoch
-
- return (
- resume_from_checkpoint_path,
- initial_global_step,
- global_step,
- first_epoch,
- )
-
-
-def get_intermediate_ckpt_path(checkpointing_limit: int, step: int, output_dir: str, logger) -> str:
- # before saving state, check if this save would set us over the `checkpointing_limit`
- if checkpointing_limit is not None:
- checkpoints = find_files(output_dir, prefix="checkpoint")
-
- # before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints
- if len(checkpoints) >= checkpointing_limit:
- num_to_remove = len(checkpoints) - checkpointing_limit + 1
- checkpoints_to_remove = checkpoints[0:num_to_remove]
- delete_files(checkpoints_to_remove, logger)
-
- logger.info(f"Checkpointing at step {step}")
- save_path = os.path.join(output_dir, f"checkpoint-{step}")
- logger.info(f"Saving state to {save_path}")
- return save_path
diff --git a/src/cogkit/finetune/utils/ckpt.py b/src/cogkit/finetune/utils/ckpt.py
new file mode 100644
index 0000000..41ac62c
--- /dev/null
+++ b/src/cogkit/finetune/utils/ckpt.py
@@ -0,0 +1,85 @@
+from pathlib import Path
+
+import torch.distributed as dist
+from safetensors.torch import save_file
+from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict, StateDictOptions
+from torch.distributed.checkpoint.stateful import Stateful
+
+from cogkit.utils.lora import save_lora
+
+from .dist import is_main_process
+from .io import check_path
+
+
+def save_state_dict(
+ state_dict: dict, save_dir: str, fname: str, metadata: dict = None, lora: bool = False
+) -> None:
+ if is_main_process():
+ if lora:
+ save_lora(state_dict, save_dir)
+ else:
+ save_file(state_dict, save_dir / fname, metadata)
+
+ dist.barrier()
+
+
+def get_global_step(ckpt_path: str | Path) -> int:
+ ckpt_path = Path(ckpt_path)
+ check_path(ckpt_path, must_exists=True, must_dir=True)
+
+ try:
+ global_step = int(ckpt_path.name.split("-")[1])
+ except IndexError:
+ raise ValueError(f"Checkpoint path '{ckpt_path}' is not in the correct format.")
+
+ return global_step
+
+
+class AppState(Stateful):
+ """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
+ with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
+ dcp.save/load APIs.
+
+ Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
+ and optimizer.
+
+ For more details, please refer to: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
+ """
+
+ def __init__(self, model, optimizer=None, lora: bool = False):
+ self.model = model
+ self.optimizer = optimizer
+ self.lora = lora
+
+ def state_dict(self):
+ # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
+ model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
+ if self.lora:
+ from peft import get_peft_model_state_dict
+
+ model_state_dict = get_peft_model_state_dict(self.model)
+
+ return {"model": model_state_dict, "optim": optimizer_state_dict}
+
+ def load_state_dict(self, state_dict):
+ # sets our state dicts on the model and optimizer, now that we've loaded
+ if self.lora:
+ from peft.utils.save_and_load import _insert_adapter_name_into_state_dict
+ from cogkit.utils.lora import _ADAPTER_NAME
+ from peft.utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING
+
+ state_dict["model"] = _insert_adapter_name_into_state_dict(
+ state_dict["model"],
+ adapter_name=_ADAPTER_NAME,
+ parameter_prefix=PEFT_TYPE_TO_PREFIX_MAPPING[
+ self.model.peft_config[_ADAPTER_NAME].peft_type
+ ],
+ )
+
+ set_state_dict(
+ self.model,
+ self.optimizer,
+ model_state_dict=state_dict["model"],
+ optim_state_dict=state_dict["optim"],
+ options=StateDictOptions(strict=False),
+ )
diff --git a/src/cogkit/finetune/utils/dist.py b/src/cogkit/finetune/utils/dist.py
new file mode 100644
index 0000000..562bf84
--- /dev/null
+++ b/src/cogkit/finetune/utils/dist.py
@@ -0,0 +1,36 @@
+import os
+from typing import Any
+
+import torch
+import torch.distributed as dist
+
+
+def check_distributed() -> None:
+ if not dist.is_initialized():
+ raise RuntimeError("Distributed training is not initialized")
+
+
+def is_main_process() -> bool:
+ return dist.get_rank() == 0
+
+
+def get_world_size() -> int:
+ return dist.get_world_size()
+
+
+def get_global_rank() -> int:
+ return dist.get_rank()
+
+
+def get_local_rank() -> int:
+ return int(os.environ["LOCAL_RANK"])
+
+
+def get_device() -> torch.device:
+ return torch.device(f"cuda:{get_local_rank()}")
+
+
+def gather_object(object: Any) -> list[Any]:
+ output_objects = [None for _ in range(get_world_size())]
+ dist.all_gather_object(output_objects, object)
+ return output_objects
diff --git a/src/cogkit/finetune/utils/file_utils.py b/src/cogkit/finetune/utils/file_utils.py
deleted file mode 100644
index 93207b6..0000000
--- a/src/cogkit/finetune/utils/file_utils.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import os
-import shutil
-from pathlib import Path
-
-
-def find_files(dir: str | Path, prefix: str = "checkpoint") -> list[str]:
- if not isinstance(dir, Path):
- dir = Path(dir)
- if not dir.exists():
- return []
- checkpoints = os.listdir(dir.as_posix())
- checkpoints = [c for c in checkpoints if c.startswith(prefix)]
- checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
- checkpoints = [dir / c for c in checkpoints]
- return checkpoints
-
-
-def delete_files(dirs: str | list[str] | Path | list[Path], logger) -> None:
- if not isinstance(dirs, list):
- dirs = [dirs]
- dirs = [Path(d) if isinstance(d, str) else d for d in dirs]
- logger.info(f"Deleting files: {dirs}")
- for dir in dirs:
- if not dir.exists():
- continue
- shutil.rmtree(dir, ignore_errors=True)
-
-
-def string_to_filename(s: str) -> str:
- return (
- s.replace(" ", "-")
- .replace("/", "-")
- .replace(":", "-")
- .replace(".", "-")
- .replace(",", "-")
- .replace(";", "-")
- .replace("!", "-")
- .replace("?", "-")
- )
diff --git a/src/cogkit/finetune/utils/io.py b/src/cogkit/finetune/utils/io.py
new file mode 100644
index 0000000..84934f5
--- /dev/null
+++ b/src/cogkit/finetune/utils/io.py
@@ -0,0 +1,107 @@
+from pathlib import Path
+import shutil
+import torch.distributed as dist
+
+from cogkit.finetune.logger import get_logger
+
+from .dist import is_main_process
+
+
+def check_path(
+ path: str | Path | None,
+ must_exists: bool = False,
+ must_dir: bool = False,
+ must_file: bool = False,
+) -> None:
+ if path is None:
+ raise ValueError("Path is None")
+ if isinstance(path, str):
+ path = Path(path)
+ if must_exists and not path.exists():
+ raise FileNotFoundError(f"Path '{path}' does not exist.")
+ if must_dir and not path.is_dir():
+ raise FileNotFoundError(f"Path '{path}' is not a directory.")
+ if must_file and not path.is_file():
+ raise FileNotFoundError(f"Path '{path}' is not a file.")
+
+
+def resolve_path(path: str | Path) -> str:
+ if isinstance(path, str):
+ path = Path(path)
+ check_path(path)
+ return str(path.expanduser().resolve())
+
+
+def mkdir(path: str | Path) -> None:
+ _logger = get_logger()
+ if is_main_process():
+ check_path(path)
+ Path(resolve_path(path)).mkdir(parents=True, exist_ok=True)
+ _logger.debug(f"Creating directory: {resolve_path(path)}")
+
+ dist.barrier()
+
+
+def touch(path: str | Path) -> None:
+ _logger = get_logger()
+ if is_main_process():
+ check_path(path)
+ Path(resolve_path(path)).touch()
+ _logger.debug(f"Touching file: {resolve_path(path)}")
+
+ dist.barrier()
+
+
+def list_files(dir: str | Path | None, prefix: str = "checkpoint") -> list[str]:
+ _logger = get_logger()
+ if dir is None:
+ _logger.warning("Directory is None, returning empty list")
+ return []
+ return [str(p) for p in Path(resolve_path(dir)).glob(f"{prefix}*")]
+
+
+def rmdir(path: str | Path) -> None:
+ _logger = get_logger()
+ if is_main_process():
+ check_path(path, must_exists=True, must_dir=True)
+ Path(resolve_path(path)).rmdir()
+ _logger.debug(f"Deleted empty directory: {resolve_path(path)}")
+
+ dist.barrier()
+
+
+def rmfile(path: str | Path, must_exists: bool = True) -> None:
+ _logger = get_logger()
+ if is_main_process():
+ check_path(path, must_exists=must_exists, must_file=True)
+ Path(resolve_path(path)).unlink()
+ _logger.debug(f"Deleted file: {resolve_path(path)}")
+
+ dist.barrier()
+
+
+def rmtree(path: str | Path) -> None:
+ """Recursively delete a directory tree."""
+ _logger = get_logger()
+ if is_main_process():
+ path = Path(resolve_path(path))
+ check_path(path, must_exists=True, must_dir=True)
+ shutil.rmtree(path)
+ _logger.debug(f"Recursively deleted directory: {path}")
+
+ dist.barrier()
+
+
+def delete_files(files: list[str], recursive: bool = True) -> None:
+ for file in files:
+ check_path(file, must_exists=True)
+ path = Path(file)
+ if path.is_dir():
+ if recursive:
+ rmtree(path)
+ else:
+ rmdir(path)
+ else:
+ rmfile(path)
+
+ dist.barrier()
diff --git a/src/cogkit/finetune/utils/memory.py b/src/cogkit/finetune/utils/memory.py
new file mode 100644
index 0000000..9f063db
--- /dev/null
+++ b/src/cogkit/finetune/utils/memory.py
@@ -0,0 +1,35 @@
+import gc
+from typing import Any
+
+import torch
+
+
+def get_memory_statistics(device: torch.device, precision: int = 3) -> dict[str, Any]:
+ memory_allocated = None
+ memory_reserved = None
+ max_memory_allocated = None
+ max_memory_reserved = None
+
+ device = torch.cuda.current_device()
+ memory_allocated = torch.cuda.memory_allocated(device)
+ memory_reserved = torch.cuda.memory_reserved(device)
+ max_memory_allocated = torch.cuda.max_memory_allocated(device)
+ max_memory_reserved = torch.cuda.max_memory_reserved(device)
+
+ return {
+ "memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision),
+ "memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision),
+ "max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision),
+ "max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision),
+ }
+
+
+def bytes_to_gigabytes(x: int) -> float:
+ if x is not None:
+ return x / 1024**3
+
+
+def free_memory() -> None:
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
diff --git a/src/cogkit/finetune/utils/memory_utils.py b/src/cogkit/finetune/utils/memory_utils.py
deleted file mode 100644
index 3427410..0000000
--- a/src/cogkit/finetune/utils/memory_utils.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import gc
-from typing import Any
-
-import torch
-
-
-def get_memory_statistics(logger, precision: int = 3) -> dict[str, Any]:
- memory_allocated = None
- memory_reserved = None
- max_memory_allocated = None
- max_memory_reserved = None
-
- if torch.cuda.is_available():
- device = torch.cuda.current_device()
- memory_allocated = torch.cuda.memory_allocated(device)
- memory_reserved = torch.cuda.memory_reserved(device)
- max_memory_allocated = torch.cuda.max_memory_allocated(device)
- max_memory_reserved = torch.cuda.max_memory_reserved(device)
-
- elif torch.mps.is_available():
- memory_allocated = torch.mps.current_allocated_memory()
-
- else:
- logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.")
-
- return {
- "memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision),
- "memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision),
- "max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision),
- "max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision),
- }
-
-
-def bytes_to_gigabytes(x: int) -> float:
- if x is not None:
- return x / 1024**3
-
-
-def free_memory() -> None:
- if torch.cuda.is_available():
- gc.collect()
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
-
- # TODO(aryan): handle non-cuda devices
-
-
-def unload_model(model):
- model.to("cpu")
-
-
-def make_contiguous(
- x: torch.Tensor | dict[str, torch.Tensor],
-) -> torch.Tensor | dict[str, torch.Tensor]:
- if isinstance(x, torch.Tensor):
- return x.contiguous()
- elif isinstance(x, dict):
- return {k: make_contiguous(v) for k, v in x.items()}
- else:
- return x
diff --git a/src/cogkit/finetune/utils/misc.py b/src/cogkit/finetune/utils/misc.py
new file mode 100644
index 0000000..6efeb13
--- /dev/null
+++ b/src/cogkit/finetune/utils/misc.py
@@ -0,0 +1,18 @@
+import torch
+
+
+def cast_training_params(model: torch.nn.Module | list[torch.nn.Module], dtype=torch.float32):
+ """
+ Casts the training parameters of the model to the specified data type.
+
+ Args:
+ model: The PyTorch model whose parameters will be cast.
+ dtype: The data type to which the model parameters will be cast.
+ """
+ if not isinstance(model, list):
+ model = [model]
+ for m in model:
+ for param in m.parameters():
+ # only upcast trainable parameters into fp32
+ if param.requires_grad:
+ param.data = param.to(dtype)
diff --git a/src/cogkit/finetune/utils/optimizer_utils.py b/src/cogkit/finetune/utils/optimizer_utils.py
deleted file mode 100644
index 5b38fe6..0000000
--- a/src/cogkit/finetune/utils/optimizer_utils.py
+++ /dev/null
@@ -1,186 +0,0 @@
-import inspect
-
-import torch
-
-
-def get_optimizer(
- params_to_optimize,
- logger,
- optimizer_name: str = "adam",
- learning_rate: float = 1e-3,
- beta1: float = 0.9,
- beta2: float = 0.95,
- beta3: float = 0.98,
- epsilon: float = 1e-8,
- weight_decay: float = 1e-4,
- prodigy_decouple: bool = False,
- prodigy_use_bias_correction: bool = False,
- prodigy_safeguard_warmup: bool = False,
- use_8bit: bool = False,
- use_4bit: bool = False,
- use_torchao: bool = False,
- use_deepspeed: bool = False,
- use_cpu_offload_optimizer: bool = False,
- offload_gradients: bool = False,
-) -> torch.optim.Optimizer:
- optimizer_name = optimizer_name.lower()
-
- # Use DeepSpeed optimzer
- if use_deepspeed:
- from accelerate.utils import DummyOptim
-
- return DummyOptim(
- params_to_optimize,
- lr=learning_rate,
- betas=(beta1, beta2),
- eps=epsilon,
- weight_decay=weight_decay,
- )
-
- if use_8bit and use_4bit:
- raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
-
- if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
- try:
- import torchao
-
- torchao.__version__
- except ImportError:
- raise ImportError(
- "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
- )
-
- if not use_torchao and use_4bit:
- raise ValueError("4-bit Optimizers are only supported with torchao.")
-
- # Optimizer creation
- supported_optimizers = ["adam", "adamw", "prodigy", "came"]
- if optimizer_name not in supported_optimizers:
- logger.warning(
- f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
- )
- optimizer_name = "adamw"
-
- if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
- raise ValueError(
- "`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers."
- )
-
- if use_8bit:
- try:
- import bitsandbytes as bnb
- except ImportError:
- raise ImportError(
- "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
- )
-
- if optimizer_name == "adamw":
- if use_torchao:
- from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
-
- optimizer_class = (
- AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
- )
- else:
- optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
-
- init_kwargs = {
- "betas": (beta1, beta2),
- "eps": epsilon,
- "weight_decay": weight_decay,
- }
-
- elif optimizer_name == "adam":
- if use_torchao:
- from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
-
- optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
- else:
- optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
-
- init_kwargs = {
- "betas": (beta1, beta2),
- "eps": epsilon,
- "weight_decay": weight_decay,
- }
-
- elif optimizer_name == "prodigy":
- try:
- import prodigyopt
- except ImportError:
- raise ImportError(
- "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`"
- )
-
- optimizer_class = prodigyopt.Prodigy
-
- if learning_rate <= 0.1:
- logger.warning(
- "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
- )
-
- init_kwargs = {
- "lr": learning_rate,
- "betas": (beta1, beta2),
- "beta3": beta3,
- "eps": epsilon,
- "weight_decay": weight_decay,
- "decouple": prodigy_decouple,
- "use_bias_correction": prodigy_use_bias_correction,
- "safeguard_warmup": prodigy_safeguard_warmup,
- }
-
- elif optimizer_name == "came":
- try:
- import came_pytorch
- except ImportError:
- raise ImportError(
- "To use CAME, please install the came-pytorch library: `pip install came-pytorch`"
- )
-
- optimizer_class = came_pytorch.CAME
-
- init_kwargs = {
- "lr": learning_rate,
- "eps": (1e-30, 1e-16),
- "betas": (beta1, beta2, beta3),
- "weight_decay": weight_decay,
- }
-
- if use_cpu_offload_optimizer:
- from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
-
- if "fused" in inspect.signature(optimizer_class.__init__).parameters:
- init_kwargs.update({"fused": True})
-
- optimizer = CPUOffloadOptimizer(
- params_to_optimize,
- optimizer_class=optimizer_class,
- offload_gradients=offload_gradients,
- **init_kwargs,
- )
- else:
- optimizer = optimizer_class(params_to_optimize, **init_kwargs)
-
- return optimizer
-
-
-def gradient_norm(parameters):
- norm = 0
- for param in parameters:
- if param.grad is None:
- continue
- local_norm = param.grad.detach().data.norm(2)
- norm += local_norm.item() ** 2
- norm = norm**0.5
- return norm
-
-
-def max_gradient(parameters):
- max_grad_value = float("-inf")
- for param in parameters:
- if param.grad is None:
- continue
- local_max_grad = param.grad.detach().data.abs().max()
- max_grad_value = max(max_grad_value, local_max_grad.item())
- return max_grad_value
diff --git a/src/cogkit/finetune/utils/torch_utils.py b/src/cogkit/finetune/utils/torch_utils.py
deleted file mode 100644
index 9db6800..0000000
--- a/src/cogkit/finetune/utils/torch_utils.py
+++ /dev/null
@@ -1,50 +0,0 @@
-import torch
-from accelerate import Accelerator
-from diffusers.utils.torch_utils import is_compiled_module
-
-
-def unwrap_model(accelerator: Accelerator, model):
- model = accelerator.unwrap_model(model)
- model = model._orig_mod if is_compiled_module(model) else model
- return model
-
-
-def align_device_and_dtype(
- x: torch.Tensor | dict[str, torch.Tensor],
- device: torch.device | None = None,
- dtype: torch.dtype | None = None,
-):
- if isinstance(x, torch.Tensor):
- if device is not None:
- x = x.to(device)
- if dtype is not None:
- x = x.to(dtype)
- elif isinstance(x, dict):
- if device is not None:
- x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
- if dtype is not None:
- x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
- return x
-
-
-def expand_tensor_to_dims(tensor, ndim):
- while len(tensor.shape) < ndim:
- tensor = tensor.unsqueeze(-1)
- return tensor
-
-
-def cast_training_params(model: torch.nn.Module | list[torch.nn.Module], dtype=torch.float32):
- """
- Casts the training parameters of the model to the specified data type.
-
- Args:
- model: The PyTorch model whose parameters will be cast.
- dtype: The data type to which the model parameters will be cast.
- """
- if not isinstance(model, list):
- model = [model]
- for m in model:
- for param in m.parameters():
- # only upcast trainable parameters into fp32
- if param.requires_grad:
- param.data = param.to(dtype)
diff --git a/src/cogkit/finetune/utils/tracker.py b/src/cogkit/finetune/utils/tracker.py
new file mode 100644
index 0000000..9bb1dec
--- /dev/null
+++ b/src/cogkit/finetune/utils/tracker.py
@@ -0,0 +1,27 @@
+from typing import Any
+
+import torch.distributed as dist
+import wandb
+
+from .dist import is_main_process
+
+
+class WandbTracker:
+ def __init__(self, name: str, config: dict[str, Any], **kwargs: Any) -> None:
+ if is_main_process():
+ self.tracker = wandb.init(
+ name=name,
+ config=config,
+ **kwargs,
+ )
+ dist.barrier()
+
+ def log(self, *args: Any, **kwargs: Any) -> None:
+ if is_main_process():
+ self.tracker.log(*args, **kwargs)
+ dist.barrier()
+
+ def finish(self) -> None:
+ if is_main_process():
+ self.tracker.finish()
+ dist.barrier()
diff --git a/src/cogkit/samplers/__init__.py b/src/cogkit/samplers/__init__.py
deleted file mode 100644
index 55c3011..0000000
--- a/src/cogkit/samplers/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from cogkit.samplers.packing_sampler import NaivePackingSampler
-
-__all__ = ["NaivePackingSampler"]
diff --git a/src/cogkit/utils/__init__.py b/src/cogkit/utils/__init__.py
index b2ec328..f4c9383 100644
--- a/src/cogkit/utils/__init__.py
+++ b/src/cogkit/utils/__init__.py
@@ -1,20 +1,21 @@
# -*- coding: utf-8 -*-
-from cogkit.utils.diffusion_pipeline import get_pipeline_meta
-from cogkit.utils.dtype import cast_to_torch_dtype
-from cogkit.utils.lora import (
+from .diffusion_pipeline import get_pipeline_meta
+from .dtype import cast_to_torch_dtype
+from .lora import (
load_lora_checkpoint,
unload_lora_checkpoint,
inject_lora,
save_lora,
unload_lora,
)
-from cogkit.utils.misc import guess_generation_mode, flatten_dict, expand_list
-from cogkit.utils.path import mkdir, resolve_path
-from cogkit.utils.prompt import convert_prompt
-from cogkit.utils.random import rand_generator
-from cogkit.utils.load import load_pipeline
+from .misc import guess_generation_mode, flatten_dict, expand_list
+from .path import mkdir, resolve_path
+from .prompt import convert_prompt
+from .random import rand_generator
+from .load import load_pipeline
+from .seed import set_global_seed
__all__ = [
"get_pipeline_meta",
@@ -32,4 +33,5 @@
"convert_prompt",
"flatten_dict",
"expand_list",
+ "set_global_seed",
]
diff --git a/src/cogkit/utils/lora.py b/src/cogkit/utils/lora.py
index db9e142..4eecd71 100644
--- a/src/cogkit/utils/lora.py
+++ b/src/cogkit/utils/lora.py
@@ -26,6 +26,7 @@
# Standard filename for LoRA adapter weights
_LORA_WEIGHT_NAME = "adapter_model.safetensors"
+_ADAPTER_NAME = "default"
def _get_lora_config() -> LoraConfig:
@@ -37,7 +38,9 @@ def _get_lora_config() -> LoraConfig:
)
-def inject_lora(model, lora_dir_or_state_dict: str | Path | None = None) -> None:
+def inject_lora(
+ model, lora_dir_or_state_dict: str | Path | None = None, adapter_name: str = _ADAPTER_NAME
+) -> None:
"""
Inject LoRA adapters into the model.
@@ -49,9 +52,10 @@ def inject_lora(model, lora_dir_or_state_dict: str | Path | None = None) -> None
model: The model to inject LoRA adapters into
lora_dir_or_state_dict: Path to a LoRA checkpoint directory, a state dict,
or None for random initialization
+ adapter_name: The name of the adapter to inject
"""
transformer_lora_config = _get_lora_config()
- inject_adapter_in_model(transformer_lora_config, model)
+ inject_adapter_in_model(transformer_lora_config, model, adapter_name=adapter_name)
if lora_dir_or_state_dict is None:
return
@@ -65,7 +69,7 @@ def inject_lora(model, lora_dir_or_state_dict: str | Path | None = None) -> None
else:
peft_state_dict = lora_dir_or_state_dict
- set_peft_model_state_dict(model, peft_state_dict)
+ set_peft_model_state_dict(model, peft_state_dict, adapter_name=adapter_name)
def save_lora(model, lora_dir: str | Path) -> None:
diff --git a/src/cogkit/utils/misc.py b/src/cogkit/utils/misc.py
index 0af5c98..b7ff2d4 100644
--- a/src/cogkit/utils/misc.py
+++ b/src/cogkit/utils/misc.py
@@ -83,7 +83,10 @@ def guess_generation_mode(
if isinstance(pipeline_or_path, str):
pl_cls_name = get_pipeline_meta(pipeline_or_path)["cls_name"]
else:
- pl_cls_name = pipeline_or_path.__class__.__name__
+ if isinstance(pipeline_or_path, type):
+ pl_cls_name = pipeline_or_path.__name__
+ else:
+ pl_cls_name = pipeline_or_path.__class__.__name__
if pl_cls_name not in _SUPPORTED_PIPELINE:
err_msg = f"The pipeline '{pl_cls_name}' is not supported."
diff --git a/src/cogkit/utils/seed.py b/src/cogkit/utils/seed.py
new file mode 100644
index 0000000..e10cd4f
--- /dev/null
+++ b/src/cogkit/utils/seed.py
@@ -0,0 +1,18 @@
+import random
+
+import numpy as np
+import torch
+
+
+def set_global_seed(seed: int) -> None:
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ if torch.backends.mps.is_available():
+ torch.backends.mps.manual_seed(seed)
diff --git a/tests/test_sampler.py b/tests/test_sampler.py
index 2ea5370..81fd4c6 100644
--- a/tests/test_sampler.py
+++ b/tests/test_sampler.py
@@ -4,7 +4,7 @@
import torch
from torch.utils.data import DataLoader, Dataset
-from cogkit.samplers import NaivePackingSampler
+from cogkit.finetune.samplers import NaivePackingSampler
# ==============================================================================
diff --git a/tools/converters/merge.py b/tools/converters/merge.py
new file mode 100755
index 0000000..e6999f0
--- /dev/null
+++ b/tools/converters/merge.py
@@ -0,0 +1,50 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+
+import argparse
+from pathlib import Path
+
+import torch
+from safetensors.torch import save_file
+from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
+
+from cogkit.utils.lora import _LORA_WEIGHT_NAME
+
+TORCH_SAVE_CHECKPOINT_DIR = "diffusion_pytorch_model.bin"
+
+
+def main(checkpoint_dir: str, output_dir: str, is_lora: bool = False):
+ # convert dcp model to torch.save (assumes checkpoint was generated as above)
+ checkpoint_dir = Path(checkpoint_dir)
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ ckpt_file = output_dir / TORCH_SAVE_CHECKPOINT_DIR
+
+ print("Converting FSDP checkpoint to torch.save format...")
+ dcp_to_torch_save(checkpoint_dir, ckpt_file)
+ state = torch.load(ckpt_file, map_location="cpu")
+ print("Deleting torch checkpoint...")
+ ckpt_file.unlink()
+ model_weights = state["app"]["model"]
+
+ print("Saving transformer weights...")
+ if is_lora:
+ ckpt_file = ckpt_file.with_name(_LORA_WEIGHT_NAME)
+ save_file(model_weights, ckpt_file)
+
+ else:
+ ckpt_file = ckpt_file.with_name(TORCH_SAVE_CHECKPOINT_DIR)
+ torch.save(model_weights, ckpt_file)
+
+ print("Done.")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--checkpoint_dir", type=str, required=True)
+ parser.add_argument("--output_dir", type=str, required=True)
+ parser.add_argument("--lora", action="store_true", default=False)
+ args = parser.parse_args()
+
+ main(args.checkpoint_dir, args.output_dir, args.lora)