Skip to content

Commit 3853676

Browse files
wz337meta-codesync[bot]
authored andcommitted
Replace lambda function in matrix_function_types.py and shampoo_types.py to a private static method for pickable concern
Summary: This is the same spirit of D82836543, which changes the default value of `scale_fn` of `SignDescentPreconditionerConfig` to a private static method for pickling concern. We found a few other occurrences of `lambda` function in `shampoo_types.py` and `matrix_function_types.py`. In order for shampoo state_dict (in particular, `param_groups`) to be compatibale with pytorch, we need to replace all lambda functions with a normal function, because `torch.save()` uses pickle for serialization, and pickle cannot serialize lambda function but only regular function. Reviewed By: anana10c, hjmshi Differential Revision: D85902488 fbshipit-source-id: ead3637fe2202ed10681312bf0a0652036be32a9
1 parent 78f629e commit 3853676

4 files changed

Lines changed: 103 additions & 19 deletions

File tree

distributed_shampoo/distributed_shampoo.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,6 +1466,17 @@ def _post_state_dict_hook(optimizer: Optimizer, state_dict: StateDict) -> None:
14661466
Returns:
14671467
None: The state_dict is modified in-place.
14681468
"""
1469+
1470+
def _has_lambda_recursively(obj: Any) -> bool:
1471+
"""Recursively check if an object contains lambda functions."""
1472+
if isinstance(obj, LambdaType):
1473+
return True
1474+
if is_dataclass(obj):
1475+
return any(
1476+
_has_lambda_recursively(getattr(obj, f.name)) for f in fields(obj)
1477+
)
1478+
return False
1479+
14691480
# for state exist on the ranks
14701481
state_dict["state"] = {
14711482
k: extract_state_dict_content(v) for k, v in state_dict["state"].items()
@@ -1477,9 +1488,7 @@ def _post_state_dict_hook(optimizer: Optimizer, state_dict: StateDict) -> None:
14771488
for group in state_dict["param_groups"]:
14781489
param_ids.extend(group["params"])
14791490
for v in group.values():
1480-
if is_dataclass(v) and any(
1481-
isinstance(getattr(v, f.name), LambdaType) for f in fields(v)
1482-
):
1491+
if _has_lambda_recursively(v):
14831492
logger.warning(
14841493
f"Found {v=}. Note that lambda function cannot be pickled. "
14851494
"torch.save() cannot serialize lambda functions, because it "

distributed_shampoo/preconditioner/matrix_functions_types.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ class EigendecompositionConfig(MatrixFunctionConfig):
8484
Moreover, we have ||B||_F = ||Q^T A Q||_F = ||A||_F.
8585
Hence, the two relative errors are also equivalent: ||A - A'||_F / ||A||_F = ||B - diag(B)||_F / ||B||_F.
8686
87+
Note: When using custom rank_deficient_stability_config, avoid lambda functions as they may cause
88+
pickling issues during serialization/deserialization. Use regular named functions
89+
instead for better compatibility with distributed training and checkpointing.
90+
8791
Attributes:
8892
rank_deficient_stability_config (RankDeficientStabilityConfig): Configuration for handling/stabilizing rank-deficient matrices. (Default: DefaultPerturbationConfig)
8993
TODO: generalize this to MatrixFunctionConfig
@@ -92,8 +96,12 @@ class EigendecompositionConfig(MatrixFunctionConfig):
9296
9397
"""
9498

99+
@staticmethod
100+
def _get_default_rank_deficient_stability_config() -> RankDeficientStabilityConfig:
101+
return DefaultPerturbationConfig
102+
95103
rank_deficient_stability_config: RankDeficientStabilityConfig = field(
96-
default_factory=lambda: DefaultPerturbationConfig
104+
default_factory=_get_default_rank_deficient_stability_config
97105
)
98106
tolerance: float = 0.0
99107

