Skip to content

Commit fb2834b

Browse files
committed
lightning logs delete + predict fix
1 parent a5f845a commit fb2834b

3 files changed

Lines changed: 19 additions & 8 deletions

File tree

{{cookiecutter.project_slug}}/.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,3 @@ instance/
104104
*.png
105105
art_checkpoints/
106106
lightning_logs/
107-

{{cookiecutter.project_slug}}/EDA.ipynb

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,19 @@
129129
"project.run_all()"
130130
]
131131
},
132+
{
133+
"cell_type": "code",
134+
"execution_count": null,
135+
"metadata": {},
136+
"outputs": [],
137+
"source": [
138+
"from art.steps import OverfitOneBatch\n",
139+
"from art.checks import CheckScoreLessThan\n",
140+
"project.add_step(OverfitOneBatch(ResNet18, number_of_steps=40),\n",
141+
" [CheckScoreLessThan(metric=ce_loss, value=0.05)])\n",
142+
"project.run_all()"
143+
]
144+
},
132145
{
133146
"cell_type": "code",
134147
"execution_count": null,

{{cookiecutter.project_slug}}/models/ResNet.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,22 @@ def __init__(self, num_classes=100, lr=1e-3):
2727
transforms.Resize(256),
2828
transforms.CenterCrop(224),
2929
])
30+
# for name, para in self.model.named_parameters():
31+
# para.requires_grad = True
3032

3133
def parse_data(self, data):
3234
"""This is first step of your pipeline it always has batch keys inside"""
3335
X = data[BATCH][INPUT]
3436
X = X / 255
3537
X = rearrange(X, "b h w c -> b c h w")
3638
X = self.preprocess(X)
37-
38-
return {INPUT: X, TARGET: data[BATCH][TARGET]}
39+
target = data[BATCH][TARGET].long()
40+
return {INPUT: X, TARGET: target}
3941

4042

4143

42-
def predict(self, data: Dict):
43-
# preds = self.model(data[INPUT]).detach().numpy()
44-
preds = self.model(data[INPUT]).detach()
45-
# perhaps softmax will be needed
46-
return {PREDICTION: preds, TARGET: data[TARGET]}
44+
def predict(self, data: Dict):
45+
return {PREDICTION: self.model(data[INPUT]), TARGET: data[TARGET]}
4746

4847
def compute_loss(self, data):
4948
# Notice that the loss calculation is done in MetricsCalculator!

0 commit comments

Comments
 (0)