-
Notifications
You must be signed in to change notification settings - Fork 69
feat: Reduce reference model memory with with parallel logprob computation #608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: Reduce reference model memory with with parallel logprob computation #608
Conversation
joecummings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this idea!
Could I ask for a few things?
- WandB logs that show the memory saved. This is always helpful as a part of verifying the correctness.
- Combine the parallel_logprobs and regular logprobs in the same file. No need to split that out just yet.
- Look for ways that this code could be simplified and/or factored out. Claude can be very verbose :)
Looking forward to getting this landed!
|
Thanks for the review @joecummings ! I refactored the code as per feedback. Less Claude footprint now :). Let me know if the code needs to be further simplified/refactored. I attached wandb chart images in the description. Also attaching it here: Old state usage:
New state (Parallel logprobs based) usage:
|
53ddb5b to
20f59bf
Compare
|
Hi @felipemello1, The unit tests were failing as Ran the tests locally. All pass. Can you trigger the tests again please? Thanks! |
|
@gitlost-murali, thanks for opening the PR. Great results! can you try to run the non-sharded version but compile F.cross_entropy? e.g. I think that simply compiling it greatly reduces the memory, since it never materializes the intermediate activations. Maybe something to do in addition to your work and not in place of your work. I am skeptical about using the log-sum-exp directly and not F.cross_entropy, since the compiled version is highly optimized. Also, you might be interested in checking Nathan's old PRs in torchtune: meta-pytorch/torchtune#2782 |
78b4913 to
da54044
Compare
|
Hi @felipemello1 , Thanks for the suggestion! The non-sharded version usage went down from 58GB to 27GB
As the compiled version doesn't spike the memory, I agree with the skepticism. Currently, the reference model handles around ~9k seq-len. For multi-turn setup, the seq-len would further increase. This is where we can benefit from the sharded version as it avoids the all gather ( |
|
oh wow, better than i expected! thanks for doing it.
Here is what i am thinking, let me know if you agree: Would you like to take this one? (2) regarding the loss parallel, i am not super familiar with it, but it seems that distributed already has APIs for it for TP and context parallelism. I think that using those would be more robust. Perhaps check if TorchTitan already does it. I will ask internally and get back to you. |
felipemello1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comments above
|
This is what i was told: so it seems that its easier than we though :) . Just dont call .full_tensor before the F.cross_entropy. |
|
Another reply i got:
i think we just need to call the loss under the context parallel context |
2c15e66 to
5fa2fc6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes addressed
|
awesome, thanks @gitlost-murali . I will try to get to it between today and tomorrow. |
…reference model This update introduces the function to compute log probabilities without gathering the full vocabulary tensor across GPUs.
…ge in non-sharded and sharded computation
5fa2fc6 to
f92b503
Compare




Summary
When tensor parallelism is enabled, the reference model's logits are sharded across GPUs on the vocabulary dimension. Previously, we called
full_tensor()to gather the complete vocab on each GPU before computing log probabilities.This PR adds
compute_logprobs_parallel()that computes log probabilities distributedly using thelog-sum-exptrick across shards.Memory savings (measured)
Old state usage:
New state (Parallel logprobs based) usage:
Tested with batch=4, seq_len=9k (1024 prompt tokens + 8192 response tokens), vocab=150k, TP=2
Changes
src/forge/util/parallel_logprobs.py- distributed log-prob computation for vocab-sharded DTensorstests/unit_tests/util/test_parallel_logprobs.py- correctness tests against sequential implementationsrc/forge/actors/reference_model.py- uses parallel version when TP is enabledImplementation
Uses distributed log-softmax without gathering:
Testing
compute_logprobs()within 1e-5 tolerance