Skip to content

Commit 2c15e66

Browse files
refactor: make compile configurable for logprobs computation
1 parent 716b8be commit 2c15e66

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/forge/actors/reference_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ def __post_init__(self):
9999
self.rank = current_rank().rank
100100
self.size = math.prod(current_size().values())
101101

102+
self.compute_log_probs = compute_logprobs
103+
if self.compile.enable:
104+
self.compute_log_probs = torch.compile(self.compute_log_probs)
105+
102106
env = {
103107
"RANK": str(self.rank),
104108
"LOCAL_RANK": str(self.rank),
@@ -193,11 +197,11 @@ async def forward(
193197
response_tokens = input_ids[:, max_req_tokens:]
194198
if parallel_dims.tp_enabled and isinstance(logits, DTensor):
195199
with loss_parallel():
196-
logprobs = compute_logprobs(logits, response_tokens, align=True)
200+
logprobs = self.compute_log_probs(logits, response_tokens)
197201

198202
logprobs = logprobs.to_local()
199203
else:
200-
logprobs = compute_logprobs(logits, response_tokens)
204+
logprobs = self.compute_log_probs(logits, response_tokens)
201205
t.step("compute_logprobs")
202206
t.stop()
203207
return logprobs

src/forge/util/ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from torch.distributed.tensor import DTensor
1010

1111

12-
@torch.compile
1312
def compute_logprobs(
1413
logits: torch.Tensor | DTensor,
1514
input_ids: torch.Tensor,

0 commit comments

Comments
 (0)