Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions aux_bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,19 @@ def to_status(m, status):
"""
change the status of batch norm layer
status can be 'clean', 'adv' or 'mix'

Three statuses, meaning the training samples in this batch are:
- clean: all clean samples
- adv: all adversarial samples
- mix: *1st* half are *adversarial* samples, and the *2nd* half are *clean* samples
"""
if hasattr(m, 'batch_type'):
m.batch_type = status


to_clean_status = partial(to_status, status='clean')
'''all clean examples'''
to_adv_status = partial(to_status, status='adv')
'''all adversarial samples'''
to_mix_status = partial(to_status, status='mix')
'''*1st* half are *adversarial* samples, and the *2nd* half are *clean* samples'''
61 changes: 51 additions & 10 deletions imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
import models.imagenet as customized_models
from models.AdaIN import StyleTransfer

from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig
from progress.bar import Bar
from utils import Logger, AverageMeter, accuracy, mkdir_p, savefig
from utils.eval import accuracy_and_perclass
from utils.imagenet_a import indices_in_1k
from tensorboardX import SummaryWriter

Expand Down Expand Up @@ -118,6 +120,8 @@
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('-et', '--evaluate-train', dest='evaluate_train', action='store_true',
help='evaluate model on train set')
# Device options
parser.add_argument('--gpu-id', default='7', type=str,
help='id(s) for CUDA_VISIBLE_DEVICES')
Expand Down Expand Up @@ -208,6 +212,7 @@ def main():
num_workers=args.workers, pin_memory=True) if not args.evaluate_imagenet_c else None

# create model
# only works for resnet or resnext
if args.arch.startswith('resnext'):
norm_layer = MixBatchNorm2d if args.mixbn else None
model = models.__dict__[args.arch](
Expand Down Expand Up @@ -264,6 +269,7 @@ def main():
break

if args.mixbn and not already_mixbn:
# update the model checkpoint with mixbn
to_merge = {}
for key in checkpoint['state_dict']:
if 'bn' in key:
Expand Down Expand Up @@ -301,9 +307,15 @@ def main():

if args.evaluate:
print('\nEvaluation only')
test_loss, test_acc = test(val_loader, model, criterion, start_epoch, use_cuda, args.FGSM)
test_loss, test_acc = test(val_loader, model, criterion, start_epoch, use_cuda, args.FGSM, args.num_classes)
print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc))
return

if args.evaluate_train:
print('\nEvaluation on training set only')
test_loss, test_acc = test(train_loader, model, criterion, start_epoch, use_cuda, args.FGSM, args.num_classes)
print(' Train Loss: %.8f, Train Acc: %.2f' % (test_loss, test_acc))
return

if args.evaluate_imagenet_c:
print("Evaluate ImageNet C")
Expand Down Expand Up @@ -331,15 +343,15 @@ def main():
start_lr=args.warm_lr) if args.warm > 0 else None
for epoch in range(start_epoch, args.epochs):
if epoch >= args.warm and args.lr_schedule == 'step':
adjust_learning_rate(optimizer, epoch, args)
adjust_learning_rate(optimizer, epoch, args, state)

print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, optimizer.param_groups[-1]['lr']))

style_transfer = partial(StyleTransfer(), alpha=args.alpha,
label_mix_alpha=1 - args.label_gamma) if args.style else None
train_func = partial(train, train_loader=train_loader, model=model, criterion=criterion,
optimizer=optimizer, epoch=epoch, use_cuda=use_cuda,
warmup_scheduler=warmup_scheduler, mixbn=args.mixbn,
warmup_scheduler=warmup_scheduler, state=state, mixbn=args.mixbn,
style_transfer=style_transfer, writer=writer)
if args.mixbn:
model.apply(to_mix_status)
Expand Down Expand Up @@ -398,8 +410,13 @@ def img_size_scheduler(batch_idx, epoch, schedule):
return ret_size, ret_size


