diff --git a/QEfficient/finetune/experimental/core/callbacks.py b/QEfficient/finetune/experimental/core/callbacks.py index bd1ce91c2..8920a9d5d 100644 --- a/QEfficient/finetune/experimental/core/callbacks.py +++ b/QEfficient/finetune/experimental/core/callbacks.py @@ -6,7 +6,9 @@ # ----------------------------------------------------------------------------- import json +import logging import os +import time from typing import Any, Dict, Optional from transformers import ( @@ -30,6 +32,23 @@ registry.callback("printer")(PrinterCallback) registry.callback("default_flow")(DefaultFlowCallback) registry.callback("tensorboard")(TensorBoardCallback) +logger = logging.getLogger(__name__) + +@registry.callback("epoch_timer") +# Custom callback to reset epoch start time +class EpochTimerCallback(TrainerCallback): + def on_epoch_begin(self, args, state, control, **kwargs): + # Record the start time of the current epoch + self.epoch_start_time = time.time() + + def on_epoch_end(self, args, state, control, **kwargs): + # Compute time elapsed for the epoch + elapsed = time.time() - self.epoch_start_time + if state.is_world_process_zero: + logger.info(f"[Epoch {state.epoch:.2f}] {elapsed:.2f} sec") + # attach to log history so it goes to TB/W&B/etc. + state.log_history.append({"train/epoch_time_sec": elapsed, "epoch": state.epoch}) + control.should_log = True @registry.callback("enhanced_progressbar") @@ -223,6 +242,9 @@ def replace_progress_callback(trainer: Any, callbacks: list[Any], logger: Any = pass try: + # Add Epoch Timer + epoch_timer = ComponentFactory.create_callback("epoch_timer") + trainer.add_callback(epoch_timer) # Add EnhancedProgressCallback enhanced_callback = ComponentFactory.create_callback("enhanced_progressbar") trainer.add_callback(enhanced_callback)