-
Notifications
You must be signed in to change notification settings - Fork 0
Cv transfer learning tutorial #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
mmaecki
wants to merge
10
commits into
main
Choose a base branch
from
cv_transfer_learning_tutorial
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
9385b83
data analysis
mmaecki 8b63fbe
loss on init
mmaecki a5f845a
cuda fix
mmaecki fb2834b
lightning logs delete + predict fix
mmaecki 5644898
transfer learning step
mmaecki c041a10
update on transfer learning + normalization fix
mmaecki 4c10a65
Adding thext for the tutorial + small cleanup
mmaecki 1b740ec
1st version of tutorial + new model
mmaecki aed727d
Tutorial almost done. Waiting for some change suggestions.
mmaecki be883fa
Final version
mmaecki File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file removed
BIN
-145 Bytes
{{cookiecutter.project_slug}}/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
1 change: 1 addition & 0 deletions
1
...ect_slug}}/art_checkpoints/AlreadyExistingResNet20Baseline_Evaluate Baseline/results.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| {"name": "Evaluate Baseline", "model": "AlreadyExistingResNet20Baseline", "runs": [{"scores": {"MulticlassAccuracy-validate": 0.6883000135421753}, "parameters": {}, "timestamp": "2023-12-02 11:10:00.493226", "successful": true, "hash": "189162562416660209dae3454f16a2df", "commit_id": "1b740ecf84c50de0b79fb155ef306ef622d92f08", "log_file_name": "2023-12-02_11-10-00_e6aac212-98e4-4836-9884-ba7e210a7b2a.log"}]} |
1 change: 1 addition & 0 deletions
1
{{cookiecutter.project_slug}}/art_checkpoints/Data analysis/results.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| {"name": "Data analysis", "model": "", "runs": [{"scores": {}, "parameters": {}, "timestamp": "2023-12-02 11:08:56.045512", "successful": true, "hash": "d95a750a53f15e554caf978981ca3494", "commit_id": "1b740ecf84c50de0b79fb155ef306ef622d92f08", "number_of_classes": 100, "class_names": ["apple", "aquarium_fish", "baby", "bear", "beaver", "bed", "bee", "beetle", "bicycle", "bottle", "bowl", "boy", "bridge", "bus", "butterfly", "camel", "can", "castle", "caterpillar", "cattle", "chair", "chimpanzee", "clock", "cloud", "cockroach", "couch", "cra", "crocodile", "cup", "dinosaur", "dolphin", "elephant", "flatfish", "forest", "fox", "girl", "hamster", "house", "kangaroo", "keyboard", "lamp", "lawn_mower", "leopard", "lion", "lizard", "lobster", "man", "maple_tree", "motorcycle", "mountain", "mouse", "mushroom", "oak_tree", "orange", "orchid", "otter", "palm_tree", "pear", "pickup_truck", "pine_tree", "plain", "plate", "poppy", "porcupine", "possum", "rabbit", "raccoon", "ray", "road", "rocket", "rose", "sea", "seal", "shark", "shrew", "skunk", "skyscraper", "snail", "snake", "spider", "squirrel", "streetcar", "sunflower", "sweet_pepper", "table", "tank", "telephone", "television", "tiger", "tractor", "train", "trout", "tulip", "turtle", "wardrobe", "whale", "willow_tree", "wolf", "woman", "worm"], "number_of_examples_in_each_class": {"cattle": 500, "dinosaur": 500, "apple": 500, "boy": 500, "aquarium_fish": 500, "telephone": 500, "train": 500, "cup": 500, "cloud": 500, "elephant": 500, "keyboard": 500, "willow_tree": 500, "sunflower": 500, "castle": 500, "sea": 500, "bicycle": 500, "wolf": 500, "squirrel": 500, "shrew": 500, "pine_tree": 500, "rose": 500, "television": 500, "table": 500, "possum": 500, "oak_tree": 500, "leopard": 500, "maple_tree": 500, "rabbit": 500, "chimpanzee": 500, "clock": 500, "streetcar": 500, "cockroach": 500, "snake": 500, "lobster": 500, "mountain": 500, "palm_tree": 500, "skyscraper": 500, "tractor": 500, "shark": 500, "butterfly": 500, "bottle": 500, "bee": 500, "chair": 500, "woman": 500, "hamster": 500, "otter": 500, "seal": 500, "lion": 500, "mushroom": 500, "girl": 500, "sweet_pepper": 500, "forest": 500, "crocodile": 500, "orange": 500, "tulip": 500, "mouse": 500, "camel": 500, "caterpillar": 500, "man": 500, "skunk": 500, "kangaroo": 500, "raccoon": 500, "snail": 500, "rocket": 500, "whale": 500, "worm": 500, "turtle": 500, "beaver": 500, "plate": 500, "wardrobe": 500, "road": 500, "fox": 500, "flatfish": 500, "tiger": 500, "ray": 500, "dolphin": 500, "poppy": 500, "porcupine": 500, "lamp": 500, "cra": 500, "motorcycle": 500, "spider": 500, "tank": 500, "orchid": 500, "lizard": 500, "beetle": 500, "bridge": 500, "baby": 500, "lawn_mower": 500, "house": 500, "bus": 500, "couch": 500, "bowl": 500, "pear": 500, "bed": 500, "plain": 500, "trout": 500, "bear": 500, "pickup_truck": 500, "can": 500}, "img_dimensions": [32, 32, 3], "log_file_name": "2023-12-02_11-08-56_119bbd34-9342-49f3-8fd2-7e744b5689b4.log"}]} |
1 change: 1 addition & 0 deletions
1
{{cookiecutter.project_slug}}/art_checkpoints/EfficientNet_Check Loss On Init/results.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| {"name": "Check Loss On Init", "model": "EfficientNet", "runs": [{"scores": {"MulticlassAccuracy-validate": 0.00675999978557229, "CrossEntropyLoss-validate": 4.893486022949219}, "parameters": {"lr": 0.001, "model_name": "EfficientNet", "n_parameters": 7841894, "batch_size": 32, "train_samples": 50000, "val_samples": 10000}, "timestamp": "2023-12-02 11:14:53.975501", "successful": true, "hash": "eef024742b33fd04192b7c9fdc0c1124", "commit_id": "1b740ecf84c50de0b79fb155ef306ef622d92f08", "log_file_name": "2023-12-02_11-14-53_43361818-6469-4547-8a82-02be2670bd52.log"}]} |
1 change: 1 addition & 0 deletions
1
{{cookiecutter.project_slug}}/art_checkpoints/EfficientNet_Overfit One Batch/results.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| {"name": "Overfit One Batch", "model": "EfficientNet", "runs": [{"scores": {"MulticlassAccuracy-train": 1.0, "CrossEntropyLoss-train": 2.902682354033459e-05}, "parameters": {"number_of_steps": 40, "lr": 0.001, "model_name": "EfficientNet", "n_parameters": 7841894, "batch_size": 32, "train_samples": 50000, "val_samples": 10000}, "timestamp": "2023-12-02 11:16:49.769266", "successful": true, "hash": "eef024742b33fd04192b7c9fdc0c1124", "commit_id": "1b740ecf84c50de0b79fb155ef306ef622d92f08", "model_path": "/content/ART-Templates/{{cookiecutter.project_slug}}/lightning_logs/version_7/checkpoints/epoch=39-step=40.ckpt", "MulticlassAccuracy-train": 1.0, "CrossEntropyLoss-train": 2.902682354033459e-05, "log_file_name": "2023-12-02_11-16-49_0704e137-3c69-4515-a943-bffe0f645f44.log"}]} |
1 change: 1 addition & 0 deletions
1
{{cookiecutter.project_slug}}/art_checkpoints/EfficientNet_Overfit/results.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| {"name": "Overfit", "model": "EfficientNet", "runs": [{"scores": {"MulticlassAccuracy-train": 0.9455999732017517, "CrossEntropyLoss-train": 0.16572201251983643, "MulticlassAccuracy-validate": 0.7613999843597412, "CrossEntropyLoss-validate": 1.1986972093582153}, "parameters": {"max_epochs": 10, "lr": 0.001, "model_name": "EfficientNet", "n_parameters": 7841894, "batch_size": 32, "train_samples": 50000, "val_samples": 10000}, "timestamp": "2023-12-02 11:17:52.515484", "successful": true, "hash": "eef024742b33fd04192b7c9fdc0c1124", "commit_id": "1b740ecf84c50de0b79fb155ef306ef622d92f08", "model_path": "/content/ART-Templates/{{cookiecutter.project_slug}}/lightning_logs/version_8/checkpoints/epoch=9-step=15630.ckpt", "log_file_name": "2023-12-02_11-17-52_fb7fe9c4-02bb-4a5b-8d37-bd257c6d9f85.log"}]} |
1 change: 1 addition & 0 deletions
1
{{cookiecutter.project_slug}}/art_checkpoints/EfficientNet_TransferLearning/results.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| {"name": "TransferLearning", "model": "EfficientNet", "runs": [{"scores": {"MulticlassAccuracy-train": 0.9606000185012817, "CrossEntropyLoss-train": 0.12111776322126389, "MulticlassAccuracy-validate": 0.7610999941825867, "CrossEntropyLoss-validate": 1.254715085029602}, "parameters": {"max_epochs": 2, "check_val_every_n_epoch": 2, "lr": 0.001, "model_name": "EfficientNet", "n_parameters": 7841894, "batch_size": 32, "train_samples": 50000, "val_samples": 10000}, "timestamp": "2023-12-02 10:02:36.630585", "successful": true, "hash": "eef024742b33fd04192b7c9fdc0c1124", "commit_id": "1b740ecf84c50de0b79fb155ef306ef622d92f08", "model_path": "/content/ART-Templates/{{cookiecutter.project_slug}}/.neptune/Untitled/TRAN-21/checkpoints/epoch=13-step=21882.ckpt", "log_file_name": "2023-12-02_10-02-40_686443d1-130e-496b-ae10-7619cf9ea9dc.log"}]} |
1 change: 1 addition & 0 deletions
1
...kiecutter.project_slug}}/art_checkpoints/HeuristicBaseline_Evaluate Baseline/results.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| {"name": "Evaluate Baseline", "model": "HeuristicBaseline", "runs": [{"scores": {"MulticlassAccuracy-validate": 0.012299999594688416}, "parameters": {}, "timestamp": "2023-12-02 11:10:00.493179", "successful": true, "hash": "06dbb6cfec9272cc2e03dc1e9f8f1846", "commit_id": "1b740ecf84c50de0b79fb155ef306ef622d92f08", "log_file_name": "2023-12-02_11-10-00_e6aac212-98e4-4836-9884-ba7e210a7b2a.log"}]} |
1 change: 1 addition & 0 deletions
1
{{cookiecutter.project_slug}}/art_checkpoints/MlBaseline_Evaluate Baseline/results.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| {"name": "Evaluate Baseline", "model": "MlBaseline", "runs": [{"scores": {"MulticlassAccuracy-validate": 0.15649999678134918}, "parameters": {}, "timestamp": "2023-12-02 11:10:00.493217", "successful": true, "hash": "49b251eb0a6cbddbbd5b185281399fa3", "commit_id": "1b740ecf84c50de0b79fb155ef306ef622d92f08", "log_file_name": "2023-12-02_11-10-00_e6aac212-98e4-4836-9884-ba7e210a7b2a.log"}]} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| from art.checks import Check, CheckResult, ResultOfCheck | ||
| from art.steps import Step | ||
| from art.utils.savers import MatplotLibSaver | ||
|
|
||
|
|
||
| class CheckClassImagesExist(Check): | ||
| def check(self, step: Step) -> ResultOfCheck: | ||
| for class_name in step.get_latest_run()["class_names"]: | ||
| image_path = step.get_class_image_path(class_name) | ||
| if not MatplotLibSaver().exists(step.get_full_step_name(), image_path): | ||
| return ResultOfCheck( | ||
| is_positive=False, | ||
| error=f"Image for class: {class_name} does not exist. it should have been here: {MatplotLibSaver().get_path(step.get_full_step_name(), image_path)}", | ||
| ) | ||
| return ResultOfCheck(is_positive=True) | ||
|
|
||
|
|
||
| class CheckLenClassNamesEqualToNumClasses(CheckResult): | ||
| def _check_method(self, result) -> ResultOfCheck: | ||
| if len(result["class_names"]) != result["number_of_classes"]: | ||
| return ResultOfCheck( | ||
| is_positive=False, | ||
| error="Number of class names is different than number of classes", | ||
| ) | ||
| return ResultOfCheck(is_positive=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,40 @@ | ||
| from lightning import LightningDataModule | ||
| import lightning as pl | ||
| from datasets import load_dataset | ||
| from torch.utils.data import DataLoader | ||
|
|
||
|
|
||
| class MyDataModule(LightningDataModule): | ||
| def __init__(self): | ||
| super().__init__(...) | ||
| self.dataset = ... | ||
| class CifarDataModule(pl.LightningDataModule): | ||
| def __init__(self, batch_size: int = 32, num_workers: int = 4): | ||
| super().__init__() | ||
| self.batch_size = batch_size | ||
| self.dataset = load_dataset("cifar100").with_format("torch") | ||
| self.dataset = self.dataset.rename_columns( | ||
| {"img": "input", "fine_label": "target"} | ||
| ) | ||
| self.dataset = self.dataset.remove_columns(["coarse_label"]) | ||
| self.num_workers = num_workers | ||
|
|
||
| def setup(self, stage: str): | ||
| self.train = self.dataset["train"] | ||
| self.test = self.dataset["test"] | ||
|
|
||
| def train_dataloader(self): | ||
| return DataLoader(...) | ||
| return DataLoader( | ||
| self.dataset["train"], | ||
| batch_size=self.batch_size, | ||
| num_workers=self.num_workers, | ||
| ) | ||
|
|
||
| def val_dataloader(self): | ||
| return DataLoader(...) | ||
| return DataLoader( | ||
| self.dataset["test"], | ||
| batch_size=self.batch_size, | ||
| num_workers=self.num_workers, | ||
| ) | ||
|
|
||
| def log_params(self): | ||
| return { | ||
| "batch_size": self.batch_size, | ||
| "train_samples": len(self.dataset["train"]), | ||
| "val_samples": len(self.dataset["test"]), | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| from typing import Dict, Any | ||
| from art.core import ArtModule | ||
| import torch | ||
| import timm | ||
| from torchvision import transforms | ||
| from einops import rearrange | ||
| from art.utils.enums import ( | ||
| BATCH, | ||
| INPUT, | ||
| LOSS, | ||
| PREDICTION, | ||
| TARGET, | ||
| ) | ||
|
|
||
|
|
||
| class EfficientNet(ArtModule): | ||
| def __init__(self, num_classes: int = 100, lr: float = 1e-3): | ||
| super().__init__() | ||
| self.model = timm.create_model( | ||
| "efficientnet_b2.ra_in1k", pretrained=True, num_classes=num_classes | ||
| ) | ||
| self.lr = lr | ||
| self.preprocess = transforms.Compose( | ||
| [ | ||
| transforms.Normalize( | ||
| mean=[0.485, 0.456, 0.406], | ||
| std=[0.229, 0.224, 0.225], # statistics of ImageNet dataset | ||
| ), | ||
| transforms.Resize(256), # Size desired by this particular model | ||
| ] | ||
| ) | ||
|
|
||
| def parse_data(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||
| """ | ||
| This is first step of your pipeline it always has batch keys inside | ||
| The result of this step is passed to the next step in the pipeline which is predict | ||
| """ | ||
| X = data[BATCH][INPUT] | ||
| X = X / 255 | ||
| X = rearrange(X, "b h w c -> b c h w") | ||
| X = self.preprocess(X) | ||
| target = data[BATCH][TARGET].long() | ||
| return {INPUT: X, TARGET: target} | ||
|
|
||
| def predict(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||
| """ | ||
| This is the second step of your pipeline. The input of this step is the output of the previous step. | ||
| You should return a dictionary with PREDICTION and TARGET keys. | ||
| """ | ||
| return {PREDICTION: self.model(data[INPUT]), TARGET: data[TARGET]} | ||
|
|
||
| def compute_loss(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||
| """ | ||
| This is the last step of your pipeline. The input of this step is the output of the previous step. | ||
| You should return a dictionary with LOSS key. | ||
| You only need to specify which loss (metric) we want to use. | ||
| """ | ||
| loss = data["CrossEntropyLoss"] | ||
| return {LOSS: loss} | ||
|
|
||
| def configure_optimizers(self) -> torch.optim.Optimizer: | ||
| """ | ||
| Set up your optimizer. | ||
| """ | ||
| return torch.optim.Adam(self.parameters(), lr=self.lr) | ||
|
|
||
| def log_params(self) -> Dict[str, Any]: | ||
| """ | ||
| This is a method for logging relevant parameters. | ||
| It has to be implemented, however, it can be empty. | ||
| """ | ||
| return { | ||
| "lr": self.lr, | ||
| "model_name": self.model.__class__.__name__, | ||
| "n_parameters": sum( | ||
| p.numel() for p in self.parameters() if p.requires_grad | ||
| ), | ||
| } |
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| from typing import Dict, Any | ||
|
|
||
| import numpy as np | ||
| from einops import rearrange | ||
| from sklearn.linear_model import LogisticRegression | ||
|
|
||
| from art.core import ArtModule | ||
| from art.utils.enums import BATCH, INPUT, PREDICTION, TARGET | ||
|
|
||
|
|
||
| class MlBaseline(ArtModule): | ||
| name = "ML Baseline" | ||
|
|
||
| def __init__(self, model: Any = LogisticRegression()): | ||
| super().__init__() | ||
| self.model = model | ||
|
|
||
| def ml_parse_data(self, data: Dict): | ||
| X = [] | ||
| y = [] | ||
| for batch in data["dataloader"]: | ||
| X.append(batch[INPUT].flatten(start_dim=1).numpy() / 255) | ||
| y.append(batch[TARGET].numpy()) | ||
|
|
||
| return {INPUT: np.concatenate(X), TARGET: np.concatenate(y)} | ||
|
|
||
| def baseline_train(self, data: Dict): | ||
| self.model = self.model.fit(data[INPUT], data[TARGET]) | ||
| return {"model": self.model} | ||
|
|
||
| def parse_data(self, data: Dict): | ||
| """This is first step of your pipeline it always has batch keys inside""" | ||
| batch = data[BATCH] | ||
| return { | ||
| INPUT: batch[INPUT].flatten(start_dim=1).numpy(), | ||
| TARGET: batch[TARGET].numpy(), | ||
| } | ||
|
|
||
| def predict(self, data: Dict): | ||
| return {PREDICTION: self.model.predict(data[INPUT]), TARGET: data[TARGET]} | ||
|
|
||
| def log_params(self): | ||
| return {"model": self.model.__class__.__name__} | ||
|
|
||
|
|
||
| class HeuristicBaseline(ArtModule): | ||
| name = "Heuristic Baseline" | ||
| n_classes = 100 | ||
| img_shape = (32, 32, 3) | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def parse_data(self, data: Dict): | ||
| """This is first step of your pipeline it always has batch keys inside""" | ||
| batch = data[BATCH] | ||
| return { | ||
| INPUT: batch[INPUT].flatten(start_dim=1).numpy(), | ||
| TARGET: batch[TARGET].numpy(), | ||
| } | ||
|
|
||
| def baseline_train(self, data: Dict): | ||
| self.prototypes = np.zeros( | ||
| (self.n_classes, self.img_shape[0] * self.img_shape[1] * self.img_shape[2]) | ||
| ) | ||
| self.counts = np.zeros(self.n_classes) | ||
| for batch in data["dataloader"]: | ||
| for img, label in zip(batch[INPUT], batch[TARGET]): | ||
| self.prototypes[label.item()] += img.flatten().numpy() / 255 | ||
| self.counts[label.item()] += 1 | ||
|
|
||
| self.prototypes = self.prototypes / self.counts[:, None] | ||
|
|
||
| def predict(self, data: Dict): | ||
| y_hat = np.argmax((data[INPUT] @ self.prototypes.T), axis=1) | ||
| return {PREDICTION: y_hat, TARGET: data[TARGET]} | ||
|
|
||
| def log_params(self): | ||
| return {"model": "Heuristic"} | ||
|
|
||
|
|
||
| class AlreadyExistingResNet20Baseline(ArtModule): | ||
| name = "Already Existing ResNet20 Baseline" | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| import torch | ||
|
|
||
| self.model = torch.hub.load( | ||
| "chenyaofo/pytorch-cifar-models", "cifar100_resnet20", pretrained=True | ||
| ) | ||
|
|
||
| def parse_data(self, data: Dict): | ||
| mean = np.asarray([0.5071, 0.4867, 0.4408], dtype=np.float32) | ||
| std = np.asarray([0.2675, 0.2565, 0.2761], dtype=np.float32) | ||
| X = data[BATCH][INPUT] | ||
| X = (X / 255 - mean) / std | ||
| X = rearrange(X, "b h w c -> b c h w") | ||
| return {INPUT: X, TARGET: data[BATCH][TARGET]} | ||
|
|
||
| def predict(self, data: Dict): | ||
| preds = self.model(data[INPUT]).detach().numpy() | ||
| return {PREDICTION: preds, TARGET: data[TARGET]} | ||
|
|
||
| def log_params(self): | ||
| return {"model": self.model.__class__.__name__} |
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| from collections import Counter | ||
| import numpy as np | ||
| import matplotlib.pyplot as plt | ||
| from art.utils.savers import MatplotLibSaver | ||
| from art.steps import ExploreData | ||
| from art.utils.enums import INPUT, TARGET | ||
|
|
||
|
|
||
| class DataAnalysis(ExploreData): | ||
| def do(self, previous_states): | ||
| targets = [] | ||
| index2label = ( | ||
| lambda x: self.datamodule.dataset["train"].features[TARGET].int2str(x) | ||
| ) | ||
| # Loop through batches in the cifar_datamodule train dataloader | ||
| for batch in self.datamodule.train_dataloader(): | ||
| targets.extend(batch[TARGET]) | ||
| targets = [index2label(int(x)) for x in targets] | ||
| # Calculate the number of unique classes in the targets | ||
| number_of_classes = len(np.unique(targets)) | ||
| # Now tell me what are the names of these classes | ||
| class_names = list(self.datamodule.dataset["train"].features[TARGET].names) | ||
|
|
||
| # Now calculate number of images in each class | ||
| class_counts = Counter(targets) | ||
|
|
||
| # Now tell me dimensions of each image | ||
| img_dimensions = self.datamodule.train_dataloader().dataset[0][INPUT].shape | ||
| for cls in class_names: | ||
| class_indices = [i for i, label in enumerate(targets) if label == cls] | ||
| class_samples = np.random.choice(class_indices, 5, replace=False).tolist() | ||
|
|
||
| fig, axes = plt.subplots(1, 5, figsize=(15, 5)) | ||
| for i, sample_idx in enumerate(class_samples): | ||
| img = self.datamodule.train_dataloader().dataset[sample_idx][INPUT] | ||
| axes[i].imshow(img, cmap="gray") | ||
| axes[i].set_title(f"Class: {cls}") | ||
| axes[i].axis("off") | ||
|
|
||
| MatplotLibSaver().save( | ||
| fig, self.get_full_step_name(), self.get_class_image_path(cls) | ||
| ) | ||
|
|
||
| self.results.update( | ||
| { | ||
| "number_of_classes": number_of_classes, | ||
| "class_names": class_names, | ||
| "number_of_examples_in_each_class": class_counts, | ||
| "img_dimensions": img_dimensions, | ||
| } | ||
| ) | ||
|
|
||
| def log_params(self): | ||
| return {} | ||
|
|
||
| def get_class_image_path(self, class_name: str): | ||
| return f"class_images/class_{class_name}.png" |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import lightning as L, not pl; it confuses it with the older version
PyTorch LightningThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lightning-AI/pytorch-lightning#16688