Skip to content

[Tracking] WS1 Execution Roadmap #202

Description

@Flink-ddd

Context & Objective

This tracking issue outlines the final execution path to successfully close Workstream 1 (Operator-Level Train-Inference Consistency).

To achieve our goal of a fully batch-invariant forward and backward chain, we must follow a strict dependency order: from merging the PyTorch ground-truth references, to validating downstream candidate kernels, auditing backward batch-invariance, and finally assembling the end-to-end chain.

Please refer to the execution roadmap below and coordinate your PRs accordingly.


WS1 Execution Roadmap

================================================================================================================
                                              WS1 EXECUTION ROADMAP
================================================================================================================

[ PHASE 1: Convergence & Merge ]
                                                       |
                                                       v
                                              +-----------------+
                                              | Clear Blockers  | (Dismiss Reviews / Update Branch)
                                              +--------+--------+
                                                       |
                              +------------------------+-------------------------+
                              |                                                  |
                              v                                                  v
                +--------------------------+                           +--------------------------+
                | Merge Ground-Truth Ops   |                           | Merge gtest Framework    |
                | (#160, #166, #167, #168, |                           | (#197)                   |
                |  #169, #170, #188)       |                           +-------------+------------+
                +-------------+------------+                                         |
                              |                                                      |
                              +------------------------+-----------------------------+
                                                       |
                                                       v
                                             [ main Branch Ready ]

================================================================================================================
[ PHASE 2: Baseline Alignment & Kernel Validation ]
                                                       |
                                                       v
                                          +--------------------------+
                                          | Run check_operator.py    |
                                          +------------+-------------+
                                                       |
                                                       v
+--------------------------------------------------------------------------------------------------------------+
| ESTABLISH CONTRACT & VALIDATE CANDIDATE KERNELS                                                              |
|                                                                                                              |
|  * Dtype Coverage Contract              -> issue #154 @Flink-ddd @maxiaosong1124 @EthanZero2Hero             |
|  * RMSNorm (fwd+bwd)                    -> issue #145 @EthanZero2Hero             PR #201 @EthanZero2Hero    |
|  * Deterministic GEMM                   -> issue #146 @Flink-ddd                  PR #180 @Flink-ddd         |
|  * Embedding & LM-Head Routing          -> issue #151 @inaniloquentee             PR #190 @inaniloquentee    |
|  * Attention (Standard Softmax)         -> issue #147 @maxiaosong1124 @EthanZero2Hero                        |
|  * RoPE / Elementwise Audit             -> issue #149 @a-kaa                                                 |
|  * Logprob (Selected, Locked Reduction) -> issue #148 @KJLdefeated @hihaluemen    PR #199 @hihaluemen        |
|  * KV-Cache Path Consistency            -> issue #152 @zhangj1an                  PR #178 @zhangj1an         |
+------------------------------------------------------+-------------------------------------------------------+
                                                       |
                                                       v
                                      (Ensure all pass tolerance_contract)

================================================================================================================
[ PHASE 3: The #153 Audit ]
                                                       |
                                                       v
                                          +--------------------------+
                                          | Comprehensive Code Audit | (Add Backward Slice tests
                                          | (#153)                   |  for all operators)
                                          +------------+-------------+
                                                       |
                                                       v
                                          +--------------------------+
                                          | Fix Batch-Invariance     | (Resolve underlying backward
                                          | Defects                  |  drift across all ops)
                                          +------------+-------------+

================================================================================================================
[ PHASE 4: The #150 Exit Gate ]
                                                       |
                                                       v
                                          +--------------------------+
                                          | Connect End-to-End Chain | (RMSNorm -> Matmul -> Attn -> 
                                          |                          |  LMHead -> Logprob)
                                          +------------+-------------+
                                                       |
                                                       v
                                          +--------------------------+
                                          | Full-Chain Validation    | (Sweep Batch Size & Padding 
                                          | (Forward & Backward)     |  variables)
                                          +------------+-------------+
                                                       |
                                                       v
                                         ******************************
                                         * WS1 EXIT           * --> Advance to WS2 Distributed
                                         ******************************

Metadata

Metadata

Assignees

Labels

platform: cudaSpecific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations)priority: highSevere congestion issues require the highest priority for resolution.sprint-0615

Type

No type

Fields

No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions