Skip to content

Commit 1b740ec

Browse files
committed
1st version of tutorial + new model
1 parent 4c10a65 commit 1b740ec

5 files changed

Lines changed: 477 additions & 62 deletions

File tree

{{cookiecutter.project_slug}}/MyDataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,25 @@
33
from torch.utils.data import DataLoader
44

55
class CifarDataModule(pl.LightningDataModule):
6-
def __init__(self, batch_size: int = 32):
6+
def __init__(self, batch_size: int = 32, num_workers: int = 4):
77
super().__init__()
88
self.batch_size = batch_size
99
self.dataset = load_dataset("cifar100").with_format("torch")
1010
self.dataset = self.dataset.rename_columns({"img": "input", "fine_label": "target"})
1111
# self.dataset = self.dataset.rename_columns({"img": "input", "coarse_label": "target"})
1212
self.dataset = self.dataset.remove_columns(["coarse_label"])
1313
# self.dataset = self.dataset.remove_columns(["fine_label"])
14+
self.num_workers = num_workers
1415

1516
def setup(self, stage: str):
1617
self.train = self.dataset["train"]
1718
self.test = self.dataset["test"]
1819

1920
def train_dataloader(self):
20-
return DataLoader(self.dataset["train"], batch_size=self.batch_size)
21+
return DataLoader(self.dataset["train"], batch_size=self.batch_size, num_workers=self.num_workers)
2122

2223
def val_dataloader(self):
23-
return DataLoader(self.dataset["test"], batch_size=self.batch_size)
24+
return DataLoader(self.dataset["test"], batch_size=self.batch_size, num_workers=self.num_workers)
2425

2526
def log_params(self):
2627
return {

{{cookiecutter.project_slug}}/Tutorial.ipynb

Lines changed: 411 additions & 47 deletions
Large diffs are not rendered by default.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from typing import Dict
2+
from art.core import ArtModule
3+
import torch
4+
import timm
5+
import torch.nn as nn
6+
from torchvision import transforms
7+
import numpy as np
8+
from einops import rearrange
9+
from art.utils.enums import (
10+
BATCH,
11+
INPUT,
12+
LOSS,
13+
PREDICTION,
14+
TARGET,
15+
TRAIN_LOSS,
16+
VALIDATION_LOSS,
17+
)
18+
19+
class EffiNet(ArtModule):
20+
def __init__(self, num_classes=100, lr=1e-3):
21+
super().__init__()
22+
self.model = timm.create_model('efficientnet_b2.ra_in1k', pretrained=True, num_classes=100)
23+
self.loss = torch.nn.CrossEntropyLoss()
24+
self.lr = lr
25+
self.preprocess = transforms.Compose([
26+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
27+
transforms.Resize(256),
28+
])
29+
30+
def parse_data(self, data):
31+
"""This is first step of your pipeline it always has batch keys inside"""
32+
X = data[BATCH][INPUT]
33+
X = X / 255
34+
X = rearrange(X, "b h w c -> b c h w")
35+
X = self.preprocess(X)
36+
target = data[BATCH][TARGET].long()
37+
return {INPUT: X, TARGET: target}
38+
39+
40+
41+
def predict(self, data: Dict):
42+
return {PREDICTION: self.model(data[INPUT]), TARGET: data[TARGET]}
43+
44+
def compute_loss(self, data):
45+
# Notice that the loss calculation is done in MetricsCalculator!
46+
# We only need to specify which loss (metric) we want to use
47+
loss = data["CrossEntropyLoss"]
48+
return {LOSS: loss}
49+
50+
def configure_optimizers(self):
51+
return torch.optim.Adam(self.parameters(), lr=self.lr)
52+
53+
def log_params(self):
54+
# Log relevant parameters
55+
return {
56+
"lr": self.lr,
57+
"model_name": self.model.__class__.__name__,
58+
"n_parameters": sum(p.numel() for p in self.parameters() if p.requires_grad),
59+
}

{{cookiecutter.project_slug}}/models/ResNet.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
VALIDATION_LOSS,
1616
)
1717

18-
class ResNet18(ArtModule):
18+
class ResNet(ArtModule):
1919
def __init__(self, num_classes=100, lr=1e-3):
2020
super().__init__()
2121
self.model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnet18_swsl')
@@ -27,8 +27,6 @@ def __init__(self, num_classes=100, lr=1e-3):
2727
transforms.Resize(256),
2828
transforms.CenterCrop(224),
2929
])
30-
# for name, para in self.model.named_parameters():
31-
# para.requires_grad = True
3230

3331
def parse_data(self, data):
3432
"""This is first step of your pipeline it always has batch keys inside"""

{{cookiecutter.project_slug}}/steps.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,11 @@ def do(self, previous_states):
2020
# Now tell me what are the names of these classes
2121
class_names = list(self.datamodule.dataset["train"].features[TARGET].names)
2222

23-
class_counts = Counter(targets)
24-
2523
# Now calculate number of images in each class
26-
number_of_examples_in_each_class = [
27-
class_counts[i] for i in range(number_of_classes)
28-
]
24+
class_counts = Counter(targets)
2925

3026
# Now tell me dimensions of each image
3127
img_dimensions = self.datamodule.train_dataloader().dataset[0][INPUT].shape
32-
figures = []
3328
for cls in class_names:
3429
class_indices = [i for i, label in enumerate(targets) if label == cls]
3530
class_samples = np.random.choice(class_indices, 5, replace=False).tolist()
@@ -47,15 +42,13 @@ def do(self, previous_states):
4742
MatplotLibSaver().save(
4843
fig, self.get_full_step_name(), self.get_class_image_path(cls)
4944
)
50-
figures.append(fig)
5145

5246
self.results.update(
5347
{
5448
"number_of_classes": number_of_classes,
5549
"class_names": class_names,
56-
"number_of_examples_in_each_class": number_of_examples_in_each_class,
50+
"number_of_examples_in_each_class": class_counts,
5751
"img_dimensions": img_dimensions,
58-
"images": figures,
5952
}
6053
)
6154
def log_params(self):

0 commit comments

Comments
 (0)