diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index 28900bc6..d6ca42cc 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -105,7 +105,7 @@ def __init__( self.weighting = weighting self.alpha = alpha self._initial_weights = initial_weights - self._lambda: Tensor | None = None + self.register_buffer("_lambda", None) @property def alpha(self) -> float: diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 54b6d0c3..7e5191c0 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -44,7 +44,7 @@ def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: super().__init__() self.beta = beta self.eps = eps - self._phi_t: Tensor | None = None + self.register_buffer("_phi_t", None) self._state_key: tuple[int, torch.dtype] | None = None @property diff --git a/src/torchjd/aggregation/_modo.py b/src/torchjd/aggregation/_modo.py index a7b6fc0e..84d64e10 100644 --- a/src/torchjd/aggregation/_modo.py +++ b/src/torchjd/aggregation/_modo.py @@ -133,7 +133,7 @@ def __init__(self, gamma: float = 0.1, rho: float = 0.1) -> None: super().__init__() self.gamma = gamma self.rho = rho - self._lambda: Tensor | None = None + self.register_buffer("_lambda", None) self._state_key: tuple[int, torch.dtype, torch.device] | None = None @property diff --git a/src/torchjd/aggregation/_sdmgrad.py b/src/torchjd/aggregation/_sdmgrad.py index e4eb5ab4..0431f296 100644 --- a/src/torchjd/aggregation/_sdmgrad.py +++ b/src/torchjd/aggregation/_sdmgrad.py @@ -107,7 +107,7 @@ def __init__( self.n_iter = n_iter self.lambda_ = lambda_ self.pref_vector = pref_vector - self._w: Tensor | None = None + self.register_buffer("_w", None) self._state_key: tuple[int, torch.dtype, torch.device] | None = None @property