Skip to content

Commit f92b503

Browse files
refactor: make compile configurable for logprobs computation
1 parent 90120bb commit f92b503

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/forge/actors/reference_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ def __post_init__(self):
9999
self.rank = current_rank().rank
100100
self.size = math.prod(current_size().values())
101101

102+
self.compute_log_probs = compute_logprobs
103+
if self.compile.enable:
104+
self.compute_log_probs = torch.compile(self.compute_log_probs)
105+
102106
env = {
103107
"RANK": str(self.rank),
104108
"LOCAL_RANK": str(self.rank),
@@ -187,11 +191,11 @@ async def forward(
187191
response_tokens = input_ids[:, max_req_tokens:]
188192
if parallel_dims.tp_enabled and isinstance(logits, DTensor):
189193
with loss_parallel():
190-
logprobs = compute_logprobs(logits, response_tokens, align=True)
194+
logprobs = self.compute_log_probs(logits, response_tokens)
191195

192196
logprobs = logprobs.to_local()
193197
else:
194-
logprobs = compute_logprobs(logits, response_tokens)
198+
logprobs = self.compute_log_probs(logits, response_tokens)
195199
t.step("compute_logprobs")
196200
t.stop()
197201
return logprobs

src/forge/util/ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from torch.distributed.tensor import DTensor
1010

1111

12-
@torch.compile
1312
def compute_logprobs(
1413
logits: torch.Tensor | DTensor,
1514
input_ids: torch.Tensor,

0 commit comments

Comments
 (0)