From 664520cf0185fc63a41f8d6b1177dcf57018bee6 Mon Sep 17 00:00:00 2001 From: Arthur Konovalov Date: Tue, 19 May 2026 16:16:22 -0600 Subject: [PATCH 1/8] chore: update to ac7680 --- llama-cpp-sys-2/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama-cpp-sys-2/llama.cpp b/llama-cpp-sys-2/llama.cpp index e21cdc11..ac76808e 160000 --- a/llama-cpp-sys-2/llama.cpp +++ b/llama-cpp-sys-2/llama.cpp @@ -1 +1 @@ -Subproject commit e21cdc11a0461d8b0cbd28cc356d993bf6be7282 +Subproject commit ac76808e4db7bbb4082b86e7fbd615934b44ac6e From 43e8ab9d1fdc3ed475f9c7be1ff574276527584b Mon Sep 17 00:00:00 2001 From: Arthur Konovalov Date: Tue, 19 May 2026 16:18:42 -0600 Subject: [PATCH 2/8] fix: adapt common wrappers to llama.cpp API --- llama-cpp-2/src/context.rs | 2 +- llama-cpp-2/src/model/params.rs | 6 ++--- llama-cpp-sys-2/wrapper_common.cpp | 35 ++++++++++++++++++++++++++++++ llama-cpp-sys-2/wrapper_common.h | 18 +++++++++++++++ llama-cpp-sys-2/wrapper_oai.cpp | 32 ++++++++++++++++++++++++++- 5 files changed, 88 insertions(+), 5 deletions(-) diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index 27a65a1b..ea8e73d9 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -365,7 +365,7 @@ impl<'model> LlamaContext<'model> { /// Print a breakdown of per-device memory use to the default logger. pub fn print_memory_breakdown(&self) { - unsafe { llama_cpp_sys_2::llama_memory_breakdown_print(self.context.as_ptr()) } + unsafe { llama_cpp_sys_2::llama_rs_memory_breakdown_print(self.context.as_ptr()) } } } diff --git a/llama-cpp-2/src/model/params.rs b/llama-cpp-2/src/model/params.rs index 9f4240a5..277c7bfd 100644 --- a/llama-cpp-2/src/model/params.rs +++ b/llama-cpp-2/src/model/params.rs @@ -353,7 +353,7 @@ impl LlamaModelParams { self.params.tensor_buft_overrides = null(); let status = unsafe { - llama_cpp_sys_2::llama_params_fit( + llama_cpp_sys_2::llama_rs_params_fit( model_path.as_ptr(), &raw mut self.params, &raw mut cparams.context_params, @@ -366,8 +366,8 @@ impl LlamaModelParams { }; match status { - llama_cpp_sys_2::LLAMA_PARAMS_FIT_STATUS_SUCCESS => {} - llama_cpp_sys_2::LLAMA_PARAMS_FIT_STATUS_FAILURE => return Err(FitError::Failure), + llama_cpp_sys_2::LLAMA_RS_PARAMS_FIT_STATUS_SUCCESS => {} + llama_cpp_sys_2::LLAMA_RS_PARAMS_FIT_STATUS_FAILURE => return Err(FitError::Failure), _ => return Err(FitError::Error), } diff --git a/llama-cpp-sys-2/wrapper_common.cpp b/llama-cpp-sys-2/wrapper_common.cpp index 09bbeede..848b1e19 100644 --- a/llama-cpp-sys-2/wrapper_common.cpp +++ b/llama-cpp-sys-2/wrapper_common.cpp @@ -6,6 +6,7 @@ #include #include +#include "llama.cpp/common/fit.h" #include "llama.cpp/common/json-schema-to-grammar.h" #include "llama.cpp/include/llama.h" #include "wrapper_utils.h" @@ -85,6 +86,40 @@ extern "C" void llama_rs_string_free(char * ptr) { } } +extern "C" enum llama_rs_params_fit_status llama_rs_params_fit( + const char * path_model, + struct llama_model_params * mparams, + struct llama_context_params * cparams, + float * tensor_split, + struct llama_model_tensor_buft_override * tensor_buft_overrides, + size_t * margins, + uint32_t n_ctx_min, + enum ggml_log_level log_level) { + const auto status = common_fit_params( + path_model, + mparams, + cparams, + tensor_split, + tensor_buft_overrides, + margins, + n_ctx_min, + log_level); + + switch (status) { + case COMMON_PARAMS_FIT_STATUS_SUCCESS: + return LLAMA_RS_PARAMS_FIT_STATUS_SUCCESS; + case COMMON_PARAMS_FIT_STATUS_FAILURE: + return LLAMA_RS_PARAMS_FIT_STATUS_FAILURE; + case COMMON_PARAMS_FIT_STATUS_ERROR: + default: + return LLAMA_RS_PARAMS_FIT_STATUS_ERROR; + } +} + +extern "C" void llama_rs_memory_breakdown_print(const struct llama_context * ctx) { + common_memory_breakdown_print(ctx); +} + extern "C" struct llama_sampler * llama_rs_sampler_init_grammar( const struct llama_vocab * vocab, const char * grammar_str, diff --git a/llama-cpp-sys-2/wrapper_common.h b/llama-cpp-sys-2/wrapper_common.h index 0c22df96..3b57f62d 100644 --- a/llama-cpp-sys-2/wrapper_common.h +++ b/llama-cpp-sys-2/wrapper_common.h @@ -30,6 +30,12 @@ struct llama_rs_chat_template_result { size_t additional_stops_count; }; +enum llama_rs_params_fit_status { + LLAMA_RS_PARAMS_FIT_STATUS_SUCCESS = 0, + LLAMA_RS_PARAMS_FIT_STATUS_FAILURE = 1, + LLAMA_RS_PARAMS_FIT_STATUS_ERROR = 2, +}; + #include "wrapper_utils.h" #ifdef __cplusplus @@ -66,6 +72,18 @@ struct llama_sampler * llama_rs_sampler_init_grammar_lazy_patterns( llama_rs_status llama_rs_sampler_accept(struct llama_sampler * sampler, llama_token token); +enum llama_rs_params_fit_status llama_rs_params_fit( + const char * path_model, + struct llama_model_params * mparams, + struct llama_context_params * cparams, + float * tensor_split, + struct llama_model_tensor_buft_override * tensor_buft_overrides, + size_t * margins, + uint32_t n_ctx_min, + enum ggml_log_level log_level); + +void llama_rs_memory_breakdown_print(const struct llama_context * ctx); + void llama_rs_chat_template_result_free(struct llama_rs_chat_template_result * result); void llama_rs_string_free(char * ptr); diff --git a/llama-cpp-sys-2/wrapper_oai.cpp b/llama-cpp-sys-2/wrapper_oai.cpp index 66d1b18f..7d052928 100644 --- a/llama-cpp-sys-2/wrapper_oai.cpp +++ b/llama-cpp-sys-2/wrapper_oai.cpp @@ -61,6 +61,36 @@ static bool detect_thinking_forced_open(const common_chat_params & params) { && ends_with(params.generation_prompt, params.thinking_start_tag); } +static json chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { + json delta = json::object(); + if (!diff.reasoning_content_delta.empty()) { + delta["reasoning_content"] = diff.reasoning_content_delta; + } + if (!diff.content_delta.empty()) { + delta["content"] = diff.content_delta; + } + if (diff.tool_call_index != std::string::npos) { + json tool_call; + tool_call["index"] = diff.tool_call_index; + if (!diff.tool_call_delta.id.empty()) { + tool_call["id"] = diff.tool_call_delta.id; + tool_call["type"] = "function"; + } + if (!diff.tool_call_delta.name.empty() || !diff.tool_call_delta.arguments.empty()) { + json function = json::object(); + if (!diff.tool_call_delta.name.empty()) { + function["name"] = diff.tool_call_delta.name; + } + if (!diff.tool_call_delta.arguments.empty()) { + function["arguments"] = diff.tool_call_delta.arguments; + } + tool_call["function"] = function; + } + delta["tool_calls"] = json::array({ tool_call }); + } + return delta; +} + static void init_chat_msg(struct llama_rs_chat_msg_oaicompat * out_msg) { if (!out_msg) { return; @@ -854,7 +884,7 @@ extern "C" llama_rs_status llama_rs_chat_msg_diff_to_oaicompat_json( msg_diff.tool_call_delta.id = diff->tool_call_delta.id ? diff->tool_call_delta.id : ""; } - auto json_delta = common_chat_msg_diff_to_json_oaicompat(msg_diff).dump(); + auto json_delta = chat_msg_diff_to_json_oaicompat(msg_diff).dump(); *out_json = llama_rs_dup_string(json_delta); return *out_json ? LLAMA_RS_STATUS_OK : LLAMA_RS_STATUS_ALLOCATION_FAILED; } catch (const std::exception &) { From 13b110eadc6e0bf45e2e4699d998b1173b108626 Mon Sep 17 00:00:00 2001 From: Arthur Konovalov Date: Tue, 19 May 2026 16:18:56 -0600 Subject: [PATCH 3/8] feat: expose MTP context parameters --- llama-cpp-2/src/context/params.rs | 33 +++++++++++++++++++++++ llama-cpp-2/src/context/params/get_set.rs | 29 +++++++++++++++++++- 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/llama-cpp-2/src/context/params.rs b/llama-cpp-2/src/context/params.rs index 027f55a3..441f0e4f 100644 --- a/llama-cpp-2/src/context/params.rs +++ b/llama-cpp-2/src/context/params.rs @@ -122,6 +122,39 @@ impl From for i32 { } } +/// A rusty wrapper around `llama_context_type`. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum LlamaContextType { + /// Default decoder context. + Default, + /// Multi-token-prediction draft context. + Mtp, + /// Unknown context type from a newer llama.cpp. + Unknown(llama_cpp_sys_2::llama_context_type), +} + +/// Create a `LlamaContextType` from a raw `llama_context_type`. +impl From for LlamaContextType { + fn from(value: llama_cpp_sys_2::llama_context_type) -> Self { + match value { + x if x == llama_cpp_sys_2::LLAMA_CONTEXT_TYPE_DEFAULT => Self::Default, + x if x == llama_cpp_sys_2::LLAMA_CONTEXT_TYPE_MTP => Self::Mtp, + x => Self::Unknown(x), + } + } +} + +/// Create a raw `llama_context_type` from a `LlamaContextType`. +impl From for llama_cpp_sys_2::llama_context_type { + fn from(value: LlamaContextType) -> Self { + match value { + LlamaContextType::Default => llama_cpp_sys_2::LLAMA_CONTEXT_TYPE_DEFAULT, + LlamaContextType::Mtp => llama_cpp_sys_2::LLAMA_CONTEXT_TYPE_MTP, + LlamaContextType::Unknown(raw) => raw, + } + } +} + /// A rusty wrapper around `ggml_type` for KV cache types. #[allow(non_camel_case_types, missing_docs)] #[derive(Copy, Clone, Debug, PartialEq, Eq)] diff --git a/llama-cpp-2/src/context/params/get_set.rs b/llama-cpp-2/src/context/params/get_set.rs index d9bf1753..51d409aa 100644 --- a/llama-cpp-2/src/context/params/get_set.rs +++ b/llama-cpp-2/src/context/params/get_set.rs @@ -1,7 +1,8 @@ use std::num::NonZeroU32; use super::{ - KvCacheType, LlamaAttentionType, LlamaContextParams, LlamaPoolingType, RopeScalingType, + KvCacheType, LlamaAttentionType, LlamaContextParams, LlamaContextType, LlamaPoolingType, + RopeScalingType, }; impl LlamaContextParams { @@ -128,6 +129,32 @@ impl LlamaContextParams { self.context_params.n_seq_max } + /// Set the number of recurrent-state rollback snapshots per sequence. + #[must_use] + pub fn with_n_rs_seq(mut self, n_rs_seq: u32) -> Self { + self.context_params.n_rs_seq = n_rs_seq; + self + } + + /// Get the number of recurrent-state rollback snapshots per sequence. + #[must_use] + pub fn n_rs_seq(&self) -> u32 { + self.context_params.n_rs_seq + } + + /// Set the llama.cpp context type. + #[must_use] + pub fn with_context_type(mut self, context_type: LlamaContextType) -> Self { + self.context_params.ctx_type = context_type.into(); + self + } + + /// Get the llama.cpp context type. + #[must_use] + pub fn context_type(&self) -> LlamaContextType { + self.context_params.ctx_type.into() + } + /// Set the number of threads /// /// # Examples From bc4bfd12ba402135749c1c6156c44a2d8b93a3da Mon Sep 17 00:00:00 2001 From: Arthur Konovalov Date: Tue, 19 May 2026 16:19:11 -0600 Subject: [PATCH 4/8] feat: expose MTP speculative decoding --- llama-cpp-2/src/lib.rs | 1 + llama-cpp-2/src/speculative.rs | 213 +++++++++++++++++++++++++++++ llama-cpp-sys-2/wrapper_common.cpp | 187 +++++++++++++++++++++++++ llama-cpp-sys-2/wrapper_common.h | 33 +++++ 4 files changed, 434 insertions(+) create mode 100644 llama-cpp-2/src/speculative.rs diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index 453f91ff..fd698ebb 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -35,6 +35,7 @@ pub mod model; pub mod mtmd; pub mod openai; pub mod sampling; +pub mod speculative; pub mod timing; pub mod token; pub mod token_type; diff --git a/llama-cpp-2/src/speculative.rs b/llama-cpp-2/src/speculative.rs new file mode 100644 index 00000000..83052824 --- /dev/null +++ b/llama-cpp-2/src/speculative.rs @@ -0,0 +1,213 @@ +//! Experimental wrappers for llama.cpp speculative decoding helpers. + +use std::ptr::NonNull; + +use crate::context::LlamaContext; +use crate::llama_batch::LlamaBatch; +use crate::status_is_ok; +use crate::token::LlamaToken; + +/// Parameters for same-model MTP speculative decoding. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct MtpSpeculativeParams { + /// Maximum number of draft tokens to propose. + pub n_max: i32, + /// Minimum number of draft tokens required before returning a draft. + pub n_min: i32, + /// Minimum draft probability accepted by llama.cpp's MTP drafter. + pub p_min: f32, +} + +impl Default for MtpSpeculativeParams { + fn default() -> Self { + Self { + n_max: 3, + n_min: 0, + p_min: 0.0, + } + } +} + +/// Errors returned by the MTP speculative wrapper. +#[derive(Debug, Eq, PartialEq, thiserror::Error)] +pub enum MtpSpeculativeError { + /// Invalid parameters were provided. + #[error("invalid MTP speculative parameters")] + InvalidParams, + /// llama.cpp returned a null speculative handle. + #[error("llama.cpp failed to initialize MTP speculative decoding")] + InitFailed, + /// llama.cpp rejected a wrapper call. + #[error("llama.cpp MTP speculative call failed with status {0}")] + Status(i32), + /// The draft output exceeded the caller-provided bound. + #[error("llama.cpp MTP draft exceeded configured maximum")] + DraftOverflow, +} + +/// RAII owner for a same-model MTP speculative context. +#[derive(Debug)] +pub struct MtpSpeculative<'model> { + raw: NonNull, + target_context: LlamaContext<'model>, + draft_context: LlamaContext<'model>, + n_max: usize, +} + +impl<'model> MtpSpeculative<'model> { + /// Create a new MTP speculative helper from a target context and an MTP + /// draft context. + /// + /// # Errors + /// + /// Returns an error if parameters are invalid or llama.cpp cannot + /// initialize the speculative implementation for the loaded model. + pub fn new( + target_context: LlamaContext<'model>, + draft_context: LlamaContext<'model>, + params: MtpSpeculativeParams, + ) -> Result { + if params.n_max <= 0 || params.n_min < 0 || params.n_min > params.n_max { + return Err(MtpSpeculativeError::InvalidParams); + } + let n_max = + usize::try_from(params.n_max).map_err(|_| MtpSpeculativeError::InvalidParams)?; + + let raw = unsafe { + llama_cpp_sys_2::llama_rs_mtp_speculative_init( + target_context.context.as_ptr(), + draft_context.context.as_ptr(), + params.n_max, + params.n_min, + params.p_min, + ) + }; + let raw = NonNull::new(raw).ok_or(MtpSpeculativeError::InitFailed)?; + + Ok(Self { + raw, + target_context, + draft_context, + n_max, + }) + } + + /// Access the target context. + #[must_use] + pub fn target_context(&self) -> &LlamaContext<'model> { + &self.target_context + } + + /// Access the target context for decode and cache rollback operations. + pub fn target_context_mut(&mut self) -> &mut LlamaContext<'model> { + &mut self.target_context + } + + /// Access the draft context for cache rollback operations. + pub fn draft_context_mut(&mut self) -> &mut LlamaContext<'model> { + &mut self.draft_context + } + + /// Begin a new generation from the given prompt tokens. + /// + /// # Errors + /// + /// Returns an error if llama.cpp rejects the call. + pub fn begin(&mut self, prompt_tokens: &[LlamaToken]) -> Result<(), MtpSpeculativeError> { + let prompt = tokens_to_raw(prompt_tokens); + let status = unsafe { + llama_cpp_sys_2::llama_rs_mtp_speculative_begin( + self.raw.as_ptr(), + prompt.as_ptr(), + prompt.len(), + ) + }; + status_to_result(status) + } + + /// Process a batch that was just decoded by the target context. + /// + /// # Errors + /// + /// Returns an error if llama.cpp cannot update the MTP draft context. + pub fn process(&mut self, batch: &LlamaBatch<'_>) -> Result<(), MtpSpeculativeError> { + let status = unsafe { + llama_cpp_sys_2::llama_rs_mtp_speculative_process( + self.raw.as_ptr(), + std::ptr::from_ref(&batch.llama_batch), + ) + }; + status_to_result(status) + } + + /// Generate draft tokens after `id_last`. + /// + /// # Errors + /// + /// Returns an error if llama.cpp rejects the draft operation or emits more + /// draft tokens than requested. + pub fn draft( + &mut self, + n_past: i32, + id_last: LlamaToken, + prompt_tokens: &[LlamaToken], + ) -> Result, MtpSpeculativeError> { + if n_past < 0 { + return Err(MtpSpeculativeError::InvalidParams); + } + + let prompt = tokens_to_raw(prompt_tokens); + let mut raw_out = vec![0; self.n_max]; + let mut out_len = 0_usize; + let status = unsafe { + llama_cpp_sys_2::llama_rs_mtp_speculative_draft( + self.raw.as_ptr(), + n_past, + id_last.0, + prompt.as_ptr(), + prompt.len(), + raw_out.as_mut_ptr(), + raw_out.len(), + &raw mut out_len, + ) + }; + if status == llama_cpp_sys_2::LLAMA_RS_STATUS_ALLOCATION_FAILED { + return Err(MtpSpeculativeError::DraftOverflow); + } + status_to_result(status)?; + raw_out.truncate(out_len); + Ok(raw_out.into_iter().map(LlamaToken).collect()) + } + + /// Notify llama.cpp how many draft tokens the target context accepted. + /// + /// # Errors + /// + /// Returns an error if llama.cpp rejects the call. + pub fn accept(&mut self, n_accepted: u16) -> Result<(), MtpSpeculativeError> { + let status = unsafe { + llama_cpp_sys_2::llama_rs_mtp_speculative_accept(self.raw.as_ptr(), n_accepted) + }; + status_to_result(status) + } +} + +impl Drop for MtpSpeculative<'_> { + fn drop(&mut self) { + unsafe { + llama_cpp_sys_2::llama_rs_mtp_speculative_free(self.raw.as_ptr()); + } + } +} + +fn tokens_to_raw(tokens: &[LlamaToken]) -> Vec { + tokens.iter().map(|token| token.0).collect() +} + +fn status_to_result(status: llama_cpp_sys_2::llama_rs_status) -> Result<(), MtpSpeculativeError> { + if status_is_ok(status) { + Ok(()) + } else { + Err(MtpSpeculativeError::Status(status)) + } +} diff --git a/llama-cpp-sys-2/wrapper_common.cpp b/llama-cpp-sys-2/wrapper_common.cpp index 848b1e19..c67eab7d 100644 --- a/llama-cpp-sys-2/wrapper_common.cpp +++ b/llama-cpp-sys-2/wrapper_common.cpp @@ -3,11 +3,14 @@ #include #include #include +#include #include #include +#include #include "llama.cpp/common/fit.h" #include "llama.cpp/common/json-schema-to-grammar.h" +#include "llama.cpp/common/speculative.h" #include "llama.cpp/include/llama.h" #include "wrapper_utils.h" @@ -201,3 +204,187 @@ extern "C" llama_rs_status llama_rs_sampler_accept(struct llama_sampler * sample return LLAMA_RS_STATUS_EXCEPTION; } } + +struct llama_rs_mtp_speculative { + common_params_speculative params; + common_speculative * spec = nullptr; + std::vector prompt; + std::vector draft; + size_t last_draft_len = 0; + bool draft_pending = false; +}; + +static bool llama_rs_mtp_batch_compatible(const struct llama_batch & batch) { + if (batch.n_tokens <= 0 || !batch.token || batch.embd || !batch.pos || !batch.n_seq_id || + !batch.seq_id) { + return false; + } + for (int32_t k = 0; k < batch.n_tokens; ++k) { + if (batch.n_seq_id[k] != 1 || !batch.seq_id[k]) { + return false; + } + } + return true; +} + +static void llama_rs_assign_tokens( + std::vector & dst, + const llama_token * tokens, + size_t count) { + if (count == 0) { + dst.clear(); + return; + } + dst.assign(tokens, tokens + count); +} + +extern "C" struct llama_rs_mtp_speculative * llama_rs_mtp_speculative_init( + struct llama_context * ctx_tgt, + struct llama_context * ctx_dft, + int32_t n_max, + int32_t n_min, + float p_min) { + if (!ctx_tgt || !ctx_dft || n_max <= 0 || n_min < 0 || n_min > n_max) { + return nullptr; + } + + try { + auto wrapper = std::make_unique(); + wrapper->params.types = { COMMON_SPECULATIVE_TYPE_DRAFT_MTP }; + wrapper->params.draft.ctx_tgt = ctx_tgt; + wrapper->params.draft.ctx_dft = ctx_dft; + wrapper->params.draft.n_max = n_max; + wrapper->params.draft.n_min = n_min; + wrapper->params.draft.p_min = p_min; + + wrapper->spec = common_speculative_init(wrapper->params, 1); + if (!wrapper->spec) { + return nullptr; + } + + return wrapper.release(); + } catch (...) { + return nullptr; + } +} + +extern "C" void llama_rs_mtp_speculative_free(struct llama_rs_mtp_speculative * spec) { + if (!spec) { + return; + } + if (spec->spec) { + common_speculative_free(spec->spec); + spec->spec = nullptr; + } + delete spec; +} + +extern "C" llama_rs_status llama_rs_mtp_speculative_begin( + struct llama_rs_mtp_speculative * spec, + const llama_token * prompt_tokens, + size_t prompt_tokens_count) { + if (!spec || !spec->spec || (!prompt_tokens && prompt_tokens_count > 0)) { + return LLAMA_RS_STATUS_INVALID_ARGUMENT; + } + + try { + llama_rs_assign_tokens(spec->prompt, prompt_tokens, prompt_tokens_count); + spec->last_draft_len = 0; + spec->draft_pending = false; + common_speculative_begin(spec->spec, 0, spec->prompt); + return LLAMA_RS_STATUS_OK; + } catch (...) { + return LLAMA_RS_STATUS_EXCEPTION; + } +} + +extern "C" llama_rs_status llama_rs_mtp_speculative_process( + struct llama_rs_mtp_speculative * spec, + const struct llama_batch * batch) { + if (!spec || !spec->spec || !batch) { + return LLAMA_RS_STATUS_INVALID_ARGUMENT; + } + if (!llama_rs_mtp_batch_compatible(*batch)) { + return LLAMA_RS_STATUS_INVALID_ARGUMENT; + } + + try { + return common_speculative_process(spec->spec, *batch) + ? LLAMA_RS_STATUS_OK + : LLAMA_RS_STATUS_EXCEPTION; + } catch (...) { + return LLAMA_RS_STATUS_EXCEPTION; + } +} + +extern "C" llama_rs_status llama_rs_mtp_speculative_draft( + struct llama_rs_mtp_speculative * spec, + llama_pos n_past, + llama_token id_last, + const llama_token * prompt_tokens, + size_t prompt_tokens_count, + llama_token * out_tokens, + size_t out_tokens_capacity, + size_t * out_tokens_count) { + if (!spec || !spec->spec || (!prompt_tokens && prompt_tokens_count > 0) || + !out_tokens_count || n_past < 0) { + return LLAMA_RS_STATUS_INVALID_ARGUMENT; + } + + try { + if (spec->draft_pending) { + return LLAMA_RS_STATUS_INVALID_ARGUMENT; + } + llama_rs_assign_tokens(spec->prompt, prompt_tokens, prompt_tokens_count); + spec->draft.clear(); + spec->last_draft_len = 0; + + auto & params = common_speculative_get_draft_params(spec->spec, 0); + params = { + true, + spec->params.draft.n_max, + n_past, + id_last, + &spec->prompt, + &spec->draft, + }; + + common_speculative_draft(spec->spec); + + *out_tokens_count = spec->draft.size(); + if (spec->draft.size() > out_tokens_capacity) { + return LLAMA_RS_STATUS_ALLOCATION_FAILED; + } + if (!spec->draft.empty() && !out_tokens) { + return LLAMA_RS_STATUS_INVALID_ARGUMENT; + } + if (!spec->draft.empty()) { + std::memcpy(out_tokens, spec->draft.data(), spec->draft.size() * sizeof(llama_token)); + } + spec->last_draft_len = spec->draft.size(); + spec->draft_pending = !spec->draft.empty(); + return LLAMA_RS_STATUS_OK; + } catch (...) { + return LLAMA_RS_STATUS_EXCEPTION; + } +} + +extern "C" llama_rs_status llama_rs_mtp_speculative_accept( + struct llama_rs_mtp_speculative * spec, + uint16_t n_accepted) { + if (!spec || !spec->spec) { + return LLAMA_RS_STATUS_INVALID_ARGUMENT; + } + if (!spec->draft_pending || n_accepted > spec->last_draft_len) { + return LLAMA_RS_STATUS_INVALID_ARGUMENT; + } + + try { + common_speculative_accept(spec->spec, 0, n_accepted); + spec->last_draft_len = 0; + spec->draft_pending = false; + return LLAMA_RS_STATUS_OK; + } catch (...) { + return LLAMA_RS_STATUS_EXCEPTION; + } +} diff --git a/llama-cpp-sys-2/wrapper_common.h b/llama-cpp-sys-2/wrapper_common.h index 3b57f62d..b3a21091 100644 --- a/llama-cpp-sys-2/wrapper_common.h +++ b/llama-cpp-sys-2/wrapper_common.h @@ -7,6 +7,7 @@ struct llama_model; struct llama_sampler; +struct llama_rs_mtp_speculative; struct llama_vocab; struct llama_rs_grammar_trigger { @@ -84,6 +85,38 @@ enum llama_rs_params_fit_status llama_rs_params_fit( void llama_rs_memory_breakdown_print(const struct llama_context * ctx); +struct llama_rs_mtp_speculative * llama_rs_mtp_speculative_init( + struct llama_context * ctx_tgt, + struct llama_context * ctx_dft, + int32_t n_max, + int32_t n_min, + float p_min); + +void llama_rs_mtp_speculative_free(struct llama_rs_mtp_speculative * spec); + +llama_rs_status llama_rs_mtp_speculative_begin( + struct llama_rs_mtp_speculative * spec, + const llama_token * prompt_tokens, + size_t prompt_tokens_count); + +llama_rs_status llama_rs_mtp_speculative_process( + struct llama_rs_mtp_speculative * spec, + const struct llama_batch * batch); + +llama_rs_status llama_rs_mtp_speculative_draft( + struct llama_rs_mtp_speculative * spec, + llama_pos n_past, + llama_token id_last, + const llama_token * prompt_tokens, + size_t prompt_tokens_count, + llama_token * out_tokens, + size_t out_tokens_capacity, + size_t * out_tokens_count); + +llama_rs_status llama_rs_mtp_speculative_accept( + struct llama_rs_mtp_speculative * spec, + uint16_t n_accepted); + void llama_rs_chat_template_result_free(struct llama_rs_chat_template_result * result); void llama_rs_string_free(char * ptr); From f361cc1609f0b696f9ad64cfd1cc6fcd2360f614 Mon Sep 17 00:00:00 2001 From: Arthur Konovalov Date: Tue, 19 May 2026 13:00:29 -0600 Subject: [PATCH 5/8] Fix llama-common static link detection Newer llama.cpp builds publish libllama-common and libllama-common-base instead of libcommon. Detect the produced common archives before emitting cargo link directives so full release builds can resolve the MTP wrapper dependencies. --- llama-cpp-sys-2/build.rs | 62 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/llama-cpp-sys-2/build.rs b/llama-cpp-sys-2/build.rs index 38ac1097..c2578eb8 100644 --- a/llama-cpp-sys-2/build.rs +++ b/llama-cpp-sys-2/build.rs @@ -149,6 +149,40 @@ fn extract_lib_assets(out_dir: &Path, target_os: &TargetOs) -> Vec { files } +fn library_file_exists( + search_dirs: &[PathBuf], + lib_name: &str, + build_shared_libs: bool, + target_os: &TargetOs, +) -> bool { + let (prefixes, extensions): (&[&str], &[&str]) = match target_os { + TargetOs::Windows(_) => (&["", "lib"], &["lib"]), + TargetOs::Apple(_) => { + if build_shared_libs { + (&["lib"], &["dylib"]) + } else { + (&["lib"], &["a"]) + } + } + TargetOs::Linux | TargetOs::Android => { + if build_shared_libs { + (&["lib"], &["so"]) + } else { + (&["lib"], &["a"]) + } + } + }; + + search_dirs.iter().any(|dir| { + prefixes.iter().any(|prefix| { + extensions.iter().any(|extension| { + dir.join(format!("{prefix}{lib_name}.{extension}")) + .is_file() + }) + }) + }) +} + fn macos_link_search_path() -> Option { let output = Command::new("clang") .arg("--print-search-dirs") @@ -1020,14 +1054,40 @@ fn main() { "cargo:rustc-link-search=native={}", common_lib_dir.display() ); + let mut common_search_dirs = vec![common_lib_dir.clone()]; let common_profile_dir = common_lib_dir.join(&profile); if common_profile_dir.is_dir() { println!( "cargo:rustc-link-search=native={}", common_profile_dir.display() ); + common_search_dirs.push(common_profile_dir); + } + + if library_file_exists( + &common_search_dirs, + "llama-common", + build_shared_libs, + &target_os, + ) { + println!("cargo:rustc-link-lib={llama_libs_kind}=llama-common"); + if library_file_exists( + &common_search_dirs, + "llama-common-base", + build_shared_libs, + &target_os, + ) { + println!("cargo:rustc-link-lib={llama_libs_kind}=llama-common-base"); + } + } else if library_file_exists(&common_search_dirs, "common", build_shared_libs, &target_os) + { + println!("cargo:rustc-link-lib={llama_libs_kind}=common"); + } else { + println!( + "cargo:warning=LLAMA_BUILD_COMMON was enabled, but no common library was found in {}", + common_lib_dir.display() + ); } - println!("cargo:rustc-link-lib=static=common"); } if cfg!(feature = "system-ggml") { From be3eb5b5d2bc88a01c73e688ba6d6abeb506ad1c Mon Sep 17 00:00:00 2001 From: Arthur Konovalov Date: Tue, 19 May 2026 16:48:41 -0600 Subject: [PATCH 6/8] Bump version to 0.1.148 --- Cargo.lock | 8 ++++---- examples/embeddings/Cargo.toml | 2 +- examples/simple/Cargo.toml | 2 +- llama-cpp-2/Cargo.toml | 6 +++--- llama-cpp-sys-2/Cargo.toml | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 820f1b47..8e7a1431 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -776,7 +776,7 @@ checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "embeddings" -version = "0.1.147" +version = "0.1.148" dependencies = [ "anyhow", "clap", @@ -1538,7 +1538,7 @@ checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" [[package]] name = "llama-cpp-2" -version = "0.1.147" +version = "0.1.148" dependencies = [ "encoding_rs", "enumflags2", @@ -1555,7 +1555,7 @@ dependencies = [ [[package]] name = "llama-cpp-sys-2" -version = "0.1.147" +version = "0.1.148" dependencies = [ "bindgen", "cc", @@ -2360,7 +2360,7 @@ checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" [[package]] name = "simple" -version = "0.1.147" +version = "0.1.148" dependencies = [ "anyhow", "clap", diff --git a/examples/embeddings/Cargo.toml b/examples/embeddings/Cargo.toml index 20719948..7ab7bf43 100644 --- a/examples/embeddings/Cargo.toml +++ b/examples/embeddings/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "embeddings" -version = "0.1.147" +version = "0.1.148" edition = "2021" publish = false diff --git a/examples/simple/Cargo.toml b/examples/simple/Cargo.toml index def301c9..a5354c32 100644 --- a/examples/simple/Cargo.toml +++ b/examples/simple/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple" -version = "0.1.147" +version = "0.1.148" edition = "2021" publish = false diff --git a/llama-cpp-2/Cargo.toml b/llama-cpp-2/Cargo.toml index 24e88ddf..8938b979 100644 --- a/llama-cpp-2/Cargo.toml +++ b/llama-cpp-2/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "llama-cpp-2" description = "llama.cpp bindings for Rust" -version = "0.1.147" +version = "0.1.148" edition = "2021" license = "MIT OR Apache-2.0" repository = "https://github.com/utilityai/llama-cpp-rs" @@ -10,7 +10,7 @@ repository = "https://github.com/utilityai/llama-cpp-rs" [dependencies] enumflags2 = "0.7.12" -llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.147" } +llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.148" } thiserror = { workspace = true } tracing = { workspace = true } tracing-core = { workspace = true } @@ -45,7 +45,7 @@ dynamic-backends = ["llama-cpp-sys-2/dynamic-backends"] [target.'cfg(all(target_os = "macos", any(target_arch = "aarch64", target_arch = "arm64")))'.dependencies] -llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.147", features = [ +llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.148", features = [ "metal", ] } diff --git a/llama-cpp-sys-2/Cargo.toml b/llama-cpp-sys-2/Cargo.toml index 92fd1d91..45f352cf 100644 --- a/llama-cpp-sys-2/Cargo.toml +++ b/llama-cpp-sys-2/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "llama-cpp-sys-2" description = "Low Level Bindings to llama.cpp" -version = "0.1.147" +version = "0.1.148" edition = "2021" license = "MIT OR Apache-2.0" repository = "https://github.com/utilityai/llama-cpp-rs" From 7329550da811c42fb33873bd039225feae49ab8f Mon Sep 17 00:00:00 2001 From: Arthur Konovalov Date: Tue, 26 May 2026 16:24:16 -0600 Subject: [PATCH 7/8] Revert "Bump version to 0.1.148" This reverts commit be3eb5b5d2bc88a01c73e688ba6d6abeb506ad1c. --- Cargo.lock | 8 ++++---- examples/embeddings/Cargo.toml | 2 +- examples/simple/Cargo.toml | 2 +- llama-cpp-2/Cargo.toml | 6 +++--- llama-cpp-sys-2/Cargo.toml | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8e7a1431..820f1b47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -776,7 +776,7 @@ checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "embeddings" -version = "0.1.148" +version = "0.1.147" dependencies = [ "anyhow", "clap", @@ -1538,7 +1538,7 @@ checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" [[package]] name = "llama-cpp-2" -version = "0.1.148" +version = "0.1.147" dependencies = [ "encoding_rs", "enumflags2", @@ -1555,7 +1555,7 @@ dependencies = [ [[package]] name = "llama-cpp-sys-2" -version = "0.1.148" +version = "0.1.147" dependencies = [ "bindgen", "cc", @@ -2360,7 +2360,7 @@ checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" [[package]] name = "simple" -version = "0.1.148" +version = "0.1.147" dependencies = [ "anyhow", "clap", diff --git a/examples/embeddings/Cargo.toml b/examples/embeddings/Cargo.toml index 7ab7bf43..20719948 100644 --- a/examples/embeddings/Cargo.toml +++ b/examples/embeddings/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "embeddings" -version = "0.1.148" +version = "0.1.147" edition = "2021" publish = false diff --git a/examples/simple/Cargo.toml b/examples/simple/Cargo.toml index a5354c32..def301c9 100644 --- a/examples/simple/Cargo.toml +++ b/examples/simple/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple" -version = "0.1.148" +version = "0.1.147" edition = "2021" publish = false diff --git a/llama-cpp-2/Cargo.toml b/llama-cpp-2/Cargo.toml index 8938b979..24e88ddf 100644 --- a/llama-cpp-2/Cargo.toml +++ b/llama-cpp-2/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "llama-cpp-2" description = "llama.cpp bindings for Rust" -version = "0.1.148" +version = "0.1.147" edition = "2021" license = "MIT OR Apache-2.0" repository = "https://github.com/utilityai/llama-cpp-rs" @@ -10,7 +10,7 @@ repository = "https://github.com/utilityai/llama-cpp-rs" [dependencies] enumflags2 = "0.7.12" -llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.148" } +llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.147" } thiserror = { workspace = true } tracing = { workspace = true } tracing-core = { workspace = true } @@ -45,7 +45,7 @@ dynamic-backends = ["llama-cpp-sys-2/dynamic-backends"] [target.'cfg(all(target_os = "macos", any(target_arch = "aarch64", target_arch = "arm64")))'.dependencies] -llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.148", features = [ +llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.147", features = [ "metal", ] } diff --git a/llama-cpp-sys-2/Cargo.toml b/llama-cpp-sys-2/Cargo.toml index 45f352cf..92fd1d91 100644 --- a/llama-cpp-sys-2/Cargo.toml +++ b/llama-cpp-sys-2/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "llama-cpp-sys-2" description = "Low Level Bindings to llama.cpp" -version = "0.1.148" +version = "0.1.147" edition = "2021" license = "MIT OR Apache-2.0" repository = "https://github.com/utilityai/llama-cpp-rs" From 7534cce9baf10cccd6f19eac4e242c2a8a9e626b Mon Sep 17 00:00:00 2001 From: Arthur Konovalov Date: Tue, 26 May 2026 19:30:30 -0600 Subject: [PATCH 8/8] Disable CUDA NCCL in crate builds --- llama-cpp-sys-2/build.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/llama-cpp-sys-2/build.rs b/llama-cpp-sys-2/build.rs index c2578eb8..fb5995a9 100644 --- a/llama-cpp-sys-2/build.rs +++ b/llama-cpp-sys-2/build.rs @@ -831,6 +831,7 @@ fn main() { if cfg!(feature = "cuda") { config.define("GGML_CUDA", "ON"); + config.define("GGML_CUDA_NCCL", "OFF"); if cfg!(feature = "cuda-no-vmm") { config.define("GGML_CUDA_NO_VMM", "ON");