Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llama-cpp-2/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ impl<'model> LlamaContext<'model> {

/// Print a breakdown of per-device memory use to the default logger.
pub fn print_memory_breakdown(&self) {
unsafe { llama_cpp_sys_2::llama_memory_breakdown_print(self.context.as_ptr()) }
unsafe { llama_cpp_sys_2::llama_rs_memory_breakdown_print(self.context.as_ptr()) }
}
}

Expand Down
33 changes: 33 additions & 0 deletions llama-cpp-2/src/context/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,39 @@ impl From<LlamaAttentionType> for i32 {
}
}

/// A rusty wrapper around `llama_context_type`.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum LlamaContextType {
/// Default decoder context.
Default,
/// Multi-token-prediction draft context.
Mtp,
/// Unknown context type from a newer llama.cpp.
Unknown(llama_cpp_sys_2::llama_context_type),
}

/// Create a `LlamaContextType` from a raw `llama_context_type`.
impl From<llama_cpp_sys_2::llama_context_type> for LlamaContextType {
fn from(value: llama_cpp_sys_2::llama_context_type) -> Self {
match value {
x if x == llama_cpp_sys_2::LLAMA_CONTEXT_TYPE_DEFAULT => Self::Default,
x if x == llama_cpp_sys_2::LLAMA_CONTEXT_TYPE_MTP => Self::Mtp,
x => Self::Unknown(x),
}
}
}

/// Create a raw `llama_context_type` from a `LlamaContextType`.
impl From<LlamaContextType> for llama_cpp_sys_2::llama_context_type {
fn from(value: LlamaContextType) -> Self {
match value {
LlamaContextType::Default => llama_cpp_sys_2::LLAMA_CONTEXT_TYPE_DEFAULT,
LlamaContextType::Mtp => llama_cpp_sys_2::LLAMA_CONTEXT_TYPE_MTP,
LlamaContextType::Unknown(raw) => raw,
}
}
}

/// A rusty wrapper around `ggml_type` for KV cache types.
#[allow(non_camel_case_types, missing_docs)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
Expand Down
29 changes: 28 additions & 1 deletion llama-cpp-2/src/context/params/get_set.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::num::NonZeroU32;

use super::{
KvCacheType, LlamaAttentionType, LlamaContextParams, LlamaPoolingType, RopeScalingType,
KvCacheType, LlamaAttentionType, LlamaContextParams, LlamaContextType, LlamaPoolingType,
RopeScalingType,
};

