Skip to content

Commit 9dd8db0

Browse files
BUG: Remove x_validate and y_validate as they're unnecessary (#266)
1 parent 102113e commit 9dd8db0

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

lectures/jax_nn.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -444,8 +444,6 @@ def train_jax_model(
444444
θ: list, # Initial parameters (pytree)
445445
x: jnp.ndarray, # Training input data
446446
y: jnp.ndarray, # Training target data
447-
x_validate: jnp.ndarray, # Validation input data
448-
y_validate: jnp.ndarray, # Validation target data
449447
config: Config # contains configuration data
450448
):
451449
"""
@@ -473,12 +471,12 @@ param_key = jax.random.PRNGKey(1234)
473471
θ = initialize_network(param_key, config)
474472
475473
# Warmup run to trigger JIT compilation
476-
train_jax_model(θ, x_train, y_train, x_validate, y_validate, config)
474+
train_jax_model(θ, x_train, y_train, config)
477475
478476
# Reset and time the actual run
479477
θ = initialize_network(param_key, config)
480478
start_time = time()
481-
θ = train_jax_model(θ, x_train, y_train, x_validate, y_validate, config)
479+
θ = train_jax_model(θ, x_train, y_train, config)
482480
θ[0].W.block_until_ready() # Ensure computation completes
483481
jax_runtime = time() - start_time
484482

0 commit comments

Comments
 (0)