1+ from typing import Dict
2+ from art .core import ArtModule
3+ import torch
4+ import timm
5+ import torch .nn as nn
6+ from torchvision import transforms
7+ import numpy as np
8+ from einops import rearrange
9+ from art .utils .enums import (
10+ BATCH ,
11+ INPUT ,
12+ LOSS ,
13+ PREDICTION ,
14+ TARGET ,
15+ TRAIN_LOSS ,
16+ VALIDATION_LOSS ,
17+ )
18+
19+ class EffiNet (ArtModule ):
20+ def __init__ (self , num_classes = 100 , lr = 1e-3 ):
21+ super ().__init__ ()
22+ self .model = timm .create_model ('efficientnet_b2.ra_in1k' , pretrained = True , num_classes = 100 )
23+ self .loss = torch .nn .CrossEntropyLoss ()
24+ self .lr = lr
25+ self .preprocess = transforms .Compose ([
26+ transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]),
27+ transforms .Resize (256 ),
28+ ])
29+
30+ def parse_data (self , data ):
31+ """This is first step of your pipeline it always has batch keys inside"""
32+ X = data [BATCH ][INPUT ]
33+ X = X / 255
34+ X = rearrange (X , "b h w c -> b c h w" )
35+ X = self .preprocess (X )
36+ target = data [BATCH ][TARGET ].long ()
37+ return {INPUT : X , TARGET : target }
38+
39+
40+
41+ def predict (self , data : Dict ):
42+ return {PREDICTION : self .model (data [INPUT ]), TARGET : data [TARGET ]}
43+
44+ def compute_loss (self , data ):
45+ # Notice that the loss calculation is done in MetricsCalculator!
46+ # We only need to specify which loss (metric) we want to use
47+ loss = data ["CrossEntropyLoss" ]
48+ return {LOSS : loss }
49+
50+ def configure_optimizers (self ):
51+ return torch .optim .Adam (self .parameters (), lr = self .lr )
52+
53+ def log_params (self ):
54+ # Log relevant parameters
55+ return {
56+ "lr" : self .lr ,
57+ "model_name" : self .model .__class__ .__name__ ,
58+ "n_parameters" : sum (p .numel () for p in self .parameters () if p .requires_grad ),
59+ }
0 commit comments