Skip to content

Conversation

@jstac
Copy link
Contributor

@jstac jstac commented Dec 7, 2025

Summary

This PR streamlines the exercise section of the JAX neural network lecture by removing underperforming strategies and highlighting the most effective approaches.

Changes

  • Removed 3 strategies with negative improvements:

    • Strategy 1: Deeper network (6 layers) - improvement: -0.000829
    • Strategy 2: Deeper network + LR schedule - improvement: -0.000764
    • Strategy 4: Baseline + L2 regularization - improvement: -0.000090
  • Kept 2 best-performing strategies:

    • Strategy 1: Deeper network + LR schedule + L2 regularization (improvement: +0.000024)
    • Strategy 2: Baseline + Armijo line search (improvement: +0.000058, best performer)
  • Added Armijo line search implementation: Complete implementation of gradient descent with backtracking line search for adaptive step size selection

  • Added technical explanation: Detailed explanation of how Armijo backtracking works and why it performs well

  • Attribution: Credited Matyas Farkas for contributing the winning Armijo line search strategy

Results

The Armijo line search strategy achieved the best validation MSE (0.040810) with competitive runtime (0.41s), demonstrating that adaptive step size selection provides meaningful improvements for neural network training.

🤖 Generated with Claude Code

Removed three strategies with negative improvements and kept only the
two best-performing approaches:
- Strategy 1: Deeper network + LR schedule + L2 regularization
- Strategy 2: Baseline + Armijo line search (best validation MSE)

Added technical explanation of Armijo backtracking line search and
credited Matyas Farkas for contributing the winning strategy.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
@netlify
Copy link

netlify bot commented Dec 7, 2025

Deploy Preview for incomparable-parfait-2417f8 ready!

Name Link
🔨 Latest commit d395036
🔍 Latest deploy log https://app.netlify.com/projects/incomparable-parfait-2417f8/deploys/6936ac03d980ef0008612539
😎 Deploy Preview https://deploy-preview-268--incomparable-parfait-2417f8.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify project configuration.

@github-actions
Copy link

github-actions bot commented Dec 7, 2025

@github-actions github-actions bot temporarily deployed to pull request December 7, 2025 21:50 Inactive
@github-actions github-actions bot temporarily deployed to pull request December 7, 2025 21:54 Inactive
@mmcky
Copy link
Contributor

mmcky commented Dec 8, 2025

  • @mmcky to do a full proof-read and style check
  • @HumphreyYang you may be interested in this new update. Thoughts are most welcome, let me know if you see an style issues etc. (@jstac mentioned an 80char line violation).

@HumphreyYang I am just working on a few things at the moment, but I hope the take a look at this later this afternoon and get this merged.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR streamlines the neural network exercise solution section by removing underperforming optimization strategies and highlighting the two best-performing approaches: a deeper network with learning rate scheduling and L2 regularization, and a baseline network with Armijo backtracking line search.

Key Changes:

  • Removed three optimization strategies with negative performance improvements (-0.000829, -0.000764, -0.000090)
  • Retained and renumbered two best-performing strategies (improvements: +0.000024 and +0.000058)
  • Added complete Armijo backtracking line search implementation with adaptive step size selection

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


