Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions {{cookiecutter.project_slug}}/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,9 @@ instance/
*.bak
*.tmp
# TODO: Add more patterns based on your specific needs

*.png
lightning_logs/

.neptune/
*.zip
542 changes: 542 additions & 0 deletions {{cookiecutter.project_slug}}/Tutorial.ipynb

Large diffs are not rendered by default.

Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"name": "Evaluate Baseline", "model": "AlreadyExistingResNet20Baseline", "runs": [{"scores": {"MulticlassAccuracy-validate": 0.6883000135421753}, "parameters": {}, "timestamp": "2023-12-02 11:10:00.493226", "successful": true, "hash": "189162562416660209dae3454f16a2df", "commit_id": "1b740ecf84c50de0b79fb155ef306ef622d92f08", "log_file_name": "2023-12-02_11-10-00_e6aac212-98e4-4836-9884-ba7e210a7b2a.log"}]}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"name": "Data analysis", "model": "", "runs": [{"scores": {}, "parameters": {}, "timestamp": "2023-12-02 11:08:56.045512", "successful": true, "hash": "d95a750a53f15e554caf978981ca3494", "commit_id": "1b740ecf84c50de0b79fb155ef306ef622d92f08", "number_of_classes": 100, "class_names": ["apple", "aquarium_fish", "baby", "bear", "beaver", "bed", "bee", "beetle", "bicycle", "bottle", "bowl", "boy", "bridge", "bus", "butterfly", "camel", "can", "castle", "caterpillar", "cattle", "chair", "chimpanzee", "clock", "cloud", "cockroach", "couch", "cra", "crocodile", "cup", "dinosaur", "dolphin", "elephant", "flatfish", "forest", "fox", "girl", "hamster", "house", "kangaroo", "keyboard", "lamp", "lawn_mower", "leopard", "lion", "lizard", "lobster", "man", "maple_tree", "motorcycle", "mountain", "mouse", "mushroom", "oak_tree", "orange", "orchid", "otter", "palm_tree", "pear", "pickup_truck", "pine_tree", "plain", "plate", "poppy", "porcupine", "possum", "rabbit", "raccoon", "ray", "road", "rocket", "rose", "sea", "seal", "shark", "shrew", "skunk", "skyscraper", "snail", "snake", "spider", "squirrel", "streetcar", "sunflower", "sweet_pepper", "table", "tank", "telephone", "television", "tiger", "tractor", "train", "trout", "tulip", "turtle", "wardrobe", "whale", "willow_tree", "wolf", "woman", "worm"], "number_of_examples_in_each_class": {"cattle": 500, "dinosaur": 500, "apple": 500, "boy": 500, "aquarium_fish": 500, "telephone": 500, "train": 500, "cup": 500, "cloud": 500, "elephant": 500, "keyboard": 500, "willow_tree": 500, "sunflower": 500, "castle": 500, "sea": 500, "bicycle": 500, "wolf": 500, "squirrel": 500, "shrew": 500, "pine_tree": 500, "rose": 500, "television": 500, "table": 500, "possum": 500, "oak_tree": 500, "leopard": 500, "maple_tree": 500, "rabbit": 500, "chimpanzee": 500, "clock": 500, "streetcar": 500, "cockroach": 500, "snake": 500, "lobster": 500, "mountain": 500, "palm_tree": 500, "skyscraper": 500, "tractor": 500, "shark": 500, "butterfly": 500, "bottle": 500, "bee": 500, "chair": 500, "woman": 500, "hamster": 500, "otter": 500, "seal": 500, "lion": 500, "mushroom": 500, "girl": 500, "sweet_pepper": 500, "forest": 500, "crocodile": 500, "orange": 500, "tulip": 500, "mouse": 500, "camel": 500, "caterpillar": 500, "man": 500, "skunk": 500, "kangaroo": 500, "raccoon": 500, "snail": 500, "rocket": 500, "whale": 500, "worm": 500, "turtle": 500, "beaver": 500, "plate": 500, "wardrobe": 500, "road": 500, "fox": 500, "flatfish": 500, "tiger": 500, "ray": 500, "dolphin": 500, "poppy": 500, "porcupine": 500, "lamp": 500, "cra": 500, "motorcycle": 500, "spider": 500, "tank": 500, "orchid": 500, "lizard": 500, "beetle": 500, "bridge": 500, "baby": 500, "lawn_mower": 500, "house": 500, "bus": 500, "couch": 500, "bowl": 500, "pear": 500, "bed": 500, "plain": 500, "trout": 500, "bear": 500, "pickup_truck": 500, "can": 500}, "img_dimensions": [32, 32, 3], "log_file_name": "2023-12-02_11-08-56_119bbd34-9342-49f3-8fd2-7e744b5689b4.log"}]}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"name": "Check Loss On Init", "model": "EfficientNet", "runs": [{"scores": {"MulticlassAccuracy-validate": 0.00675999978557229, "CrossEntropyLoss-validate": 4.893486022949219}, "parameters": {"lr": 0.001, "model_name": "EfficientNet", "n_parameters": 7841894, "batch_size": 32, "train_samples": 50000, "val_samples": 10000}, "timestamp": "2023-12-02 11:14:53.975501", "successful": true, "hash": "eef024742b33fd04192b7c9fdc0c1124", "commit_id": "1b740ecf84c50de0b79fb155ef306ef622d92f08", "log_file_name": "2023-12-02_11-14-53_43361818-6469-4547-8a82-02be2670bd52.log"}]}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"name": "Overfit One Batch", "model": "EfficientNet", "runs": [{"scores": {"MulticlassAccuracy-train": 1.0, "CrossEntropyLoss-train": 2.902682354033459e-05}, "parameters": {"number_of_steps": 40, "lr": 0.001, "model_name": "EfficientNet", "n_parameters": 7841894, "batch_size": 32, "train_samples": 50000, "val_samples": 10000}, "timestamp": "2023-12-02 11:16:49.769266", "successful": true, "hash": "eef024742b33fd04192b7c9fdc0c1124", "commit_id": "1b740ecf84c50de0b79fb155ef306ef622d92f08", "model_path": "/content/ART-Templates/{{cookiecutter.project_slug}}/lightning_logs/version_7/checkpoints/epoch=39-step=40.ckpt", "MulticlassAccuracy-train": 1.0, "CrossEntropyLoss-train": 2.902682354033459e-05, "log_file_name": "2023-12-02_11-16-49_0704e137-3c69-4515-a943-bffe0f645f44.log"}]}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"name": "Overfit", "model": "EfficientNet", "runs": [{"scores": {"MulticlassAccuracy-train": 0.9455999732017517, "CrossEntropyLoss-train": 0.16572201251983643, "MulticlassAccuracy-validate": 0.7613999843597412, "CrossEntropyLoss-validate": 1.1986972093582153}, "parameters": {"max_epochs": 10, "lr": 0.001, "model_name": "EfficientNet", "n_parameters": 7841894, "batch_size": 32, "train_samples": 50000, "val_samples": 10000}, "timestamp": "2023-12-02 11:17:52.515484", "successful": true, "hash": "eef024742b33fd04192b7c9fdc0c1124", "commit_id": "1b740ecf84c50de0b79fb155ef306ef622d92f08", "model_path": "/content/ART-Templates/{{cookiecutter.project_slug}}/lightning_logs/version_8/checkpoints/epoch=9-step=15630.ckpt", "log_file_name": "2023-12-02_11-17-52_fb7fe9c4-02bb-4a5b-8d37-bd257c6d9f85.log"}]}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"name": "TransferLearning", "model": "EfficientNet", "runs": [{"scores": {"MulticlassAccuracy-train": 0.9606000185012817, "CrossEntropyLoss-train": 0.12111776322126389, "MulticlassAccuracy-validate": 0.7610999941825867, "CrossEntropyLoss-validate": 1.254715085029602}, "parameters": {"max_epochs": 2, "check_val_every_n_epoch": 2, "lr": 0.001, "model_name": "EfficientNet", "n_parameters": 7841894, "batch_size": 32, "train_samples": 50000, "val_samples": 10000}, "timestamp": "2023-12-02 10:02:36.630585", "successful": true, "hash": "eef024742b33fd04192b7c9fdc0c1124", "commit_id": "1b740ecf84c50de0b79fb155ef306ef622d92f08", "model_path": "/content/ART-Templates/{{cookiecutter.project_slug}}/.neptune/Untitled/TRAN-21/checkpoints/epoch=13-step=21882.ckpt", "log_file_name": "2023-12-02_10-02-40_686443d1-130e-496b-ae10-7619cf9ea9dc.log"}]}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"name": "Evaluate Baseline", "model": "HeuristicBaseline", "runs": [{"scores": {"MulticlassAccuracy-validate": 0.012299999594688416}, "parameters": {}, "timestamp": "2023-12-02 11:10:00.493179", "successful": true, "hash": "06dbb6cfec9272cc2e03dc1e9f8f1846", "commit_id": "1b740ecf84c50de0b79fb155ef306ef622d92f08", "log_file_name": "2023-12-02_11-10-00_e6aac212-98e4-4836-9884-ba7e210a7b2a.log"}]}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"name": "Evaluate Baseline", "model": "MlBaseline", "runs": [{"scores": {"MulticlassAccuracy-validate": 0.15649999678134918}, "parameters": {}, "timestamp": "2023-12-02 11:10:00.493217", "successful": true, "hash": "49b251eb0a6cbddbbd5b185281399fa3", "commit_id": "1b740ecf84c50de0b79fb155ef306ef622d92f08", "log_file_name": "2023-12-02_11-10-00_e6aac212-98e4-4836-9884-ba7e210a7b2a.log"}]}
25 changes: 25 additions & 0 deletions {{cookiecutter.project_slug}}/checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from art.checks import Check, CheckResult, ResultOfCheck
from art.steps import Step
from art.utils.savers import MatplotLibSaver


