11import torch
22import numpy as np
3- from flax .core import freeze , unfreeze
4-
5- # Load PyTorch state_dict
6- state_dict = torch .load ("/results/pytorch_base_model_criteo1tb_22_may.pth" )
7-
8- # Convert PyTorch tensors to NumPy arrays
9- numpy_weights = {k : v .numpy () for k , v in state_dict .items ()}
10-
11-
3+ import jax
4+ import jax .numpy as jnp
125"""
136Jax default parameter structure:
147dict_keys(['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3', 'Dense_4', 'Dense_5', 'Dense_6', 'Dense_7', 'embedding_table'])
2316The function assumes that the Jax model parameters are already initialized
2417and that the PyTorch weights are in the correct format.
2518"""
26- def use_pytorch_weights (jax_params ):
19+ def use_pytorch_weights (jax_params , file_name = None ):
20+ # Load PyTorch state_dict
21+ state_dict = torch .load (file_name )
22+ print (state_dict .keys ())
23+ # Convert PyTorch tensors to NumPy arrays
24+ numpy_weights = {k : v .cpu ().numpy () for k , v in state_dict .items ()}
25+
2726 # --- Embedding Table ---
2827 embedding_table = np .concatenate ([
2928 numpy_weights [f'embedding_chunk_{ i } ' ] for i in range (4 )
@@ -42,3 +41,28 @@ def use_pytorch_weights(jax_params):
4241 jax_params [f'Dense_{ j } ' ]['bias' ] = numpy_weights [f'top_mlp.{ i } .bias' ]
4342
4443 return jax_params
44+
45+
46+ def are_weights_equal (params1 , params2 , atol = 1e-6 , rtol = 1e-6 ):
47+ """Compares two JAX PyTrees of weights and prints where they differ."""
48+ all_equal = True
49+
50+ def compare_fn (p1 , p2 ):
51+ nonlocal all_equal
52+ #if not jnp.allclose(p1, p2):
53+ if not jnp .allclose (p1 , p2 , atol = atol , rtol = rtol ):
54+ print ("❌ Mismatch found:" )
55+ print (f"Shape 1: { p1 .shape } , Shape 2: { p2 .shape } " )
56+ print (f"Max diff: { jnp .max (jnp .abs (p1 - p2 ))} " )
57+ all_equal = False
58+ return jnp .allclose (p1 , p2 , atol = atol , rtol = rtol )
59+
60+ try :
61+ _ = jax .tree_util .tree_map (compare_fn , params1 , params2 )
62+ except Exception as e :
63+ print ("❌ Structure mismatch or error during comparison:" , e )
64+ return False
65+
66+ if all_equal :
67+ print ("✅ All weights are equal (within tolerance)" )
68+ return all_equal
0 commit comments