-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaugmentation.py
More file actions
164 lines (145 loc) · 5.72 KB
/
augmentation.py
File metadata and controls
164 lines (145 loc) · 5.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Following Table A2 in Appendix C, which defines a schedule of augmentations for RAR finetuning.
import torch
import random
import io
from torchvision import transforms
from PIL import Image, ImageFilter, ImageEnhance
class JPEGCompression(object):
def __init__(self, qualities):
self.qualities = qualities # Expects a list, e.g., [90, 85, 80]
def __call__(self, img):
if not self.qualities: return img
q = random.choice(self.qualities)
output_buffer = io.BytesIO()
img.save(output_buffer, format='JPEG', quality=q)
output_buffer.seek(0)
return Image.open(output_buffer)
class GaussianBlurDiscrete(object):
def __init__(self, kernel_sizes):
self.kernel_sizes = kernel_sizes # Expects list, e.g., [3, 5]
def __call__(self, img):
if not self.kernel_sizes: return img
k = random.choice(self.kernel_sizes)
# Actually PIL ImageFilter.GaussianBlur uses radius.
# Standard approximation: radius = (kernel_size - 1) / 2
radius = (k - 1) / 2.0
return img.filter(ImageFilter.GaussianBlur(radius))
class AddGaussianNoise(object):
def __init__(self, std_list):
self.std_list = std_list # Expects list, e.g., [0.005, 0.01, 0.02]
def __call__(self, tensor):
if not self.std_list: return tensor
std = random.choice(self.std_list)
return tensor + torch.randn(tensor.size()) * std
class RandomResize(object):
def __init__(self, ratios):
self.ratios = ratios # Expects list, e.g., [0.9, 0.85, 0.8]
def __call__(self, img):
if not self.ratios: return img
ratio = random.choice(self.ratios)
if hasattr(img, "size") and callable(img.size):
_, h, w = img.size()
else:
w, h = img.size
new_w, new_h = int(w * ratio), int(h * ratio)
# Resize down
img = img.resize((new_w, new_h), Image.BILINEAR)
# Resize back up to original (to keep tensor shape consistent for batching)
img = img.resize((w, h), Image.BILINEAR)
return img
class RandomColorTransform(object):
def __init__(self, brightness, saturation, contrast):
self.brightness = brightness # Expects list, e.g., [1.0, 1.1, 1.2]
self.saturation = saturation # Expects list, e.g., [1.0, 1.2, 1.5]
self.contrast = contrast # Expects list, e.g., [1.0, 1.2, 1.5]
def __call__(self, img):
# Apply Brightness
if self.brightness:
factor = random.choice(self.brightness)
img = ImageEnhance.Brightness(img).enhance(factor)
# Apply Saturation
if self.saturation:
factor = random.choice(self.saturation)
img = ImageEnhance.Color(img).enhance(factor)
# Apply Contrast
if self.contrast:
factor = random.choice(self.contrast)
img = ImageEnhance.Contrast(img).enhance(factor)
return img
class DualTransform(object):
"""
Returns a pair: (Original Image, Augmented Image)
"""
def __init__(self, base_transform, aug_transform):
self.base_transform = base_transform
self.aug_transform = aug_transform
def __call__(self, img):
# 1. Generate the Clean Target (Standard Resize+Norm)
original = self.base_transform(img)
# 2. Generate the Augmented Input (Blur/JPEG/etc)
aug = self.aug_transform(img)
return original, aug
def get_rar_transforms(epoch):
"""
Implementation of Table A2 (Appendix C).
Epoch is 0-indexed in code, so:
- Code Epoch 0-4 = Paper Epoch 1-5 (None)
- Code Epoch 5-9 = Paper Epoch 6-10 (Weak)
- Code Epoch 10-29 = Paper Epoch 11-30 (Medium)
- Code Epoch 30-49 = Paper Epoch 31-50 (Strong)
"""
epoch_num = epoch + 1 # Align with the paper's 1-based indexing
# 1. No Augmentation (Epochs 1-5)
if epoch_num <= 0.1 * epoch:
return transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
# 2. Weak Augmentation (Epochs 6-10)
elif epoch_num <= 0.2 * epoch:
return transforms.Compose([
transforms.Resize((256, 256)),
# Table A2: Weak Params
RandomResize([0.9, 0.85, 0.8]),
JPEGCompression([90, 85, 80]),
GaussianBlurDiscrete([1, 3]),
RandomColorTransform(
brightness=[1.0, 1.1, 1.2],
saturation=[1.0, 1.2, 1.5],
contrast=[1.0, 1.2, 1.5]
),
transforms.ToTensor(),
AddGaussianNoise([0.005, 0.01, 0.02]),
])
# 3. Medium Augmentation (Epochs 11-30)
elif epoch_num <= 0.6 * epoch:
return transforms.Compose([
transforms.Resize((256, 256)),
# Table A2: Medium Params
RandomResize([0.8, 0.75, 0.7]),
JPEGCompression([80, 75, 70, 65]),
GaussianBlurDiscrete([3, 5]),
RandomColorTransform(
brightness=[1.3, 1.4, 1.5],
saturation=[1.5, 1.7, 2.0],
contrast=[1.5, 1.7, 2.0]
),
transforms.ToTensor(),
AddGaussianNoise([0.02, 0.03, 0.04]),
])
# 4. Strong Augmentation (Epochs 31-50)
else:
return transforms.Compose([
transforms.Resize((256, 256)),
# Table A2: Strong Params
RandomResize([0.7, 0.6, 0.5]),
JPEGCompression([60, 55, 50]),
GaussianBlurDiscrete([5, 7, 9]),
RandomColorTransform(
brightness=[1.5, 1.7, 2.0],
saturation=[2.0, 2.2, 2.5],
contrast=[2.0, 2.2, 2.4]
),
transforms.ToTensor(),
AddGaussianNoise([0.03, 0.04, 0.05]),
])