Skip to content

Commit a422795

Browse files
committed
schedule free comparisons with criteo
1 parent 593cf49 commit a422795

7 files changed

Lines changed: 753 additions & 1 deletion

File tree

algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from algoperf.workloads.criteo1tb.criteo1tb_jax import models
1414
from algoperf.workloads.criteo1tb.workload import \
1515
BaseCriteo1TbDlrmSmallWorkload
16+
from custom_pytorch_jax_converter import use_pytorch_weights_inplace, use_pytorch_weights_inplace_mnist
17+
1618

1719

1820
class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload):
@@ -103,6 +105,7 @@ def init_model_fn(
103105
{'params': params_rng, 'dropout': dropout_rng},
104106
jnp.ones(input_shape, jnp.float32))
105107
initial_params = initial_variables['params']
108+
initial_params = use_pytorch_weights_inplace(initial_params, file_name="/results/pytorch_base_model_criteo1tb_24_june.pth")
106109
self._param_shapes = param_utils.jax_param_shapes(initial_params)
107110
self._param_types = param_utils.jax_param_types(self._param_shapes)
108111
return jax_utils.replicate(initial_params), None

algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def init_model_fn(
8888
dropout_rate=dropout_rate,
8989
use_layer_norm=self.use_layer_norm,
9090
embedding_init_multiplier=self.embedding_init_multiplier)
91+
torch.save(model.state_dict(), "/results/pytorch_base_model_criteo1tb_24_june.pth")
9192
self._param_shapes = param_utils.pytorch_param_shapes(model)
9293
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
9394
model.to(DEVICE)

algoperf/workloads/mnist/mnist_jax/workload.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from algoperf import param_utils
1414
from algoperf import spec
1515
from algoperf.workloads.mnist.workload import BaseMnistWorkload
16-
16+
from custom_pytorch_jax_converter import use_pytorch_weights_inplace, use_pytorch_weights_inplace_mnist
1717

1818
class _Model(nn.Module):
1919

