@@ -57,15 +57,18 @@ def init_optimizer_state(workload: spec.Workload,
5757 workload .param_shapes )
5858 optimizer_state = opt_init_fn (params_zeros_like )
5959
60- return jax_utils .replicate (optimizer_state ), opt_update_fn
60+ #return jax_utils.replicate(optimizer_state), opt_update_fn
61+ return optimizer_state , opt_update_fn
6162
6263
6364@functools .partial (
64- jax .pmap ,
65+ jax .vmap ,
6566 axis_name = 'batch' ,
66- in_axes = (None , None , 0 , 0 , 0 , 0 , 0 , None , None ),
67- static_broadcasted_argnums = (0 , 1 ),
68- donate_argnums = (2 , 3 , 4 ))
67+ #in_axes=(None, None, None, 0, 0, 0, 0, None, None))
68+ in_axes = (None , None , None , None , None , 0 , 0 , None , None ))
69+ #in_axes=(None, None, None, 0, None, 0, 0, None, None))
70+ # static_broadcasted_argnums=(0, 1),
71+ # donate_argnums=(2, 3, 4))
6972def pmapped_train_step (workload ,
7073 opt_update_fn ,
7174 model_state ,
@@ -143,7 +146,7 @@ def update_params(workload: spec.Workload,
143146 del eval_results
144147
145148 optimizer_state , opt_update_fn = optimizer_state
146- per_device_rngs = jax .random .split (rng , jax .local_device_count ())
149+ # per_device_rngs = jax.random.split(rng, jax.local_device_count())
147150 if hasattr (hyperparameters , 'label_smoothing' ):
148151 label_smoothing = hyperparameters .label_smoothing
149152 else :
@@ -152,14 +155,22 @@ def update_params(workload: spec.Workload,
152155 grad_clip = hyperparameters .grad_clip
153156 else :
154157 grad_clip = None
158+
159+ per_example_rngs = jax .random .split (rng , batch ['inputs' ].shape [0 ])
160+
161+ print ("Optimizer state: " , jax .tree_map (lambda x : x .shape , optimizer_state ))
162+ print ("Current param container: " , jax .tree_map (lambda x : x .shape , current_param_container ))
163+ print ("model state: " , jax .tree_map (lambda x : x .shape , model_state ))
164+ print ("batch: " , jax .tree_map (lambda x : x .shape , batch ))
165+ print ("rng: " , jax .tree_map (lambda x : x .shape , per_example_rngs ))
155166
156167 outputs = pmapped_train_step (workload ,
157168 opt_update_fn ,
158169 model_state ,
159170 optimizer_state ,
160171 current_param_container ,
161172 batch ,
162- per_device_rngs ,
173+ per_example_rngs ,
163174 grad_clip ,
164175 label_smoothing )
165176 breakpoint ()
@@ -186,6 +197,7 @@ def update_params(workload: spec.Workload,
186197def get_batch_size (workload_name ):
187198 # Return the global batch size.
188199 if workload_name == 'criteo1tb' :
200+ return 8
189201 return 262_144
190202 elif workload_name == 'fastmri' :
191203 return 32
0 commit comments