From e212bf53bba108cbe58a53c9e0b86d2be31939e9 Mon Sep 17 00:00:00 2001 From: Dan Stair Date: Sat, 11 Apr 2026 23:55:07 +0000 Subject: [PATCH 1/5] initial jepa2 commit --- contrib/models/jepa-2-1/._AGENT.md | Bin 0 -> 163 bytes contrib/models/jepa-2-1/._PLAN.md | Bin 0 -> 163 bytes contrib/models/jepa-2-1/._README.md | Bin 0 -> 163 bytes contrib/models/jepa-2-1/._pyproject.toml | Bin 0 -> 163 bytes contrib/models/jepa-2-1/._src | Bin 0 -> 163 bytes contrib/models/jepa-2-1/._test | Bin 0 -> 163 bytes contrib/models/jepa-2-1/._uv.lock | Bin 0 -> 163 bytes contrib/models/jepa-2-1/AGENT.md | 105 ++ contrib/models/jepa-2-1/PLAN.md | 82 ++ contrib/models/jepa-2-1/PLAN_trn2.md | 257 +++++ contrib/models/jepa-2-1/README.md | 135 +++ contrib/models/jepa-2-1/pyproject.toml | 16 + contrib/models/jepa-2-1/src/.___init__.py | Bin 0 -> 163 bytes contrib/models/jepa-2-1/src/.___pycache__ | Bin 0 -> 163 bytes .../models/jepa-2-1/src/._modeling_jepa21.py | Bin 0 -> 163 bytes contrib/models/jepa-2-1/src/__init__.py | 5 + .../models/jepa-2-1/src/modeling_jepa21.py | 917 ++++++++++++++++++ contrib/models/jepa-2-1/test/.___init__.py | Bin 0 -> 163 bytes contrib/models/jepa-2-1/test/.___pycache__ | Bin 0 -> 163 bytes contrib/models/jepa-2-1/test/._integration | Bin 0 -> 163 bytes contrib/models/jepa-2-1/test/._unit | Bin 0 -> 163 bytes contrib/models/jepa-2-1/test/__init__.py | 0 .../jepa-2-1/test/integration/.___init__.py | Bin 0 -> 163 bytes .../jepa-2-1/test/integration/.___pycache__ | Bin 0 -> 163 bytes .../jepa-2-1/test/integration/._test_model.py | Bin 0 -> 163 bytes .../jepa-2-1/test/integration/__init__.py | 0 .../jepa-2-1/test/integration/test_model.py | 95 ++ .../models/jepa-2-1/test/unit/.___init__.py | Bin 0 -> 163 bytes .../models/jepa-2-1/test/unit/.___pycache__ | Bin 0 -> 163 bytes .../jepa-2-1/test/unit/._test_encoder.py | Bin 0 -> 163 bytes contrib/models/jepa-2-1/test/unit/__init__.py | 0 .../models/jepa-2-1/test/unit/test_encoder.py | 144 +++ contrib/models/jepa-2-1/uv.lock | 629 ++++++++++++ 33 files changed, 2385 insertions(+) create mode 100644 contrib/models/jepa-2-1/._AGENT.md create mode 100644 contrib/models/jepa-2-1/._PLAN.md create mode 100644 contrib/models/jepa-2-1/._README.md create mode 100644 contrib/models/jepa-2-1/._pyproject.toml create mode 100755 contrib/models/jepa-2-1/._src create mode 100755 contrib/models/jepa-2-1/._test create mode 100644 contrib/models/jepa-2-1/._uv.lock create mode 100644 contrib/models/jepa-2-1/AGENT.md create mode 100644 contrib/models/jepa-2-1/PLAN.md create mode 100644 contrib/models/jepa-2-1/PLAN_trn2.md create mode 100644 contrib/models/jepa-2-1/README.md create mode 100644 contrib/models/jepa-2-1/pyproject.toml create mode 100644 contrib/models/jepa-2-1/src/.___init__.py create mode 100755 contrib/models/jepa-2-1/src/.___pycache__ create mode 100644 contrib/models/jepa-2-1/src/._modeling_jepa21.py create mode 100644 contrib/models/jepa-2-1/src/__init__.py create mode 100644 contrib/models/jepa-2-1/src/modeling_jepa21.py create mode 100644 contrib/models/jepa-2-1/test/.___init__.py create mode 100755 contrib/models/jepa-2-1/test/.___pycache__ create mode 100755 contrib/models/jepa-2-1/test/._integration create mode 100755 contrib/models/jepa-2-1/test/._unit create mode 100644 contrib/models/jepa-2-1/test/__init__.py create mode 100644 contrib/models/jepa-2-1/test/integration/.___init__.py create mode 100755 contrib/models/jepa-2-1/test/integration/.___pycache__ create mode 100644 contrib/models/jepa-2-1/test/integration/._test_model.py create mode 100644 contrib/models/jepa-2-1/test/integration/__init__.py create mode 100644 contrib/models/jepa-2-1/test/integration/test_model.py create mode 100644 contrib/models/jepa-2-1/test/unit/.___init__.py create mode 100755 contrib/models/jepa-2-1/test/unit/.___pycache__ create mode 100644 contrib/models/jepa-2-1/test/unit/._test_encoder.py create mode 100644 contrib/models/jepa-2-1/test/unit/__init__.py create mode 100644 contrib/models/jepa-2-1/test/unit/test_encoder.py create mode 100644 contrib/models/jepa-2-1/uv.lock diff --git a/contrib/models/jepa-2-1/._AGENT.md b/contrib/models/jepa-2-1/._AGENT.md new file mode 100644 index 0000000000000000000000000000000000000000..5177851bca3e9f09f1906e0acb8266eb0b3b9849 GIT binary patch literal 163 zcmZQz6=P>$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K1 via NxDI if needed for long clips + - Use NKI flash attention kernels + +6. **Conv3d patch embedding**: Verify Neuron compiler support for 3D convolutions. If unsupported, can be decomposed into reshape + Conv2d. + +### Weight Loading + +Checkpoints are loaded via `torch.hub.load_state_dict_from_url`. The state dict has keys prefixed with `module.` and `backbone.` which are stripped by `_clean_backbone_key()`. For V-JEPA 2.1 distilled models, the encoder key is `ema_encoder` (not `target_encoder`). + +### Inference-Only Simplifications + +For inference, these training-only features can be removed: +- Mask application (`apply_masks`) — not used during inference +- Drop path — identity at eval +- Predictor — only needed for pretraining/anticipation +- Activation checkpointing — only for training memory savings + +## Reference Patterns + +### NxDI Contrib Structure +See `~/dev/Neuron-steering-docs/steering/nxdi-contrib.md` for submission requirements. + +### Neuron SDK Docs +See `~/dev/neuron-docs/` for: +- `neuronx-distributed/` — distributed inference patterns +- `nki-library/` — NKI kernel examples (flash attention, etc.) + +### Similar Ports +- Vision-language models in NxDI (Qwen-VL, MLLama) have image encoder components +- The Flux diffusion model in NxDI uses TP + NKI attention for large sequence lengths diff --git a/contrib/models/jepa-2-1/PLAN.md b/contrib/models/jepa-2-1/PLAN.md new file mode 100644 index 00000000..8ba0404c --- /dev/null +++ b/contrib/models/jepa-2-1/PLAN.md @@ -0,0 +1,82 @@ +# PLAN.md — V-JEPA 2.1 Neuron Port Roadmap + +## Current Status: Phase 1 — Initial Port (CPU-only) + +### Completed +- [x] Read and analyzed V-JEPA 2.1 source code (encoder, predictor, AC predictor, modules) +- [x] Read the paper (arxiv 2506.09985) +- [x] Created project structure following NxDI contrib conventions +- [x] Created self-contained encoder module (`modeling_jepa21.py`) with no upstream imports +- [x] Created CPU-only unit tests for encoder forward pass +- [x] Created README.md, AGENT.md, PLAN.md + +### In Progress +- [ ] Verify CPU forward pass matches upstream vjepa2 repo output (numerical equivalence) +- [ ] Test all 4 encoder variants (ViT-B, ViT-L, ViT-g, ViT-G) + +## Phase 2 — Neuron Compilation (on Trainium) + +### Tasks +- [ ] Set up trn2 instance with Neuron SDK 2.28+ +- [ ] Install dependencies (torch-neuronx, neuronx-distributed-inference) +- [ ] Trace ViT-B encoder with `torch_neuronx.trace()` at 384×384, 16 frames +- [ ] Verify SDPA compatibility — if Neuron doesn't support `F.scaled_dot_product_attention`, add manual attention fallback +- [ ] Verify Conv3d support — if unsupported, decompose to reshape + Conv2d +- [ ] Handle `torch.arange` in RoPE forward pass (may need to precompute) +- [ ] Trace ViT-L encoder +- [ ] Compare traced output vs CPU reference (cosine similarity > 0.99) +- [ ] Benchmark latency and throughput + +## Phase 3 — Scaling & Optimization + +### Tasks +- [ ] Test ViT-g (1B params) — may need TP>1 or NKI flash attention for 64-frame clips +- [ ] Test ViT-G (1.8B params) — likely needs TP≥2 +- [ ] If TP needed: port to NxDI pattern with NKI flash attention +- [ ] Profile memory usage at different frame counts (16, 32, 64) +- [ ] Optimize: batch compilation for multiple input shapes (frame count buckets) + +## Phase 4 — Downstream Tasks + +### Tasks +- [ ] Add attentive pooler for classification inference +- [ ] Add predictor for action anticipation inference +- [ ] Test with pretrained checkpoints on downstream benchmarks +- [ ] Add AC predictor for robotics planning inference (if applicable) + +## Phase 5 — Contrib Submission + +### Tasks +- [ ] Run full test suite on Trainium hardware +- [ ] Measure accuracy with `neuron_allclose()` against CPU reference +- [ ] Fill in compatibility matrix with actual test results +- [ ] Fill in benchmark results (throughput, latency) +- [ ] Ensure all tests pass with `pytest` +- [ ] Submit PR following NxDI contrib guidelines + +## Key Decisions + +### Why start with ViT-B/ViT-L? +- Smaller models compile faster and fit on single NeuronCore +- Validates the porting approach before scaling up +- ViT-B (86M params) and ViT-L (300M params) are practical for many downstream tasks + +### Why `torch_neuronx.trace()` first? +- Simpler than full NxDI port +- Encoder is feedforward (no KV cache, no autoregressive) +- Can always upgrade to NxDI later if TP is needed for larger models + +### Why not port the predictor first? +- Encoder is the primary inference component +- Predictor is only needed for pretraining and specific tasks (anticipation) +- Encoder features are sufficient for classification, VQA, and feature extraction + +## Risk Register + +| Risk | Impact | Mitigation | +|------|--------|------------| +| SDPA not supported on Neuron | Medium | Manual attention fallback already in codebase (`use_sdpa=False`) | +| Conv3d not supported | Low | Decompose to reshape + Conv2d | +| 64-frame ViT-G exceeds single-core HBM | High | Start with shorter clips; upgrade to NxDI with TP if needed | +| RoPE dynamic tensor creation | Medium | Precompute position tensors at trace time | +| `timm` dependency | Low | Replaced with inline `drop_path` (identity at eval) | diff --git a/contrib/models/jepa-2-1/PLAN_trn2.md b/contrib/models/jepa-2-1/PLAN_trn2.md new file mode 100644 index 00000000..d71957bb --- /dev/null +++ b/contrib/models/jepa-2-1/PLAN_trn2.md @@ -0,0 +1,257 @@ +# PLAN_trn2.md — V-JEPA 2.1 Trainium Execution Plan + +## Instance + +- **Type**: trn2.3xlarge (spot) in sa-east-1b +- **Instance ID**: i-0cae7b2ac61807cf9 +- **SSH**: `ssh -i ~/.ssh/trn2-sa-east-1.pem ubuntu@52.67.239.128` +- **Hardware**: 1 Neuron device, 4 NeuronCores, 96 GB HBM, 124 GB system RAM, 418 GB disk free +- **OS**: Ubuntu 24.04, Python 3.12.3 +- **Neuron driver**: aws-neuronx-dkms 2.27.4, runtime 2.31.24 (apt-installed) +- **Python venv**: `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate` +- **Neuron SDK**: torch-neuronx 2.9.0, neuronx-cc 2.24.5133, neuronx-distributed-inference 0.9.17334 +- **PyTorch**: 2.9.1, torch-xla 2.9.0, torchvision 0.24.1 +- **Tools**: pytest 9.0.3 + +## Workflow + +Edit code locally at `~/dev/neuron-docs/neuronx-distributed-inference/contrib/models/jepa-2-1/`, rsync to trn2, run remotely. Same pattern as the autoresearch port. + +```bash +# Sync +rsync -avz --exclude='__pycache__' --exclude='.DS_Store' --exclude='._*' \ + ~/dev/neuron-docs/neuronx-distributed-inference/contrib/models/jepa-2-1/ \ + -e "ssh -i ~/.ssh/trn2-sa-east-1.pem" ubuntu@52.67.239.128:jepa-2-1/ + +# Run remotely +ssh -i ~/.ssh/trn2-sa-east-1.pem ubuntu@52.67.239.128 \ + "cd jepa-2-1 && source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate && python ..." +``` + +--- + +## Step 1 — Sync Code & Run CPU Smoke Test on trn2 + +Rsync the project and verify the encoder runs on CPU. The Neuron SDK venv is pre-installed. + +```bash +# Sync +rsync -avz --exclude='__pycache__' --exclude='._*' \ + ~/dev/neuron-docs/neuronx-distributed-inference/contrib/models/jepa-2-1/ \ + -e "ssh -i ~/.ssh/trn2-sa-east-1.pem" ubuntu@52.67.239.128:jepa-2-1/ + +# CPU smoke test +ssh -i ~/.ssh/trn2-sa-east-1.pem ubuntu@52.67.239.128 << 'EOF' +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +cd jepa-2-1 +python -c " +from src.modeling_jepa21 import build_vjepa21_encoder +import torch +encoder = build_vjepa21_encoder(arch='vit_base', img_size=384, num_frames=16, pretrained=False) +encoder.eval() +x = torch.randn(1, 3, 16, 384, 384) +with torch.no_grad(): + out = encoder(x) +print(f'Output shape: {out.shape}') # expect (1, 4608, 768) +" +EOF +``` + +**Success criteria**: Output shape is `(1, 4608, 768)` for ViT-B with 16 frames. + +## Step 2 — Trace ViT-B Encoder with torch_neuronx + +First compilation attempt. Start with the smallest model (ViT-B, 86M params) and 16 frames. Cast to `bfloat16` — trn2 NeuronCores are heavily optimized for BF16/FP8, and explicit casting avoids unpredictable compiler auto-cast behavior. + +```python +import torch +import torch_neuronx +from src.modeling_jepa21 import build_vjepa21_encoder + +encoder = build_vjepa21_encoder(arch='vit_base', img_size=384, num_frames=16, pretrained=False) +encoder.eval().bfloat16() + +example = torch.randn(1, 3, 16, 384, 384, dtype=torch.bfloat16) +traced = torch_neuronx.trace(encoder, example, compiler_args=['--auto-cast', 'none']) +traced.save("vjepa21_vitb_16f_384.pt") +print("Compilation succeeded") +``` + +**Note on masking**: The encoder's inference path uses `masks=None` by default — no tokens are dropped, so all tensor shapes are fully static. The masking codepath is training-only and won't be triggered during tracing. + +**Expected issues** (debug in order of likelihood): + +1. **SDPA not supported** → Set `use_sdpa=False` in the encoder config or add a manual attention fallback path in `modeling_jepa21.py`. +2. **Conv3d not supported** → Two options: (a) decompose `PatchEmbed3D` into reshape + Conv2d, or (b) replace with reshape + `nn.Linear` since stride == kernel_size makes the convolution equivalent to a linear projection over flattened tubelet patches — this maps directly to MatMul on NeuronCore and may be faster. +3. **`torch.arange` in RoPE** → Precompute RoPE frequency tensors before tracing (move out of forward pass). +4. **`repeat_interleave` not supported** → Replace with equivalent `reshape`/`expand`/`reshape` sequence. + +**Success criteria**: `.pt` file saved, no compilation errors. + +## Step 3 — Validate Traced Model Output + +Compare Neuron-traced output against CPU reference. + +```python +import torch +import torch_neuronx + +# CPU reference (BF16) +encoder_cpu = build_vjepa21_encoder(arch='vit_base', img_size=384, num_frames=16, pretrained=False) +encoder_cpu.eval().bfloat16() +x = torch.randn(1, 3, 16, 384, 384, dtype=torch.bfloat16) +with torch.no_grad(): + ref = encoder_cpu(x) + +# Neuron +traced = torch.jit.load("vjepa21_vitb_16f_384.pt") +neuron_out = traced(x) + +cos_sim = torch.nn.functional.cosine_similarity(ref.flatten().float(), neuron_out.flatten().float(), dim=0) +print(f"Cosine similarity: {cos_sim.item():.6f}") # target > 0.99 +``` + +**Success criteria**: Cosine similarity > 0.99 between CPU and Neuron outputs. + +## Step 4 — Trace ViT-L Encoder + +Scale up to ViT-L (300M params, 16 frames, 4608 tokens). + +```python +encoder = build_vjepa21_encoder(arch='vit_large', img_size=384, num_frames=16, pretrained=False) +encoder.eval().bfloat16() +example = torch.randn(1, 3, 16, 384, 384, dtype=torch.bfloat16) +traced = torch_neuronx.trace(encoder, example, compiler_args=['--auto-cast', 'none']) +traced.save("vjepa21_vitl_16f_384.pt") +``` + +**Potential issue**: ViT-L has 24 layers × 16 heads. Attention matrices are 4608×4608 per head. Should fit in 96 GB HBM on a single NeuronCore, but watch for OOM during compilation (neuronx-cc can be memory-hungry on the host side — 124 GB system RAM may be tight for large graphs). + +**Success criteria**: Compilation succeeds, cosine similarity > 0.99 vs CPU. + +## Step 5 — Benchmark Latency + +Measure inference latency for both models. + +```python +import time +import torch + +traced = torch.jit.load("vjepa21_vitb_16f_384.pt") +x = torch.randn(1, 3, 16, 384, 384, dtype=torch.bfloat16) + +# Warmup +for _ in range(5): + traced(x) + +# Benchmark +times = [] +for _ in range(50): + t0 = time.perf_counter() + traced(x) + t1 = time.perf_counter() + times.append(t1 - t0) + +import statistics +print(f"ViT-B 16f: {statistics.median(times)*1000:.1f} ms median, {statistics.mean(times)*1000:.1f} ms mean") +``` + +Repeat for ViT-L. Record results in README.md compatibility matrix. + +## Step 6 — Test with Pretrained Weights (Optional) + +If Meta's checkpoints are accessible via `torch.hub`: + +```python +encoder = build_vjepa21_encoder(arch='vit_large', img_size=384, num_frames=16, pretrained=True) +``` + +This validates that the weight loading path works end-to-end on Neuron. + +--- + +## Risk Mitigation + +| Risk | Mitigation | +|------|------------| +| SDPA unsupported on Neuron | `use_sdpa=False` flag already in model; manual `q @ k.T / sqrt(d) → softmax → @ v` fallback | +| Conv3d unsupported | Decompose to `reshape` + `Conv2d`, or replace with `reshape` + `nn.Linear` (maps to NeuronCore MatMul engine) | +| Dynamic `torch.arange` in RoPE | Precompute freq tensors as buffers; register in `__init__` | +| `repeat_interleave` unsupported | Replace with `reshape`→`expand`→`reshape` | +| Host OOM during compilation (124 GB RAM) | Compile ViT-B first (smaller graph); use `NEURON_CC_FLAGS="--retry_failed_compilation"` | +| Spot instance termination | Save compiled `.pt` files to S3 after each successful compilation | + +## Out of Scope (for now) + +- ViT-g / ViT-G (need TP, Phase 3 in PLAN.md) +- 64-frame inference (18K tokens, likely needs NKI flash attention) +- Predictor / AC predictor compilation +- Attentive pooler +- Downstream task benchmarks + +--- + +## Execution Results (2026-04-11) + +All steps executed on trn2.3xlarge `i-0cae7b2ac61807cf9` in sa-east-1. +SDK: torch-neuronx 2.9.0, neuronx-cc 2.24.5133, Python 3.12.3. + +### Step 1 — Sync & CPU Smoke Test ✅ + +- Rsync: 14 files transferred +- CPU output shape: `(1, 4608, 768)` — matches expected for ViT-B/16 with 16 frames + +### Step 2 — Trace ViT-B ✅ + +- Compiled on **first attempt** with `use_sdpa=False` and `--auto-cast none` +- None of the anticipated workarounds were needed: + - Conv3d: compiled natively + - `torch.arange` in RoPE: compiled natively + - `repeat_interleave`: compiled natively +- Only required change: `use_sdpa=False` to bypass `F.scaled_dot_product_attention` + +### Step 2 (fix) — BF16 dtype fix + +- `softmax()` promotes BF16→FP32 on CPU, causing dtype mismatch in manual attention path +- Fix: added `.to(v.dtype)` after `softmax` in both `RoPEAttention` and `Attention` classes + +### Step 3 — Validate ViT-B ✅ + +| Metric | Value | +|--------|-------| +| CPU output shape | `(1, 4608, 768)` | +| Neuron output shape | `(1, 4608, 768)` | +| Cosine similarity | **0.999846** | +| Max abs diff | 0.078125 | +| Mean abs diff | 0.004509 | + +### Step 4 — Trace ViT-L ✅ + +- Compilation time: **1073s (~18 min)** +- No host OOM — 124 GB system RAM was sufficient + +### Step 4b — Validate ViT-L ✅ + +| Metric | Value | +|--------|-------| +| CPU output shape | `(1, 4608, 1024)` | +| Neuron output shape | `(1, 4608, 1024)` | +| Cosine similarity | **0.999873** | +| Max abs diff | 0.132812 | +| Mean abs diff | 0.007144 | + +### Step 5 — Benchmark Latency ✅ + +Batch=1, BF16, 16 frames, 384×384, 50 iterations after 5 warmup: + +| Model | Params | Median | Mean | p5 | p95 | +|-------|--------|--------|------|-----|-----| +| ViT-B | 86M | **164.5 ms** | 164.5 ms | 164.4 ms | 164.6 ms | +| ViT-L | 300M | **437.4 ms** | 437.5 ms | 437.4 ms | 437.6 ms | + +Sub-millisecond variance — typical of Neuron hardware deterministic execution. + +### Files modified + +- `src/modeling_jepa21.py` — `.to(v.dtype)` after softmax in manual attention paths +- `README.md` — updated compatibility matrix, compilation example, known issues diff --git a/contrib/models/jepa-2-1/README.md b/contrib/models/jepa-2-1/README.md new file mode 100644 index 00000000..d4b13aea --- /dev/null +++ b/contrib/models/jepa-2-1/README.md @@ -0,0 +1,135 @@ +# V-JEPA 2.1 on AWS Trainium + +V-JEPA 2.1 (Video Joint-Embedding Predictive Architecture) is Meta's self-supervised video foundation model. It learns visual representations by predicting masked video segments in a learned representation space, rather than pixel space. V-JEPA 2.1 extends V-JEPA 2 with knowledge distillation from a ViT-Gigantic teacher, enabling smaller student encoders (ViT-Base, ViT-Large) to achieve strong performance. + +This port targets inference on AWS Trainium (trn2) using `torch_neuronx.trace()`. + +## Model Information + +- **Source**: [facebookresearch/vjepa2](https://github.com/facebookresearch/vjepa2) +- **Paper**: [Self-Supervised Video Models Enable Understanding, Prediction and Planning](https://arxiv.org/abs/2506.09985) +- **Model Type**: Self-supervised Vision Transformer (ViT) encoder + predictor +- **Architecture**: ViT with 3D-RoPE, mask-denoising pretraining, hierarchical multi-layer output, modality embeddings (image/video) +- **License**: MIT (vjepa2 repo) + +### Available Checkpoints + +| Model | Params | Embed Dim | Depth | Heads | Resolution | Teacher | +|-------|--------|-----------|-------|-------|------------|---------| +| V-JEPA 2.1 ViT-B/16 | 86M | 768 | 12 | 12 | 384 | ViT-G distillation | +| V-JEPA 2.1 ViT-L/16 | 300M | 1024 | 24 | 16 | 384 | ViT-G distillation | +| V-JEPA 2.1 ViT-g/16 | 1B | 1408 | 40 | 22 | 384 | Self-supervised | +| V-JEPA 2.1 ViT-G/16 | 1.8B | 1664 | 48 | 26 | 384 | Self-supervised | + +## Architecture Overview + +V-JEPA 2.1 consists of: + +1. **Encoder** (`VisionTransformer`): Processes video frames patchified into 2×16×16 tubelets. Uses 3D-RoPE for spatiotemporal position encoding. Outputs hierarchical features from multiple intermediate layers (e.g., layers [5, 11, 17, 23] for ViT-L depth=24). + +2. **Predictor** (`VisionTransformerPredictor`): Takes encoder features + learnable mask tokens and predicts representations of masked patches. Uses multi-layer hierarchical input from the encoder via a learned projection. + +3. **Attentive Pooler** (optional, for classification): Cross-attention pooling over encoder features for downstream classification tasks. + +Key differences from V-JEPA 2: +- Hierarchical multi-layer output with per-layer norms (`norms_block`) +- Modality embeddings (separate for image vs video input) +- `img_temporal_dim_size` for handling single-frame image inputs with tubelet_size=1 +- Distillation-aware predictor with `n_output_distillation` controlling which layers contribute +- `interpolate_rope` for resolution-flexible RoPE + +## Inference Approach + +For inference on Trainium, we use `torch_neuronx.trace()` on the encoder. The encoder is the primary component needed for downstream tasks (classification, VQA, feature extraction). The predictor is only needed for pretraining and action anticipation tasks. + +### Why `torch_neuronx.trace()` (not NxDI) + +- The encoder is a standard ViT without KV cache or autoregressive decoding +- At 384×384 resolution with 64 frames: seq_len = (64/2) × (384/16)² = 32 × 576 = 18,432 tokens per clip +- For single-frame image inference: seq_len = 576 tokens (trivially fits) +- For short video clips (16 frames): seq_len = 8 × 576 = 4,608 tokens +- NxDI's KV cache and flash attention infrastructure is unnecessary for non-autoregressive models +- `torch_neuronx.trace()` is simpler and sufficient for encoder-only inference + +### Compilation Strategy + +- Trace the encoder with a fixed input shape (batch, channels, frames, height, width) +- Use `torch_neuronx.trace()` with example inputs +- For variable-length video, compile multiple buckets or pad to max length + +## Usage + +```python +import torch +import torch_neuronx + +# Load encoder (CPU reference) +from src.modeling_jepa21 import build_vjepa21_encoder + +encoder = build_vjepa21_encoder( + arch="vit_large", + img_size=384, + num_frames=16, + pretrained=False, # set True when checkpoint available +) +encoder.eval() + +# Example: single image input (B, C, T, H, W) +image_input = torch.randn(1, 3, 1, 384, 384) +with torch.no_grad(): + features = encoder(image_input) +# features shape: (1, 576, 1024) for ViT-L + +# Example: video input +video_input = torch.randn(1, 3, 16, 384, 384) +with torch.no_grad(): + features = encoder(video_input) +# features shape: (1, 4608, 1024) for ViT-L with 16 frames +``` + +### Neuron Compilation (on Trainium instance) + +```python +import torch +import torch_neuronx +from src.modeling_jepa21 import build_vjepa21_encoder + +encoder = build_vjepa21_encoder(arch="vit_large", img_size=384, num_frames=16, use_sdpa=False) +encoder.eval().bfloat16() + +example_input = torch.randn(1, 3, 16, 384, 384, dtype=torch.bfloat16) +traced = torch_neuronx.trace(encoder, example_input, compiler_args=["--auto-cast", "none"]) +traced.save("vjepa21_vitl_16f_384.pt") +``` + +## Compatibility Matrix + +| Instance | SDK | Model | Frames | Resolution | Dtype | Compile | Cosine Sim | Latency (median) | +|----------|-----|-------|--------|------------|-------|---------|------------|-------------------| +| trn2.3xlarge | 2.27 (torch-neuronx 2.9.0, neuronx-cc 2.24.5133) | ViT-B (86M) | 16 | 384×384 | BF16 | ✅ PASS | 0.9998 | 164.5 ms | +| trn2.3xlarge | 2.27 (torch-neuronx 2.9.0, neuronx-cc 2.24.5133) | ViT-L (300M) | 16 | 384×384 | BF16 | ✅ PASS | 0.9999 | 437.4 ms | +| inf2.xlarge | — | — | — | — | — | Not tested | — | — | + +## Example Checkpoints + +* V-JEPA 2.1 weights are loaded via `torch.hub` from Meta's servers (see `hubconf.py` in vjepa2 repo) + +## Testing Instructions + +```bash +# CPU-only tests (runs on MacBook) +cd contrib/models/jepa-2-1 +pytest test/ -v + +# On Trainium instance +pytest test/integration/test_model.py -v +``` + +## Known Issues + +- `use_sdpa=False` is required for Neuron compilation — `F.scaled_dot_product_attention` is not supported by `torch_neuronx.trace()`. The manual attention fallback (`q @ k.T * scale → softmax → @ v`) works correctly. +- BF16 softmax promotes to FP32 on CPU; `.to(v.dtype)` cast added after softmax to maintain dtype consistency. +- 3D-RoPE uses a duplicated frequency pattern (known upstream bug, preserved for checkpoint compatibility) +- `timm` is required as a dependency for `drop_path` (replaced with inline implementation) +- Full 64-frame ViT-G inference may require TP>1 on Trainium due to memory +- Conv3d, `torch.arange`, and `repeat_interleave` all compile successfully on neuronx-cc 2.24.5133 diff --git a/contrib/models/jepa-2-1/pyproject.toml b/contrib/models/jepa-2-1/pyproject.toml new file mode 100644 index 00000000..5506ccbc --- /dev/null +++ b/contrib/models/jepa-2-1/pyproject.toml @@ -0,0 +1,16 @@ +[project] +name = "jepa-2-1-neuron" +version = "0.1.0" +description = "V-JEPA 2.1 encoder port for AWS Trainium" +requires-python = ">=3.10" +dependencies = [ + "torch>=2.1", +] + +[project.optional-dependencies] +test = [ + "pytest>=7.0", +] + +[tool.pytest.ini_options] +testpaths = ["test"] diff --git a/contrib/models/jepa-2-1/src/.___init__.py b/contrib/models/jepa-2-1/src/.___init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5177851bca3e9f09f1906e0acb8266eb0b3b9849 GIT binary patch literal 163 zcmZQz6=P>$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K 0 else None + + omega = torch.arange(D // 2, dtype=x.dtype, device=x.device) + omega /= D / 2.0 + omega = 1.0 / 10000**omega + freq = torch.einsum("..., f -> ... f", pos, omega) + + emb_sin = freq.sin().repeat_interleave(2, dim=-1) + emb_cos = freq.cos().repeat_interleave(2, dim=-1) + + y = x_ctx.unflatten(-1, (-1, 2)) + y1, y2 = y.unbind(dim=-1) + y = torch.stack((-y2, y1), dim=-1).flatten(-2) + + out_ctx = (x_ctx * emb_cos) + (y * emb_sin) + + parts = [] + if n_cls: + parts.append(x_cls) + parts.append(out_ctx) + if n_registers: + parts.append(x_reg) + return torch.cat(parts, dim=-2) + + +# --------------------------------------------------------------------------- +# Attention Modules +# --------------------------------------------------------------------------- +class RoPEAttention(nn.Module): + """Multi-head attention with 3D RoPE for V-JEPA 2.1.""" + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + use_sdpa=True, + use_nki_flash=False, + grid_size=14, + is_causal=False, + n_registers=0, + has_cls_first=False, + interpolate_rope=False, + patch_size=16, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.use_nki_flash = use_nki_flash and (_nki_flash_attn is not None) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop_prob = proj_drop + self.proj_drop = nn.Dropout(proj_drop) + self.use_sdpa = use_sdpa + self.d_dim = int(2 * ((head_dim // 3) // 2)) + self.h_dim = int(2 * ((head_dim // 3) // 2)) + self.w_dim = int(2 * ((head_dim // 3) // 2)) + self.grid_size = grid_size + self.is_causal = is_causal + self.n_registers = n_registers + self.has_cls_first = has_cls_first + self.interpolate_rope = interpolate_rope + self.pretrained_patch_size = patch_size + if patch_size == 14: + self.pretrained_grid_size = int(252 / patch_size) + elif patch_size == 16: + self.pretrained_grid_size = int(256 / patch_size) + else: + self.pretrained_grid_size = grid_size + + def _get_frame_pos(self, ids, H_patches=None, W_patches=None): + if H_patches is None or W_patches is None: + tokens_per_frame = int(self.grid_size * self.grid_size) + else: + tokens_per_frame = int(H_patches * W_patches) + return ids // tokens_per_frame + + def _get_height_pos(self, ids, H_patches=None, W_patches=None): + if H_patches is None or W_patches is None: + tokens_per_frame = int(self.grid_size * self.grid_size) + tokens_per_row = self.grid_size + else: + tokens_per_frame = int(H_patches * W_patches) + tokens_per_row = W_patches + frame_ids = self._get_frame_pos(ids, H_patches, W_patches) + ids = ids - tokens_per_frame * frame_ids + return ids // tokens_per_row + + def separate_positions(self, ids, H_patches=None, W_patches=None): + if H_patches is None or W_patches is None: + tokens_per_frame = int(self.grid_size * self.grid_size) + tokens_per_row = self.grid_size + else: + tokens_per_frame = int(H_patches * W_patches) + tokens_per_row = W_patches + frame_ids = self._get_frame_pos(ids, H_patches, W_patches) + height_ids = self._get_height_pos(ids, H_patches, W_patches) + width_ids = (ids - tokens_per_frame * frame_ids) - tokens_per_row * height_ids + return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids + + def forward(self, x, mask=None, T=None, H_patches=None, W_patches=None, return_attn=False): + B, N, C = x.size() + + qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + if mask is not None: + mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1) + d_mask, h_mask, w_mask = self.separate_positions(mask, H_patches, W_patches) + else: + if T is None or H_patches is None or W_patches is None: + N_ctx = N - self.n_registers + grid_depth = int(N_ctx // (self.grid_size * self.grid_size)) + mask = torch.arange( + int(grid_depth * self.grid_size * self.grid_size), device=x.device + ) + else: + mask = torch.arange(int(T * H_patches * W_patches), device=x.device) + d_mask, h_mask, w_mask = self.separate_positions(mask, H_patches, W_patches) + + if self.interpolate_rope: + if H_patches is None: + H_patches = int(self.grid_size) + if W_patches is None: + W_patches = int(self.grid_size) + h_mask = h_mask * (self.pretrained_grid_size - 1) / (H_patches - 1) + w_mask = w_mask * (self.pretrained_grid_size - 1) / (W_patches - 1) + + s = 0 + qd = rotate_queries_or_keys_v21(q[..., s:s + self.d_dim], pos=d_mask, + n_registers=self.n_registers, has_cls_first=self.has_cls_first) + kd = rotate_queries_or_keys_v21(k[..., s:s + self.d_dim], pos=d_mask, + n_registers=self.n_registers, has_cls_first=self.has_cls_first) + s += self.d_dim + qh = rotate_queries_or_keys_v21(q[..., s:s + self.h_dim], pos=h_mask, + n_registers=self.n_registers, has_cls_first=self.has_cls_first) + kh = rotate_queries_or_keys_v21(k[..., s:s + self.h_dim], pos=h_mask, + n_registers=self.n_registers, has_cls_first=self.has_cls_first) + s += self.h_dim + qw = rotate_queries_or_keys_v21(q[..., s:s + self.w_dim], pos=w_mask, + n_registers=self.n_registers, has_cls_first=self.has_cls_first) + kw = rotate_queries_or_keys_v21(k[..., s:s + self.w_dim], pos=w_mask, + n_registers=self.n_registers, has_cls_first=self.has_cls_first) + s += self.w_dim + + if s < self.head_dim: + qr = q[..., s:] + kr = k[..., s:] + q = torch.cat([qd, qh, qw, qr], dim=-1) + k = torch.cat([kd, kh, kw, kr], dim=-1) + else: + q = torch.cat([qd, qh, qw], dim=-1) + k = torch.cat([kd, kh, kw], dim=-1) + + if self.use_nki_flash: + # NKI ISA kernel layout: q/k=(B*H, d, seqlen), v=(B*H, seqlen, d), out=(B*H, seqlen, d) + q_nki = q.reshape(B * self.num_heads, N, self.head_dim).permute(0, 2, 1).contiguous() + k_nki = k.reshape(B * self.num_heads, N, self.head_dim).permute(0, 2, 1).contiguous() + v_nki = v.reshape(B * self.num_heads, N, self.head_dim).contiguous() + attn_output = torch.zeros(B * self.num_heads, N, self.head_dim, dtype=q.dtype, device=q.device) + _nki_flash_attn(q_nki, k_nki, v_nki, self.scale, attn_output, + kernel_name="AttentionMMSoftmaxMMWithoutSwap") + x = attn_output.reshape(B, self.num_heads, N, self.head_dim) + attn = None + elif self.use_sdpa: + with torch.backends.cuda.sdp_kernel(): + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal + ) + attn = None + else: + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1).to(v.dtype) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + if return_attn: + return x, attn + return x, None + + +class Attention(nn.Module): + """Standard multi-head attention (no RoPE).""" + + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, + attn_drop=0.0, proj_drop=0.0, use_sdpa=True, is_causal=False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop_prob = proj_drop + self.proj_drop = nn.Dropout(proj_drop) + self.use_sdpa = use_sdpa + self.is_causal = is_causal + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + if self.use_sdpa: + with torch.backends.cuda.sdp_kernel(): + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal + ) + else: + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1).to(v.dtype) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +# --------------------------------------------------------------------------- +# MLP Modules +# --------------------------------------------------------------------------- +class MLP(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class SwiGLUFFN(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, + act_layer=nn.SiLU, drop=0.0, wide_silu=True): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + swiglu_hidden_features = hidden_features + if wide_silu: + swiglu_hidden_features = int(2 * hidden_features / 3) + align_as = 8 + swiglu_hidden_features = (swiglu_hidden_features + align_as - 1) // align_as * align_as + self.fc1 = nn.Linear(in_features, swiglu_hidden_features) + self.fc2 = nn.Linear(in_features, swiglu_hidden_features) + self.act = act_layer() + self.fc3 = nn.Linear(swiglu_hidden_features, out_features) + + def forward(self, x): + x1 = self.fc1(x) + x2 = self.fc2(x) + hidden = F.silu(x1) * x2 + return self.fc3(hidden) + + +# --------------------------------------------------------------------------- +# Transformer Block +# --------------------------------------------------------------------------- +class Block(nn.Module): + def __init__( + self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, + drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, wide_silu=True, + norm_layer=nn.LayerNorm, use_sdpa=True, use_nki_flash=False, is_causal=False, + grid_size=16, use_rope=False, n_registers=0, has_cls_first=False, + interpolate_rope=False, patch_size=16, **kwargs, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.use_rope = use_rope + if use_rope: + self.attn = RoPEAttention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, use_sdpa=use_sdpa, use_nki_flash=use_nki_flash, + is_causal=is_causal, grid_size=grid_size, proj_drop=drop, + n_registers=n_registers, has_cls_first=has_cls_first, + interpolate_rope=interpolate_rope, patch_size=patch_size, + ) + else: + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, use_sdpa=use_sdpa, is_causal=is_causal, proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + if act_layer is nn.SiLU: + self.mlp = SwiGLUFFN( + in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, wide_silu=wide_silu, drop=drop, + ) + else: + self.mlp = MLP( + in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop, + ) + + def forward(self, x, mask=None, T=None, H_patches=None, W_patches=None, + return_attn=False, mode="video"): + if self.use_rope: + y, attn = self.attn( + self.norm1(x), mask=mask, T=T, H_patches=H_patches, + W_patches=W_patches, return_attn=return_attn, + ) + else: + y = self.attn(self.norm1(x)) + attn = None + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + if return_attn: + return x, attn + return x, None + + +# --------------------------------------------------------------------------- +# V-JEPA 2.1 Vision Transformer Encoder +# --------------------------------------------------------------------------- +class VisionTransformer(nn.Module): + """V-JEPA 2.1 Vision Transformer encoder for inference.""" + + def __init__( + self, + img_size=(384, 384), + patch_size=16, + num_frames=64, + tubelet_size=2, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + init_std=0.02, + out_layers=None, + uniform_power=False, + use_silu=False, + wide_silu=True, + use_sdpa=True, + use_nki_flash=False, + use_activation_checkpointing=False, + is_causal=False, + use_rope=True, + handle_nonsquare_inputs=True, + img_temporal_dim_size=None, + n_registers=0, + has_cls_first=False, + interpolate_rope=False, + modality_embedding=True, + n_output_distillation=4, + **kwargs, + ): + super().__init__() + self.num_features = self.embed_dim = embed_dim + self.num_heads = num_heads + self.out_layers = out_layers + self.handle_nonsquare_inputs = handle_nonsquare_inputs + self.img_temporal_dim_size = img_temporal_dim_size + + if isinstance(img_size, int): + img_size = (img_size, img_size) + self.img_height, self.img_width = img_size + self.patch_size = patch_size + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.is_video = num_frames > 1 + + self.use_activation_checkpointing = use_activation_checkpointing + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + if self.is_video: + self.patch_embed = PatchEmbed3D( + patch_size=patch_size, tubelet_size=tubelet_size, + in_chans=in_chans, embed_dim=embed_dim, + ) + self.num_patches = (num_frames // tubelet_size) * (img_size[0] // patch_size) * (img_size[1] // patch_size) + else: + self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) + + # Separate image patch embed for img_temporal_dim_size + if self.img_temporal_dim_size is not None: + self.patch_embed_img = PatchEmbed3D( + patch_size=patch_size, tubelet_size=1, + in_chans=in_chans, embed_dim=embed_dim, + ) + else: + self.patch_embed_img = None + + self.uniform_power = uniform_power + self.use_rope = use_rope + + self.blocks = nn.ModuleList([ + Block( + use_rope=use_rope, grid_size=img_size[0] // patch_size, + grid_depth=num_frames // tubelet_size, dim=embed_dim, + num_heads=num_heads, mlp_ratio=mlp_ratio, use_sdpa=use_sdpa, + use_nki_flash=use_nki_flash, is_causal=is_causal, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, act_layer=nn.SiLU if use_silu else nn.GELU, + wide_silu=wide_silu, attn_drop=attn_drop_rate, drop_path=dpr[i], + norm_layer=norm_layer, n_registers=n_registers, + has_cls_first=has_cls_first, interpolate_rope=interpolate_rope, + patch_size=patch_size, + ) + for i in range(depth) + ]) + + self.init_std = init_std + self.apply(self._init_weights) + self._rescale_blocks() + + # Hierarchical layer indices + _layer_map = { + 12: [2, 5, 8, 11], + 24: [5, 11, 17, 23], + 40: [9, 19, 29, 39], + 48: [11, 23, 37, 47], + } + self.hierarchical_layers = _layer_map.get(depth, [depth - 1]) + + if n_output_distillation == 4: + self.out_layers_distillation = self.hierarchical_layers[:] + elif n_output_distillation == 1: + self.out_layers_distillation = [self.hierarchical_layers[-1]] + else: + self.out_layers_distillation = self.hierarchical_layers[-n_output_distillation:] + + self.norms_block = nn.ModuleList([ + norm_layer(embed_dim) for _ in range(len(self.hierarchical_layers)) + ]) + + self.cls_token = None + self.return_hierarchical = False + + # Modality embeddings + self.modality_embedding = False + if modality_embedding: + self.img_mod_embed = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.video_mod_embed = nn.Parameter(torch.zeros(1, 1, embed_dim)) + nn.init.normal_(self.img_mod_embed, std=1e-6) + nn.init.normal_(self.video_mod_embed, std=1e-6) + self.modality_embedding = True + + def _init_weights(self, m): + if isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + return + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.Conv2d, nn.Conv3d)): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def check_temporal_dim(self, shape) -> bool: + if self.img_temporal_dim_size is not None: + if shape[2] == self.img_temporal_dim_size: + return True + return False + + def forward(self, x, masks=None, training=False): + """ + Args: + x: input tensor. Image: (B, C, H, W) or (B, C, 1, H, W). Video: (B, C, T, H, W). + masks: optional mask indices (training only). + training: if True, return hierarchical concatenated features. + Returns: + Tensor of shape (B, N, D) where N is the number of patch tokens. + """ + if masks is not None and not isinstance(masks, list): + masks = [masks] + + if x.ndim == 4: + _, _, H, W = x.shape + T = 1 + elif x.ndim == 5: + _, _, T_raw, H, W = x.shape + if self.check_temporal_dim(x.shape): + T = T_raw // 1 + else: + T = T_raw // self.tubelet_size + + H_patches = H // self.patch_size + W_patches = W // self.patch_size + if not self.handle_nonsquare_inputs: + T = H_patches = W_patches = None + + # Patch embedding + if x.ndim == 5 and self.check_temporal_dim(x.shape): + assert self.patch_embed_img is not None + x = self.patch_embed_img(x) + mode = "img" + if self.modality_embedding: + x = x + self.img_mod_embed.repeat(x.shape[0], 1, 1) + else: + x = self.patch_embed(x) + mode = "video" + if self.modality_embedding: + x = x + self.video_mod_embed.repeat(x.shape[0], 1, 1) + + # Masking (training only) + if masks is not None: + from src.masks.utils import apply_masks + x = apply_masks(x, masks) + masks = torch.cat(masks, dim=0) + + # Forward through blocks + hier = [] + for i, blk in enumerate(self.blocks): + x, _attn = blk( + x, mask=masks, T=T, H_patches=H_patches, W_patches=W_patches, + return_attn=False, mode=mode, + ) + if i in self.out_layers_distillation: + out_idx = self.hierarchical_layers.index(i) + hier.append(self.norms_block[out_idx](x)) + + if training or self.return_hierarchical: + return torch.cat(hier, dim=2) + else: + # Return last hierarchical layer's normed output + return self.norms_block[-1](x) + + +# --------------------------------------------------------------------------- +# V-JEPA 2.1 Predictor (for completeness — not needed for basic inference) +# --------------------------------------------------------------------------- +class VisionTransformerPredictor(nn.Module): + """V-JEPA 2.1 predictor for mask prediction / action anticipation.""" + + def __init__( + self, + img_size=(384, 384), + patch_size=16, + num_frames=64, + tubelet_size=2, + embed_dim=768, + predictor_embed_dim=384, + out_embed_dim=None, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + init_std=0.02, + uniform_power=False, + use_mask_tokens=False, + num_mask_tokens=2, + zero_init_mask_tokens=True, + use_silu=False, + wide_silu=True, + use_rope=True, + n_output_distillation=4, + teacher_embed_dim=None, + return_all_tokens=False, + modality_embedding=True, + img_temporal_dim_size=None, + interpolate_rope=False, + **kwargs, + ): + super().__init__() + self.return_all_tokens = return_all_tokens + + if isinstance(img_size, int): + img_size = (img_size, img_size) + self.img_height, self.img_width = img_size + self.patch_size = patch_size + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.is_video = num_frames > 1 + self.grid_height = img_size[0] // patch_size + self.grid_width = img_size[1] // patch_size + self.grid_depth = num_frames // tubelet_size + + if self.is_video: + self.num_patches = self.grid_depth * self.grid_height * self.grid_width + else: + self.num_patches = self.grid_height * self.grid_width + + # Hierarchical layers + _layer_map = {4: [0,1,2,3], 8: [1,3,5,7], 12: [2,5,8,11], 20: [4,9,14,19], 24: [4,11,17,23], 40: [9,19,29,39]} + all_hier = _layer_map.get(depth, list(range(depth))) + self.hierarchical_layers = all_hier[-n_output_distillation:] + + act_layer_mlp = nn.SiLU if use_silu else nn.GELU + if len(self.hierarchical_layers) == 1: + self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) + else: + self.predictor_embed = nn.Sequential( + nn.Linear(embed_dim * len(self.hierarchical_layers), embed_dim, bias=True), + act_layer_mlp(), + nn.Linear(embed_dim, predictor_embed_dim, bias=True), + ) + + # Mask tokens + self.mask_tokens = None + self.num_mask_tokens = 0 + if use_mask_tokens: + self.num_mask_tokens = num_mask_tokens + self.mask_tokens = nn.ParameterList([ + nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + for _ in range(num_mask_tokens) + ]) + + # Modality embeddings + self.modality_embedding = False + if img_temporal_dim_size is not None and modality_embedding: + self.video_mod_embed = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + self.img_mod_embed = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + nn.init.normal_(self.video_mod_embed, std=1e-6) + nn.init.normal_(self.img_mod_embed, std=1e-6) + self.modality_embedding = True + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + self.use_rope = use_rope + self.predictor_blocks = nn.ModuleList([ + Block( + use_rope=use_rope, grid_size=self.grid_height, grid_depth=self.grid_depth, + dim=predictor_embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, + act_layer=nn.SiLU if use_silu else nn.GELU, wide_silu=wide_silu, + attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + interpolate_rope=interpolate_rope, patch_size=patch_size, + ) + for i in range(depth) + ]) + + if out_embed_dim is None: + if teacher_embed_dim is not None: + out_embed_dim = teacher_embed_dim // len(self.hierarchical_layers) + else: + out_embed_dim = embed_dim + + self.predictor_norm = norm_layer(predictor_embed_dim) + self.predictor_proj = nn.Linear( + predictor_embed_dim, len(self.hierarchical_layers) * out_embed_dim, bias=True + ) + if self.return_all_tokens: + self.predictor_proj_context = nn.Linear( + predictor_embed_dim, out_embed_dim * len(self.hierarchical_layers), bias=True + ) + + self.init_std = init_std + if use_mask_tokens and not zero_init_mask_tokens: + for mt in self.mask_tokens: + trunc_normal_(mt, std=init_std) + self.apply(self._init_weights) + self._rescale_blocks() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _rescale_blocks(self): + for layer_id, layer in enumerate(self.predictor_blocks): + layer.attn.proj.weight.data.div_(math.sqrt(2.0 * (layer_id + 1))) + layer.mlp.fc2.weight.data.div_(math.sqrt(2.0 * (layer_id + 1))) + + def forward(self, x, masks_x=None, masks_y=None, mod="video", mask_index=1): + """Forward pass. For inference without masks, just pass features through.""" + if masks_x is None or masks_y is None: + # Simple forward without masking (inference mode) + x = self.predictor_embed(x) + for blk in self.predictor_blocks: + x, _ = blk(x) + x = self.predictor_norm(x) + x = self.predictor_proj(x) + return x, None + + # Full masked prediction (training mode) — not ported for Neuron + raise NotImplementedError("Masked prediction forward not implemented for Neuron inference") + + +# --------------------------------------------------------------------------- +# Builder Functions +# --------------------------------------------------------------------------- +_ARCH_CONFIGS = { + "vit_base": dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0), + "vit_large": dict(embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4.0), + "vit_giant": dict(embed_dim=1408, depth=40, num_heads=22, mlp_ratio=48 / 11), + "vit_gigantic": dict(embed_dim=1664, depth=48, num_heads=26, mlp_ratio=64 / 13), +} + + +def build_vjepa21_encoder( + arch: str = "vit_large", + img_size: int = 384, + num_frames: int = 64, + patch_size: int = 16, + tubelet_size: int = 2, + use_sdpa: bool = True, + use_nki_flash: bool = False, + use_rope: bool = True, + interpolate_rope: bool = True, + img_temporal_dim_size: int = 1, + modality_embedding: bool = True, + n_output_distillation: int = 4, + pretrained: bool = False, + **kwargs, +) -> VisionTransformer: + """Build a V-JEPA 2.1 encoder. + + Args: + arch: one of 'vit_base', 'vit_large', 'vit_giant', 'vit_gigantic' + img_size: spatial resolution (square) + num_frames: number of video frames + pretrained: if True, load pretrained weights (requires network access) + """ + if arch not in _ARCH_CONFIGS: + raise ValueError(f"Unknown arch '{arch}'. Choose from {list(_ARCH_CONFIGS.keys())}") + + cfg = _ARCH_CONFIGS[arch] + encoder = VisionTransformer( + img_size=(img_size, img_size), + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + use_sdpa=use_sdpa, + use_nki_flash=use_nki_flash, + use_silu=False, + wide_silu=True, + uniform_power=False, + use_rope=use_rope, + interpolate_rope=interpolate_rope, + img_temporal_dim_size=img_temporal_dim_size, + modality_embedding=modality_embedding, + n_output_distillation=n_output_distillation, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + qkv_bias=True, + **cfg, + **kwargs, + ) + + if pretrained: + _load_pretrained_weights(encoder, arch) + + return encoder + + +def _load_pretrained_weights(encoder, arch): + """Load pretrained V-JEPA 2.1 weights from Meta's servers.""" + VJEPA_BASE_URL = "https://dl.fbaipublicfiles.com/vjepa2" + _CKPT_MAP = { + "vit_base": "vjepa2_1_vitb_dist_vitG_384", + "vit_large": "vjepa2_1_vitl_dist_vitG_384", + "vit_giant": "vjepa2_1_vitg_384", + "vit_gigantic": "vjepa2_1_vitG_384", + } + model_file = _CKPT_MAP[arch] + url = f"{VJEPA_BASE_URL}/{model_file}.pt" + + checkpoint_key = "ema_encoder" if arch in ("vit_base", "vit_large") else "target_encoder" + + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + encoder_sd = state_dict[checkpoint_key] + + # Clean keys: remove 'module.' and 'backbone.' prefixes + cleaned = {} + for k, v in encoder_sd.items(): + k = k.replace("module.", "").replace("backbone.", "") + cleaned[k] = v + + encoder.load_state_dict(cleaned, strict=True) + print(f"Loaded pretrained weights from {url} (key={checkpoint_key})") diff --git a/contrib/models/jepa-2-1/test/.___init__.py b/contrib/models/jepa-2-1/test/.___init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5177851bca3e9f09f1906e0acb8266eb0b3b9849 GIT binary patch literal 163 zcmZQz6=P>$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K (B, N, D).""" + x = torch.randn(1, 3, 16, 384, 384) + with torch.no_grad(): + out = vit_base(x) + # 16 frames / tubelet_size=2 = 8 temporal tokens + # 384/16 = 24 spatial patches per dim + # 8 * 24 * 24 = 4608 tokens + assert out.shape == (1, 4608, 768), f"Expected (1, 4608, 768), got {out.shape}" + + def test_image_forward_shape(self, vit_base): + """Test single-frame image input via img_temporal_dim_size=1.""" + x = torch.randn(1, 3, 1, 384, 384) + with torch.no_grad(): + out = vit_base(x) + # 1 frame / tubelet_size=1 (img path) = 1 temporal token + # 24 * 24 = 576 spatial tokens + assert out.shape == (1, 576, 768), f"Expected (1, 576, 768), got {out.shape}" + + def test_batch_forward(self, vit_base): + """Test batched input.""" + x = torch.randn(2, 3, 16, 384, 384) + with torch.no_grad(): + out = vit_base(x) + assert out.shape == (2, 4608, 768) + + def test_hierarchical_output(self, vit_base): + """Test hierarchical output mode returns concatenated features.""" + vit_base.return_hierarchical = True + x = torch.randn(1, 3, 16, 384, 384) + with torch.no_grad(): + out = vit_base(x, training=True) + # 4 hierarchical layers * embed_dim=768 = 3072 + assert out.shape[0] == 1 + assert out.shape[1] == 4608 + assert out.shape[2] == 768 * 4 # 4 distillation layers + vit_base.return_hierarchical = False + + def test_output_deterministic(self, vit_base): + """Test that eval mode produces deterministic output.""" + x = torch.randn(1, 3, 16, 384, 384) + with torch.no_grad(): + out1 = vit_base(x) + out2 = vit_base(x) + assert torch.allclose(out1, out2, atol=1e-6) + + def test_256_resolution(self): + """Test with 256x256 resolution.""" + encoder = build_vjepa21_encoder( + arch="vit_base", img_size=256, num_frames=16, pretrained=False + ) + encoder.eval() + x = torch.randn(1, 3, 16, 256, 256) + with torch.no_grad(): + out = encoder(x) + # 8 * 16 * 16 = 2048 tokens + assert out.shape == (1, 2048, 768) + + +class TestEncoderComponents: + """Test individual components.""" + + def test_patch_embed_3d(self): + from modeling_jepa21 import PatchEmbed3D + pe = PatchEmbed3D(patch_size=16, tubelet_size=2, in_chans=3, embed_dim=768) + x = torch.randn(1, 3, 16, 384, 384) + out = pe(x) + # (16/2) * (384/16) * (384/16) = 8 * 24 * 24 = 4608 + assert out.shape == (1, 4608, 768) + + def test_patch_embed_3d_image(self): + from modeling_jepa21 import PatchEmbed3D + pe = PatchEmbed3D(patch_size=16, tubelet_size=1, in_chans=3, embed_dim=768) + x = torch.randn(1, 3, 1, 384, 384) + out = pe(x) + assert out.shape == (1, 576, 768) + + def test_rope_attention(self): + from modeling_jepa21 import RoPEAttention + attn = RoPEAttention( + dim=768, num_heads=12, qkv_bias=True, use_sdpa=False, + grid_size=24, interpolate_rope=True, patch_size=16, + ) + x = torch.randn(1, 576, 768) + out, _ = attn(x, T=1, H_patches=24, W_patches=24) + assert out.shape == (1, 576, 768) + + def test_block(self): + from modeling_jepa21 import Block + blk = Block( + dim=768, num_heads=12, mlp_ratio=4.0, qkv_bias=True, + use_rope=True, grid_size=24, interpolate_rope=True, patch_size=16, + norm_layer=torch.nn.LayerNorm, + ) + x = torch.randn(1, 576, 768) + out, _ = blk(x, T=1, H_patches=24, W_patches=24) + assert out.shape == (1, 576, 768) diff --git a/contrib/models/jepa-2-1/uv.lock b/contrib/models/jepa-2-1/uv.lock new file mode 100644 index 00000000..c8aa441a --- /dev/null +++ b/contrib/models/jepa-2-1/uv.lock @@ -0,0 +1,629 @@ +version = 1 +revision = 3 +requires-python = ">=3.10" +resolution-markers = [ + "python_full_version >= '3.11'", + "python_full_version < '3.11'", +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + +[[package]] +name = "cuda-bindings" +version = "13.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-pathfinder" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/fe/7351d7e586a8b4c9f89731bfe4cf0148223e8f9903ff09571f78b3fb0682/cuda_bindings-13.2.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:08b395f79cb89ce0cd8effff07c4a1e20101b873c256a1aeb286e8fd7bd0f556", size = 5744254, upload-time = "2026-03-11T00:12:29.798Z" }, + { url = "https://files.pythonhosted.org/packages/aa/ef/184aa775e970fc089942cd9ec6302e6e44679d4c14549c6a7ea45bf7f798/cuda_bindings-13.2.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d6f3682ec3c4769326aafc67c2ba669d97d688d0b7e63e659d36d2f8b72f32d6", size = 6329075, upload-time = "2026-03-11T00:12:32.319Z" }, + { url = "https://files.pythonhosted.org/packages/e0/a9/3a8241c6e19483ac1f1dcf5c10238205dcb8a6e9d0d4d4709240dff28ff4/cuda_bindings-13.2.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:721104c603f059780d287969be3d194a18d0cc3b713ed9049065a1107706759d", size = 5730273, upload-time = "2026-03-11T00:12:37.18Z" }, + { url = "https://files.pythonhosted.org/packages/e9/94/2748597f47bb1600cd466b20cab4159f1530a3a33fe7f70fee199b3abb9e/cuda_bindings-13.2.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1eba9504ac70667dd48313395fe05157518fd6371b532790e96fbb31bbb5a5e1", size = 6313924, upload-time = "2026-03-11T00:12:39.462Z" }, + { url = "https://files.pythonhosted.org/packages/52/c8/b2589d68acf7e3d63e2be330b84bc25712e97ed799affbca7edd7eae25d6/cuda_bindings-13.2.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e865447abfb83d6a98ad5130ed3c70b1fc295ae3eeee39fd07b4ddb0671b6788", size = 5722404, upload-time = "2026-03-11T00:12:44.041Z" }, + { url = "https://files.pythonhosted.org/packages/1f/92/f899f7bbb5617bb65ec52a6eac1e9a1447a86b916c4194f8a5001b8cde0c/cuda_bindings-13.2.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46d8776a55d6d5da9dd6e9858fba2efcda2abe6743871dee47dd06eb8cb6d955", size = 6320619, upload-time = "2026-03-11T00:12:45.939Z" }, + { url = "https://files.pythonhosted.org/packages/df/93/eef988860a3ca985f82c4f3174fc0cdd94e07331ba9a92e8e064c260337f/cuda_bindings-13.2.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6629ca2df6f795b784752409bcaedbd22a7a651b74b56a165ebc0c9dcbd504d0", size = 5614610, upload-time = "2026-03-11T00:12:50.337Z" }, + { url = "https://files.pythonhosted.org/packages/18/23/6db3aba46864aee357ab2415135b3fe3da7e9f1fa0221fa2a86a5968099c/cuda_bindings-13.2.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7dca0da053d3b4cc4869eff49c61c03f3c5dbaa0bcd712317a358d5b8f3f385d", size = 6149914, upload-time = "2026-03-11T00:12:52.374Z" }, + { url = "https://files.pythonhosted.org/packages/c0/87/87a014f045b77c6de5c8527b0757fe644417b184e5367db977236a141602/cuda_bindings-13.2.0-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a6464b30f46692d6c7f65d4a0e0450d81dd29de3afc1bb515653973d01c2cd6e", size = 5685673, upload-time = "2026-03-11T00:12:56.371Z" }, + { url = "https://files.pythonhosted.org/packages/ee/5e/c0fe77a73aaefd3fff25ffaccaac69c5a63eafdf8b9a4c476626ef0ac703/cuda_bindings-13.2.0-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f4af9f3e1be603fa12d5ad6cfca7844c9d230befa9792b5abdf7dd79979c3626", size = 6191386, upload-time = "2026-03-11T00:12:58.965Z" }, + { url = "https://files.pythonhosted.org/packages/5f/58/ed2c3b39c8dd5f96aa7a4abef0d47a73932c7a988e30f5fa428f00ed0da1/cuda_bindings-13.2.0-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df850a1ff8ce1b3385257b08e47b70e959932f5f432d0a4e46a355962b4e4771", size = 5507469, upload-time = "2026-03-11T00:13:04.063Z" }, + { url = "https://files.pythonhosted.org/packages/1f/01/0c941b112ceeb21439b05895eace78ca1aa2eaaf695c8521a068fd9b4c00/cuda_bindings-13.2.0-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8a16384c6494e5485f39314b0b4afb04bee48d49edb16d5d8593fd35bbd231b", size = 6059693, upload-time = "2026-03-11T00:13:06.003Z" }, +] + +[[package]] +name = "cuda-pathfinder" +version = "1.5.2" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f2/f9/1b9b60a30fc463c14cdea7a77228131a0ccc89572e8df9cb86c9648271ab/cuda_pathfinder-1.5.2-py3-none-any.whl", hash = "sha256:0c5f160a7756c5b072723cbbd6d861e38917ef956c68150b02f0b6e9271c71fa", size = 49988, upload-time = "2026-04-06T23:01:05.17Z" }, +] + +[[package]] +name = "cuda-toolkit" +version = "13.0.2" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/57/b2/453099f5f3b698d7d0eab38916aac44c7f76229f451709e2eb9db6615dcd/cuda_toolkit-13.0.2-py2.py3-none-any.whl", hash = "sha256:b198824cf2f54003f50d64ada3a0f184b42ca0846c1c94192fa269ecd97a66eb", size = 2364, upload-time = "2025-12-19T23:24:07.328Z" }, +] + +[package.optional-dependencies] +cublas = [ + { name = "nvidia-cublas", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +cudart = [ + { name = "nvidia-cuda-runtime", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +cufft = [ + { name = "nvidia-cufft", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +cufile = [ + { name = "nvidia-cufile", marker = "sys_platform == 'linux'" }, +] +cupti = [ + { name = "nvidia-cuda-cupti", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +curand = [ + { name = "nvidia-curand", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +cusolver = [ + { name = "nvidia-cusolver", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +cusparse = [ + { name = "nvidia-cusparse", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +nvjitlink = [ + { name = "nvidia-nvjitlink", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +nvrtc = [ + { name = "nvidia-cuda-nvrtc", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +nvtx = [ + { name = "nvidia-nvtx", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] + +[[package]] +name = "exceptiongroup" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size = 16740, upload-time = "2025-11-21T23:01:53.443Z" }, +] + +[[package]] +name = "filelock" +version = "3.25.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/b8/00651a0f559862f3bb7d6f7477b192afe3f583cc5e26403b44e59a55ab34/filelock-3.25.2.tar.gz", hash = "sha256:b64ece2b38f4ca29dd3e810287aa8c48182bbecd1ae6e9ae126c9b35f1382694", size = 40480, upload-time = "2026-03-11T20:45:38.487Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/a5/842ae8f0c08b61d6484b52f99a03510a3a72d23141942d216ebe81fefbce/filelock-3.25.2-py3-none-any.whl", hash = "sha256:ca8afb0da15f229774c9ad1b455ed96e85a81373065fb10446672f64444ddf70", size = 26759, upload-time = "2026-03-11T20:45:37.437Z" }, +] + +[[package]] +name = "fsspec" +version = "2026.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e1/cf/b50ddf667c15276a9ab15a70ef5f257564de271957933ffea49d2cdbcdfb/fsspec-2026.3.0.tar.gz", hash = "sha256:1ee6a0e28677557f8c2f994e3eea77db6392b4de9cd1f5d7a9e87a0ae9d01b41", size = 313547, upload-time = "2026-03-27T19:11:14.892Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/1f/5f4a3cd9e4440e9d9bc78ad0a91a1c8d46b4d429d5239ebe6793c9fe5c41/fsspec-2026.3.0-py3-none-any.whl", hash = "sha256:d2ceafaad1b3457968ed14efa28798162f1638dbb5d2a6868a2db002a5ee39a4", size = 202595, upload-time = "2026-03-27T19:11:13.595Z" }, +] + +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + +[[package]] +name = "jepa-2-1-neuron" +version = "0.1.0" +source = { virtual = "." } +dependencies = [ + { name = "torch" }, +] + +[package.optional-dependencies] +test = [ + { name = "pytest" }, +] + +[package.metadata] +requires-dist = [ + { name = "pytest", marker = "extra == 'test'", specifier = ">=7.0" }, + { name = "torch", specifier = ">=2.1" }, +] +provides-extras = ["test"] + +[[package]] +name = "jinja2" +version = "3.1.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, +] + +[[package]] +name = "markupsafe" +version = "3.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7e/99/7690b6d4034fffd95959cbe0c02de8deb3098cc577c67bb6a24fe5d7caa7/markupsafe-3.0.3.tar.gz", hash = "sha256:722695808f4b6457b320fdc131280796bdceb04ab50fe1795cd540799ebe1698", size = 80313, upload-time = "2025-09-27T18:37:40.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e8/4b/3541d44f3937ba468b75da9eebcae497dcf67adb65caa16760b0a6807ebb/markupsafe-3.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2f981d352f04553a7171b8e44369f2af4055f888dfb147d55e42d29e29e74559", size = 11631, upload-time = "2025-09-27T18:36:05.558Z" }, + { url = "https://files.pythonhosted.org/packages/98/1b/fbd8eed11021cabd9226c37342fa6ca4e8a98d8188a8d9b66740494960e4/markupsafe-3.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e1c1493fb6e50ab01d20a22826e57520f1284df32f2d8601fdd90b6304601419", size = 12057, upload-time = "2025-09-27T18:36:07.165Z" }, + { url = "https://files.pythonhosted.org/packages/40/01/e560d658dc0bb8ab762670ece35281dec7b6c1b33f5fbc09ebb57a185519/markupsafe-3.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1ba88449deb3de88bd40044603fafffb7bc2b055d626a330323a9ed736661695", size = 22050, upload-time = "2025-09-27T18:36:08.005Z" }, + { url = "https://files.pythonhosted.org/packages/af/cd/ce6e848bbf2c32314c9b237839119c5a564a59725b53157c856e90937b7a/markupsafe-3.0.3-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f42d0984e947b8adf7dd6dde396e720934d12c506ce84eea8476409563607591", size = 20681, upload-time = "2025-09-27T18:36:08.881Z" }, + { url = "https://files.pythonhosted.org/packages/c9/2a/b5c12c809f1c3045c4d580b035a743d12fcde53cf685dbc44660826308da/markupsafe-3.0.3-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c0c0b3ade1c0b13b936d7970b1d37a57acde9199dc2aecc4c336773e1d86049c", size = 20705, upload-time = "2025-09-27T18:36:10.131Z" }, + { url = "https://files.pythonhosted.org/packages/cf/e3/9427a68c82728d0a88c50f890d0fc072a1484de2f3ac1ad0bfc1a7214fd5/markupsafe-3.0.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:0303439a41979d9e74d18ff5e2dd8c43ed6c6001fd40e5bf2e43f7bd9bbc523f", size = 21524, upload-time = "2025-09-27T18:36:11.324Z" }, + { url = "https://files.pythonhosted.org/packages/bc/36/23578f29e9e582a4d0278e009b38081dbe363c5e7165113fad546918a232/markupsafe-3.0.3-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:d2ee202e79d8ed691ceebae8e0486bd9a2cd4794cec4824e1c99b6f5009502f6", size = 20282, upload-time = "2025-09-27T18:36:12.573Z" }, + { url = "https://files.pythonhosted.org/packages/56/21/dca11354e756ebd03e036bd8ad58d6d7168c80ce1fe5e75218e4945cbab7/markupsafe-3.0.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:177b5253b2834fe3678cb4a5f0059808258584c559193998be2601324fdeafb1", size = 20745, upload-time = "2025-09-27T18:36:13.504Z" }, + { url = "https://files.pythonhosted.org/packages/87/99/faba9369a7ad6e4d10b6a5fbf71fa2a188fe4a593b15f0963b73859a1bbd/markupsafe-3.0.3-cp310-cp310-win32.whl", hash = "sha256:2a15a08b17dd94c53a1da0438822d70ebcd13f8c3a95abe3a9ef9f11a94830aa", size = 14571, upload-time = "2025-09-27T18:36:14.779Z" }, + { url = "https://files.pythonhosted.org/packages/d6/25/55dc3ab959917602c96985cb1253efaa4ff42f71194bddeb61eb7278b8be/markupsafe-3.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:c4ffb7ebf07cfe8931028e3e4c85f0357459a3f9f9490886198848f4fa002ec8", size = 15056, upload-time = "2025-09-27T18:36:16.125Z" }, + { url = "https://files.pythonhosted.org/packages/d0/9e/0a02226640c255d1da0b8d12e24ac2aa6734da68bff14c05dd53b94a0fc3/markupsafe-3.0.3-cp310-cp310-win_arm64.whl", hash = "sha256:e2103a929dfa2fcaf9bb4e7c091983a49c9ac3b19c9061b6d5427dd7d14d81a1", size = 13932, upload-time = "2025-09-27T18:36:17.311Z" }, + { url = "https://files.pythonhosted.org/packages/08/db/fefacb2136439fc8dd20e797950e749aa1f4997ed584c62cfb8ef7c2be0e/markupsafe-3.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1cc7ea17a6824959616c525620e387f6dd30fec8cb44f649e31712db02123dad", size = 11631, upload-time = "2025-09-27T18:36:18.185Z" }, + { url = "https://files.pythonhosted.org/packages/e1/2e/5898933336b61975ce9dc04decbc0a7f2fee78c30353c5efba7f2d6ff27a/markupsafe-3.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4bd4cd07944443f5a265608cc6aab442e4f74dff8088b0dfc8238647b8f6ae9a", size = 12058, upload-time = "2025-09-27T18:36:19.444Z" }, + { url = "https://files.pythonhosted.org/packages/1d/09/adf2df3699d87d1d8184038df46a9c80d78c0148492323f4693df54e17bb/markupsafe-3.0.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b5420a1d9450023228968e7e6a9ce57f65d148ab56d2313fcd589eee96a7a50", size = 24287, upload-time = "2025-09-27T18:36:20.768Z" }, + { url = "https://files.pythonhosted.org/packages/30/ac/0273f6fcb5f42e314c6d8cd99effae6a5354604d461b8d392b5ec9530a54/markupsafe-3.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0bf2a864d67e76e5c9a34dc26ec616a66b9888e25e7b9460e1c76d3293bd9dbf", size = 22940, upload-time = "2025-09-27T18:36:22.249Z" }, + { url = "https://files.pythonhosted.org/packages/19/ae/31c1be199ef767124c042c6c3e904da327a2f7f0cd63a0337e1eca2967a8/markupsafe-3.0.3-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc51efed119bc9cfdf792cdeaa4d67e8f6fcccab66ed4bfdd6bde3e59bfcbb2f", size = 21887, upload-time = "2025-09-27T18:36:23.535Z" }, + { url = "https://files.pythonhosted.org/packages/b2/76/7edcab99d5349a4532a459e1fe64f0b0467a3365056ae550d3bcf3f79e1e/markupsafe-3.0.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:068f375c472b3e7acbe2d5318dea141359e6900156b5b2ba06a30b169086b91a", size = 23692, upload-time = "2025-09-27T18:36:24.823Z" }, + { url = "https://files.pythonhosted.org/packages/a4/28/6e74cdd26d7514849143d69f0bf2399f929c37dc2b31e6829fd2045b2765/markupsafe-3.0.3-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:7be7b61bb172e1ed687f1754f8e7484f1c8019780f6f6b0786e76bb01c2ae115", size = 21471, upload-time = "2025-09-27T18:36:25.95Z" }, + { url = "https://files.pythonhosted.org/packages/62/7e/a145f36a5c2945673e590850a6f8014318d5577ed7e5920a4b3448e0865d/markupsafe-3.0.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f9e130248f4462aaa8e2552d547f36ddadbeaa573879158d721bbd33dfe4743a", size = 22923, upload-time = "2025-09-27T18:36:27.109Z" }, + { url = "https://files.pythonhosted.org/packages/0f/62/d9c46a7f5c9adbeeeda52f5b8d802e1094e9717705a645efc71b0913a0a8/markupsafe-3.0.3-cp311-cp311-win32.whl", hash = "sha256:0db14f5dafddbb6d9208827849fad01f1a2609380add406671a26386cdf15a19", size = 14572, upload-time = "2025-09-27T18:36:28.045Z" }, + { url = "https://files.pythonhosted.org/packages/83/8a/4414c03d3f891739326e1783338e48fb49781cc915b2e0ee052aa490d586/markupsafe-3.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:de8a88e63464af587c950061a5e6a67d3632e36df62b986892331d4620a35c01", size = 15077, upload-time = "2025-09-27T18:36:29.025Z" }, + { url = "https://files.pythonhosted.org/packages/35/73/893072b42e6862f319b5207adc9ae06070f095b358655f077f69a35601f0/markupsafe-3.0.3-cp311-cp311-win_arm64.whl", hash = "sha256:3b562dd9e9ea93f13d53989d23a7e775fdfd1066c33494ff43f5418bc8c58a5c", size = 13876, upload-time = "2025-09-27T18:36:29.954Z" }, + { url = "https://files.pythonhosted.org/packages/5a/72/147da192e38635ada20e0a2e1a51cf8823d2119ce8883f7053879c2199b5/markupsafe-3.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d53197da72cc091b024dd97249dfc7794d6a56530370992a5e1a08983ad9230e", size = 11615, upload-time = "2025-09-27T18:36:30.854Z" }, + { url = "https://files.pythonhosted.org/packages/9a/81/7e4e08678a1f98521201c3079f77db69fb552acd56067661f8c2f534a718/markupsafe-3.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1872df69a4de6aead3491198eaf13810b565bdbeec3ae2dc8780f14458ec73ce", size = 12020, upload-time = "2025-09-27T18:36:31.971Z" }, + { url = "https://files.pythonhosted.org/packages/1e/2c/799f4742efc39633a1b54a92eec4082e4f815314869865d876824c257c1e/markupsafe-3.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3a7e8ae81ae39e62a41ec302f972ba6ae23a5c5396c8e60113e9066ef893da0d", size = 24332, upload-time = "2025-09-27T18:36:32.813Z" }, + { url = "https://files.pythonhosted.org/packages/3c/2e/8d0c2ab90a8c1d9a24f0399058ab8519a3279d1bd4289511d74e909f060e/markupsafe-3.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d6dd0be5b5b189d31db7cda48b91d7e0a9795f31430b7f271219ab30f1d3ac9d", size = 22947, upload-time = "2025-09-27T18:36:33.86Z" }, + { url = "https://files.pythonhosted.org/packages/2c/54/887f3092a85238093a0b2154bd629c89444f395618842e8b0c41783898ea/markupsafe-3.0.3-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:94c6f0bb423f739146aec64595853541634bde58b2135f27f61c1ffd1cd4d16a", size = 21962, upload-time = "2025-09-27T18:36:35.099Z" }, + { url = "https://files.pythonhosted.org/packages/c9/2f/336b8c7b6f4a4d95e91119dc8521402461b74a485558d8f238a68312f11c/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:be8813b57049a7dc738189df53d69395eba14fb99345e0a5994914a3864c8a4b", size = 23760, upload-time = "2025-09-27T18:36:36.001Z" }, + { url = "https://files.pythonhosted.org/packages/32/43/67935f2b7e4982ffb50a4d169b724d74b62a3964bc1a9a527f5ac4f1ee2b/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:83891d0e9fb81a825d9a6d61e3f07550ca70a076484292a70fde82c4b807286f", size = 21529, upload-time = "2025-09-27T18:36:36.906Z" }, + { url = "https://files.pythonhosted.org/packages/89/e0/4486f11e51bbba8b0c041098859e869e304d1c261e59244baa3d295d47b7/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:77f0643abe7495da77fb436f50f8dab76dbc6e5fd25d39589a0f1fe6548bfa2b", size = 23015, upload-time = "2025-09-27T18:36:37.868Z" }, + { url = "https://files.pythonhosted.org/packages/2f/e1/78ee7a023dac597a5825441ebd17170785a9dab23de95d2c7508ade94e0e/markupsafe-3.0.3-cp312-cp312-win32.whl", hash = "sha256:d88b440e37a16e651bda4c7c2b930eb586fd15ca7406cb39e211fcff3bf3017d", size = 14540, upload-time = "2025-09-27T18:36:38.761Z" }, + { url = "https://files.pythonhosted.org/packages/aa/5b/bec5aa9bbbb2c946ca2733ef9c4ca91c91b6a24580193e891b5f7dbe8e1e/markupsafe-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:26a5784ded40c9e318cfc2bdb30fe164bdb8665ded9cd64d500a34fb42067b1c", size = 15105, upload-time = "2025-09-27T18:36:39.701Z" }, + { url = "https://files.pythonhosted.org/packages/e5/f1/216fc1bbfd74011693a4fd837e7026152e89c4bcf3e77b6692fba9923123/markupsafe-3.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:35add3b638a5d900e807944a078b51922212fb3dedb01633a8defc4b01a3c85f", size = 13906, upload-time = "2025-09-27T18:36:40.689Z" }, + { url = "https://files.pythonhosted.org/packages/38/2f/907b9c7bbba283e68f20259574b13d005c121a0fa4c175f9bed27c4597ff/markupsafe-3.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e1cf1972137e83c5d4c136c43ced9ac51d0e124706ee1c8aa8532c1287fa8795", size = 11622, upload-time = "2025-09-27T18:36:41.777Z" }, + { url = "https://files.pythonhosted.org/packages/9c/d9/5f7756922cdd676869eca1c4e3c0cd0df60ed30199ffd775e319089cb3ed/markupsafe-3.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:116bb52f642a37c115f517494ea5feb03889e04df47eeff5b130b1808ce7c219", size = 12029, upload-time = "2025-09-27T18:36:43.257Z" }, + { url = "https://files.pythonhosted.org/packages/00/07/575a68c754943058c78f30db02ee03a64b3c638586fba6a6dd56830b30a3/markupsafe-3.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:133a43e73a802c5562be9bbcd03d090aa5a1fe899db609c29e8c8d815c5f6de6", size = 24374, upload-time = "2025-09-27T18:36:44.508Z" }, + { url = "https://files.pythonhosted.org/packages/a9/21/9b05698b46f218fc0e118e1f8168395c65c8a2c750ae2bab54fc4bd4e0e8/markupsafe-3.0.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ccfcd093f13f0f0b7fdd0f198b90053bf7b2f02a3927a30e63f3ccc9df56b676", size = 22980, upload-time = "2025-09-27T18:36:45.385Z" }, + { url = "https://files.pythonhosted.org/packages/7f/71/544260864f893f18b6827315b988c146b559391e6e7e8f7252839b1b846a/markupsafe-3.0.3-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:509fa21c6deb7a7a273d629cf5ec029bc209d1a51178615ddf718f5918992ab9", size = 21990, upload-time = "2025-09-27T18:36:46.916Z" }, + { url = "https://files.pythonhosted.org/packages/c2/28/b50fc2f74d1ad761af2f5dcce7492648b983d00a65b8c0e0cb457c82ebbe/markupsafe-3.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4afe79fb3de0b7097d81da19090f4df4f8d3a2b3adaa8764138aac2e44f3af1", size = 23784, upload-time = "2025-09-27T18:36:47.884Z" }, + { url = "https://files.pythonhosted.org/packages/ed/76/104b2aa106a208da8b17a2fb72e033a5a9d7073c68f7e508b94916ed47a9/markupsafe-3.0.3-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:795e7751525cae078558e679d646ae45574b47ed6e7771863fcc079a6171a0fc", size = 21588, upload-time = "2025-09-27T18:36:48.82Z" }, + { url = "https://files.pythonhosted.org/packages/b5/99/16a5eb2d140087ebd97180d95249b00a03aa87e29cc224056274f2e45fd6/markupsafe-3.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8485f406a96febb5140bfeca44a73e3ce5116b2501ac54fe953e488fb1d03b12", size = 23041, upload-time = "2025-09-27T18:36:49.797Z" }, + { url = "https://files.pythonhosted.org/packages/19/bc/e7140ed90c5d61d77cea142eed9f9c303f4c4806f60a1044c13e3f1471d0/markupsafe-3.0.3-cp313-cp313-win32.whl", hash = "sha256:bdd37121970bfd8be76c5fb069c7751683bdf373db1ed6c010162b2a130248ed", size = 14543, upload-time = "2025-09-27T18:36:51.584Z" }, + { url = "https://files.pythonhosted.org/packages/05/73/c4abe620b841b6b791f2edc248f556900667a5a1cf023a6646967ae98335/markupsafe-3.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:9a1abfdc021a164803f4d485104931fb8f8c1efd55bc6b748d2f5774e78b62c5", size = 15113, upload-time = "2025-09-27T18:36:52.537Z" }, + { url = "https://files.pythonhosted.org/packages/f0/3a/fa34a0f7cfef23cf9500d68cb7c32dd64ffd58a12b09225fb03dd37d5b80/markupsafe-3.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:7e68f88e5b8799aa49c85cd116c932a1ac15caaa3f5db09087854d218359e485", size = 13911, upload-time = "2025-09-27T18:36:53.513Z" }, + { url = "https://files.pythonhosted.org/packages/e4/d7/e05cd7efe43a88a17a37b3ae96e79a19e846f3f456fe79c57ca61356ef01/markupsafe-3.0.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:218551f6df4868a8d527e3062d0fb968682fe92054e89978594c28e642c43a73", size = 11658, upload-time = "2025-09-27T18:36:54.819Z" }, + { url = "https://files.pythonhosted.org/packages/99/9e/e412117548182ce2148bdeacdda3bb494260c0b0184360fe0d56389b523b/markupsafe-3.0.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3524b778fe5cfb3452a09d31e7b5adefeea8c5be1d43c4f810ba09f2ceb29d37", size = 12066, upload-time = "2025-09-27T18:36:55.714Z" }, + { url = "https://files.pythonhosted.org/packages/bc/e6/fa0ffcda717ef64a5108eaa7b4f5ed28d56122c9a6d70ab8b72f9f715c80/markupsafe-3.0.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4e885a3d1efa2eadc93c894a21770e4bc67899e3543680313b09f139e149ab19", size = 25639, upload-time = "2025-09-27T18:36:56.908Z" }, + { url = "https://files.pythonhosted.org/packages/96/ec/2102e881fe9d25fc16cb4b25d5f5cde50970967ffa5dddafdb771237062d/markupsafe-3.0.3-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8709b08f4a89aa7586de0aadc8da56180242ee0ada3999749b183aa23df95025", size = 23569, upload-time = "2025-09-27T18:36:57.913Z" }, + { url = "https://files.pythonhosted.org/packages/4b/30/6f2fce1f1f205fc9323255b216ca8a235b15860c34b6798f810f05828e32/markupsafe-3.0.3-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:b8512a91625c9b3da6f127803b166b629725e68af71f8184ae7e7d54686a56d6", size = 23284, upload-time = "2025-09-27T18:36:58.833Z" }, + { url = "https://files.pythonhosted.org/packages/58/47/4a0ccea4ab9f5dcb6f79c0236d954acb382202721e704223a8aafa38b5c8/markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9b79b7a16f7fedff2495d684f2b59b0457c3b493778c9eed31111be64d58279f", size = 24801, upload-time = "2025-09-27T18:36:59.739Z" }, + { url = "https://files.pythonhosted.org/packages/6a/70/3780e9b72180b6fecb83a4814d84c3bf4b4ae4bf0b19c27196104149734c/markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:12c63dfb4a98206f045aa9563db46507995f7ef6d83b2f68eda65c307c6829eb", size = 22769, upload-time = "2025-09-27T18:37:00.719Z" }, + { url = "https://files.pythonhosted.org/packages/98/c5/c03c7f4125180fc215220c035beac6b9cb684bc7a067c84fc69414d315f5/markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:8f71bc33915be5186016f675cd83a1e08523649b0e33efdb898db577ef5bb009", size = 23642, upload-time = "2025-09-27T18:37:01.673Z" }, + { url = "https://files.pythonhosted.org/packages/80/d6/2d1b89f6ca4bff1036499b1e29a1d02d282259f3681540e16563f27ebc23/markupsafe-3.0.3-cp313-cp313t-win32.whl", hash = "sha256:69c0b73548bc525c8cb9a251cddf1931d1db4d2258e9599c28c07ef3580ef354", size = 14612, upload-time = "2025-09-27T18:37:02.639Z" }, + { url = "https://files.pythonhosted.org/packages/2b/98/e48a4bfba0a0ffcf9925fe2d69240bfaa19c6f7507b8cd09c70684a53c1e/markupsafe-3.0.3-cp313-cp313t-win_amd64.whl", hash = "sha256:1b4b79e8ebf6b55351f0d91fe80f893b4743f104bff22e90697db1590e47a218", size = 15200, upload-time = "2025-09-27T18:37:03.582Z" }, + { url = "https://files.pythonhosted.org/packages/0e/72/e3cc540f351f316e9ed0f092757459afbc595824ca724cbc5a5d4263713f/markupsafe-3.0.3-cp313-cp313t-win_arm64.whl", hash = "sha256:ad2cf8aa28b8c020ab2fc8287b0f823d0a7d8630784c31e9ee5edea20f406287", size = 13973, upload-time = "2025-09-27T18:37:04.929Z" }, + { url = "https://files.pythonhosted.org/packages/33/8a/8e42d4838cd89b7dde187011e97fe6c3af66d8c044997d2183fbd6d31352/markupsafe-3.0.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:eaa9599de571d72e2daf60164784109f19978b327a3910d3e9de8c97b5b70cfe", size = 11619, upload-time = "2025-09-27T18:37:06.342Z" }, + { url = "https://files.pythonhosted.org/packages/b5/64/7660f8a4a8e53c924d0fa05dc3a55c9cee10bbd82b11c5afb27d44b096ce/markupsafe-3.0.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c47a551199eb8eb2121d4f0f15ae0f923d31350ab9280078d1e5f12b249e0026", size = 12029, upload-time = "2025-09-27T18:37:07.213Z" }, + { url = "https://files.pythonhosted.org/packages/da/ef/e648bfd021127bef5fa12e1720ffed0c6cbb8310c8d9bea7266337ff06de/markupsafe-3.0.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f34c41761022dd093b4b6896d4810782ffbabe30f2d443ff5f083e0cbbb8c737", size = 24408, upload-time = "2025-09-27T18:37:09.572Z" }, + { url = "https://files.pythonhosted.org/packages/41/3c/a36c2450754618e62008bf7435ccb0f88053e07592e6028a34776213d877/markupsafe-3.0.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:457a69a9577064c05a97c41f4e65148652db078a3a509039e64d3467b9e7ef97", size = 23005, upload-time = "2025-09-27T18:37:10.58Z" }, + { url = "https://files.pythonhosted.org/packages/bc/20/b7fdf89a8456b099837cd1dc21974632a02a999ec9bf7ca3e490aacd98e7/markupsafe-3.0.3-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e8afc3f2ccfa24215f8cb28dcf43f0113ac3c37c2f0f0806d8c70e4228c5cf4d", size = 22048, upload-time = "2025-09-27T18:37:11.547Z" }, + { url = "https://files.pythonhosted.org/packages/9a/a7/591f592afdc734f47db08a75793a55d7fbcc6902a723ae4cfbab61010cc5/markupsafe-3.0.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ec15a59cf5af7be74194f7ab02d0f59a62bdcf1a537677ce67a2537c9b87fcda", size = 23821, upload-time = "2025-09-27T18:37:12.48Z" }, + { url = "https://files.pythonhosted.org/packages/7d/33/45b24e4f44195b26521bc6f1a82197118f74df348556594bd2262bda1038/markupsafe-3.0.3-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:0eb9ff8191e8498cca014656ae6b8d61f39da5f95b488805da4bb029cccbfbaf", size = 21606, upload-time = "2025-09-27T18:37:13.485Z" }, + { url = "https://files.pythonhosted.org/packages/ff/0e/53dfaca23a69fbfbbf17a4b64072090e70717344c52eaaaa9c5ddff1e5f0/markupsafe-3.0.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:2713baf880df847f2bece4230d4d094280f4e67b1e813eec43b4c0e144a34ffe", size = 23043, upload-time = "2025-09-27T18:37:14.408Z" }, + { url = "https://files.pythonhosted.org/packages/46/11/f333a06fc16236d5238bfe74daccbca41459dcd8d1fa952e8fbd5dccfb70/markupsafe-3.0.3-cp314-cp314-win32.whl", hash = "sha256:729586769a26dbceff69f7a7dbbf59ab6572b99d94576a5592625d5b411576b9", size = 14747, upload-time = "2025-09-27T18:37:15.36Z" }, + { url = "https://files.pythonhosted.org/packages/28/52/182836104b33b444e400b14f797212f720cbc9ed6ba34c800639d154e821/markupsafe-3.0.3-cp314-cp314-win_amd64.whl", hash = "sha256:bdc919ead48f234740ad807933cdf545180bfbe9342c2bb451556db2ed958581", size = 15341, upload-time = "2025-09-27T18:37:16.496Z" }, + { url = "https://files.pythonhosted.org/packages/6f/18/acf23e91bd94fd7b3031558b1f013adfa21a8e407a3fdb32745538730382/markupsafe-3.0.3-cp314-cp314-win_arm64.whl", hash = "sha256:5a7d5dc5140555cf21a6fefbdbf8723f06fcd2f63ef108f2854de715e4422cb4", size = 14073, upload-time = "2025-09-27T18:37:17.476Z" }, + { url = "https://files.pythonhosted.org/packages/3c/f0/57689aa4076e1b43b15fdfa646b04653969d50cf30c32a102762be2485da/markupsafe-3.0.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:1353ef0c1b138e1907ae78e2f6c63ff67501122006b0f9abad68fda5f4ffc6ab", size = 11661, upload-time = "2025-09-27T18:37:18.453Z" }, + { url = "https://files.pythonhosted.org/packages/89/c3/2e67a7ca217c6912985ec766c6393b636fb0c2344443ff9d91404dc4c79f/markupsafe-3.0.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:1085e7fbddd3be5f89cc898938f42c0b3c711fdcb37d75221de2666af647c175", size = 12069, upload-time = "2025-09-27T18:37:19.332Z" }, + { url = "https://files.pythonhosted.org/packages/f0/00/be561dce4e6ca66b15276e184ce4b8aec61fe83662cce2f7d72bd3249d28/markupsafe-3.0.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1b52b4fb9df4eb9ae465f8d0c228a00624de2334f216f178a995ccdcf82c4634", size = 25670, upload-time = "2025-09-27T18:37:20.245Z" }, + { url = "https://files.pythonhosted.org/packages/50/09/c419f6f5a92e5fadde27efd190eca90f05e1261b10dbd8cbcb39cd8ea1dc/markupsafe-3.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fed51ac40f757d41b7c48425901843666a6677e3e8eb0abcff09e4ba6e664f50", size = 23598, upload-time = "2025-09-27T18:37:21.177Z" }, + { url = "https://files.pythonhosted.org/packages/22/44/a0681611106e0b2921b3033fc19bc53323e0b50bc70cffdd19f7d679bb66/markupsafe-3.0.3-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f190daf01f13c72eac4efd5c430a8de82489d9cff23c364c3ea822545032993e", size = 23261, upload-time = "2025-09-27T18:37:22.167Z" }, + { url = "https://files.pythonhosted.org/packages/5f/57/1b0b3f100259dc9fffe780cfb60d4be71375510e435efec3d116b6436d43/markupsafe-3.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:e56b7d45a839a697b5eb268c82a71bd8c7f6c94d6fd50c3d577fa39a9f1409f5", size = 24835, upload-time = "2025-09-27T18:37:23.296Z" }, + { url = "https://files.pythonhosted.org/packages/26/6a/4bf6d0c97c4920f1597cc14dd720705eca0bf7c787aebc6bb4d1bead5388/markupsafe-3.0.3-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:f3e98bb3798ead92273dc0e5fd0f31ade220f59a266ffd8a4f6065e0a3ce0523", size = 22733, upload-time = "2025-09-27T18:37:24.237Z" }, + { url = "https://files.pythonhosted.org/packages/14/c7/ca723101509b518797fedc2fdf79ba57f886b4aca8a7d31857ba3ee8281f/markupsafe-3.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5678211cb9333a6468fb8d8be0305520aa073f50d17f089b5b4b477ea6e67fdc", size = 23672, upload-time = "2025-09-27T18:37:25.271Z" }, + { url = "https://files.pythonhosted.org/packages/fb/df/5bd7a48c256faecd1d36edc13133e51397e41b73bb77e1a69deab746ebac/markupsafe-3.0.3-cp314-cp314t-win32.whl", hash = "sha256:915c04ba3851909ce68ccc2b8e2cd691618c4dc4c4232fb7982bca3f41fd8c3d", size = 14819, upload-time = "2025-09-27T18:37:26.285Z" }, + { url = "https://files.pythonhosted.org/packages/1a/8a/0402ba61a2f16038b48b39bccca271134be00c5c9f0f623208399333c448/markupsafe-3.0.3-cp314-cp314t-win_amd64.whl", hash = "sha256:4faffd047e07c38848ce017e8725090413cd80cbc23d86e55c587bf979e579c9", size = 15426, upload-time = "2025-09-27T18:37:27.316Z" }, + { url = "https://files.pythonhosted.org/packages/70/bc/6f1c2f612465f5fa89b95bead1f44dcb607670fd42891d8fdcd5d039f4f4/markupsafe-3.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa", size = 14146, upload-time = "2025-09-27T18:37:28.327Z" }, +] + +[[package]] +name = "mpmath" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106, upload-time = "2023-03-07T16:47:11.061Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" }, +] + +[[package]] +name = "networkx" +version = "3.4.2" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", +] +sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263, upload-time = "2024-10-21T12:39:36.247Z" }, +] + +[[package]] +name = "networkx" +version = "3.6.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.11'", +] +sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" }, +] + +[[package]] +name = "nvidia-cublas" +version = "13.1.0.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/a5/fce49e2ae977e0ccc084e5adafceb4f0ac0c8333cb6863501618a7277f67/nvidia_cublas-13.1.0.3-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c86fc7f7ae36d7528288c5d88098edcb7b02c633d262e7ddbb86b0ad91be5df2", size = 542851226, upload-time = "2025-10-09T08:59:04.818Z" }, + { url = "https://files.pythonhosted.org/packages/e7/44/423ac00af4dd95a5aeb27207e2c0d9b7118702149bf4704c3ddb55bb7429/nvidia_cublas-13.1.0.3-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:ee8722c1f0145ab246bccb9e452153b5e0515fd094c3678df50b2a0888b8b171", size = 423133236, upload-time = "2025-10-09T08:59:32.536Z" }, +] + +[[package]] +name = "nvidia-cuda-cupti" +version = "13.0.85" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/2a/80353b103fc20ce05ef51e928daed4b6015db4aaa9162ed0997090fe2250/nvidia_cuda_cupti-13.0.85-py3-none-manylinux_2_25_aarch64.whl", hash = "sha256:796bd679890ee55fb14a94629b698b6db54bcfd833d391d5e94017dd9d7d3151", size = 10310827, upload-time = "2025-09-04T08:26:42.012Z" }, + { url = "https://files.pythonhosted.org/packages/33/6d/737d164b4837a9bbd202f5ae3078975f0525a55730fe871d8ed4e3b952b0/nvidia_cuda_cupti-13.0.85-py3-none-manylinux_2_25_x86_64.whl", hash = "sha256:4eb01c08e859bf924d222250d2e8f8b8ff6d3db4721288cf35d14252a4d933c8", size = 10715597, upload-time = "2025-09-04T08:26:51.312Z" }, +] + +[[package]] +name = "nvidia-cuda-nvrtc" +version = "13.0.88" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/68/483a78f5e8f31b08fb1bb671559968c0ca3a065ac7acabfc7cee55214fd6/nvidia_cuda_nvrtc-13.0.88-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:ad9b6d2ead2435f11cbb6868809d2adeeee302e9bb94bcf0539c7a40d80e8575", size = 90215200, upload-time = "2025-09-04T08:28:44.204Z" }, + { url = "https://files.pythonhosted.org/packages/b7/dc/6bb80850e0b7edd6588d560758f17e0550893a1feaf436807d64d2da040f/nvidia_cuda_nvrtc-13.0.88-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d27f20a0ca67a4bb34268a5e951033496c5b74870b868bacd046b1b8e0c3267b", size = 43015449, upload-time = "2025-09-04T08:28:20.239Z" }, +] + +[[package]] +name = "nvidia-cuda-runtime" +version = "13.0.96" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/4f/17d7b9b8e285199c58ce28e31b5c5bbaa4d8271af06a89b6405258245de2/nvidia_cuda_runtime-13.0.96-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ef9bcbe90493a2b9d810e43d249adb3d02e98dd30200d86607d8d02687c43f55", size = 2261060, upload-time = "2025-10-09T08:55:15.78Z" }, + { url = "https://files.pythonhosted.org/packages/2e/24/d1558f3b68b1d26e706813b1d10aa1d785e4698c425af8db8edc3dced472/nvidia_cuda_runtime-13.0.96-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7f82250d7782aa23b6cfe765ecc7db554bd3c2870c43f3d1821f1d18aebf0548", size = 2243632, upload-time = "2025-10-09T08:55:36.117Z" }, +] + +[[package]] +name = "nvidia-cudnn-cu13" +version = "9.19.0.56" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/f1/84/26025437c1e6b61a707442184fa0c03d083b661adf3a3eecfd6d21677740/nvidia_cudnn_cu13-9.19.0.56-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:6ed29ffaee1176c612daf442e4dd6cfeb6a0caa43ddcbeb59da94953030b1be4", size = 433781201, upload-time = "2026-02-03T20:40:53.805Z" }, + { url = "https://files.pythonhosted.org/packages/a3/22/0b4b932655d17a6da1b92fa92ab12844b053bb2ac2475e179ba6f043da1e/nvidia_cudnn_cu13-9.19.0.56-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:d20e1734305e9d68889a96e3f35094d733ff1f83932ebe462753973e53a572bf", size = 366066321, upload-time = "2026-02-03T20:44:52.837Z" }, +] + +[[package]] +name = "nvidia-cufft" +version = "12.0.0.61" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/ae/f417a75c0259e85c1d2f83ca4e960289a5f814ed0cea74d18c353d3e989d/nvidia_cufft-12.0.0.61-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5", size = 214053554, upload-time = "2025-09-04T08:31:38.196Z" }, + { url = "https://files.pythonhosted.org/packages/a8/2f/7b57e29836ea8714f81e9898409196f47d772d5ddedddf1592eadb8ab743/nvidia_cufft-12.0.0.61-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6c44f692dce8fd5ffd3e3df134b6cdb9c2f72d99cf40b62c32dde45eea9ddad3", size = 214085489, upload-time = "2025-09-04T08:31:56.044Z" }, +] + +[[package]] +name = "nvidia-cufile" +version = "1.15.1.6" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/70/4f193de89a48b71714e74602ee14d04e4019ad36a5a9f20c425776e72cd6/nvidia_cufile-1.15.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08a3ecefae5a01c7f5117351c64f17c7c62efa5fffdbe24fc7d298da19cd0b44", size = 1223672, upload-time = "2025-09-04T08:32:22.779Z" }, + { url = "https://files.pythonhosted.org/packages/ab/73/cc4a14c9813a8a0d509417cf5f4bdaba76e924d58beb9864f5a7baceefbf/nvidia_cufile-1.15.1.6-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:bdc0deedc61f548bddf7733bdc216456c2fdb101d020e1ab4b88d232d5e2f6d1", size = 1136992, upload-time = "2025-09-04T08:32:14.119Z" }, +] + +[[package]] +name = "nvidia-curand" +version = "10.4.0.35" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/72/7c2ae24fb6b63a32e6ae5d241cc65263ea18d08802aaae087d9f013335a2/nvidia_curand-10.4.0.35-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:133df5a7509c3e292aaa2b477afd0194f06ce4ea24d714d616ff36439cee349a", size = 61962106, upload-time = "2025-08-04T10:21:41.128Z" }, + { url = "https://files.pythonhosted.org/packages/a5/9f/be0a41ca4a4917abf5cb9ae0daff1a6060cc5de950aec0396de9f3b52bc5/nvidia_curand-10.4.0.35-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:1aee33a5da6e1db083fe2b90082def8915f30f3248d5896bcec36a579d941bfc", size = 59544258, upload-time = "2025-08-04T10:22:03.992Z" }, +] + +[[package]] +name = "nvidia-cusolver" +version = "12.0.4.66" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas" }, + { name = "nvidia-cusparse" }, + { name = "nvidia-nvjitlink" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/c3/b30c9e935fc01e3da443ec0116ed1b2a009bb867f5324d3f2d7e533e776b/nvidia_cusolver-12.0.4.66-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2", size = 223467760, upload-time = "2025-09-04T08:33:04.222Z" }, + { url = "https://files.pythonhosted.org/packages/5f/67/cba3777620cdacb99102da4042883709c41c709f4b6323c10781a9c3aa34/nvidia_cusolver-12.0.4.66-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:0a759da5dea5c0ea10fd307de75cdeb59e7ea4fcb8add0924859b944babf1112", size = 200941980, upload-time = "2025-09-04T08:33:22.767Z" }, +] + +[[package]] +name = "nvidia-cusparse" +version = "12.6.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/94/5c26f33738ae35276672f12615a64bd008ed5be6d1ebcb23579285d960a9/nvidia_cusparse-12.6.3.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c", size = 162155568, upload-time = "2025-09-04T08:33:42.864Z" }, + { url = "https://files.pythonhosted.org/packages/fa/18/623c77619c31d62efd55302939756966f3ecc8d724a14dab2b75f1508850/nvidia_cusparse-12.6.3.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2b3c89c88d01ee0e477cb7f82ef60a11a4bcd57b6b87c33f789350b59759360b", size = 145942937, upload-time = "2025-09-04T08:33:58.029Z" }, +] + +[[package]] +name = "nvidia-cusparselt-cu13" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/10/8dcd1175260706a2fc92a16a52e306b71d4c1ea0b0cc4a9484183399818a/nvidia_cusparselt_cu13-0.8.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:400c6ed1cf6780fc6efedd64ec9f1345871767e6a1a0a552a1ea0578117ea77c", size = 220791277, upload-time = "2025-08-13T19:22:40.982Z" }, + { url = "https://files.pythonhosted.org/packages/fd/53/43b0d71f4e702fa9733f8b4571fdca50a8813f1e450b656c239beff12315/nvidia_cusparselt_cu13-0.8.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:25e30a8a7323935d4ad0340b95a0b69926eee755767e8e0b1cf8dd85b197d3fd", size = 169884119, upload-time = "2025-08-13T19:23:41.967Z" }, +] + +[[package]] +name = "nvidia-nccl-cu13" +version = "2.28.9" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/55/1920646a2e43ffd4fc958536b276197ed740e9e0c54105b4bb3521591fc7/nvidia_nccl_cu13-2.28.9-py3-none-manylinux_2_18_aarch64.whl", hash = "sha256:01c873ba1626b54caa12272ed228dc5b2781545e0ae8ba3f432a8ef1c6d78643", size = 196561677, upload-time = "2025-11-18T05:49:03.45Z" }, + { url = "https://files.pythonhosted.org/packages/b0/b4/878fefaad5b2bcc6fcf8d474a25e3e3774bc5133e4b58adff4d0bca238bc/nvidia_nccl_cu13-2.28.9-py3-none-manylinux_2_18_x86_64.whl", hash = "sha256:e4553a30f34195f3fa1da02a6da3d6337d28f2003943aa0a3d247bbc25fefc42", size = 196493177, upload-time = "2025-11-18T05:49:17.677Z" }, +] + +[[package]] +name = "nvidia-nvjitlink" +version = "13.0.88" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/7a/123e033aaff487c77107195fa5a2b8686795ca537935a24efae476c41f05/nvidia_nvjitlink-13.0.88-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:13a74f429e23b921c1109976abefacc69835f2f433ebd323d3946e11d804e47b", size = 40713933, upload-time = "2025-09-04T08:35:43.553Z" }, + { url = "https://files.pythonhosted.org/packages/ab/2c/93c5250e64df4f894f1cbb397c6fd71f79813f9fd79d7cd61de3f97b3c2d/nvidia_nvjitlink-13.0.88-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e931536ccc7d467a98ba1d8b89ff7fa7f1fa3b13f2b0069118cd7f47bff07d0c", size = 38768748, upload-time = "2025-09-04T08:35:20.008Z" }, +] + +[[package]] +name = "nvidia-nvshmem-cu13" +version = "3.4.5" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/0f/05cc9c720236dcd2db9c1ab97fff629e96821be2e63103569da0c9b72f19/nvidia_nvshmem_cu13-3.4.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6dc2a197f38e5d0376ad52cd1a2a3617d3cdc150fd5966f4aee9bcebb1d68fe9", size = 60215947, upload-time = "2025-09-06T00:32:20.022Z" }, + { url = "https://files.pythonhosted.org/packages/3c/35/a9bf80a609e74e3b000fef598933235c908fcefcef9026042b8e6dfde2a9/nvidia_nvshmem_cu13-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:290f0a2ee94c9f3687a02502f3b9299a9f9fe826e6d0287ee18482e78d495b80", size = 60412546, upload-time = "2025-09-06T00:32:41.564Z" }, +] + +[[package]] +name = "nvidia-nvtx" +version = "13.0.85" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/f3/d86c845465a2723ad7e1e5c36dcd75ddb82898b3f53be47ebd429fb2fa5d/nvidia_nvtx-13.0.85-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:4936d1d6780fbe68db454f5e72a42ff64d1fd6397df9f363ae786930fd5c1cd4", size = 148047, upload-time = "2025-09-04T08:29:01.761Z" }, + { url = "https://files.pythonhosted.org/packages/a8/64/3708a90d1ebe202ffdeb7185f878a3c84d15c2b2c31858da2ce0583e2def/nvidia_nvtx-13.0.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb7780edb6b14107373c835bf8b72e7a178bac7367e23da7acb108f973f157a6", size = 148878, upload-time = "2025-09-04T08:28:53.627Z" }, +] + +[[package]] +name = "packaging" +version = "26.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/65/ee/299d360cdc32edc7d2cf530f3accf79c4fca01e96ffc950d8a52213bd8e4/packaging-26.0.tar.gz", hash = "sha256:00243ae351a257117b6a241061796684b084ed1c516a08c48a3f7e147a9d80b4", size = 143416, upload-time = "2026-01-21T20:50:39.064Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" }, +] + +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + +[[package]] +name = "pygments" +version = "2.20.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/b2/bc9c9196916376152d655522fdcebac55e66de6603a76a02bca1b6414f6c/pygments-2.20.0.tar.gz", hash = "sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f", size = 4955991, upload-time = "2026-03-29T13:29:33.898Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/7e/a72dd26f3b0f4f2bf1dd8923c85f7ceb43172af56d63c7383eb62b332364/pygments-2.20.0-py3-none-any.whl", hash = "sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176", size = 1231151, upload-time = "2026-03-29T13:29:30.038Z" }, +] + +[[package]] +name = "pytest" +version = "9.0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" }, +] + +[[package]] +name = "setuptools" +version = "81.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/1c/73e719955c59b8e424d015ab450f51c0af856ae46ea2da83eba51cc88de1/setuptools-81.0.0.tar.gz", hash = "sha256:487b53915f52501f0a79ccfd0c02c165ffe06631443a886740b91af4b7a5845a", size = 1198299, upload-time = "2026-02-06T21:10:39.601Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/e3/c164c88b2e5ce7b24d667b9bd83589cf4f3520d97cad01534cd3c4f55fdb/setuptools-81.0.0-py3-none-any.whl", hash = "sha256:fdd925d5c5d9f62e4b74b30d6dd7828ce236fd6ed998a08d81de62ce5a6310d6", size = 1062021, upload-time = "2026-02-06T21:10:37.175Z" }, +] + +[[package]] +name = "sympy" +version = "1.14.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mpmath" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921, upload-time = "2025-04-27T18:05:01.611Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, +] + +[[package]] +name = "tomli" +version = "2.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/22/de/48c59722572767841493b26183a0d1cc411d54fd759c5607c4590b6563a6/tomli-2.4.1.tar.gz", hash = "sha256:7c7e1a961a0b2f2472c1ac5b69affa0ae1132c39adcb67aba98568702b9cc23f", size = 17543, upload-time = "2026-03-25T20:22:03.828Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/11/db3d5885d8528263d8adc260bb2d28ebf1270b96e98f0e0268d32b8d9900/tomli-2.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f8f0fc26ec2cc2b965b7a3b87cd19c5c6b8c5e5f436b984e85f486d652285c30", size = 154704, upload-time = "2026-03-25T20:21:10.473Z" }, + { url = "https://files.pythonhosted.org/packages/6d/f7/675db52c7e46064a9aa928885a9b20f4124ecb9bc2e1ce74c9106648d202/tomli-2.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4ab97e64ccda8756376892c53a72bd1f964e519c77236368527f758fbc36a53a", size = 149454, upload-time = "2026-03-25T20:21:12.036Z" }, + { url = "https://files.pythonhosted.org/packages/61/71/81c50943cf953efa35bce7646caab3cf457a7d8c030b27cfb40d7235f9ee/tomli-2.4.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:96481a5786729fd470164b47cdb3e0e58062a496f455ee41b4403be77cb5a076", size = 237561, upload-time = "2026-03-25T20:21:13.098Z" }, + { url = "https://files.pythonhosted.org/packages/48/c1/f41d9cb618acccca7df82aaf682f9b49013c9397212cb9f53219e3abac37/tomli-2.4.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5a881ab208c0baf688221f8cecc5401bd291d67e38a1ac884d6736cbcd8247e9", size = 243824, upload-time = "2026-03-25T20:21:14.569Z" }, + { url = "https://files.pythonhosted.org/packages/22/e4/5a816ecdd1f8ca51fb756ef684b90f2780afc52fc67f987e3c61d800a46d/tomli-2.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:47149d5bd38761ac8be13a84864bf0b7b70bc051806bc3669ab1cbc56216b23c", size = 242227, upload-time = "2026-03-25T20:21:15.712Z" }, + { url = "https://files.pythonhosted.org/packages/6b/49/2b2a0ef529aa6eec245d25f0c703e020a73955ad7edf73e7f54ddc608aa5/tomli-2.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ec9bfaf3ad2df51ace80688143a6a4ebc09a248f6ff781a9945e51937008fcbc", size = 247859, upload-time = "2026-03-25T20:21:17.001Z" }, + { url = "https://files.pythonhosted.org/packages/83/bd/6c1a630eaca337e1e78c5903104f831bda934c426f9231429396ce3c3467/tomli-2.4.1-cp311-cp311-win32.whl", hash = "sha256:ff2983983d34813c1aeb0fa89091e76c3a22889ee83ab27c5eeb45100560c049", size = 97204, upload-time = "2026-03-25T20:21:18.079Z" }, + { url = "https://files.pythonhosted.org/packages/42/59/71461df1a885647e10b6bb7802d0b8e66480c61f3f43079e0dcd315b3954/tomli-2.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:5ee18d9ebdb417e384b58fe414e8d6af9f4e7a0ae761519fb50f721de398dd4e", size = 108084, upload-time = "2026-03-25T20:21:18.978Z" }, + { url = "https://files.pythonhosted.org/packages/b8/83/dceca96142499c069475b790e7913b1044c1a4337e700751f48ed723f883/tomli-2.4.1-cp311-cp311-win_arm64.whl", hash = "sha256:c2541745709bad0264b7d4705ad453b76ccd191e64aa6f0fc66b69a293a45ece", size = 95285, upload-time = "2026-03-25T20:21:20.309Z" }, + { url = "https://files.pythonhosted.org/packages/c1/ba/42f134a3fe2b370f555f44b1d72feebb94debcab01676bf918d0cb70e9aa/tomli-2.4.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c742f741d58a28940ce01d58f0ab2ea3ced8b12402f162f4d534dfe18ba1cd6a", size = 155924, upload-time = "2026-03-25T20:21:21.626Z" }, + { url = "https://files.pythonhosted.org/packages/dc/c7/62d7a17c26487ade21c5422b646110f2162f1fcc95980ef7f63e73c68f14/tomli-2.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7f86fd587c4ed9dd76f318225e7d9b29cfc5a9d43de44e5754db8d1128487085", size = 150018, upload-time = "2026-03-25T20:21:23.002Z" }, + { url = "https://files.pythonhosted.org/packages/5c/05/79d13d7c15f13bdef410bdd49a6485b1c37d28968314eabee452c22a7fda/tomli-2.4.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ff18e6a727ee0ab0388507b89d1bc6a22b138d1e2fa56d1ad494586d61d2eae9", size = 244948, upload-time = "2026-03-25T20:21:24.04Z" }, + { url = "https://files.pythonhosted.org/packages/10/90/d62ce007a1c80d0b2c93e02cab211224756240884751b94ca72df8a875ca/tomli-2.4.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:136443dbd7e1dee43c68ac2694fde36b2849865fa258d39bf822c10e8068eac5", size = 253341, upload-time = "2026-03-25T20:21:25.177Z" }, + { url = "https://files.pythonhosted.org/packages/1a/7e/caf6496d60152ad4ed09282c1885cca4eea150bfd007da84aea07bcc0a3e/tomli-2.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5e262d41726bc187e69af7825504c933b6794dc3fbd5945e41a79bb14c31f585", size = 248159, upload-time = "2026-03-25T20:21:26.364Z" }, + { url = "https://files.pythonhosted.org/packages/99/e7/c6f69c3120de34bbd882c6fba7975f3d7a746e9218e56ab46a1bc4b42552/tomli-2.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5cb41aa38891e073ee49d55fbc7839cfdb2bc0e600add13874d048c94aadddd1", size = 253290, upload-time = "2026-03-25T20:21:27.46Z" }, + { url = "https://files.pythonhosted.org/packages/d6/2f/4a3c322f22c5c66c4b836ec58211641a4067364f5dcdd7b974b4c5da300c/tomli-2.4.1-cp312-cp312-win32.whl", hash = "sha256:da25dc3563bff5965356133435b757a795a17b17d01dbc0f42fb32447ddfd917", size = 98141, upload-time = "2026-03-25T20:21:28.492Z" }, + { url = "https://files.pythonhosted.org/packages/24/22/4daacd05391b92c55759d55eaee21e1dfaea86ce5c571f10083360adf534/tomli-2.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:52c8ef851d9a240f11a88c003eacb03c31fc1c9c4ec64a99a0f922b93874fda9", size = 108847, upload-time = "2026-03-25T20:21:29.386Z" }, + { url = "https://files.pythonhosted.org/packages/68/fd/70e768887666ddd9e9f5d85129e84910f2db2796f9096aa02b721a53098d/tomli-2.4.1-cp312-cp312-win_arm64.whl", hash = "sha256:f758f1b9299d059cc3f6546ae2af89670cb1c4d48ea29c3cacc4fe7de3058257", size = 95088, upload-time = "2026-03-25T20:21:30.677Z" }, + { url = "https://files.pythonhosted.org/packages/07/06/b823a7e818c756d9a7123ba2cda7d07bc2dd32835648d1a7b7b7a05d848d/tomli-2.4.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:36d2bd2ad5fb9eaddba5226aa02c8ec3fa4f192631e347b3ed28186d43be6b54", size = 155866, upload-time = "2026-03-25T20:21:31.65Z" }, + { url = "https://files.pythonhosted.org/packages/14/6f/12645cf7f08e1a20c7eb8c297c6f11d31c1b50f316a7e7e1e1de6e2e7b7e/tomli-2.4.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:eb0dc4e38e6a1fd579e5d50369aa2e10acfc9cace504579b2faabb478e76941a", size = 149887, upload-time = "2026-03-25T20:21:33.028Z" }, + { url = "https://files.pythonhosted.org/packages/5c/e0/90637574e5e7212c09099c67ad349b04ec4d6020324539297b634a0192b0/tomli-2.4.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c7f2c7f2b9ca6bdeef8f0fa897f8e05085923eb091721675170254cbc5b02897", size = 243704, upload-time = "2026-03-25T20:21:34.51Z" }, + { url = "https://files.pythonhosted.org/packages/10/8f/d3ddb16c5a4befdf31a23307f72828686ab2096f068eaf56631e136c1fdd/tomli-2.4.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f3c6818a1a86dd6dca7ddcaaf76947d5ba31aecc28cb1b67009a5877c9a64f3f", size = 251628, upload-time = "2026-03-25T20:21:36.012Z" }, + { url = "https://files.pythonhosted.org/packages/e3/f1/dbeeb9116715abee2485bf0a12d07a8f31af94d71608c171c45f64c0469d/tomli-2.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d312ef37c91508b0ab2cee7da26ec0b3ed2f03ce12bd87a588d771ae15dcf82d", size = 247180, upload-time = "2026-03-25T20:21:37.136Z" }, + { url = "https://files.pythonhosted.org/packages/d3/74/16336ffd19ed4da28a70959f92f506233bd7cfc2332b20bdb01591e8b1d1/tomli-2.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:51529d40e3ca50046d7606fa99ce3956a617f9b36380da3b7f0dd3dd28e68cb5", size = 251674, upload-time = "2026-03-25T20:21:38.298Z" }, + { url = "https://files.pythonhosted.org/packages/16/f9/229fa3434c590ddf6c0aa9af64d3af4b752540686cace29e6281e3458469/tomli-2.4.1-cp313-cp313-win32.whl", hash = "sha256:2190f2e9dd7508d2a90ded5ed369255980a1bcdd58e52f7fe24b8162bf9fedbd", size = 97976, upload-time = "2026-03-25T20:21:39.316Z" }, + { url = "https://files.pythonhosted.org/packages/6a/1e/71dfd96bcc1c775420cb8befe7a9d35f2e5b1309798f009dca17b7708c1e/tomli-2.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:8d65a2fbf9d2f8352685bc1364177ee3923d6baf5e7f43ea4959d7d8bc326a36", size = 108755, upload-time = "2026-03-25T20:21:40.248Z" }, + { url = "https://files.pythonhosted.org/packages/83/7a/d34f422a021d62420b78f5c538e5b102f62bea616d1d75a13f0a88acb04a/tomli-2.4.1-cp313-cp313-win_arm64.whl", hash = "sha256:4b605484e43cdc43f0954ddae319fb75f04cc10dd80d830540060ee7cd0243cd", size = 95265, upload-time = "2026-03-25T20:21:41.219Z" }, + { url = "https://files.pythonhosted.org/packages/3c/fb/9a5c8d27dbab540869f7c1f8eb0abb3244189ce780ba9cd73f3770662072/tomli-2.4.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:fd0409a3653af6c147209d267a0e4243f0ae46b011aa978b1080359fddc9b6cf", size = 155726, upload-time = "2026-03-25T20:21:42.23Z" }, + { url = "https://files.pythonhosted.org/packages/62/05/d2f816630cc771ad836af54f5001f47a6f611d2d39535364f148b6a92d6b/tomli-2.4.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:a120733b01c45e9a0c34aeef92bf0cf1d56cfe81ed9d47d562f9ed591a9828ac", size = 149859, upload-time = "2026-03-25T20:21:43.386Z" }, + { url = "https://files.pythonhosted.org/packages/ce/48/66341bdb858ad9bd0ceab5a86f90eddab127cf8b046418009f2125630ecb/tomli-2.4.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:559db847dc486944896521f68d8190be1c9e719fced785720d2216fe7022b662", size = 244713, upload-time = "2026-03-25T20:21:44.474Z" }, + { url = "https://files.pythonhosted.org/packages/df/6d/c5fad00d82b3c7a3ab6189bd4b10e60466f22cfe8a08a9394185c8a8111c/tomli-2.4.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01f520d4f53ef97964a240a035ec2a869fe1a37dde002b57ebc4417a27ccd853", size = 252084, upload-time = "2026-03-25T20:21:45.62Z" }, + { url = "https://files.pythonhosted.org/packages/00/71/3a69e86f3eafe8c7a59d008d245888051005bd657760e96d5fbfb0b740c2/tomli-2.4.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:7f94b27a62cfad8496c8d2513e1a222dd446f095fca8987fceef261225538a15", size = 247973, upload-time = "2026-03-25T20:21:46.937Z" }, + { url = "https://files.pythonhosted.org/packages/67/50/361e986652847fec4bd5e4a0208752fbe64689c603c7ae5ea7cb16b1c0ca/tomli-2.4.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:ede3e6487c5ef5d28634ba3f31f989030ad6af71edfb0055cbbd14189ff240ba", size = 256223, upload-time = "2026-03-25T20:21:48.467Z" }, + { url = "https://files.pythonhosted.org/packages/8c/9a/b4173689a9203472e5467217e0154b00e260621caa227b6fa01feab16998/tomli-2.4.1-cp314-cp314-win32.whl", hash = "sha256:3d48a93ee1c9b79c04bb38772ee1b64dcf18ff43085896ea460ca8dec96f35f6", size = 98973, upload-time = "2026-03-25T20:21:49.526Z" }, + { url = "https://files.pythonhosted.org/packages/14/58/640ac93bf230cd27d002462c9af0d837779f8773bc03dee06b5835208214/tomli-2.4.1-cp314-cp314-win_amd64.whl", hash = "sha256:88dceee75c2c63af144e456745e10101eb67361050196b0b6af5d717254dddf7", size = 109082, upload-time = "2026-03-25T20:21:50.506Z" }, + { url = "https://files.pythonhosted.org/packages/d5/2f/702d5e05b227401c1068f0d386d79a589bb12bf64c3d2c72ce0631e3bc49/tomli-2.4.1-cp314-cp314-win_arm64.whl", hash = "sha256:b8c198f8c1805dc42708689ed6864951fd2494f924149d3e4bce7710f8eb5232", size = 96490, upload-time = "2026-03-25T20:21:51.474Z" }, + { url = "https://files.pythonhosted.org/packages/45/4b/b877b05c8ba62927d9865dd980e34a755de541eb65fffba52b4cc495d4d2/tomli-2.4.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:d4d8fe59808a54658fcc0160ecfb1b30f9089906c50b23bcb4c69eddc19ec2b4", size = 164263, upload-time = "2026-03-25T20:21:52.543Z" }, + { url = "https://files.pythonhosted.org/packages/24/79/6ab420d37a270b89f7195dec5448f79400d9e9c1826df982f3f8e97b24fd/tomli-2.4.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7008df2e7655c495dd12d2a4ad038ff878d4ca4b81fccaf82b714e07eae4402c", size = 160736, upload-time = "2026-03-25T20:21:53.674Z" }, + { url = "https://files.pythonhosted.org/packages/02/e0/3630057d8eb170310785723ed5adcdfb7d50cb7e6455f85ba8a3deed642b/tomli-2.4.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1d8591993e228b0c930c4bb0db464bdad97b3289fb981255d6c9a41aedc84b2d", size = 270717, upload-time = "2026-03-25T20:21:55.129Z" }, + { url = "https://files.pythonhosted.org/packages/7a/b4/1613716072e544d1a7891f548d8f9ec6ce2faf42ca65acae01d76ea06bb0/tomli-2.4.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:734e20b57ba95624ecf1841e72b53f6e186355e216e5412de414e3c51e5e3c41", size = 278461, upload-time = "2026-03-25T20:21:56.228Z" }, + { url = "https://files.pythonhosted.org/packages/05/38/30f541baf6a3f6df77b3df16b01ba319221389e2da59427e221ef417ac0c/tomli-2.4.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:8a650c2dbafa08d42e51ba0b62740dae4ecb9338eefa093aa5c78ceb546fcd5c", size = 274855, upload-time = "2026-03-25T20:21:57.653Z" }, + { url = "https://files.pythonhosted.org/packages/77/a3/ec9dd4fd2c38e98de34223b995a3b34813e6bdadf86c75314c928350ed14/tomli-2.4.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:504aa796fe0569bb43171066009ead363de03675276d2d121ac1a4572397870f", size = 283144, upload-time = "2026-03-25T20:21:59.089Z" }, + { url = "https://files.pythonhosted.org/packages/ef/be/605a6261cac79fba2ec0c9827e986e00323a1945700969b8ee0b30d85453/tomli-2.4.1-cp314-cp314t-win32.whl", hash = "sha256:b1d22e6e9387bf4739fbe23bfa80e93f6b0373a7f1b96c6227c32bef95a4d7a8", size = 108683, upload-time = "2026-03-25T20:22:00.214Z" }, + { url = "https://files.pythonhosted.org/packages/12/64/da524626d3b9cc40c168a13da8335fe1c51be12c0a63685cc6db7308daae/tomli-2.4.1-cp314-cp314t-win_amd64.whl", hash = "sha256:2c1c351919aca02858f740c6d33adea0c5deea37f9ecca1cc1ef9e884a619d26", size = 121196, upload-time = "2026-03-25T20:22:01.169Z" }, + { url = "https://files.pythonhosted.org/packages/5a/cd/e80b62269fc78fc36c9af5a6b89c835baa8af28ff5ad28c7028d60860320/tomli-2.4.1-cp314-cp314t-win_arm64.whl", hash = "sha256:eab21f45c7f66c13f2a9e0e1535309cee140182a9cdae1e041d02e47291e8396", size = 100393, upload-time = "2026-03-25T20:22:02.137Z" }, + { url = "https://files.pythonhosted.org/packages/7b/61/cceae43728b7de99d9b847560c262873a1f6c98202171fd5ed62640b494b/tomli-2.4.1-py3-none-any.whl", hash = "sha256:0d85819802132122da43cb86656f8d1f8c6587d54ae7dcaf30e90533028b49fe", size = 14583, upload-time = "2026-03-25T20:22:03.012Z" }, +] + +[[package]] +name = "torch" +version = "2.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-bindings", marker = "sys_platform == 'linux'" }, + { name = "cuda-toolkit", extra = ["cublas", "cudart", "cufft", "cufile", "cupti", "curand", "cusolver", "cusparse", "nvjitlink", "nvrtc", "nvtx"], marker = "sys_platform == 'linux'" }, + { name = "filelock" }, + { name = "fsspec" }, + { name = "jinja2" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "networkx", version = "3.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "nvidia-cudnn-cu13", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu13", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu13", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvshmem-cu13", marker = "sys_platform == 'linux'" }, + { name = "setuptools" }, + { name = "sympy" }, + { name = "triton", marker = "sys_platform == 'linux'" }, + { name = "typing-extensions" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/ac/f2/c1690994afe461aae2d0cac62251e6802a703dec0a6c549c02ecd0de92a9/torch-2.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2c0d7fcfbc0c4e8bb5ebc3907cbc0c6a0da1b8f82b1fc6e14e914fa0b9baf74e", size = 80526521, upload-time = "2026-03-23T18:12:06.86Z" }, + { url = "https://files.pythonhosted.org/packages/a4/f0/98ae802fa8c09d3149b0c8690741f3f5753c90e779bd28c9613257295945/torch-2.11.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:4cf8687f4aec3900f748d553483ef40e0ac38411c3c48d0a86a438f6d7a99b18", size = 419723025, upload-time = "2026-03-23T18:11:43.774Z" }, + { url = "https://files.pythonhosted.org/packages/f9/1e/18a9b10b4bd34f12d4e561c52b0ae7158707b8193c6cfc0aad2b48167090/torch-2.11.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:1b32ceda909818a03b112006709b02be1877240c31750a8d9c6b7bf5f2d8a6e5", size = 530589207, upload-time = "2026-03-23T18:11:23.756Z" }, + { url = "https://files.pythonhosted.org/packages/35/40/2d532e8c0e23705be9d1debce5bc37b68d59a39bda7584c26fe9668076fe/torch-2.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:b3c712ae6fb8e7a949051a953fc412fe0a6940337336c3b6f905e905dac5157f", size = 114518313, upload-time = "2026-03-23T18:11:58.281Z" }, + { url = "https://files.pythonhosted.org/packages/ae/0d/98b410492609e34a155fa8b121b55c7dca229f39636851c3a9ec20edea21/torch-2.11.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7b6a60d48062809f58595509c524b88e6ddec3ebe25833d6462eeab81e5f2ce4", size = 80529712, upload-time = "2026-03-23T18:12:02.608Z" }, + { url = "https://files.pythonhosted.org/packages/84/03/acea680005f098f79fd70c1d9d5ccc0cb4296ec2af539a0450108232fc0c/torch-2.11.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:d91aac77f24082809d2c5a93f52a5f085032740a1ebc9252a7b052ef5a4fddc6", size = 419718178, upload-time = "2026-03-23T18:10:46.675Z" }, + { url = "https://files.pythonhosted.org/packages/8c/8b/d7be22fbec9ffee6cff31a39f8750d4b3a65d349a286cf4aec74c2375662/torch-2.11.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:7aa2f9bbc6d4595ba72138026b2074be1233186150e9292865e04b7a63b8c67a", size = 530604548, upload-time = "2026-03-23T18:10:03.569Z" }, + { url = "https://files.pythonhosted.org/packages/d1/bd/9912d30b68845256aabbb4a40aeefeef3c3b20db5211ccda653544ada4b6/torch-2.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:73e24aaf8f36ab90d95cd1761208b2eb70841c2a9ca1a3f9061b39fc5331b708", size = 114519675, upload-time = "2026-03-23T18:11:52.995Z" }, + { url = "https://files.pythonhosted.org/packages/6f/8b/69e3008d78e5cee2b30183340cc425081b78afc5eff3d080daab0adda9aa/torch-2.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4b5866312ee6e52ea625cd211dcb97d6a2cdc1131a5f15cc0d87eec948f6dd34", size = 80606338, upload-time = "2026-03-23T18:11:34.781Z" }, + { url = "https://files.pythonhosted.org/packages/13/16/42e5915ebe4868caa6bac83a8ed59db57f12e9a61b7d749d584776ed53d5/torch-2.11.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:f99924682ef0aa6a4ab3b1b76f40dc6e273fca09f367d15a524266db100a723f", size = 419731115, upload-time = "2026-03-23T18:11:06.944Z" }, + { url = "https://files.pythonhosted.org/packages/1a/c9/82638ef24d7877510f83baf821f5619a61b45568ce21c0a87a91576510aa/torch-2.11.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:0f68f4ac6d95d12e896c3b7a912b5871619542ec54d3649cf48cc1edd4dd2756", size = 530712279, upload-time = "2026-03-23T18:10:31.481Z" }, + { url = "https://files.pythonhosted.org/packages/1c/ff/6756f1c7ee302f6d202120e0f4f05b432b839908f9071157302cedfc5232/torch-2.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:fbf39280699d1b869f55eac536deceaa1b60bd6788ba74f399cc67e60a5fab10", size = 114556047, upload-time = "2026-03-23T18:10:55.931Z" }, + { url = "https://files.pythonhosted.org/packages/87/89/5ea6722763acee56b045435fb84258db7375c48165ec8be7880ab2b281c5/torch-2.11.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1e6debd97ccd3205bbb37eb806a9d8219e1139d15419982c09e23ef7d4369d18", size = 80606801, upload-time = "2026-03-23T18:10:18.649Z" }, + { url = "https://files.pythonhosted.org/packages/32/d1/8ed2173589cbfe744ed54e5a73efc107c0085ba5777ee93a5f4c1ab90553/torch-2.11.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:63a68fa59de8f87acc7e85a5478bb2dddbb3392b7593ec3e78827c793c4b73fd", size = 419732382, upload-time = "2026-03-23T18:08:30.835Z" }, + { url = "https://files.pythonhosted.org/packages/3d/e1/b73f7c575a4b8f87a5928f50a1e35416b5e27295d8be9397d5293e7e8d4c/torch-2.11.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:cc89b9b173d9adfab59fd227f0ab5e5516d9a52b658ae41d64e59d2e55a418db", size = 530711509, upload-time = "2026-03-23T18:08:47.213Z" }, + { url = "https://files.pythonhosted.org/packages/66/82/3e3fcdd388fbe54e29fd3f991f36846ff4ac90b0d0181e9c8f7236565f82/torch-2.11.0-cp313-cp313-win_amd64.whl", hash = "sha256:4dda3b3f52d121063a731ddb835f010dc137b920d7fec2778e52f60d8e4bf0cd", size = 114555842, upload-time = "2026-03-23T18:09:52.111Z" }, + { url = "https://files.pythonhosted.org/packages/db/38/8ac78069621b8c2b4979c2f96dc8409ef5e9c4189f6aac629189a78677ca/torch-2.11.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8b394322f49af4362d4f80e424bcaca7efcd049619af03a4cf4501520bdf0fb4", size = 80959574, upload-time = "2026-03-23T18:10:14.214Z" }, + { url = "https://files.pythonhosted.org/packages/6d/6c/56bfb37073e7136e6dd86bfc6af7339946dd684e0ecf2155ac0eee687ae1/torch-2.11.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2658f34ce7e2dabf4ec73b45e2ca68aedad7a5be87ea756ad656eaf32bf1e1ea", size = 419732324, upload-time = "2026-03-23T18:09:36.604Z" }, + { url = "https://files.pythonhosted.org/packages/07/f4/1b666b6d61d3394cca306ea543ed03a64aad0a201b6cd159f1d41010aeb1/torch-2.11.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:98bb213c3084cfe176302949bdc360074b18a9da7ab59ef2edc9d9f742504778", size = 530596026, upload-time = "2026-03-23T18:09:20.842Z" }, + { url = "https://files.pythonhosted.org/packages/48/6b/30d1459fa7e4b67e9e3fe1685ca1d8bb4ce7c62ef436c3a615963c6c866c/torch-2.11.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a97b94bbf62992949b4730c6cd2cc9aee7b335921ee8dc207d930f2ed09ae2db", size = 114793702, upload-time = "2026-03-23T18:09:47.304Z" }, + { url = "https://files.pythonhosted.org/packages/26/0d/8603382f61abd0db35841148ddc1ffd607bf3100b11c6e1dab6d2fc44e72/torch-2.11.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:01018087326984a33b64e04c8cb5c2795f9120e0d775ada1f6638840227b04d7", size = 80573442, upload-time = "2026-03-23T18:09:10.117Z" }, + { url = "https://files.pythonhosted.org/packages/c7/86/7cd7c66cb9cec6be330fff36db5bd0eef386d80c031b581ec81be1d4b26c/torch-2.11.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:2bb3cc54bd0dea126b0060bb1ec9de0f9c7f7342d93d436646516b0330cd5be7", size = 419749385, upload-time = "2026-03-23T18:07:33.77Z" }, + { url = "https://files.pythonhosted.org/packages/47/e8/b98ca2d39b2e0e4730c0ee52537e488e7008025bc77ca89552ff91021f7c/torch-2.11.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:4dc8b3809469b6c30b411bb8c4cad3828efd26236153d9beb6a3ec500f211a60", size = 530716756, upload-time = "2026-03-23T18:07:50.02Z" }, + { url = "https://files.pythonhosted.org/packages/78/88/d4a4cda8362f8a30d1ed428564878c3cafb0d87971fbd3947d4c84552095/torch-2.11.0-cp314-cp314-win_amd64.whl", hash = "sha256:2b4e811728bd0cc58fb2b0948fe939a1ee2bf1422f6025be2fca4c7bd9d79718", size = 114552300, upload-time = "2026-03-23T18:09:05.617Z" }, + { url = "https://files.pythonhosted.org/packages/bf/46/4419098ed6d801750f26567b478fc185c3432e11e2cad712bc6b4c2ab0d0/torch-2.11.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:8245477871c3700d4370352ffec94b103cfcb737229445cf9946cddb7b2ca7cd", size = 80959460, upload-time = "2026-03-23T18:09:00.818Z" }, + { url = "https://files.pythonhosted.org/packages/fd/66/54a56a4a6ceaffb567231994a9745821d3af922a854ed33b0b3a278e0a99/torch-2.11.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:ab9a8482f475f9ba20e12db84b0e55e2f58784bdca43a854a6ccd3fd4b9f75e6", size = 419735835, upload-time = "2026-03-23T18:07:18.974Z" }, + { url = "https://files.pythonhosted.org/packages/b1/e7/0b6665f533aa9e337662dc190425abc0af1fe3234088f4454c52393ded61/torch-2.11.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:563ed3d25542d7e7bbc5b235ccfacfeb97fb470c7fee257eae599adb8005c8a2", size = 530613405, upload-time = "2026-03-23T18:08:07.014Z" }, + { url = "https://files.pythonhosted.org/packages/cf/bf/c8d12a2c86dbfd7f40fb2f56fbf5a505ccf2d9ce131eb559dfc7c51e1a04/torch-2.11.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b2a43985ff5ef6ddd923bbcf99943e5f58059805787c5c9a2622bf05ca2965b0", size = 114792991, upload-time = "2026-03-23T18:08:19.216Z" }, +] + +[[package]] +name = "triton" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/ba/b1b04f4b291a3205d95ebd24465de0e5bf010a2df27a4e58a9b5f039d8f2/triton-3.6.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c723cfb12f6842a0ae94ac307dba7e7a44741d720a40cf0e270ed4a4e3be781", size = 175972180, upload-time = "2026-01-20T16:15:53.664Z" }, + { url = "https://files.pythonhosted.org/packages/8c/f7/f1c9d3424ab199ac53c2da567b859bcddbb9c9e7154805119f8bd95ec36f/triton-3.6.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6550fae429e0667e397e5de64b332d1e5695b73650ee75a6146e2e902770bea", size = 188105201, upload-time = "2026-01-20T16:00:29.272Z" }, + { url = "https://files.pythonhosted.org/packages/0f/2c/96f92f3c60387e14cc45aed49487f3486f89ea27106c1b1376913c62abe4/triton-3.6.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49df5ef37379c0c2b5c0012286f80174fcf0e073e5ade1ca9a86c36814553651", size = 176081190, upload-time = "2026-01-20T16:16:00.523Z" }, + { url = "https://files.pythonhosted.org/packages/e0/12/b05ba554d2c623bffa59922b94b0775673de251f468a9609bc9e45de95e9/triton-3.6.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8e323d608e3a9bfcc2d9efcc90ceefb764a82b99dea12a86d643c72539ad5d3", size = 188214640, upload-time = "2026-01-20T16:00:35.869Z" }, + { url = "https://files.pythonhosted.org/packages/17/5d/08201db32823bdf77a0e2b9039540080b2e5c23a20706ddba942924ebcd6/triton-3.6.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:374f52c11a711fd062b4bfbb201fd9ac0a5febd28a96fb41b4a0f51dde3157f4", size = 176128243, upload-time = "2026-01-20T16:16:07.857Z" }, + { url = "https://files.pythonhosted.org/packages/ab/a8/cdf8b3e4c98132f965f88c2313a4b493266832ad47fb52f23d14d4f86bb5/triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74caf5e34b66d9f3a429af689c1c7128daba1d8208df60e81106b115c00d6fca", size = 188266850, upload-time = "2026-01-20T16:00:43.041Z" }, + { url = "https://files.pythonhosted.org/packages/3c/12/34d71b350e89a204c2c7777a9bba0dcf2f19a5bfdd70b57c4dbc5ffd7154/triton-3.6.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:448e02fe6dc898e9e5aa89cf0ee5c371e99df5aa5e8ad976a80b93334f3494fd", size = 176133521, upload-time = "2026-01-20T16:16:13.321Z" }, + { url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" }, + { url = "https://files.pythonhosted.org/packages/ce/4e/41b0c8033b503fd3cfcd12392cdd256945026a91ff02452bef40ec34bee7/triton-3.6.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1722e172d34e32abc3eb7711d0025bb69d7959ebea84e3b7f7a341cd7ed694d6", size = 176276087, upload-time = "2026-01-20T16:16:18.989Z" }, + { url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" }, + { url = "https://files.pythonhosted.org/packages/49/55/5ecf0dcaa0f2fbbd4420f7ef227ee3cb172e91e5fede9d0ecaddc43363b4/triton-3.6.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef5523241e7d1abca00f1d240949eebdd7c673b005edbbce0aca95b8191f1d43", size = 176138577, upload-time = "2026-01-20T16:16:25.426Z" }, + { url = "https://files.pythonhosted.org/packages/df/3d/9e7eee57b37c80cec63322c0231bb6da3cfe535a91d7a4d64896fcb89357/triton-3.6.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a17a5d5985f0ac494ed8a8e54568f092f7057ef60e1b0fa09d3fd1512064e803", size = 188273063, upload-time = "2026-01-20T16:01:07.278Z" }, + { url = "https://files.pythonhosted.org/packages/48/db/56ee649cab5eaff4757541325aca81f52d02d4a7cd3506776cad2451e060/triton-3.6.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b3a97e8ed304dfa9bd23bb41ca04cdf6b2e617d5e782a8653d616037a5d537d", size = 176274804, upload-time = "2026-01-20T16:16:31.528Z" }, + { url = "https://files.pythonhosted.org/packages/f6/56/6113c23ff46c00aae423333eb58b3e60bdfe9179d542781955a5e1514cb3/triton-3.6.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46bd1c1af4b6704e554cad2eeb3b0a6513a980d470ccfa63189737340c7746a7", size = 188397994, upload-time = "2026-01-20T16:01:14.236Z" }, +] + +[[package]] +name = "typing-extensions" +version = "4.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, +] From 7c670871721ddad3178f6660ece51f49c309562e Mon Sep 17 00:00:00 2001 From: Daniel Stair Date: Wed, 22 Apr 2026 02:39:10 +0000 Subject: [PATCH 2/5] Condense README with benchmarks; add NxDI modular compilation markers --- contrib/models/jepa-2-1/AGENT.md | 149 +++++----- contrib/models/jepa-2-1/PLAN.md | 132 ++++----- contrib/models/jepa-2-1/PLAN_trn2.md | 272 +++++------------- contrib/models/jepa-2-1/README.md | 141 ++++----- .../models/jepa-2-1/src/modeling_jepa21.py | 21 ++ 5 files changed, 277 insertions(+), 438 deletions(-) diff --git a/contrib/models/jepa-2-1/AGENT.md b/contrib/models/jepa-2-1/AGENT.md index 1d0e375e..575146f4 100644 --- a/contrib/models/jepa-2-1/AGENT.md +++ b/contrib/models/jepa-2-1/AGENT.md @@ -1,105 +1,118 @@ # AGENT.md — V-JEPA 2.1 Neuron Port Technical Reference -## Source Code Location +## Source Code -- **Original model code**: `~/dev/vjepa2/` - - V-JEPA 2 encoder: `src/models/vision_transformer.py` → `VisionTransformer` - - V-JEPA 2 predictor: `src/models/predictor.py` → `VisionTransformerPredictor` - - V-JEPA 2 AC predictor: `src/models/ac_predictor.py` → `VisionTransformerPredictorAC` - - V-JEPA 2.1 encoder: `app/vjepa_2_1/models/vision_transformer.py` → `VisionTransformer` (extended) - - V-JEPA 2.1 predictor: `app/vjepa_2_1/models/predictor.py` → `VisionTransformerPredictor` (extended) - - V-JEPA 2.1 modules: `app/vjepa_2_1/models/utils/modules.py` (Block, RoPEAttention, MLP, SwiGLUFFN) - - V-JEPA 2.1 patch embed: `app/vjepa_2_1/models/utils/patch_embed.py` (PatchEmbed, PatchEmbed3D) - - Hub/loading: `src/hub/backbones.py` (checkpoint loading, arch configs) - - Attentive pooler: `src/models/attentive_pooler.py` +- **Upstream**: `~/dev/vjepa2/` (Meta's vjepa2 repo) +- **Neuron port**: `src/modeling_jepa21.py` — self-contained, no upstream imports -## Architecture Details +## Architecture ### Encoder (VisionTransformer) -The V-JEPA 2.1 encoder is a standard ViT with these key features: +Standard ViT with 3D-RoPE, bidirectional attention, hierarchical output. -1. **Patch Embedding**: `PatchEmbed3D` with Conv3d kernel (tubelet_size, patch_size, patch_size). For video: stride matches kernel. For images: uses separate `PatchEmbed3D` with `tubelet_size=1` when `img_temporal_dim_size` is set. +| Arch | Params | embed_dim | depth | num_heads | head_dim | mlp_ratio | +|------|--------|-----------|-------|-----------|----------|-----------| +| vit_base | 86M | 768 | 12 | 12 | 64 | 4.0 | +| vit_large | 300M | 1024 | 24 | 16 | 64 | 4.0 | +| vit_giant | 1.01B | 1408 | 40 | 22 | 64 | 48/11 | +| vit_gigantic | 1.8B | 1664 | 48 | 26 | 64 | 64/13 | -2. **3D-RoPE**: Rotary position embeddings applied separately to depth/height/width dimensions. Head dim is split into 3 roughly equal segments: - - `d_dim = 2 * ((head_dim // 3) // 2)` — temporal - - `h_dim = 2 * ((head_dim // 3) // 2)` — height - - `w_dim = 2 * ((head_dim // 3) // 2)` — width - - Remaining dims get no rotation +### Token Counts (384×384, patch_size=16, tubelet_size=2) -3. **RoPE Bug**: The `rotate_queries_or_keys` function in V-JEPA 2 uses `.repeat(1,1,1,2)` instead of `.repeat_interleave(2, dim=-1)`. V-JEPA 2.1 fixes this with `repeat_interleave`. Both are preserved for checkpoint compatibility. +- 16 frames: 8 × 24 × 24 = **4,608 tokens** +- 64 frames: 32 × 24 × 24 = 18,432 tokens -4. **Hierarchical Output**: The encoder outputs features from multiple intermediate layers. For ViT-L (depth=24): layers [5, 11, 17, 23]. During training, these are concatenated along the feature dimension. During inference, only the last layer's normed output is returned (unless `return_hierarchical=True`). +### Key Features -5. **Modality Embeddings**: Separate learned embeddings for image vs video input, added after patch embedding. +- **PatchEmbed3D**: Conv3d with kernel=stride=(tubelet_size, patch_size, patch_size) +- **3D-RoPE**: Separate rotations for depth/height/width on head_dim slices (d_dim=h_dim=w_dim=20 for head_dim=64, 4 dims unrotated) +- **Hierarchical output**: Normed features from intermediate layers (e.g., [5,11,17,23] for depth=24). Inference returns only the last layer's normed output. +- **Modality embeddings**: Separate learned embeddings for image vs video +- **interpolate_rope**: Scales RoPE positions for resolution flexibility -6. **interpolate_rope**: V-JEPA 2.1 adds RoPE interpolation for resolution flexibility. Height/width positions are scaled by `(pretrained_grid_size - 1) / (actual_grid_size - 1)`. +## Neuron Compilation — Verified Findings -### Model Configurations +### What works with `torch_neuronx.trace()` (neuronx-cc 2.24.5133) -| Arch | embed_dim | depth | num_heads | mlp_ratio | head_dim | d/h/w_dim | -|------|-----------|-------|-----------|-----------|----------|-----------| -| vit_base | 768 | 12 | 12 | 4.0 | 64 | 20, 20, 20 | -| vit_large | 1024 | 24 | 16 | 4.0 | 64 | 20, 20, 20 | -| vit_giant_xformers | 1408 | 40 | 22 | 48/11 | 64 | 20, 20, 20 | -| vit_gigantic_xformers | 1664 | 48 | 26 | 64/13 | 64 | 20, 20, 20 | +- **Conv3d**: Compiles natively. No decomposition needed. +- **`torch.arange` in RoPE**: Compiles natively for fixed input shapes. No precomputation needed. +- **`repeat_interleave`**: Compiles natively. No reshape/expand workaround needed. +- **Manual attention** (`q @ k.T * scale → softmax → @ v`): Works correctly with `use_sdpa=False`. +- **BF16 inference**: Works with `--auto-cast none`. Cast model to `.bfloat16()` and use BF16 input tensors. -### Token Counts +### What does NOT work -For 384×384 resolution, patch_size=16, tubelet_size=2: -- Image (1 frame, tubelet_size=1): 1 × 24 × 24 = 576 tokens -- 16 frames: 8 × 24 × 24 = 4,608 tokens -- 64 frames: 32 × 24 × 24 = 18,432 tokens +- **`F.scaled_dot_product_attention`**: Not supported by `torch_neuronx.trace()`. Must use `use_sdpa=False`. +- **BF16 softmax on CPU**: `softmax()` promotes BF16→FP32, causing dtype mismatch with V tensor. Fixed with `.to(v.dtype)` after softmax. +- **ViT-g/ViT-G monolithic compilation**: neuronx-cc OOMs on host (>124GB RAM needed for 40+ layer graph). See "Modular Compilation" below. + +### NKI Flash Attention (`attention_isa_kernel`) -### Attention Pattern +Integrated the NxDI production NKI flash attention kernel for bidirectional attention. -Standard bidirectional self-attention (no causal mask) for the encoder. The AC predictor uses block-causal attention for autoregressive frame prediction. +**Interface** (from `neuronxcc.nki._private_kernels.attention`): +```python +from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +from torch_neuronx.xla_impl.ops import nki_jit +_flash = nki_jit()(attention_isa_kernel) -## Neuron Porting Considerations +# q: (B*H, d_head, seqlen), k: (B*H, d_head, seqlen), v: (B*H, seqlen, d_head) +# out: pre-allocated zeros (B*H, seqlen, d_head) +_flash(q, k, v, scale, out, kernel_name="AttentionMMSoftmaxMMWithoutSwap") +``` -### Compilation Approach: `torch_neuronx.trace()` +**Result**: Higher numerical accuracy (cos_sim 0.9999 vs 0.9998) but **slower** at 4608 tokens — 307ms vs 165ms for ViT-B. The kernel overhead (reshape, launch) outweighs the flash attention benefit at this sequence length. The kernel is designed for 16K+ tokens. Use `use_nki_flash=False` for 16-frame inference. -The encoder is a standard feedforward ViT — no KV cache, no autoregressive generation. `torch_neuronx.trace()` is the right tool. +**Important**: The NKI kernel cannot run on CPU. When using `use_nki_flash=True`, build a separate CPU reference model with `use_nki_flash=False` for validation. The kernel only executes during XLA tracing. -### Potential Issues +### Modular Compilation (Layer Boundary Markers) -1. **`F.scaled_dot_product_attention`**: Neuron compiler may not support all SDPA backends. May need to fall back to manual attention (`q @ k.T * scale → softmax → @ v`). Set `use_sdpa=False` in the model config. +Added `ModuleMarkerStartWrapper`/`ModuleMarkerEndWrapper` from NxDI to split the compiler graph into groups of N layers. Controlled by `modular_compilation_group_size` parameter. -2. **`timm.models.layers.drop_path`**: External dependency. For inference (eval mode), drop_path is identity. Can be replaced with `nn.Identity()`. +**Status**: Markers are inserted correctly and validated on ViT-B (identical output and latency to baseline). However, **`torch_neuronx.trace()` does not respect the markers for graph splitting** — ViT-g still OOMs with group_size=8. The markers are likely only respected by `neuronx_distributed.trace.parallel_model_trace`. -3. **Dynamic shapes**: The encoder supports variable frame counts and resolutions via RoPE interpolation. For Neuron compilation, fix the input shape at trace time. Compile separate models for different input shapes if needed. +**Next step**: Use `parallel_model_trace` from NxD instead of `torch_neuronx.trace()`, or compile on a larger instance (trn2.48xlarge with 2TB RAM). -4. **`torch.arange` in RoPE**: Dynamic tensor creation inside forward pass. Neuron compiler should handle this for fixed input shapes, but verify. +### DataParallel Throughput -5. **Large attention matrices**: At 64 frames × 384px, the attention matrix is 18432×18432. This may exceed single-core HBM. Options: - - Use shorter clips (16 frames → 4608 tokens, manageable) - - Use TP>1 via NxDI if needed for long clips - - Use NKI flash attention kernels +`torch_neuronx.DataParallel` distributes inference across NeuronCores with zero model changes: +```python +model_dp = torch_neuronx.DataParallel(traced_model) +output = model_dp(batched_input) # splits batch across cores +``` +trn2.3xlarge has 2 logical NeuronCores → 2x throughput. Scales linearly with batch size. -6. **Conv3d patch embedding**: Verify Neuron compiler support for 3D convolutions. If unsupported, can be decomposed into reshape + Conv2d. +## Instance Details -### Weight Loading +- **Type**: trn2.3xlarge (persistent spot) in sa-east-1 +- **Instance ID**: i-0cae7b2ac61807cf9 +- **SSH**: `ssh -i ~/.ssh/trn2-sa-east-1.pem ubuntu@52.67.239.128` +- **Hardware**: 1 Neuron device, 2 logical NeuronCores, 96 GB HBM, 124 GB system RAM +- **Neuron SDK**: torch-neuronx 2.9.0, neuronx-cc 2.24.5133, NxDI 0.9.17334 +- **Venv**: `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate` -Checkpoints are loaded via `torch.hub.load_state_dict_from_url`. The state dict has keys prefixed with `module.` and `backbone.` which are stripped by `_clean_backbone_key()`. For V-JEPA 2.1 distilled models, the encoder key is `ema_encoder` (not `target_encoder`). +## Workflow -### Inference-Only Simplifications +```bash +# Sync local → trn2 +rsync -avz --exclude='__pycache__' --exclude='._*' \ + ~/dev/neuron-docs/neuronx-distributed-inference/contrib/models/jepa-2-1/ \ + -e "ssh -i ~/.ssh/trn2-sa-east-1.pem" ubuntu@52.67.239.128:jepa-2-1/ -For inference, these training-only features can be removed: -- Mask application (`apply_masks`) — not used during inference -- Drop path — identity at eval -- Predictor — only needed for pretraining/anticipation -- Activation checkpointing — only for training memory savings +# Run on trn2 +ssh -i ~/.ssh/trn2-sa-east-1.pem ubuntu@52.67.239.128 \ + "cd jepa-2-1 && source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate && python ..." +``` -## Reference Patterns +## Weight Loading -### NxDI Contrib Structure -See `~/dev/Neuron-steering-docs/steering/nxdi-contrib.md` for submission requirements. +Checkpoints loaded via `torch.hub.load_state_dict_from_url`. State dict keys prefixed with `module.` and `backbone.` are stripped. Distilled models (ViT-B, ViT-L) use key `ema_encoder`; self-supervised (ViT-g, ViT-G) use `target_encoder`. -### Neuron SDK Docs -See `~/dev/neuron-docs/` for: -- `neuronx-distributed/` — distributed inference patterns -- `nki-library/` — NKI kernel examples (flash attention, etc.) +## Reference Code in ~/dev/neuron-docs/ -### Similar Ports -- Vision-language models in NxDI (Qwen-VL, MLLama) have image encoder components -- The Flux diffusion model in NxDI uses TP + NKI attention for large sequence lengths +- `nki-library/src/.../core/attention/` — NKI flash attention kernels (production) +- `nki-library/src/.../core/embeddings/rope.py` — NKI RoPE kernel +- `neuronx-distributed-inference/src/.../models/diffusers/flux/` — Flux model (non-autoregressive, uses NKI attention + modular markers) +- `neuronx-distributed-inference/src/.../models/mllama/modeling_mllama_vision.py` — MLLama vision encoder (uses NKI attention) +- `neuronx-distributed-inference/src/.../models/layer_boundary_marker.py` — ModuleMarkerStart/End for modular compilation diff --git a/contrib/models/jepa-2-1/PLAN.md b/contrib/models/jepa-2-1/PLAN.md index 8ba0404c..c23f7f1c 100644 --- a/contrib/models/jepa-2-1/PLAN.md +++ b/contrib/models/jepa-2-1/PLAN.md @@ -1,82 +1,58 @@ # PLAN.md — V-JEPA 2.1 Neuron Port Roadmap -## Current Status: Phase 1 — Initial Port (CPU-only) +## Phase 1 — Initial Port (CPU-only) ✅ COMPLETE -### Completed -- [x] Read and analyzed V-JEPA 2.1 source code (encoder, predictor, AC predictor, modules) -- [x] Read the paper (arxiv 2506.09985) +- [x] Analyzed V-JEPA 2.1 source code and paper +- [x] Created self-contained `modeling_jepa21.py` with no upstream imports +- [x] Created CPU-only unit tests - [x] Created project structure following NxDI contrib conventions -- [x] Created self-contained encoder module (`modeling_jepa21.py`) with no upstream imports -- [x] Created CPU-only unit tests for encoder forward pass -- [x] Created README.md, AGENT.md, PLAN.md - -### In Progress -- [ ] Verify CPU forward pass matches upstream vjepa2 repo output (numerical equivalence) -- [ ] Test all 4 encoder variants (ViT-B, ViT-L, ViT-g, ViT-G) - -## Phase 2 — Neuron Compilation (on Trainium) - -### Tasks -- [ ] Set up trn2 instance with Neuron SDK 2.28+ -- [ ] Install dependencies (torch-neuronx, neuronx-distributed-inference) -- [ ] Trace ViT-B encoder with `torch_neuronx.trace()` at 384×384, 16 frames -- [ ] Verify SDPA compatibility — if Neuron doesn't support `F.scaled_dot_product_attention`, add manual attention fallback -- [ ] Verify Conv3d support — if unsupported, decompose to reshape + Conv2d -- [ ] Handle `torch.arange` in RoPE forward pass (may need to precompute) -- [ ] Trace ViT-L encoder -- [ ] Compare traced output vs CPU reference (cosine similarity > 0.99) -- [ ] Benchmark latency and throughput - -## Phase 3 — Scaling & Optimization - -### Tasks -- [ ] Test ViT-g (1B params) — may need TP>1 or NKI flash attention for 64-frame clips -- [ ] Test ViT-G (1.8B params) — likely needs TP≥2 -- [ ] If TP needed: port to NxDI pattern with NKI flash attention -- [ ] Profile memory usage at different frame counts (16, 32, 64) -- [ ] Optimize: batch compilation for multiple input shapes (frame count buckets) - -## Phase 4 — Downstream Tasks - -### Tasks -- [ ] Add attentive pooler for classification inference -- [ ] Add predictor for action anticipation inference -- [ ] Test with pretrained checkpoints on downstream benchmarks -- [ ] Add AC predictor for robotics planning inference (if applicable) - -## Phase 5 — Contrib Submission - -### Tasks -- [ ] Run full test suite on Trainium hardware -- [ ] Measure accuracy with `neuron_allclose()` against CPU reference -- [ ] Fill in compatibility matrix with actual test results -- [ ] Fill in benchmark results (throughput, latency) -- [ ] Ensure all tests pass with `pytest` -- [ ] Submit PR following NxDI contrib guidelines - -## Key Decisions - -### Why start with ViT-B/ViT-L? -- Smaller models compile faster and fit on single NeuronCore -- Validates the porting approach before scaling up -- ViT-B (86M params) and ViT-L (300M params) are practical for many downstream tasks - -### Why `torch_neuronx.trace()` first? -- Simpler than full NxDI port -- Encoder is feedforward (no KV cache, no autoregressive) -- Can always upgrade to NxDI later if TP is needed for larger models - -### Why not port the predictor first? -- Encoder is the primary inference component -- Predictor is only needed for pretraining and specific tasks (anticipation) -- Encoder features are sufficient for classification, VQA, and feature extraction - -## Risk Register - -| Risk | Impact | Mitigation | -|------|--------|------------| -| SDPA not supported on Neuron | Medium | Manual attention fallback already in codebase (`use_sdpa=False`) | -| Conv3d not supported | Low | Decompose to reshape + Conv2d | -| 64-frame ViT-G exceeds single-core HBM | High | Start with shorter clips; upgrade to NxDI with TP if needed | -| RoPE dynamic tensor creation | Medium | Precompute position tensors at trace time | -| `timm` dependency | Low | Replaced with inline `drop_path` (identity at eval) | +- [ ] Verify CPU forward pass matches upstream vjepa2 repo (numerical equivalence) + +## Phase 2 — Neuron Compilation (on Trainium) ✅ COMPLETE + +- [x] Set up trn2.3xlarge instance (sa-east-1, persistent spot) +- [x] Traced ViT-B (86M) — compiled on first attempt with `use_sdpa=False` +- [x] Traced ViT-L (300M) — compiled in 18 min +- [x] Validated both: cosine similarity > 0.999 vs CPU reference +- [x] Benchmarked: ViT-B 164.5ms, ViT-L 437.4ms (batch=1, BF16, 16 frames) +- [x] DataParallel: 2x throughput with `torch_neuronx.DataParallel` (zero code changes) +- [x] Integrated NKI flash attention (`attention_isa_kernel`) — works but slower at 4608 tokens +- [x] Added modular compilation markers (`ModuleMarkerStartWrapper`/`EndWrapper`) + +### Key findings +- Conv3d, `torch.arange`, `repeat_interleave` all compile natively — no workarounds needed +- Only required change: `use_sdpa=False` to bypass unsupported SDPA +- BF16 softmax dtype fix: `.to(v.dtype)` after softmax +- NKI flash attention: higher accuracy but 1.8x slower at 4608 tokens (designed for 16K+) +- DataParallel: linear throughput scaling, 83ms/clip for ViT-B (2 NeuronCores) + +## Phase 3 — Scaling to ViT-g / ViT-G 🔴 BLOCKED + +**Blocker**: neuronx-cc compiler OOMs on host (>124GB RAM) when compiling ViT-g (40 layers) as a monolithic graph. The `ModuleMarkerStartWrapper`/`EndWrapper` markers do NOT cause `torch_neuronx.trace()` to split the graph — they are only respected by `parallel_model_trace` from NxD. + +### Options (in order of recommendation) + +1. **`parallel_model_trace` from NxD** — Use `neuronx_distributed.trace.parallel_model_trace` instead of `torch_neuronx.trace()`. This is how Flux and other NxDI models compile with modular markers. Requires wrapping the model in a `ModelWrapper`-like class with `input_generator()` and `get_model_instance()`. The markers are already in the model code. + +2. **Larger instance for compilation** — Compile on trn2.48xlarge (2TB RAM), then load the `.pt` on trn2.3xlarge for inference. Simplest approach, just costs more during compilation. + +3. **Manual graph splitting** — Trace layers 0-19 and 20-39 as separate models, chain at runtime. Hacky but avoids NxD dependency. + +### Tasks remaining +- [ ] Get ViT-g (1B) compiling via one of the above approaches +- [ ] Validate and benchmark ViT-g +- [ ] Compile, validate, and benchmark ViT-G (1.8B) + +## Phase 4 — Downstream Tasks (NOT STARTED) + +- [ ] Attentive pooler for classification +- [ ] Predictor for action anticipation +- [ ] Test with pretrained checkpoints +- [ ] 64-frame inference (NKI flash attention becomes relevant here) + +## Phase 5 — Contrib Submission (NOT STARTED) + +- [ ] Full test suite on Trainium +- [ ] `neuron_allclose()` validation +- [ ] Complete compatibility matrix +- [ ] Submit PR diff --git a/contrib/models/jepa-2-1/PLAN_trn2.md b/contrib/models/jepa-2-1/PLAN_trn2.md index d71957bb..f2512f14 100644 --- a/contrib/models/jepa-2-1/PLAN_trn2.md +++ b/contrib/models/jepa-2-1/PLAN_trn2.md @@ -1,257 +1,125 @@ -# PLAN_trn2.md — V-JEPA 2.1 Trainium Execution Plan +# PLAN_trn2.md — V-JEPA 2.1 Trainium Execution Plan & Results ## Instance -- **Type**: trn2.3xlarge (spot) in sa-east-1b +- **Type**: trn2.3xlarge (persistent spot) in sa-east-1b - **Instance ID**: i-0cae7b2ac61807cf9 - **SSH**: `ssh -i ~/.ssh/trn2-sa-east-1.pem ubuntu@52.67.239.128` -- **Hardware**: 1 Neuron device, 4 NeuronCores, 96 GB HBM, 124 GB system RAM, 418 GB disk free -- **OS**: Ubuntu 24.04, Python 3.12.3 -- **Neuron driver**: aws-neuronx-dkms 2.27.4, runtime 2.31.24 (apt-installed) -- **Python venv**: `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate` -- **Neuron SDK**: torch-neuronx 2.9.0, neuronx-cc 2.24.5133, neuronx-distributed-inference 0.9.17334 -- **PyTorch**: 2.9.1, torch-xla 2.9.0, torchvision 0.24.1 -- **Tools**: pytest 9.0.3 +- **Hardware**: 1 Neuron device, 2 logical NeuronCores, 96 GB HBM, 124 GB system RAM +- **Neuron SDK**: torch-neuronx 2.9.0, neuronx-cc 2.24.5133, NxDI 0.9.17334 +- **Venv**: `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate` ## Workflow -Edit code locally at `~/dev/neuron-docs/neuronx-distributed-inference/contrib/models/jepa-2-1/`, rsync to trn2, run remotely. Same pattern as the autoresearch port. - ```bash -# Sync -rsync -avz --exclude='__pycache__' --exclude='.DS_Store' --exclude='._*' \ +# Sync local → trn2 +rsync -avz --exclude='__pycache__' --exclude='._*' \ ~/dev/neuron-docs/neuronx-distributed-inference/contrib/models/jepa-2-1/ \ -e "ssh -i ~/.ssh/trn2-sa-east-1.pem" ubuntu@52.67.239.128:jepa-2-1/ -# Run remotely +# Run on trn2 ssh -i ~/.ssh/trn2-sa-east-1.pem ubuntu@52.67.239.128 \ "cd jepa-2-1 && source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate && python ..." ``` --- -## Step 1 — Sync Code & Run CPU Smoke Test on trn2 +## Results Summary -Rsync the project and verify the encoder runs on CPU. The Neuron SDK venv is pre-installed. - -```bash -# Sync -rsync -avz --exclude='__pycache__' --exclude='._*' \ - ~/dev/neuron-docs/neuronx-distributed-inference/contrib/models/jepa-2-1/ \ - -e "ssh -i ~/.ssh/trn2-sa-east-1.pem" ubuntu@52.67.239.128:jepa-2-1/ +### Compilation & Validation (BF16, 16 frames, 384×384) -# CPU smoke test -ssh -i ~/.ssh/trn2-sa-east-1.pem ubuntu@52.67.239.128 << 'EOF' -source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate -cd jepa-2-1 -python -c " -from src.modeling_jepa21 import build_vjepa21_encoder -import torch -encoder = build_vjepa21_encoder(arch='vit_base', img_size=384, num_frames=16, pretrained=False) -encoder.eval() -x = torch.randn(1, 3, 16, 384, 384) -with torch.no_grad(): - out = encoder(x) -print(f'Output shape: {out.shape}') # expect (1, 4608, 768) -" -EOF -``` +| Model | Params | Compile Time | Cosine Sim | Status | +|-------|--------|-------------|------------|--------| +| ViT-B | 86M | ~8 min | 0.999846 | ✅ | +| ViT-L | 300M | 18 min | 0.999873 | ✅ | +| ViT-g | 1.01B | OOM at ~30 min | — | ❌ Host OOM (>124GB RAM) | +| ViT-G | 1.8B | Not attempted | — | ❌ Blocked | -**Success criteria**: Output shape is `(1, 4608, 768)` for ViT-B with 16 frames. +### Latency (batch=1, single NeuronCore) -## Step 2 — Trace ViT-B Encoder with torch_neuronx +| Model | Median | Mean | p5 | p95 | +|-------|--------|------|-----|-----| +| ViT-B | 164.5 ms | 164.5 ms | 164.4 ms | 164.6 ms | +| ViT-L | 437.4 ms | 437.5 ms | 437.4 ms | 437.6 ms | -First compilation attempt. Start with the smallest model (ViT-B, 86M params) and 16 frames. Cast to `bfloat16` — trn2 NeuronCores are heavily optimized for BF16/FP8, and explicit casting avoids unpredictable compiler auto-cast behavior. +Sub-millisecond variance — deterministic Neuron execution. -```python -import torch -import torch_neuronx -from src.modeling_jepa21 import build_vjepa21_encoder +### DataParallel Throughput (2 logical NeuronCores) -encoder = build_vjepa21_encoder(arch='vit_base', img_size=384, num_frames=16, pretrained=False) -encoder.eval().bfloat16() +| Model | Per-clip Latency | Throughput | Speedup | +|-------|-----------------|------------|---------| +| ViT-B | 83.2 ms | 12.0 clips/sec | 1.98x | +| ViT-L | 219.8 ms | 4.5 clips/sec | 1.99x | -example = torch.randn(1, 3, 16, 384, 384, dtype=torch.bfloat16) -traced = torch_neuronx.trace(encoder, example, compiler_args=['--auto-cast', 'none']) -traced.save("vjepa21_vitb_16f_384.pt") -print("Compilation succeeded") -``` +Linear scaling with batch size. Any batch size works (dynamic batching). -**Note on masking**: The encoder's inference path uses `masks=None` by default — no tokens are dropped, so all tensor shapes are fully static. The masking codepath is training-only and won't be triggered during tracing. +### NKI Flash Attention (experimental) -**Expected issues** (debug in order of likelihood): +| Model | Baseline | NKI Flash | Cosine Sim | +|-------|----------|-----------|------------| +| ViT-B | 164.5 ms | 307.4 ms (+87%) | 0.999972 | +| ViT-L | 437.4 ms | 787.2 ms (+80%) | 1.000006 | -1. **SDPA not supported** → Set `use_sdpa=False` in the encoder config or add a manual attention fallback path in `modeling_jepa21.py`. -2. **Conv3d not supported** → Two options: (a) decompose `PatchEmbed3D` into reshape + Conv2d, or (b) replace with reshape + `nn.Linear` since stride == kernel_size makes the convolution equivalent to a linear projection over flattened tubelet patches — this maps directly to MatMul on NeuronCore and may be faster. -3. **`torch.arange` in RoPE** → Precompute RoPE frequency tensors before tracing (move out of forward pass). -4. **`repeat_interleave` not supported** → Replace with equivalent `reshape`/`expand`/`reshape` sequence. +Higher accuracy but slower at 4608 tokens. Reserved for 64-frame (18K token) inference. -**Success criteria**: `.pt` file saved, no compilation errors. +--- -## Step 3 — Validate Traced Model Output +## Compilation Commands -Compare Neuron-traced output against CPU reference. +### ViT-B ```python -import torch -import torch_neuronx - -# CPU reference (BF16) -encoder_cpu = build_vjepa21_encoder(arch='vit_base', img_size=384, num_frames=16, pretrained=False) -encoder_cpu.eval().bfloat16() -x = torch.randn(1, 3, 16, 384, 384, dtype=torch.bfloat16) -with torch.no_grad(): - ref = encoder_cpu(x) - -# Neuron -traced = torch.jit.load("vjepa21_vitb_16f_384.pt") -neuron_out = traced(x) - -cos_sim = torch.nn.functional.cosine_similarity(ref.flatten().float(), neuron_out.flatten().float(), dim=0) -print(f"Cosine similarity: {cos_sim.item():.6f}") # target > 0.99 -``` - -**Success criteria**: Cosine similarity > 0.99 between CPU and Neuron outputs. - -## Step 4 — Trace ViT-L Encoder - -Scale up to ViT-L (300M params, 16 frames, 4608 tokens). +import torch, torch_neuronx +from src.modeling_jepa21 import build_vjepa21_encoder -```python -encoder = build_vjepa21_encoder(arch='vit_large', img_size=384, num_frames=16, pretrained=False) +encoder = build_vjepa21_encoder(arch='vit_base', img_size=384, num_frames=16, use_sdpa=False) encoder.eval().bfloat16() -example = torch.randn(1, 3, 16, 384, 384, dtype=torch.bfloat16) -traced = torch_neuronx.trace(encoder, example, compiler_args=['--auto-cast', 'none']) -traced.save("vjepa21_vitl_16f_384.pt") -``` - -**Potential issue**: ViT-L has 24 layers × 16 heads. Attention matrices are 4608×4608 per head. Should fit in 96 GB HBM on a single NeuronCore, but watch for OOM during compilation (neuronx-cc can be memory-hungry on the host side — 124 GB system RAM may be tight for large graphs). - -**Success criteria**: Compilation succeeds, cosine similarity > 0.99 vs CPU. - -## Step 5 — Benchmark Latency - -Measure inference latency for both models. - -```python -import time -import torch - -traced = torch.jit.load("vjepa21_vitb_16f_384.pt") x = torch.randn(1, 3, 16, 384, 384, dtype=torch.bfloat16) - -# Warmup -for _ in range(5): - traced(x) - -# Benchmark -times = [] -for _ in range(50): - t0 = time.perf_counter() - traced(x) - t1 = time.perf_counter() - times.append(t1 - t0) - -import statistics -print(f"ViT-B 16f: {statistics.median(times)*1000:.1f} ms median, {statistics.mean(times)*1000:.1f} ms mean") +traced = torch_neuronx.trace(encoder, x, compiler_args=['--auto-cast', 'none']) +traced.save('vjepa21_vitb_16f_384.pt') ``` -Repeat for ViT-L. Record results in README.md compatibility matrix. +### ViT-L -## Step 6 — Test with Pretrained Weights (Optional) +Same as above with `arch='vit_large'`. -If Meta's checkpoints are accessible via `torch.hub`: +### Validation Pattern ```python -encoder = build_vjepa21_encoder(arch='vit_large', img_size=384, num_frames=16, pretrained=True) +# Build CPU ref (no NKI) and NKI model with same seed for matching weights +torch.manual_seed(0) +encoder_cpu = build_vjepa21_encoder(..., use_nki_flash=False) +# ... get ref output ... + +torch.manual_seed(0) +encoder_nki = build_vjepa21_encoder(..., use_nki_flash=True) +# ... trace and compare ... ``` -This validates that the weight loading path works end-to-end on Neuron. - --- -## Risk Mitigation - -| Risk | Mitigation | -|------|------------| -| SDPA unsupported on Neuron | `use_sdpa=False` flag already in model; manual `q @ k.T / sqrt(d) → softmax → @ v` fallback | -| Conv3d unsupported | Decompose to `reshape` + `Conv2d`, or replace with `reshape` + `nn.Linear` (maps to NeuronCore MatMul engine) | -| Dynamic `torch.arange` in RoPE | Precompute freq tensors as buffers; register in `__init__` | -| `repeat_interleave` unsupported | Replace with `reshape`→`expand`→`reshape` | -| Host OOM during compilation (124 GB RAM) | Compile ViT-B first (smaller graph); use `NEURON_CC_FLAGS="--retry_failed_compilation"` | -| Spot instance termination | Save compiled `.pt` files to S3 after each successful compilation | +## Compiled Files on Instance -## Out of Scope (for now) - -- ViT-g / ViT-G (need TP, Phase 3 in PLAN.md) -- 64-frame inference (18K tokens, likely needs NKI flash attention) -- Predictor / AC predictor compilation -- Attentive pooler -- Downstream task benchmarks +``` +~/jepa-2-1/vjepa21_vitb_16f_384_v2.pt (335M) — ViT-B baseline (best) +~/jepa-2-1/vjepa21_vitl_16f_384.pt (1.1G) — ViT-L baseline (best) +~/jepa-2-1/vjepa21_vitb_nki_16f_384.pt (405M) — ViT-B + NKI flash (slower) +~/jepa-2-1/vjepa21_vitl_nki_16f_384.pt (1.4G) — ViT-L + NKI flash (slower) +``` --- -## Execution Results (2026-04-11) - -All steps executed on trn2.3xlarge `i-0cae7b2ac61807cf9` in sa-east-1. -SDK: torch-neuronx 2.9.0, neuronx-cc 2.24.5133, Python 3.12.3. - -### Step 1 — Sync & CPU Smoke Test ✅ - -- Rsync: 14 files transferred -- CPU output shape: `(1, 4608, 768)` — matches expected for ViT-B/16 with 16 frames - -### Step 2 — Trace ViT-B ✅ - -- Compiled on **first attempt** with `use_sdpa=False` and `--auto-cast none` -- None of the anticipated workarounds were needed: - - Conv3d: compiled natively - - `torch.arange` in RoPE: compiled natively - - `repeat_interleave`: compiled natively -- Only required change: `use_sdpa=False` to bypass `F.scaled_dot_product_attention` - -### Step 2 (fix) — BF16 dtype fix - -- `softmax()` promotes BF16→FP32 on CPU, causing dtype mismatch in manual attention path -- Fix: added `.to(v.dtype)` after `softmax` in both `RoPEAttention` and `Attention` classes - -### Step 3 — Validate ViT-B ✅ - -| Metric | Value | -|--------|-------| -| CPU output shape | `(1, 4608, 768)` | -| Neuron output shape | `(1, 4608, 768)` | -| Cosine similarity | **0.999846** | -| Max abs diff | 0.078125 | -| Mean abs diff | 0.004509 | - -### Step 4 — Trace ViT-L ✅ - -- Compilation time: **1073s (~18 min)** -- No host OOM — 124 GB system RAM was sufficient - -### Step 4b — Validate ViT-L ✅ - -| Metric | Value | -|--------|-------| -| CPU output shape | `(1, 4608, 1024)` | -| Neuron output shape | `(1, 4608, 1024)` | -| Cosine similarity | **0.999873** | -| Max abs diff | 0.132812 | -| Mean abs diff | 0.007144 | - -### Step 5 — Benchmark Latency ✅ - -Batch=1, BF16, 16 frames, 384×384, 50 iterations after 5 warmup: +## ViT-g / ViT-G: Compilation Failure Analysis -| Model | Params | Median | Mean | p5 | p95 | -|-------|--------|--------|------|-----|-----| -| ViT-B | 86M | **164.5 ms** | 164.5 ms | 164.4 ms | 164.6 ms | -| ViT-L | 300M | **437.4 ms** | 437.5 ms | 437.4 ms | 437.6 ms | +**Root cause**: neuronx-cc compiler memory scales with graph size. Peak host RAM usage: +- ViT-L (24 layers): ~60GB → fits in 124GB ✅ +- ViT-g (40 layers): >124GB → OOM ❌ -Sub-millisecond variance — typical of Neuron hardware deterministic execution. +The failure is in the compiler, not the model. The CPU forward pass succeeds. The compiled NEFF would likely fit in 96GB HBM at runtime. -### Files modified +**Attempted mitigation**: Added `ModuleMarkerStartWrapper`/`EndWrapper` from NxDI to split the graph into groups of 8 layers. Result: markers are inserted but `torch_neuronx.trace()` does NOT respect them — it still compiles the full graph as one unit. The markers are only respected by `parallel_model_trace` from NxD. -- `src/modeling_jepa21.py` — `.to(v.dtype)` after softmax in manual attention paths -- `README.md` — updated compatibility matrix, compilation example, known issues +**Next steps** (see PLAN.md Phase 3): +1. Use `parallel_model_trace` from NxD (recommended — markers already in code) +2. Compile on trn2.48xlarge (2TB RAM) +3. Manual graph splitting (hacky) diff --git a/contrib/models/jepa-2-1/README.md b/contrib/models/jepa-2-1/README.md index d4b13aea..e4bbadce 100644 --- a/contrib/models/jepa-2-1/README.md +++ b/contrib/models/jepa-2-1/README.md @@ -1,135 +1,96 @@ # V-JEPA 2.1 on AWS Trainium -V-JEPA 2.1 (Video Joint-Embedding Predictive Architecture) is Meta's self-supervised video foundation model. It learns visual representations by predicting masked video segments in a learned representation space, rather than pixel space. V-JEPA 2.1 extends V-JEPA 2 with knowledge distillation from a ViT-Gigantic teacher, enabling smaller student encoders (ViT-Base, ViT-Large) to achieve strong performance. +V-JEPA 2.1 (Video Joint-Embedding Predictive Architecture) is Meta's self-supervised video foundation model. It learns visual representations by predicting masked video segments in a learned representation space, rather than pixel space. V-JEPA 2.1 extends V-JEPA 2 with knowledge distillation from a ViT-Gigantic teacher. This port targets inference on AWS Trainium (trn2) using `torch_neuronx.trace()`. ## Model Information - **Source**: [facebookresearch/vjepa2](https://github.com/facebookresearch/vjepa2) -- **Paper**: [Self-Supervised Video Models Enable Understanding, Prediction and Planning](https://arxiv.org/abs/2506.09985) -- **Model Type**: Self-supervised Vision Transformer (ViT) encoder + predictor -- **Architecture**: ViT with 3D-RoPE, mask-denoising pretraining, hierarchical multi-layer output, modality embeddings (image/video) -- **License**: MIT (vjepa2 repo) +- **Paper**: [arxiv.org/abs/2506.09985](https://arxiv.org/abs/2506.09985) +- **License**: MIT -### Available Checkpoints +| Model | Params | Depth | Heads | Resolution | Neuron Status | +|-------|--------|-------|-------|------------|---------------| +| ViT-B/16 | 86M | 12 | 12 | 384 | ✅ Compiled & benchmarked | +| ViT-L/16 | 300M | 24 | 16 | 384 | ✅ Compiled & benchmarked | +| ViT-g/16 | 1B | 40 | 22 | 384 | ❌ Host OOM during compilation (needs >124GB RAM) | +| ViT-G/16 | 1.8B | 48 | 26 | 384 | ❌ Not attempted (blocked by ViT-g) | -| Model | Params | Embed Dim | Depth | Heads | Resolution | Teacher | -|-------|--------|-----------|-------|-------|------------|---------| -| V-JEPA 2.1 ViT-B/16 | 86M | 768 | 12 | 12 | 384 | ViT-G distillation | -| V-JEPA 2.1 ViT-L/16 | 300M | 1024 | 24 | 16 | 384 | ViT-G distillation | -| V-JEPA 2.1 ViT-g/16 | 1B | 1408 | 40 | 22 | 384 | Self-supervised | -| V-JEPA 2.1 ViT-G/16 | 1.8B | 1664 | 48 | 26 | 384 | Self-supervised | +## Benchmark Results -## Architecture Overview +trn2.3xlarge, BF16, 16 frames, 384×384, `torch_neuronx.trace()` with `--auto-cast none`: -V-JEPA 2.1 consists of: +| Model | Single-core Latency | Cosine Sim vs CPU | DataParallel (2 NCs) | Throughput | +|-------|--------------------|--------------------|----------------------|------------| +| ViT-B (86M) | 164.5 ms | 0.9998 | 83.2 ms/clip | 12.0 clips/sec | +| ViT-L (300M) | 437.4 ms | 0.9999 | 219.8 ms/clip | 4.5 clips/sec | -1. **Encoder** (`VisionTransformer`): Processes video frames patchified into 2×16×16 tubelets. Uses 3D-RoPE for spatiotemporal position encoding. Outputs hierarchical features from multiple intermediate layers (e.g., layers [5, 11, 17, 23] for ViT-L depth=24). - -2. **Predictor** (`VisionTransformerPredictor`): Takes encoder features + learnable mask tokens and predicts representations of masked patches. Uses multi-layer hierarchical input from the encoder via a learned projection. - -3. **Attentive Pooler** (optional, for classification): Cross-attention pooling over encoder features for downstream classification tasks. - -Key differences from V-JEPA 2: -- Hierarchical multi-layer output with per-layer norms (`norms_block`) -- Modality embeddings (separate for image vs video input) -- `img_temporal_dim_size` for handling single-frame image inputs with tubelet_size=1 -- Distillation-aware predictor with `n_output_distillation` controlling which layers contribute -- `interpolate_rope` for resolution-flexible RoPE - -## Inference Approach - -For inference on Trainium, we use `torch_neuronx.trace()` on the encoder. The encoder is the primary component needed for downstream tasks (classification, VQA, feature extraction). The predictor is only needed for pretraining and action anticipation tasks. - -### Why `torch_neuronx.trace()` (not NxDI) - -- The encoder is a standard ViT without KV cache or autoregressive decoding -- At 384×384 resolution with 64 frames: seq_len = (64/2) × (384/16)² = 32 × 576 = 18,432 tokens per clip -- For single-frame image inference: seq_len = 576 tokens (trivially fits) -- For short video clips (16 frames): seq_len = 8 × 576 = 4,608 tokens -- NxDI's KV cache and flash attention infrastructure is unnecessary for non-autoregressive models -- `torch_neuronx.trace()` is simpler and sufficient for encoder-only inference - -### Compilation Strategy - -- Trace the encoder with a fixed input shape (batch, channels, frames, height, width) -- Use `torch_neuronx.trace()` with example inputs -- For variable-length video, compile multiple buckets or pad to max length +Real-time video processing (16 frames @ 30fps = 0.53s of video): +- ViT-B: **3.2x real-time** (single-core), **6.4x real-time** (DataParallel) +- ViT-L: **1.2x real-time** (single-core), **2.4x real-time** (DataParallel) ## Usage -```python -import torch -import torch_neuronx +### CPU Inference -# Load encoder (CPU reference) +```python from src.modeling_jepa21 import build_vjepa21_encoder +import torch -encoder = build_vjepa21_encoder( - arch="vit_large", - img_size=384, - num_frames=16, - pretrained=False, # set True when checkpoint available -) +encoder = build_vjepa21_encoder(arch="vit_large", img_size=384, num_frames=16, pretrained=False) encoder.eval() -# Example: single image input (B, C, T, H, W) -image_input = torch.randn(1, 3, 1, 384, 384) +video = torch.randn(1, 3, 16, 384, 384) with torch.no_grad(): - features = encoder(image_input) -# features shape: (1, 576, 1024) for ViT-L - -# Example: video input -video_input = torch.randn(1, 3, 16, 384, 384) -with torch.no_grad(): - features = encoder(video_input) -# features shape: (1, 4608, 1024) for ViT-L with 16 frames + features = encoder(video) # (1, 4608, 1024) ``` -### Neuron Compilation (on Trainium instance) +### Neuron Compilation ```python -import torch -import torch_neuronx +import torch, torch_neuronx from src.modeling_jepa21 import build_vjepa21_encoder encoder = build_vjepa21_encoder(arch="vit_large", img_size=384, num_frames=16, use_sdpa=False) encoder.eval().bfloat16() -example_input = torch.randn(1, 3, 16, 384, 384, dtype=torch.bfloat16) -traced = torch_neuronx.trace(encoder, example_input, compiler_args=["--auto-cast", "none"]) +x = torch.randn(1, 3, 16, 384, 384, dtype=torch.bfloat16) +traced = torch_neuronx.trace(encoder, x, compiler_args=["--auto-cast", "none"]) traced.save("vjepa21_vitl_16f_384.pt") ``` -## Compatibility Matrix +### DataParallel (2x throughput on trn2.3xlarge) -| Instance | SDK | Model | Frames | Resolution | Dtype | Compile | Cosine Sim | Latency (median) | -|----------|-----|-------|--------|------------|-------|---------|------------|-------------------| -| trn2.3xlarge | 2.27 (torch-neuronx 2.9.0, neuronx-cc 2.24.5133) | ViT-B (86M) | 16 | 384×384 | BF16 | ✅ PASS | 0.9998 | 164.5 ms | -| trn2.3xlarge | 2.27 (torch-neuronx 2.9.0, neuronx-cc 2.24.5133) | ViT-L (300M) | 16 | 384×384 | BF16 | ✅ PASS | 0.9999 | 437.4 ms | -| inf2.xlarge | — | — | — | — | — | Not tested | — | — | +```python +import torch, torch_neuronx -## Example Checkpoints +traced = torch.jit.load("vjepa21_vitl_16f_384.pt") +model_dp = torch_neuronx.DataParallel(traced) -* V-JEPA 2.1 weights are loaded via `torch.hub` from Meta's servers (see `hubconf.py` in vjepa2 repo) +batch = torch.randn(4, 3, 16, 384, 384, dtype=torch.bfloat16) +output = model_dp(batch) # distributes across 2 NeuronCores +``` + +## Key Requirements for Neuron Compilation -## Testing Instructions +- `use_sdpa=False` — SDPA is not supported by `torch_neuronx.trace()` +- `.bfloat16()` model and inputs — trn2 NeuronCores are optimized for BF16 +- `--auto-cast none` — avoids unpredictable compiler auto-cast behavior +- Conv3d, `torch.arange`, `repeat_interleave` all compile natively + +## Known Issues + +- ViT-g (1B) and ViT-G (1.8B) cannot compile on trn2.3xlarge (124GB RAM) — the neuronx-cc compiler OOMs on the 40+ layer graph. NxDI `ModuleMarkerStartWrapper`/`EndWrapper` markers were added but `torch_neuronx.trace()` does not respect them for graph splitting. Next step: use `parallel_model_trace` from NxD, or compile on a larger instance. +- NKI flash attention (`attention_isa_kernel`) is integrated but slower than compiler-generated attention at 4608 tokens. Reserved for 64-frame inference (18K tokens) where it will be beneficial. +- BF16 softmax promotes to FP32 on CPU; `.to(v.dtype)` cast added after softmax. + +## Testing ```bash -# CPU-only tests (runs on MacBook) -cd contrib/models/jepa-2-1 +# CPU-only tests pytest test/ -v # On Trainium instance pytest test/integration/test_model.py -v ``` - -## Known Issues - -- `use_sdpa=False` is required for Neuron compilation — `F.scaled_dot_product_attention` is not supported by `torch_neuronx.trace()`. The manual attention fallback (`q @ k.T * scale → softmax → @ v`) works correctly. -- BF16 softmax promotes to FP32 on CPU; `.to(v.dtype)` cast added after softmax to maintain dtype consistency. -- 3D-RoPE uses a duplicated frequency pattern (known upstream bug, preserved for checkpoint compatibility) -- `timm` is required as a dependency for `drop_path` (replaced with inline implementation) -- Full 64-frame ViT-G inference may require TP>1 on Trainium due to memory -- Conv3d, `torch.arange`, and `repeat_interleave` all compile successfully on neuronx-cc 2.24.5133 diff --git a/contrib/models/jepa-2-1/src/modeling_jepa21.py b/contrib/models/jepa-2-1/src/modeling_jepa21.py index 965398af..d4c8b183 100644 --- a/contrib/models/jepa-2-1/src/modeling_jepa21.py +++ b/contrib/models/jepa-2-1/src/modeling_jepa21.py @@ -23,6 +23,17 @@ except ImportError: pass +# Modular compilation markers — only available with NxDI +_ModuleMarkerStart = None +_ModuleMarkerEnd = None +try: + from neuronx_distributed_inference.models.layer_boundary_marker import ( + ModuleMarkerStartWrapper as _ModuleMarkerStart, + ModuleMarkerEndWrapper as _ModuleMarkerEnd, + ) +except ImportError: + pass + # --------------------------------------------------------------------------- # Utility: truncated normal init (replaces src.utils.tensors.trunc_normal_) @@ -470,6 +481,7 @@ def __init__( wide_silu=True, use_sdpa=True, use_nki_flash=False, + modular_compilation_group_size=0, use_activation_checkpointing=False, is_causal=False, use_rope=True, @@ -498,6 +510,7 @@ def __init__( self.is_video = num_frames > 1 self.use_activation_checkpointing = use_activation_checkpointing + self.modular_compilation_group_size = modular_compilation_group_size dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] @@ -649,8 +662,12 @@ def forward(self, x, masks=None, training=False): masks = torch.cat(masks, dim=0) # Forward through blocks + gs = self.modular_compilation_group_size + use_markers = gs > 0 and _ModuleMarkerStart is not None hier = [] for i, blk in enumerate(self.blocks): + if use_markers and i % gs == 0: + x = _ModuleMarkerStart()(x) x, _attn = blk( x, mask=masks, T=T, H_patches=H_patches, W_patches=W_patches, return_attn=False, mode=mode, @@ -658,6 +675,8 @@ def forward(self, x, masks=None, training=False): if i in self.out_layers_distillation: out_idx = self.hierarchical_layers.index(i) hier.append(self.norms_block[out_idx](x)) + if use_markers and (i % gs == gs - 1 or i == len(self.blocks) - 1): + x = _ModuleMarkerEnd()(x) if training or self.return_hierarchical: return torch.cat(hier, dim=2) @@ -843,6 +862,7 @@ def build_vjepa21_encoder( tubelet_size: int = 2, use_sdpa: bool = True, use_nki_flash: bool = False, + modular_compilation_group_size: int = 0, use_rope: bool = True, interpolate_rope: bool = True, img_temporal_dim_size: int = 1, @@ -870,6 +890,7 @@ def build_vjepa21_encoder( tubelet_size=tubelet_size, use_sdpa=use_sdpa, use_nki_flash=use_nki_flash, + modular_compilation_group_size=modular_compilation_group_size, use_silu=False, wide_silu=True, uniform_power=False, From 25ad5e68ad63d32dd5f001ea3cf65d123ff25853 Mon Sep 17 00:00:00 2001 From: Daniel Stair Date: Wed, 6 May 2026 20:30:23 +0000 Subject: [PATCH 3/5] all model sizes tested --- contrib/models/jepa-2-1/._AGENT.md | Bin 163 -> 0 bytes contrib/models/jepa-2-1/._PLAN.md | Bin 163 -> 0 bytes contrib/models/jepa-2-1/._README.md | Bin 163 -> 0 bytes contrib/models/jepa-2-1/._pyproject.toml | Bin 163 -> 0 bytes contrib/models/jepa-2-1/._src | Bin 163 -> 0 bytes contrib/models/jepa-2-1/._test | Bin 163 -> 0 bytes contrib/models/jepa-2-1/._uv.lock | Bin 163 -> 0 bytes contrib/models/jepa-2-1/AGENT.md | 162 +++++++++- contrib/models/jepa-2-1/PLAN.md | 58 ---- contrib/models/jepa-2-1/PLAN_trn2.md | 125 -------- contrib/models/jepa-2-1/PLAN_trn2_48xl.md | 93 ++++++ contrib/models/jepa-2-1/PR_README.md | 236 ++++++++++++++ contrib/models/jepa-2-1/README.md | 293 +++++++++++++++--- contrib/models/jepa-2-1/demo_classify.py | 73 +++++ contrib/models/jepa-2-1/demo_neuron.py | 111 +++++++ contrib/models/jepa-2-1/src/.___init__.py | Bin 163 -> 0 bytes contrib/models/jepa-2-1/src/.___pycache__ | Bin 163 -> 0 bytes .../models/jepa-2-1/src/._modeling_jepa21.py | Bin 163 -> 0 bytes contrib/models/jepa-2-1/test/.___init__.py | Bin 163 -> 0 bytes contrib/models/jepa-2-1/test/.___pycache__ | Bin 163 -> 0 bytes contrib/models/jepa-2-1/test/._integration | Bin 163 -> 0 bytes contrib/models/jepa-2-1/test/._unit | Bin 163 -> 0 bytes .../jepa-2-1/test/integration/.___init__.py | Bin 163 -> 0 bytes .../jepa-2-1/test/integration/.___pycache__ | Bin 163 -> 0 bytes .../jepa-2-1/test/integration/._test_model.py | Bin 163 -> 0 bytes .../models/jepa-2-1/test/unit/.___init__.py | Bin 163 -> 0 bytes .../models/jepa-2-1/test/unit/.___pycache__ | Bin 163 -> 0 bytes .../jepa-2-1/test/unit/._test_encoder.py | Bin 163 -> 0 bytes .../models/jepa-2-1/validate_pretrained.py | 210 +++++++++++++ 29 files changed, 1132 insertions(+), 229 deletions(-) delete mode 100644 contrib/models/jepa-2-1/._AGENT.md delete mode 100644 contrib/models/jepa-2-1/._PLAN.md delete mode 100644 contrib/models/jepa-2-1/._README.md delete mode 100644 contrib/models/jepa-2-1/._pyproject.toml delete mode 100755 contrib/models/jepa-2-1/._src delete mode 100755 contrib/models/jepa-2-1/._test delete mode 100644 contrib/models/jepa-2-1/._uv.lock delete mode 100644 contrib/models/jepa-2-1/PLAN.md delete mode 100644 contrib/models/jepa-2-1/PLAN_trn2.md create mode 100644 contrib/models/jepa-2-1/PLAN_trn2_48xl.md create mode 100644 contrib/models/jepa-2-1/PR_README.md create mode 100644 contrib/models/jepa-2-1/demo_classify.py create mode 100644 contrib/models/jepa-2-1/demo_neuron.py delete mode 100644 contrib/models/jepa-2-1/src/.___init__.py delete mode 100755 contrib/models/jepa-2-1/src/.___pycache__ delete mode 100644 contrib/models/jepa-2-1/src/._modeling_jepa21.py delete mode 100644 contrib/models/jepa-2-1/test/.___init__.py delete mode 100755 contrib/models/jepa-2-1/test/.___pycache__ delete mode 100755 contrib/models/jepa-2-1/test/._integration delete mode 100755 contrib/models/jepa-2-1/test/._unit delete mode 100644 contrib/models/jepa-2-1/test/integration/.___init__.py delete mode 100755 contrib/models/jepa-2-1/test/integration/.___pycache__ delete mode 100644 contrib/models/jepa-2-1/test/integration/._test_model.py delete mode 100644 contrib/models/jepa-2-1/test/unit/.___init__.py delete mode 100755 contrib/models/jepa-2-1/test/unit/.___pycache__ delete mode 100644 contrib/models/jepa-2-1/test/unit/._test_encoder.py create mode 100644 contrib/models/jepa-2-1/validate_pretrained.py diff --git a/contrib/models/jepa-2-1/._AGENT.md b/contrib/models/jepa-2-1/._AGENT.md deleted file mode 100644 index 5177851bca3e9f09f1906e0acb8266eb0b3b9849..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 163 zcmZQz6=P>$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K 0.999 vs CPU reference -- [x] Benchmarked: ViT-B 164.5ms, ViT-L 437.4ms (batch=1, BF16, 16 frames) -- [x] DataParallel: 2x throughput with `torch_neuronx.DataParallel` (zero code changes) -- [x] Integrated NKI flash attention (`attention_isa_kernel`) — works but slower at 4608 tokens -- [x] Added modular compilation markers (`ModuleMarkerStartWrapper`/`EndWrapper`) - -### Key findings -- Conv3d, `torch.arange`, `repeat_interleave` all compile natively — no workarounds needed -- Only required change: `use_sdpa=False` to bypass unsupported SDPA -- BF16 softmax dtype fix: `.to(v.dtype)` after softmax -- NKI flash attention: higher accuracy but 1.8x slower at 4608 tokens (designed for 16K+) -- DataParallel: linear throughput scaling, 83ms/clip for ViT-B (2 NeuronCores) - -## Phase 3 — Scaling to ViT-g / ViT-G 🔴 BLOCKED - -**Blocker**: neuronx-cc compiler OOMs on host (>124GB RAM) when compiling ViT-g (40 layers) as a monolithic graph. The `ModuleMarkerStartWrapper`/`EndWrapper` markers do NOT cause `torch_neuronx.trace()` to split the graph — they are only respected by `parallel_model_trace` from NxD. - -### Options (in order of recommendation) - -1. **`parallel_model_trace` from NxD** — Use `neuronx_distributed.trace.parallel_model_trace` instead of `torch_neuronx.trace()`. This is how Flux and other NxDI models compile with modular markers. Requires wrapping the model in a `ModelWrapper`-like class with `input_generator()` and `get_model_instance()`. The markers are already in the model code. - -2. **Larger instance for compilation** — Compile on trn2.48xlarge (2TB RAM), then load the `.pt` on trn2.3xlarge for inference. Simplest approach, just costs more during compilation. - -3. **Manual graph splitting** — Trace layers 0-19 and 20-39 as separate models, chain at runtime. Hacky but avoids NxD dependency. - -### Tasks remaining -- [ ] Get ViT-g (1B) compiling via one of the above approaches -- [ ] Validate and benchmark ViT-g -- [ ] Compile, validate, and benchmark ViT-G (1.8B) - -## Phase 4 — Downstream Tasks (NOT STARTED) - -- [ ] Attentive pooler for classification -- [ ] Predictor for action anticipation -- [ ] Test with pretrained checkpoints -- [ ] 64-frame inference (NKI flash attention becomes relevant here) - -## Phase 5 — Contrib Submission (NOT STARTED) - -- [ ] Full test suite on Trainium -- [ ] `neuron_allclose()` validation -- [ ] Complete compatibility matrix -- [ ] Submit PR diff --git a/contrib/models/jepa-2-1/PLAN_trn2.md b/contrib/models/jepa-2-1/PLAN_trn2.md deleted file mode 100644 index f2512f14..00000000 --- a/contrib/models/jepa-2-1/PLAN_trn2.md +++ /dev/null @@ -1,125 +0,0 @@ -# PLAN_trn2.md — V-JEPA 2.1 Trainium Execution Plan & Results - -## Instance - -- **Type**: trn2.3xlarge (persistent spot) in sa-east-1b -- **Instance ID**: i-0cae7b2ac61807cf9 -- **SSH**: `ssh -i ~/.ssh/trn2-sa-east-1.pem ubuntu@52.67.239.128` -- **Hardware**: 1 Neuron device, 2 logical NeuronCores, 96 GB HBM, 124 GB system RAM -- **Neuron SDK**: torch-neuronx 2.9.0, neuronx-cc 2.24.5133, NxDI 0.9.17334 -- **Venv**: `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate` - -## Workflow - -```bash -# Sync local → trn2 -rsync -avz --exclude='__pycache__' --exclude='._*' \ - ~/dev/neuron-docs/neuronx-distributed-inference/contrib/models/jepa-2-1/ \ - -e "ssh -i ~/.ssh/trn2-sa-east-1.pem" ubuntu@52.67.239.128:jepa-2-1/ - -# Run on trn2 -ssh -i ~/.ssh/trn2-sa-east-1.pem ubuntu@52.67.239.128 \ - "cd jepa-2-1 && source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate && python ..." -``` - ---- - -## Results Summary - -### Compilation & Validation (BF16, 16 frames, 384×384) - -| Model | Params | Compile Time | Cosine Sim | Status | -|-------|--------|-------------|------------|--------| -| ViT-B | 86M | ~8 min | 0.999846 | ✅ | -| ViT-L | 300M | 18 min | 0.999873 | ✅ | -| ViT-g | 1.01B | OOM at ~30 min | — | ❌ Host OOM (>124GB RAM) | -| ViT-G | 1.8B | Not attempted | — | ❌ Blocked | - -### Latency (batch=1, single NeuronCore) - -| Model | Median | Mean | p5 | p95 | -|-------|--------|------|-----|-----| -| ViT-B | 164.5 ms | 164.5 ms | 164.4 ms | 164.6 ms | -| ViT-L | 437.4 ms | 437.5 ms | 437.4 ms | 437.6 ms | - -Sub-millisecond variance — deterministic Neuron execution. - -### DataParallel Throughput (2 logical NeuronCores) - -| Model | Per-clip Latency | Throughput | Speedup | -|-------|-----------------|------------|---------| -| ViT-B | 83.2 ms | 12.0 clips/sec | 1.98x | -| ViT-L | 219.8 ms | 4.5 clips/sec | 1.99x | - -Linear scaling with batch size. Any batch size works (dynamic batching). - -### NKI Flash Attention (experimental) - -| Model | Baseline | NKI Flash | Cosine Sim | -|-------|----------|-----------|------------| -| ViT-B | 164.5 ms | 307.4 ms (+87%) | 0.999972 | -| ViT-L | 437.4 ms | 787.2 ms (+80%) | 1.000006 | - -Higher accuracy but slower at 4608 tokens. Reserved for 64-frame (18K token) inference. - ---- - -## Compilation Commands - -### ViT-B - -```python -import torch, torch_neuronx -from src.modeling_jepa21 import build_vjepa21_encoder - -encoder = build_vjepa21_encoder(arch='vit_base', img_size=384, num_frames=16, use_sdpa=False) -encoder.eval().bfloat16() -x = torch.randn(1, 3, 16, 384, 384, dtype=torch.bfloat16) -traced = torch_neuronx.trace(encoder, x, compiler_args=['--auto-cast', 'none']) -traced.save('vjepa21_vitb_16f_384.pt') -``` - -### ViT-L - -Same as above with `arch='vit_large'`. - -### Validation Pattern - -```python -# Build CPU ref (no NKI) and NKI model with same seed for matching weights -torch.manual_seed(0) -encoder_cpu = build_vjepa21_encoder(..., use_nki_flash=False) -# ... get ref output ... - -torch.manual_seed(0) -encoder_nki = build_vjepa21_encoder(..., use_nki_flash=True) -# ... trace and compare ... -``` - ---- - -## Compiled Files on Instance - -``` -~/jepa-2-1/vjepa21_vitb_16f_384_v2.pt (335M) — ViT-B baseline (best) -~/jepa-2-1/vjepa21_vitl_16f_384.pt (1.1G) — ViT-L baseline (best) -~/jepa-2-1/vjepa21_vitb_nki_16f_384.pt (405M) — ViT-B + NKI flash (slower) -~/jepa-2-1/vjepa21_vitl_nki_16f_384.pt (1.4G) — ViT-L + NKI flash (slower) -``` - ---- - -## ViT-g / ViT-G: Compilation Failure Analysis - -**Root cause**: neuronx-cc compiler memory scales with graph size. Peak host RAM usage: -- ViT-L (24 layers): ~60GB → fits in 124GB ✅ -- ViT-g (40 layers): >124GB → OOM ❌ - -The failure is in the compiler, not the model. The CPU forward pass succeeds. The compiled NEFF would likely fit in 96GB HBM at runtime. - -**Attempted mitigation**: Added `ModuleMarkerStartWrapper`/`EndWrapper` from NxDI to split the graph into groups of 8 layers. Result: markers are inserted but `torch_neuronx.trace()` does NOT respect them — it still compiles the full graph as one unit. The markers are only respected by `parallel_model_trace` from NxD. - -**Next steps** (see PLAN.md Phase 3): -1. Use `parallel_model_trace` from NxD (recommended — markers already in code) -2. Compile on trn2.48xlarge (2TB RAM) -3. Manual graph splitting (hacky) diff --git a/contrib/models/jepa-2-1/PLAN_trn2_48xl.md b/contrib/models/jepa-2-1/PLAN_trn2_48xl.md new file mode 100644 index 00000000..917ac957 --- /dev/null +++ b/contrib/models/jepa-2-1/PLAN_trn2_48xl.md @@ -0,0 +1,93 @@ +# V-JEPA 2.1 — ViT-g / ViT-G on trn2.48xlarge + +## Objective + +Compile, test, and benchmark the two larger V-JEPA 2.1 model sizes on a trn2.48xlarge instance (2 TB RAM, 64 NeuronCores). These models OOM'd during compilation on trn2.3xlarge (124 GB RAM). + +| Model | Params | embed_dim | depth | num_heads | Checkpoint | +|-------|--------|-----------|-------|-----------|------------| +| ViT-g | 1.01B | 1408 | 40 | 22 | `vjepa2_1_vitg_384.pt` (~4 GB) | +| ViT-G | 1.8B | 1664 | 48 | 26 | `vjepa2_1_vitG_384.pt` (~7 GB) | + +## Instance + +- **Instance ID:** `i-09812af3093beb594` +- **Type:** trn2.48xlarge (96 vCPUs, 2 TB RAM, 64 NeuronCores) +- **Region:** us-east-2 +- **AMI:** ami-0a81a0376c52f4d22 +- **Access:** SSM (no SSH) +- **State:** stopped → starting + +## Execution Plan + +### Phase 1: Environment Setup + +1. Start instance, wait for SSM connectivity +2. Copy `contrib/models/jepa-2-1/` to instance via SSM + tar +3. Activate Neuron venv: `source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate` +4. Verify SDK: `neuron-ls`, `pip show torch-neuronx neuronx-cc` + +### Phase 2: ViT-g (1B params) + +5. Compile ViT-g with `pretrained=False` (random weights) — image input (1,3,1,384,384) +6. Compile ViT-g with `pretrained=False` — video input (1,3,16,384,384) +7. Validate accuracy: Neuron BF16 vs CPU FP32 (cosine similarity, `neuron_allclose`) +8. Compile ViT-g with `pretrained=True` — validate pretrained weights +9. Benchmark: single NC latency (100 iterations), DataParallel throughput + +### Phase 3: ViT-G (1.8B params) + +10. Repeat steps 5–9 for ViT-G + +### Phase 4: Results + +11. Collect all metrics, update this file with results +12. Stop instance + +## Expected Outputs + +- Compilation time for each model × input type +- Single NeuronCore latency (median, p5, p95) +- DataParallel throughput (clips/sec) +- Cosine similarity (BF16 Neuron vs FP32 CPU) +- Real-time video processing factor (vs 30fps, 16 frames = 0.53s) + +## Results + +### ViT-g (1B) + +| Metric | Image (1 frame) | Video (16 frames) | +|--------|------------------|--------------------| +| Compilation time | — | — | +| Output shape | — | — | +| Cosine similarity | — | — | +| Median latency (ms) | — | — | +| p5 / p95 (ms) | — / — | — / — | +| DP throughput (clips/s) | — | — | + +### ViT-G (1.8B) + +| Metric | Image (1 frame) | Video (16 frames) | +|--------|------------------|--------------------| +| Compilation time | — | — | +| Output shape | — | — | +| Cosine similarity | — | — | +| Median latency (ms) | — | — | +| p5 / p95 (ms) | — / — | — / — | +| DP throughput (clips/s) | — | — | + +### Comparison (all models, video 16 frames, single NC) + +| Model | Params | Median (ms) | Real-time factor | Instance | +|-------|--------|-------------|------------------|----------| +| ViT-B | 86M | 247.4 | 2.1x | trn2.3xlarge | +| ViT-L | 300M | 741.8 | 0.7x | trn2.3xlarge | +| ViT-g | 1.01B | — | — | trn2.48xlarge | +| ViT-G | 1.8B | — | — | trn2.48xlarge | + +## Notes + +- `use_sdpa=False` required (SDPA not supported by `torch_neuronx.trace()`) +- `--auto-cast none` compiler flag required for BF16 +- ViT-g/G use `target_encoder` key in checkpoint (not `ema_encoder`) +- NKI flash attention disabled for 16-frame inference (slower at 4,608 tokens) diff --git a/contrib/models/jepa-2-1/PR_README.md b/contrib/models/jepa-2-1/PR_README.md new file mode 100644 index 00000000..b618180f --- /dev/null +++ b/contrib/models/jepa-2-1/PR_README.md @@ -0,0 +1,236 @@ +## Description + +NxDI contrib implementation of [V-JEPA 2.1](https://github.com/facebookresearch/vjepa2), Meta's self-supervised video foundation model. V-JEPA 2.1 is a Vision Transformer encoder that learns visual representations by predicting masked video segments in representation space. This is a vision encoder — not a causal language model — compiled for inference on AWS Trainium via `torch_neuronx.trace()`. + +Key architecture features ported: +* **3D RoPE:** Separate depth/height/width rotations on head_dim slices, using `repeat_interleave` layout +* **Conv3d tubelet embedding:** 3D convolution for video patch embedding (patch_size=16, tubelet_size=2) +* **Hierarchical output:** Normed features from 4 intermediate layers +* **Modality embeddings:** Separate learned embeddings for image vs video inputs +* **NKI flash attention:** Integrated `attention_isa_kernel` (reserved for 64-frame / 18K token inference) +* **Modular compilation markers:** `ModuleMarkerStartWrapper`/`EndWrapper` for future graph splitting + +## Model Information +* **Model Name:** V-JEPA 2.1 (vit_base, vit_large, vit_giant, vit_gigantic) +* **Model Architecture:** Vision Transformer encoder with 3D RoPE (86M–1.8B params) +* **Purpose:** Self-supervised video representation learning (feature extraction, not text generation) +* **Source:** [https://github.com/facebookresearch/vjepa2](https://github.com/facebookresearch/vjepa2) +* **License:** [MIT](https://github.com/facebookresearch/vjepa2/blob/main/LICENSE) + +## Checklist + +### Required Components +* **Accuracy Test** (`test/integration/test_model.py`) + * Integration test validates Neuron vs CPU accuracy via `neuron_allclose` (rtol=0.01) + * Test can compile and run the model on Neuron (validated on trn2.3xlarge) + * Pretrained weight validation: cosine similarity 0.9998–1.0002 across all configurations (ViT-B/L/g/G, image/video) +* **README.md** with the following sections: + * Usage Example: CPU inference, Neuron compilation, DataParallel + * Compatibility Matrix: trn2.3xlarge with SDK 2.28 + * Example Checkpoints: Meta's pretrained weights (auto-download) + * Testing Instructions: Commands to run unit and integration test suites + * Performance Benchmarks: Latency, throughput, DataParallel scaling +* **Source Code** (`src/`) + * `modeling_jepa21.py` (~700 lines): Self-contained encoder implementation, no upstream imports + * Properly structured in the contrib folder hierarchy + +### Optional Components +* **Unit Tests** (CPU-based, no Neuron device required) + * `test_encoder.py` — Construction: 4/4 PASS (ViT-B/L/g construction, invalid arch) + * `test_encoder.py` — Forward: 6/6 PASS (video/image/batch shapes, hierarchical output, determinism, resolution) + * `test_encoder.py` — Components: 4/4 PASS (PatchEmbed3D, RoPEAttention, Block) + +### Not Applicable (vision encoder, not causal LM) +* vLLM Integration — not applicable (not a text generation model) +* TPOT/TTFT benchmarks — not applicable (no token generation) +* Logit divergence test — not applicable (no autoregressive decoding) +* On-device sampling — not applicable + +### Demos +* **`demo_neuron.py`** — Neuron smoke test: runs pretrained ViT-B on CPU (FP32) and Neuron (BF16), compares embeddings. Cosine similarity 1.0005 [PASS]. Compilation 416s, latency 248ms. +* **`demo_classify.py`** — CPU video classification demo using HuggingFace V-JEPA 2 finetuned on SSv2 (174 action classes). Downloads Big Buck Bunny (CC-BY-3.0) as sample. + +## Folder Structure + +``` +contrib/models/jepa-2-1/ +├── README.md +├── PR_README.md # PR description (paste into GitHub PR body) +├── AGENT.md # Technical reference for coding agents +├── demo_neuron.py # Neuron smoke test (pretrained ViT-B, CPU vs Neuron) +├── demo_classify.py # CPU video classification demo (HF V-JEPA 2 + SSv2) +├── pyproject.toml +├── src/ +│ ├── __init__.py +│ └── modeling_jepa21.py # Self-contained encoder (3D RoPE, Conv3d, NKI flash) +└── test/ + ├── __init__.py + ├── unit/ + │ ├── __init__.py + │ └── test_encoder.py # CPU-only: construction, forward, components + └── integration/ + ├── __init__.py + └── test_model.py # Neuron: trace, accuracy, ViT-B/L +``` + +## Testing + +### How to run the test suite + +**Unit tests (CPU only, no Neuron device needed):** + +```bash +cd contrib/models/jepa-2-1/ +pytest test/unit/ -v +``` + +Expected: **14/14 PASS** (construction: 4, forward: 6, components: 4) + +**Integration tests (needs Neuron hardware, trn2.3xlarge):** + +```bash +cd contrib/models/jepa-2-1/ +pytest test/integration/test_model.py -v +``` + +Expected: **4/4 PASS** (trace ViT-B image/video, Neuron vs CPU accuracy, trace ViT-L image) + +### Accuracy validation + +Neuron output is validated against CPU reference using `neuron_allclose`: + +```python +from torch_neuronx.testing.validation import neuron_allclose +result = neuron_allclose(neuron_output, cpu_output, rtol=0.01, atol=1e-5) +assert result.allclose +``` + +Cosine similarity between BF16 Neuron output and FP32 CPU reference: +- ViT-B: 0.9998 +- ViT-L: 0.9999 +- ViT-g: 0.9999 (image), 1.0001 (video) +- ViT-G: 0.9998 (image) + +## Test Results + +### Unit Tests (CPU) + +| Test Module | Tests | Status | +|-------------|-------|--------| +| test_encoder.py — Construction | 4 | 4/4 PASS | +| test_encoder.py — Forward | 6 | 6/6 PASS | +| test_encoder.py — Components | 4 | 4/4 PASS | +| **Total** | **14** | **14/14 PASS** | + +### Integration Tests (trn2.3xlarge, 2 NeuronCores) + +| Test | Status | Notes | +|------|--------|-------| +| Trace ViT-B image (1 frame) | PASS | Output shape (1, 576, 768) | +| Trace ViT-B video (16 frames) | PASS | Output shape (1, 4608, 768) | +| Neuron vs CPU accuracy (ViT-B) | PASS | `neuron_allclose` rtol=0.01 | +| Trace ViT-L image (1 frame) | PASS | Output shape (1, 576, 1024) | + +### Accuracy (BF16 Neuron vs FP32 CPU, pretrained weights) + +| Model | Input | Cosine Similarity | Status | +|-------|-------|-------------------|--------| +| ViT-B (86M) | Image (1 frame) | 0.9999 | PASS | +| ViT-B (86M) | Video (16 frames) | 1.0000 | PASS | +| ViT-L (300M) | Image (1 frame) | 0.9999 | PASS | +| ViT-L (300M) | Video (16 frames) | 1.0002 | PASS | +| ViT-g (1.01B) | Image (1 frame) | 0.9999 | PASS | +| ViT-g (1.01B) | Video (16 frames) | 1.0001 | PASS | +| ViT-G (1.8B) | Image (1 frame) | 0.9998 | PASS | + +## Performance Benchmarks + +**Pretrained weights, BF16, 384×384.** ViT-B/L on trn2.3xlarge; ViT-g/G on trn2.48xlarge. 100 timed iterations after 10 warmup. + +### Single NeuronCore Latency (batch=1) + +| Model | Input | Median (ms) | p5 (ms) | p95 (ms) | +|-------|-------|-------------|---------|----------| +| ViT-B (86M) | Image (1 frame) | 4.4 | 4.4 | 4.5 | +| ViT-B (86M) | Video (16 frames) | 247.4 | 247.3 | 248.3 | +| ViT-L (300M) | Image (1 frame) | 11.6 | 11.6 | 11.7 | +| ViT-L (300M) | Video (16 frames) | 741.8 | 741.5 | 742.5 | +| ViT-g (1.01B) | Image (1 frame) | 28.0 | 27.9 | 28.1 | +| ViT-g (1.01B) | Video (16 frames) | 1029.5 | 1029.4 | 1029.7 | +| ViT-G (1.8B) | Image (1 frame) | 49.8 | 49.8 | 49.9 | + +### DataParallel Throughput (2 NeuronCores) + +| Model | Input | Per-clip Latency | Throughput | +|-------|-------|-----------------|------------| +| ViT-B (86M) | Image | 5.2 ms | 383 clips/sec | +| ViT-B (86M) | Video (16f) | 249.5 ms | 8.0 clips/sec | +| ViT-L (300M) | Image | 12.5 ms | 160 clips/sec | +| ViT-L (300M) | Video (16f) | 744.1 ms | 2.7 clips/sec | +| ViT-g (1.01B) | Image | 29.1 ms | 68.8 clips/sec | +| ViT-g (1.01B) | Video (16f) | 1032.1 ms | 1.9 clips/sec | +| ViT-G (1.8B) | Image | 51.1 ms | 39.1 clips/sec | + +### Real-Time Video Processing (16 frames @ 30fps = 0.53s of video) + +| Model | Single NC | DataParallel (2 NCs) | +|-------|-----------|----------------------| +| ViT-B | 2.1x real-time | 4.3x real-time | +| ViT-L | 0.7x real-time | 1.4x real-time | +| ViT-g | 0.5x real-time | 1.0x real-time | + +## Compatibility + +Tested with: +* **Neuron SDK Version(s):** 2.28 +* **Instance Type(s):** trn2.3xlarge, trn2.48xlarge +* **PyTorch Version:** 2.9.0 +* **Python Version:** 3.12 +* **neuronx-cc Version:** 2.24.5133 +* **NxDI Version:** 0.9.17334 +* **torch-neuronx Version:** 2.9.0 + +| Instance | NeuronCores | Status | Notes | +|----------|-------------|--------|-------| +| trn2.3xlarge | 2 | **PASS** | ViT-B and ViT-L compiled and benchmarked | +| trn2.48xlarge | 64 | **PASS** | ViT-g and ViT-G compiled and benchmarked (ViT-G video exceeds graph limit) | + +### Minimum Requirements (ViT-L) + +| Resource | Requirement | +|----------|------------| +| HBM | 96 GB (1 Neuron device) | +| System RAM | 124 GB (compilation) | +| Instance | trn2.3xlarge | +| Compiled model | 1.1 GB (.pt file) | + +## Additional Information + +### Key Porting Challenges +1. **Self-contained port:** All upstream imports from Meta's `vjepa2` repo replaced with inline implementations (~700 lines). No runtime dependency on the upstream repo. +2. **SDPA not supported:** `F.scaled_dot_product_attention` is not supported by `torch_neuronx.trace()`. Replaced with manual `Q @ K^T * scale → softmax → @ V` path. +3. **3D RoPE with `repeat_interleave`:** V-JEPA 2.1 uses `repeat_interleave` (not `repeat`) for RoPE frequency expansion. Compiles natively on Neuron — no workaround needed. +4. **Conv3d tubelet embedding:** 3D convolution compiles natively. No decomposition into 2D convolutions needed. +5. **BF16 softmax dtype:** Softmax promotes BF16→FP32 on CPU, causing dtype mismatch with V tensor. Fixed with `.to(v.dtype)` after softmax. +6. **NKI flash attention integration:** Integrated `attention_isa_kernel` with correct tensor layouts. Works correctly but ~80% slower at 4,608 tokens (designed for 16K+). Reserved for 64-frame inference. +7. **Modular compilation markers:** Added `ModuleMarkerStartWrapper`/`EndWrapper` from NxDI, but `torch_neuronx.trace()` does not respect them for graph splitting. They are only respected by `parallel_model_trace` from NxD. + +### Known Limitations +* ViT-g (1B) and ViT-G (1.8B) require trn2.48xlarge for compilation (>130GB host RAM needed); compiled models run on any trn2 instance +* ViT-G video (16 frames) exceeds neuronx-cc's 10M instruction limit (17.8M instructions); requires `parallel_model_trace` to split across NeuronCores +* `use_sdpa=False` is required (SDPA not supported by `torch_neuronx.trace()`) +* NKI flash attention is slower than compiler-generated attention at 4,608 tokens (16 frames) +* Modular compilation markers are not respected by `torch_neuronx.trace()` — need `parallel_model_trace` for graph splitting +* Not a causal LM — no vLLM integration, no KV cache, no token generation +* Pretrained weight download requires network access to `dl.fbaipublicfiles.com` + +### Future Work +* ViT-G video (16 frames) via `parallel_model_trace` to split the 17.8M-instruction graph across NeuronCores +* 64-frame inference (18,432 tokens) where NKI flash attention becomes beneficial +* Downstream tasks: attentive pooler for classification, predictor for action anticipation + +By submitting this PR, I confirm that: +* I have read and followed the contributing guidelines +* This is a community contribution and may have limited testing compared to officially-supported models +* The code follows best practices and is well-documented +* All required components listed above are included diff --git a/contrib/models/jepa-2-1/README.md b/contrib/models/jepa-2-1/README.md index e4bbadce..ded1e7be 100644 --- a/contrib/models/jepa-2-1/README.md +++ b/contrib/models/jepa-2-1/README.md @@ -1,34 +1,136 @@ -# V-JEPA 2.1 on AWS Trainium +# Contrib Model: V-JEPA 2.1 -V-JEPA 2.1 (Video Joint-Embedding Predictive Architecture) is Meta's self-supervised video foundation model. It learns visual representations by predicting masked video segments in a learned representation space, rather than pixel space. V-JEPA 2.1 extends V-JEPA 2 with knowledge distillation from a ViT-Gigantic teacher. +V-JEPA 2.1 (Video Joint-Embedding Predictive Architecture) is Meta's self-supervised video foundation model. It learns visual representations by predicting masked video segments in a learned representation space, rather than pixel space. This is a vision encoder — not a causal language model — compiled for inference on AWS Trainium via `torch_neuronx.trace()`. -This port targets inference on AWS Trainium (trn2) using `torch_neuronx.trace()`. +## Model Family -## Model Information +| Model | Source | Params | Instance | Neuron Status | +|-------|--------|--------|----------|---------------| +| **ViT-B/16** | [facebookresearch/vjepa2](https://github.com/facebookresearch/vjepa2) | 86M | trn2.3xlarge | ✅ Compiled & benchmarked | +| **ViT-L/16** | [facebookresearch/vjepa2](https://github.com/facebookresearch/vjepa2) | 300M | trn2.3xlarge | ✅ Compiled & benchmarked | +| **ViT-g/16** | [facebookresearch/vjepa2](https://github.com/facebookresearch/vjepa2) | 1.01B | trn2.48xlarge | ✅ Compiled & benchmarked | +| **ViT-G/16** | [facebookresearch/vjepa2](https://github.com/facebookresearch/vjepa2) | 1.8B | trn2.48xlarge | ✅ Image compiled; ❌ Video exceeds graph limit | -- **Source**: [facebookresearch/vjepa2](https://github.com/facebookresearch/vjepa2) -- **Paper**: [arxiv.org/abs/2506.09985](https://arxiv.org/abs/2506.09985) -- **License**: MIT +**License:** [MIT](https://github.com/facebookresearch/vjepa2/blob/main/LICENSE) +**Paper:** [arxiv.org/abs/2506.09985](https://arxiv.org/abs/2506.09985) -| Model | Params | Depth | Heads | Resolution | Neuron Status | -|-------|--------|-------|-------|------------|---------------| -| ViT-B/16 | 86M | 12 | 12 | 384 | ✅ Compiled & benchmarked | -| ViT-L/16 | 300M | 24 | 16 | 384 | ✅ Compiled & benchmarked | -| ViT-g/16 | 1B | 40 | 22 | 384 | ❌ Host OOM during compilation (needs >124GB RAM) | -| ViT-G/16 | 1.8B | 48 | 26 | 384 | ❌ Not attempted (blocked by ViT-g) | +## Architecture Details -## Benchmark Results +| Feature | Value | +|---------|-------| +| Type | Vision Transformer encoder (not a causal LM) | +| Patch Embedding | Conv3d tubelets (patch_size=16, tubelet_size=2) | +| Position Encoding | 3D RoPE (separate depth/height/width rotations) | +| Attention | Bidirectional multi-head attention (no KV cache) | +| Normalization | LayerNorm | +| Activation | GELU (encoder), SiLU (predictor) | +| Hierarchical Output | Normed features from 4 intermediate layers | +| Modality Embeddings | Separate learned embeddings for image vs video | +| Image Path | Conv3d with tubelet_size=1 for single-frame input | -trn2.3xlarge, BF16, 16 frames, 384×384, `torch_neuronx.trace()` with `--auto-cast none`: +### Model Configurations -| Model | Single-core Latency | Cosine Sim vs CPU | DataParallel (2 NCs) | Throughput | -|-------|--------------------|--------------------|----------------------|------------| -| ViT-B (86M) | 164.5 ms | 0.9998 | 83.2 ms/clip | 12.0 clips/sec | -| ViT-L (300M) | 437.4 ms | 0.9999 | 219.8 ms/clip | 4.5 clips/sec | +| Arch | embed_dim | depth | num_heads | head_dim | mlp_ratio | Tokens (16f, 384²) | +|------|-----------|-------|-----------|----------|-----------|---------------------| +| vit_base | 768 | 12 | 12 | 64 | 4.0 | 4,608 | +| vit_large | 1024 | 24 | 16 | 64 | 4.0 | 4,608 | +| vit_giant | 1408 | 40 | 22 | 64 | 48/11 | 4,608 | +| vit_gigantic | 1664 | 48 | 26 | 64 | 64/13 | 4,608 | -Real-time video processing (16 frames @ 30fps = 0.53s of video): -- ViT-B: **3.2x real-time** (single-core), **6.4x real-time** (DataParallel) -- ViT-L: **1.2x real-time** (single-core), **2.4x real-time** (DataParallel) +### Unique Architecture Features + +- **3D RoPE:** Head dimension split into depth/height/width slices (d_dim=h_dim=w_dim=20 for head_dim=64, 4 dims unrotated). Uses `repeat_interleave` layout. +- **Hierarchical output:** Normed features from intermediate layers (e.g., [5,11,17,23] for depth=24). Inference returns only the last layer's normed output by default. +- **Modality embeddings:** Separate learned embeddings added after patch embedding for image vs video inputs. +- **interpolate_rope:** Scales RoPE positions for resolution flexibility beyond the pretrained grid size. +- **Pretrained weight loading:** Checkpoints loaded via `torch.hub.load_state_dict_from_url` from Meta's servers. Distilled models (ViT-B, ViT-L) use key `ema_encoder`; self-supervised (ViT-g, ViT-G) use `target_encoder`. + +## Test Results + +### Unit Tests (CPU) + +| Test Module | Tests | Status | +|-------------|-------|--------| +| test_encoder.py — Construction | 4 | 4/4 PASS | +| test_encoder.py — Forward | 6 | 6/6 PASS | +| test_encoder.py — Components | 4 | 4/4 PASS | +| **Total** | **14** | **14/14 PASS** | + +### Integration Tests (trn2.3xlarge, 2 NeuronCores) + +| Test | Status | Notes | +|------|--------|-------| +| Trace ViT-B image (1 frame) | PASS | Output shape (1, 576, 768) | +| Trace ViT-B video (16 frames) | PASS | Output shape (1, 4608, 768) | +| Neuron vs CPU accuracy (ViT-B) | PASS | `neuron_allclose` rtol=0.01 | +| Trace ViT-L image (1 frame) | PASS | Output shape (1, 576, 1024) | + +### Pretrained Weight Validation (BF16 Neuron vs FP32 CPU) + +Validated with official Meta pretrained weights downloaded from `dl.fbaipublicfiles.com/vjepa2/`. Cosine similarity measured between BF16 Neuron output and FP32 CPU reference on identical inputs (seed=42). + +| Model | Input | Cosine Similarity | Status | +|-------|-------|-------------------|--------| +| ViT-B (86M) | Image (1×3×1×384×384) | 0.9999 | PASS | +| ViT-B (86M) | Video (1×3×16×384×384) | 1.0000 | PASS | +| ViT-L (300M) | Image (1×3×1×384×384) | 0.9999 | PASS | +| ViT-L (300M) | Video (1×3×16×384×384) | 1.0002 | PASS | +| ViT-g (1.01B) | Image (1×3×1×384×384) | 0.9999 | PASS | +| ViT-g (1.01B) | Video (1×3×16×384×384) | 1.0001 | PASS | +| ViT-G (1.8B) | Image (1×3×1×384×384) | 0.9998 | PASS | + +No NaN or Inf values in any output. Feature statistics (mean, std, norm) match closely between CPU and Neuron. + +## Performance Benchmarks + +**Pretrained weights, BF16, 384×384, `torch_neuronx.trace()` with `--auto-cast none`.** ViT-B/L on trn2.3xlarge (2 NeuronCores); ViT-g/G on trn2.48xlarge (2 NeuronCores). All measurements from 100 timed iterations after 10 warmup runs. + +### Single NeuronCore Latency (batch=1) + +| Model | Input | Median (ms) | p5 (ms) | p95 (ms) | +|-------|-------|-------------|---------|----------| +| ViT-B (86M) | Image (1 frame) | 4.4 | 4.4 | 4.5 | +| ViT-B (86M) | Video (16 frames) | 247.4 | 247.3 | 248.3 | +| ViT-L (300M) | Image (1 frame) | 11.6 | 11.6 | 11.7 | +| ViT-L (300M) | Video (16 frames) | 741.8 | 741.5 | 742.5 | +| ViT-g (1.01B) | Image (1 frame) | 28.0 | 27.9 | 28.1 | +| ViT-g (1.01B) | Video (16 frames) | 1029.5 | 1029.4 | 1029.7 | +| ViT-G (1.8B) | Image (1 frame) | 49.8 | 49.8 | 49.9 | + +Sub-millisecond variance — deterministic Neuron execution. + +### DataParallel Throughput (2 NeuronCores) + +| Model | Input | Per-clip Latency | Throughput | +|-------|-------|-----------------|------------| +| ViT-B (86M) | Image | 5.2 ms | 383 clips/sec | +| ViT-B (86M) | Video (16f) | 249.5 ms | 8.0 clips/sec | +| ViT-L (300M) | Image | 12.5 ms | 160 clips/sec | +| ViT-L (300M) | Video (16f) | 744.1 ms | 2.7 clips/sec | +| ViT-g (1.01B) | Image | 29.1 ms | 68.8 clips/sec | +| ViT-g (1.01B) | Video (16f) | 1032.1 ms | 1.9 clips/sec | +| ViT-G (1.8B) | Image | 51.1 ms | 39.1 clips/sec | + +### Real-Time Video Processing (16 frames @ 30fps = 0.53s of video) + +| Model | Single NC | DataParallel (2 NCs) | +|-------|-----------|----------------------| +| ViT-B | 2.1x real-time | 4.3x real-time | +| ViT-L | 0.7x real-time | 1.4x real-time | +| ViT-g | 0.5x real-time | 1.0x real-time | + +### Timing Summary + +| Operation | Time | +|-----------|------| +| ViT-B compilation | ~8 min | +| ViT-L compilation | ~18 min | +| ViT-g compilation (image) | ~7 min | +| ViT-g compilation (video) | ~51 min | +| ViT-G compilation (image) | ~11 min | +| ViT-B video inference (single NC) | 247.4 ms | +| ViT-L video inference (single NC) | 741.8 ms | +| ViT-g video inference (single NC) | 1029.5 ms | ## Usage @@ -52,7 +154,10 @@ with torch.no_grad(): import torch, torch_neuronx from src.modeling_jepa21 import build_vjepa21_encoder -encoder = build_vjepa21_encoder(arch="vit_large", img_size=384, num_frames=16, use_sdpa=False) +encoder = build_vjepa21_encoder( + arch="vit_large", img_size=384, num_frames=16, + use_sdpa=False, pretrained=False, +) encoder.eval().bfloat16() x = torch.randn(1, 3, 16, 384, 384, dtype=torch.bfloat16) @@ -72,25 +177,143 @@ batch = torch.randn(4, 3, 16, 384, 384, dtype=torch.bfloat16) output = model_dp(batch) # distributes across 2 NeuronCores ``` -## Key Requirements for Neuron Compilation +### With Pretrained Weights + +```python +encoder = build_vjepa21_encoder(arch="vit_large", img_size=384, num_frames=16, pretrained=True) +# Downloads from https://dl.fbaipublicfiles.com/vjepa2/vjepa2_1_vitl_dist_vitG_384.pt +``` + +## Demos + +### Neuron Smoke Test (`demo_neuron.py`) + +Runs pretrained ViT-B on both CPU (FP32) and Neuron (BF16), compares feature embeddings. Serves as a quick validation that the Neuron port is working correctly. No external dependencies beyond torch-neuronx. + +```bash +# On a Neuron instance (trn2/inf2): +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +python demo_neuron.py # synthetic video (no extra deps) +python demo_neuron.py path/to/video.mp4 # your own video (needs decord, pillow) +``` + +Expected output: +``` +Using synthetic video (moving circle, no dependencies needed) +Input shape: torch.Size([1, 3, 16, 384, 384]) + +Loading pretrained ViT-B (CPU, FP32)... +CPU output: shape=torch.Size([1, 4608, 768]), norm=2042.1 + +Tracing for Neuron (BF16)... +Compilation: 416.1s +Neuron output: shape=torch.Size([1, 4608, 768]), norm=2046.2 +Latency: 248.0ms + +Cosine similarity (CPU FP32 vs Neuron BF16): 1.000502 [PASS] +``` + +### Video Classification (`demo_classify.py`) + +Classifies a video using a finetuned V-JEPA 2 model on Something-Something v2 (174 action classes). Runs on CPU — no Neuron hardware needed. + +```bash +pip install transformers accelerate torchvision decord +python demo_classify.py # Big Buck Bunny sample (CC-BY-3.0) +python demo_classify.py path/to/video.mp4 # your own video +``` + +Note: This demo uses the HuggingFace `VJEPA2ForVideoClassification` model (V-JEPA 2, not 2.1) to demonstrate what the encoder features can do. The Neuron port (`modeling_jepa21.py`) is the V-JEPA 2.1 encoder only. + +## Caveats + +1. **`use_sdpa=False` required** — `F.scaled_dot_product_attention` is not supported by `torch_neuronx.trace()`. Must use manual attention path. + +2. **BF16 model and inputs required** — Cast model with `.bfloat16()` and use BF16 input tensors. Use `--auto-cast none` compiler flag. -- `use_sdpa=False` — SDPA is not supported by `torch_neuronx.trace()` -- `.bfloat16()` model and inputs — trn2 NeuronCores are optimized for BF16 -- `--auto-cast none` — avoids unpredictable compiler auto-cast behavior -- Conv3d, `torch.arange`, `repeat_interleave` all compile natively +3. **ViT-g/ViT-G require trn2.48xlarge for compilation** — neuronx-cc uses >130GB host RAM for 40+ layer graphs, exceeding trn2.3xlarge's 124GB. Compiled models run on any trn2 instance (inference uses <30GB host RAM). -## Known Issues +4. **NKI flash attention slower at short sequences** — The `attention_isa_kernel` is optimized for 16K+ tokens. At 4,608 tokens (16 frames), it's ~80% slower than compiler-generated attention. Use `use_nki_flash=False` for 16-frame inference. -- ViT-g (1B) and ViT-G (1.8B) cannot compile on trn2.3xlarge (124GB RAM) — the neuronx-cc compiler OOMs on the 40+ layer graph. NxDI `ModuleMarkerStartWrapper`/`EndWrapper` markers were added but `torch_neuronx.trace()` does not respect them for graph splitting. Next step: use `parallel_model_trace` from NxD, or compile on a larger instance. -- NKI flash attention (`attention_isa_kernel`) is integrated but slower than compiler-generated attention at 4608 tokens. Reserved for 64-frame inference (18K tokens) where it will be beneficial. -- BF16 softmax promotes to FP32 on CPU; `.to(v.dtype)` cast added after softmax. +5. **BF16 softmax dtype** — Softmax promotes BF16→FP32 on CPU, causing dtype mismatch. Fixed with `.to(v.dtype)` after softmax (already handled in code). + +6. **Not a causal LM** — This is a vision encoder. It does not use NxDI's `NeuronBaseForCausalLM`, KV cache, token generation, or vLLM integration. Compilation uses `torch_neuronx.trace()` directly. + +7. **ViT-G video exceeds single-graph instruction limit** — ViT-G with 16 frames generates 17.8M instructions, exceeding neuronx-cc's 10M limit (`NCC_EXTP004`). Requires `parallel_model_trace` to split across NeuronCores. ViT-G image (1 frame) compiles fine. + +## Compatibility Matrix + +| Instance | NeuronCores | Status | Notes | +|----------|-------------|--------|-------| +| trn2.3xlarge | 2 | **PASS** | ViT-B and ViT-L | +| trn2.48xlarge | 64 | **PASS** | ViT-g and ViT-G (compilation requires >130GB RAM) | + +### Minimum Requirements (ViT-L) + +| Resource | Requirement | +|----------|------------| +| HBM | 96 GB (1 Neuron device) | +| System RAM | 124 GB (compilation) | +| Instance | trn2.3xlarge | +| Compiled model size | 1.1 GB (ViT-L .pt file) | + +### SDK Configuration + +| Component | Version | +|-----------|---------| +| torch-neuronx | 2.9.0 | +| neuronx-cc | 2.24.5133 | +| NxDI | 0.9.17334 | +| Python | 3.12 | +| Venv | `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/` | ## Testing +### Unit Tests (CPU only, no device needed) + ```bash -# CPU-only tests -pytest test/ -v +cd contrib/models/jepa-2-1/ +pytest test/unit/ -v +``` + +Tests: construction (4), forward pass (6), components (4) = **14 tests**. + +### Integration Tests (needs Neuron hardware) -# On Trainium instance +```bash +cd contrib/models/jepa-2-1/ pytest test/integration/test_model.py -v ``` + +Tests: trace ViT-B image/video, Neuron vs CPU accuracy, trace ViT-L image = **4 tests**. + +## Key Porting Challenges + +1. **Self-contained port:** All upstream imports from `vjepa2` replaced with inline implementations. No dependency on the upstream repo at runtime. + +2. **3D RoPE with `repeat_interleave`:** V-JEPA 2.1 uses `repeat_interleave` (not `repeat`) for RoPE frequency expansion. This compiles natively on Neuron. + +3. **Conv3d tubelet embedding:** 3D convolution for video patch embedding compiles natively — no decomposition into 2D convolutions needed. + +4. **Modular compilation markers:** `ModuleMarkerStartWrapper`/`EndWrapper` from NxDI added for future graph splitting. Currently only respected by `parallel_model_trace`, not `torch_neuronx.trace()`. + +5. **NKI flash attention integration:** Integrated `attention_isa_kernel` with correct tensor layouts (q/k: `(B*H, d, seqlen)`, v: `(B*H, seqlen, d)`). Works correctly but slower at short sequences. + +## Example Checkpoints + +Pretrained weights are downloaded automatically when `pretrained=True`: + +| Arch | Checkpoint | Size | Key | +|------|-----------|------|-----| +| vit_base | `vjepa2_1_vitb_dist_vitG_384.pt` | ~350 MB | `ema_encoder` | +| vit_large | `vjepa2_1_vitl_dist_vitG_384.pt` | ~1.2 GB | `ema_encoder` | +| vit_giant | `vjepa2_1_vitg_384.pt` | ~4 GB | `target_encoder` | +| vit_gigantic | `vjepa2_1_vitG_384.pt` | ~7 GB | `target_encoder` | + +Source: `https://dl.fbaipublicfiles.com/vjepa2/` + +## Maintainer + +Community contribution + +**Last Updated:** 2026-04-29 diff --git a/contrib/models/jepa-2-1/demo_classify.py b/contrib/models/jepa-2-1/demo_classify.py new file mode 100644 index 00000000..498caa57 --- /dev/null +++ b/contrib/models/jepa-2-1/demo_classify.py @@ -0,0 +1,73 @@ +""" +V-JEPA 2 video classification demo. + +Classifies a video using the finetuned V-JEPA 2 ViT-L model +on Something-Something v2 (174 action classes). + +Runs on CPU. No Neuron hardware needed. + +Usage: + pip install transformers accelerate torchvision decord + python demo_classify.py # sample bowling video + python demo_classify.py path/to/video.mp4 # your own video + python demo_classify.py photo.jpg # static image (repeated as frames) +""" + +import sys +import os +import urllib.request +import numpy as np +import torch +from transformers import AutoVideoProcessor, VJEPA2ForVideoClassification + +model_id = "facebook/vjepa2-vitl-fpc16-256-ssv2" + + +def load_video_frames(source): + if source.endswith((".jpg", ".jpeg", ".png")): + from PIL import Image + img = np.array(Image.open(source).convert("RGB")) + return np.stack([img] * 16) + + from decord import VideoReader + vr = VideoReader(source) + total = len(vr) + indices = np.linspace(0, total - 1, 16, dtype=int) + return vr.get_batch(indices).asnumpy() # (T, H, W, C) + + +def main(): + # Get video source + if len(sys.argv) > 1: + source = sys.argv[1] + else: + # Big Buck Bunny — CC-BY-3.0, Blender Foundation + source = "/tmp/bigbuckbunny_10s.mp4" + if not os.path.exists(source): + url = "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/360/Big_Buck_Bunny_360_10s_1MB.mp4" + print("Downloading Big Buck Bunny sample clip (CC-BY-3.0, Blender Foundation)...") + urllib.request.urlretrieve(url, source) + + print(f"Loading model {model_id}...") + model = VJEPA2ForVideoClassification.from_pretrained(model_id) + processor = AutoVideoProcessor.from_pretrained(model_id) + model.eval() + + print(f"Loading video: {source}") + video = load_video_frames(source) + print(f"Video shape: {video.shape}") + + inputs = processor(video, return_tensors="pt") + with torch.no_grad(): + logits = model(**inputs).logits + + print("\nTop 10 predictions (Something-Something v2 action classes):") + probs = torch.softmax(logits, dim=-1) + top10 = probs.topk(10) + for i, (idx, prob) in enumerate(zip(top10.indices[0], top10.values[0])): + label = model.config.id2label[idx.item()] + print(f" {i+1:2d}. {prob:.1%} {label}") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/jepa-2-1/demo_neuron.py b/contrib/models/jepa-2-1/demo_neuron.py new file mode 100644 index 00000000..02fe0465 --- /dev/null +++ b/contrib/models/jepa-2-1/demo_neuron.py @@ -0,0 +1,111 @@ +""" +V-JEPA 2.1 Neuron smoke test. + +Runs a video through the pretrained ViT-B encoder on both CPU (FP32) +and Neuron (BF16), then compares the feature embeddings. + +Requires: trn2/inf2 instance with torch-neuronx. + +Usage (on Neuron instance): + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + python demo_neuron.py # synthetic video (no deps) + python demo_neuron.py path/to/video.mp4 # your own video (needs decord, pillow) +""" + +import sys +import os +import time + +import numpy as np +import torch +import torch_neuronx + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) +from modeling_jepa21 import build_vjepa21_encoder + + +def make_synthetic_video(num_frames=16, size=384): + """Generate synthetic video: moving circle on gradient background. Returns (1, 3, T, H, W).""" + frames = [] + for i in range(num_frames): + frame = np.zeros((size, size, 3), dtype=np.float32) + frame[:, :, 2] = np.linspace(0, 0.3, size).reshape(1, -1) + cx = int(size * (0.2 + 0.6 * i / num_frames)) + y, x = np.ogrid[:size, :size] + mask = ((x - cx)**2 + (y - size // 2)**2) < (size // 10)**2 + frame[mask] = 1.0 + frames.append(frame) + # (T, H, W, 3) -> (3, T, H, W) + video = torch.from_numpy(np.stack(frames)).permute(3, 0, 1, 2) + return video.unsqueeze(0) # (1, 3, T, H, W) + + +def load_video_tensor(path, num_frames=16, size=384): + """Load video file as (1, 3, T, H, W) tensor. Needs decord and pillow.""" + from decord import VideoReader + from PIL import Image + vr = VideoReader(path) + indices = np.linspace(0, len(vr) - 1, num_frames, dtype=int) + frames = vr.get_batch(indices).asnumpy() + processed = [] + for f in frames: + img = Image.fromarray(f).resize((size, size)) + processed.append(np.array(img, dtype=np.float32) / 255.0) + video = torch.from_numpy(np.stack(processed)).permute(3, 0, 1, 2) + return video.unsqueeze(0) + + +def main(): + if len(sys.argv) > 1 and os.path.exists(sys.argv[1]): + print(f"Video: {sys.argv[1]}") + video = load_video_tensor(sys.argv[1]) + else: + print("Using synthetic video (moving circle, no dependencies needed)") + video = make_synthetic_video() + print(f"Input shape: {video.shape}") + + # --- CPU reference (FP32) --- + print("\nLoading pretrained ViT-B (CPU, FP32)...") + encoder = build_vjepa21_encoder( + arch="vit_base", img_size=384, num_frames=16, + pretrained=True, use_sdpa=False, + ) + encoder.eval() + + with torch.no_grad(): + cpu_out = encoder(video) + print(f"CPU output: shape={cpu_out.shape}, norm={cpu_out.float().norm():.1f}") + + # --- Neuron (BF16) --- + print("\nTracing for Neuron (BF16)...") + encoder.bfloat16() + video_bf16 = video.bfloat16() + + t0 = time.time() + traced = torch_neuronx.trace(encoder, video_bf16, compiler_args=["--auto-cast", "none"]) + compile_time = time.time() - t0 + print(f"Compilation: {compile_time:.1f}s") + + # Warmup + for _ in range(3): + traced(video_bf16) + + # Timed run + t0 = time.time() + neuron_out = traced(video_bf16) + latency = (time.time() - t0) * 1000 + print(f"Neuron output: shape={neuron_out.shape}, norm={neuron_out.float().norm():.1f}") + print(f"Latency: {latency:.1f}ms") + + # --- Compare --- + cos_sim = torch.nn.functional.cosine_similarity( + cpu_out.float().flatten().unsqueeze(0), + neuron_out.float().flatten().unsqueeze(0), + ).item() + + status = "PASS" if cos_sim > 0.999 else "FAIL" + print(f"\nCosine similarity (CPU FP32 vs Neuron BF16): {cos_sim:.6f} [{status}]") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/jepa-2-1/src/.___init__.py b/contrib/models/jepa-2-1/src/.___init__.py deleted file mode 100644 index 5177851bca3e9f09f1906e0acb8266eb0b3b9849..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 163 zcmZQz6=P>$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K$Vqox1Ojhs@R)|o50+1L3ClDI}aUl?c_=|y<2;dkJ5(HHS(lG;wxzV&S oBE&_L^K 0.999 else "WARN" if cos > 0.99 else "FAIL" + print(f" {r['arch']} / {inp['name']}: cos_sim={cos:.6f} latency={lat:.1f}ms [{status}]") + + # Save results + out_file = "validation_results.json" + with open(out_file, "w") as f: + json.dump(all_results, f, indent=2) + print(f"\nFull results saved to {out_file}") + + +if __name__ == "__main__": + main() From bec3892ed3dea04e340e348fcdb941a3eebff87a Mon Sep 17 00:00:00 2001 From: Daniel Stair Date: Thu, 7 May 2026 17:20:20 +0000 Subject: [PATCH 4/5] remove & consolidate files --- contrib/models/jepa-2-1/AGENT.md | 74 +++--- contrib/models/jepa-2-1/PLAN_trn2_48xl.md | 93 ------- contrib/models/jepa-2-1/PR_README.md | 236 ------------------ contrib/models/jepa-2-1/README.md | 40 ++- contrib/models/jepa-2-1/demo_neuron.py | 111 -------- .../jepa-2-1/{ => examples}/demo_classify.py | 0 .../test/integration/test_pretrained_smoke.py | 105 ++++++++ 7 files changed, 183 insertions(+), 476 deletions(-) delete mode 100644 contrib/models/jepa-2-1/PLAN_trn2_48xl.md delete mode 100644 contrib/models/jepa-2-1/PR_README.md delete mode 100644 contrib/models/jepa-2-1/demo_neuron.py rename contrib/models/jepa-2-1/{ => examples}/demo_classify.py (100%) create mode 100644 contrib/models/jepa-2-1/test/integration/test_pretrained_smoke.py diff --git a/contrib/models/jepa-2-1/AGENT.md b/contrib/models/jepa-2-1/AGENT.md index 4e6e8195..e7e02892 100644 --- a/contrib/models/jepa-2-1/AGENT.md +++ b/contrib/models/jepa-2-1/AGENT.md @@ -7,23 +7,26 @@ This file is for coding agents working on this model. It documents architecture ``` contrib/models/jepa-2-1/ ├── AGENT.md ← You are here -├── PR_README.md ← PR description (paste into GitHub PR body) ├── README.md ← User-facing documentation -├── pyproject.toml +├── pyproject.toml ← Project config (pytest settings, dependencies) +├── examples/ +│ └── demo_classify.py ← CPU video classification demo (HF V-JEPA 2 + SSv2, no Neuron needed) ├── src/ │ ├── __init__.py ← Exports: build_vjepa21_encoder, VisionTransformer, VisionTransformerPredictor │ └── modeling_jepa21.py ← Self-contained encoder (~700 lines, no upstream imports) └── test/ ├── unit/ - │ └── test_encoder.py ← 14 CPU-only tests (construction, forward, components) + │ └── test_encoder.py ← 14 CPU-only tests (construction, forward, components) └── integration/ - └── test_model.py ← 4 Neuron tests (trace, accuracy, ViT-B/L) + ├── test_model.py ← 4 Neuron tests (trace, accuracy, ViT-B/L, random weights) + └── test_pretrained_smoke.py ← 5 tests: 3 CPU + 2 Neuron (pretrained weight validation) ``` ## Source Code - **Upstream**: [github.com/facebookresearch/vjepa2](https://github.com/facebookresearch/vjepa2) - **Neuron port**: `src/modeling_jepa21.py` — self-contained, no upstream imports +- **Examples**: `examples/demo_classify.py` — CPU-only video classification demo using HuggingFace V-JEPA 2 finetuned on SSv2 (174 action classes). Requires `transformers`, `accelerate`, `torchvision`, `decord`. - **Key difference from most NxDI contrib models**: This is a vision encoder, NOT a causal LM. It does not use `NeuronBaseForCausalLM`, KV cache, token generation, or vLLM. Compilation uses `torch_neuronx.trace()` directly. ## Classes and Functions in `modeling_jepa21.py` @@ -169,8 +172,11 @@ trn2.3xlarge has 2 logical NeuronCores → 2x throughput. Scales linearly with b cd contrib/models/jepa-2-1/ pytest test/unit/ -v -# Integration tests (needs Neuron hardware, 4 tests) -pytest test/integration/test_model.py -v +# Integration tests (needs Neuron hardware, 4 + 5 tests) +pytest test/integration/ -v + +# Just the pretrained smoke tests (3 CPU + 2 Neuron) +pytest test/integration/test_pretrained_smoke.py -v ``` ### What the tests cover @@ -183,43 +189,42 @@ pytest test/integration/test_model.py -v **Integration tests (`test/integration/test_model.py`):** - `TestNeuronTrace`: trace ViT-B image, trace ViT-B video, Neuron vs CPU accuracy via `neuron_allclose` (3 tests) - `TestNeuronTraceVitLarge`: trace ViT-L image (1 test) +- Uses random weights (fast compilation, no download) + +**Pretrained smoke tests (`test/integration/test_pretrained_smoke.py`):** +- `TestPretrainedCPU`: loads pretrained ViT-B, checks output shape, checks no NaN/Inf (3 tests) +- `TestPretrainedNeuron`: compiles pretrained ViT-B, validates cosine similarity > 0.999 vs CPU, checks no NaN/Inf (2 tests) +- Downloads ~350MB pretrained weights on first run +- Neuron tests take ~14 min (two compilations of ViT-B 16-frame) ### Test gaps (future work) -- No pretrained weight validation (tests use random weights) -- No ViT-g/ViT-G tests (blocked by compilation) +- No ViT-g/ViT-G tests (blocked by compilation on trn2.3xlarge) - No 64-frame tests - No predictor tests -## Instance Details +## Instance Requirements -- **Type**: trn2.3xlarge (persistent spot) in sa-east-1 -- **Instance ID**: i-0cae7b2ac61807cf9 -- **SSH**: `ssh -i ~/.ssh/trn2-sa-east-1.pem ubuntu@52.67.239.128` -- **Hardware**: 1 Neuron device, 2 logical NeuronCores, 96 GB HBM, 124 GB system RAM +- **Minimum for ViT-B/L**: trn2.3xlarge (1 Neuron device, 2 logical NeuronCores, 96 GB HBM, 124 GB system RAM) +- **Required for ViT-g/G compilation**: trn2.48xlarge (>130 GB host RAM needed for compilation; compiled models run on any trn2) - **Neuron SDK**: torch-neuronx 2.9.0, neuronx-cc 2.24.5133, NxDI 0.9.17334 - **Venv**: `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate` -### Compiled files on instance - -``` -~/jepa-2-1/vjepa21_vitb_16f_384_v2.pt (335M) — ViT-B baseline (best) -~/jepa-2-1/vjepa21_vitl_16f_384.pt (1.1G) — ViT-L baseline (best) -~/jepa-2-1/vjepa21_vitb_nki_16f_384.pt (405M) — ViT-B + NKI flash (slower) -~/jepa-2-1/vjepa21_vitl_nki_16f_384.pt (1.4G) — ViT-L + NKI flash (slower) -``` - ## Workflow ```bash -# Sync local → trn2 -rsync -avz --exclude='__pycache__' --exclude='._*' \ - ~/dev/neuron-docs/neuronx-distributed-inference/contrib/models/jepa-2-1/ \ - -e "ssh -i ~/.ssh/trn2-sa-east-1.pem" ubuntu@52.67.239.128:jepa-2-1/ - -# Run on trn2 -ssh -i ~/.ssh/trn2-sa-east-1.pem ubuntu@52.67.239.128 \ - "cd jepa-2-1 && source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate && python ..." +# Activate the Neuron venv on the instance +. /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +# Run unit tests (CPU only, no Neuron device needed) +cd contrib/models/jepa-2-1/ +pytest test/unit/ -v + +# Run integration tests (needs Neuron hardware) +pytest test/integration/ -v + +# Run only the pretrained smoke tests +pytest test/integration/test_pretrained_smoke.py -v ``` ## Weight Loading @@ -237,15 +242,14 @@ Checkpoints loaded via `torch.hub.load_state_dict_from_url`. State dict keys pre ### P0 — Needed for production readiness 1. **Compile ViT-g (1B) and ViT-G (1.8B)**: Use `parallel_model_trace` from NxD (markers already in code) or compile on trn2.48xlarge (2TB RAM). The modular compilation markers are already inserted. -2. **Validate with pretrained weights**: Current tests use random weights. Need to verify that pretrained models produce meaningful features on Neuron. ### P1 — Valuable additions -3. **64-frame inference**: 18,432 tokens — NKI flash attention should become beneficial here. Need to benchmark. -4. **Downstream tasks**: Attentive pooler for classification, predictor for action anticipation. +2. **64-frame inference**: 18,432 tokens — NKI flash attention should become beneficial here. Need to benchmark. +3. **Downstream tasks**: Attentive pooler for classification, predictor for action anticipation. ### P2 — Nice to have -5. **Tensor parallelism**: For ViT-G on multi-device instances. Would require wrapping with NxD parallel layers. -6. **Dynamic resolution**: Test with non-384 resolutions using `interpolate_rope=True`. +4. **Tensor parallelism**: For ViT-G on multi-device instances. Would require wrapping with NxD parallel layers. +5. **Dynamic resolution**: Test with non-384 resolutions using `interpolate_rope=True`. ## Reference Code in the NxDI Repo diff --git a/contrib/models/jepa-2-1/PLAN_trn2_48xl.md b/contrib/models/jepa-2-1/PLAN_trn2_48xl.md deleted file mode 100644 index 917ac957..00000000 --- a/contrib/models/jepa-2-1/PLAN_trn2_48xl.md +++ /dev/null @@ -1,93 +0,0 @@ -# V-JEPA 2.1 — ViT-g / ViT-G on trn2.48xlarge - -## Objective - -Compile, test, and benchmark the two larger V-JEPA 2.1 model sizes on a trn2.48xlarge instance (2 TB RAM, 64 NeuronCores). These models OOM'd during compilation on trn2.3xlarge (124 GB RAM). - -| Model | Params | embed_dim | depth | num_heads | Checkpoint | -|-------|--------|-----------|-------|-----------|------------| -| ViT-g | 1.01B | 1408 | 40 | 22 | `vjepa2_1_vitg_384.pt` (~4 GB) | -| ViT-G | 1.8B | 1664 | 48 | 26 | `vjepa2_1_vitG_384.pt` (~7 GB) | - -## Instance - -- **Instance ID:** `i-09812af3093beb594` -- **Type:** trn2.48xlarge (96 vCPUs, 2 TB RAM, 64 NeuronCores) -- **Region:** us-east-2 -- **AMI:** ami-0a81a0376c52f4d22 -- **Access:** SSM (no SSH) -- **State:** stopped → starting - -## Execution Plan - -### Phase 1: Environment Setup - -1. Start instance, wait for SSM connectivity -2. Copy `contrib/models/jepa-2-1/` to instance via SSM + tar -3. Activate Neuron venv: `source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate` -4. Verify SDK: `neuron-ls`, `pip show torch-neuronx neuronx-cc` - -### Phase 2: ViT-g (1B params) - -5. Compile ViT-g with `pretrained=False` (random weights) — image input (1,3,1,384,384) -6. Compile ViT-g with `pretrained=False` — video input (1,3,16,384,384) -7. Validate accuracy: Neuron BF16 vs CPU FP32 (cosine similarity, `neuron_allclose`) -8. Compile ViT-g with `pretrained=True` — validate pretrained weights -9. Benchmark: single NC latency (100 iterations), DataParallel throughput - -### Phase 3: ViT-G (1.8B params) - -10. Repeat steps 5–9 for ViT-G - -### Phase 4: Results - -11. Collect all metrics, update this file with results -12. Stop instance - -## Expected Outputs - -- Compilation time for each model × input type -- Single NeuronCore latency (median, p5, p95) -- DataParallel throughput (clips/sec) -- Cosine similarity (BF16 Neuron vs FP32 CPU) -- Real-time video processing factor (vs 30fps, 16 frames = 0.53s) - -## Results - -### ViT-g (1B) - -| Metric | Image (1 frame) | Video (16 frames) | -|--------|------------------|--------------------| -| Compilation time | — | — | -| Output shape | — | — | -| Cosine similarity | — | — | -| Median latency (ms) | — | — | -| p5 / p95 (ms) | — / — | — / — | -| DP throughput (clips/s) | — | — | - -### ViT-G (1.8B) - -| Metric | Image (1 frame) | Video (16 frames) | -|--------|------------------|--------------------| -| Compilation time | — | — | -| Output shape | — | — | -| Cosine similarity | — | — | -| Median latency (ms) | — | — | -| p5 / p95 (ms) | — / — | — / — | -| DP throughput (clips/s) | — | — | - -### Comparison (all models, video 16 frames, single NC) - -| Model | Params | Median (ms) | Real-time factor | Instance | -|-------|--------|-------------|------------------|----------| -| ViT-B | 86M | 247.4 | 2.1x | trn2.3xlarge | -| ViT-L | 300M | 741.8 | 0.7x | trn2.3xlarge | -| ViT-g | 1.01B | — | — | trn2.48xlarge | -| ViT-G | 1.8B | — | — | trn2.48xlarge | - -## Notes - -- `use_sdpa=False` required (SDPA not supported by `torch_neuronx.trace()`) -- `--auto-cast none` compiler flag required for BF16 -- ViT-g/G use `target_encoder` key in checkpoint (not `ema_encoder`) -- NKI flash attention disabled for 16-frame inference (slower at 4,608 tokens) diff --git a/contrib/models/jepa-2-1/PR_README.md b/contrib/models/jepa-2-1/PR_README.md deleted file mode 100644 index b618180f..00000000 --- a/contrib/models/jepa-2-1/PR_README.md +++ /dev/null @@ -1,236 +0,0 @@ -## Description - -NxDI contrib implementation of [V-JEPA 2.1](https://github.com/facebookresearch/vjepa2), Meta's self-supervised video foundation model. V-JEPA 2.1 is a Vision Transformer encoder that learns visual representations by predicting masked video segments in representation space. This is a vision encoder — not a causal language model — compiled for inference on AWS Trainium via `torch_neuronx.trace()`. - -Key architecture features ported: -* **3D RoPE:** Separate depth/height/width rotations on head_dim slices, using `repeat_interleave` layout -* **Conv3d tubelet embedding:** 3D convolution for video patch embedding (patch_size=16, tubelet_size=2) -* **Hierarchical output:** Normed features from 4 intermediate layers -* **Modality embeddings:** Separate learned embeddings for image vs video inputs -* **NKI flash attention:** Integrated `attention_isa_kernel` (reserved for 64-frame / 18K token inference) -* **Modular compilation markers:** `ModuleMarkerStartWrapper`/`EndWrapper` for future graph splitting - -## Model Information -* **Model Name:** V-JEPA 2.1 (vit_base, vit_large, vit_giant, vit_gigantic) -* **Model Architecture:** Vision Transformer encoder with 3D RoPE (86M–1.8B params) -* **Purpose:** Self-supervised video representation learning (feature extraction, not text generation) -* **Source:** [https://github.com/facebookresearch/vjepa2](https://github.com/facebookresearch/vjepa2) -* **License:** [MIT](https://github.com/facebookresearch/vjepa2/blob/main/LICENSE) - -## Checklist - -### Required Components -* **Accuracy Test** (`test/integration/test_model.py`) - * Integration test validates Neuron vs CPU accuracy via `neuron_allclose` (rtol=0.01) - * Test can compile and run the model on Neuron (validated on trn2.3xlarge) - * Pretrained weight validation: cosine similarity 0.9998–1.0002 across all configurations (ViT-B/L/g/G, image/video) -* **README.md** with the following sections: - * Usage Example: CPU inference, Neuron compilation, DataParallel - * Compatibility Matrix: trn2.3xlarge with SDK 2.28 - * Example Checkpoints: Meta's pretrained weights (auto-download) - * Testing Instructions: Commands to run unit and integration test suites - * Performance Benchmarks: Latency, throughput, DataParallel scaling -* **Source Code** (`src/`) - * `modeling_jepa21.py` (~700 lines): Self-contained encoder implementation, no upstream imports - * Properly structured in the contrib folder hierarchy - -### Optional Components -* **Unit Tests** (CPU-based, no Neuron device required) - * `test_encoder.py` — Construction: 4/4 PASS (ViT-B/L/g construction, invalid arch) - * `test_encoder.py` — Forward: 6/6 PASS (video/image/batch shapes, hierarchical output, determinism, resolution) - * `test_encoder.py` — Components: 4/4 PASS (PatchEmbed3D, RoPEAttention, Block) - -### Not Applicable (vision encoder, not causal LM) -* vLLM Integration — not applicable (not a text generation model) -* TPOT/TTFT benchmarks — not applicable (no token generation) -* Logit divergence test — not applicable (no autoregressive decoding) -* On-device sampling — not applicable - -### Demos -* **`demo_neuron.py`** — Neuron smoke test: runs pretrained ViT-B on CPU (FP32) and Neuron (BF16), compares embeddings. Cosine similarity 1.0005 [PASS]. Compilation 416s, latency 248ms. -* **`demo_classify.py`** — CPU video classification demo using HuggingFace V-JEPA 2 finetuned on SSv2 (174 action classes). Downloads Big Buck Bunny (CC-BY-3.0) as sample. - -## Folder Structure - -``` -contrib/models/jepa-2-1/ -├── README.md -├── PR_README.md # PR description (paste into GitHub PR body) -├── AGENT.md # Technical reference for coding agents -├── demo_neuron.py # Neuron smoke test (pretrained ViT-B, CPU vs Neuron) -├── demo_classify.py # CPU video classification demo (HF V-JEPA 2 + SSv2) -├── pyproject.toml -├── src/ -│ ├── __init__.py -│ └── modeling_jepa21.py # Self-contained encoder (3D RoPE, Conv3d, NKI flash) -└── test/ - ├── __init__.py - ├── unit/ - │ ├── __init__.py - │ └── test_encoder.py # CPU-only: construction, forward, components - └── integration/ - ├── __init__.py - └── test_model.py # Neuron: trace, accuracy, ViT-B/L -``` - -## Testing - -### How to run the test suite - -**Unit tests (CPU only, no Neuron device needed):** - -```bash -cd contrib/models/jepa-2-1/ -pytest test/unit/ -v -``` - -Expected: **14/14 PASS** (construction: 4, forward: 6, components: 4) - -**Integration tests (needs Neuron hardware, trn2.3xlarge):** - -```bash -cd contrib/models/jepa-2-1/ -pytest test/integration/test_model.py -v -``` - -Expected: **4/4 PASS** (trace ViT-B image/video, Neuron vs CPU accuracy, trace ViT-L image) - -### Accuracy validation - -Neuron output is validated against CPU reference using `neuron_allclose`: - -```python -from torch_neuronx.testing.validation import neuron_allclose -result = neuron_allclose(neuron_output, cpu_output, rtol=0.01, atol=1e-5) -assert result.allclose -``` - -Cosine similarity between BF16 Neuron output and FP32 CPU reference: -- ViT-B: 0.9998 -- ViT-L: 0.9999 -- ViT-g: 0.9999 (image), 1.0001 (video) -- ViT-G: 0.9998 (image) - -## Test Results - -### Unit Tests (CPU) - -| Test Module | Tests | Status | -|-------------|-------|--------| -| test_encoder.py — Construction | 4 | 4/4 PASS | -| test_encoder.py — Forward | 6 | 6/6 PASS | -| test_encoder.py — Components | 4 | 4/4 PASS | -| **Total** | **14** | **14/14 PASS** | - -### Integration Tests (trn2.3xlarge, 2 NeuronCores) - -| Test | Status | Notes | -|------|--------|-------| -| Trace ViT-B image (1 frame) | PASS | Output shape (1, 576, 768) | -| Trace ViT-B video (16 frames) | PASS | Output shape (1, 4608, 768) | -| Neuron vs CPU accuracy (ViT-B) | PASS | `neuron_allclose` rtol=0.01 | -| Trace ViT-L image (1 frame) | PASS | Output shape (1, 576, 1024) | - -### Accuracy (BF16 Neuron vs FP32 CPU, pretrained weights) - -| Model | Input | Cosine Similarity | Status | -|-------|-------|-------------------|--------| -| ViT-B (86M) | Image (1 frame) | 0.9999 | PASS | -| ViT-B (86M) | Video (16 frames) | 1.0000 | PASS | -| ViT-L (300M) | Image (1 frame) | 0.9999 | PASS | -| ViT-L (300M) | Video (16 frames) | 1.0002 | PASS | -| ViT-g (1.01B) | Image (1 frame) | 0.9999 | PASS | -| ViT-g (1.01B) | Video (16 frames) | 1.0001 | PASS | -| ViT-G (1.8B) | Image (1 frame) | 0.9998 | PASS | - -## Performance Benchmarks - -**Pretrained weights, BF16, 384×384.** ViT-B/L on trn2.3xlarge; ViT-g/G on trn2.48xlarge. 100 timed iterations after 10 warmup. - -### Single NeuronCore Latency (batch=1) - -| Model | Input | Median (ms) | p5 (ms) | p95 (ms) | -|-------|-------|-------------|---------|----------| -| ViT-B (86M) | Image (1 frame) | 4.4 | 4.4 | 4.5 | -| ViT-B (86M) | Video (16 frames) | 247.4 | 247.3 | 248.3 | -| ViT-L (300M) | Image (1 frame) | 11.6 | 11.6 | 11.7 | -| ViT-L (300M) | Video (16 frames) | 741.8 | 741.5 | 742.5 | -| ViT-g (1.01B) | Image (1 frame) | 28.0 | 27.9 | 28.1 | -| ViT-g (1.01B) | Video (16 frames) | 1029.5 | 1029.4 | 1029.7 | -| ViT-G (1.8B) | Image (1 frame) | 49.8 | 49.8 | 49.9 | - -### DataParallel Throughput (2 NeuronCores) - -| Model | Input | Per-clip Latency | Throughput | -|-------|-------|-----------------|------------| -| ViT-B (86M) | Image | 5.2 ms | 383 clips/sec | -| ViT-B (86M) | Video (16f) | 249.5 ms | 8.0 clips/sec | -| ViT-L (300M) | Image | 12.5 ms | 160 clips/sec | -| ViT-L (300M) | Video (16f) | 744.1 ms | 2.7 clips/sec | -| ViT-g (1.01B) | Image | 29.1 ms | 68.8 clips/sec | -| ViT-g (1.01B) | Video (16f) | 1032.1 ms | 1.9 clips/sec | -| ViT-G (1.8B) | Image | 51.1 ms | 39.1 clips/sec | - -### Real-Time Video Processing (16 frames @ 30fps = 0.53s of video) - -| Model | Single NC | DataParallel (2 NCs) | -|-------|-----------|----------------------| -| ViT-B | 2.1x real-time | 4.3x real-time | -| ViT-L | 0.7x real-time | 1.4x real-time | -| ViT-g | 0.5x real-time | 1.0x real-time | - -## Compatibility - -Tested with: -* **Neuron SDK Version(s):** 2.28 -* **Instance Type(s):** trn2.3xlarge, trn2.48xlarge -* **PyTorch Version:** 2.9.0 -* **Python Version:** 3.12 -* **neuronx-cc Version:** 2.24.5133 -* **NxDI Version:** 0.9.17334 -* **torch-neuronx Version:** 2.9.0 - -| Instance | NeuronCores | Status | Notes | -|----------|-------------|--------|-------| -| trn2.3xlarge | 2 | **PASS** | ViT-B and ViT-L compiled and benchmarked | -| trn2.48xlarge | 64 | **PASS** | ViT-g and ViT-G compiled and benchmarked (ViT-G video exceeds graph limit) | - -### Minimum Requirements (ViT-L) - -| Resource | Requirement | -|----------|------------| -| HBM | 96 GB (1 Neuron device) | -| System RAM | 124 GB (compilation) | -| Instance | trn2.3xlarge | -| Compiled model | 1.1 GB (.pt file) | - -## Additional Information - -### Key Porting Challenges -1. **Self-contained port:** All upstream imports from Meta's `vjepa2` repo replaced with inline implementations (~700 lines). No runtime dependency on the upstream repo. -2. **SDPA not supported:** `F.scaled_dot_product_attention` is not supported by `torch_neuronx.trace()`. Replaced with manual `Q @ K^T * scale → softmax → @ V` path. -3. **3D RoPE with `repeat_interleave`:** V-JEPA 2.1 uses `repeat_interleave` (not `repeat`) for RoPE frequency expansion. Compiles natively on Neuron — no workaround needed. -4. **Conv3d tubelet embedding:** 3D convolution compiles natively. No decomposition into 2D convolutions needed. -5. **BF16 softmax dtype:** Softmax promotes BF16→FP32 on CPU, causing dtype mismatch with V tensor. Fixed with `.to(v.dtype)` after softmax. -6. **NKI flash attention integration:** Integrated `attention_isa_kernel` with correct tensor layouts. Works correctly but ~80% slower at 4,608 tokens (designed for 16K+). Reserved for 64-frame inference. -7. **Modular compilation markers:** Added `ModuleMarkerStartWrapper`/`EndWrapper` from NxDI, but `torch_neuronx.trace()` does not respect them for graph splitting. They are only respected by `parallel_model_trace` from NxD. - -### Known Limitations -* ViT-g (1B) and ViT-G (1.8B) require trn2.48xlarge for compilation (>130GB host RAM needed); compiled models run on any trn2 instance -* ViT-G video (16 frames) exceeds neuronx-cc's 10M instruction limit (17.8M instructions); requires `parallel_model_trace` to split across NeuronCores -* `use_sdpa=False` is required (SDPA not supported by `torch_neuronx.trace()`) -* NKI flash attention is slower than compiler-generated attention at 4,608 tokens (16 frames) -* Modular compilation markers are not respected by `torch_neuronx.trace()` — need `parallel_model_trace` for graph splitting -* Not a causal LM — no vLLM integration, no KV cache, no token generation -* Pretrained weight download requires network access to `dl.fbaipublicfiles.com` - -### Future Work -* ViT-G video (16 frames) via `parallel_model_trace` to split the 17.8M-instruction graph across NeuronCores -* 64-frame inference (18,432 tokens) where NKI flash attention becomes beneficial -* Downstream tasks: attentive pooler for classification, predictor for action anticipation - -By submitting this PR, I confirm that: -* I have read and followed the contributing guidelines -* This is a community contribution and may have limited testing compared to officially-supported models -* The code follows best practices and is well-documented -* All required components listed above are included diff --git a/contrib/models/jepa-2-1/README.md b/contrib/models/jepa-2-1/README.md index ded1e7be..b7d16829 100644 --- a/contrib/models/jepa-2-1/README.md +++ b/contrib/models/jepa-2-1/README.md @@ -299,6 +299,45 @@ Tests: trace ViT-B image/video, Neuron vs CPU accuracy, trace ViT-L image = **4 5. **NKI flash attention integration:** Integrated `attention_isa_kernel` with correct tensor layouts (q/k: `(B*H, d, seqlen)`, v: `(B*H, seqlen, d)`). Works correctly but slower at short sequences. +## Folder Structure + +``` +contrib/models/jepa-2-1/ +├── README.md +├── AGENT.md # Technical reference for coding agents +├── demo_neuron.py # Neuron smoke test (pretrained ViT-B, CPU vs Neuron) +├── pyproject.toml +├── examples/ +│ └── demo_classify.py # CPU video classification demo (HF V-JEPA 2 + SSv2) +├── src/ +│ ├── __init__.py +│ └── modeling_jepa21.py # Self-contained encoder (3D RoPE, Conv3d, NKI flash) +└── test/ + ├── __init__.py + ├── unit/ + │ ├── __init__.py + │ └── test_encoder.py # CPU-only: construction, forward, components + └── integration/ + ├── __init__.py + └── test_model.py # Neuron: trace, accuracy, ViT-B/L +``` + +## Known Limitations + +* ViT-g (1B) and ViT-G (1.8B) require trn2.48xlarge for compilation (>130GB host RAM needed); compiled models run on any trn2 instance +* ViT-G video (16 frames) exceeds neuronx-cc's 10M instruction limit (17.8M instructions); requires `parallel_model_trace` to split across NeuronCores +* `use_sdpa=False` is required (SDPA not supported by `torch_neuronx.trace()`) +* NKI flash attention is slower than compiler-generated attention at 4,608 tokens (16 frames) +* Modular compilation markers are not respected by `torch_neuronx.trace()` — need `parallel_model_trace` for graph splitting +* Not a causal LM — no vLLM integration, no KV cache, no token generation +* Pretrained weight download requires network access to `dl.fbaipublicfiles.com` + +## Future Work + +* ViT-G video (16 frames) via `parallel_model_trace` to split the 17.8M-instruction graph across NeuronCores +* 64-frame inference (18,432 tokens) where NKI flash attention becomes beneficial +* Downstream tasks: attentive pooler for classification, predictor for action anticipation + ## Example Checkpoints Pretrained weights are downloaded automatically when `pretrained=True`: @@ -316,4 +355,3 @@ Source: `https://dl.fbaipublicfiles.com/vjepa2/` Community contribution -**Last Updated:** 2026-04-29 diff --git a/contrib/models/jepa-2-1/demo_neuron.py b/contrib/models/jepa-2-1/demo_neuron.py deleted file mode 100644 index 02fe0465..00000000 --- a/contrib/models/jepa-2-1/demo_neuron.py +++ /dev/null @@ -1,111 +0,0 @@ -""" -V-JEPA 2.1 Neuron smoke test. - -Runs a video through the pretrained ViT-B encoder on both CPU (FP32) -and Neuron (BF16), then compares the feature embeddings. - -Requires: trn2/inf2 instance with torch-neuronx. - -Usage (on Neuron instance): - source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate - python demo_neuron.py # synthetic video (no deps) - python demo_neuron.py path/to/video.mp4 # your own video (needs decord, pillow) -""" - -import sys -import os -import time - -import numpy as np -import torch -import torch_neuronx - -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) -from modeling_jepa21 import build_vjepa21_encoder - - -def make_synthetic_video(num_frames=16, size=384): - """Generate synthetic video: moving circle on gradient background. Returns (1, 3, T, H, W).""" - frames = [] - for i in range(num_frames): - frame = np.zeros((size, size, 3), dtype=np.float32) - frame[:, :, 2] = np.linspace(0, 0.3, size).reshape(1, -1) - cx = int(size * (0.2 + 0.6 * i / num_frames)) - y, x = np.ogrid[:size, :size] - mask = ((x - cx)**2 + (y - size // 2)**2) < (size // 10)**2 - frame[mask] = 1.0 - frames.append(frame) - # (T, H, W, 3) -> (3, T, H, W) - video = torch.from_numpy(np.stack(frames)).permute(3, 0, 1, 2) - return video.unsqueeze(0) # (1, 3, T, H, W) - - -def load_video_tensor(path, num_frames=16, size=384): - """Load video file as (1, 3, T, H, W) tensor. Needs decord and pillow.""" - from decord import VideoReader - from PIL import Image - vr = VideoReader(path) - indices = np.linspace(0, len(vr) - 1, num_frames, dtype=int) - frames = vr.get_batch(indices).asnumpy() - processed = [] - for f in frames: - img = Image.fromarray(f).resize((size, size)) - processed.append(np.array(img, dtype=np.float32) / 255.0) - video = torch.from_numpy(np.stack(processed)).permute(3, 0, 1, 2) - return video.unsqueeze(0) - - -def main(): - if len(sys.argv) > 1 and os.path.exists(sys.argv[1]): - print(f"Video: {sys.argv[1]}") - video = load_video_tensor(sys.argv[1]) - else: - print("Using synthetic video (moving circle, no dependencies needed)") - video = make_synthetic_video() - print(f"Input shape: {video.shape}") - - # --- CPU reference (FP32) --- - print("\nLoading pretrained ViT-B (CPU, FP32)...") - encoder = build_vjepa21_encoder( - arch="vit_base", img_size=384, num_frames=16, - pretrained=True, use_sdpa=False, - ) - encoder.eval() - - with torch.no_grad(): - cpu_out = encoder(video) - print(f"CPU output: shape={cpu_out.shape}, norm={cpu_out.float().norm():.1f}") - - # --- Neuron (BF16) --- - print("\nTracing for Neuron (BF16)...") - encoder.bfloat16() - video_bf16 = video.bfloat16() - - t0 = time.time() - traced = torch_neuronx.trace(encoder, video_bf16, compiler_args=["--auto-cast", "none"]) - compile_time = time.time() - t0 - print(f"Compilation: {compile_time:.1f}s") - - # Warmup - for _ in range(3): - traced(video_bf16) - - # Timed run - t0 = time.time() - neuron_out = traced(video_bf16) - latency = (time.time() - t0) * 1000 - print(f"Neuron output: shape={neuron_out.shape}, norm={neuron_out.float().norm():.1f}") - print(f"Latency: {latency:.1f}ms") - - # --- Compare --- - cos_sim = torch.nn.functional.cosine_similarity( - cpu_out.float().flatten().unsqueeze(0), - neuron_out.float().flatten().unsqueeze(0), - ).item() - - status = "PASS" if cos_sim > 0.999 else "FAIL" - print(f"\nCosine similarity (CPU FP32 vs Neuron BF16): {cos_sim:.6f} [{status}]") - - -if __name__ == "__main__": - main() diff --git a/contrib/models/jepa-2-1/demo_classify.py b/contrib/models/jepa-2-1/examples/demo_classify.py similarity index 100% rename from contrib/models/jepa-2-1/demo_classify.py rename to contrib/models/jepa-2-1/examples/demo_classify.py diff --git a/contrib/models/jepa-2-1/test/integration/test_pretrained_smoke.py b/contrib/models/jepa-2-1/test/integration/test_pretrained_smoke.py new file mode 100644 index 00000000..05325229 --- /dev/null +++ b/contrib/models/jepa-2-1/test/integration/test_pretrained_smoke.py @@ -0,0 +1,105 @@ +"""Pretrained weight smoke tests for V-JEPA 2.1 on Neuron. + +Validates that pretrained ViT-B produces correct features on Neuron +by comparing BF16 Neuron output against FP32 CPU reference. + +Requires: trn2/inf2 instance with torch-neuronx. +""" + +import sys +import os + +import pytest +import torch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "src")) + +try: + import torch_neuronx + HAS_NEURON = True +except ImportError: + HAS_NEURON = False + +from modeling_jepa21 import build_vjepa21_encoder + + +@pytest.fixture +def pretrained_encoder(): + """Build pretrained ViT-B encoder (FP32).""" + encoder = build_vjepa21_encoder( + arch="vit_base", img_size=384, num_frames=16, + pretrained=True, use_sdpa=False, + ) + encoder.eval() + return encoder + + +@pytest.fixture +def synthetic_video(): + """Deterministic synthetic video input (1, 3, 16, 384, 384).""" + torch.manual_seed(42) + return torch.randn(1, 3, 16, 384, 384) + + +class TestPretrainedCPU: + """CPU-only sanity checks for pretrained weights.""" + + def test_pretrained_loads(self, pretrained_encoder): + """Pretrained weights load without error.""" + assert pretrained_encoder is not None + + def test_pretrained_forward_shape(self, pretrained_encoder, synthetic_video): + """Pretrained encoder produces correct output shape.""" + with torch.no_grad(): + out = pretrained_encoder(synthetic_video) + assert out.shape == (1, 4608, 768) + + def test_pretrained_no_nan(self, pretrained_encoder, synthetic_video): + """Pretrained encoder output has no NaN/Inf.""" + with torch.no_grad(): + out = pretrained_encoder(synthetic_video) + assert not out.isnan().any() + assert not out.isinf().any() + + +@pytest.mark.skipif(not HAS_NEURON, reason="torch_neuronx not available") +class TestPretrainedNeuron: + """Neuron accuracy tests with pretrained weights.""" + + def test_pretrained_neuron_vs_cpu(self, pretrained_encoder, synthetic_video): + """Pretrained ViT-B on Neuron matches CPU (cosine sim > 0.999).""" + with torch.no_grad(): + cpu_out = pretrained_encoder(synthetic_video) + + pretrained_encoder.bfloat16() + video_bf16 = synthetic_video.bfloat16() + traced = torch_neuronx.trace( + pretrained_encoder, video_bf16, + compiler_args=["--auto-cast", "none"], + ) + + # Warmup + for _ in range(3): + traced(video_bf16) + + neuron_out = traced(video_bf16) + + cos_sim = torch.nn.functional.cosine_similarity( + cpu_out.float().flatten().unsqueeze(0), + neuron_out.float().flatten().unsqueeze(0), + ).item() + + assert cos_sim > 0.999, f"Cosine similarity {cos_sim:.6f} below 0.999 threshold" + + def test_pretrained_neuron_no_nan(self, pretrained_encoder, synthetic_video): + """Neuron output has no NaN/Inf values.""" + pretrained_encoder.bfloat16() + video_bf16 = synthetic_video.bfloat16() + traced = torch_neuronx.trace( + pretrained_encoder, video_bf16, + compiler_args=["--auto-cast", "none"], + ) + neuron_out = traced(video_bf16) + + assert not neuron_out.isnan().any(), "Neuron output contains NaN" + assert not neuron_out.isinf().any(), "Neuron output contains Inf" From b1951009427e7a006072cfd032162549cb14fc1f Mon Sep 17 00:00:00 2001 From: Daniel Stair Date: Thu, 7 May 2026 19:54:06 +0000 Subject: [PATCH 5/5] update README --- contrib/models/jepa-2-1/AGENT.md | 16 ++++++---------- contrib/models/jepa-2-1/README.md | 31 ++++++++++--------------------- 2 files changed, 16 insertions(+), 31 deletions(-) diff --git a/contrib/models/jepa-2-1/AGENT.md b/contrib/models/jepa-2-1/AGENT.md index e7e02892..1d91611c 100644 --- a/contrib/models/jepa-2-1/AGENT.md +++ b/contrib/models/jepa-2-1/AGENT.md @@ -139,8 +139,8 @@ trn2.3xlarge has 2 logical NeuronCores → 2x throughput. Scales linearly with b |-------|--------|-------------|-------------------|--------| | ViT-B | 86M | ~8 min | 0.9998 | ✅ | | ViT-L | 300M | ~18 min | 0.9999 | ✅ | -| ViT-g | 1.01B | OOM at ~30 min | — | ❌ Host OOM | -| ViT-G | 1.8B | Not attempted | — | ❌ Blocked | +| ViT-g | 1.01B | ~51 min | 0.9999 | ✅ (trn2.48xlarge) | +| ViT-G | 1.8B | ~11 min (image) | 0.9998 | ✅ (trn2.48xlarge, image only) | ### Latency (batch=1, single NeuronCore) @@ -199,7 +199,6 @@ pytest test/integration/test_pretrained_smoke.py -v ### Test gaps (future work) -- No ViT-g/ViT-G tests (blocked by compilation on trn2.3xlarge) - No 64-frame tests - No predictor tests @@ -240,16 +239,13 @@ Checkpoints loaded via `torch.hub.load_state_dict_from_url`. State dict keys pre ## Open Work Items -### P0 — Needed for production readiness -1. **Compile ViT-g (1B) and ViT-G (1.8B)**: Use `parallel_model_trace` from NxD (markers already in code) or compile on trn2.48xlarge (2TB RAM). The modular compilation markers are already inserted. - ### P1 — Valuable additions -2. **64-frame inference**: 18,432 tokens — NKI flash attention should become beneficial here. Need to benchmark. -3. **Downstream tasks**: Attentive pooler for classification, predictor for action anticipation. +1. **64-frame inference**: 18,432 tokens — NKI flash attention should become beneficial here. Need to benchmark. +2. **Downstream tasks**: Attentive pooler for classification, predictor for action anticipation. ### P2 — Nice to have -4. **Tensor parallelism**: For ViT-G on multi-device instances. Would require wrapping with NxD parallel layers. -5. **Dynamic resolution**: Test with non-384 resolutions using `interpolate_rope=True`. +3. **Tensor parallelism**: For ViT-G on multi-device instances. Would require wrapping with NxD parallel layers. +4. **Dynamic resolution**: Test with non-384 resolutions using `interpolate_rope=True`. ## Reference Code in the NxDI Repo diff --git a/contrib/models/jepa-2-1/README.md b/contrib/models/jepa-2-1/README.md index b7d16829..3f8290d5 100644 --- a/contrib/models/jepa-2-1/README.md +++ b/contrib/models/jepa-2-1/README.md @@ -186,41 +186,30 @@ encoder = build_vjepa21_encoder(arch="vit_large", img_size=384, num_frames=16, p ## Demos -### Neuron Smoke Test (`demo_neuron.py`) +### Neuron Smoke Test (`test/integration/test_pretrained_smoke.py`) -Runs pretrained ViT-B on both CPU (FP32) and Neuron (BF16), compares feature embeddings. Serves as a quick validation that the Neuron port is working correctly. No external dependencies beyond torch-neuronx. +Validates pretrained ViT-B on both CPU (FP32) and Neuron (BF16), comparing feature embeddings via cosine similarity. Run as a pytest test: ```bash # On a Neuron instance (trn2/inf2): source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate -python demo_neuron.py # synthetic video (no extra deps) -python demo_neuron.py path/to/video.mp4 # your own video (needs decord, pillow) -``` - -Expected output: -``` -Using synthetic video (moving circle, no dependencies needed) -Input shape: torch.Size([1, 3, 16, 384, 384]) - -Loading pretrained ViT-B (CPU, FP32)... -CPU output: shape=torch.Size([1, 4608, 768]), norm=2042.1 +cd contrib/models/jepa-2-1/ -Tracing for Neuron (BF16)... -Compilation: 416.1s -Neuron output: shape=torch.Size([1, 4608, 768]), norm=2046.2 -Latency: 248.0ms +# CPU-only checks (pretrained weight load, output shape, no NaN) +pytest test/integration/test_pretrained_smoke.py::TestPretrainedCPU -v -Cosine similarity (CPU FP32 vs Neuron BF16): 1.000502 [PASS] +# Full Neuron validation (compiles ViT-B, checks cosine sim > 0.999) +pytest test/integration/test_pretrained_smoke.py::TestPretrainedNeuron -v ``` -### Video Classification (`demo_classify.py`) +### Video Classification (`examples/demo_classify.py`) Classifies a video using a finetuned V-JEPA 2 model on Something-Something v2 (174 action classes). Runs on CPU — no Neuron hardware needed. ```bash pip install transformers accelerate torchvision decord -python demo_classify.py # Big Buck Bunny sample (CC-BY-3.0) -python demo_classify.py path/to/video.mp4 # your own video +python examples/demo_classify.py # Big Buck Bunny sample (CC-BY-3.0) +python examples/demo_classify.py path/to/video.mp4 # your own video ``` Note: This demo uses the HuggingFace `VJEPA2ForVideoClassification` model (V-JEPA 2, not 2.1) to demonstrate what the encoder features can do. The Neuron port (`modeling_jepa21.py`) is the V-JEPA 2.1 encoder only.