-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
93 lines (79 loc) · 3.42 KB
/
main.py
File metadata and controls
93 lines (79 loc) · 3.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# main.py
import argparse
from torch.utils.data import DataLoader
from smokenet.config import load_config
from smokenet.data.collate import smoke_collate_fn
from smokenet.data.loader import load_datasets
from smokenet.train import train
from smokenet.utils.logging import setup_logging
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=["train", "eval"], default="train")
parser.add_argument(
"--config", default="config/default.yaml", help="Path to config YAML file"
)
parser.add_argument(
"--model", default=None, help="Path to model checkpoint (eval mode only)"
)
parser.add_argument("--batch-size", type=int, help="Override training batch size")
parser.add_argument("--num-epochs", type=int, help="Override training epochs")
parser.add_argument(
"--learning-rate", type=float, help="Override optimizer learning rate"
)
parser.add_argument("--device", type=str, help="Override training device")
parser.add_argument("--window-size", type=int, help="Override data window size")
args = parser.parse_args()
data_cfg, model_cfg, train_cfg = load_config(args.config)
logger = setup_logging(train_cfg.output_root)
logger.info("Loaded config from %s", args.config)
def warn_override(cfg, attr, value, label: str):
if value is None:
return
old = getattr(cfg, attr)
if value != old:
logger.warning("Override %s: %s -> %s", label, old, value)
setattr(cfg, attr, value)
warn_override(train_cfg, "batch_size", args.batch_size, "batch_size")
warn_override(train_cfg, "num_epochs", args.num_epochs, "num_epochs")
warn_override(train_cfg, "learning_rate", args.learning_rate, "learning_rate")
warn_override(train_cfg, "device", args.device, "device")
warn_override(data_cfg, "window_size", args.window_size, "window_size")
# model accepts raw sensor channels; windowing handled by dataset
model_cfg.in_channels = data_cfg.channels
train_dataset, val_dataset = load_datasets(data_cfg)
logger.info(
"Loaded datasets with %d training samples and %d validation samples",
len(train_dataset),
len(val_dataset),
)
base_dataset = (
train_dataset.dataset if hasattr(train_dataset, "dataset") else train_dataset
)
fuel_available = getattr(base_dataset, "fuel_available", False)
fuel_enabled = model_cfg.enable_fuel_classification and fuel_available
logger.info("Fuel classification enabled: %s", fuel_enabled)
train_loader = DataLoader(
train_dataset,
batch_size=train_cfg.batch_size,
shuffle=True,
collate_fn=smoke_collate_fn,
)
val_loader = DataLoader(
val_dataset,
batch_size=train_cfg.batch_size,
shuffle=False,
collate_fn=smoke_collate_fn,
)
if args.mode == "train":
if args.model:
logger.warning("--model only used in eval mode, ignored in train mode.")
model, _ = train(train_loader, val_loader, model_cfg, train_cfg, fuel_enabled)
else:
if not args.model:
logger.warning("No --model provided, cannot load weights in eval mode.")
else:
logger.warning(
"Eval mode not implemented, skip loading model: %s", args.model
)
if __name__ == "__main__":
main()