Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions LLama/Batched/BatchedExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ public sealed class BatchedExecutor
/// The <see cref="LLamaWeights"/> this executor is using
/// </summary>
public LLamaWeights Model { get; }


/// <summary>
/// The optional <see cref="MtmdWeights"/> this executor is using
/// </summary>
public MtmdWeights? ClipModel { get; }

/// <summary>
/// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
/// </summary>
Expand Down Expand Up @@ -79,21 +84,15 @@ public int BatchedTokenCount
/// </summary>
/// <param name="model">The model to use</param>
/// <param name="contextParams">Parameters to create a new context</param>
public BatchedExecutor(LLamaWeights model, IContextParams contextParams)
: this(model, contextParams, null)
{
}

public BatchedExecutor(LLamaWeights model, IContextParams contextParams, MtmdWeights? clipModel)
/// <param name="clipModel">Clip model to use for multimodal capabilities</param>
public BatchedExecutor(LLamaWeights model, IContextParams contextParams, MtmdWeights? clipModel = null)
{
Model = model;
Context = model.CreateContext(contextParams);
ClipModel = clipModel;
Epoch = 1;
}

public MtmdWeights? ClipModel { get; }

/// <summary>
/// Start a new <see cref="Conversation"/>
/// </summary>
Expand Down
5 changes: 1 addition & 4 deletions LLama/LLamaEmbedder.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
using LLama.Exceptions;
using LLama.Native;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using static System.Net.Mime.MediaTypeNames;

namespace LLama;

Expand Down Expand Up @@ -79,7 +76,7 @@ public async Task<IReadOnlyList<float[]>> GetEmbeddings(string input, Cancellati
Context.Dispose();

Context = _weights.CreateContext(_params, _logger);
NativeApi.llama_set_embeddings(Context.NativeHandle, true);
Context.NativeHandle.SetEmbeddings(true);

// Add all of the tokens to the batch
var tokens = Context.Tokenize(input, special: true);
Expand Down
6 changes: 1 addition & 5 deletions LLama/LLamaReranker.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Xml.Linq;
using LLama.Abstractions;
using LLama.Exceptions;
using LLama.Native;
Expand Down Expand Up @@ -44,7 +40,7 @@ public LLamaReranker(LLamaWeights weights, IContextParams @params, ILogger? logg
if (@params.PoolingType != LLamaPoolingType.Rank)
throw new NotSupportedException("Computing rank score, PoolingType must be equal to LLamaPoolingType.Rank");
Context = weights.CreateContext(@params, logger);
NativeApi.llama_set_embeddings(Context.NativeHandle, true);
Context.NativeHandle.SetEmbeddings(true);
}

/// <inheritdoc />
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
</ItemGroup>

<PropertyGroup>
<BinaryReleaseId>ff4affb4c1aa7eb4_v3</BinaryReleaseId>
<BinaryReleaseId>73c9eb8ceda397b</BinaryReleaseId>
</PropertyGroup>

<PropertyGroup>
Expand Down
7 changes: 6 additions & 1 deletion LLama/Native/LLamaFtype.cs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,12 @@ public enum LLamaFtype
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38,


/// <summary>
/// Except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_NVFP4 = 39,

/// <summary>
/// File type was not specified
/// </summary>
Expand Down
10 changes: 10 additions & 0 deletions LLama/Native/LLamaModelQuantizeParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ public bool keep_split
}
private sbyte _keep_split;

/// <summary>
/// calculate and show the final quantization size without performing quantization
/// </summary>
public bool dry_run
{
get => Convert.ToBoolean(_dry_run);
set => _dry_run = Convert.ToSByte(value);
}
private sbyte _dry_run;

/// <summary>
/// pointer to importance matrix data
/// </summary>
Expand Down
20 changes: 0 additions & 20 deletions LLama/Native/LoraAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,10 @@ public class LoraAdapter
/// </summary>
internal IntPtr Pointer { get; }

/// <summary>
/// Indicates if this adapter has been unloaded
/// </summary>
internal bool Loaded { get; private set; }

internal LoraAdapter(SafeLlamaModelHandle model, string path, IntPtr nativePtr)
{
Model = model;
Path = path;
Pointer = nativePtr;
Loaded = true;
}

