Skip to content

Commit 35f2ff1

Browse files
Anusha Bhamidipatiquic-abhamidi
authored andcommitted
Added per epoch time in the log files
Signed-off-by: Anusha Bhamidipati <abhamidi@qti.qualcomm.com>
1 parent f2ab8a3 commit 35f2ff1

2 files changed

Lines changed: 9 additions & 9 deletions

File tree

QEfficient/finetune/experimental/core/callbacks.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import time
1212
from typing import Any, Dict, Optional
1313

14+
from logger import Logger
1415
from transformers import (
1516
DefaultFlowCallback,
1617
EarlyStoppingCallback,
@@ -33,26 +34,24 @@
3334
registry.callback("printer")(PrinterCallback)
3435
registry.callback("default_flow")(DefaultFlowCallback)
3536
registry.callback("tensorboard")(TensorBoardCallback)
37+
logger = Logger("epoch_logger", log_file="logs/training_log.log", level=logging.DEBUG)
38+
3639

3740
@registry.callback("epoch_timer")
3841
# Custom callback to reset epoch start time
3942
class EpochTimerCallback(TrainerCallback):
4043
def on_epoch_begin(self, args, state, control, **kwargs):
4144
# Record the start time of the current epoch
4245
self.epoch_start_time = time.time()
43-
#print(f"Epoch {state.epoch} started at {time.ctime(self.epoch_start_time)}")
4446

4547
def on_epoch_end(self, args, state, control, **kwargs):
4648
# Compute time elapsed for the epoch
4749
elapsed = time.time() - self.epoch_start_time
48-
#print(f"Time taken to execute Epoch {state.epoch}: {elapsed:.2f} seconds")
49-
if state.is_world_process_zero:
50-
print(f"[Epoch {state.epoch:.2f}] {elapsed:.2f} sec")
51-
# attach to log history so it goes to TB/W&B/etc.
50+
logger.log_rank_zero(f"[Epoch {state.epoch:.2f}] {elapsed:.2f} sec")
51+
# attach to log history so it goes to TB/W&B/etc.
5252
state.log_history.append({"train/epoch_time_sec": elapsed, "epoch": state.epoch})
5353
control.should_log = True
5454

55-
5655
@registry.callback("enhanced_progressbar")
5756
class EnhancedProgressCallback(ProgressCallback):
5857
"""
@@ -244,10 +243,10 @@ def replace_progress_callback(trainer: Any, callbacks: list[Any], logger: Any =
244243
pass
245244

246245
try:
247-
#Add Epoch Timer
246+
# Add Epoch Timer
248247
epoch_timer = ComponentFactory.create_callback("epoch_timer")
249248
trainer.add_callback(epoch_timer)
250-
#Add EnhancedProgressCallback
249+
# Add EnhancedProgressCallback
251250
enhanced_callback = ComponentFactory.create_callback("enhanced_progressbar")
252251
trainer.add_callback(enhanced_callback)
253252
except Exception as e:

QEfficient/finetune/experimental/core/trainer/base_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
78
from typing import Optional
89

910
from peft import get_peft_model
@@ -62,7 +63,7 @@ def __init__(
6263
if peft_config is not None and model is not None:
6364
model = get_peft_model(model, peft_config)
6465
model.print_trainable_parameters()
65-
66+
model.config.use_cache = False
6667
# Initialize the parent Trainer class
6768
super().__init__(
6869
model=model,

0 commit comments

Comments
 (0)