class CheckClassImagesExist(Check):
def check(self, step: Step) -> ResultOfCheck:
for class_name in step.get_latest_run()["class_names"]:
image_path = step.get_class_image_path(class_name)
if not MatplotLibSaver().exists(step.get_full_step_name(), image_path):
return ResultOfCheck(
is_positive=False,
error=f"Image for class: {class_name} does not exist. it should have been here: {MatplotLibSaver().get_path(step.get_full_step_name(), image_path)}",
)
return ResultOfCheck(is_positive=True)


class CheckLenClassNamesEqualToNumClasses(CheckResult):
def _check_method(self, result) -> ResultOfCheck:
if len(result["class_names"]) != result["number_of_classes"]:
return ResultOfCheck(
is_positive=False,
error="Number of class names is different than number of classes",
)
return ResultOfCheck(is_positive=True)
40 changes: 33 additions & 7 deletions {{cookiecutter.project_slug}}/dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,40 @@
from lightning import LightningDataModule
import lightning as pl
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import lightning as L, not pl; it confuses it with the older version PyTorch Lightning

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from datasets import load_dataset
from torch.utils.data import DataLoader


class MyDataModule(LightningDataModule):
def __init__(self):
super().__init__(...)
self.dataset = ...
class CifarDataModule(pl.LightningDataModule):
def __init__(self, batch_size: int = 32, num_workers: int = 4):
super().__init__()
self.batch_size = batch_size
self.dataset = load_dataset("cifar100").with_format("torch")
self.dataset = self.dataset.rename_columns(
{"img": "input", "fine_label": "target"}
)
self.dataset = self.dataset.remove_columns(["coarse_label"])
self.num_workers = num_workers

def setup(self, stage: str):
self.train = self.dataset["train"]
self.test = self.dataset["test"]

def train_dataloader(self):
return DataLoader(...)
return DataLoader(
self.dataset["train"],
batch_size=self.batch_size,
num_workers=self.num_workers,
)

def val_dataloader(self):
return DataLoader(...)
return DataLoader(
self.dataset["test"],
batch_size=self.batch_size,
num_workers=self.num_workers,
)

def log_params(self):
return {
"batch_size": self.batch_size,
"train_samples": len(self.dataset["train"]),
"val_samples": len(self.dataset["test"]),
}
78 changes: 78 additions & 0 deletions {{cookiecutter.project_slug}}/models/EfficientNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Dict, Any
from art.core import ArtModule
import torch
import timm
from torchvision import transforms
from einops import rearrange
from art.utils.enums import (
BATCH,
INPUT,
LOSS,
PREDICTION,
TARGET,
)


class EfficientNet(ArtModule):
def __init__(self, num_classes: int = 100, lr: float = 1e-3):
super().__init__()
self.model = timm.create_model(
"efficientnet_b2.ra_in1k", pretrained=True, num_classes=num_classes
)
self.lr = lr
self.preprocess = transforms.Compose(
[
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225], # statistics of ImageNet dataset
),
transforms.Resize(256), # Size desired by this particular model
]
)

