From d80d038a0edacfddfdada8e30c52a922fed21a04 Mon Sep 17 00:00:00 2001 From: Krisbiradar Date: Fri, 5 Sep 2025 00:40:00 +0530 Subject: [PATCH 01/35] Update LLamaModelParams.cs --- LLama/Native/LLamaModelParams.cs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs index acb024852..4826c96b7 100644 --- a/LLama/Native/LLamaModelParams.cs +++ b/LLama/Native/LLamaModelParams.cs @@ -100,7 +100,16 @@ public bool check_tensors set => _check_tensors = Convert.ToSByte(value); } private sbyte _check_tensors; - + + /// + /// use extra buffer types (used for weight repacking) + /// + public bool use_extra_bufts + { + readonly get => Convert.ToBoolean(_use_extra_bufts); + set => _use_extra_bufts = Convert.ToSByte(value); + } + private sbyte _use_extra_bufts; /// /// Create a LLamaModelParams with default values /// From da017898f26cf8144d31d13b1814e8b05def4b5f Mon Sep 17 00:00:00 2001 From: Krisbiradar Date: Thu, 11 Sep 2025 10:43:27 +0530 Subject: [PATCH 02/35] Add Flash Attention and diffusion model support Introduces LLamaFlashAttentionType enum and integrates flash attention configuration into LLamaContextParams. Adds support for diffusion-based models in SafeLlamaModelHandle. Updates NativeApi and SafeLLamaContextHandle with new adapter metadata and sequence state methods. Syncs llama.cpp submodule. --- LLama/Native/LLamaContextParams.cs | 5 ++ LLama/Native/LLamaFlashAttentionType.cs | 19 +++++ LLama/Native/LLamaFtype.cs | 7 +- LLama/Native/NativeApi.cs | 95 +++++++++++++++++++------ LLama/Native/SafeLLamaContextHandle.cs | 41 +++++++++++ LLama/Native/SafeLLamaSamplerHandle.cs | 2 +- LLama/Native/SafeLlamaModelHandle.cs | 11 ++- llama.cpp | 2 +- 8 files changed, 158 insertions(+), 24 deletions(-) create mode 100644 LLama/Native/LLamaFlashAttentionType.cs diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs index 76f5d6c77..6dea4de47 100644 --- a/LLama/Native/LLamaContextParams.cs +++ b/LLama/Native/LLamaContextParams.cs @@ -64,6 +64,11 @@ public struct LLamaContextParams /// Attention type to use for embeddings /// public LLamaAttentionType attention_type; + + /// + /// when to enable Flash Attention + /// + public LLamaFlashAttentionType llama_flash_attn_type; /// /// RoPE base frequency, 0 = from model diff --git a/LLama/Native/LLamaFlashAttentionType.cs b/LLama/Native/LLamaFlashAttentionType.cs new file mode 100644 index 000000000..7138dea93 --- /dev/null +++ b/LLama/Native/LLamaFlashAttentionType.cs @@ -0,0 +1,19 @@ +namespace LLama.Native; +/// +/// flash_attn_type +/// +public enum LLamaFlashAttentionType +{ + /// + /// attention type auto + /// + LLAMA_FLASH_ATTENTION_TYPE_AUTO = -1, + /// + /// attention disabled + /// + LLAMA_FLASH_ATTENTION_TYPE_DISABLED = 0, + /// + /// attention enabled + /// + LLAMA_FLASH_ATTENTION_TYPE_ENABLED = 1, +} \ No newline at end of file diff --git a/LLama/Native/LLamaFtype.cs b/LLama/Native/LLamaFtype.cs index 705f8032e..813bad1ae 100644 --- a/LLama/Native/LLamaFtype.cs +++ b/LLama/Native/LLamaFtype.cs @@ -201,7 +201,12 @@ public enum LLamaFtype /// except 1d tensors /// LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, - + + /// + /// except 1d tensors + /// + LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, + /// /// File type was not specified /// diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index db9e928bd..0a5ad6003 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -99,7 +99,8 @@ public static void llama_empty_call() /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [return: MarshalAs(UnmanagedType.U1)] - public static extern bool llama_state_load_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out); + public static extern bool llama_state_load_file(SafeLLamaContextHandle ctx, string path_session, + LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out); /// /// Save session file @@ -111,25 +112,29 @@ public static void llama_empty_call() /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [return: MarshalAs(UnmanagedType.U1)] - public static extern bool llama_state_save_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count); + public static extern bool llama_state_save_file(SafeLLamaContextHandle ctx, string path_session, + LLamaToken[] tokens, ulong n_token_count); /// /// Saves the specified sequence as a file on specified filepath. Can later be loaded via /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe nuint llama_state_seq_save_file(SafeLLamaContextHandle ctx, string filepath, LLamaSeqId seq_id, LLamaToken* tokens, nuint n_token_count); + public static extern unsafe nuint llama_state_seq_save_file(SafeLLamaContextHandle ctx, string filepath, + LLamaSeqId seq_id, LLamaToken* tokens, nuint n_token_count); /// /// Loads a sequence saved as a file via into the specified sequence /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe nuint llama_state_seq_load_file(SafeLLamaContextHandle ctx, string filepath, LLamaSeqId dest_seq_id, LLamaToken* tokens_out, nuint n_token_capacity, out nuint n_token_count_out); + public static extern unsafe nuint llama_state_seq_load_file(SafeLLamaContextHandle ctx, string filepath, + LLamaSeqId dest_seq_id, LLamaToken* tokens_out, nuint n_token_capacity, out nuint n_token_count_out); /// /// Set whether to use causal attention or not. If set to true, the model will only attend to the past tokens /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_set_causal_attn(SafeLLamaContextHandle ctx, [MarshalAs(UnmanagedType.U1)] bool causalAttn); + public static extern void llama_set_causal_attn(SafeLLamaContextHandle ctx, + [MarshalAs(UnmanagedType.U1)] bool causalAttn); /// /// Set whether the context outputs embeddings or not @@ -137,13 +142,15 @@ public static void llama_empty_call() /// /// If true, embeddings will be returned but logits will not [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_set_embeddings(SafeLLamaContextHandle ctx, [MarshalAs(UnmanagedType.U1)] bool embeddings); + public static extern void llama_set_embeddings(SafeLLamaContextHandle ctx, + [MarshalAs(UnmanagedType.U1)] bool embeddings); /// /// Set abort callback /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_set_abort_callback(SafeLlamaModelHandle ctx, IntPtr /* ggml_abort_callback */ abortCallback, IntPtr abortCallbackData); + public static extern void llama_set_abort_callback(SafeLlamaModelHandle ctx, + IntPtr /* ggml_abort_callback */ abortCallback, IntPtr abortCallbackData); /// /// Get the n_seq_max for this context @@ -175,12 +182,15 @@ public static void llama_empty_call() /// A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages) /// The size of the allocated buffer /// The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. - public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length) + public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, + [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length) { return internal_llama_chat_apply_template(tmpl, chat, n_msg, add_ass, buf, length); - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_chat_apply_template")] - static extern int internal_llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length); + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, + EntryPoint = "llama_chat_apply_template")] + static extern int internal_llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, + [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length); } /// @@ -215,7 +225,8 @@ public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* /// User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix') /// If true, special tokens are rendered in the output /// The length written, or if the buffer is too small a negative that indicates the length required - public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken, Span buffer, int lstrip, bool special) + public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken, + Span buffer, int lstrip, bool special) { // Handle invalid tokens if ((int)llamaToken < 0) @@ -225,12 +236,14 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL { fixed (byte* bufferPtr = buffer) { - return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip, special); + return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip, + special); } } [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")] - static extern unsafe int llama_token_to_piece_native(LLamaVocabNative* model, LLamaToken llamaToken, byte* buffer, int length, int lstrip, [MarshalAs(UnmanagedType.U1)] bool special); + static extern unsafe int llama_token_to_piece_native(LLamaVocabNative* model, LLamaToken llamaToken, + byte* buffer, int length, int lstrip, [MarshalAs(UnmanagedType.U1)] bool special); } /// @@ -247,7 +260,9 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL /// Returns a negative number on failure - the number of tokens that would have been returned. Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit) /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - internal static extern unsafe int llama_tokenize(LLamaVocabNative* model, byte* text, int text_len, LLamaToken* tokens, int n_max_tokens, [MarshalAs(UnmanagedType.U1)] bool add_special, [MarshalAs(UnmanagedType.U1)] bool parse_special); + internal static extern unsafe int llama_tokenize(LLamaVocabNative* model, byte* text, int text_len, + LLamaToken* tokens, int n_max_tokens, [MarshalAs(UnmanagedType.U1)] bool add_special, + [MarshalAs(UnmanagedType.U1)] bool parse_special); /// /// Convert the provided tokens into text (inverse of llama_tokenize()). @@ -261,7 +276,8 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL /// unparse_special If true, special tokens are rendered in the output. /// Returns the number of chars/bytes on success, no more than textLengthMax. Returns a negative number on failure - the number of chars/bytes that would have been returned. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - internal static extern unsafe int llama_detokenize(LLamaVocabNative* model, LLamaToken* tokens, int nTokens, byte* textOut, int textLengthMax, bool removeSpecial, bool unparseSpecial); + internal static extern unsafe int llama_detokenize(LLamaVocabNative* model, LLamaToken* tokens, int nTokens, + byte* textOut, int textLengthMax, bool removeSpecial, bool unparseSpecial); /// /// Register a callback to receive llama log messages @@ -272,7 +288,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) { NativeLogConfig.llama_log_set(logCallback); } - + /// /// Allocates a batch of tokens on the heap /// Each token can be assigned up to n_seq_max sequence ids @@ -311,7 +327,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe int llama_apply_adapter_cvec(SafeLLamaContextHandle ctx, float* data, nuint len, int n_embd, int il_start, int il_end); + public static extern unsafe int llama_apply_adapter_cvec(SafeLLamaContextHandle ctx, float* data, nuint len, + int n_embd, int il_start, int il_end); /// /// Build a split GGUF final path for this chunk. @@ -324,7 +341,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) /// /// Returns the split_path length. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no, int split_count); + public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no, + int split_count); /// /// Extract the path prefix from the split_path if and only if the split_no and split_count match. @@ -337,7 +355,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) /// /// Returns the split_prefix length. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no, int split_count); + public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no, + int split_count); //[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] //todo: public static void llama_attach_threadpool(SafeLLamaContextHandle ctx, ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); @@ -380,5 +399,41 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) /// Name of the buffer type [DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)] public static extern IntPtr ggml_backend_buft_name(IntPtr buft); + + /// + /// + /// + /// + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern UIntPtr llama_state_seq_get_size_ext(IntPtr ctx, int seq_id, uint flags); + + /// + /// + /// + /// + /// + /// + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern UIntPtr llama_state_seq_get_data_ext(IntPtr ctx, [Out] byte[] dst, UIntPtr size, + int seq_id, uint flags); + + /// + /// + /// + /// + /// + /// + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern UIntPtr llama_state_seq_set_data_ext(IntPtr ctx, byte[] src, UIntPtr size, int dest_seq_id, + uint flags); } -} +} \ No newline at end of file diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index e26619b26..10e0aa050 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -341,6 +341,47 @@ static SafeLLamaContextHandle() [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern int llama_set_adapter_lora(SafeLLamaContextHandle context, IntPtr adapter, float scale); + /// + /// Get metadata value as a string by key name + /// + /// + /// + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_adapter_meta_val_str(IntPtr adapter, string key, StringBuilder buf, UIntPtr buf_size); + + /// + /// Get the number of metadata key value pairs + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_adapter_meta_count(IntPtr adapter); + + /// + /// Get metadata key name by index + /// + /// + /// + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_adapter_meta_key_by_index(IntPtr adapter, int i, StringBuilder buf, UIntPtr buf_size); + + /// + /// Get metadata key value by index + /// + /// + /// + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_adapter_meta_val_by_index(IntPtr adapter, int i, StringBuilder buf, UIntPtr buf_size); + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern int llama_rm_adapter_lora(SafeLLamaContextHandle context, IntPtr adapter); diff --git a/LLama/Native/SafeLLamaSamplerHandle.cs b/LLama/Native/SafeLLamaSamplerHandle.cs index bad1a1974..a113e1694 100644 --- a/LLama/Native/SafeLLamaSamplerHandle.cs +++ b/LLama/Native/SafeLLamaSamplerHandle.cs @@ -616,7 +616,7 @@ static extern unsafe IntPtr llama_sampler_init_logit_bias( // This is a tricky method to work with! // It can't return a handle, because that would create a second handle to these resources. - // Instead It returns the raw pointer, and that can be looked up in the _samplers dictionary. + // Instead , It returns the raw pointer, and that can be looked up in the _samplers dictionary. [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern IntPtr llama_sampler_chain_get(SafeLLamaSamplerChainHandle chain, int i); // ReSharper restore InconsistentNaming diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index d335a1209..196bb1763 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -80,7 +80,12 @@ public sealed class SafeLlamaModelHandle /// Returns true if the model is recurrent (like Mamba, RWKV, etc.) /// public bool IsRecurrent => llama_model_is_recurrent(this); - + + /// + /// Returns true if the model is diffusion based (like LLaDA , Dream etc ) + /// + public bool IsDiffusion => llama_model_is_diffusion(this); + /// /// Get a description of this model /// @@ -424,6 +429,10 @@ private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string k [return: MarshalAs(UnmanagedType.U1)] private static extern bool llama_model_is_recurrent(SafeLlamaModelHandle model); + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.U1)] + private static extern bool llama_model_is_diffusion(SafeLlamaModelHandle model); + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern unsafe LLamaVocabNative* llama_model_get_vocab(SafeLlamaModelHandle model); diff --git a/llama.cpp b/llama.cpp index 11dd5a44e..86587da03 160000 --- a/llama.cpp +++ b/llama.cpp @@ -1 +1 @@ -Subproject commit 11dd5a44eb180e1d69fac24d3852b5222d66fb7f +Subproject commit 86587da03bd78df8f4e7d8b111a0c1d2494d6ed0 From 53c8c56bb1e0112ce341ba760b99ccca08dee5d5 Mon Sep 17 00:00:00 2001 From: Krisbiradar Date: Sat, 13 Sep 2025 20:53:57 +0530 Subject: [PATCH 03/35] Update LLamaSharp.csproj --- LLama/LLamaSharp.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index 629d10447..96f272c14 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -68,7 +68,7 @@ - + From 20bcf74a7019e4ea5d937dad9c0e01411468b87b Mon Sep 17 00:00:00 2001 From: Krisbiradar Date: Sat, 13 Sep 2025 21:20:26 +0530 Subject: [PATCH 04/35] Update LLamaSharp.csproj --- LLama/LLamaSharp.csproj | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index 96f272c14..be5d09da3 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -57,7 +57,7 @@ - 11dd5a44eb180e + 86587da @@ -68,7 +68,7 @@ - + From 48f109aafc24aa301b07ed18501c5b27d6c8bf84 Mon Sep 17 00:00:00 2001 From: Krisbiradar Date: Mon, 15 Sep 2025 23:49:27 +0530 Subject: [PATCH 05/35] bug fix: remove flash attention parameter from the model params --- LLama.KernelMemory/BuilderExtensions.cs | 1 - LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs | 2 -- LLama.KernelMemory/LlamaSharpTextGenerator.cs | 2 -- LLama.Unittest/SamplingTests.cs | 2 +- LLama/Abstractions/IContextParams.cs | 5 ----- LLama/Common/ModelParams.cs | 6 ++---- LLama/Extensions/IContextParamsExtensions.cs | 1 - LLama/Native/SafeLLamaContextHandle.cs | 8 ++++---- 8 files changed, 7 insertions(+), 20 deletions(-) diff --git a/LLama.KernelMemory/BuilderExtensions.cs b/LLama.KernelMemory/BuilderExtensions.cs index 6ab04a8bc..0aae8e69d 100644 --- a/LLama.KernelMemory/BuilderExtensions.cs +++ b/LLama.KernelMemory/BuilderExtensions.cs @@ -77,7 +77,6 @@ public static IKernelMemoryBuilder WithLLamaSharpDefaults(this IKernelMemoryBuil SplitMode = config.SplitMode, BatchSize = 512, UBatchSize = 512, - FlashAttention = true, UseMemorymap = true }; diff --git a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs index 0635015df..b5c110194 100644 --- a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs +++ b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs @@ -40,7 +40,6 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config) SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer, BatchSize = 512, UBatchSize = 512, - FlashAttention = true, UseMemorymap = true, PoolingType = LLamaPoolingType.Mean, }; @@ -68,7 +67,6 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config, LLamaWeights we SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer, BatchSize = 512, UBatchSize = 512, - FlashAttention = true, UseMemorymap = true, PoolingType = LLamaPoolingType.Mean, }; diff --git a/LLama.KernelMemory/LlamaSharpTextGenerator.cs b/LLama.KernelMemory/LlamaSharpTextGenerator.cs index 5c965b266..166d4ad38 100644 --- a/LLama.KernelMemory/LlamaSharpTextGenerator.cs +++ b/LLama.KernelMemory/LlamaSharpTextGenerator.cs @@ -38,7 +38,6 @@ public LlamaSharpTextGenerator(LLamaSharpConfig config) SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer, BatchSize = 512, UBatchSize = 512, - FlashAttention = true, UseMemorymap = true }; _weights = LLamaWeights.LoadFromFile(@params); @@ -66,7 +65,6 @@ public LlamaSharpTextGenerator(LLamaWeights weights, LLamaSharpConfig config, St SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer, BatchSize = 512, UBatchSize = 512, - FlashAttention = true, UseMemorymap = true }; _executor = executor ?? new StatelessExecutor(_weights, @params); diff --git a/LLama.Unittest/SamplingTests.cs b/LLama.Unittest/SamplingTests.cs index 615a7c79e..297641df3 100644 --- a/LLama.Unittest/SamplingTests.cs +++ b/LLama.Unittest/SamplingTests.cs @@ -104,7 +104,7 @@ public void BatchedSampling() } } - // Add " repeat" and test whether next tokens will be "this phrase forever.". + // Add " repeat" and test whether next tokens will be "this phrase forever." for (int i = 0; i < 4; i++) { for (int b = 0; b < batch_count; b++) diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs index f80759c8a..e376258bb 100644 --- a/LLama/Abstractions/IContextParams.cs +++ b/LLama/Abstractions/IContextParams.cs @@ -103,11 +103,6 @@ public interface IContextParams /// bool NoKqvOffload { get; } - /// - /// Whether to use flash attention - /// - bool FlashAttention { get; } - /// /// defragment the KV cache if holes/size > defrag_threshold, Set to <= 0 to disable (default) /// diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index 89737faa7..532dc1a22 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -1,3 +1,4 @@ +using System; using LLama.Abstractions; using System.Text; using System.Text.Json.Serialization; @@ -97,10 +98,7 @@ public record ModelParams public bool NoKqvOffload { get; set; } /// - - public bool FlashAttention { get; set; } - - /// + [Obsolete] public float? DefragThreshold { get; set; } /// diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs index 85e40f7ad..882bf7fd3 100644 --- a/LLama/Extensions/IContextParamsExtensions.cs +++ b/LLama/Extensions/IContextParamsExtensions.cs @@ -49,7 +49,6 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16; result.type_v = @params.TypeV ?? GGMLType.GGML_TYPE_F16; result.offload_kqv = !@params.NoKqvOffload; - result.flash_attention = @params.FlashAttention; result.llama_pooling_type = @params.PoolingType; result.attention_type = @params.AttentionType; diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 10e0aa050..f48e818b7 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -348,7 +348,7 @@ static SafeLLamaContextHandle() /// /// /// - /// + /// The length of the value string (on success) -1 otherwise [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern int llama_adapter_meta_val_str(IntPtr adapter, string key, StringBuilder buf, UIntPtr buf_size); @@ -356,7 +356,7 @@ static SafeLLamaContextHandle() /// Get the number of metadata key value pairs /// /// - /// + /// The count of meta key value pairs [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern int llama_adapter_meta_count(IntPtr adapter); @@ -367,7 +367,7 @@ static SafeLLamaContextHandle() /// /// /// - /// + /// The length of string i.e meta key (on success) -1 otherwise [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern int llama_adapter_meta_key_by_index(IntPtr adapter, int i, StringBuilder buf, UIntPtr buf_size); @@ -378,7 +378,7 @@ static SafeLLamaContextHandle() /// /// /// - /// + /// The length of value string (on success) -1 otherwise [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern int llama_adapter_meta_val_by_index(IntPtr adapter, int i, StringBuilder buf, UIntPtr buf_size); From 424a7360d86efff0e437514aa705385524e725b3 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sun, 5 Oct 2025 08:36:01 +0100 Subject: [PATCH 06/35] Fixed some failing tests, it looks like there's a min context size which is why these were failing now. (#1) --- LLama.Unittest/LLamaContextTests.cs | 4 ++-- LLama.Unittest/LLamaContextWithCustomLoggerTests.cs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs index e28b55ce0..b04ee5382 100644 --- a/LLama.Unittest/LLamaContextTests.cs +++ b/LLama.Unittest/LLamaContextTests.cs @@ -13,7 +13,7 @@ public LLamaContextTests() { var @params = new ModelParams(Constants.GenerativeModelPath2) { - ContextSize = 128, + ContextSize = 512, BatchSize = 8, UBatchSize = 8, SeqMax = 1, @@ -33,7 +33,7 @@ public void Dispose() [Fact] public void CheckProperties() { - Assert.Equal(128u, _context.ContextSize); + Assert.Equal(512u, _context.ContextSize); Assert.Equal(960, _context.EmbeddingSize); Assert.Equal(49152, _context.Vocab.Count); } diff --git a/LLama.Unittest/LLamaContextWithCustomLoggerTests.cs b/LLama.Unittest/LLamaContextWithCustomLoggerTests.cs index 1d16b0481..871b6b8cd 100644 --- a/LLama.Unittest/LLamaContextWithCustomLoggerTests.cs +++ b/LLama.Unittest/LLamaContextWithCustomLoggerTests.cs @@ -30,7 +30,7 @@ public LLamaContextWithCustomLoggerTests() { var @params = new ModelParams(Constants.GenerativeModelPath2) { - ContextSize = 128, + ContextSize = 512, GpuLayerCount = Constants.CIGpuLayerCount, }; @@ -55,7 +55,7 @@ public void Dispose() [Fact] public void CheckProperties() { - Assert.Equal(128u, _context.ContextSize); + Assert.Equal(512u, _context.ContextSize); Assert.Equal(960, _context.EmbeddingSize); Assert.Equal(49152, _context.Vocab.Count); } From ff6ea954dde4b35a0144165a6670a70a63a82a72 Mon Sep 17 00:00:00 2001 From: Krisbiradar Date: Tue, 14 Oct 2025 02:25:22 +0530 Subject: [PATCH 07/35] Fix Reranker and Sampling Test Failures --- LLama.Unittest/LLamaRerankerTests.cs | 2 +- LLama.Unittest/SamplingTests.cs | 1 + LLama.Web/Common/ModelOptions.cs | 2 +- LLama/Abstractions/IContextParams.cs | 5 +++++ LLama/Common/ModelParams.cs | 3 +++ LLama/Extensions/IContextParamsExtensions.cs | 7 +++++++ 6 files changed, 18 insertions(+), 2 deletions(-) diff --git a/LLama.Unittest/LLamaRerankerTests.cs b/LLama.Unittest/LLamaRerankerTests.cs index b8dfcfa8d..534623a41 100644 --- a/LLama.Unittest/LLamaRerankerTests.cs +++ b/LLama.Unittest/LLamaRerankerTests.cs @@ -18,9 +18,9 @@ public LLamaRerankerTests(ITestOutputHelper testOutputHelper) var @params = new ModelParams(Constants.RerankingModelPath) { ContextSize = 0, + SeqMax = 1, PoolingType = LLamaPoolingType.Rank, GpuLayerCount = Constants.CIGpuLayerCount, - }; using var weights = LLamaWeights.LoadFromFile(@params); _reranker = new LLamaReranker(weights, @params); diff --git a/LLama.Unittest/SamplingTests.cs b/LLama.Unittest/SamplingTests.cs index 297641df3..5dcb7b494 100644 --- a/LLama.Unittest/SamplingTests.cs +++ b/LLama.Unittest/SamplingTests.cs @@ -25,6 +25,7 @@ public SamplingTests(ITestOutputHelper testOutputHelper) _params = new ModelParams(Constants.GenerativeModelPath2) { ContextSize = 200, BatchSize = 200, + SeqMax = 4, GpuLayerCount = Constants.CIGpuLayerCount, }; _model = LLamaWeights.LoadFromFile(_params); diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index c453aeddf..586db5611 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -102,7 +102,7 @@ public class ModelOptions public bool NoKqvOffload { get; set; } /// - public bool FlashAttention { get; set; } + public bool? FlashAttention { get; set; } /// public Encoding Encoding { get; set; } = Encoding.UTF8; diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs index e376258bb..b7abed5ed 100644 --- a/LLama/Abstractions/IContextParams.cs +++ b/LLama/Abstractions/IContextParams.cs @@ -103,6 +103,11 @@ public interface IContextParams /// bool NoKqvOffload { get; } + /// + /// Whether to use flash attention + /// + bool? FlashAttention { get; } + /// /// defragment the KV cache if holes/size > defrag_threshold, Set to <= 0 to disable (default) /// diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index 532dc1a22..1b7e44308 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -96,6 +96,9 @@ public record ModelParams /// public bool NoKqvOffload { get; set; } + + /// + public bool? FlashAttention { get; set; } /// [Obsolete] diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs index 882bf7fd3..816118524 100644 --- a/LLama/Extensions/IContextParamsExtensions.cs +++ b/LLama/Extensions/IContextParamsExtensions.cs @@ -51,6 +51,13 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo result.offload_kqv = !@params.NoKqvOffload; result.llama_pooling_type = @params.PoolingType; result.attention_type = @params.AttentionType; + result.llama_flash_attn_type = @params.FlashAttention switch + { + true => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_ENABLED, + false => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_DISABLED, + null => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_AUTO + }; + result.kv_unified = true; result.n_threads = Threads(@params.Threads); result.n_threads_batch = Threads(@params.BatchThreads); From 5a0e7b8af50b773778846b032660fcf5be6f7f33 Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sat, 27 Sep 2025 07:52:40 +0200 Subject: [PATCH 08/35] Mtmd Implementation base --- LLama.Examples/ExampleRunner.cs | 4 +- .../Examples/BatchedExecutorLLava.cs | 91 ----- .../Examples/BatchedExecutorMtmd.cs | 126 +++++++ ...ecute.cs => MtmdInteractiveModeExecute.cs} | 83 +++-- LLama.Examples/LLama.Examples.csproj | 2 +- LLama.Unittest/Constants.cs | 6 +- LLama.Unittest/LLama.Unittest.csproj | 16 +- LLama.Unittest/MtmdExecutorTests.cs | 81 ++++ LLama.Unittest/MtmdWeightsTests.cs | 140 +++++++ .../Native/SafeLlamaModelHandleTests.cs | 32 -- .../SafeLlamaModelHandleVocabularyTests.cs | 42 --- LLama/Abstractions/ILLamaExecutor.cs | 7 +- LLama/Batched/BatchedExecutor.cs | 65 ++++ LLama/Batched/Conversation.cs | 242 ++++++++++-- LLama/Batched/ConversationExtensions.cs | 14 +- LLama/LLamaExecutorBase.cs | 135 ++++--- LLama/LLamaInstructExecutor.cs | 213 ++++++++++- LLama/LLamaInteractExecutor.cs | 301 +++++++++++---- LLama/LLamaSharp.csproj | 4 +- LLama/LLamaStatelessExecutor.cs | 6 +- LLama/LLavaWeights.cs | 137 ------- LLama/Native/LLavaImageEmbed.cs | 19 - LLama/Native/Load/NativeLibraryConfig.cs | 32 +- LLama/Native/Load/NativeLibraryUtils.cs | 2 +- LLama/Native/MtmdContextParams.cs | 148 ++++++++ LLama/Native/MtmdImageEmbed.cs | 20 + LLama/Native/NativeApi.LLava.cs | 63 ---- LLama/Native/NativeApi.Load.cs | 22 +- LLama/Native/NativeApi.Mtmd.cs | 312 ++++++++++++++++ LLama/Native/NativeApi.cs | 119 +++++- LLama/Native/SafeLlavaImageEmbedHandle.cs | 162 -------- LLama/Native/SafeLlavaModelHandle.cs | 137 ------- LLama/Native/SafeMtmdEmbed.cs | 247 +++++++++++++ LLama/Native/SafeMtmdInputChunk.cs | 150 ++++++++ LLama/Native/SafeMtmdInputChunks.cs | 103 ++++++ LLama/Native/SafeMtmdModelHandle.cs | 349 ++++++++++++++++++ LLama/Properties/InternalsVisibleTo.cs | 3 + LLama/SafeMtmdWeights.cs | 80 ++++ docs/Examples/LLavaInteractiveModeExecute.md | 129 ------- docs/Examples/MtmdInteractiveModeExecute.md | 41 ++ mkdocs.yml | 4 +- 41 files changed, 2828 insertions(+), 1061 deletions(-) delete mode 100644 LLama.Examples/Examples/BatchedExecutorLLava.cs create mode 100644 LLama.Examples/Examples/BatchedExecutorMtmd.cs rename LLama.Examples/Examples/{LlavaInteractiveModeExecute.cs => MtmdInteractiveModeExecute.cs} (59%) create mode 100644 LLama.Unittest/MtmdExecutorTests.cs create mode 100644 LLama.Unittest/MtmdWeightsTests.cs delete mode 100644 LLama.Unittest/Native/SafeLlamaModelHandleTests.cs delete mode 100644 LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs delete mode 100644 LLama/LLavaWeights.cs delete mode 100644 LLama/Native/LLavaImageEmbed.cs create mode 100644 LLama/Native/MtmdContextParams.cs create mode 100644 LLama/Native/MtmdImageEmbed.cs delete mode 100644 LLama/Native/NativeApi.LLava.cs create mode 100644 LLama/Native/NativeApi.Mtmd.cs delete mode 100644 LLama/Native/SafeLlavaImageEmbedHandle.cs delete mode 100644 LLama/Native/SafeLlavaModelHandle.cs create mode 100644 LLama/Native/SafeMtmdEmbed.cs create mode 100644 LLama/Native/SafeMtmdInputChunk.cs create mode 100644 LLama/Native/SafeMtmdInputChunks.cs create mode 100644 LLama/Native/SafeMtmdModelHandle.cs create mode 100644 LLama/Properties/InternalsVisibleTo.cs create mode 100644 LLama/SafeMtmdWeights.cs delete mode 100644 docs/Examples/LLavaInteractiveModeExecute.md create mode 100644 docs/Examples/MtmdInteractiveModeExecute.md diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs index c073cd4cd..23f07c6a1 100644 --- a/LLama.Examples/ExampleRunner.cs +++ b/LLama.Examples/ExampleRunner.cs @@ -15,7 +15,7 @@ public class ExampleRunner { "Chat Session: Automatic conversation", TalkToYourself.Run }, { "Chat Session: Chinese characters", ChatChineseGB2312.Run }, { "Executor: Interactive mode chat", InteractiveModeExecute.Run }, - { "Executor: Llava Interactive mode chat", LlavaInteractiveModeExecute.Run }, + { "Executor: Mtmd Interactive mode chat", MtmdInteractiveModeExecute.Run }, { "Executor: Instruct mode chat", InstructModeExecute.Run }, { "Executor: Stateless mode chat", StatelessModeExecute.Run }, { "Save and Load: chat session", SaveAndLoadSession.Run }, @@ -33,7 +33,7 @@ public class ExampleRunner { "Batched Executor: Save/Load", BatchedExecutorSaveAndLoad.Run }, { "Batched Executor: Fork", BatchedExecutorFork.Run }, { "Batched Executor: Rewind", BatchedExecutorRewind.Run }, - { "Batched Executor: LLava", BatchedExecutorLLava.Run }, + { "Batched Executor: Mtmd", BatchedExecutorMtmd.Run }, { "Batched Executor: BoolQ Benchmark", BatchedExecutorBoolQ.Run }, { "Batched Executor: Beam Search", BatchedExecutorBeamSearch.Run }, { "Custom Sampling Pipeline", CustomSampler.Run }, diff --git a/LLama.Examples/Examples/BatchedExecutorLLava.cs b/LLama.Examples/Examples/BatchedExecutorLLava.cs deleted file mode 100644 index a131e994e..000000000 --- a/LLama.Examples/Examples/BatchedExecutorLLava.cs +++ /dev/null @@ -1,91 +0,0 @@ -using System.Text; -using LLama.Batched; -using LLama.Common; -using LLama.Native; -using LLama.Sampling; -using Spectre.Console; - -namespace LLama.Examples.Examples; - -/// -/// Demonstrates using LLava (image embeddings) with the batched executor. -/// -public class BatchedExecutorLLava -{ - /// - /// How many tokens of response to generate - /// - public const int TokenCount = 64; - - public static async Task Run() - { - // Load model weights - var parameters = new ModelParams(UserSettings.GetModelPath()); - using var model = await LLamaWeights.LoadFromFileAsync(parameters); - using var llava = await LLavaWeights.LoadFromFileAsync(UserSettings.GetMMProjPath()); - - // Decide on the prompt - var prompt = model.Tokenize(AnsiConsole.Ask("Prompt (or ENTER for default):", "\nUSER: Provide a full description of the image.\nASSISTANT: "), true, false, Encoding.UTF8); - - // Get image and show it - var image = UserSettings.GetImagePath(); - AnsiConsole.Write(new CanvasImage(image)); - - // Create an executor with one conversation - using var executor = new BatchedExecutor(model, parameters); - using var conversation = executor.Create(); - - // Embed the image - SafeLlavaImageEmbedHandle embedding = null!; - await AnsiConsole - .Status() - .StartAsync("[yellow]Embedding image with CLIP[/]", async _ => - { - // ReSharper disable once AccessToDisposedClosure - embedding = llava.CreateImageEmbeddings(await File.ReadAllBytesAsync(image)); - }); - - // Pass in the image and run inference until the entire image has been processed - await AnsiConsole - .Status() - .StartAsync("[yellow]Processing image embedding with language model[/]", async _ => - { - conversation.Prompt(embedding); - while (executor.BatchedTokenCount > 0) - await executor.Infer(); - }); - - // Prompt with the text prompt - conversation.Prompt(prompt); - - // Run inference loop - var decoder = new StreamingTokenDecoder(executor.Context); - var sampler = new DefaultSamplingPipeline(); - await AnsiConsole - .Progress() - .StartAsync(async ctx => - { - var task = ctx.AddTask("Generating Response"); - task.MaxValue = TokenCount; - - // Run a normal inference loop - for (var i = 0; i < TokenCount; i++) - { - task.Increment(1); - - await executor.Infer(); - - var token = sampler.Sample(executor.Context.NativeHandle, conversation.GetSampleIndex()); - if (token.IsEndOfGeneration(executor.Context.Vocab)) - break; - - decoder.Add(token); - conversation.Prompt(token); - } - }); - - // Print final result - var str = decoder.Read(); - AnsiConsole.MarkupInterpolated($"[green]{str}[/]"); - } -} \ No newline at end of file diff --git a/LLama.Examples/Examples/BatchedExecutorMtmd.cs b/LLama.Examples/Examples/BatchedExecutorMtmd.cs new file mode 100644 index 000000000..b62f8b120 --- /dev/null +++ b/LLama.Examples/Examples/BatchedExecutorMtmd.cs @@ -0,0 +1,126 @@ +using System; +using System.Collections.Generic; +using System.IO; +using LLama.Batched; +using LLama.Common; +using LLama.Exceptions; +using LLama.Native; +using LLama.Sampling; +using Spectre.Console; + +namespace LLama.Examples.Examples; + +/// +/// Demonstrates how to evaluate an image with MTMD helpers and continue generation by +/// manually scheduling batches, similar to what the batched executor does internally. +/// +public class BatchedExecutorMtmd +{ + /// + /// Number of completion tokens to generate after sending the image prompt. + /// + public const int TokenCount = 10000; + + public static async Task Run() + { + // Load the base LLM and its clip/mtmd sidecar weights so the executor has everything it needs. + var parameters = new ModelParams(UserSettings.GetModelPath()); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); + var mtmdParams = MtmdContextParams.Default(); // reuse llama.cpp defaults for helper settings + mtmdParams.UseGpu = false; + var marker = mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; + + using var mtmd = await SafeMtmdWeights.LoadFromFileAsync(UserSettings.GetMMProjPath(), model, mtmdParams); // multimodal helper weights + + using var executor = new BatchedExecutor(model, parameters, mtmd); // drives batched token + chunk evaluation + + // Prepend the media marker so the helper knows where to inject the encoded image tokens. + var defaultPrompt = "\nUSER: Provide a full description of the image.\nASSISTANT: "; + var promptSuffix = AnsiConsole.Ask("Prompt (or ENTER for default):", defaultPrompt); + var promptText = string.Concat(marker, promptSuffix); + + var imagePath = UserSettings.GetImagePath(); + AnsiConsole.Write(new CanvasImage(imagePath)); + + var vocab = executor.Context.NativeHandle.ModelHandle.Vocab; + + // Simple low-temperature sampler keeps the demo deterministic-ish. + var sampler = new DefaultSamplingPipeline + { + Temperature = 0.1f + }; + + // Stream decoded text to the console as soon as tokens arrive. + var decoder = new StreamingTokenDecoder(executor.Context) + { + DecodeSpecialTokens = false + }; + + try + { + // Each conversation tracks its own KV cache sequence IDs. + var conversation = executor.Create(); + // enqueue the image so MtmdHelper sees it + conversation.QueueMedia(imagePath); + // schedule multimodal prompt + conversation.Prompt(promptText, addBos: true, special: true); + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Prompt queued with multimodal chunks. Generating response...\n"); + Console.ResetColor(); + + var remaining = TokenCount; + + // Run one decode/sampling/prompt cycle – mirrors the batched executor inner loop. + async Task ProcessNextAsync() + { + var decodeResult = await executor.Infer(); + if (decodeResult == DecodeResult.NoKvSlot) // KV cache exhausted – surface to the user + { + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine("Insufficient KV cache space for multimodal evaluation."); + Console.ResetColor(); + return false; + } + + if (decodeResult != DecodeResult.Ok) + throw new RuntimeError($"Failed to evaluate batch: {decodeResult}."); + + if (!conversation.RequiresSampling) // another conversation may still be queued + return true; + + var token = conversation.Sample(sampler); // pull logits (or -1 for mtmd chunk) and sample + if (token.IsEndOfGeneration(vocab)) + return false; + + decoder.Add(token); + var delta = decoder.Read(); + if (!string.IsNullOrEmpty(delta)) + Console.Write(delta); + + sampler.Accept(token); // keep sampler state in sync + conversation.Prompt(token); // feed the accepted token back into the batch + remaining--; + return remaining > 0; + } + + while (remaining > 0 && await ProcessNextAsync()) // continue until EOS or budget is reached + { + } + + Console.WriteLine(); + } + catch (IOException ex) + { + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine($"Could not load media '{imagePath}': {ex.Message}"); + Console.ResetColor(); + } + catch (RuntimeError ex) + { + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine($"MTMD processing failed: {ex.Message}"); + Console.ResetColor(); + } + } +} diff --git a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs b/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs similarity index 59% rename from LLama.Examples/Examples/LlavaInteractiveModeExecute.cs rename to LLama.Examples/Examples/MtmdInteractiveModeExecute.cs index 8cbf58dcd..ca0de3b77 100644 --- a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs +++ b/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs @@ -1,3 +1,5 @@ +using System.Collections.Generic; +using System.IO; using System.Text.RegularExpressions; using LLama.Common; using Spectre.Console; @@ -6,27 +8,32 @@ namespace LLama.Examples.Examples { - // This example shows how to chat with LLaVA model with both image and text as input. + // This example shows how to chat with Mtmd model with both image and text as input. // It uses the interactive executor to inference. - public class LlavaInteractiveModeExecute + public class MtmdInteractiveModeExecute { public static async Task Run() { string multiModalProj = UserSettings.GetMMProjPath(); string modelPath = UserSettings.GetModelPath(); string modelImage = UserSettings.GetImagePath(); - const int maxTokens = 1024; + const int maxTokens = 2048; var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n"; var parameters = new ModelParams(modelPath); + var mtmdParameters = MtmdContextParams.Default(); + mtmdParameters.UseGpu = false; + using var model = await LLamaWeights.LoadFromFileAsync(parameters); using var context = model.CreateContext(parameters); - - // Llava Init - using var clipModel = await LLavaWeights.LoadFromFileAsync(multiModalProj); - + + // Mtmd Init + using var clipModel = await SafeMtmdWeights.LoadFromFileAsync(multiModalProj, model, mtmdParameters ); + + var mediaMarker = mtmdParameters.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; + var ex = new InteractiveExecutor(context, clipModel); Console.ForegroundColor = ConsoleColor.Yellow; @@ -40,7 +47,7 @@ public static async Task Run() Temperature = 0.1f }, - AntiPrompts = new List { "\nUSER:" }, + AntiPrompts = new List { "\nASSISTANT:" }, MaxTokens = maxTokens }; @@ -48,30 +55,53 @@ public static async Task Run() do { - // Evaluate if we have images + // Evaluate if we have media // - var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); - var imageCount = imageMatches.Count(); - var hasImages = imageCount > 0; + var mediaMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); + var mediaCount = mediaMatches.Count(); + var hasMedia = mediaCount > 0; - if (hasImages) + if (hasMedia) { - var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); - var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList(); + var mediaPathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); + var mediaPaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList(); - List imageBytes; + var embeds = new List(); + var imageList = new List(); + var imageExtensions = new HashSet(StringComparer.OrdinalIgnoreCase) + { + ".png", + ".jpg", + ".jpeg", + ".bmp", + ".gif", + ".webp" + }; + try { - imageBytes = imagePaths.Select(File.ReadAllBytes).ToList(); + foreach (var mediaPath in mediaPaths) + { + var extension = Path.GetExtension(mediaPath); + if (!string.IsNullOrEmpty(extension) && imageExtensions.Contains(extension)) + { + // Keep the raw image data so the caller can reuse or inspect the images later. + imageList.Add(File.ReadAllBytes(mediaPath)); + } + + var embed = clipModel.LoadMedia(mediaPath); + embeds.Add(embed); + } } catch (IOException exception) { Console.ForegroundColor = ConsoleColor.Red; Console.Write( - $"Could not load your {(imageCount == 1 ? "image" : "images")}:"); + $"Could not load your {(mediaCount == 1 ? "media" : "medias")}:"); Console.Write($"{exception.Message}"); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Please try again."); + clipModel.ClearMedia(); break; } @@ -81,19 +111,17 @@ public static async Task Run() // https://github.com/ggerganov/llama.cpp/discussions/3620 ex.Context.NativeHandle.MemorySequenceRemove( LLamaSeqId.Zero, -1, -1 ); - int index = 0; - foreach (var path in imagePathsWithCurlyBraces) + // Replace placeholders with media markers (one marker per image) + foreach (var path in mediaPathsWithCurlyBraces) { - // First image replace to tag " : ""); + prompt = prompt.Replace(path, mediaMarker, StringComparison.Ordinal); } - Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine($"Here are the images, that are sent to the chat model in addition to your message."); Console.WriteLine(); - foreach (var consoleImage in imageBytes?.Select(bytes => new CanvasImage(bytes)) ?? Array.Empty()) + foreach (var consoleImage in imageList.Select(image => new CanvasImage(image.ToArray()))) { consoleImage.MaxWidth = 50; AnsiConsole.Write(consoleImage); @@ -108,10 +136,9 @@ public static async Task Run() // Initialize Images in executor // - foreach (var image in imagePaths) - { - ex.Images.Add(await File.ReadAllBytesAsync(image)); - } + ex.Embeds.Clear(); + foreach (var embed in embeds) + ex.Embeds.Add(embed); } Console.ForegroundColor = Color.White; diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj index 8d70d5637..6d69bc942 100644 --- a/LLama.Examples/LLama.Examples.csproj +++ b/LLama.Examples/LLama.Examples.csproj @@ -9,7 +9,7 @@ true true - 12 + 13 1701;1702;8604;SKEXP0001;SKEXP0050;SKEXP0052;SKEXP0003 diff --git a/LLama.Unittest/Constants.cs b/LLama.Unittest/Constants.cs index d501b189b..f705f1609 100644 --- a/LLama.Unittest/Constants.cs +++ b/LLama.Unittest/Constants.cs @@ -9,9 +9,9 @@ internal static class Constants public static readonly string EmbeddingModelPath = "Models/all-MiniLM-L12-v2.Q8_0.gguf"; public static readonly string RerankingModelPath = "Models/jina-reranker-v1-tiny-en-FP16.gguf"; - public static readonly string LLavaModelPath = "Models/llava-v1.6-mistral-7b.Q3_K_XS.gguf"; - public static readonly string LLavaMmpPath = "Models/mmproj-model-f16.gguf"; - public static readonly string LLavaImage = "Models/extreme-ironing-taxi-610x427.jpg"; + public static readonly string MtmdModelPath = "Models/gemma-3-4b-it-Q4_K_M.gguf"; + public static readonly string MtmdMmpPath = "Models/gemma-mmproj-model-f16.gguf"; + public static readonly string MtmdImage = "Models/extreme-ironing-taxi-610x427.jpg"; /// /// Calculate GpuLayer Count to use in UnitTest diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index 8f9f075d8..ca3ea8854 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -52,16 +52,16 @@ jina-reranker-v1-tiny-en-FP16.gguf - - https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/llava-v1.6-mistral-7b.Q3_K_XS.gguf + + https://huggingface.co/ggml-org/gemma-3-4b-it-GGUF/resolve/main/gemma-3-4b-it-Q4_K_M.gguf Models - llava-v1.6-mistral-7b.Q3_K_XS.gguf + gemma-3-4b-it-Q4_K_M.gguf - - https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/mmproj-model-f16.gguf + + https://huggingface.co/ggml-org/gemma-3-4b-it-GGUF/resolve/main/mmproj-model-f16.gguf Models - mmproj-model-f16.gguf + gemma-mmproj-model-f16.gguf @@ -142,10 +142,10 @@ PreserveNewest - + PreserveNewest - + PreserveNewest diff --git a/LLama.Unittest/MtmdExecutorTests.cs b/LLama.Unittest/MtmdExecutorTests.cs new file mode 100644 index 000000000..75a96b261 --- /dev/null +++ b/LLama.Unittest/MtmdExecutorTests.cs @@ -0,0 +1,81 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using LLama.Common; +using LLama.Native; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace LLama.Unittest; + +[Trait("Category", "NoCI")] +public class MtmdExecutorTests : IDisposable +{ + private readonly LLamaWeights _weights; + private readonly MtmdContextParams _mtmdParams; + private readonly SafeMtmdWeights _mtmd; + private readonly ModelParams _modelParams; + + public MtmdExecutorTests() + { + _modelParams = new ModelParams(Constants.MtmdModelPath) + { + ContextSize = 1024 * 8, + GpuLayerCount = Constants.CIGpuLayerCount, + }; + + _weights = LLamaWeights.LoadFromFile(_modelParams); + + _mtmdParams = MtmdContextParams.Default(); + _mtmdParams.NThreads = Math.Max(1, Constants.CIGpuLayerCount); + _mtmdParams.UseGpu = false; + + _mtmd = SafeMtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _weights, _mtmdParams); + } + + public void Dispose() + { + _mtmd.Dispose(); + _weights.Dispose(); + } + + [Fact] + public async Task InteractiveExecutor_EvaluateChunks_DoesNotRetokenize() + { + using var context = _weights.CreateContext(_modelParams, NullLogger.Instance); + var executor = new InteractiveExecutor(context, _mtmd, NullLogger.Instance); + var marker = _mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; + var prompt = $"{marker}\nDescribe the image succinctly."; + + executor.Embeds.Add(_mtmd.LoadMedia(Constants.MtmdImage)); + + await foreach (var _ in executor.InferAsync(prompt, new InferenceParams { MaxTokens = 0 })) + { + Assert.True(false, "Prefill should not emit generated text"); + } + + var diagnostics = executor.GetDiagnostics(); + Assert.Equal(diagnostics.EmbedCount, diagnostics.ConsumedCount); + Assert.Equal(diagnostics.ConsumedCount, diagnostics.PastCount); + Assert.Equal(0, diagnostics.PendingEmbedCount); + } + + [Fact] + public async Task InstructExecutor_MtmdPromptAdvancesPastTokensOnce() + { + using var context = _weights.CreateContext(_modelParams, NullLogger.Instance); + var executor = new InstructExecutor(context, _mtmd, logger: NullLogger.Instance); + executor.Embeds.Add(_mtmd.LoadMedia(Constants.MtmdImage)); + + var prompt = $"{_mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""} Provide details."; + + await foreach (var _ in executor.InferAsync(prompt, new InferenceParams { MaxTokens = 0 })) + { + } + + var diagnostics = executor.GetDiagnostics(); + Assert.Equal(diagnostics.EmbedCount, diagnostics.ConsumedCount); + Assert.Equal(diagnostics.ConsumedCount, diagnostics.PastCount); + Assert.Equal(0, diagnostics.PendingEmbedCount); + } +} diff --git a/LLama.Unittest/MtmdWeightsTests.cs b/LLama.Unittest/MtmdWeightsTests.cs new file mode 100644 index 000000000..947bbd1ea --- /dev/null +++ b/LLama.Unittest/MtmdWeightsTests.cs @@ -0,0 +1,140 @@ +using System; +using System.IO; +using LLama.Common; +using LLama.Native; +using Xunit; + +namespace LLama.Unittest +{ + // Test the same things as llama model + image embedings + // + public sealed class MtmdWeightTests + : IDisposable + { + private readonly LLamaWeights _llamaWeights; + private readonly SafeMtmdWeights _safeMtmdWeights; + private readonly LLamaContext _context; + private readonly MtmdContextParams _mtmdParams; + private readonly string _mediaMarker; + + public MtmdWeightTests() + { + var @params = new ModelParams(Constants.MtmdModelPath) + { + // Mtmd models requires big context + ContextSize = 1024 * 32, + GpuLayerCount = Constants.CIGpuLayerCount, + }; + _llamaWeights = LLamaWeights.LoadFromFile(@params); + + _mtmdParams = MtmdContextParams.Default(); + _mtmdParams.NThreads = Constants.CIGpuLayerCount; + _mtmdParams.UseGpu = false; // keep tests portable across environments without GPU + + _mediaMarker = _mtmdParams.MediaMarker ?? throw new InvalidOperationException("MTMD media marker unavailable."); + + _safeMtmdWeights = SafeMtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _llamaWeights, _mtmdParams); + _context = _llamaWeights.CreateContext(@params); + } + + public void Dispose() + { + _context.Dispose(); + _safeMtmdWeights.Dispose(); + _llamaWeights.Dispose(); + } + + private SafeMtmdInputChunks TokenizeWithEmbed(Func loadEmbed) + { + _safeMtmdWeights.ClearMedia(); + + var embed = loadEmbed(); + Assert.NotNull(embed); + + using (embed) + { + Assert.True(embed.Nx > 0); + Assert.True(embed.Ny > 0); + Assert.False(embed.IsAudio); + Assert.True(embed.GetDataSpan().Length > 0); + + var status = _safeMtmdWeights.Tokenize(_mediaMarker, addSpecial: true, parseSpecial: true, out var chunks); + Assert.Equal(0, status); + Assert.NotNull(chunks); + + return chunks!; + } + } + + private void AssertChunksEvaluate(SafeMtmdInputChunks chunks) + { + long nPast = 0; + var eval = _safeMtmdWeights.EvaluateChunks(chunks, _context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)_context.BatchSize), logitsLast: true); + Assert.Equal(0, eval); + Assert.True(nPast > 0); + } + + [Fact,Trait("Category", "NoCI")] + public void EmbedImageAsFileName() + { + using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(Constants.MtmdImage)); + AssertChunksEvaluate(chunks); + } + + [Fact,Trait("Category", "NoCI")] + public void EmbedImageAsBinary() + { + var imageBytes = File.ReadAllBytes(Constants.MtmdImage); + using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(imageBytes)); + AssertChunksEvaluate(chunks); + } + + [Fact,Trait("Category", "NoCI")] + public void TokenizeProvidesChunkMetadata() + { + using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(Constants.MtmdImage)); + + Assert.True(chunks.Size > 0); + + ulong totalTokens = 0; + long totalPositions = 0; + var imageChunks = 0; + + foreach (var chunk in chunks.Enumerate()) + { + totalTokens += chunk.NTokens; + totalPositions += chunk.NPos; + + if (chunk.Type == SafeMtmdInputChunk.SafeMtmdInputChunkType.Image) + { + imageChunks++; + + var copy = chunk.Copy(); + try + { + Assert.NotNull(copy); + if (copy != null) + { + Assert.Equal(chunk.NTokens, copy.NTokens); + Assert.Equal(chunk.NPos, copy.NPos); + } + } + finally + { + copy?.Dispose(); + } + } + } + + Assert.True(imageChunks > 0); + Assert.True(totalTokens > 0); + Assert.Equal(totalTokens, _safeMtmdWeights.CountTokens(chunks)); + Assert.Equal(totalPositions, _safeMtmdWeights.CountPositions(chunks)); + Assert.True(_safeMtmdWeights.SupportsVision); + Assert.False(_safeMtmdWeights.SupportsAudio); + + var audioBitrate = _safeMtmdWeights.AudioBitrate; + Assert.True(audioBitrate <= 0); + } + } +} diff --git a/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs b/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs deleted file mode 100644 index f3e5798f2..000000000 --- a/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs +++ /dev/null @@ -1,32 +0,0 @@ -using System.Runtime.InteropServices; -using System.Text; -using LLama.Common; -using LLama.Extensions; -using Xunit; - -namespace LLama.Unittest.Native; - -public class SafeLlamaModelHandleTests -{ - private readonly LLamaWeights _model; - - public SafeLlamaModelHandleTests() - { - var @params = new ModelParams(Constants.GenerativeModelPath2) - { - ContextSize = 1, - GpuLayerCount = Constants.CIGpuLayerCount - }; - _model = LLamaWeights.LoadFromFile(@params); - } - - // Note: This test is flakey, it appears to often (but not always) fail the first time it is run after downloading the model file, but then succeed every time after! - //[SkippableFact] - //public void MetadataValByKey_ReturnsCorrectly() - //{ - // Skip.If(RuntimeInformation.IsOSPlatform(OSPlatform.OSX), "Skipping this test on macOS because for some reason the meta data is incorrect, but the rest of tests work well on mscOS [Check later!]."); - // const string key = "general.name"; - // var template = _model.NativeHandle.MetadataValueByKey(key); - // var name = Encoding.UTF8.GetStringFromSpan(template!.Value.Span); - //} -} diff --git a/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs b/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs deleted file mode 100644 index 1ce53f395..000000000 --- a/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs +++ /dev/null @@ -1,42 +0,0 @@ -using System.Text; -using System.Xml.Linq; -using LLama.Common; -using LLama.Extensions; -using Microsoft.Extensions.Logging; - - -namespace LLama.Unittest.Native; - -public class SafeLlamaModelHandleVocabularyTests: IDisposable -{ - private readonly LLamaWeights _model; - - public SafeLlamaModelHandleVocabularyTests() - { - var @params = new ModelParams(Constants.RerankingModelPath) - { - ContextSize = 0, - PoolingType = LLama.Native.LLamaPoolingType.Rank, - GpuLayerCount = Constants.CIGpuLayerCount - }; - _model = LLamaWeights.LoadFromFile(@params); - } - - public void Dispose() - { - _model.Dispose(); - } - - [Fact] - public void GetLLamaTokenString() - { - var bos = _model.Vocab.BOS; - var eos = _model.Vocab.EOS; - - var bosStr = _model.Vocab.LLamaTokenToString(bos, true); - var eosStr = _model.Vocab.LLamaTokenToString(eos, true); - - Assert.Equal("", bosStr); - Assert.Equal("", eosStr); - } -} diff --git a/LLama/Abstractions/ILLamaExecutor.cs b/LLama/Abstractions/ILLamaExecutor.cs index 9a2233287..92276e4a6 100644 --- a/LLama/Abstractions/ILLamaExecutor.cs +++ b/LLama/Abstractions/ILLamaExecutor.cs @@ -1,5 +1,6 @@ using System.Collections.Generic; using System.Threading; +using LLama.Native; namespace LLama.Abstractions { @@ -22,12 +23,12 @@ public interface ILLamaExecutor /// /// Multi-Modal Projections / Clip Model weights /// - public LLavaWeights? ClipModel { get; } + public SafeMtmdWeights? ClipModel { get; } /// - /// List of images: List of images in byte array format. + /// List of media: List of media for Multi-Modal models. /// - public List Images { get; } + public List Embeds { get; } /// /// Asynchronously infers a response from the model. diff --git a/LLama/Batched/BatchedExecutor.cs b/LLama/Batched/BatchedExecutor.cs index 462e9e555..40468c98d 100644 --- a/LLama/Batched/BatchedExecutor.cs +++ b/LLama/Batched/BatchedExecutor.cs @@ -16,6 +16,7 @@ public sealed class BatchedExecutor { private int _nextSequenceId; private readonly List _batchQueue = []; + private string? _mtmdMarker; private int _batchQueueHead; private int _batchedTokenCount; private bool _batchedTokenCountDirty = true; @@ -79,12 +80,20 @@ public int BatchedTokenCount /// The model to use /// Parameters to create a new context public BatchedExecutor(LLamaWeights model, IContextParams contextParams) + : this(model, contextParams, null) + { + } + + public BatchedExecutor(LLamaWeights model, IContextParams contextParams, SafeMtmdWeights? clipModel) { Model = model; Context = model.CreateContext(contextParams); + ClipModel = clipModel; Epoch = 1; } + public SafeMtmdWeights? ClipModel { get; } + /// /// Start a new /// @@ -314,6 +323,23 @@ internal LLamaSeqId GetNextSequenceId() return (end, Epoch + (uint)(_batchQueue.Count - _batchQueueHead) * 2); } + internal ulong QueueMtmdBatch(Conversation conversation, Conversation.MtmdChunkSequence sequence) + { + if (ClipModel is null) + throw new InvalidOperationException("This batched executor is not configured for multimodal inference."); + + var batch = new MtmdChunkBatch(ClipModel, conversation, sequence); + _batchQueue.Add(batch); + return Epoch + (uint)_batchQueue.Count * 2; + } + + internal string GetMtmdMarker() + { + if (ClipModel is null) + throw new InvalidOperationException("This batched executor is not configured for multimodal inference."); + return _mtmdMarker ??= NativeApi.MtmdDefaultMarker() ?? ""; + } + #region batches private interface IBatch { @@ -345,5 +371,44 @@ public Task DecodeAsync(LLamaContext ctx, CancellationToken token) return ctx.DecodeAsync(Batch, token); } } + + private class MtmdChunkBatch : IBatch + { + private readonly SafeMtmdWeights _clipModel; + private readonly Conversation _conversation; + private readonly Conversation.MtmdChunkSequence _sequence; + + public MtmdChunkBatch(SafeMtmdWeights clipModel, Conversation conversation, Conversation.MtmdChunkSequence sequence) + { + _clipModel = clipModel; + _conversation = conversation; + _sequence = sequence; + } + + public int ItemCount => Math.Max(1, _sequence.TotalTokens); + + public Task DecodeAsync(LLamaContext ctx, CancellationToken token) + { + try + { + var nPast = _conversation.GetMtmdPast(); + var status = _clipModel.EvaluateChunks(_sequence.Chunks, ctx.NativeHandle, ref nPast, + (int)_conversation.ConversationId, checked((int)ctx.BatchSize), logitsLast: true); + if (status != 0) + { + _conversation.OnMtmdEvaluationFailed(status); + return Task.FromResult(DecodeResult.DecodeFailed); + } + + _conversation.OnMtmdEvaluationCompleted(nPast, _sequence); + return Task.FromResult(DecodeResult.Ok); + } + catch + { + _conversation.OnMtmdEvaluationFailed(-1); + return Task.FromResult(DecodeResult.DecodeFailed); + } + } + } #endregion } diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs index c504ce07a..2311c8a0c 100644 --- a/LLama/Batched/Conversation.cs +++ b/LLama/Batched/Conversation.cs @@ -3,6 +3,7 @@ using System.Linq; using System.Text.Json; using CommunityToolkit.HighPerformance.Buffers; +using LLama.Exceptions; using LLama.Native; namespace LLama.Batched; @@ -21,6 +22,12 @@ public sealed class Conversation /// Indicates if this conversation has been "forked" and may share logits with another conversation. /// private bool _forked; + private readonly List _mtmdEmbeds = new(); + private int? _mtmdLogitsIndex; + private MtmdChunkSequence? _pendingMtmdSequence; + private readonly List _embed_inps = new(); + private readonly List _session_tokens = new(); + private int _consumedTokensCount; /// /// Stores the indices to sample from. Contains valid items. @@ -65,6 +72,46 @@ internal Conversation(BatchedExecutor batch, LLamaSeqId id) Executor = batch; } + internal sealed class MtmdChunkSequence : IDisposable + { + public SafeMtmdInputChunks Chunks { get; } + public List TextTokens { get; } + public int TotalPositions { get; } + public int TotalTokens => TextTokens.Count; + + private MtmdChunkSequence(SafeMtmdInputChunks chunks, List textTokens, int totalPositions) + { + Chunks = chunks; + TextTokens = textTokens; + TotalPositions = totalPositions; + } + + public static MtmdChunkSequence Create(SafeMtmdInputChunks chunks, SafeMtmdWeights clipModel) + { + var textTokens = new List(); + + foreach (var chunk in chunks.Enumerate()) + { + using (chunk) + { + if (chunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) + continue; + + foreach (var token in chunk.GetTextTokensSpan()) + textTokens.Add((LLamaToken)unchecked((int)token)); + } + } + + var totalPositions = (int)clipModel.CountPositions(chunks); + return new MtmdChunkSequence(chunks, textTokens, totalPositions); + } + + public void Dispose() + { + Chunks.Dispose(); + } + } + /// /// Finalizer for Conversation /// @@ -83,6 +130,11 @@ public void Dispose() return; _disposed = true; + _pendingMtmdSequence?.Dispose(); + _pendingMtmdSequence = null; + + DisposeQueuedMedia(); + // Remove this conversation from the KV cache Executor.Context.NativeHandle.MemorySequenceRemove(ConversationId, -1, -1); @@ -206,6 +258,43 @@ private void AssertCanBePrompted() if (RequiresInference) throw new AlreadyPromptedConversationException(); + + _mtmdLogitsIndex = null; + } + + public void QueueMedia(string path) + { + AssertCanBePrompted(); + + if (Executor.ClipModel is null) + throw new InvalidOperationException("This conversation is not configured for multimodal prompts."); + + var embed = Executor.ClipModel.LoadMedia(path); + _mtmdEmbeds.Add(embed); + _mtmdLogitsIndex = null; + } + + public void QueueMedia(SafeMtmdEmbed embed) + { + AssertCanBePrompted(); + + if (Executor.ClipModel is null) + throw new InvalidOperationException("This conversation is not configured for multimodal prompts."); + + _mtmdEmbeds.Add(embed); + _mtmdLogitsIndex = null; + } + + public void Prompt(string promptText, bool addBos = true, bool special = true) + { + if (Executor.ClipModel != null && _mtmdEmbeds.Count > 0) + { + PromptMultimodal(promptText, addBos); + return; + } + + var tokens = Executor.Context.Tokenize(promptText, addBos, special); + Prompt(tokens); } /// @@ -246,6 +335,7 @@ public void Prompt(List tokens, bool allLogits = false) public void Prompt(ReadOnlySpan tokens, bool allLogits = false) { AssertCanBePrompted(); + _mtmdLogitsIndex = null; // No point doing anything if there is no actual prompt! if (tokens.Length == 0) @@ -289,6 +379,59 @@ public void Prompt(ReadOnlySpan tokens, bool allLogits = false) // Unset the forked flag. Since this conversation has just been prompted it's no longer // sharing anything with any other conversations. _forked = false; + _mtmdLogitsIndex = null; + } + + private void PromptMultimodal(string text, bool addBos) + { + AssertCanBePrompted(); + + if (Executor.ClipModel is null) + throw new InvalidOperationException("This conversation is not configured for multimodal prompts."); + if (_mtmdEmbeds.Count == 0) + throw new InvalidOperationException("Queue media before prompting with multimodal input."); + + var marker = Executor.GetMtmdMarker(); + var prompt = text; + + if (prompt.Contains("")) + prompt = prompt.Replace("", marker); + + if (!prompt.Contains(marker)) + { + var suffix = string.Concat(Enumerable.Repeat(marker, _mtmdEmbeds.Count)); + prompt = string.Concat(prompt, suffix); + } + + SafeMtmdInputChunks? chunks = null; + try + { + _mtmdLogitsIndex = null; + var status = Executor.ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); + if (status != 0 || chunks is null) + { + Executor.ClipModel.ClearMedia(); + throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); + } + + var sequence = MtmdChunkSequence.Create(chunks, Executor.ClipModel); + _pendingMtmdSequence = sequence; + + var epoch = Executor.QueueMtmdBatch(this, sequence); + chunks = null; + + if (_batchSampleIndices.Length == 0) + _batchSampleIndices = new int[4]; + + _batchSampleCount = 0; + _requiredEpoch = epoch; + _forked = false; + } + finally + { + DisposeQueuedMedia(); + chunks?.Dispose(); + } } /// @@ -305,32 +448,7 @@ public void Prompt(LLamaToken token) Span span = [ token ]; Prompt(span); } - - /// - /// Prompt this conversation with an image embedding - /// - /// - public void Prompt(SafeLlavaImageEmbedHandle embedding) - { - AssertCanBePrompted(); - - if (embedding.Model.EmbeddingDimensions != Executor.Model.EmbeddingSize) - throw new ArgumentException($"Embedding dimension mismatch between image embedding ({embedding.Model.EmbeddingDimensions}) and model ({Executor.Model.EmbeddingSize})"); - - for (var i = 0; i < embedding.Model.PatchCount; i++) - { - // Get a batch with space - (var batch, _requiredEpoch) = Executor.GetEmbeddingBatch(); - - batch.Add( - (i, embedding), - static (Span dest, (int index, SafeLlavaImageEmbedHandle embedding) tup) => tup.embedding.GetEmbedding(dest, tup.index), - _end++, - ConversationId, - i == embedding.Model.PatchCount - 1 - ); - } - } + /// /// Prompt this conversation with embeddings @@ -339,6 +457,7 @@ public void Prompt(SafeLlavaImageEmbedHandle embedding) public void Prompt(ReadOnlySpan embeddings) { AssertCanBePrompted(); + _mtmdLogitsIndex = null; var dim = Executor.Model.EmbeddingSize; var count = embeddings.Length / dim; @@ -385,6 +504,75 @@ public void Modify(ModifyKvCache modifier) _requiredEpoch = 0; } + internal long GetMtmdPast() => _end.Value; + + internal void OnMtmdEvaluationCompleted(long newPast, MtmdChunkSequence sequence) + { + _pendingMtmdSequence?.Dispose(); + _pendingMtmdSequence = null; + + _end = (LLamaPos)checked((int)newPast); + + if (_batchSampleIndices.Length == 0) + _batchSampleIndices = new int[4]; + + _batchSampleCount = 1; + _batchSampleIndices[0] = 0; + _mtmdLogitsIndex = -1; + _requiredEpoch = Executor.Epoch + 1; + _forked = false; + + if (sequence.TextTokens.Count > 0) + { + _embed_inps.AddRange(sequence.TextTokens); + _session_tokens.AddRange(sequence.TextTokens); + } + + var fillerToken = GetFillerToken(Executor.GetMtmdMarker()); + var fillerCount = Math.Max(0, sequence.TotalPositions - sequence.TotalTokens); + for (var i = 0; i < fillerCount; i++) + _embed_inps.Add(fillerToken); + + _consumedTokensCount = _embed_inps.Count; + sequence.Dispose(); + } + + internal void OnMtmdEvaluationFailed(int status) + { + _pendingMtmdSequence?.Dispose(); + _pendingMtmdSequence = null; + _mtmdLogitsIndex = null; + _requiredEpoch = Executor.Epoch; + DisposeQueuedMedia(); + } + + internal int? MtmdLogitsIndex => _mtmdLogitsIndex; + + private LLamaToken GetFillerToken(string marker) + { + var markerTokens = Executor.Context.Tokenize(marker, addBos: false, special: true); + if (markerTokens.Length > 0) + return markerTokens[markerTokens.Length - 1]; + + var eos = Executor.Context.Vocab.EOS; + if (eos.HasValue) + return eos.Value; + + return default; + } + + private void DisposeQueuedMedia() + { + if (_mtmdEmbeds.Count == 0) + return; + + foreach (var embed in _mtmdEmbeds) + embed.Dispose(); + + _mtmdEmbeds.Clear(); + Executor.ClipModel?.ClearMedia(); + } + /// /// Provides direct access to the KV cache of a . /// See for how to use this. @@ -629,4 +817,4 @@ internal State() } } #endregion -} \ No newline at end of file +} diff --git a/LLama/Batched/ConversationExtensions.cs b/LLama/Batched/ConversationExtensions.cs index eb0192061..3e25d3f43 100644 --- a/LLama/Batched/ConversationExtensions.cs +++ b/LLama/Batched/ConversationExtensions.cs @@ -18,7 +18,11 @@ public static class ConversationExtensions /// public static LLamaToken Sample(this Conversation conversation, SafeLLamaSamplerChainHandle sampler, int offset = 0) { - return sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.GetSampleIndex(offset)); + var ctx = conversation.Executor.Context.NativeHandle; + if (conversation.MtmdLogitsIndex == -1) + return sampler.Sample(ctx, -1); + + return sampler.Sample(ctx, conversation.GetSampleIndex(offset)); } /// @@ -30,7 +34,11 @@ public static LLamaToken Sample(this Conversation conversation, SafeLLamaSampler /// public static LLamaToken Sample(this Conversation conversation, ISamplingPipeline sampler, int offset = 0) { - return sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.GetSampleIndex(offset)); + var ctx = conversation.Executor.Context.NativeHandle; + if (conversation.MtmdLogitsIndex == -1) + return sampler.Sample(ctx, -1); + + return sampler.Sample(ctx, conversation.GetSampleIndex(offset)); } /// @@ -82,4 +90,4 @@ public static void ShiftLeft(this Conversation conversation, int count, int keep return end.Value - count; }); } -} \ No newline at end of file +} diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index d0829deca..a39ad3836 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -32,11 +32,11 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// protected int _consumedTokensCount; // n_consume /// - /// + /// Number of tokens consumed from the session cache during the current run. /// protected int _n_session_consumed; /// - /// + /// Number of prompt tokens that match the loaded session cache prefix. /// protected int _n_matching_session_tokens; /// @@ -52,7 +52,7 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// protected List _embed_inps = new(); /// - /// + /// Tokens recovered from the session file and reused to warm up the KV cache. /// protected List _session_tokens = new(); /// @@ -81,21 +81,21 @@ public bool IsMultiModal } /// - public LLavaWeights? ClipModel { get; } + public SafeMtmdWeights? ClipModel { get; } /// - public List Images { get; } + public List Embeds { get; } private readonly StreamingTokenDecoder _decoder; /// - /// + /// Initialize a stateful executor bound to a specific context. /// - /// - /// + /// LLama context used for all native interactions. + /// Optional logger for diagnostic output. protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) { - Images = new List(); + Embeds = new List(); _logger = logger; Context = context; _pastTokensCount = 0; @@ -107,22 +107,22 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) } /// - /// + /// Initialize a multimodal executor with the supplied MTMD weights. /// - /// - /// - /// - public StatefulExecutorBase(LLamaContext context, LLavaWeights lLavaWeights, ILogger? logger = null) : + /// LLama context used for all native interactions. + /// Multimodal weights to associate with this executor. + /// Optional logger for diagnostic output. + public StatefulExecutorBase(LLamaContext context, SafeMtmdWeights safeMtmdWeights, ILogger? logger = null) : this( context, logger ) { - ClipModel = lLavaWeights; + ClipModel = safeMtmdWeights; } /// - /// This API is currently not verified. + /// Attach a session cache file so the executor can reuse previous KV state if compatible. /// - /// - /// + /// Path to the llama.cpp session file. + /// The current executor instance for fluent configuration. /// /// public StatefulExecutorBase WithSessionFile(string filename) @@ -179,9 +179,9 @@ public StatefulExecutorBase WithSessionFile(string filename) } /// - /// This API has not been verified currently. + /// Persist the current session cache to disk. /// - /// + /// Destination path for the llama.cpp session file. public void SaveSessionFile(string filename) { var session_token_array = _session_tokens.ToArray(); @@ -209,7 +209,7 @@ protected virtual void HandleRunOutOfContext(int tokensToKeep) } /// - /// Try to reuse the matching prefix from the session file. + /// Try to reuse the matching prompt prefix from the loaded session cache before evaluating new tokens. /// protected virtual void TryReuseMatchingPrefix() { @@ -243,73 +243,73 @@ protected virtual void TryReuseMatchingPrefix() } /// - /// Decide whether to continue the loop. + /// Determine whether the inference loop should continue processing tokens. /// - /// + /// Mutable state associated with the current inference. /// - /// + /// true to continue generating; otherwise false. protected abstract Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken = default); /// - /// Preprocess the inputs before the inference. + /// Prepare the executor for inference by tokenizing input and updating cached state. /// - /// - /// + /// Prompt text to process. + /// Mutable state associated with the current inference. /// protected abstract Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken = default); /// - /// Do some post processing after the inference. + /// Perform any post-processing on the generated tokens. /// - /// - /// + /// Parameters controlling sampling. /// - /// + /// Mutable state associated with the current inference. + /// A tuple indicating whether generation should stop and any extra outputs to emit. protected abstract Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default); /// - /// The core inference logic. + /// Core inference loop that advances the model by one step. /// - /// - /// + /// Parameters controlling sampling. + /// Mutable state associated with the current inference. /// protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default); /// - /// Save the current state to a file. + /// Save the executor state to a serialized snapshot file. /// - /// + /// Destination file for the serialized state. /// public abstract Task SaveState(string filename, CancellationToken cancellationToken = default); /// - /// Get the current state data. + /// Capture the executor state in a serializable object. /// - /// + /// State snapshot suitable for persistence. public abstract ExecutorBaseState GetStateData(); /// - /// Load the state from data. + /// Restore executor state from a previously captured snapshot. /// - /// + /// State snapshot created by . /// public abstract Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default); /// - /// Load the state from a file. + /// Restore executor state from a serialized snapshot file. /// - /// + /// Path to the snapshot produced by . /// public abstract Task LoadState(string filename, CancellationToken cancellationToken = default); /// - /// Execute the inference. + /// Execute an asynchronous inference session. /// - /// The prompt. If null, generation will continue where it left off previously. - /// - /// - /// + /// Optional prompt; when null generation resumes from prior state. + /// Sampling parameters to apply; defaults are used when null. + /// Cancellation token for cooperative cancellation. + /// Stream of decoded text segments as they become available. public virtual async IAsyncEnumerable InferAsync(string? text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); @@ -390,12 +390,12 @@ public virtual async Task PrefillPromptAsync(string prompt, CancellationToken ca } /// - /// State arguments that are used in single inference + /// Mutable state passed between inference callbacks during a single generation pass. /// protected class InferStateArgs { /// - /// + /// Anti-prompts that terminate generation when encountered. /// public IList? Antiprompts { get; set; } /// @@ -403,15 +403,15 @@ protected class InferStateArgs /// public int RemainedTokens { get; set; } /// - /// + /// Indicates whether generated tokens should be returned to the caller. /// public bool ReturnValue { get; set; } /// - /// + /// Signals that the executor should pause and wait for additional user input. /// public bool WaitForInput { get; set; } /// - /// + /// Indicates whether the session cache should be persisted after inference completes. /// public bool NeedToSaveSession { get; set; } @@ -422,6 +422,9 @@ protected class InferStateArgs } #pragma warning disable CS1591, CS8618 // Missing XML and irrelevant nullable warnings + /// + /// Serializable snapshot of executor state used for persistence and restart. + /// [JsonConverter(typeof(PolymorphicJSONConverter))] public class ExecutorBaseState { @@ -459,5 +462,33 @@ public class ExecutorBaseState public float? MirostatMu { get; set; } } #pragma warning restore + + internal ExecutorDiagnostics GetDiagnostics() + { + return new ExecutorDiagnostics( + _embed_inps.Count, + _consumedTokensCount, + _pastTokensCount, + _embeds.Count); + } + } +} + +namespace LLama +{ + internal readonly struct ExecutorDiagnostics + { + public ExecutorDiagnostics(int embedCount, int consumedCount, int pastCount, int pendingEmbeds) + { + EmbedCount = embedCount; + ConsumedCount = consumedCount; + PastCount = pastCount; + PendingEmbedCount = pendingEmbeds; + } + + public int EmbedCount { get; } + public int ConsumedCount { get; } + public int PastCount { get; } + public int PendingEmbedCount { get; } } } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 517a4e7d0..1bdba035a 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Text; using System.Text.Json; using System.Text.Json.Serialization; using System.Threading; @@ -25,6 +26,9 @@ public class InstructExecutor private readonly string _instructionPrefix; private LLamaToken[] _inp_pfx; private LLamaToken[] _inp_sfx; + private SafeMtmdInputChunks? _mtmdChunks; + private string? _mtmdMarker; + private readonly string _instructionSuffix; /// /// @@ -42,6 +46,20 @@ public InstructExecutor(LLamaContext context, _inp_pfx = Context.Tokenize(instructionPrefix, true, true); _inp_sfx = Context.Tokenize(instructionSuffix, false, true); _instructionPrefix = instructionPrefix; + _instructionSuffix = instructionSuffix; + } + + public InstructExecutor(LLamaContext context, + SafeMtmdWeights clipModel, + string instructionPrefix = "\n\n### Instruction:\n\n", + string instructionSuffix = "\n\n### Response:\n\n", + ILogger? logger = null) + : base(context, clipModel, logger) + { + _inp_pfx = Context.Tokenize(instructionPrefix, true, true); + _inp_sfx = Context.Tokenize(instructionSuffix, false, true); + _instructionPrefix = instructionPrefix; + _instructionSuffix = instructionSuffix; } /// @@ -68,7 +86,8 @@ public override ExecutorBaseState GetStateData() /// public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default) { - if (data is InstructExecutorState state) + DisposeMtmdChunks(); + if(data is InstructExecutorState state) { _n_session_consumed = state.ConsumedSessionCount; _embed_inps = state.EmbedInps!.ToList(); @@ -128,7 +147,14 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc { // When running the first input (prompt) in interactive mode, we should specially process it. if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously."); - _embed_inps = Context.Tokenize(text, true, true).ToList(); + if (!IsMultiModal) + { + _embed_inps = Context.Tokenize(text, true, true).ToList(); + } + else + { + return PreprocessMtmd(text, args, addBos: true, replaceExisting: true); + } } else { @@ -141,20 +167,161 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc { text += "\n"; } - _embed_inps.AddRange(_inp_pfx); + if (!IsMultiModal) + { + _embed_inps.AddRange(_inp_pfx); - var line_inp = Context.Tokenize(text, false, true); - _embed_inps.AddRange(line_inp); + var line_inp = Context.Tokenize(text, false, true); + _embed_inps.AddRange(line_inp); - _embed_inps.AddRange(_inp_sfx); + _embed_inps.AddRange(_inp_sfx); - args.RemainedTokens -= line_inp.Length; + args.RemainedTokens -= line_inp.Length; + } + else + { + var builder = new StringBuilder(); + builder.Append(_instructionPrefix); + builder.Append(text); + builder.Append(_instructionSuffix); + return PreprocessMtmd(builder.ToString(), args, addBos: false, replaceExisting: false); + } } } return Task.CompletedTask; } + private void DisposeMtmdChunks() + { + _mtmdChunks?.Dispose(); + _mtmdChunks = null; + } + + private void DisposeEmbeds() + { + if (Embeds.Count == 0) + return; + + foreach (var embed in Embeds) + embed.Dispose(); + + Embeds.Clear(); + } + + private string GetMtmdMarker() + { + if (_mtmdMarker is not null) + return _mtmdMarker; + + _mtmdMarker = NativeApi.MtmdDefaultMarker() ?? ""; + return _mtmdMarker; + } + + private static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken) + { + if (totalPositions <= tokens.Count) + return new List(tokens); + + var result = new List(totalPositions); + result.AddRange(tokens); + result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count)); + return result; + } + + private LLamaToken GetFillerToken(string marker) + { + var markerTokens = Context.Tokenize(marker, false, true); + if (markerTokens.Length > 0) + return markerTokens[markerTokens.Length - 1]; + + var eos = Context.Vocab.EOS; + if (eos.HasValue) + return eos.Value; + + return default(LLamaToken); + } + + private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, bool replaceExisting) + { + if (ClipModel is null) + throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); + + DisposeMtmdChunks(); + + var marker = GetMtmdMarker(); + var prompt = text; + + if (Embeds.Count > 0) + { + if (prompt.Contains("")) + prompt = prompt.Replace("", marker); + + if (!prompt.Contains(marker)) + { + var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count)); + prompt = string.Concat(prompt, suffix); + } + } + + SafeMtmdInputChunks? chunks = null; + try + { + var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); + if (status != 0 || chunks is null) + { + ClipModel.ClearMedia(); + throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); + } + + _mtmdChunks = chunks; + + var tokens = new List(); + foreach (var chunk in chunks.Enumerate()) + { + using var scopedChunk = chunk; + if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) + continue; + + foreach (var token in scopedChunk.GetTextTokensSpan()) + tokens.Add(unchecked((int)token)); + } + + var totalPositions = (int)ClipModel.CountPositions(chunks); + var fillerToken = GetFillerToken(marker); + + if (replaceExisting) + { + _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken); + _consumedTokensCount = 0; + } + else + { + if (_embed_inps.Count == 0) + _embed_inps = new List(); + + _embed_inps.AddRange(tokens); + var fillerCount = totalPositions - tokens.Count; + if (fillerCount > 0) + _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount)); + + args.RemainedTokens -= tokens.Count; + } + } + catch + { + chunks?.Dispose(); + _mtmdChunks = null; + throw; + } + finally + { + DisposeEmbeds(); + } + + return Task.CompletedTask; + } + /// protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { @@ -217,11 +384,43 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In _n_session_consumed = _session_tokens.Count; } } + else if (IsMultiModal && _mtmdChunks is not null) + { + _is_prompt_run = false; + var nPast = (long)_pastTokensCount; + var previousConsumed = _consumedTokensCount; + var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)Context.BatchSize), logitsLast: true); + if (evalStatus != 0) + { + _logger?.LogError("[InstructExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); + DisposeMtmdChunks(); + throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); + } + + _pastTokensCount = checked((int)nPast); + DisposeMtmdChunks(); + + if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) + { + _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); + _n_session_consumed = _session_tokens.Count; + } + + _consumedTokensCount = _embed_inps.Count; + _embeds.Clear(); + } _embeds.Clear(); if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) { + if (inferenceParams.MaxTokens == 0) + { + _embeds.Clear(); + args.WaitForInput = true; + args.ReturnValue = false; + return; + } // optionally save the session on first sample (for faster prompt loading next time) if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) { diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index e7cac4c47..97d49f5de 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using LLama.Abstractions; using LLama.Common; +using LLama; using LLama.Exceptions; using LLama.Native; using LLama.Sampling; @@ -21,30 +22,31 @@ namespace LLama /// public class InteractiveExecutor : StatefulExecutorBase { + // Indicates whether the executor is currently evaluating the initial prompt or a follow-up turn. private bool _is_prompt_run = true; - // LLava - private int _EmbedImagePosition = -1; - private List _imageEmbedHandles = new List(); - private bool _imageInPrompt = false; + // MTMD multimodal state + private SafeMtmdInputChunks? _mtmdChunks; // Pending chunk collection produced by the multimodal tokenizer. + private string? _mtmdMarker; // Cached multimodal marker returned by the native helper. + /// - /// + /// Create an interactive executor for text-only inference. /// - /// - /// + /// LLama context to operate against. + /// Optional logger for diagnostic output. public InteractiveExecutor(LLamaContext context, ILogger? logger = null) : base(context, logger) { } /// - /// + /// Create an interactive multimodal executor that can process text alongside media inputs. /// - /// - /// - /// - public InteractiveExecutor(LLamaContext context, LLavaWeights clipModel, ILogger? logger = null) + /// LLama context to operate against. + /// Multimodal weights (MTMD) to attach to the executor. + /// Optional logger for diagnostic output. + public InteractiveExecutor(LLamaContext context, SafeMtmdWeights clipModel, ILogger? logger = null) : base(context, clipModel, logger) { } @@ -72,6 +74,7 @@ public override ExecutorBaseState GetStateData() /// public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default) { + DisposeMtmdChunks(); if (data is InteractiveExecutorState state) { _n_session_consumed = state.ConsumedSessionCount; @@ -111,15 +114,20 @@ public override async Task LoadState(string filename, CancellationToken cancella } /// - /// Define whether to continue the loop to generate responses. + /// Decide whether generation should continue for the current iteration. /// - /// + /// Mutable inference state. + /// true to keep generating; otherwise false. protected override Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken) { return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run); } - /// + /// + /// Preprocess the incoming prompt or continuation text before inference. + /// + /// Prompt text or continuation provided by the caller. + /// Mutable inference state. protected override Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken) { if (_is_prompt_run) @@ -136,7 +144,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc } else { - PreprocessLlava(text, args, true); + PreprocessMtmd(text, args, true); } } else @@ -157,7 +165,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc } else { - PreprocessLlava(text, args, false); + PreprocessMtmd(text, args, false); } } } @@ -165,51 +173,172 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc return Task.CompletedTask; } - /// - private void PreprocessLlava(string text, InferStateArgs args, bool addBos = true) + /// + /// Release any queued multimodal chunks and reset state. + /// + private void DisposeMtmdChunks() + { + _mtmdChunks?.Dispose(); + _mtmdChunks = null; + } + + /// + /// Dispose and clear any pending multimodal embeddings queued for evaluation. + /// + private void DisposeEmbeds() + { + if (Embeds.Count == 0) + { + return; + } + + foreach (var embed in Embeds) + { + embed.Dispose(); + } + + Embeds.Clear(); + } + + /// + /// Retrieve the marker token used to signal media segments to the tokenizer. + /// + private string GetMtmdMarker() { - // If the prompt contains the tag extract this. - _imageInPrompt = text.Contains(""); - if (_imageInPrompt && IsMultiModal) + if (_mtmdMarker is not null) { - foreach (var image in Images) + return _mtmdMarker; + } + + _mtmdMarker = NativeApi.MtmdDefaultMarker() ?? ""; + return _mtmdMarker; + } + + private static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken) + { + if (totalPositions <= tokens.Count) + return new List(tokens); + + var result = new List(totalPositions); + result.AddRange(tokens); + result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count)); + return result; + } + + private LLamaToken GetFillerToken(string marker) + { + var markerTokens = Context.Tokenize(marker, false, true); + if (markerTokens.Length > 0) + return markerTokens[markerTokens.Length - 1]; + + var eos = Context.Vocab.EOS; + if (eos.HasValue) + return eos.Value; + + return default(LLamaToken); + } + + /// + /// Preprocess multimodal prompts by aligning media markers and tokenizing via MTMD helpers. + /// + /// Prompt text containing optional media markers. + /// Mutable inference state. + /// Whether to treat the prompt as a fresh run and add the BOS token. + private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos = true) + { + if (ClipModel is null) + { + throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); + } + + DisposeMtmdChunks(); + + var marker = GetMtmdMarker(); + var prompt = text; + + if (Embeds.Count > 0) + { + if (prompt.Contains("")) { - _imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromMemory(ClipModel!.NativeHandle, Context, image)); + prompt = prompt.Replace("", marker); } - int imageIndex = text.IndexOf(""); - // Tokenize segment 1 (before tag) - string preImagePrompt = text.Substring(0, imageIndex); - var segment1 = Context.Tokenize(preImagePrompt, addBos, true); - // Remember the position to add the image embeddings - _EmbedImagePosition = segment1.Length; - string postImagePrompt = text.Substring(imageIndex + 7); - var segment2 = Context.Tokenize(postImagePrompt, false, true); - _embed_inps.AddRange(segment1); - _embed_inps.AddRange(segment2); + if (!prompt.Contains(marker)) + { + var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count)); // Ensure tokenizer sees one marker per embed. + prompt = string.Concat(prompt, suffix); + } } - else + + SafeMtmdInputChunks? chunks = null; + try { + var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); + if (status != 0 || chunks is null) + { + ClipModel.ClearMedia(); + throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); + } + + _mtmdChunks = chunks; // Own the chunk collection until evaluation completes. + + var tokens = new List(); + foreach (var chunk in chunks.Enumerate()) + { + using var scopedChunk = chunk; + if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) + { + continue; + } + + foreach (var token in scopedChunk.GetTextTokensSpan()) + { + tokens.Add(unchecked((int)token)); + } + } + + var totalPositions = (int)ClipModel.CountPositions(chunks); + var fillerToken = GetFillerToken(marker); + if (addBos) { - _embed_inps = Context.Tokenize(text, true, true).ToList(); + _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken); + _consumedTokensCount = 0; } else { - var line_inp = Context.Tokenize(text, false, true); - _embed_inps.AddRange(line_inp); - args.RemainedTokens -= line_inp.Length; + if (_embed_inps.Count == 0) + _embed_inps = new List(); + + _embed_inps.AddRange(tokens); + var fillerCount = totalPositions - tokens.Count; + if (fillerCount > 0) + _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount)); + + args.RemainedTokens -= tokens.Count; } } + catch + { + chunks?.Dispose(); + _mtmdChunks = null; + throw; + } + finally + { + DisposeEmbeds(); // Flush any embeds decoded in prior step; MTMD replays them via chunk eval. + } + + return Task.CompletedTask; } /// - /// Return whether to break the generation. + /// Decide whether generation should stop based on antiprompts, token limits, or end-of-generation markers. /// - /// - /// + /// Sampling parameters controlling generation. + /// Mutable inference state. /// - /// + /// Tuple describing whether to stop and any additional outputs to emit. protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { if (_embed_inps.Count <= _consumedTokensCount) @@ -264,51 +393,87 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In HandleRunOutOfContext(tokensToKeep); } - TryReuseMatchingPrefix(); - - // Changes to support Multi-Modal LLMs. - // - (DecodeResult, int, int) header, end, result; - if (IsMultiModal && _EmbedImagePosition > 0) + if (_mtmdChunks is null) { - // Tokens previous to the images - header = await Context.DecodeAsync(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount); - _pastTokensCount = header.Item3; + TryReuseMatchingPrefix(); + } - if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1); + if (IsMultiModal && _mtmdChunks is not null) + { + var nPast = (long)_pastTokensCount; + var previousConsumed = _consumedTokensCount; + var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, + nBatch: checked((int)Context.BatchSize), logitsLast: true); + if (evalStatus != 0) + { + _logger?.LogError("[InteractiveExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); + DisposeMtmdChunks(); + throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); + } - // Images - foreach (var image in _imageEmbedHandles) - ClipModel!.EvalImageEmbed(Context, image, ref _pastTokensCount); + _pastTokensCount = checked((int)nPast); + DisposeMtmdChunks(); - // Post-image Tokens - end = await Context.DecodeAsync(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount); - _pastTokensCount = end.Item3; + if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) + { + _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); + _n_session_consumed = _session_tokens.Count; + } - _EmbedImagePosition = -1; - _imageEmbedHandles.Clear(); - Images.Clear(); + _consumedTokensCount = _embed_inps.Count; + _embeds.Clear(); } else { - result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount); + var result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount); _pastTokensCount = result.Item3; if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1); + + if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) + { + _session_tokens.AddRange(_embeds); + _n_session_consumed = _session_tokens.Count; + } + } + } + else if (IsMultiModal && _mtmdChunks is not null) + { + _is_prompt_run = false; + var nPast = (long)_pastTokensCount; + var previousConsumed = _consumedTokensCount; + var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)Context.BatchSize), logitsLast: true); + if (evalStatus != 0) + { + _logger?.LogError("[InteractiveExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); + DisposeMtmdChunks(); + throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); } + _pastTokensCount = checked((int)nPast); + DisposeMtmdChunks(); - if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) + if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) { - _session_tokens.AddRange(_embeds); + _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); _n_session_consumed = _session_tokens.Count; } - } + _consumedTokensCount = _embed_inps.Count; + _embeds.Clear(); + } + _embeds.Clear(); if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) { + if (inferenceParams.MaxTokens == 0) + { + _embeds.Clear(); + args.WaitForInput = true; + args.ReturnValue = false; + return; + } // optionally save the session on first sample (for faster prompt loading next time) if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) { @@ -355,10 +520,10 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In } /// - /// The descriptor of the state of the interactive executor. + /// Serializable state specific to the interactive executor. /// public class InteractiveExecutorState - : ExecutorBaseState + : StatefulExecutorBase.ExecutorBaseState { /// /// Whether the executor is running for the first time (running the prompt). diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index e91436b89..63948b596 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -3,7 +3,7 @@ netstandard2.0;net8.0 LLama enable - 12 + 13 AnyCPU;x64;Arm64 True @@ -17,7 +17,7 @@ https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 LLama, LLM, GPT, ChatGPT, NLP, AI, Chat Bot, SciSharp - LLamaSharp is a cross-platform library to run 🦙LLaMA/LLaVA model (and others) in your local device. + LLamaSharp is a cross-platform library to run 🦙LLaMA/Mtmd model (and others) in your local device. Based on [llama.cpp](https://github.com/ggerganov/llama.cpp), inference with LLamaSharp is efficient on both CPU and GPU. With the higher-level APIs and RAG support, it's convenient to deploy LLM (Large Language Model) in your application with LLamaSharp. diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 8f9b40cc3..94bc60830 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -28,10 +28,10 @@ public class StatelessExecutor public bool IsMultiModal => false; /// - public LLavaWeights? ClipModel => default; + public SafeMtmdWeights? ClipModel => default; /// - public List Images { get; } + public List Embeds { get; } /// /// The context used by the executor when running the inference. @@ -57,7 +57,7 @@ public class StatelessExecutor /// public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null) { - Images = [ ]; + Embeds = [ ]; _weights = weights; _params = @params; _logger = logger; diff --git a/LLama/LLavaWeights.cs b/LLama/LLavaWeights.cs deleted file mode 100644 index f2f9f6256..000000000 --- a/LLama/LLavaWeights.cs +++ /dev/null @@ -1,137 +0,0 @@ - -using System; -using System.Threading; -using System.Threading.Tasks; -using LLama.Native; - -namespace LLama; - -/// -/// A set of llava model weights (mmproj), loaded into memory. -/// -public sealed class LLavaWeights - : IDisposable -{ - /// - /// The native handle, which is used in the native APIs - /// - /// Be careful how you use this! - public SafeLlavaModelHandle NativeHandle { get; } - - private LLavaWeights(SafeLlavaModelHandle weights) - { - NativeHandle = weights; - } - - #region load - /// - /// Load weights into memory - /// - /// path to the "mmproj" model file - /// - public static LLavaWeights LoadFromFile(string mmProject) - { - var weights = SafeLlavaModelHandle.LoadFromFile(mmProject, 1); - return new LLavaWeights(weights); - } - - /// - /// Load weights into memory - /// - /// path to the "mmproj" model file - /// - /// - public static Task LoadFromFileAsync(string mmProject, CancellationToken token = default) - { - return Task.Run(() => LoadFromFile(mmProject), token); - } - #endregion - - #region embed - /// - /// Create the Image Embeddings from the bytes of an image. - /// - /// - /// Image bytes. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image) - { - return NativeHandle.CreateImageEmbeddings(ctxLlama, image); - } - - /// - /// Create the Image Embeddings. - /// - /// Image in binary format (it supports jpeg format only) - /// Number of threads to use - /// return the SafeHandle of these embeddings - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads = -1) - { - return NativeHandle.CreateImageEmbeddings(image, threads); - } - - /// - /// Create the Image Embeddings from the bytes of an image. - /// - /// - /// Path to the image file. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - /// - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image) - { - return NativeHandle.CreateImageEmbeddings(ctxLlama, image); - } - - /// - /// Create the Image Embeddings from the bytes of an image. - /// - /// Path to the image file. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - /// - /// - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(string image, int threads = -1) - { - return NativeHandle.CreateImageEmbeddings(image, threads); - } - #endregion - - /// - /// Eval the image embeddings - /// - /// - /// - /// - /// - public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imageEmbed, ref int n_past) - { - return NativeHandle.EvalImageEmbed( ctxLlama, imageEmbed, ref n_past ); - } - - /// - public void Dispose() - { - NativeHandle.Dispose(); - } - -} \ No newline at end of file diff --git a/LLama/Native/LLavaImageEmbed.cs b/LLama/Native/LLavaImageEmbed.cs deleted file mode 100644 index 65eba230c..000000000 --- a/LLama/Native/LLavaImageEmbed.cs +++ /dev/null @@ -1,19 +0,0 @@ -namespace LLama.Native; - -/// -/// LLaVa Image embeddings -/// -/// llava_image_embed -[StructLayout(LayoutKind.Sequential)] -public unsafe struct LLavaImageEmbed -{ - /// - /// The embeddings of the embedded image. - /// - public float* embed; - - /// - /// The position of the image's tokens. - /// - public int n_image_pos; -} \ No newline at end of file diff --git a/LLama/Native/Load/NativeLibraryConfig.cs b/LLama/Native/Load/NativeLibraryConfig.cs index c20453e27..723717c23 100644 --- a/LLama/Native/Load/NativeLibraryConfig.cs +++ b/LLama/Native/Load/NativeLibraryConfig.cs @@ -299,15 +299,15 @@ public sealed partial class NativeLibraryConfig public static NativeLibraryConfig LLama { get; } /// - /// Configuration for LLava native library + /// Configuration for Mtmd native library /// - public static NativeLibraryConfig LLava { get; } + public static NativeLibraryConfig Mtmd { get; } static NativeLibraryConfig() { LLama = new(NativeLibraryName.LLama); - LLava = new(NativeLibraryName.LLava); - All = new(LLama, LLava); + Mtmd = new(NativeLibraryName.Mtmd); + All = new(LLama, Mtmd); } #if NETSTANDARD2_0 @@ -413,9 +413,9 @@ public void ForEach(Action action) /// When this method is called, all the other configurations will be ignored. /// /// The full path to the llama library to load. - /// The full path to the llava library to load. + /// The full path to the mtmd library to load. /// Thrown if `LibraryHasLoaded` is true. - public NativeLibraryConfigContainer WithLibrary(string? llamaPath, string? llavaPath) + public NativeLibraryConfigContainer WithLibrary(string? llamaPath, string? mtmdPath) { foreach(var config in _configs) { @@ -423,9 +423,9 @@ public NativeLibraryConfigContainer WithLibrary(string? llamaPath, string? llava { config.WithLibrary(llamaPath); } - if(config.NativeLibraryName == NativeLibraryName.LLava && llavaPath is not null) + if(config.NativeLibraryName == NativeLibraryName.Mtmd && mtmdPath is not null) { - config.WithLibrary(llavaPath); + config.WithLibrary(mtmdPath); } } @@ -594,7 +594,7 @@ public NativeLibraryConfigContainer WithLogCallback(ILogger? logger) /// You can still modify the configuration after this calling but only before any call from . /// /// Whether the running is successful. - public bool DryRun(out INativeLibrary? loadedLLamaNativeLibrary, out INativeLibrary? loadedLLavaNativeLibrary) + public bool DryRun(out INativeLibrary? loadedLLamaNativeLibrary, out INativeLibrary? loadedMtmdNativeLibrary) { bool success = true; foreach(var config in _configs) @@ -604,16 +604,16 @@ public bool DryRun(out INativeLibrary? loadedLLamaNativeLibrary, out INativeLibr { loadedLLamaNativeLibrary = loadedLibrary; } - else if(config.NativeLibraryName == NativeLibraryName.LLava) + else if(config.NativeLibraryName == NativeLibraryName.Mtmd) { - loadedLLavaNativeLibrary = loadedLibrary; + loadedMtmdNativeLibrary = loadedLibrary; } else { throw new Exception("Unknown native library config during the dry run."); } } - loadedLLamaNativeLibrary = loadedLLavaNativeLibrary = null; + loadedLLamaNativeLibrary = loadedMtmdNativeLibrary = null; return success; } } @@ -628,9 +628,9 @@ public enum NativeLibraryName /// LLama, /// - /// The native library compiled from the LLaVA example of llama.cpp. + /// The native library compiled from the MTMD library of llama.cpp. /// - LLava + Mtmd } internal static class LibraryNameExtensions @@ -641,8 +641,8 @@ public static string GetLibraryName(this NativeLibraryName name) { case NativeLibraryName.LLama: return NativeApi.libraryName; - case NativeLibraryName.LLava: - return NativeApi.llavaLibraryName; + case NativeLibraryName.Mtmd: + return NativeApi.mtmdLibraryName; default: throw new ArgumentOutOfRangeException(nameof(name), name, null); } diff --git a/LLama/Native/Load/NativeLibraryUtils.cs b/LLama/Native/Load/NativeLibraryUtils.cs index 9f6457cd1..84ababc60 100644 --- a/LLama/Native/Load/NativeLibraryUtils.cs +++ b/LLama/Native/Load/NativeLibraryUtils.cs @@ -9,7 +9,7 @@ namespace LLama.Native internal static class NativeLibraryUtils { /// - /// Try to load libllama/llava_shared, using CPU feature detection to try and load a more specialised DLL if possible + /// Try to load libllama/mtmd, using CPU feature detection to try and load a more specialised DLL if possible /// /// The library handle to unload later, or IntPtr.Zero if no library was loaded internal static IntPtr TryLoadLibrary(NativeLibraryConfig config, out INativeLibrary? loadedLibrary) diff --git a/LLama/Native/MtmdContextParams.cs b/LLama/Native/MtmdContextParams.cs new file mode 100644 index 000000000..d83831d85 --- /dev/null +++ b/LLama/Native/MtmdContextParams.cs @@ -0,0 +1,148 @@ +using System; +using System.Runtime.InteropServices; +using System.Text; + +namespace LLama.Native; + +/// +/// Managed representation of the native mtmd_context_params structure used to configure multimodal helpers. +/// +public class MtmdContextParams +{ + /// + /// Whether GPU acceleration should be requested when available. + /// + public bool UseGpu { get; set; } + + /// + /// Whether timing information should be emitted by the native helper. + /// + public bool PrintTimings { get; set; } + + /// + /// Number of worker threads to dedicate to preprocessing and tokenization. + /// + public int NThreads { get; set; } + + /// + /// Verbosity level forwarded to llama.cpp logging (matches ggml_log_level). + /// + public int Verbosity { get; set; } + + /// + /// Marker token inserted into the text stream to reference an image embedding. + /// + public string? ImageMarker { get; set; } + + /// + /// Marker token inserted into the text stream to reference a generic media embedding. + /// + public string? MediaMarker { get; set; } + + /// + /// Create a managed copy of the native defaults returned by . + /// + public static MtmdContextParams Default() + { + var native = NativeApi.mtmd_context_params_default(); + return new MtmdContextParams + { + UseGpu = native.use_gpu, + PrintTimings = native.print_timings, + NThreads = native.n_threads, + Verbosity = native.verbosity, + ImageMarker = PtrToString(native.image_marker), + MediaMarker = PtrToString(native.media_marker) + }; + } + + private static string? PtrToString(IntPtr ptr) + { + if (ptr == IntPtr.Zero) + return null; + +#if NETSTANDARD2_0 + unsafe + { + var length = 0; + var current = (byte*)ptr; + while (current[length] != 0) + length++; + + if (length == 0) + return string.Empty; + + var buffer = new byte[length]; + Marshal.Copy(ptr, buffer, 0, length); + return Encoding.UTF8.GetString(buffer); + } +#else + return Marshal.PtrToStringUTF8(ptr); +#endif + } + + /// + /// Convert the managed representation to a native structure, pinning strings for the duration of the scope. + /// + internal NativeScope ToNativeScope() => new(this); + + internal readonly struct NativeScope : IDisposable + { + public NativeApi.mtmd_context_params Value { get; } + + private readonly PinnedUtf8String? _imageMarker; + private readonly PinnedUtf8String? _mediaMarker; + + public NativeScope(MtmdContextParams managed) + { + _imageMarker = PinnedUtf8String.Create(managed.ImageMarker); + _mediaMarker = PinnedUtf8String.Create(managed.MediaMarker); + + var native = NativeApi.mtmd_context_params_default(); + native.use_gpu = managed.UseGpu; + native.print_timings = managed.PrintTimings; + native.n_threads = managed.NThreads; + native.verbosity = managed.Verbosity; + + if (_imageMarker is not null) + native.image_marker = _imageMarker.Pointer; + if (_mediaMarker is not null) + native.media_marker = _mediaMarker.Pointer; + + Value = native; + } + + public void Dispose() + { + _imageMarker?.Dispose(); + _mediaMarker?.Dispose(); + } + } +} + +/// +/// Helper that pins a managed string as UTF-8 for the lifetime of the instance. +/// +internal sealed class PinnedUtf8String : IDisposable +{ + private readonly byte[]? _buffer; + private readonly GCHandle _handle; + + private PinnedUtf8String(string value) + { + var bytes = Encoding.UTF8.GetBytes(value); + _buffer = new byte[bytes.Length + 1]; + Buffer.BlockCopy(bytes, 0, _buffer, 0, bytes.Length); + _handle = GCHandle.Alloc(_buffer, GCHandleType.Pinned); + } + + public static PinnedUtf8String? Create(string? value) => value is null ? null : new PinnedUtf8String(value); + + public IntPtr Pointer => _buffer is null ? IntPtr.Zero : _handle.AddrOfPinnedObject(); + + public void Dispose() + { + if (_buffer is not null && _handle.IsAllocated) + _handle.Free(); + } +} diff --git a/LLama/Native/MtmdImageEmbed.cs b/LLama/Native/MtmdImageEmbed.cs new file mode 100644 index 000000000..7341b8563 --- /dev/null +++ b/LLama/Native/MtmdImageEmbed.cs @@ -0,0 +1,20 @@ +using System.Runtime.InteropServices; + +namespace LLama.Native; + +/// +/// Representation of the native llava_image_embed structure used to return image embeddings. +/// +[StructLayout(LayoutKind.Sequential)] +public unsafe struct MtmdImageEmbed +{ + /// + /// Pointer to the embedding buffer for the decoded image. + /// + public float* embed; + + /// + /// Number of sequence positions consumed by the image tokens associated with the embedding. + /// + public int n_image_pos; +} diff --git a/LLama/Native/NativeApi.LLava.cs b/LLama/Native/NativeApi.LLava.cs deleted file mode 100644 index 692e3f0ad..000000000 --- a/LLama/Native/NativeApi.LLava.cs +++ /dev/null @@ -1,63 +0,0 @@ -using System; - -namespace LLama.Native; - -public static partial class NativeApi -{ - /// - /// Sanity check for clip <-> llava embed size match - /// - /// LLama Context - /// Llava Model - /// True if validate successfully - [DllImport(llavaLibraryName, EntryPoint = "llava_validate_embed_size", CallingConvention = CallingConvention.Cdecl)] - [return: MarshalAs(UnmanagedType.U1)] - public static extern bool llava_validate_embed_size( SafeLLamaContextHandle ctxLlama, SafeLlavaModelHandle ctxClip); - - /// - /// Build an image embed from image file bytes - /// - /// SafeHandle to the Clip Model - /// Number of threads - /// Binary image in jpeg format - /// Bytes length of the image - /// SafeHandle to the Embeddings - [DllImport(llavaLibraryName, EntryPoint = "llava_image_embed_make_with_bytes", - CallingConvention = CallingConvention.Cdecl)] - public static extern - SafeLlavaImageEmbedHandle llava_image_embed_make_with_bytes(SafeLlavaModelHandle ctx_clip, int n_threads, - byte[] image_bytes, int image_bytes_length); - - /// - /// Build an image embed from a path to an image filename - /// - /// SafeHandle to the Clip Model - /// Number of threads - /// Image filename (jpeg) to generate embeddings - /// SafeHandle to the embeddings - [DllImport(llavaLibraryName, EntryPoint = "llava_image_embed_make_with_filename", CallingConvention = CallingConvention.Cdecl)] - public static extern - SafeLlavaImageEmbedHandle llava_image_embed_make_with_filename(SafeLlavaModelHandle ctx_clip, int n_threads, - [MarshalAs(UnmanagedType.LPStr)] string image_path); - - /// - /// Free an embedding made with llava_image_embed_make_* - /// - /// Embeddings to release - [DllImport(llavaLibraryName, EntryPoint = "llava_image_embed_free", CallingConvention = CallingConvention.Cdecl)] - public static extern void llava_image_embed_free(IntPtr embed); - - /// - /// Write the image represented by embed into the llama context with batch size n_batch, starting at context - /// pos n_past. on completion, n_past points to the next position in the context after the image embed. - /// - /// Llama Context - /// Embedding handle - /// - /// - /// True on success - [DllImport(llavaLibraryName, EntryPoint = "llava_eval_image_embed", CallingConvention = CallingConvention.Cdecl)] - [return: MarshalAs(UnmanagedType.U1)] - public static extern bool llava_eval_image_embed(SafeLLamaContextHandle ctx_llama, SafeLlavaImageEmbedHandle embed, int n_batch, ref int n_past); - -} \ No newline at end of file diff --git a/LLama/Native/NativeApi.Load.cs b/LLama/Native/NativeApi.Load.cs index 4555ed0d2..57bb2d146 100644 --- a/LLama/Native/NativeApi.Load.cs +++ b/LLama/Native/NativeApi.Load.cs @@ -16,7 +16,7 @@ static NativeApi() // Set flag to indicate that this point has been passed. No native library config can be done after this point. NativeLibraryConfig.LLama.LibraryHasLoaded = true; - NativeLibraryConfig.LLava.LibraryHasLoaded = true; + NativeLibraryConfig.Mtmd.LibraryHasLoaded = true; // Immediately make a call which requires loading the llama DLL. This method call // can't fail unless the DLL hasn't been loaded. @@ -45,7 +45,7 @@ static NativeApi() #if NET5_0_OR_GREATER private static IntPtr _loadedLlamaHandle; - private static IntPtr _loadedLlavaSharedHandle; + private static IntPtr _loadedMtmdHandle; #endif private static void SetDllImportResolver() @@ -72,15 +72,15 @@ private static void SetDllImportResolver() return _loadedLlamaHandle; } - if (name == "llava_shared") + if (name == "mtmd") { - // If we've already loaded llava return the handle that was loaded last time. - if (_loadedLlavaSharedHandle != IntPtr.Zero) - return _loadedLlavaSharedHandle; + // If we've already loaded Mtmd return the handle that was loaded last time. + if (_loadedMtmdHandle != IntPtr.Zero) + return _loadedMtmdHandle; // Try to load a preferred library, based on CPU feature detection - _loadedLlavaSharedHandle = NativeLibraryUtils.TryLoadLibrary(NativeLibraryConfig.LLava, out _loadedLLavaLibrary); - return _loadedLlavaSharedHandle; + _loadedMtmdHandle = NativeLibraryUtils.TryLoadLibrary(NativeLibraryConfig.Mtmd, out _loadedMtmdLibrary); + return _loadedMtmdHandle; } // Return null pointer to indicate that nothing was loaded. @@ -100,17 +100,17 @@ private static void SetDllImportResolver() return name switch { NativeLibraryName.LLama => _loadedLLamaLibrary, - NativeLibraryName.LLava => _loadedLLavaLibrary, + NativeLibraryName.Mtmd => _loadedMtmdLibrary, _ => throw new ArgumentException($"Library name {name} is not found.") }; } internal const string libraryName = "llama"; - internal const string llavaLibraryName = "llava_shared"; + internal const string mtmdLibraryName = "mtmd"; internal const string ggmlLibraryName = "ggml"; internal const string ggmlBaseLibraryName = "ggml-base"; private static INativeLibrary? _loadedLLamaLibrary = null; - private static INativeLibrary? _loadedLLavaLibrary = null; + private static INativeLibrary? _loadedMtmdLibrary = null; } } diff --git a/LLama/Native/NativeApi.Mtmd.cs b/LLama/Native/NativeApi.Mtmd.cs new file mode 100644 index 000000000..bfd6193c2 --- /dev/null +++ b/LLama/Native/NativeApi.Mtmd.cs @@ -0,0 +1,312 @@ +using System; +using System.Runtime.InteropServices; +using System.Text; + +namespace LLama.Native; + +/// +/// P/Invoke surface for MTMD (multimodal) helpers exposed by llama.cpp. +/// +public static partial class NativeApi +{ + /// + /// Convert a UTF-8 encoded native string pointer into a managed . + /// Returns null when the pointer is zero. + /// + public static string? PtrToStringUtf8(IntPtr ptr) + { + if (ptr == IntPtr.Zero) + return null; + +#if NETSTANDARD2_0 + unsafe + { + var current = (byte*)ptr; + var length = 0; + while (current[length] != 0) + length++; + + if (length == 0) + return string.Empty; + + var buffer = new byte[length]; + Marshal.Copy(ptr, buffer, 0, length); + return Encoding.UTF8.GetString(buffer); + } +#else + return Marshal.PtrToStringUTF8(ptr); +#endif + } + + /// + /// Native context parameters returned by . + /// + [StructLayout(LayoutKind.Sequential)] + internal struct mtmd_context_params + { + [MarshalAs(UnmanagedType.I1)] public bool use_gpu; + [MarshalAs(UnmanagedType.I1)] public bool print_timings; + public int n_threads; + public int verbosity; + public IntPtr image_marker; + public IntPtr media_marker; + } + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_default_marker", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_default_marker(); + + /// + /// Retrieve the default multimodal marker text. + /// + public static string? MtmdDefaultMarker() + => PtrToStringUtf8(mtmd_default_marker()); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_context_params_default", CallingConvention = CallingConvention.Cdecl)] + internal static extern mtmd_context_params mtmd_context_params_default(); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_decode_use_non_causal", CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + internal static extern bool mtmd_decode_use_non_causal(IntPtr ctx); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_decode_use_mrope", CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + internal static extern bool mtmd_decode_use_mrope(IntPtr ctx); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_support_vision", CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + internal static extern bool mtmd_support_vision(IntPtr ctx); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_support_audio", CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + internal static extern bool mtmd_support_audio(IntPtr ctx); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_get_audio_bitrate", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_get_audio_bitrate(IntPtr ctx); + + // bitmap ------------------------------------------------------------ + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_init", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_bitmap_init(uint nx, uint ny, IntPtr data); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_init_from_audio", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_bitmap_init_from_audio(ulong n_samples, IntPtr data); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_nx", CallingConvention = CallingConvention.Cdecl)] + internal static extern uint mtmd_bitmap_get_nx(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_ny", CallingConvention = CallingConvention.Cdecl)] + internal static extern uint mtmd_bitmap_get_ny(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_data", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_bitmap_get_data(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_n_bytes", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_bitmap_get_n_bytes(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_is_audio", CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + internal static extern bool mtmd_bitmap_is_audio(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_free", CallingConvention = CallingConvention.Cdecl)] + internal static extern void mtmd_bitmap_free(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_id", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_bitmap_get_id(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_set_id", CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe void mtmd_bitmap_set_id_native(IntPtr bitmap, byte* id); + + /// + /// Assign an identifier to a bitmap using a UTF-8 encoded string. + /// + internal static unsafe void mtmd_bitmap_set_id(IntPtr bitmap, string? id) + { + if (bitmap == IntPtr.Zero) + throw new ArgumentNullException(nameof(bitmap)); + + if (id is null) + { + mtmd_bitmap_set_id_native(bitmap, null); + return; + } + + using var pinned = PinnedUtf8String.Create(id) ?? throw new ArgumentNullException(nameof(id)); + mtmd_bitmap_set_id_native(bitmap, (byte*)pinned.Pointer); + } + + // input_chunks ------------------------------------------------------ + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunks_init", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunks_init(); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunks_size", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_input_chunks_size(IntPtr chunks); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunks_get", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunks_get(IntPtr chunks, UIntPtr idx); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunks_free", CallingConvention = CallingConvention.Cdecl)] + internal static extern void mtmd_input_chunks_free(IntPtr chunks); + + // input_chunk ------------------------------------------------------- + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_type", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_input_chunk_get_type(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_tokens_text", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunk_get_tokens_text(IntPtr chunk, out UIntPtr n_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_tokens_image", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunk_get_tokens_image(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_n_tokens", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_input_chunk_get_n_tokens(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_id", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunk_get_id(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_n_pos", CallingConvention = CallingConvention.Cdecl)] + internal static extern long mtmd_input_chunk_get_n_pos(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_copy", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunk_copy(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_free", CallingConvention = CallingConvention.Cdecl)] + internal static extern void mtmd_input_chunk_free(IntPtr chunk); + + // image_tokens ------------------------------------------------------ + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_n_tokens", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_image_tokens_get_n_tokens(IntPtr image_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_nx", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_image_tokens_get_nx(IntPtr image_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_ny", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_image_tokens_get_ny(IntPtr image_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_id", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_image_tokens_get_id(IntPtr image_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_n_pos", CallingConvention = CallingConvention.Cdecl)] + internal static extern long mtmd_image_tokens_get_n_pos(IntPtr image_tokens); + + // tokenize ---------------------------------------------------------- + + /// + /// Native text structure consumed by . + /// + internal unsafe struct mtmd_input_text_native + { + public byte* text; + [MarshalAs(UnmanagedType.I1)] public bool add_special; + [MarshalAs(UnmanagedType.I1)] public bool parse_special; + } + + /// + /// Utility scope that pins managed text while invoking the native tokenizer. + /// + internal readonly unsafe ref struct MtmdInputTextScope + { + public readonly mtmd_input_text_native Value; + private readonly PinnedUtf8String _text; + + public MtmdInputTextScope(string text, bool addSpecial, bool parseSpecial) + { + _text = PinnedUtf8String.Create(text) ?? throw new ArgumentNullException(nameof(text)); + Value = new mtmd_input_text_native + { + text = (byte*)_text.Pointer, + add_special = addSpecial, + parse_special = parseSpecial + }; + } + + public void Dispose() => _text.Dispose(); + } + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_tokenize", CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe int mtmd_tokenize_native( + IntPtr ctx, + IntPtr output, + mtmd_input_text_native* text, + IntPtr[] bitmaps, + UIntPtr n_bitmaps); + + internal static unsafe int mtmd_tokenize(IntPtr ctx, IntPtr output, in mtmd_input_text_native text, IntPtr[] bitmaps, UIntPtr n_bitmaps) + { + var temp = text; + return mtmd_tokenize_native(ctx, output, &temp, bitmaps, n_bitmaps); + } + + internal static unsafe int mtmd_tokenize(IntPtr ctx, IntPtr output, string text, bool addSpecial, bool parseSpecial, IntPtr[] bitmaps, UIntPtr n_bitmaps) + { + using var scope = new MtmdInputTextScope(text, addSpecial, parseSpecial); + return mtmd_tokenize_native(ctx, output, &scope.Value, bitmaps, n_bitmaps); + } + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_encode", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_encode(IntPtr ctx, IntPtr image_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_encode_chunk", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_encode_chunk(IntPtr ctx, IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_get_output_embd", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_get_output_embd(IntPtr ctx); + + // helper ------------------------------------------------------------ + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_test_create_input_chunks", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_test_create_input_chunks(); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_bitmap_init_from_file", CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe IntPtr mtmd_helper_bitmap_init_from_file_native(IntPtr ctx, byte* fname); + + internal static unsafe IntPtr mtmd_helper_bitmap_init_from_file(IntPtr ctx, string fname) + { + using var pinned = PinnedUtf8String.Create(fname) ?? throw new ArgumentNullException(nameof(fname)); + return mtmd_helper_bitmap_init_from_file_native(ctx, (byte*)pinned.Pointer); + } + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_bitmap_init_from_buf", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_helper_bitmap_init_from_buf(IntPtr ctx, IntPtr buf, UIntPtr len); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_get_n_tokens", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_helper_get_n_tokens(IntPtr chunks); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_get_n_pos", CallingConvention = CallingConvention.Cdecl)] + internal static extern long mtmd_helper_get_n_pos(IntPtr chunks); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_eval_chunks", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_helper_eval_chunks( + IntPtr ctx, + IntPtr lctx, + IntPtr chunks, + long n_past, + int seq_id, + int n_batch, + [MarshalAs(UnmanagedType.I1)] bool logits_last, + ref long new_n_past); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_eval_chunk_single", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_helper_eval_chunk_single( + IntPtr ctx, + IntPtr lctx, + IntPtr chunk, + long n_past, + int seq_id, + int n_batch, + [MarshalAs(UnmanagedType.I1)] bool logits_last, + ref long new_n_past); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_decode_image_chunk", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_helper_decode_image_chunk( + IntPtr ctx, + IntPtr lctx, + IntPtr chunk, + IntPtr encoded_embd, + long n_past, + int seq_id, + int n_batch, + ref long new_n_past); +} diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index db9e928bd..3123674fc 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -1,4 +1,5 @@ using System; +using System.Text; #pragma warning disable IDE1006 // Naming Styles @@ -323,21 +324,115 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) /// /// /// Returns the split_path length. - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no, int split_count); + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_split_path")] + private static extern unsafe int llama_split_path_native(byte* split_path, nuint maxlen, byte* path_prefix, int split_no, int split_count); + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_split_prefix")] + private static extern unsafe int llama_split_prefix_native(byte* split_prefix, nuint maxlen, byte* split_path, int split_no, int split_count); + + private static byte[] EncodeNullTerminatedUtf8(string value, string paramName) + { + if (value is null) + throw new ArgumentNullException(paramName); + + var bytes = Encoding.UTF8.GetBytes(value); + var buffer = new byte[bytes.Length + 1]; + Buffer.BlockCopy(bytes, 0, buffer, 0, bytes.Length); + // buffer[^1] = 0; + return buffer; + } /// - /// Extract the path prefix from the split_path if and only if the split_no and split_count match. - /// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0" + /// Build the fully-qualified path for a specific split file in a GGUF shard set. /// - /// - /// - /// - /// - /// - /// Returns the split_prefix length. - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no, int split_count); + /// Writable buffer that receives the UTF-8 encoded path. + /// Base path (e.g. "/models/ggml-model-q4_0"). + /// Zero-based split index. + /// Total number of splits. + /// Number of bytes written to . + public static int llama_split_path(Span splitPathBuffer, string pathPrefix, int splitNo, int splitCount) + { + if (splitPathBuffer.Length == 0) + throw new ArgumentException("Buffer must not be empty.", nameof(splitPathBuffer)); + + var pathPrefixBytes = EncodeNullTerminatedUtf8(pathPrefix, nameof(pathPrefix)); + + unsafe + { + fixed (byte* splitPtr = splitPathBuffer) + fixed (byte* prefixPtr = pathPrefixBytes) + { + return llama_split_path_native(splitPtr, (nuint)splitPathBuffer.Length, prefixPtr, splitNo, splitCount); + } + } + } + + /// + /// Build the fully-qualified path for a specific split file in a GGUF shard set. + /// + /// Base path (e.g. "/models/ggml-model-q4_0"). + /// Zero-based split index. + /// Total number of splits. + /// Maximum number of bytes to allocate for the resulting UTF-8 string. + /// UTF-8 decoded split path. + public static string llama_split_path(string pathPrefix, int splitNo, int splitCount, int maxLength = 1024) + { + if (maxLength <= 0) + throw new ArgumentOutOfRangeException(nameof(maxLength)); + + var buffer = new byte[maxLength]; + var written = llama_split_path((Span)buffer, pathPrefix, splitNo, splitCount); + if (written <= 0) + throw new InvalidOperationException("Failed to build split path using llama_split_path."); + + return Encoding.UTF8.GetString(buffer, 0, written); + } + + /// + /// Extract the shard prefix from a GGUF split path when the split metadata matches. + /// + /// Writable buffer that receives the UTF-8 encoded prefix. + /// Full path to a shard file. + /// Zero-based split index. + /// Total number of splits. + /// Number of bytes written to . + public static int llama_split_prefix(Span splitPrefixBuffer, string splitPath, int splitNo, int splitCount) + { + if (splitPrefixBuffer.Length == 0) + throw new ArgumentException("Buffer must not be empty.", nameof(splitPrefixBuffer)); + + var splitPathBytes = EncodeNullTerminatedUtf8(splitPath, nameof(splitPath)); + + unsafe + { + fixed (byte* prefixPtr = splitPrefixBuffer) + fixed (byte* pathPtr = splitPathBytes) + { + return llama_split_prefix_native(prefixPtr, (nuint)splitPrefixBuffer.Length, pathPtr, splitNo, splitCount); + } + } + } + + /// + /// Extract the shard prefix from a GGUF split path when the split metadata matches. + /// + /// Full path to a shard file. + /// Zero-based split index. + /// Total number of splits. + /// Maximum number of bytes to allocate for the resulting UTF-8 string. + /// UTF-8 decoded split prefix. + public static string llama_split_prefix(string splitPath, int splitNo, int splitCount, int maxLength = 1024) + { + if (maxLength <= 0) + throw new ArgumentOutOfRangeException(nameof(maxLength)); + + var buffer = new byte[maxLength]; + var written = llama_split_prefix((Span)buffer, splitPath, splitNo, splitCount); + if (written <= 0) + throw new InvalidOperationException("Failed to extract split prefix using llama_split_prefix."); + + return Encoding.UTF8.GetString(buffer, 0, written); + } //[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] //todo: public static void llama_attach_threadpool(SafeLLamaContextHandle ctx, ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); diff --git a/LLama/Native/SafeLlavaImageEmbedHandle.cs b/LLama/Native/SafeLlavaImageEmbedHandle.cs deleted file mode 100644 index 102c4b93f..000000000 --- a/LLama/Native/SafeLlavaImageEmbedHandle.cs +++ /dev/null @@ -1,162 +0,0 @@ -using System; -using System.IO; - - -namespace LLama.Native -{ - /// - /// A Reference to a llava Image Embed handle - /// - public sealed class SafeLlavaImageEmbedHandle - : SafeLLamaHandleBase - { - /// - /// Get the model used to create this image embedding - /// - public SafeLlavaModelHandle Model { get; private set; } = null!; - - /// - /// Get the number of dimensions in an embedding - /// - public int EmbeddingDimensions => Model.EmbeddingDimensions; - - /// - /// Get the number of "patches" in an image embedding - /// - public int PatchCount => Model.PatchCount; - - #region embed - /// - /// Create an image embed from an image file - /// - /// - /// - /// Path to the image file. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - /// - public static SafeLlavaImageEmbedHandle CreateFromFileName(SafeLlavaModelHandle clip, LLamaContext ctx, string image) - { - if (!NativeApi.llava_validate_embed_size(ctx.NativeHandle, clip)) - throw new InvalidOperationException($"Cannot create image embed. Embedding dim of the multimodal projector ({clip.EmbeddingDimensions}) is not equal to embedding dim of model ({ctx.EmbeddingSize})"); - - return CreateFromFileName(clip, image, (int)ctx.BatchThreads); - } - - /// - /// Create an image embed from an image file - /// - /// - /// Path to the image file. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - /// - /// - public static SafeLlavaImageEmbedHandle CreateFromFileName(SafeLlavaModelHandle clip, string image, int threads = -1) - { - if (threads <= 0) - threads = Environment.ProcessorCount / 2; - - // Try to open the image file, this will check: - // - File exists (automatically throws FileNotFoundException) - // - File is readable (explicit check) - // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases. - using (var fs = new FileStream(image, FileMode.Open)) - if (!fs.CanRead) - throw new InvalidOperationException($"Llava image file '{image}' is not readable"); - - var embed = NativeApi.llava_image_embed_make_with_filename(clip, threads, image); - embed.Model = clip; - return embed; - } - - /// - /// Create an image embed from the bytes of an image. - /// - /// - /// - /// Image bytes. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - public static SafeLlavaImageEmbedHandle CreateFromMemory(SafeLlavaModelHandle clip, LLamaContext ctx, byte[] image) - { - if (!NativeApi.llava_validate_embed_size(ctx.NativeHandle, clip)) - throw new InvalidOperationException($"Cannot create image embed. Embedding dim of the multimodal projector ({clip.EmbeddingDimensions}) is not equal to embedding dim of model ({ctx.EmbeddingSize})"); - - return CreateFromMemory(clip, image, (int)ctx.BatchThreads); - } - - /// - /// Create an image embed from the bytes of an image. - /// - /// - /// Image bytes. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - /// - public static SafeLlavaImageEmbedHandle CreateFromMemory(SafeLlavaModelHandle clip, byte[] image, int threads = -1) - { - if (threads <= 0) - threads = Environment.ProcessorCount / 2; - - var embed = NativeApi.llava_image_embed_make_with_bytes(clip, threads, image, image.Length); - embed.Model = clip; - return embed; - } - #endregion - - /// - protected override bool ReleaseHandle() - { - NativeApi.llava_image_embed_free(DangerousGetHandle()); - SetHandle(IntPtr.Zero); - return true; - } - - /// - /// Copy the embeddings data to the destination span - /// - /// - /// - public void GetEmbedding(Span dest, int index) - { - if (index < 0) - throw new ArgumentOutOfRangeException(nameof(index), "index must be >= 0"); - if (index >= Model.PatchCount) - throw new ArgumentOutOfRangeException(nameof(index), "index must be < Model.PatchCount"); - - unsafe - { - var embed = (LLavaImageEmbed*)DangerousGetHandle(); - new Span( - embed->embed + Model.EmbeddingDimensions * index, - Model.EmbeddingDimensions - ).CopyTo(dest); - } - } - } -} diff --git a/LLama/Native/SafeLlavaModelHandle.cs b/LLama/Native/SafeLlavaModelHandle.cs deleted file mode 100644 index 5b3a910e9..000000000 --- a/LLama/Native/SafeLlavaModelHandle.cs +++ /dev/null @@ -1,137 +0,0 @@ -using System; -using System.IO; -using LLama.Exceptions; - - -namespace LLama.Native -{ - /// - /// A reference to a set of llava model weights. - /// - public sealed class SafeLlavaModelHandle - : SafeLLamaHandleBase - { - /// - /// Get the number of dimensions in an embedding - /// - public int EmbeddingDimensions => clip_n_mmproj_embd(this); - - /// - /// Get the number of "patches" in an image embedding - /// - public int PatchCount => clip_n_patches(this); - - /// - protected override bool ReleaseHandle() - { - clip_free(DangerousGetHandle()); - SetHandle(IntPtr.Zero); - return true; - } - - /// - /// Load a model from the given file path into memory - /// - /// MMP File (Multi-Modal Projections) - /// Verbosity level - /// SafeHandle of the Clip Model - /// - /// - public static SafeLlavaModelHandle LoadFromFile(string modelPath, int verbosity ) - { - // Try to open the model file, this will check: - // - File exists (automatically throws FileNotFoundException) - // - File is readable (explicit check) - // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases. - using (var fs = new FileStream(modelPath, FileMode.Open)) - if (!fs.CanRead) - throw new InvalidOperationException($"Llava MMP Model file '{modelPath}' is not readable"); - - var handle = clip_model_load(modelPath, verbosity); - if (handle.IsInvalid) - throw new LoadWeightsFailedException(modelPath); - - return handle; - } - - /// - /// Create the Image Embeddings. - /// - /// LLama Context - /// Image filename (it supports jpeg format only) - /// return the SafeHandle of these embeddings - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image) - { - return SafeLlavaImageEmbedHandle.CreateFromFileName(this, ctxLlama, image); - } - - /// - /// Create the Image Embeddings. - /// - /// Image in binary format (it supports jpeg format only) - /// Number of threads to use - /// return the SafeHandle of these embeddings - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(string image, int threads = -1) - { - return SafeLlavaImageEmbedHandle.CreateFromFileName(this, image, threads); - } - - /// - /// Create the Image Embeddings. - /// - /// LLama Context - /// Image in binary format (it supports jpeg format only) - /// return the SafeHandle of these embeddings - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image) - { - return SafeLlavaImageEmbedHandle.CreateFromMemory(this, ctxLlama, image ); - } - - /// - /// Create the Image Embeddings. - /// - /// Image in binary format (it supports jpeg format only) - /// Number of threads to use - /// return the SafeHandle of these embeddings - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads = -1) - { - return SafeLlavaImageEmbedHandle.CreateFromMemory(this, image, threads); - } - - /// - /// Evaluates the image embeddings. - /// - /// Llama Context - /// The current embeddings to evaluate - /// - /// True on success - public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imageEmbed, ref int n_past) - { - return NativeApi.llava_eval_image_embed(ctxLlama.NativeHandle, imageEmbed, (int)ctxLlama.BatchSize, ref n_past ); - } - - #region native API - /// - /// Load MULTI MODAL PROJECTIONS model / Clip Model - /// - /// Model path/file - /// Verbosity level - /// SafeLlavaModelHandle - [DllImport(NativeApi.llavaLibraryName, EntryPoint = "clip_model_load", CallingConvention = CallingConvention.Cdecl)] - private static extern SafeLlavaModelHandle clip_model_load(string mmProj, int verbosity); - - /// - /// Frees MULTI MODAL PROJECTIONS model / Clip Model - /// - /// Internal Pointer to the model - [DllImport(NativeApi.llavaLibraryName, EntryPoint = "clip_free", CallingConvention = CallingConvention.Cdecl)] - private static extern void clip_free(IntPtr ctx); - - [DllImport(NativeApi.llavaLibraryName, CallingConvention = CallingConvention.Cdecl)] - private static extern int clip_n_mmproj_embd(SafeLlavaModelHandle ctx); - - [DllImport(NativeApi.llavaLibraryName, CallingConvention = CallingConvention.Cdecl)] - private static extern int clip_n_patches(SafeLlavaModelHandle ctx); - #endregion - } -} diff --git a/LLama/Native/SafeMtmdEmbed.cs b/LLama/Native/SafeMtmdEmbed.cs new file mode 100644 index 000000000..c651db102 --- /dev/null +++ b/LLama/Native/SafeMtmdEmbed.cs @@ -0,0 +1,247 @@ +using System; +using System.IO; +using System.Runtime.InteropServices; + +namespace LLama.Native +{ + /// + /// Managed wrapper around mtmd_bitmap* resources. Instances own the native pointer + /// and ensure proper cleanup when disposed. + /// + public sealed class SafeMtmdEmbed : IDisposable + { + /// + /// Raw pointer to the native bitmap structure. Internal so other wrappers can interop. + /// + internal IntPtr NativePtr { get; private set; } + + private bool _disposed; + + private SafeMtmdEmbed(IntPtr ptr) + { + NativePtr = ptr != IntPtr.Zero + ? ptr + : throw new InvalidOperationException("Failed to create MTMD bitmap."); + } + + /// + /// Create an embedding from raw RGB bytes. + /// + /// Width of the bitmap in pixels. + /// Height of the bitmap in pixels. + /// Packed RGB data (3 bytes per pixel). + /// Managed wrapper when initialization succeeds; otherwise null. + /// The RGB buffer is null. + public static SafeMtmdEmbed? FromRgbBytes(uint nx, uint ny, byte[] rgbData) + { + if (rgbData == null) + throw new ArgumentNullException(nameof(rgbData)); + + var handle = GCHandle.Alloc(rgbData, GCHandleType.Pinned); + try + { + var native = NativeApi.mtmd_bitmap_init(nx, ny, handle.AddrOfPinnedObject()); + return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native); + } + finally + { + if (handle.IsAllocated) + handle.Free(); + } + } + + /// + /// Create an embedding from PCM audio samples. + /// + /// Array of mono PCM samples in float format. + /// Managed wrapper when initialization succeeds; otherwise null. + /// The audio buffer is null. + public static SafeMtmdEmbed? FromAudioSamples(float[] samples) + { + if (samples == null) + throw new ArgumentNullException(nameof(samples)); + + var handle = GCHandle.Alloc(samples, GCHandleType.Pinned); + try + { + var native = NativeApi.mtmd_bitmap_init_from_audio((ulong)samples.Length, handle.AddrOfPinnedObject()); + return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native); + } + finally + { + if (handle.IsAllocated) + handle.Free(); + } + } + + /// + /// Create an embedding by decoding a media file using libmtmd helpers. + /// + /// Model context that provides the decoder configuration. + /// Path to the media file on disk. + /// Managed wrapper when decoding succeeds; otherwise null. + /// The context is null. + /// The path is null or whitespace. + /// The supplied file does not exist. + public static SafeMtmdEmbed? FromMediaFile(SafeMtmdModelHandle mtmdContext, string path) + { + if (mtmdContext == null) + throw new ArgumentNullException(nameof(mtmdContext)); + if (string.IsNullOrWhiteSpace(path)) + throw new ArgumentException("Value cannot be null or whitespace.", nameof(path)); + + var fullPath = Path.GetFullPath(path); + if (!File.Exists(fullPath)) + throw new FileNotFoundException("Media file not found.", fullPath); + + bool added = false; + var ctxPtr = IntPtr.Zero; + try + { + // Hold a strong reference to the native context while the helper decodes the media file. + mtmdContext.DangerousAddRef(ref added); + ctxPtr = mtmdContext.DangerousGetHandle(); + var native = NativeApi.mtmd_helper_bitmap_init_from_file(ctxPtr, fullPath); + return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native); + } + finally + { + if (added) + mtmdContext.DangerousRelease(); + } + } + + /// + /// Create an embedding from an in-memory media buffer (image/audio/video). + /// + /// Model context that provides the decoder configuration. + /// Binary buffer containing the encoded media. + /// Managed wrapper when decoding succeeds; otherwise null. + /// The context is null. + /// The buffer is empty. + public static unsafe SafeMtmdEmbed? FromMediaBuffer(SafeMtmdModelHandle mtmdContext, ReadOnlySpan data) + { + if (mtmdContext == null) + throw new ArgumentNullException(nameof(mtmdContext)); + if (data.IsEmpty) + throw new ArgumentException("Buffer must not be empty.", nameof(data)); + + bool added = false; + var ctxPtr = IntPtr.Zero; + try + { + // Keep the context alive while the native helper processes the buffer. + mtmdContext.DangerousAddRef(ref added); + ctxPtr = mtmdContext.DangerousGetHandle(); + + fixed (byte* bufferPtr = data) + { + var native = NativeApi.mtmd_helper_bitmap_init_from_buf(ctxPtr, new IntPtr(bufferPtr), (UIntPtr)data.Length); + return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native); + } + } + finally + { + if (added) + mtmdContext.DangerousRelease(); + } + } + + /// + /// Width of the bitmap in pixels (or number of samples for audio embeddings). + /// + public uint Nx + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_bitmap_get_nx(NativePtr); + } + } + + /// + /// Height of the bitmap in pixels. For audio embeddings this is typically 1. + /// + public uint Ny + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_bitmap_get_ny(NativePtr); + } + } + + /// + /// Indicates whether the embedding stores audio data instead of image pixels. + /// + public bool IsAudio + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_bitmap_is_audio(NativePtr); + } + } + + /// + /// Optional identifier assigned to this embedding. + /// + public string? Id + { + get + { + EnsureNotDisposed(); + var ptr = NativeApi.mtmd_bitmap_get_id(NativePtr); + return NativeApi.PtrToStringUtf8(ptr); + } + set + { + EnsureNotDisposed(); + NativeApi.mtmd_bitmap_set_id(NativePtr, value); + } + } + + /// + /// Zero-copy access to the underlying bitmap bytes. The span remains valid while this wrapper is alive. + /// + /// Read-only span exposing the native data buffer. + /// The embedding has been disposed. + public unsafe ReadOnlySpan GetDataSpan() + { + EnsureNotDisposed(); + + var dataPtr = (byte*)NativeApi.mtmd_bitmap_get_data(NativePtr); + var length = checked((int)NativeApi.mtmd_bitmap_get_n_bytes(NativePtr).ToUInt64()); + return dataPtr == null || length == 0 ? ReadOnlySpan.Empty : new ReadOnlySpan(dataPtr, length); + } + + /// + /// Release the underlying native bitmap. + /// + public void Dispose() + { + if (_disposed) + return; + + if (NativePtr != IntPtr.Zero) + { + NativeApi.mtmd_bitmap_free(NativePtr); + NativePtr = IntPtr.Zero; + } + + _disposed = true; + GC.SuppressFinalize(this); + } + + /// + /// Finalizer to ensure native resources are reclaimed when Dispose is not invoked. + /// + ~SafeMtmdEmbed() => Dispose(); + + private void EnsureNotDisposed() + { + if (_disposed || NativePtr == IntPtr.Zero) + throw new ObjectDisposedException(nameof(SafeMtmdEmbed)); + } + } +} diff --git a/LLama/Native/SafeMtmdInputChunk.cs b/LLama/Native/SafeMtmdInputChunk.cs new file mode 100644 index 000000000..59d1897ef --- /dev/null +++ b/LLama/Native/SafeMtmdInputChunk.cs @@ -0,0 +1,150 @@ +using System; +using System.Runtime.InteropServices; + +namespace LLama.Native; + +/// +/// Managed wrapper around a single mtmd_input_chunk. Instances can either own the +/// underlying native pointer (when created via ) or act as non-owning views +/// produced by the tokenizer. +/// +public sealed class SafeMtmdInputChunk : IDisposable +{ + /// + /// Chunk modality returned by the native tokenizer. + /// + public enum SafeMtmdInputChunkType + { + Text = 0, + Image = 1, + Audio = 2 + } + + /// + /// Raw pointer to the native chunk structure. + /// + public IntPtr NativePtr { get; private set; } + + private bool _ownsPtr; + private bool _disposed; + + private SafeMtmdInputChunk(IntPtr ptr, bool owns) + { + NativePtr = ptr; + _ownsPtr = owns; + } + + /// + /// Wrap an existing chunk pointer without taking ownership. + /// + /// Pointer returned by the native tokenizer. + /// Managed wrapper, or null when the pointer is null. + public static SafeMtmdInputChunk Wrap(IntPtr ptr) + => ptr == IntPtr.Zero ? null : new SafeMtmdInputChunk(ptr, false); + + /// + /// Create an owning copy of the current chunk. The caller becomes responsible for disposal. + /// + /// Owning managed wrapper, or null if the native copy failed. + /// Thrown when the current wrapper has been disposed. + public SafeMtmdInputChunk Copy() + { + EnsureNotDisposed(); + + var p = NativeApi.mtmd_input_chunk_copy(NativePtr); + return p == IntPtr.Zero ? null : new SafeMtmdInputChunk(p, true); + } + + /// + /// Chunk modality reported by the native helper. + /// + public SafeMtmdInputChunkType Type + { + get + { + EnsureNotDisposed(); + return (SafeMtmdInputChunkType)NativeApi.mtmd_input_chunk_get_type(NativePtr); + } + } + + /// + /// Number of tokens contained in this chunk. + /// + public ulong NTokens + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_input_chunk_get_n_tokens(NativePtr).ToUInt64(); + } + } + + /// + /// Identifier assigned by the tokenizer (if any). + /// + public string Id + { + get + { + EnsureNotDisposed(); + return Marshal.PtrToStringAnsi(NativeApi.mtmd_input_chunk_get_id(NativePtr)) ?? string.Empty; + } + } + + /// + /// Number of positional slots consumed by this chunk. + /// + public long NPos + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_input_chunk_get_n_pos(NativePtr); + } + } + + /// + /// Zero-copy view over the chunk's token buffer. The span remains valid only while the native chunk is alive. + /// + /// Read-only span exposing the chunk's tokens. + /// Thrown when the wrapper has been disposed. + public unsafe ReadOnlySpan GetTextTokensSpan() + { + EnsureNotDisposed(); + + UIntPtr n; + var p = (uint*)NativeApi.mtmd_input_chunk_get_tokens_text(NativePtr, out n); + return p == null ? ReadOnlySpan.Empty : new ReadOnlySpan(p, checked((int)n.ToUInt64())); + } + + /// + /// Release the underlying native resources if this instance owns them. + /// + public void Dispose() + { + if (_disposed) + return; + + if (_ownsPtr && NativePtr != IntPtr.Zero) + { + NativeApi.mtmd_input_chunk_free(NativePtr); + } + + NativePtr = IntPtr.Zero; + _ownsPtr = false; + _disposed = true; + + GC.SuppressFinalize(this); + } + + /// + /// Finalizer to ensure native memory is reclaimed when Dispose is not called by owners. + /// + ~SafeMtmdInputChunk() => Dispose(); + + private void EnsureNotDisposed() + { + if (_disposed || NativePtr == IntPtr.Zero) + throw new ObjectDisposedException(nameof(SafeMtmdInputChunk)); + } +} diff --git a/LLama/Native/SafeMtmdInputChunks.cs b/LLama/Native/SafeMtmdInputChunks.cs new file mode 100644 index 000000000..2081cd0a6 --- /dev/null +++ b/LLama/Native/SafeMtmdInputChunks.cs @@ -0,0 +1,103 @@ +using System; +using System.Collections.Generic; + +namespace LLama.Native; + +/// +/// Managed lifetime wrapper around a native mtmd_input_chunks collection returned by the tokenizer. +/// +public sealed class SafeMtmdInputChunks : IDisposable +{ + /// + /// Raw pointer to the native chunk collection. Internal to allow other wrappers to interop safely. + /// + internal IntPtr NativePtr { get; private set; } + + private bool _disposed; + + internal SafeMtmdInputChunks(IntPtr ptr) + { + NativePtr = ptr; + } + + /// + /// Releases the native chunk collection and suppresses finalization. + /// + public void Dispose() + { + if (_disposed) + return; + + if (NativePtr != IntPtr.Zero) + { + NativeApi.mtmd_input_chunks_free(NativePtr); + NativePtr = IntPtr.Zero; + } + + _disposed = true; + GC.SuppressFinalize(this); + } + + /// + /// Finalizer to ensure native memory is reclaimed if Dispose is not called. + /// + ~SafeMtmdInputChunks() + { + Dispose(); + } + + /// + /// Number of chunks currently held by the native collection. + /// + public ulong Size + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_input_chunks_size(NativePtr).ToUInt64(); + } + } + + /// + /// Get a raw pointer to a chunk. The returned is the mtmd_input_chunk*. + /// Use to create a managed wrapper if desired. + /// + /// Zero-based index of the chunk to retrieve. + /// Pointer to the requested chunk. + /// The collection has already been disposed. + /// The requested index is outside of the valid range. + public IntPtr GetChunkPtr(ulong index) + { + EnsureNotDisposed(); + + if (index >= Size) throw new IndexOutOfRangeException(); + return NativeApi.mtmd_input_chunks_get(NativePtr, (UIntPtr)index); + } + + /// + /// Enumerate the contained chunks as non-owning wrappers. Callers should dispose the returned chunk + /// if they create a copy. + /// + /// Enumeration of chunk wrappers backed by the native collection. + /// The collection has already been disposed. + public IEnumerable Enumerate() + { + EnsureNotDisposed(); + + for (ulong i = 0; i < Size; i++) + { + var chunk = SafeMtmdInputChunk.Wrap(GetChunkPtr(i)); + if (chunk != null) + { + // Yield a lightweight wrapper; ownership remains with the native collection. + yield return chunk; + } + } + } + + private void EnsureNotDisposed() + { + if (_disposed || NativePtr == IntPtr.Zero) + throw new ObjectDisposedException(nameof(SafeMtmdInputChunks)); + } +} diff --git a/LLama/Native/SafeMtmdModelHandle.cs b/LLama/Native/SafeMtmdModelHandle.cs new file mode 100644 index 000000000..236a22011 --- /dev/null +++ b/LLama/Native/SafeMtmdModelHandle.cs @@ -0,0 +1,349 @@ +using System; +using System.Collections.Generic; +using System.IO; +using LLama.Exceptions; + + +namespace LLama.Native +{ + /// + /// Wrapper to the Multi Modal Weights handle. This wrapper manages the low level + /// operations. + /// + public sealed class SafeMtmdModelHandle : SafeLLamaHandleBase + { + // Pending media embeddings queued for the next call to Tokenize. + private readonly List _pendingMedia = new(); + + /// + protected override bool ReleaseHandle() + { + mtmd_free(DangerousGetHandle()); + SetHandle(IntPtr.Zero); + return true; + } + + /// + /// Load a multimodal projection model from disk and bind it to the supplied text model. + /// + /// Path to the MMP (Multi-Modal Projections) file. + /// Text model that provides tokenizer weights for the multimodal helper. + /// Optional context parameters; defaults are used when null. + /// Safe handle for the MTMD model. + /// The file exists but is not readable by the current process. + /// The native loader failed to initialize the MTMD model. + public static SafeMtmdModelHandle LoadFromFile(string modelPath, LLamaWeights textModel, MtmdContextParams mtmdCtxParams) + { + // Try to open the model file, this will check: + // - File exists (automatically throws FileNotFoundException) + // - File is readable (explicit check) + // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases. + using (var fs = new FileStream(modelPath, FileMode.Open)) + if (!fs.CanRead) + throw new InvalidOperationException($"Mtmd MMP Model file '{modelPath}' is not readable"); + + using var pathUtf8 = PinnedUtf8String.Create(modelPath) ?? throw new ArgumentNullException(nameof(modelPath)); + + unsafe + { + SafeMtmdModelHandle handle; + if (mtmdCtxParams is null) + { + var nativeParams = NativeApi.mtmd_context_params_default(); + handle = mtmd_init_from_file((byte*)pathUtf8.Pointer, textModel.NativeHandle, nativeParams); + } + else + { + using var nativeParamsScope = mtmdCtxParams.ToNativeScope(); + handle = mtmd_init_from_file((byte*)pathUtf8.Pointer, textModel.NativeHandle, nativeParamsScope.Value); + } + + if (handle.IsInvalid) + throw new LoadWeightsFailedException(modelPath); + + return handle; + } + } + + /// + /// Load media from disk and queue it for the next tokenize call. + /// + /// Absolute or relative path to the media asset. + /// Safe handle to the media embedding. + /// The model handle has been disposed. + /// The native loader failed to ingest the file. + public SafeMtmdEmbed LoadMediaFromFile(string path) + { + EnsureNotDisposed(); + + var embed = SafeMtmdEmbed.FromMediaFile(this, path) + ?? throw new RuntimeError($"Failed to load media '{path}'."); + _pendingMedia.Add(embed); + return embed; + } + + /// + /// Load media from an in-memory buffer and queue it for the next tokenize call. + /// + /// Binary buffer containing the encoded media data. + /// Safe handle to the media embedding. + /// The model handle has been disposed. + /// The native loader failed to ingest the buffer contents. + public SafeMtmdEmbed LoadMediaFromBuffer(ReadOnlySpan buffer) + { + EnsureNotDisposed(); + + var embed = SafeMtmdEmbed.FromMediaBuffer(this, buffer) + ?? throw new RuntimeError("Failed to load media from buffer."); + _pendingMedia.Add(embed); + return embed; + } + + /// + /// Disposes and clears any media buffers currently queued for tokenization. + /// + public void ClearMedia() + { + foreach (var media in _pendingMedia) + media.Dispose(); + _pendingMedia.Clear(); + } + + /// + /// Tokenize a prompt alongside the pending media buffers. Pending media is cleared on success. + /// + /// Prompt text to tokenize. + /// Whether to append special tokens automatically. + /// Whether special tokens should be treated as user-provided text. + /// Receives the native chunk collection when tokenization succeeds. + /// Zero on success; otherwise the native mtmd tokenize error code. + /// The model handle has been disposed. + public int Tokenize(string text, bool addSpecial, bool parseSpecial, out SafeMtmdInputChunks? chunks) + { + EnsureNotDisposed(); + + chunks = null; + // Allocate the chunk container before invoking the native tokenizer. + var output = NativeApi.mtmd_input_chunks_init(); + if (output == IntPtr.Zero) + throw new RuntimeError("Failed to allocate mtmd_input_chunks."); + + // Collect native pointers to the queued media embeddings. + var bitmapHandles = new IntPtr[_pendingMedia.Count]; + for (var i = 0; i < _pendingMedia.Count; i++) + bitmapHandles[i] = _pendingMedia[i].NativePtr; + + var result = NativeApi.mtmd_tokenize(DangerousGetHandle(), output, text, addSpecial, parseSpecial, bitmapHandles, (UIntPtr)bitmapHandles.Length); + + if (result == 0) + { + chunks = new SafeMtmdInputChunks(output); + foreach (var media in _pendingMedia) + media.Dispose(); + _pendingMedia.Clear(); + } + else + { + NativeApi.mtmd_input_chunks_free(output); + } + + if (result != 0) + { + foreach (var media in _pendingMedia) + media.Dispose(); + _pendingMedia.Clear(); + } + + return result; + } + + /// + /// Evaluate a batch of chunks using the helper (mirrors mtmd-helper eval logic). + /// + /// Chunk collection produced by . + /// Context handle that receives the evaluated tokens. + /// Number of past tokens; updated when evaluation succeeds. + /// Sequence identifier used for KV cache management. + /// Maximum number of tokens to evaluate in a single batch. + /// Whether to request logits for the last token only. + /// Zero on success; otherwise the native helper error code. + /// Thrown when required handles are null. + public int EvaluateChunks(SafeMtmdInputChunks chunks, SafeLLamaContextHandle llamaContext, ref long nPast, int seqId, int nBatch, bool logitsLast) + { + EnsureNotDisposed(); + + if (chunks == null) + throw new ArgumentNullException(nameof(chunks)); + if (llamaContext == null) + throw new ArgumentNullException(nameof(llamaContext)); + + var newNPast = nPast; + var result = NativeApi.mtmd_helper_eval_chunks( + DangerousGetHandle(), + llamaContext.DangerousGetHandle(), + chunks.NativePtr, + nPast, + seqId, + nBatch, + logitsLast, + ref newNPast); + + if (result == 0) + nPast = newNPast; + + return result; + } + + /// + /// Evaluate a single chunk helper. + /// + /// Pointer to the chunk to evaluate. + /// Context handle that receives the evaluated tokens. + /// Number of past tokens; updated when evaluation succeeds. + /// Sequence identifier used for KV cache management. + /// Maximum number of tokens to evaluate in a single batch. + /// Whether to request logits for the last token only. + /// Zero on success; otherwise the native helper error code. + /// Thrown when required handles are null. + public int EvaluateChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, ref long nPast, int seqId, int nBatch, bool logitsLast) + { + EnsureNotDisposed(); + + if (chunkPtr == IntPtr.Zero) + throw new ArgumentNullException(nameof(chunkPtr)); + if (llamaContext == null) + throw new ArgumentNullException(nameof(llamaContext)); + + var newNPast = nPast; + var result = NativeApi.mtmd_helper_eval_chunk_single( + DangerousGetHandle(), + llamaContext.DangerousGetHandle(), + chunkPtr, + nPast, + seqId, + nBatch, + logitsLast, + ref newNPast); + + if (result == 0) + nPast = newNPast; + + return result; + } + + /// + /// Decode a prepared image chunk whose embedding is already computed. + /// + /// Pointer to the chunk whose embedding should be decoded. + /// Context handle used for decoding. + /// Pointer to the pre-computed embedding data. + /// Number of past tokens; updated when evaluation succeeds. + /// Sequence identifier used for KV cache management. + /// Maximum number of tokens to evaluate in a single batch. + /// Zero on success; otherwise the native helper error code. + /// Thrown when required handles are null. + public int DecodeImageChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, IntPtr encodedEmbeddings, ref long nPast, int seqId, int nBatch) + { + EnsureNotDisposed(); + + if (chunkPtr == IntPtr.Zero) + throw new ArgumentNullException(nameof(chunkPtr)); + + var newNPast = nPast; + var result = NativeApi.mtmd_helper_decode_image_chunk( + DangerousGetHandle(), + llamaContext?.DangerousGetHandle() ?? throw new ArgumentNullException(nameof(llamaContext)), + chunkPtr, + encodedEmbeddings, + nPast, + seqId, + nBatch, + ref newNPast); + + if (result == 0) + nPast = newNPast; + + return result; + } + + /// + /// Get the number of tokens contained in the provided chunk collection. + /// + /// Chunk collection produced by . + /// Total token count. + public ulong CountTokens(SafeMtmdInputChunks chunks) + { + if (chunks == null) + throw new ArgumentNullException(nameof(chunks)); + return NativeApi.mtmd_helper_get_n_tokens(chunks.NativePtr).ToUInt64(); + } + + /// + /// Get the number of positions contained in the provided chunk collection. + /// + /// Chunk collection produced by . + /// Total number of positional slots consumed. + public long CountPositions(SafeMtmdInputChunks chunks) + { + if (chunks == null) + throw new ArgumentNullException(nameof(chunks)); + return NativeApi.mtmd_helper_get_n_pos(chunks.NativePtr); + } + + #region native API + + // mtmd_init_from_file(const char * mmproj_fname, const struct llama_model * text_model, const struct mtmd_context_params ctx_params); + // The llama_model layout is opaque; expose it via SafeLlamaModelHandle to match the managed wrapper. + [DllImport(NativeApi.mtmdLibraryName, EntryPoint = "mtmd_init_from_file", CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe SafeMtmdModelHandle mtmd_init_from_file( + byte* mmproj_fname, + SafeLlamaModelHandle text_model, + NativeApi.mtmd_context_params @ctx_params); + + [DllImport(NativeApi.mtmdLibraryName, EntryPoint = "mtmd_free", CallingConvention = CallingConvention.Cdecl)] + internal static extern void mtmd_free(IntPtr ctx); + + #endregion + + + + /// + /// Finalizer to ensure native resources are released if Dispose was not called. + /// + ~SafeMtmdModelHandle() + { + Dispose(); + } + + /// + /// Indicates whether the model decodes using the non-causal path. + /// + public bool DecodeUseNonCausal() => NativeApi.mtmd_decode_use_non_causal(handle); + + /// + /// Indicates whether the model decodes using multi-scale RoPE. + /// + public bool DecodeUseMRope() => NativeApi.mtmd_decode_use_mrope(handle); + + /// + /// Indicates whether the model supports vision inputs. + /// + public bool SupportVision() => NativeApi.mtmd_support_vision(handle); + + /// + /// Indicates whether the model supports audio inputs. + /// + public bool SupportAudio() => NativeApi.mtmd_support_audio(handle); + + /// + /// Gets the audio bitrate advertised by the model. + /// + public int GetAudioBitrate() => NativeApi.mtmd_get_audio_bitrate(handle); + + private void EnsureNotDisposed() + { + if (IsInvalid || IsClosed) + throw new ObjectDisposedException(nameof(SafeMtmdModelHandle)); + } + } +} diff --git a/LLama/Properties/InternalsVisibleTo.cs b/LLama/Properties/InternalsVisibleTo.cs new file mode 100644 index 000000000..b0a1ac4be --- /dev/null +++ b/LLama/Properties/InternalsVisibleTo.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("LLama.Unittest")] diff --git a/LLama/SafeMtmdWeights.cs b/LLama/SafeMtmdWeights.cs new file mode 100644 index 000000000..e490049b4 --- /dev/null +++ b/LLama/SafeMtmdWeights.cs @@ -0,0 +1,80 @@ + +using System; +using System.Threading; +using System.Threading.Tasks; +using LLama.Native; + +namespace LLama; + +/// +/// Lightweight wrapper around the MTMD native context and its helpers. +/// +public sealed class SafeMtmdWeights : IDisposable +{ + public SafeMtmdModelHandle NativeHandle { get; } + + private SafeMtmdWeights(SafeMtmdModelHandle handle) + { + NativeHandle = handle ?? throw new ArgumentNullException(nameof(handle)); + } + + public static SafeMtmdWeights LoadFromFile(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams) + { + if (mmProject == null) throw new ArgumentNullException(nameof(mmProject)); + if (textModel == null) throw new ArgumentNullException(nameof(textModel)); + if (mtmdCtxParams == null) throw new ArgumentNullException(nameof(mtmdCtxParams)); + + var handle = SafeMtmdModelHandle.LoadFromFile(mmProject, textModel, mtmdCtxParams); + return new SafeMtmdWeights(handle); + } + + public static Task LoadFromFileAsync(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams, CancellationToken token = default) + { + return Task.Run(() => LoadFromFile(mmProject, textModel, mtmdCtxParams), token); + } + + /// + /// Load media from disk and keep it pending for the next tokenize call. + /// + public SafeMtmdEmbed LoadMedia(string path) => NativeHandle.LoadMediaFromFile(path); + + /// + /// Load media from an in-memory buffer and keep it pending for the next tokenize call. + /// + public SafeMtmdEmbed LoadMedia(ReadOnlySpan data) => NativeHandle.LoadMediaFromBuffer(data); + + /// + /// Clear any pending media buffers before or after tokenization. + /// + public void ClearMedia() => NativeHandle.ClearMedia(); + + /// + /// Tokenize text (with optional special tokens) against the pending media buffers. + /// + public int Tokenize(string text, bool addSpecial, bool parseSpecial, out SafeMtmdInputChunks? chunks) + => NativeHandle.Tokenize(text, addSpecial, parseSpecial, out chunks); + + /// + /// Evaluate a chunk batch using the helper that performs mtmd encode + llama decode. + /// + public int EvaluateChunks(SafeMtmdInputChunks chunks, SafeLLamaContextHandle llamaContext, ref long nPast, int seqId, int nBatch, bool logitsLast) + => NativeHandle.EvaluateChunks(chunks, llamaContext, ref nPast, seqId, nBatch, logitsLast); + + public int EvaluateChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, ref long nPast, int seqId, int nBatch, bool logitsLast) + => NativeHandle.EvaluateChunk(chunkPtr, llamaContext, ref nPast, seqId, nBatch, logitsLast); + + public int DecodeImageChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, IntPtr encodedEmbeddings, ref long nPast, int seqId, int nBatch) + => NativeHandle.DecodeImageChunk(chunkPtr, llamaContext, encodedEmbeddings, ref nPast, seqId, nBatch); + + public ulong CountTokens(SafeMtmdInputChunks chunks) => NativeHandle.CountTokens(chunks); + + public long CountPositions(SafeMtmdInputChunks chunks) => NativeHandle.CountPositions(chunks); + + public bool SupportsVision => NativeHandle.SupportVision(); + public bool SupportsAudio => NativeHandle.SupportAudio(); + public bool UsesNonCausalAttention => NativeHandle.DecodeUseNonCausal(); + public bool UsesMRope => NativeHandle.DecodeUseMRope(); + public int AudioBitrate => NativeHandle.GetAudioBitrate(); + + public void Dispose() => NativeHandle.Dispose(); +} diff --git a/docs/Examples/LLavaInteractiveModeExecute.md b/docs/Examples/LLavaInteractiveModeExecute.md deleted file mode 100644 index 2bfbbea1d..000000000 --- a/docs/Examples/LLavaInteractiveModeExecute.md +++ /dev/null @@ -1,129 +0,0 @@ -# LLaVA - basic - -```cs -using System.Text.RegularExpressions; -using LLama.Common; -using Spectre.Console; -using LLama.Native; - -namespace LLama.Examples.Examples -{ - // This example shows how to chat with LLaVA model with both image and text as input. - // It uses the interactive executor to inference. - public class LlavaInteractiveModeExecute - { - public static async Task Run() - { - string multiModalProj = UserSettings.GetMMProjPath(); - string modelPath = UserSettings.GetModelPath(); - string modelImage = UserSettings.GetImagePath(); - const int maxTokens = 1024; - - var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n"; - - var parameters = new ModelParams(modelPath); - - using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters); - - // Llava Init - using var clipModel = LLavaWeights.LoadFromFile(multiModalProj); - - var ex = new InteractiveExecutor(context, clipModel ); - - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to {0} and the context size is {1}.", maxTokens, parameters.ContextSize ); - Console.WriteLine("To send an image, enter its filename in curly braces, like this {c:/image.jpg}."); - - var inferenceParams = new InferenceParams() { Temperature = 0.1f, AntiPrompts = new List { "\nUSER:" }, MaxTokens = maxTokens }; - - do - { - - // Evaluate if we have images - // - var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); - var imageCount = imageMatches.Count(); - var hasImages = imageCount > 0; - - if (hasImages) - { - var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); - var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList(); - - List imageBytes; - try - { - imageBytes = imagePaths.Select(File.ReadAllBytes).ToList(); - } - catch (IOException exception) - { - Console.ForegroundColor = ConsoleColor.Red; - Console.Write( - $"Could not load your {(imageCount == 1 ? "image" : "images")}:"); - Console.Write($"{exception.Message}"); - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("Please try again."); - break; - } - - // Each prompt with images we clear cache - // When the prompt contains images we clear KV_CACHE to restart conversation - // See: - // https://github.com/ggerganov/llama.cpp/discussions/3620 - ex.Context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 ); - - int index = 0; - foreach (var path in imagePathsWithCurlyBraces) - { - // First image replace to tag " : ""); - } - - - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine($"Here are the images, that are sent to the chat model in addition to your message."); - Console.WriteLine(); - - foreach (var consoleImage in imageBytes?.Select(bytes => new CanvasImage(bytes))) - { - consoleImage.MaxWidth = 50; - AnsiConsole.Write(consoleImage); - } - - Console.WriteLine(); - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine($"The images were scaled down for the console only, the model gets full versions."); - Console.WriteLine($"Write /exit or press Ctrl+c to return to main menu."); - Console.WriteLine(); - - - // Initialize Images in executor - // - foreach (var image in imagePaths) - { - ex.Images.Add(await File.ReadAllBytesAsync(image)); - } - } - - Console.ForegroundColor = Color.White; - await foreach (var text in ex.InferAsync(prompt, inferenceParams)) - { - Console.Write(text); - } - Console.Write(" "); - Console.ForegroundColor = ConsoleColor.Green; - prompt = Console.ReadLine(); - Console.WriteLine(); - - // let the user finish with exit - // - if (prompt != null && prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase)) - break; - - } - while(true); - } - } -} -``` \ No newline at end of file diff --git a/docs/Examples/MtmdInteractiveModeExecute.md b/docs/Examples/MtmdInteractiveModeExecute.md new file mode 100644 index 000000000..378c93a1b --- /dev/null +++ b/docs/Examples/MtmdInteractiveModeExecute.md @@ -0,0 +1,41 @@ +# MTMD interactive mode + +`MtmdInteractiveModeExecute` shows how to pair a multimodal projection with a text model so the chat loop can reason over images supplied at runtime. The sample lives in `LLama.Examples/Examples/MtmdInteractiveModeExecute.cs` and reuses the interactive executor provided by LLamaSharp. + +## Workflow +- Resolve the model, multimodal projection, and sample image paths via `UserSettings`. +- Create `ModelParams` for the text model and capture the MTMD defaults with `MtmdContextParams.Default()`. +- Load the base model and context, then initialize `SafeMtmdWeights` with the multimodal projection file. +- Ask the helper for a media marker (`mtmdParameters.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""`) and feed it into an `InteractiveExecutor`. + +```cs +var mtmdParameters = MtmdContextParams.Default(); + +using var model = await LLamaWeights.LoadFromFileAsync(parameters); +using var context = model.CreateContext(parameters); + +// Mtmd Init +using var clipModel = await SafeMtmdWeights.LoadFromFileAsync( + multiModalProj, + model, + mtmdParameters); + +var mediaMarker = mtmdParameters.MediaMarker + ?? NativeApi.MtmdDefaultMarker() + ?? ""; + +var ex = new InteractiveExecutor(context, clipModel); +``` + +## Handling user input +- Prompts can include image paths wrapped in braces (for example `{c:/image.jpg}`); the loop searches for those markers with regular expressions. +- Every referenced file is loaded through `SafeMtmdWeights.LoadMedia`, producing `SafeMtmdEmbed` instances that are queued for the next tokenization call. +- When the user provides images, the executor clears its KV cache (`MemorySequenceRemove`) before replacing each brace-wrapped path in the prompt with the multimodal marker. +- The embeds collected for the current turn are copied into `ex.Embeds`, so the executor submits both the text prompt and the pending media to the helper before generation. + +## Running the sample +1. Ensure the model and projection paths returned by `UserSettings` exist locally. +2. Start the example (for instance from the examples host application) and observe the initial description printed to the console. +3. Type text normally, or reference new images by including their path inside braces. Type `/exit` to end the conversation. + +This walkthrough mirrors the logic in the sample so you can adapt it for your own multimodal workflows. diff --git a/mkdocs.yml b/mkdocs.yml index 09cb3b96b..fbffdbba7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -38,7 +38,7 @@ nav: - Interactive executor - basic: Examples/InteractiveModeExecute.md - Kernel memory integration - basic: Examples/KernelMemory.md - Kernel-memory - save & load: Examples/KernelMemorySaveAndLoad.md - - LLaVA - basic: Examples/LLavaInteractiveModeExecute.md + - MTMD interactive: Examples/MtmdInteractiveModeExecute.md - ChatSession - load & save: Examples/LoadAndSaveSession.md - Executor - save/load state: Examples/LoadAndSaveState.md - Quantization: Examples/QuantizeModel.md @@ -254,4 +254,4 @@ markdown_extensions: custom_checkbox: true - pymdownx.tilde - pymdownx.tabbed: - alternate_style: true \ No newline at end of file + alternate_style: true From ab0d42c3bcc6110bb692d6e50de9be2e67a58b95 Mon Sep 17 00:00:00 2001 From: jlsantiago Date: Mon, 29 Sep 2025 21:56:58 +0200 Subject: [PATCH 09/35] Update LLama/Native/NativeApi.cs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- LLama/Native/NativeApi.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 3123674fc..0ea46a600 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -338,7 +338,6 @@ private static byte[] EncodeNullTerminatedUtf8(string value, string paramName) var bytes = Encoding.UTF8.GetBytes(value); var buffer = new byte[bytes.Length + 1]; Buffer.BlockCopy(bytes, 0, buffer, 0, bytes.Length); - // buffer[^1] = 0; return buffer; } From f59198590cce8a31b2aa53a0755790ed1c739c49 Mon Sep 17 00:00:00 2001 From: SignalRT Date: Mon, 29 Sep 2025 22:57:09 +0200 Subject: [PATCH 10/35] Resolve comment: https://github.com/SciSharp/LLamaSharp/pull/1261#discussion_r2386165308 --- LLama/Native/MtmdContextParams.cs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/LLama/Native/MtmdContextParams.cs b/LLama/Native/MtmdContextParams.cs index d83831d85..5b282d802 100644 --- a/LLama/Native/MtmdContextParams.cs +++ b/LLama/Native/MtmdContextParams.cs @@ -138,7 +138,16 @@ private PinnedUtf8String(string value) public static PinnedUtf8String? Create(string? value) => value is null ? null : new PinnedUtf8String(value); - public IntPtr Pointer => _buffer is null ? IntPtr.Zero : _handle.AddrOfPinnedObject(); + public IntPtr Pointer + { + get + { + if (_buffer is null || !_handle.IsAllocated) + return IntPtr.Zero; + + return _handle.AddrOfPinnedObject(); + } + } public void Dispose() { From 03d444109e54f86e0d933f79179a92ca0cb395ec Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sun, 5 Oct 2025 13:47:51 +0200 Subject: [PATCH 11/35] Remove duplicate code --- LLama/Native/SafeMtmdModelHandle.cs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/LLama/Native/SafeMtmdModelHandle.cs b/LLama/Native/SafeMtmdModelHandle.cs index 236a22011..86abf8c6c 100644 --- a/LLama/Native/SafeMtmdModelHandle.cs +++ b/LLama/Native/SafeMtmdModelHandle.cs @@ -40,7 +40,7 @@ public static SafeMtmdModelHandle LoadFromFile(string modelPath, LLamaWeights te // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases. using (var fs = new FileStream(modelPath, FileMode.Open)) if (!fs.CanRead) - throw new InvalidOperationException($"Mtmd MMP Model file '{modelPath}' is not readable"); + throw new InvalidOperationException($"Mtmd Model file '{modelPath}' is not readable"); using var pathUtf8 = PinnedUtf8String.Create(modelPath) ?? throw new ArgumentNullException(nameof(modelPath)); @@ -138,21 +138,13 @@ public int Tokenize(string text, bool addSpecial, bool parseSpecial, out SafeMtm if (result == 0) { chunks = new SafeMtmdInputChunks(output); - foreach (var media in _pendingMedia) - media.Dispose(); - _pendingMedia.Clear(); } else { NativeApi.mtmd_input_chunks_free(output); } - if (result != 0) - { - foreach (var media in _pendingMedia) - media.Dispose(); - _pendingMedia.Clear(); - } + ClearMedia(); return result; } From ea4ba82874cfb173d47e852ff392ee3001821aba Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sun, 5 Oct 2025 14:16:49 +0200 Subject: [PATCH 12/35] Move common logic to LlamaExecutorBase --- LLama/LLamaExecutorBase.cs | 195 ++++++++++++++++++++++++++++++ LLama/LLamaInstructExecutor.cs | 154 +----------------------- LLama/LLamaInteractExecutor.cs | 210 ++------------------------------- 3 files changed, 204 insertions(+), 355 deletions(-) diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index a39ad3836..bb1c27a35 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -86,6 +86,13 @@ public bool IsMultiModal /// public List Embeds { get; } + /// + /// Pending multimodal chunks produced by the MTMD tokenizer. + /// + protected SafeMtmdInputChunks? MtmdChunks { get; set; } + + private string? _mtmdMarker; + private readonly StreamingTokenDecoder _decoder; /// @@ -242,6 +249,194 @@ protected virtual void TryReuseMatchingPrefix() } } + /// + /// Dispose and clear any queued multimodal chunk collection. + /// + protected void DisposeMtmdChunks() + { + MtmdChunks?.Dispose(); + MtmdChunks = null; + } + + /// + /// Dispose and clear any pending multimodal embeddings. + /// + protected void DisposeEmbeds() + { + if (Embeds.Count == 0) + return; + + foreach (var embed in Embeds) + embed.Dispose(); + + Embeds.Clear(); + } + + /// + /// Retrieve the marker token used to signal media segments to the tokenizer. + /// + protected string GetMtmdMarker() + { + if (_mtmdMarker is not null) + return _mtmdMarker; + + _mtmdMarker = NativeApi.MtmdDefaultMarker() ?? ""; + return _mtmdMarker; + } + + /// + /// Ensure the token list fills all positional slots reported by the MTMD helper. + /// + protected static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken) + { + if (totalPositions <= tokens.Count) + return new List(tokens); + + var result = new List(totalPositions); + result.AddRange(tokens); + result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count)); + return result; + } + + /// + /// Resolve the fallback token inserted when the tokenizer emits fewer tokens than positions. + /// + protected LLamaToken GetFillerToken(string marker) + { + var markerTokens = Context.Tokenize(marker, false, true); + if (markerTokens.Length > 0) + return markerTokens[markerTokens.Length - 1]; + + var eos = Context.Vocab.EOS; + if (eos.HasValue) + return eos.Value; + + return default; + } + + /// + /// Prepare multimodal inputs by invoking the MTMD tokenizer and aligning filler tokens. + /// + protected Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, bool replaceExisting) + { + if (ClipModel is null) + throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); + + DisposeMtmdChunks(); + + var marker = GetMtmdMarker(); + var prompt = text; + + if (Embeds.Count > 0) + { + if (prompt.Contains("")) + prompt = prompt.Replace("", marker); + + if (!prompt.Contains(marker)) + { + var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count)); + prompt = string.Concat(prompt, suffix); + } + } + + SafeMtmdInputChunks? chunks = null; + try + { + var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); + if (status != 0 || chunks is null) + { + ClipModel.ClearMedia(); + throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); + } + + MtmdChunks = chunks; + + var tokens = new List(); + foreach (var chunk in chunks.Enumerate()) + { + using var scopedChunk = chunk; + if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) + continue; + + foreach (var token in scopedChunk.GetTextTokensSpan()) + tokens.Add(unchecked((int)token)); + } + + var totalPositions = (int)ClipModel.CountPositions(chunks); + var fillerToken = GetFillerToken(marker); + + if (replaceExisting) + { + _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken); + _consumedTokensCount = 0; + } + else + { + if (_embed_inps.Count == 0) + _embed_inps = new List(); + + _embed_inps.AddRange(tokens); + var fillerCount = totalPositions - tokens.Count; + if (fillerCount > 0) + _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount)); + + args.RemainedTokens -= tokens.Count; + } + } + catch + { + chunks?.Dispose(); + MtmdChunks = null; + throw; + } + finally + { + DisposeEmbeds(); + } + + return Task.CompletedTask; + } + + /// + /// Apply bookkeeping after successfully evaluating multimodal chunks. + /// + protected void FinalizeMtmdEvaluation(long newNPast, int previousConsumed) + { + _pastTokensCount = checked((int)newNPast); + DisposeMtmdChunks(); + + if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) + { + _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); + _n_session_consumed = _session_tokens.Count; + } + + _consumedTokensCount = _embed_inps.Count; + _embeds.Clear(); + } + + /// + /// Evaluate the queued MTMD chunks and update executor state. + /// + protected void EvaluateMtmdChunks(ref long nPast, int previousConsumed, string executorName) + { + if (ClipModel is null) + throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); + if (MtmdChunks is null) + throw new InvalidOperationException("No MTMD chunks are queued for evaluation."); + + var evalStatus = ClipModel.EvaluateChunks(MtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, + nBatch: checked((int)Context.BatchSize), logitsLast: true); + if (evalStatus != 0) + { + _logger?.LogError("[{Executor}] Failed to evaluate multimodal chunks. Status: {Status}", executorName, evalStatus); + DisposeMtmdChunks(); + throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); + } + + FinalizeMtmdEvaluation(nPast, previousConsumed); + } + /// /// Determine whether the inference loop should continue processing tokens. /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 1bdba035a..b7a0f9ec7 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -26,8 +26,6 @@ public class InstructExecutor private readonly string _instructionPrefix; private LLamaToken[] _inp_pfx; private LLamaToken[] _inp_sfx; - private SafeMtmdInputChunks? _mtmdChunks; - private string? _mtmdMarker; private readonly string _instructionSuffix; /// @@ -192,136 +190,6 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc return Task.CompletedTask; } - private void DisposeMtmdChunks() - { - _mtmdChunks?.Dispose(); - _mtmdChunks = null; - } - - private void DisposeEmbeds() - { - if (Embeds.Count == 0) - return; - - foreach (var embed in Embeds) - embed.Dispose(); - - Embeds.Clear(); - } - - private string GetMtmdMarker() - { - if (_mtmdMarker is not null) - return _mtmdMarker; - - _mtmdMarker = NativeApi.MtmdDefaultMarker() ?? ""; - return _mtmdMarker; - } - - private static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken) - { - if (totalPositions <= tokens.Count) - return new List(tokens); - - var result = new List(totalPositions); - result.AddRange(tokens); - result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count)); - return result; - } - - private LLamaToken GetFillerToken(string marker) - { - var markerTokens = Context.Tokenize(marker, false, true); - if (markerTokens.Length > 0) - return markerTokens[markerTokens.Length - 1]; - - var eos = Context.Vocab.EOS; - if (eos.HasValue) - return eos.Value; - - return default(LLamaToken); - } - - private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, bool replaceExisting) - { - if (ClipModel is null) - throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); - - DisposeMtmdChunks(); - - var marker = GetMtmdMarker(); - var prompt = text; - - if (Embeds.Count > 0) - { - if (prompt.Contains("")) - prompt = prompt.Replace("", marker); - - if (!prompt.Contains(marker)) - { - var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count)); - prompt = string.Concat(prompt, suffix); - } - } - - SafeMtmdInputChunks? chunks = null; - try - { - var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); - if (status != 0 || chunks is null) - { - ClipModel.ClearMedia(); - throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); - } - - _mtmdChunks = chunks; - - var tokens = new List(); - foreach (var chunk in chunks.Enumerate()) - { - using var scopedChunk = chunk; - if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) - continue; - - foreach (var token in scopedChunk.GetTextTokensSpan()) - tokens.Add(unchecked((int)token)); - } - - var totalPositions = (int)ClipModel.CountPositions(chunks); - var fillerToken = GetFillerToken(marker); - - if (replaceExisting) - { - _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken); - _consumedTokensCount = 0; - } - else - { - if (_embed_inps.Count == 0) - _embed_inps = new List(); - - _embed_inps.AddRange(tokens); - var fillerCount = totalPositions - tokens.Count; - if (fillerCount > 0) - _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount)); - - args.RemainedTokens -= tokens.Count; - } - } - catch - { - chunks?.Dispose(); - _mtmdChunks = null; - throw; - } - finally - { - DisposeEmbeds(); - } - - return Task.CompletedTask; - } - /// protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { @@ -384,30 +252,12 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In _n_session_consumed = _session_tokens.Count; } } - else if (IsMultiModal && _mtmdChunks is not null) + else if (IsMultiModal && MtmdChunks is not null) { _is_prompt_run = false; var nPast = (long)_pastTokensCount; var previousConsumed = _consumedTokensCount; - var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)Context.BatchSize), logitsLast: true); - if (evalStatus != 0) - { - _logger?.LogError("[InstructExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); - DisposeMtmdChunks(); - throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); - } - - _pastTokensCount = checked((int)nPast); - DisposeMtmdChunks(); - - if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) - { - _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); - _n_session_consumed = _session_tokens.Count; - } - - _consumedTokensCount = _embed_inps.Count; - _embeds.Clear(); + EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InstructExecutor)); } _embeds.Clear(); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 97d49f5de..da6ed53a9 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -144,7 +144,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc } else { - PreprocessMtmd(text, args, true); + return PreprocessMtmd(text, args, addBos: true, replaceExisting: true); } } else @@ -165,7 +165,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc } else { - PreprocessMtmd(text, args, false); + return PreprocessMtmd(text, args, addBos: false, replaceExisting: false); } } } @@ -173,165 +173,6 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc return Task.CompletedTask; } - /// - /// Release any queued multimodal chunks and reset state. - /// - private void DisposeMtmdChunks() - { - _mtmdChunks?.Dispose(); - _mtmdChunks = null; - } - - /// - /// Dispose and clear any pending multimodal embeddings queued for evaluation. - /// - private void DisposeEmbeds() - { - if (Embeds.Count == 0) - { - return; - } - - foreach (var embed in Embeds) - { - embed.Dispose(); - } - - Embeds.Clear(); - } - - /// - /// Retrieve the marker token used to signal media segments to the tokenizer. - /// - private string GetMtmdMarker() - { - if (_mtmdMarker is not null) - { - return _mtmdMarker; - } - - _mtmdMarker = NativeApi.MtmdDefaultMarker() ?? ""; - return _mtmdMarker; - } - - private static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken) - { - if (totalPositions <= tokens.Count) - return new List(tokens); - - var result = new List(totalPositions); - result.AddRange(tokens); - result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count)); - return result; - } - - private LLamaToken GetFillerToken(string marker) - { - var markerTokens = Context.Tokenize(marker, false, true); - if (markerTokens.Length > 0) - return markerTokens[markerTokens.Length - 1]; - - var eos = Context.Vocab.EOS; - if (eos.HasValue) - return eos.Value; - - return default(LLamaToken); - } - - /// - /// Preprocess multimodal prompts by aligning media markers and tokenizing via MTMD helpers. - /// - /// Prompt text containing optional media markers. - /// Mutable inference state. - /// Whether to treat the prompt as a fresh run and add the BOS token. - private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos = true) - { - if (ClipModel is null) - { - throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); - } - - DisposeMtmdChunks(); - - var marker = GetMtmdMarker(); - var prompt = text; - - if (Embeds.Count > 0) - { - if (prompt.Contains("")) - { - prompt = prompt.Replace("", marker); - } - - if (!prompt.Contains(marker)) - { - var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count)); // Ensure tokenizer sees one marker per embed. - prompt = string.Concat(prompt, suffix); - } - } - - SafeMtmdInputChunks? chunks = null; - try - { - var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); - if (status != 0 || chunks is null) - { - ClipModel.ClearMedia(); - throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); - } - - _mtmdChunks = chunks; // Own the chunk collection until evaluation completes. - - var tokens = new List(); - foreach (var chunk in chunks.Enumerate()) - { - using var scopedChunk = chunk; - if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) - { - continue; - } - - foreach (var token in scopedChunk.GetTextTokensSpan()) - { - tokens.Add(unchecked((int)token)); - } - } - - var totalPositions = (int)ClipModel.CountPositions(chunks); - var fillerToken = GetFillerToken(marker); - - if (addBos) - { - _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken); - _consumedTokensCount = 0; - } - else - { - if (_embed_inps.Count == 0) - _embed_inps = new List(); - - _embed_inps.AddRange(tokens); - var fillerCount = totalPositions - tokens.Count; - if (fillerCount > 0) - _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount)); - - args.RemainedTokens -= tokens.Count; - } - } - catch - { - chunks?.Dispose(); - _mtmdChunks = null; - throw; - } - finally - { - DisposeEmbeds(); // Flush any embeds decoded in prior step; MTMD replays them via chunk eval. - } - - return Task.CompletedTask; - } - /// /// Decide whether generation should stop based on antiprompts, token limits, or end-of-generation markers. /// @@ -393,35 +234,16 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In HandleRunOutOfContext(tokensToKeep); } - if (_mtmdChunks is null) + if (MtmdChunks is null) { TryReuseMatchingPrefix(); } - if (IsMultiModal && _mtmdChunks is not null) + if (IsMultiModal && MtmdChunks is not null) { var nPast = (long)_pastTokensCount; var previousConsumed = _consumedTokensCount; - var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, - nBatch: checked((int)Context.BatchSize), logitsLast: true); - if (evalStatus != 0) - { - _logger?.LogError("[InteractiveExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); - DisposeMtmdChunks(); - throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); - } - - _pastTokensCount = checked((int)nPast); - DisposeMtmdChunks(); - - if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) - { - _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); - _n_session_consumed = _session_tokens.Count; - } - - _consumedTokensCount = _embed_inps.Count; - _embeds.Clear(); + EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InteractiveExecutor)); } else { @@ -437,30 +259,12 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In } } } - else if (IsMultiModal && _mtmdChunks is not null) + else if (IsMultiModal && MtmdChunks is not null) { _is_prompt_run = false; var nPast = (long)_pastTokensCount; var previousConsumed = _consumedTokensCount; - var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)Context.BatchSize), logitsLast: true); - if (evalStatus != 0) - { - _logger?.LogError("[InteractiveExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); - DisposeMtmdChunks(); - throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); - } - - _pastTokensCount = checked((int)nPast); - DisposeMtmdChunks(); - - if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) - { - _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); - _n_session_consumed = _session_tokens.Count; - } - - _consumedTokensCount = _embed_inps.Count; - _embeds.Clear(); + EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InteractiveExecutor)); } _embeds.Clear(); From 385c62a6348b56a7b192c36297481d3262676a28 Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sat, 25 Oct 2025 10:31:18 +0200 Subject: [PATCH 13/35] Rename SafeMtmdWeights --- .../Examples/BatchedExecutorMtmd.cs | 2 +- .../Examples/MtmdInteractiveModeExecute.cs | 2 +- LLama.Unittest/MtmdExecutorTests.cs | 4 +-- LLama.Unittest/MtmdWeightsTests.cs | 28 +++++++++---------- LLama/Abstractions/ILLamaExecutor.cs | 2 +- LLama/Batched/BatchedExecutor.cs | 8 +++--- LLama/Batched/Conversation.cs | 2 +- LLama/LLamaExecutorBase.cs | 8 +++--- LLama/LLamaInstructExecutor.cs | 2 +- LLama/LLamaInteractExecutor.cs | 2 +- LLama/LLamaStatelessExecutor.cs | 2 +- LLama/{SafeMtmdWeights.cs => MtmdWeights.cs} | 10 +++---- 12 files changed, 36 insertions(+), 36 deletions(-) rename LLama/{SafeMtmdWeights.cs => MtmdWeights.cs} (87%) diff --git a/LLama.Examples/Examples/BatchedExecutorMtmd.cs b/LLama.Examples/Examples/BatchedExecutorMtmd.cs index b62f8b120..8cdc4dac5 100644 --- a/LLama.Examples/Examples/BatchedExecutorMtmd.cs +++ b/LLama.Examples/Examples/BatchedExecutorMtmd.cs @@ -30,7 +30,7 @@ public static async Task Run() mtmdParams.UseGpu = false; var marker = mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; - using var mtmd = await SafeMtmdWeights.LoadFromFileAsync(UserSettings.GetMMProjPath(), model, mtmdParams); // multimodal helper weights + using var mtmd = await MtmdWeights.LoadFromFileAsync(UserSettings.GetMMProjPath(), model, mtmdParams); // multimodal helper weights using var executor = new BatchedExecutor(model, parameters, mtmd); // drives batched token + chunk evaluation diff --git a/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs b/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs index ca0de3b77..b6395d0f8 100644 --- a/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs +++ b/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs @@ -30,7 +30,7 @@ public static async Task Run() using var context = model.CreateContext(parameters); // Mtmd Init - using var clipModel = await SafeMtmdWeights.LoadFromFileAsync(multiModalProj, model, mtmdParameters ); + using var clipModel = await MtmdWeights.LoadFromFileAsync(multiModalProj, model, mtmdParameters ); var mediaMarker = mtmdParameters.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; diff --git a/LLama.Unittest/MtmdExecutorTests.cs b/LLama.Unittest/MtmdExecutorTests.cs index 75a96b261..2533a75d8 100644 --- a/LLama.Unittest/MtmdExecutorTests.cs +++ b/LLama.Unittest/MtmdExecutorTests.cs @@ -13,7 +13,7 @@ public class MtmdExecutorTests : IDisposable { private readonly LLamaWeights _weights; private readonly MtmdContextParams _mtmdParams; - private readonly SafeMtmdWeights _mtmd; + private readonly MtmdWeights _mtmd; private readonly ModelParams _modelParams; public MtmdExecutorTests() @@ -30,7 +30,7 @@ public MtmdExecutorTests() _mtmdParams.NThreads = Math.Max(1, Constants.CIGpuLayerCount); _mtmdParams.UseGpu = false; - _mtmd = SafeMtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _weights, _mtmdParams); + _mtmd = MtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _weights, _mtmdParams); } public void Dispose() diff --git a/LLama.Unittest/MtmdWeightsTests.cs b/LLama.Unittest/MtmdWeightsTests.cs index 947bbd1ea..9ffffc518 100644 --- a/LLama.Unittest/MtmdWeightsTests.cs +++ b/LLama.Unittest/MtmdWeightsTests.cs @@ -12,7 +12,7 @@ public sealed class MtmdWeightTests : IDisposable { private readonly LLamaWeights _llamaWeights; - private readonly SafeMtmdWeights _safeMtmdWeights; + private readonly MtmdWeights _mtmdWeights; private readonly LLamaContext _context; private readonly MtmdContextParams _mtmdParams; private readonly string _mediaMarker; @@ -33,20 +33,20 @@ public MtmdWeightTests() _mediaMarker = _mtmdParams.MediaMarker ?? throw new InvalidOperationException("MTMD media marker unavailable."); - _safeMtmdWeights = SafeMtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _llamaWeights, _mtmdParams); + _mtmdWeights = MtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _llamaWeights, _mtmdParams); _context = _llamaWeights.CreateContext(@params); } public void Dispose() { _context.Dispose(); - _safeMtmdWeights.Dispose(); + _mtmdWeights.Dispose(); _llamaWeights.Dispose(); } private SafeMtmdInputChunks TokenizeWithEmbed(Func loadEmbed) { - _safeMtmdWeights.ClearMedia(); + _mtmdWeights.ClearMedia(); var embed = loadEmbed(); Assert.NotNull(embed); @@ -58,7 +58,7 @@ private SafeMtmdInputChunks TokenizeWithEmbed(Func loadEmbed) Assert.False(embed.IsAudio); Assert.True(embed.GetDataSpan().Length > 0); - var status = _safeMtmdWeights.Tokenize(_mediaMarker, addSpecial: true, parseSpecial: true, out var chunks); + var status = _mtmdWeights.Tokenize(_mediaMarker, addSpecial: true, parseSpecial: true, out var chunks); Assert.Equal(0, status); Assert.NotNull(chunks); @@ -69,7 +69,7 @@ private SafeMtmdInputChunks TokenizeWithEmbed(Func loadEmbed) private void AssertChunksEvaluate(SafeMtmdInputChunks chunks) { long nPast = 0; - var eval = _safeMtmdWeights.EvaluateChunks(chunks, _context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)_context.BatchSize), logitsLast: true); + var eval = _mtmdWeights.EvaluateChunks(chunks, _context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)_context.BatchSize), logitsLast: true); Assert.Equal(0, eval); Assert.True(nPast > 0); } @@ -77,7 +77,7 @@ private void AssertChunksEvaluate(SafeMtmdInputChunks chunks) [Fact,Trait("Category", "NoCI")] public void EmbedImageAsFileName() { - using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(Constants.MtmdImage)); + using var chunks = TokenizeWithEmbed(() => _mtmdWeights.LoadMedia(Constants.MtmdImage)); AssertChunksEvaluate(chunks); } @@ -85,14 +85,14 @@ public void EmbedImageAsFileName() public void EmbedImageAsBinary() { var imageBytes = File.ReadAllBytes(Constants.MtmdImage); - using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(imageBytes)); + using var chunks = TokenizeWithEmbed(() => _mtmdWeights.LoadMedia(imageBytes)); AssertChunksEvaluate(chunks); } [Fact,Trait("Category", "NoCI")] public void TokenizeProvidesChunkMetadata() { - using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(Constants.MtmdImage)); + using var chunks = TokenizeWithEmbed(() => _mtmdWeights.LoadMedia(Constants.MtmdImage)); Assert.True(chunks.Size > 0); @@ -128,12 +128,12 @@ public void TokenizeProvidesChunkMetadata() Assert.True(imageChunks > 0); Assert.True(totalTokens > 0); - Assert.Equal(totalTokens, _safeMtmdWeights.CountTokens(chunks)); - Assert.Equal(totalPositions, _safeMtmdWeights.CountPositions(chunks)); - Assert.True(_safeMtmdWeights.SupportsVision); - Assert.False(_safeMtmdWeights.SupportsAudio); + Assert.Equal(totalTokens, _mtmdWeights.CountTokens(chunks)); + Assert.Equal(totalPositions, _mtmdWeights.CountPositions(chunks)); + Assert.True(_mtmdWeights.SupportsVision); + Assert.False(_mtmdWeights.SupportsAudio); - var audioBitrate = _safeMtmdWeights.AudioBitrate; + var audioBitrate = _mtmdWeights.AudioBitrate; Assert.True(audioBitrate <= 0); } } diff --git a/LLama/Abstractions/ILLamaExecutor.cs b/LLama/Abstractions/ILLamaExecutor.cs index 92276e4a6..974a4ffbf 100644 --- a/LLama/Abstractions/ILLamaExecutor.cs +++ b/LLama/Abstractions/ILLamaExecutor.cs @@ -23,7 +23,7 @@ public interface ILLamaExecutor /// /// Multi-Modal Projections / Clip Model weights /// - public SafeMtmdWeights? ClipModel { get; } + public MtmdWeights? ClipModel { get; } /// /// List of media: List of media for Multi-Modal models. diff --git a/LLama/Batched/BatchedExecutor.cs b/LLama/Batched/BatchedExecutor.cs index 40468c98d..1a6698b1a 100644 --- a/LLama/Batched/BatchedExecutor.cs +++ b/LLama/Batched/BatchedExecutor.cs @@ -84,7 +84,7 @@ public BatchedExecutor(LLamaWeights model, IContextParams contextParams) { } - public BatchedExecutor(LLamaWeights model, IContextParams contextParams, SafeMtmdWeights? clipModel) + public BatchedExecutor(LLamaWeights model, IContextParams contextParams, MtmdWeights? clipModel) { Model = model; Context = model.CreateContext(contextParams); @@ -92,7 +92,7 @@ public BatchedExecutor(LLamaWeights model, IContextParams contextParams, SafeMtm Epoch = 1; } - public SafeMtmdWeights? ClipModel { get; } + public MtmdWeights? ClipModel { get; } /// /// Start a new @@ -374,11 +374,11 @@ public Task DecodeAsync(LLamaContext ctx, CancellationToken token) private class MtmdChunkBatch : IBatch { - private readonly SafeMtmdWeights _clipModel; + private readonly MtmdWeights _clipModel; private readonly Conversation _conversation; private readonly Conversation.MtmdChunkSequence _sequence; - public MtmdChunkBatch(SafeMtmdWeights clipModel, Conversation conversation, Conversation.MtmdChunkSequence sequence) + public MtmdChunkBatch(MtmdWeights clipModel, Conversation conversation, Conversation.MtmdChunkSequence sequence) { _clipModel = clipModel; _conversation = conversation; diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs index 2311c8a0c..89d725e97 100644 --- a/LLama/Batched/Conversation.cs +++ b/LLama/Batched/Conversation.cs @@ -86,7 +86,7 @@ private MtmdChunkSequence(SafeMtmdInputChunks chunks, List textToken TotalPositions = totalPositions; } - public static MtmdChunkSequence Create(SafeMtmdInputChunks chunks, SafeMtmdWeights clipModel) + public static MtmdChunkSequence Create(SafeMtmdInputChunks chunks, MtmdWeights clipModel) { var textTokens = new List(); diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index bb1c27a35..1e264ebab 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -81,7 +81,7 @@ public bool IsMultiModal } /// - public SafeMtmdWeights? ClipModel { get; } + public MtmdWeights? ClipModel { get; } /// public List Embeds { get; } @@ -117,12 +117,12 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) /// Initialize a multimodal executor with the supplied MTMD weights. /// /// LLama context used for all native interactions. - /// Multimodal weights to associate with this executor. + /// Multimodal weights to associate with this executor. /// Optional logger for diagnostic output. - public StatefulExecutorBase(LLamaContext context, SafeMtmdWeights safeMtmdWeights, ILogger? logger = null) : + public StatefulExecutorBase(LLamaContext context, MtmdWeights mtmdWeights, ILogger? logger = null) : this( context, logger ) { - ClipModel = safeMtmdWeights; + ClipModel = mtmdWeights; } /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index b7a0f9ec7..a31cab211 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -48,7 +48,7 @@ public InstructExecutor(LLamaContext context, } public InstructExecutor(LLamaContext context, - SafeMtmdWeights clipModel, + MtmdWeights clipModel, string instructionPrefix = "\n\n### Instruction:\n\n", string instructionSuffix = "\n\n### Response:\n\n", ILogger? logger = null) diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index da6ed53a9..1359447c4 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -46,7 +46,7 @@ public InteractiveExecutor(LLamaContext context, ILogger? logger = null) /// LLama context to operate against. /// Multimodal weights (MTMD) to attach to the executor. /// Optional logger for diagnostic output. - public InteractiveExecutor(LLamaContext context, SafeMtmdWeights clipModel, ILogger? logger = null) + public InteractiveExecutor(LLamaContext context, MtmdWeights clipModel, ILogger? logger = null) : base(context, clipModel, logger) { } diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 94bc60830..a895054d4 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -28,7 +28,7 @@ public class StatelessExecutor public bool IsMultiModal => false; /// - public SafeMtmdWeights? ClipModel => default; + public MtmdWeights? ClipModel => default; /// public List Embeds { get; } diff --git a/LLama/SafeMtmdWeights.cs b/LLama/MtmdWeights.cs similarity index 87% rename from LLama/SafeMtmdWeights.cs rename to LLama/MtmdWeights.cs index e490049b4..945ccecc2 100644 --- a/LLama/SafeMtmdWeights.cs +++ b/LLama/MtmdWeights.cs @@ -9,26 +9,26 @@ namespace LLama; /// /// Lightweight wrapper around the MTMD native context and its helpers. /// -public sealed class SafeMtmdWeights : IDisposable +public sealed class MtmdWeights : IDisposable { public SafeMtmdModelHandle NativeHandle { get; } - private SafeMtmdWeights(SafeMtmdModelHandle handle) + private MtmdWeights(SafeMtmdModelHandle handle) { NativeHandle = handle ?? throw new ArgumentNullException(nameof(handle)); } - public static SafeMtmdWeights LoadFromFile(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams) + public static MtmdWeights LoadFromFile(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams) { if (mmProject == null) throw new ArgumentNullException(nameof(mmProject)); if (textModel == null) throw new ArgumentNullException(nameof(textModel)); if (mtmdCtxParams == null) throw new ArgumentNullException(nameof(mtmdCtxParams)); var handle = SafeMtmdModelHandle.LoadFromFile(mmProject, textModel, mtmdCtxParams); - return new SafeMtmdWeights(handle); + return new MtmdWeights(handle); } - public static Task LoadFromFileAsync(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams, CancellationToken token = default) + public static Task LoadFromFileAsync(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams, CancellationToken token = default) { return Task.Run(() => LoadFromFile(mmProject, textModel, mtmdCtxParams), token); } From 2cdcc5ac2180657b44d318ff9799224c2e76904e Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sat, 25 Oct 2025 10:48:18 +0200 Subject: [PATCH 14/35] Implement SafeHandle --- LLama/Native/SafeMtmdEmbed.cs | 104 +++++++++++++++++--------- LLama/Native/SafeMtmdInputChunk.cs | 112 +++++++++++++++++----------- LLama/Native/SafeMtmdInputChunks.cs | 80 ++++++++++++-------- 3 files changed, 188 insertions(+), 108 deletions(-) diff --git a/LLama/Native/SafeMtmdEmbed.cs b/LLama/Native/SafeMtmdEmbed.cs index c651db102..426b77c79 100644 --- a/LLama/Native/SafeMtmdEmbed.cs +++ b/LLama/Native/SafeMtmdEmbed.cs @@ -8,20 +8,25 @@ namespace LLama.Native /// Managed wrapper around mtmd_bitmap* resources. Instances own the native pointer /// and ensure proper cleanup when disposed. /// - public sealed class SafeMtmdEmbed : IDisposable + public sealed class SafeMtmdEmbed : SafeLLamaHandleBase { /// /// Raw pointer to the native bitmap structure. Internal so other wrappers can interop. /// - internal IntPtr NativePtr { get; private set; } - - private bool _disposed; + internal IntPtr NativePtr + { + get + { + EnsureNotDisposed(); + return DangerousGetHandle(); + } + } private SafeMtmdEmbed(IntPtr ptr) + : base(ptr, ownsHandle: true) { - NativePtr = ptr != IntPtr.Zero - ? ptr - : throw new InvalidOperationException("Failed to create MTMD bitmap."); + if (IsInvalid) + throw new InvalidOperationException("Failed to create MTMD bitmap."); } /// @@ -154,8 +159,7 @@ public uint Nx { get { - EnsureNotDisposed(); - return NativeApi.mtmd_bitmap_get_nx(NativePtr); + return WithHandle(ptr => NativeApi.mtmd_bitmap_get_nx(ptr)); } } @@ -166,8 +170,7 @@ public uint Ny { get { - EnsureNotDisposed(); - return NativeApi.mtmd_bitmap_get_ny(NativePtr); + return WithHandle(ptr => NativeApi.mtmd_bitmap_get_ny(ptr)); } } @@ -178,8 +181,7 @@ public bool IsAudio { get { - EnsureNotDisposed(); - return NativeApi.mtmd_bitmap_is_audio(NativePtr); + return WithHandle(ptr => NativeApi.mtmd_bitmap_is_audio(ptr)); } } @@ -190,14 +192,19 @@ public string? Id { get { - EnsureNotDisposed(); - var ptr = NativeApi.mtmd_bitmap_get_id(NativePtr); - return NativeApi.PtrToStringUtf8(ptr); + return WithHandle(ptr => + { + var idPtr = NativeApi.mtmd_bitmap_get_id(ptr); + return NativeApi.PtrToStringUtf8(idPtr); + }); } set { - EnsureNotDisposed(); - NativeApi.mtmd_bitmap_set_id(NativePtr, value); + WithHandle(ptr => + { + NativeApi.mtmd_bitmap_set_id(ptr, value); + return 0; + }); } } @@ -210,38 +217,63 @@ public unsafe ReadOnlySpan GetDataSpan() { EnsureNotDisposed(); - var dataPtr = (byte*)NativeApi.mtmd_bitmap_get_data(NativePtr); - var length = checked((int)NativeApi.mtmd_bitmap_get_n_bytes(NativePtr).ToUInt64()); - return dataPtr == null || length == 0 ? ReadOnlySpan.Empty : new ReadOnlySpan(dataPtr, length); + bool added = false; + try + { + DangerousAddRef(ref added); + var ptr = DangerousGetHandle(); + var dataPtr = (byte*)NativeApi.mtmd_bitmap_get_data(ptr); + var length = checked((int)NativeApi.mtmd_bitmap_get_n_bytes(ptr).ToUInt64()); + return dataPtr == null || length == 0 + ? ReadOnlySpan.Empty + : new ReadOnlySpan(dataPtr, length); + } + finally + { + if (added) + DangerousRelease(); + } } /// /// Release the underlying native bitmap. /// - public void Dispose() + protected override bool ReleaseHandle() { - if (_disposed) - return; - - if (NativePtr != IntPtr.Zero) + if (handle != IntPtr.Zero) { - NativeApi.mtmd_bitmap_free(NativePtr); - NativePtr = IntPtr.Zero; + NativeApi.mtmd_bitmap_free(handle); + SetHandle(IntPtr.Zero); } - _disposed = true; - GC.SuppressFinalize(this); + return true; } - /// - /// Finalizer to ensure native resources are reclaimed when Dispose is not invoked. - /// - ~SafeMtmdEmbed() => Dispose(); - private void EnsureNotDisposed() { - if (_disposed || NativePtr == IntPtr.Zero) + if (IsClosed || IsInvalid) throw new ObjectDisposedException(nameof(SafeMtmdEmbed)); } + + private T WithHandle(Func action) + { + EnsureNotDisposed(); + + bool added = false; + try + { + DangerousAddRef(ref added); + var ptr = DangerousGetHandle(); + if (ptr == IntPtr.Zero) + throw new ObjectDisposedException(nameof(SafeMtmdEmbed)); + + return action(ptr); + } + finally + { + if (added) + DangerousRelease(); + } + } } } diff --git a/LLama/Native/SafeMtmdInputChunk.cs b/LLama/Native/SafeMtmdInputChunk.cs index 59d1897ef..3ab4f7ccc 100644 --- a/LLama/Native/SafeMtmdInputChunk.cs +++ b/LLama/Native/SafeMtmdInputChunk.cs @@ -8,7 +8,7 @@ namespace LLama.Native; /// underlying native pointer (when created via ) or act as non-owning views /// produced by the tokenizer. /// -public sealed class SafeMtmdInputChunk : IDisposable +public sealed class SafeMtmdInputChunk : SafeLLamaHandleBase { /// /// Chunk modality returned by the native tokenizer. @@ -23,15 +23,18 @@ public enum SafeMtmdInputChunkType /// /// Raw pointer to the native chunk structure. /// - public IntPtr NativePtr { get; private set; } - - private bool _ownsPtr; - private bool _disposed; + public IntPtr NativePtr + { + get + { + EnsureNotDisposed(); + return DangerousGetHandle(); + } + } - private SafeMtmdInputChunk(IntPtr ptr, bool owns) + private SafeMtmdInputChunk(IntPtr handle, bool ownsHandle) + : base(handle, ownsHandle) { - NativePtr = ptr; - _ownsPtr = owns; } /// @@ -40,7 +43,7 @@ private SafeMtmdInputChunk(IntPtr ptr, bool owns) /// Pointer returned by the native tokenizer. /// Managed wrapper, or null when the pointer is null. public static SafeMtmdInputChunk Wrap(IntPtr ptr) - => ptr == IntPtr.Zero ? null : new SafeMtmdInputChunk(ptr, false); + => ptr == IntPtr.Zero ? null : new SafeMtmdInputChunk(ptr, ownsHandle: false); /// /// Create an owning copy of the current chunk. The caller becomes responsible for disposal. @@ -49,10 +52,11 @@ public static SafeMtmdInputChunk Wrap(IntPtr ptr) /// Thrown when the current wrapper has been disposed. public SafeMtmdInputChunk Copy() { - EnsureNotDisposed(); - - var p = NativeApi.mtmd_input_chunk_copy(NativePtr); - return p == IntPtr.Zero ? null : new SafeMtmdInputChunk(p, true); + return WithHandle(ptr => + { + var clone = NativeApi.mtmd_input_chunk_copy(ptr); + return clone == IntPtr.Zero ? null : new SafeMtmdInputChunk(clone, ownsHandle: true); + }); } /// @@ -62,8 +66,7 @@ public SafeMtmdInputChunkType Type { get { - EnsureNotDisposed(); - return (SafeMtmdInputChunkType)NativeApi.mtmd_input_chunk_get_type(NativePtr); + return WithHandle(ptr => (SafeMtmdInputChunkType)NativeApi.mtmd_input_chunk_get_type(ptr)); } } @@ -74,8 +77,7 @@ public ulong NTokens { get { - EnsureNotDisposed(); - return NativeApi.mtmd_input_chunk_get_n_tokens(NativePtr).ToUInt64(); + return WithHandle(ptr => NativeApi.mtmd_input_chunk_get_n_tokens(ptr).ToUInt64()); } } @@ -86,8 +88,11 @@ public string Id { get { - EnsureNotDisposed(); - return Marshal.PtrToStringAnsi(NativeApi.mtmd_input_chunk_get_id(NativePtr)) ?? string.Empty; + return WithHandle(ptr => + { + var idPtr = NativeApi.mtmd_input_chunk_get_id(ptr); + return Marshal.PtrToStringAnsi(idPtr) ?? string.Empty; + }); } } @@ -98,8 +103,7 @@ public long NPos { get { - EnsureNotDisposed(); - return NativeApi.mtmd_input_chunk_get_n_pos(NativePtr); + return WithHandle(ptr => NativeApi.mtmd_input_chunk_get_n_pos(ptr)); } } @@ -112,39 +116,63 @@ public unsafe ReadOnlySpan GetTextTokensSpan() { EnsureNotDisposed(); - UIntPtr n; - var p = (uint*)NativeApi.mtmd_input_chunk_get_tokens_text(NativePtr, out n); - return p == null ? ReadOnlySpan.Empty : new ReadOnlySpan(p, checked((int)n.ToUInt64())); + bool added = false; + try + { + DangerousAddRef(ref added); + UIntPtr nTokens; + var tokensPtr = (uint*)NativeApi.mtmd_input_chunk_get_tokens_text(DangerousGetHandle(), out nTokens); + if (tokensPtr == null) + return ReadOnlySpan.Empty; + + var length = checked((int)nTokens.ToUInt64()); + return new ReadOnlySpan(tokensPtr, length); + } + finally + { + if (added) + DangerousRelease(); + } } /// - /// Release the underlying native resources if this instance owns them. + /// Releases the native chunk when ownership is held by this instance. /// - public void Dispose() + protected override bool ReleaseHandle() { - if (_disposed) - return; - - if (_ownsPtr && NativePtr != IntPtr.Zero) + if (handle != IntPtr.Zero) { - NativeApi.mtmd_input_chunk_free(NativePtr); + NativeApi.mtmd_input_chunk_free(handle); + SetHandle(IntPtr.Zero); } - NativePtr = IntPtr.Zero; - _ownsPtr = false; - _disposed = true; - - GC.SuppressFinalize(this); + return true; } - /// - /// Finalizer to ensure native memory is reclaimed when Dispose is not called by owners. - /// - ~SafeMtmdInputChunk() => Dispose(); - private void EnsureNotDisposed() { - if (_disposed || NativePtr == IntPtr.Zero) + if (IsClosed || IsInvalid) throw new ObjectDisposedException(nameof(SafeMtmdInputChunk)); } + + private T WithHandle(Func action) + { + EnsureNotDisposed(); + + bool added = false; + try + { + DangerousAddRef(ref added); + var ptr = DangerousGetHandle(); + if (ptr == IntPtr.Zero) + throw new ObjectDisposedException(nameof(SafeMtmdInputChunk)); + + return action(ptr); + } + finally + { + if (added) + DangerousRelease(); + } + } } diff --git a/LLama/Native/SafeMtmdInputChunks.cs b/LLama/Native/SafeMtmdInputChunks.cs index 2081cd0a6..bd095f36a 100644 --- a/LLama/Native/SafeMtmdInputChunks.cs +++ b/LLama/Native/SafeMtmdInputChunks.cs @@ -6,44 +6,39 @@ namespace LLama.Native; /// /// Managed lifetime wrapper around a native mtmd_input_chunks collection returned by the tokenizer. /// -public sealed class SafeMtmdInputChunks : IDisposable +public sealed class SafeMtmdInputChunks : SafeLLamaHandleBase { /// /// Raw pointer to the native chunk collection. Internal to allow other wrappers to interop safely. /// - internal IntPtr NativePtr { get; private set; } - - private bool _disposed; + internal IntPtr NativePtr + { + get + { + EnsureNotDisposed(); + return DangerousGetHandle(); + } + } internal SafeMtmdInputChunks(IntPtr ptr) + : base(ptr, ownsHandle: true) { - NativePtr = ptr; + if (IsInvalid) + throw new InvalidOperationException("Native MTMD chunk collection pointer is null."); } /// - /// Releases the native chunk collection and suppresses finalization. + /// Releases the native chunk collection. /// - public void Dispose() + protected override bool ReleaseHandle() { - if (_disposed) - return; - - if (NativePtr != IntPtr.Zero) + if (handle != IntPtr.Zero) { - NativeApi.mtmd_input_chunks_free(NativePtr); - NativePtr = IntPtr.Zero; + NativeApi.mtmd_input_chunks_free(handle); + SetHandle(IntPtr.Zero); } - _disposed = true; - GC.SuppressFinalize(this); - } - - /// - /// Finalizer to ensure native memory is reclaimed if Dispose is not called. - /// - ~SafeMtmdInputChunks() - { - Dispose(); + return true; } /// @@ -53,8 +48,7 @@ public ulong Size { get { - EnsureNotDisposed(); - return NativeApi.mtmd_input_chunks_size(NativePtr).ToUInt64(); + return WithHandle(ptr => NativeApi.mtmd_input_chunks_size(ptr).ToUInt64()); } } @@ -68,10 +62,14 @@ public ulong Size /// The requested index is outside of the valid range. public IntPtr GetChunkPtr(ulong index) { - EnsureNotDisposed(); + return WithHandle(ptr => + { + var size = NativeApi.mtmd_input_chunks_size(ptr).ToUInt64(); + if (index >= size) + throw new IndexOutOfRangeException(); - if (index >= Size) throw new IndexOutOfRangeException(); - return NativeApi.mtmd_input_chunks_get(NativePtr, (UIntPtr)index); + return NativeApi.mtmd_input_chunks_get(ptr, (UIntPtr)index); + }); } /// @@ -84,7 +82,8 @@ public IEnumerable Enumerate() { EnsureNotDisposed(); - for (ulong i = 0; i < Size; i++) + var count = Size; + for (ulong i = 0; i < count; i++) { var chunk = SafeMtmdInputChunk.Wrap(GetChunkPtr(i)); if (chunk != null) @@ -97,7 +96,28 @@ public IEnumerable Enumerate() private void EnsureNotDisposed() { - if (_disposed || NativePtr == IntPtr.Zero) + if (IsClosed || IsInvalid) throw new ObjectDisposedException(nameof(SafeMtmdInputChunks)); } + + private T WithHandle(Func action) + { + EnsureNotDisposed(); + + bool added = false; + try + { + DangerousAddRef(ref added); + var ptr = DangerousGetHandle(); + if (ptr == IntPtr.Zero) + throw new ObjectDisposedException(nameof(SafeMtmdInputChunks)); + + return action(ptr); + } + finally + { + if (added) + DangerousRelease(); + } + } } From 63d8ce4489069b5faa11a00259ce0b2d0ea98741 Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sat, 25 Oct 2025 14:38:05 +0200 Subject: [PATCH 15/35] Add IntPtrExtension To manage PtrToString conversions --- LLama/Extensions/IModelParamsExtensions.cs | 4 +- LLama/Extensions/IntPtrExtensions.cs | 51 ++++++++++++++++++++++ LLama/Native/MtmdContextParams.cs | 28 +----------- LLama/Native/NativeApi.Mtmd.cs | 32 +------------- LLama/Native/SafeLLamaSamplerHandle.cs | 6 ++- LLama/Native/SafeMtmdEmbed.cs | 2 +- LLama/Native/SafeMtmdInputChunk.cs | 4 +- 7 files changed, 64 insertions(+), 63 deletions(-) create mode 100644 LLama/Extensions/IntPtrExtensions.cs diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs index 2939318da..6c307861d 100644 --- a/LLama/Extensions/IModelParamsExtensions.cs +++ b/LLama/Extensions/IModelParamsExtensions.cs @@ -115,7 +115,7 @@ private static IReadOnlyDictionary GetAvailableBufferTypes() var dev = NativeApi.ggml_backend_dev_get(i); var buft = NativeApi.ggml_backend_dev_buffer_type(dev); - var name = Marshal.PtrToStringAnsi(NativeApi.ggml_backend_buft_name(buft)); + var name = NativeApi.ggml_backend_buft_name(buft).PtrToString(); if (string.IsNullOrEmpty(name)) continue; @@ -165,4 +165,4 @@ private static IReadOnlyDictionary GetAvailableBufferTypes() return (LLamaModelTensorBufferOverride*)overrideArrayPin.Pointer; } -} \ No newline at end of file +} diff --git a/LLama/Extensions/IntPtrExtensions.cs b/LLama/Extensions/IntPtrExtensions.cs new file mode 100644 index 000000000..eb5c90850 --- /dev/null +++ b/LLama/Extensions/IntPtrExtensions.cs @@ -0,0 +1,51 @@ +using System; +using System.Runtime.InteropServices; +using System.Text; + +namespace LLama.Extensions; + +public static class IntPtrExtensions +{ + + /// + /// Converts a native UTF-8 string pointer to a managed string, returning a fallback value when no data is available. + /// + /// Pointer to a null-terminated UTF-8 string. + /// Value to return when the pointer is or when the string is empty. + /// Managed string representation of the native data, or when unavailable. + public static string PtrToStringWithDefault(this IntPtr ptr, string defaultValue="") + { + return ptr.PtrToString() ?? defaultValue; + } + + /// + /// Converts a pointer to a null-terminated UTF-8 string into a managed string. + /// + /// Pointer to the first byte of a null-terminated UTF-8 string. + /// Managed string representation, or null when the pointer is zero or the string is empty. + public static string? PtrToString(this IntPtr ptr ) + { + if (ptr == IntPtr.Zero) + return null; + +#if NETSTANDARD2_0 + unsafe + { + var length = 0; + var current = (byte*)ptr; + while (current[length] != 0) + length++; + + if (length == 0) + return null; + + var buffer = new byte[length]; + Marshal.Copy(ptr, buffer, 0, length); + return Encoding.UTF8.GetString(buffer); + } +#else + return Marshal.PtrToStringUTF8(ptr); +#endif + } + +} diff --git a/LLama/Native/MtmdContextParams.cs b/LLama/Native/MtmdContextParams.cs index 5b282d802..fc8d6b5f8 100644 --- a/LLama/Native/MtmdContextParams.cs +++ b/LLama/Native/MtmdContextParams.cs @@ -51,35 +51,11 @@ public static MtmdContextParams Default() PrintTimings = native.print_timings, NThreads = native.n_threads, Verbosity = native.verbosity, - ImageMarker = PtrToString(native.image_marker), - MediaMarker = PtrToString(native.media_marker) + ImageMarker = native.image_marker.PtrToString(), + MediaMarker = native.media_marker.PtrToString() }; } - private static string? PtrToString(IntPtr ptr) - { - if (ptr == IntPtr.Zero) - return null; - -#if NETSTANDARD2_0 - unsafe - { - var length = 0; - var current = (byte*)ptr; - while (current[length] != 0) - length++; - - if (length == 0) - return string.Empty; - - var buffer = new byte[length]; - Marshal.Copy(ptr, buffer, 0, length); - return Encoding.UTF8.GetString(buffer); - } -#else - return Marshal.PtrToStringUTF8(ptr); -#endif - } /// /// Convert the managed representation to a native structure, pinning strings for the duration of the scope. diff --git a/LLama/Native/NativeApi.Mtmd.cs b/LLama/Native/NativeApi.Mtmd.cs index bfd6193c2..0aa8a314a 100644 --- a/LLama/Native/NativeApi.Mtmd.cs +++ b/LLama/Native/NativeApi.Mtmd.cs @@ -9,35 +9,7 @@ namespace LLama.Native; /// public static partial class NativeApi { - /// - /// Convert a UTF-8 encoded native string pointer into a managed . - /// Returns null when the pointer is zero. - /// - public static string? PtrToStringUtf8(IntPtr ptr) - { - if (ptr == IntPtr.Zero) - return null; - -#if NETSTANDARD2_0 - unsafe - { - var current = (byte*)ptr; - var length = 0; - while (current[length] != 0) - length++; - - if (length == 0) - return string.Empty; - - var buffer = new byte[length]; - Marshal.Copy(ptr, buffer, 0, length); - return Encoding.UTF8.GetString(buffer); - } -#else - return Marshal.PtrToStringUTF8(ptr); -#endif - } - + /// /// Native context parameters returned by . /// @@ -59,7 +31,7 @@ internal struct mtmd_context_params /// Retrieve the default multimodal marker text. /// public static string? MtmdDefaultMarker() - => PtrToStringUtf8(mtmd_default_marker()); + => mtmd_default_marker().PtrToString(); [DllImport(mtmdLibraryName, EntryPoint = "mtmd_context_params_default", CallingConvention = CallingConvention.Cdecl)] internal static extern mtmd_context_params mtmd_context_params_default(); diff --git a/LLama/Native/SafeLLamaSamplerHandle.cs b/LLama/Native/SafeLLamaSamplerHandle.cs index bad1a1974..fe8ccb4e6 100644 --- a/LLama/Native/SafeLLamaSamplerHandle.cs +++ b/LLama/Native/SafeLLamaSamplerHandle.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; +using System.Runtime.InteropServices; using System.Text; +using LLama.Extensions; namespace LLama.Native; @@ -119,7 +121,7 @@ public string GetName(int index) if (index < 0 || index >= Count) throw new ArgumentOutOfRangeException(nameof(index)); - return Marshal.PtrToStringAnsi(llama_sampler_name(llama_sampler_chain_get(this, index))) ?? "Unknown Name"; + return llama_sampler_name(llama_sampler_chain_get(this, index)).PtrToStringWithDefault("Unknown Name"); [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] static extern IntPtr llama_sampler_name(IntPtr smpl); @@ -904,4 +906,4 @@ public interface ICustomSampler /// Create a clone of this sampler /// ICustomSampler Clone(); -} \ No newline at end of file +} diff --git a/LLama/Native/SafeMtmdEmbed.cs b/LLama/Native/SafeMtmdEmbed.cs index 426b77c79..2e92b96a7 100644 --- a/LLama/Native/SafeMtmdEmbed.cs +++ b/LLama/Native/SafeMtmdEmbed.cs @@ -195,7 +195,7 @@ public string? Id return WithHandle(ptr => { var idPtr = NativeApi.mtmd_bitmap_get_id(ptr); - return NativeApi.PtrToStringUtf8(idPtr); + return idPtr.PtrToString(); }); } set diff --git a/LLama/Native/SafeMtmdInputChunk.cs b/LLama/Native/SafeMtmdInputChunk.cs index 3ab4f7ccc..fe08d50f7 100644 --- a/LLama/Native/SafeMtmdInputChunk.cs +++ b/LLama/Native/SafeMtmdInputChunk.cs @@ -1,5 +1,5 @@ using System; -using System.Runtime.InteropServices; +using LLama.Extensions; namespace LLama.Native; @@ -91,7 +91,7 @@ public string Id return WithHandle(ptr => { var idPtr = NativeApi.mtmd_input_chunk_get_id(ptr); - return Marshal.PtrToStringAnsi(idPtr) ?? string.Empty; + return idPtr.PtrToStringWithDefault(string.Empty); }); } } From 32edd6f56c91179db692253de2323daec397092d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Luis=20Santiago?= Date: Sun, 26 Oct 2025 17:21:57 +0100 Subject: [PATCH 16/35] Solve bad DLL naming in Windows with MTMD libraries --- LLama/LLamaSharp.Runtime.targets | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/LLama/LLamaSharp.Runtime.targets b/LLama/LLamaSharp.Runtime.targets index 89caa042a..c99c508dd 100644 --- a/LLama/LLamaSharp.Runtime.targets +++ b/LLama/LLamaSharp.Runtime.targets @@ -396,19 +396,19 @@ PreserveNewest - runtimes/win-x64/native/noavx/libmtmd.dll + runtimes/win-x64/native/noavx/mtmd.dll PreserveNewest - runtimes/win-x64/native/avx/libmtmd.dll + runtimes/win-x64/native/avx/mtmd.dll PreserveNewest - runtimes/win-x64/native/avx2/libmtmd.dll + runtimes/win-x64/native/avx2/mtmd.dll PreserveNewest - runtimes/win-x64/native/avx512/libmtmd.dll + runtimes/win-x64/native/avx512/mtmd.dll PreserveNewest @@ -416,7 +416,7 @@ PreserveNewest - runtimes/win-x64/native/vulkan/libmtmd.dll + runtimes/win-x64/native/vulkan/mtmd.dll From 0990be3c096e6b0641120ab18ef468bfea7c18f6 Mon Sep 17 00:00:00 2001 From: Krisbiradar Date: Thu, 30 Oct 2025 01:39:38 +0530 Subject: [PATCH 17/35] Enable FlashAttention and clean up P/Invoke signatures Set FlashAttention to true in BuilderExtensions for improved performance. Refactored NativeApi P/Invoke method signatures to single-line format for better readability and consistency. --- LLama.KernelMemory/BuilderExtensions.cs | 1 + LLama/Native/NativeApi.cs | 21 +++++++-------------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/LLama.KernelMemory/BuilderExtensions.cs b/LLama.KernelMemory/BuilderExtensions.cs index 0aae8e69d..6ab04a8bc 100644 --- a/LLama.KernelMemory/BuilderExtensions.cs +++ b/LLama.KernelMemory/BuilderExtensions.cs @@ -77,6 +77,7 @@ public static IKernelMemoryBuilder WithLLamaSharpDefaults(this IKernelMemoryBuil SplitMode = config.SplitMode, BatchSize = 512, UBatchSize = 512, + FlashAttention = true, UseMemorymap = true }; diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 0a5ad6003..30aadc25e 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -99,8 +99,7 @@ public static void llama_empty_call() /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [return: MarshalAs(UnmanagedType.U1)] - public static extern bool llama_state_load_file(SafeLLamaContextHandle ctx, string path_session, - LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out); + public static extern bool llama_state_load_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out); /// /// Save session file @@ -112,29 +111,25 @@ public static extern bool llama_state_load_file(SafeLLamaContextHandle ctx, stri /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [return: MarshalAs(UnmanagedType.U1)] - public static extern bool llama_state_save_file(SafeLLamaContextHandle ctx, string path_session, - LLamaToken[] tokens, ulong n_token_count); + public static extern bool llama_state_save_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count); /// /// Saves the specified sequence as a file on specified filepath. Can later be loaded via /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe nuint llama_state_seq_save_file(SafeLLamaContextHandle ctx, string filepath, - LLamaSeqId seq_id, LLamaToken* tokens, nuint n_token_count); + public static extern unsafe nuint llama_state_seq_save_file(SafeLLamaContextHandle ctx, string filepath, LLamaSeqId seq_id, LLamaToken* tokens, nuint n_token_count); /// /// Loads a sequence saved as a file via into the specified sequence /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe nuint llama_state_seq_load_file(SafeLLamaContextHandle ctx, string filepath, - LLamaSeqId dest_seq_id, LLamaToken* tokens_out, nuint n_token_capacity, out nuint n_token_count_out); + public static extern unsafe nuint llama_state_seq_load_file(SafeLLamaContextHandle ctx, string filepath, LLamaSeqId dest_seq_id, LLamaToken* tokens_out, nuint n_token_capacity, out nuint n_token_count_out); /// /// Set whether to use causal attention or not. If set to true, the model will only attend to the past tokens /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_set_causal_attn(SafeLLamaContextHandle ctx, - [MarshalAs(UnmanagedType.U1)] bool causalAttn); + public static extern void llama_set_causal_attn(SafeLLamaContextHandle ctx, [MarshalAs(UnmanagedType.U1)] bool causalAttn); /// /// Set whether the context outputs embeddings or not @@ -142,15 +137,13 @@ public static extern void llama_set_causal_attn(SafeLLamaContextHandle ctx, /// /// If true, embeddings will be returned but logits will not [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_set_embeddings(SafeLLamaContextHandle ctx, - [MarshalAs(UnmanagedType.U1)] bool embeddings); + public static extern void llama_set_embeddings(SafeLLamaContextHandle ctx, [MarshalAs(UnmanagedType.U1)] bool embeddings); /// /// Set abort callback /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_set_abort_callback(SafeLlamaModelHandle ctx, - IntPtr /* ggml_abort_callback */ abortCallback, IntPtr abortCallbackData); + public static extern void llama_set_abort_callback(SafeLlamaModelHandle ctx, IntPtr /* ggml_abort_callback */ abortCallback, IntPtr abortCallbackData); /// /// Get the n_seq_max for this context From d6b10a020175662c0987ca0b99e9ae43b90a8de1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 1 Nov 2025 06:05:21 +0000 Subject: [PATCH 18/35] Bump BenchmarkDotNet and BenchmarkDotNet.Diagnostics.Windows Bumps BenchmarkDotNet from 0.15.4 to 0.15.5 Bumps BenchmarkDotNet.Diagnostics.Windows from 0.15.4 to 0.15.5 --- updated-dependencies: - dependency-name: BenchmarkDotNet dependency-version: 0.15.5 dependency-type: direct:production update-type: version-update:semver-patch - dependency-name: BenchmarkDotNet.Diagnostics.Windows dependency-version: 0.15.5 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- LLama.Benchmark/LLama.Benchmark.csproj | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/LLama.Benchmark/LLama.Benchmark.csproj b/LLama.Benchmark/LLama.Benchmark.csproj index ec9c3a33e..7396d39d0 100644 --- a/LLama.Benchmark/LLama.Benchmark.csproj +++ b/LLama.Benchmark/LLama.Benchmark.csproj @@ -10,8 +10,8 @@ - - + + From ae845fd8da68a91e50a3c77858398b8b92d0e716 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 1 Nov 2025 06:05:55 +0000 Subject: [PATCH 19/35] Bump Microsoft.AspNetCore.Mvc.Razor.RuntimeCompilation from 8.0.20 to 8.0.21 --- updated-dependencies: - dependency-name: Microsoft.AspNetCore.Mvc.Razor.RuntimeCompilation dependency-version: 8.0.21 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- LLama.Web/LLama.Web.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LLama.Web/LLama.Web.csproj b/LLama.Web/LLama.Web.csproj index 52d68ad41..89840361a 100644 --- a/LLama.Web/LLama.Web.csproj +++ b/LLama.Web/LLama.Web.csproj @@ -15,7 +15,7 @@ - + From 17bd6b8c8d20d0e93ac21d8b112b73345038e3fb Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 1 Nov 2025 06:06:25 +0000 Subject: [PATCH 20/35] Bump Microsoft.AspNetCore.OpenApi from 8.0.20 to 8.0.21 --- updated-dependencies: - dependency-name: Microsoft.AspNetCore.OpenApi dependency-version: 8.0.21 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- LLama.WebAPI/LLama.WebAPI.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LLama.WebAPI/LLama.WebAPI.csproj b/LLama.WebAPI/LLama.WebAPI.csproj index 6b2d28cc4..054dafdfb 100644 --- a/LLama.WebAPI/LLama.WebAPI.csproj +++ b/LLama.WebAPI/LLama.WebAPI.csproj @@ -9,7 +9,7 @@ - + From 7a2dfd97c65170ce2420a92e1d7896d9aa67f4b4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 1 Nov 2025 06:06:45 +0000 Subject: [PATCH 21/35] Bump Microsoft.Bcl.AsyncInterfaces from 9.0.9 to 9.0.10 --- updated-dependencies: - dependency-name: Microsoft.Bcl.AsyncInterfaces dependency-version: 9.0.10 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- LLama/LLamaSharp.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index e91436b89..6b4dcce12 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -50,7 +50,7 @@ - + From 78f2848be1e5ccb64100136269967fd231172307 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 1 Nov 2025 11:44:23 +0000 Subject: [PATCH 22/35] Bump Microsoft.Extensions.AI.Abstractions from 9.9.1 to 9.10.1 --- updated-dependencies: - dependency-name: Microsoft.Extensions.AI.Abstractions dependency-version: 9.10.1 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- LLama/LLamaSharp.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index 6b4dcce12..1416eb503 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -51,7 +51,7 @@ - + From 43558099b8f5f10b795a5578951914d28727c5db Mon Sep 17 00:00:00 2001 From: Kris Date: Sat, 1 Nov 2025 22:48:29 +0530 Subject: [PATCH 23/35] Enable FlashAttention and remove SeqMax param FlashAttention is now enabled by default in model parameter initialization for embedding and text generation. The unused SeqMax parameter has been removed from unit tests to simplify configuration. Minor formatting improvements were made in IContextParamsExtensions and NativeApi for consistency. --- LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs | 2 ++ LLama.KernelMemory/LlamaSharpTextGenerator.cs | 2 ++ LLama.Unittest/LLamaContextTests.cs | 1 - LLama.Unittest/LLamaRerankerTests.cs | 1 - LLama.Unittest/SamplingTests.cs | 1 - LLama/Extensions/IContextParamsExtensions.cs | 2 +- LLama/Native/NativeApi.cs | 9 +++------ 7 files changed, 8 insertions(+), 10 deletions(-) diff --git a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs index b5c110194..0635015df 100644 --- a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs +++ b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs @@ -40,6 +40,7 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config) SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer, BatchSize = 512, UBatchSize = 512, + FlashAttention = true, UseMemorymap = true, PoolingType = LLamaPoolingType.Mean, }; @@ -67,6 +68,7 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config, LLamaWeights we SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer, BatchSize = 512, UBatchSize = 512, + FlashAttention = true, UseMemorymap = true, PoolingType = LLamaPoolingType.Mean, }; diff --git a/LLama.KernelMemory/LlamaSharpTextGenerator.cs b/LLama.KernelMemory/LlamaSharpTextGenerator.cs index 166d4ad38..5c965b266 100644 --- a/LLama.KernelMemory/LlamaSharpTextGenerator.cs +++ b/LLama.KernelMemory/LlamaSharpTextGenerator.cs @@ -38,6 +38,7 @@ public LlamaSharpTextGenerator(LLamaSharpConfig config) SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer, BatchSize = 512, UBatchSize = 512, + FlashAttention = true, UseMemorymap = true }; _weights = LLamaWeights.LoadFromFile(@params); @@ -65,6 +66,7 @@ public LlamaSharpTextGenerator(LLamaWeights weights, LLamaSharpConfig config, St SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer, BatchSize = 512, UBatchSize = 512, + FlashAttention = true, UseMemorymap = true }; _executor = executor ?? new StatelessExecutor(_weights, @params); diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs index b04ee5382..7d85589d6 100644 --- a/LLama.Unittest/LLamaContextTests.cs +++ b/LLama.Unittest/LLamaContextTests.cs @@ -16,7 +16,6 @@ public LLamaContextTests() ContextSize = 512, BatchSize = 8, UBatchSize = 8, - SeqMax = 1, VocabOnly = false, GpuLayerCount = Constants.CIGpuLayerCount, }; diff --git a/LLama.Unittest/LLamaRerankerTests.cs b/LLama.Unittest/LLamaRerankerTests.cs index 534623a41..9ba0a31c0 100644 --- a/LLama.Unittest/LLamaRerankerTests.cs +++ b/LLama.Unittest/LLamaRerankerTests.cs @@ -18,7 +18,6 @@ public LLamaRerankerTests(ITestOutputHelper testOutputHelper) var @params = new ModelParams(Constants.RerankingModelPath) { ContextSize = 0, - SeqMax = 1, PoolingType = LLamaPoolingType.Rank, GpuLayerCount = Constants.CIGpuLayerCount, }; diff --git a/LLama.Unittest/SamplingTests.cs b/LLama.Unittest/SamplingTests.cs index 5dcb7b494..297641df3 100644 --- a/LLama.Unittest/SamplingTests.cs +++ b/LLama.Unittest/SamplingTests.cs @@ -25,7 +25,6 @@ public SamplingTests(ITestOutputHelper testOutputHelper) _params = new ModelParams(Constants.GenerativeModelPath2) { ContextSize = 200, BatchSize = 200, - SeqMax = 4, GpuLayerCount = Constants.CIGpuLayerCount, }; _model = LLamaWeights.LoadFromFile(_params); diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs index 816118524..8bac9f3f6 100644 --- a/LLama/Extensions/IContextParamsExtensions.cs +++ b/LLama/Extensions/IContextParamsExtensions.cs @@ -37,7 +37,7 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo result.yarn_beta_slow = @params.YarnBetaSlow ?? 1f; result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0; result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.Unspecified; - + result.defrag_threshold = @params.DefragThreshold ?? -1; result.cb_eval = IntPtr.Zero; diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 30aadc25e..4aefc8810 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -175,15 +175,12 @@ public static void llama_empty_call() /// A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages) /// The size of the allocated buffer /// The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. - public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, - [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length) + public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length) { return internal_llama_chat_apply_template(tmpl, chat, n_msg, add_ass, buf, length); - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, - EntryPoint = "llama_chat_apply_template")] - static extern int internal_llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, - [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length); + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl,EntryPoint = "llama_chat_apply_template")] + static extern int internal_llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length); } /// From ba5ca337069205f5980dd5693934a2cca2f4d388 Mon Sep 17 00:00:00 2001 From: Kris Date: Mon, 3 Nov 2025 02:55:07 +0530 Subject: [PATCH 24/35] Update IContextParamsExtensions.cs --- LLama/Extensions/IContextParamsExtensions.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs index 8bac9f3f6..bfef9b9c1 100644 --- a/LLama/Extensions/IContextParamsExtensions.cs +++ b/LLama/Extensions/IContextParamsExtensions.cs @@ -58,6 +58,7 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo null => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_AUTO }; result.kv_unified = true; + result.n_seq_max = (uint)Math.Min(Math.Max(10,result.n_ctx/8),256); result.n_threads = Threads(@params.Threads); result.n_threads_batch = Threads(@params.BatchThreads); From 55a7aeb16e2be38d42439c6e052617bb595ad1e6 Mon Sep 17 00:00:00 2001 From: Kris Date: Wed, 12 Nov 2025 01:07:02 +0530 Subject: [PATCH 25/35] fix : check-properties tests --- LLama.Unittest/LLamaContextTests.cs | 2 +- LLama.Unittest/LLamaContextWithCustomLoggerTests.cs | 2 +- LLama/Native/SafeLLamaContextHandle.cs | 5 +++++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs index 7d85589d6..6379bcc17 100644 --- a/LLama.Unittest/LLamaContextTests.cs +++ b/LLama.Unittest/LLamaContextTests.cs @@ -32,7 +32,7 @@ public void Dispose() [Fact] public void CheckProperties() { - Assert.Equal(512u, _context.ContextSize); + Assert.Equal(_context.NativeHandle.MaxSeq * 256, _context.ContextSize); Assert.Equal(960, _context.EmbeddingSize); Assert.Equal(49152, _context.Vocab.Count); } diff --git a/LLama.Unittest/LLamaContextWithCustomLoggerTests.cs b/LLama.Unittest/LLamaContextWithCustomLoggerTests.cs index 871b6b8cd..b689c4ceb 100644 --- a/LLama.Unittest/LLamaContextWithCustomLoggerTests.cs +++ b/LLama.Unittest/LLamaContextWithCustomLoggerTests.cs @@ -55,7 +55,7 @@ public void Dispose() [Fact] public void CheckProperties() { - Assert.Equal(512u, _context.ContextSize); + Assert.Equal(_context.NativeHandle.MaxSeq * 256, _context.ContextSize); Assert.Equal(960, _context.EmbeddingSize); Assert.Equal(49152, _context.Vocab.Count); } diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index f48e818b7..05b52dcae 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -33,6 +33,11 @@ public sealed class SafeLLamaContextHandle /// Get the physical maximum batch size for this context /// public uint UBatchSize => llama_n_ubatch(this); + + /// + /// Get the number of maximum sequences allowed + /// + public uint MaxSeq => NativeApi.llama_n_seq_max(this); /// /// Get or set the number of threads used for generation of a single token. From 070ff334514827169b5e9ae7beb712c6d1eab3b0 Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sat, 27 Sep 2025 07:52:40 +0200 Subject: [PATCH 26/35] Mtmd Implementation base # Conflicts: # LLama/Native/NativeApi.cs --- LLama.Examples/ExampleRunner.cs | 4 +- .../Examples/BatchedExecutorLLava.cs | 91 ----- .../Examples/BatchedExecutorMtmd.cs | 126 +++++++ ...ecute.cs => MtmdInteractiveModeExecute.cs} | 83 +++-- LLama.Examples/LLama.Examples.csproj | 2 +- LLama.Unittest/Constants.cs | 6 +- LLama.Unittest/LLama.Unittest.csproj | 16 +- LLama.Unittest/MtmdExecutorTests.cs | 81 ++++ LLama.Unittest/MtmdWeightsTests.cs | 140 +++++++ .../Native/SafeLlamaModelHandleTests.cs | 32 -- .../SafeLlamaModelHandleVocabularyTests.cs | 42 --- LLama/Abstractions/ILLamaExecutor.cs | 7 +- LLama/Batched/BatchedExecutor.cs | 65 ++++ LLama/Batched/Conversation.cs | 242 ++++++++++-- LLama/Batched/ConversationExtensions.cs | 14 +- LLama/LLamaExecutorBase.cs | 135 ++++--- LLama/LLamaInstructExecutor.cs | 213 ++++++++++- LLama/LLamaInteractExecutor.cs | 301 +++++++++++---- LLama/LLamaSharp.csproj | 4 +- LLama/LLamaStatelessExecutor.cs | 6 +- LLama/LLavaWeights.cs | 137 ------- LLama/Native/LLavaImageEmbed.cs | 19 - LLama/Native/Load/NativeLibraryConfig.cs | 32 +- LLama/Native/Load/NativeLibraryUtils.cs | 2 +- LLama/Native/MtmdContextParams.cs | 148 ++++++++ LLama/Native/MtmdImageEmbed.cs | 20 + LLama/Native/NativeApi.LLava.cs | 63 ---- LLama/Native/NativeApi.Load.cs | 22 +- LLama/Native/NativeApi.Mtmd.cs | 312 ++++++++++++++++ LLama/Native/NativeApi.cs | 182 +++++---- LLama/Native/SafeLlavaImageEmbedHandle.cs | 162 -------- LLama/Native/SafeLlavaModelHandle.cs | 137 ------- LLama/Native/SafeMtmdEmbed.cs | 247 +++++++++++++ LLama/Native/SafeMtmdInputChunk.cs | 150 ++++++++ LLama/Native/SafeMtmdInputChunks.cs | 103 ++++++ LLama/Native/SafeMtmdModelHandle.cs | 349 ++++++++++++++++++ LLama/Properties/InternalsVisibleTo.cs | 3 + LLama/SafeMtmdWeights.cs | 80 ++++ docs/Examples/LLavaInteractiveModeExecute.md | 129 ------- docs/Examples/MtmdInteractiveModeExecute.md | 41 ++ mkdocs.yml | 4 +- 41 files changed, 2837 insertions(+), 1115 deletions(-) delete mode 100644 LLama.Examples/Examples/BatchedExecutorLLava.cs create mode 100644 LLama.Examples/Examples/BatchedExecutorMtmd.cs rename LLama.Examples/Examples/{LlavaInteractiveModeExecute.cs => MtmdInteractiveModeExecute.cs} (59%) create mode 100644 LLama.Unittest/MtmdExecutorTests.cs create mode 100644 LLama.Unittest/MtmdWeightsTests.cs delete mode 100644 LLama.Unittest/Native/SafeLlamaModelHandleTests.cs delete mode 100644 LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs delete mode 100644 LLama/LLavaWeights.cs delete mode 100644 LLama/Native/LLavaImageEmbed.cs create mode 100644 LLama/Native/MtmdContextParams.cs create mode 100644 LLama/Native/MtmdImageEmbed.cs delete mode 100644 LLama/Native/NativeApi.LLava.cs create mode 100644 LLama/Native/NativeApi.Mtmd.cs delete mode 100644 LLama/Native/SafeLlavaImageEmbedHandle.cs delete mode 100644 LLama/Native/SafeLlavaModelHandle.cs create mode 100644 LLama/Native/SafeMtmdEmbed.cs create mode 100644 LLama/Native/SafeMtmdInputChunk.cs create mode 100644 LLama/Native/SafeMtmdInputChunks.cs create mode 100644 LLama/Native/SafeMtmdModelHandle.cs create mode 100644 LLama/Properties/InternalsVisibleTo.cs create mode 100644 LLama/SafeMtmdWeights.cs delete mode 100644 docs/Examples/LLavaInteractiveModeExecute.md create mode 100644 docs/Examples/MtmdInteractiveModeExecute.md diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs index c073cd4cd..23f07c6a1 100644 --- a/LLama.Examples/ExampleRunner.cs +++ b/LLama.Examples/ExampleRunner.cs @@ -15,7 +15,7 @@ public class ExampleRunner { "Chat Session: Automatic conversation", TalkToYourself.Run }, { "Chat Session: Chinese characters", ChatChineseGB2312.Run }, { "Executor: Interactive mode chat", InteractiveModeExecute.Run }, - { "Executor: Llava Interactive mode chat", LlavaInteractiveModeExecute.Run }, + { "Executor: Mtmd Interactive mode chat", MtmdInteractiveModeExecute.Run }, { "Executor: Instruct mode chat", InstructModeExecute.Run }, { "Executor: Stateless mode chat", StatelessModeExecute.Run }, { "Save and Load: chat session", SaveAndLoadSession.Run }, @@ -33,7 +33,7 @@ public class ExampleRunner { "Batched Executor: Save/Load", BatchedExecutorSaveAndLoad.Run }, { "Batched Executor: Fork", BatchedExecutorFork.Run }, { "Batched Executor: Rewind", BatchedExecutorRewind.Run }, - { "Batched Executor: LLava", BatchedExecutorLLava.Run }, + { "Batched Executor: Mtmd", BatchedExecutorMtmd.Run }, { "Batched Executor: BoolQ Benchmark", BatchedExecutorBoolQ.Run }, { "Batched Executor: Beam Search", BatchedExecutorBeamSearch.Run }, { "Custom Sampling Pipeline", CustomSampler.Run }, diff --git a/LLama.Examples/Examples/BatchedExecutorLLava.cs b/LLama.Examples/Examples/BatchedExecutorLLava.cs deleted file mode 100644 index a131e994e..000000000 --- a/LLama.Examples/Examples/BatchedExecutorLLava.cs +++ /dev/null @@ -1,91 +0,0 @@ -using System.Text; -using LLama.Batched; -using LLama.Common; -using LLama.Native; -using LLama.Sampling; -using Spectre.Console; - -namespace LLama.Examples.Examples; - -/// -/// Demonstrates using LLava (image embeddings) with the batched executor. -/// -public class BatchedExecutorLLava -{ - /// - /// How many tokens of response to generate - /// - public const int TokenCount = 64; - - public static async Task Run() - { - // Load model weights - var parameters = new ModelParams(UserSettings.GetModelPath()); - using var model = await LLamaWeights.LoadFromFileAsync(parameters); - using var llava = await LLavaWeights.LoadFromFileAsync(UserSettings.GetMMProjPath()); - - // Decide on the prompt - var prompt = model.Tokenize(AnsiConsole.Ask("Prompt (or ENTER for default):", "\nUSER: Provide a full description of the image.\nASSISTANT: "), true, false, Encoding.UTF8); - - // Get image and show it - var image = UserSettings.GetImagePath(); - AnsiConsole.Write(new CanvasImage(image)); - - // Create an executor with one conversation - using var executor = new BatchedExecutor(model, parameters); - using var conversation = executor.Create(); - - // Embed the image - SafeLlavaImageEmbedHandle embedding = null!; - await AnsiConsole - .Status() - .StartAsync("[yellow]Embedding image with CLIP[/]", async _ => - { - // ReSharper disable once AccessToDisposedClosure - embedding = llava.CreateImageEmbeddings(await File.ReadAllBytesAsync(image)); - }); - - // Pass in the image and run inference until the entire image has been processed - await AnsiConsole - .Status() - .StartAsync("[yellow]Processing image embedding with language model[/]", async _ => - { - conversation.Prompt(embedding); - while (executor.BatchedTokenCount > 0) - await executor.Infer(); - }); - - // Prompt with the text prompt - conversation.Prompt(prompt); - - // Run inference loop - var decoder = new StreamingTokenDecoder(executor.Context); - var sampler = new DefaultSamplingPipeline(); - await AnsiConsole - .Progress() - .StartAsync(async ctx => - { - var task = ctx.AddTask("Generating Response"); - task.MaxValue = TokenCount; - - // Run a normal inference loop - for (var i = 0; i < TokenCount; i++) - { - task.Increment(1); - - await executor.Infer(); - - var token = sampler.Sample(executor.Context.NativeHandle, conversation.GetSampleIndex()); - if (token.IsEndOfGeneration(executor.Context.Vocab)) - break; - - decoder.Add(token); - conversation.Prompt(token); - } - }); - - // Print final result - var str = decoder.Read(); - AnsiConsole.MarkupInterpolated($"[green]{str}[/]"); - } -} \ No newline at end of file diff --git a/LLama.Examples/Examples/BatchedExecutorMtmd.cs b/LLama.Examples/Examples/BatchedExecutorMtmd.cs new file mode 100644 index 000000000..b62f8b120 --- /dev/null +++ b/LLama.Examples/Examples/BatchedExecutorMtmd.cs @@ -0,0 +1,126 @@ +using System; +using System.Collections.Generic; +using System.IO; +using LLama.Batched; +using LLama.Common; +using LLama.Exceptions; +using LLama.Native; +using LLama.Sampling; +using Spectre.Console; + +namespace LLama.Examples.Examples; + +/// +/// Demonstrates how to evaluate an image with MTMD helpers and continue generation by +/// manually scheduling batches, similar to what the batched executor does internally. +/// +public class BatchedExecutorMtmd +{ + /// + /// Number of completion tokens to generate after sending the image prompt. + /// + public const int TokenCount = 10000; + + public static async Task Run() + { + // Load the base LLM and its clip/mtmd sidecar weights so the executor has everything it needs. + var parameters = new ModelParams(UserSettings.GetModelPath()); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); + var mtmdParams = MtmdContextParams.Default(); // reuse llama.cpp defaults for helper settings + mtmdParams.UseGpu = false; + var marker = mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; + + using var mtmd = await SafeMtmdWeights.LoadFromFileAsync(UserSettings.GetMMProjPath(), model, mtmdParams); // multimodal helper weights + + using var executor = new BatchedExecutor(model, parameters, mtmd); // drives batched token + chunk evaluation + + // Prepend the media marker so the helper knows where to inject the encoded image tokens. + var defaultPrompt = "\nUSER: Provide a full description of the image.\nASSISTANT: "; + var promptSuffix = AnsiConsole.Ask("Prompt (or ENTER for default):", defaultPrompt); + var promptText = string.Concat(marker, promptSuffix); + + var imagePath = UserSettings.GetImagePath(); + AnsiConsole.Write(new CanvasImage(imagePath)); + + var vocab = executor.Context.NativeHandle.ModelHandle.Vocab; + + // Simple low-temperature sampler keeps the demo deterministic-ish. + var sampler = new DefaultSamplingPipeline + { + Temperature = 0.1f + }; + + // Stream decoded text to the console as soon as tokens arrive. + var decoder = new StreamingTokenDecoder(executor.Context) + { + DecodeSpecialTokens = false + }; + + try + { + // Each conversation tracks its own KV cache sequence IDs. + var conversation = executor.Create(); + // enqueue the image so MtmdHelper sees it + conversation.QueueMedia(imagePath); + // schedule multimodal prompt + conversation.Prompt(promptText, addBos: true, special: true); + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Prompt queued with multimodal chunks. Generating response...\n"); + Console.ResetColor(); + + var remaining = TokenCount; + + // Run one decode/sampling/prompt cycle – mirrors the batched executor inner loop. + async Task ProcessNextAsync() + { + var decodeResult = await executor.Infer(); + if (decodeResult == DecodeResult.NoKvSlot) // KV cache exhausted – surface to the user + { + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine("Insufficient KV cache space for multimodal evaluation."); + Console.ResetColor(); + return false; + } + + if (decodeResult != DecodeResult.Ok) + throw new RuntimeError($"Failed to evaluate batch: {decodeResult}."); + + if (!conversation.RequiresSampling) // another conversation may still be queued + return true; + + var token = conversation.Sample(sampler); // pull logits (or -1 for mtmd chunk) and sample + if (token.IsEndOfGeneration(vocab)) + return false; + + decoder.Add(token); + var delta = decoder.Read(); + if (!string.IsNullOrEmpty(delta)) + Console.Write(delta); + + sampler.Accept(token); // keep sampler state in sync + conversation.Prompt(token); // feed the accepted token back into the batch + remaining--; + return remaining > 0; + } + + while (remaining > 0 && await ProcessNextAsync()) // continue until EOS or budget is reached + { + } + + Console.WriteLine(); + } + catch (IOException ex) + { + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine($"Could not load media '{imagePath}': {ex.Message}"); + Console.ResetColor(); + } + catch (RuntimeError ex) + { + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine($"MTMD processing failed: {ex.Message}"); + Console.ResetColor(); + } + } +} diff --git a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs b/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs similarity index 59% rename from LLama.Examples/Examples/LlavaInteractiveModeExecute.cs rename to LLama.Examples/Examples/MtmdInteractiveModeExecute.cs index 8cbf58dcd..ca0de3b77 100644 --- a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs +++ b/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs @@ -1,3 +1,5 @@ +using System.Collections.Generic; +using System.IO; using System.Text.RegularExpressions; using LLama.Common; using Spectre.Console; @@ -6,27 +8,32 @@ namespace LLama.Examples.Examples { - // This example shows how to chat with LLaVA model with both image and text as input. + // This example shows how to chat with Mtmd model with both image and text as input. // It uses the interactive executor to inference. - public class LlavaInteractiveModeExecute + public class MtmdInteractiveModeExecute { public static async Task Run() { string multiModalProj = UserSettings.GetMMProjPath(); string modelPath = UserSettings.GetModelPath(); string modelImage = UserSettings.GetImagePath(); - const int maxTokens = 1024; + const int maxTokens = 2048; var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n"; var parameters = new ModelParams(modelPath); + var mtmdParameters = MtmdContextParams.Default(); + mtmdParameters.UseGpu = false; + using var model = await LLamaWeights.LoadFromFileAsync(parameters); using var context = model.CreateContext(parameters); - - // Llava Init - using var clipModel = await LLavaWeights.LoadFromFileAsync(multiModalProj); - + + // Mtmd Init + using var clipModel = await SafeMtmdWeights.LoadFromFileAsync(multiModalProj, model, mtmdParameters ); + + var mediaMarker = mtmdParameters.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; + var ex = new InteractiveExecutor(context, clipModel); Console.ForegroundColor = ConsoleColor.Yellow; @@ -40,7 +47,7 @@ public static async Task Run() Temperature = 0.1f }, - AntiPrompts = new List { "\nUSER:" }, + AntiPrompts = new List { "\nASSISTANT:" }, MaxTokens = maxTokens }; @@ -48,30 +55,53 @@ public static async Task Run() do { - // Evaluate if we have images + // Evaluate if we have media // - var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); - var imageCount = imageMatches.Count(); - var hasImages = imageCount > 0; + var mediaMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); + var mediaCount = mediaMatches.Count(); + var hasMedia = mediaCount > 0; - if (hasImages) + if (hasMedia) { - var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); - var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList(); + var mediaPathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); + var mediaPaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList(); - List imageBytes; + var embeds = new List(); + var imageList = new List(); + var imageExtensions = new HashSet(StringComparer.OrdinalIgnoreCase) + { + ".png", + ".jpg", + ".jpeg", + ".bmp", + ".gif", + ".webp" + }; + try { - imageBytes = imagePaths.Select(File.ReadAllBytes).ToList(); + foreach (var mediaPath in mediaPaths) + { + var extension = Path.GetExtension(mediaPath); + if (!string.IsNullOrEmpty(extension) && imageExtensions.Contains(extension)) + { + // Keep the raw image data so the caller can reuse or inspect the images later. + imageList.Add(File.ReadAllBytes(mediaPath)); + } + + var embed = clipModel.LoadMedia(mediaPath); + embeds.Add(embed); + } } catch (IOException exception) { Console.ForegroundColor = ConsoleColor.Red; Console.Write( - $"Could not load your {(imageCount == 1 ? "image" : "images")}:"); + $"Could not load your {(mediaCount == 1 ? "media" : "medias")}:"); Console.Write($"{exception.Message}"); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Please try again."); + clipModel.ClearMedia(); break; } @@ -81,19 +111,17 @@ public static async Task Run() // https://github.com/ggerganov/llama.cpp/discussions/3620 ex.Context.NativeHandle.MemorySequenceRemove( LLamaSeqId.Zero, -1, -1 ); - int index = 0; - foreach (var path in imagePathsWithCurlyBraces) + // Replace placeholders with media markers (one marker per image) + foreach (var path in mediaPathsWithCurlyBraces) { - // First image replace to tag " : ""); + prompt = prompt.Replace(path, mediaMarker, StringComparison.Ordinal); } - Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine($"Here are the images, that are sent to the chat model in addition to your message."); Console.WriteLine(); - foreach (var consoleImage in imageBytes?.Select(bytes => new CanvasImage(bytes)) ?? Array.Empty()) + foreach (var consoleImage in imageList.Select(image => new CanvasImage(image.ToArray()))) { consoleImage.MaxWidth = 50; AnsiConsole.Write(consoleImage); @@ -108,10 +136,9 @@ public static async Task Run() // Initialize Images in executor // - foreach (var image in imagePaths) - { - ex.Images.Add(await File.ReadAllBytesAsync(image)); - } + ex.Embeds.Clear(); + foreach (var embed in embeds) + ex.Embeds.Add(embed); } Console.ForegroundColor = Color.White; diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj index 8d70d5637..6d69bc942 100644 --- a/LLama.Examples/LLama.Examples.csproj +++ b/LLama.Examples/LLama.Examples.csproj @@ -9,7 +9,7 @@ true true - 12 + 13 1701;1702;8604;SKEXP0001;SKEXP0050;SKEXP0052;SKEXP0003 diff --git a/LLama.Unittest/Constants.cs b/LLama.Unittest/Constants.cs index d501b189b..f705f1609 100644 --- a/LLama.Unittest/Constants.cs +++ b/LLama.Unittest/Constants.cs @@ -9,9 +9,9 @@ internal static class Constants public static readonly string EmbeddingModelPath = "Models/all-MiniLM-L12-v2.Q8_0.gguf"; public static readonly string RerankingModelPath = "Models/jina-reranker-v1-tiny-en-FP16.gguf"; - public static readonly string LLavaModelPath = "Models/llava-v1.6-mistral-7b.Q3_K_XS.gguf"; - public static readonly string LLavaMmpPath = "Models/mmproj-model-f16.gguf"; - public static readonly string LLavaImage = "Models/extreme-ironing-taxi-610x427.jpg"; + public static readonly string MtmdModelPath = "Models/gemma-3-4b-it-Q4_K_M.gguf"; + public static readonly string MtmdMmpPath = "Models/gemma-mmproj-model-f16.gguf"; + public static readonly string MtmdImage = "Models/extreme-ironing-taxi-610x427.jpg"; /// /// Calculate GpuLayer Count to use in UnitTest diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index 8f9f075d8..ca3ea8854 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -52,16 +52,16 @@ jina-reranker-v1-tiny-en-FP16.gguf - - https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/llava-v1.6-mistral-7b.Q3_K_XS.gguf + + https://huggingface.co/ggml-org/gemma-3-4b-it-GGUF/resolve/main/gemma-3-4b-it-Q4_K_M.gguf Models - llava-v1.6-mistral-7b.Q3_K_XS.gguf + gemma-3-4b-it-Q4_K_M.gguf - - https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/mmproj-model-f16.gguf + + https://huggingface.co/ggml-org/gemma-3-4b-it-GGUF/resolve/main/mmproj-model-f16.gguf Models - mmproj-model-f16.gguf + gemma-mmproj-model-f16.gguf @@ -142,10 +142,10 @@ PreserveNewest - + PreserveNewest - + PreserveNewest diff --git a/LLama.Unittest/MtmdExecutorTests.cs b/LLama.Unittest/MtmdExecutorTests.cs new file mode 100644 index 000000000..75a96b261 --- /dev/null +++ b/LLama.Unittest/MtmdExecutorTests.cs @@ -0,0 +1,81 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using LLama.Common; +using LLama.Native; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace LLama.Unittest; + +[Trait("Category", "NoCI")] +public class MtmdExecutorTests : IDisposable +{ + private readonly LLamaWeights _weights; + private readonly MtmdContextParams _mtmdParams; + private readonly SafeMtmdWeights _mtmd; + private readonly ModelParams _modelParams; + + public MtmdExecutorTests() + { + _modelParams = new ModelParams(Constants.MtmdModelPath) + { + ContextSize = 1024 * 8, + GpuLayerCount = Constants.CIGpuLayerCount, + }; + + _weights = LLamaWeights.LoadFromFile(_modelParams); + + _mtmdParams = MtmdContextParams.Default(); + _mtmdParams.NThreads = Math.Max(1, Constants.CIGpuLayerCount); + _mtmdParams.UseGpu = false; + + _mtmd = SafeMtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _weights, _mtmdParams); + } + + public void Dispose() + { + _mtmd.Dispose(); + _weights.Dispose(); + } + + [Fact] + public async Task InteractiveExecutor_EvaluateChunks_DoesNotRetokenize() + { + using var context = _weights.CreateContext(_modelParams, NullLogger.Instance); + var executor = new InteractiveExecutor(context, _mtmd, NullLogger.Instance); + var marker = _mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; + var prompt = $"{marker}\nDescribe the image succinctly."; + + executor.Embeds.Add(_mtmd.LoadMedia(Constants.MtmdImage)); + + await foreach (var _ in executor.InferAsync(prompt, new InferenceParams { MaxTokens = 0 })) + { + Assert.True(false, "Prefill should not emit generated text"); + } + + var diagnostics = executor.GetDiagnostics(); + Assert.Equal(diagnostics.EmbedCount, diagnostics.ConsumedCount); + Assert.Equal(diagnostics.ConsumedCount, diagnostics.PastCount); + Assert.Equal(0, diagnostics.PendingEmbedCount); + } + + [Fact] + public async Task InstructExecutor_MtmdPromptAdvancesPastTokensOnce() + { + using var context = _weights.CreateContext(_modelParams, NullLogger.Instance); + var executor = new InstructExecutor(context, _mtmd, logger: NullLogger.Instance); + executor.Embeds.Add(_mtmd.LoadMedia(Constants.MtmdImage)); + + var prompt = $"{_mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""} Provide details."; + + await foreach (var _ in executor.InferAsync(prompt, new InferenceParams { MaxTokens = 0 })) + { + } + + var diagnostics = executor.GetDiagnostics(); + Assert.Equal(diagnostics.EmbedCount, diagnostics.ConsumedCount); + Assert.Equal(diagnostics.ConsumedCount, diagnostics.PastCount); + Assert.Equal(0, diagnostics.PendingEmbedCount); + } +} diff --git a/LLama.Unittest/MtmdWeightsTests.cs b/LLama.Unittest/MtmdWeightsTests.cs new file mode 100644 index 000000000..947bbd1ea --- /dev/null +++ b/LLama.Unittest/MtmdWeightsTests.cs @@ -0,0 +1,140 @@ +using System; +using System.IO; +using LLama.Common; +using LLama.Native; +using Xunit; + +namespace LLama.Unittest +{ + // Test the same things as llama model + image embedings + // + public sealed class MtmdWeightTests + : IDisposable + { + private readonly LLamaWeights _llamaWeights; + private readonly SafeMtmdWeights _safeMtmdWeights; + private readonly LLamaContext _context; + private readonly MtmdContextParams _mtmdParams; + private readonly string _mediaMarker; + + public MtmdWeightTests() + { + var @params = new ModelParams(Constants.MtmdModelPath) + { + // Mtmd models requires big context + ContextSize = 1024 * 32, + GpuLayerCount = Constants.CIGpuLayerCount, + }; + _llamaWeights = LLamaWeights.LoadFromFile(@params); + + _mtmdParams = MtmdContextParams.Default(); + _mtmdParams.NThreads = Constants.CIGpuLayerCount; + _mtmdParams.UseGpu = false; // keep tests portable across environments without GPU + + _mediaMarker = _mtmdParams.MediaMarker ?? throw new InvalidOperationException("MTMD media marker unavailable."); + + _safeMtmdWeights = SafeMtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _llamaWeights, _mtmdParams); + _context = _llamaWeights.CreateContext(@params); + } + + public void Dispose() + { + _context.Dispose(); + _safeMtmdWeights.Dispose(); + _llamaWeights.Dispose(); + } + + private SafeMtmdInputChunks TokenizeWithEmbed(Func loadEmbed) + { + _safeMtmdWeights.ClearMedia(); + + var embed = loadEmbed(); + Assert.NotNull(embed); + + using (embed) + { + Assert.True(embed.Nx > 0); + Assert.True(embed.Ny > 0); + Assert.False(embed.IsAudio); + Assert.True(embed.GetDataSpan().Length > 0); + + var status = _safeMtmdWeights.Tokenize(_mediaMarker, addSpecial: true, parseSpecial: true, out var chunks); + Assert.Equal(0, status); + Assert.NotNull(chunks); + + return chunks!; + } + } + + private void AssertChunksEvaluate(SafeMtmdInputChunks chunks) + { + long nPast = 0; + var eval = _safeMtmdWeights.EvaluateChunks(chunks, _context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)_context.BatchSize), logitsLast: true); + Assert.Equal(0, eval); + Assert.True(nPast > 0); + } + + [Fact,Trait("Category", "NoCI")] + public void EmbedImageAsFileName() + { + using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(Constants.MtmdImage)); + AssertChunksEvaluate(chunks); + } + + [Fact,Trait("Category", "NoCI")] + public void EmbedImageAsBinary() + { + var imageBytes = File.ReadAllBytes(Constants.MtmdImage); + using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(imageBytes)); + AssertChunksEvaluate(chunks); + } + + [Fact,Trait("Category", "NoCI")] + public void TokenizeProvidesChunkMetadata() + { + using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(Constants.MtmdImage)); + + Assert.True(chunks.Size > 0); + + ulong totalTokens = 0; + long totalPositions = 0; + var imageChunks = 0; + + foreach (var chunk in chunks.Enumerate()) + { + totalTokens += chunk.NTokens; + totalPositions += chunk.NPos; + + if (chunk.Type == SafeMtmdInputChunk.SafeMtmdInputChunkType.Image) + { + imageChunks++; + + var copy = chunk.Copy(); + try + { + Assert.NotNull(copy); + if (copy != null) + { + Assert.Equal(chunk.NTokens, copy.NTokens); + Assert.Equal(chunk.NPos, copy.NPos); + } + } + finally + { + copy?.Dispose(); + } + } + } + + Assert.True(imageChunks > 0); + Assert.True(totalTokens > 0); + Assert.Equal(totalTokens, _safeMtmdWeights.CountTokens(chunks)); + Assert.Equal(totalPositions, _safeMtmdWeights.CountPositions(chunks)); + Assert.True(_safeMtmdWeights.SupportsVision); + Assert.False(_safeMtmdWeights.SupportsAudio); + + var audioBitrate = _safeMtmdWeights.AudioBitrate; + Assert.True(audioBitrate <= 0); + } + } +} diff --git a/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs b/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs deleted file mode 100644 index f3e5798f2..000000000 --- a/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs +++ /dev/null @@ -1,32 +0,0 @@ -using System.Runtime.InteropServices; -using System.Text; -using LLama.Common; -using LLama.Extensions; -using Xunit; - -namespace LLama.Unittest.Native; - -public class SafeLlamaModelHandleTests -{ - private readonly LLamaWeights _model; - - public SafeLlamaModelHandleTests() - { - var @params = new ModelParams(Constants.GenerativeModelPath2) - { - ContextSize = 1, - GpuLayerCount = Constants.CIGpuLayerCount - }; - _model = LLamaWeights.LoadFromFile(@params); - } - - // Note: This test is flakey, it appears to often (but not always) fail the first time it is run after downloading the model file, but then succeed every time after! - //[SkippableFact] - //public void MetadataValByKey_ReturnsCorrectly() - //{ - // Skip.If(RuntimeInformation.IsOSPlatform(OSPlatform.OSX), "Skipping this test on macOS because for some reason the meta data is incorrect, but the rest of tests work well on mscOS [Check later!]."); - // const string key = "general.name"; - // var template = _model.NativeHandle.MetadataValueByKey(key); - // var name = Encoding.UTF8.GetStringFromSpan(template!.Value.Span); - //} -} diff --git a/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs b/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs deleted file mode 100644 index 1ce53f395..000000000 --- a/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs +++ /dev/null @@ -1,42 +0,0 @@ -using System.Text; -using System.Xml.Linq; -using LLama.Common; -using LLama.Extensions; -using Microsoft.Extensions.Logging; - - -namespace LLama.Unittest.Native; - -public class SafeLlamaModelHandleVocabularyTests: IDisposable -{ - private readonly LLamaWeights _model; - - public SafeLlamaModelHandleVocabularyTests() - { - var @params = new ModelParams(Constants.RerankingModelPath) - { - ContextSize = 0, - PoolingType = LLama.Native.LLamaPoolingType.Rank, - GpuLayerCount = Constants.CIGpuLayerCount - }; - _model = LLamaWeights.LoadFromFile(@params); - } - - public void Dispose() - { - _model.Dispose(); - } - - [Fact] - public void GetLLamaTokenString() - { - var bos = _model.Vocab.BOS; - var eos = _model.Vocab.EOS; - - var bosStr = _model.Vocab.LLamaTokenToString(bos, true); - var eosStr = _model.Vocab.LLamaTokenToString(eos, true); - - Assert.Equal("", bosStr); - Assert.Equal("", eosStr); - } -} diff --git a/LLama/Abstractions/ILLamaExecutor.cs b/LLama/Abstractions/ILLamaExecutor.cs index 9a2233287..92276e4a6 100644 --- a/LLama/Abstractions/ILLamaExecutor.cs +++ b/LLama/Abstractions/ILLamaExecutor.cs @@ -1,5 +1,6 @@ using System.Collections.Generic; using System.Threading; +using LLama.Native; namespace LLama.Abstractions { @@ -22,12 +23,12 @@ public interface ILLamaExecutor /// /// Multi-Modal Projections / Clip Model weights /// - public LLavaWeights? ClipModel { get; } + public SafeMtmdWeights? ClipModel { get; } /// - /// List of images: List of images in byte array format. + /// List of media: List of media for Multi-Modal models. /// - public List Images { get; } + public List Embeds { get; } /// /// Asynchronously infers a response from the model. diff --git a/LLama/Batched/BatchedExecutor.cs b/LLama/Batched/BatchedExecutor.cs index 462e9e555..40468c98d 100644 --- a/LLama/Batched/BatchedExecutor.cs +++ b/LLama/Batched/BatchedExecutor.cs @@ -16,6 +16,7 @@ public sealed class BatchedExecutor { private int _nextSequenceId; private readonly List _batchQueue = []; + private string? _mtmdMarker; private int _batchQueueHead; private int _batchedTokenCount; private bool _batchedTokenCountDirty = true; @@ -79,12 +80,20 @@ public int BatchedTokenCount /// The model to use /// Parameters to create a new context public BatchedExecutor(LLamaWeights model, IContextParams contextParams) + : this(model, contextParams, null) + { + } + + public BatchedExecutor(LLamaWeights model, IContextParams contextParams, SafeMtmdWeights? clipModel) { Model = model; Context = model.CreateContext(contextParams); + ClipModel = clipModel; Epoch = 1; } + public SafeMtmdWeights? ClipModel { get; } + /// /// Start a new /// @@ -314,6 +323,23 @@ internal LLamaSeqId GetNextSequenceId() return (end, Epoch + (uint)(_batchQueue.Count - _batchQueueHead) * 2); } + internal ulong QueueMtmdBatch(Conversation conversation, Conversation.MtmdChunkSequence sequence) + { + if (ClipModel is null) + throw new InvalidOperationException("This batched executor is not configured for multimodal inference."); + + var batch = new MtmdChunkBatch(ClipModel, conversation, sequence); + _batchQueue.Add(batch); + return Epoch + (uint)_batchQueue.Count * 2; + } + + internal string GetMtmdMarker() + { + if (ClipModel is null) + throw new InvalidOperationException("This batched executor is not configured for multimodal inference."); + return _mtmdMarker ??= NativeApi.MtmdDefaultMarker() ?? ""; + } + #region batches private interface IBatch { @@ -345,5 +371,44 @@ public Task DecodeAsync(LLamaContext ctx, CancellationToken token) return ctx.DecodeAsync(Batch, token); } } + + private class MtmdChunkBatch : IBatch + { + private readonly SafeMtmdWeights _clipModel; + private readonly Conversation _conversation; + private readonly Conversation.MtmdChunkSequence _sequence; + + public MtmdChunkBatch(SafeMtmdWeights clipModel, Conversation conversation, Conversation.MtmdChunkSequence sequence) + { + _clipModel = clipModel; + _conversation = conversation; + _sequence = sequence; + } + + public int ItemCount => Math.Max(1, _sequence.TotalTokens); + + public Task DecodeAsync(LLamaContext ctx, CancellationToken token) + { + try + { + var nPast = _conversation.GetMtmdPast(); + var status = _clipModel.EvaluateChunks(_sequence.Chunks, ctx.NativeHandle, ref nPast, + (int)_conversation.ConversationId, checked((int)ctx.BatchSize), logitsLast: true); + if (status != 0) + { + _conversation.OnMtmdEvaluationFailed(status); + return Task.FromResult(DecodeResult.DecodeFailed); + } + + _conversation.OnMtmdEvaluationCompleted(nPast, _sequence); + return Task.FromResult(DecodeResult.Ok); + } + catch + { + _conversation.OnMtmdEvaluationFailed(-1); + return Task.FromResult(DecodeResult.DecodeFailed); + } + } + } #endregion } diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs index c504ce07a..2311c8a0c 100644 --- a/LLama/Batched/Conversation.cs +++ b/LLama/Batched/Conversation.cs @@ -3,6 +3,7 @@ using System.Linq; using System.Text.Json; using CommunityToolkit.HighPerformance.Buffers; +using LLama.Exceptions; using LLama.Native; namespace LLama.Batched; @@ -21,6 +22,12 @@ public sealed class Conversation /// Indicates if this conversation has been "forked" and may share logits with another conversation. /// private bool _forked; + private readonly List _mtmdEmbeds = new(); + private int? _mtmdLogitsIndex; + private MtmdChunkSequence? _pendingMtmdSequence; + private readonly List _embed_inps = new(); + private readonly List _session_tokens = new(); + private int _consumedTokensCount; /// /// Stores the indices to sample from. Contains valid items. @@ -65,6 +72,46 @@ internal Conversation(BatchedExecutor batch, LLamaSeqId id) Executor = batch; } + internal sealed class MtmdChunkSequence : IDisposable + { + public SafeMtmdInputChunks Chunks { get; } + public List TextTokens { get; } + public int TotalPositions { get; } + public int TotalTokens => TextTokens.Count; + + private MtmdChunkSequence(SafeMtmdInputChunks chunks, List textTokens, int totalPositions) + { + Chunks = chunks; + TextTokens = textTokens; + TotalPositions = totalPositions; + } + + public static MtmdChunkSequence Create(SafeMtmdInputChunks chunks, SafeMtmdWeights clipModel) + { + var textTokens = new List(); + + foreach (var chunk in chunks.Enumerate()) + { + using (chunk) + { + if (chunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) + continue; + + foreach (var token in chunk.GetTextTokensSpan()) + textTokens.Add((LLamaToken)unchecked((int)token)); + } + } + + var totalPositions = (int)clipModel.CountPositions(chunks); + return new MtmdChunkSequence(chunks, textTokens, totalPositions); + } + + public void Dispose() + { + Chunks.Dispose(); + } + } + /// /// Finalizer for Conversation /// @@ -83,6 +130,11 @@ public void Dispose() return; _disposed = true; + _pendingMtmdSequence?.Dispose(); + _pendingMtmdSequence = null; + + DisposeQueuedMedia(); + // Remove this conversation from the KV cache Executor.Context.NativeHandle.MemorySequenceRemove(ConversationId, -1, -1); @@ -206,6 +258,43 @@ private void AssertCanBePrompted() if (RequiresInference) throw new AlreadyPromptedConversationException(); + + _mtmdLogitsIndex = null; + } + + public void QueueMedia(string path) + { + AssertCanBePrompted(); + + if (Executor.ClipModel is null) + throw new InvalidOperationException("This conversation is not configured for multimodal prompts."); + + var embed = Executor.ClipModel.LoadMedia(path); + _mtmdEmbeds.Add(embed); + _mtmdLogitsIndex = null; + } + + public void QueueMedia(SafeMtmdEmbed embed) + { + AssertCanBePrompted(); + + if (Executor.ClipModel is null) + throw new InvalidOperationException("This conversation is not configured for multimodal prompts."); + + _mtmdEmbeds.Add(embed); + _mtmdLogitsIndex = null; + } + + public void Prompt(string promptText, bool addBos = true, bool special = true) + { + if (Executor.ClipModel != null && _mtmdEmbeds.Count > 0) + { + PromptMultimodal(promptText, addBos); + return; + } + + var tokens = Executor.Context.Tokenize(promptText, addBos, special); + Prompt(tokens); } /// @@ -246,6 +335,7 @@ public void Prompt(List tokens, bool allLogits = false) public void Prompt(ReadOnlySpan tokens, bool allLogits = false) { AssertCanBePrompted(); + _mtmdLogitsIndex = null; // No point doing anything if there is no actual prompt! if (tokens.Length == 0) @@ -289,6 +379,59 @@ public void Prompt(ReadOnlySpan tokens, bool allLogits = false) // Unset the forked flag. Since this conversation has just been prompted it's no longer // sharing anything with any other conversations. _forked = false; + _mtmdLogitsIndex = null; + } + + private void PromptMultimodal(string text, bool addBos) + { + AssertCanBePrompted(); + + if (Executor.ClipModel is null) + throw new InvalidOperationException("This conversation is not configured for multimodal prompts."); + if (_mtmdEmbeds.Count == 0) + throw new InvalidOperationException("Queue media before prompting with multimodal input."); + + var marker = Executor.GetMtmdMarker(); + var prompt = text; + + if (prompt.Contains("")) + prompt = prompt.Replace("", marker); + + if (!prompt.Contains(marker)) + { + var suffix = string.Concat(Enumerable.Repeat(marker, _mtmdEmbeds.Count)); + prompt = string.Concat(prompt, suffix); + } + + SafeMtmdInputChunks? chunks = null; + try + { + _mtmdLogitsIndex = null; + var status = Executor.ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); + if (status != 0 || chunks is null) + { + Executor.ClipModel.ClearMedia(); + throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); + } + + var sequence = MtmdChunkSequence.Create(chunks, Executor.ClipModel); + _pendingMtmdSequence = sequence; + + var epoch = Executor.QueueMtmdBatch(this, sequence); + chunks = null; + + if (_batchSampleIndices.Length == 0) + _batchSampleIndices = new int[4]; + + _batchSampleCount = 0; + _requiredEpoch = epoch; + _forked = false; + } + finally + { + DisposeQueuedMedia(); + chunks?.Dispose(); + } } /// @@ -305,32 +448,7 @@ public void Prompt(LLamaToken token) Span span = [ token ]; Prompt(span); } - - /// - /// Prompt this conversation with an image embedding - /// - /// - public void Prompt(SafeLlavaImageEmbedHandle embedding) - { - AssertCanBePrompted(); - - if (embedding.Model.EmbeddingDimensions != Executor.Model.EmbeddingSize) - throw new ArgumentException($"Embedding dimension mismatch between image embedding ({embedding.Model.EmbeddingDimensions}) and model ({Executor.Model.EmbeddingSize})"); - - for (var i = 0; i < embedding.Model.PatchCount; i++) - { - // Get a batch with space - (var batch, _requiredEpoch) = Executor.GetEmbeddingBatch(); - - batch.Add( - (i, embedding), - static (Span dest, (int index, SafeLlavaImageEmbedHandle embedding) tup) => tup.embedding.GetEmbedding(dest, tup.index), - _end++, - ConversationId, - i == embedding.Model.PatchCount - 1 - ); - } - } + /// /// Prompt this conversation with embeddings @@ -339,6 +457,7 @@ public void Prompt(SafeLlavaImageEmbedHandle embedding) public void Prompt(ReadOnlySpan embeddings) { AssertCanBePrompted(); + _mtmdLogitsIndex = null; var dim = Executor.Model.EmbeddingSize; var count = embeddings.Length / dim; @@ -385,6 +504,75 @@ public void Modify(ModifyKvCache modifier) _requiredEpoch = 0; } + internal long GetMtmdPast() => _end.Value; + + internal void OnMtmdEvaluationCompleted(long newPast, MtmdChunkSequence sequence) + { + _pendingMtmdSequence?.Dispose(); + _pendingMtmdSequence = null; + + _end = (LLamaPos)checked((int)newPast); + + if (_batchSampleIndices.Length == 0) + _batchSampleIndices = new int[4]; + + _batchSampleCount = 1; + _batchSampleIndices[0] = 0; + _mtmdLogitsIndex = -1; + _requiredEpoch = Executor.Epoch + 1; + _forked = false; + + if (sequence.TextTokens.Count > 0) + { + _embed_inps.AddRange(sequence.TextTokens); + _session_tokens.AddRange(sequence.TextTokens); + } + + var fillerToken = GetFillerToken(Executor.GetMtmdMarker()); + var fillerCount = Math.Max(0, sequence.TotalPositions - sequence.TotalTokens); + for (var i = 0; i < fillerCount; i++) + _embed_inps.Add(fillerToken); + + _consumedTokensCount = _embed_inps.Count; + sequence.Dispose(); + } + + internal void OnMtmdEvaluationFailed(int status) + { + _pendingMtmdSequence?.Dispose(); + _pendingMtmdSequence = null; + _mtmdLogitsIndex = null; + _requiredEpoch = Executor.Epoch; + DisposeQueuedMedia(); + } + + internal int? MtmdLogitsIndex => _mtmdLogitsIndex; + + private LLamaToken GetFillerToken(string marker) + { + var markerTokens = Executor.Context.Tokenize(marker, addBos: false, special: true); + if (markerTokens.Length > 0) + return markerTokens[markerTokens.Length - 1]; + + var eos = Executor.Context.Vocab.EOS; + if (eos.HasValue) + return eos.Value; + + return default; + } + + private void DisposeQueuedMedia() + { + if (_mtmdEmbeds.Count == 0) + return; + + foreach (var embed in _mtmdEmbeds) + embed.Dispose(); + + _mtmdEmbeds.Clear(); + Executor.ClipModel?.ClearMedia(); + } + /// /// Provides direct access to the KV cache of a . /// See for how to use this. @@ -629,4 +817,4 @@ internal State() } } #endregion -} \ No newline at end of file +} diff --git a/LLama/Batched/ConversationExtensions.cs b/LLama/Batched/ConversationExtensions.cs index eb0192061..3e25d3f43 100644 --- a/LLama/Batched/ConversationExtensions.cs +++ b/LLama/Batched/ConversationExtensions.cs @@ -18,7 +18,11 @@ public static class ConversationExtensions /// public static LLamaToken Sample(this Conversation conversation, SafeLLamaSamplerChainHandle sampler, int offset = 0) { - return sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.GetSampleIndex(offset)); + var ctx = conversation.Executor.Context.NativeHandle; + if (conversation.MtmdLogitsIndex == -1) + return sampler.Sample(ctx, -1); + + return sampler.Sample(ctx, conversation.GetSampleIndex(offset)); } /// @@ -30,7 +34,11 @@ public static LLamaToken Sample(this Conversation conversation, SafeLLamaSampler /// public static LLamaToken Sample(this Conversation conversation, ISamplingPipeline sampler, int offset = 0) { - return sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.GetSampleIndex(offset)); + var ctx = conversation.Executor.Context.NativeHandle; + if (conversation.MtmdLogitsIndex == -1) + return sampler.Sample(ctx, -1); + + return sampler.Sample(ctx, conversation.GetSampleIndex(offset)); } /// @@ -82,4 +90,4 @@ public static void ShiftLeft(this Conversation conversation, int count, int keep return end.Value - count; }); } -} \ No newline at end of file +} diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index d0829deca..a39ad3836 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -32,11 +32,11 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// protected int _consumedTokensCount; // n_consume /// - /// + /// Number of tokens consumed from the session cache during the current run. /// protected int _n_session_consumed; /// - /// + /// Number of prompt tokens that match the loaded session cache prefix. /// protected int _n_matching_session_tokens; /// @@ -52,7 +52,7 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// protected List _embed_inps = new(); /// - /// + /// Tokens recovered from the session file and reused to warm up the KV cache. /// protected List _session_tokens = new(); /// @@ -81,21 +81,21 @@ public bool IsMultiModal } /// - public LLavaWeights? ClipModel { get; } + public SafeMtmdWeights? ClipModel { get; } /// - public List Images { get; } + public List Embeds { get; } private readonly StreamingTokenDecoder _decoder; /// - /// + /// Initialize a stateful executor bound to a specific context. /// - /// - /// + /// LLama context used for all native interactions. + /// Optional logger for diagnostic output. protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) { - Images = new List(); + Embeds = new List(); _logger = logger; Context = context; _pastTokensCount = 0; @@ -107,22 +107,22 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) } /// - /// + /// Initialize a multimodal executor with the supplied MTMD weights. /// - /// - /// - /// - public StatefulExecutorBase(LLamaContext context, LLavaWeights lLavaWeights, ILogger? logger = null) : + /// LLama context used for all native interactions. + /// Multimodal weights to associate with this executor. + /// Optional logger for diagnostic output. + public StatefulExecutorBase(LLamaContext context, SafeMtmdWeights safeMtmdWeights, ILogger? logger = null) : this( context, logger ) { - ClipModel = lLavaWeights; + ClipModel = safeMtmdWeights; } /// - /// This API is currently not verified. + /// Attach a session cache file so the executor can reuse previous KV state if compatible. /// - /// - /// + /// Path to the llama.cpp session file. + /// The current executor instance for fluent configuration. /// /// public StatefulExecutorBase WithSessionFile(string filename) @@ -179,9 +179,9 @@ public StatefulExecutorBase WithSessionFile(string filename) } /// - /// This API has not been verified currently. + /// Persist the current session cache to disk. /// - /// + /// Destination path for the llama.cpp session file. public void SaveSessionFile(string filename) { var session_token_array = _session_tokens.ToArray(); @@ -209,7 +209,7 @@ protected virtual void HandleRunOutOfContext(int tokensToKeep) } /// - /// Try to reuse the matching prefix from the session file. + /// Try to reuse the matching prompt prefix from the loaded session cache before evaluating new tokens. /// protected virtual void TryReuseMatchingPrefix() { @@ -243,73 +243,73 @@ protected virtual void TryReuseMatchingPrefix() } /// - /// Decide whether to continue the loop. + /// Determine whether the inference loop should continue processing tokens. /// - /// + /// Mutable state associated with the current inference. /// - /// + /// true to continue generating; otherwise false. protected abstract Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken = default); /// - /// Preprocess the inputs before the inference. + /// Prepare the executor for inference by tokenizing input and updating cached state. /// - /// - /// + /// Prompt text to process. + /// Mutable state associated with the current inference. /// protected abstract Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken = default); /// - /// Do some post processing after the inference. + /// Perform any post-processing on the generated tokens. /// - /// - /// + /// Parameters controlling sampling. /// - /// + /// Mutable state associated with the current inference. + /// A tuple indicating whether generation should stop and any extra outputs to emit. protected abstract Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default); /// - /// The core inference logic. + /// Core inference loop that advances the model by one step. /// - /// - /// + /// Parameters controlling sampling. + /// Mutable state associated with the current inference. /// protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default); /// - /// Save the current state to a file. + /// Save the executor state to a serialized snapshot file. /// - /// + /// Destination file for the serialized state. /// public abstract Task SaveState(string filename, CancellationToken cancellationToken = default); /// - /// Get the current state data. + /// Capture the executor state in a serializable object. /// - /// + /// State snapshot suitable for persistence. public abstract ExecutorBaseState GetStateData(); /// - /// Load the state from data. + /// Restore executor state from a previously captured snapshot. /// - /// + /// State snapshot created by . /// public abstract Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default); /// - /// Load the state from a file. + /// Restore executor state from a serialized snapshot file. /// - /// + /// Path to the snapshot produced by . /// public abstract Task LoadState(string filename, CancellationToken cancellationToken = default); /// - /// Execute the inference. + /// Execute an asynchronous inference session. /// - /// The prompt. If null, generation will continue where it left off previously. - /// - /// - /// + /// Optional prompt; when null generation resumes from prior state. + /// Sampling parameters to apply; defaults are used when null. + /// Cancellation token for cooperative cancellation. + /// Stream of decoded text segments as they become available. public virtual async IAsyncEnumerable InferAsync(string? text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); @@ -390,12 +390,12 @@ public virtual async Task PrefillPromptAsync(string prompt, CancellationToken ca } /// - /// State arguments that are used in single inference + /// Mutable state passed between inference callbacks during a single generation pass. /// protected class InferStateArgs { /// - /// + /// Anti-prompts that terminate generation when encountered. /// public IList? Antiprompts { get; set; } /// @@ -403,15 +403,15 @@ protected class InferStateArgs /// public int RemainedTokens { get; set; } /// - /// + /// Indicates whether generated tokens should be returned to the caller. /// public bool ReturnValue { get; set; } /// - /// + /// Signals that the executor should pause and wait for additional user input. /// public bool WaitForInput { get; set; } /// - /// + /// Indicates whether the session cache should be persisted after inference completes. /// public bool NeedToSaveSession { get; set; } @@ -422,6 +422,9 @@ protected class InferStateArgs } #pragma warning disable CS1591, CS8618 // Missing XML and irrelevant nullable warnings + /// + /// Serializable snapshot of executor state used for persistence and restart. + /// [JsonConverter(typeof(PolymorphicJSONConverter))] public class ExecutorBaseState { @@ -459,5 +462,33 @@ public class ExecutorBaseState public float? MirostatMu { get; set; } } #pragma warning restore + + internal ExecutorDiagnostics GetDiagnostics() + { + return new ExecutorDiagnostics( + _embed_inps.Count, + _consumedTokensCount, + _pastTokensCount, + _embeds.Count); + } + } +} + +namespace LLama +{ + internal readonly struct ExecutorDiagnostics + { + public ExecutorDiagnostics(int embedCount, int consumedCount, int pastCount, int pendingEmbeds) + { + EmbedCount = embedCount; + ConsumedCount = consumedCount; + PastCount = pastCount; + PendingEmbedCount = pendingEmbeds; + } + + public int EmbedCount { get; } + public int ConsumedCount { get; } + public int PastCount { get; } + public int PendingEmbedCount { get; } } } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 517a4e7d0..1bdba035a 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Text; using System.Text.Json; using System.Text.Json.Serialization; using System.Threading; @@ -25,6 +26,9 @@ public class InstructExecutor private readonly string _instructionPrefix; private LLamaToken[] _inp_pfx; private LLamaToken[] _inp_sfx; + private SafeMtmdInputChunks? _mtmdChunks; + private string? _mtmdMarker; + private readonly string _instructionSuffix; /// /// @@ -42,6 +46,20 @@ public InstructExecutor(LLamaContext context, _inp_pfx = Context.Tokenize(instructionPrefix, true, true); _inp_sfx = Context.Tokenize(instructionSuffix, false, true); _instructionPrefix = instructionPrefix; + _instructionSuffix = instructionSuffix; + } + + public InstructExecutor(LLamaContext context, + SafeMtmdWeights clipModel, + string instructionPrefix = "\n\n### Instruction:\n\n", + string instructionSuffix = "\n\n### Response:\n\n", + ILogger? logger = null) + : base(context, clipModel, logger) + { + _inp_pfx = Context.Tokenize(instructionPrefix, true, true); + _inp_sfx = Context.Tokenize(instructionSuffix, false, true); + _instructionPrefix = instructionPrefix; + _instructionSuffix = instructionSuffix; } /// @@ -68,7 +86,8 @@ public override ExecutorBaseState GetStateData() /// public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default) { - if (data is InstructExecutorState state) + DisposeMtmdChunks(); + if(data is InstructExecutorState state) { _n_session_consumed = state.ConsumedSessionCount; _embed_inps = state.EmbedInps!.ToList(); @@ -128,7 +147,14 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc { // When running the first input (prompt) in interactive mode, we should specially process it. if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously."); - _embed_inps = Context.Tokenize(text, true, true).ToList(); + if (!IsMultiModal) + { + _embed_inps = Context.Tokenize(text, true, true).ToList(); + } + else + { + return PreprocessMtmd(text, args, addBos: true, replaceExisting: true); + } } else { @@ -141,20 +167,161 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc { text += "\n"; } - _embed_inps.AddRange(_inp_pfx); + if (!IsMultiModal) + { + _embed_inps.AddRange(_inp_pfx); - var line_inp = Context.Tokenize(text, false, true); - _embed_inps.AddRange(line_inp); + var line_inp = Context.Tokenize(text, false, true); + _embed_inps.AddRange(line_inp); - _embed_inps.AddRange(_inp_sfx); + _embed_inps.AddRange(_inp_sfx); - args.RemainedTokens -= line_inp.Length; + args.RemainedTokens -= line_inp.Length; + } + else + { + var builder = new StringBuilder(); + builder.Append(_instructionPrefix); + builder.Append(text); + builder.Append(_instructionSuffix); + return PreprocessMtmd(builder.ToString(), args, addBos: false, replaceExisting: false); + } } } return Task.CompletedTask; } + private void DisposeMtmdChunks() + { + _mtmdChunks?.Dispose(); + _mtmdChunks = null; + } + + private void DisposeEmbeds() + { + if (Embeds.Count == 0) + return; + + foreach (var embed in Embeds) + embed.Dispose(); + + Embeds.Clear(); + } + + private string GetMtmdMarker() + { + if (_mtmdMarker is not null) + return _mtmdMarker; + + _mtmdMarker = NativeApi.MtmdDefaultMarker() ?? ""; + return _mtmdMarker; + } + + private static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken) + { + if (totalPositions <= tokens.Count) + return new List(tokens); + + var result = new List(totalPositions); + result.AddRange(tokens); + result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count)); + return result; + } + + private LLamaToken GetFillerToken(string marker) + { + var markerTokens = Context.Tokenize(marker, false, true); + if (markerTokens.Length > 0) + return markerTokens[markerTokens.Length - 1]; + + var eos = Context.Vocab.EOS; + if (eos.HasValue) + return eos.Value; + + return default(LLamaToken); + } + + private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, bool replaceExisting) + { + if (ClipModel is null) + throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); + + DisposeMtmdChunks(); + + var marker = GetMtmdMarker(); + var prompt = text; + + if (Embeds.Count > 0) + { + if (prompt.Contains("")) + prompt = prompt.Replace("", marker); + + if (!prompt.Contains(marker)) + { + var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count)); + prompt = string.Concat(prompt, suffix); + } + } + + SafeMtmdInputChunks? chunks = null; + try + { + var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); + if (status != 0 || chunks is null) + { + ClipModel.ClearMedia(); + throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); + } + + _mtmdChunks = chunks; + + var tokens = new List(); + foreach (var chunk in chunks.Enumerate()) + { + using var scopedChunk = chunk; + if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) + continue; + + foreach (var token in scopedChunk.GetTextTokensSpan()) + tokens.Add(unchecked((int)token)); + } + + var totalPositions = (int)ClipModel.CountPositions(chunks); + var fillerToken = GetFillerToken(marker); + + if (replaceExisting) + { + _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken); + _consumedTokensCount = 0; + } + else + { + if (_embed_inps.Count == 0) + _embed_inps = new List(); + + _embed_inps.AddRange(tokens); + var fillerCount = totalPositions - tokens.Count; + if (fillerCount > 0) + _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount)); + + args.RemainedTokens -= tokens.Count; + } + } + catch + { + chunks?.Dispose(); + _mtmdChunks = null; + throw; + } + finally + { + DisposeEmbeds(); + } + + return Task.CompletedTask; + } + /// protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { @@ -217,11 +384,43 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In _n_session_consumed = _session_tokens.Count; } } + else if (IsMultiModal && _mtmdChunks is not null) + { + _is_prompt_run = false; + var nPast = (long)_pastTokensCount; + var previousConsumed = _consumedTokensCount; + var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)Context.BatchSize), logitsLast: true); + if (evalStatus != 0) + { + _logger?.LogError("[InstructExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); + DisposeMtmdChunks(); + throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); + } + + _pastTokensCount = checked((int)nPast); + DisposeMtmdChunks(); + + if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) + { + _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); + _n_session_consumed = _session_tokens.Count; + } + + _consumedTokensCount = _embed_inps.Count; + _embeds.Clear(); + } _embeds.Clear(); if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) { + if (inferenceParams.MaxTokens == 0) + { + _embeds.Clear(); + args.WaitForInput = true; + args.ReturnValue = false; + return; + } // optionally save the session on first sample (for faster prompt loading next time) if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) { diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index e7cac4c47..97d49f5de 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using LLama.Abstractions; using LLama.Common; +using LLama; using LLama.Exceptions; using LLama.Native; using LLama.Sampling; @@ -21,30 +22,31 @@ namespace LLama /// public class InteractiveExecutor : StatefulExecutorBase { + // Indicates whether the executor is currently evaluating the initial prompt or a follow-up turn. private bool _is_prompt_run = true; - // LLava - private int _EmbedImagePosition = -1; - private List _imageEmbedHandles = new List(); - private bool _imageInPrompt = false; + // MTMD multimodal state + private SafeMtmdInputChunks? _mtmdChunks; // Pending chunk collection produced by the multimodal tokenizer. + private string? _mtmdMarker; // Cached multimodal marker returned by the native helper. + /// - /// + /// Create an interactive executor for text-only inference. /// - /// - /// + /// LLama context to operate against. + /// Optional logger for diagnostic output. public InteractiveExecutor(LLamaContext context, ILogger? logger = null) : base(context, logger) { } /// - /// + /// Create an interactive multimodal executor that can process text alongside media inputs. /// - /// - /// - /// - public InteractiveExecutor(LLamaContext context, LLavaWeights clipModel, ILogger? logger = null) + /// LLama context to operate against. + /// Multimodal weights (MTMD) to attach to the executor. + /// Optional logger for diagnostic output. + public InteractiveExecutor(LLamaContext context, SafeMtmdWeights clipModel, ILogger? logger = null) : base(context, clipModel, logger) { } @@ -72,6 +74,7 @@ public override ExecutorBaseState GetStateData() /// public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default) { + DisposeMtmdChunks(); if (data is InteractiveExecutorState state) { _n_session_consumed = state.ConsumedSessionCount; @@ -111,15 +114,20 @@ public override async Task LoadState(string filename, CancellationToken cancella } /// - /// Define whether to continue the loop to generate responses. + /// Decide whether generation should continue for the current iteration. /// - /// + /// Mutable inference state. + /// true to keep generating; otherwise false. protected override Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken) { return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run); } - /// + /// + /// Preprocess the incoming prompt or continuation text before inference. + /// + /// Prompt text or continuation provided by the caller. + /// Mutable inference state. protected override Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken) { if (_is_prompt_run) @@ -136,7 +144,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc } else { - PreprocessLlava(text, args, true); + PreprocessMtmd(text, args, true); } } else @@ -157,7 +165,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc } else { - PreprocessLlava(text, args, false); + PreprocessMtmd(text, args, false); } } } @@ -165,51 +173,172 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc return Task.CompletedTask; } - /// - private void PreprocessLlava(string text, InferStateArgs args, bool addBos = true) + /// + /// Release any queued multimodal chunks and reset state. + /// + private void DisposeMtmdChunks() + { + _mtmdChunks?.Dispose(); + _mtmdChunks = null; + } + + /// + /// Dispose and clear any pending multimodal embeddings queued for evaluation. + /// + private void DisposeEmbeds() + { + if (Embeds.Count == 0) + { + return; + } + + foreach (var embed in Embeds) + { + embed.Dispose(); + } + + Embeds.Clear(); + } + + /// + /// Retrieve the marker token used to signal media segments to the tokenizer. + /// + private string GetMtmdMarker() { - // If the prompt contains the tag extract this. - _imageInPrompt = text.Contains(""); - if (_imageInPrompt && IsMultiModal) + if (_mtmdMarker is not null) { - foreach (var image in Images) + return _mtmdMarker; + } + + _mtmdMarker = NativeApi.MtmdDefaultMarker() ?? ""; + return _mtmdMarker; + } + + private static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken) + { + if (totalPositions <= tokens.Count) + return new List(tokens); + + var result = new List(totalPositions); + result.AddRange(tokens); + result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count)); + return result; + } + + private LLamaToken GetFillerToken(string marker) + { + var markerTokens = Context.Tokenize(marker, false, true); + if (markerTokens.Length > 0) + return markerTokens[markerTokens.Length - 1]; + + var eos = Context.Vocab.EOS; + if (eos.HasValue) + return eos.Value; + + return default(LLamaToken); + } + + /// + /// Preprocess multimodal prompts by aligning media markers and tokenizing via MTMD helpers. + /// + /// Prompt text containing optional media markers. + /// Mutable inference state. + /// Whether to treat the prompt as a fresh run and add the BOS token. + private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos = true) + { + if (ClipModel is null) + { + throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); + } + + DisposeMtmdChunks(); + + var marker = GetMtmdMarker(); + var prompt = text; + + if (Embeds.Count > 0) + { + if (prompt.Contains("")) { - _imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromMemory(ClipModel!.NativeHandle, Context, image)); + prompt = prompt.Replace("", marker); } - int imageIndex = text.IndexOf(""); - // Tokenize segment 1 (before tag) - string preImagePrompt = text.Substring(0, imageIndex); - var segment1 = Context.Tokenize(preImagePrompt, addBos, true); - // Remember the position to add the image embeddings - _EmbedImagePosition = segment1.Length; - string postImagePrompt = text.Substring(imageIndex + 7); - var segment2 = Context.Tokenize(postImagePrompt, false, true); - _embed_inps.AddRange(segment1); - _embed_inps.AddRange(segment2); + if (!prompt.Contains(marker)) + { + var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count)); // Ensure tokenizer sees one marker per embed. + prompt = string.Concat(prompt, suffix); + } } - else + + SafeMtmdInputChunks? chunks = null; + try { + var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); + if (status != 0 || chunks is null) + { + ClipModel.ClearMedia(); + throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); + } + + _mtmdChunks = chunks; // Own the chunk collection until evaluation completes. + + var tokens = new List(); + foreach (var chunk in chunks.Enumerate()) + { + using var scopedChunk = chunk; + if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) + { + continue; + } + + foreach (var token in scopedChunk.GetTextTokensSpan()) + { + tokens.Add(unchecked((int)token)); + } + } + + var totalPositions = (int)ClipModel.CountPositions(chunks); + var fillerToken = GetFillerToken(marker); + if (addBos) { - _embed_inps = Context.Tokenize(text, true, true).ToList(); + _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken); + _consumedTokensCount = 0; } else { - var line_inp = Context.Tokenize(text, false, true); - _embed_inps.AddRange(line_inp); - args.RemainedTokens -= line_inp.Length; + if (_embed_inps.Count == 0) + _embed_inps = new List(); + + _embed_inps.AddRange(tokens); + var fillerCount = totalPositions - tokens.Count; + if (fillerCount > 0) + _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount)); + + args.RemainedTokens -= tokens.Count; } } + catch + { + chunks?.Dispose(); + _mtmdChunks = null; + throw; + } + finally + { + DisposeEmbeds(); // Flush any embeds decoded in prior step; MTMD replays them via chunk eval. + } + + return Task.CompletedTask; } /// - /// Return whether to break the generation. + /// Decide whether generation should stop based on antiprompts, token limits, or end-of-generation markers. /// - /// - /// + /// Sampling parameters controlling generation. + /// Mutable inference state. /// - /// + /// Tuple describing whether to stop and any additional outputs to emit. protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { if (_embed_inps.Count <= _consumedTokensCount) @@ -264,51 +393,87 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In HandleRunOutOfContext(tokensToKeep); } - TryReuseMatchingPrefix(); - - // Changes to support Multi-Modal LLMs. - // - (DecodeResult, int, int) header, end, result; - if (IsMultiModal && _EmbedImagePosition > 0) + if (_mtmdChunks is null) { - // Tokens previous to the images - header = await Context.DecodeAsync(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount); - _pastTokensCount = header.Item3; + TryReuseMatchingPrefix(); + } - if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1); + if (IsMultiModal && _mtmdChunks is not null) + { + var nPast = (long)_pastTokensCount; + var previousConsumed = _consumedTokensCount; + var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, + nBatch: checked((int)Context.BatchSize), logitsLast: true); + if (evalStatus != 0) + { + _logger?.LogError("[InteractiveExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); + DisposeMtmdChunks(); + throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); + } - // Images - foreach (var image in _imageEmbedHandles) - ClipModel!.EvalImageEmbed(Context, image, ref _pastTokensCount); + _pastTokensCount = checked((int)nPast); + DisposeMtmdChunks(); - // Post-image Tokens - end = await Context.DecodeAsync(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount); - _pastTokensCount = end.Item3; + if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) + { + _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); + _n_session_consumed = _session_tokens.Count; + } - _EmbedImagePosition = -1; - _imageEmbedHandles.Clear(); - Images.Clear(); + _consumedTokensCount = _embed_inps.Count; + _embeds.Clear(); } else { - result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount); + var result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount); _pastTokensCount = result.Item3; if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1); + + if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) + { + _session_tokens.AddRange(_embeds); + _n_session_consumed = _session_tokens.Count; + } + } + } + else if (IsMultiModal && _mtmdChunks is not null) + { + _is_prompt_run = false; + var nPast = (long)_pastTokensCount; + var previousConsumed = _consumedTokensCount; + var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)Context.BatchSize), logitsLast: true); + if (evalStatus != 0) + { + _logger?.LogError("[InteractiveExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); + DisposeMtmdChunks(); + throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); } + _pastTokensCount = checked((int)nPast); + DisposeMtmdChunks(); - if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) + if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) { - _session_tokens.AddRange(_embeds); + _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); _n_session_consumed = _session_tokens.Count; } - } + _consumedTokensCount = _embed_inps.Count; + _embeds.Clear(); + } + _embeds.Clear(); if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) { + if (inferenceParams.MaxTokens == 0) + { + _embeds.Clear(); + args.WaitForInput = true; + args.ReturnValue = false; + return; + } // optionally save the session on first sample (for faster prompt loading next time) if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession) { @@ -355,10 +520,10 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In } /// - /// The descriptor of the state of the interactive executor. + /// Serializable state specific to the interactive executor. /// public class InteractiveExecutorState - : ExecutorBaseState + : StatefulExecutorBase.ExecutorBaseState { /// /// Whether the executor is running for the first time (running the prompt). diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index ff4e87487..be763f419 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -3,7 +3,7 @@ netstandard2.0;net8.0 LLama enable - 12 + 13 AnyCPU;x64;Arm64 True @@ -17,7 +17,7 @@ https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 LLama, LLM, GPT, ChatGPT, NLP, AI, Chat Bot, SciSharp - LLamaSharp is a cross-platform library to run 🦙LLaMA/LLaVA model (and others) in your local device. + LLamaSharp is a cross-platform library to run 🦙LLaMA/Mtmd model (and others) in your local device. Based on [llama.cpp](https://github.com/ggerganov/llama.cpp), inference with LLamaSharp is efficient on both CPU and GPU. With the higher-level APIs and RAG support, it's convenient to deploy LLM (Large Language Model) in your application with LLamaSharp. diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 8f9b40cc3..94bc60830 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -28,10 +28,10 @@ public class StatelessExecutor public bool IsMultiModal => false; /// - public LLavaWeights? ClipModel => default; + public SafeMtmdWeights? ClipModel => default; /// - public List Images { get; } + public List Embeds { get; } /// /// The context used by the executor when running the inference. @@ -57,7 +57,7 @@ public class StatelessExecutor /// public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null) { - Images = [ ]; + Embeds = [ ]; _weights = weights; _params = @params; _logger = logger; diff --git a/LLama/LLavaWeights.cs b/LLama/LLavaWeights.cs deleted file mode 100644 index f2f9f6256..000000000 --- a/LLama/LLavaWeights.cs +++ /dev/null @@ -1,137 +0,0 @@ - -using System; -using System.Threading; -using System.Threading.Tasks; -using LLama.Native; - -namespace LLama; - -/// -/// A set of llava model weights (mmproj), loaded into memory. -/// -public sealed class LLavaWeights - : IDisposable -{ - /// - /// The native handle, which is used in the native APIs - /// - /// Be careful how you use this! - public SafeLlavaModelHandle NativeHandle { get; } - - private LLavaWeights(SafeLlavaModelHandle weights) - { - NativeHandle = weights; - } - - #region load - /// - /// Load weights into memory - /// - /// path to the "mmproj" model file - /// - public static LLavaWeights LoadFromFile(string mmProject) - { - var weights = SafeLlavaModelHandle.LoadFromFile(mmProject, 1); - return new LLavaWeights(weights); - } - - /// - /// Load weights into memory - /// - /// path to the "mmproj" model file - /// - /// - public static Task LoadFromFileAsync(string mmProject, CancellationToken token = default) - { - return Task.Run(() => LoadFromFile(mmProject), token); - } - #endregion - - #region embed - /// - /// Create the Image Embeddings from the bytes of an image. - /// - /// - /// Image bytes. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image) - { - return NativeHandle.CreateImageEmbeddings(ctxLlama, image); - } - - /// - /// Create the Image Embeddings. - /// - /// Image in binary format (it supports jpeg format only) - /// Number of threads to use - /// return the SafeHandle of these embeddings - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads = -1) - { - return NativeHandle.CreateImageEmbeddings(image, threads); - } - - /// - /// Create the Image Embeddings from the bytes of an image. - /// - /// - /// Path to the image file. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - /// - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image) - { - return NativeHandle.CreateImageEmbeddings(ctxLlama, image); - } - - /// - /// Create the Image Embeddings from the bytes of an image. - /// - /// Path to the image file. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - /// - /// - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(string image, int threads = -1) - { - return NativeHandle.CreateImageEmbeddings(image, threads); - } - #endregion - - /// - /// Eval the image embeddings - /// - /// - /// - /// - /// - public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imageEmbed, ref int n_past) - { - return NativeHandle.EvalImageEmbed( ctxLlama, imageEmbed, ref n_past ); - } - - /// - public void Dispose() - { - NativeHandle.Dispose(); - } - -} \ No newline at end of file diff --git a/LLama/Native/LLavaImageEmbed.cs b/LLama/Native/LLavaImageEmbed.cs deleted file mode 100644 index 65eba230c..000000000 --- a/LLama/Native/LLavaImageEmbed.cs +++ /dev/null @@ -1,19 +0,0 @@ -namespace LLama.Native; - -/// -/// LLaVa Image embeddings -/// -/// llava_image_embed -[StructLayout(LayoutKind.Sequential)] -public unsafe struct LLavaImageEmbed -{ - /// - /// The embeddings of the embedded image. - /// - public float* embed; - - /// - /// The position of the image's tokens. - /// - public int n_image_pos; -} \ No newline at end of file diff --git a/LLama/Native/Load/NativeLibraryConfig.cs b/LLama/Native/Load/NativeLibraryConfig.cs index c20453e27..723717c23 100644 --- a/LLama/Native/Load/NativeLibraryConfig.cs +++ b/LLama/Native/Load/NativeLibraryConfig.cs @@ -299,15 +299,15 @@ public sealed partial class NativeLibraryConfig public static NativeLibraryConfig LLama { get; } /// - /// Configuration for LLava native library + /// Configuration for Mtmd native library /// - public static NativeLibraryConfig LLava { get; } + public static NativeLibraryConfig Mtmd { get; } static NativeLibraryConfig() { LLama = new(NativeLibraryName.LLama); - LLava = new(NativeLibraryName.LLava); - All = new(LLama, LLava); + Mtmd = new(NativeLibraryName.Mtmd); + All = new(LLama, Mtmd); } #if NETSTANDARD2_0 @@ -413,9 +413,9 @@ public void ForEach(Action action) /// When this method is called, all the other configurations will be ignored. /// /// The full path to the llama library to load. - /// The full path to the llava library to load. + /// The full path to the mtmd library to load. /// Thrown if `LibraryHasLoaded` is true. - public NativeLibraryConfigContainer WithLibrary(string? llamaPath, string? llavaPath) + public NativeLibraryConfigContainer WithLibrary(string? llamaPath, string? mtmdPath) { foreach(var config in _configs) { @@ -423,9 +423,9 @@ public NativeLibraryConfigContainer WithLibrary(string? llamaPath, string? llava { config.WithLibrary(llamaPath); } - if(config.NativeLibraryName == NativeLibraryName.LLava && llavaPath is not null) + if(config.NativeLibraryName == NativeLibraryName.Mtmd && mtmdPath is not null) { - config.WithLibrary(llavaPath); + config.WithLibrary(mtmdPath); } } @@ -594,7 +594,7 @@ public NativeLibraryConfigContainer WithLogCallback(ILogger? logger) /// You can still modify the configuration after this calling but only before any call from . /// /// Whether the running is successful. - public bool DryRun(out INativeLibrary? loadedLLamaNativeLibrary, out INativeLibrary? loadedLLavaNativeLibrary) + public bool DryRun(out INativeLibrary? loadedLLamaNativeLibrary, out INativeLibrary? loadedMtmdNativeLibrary) { bool success = true; foreach(var config in _configs) @@ -604,16 +604,16 @@ public bool DryRun(out INativeLibrary? loadedLLamaNativeLibrary, out INativeLibr { loadedLLamaNativeLibrary = loadedLibrary; } - else if(config.NativeLibraryName == NativeLibraryName.LLava) + else if(config.NativeLibraryName == NativeLibraryName.Mtmd) { - loadedLLavaNativeLibrary = loadedLibrary; + loadedMtmdNativeLibrary = loadedLibrary; } else { throw new Exception("Unknown native library config during the dry run."); } } - loadedLLamaNativeLibrary = loadedLLavaNativeLibrary = null; + loadedLLamaNativeLibrary = loadedMtmdNativeLibrary = null; return success; } } @@ -628,9 +628,9 @@ public enum NativeLibraryName /// LLama, /// - /// The native library compiled from the LLaVA example of llama.cpp. + /// The native library compiled from the MTMD library of llama.cpp. /// - LLava + Mtmd } internal static class LibraryNameExtensions @@ -641,8 +641,8 @@ public static string GetLibraryName(this NativeLibraryName name) { case NativeLibraryName.LLama: return NativeApi.libraryName; - case NativeLibraryName.LLava: - return NativeApi.llavaLibraryName; + case NativeLibraryName.Mtmd: + return NativeApi.mtmdLibraryName; default: throw new ArgumentOutOfRangeException(nameof(name), name, null); } diff --git a/LLama/Native/Load/NativeLibraryUtils.cs b/LLama/Native/Load/NativeLibraryUtils.cs index 9f6457cd1..84ababc60 100644 --- a/LLama/Native/Load/NativeLibraryUtils.cs +++ b/LLama/Native/Load/NativeLibraryUtils.cs @@ -9,7 +9,7 @@ namespace LLama.Native internal static class NativeLibraryUtils { /// - /// Try to load libllama/llava_shared, using CPU feature detection to try and load a more specialised DLL if possible + /// Try to load libllama/mtmd, using CPU feature detection to try and load a more specialised DLL if possible /// /// The library handle to unload later, or IntPtr.Zero if no library was loaded internal static IntPtr TryLoadLibrary(NativeLibraryConfig config, out INativeLibrary? loadedLibrary) diff --git a/LLama/Native/MtmdContextParams.cs b/LLama/Native/MtmdContextParams.cs new file mode 100644 index 000000000..d83831d85 --- /dev/null +++ b/LLama/Native/MtmdContextParams.cs @@ -0,0 +1,148 @@ +using System; +using System.Runtime.InteropServices; +using System.Text; + +namespace LLama.Native; + +/// +/// Managed representation of the native mtmd_context_params structure used to configure multimodal helpers. +/// +public class MtmdContextParams +{ + /// + /// Whether GPU acceleration should be requested when available. + /// + public bool UseGpu { get; set; } + + /// + /// Whether timing information should be emitted by the native helper. + /// + public bool PrintTimings { get; set; } + + /// + /// Number of worker threads to dedicate to preprocessing and tokenization. + /// + public int NThreads { get; set; } + + /// + /// Verbosity level forwarded to llama.cpp logging (matches ggml_log_level). + /// + public int Verbosity { get; set; } + + /// + /// Marker token inserted into the text stream to reference an image embedding. + /// + public string? ImageMarker { get; set; } + + /// + /// Marker token inserted into the text stream to reference a generic media embedding. + /// + public string? MediaMarker { get; set; } + + /// + /// Create a managed copy of the native defaults returned by . + /// + public static MtmdContextParams Default() + { + var native = NativeApi.mtmd_context_params_default(); + return new MtmdContextParams + { + UseGpu = native.use_gpu, + PrintTimings = native.print_timings, + NThreads = native.n_threads, + Verbosity = native.verbosity, + ImageMarker = PtrToString(native.image_marker), + MediaMarker = PtrToString(native.media_marker) + }; + } + + private static string? PtrToString(IntPtr ptr) + { + if (ptr == IntPtr.Zero) + return null; + +#if NETSTANDARD2_0 + unsafe + { + var length = 0; + var current = (byte*)ptr; + while (current[length] != 0) + length++; + + if (length == 0) + return string.Empty; + + var buffer = new byte[length]; + Marshal.Copy(ptr, buffer, 0, length); + return Encoding.UTF8.GetString(buffer); + } +#else + return Marshal.PtrToStringUTF8(ptr); +#endif + } + + /// + /// Convert the managed representation to a native structure, pinning strings for the duration of the scope. + /// + internal NativeScope ToNativeScope() => new(this); + + internal readonly struct NativeScope : IDisposable + { + public NativeApi.mtmd_context_params Value { get; } + + private readonly PinnedUtf8String? _imageMarker; + private readonly PinnedUtf8String? _mediaMarker; + + public NativeScope(MtmdContextParams managed) + { + _imageMarker = PinnedUtf8String.Create(managed.ImageMarker); + _mediaMarker = PinnedUtf8String.Create(managed.MediaMarker); + + var native = NativeApi.mtmd_context_params_default(); + native.use_gpu = managed.UseGpu; + native.print_timings = managed.PrintTimings; + native.n_threads = managed.NThreads; + native.verbosity = managed.Verbosity; + + if (_imageMarker is not null) + native.image_marker = _imageMarker.Pointer; + if (_mediaMarker is not null) + native.media_marker = _mediaMarker.Pointer; + + Value = native; + } + + public void Dispose() + { + _imageMarker?.Dispose(); + _mediaMarker?.Dispose(); + } + } +} + +/// +/// Helper that pins a managed string as UTF-8 for the lifetime of the instance. +/// +internal sealed class PinnedUtf8String : IDisposable +{ + private readonly byte[]? _buffer; + private readonly GCHandle _handle; + + private PinnedUtf8String(string value) + { + var bytes = Encoding.UTF8.GetBytes(value); + _buffer = new byte[bytes.Length + 1]; + Buffer.BlockCopy(bytes, 0, _buffer, 0, bytes.Length); + _handle = GCHandle.Alloc(_buffer, GCHandleType.Pinned); + } + + public static PinnedUtf8String? Create(string? value) => value is null ? null : new PinnedUtf8String(value); + + public IntPtr Pointer => _buffer is null ? IntPtr.Zero : _handle.AddrOfPinnedObject(); + + public void Dispose() + { + if (_buffer is not null && _handle.IsAllocated) + _handle.Free(); + } +} diff --git a/LLama/Native/MtmdImageEmbed.cs b/LLama/Native/MtmdImageEmbed.cs new file mode 100644 index 000000000..7341b8563 --- /dev/null +++ b/LLama/Native/MtmdImageEmbed.cs @@ -0,0 +1,20 @@ +using System.Runtime.InteropServices; + +namespace LLama.Native; + +/// +/// Representation of the native llava_image_embed structure used to return image embeddings. +/// +[StructLayout(LayoutKind.Sequential)] +public unsafe struct MtmdImageEmbed +{ + /// + /// Pointer to the embedding buffer for the decoded image. + /// + public float* embed; + + /// + /// Number of sequence positions consumed by the image tokens associated with the embedding. + /// + public int n_image_pos; +} diff --git a/LLama/Native/NativeApi.LLava.cs b/LLama/Native/NativeApi.LLava.cs deleted file mode 100644 index 692e3f0ad..000000000 --- a/LLama/Native/NativeApi.LLava.cs +++ /dev/null @@ -1,63 +0,0 @@ -using System; - -namespace LLama.Native; - -public static partial class NativeApi -{ - /// - /// Sanity check for clip <-> llava embed size match - /// - /// LLama Context - /// Llava Model - /// True if validate successfully - [DllImport(llavaLibraryName, EntryPoint = "llava_validate_embed_size", CallingConvention = CallingConvention.Cdecl)] - [return: MarshalAs(UnmanagedType.U1)] - public static extern bool llava_validate_embed_size( SafeLLamaContextHandle ctxLlama, SafeLlavaModelHandle ctxClip); - - /// - /// Build an image embed from image file bytes - /// - /// SafeHandle to the Clip Model - /// Number of threads - /// Binary image in jpeg format - /// Bytes length of the image - /// SafeHandle to the Embeddings - [DllImport(llavaLibraryName, EntryPoint = "llava_image_embed_make_with_bytes", - CallingConvention = CallingConvention.Cdecl)] - public static extern - SafeLlavaImageEmbedHandle llava_image_embed_make_with_bytes(SafeLlavaModelHandle ctx_clip, int n_threads, - byte[] image_bytes, int image_bytes_length); - - /// - /// Build an image embed from a path to an image filename - /// - /// SafeHandle to the Clip Model - /// Number of threads - /// Image filename (jpeg) to generate embeddings - /// SafeHandle to the embeddings - [DllImport(llavaLibraryName, EntryPoint = "llava_image_embed_make_with_filename", CallingConvention = CallingConvention.Cdecl)] - public static extern - SafeLlavaImageEmbedHandle llava_image_embed_make_with_filename(SafeLlavaModelHandle ctx_clip, int n_threads, - [MarshalAs(UnmanagedType.LPStr)] string image_path); - - /// - /// Free an embedding made with llava_image_embed_make_* - /// - /// Embeddings to release - [DllImport(llavaLibraryName, EntryPoint = "llava_image_embed_free", CallingConvention = CallingConvention.Cdecl)] - public static extern void llava_image_embed_free(IntPtr embed); - - /// - /// Write the image represented by embed into the llama context with batch size n_batch, starting at context - /// pos n_past. on completion, n_past points to the next position in the context after the image embed. - /// - /// Llama Context - /// Embedding handle - /// - /// - /// True on success - [DllImport(llavaLibraryName, EntryPoint = "llava_eval_image_embed", CallingConvention = CallingConvention.Cdecl)] - [return: MarshalAs(UnmanagedType.U1)] - public static extern bool llava_eval_image_embed(SafeLLamaContextHandle ctx_llama, SafeLlavaImageEmbedHandle embed, int n_batch, ref int n_past); - -} \ No newline at end of file diff --git a/LLama/Native/NativeApi.Load.cs b/LLama/Native/NativeApi.Load.cs index 4555ed0d2..57bb2d146 100644 --- a/LLama/Native/NativeApi.Load.cs +++ b/LLama/Native/NativeApi.Load.cs @@ -16,7 +16,7 @@ static NativeApi() // Set flag to indicate that this point has been passed. No native library config can be done after this point. NativeLibraryConfig.LLama.LibraryHasLoaded = true; - NativeLibraryConfig.LLava.LibraryHasLoaded = true; + NativeLibraryConfig.Mtmd.LibraryHasLoaded = true; // Immediately make a call which requires loading the llama DLL. This method call // can't fail unless the DLL hasn't been loaded. @@ -45,7 +45,7 @@ static NativeApi() #if NET5_0_OR_GREATER private static IntPtr _loadedLlamaHandle; - private static IntPtr _loadedLlavaSharedHandle; + private static IntPtr _loadedMtmdHandle; #endif private static void SetDllImportResolver() @@ -72,15 +72,15 @@ private static void SetDllImportResolver() return _loadedLlamaHandle; } - if (name == "llava_shared") + if (name == "mtmd") { - // If we've already loaded llava return the handle that was loaded last time. - if (_loadedLlavaSharedHandle != IntPtr.Zero) - return _loadedLlavaSharedHandle; + // If we've already loaded Mtmd return the handle that was loaded last time. + if (_loadedMtmdHandle != IntPtr.Zero) + return _loadedMtmdHandle; // Try to load a preferred library, based on CPU feature detection - _loadedLlavaSharedHandle = NativeLibraryUtils.TryLoadLibrary(NativeLibraryConfig.LLava, out _loadedLLavaLibrary); - return _loadedLlavaSharedHandle; + _loadedMtmdHandle = NativeLibraryUtils.TryLoadLibrary(NativeLibraryConfig.Mtmd, out _loadedMtmdLibrary); + return _loadedMtmdHandle; } // Return null pointer to indicate that nothing was loaded. @@ -100,17 +100,17 @@ private static void SetDllImportResolver() return name switch { NativeLibraryName.LLama => _loadedLLamaLibrary, - NativeLibraryName.LLava => _loadedLLavaLibrary, + NativeLibraryName.Mtmd => _loadedMtmdLibrary, _ => throw new ArgumentException($"Library name {name} is not found.") }; } internal const string libraryName = "llama"; - internal const string llavaLibraryName = "llava_shared"; + internal const string mtmdLibraryName = "mtmd"; internal const string ggmlLibraryName = "ggml"; internal const string ggmlBaseLibraryName = "ggml-base"; private static INativeLibrary? _loadedLLamaLibrary = null; - private static INativeLibrary? _loadedLLavaLibrary = null; + private static INativeLibrary? _loadedMtmdLibrary = null; } } diff --git a/LLama/Native/NativeApi.Mtmd.cs b/LLama/Native/NativeApi.Mtmd.cs new file mode 100644 index 000000000..bfd6193c2 --- /dev/null +++ b/LLama/Native/NativeApi.Mtmd.cs @@ -0,0 +1,312 @@ +using System; +using System.Runtime.InteropServices; +using System.Text; + +namespace LLama.Native; + +/// +/// P/Invoke surface for MTMD (multimodal) helpers exposed by llama.cpp. +/// +public static partial class NativeApi +{ + /// + /// Convert a UTF-8 encoded native string pointer into a managed . + /// Returns null when the pointer is zero. + /// + public static string? PtrToStringUtf8(IntPtr ptr) + { + if (ptr == IntPtr.Zero) + return null; + +#if NETSTANDARD2_0 + unsafe + { + var current = (byte*)ptr; + var length = 0; + while (current[length] != 0) + length++; + + if (length == 0) + return string.Empty; + + var buffer = new byte[length]; + Marshal.Copy(ptr, buffer, 0, length); + return Encoding.UTF8.GetString(buffer); + } +#else + return Marshal.PtrToStringUTF8(ptr); +#endif + } + + /// + /// Native context parameters returned by . + /// + [StructLayout(LayoutKind.Sequential)] + internal struct mtmd_context_params + { + [MarshalAs(UnmanagedType.I1)] public bool use_gpu; + [MarshalAs(UnmanagedType.I1)] public bool print_timings; + public int n_threads; + public int verbosity; + public IntPtr image_marker; + public IntPtr media_marker; + } + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_default_marker", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_default_marker(); + + /// + /// Retrieve the default multimodal marker text. + /// + public static string? MtmdDefaultMarker() + => PtrToStringUtf8(mtmd_default_marker()); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_context_params_default", CallingConvention = CallingConvention.Cdecl)] + internal static extern mtmd_context_params mtmd_context_params_default(); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_decode_use_non_causal", CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + internal static extern bool mtmd_decode_use_non_causal(IntPtr ctx); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_decode_use_mrope", CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + internal static extern bool mtmd_decode_use_mrope(IntPtr ctx); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_support_vision", CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + internal static extern bool mtmd_support_vision(IntPtr ctx); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_support_audio", CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + internal static extern bool mtmd_support_audio(IntPtr ctx); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_get_audio_bitrate", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_get_audio_bitrate(IntPtr ctx); + + // bitmap ------------------------------------------------------------ + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_init", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_bitmap_init(uint nx, uint ny, IntPtr data); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_init_from_audio", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_bitmap_init_from_audio(ulong n_samples, IntPtr data); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_nx", CallingConvention = CallingConvention.Cdecl)] + internal static extern uint mtmd_bitmap_get_nx(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_ny", CallingConvention = CallingConvention.Cdecl)] + internal static extern uint mtmd_bitmap_get_ny(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_data", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_bitmap_get_data(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_n_bytes", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_bitmap_get_n_bytes(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_is_audio", CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + internal static extern bool mtmd_bitmap_is_audio(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_free", CallingConvention = CallingConvention.Cdecl)] + internal static extern void mtmd_bitmap_free(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_id", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_bitmap_get_id(IntPtr bitmap); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_set_id", CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe void mtmd_bitmap_set_id_native(IntPtr bitmap, byte* id); + + /// + /// Assign an identifier to a bitmap using a UTF-8 encoded string. + /// + internal static unsafe void mtmd_bitmap_set_id(IntPtr bitmap, string? id) + { + if (bitmap == IntPtr.Zero) + throw new ArgumentNullException(nameof(bitmap)); + + if (id is null) + { + mtmd_bitmap_set_id_native(bitmap, null); + return; + } + + using var pinned = PinnedUtf8String.Create(id) ?? throw new ArgumentNullException(nameof(id)); + mtmd_bitmap_set_id_native(bitmap, (byte*)pinned.Pointer); + } + + // input_chunks ------------------------------------------------------ + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunks_init", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunks_init(); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunks_size", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_input_chunks_size(IntPtr chunks); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunks_get", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunks_get(IntPtr chunks, UIntPtr idx); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunks_free", CallingConvention = CallingConvention.Cdecl)] + internal static extern void mtmd_input_chunks_free(IntPtr chunks); + + // input_chunk ------------------------------------------------------- + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_type", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_input_chunk_get_type(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_tokens_text", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunk_get_tokens_text(IntPtr chunk, out UIntPtr n_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_tokens_image", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunk_get_tokens_image(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_n_tokens", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_input_chunk_get_n_tokens(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_id", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunk_get_id(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_n_pos", CallingConvention = CallingConvention.Cdecl)] + internal static extern long mtmd_input_chunk_get_n_pos(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_copy", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_input_chunk_copy(IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_free", CallingConvention = CallingConvention.Cdecl)] + internal static extern void mtmd_input_chunk_free(IntPtr chunk); + + // image_tokens ------------------------------------------------------ + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_n_tokens", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_image_tokens_get_n_tokens(IntPtr image_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_nx", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_image_tokens_get_nx(IntPtr image_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_ny", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_image_tokens_get_ny(IntPtr image_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_id", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_image_tokens_get_id(IntPtr image_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_n_pos", CallingConvention = CallingConvention.Cdecl)] + internal static extern long mtmd_image_tokens_get_n_pos(IntPtr image_tokens); + + // tokenize ---------------------------------------------------------- + + /// + /// Native text structure consumed by . + /// + internal unsafe struct mtmd_input_text_native + { + public byte* text; + [MarshalAs(UnmanagedType.I1)] public bool add_special; + [MarshalAs(UnmanagedType.I1)] public bool parse_special; + } + + /// + /// Utility scope that pins managed text while invoking the native tokenizer. + /// + internal readonly unsafe ref struct MtmdInputTextScope + { + public readonly mtmd_input_text_native Value; + private readonly PinnedUtf8String _text; + + public MtmdInputTextScope(string text, bool addSpecial, bool parseSpecial) + { + _text = PinnedUtf8String.Create(text) ?? throw new ArgumentNullException(nameof(text)); + Value = new mtmd_input_text_native + { + text = (byte*)_text.Pointer, + add_special = addSpecial, + parse_special = parseSpecial + }; + } + + public void Dispose() => _text.Dispose(); + } + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_tokenize", CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe int mtmd_tokenize_native( + IntPtr ctx, + IntPtr output, + mtmd_input_text_native* text, + IntPtr[] bitmaps, + UIntPtr n_bitmaps); + + internal static unsafe int mtmd_tokenize(IntPtr ctx, IntPtr output, in mtmd_input_text_native text, IntPtr[] bitmaps, UIntPtr n_bitmaps) + { + var temp = text; + return mtmd_tokenize_native(ctx, output, &temp, bitmaps, n_bitmaps); + } + + internal static unsafe int mtmd_tokenize(IntPtr ctx, IntPtr output, string text, bool addSpecial, bool parseSpecial, IntPtr[] bitmaps, UIntPtr n_bitmaps) + { + using var scope = new MtmdInputTextScope(text, addSpecial, parseSpecial); + return mtmd_tokenize_native(ctx, output, &scope.Value, bitmaps, n_bitmaps); + } + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_encode", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_encode(IntPtr ctx, IntPtr image_tokens); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_encode_chunk", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_encode_chunk(IntPtr ctx, IntPtr chunk); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_get_output_embd", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_get_output_embd(IntPtr ctx); + + // helper ------------------------------------------------------------ + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_test_create_input_chunks", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_test_create_input_chunks(); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_bitmap_init_from_file", CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe IntPtr mtmd_helper_bitmap_init_from_file_native(IntPtr ctx, byte* fname); + + internal static unsafe IntPtr mtmd_helper_bitmap_init_from_file(IntPtr ctx, string fname) + { + using var pinned = PinnedUtf8String.Create(fname) ?? throw new ArgumentNullException(nameof(fname)); + return mtmd_helper_bitmap_init_from_file_native(ctx, (byte*)pinned.Pointer); + } + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_bitmap_init_from_buf", CallingConvention = CallingConvention.Cdecl)] + internal static extern IntPtr mtmd_helper_bitmap_init_from_buf(IntPtr ctx, IntPtr buf, UIntPtr len); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_get_n_tokens", CallingConvention = CallingConvention.Cdecl)] + internal static extern UIntPtr mtmd_helper_get_n_tokens(IntPtr chunks); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_get_n_pos", CallingConvention = CallingConvention.Cdecl)] + internal static extern long mtmd_helper_get_n_pos(IntPtr chunks); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_eval_chunks", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_helper_eval_chunks( + IntPtr ctx, + IntPtr lctx, + IntPtr chunks, + long n_past, + int seq_id, + int n_batch, + [MarshalAs(UnmanagedType.I1)] bool logits_last, + ref long new_n_past); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_eval_chunk_single", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_helper_eval_chunk_single( + IntPtr ctx, + IntPtr lctx, + IntPtr chunk, + long n_past, + int seq_id, + int n_batch, + [MarshalAs(UnmanagedType.I1)] bool logits_last, + ref long new_n_past); + + [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_decode_image_chunk", CallingConvention = CallingConvention.Cdecl)] + internal static extern int mtmd_helper_decode_image_chunk( + IntPtr ctx, + IntPtr lctx, + IntPtr chunk, + IntPtr encoded_embd, + long n_past, + int seq_id, + int n_batch, + ref long new_n_past); +} diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 4aefc8810..3123674fc 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -1,4 +1,5 @@ using System; +using System.Text; #pragma warning disable IDE1006 // Naming Styles @@ -179,7 +180,7 @@ public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* { return internal_llama_chat_apply_template(tmpl, chat, n_msg, add_ass, buf, length); - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl,EntryPoint = "llama_chat_apply_template")] + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_chat_apply_template")] static extern int internal_llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length); } @@ -215,8 +216,7 @@ public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* /// User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix') /// If true, special tokens are rendered in the output /// The length written, or if the buffer is too small a negative that indicates the length required - public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken, - Span buffer, int lstrip, bool special) + public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken, Span buffer, int lstrip, bool special) { // Handle invalid tokens if ((int)llamaToken < 0) @@ -226,14 +226,12 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL { fixed (byte* bufferPtr = buffer) { - return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip, - special); + return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip, special); } } [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")] - static extern unsafe int llama_token_to_piece_native(LLamaVocabNative* model, LLamaToken llamaToken, - byte* buffer, int length, int lstrip, [MarshalAs(UnmanagedType.U1)] bool special); + static extern unsafe int llama_token_to_piece_native(LLamaVocabNative* model, LLamaToken llamaToken, byte* buffer, int length, int lstrip, [MarshalAs(UnmanagedType.U1)] bool special); } /// @@ -250,9 +248,7 @@ static extern unsafe int llama_token_to_piece_native(LLamaVocabNative* model, LL /// Returns a negative number on failure - the number of tokens that would have been returned. Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit) /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - internal static extern unsafe int llama_tokenize(LLamaVocabNative* model, byte* text, int text_len, - LLamaToken* tokens, int n_max_tokens, [MarshalAs(UnmanagedType.U1)] bool add_special, - [MarshalAs(UnmanagedType.U1)] bool parse_special); + internal static extern unsafe int llama_tokenize(LLamaVocabNative* model, byte* text, int text_len, LLamaToken* tokens, int n_max_tokens, [MarshalAs(UnmanagedType.U1)] bool add_special, [MarshalAs(UnmanagedType.U1)] bool parse_special); /// /// Convert the provided tokens into text (inverse of llama_tokenize()). @@ -266,8 +262,7 @@ internal static extern unsafe int llama_tokenize(LLamaVocabNative* model, byte* /// unparse_special If true, special tokens are rendered in the output. /// Returns the number of chars/bytes on success, no more than textLengthMax. Returns a negative number on failure - the number of chars/bytes that would have been returned. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - internal static extern unsafe int llama_detokenize(LLamaVocabNative* model, LLamaToken* tokens, int nTokens, - byte* textOut, int textLengthMax, bool removeSpecial, bool unparseSpecial); + internal static extern unsafe int llama_detokenize(LLamaVocabNative* model, LLamaToken* tokens, int nTokens, byte* textOut, int textLengthMax, bool removeSpecial, bool unparseSpecial); /// /// Register a callback to receive llama log messages @@ -278,7 +273,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) { NativeLogConfig.llama_log_set(logCallback); } - + /// /// Allocates a batch of tokens on the heap /// Each token can be assigned up to n_seq_max sequence ids @@ -317,8 +312,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe int llama_apply_adapter_cvec(SafeLLamaContextHandle ctx, float* data, nuint len, - int n_embd, int il_start, int il_end); + public static extern unsafe int llama_apply_adapter_cvec(SafeLLamaContextHandle ctx, float* data, nuint len, int n_embd, int il_start, int il_end); /// /// Build a split GGUF final path for this chunk. @@ -330,23 +324,115 @@ public static extern unsafe int llama_apply_adapter_cvec(SafeLLamaContextHandle /// /// /// Returns the split_path length. - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no, - int split_count); + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_split_path")] + private static extern unsafe int llama_split_path_native(byte* split_path, nuint maxlen, byte* path_prefix, int split_no, int split_count); + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_split_prefix")] + private static extern unsafe int llama_split_prefix_native(byte* split_prefix, nuint maxlen, byte* split_path, int split_no, int split_count); + + private static byte[] EncodeNullTerminatedUtf8(string value, string paramName) + { + if (value is null) + throw new ArgumentNullException(paramName); + + var bytes = Encoding.UTF8.GetBytes(value); + var buffer = new byte[bytes.Length + 1]; + Buffer.BlockCopy(bytes, 0, buffer, 0, bytes.Length); + // buffer[^1] = 0; + return buffer; + } /// - /// Extract the path prefix from the split_path if and only if the split_no and split_count match. - /// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0" + /// Build the fully-qualified path for a specific split file in a GGUF shard set. /// - /// - /// - /// - /// - /// - /// Returns the split_prefix length. - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no, - int split_count); + /// Writable buffer that receives the UTF-8 encoded path. + /// Base path (e.g. "/models/ggml-model-q4_0"). + /// Zero-based split index. + /// Total number of splits. + /// Number of bytes written to . + public static int llama_split_path(Span splitPathBuffer, string pathPrefix, int splitNo, int splitCount) + { + if (splitPathBuffer.Length == 0) + throw new ArgumentException("Buffer must not be empty.", nameof(splitPathBuffer)); + + var pathPrefixBytes = EncodeNullTerminatedUtf8(pathPrefix, nameof(pathPrefix)); + + unsafe + { + fixed (byte* splitPtr = splitPathBuffer) + fixed (byte* prefixPtr = pathPrefixBytes) + { + return llama_split_path_native(splitPtr, (nuint)splitPathBuffer.Length, prefixPtr, splitNo, splitCount); + } + } + } + + /// + /// Build the fully-qualified path for a specific split file in a GGUF shard set. + /// + /// Base path (e.g. "/models/ggml-model-q4_0"). + /// Zero-based split index. + /// Total number of splits. + /// Maximum number of bytes to allocate for the resulting UTF-8 string. + /// UTF-8 decoded split path. + public static string llama_split_path(string pathPrefix, int splitNo, int splitCount, int maxLength = 1024) + { + if (maxLength <= 0) + throw new ArgumentOutOfRangeException(nameof(maxLength)); + + var buffer = new byte[maxLength]; + var written = llama_split_path((Span)buffer, pathPrefix, splitNo, splitCount); + if (written <= 0) + throw new InvalidOperationException("Failed to build split path using llama_split_path."); + + return Encoding.UTF8.GetString(buffer, 0, written); + } + + /// + /// Extract the shard prefix from a GGUF split path when the split metadata matches. + /// + /// Writable buffer that receives the UTF-8 encoded prefix. + /// Full path to a shard file. + /// Zero-based split index. + /// Total number of splits. + /// Number of bytes written to . + public static int llama_split_prefix(Span splitPrefixBuffer, string splitPath, int splitNo, int splitCount) + { + if (splitPrefixBuffer.Length == 0) + throw new ArgumentException("Buffer must not be empty.", nameof(splitPrefixBuffer)); + + var splitPathBytes = EncodeNullTerminatedUtf8(splitPath, nameof(splitPath)); + + unsafe + { + fixed (byte* prefixPtr = splitPrefixBuffer) + fixed (byte* pathPtr = splitPathBytes) + { + return llama_split_prefix_native(prefixPtr, (nuint)splitPrefixBuffer.Length, pathPtr, splitNo, splitCount); + } + } + } + + /// + /// Extract the shard prefix from a GGUF split path when the split metadata matches. + /// + /// Full path to a shard file. + /// Zero-based split index. + /// Total number of splits. + /// Maximum number of bytes to allocate for the resulting UTF-8 string. + /// UTF-8 decoded split prefix. + public static string llama_split_prefix(string splitPath, int splitNo, int splitCount, int maxLength = 1024) + { + if (maxLength <= 0) + throw new ArgumentOutOfRangeException(nameof(maxLength)); + + var buffer = new byte[maxLength]; + var written = llama_split_prefix((Span)buffer, splitPath, splitNo, splitCount); + if (written <= 0) + throw new InvalidOperationException("Failed to extract split prefix using llama_split_prefix."); + + return Encoding.UTF8.GetString(buffer, 0, written); + } //[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] //todo: public static void llama_attach_threadpool(SafeLLamaContextHandle ctx, ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); @@ -389,41 +475,5 @@ public static extern int llama_split_prefix(string split_prefix, nuint maxlen, s /// Name of the buffer type [DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)] public static extern IntPtr ggml_backend_buft_name(IntPtr buft); - - /// - /// - /// - /// - /// - /// - /// - [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern UIntPtr llama_state_seq_get_size_ext(IntPtr ctx, int seq_id, uint flags); - - /// - /// - /// - /// - /// - /// - /// - /// - /// - [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern UIntPtr llama_state_seq_get_data_ext(IntPtr ctx, [Out] byte[] dst, UIntPtr size, - int seq_id, uint flags); - - /// - /// - /// - /// - /// - /// - /// - /// - /// - [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern UIntPtr llama_state_seq_set_data_ext(IntPtr ctx, byte[] src, UIntPtr size, int dest_seq_id, - uint flags); } -} \ No newline at end of file +} diff --git a/LLama/Native/SafeLlavaImageEmbedHandle.cs b/LLama/Native/SafeLlavaImageEmbedHandle.cs deleted file mode 100644 index 102c4b93f..000000000 --- a/LLama/Native/SafeLlavaImageEmbedHandle.cs +++ /dev/null @@ -1,162 +0,0 @@ -using System; -using System.IO; - - -namespace LLama.Native -{ - /// - /// A Reference to a llava Image Embed handle - /// - public sealed class SafeLlavaImageEmbedHandle - : SafeLLamaHandleBase - { - /// - /// Get the model used to create this image embedding - /// - public SafeLlavaModelHandle Model { get; private set; } = null!; - - /// - /// Get the number of dimensions in an embedding - /// - public int EmbeddingDimensions => Model.EmbeddingDimensions; - - /// - /// Get the number of "patches" in an image embedding - /// - public int PatchCount => Model.PatchCount; - - #region embed - /// - /// Create an image embed from an image file - /// - /// - /// - /// Path to the image file. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - /// - public static SafeLlavaImageEmbedHandle CreateFromFileName(SafeLlavaModelHandle clip, LLamaContext ctx, string image) - { - if (!NativeApi.llava_validate_embed_size(ctx.NativeHandle, clip)) - throw new InvalidOperationException($"Cannot create image embed. Embedding dim of the multimodal projector ({clip.EmbeddingDimensions}) is not equal to embedding dim of model ({ctx.EmbeddingSize})"); - - return CreateFromFileName(clip, image, (int)ctx.BatchThreads); - } - - /// - /// Create an image embed from an image file - /// - /// - /// Path to the image file. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - /// - /// - public static SafeLlavaImageEmbedHandle CreateFromFileName(SafeLlavaModelHandle clip, string image, int threads = -1) - { - if (threads <= 0) - threads = Environment.ProcessorCount / 2; - - // Try to open the image file, this will check: - // - File exists (automatically throws FileNotFoundException) - // - File is readable (explicit check) - // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases. - using (var fs = new FileStream(image, FileMode.Open)) - if (!fs.CanRead) - throw new InvalidOperationException($"Llava image file '{image}' is not readable"); - - var embed = NativeApi.llava_image_embed_make_with_filename(clip, threads, image); - embed.Model = clip; - return embed; - } - - /// - /// Create an image embed from the bytes of an image. - /// - /// - /// - /// Image bytes. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - public static SafeLlavaImageEmbedHandle CreateFromMemory(SafeLlavaModelHandle clip, LLamaContext ctx, byte[] image) - { - if (!NativeApi.llava_validate_embed_size(ctx.NativeHandle, clip)) - throw new InvalidOperationException($"Cannot create image embed. Embedding dim of the multimodal projector ({clip.EmbeddingDimensions}) is not equal to embedding dim of model ({ctx.EmbeddingSize})"); - - return CreateFromMemory(clip, image, (int)ctx.BatchThreads); - } - - /// - /// Create an image embed from the bytes of an image. - /// - /// - /// Image bytes. Supported formats: - /// - /// JPG - /// PNG - /// BMP - /// TGA - /// - /// - /// - /// - public static SafeLlavaImageEmbedHandle CreateFromMemory(SafeLlavaModelHandle clip, byte[] image, int threads = -1) - { - if (threads <= 0) - threads = Environment.ProcessorCount / 2; - - var embed = NativeApi.llava_image_embed_make_with_bytes(clip, threads, image, image.Length); - embed.Model = clip; - return embed; - } - #endregion - - /// - protected override bool ReleaseHandle() - { - NativeApi.llava_image_embed_free(DangerousGetHandle()); - SetHandle(IntPtr.Zero); - return true; - } - - /// - /// Copy the embeddings data to the destination span - /// - /// - /// - public void GetEmbedding(Span dest, int index) - { - if (index < 0) - throw new ArgumentOutOfRangeException(nameof(index), "index must be >= 0"); - if (index >= Model.PatchCount) - throw new ArgumentOutOfRangeException(nameof(index), "index must be < Model.PatchCount"); - - unsafe - { - var embed = (LLavaImageEmbed*)DangerousGetHandle(); - new Span( - embed->embed + Model.EmbeddingDimensions * index, - Model.EmbeddingDimensions - ).CopyTo(dest); - } - } - } -} diff --git a/LLama/Native/SafeLlavaModelHandle.cs b/LLama/Native/SafeLlavaModelHandle.cs deleted file mode 100644 index 5b3a910e9..000000000 --- a/LLama/Native/SafeLlavaModelHandle.cs +++ /dev/null @@ -1,137 +0,0 @@ -using System; -using System.IO; -using LLama.Exceptions; - - -namespace LLama.Native -{ - /// - /// A reference to a set of llava model weights. - /// - public sealed class SafeLlavaModelHandle - : SafeLLamaHandleBase - { - /// - /// Get the number of dimensions in an embedding - /// - public int EmbeddingDimensions => clip_n_mmproj_embd(this); - - /// - /// Get the number of "patches" in an image embedding - /// - public int PatchCount => clip_n_patches(this); - - /// - protected override bool ReleaseHandle() - { - clip_free(DangerousGetHandle()); - SetHandle(IntPtr.Zero); - return true; - } - - /// - /// Load a model from the given file path into memory - /// - /// MMP File (Multi-Modal Projections) - /// Verbosity level - /// SafeHandle of the Clip Model - /// - /// - public static SafeLlavaModelHandle LoadFromFile(string modelPath, int verbosity ) - { - // Try to open the model file, this will check: - // - File exists (automatically throws FileNotFoundException) - // - File is readable (explicit check) - // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases. - using (var fs = new FileStream(modelPath, FileMode.Open)) - if (!fs.CanRead) - throw new InvalidOperationException($"Llava MMP Model file '{modelPath}' is not readable"); - - var handle = clip_model_load(modelPath, verbosity); - if (handle.IsInvalid) - throw new LoadWeightsFailedException(modelPath); - - return handle; - } - - /// - /// Create the Image Embeddings. - /// - /// LLama Context - /// Image filename (it supports jpeg format only) - /// return the SafeHandle of these embeddings - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image) - { - return SafeLlavaImageEmbedHandle.CreateFromFileName(this, ctxLlama, image); - } - - /// - /// Create the Image Embeddings. - /// - /// Image in binary format (it supports jpeg format only) - /// Number of threads to use - /// return the SafeHandle of these embeddings - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(string image, int threads = -1) - { - return SafeLlavaImageEmbedHandle.CreateFromFileName(this, image, threads); - } - - /// - /// Create the Image Embeddings. - /// - /// LLama Context - /// Image in binary format (it supports jpeg format only) - /// return the SafeHandle of these embeddings - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image) - { - return SafeLlavaImageEmbedHandle.CreateFromMemory(this, ctxLlama, image ); - } - - /// - /// Create the Image Embeddings. - /// - /// Image in binary format (it supports jpeg format only) - /// Number of threads to use - /// return the SafeHandle of these embeddings - public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads = -1) - { - return SafeLlavaImageEmbedHandle.CreateFromMemory(this, image, threads); - } - - /// - /// Evaluates the image embeddings. - /// - /// Llama Context - /// The current embeddings to evaluate - /// - /// True on success - public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imageEmbed, ref int n_past) - { - return NativeApi.llava_eval_image_embed(ctxLlama.NativeHandle, imageEmbed, (int)ctxLlama.BatchSize, ref n_past ); - } - - #region native API - /// - /// Load MULTI MODAL PROJECTIONS model / Clip Model - /// - /// Model path/file - /// Verbosity level - /// SafeLlavaModelHandle - [DllImport(NativeApi.llavaLibraryName, EntryPoint = "clip_model_load", CallingConvention = CallingConvention.Cdecl)] - private static extern SafeLlavaModelHandle clip_model_load(string mmProj, int verbosity); - - /// - /// Frees MULTI MODAL PROJECTIONS model / Clip Model - /// - /// Internal Pointer to the model - [DllImport(NativeApi.llavaLibraryName, EntryPoint = "clip_free", CallingConvention = CallingConvention.Cdecl)] - private static extern void clip_free(IntPtr ctx); - - [DllImport(NativeApi.llavaLibraryName, CallingConvention = CallingConvention.Cdecl)] - private static extern int clip_n_mmproj_embd(SafeLlavaModelHandle ctx); - - [DllImport(NativeApi.llavaLibraryName, CallingConvention = CallingConvention.Cdecl)] - private static extern int clip_n_patches(SafeLlavaModelHandle ctx); - #endregion - } -} diff --git a/LLama/Native/SafeMtmdEmbed.cs b/LLama/Native/SafeMtmdEmbed.cs new file mode 100644 index 000000000..c651db102 --- /dev/null +++ b/LLama/Native/SafeMtmdEmbed.cs @@ -0,0 +1,247 @@ +using System; +using System.IO; +using System.Runtime.InteropServices; + +namespace LLama.Native +{ + /// + /// Managed wrapper around mtmd_bitmap* resources. Instances own the native pointer + /// and ensure proper cleanup when disposed. + /// + public sealed class SafeMtmdEmbed : IDisposable + { + /// + /// Raw pointer to the native bitmap structure. Internal so other wrappers can interop. + /// + internal IntPtr NativePtr { get; private set; } + + private bool _disposed; + + private SafeMtmdEmbed(IntPtr ptr) + { + NativePtr = ptr != IntPtr.Zero + ? ptr + : throw new InvalidOperationException("Failed to create MTMD bitmap."); + } + + /// + /// Create an embedding from raw RGB bytes. + /// + /// Width of the bitmap in pixels. + /// Height of the bitmap in pixels. + /// Packed RGB data (3 bytes per pixel). + /// Managed wrapper when initialization succeeds; otherwise null. + /// The RGB buffer is null. + public static SafeMtmdEmbed? FromRgbBytes(uint nx, uint ny, byte[] rgbData) + { + if (rgbData == null) + throw new ArgumentNullException(nameof(rgbData)); + + var handle = GCHandle.Alloc(rgbData, GCHandleType.Pinned); + try + { + var native = NativeApi.mtmd_bitmap_init(nx, ny, handle.AddrOfPinnedObject()); + return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native); + } + finally + { + if (handle.IsAllocated) + handle.Free(); + } + } + + /// + /// Create an embedding from PCM audio samples. + /// + /// Array of mono PCM samples in float format. + /// Managed wrapper when initialization succeeds; otherwise null. + /// The audio buffer is null. + public static SafeMtmdEmbed? FromAudioSamples(float[] samples) + { + if (samples == null) + throw new ArgumentNullException(nameof(samples)); + + var handle = GCHandle.Alloc(samples, GCHandleType.Pinned); + try + { + var native = NativeApi.mtmd_bitmap_init_from_audio((ulong)samples.Length, handle.AddrOfPinnedObject()); + return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native); + } + finally + { + if (handle.IsAllocated) + handle.Free(); + } + } + + /// + /// Create an embedding by decoding a media file using libmtmd helpers. + /// + /// Model context that provides the decoder configuration. + /// Path to the media file on disk. + /// Managed wrapper when decoding succeeds; otherwise null. + /// The context is null. + /// The path is null or whitespace. + /// The supplied file does not exist. + public static SafeMtmdEmbed? FromMediaFile(SafeMtmdModelHandle mtmdContext, string path) + { + if (mtmdContext == null) + throw new ArgumentNullException(nameof(mtmdContext)); + if (string.IsNullOrWhiteSpace(path)) + throw new ArgumentException("Value cannot be null or whitespace.", nameof(path)); + + var fullPath = Path.GetFullPath(path); + if (!File.Exists(fullPath)) + throw new FileNotFoundException("Media file not found.", fullPath); + + bool added = false; + var ctxPtr = IntPtr.Zero; + try + { + // Hold a strong reference to the native context while the helper decodes the media file. + mtmdContext.DangerousAddRef(ref added); + ctxPtr = mtmdContext.DangerousGetHandle(); + var native = NativeApi.mtmd_helper_bitmap_init_from_file(ctxPtr, fullPath); + return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native); + } + finally + { + if (added) + mtmdContext.DangerousRelease(); + } + } + + /// + /// Create an embedding from an in-memory media buffer (image/audio/video). + /// + /// Model context that provides the decoder configuration. + /// Binary buffer containing the encoded media. + /// Managed wrapper when decoding succeeds; otherwise null. + /// The context is null. + /// The buffer is empty. + public static unsafe SafeMtmdEmbed? FromMediaBuffer(SafeMtmdModelHandle mtmdContext, ReadOnlySpan data) + { + if (mtmdContext == null) + throw new ArgumentNullException(nameof(mtmdContext)); + if (data.IsEmpty) + throw new ArgumentException("Buffer must not be empty.", nameof(data)); + + bool added = false; + var ctxPtr = IntPtr.Zero; + try + { + // Keep the context alive while the native helper processes the buffer. + mtmdContext.DangerousAddRef(ref added); + ctxPtr = mtmdContext.DangerousGetHandle(); + + fixed (byte* bufferPtr = data) + { + var native = NativeApi.mtmd_helper_bitmap_init_from_buf(ctxPtr, new IntPtr(bufferPtr), (UIntPtr)data.Length); + return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native); + } + } + finally + { + if (added) + mtmdContext.DangerousRelease(); + } + } + + /// + /// Width of the bitmap in pixels (or number of samples for audio embeddings). + /// + public uint Nx + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_bitmap_get_nx(NativePtr); + } + } + + /// + /// Height of the bitmap in pixels. For audio embeddings this is typically 1. + /// + public uint Ny + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_bitmap_get_ny(NativePtr); + } + } + + /// + /// Indicates whether the embedding stores audio data instead of image pixels. + /// + public bool IsAudio + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_bitmap_is_audio(NativePtr); + } + } + + /// + /// Optional identifier assigned to this embedding. + /// + public string? Id + { + get + { + EnsureNotDisposed(); + var ptr = NativeApi.mtmd_bitmap_get_id(NativePtr); + return NativeApi.PtrToStringUtf8(ptr); + } + set + { + EnsureNotDisposed(); + NativeApi.mtmd_bitmap_set_id(NativePtr, value); + } + } + + /// + /// Zero-copy access to the underlying bitmap bytes. The span remains valid while this wrapper is alive. + /// + /// Read-only span exposing the native data buffer. + /// The embedding has been disposed. + public unsafe ReadOnlySpan GetDataSpan() + { + EnsureNotDisposed(); + + var dataPtr = (byte*)NativeApi.mtmd_bitmap_get_data(NativePtr); + var length = checked((int)NativeApi.mtmd_bitmap_get_n_bytes(NativePtr).ToUInt64()); + return dataPtr == null || length == 0 ? ReadOnlySpan.Empty : new ReadOnlySpan(dataPtr, length); + } + + /// + /// Release the underlying native bitmap. + /// + public void Dispose() + { + if (_disposed) + return; + + if (NativePtr != IntPtr.Zero) + { + NativeApi.mtmd_bitmap_free(NativePtr); + NativePtr = IntPtr.Zero; + } + + _disposed = true; + GC.SuppressFinalize(this); + } + + /// + /// Finalizer to ensure native resources are reclaimed when Dispose is not invoked. + /// + ~SafeMtmdEmbed() => Dispose(); + + private void EnsureNotDisposed() + { + if (_disposed || NativePtr == IntPtr.Zero) + throw new ObjectDisposedException(nameof(SafeMtmdEmbed)); + } + } +} diff --git a/LLama/Native/SafeMtmdInputChunk.cs b/LLama/Native/SafeMtmdInputChunk.cs new file mode 100644 index 000000000..59d1897ef --- /dev/null +++ b/LLama/Native/SafeMtmdInputChunk.cs @@ -0,0 +1,150 @@ +using System; +using System.Runtime.InteropServices; + +namespace LLama.Native; + +/// +/// Managed wrapper around a single mtmd_input_chunk. Instances can either own the +/// underlying native pointer (when created via ) or act as non-owning views +/// produced by the tokenizer. +/// +public sealed class SafeMtmdInputChunk : IDisposable +{ + /// + /// Chunk modality returned by the native tokenizer. + /// + public enum SafeMtmdInputChunkType + { + Text = 0, + Image = 1, + Audio = 2 + } + + /// + /// Raw pointer to the native chunk structure. + /// + public IntPtr NativePtr { get; private set; } + + private bool _ownsPtr; + private bool _disposed; + + private SafeMtmdInputChunk(IntPtr ptr, bool owns) + { + NativePtr = ptr; + _ownsPtr = owns; + } + + /// + /// Wrap an existing chunk pointer without taking ownership. + /// + /// Pointer returned by the native tokenizer. + /// Managed wrapper, or null when the pointer is null. + public static SafeMtmdInputChunk Wrap(IntPtr ptr) + => ptr == IntPtr.Zero ? null : new SafeMtmdInputChunk(ptr, false); + + /// + /// Create an owning copy of the current chunk. The caller becomes responsible for disposal. + /// + /// Owning managed wrapper, or null if the native copy failed. + /// Thrown when the current wrapper has been disposed. + public SafeMtmdInputChunk Copy() + { + EnsureNotDisposed(); + + var p = NativeApi.mtmd_input_chunk_copy(NativePtr); + return p == IntPtr.Zero ? null : new SafeMtmdInputChunk(p, true); + } + + /// + /// Chunk modality reported by the native helper. + /// + public SafeMtmdInputChunkType Type + { + get + { + EnsureNotDisposed(); + return (SafeMtmdInputChunkType)NativeApi.mtmd_input_chunk_get_type(NativePtr); + } + } + + /// + /// Number of tokens contained in this chunk. + /// + public ulong NTokens + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_input_chunk_get_n_tokens(NativePtr).ToUInt64(); + } + } + + /// + /// Identifier assigned by the tokenizer (if any). + /// + public string Id + { + get + { + EnsureNotDisposed(); + return Marshal.PtrToStringAnsi(NativeApi.mtmd_input_chunk_get_id(NativePtr)) ?? string.Empty; + } + } + + /// + /// Number of positional slots consumed by this chunk. + /// + public long NPos + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_input_chunk_get_n_pos(NativePtr); + } + } + + /// + /// Zero-copy view over the chunk's token buffer. The span remains valid only while the native chunk is alive. + /// + /// Read-only span exposing the chunk's tokens. + /// Thrown when the wrapper has been disposed. + public unsafe ReadOnlySpan GetTextTokensSpan() + { + EnsureNotDisposed(); + + UIntPtr n; + var p = (uint*)NativeApi.mtmd_input_chunk_get_tokens_text(NativePtr, out n); + return p == null ? ReadOnlySpan.Empty : new ReadOnlySpan(p, checked((int)n.ToUInt64())); + } + + /// + /// Release the underlying native resources if this instance owns them. + /// + public void Dispose() + { + if (_disposed) + return; + + if (_ownsPtr && NativePtr != IntPtr.Zero) + { + NativeApi.mtmd_input_chunk_free(NativePtr); + } + + NativePtr = IntPtr.Zero; + _ownsPtr = false; + _disposed = true; + + GC.SuppressFinalize(this); + } + + /// + /// Finalizer to ensure native memory is reclaimed when Dispose is not called by owners. + /// + ~SafeMtmdInputChunk() => Dispose(); + + private void EnsureNotDisposed() + { + if (_disposed || NativePtr == IntPtr.Zero) + throw new ObjectDisposedException(nameof(SafeMtmdInputChunk)); + } +} diff --git a/LLama/Native/SafeMtmdInputChunks.cs b/LLama/Native/SafeMtmdInputChunks.cs new file mode 100644 index 000000000..2081cd0a6 --- /dev/null +++ b/LLama/Native/SafeMtmdInputChunks.cs @@ -0,0 +1,103 @@ +using System; +using System.Collections.Generic; + +namespace LLama.Native; + +/// +/// Managed lifetime wrapper around a native mtmd_input_chunks collection returned by the tokenizer. +/// +public sealed class SafeMtmdInputChunks : IDisposable +{ + /// + /// Raw pointer to the native chunk collection. Internal to allow other wrappers to interop safely. + /// + internal IntPtr NativePtr { get; private set; } + + private bool _disposed; + + internal SafeMtmdInputChunks(IntPtr ptr) + { + NativePtr = ptr; + } + + /// + /// Releases the native chunk collection and suppresses finalization. + /// + public void Dispose() + { + if (_disposed) + return; + + if (NativePtr != IntPtr.Zero) + { + NativeApi.mtmd_input_chunks_free(NativePtr); + NativePtr = IntPtr.Zero; + } + + _disposed = true; + GC.SuppressFinalize(this); + } + + /// + /// Finalizer to ensure native memory is reclaimed if Dispose is not called. + /// + ~SafeMtmdInputChunks() + { + Dispose(); + } + + /// + /// Number of chunks currently held by the native collection. + /// + public ulong Size + { + get + { + EnsureNotDisposed(); + return NativeApi.mtmd_input_chunks_size(NativePtr).ToUInt64(); + } + } + + /// + /// Get a raw pointer to a chunk. The returned is the mtmd_input_chunk*. + /// Use to create a managed wrapper if desired. + /// + /// Zero-based index of the chunk to retrieve. + /// Pointer to the requested chunk. + /// The collection has already been disposed. + /// The requested index is outside of the valid range. + public IntPtr GetChunkPtr(ulong index) + { + EnsureNotDisposed(); + + if (index >= Size) throw new IndexOutOfRangeException(); + return NativeApi.mtmd_input_chunks_get(NativePtr, (UIntPtr)index); + } + + /// + /// Enumerate the contained chunks as non-owning wrappers. Callers should dispose the returned chunk + /// if they create a copy. + /// + /// Enumeration of chunk wrappers backed by the native collection. + /// The collection has already been disposed. + public IEnumerable Enumerate() + { + EnsureNotDisposed(); + + for (ulong i = 0; i < Size; i++) + { + var chunk = SafeMtmdInputChunk.Wrap(GetChunkPtr(i)); + if (chunk != null) + { + // Yield a lightweight wrapper; ownership remains with the native collection. + yield return chunk; + } + } + } + + private void EnsureNotDisposed() + { + if (_disposed || NativePtr == IntPtr.Zero) + throw new ObjectDisposedException(nameof(SafeMtmdInputChunks)); + } +} diff --git a/LLama/Native/SafeMtmdModelHandle.cs b/LLama/Native/SafeMtmdModelHandle.cs new file mode 100644 index 000000000..236a22011 --- /dev/null +++ b/LLama/Native/SafeMtmdModelHandle.cs @@ -0,0 +1,349 @@ +using System; +using System.Collections.Generic; +using System.IO; +using LLama.Exceptions; + + +namespace LLama.Native +{ + /// + /// Wrapper to the Multi Modal Weights handle. This wrapper manages the low level + /// operations. + /// + public sealed class SafeMtmdModelHandle : SafeLLamaHandleBase + { + // Pending media embeddings queued for the next call to Tokenize. + private readonly List _pendingMedia = new(); + + /// + protected override bool ReleaseHandle() + { + mtmd_free(DangerousGetHandle()); + SetHandle(IntPtr.Zero); + return true; + } + + /// + /// Load a multimodal projection model from disk and bind it to the supplied text model. + /// + /// Path to the MMP (Multi-Modal Projections) file. + /// Text model that provides tokenizer weights for the multimodal helper. + /// Optional context parameters; defaults are used when null. + /// Safe handle for the MTMD model. + /// The file exists but is not readable by the current process. + /// The native loader failed to initialize the MTMD model. + public static SafeMtmdModelHandle LoadFromFile(string modelPath, LLamaWeights textModel, MtmdContextParams mtmdCtxParams) + { + // Try to open the model file, this will check: + // - File exists (automatically throws FileNotFoundException) + // - File is readable (explicit check) + // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases. + using (var fs = new FileStream(modelPath, FileMode.Open)) + if (!fs.CanRead) + throw new InvalidOperationException($"Mtmd MMP Model file '{modelPath}' is not readable"); + + using var pathUtf8 = PinnedUtf8String.Create(modelPath) ?? throw new ArgumentNullException(nameof(modelPath)); + + unsafe + { + SafeMtmdModelHandle handle; + if (mtmdCtxParams is null) + { + var nativeParams = NativeApi.mtmd_context_params_default(); + handle = mtmd_init_from_file((byte*)pathUtf8.Pointer, textModel.NativeHandle, nativeParams); + } + else + { + using var nativeParamsScope = mtmdCtxParams.ToNativeScope(); + handle = mtmd_init_from_file((byte*)pathUtf8.Pointer, textModel.NativeHandle, nativeParamsScope.Value); + } + + if (handle.IsInvalid) + throw new LoadWeightsFailedException(modelPath); + + return handle; + } + } + + /// + /// Load media from disk and queue it for the next tokenize call. + /// + /// Absolute or relative path to the media asset. + /// Safe handle to the media embedding. + /// The model handle has been disposed. + /// The native loader failed to ingest the file. + public SafeMtmdEmbed LoadMediaFromFile(string path) + { + EnsureNotDisposed(); + + var embed = SafeMtmdEmbed.FromMediaFile(this, path) + ?? throw new RuntimeError($"Failed to load media '{path}'."); + _pendingMedia.Add(embed); + return embed; + } + + /// + /// Load media from an in-memory buffer and queue it for the next tokenize call. + /// + /// Binary buffer containing the encoded media data. + /// Safe handle to the media embedding. + /// The model handle has been disposed. + /// The native loader failed to ingest the buffer contents. + public SafeMtmdEmbed LoadMediaFromBuffer(ReadOnlySpan buffer) + { + EnsureNotDisposed(); + + var embed = SafeMtmdEmbed.FromMediaBuffer(this, buffer) + ?? throw new RuntimeError("Failed to load media from buffer."); + _pendingMedia.Add(embed); + return embed; + } + + /// + /// Disposes and clears any media buffers currently queued for tokenization. + /// + public void ClearMedia() + { + foreach (var media in _pendingMedia) + media.Dispose(); + _pendingMedia.Clear(); + } + + /// + /// Tokenize a prompt alongside the pending media buffers. Pending media is cleared on success. + /// + /// Prompt text to tokenize. + /// Whether to append special tokens automatically. + /// Whether special tokens should be treated as user-provided text. + /// Receives the native chunk collection when tokenization succeeds. + /// Zero on success; otherwise the native mtmd tokenize error code. + /// The model handle has been disposed. + public int Tokenize(string text, bool addSpecial, bool parseSpecial, out SafeMtmdInputChunks? chunks) + { + EnsureNotDisposed(); + + chunks = null; + // Allocate the chunk container before invoking the native tokenizer. + var output = NativeApi.mtmd_input_chunks_init(); + if (output == IntPtr.Zero) + throw new RuntimeError("Failed to allocate mtmd_input_chunks."); + + // Collect native pointers to the queued media embeddings. + var bitmapHandles = new IntPtr[_pendingMedia.Count]; + for (var i = 0; i < _pendingMedia.Count; i++) + bitmapHandles[i] = _pendingMedia[i].NativePtr; + + var result = NativeApi.mtmd_tokenize(DangerousGetHandle(), output, text, addSpecial, parseSpecial, bitmapHandles, (UIntPtr)bitmapHandles.Length); + + if (result == 0) + { + chunks = new SafeMtmdInputChunks(output); + foreach (var media in _pendingMedia) + media.Dispose(); + _pendingMedia.Clear(); + } + else + { + NativeApi.mtmd_input_chunks_free(output); + } + + if (result != 0) + { + foreach (var media in _pendingMedia) + media.Dispose(); + _pendingMedia.Clear(); + } + + return result; + } + + /// + /// Evaluate a batch of chunks using the helper (mirrors mtmd-helper eval logic). + /// + /// Chunk collection produced by . + /// Context handle that receives the evaluated tokens. + /// Number of past tokens; updated when evaluation succeeds. + /// Sequence identifier used for KV cache management. + /// Maximum number of tokens to evaluate in a single batch. + /// Whether to request logits for the last token only. + /// Zero on success; otherwise the native helper error code. + /// Thrown when required handles are null. + public int EvaluateChunks(SafeMtmdInputChunks chunks, SafeLLamaContextHandle llamaContext, ref long nPast, int seqId, int nBatch, bool logitsLast) + { + EnsureNotDisposed(); + + if (chunks == null) + throw new ArgumentNullException(nameof(chunks)); + if (llamaContext == null) + throw new ArgumentNullException(nameof(llamaContext)); + + var newNPast = nPast; + var result = NativeApi.mtmd_helper_eval_chunks( + DangerousGetHandle(), + llamaContext.DangerousGetHandle(), + chunks.NativePtr, + nPast, + seqId, + nBatch, + logitsLast, + ref newNPast); + + if (result == 0) + nPast = newNPast; + + return result; + } + + /// + /// Evaluate a single chunk helper. + /// + /// Pointer to the chunk to evaluate. + /// Context handle that receives the evaluated tokens. + /// Number of past tokens; updated when evaluation succeeds. + /// Sequence identifier used for KV cache management. + /// Maximum number of tokens to evaluate in a single batch. + /// Whether to request logits for the last token only. + /// Zero on success; otherwise the native helper error code. + /// Thrown when required handles are null. + public int EvaluateChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, ref long nPast, int seqId, int nBatch, bool logitsLast) + { + EnsureNotDisposed(); + + if (chunkPtr == IntPtr.Zero) + throw new ArgumentNullException(nameof(chunkPtr)); + if (llamaContext == null) + throw new ArgumentNullException(nameof(llamaContext)); + + var newNPast = nPast; + var result = NativeApi.mtmd_helper_eval_chunk_single( + DangerousGetHandle(), + llamaContext.DangerousGetHandle(), + chunkPtr, + nPast, + seqId, + nBatch, + logitsLast, + ref newNPast); + + if (result == 0) + nPast = newNPast; + + return result; + } + + /// + /// Decode a prepared image chunk whose embedding is already computed. + /// + /// Pointer to the chunk whose embedding should be decoded. + /// Context handle used for decoding. + /// Pointer to the pre-computed embedding data. + /// Number of past tokens; updated when evaluation succeeds. + /// Sequence identifier used for KV cache management. + /// Maximum number of tokens to evaluate in a single batch. + /// Zero on success; otherwise the native helper error code. + /// Thrown when required handles are null. + public int DecodeImageChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, IntPtr encodedEmbeddings, ref long nPast, int seqId, int nBatch) + { + EnsureNotDisposed(); + + if (chunkPtr == IntPtr.Zero) + throw new ArgumentNullException(nameof(chunkPtr)); + + var newNPast = nPast; + var result = NativeApi.mtmd_helper_decode_image_chunk( + DangerousGetHandle(), + llamaContext?.DangerousGetHandle() ?? throw new ArgumentNullException(nameof(llamaContext)), + chunkPtr, + encodedEmbeddings, + nPast, + seqId, + nBatch, + ref newNPast); + + if (result == 0) + nPast = newNPast; + + return result; + } + + /// + /// Get the number of tokens contained in the provided chunk collection. + /// + /// Chunk collection produced by . + /// Total token count. + public ulong CountTokens(SafeMtmdInputChunks chunks) + { + if (chunks == null) + throw new ArgumentNullException(nameof(chunks)); + return NativeApi.mtmd_helper_get_n_tokens(chunks.NativePtr).ToUInt64(); + } + + /// + /// Get the number of positions contained in the provided chunk collection. + /// + /// Chunk collection produced by . + /// Total number of positional slots consumed. + public long CountPositions(SafeMtmdInputChunks chunks) + { + if (chunks == null) + throw new ArgumentNullException(nameof(chunks)); + return NativeApi.mtmd_helper_get_n_pos(chunks.NativePtr); + } + + #region native API + + // mtmd_init_from_file(const char * mmproj_fname, const struct llama_model * text_model, const struct mtmd_context_params ctx_params); + // The llama_model layout is opaque; expose it via SafeLlamaModelHandle to match the managed wrapper. + [DllImport(NativeApi.mtmdLibraryName, EntryPoint = "mtmd_init_from_file", CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe SafeMtmdModelHandle mtmd_init_from_file( + byte* mmproj_fname, + SafeLlamaModelHandle text_model, + NativeApi.mtmd_context_params @ctx_params); + + [DllImport(NativeApi.mtmdLibraryName, EntryPoint = "mtmd_free", CallingConvention = CallingConvention.Cdecl)] + internal static extern void mtmd_free(IntPtr ctx); + + #endregion + + + + /// + /// Finalizer to ensure native resources are released if Dispose was not called. + /// + ~SafeMtmdModelHandle() + { + Dispose(); + } + + /// + /// Indicates whether the model decodes using the non-causal path. + /// + public bool DecodeUseNonCausal() => NativeApi.mtmd_decode_use_non_causal(handle); + + /// + /// Indicates whether the model decodes using multi-scale RoPE. + /// + public bool DecodeUseMRope() => NativeApi.mtmd_decode_use_mrope(handle); + + /// + /// Indicates whether the model supports vision inputs. + /// + public bool SupportVision() => NativeApi.mtmd_support_vision(handle); + + /// + /// Indicates whether the model supports audio inputs. + /// + public bool SupportAudio() => NativeApi.mtmd_support_audio(handle); + + /// + /// Gets the audio bitrate advertised by the model. + /// + public int GetAudioBitrate() => NativeApi.mtmd_get_audio_bitrate(handle); + + private void EnsureNotDisposed() + { + if (IsInvalid || IsClosed) + throw new ObjectDisposedException(nameof(SafeMtmdModelHandle)); + } + } +} diff --git a/LLama/Properties/InternalsVisibleTo.cs b/LLama/Properties/InternalsVisibleTo.cs new file mode 100644 index 000000000..b0a1ac4be --- /dev/null +++ b/LLama/Properties/InternalsVisibleTo.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("LLama.Unittest")] diff --git a/LLama/SafeMtmdWeights.cs b/LLama/SafeMtmdWeights.cs new file mode 100644 index 000000000..e490049b4 --- /dev/null +++ b/LLama/SafeMtmdWeights.cs @@ -0,0 +1,80 @@ + +using System; +using System.Threading; +using System.Threading.Tasks; +using LLama.Native; + +namespace LLama; + +/// +/// Lightweight wrapper around the MTMD native context and its helpers. +/// +public sealed class SafeMtmdWeights : IDisposable +{ + public SafeMtmdModelHandle NativeHandle { get; } + + private SafeMtmdWeights(SafeMtmdModelHandle handle) + { + NativeHandle = handle ?? throw new ArgumentNullException(nameof(handle)); + } + + public static SafeMtmdWeights LoadFromFile(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams) + { + if (mmProject == null) throw new ArgumentNullException(nameof(mmProject)); + if (textModel == null) throw new ArgumentNullException(nameof(textModel)); + if (mtmdCtxParams == null) throw new ArgumentNullException(nameof(mtmdCtxParams)); + + var handle = SafeMtmdModelHandle.LoadFromFile(mmProject, textModel, mtmdCtxParams); + return new SafeMtmdWeights(handle); + } + + public static Task LoadFromFileAsync(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams, CancellationToken token = default) + { + return Task.Run(() => LoadFromFile(mmProject, textModel, mtmdCtxParams), token); + } + + /// + /// Load media from disk and keep it pending for the next tokenize call. + /// + public SafeMtmdEmbed LoadMedia(string path) => NativeHandle.LoadMediaFromFile(path); + + /// + /// Load media from an in-memory buffer and keep it pending for the next tokenize call. + /// + public SafeMtmdEmbed LoadMedia(ReadOnlySpan data) => NativeHandle.LoadMediaFromBuffer(data); + + /// + /// Clear any pending media buffers before or after tokenization. + /// + public void ClearMedia() => NativeHandle.ClearMedia(); + + /// + /// Tokenize text (with optional special tokens) against the pending media buffers. + /// + public int Tokenize(string text, bool addSpecial, bool parseSpecial, out SafeMtmdInputChunks? chunks) + => NativeHandle.Tokenize(text, addSpecial, parseSpecial, out chunks); + + /// + /// Evaluate a chunk batch using the helper that performs mtmd encode + llama decode. + /// + public int EvaluateChunks(SafeMtmdInputChunks chunks, SafeLLamaContextHandle llamaContext, ref long nPast, int seqId, int nBatch, bool logitsLast) + => NativeHandle.EvaluateChunks(chunks, llamaContext, ref nPast, seqId, nBatch, logitsLast); + + public int EvaluateChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, ref long nPast, int seqId, int nBatch, bool logitsLast) + => NativeHandle.EvaluateChunk(chunkPtr, llamaContext, ref nPast, seqId, nBatch, logitsLast); + + public int DecodeImageChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, IntPtr encodedEmbeddings, ref long nPast, int seqId, int nBatch) + => NativeHandle.DecodeImageChunk(chunkPtr, llamaContext, encodedEmbeddings, ref nPast, seqId, nBatch); + + public ulong CountTokens(SafeMtmdInputChunks chunks) => NativeHandle.CountTokens(chunks); + + public long CountPositions(SafeMtmdInputChunks chunks) => NativeHandle.CountPositions(chunks); + + public bool SupportsVision => NativeHandle.SupportVision(); + public bool SupportsAudio => NativeHandle.SupportAudio(); + public bool UsesNonCausalAttention => NativeHandle.DecodeUseNonCausal(); + public bool UsesMRope => NativeHandle.DecodeUseMRope(); + public int AudioBitrate => NativeHandle.GetAudioBitrate(); + + public void Dispose() => NativeHandle.Dispose(); +} diff --git a/docs/Examples/LLavaInteractiveModeExecute.md b/docs/Examples/LLavaInteractiveModeExecute.md deleted file mode 100644 index 2bfbbea1d..000000000 --- a/docs/Examples/LLavaInteractiveModeExecute.md +++ /dev/null @@ -1,129 +0,0 @@ -# LLaVA - basic - -```cs -using System.Text.RegularExpressions; -using LLama.Common; -using Spectre.Console; -using LLama.Native; - -namespace LLama.Examples.Examples -{ - // This example shows how to chat with LLaVA model with both image and text as input. - // It uses the interactive executor to inference. - public class LlavaInteractiveModeExecute - { - public static async Task Run() - { - string multiModalProj = UserSettings.GetMMProjPath(); - string modelPath = UserSettings.GetModelPath(); - string modelImage = UserSettings.GetImagePath(); - const int maxTokens = 1024; - - var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n"; - - var parameters = new ModelParams(modelPath); - - using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters); - - // Llava Init - using var clipModel = LLavaWeights.LoadFromFile(multiModalProj); - - var ex = new InteractiveExecutor(context, clipModel ); - - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to {0} and the context size is {1}.", maxTokens, parameters.ContextSize ); - Console.WriteLine("To send an image, enter its filename in curly braces, like this {c:/image.jpg}."); - - var inferenceParams = new InferenceParams() { Temperature = 0.1f, AntiPrompts = new List { "\nUSER:" }, MaxTokens = maxTokens }; - - do - { - - // Evaluate if we have images - // - var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); - var imageCount = imageMatches.Count(); - var hasImages = imageCount > 0; - - if (hasImages) - { - var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); - var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList(); - - List imageBytes; - try - { - imageBytes = imagePaths.Select(File.ReadAllBytes).ToList(); - } - catch (IOException exception) - { - Console.ForegroundColor = ConsoleColor.Red; - Console.Write( - $"Could not load your {(imageCount == 1 ? "image" : "images")}:"); - Console.Write($"{exception.Message}"); - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("Please try again."); - break; - } - - // Each prompt with images we clear cache - // When the prompt contains images we clear KV_CACHE to restart conversation - // See: - // https://github.com/ggerganov/llama.cpp/discussions/3620 - ex.Context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 ); - - int index = 0; - foreach (var path in imagePathsWithCurlyBraces) - { - // First image replace to tag " : ""); - } - - - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine($"Here are the images, that are sent to the chat model in addition to your message."); - Console.WriteLine(); - - foreach (var consoleImage in imageBytes?.Select(bytes => new CanvasImage(bytes))) - { - consoleImage.MaxWidth = 50; - AnsiConsole.Write(consoleImage); - } - - Console.WriteLine(); - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine($"The images were scaled down for the console only, the model gets full versions."); - Console.WriteLine($"Write /exit or press Ctrl+c to return to main menu."); - Console.WriteLine(); - - - // Initialize Images in executor - // - foreach (var image in imagePaths) - { - ex.Images.Add(await File.ReadAllBytesAsync(image)); - } - } - - Console.ForegroundColor = Color.White; - await foreach (var text in ex.InferAsync(prompt, inferenceParams)) - { - Console.Write(text); - } - Console.Write(" "); - Console.ForegroundColor = ConsoleColor.Green; - prompt = Console.ReadLine(); - Console.WriteLine(); - - // let the user finish with exit - // - if (prompt != null && prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase)) - break; - - } - while(true); - } - } -} -``` \ No newline at end of file diff --git a/docs/Examples/MtmdInteractiveModeExecute.md b/docs/Examples/MtmdInteractiveModeExecute.md new file mode 100644 index 000000000..378c93a1b --- /dev/null +++ b/docs/Examples/MtmdInteractiveModeExecute.md @@ -0,0 +1,41 @@ +# MTMD interactive mode + +`MtmdInteractiveModeExecute` shows how to pair a multimodal projection with a text model so the chat loop can reason over images supplied at runtime. The sample lives in `LLama.Examples/Examples/MtmdInteractiveModeExecute.cs` and reuses the interactive executor provided by LLamaSharp. + +## Workflow +- Resolve the model, multimodal projection, and sample image paths via `UserSettings`. +- Create `ModelParams` for the text model and capture the MTMD defaults with `MtmdContextParams.Default()`. +- Load the base model and context, then initialize `SafeMtmdWeights` with the multimodal projection file. +- Ask the helper for a media marker (`mtmdParameters.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""`) and feed it into an `InteractiveExecutor`. + +```cs +var mtmdParameters = MtmdContextParams.Default(); + +using var model = await LLamaWeights.LoadFromFileAsync(parameters); +using var context = model.CreateContext(parameters); + +// Mtmd Init +using var clipModel = await SafeMtmdWeights.LoadFromFileAsync( + multiModalProj, + model, + mtmdParameters); + +var mediaMarker = mtmdParameters.MediaMarker + ?? NativeApi.MtmdDefaultMarker() + ?? ""; + +var ex = new InteractiveExecutor(context, clipModel); +``` + +## Handling user input +- Prompts can include image paths wrapped in braces (for example `{c:/image.jpg}`); the loop searches for those markers with regular expressions. +- Every referenced file is loaded through `SafeMtmdWeights.LoadMedia`, producing `SafeMtmdEmbed` instances that are queued for the next tokenization call. +- When the user provides images, the executor clears its KV cache (`MemorySequenceRemove`) before replacing each brace-wrapped path in the prompt with the multimodal marker. +- The embeds collected for the current turn are copied into `ex.Embeds`, so the executor submits both the text prompt and the pending media to the helper before generation. + +## Running the sample +1. Ensure the model and projection paths returned by `UserSettings` exist locally. +2. Start the example (for instance from the examples host application) and observe the initial description printed to the console. +3. Type text normally, or reference new images by including their path inside braces. Type `/exit` to end the conversation. + +This walkthrough mirrors the logic in the sample so you can adapt it for your own multimodal workflows. diff --git a/mkdocs.yml b/mkdocs.yml index 09cb3b96b..fbffdbba7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -38,7 +38,7 @@ nav: - Interactive executor - basic: Examples/InteractiveModeExecute.md - Kernel memory integration - basic: Examples/KernelMemory.md - Kernel-memory - save & load: Examples/KernelMemorySaveAndLoad.md - - LLaVA - basic: Examples/LLavaInteractiveModeExecute.md + - MTMD interactive: Examples/MtmdInteractiveModeExecute.md - ChatSession - load & save: Examples/LoadAndSaveSession.md - Executor - save/load state: Examples/LoadAndSaveState.md - Quantization: Examples/QuantizeModel.md @@ -254,4 +254,4 @@ markdown_extensions: custom_checkbox: true - pymdownx.tilde - pymdownx.tabbed: - alternate_style: true \ No newline at end of file + alternate_style: true From 82c039c9da20a2f300ce3a4d5737d8c39c6efe0b Mon Sep 17 00:00:00 2001 From: jlsantiago Date: Mon, 29 Sep 2025 21:56:58 +0200 Subject: [PATCH 27/35] Update LLama/Native/NativeApi.cs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- LLama/Native/NativeApi.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 3123674fc..0ea46a600 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -338,7 +338,6 @@ private static byte[] EncodeNullTerminatedUtf8(string value, string paramName) var bytes = Encoding.UTF8.GetBytes(value); var buffer = new byte[bytes.Length + 1]; Buffer.BlockCopy(bytes, 0, buffer, 0, bytes.Length); - // buffer[^1] = 0; return buffer; } From 83d31f8b5acf54220af11c966ea44d5d0e8c312d Mon Sep 17 00:00:00 2001 From: SignalRT Date: Mon, 29 Sep 2025 22:57:09 +0200 Subject: [PATCH 28/35] Resolve comment: https://github.com/SciSharp/LLamaSharp/pull/1261#discussion_r2386165308 --- LLama/Native/MtmdContextParams.cs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/LLama/Native/MtmdContextParams.cs b/LLama/Native/MtmdContextParams.cs index d83831d85..5b282d802 100644 --- a/LLama/Native/MtmdContextParams.cs +++ b/LLama/Native/MtmdContextParams.cs @@ -138,7 +138,16 @@ private PinnedUtf8String(string value) public static PinnedUtf8String? Create(string? value) => value is null ? null : new PinnedUtf8String(value); - public IntPtr Pointer => _buffer is null ? IntPtr.Zero : _handle.AddrOfPinnedObject(); + public IntPtr Pointer + { + get + { + if (_buffer is null || !_handle.IsAllocated) + return IntPtr.Zero; + + return _handle.AddrOfPinnedObject(); + } + } public void Dispose() { From b65a6cf59bac698356b40d683c4f5c4ea20801bf Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sun, 5 Oct 2025 13:47:51 +0200 Subject: [PATCH 29/35] Remove duplicate code --- LLama/Native/SafeMtmdModelHandle.cs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/LLama/Native/SafeMtmdModelHandle.cs b/LLama/Native/SafeMtmdModelHandle.cs index 236a22011..86abf8c6c 100644 --- a/LLama/Native/SafeMtmdModelHandle.cs +++ b/LLama/Native/SafeMtmdModelHandle.cs @@ -40,7 +40,7 @@ public static SafeMtmdModelHandle LoadFromFile(string modelPath, LLamaWeights te // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases. using (var fs = new FileStream(modelPath, FileMode.Open)) if (!fs.CanRead) - throw new InvalidOperationException($"Mtmd MMP Model file '{modelPath}' is not readable"); + throw new InvalidOperationException($"Mtmd Model file '{modelPath}' is not readable"); using var pathUtf8 = PinnedUtf8String.Create(modelPath) ?? throw new ArgumentNullException(nameof(modelPath)); @@ -138,21 +138,13 @@ public int Tokenize(string text, bool addSpecial, bool parseSpecial, out SafeMtm if (result == 0) { chunks = new SafeMtmdInputChunks(output); - foreach (var media in _pendingMedia) - media.Dispose(); - _pendingMedia.Clear(); } else { NativeApi.mtmd_input_chunks_free(output); } - if (result != 0) - { - foreach (var media in _pendingMedia) - media.Dispose(); - _pendingMedia.Clear(); - } + ClearMedia(); return result; } From 3e36bb91fb94c2538af51fa61395a85cc7cd796b Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sun, 5 Oct 2025 14:16:49 +0200 Subject: [PATCH 30/35] Move common logic to LlamaExecutorBase --- LLama/LLamaExecutorBase.cs | 195 ++++++++++++++++++++++++++++++ LLama/LLamaInstructExecutor.cs | 154 +----------------------- LLama/LLamaInteractExecutor.cs | 210 ++------------------------------- 3 files changed, 204 insertions(+), 355 deletions(-) diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index a39ad3836..bb1c27a35 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -86,6 +86,13 @@ public bool IsMultiModal /// public List Embeds { get; } + /// + /// Pending multimodal chunks produced by the MTMD tokenizer. + /// + protected SafeMtmdInputChunks? MtmdChunks { get; set; } + + private string? _mtmdMarker; + private readonly StreamingTokenDecoder _decoder; /// @@ -242,6 +249,194 @@ protected virtual void TryReuseMatchingPrefix() } } + /// + /// Dispose and clear any queued multimodal chunk collection. + /// + protected void DisposeMtmdChunks() + { + MtmdChunks?.Dispose(); + MtmdChunks = null; + } + + /// + /// Dispose and clear any pending multimodal embeddings. + /// + protected void DisposeEmbeds() + { + if (Embeds.Count == 0) + return; + + foreach (var embed in Embeds) + embed.Dispose(); + + Embeds.Clear(); + } + + /// + /// Retrieve the marker token used to signal media segments to the tokenizer. + /// + protected string GetMtmdMarker() + { + if (_mtmdMarker is not null) + return _mtmdMarker; + + _mtmdMarker = NativeApi.MtmdDefaultMarker() ?? ""; + return _mtmdMarker; + } + + /// + /// Ensure the token list fills all positional slots reported by the MTMD helper. + /// + protected static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken) + { + if (totalPositions <= tokens.Count) + return new List(tokens); + + var result = new List(totalPositions); + result.AddRange(tokens); + result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count)); + return result; + } + + /// + /// Resolve the fallback token inserted when the tokenizer emits fewer tokens than positions. + /// + protected LLamaToken GetFillerToken(string marker) + { + var markerTokens = Context.Tokenize(marker, false, true); + if (markerTokens.Length > 0) + return markerTokens[markerTokens.Length - 1]; + + var eos = Context.Vocab.EOS; + if (eos.HasValue) + return eos.Value; + + return default; + } + + /// + /// Prepare multimodal inputs by invoking the MTMD tokenizer and aligning filler tokens. + /// + protected Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, bool replaceExisting) + { + if (ClipModel is null) + throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); + + DisposeMtmdChunks(); + + var marker = GetMtmdMarker(); + var prompt = text; + + if (Embeds.Count > 0) + { + if (prompt.Contains("")) + prompt = prompt.Replace("", marker); + + if (!prompt.Contains(marker)) + { + var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count)); + prompt = string.Concat(prompt, suffix); + } + } + + SafeMtmdInputChunks? chunks = null; + try + { + var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); + if (status != 0 || chunks is null) + { + ClipModel.ClearMedia(); + throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); + } + + MtmdChunks = chunks; + + var tokens = new List(); + foreach (var chunk in chunks.Enumerate()) + { + using var scopedChunk = chunk; + if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) + continue; + + foreach (var token in scopedChunk.GetTextTokensSpan()) + tokens.Add(unchecked((int)token)); + } + + var totalPositions = (int)ClipModel.CountPositions(chunks); + var fillerToken = GetFillerToken(marker); + + if (replaceExisting) + { + _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken); + _consumedTokensCount = 0; + } + else + { + if (_embed_inps.Count == 0) + _embed_inps = new List(); + + _embed_inps.AddRange(tokens); + var fillerCount = totalPositions - tokens.Count; + if (fillerCount > 0) + _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount)); + + args.RemainedTokens -= tokens.Count; + } + } + catch + { + chunks?.Dispose(); + MtmdChunks = null; + throw; + } + finally + { + DisposeEmbeds(); + } + + return Task.CompletedTask; + } + + /// + /// Apply bookkeeping after successfully evaluating multimodal chunks. + /// + protected void FinalizeMtmdEvaluation(long newNPast, int previousConsumed) + { + _pastTokensCount = checked((int)newNPast); + DisposeMtmdChunks(); + + if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) + { + _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); + _n_session_consumed = _session_tokens.Count; + } + + _consumedTokensCount = _embed_inps.Count; + _embeds.Clear(); + } + + /// + /// Evaluate the queued MTMD chunks and update executor state. + /// + protected void EvaluateMtmdChunks(ref long nPast, int previousConsumed, string executorName) + { + if (ClipModel is null) + throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); + if (MtmdChunks is null) + throw new InvalidOperationException("No MTMD chunks are queued for evaluation."); + + var evalStatus = ClipModel.EvaluateChunks(MtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, + nBatch: checked((int)Context.BatchSize), logitsLast: true); + if (evalStatus != 0) + { + _logger?.LogError("[{Executor}] Failed to evaluate multimodal chunks. Status: {Status}", executorName, evalStatus); + DisposeMtmdChunks(); + throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); + } + + FinalizeMtmdEvaluation(nPast, previousConsumed); + } + /// /// Determine whether the inference loop should continue processing tokens. /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 1bdba035a..b7a0f9ec7 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -26,8 +26,6 @@ public class InstructExecutor private readonly string _instructionPrefix; private LLamaToken[] _inp_pfx; private LLamaToken[] _inp_sfx; - private SafeMtmdInputChunks? _mtmdChunks; - private string? _mtmdMarker; private readonly string _instructionSuffix; /// @@ -192,136 +190,6 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc return Task.CompletedTask; } - private void DisposeMtmdChunks() - { - _mtmdChunks?.Dispose(); - _mtmdChunks = null; - } - - private void DisposeEmbeds() - { - if (Embeds.Count == 0) - return; - - foreach (var embed in Embeds) - embed.Dispose(); - - Embeds.Clear(); - } - - private string GetMtmdMarker() - { - if (_mtmdMarker is not null) - return _mtmdMarker; - - _mtmdMarker = NativeApi.MtmdDefaultMarker() ?? ""; - return _mtmdMarker; - } - - private static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken) - { - if (totalPositions <= tokens.Count) - return new List(tokens); - - var result = new List(totalPositions); - result.AddRange(tokens); - result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count)); - return result; - } - - private LLamaToken GetFillerToken(string marker) - { - var markerTokens = Context.Tokenize(marker, false, true); - if (markerTokens.Length > 0) - return markerTokens[markerTokens.Length - 1]; - - var eos = Context.Vocab.EOS; - if (eos.HasValue) - return eos.Value; - - return default(LLamaToken); - } - - private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, bool replaceExisting) - { - if (ClipModel is null) - throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); - - DisposeMtmdChunks(); - - var marker = GetMtmdMarker(); - var prompt = text; - - if (Embeds.Count > 0) - { - if (prompt.Contains("")) - prompt = prompt.Replace("", marker); - - if (!prompt.Contains(marker)) - { - var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count)); - prompt = string.Concat(prompt, suffix); - } - } - - SafeMtmdInputChunks? chunks = null; - try - { - var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); - if (status != 0 || chunks is null) - { - ClipModel.ClearMedia(); - throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); - } - - _mtmdChunks = chunks; - - var tokens = new List(); - foreach (var chunk in chunks.Enumerate()) - { - using var scopedChunk = chunk; - if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) - continue; - - foreach (var token in scopedChunk.GetTextTokensSpan()) - tokens.Add(unchecked((int)token)); - } - - var totalPositions = (int)ClipModel.CountPositions(chunks); - var fillerToken = GetFillerToken(marker); - - if (replaceExisting) - { - _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken); - _consumedTokensCount = 0; - } - else - { - if (_embed_inps.Count == 0) - _embed_inps = new List(); - - _embed_inps.AddRange(tokens); - var fillerCount = totalPositions - tokens.Count; - if (fillerCount > 0) - _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount)); - - args.RemainedTokens -= tokens.Count; - } - } - catch - { - chunks?.Dispose(); - _mtmdChunks = null; - throw; - } - finally - { - DisposeEmbeds(); - } - - return Task.CompletedTask; - } - /// protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { @@ -384,30 +252,12 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In _n_session_consumed = _session_tokens.Count; } } - else if (IsMultiModal && _mtmdChunks is not null) + else if (IsMultiModal && MtmdChunks is not null) { _is_prompt_run = false; var nPast = (long)_pastTokensCount; var previousConsumed = _consumedTokensCount; - var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)Context.BatchSize), logitsLast: true); - if (evalStatus != 0) - { - _logger?.LogError("[InstructExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); - DisposeMtmdChunks(); - throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); - } - - _pastTokensCount = checked((int)nPast); - DisposeMtmdChunks(); - - if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) - { - _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); - _n_session_consumed = _session_tokens.Count; - } - - _consumedTokensCount = _embed_inps.Count; - _embeds.Clear(); + EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InstructExecutor)); } _embeds.Clear(); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 97d49f5de..da6ed53a9 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -144,7 +144,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc } else { - PreprocessMtmd(text, args, true); + return PreprocessMtmd(text, args, addBos: true, replaceExisting: true); } } else @@ -165,7 +165,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc } else { - PreprocessMtmd(text, args, false); + return PreprocessMtmd(text, args, addBos: false, replaceExisting: false); } } } @@ -173,165 +173,6 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc return Task.CompletedTask; } - /// - /// Release any queued multimodal chunks and reset state. - /// - private void DisposeMtmdChunks() - { - _mtmdChunks?.Dispose(); - _mtmdChunks = null; - } - - /// - /// Dispose and clear any pending multimodal embeddings queued for evaluation. - /// - private void DisposeEmbeds() - { - if (Embeds.Count == 0) - { - return; - } - - foreach (var embed in Embeds) - { - embed.Dispose(); - } - - Embeds.Clear(); - } - - /// - /// Retrieve the marker token used to signal media segments to the tokenizer. - /// - private string GetMtmdMarker() - { - if (_mtmdMarker is not null) - { - return _mtmdMarker; - } - - _mtmdMarker = NativeApi.MtmdDefaultMarker() ?? ""; - return _mtmdMarker; - } - - private static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken) - { - if (totalPositions <= tokens.Count) - return new List(tokens); - - var result = new List(totalPositions); - result.AddRange(tokens); - result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count)); - return result; - } - - private LLamaToken GetFillerToken(string marker) - { - var markerTokens = Context.Tokenize(marker, false, true); - if (markerTokens.Length > 0) - return markerTokens[markerTokens.Length - 1]; - - var eos = Context.Vocab.EOS; - if (eos.HasValue) - return eos.Value; - - return default(LLamaToken); - } - - /// - /// Preprocess multimodal prompts by aligning media markers and tokenizing via MTMD helpers. - /// - /// Prompt text containing optional media markers. - /// Mutable inference state. - /// Whether to treat the prompt as a fresh run and add the BOS token. - private Task PreprocessMtmd(string text, InferStateArgs args, bool addBos = true) - { - if (ClipModel is null) - { - throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); - } - - DisposeMtmdChunks(); - - var marker = GetMtmdMarker(); - var prompt = text; - - if (Embeds.Count > 0) - { - if (prompt.Contains("")) - { - prompt = prompt.Replace("", marker); - } - - if (!prompt.Contains(marker)) - { - var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count)); // Ensure tokenizer sees one marker per embed. - prompt = string.Concat(prompt, suffix); - } - } - - SafeMtmdInputChunks? chunks = null; - try - { - var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); - if (status != 0 || chunks is null) - { - ClipModel.ClearMedia(); - throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); - } - - _mtmdChunks = chunks; // Own the chunk collection until evaluation completes. - - var tokens = new List(); - foreach (var chunk in chunks.Enumerate()) - { - using var scopedChunk = chunk; - if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) - { - continue; - } - - foreach (var token in scopedChunk.GetTextTokensSpan()) - { - tokens.Add(unchecked((int)token)); - } - } - - var totalPositions = (int)ClipModel.CountPositions(chunks); - var fillerToken = GetFillerToken(marker); - - if (addBos) - { - _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken); - _consumedTokensCount = 0; - } - else - { - if (_embed_inps.Count == 0) - _embed_inps = new List(); - - _embed_inps.AddRange(tokens); - var fillerCount = totalPositions - tokens.Count; - if (fillerCount > 0) - _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount)); - - args.RemainedTokens -= tokens.Count; - } - } - catch - { - chunks?.Dispose(); - _mtmdChunks = null; - throw; - } - finally - { - DisposeEmbeds(); // Flush any embeds decoded in prior step; MTMD replays them via chunk eval. - } - - return Task.CompletedTask; - } - /// /// Decide whether generation should stop based on antiprompts, token limits, or end-of-generation markers. /// @@ -393,35 +234,16 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In HandleRunOutOfContext(tokensToKeep); } - if (_mtmdChunks is null) + if (MtmdChunks is null) { TryReuseMatchingPrefix(); } - if (IsMultiModal && _mtmdChunks is not null) + if (IsMultiModal && MtmdChunks is not null) { var nPast = (long)_pastTokensCount; var previousConsumed = _consumedTokensCount; - var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, - nBatch: checked((int)Context.BatchSize), logitsLast: true); - if (evalStatus != 0) - { - _logger?.LogError("[InteractiveExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); - DisposeMtmdChunks(); - throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); - } - - _pastTokensCount = checked((int)nPast); - DisposeMtmdChunks(); - - if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) - { - _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); - _n_session_consumed = _session_tokens.Count; - } - - _consumedTokensCount = _embed_inps.Count; - _embeds.Clear(); + EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InteractiveExecutor)); } else { @@ -437,30 +259,12 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In } } } - else if (IsMultiModal && _mtmdChunks is not null) + else if (IsMultiModal && MtmdChunks is not null) { _is_prompt_run = false; var nPast = (long)_pastTokensCount; var previousConsumed = _consumedTokensCount; - var evalStatus = ClipModel!.EvaluateChunks(_mtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)Context.BatchSize), logitsLast: true); - if (evalStatus != 0) - { - _logger?.LogError("[InteractiveExecutor] Failed to evaluate multimodal chunks. Status: {Status}", evalStatus); - DisposeMtmdChunks(); - throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); - } - - _pastTokensCount = checked((int)nPast); - DisposeMtmdChunks(); - - if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) - { - _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); - _n_session_consumed = _session_tokens.Count; - } - - _consumedTokensCount = _embed_inps.Count; - _embeds.Clear(); + EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InteractiveExecutor)); } _embeds.Clear(); From e58618c2cc017db8c706e78b8f36002b3401f493 Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sat, 25 Oct 2025 10:31:18 +0200 Subject: [PATCH 31/35] Rename SafeMtmdWeights --- .../Examples/BatchedExecutorMtmd.cs | 2 +- .../Examples/MtmdInteractiveModeExecute.cs | 2 +- LLama.Unittest/MtmdExecutorTests.cs | 4 +-- LLama.Unittest/MtmdWeightsTests.cs | 28 +++++++++---------- LLama/Abstractions/ILLamaExecutor.cs | 2 +- LLama/Batched/BatchedExecutor.cs | 8 +++--- LLama/Batched/Conversation.cs | 2 +- LLama/LLamaExecutorBase.cs | 8 +++--- LLama/LLamaInstructExecutor.cs | 2 +- LLama/LLamaInteractExecutor.cs | 2 +- LLama/LLamaStatelessExecutor.cs | 2 +- LLama/{SafeMtmdWeights.cs => MtmdWeights.cs} | 10 +++---- 12 files changed, 36 insertions(+), 36 deletions(-) rename LLama/{SafeMtmdWeights.cs => MtmdWeights.cs} (87%) diff --git a/LLama.Examples/Examples/BatchedExecutorMtmd.cs b/LLama.Examples/Examples/BatchedExecutorMtmd.cs index b62f8b120..8cdc4dac5 100644 --- a/LLama.Examples/Examples/BatchedExecutorMtmd.cs +++ b/LLama.Examples/Examples/BatchedExecutorMtmd.cs @@ -30,7 +30,7 @@ public static async Task Run() mtmdParams.UseGpu = false; var marker = mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; - using var mtmd = await SafeMtmdWeights.LoadFromFileAsync(UserSettings.GetMMProjPath(), model, mtmdParams); // multimodal helper weights + using var mtmd = await MtmdWeights.LoadFromFileAsync(UserSettings.GetMMProjPath(), model, mtmdParams); // multimodal helper weights using var executor = new BatchedExecutor(model, parameters, mtmd); // drives batched token + chunk evaluation diff --git a/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs b/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs index ca0de3b77..b6395d0f8 100644 --- a/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs +++ b/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs @@ -30,7 +30,7 @@ public static async Task Run() using var context = model.CreateContext(parameters); // Mtmd Init - using var clipModel = await SafeMtmdWeights.LoadFromFileAsync(multiModalProj, model, mtmdParameters ); + using var clipModel = await MtmdWeights.LoadFromFileAsync(multiModalProj, model, mtmdParameters ); var mediaMarker = mtmdParameters.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; diff --git a/LLama.Unittest/MtmdExecutorTests.cs b/LLama.Unittest/MtmdExecutorTests.cs index 75a96b261..2533a75d8 100644 --- a/LLama.Unittest/MtmdExecutorTests.cs +++ b/LLama.Unittest/MtmdExecutorTests.cs @@ -13,7 +13,7 @@ public class MtmdExecutorTests : IDisposable { private readonly LLamaWeights _weights; private readonly MtmdContextParams _mtmdParams; - private readonly SafeMtmdWeights _mtmd; + private readonly MtmdWeights _mtmd; private readonly ModelParams _modelParams; public MtmdExecutorTests() @@ -30,7 +30,7 @@ public MtmdExecutorTests() _mtmdParams.NThreads = Math.Max(1, Constants.CIGpuLayerCount); _mtmdParams.UseGpu = false; - _mtmd = SafeMtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _weights, _mtmdParams); + _mtmd = MtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _weights, _mtmdParams); } public void Dispose() diff --git a/LLama.Unittest/MtmdWeightsTests.cs b/LLama.Unittest/MtmdWeightsTests.cs index 947bbd1ea..9ffffc518 100644 --- a/LLama.Unittest/MtmdWeightsTests.cs +++ b/LLama.Unittest/MtmdWeightsTests.cs @@ -12,7 +12,7 @@ public sealed class MtmdWeightTests : IDisposable { private readonly LLamaWeights _llamaWeights; - private readonly SafeMtmdWeights _safeMtmdWeights; + private readonly MtmdWeights _mtmdWeights; private readonly LLamaContext _context; private readonly MtmdContextParams _mtmdParams; private readonly string _mediaMarker; @@ -33,20 +33,20 @@ public MtmdWeightTests() _mediaMarker = _mtmdParams.MediaMarker ?? throw new InvalidOperationException("MTMD media marker unavailable."); - _safeMtmdWeights = SafeMtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _llamaWeights, _mtmdParams); + _mtmdWeights = MtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _llamaWeights, _mtmdParams); _context = _llamaWeights.CreateContext(@params); } public void Dispose() { _context.Dispose(); - _safeMtmdWeights.Dispose(); + _mtmdWeights.Dispose(); _llamaWeights.Dispose(); } private SafeMtmdInputChunks TokenizeWithEmbed(Func loadEmbed) { - _safeMtmdWeights.ClearMedia(); + _mtmdWeights.ClearMedia(); var embed = loadEmbed(); Assert.NotNull(embed); @@ -58,7 +58,7 @@ private SafeMtmdInputChunks TokenizeWithEmbed(Func loadEmbed) Assert.False(embed.IsAudio); Assert.True(embed.GetDataSpan().Length > 0); - var status = _safeMtmdWeights.Tokenize(_mediaMarker, addSpecial: true, parseSpecial: true, out var chunks); + var status = _mtmdWeights.Tokenize(_mediaMarker, addSpecial: true, parseSpecial: true, out var chunks); Assert.Equal(0, status); Assert.NotNull(chunks); @@ -69,7 +69,7 @@ private SafeMtmdInputChunks TokenizeWithEmbed(Func loadEmbed) private void AssertChunksEvaluate(SafeMtmdInputChunks chunks) { long nPast = 0; - var eval = _safeMtmdWeights.EvaluateChunks(chunks, _context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)_context.BatchSize), logitsLast: true); + var eval = _mtmdWeights.EvaluateChunks(chunks, _context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)_context.BatchSize), logitsLast: true); Assert.Equal(0, eval); Assert.True(nPast > 0); } @@ -77,7 +77,7 @@ private void AssertChunksEvaluate(SafeMtmdInputChunks chunks) [Fact,Trait("Category", "NoCI")] public void EmbedImageAsFileName() { - using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(Constants.MtmdImage)); + using var chunks = TokenizeWithEmbed(() => _mtmdWeights.LoadMedia(Constants.MtmdImage)); AssertChunksEvaluate(chunks); } @@ -85,14 +85,14 @@ public void EmbedImageAsFileName() public void EmbedImageAsBinary() { var imageBytes = File.ReadAllBytes(Constants.MtmdImage); - using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(imageBytes)); + using var chunks = TokenizeWithEmbed(() => _mtmdWeights.LoadMedia(imageBytes)); AssertChunksEvaluate(chunks); } [Fact,Trait("Category", "NoCI")] public void TokenizeProvidesChunkMetadata() { - using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(Constants.MtmdImage)); + using var chunks = TokenizeWithEmbed(() => _mtmdWeights.LoadMedia(Constants.MtmdImage)); Assert.True(chunks.Size > 0); @@ -128,12 +128,12 @@ public void TokenizeProvidesChunkMetadata() Assert.True(imageChunks > 0); Assert.True(totalTokens > 0); - Assert.Equal(totalTokens, _safeMtmdWeights.CountTokens(chunks)); - Assert.Equal(totalPositions, _safeMtmdWeights.CountPositions(chunks)); - Assert.True(_safeMtmdWeights.SupportsVision); - Assert.False(_safeMtmdWeights.SupportsAudio); + Assert.Equal(totalTokens, _mtmdWeights.CountTokens(chunks)); + Assert.Equal(totalPositions, _mtmdWeights.CountPositions(chunks)); + Assert.True(_mtmdWeights.SupportsVision); + Assert.False(_mtmdWeights.SupportsAudio); - var audioBitrate = _safeMtmdWeights.AudioBitrate; + var audioBitrate = _mtmdWeights.AudioBitrate; Assert.True(audioBitrate <= 0); } } diff --git a/LLama/Abstractions/ILLamaExecutor.cs b/LLama/Abstractions/ILLamaExecutor.cs index 92276e4a6..974a4ffbf 100644 --- a/LLama/Abstractions/ILLamaExecutor.cs +++ b/LLama/Abstractions/ILLamaExecutor.cs @@ -23,7 +23,7 @@ public interface ILLamaExecutor /// /// Multi-Modal Projections / Clip Model weights /// - public SafeMtmdWeights? ClipModel { get; } + public MtmdWeights? ClipModel { get; } /// /// List of media: List of media for Multi-Modal models. diff --git a/LLama/Batched/BatchedExecutor.cs b/LLama/Batched/BatchedExecutor.cs index 40468c98d..1a6698b1a 100644 --- a/LLama/Batched/BatchedExecutor.cs +++ b/LLama/Batched/BatchedExecutor.cs @@ -84,7 +84,7 @@ public BatchedExecutor(LLamaWeights model, IContextParams contextParams) { } - public BatchedExecutor(LLamaWeights model, IContextParams contextParams, SafeMtmdWeights? clipModel) + public BatchedExecutor(LLamaWeights model, IContextParams contextParams, MtmdWeights? clipModel) { Model = model; Context = model.CreateContext(contextParams); @@ -92,7 +92,7 @@ public BatchedExecutor(LLamaWeights model, IContextParams contextParams, SafeMtm Epoch = 1; } - public SafeMtmdWeights? ClipModel { get; } + public MtmdWeights? ClipModel { get; } /// /// Start a new @@ -374,11 +374,11 @@ public Task DecodeAsync(LLamaContext ctx, CancellationToken token) private class MtmdChunkBatch : IBatch { - private readonly SafeMtmdWeights _clipModel; + private readonly MtmdWeights _clipModel; private readonly Conversation _conversation; private readonly Conversation.MtmdChunkSequence _sequence; - public MtmdChunkBatch(SafeMtmdWeights clipModel, Conversation conversation, Conversation.MtmdChunkSequence sequence) + public MtmdChunkBatch(MtmdWeights clipModel, Conversation conversation, Conversation.MtmdChunkSequence sequence) { _clipModel = clipModel; _conversation = conversation; diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs index 2311c8a0c..89d725e97 100644 --- a/LLama/Batched/Conversation.cs +++ b/LLama/Batched/Conversation.cs @@ -86,7 +86,7 @@ private MtmdChunkSequence(SafeMtmdInputChunks chunks, List textToken TotalPositions = totalPositions; } - public static MtmdChunkSequence Create(SafeMtmdInputChunks chunks, SafeMtmdWeights clipModel) + public static MtmdChunkSequence Create(SafeMtmdInputChunks chunks, MtmdWeights clipModel) { var textTokens = new List(); diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index bb1c27a35..1e264ebab 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -81,7 +81,7 @@ public bool IsMultiModal } /// - public SafeMtmdWeights? ClipModel { get; } + public MtmdWeights? ClipModel { get; } /// public List Embeds { get; } @@ -117,12 +117,12 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) /// Initialize a multimodal executor with the supplied MTMD weights. /// /// LLama context used for all native interactions. - /// Multimodal weights to associate with this executor. + /// Multimodal weights to associate with this executor. /// Optional logger for diagnostic output. - public StatefulExecutorBase(LLamaContext context, SafeMtmdWeights safeMtmdWeights, ILogger? logger = null) : + public StatefulExecutorBase(LLamaContext context, MtmdWeights mtmdWeights, ILogger? logger = null) : this( context, logger ) { - ClipModel = safeMtmdWeights; + ClipModel = mtmdWeights; } /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index b7a0f9ec7..a31cab211 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -48,7 +48,7 @@ public InstructExecutor(LLamaContext context, } public InstructExecutor(LLamaContext context, - SafeMtmdWeights clipModel, + MtmdWeights clipModel, string instructionPrefix = "\n\n### Instruction:\n\n", string instructionSuffix = "\n\n### Response:\n\n", ILogger? logger = null) diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index da6ed53a9..1359447c4 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -46,7 +46,7 @@ public InteractiveExecutor(LLamaContext context, ILogger? logger = null) /// LLama context to operate against. /// Multimodal weights (MTMD) to attach to the executor. /// Optional logger for diagnostic output. - public InteractiveExecutor(LLamaContext context, SafeMtmdWeights clipModel, ILogger? logger = null) + public InteractiveExecutor(LLamaContext context, MtmdWeights clipModel, ILogger? logger = null) : base(context, clipModel, logger) { } diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 94bc60830..a895054d4 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -28,7 +28,7 @@ public class StatelessExecutor public bool IsMultiModal => false; /// - public SafeMtmdWeights? ClipModel => default; + public MtmdWeights? ClipModel => default; /// public List Embeds { get; } diff --git a/LLama/SafeMtmdWeights.cs b/LLama/MtmdWeights.cs similarity index 87% rename from LLama/SafeMtmdWeights.cs rename to LLama/MtmdWeights.cs index e490049b4..945ccecc2 100644 --- a/LLama/SafeMtmdWeights.cs +++ b/LLama/MtmdWeights.cs @@ -9,26 +9,26 @@ namespace LLama; /// /// Lightweight wrapper around the MTMD native context and its helpers. /// -public sealed class SafeMtmdWeights : IDisposable +public sealed class MtmdWeights : IDisposable { public SafeMtmdModelHandle NativeHandle { get; } - private SafeMtmdWeights(SafeMtmdModelHandle handle) + private MtmdWeights(SafeMtmdModelHandle handle) { NativeHandle = handle ?? throw new ArgumentNullException(nameof(handle)); } - public static SafeMtmdWeights LoadFromFile(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams) + public static MtmdWeights LoadFromFile(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams) { if (mmProject == null) throw new ArgumentNullException(nameof(mmProject)); if (textModel == null) throw new ArgumentNullException(nameof(textModel)); if (mtmdCtxParams == null) throw new ArgumentNullException(nameof(mtmdCtxParams)); var handle = SafeMtmdModelHandle.LoadFromFile(mmProject, textModel, mtmdCtxParams); - return new SafeMtmdWeights(handle); + return new MtmdWeights(handle); } - public static Task LoadFromFileAsync(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams, CancellationToken token = default) + public static Task LoadFromFileAsync(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams, CancellationToken token = default) { return Task.Run(() => LoadFromFile(mmProject, textModel, mtmdCtxParams), token); } From 7e92b31cb418732b974fa227f5d6c6c6af338d2f Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sat, 25 Oct 2025 10:48:18 +0200 Subject: [PATCH 32/35] Implement SafeHandle --- LLama/Native/SafeMtmdEmbed.cs | 104 +++++++++++++++++--------- LLama/Native/SafeMtmdInputChunk.cs | 112 +++++++++++++++++----------- LLama/Native/SafeMtmdInputChunks.cs | 80 ++++++++++++-------- 3 files changed, 188 insertions(+), 108 deletions(-) diff --git a/LLama/Native/SafeMtmdEmbed.cs b/LLama/Native/SafeMtmdEmbed.cs index c651db102..426b77c79 100644 --- a/LLama/Native/SafeMtmdEmbed.cs +++ b/LLama/Native/SafeMtmdEmbed.cs @@ -8,20 +8,25 @@ namespace LLama.Native /// Managed wrapper around mtmd_bitmap* resources. Instances own the native pointer /// and ensure proper cleanup when disposed. /// - public sealed class SafeMtmdEmbed : IDisposable + public sealed class SafeMtmdEmbed : SafeLLamaHandleBase { /// /// Raw pointer to the native bitmap structure. Internal so other wrappers can interop. /// - internal IntPtr NativePtr { get; private set; } - - private bool _disposed; + internal IntPtr NativePtr + { + get + { + EnsureNotDisposed(); + return DangerousGetHandle(); + } + } private SafeMtmdEmbed(IntPtr ptr) + : base(ptr, ownsHandle: true) { - NativePtr = ptr != IntPtr.Zero - ? ptr - : throw new InvalidOperationException("Failed to create MTMD bitmap."); + if (IsInvalid) + throw new InvalidOperationException("Failed to create MTMD bitmap."); } /// @@ -154,8 +159,7 @@ public uint Nx { get { - EnsureNotDisposed(); - return NativeApi.mtmd_bitmap_get_nx(NativePtr); + return WithHandle(ptr => NativeApi.mtmd_bitmap_get_nx(ptr)); } } @@ -166,8 +170,7 @@ public uint Ny { get { - EnsureNotDisposed(); - return NativeApi.mtmd_bitmap_get_ny(NativePtr); + return WithHandle(ptr => NativeApi.mtmd_bitmap_get_ny(ptr)); } } @@ -178,8 +181,7 @@ public bool IsAudio { get { - EnsureNotDisposed(); - return NativeApi.mtmd_bitmap_is_audio(NativePtr); + return WithHandle(ptr => NativeApi.mtmd_bitmap_is_audio(ptr)); } } @@ -190,14 +192,19 @@ public string? Id { get { - EnsureNotDisposed(); - var ptr = NativeApi.mtmd_bitmap_get_id(NativePtr); - return NativeApi.PtrToStringUtf8(ptr); + return WithHandle(ptr => + { + var idPtr = NativeApi.mtmd_bitmap_get_id(ptr); + return NativeApi.PtrToStringUtf8(idPtr); + }); } set { - EnsureNotDisposed(); - NativeApi.mtmd_bitmap_set_id(NativePtr, value); + WithHandle(ptr => + { + NativeApi.mtmd_bitmap_set_id(ptr, value); + return 0; + }); } } @@ -210,38 +217,63 @@ public unsafe ReadOnlySpan GetDataSpan() { EnsureNotDisposed(); - var dataPtr = (byte*)NativeApi.mtmd_bitmap_get_data(NativePtr); - var length = checked((int)NativeApi.mtmd_bitmap_get_n_bytes(NativePtr).ToUInt64()); - return dataPtr == null || length == 0 ? ReadOnlySpan.Empty : new ReadOnlySpan(dataPtr, length); + bool added = false; + try + { + DangerousAddRef(ref added); + var ptr = DangerousGetHandle(); + var dataPtr = (byte*)NativeApi.mtmd_bitmap_get_data(ptr); + var length = checked((int)NativeApi.mtmd_bitmap_get_n_bytes(ptr).ToUInt64()); + return dataPtr == null || length == 0 + ? ReadOnlySpan.Empty + : new ReadOnlySpan(dataPtr, length); + } + finally + { + if (added) + DangerousRelease(); + } } /// /// Release the underlying native bitmap. /// - public void Dispose() + protected override bool ReleaseHandle() { - if (_disposed) - return; - - if (NativePtr != IntPtr.Zero) + if (handle != IntPtr.Zero) { - NativeApi.mtmd_bitmap_free(NativePtr); - NativePtr = IntPtr.Zero; + NativeApi.mtmd_bitmap_free(handle); + SetHandle(IntPtr.Zero); } - _disposed = true; - GC.SuppressFinalize(this); + return true; } - /// - /// Finalizer to ensure native resources are reclaimed when Dispose is not invoked. - /// - ~SafeMtmdEmbed() => Dispose(); - private void EnsureNotDisposed() { - if (_disposed || NativePtr == IntPtr.Zero) + if (IsClosed || IsInvalid) throw new ObjectDisposedException(nameof(SafeMtmdEmbed)); } + + private T WithHandle(Func action) + { + EnsureNotDisposed(); + + bool added = false; + try + { + DangerousAddRef(ref added); + var ptr = DangerousGetHandle(); + if (ptr == IntPtr.Zero) + throw new ObjectDisposedException(nameof(SafeMtmdEmbed)); + + return action(ptr); + } + finally + { + if (added) + DangerousRelease(); + } + } } } diff --git a/LLama/Native/SafeMtmdInputChunk.cs b/LLama/Native/SafeMtmdInputChunk.cs index 59d1897ef..3ab4f7ccc 100644 --- a/LLama/Native/SafeMtmdInputChunk.cs +++ b/LLama/Native/SafeMtmdInputChunk.cs @@ -8,7 +8,7 @@ namespace LLama.Native; /// underlying native pointer (when created via ) or act as non-owning views /// produced by the tokenizer. /// -public sealed class SafeMtmdInputChunk : IDisposable +public sealed class SafeMtmdInputChunk : SafeLLamaHandleBase { /// /// Chunk modality returned by the native tokenizer. @@ -23,15 +23,18 @@ public enum SafeMtmdInputChunkType /// /// Raw pointer to the native chunk structure. /// - public IntPtr NativePtr { get; private set; } - - private bool _ownsPtr; - private bool _disposed; + public IntPtr NativePtr + { + get + { + EnsureNotDisposed(); + return DangerousGetHandle(); + } + } - private SafeMtmdInputChunk(IntPtr ptr, bool owns) + private SafeMtmdInputChunk(IntPtr handle, bool ownsHandle) + : base(handle, ownsHandle) { - NativePtr = ptr; - _ownsPtr = owns; } /// @@ -40,7 +43,7 @@ private SafeMtmdInputChunk(IntPtr ptr, bool owns) /// Pointer returned by the native tokenizer. /// Managed wrapper, or null when the pointer is null. public static SafeMtmdInputChunk Wrap(IntPtr ptr) - => ptr == IntPtr.Zero ? null : new SafeMtmdInputChunk(ptr, false); + => ptr == IntPtr.Zero ? null : new SafeMtmdInputChunk(ptr, ownsHandle: false); /// /// Create an owning copy of the current chunk. The caller becomes responsible for disposal. @@ -49,10 +52,11 @@ public static SafeMtmdInputChunk Wrap(IntPtr ptr) /// Thrown when the current wrapper has been disposed. public SafeMtmdInputChunk Copy() { - EnsureNotDisposed(); - - var p = NativeApi.mtmd_input_chunk_copy(NativePtr); - return p == IntPtr.Zero ? null : new SafeMtmdInputChunk(p, true); + return WithHandle(ptr => + { + var clone = NativeApi.mtmd_input_chunk_copy(ptr); + return clone == IntPtr.Zero ? null : new SafeMtmdInputChunk(clone, ownsHandle: true); + }); } /// @@ -62,8 +66,7 @@ public SafeMtmdInputChunkType Type { get { - EnsureNotDisposed(); - return (SafeMtmdInputChunkType)NativeApi.mtmd_input_chunk_get_type(NativePtr); + return WithHandle(ptr => (SafeMtmdInputChunkType)NativeApi.mtmd_input_chunk_get_type(ptr)); } } @@ -74,8 +77,7 @@ public ulong NTokens { get { - EnsureNotDisposed(); - return NativeApi.mtmd_input_chunk_get_n_tokens(NativePtr).ToUInt64(); + return WithHandle(ptr => NativeApi.mtmd_input_chunk_get_n_tokens(ptr).ToUInt64()); } } @@ -86,8 +88,11 @@ public string Id { get { - EnsureNotDisposed(); - return Marshal.PtrToStringAnsi(NativeApi.mtmd_input_chunk_get_id(NativePtr)) ?? string.Empty; + return WithHandle(ptr => + { + var idPtr = NativeApi.mtmd_input_chunk_get_id(ptr); + return Marshal.PtrToStringAnsi(idPtr) ?? string.Empty; + }); } } @@ -98,8 +103,7 @@ public long NPos { get { - EnsureNotDisposed(); - return NativeApi.mtmd_input_chunk_get_n_pos(NativePtr); + return WithHandle(ptr => NativeApi.mtmd_input_chunk_get_n_pos(ptr)); } } @@ -112,39 +116,63 @@ public unsafe ReadOnlySpan GetTextTokensSpan() { EnsureNotDisposed(); - UIntPtr n; - var p = (uint*)NativeApi.mtmd_input_chunk_get_tokens_text(NativePtr, out n); - return p == null ? ReadOnlySpan.Empty : new ReadOnlySpan(p, checked((int)n.ToUInt64())); + bool added = false; + try + { + DangerousAddRef(ref added); + UIntPtr nTokens; + var tokensPtr = (uint*)NativeApi.mtmd_input_chunk_get_tokens_text(DangerousGetHandle(), out nTokens); + if (tokensPtr == null) + return ReadOnlySpan.Empty; + + var length = checked((int)nTokens.ToUInt64()); + return new ReadOnlySpan(tokensPtr, length); + } + finally + { + if (added) + DangerousRelease(); + } } /// - /// Release the underlying native resources if this instance owns them. + /// Releases the native chunk when ownership is held by this instance. /// - public void Dispose() + protected override bool ReleaseHandle() { - if (_disposed) - return; - - if (_ownsPtr && NativePtr != IntPtr.Zero) + if (handle != IntPtr.Zero) { - NativeApi.mtmd_input_chunk_free(NativePtr); + NativeApi.mtmd_input_chunk_free(handle); + SetHandle(IntPtr.Zero); } - NativePtr = IntPtr.Zero; - _ownsPtr = false; - _disposed = true; - - GC.SuppressFinalize(this); + return true; } - /// - /// Finalizer to ensure native memory is reclaimed when Dispose is not called by owners. - /// - ~SafeMtmdInputChunk() => Dispose(); - private void EnsureNotDisposed() { - if (_disposed || NativePtr == IntPtr.Zero) + if (IsClosed || IsInvalid) throw new ObjectDisposedException(nameof(SafeMtmdInputChunk)); } + + private T WithHandle(Func action) + { + EnsureNotDisposed(); + + bool added = false; + try + { + DangerousAddRef(ref added); + var ptr = DangerousGetHandle(); + if (ptr == IntPtr.Zero) + throw new ObjectDisposedException(nameof(SafeMtmdInputChunk)); + + return action(ptr); + } + finally + { + if (added) + DangerousRelease(); + } + } } diff --git a/LLama/Native/SafeMtmdInputChunks.cs b/LLama/Native/SafeMtmdInputChunks.cs index 2081cd0a6..bd095f36a 100644 --- a/LLama/Native/SafeMtmdInputChunks.cs +++ b/LLama/Native/SafeMtmdInputChunks.cs @@ -6,44 +6,39 @@ namespace LLama.Native; /// /// Managed lifetime wrapper around a native mtmd_input_chunks collection returned by the tokenizer. /// -public sealed class SafeMtmdInputChunks : IDisposable +public sealed class SafeMtmdInputChunks : SafeLLamaHandleBase { /// /// Raw pointer to the native chunk collection. Internal to allow other wrappers to interop safely. /// - internal IntPtr NativePtr { get; private set; } - - private bool _disposed; + internal IntPtr NativePtr + { + get + { + EnsureNotDisposed(); + return DangerousGetHandle(); + } + } internal SafeMtmdInputChunks(IntPtr ptr) + : base(ptr, ownsHandle: true) { - NativePtr = ptr; + if (IsInvalid) + throw new InvalidOperationException("Native MTMD chunk collection pointer is null."); } /// - /// Releases the native chunk collection and suppresses finalization. + /// Releases the native chunk collection. /// - public void Dispose() + protected override bool ReleaseHandle() { - if (_disposed) - return; - - if (NativePtr != IntPtr.Zero) + if (handle != IntPtr.Zero) { - NativeApi.mtmd_input_chunks_free(NativePtr); - NativePtr = IntPtr.Zero; + NativeApi.mtmd_input_chunks_free(handle); + SetHandle(IntPtr.Zero); } - _disposed = true; - GC.SuppressFinalize(this); - } - - /// - /// Finalizer to ensure native memory is reclaimed if Dispose is not called. - /// - ~SafeMtmdInputChunks() - { - Dispose(); + return true; } /// @@ -53,8 +48,7 @@ public ulong Size { get { - EnsureNotDisposed(); - return NativeApi.mtmd_input_chunks_size(NativePtr).ToUInt64(); + return WithHandle(ptr => NativeApi.mtmd_input_chunks_size(ptr).ToUInt64()); } } @@ -68,10 +62,14 @@ public ulong Size /// The requested index is outside of the valid range. public IntPtr GetChunkPtr(ulong index) { - EnsureNotDisposed(); + return WithHandle(ptr => + { + var size = NativeApi.mtmd_input_chunks_size(ptr).ToUInt64(); + if (index >= size) + throw new IndexOutOfRangeException(); - if (index >= Size) throw new IndexOutOfRangeException(); - return NativeApi.mtmd_input_chunks_get(NativePtr, (UIntPtr)index); + return NativeApi.mtmd_input_chunks_get(ptr, (UIntPtr)index); + }); } /// @@ -84,7 +82,8 @@ public IEnumerable Enumerate() { EnsureNotDisposed(); - for (ulong i = 0; i < Size; i++) + var count = Size; + for (ulong i = 0; i < count; i++) { var chunk = SafeMtmdInputChunk.Wrap(GetChunkPtr(i)); if (chunk != null) @@ -97,7 +96,28 @@ public IEnumerable Enumerate() private void EnsureNotDisposed() { - if (_disposed || NativePtr == IntPtr.Zero) + if (IsClosed || IsInvalid) throw new ObjectDisposedException(nameof(SafeMtmdInputChunks)); } + + private T WithHandle(Func action) + { + EnsureNotDisposed(); + + bool added = false; + try + { + DangerousAddRef(ref added); + var ptr = DangerousGetHandle(); + if (ptr == IntPtr.Zero) + throw new ObjectDisposedException(nameof(SafeMtmdInputChunks)); + + return action(ptr); + } + finally + { + if (added) + DangerousRelease(); + } + } } From f13f286ef794589d2ecdee44299c63fbc1c80d90 Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sat, 25 Oct 2025 14:38:05 +0200 Subject: [PATCH 33/35] Add IntPtrExtension To manage PtrToString conversions --- LLama/Extensions/IModelParamsExtensions.cs | 4 +- LLama/Extensions/IntPtrExtensions.cs | 51 ++++++++++++++++++++++ LLama/Native/MtmdContextParams.cs | 28 +----------- LLama/Native/NativeApi.Mtmd.cs | 32 +------------- LLama/Native/SafeLLamaSamplerHandle.cs | 6 ++- LLama/Native/SafeMtmdEmbed.cs | 2 +- LLama/Native/SafeMtmdInputChunk.cs | 4 +- 7 files changed, 64 insertions(+), 63 deletions(-) create mode 100644 LLama/Extensions/IntPtrExtensions.cs diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs index 2939318da..6c307861d 100644 --- a/LLama/Extensions/IModelParamsExtensions.cs +++ b/LLama/Extensions/IModelParamsExtensions.cs @@ -115,7 +115,7 @@ private static IReadOnlyDictionary GetAvailableBufferTypes() var dev = NativeApi.ggml_backend_dev_get(i); var buft = NativeApi.ggml_backend_dev_buffer_type(dev); - var name = Marshal.PtrToStringAnsi(NativeApi.ggml_backend_buft_name(buft)); + var name = NativeApi.ggml_backend_buft_name(buft).PtrToString(); if (string.IsNullOrEmpty(name)) continue; @@ -165,4 +165,4 @@ private static IReadOnlyDictionary GetAvailableBufferTypes() return (LLamaModelTensorBufferOverride*)overrideArrayPin.Pointer; } -} \ No newline at end of file +} diff --git a/LLama/Extensions/IntPtrExtensions.cs b/LLama/Extensions/IntPtrExtensions.cs new file mode 100644 index 000000000..eb5c90850 --- /dev/null +++ b/LLama/Extensions/IntPtrExtensions.cs @@ -0,0 +1,51 @@ +using System; +using System.Runtime.InteropServices; +using System.Text; + +namespace LLama.Extensions; + +public static class IntPtrExtensions +{ + + /// + /// Converts a native UTF-8 string pointer to a managed string, returning a fallback value when no data is available. + /// + /// Pointer to a null-terminated UTF-8 string. + /// Value to return when the pointer is or when the string is empty. + /// Managed string representation of the native data, or when unavailable. + public static string PtrToStringWithDefault(this IntPtr ptr, string defaultValue="") + { + return ptr.PtrToString() ?? defaultValue; + } + + /// + /// Converts a pointer to a null-terminated UTF-8 string into a managed string. + /// + /// Pointer to the first byte of a null-terminated UTF-8 string. + /// Managed string representation, or null when the pointer is zero or the string is empty. + public static string? PtrToString(this IntPtr ptr ) + { + if (ptr == IntPtr.Zero) + return null; + +#if NETSTANDARD2_0 + unsafe + { + var length = 0; + var current = (byte*)ptr; + while (current[length] != 0) + length++; + + if (length == 0) + return null; + + var buffer = new byte[length]; + Marshal.Copy(ptr, buffer, 0, length); + return Encoding.UTF8.GetString(buffer); + } +#else + return Marshal.PtrToStringUTF8(ptr); +#endif + } + +} diff --git a/LLama/Native/MtmdContextParams.cs b/LLama/Native/MtmdContextParams.cs index 5b282d802..fc8d6b5f8 100644 --- a/LLama/Native/MtmdContextParams.cs +++ b/LLama/Native/MtmdContextParams.cs @@ -51,35 +51,11 @@ public static MtmdContextParams Default() PrintTimings = native.print_timings, NThreads = native.n_threads, Verbosity = native.verbosity, - ImageMarker = PtrToString(native.image_marker), - MediaMarker = PtrToString(native.media_marker) + ImageMarker = native.image_marker.PtrToString(), + MediaMarker = native.media_marker.PtrToString() }; } - private static string? PtrToString(IntPtr ptr) - { - if (ptr == IntPtr.Zero) - return null; - -#if NETSTANDARD2_0 - unsafe - { - var length = 0; - var current = (byte*)ptr; - while (current[length] != 0) - length++; - - if (length == 0) - return string.Empty; - - var buffer = new byte[length]; - Marshal.Copy(ptr, buffer, 0, length); - return Encoding.UTF8.GetString(buffer); - } -#else - return Marshal.PtrToStringUTF8(ptr); -#endif - } /// /// Convert the managed representation to a native structure, pinning strings for the duration of the scope. diff --git a/LLama/Native/NativeApi.Mtmd.cs b/LLama/Native/NativeApi.Mtmd.cs index bfd6193c2..0aa8a314a 100644 --- a/LLama/Native/NativeApi.Mtmd.cs +++ b/LLama/Native/NativeApi.Mtmd.cs @@ -9,35 +9,7 @@ namespace LLama.Native; /// public static partial class NativeApi { - /// - /// Convert a UTF-8 encoded native string pointer into a managed . - /// Returns null when the pointer is zero. - /// - public static string? PtrToStringUtf8(IntPtr ptr) - { - if (ptr == IntPtr.Zero) - return null; - -#if NETSTANDARD2_0 - unsafe - { - var current = (byte*)ptr; - var length = 0; - while (current[length] != 0) - length++; - - if (length == 0) - return string.Empty; - - var buffer = new byte[length]; - Marshal.Copy(ptr, buffer, 0, length); - return Encoding.UTF8.GetString(buffer); - } -#else - return Marshal.PtrToStringUTF8(ptr); -#endif - } - + /// /// Native context parameters returned by . /// @@ -59,7 +31,7 @@ internal struct mtmd_context_params /// Retrieve the default multimodal marker text. /// public static string? MtmdDefaultMarker() - => PtrToStringUtf8(mtmd_default_marker()); + => mtmd_default_marker().PtrToString(); [DllImport(mtmdLibraryName, EntryPoint = "mtmd_context_params_default", CallingConvention = CallingConvention.Cdecl)] internal static extern mtmd_context_params mtmd_context_params_default(); diff --git a/LLama/Native/SafeLLamaSamplerHandle.cs b/LLama/Native/SafeLLamaSamplerHandle.cs index a113e1694..8219f1988 100644 --- a/LLama/Native/SafeLLamaSamplerHandle.cs +++ b/LLama/Native/SafeLLamaSamplerHandle.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; +using System.Runtime.InteropServices; using System.Text; +using LLama.Extensions; namespace LLama.Native; @@ -119,7 +121,7 @@ public string GetName(int index) if (index < 0 || index >= Count) throw new ArgumentOutOfRangeException(nameof(index)); - return Marshal.PtrToStringAnsi(llama_sampler_name(llama_sampler_chain_get(this, index))) ?? "Unknown Name"; + return llama_sampler_name(llama_sampler_chain_get(this, index)).PtrToStringWithDefault("Unknown Name"); [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] static extern IntPtr llama_sampler_name(IntPtr smpl); @@ -904,4 +906,4 @@ public interface ICustomSampler /// Create a clone of this sampler /// ICustomSampler Clone(); -} \ No newline at end of file +} diff --git a/LLama/Native/SafeMtmdEmbed.cs b/LLama/Native/SafeMtmdEmbed.cs index 426b77c79..2e92b96a7 100644 --- a/LLama/Native/SafeMtmdEmbed.cs +++ b/LLama/Native/SafeMtmdEmbed.cs @@ -195,7 +195,7 @@ public string? Id return WithHandle(ptr => { var idPtr = NativeApi.mtmd_bitmap_get_id(ptr); - return NativeApi.PtrToStringUtf8(idPtr); + return idPtr.PtrToString(); }); } set diff --git a/LLama/Native/SafeMtmdInputChunk.cs b/LLama/Native/SafeMtmdInputChunk.cs index 3ab4f7ccc..fe08d50f7 100644 --- a/LLama/Native/SafeMtmdInputChunk.cs +++ b/LLama/Native/SafeMtmdInputChunk.cs @@ -1,5 +1,5 @@ using System; -using System.Runtime.InteropServices; +using LLama.Extensions; namespace LLama.Native; @@ -91,7 +91,7 @@ public string Id return WithHandle(ptr => { var idPtr = NativeApi.mtmd_input_chunk_get_id(ptr); - return Marshal.PtrToStringAnsi(idPtr) ?? string.Empty; + return idPtr.PtrToStringWithDefault(string.Empty); }); } } From 09fb90dbe7f6a8eb25a0f2cf236d336dcf5c0233 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Luis=20Santiago?= Date: Sun, 26 Oct 2025 17:21:57 +0100 Subject: [PATCH 34/35] Solve bad DLL naming in Windows with MTMD libraries --- LLama/LLamaSharp.Runtime.targets | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/LLama/LLamaSharp.Runtime.targets b/LLama/LLamaSharp.Runtime.targets index 89caa042a..c99c508dd 100644 --- a/LLama/LLamaSharp.Runtime.targets +++ b/LLama/LLamaSharp.Runtime.targets @@ -396,19 +396,19 @@ PreserveNewest - runtimes/win-x64/native/noavx/libmtmd.dll + runtimes/win-x64/native/noavx/mtmd.dll PreserveNewest - runtimes/win-x64/native/avx/libmtmd.dll + runtimes/win-x64/native/avx/mtmd.dll PreserveNewest - runtimes/win-x64/native/avx2/libmtmd.dll + runtimes/win-x64/native/avx2/mtmd.dll PreserveNewest - runtimes/win-x64/native/avx512/libmtmd.dll + runtimes/win-x64/native/avx512/mtmd.dll PreserveNewest @@ -416,7 +416,7 @@ PreserveNewest - runtimes/win-x64/native/vulkan/libmtmd.dll + runtimes/win-x64/native/vulkan/mtmd.dll From 43c5dc773f880e437047542631e30ac27673e2f0 Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sat, 13 Dec 2025 20:08:35 +0100 Subject: [PATCH 35/35] Change Macos build only ARM --- .github/workflows/compile.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/compile.yml b/.github/workflows/compile.yml index 5cd1f2097..6b160a101 100644 --- a/.github/workflows/compile.yml +++ b/.github/workflows/compile.yml @@ -460,7 +460,7 @@ jobs: matrix: include: - build: 'arm64' - defines: '-DCMAKE_OSX_ARCHITECTURES=arm64 -DGGML_METAL_EMBED_LIBRARY=ON' + defines: '-DCMAKE_OSX_ARCHITECTURES=arm64 -DGGML_METAL_EMBED_LIBRARY=ON -DGGML_METAL_USE_BF16=ON -DLLAMA_CURL=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_TOOLS=OFF DLLAMA_BUILD_BORINGSSL=ON ' - build: 'x64' defines: '-DCMAKE_OSX_ARCHITECTURES=x86_64 -DGGML_METAL=OFF -DGGML_AVX=ON -DGGML_AVX2=ON' - build: 'x64-rosetta2'