-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
153 lines (124 loc) · 5.55 KB
/
train.py
File metadata and controls
153 lines (124 loc) · 5.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
from torch.utils.data import DataLoader
from safetensors import safe_open
from safetensors.torch import save_file, load_file
import os
import json
from wan.configs.wan_t2v_1_3B import t2v_1_3B
import wandb
import omegaconf
import argparse
from tqdm import tqdm
from pipeline import WanMeanFlowXPipeline
from dataset import OpenVid1MDataset
def load_wan_state_dict(base_path, do_print):
index_path = os.path.join(base_path, 'wan', 'diffusion_pytorch_model.safetensors.index.json')
with open(index_path, 'r') as f:
index = json.load(f)
state_dict = {}
shards = {}
for shard_file in index['weight_map'].values():
shards[shard_file] = True
for shard_file in shards.keys():
if do_print:
print(f"Loading {shard_file}")
shard_path = os.path.join(base_path, 'wan', shard_file)
with safe_open(shard_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
return state_dict
class WanMeanFlowXTrainer:
def __init__(self, config):
self.device = torch.device('cuda')
self.pipeline = WanMeanFlowXPipeline(
checkpoint_dir=config.wan_checkpoint_dir,
device=self.device,
config=t2v_1_3B,
)
wan_sd = load_wan_state_dict(config.wan_checkpoint_dir, do_print=True)
self.pipeline.model.load_state_dict(wan_sd, strict=False)
for name, param in self.pipeline.model.named_parameters():
if config.allowed_param_names is None or any(allowed_name in name for allowed_name in config.allowed_param_names):
param.requires_grad = True
else:
param.requires_grad = False
muon_body_params = [(name, param) for name, param in self.pipeline.model.blocks.named_parameters() if param.ndim >= 2 and "modulation" not in name]
adamw_body_params = [(name, param) for name, param in self.pipeline.model.blocks.named_parameters() if param.ndim < 2 or "modulation" in name]
adamw_other_params = [(name, param) for name, param in self.pipeline.model.named_parameters() if not name.startswith("blocks")]
self.muon_optimizer = torch.optim.Muon(muon_body_params, lr=2e-5)
self.adamw_optimizer = torch.optim.AdamW(
[*adamw_body_params, *adamw_other_params],
lr=1e-5,
weight_decay=0.01
)
for name, param in muon_body_params + adamw_body_params + adamw_other_params:
assert param.requires_grad
print(f"{name} requires grad")
self.dataset = OpenVid1MDataset(config.data_dir, config.csv_path)
self.dataloader = DataLoader(self.dataset, batch_size=config.batch_size, shuffle=True)
self.config = config
wandb.init(project="wan_mean_flow_x", name="wan1.3B")
def sample_r_t(self):
if torch.rand(1).item() > 0.9:
# 10% of the time, sample r from uniform distribution
r = torch.rand(1, device=self.device)
t = torch.rand(1, device=self.device)
else:
# 90% of the time, sample r from logit-normal distribution
r = sample_logit_normal(std=0.8, mean=0.0).to(self.device)
t = sample_logit_normal(std=0.8, mean=0.0).to(self.device)
if r > t:
r, t = t, r
if torch.rand(1).item() < 0.5:
# 50% of the time, set r to 1.0 (i.e., only use t, instantaneous velocity)
r = torch.tensor(1.0, device=self.device)
return r, t
def train_step(self, batch):
B = len(batch['prompt'])
rs, ts = [], []
for _ in range(B):
r, t = self.sample_r_t()
rs.append(r)
ts.append(t)
r = torch.stack(rs, device=self.device) # shape (B,)
t = torch.stack(ts, device=self.device) # shape (B,)
loss = self.pipeline.run_one_step(batch, t, r)
return loss
def train(self):
accumulation_steps = self.config.accumulation_steps
num_steps = self.config.num_steps
step = 0
_step_in_accum = 0
pbar = tqdm(total=num_steps, desc="Training")
loss_ema = None
for batch in self.dataloader:
loss = self.train_step(batch)
(loss / accumulation_steps).backward()
_step_in_accum += 1
if _step_in_accum == accumulation_steps:
torch.nn.utils.clip_grad_norm_(self.pipeline.model.parameters(), max_norm=1.0)
self.muon_optimizer.step()
self.adamw_optimizer.step()
self.muon_optimizer.zero_grad()
self.adamw_optimizer.zero_grad()
if loss_ema is None:
loss_ema = loss.item()
else:
loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
wandb.log({"loss": loss.item()}, step=step)
pbar.set_postfix({"loss_ema95": loss_ema})
step += 1
pbar.update(1)
_step_in_accum = 0
if step >= num_steps:
break
def sample_logit_normal(std=0.8, mean=0.0):
s = torch.randn(1) * std + mean
return torch.sigmoid(s)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train Wan Mean Flow X")
parser.add_argument('config', type=str, required=True, help='Path to config file')
arg = parser.parse_args()
config = omegaconf.OmegaConf.load(arg.config)
trainer = WanMeanFlowXTrainer(config)
trainer.train()