We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d58e684 commit 716b8beCopy full SHA for 716b8be
src/forge/actors/reference_model.py
@@ -194,6 +194,8 @@ async def forward(
194
if parallel_dims.tp_enabled and isinstance(logits, DTensor):
195
with loss_parallel():
196
logprobs = compute_logprobs(logits, response_tokens, align=True)
197
+
198
+ logprobs = logprobs.to_local()
199
else:
200
logprobs = compute_logprobs(logits, response_tokens)
201
t.step("compute_logprobs")
0 commit comments