2121The function assumes that the Jax model parameters are already initialized
2222and that the PyTorch weights are in the correct format.
2323"""
24- def use_pytorch_weights_inplace (jax_params , file_name = None , replicate = False ):
2524
26- # Load PyTorch state_dict
27- state_dict = torch .load (file_name )
28- print (state_dict .keys ())
29-
30- # Convert PyTorch tensors to NumPy arrays
31- numpy_weights = {k : v .cpu ().numpy () for k , v in state_dict .items ()}
32-
33- # --- Embedding Table ---
34- embedding_table = np .concatenate ([
35- numpy_weights [f'embedding_chunk_{ i } ' ] for i in range (4 )
36- ], axis = 0 ) # adjust axis depending on chunking direction
37-
38- jax_params ['embedding_table' ] = jnp .array (embedding_table )
39-
40- # --- Bot MLP: Dense_0 to Dense_2 ---
41- for i , j in zip ([0 , 2 , 4 ], range (3 )):
42- jax_params [f'Dense_{ j } ' ]['kernel' ] = jnp .array (numpy_weights [f'bot_mlp.{ i } .weight' ].T )
43- jax_params [f'Dense_{ j } ' ]['bias' ] = jnp .array (numpy_weights [f'bot_mlp.{ i } .bias' ])
44-
45- # --- Top MLP: Dense_3 to Dense_7 ---
46- for i , j in zip ([0 , 2 , 4 , 6 , 8 ], range (3 , 8 )):
47- jax_params [f'Dense_{ j } ' ]['kernel' ] = jnp .array (numpy_weights [f'top_mlp.{ i } .weight' ].T )
48- jax_params [f'Dense_{ j } ' ]['bias' ] = jnp .array (numpy_weights [f'top_mlp.{ i } .bias' ])
49-
50- del state_dict
51- return jax_params
25+ def use_pytorch_weights (file_name : str ):
26+ jax_copy = {}
5227
53-
54- def use_pytorch_weights_cpu_copy (jax_params , file_name = None , replicate = False ):
55-
56- def deep_copy_to_cpu (pytree ):
57- return tree_map (lambda x : jax .device_put (jnp .array (copy .deepcopy (x )), device = jax .devices ("cpu" )[0 ]), pytree )
58-
59- jax_copy = deep_copy_to_cpu (jax_params )
6028 # Load PyTorch state_dict lazily to CPU
6129 state_dict = torch .load (file_name , map_location = 'cpu' )
6230 print (state_dict .keys ())
63-
31+
6432 # Convert PyTorch tensors to NumPy arrays
65- numpy_weights = {k : v .cpu ().numpy () for k , v in state_dict .items ()}
33+ numpy_weights = {k : v .cpu ().numpy () for k , v in state_dict .items ()}
6634
6735 # --- Embedding Table ---
6836 embedding_table = np .concatenate ([
6937 numpy_weights [f'embedding_chunk_{ i } ' ] for i in range (4 )
70- ], axis = 0 ) # adjust axis depending on chunking direction
38+ ], axis = 0 ) # adjust axis if chunking is not vertical
7139
7240 jax_copy ['embedding_table' ] = jnp .array (embedding_table )
7341
7442 # --- Bot MLP: Dense_0 to Dense_2 ---
7543 for i , j in zip ([0 , 2 , 4 ], range (3 )):
44+ jax_copy [f'Dense_{ j } ' ] = {}
7645 jax_copy [f'Dense_{ j } ' ]['kernel' ] = jnp .array (numpy_weights [f'bot_mlp.{ i } .weight' ].T )
7746 jax_copy [f'Dense_{ j } ' ]['bias' ] = jnp .array (numpy_weights [f'bot_mlp.{ i } .bias' ])
7847
7948 # --- Top MLP: Dense_3 to Dense_7 ---
8049 for i , j in zip ([0 , 2 , 4 , 6 , 8 ], range (3 , 8 )):
50+ jax_copy [f'Dense_{ j } ' ] = {}
8151 jax_copy [f'Dense_{ j } ' ]['kernel' ] = jnp .array (numpy_weights [f'top_mlp.{ i } .weight' ].T )
8252 jax_copy [f'Dense_{ j } ' ]['bias' ] = jnp .array (numpy_weights [f'top_mlp.{ i } .bias' ])
83- #jax_copy = tree_map(lambda x: jnp.array(x), jax_copy)
84- del state_dict
8553
54+ del state_dict
8655 return jax_copy
8756
8857
89- def use_pytorch_weights_inplace_mnist (jax_params , file_name = None , replicate = False ):
90- # Load the PyTorch checkpoint
91- ckpt = torch .load (file_name )
92- state_dict = ckpt ['state_dict' ] if 'state_dict' in ckpt else ckpt
93-
94- print ("Loaded PyTorch keys:" , state_dict .keys ())
95-
96- # Convert to numpy
97- numpy_weights = {k : v .cpu ().numpy () for k , v in state_dict .items ()}
98-
99- # Mapping PyTorch keys → JAX Dense layers
100- layer_map = {
101- 'net.layer1' : 'Dense_0' ,
102- 'net.layer2' : 'Dense_1' ,
103- }
104-
105- for pt_name , jax_name in layer_map .items ():
106- weight_key = f"{ pt_name } .weight"
107- bias_key = f"{ pt_name } .bias"
108-
109- if weight_key not in numpy_weights or bias_key not in numpy_weights :
110- raise KeyError (f"Missing keys: { weight_key } or { bias_key } in PyTorch weights" )
111-
112- jax_params [jax_name ]['kernel' ] = jnp .array (numpy_weights [weight_key ].T ) # Transpose!
113- jax_params [jax_name ]['bias' ] = jnp .array (numpy_weights [bias_key ])
114-
115- return jax_params
116-
117-
11858def maybe_unreplicate (pytree ):
11959 """If leading axis matches device count, strip it assuming it's pmap replication."""
12060 num_devices = jax .device_count ()
@@ -123,6 +63,7 @@ def maybe_unreplicate(pytree):
12363 pytree
12464 )
12565
66+
12667def move_to_cpu (tree ):
12768 return jax .tree_util .tree_map (lambda x : jax .device_put (x , device = jax .devices ("cpu" )[0 ]), tree )
12869
@@ -143,7 +84,7 @@ def compare_fn(p1, p2):
14384 nonlocal all_equal
14485 if not jnp .allclose (p1 , p2 , atol = atol , rtol = rtol ):
14586 logging .info ("❌ Mismatch found:" )
146- logging .info (f"Shape 1 : { p1 .shape } , Shape 2: { p2 .shape } " )
87+ logging .info (f"Shape : { p1 .shape } , Shape 2: { p2 .shape } " )
14788 logging .info (f"Max diff: { jnp .max (jnp .abs (p1 - p2 ))} " )
14889 all_equal = False
14990 return jnp .allclose (p1 , p2 , atol = atol , rtol = rtol )
@@ -156,31 +97,6 @@ def compare_fn(p1, p2):
15697
15798 if all_equal :
15899 logging .info ("✅ All weights are equal (within tolerance)" )
159- return all_equal
160-
161-
162-
163- # def are_weights_equal(params1, params2, atol=1e-6, rtol=1e-6):
164- # """Compares two JAX PyTrees of weights and prints where they differ."""
165- # all_equal = True
166-
167- # def compare_fn(p1, p2):
168- # nonlocal all_equal
169- # #if not jnp.allclose(p1, p2):
170- # if not jnp.allclose(p1, p2, atol=atol, rtol=rtol):
171- # logging.info("❌ Mismatch found:")
172- # logging.info(f"Shape 1: {p1.shape}, Shape 2: {p2.shape}")
173- # logging.info(f"Max diff: {jnp.max(jnp.abs(p1 - p2))}")
174- # all_equal = False
175- # return jnp.allclose(p1, p2, atol=atol, rtol=rtol)
176-
177- # try:
178- # _ = jax.tree_util.tree_map(compare_fn, params1, params2)
179- # except Exception as e:
180- # logging.info("❌ Structure mismatch or error during comparison:", e)
181- # return False
182-
183- # if all_equal:
184- # logging.info("✅ All weights are equal (within tolerance)")
185- # return all_equal
186-
100+ del params1
101+ del params2
102+ return all_equal
0 commit comments