@@ -42,8 +42,10 @@ def init_model_fn(
4242
del aux_dropout_rate
4343
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
4444
self._model = _Model()
45+
4546
initial_params = self._model.init({'params': rng}, init_val,
4647
train=True)['params']
48+
initial_params = use_pytorch_weights_inplace_mnist(initial_params, file_name="/results/pytorch_base_model_mnist_24june.pth")
4749
self._param_shapes = param_utils.jax_param_shapes(initial_params)
4850
self._param_types = param_utils.jax_param_types(self._param_shapes)
4951
return jax_utils.replicate(initial_params), None

algoperf/workloads/mnist/mnist_pytorch/workload.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def init_model_fn(
135135

136136
torch.random.manual_seed(rng[0])
137137
self._model = _Model()
138+
torch.save(self._model.state_dict(), "/results/pytorch_base_model_mnist_24june.pth")
138139
self._param_shapes = param_utils.pytorch_param_shapes(self._model)
139140
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
140141
self._model.to(DEVICE)

custom_pytorch_jax_converter.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import torch
2+
import numpy as np
3+
import jax
4+
import jax.numpy as jnp
5+
import logging
6+
import copy
7+
import copy
8+
from jax.tree_util import tree_map
9+
"""
10+
Jax default parameter structure:
11+
dict_keys(['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3', 'Dense_4', 'Dense_5', 'Dense_6', 'Dense_7', 'embedding_table'])
12+
13+
Pytorch stateduct structure:
14+
dict_keys(['embedding_chunk_0', 'embedding_chunk_1', 'embedding_chunk_2', 'embedding_chunk_3', 'bot_mlp.0.weight', 'bot_mlp.0.bias', 'bot_mlp.2.weight', 'bot_mlp.2.bias', 'bot_mlp.4.weight', 'bot_mlp.4.bias', 'top_mlp.0.weight', 'top_mlp.0.bias', 'top_mlp.2.weight', 'top_mlp.2.bias', 'top_mlp.4.weight', 'top_mlp.4.bias', 'top_mlp.6.weight', 'top_mlp.6.bias', 'top_mlp.8.weight', 'top_mlp.8.bias'])
15+
16+
17+
18+
The following function converts the PyTorch weights to the Jax format
19+
and assigns them to the Jax model parameters.
20+
The function assumes that the Jax model parameters are already initialized
21+
and that the PyTorch weights are in the correct format.
22+
"""
23+
def use_pytorch_weights_inplace(jax_params, file_name=None, replicate=False):
24+
25+
# Load PyTorch state_dict
26+
state_dict = torch.load(file_name)
27+
print(state_dict.keys())
28+
# Convert PyTorch tensors to NumPy arrays
29+
numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()}
30+
31+
# --- Embedding Table ---
32+
embedding_table = np.concatenate([
33+
numpy_weights[f'embedding_chunk_{i}'] for i in range(4)
34+
], axis=0) # adjust axis depending on chunking direction
35+
36+
jax_params['embedding_table'] = jnp.array(embedding_table)
37+
38+
# --- Bot MLP: Dense_0 to Dense_2 ---
39+
for i, j in zip([0, 2, 4], range(3)):
40+
jax_params[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'bot_mlp.{i}.weight'].T)
41+
jax_params[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'bot_mlp.{i}.bias'])
42+
43+
# --- Top MLP: Dense_3 to Dense_7 ---
44+
for i, j in zip([0, 2, 4, 6, 8], range(3, 8)):
45+
jax_params[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'top_mlp.{i}.weight'].T)
46+
jax_params[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'top_mlp.{i}.bias'])
47+
#jax_params = tree_map(lambda x: jnp.array(x), jax_params)
48+
del state_dict
49+
return jax_params
50+
51+
52+
def use_pytorch_weights_inplace_mnist(jax_params, file_name=None, replicate=False):
53+
# Load the PyTorch checkpoint
54+
ckpt = torch.load(file_name)
55+
state_dict = ckpt['state_dict'] if 'state_dict' in ckpt else ckpt
56+
57+
print("Loaded PyTorch keys:", state_dict.keys())
58+
59+
# Convert to numpy
60+
numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()}
61+
62+
# Mapping PyTorch keys → JAX Dense layers
63+
layer_map = {
64+
'net.layer1': 'Dense_0',
65+
'net.layer2': 'Dense_1',
66+
}
67+
68+
for pt_name, jax_name in layer_map.items():
69+
weight_key = f"{pt_name}.weight"
70+
bias_key = f"{pt_name}.bias"
71+
72+
if weight_key not in numpy_weights or bias_key not in numpy_weights:
73+
raise KeyError(f"Missing keys: {weight_key} or {bias_key} in PyTorch weights")
74+
75+
jax_params[jax_name]['kernel'] = jnp.array(numpy_weights[weight_key].T) # Transpose!
76+
jax_params[jax_name]['bias'] = jnp.array(numpy_weights[bias_key])
77+
78+
return jax_params
79+
80+
81+
# def are_weights_equal(params1, params2, atol=1e-6, rtol=1e-6):
82+
# """Compares two JAX PyTrees of weights and prints where they differ."""
83+
# all_equal = True
84+
85+
# def compare_fn(p1, p2):
86+
# nonlocal all_equal
87+
# #if not jnp.allclose(p1, p2):
88+
# if not jnp.allclose(p1, p2, atol=atol, rtol=rtol):
89+
# logging.info("❌ Mismatch found:")
90+
# logging.info(f"Shape 1: {p1.shape}, Shape 2: {p2.shape}")
91+
# logging.info(f"Max diff: {jnp.max(jnp.abs(p1 - p2))}")
92+
# all_equal = False
93+
# return jnp.allclose(p1, p2, atol=atol, rtol=rtol)
94+
95+
# try:
96+
# _ = jax.tree_util.tree_map(compare_fn, params1, params2)
97+
# except Exception as e:
98+
# logging.info("❌ Structure mismatch or error during comparison:", e)
99+
# return False
100+
101+
# if all_equal:
102+
# logging.info("✅ All weights are equal (within tolerance)")
103+
# return all_equal
104+
105+
import jax
106+
import jax.numpy as jnp
107+
import logging
108+
109+
def maybe_unreplicate(pytree):
110+
"""If leading axis matches device count, strip it assuming it's pmap replication."""
111+
num_devices = jax.device_count()
112+
return jax.tree_util.tree_map(
113+
lambda x: x[0] if isinstance(x, jax.Array) and x.shape[0] == num_devices else x,
114+
pytree
115+
)
116+
117+
def move_to_cpu(tree):
118+
return jax.tree_util.tree_map(lambda x: jax.device_put(x, device=jax.devices("cpu")[0]), tree)
119+
120+
121+
def are_weights_equal(params1, params2, atol=1e-6, rtol=1e-6):
122+
"""Compares two JAX PyTrees of weights and logs where they differ, safely handling PMAP replication."""
123+
# Attempt to unreplicate if needed
124+
params1 = maybe_unreplicate(params1)
125+
params2 = maybe_unreplicate(params2)
126+
127+
params1 = move_to_cpu(params1)
128+
params2 = move_to_cpu(params2)
129+
130+
all_equal = True
131+
132+
def compare_fn(p1, p2):
133+
nonlocal all_equal
134+
if not jnp.allclose(p1, p2, atol=atol, rtol=rtol):
135+
logging.info("❌ Mismatch found:")
136+
logging.info(f"Shape 1: {p1.shape}, Shape 2: {p2.shape}")
137+
logging.info(f"Max diff: {jnp.max(jnp.abs(p1 - p2))}")
138+
all_equal = False
139+
return jnp.allclose(p1, p2, atol=atol, rtol=rtol)
140+
141+
try:
142+
jax.tree_util.tree_map(compare_fn, params1, params2)
143+
except Exception as e:
144+
logging.info("❌ Structure mismatch or error during comparison:", exc_info=True)
145+
return False
146+
147+
if all_equal:
148+
logging.info("✅ All weights are equal (within tolerance)")
149+
return all_equal
150+
151+
152+
153+
def use_pytorch_weights2(jax_params, file_name=None, replicate=False):
154+
155+
def deep_copy_to_cpu(pytree):
156+
return tree_map(lambda x: jax.device_put(jnp.array(copy.deepcopy(x)), device=jax.devices("cpu")[0]), pytree)
157+
158+
breakpoint()
159+
jax_copy = deep_copy_to_cpu(jax_params)
160+
# Load PyTorch state_dict lazily to CPU
161+
state_dict = torch.load(file_name, map_location='cpu')
162+
print(state_dict.keys())
163+
# Convert PyTorch tensors to NumPy arrays
164+
numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()}
165+
166+
# --- Embedding Table ---
167+
embedding_table = np.concatenate([
168+
numpy_weights[f'embedding_chunk_{i}'] for i in range(4)
169+
], axis=0) # adjust axis depending on chunking direction
170+
171+
jax_copy['embedding_table'] = jnp.array(embedding_table)
172+
173+
# --- Bot MLP: Dense_0 to Dense_2 ---
174+
for i, j in zip([0, 2, 4], range(3)):
175+
jax_copy[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'bot_mlp.{i}.weight'].T)
176+
jax_copy[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'bot_mlp.{i}.bias'])
177+
178+
# --- Top MLP: Dense_3 to Dense_7 ---
179+
for i, j in zip([0, 2, 4, 6, 8], range(3, 8)):
180+
jax_copy[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'top_mlp.{i}.weight'].T)
181+
jax_copy[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'top_mlp.{i}.bias'])
182+
#jax_copy = tree_map(lambda x: jnp.array(x), jax_copy)
183+
del state_dict
184+
185+
return jax_copy
186+

0 commit comments

Comments
 (0)