@@ -1721,7 +1721,6 @@ def logit_bias_processor(
17211721 for i , token in enumerate (all_tokens )
17221722 ]
17231723 all_logprobs = Llama .logits_to_logprobs (self ._scores )[token_offset :]
1724- # TODO: may be able to change this loop to use np.take_along_dim
17251724 for idx , (token , token_str , logprobs_token ) in enumerate (
17261725 zip (all_tokens , all_token_strs , all_logprobs )
17271726 ):
@@ -1736,17 +1735,16 @@ def logit_bias_processor(
17361735 )
17371736 )
17381737 tokens .append (token_str )
1739- sorted_logprobs = list (
1740- sorted (
1741- zip (logprobs_token , range (len (logprobs_token ))), reverse = True
1742- )
1743- )
1738+ top_k_indices = np .argpartition (logprobs_token , - logprobs )[- logprobs :]
1739+ top_k_indices = top_k_indices [
1740+ np .argsort (logprobs_token [top_k_indices ])
1741+ ][::- 1 ]
17441742 token_logprobs .append (logprobs_token [int (token )])
17451743 top_logprob : Optional [Dict [str , float ]] = {
1746- self .detokenize ([i ], prev_tokens = all_tokens [:idx ]).decode (
1744+ self .detokenize ([int ( i ) ], prev_tokens = all_tokens [:idx ]).decode (
17471745 "utf-8" , errors = "ignore"
1748- ): logprob
1749- for logprob , i in sorted_logprobs [: logprobs ]
1746+ ): logprobs_token [ int ( i )]
1747+ for i in top_k_indices
17501748 }
17511749 top_logprob .update ({token_str : logprobs_token [int (token )]})
17521750 top_logprobs .append (top_logprob )
0 commit comments