Skip to content

Latest commit

Β 

History

History
85 lines (62 loc) Β· 3.83 KB

File metadata and controls

85 lines (62 loc) Β· 3.83 KB

TUNIX: Post-Training for Reasoning

A high-performance pipeline with Scaling Logical Inference for Post Training Gemma 3 1B into a specialized reasoning engine using Tunix, JAX, and Cloud TPUs v3–8/v5e-1

Check out My blog related to This πŸ‘‰ Medium

πŸ“Œ Project Overview

This project focuses on the post-training phase of Large Language Models to enhance logical deduction and multi-step reasoning. By leveraging the JAX/Flax ecosystem, we achieve massive throughput on TPU v3-8 hardware, overcoming traditional bottlenecks associated with dynamic shapes in transformer training.

Key Objectives:

  • Reasoning Alignment: Transforming general-purpose knowledge into structured "Chain of Thought" (CoT) logic.
  • JAX Optimization: Solving the XLA "compilation hang" and memory fragmentation on TPUs.
  • Hardware Efficiency: Utilizing a TPU v3-8 mesh with LoRA (Low-Rank Adaptation) via Tunix.

πŸ› οΈ Technical Stack

  • Model: Google Gemma 3 (1B Variant)
  • Framework: Tunix (Post-training framework for JAX/Flax)
  • Compute: Google Cloud TPU v3-8
  • Optimization: Optax (AdamW with Cosine Schedule)
  • Data: Hugging Face Datasets (Multi-domain rebalanced)

πŸš€ Key Technical Hurdles & Solutions

1. The "Inhomogeneous Shape" Fix (Static Padding)

TPUs require fixed-size buffers. Standard tokenization creates "jagged" arrays that crash the XLA compiler.

  • Solution: Implemented a custom data collator that enforces a strict MAX_TARGET_LENGTH. This ensures every batch is a perfect rectangle, allowing JAX to compile the computation graph exactly once.

2. Solving the "BTNS" Einsum Error

Gemma 3's Multi-Head Attention expects specific 4D tensor alignments.

  • Solution: Manually expanded 2D attention masks into a 4D broadcastable format [Batch, 1, 1, Sequence]. This aligned the dimensions for the Einstein Summation (einsum) kernels in the attention layers.

3. JAX Data Tracing

JAX cannot trace Python strings. Residual metadata in datasets often causes _str_abstractify errors.

  • Solution: Developed a pre-processing pipeline that strips all non-numeric columns, leaving only the raw integer input_ids and attention_mask for the TPU.

πŸ“ˆ Performance & Results

  • Compilation Speed: After the initial Step 0 JAX trace, training stabilized at millisecond execution speeds per step.
  • Reasoning Delta: Post-trained models showed a marked increase in using "Chain of Thought" markers compared to the base model.

Tunix Reasoning Architecture & Hyperparameters Configs:

balance_ds_image
Final Rebalancing Dataset


enhanced_hardware_check Hardware Check


model_config
Model Configs


---

πŸ“Š Data Strategy

We utilized Stratified Post-Training, rebalancing the model across six critical reasoning domains:

  1. Mathematics: Step-by-step problem solving.
  2. Coding: Logic-heavy algorithm generation.
  3. Science: Deductive reasoning.
  4. Creative: Instruction following.
  5. Summarization: Contextual logic.
  6. General: General-knowledge reasoning.

πŸ‘©β€πŸ’» Author

Aditya Katkar
GitHub
LinkedIn