diff --git a/common/arg.cpp b/common/arg.cpp index 0f01bb31454..61c1abc3bcc 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3315,6 +3315,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.cache_type_k = kv_cache_type_from_str(value); } ).set_env("LLAMA_ARG_CACHE_TYPE_K_DRAFT")); + add_opt(common_arg( + {"-mtp", "--multi-token-prediction"}, + string_format("Activate multi-token-prediction (if supported) (default: %s)", params.mtp ? "true" : "false"), + [](common_params & params) { + params.mtp = true; + } + )); add_opt(common_arg( {"-ctvd", "--cache-type-v-draft"}, "TYPE", string_format( diff --git a/common/common.h b/common/common.h index 5eab199af55..eb43003ed83 100644 --- a/common/common.h +++ b/common/common.h @@ -362,6 +362,7 @@ struct common_params { bool check_tensors = false; // validate tensor data bool no_op_offload = false; // globally disable offload host tensor operations to device bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking) + bool mtp = false; // use mtp is supported bool single_turn = false; // single turn chat conversation diff --git a/common/sampling.cpp b/common/sampling.cpp index 452cefee3b9..8422023a540 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -348,11 +348,6 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co llama_sampler_apply(chain, &cur_p); - /*for (int k = 0; k < (int)cur_p.size; ++k) { - LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f)\n", - k, 0, cur_p.data[k].id, cur_p.data[k].p); - }*/ - GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); const llama_token id = cur_p.data[cur_p.selected].id; @@ -583,6 +578,41 @@ std::vector common_sampler_types_from_chars(const std::stri return samplers; } -void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) { - llama_sampler_apply(gsmpl->chain, cur_p); +/** + * Specialized sampling for speculative drafting. + * + * Prioritizes performance by using a direct ArgMax loop (Greedy) when no + * penalties (repetition, frequency, presence, DRY) are configured. + * Falls back to the full sampler chain if penalties are active to prevent + * generative loops or adhere to constraints. + */ +llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) { + const auto & params = gsmpl->params; + + bool use_heavy_sampler = + (params.penalty_last_n > 0 && ( + params.penalty_repeat != 1.0f || + params.penalty_freq != 0.0f || + params.penalty_present != 0.0f + )) || + (params.dry_allowed_length > 0 && params.dry_multiplier != 0.0f); + + if (use_heavy_sampler) { + return common_sampler_sample(gsmpl, ctx, idx, false); + } + + float * logits = llama_get_logits_ith(ctx, idx); + const int n_vocab = llama_n_vocab(llama_model_get_vocab(llama_get_model(ctx))); + + int best_id = 0; + float max_val = logits[0]; + + for (int i = 1; i < n_vocab; ++i) { + if (logits[i] > max_val) { + max_val = logits[i]; + best_id = i; + } + } + + return best_id; } \ No newline at end of file diff --git a/common/sampling.h b/common/sampling.h index b424d7d6d70..81a89727384 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -106,4 +106,4 @@ std::vector common_sampler_types_from_chars(const std: llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, const char * grammar_data); -void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p); \ No newline at end of file +llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx); \ No newline at end of file diff --git a/common/speculative.cpp b/common/speculative.cpp index a7a40426821..9d6b7ec1cbb 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -367,42 +367,31 @@ llama_token mtp_speculative_gen_draft( struct common_sampler* smpl, struct llama_context* ctx, llama_token id_last, - int32_t n_past, - int32_t last_tok_idx) { + int32_t n_past) { + + if (!smpl) return -1; - if (!smpl) { - return -1; - } llama_batch mtp_batch = llama_batch_init(1, 0, 1); - const llama_pos draft_pos = n_past; const llama_seq_id draft_seq_id = 0; + common_batch_add(mtp_batch, id_last, n_past, {0}, true); mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN; // Perform the MTP draft generation decode. This writes the MTP layer's // KV state for the draft token into the cache. - llama_decode(ctx, mtp_batch); + if (llama_decode(ctx, mtp_batch) != 0) { + llama_batch_free(mtp_batch); + return -1; + } llama_batch_free(mtp_batch); // CRITICAL: Purge the metadata for the draft token we just wrote. // This makes the physical cell available again for the main model's validation pass, // preventing a cache state corruption where two cells map to the same logical position. - llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1); - - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - const int n_vocab = llama_n_vocab(vocab); - llama_token_data_array * cur_p = common_sampler_get_candidates(smpl); - cur_p->size = n_vocab; - for (int i = 0; i < n_vocab; ++i) { - cur_p->data[i].id = i; - cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // For a single-token batch, logits are always at index 0. - } - cur_p->sorted = false; - common_sampler_apply_chain(smpl, cur_p); - - return cur_p->data[0].id; + llama_kv_cache_seq_rm(ctx, draft_seq_id, n_past, n_past + 1); + + return common_sampler_sample_speculative(smpl, ctx, 0); } diff --git a/common/speculative.h b/common/speculative.h index 8b81f4ac77d..4720c50cfde 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -39,8 +39,7 @@ llama_token mtp_speculative_gen_draft( struct common_sampler* smpl, struct llama_context* ctx, llama_token id_last, - int32_t n_past, - int32_t last_tok_idx); + int32_t n_past); // sample up to n_draft tokens and add them to the batch using the draft model llama_tokens common_speculative_gen_draft( diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index c01960c55ea..3687058c82e 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -275,9 +275,7 @@ bool llama_batch_allocr::init( } } - // TEMPORARILY DISABLING THIS SANITY CHECK - // TODO: UNDO THIS IF IT WORKS - /*if (!ok) { + if (!ok) { LLAMA_LOG_ERROR( "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" @@ -286,7 +284,7 @@ bool llama_batch_allocr::init( __func__, s, s, p0, s, seq_pos_min(s)); return false; - }*/ + } } if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index fb35d6c79de..1a7f148a2c3 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -17,10 +17,24 @@ // // llama_context // +// Key for the graph cache. It contains all parameters that define the graph topology. +struct llama_graph_cache_key { + uint32_t n_tokens; + uint32_t n_outputs; + llama_mtp_op_type op_type; + bool causal_attn; + + bool operator<(const llama_graph_cache_key& other) const { + return std::tie(n_tokens, n_outputs, op_type, causal_attn) < + std::tie(other.n_tokens, other.n_outputs, other.op_type, other.causal_attn); + } +}; + struct llama_context_kv_cache_data { llama_kv_cache_unified::slot_info_vec_t last_main_model_sinfos; llama_kv_cache_unified::slot_info_vec_t resized_sinfo_for_force; const llama_kv_cache_unified::slot_info_vec_t * forced_sinfos = nullptr; + std::map graph_cache; }; llama_context::llama_context( @@ -532,18 +546,6 @@ float * llama_context::get_logits() { return logits; } -void llama_context::set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i) { - output_reorder(); - - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched_override, logit_override); - GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); - - int64_t j = output_ids[i]; - - ggml_backend_tensor_get_async(backend_res, logit_override, logits + j*model.vocab.n_tokens(), 0, model.vocab.n_tokens() * sizeof(float)); -} - float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; @@ -745,40 +747,78 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll return nullptr; } - auto * res = gf_res_prev.get(); - auto * gf = res->get_gf(); - - // the new graph parameters - // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters - const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params); + auto * kvd = static_cast(kv_cache_data); + llm_graph_result * res; - if (!graph_reuse_disable && res->can_reuse(gparams)) { - //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); + if (mtp_params.op_type != MTP_OP_NONE) { + int32_t n_outputs = 0; + for (int i = 0; i < ubatch.n_tokens; ++i) { if (ubatch.output[i]) n_outputs++; } + const llama_graph_cache_key key = { ubatch.n_tokens, (uint32_t)n_outputs, mtp_params.op_type, cparams.causal_attn }; - n_reused++; - } else { - res->reset(); + auto & res_ptr = kvd->graph_cache[key]; + if (!res_ptr) { + LLAMA_LOG_DEBUG("%s: Creating a new graph container for key (op=%d, tok=%d, out=%d)\n", + __func__, (int)key.op_type, key.n_tokens, key.n_outputs); + res_ptr = std::make_unique(graph_max_nodes()); + } + res = res_ptr.get(); + // the new graph parameters + // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters + const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params); + ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); - //const auto t_start_us = ggml_time_us(); - - gf = model.build_graph(gparams); - - //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); + res->reset(); + res->set_params(gparams); + res->gf = model.build_graph(gparams); - if (!gf) { + if (!res->gf) { LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); ret = GGML_STATUS_FAILED; return nullptr; } - if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + if (!ggml_backend_sched_alloc_graph(sched.get(), res->gf)) { LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); ret = GGML_STATUS_ALLOC_FAILED; return nullptr; } + + } else { + res = gf_res_prev.get(); + const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params); + + if (!graph_reuse_disable && res->can_reuse(gparams)) { + LLAMA_LOG_DEBUG("%s: Reusing previous graph\n", __func__); + n_reused++; + } else { + LLAMA_LOG_DEBUG("%s: Reconstructed graph...\n", __func__); + + ggml_backend_sched_reset(sched.get()); + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + res->reset(); + res->set_params(gparams); + //const auto t_start_us = ggml_time_us(); + + res->gf = model.build_graph(gparams); + + //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); + + if (!res->gf) { + LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + if (!ggml_backend_sched_alloc_graph(sched.get(), res->gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); + ret = GGML_STATUS_ALLOC_FAILED; + return nullptr; + } + } } if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation @@ -805,9 +845,6 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } ret = GGML_STATUS_SUCCESS; - if (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { - ggml_tensor * sum_tensor = ggml_get_tensor(res->get_ctx(), "mtp_input_sum"); - } return res; } diff --git a/src/llama-context.h b/src/llama-context.h index 4d77d5d81ae..7297c4c5d16 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -216,8 +216,6 @@ struct llama_context { // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); - void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i); - ggml_backend_sched_t create_temp_scheduler(size_t n_nodes); std::unique_ptr mtp_memory_batch(const llama_batch& batch_inp); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a24532c6939..e26e1d6fecc 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1404,14 +1404,10 @@ struct server_slot { // if the context does not have a memory module then all embeddings have to be computed within a single ubatch // also we cannot split if the pooling would require any past tokens bool can_split() const { - //fprintf(stderr, "need_embd() %d\n", need_embd()); - //fprintf(stderr, "llama_get_memory(ctx) %d\n", llama_get_memory(ctx) != nullptr); - //fprintf(stderr, "POOLING_TYPE check %d\n", llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); - return !need_embd() || (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE); // this seems to save embeddings for whole batch? + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE); } bool can_batch_with(server_slot & other_slot) const { @@ -1440,7 +1436,6 @@ struct server_slot { bool can_speculate() const { return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt; - // return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt; } void add_token(const completion_token_output & token) { @@ -2133,7 +2128,7 @@ struct server_context { } // if model has MTP and no draft model is specified... - else if (llama_model_n_nextn_layer(model) > 0) { + else if (llama_model_n_nextn_layer(model) > 0 && params_base.mtp) { SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model)); slot.has_mtp = true; @@ -3637,7 +3632,7 @@ struct server_context { llama_tokens draft; if (slot.has_mtp) { - llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); + llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past); draft.reserve(1); draft.push_back(draft_id); }