File tree Expand file tree Collapse file tree 2 files changed +6
-3
lines changed
Expand file tree Collapse file tree 2 files changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -99,6 +99,10 @@ def __post_init__(self):
9999 self .rank = current_rank ().rank
100100 self .size = math .prod (current_size ().values ())
101101
102+ self .compute_log_probs = compute_logprobs
103+ if self .compile .enable :
104+ self .compute_log_probs = torch .compile (self .compute_log_probs )
105+
102106 env = {
103107 "RANK" : str (self .rank ),
104108 "LOCAL_RANK" : str (self .rank ),
@@ -187,11 +191,11 @@ async def forward(
187191 response_tokens = input_ids [:, max_req_tokens :]
188192 if parallel_dims .tp_enabled and isinstance (logits , DTensor ):
189193 with loss_parallel ():
190- logprobs = compute_logprobs (logits , response_tokens , align = True )
194+ logprobs = self . compute_log_probs (logits , response_tokens )
191195
192196 logprobs = logprobs .to_local ()
193197 else :
194- logprobs = compute_logprobs (logits , response_tokens )
198+ logprobs = self . compute_log_probs (logits , response_tokens )
195199 t .step ("compute_logprobs" )
196200 t .stop ()
197201 return logprobs
Original file line number Diff line number Diff line change 99from torch .distributed .tensor import DTensor
1010
1111
12- @torch .compile
1312def compute_logprobs (
1413 logits : torch .Tensor | DTensor ,
1514 input_ids : torch .Tensor ,
You can’t perform that action at this time.
0 commit comments