@@ -238,13 +246,28 @@ class OrthogonalizationConfig(MatrixFunctionConfig):
238246
239247
If the reduced SVD of the matrix A is given by A = U S V^T, then the orthogonalized/closest orthogonal matrix is U V^T.
240248
249+
Note: When using custom scale_by_dims_fn, avoid lambda functions as they may cause
250+
pickling issues during serialization/deserialization. Use regular named functions
251+
instead for better compatibility with distributed training and checkpointing.
252+
241253
Attributes:
242254
scale_by_dims_fn (Callable[[int, int], float]): Function to scale the orthogonalized matrix by some function of the dimensions of the matrix.
243-
(Default: lambda d_in, d_out: 1.0)
255+
(Default: _default_scale_by_dims_fn)
244256
245257
"""
246258

247-
scale_by_dims_fn: Callable[[int, int], float] = lambda d_in, d_out: 1.0
259+
@staticmethod
260+
def _default_scale_by_dims_fn(d_in: int, d_out: int) -> float:
261+
"""Default scaling function that returns 1.0 (no scaling)."""
262+
return 1.0
263+
264+
@staticmethod
265+
def _get_default_scale_by_dims_fn() -> Callable[[int, int], float]:
266+
return OrthogonalizationConfig._default_scale_by_dims_fn
267+
268+
scale_by_dims_fn: Callable[[int, int], float] = field(
269+
default_factory=_get_default_scale_by_dims_fn
270+
)
248271

249272

250273
@dataclass(kw_only=True)

distributed_shampoo/shampoo_types.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,10 @@ def __post_init__(self) -> None:
259259
class RootInvShampooPreconditionerConfig(ShampooPreconditionerConfig):
260260
"""Configuration for Shampoo preconditioner computation with caching of the root inverse factor matrices.
261261
262+
Note: When using custom amortized_computation_config, avoid lambda functions as they may cause
263+
pickling issues during serialization/deserialization. Use regular named functions
264+
instead for better compatibility with distributed training and checkpointing.
265+
262266
Attributes:
263267
amortized_computation_config (RootInvConfig): Configuration for the inverse-root computation. (Default: DefaultEigenConfig)
264268
num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3)
@@ -303,8 +307,12 @@ class RootInvShampooPreconditionerConfig(ShampooPreconditionerConfig):
303307
304308
"""
305309

310+
@staticmethod
311+
def _get_default_amortized_computation_config() -> RootInvConfig:
312+
return DefaultEigenConfig
313+
306314
amortized_computation_config: RootInvConfig = field(
307-
default_factory=lambda: DefaultEigenConfig
315+
default_factory=_get_default_amortized_computation_config
308316
)
309317
inv_factor_matrix_dtype: torch.dtype = torch.float32
310318

@@ -316,6 +324,10 @@ class RootInvShampooPreconditionerConfig(ShampooPreconditionerConfig):
316324
class EigendecomposedShampooPreconditionerConfig(ShampooPreconditionerConfig):
317325
"""Configuration for Shampoo preconditioner computation with caching of the eigendecomposed factor matrices.
318326
327+
Note: When using custom amortized_computation_config, avoid lambda functions as they may cause
328+
pickling issues during serialization/deserialization. Use regular named functions
329+
instead for better compatibility with distributed training and checkpointing.
330+
319331
Attributes:
320332
amortized_computation_config (EigendecompositionConfig): Configuration for the eigendecomposition computation. (Default: DefaultEigendecompositionConfig)
321333
num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3)
@@ -361,8 +373,12 @@ class EigendecomposedShampooPreconditionerConfig(ShampooPreconditionerConfig):
361373
362374
"""
363375

376+
@staticmethod
377+
def _get_default_amortized_computation_config() -> EigendecompositionConfig:
378+
return DefaultEigendecompositionConfig
379+
364380
amortized_computation_config: EigendecompositionConfig = field(
365-
default_factory=lambda: DefaultEigendecompositionConfig
381+
default_factory=_get_default_amortized_computation_config
366382
)
367383
factor_matrix_eigenvectors_dtype: torch.dtype = torch.float32
368384
factor_matrix_eigenvalues_dtype: torch.dtype = torch.float32
@@ -375,6 +391,10 @@ class EigenvalueCorrectedShampooPreconditionerConfig(AmortizedPreconditionerConf
375391
Recall that in eigenvalue-corrected Shampoo, the eigenvectors and eigenvalues of the factor matrices are computed separately and stored in place of the full inverted preconditioner, as opposed to the single inverse-root computation of the factor matrices in Shampoo.
376392
In eigenvalue-corrected Shampoo, the eigenvectors are updated periodically like the inverted preconditioners in Shampoo, but the eigenvalues are updated every iteration.
377393
394+
Note: When using custom amortized_computation_config, avoid lambda functions as they may cause
395+
pickling issues during serialization/deserialization. Use regular named functions
396+
instead for better compatibility with distributed training and checkpointing.
397+
378398
Attributes:
379399
amortized_computation_config (EigendecompositionConfig): Configuration for the eigenvector computation.
380400
(Default: DefaultEigendecompositionConfig)
@@ -421,8 +441,12 @@ class EigenvalueCorrectedShampooPreconditionerConfig(AmortizedPreconditionerConf
421441
422442
"""
423443

444+
@staticmethod
445+
def _get_default_amortized_computation_config() -> EigendecompositionConfig:
446+
return DefaultEigendecompositionConfig
447+
424448
amortized_computation_config: EigendecompositionConfig = field(
425-
default_factory=lambda: DefaultEigendecompositionConfig
449+
default_factory=_get_default_amortized_computation_config
426450
)
427451
ignored_basis_change_dims: dict[int, list[int]] = field(default_factory=dict)
428452
inverse_exponent_override: dict[int, float] = field(default_factory=dict)
@@ -497,14 +521,23 @@ class SpectralDescentPreconditionerConfig(PreconditionerConfig):
497521
Which parameters are reshaped to 2D is determined by the max_preconditioner_dim argument in DistributedShampoo.
498522
If all >2D parameters should be guaranteed to be reshaped to 2D, then max_preconditioner_dim=math.inf and distributed_config.target_parameter_dimensionality=2 has to be used.
499523
524+
525+
Note: When using custom orthogonalization config, avoid lambda functions as they may cause
526+
pickling issues during serialization/deserialization. Use regular named functions
527+
instead for better compatibility with distributed training and checkpointing.
528+
500529
Attributes:
501530
orthogonalization_config (OrthogonalizationConfig): Configuration for orthogonalization of the search direction.
502531
(Default: DefaultNewtonSchulzOrthogonalizationConfig)
503532
504533
"""
505534

535+
@staticmethod
536+
def _default_orthogonalization_config() -> OrthogonalizationConfig:
537+
return DefaultNewtonSchulzOrthogonalizationConfig
538+
506539
orthogonalization_config: OrthogonalizationConfig = field(
507-
default_factory=lambda: DefaultNewtonSchulzOrthogonalizationConfig
540+
default_factory=_default_orthogonalization_config
508541
)
509542

510543

@@ -595,12 +628,20 @@ class LoadBalancingConfig:
595628
The `cost_model` defines how the cost of a tensor is computed, and the distributor uses this cost to partition workloads.
596629
By default, it uses `AlignedMemoryCostModel`, other options include `PolynomialComputationalCostModel`.
597630
631+
Note: When using custom cost_model, avoid lambda functions as they may cause
632+
pickling issues during serialization/deserialization. Use regular named functions
633+
instead for better compatibility with distributed training and checkpointing.
634+
598635
Args:
599636
cost_model (CostModel): The cost model used for load balancing. (Default: DefaultCostModel)
600637
601638
"""
602639

603-
cost_model: CostModel = field(default_factory=lambda: DefaultCostModel)
640+
@staticmethod
641+
def _get_default_cost_model() -> CostModel:
642+
return DefaultCostModel
643+
644+
cost_model: CostModel = field(default_factory=_get_default_cost_model)
604645

605646

606647
@dataclass(init=False)
@@ -659,6 +700,10 @@ class DDPDistributedConfig(DistributedConfig):
659700
660701
Enables distributed computation and optimizer states (like ZeRO-1) via DTensor for Shampoo.
661702
703+
Note: When using custom load_balancing_config, avoid lambda functions as they may cause
704+
pickling issues during serialization/deserialization. Use regular named functions
705+
instead for better compatibility with distributed training and checkpointing.
706+
662707
Attributes:
663708
target_parameter_dimensionality (int | float): The idealized parameter dimensionality for a given algorithm.
664709
The dimensions of parameters and gradients will be merged (after squeezing dimensions of size 1) while respecting max_preconditioner_dim until the tensor has target_parameter_dimensionality dimensions left.
@@ -679,8 +724,13 @@ class DDPDistributedConfig(DistributedConfig):
679724
communication_dtype: torch.dtype = torch.float32
680725
num_trainers_per_group: int = -1
681726
communicate_params: bool = False
727+
728+
@staticmethod
729+
def _get_default_load_balancing_config() -> LoadBalancingConfig:
730+
return LoadBalancingConfig()
731+
682732
load_balancing_config: LoadBalancingConfig = field(
683-
default_factory=lambda: LoadBalancingConfig()
733+
default_factory=_get_default_load_balancing_config
684734
)
685735

686736

distributed_shampoo/tests/distributed_shampoo_test.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
import re
1616
import unittest
1717
from collections.abc import Callable
18-
from dataclasses import dataclass, replace
18+
from dataclasses import dataclass, field, replace
1919
from typing import Any, cast
2020

2121
import torch
2222
from distributed_shampoo.distributed_shampoo import DistributedShampoo
2323
from distributed_shampoo.preconditioner.matrix_functions_types import (
24+
DefaultNewtonSchulzOrthogonalizationConfig,
2425
EigenConfig,
26+
OrthogonalizationConfig,
2527
PseudoInverseConfig,
2628
)
2729
from distributed_shampoo.shampoo_types import (
@@ -1159,17 +1161,17 @@ def test_state_dict_warning(self) -> None:
11591161
self.assertCountEqual(osd.keys(), ["state", "param_groups"])
11601162

11611163
@dataclass(kw_only=True)
1162-
class SpectralDescentPreconditionerConfigWithLambda(
1163-
SpectralDescentPreconditionerConfig
1164-
):
1164+
class SpectralDescentPreconditionerConfigWithLambda(PreconditionerConfig):
11651165
"""
1166-
Creating a preconditioner config with a dummy lambda function to make sure the
1166+
Creating a orthogonalization config with a dummy lambda function to make sure the
11671167
warning from `_post_state_dict_hook` emit.
11681168
"""
11691169

1170-
scale_fn: Callable[[Tensor], float | Tensor] = lambda grad: 1.0
1170+
orthogonalization_config: OrthogonalizationConfig = field(
1171+
default_factory=lambda: DefaultNewtonSchulzOrthogonalizationConfig
1172+
)
11711173

1172-
self._optimizer.param_groups[0]["preconditioner_config"] = (
1174+
self._optimizer.param_groups[0]["orthogonalization_config"] = (
11731175
SpectralDescentPreconditionerConfigWithLambda()
11741176
)
11751177
logger = logging.getLogger("distributed_shampoo.distributed_shampoo")

0 commit comments

Comments
 (0)