-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinv_decoder.py
More file actions
257 lines (216 loc) · 11.4 KB
/
inv_decoder.py
File metadata and controls
257 lines (216 loc) · 11.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import copy
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from augmentation import get_rar_transforms, DualTransform
import torch.distributed as dist
dist.init_process_group(backend="nccl")
def normalize_01_into_pm1(x):
return x * 2 - 1
import sys
from pathlib import Path
# Add RAR tokenizer to path
rar_tokenizer_path = Path(__file__).parent / "RAR" / "tokenizer_1d"
print(f"Adding to sys.path: {rar_tokenizer_path}")
print(f"Path exists: {rar_tokenizer_path.exists()}")
sys.path.insert(0, str(rar_tokenizer_path))
# Add VAR tokenizer to path
var_tokenizer_path = Path(__file__).parent / "VAR"
print(f"Adding to sys.path: {var_tokenizer_path}")
print(f"Path exists: {var_tokenizer_path.exists()}")
sys.path.insert(0, str(var_tokenizer_path))
# Import RAR modules
# from modeling.rar import RAR
from modeling.titok import PretrainedTokenizer
# Import VAR modules
from models.vqvae import VQVAE
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATASET_DIR = Path("Dataset")
# According to Table A1 in Appendix C, RAR and VAR both have 50000 finetuning data.
# Recall the statement in the paper: "Given that the finetuning exclusively relies on images produced by the target IAR, it does not require the costly curation of additional training data".
def train_inverse_decoder(family_name, vae_model, epochs=20):
"""
Finetunes the encoder to map generated images back to their EXACT tokens.
Args:
family_name (str): "var" or "rar" - just for logging and saving purposes.
vae_model: The pre-trained VAE model (with frozen decoder and quantizer).
"""
print(f"\n[Training] Starting Inverse Decoder for {family_name}...")
# Use `copy.deepcopy` to avoid accidentally modifying the original pre-trained model
gt_vae = copy.deepcopy(vae_model).to(DEVICE)
pred_vae = copy.deepcopy(vae_model).to(DEVICE)
# Freeze the GT model (used for generating targets)
gt_vae.eval()
for p in gt_vae.parameters():
p.requires_grad = False
# Freeze Decoder & Quantizer in the Pred model (only train the encoder)
pred_vae.encoder.train()
pred_vae.decoder.eval()
pred_vae.quantize.eval()
for p in pred_vae.decoder.parameters():
p.requires_grad = False
for p in pred_vae.quantize.parameters():
p.requires_grad = False
# Unfreeze Encoder
for p in pred_vae.encoder.parameters():
p.requires_grad = True
# optimizer = Adam(pred_vae.encoder.parameters(), lr=1e-5)
# According to Table A1 in Appendix C
if family_name == "var":
lr=5e-5
batch_size = 16
elif family_name == "rar":
lr=5e-4
batch_size = 8
else:
raise ValueError("family_name must be 'var' or 'rar'")
optimizer = Adam(pred_vae.encoder.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)
# Setup Data (Load ALL images from this family)
# For VAR, we load var16, var20, var24, var30 folders combined
# For RAR, we load rarb, rarl, rarxl, rarxxl folders combined
base_transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
train_dataset = datasets.ImageFolder(
root=DATASET_DIR / "train", transform=base_transform
)
val_dataset = datasets.ImageFolder(
root=DATASET_DIR / "val", transform=base_transform
)
# Filter indices based on `family_name`
if family_name == "rar":
target_classes = ["rarb", "rarl", "rarxl", "rarxxl"]
elif family_name == "var":
target_classes = ["var16", "var20", "var24", "var30"]
train_class_indices = [train_dataset.class_to_idx[cls] for cls in target_classes if cls in train_dataset.class_to_idx]
val_class_indices = [val_dataset.class_to_idx[cls] for cls in target_classes if cls in val_dataset.class_to_idx]
train_indices = [i for i, (_, label) in enumerate(train_dataset.samples) if label in train_class_indices]
val_indices = [i for i, (_, label) in enumerate(val_dataset.samples) if label in val_class_indices]
train_subset = torch.utils.data.Subset(train_dataset, train_indices)
val_subset = torch.utils.data.Subset(val_dataset, val_indices)
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=4)
for epoch in range(epochs):
if family_name == "rar":
aug_transform = get_rar_transforms(epoch)
# dual_transform = DualTransform(base_transform, aug_transform)
# Handle Subset or standard dataset wrapper
if isinstance(train_loader.dataset, torch.utils.data.Subset):
# train_loader.dataset.dataset.transform = dual_transform
train_loader.dataset.dataset.transform = aug_transform
else:
# train_loader.dataset.transform = dual_transform
train_loader.dataset.transform = aug_transform
total_loss = 0
for images, _ in train_loader:
images = images.to(DEVICE)
# # Handle the dual output for RAR (original, augmented)
# if family_name == "rar":
# original_images, augmented_images = images # Unpack the tuple from DualTransform
# elif family_name == "var":
# original_images, augmented_images = images, images # For VAR, we don't have augmentation
# original_images = original_images.to(DEVICE)
# augmented_images = augmented_images.to(DEVICE)
# Get GT Feature Map (f_Z) from CLEAN image (as Equation (6) suggests)
# The paper says: L_inv = || f_Z - D_inv(Aug(img)) ||
# So GT comes from non-augmented (or less augmented) path?
# Actually, standard practice is: GT is the latent of the *original* image.
# But here 'images' is ALREADY augmented by the DataLoader.
# However, the paper says "target feature map f_Z is the original un-augmented feature map".
# To do this strictly, we would need the loader to return both clean and aug.
# SIMPLIFICATION: We assume the VAE is robust enough to extract decent f_Z from aug image,
# OR we rely on the fact that for generated images, f_Z is stable.
# Given the code structure, we use 'images' (augmented) for input.
with torch.no_grad():
if family_name == "rar":
f_raw_gt = gt_vae.encoder(images) # images are x_Z (augmented)
elif family_name == "var":
f_raw_gt = gt_vae.encoder(normalize_01_into_pm1(images)) # images are x_Z (augmented)
if family_name == "var":
f_gt = gt_vae.quant_conv(f_raw_gt) # Apply `quant_conv` before `quantize`
f_Z_gt, _, _ = gt_vae.quantize(f_gt) # Quantized feature map f_Z
elif family_name == "rar":
f_Z_gt, _, _ = gt_vae.quantize(f_raw_gt)
if epoch == 0 and total_loss == 0:
print(f"DEBUG: f_Z_gt shape: {f_Z_gt.shape}")
# Ensure this prints a tensor shape like [16, 256, 14, 14]
# Not a scalar (like shape: torch.Size([]))
# f_raw_aug_pred = pred_vae.encoder(augmented_images) # x_Z -> encoder
# if family_name == "var":
# f_Z_aug_pred = pred_vae.quant_conv(f_raw_aug_pred) # Apply `quant_conv` before `quantize`
# # f_Z_pred, _, _ = pred_vae.quantize(f_pred) # Get the quantized feature maps (codebook vectors)
# elif family_name == "rar":
# f_Z_aug_pred = f_raw_aug_pred
# f_raw_ori_pred = pred_vae.encoder(original_images) # x_Z -> encoder
# if family_name == "var":
# f_Z_ori_pred = pred_vae.quant_conv(f_raw_ori_pred) # Apply `quant_conv` before `quantize`
# # f_Z_pred, _, _ = pred_vae.quantize(f_pred) # Get the quantized feature maps (codebook vectors)
# elif family_name == "rar":
# f_Z_ori_pred = f_raw_ori_pred
if family_name == "rar":
f_raw_pred = pred_vae.encoder(images) # x_Z -> encoder
elif family_name == "var":
f_raw_pred = pred_vae.encoder(normalize_01_into_pm1(images)) # x_Z -> encoder
if family_name == "var":
f_Z_pred = pred_vae.quant_conv(f_raw_pred) # Apply `quant_conv` before `quantize`
# f_Z_pred, _, _ = pred_vae.quantize(f_pred) # Get the quantized feature maps (codebook vectors)
elif family_name == "rar":
f_Z_pred = f_raw_pred
loss = F.mse_loss(f_Z_pred, f_Z_gt.detach()) # Equation (6)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
# Note that the scheduler step is done after each epoch, not after each batch
scheduler.step()
print(f" Epoch {epoch+1}: Loss {total_loss/len(train_loader):.6f}")
# Note that the `val_loader` is not augmented, still using `base_transform`
val_total_loss = 0
val_batches = 0
for images, _ in val_loader:
images = images.to(DEVICE)
with torch.no_grad():
if family_name == "rar":
f_raw_gt = gt_vae.encoder(images)
elif family_name == "var":
f_raw_gt = gt_vae.encoder(normalize_01_into_pm1(images))
if family_name == "var":
f_gt = gt_vae.quant_conv(f_raw_gt)
f_Z_gt, _, _ = gt_vae.quantize(f_gt)
elif family_name == "rar":
f_Z_gt, _, _ = gt_vae.quantize(f_raw_gt)
if family_name == "rar":
f_raw_pred = pred_vae.encoder(images)
elif family_name == "var":
f_raw_pred = pred_vae.encoder(normalize_01_into_pm1(images))
if family_name == "var":
f_Z_pred = pred_vae.quant_conv(f_raw_pred)
elif family_name == "rar":
f_Z_pred = f_raw_pred
val_loss = F.mse_loss(f_Z_pred, f_Z_gt.detach())
val_total_loss += val_loss.item()
val_batches += 1
print(f" Validation Loss: {val_total_loss / val_batches:.6f}")
torch.save(pred_vae.encoder.state_dict(), f"{family_name}_inverse_decoder.pth")
print(f"[Saved] {family_name}_inverse_decoder.pth")
if __name__ == "__main__":
rar_vae = PretrainedTokenizer(
"RAR/tokenizer_1d/checkpoints/maskgit-vqgan-imagenet-f16-256.bin"
)
var_vae = VQVAE(
vocab_size=4096,
z_channels=32,
ch=160,
share_quant_resi=4,
test_mode=True,
v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),
).to(DEVICE)
var_vae.load_state_dict(
torch.load("VAR/checkpoints/vae_ch160v4096z32.pth", map_location="cpu"),
strict=True,
)
print(f"Using device: {DEVICE}")
print(f"CUDA available: {torch.cuda.is_available()}")
train_inverse_decoder("rar", rar_vae, epochs=100)
train_inverse_decoder("var", var_vae, epochs=20)