Skip to content

Move loss_fn into model definition in PP example#324

Merged
xmfan merged 1 commit intomainfrom
xmfan/stack/25
Feb 24, 2026
Merged

Move loss_fn into model definition in PP example#324
xmfan merged 1 commit intomainfrom
xmfan/stack/25

Conversation

@xmfan
Copy link
Member

@xmfan xmfan commented Feb 24, 2026

Stacked PRs:


Move loss_fn into model definition in PP example

Instead of passing loss_fn to AutoParallelPP, wrap the model in a
ModelWithLoss module that bakes cross-entropy loss into forward().
This makes the example compatible with tracing self.model directly.

xmfan added a commit that referenced this pull request Feb 24, 2026
Instead of passing loss_fn to AutoParallelPP, wrap the model in a
ModelWithLoss module that bakes cross-entropy loss into forward().
This makes the example compatible with tracing self.model directly.

stack-info: PR: #324, branch: xmfan/stack/25
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 24, 2026
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Then I suppose we can remove the custom export that we have?

Instead of passing loss_fn to AutoParallelPP, wrap the model in a
ModelWithLoss module that bakes cross-entropy loss into forward().
This makes the example compatible with tracing self.model directly.

stack-info: PR: #324, branch: xmfan/stack/25
@xmfan
Copy link
Member Author

xmfan commented Feb 24, 2026

@fmassa you mean clean up AutoParallelPP? or something else?

@xmfan xmfan merged commit 454780d into main Feb 24, 2026
4 of 6 checks passed
xmfan added a commit that referenced this pull request Mar 2, 2026
Same approach as #324 for example_pp_graph_passes.py: wrap the last
stage in a ModelWithLoss module instead of passing loss_fn to
AutoParallelPP, which no longer supports it.

Authored with Claude.
xmfan added a commit that referenced this pull request Mar 2, 2026
Same approach as #324 for example_pp_graph_passes.py: wrap the last
stage in a ModelWithLoss module instead of passing loss_fn to
AutoParallelPP, which no longer supports it.

Authored with Claude.

stack-info: PR: #329, branch: xmfan/stack/27
xmfan added a commit that referenced this pull request Mar 4, 2026
Same approach as #324 for example_pp_graph_passes.py: wrap the last
stage in a ModelWithLoss module instead of passing loss_fn to
AutoParallelPP, which no longer supports it.

Authored with Claude.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants