Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3315,6 +3315,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.cache_type_k = kv_cache_type_from_str(value);
}
).set_env("LLAMA_ARG_CACHE_TYPE_K_DRAFT"));
add_opt(common_arg(
{"-mtp", "--multi-token-prediction"},
string_format("Activate multi-token-prediction (if supported) (default: %s)", params.mtp ? "true" : "false"),
[](common_params & params) {
params.mtp = true;
}
));
add_opt(common_arg(
{"-ctvd", "--cache-type-v-draft"}, "TYPE",
string_format(
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ struct common_params {
bool check_tensors = false; // validate tensor data
bool no_op_offload = false; // globally disable offload host tensor operations to device
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
bool mtp = false; // use mtp is supported

bool single_turn = false; // single turn chat conversation

Expand Down
44 changes: 37 additions & 7 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,6 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co

llama_sampler_apply(chain, &cur_p);

/*for (int k = 0; k < (int)cur_p.size; ++k) {
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f)\n",
k, 0, cur_p.data[k].id, cur_p.data[k].p);
}*/

GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");

const llama_token id = cur_p.data[cur_p.selected].id;
Expand Down Expand Up @@ -583,6 +578,41 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
return samplers;
}

void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) {
llama_sampler_apply(gsmpl->chain, cur_p);
/**
* Specialized sampling for speculative drafting.
*
* Prioritizes performance by using a direct ArgMax loop (Greedy) when no
* penalties (repetition, frequency, presence, DRY) are configured.
* Falls back to the full sampler chain if penalties are active to prevent
* generative loops or adhere to constraints.
*/
llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
const auto & params = gsmpl->params;

bool use_heavy_sampler =
(params.penalty_last_n > 0 && (
params.penalty_repeat != 1.0f ||
params.penalty_freq != 0.0f ||
params.penalty_present != 0.0f
)) ||
(params.dry_allowed_length > 0 && params.dry_multiplier != 0.0f);

if (use_heavy_sampler) {
return common_sampler_sample(gsmpl, ctx, idx, false);
}

float * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_model_get_vocab(llama_get_model(ctx)));

int best_id = 0;
float max_val = logits[0];

for (int i = 1; i < n_vocab; ++i) {
if (logits[i] > max_val) {
max_val = logits[i];
best_id = i;
}
}

return best_id;
}
2 changes: 1 addition & 1 deletion common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,4 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
const char * grammar_kind, const char * grammar_data);

void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p);
llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
33 changes: 11 additions & 22 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,42 +367,31 @@ llama_token mtp_speculative_gen_draft(
struct common_sampler* smpl,
struct llama_context* ctx,
llama_token id_last,
int32_t n_past,
int32_t last_tok_idx) {
int32_t n_past) {

if (!smpl) return -1;

if (!smpl) {
return -1;
}
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
const llama_pos draft_pos = n_past;
const llama_seq_id draft_seq_id = 0;

common_batch_add(mtp_batch, id_last, n_past, {0}, true);

mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN;

// Perform the MTP draft generation decode. This writes the MTP layer's
// KV state for the draft token into the cache.
llama_decode(ctx, mtp_batch);
if (llama_decode(ctx, mtp_batch) != 0) {
llama_batch_free(mtp_batch);
return -1;
}
llama_batch_free(mtp_batch);

// CRITICAL: Purge the metadata for the draft token we just wrote.
// This makes the physical cell available again for the main model's validation pass,
// preventing a cache state corruption where two cells map to the same logical position.
llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1);

const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab);
llama_token_data_array * cur_p = common_sampler_get_candidates(smpl);
cur_p->size = n_vocab;
for (int i = 0; i < n_vocab; ++i) {
cur_p->data[i].id = i;
cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // For a single-token batch, logits are always at index 0.
}
cur_p->sorted = false;
common_sampler_apply_chain(smpl, cur_p);

return cur_p->data[0].id;
llama_kv_cache_seq_rm(ctx, draft_seq_id, n_past, n_past + 1);

return common_sampler_sample_speculative(smpl, ctx, 0);
}


Expand Down
3 changes: 1 addition & 2 deletions common/speculative.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ llama_token mtp_speculative_gen_draft(
struct common_sampler* smpl,
struct llama_context* ctx,
llama_token id_last,
int32_t n_past,
int32_t last_tok_idx);
int32_t n_past);

// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_gen_draft(
Expand Down
6 changes: 2 additions & 4 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,7 @@ bool llama_batch_allocr::init(
}
}

// TEMPORARILY DISABLING THIS SANITY CHECK
// TODO: UNDO THIS IF IT WORKS
/*if (!ok) {
if (!ok) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
Expand All @@ -286,7 +284,7 @@ bool llama_batch_allocr::init(
__func__, s, s, p0, s, seq_pos_min(s));

return false;
}*/
}
}

if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
Expand Down
103 changes: 70 additions & 33 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,24 @@
//
// llama_context
//
// Key for the graph cache. It contains all parameters that define the graph topology.
struct llama_graph_cache_key {
uint32_t n_tokens;
uint32_t n_outputs;
llama_mtp_op_type op_type;
bool causal_attn;

bool operator<(const llama_graph_cache_key& other) const {
return std::tie(n_tokens, n_outputs, op_type, causal_attn) <
std::tie(other.n_tokens, other.n_outputs, other.op_type, other.causal_attn);
}
};

