Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2396,7 +2396,7 @@ struct server_context_impl {
SLT_WRN(slot, "%s\n", st1.str().c_str());
}

if (pos_min >= pos_min_thold) {
if (pos_min > pos_min_thold) {
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);

// search for a context checkpoint
Expand All @@ -2407,7 +2407,7 @@ struct server_context_impl {
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12,
func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold);
return cur.pos_min < pos_min_thold || cur.pos_min == 0;
return cur.pos_min <= pos_min_thold;
}
);

Expand All @@ -2421,7 +2421,6 @@ struct server_context_impl {
if (n != checkpoint_size) {
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
Expand Down
17 changes: 13 additions & 4 deletions tools/server/server-task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2036,12 +2036,12 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
}

bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
int lcp_best = prompt.tokens.get_common_prefix(tokens_new);

float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins
float sim_best = float(lcp_best) / tokens_new.size();

SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
SRV_WRN(" - looking for better prompt, base lcp = %d, f_keep = %.3f, sim = %.3f\n", lcp_best, f_keep_best, sim_best);

auto it_best = states.end();

Expand All @@ -2057,7 +2057,16 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
continue;
}

if (f_keep_best < f_keep_cur && sim_best < sim_cur) {
// Prioritize absolute prefix reuse length. This helps promote a newly
// joined static prefix (e.g. after truncate-middle) over older cache entries
// that keep more of their own state but match fewer request tokens.
const bool better_match =
lcp_cur > lcp_best ||
(lcp_cur == lcp_best && sim_cur > sim_best) ||
(lcp_cur == lcp_best && sim_cur == sim_best && f_keep_cur > f_keep_best);

if (better_match) {
lcp_best = lcp_cur;
f_keep_best = f_keep_cur;
sim_best = sim_cur;

Expand All @@ -2066,7 +2075,7 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
}

if (it_best != states.end()) {
SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
SRV_WRN(" - found better prompt with lcp = %d, f_keep = %.3f, sim = %.3f\n", lcp_best, f_keep_best, sim_best);

const size_t size = it_best->data.size();
const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
Expand Down