diff --git a/lectures/jax_nn.md b/lectures/jax_nn.md index 389092a..58658c7 100644 --- a/lectures/jax_nn.md +++ b/lectures/jax_nn.md @@ -40,7 +40,6 @@ The lecture proceeds in three stages: We begin with imports and installs. ```{code-cell} ipython3 -import numpy as np import jax import jax.numpy as jnp import matplotlib.pyplot as plt @@ -48,7 +47,6 @@ import os from time import time from typing import NamedTuple from functools import partial - ``` ```{code-cell} ipython3 @@ -78,7 +76,7 @@ from keras.layers import Dense import optax ``` -## Set Up +## Set up Here we briefly describe the problem and generate synthetic data. @@ -103,9 +101,8 @@ Our default value of $k$ will be 10. ```{code-cell} ipython3 class Config(NamedTuple): epochs: int = 4000 # Number of passes through the data set - output_dim: int = 10 # Output dimension of input and hidden layers learning_rate: float = 0.001 # Learning rate for gradient descent - layer_sizes: tuple = (1, 10, 10, 10, 1) # Sizes of each layer in the network + layer_sizes: tuple = (1, 10, 10, 10, 1) # Layer sizes seed: int = 14 # Random seed for data generation ``` @@ -140,9 +137,10 @@ key = jax.random.PRNGKey(config.seed) key_train, key_validate = jax.random.split(key) x_train, y_train = generate_data(key_train) x_validate, y_validate = generate_data(key_validate) +keras.utils.set_random_seed(config.seed) fig, ax = plt.subplots() ax.scatter(x_train, y_train, alpha=0.5) -ax.scatter(x_validate, y_validate, color='red', alpha=0.5) +ax.scatter(x_validate, y_validate, alpha=0.5) ax.set_xlabel('x') ax.set_ylabel('y') plt.show() @@ -165,16 +163,31 @@ def build_keras_model( activation_function: str = 'tanh' # activation with default ): model = Sequential() + layer_sizes = config.layer_sizes + # Add layers to the network sequentially, from inputs towards outputs - for i in range(len(config.layer_sizes) - 1): + for in_dim, out_dim in zip(layer_sizes[:-2], layer_sizes[1:-1]): model.add( - Dense(units=config.output_dim, activation=activation_function) + Dense( + units=out_dim, + activation=activation_function, + kernel_initializer=keras.initializers.HeNormal(), + bias_initializer='ones', + ) ) # Add a final layer that maps to a scalar value, for regression. - model.add(Dense(units=1)) + model.add( + Dense( + units=layer_sizes[-1], + kernel_initializer=keras.initializers.HeNormal(), + bias_initializer='ones', + ) + ) # Embed training configurations model.compile( - optimizer=keras.optimizers.SGD(), + optimizer=keras.optimizers.SGD( + learning_rate=config.learning_rate + ), loss='mean_squared_error' ) return model @@ -215,7 +228,7 @@ The next function extracts and visualizes a prediction from the trained model. def plot_keras_output(model, x, y, x_validate, y_validate): y_predict = model.predict(x_validate, verbose=2) fig, ax = plt.subplots() - ax.scatter(x_validate, y_validate, color='red', alpha=0.5) + ax.scatter(x_validate, y_validate, alpha=0.5) ax.plot(x_validate, y_predict, label="fitted model", color='black') ax.set_xlabel('x') ax.set_ylabel('y') @@ -293,7 +306,6 @@ def initialize_layer(in_dim, out_dim, key): """ Initialize weights and biases for a single layer of a the network. Use He initialization for weights and ones for biases. - """ W = jax.random.normal(key, shape=(in_dim, out_dim)) * jnp.sqrt(2 / in_dim) b = jnp.ones((1, out_dim)) @@ -312,7 +324,6 @@ def initialize_network( Build a network by initializing all of the parameters. A network is a list of LayerParams instances, each containing a weight-bias pair (W, b). - """ layer_sizes = config.layer_sizes params = [] @@ -329,7 +340,7 @@ def initialize_network( Wait, you say! -Shouldn’t we concatenate the elements of $ \theta $ into some kind of big array, so that we can do autodiff with respect to this array? +Shouldn’t we concatenate the elements of $\theta$ into some kind of big array, so that we can do autodiff with respect to this array? Actually we don’t need to --- we use the JAX PyTree approach discussed below. @@ -354,7 +365,7 @@ def f( return x ``` -The function $ f $ is appropriately vectorized, so that we can pass in the entire +The function $f$ is appropriately vectorized, so that we can pass in the entire set of input observations as `x` and return the predicted vector of outputs `y_hat = f(θ, x)` corresponding to each data point. @@ -372,21 +383,21 @@ def loss_fn( We’ll use its gradient to do stochastic gradient descent. (Technically, we will be doing gradient descent, rather than stochastic -gradient descent, since will not randomize over sample points when we +gradient descent, since we will not randomize over sample points when we evaluate the gradient.) ```{code-cell} ipython3 -loss_gradient = jax.jit(jax.grad(loss_fn)) +loss_gradient = jax.grad(loss_fn) ``` The gradient of `loss_fn` is with respect to the first argument `θ`. The code above seems kind of magical, since we are differentiating with respect -to a parameter “vector” stored as a list of dictionaries containing arrays. +to a parameter “vector” stored as a list of namedtuples containing arrays. How can we differentiate with respect to such a complex object? -The answer is that the list of dictionaries is treated internally as a +The answer is that the list of namedtuples is treated internally as a [pytree](https://docs.jax.dev/en/latest/pytrees.html). The JAX function `grad` understands how to @@ -395,7 +406,6 @@ The JAX function `grad` understands how to 1. compute derivatives with respect to each one, and 1. pack the resulting derivatives into a pytree with the same structure as the parameter vector. -+++ ### Gradient descent @@ -416,6 +426,7 @@ def update_parameters( """ λ = config.learning_rate + # Specify the update rule def gradient_descent_step(p, g): """ @@ -423,7 +434,9 @@ def update_parameters( It will be applied to each leaf of the pytree of parameters. """ return p - λ * g + gradient = loss_gradient(θ, x, y) + # Use tree.map to apply the update rule to the parameter vectors θ_new = jax.tree.map(gradient_descent_step, θ, gradient) return θ_new @@ -448,7 +461,6 @@ def train_jax_model( ): """ Train model using gradient descent. - """ def update(_, θ): θ_new = update_parameters(θ, x, y, config) @@ -492,7 +504,7 @@ Here's a visualization of the quality of our fit. ```{code-cell} ipython3 fig, ax = plt.subplots() -ax.scatter(x_validate, y_validate, color='red', alpha=0.5) +ax.scatter(x_validate, y_validate, alpha=0.5) ax.plot(x_validate.flatten(), f(θ, x_validate).flatten(), label="fitted model", color='black') ax.set_xlabel('x') @@ -534,7 +546,8 @@ def train_jax_optax( return new_loop_state initial_loop_state = θ, opt_state - final_loop_state = jax.lax.fori_loop(0, epochs, update, initial_loop_state) + final_loop_state = jax.lax.fori_loop(0, + epochs, update, initial_loop_state) θ_final, _ = final_loop_state return θ_final ``` @@ -557,13 +570,16 @@ optax_sgd_runtime = time() - start_time optax_sgd_mse = loss_fn(θ, x_validate, y_validate) optax_sgd_train_mse = loss_fn(θ, x_train, y_train) -print(f"Trained model with JAX and Optax SGD in {optax_sgd_runtime:.2f} seconds.") +print( + "Trained model with JAX and Optax SGD " + f"in {optax_sgd_runtime:.2f} seconds." +) print(f"Final MSE on validation data = {optax_sgd_mse:.6f}") ``` ```{code-cell} ipython3 fig, ax = plt.subplots() -ax.scatter(x_validate, y_validate, color='red', alpha=0.5) +ax.scatter(x_validate, y_validate, alpha=0.5) ax.plot(x_validate.flatten(), f(θ, x_validate).flatten(), label="fitted model", color='black') ax.set_xlabel('x') @@ -601,11 +617,11 @@ def train_jax_optax_adam( return (θ_new, new_opt_state) initial_loop_state = θ, opt_state - θ_final, _ = jax.lax.fori_loop(0, epochs, update, initial_loop_state) + θ_final, _ = jax.lax.fori_loop(0, + epochs, update, initial_loop_state) return θ_final ``` - ```{code-cell} ipython3 # Reset parameter vector θ = initialize_network(param_key, config) @@ -622,7 +638,8 @@ optax_adam_runtime = time() - start_time optax_adam_mse = loss_fn(θ, x_validate, y_validate) optax_adam_train_mse = loss_fn(θ, x_train, y_train) -print(f"Trained model with JAX and Optax ADAM in {optax_adam_runtime:.2f} seconds.") +print("Trained model with JAX and Optax ADAM " + f"in {optax_adam_runtime:.2f} seconds.") print(f"Final MSE on validation data = {optax_adam_mse:.6f}") ``` @@ -630,7 +647,7 @@ Here's a visualization of the result. ```{code-cell} ipython3 fig, ax = plt.subplots() -ax.scatter(x_validate, y_validate, color='red', alpha=0.5) +ax.scatter(x_validate, y_validate, alpha=0.5) ax.plot(x_validate.flatten(), f(θ, x_validate).flatten(), label="fitted model", color='black') ax.set_xlabel('x') @@ -648,14 +665,8 @@ Here we compare the performance of the four different training approaches we exp import pandas as pd # Compute training MSEs for each method -# Need to retrieve the trained models and compute training MSE -# For Keras, we already have the model keras_train_mse = model.evaluate(x_train, y_train, verbose=0) -# For JAX methods, we need to compute using loss_fn with the final θ from each method -# We need to re-train or save the θ from each method -# For now, let's add these calculations after each training section - # Create summary table results = { 'Method': [ @@ -692,8 +703,7 @@ print("\nSummary of Training Methods:") print(df.to_string(index=False)) ``` - -All methods achieve similar validation MSE values (around 0.043-0.045). +All methods achieve similar validation MSE values (around 0.040-0.046). At the time of writing, the MSEs from plain vanilla Optax and our own hand-coded SGD routine are identical. @@ -713,22 +723,21 @@ Not surprisingly, Keras has more overhead from its abstraction layers. ```{exercise} :label: jax_nn_ex1 -Try to reduce the MSE on the validation data without significantly increasing the computational load. +Try to reduce the MSE on the validation data without significantly increasing +the computational load. -You should hold constant both the number of epochs and the total number of parameters in the network. +You should hold constant both the number of epochs and the total number of +parameters in the network. -Currently, the network has 4 layers with output dimension $k=10$, giving a total of: -- Layer 0: $1 \times 10 + 10 = 20$ parameters (weights + biases) -- Layer 1: $10 \times 10 + 10 = 110$ parameters -- Layer 2: $10 \times 10 + 10 = 110$ parameters -- Layer 3: $10 \times 1 + 1 = 11$ parameters -- Total: $251$ parameters +Currently, the network has 4 layers with output dimension $k=10$, giving a total +of $251$ parameters You can experiment with: - Changing the network architecture -- Trying different activation functions (e.g., `jax.nn.relu`, `jax.nn.gelu`, `jax.nn.sigmoid`, `jax.nn.elu`) -- Modifying the optimizer (e.g., different learning rates, learning rate schedules, momentum, other Optax optimizers) +- Trying different activation functions +- Modifying the optimizer (e.g., learning rates, learning rate schedules, momentum, etc.) - Experimenting with different weight initialization strategies +- Modifying the loss function (e.g., adding regularization) Which combination gives you the lowest validation MSE? @@ -741,116 +750,42 @@ Which combination gives you the lowest validation MSE? Let's implement and test several strategies. -**Strategy 1: Deeper Network Architecture** - -Let's try a deeper network with 6 layers instead of 4, keeping total parameters ≤ 251: - -```{code-cell} ipython3 -# Strategy 1: Deeper network (6 layers with k=6) -# Layer sizes: 1→6→6→6→6→6→1 -# Parameters: (1×6+6) + 4×(6×6+6) + (6×1+1) = 12 + 4×42 + 7 = 187 < 251 -θ = initialize_network(param_key, config) - -def initialize_deep_params( - key: jax.Array, # JAX random key - k: int = 6, # Layer width - num_hidden: int = 5 # Number of hidden layers - ): - " Initialize parameters for deeper network with k=6. " - layer_sizes = tuple([1] + [k] * num_hidden + [1]) - config_deep = Config(layer_sizes=layer_sizes) - return initialize_network(key, config_deep) - -θ_deep = initialize_deep_params(param_key) -config_deep = Config(layer_sizes=(1, 6, 6, 6, 6, 6, 1)) - -# Warmup -train_jax_optax_adam(θ_deep, x_train, y_train, config_deep) - -# Actual run -θ_deep = initialize_deep_params(param_key) -start_time = time() -θ_deep = train_jax_optax_adam(θ_deep, x_train, y_train, config_deep) -θ_deep[0].W.block_until_ready() -deep_runtime = time() - start_time - -deep_mse = loss_fn(θ_deep, x_validate, y_validate) -print(f"Strategy 1 - Deeper network (6 layers, k=6)") -print(f" Total parameters: 187") -print(f" Runtime: {deep_runtime:.2f}s") -print(f" Validation MSE: {deep_mse:.6f}") -print(f" Improvement over ADAM: {optax_adam_mse - deep_mse:.6f}") -``` - -**Strategy 2: Deeper Network + Learning Rate Schedule** +**Strategy 1: LR schedule and L2 regularization** -Since the deeper network performed best, let's combine it with the learning rate schedule: +Let's keep the baseline network architecture and add a learning rate schedule with L2 regularization: ```{code-cell} ipython3 -# Strategy 2: Deeper network + LR schedule -θ_deep = initialize_deep_params(param_key) +# Strategy 1: LR schedule and L2 regularization -def train_deep_with_schedule( +def loss_fn_l2( θ: list, x: jnp.ndarray, y: jnp.ndarray, - config: Config + λ_l2: float ): - " Train deeper network with learning rate schedule. " - epochs = config.epochs - schedule = optax.exponential_decay( - init_value=0.003, - transition_steps=1000, - decay_rate=0.5 - ) - - solver = optax.adam(schedule) - opt_state = solver.init(θ) - - def update(_, loop_state): - θ, opt_state = loop_state - grad = loss_gradient(θ, x, y) - updates, new_opt_state = solver.update(grad, opt_state, θ) - θ_new = optax.apply_updates(θ, updates) - return (θ_new, new_opt_state) - - initial_loop_state = θ, opt_state - θ_final, _ = jax.lax.fori_loop(0, epochs, update, initial_loop_state) - return θ_final - -# Warmup -train_deep_with_schedule(θ_deep, x_train, y_train, config_deep) + " L2-regularized MSE loss. " + mse = jnp.mean((f(θ, x) - y)**2) + l2_penalty = 0.0 + for W, b in θ: + l2_penalty += jnp.sum(W**2) + return mse + λ_l2 * l2_penalty -# Actual run -θ_deep = initialize_deep_params(param_key) -start_time = time() -θ_deep_schedule = train_deep_with_schedule(θ_deep, x_train, y_train, config_deep) -θ_deep_schedule[0].W.block_until_ready() -deep_schedule_runtime = time() - start_time - -deep_schedule_mse = loss_fn(θ_deep_schedule, x_validate, y_validate) -print(f"Strategy 2 - Deeper network + LR schedule") -print(f" Runtime: {deep_schedule_runtime:.2f}s") -print(f" Validation MSE: {deep_schedule_mse:.6f}") -print(f" Improvement over ADAM: {optax_adam_mse - deep_schedule_mse:.6f}") -``` -**Strategy 3: Deeper Network + LR Schedule + L2 Regularization** +loss_gradient_l2 = jax.grad(loss_fn_l2) -Let's add L2 regularization (similar to ridge regression) to penalize complexity: -```{code-cell} ipython3 -# Strategy 3: Deeper network + LR schedule + L2 regularization -θ_deep = initialize_deep_params(param_key) - -def train_deep_with_schedule_and_l2( +@partial(jax.jit, static_argnames=['config']) +def train_with_schedule_and_l2( θ: list, x: jnp.ndarray, y: jnp.ndarray, config: Config, - lambda_l2: float = 0.001 + λ_l2: float = 0.001 ): - " Train deeper network with learning rate schedule and L2 regularization. " + """ + Train baseline network with learning rate schedule + and L2 regularization. + """ epochs = config.epochs schedule = optax.exponential_decay( init_value=0.003, @@ -858,111 +793,166 @@ def train_deep_with_schedule_and_l2( decay_rate=0.5 ) - # Define regularized loss function - @jax.jit - def loss_fn_l2(θ, x, y): - # Standard MSE loss - mse = jnp.mean((f(θ, x) - y)**2) - # L2 penalty on weights (not biases) - l2_penalty = 0.0 - for W, b in θ: - l2_penalty += jnp.sum(W**2) - return mse + lambda_l2 * l2_penalty - - loss_gradient_l2 = jax.jit(jax.grad(loss_fn_l2)) - solver = optax.adam(schedule) opt_state = solver.init(θ) def update(_, loop_state): - θ, opt_state = loop_state - grad = loss_gradient_l2(θ, x, y) - updates, new_opt_state = solver.update(grad, opt_state, θ) - θ_new = optax.apply_updates(θ, updates) - return (θ_new, new_opt_state) + θ_curr, opt_state_curr = loop_state + grad = loss_gradient_l2(θ_curr, x, y, λ_l2) + updates, opt_state_new = solver.update( + grad, opt_state_curr, θ_curr) + θ_new = optax.apply_updates(θ_curr, updates) + return (θ_new, opt_state_new) initial_loop_state = θ, opt_state - θ_final, _ = jax.lax.fori_loop(0, epochs, update, initial_loop_state) + θ_final, _ = jax.lax.fori_loop( + 0, epochs, update, initial_loop_state + ) return θ_final # Warmup -train_deep_with_schedule_and_l2(θ_deep, x_train, y_train, config_deep) +θ_l2 = initialize_network(param_key, config) +train_with_schedule_and_l2(θ_l2, x_train, y_train, config) # Actual run -θ_deep = initialize_deep_params(param_key) +θ_l2 = initialize_network(param_key, config) start_time = time() -θ_deep_l2 = train_deep_with_schedule_and_l2(θ_deep, x_train, y_train, config_deep) -θ_deep_l2[0].W.block_until_ready() +θ_l2 = train_with_schedule_and_l2(θ_l2, x_train, y_train, config) +θ_l2[0].W.block_until_ready() deep_l2_runtime = time() - start_time -deep_l2_mse = loss_fn(θ_deep_l2, x_validate, y_validate) -print(f"Strategy 3 - Deeper network + LR schedule + L2 regularization") +deep_l2_mse = loss_fn(θ_l2, x_validate, y_validate) +print(f"Strategy 1 - LR schedule and L2 regularization") print(f" Runtime: {deep_l2_runtime:.2f}s") print(f" Validation MSE: {deep_l2_mse:.6f}") print(f" Improvement over ADAM: {optax_adam_mse - deep_l2_mse:.6f}") ``` -**Strategy 4: Baseline + L2 Regularization** +**Strategy 2: Baseline + Armijo Line Search** -Let's see if L2 regularization helps the baseline architecture: +Let's implement gradient descent with [Armijo line search](https://en.wikipedia.org/wiki/Backtracking_line_search) for adaptive step size selection: ```{code-cell} ipython3 -# Strategy 4: Baseline architecture + L2 regularization -θ = initialize_network(param_key, config) +# Strategy 2: Baseline architecture + Armijo line search +# Line search parameters +line_search_init_value = 0.01 +line_search_backtrack_factor = 0.5 +line_search_armijo_constant = 0.001 +max_backtrack_steps = 20 -def train_baseline_with_l2( - θ: list, - x: jnp.ndarray, - y: jnp.ndarray, - config: Config, - lambda_l2: float = 0.001 +@partial(jax.jit, static_argnames=['config']) +def train_jax_armijo_ls( + θ: list, # Initial parameters (pytree) + x: jnp.ndarray, # Training input data + y: jnp.ndarray, # Training target data + config: Config # contains configuration data ): - " Train baseline model with L2 regularization. " + """ + Train model using gradient descent with Armijo line search. + + The Armijo line search adaptively finds a suitable step size at each + iteration by ensuring sufficient decrease in the loss function. + """ epochs = config.epochs - learning_rate = config.learning_rate - # Define regularized loss function - @jax.jit - def loss_fn_l2(θ, x, y): - # Standard MSE loss - mse = jnp.mean((f(θ, x) - y)**2) - # L2 penalty on weights (not biases) - l2_penalty = 0.0 - for W, b in θ: - l2_penalty += jnp.sum(W**2) - return mse + lambda_l2 * l2_penalty + # Line search parameters + α_init = line_search_init_value + backtrack_factor = line_search_backtrack_factor + armijo_c = line_search_armijo_constant - loss_gradient_l2 = jax.jit(jax.grad(loss_fn_l2)) + def update_step(θ_current, x_data, y_data): + current_loss = loss_fn(θ_current, x_data, y_data) + g = loss_gradient(θ_current, x_data, y_data) - solver = optax.adam(learning_rate) - opt_state = solver.init(θ) + # Squared Euclidean norm of gradient for Armijo condition + g_norm_sq = jax.tree.reduce( + lambda a, b: a + jnp.sum(b**2), + g, + initializer=0.0, + ) - def update(_, loop_state): - θ, opt_state = loop_state - grad = loss_gradient_l2(θ, x, y) - updates, new_opt_state = solver.update(grad, opt_state, θ) - θ_new = optax.apply_updates(θ, updates) - return (θ_new, new_opt_state) + # Define the condition for the while_loop + def cond_fn(loop_args): + (α_val, + current_loss_val, + g_sq_sum, + θ_orig, + x_in, + y_in, + step_count) = loop_args + + loss_threshold = ( + current_loss_val - armijo_c * α_val * g_sq_sum + ) + + θ_candidate = jax.tree.map( + lambda p, g_leaf: p - α_val * g_leaf, + θ_orig, + g, + ) + loss_candidate = loss_fn(θ_candidate, x_in, y_in) + return ((loss_candidate > loss_threshold) + & (step_count < max_backtrack_steps)) + + # Define the body for the while_loop + def body_fn(loop_args): + (α_val, + current_loss_val, + g_sq_sum, + θ_orig, + x_in, + y_in, + step_count) = loop_args + α_new = α_val * backtrack_factor + step_count_new = step_count + 1 + return (α_new, + current_loss_val, + g_sq_sum, + θ_orig, + x_in, + y_in, + step_count_new) + + # Execute the Armijo line search using jax.lax.while_loop + loop_state = jax.lax.while_loop( + cond_fn, + body_fn, + (α_init, current_loss, + g_norm_sq, θ_current, x_data, y_data, 0), + ) + α_final = loop_state[0] - initial_loop_state = θ, opt_state - θ_final, _ = jax.lax.fori_loop(0, epochs, update, initial_loop_state) + # Update parameters with the chosen step size + θ_new = jax.tree.map( + lambda p, g_leaf: p - α_final * g_leaf, + θ_current, + g, + ) + return θ_new + + # Main training loop (epochs) + def outer_update(_, θ_curr): + return update_step(θ_curr, x, y) + + θ_final = jax.lax.fori_loop(0, epochs, outer_update, θ) return θ_final # Warmup -train_baseline_with_l2(θ, x_train, y_train, config) +θ = initialize_network(param_key, config) +train_jax_armijo_ls(θ, x_train, y_train, config) # Actual run θ = initialize_network(param_key, config) start_time = time() -θ_baseline_l2 = train_baseline_with_l2(θ, x_train, y_train, config) -θ_baseline_l2[0].W.block_until_ready() -baseline_l2_runtime = time() - start_time - -baseline_l2_mse = loss_fn(θ_baseline_l2, x_validate, y_validate) -print(f"Strategy 4 - Baseline + L2 regularization") -print(f" Runtime: {baseline_l2_runtime:.2f}s") -print(f" Validation MSE: {baseline_l2_mse:.6f}") -print(f" Improvement over ADAM: {optax_adam_mse - baseline_l2_mse:.6f}") +θ_armijo = train_jax_armijo_ls(θ, x_train, y_train, config) +θ_armijo[0].W.block_until_ready() +armijo_runtime = time() - start_time + +armijo_mse = loss_fn(θ_armijo, x_validate, y_validate) +print(f"Strategy 2 - Baseline + Armijo Line Search") +print(f" Runtime: {armijo_runtime:.2f}s") +print(f" Validation MSE: {armijo_mse:.6f}") +print(f" Improvement over ADAM: {optax_adam_mse - armijo_mse:.6f}") ``` **Results Summary** @@ -976,31 +966,23 @@ Let's compare all strategies: strategies_results = { 'Strategy': [ 'Baseline (ADAM + tanh)', - '1. Deeper network (6 layers)', - '2. Deeper network + LR schedule', - '3. Strategy 2 + L2 regularization', - '4. Baseline + L2 regularization' + '1. LR schedule and L2 regularization', + '2. Baseline + Armijo Line Search' ], 'Runtime (s)': [ optax_adam_runtime, - deep_runtime, - deep_schedule_runtime, deep_l2_runtime, - baseline_l2_runtime + armijo_runtime ], 'Validation MSE': [ optax_adam_mse, - deep_mse, - deep_schedule_mse, deep_l2_mse, - baseline_l2_mse + armijo_mse ], 'Improvement': [ 0.0, - float(optax_adam_mse - deep_mse), - float(optax_adam_mse - deep_schedule_mse), float(optax_adam_mse - deep_l2_mse), - float(optax_adam_mse - baseline_l2_mse) + float(optax_adam_mse - armijo_mse) ] } @@ -1009,23 +991,17 @@ print("\nSummary of Exercise Strategies:") print(df_strategies.to_string(index=False)) ``` +In terms of reducing loss on the validation data, the current winner is the +Armijo line search strategy. -The experimental results reveal several lessons: - -1. Architecture matters: A deeper, narrower network outperformed the - baseline network, despite using fewer parameters (187 vs 251). - -2. Combining strategies: Combining the deeper architecture with a learning - rate schedule showed that synergistic improvements are possible. - -3. Regularization helps: Adding L2 regularization (ridge penalty) can - improve performance by penalizing model complexity and reducing overfitting. - -4. Regularization vs architecture: Comparing strategies 3 and 4 shows whether - regularization is more effective with deeper architectures or simpler ones. +The Armijo backtracking line search is an adaptive step size method that +dynamically adjusts the learning rate at each iteration to ensure sufficient +decrease in the loss function. +Unlike fixed learning rates or predetermined schedules, it adapts to the local +geometry of the loss landscape. +This strategy and its code were contributed by [Matyas Farkas](https://www.matyasfarkas.eu/). ```{solution-end} ``` -