Skip to content

Commit ce51a2f

Browse files
fix: convert DTensor output to regular tensor after loss_parallel
1 parent a5bd773 commit ce51a2f

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/forge/actors/reference_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ async def forward(
194194
if parallel_dims.tp_enabled and isinstance(logits, DTensor):
195195
with loss_parallel():
196196
logprobs = compute_logprobs(logits, response_tokens, align=True)
197+
198+
logprobs = logprobs.to_local()
197199
else:
198200
logprobs = compute_logprobs(logits, response_tokens)
199201
t.step("compute_logprobs")

0 commit comments

Comments
 (0)