-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
46 lines (35 loc) · 1.06 KB
/
main.py
File metadata and controls
46 lines (35 loc) · 1.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import argparse
import io
import torch
import src.serialisation.local as local_storage
from model import build_model
from src.run import run
from src.train import train
def cmd_train(save):
model = train()
buffer = io.BytesIO()
torch.save(model.state_dict(), buffer)
save(buffer.getvalue())
def cmd_run(model_data):
model = build_model()
model.load_state_dict(torch.load(io.BytesIO(model_data), weights_only=True))
run(model.eval())
def main():
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="command", required=True)
subparsers.add_parser("train")
run_parser = subparsers.add_parser("run")
run_parser.add_argument(
"model",
nargs="?",
default=None,
help="Path to model weights (default: latest in out/)",
)
args = parser.parse_args()
if args.command == "train":
cmd_train(save=local_storage.save)
elif args.command == "run":
model_data = local_storage.load(args.model)
cmd_run(model_data)
if __name__ == "__main__":
main()