def train(train_loader, model, criterion, optimizer, epoch, use_cuda, warmup_scheduler, mixbn=False,
def train(train_loader, model, criterion, optimizer, epoch, use_cuda, warmup_scheduler, state, mixbn=False,
style_transfer=None, writer=None):
'''
Train the model for a single epoch

Core of shape-texture debiased training happens here
'''
# switch to train mode
model.train()

Expand All @@ -422,7 +439,7 @@ def train(train_loader, model, criterion, optimizer, epoch, use_cuda, warmup_sch
if epoch < args.warm:
warmup_scheduler.step()
elif args.lr_schedule == 'cos':
adjust_learning_rate(optimizer, epoch, args, batch=batch_idx, nBatch=len(train_loader))
adjust_learning_rate(optimizer, epoch, args, state, batch=batch_idx, nBatch=len(train_loader))

# measure data loading time
data_time.update(time.time() - end)
Expand All @@ -433,10 +450,16 @@ def train(train_loader, model, criterion, optimizer, epoch, use_cuda, warmup_sch

if style_transfer is not None:
if args.multi_grid:
'''
Multigrid training
https://arxiv.org/pdf/1912.00998.pdf
Help with faster convergence. Might improve performance for small models
'''
img_size = img_size_scheduler(batch_idx, epoch, args.schedule)
resized_inputs = torch.nn.functional.interpolate(inputs, size=img_size)
inputs_aux, targets_aux = style_transfer(resized_inputs, targets, replace=True)
inputs = (inputs, inputs_aux)
# get the nwe set of targets that include the label of style transferred images
if len(targets_aux) == 3:
n = targets.size(0)
targets = (torch.cat([targets, targets_aux[0]]),
Expand All @@ -462,16 +485,20 @@ def train(train_loader, model, criterion, optimizer, epoch, use_cuda, warmup_sch
elif args.cutmix:
inputs, targets = cutmix_data(inputs, targets, beta=args.cutmix, half=False)

# normalize AFTER style transfer
if not args.multi_grid:
inputs = (inputs - MEAN[:, None, None]) / STD[:, None, None]
# If using mixbn, model should be in mixed status here because inputs contain both the original and stylized images
outputs = model(inputs)
else:
inputs = ((inputs[0] - MEAN[:, None, None]) / STD[:, None, None],
(inputs[1] - MEAN[:, None, None]) / STD[:, None, None])
if args.mixbn:

# Run the batch on model. Since the original and stylized images are in two separate variable in this case, the first one is considered all clean samples, and the second one is considered all adversarial samples
if mixbn:
model.apply(to_clean_status)
outputs1 = model(inputs[0])
if args.mixbn:
if mixbn:
model.apply(to_adv_status)
outputs2 = model(inputs[1])
outputs = torch.cat([outputs1, outputs2])
Expand All @@ -488,6 +515,7 @@ def train(train_loader, model, criterion, optimizer, epoch, use_cuda, warmup_sch
top1.update(prec1.item(), outputs.size(0))
top5.update(prec5.item(), outputs.size(0))

# Compute main and aux loss/metrics separately when using mixbn
if mixbn:
with torch.no_grad():
batch_size = outputs.size(0)
Expand Down Expand Up @@ -537,14 +565,17 @@ def train(train_loader, model, criterion, optimizer, epoch, use_cuda, warmup_sch
return losses.avg, top1.avg


def test(val_loader, model, criterion, epoch, use_cuda, FGSM=False):
def test(val_loader, model, criterion, epoch, use_cuda, FGSM=False, num_classes=None):
global best_acc

batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
if num_classes is not None:
total_num_per_class = torch.zeros(num_classes).int()
total_correct_per_class = torch.zeros(num_classes).int()

# switch to evaluate mode
model.eval()
Expand Down Expand Up @@ -580,7 +611,12 @@ def test(val_loader, model, criterion, epoch, use_cuda, FGSM=False):
loss = criterion(outputs, targets).mean()

# measure accuracy and record loss
prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
if num_classes is not None:
prec1, prec5, num_per_class, correct_per_class = accuracy_and_perclass(outputs.data, targets.data, topk=(1, 5), numclasses=num_classes)
total_num_per_class += num_per_class
total_correct_per_class += correct_per_class
else:
prec1, prec5= accuracy(outputs.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
Expand All @@ -603,6 +639,11 @@ def test(val_loader, model, criterion, epoch, use_cuda, FGSM=False):
)
bar.next()
bar.finish()
if num_classes is not None:
accuracy_per_class = total_correct_per_class / total_num_per_class
print(f"class\tacc")
for i, acc in enumerate(accuracy_per_class.tolist()):
print(f"{i}\t{acc}")
return (losses.avg, top1.avg)


Expand Down
7 changes: 5 additions & 2 deletions models/AdaIN.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,17 @@ def __call__(self, image, label, alpha, replace=True, label_mix_alpha=0):
n, c, h, w = image.shape
content = image.detach()
random_index = torch.randperm(n)
style = image.detach()[random_index]
label_style = label.detach()[random_index]
style = image.detach()[random_index] # Style is created from randomly permuting a batch of image. Thus each (content[i], style[i]) pair is from the same batch and essentially could be the same image
label_style = label.detach()[random_index] # need to also interpolate the label
with torch.no_grad():
# Run AdaIN and get style-transferred image
stylized_image = self.style_transfer(content, style, alpha)

if replace:
# In replace model, the original image is not kept.
return stylized_image, (label, label_style, torch.ones(n).cuda() * label_mix_alpha)
else:
# Return both orignal image and stylized image. Thus, label and label style is copied twice
label1 = torch.cat([label, label])
label2 = torch.cat([label_style, label_style])
label_weight = torch.cat([torch.zeros(n), torch.ones(n) * label_mix_alpha]).cuda()
Expand Down
31 changes: 28 additions & 3 deletions utils/eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import print_function, absolute_import

import torch
__all__ = ['accuracy']

def accuracy(output, target, topk=(1,)):
Expand All @@ -13,6 +13,31 @@ def accuracy(output, target, topk=(1,)):

res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
correct_k = correct[:k].reshape(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
return res

def accuracy_and_perclass(output, target, topk=(1,), numclasses=200):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)

_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))

res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))

# top1 per-class stats
num = target.bincount()
wt = (pred[0] == target).int()
correct = target.bincount(wt)
num_per_class = torch.zeros(numclasses).int()
correct_per_class = torch.zeros(numclasses).int()
num_per_class[:len(num)] = num
correct_per_class[:len(correct)] = correct

return *res, num_per_class, correct_per_class
5 changes: 2 additions & 3 deletions utils/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from torch.optim.lr_scheduler import _LRScheduler
import math


def adjust_learning_rate(optimizer, epoch, args, batch=None, nBatch=None):
global state
def adjust_learning_rate(optimizer, epoch, args, state, batch=None, nBatch=None):
if args.lr_schedule == 'cos':
T_total = args.epochs * nBatch
T_cur = (epoch % args.epochs) * nBatch + batch
Expand Down