Skip to content

Commit 79d676f

Browse files
committed
change effinet, transformations and add num_workers to dataloader
1 parent 8fe3c5c commit 79d676f

2 files changed

Lines changed: 5 additions & 6 deletions

File tree

{{cookiecutter.project_slug}}/MyDataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,25 @@
33
from torch.utils.data import DataLoader
44

55
class CifarDataModule(pl.LightningDataModule):
6-
def __init__(self, batch_size: int = 32):
6+
def __init__(self, batch_size: int = 32, num_workers: int = 4):
77
super().__init__()
88
self.batch_size = batch_size
99
self.dataset = load_dataset("cifar100").with_format("torch")
1010
self.dataset = self.dataset.rename_columns({"img": "input", "fine_label": "target"})
1111
# self.dataset = self.dataset.rename_columns({"img": "input", "coarse_label": "target"})
1212
self.dataset = self.dataset.remove_columns(["coarse_label"])
1313
# self.dataset = self.dataset.remove_columns(["fine_label"])
14+
self.num_workers = num_workers
1415

1516
def setup(self, stage: str):
1617
self.train = self.dataset["train"]
1718
self.test = self.dataset["test"]
1819

1920
def train_dataloader(self):
20-
return DataLoader(self.dataset["train"], batch_size=self.batch_size)
21+
return DataLoader(self.dataset["train"], batch_size=self.batch_size, num_workers=self.num_workers)
2122

2223
def val_dataloader(self):
23-
return DataLoader(self.dataset["test"], batch_size=self.batch_size)
24+
return DataLoader(self.dataset["test"], batch_size=self.batch_size, num_workers=self.num_workers)
2425

2526
def log_params(self):
2627
return {

{{cookiecutter.project_slug}}/models/EffiNet.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,12 @@
1919
class EffiNet(ArtModule):
2020
def __init__(self, num_classes=100, lr=1e-3):
2121
super().__init__()
22-
self.model = timm.create_model('tf_efficientnet_b4.in1k', pretrained=True)
22+
self.model = timm.create_model('efficientnet_b2.ra_in1kx ', pretrained=True, num_classes=100)
2323
self.loss = torch.nn.CrossEntropyLoss()
2424
self.lr = lr
25-
self.model.classifier = nn.Linear(self.model.classifier.in_features, num_classes)
2625
self.preprocess = transforms.Compose([
2726
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
2827
transforms.Resize(256),
29-
transforms.CenterCrop(224),
3028
])
3129

3230
def parse_data(self, data):

0 commit comments

Comments
 (0)