Skip to content

Commit 4d4c571

Browse files
author
Ralf Waldukat
committed
perf: use np.argpartition for top-k logprobs instead of full vocab sort
Replace O(V log V) Python sorted() with O(V) np.argpartition for finding top-k logprobs. For a 128K vocab, this is orders of magnitude faster.
1 parent 488cb3e commit 4d4c571

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

llama_cpp/llama.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,7 +1721,6 @@ def logit_bias_processor(
17211721
for i, token in enumerate(all_tokens)
17221722
]
17231723
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
1724-
# TODO: may be able to change this loop to use np.take_along_dim
17251724
for idx, (token, token_str, logprobs_token) in enumerate(
17261725
zip(all_tokens, all_token_strs, all_logprobs)
17271726
):
@@ -1736,17 +1735,16 @@ def logit_bias_processor(
17361735
)
17371736
)
17381737
tokens.append(token_str)
1739-
sorted_logprobs = list(
1740-
sorted(
1741-
zip(logprobs_token, range(len(logprobs_token))), reverse=True
1742-
)
1743-
)
1738+
top_k_indices = np.argpartition(logprobs_token, -logprobs)[-logprobs:]
1739+
top_k_indices = top_k_indices[
1740+
np.argsort(logprobs_token[top_k_indices])
1741+
][::-1]
17441742
token_logprobs.append(logprobs_token[int(token)])
17451743
top_logprob: Optional[Dict[str, float]] = {
1746-
self.detokenize([i], prev_tokens=all_tokens[:idx]).decode(
1744+
self.detokenize([int(i)], prev_tokens=all_tokens[:idx]).decode(
17471745
"utf-8", errors="ignore"
1748-
): logprob
1749-
for logprob, i in sorted_logprobs[:logprobs]
1746+
): logprobs_token[int(i)]
1747+
for i in top_k_indices
17501748
}
17511749
top_logprob.update({token_str: logprobs_token[int(token)]})
17521750
top_logprobs.append(top_logprob)

0 commit comments

Comments
 (0)