Skip to content

EMA not synced with loaded best model during test evaluation in EvaluationHook.after_run #258

@jwc-rad

Description

@jwc-rad

Summary

After loading the best checkpoint in EvaluationHook.after_run, the EMA shadow used by evaluate() is not refreshed. This causes test evaluation to use the EMA state from the last training step instead of the EMA state associated with the loaded "best" checkpoint.

Relevant Code

  • EvaluationHook.after_run
    if 'test' in algorithm.loader_dict:
    # load the best model and evaluate on test dataset
    best_model_path = os.path.join(algorithm.args.save_dir, algorithm.args.save_name, 'model_best.pth')
    algorithm.load_model(best_model_path)
    test_dict = algorithm.evaluate('test')
    results_dict['test/best_acc'] = test_dict['test/top-1-acc']
  • AlgorithmBase.evaluate
    def evaluate(self, eval_dest='eval', out_key='logits', return_logits=False):
    """
    evaluation function
    """
    self.model.eval()
    self.ema.apply_shadow()

Problem

  1. Inconsistency in Results: There is a discrepancy between the test/best_acc reported in the training log and the result obtained by manually evaluating the saved model_best.pth. Specifically, the final test evaluation triggered at the end of training produces different metrics than those logged during the best epoch.
  2. Technical Root Cause: evaluate() applies EMA via algorithm.ema.apply_shadow().
  3. After calling load_model() in the after_run hook, the algorithm.ema.shadow still reflects the EMA weights from the very last training step.
  4. Consequently, test metrics are computed with a mismatched EMA state: the model weights are correctly loaded from the "best" checkpoint, but the EMA shadow buffer remains stuck at the "last" training step.

Expected Behavior

After loading the best checkpoint, the evaluation process should use the EMA shadow parameters that correspond specifically to that checkpoint.


Proposed Fix

Option A: Minimal (In EvaluationHook)

Refresh the EMA shadow from the loaded ema_model before calling the final evaluation:

# In EvaluationHook.after_run
self.algorithm.load_model(os.path.join(self.args.save_dir, self.args.save_name))
if hasattr(self.algorithm, 'ema') and self.algorithm.ema is not None:
    self.algorithm.ema.load(self.algorithm.ema_model) # Sync shadow with loaded weights
self.algorithm.evaluate('test')

Option B: Centralized (In AlgorithmBase.load_model)

Ensure all loading paths synchronize the EMA state automatically:

# In AlgorithmBase.load_model
def load_model(self, load_path):
    checkpoint = torch.load(load_path)
    self.model.load_state_dict(checkpoint['model'])
    self.ema_model.load_state_dict(checkpoint['ema_model'])
    if self.ema is not None:
        self.ema.load(self.ema_model) # Automatically update shadow

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions