Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions QEfficient/finetune/experimental/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
# -----------------------------------------------------------------------------

import json
import logging
import os
import time
from typing import Any, Dict, Optional

from transformers import (
Expand All @@ -30,6 +32,23 @@
registry.callback("printer")(PrinterCallback)
registry.callback("default_flow")(DefaultFlowCallback)
registry.callback("tensorboard")(TensorBoardCallback)
logger = logging.getLogger(__name__)

@registry.callback("epoch_timer")
# Custom callback to reset epoch start time
class EpochTimerCallback(TrainerCallback):
def on_epoch_begin(self, args, state, control, **kwargs):
# Record the start time of the current epoch
self.epoch_start_time = time.time()

def on_epoch_end(self, args, state, control, **kwargs):
# Compute time elapsed for the epoch
elapsed = time.time() - self.epoch_start_time
if state.is_world_process_zero:
logger.info(f"[Epoch {state.epoch:.2f}] {elapsed:.2f} sec")
# attach to log history so it goes to TB/W&B/etc.
state.log_history.append({"train/epoch_time_sec": elapsed, "epoch": state.epoch})
control.should_log = True


@registry.callback("enhanced_progressbar")
Expand Down Expand Up @@ -223,6 +242,9 @@ def replace_progress_callback(trainer: Any, callbacks: list[Any], logger: Any =
pass

try:
# Add Epoch Timer
epoch_timer = ComponentFactory.create_callback("epoch_timer")
trainer.add_callback(epoch_timer)
# Add EnhancedProgressCallback
enhanced_callback = ComponentFactory.create_callback("enhanced_progressbar")
trainer.add_callback(enhanced_callback)
Expand Down
Loading