Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions Custom_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __getitem__(self, index):
img_path = os.path.join(self.root_dir, "images", self.annotations.iloc[index, 0])
with Image.open(img_path).convert('RGB') as image:
y_wild_label = self.annotations.iloc[index, 2]
if y_wild_label is not None:
if y_wild_label == y_wild_label:
y_label = int(y_wild_label)
else:
y_label = 0
Expand Down Expand Up @@ -51,7 +51,6 @@ def __len__(self):
def __getitem__(self, index):
original_index = math.floor(index / self.samples4category / self.split) * self.samples4category
samples_missing = len(self.original_dataset) - original_index

if not self.from_bottom:
if samples_missing < self.samples4category:
# we are at the end of the dataset
Expand Down
4 changes: 2 additions & 2 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# The mean and variance used for the normalization
KNOWN_NORMALIZATION = {'CIFAR10': ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
'CIFAR100': ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
'CDON': ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))} # todo: tune the values for CDON
'CDON': ((0.0036, 0.0038, 0.0040), (0.0043, 0.0046, 0.0044))}


class FastTensorDataLoader:
Expand Down Expand Up @@ -127,7 +127,7 @@ def load_cifar_dataset(dataset_name, batch_size=128, noise_rate=0.0, is_symmetri
transforms.ToTensor(),
transforms.Normalize(*KNOWN_NORMALIZATION[dataset_name]),
])

if dataset_name == "CIFAR10":
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
Expand Down
29 changes: 14 additions & 15 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def record_results(filepath, dataset, noise_rate, is_symmetric_noise, enable_amp
'correct': correct[-1], 'memorized': memorized[-1], 'incorrect': incorrect[-1]
})


def model_pipeline(config, trainer_config, loadExistingWeights=False):
# Start wandb
wandb_project = 'resnet-ce-cdon'
Expand Down Expand Up @@ -285,21 +284,21 @@ def main():
config = dict(
n_epochs=120,
batch_size=128,
classes=10,
noise_rate=0.4,
classes=64, #157 categories for clothing # total subcategories is 3516
noise_rate=0.0,
is_symmetric_noise=True,
fraction=1.0,
compute_memorization=True,
dataset_name='CIFAR10', # opt: 'CIFAR10', 'CIFAR100', 'CDON' (not implemented)
model_path='./models/CIFAR10_20.mdl',
plot_path='./results/CIFAR10_20',
compute_memorization=False,
dataset_name='CDON', # opt: 'CIFAR10', 'CIFAR100', 'CDON'
model_path='./models/CDON_CE.mdl',
plot_path='./results/CDON_CE',
learning_rate=0.02,
momentum=0.9,
weight_decay=1e-3,
milestones=[40, 80],
gamma=0.01,
enable_amp=True,
use_ELR=True,
use_ELR=False,
elr_lambda=3.0,
elr_beta=0.7
)
Expand Down Expand Up @@ -338,12 +337,12 @@ def main():
'criterion_params': {}
}

# use_CosAnneal = {
# 'scheduler': optim.lr_scheduler.CosineAnnealingWarmRestarts,
# 'scheduler_params': {"T_0": 10, "eta_min": 0.001},
# # 'scheduler_params': {'T_max': 200, 'eta_min': 0.001}
# }
# trainer_config.update(use_CosAnneal)
use_CosAnneal = {
'scheduler': optim.lr_scheduler.CosineAnnealingWarmRestarts,
'scheduler_params': {"T_0": 10, "eta_min": 0.001},
# 'scheduler_params': {'T_max': 200, 'eta_min': 0.001}
}
trainer_config.update(use_CosAnneal)