def parse_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
This is first step of your pipeline it always has batch keys inside
The result of this step is passed to the next step in the pipeline which is predict
"""
X = data[BATCH][INPUT]
X = X / 255
X = rearrange(X, "b h w c -> b c h w")
X = self.preprocess(X)
target = data[BATCH][TARGET].long()
return {INPUT: X, TARGET: target}

def predict(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
This is the second step of your pipeline. The input of this step is the output of the previous step.
You should return a dictionary with PREDICTION and TARGET keys.
"""
return {PREDICTION: self.model(data[INPUT]), TARGET: data[TARGET]}

def compute_loss(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
This is the last step of your pipeline. The input of this step is the output of the previous step.
You should return a dictionary with LOSS key.
You only need to specify which loss (metric) we want to use.
"""
loss = data["CrossEntropyLoss"]
return {LOSS: loss}

def configure_optimizers(self) -> torch.optim.Optimizer:
"""
Set up your optimizer.
"""
return torch.optim.Adam(self.parameters(), lr=self.lr)

def log_params(self) -> Dict[str, Any]:
"""
This is a method for logging relevant parameters.
It has to be implemented, however, it can be empty.
"""
return {
"lr": self.lr,
"model_name": self.model.__class__.__name__,
"n_parameters": sum(
p.numel() for p in self.parameters() if p.requires_grad
),
}
5 changes: 0 additions & 5 deletions {{cookiecutter.project_slug}}/models/base_model.py

This file was deleted.

106 changes: 106 additions & 0 deletions {{cookiecutter.project_slug}}/models/baselines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Dict, Any

import numpy as np
from einops import rearrange
from sklearn.linear_model import LogisticRegression

from art.core import ArtModule
from art.utils.enums import BATCH, INPUT, PREDICTION, TARGET


class MlBaseline(ArtModule):
name = "ML Baseline"

def __init__(self, model: Any = LogisticRegression()):
super().__init__()
self.model = model

def ml_parse_data(self, data: Dict):
X = []
y = []
for batch in data["dataloader"]:
X.append(batch[INPUT].flatten(start_dim=1).numpy() / 255)
y.append(batch[TARGET].numpy())

return {INPUT: np.concatenate(X), TARGET: np.concatenate(y)}

def baseline_train(self, data: Dict):
self.model = self.model.fit(data[INPUT], data[TARGET])
return {"model": self.model}

def parse_data(self, data: Dict):
"""This is first step of your pipeline it always has batch keys inside"""
batch = data[BATCH]
return {
INPUT: batch[INPUT].flatten(start_dim=1).numpy(),
TARGET: batch[TARGET].numpy(),
}

def predict(self, data: Dict):
return {PREDICTION: self.model.predict(data[INPUT]), TARGET: data[TARGET]}

def log_params(self):
return {"model": self.model.__class__.__name__}


class HeuristicBaseline(ArtModule):
name = "Heuristic Baseline"
n_classes = 100
img_shape = (32, 32, 3)

def __init__(self):
super().__init__()

def parse_data(self, data: Dict):
"""This is first step of your pipeline it always has batch keys inside"""
batch = data[BATCH]
return {
INPUT: batch[INPUT].flatten(start_dim=1).numpy(),
TARGET: batch[TARGET].numpy(),
}

def baseline_train(self, data: Dict):
self.prototypes = np.zeros(
(self.n_classes, self.img_shape[0] * self.img_shape[1] * self.img_shape[2])
)
self.counts = np.zeros(self.n_classes)
for batch in data["dataloader"]:
for img, label in zip(batch[INPUT], batch[TARGET]):
self.prototypes[label.item()] += img.flatten().numpy() / 255
self.counts[label.item()] += 1

self.prototypes = self.prototypes / self.counts[:, None]

def predict(self, data: Dict):
y_hat = np.argmax((data[INPUT] @ self.prototypes.T), axis=1)
return {PREDICTION: y_hat, TARGET: data[TARGET]}

def log_params(self):
return {"model": "Heuristic"}


class AlreadyExistingResNet20Baseline(ArtModule):
name = "Already Existing ResNet20 Baseline"

def __init__(self):
super().__init__()
import torch

self.model = torch.hub.load(
"chenyaofo/pytorch-cifar-models", "cifar100_resnet20", pretrained=True
)

def parse_data(self, data: Dict):
mean = np.asarray([0.5071, 0.4867, 0.4408], dtype=np.float32)
std = np.asarray([0.2675, 0.2565, 0.2761], dtype=np.float32)
X = data[BATCH][INPUT]
X = (X / 255 - mean) / std
X = rearrange(X, "b h w c -> b c h w")
return {INPUT: X, TARGET: data[BATCH][TARGET]}

def predict(self, data: Dict):
preds = self.model(data[INPUT]).detach().numpy()
return {PREDICTION: preds, TARGET: data[TARGET]}

def log_params(self):
return {"model": self.model.__class__.__name__}
16 changes: 0 additions & 16 deletions {{cookiecutter.project_slug}}/run.py

This file was deleted.

57 changes: 57 additions & 0 deletions {{cookiecutter.project_slug}}/steps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
from art.utils.savers import MatplotLibSaver
from art.steps import ExploreData
from art.utils.enums import INPUT, TARGET


class DataAnalysis(ExploreData):
def do(self, previous_states):
targets = []
index2label = (
lambda x: self.datamodule.dataset["train"].features[TARGET].int2str(x)
)
# Loop through batches in the cifar_datamodule train dataloader
for batch in self.datamodule.train_dataloader():
targets.extend(batch[TARGET])
targets = [index2label(int(x)) for x in targets]
# Calculate the number of unique classes in the targets
number_of_classes = len(np.unique(targets))
# Now tell me what are the names of these classes
class_names = list(self.datamodule.dataset["train"].features[TARGET].names)

# Now calculate number of images in each class
class_counts = Counter(targets)

# Now tell me dimensions of each image
img_dimensions = self.datamodule.train_dataloader().dataset[0][INPUT].shape
for cls in class_names:
class_indices = [i for i, label in enumerate(targets) if label == cls]
class_samples = np.random.choice(class_indices, 5, replace=False).tolist()

fig, axes = plt.subplots(1, 5, figsize=(15, 5))
for i, sample_idx in enumerate(class_samples):
img = self.datamodule.train_dataloader().dataset[sample_idx][INPUT]
axes[i].imshow(img, cmap="gray")
axes[i].set_title(f"Class: {cls}")
axes[i].axis("off")

MatplotLibSaver().save(
fig, self.get_full_step_name(), self.get_class_image_path(cls)
)

self.results.update(
{
"number_of_classes": number_of_classes,
"class_names": class_names,
"number_of_examples_in_each_class": class_counts,
"img_dimensions": img_dimensions,
}
)

def log_params(self):
return {}

def get_class_image_path(self, class_name: str):
return f"class_images/class_{class_name}.png"