Skip to content

Use Protocols to type-check linear_proj submodules of Attention#3434

Open
nschank wants to merge 2 commits intoNVIDIA:mainfrom
nschank:linearproj
Open

Use Protocols to type-check linear_proj submodules of Attention#3434
nschank wants to merge 2 commits intoNVIDIA:mainfrom
nschank:linearproj

Conversation

@nschank
Copy link
Copy Markdown
Contributor

@nschank nschank commented Feb 15, 2026

What does this PR do ?

Defines Protocols representing linear_proj submodules, and uses them instead of ModuleSpec to enable typechecking of its construction in SelfAttention, CrossAttention, and MLA.

I also updated Backend to return linear_proj specifically, allowing type-checking of RowParallelLinear types as instances of linear_proj directly (otherwise Backend "hides" the type and makes no type-checking occur).

While I was in attention, I also updated the naming conventions of the existing interfaces to match what we've finalized on.

Associated design doc: Typed ModuleSpec.pdf

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@nschank nschank requested review from a team as code owners February 15, 2026 16:42
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Feb 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ko3n1g ko3n1g requested a review from a team February 15, 2026 16:42
@Phlip79 Phlip79 added Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. complexity: medium labels Feb 17, 2026
@Phlip79
Copy link
Copy Markdown
Member

Phlip79 commented Feb 17, 2026

/ok to test 9db13d6

@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Mar 2, 2026
@nschank
Copy link
Copy Markdown
Contributor Author

nschank commented Mar 7, 2026

Resynced after coming back from travel, sorry for delay!

@chtruong814 chtruong814 added needs-follow-up Issue needs follow-up and removed needs-follow-up Issue needs follow-up labels Mar 7, 2026
Copy link
Copy Markdown
Contributor

@yashaswikarnati yashaswikarnati left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

synced offline, just had a minor comment,overall lgtm!

@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Mar 14, 2026
@jaredcasper
Copy link
Copy Markdown
Contributor

I also updated Backend to return linear_proj specifically, allowing type-checking of RowParallelLinear types as instances of linear_proj directly (otherwise Backend "hides" the type and makes no type-checking occur).

Can you expand on this a bit? I'm guessing this is adding the row_parallel_linear_proj() function in addition to the "row_parallel_linear()" function? Don't those have the same inputs/outputs so same types? Why the need for a special one for "_proj"?

@nschank
Copy link
Copy Markdown
Contributor Author

nschank commented Mar 19, 2026

@jaredcasper Sure! Fair criticism, this is sorta in a partial state so maybe I should update with a TODO for clarity or something. I'm trying to solve the following problem:

backend: BackendSpecProvider = ...
submodules = SelfAttentionSubmodules(..., linear_proj=backend.get_type(), ...)

SelfAttentionSubmodules.linear_proj has a specific interface it wants to require - it knows the exact signature that a LinearProjBuilder is supposed to satisfy, and same for the LinearProjInterface it must return. So whenever you provide something via linear_proj=, the type checker is given the opportunity to check that the interface actually matches.

It can only do so if the thing being passed to linear_proj= actually has a type which can be tested against that interface. This is true of specific classes (so if I pass something of type type[RowParallelLinear]), unions of classes, Callables, functools.partial, etc.

But the return type of BackendSpecProvider.row_parallel_linear() is just type. type is basically equivalent to 🤷 as far as the type-checker is concerned, so doing linear_proj=backend.row_parallel_linear() will not catch a type error. Individual subclasses of BackendSpecProvider can provide a narrower return type for row_parallel_linear, which helps somewhat (if callers are using a subclass directly), but any time a caller is using something which the type-checker only knows is a BackendSpecProvider (but not which kind) then it will not type-check row_parallel_linear.

I don't have a great Protocol to use here for what generically a method named row_parallel_linear() should actually return - there are at least two distinct Protocols that row_parallel_linear() needs to satisfy (LinearProjBuilder and LinearFc2Builder), and it's not entirely obvious those two things are required to have identical interfaces. The ideal world would be if I could just say the return type is LinearProjBuilder & LinearFc2Builder (i.e. it must satisfy both at once) but Python doesn't support that.

Thus, my proposed solution here is effectively to have BackendSpecProvider offer individual methods for each particular Builder protocol that we end up introducing. If we later merge LinearProjBuilder and LinearFc2Builder into a single LinearLayerBuilder then both column_parallel_linear and row_parallel_linear could use it; but in the meantime I think we should have row_parallel_linear_proj (returning LinearProjBuilder) and I will rename row_parallel_linear() to row_parallel_linear_fc2() -> LinearFc2Builder. This basically means BackendSpecProvider might return the same class from multiple separate methods, but each one is enforcing that that class satisfies a different interface.

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the Final Review PR is in the "final review" stage label Mar 19, 2026
@chtruong814 chtruong814 removed the needs-follow-up Issue needs follow-up label Mar 19, 2026
@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Mar 22, 2026
@jaredcasper
Copy link
Copy Markdown
Contributor

it's not entirely obvious those two things are required to have identical interfaces.

Why not? Both are row parallel linear layers. Do we have cases where the fc2 row_parallel_linear has a different interface than the linear_proj row_parallel_linear? If so that should be fixed. I don't want a backend to say "this is what I want you to use for linear proj" and "this is what I want you to use for fc2"... That's not a backend, that's just an additional layer to specs (which is already too confusing as it is). I want a backend to say "when you need a row_parallel_linear layer, wherever it is, this is the thing to use."

@nschank
Copy link
Copy Markdown
Contributor Author

nschank commented Mar 24, 2026

Why not? Both are row parallel linear layers. Do we have cases where the fc2 row_parallel_linear has a different interface than the linear_proj row_parallel_linear?

This is fair! I was trying to be less restrictive than that, and let the two APIs evolve independently. I don't really have an opinion between the two - if we want to basically merge the Linear protocols into one, and enforce that for all callers, that would be totally fine with me.

I would personally recommend letting me get this in first, and then I can merge them as an immediate followup - merging the interfaces for just row_linear will be a nontrivial task, and similarly it will be a bit of work to do the same for column_linear. So if you're fine with the somewhat gross intermediate step, I can follow up. But if not, I'm happy to spend some time getting together the full thing. Which makes more sense to you? @jaredcasper

@jaredcasper
Copy link
Copy Markdown
Contributor

I think it makes sense to have the backend define types for general layer types, not specific layers (i.e. define "row parallel linear" instead of "type specifically for the linear proj", otherwise, as I said, it's doing the same thing as the spec in general, just hidden behind yet another layer of abstraction. Putting this in the meantime adds this API to the backend that would then need to be changed again. Let's just go straight to row_parallel_linear().

@nschank
Copy link
Copy Markdown
Contributor Author

nschank commented Mar 31, 2026

SG, will get back in a day or so with the update!

@nschank
Copy link
Copy Markdown
Contributor Author

nschank commented Apr 1, 2026

@jaredcasper I realized the relevant work is actually somewhat independent so am opening a separate PR for it here: #4087 - I simply reverted the Backend changes here so now we're just focusing on linear_proj, and the issue I noted with type-checking will be fixed by that PR.

@Phlip79 Phlip79 removed the Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. label Apr 3, 2026
@Phlip79
Copy link
Copy Markdown
Member

Phlip79 commented Apr 3, 2026

/ok to test 6585350

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request complexity: medium Final Review PR is in the "final review" stage needs-follow-up Issue needs follow-up

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants