From c9572a116e388a9908498fbb9c03e5f4d6eb525b Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 2 Oct 2024 14:01:18 -0700 Subject: [PATCH] Stackless yashful PiperOrigin-RevId: 681582933 --- learned_optimization/jax_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/learned_optimization/jax_utils.py b/learned_optimization/jax_utils.py index 301d090..144a217 100644 --- a/learned_optimization/jax_utils.py +++ b/learned_optimization/jax_utils.py @@ -46,9 +46,12 @@ def body_fn(_, operand): def in_jit() -> bool: """Returns true if tracing jit.""" - return "DynamicJaxprTrace" in str( - jax.core.thread_local_state.trace_state.trace_stack - ) + if jax.__version_info__ <= (0, 4, 33): + return "DynamicJaxprTrace" in str( + jax.core.thread_local_state.trace_state.trace_stack # type: ignore + ) + + return jax.core.unsafe_am_i_under_a_jit_DO_NOT_USE() Carry = TypeVar("Carry")