/// <summary>
/// Unload this adapter
/// </summary>
public void Unload()
{
Loaded = false;
llama_adapter_lora_free(Pointer);

// Manually free a LoRA adapter. loaded adapters will be free when the associated model is deleted
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
[Obsolete("adapters are now freed together with the associated model")]
static extern void llama_adapter_lora_free(IntPtr adapter);
}
}
2 changes: 1 addition & 1 deletion LLama/Native/NativeApi.Mtmd.cs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ internal static unsafe void mtmd_bitmap_set_id(SafeMtmdEmbed bitmap, string? id)
// tokenize ----------------------------------------------------------

/// <summary>
/// Native text structure consumed by <see cref="mtmd_tokenize"/>.
/// Native text structure consumed by <see cref="NativeApi.mtmd_tokenize(LLama.Native.SafeMtmdModelHandle,System.IntPtr,in LLama.Native.NativeApi.mtmd_input_text_native,System.IntPtr[],System.UIntPtr)"/>.
/// </summary>
internal unsafe struct mtmd_input_text_native
{
Expand Down
36 changes: 14 additions & 22 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -132,34 +132,14 @@ public static void llama_empty_call()
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe nuint llama_state_seq_load_file(SafeLLamaContextHandle ctx, string filepath, LLamaSeqId dest_seq_id, LLamaToken* tokens_out, nuint n_token_capacity, out nuint n_token_count_out);

/// <summary>
/// Set whether to use causal attention or not. If set to true, the model will only attend to the past tokens
/// </summary>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_set_causal_attn(SafeLLamaContextHandle ctx, [MarshalAs(UnmanagedType.U1)] bool causalAttn);

/// <summary>
/// Set whether the context outputs embeddings or not
/// </summary>
/// <param name="ctx"></param>
/// <param name="embeddings">If true, embeddings will be returned but logits will not</param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_set_embeddings(SafeLLamaContextHandle ctx, [MarshalAs(UnmanagedType.U1)] bool embeddings);


/// <summary>
/// Set abort callback
/// </summary>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_set_abort_callback(SafeLLamaContextHandle ctx, IntPtr /* ggml_abort_callback */ abortCallback, IntPtr abortCallbackData);

/// <summary>
/// Get the n_seq_max for this context
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern uint llama_n_seq_max(SafeLLamaContextHandle ctx);

/// <summary>
/// Get all output token embeddings.
/// When pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, the embeddings for which
Expand Down Expand Up @@ -515,6 +495,18 @@ public static extern unsafe LLamaParamsFitStatus llama_params_fit(
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern long llama_time_us();


/* Directly exposes `ggml_tensor` and `gguf_context` which LLamaSharp does not currently support!

typedef void (* llama_model_set_tensor_data_t) (struct ggml_tensor * tensor, void* userdata);

// Create a new model from GGUF metadata as well as a function to set the tensor data
// - tensors are created as GGML_TYPE_F32 by default,
// override by adding a tensor with the same name but a different name to the context
LLAMA_API struct llama_model * llama_model_init_from_user(
struct gguf_context * metadata,
llama_model_set_tensor_data_t set_tensor_data, // function to initialize tensor data with
void* set_tensor_data_ud, // userdata for function
struct llama_model_params params);
*/
}
}
61 changes: 55 additions & 6 deletions LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public sealed class SafeLLamaContextHandle
/// <summary>
/// Get the number of maximum sequences allowed
/// </summary>
public uint MaxSeq => NativeApi.llama_n_seq_max(this);
public uint MaxSeq => llama_n_seq_max(this);

/// <summary>
/// Get or set the number of threads used for generation of a single token.
Expand Down Expand Up @@ -355,6 +355,7 @@ static SafeLLamaContextHandle()
/// <param name="buf_size"></param>
/// <returns>The length of the value string (on success) -1 otherwise </returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
// ReSharper disable once InconsistentNaming
private static extern int llama_adapter_meta_val_str(IntPtr adapter, string key, StringBuilder buf, UIntPtr buf_size);

/// <summary>
Expand All @@ -374,6 +375,7 @@ static SafeLLamaContextHandle()
/// <param name="buf_size"></param>
/// <returns>The length of string i.e meta key (on success) -1 otherwise</returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
// ReSharper disable once InconsistentNaming
private static extern int llama_adapter_meta_key_by_index(IntPtr adapter, int i, StringBuilder buf, UIntPtr buf_size);

