File tree Expand file tree Collapse file tree
algoperf/workloads/criteo1tb/criteo1tb_jax
reference_algorithms/schedule_free/jax Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1313from algoperf .workloads .criteo1tb .criteo1tb_jax import models
1414from algoperf .workloads .criteo1tb .workload import \
1515 BaseCriteo1TbDlrmSmallWorkload
16- from custom_pytorch_jax_converter import use_pytorch_weights_inplace , use_pytorch_weights_inplace_mnist
16+ from custom_pytorch_jax_converter import use_pytorch_weights_inplace
1717
1818
1919
@@ -108,7 +108,6 @@ def init_model_fn(
108108 initial_params = use_pytorch_weights_inplace (initial_params , file_name = "/results/pytorch_base_model_criteo1tb_24_june.pth" )
109109 self ._param_shapes = param_utils .jax_param_shapes (initial_params )
110110 self ._param_types = param_utils .jax_param_types (self ._param_shapes )
111- return initial_params , None
112111 return jax_utils .replicate (initial_params ), None
113112
114113 def is_output_params (self , param_key : spec .ParameterKey ) -> bool :
Original file line number Diff line number Diff line change @@ -172,8 +172,8 @@ def update_params(workload: spec.Workload,
172172
173173 # Log the number of parameters.
174174 if global_step % 100 == 0 and workload .metrics_logger is not None :
175- date_ = "2025= 06-14"
176- file_name = f"/results/schedule_free_test_pytorch_weights/criteo1tb_{ date_ } _after_{ global_step } .pth"
175+ date_ = "2025- 06-14"
176+ file_name = f"/results/schedule_free_test_pytorch_weights/criteo1tb_{ date_ } _after_{ global_step } _steps .pth"
177177 params = use_pytorch_weights2 (new_params , file_name = file_name , replicate = True )
178178 are_weights_equal (new_params , params )
179179 del params
You can’t perform that action at this time.
0 commit comments