4. Regularization vs architecture: Comparing strategies 3 and 4 shows whether
regularization is more effective with deeper architectures or simpler ones.
This strategy and its code was contributed by [Matyas Farkas](https://www.matyasfarkas.eu/).
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subject-verb agreement error: "This strategy and its code" is a plural subject requiring the plural verb "were" instead of "was".

Suggested change
This strategy and its code was contributed by [Matyas Farkas](https://www.matyasfarkas.eu/).
This strategy and its code were contributed by [Matyas Farkas](https://www.matyasfarkas.eu/).

Copilot uses AI. Check for mistakes.
@HumphreyYang
Copy link
Member

HumphreyYang commented Dec 8, 2025

Many thanks @jstac and @mmcky — the exercise looks great! It reminds me of the convex optimization classes I took.

I read through the lecture carefully and noted a few small questions and suggestions that might help:

  • ## Set Up -> ## Set up
  • Would removing color='red' better follow the stylesheet color covention?
  • The keras model actually has four hidden layers:

$$ \mathbb R \to \mathbb R^k \to \mathbb R^k \to \mathbb R^k \to \mathbb R^k \to \mathbb R $$

based on the output:

Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape           ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense (Dense)                   │ (400, 10)              │            20 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_1 (Dense)                 │ (400, 10)              │           110 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_2 (Dense)                 │ (400, 10)              │           110 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_3 (Dense)                 │ (400, 10)              │           110 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_4 (Dense)                 │ (400, 1)               │            11 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 363 (1.42 KB)
 Trainable params: 361 (1.41 KB)
 Non-trainable params: 0 (0.00 B)
 Optimizer params: 2 (8.00 B)

with $361$ parameters instead of $251$ (which is $361 - 110$ after removing one of the hidden layer!)

To align with the JAX version, we could change
for i in range(len(config.layer_sizes) - 1):
to
for i in range(len(config.layer_sizes) - 2):
since the output layer is added separately via model.add(Dense(units=1)).

Also, the Keras model does not use the dimensions in layer_sizes or learning_rate (the optimizer uses its default learning rate). This means changing Config will not affect the Keras model (though we did not do that in the lecture!)

  • In the function
@partial(jax.jit, static_argnames=['config'])
def train_jax_model(
        θ: list,                    # Initial parameters (pytree)
        x: jnp.ndarray,             # Training input data
        y: jnp.ndarray,             # Training target data
        x_validate: jnp.ndarray,    # Validation input data
        y_validate: jnp.ndarray,    # Validation target data
        config: Config              # contains configuration data
    ):
    """
    Train model using gradient descent.

    """
    def update(_, θ):
        θ_new = update_parameters(θ, x, y, config)
        return θ_new

    θ_final = jax.lax.fori_loop(0, config.epochs, update, θ)
    return θ_final

it looks like x_validate and y_validate are unused and could be removed.

  • The sentence ... a list of dictionaries containing arrays. could be updated to ... a list of namedtuples containing arrays. since we are storing them in namedtuples.

  • loss_gradient is already jitted via loss_gradient = jax.jit(jax.grad(loss_fn)) and is later called inside other jitted functions train_jax_optax, train_jax_optax_adam, and train_jax_armijo_ls. To follow "no nesting rule", we could remove the inner jax.jit or outer jax.jit.

  • I might be wrong on this, but in the Summary section, the following comments seem to be outdated:

# For JAX methods, we need to compute using loss_fn with the final θ from each method
# We need to re-train or save the θ from each method
# For now, let's add these calculations after each training section
  • In the exercise, it is noted that You should hold constant both the number of epochs and the total number of parameters in the network. but in the exercise, we used a deeper network that has $187 &lt; 251$ parameters. Here is my calculation for cross-checking:

input: 1 -> 6 so we have $6 + 6 = 12$ parameters.
hidden layers: 6 -> 6 so each has $6 \times 6 + 6 = 42$ parameters and for 4 of them we have $168$ parameters
output layer: 6 -> 1 so we have $6 + 1 = 7$ parameters so in total we have 187 parameters.

  • I think import numpy as np could be removed because it is not used.

Please let me know if I’ve misunderstood anything — I can push a commit to update the lecture if these suggestions seem useful!

@mmcky
Copy link
Contributor

mmcky commented Dec 8, 2025

@HumphreyYang I just merged #266 so I will check off the unused import comment.

@HumphreyYang
Copy link
Member

HumphreyYang commented Dec 8, 2025

Many thanks @mmcky, I realized that the colab bottom is not working because it is a preview link so I just removed that item!

@mmcky
Copy link
Contributor

mmcky commented Dec 8, 2025

Many thanks @mmcky, I realized that the colab bottom is not working because it is a preview link so I just removed that item!

thanks @HumphreyYang I wonder where the preview build is getting master from? I'll investigate. The tricky part with Colab buttons on previews is that it pulls from .notebook repo so is never representative (in terms of the notebook itself -- it will show you current live notebook until publish has finished and .notebooks repo is updated).

@github-actions github-actions bot temporarily deployed to pull request December 8, 2025 04:42 Inactive
@HumphreyYang
Copy link
Member

Hi @mmcky, it is coming from Fetch for https://api.github.com/repos/QuantEcon/lecture-jax.notebooks/contents/lectures?per_page=100&ref=master failed

@github-actions github-actions bot temporarily deployed to pull request December 8, 2025 04:46 Inactive
@mmcky
Copy link
Contributor

mmcky commented Dec 8, 2025

  • rerun this with a new cache build @mmcky

@github-actions github-actions bot temporarily deployed to pull request December 8, 2025 05:17 Inactive
@mmcky
Copy link
Contributor

mmcky commented Dec 8, 2025

  • @HumphreyYang the cache.yml was stale so looks like it was using an old cached javascript file that we fixed last week. The new run looks like it is working nicely.

@HumphreyYang
Copy link
Member

HumphreyYang commented Dec 8, 2025

Many thanks @mmcky, the colab bottom works perfectly now!

@jstac
Copy link
Contributor Author

jstac commented Dec 8, 2025

Great review @HumphreyYang many thanks!

Good catch regarding the Keras layers! We need to cut one, as you say.

Please check the runtime and MSE after the change, on a GPU build. The discussion of relative times and the table before the exercises might need to change.

As for the layer numbering convention, I follow this:

Some sources do not count the input layer, so 4 layers can also be correct

PS Feel free to enter the competition to reduce validation MSE !

@HumphreyYang
Copy link
Member

HumphreyYang commented Dec 8, 2025

Many thanks @jstac for your confirmation! I’ll leave the layer count out.

Is there anything else I should keep constant, and do I have your permission to make changes for the other items on the list in this PR? I can open a new branch after this is merged as well.

I’ll join the race after this busy month of paper revisions!

@jstac
Copy link
Contributor Author

jstac commented Dec 8, 2025

Thanks @HumphreyYang . All your comments are spot on and you are free to make those changes.

I’ll join the race after this busy month of paper revisions!

Sure. Please don't feel the need to invest too much time in this --- I know you are busy.

@HumphreyYang
Copy link
Member

Hi guys, I just pushed a commit to fix the listed items and the code length.

I’ll update the discussions on MSE once I’m back from dinner and have checked the preview run!

@github-actions github-actions bot temporarily deployed to pull request December 8, 2025 07:51 Inactive
@github-actions github-actions bot temporarily deployed to pull request December 8, 2025 07:55 Inactive
@HumphreyYang
Copy link
Member

HumphreyYang commented Dec 8, 2025

Hi @jstac,

I checked the MSE in the preview and the results are consistent with the previous discussions.

I made another commit to address some minor points:

  1. I removed
output_dim: int = 10           # Output dimension of input and hidden layers

since the layer output is controlled by layer_sizes so it is not used in code
2. jax.tree_util.tree_reduce $\to$ jax.tree.reduce.
3. I jitted the train_with_schedule_and_l2 instead of the inner function loss_fn_l2 only so the warm-up run speeds up the compiled run.

Please let me know if you spot any more improvement!

@github-actions github-actions bot temporarily deployed to pull request December 8, 2025 10:50 Inactive
@github-actions github-actions bot temporarily deployed to pull request December 8, 2025 10:55 Inactive
@jstac
Copy link
Contributor Author

jstac commented Dec 8, 2025

Great work @HumphreyYang , many thanks! This all looks great.

@mmcky It would be helpful to have the nice lecture-specific deployment messages here as well (if it's not a big job, since we'll dismantle this series before too long).

@jstac
Copy link
Contributor Author

jstac commented Dec 8, 2025

@mmcky Please review, merge and make live when you are ready.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants