Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3448,13 +3448,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.p_split = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}).set_env("LLAMA_ARG_DRAFT_P_SPLIT"));

add_opt(common_arg(
{"--draft-p-min"}, "P",
string_format("minimum speculative decoding probability (greedy) (default: %.2f)", (double)params.speculative.p_min),
[](common_params & params, const std::string & value) {
params.speculative.p_min = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_P_MIN"));

add_opt(common_arg(
{"--draft-p-accept"}, "P",
string_format("MTP draft acceptance probability threshold - accept non-argmax draft token if main model assigns it at least this probability (default: %.2f, 0.0 = greedy match only)", (double)params.speculative.p_accept),
[](common_params & params, const std::string & value) {
params.speculative.p_accept = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_P_ACCEPT"));

add_opt(common_arg(
{"-cd", "--ctx-size-draft"}, "N",
string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx),
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ struct common_params_speculative {
int32_t draft_block_size = 3;
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
float p_accept = 0.0f; // min probability for main model to accept a non-argmax MTP draft token


// ngram-based speculative decoding

Expand Down
29 changes: 20 additions & 9 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
return id;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first, float p_accept) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");

std::vector<llama_token> result;
Expand All @@ -614,14 +614,26 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

common_sampler_accept(gsmpl, id, true);

result.push_back(id);

if (draft[i] != id) {
if (p_accept > 0.0f) {
const float * logits = llama_get_logits_ith(ctx, idxs[i]);
const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(ctx)));
float max_l = *std::max_element(logits, logits + n_vocab);
float sum = 0.0f;
for (int j = 0; j < n_vocab; j++) sum += expf(logits[j] - max_l);
const float p_main = expf(logits[draft[i]] - max_l) / sum;
if (p_main >= p_accept) {
common_sampler_accept(gsmpl, draft[i], true);
result.push_back(draft[i]);
continue;
}
}
common_sampler_accept(gsmpl, id, true);
result.push_back(id);
break;
}
common_sampler_accept(gsmpl, id, true);
result.push_back(id);
}

if (i == draft.size()) {
Expand All @@ -635,13 +647,12 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
return result;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first, float p_accept) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}

return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first, p_accept);
}

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
Expand Down
4 changes: 2 additions & 2 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
//
// returns at least 1 token, up to idxs.size()
//
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false, float p_accept = 0.0f);

// assume idxs == [ 0, 1, 2, ..., draft.size() ]
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false, float p_accept = 0.0f);

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);

Expand Down
2 changes: 1 addition & 1 deletion tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2938,7 +2938,7 @@ struct server_context_impl {
const size_t n_draft = slot.drafted.size();

// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted, false, params_base.speculative.p_accept);

// For MTP speculation, h_prev for the next draft must come from the LAST ACCEPTED
// batch output - not embeddings_ith(-1), which would point at a rejected draft's
Expand Down
Loading