/// <summary>
Expand All @@ -385,6 +387,7 @@ static SafeLLamaContextHandle()
/// <param name="buf_size"></param>
/// <returns>The length of value string (on success) -1 otherwise</returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
// ReSharper disable once InconsistentNaming
private static extern int llama_adapter_meta_val_by_index(IntPtr adapter, int i, StringBuilder buf, UIntPtr buf_size);

/// <summary>
Expand Down Expand Up @@ -424,6 +427,56 @@ static SafeLLamaContextHandle()
/// <param name="warmup"></param>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_set_warmup(SafeLLamaContextHandle ctx, [MarshalAs(UnmanagedType.U1)] bool warmup);

/// <summary>
/// Set whether to use causal attention or not. If set to true, the model will only attend to the past tokens
/// </summary>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_set_causal_attn(SafeLLamaContextHandle ctx, [MarshalAs(UnmanagedType.U1)] bool causalAttn);

/// <summary>
/// Set whether the context outputs embeddings or not
/// </summary>
/// <param name="ctx"></param>
/// <param name="embeddings">If true, embeddings will be returned but logits will not</param>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_set_embeddings(SafeLLamaContextHandle ctx, [MarshalAs(UnmanagedType.U1)] bool embeddings);

/// <summary>
/// Get the n_seq_max for this context
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern uint llama_n_seq_max(SafeLLamaContextHandle ctx);
#endregion

#region Setters
/// <summary>
/// Set whether the model is in warmup mode or not
/// If true, all model tensors are activated during <see cref="Decode(LLamaBatch)"/> to load and cache their weights.
/// </summary>
public void SetWarmup(bool value)
{
llama_set_warmup(this, value);
}

/// <summary>
/// Set whether to use causal attention or not. If set to true, the model will only attend to the past tokens
/// </summary>
public void SetCausalAttention(bool value)
{
llama_set_causal_attn(this, value);
}

/// <summary>
/// Set whether the context outputs embeddings or not
/// </summary>
/// <param name="value">If true, embeddings will be returned but logits will not</param>
public void SetEmbeddings(bool value)
{
llama_set_embeddings(this, value);
}
#endregion

#region LoRA
Expand All @@ -434,14 +487,10 @@ static SafeLLamaContextHandle()
/// <exception cref="ArgumentException"></exception>
public void SetLoraAdapters(params Span<(LoraAdapter Adapter, float Scale)> adapters)
{
// Check adapters are all valid
// Check adapters are all valid and attached to this model
foreach (var adapter in adapters)
{
if (adapter.Adapter.Model != ModelHandle)
throw new ArgumentException("Cannot add LoRA adapter which was loaded for a different model");
if (!adapter.Adapter.Loaded)
throw new ArgumentException("Cannot add LoRA adapter which has been unloaded");
}

// Copy data into buffers
Span<IntPtr> adapterPtrs = stackalloc IntPtr[adapters.Length];
Expand Down
6 changes: 3 additions & 3 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ static SafeLlamaModelHandle()
private static extern unsafe byte* llama_model_chat_template(SafeLlamaModelHandle model, string? name);

/// <summary>
/// Load the model from a file
/// Load a model from a file
/// If the file is split into multiple parts, the file name must follow this pattern: {name}-%05d-of-%05d.gguf
/// If the split file name does not follow this pattern, use llama_model_load_from_splits
/// </summary>
Expand All @@ -186,7 +186,7 @@ static SafeLlamaModelHandle()
private static extern SafeLlamaModelHandle llama_model_load_from_file(string path, LLamaModelParams @params);

/// <summary>
/// Load the model from multiple splits (support custom naming scheme)
/// Load a model from multiple splits (support custom naming scheme)
/// The paths must be in the correct order
/// </summary>
/// <returns></returns>
Expand Down Expand Up @@ -460,7 +460,7 @@ private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string k
/// <param name="i"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern string? llama_model_cls_label(SafeLlamaModelHandle model, uint i);
private static extern IntPtr /* char* */ llama_model_cls_label(SafeLlamaModelHandle model, uint i);
#endregion

#region LoRA
Expand Down
2 changes: 1 addition & 1 deletion llama.cpp
Loading