Skip to content

Commit 7a0f8db

Browse files
committed
fix: cleaning up the code
1 parent 9837910 commit 7a0f8db

3 files changed

Lines changed: 17 additions & 103 deletions

File tree

algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from algoperf.workloads.criteo1tb.criteo1tb_jax import models
1414
from algoperf.workloads.criteo1tb.workload import \
1515
BaseCriteo1TbDlrmSmallWorkload
16-
from custom_pytorch_jax_converter import use_pytorch_weights_inplace
16+
from custom_pytorch_jax_converter import use_pytorch_weights
1717

1818

1919

@@ -105,7 +105,7 @@ def init_model_fn(
105105
{'params': params_rng, 'dropout': dropout_rng},
106106
jnp.ones(input_shape, jnp.float32))
107107
initial_params = initial_variables['params']
108-
initial_params = use_pytorch_weights_inplace(initial_params, file_name="/results/pytorch_base_model_criteo1tb_1_july.pth")
108+
initial_params = use_pytorch_weights(file_name="/results/pytorch_base_model_criteo1tb_1_july.pth")
109109
self._param_shapes = param_utils.jax_param_shapes(initial_params)
110110
self._param_types = param_utils.jax_param_types(self._param_shapes)
111111
return jax_utils.replicate(initial_params), None

custom_pytorch_jax_converter.py

Lines changed: 13 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -21,100 +21,40 @@
2121
The function assumes that the Jax model parameters are already initialized
2222
and that the PyTorch weights are in the correct format.
2323
"""
24-
def use_pytorch_weights_inplace(jax_params, file_name=None, replicate=False):
2524

26-
# Load PyTorch state_dict
27-
state_dict = torch.load(file_name)
28-
print(state_dict.keys())
29-
30-
# Convert PyTorch tensors to NumPy arrays
31-
numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()}
32-
33-
# --- Embedding Table ---
34-
embedding_table = np.concatenate([
35-
numpy_weights[f'embedding_chunk_{i}'] for i in range(4)
36-
], axis=0) # adjust axis depending on chunking direction
37-
38-
jax_params['embedding_table'] = jnp.array(embedding_table)
39-
40-
# --- Bot MLP: Dense_0 to Dense_2 ---
41-
for i, j in zip([0, 2, 4], range(3)):
42-
jax_params[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'bot_mlp.{i}.weight'].T)
43-
jax_params[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'bot_mlp.{i}.bias'])
44-
45-
# --- Top MLP: Dense_3 to Dense_7 ---
46-
for i, j in zip([0, 2, 4, 6, 8], range(3, 8)):
47-
jax_params[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'top_mlp.{i}.weight'].T)
48-
jax_params[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'top_mlp.{i}.bias'])
49-
50-
del state_dict
51-
return jax_params
25+
def use_pytorch_weights(file_name: str):
26+
jax_copy = {}
5227

53-
54-
def use_pytorch_weights_cpu_copy(jax_params, file_name=None, replicate=False):
55-
56-
def deep_copy_to_cpu(pytree):
57-
return tree_map(lambda x: jax.device_put(jnp.array(copy.deepcopy(x)), device=jax.devices("cpu")[0]), pytree)
58-
59-
jax_copy = deep_copy_to_cpu(jax_params)
6028
# Load PyTorch state_dict lazily to CPU
6129
state_dict = torch.load(file_name, map_location='cpu')
6230
print(state_dict.keys())
63-
31+
6432
# Convert PyTorch tensors to NumPy arrays
65-
numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()}
33+
numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()}
6634

6735
# --- Embedding Table ---
6836
embedding_table = np.concatenate([
6937
numpy_weights[f'embedding_chunk_{i}'] for i in range(4)
70-
], axis=0) # adjust axis depending on chunking direction
38+
], axis=0) # adjust axis if chunking is not vertical
7139

7240
jax_copy['embedding_table'] = jnp.array(embedding_table)
7341

7442
# --- Bot MLP: Dense_0 to Dense_2 ---
7543
for i, j in zip([0, 2, 4], range(3)):
44+
jax_copy[f'Dense_{j}'] = {}
7645
jax_copy[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'bot_mlp.{i}.weight'].T)
7746
jax_copy[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'bot_mlp.{i}.bias'])
7847

7948
# --- Top MLP: Dense_3 to Dense_7 ---
8049
for i, j in zip([0, 2, 4, 6, 8], range(3, 8)):
50+
jax_copy[f'Dense_{j}'] = {}
8151
jax_copy[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'top_mlp.{i}.weight'].T)
8252
jax_copy[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'top_mlp.{i}.bias'])
83-
#jax_copy = tree_map(lambda x: jnp.array(x), jax_copy)
84-
del state_dict
8553

54+
del state_dict
8655
return jax_copy
8756

8857

89-
def use_pytorch_weights_inplace_mnist(jax_params, file_name=None, replicate=False):
90-
# Load the PyTorch checkpoint
91-
ckpt = torch.load(file_name)
92-
state_dict = ckpt['state_dict'] if 'state_dict' in ckpt else ckpt
93-
94-
print("Loaded PyTorch keys:", state_dict.keys())
95-
96-
# Convert to numpy
97-
numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()}
98-
99-
# Mapping PyTorch keys → JAX Dense layers
100-
layer_map = {
101-
'net.layer1': 'Dense_0',
102-
'net.layer2': 'Dense_1',
103-
}
104-
105-
for pt_name, jax_name in layer_map.items():
106-
weight_key = f"{pt_name}.weight"
107-
bias_key = f"{pt_name}.bias"
108-
109-
if weight_key not in numpy_weights or bias_key not in numpy_weights:
110-
raise KeyError(f"Missing keys: {weight_key} or {bias_key} in PyTorch weights")
111-
112-
jax_params[jax_name]['kernel'] = jnp.array(numpy_weights[weight_key].T) # Transpose!
113-
jax_params[jax_name]['bias'] = jnp.array(numpy_weights[bias_key])
114-
115-
return jax_params
116-
117-
11858
def maybe_unreplicate(pytree):
11959
"""If leading axis matches device count, strip it assuming it's pmap replication."""
12060
num_devices = jax.device_count()
@@ -123,6 +63,7 @@ def maybe_unreplicate(pytree):
12363
pytree
12464
)
12565

