-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig.ymal
More file actions
88 lines (68 loc) · 3.3 KB
/
config.ymal
File metadata and controls
88 lines (68 loc) · 3.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# config.yaml
# Configuration file for training the Align Your Flow model.
# --------------------------------------------------------------------------
# Model Configuration
# --------------------------------------------------------------------------
model:
# Hugging Face model IDs for the teacher and autoguide models.
# Using FLUX.1-schnell as it's smaller and good for demonstration.
teacher_model_id: "black-forest-labs/FLUX.1-schnell"
autoguide_model_id: "black-forest-labs/FLUX.1-schnell" # A different/weaker model is ideal
# --------------------------------------------------------------------------
# Data Configuration
# --------------------------------------------------------------------------
data:
# Name of the dataset, used for logic switching.
name: "text-to-image-2M"
# List of URLs for the webdataset .tar shards.
# For a full run, this list would contain all shards.
# For a quick test, a single shard URL is sufficient.
urls: "pipe:curl -L -s https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_{000000..000000}.tar"
# Number of samples in the specified URL list (for one shard of this dataset).
# This is needed to calculate the number of steps per epoch.
num_samples: 42000
# Target resolution for the images.
resolution: 512
# Number of worker processes for the DataLoader.
num_workers: 4
# --------------------------------------------------------------------------
# Training Configuration
# --------------------------------------------------------------------------
train:
# Directory to save checkpoints and logs.
output_dir: "./ayf-training-output"
# Total number of training epochs.
num_epochs: 50
# Batch size per GPU. The global batch size will be (batch_size_per_gpu * num_gpus * gradient_accumulation_steps).
batch_size_per_gpu: 8
# Number of steps to accumulate gradients before an optimizer step.
gradient_accumulation_steps: 4
# Learning rates for the student (generator) and discriminator.
# Appendix F.2 suggests 2e-5 for finetuning. 1e-4 is a common start for distillation.
lr_student: 1.0e-4
lr_discriminator: 2.0e-5
# Mixed precision setting for training ('no', 'fp16', 'bf16').
mixed_precision: "bf16"
# How often to save a model checkpoint (in epochs).
save_epoch_freq: 1
# Epoch at which to start the adversarial finetuning phase.
adversarial_start_epoch: 40 # Start finetuning after initial distillation.
# Weight for the adversarial loss term for the generator (alpha in Algo 2).
adv_loss_weight: 0.1
# Regularization weight for the R1/R2 gradient penalty (beta in Algo 2).
r1_reg_weight: 0.1
# --------------------------------------------------------------------------
# AYF-EMD Loss Parameters (from Appendix F.1)
# --------------------------------------------------------------------------
ayf_loss:
# Mean and standard deviation for the timestep sampling distribution.
p_mean: -0.8
p_std: 1.0
# Number of iterations for the tangent warmup.
warmup_iters: 10000
# Constant for tangent normalization.
tangent_norm_c: 0.1
# Weight for autoguidance (lambda in the paper).
# A range of [1, 3] is sampled during training. This is a placeholder,
# as the sampling is hardcoded in the script per the paper's description.
autoguide_weight: 2.0