Skip to content

update EMA#670

Open
yutong-xiang-97 wants to merge 2 commits intomainfrom
yutong-trn-1834-update-momentum-functions
Open

update EMA#670
yutong-xiang-97 wants to merge 2 commits intomainfrom
yutong-trn-1834-update-momentum-functions

Conversation

@yutong-xiang-97
Copy link
Copy Markdown
Contributor

What has changed and why?

Update EMA calculation with _foreach_mul_

How has it been tested?

Unit tests

Did you update CHANGELOG.md?

  • Yes
  • Not needed (internal change)

Did you update the documentation?

  • Yes
  • Not needed (internal change without effects for user)

Copilot AI review requested due to automatic review settings March 26, 2026 18:07
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates exponential moving average (EMA) updates to use in-place/foreach tensor operations, and routes method implementations to the new internal helpers.

Changes:

  • Added update_ema_tensors (foreach-based) and update_momentum helpers in _torch_helpers.py.
  • Updated object detection ModelEMA and DINO/DINOv2/DenseCL codepaths to use the new EMA/momentum helpers.
  • Added unit tests covering the new EMA/momentum helper behavior.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
src/lightly_train/_torch_helpers.py Introduces foreach-based EMA tensor updates and a momentum helper.
src/lightly_train/_task_models/object_detection_components/ema.py Switches ModelEMA update logic to batch-update floating tensors via update_ema_tensors.
src/lightly_train/_methods/dinov2/dinov2_loss.py Uses in-place ops for center EMA update to avoid tensor reassignment.
src/lightly_train/_methods/dinov2/dinov2.py Replaces external update_momentum import with internal helper.
src/lightly_train/_methods/dino/dino.py Replaces external update_momentum import with internal helper.
src/lightly_train/_methods/densecl/densecl.py Uses internal update_momentum helper for key/query encoder momentum update.
tests/test__torch_helpers.py Adds unit tests for update_momentum and dtype-grouped update_ema_tensors.

@yutong-xiang-97 yutong-xiang-97 force-pushed the yutong-trn-1834-update-momentum-functions branch from e648324 to 413f3eb Compare March 26, 2026 20:08
@yutong-xiang-97
Copy link
Copy Markdown
Contributor Author

/review

Comment on lines +156 to +157
self.center = self.center * self.center_momentum + _t * (
1 - self.center_momentum
)
self.center.mul_(self.center_momentum)
self.center.add_(_t, alpha=1 - self.center_momentum)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't need a change but this probably doesn't give a speedup as it is just a single operation and makes the code slightly harder to understand.

Comment on lines +87 to +90
for ema_tensor, tensor in zip(ema_tensors, tensors):
key = (ema_tensor.device.type, ema_tensor.device.index, ema_tensor.dtype)
grouped_ema_tensors[key].append(ema_tensor)
grouped_tensors[key].append(tensor)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the grouping needed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked and it was due to a false statement of Codex that using tensors with different data types would break stuff. I removed it now.

Comment on lines +77 to +79
ema_tensors: list[Tensor],
tensors: list[Tensor],
decay: float,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to have the same argument order as in update_momentum. Also same naming with tensors, tensors_ema, m.

Comment on lines +94 to +95
torch._foreach_mul_(grouped_ema, decay)
torch._foreach_add_(grouped_ema, grouped_tensor, alpha=1.0 - decay)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch._foreach_mul_(grouped_ema, decay)
torch._foreach_add_(grouped_ema, grouped_tensor, alpha=1.0 - decay)
torch._foreach_mul_(grouped_ema, m)
torch._foreach_add_(grouped_ema, grouped_tensor, alpha=1.0 - m)

Following https://github.com/lightly-ai/lightly/pull/1899/changes

Note that momentum and decay are not the same thing. decay is 1 - momentum.

update_ema_tensors(
ema_tensors=list(model_ema.parameters()),
tensors=list(model.parameters()),
decay=m,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decay is not the same thing as momentum

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants