You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Replace lambda function in matrix_function_types.py and shampoo_types.py to a private static method for pickable concern
Summary:
This is the same spirit of D82836543, which changes the default value of `scale_fn` of `SignDescentPreconditionerConfig` to a private static method for pickling concern.
We found a few other occurrences of `lambda` function in `shampoo_types.py` and `matrix_function_types.py`. In order for shampoo state_dict (in particular, `param_groups`) to be compatibale with pytorch, we need to replace all lambda functions with a normal function, because `torch.save()` uses pickle for serialization, and pickle cannot serialize lambda function but only regular function.
Reviewed By: anana10c, hjmshi
Differential Revision: D85902488
fbshipit-source-id: ead3637fe2202ed10681312bf0a0652036be32a9
"""Configuration for Shampoo preconditioner computation with caching of the eigendecomposed factor matrices.
318
326
327
+
Note: When using custom amortized_computation_config, avoid lambda functions as they may cause
328
+
pickling issues during serialization/deserialization. Use regular named functions
329
+
instead for better compatibility with distributed training and checkpointing.
330
+
319
331
Attributes:
320
332
amortized_computation_config (EigendecompositionConfig): Configuration for the eigendecomposition computation. (Default: DefaultEigendecompositionConfig)
321
333
num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3)
@@ -361,8 +373,12 @@ class EigendecomposedShampooPreconditionerConfig(ShampooPreconditionerConfig):
@@ -375,6 +391,10 @@ class EigenvalueCorrectedShampooPreconditionerConfig(AmortizedPreconditionerConf
375
391
Recall that in eigenvalue-corrected Shampoo, the eigenvectors and eigenvalues of the factor matrices are computed separately and stored in place of the full inverted preconditioner, as opposed to the single inverse-root computation of the factor matrices in Shampoo.
376
392
In eigenvalue-corrected Shampoo, the eigenvectors are updated periodically like the inverted preconditioners in Shampoo, but the eigenvalues are updated every iteration.
377
393
394
+
Note: When using custom amortized_computation_config, avoid lambda functions as they may cause
395
+
pickling issues during serialization/deserialization. Use regular named functions
396
+
instead for better compatibility with distributed training and checkpointing.
397
+
378
398
Attributes:
379
399
amortized_computation_config (EigendecompositionConfig): Configuration for the eigenvector computation.
380
400
(Default: DefaultEigendecompositionConfig)
@@ -421,8 +441,12 @@ class EigenvalueCorrectedShampooPreconditionerConfig(AmortizedPreconditionerConf
@@ -497,14 +521,23 @@ class SpectralDescentPreconditionerConfig(PreconditionerConfig):
497
521
Which parameters are reshaped to 2D is determined by the max_preconditioner_dim argument in DistributedShampoo.
498
522
If all >2D parameters should be guaranteed to be reshaped to 2D, then max_preconditioner_dim=math.inf and distributed_config.target_parameter_dimensionality=2 has to be used.
499
523
524
+
525
+
Note: When using custom orthogonalization config, avoid lambda functions as they may cause
526
+
pickling issues during serialization/deserialization. Use regular named functions
527
+
instead for better compatibility with distributed training and checkpointing.
528
+
500
529
Attributes:
501
530
orthogonalization_config (OrthogonalizationConfig): Configuration for orthogonalization of the search direction.
@@ -659,6 +700,10 @@ class DDPDistributedConfig(DistributedConfig):
659
700
660
701
Enables distributed computation and optimizer states (like ZeRO-1) via DTensor for Shampoo.
661
702
703
+
Note: When using custom load_balancing_config, avoid lambda functions as they may cause
704
+
pickling issues during serialization/deserialization. Use regular named functions
705
+
instead for better compatibility with distributed training and checkpointing.
706
+
662
707
Attributes:
663
708
target_parameter_dimensionality (int | float): The idealized parameter dimensionality for a given algorithm.
664
709
The dimensions of parameters and gradients will be merged (after squeezing dimensions of size 1) while respecting max_preconditioner_dim until the tensor has target_parameter_dimensionality dimensions left.
@@ -679,8 +724,13 @@ class DDPDistributedConfig(DistributedConfig):
0 commit comments