Skip to content

Commit 7ec211f

Browse files
committed
getting it to level 0
1 parent 890f5af commit 7ec211f

2 files changed

Lines changed: 3 additions & 4 deletions

File tree

algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from algoperf.workloads.criteo1tb.criteo1tb_jax import models
1414
from algoperf.workloads.criteo1tb.workload import \
1515
BaseCriteo1TbDlrmSmallWorkload
16-
from custom_pytorch_jax_converter import use_pytorch_weights_inplace, use_pytorch_weights_inplace_mnist
16+
from custom_pytorch_jax_converter import use_pytorch_weights_inplace
1717

1818

1919

@@ -108,7 +108,6 @@ def init_model_fn(
108108
initial_params = use_pytorch_weights_inplace(initial_params, file_name="/results/pytorch_base_model_criteo1tb_24_june.pth")
109109
self._param_shapes = param_utils.jax_param_shapes(initial_params)
110110
self._param_types = param_utils.jax_param_types(self._param_shapes)
111-
return initial_params, None
112111
return jax_utils.replicate(initial_params), None
113112

114113
def is_output_params(self, param_key: spec.ParameterKey) -> bool:

reference_algorithms/schedule_free/jax/submission.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,8 @@ def update_params(workload: spec.Workload,
172172

173173
# Log the number of parameters.
174174
if global_step % 100 == 0 and workload.metrics_logger is not None:
175-
date_ = "2025=06-14"
176-
file_name = f"/results/schedule_free_test_pytorch_weights/criteo1tb_{date_}_after_{global_step}.pth"
175+
date_ = "2025-06-14"
176+
file_name = f"/results/schedule_free_test_pytorch_weights/criteo1tb_{date_}_after_{global_step}_steps.pth"
177177
params = use_pytorch_weights2(new_params, file_name=file_name, replicate=True)
178178
are_weights_equal(new_params, params)
179179
del params

0 commit comments

Comments
 (0)