|
11 | 11 | import time |
12 | 12 | from typing import Any, Dict, Optional |
13 | 13 |
|
| 14 | +from logger import Logger |
14 | 15 | from transformers import ( |
15 | 16 | DefaultFlowCallback, |
16 | 17 | EarlyStoppingCallback, |
|
33 | 34 | registry.callback("printer")(PrinterCallback) |
34 | 35 | registry.callback("default_flow")(DefaultFlowCallback) |
35 | 36 | registry.callback("tensorboard")(TensorBoardCallback) |
| 37 | +logger = Logger("epoch_logger", log_file="logs/training_log.log", level=logging.DEBUG) |
| 38 | + |
36 | 39 |
|
37 | 40 | @registry.callback("epoch_timer") |
38 | 41 | # Custom callback to reset epoch start time |
39 | 42 | class EpochTimerCallback(TrainerCallback): |
40 | 43 | def on_epoch_begin(self, args, state, control, **kwargs): |
41 | 44 | # Record the start time of the current epoch |
42 | 45 | self.epoch_start_time = time.time() |
43 | | - #print(f"Epoch {state.epoch} started at {time.ctime(self.epoch_start_time)}") |
44 | 46 |
|
45 | 47 | def on_epoch_end(self, args, state, control, **kwargs): |
46 | 48 | # Compute time elapsed for the epoch |
47 | 49 | elapsed = time.time() - self.epoch_start_time |
48 | | - #print(f"Time taken to execute Epoch {state.epoch}: {elapsed:.2f} seconds") |
49 | | - if state.is_world_process_zero: |
50 | | - print(f"[Epoch {state.epoch:.2f}] {elapsed:.2f} sec") |
51 | | - # attach to log history so it goes to TB/W&B/etc. |
| 50 | + logger.log_rank_zero(f"[Epoch {state.epoch:.2f}] {elapsed:.2f} sec") |
| 51 | + # attach to log history so it goes to TB/W&B/etc. |
52 | 52 | state.log_history.append({"train/epoch_time_sec": elapsed, "epoch": state.epoch}) |
53 | 53 | control.should_log = True |
54 | 54 |
|
55 | | - |
56 | 55 | @registry.callback("enhanced_progressbar") |
57 | 56 | class EnhancedProgressCallback(ProgressCallback): |
58 | 57 | """ |
@@ -244,10 +243,10 @@ def replace_progress_callback(trainer: Any, callbacks: list[Any], logger: Any = |
244 | 243 | pass |
245 | 244 |
|
246 | 245 | try: |
247 | | - #Add Epoch Timer |
| 246 | + # Add Epoch Timer |
248 | 247 | epoch_timer = ComponentFactory.create_callback("epoch_timer") |
249 | 248 | trainer.add_callback(epoch_timer) |
250 | | - #Add EnhancedProgressCallback |
| 249 | + # Add EnhancedProgressCallback |
251 | 250 | enhanced_callback = ComponentFactory.create_callback("enhanced_progressbar") |
252 | 251 | trainer.add_callback(enhanced_callback) |
253 | 252 | except Exception as e: |
|
0 commit comments