-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprada.py
More file actions
343 lines (297 loc) · 16.5 KB
/
prada.py
File metadata and controls
343 lines (297 loc) · 16.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import random
from feature_extractor import extract_raw_log_probs
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Standard ImageNet transforms required by the pre-trained models (ResNet50, ConvNeXt)
guesser_transform = transforms.Compose([
# transforms.Resize(256), # Not needed since our dataloader will already resize to 256x256
transforms.CenterCrop(224),
# transforms.ToTensor(), # Not needed since the dataloader already returns tensors
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
print("Loading ImageNet Validation Set...")
# PyTorch automatically extracts and maps the 50,000 images to the 0-999 labels
transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
imagenet_val = datasets.ImageNet(
root='./ImageNet',
split='val',
transform=transform
)
print("Indexing ImageNet classes for fast sampling...")
class_to_indices = {i: [] for i in range(1000)}
# .samples contains tuples of (file_path, label_index)
for idx, (_, label) in enumerate(imagenet_val.samples):
class_to_indices[label].append(idx)
# Following Section 4 of the PRADA paper
class PRADAScorer(nn.Module):
def __init__(self, num_scales=1):
super().__init__()
self.alpha = nn.Parameter(torch.tensor(1.0)) # Probability ratio balancing (Alpha)
# Token-wise scoring (f_theta): 2 hidden layers, 16 neurons, ELU
self.f_theta = nn.Sequential(
nn.Linear(1, 16),
nn.ELU(),
nn.Linear(16, 16),
nn.ELU(),
nn.Linear(16, 1)
)
# Scale weighting (w_s)
self.num_scales = num_scales
if num_scales > 1:
self.w = nn.Parameter(torch.ones(num_scales) / num_scales)
else:
self.w = torch.tensor([1.0], device=DEVICE)
def forward(self, log_p_cond_list, log_p_uncond_list):
"""
Expects lists of tensors (one per scale).
For RAR, it's a list of length 1. For VAR, list of length num_scales.
Each tensor is shape (N, seq_len_for_that_scale).
So, log_p_cond_list and log_p_uncond_list are a list of num_scales tensors,
each of shape (N, seq_len_for_that_scale).
"""
batch_size = log_p_cond_list[0].size(0) # Should be the same as log_p_uncond_list[0].size(0)
final_scores = torch.zeros(batch_size, device=DEVICE)
for s in range(self.num_scales):
log_p_c = log_p_cond_list[s]
log_p_u = log_p_uncond_list[s]
# Equation (3): Balanced probability ratio
delta_alpha = (2 - self.alpha) * log_p_c - self.alpha * log_p_u
if self.training:
# The paper adds small, scale-dependent Gaussian noise to delta_alpha during training for regularization,
# but doesn't specify the exact noise scale.
current_std = delta_alpha.std().detach()
noise_scale = 0.05 * current_std # 5%
gaussian_noise = torch.randn_like(delta_alpha) * noise_scale
delta_alpha += gaussian_noise
# Pass through f_theta and reshape for linear layers:
# (Batch, Seq_Len) -> (Batch, Seq_Len, 1) -> (Batch, Seq_Len)
f_out = self.f_theta(delta_alpha.unsqueeze(-1)).squeeze(-1) # (Batch, Seq_Len)
# Scale-wise mean
scale_mean = f_out.mean(dim=1) # (Batch,)
# Equation (4): Weight and aggregate across scales to get PRADA score
weight = self.w[s] if self.num_scales > 1 else self.w[0]
final_scores += weight * scale_mean
return final_scores # (Batch,)
def calibrate_prada_scorer(log_p_cond_list, log_p_uncond_list, is_generated_labels, num_scales):
"""
Trains the PRADA scorer for 3000 steps using AdamW as specified in the paper.
Args:
- log_p_cond_list, log_p_uncond_list: lists of tensors (one per scale) containing log probabilities.
- is_generated_labels: 1 if generated by THIS model, 0 otherwise (real or other models).
"""
scorer = PRADAScorer(num_scales=num_scales).to(DEVICE)
# lr=1e-3 is by default, as the paper doesn't specify the exact lr
optimizer = optim.AdamW(scorer.parameters(), lr=1e-3)
# The paper uses label smoothing but doesn't specify the exact values
smoothed_labels = torch.where(
is_generated_labels == 1,
torch.tensor(0.9, device=DEVICE),
torch.tensor(0.1, device=DEVICE)
)
criterion = nn.BCEWithLogitsLoss() # Treats PRADA score as a logit for binary classification
scorer.train()
for _ in range(3000): # 3000 steps of training as in the paper
optimizer.zero_grad()
scores = scorer(log_p_cond_list, log_p_uncond_list) # Get PRADA scores P(x) in Equation (4)
# Loss + L1 regularization on scale weights w (as the paper specifies: encouraging ||w||_1 -> 1)
loss = criterion(scores, smoothed_labels.float())
if num_scales > 1: # For RAR, there's only one scale and `scorer.w` is not a learnable parameter, so we only apply this regularization for VAR
# The paper doesn't specify the exact regularization strength,
# so we use 0.01 as a small penalty to encourage sum of weights to be close to 1
l1_norm = torch.norm(scorer.w, p=1)
loss += 0.01 * torch.abs(l1_norm - 1.0)
loss.backward()
optimizer.step()
scorer.eval()
return scorer
def prada_variant_distinguisher(images, content_guesser, prada_scorers_dict, models, vae, is_var=False):
"""
Args:
- prada_scorer_dict: dict mapping model names to their calibrated PRADA scorers,
{'rarb': scorer_rarb, 'rarl': scorer_rarl, ..., 'var16': scorer_var16, ...}
"""
scores_per_model = []
model_names = ["var16", "var20", "var24", "var30"] if is_var else ["rarb", "rarl", "rarxl", "rarxxl"]
for model_name in model_names:
log_p_cond, log_p_uncond = extract_raw_log_probs(images, models[model_name], vae, is_var=is_var, content_guesser=content_guesser)
scorer = prada_scorers_dict[model_name]
with torch.no_grad():
score = scorer(log_p_cond, log_p_uncond) # Shape: (Batch,)
scores_per_model.append(score)
scores_tensor = torch.stack(scores_per_model, dim=1) # (Batch, 4)
predicted_variants = torch.argmax(scores_tensor, dim=1) # Integer indices
return predicted_variants
def create_prada_calibration_data(fake_images, content_guesser, imagenet_val, class_to_indices, batch_size=16):
"""
Creates a balanced PRADA dataset: `num_train` Fake (generated by target model) and `num_train` Real (ImageNet).
The Real images are guaranteed to have the exact same class distribution as the Fake images.
"""
content_guesser.eval()
predicted_classes = []
with torch.no_grad():
# Process in batches to avoid OOM
for i in range(0, len(fake_images), batch_size):
batch = guesser_transform(fake_images[i:i+batch_size].to(DEVICE))
logits = content_guesser(batch)
predicted_classes.extend(torch.argmax(logits, dim=1).cpu().tolist())
del batch, logits # Free up memory after each batch
torch.cuda.empty_cache()
real_images_list = []
# For every fake image's predicted class, sample one real image of the same class
for predicted_class in predicted_classes:
available_indices = class_to_indices[predicted_class]
chosen_idx = random.choice(available_indices) # Randomly pick one real image of this class
real_img, _ = imagenet_val[chosen_idx] # Fetch the actual image tensor from the dataset
real_images_list.append(real_img)
real_images = torch.stack(real_images_list).to(DEVICE)
combined_images = torch.cat([fake_images, real_images], dim=0) # Shape: (2*num_train, 3, 224, 224)
combined_conditions = torch.tensor(predicted_classes + predicted_classes, device=DEVICE)
# Create the binary labels: 1 for Fake, 0 for Real
labels_fake = torch.ones(len(fake_images), dtype=torch.float32)
labels_real = torch.zeros(len(real_images), dtype=torch.float32)
combined_labels = torch.cat([labels_fake, labels_real], dim=0).to(DEVICE)
return combined_images, combined_conditions, combined_labels
def train_prada_scorers(validation_loader, content_guesser, rar_models, var_models, rar_vae, var_vae, num_scales, num_train=50, n_splits=5):
"""
Args:
- num_train=50: Number of samples in the validation set to use for training each PRADA scorer,
50 Fake (generated by target model) and 50 Real (ImageNet).
"""
prada_scorers_dict = {}
# Move ALL models to CPU first
for m in list(rar_models.values()) + list(var_models.values()):
m.to('cpu')
rar_vae.to('cpu')
var_vae.to('cpu')
content_guesser.to('cpu') # also move content_guesser
torch.cuda.empty_cache()
# Instead of iterating over validation_loader in batches, directly access the dataset
dataset = validation_loader.dataset
all_images_per_class = {}
for idx in range(len(dataset)):
img, label_idx = dataset[idx]
class_name = dataset.classes[label_idx]
if class_name not in all_images_per_class:
all_images_per_class[class_name] = []
all_images_per_class[class_name].append(img)
fold_acc_rar = []
fold_acc_var = []
rar_names = ["rarb", "rarl", "rarxl", "rarxxl"]
var_names = ["var16", "var20", "var24", "var30"]
for fold in range(n_splits):
print(f"\n--- Training PRADA scorers for Fold {fold+1}/{n_splits} ---")
# Split each class into train (40) / eval (10)
fold_train = {}
fold_eval = {}
for class_name, images in all_images_per_class.items():
fold_size = len(images) // n_splits # 10
val_start = fold * fold_size
val_end = val_start + fold_size
fold_eval[class_name] = images[val_start:val_end] # 10 images
fold_train[class_name] = images[:val_start] + images[val_end:] # 40 images
# Train one PRADA scorer per model using fold_train (40 fake + 40 real)
fold_scorers = {}
for model_name in rar_names:
# Only move what is needed for THIS model
rar_vae.to(DEVICE)
rar_models[model_name].to(DEVICE)
content_guesser.to(DEVICE)
torch.cuda.empty_cache()
fake_images = torch.stack(fold_train[model_name]).to(DEVICE) # Shape: (40, 3, 256, 256)
combined_images, combined_conditions, combined_labels = create_prada_calibration_data(
fake_images, content_guesser, imagenet_val, class_to_indices
) # Shape: (80, 3, 256, 256), (80,), (80,)
log_p_cond, log_p_uncond = extract_raw_log_probs(
combined_images, rar_models[model_name], rar_vae, is_var=False, condition=combined_conditions)
# Move log probs to DEVICE before passing to scorer
log_p_cond = [t.to(DEVICE) for t in log_p_cond]
log_p_uncond = [t.to(DEVICE) for t in log_p_uncond]
fold_scorers[model_name] = calibrate_prada_scorer(log_p_cond, log_p_uncond, combined_labels, 1) # RAR has only 1 scale
# Move back to CPU
rar_models[model_name].to('cpu')
rar_vae.to('cpu')
content_guesser.to('cpu')
torch.cuda.empty_cache()
for model_name in var_names:
var_vae.to(DEVICE)
var_models[model_name].to(DEVICE)
content_guesser.to(DEVICE)
torch.cuda.empty_cache()
fake_images = torch.stack(fold_train[model_name]).to(DEVICE) # Shape: (40, 3, 256, 256)
combined_images, combined_conditions, combined_labels = create_prada_calibration_data(
fake_images, content_guesser, imagenet_val, class_to_indices
) # Shape: (80, 3, 256, 256), (80,), (80,)
log_p_cond, log_p_uncond = extract_raw_log_probs(
combined_images, var_models[model_name], var_vae, is_var=True, condition=combined_conditions)
# Move log probs to DEVICE before passing to scorer
log_p_cond = [t.to(DEVICE) for t in log_p_cond]
log_p_uncond = [t.to(DEVICE) for t in log_p_uncond]
fold_scorers[model_name] = calibrate_prada_scorer(log_p_cond, log_p_uncond, combined_labels, num_scales)
var_models[model_name].to('cpu')
var_vae.to('cpu')
content_guesser.to('cpu')
torch.cuda.empty_cache()
# Evaluate on 40 RAR held-out images only
rar_eval_images = torch.stack(
[img for cls in rar_names for img in fold_eval[cls]]
).to(DEVICE) # (40, 3, 256, 256)
rar_true_labels = [cls for cls in rar_names for _ in fold_eval[cls]] # 40 true class names
rar_predicted = prada_variant_distinguisher(
rar_eval_images, content_guesser, fold_scorers,
rar_models, rar_vae, is_var=False
) # (40,) indices into rar_names, so integer values from 0 to 3
rar_correct = sum(
rar_names[rar_predicted[i].item()] == rar_true_labels[i]
for i in range(len(rar_true_labels))
)
rar_acc = rar_correct / len(rar_true_labels) # out of 40
fold_acc_rar.append(rar_acc)
print(f"Fold {fold+1} RAR accuracy: {rar_acc:.4f}")
# Evaluate on 40 VAR held-out images only
var_eval_images = torch.stack(
[img for cls in var_names for img in fold_eval[cls]]
).to(DEVICE) # (40, 3, 256, 256)
var_true_labels = [cls for cls in var_names for _ in fold_eval[cls]] # 40 true class names
var_predicted = prada_variant_distinguisher(
var_eval_images, content_guesser, fold_scorers,
var_models, var_vae, is_var=True
) # (40,) indices into var_names, so integer values from 0 to 3
var_correct = sum(
var_names[var_predicted[i].item()] == var_true_labels[i]
for i in range(len(var_true_labels))
)
var_acc = var_correct / len(var_true_labels) # out of 40
fold_acc_var.append(var_acc)
print(f"Fold {fold+1} VAR accuracy: {var_acc:.4f}")
# Overall accuracy for RAR and VAR across all folds
overall_acc_rar = sum(fold_acc_rar) / len(fold_acc_rar)
overall_acc_var = sum(fold_acc_var) / len(fold_acc_var)
print(f"Overall RAR accuracy: {overall_acc_rar:.4f}")
print(f"Overall VAR accuracy: {overall_acc_var:.4f}")
# Retrain final PRADA scorers on the entire validation set (50 fake + 50 real) for each model
for model_name in rar_names:
all_fake = all_images_per_class[model_name]
sampled_fake = random.sample(all_fake, min(num_train, len(all_fake))) # Randomly sample 50 fake images generated by THIS model
fake_images = torch.stack(sampled_fake).to(DEVICE) # Shape: (50, 3, 224, 224)
combined_images, combined_conditions, combined_labels = create_prada_calibration_data(
fake_images, content_guesser, imagenet_val, class_to_indices
) # Shape: (100, 3, 256, 256), (100,), (100,)
log_p_cond, log_p_uncond = extract_raw_log_probs(
combined_images, rar_models[model_name], rar_vae, is_var=False, condition=combined_conditions)
prada_scorers_dict[model_name] = calibrate_prada_scorer(log_p_cond, log_p_uncond, combined_labels, 1) # RAR has only 1 scale
for model_name in var_names:
all_fake = all_images_per_class[model_name]
sampled_fake = random.sample(all_fake, min(num_train, len(all_fake))) # Randomly sample 50 fake images generated by THIS model
fake_images = torch.stack(sampled_fake).to(DEVICE) # Shape: (50, 3, 224, 224)
combined_images, combined_conditions, combined_labels = create_prada_calibration_data(
fake_images, content_guesser, imagenet_val, class_to_indices
) # Shape: (100, 3, 256, 256), (100,), (100,)
log_p_cond, log_p_uncond = extract_raw_log_probs(
combined_images, var_models[model_name], var_vae, is_var=True, condition=combined_conditions)
prada_scorers_dict[model_name] = calibrate_prada_scorer(log_p_cond, log_p_uncond, combined_labels, num_scales)
return prada_scorers_dict