-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
executable file
·147 lines (120 loc) · 3.75 KB
/
train.py
File metadata and controls
executable file
·147 lines (120 loc) · 3.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MultiGridDet Training Script
Train object detection models with YAML configuration.
Usage:
python train.py --config configs/train_config.yaml
python train.py --config configs/train_config.yaml --weights weights/model5.h5
python train.py --config configs/train_config.yaml --backbone-weights weights/darknet53.h5
python train.py --config configs/train_config.yaml --resume
"""
import argparse
import sys
from pathlib import Path
# Add project root to Python path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from multigriddet.config import ConfigLoader
from multigriddet.trainers import MultiGridTrainer
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description='Train MultiGridDet model',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'--config',
type=str,
default='configs/train_config.yaml',
help='Path to training config file'
)
parser.add_argument(
'--weights',
type=str,
default=None,
help='Path to pretrained full model weights (overrides config)'
)
parser.add_argument(
'--backbone-weights',
type=str,
default=None,
help='Path to pretrained backbone weights (e.g., darknet53.h5)'
)
parser.add_argument(
'--resume',
action='store_true',
help='Resume training from checkpoint'
)
parser.add_argument(
'--epochs',
type=int,
default=None,
help='Number of epochs (overrides config)'
)
parser.add_argument(
'--batch-size',
type=int,
default=None,
help='Batch size (overrides config)'
)
return parser.parse_args()
def main():
"""Main training function."""
# Parse arguments
args = parse_args()
print("=" * 80)
print("MultiGridDet Training")
print("=" * 80)
print(f"Config file: {args.config}")
# Load configuration
try:
config = ConfigLoader.load_config(args.config)
except FileNotFoundError as e:
print(f"[ERROR] {e}")
print(f" Please check that the config file exists")
return 1
except Exception as e:
print(f"[ERROR] Error loading config: {e}")
return 1
# Override with command-line args
if 'resume' not in config:
config['resume'] = {}
if args.weights:
config['resume']['weights_path'] = args.weights
print(f" Using full model weights: {args.weights}")
if args.backbone_weights:
config['resume']['backbone_weights_path'] = args.backbone_weights
print(f" Using backbone weights: {args.backbone_weights}")
if args.resume:
config['resume']['enabled'] = True
print(f" Resume mode enabled")
if args.epochs:
config['training']['epochs'] = args.epochs
print(f" Epochs: {args.epochs}")
if args.batch_size:
config['training']['batch_size'] = args.batch_size
print(f" Batch size: {args.batch_size}")
print()
# Create trainer
try:
trainer = MultiGridTrainer(config)
except Exception as e:
print(f"[ERROR] Error creating trainer: {e}")
import traceback
traceback.print_exc()
return 1
# Start training
try:
trainer.train()
return 0
except KeyboardInterrupt:
print("\n\n[WARNING] Training interrupted by user")
return 1
except Exception as e:
print(f"\n[ERROR] Training error: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == '__main__':
sys.exit(main())