diff --git a/common/arg.cpp b/common/arg.cpp index cb36c554128..e98cd1a70a6 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3448,6 +3448,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.p_split = std::stof(value); } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}).set_env("LLAMA_ARG_DRAFT_P_SPLIT")); + add_opt(common_arg( {"--draft-p-min"}, "P", string_format("minimum speculative decoding probability (greedy) (default: %.2f)", (double)params.speculative.p_min), @@ -3455,6 +3456,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.p_min = std::stof(value); } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_P_MIN")); + + add_opt(common_arg( + {"--draft-p-accept"}, "P", + string_format("MTP draft acceptance probability threshold - accept non-argmax draft token if main model assigns it at least this probability (default: %.2f, 0.0 = greedy match only)", (double)params.speculative.p_accept), + [](common_params & params, const std::string & value) { + params.speculative.p_accept = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_P_ACCEPT")); + add_opt(common_arg( {"-cd", "--ctx-size-draft"}, "N", string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx), diff --git a/common/common.h b/common/common.h index 674b2488557..73e85c30694 100644 --- a/common/common.h +++ b/common/common.h @@ -325,6 +325,8 @@ struct common_params_speculative { int32_t draft_block_size = 3; float p_split = 0.1f; // speculative decoding split probability float p_min = 0.75f; // minimum speculative decoding probability (greedy) + float p_accept = 0.0f; // min probability for main model to accept a non-argmax MTP draft token + // ngram-based speculative decoding diff --git a/common/sampling.cpp b/common/sampling.cpp index 526f036ff98..02ac93f55bc 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -605,7 +605,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co return id; } -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first, float p_accept) { GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); std::vector result; @@ -614,14 +614,26 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample size_t i = 0; for (; i < draft.size(); i++) { const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); - - common_sampler_accept(gsmpl, id, true); - - result.push_back(id); - if (draft[i] != id) { + if (p_accept > 0.0f) { + const float * logits = llama_get_logits_ith(ctx, idxs[i]); + const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(ctx))); + float max_l = *std::max_element(logits, logits + n_vocab); + float sum = 0.0f; + for (int j = 0; j < n_vocab; j++) sum += expf(logits[j] - max_l); + const float p_main = expf(logits[draft[i]] - max_l) / sum; + if (p_main >= p_accept) { + common_sampler_accept(gsmpl, draft[i], true); + result.push_back(draft[i]); + continue; + } + } + common_sampler_accept(gsmpl, id, true); + result.push_back(id); break; } + common_sampler_accept(gsmpl, id, true); + result.push_back(id); } if (i == draft.size()) { @@ -635,13 +647,12 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample return result; } -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first, float p_accept) { std::vector idxs(draft.size() + 1); for (size_t i = 0; i < idxs.size(); ++i) { idxs[i] = i; } - - return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first); + return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first, p_accept); } uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { diff --git a/common/sampling.h b/common/sampling.h index 2426126b346..7f99136b6dc 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -82,10 +82,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co // // returns at least 1 token, up to idxs.size() // -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false, float p_accept = 0.0f); // assume idxs == [ 0, 1, 2, ..., draft.size() ] -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false, float p_accept = 0.0f); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index fea3ed87502..515a08628c8 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2938,7 +2938,7 @@ struct server_context_impl { const size_t n_draft = slot.drafted.size(); // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted); + const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted, false, params_base.speculative.p_accept); // For MTP speculation, h_prev for the next draft must come from the LAST ACCEPTED // batch output - not embeddings_ith(-1), which would point at a rejected draft's