-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
94 lines (74 loc) · 2.99 KB
/
evaluate.py
File metadata and controls
94 lines (74 loc) · 2.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import dataclasses
import lightning.pytorch as pl
import torch
import tyro
import wandb
from lightning.pytorch.loggers import WandbLogger
from pape.configs import Checkpoint
from pape.configs import Config
from pape.configs import Split
from pape.constants import WANDB_PROJECT
from pape.data import load_datamodule
from pape.models import load_model
from pape.paths import get_experiment_dir
def main(
name: str,
/,
batch_size: int = 1,
checkpoint: Checkpoint = Checkpoint.best,
num_workers: int = 4,
split: Split = Split.val,
detection_max_det: int | None = None,
detection_min_conf: float | None = None,
) -> None:
torch.set_float32_matmul_precision("high")
config, run_id, weights_path = load_checkpoint(name, checkpoint)
overrides = {"batch_size": batch_size, "num_workers": num_workers}
if detection_max_det is not None or detection_min_conf is not None:
detection_overrides = {}
if detection_max_det is not None:
detection_overrides["max_det"] = detection_max_det
if detection_min_conf is not None:
detection_overrides["min_conf"] = detection_min_conf
overrides["detection"] = dataclasses.replace(config.detection, **detection_overrides)
config = dataclasses.replace(config, **overrides)
datamodule = load_datamodule(config)
lightning_model = load_model(config)
logger = WandbLogger(
id=run_id,
prefix="eval",
project=WANDB_PROJECT,
settings=wandb.Settings(_disable_stats=True),
)
trainer = pl.Trainer(
callbacks=[],
logger=logger,
precision=config.train.precision,
)
match split:
case Split.val:
trainer.validate(lightning_model, datamodule=datamodule, ckpt_path=weights_path)
case Split.test:
trainer.test(lightning_model, datamodule=datamodule, ckpt_path=weights_path)
case _:
raise ValueError(f"Unsupported evaluation split: {split}")
def load_checkpoint(name: str, checkpoint: Checkpoint) -> tuple[Config, str, str]:
experiment_dir = get_experiment_dir(name)
if not experiment_dir.exists():
raise ValueError(f"Could not find experiment directory {experiment_dir}.")
run_id = (experiment_dir / "run_id.txt").read_text()
# get model parameters from wandb run with the given name
api = wandb.Api()
run = api.run(f"{WANDB_PROJECT}/{run_id}")
config = Config.from_dict(run.config)
checkpoints_dir = experiment_dir / "checkpoints"
best_path = checkpoints_dir / f"{Checkpoint.best.value}.ckpt"
last_path = checkpoints_dir / f"{Checkpoint.last.value}.ckpt"
if not best_path.exists() and not last_path.exists():
raise ValueError(f"Could not find any checkpoints in {checkpoints_dir}.")
if checkpoint == Checkpoint.best:
weights_path = best_path if best_path.exists() else last_path
else:
weights_path = last_path if last_path.exists() else best_path
return config, run_id, weights_path
tyro.cli(main)