Deprecate optax.global_norm in favor of optax.tree.norm.#1368
Merged
copybara-service[bot] merged 1 commit intogoogle-deepmind:mainfrom Nov 10, 2025
Merged
Conversation
Collaborator
|
Did you check that the hlos match? |
Contributor
Author
Diff: module @jit_global_norm attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = "result"}) {
%0 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
- %1 = stablehlo.convert %0 : tensor<f32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
- %2 = stablehlo.reduce(%1 init: %cst) applies stablehlo.add across dimensions = [] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %1 = stablehlo.multiply %cst, %cst : tensor<f32>
+ %2 = stablehlo.add %0, %1 : tensor<f32>
+ %3 = stablehlo.convert %2 : tensor<f32>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
- %3 = stablehlo.add %cst_0, %2 : tensor<f32>
- %4 = stablehlo.sqrt %3 : tensor<f32>
- return %4 : tensor<f32>
+ %4 = stablehlo.reduce(%3 init: %cst_0) applies stablehlo.add across dimensions = [] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
+ %5 = stablehlo.add %cst_1, %4 : tensor<f32>
+ %6 = stablehlo.sqrt %5 : tensor<f32>
+ return %6 : tensor<f32>
}
}Separately: OldNewFor reference, here are the current implementations of A difference is that the current implementation of def abs_sq(x):
return (x.conj() * x).realwhereas def _square(leaf):
return jnp.square(leaf.real) + jnp.square(leaf.imag)We should probably pick the most efficient of these two for both functions. Based on HLO size, it looks like |
Collaborator
|
On a pytree input, HLO is as follows: For For These are identical except |
emilyfertig
approved these changes
Nov 10, 2025
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Deprecates the redundant function
optax.global_normin favor ofoptax.tree.norm.Split into a new PR from #1365.