-
Notifications
You must be signed in to change notification settings - Fork 30
Expand file tree
/
Copy pathtrainer.py
More file actions
65 lines (46 loc) · 2.18 KB
/
trainer.py
File metadata and controls
65 lines (46 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from runtime.rpc import trainer_client
import argparse
import utils
import time
import threading
import os
class Trainer(object):
def __init__(self, worker_ip, worker_port, trainer_ip, trainer_port, job_id, batch_size) -> None:
super().__init__()
self._trainer_ip = trainer_ip
self._trainer_port = trainer_port
self._job_id = job_id
self._batch_size = batch_size
self._logger = utils.make_logger(__name__)
self._start_time = time.time()
self._finished_iteraions = 0
self._client_for_scheduler = trainer_client.TrainerClientForScheduler(self._logger, worker_ip, worker_port)
self.init_stats()
self._logger.info(f'job {self._job_id}, trainer, start, {self._start_time}')
def init_stats(self):
self._last_report_time = time.time()
LOG_FILE_PATH = os.getenv('TGS_LOG_FILE_PATH')
self._fd = open(LOG_FILE_PATH, 'w') if LOG_FILE_PATH != None else None
print(f'LOG_FILE_PATH: {self._fd}')
def update_stats(self, iteration_time):
self._finished_iteraions += 1
if self._fd != None:
print('%lf %lf' % (time.time(), self._batch_size / iteration_time), file=self._fd)
def record(self, iteration_time):
self.update_stats(iteration_time)
if time.time() - self._last_report_time >= 10:
self._client_for_scheduler.report_stats(self._job_id, self._finished_iteraions)
self._finished_iteraions = 0
self._last_report_time = time.time()
def close(self):
self._client_for_scheduler.report_stats(self._job_id, self._finished_iteraions)
self._finished_iteraions = 0
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--worker_ip', type=str, required=True)
parser.add_argument('--worker_port', type=int, default=6889)
parser.add_argument('--trainer_port', type=int)
parser.add_argument('--job_id', type=int, default=-1)
parser.add_argument('--batch_size', type=int, default=8)
args = parser.parse_args()
trainer = Trainer(args.worker_ip, args.worker_port, utils.get_host_ip(), args.trainer_port, args.job_id, args.batch_size)