Skip to content

Latest commit

 

History

History
63 lines (36 loc) · 2.88 KB

File metadata and controls

63 lines (36 loc) · 2.88 KB

Task 2.5: LM-JEPA Foundation Model for Squared Amplitudes

The Transformer architecture needed to process tokenized prefix sequences (from Task 1.2) and perform unsupervised representation learning and supervised fine-tuning for Squared Amplitude prediction is contained in this directory.

Architecture Overview

I leveraged a Joint-Embedding Predictive Architecture (JEPA). This framework learns representations fully in the latent space, avoiding the collapse problems of conventional masked autoencoders.

Linear FastAttention vs Standard Attention

Resolving the computational bottleneck related to nested SymPy graphs was a crucial turning point in this stage.

graph LR
    A["Standard Attention"] -->|Softmax bottleneck| B("O(N²) Complexity")
    C["FastAttention"] -->|Linear Matrix Decomposition| D("O(N) Complexity")
    
    style B fill:#ffe6e6,stroke:#ff0000
    style D fill:#e6ffe6,stroke:#008000
Loading

Performance Benchmarking

Representation Learning Progress

The JEPA model established a strong foundational prior for Feynman diagram kinematics during the self supervised phase, achieving robust convergence before fine-tuning.

JEPA Loss Curve

Key Observation: The JEPA component successfully learns the physics prior, as evidenced by the self-supervised objective's steady decline from 0.8 to a stabilized 0.125 floor.

Inference Scaling

Scaling Plot

Key Observation: The complexity scaling plot explicitly demonstrates the superiority of FastAttention over standard $O(N^2)$ transformers. Standard transformers experience significant memory bottlenecks at lengths longer than 32 tokens, while the linear pipeline easily handles longer sequences.

Prediction Precision

To strictly validate the model's regression capabilities for Squared Amplitude modeling (0.098 Validation MSE), the distributions of predictions are mapped below:

Parity Plot

Parity Plot

Observation: The predicted values accurately follow the 45 degree ground truth line, proving linear mapping capability without structural bias

Residual Histogram

Residual Histogram

Observation: A highly constrained error distribution with a sharp peak perfectly centered at zero error natively supports the robust convergence demonstrated in the JEPA phase

Final Result Metrics

Phase Task Focus Validation MSE Weights
Pre-Training LM-JEPA Unsupervised 0.125 Local Weights
Fine-Tuning QCD Squared Amplitudes 0.098 Local Weights

Extended Deliverables

View Full Methodology: Please do check the detailed Task2.5_Solution.pdf report for the ablation study, hyperparameter protocols and mathematical proofs.