From 12706d051a8e9f319f3b5c6edd39a206b8b04382 Mon Sep 17 00:00:00 2001 From: Dr Sujit Vasanth Date: Mon, 11 May 2026 02:35:38 +0100 Subject: [PATCH 1/2] feat: one-sided target probability acceptance for MTP drafts increases acceptance rate and throughput compared to argmax alone MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MTP drafters use greedy argmax internally — they do not expose a full logit distribution, by design, for speed. This change adds a further tok/s improvement by allowing users to tune the acceptance threshold, achieving ~20% throughput gains by accepting more draft tokens while retaining the ability to manually verify the threshold at which semantic breakdown occurs for their specific model/task combination. When the drafter and target model disagree on a token, rather than immediately rejecting (standard argmax behaviour), --draft-p-accept triggers a one-sided softmax check over the target model's logits for the draft token. If the target assigns p >= draft-p-accept to that token, it is accepted in place of the target's own argmax prediction and decoding continues. No drafter logits are required, keeping the drafter inference path unchanged and preserving the speed advantage of argmax-only drafting. This is intentionally lighter than the full ratio test in the MTP paper. Changes: - common/sampling.cpp: add p_accept parameter to sample_and_accept_n; on drafter/target disagreement compute softmax over target logits and accept draft token if p_target(draft_token) >= p_accept - common/sampling.h: update both overloads of sample_and_accept_n signature - common/arg.cpp: register --draft-p-accept CLI argument - common/common.h: add p_accept field to common_params_speculative struct - tools/server/server-context.cpp: wire p_accept into speculative config Usage: --draft-p-accept 0.005 # accept draft token if p_target >= 0.005 --draft-p-accept 0.0 # standard argmax-only behaviour (default) --- common/arg.cpp | 10 ++++++++++ common/common.h | 2 ++ common/sampling.cpp | 22 +++++++++++++++------- common/sampling.h | 4 ++-- tools/server/server-context.cpp | 2 +- 5 files changed, 30 insertions(+), 10 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index cb36c554128..e98cd1a70a6 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3448,6 +3448,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.p_split = std::stof(value); } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}).set_env("LLAMA_ARG_DRAFT_P_SPLIT")); + add_opt(common_arg( {"--draft-p-min"}, "P", string_format("minimum speculative decoding probability (greedy) (default: %.2f)", (double)params.speculative.p_min), @@ -3455,6 +3456,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.p_min = std::stof(value); } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_P_MIN")); + + add_opt(common_arg( + {"--draft-p-accept"}, "P", + string_format("MTP draft acceptance probability threshold - accept non-argmax draft token if main model assigns it at least this probability (default: %.2f, 0.0 = greedy match only)", (double)params.speculative.p_accept), + [](common_params & params, const std::string & value) { + params.speculative.p_accept = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_P_ACCEPT")); + add_opt(common_arg( {"-cd", "--ctx-size-draft"}, "N", string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx), diff --git a/common/common.h b/common/common.h index 674b2488557..73e85c30694 100644 --- a/common/common.h +++ b/common/common.h @@ -325,6 +325,8 @@ struct common_params_speculative { int32_t draft_block_size = 3; float p_split = 0.1f; // speculative decoding split probability float p_min = 0.75f; // minimum speculative decoding probability (greedy) + float p_accept = 0.0f; // min probability for main model to accept a non-argmax MTP draft token + // ngram-based speculative decoding diff --git a/common/sampling.cpp b/common/sampling.cpp index 526f036ff98..e7bc6503f69 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -605,7 +605,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co return id; } -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first, float p_accept) { GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); std::vector result; @@ -614,12 +614,21 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample size_t i = 0; for (; i < draft.size(); i++) { const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); - common_sampler_accept(gsmpl, id, true); - result.push_back(id); - if (draft[i] != id) { + if (p_accept > 0.0f) { + const float * logits = llama_get_logits_ith(ctx, idxs[i]); + const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(ctx))); + float max_l = *std::max_element(logits, logits + n_vocab); + float sum = 0.0f; + for (int j = 0; j < n_vocab; j++) sum += expf(logits[j] - max_l); + const float p_main = expf(logits[draft[i]] - max_l) / sum; + if (p_main >= p_accept) { + result.back() = draft[i]; + continue; + } + } break; } } @@ -635,13 +644,12 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample return result; } -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first, float p_accept) { std::vector idxs(draft.size() + 1); for (size_t i = 0; i < idxs.size(); ++i) { idxs[i] = i; } - - return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first); + return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first, p_accept); } uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { diff --git a/common/sampling.h b/common/sampling.h index 2426126b346..7f99136b6dc 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -82,10 +82,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co // // returns at least 1 token, up to idxs.size() // -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false, float p_accept = 0.0f); // assume idxs == [ 0, 1, 2, ..., draft.size() ] -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false, float p_accept = 0.0f); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index fea3ed87502..515a08628c8 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2938,7 +2938,7 @@ struct server_context_impl { const size_t n_draft = slot.drafted.size(); // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted); + const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted, false, params_base.speculative.p_accept); // For MTP speculation, h_prev for the next draft must come from the LAST ACCEPTED // batch output - not embeddings_ith(-1), which would point at a rejected draft's From 1c5d20851f8cd6aa95ede1265f4a235cde7a0448 Mon Sep 17 00:00:00 2001 From: Dr Sujit Vasanth Date: Wed, 13 May 2026 03:20:47 +0100 Subject: [PATCH 2/2] fix: defer common_sampler_accept until after p_accept resolution Fixes sampler state bug identified by Ooooze - previously common_sampler_accept was called with target id before p_accept check, leaving grammar FSM and gsmpl->prev tracking wrong token when draft token was substituted. --- common/sampling.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index e7bc6503f69..02ac93f55bc 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -614,8 +614,6 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample size_t i = 0; for (; i < draft.size(); i++) { const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); - common_sampler_accept(gsmpl, id, true); - result.push_back(id); if (draft[i] != id) { if (p_accept > 0.0f) { const float * logits = llama_get_logits_ith(ctx, idxs[i]); @@ -625,12 +623,17 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample for (int j = 0; j < n_vocab; j++) sum += expf(logits[j] - max_l); const float p_main = expf(logits[draft[i]] - max_l) / sum; if (p_main >= p_accept) { - result.back() = draft[i]; + common_sampler_accept(gsmpl, draft[i], true); + result.push_back(draft[i]); continue; } } + common_sampler_accept(gsmpl, id, true); + result.push_back(id); break; } + common_sampler_accept(gsmpl, id, true); + result.push_back(id); } if (i == draft.size()) {