Feature/issue 275 gemma3 tpu v5e8#283
Conversation
…evert unrelated README/docs changes - Delete `examples/lookahead_usage.py` (redundant and not notebook-style) - Revert unrelated note additions in `README.md` and `docs/development.md` - Improve `examples/lookahead_mnist.ipynb` with detailed explanation, initialization steps, annotated training loop, and a summary usage pattern for Lookahead optimizer
…sm notebook - Implements Gemma 3 (270M) with Keras 3 JAX backend - Uses Keras Distribution API for modern data parallelism - Targets Kaggle TPU v5e-8 (8-core) for accessible multi-core training - Replaces outdated Flax/HuggingFace approach with future-proof stack - Includes comprehensive examples, benchmarks, and best practices - Addresses deprecated legacy classes from original notebook Features: - TPU mesh configuration and device detection - Data parallel inference with performance comparison - Batch size scaling experiments - Advanced mesh topology examples - Memory monitoring and troubleshooting guide - Kaggle-specific optimizations and setup instructions
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Summary of ChangesHello @Solventerritory, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request aims to introduce a new, modernized Jupyter notebook for performing data parallel inference with Gemma 3 on Kaggle's TPU v5e-8. It leverages a contemporary technology stack including Keras 3 with JAX and the Keras Distribution API, providing a robust and scalable solution for multi-core TPU environments. The notebook is designed to replace outdated implementations and offers comprehensive guidance on TPU usage and optimization. However, it appears the provided patch content does not align with this description, instead adding an Optax lookahead optimizer example notebook and an Optax subproject. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds a new example notebook for the Optax lookahead optimizer. However, the PR description seems to be for a different change related to Gemma 3 and TPUs, which is confusing.
My review of the new notebook lookahead_mnist.ipynb has found a few critical issues:
- The notebook will fail to run due to a shape mismatch error in the
loss_fn. - The dummy data is initialized in a way that results in zero loss and zero gradients, meaning the model will not learn, and the optimizer's effect cannot be demonstrated.
- The core purpose of the notebook—to demonstrate a bug and its fix—is not achieved. The code presented to 'reproduce the bug' is actually correct, and it is identical to the code in the 'fix' section.
I've left specific comments with suggestions to fix these issues. Addressing them will make the notebook functional and a valuable example for users.
| "\n", | ||
| "# Dummy data for demonstration\n", | ||
| "x = jnp.ones((32, 784))\n", | ||
| "y = jnp.zeros((32,), dtype=jnp.int32)\n", |
There was a problem hiding this comment.
The dummy data setup with y as all zeros, combined with zero-initialized parameters, results in an initial loss of 0 and zero gradients. Consequently, the optimizer will not update the parameters, and the loss will not decrease. This prevents the notebook from demonstrating that the optimizer is working. To fix this, initialize y with non-zero values to ensure there is a non-zero loss and gradient at the start of training.
y = jnp.ones((32,), dtype=jnp.int32)
| "\n", | ||
| "def loss_fn(params, x, y):\n", | ||
| " logits = model(params, x)\n", | ||
| " return jnp.mean((logits - y) ** 2)\n", |
There was a problem hiding this comment.
The loss_fn will raise a ValueError because of a shape mismatch during subtraction. logits has a shape of (32, 10), while y has a shape of (32,). These shapes are not compatible for broadcasting. To fix this, you should reshape y to (32, 1) to make it a column vector, which can then be broadcast correctly across the logits matrix.
return jnp.mean((logits - y[:, None]) ** 2)
| "\n", | ||
| "# Incorrect usage: not updating lookahead state properly in a loop\n", | ||
| "for step in range(5):\n", | ||
| " params, opt_state = update(params, opt_state, x, y)\n", |
There was a problem hiding this comment.
This line correctly updates the opt_state, which contradicts the section's goal of demonstrating a bug. To properly illustrate the 'incorrect usage', you should simulate a common error, such as failing to update the optimizer state. For example, you could discard the new state returned from the update function.
params, _ = update(params, opt_state, x, y) # Bug: opt_state is not updated for the next iteration
|
Your CL has an unrelated notebook |
This PR addresses issue #275 by adding a modernized notebook for data parallel inference with Gemma 3 on TPU v5e-8.
Changes
New Notebook:
[Gemma_3]Data_Parallel_Inference_JAX_TPU_v5e8.ipynbThis notebook replaces the outdated
[Gemma_1]data_parallel_inference_in_jax_tpu.ipynbwith a modern, future-proof implementation.Key Improvements
Modern Stack
Accessible Hardware
Comprehensive Content
Why This Matters
The original notebook relied on:
This new notebook provides:
Testing
Related Issues
Fixes #275
Additional Notes
This is a new notebook addition, not a modification of the existing one. The old notebook can be deprecated separately to maintain backward compatibility.
Ready for testing on Kaggle TPU v5e-8 environment.
How to Test:
Write-Host "========================================`n" -ForegroundColor Cyan