Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Pull request overview
This PR updates exponential moving average (EMA) updates to use in-place/foreach tensor operations, and routes method implementations to the new internal helpers.
Changes:
- Added
update_ema_tensors(foreach-based) andupdate_momentumhelpers in_torch_helpers.py. - Updated object detection
ModelEMAand DINO/DINOv2/DenseCL codepaths to use the new EMA/momentum helpers. - Added unit tests covering the new EMA/momentum helper behavior.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
src/lightly_train/_torch_helpers.py |
Introduces foreach-based EMA tensor updates and a momentum helper. |
src/lightly_train/_task_models/object_detection_components/ema.py |
Switches ModelEMA update logic to batch-update floating tensors via update_ema_tensors. |
src/lightly_train/_methods/dinov2/dinov2_loss.py |
Uses in-place ops for center EMA update to avoid tensor reassignment. |
src/lightly_train/_methods/dinov2/dinov2.py |
Replaces external update_momentum import with internal helper. |
src/lightly_train/_methods/dino/dino.py |
Replaces external update_momentum import with internal helper. |
src/lightly_train/_methods/densecl/densecl.py |
Uses internal update_momentum helper for key/query encoder momentum update. |
tests/test__torch_helpers.py |
Adds unit tests for update_momentum and dtype-grouped update_ema_tensors. |
e648324 to
413f3eb
Compare
|
/review |
| self.center = self.center * self.center_momentum + _t * ( | ||
| 1 - self.center_momentum | ||
| ) | ||
| self.center.mul_(self.center_momentum) | ||
| self.center.add_(_t, alpha=1 - self.center_momentum) |
There was a problem hiding this comment.
Doesn't need a change but this probably doesn't give a speedup as it is just a single operation and makes the code slightly harder to understand.
src/lightly_train/_torch_helpers.py
Outdated
| for ema_tensor, tensor in zip(ema_tensors, tensors): | ||
| key = (ema_tensor.device.type, ema_tensor.device.index, ema_tensor.dtype) | ||
| grouped_ema_tensors[key].append(ema_tensor) | ||
| grouped_tensors[key].append(tensor) |
There was a problem hiding this comment.
Why is the grouping needed?
There was a problem hiding this comment.
I checked and it was due to a false statement of Codex that using tensors with different data types would break stuff. I removed it now.
src/lightly_train/_torch_helpers.py
Outdated
| ema_tensors: list[Tensor], | ||
| tensors: list[Tensor], | ||
| decay: float, |
There was a problem hiding this comment.
I think it would be nice to have the same argument order as in update_momentum. Also same naming with tensors, tensors_ema, m.
src/lightly_train/_torch_helpers.py
Outdated
| torch._foreach_mul_(grouped_ema, decay) | ||
| torch._foreach_add_(grouped_ema, grouped_tensor, alpha=1.0 - decay) |
There was a problem hiding this comment.
| torch._foreach_mul_(grouped_ema, decay) | |
| torch._foreach_add_(grouped_ema, grouped_tensor, alpha=1.0 - decay) | |
| torch._foreach_mul_(grouped_ema, m) | |
| torch._foreach_add_(grouped_ema, grouped_tensor, alpha=1.0 - m) |
Following https://github.com/lightly-ai/lightly/pull/1899/changes
Note that momentum and decay are not the same thing. decay is 1 - momentum.
src/lightly_train/_torch_helpers.py
Outdated
| update_ema_tensors( | ||
| ema_tensors=list(model_ema.parameters()), | ||
| tensors=list(model.parameters()), | ||
| decay=m, |
There was a problem hiding this comment.
Decay is not the same thing as momentum
What has changed and why?
Update EMA calculation with
_foreach_mul_How has it been tested?
Unit tests
Did you update CHANGELOG.md?
Did you update the documentation?