Skip to content

Conversation

@isururanawaka
Copy link
Contributor

Summary:
This diff introduce two objectives for LP Planner considering input_dist in Critical Path.
- BALANCE_ACROSS_ALL_SYNC_POINTS_WITH_INPUT_DIST
max(fwd compute) + max(bwd compute) + sum_{module, shardtype} max(fwd comms for module) + max(bwd comms for module, shardtype) + sum_{module, shardtype} max(bwd comms for module, shardtype}
+ sum_{module} max(input_dist_comms for module)

 -  BALANCE_ACROSS_ALL_SYNC_POINTS_WITH_COMBINED_FWD_COMMS_INPUT_DIST
        max(fwd compute) + max(bwd compute) + sum_{module, shardtype} max(fwd comms + input_dist_comms for module) + max(bwd comms for module, shardtype) + sum_{module, shardtype} max(bwd comms for module, shardtype}

Differential Revision: D87389540

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 24, 2025
@meta-codesync
Copy link
Contributor

meta-codesync bot commented Nov 24, 2025

@isururanawaka has exported this pull request. If you are a Meta employee, you can view the originating Diff in D87389540.

isururanawaka added a commit to isururanawaka/torchrec that referenced this pull request Nov 24, 2025
Summary:
Pull Request resolved: meta-pytorch#3575

This diff introduce two objectives for LP Planner considering  input_dist in Critical Path.
     -  BALANCE_ACROSS_ALL_SYNC_POINTS_WITH_INPUT_DIST
            max(fwd compute) + max(bwd compute) + sum_{module, shardtype} max(fwd comms for module) + max(bwd comms for module, shardtype) + sum_{module, shardtype} max(bwd comms for module, shardtype}
             + sum_{module} max(input_dist_comms for module)

     -  BALANCE_ACROSS_ALL_SYNC_POINTS_WITH_COMBINED_FWD_COMMS_INPUT_DIST
            max(fwd compute) + max(bwd compute) + sum_{module, shardtype} max(fwd comms + input_dist_comms for module) + max(bwd comms for module, shardtype) + sum_{module, shardtype} max(bwd comms for module, shardtype}

Differential Revision: D87389540
Summary:

This introduces input distribution latency estimations. Input distribution is two step communication happens inside SDD pipelines.
 - split exchange:  Exchanges buffer sizes to receive input IDS from KJTs.  The cost does not depend on Input and it meta data exchanging phase. Hence, this diff excludes that from the computations.
- ID exchange:  this exchanges actual IDs to lookup. we estimated the cost by analyzing all-to-all comms

Differential Revision: D87389540
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant