Skip to content

Conversation

@gitlost-murali
Copy link
Contributor

@gitlost-murali gitlost-murali commented Nov 30, 2025

Summary

When tensor parallelism is enabled, the reference model's logits are sharded across GPUs on the vocabulary dimension. Previously, we called full_tensor() to gather the complete vocab on each GPU before computing log probabilities.

This PR adds compute_logprobs_parallel() that computes log probabilities distributedly using the log-sum-exp trick across shards.

Memory savings (measured)

Scenario Memory per GPU
Old (full_tensor + compute_logprobs) 58 GB
New (parallel logprobs) 34 GB
Saved 24 GB (~41%)

Old state usage:

current

New state (Parallel logprobs based) usage:

optimized

Tested with batch=4, seq_len=9k (1024 prompt tokens + 8192 response tokens), vocab=150k, TP=2

Changes

  • New: src/forge/util/parallel_logprobs.py - distributed log-prob computation for vocab-sharded DTensors
  • New: tests/unit_tests/util/test_parallel_logprobs.py - correctness tests against sequential implementation
  • Modified: src/forge/actors/reference_model.py - uses parallel version when TP is enabled

Implementation

Uses distributed log-softmax without gathering:

  1. All-reduce MAX for numerical stability
  2. All-reduce SUM of local exp(x - max)
  3. Each rank gathers logits only for tokens in its shard
  4. All-reduce SUM to combine (only owning rank contributes)

Testing

  • Verified results match compute_logprobs() within 1e-5 tolerance
  • Tested temperature scaling, alignment modes, numerical stability with extreme values
  • Tested 2-way vocab sharded config

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 30, 2025
@gitlost-murali gitlost-murali changed the title feat: Distributed log-prob computation for vocab-sharded reference model feat: Optimize reference model GPU usage by distributed log-prob computation on vocab-sharded logits Nov 30, 2025
@gitlost-murali gitlost-murali changed the title feat: Optimize reference model GPU usage by distributed log-prob computation on vocab-sharded logits feat: Reduce reference model memory usage with distributed log-probs comp Nov 30, 2025
@gitlost-murali gitlost-murali changed the title feat: Reduce reference model memory usage with distributed log-probs comp feat: Reduce reference model memory with with parallel logprob computation Nov 30, 2025
Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea!

Could I ask for a few things?

  1. WandB logs that show the memory saved. This is always helpful as a part of verifying the correctness.
  2. Combine the parallel_logprobs and regular logprobs in the same file. No need to split that out just yet.
  3. Look for ways that this code could be simplified and/or factored out. Claude can be very verbose :)

Looking forward to getting this landed!

@gitlost-murali
Copy link
Contributor Author

Thanks for the review @joecummings !

I refactored the code as per feedback. Less Claude footprint now :). Let me know if the code needs to be further simplified/refactored.

I attached wandb chart images in the description. Also attaching it here:

Old state usage:

current

New state (Parallel logprobs based) usage:

optimized

@gitlost-murali
Copy link
Contributor Author

Hi @felipemello1,

The unit tests were failing as pytz was missing from CI env. I rebased on main now. Looks like #618 (easy - remove pytz) takes care of this

Ran the tests locally. All pass. Can you trigger the tests again please?

Thanks!

@felipemello1
Copy link
Contributor

felipemello1 commented Dec 3, 2025

@gitlost-murali, thanks for opening the PR. Great results!

can you try to run the non-sharded version but compile F.cross_entropy? e.g.

@torch.compile()
def compute_logprobs(...):
    ...

I think that simply compiling it greatly reduces the memory, since it never materializes the intermediate activations. Maybe something to do in addition to your work and not in place of your work. I am skeptical about using the log-sum-exp directly and not F.cross_entropy, since the compiled version is highly optimized.

Also, you might be interested in checking Nathan's old PRs in torchtune: meta-pytorch/torchtune#2782

@gitlost-murali gitlost-murali force-pushed the optimize-ref-model-usage branch from 78b4913 to da54044 Compare December 5, 2025 11:25
@gitlost-murali
Copy link
Contributor Author

gitlost-murali commented Dec 5, 2025

Hi @felipemello1 ,

Thanks for the suggestion! torch.compile greatly reduced the memory usage.

The non-sharded version usage went down from 58GB to 27GB
And the sharded version usage went down from 34GB to 7GB

Screenshot 2025-12-05 at 12 53 34

I am skeptical about using the log-sum-exp directly and not F.cross_entropy, since the compiled version is highly optimized.

As the compiled version doesn't spike the memory, I agree with the skepticism.

Currently, the reference model handles around ~9k seq-len. For multi-turn setup, the seq-len would further increase. This is where we can benefit from the sharded version as it avoids the all gather (.full_tensor()). But if you think the sharded version (log-sum-exp) is an overkill or mixing the levels of abstraction, I am happy to reduce this MR to just adding the decorator on current non-sharded version

@felipemello1
Copy link
Contributor

felipemello1 commented Dec 5, 2025

oh wow, better than i expected! thanks for doing it.

I am happy to reduce this MR to just adding the decorator on current non-sharded version

Here is what i am thinking, let me know if you agree:
(1) yes, we should have a PR where we enable torch.compile on this function. Decorator is easy, but there is no way to disable it if for some reason we need to. The reference model already has a flag for compile. We can do in post_init:

self.compute_log_probs = compute_logprobs
if compile:
   self.compute_log_probs = torch.compile(self.compute_log_probs)

Would you like to take this one?

(2) regarding the loss parallel, i am not super familiar with it, but it seems that distributed already has APIs for it for TP and context parallelism.

TP: https://github.com/meta-pytorch/torchtune/blob/67ab86b94de9e7ac7dd9850113ebe69e2bbd307c/torchtune/training/_distributed.py#L894

CP: https://github.com/meta-pytorch/torchtune/blob/67ab86b94de9e7ac7dd9850113ebe69e2bbd307c/torchtune/training/_distributed.py#L844

I think that using those would be more robust. Perhaps check if TorchTitan already does it. I will ask internally and get back to you.

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments above

@felipemello1
Copy link
Contributor

This is what i was told:

Now in torchtitan, when TP is enabled, we keep the output sharded and use loss_parallel context to calculate loss using sharded logits: https://docs.pytorch.org/docs/stable/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.loss_parallel

so it seems that its easier than we though :) . Just dont call .full_tensor before the F.cross_entropy.

@felipemello1
Copy link
Contributor

felipemello1 commented Dec 5, 2025

Another reply i got:

CP loss computation is naturally parallelized, along seq dim. Similar for DP.

For TP you can parallelize along vocab dimension, using loss_parallel context https://github.com/pytorch/torchtitan/blob/1168f9e4d58bbd91c07b08c382d1ca3ae4b2e02c/torchtitan/distributed/utils.py#L

i think we just need to call the loss under the context parallel context

@gitlost-murali
Copy link
Contributor Author

gitlost-murali commented Dec 6, 2025

Amazing! Thanks a lot for the pointers! This simplifies the code a lot and I like that compile is configurable

I updated the code to use loss_parallel context manager. Here are the numbers for sharded computation with loss_parallel():

  • 39 GB with no-compile vs 7 GB with compile
Screenshot 2025-12-06 at 01 30 45

Btw, should we leave the compile.enable flag as false in yml files? or change this to true?

  compile:
    enable: false

Overall summary (no-compile vs compile):

  • The non-sharded version: 58GB vs 27GB
  • Sharded version (custom shading logic): 34GB vs 7GB
  • Sharded version (current version with loss_parallel() ctx manager): 39GB vs 7GB

Copy link
Contributor Author

@gitlost-murali gitlost-murali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes addressed

@felipemello1
Copy link
Contributor

awesome, thanks @gitlost-murali . I will try to get to it between today and tomorrow.

@gitlost-murali gitlost-murali force-pushed the optimize-ref-model-usage branch from 5fa2fc6 to f92b503 Compare December 8, 2025 22:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants