@@ -444,8 +444,6 @@ def train_jax_model(
444444 θ: list, # Initial parameters (pytree)
445445 x: jnp.ndarray, # Training input data
446446 y: jnp.ndarray, # Training target data
447- x_validate: jnp.ndarray, # Validation input data
448- y_validate: jnp.ndarray, # Validation target data
449447 config: Config # contains configuration data
450448 ):
451449 """
@@ -473,12 +471,12 @@ param_key = jax.random.PRNGKey(1234)
473471θ = initialize_network(param_key, config)
474472
475473# Warmup run to trigger JIT compilation
476- train_jax_model(θ, x_train, y_train, x_validate, y_validate, config)
474+ train_jax_model(θ, x_train, y_train, config)
477475
478476# Reset and time the actual run
479477θ = initialize_network(param_key, config)
480478start_time = time()
481- θ = train_jax_model(θ, x_train, y_train, x_validate, y_validate, config)
479+ θ = train_jax_model(θ, x_train, y_train, config)
482480θ[0].W.block_until_ready() # Ensure computation completes
483481jax_runtime = time() - start_time
484482
0 commit comments