We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a5bd773 commit ce51a2fCopy full SHA for ce51a2f
src/forge/actors/reference_model.py
@@ -194,6 +194,8 @@ async def forward(
194
if parallel_dims.tp_enabled and isinstance(logits, DTensor):
195
with loss_parallel():
196
logprobs = compute_logprobs(logits, response_tokens, align=True)
197
+
198
+ logprobs = logprobs.to_local()
199
else:
200
logprobs = compute_logprobs(logits, response_tokens)
201
t.step("compute_logprobs")
0 commit comments