66+
12667
def move_to_cpu(tree):
12768
return jax.tree_util.tree_map(lambda x: jax.device_put(x, device=jax.devices("cpu")[0]), tree)
12869

@@ -143,7 +84,7 @@ def compare_fn(p1, p2):
14384
nonlocal all_equal
14485
if not jnp.allclose(p1, p2, atol=atol, rtol=rtol):
14586
logging.info("❌ Mismatch found:")
146-
logging.info(f"Shape 1: {p1.shape}, Shape 2: {p2.shape}")
87+
logging.info(f"Shape : {p1.shape}, Shape 2: {p2.shape}")
14788
logging.info(f"Max diff: {jnp.max(jnp.abs(p1 - p2))}")
14889
all_equal = False
14990
return jnp.allclose(p1, p2, atol=atol, rtol=rtol)
@@ -156,31 +97,6 @@ def compare_fn(p1, p2):
15697

15798
if all_equal:
15899
logging.info("✅ All weights are equal (within tolerance)")
159-
return all_equal
160-
161-
162-
163-
# def are_weights_equal(params1, params2, atol=1e-6, rtol=1e-6):
164-
# """Compares two JAX PyTrees of weights and prints where they differ."""
165-
# all_equal = True
166-
167-
# def compare_fn(p1, p2):
168-
# nonlocal all_equal
169-
# #if not jnp.allclose(p1, p2):
170-
# if not jnp.allclose(p1, p2, atol=atol, rtol=rtol):
171-
# logging.info("❌ Mismatch found:")
172-
# logging.info(f"Shape 1: {p1.shape}, Shape 2: {p2.shape}")
173-
# logging.info(f"Max diff: {jnp.max(jnp.abs(p1 - p2))}")
174-
# all_equal = False
175-
# return jnp.allclose(p1, p2, atol=atol, rtol=rtol)
176-
177-
# try:
178-
# _ = jax.tree_util.tree_map(compare_fn, params1, params2)
179-
# except Exception as e:
180-
# logging.info("❌ Structure mismatch or error during comparison:", e)
181-
# return False
182-
183-
# if all_equal:
184-
# logging.info("✅ All weights are equal (within tolerance)")
185-
# return all_equal
186-
100+
del params1
101+
del params2
102+
return all_equal

reference_algorithms/schedule_free/jax/submission.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import jax.numpy as jnp
1111
from optax.contrib import schedule_free_adamw
1212
from algoperf import spec
13-
from custom_pytorch_jax_converter import use_pytorch_weights_cpu_copy, are_weights_equal
13+
from custom_pytorch_jax_converter import use_pytorch_weights, are_weights_equal
1414

1515
_GRAD_CLIP_EPS = 1e-6
1616

@@ -174,11 +174,9 @@ def update_params(workload: spec.Workload,
174174
if global_step % 100 == 0:
175175
date_ = "2025-07-01"
176176
file_name = f"/results/schedule_free_pytorch_weights/criteo1tb_{date_}_after_{global_step}_steps.pth"
177-
params = use_pytorch_weights_cpu_copy(new_params, file_name=file_name, replicate=True)
177+
params = use_pytorch_weights(file_name=file_name)
178178
are_weights_equal(new_params, params)
179179
del params
180-
181-
breakpoint()
182180

183181
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
184182

0 commit comments

Comments
 (0)