-
Notifications
You must be signed in to change notification settings - Fork 212
Description
Summary
After loading the best checkpoint in EvaluationHook.after_run, the EMA shadow used by evaluate() is not refreshed. This causes test evaluation to use the EMA state from the last training step instead of the EMA state associated with the loaded "best" checkpoint.
Relevant Code
EvaluationHook.after_run
Semi-supervised-learning/semilearn/core/hooks/evaluation.py
Lines 32 to 37 in 1ef4cbe
if 'test' in algorithm.loader_dict: # load the best model and evaluate on test dataset best_model_path = os.path.join(algorithm.args.save_dir, algorithm.args.save_name, 'model_best.pth') algorithm.load_model(best_model_path) test_dict = algorithm.evaluate('test') results_dict['test/best_acc'] = test_dict['test/top-1-acc'] AlgorithmBase.evaluate
Semi-supervised-learning/semilearn/core/algorithmbase.py
Lines 329 to 334 in 1ef4cbe
def evaluate(self, eval_dest='eval', out_key='logits', return_logits=False): """ evaluation function """ self.model.eval() self.ema.apply_shadow()
Problem
- Inconsistency in Results: There is a discrepancy between the
test/best_accreported in the training log and the result obtained by manually evaluating the savedmodel_best.pth. Specifically, the final test evaluation triggered at the end of training produces different metrics than those logged during the best epoch. - Technical Root Cause:
evaluate()applies EMA viaalgorithm.ema.apply_shadow(). - After calling
load_model()in theafter_runhook, thealgorithm.ema.shadowstill reflects the EMA weights from the very last training step. - Consequently, test metrics are computed with a mismatched EMA state: the model weights are correctly loaded from the "best" checkpoint, but the EMA shadow buffer remains stuck at the "last" training step.
Expected Behavior
After loading the best checkpoint, the evaluation process should use the EMA shadow parameters that correspond specifically to that checkpoint.
Proposed Fix
Option A: Minimal (In EvaluationHook)
Refresh the EMA shadow from the loaded ema_model before calling the final evaluation:
# In EvaluationHook.after_run
self.algorithm.load_model(os.path.join(self.args.save_dir, self.args.save_name))
if hasattr(self.algorithm, 'ema') and self.algorithm.ema is not None:
self.algorithm.ema.load(self.algorithm.ema_model) # Sync shadow with loaded weights
self.algorithm.evaluate('test')Option B: Centralized (In AlgorithmBase.load_model)
Ensure all loading paths synchronize the EMA state automatically:
# In AlgorithmBase.load_model
def load_model(self, load_path):
checkpoint = torch.load(load_path)
self.model.load_state_dict(checkpoint['model'])
self.ema_model.load_state_dict(checkpoint['ema_model'])
if self.ema is not None:
self.ema.load(self.ema_model) # Automatically update shadow