File tree Expand file tree Collapse file tree 2 files changed +6
-3
lines changed
Expand file tree Collapse file tree 2 files changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -99,6 +99,10 @@ def __post_init__(self):
9999 self .rank = current_rank ().rank
100100 self .size = math .prod (current_size ().values ())
101101
102+ self .compute_log_probs = compute_logprobs
103+ if self .compile .enable :
104+ self .compute_log_probs = torch .compile (self .compute_log_probs )
105+
102106 env = {
103107 "RANK" : str (self .rank ),
104108 "LOCAL_RANK" : str (self .rank ),
@@ -193,11 +197,11 @@ async def forward(
193197 response_tokens = input_ids [:, max_req_tokens :]
194198 if parallel_dims .tp_enabled and isinstance (logits , DTensor ):
195199 with loss_parallel ():
196- logprobs = compute_logprobs (logits , response_tokens , align = True )
200+ logprobs = self . compute_log_probs (logits , response_tokens )
197201
198202 logprobs = logprobs .to_local ()
199203 else :
200- logprobs = compute_logprobs (logits , response_tokens )
204+ logprobs = self . compute_log_probs (logits , response_tokens )
201205 t .step ("compute_logprobs" )
202206 t .stop ()
203207 return logprobs
Original file line number Diff line number Diff line change 99from torch .distributed .tensor import DTensor
1010
1111
12- @torch .compile
1312def compute_logprobs (
1413 logits : torch .Tensor | DTensor ,
1514 input_ids : torch .Tensor ,
You can’t perform that action at this time.
0 commit comments