if config['use_ELR']:
use_ELR = {
Expand All @@ -356,4 +355,4 @@ def main():


if __name__ == '__main__':
main()
main()
121 changes: 63 additions & 58 deletions scripts/image_downloader.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
"""
This is a script that takes data from various CSV files, download and stores the images.
It assumes the csv has the following columns: name, thumbnail link, date, img link, subcategory id, subcategory.
The saved dataset will now have an equal distribution among the subcategories

The final dataset can be store according to two different conventions:
one that is more pytorch friendly: saves the images as individual files inside a folder and the labels in a csv
the other is more numpy friendly: the images are stored as numpy array inside a file "pickled" from a dictionary with also the labels

the other is more numpy friendly: the images are stored as numpy array inside a file "pickled" from a dictionary with also the labels
Details of the numpy version:
The final dataset is stored in the final_filename{i} files according to CIFAR semantics:
The archive contains the files dataset1, dataset2, ... Each of these files is a Python "pickled" object produced with Pickle.
The archive contains the files dataset1, dataset2, ... Each of these files is a Python "pickled" object produced with Pickle.
Here is a python3 routine which will open such a file and return a dictionary:

def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict

Loaded in this way, each of the batch files contains a dictionary with the following elements:
data -- a 10000x3072 numpy array of uint8s. Each row of the array stores a 32x32 colour image.
data -- a 1000x3072 numpy array of uint8s. Each row of the array stores a 32x32 colour image.
The first 1024 entries contain the red channel values, the next 1024 the green, and the final 1024 the blue.
The image is stored in row-major order, so that the first 50 entries of the array are the red channel values of the first row of the image.
labels -- a list of 10000 numbers in the range 1-21. The number at index i indicates the label of the ith image in the array data.
sublabels -- a list of 10000 numbers in the range 1-?. The number at index i indecates the sublabel (detailed category) of the ith image in the array data.
labels -- a list of 1000 numbers in the range 1-21. The number at index i indicates the label of the ith image in the array data.
sublabels -- a list of 1000 numbers in the range 1-?. The number at index i indecates the sublabel (detailed category) of the ith image in the array data.
The sublabel is an ID. It therefore needs enumeration before use.

Details of the pytorch version:
The images will all be stored inside the folder 'images' (that will be created if non-existent) with their original name (assumed to be unique)
Expand Down Expand Up @@ -68,43 +70,45 @@ class DownloaderConfig():
dataset_folder: str = "./dataset/" # the folder in which to find the csv files and in which to store the dataset
date_accepted: int = 2010 # date from which to save a product
final_filename: str = "dataset" # the base name for the dataset that will be stored
percentage2download: int = 10 # percentage of dataset to download (not garanteed)
samples4file: int = 10000 # save data every...
percentage2download: int = 100 # percentage of dataset to download (not garanteed)
samples4category: int = 1000 # the number of samples to store for each category (if a category has less samples, it will not be stored)
labels_writer: Any = None # used in the PYTORCH_FRIENDLY, its a csv writer for the labels file
final_dataset_path: Callable = None # used in NUMPY_FRIENDLY, used to generate the name of the next dataset file

GET_CATEGORY = lambda x: x.split("/")[-1].split("_")[0] # get the category from the filename
collected_samples = 0 # the number of samples correctly stored
categories_written = 0 # the number of categories written

def store_data(config, data, labels, sublabels):
def store_data(config, data):
"""Store the given data, based on the config object

Args:
config (DownloaderConfig): an object with all the useful configs
data (list): the images stored as bytes
labels (list): the labels relative to each image corresponding to the category of the product
sublabels (list): the sublabels relative to each image corresponding to the subcategory
data (list): a list of tuple ((image, image_name), label, sublabel)
"""
global categories_written
if config.store_format == STORE_FORMAT.NUMPY_FRIENDLY:
# store as numpy array
images2write = np.array([np.array(ImageOps.pad(Image.open(io.BytesIO(x)).convert('RGB'), (32, 32))) for x, _ in data])
labels = []
sublabels = []
images2write = []
for ((image, image_name), label, sublabel) in data:
images2write.append(np.array(ImageOps.pad(Image.open(io.BytesIO(image)).convert('RGB'), (32, 32))))
labels.append(label)
sublabels.append(sublabel)

with open(config.final_dataset_path(), "wb") as data_file:
dict2write = {"data": images2write, "labels": np.array(labels), "sublabels": np.array(sublabels)}
pickle.dump(dict2write, data_file, protocol=pickle.HIGHEST_PROTOCOL)
data.clear()
labels.clear()
sublabels.clear()
else:
for (image, image_name), label, sublabel in zip(data, labels, sublabels):
for ((image, image_name), label, sublabel) in data:
with open(config.dataset_folder + "images/" + image_name, "wb") as file:
file.write(image)
config.labels_writer.writerow([image_name, label, sublabel])
categories_written += 1





def download_file_images(file, data, labels, sublabels, config, categories, lock):
global collected_samples
def download_file_images(file, config, categories, lock, subcat):
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:59.0) Gecko/20100101 Firefox/59.0'
} # needed, otherwise the request hangs
Expand All @@ -120,36 +124,41 @@ def download_file_images(file, data, labels, sublabels, config, categories, lock
# check date
if not date or int(date.split("-")[0]) > config.date_accepted:
thumb_url = row[Tags.THUMBNAIL.value]
try:
answer = requests.get(thumb_url, headers=headers) # download thumbnail
if answer.status_code <= 400:
image = answer.content
with lock:
data.append((image, thumb_url.split("/")[-1]))
labels.append(categories[GET_CATEGORY(file)])
sublabels.append(row[Tags.SUB_ID.value])
collected_samples += 1

# check if collected enough samples
if collected_samples % config.samples4file == 0:
store_data(config, data, labels, sublabels) # todo: could be improved as operation can be done without lock

if collected_samples % 100 == 0:
print(f'Collected {collected_samples} samples')
else:
print(f"Received a non-success code {answer.status_code} when crawling:")
sublabel = row[Tags.SUB_ID.value]
download_image = True
with lock:
if not sublabel in subcat:
subcat[sublabel] = []
download_image = len(subcat[sublabel]) < config.samples4category
if download_image:
try:
answer = requests.get(thumb_url, headers=headers) # download thumbnail
if answer.status_code <= 400:
image = answer.content
data = (image, thumb_url.split("/")[-1])
label = categories[GET_CATEGORY(file)]

with lock:
if len(subcat[sublabel]) < config.samples4category:
subcat[sublabel].append((data, label, sublabel))
if len(subcat[sublabel]) == config.samples4category:
store_data(config, subcat[sublabel]) # todo: could be improved as operation can be done without lock
print(f'Storing {sublabel} subcat samples')
else:
print(f"Received a non-success code {answer.status_code} when crawling:")
print(thumb_url)
sleep(10)
except Exception as e: # todo: bad as it catches all the exceptions
print('Caught the following exception when crawling: ', e)
print(thumb_url)
sleep(10)
except Exception as e: # todo: bad as it catches all the exceptions
print('Caught the following exception when crawling: ', e)
print(thumb_url)


def multithread_image_download(config, max_threads):
all_files = listdir(config.dataset_folder)
all_files.sort()
csv_files = []
categories = {}
subcat = {}

for file in all_files:
if not file.endswith(".csv"):
Expand All @@ -158,12 +167,8 @@ def multithread_image_download(config, max_threads):
if not GET_CATEGORY(file) in categories:
categories[GET_CATEGORY(file)] = len(categories)

data = []
labels = []
sublabels = []

common_lock = RLock()

labels_csv_file = None
if config.store_format == STORE_FORMAT.PYTORCH_FRIENDLY:
if not exists(config.dataset_folder + "images"):
Expand All @@ -174,9 +179,9 @@ def multithread_image_download(config, max_threads):
# files are in the dataset folder
with ThreadPoolExecutor(max_workers=max_threads) as executor:
for file in reversed(csv_files):
executor.submit(download_file_images, file, data, labels, sublabels, config, categories, common_lock,)
executor.submit(download_file_images, file, config, categories, common_lock, subcat)

store_data(config, data, labels, sublabels)
# store_data(config, data, labels, sublabels)
if labels_csv_file:
labels_csv_file.close()

Expand All @@ -191,17 +196,17 @@ def multithread_image_download(config, max_threads):
parser.add_argument('--threads', help="the max number of threads running at the same time, default: uncapped")
parser.add_argument('--dataset-percentage', help="the percentage of the dataset to download, default is 10")
args = vars(parser.parse_args())

# set arguments based on the parsed ones
if args['format'] == "pytorch":
downloader_config.store_format = STORE_FORMAT.PYTORCH_FRIENDLY
elif args['format'] and args['format'] != "numpy":
print('The format you provided is not valid.')
parser.print_help()
exit()

if args['folder']:
downloader_config.dataset_folder = args['csv-folder']
downloader_config.dataset_folder = args['folder']
if downloader_config.dataset_folder[-1] != "/":
downloader_config.dataset_folder += "/"
max_threads = None
Expand All @@ -210,9 +215,9 @@ def multithread_image_download(config, max_threads):

if args['dataset_percentage']:
downloader_config.percentage2download = args['dataset-percentage']

# define lambda used in NUMPY_FRIDENLY to name the datasets name
downloader_config.final_dataset_path = lambda: f'{downloader_config.dataset_folder}{downloader_config.final_filename}' +\
f'{int(np.ceil(collected_samples / downloader_config.samples4file))}'
f'{categories_written}'

multithread_image_download(downloader_config, max_threads)
multithread_image_download(downloader_config, max_threads)
57 changes: 57 additions & 0 deletions scripts/subcat_id2subcat_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
A simple script that generates a csv file with the link between the subcategories IDs and their names.
This information is gain from the various CSV files.
It is assumed that the CSVs have the following columns: name, thumbnail link, date, img link, subcategory id, subcategory.
"""

import csv
from os import listdir
from enum import Enum
import argparse

class Tags(Enum):
NAME = 0
THUMBNAIL = 1
DATE = 2
IMG = 3
SUB_ID = 4
SUB = 5


def generate_subcat2names(dataset_folder = '.'):
all_files = listdir(dataset_folder)
all_files.sort()
csv_files = []

for file in all_files:
if not file.endswith(".csv"):
continue
csv_files.append(file)

subcat2names = {}

for file in all_files:
with open(file, encoding='utf-8', errors='ignore') as csvfile: # ignore problems with strange characters
reader = csv.reader((x.replace('\0', '') for x in csvfile), delimiter=",")
for row in reader:
if len(row) > Tags.SUB.value and row[Tags.SUB_ID.value]:
if not int(row[Tags.SUB_ID.value]) in subcat2names:
subcat2names[int(row[Tags.SUB_ID.value])] = row[Tags.SUB.value]

with open('subcat2names.csv', 'w', newline='') as csvfile:
spamwriter = csv.writer(csvfile, delimiter=',')
for (a, b) in subcat2names.items():
spamwriter.writerow([a, b])


if __name__ == "__main__":
# accept key parameter as args
parser = argparse.ArgumentParser(description='Creates a csv file with the link between the subcategories IDs and their names')
parser.add_argument('--folder', help='the folder in which to find the csv files from which to extract the subcategories, default is "."')
args = vars(parser.parse_args())

if args['folder']:
folder = args['folder']
generate_subcat2names(folder)
else:
generate_subcat2names()