struct llama_context_kv_cache_data {
llama_kv_cache_unified::slot_info_vec_t last_main_model_sinfos;
llama_kv_cache_unified::slot_info_vec_t resized_sinfo_for_force;
const llama_kv_cache_unified::slot_info_vec_t * forced_sinfos = nullptr;
std::map<llama_graph_cache_key, llm_graph_result_ptr> graph_cache;
};

llama_context::llama_context(
Expand Down Expand Up @@ -532,18 +546,6 @@ float * llama_context::get_logits() {
return logits;
}

void llama_context::set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i) {
output_reorder();

ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched_override, logit_override);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(logits != nullptr);

int64_t j = output_ids[i];

ggml_backend_tensor_get_async(backend_res, logit_override, logits + j*model.vocab.n_tokens(), 0, model.vocab.n_tokens() * sizeof(float));
}

float * llama_context::get_logits_ith(int32_t i) {
int64_t j = -1;

Expand Down Expand Up @@ -745,40 +747,78 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
return nullptr;
}

auto * res = gf_res_prev.get();
auto * gf = res->get_gf();

// the new graph parameters
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params);
auto * kvd = static_cast<llama_context_kv_cache_data *>(kv_cache_data);
llm_graph_result * res;

if (!graph_reuse_disable && res->can_reuse(gparams)) {
//LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
if (mtp_params.op_type != MTP_OP_NONE) {
int32_t n_outputs = 0;
for (int i = 0; i < ubatch.n_tokens; ++i) { if (ubatch.output[i]) n_outputs++; }
const llama_graph_cache_key key = { ubatch.n_tokens, (uint32_t)n_outputs, mtp_params.op_type, cparams.causal_attn };

n_reused++;
} else {
res->reset();
auto & res_ptr = kvd->graph_cache[key];
if (!res_ptr) {
LLAMA_LOG_DEBUG("%s: Creating a new graph container for key (op=%d, tok=%d, out=%d)\n",
__func__, (int)key.op_type, key.n_tokens, key.n_outputs);
res_ptr = std::make_unique<llm_graph_result>(graph_max_nodes());
}
res = res_ptr.get();

// the new graph parameters
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params);

ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);

//const auto t_start_us = ggml_time_us();

gf = model.build_graph(gparams);

//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
res->reset();
res->set_params(gparams);
res->gf = model.build_graph(gparams);

if (!gf) {
if (!res->gf) {
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}

if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
if (!ggml_backend_sched_alloc_graph(sched.get(), res->gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
ret = GGML_STATUS_ALLOC_FAILED;
return nullptr;
}

} else {
res = gf_res_prev.get();
const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params);

if (!graph_reuse_disable && res->can_reuse(gparams)) {
LLAMA_LOG_DEBUG("%s: Reusing previous graph\n", __func__);
n_reused++;
} else {
LLAMA_LOG_DEBUG("%s: Reconstructed graph...\n", __func__);

ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);

res->reset();
res->set_params(gparams);
//const auto t_start_us = ggml_time_us();

res->gf = model.build_graph(gparams);

//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);

if (!res->gf) {
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}

if (!ggml_backend_sched_alloc_graph(sched.get(), res->gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
ret = GGML_STATUS_ALLOC_FAILED;
return nullptr;
}
}
}

if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation
Expand All @@ -805,9 +845,6 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
}

ret = GGML_STATUS_SUCCESS;
if (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
ggml_tensor * sum_tensor = ggml_get_tensor(res->get_ctx(), "mtp_input_sum");
}
return res;
}

Expand Down
2 changes: 0 additions & 2 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,6 @@ struct llama_context {
// reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);

void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i);

ggml_backend_sched_t create_temp_scheduler(size_t n_nodes);

std::unique_ptr<llama_memory_context_i> mtp_memory_batch(const llama_batch& batch_inp);
Expand Down
11 changes: 3 additions & 8 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1404,14 +1404,10 @@ struct server_slot {
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
// also we cannot split if the pooling would require any past tokens
bool can_split() const {
//fprintf(stderr, "need_embd() %d\n", need_embd());
//fprintf(stderr, "llama_get_memory(ctx) %d\n", llama_get_memory(ctx) != nullptr);
//fprintf(stderr, "POOLING_TYPE check %d\n", llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);

return
!need_embd() ||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) ||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE); // this seems to save embeddings for whole batch?
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE);
}

bool can_batch_with(server_slot & other_slot) const {
Expand Down Expand Up @@ -1440,7 +1436,6 @@ struct server_slot {

bool can_speculate() const {
return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt;
// return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt;
}

void add_token(const completion_token_output & token) {
Expand Down Expand Up @@ -2133,7 +2128,7 @@ struct server_context {
}

// if model has MTP and no draft model is specified...
else if (llama_model_n_nextn_layer(model) > 0) {
else if (llama_model_n_nextn_layer(model) > 0 && params_base.mtp) {
SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model));
slot.has_mtp = true;

Expand Down Expand Up @@ -3637,7 +3632,7 @@ struct server_context {

llama_tokens draft;
if (slot.has_mtp) {
llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx);
llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past);
draft.reserve(1);
draft.push_back(draft_id);
}
Expand Down