diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 4ab15b9d50f7..cbb6b31fc30c 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -896,9 +896,10 @@ def on_log(self, args, state, control, logs, model=None, **kwargs): if not self._initialized: self.setup(args, state, model) if state.is_world_process_zero: + metrics = {} for k, v in logs.items(): if isinstance(v, (int, float)): - self._ml_flow.log_metric(k, v, step=state.global_step) + metrics[k] = v else: logger.warning( f"Trainer is attempting to log a value of " @@ -906,6 +907,7 @@ def on_log(self, args, state, control, logs, model=None, **kwargs): f"MLflow's log_metric() only accepts float and " f"int types so we dropped this attribute." ) + self._ml_flow.log_metrics(metrics=metrics, step=state.global_step) def on_train_end(self, args, state, control, **kwargs): if self._initialized and state.is_world_process_zero: