Fix: #1509 Merge duplicate DoG implementations and add layer-wise support #1518
Fix: #1509 Merge duplicate DoG implementations and add layer-wise support #1518abheeyeee wants to merge 19 commits intogoogle-deepmind:mainfrom
Conversation
|
Can you keep _dog.py and _dog_test.py files, and keep them in the contrib folder? |
Thanks for the Feedback, i will do the changes rn. |
|
@vroulet moved the files as you asked. This should fix the issue |
|
@emilyfertig I made those changes that you prefered I added the new scale_by_l_dog function. 137d6a7 |
emilyfertig
left a comment
There was a problem hiding this comment.
Thanks! scale_by_l_dog should be the same as scale_by_dog with layer_wise = True, right? Can we remove the layer_wise arg, and make scale_by_l_dog the same as scale_by_dog with layer_wise=True?
Thanks for your feedback @emilyfertig i did as you asked made scale_by_l_dog the same as scale_by_dog with layer_wise=True. |
06259a2 to
0894aeb
Compare
emilyfertig
left a comment
There was a problem hiding this comment.
Thanks! I think the pytest failure is unrelated and should clear up if you rebase.
| return _scale_by_dog( | ||
| init_step=("heuristic", reps_rel), | ||
| eps=eps, | ||
| layer_wise=True, |
There was a problem hiding this comment.
Sorry, what I meant is to please get rid of the layer_wise arg everywhere, and make separate implementations of scale_by_dog and scale_by_l_dog. Does that make sense?
There was a problem hiding this comment.
@emilyfertig Refactored the DoG optimizer implementation in optax/contrib/_dog.py to separate the global and layer-wise variants.
Refactored optax/contrib/_dog.py: Removed the internal _scale_by_dog helper function.
Implemented scale_by_dog (global DoG) and scale_by_l_dog (layer-wise DoG) as distinct, standalone functions.
Removed the layer_wise argument from scale_by_dog to enforce clear separation of concerns.
Updated optax/contrib/_dog_test.py:
Renamed test_dog_layer_wise to test_l_dog_vs_dog to reflect the API changes.
Updated comments to remove outdated references to the layer_wise argument.
Verified that all tests pass with pytest optax/contrib/_dog_test.py.
| def scale_by_l_dog( | ||
| reps_rel: jax.typing.ArrayLike = 1e-6, | ||
| eps: jax.typing.ArrayLike = 1e-8, | ||
| param_dtype: Optional[jax.typing.DTypeLike] = None, |
There was a problem hiding this comment.
Please remove the unused param_dtype arg.
|
|
||
| def init_fn(params: base.Params) -> DoGState: | ||
| params_dtype = optax.tree.dtype(params, "lowest") | ||
| if param_dtype is not None: |
There was a problem hiding this comment.
Why is this done here but not in scale_by_dog?
| max_dist=jnp.asarray(r_epsilon, dtype=params_dtype), | ||
| sum_sq_norm_grads=jnp.asarray(0.0, dtype=params_dtype), | ||
| init_params=optax.tree.cast(params, params_dtype), | ||
| max_dist=max_dist, |
There was a problem hiding this comment.
Please revert this so it's inlined.
| with self.assertRaises(AssertionError): | ||
| test_utils.assert_trees_all_close(updates_global, updates_layer) | ||
|
|
||
| def test_legacy_compatibility(self): |
There was a problem hiding this comment.
This test doesn't have much point if the scale_by_distance_over_gradients implementation is changed to call scale_by_l_dog. Can you revert scale_by_distance_over_gradients to its former implementation and still deprecate it?
|
@emilyfertig Understood, this is what i am going to do. optax/contrib/_dog.py optax/contrib/_dog_test.py |
| Ivgi et al, `DoG is SGD's Best Friend: A Parameter-Free Dynamic Step Size | ||
| Schedule <https://arxiv.org/pdf/2302.12022.pdf>`_, 2023 | ||
| """ | ||
| reps_rel = 1e-6 if reps_rel is None else reps_rel |
There was a problem hiding this comment.
This is not the original implementation. You appear to be using an LLM. Please check what it outputs before you request a review.
| max_dist=jnp.asarray(r_epsilon, dtype=params_dtype), | ||
| sum_sq_norm_grads=jnp.asarray(0.0, dtype=params_dtype), | ||
| init_params=optax.tree.cast(params, params_dtype), | ||
| max_dist=max_dist, |
There was a problem hiding this comment.
Please revert this so it's inlined.
| def init_fn(params: base.Params) -> DoGState: | ||
| params_dtype = optax.tree.dtype(params, "lowest") | ||
|
|
||
| # r_epsilon is already a tree of scalars |
There was a problem hiding this comment.
Please remove or clarify comment
Fix Issue: #1509
This PR merges the duplicate Distance over Gradients (DoG) implementations found in optax/contrib/_dog.py and optax/_src/transform.py into a single, unified implementation in optax/_src/dog.py.
Created optax/_src/dog.py which consolidates DoG and DoWG.
The new scale_by_dognow supports a layer_wise argument.
Re-implemented scale_by_distance_over_gradients in optax/_src/transform.py to use the new scale_by_dog with layer_wise=True.
Deprecated scale_by_distance_over_gradients in favor of scale_by_dog.
Updated optax/contrib/_dog.py to be a compatibility shim importing from optax/_src/dog.py.
Added optax/_src/dog_test.py to verify both global and layer-wise behaviors, as well as legacy compatibility.