Skip to content

Commit 453913f

Browse files
committed
trying vmap for debugging
1 parent a422795 commit 453913f

2 files changed

Lines changed: 20 additions & 7 deletions

File tree

algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def init_model_fn(
108108
initial_params = use_pytorch_weights_inplace(initial_params, file_name="/results/pytorch_base_model_criteo1tb_24_june.pth")
109109
self._param_shapes = param_utils.jax_param_shapes(initial_params)
110110
self._param_types = param_utils.jax_param_types(self._param_shapes)
111+
return initial_params, None
111112
return jax_utils.replicate(initial_params), None
112113

113114
def is_output_params(self, param_key: spec.ParameterKey) -> bool:

reference_algorithms/schedule_free/jax/submission.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,18 @@ def init_optimizer_state(workload: spec.Workload,
5757
workload.param_shapes)
5858
optimizer_state = opt_init_fn(params_zeros_like)
5959

60-
return jax_utils.replicate(optimizer_state), opt_update_fn
60+
#return jax_utils.replicate(optimizer_state), opt_update_fn
61+
return optimizer_state, opt_update_fn
6162

6263

6364
@functools.partial(
64-
jax.pmap,
65+
jax.vmap,
6566
axis_name='batch',
66-
in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
67-
static_broadcasted_argnums=(0, 1),
68-
donate_argnums=(2, 3, 4))
67+
#in_axes=(None, None, None, 0, 0, 0, 0, None, None))
68+
in_axes=(None, None, None, None, None, 0, 0, None, None))
69+
#in_axes=(None, None, None, 0, None, 0, 0, None, None))
70+
# static_broadcasted_argnums=(0, 1),
71+
# donate_argnums=(2, 3, 4))
6972
def pmapped_train_step(workload,
7073
opt_update_fn,
7174
model_state,
@@ -143,7 +146,7 @@ def update_params(workload: spec.Workload,
143146
del eval_results
144147

145148
optimizer_state, opt_update_fn = optimizer_state
146-
per_device_rngs = jax.random.split(rng, jax.local_device_count())
149+
#per_device_rngs = jax.random.split(rng, jax.local_device_count())
147150
if hasattr(hyperparameters, 'label_smoothing'):
148151
label_smoothing = hyperparameters.label_smoothing
149152
else:
@@ -152,14 +155,22 @@ def update_params(workload: spec.Workload,
152155
grad_clip = hyperparameters.grad_clip
153156
else:
154157
grad_clip = None
158+
159+
per_example_rngs = jax.random.split(rng, batch['inputs'].shape[0])
160+
161+
print("Optimizer state: ", jax.tree_map(lambda x: x.shape, optimizer_state))
162+
print("Current param container: ", jax.tree_map(lambda x: x.shape, current_param_container))
163+
print("model state: ", jax.tree_map(lambda x: x.shape, model_state))
164+
print("batch: ", jax.tree_map(lambda x: x.shape, batch))
165+
print("rng: ", jax.tree_map(lambda x: x.shape, per_example_rngs))
155166

156167
outputs = pmapped_train_step(workload,
157168
opt_update_fn,
158169
model_state,
159170
optimizer_state,
160171
current_param_container,
161172
batch,
162-
per_device_rngs,
173+
per_example_rngs,
163174
grad_clip,
164175
label_smoothing)
165176
breakpoint()
@@ -186,6 +197,7 @@ def update_params(workload: spec.Workload,
186197
def get_batch_size(workload_name):
187198
# Return the global batch size.
188199
if workload_name == 'criteo1tb':
200+
return 8
189201
return 262_144
190202
elif workload_name == 'fastmri':
191203
return 32

0 commit comments

Comments
 (0)