Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/lightly_train/_methods/densecl/densecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from lightly_train._optim.sgd_args import SGDArgs
from lightly_train._optim.trainable_modules import TrainableModules
from lightly_train._scaling import ScalingInfo
from lightly_train._torch_helpers import update_momentum
from lightly_train._transforms.transform import (
MethodTransform,
)
Expand Down Expand Up @@ -193,8 +194,10 @@ def training_step_impl(self, batch: Batch, batch_idx: int) -> TrainingStepResult
start_value=self.method_args.momentum_start,
end_value=self.method_args.momentum_end,
)
utils.update_momentum(
model=self.query_encoder, model_ema=self.key_encoder, m=momentum
update_momentum(
model=self.query_encoder,
model_ema=self.key_encoder,
m=momentum,
)
views = batch["views"]
query_features, query_global, query_local, _ = self.query_encoder(views[0])
Expand Down
3 changes: 2 additions & 1 deletion src/lightly_train/_methods/dino/dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
from lightly.loss import DINOLoss
from lightly.models.modules.heads import DINOProjectionHead
from lightly.models.utils import get_weight_decay_parameters, update_momentum
from lightly.models.utils import get_weight_decay_parameters
from lightly.utils import optim
from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule
from pytorch_lightning.utilities.types import OptimizerLRScheduler
Expand All @@ -37,6 +37,7 @@
from lightly_train._optim.sgd_args import SGDArgs
from lightly_train._optim.trainable_modules import TrainableModules
from lightly_train._scaling import IMAGENET_SIZE, ScalingInfo
from lightly_train._torch_helpers import update_momentum
from lightly_train._transforms.transform import (
MethodTransform,
)
Expand Down
2 changes: 1 addition & 1 deletion src/lightly_train/_methods/dinov2/dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from lightly.loss import (
KoLeoLoss,
) # we use LightlySSL's KoLeoLoss for better numerical stability
from lightly.models.utils import update_momentum
from lightly.utils.optim import update_param_groups
from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule
from pydantic import Field
Expand Down Expand Up @@ -55,6 +54,7 @@
from lightly_train._optim.optimizer_type import OptimizerType
from lightly_train._optim.trainable_modules import TrainableModules
from lightly_train._scaling import ScalingInfo
from lightly_train._torch_helpers import update_momentum
from lightly_train.types import Batch

logger = logging.getLogger(__name__)
Expand Down
17 changes: 13 additions & 4 deletions src/lightly_train/_task_models/object_detection_components/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import torch.nn as nn
from torch.nn import Module

from lightly_train._torch_helpers import update_ema_tensors


class ModelEMA(Module):
"""
Expand Down Expand Up @@ -59,10 +61,17 @@ def update(self, model: nn.Module):
decay=self.decay, warmup_steps=self.warmups, step=self.updates
)
msd = model.state_dict()
for k, v in self.model.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * msd[k].detach()
ema_tensors = []
model_tensors = []
for key, value in self.model.state_dict().items():
if value.dtype.is_floating_point:
ema_tensors.append(value)
model_tensors.append(msd[key].detach())
update_ema_tensors(
tensors=model_tensors,
tensors_ema=ema_tensors,
m=d,
)

def forward(
self,
Expand Down
25 changes: 25 additions & 0 deletions src/lightly_train/_torch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Any, Callable, Generator

import torch
from torch import Tensor
from torch.nn import Module


Expand Down Expand Up @@ -68,3 +69,27 @@ def set_warn_on_accumulate_grad_stream_mismatch(value: bool) -> None:
# suppress this warning.
if hasattr(torch.autograd.graph, "set_warn_on_accumulate_grad_stream_mismatch"):
torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(value) # type: ignore


@torch.no_grad()
def update_ema_tensors(
tensors: list[Tensor],
tensors_ema: list[Tensor],
m: float,
) -> None:
Comment thread
yutong-xiang-97 marked this conversation as resolved.
"""Updates tensors with an exponential moving average using foreach ops."""
if not tensors_ema:
return

torch._foreach_mul_(tensors_ema, m)
torch._foreach_add_(tensors_ema, tensors, alpha=1.0 - m)


@torch.no_grad()
def update_momentum(model: Module, model_ema: Module, m: float) -> None:
"""Updates parameters of `model_ema` with the EMA of `model`."""
update_ema_tensors(
tensors=list(model.parameters()),
tensors_ema=list(model_ema.parameters()),
m=m,
)
38 changes: 38 additions & 0 deletions tests/test__torch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
from lightning_fabric import Fabric
from torch import nn

from lightly_train import _torch_helpers

Expand All @@ -31,3 +32,40 @@ def test__torch_weights_only_false(tmp_path: Path) -> None:
torch.load(ckpt_path)
fabric.load(ckpt_path)
assert os.environ.get("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD") is None


def test_update_momentum() -> None:
model = nn.Linear(2, 2, bias=False)
model_ema = nn.Linear(2, 2, bias=False)
with torch.no_grad():
model.weight.copy_(torch.tensor([[3.0, 4.0], [5.0, 6.0]]))
model_ema.weight.copy_(torch.tensor([[1.0, 2.0], [3.0, 4.0]]))

ema_weight_ptr = model_ema.weight.data_ptr()
_torch_helpers.update_momentum(model=model, model_ema=model_ema, m=0.25)

expected = torch.tensor([[2.5, 3.5], [4.5, 5.5]])
assert model_ema.weight.data_ptr() == ema_weight_ptr
assert torch.equal(model_ema.weight, expected)


def test_update_ema_tensors() -> None:
ema_tensors = [
torch.tensor([1.0, 2.0], dtype=torch.float32),
torch.tensor([3.0, 4.0], dtype=torch.float64),
]
tensors = [
torch.tensor([5.0, 6.0], dtype=torch.float32),
torch.tensor([7.0, 8.0], dtype=torch.float64),
]
ema_tensor_ptrs = [tensor.data_ptr() for tensor in ema_tensors]

_torch_helpers.update_ema_tensors(
tensors=tensors,
tensors_ema=ema_tensors,
m=0.5,
)

assert [tensor.data_ptr() for tensor in ema_tensors] == ema_tensor_ptrs
assert torch.equal(ema_tensors[0], torch.tensor([3.0, 4.0], dtype=torch.float32))
assert torch.equal(ema_tensors[1], torch.tensor([5.0, 6.0], dtype=torch.float64))
Loading