Skip to content

Commit c057f10

Browse files
committed
comparison of weights initialized from pytorch saved state dict
1 parent b6768b2 commit c057f10

6 files changed

Lines changed: 49 additions & 13 deletions

File tree

algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import jax
88
import jax.numpy as jnp
99
import numpy as np
10-
10+
import logging
1111
from algoperf import param_utils
1212
from algoperf import spec
1313
from algoperf.workloads.criteo1tb.criteo1tb_jax import models
@@ -104,7 +104,8 @@ def init_model_fn(
104104
{'params': params_rng, 'dropout': dropout_rng},
105105
jnp.ones(input_shape, jnp.float32))
106106
initial_params = initial_variables['params']
107-
initial_params = use_pytorch_weights(initial_params)
107+
logging.info('\n\nInitializing with Pytorch weights\n\n')
108+
initial_params = use_pytorch_weights(initial_params, file_name="/results/pytorch_base_model_criteo1tb_8_june.pth")
108109
self._param_shapes = param_utils.jax_param_shapes(initial_params)
109110
self._param_types = param_utils.jax_param_types(self._param_shapes)
110111
return jax_utils.replicate(initial_params), None

algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def init_model_fn(
8888
dropout_rate=dropout_rate,
8989
use_layer_norm=self.use_layer_norm,
9090
embedding_init_multiplier=self.embedding_init_multiplier)
91-
torch.save(model.state_dict(), "/results/pytorch_base_model_criteo1tb_22_may.pth")
91+
torch.save(model.state_dict(), "/results/pytorch_base_model_criteo1tb_8_june.pth")
9292
self._param_shapes = param_utils.pytorch_param_shapes(model)
9393
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
9494
model.to(DEVICE)

custom_pytorch_jax_converter.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
11
import torch
22
import numpy as np
3-
from flax.core import freeze, unfreeze
4-
5-
# Load PyTorch state_dict
6-
state_dict = torch.load("/results/pytorch_base_model_criteo1tb_22_may.pth")
7-
8-
# Convert PyTorch tensors to NumPy arrays
9-
numpy_weights = {k: v.numpy() for k, v in state_dict.items()}
10-
11-
3+
import jax
4+
import jax.numpy as jnp
125
"""
136
Jax default parameter structure:
147
dict_keys(['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3', 'Dense_4', 'Dense_5', 'Dense_6', 'Dense_7', 'embedding_table'])
@@ -23,7 +16,13 @@
2316
The function assumes that the Jax model parameters are already initialized
2417
and that the PyTorch weights are in the correct format.
2518
"""
26-
def use_pytorch_weights(jax_params):
19+
def use_pytorch_weights(jax_params, file_name=None):
20+
# Load PyTorch state_dict
21+
state_dict = torch.load(file_name)
22+
print(state_dict.keys())
23+
# Convert PyTorch tensors to NumPy arrays
24+
numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()}
25+
2726
# --- Embedding Table ---
2827
embedding_table = np.concatenate([
2928
numpy_weights[f'embedding_chunk_{i}'] for i in range(4)
@@ -42,3 +41,28 @@ def use_pytorch_weights(jax_params):
4241
jax_params[f'Dense_{j}']['bias'] = numpy_weights[f'top_mlp.{i}.bias']
4342

4443
return jax_params
44+
45+
46+
def are_weights_equal(params1, params2, atol=1e-6, rtol=1e-6):
47+
"""Compares two JAX PyTrees of weights and prints where they differ."""
48+
all_equal = True
49+
50+
def compare_fn(p1, p2):
51+
nonlocal all_equal
52+
#if not jnp.allclose(p1, p2):
53+
if not jnp.allclose(p1, p2, atol=atol, rtol=rtol):
54+
print("❌ Mismatch found:")
55+
print(f"Shape 1: {p1.shape}, Shape 2: {p2.shape}")
56+
print(f"Max diff: {jnp.max(jnp.abs(p1 - p2))}")
57+
all_equal = False
58+
return jnp.allclose(p1, p2, atol=atol, rtol=rtol)
59+
60+
try:
61+
_ = jax.tree_util.tree_map(compare_fn, params1, params2)
62+
except Exception as e:
63+
print("❌ Structure mismatch or error during comparison:", e)
64+
return False
65+
66+
if all_equal:
67+
print("✅ All weights are equal (within tolerance)")
68+
return all_equal

reference_algorithms/schedule_free/jax/submission.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import jax
99
from jax import lax
1010
import jax.numpy as jnp
11+
import logging
1112
from optax.contrib import schedule_free_adamw
1213
from algoperf import spec
14+
from custom_pytorch_jax_converter import use_pytorch_weights, are_weights_equal
1315

1416
_GRAD_CLIP_EPS = 1e-6
1517

@@ -168,6 +170,9 @@ def update_params(workload: spec.Workload,
168170
'loss': loss[0],
169171
'grad_norm': grad_norm[0],
170172
}, global_step)
173+
logging.info('\n\nUsing the PyTorch weights of first update\n\n')
174+
params = use_pytorch_weights(new_params, file_name="/results/pytorch_base_model_criteo1tb_8_june_after_first_update.pth")
175+
are_weights_equal(new_params, params)
171176
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
172177

173178

reference_algorithms/schedule_free/pytorch/submission.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ def closure():
280280
global_step,
281281
loss.item())
282282

283+
torch.save(current_param_container.module.state_dict(), "/results/pytorch_base_model_criteo1tb_8_june_after_first_update.pth")
284+
283285
return (optimizer_state, current_param_container, new_model_state)
284286

285287

submission_runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,12 @@ def train_once(
376376
train_state['training_complete'] = True
377377
global_step += 1
378378
logging.info(f'Global step: {global_step}, batch size: {len(batch)}')
379+
379380
if (max_global_steps is not None) and (global_step == max_global_steps):
380381
train_state['training_complete'] = True
382+
if dist.is_available() and dist.is_initialized():
383+
dist.destroy_process_group()
384+
exit(0)
381385

382386
train_step_end_time = get_time()
383387

0 commit comments

Comments
 (0)