impl LlamaContextParams {
Expand Down Expand Up @@ -128,6 +129,32 @@ impl LlamaContextParams {
self.context_params.n_seq_max
}

/// Set the number of recurrent-state rollback snapshots per sequence.
#[must_use]
pub fn with_n_rs_seq(mut self, n_rs_seq: u32) -> Self {
self.context_params.n_rs_seq = n_rs_seq;
self
}

/// Get the number of recurrent-state rollback snapshots per sequence.
#[must_use]
pub fn n_rs_seq(&self) -> u32 {
self.context_params.n_rs_seq
}

/// Set the llama.cpp context type.
#[must_use]
pub fn with_context_type(mut self, context_type: LlamaContextType) -> Self {
self.context_params.ctx_type = context_type.into();
self
}

/// Get the llama.cpp context type.
#[must_use]
pub fn context_type(&self) -> LlamaContextType {
self.context_params.ctx_type.into()
}

/// Set the number of threads
///
/// # Examples
Expand Down
1 change: 1 addition & 0 deletions llama-cpp-2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub mod model;
pub mod mtmd;
pub mod openai;
pub mod sampling;
pub mod speculative;
pub mod timing;
pub mod token;
pub mod token_type;
Expand Down
6 changes: 3 additions & 3 deletions llama-cpp-2/src/model/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ impl LlamaModelParams {
self.params.tensor_buft_overrides = null();

let status = unsafe {
llama_cpp_sys_2::llama_params_fit(
llama_cpp_sys_2::llama_rs_params_fit(
model_path.as_ptr(),
&raw mut self.params,
&raw mut cparams.context_params,
Expand All @@ -366,8 +366,8 @@ impl LlamaModelParams {
};

match status {
llama_cpp_sys_2::LLAMA_PARAMS_FIT_STATUS_SUCCESS => {}
llama_cpp_sys_2::LLAMA_PARAMS_FIT_STATUS_FAILURE => return Err(FitError::Failure),
llama_cpp_sys_2::LLAMA_RS_PARAMS_FIT_STATUS_SUCCESS => {}
llama_cpp_sys_2::LLAMA_RS_PARAMS_FIT_STATUS_FAILURE => return Err(FitError::Failure),
_ => return Err(FitError::Error),
}

Expand Down
213 changes: 213 additions & 0 deletions llama-cpp-2/src/speculative.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
//! Experimental wrappers for llama.cpp speculative decoding helpers.

use std::ptr::NonNull;

use crate::context::LlamaContext;
use crate::llama_batch::LlamaBatch;
use crate::status_is_ok;
use crate::token::LlamaToken;

/// Parameters for same-model MTP speculative decoding.
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct MtpSpeculativeParams {
/// Maximum number of draft tokens to propose.
pub n_max: i32,
/// Minimum number of draft tokens required before returning a draft.
pub n_min: i32,
/// Minimum draft probability accepted by llama.cpp's MTP drafter.
pub p_min: f32,
}

impl Default for MtpSpeculativeParams {
fn default() -> Self {
Self {
n_max: 3,
n_min: 0,
p_min: 0.0,
}
}
}

/// Errors returned by the MTP speculative wrapper.
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum MtpSpeculativeError {
/// Invalid parameters were provided.
#[error("invalid MTP speculative parameters")]
InvalidParams,
/// llama.cpp returned a null speculative handle.
#[error("llama.cpp failed to initialize MTP speculative decoding")]
InitFailed,
/// llama.cpp rejected a wrapper call.
#[error("llama.cpp MTP speculative call failed with status {0}")]
Status(i32),
/// The draft output exceeded the caller-provided bound.
#[error("llama.cpp MTP draft exceeded configured maximum")]
DraftOverflow,
}

/// RAII owner for a same-model MTP speculative context.
#[derive(Debug)]
pub struct MtpSpeculative<'model> {
raw: NonNull<llama_cpp_sys_2::llama_rs_mtp_speculative>,
target_context: LlamaContext<'model>,
draft_context: LlamaContext<'model>,
n_max: usize,
}

impl<'model> MtpSpeculative<'model> {
/// Create a new MTP speculative helper from a target context and an MTP
/// draft context.
///
/// # Errors
///
/// Returns an error if parameters are invalid or llama.cpp cannot
/// initialize the speculative implementation for the loaded model.
pub fn new(
target_context: LlamaContext<'model>,
draft_context: LlamaContext<'model>,
params: MtpSpeculativeParams,
) -> Result<Self, MtpSpeculativeError> {
if params.n_max <= 0 || params.n_min < 0 || params.n_min > params.n_max {
return Err(MtpSpeculativeError::InvalidParams);
}
let n_max =
usize::try_from(params.n_max).map_err(|_| MtpSpeculativeError::InvalidParams)?;

let raw = unsafe {
llama_cpp_sys_2::llama_rs_mtp_speculative_init(
target_context.context.as_ptr(),
draft_context.context.as_ptr(),
params.n_max,
params.n_min,
params.p_min,
)
};
let raw = NonNull::new(raw).ok_or(MtpSpeculativeError::InitFailed)?;

Ok(Self {
raw,
target_context,
draft_context,
n_max,
})
}

/// Access the target context.
#[must_use]
pub fn target_context(&self) -> &LlamaContext<'model> {
&self.target_context
}

/// Access the target context for decode and cache rollback operations.
pub fn target_context_mut(&mut self) -> &mut LlamaContext<'model> {
&mut self.target_context
}

/// Access the draft context for cache rollback operations.
pub fn draft_context_mut(&mut self) -> &mut LlamaContext<'model> {
&mut self.draft_context
}

/// Begin a new generation from the given prompt tokens.
///
/// # Errors
///
/// Returns an error if llama.cpp rejects the call.
pub fn begin(&mut self, prompt_tokens: &[LlamaToken]) -> Result<(), MtpSpeculativeError> {
let prompt = tokens_to_raw(prompt_tokens);
let status = unsafe {
llama_cpp_sys_2::llama_rs_mtp_speculative_begin(
self.raw.as_ptr(),
prompt.as_ptr(),
prompt.len(),
)
};
status_to_result(status)
}

/// Process a batch that was just decoded by the target context.
///
/// # Errors
///
/// Returns an error if llama.cpp cannot update the MTP draft context.
pub fn process(&mut self, batch: &LlamaBatch<'_>) -> Result<(), MtpSpeculativeError> {
let status = unsafe {
llama_cpp_sys_2::llama_rs_mtp_speculative_process(
self.raw.as_ptr(),
std::ptr::from_ref(&batch.llama_batch),
)
};
status_to_result(status)
}

/// Generate draft tokens after `id_last`.
///
/// # Errors
///
/// Returns an error if llama.cpp rejects the draft operation or emits more
/// draft tokens than requested.
pub fn draft(
&mut self,
n_past: i32,
id_last: LlamaToken,
prompt_tokens: &[LlamaToken],
) -> Result<Vec<LlamaToken>, MtpSpeculativeError> {
if n_past < 0 {
return Err(MtpSpeculativeError::InvalidParams);
}

let prompt = tokens_to_raw(prompt_tokens);
let mut raw_out = vec![0; self.n_max];
let mut out_len = 0_usize;
let status = unsafe {
llama_cpp_sys_2::llama_rs_mtp_speculative_draft(
self.raw.as_ptr(),
n_past,
id_last.0,
prompt.as_ptr(),
prompt.len(),
raw_out.as_mut_ptr(),
raw_out.len(),
&raw mut out_len,
)
};
if status == llama_cpp_sys_2::LLAMA_RS_STATUS_ALLOCATION_FAILED {
return Err(MtpSpeculativeError::DraftOverflow);
}
status_to_result(status)?;
raw_out.truncate(out_len);
Ok(raw_out.into_iter().map(LlamaToken).collect())
}

/// Notify llama.cpp how many draft tokens the target context accepted.
///
/// # Errors
///
/// Returns an error if llama.cpp rejects the call.
pub fn accept(&mut self, n_accepted: u16) -> Result<(), MtpSpeculativeError> {
let status = unsafe {
llama_cpp_sys_2::llama_rs_mtp_speculative_accept(self.raw.as_ptr(), n_accepted)
};
status_to_result(status)
}
}

impl Drop for MtpSpeculative<'_> {
fn drop(&mut self) {
unsafe {
llama_cpp_sys_2::llama_rs_mtp_speculative_free(self.raw.as_ptr());
}
}
}

fn tokens_to_raw(tokens: &[LlamaToken]) -> Vec<llama_cpp_sys_2::llama_token> {
tokens.iter().map(|token| token.0).collect()
}

fn status_to_result(status: llama_cpp_sys_2::llama_rs_status) -> Result<(), MtpSpeculativeError> {
if status_is_ok(status) {
Ok(())
} else {
Err(MtpSpeculativeError::Status(status))
}
}
Loading
Loading