From 6859e5795764b51d948f955df7058093056c09d2 Mon Sep 17 00:00:00 2001 From: SignalRT Date: Mon, 19 Jan 2026 23:45:17 +0100 Subject: [PATCH 1/8] Improve LLama.Web Initial version --- .gitignore | 2 + LLama.Web/App.razor | 13 + LLama.Web/Async/AsyncGuard.cs | 6 +- LLama.Web/Async/AsyncLock.cs | 6 +- LLama.Web/Common/ModelLoadType.cs | 10 +- LLama.Web/Common/ModelOptions.cs | 21 +- LLama.Web/Controllers/AttachmentController.cs | 58 ++ LLama.Web/Extensions.cs | 8 +- LLama.Web/Hubs/ISessionClient.cs | 3 + LLama.Web/Hubs/SessionConnectionHub.cs | 46 +- LLama.Web/LLama.Web.csproj | 1 + LLama.Web/Models/LLamaModel.cs | 28 +- LLama.Web/Models/ModelSession.cs | 187 +++- LLama.Web/Pages/Chat.razor | 941 ++++++++++++++++++ LLama.Web/Pages/Index.cshtml | 110 -- LLama.Web/Pages/Index.cshtml.cs | 45 - LLama.Web/Pages/Shared/_ChatTemplates.cshtml | 12 +- LLama.Web/Pages/Shared/_Layout.cshtml | 82 +- LLama.Web/Pages/Shared/_Parameters.cshtml | 61 +- LLama.Web/Pages/_Host.cshtml | 8 + LLama.Web/Program.cs | 16 +- LLama.Web/README.md | 11 + LLama.Web/Services/AttachmentService.cs | 359 +++++++ LLama.Web/Services/IAttachmentService.cs | 14 + LLama.Web/Services/IModelDownloadService.cs | 13 + LLama.Web/Services/IModelService.cs | 12 +- LLama.Web/Services/IModelSessionService.cs | 39 +- LLama.Web/Services/ModelDownloadService.cs | 301 ++++++ LLama.Web/Services/ModelLoaderService.cs | 17 +- LLama.Web/Services/ModelService.cs | 31 +- LLama.Web/Services/ModelSessionService.cs | 215 +++- LLama.Web/Shared/MainLayout.razor | 5 + LLama.Web/_Imports.razor | 13 + LLama.Web/appsettings.json | 74 +- LLama.Web/libman.json | 67 +- LLama.Web/wwwroot/css/site.css | 586 +++++++++-- LLama.Web/wwwroot/js/blazorChat.js | 193 ++++ LLama.Web/wwwroot/js/sessionConnectionChat.js | 398 +++++++- 38 files changed, 3548 insertions(+), 464 deletions(-) create mode 100644 LLama.Web/App.razor create mode 100644 LLama.Web/Controllers/AttachmentController.cs create mode 100644 LLama.Web/Pages/Chat.razor delete mode 100644 LLama.Web/Pages/Index.cshtml delete mode 100644 LLama.Web/Pages/Index.cshtml.cs create mode 100644 LLama.Web/Pages/_Host.cshtml create mode 100644 LLama.Web/Services/AttachmentService.cs create mode 100644 LLama.Web/Services/IAttachmentService.cs create mode 100644 LLama.Web/Services/IModelDownloadService.cs create mode 100644 LLama.Web/Services/ModelDownloadService.cs create mode 100644 LLama.Web/Shared/MainLayout.razor create mode 100644 LLama.Web/_Imports.razor create mode 100644 LLama.Web/wwwroot/js/blazorChat.js diff --git a/.gitignore b/.gitignore index 935538441..8bd6b31fc 100644 --- a/.gitignore +++ b/.gitignore @@ -355,3 +355,5 @@ site/ /Llama.Mobile/Resources/Raw /nuget_pack.bat /temp +LLama.Web/wwwroot/lib/ +LLama.Web/Models/ diff --git a/LLama.Web/App.razor b/LLama.Web/App.razor new file mode 100644 index 000000000..a50a5f9d1 --- /dev/null +++ b/LLama.Web/App.razor @@ -0,0 +1,13 @@ + + + + + + +
+

Page not found

+

The page you requested does not exist.

+
+
+
+
diff --git a/LLama.Web/Async/AsyncGuard.cs b/LLama.Web/Async/AsyncGuard.cs index ff6b6c43a..9dd0d425d 100644 --- a/LLama.Web/Async/AsyncGuard.cs +++ b/LLama.Web/Async/AsyncGuard.cs @@ -4,7 +4,7 @@ namespace LLama.Web.Async { /// - /// Creates a async/thread-safe guard helper + /// Creates an async/thread-safe guard helper. /// /// public class AsyncGuard : AsyncGuard @@ -26,7 +26,7 @@ public AsyncGuard() /// /// Guards this instance. /// - /// true if able to enter an guard, false if already guarded + /// true if able to enter a guard; false if already guarded. public bool Guard() { return _lockData.TryAdd(_key, true); @@ -74,7 +74,7 @@ public AsyncGuard() /// Guards the specified value. /// /// The value. - /// true if able to enter a guard for this value, false if this value is already guarded + /// true if able to enter a guard for this value; false if this value is already guarded. public bool Guard(T value) { return _lockData.TryAdd(value, true); diff --git a/LLama.Web/Async/AsyncLock.cs b/LLama.Web/Async/AsyncLock.cs index df294bf8b..6d839c1c1 100644 --- a/LLama.Web/Async/AsyncLock.cs +++ b/LLama.Web/Async/AsyncLock.cs @@ -1,7 +1,7 @@ namespace LLama.Web.Async { /// - /// Create an Async locking using statement + /// Creates an async lock for a using block. /// public sealed class AsyncLock { @@ -20,7 +20,7 @@ public AsyncLock() /// - /// Locks the using statement asynchronously. + /// Acquires the lock asynchronously. /// /// public Task LockAsync() @@ -34,7 +34,7 @@ public Task LockAsync() /// - /// IDisposable wrapper class to release the lock on dispose + /// IDisposable wrapper to release the lock on dispose. /// /// private sealed class Releaser : IDisposable diff --git a/LLama.Web/Common/ModelLoadType.cs b/LLama.Web/Common/ModelLoadType.cs index 9e1c77b7b..d4664637c 100644 --- a/LLama.Web/Common/ModelLoadType.cs +++ b/LLama.Web/Common/ModelLoadType.cs @@ -1,29 +1,29 @@ namespace LLama.Web.Common { /// - /// The type of model load caching to use + /// The type of model load caching to use. /// public enum ModelLoadType { /// - /// Only one model will be loaded into memory at a time, any other models will be unloaded before the new one is loaded + /// Only one model is loaded into memory at a time; any other models are unloaded before the new one is loaded. /// Single = 0, /// - /// Multiple models will be loaded into memory, ensure you use the ModelConfigs to split the hardware resources + /// Multiple models are loaded into memory; ensure you use the ModelConfigs to split hardware resources. /// Multiple = 1, /// - /// The first model in the appsettings.json list will be preloaded into memory at app startup + /// The first model in the appsettings.json list is preloaded into memory at app startup. /// PreloadSingle = 2, /// - /// All models in the appsettings.json list will be preloaded into memory at app startup, ensure you use the ModelConfigs to split the hardware resources + /// All models in the appsettings.json list are preloaded into memory at app startup; ensure you use the ModelConfigs to split hardware resources. /// PreloadMultiple = 3, } diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index 586db5611..da12b06f7 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -8,12 +8,12 @@ public class ModelOptions : ILLamaParams { /// - /// Model friendly name + /// Model-friendly name. /// public string Name { get; set; } /// - /// Max context insta=nces allowed per model + /// Max context instances allowed per model. /// public int MaxInstances { get; set; } @@ -47,6 +47,21 @@ public class ModelOptions /// public string ModelPath { get; set; } + /// + /// Download URL for the model file (GGUF). + /// + public string ModelDownloadUrl { get; set; } + + /// + /// Optional local path to the mmproj (multimodal projection) file. + /// + public string MmprojPath { get; set; } + + /// + /// Download URL for the mmproj file, if required. + /// + public string MmprojDownloadUrl { get; set; } + /// public int? Threads { get; set; } @@ -128,4 +143,4 @@ public class ModelOptions /// public LLamaAttentionType AttentionType { get; set; } = LLamaAttentionType.Unspecified; } -} \ No newline at end of file +} diff --git a/LLama.Web/Controllers/AttachmentController.cs b/LLama.Web/Controllers/AttachmentController.cs new file mode 100644 index 000000000..fe57dc6a2 --- /dev/null +++ b/LLama.Web/Controllers/AttachmentController.cs @@ -0,0 +1,58 @@ +using LLama.Web.Models; +using LLama.Web.Services; +using Microsoft.AspNetCore.Mvc; + +namespace LLama.Web.Controllers; + +[ApiController] +[Route("api/attachments")] +public class AttachmentController : ControllerBase +{ + private readonly IAttachmentService _attachmentService; + + public AttachmentController(IAttachmentService attachmentService) + { + _attachmentService = attachmentService; + } + + [HttpPost] + [RequestSizeLimit(256_000_000)] + public async Task> Upload([FromForm] string connectionId, [FromForm] List files, CancellationToken cancellationToken) + { + if (string.IsNullOrWhiteSpace(connectionId)) + return BadRequest("Missing connectionId."); + + if (files is null || files.Count == 0) + return BadRequest("No files provided."); + + try + { + var result = await _attachmentService.SaveAsync(connectionId, files, cancellationToken); + return Ok(result); + } + catch (InvalidOperationException ex) + { + return BadRequest(ex.Message); + } + } + + [HttpGet("{connectionId}/{id}")] + public ActionResult Download(string connectionId, string id) + { + if (string.IsNullOrWhiteSpace(connectionId) || string.IsNullOrWhiteSpace(id)) + return BadRequest("Missing connectionId or attachment id."); + + var attachment = _attachmentService.GetAttachment(connectionId, id); + if (attachment == null || string.IsNullOrWhiteSpace(attachment.FilePath)) + return NotFound(); + + if (!System.IO.File.Exists(attachment.FilePath)) + return NotFound(); + + var contentType = string.IsNullOrWhiteSpace(attachment.ContentType) + ? "application/octet-stream" + : attachment.ContentType; + + return PhysicalFile(attachment.FilePath, contentType, attachment.FileName); + } +} diff --git a/LLama.Web/Extensions.cs b/LLama.Web/Extensions.cs index 1fe08f6eb..d8995f082 100644 --- a/LLama.Web/Extensions.cs +++ b/LLama.Web/Extensions.cs @@ -2,10 +2,10 @@ namespace LLama.Web; -public static class Extensions +public static class Extensions { /// - /// Combines the AntiPrompts list and AntiPrompt csv + /// Combines the AntiPrompts list and AntiPrompt CSV. /// /// The session configuration. /// Combined AntiPrompts with duplicates removed @@ -15,7 +15,7 @@ public static List GetAntiPrompts(this ISessionConfig sessionConfig) } /// - /// Combines the OutputFilters list and OutputFilter csv + /// Combines the OutputFilters list and OutputFilter CSV. /// /// The session configuration. /// Combined OutputFilters with duplicates removed @@ -25,7 +25,7 @@ public static List GetOutputFilters(this ISessionConfig sessionConfig) } /// - /// Combines a string list and a csv and removes duplicates + /// Combines a string list and a CSV, and removes duplicates. /// /// The list. /// The CSV. diff --git a/LLama.Web/Hubs/ISessionClient.cs b/LLama.Web/Hubs/ISessionClient.cs index 92302b219..552fa6133 100644 --- a/LLama.Web/Hubs/ISessionClient.cs +++ b/LLama.Web/Hubs/ISessionClient.cs @@ -7,5 +7,8 @@ public interface ISessionClient { Task OnStatus(string connectionId, SessionConnectionStatus status); Task OnError(string error); + Task OnModelDownloadSnapshot(IReadOnlyList snapshots); + Task OnModelDownloadProgress(ModelDownloadProgress progress); + Task OnStorageInfo(StorageInfo info); } } diff --git a/LLama.Web/Hubs/SessionConnectionHub.cs b/LLama.Web/Hubs/SessionConnectionHub.cs index b6a977e36..425aaf8e1 100644 --- a/LLama.Web/Hubs/SessionConnectionHub.cs +++ b/LLama.Web/Hubs/SessionConnectionHub.cs @@ -9,19 +9,30 @@ public class SessionConnectionHub : Hub { private readonly ILogger _logger; private readonly IModelSessionService _modelSessionService; + private readonly IModelDownloadService _modelDownloadService; + private readonly IAttachmentService _attachmentService; - public SessionConnectionHub(ILogger logger, IModelSessionService modelSessionService) + public SessionConnectionHub(ILogger logger, IModelSessionService modelSessionService, IModelDownloadService modelDownloadService, IAttachmentService attachmentService) { _logger = logger; _modelSessionService = modelSessionService; + _modelDownloadService = modelDownloadService; + _attachmentService = attachmentService; } public override async Task OnConnectedAsync() { _logger.Log(LogLevel.Information, "[OnConnectedAsync], Id: {0}", Context.ConnectionId); - // Notify client of successful connection + // Notify the client of a successful connection. await Clients.Caller.OnStatus(Context.ConnectionId, SessionConnectionStatus.Connected); + await Clients.Caller.OnModelDownloadSnapshot(_modelDownloadService.GetSnapshots()); + await Clients.Caller.OnStorageInfo(new StorageInfo + { + ModelsPath = _modelDownloadService.ModelsRoot, + DownloadsPath = _modelDownloadService.DownloadsRoot, + UploadsPath = _attachmentService.UploadsRoot + }); await base.OnConnectedAsync(); } @@ -29,8 +40,9 @@ public override async Task OnDisconnectedAsync(Exception exception) { _logger.Log(LogLevel.Information, "[OnDisconnectedAsync], Id: {0}", Context.ConnectionId); - // Remove connections session on disconnect + // Remove the connection session on disconnect. await _modelSessionService.CloseAsync(Context.ConnectionId); + await _attachmentService.CleanupAsync(Context.ConnectionId); await base.OnDisconnectedAsync(exception); } @@ -40,24 +52,38 @@ public async Task OnLoadModel(SessionConfig sessionConfig, InferenceOptions infe _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId); await _modelSessionService.CloseAsync(Context.ConnectionId); - // Create model session - var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, sessionConfig, inferenceConfig); - if (modelSession is null) + try + { + // Create model session. + var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, sessionConfig, inferenceConfig); + if (modelSession is null) + { + await Clients.Caller.OnError("Failed to create model session"); + return; + } + } + catch (Exception ex) { - await Clients.Caller.OnError("Failed to create model session"); + await Clients.Caller.OnError(ex.Message); return; } - // Notify client + // Notify client. await Clients.Caller.OnStatus(Context.ConnectionId, SessionConnectionStatus.Loaded); } [HubMethodName("SendPrompt")] - public IAsyncEnumerable OnSendPrompt(string prompt, InferenceOptions inferConfig, CancellationToken cancellationToken) + public IAsyncEnumerable OnSendPrompt(PromptRequest request, InferenceOptions inferConfig, CancellationToken cancellationToken) { _logger.Log(LogLevel.Information, "[OnSendPrompt] - New prompt received, Connection: {0}", Context.ConnectionId); var linkedCancelationToken = CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted, cancellationToken); - return _modelSessionService.InferAsync(Context.ConnectionId, prompt, inferConfig, linkedCancelationToken.Token); + return _modelSessionService.InferAsync(Context.ConnectionId, request, inferConfig, linkedCancelationToken.Token); + } + + [HubMethodName("GetModelStatuses")] + public Task> GetModelStatuses() + { + return Task.FromResult(_modelDownloadService.GetSnapshots()); } } diff --git a/LLama.Web/LLama.Web.csproj b/LLama.Web/LLama.Web.csproj index 06311bd07..eb9354d40 100644 --- a/LLama.Web/LLama.Web.csproj +++ b/LLama.Web/LLama.Web.csproj @@ -17,6 +17,7 @@ + diff --git a/LLama.Web/Models/LLamaModel.cs b/LLama.Web/Models/LLamaModel.cs index af228bd4c..ba9aa4c8d 100644 --- a/LLama.Web/Models/LLamaModel.cs +++ b/LLama.Web/Models/LLamaModel.cs @@ -1,4 +1,6 @@ +using LLama; using LLama.Abstractions; +using LLama.Native; using LLama.Web.Common; using System.Collections.Concurrent; @@ -13,6 +15,7 @@ public class LLamaModel : IDisposable private readonly ILogger _llamaLogger; private readonly ModelOptions _config; private readonly LLamaWeights _weights; + private readonly MtmdWeights _mtmdWeights; private readonly ConcurrentDictionary _contexts; /// @@ -21,11 +24,12 @@ public class LLamaModel : IDisposable /// The model parameters. /// A logger class. /// Model weights. - private LLamaModel(ModelOptions modelParams, ILogger llamaLogger, LLamaWeights weights) + private LLamaModel(ModelOptions modelParams, ILogger llamaLogger, LLamaWeights weights, MtmdWeights mtmdWeights) { _config = modelParams; _llamaLogger = llamaLogger; _weights = weights; + _mtmdWeights = mtmdWeights; _contexts = new ConcurrentDictionary(); } @@ -37,7 +41,21 @@ private LLamaModel(ModelOptions modelParams, ILogger llamaLogger, LLamaWeights w public static async Task CreateAsync(ModelOptions modelParams, ILogger llamaLogger) { var weights = await LLamaWeights.LoadFromFileAsync(modelParams); - return new LLamaModel(modelParams, llamaLogger, weights); + + MtmdWeights mtmdWeights = null; + if (!string.IsNullOrWhiteSpace(modelParams.MmprojPath)) + { + if (!File.Exists(modelParams.MmprojPath)) + throw new FileNotFoundException($"Mmproj file not found: {modelParams.MmprojPath}"); + + var mtmdParams = MtmdContextParams.Default(); + mtmdParams.UseGpu = false; + if (modelParams.Threads.HasValue) + mtmdParams.NThreads = modelParams.Threads.Value; + mtmdWeights = await MtmdWeights.LoadFromFileAsync(modelParams.MmprojPath, weights, mtmdParams); + } + + return new LLamaModel(modelParams, llamaLogger, weights, mtmdWeights); } /// @@ -50,6 +68,11 @@ public static async Task CreateAsync(ModelOptions modelParams, ILogg /// public LLamaWeights LLamaWeights => _weights; + /// + /// Gets the multimodal weights, if configured. + /// + public MtmdWeights MtmdWeights => _mtmdWeights; + /// /// Gets the context count. /// @@ -119,5 +142,6 @@ public void Dispose() context?.Dispose(); } _weights.Dispose(); + _mtmdWeights?.Dispose(); } } diff --git a/LLama.Web/Models/ModelSession.cs b/LLama.Web/Models/ModelSession.cs index 71e799ca5..53f76d8c7 100644 --- a/LLama.Web/Models/ModelSession.cs +++ b/LLama.Web/Models/ModelSession.cs @@ -1,4 +1,7 @@ +using LLama; using LLama.Abstractions; +using LLama.Common; +using LLama.Native; using LLama.Sampling; using LLama.Web.Common; @@ -13,6 +16,10 @@ public class ModelSession private readonly ISessionConfig _sessionConfig; private readonly ITextStreamTransform _outputTransform; private readonly InferenceOptions _defaultInferenceConfig; + private readonly ChatHistory _chatHistory = new(); + private readonly string? _systemPrompt; + private readonly string? _templateUserMarker; + private readonly object _historyLock = new(); private CancellationTokenSource _cancellationTokenSource; @@ -28,6 +35,11 @@ public ModelSession(LLamaModel model, LLamaContext context, string sessionId, IS }; _outputTransform = CreateOutputFilter(); _executor = CreateExecutor(); + _systemPrompt = _sessionConfig.Prompt; + _templateUserMarker = ResolveTemplateUserMarker(); + + if (_sessionConfig.ExecutorType != LLamaExecutorType.Stateless && !string.IsNullOrWhiteSpace(_systemPrompt)) + _chatHistory.AddMessage(AuthorRole.System, _systemPrompt); } /// @@ -55,6 +67,43 @@ public ModelSession(LLamaModel model, LLamaContext context, string sessionId, IS /// public InferenceOptions InferenceParams => _defaultInferenceConfig; + /// + /// Returns true if the executor has multimodal support enabled. + /// + public bool IsMultiModal => _executor.IsMultiModal; + + /// + /// Returns true if the multimodal model supports audio inputs. + /// + public bool SupportsAudio => _executor.ClipModel?.SupportsAudio ?? false; + + /// + /// Returns true if the multimodal model supports vision inputs. + /// + public bool SupportsVision => _executor.ClipModel?.SupportsVision ?? false; + + /// + /// Gets the MTMD clip model, if available. + /// + public MtmdWeights ClipModel => _executor.ClipModel; + + /// + /// Returns true if the session should preserve chat history. + /// + public bool StoresHistory => _sessionConfig.ExecutorType != LLamaExecutorType.Stateless; + + /// + /// Queue media embeddings for the next inference call. + /// + public void QueueEmbeds(IEnumerable embeds) + { + if (embeds is null) + return; + + foreach (var embed in embeds) + _executor.Embeds.Add(embed); + } + /// /// Initializes the prompt. /// @@ -65,17 +114,15 @@ internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, Ca if (_sessionConfig.ExecutorType == LLamaExecutorType.Stateless) return; - if (string.IsNullOrEmpty(_sessionConfig.Prompt)) + if (_chatHistory.Messages.Count == 0) return; - // Run Initial prompt - var inferenceParams = ConfigureInferenceParams(inferenceConfig); - _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - await foreach (var _ in _executor.InferAsync(_sessionConfig.Prompt, inferenceParams, _cancellationTokenSource.Token)) + if (_executor is StatefulExecutorBase statefulExecutor) { - // We dont really need the response of the initial prompt, so exit on first token - break; - }; + var prompt = FormatChatHistory(_chatHistory, addAssistant: false); + if (!string.IsNullOrWhiteSpace(prompt)) + await statefulExecutor.PrefillPromptAsync(prompt, cancellationToken); + } } /// @@ -107,6 +154,36 @@ public bool IsInferCanceled() return _cancellationTokenSource.IsCancellationRequested; } + /// + /// Builds a prompt for the current chat session using the model template. + /// + /// The user content to include in the prompt. + /// The formatted prompt to send to the model. + internal string BuildPrompt(string userContent) + { + userContent ??= string.Empty; + + if (_sessionConfig.ExecutorType == LLamaExecutorType.Stateless) + return BuildStatelessPrompt(userContent); + + return BuildChatDelta(userContent); + } + + /// + /// Adds the assistant response to the chat history. + /// + /// Assistant response content. + internal void AddAssistantMessage(string assistantContent) + { + if (_sessionConfig.ExecutorType == LLamaExecutorType.Stateless) + return; + + lock (_historyLock) + { + _chatHistory.AddMessage(AuthorRole.Assistant, assistantContent ?? string.Empty); + } + } + /// /// Configures the inference parameters. /// @@ -114,7 +191,18 @@ public bool IsInferCanceled() private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig) { var inferenceParams = inferenceConfig ?? _defaultInferenceConfig; - inferenceParams.AntiPrompts = _sessionConfig.GetAntiPrompts(); + var antiPrompts = _sessionConfig.GetAntiPrompts(); + if (!string.IsNullOrWhiteSpace(_templateUserMarker)) + { + if (!antiPrompts.Contains(_templateUserMarker)) + antiPrompts.Add(_templateUserMarker); + + var trimmedMarker = _templateUserMarker.Trim(); + if (!string.IsNullOrWhiteSpace(trimmedMarker) && !antiPrompts.Contains(trimmedMarker)) + antiPrompts.Add(trimmedMarker); + } + + inferenceParams.AntiPrompts = antiPrompts; return inferenceParams; } @@ -131,10 +219,87 @@ private ILLamaExecutor CreateExecutor() { return _sessionConfig.ExecutorType switch { - LLamaExecutorType.Interactive => new InteractiveExecutor(_context), - LLamaExecutorType.Instruct => new InstructExecutor(_context), + LLamaExecutorType.Interactive => _model.MtmdWeights is null + ? new InteractiveExecutor(_context) + : new InteractiveExecutor(_context, _model.MtmdWeights), + LLamaExecutorType.Instruct => _model.MtmdWeights is null + ? new InstructExecutor(_context) + : new InstructExecutor(_context, _model.MtmdWeights), LLamaExecutorType.Stateless => new StatelessExecutor(_model.LLamaWeights, _context.Params), _ => default }; } + + private string BuildChatDelta(string userContent) + { + lock (_historyLock) + { + var pastPrompt = FormatChatHistory(_chatHistory, addAssistant: false); + _chatHistory.AddMessage(AuthorRole.User, userContent); + var fullPrompt = FormatChatHistory(_chatHistory, addAssistant: true); + + if (!fullPrompt.StartsWith(pastPrompt, StringComparison.Ordinal)) + return fullPrompt; + + var delta = fullPrompt[pastPrompt.Length..]; + if (pastPrompt.Length > 0 && pastPrompt.EndsWith('\n') && (delta.Length == 0 || delta[0] != '\n')) + delta = "\n" + delta; + + return delta; + } + } + + private string BuildStatelessPrompt(string userContent) + { + var history = new ChatHistory(); + if (!string.IsNullOrWhiteSpace(_systemPrompt)) + history.AddMessage(AuthorRole.System, _systemPrompt); + history.AddMessage(AuthorRole.User, userContent); + + return FormatChatHistory(history, addAssistant: true); + } + + private string FormatChatHistory(ChatHistory history, bool addAssistant) + { + var template = new LLamaTemplate(_model.LLamaWeights.NativeHandle) + { + AddAssistant = addAssistant + }; + + foreach (var message in history.Messages) + template.Add(message.AuthorRole.ToString().ToLowerInvariant(), message.Content); + + return LLamaTemplate.Encoding.GetString(template.Apply()); + } + + private string? ResolveTemplateUserMarker() + { + const string userMarkerA = "__LLAMA_USER_A__"; + const string assistantMarkerA = "__LLAMA_ASSISTANT_A__"; + const string userMarkerB = "__LLAMA_USER_B__"; + + try + { + var template = new LLamaTemplate(_model.LLamaWeights.NativeHandle) + { + AddAssistant = false + }; + template.Add("user", userMarkerA); + template.Add("assistant", assistantMarkerA); + template.Add("user", userMarkerB); + + var rendered = LLamaTemplate.Encoding.GetString(template.Apply()); + var assistantIndex = rendered.IndexOf(assistantMarkerA, StringComparison.Ordinal); + var userIndex = rendered.IndexOf(userMarkerB, StringComparison.Ordinal); + if (assistantIndex < 0 || userIndex <= assistantIndex) + return null; + + var between = rendered.Substring(assistantIndex + assistantMarkerA.Length, userIndex - (assistantIndex + assistantMarkerA.Length)); + return string.IsNullOrWhiteSpace(between) ? null : between; + } + catch + { + return null; + } + } } diff --git a/LLama.Web/Pages/Chat.razor b/LLama.Web/Pages/Chat.razor new file mode 100644 index 000000000..ec41f2a39 --- /dev/null +++ b/LLama.Web/Pages/Chat.razor @@ -0,0 +1,941 @@ +@page "/" +@implements IAsyncDisposable +@inject IModelDownloadService ModelDownloadService +@inject IModelSessionService ModelSessionService +@inject IAttachmentService AttachmentService +@inject IOptions OptionsAccessor +@inject IJSRuntime JS + +
+ + +
+
+
+

Chat Playground

+ Streaming, multimodal, and markdown-ready. +
+
+ @_statusText +
+
+ +
+ @foreach (var message in _messages) + { +
+
@message.Avatar
+
+
+ @if (message.Attachments.Any(attachment => attachment.Kind == AttachmentKind.Image)) + { +
+ @foreach (var attachment in message.Attachments.Where(attachment => attachment.Kind == AttachmentKind.Image)) + { + @attachment.FileName + } +
+ } + @if (message.Attachments.Any(attachment => attachment.Kind == AttachmentKind.Audio)) + { +
+ @foreach (var attachment in message.Attachments.Where(attachment => attachment.Kind == AttachmentKind.Audio)) + { +
+
@attachment.FileName
+ +
+ } +
+ } + @if (message.Attachments.Any()) + { +
+ @foreach (var attachment in message.Attachments) + { + @attachment.FileName + } +
+ } + @if (!string.IsNullOrWhiteSpace(message.Meta)) + { +
@message.Meta
+ } +
+
+ } +
+ + @if (!string.IsNullOrWhiteSpace(_inputText)) + { +
+
+
U
+
+
@_inputText
+
+
+
+ } + +
+
+
+ + @if (_modelCapabilities.SupportsAudio && _audioRecordingSupported) + { + + } + +
+ +
+ @if (_pendingFiles.Count > 0) + { +
+ @foreach (var file in _pendingFiles) + { + @file.Name + } +
+ } + @if (_hasRecordedAudio) + { +
+
+
@_recordedAudioName
+ +
+
+
+
+ + +
+
+ } + + @if (!string.IsNullOrWhiteSpace(_composerError)) + { +
@_composerError
+ } +
+
+ + +
+
+
+
+
+ +@code { + private readonly List _messages = new(); + private readonly HashSet _pendingRenders = new(); + private readonly List _pendingFiles = new(); + private readonly List _downloadSnapshots = new(); + + private ElementReference _threadRef; + private bool _initialized; + private bool _sessionLoaded; + private bool _isSending; + private bool _scrollToBottom; + private string _inputText = string.Empty; + private string _statusText = "Model not loaded"; + private string _composerHint = "Select a model and start a session."; + private string _composerError = string.Empty; + private string _acceptedUploadTypes = "application/pdf,application/vnd.openxmlformats-officedocument.wordprocessingml.document,image/*"; + private SessionConfig _sessionConfig = new(); + private InferenceOptions _inferenceOptions = new(); + private SamplingSettings _samplingSettings = new(); + private string _sessionId = Guid.NewGuid().ToString("N"); + private IReadOnlyList _models = Array.Empty(); + private StorageInfo _storageInfo; + private PeriodicTimer _downloadTimer; + private CancellationTokenSource _downloadCts; + private SidebarTab _activeTab = SidebarTab.Models; + private ModelCapabilities _modelCapabilities = new ModelCapabilities { SupportsText = true }; + private string _capabilitiesHint = "Load a model session to inspect capabilities."; + private bool _audioRecordingSupported; + private bool _isRecordingAudio; + private bool _hasRecordedAudio; + private string _recordedAudioName = string.Empty; + private string _recordedAudioPreviewUrl = string.Empty; + private MemoryBrowserFile _recordedAudioFile; + + private bool _canLoad => !_sessionLoaded && IsModelReady(_sessionConfig.Model); + + protected override Task OnInitializedAsync() + { + _models = OptionsAccessor.Value.Models ?? new List(); + _sessionConfig = new SessionConfig + { + Model = _models.FirstOrDefault()?.Name, + ExecutorType = LLamaExecutorType.Interactive, + Prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.", + AntiPrompt = "User:,<|im_end|>,<|eot_id|>,<|endoftext|>", + OutputFilter = "User:, Assistant:,<|im_end|>,<|eot_id|>,<|endoftext|>" + }; + + _samplingSettings = SamplingSettings.CreateDefault(); + _inferenceOptions = BuildInferenceOptions(); + + _storageInfo = new StorageInfo + { + ModelsPath = ModelDownloadService.ModelsRoot, + DownloadsPath = ModelDownloadService.DownloadsRoot, + UploadsPath = AttachmentService.UploadsRoot + }; + + _downloadCts = new CancellationTokenSource(); + _downloadTimer = new PeriodicTimer(TimeSpan.FromMilliseconds(800)); + _ = PollDownloadsAsync(_downloadCts.Token); + + _acceptedUploadTypes = BuildAcceptedUploadTypes(); + return Task.CompletedTask; + } + + protected override async Task OnAfterRenderAsync(bool firstRender) + { + if (firstRender) + { + await JS.InvokeVoidAsync("llamaChat.initialize"); + _audioRecordingSupported = await JS.InvokeAsync("llamaChat.isAudioRecordingSupported"); + _initialized = true; + var storedTab = await JS.InvokeAsync("llamaChat.getSidebarTab"); + if (!string.IsNullOrWhiteSpace(storedTab) && Enum.TryParse(storedTab, out var tab)) + { + _activeTab = tab; + } + } + + if (_initialized && _pendingRenders.Count > 0) + { + var toRender = _messages.Where(message => _pendingRenders.Contains(message.Id)).ToList(); + foreach (var message in toRender) + { + await JS.InvokeVoidAsync("llamaChat.renderMarkdown", message.ContentRef, message.Content); + _pendingRenders.Remove(message.Id); + } + } + + if (_scrollToBottom) + { + await JS.InvokeVoidAsync("llamaChat.scrollToBottom", _threadRef); + _scrollToBottom = false; + } + } + + private async Task PollDownloadsAsync(CancellationToken cancellationToken) + { + while (await _downloadTimer.WaitForNextTickAsync(cancellationToken)) + { + _downloadSnapshots.Clear(); + _downloadSnapshots.AddRange(ModelDownloadService.GetSnapshots()); + _composerHint = IsModelReady(_sessionConfig.Model) + ? "Ready to chat." + : "Model is downloading..."; + await InvokeAsync(StateHasChanged); + } + } + + private async Task LoadSessionAsync() + { + if (!IsModelReady(_sessionConfig.Model)) + { + AddSystemMessage("Selected model is still downloading.", MessageKind.Error); + return; + } + + await UnloadSessionAsync(); + + try + { + _inferenceOptions = BuildInferenceOptions(); + await ModelSessionService.CreateAsync(_sessionId, _sessionConfig, _inferenceOptions); + _sessionLoaded = true; + _statusText = $"Model loaded: {_sessionConfig.Model}"; + _composerHint = "Ask about images, PDFs, or Word documents."; + await RefreshCapabilitiesAsync(); + } + catch (Exception ex) + { + AddSystemMessage($"Failed to start session: {ex.Message}", MessageKind.Error); + _sessionLoaded = false; + _statusText = "Model not loaded"; + } + } + + private async Task UnloadSessionAsync() + { + if (!_sessionLoaded) + return; + + await ModelSessionService.CloseAsync(_sessionId); + await AttachmentService.CleanupAsync(_sessionId); + _sessionLoaded = false; + _statusText = "Model not loaded"; + _modelCapabilities = new ModelCapabilities { SupportsText = true }; + _capabilitiesHint = "Load a model session to inspect capabilities."; + } + + private async Task SendAsync() + { + if (!_sessionLoaded || _isSending) + return; + + if (string.IsNullOrWhiteSpace(_inputText)) + return; + + if (HasImageAttachments() && !_modelCapabilities.SupportsVision) + { + _composerError = "This model does not support images. Remove image attachments or choose a multimodal model."; + return; + } + + if (HasAudioAttachments() && !_modelCapabilities.SupportsAudio) + { + _composerError = "This model does not support audio. Remove audio attachments or choose an audio-capable model."; + return; + } + + var prompt = _inputText.Trim(); + _inputText = string.Empty; + _isSending = true; + _composerError = string.Empty; + + var attachments = await UploadAttachmentsAsync(); + + var userMessage = new ChatMessageView(MessageKind.User, prompt) + { + Attachments = attachments + }; + AddMessage(userMessage); + + var assistantMessage = new ChatMessageView(MessageKind.Assistant, string.Empty); + AddMessage(assistantMessage); + + var request = new PromptRequest + { + Prompt = prompt, + AttachmentIds = attachments.Select(a => a.Id).ToList() + }; + + try + { + _inferenceOptions = BuildInferenceOptions(); + await foreach (var token in ModelSessionService.InferAsync(_sessionId, request, _inferenceOptions)) + { + if (token.TokenType == TokenType.Content) + { + assistantMessage.Content += token.Content; + MarkForRender(assistantMessage); + } + else if (token.TokenType == TokenType.End || token.TokenType == TokenType.Cancel) + { + assistantMessage.Meta = token.Content; + } + + _scrollToBottom = true; + await InvokeAsync(StateHasChanged); + } + } + catch (Exception ex) + { + AddSystemMessage($"Failed to invoke prompt: {ex.Message}", MessageKind.Error); + } + finally + { + _isSending = false; + } + } + + private async Task CancelAsync() + { + if (!_isSending) + return; + + await ModelSessionService.CancelAsync(_sessionId); + _isSending = false; + AddSystemMessage("Generation cancelled.", MessageKind.Info); + } + + private async Task> UploadAttachmentsAsync() + { + if (_pendingFiles.Count == 0) + return new List(); + + try + { + AddSystemMessage("Uploading attachments...", MessageKind.Info); + var result = await AttachmentService.SaveAsync(_sessionId, _pendingFiles, CancellationToken.None); + _pendingFiles.Clear(); + return result.Attachments.ToList(); + } + catch (Exception ex) + { + AddSystemMessage($"Attachment upload failed: {ex.Message}", MessageKind.Error); + return new List(); + } + } + + private void OnFilesSelected(InputFileChangeEventArgs args) + { + _pendingFiles.Clear(); + var invalidFiles = new List(); + foreach (var file in args.GetMultipleFiles()) + { + if (IsAllowedUpload(file)) + { + _pendingFiles.Add(file); + } + else + { + invalidFiles.Add(file.Name); + } + } + + if (invalidFiles.Count > 0) + { + var capabilityNote = BuildCapabilitiesHint(); + AddSystemMessage($"Unsupported files ignored: {string.Join(", ", invalidFiles)}. {capabilityNote}", MessageKind.Info); + } + _composerError = string.Empty; + } + + private void ClearAttachments() + { + _pendingFiles.Clear(); + ClearRecordedAudio(); + _composerError = string.Empty; + } + + private bool IsAllowedUpload(IBrowserFile file) + { + var extension = Path.GetExtension(file.Name).ToLowerInvariant(); + var contentType = file.ContentType?.ToLowerInvariant() ?? string.Empty; + + if (extension == ".doc") + return false; + + if (extension == ".docx" || contentType.Contains("officedocument.wordprocessingml")) + return true; + + if (extension == ".pdf" || contentType.Contains("pdf")) + return true; + + if (contentType.StartsWith("image/")) + return true; + + if (contentType.StartsWith("audio/") || extension is ".wav" or ".mp3" or ".m4a" or ".ogg" or ".flac" or ".webm") + return true; + + return false; + } + + private bool HasImageAttachments() + { + foreach (var file in _pendingFiles) + { + var contentType = file.ContentType?.ToLowerInvariant() ?? string.Empty; + var extension = Path.GetExtension(file.Name).ToLowerInvariant(); + if (contentType.StartsWith("image/") || extension is ".png" or ".jpg" or ".jpeg" or ".webp" or ".bmp" or ".gif") + return true; + } + + return false; + } + + private bool HasAudioAttachments() + { + foreach (var file in _pendingFiles) + { + var contentType = file.ContentType?.ToLowerInvariant() ?? string.Empty; + var extension = Path.GetExtension(file.Name).ToLowerInvariant(); + if (contentType.StartsWith("audio/") || extension is ".wav" or ".mp3" or ".m4a" or ".ogg" or ".flac" or ".webm") + return true; + } + + return false; + } + + private async Task RefreshCapabilitiesAsync() + { + if (!_sessionLoaded) + return; + + _modelCapabilities = await ModelSessionService.GetCapabilitiesAsync(_sessionId); + _capabilitiesHint = _modelCapabilities.SupportsVision || _modelCapabilities.SupportsAudio + ? "Multimodal inputs enabled for this model." + : "Text-only model loaded."; + _acceptedUploadTypes = BuildAcceptedUploadTypes(); + } + + private string BuildAcceptedUploadTypes() + { + var types = new List + { + "application/pdf", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + }; + + if (_modelCapabilities.SupportsVision) + types.Add("image/*"); + + types.Add("audio/*"); + + return string.Join(",", types); + } + + private string BuildCapabilitiesHint() + { + var parts = new List { "Use PDF or DOCX" }; + + if (_modelCapabilities.SupportsVision) + parts.Add("images"); + + if (_modelCapabilities.SupportsAudio) + parts.Add("audio"); + + return $"{string.Join(", ", parts)}."; + } + + private async Task HandleKeyPress(KeyboardEventArgs args) + { + if (args.Key == "Enter" && !args.ShiftKey) + { + await SendAsync(); + } + } + + private async Task ToggleAudioRecordingAsync() + { + if (_isRecordingAudio) + { + await StopAudioRecordingAsync(); + return; + } + + if (!_modelCapabilities.SupportsAudio || !_audioRecordingSupported) + return; + + var result = await JS.InvokeAsync("llamaChat.startAudioRecording"); + if (!result.Started) + { + AddSystemMessage(result.Error ?? "Unable to start audio recording.", MessageKind.Error); + return; + } + + _isRecordingAudio = true; + _composerError = string.Empty; + } + + private async Task StopAudioRecordingAsync() + { + if (!_isRecordingAudio) + return; + + var result = await JS.InvokeAsync("llamaChat.stopAudioRecording"); + _isRecordingAudio = false; + + if (result == null || string.IsNullOrWhiteSpace(result.Base64)) + { + AddSystemMessage("No audio captured.", MessageKind.Info); + return; + } + + var bytes = Convert.FromBase64String(result.Base64); + var mimeType = string.IsNullOrWhiteSpace(result.MimeType) ? "audio/webm" : result.MimeType; + var extension = GetAudioExtension(mimeType); + var fileName = $"recording-{DateTimeOffset.Now:yyyyMMdd-HHmmss}{extension}"; + + _recordedAudioFile = new MemoryBrowserFile(fileName, mimeType, bytes, DateTimeOffset.Now); + _recordedAudioName = fileName; + _recordedAudioPreviewUrl = $"data:{mimeType};base64,{result.Base64}"; + _hasRecordedAudio = true; + _composerError = string.Empty; + } + + private void AcceptRecordedAudio() + { + if (_recordedAudioFile == null) + return; + + _pendingFiles.Add(_recordedAudioFile); + ClearRecordedAudio(); + _composerError = string.Empty; + } + + private void DiscardRecordedAudio() + { + ClearRecordedAudio(); + } + + private void ClearRecordedAudio() + { + _recordedAudioFile = null; + _recordedAudioName = string.Empty; + _recordedAudioPreviewUrl = string.Empty; + _hasRecordedAudio = false; + } + + private static string GetAudioExtension(string contentType) + { + var normalized = contentType.ToLowerInvariant(); + if (normalized.Contains("wav")) + return ".wav"; + if (normalized.Contains("mpeg") || normalized.Contains("mp3")) + return ".mp3"; + if (normalized.Contains("ogg")) + return ".ogg"; + if (normalized.Contains("mp4") || normalized.Contains("m4a")) + return ".m4a"; + if (normalized.Contains("webm")) + return ".webm"; + + return ".audio"; + } + + private bool IsModelReady(string modelName) + { + if (string.IsNullOrWhiteSpace(modelName)) + return false; + + return ModelDownloadService.IsModelReady(modelName); + } + + private void AddSystemMessage(string message, MessageKind kind) + { + AddMessage(new ChatMessageView(kind, message)); + } + + private void AddMessage(ChatMessageView message) + { + _messages.Add(message); + MarkForRender(message); + _scrollToBottom = true; + } + + private void MarkForRender(ChatMessageView message) + { + _pendingRenders.Add(message.Id); + } + + private static string FormatBytes(long? value) + { + if (!value.HasValue) + return "-"; + + var bytes = value.Value; + if (bytes <= 0) + return "0 B"; + + var sizes = new[] { "B", "KB", "MB", "GB", "TB" }; + var order = (int)Math.Floor(Math.Log(bytes, 1024)); + var adjusted = bytes / Math.Pow(1024, order); + return $"{adjusted:0.0} {sizes[order]}"; + } + + private InferenceOptions BuildInferenceOptions() + { + return new InferenceOptions + { + MaxTokens = _samplingSettings.MaxTokens, + DecodeSpecialTokens = true, + SamplingPipeline = new DefaultSamplingPipeline + { + Temperature = _samplingSettings.Temperature, + TopP = _samplingSettings.TopP, + TopK = _samplingSettings.TopK, + RepeatPenalty = _samplingSettings.RepeatPenalty, + PresencePenalty = _samplingSettings.PresencePenalty, + FrequencyPenalty = _samplingSettings.FrequencyPenalty, + PenaltyCount = _samplingSettings.PenaltyCount, + TypicalP = _samplingSettings.TypicalP, + MinP = _samplingSettings.MinP, + PreventEOS = _samplingSettings.PreventEos, + PenalizeNewline = _samplingSettings.PenalizeNewline + } + }; + } + + public async ValueTask DisposeAsync() + { + _downloadCts?.Cancel(); + _downloadTimer?.Dispose(); + await ModelSessionService.CloseAsync(_sessionId); + await AttachmentService.CleanupAsync(_sessionId); + } + + private sealed class ChatMessageView + { + public ChatMessageView(MessageKind kind, string content) + { + Id = Guid.NewGuid().ToString("N"); + Kind = kind; + Content = content; + Attachments = new List(); + } + + public string Id { get; } + public MessageKind Kind { get; } + public string Content { get; set; } + public string Meta { get; set; } + public ElementReference ContentRef { get; set; } + public List Attachments { get; set; } + + public string RoleClass => Kind switch + { + MessageKind.User => "user", + MessageKind.Assistant => "assistant", + MessageKind.Info => "info", + MessageKind.Error => "error", + _ => "assistant" + }; + + public string Avatar => Kind switch + { + MessageKind.User => "U", + MessageKind.Assistant => "AI", + MessageKind.Info => "i", + MessageKind.Error => "!", + _ => "•" + }; + } + + private sealed class SamplingSettings + { + public int MaxTokens { get; set; } + public float Temperature { get; set; } + public float TopP { get; set; } + public int TopK { get; set; } + public float RepeatPenalty { get; set; } + public float PresencePenalty { get; set; } + public float FrequencyPenalty { get; set; } + public int PenaltyCount { get; set; } + public float TypicalP { get; set; } + public float MinP { get; set; } + public bool PreventEos { get; set; } + public bool PenalizeNewline { get; set; } + + public static SamplingSettings CreateDefault() + { + return new SamplingSettings + { + MaxTokens = 512, + Temperature = 0.7f, + TopP = 0.9f, + TopK = 40, + RepeatPenalty = 1.0f, + PresencePenalty = 0.0f, + FrequencyPenalty = 0.0f, + PenaltyCount = 64, + TypicalP = 1.0f, + MinP = 0.1f, + PreventEos = false, + PenalizeNewline = false + }; + } + } + + private enum MessageKind + { + User, + Assistant, + Info, + Error + } + + private enum SidebarTab + { + Models, + Settings + } + + private sealed class AudioRecordingStartResult + { + public bool Started { get; set; } + public string Error { get; set; } + public string MimeType { get; set; } + } + + private sealed class AudioRecordingStopResult + { + public string Base64 { get; set; } + public string MimeType { get; set; } + public long Size { get; set; } + } + + private void SetTab(SidebarTab tab) + { + _activeTab = tab; + _ = JS.InvokeVoidAsync("llamaChat.setSidebarTab", tab.ToString()); + } + + private string GetTabClass(SidebarTab tab) + { + return _activeTab == tab ? "tab-button active" : "tab-button"; + } + + private string GetCapabilityClass(bool enabled) + { + return enabled ? "capability-pill enabled" : "capability-pill disabled"; + } +} diff --git a/LLama.Web/Pages/Index.cshtml b/LLama.Web/Pages/Index.cshtml deleted file mode 100644 index 52f87b62c..000000000 --- a/LLama.Web/Pages/Index.cshtml +++ /dev/null @@ -1,110 +0,0 @@ -@page -@using LLama.Web.Common; - -@model IndexModel -@{ - ViewData["Title"] = "Home"; -} - -@Html.AntiForgeryToken() -
- - - -
-
-
- -
-
-
-
- -
- - -
- -
-
- -@{ - await Html.RenderPartialAsync("_ChatTemplates"); -} - -@section Scripts { - - -} \ No newline at end of file diff --git a/LLama.Web/Pages/Index.cshtml.cs b/LLama.Web/Pages/Index.cshtml.cs deleted file mode 100644 index b2c65aa90..000000000 --- a/LLama.Web/Pages/Index.cshtml.cs +++ /dev/null @@ -1,45 +0,0 @@ -using LLama.Sampling; -using LLama.Web.Common; -using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.Mvc.RazorPages; -using Microsoft.Extensions.Options; - -namespace LLama.Web.Pages -{ - public class IndexModel : PageModel - { - private readonly ILogger _logger; - - public IndexModel(ILogger logger, IOptions options) - { - _logger = logger; - Options = options.Value; - } - - public LLamaOptions Options { get; set; } - - [BindProperty] - public ISessionConfig SessionConfig { get; set; } - - [BindProperty] - public InferenceOptions InferenceOptions { get; set; } - - public void OnGet() - { - SessionConfig = new SessionConfig - { - Prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.", - AntiPrompt = "User:", - OutputFilter = "User:, Assistant: " - }; - - InferenceOptions = new InferenceOptions - { - SamplingPipeline = new DefaultSamplingPipeline - { - Temperature = 0.8f - }, - }; - } - } -} \ No newline at end of file diff --git a/LLama.Web/Pages/Shared/_ChatTemplates.cshtml b/LLama.Web/Pages/Shared/_ChatTemplates.cshtml index 843c58607..f70b7230b 100644 --- a/LLama.Web/Pages/Shared/_ChatTemplates.cshtml +++ b/LLama.Web/Pages/Shared/_ChatTemplates.cshtml @@ -17,7 +17,13 @@ -
{{text}}
+
{{text}}
+ {{#attachments}} +
+ + {{fileName}} +
+ {{/attachments}} @@ -34,9 +40,9 @@ -
+
- \ No newline at end of file + diff --git a/LLama.Web/Pages/Shared/_Layout.cshtml b/LLama.Web/Pages/Shared/_Layout.cshtml index bba9c6831..80585255f 100644 --- a/LLama.Web/Pages/Shared/_Layout.cshtml +++ b/LLama.Web/Pages/Shared/_Layout.cshtml @@ -5,83 +5,33 @@ @ViewData["Title"] - LLamaSharp Web + -
- -
- -
+
@RenderBody()
- - - - - + + + + + + + + + + + + + @await RenderSectionAsync("Scripts", required: false) - \ No newline at end of file + diff --git a/LLama.Web/Pages/Shared/_Parameters.cshtml b/LLama.Web/Pages/Shared/_Parameters.cshtml index 009f1fe83..780313f96 100644 --- a/LLama.Web/Pages/Shared/_Parameters.cshtml +++ b/LLama.Web/Pages/Shared/_Parameters.cshtml @@ -1,6 +1,10 @@ @page -@using LLama.Common; -@model LLama.Abstractions.IInferenceParams +@using LLama.Sampling; +@model LLama.Web.Common.InferenceOptions + +@{ + var pipeline = Model?.SamplingPipeline as DefaultSamplingPipeline ?? new DefaultSamplingPipeline(); +}
@@ -20,12 +24,11 @@
-@* todo: update for new sampling system
TopK
- @Html.TextBoxFor(m => m.TopK, new { @type = "range", @class = "slider", min = "-1", max = "100", step = "1" }) + @Html.TextBox("SamplingPipeline.TopK", pipeline.TopK, new { @type = "range", @class = "slider", min = "1", max = "100", step = "1" })
@@ -33,7 +36,7 @@
TopP
- @Html.TextBoxFor(m => m.TopP, new { @type = "range", @class = "slider", min = "0.0", max = "1.0", step = "0.01" }) + @Html.TextBox("SamplingPipeline.TopP", pipeline.TopP, new { @type = "range", @class = "slider", min = "0.0", max = "1.0", step = "0.01" })
@@ -43,7 +46,7 @@
TypicalP
- @Html.TextBoxFor(m => m.TypicalP, new { @type = "range", @class = "slider", min = "0.0", max = "1.0", step = "0.01" }) + @Html.TextBox("SamplingPipeline.TypicalP", pipeline.TypicalP, new { @type = "range", @class = "slider", min = "0.0", max = "1.0", step = "0.01" })
@@ -51,7 +54,7 @@
Temperature
- @Html.TextBoxFor(m => m.Temperature, new { @type = "range", @class = "slider", min = "0.0", max = "1.5", step = "0.01" }) + @Html.TextBox("SamplingPipeline.Temperature", pipeline.Temperature, new { @type = "range", @class = "slider", min = "0.0", max = "1.5", step = "0.01" })
@@ -61,15 +64,15 @@
RepeatPenalty
- @Html.TextBoxFor(m => m.RepeatPenalty, new { @type = "range", @class = "slider", min = "0.0", max = "2.0", step = "0.01" }) + @Html.TextBox("SamplingPipeline.RepeatPenalty", pipeline.RepeatPenalty, new { @type = "range", @class = "slider", min = "0.0", max = "2.0", step = "0.01" })
- RepeatLastTokensCount + PenaltyCount
- @Html.TextBoxFor(m => m.RepeatLastTokensCount, new { @type = "range", @class = "slider", min = "0", max = "2048", step = "1" }) + @Html.TextBox("SamplingPipeline.PenaltyCount", pipeline.PenaltyCount, new { @type = "range", @class = "slider", min = "0", max = "2048", step = "1" })
@@ -79,7 +82,7 @@
FrequencyPenalty
- @Html.TextBoxFor(m => m.FrequencyPenalty, new { @type = "range", @class = "slider", min = "0.0", max = "1.0", step = "0.01" }) + @Html.TextBox("SamplingPipeline.FrequencyPenalty", pipeline.FrequencyPenalty, new { @type = "range", @class = "slider", min = "-2.0", max = "2.0", step = "0.01" })
@@ -87,7 +90,7 @@
PresencePenalty
- @Html.TextBoxFor(m => m.PresencePenalty, new { @type = "range", @class = "slider", min = "0.0", max = "1.0", step = "0.01" }) + @Html.TextBox("SamplingPipeline.PresencePenalty", pipeline.PresencePenalty, new { @type = "range", @class = "slider", min = "-2.0", max = "2.0", step = "0.01" })
@@ -95,45 +98,33 @@
- TfsZ + MinP
- @Html.TextBoxFor(m => m.TfsZ, new { @type = "range", @class = "slider", min = "0.0", max = "1.0", step = "0.01" }) + @Html.TextBox("SamplingPipeline.MinP", pipeline.MinP, new { @type = "range", @class = "slider", min = "0.0", max = "1.0", step = "0.01" })
- - + MinKeep
- - + @Html.TextBox("SamplingPipeline.MinKeep", pipeline.MinKeep, new { @type = "range", @class = "slider", min = "1", max = "64", step = "1" }) +
- -
-
- @Html.DropDownListFor(m => m.Mirostat, Html.GetEnumSelectList(), new { @class = "form-control form-select" }) - @Html.LabelFor(m => m.Mirostat) -
-
-
- MirostatTau -
- @Html.TextBoxFor(m => m.MirostatTau, new { @type = "range", @class = "slider", min = "0.0", max = "10.0", step = "0.01" }) - + PreventEOS +
+ @Html.CheckBox("SamplingPipeline.PreventEOS", pipeline.PreventEOS, new { @class = "form-check-input" })
- MirostatEta -
- @Html.TextBoxFor(m => m.MirostatEta, new { @type = "range", @class = "slider", min = "0.0", max = "1.0", step = "0.01" }) - + PenalizeNewline +
+ @Html.CheckBox("SamplingPipeline.PenalizeNewline", pipeline.PenalizeNewline, new { @class = "form-check-input" })
- -*@ \ No newline at end of file diff --git a/LLama.Web/Pages/_Host.cshtml b/LLama.Web/Pages/_Host.cshtml new file mode 100644 index 000000000..906a6c180 --- /dev/null +++ b/LLama.Web/Pages/_Host.cshtml @@ -0,0 +1,8 @@ +@page "/" +@namespace LLama.Web.Pages +@addTagHelper *, Microsoft.AspNetCore.Mvc.TagHelpers +@{ + Layout = "_Layout"; +} + + diff --git a/LLama.Web/Program.cs b/LLama.Web/Program.cs index 7ab301507..3cd4cb4d7 100644 --- a/LLama.Web/Program.cs +++ b/LLama.Web/Program.cs @@ -6,6 +6,8 @@ // Add services to the container. var mvcBuilder = builder.Services.AddRazorPages(); +builder.Services.AddServerSideBlazor(); +builder.Services.AddControllers(); if (builder.Environment.IsDevelopment()) { @@ -16,18 +18,25 @@ builder.Services.AddSignalR(); builder.Logging.ClearProviders(); builder.Services.AddLogging((loggingBuilder) => loggingBuilder.SetMinimumLevel(LogLevel.Trace).AddConsole()); +builder.Services.AddHttpClient(nameof(ModelDownloadService), client => client.Timeout = Timeout.InfiniteTimeSpan); -// Load InteractiveOptions +// Load interactive options. builder.Services.AddOptions() .BindConfiguration(nameof(LLamaOptions)); -// Services DI +// Register DI services. builder.Services.AddHostedService(); +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); var app = builder.Build(); +app.Logger.LogInformation("Model storage path: {Path}", app.Services.GetRequiredService().ModelsRoot); +app.Logger.LogInformation("Download staging path: {Path}", app.Services.GetRequiredService().DownloadsRoot); +app.Logger.LogInformation("Uploads path: {Path}", app.Services.GetRequiredService().UploadsRoot); + // Configure the HTTP request pipeline. if (!app.Environment.IsDevelopment()) { @@ -44,6 +53,9 @@ app.UseAuthorization(); app.MapRazorPages(); +app.MapControllers(); +app.MapBlazorHub(); +app.MapFallbackToPage("/_Host"); app.MapHub(nameof(SessionConnectionHub)); diff --git a/LLama.Web/README.md b/LLama.Web/README.md index e14562599..101b7fd35 100644 --- a/LLama.Web/README.md +++ b/LLama.Web/README.md @@ -9,6 +9,17 @@ Models, Prompts and Inference parameters can be added to `appsettings.json`. If you would like to add your own local model files then it's best to create an `appSettings.Local.json` file and add them there. The `appSettings.Local.json` file will be ignored by Git. +Local assets downloaded at runtime (models in your AppData `LLama.Web/Models`, uploaded files in your AppData `LLama.Web/Uploads`, and offline web libs in `LLama.Web/wwwroot/lib`) are intentionally git-ignored. + +To restore offline web assets (markdown/diagram libraries), install LibMan if needed (`dotnet tool install -g Microsoft.Web.LibraryManager.Cli`), ensure the dotnet tools folder is on your PATH (typically `~/.dotnet/tools`), and run LibMan from the `LLama.Web` folder: +```sh +libman restore +``` + +Attachments uploaded through the web UI are stored under your AppData `LLama.Web/Uploads/` and are cleaned up when the SignalR connection closes. + +Model downloads are staged under your AppData `LLama.Web/Downloads` and finalized in AppData `LLama.Web/Models`. + **Models** You can add multiple models to the options for quick selection in the UI, options are based on ModelParams so its fully configurable. diff --git a/LLama.Web/Services/AttachmentService.cs b/LLama.Web/Services/AttachmentService.cs new file mode 100644 index 000000000..4990bf845 --- /dev/null +++ b/LLama.Web/Services/AttachmentService.cs @@ -0,0 +1,359 @@ +using System.Collections.Concurrent; +using System.IO.Compression; +using System.Linq; +using System.Text; +using System.Xml.Linq; +using LLama.Web.Models; +using Microsoft.AspNetCore.Components.Forms; +using UglyToad.PdfPig; + +namespace LLama.Web.Services; + +public class AttachmentService : IAttachmentService +{ + private const int MaxExtractedCharacters = 12000; + private const long MaxUploadSize = 512L * 1024 * 1024; + private readonly IWebHostEnvironment _environment; + private readonly string _uploadsRoot; + private readonly ConcurrentDictionary> _attachments = new(); + + public AttachmentService(IWebHostEnvironment environment) + { + _environment = environment; + var appDataRoot = Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData); + _uploadsRoot = Path.Combine(appDataRoot, "LLama.Web", "Uploads"); + Directory.CreateDirectory(_uploadsRoot); + } + + public string UploadsRoot => _uploadsRoot; + + public async Task SaveAsync(string connectionId, IEnumerable files, CancellationToken cancellationToken) + { + if (string.IsNullOrWhiteSpace(connectionId)) + throw new ArgumentException("Connection id is required.", nameof(connectionId)); + + ValidateUploads(files); + + var result = new AttachmentUploadResult(); + var storage = _attachments.GetOrAdd(connectionId, _ => new ConcurrentDictionary()); + var root = Path.Combine(_uploadsRoot, connectionId); + Directory.CreateDirectory(root); + + foreach (var file in files) + { + if (file == null || file.Length == 0) + continue; + + var id = Guid.NewGuid().ToString("N"); + var safeName = Path.GetFileName(file.FileName); + var filePath = Path.Combine(root, $"{id}-{safeName}"); + + await using (var stream = new FileStream(filePath, FileMode.Create, FileAccess.Write, FileShare.None, 81920, useAsync: true)) + { + await file.CopyToAsync(stream, cancellationToken); + } + + var info = new AttachmentInfo + { + Id = id, + ConnectionId = connectionId, + FileName = safeName, + FilePath = filePath, + ContentType = file.ContentType, + SizeBytes = file.Length, + Kind = GetKind(file) + }; + + if (info.Kind == AttachmentKind.Pdf) + ExtractPdfText(info); + if (info.Kind == AttachmentKind.Word) + ExtractWordText(info); + + storage[id] = info; + result.Attachments.Add(info); + } + + return result; + } + + public async Task SaveAsync(string connectionId, IEnumerable files, CancellationToken cancellationToken) + { + if (string.IsNullOrWhiteSpace(connectionId)) + throw new ArgumentException("Connection id is required.", nameof(connectionId)); + + ValidateUploads(files); + + var result = new AttachmentUploadResult(); + var storage = _attachments.GetOrAdd(connectionId, _ => new ConcurrentDictionary()); + var root = Path.Combine(_uploadsRoot, connectionId); + Directory.CreateDirectory(root); + + foreach (var file in files) + { + if (file == null || file.Size == 0) + continue; + + var id = Guid.NewGuid().ToString("N"); + var safeName = Path.GetFileName(file.Name); + var filePath = Path.Combine(root, $"{id}-{safeName}"); + + await using (var input = file.OpenReadStream(maxAllowedSize: MaxUploadSize, cancellationToken)) + await using (var output = new FileStream(filePath, FileMode.Create, FileAccess.Write, FileShare.None, 81920, useAsync: true)) + { + await input.CopyToAsync(output, cancellationToken); + } + + var info = new AttachmentInfo + { + Id = id, + ConnectionId = connectionId, + FileName = safeName, + FilePath = filePath, + ContentType = file.ContentType, + SizeBytes = file.Size, + Kind = GetKind(file) + }; + + if (info.Kind == AttachmentKind.Pdf) + ExtractPdfText(info); + if (info.Kind == AttachmentKind.Word) + ExtractWordText(info); + + storage[id] = info; + result.Attachments.Add(info); + } + + return result; + } + + public IReadOnlyList GetAttachments(string connectionId, IEnumerable ids) + { + if (string.IsNullOrWhiteSpace(connectionId) || ids is null) + return Array.Empty(); + + if (!_attachments.TryGetValue(connectionId, out var storage)) + return Array.Empty(); + + var list = new List(); + foreach (var id in ids) + { + if (storage.TryGetValue(id, out var info)) + list.Add(info); + } + + return list; + } + + public AttachmentInfo GetAttachment(string connectionId, string id) + { + if (string.IsNullOrWhiteSpace(connectionId) || string.IsNullOrWhiteSpace(id)) + return null; + + if (!_attachments.TryGetValue(connectionId, out var storage)) + return null; + + return storage.TryGetValue(id, out var info) ? info : null; + } + + public Task CleanupAsync(string connectionId) + { + if (string.IsNullOrWhiteSpace(connectionId)) + return Task.CompletedTask; + + if (_attachments.TryRemove(connectionId, out _)) + { + var root = Path.Combine(_uploadsRoot, connectionId); + if (Directory.Exists(root)) + Directory.Delete(root, recursive: true); + } + + return Task.CompletedTask; + } + + private static AttachmentKind GetKind(IFormFile file) + { + var contentType = file.ContentType?.ToLowerInvariant() ?? string.Empty; + var extension = Path.GetExtension(file.FileName).ToLowerInvariant(); + + if (contentType.Contains("pdf") || extension == ".pdf") + return AttachmentKind.Pdf; + if (IsWordDocument(contentType, extension)) + return AttachmentKind.Word; + if (contentType.StartsWith("image/")) + return AttachmentKind.Image; + if (contentType.StartsWith("audio/") || extension is ".wav" or ".mp3" or ".m4a" or ".ogg" or ".flac" or ".webm") + return AttachmentKind.Audio; + + return AttachmentKind.Unknown; + } + + private static AttachmentKind GetKind(IBrowserFile file) + { + var contentType = file.ContentType?.ToLowerInvariant() ?? string.Empty; + var extension = Path.GetExtension(file.Name).ToLowerInvariant(); + + if (contentType.Contains("pdf") || extension == ".pdf") + return AttachmentKind.Pdf; + if (IsWordDocument(contentType, extension)) + return AttachmentKind.Word; + if (contentType.StartsWith("image/")) + return AttachmentKind.Image; + if (contentType.StartsWith("audio/") || extension is ".wav" or ".mp3" or ".m4a" or ".ogg" or ".flac" or ".webm") + return AttachmentKind.Audio; + + return AttachmentKind.Unknown; + } + + private static bool IsWordDocument(string contentType, string extension) + { + if (extension == ".docx") + return true; + + if (string.IsNullOrWhiteSpace(contentType)) + return false; + + return contentType.Contains("word") || contentType.Contains("officedocument.wordprocessingml"); + } + + private static bool IsAllowedUpload(string contentType, string extension) + { + if (extension == ".doc") + return false; + + if (extension == ".docx" || contentType.Contains("officedocument.wordprocessingml") || contentType.Contains("word")) + return true; + + if (extension == ".pdf" || contentType.Contains("pdf")) + return true; + + if (contentType.StartsWith("image/")) + return true; + + if (contentType.StartsWith("audio/")) + return true; + + if (extension is ".wav" or ".mp3" or ".m4a" or ".ogg" or ".flac" or ".webm") + return true; + + return false; + } + + private static void ValidateUploads(IEnumerable files) + { + var invalid = files + .Where(file => file != null) + .Where(file => !IsAllowedUpload(file.ContentType?.ToLowerInvariant() ?? string.Empty, Path.GetExtension(file.FileName).ToLowerInvariant())) + .Select(file => file.FileName) + .ToList(); + + if (invalid.Count == 0) + return; + + throw new InvalidOperationException($"Unsupported files: {string.Join(", ", invalid)}. Use PDF, DOCX, or images."); + } + + private static void ValidateUploads(IEnumerable files) + { + var invalid = files + .Where(file => file != null) + .Where(file => !IsAllowedUpload(file.ContentType?.ToLowerInvariant() ?? string.Empty, Path.GetExtension(file.Name).ToLowerInvariant())) + .Select(file => file.Name) + .ToList(); + + if (invalid.Count == 0) + return; + + throw new InvalidOperationException($"Unsupported files: {string.Join(", ", invalid)}. Use PDF, DOCX, or images."); + } + + private static void ExtractPdfText(AttachmentInfo info) + { + if (!File.Exists(info.FilePath)) + return; + + var builder = new StringBuilder(); + try + { + using var document = PdfDocument.Open(info.FilePath); + foreach (var page in document.GetPages()) + { + if (builder.Length >= MaxExtractedCharacters) + break; + + builder.AppendLine(page.Text); + } + + var text = builder.ToString(); + if (text.Length > MaxExtractedCharacters) + { + info.ExtractedText = text.Substring(0, MaxExtractedCharacters); + info.ExtractedTextTruncated = true; + } + else + { + info.ExtractedText = text; + info.ExtractedTextTruncated = false; + } + } + catch + { + info.ExtractedText = string.Empty; + info.ExtractedTextTruncated = false; + } + } + + private static void ExtractWordText(AttachmentInfo info) + { + if (!File.Exists(info.FilePath)) + return; + + var extension = Path.GetExtension(info.FilePath).ToLowerInvariant(); + if (extension != ".docx") + { + info.ExtractedText = string.Empty; + info.ExtractedTextTruncated = false; + return; + } + + try + { + using var archive = ZipFile.OpenRead(info.FilePath); + var entry = archive.GetEntry("word/document.xml"); + if (entry == null) + return; + + using var stream = entry.Open(); + var document = XDocument.Load(stream); + var textNodes = document.Descendants() + .Where(node => node.Name.LocalName == "t") + .Select(node => node.Value); + + var builder = new StringBuilder(); + foreach (var text in textNodes) + { + if (builder.Length >= MaxExtractedCharacters) + break; + + builder.Append(text); + builder.AppendLine(); + } + + var resultText = builder.ToString(); + if (resultText.Length > MaxExtractedCharacters) + { + info.ExtractedText = resultText.Substring(0, MaxExtractedCharacters); + info.ExtractedTextTruncated = true; + } + else + { + info.ExtractedText = resultText; + info.ExtractedTextTruncated = false; + } + } + catch + { + info.ExtractedText = string.Empty; + info.ExtractedTextTruncated = false; + } + } +} diff --git a/LLama.Web/Services/IAttachmentService.cs b/LLama.Web/Services/IAttachmentService.cs new file mode 100644 index 000000000..31a2789da --- /dev/null +++ b/LLama.Web/Services/IAttachmentService.cs @@ -0,0 +1,14 @@ +using LLama.Web.Models; +using Microsoft.AspNetCore.Components.Forms; + +namespace LLama.Web.Services; + +public interface IAttachmentService +{ + string UploadsRoot { get; } + Task SaveAsync(string connectionId, IEnumerable files, CancellationToken cancellationToken); + Task SaveAsync(string connectionId, IEnumerable files, CancellationToken cancellationToken); + IReadOnlyList GetAttachments(string connectionId, IEnumerable ids); + AttachmentInfo GetAttachment(string connectionId, string id); + Task CleanupAsync(string connectionId); +} diff --git a/LLama.Web/Services/IModelDownloadService.cs b/LLama.Web/Services/IModelDownloadService.cs new file mode 100644 index 000000000..904ac13ba --- /dev/null +++ b/LLama.Web/Services/IModelDownloadService.cs @@ -0,0 +1,13 @@ +using LLama.Web.Models; + +namespace LLama.Web.Services; + +public interface IModelDownloadService +{ + string ModelsRoot { get; } + string DownloadsRoot { get; } + void StartDownloads(CancellationToken cancellationToken); + Task WaitForDownloadsAsync(CancellationToken cancellationToken); + IReadOnlyList GetSnapshots(); + bool IsModelReady(string modelName); +} diff --git a/LLama.Web/Services/IModelService.cs b/LLama.Web/Services/IModelService.cs index 0500a0b29..57796417b 100644 --- a/LLama.Web/Services/IModelService.cs +++ b/LLama.Web/Services/IModelService.cs @@ -4,7 +4,7 @@ namespace LLama.Web.Services; /// -/// Service for managing language Models +/// Service for managing language models. /// public interface IModelService { @@ -37,7 +37,7 @@ public interface IModelService Task UnloadModels(); /// - /// Gets a context with the specified identifier + /// Gets a context with the specified identifier. /// /// Name of the model. /// The context identifier. @@ -58,11 +58,11 @@ public interface IModelService Task CreateContext(string modelName, string contextName); /// - /// Gets the or create model and context. - /// This will load a model from disk if not already loaded, and also create the context + /// Gets or creates a model and context. + /// This loads a model from disk if it is not already loaded and creates the context. /// /// Name of the model. /// The context identifier. - /// Both loaded Model and Context + /// The loaded model and context. Task<(LLamaModel, LLamaContext)> GetOrCreateModelAndContext(string modelName, string contextName); -} \ No newline at end of file +} diff --git a/LLama.Web/Services/IModelSessionService.cs b/LLama.Web/Services/IModelSessionService.cs index 974aebd2f..392403379 100644 --- a/LLama.Web/Services/IModelSessionService.cs +++ b/LLama.Web/Services/IModelSessionService.cs @@ -6,67 +6,67 @@ namespace LLama.Web.Services; public interface IModelSessionService { /// - /// Gets the ModelSession with the specified Id. + /// Gets the ModelSession with the specified ID. /// /// The session identifier. - /// The ModelSession if exists, otherwise null + /// The ModelSession if it exists; otherwise, null. Task GetAsync(string sessionId); /// - /// Gets all ModelSessions + /// Gets all model sessions. /// - /// A collection oa all Model instances + /// A collection of all model sessions. Task> GetAllAsync(); /// - /// Creates a new ModelSession + /// Creates a new model session. /// /// The session identifier. /// The session configuration. - /// The default inference configuration, will be used for all inference where no infer configuration is supplied. + /// The default inference configuration, used when no inference configuration is supplied. /// The cancellation token. /// /// - /// Session with id {sessionId} already exists + /// Session with ID {sessionId} already exists /// or /// Failed to create model session /// Task CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default); /// - /// Closes the session + /// Closes the session. /// /// The session identifier. /// Task CloseAsync(string sessionId); /// - /// Runs inference on the current ModelSession + /// Runs inference on the current model session. /// /// The session identifier. - /// The prompt. - /// The inference configuration, if null session default is used + /// The prompt request. + /// The inference configuration; if null, the session default is used. /// The cancellation token. /// Inference is already running for this session - IAsyncEnumerable InferAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default); + IAsyncEnumerable InferAsync(string sessionId, PromptRequest request, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default); /// - /// Runs inference on the current ModelSession + /// Runs inference on the current model session. /// /// The session identifier. /// The prompt. - /// The inference configuration, if null session default is used + /// The inference configuration; if null, the session default is used. /// The cancellation token. /// Streaming async result of /// Inference is already running for this session IAsyncEnumerable InferTextAsync(string sessionId, string prompt, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default); /// - /// Queues inference on the current ModelSession + /// Queues inference on the current model session. /// /// The session identifier. /// The prompt. - /// The inference configuration, if null session default is used + /// The inference configuration; if null, the session default is used. /// The cancellation token. /// Completed inference result as string /// Inference is already running for this session @@ -78,4 +78,11 @@ public interface IModelSessionService /// The session identifier. /// Task CancelAsync(string sessionId); + + /// + /// Gets the model capabilities for the specified session. + /// + /// The session identifier. + /// The model capabilities. + Task GetCapabilitiesAsync(string sessionId); } diff --git a/LLama.Web/Services/ModelDownloadService.cs b/LLama.Web/Services/ModelDownloadService.cs new file mode 100644 index 000000000..df2c3bc7a --- /dev/null +++ b/LLama.Web/Services/ModelDownloadService.cs @@ -0,0 +1,301 @@ +using System.Collections.Concurrent; +using LLama.Web.Common; +using LLama.Web.Hubs; +using LLama.Web.Models; +using Microsoft.AspNetCore.SignalR; +using Microsoft.Extensions.Options; + +namespace LLama.Web.Services; + +public class ModelDownloadService : IModelDownloadService +{ + private readonly IWebHostEnvironment _environment; + private readonly HttpClient _httpClient; + private readonly LLamaOptions _options; + private readonly IHubContext _hubContext; + private readonly string _modelsRoot; + private readonly string _downloadsRoot; + private readonly object _syncLock = new(); + private readonly ConcurrentDictionary> _assetsByModel = new(); + + private Task _downloadTask; + private bool _started; + private bool _initialized; + + public ModelDownloadService(IWebHostEnvironment environment, IHttpClientFactory httpClientFactory, IOptions options, IHubContext hubContext) + { + _environment = environment; + _httpClient = httpClientFactory.CreateClient(nameof(ModelDownloadService)); + _options = options.Value; + _hubContext = hubContext; + + var appDataRoot = Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData); + var appRoot = Path.Combine(appDataRoot, "LLama.Web"); + _modelsRoot = Path.Combine(appRoot, "Models"); + _downloadsRoot = Path.Combine(appRoot, "Downloads"); + } + + public string ModelsRoot => _modelsRoot; + public string DownloadsRoot => _downloadsRoot; + + public void StartDownloads(CancellationToken cancellationToken) + { + lock (_syncLock) + { + if (_started) + return; + + _started = true; + EnsureInitialized(); + _downloadTask = Task.Run(() => DownloadAllAsync(cancellationToken), cancellationToken); + } + } + + public Task WaitForDownloadsAsync(CancellationToken cancellationToken) + { + lock (_syncLock) + { + return _downloadTask ?? Task.CompletedTask; + } + } + + public bool IsModelReady(string modelName) + { + if (string.IsNullOrWhiteSpace(modelName)) + return false; + + if (!_assetsByModel.TryGetValue(modelName, out var assets)) + return false; + + return assets.All(asset => asset.State == ModelDownloadState.Completed); + } + + public IReadOnlyList GetSnapshots() + { + EnsureInitialized(); + var snapshots = new List(); + foreach (var kvp in _assetsByModel) + { + var assets = kvp.Value; + var snapshot = new ModelDownloadSnapshot + { + ModelName = kvp.Key, + Ready = assets.All(asset => asset.State == ModelDownloadState.Completed) + }; + + foreach (var asset in assets) + snapshot.Assets.Add(ToSnapshot(asset)); + + snapshots.Add(snapshot); + } + + return snapshots; + } + + private void EnsureInitialized() + { + lock (_syncLock) + { + if (_initialized) + return; + + InitializeAssets(); + _initialized = true; + } + + _ = BroadcastSnapshotAsync(); + } + + private void InitializeAssets() + { + Directory.CreateDirectory(_modelsRoot); + Directory.CreateDirectory(_downloadsRoot); + foreach (var model in _options.Models ?? new List()) + { + var assets = new List(); + + var modelPath = ResolvePath(_modelsRoot, model.ModelPath); + model.ModelPath = modelPath; + assets.Add(CreateAssetStatus(model.Name, ModelAssetKind.Model, modelPath, model.ModelDownloadUrl)); + + if (!string.IsNullOrWhiteSpace(model.MmprojPath)) + { + var mmprojPath = ResolvePath(_modelsRoot, model.MmprojPath); + model.MmprojPath = mmprojPath; + assets.Add(CreateAssetStatus(model.Name, ModelAssetKind.Mmproj, mmprojPath, model.MmprojDownloadUrl)); + } + + _assetsByModel[model.Name] = assets; + } + } + + private ModelAssetStatus CreateAssetStatus(string modelName, ModelAssetKind kind, string localPath, string downloadUrl) + { + var status = new ModelAssetStatus + { + ModelName = modelName, + AssetKind = kind, + LocalPath = localPath, + DownloadUrl = downloadUrl, + FileName = Path.GetFileName(localPath), + State = ModelDownloadState.Queued + }; + + if (File.Exists(localPath)) + { + var info = new FileInfo(localPath); + status.TotalBytes = info.Length; + status.BytesReceived = info.Length; + status.State = ModelDownloadState.Completed; + } + else if (string.IsNullOrWhiteSpace(downloadUrl)) + { + status.State = ModelDownloadState.MissingUrl; + status.Error = "Download URL is missing."; + } + + return status; + } + + private async Task DownloadAllAsync(CancellationToken cancellationToken) + { + foreach (var assets in _assetsByModel.Values) + { + foreach (var asset in assets) + { + if (asset.State == ModelDownloadState.Completed) + continue; + + if (asset.State == ModelDownloadState.MissingUrl) + { + await BroadcastProgressAsync(asset); + continue; + } + + await DownloadAssetAsync(asset, cancellationToken); + } + } + } + + private async Task DownloadAssetAsync(ModelAssetStatus asset, CancellationToken cancellationToken) + { + asset.State = ModelDownloadState.Downloading; + asset.BytesReceived = 0; + asset.Error = null; + + await BroadcastProgressAsync(asset); + + var directory = Path.GetDirectoryName(asset.LocalPath); + if (string.IsNullOrWhiteSpace(directory)) + directory = _environment.ContentRootPath; + Directory.CreateDirectory(directory); + + var tempFileName = $"{asset.ModelName}-{asset.AssetKind}-{asset.FileName}.download"; + var tempPath = Path.Combine(_downloadsRoot, tempFileName); + try + { + using var request = new HttpRequestMessage(HttpMethod.Get, asset.DownloadUrl); + using var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); + response.EnsureSuccessStatusCode(); + + asset.TotalBytes = response.Content.Headers.ContentLength; + + var lastReport = DateTimeOffset.UtcNow; + + await using var fileStream = new FileStream(tempPath, FileMode.Create, FileAccess.Write, FileShare.None, 1024 * 64, useAsync: true); + await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken); + + var buffer = new byte[1024 * 64]; + int read; + while ((read = await stream.ReadAsync(buffer.AsMemory(0, buffer.Length), cancellationToken)) > 0) + { + await fileStream.WriteAsync(buffer.AsMemory(0, read), cancellationToken); + asset.BytesReceived += read; + + if (DateTimeOffset.UtcNow - lastReport > TimeSpan.FromMilliseconds(250)) + { + lastReport = DateTimeOffset.UtcNow; + await BroadcastProgressAsync(asset); + } + } + + if (File.Exists(asset.LocalPath)) + File.Delete(asset.LocalPath); + + File.Move(tempPath, asset.LocalPath); + + asset.State = ModelDownloadState.Completed; + asset.BytesReceived = asset.TotalBytes ?? asset.BytesReceived; + asset.Error = null; + } + catch (Exception ex) + { + asset.State = ModelDownloadState.Failed; + asset.Error = ex.Message; + if (File.Exists(tempPath)) + File.Delete(tempPath); + } + finally + { + await BroadcastProgressAsync(asset); + } + } + + private static string ResolvePath(string contentRoot, string path) + { + if (string.IsNullOrWhiteSpace(path)) + return string.Empty; + + return Path.IsPathRooted(path) + ? path + : Path.GetFullPath(Path.Combine(contentRoot, path)); + } + + private Task BroadcastProgressAsync(ModelAssetStatus asset) + { + var progress = new ModelDownloadProgress + { + ModelName = asset.ModelName, + FileName = asset.FileName, + AssetKind = asset.AssetKind, + State = asset.State, + TotalBytes = asset.TotalBytes, + BytesReceived = asset.BytesReceived, + Error = asset.Error + }; + + return _hubContext.Clients.All.OnModelDownloadProgress(progress); + } + + private Task BroadcastSnapshotAsync() + { + return _hubContext.Clients.All.OnModelDownloadSnapshot(GetSnapshots()); + } + + private static ModelAssetSnapshot ToSnapshot(ModelAssetStatus asset) + { + return new ModelAssetSnapshot + { + ModelName = asset.ModelName, + FileName = asset.FileName, + AssetKind = asset.AssetKind, + State = asset.State, + TotalBytes = asset.TotalBytes, + BytesReceived = asset.BytesReceived, + Error = asset.Error + }; + } + + private sealed class ModelAssetStatus + { + public string ModelName { get; set; } + public string FileName { get; set; } + public string LocalPath { get; set; } + public string DownloadUrl { get; set; } + public ModelAssetKind AssetKind { get; set; } + public ModelDownloadState State { get; set; } + public long? TotalBytes { get; set; } + public long BytesReceived { get; set; } + public string Error { get; set; } + } +} diff --git a/LLama.Web/Services/ModelLoaderService.cs b/LLama.Web/Services/ModelLoaderService.cs index 1ff3e9250..69ffd8f02 100644 --- a/LLama.Web/Services/ModelLoaderService.cs +++ b/LLama.Web/Services/ModelLoaderService.cs @@ -3,28 +3,37 @@ namespace LLama.Web.Services; /// /// Service for managing loading/preloading of models at app startup /// -/// Type used to identify contexts /// public class ModelLoaderService : IHostedService { private readonly IModelService _modelService; + private readonly IModelDownloadService _modelDownloadService; /// /// Initializes a new instance of the class. /// /// The model service. - public ModelLoaderService(IModelService modelService) + public ModelLoaderService(IModelService modelService, IModelDownloadService modelDownloadService) { _modelService = modelService; + _modelDownloadService = modelDownloadService; } /// /// Triggered when the application host is ready to start the service. /// /// Indicates that the start process has been aborted. - public async Task StartAsync(CancellationToken cancellationToken) + public Task StartAsync(CancellationToken cancellationToken) { - await _modelService.LoadModels(); + _modelDownloadService.StartDownloads(cancellationToken); + + _ = Task.Run(async () => + { + await _modelDownloadService.WaitForDownloadsAsync(cancellationToken); + await _modelService.LoadModels(); + }, cancellationToken); + + return Task.CompletedTask; } /// diff --git a/LLama.Web/Services/ModelService.cs b/LLama.Web/Services/ModelService.cs index db95beca7..f5351387a 100644 --- a/LLama.Web/Services/ModelService.cs +++ b/LLama.Web/Services/ModelService.cs @@ -7,7 +7,7 @@ namespace LLama.Web.Services; /// -/// Service for handling Models,Weights & Contexts +/// Service for handling models, weights, and contexts. /// public class ModelService : IModelService { @@ -15,19 +15,22 @@ public class ModelService : IModelService private readonly AsyncLock _contextLock; private readonly LLamaOptions _configuration; private readonly ILogger _llamaLogger; + private readonly IModelDownloadService _modelDownloadService; private readonly ConcurrentDictionary _modelInstances; /// /// Initializes a new instance of the class. /// - /// The logger. - /// The options. - public ModelService(IOptions configuration, ILogger llamaLogger) + /// The configuration. + /// The logger. + /// The model download service. + public ModelService(IOptions configuration, ILogger llamaLogger, IModelDownloadService modelDownloadService) { _llamaLogger = llamaLogger; _modelLock = new AsyncLock(); _contextLock = new AsyncLock(); _configuration = configuration.Value; + _modelDownloadService = modelDownloadService; _modelInstances = new ConcurrentDictionary(); } @@ -46,11 +49,17 @@ public async Task LoadModel(ModelOptions modelOptions) if (_modelInstances.TryGetValue(modelOptions.Name, out var model)) return model; - // If in single mode unload any other models + if (!_modelDownloadService.IsModelReady(modelOptions.Name)) + throw new Exception($"Model '{modelOptions.Name}' is still downloading."); + + // If in single mode, unload any other models. if (_configuration.ModelLoadType == ModelLoadType.Single || _configuration.ModelLoadType == ModelLoadType.PreloadSingle) await UnloadModels(); + if (modelOptions.UBatchSize < modelOptions.BatchSize) + modelOptions.UBatchSize = modelOptions.BatchSize; + model = await LLamaModel.CreateAsync(modelOptions, _llamaLogger); _modelInstances.TryAdd(modelOptions.Name, model); return model; @@ -70,7 +79,7 @@ public async Task LoadModels() { await LoadModel(modelConfig); - //Only preload first model if in SinglePreload mode + // Only preload the first model in PreloadSingle mode. if (_configuration.ModelLoadType == ModelLoadType.PreloadSingle) break; } @@ -103,7 +112,7 @@ public async Task UnloadModels() } /// - /// Gets a model ny name. + /// Gets a model by name. /// /// Name of the model. /// @@ -165,7 +174,7 @@ public async Task RemoveContext(string modelName, string contextName) } /// - /// Loads, Gets,Creates a Model and a Context + /// Loads, gets, or creates a model and a context. /// /// Name of the model. /// The contextName. @@ -176,15 +185,15 @@ public async Task RemoveContext(string modelName, string contextName) if (_modelInstances.TryGetValue(modelName, out var model)) return (model, await model.GetContext(contextName) ?? await model.CreateContext(contextName)); - // Get model configuration + // Get model configuration. var modelConfig = _configuration.Models.FirstOrDefault(x => x.Name == modelName); if (modelConfig is null) throw new Exception($"Model option '{modelName}' not found"); - // Load Model + // Load model. model = await LoadModel(modelConfig); - // Get or Create Context + // Get or create context. return (model, await model.GetContext(contextName) ?? await model.CreateContext(contextName)); } } diff --git a/LLama.Web/Services/ModelSessionService.cs b/LLama.Web/Services/ModelSessionService.cs index 3fc95328a..a316c8299 100644 --- a/LLama.Web/Services/ModelSessionService.cs +++ b/LLama.Web/Services/ModelSessionService.cs @@ -1,63 +1,67 @@ using LLama.Web.Async; using LLama.Web.Common; using LLama.Web.Models; +using LLama.Native; using System.Collections.Concurrent; using System.Diagnostics; using System.Runtime.CompilerServices; +using System.Text; namespace LLama.Web.Services; /// -/// Example Service for handling a model session for a websockets connection lifetime -/// Each websocket connection will create its own unique session and context allowing you to use multiple tabs to compare prompts etc +/// Example service for handling a model session for the lifetime of a WebSocket connection. +/// Each WebSocket connection creates its own session and context, allowing you to use multiple tabs to compare prompts and results. /// public class ModelSessionService : IModelSessionService { private readonly AsyncGuard _sessionGuard; private readonly IModelService _modelService; + private readonly IAttachmentService _attachmentService; private readonly ConcurrentDictionary _modelSessions; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The model service. - /// The model session state service. - public ModelSessionService(IModelService modelService) + /// The attachment service. + public ModelSessionService(IModelService modelService, IAttachmentService attachmentService) { _modelService = modelService; + _attachmentService = attachmentService; _sessionGuard = new AsyncGuard(); _modelSessions = new ConcurrentDictionary(); } /// - /// Gets the ModelSession with the specified Id. + /// Gets the ModelSession with the specified ID. /// /// The session identifier. - /// The ModelSession if exists, otherwise null + /// The ModelSession if it exists; otherwise, null. public Task GetAsync(string sessionId) { return Task.FromResult(_modelSessions.TryGetValue(sessionId, out var session) ? session : null); } /// - /// Gets all ModelSessions + /// Gets all model sessions. /// - /// A collection oa all Model instances + /// A collection of all model sessions. public Task> GetAllAsync() { return Task.FromResult>(_modelSessions.Values); } /// - /// Creates a new ModelSession + /// Creates a new model session. /// /// The session identifier. /// The session configuration. - /// The default inference configuration, will be used for all inference where no infer configuration is supplied. + /// The default inference configuration, used when no inference configuration is supplied. /// The cancellation token. /// /// - /// Session with id {sessionId} already exists + /// Session with ID {sessionId} already exists /// or /// Failed to create model session /// @@ -66,22 +70,22 @@ public async Task CreateAsync(string sessionId, ISessionConfig ses if (_modelSessions.TryGetValue(sessionId, out _)) throw new Exception($"Session with id {sessionId} already exists"); - // Create context + // Create context. var (model, context) = await _modelService.GetOrCreateModelAndContext(sessionConfig.Model, sessionId); - // Create session + // Create session. var modelSession = new ModelSession(model, context, sessionId, sessionConfig, inferenceConfig); if (!_modelSessions.TryAdd(sessionId, modelSession)) throw new Exception($"Failed to create model session"); - // Run initial Prompt + // Run initial prompt. await modelSession.InitializePrompt(inferenceConfig, cancellationToken); return modelSession; } /// - /// Closes the session + /// Closes the session. /// /// The session identifier. /// @@ -96,14 +100,14 @@ public async Task CloseAsync(string sessionId) } /// - /// Runs inference on the current ModelSession + /// Runs inference on the current model session. /// /// The session identifier. - /// The prompt. - /// The inference configuration, if null session default is used + /// The prompt request. + /// The inference configuration; if null, the session default is used. /// The cancellation token. /// Inference is already running for this session - public async IAsyncEnumerable InferAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable InferAsync(string sessionId, PromptRequest request, InferenceOptions inferenceConfig = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { if (!_sessionGuard.Guard(sessionId)) throw new Exception($"Inference is already running for this session"); @@ -113,17 +117,26 @@ public async IAsyncEnumerable InferAsync(string sessionId, string pr if (!_modelSessions.TryGetValue(sessionId, out var modelSession)) yield break; - // Send begin of response + var preparedPrompt = await PreparePromptAsync(sessionId, modelSession, request, cancellationToken); + var formattedPrompt = modelSession.BuildPrompt(preparedPrompt.Prompt); + + // Send start of response. var stopwatch = Stopwatch.GetTimestamp(); yield return new TokenModel(default, default, TokenType.Begin); - // Send content of response - await foreach (var token in modelSession.InferAsync(prompt, inferenceConfig, cancellationToken).ConfigureAwait(false)) + var responseBuilder = new StringBuilder(); + + // Send response content. + await foreach (var token in modelSession.InferAsync(formattedPrompt, inferenceConfig, cancellationToken).ConfigureAwait(false)) { + responseBuilder.Append(token); yield return new TokenModel(default, token); } - // Send end of response + if (!modelSession.IsInferCanceled() && modelSession.StoresHistory) + modelSession.AddAssistantMessage(responseBuilder.ToString()); + + // Send end of response. var elapsedTime = GetElapsed(stopwatch); var endTokenType = modelSession.IsInferCanceled() ? TokenType.Cancel : TokenType.End; var signature = endTokenType == TokenType.Cancel @@ -138,11 +151,11 @@ public async IAsyncEnumerable InferAsync(string sessionId, string pr } /// - /// Runs inference on the current ModelSession + /// Runs inference on the current model session. /// /// The session identifier. /// The prompt. - /// The inference configuration, if null session default is used + /// The inference configuration; if null, the session default is used. /// The cancellation token. /// Streaming async result of /// Inference is already running for this session @@ -150,7 +163,7 @@ public IAsyncEnumerable InferTextAsync(string sessionId, string prompt, { async IAsyncEnumerable InferTextInternal() { - await foreach (var token in InferAsync(sessionId, prompt, inferenceConfig, cancellationToken).ConfigureAwait(false)) + await foreach (var token in InferAsync(sessionId, new PromptRequest { Prompt = prompt }, inferenceConfig, cancellationToken).ConfigureAwait(false)) { if (token.TokenType == TokenType.Content) yield return token.Content; @@ -160,17 +173,17 @@ async IAsyncEnumerable InferTextInternal() } /// - /// Runs inference on the current ModelSession + /// Runs inference on the current model session. /// /// The session identifier. /// The prompt. - /// The inference configuration, if null session default is used + /// The inference configuration; if null, the session default is used. /// The cancellation token. /// Completed inference result as string /// Inference is already running for this session public async Task InferTextCompleteAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) { - var inferResult = await InferAsync(sessionId, prompt, inferenceConfig, cancellationToken) + var inferResult = await InferAsync(sessionId, new PromptRequest { Prompt = prompt }, inferenceConfig, cancellationToken) .Where(x => x.TokenType == TokenType.Content) .Select(x => x.Content) .ToListAsync(cancellationToken: cancellationToken); @@ -193,6 +206,31 @@ public Task CancelAsync(string sessionId) return Task.FromResult(false); } + /// + /// Gets the model capabilities for the specified session. + /// + /// The session identifier. + /// The model capabilities. + public Task GetCapabilitiesAsync(string sessionId) + { + if (_modelSessions.TryGetValue(sessionId, out var modelSession)) + { + return Task.FromResult(new ModelCapabilities + { + SupportsText = true, + SupportsVision = modelSession.SupportsVision, + SupportsAudio = modelSession.SupportsAudio + }); + } + + return Task.FromResult(new ModelCapabilities + { + SupportsText = true, + SupportsVision = false, + SupportsAudio = false + }); + } + /// /// Gets the elapsed time in milliseconds. /// @@ -202,4 +240,121 @@ private static int GetElapsed(long timestamp) { return (int)Stopwatch.GetElapsedTime(timestamp).TotalMilliseconds; } + + private async Task PreparePromptAsync(string sessionId, ModelSession modelSession, PromptRequest request, CancellationToken cancellationToken) + { + if (request == null) + return new PreparedPrompt(string.Empty); + + var promptBuilder = new StringBuilder(request.Prompt ?? string.Empty); + var embeds = new List(); + + try + { + if (request.AttachmentIds != null && request.AttachmentIds.Count > 0) + { + var mediaMarker = GetMediaMarker(); + var attachments = _attachmentService.GetAttachments(sessionId, request.AttachmentIds); + + foreach (var attachment in attachments) + { + switch (attachment.Kind) + { + case AttachmentKind.Pdf: + AppendPdf(promptBuilder, attachment); + break; + case AttachmentKind.Word: + AppendWord(promptBuilder, attachment); + break; + case AttachmentKind.Image: + if (!modelSession.IsMultiModal) + throw new Exception("This model does not support multimodal inputs."); + + embeds.Add(await LoadEmbedAsync(modelSession, attachment, cancellationToken)); + AppendMedia(promptBuilder, attachment, "Image", mediaMarker); + break; + case AttachmentKind.Audio: + if (!modelSession.IsMultiModal || !modelSession.SupportsAudio) + throw new Exception("This model does not support audio inputs."); + + embeds.Add(await LoadEmbedAsync(modelSession, attachment, cancellationToken)); + AppendMedia(promptBuilder, attachment, "Audio", mediaMarker); + break; + } + } + } + + if (embeds.Count > 0) + modelSession.QueueEmbeds(embeds); + + return new PreparedPrompt(promptBuilder.ToString()); + } + catch + { + foreach (var embed in embeds) + embed.Dispose(); + + throw; + } + } + + private static void AppendPdf(StringBuilder builder, AttachmentInfo attachment) + { + if (string.IsNullOrWhiteSpace(attachment.ExtractedText)) + return; + + builder.AppendLine(); + builder.AppendLine($"[PDF: {attachment.FileName}]"); + builder.AppendLine(attachment.ExtractedText.Trim()); + if (attachment.ExtractedTextTruncated) + builder.AppendLine("[PDF text truncated]"); + } + + private static void AppendMedia(StringBuilder builder, AttachmentInfo attachment, string mediaLabel, string mediaMarker) + { + builder.AppendLine(); + builder.AppendLine($"[{mediaLabel}: {attachment.FileName}]"); + builder.AppendLine(mediaMarker); + } + + private static string GetMediaMarker() + { + var mtmdParameters = MtmdContextParams.Default(); + return mtmdParameters.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; + } + + private static void AppendWord(StringBuilder builder, AttachmentInfo attachment) + { + builder.AppendLine(); + builder.AppendLine($"[Word: {attachment.FileName}]"); + + if (string.IsNullOrWhiteSpace(attachment.ExtractedText)) + { + builder.AppendLine("[Word text could not be extracted]"); + return; + } + + builder.AppendLine(attachment.ExtractedText.Trim()); + if (attachment.ExtractedTextTruncated) + builder.AppendLine("[Word text truncated]"); + } + + private static async Task LoadEmbedAsync(ModelSession modelSession, AttachmentInfo attachment, CancellationToken cancellationToken) + { + if (!File.Exists(attachment.FilePath)) + throw new FileNotFoundException("Attachment not found.", attachment.FilePath); + + var data = await File.ReadAllBytesAsync(attachment.FilePath, cancellationToken); + var clipModel = modelSession.ClipModel; + if (clipModel is null) + throw new Exception("Multimodal model is not available for this session."); + + var embed = clipModel.LoadMedia(data); + if (embed == null) + throw new Exception("Failed to prepare multimodal embedding."); + + return embed; + } + + private sealed record PreparedPrompt(string Prompt); } diff --git a/LLama.Web/Shared/MainLayout.razor b/LLama.Web/Shared/MainLayout.razor new file mode 100644 index 000000000..224e0204e --- /dev/null +++ b/LLama.Web/Shared/MainLayout.razor @@ -0,0 +1,5 @@ +@inherits LayoutComponentBase + +
+ @Body +
diff --git a/LLama.Web/_Imports.razor b/LLama.Web/_Imports.razor new file mode 100644 index 000000000..143d6c514 --- /dev/null +++ b/LLama.Web/_Imports.razor @@ -0,0 +1,13 @@ +@using System.Net.Http +@using System.Net.Http.Json +@using System.IO +@using Microsoft.AspNetCore.Components +@using Microsoft.AspNetCore.Components.Forms +@using Microsoft.AspNetCore.Components.Routing +@using Microsoft.AspNetCore.Components.Web +@using Microsoft.Extensions.Options +@using Microsoft.JSInterop +@using LLama.Web.Common +@using LLama.Web.Models +@using LLama.Web.Services +@using LLama.Sampling diff --git a/LLama.Web/appsettings.json b/LLama.Web/appsettings.json index caa27cc64..41465acff 100644 --- a/LLama.Web/appsettings.json +++ b/LLama.Web/appsettings.json @@ -10,11 +10,15 @@ "ModelLoadType": 0, "Models": [ { - "Name": "Example LLava-v1.6-mistral", + "Name": "Qwen3-VL-2B-Instruct Q4_K_M", "MaxInstances": 20, - "ModelPath": "..\\LLama.Unittest\\Models\\llava-v1.6-mistral-7b.Q3_K_XS.gguf", - "ContextSize": 2048, - "BatchSize": 2048, + "ModelPath": "Models/Qwen3VL-2B-Instruct-Q4_K_M.gguf", + "ModelDownloadUrl": "https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q4_K_M.gguf", + "MmprojPath": "Models/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf", + "MmprojDownloadUrl": "https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf", + "ContextSize": 4096, + "BatchSize": 512, + "UBatchSize": 512, "Threads": 4, "GpuLayerCount": 32, "UseMemorymap": true, @@ -22,7 +26,67 @@ "MainGpu": 0, "LowVram": false, "Seed": 1686349486, - "UseFp16Memory": true, + "UseFp16Memory": false, + "Perplexity": false, + "LoraAdapter": "", + "LoraBase": "", + "EmbeddingMode": false, + "TensorSplits": null, + "GroupedQueryAttention": 1, + "RmsNormEpsilon": 0.000005, + "RopeFrequencyBase": 10000.0, + "RopeFrequencyScale": 1.0, + "MulMatQ": false, + "Encoding": "UTF-8" + }, + { + "Name": "Gemma-3-4B-IT Q4_K_M", + "MaxInstances": 20, + "ModelPath": "Models/google_gemma-3-4b-it-q4_k_m.gguf", + "ModelDownloadUrl": "https://huggingface.co/ggml-org/gemma-3-4b-it-GGUF/resolve/main/gemma-3-4b-it-Q4_K_M.gguf", + "MmprojPath": "Models/google_gemma-3-4b-it-mmproj-f16.gguf", + "MmprojDownloadUrl": "https://huggingface.co/ggml-org/gemma-3-4b-it-GGUF/resolve/main/mmproj-model-f16.gguf", + "ContextSize": 4096, + "BatchSize": 512, + "UBatchSize": 512, + "Threads": 4, + "GpuLayerCount": 32, + "UseMemorymap": true, + "UseMemoryLock": false, + "MainGpu": 0, + "LowVram": false, + "Seed": 1686349486, + "UseFp16Memory": false, + "Perplexity": false, + "LoraAdapter": "", + "LoraBase": "", + "EmbeddingMode": false, + "TensorSplits": null, + "GroupedQueryAttention": 1, + "RmsNormEpsilon": 0.000005, + "RopeFrequencyBase": 10000.0, + "RopeFrequencyScale": 1.0, + "MulMatQ": false, + "Encoding": "UTF-8" + }, + { + "Name": "Qwen2.5-Omni-3B Q4_K_M (Audio)", + "MaxInstances": 20, + "ModelPath": "Models/Qwen2.5-Omni-3B-Q4_K_M.gguf", + "ModelDownloadUrl": "https://huggingface.co/ggml-org/Qwen2.5-Omni-3B-GGUF/resolve/main/Qwen2.5-Omni-3B-Q4_K_M.gguf", + "MmprojPath": "Models/mmproj-Qwen2.5-Omni-3B-Q8_0.gguf", + "MmprojDownloadUrl": "https://huggingface.co/ggml-org/Qwen2.5-Omni-3B-GGUF/resolve/main/mmproj-Qwen2.5-Omni-3B-Q8_0.gguf", + "ContextSize": 4096, + "BatchSize": 512, + "UBatchSize": 512, + "Threads": 4, + "GpuLayerCount": 32, + "UseMemorymap": true, + "UseMemoryLock": false, + "MainGpu": 0, + "LowVram": false, + "Seed": 1686349486, + "UseFp16Memory": false, "Perplexity": false, "LoraAdapter": "", "LoraBase": "", diff --git a/LLama.Web/libman.json b/LLama.Web/libman.json index bb737e097..8a7b77c9b 100644 --- a/LLama.Web/libman.json +++ b/LLama.Web/libman.json @@ -5,6 +5,71 @@ { "library": "jquery@3.7.1", "destination": "wwwroot/lib/jquery/" + }, + { + "library": "katex@0.16.11", + "provider": "unpkg", + "destination": "wwwroot/lib/katex/", + "files": [ + "dist/katex.min.css", + "dist/katex.min.js" + ] + }, + { + "library": "markdown-it@14.1.0", + "provider": "unpkg", + "destination": "wwwroot/lib/markdown-it/", + "files": [ + "dist/markdown-it.min.js" + ] + }, + { + "library": "markdown-it-katex@2.0.3", + "provider": "unpkg", + "destination": "wwwroot/lib/markdown-it-katex/" + }, + { + "library": "markdown-it-task-lists@2.1.1", + "provider": "unpkg", + "destination": "wwwroot/lib/markdown-it-task-lists/" + }, + { + "library": "markdown-it-footnote@4.0.0", + "provider": "unpkg", + "destination": "wwwroot/lib/markdown-it-footnote/" + }, + { + "library": "markdown-it-deflist@2.1.0", + "provider": "unpkg", + "destination": "wwwroot/lib/markdown-it-deflist/" + }, + { + "library": "markdown-it-sub@2.0.0", + "provider": "unpkg", + "destination": "wwwroot/lib/markdown-it-sub/" + }, + { + "library": "markdown-it-sup@2.0.0", + "provider": "unpkg", + "destination": "wwwroot/lib/markdown-it-sup/" + }, + { + "library": "markdown-it-mark@4.0.0", + "provider": "unpkg", + "destination": "wwwroot/lib/markdown-it-mark/" + }, + { + "library": "markdown-it-emoji@3.0.0", + "provider": "unpkg", + "destination": "wwwroot/lib/markdown-it-emoji/" + }, + { + "library": "mermaid@10.9.1", + "provider": "unpkg", + "destination": "wwwroot/lib/mermaid/", + "files": [ + "dist/mermaid.min.js" + ] } ] -} \ No newline at end of file +} diff --git a/LLama.Web/wwwroot/css/site.css b/LLama.Web/wwwroot/css/site.css index f10d5e78c..217617950 100644 --- a/LLama.Web/wwwroot/css/site.css +++ b/LLama.Web/wwwroot/css/site.css @@ -1,106 +1,570 @@ @font-face { - font-family: 'Roboto'; + font-family: "Roboto"; src: url('/fonts/Roboto-Regular.ttf'); font-display: swap; } +:root { + color-scheme: dark; + --bg-0: #0b0d10; + --bg-1: #0f1218; + --bg-2: #141926; + --panel: #111722; + --panel-2: #161c2a; + --panel-3: #1b2233; + --border: rgba(255, 255, 255, 0.08); + --border-strong: rgba(255, 255, 255, 0.16); + --text-primary: #e7ecf3; + --text-muted: #9aa6b2; + --accent: #6ae3ff; + --accent-2: #66d1a3; + --warning: #ffb86b; + --danger: #ff7a7a; + --shadow: 0 18px 40px rgba(0, 0, 0, 0.35); +} + +* { + box-sizing: border-box; +} + html, body { - font-size: 14px; height: 100%; + margin: 0; + font-family: "Roboto", "Segoe UI", sans-serif; + background: radial-gradient(circle at top, #121826 0%, #0b0d10 45%, #090b0f 100%); + color: var(--text-primary); +} + +main { + min-height: 100%; +} + +.app-shell { + min-height: 100vh; + padding: 16px; +} + +.app-grid { + display: grid; + grid-template-columns: minmax(260px, 320px) 1fr; + gap: 16px; + min-height: calc(100vh - 32px); +} + +.sidebar { + background: linear-gradient(180deg, #0f1420 0%, #0b0d12 100%); + border: 1px solid var(--border); + border-radius: 18px; + padding: 18px; display: flex; flex-direction: column; - font-family: 'Roboto'; + gap: 18px; + box-shadow: var(--shadow); } -main { +.sidebar-header { + display: flex; + align-items: center; + justify-content: space-between; + gap: 12px; +} + +.brand { + display: flex; + align-items: center; + gap: 10px; +} + +.brand img { + width: 32px; + height: 32px; +} + +.brand h1 { + font-size: 1.1rem; + margin: 0; + letter-spacing: 0.02em; +} + +.sidebar-section { + background: var(--panel); + border: 1px solid var(--border); + border-radius: 14px; + padding: 12px; + display: flex; + flex-direction: column; + gap: 10px; +} + +.capability-row { + display: flex; + gap: 8px; + flex-wrap: wrap; +} + +.capability-pill { + padding: 6px 10px; + border-radius: 999px; + font-size: 0.75rem; + font-weight: 600; + border: 1px solid var(--border); +} + +.capability-pill.enabled { + background: rgba(106, 227, 255, 0.18); + color: var(--text-primary); + border-color: rgba(106, 227, 255, 0.3); +} + +.capability-pill.disabled { + background: rgba(255, 255, 255, 0.04); + color: var(--text-muted); +} + +.sidebar-tabs { + display: flex; + gap: 8px; + background: var(--panel); + border: 1px solid var(--border); + border-radius: 999px; + padding: 4px; +} + +.tab-button { flex: 1; - overflow: auto; + border: none; + background: transparent; + color: var(--text-muted); + padding: 8px 10px; + border-radius: 999px; + font-weight: 600; + cursor: pointer; } -footer { - bottom: 0; +.tab-button.active { + background: rgba(106, 227, 255, 0.18); + color: var(--text-primary); +} + +.section-title { + font-size: 0.9rem; + color: var(--text-muted); + letter-spacing: 0.08em; + text-transform: uppercase; +} + +.select, .text-input, .textarea { width: 100%; - white-space: nowrap; - line-height: 60px; + background: var(--panel-2); + border: 1px solid var(--border-strong); + color: var(--text-primary); + border-radius: 10px; + padding: 10px 12px; } +.textarea { + min-height: 80px; + resize: vertical; +} -@media (min-width: 768px) { - html { - font-size: 16px; - } +.button-row { + display: flex; + gap: 8px; +} + +.btn-primary, .btn-secondary, .btn-danger { + border: none; + padding: 10px 14px; + border-radius: 10px; + font-weight: 600; + cursor: pointer; +} + +.btn-primary { + background: linear-gradient(135deg, #4f8cff 0%, #6ae3ff 100%); + color: #0b0d10; +} + +.btn-secondary { + background: var(--panel-3); + color: var(--text-primary); + border: 1px solid var(--border); +} + +.btn-danger { + background: linear-gradient(135deg, #ff7a7a 0%, #ffb86b 100%); + color: #1b0c0c; +} + +.btn-primary:disabled, +.btn-secondary:disabled, +.btn-danger:disabled { + opacity: 0.5; + cursor: not-allowed; +} + +.sidebar-meta { + font-size: 0.75rem; + color: var(--text-muted); + line-height: 1.4; +} + +.download-card { + border-radius: 12px; + border: 1px solid var(--border); + padding: 10px; + background: var(--panel-2); + display: flex; + flex-direction: column; + gap: 6px; +} + +.download-card header { + display: flex; + justify-content: space-between; + font-weight: 600; +} + +.progress-track { + width: 100%; + height: 6px; + border-radius: 999px; + background: rgba(255, 255, 255, 0.06); + overflow: hidden; +} + +.progress-fill { + height: 100%; + background: linear-gradient(90deg, var(--accent) 0%, var(--accent-2) 100%); + border-radius: 999px; +} + +.download-meta { + font-size: 0.75rem; + color: var(--text-muted); +} + +.download-asset { + border-radius: 10px; + padding: 8px; + background: rgba(255, 255, 255, 0.03); + border: 1px solid rgba(255, 255, 255, 0.06); + display: flex; + flex-direction: column; + gap: 4px; +} + +.download-asset-title { + display: flex; + justify-content: space-between; + font-size: 0.8rem; +} + +.chat-column { + display: flex; + flex-direction: column; + gap: 12px; + background: linear-gradient(180deg, rgba(20, 25, 38, 0.92) 0%, rgba(12, 15, 22, 0.96) 100%); + border-radius: 22px; + border: 1px solid var(--border); + padding: 18px; + box-shadow: var(--shadow); + min-height: 100%; + height: calc(100vh - 32px); +} + +.chat-header { + display: flex; + align-items: center; + justify-content: space-between; + gap: 12px; +} + +.chat-title { + display: flex; + flex-direction: column; + gap: 4px; +} + +.chat-title h2 { + margin: 0; + font-size: 1.3rem; +} + +.chat-title span { + color: var(--text-muted); + font-size: 0.85rem; +} + +.status-pill { + padding: 6px 10px; + border-radius: 999px; + font-size: 0.75rem; + border: 1px solid var(--border); + background: var(--panel-2); } -#scroll-container { +.status-pill.live { + border-color: rgba(106, 227, 255, 0.5); + color: var(--accent); +} + +.status-pill.warn { + border-color: rgba(255, 184, 107, 0.5); + color: var(--warning); +} + +.chat-thread { flex: 1; overflow-y: auto; - border-radius: var(--bs-border-radius); + padding-right: 6px; + display: flex; + flex-direction: column; + gap: 18px; + min-height: 0; +} + +.chat-message { + display: flex; + gap: 12px; + align-items: flex-start; +} + +.chat-message .avatar { + width: 36px; + height: 36px; + border-radius: 12px; + background: var(--panel-3); + display: grid; + place-items: center; + font-weight: 700; + color: var(--text-muted); +} + +.chat-message.user .avatar { + background: rgba(106, 227, 255, 0.16); + color: var(--accent); +} + +.chat-message.assistant .avatar { + background: rgba(102, 209, 163, 0.16); + color: var(--accent-2); +} + +.chat-message.error .avatar { + background: rgba(255, 122, 122, 0.2); + color: var(--danger); +} + +.chat-message.info .avatar { + background: rgba(106, 227, 255, 0.12); + color: var(--accent); +} + +.chat-bubble { + background: var(--panel); + border-radius: 16px; + padding: 12px 14px; + border: 1px solid var(--border); + width: 100%; +} + +.chat-message.user .chat-bubble { + background: linear-gradient(135deg, rgba(79, 140, 255, 0.16), rgba(106, 227, 255, 0.12)); +} + +.chat-message.assistant .chat-bubble { + background: rgba(22, 28, 42, 0.85); +} + +.chat-message.info .chat-bubble { + border-color: rgba(106, 227, 255, 0.2); + background: rgba(20, 30, 45, 0.7); +} + +.message-meta { + margin-top: 8px; + font-size: 0.75rem; + color: var(--text-muted); +} + +.attachment-tags { + display: flex; + flex-wrap: wrap; + gap: 6px; + margin-top: 8px; +} + +.attachment-images { + display: flex; + flex-wrap: wrap; + gap: 10px; + margin-top: 10px; } -.slider-container > .slider { +.attachment-image { + max-width: min(360px, 100%); + border-radius: 12px; + border: 1px solid rgba(255, 255, 255, 0.12); + box-shadow: 0 8px 18px rgba(0, 0, 0, 0.25); +} + +.attachment-audio { + display: flex; + flex-direction: column; + gap: 8px; + margin-top: 10px; +} + +.attachment-audio-item { + display: flex; + flex-direction: column; + gap: 6px; +} + +.attachment-audio-name { + font-size: 0.8rem; + color: var(--text-muted); +} + +.attachment-audio audio { width: 100%; } -.slider-container > label { - width: 50px; - text-align: center; +.attachment-tag { + padding: 4px 8px; + border-radius: 999px; + background: rgba(255, 255, 255, 0.08); + font-size: 0.75rem; + color: var(--text-muted); } -.form-floating > textarea.form-control { - height: 100px; +.composer { + border-radius: 16px; + border: 1px solid var(--border); + padding: 12px; + background: var(--panel); + display: flex; + flex-direction: column; + gap: 10px; + position: sticky; + bottom: 0; } -.navbar-image { - height: 30px; - margin-left: 5px; - margin-right: 10px; +.draft-preview { + padding: 0 4px; } -.nav-item.theme { - margin-top: 2px; +.draft-bubble { + border-color: rgba(106, 227, 255, 0.35); + background: rgba(79, 140, 255, 0.12); } -/* Dark theme values */ -[data-bs-theme=dark] #scroll-container { - background-color: #202020; +.draft-text { + white-space: pre-wrap; + color: var(--text-primary); } -[data-bs-theme=dark] html, -[data-bs-theme=dark] body, -[data-bs-theme=dark] #sidebar, -[data-bs-theme=dark] header { - background-color: #191919; +.composer-error { + color: var(--danger); + font-size: 0.85rem; + padding: 4px 6px; + border-radius: 8px; + background: rgba(255, 122, 122, 0.08); + border: 1px solid rgba(255, 122, 122, 0.3); +} + +.composer textarea { + width: 100%; + background: transparent; + border: none; + resize: none; + color: var(--text-primary); + min-height: 72px; } -.entry { - padding-top: 10px; - padding-bottom: 20px; - padding-right: 10px; +.composer textarea:focus { + outline: none; } - .entry .signature, .entry .entry-date { - color: #808080; - font-size: 0.9em; - font-style: italic; +.composer-row { + display: flex; + align-items: center; + justify-content: space-between; + gap: 10px; +} + +.composer-actions { + display: flex; + gap: 8px; +} + +.file-button { + position: relative; + overflow: hidden; +} + +.file-input { + position: absolute; + inset: 0; + opacity: 0; + cursor: pointer; +} + +.file-pill { + padding: 6px 10px; + border-radius: 999px; + font-size: 0.75rem; + background: rgba(255, 255, 255, 0.08); + color: var(--text-muted); +} + +.markdown-body pre { + padding: 12px; + border-radius: 10px; + background-color: rgba(0, 0, 0, 0.25); + overflow-x: auto; +} + +.markdown-body code { + padding: 2px 4px; + border-radius: 4px; + background-color: rgba(0, 0, 0, 0.25); +} + +.markdown-body table { + width: 100%; + border-collapse: collapse; + margin-top: 8px; +} + +.markdown-body th, +.markdown-body td { + padding: 6px 8px; + border: 1px solid rgba(255, 255, 255, 0.1); +} + +.markdown-body img { + max-width: 100%; + border-radius: 10px; +} + +.markdown-body blockquote { + border-left: 3px solid var(--accent); + padding-left: 12px; + color: var(--text-muted); +} + +@media (max-width: 960px) { + .app-grid { + grid-template-columns: 1fr; } - .entry .content, .entry .signature { - padding-left: 5px; - margin-left: 50px; + .sidebar { + order: 2; } - .entry .entry-header .participant { - font-weight: bold; - font-size: 1.4em; - padding-bottom: 5px; - } - - .entry .entry-header .participant .participant-icon { - display: inline-block; - width: 50px; - text-align: center - } - - .entry .entry-header .participant .participant-name { - display: inline-block; - } + .chat-column { + order: 1; + } +} diff --git a/LLama.Web/wwwroot/js/blazorChat.js b/LLama.Web/wwwroot/js/blazorChat.js new file mode 100644 index 000000000..5099e04b0 --- /dev/null +++ b/LLama.Web/wwwroot/js/blazorChat.js @@ -0,0 +1,193 @@ +window.llamaChat = (() => { + let markdown = null; + let audioRecorder = null; + let audioChunks = []; + let audioStream = null; + + const ensureMarkdown = () => { + if (markdown) { + return markdown; + } + + markdown = window.markdownit({ + html: false, + linkify: true, + breaks: true + }); + + if (window.markdownitKatex) { + markdown.use(window.markdownitKatex); + } + if (window.markdownitTaskLists) { + markdown.use(window.markdownitTaskLists); + } + if (window.markdownitFootnote) { + markdown.use(window.markdownitFootnote); + } + if (window.markdownitDeflist) { + markdown.use(window.markdownitDeflist); + } + if (window.markdownitSub) { + markdown.use(window.markdownitSub); + } + if (window.markdownitSup) { + markdown.use(window.markdownitSup); + } + if (window.markdownitMark) { + markdown.use(window.markdownitMark); + } + if (window.markdownitEmoji) { + markdown.use(window.markdownitEmoji); + } + + if (window.mermaid) { + window.mermaid.initialize({ startOnLoad: false, theme: "dark" }); + } + + return markdown; + }; + + const renderMermaid = (container) => { + if (!window.mermaid || !container) { + return; + } + + const codeBlocks = container.querySelectorAll("pre > code.language-mermaid, pre > code.lang-mermaid, code.language-mermaid"); + if (!codeBlocks.length) { + return; + } + + codeBlocks.forEach((code) => { + const parent = code.parentElement; + const mermaidDiv = document.createElement("div"); + mermaidDiv.className = "mermaid"; + mermaidDiv.textContent = code.textContent || ""; + parent.replaceWith(mermaidDiv); + }); + + window.mermaid.run({ nodes: container.querySelectorAll('.mermaid') }); + }; + + const isAudioRecordingSupported = () => { + return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia && window.MediaRecorder); + }; + + const getSupportedAudioMimeType = () => { + if (!window.MediaRecorder || !MediaRecorder.isTypeSupported) { + return ""; + } + + const candidates = [ + "audio/webm;codecs=opus", + "audio/webm", + "audio/ogg;codecs=opus", + "audio/ogg", + "audio/mp4", + "audio/mpeg" + ]; + + return candidates.find((candidate) => MediaRecorder.isTypeSupported(candidate)) || ""; + }; + + const cleanupAudioStream = () => { + if (audioStream) { + audioStream.getTracks().forEach((track) => track.stop()); + audioStream = null; + } + }; + + const bufferToBase64 = (buffer) => { + const bytes = new Uint8Array(buffer); + let binary = ""; + const chunkSize = 0x8000; + for (let i = 0; i < bytes.length; i += chunkSize) { + binary += String.fromCharCode.apply(null, bytes.subarray(i, i + chunkSize)); + } + return btoa(binary); + }; + + return { + initialize: () => { + ensureMarkdown(); + }, + renderMarkdown: (element, text) => { + if (!element) { + return; + } + + const md = ensureMarkdown(); + element.innerHTML = md.render(text || ""); + renderMermaid(element); + }, + scrollToBottom: (element) => { + if (!element) { + return; + } + element.scrollTop = element.scrollHeight; + }, + getSidebarTab: () => { + return window.localStorage.getItem("llama.sidebar.tab"); + }, + setSidebarTab: (value) => { + if (!value) { + return; + } + window.localStorage.setItem("llama.sidebar.tab", value); + }, + isAudioRecordingSupported: () => { + return isAudioRecordingSupported(); + }, + startAudioRecording: async () => { + if (!isAudioRecordingSupported()) { + return { started: false, error: "Audio recording is not supported in this browser." }; + } + + if (audioRecorder && audioRecorder.state === "recording") { + return { started: false, error: "Audio recording is already in progress." }; + } + + try { + audioStream = await navigator.mediaDevices.getUserMedia({ audio: true }); + const mimeType = getSupportedAudioMimeType(); + const options = mimeType ? { mimeType } : undefined; + audioRecorder = new MediaRecorder(audioStream, options); + audioChunks = []; + + audioRecorder.ondataavailable = (event) => { + if (event.data && event.data.size > 0) { + audioChunks.push(event.data); + } + }; + + audioRecorder.start(); + return { started: true, mimeType: audioRecorder.mimeType || mimeType || "" }; + } catch (err) { + cleanupAudioStream(); + return { started: false, error: err?.message || "Unable to access the microphone." }; + } + }, + stopAudioRecording: () => { + if (!audioRecorder || audioRecorder.state !== "recording") { + cleanupAudioStream(); + audioRecorder = null; + audioChunks = []; + return Promise.resolve({ base64: "", mimeType: "", size: 0 }); + } + + return new Promise((resolve) => { + audioRecorder.onstop = async () => { + const blobType = audioRecorder.mimeType || "audio/webm"; + const blob = new Blob(audioChunks, { type: blobType }); + const buffer = await blob.arrayBuffer(); + const base64 = bufferToBase64(buffer); + cleanupAudioStream(); + audioRecorder = null; + audioChunks = []; + resolve({ base64, mimeType: blob.type, size: blob.size }); + }; + + audioRecorder.stop(); + }); + } + }; +})(); diff --git a/LLama.Web/wwwroot/js/sessionConnectionChat.js b/LLama.Web/wwwroot/js/sessionConnectionChat.js index 8ecf0177a..0afd1f35b 100644 --- a/LLama.Web/wwwroot/js/sessionConnectionChat.js +++ b/LLama.Web/wwwroot/js/sessionConnectionChat.js @@ -6,13 +6,30 @@ const createConnectionSessionChat = () => { const signatureTemplate = $("#signatureTemplate").html(); let inferenceSession; + let connectionId; const connection = new signalR.HubConnectionBuilder().withUrl("/SessionConnectionHub").build(); const scrollContainer = $("#scroll-container"); const outputContainer = $("#output-container"); const chatInput = $("#input"); - - const onStatus = (connection, status) => { + const modelSelect = $("#model-select"); + const downloadList = $("#download-list"); + const downloadEmpty = $("#download-empty"); + const attachmentInput = $("#attachments"); + const attachmentList = $("#attachment-list"); + const clearAttachmentsButton = $("#clear-attachments"); + const modelsPath = $("#models-path"); + const downloadsPath = $("#downloads-path"); + const uploadsPath = $("#uploads-path"); + + const downloadState = {}; + let pendingAttachments = []; + + const renderQueue = new Map(); + let renderScheduled = false; + + const onStatus = (id, status) => { + connectionId = id; if (status == Enums.SessionConnectionStatus.Disconnected) { onError("Socket not connected") } @@ -24,7 +41,6 @@ const createConnectionSessionChat = () => { enableControls(); $("#load").hide(); $("#unload").show(); - onInfo(`New model session successfully started`) } } @@ -37,9 +53,18 @@ const createConnectionSessionChat = () => { outputContainer.append(Mustache.render(outputInfoTemplate, { text: message, date: getDateTime() })); } + const onStorageInfo = (info) => { + if (!info) + return; + + modelsPath.text(info.modelsPath || "-"); + downloadsPath.text(info.downloadsPath || "-"); + uploadsPath.text(info.uploadsPath || "-"); + } + let responseContent; let responseContainer; - let responseFirstToken; + let responseRaw = ""; const onResponse = (response) => { if (!response) @@ -50,7 +75,8 @@ const createConnectionSessionChat = () => { outputContainer.append(Mustache.render(outputBotTemplate, { uniqueId: uniqueId, ...response })); responseContainer = $(`#${uniqueId}`); responseContent = responseContainer.find(".content"); - responseFirstToken = true; + responseRaw = ""; + responseContent.text("..."); scrollToBottom(true); return; } @@ -61,34 +87,53 @@ const createConnectionSessionChat = () => { scrollToBottom(); } else { - if (responseFirstToken) { - responseContent.empty(); - responseFirstToken = false; - responseContainer.find(".date").append(getDateTime()); - responseContent.append(response.content.trim()); - } - else { - responseContent.append(response.content); + if (responseContent) { + if (responseRaw.length === 0) { + responseContainer.find(".date").append(getDateTime()); + } + responseRaw += response.content; + scheduleRender(responseContent, responseRaw); + scrollToBottom(); } - scrollToBottom(); } } const sendPrompt = async () => { const text = chatInput.val(); - if (text) { - chatInput.val(null); - disableControls(); - outputContainer.append(Mustache.render(outputUserTemplate, { text: text, date: getDateTime() })); - inferenceSession = await connection - .stream("SendPrompt", text, serializeFormToJson('SessionParameters')) - .subscribe({ - next: onResponse, - complete: onResponse, - error: onError, - }); - scrollToBottom(true); + if (!text) + return; + + if (!isSelectedModelReady()) { + onError("Selected model is still downloading."); + return; } + + chatInput.val(null); + disableControls(); + + let attachments = []; + if (pendingAttachments.length > 0) { + try { + attachments = await uploadAttachments(); + } catch (error) { + onError(error.message || "Failed to upload attachments."); + enableControls(); + return; + } + } + + outputContainer.append(Mustache.render(outputUserTemplate, { text: text, date: getDateTime(), attachments: attachments })); + renderMarkdownAsync(outputContainer.find(".entry:last .content"), text); + + const request = { prompt: text, attachmentIds: attachments.map(a => a.id) }; + inferenceSession = await connection + .stream("SendPrompt", request, serializeFormToJson('SessionParameters')) + .subscribe({ + next: onResponse, + complete: onResponse, + error: onError, + }); + scrollToBottom(true); } const cancelPrompt = async () => { @@ -97,6 +142,11 @@ const createConnectionSessionChat = () => { } const loadModel = async () => { + if (!isSelectedModelReady()) { + onError("Selected model is still downloading."); + return; + } + const sessionParams = serializeFormToJson('SessionParameters'); loaderShow(); disableControls(); @@ -114,6 +164,31 @@ const createConnectionSessionChat = () => { $("#load").removeAttr("disabled"); } + const uploadAttachments = async () => { + if (pendingAttachments.length === 0) + return []; + + if (!connectionId) + throw new Error("Connection not ready for uploads."); + + const formData = new FormData(); + formData.append("connectionId", connectionId); + pendingAttachments.forEach(file => formData.append("files", file)); + + onInfo("Uploading attachments..."); + const response = await fetch("/api/attachments", { method: "POST", body: formData }); + if (!response.ok) { + const errorText = await response.text(); + throw new Error(errorText || "Attachment upload failed."); + } + + const result = await response.json(); + pendingAttachments = []; + attachmentInput.val(""); + renderAttachmentList(); + return result.attachments || []; + } + const serializeFormToJson = (form) => { const formDataJson = {}; const formData = new FormData(document.getElementById(form)); @@ -150,6 +225,7 @@ const createConnectionSessionChat = () => { $("#unload").hide(); $(".prompt-control").removeAttr("disabled"); activatePromptTab(); + updateLoadButtonState(); } const disablePromptControls = () => { @@ -205,6 +281,254 @@ const createConnectionSessionChat = () => { $(".spinner").hide(); } + const createMarkdownRenderer = () => { + if (typeof window.markdownit === "undefined") + return null; + + const md = window.markdownit({ + html: false, + linkify: true, + breaks: true, + typographer: true + }); + + if (window.markdownitKatex) + md.use(window.markdownitKatex); + if (window.markdownitTaskLists) + md.use(window.markdownitTaskLists, { enabled: true, label: true }); + if (window.markdownitFootnote) + md.use(window.markdownitFootnote); + if (window.markdownitDeflist) + md.use(window.markdownitDeflist); + if (window.markdownitSub) + md.use(window.markdownitSub); + if (window.markdownitSup) + md.use(window.markdownitSup); + if (window.markdownitMark) + md.use(window.markdownitMark); + if (window.markdownitEmoji) + md.use(window.markdownitEmoji); + + if (window.mermaid) { + window.mermaid.initialize({ startOnLoad: false }); + } + + return md; + } + + const markdown = createMarkdownRenderer(); + + const scheduleRender = (container, raw) => { + if (!markdown) { + container.text(raw); + return; + } + + renderQueue.set(container[0], { container, raw }); + if (renderScheduled) + return; + + renderScheduled = true; + requestAnimationFrame(() => { + renderScheduled = false; + renderQueue.forEach((value) => { + renderMarkdown(value.container, value.raw); + }); + renderQueue.clear(); + }); + } + + const renderMarkdownAsync = (container, raw) => { + scheduleRender(container, raw); + } + + const renderMarkdown = (container, raw) => { + try { + container.html(markdown.render(raw || "")); + renderMermaid(container); + } catch (error) { + container.text(raw); + } + } + + const renderMermaid = (container) => { + if (!window.mermaid) + return; + + const blocks = container.find("pre code.language-mermaid"); + if (!blocks.length) + return; + + blocks.each((index, block) => { + const code = $(block); + const parent = code.parent(); + const mermaidDiv = $("
"); + mermaidDiv.text(code.text()); + parent.replaceWith(mermaidDiv); + }); + + window.mermaid.run({ nodes: container[0].querySelectorAll('.mermaid') }); + } + + const renderAttachmentList = () => { + attachmentList.empty(); + if (pendingAttachments.length === 0) + return; + + pendingAttachments.forEach(file => { + attachmentList.append(`
${file.name}
`); + }); + } + + const formatBytes = (value) => { + if (value === null || value === undefined) + return ""; + if (value === 0) + return "0 B"; + + const k = 1024; + const sizes = ["B", "KB", "MB", "GB", "TB"]; + const i = Math.floor(Math.log(value) / Math.log(k)); + return `${(value / Math.pow(k, i)).toFixed(1)} ${sizes[i]}`; + } + + const updateDownloadState = (progress) => { + if (!progress || !progress.modelName) + return; + + if (!downloadState[progress.modelName]) + downloadState[progress.modelName] = { assets: {}, ready: false }; + + const assetKey = `${progress.assetKind}-${progress.fileName}`; + downloadState[progress.modelName].assets[assetKey] = progress; + refreshDownloadUI(); + } + + const applyDownloadSnapshot = (snapshots) => { + if (!snapshots) + return; + + snapshots.forEach(snapshot => { + downloadState[snapshot.modelName] = { assets: {}, ready: snapshot.ready }; + snapshot.assets.forEach(asset => { + const assetKey = `${asset.assetKind}-${asset.fileName}`; + downloadState[snapshot.modelName].assets[assetKey] = asset; + }); + }); + + refreshDownloadUI(); + } + + const refreshDownloadUI = () => { + downloadList.empty(); + let anyActive = false; + + Object.keys(downloadState).forEach(modelName => { + const modelEntry = downloadState[modelName]; + const assets = Object.values(modelEntry.assets); + if (assets.length === 0) + return; + + const ready = assets.every(asset => asset.state === 2); + modelEntry.ready = ready; + + const totalBytes = assets.reduce((sum, asset) => sum + (asset.totalBytes || 0), 0); + const receivedBytes = assets.reduce((sum, asset) => sum + (asset.bytesReceived || 0), 0); + const hasUnknownTotal = assets.some(asset => asset.totalBytes === null || asset.totalBytes === undefined); + const hasFailure = assets.some(asset => asset.state === 3 || asset.state === 4); + const isDownloading = assets.some(asset => asset.state === 1); + + const statusText = ready + ? "Ready" + : hasFailure + ? "Error" + : isDownloading + ? "Downloading" + : "Queued"; + + const progressPercent = totalBytes > 0 ? Math.min(100, Math.round((receivedBytes / totalBytes) * 100)) : 0; + const remainingText = hasUnknownTotal + ? "Calculating size..." + : `${formatBytes(Math.max(0, totalBytes - receivedBytes))} remaining`; + + if (!ready) + anyActive = true; + + const assetRows = assets.map(asset => { + const assetTotal = asset.totalBytes || 0; + const assetPercent = assetTotal > 0 ? Math.min(100, Math.round((asset.bytesReceived / assetTotal) * 100)) : 0; + const assetStatus = asset.state === 2 + ? "Ready" + : asset.state === 3 + ? "Failed" + : asset.state === 4 + ? "Missing URL" + : asset.state === 1 + ? "Downloading" + : "Queued"; + const assetMeta = assetTotal > 0 + ? `${formatBytes(asset.bytesReceived)} / ${formatBytes(assetTotal)}` + : asset.bytesReceived > 0 + ? formatBytes(asset.bytesReceived) + : "-"; + + return `
+
+
${asset.fileName}
+
${assetStatus}
+
+
+
+
+
${assetMeta}
+
`; + }).join(""); + + const item = $( + `
+
+
${modelName}
+
${statusText}
+
+
+
+
+
${remainingText}
+
${assetRows}
+
` + ); + + downloadList.append(item); + updateModelOptionState(modelName, ready); + }); + + downloadEmpty.toggle(!anyActive); + updateLoadButtonState(); + } + + const updateModelOptionState = (modelName, ready) => { + const option = modelSelect.find(`option[value='${modelName}']`); + option.prop("disabled", !ready); + } + + const isSelectedModelReady = () => { + const selected = modelSelect.val(); + if (!selected) + return false; + + const entry = downloadState[selected]; + return entry ? entry.ready : false; + } + + const updateLoadButtonState = () => { + const ready = isSelectedModelReady(); + if (ready) { + $("#load").removeAttr("disabled"); + } else { + $("#load").attr("disabled", "disabled"); + } + } + // Map UI functions $("#load").on("click", loadModel); $("#unload").on("click", unloadModel); @@ -212,6 +536,16 @@ const createConnectionSessionChat = () => { $("#clear").on("click", clearOutput); $("#cancel").on("click", cancelPrompt); $("#Prompt").on("change", updatePrompt); + modelSelect.on("change", updateLoadButtonState); + attachmentInput.on("change", () => { + pendingAttachments = Array.from(attachmentInput[0].files || []); + renderAttachmentList(); + }); + clearAttachmentsButton.on("click", () => { + pendingAttachments = []; + attachmentInput.val(""); + renderAttachmentList(); + }); chatInput.on('keydown', function (event) { if (event.key === 'Enter' && !event.shiftKey) { event.preventDefault(); @@ -228,5 +562,13 @@ const createConnectionSessionChat = () => { connection.on("OnStatus", onStatus); connection.on("OnError", onError); connection.on("OnResponse", onResponse); - connection.start(); -} \ No newline at end of file + connection.on("OnModelDownloadSnapshot", applyDownloadSnapshot); + connection.on("OnModelDownloadProgress", updateDownloadState); + connection.on("OnStorageInfo", onStorageInfo); + connection.start().then(() => { + connection.invoke("GetModelStatuses").then(applyDownloadSnapshot); + setInterval(() => { + connection.invoke("GetModelStatuses").then(applyDownloadSnapshot); + }, 2000); + }); +} From 466a8cb3e334bf72ad73607eaf0dfa47d45413dc Mon Sep 17 00:00:00 2001 From: SignalRT Date: Thu, 22 Jan 2026 22:44:08 +0100 Subject: [PATCH 2/8] Add Missing Files --- LLama.Web/Models/AttachmentInfo.cs | 32 ++++++++++++++++++ LLama.Web/Models/MemoryBrowserFile.cs | 34 +++++++++++++++++++ LLama.Web/Models/ModelCapabilities.cs | 8 +++++ LLama.Web/Models/ModelDownloadStatus.cs | 45 +++++++++++++++++++++++++ LLama.Web/Models/PromptRequest.cs | 7 ++++ LLama.Web/Models/StorageInfo.cs | 8 +++++ 6 files changed, 134 insertions(+) create mode 100644 LLama.Web/Models/AttachmentInfo.cs create mode 100644 LLama.Web/Models/MemoryBrowserFile.cs create mode 100644 LLama.Web/Models/ModelCapabilities.cs create mode 100644 LLama.Web/Models/ModelDownloadStatus.cs create mode 100644 LLama.Web/Models/PromptRequest.cs create mode 100644 LLama.Web/Models/StorageInfo.cs diff --git a/LLama.Web/Models/AttachmentInfo.cs b/LLama.Web/Models/AttachmentInfo.cs new file mode 100644 index 000000000..65dbd436b --- /dev/null +++ b/LLama.Web/Models/AttachmentInfo.cs @@ -0,0 +1,32 @@ +namespace LLama.Web.Models; + +public enum AttachmentKind +{ + Unknown = 0, + Pdf = 1, + Image = 2, + Audio = 3, + Word = 4 +} + +public class AttachmentInfo +{ + public string Id { get; set; } + [System.Text.Json.Serialization.JsonIgnore] + public string ConnectionId { get; set; } + public string FileName { get; set; } + public string ContentType { get; set; } + [System.Text.Json.Serialization.JsonIgnore] + public string FilePath { get; set; } + public AttachmentKind Kind { get; set; } + public long SizeBytes { get; set; } + [System.Text.Json.Serialization.JsonIgnore] + public string ExtractedText { get; set; } + [System.Text.Json.Serialization.JsonIgnore] + public bool ExtractedTextTruncated { get; set; } +} + +public class AttachmentUploadResult +{ + public List Attachments { get; set; } = new(); +} diff --git a/LLama.Web/Models/MemoryBrowserFile.cs b/LLama.Web/Models/MemoryBrowserFile.cs new file mode 100644 index 000000000..a79938acc --- /dev/null +++ b/LLama.Web/Models/MemoryBrowserFile.cs @@ -0,0 +1,34 @@ +using Microsoft.AspNetCore.Components.Forms; + +namespace LLama.Web.Models; + +public sealed class MemoryBrowserFile : IBrowserFile +{ + private readonly byte[] _data; + + public MemoryBrowserFile(string name, string contentType, byte[] data, DateTimeOffset? lastModified = null) + { + if (string.IsNullOrWhiteSpace(name)) + throw new ArgumentException("File name is required.", nameof(name)); + if (string.IsNullOrWhiteSpace(contentType)) + throw new ArgumentException("Content type is required.", nameof(contentType)); + _data = data ?? throw new ArgumentNullException(nameof(data)); + + Name = name; + ContentType = contentType; + LastModified = lastModified ?? DateTimeOffset.UtcNow; + } + + public string Name { get; } + public DateTimeOffset LastModified { get; } + public long Size => _data.LongLength; + public string ContentType { get; } + + public Stream OpenReadStream(long maxAllowedSize = 512L * 1024 * 1024, CancellationToken cancellationToken = default) + { + if (Size > maxAllowedSize) + throw new IOException($"File {Name} exceeds the maximum allowed size of {maxAllowedSize} bytes."); + + return new MemoryStream(_data, writable: false); + } +} diff --git a/LLama.Web/Models/ModelCapabilities.cs b/LLama.Web/Models/ModelCapabilities.cs new file mode 100644 index 000000000..86a449b98 --- /dev/null +++ b/LLama.Web/Models/ModelCapabilities.cs @@ -0,0 +1,8 @@ +namespace LLama.Web.Models; + +public class ModelCapabilities +{ + public bool SupportsText { get; set; } + public bool SupportsVision { get; set; } + public bool SupportsAudio { get; set; } +} diff --git a/LLama.Web/Models/ModelDownloadStatus.cs b/LLama.Web/Models/ModelDownloadStatus.cs new file mode 100644 index 000000000..785178f4f --- /dev/null +++ b/LLama.Web/Models/ModelDownloadStatus.cs @@ -0,0 +1,45 @@ +namespace LLama.Web.Models; + +public enum ModelAssetKind +{ + Model = 0, + Mmproj = 1 +} + +public enum ModelDownloadState +{ + Queued = 0, + Downloading = 1, + Completed = 2, + Failed = 3, + MissingUrl = 4 +} + +public class ModelAssetSnapshot +{ + public string ModelName { get; set; } + public string FileName { get; set; } + public ModelAssetKind AssetKind { get; set; } + public ModelDownloadState State { get; set; } + public long? TotalBytes { get; set; } + public long BytesReceived { get; set; } + public string Error { get; set; } +} + +public class ModelDownloadSnapshot +{ + public string ModelName { get; set; } + public bool Ready { get; set; } + public List Assets { get; set; } = new(); +} + +public class ModelDownloadProgress +{ + public string ModelName { get; set; } + public string FileName { get; set; } + public ModelAssetKind AssetKind { get; set; } + public ModelDownloadState State { get; set; } + public long? TotalBytes { get; set; } + public long BytesReceived { get; set; } + public string Error { get; set; } +} diff --git a/LLama.Web/Models/PromptRequest.cs b/LLama.Web/Models/PromptRequest.cs new file mode 100644 index 000000000..cc807eea1 --- /dev/null +++ b/LLama.Web/Models/PromptRequest.cs @@ -0,0 +1,7 @@ +namespace LLama.Web.Models; + +public class PromptRequest +{ + public string Prompt { get; set; } + public List AttachmentIds { get; set; } = new(); +} diff --git a/LLama.Web/Models/StorageInfo.cs b/LLama.Web/Models/StorageInfo.cs new file mode 100644 index 000000000..60d9e36ce --- /dev/null +++ b/LLama.Web/Models/StorageInfo.cs @@ -0,0 +1,8 @@ +namespace LLama.Web.Models; + +public class StorageInfo +{ + public string ModelsPath { get; set; } + public string DownloadsPath { get; set; } + public string UploadsPath { get; set; } +} From 85e9b6d5e2d158e3f4ad338fa3a1c1452a76d33f Mon Sep 17 00:00:00 2001 From: SignalRT Date: Sun, 15 Mar 2026 20:41:52 +0100 Subject: [PATCH 3/8] Improve MTMD execution and web audio/chat handling - Reworked MTMD prompt handling to preserve text/media ordering and evaluate multimodal input incrementally. - Disabled unsupported multimodal features such as session persistence and context shifting. - Added standalone MTMD media loading and synchronized MTMD weight operations. - Updated MTMD example and tests to cover prompt ordering, guards, and opt-in NoCI execution. - Fixed web model/session defaults for multimodal models, including template-derived stop markers and unspecified pooling. - Improved LLama.Web audio attachment/recording flow, Qwen audio prompt handling, and chat composer UX. - Removed the broken browser script include and added a safe markdown fallback. --- .../Examples/MtmdInteractiveModeExecute.cs | 147 +++++-- LLama.Unittest/MtmdContextGuardTests.cs | 265 ++++++++++++ LLama.Unittest/MtmdExecutorTests.cs | 167 ++++++-- LLama.Unittest/MtmdNoCiCollection.cs | 71 ++++ LLama.Unittest/MtmdWeightsTests.cs | 79 ++-- LLama.Web/Common/ModelOptions.cs | 2 +- LLama.Web/LLama.Web.csproj | 1 + LLama.Web/Models/ModelSession.cs | 252 ++++++++++-- LLama.Web/Pages/Chat.razor | 121 ++++-- LLama.Web/Pages/Shared/_Layout.cshtml | 1 - LLama.Web/Services/ModelService.cs | 2 +- LLama.Web/Services/ModelSessionService.cs | 59 ++- LLama.Web/appsettings.json | 70 +--- LLama.Web/wwwroot/js/blazorChat.js | 385 +++++++++++++++++- LLama/AntipromptProcessor.cs | 2 +- LLama/ChatSession.cs | 22 +- LLama/LLamaExecutorBase.cs | 284 +++++++++---- LLama/LLamaInstructExecutor.cs | 77 ++-- LLama/LLamaInteractExecutor.cs | 136 ++++--- LLama/MtmdWeights.cs | 63 ++- LLama/Native/SafeMtmdModelHandle.cs | 36 +- 21 files changed, 1792 insertions(+), 450 deletions(-) create mode 100644 LLama.Unittest/MtmdContextGuardTests.cs create mode 100644 LLama.Unittest/MtmdNoCiCollection.cs diff --git a/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs b/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs index 571a89c1a..5c61854ab 100644 --- a/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs +++ b/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs @@ -15,14 +15,14 @@ namespace LLama.Examples.Examples // It uses the interactive executor to inference. public class MtmdInteractiveModeExecute { + private sealed record TemplateMarkers(string? AssistantEndMarker, string? AssistantToUserMarker); + public static async Task Run() { string multiModalProj = UserSettings.GetMMProjPath(); string modelPath = UserSettings.GetModelPath(); const int maxTokens = 4096; - string? prompt = await File.ReadAllTextAsync("Assets/chat-with-bob.json"); - var parameters = new ModelParams(modelPath); var mtmdParameters = MtmdContextParams.Default(); @@ -40,8 +40,8 @@ public static async Task Run() var ex = new InteractiveExecutor(context, clipModel); var chatHistory = new ChatHistory(); - var isFirstTurn = true; - + var templateMarkers = ResolveTemplateMarkers(model); + var antiPrompts = GetEffectiveAntiPrompts(templateMarkers); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example the maximum tokens is set to {0} and the context size is {1}.", maxTokens, parameters.ContextSize ); Console.WriteLine("Model: {0}", modelPath); @@ -54,6 +54,7 @@ public static async Task Run() Console.WriteLine("Video inputs are not supported (format would be {{c:/video.mp4}})."); Console.WriteLine("Commands: /exit (return to main menu) | /clear (reset the chat history and KV cache)."); Console.WriteLine("Press Ctrl+c to return to main menu."); + Console.Write("User: "); void ResetConversation() { @@ -64,7 +65,7 @@ void ResetConversation() clipModel.ClearMedia(); chatHistory.Messages.Clear(); ex = new InteractiveExecutor(context, clipModel); - Console.WriteLine("User:"); + Console.Write("User: "); } var inferenceParams = new InferenceParams @@ -74,34 +75,34 @@ void ResetConversation() Temperature = 0.1f }, - AntiPrompts = new List { "User:" }, - MaxTokens = maxTokens + AntiPrompts = antiPrompts, + MaxTokens = maxTokens, + DecodeSpecialTokens = ShouldDecodeSpecialTokens(antiPrompts) }; do { - if (!isFirstTurn) - { - Console.ForegroundColor = ConsoleColor.Green; - prompt = Console.ReadLine(); - Console.WriteLine(); + Console.ForegroundColor = ConsoleColor.Green; + var prompt = Console.ReadLine(); + Console.WriteLine(); - if (prompt == null || prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase)) - break; + if (prompt == null || prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase)) + break; - if (prompt.Equals("/clear", StringComparison.OrdinalIgnoreCase)) - { - ResetConversation(); - continue; - } + if (prompt.Equals("/clear", StringComparison.OrdinalIgnoreCase)) + { + ResetConversation(); + continue; } - else + + if (string.IsNullOrWhiteSpace(prompt)) { - isFirstTurn = false; + Console.Write("User: "); + continue; } - var userPrompt = prompt ?? string.Empty; + var userPrompt = prompt; // Evaluate if we have media // @@ -171,7 +172,7 @@ void ResetConversation() audioList.Add(mediaPath); } - var embed = clipModel.LoadMedia(mediaPath); + var embed = clipModel.LoadMediaStandalone(mediaPath); embeds.Add(embed); } } @@ -183,7 +184,8 @@ void ResetConversation() Console.Write($"{exception.Message}"); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Please try again."); - clipModel.ClearMedia(); + foreach (var embed in embeds) + embed.Dispose(); break; } @@ -236,8 +238,9 @@ void ResetConversation() Console.Write(text); responseBuilder.Append(text); } - Console.Write(" "); + Console.WriteLine(); chatHistory.AddMessage(AuthorRole.Assistant, responseBuilder.ToString()); + Console.Write("User: "); } while(true); @@ -259,6 +262,45 @@ private static string BuildChatDelta(LLamaWeights model, ChatHistory history, st return delta; } + private static List GetEffectiveAntiPrompts(TemplateMarkers templateMarkers) + { + var antiPrompts = new List(); + AddMarker(antiPrompts, templateMarkers.AssistantEndMarker); + AddMarker(antiPrompts, templateMarkers.AssistantToUserMarker); + + if (antiPrompts.Count == 0) + antiPrompts.Add("User:"); + + return antiPrompts; + } + + private static void AddMarker(List values, string? marker) + { + if (string.IsNullOrWhiteSpace(marker)) + return; + + if (!values.Contains(marker)) + values.Add(marker); + + var trimmedMarker = marker.Trim(); + if (!string.IsNullOrWhiteSpace(trimmedMarker) && !values.Contains(trimmedMarker)) + values.Add(trimmedMarker); + } + + private static bool ShouldDecodeSpecialTokens(IReadOnlyList antiPrompts) + { + foreach (var antiPrompt in antiPrompts) + { + if (string.IsNullOrWhiteSpace(antiPrompt)) + continue; + + if (antiPrompt.Contains('<', StringComparison.Ordinal) || antiPrompt.Contains('>', StringComparison.Ordinal)) + return true; + } + + return false; + } + private static string FormatChatHistory(LLamaWeights model, ChatHistory history, bool addAssistant) { var template = new LLamaTemplate(model.NativeHandle) @@ -271,5 +313,60 @@ private static string FormatChatHistory(LLamaWeights model, ChatHistory history, return LLamaTemplate.Encoding.GetString(template.Apply()); } + + private static TemplateMarkers ResolveTemplateMarkers(LLamaWeights model) + { + const string userMarkerA = "__LLAMA_USER_A__"; + const string assistantMarkerA = "__LLAMA_ASSISTANT_A__"; + const string userMarkerB = "__LLAMA_USER_B__"; + + try + { + var assistantTemplate = new LLamaTemplate(model.NativeHandle) + { + AddAssistant = false + }; + assistantTemplate.Add("user", userMarkerA); + assistantTemplate.Add("assistant", assistantMarkerA); + + var assistantRendered = LLamaTemplate.Encoding.GetString(assistantTemplate.Apply()); + var assistantIndex = assistantRendered.IndexOf(assistantMarkerA, StringComparison.Ordinal); + if (assistantIndex < 0) + return new TemplateMarkers(null, null); + + var assistantEndMarker = assistantRendered[(assistantIndex + assistantMarkerA.Length)..]; + + var conversationTemplate = new LLamaTemplate(model.NativeHandle) + { + AddAssistant = false + }; + conversationTemplate.Add("user", userMarkerA); + conversationTemplate.Add("assistant", assistantMarkerA); + conversationTemplate.Add("user", userMarkerB); + + var conversationRendered = LLamaTemplate.Encoding.GetString(conversationTemplate.Apply()); + var assistantConversationIndex = conversationRendered.IndexOf(assistantMarkerA, StringComparison.Ordinal); + var userIndex = conversationRendered.IndexOf(userMarkerB, StringComparison.Ordinal); + if (assistantConversationIndex < 0 || userIndex <= assistantConversationIndex) + return new TemplateMarkers(NormalizeMarker(assistantEndMarker), null); + + var assistantToUserMarker = conversationRendered.Substring( + assistantConversationIndex + assistantMarkerA.Length, + userIndex - (assistantConversationIndex + assistantMarkerA.Length)); + + return new TemplateMarkers( + NormalizeMarker(assistantEndMarker), + NormalizeMarker(assistantToUserMarker)); + } + catch + { + return new TemplateMarkers(null, null); + } + } + + private static string? NormalizeMarker(string? marker) + { + return string.IsNullOrWhiteSpace(marker) ? null : marker; + } } } diff --git a/LLama.Unittest/MtmdContextGuardTests.cs b/LLama.Unittest/MtmdContextGuardTests.cs new file mode 100644 index 000000000..d2ee7b156 --- /dev/null +++ b/LLama.Unittest/MtmdContextGuardTests.cs @@ -0,0 +1,265 @@ +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Threading.Tasks; +using LLama.Common; +using LLama.Exceptions; +using LLama.Native; +using Microsoft.Extensions.Logging.Abstractions; + +namespace LLama.Unittest; + +public sealed class MtmdContextGuardTests : IDisposable +{ + private const string EnableEnvVar = "LLAMASHARP_RUN_MTMD_TESTS"; + private const string SkipReason = "MTMD tests are opt-in. Set LLAMASHARP_RUN_MTMD_TESTS=1 to run them."; + + private readonly bool _isEnabled; + private readonly LLamaWeights? _weights; + private readonly ModelParams? _modelParams; + + public MtmdContextGuardTests() + { + _isEnabled = string.Equals(Environment.GetEnvironmentVariable(EnableEnvVar), "1", StringComparison.OrdinalIgnoreCase) + || string.Equals(Environment.GetEnvironmentVariable(EnableEnvVar), "true", StringComparison.OrdinalIgnoreCase); + + if (!_isEnabled) + return; + + NativeLogConfig.llama_log_set(static (_, _) => { }); + + _modelParams = new ModelParams(Constants.GenerativeModelPath2) + { + ContextSize = 128, + BatchSize = 8, + UBatchSize = 8, + VocabOnly = false, + GpuLayerCount = Constants.CIGpuLayerCount, + }; + _weights = LLamaWeights.LoadFromFile(_modelParams); + } + + public void Dispose() + { + if (!_isEnabled) + return; + + _weights.Dispose(); + NativeLogConfig.llama_log_set((NativeLogConfig.LLamaLogCallback?)null); + } + + [SkippableFact] + public void InteractiveExecutor_LoadMtmdPromptSegments_PreservesOrderedPromptAndMediaBoundary() + { + Skip.IfNot(_isEnabled, SkipReason); + + using var context = _weights!.CreateContext(_modelParams!, NullLogger.Instance); + using var chunks = new SafeMtmdInputChunks(NativeApi.mtmd_test_create_input_chunks()); + var executor = new TrackingInteractiveExecutor(context); + var fillerToken = (LLamaToken)99; + + var (imageTokenCount, imagePositionCount) = GetSyntheticImageChunkMetadata(chunks); + executor.LoadPendingMtmd(chunks, fillerToken); + + Assert.Equal( + new[] { (LLamaToken)1, 2, 3, 4, 5 }.Concat(Enumerable.Repeat(fillerToken, imageTokenCount)), + executor.EmbedInputs); + Assert.Equal(5 + imagePositionCount, executor.PendingPositions); + + executor.QueuePromptText(); + + Assert.Equal(new[] { (LLamaToken)1, 2, 3, 4, 5 }, executor.PendingEmbeds); + Assert.True(executor.PendingMediaSegment); + + executor.ClearPendingEmbeds(); + + Assert.Equal(imagePositionCount, executor.PendingPositions); + } + + [SkippableFact] + public void InteractiveExecutor_MtmdOverflowThrowsWithoutContextShift() + { + Skip.IfNot(_isEnabled, SkipReason); + + using var context = _weights!.CreateContext(_modelParams!, NullLogger.Instance); + using var chunks = new SafeMtmdInputChunks(NativeApi.mtmd_test_create_input_chunks()); + var executor = new TrackingInteractiveExecutor(context); + + executor.LoadPendingMtmd(chunks, fillerToken: 99, pastTokensCount: (int)context.ContextSize - 1, pendingEmbeds: [(LLamaToken)1]); + + Assert.True(executor.PendingPositions > 1); + + var ex = Assert.Throws(() => executor.CheckContextGuard(new InferenceParams())); + + Assert.Contains("Context shifting is not supported", ex.Message); + } + + [SkippableFact] + public void InstructExecutor_MtmdOverflowThrowsWithoutContextShift() + { + Skip.IfNot(_isEnabled, SkipReason); + + using var context = _weights!.CreateContext(_modelParams!, NullLogger.Instance); + using var chunks = new SafeMtmdInputChunks(NativeApi.mtmd_test_create_input_chunks()); + var executor = new TrackingInstructExecutor(context); + + executor.LoadPendingMtmd(chunks, fillerToken: 99, pastTokensCount: (int)context.ContextSize - 1, pendingEmbeds: [(LLamaToken)1]); + + Assert.True(executor.PendingPositions > 1); + + var ex = Assert.Throws(() => executor.CheckContextGuard()); + + Assert.Contains("Context shifting is not supported", ex.Message); + } + + [SkippableFact] + public async Task InteractiveExecutor_MultimodalSessionAndStateApisAreRejected() + { + Skip.IfNot(_isEnabled, SkipReason); + + using var context = _weights!.CreateContext(_modelParams!, NullLogger.Instance); + var executor = new TrackingInteractiveExecutor(context); + var statePath = Path.Combine(Path.GetTempPath(), $"{Guid.NewGuid():N}.json"); + + Assert.Throws(() => executor.WithSessionFile(statePath)); + Assert.Throws(() => executor.SaveSessionFile(statePath)); + Assert.Throws(() => executor.GetStateData()); + await Assert.ThrowsAsync(() => executor.SaveState(statePath)); + await Assert.ThrowsAsync(() => executor.LoadState(statePath)); + } + + [SkippableFact] + public async Task InstructExecutor_MultimodalSessionAndStateApisAreRejected() + { + Skip.IfNot(_isEnabled, SkipReason); + + using var context = _weights!.CreateContext(_modelParams!, NullLogger.Instance); + var executor = new TrackingInstructExecutor(context); + var statePath = Path.Combine(Path.GetTempPath(), $"{Guid.NewGuid():N}.json"); + + Assert.Throws(() => executor.WithSessionFile(statePath)); + Assert.Throws(() => executor.SaveSessionFile(statePath)); + Assert.Throws(() => executor.GetStateData()); + await Assert.ThrowsAsync(() => executor.SaveState(statePath)); + await Assert.ThrowsAsync(() => executor.LoadState(statePath)); + } + + [SkippableFact] + public void ChatSession_MultimodalSessionPersistenceApisAreRejected() + { + Skip.IfNot(_isEnabled, SkipReason); + + using var context = _weights!.CreateContext(_modelParams!, NullLogger.Instance); + var executor = new TrackingInteractiveExecutor(context); + var chatSession = new ChatSession(executor); + var sessionPath = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString("N")); + var state = new SessionState( + contextState: null, + executorState: new InteractiveExecutor.InteractiveExecutorState(), + history: new ChatHistory(), + inputTransformPipeline: [], + outputTransform: new LLamaTransforms.EmptyTextOutputStreamTransform(), + historyTransform: new LLamaTransforms.DefaultHistoryTransform()); + + Assert.Throws(() => chatSession.GetSessionState()); + Assert.Throws(() => chatSession.SaveSession(sessionPath)); + Assert.Throws(() => chatSession.LoadSession(state)); + Assert.Throws(() => chatSession.LoadSession(sessionPath)); + } + + private static (int ImageTokenCount, int ImagePositionCount) GetSyntheticImageChunkMetadata(SafeMtmdInputChunks chunks) + { + foreach (var chunk in chunks.Enumerate()) + { + using var scopedChunk = chunk; + if (scopedChunk.Type == SafeMtmdInputChunk.SafeMtmdInputChunkType.Image) + return (checked((int)scopedChunk.NTokens), checked((int)scopedChunk.NPos)); + } + + throw new InvalidOperationException("Synthetic MTMD test chunks do not contain an image segment."); + } + + private static MtmdWeights CreateFakeMtmdWeights() + => (MtmdWeights)RuntimeHelpers.GetUninitializedObject(typeof(MtmdWeights)); + + private static class TrackingMultimodalExecutorBase + { + private static readonly FieldInfo ClipModelField = typeof(StatefulExecutorBase) + .GetField("k__BackingField", BindingFlags.Instance | BindingFlags.NonPublic) + ?? throw new InvalidOperationException("Unable to find ClipModel backing field."); + + public static void MarkAsMultimodal(StatefulExecutorBase executor) + { + ClipModelField.SetValue(executor, CreateFakeMtmdWeights()); + } + } + + private sealed class TrackingInteractiveExecutor : InteractiveExecutor + { + public TrackingInteractiveExecutor(LLamaContext context) + : base(context, NullLogger.Instance) + { + TrackingMultimodalExecutorBase.MarkAsMultimodal(this); + } + + public int PendingPositions => GetPendingInputPositionCount(); + + public IReadOnlyList EmbedInputs => _embed_inps; + + public IReadOnlyList PendingEmbeds => _embeds; + + public bool PendingMediaSegment => HasPendingMtmdMediaSegment(); + + public void LoadPendingMtmd(SafeMtmdInputChunks chunks, LLamaToken fillerToken, int pastTokensCount = 0, IReadOnlyList? pendingEmbeds = null) + { + LoadMtmdPromptSegments(chunks, fillerToken, replaceExisting: true); + _pastTokensCount = pastTokensCount; + _embeds = pendingEmbeds?.ToList() ?? []; + } + + public void QueuePromptText() + => QueueNextMtmdPromptInput(); + + public void ClearPendingEmbeds() + => _embeds.Clear(); + + public void CheckContextGuard(InferenceParams inferenceParams) + { + var tokensToKeep = inferenceParams.TokensKeep; + if (tokensToKeep < 0 || tokensToKeep > _embed_inps.Count) + { + tokensToKeep = _embed_inps.Count; + } + else + { + tokensToKeep += Convert.ToInt32(Context.Vocab.ShouldAddBOS); + } + + EnsurePendingInputFitsContext(tokensToKeep); + } + } + + private sealed class TrackingInstructExecutor : InstructExecutor + { + public TrackingInstructExecutor(LLamaContext context) + : base(context, logger: NullLogger.Instance) + { + TrackingMultimodalExecutorBase.MarkAsMultimodal(this); + } + + public int PendingPositions => GetPendingInputPositionCount(); + + public void LoadPendingMtmd(SafeMtmdInputChunks chunks, LLamaToken fillerToken, int pastTokensCount = 0, IReadOnlyList? pendingEmbeds = null) + { + LoadMtmdPromptSegments(chunks, fillerToken, replaceExisting: true); + _pastTokensCount = pastTokensCount; + _embeds = pendingEmbeds?.ToList() ?? []; + } + + public void CheckContextGuard() + => EnsurePendingInputFitsContext(_embed_inps.Count); + } +} diff --git a/LLama.Unittest/MtmdExecutorTests.cs b/LLama.Unittest/MtmdExecutorTests.cs index 7aa89e2f9..041d10c9f 100644 --- a/LLama.Unittest/MtmdExecutorTests.cs +++ b/LLama.Unittest/MtmdExecutorTests.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Threading.Tasks; using LLama.Common; using LLama.Native; @@ -8,74 +9,158 @@ namespace LLama.Unittest; +[Collection("MtmdNoCi")] [Trait("Category", "NoCI")] -public class MtmdExecutorTests : IDisposable +public class MtmdExecutorTests { - private readonly LLamaWeights _weights; - private readonly MtmdContextParams _mtmdParams; - private readonly MtmdWeights _mtmd; - private readonly ModelParams _modelParams; + private readonly MtmdNoCiFixture _fixture; - public MtmdExecutorTests() + public MtmdExecutorTests(MtmdNoCiFixture fixture) { - _modelParams = new ModelParams(Constants.MtmdModelPath) - { - ContextSize = 1024 * 32, - GpuLayerCount = Constants.CIGpuLayerCount, - }; - - _weights = LLamaWeights.LoadFromFile(_modelParams); - - _mtmdParams = MtmdContextParams.Default(); - _mtmdParams.NThreads = Math.Max(1, Constants.CIGpuLayerCount); - _mtmdParams.UseGpu = false; - - _mtmd = MtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _weights, _mtmdParams); - } - - public void Dispose() - { - _mtmd.Dispose(); - _weights.Dispose(); + _fixture = fixture; } - [Fact] + [SkippableFact] public async Task InteractiveExecutor_EvaluateChunks_DoesNotRetokenize() { - using var context = _weights.CreateContext(_modelParams, NullLogger.Instance); - var executor = new InteractiveExecutor(context, _mtmd, NullLogger.Instance); - var marker = _mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; - var prompt = $"{marker}\nDescribe the image succinctly."; + Skip.IfNot(_fixture.IsEnabled, MtmdNoCiFixture.SkipReason); - executor.Embeds.Add(_mtmd.LoadMedia(Constants.MtmdImage)); + using var context = _fixture.CreateContext(); + var executor = new InteractiveExecutor(context, _fixture.Mtmd, NullLogger.Instance); + var marker = _fixture.MediaMarker; + var prompt = $"Before {marker} after."; + + executor.Embeds.Add(_fixture.Mtmd.LoadMedia(Constants.MtmdImage)); await foreach (var _ in executor.InferAsync(prompt, new InferenceParams { MaxTokens = 0 })) { - Assert.True(false, "Prefill should not emit generated text"); + Assert.Fail("Prefill should not emit generated text"); } var diagnostics = executor.GetDiagnostics(); - Assert.Equal(diagnostics.EmbedCount, diagnostics.ConsumedCount); - Assert.Equal(diagnostics.ConsumedCount, diagnostics.PastCount); + Assert.True(diagnostics.ConsumedCount > 0); Assert.Equal(0, diagnostics.PendingEmbedCount); } - [Fact] + [SkippableFact] public async Task InstructExecutor_MtmdPromptAdvancesPastTokensOnce() { - using var context = _weights.CreateContext(_modelParams, NullLogger.Instance); - var executor = new InstructExecutor(context, _mtmd, logger: NullLogger.Instance); - executor.Embeds.Add(_mtmd.LoadMedia(Constants.MtmdImage)); + Skip.IfNot(_fixture.IsEnabled, MtmdNoCiFixture.SkipReason); + + using var context = _fixture.CreateContext(); + var executor = new InstructExecutor(context, _fixture.Mtmd, logger: NullLogger.Instance); + executor.Embeds.Add(_fixture.Mtmd.LoadMedia(Constants.MtmdImage)); - var prompt = $"{_mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""} Provide details."; + var prompt = $"{_fixture.MediaMarker} Provide details."; await foreach (var _ in executor.InferAsync(prompt, new InferenceParams { MaxTokens = 0 })) { } var diagnostics = executor.GetDiagnostics(); - Assert.Equal(diagnostics.EmbedCount, diagnostics.ConsumedCount); - Assert.Equal(diagnostics.ConsumedCount, diagnostics.PastCount); + Assert.True(diagnostics.ConsumedCount > 0); Assert.Equal(0, diagnostics.PendingEmbedCount); } + + [SkippableFact] + public async Task InteractiveExecutor_PreprocessMtmd_PreservesInterleavedPromptOrder() + { + Skip.IfNot(_fixture.IsEnabled, MtmdNoCiFixture.SkipReason); + + using var context = _fixture.CreateContext(); + var executor = new InspectableInteractiveExecutor(context, _fixture.Mtmd); + var marker = _fixture.MediaMarker; + var prompt = $"Before {marker} after."; + using var expectedChunks = TokenizePromptWithMedia(prompt); + var fillerToken = GetFillerToken(context, marker); + var expectedPromptTokens = BuildLogicalPromptTokens(expectedChunks, fillerToken); + var expectedLeadingTextTokens = GetLeadingTextTokens(expectedChunks); + + executor.Embeds.Add(_fixture.Mtmd.LoadMedia(Constants.MtmdImage)); + await executor.PreparePromptAsync(prompt, addBos: true, replaceExisting: true); + + Assert.Equal(expectedPromptTokens, executor.EmbedInputs); + Assert.Equal((int)expectedChunks.CountPositions(), executor.PendingPositions); + + executor.QueuePromptText(); + + Assert.Equal(expectedLeadingTextTokens, executor.PendingEmbeds); + Assert.True(executor.PendingMediaSegment); + } + + private SafeMtmdInputChunks TokenizePromptWithMedia(string prompt) + { + _fixture.Mtmd.ClearMedia(); + using var embed = _fixture.Mtmd.LoadMedia(Constants.MtmdImage); + + var status = _fixture.Mtmd.Tokenize(prompt, addSpecial: true, parseSpecial: true, out var chunks); + Assert.Equal(0, status); + Assert.NotNull(chunks); + + return chunks!; + } + + private static LLamaToken GetFillerToken(LLamaContext context, string marker) + { + var markerTokens = context.Tokenize(marker, false, true); + return markerTokens.Length > 0 + ? markerTokens[^1] + : context.Vocab.EOS ?? default; + } + + private static IReadOnlyList BuildLogicalPromptTokens(SafeMtmdInputChunks chunks, LLamaToken fillerToken) + { + var tokens = new List(); + + foreach (var chunk in chunks.Enumerate()) + { + using var scopedChunk = chunk; + if (scopedChunk.Type == SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) + { + tokens.AddRange(scopedChunk.GetTextTokensSpan().ToArray().Select(static token => (LLamaToken)token)); + } + else + { + tokens.AddRange(Enumerable.Repeat(fillerToken, checked((int)scopedChunk.NTokens))); + } + } + + return tokens; + } + + private static IReadOnlyList GetLeadingTextTokens(SafeMtmdInputChunks chunks) + { + foreach (var chunk in chunks.Enumerate()) + { + using var scopedChunk = chunk; + if (scopedChunk.Type == SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) + return scopedChunk.GetTextTokensSpan().ToArray().Select(static token => (LLamaToken)token).ToArray(); + + break; + } + + throw new InvalidOperationException("Expected the tokenized multimodal prompt to begin with a text chunk."); + } + + private sealed class InspectableInteractiveExecutor : InteractiveExecutor + { + public InspectableInteractiveExecutor(LLamaContext context, MtmdWeights clipModel) + : base(context, clipModel, NullLogger.Instance) + { + } + + public IReadOnlyList EmbedInputs => _embed_inps; + + public IReadOnlyList PendingEmbeds => _embeds; + + public int PendingPositions => GetPendingInputPositionCount(); + + public bool PendingMediaSegment => HasPendingMtmdMediaSegment(); + + public Task PreparePromptAsync(string prompt, bool addBos, bool replaceExisting) + => PreprocessMtmd(prompt, new InferStateArgs(), addBos, replaceExisting); + + public void QueuePromptText() + => QueueNextMtmdPromptInput(); + } } diff --git a/LLama.Unittest/MtmdNoCiCollection.cs b/LLama.Unittest/MtmdNoCiCollection.cs new file mode 100644 index 000000000..38733ffc7 --- /dev/null +++ b/LLama.Unittest/MtmdNoCiCollection.cs @@ -0,0 +1,71 @@ +using System; +using LLama.Common; +using LLama.Native; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace LLama.Unittest; + +[CollectionDefinition("MtmdNoCi", DisableParallelization = true)] +public sealed class MtmdNoCiCollection : ICollectionFixture +{ +} + +public sealed class MtmdNoCiFixture : IDisposable +{ + public const string EnableEnvVar = "LLAMASHARP_RUN_NOCI_MTMD"; + public const string SkipReason = "MTMD NoCI tests are opt-in. Set LLAMASHARP_RUN_NOCI_MTMD=1 to run them."; + + public MtmdNoCiFixture() + { + IsEnabled = string.Equals(Environment.GetEnvironmentVariable(EnableEnvVar), "1", StringComparison.OrdinalIgnoreCase) + || string.Equals(Environment.GetEnvironmentVariable(EnableEnvVar), "true", StringComparison.OrdinalIgnoreCase); + + if (!IsEnabled) + return; + + NativeLogConfig.llama_log_set(static (_, _) => { }); + + ModelParams = new ModelParams(Constants.MtmdModelPath) + { + ContextSize = 1024 * 32, + GpuLayerCount = Constants.CIGpuLayerCount, + }; + + Weights = LLamaWeights.LoadFromFile(ModelParams); + + MtmdParams = MtmdContextParams.Default(); + MtmdParams.NThreads = Math.Max(1, Constants.CIGpuLayerCount); + MtmdParams.UseGpu = false; + + Mtmd = MtmdWeights.LoadFromFile(Constants.MtmdMmpPath, Weights, MtmdParams); + MediaMarker = MtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; + } + + public bool IsEnabled { get; } + + public ModelParams ModelParams { get; } = null!; + + public LLamaWeights Weights { get; } = null!; + + public MtmdContextParams MtmdParams { get; } = null!; + + public MtmdWeights Mtmd { get; } = null!; + + public string MediaMarker { get; } = string.Empty; + + public LLamaContext CreateContext() + => IsEnabled + ? Weights.CreateContext(ModelParams, NullLogger.Instance) + : throw new InvalidOperationException(SkipReason); + + public void Dispose() + { + if (!IsEnabled) + return; + + Mtmd.Dispose(); + Weights.Dispose(); + NativeLogConfig.llama_log_set((NativeLogConfig.LLamaLogCallback?)null); + } +} diff --git a/LLama.Unittest/MtmdWeightsTests.cs b/LLama.Unittest/MtmdWeightsTests.cs index 4a1a59c75..7dc6cc737 100644 --- a/LLama.Unittest/MtmdWeightsTests.cs +++ b/LLama.Unittest/MtmdWeightsTests.cs @@ -5,45 +5,21 @@ namespace LLama.Unittest { // Test the same things as llama model + image embedings // + [Collection("MtmdNoCi")] public sealed class MtmdWeightTests - : IDisposable { - private readonly LLamaWeights _llamaWeights; - private readonly MtmdWeights _mtmdWeights; - private readonly LLamaContext _context; - private readonly MtmdContextParams _mtmdParams; + private readonly MtmdNoCiFixture _fixture; private readonly string _mediaMarker; - public MtmdWeightTests() + public MtmdWeightTests(MtmdNoCiFixture fixture) { - var @params = new ModelParams(Constants.MtmdModelPath) - { - // Mtmd models requires big context - ContextSize = 1024 * 32, - GpuLayerCount = Constants.CIGpuLayerCount, - }; - _llamaWeights = LLamaWeights.LoadFromFile(@params); - - _mtmdParams = MtmdContextParams.Default(); - _mtmdParams.NThreads = Constants.CIGpuLayerCount; - _mtmdParams.UseGpu = false; // keep tests portable across environments without GPU - - _mediaMarker = _mtmdParams.MediaMarker ?? throw new InvalidOperationException("MTMD media marker unavailable."); - - _mtmdWeights = Task.Run(async () => await MtmdWeights.LoadFromFileAsync(Constants.MtmdMmpPath, _llamaWeights, _mtmdParams)).Result; - _context = _llamaWeights.CreateContext(@params); - } - - public void Dispose() - { - _context.Dispose(); - _mtmdWeights.Dispose(); - _llamaWeights.Dispose(); + _fixture = fixture; + _mediaMarker = _fixture.MediaMarker; } private SafeMtmdInputChunks TokenizeWithEmbed(Func loadEmbed) { - _mtmdWeights.ClearMedia(); + _fixture.Mtmd.ClearMedia(); var embed = loadEmbed(); Assert.NotNull(embed); @@ -58,7 +34,7 @@ private SafeMtmdInputChunks TokenizeWithEmbed(Func loadEmbed) using var mem = embed.GetData(); Assert.True(mem.Data.Length > 0); - var status = _mtmdWeights.Tokenize(_mediaMarker, addSpecial: true, parseSpecial: true, out var chunks); + var status = _fixture.Mtmd.Tokenize(_mediaMarker, addSpecial: true, parseSpecial: true, out var chunks); Assert.Equal(0, status); Assert.NotNull(chunks); @@ -68,41 +44,50 @@ private SafeMtmdInputChunks TokenizeWithEmbed(Func loadEmbed) private void AssertChunksEvaluate(SafeMtmdInputChunks chunks) { + using var context = _fixture.CreateContext(); int nPast = 0; - var eval = _mtmdWeights.EvaluateChunks(chunks, _context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)_context.BatchSize), logitsLast: true); + var eval = _fixture.Mtmd.EvaluateChunks(chunks, context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)context.BatchSize), logitsLast: true); Assert.Equal(0, eval); Assert.True(nPast > 0); } - [Fact, Trait("Category", "NoCI")] + [SkippableFact, Trait("Category", "NoCI")] public void BasicPropertyChecks() { - Assert.False(_mtmdWeights.SupportsAudio); - Assert.True(_mtmdWeights.SupportsVision); - Assert.False(_mtmdWeights.UsesMRope); - Assert.True(_mtmdWeights.UsesNonCausalAttention); - Assert.Equal(-1, _mtmdWeights.AudioBitrate); + Skip.IfNot(_fixture.IsEnabled, MtmdNoCiFixture.SkipReason); + + Assert.False(_fixture.Mtmd.SupportsAudio); + Assert.True(_fixture.Mtmd.SupportsVision); + Assert.False(_fixture.Mtmd.UsesMRope); + Assert.True(_fixture.Mtmd.UsesNonCausalAttention); + Assert.Equal(-1, _fixture.Mtmd.AudioBitrate); } - [Fact,Trait("Category", "NoCI")] + [SkippableFact, Trait("Category", "NoCI")] public void EmbedImageAsFileName() { - using var chunks = TokenizeWithEmbed(() => _mtmdWeights.LoadMedia(Constants.MtmdImage)); + Skip.IfNot(_fixture.IsEnabled, MtmdNoCiFixture.SkipReason); + + using var chunks = TokenizeWithEmbed(() => _fixture.Mtmd.LoadMedia(Constants.MtmdImage)); AssertChunksEvaluate(chunks); } - [Fact,Trait("Category", "NoCI")] + [SkippableFact, Trait("Category", "NoCI")] public void EmbedImageAsBinary() { + Skip.IfNot(_fixture.IsEnabled, MtmdNoCiFixture.SkipReason); + var imageBytes = File.ReadAllBytes(Constants.MtmdImage); - using var chunks = TokenizeWithEmbed(() => _mtmdWeights.LoadMedia(imageBytes)); + using var chunks = TokenizeWithEmbed(() => _fixture.Mtmd.LoadMedia(imageBytes)); AssertChunksEvaluate(chunks); } - [Fact,Trait("Category", "NoCI")] + [SkippableFact, Trait("Category", "NoCI")] public void TokenizeProvidesChunkMetadata() { - using var chunks = TokenizeWithEmbed(() => _mtmdWeights.LoadMedia(Constants.MtmdImage)); + Skip.IfNot(_fixture.IsEnabled, MtmdNoCiFixture.SkipReason); + + using var chunks = TokenizeWithEmbed(() => _fixture.Mtmd.LoadMedia(Constants.MtmdImage)); Assert.True(chunks.Size > 0); @@ -140,10 +125,10 @@ public void TokenizeProvidesChunkMetadata() Assert.True(totalTokens > 0); Assert.Equal(totalTokens, chunks.CountTokens()); Assert.Equal(totalPositions, chunks.CountPositions()); - Assert.True(_mtmdWeights.SupportsVision); - Assert.False(_mtmdWeights.SupportsAudio); + Assert.True(_fixture.Mtmd.SupportsVision); + Assert.False(_fixture.Mtmd.SupportsAudio); - var audioBitrate = _mtmdWeights.AudioBitrate; + var audioBitrate = _fixture.Mtmd.AudioBitrate; Assert.True(audioBitrate <= 0); } } diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index 6668abbab..b7342f392 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -141,7 +141,7 @@ public class ModelOptions public float? DefragThreshold { get; set; } /// - public LLamaPoolingType PoolingType { get; set; } + public LLamaPoolingType PoolingType { get; set; } = LLamaPoolingType.Unspecified; /// public LLamaAttentionType AttentionType { get; set; } = LLamaAttentionType.Unspecified; diff --git a/LLama.Web/LLama.Web.csproj b/LLama.Web/LLama.Web.csproj index 5569b0f95..35e7d2c02 100644 --- a/LLama.Web/LLama.Web.csproj +++ b/LLama.Web/LLama.Web.csproj @@ -15,6 +15,7 @@ + diff --git a/LLama.Web/Models/ModelSession.cs b/LLama.Web/Models/ModelSession.cs index 53f76d8c7..8c3938d68 100644 --- a/LLama.Web/Models/ModelSession.cs +++ b/LLama.Web/Models/ModelSession.cs @@ -4,39 +4,72 @@ using LLama.Native; using LLama.Sampling; using LLama.Web.Common; +using Microsoft.Extensions.Logging; namespace LLama.Web.Models; public class ModelSession { + private sealed record TemplateMarkers(string? AssistantEndMarker, string? AssistantToUserMarker); + + private static readonly HashSet LegacyDefaultAntiPrompts = new(StringComparer.Ordinal) + { + "User:", + "<|im_end|>", + "<|eot_id|>", + "<|endoftext|>" + }; + + private static readonly HashSet LegacyDefaultOutputFilters = new(StringComparer.Ordinal) + { + "User:", + "Assistant:", + "<|im_end|>", + "<|eot_id|>", + "<|endoftext|>" + }; + private readonly string _sessionId; private readonly LLamaModel _model; private readonly LLamaContext _context; private readonly ILLamaExecutor _executor; private readonly ISessionConfig _sessionConfig; + private readonly ILogger? _logger; + private readonly List _configuredAntiPrompts; + private readonly List _configuredOutputFilters; private readonly ITextStreamTransform _outputTransform; private readonly InferenceOptions _defaultInferenceConfig; private readonly ChatHistory _chatHistory = new(); private readonly string? _systemPrompt; - private readonly string? _templateUserMarker; + private readonly TemplateMarkers _templateMarkers; private readonly object _historyLock = new(); private CancellationTokenSource _cancellationTokenSource; - public ModelSession(LLamaModel model, LLamaContext context, string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null) + public ModelSession(LLamaModel model, LLamaContext context, string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null, ILogger? logger = null) { _model = model; _context = context; _sessionId = sessionId; _sessionConfig = sessionConfig; + _logger = logger; _defaultInferenceConfig = inferenceOptions ?? new InferenceOptions { SamplingPipeline = new DefaultSamplingPipeline() }; - _outputTransform = CreateOutputFilter(); - _executor = CreateExecutor(); _systemPrompt = _sessionConfig.Prompt; - _templateUserMarker = ResolveTemplateUserMarker(); + _configuredAntiPrompts = _sessionConfig.GetAntiPrompts(); + _configuredOutputFilters = _sessionConfig.GetOutputFilters(); + _executor = CreateExecutor(); + _templateMarkers = ResolveTemplateMarkers(); + _outputTransform = CreateOutputFilter(); + _logger?.LogInformation( + "Session {SessionId} created for model {ModelName} using {ExecutorType}. Template markers: assistant-end={AssistantEndMarker}; assistant-to-user={AssistantToUserMarker}.", + _sessionId, + _sessionConfig.Model, + _sessionConfig.ExecutorType, + EscapeForLog(_templateMarkers.AssistantEndMarker), + EscapeForLog(_templateMarkers.AssistantToUserMarker)); if (_sessionConfig.ExecutorType != LLamaExecutorType.Stateless && !string.IsNullOrWhiteSpace(_systemPrompt)) _chatHistory.AddMessage(AuthorRole.System, _systemPrompt); @@ -136,10 +169,20 @@ internal IAsyncEnumerable InferAsync(string message, InferenceOptions in { var inferenceParams = ConfigureInferenceParams(inferenceConfig); _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - - var inferenceStream = _executor.InferAsync(message, inferenceParams, _cancellationTokenSource.Token); + _logger?.LogInformation( + "Session {SessionId} starting inference. Prompt={Prompt}. DecodeSpecialTokens={DecodeSpecialTokens}. AntiPrompts=[{AntiPrompts}]. OutputFilters=[{OutputFilters}].", + _sessionId, + EscapeForLog(message, 320), + inferenceParams.DecodeSpecialTokens, + JoinForLog(inferenceParams.AntiPrompts), + JoinForLog(GetEffectiveOutputFilters())); + + var inferenceStream = LogStreamAsync( + _executor.InferAsync(message, inferenceParams, _cancellationTokenSource.Token), + "raw", + _cancellationTokenSource.Token); if (_outputTransform is not null) - return _outputTransform.TransformAsync(inferenceStream); + return LogStreamAsync(_outputTransform.TransformAsync(inferenceStream), "filtered", _cancellationTokenSource.Token); return inferenceStream; } @@ -191,41 +234,106 @@ internal void AddAssistantMessage(string assistantContent) private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig) { var inferenceParams = inferenceConfig ?? _defaultInferenceConfig; - var antiPrompts = _sessionConfig.GetAntiPrompts(); - if (!string.IsNullOrWhiteSpace(_templateUserMarker)) - { - if (!antiPrompts.Contains(_templateUserMarker)) - antiPrompts.Add(_templateUserMarker); - - var trimmedMarker = _templateUserMarker.Trim(); - if (!string.IsNullOrWhiteSpace(trimmedMarker) && !antiPrompts.Contains(trimmedMarker)) - antiPrompts.Add(trimmedMarker); - } + var antiPrompts = GetEffectiveAntiPrompts(); + AddTemplateMarkers(antiPrompts); inferenceParams.AntiPrompts = antiPrompts; + inferenceParams.DecodeSpecialTokens = inferenceParams.DecodeSpecialTokens || ShouldDecodeSpecialTokens(antiPrompts); return inferenceParams; } private ITextStreamTransform CreateOutputFilter() { - var outputFilters = _sessionConfig.GetOutputFilters(); + var outputFilters = GetEffectiveOutputFilters(); if (outputFilters.Count > 0) return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters); return null; } + private List GetEffectiveAntiPrompts() + { + var antiPrompts = new List(_configuredAntiPrompts); + if (HasTemplateMarkers() && ContainsOnlyLegacyDefaults(antiPrompts, LegacyDefaultAntiPrompts)) + antiPrompts.Clear(); + + return antiPrompts; + } + + private List GetEffectiveOutputFilters() + { + var outputFilters = new List(_configuredOutputFilters); + if (HasTemplateMarkers() && ContainsOnlyLegacyDefaults(outputFilters, LegacyDefaultOutputFilters)) + outputFilters.Clear(); + + AddTemplateMarkers(outputFilters); + return outputFilters; + } + + private bool HasTemplateMarkers() + { + return !string.IsNullOrWhiteSpace(_templateMarkers.AssistantEndMarker) + || !string.IsNullOrWhiteSpace(_templateMarkers.AssistantToUserMarker); + } + + private void AddTemplateMarkers(List values) + { + AddMarker(values, _templateMarkers.AssistantEndMarker); + AddMarker(values, _templateMarkers.AssistantToUserMarker); + } + + private static void AddMarker(List values, string? marker) + { + if (string.IsNullOrWhiteSpace(marker)) + return; + + if (!values.Contains(marker)) + values.Add(marker); + + var trimmedMarker = marker.Trim(); + if (!string.IsNullOrWhiteSpace(trimmedMarker) && !values.Contains(trimmedMarker)) + values.Add(trimmedMarker); + } + + private static bool ContainsOnlyLegacyDefaults(List values, HashSet legacyDefaults) + { + if (values.Count == 0) + return false; + + foreach (var value in values) + { + if (!legacyDefaults.Contains(value)) + return false; + } + + return true; + } + + private static bool ShouldDecodeSpecialTokens(IReadOnlyList antiPrompts) + { + foreach (var antiPrompt in antiPrompts) + { + if (string.IsNullOrWhiteSpace(antiPrompt)) + continue; + + if (antiPrompt.Contains('<', StringComparison.Ordinal) || antiPrompt.Contains('>', StringComparison.Ordinal)) + return true; + } + + return false; + } + private ILLamaExecutor CreateExecutor() { return _sessionConfig.ExecutorType switch { LLamaExecutorType.Interactive => _model.MtmdWeights is null - ? new InteractiveExecutor(_context) - : new InteractiveExecutor(_context, _model.MtmdWeights), + ? new InteractiveExecutor(_context, _logger) + : new InteractiveExecutor(_context, _model.MtmdWeights, _logger), LLamaExecutorType.Instruct => _model.MtmdWeights is null - ? new InstructExecutor(_context) - : new InstructExecutor(_context, _model.MtmdWeights), - LLamaExecutorType.Stateless => new StatelessExecutor(_model.LLamaWeights, _context.Params), + ? new InstructExecutor(_context, logger: _logger) + : new InstructExecutor(_context, _model.MtmdWeights, logger: _logger), + LLamaExecutorType.Stateless => new StatelessExecutor(_model.LLamaWeights, _context.Params, _logger), _ => default }; } @@ -272,7 +380,7 @@ private string FormatChatHistory(ChatHistory history, bool addAssistant) return LLamaTemplate.Encoding.GetString(template.Apply()); } - private string? ResolveTemplateUserMarker() + private TemplateMarkers ResolveTemplateMarkers() { const string userMarkerA = "__LLAMA_USER_A__"; const string assistantMarkerA = "__LLAMA_ASSISTANT_A__"; @@ -280,26 +388,92 @@ private string FormatChatHistory(ChatHistory history, bool addAssistant) try { - var template = new LLamaTemplate(_model.LLamaWeights.NativeHandle) + var assistantTemplate = new LLamaTemplate(_model.LLamaWeights.NativeHandle) { AddAssistant = false }; - template.Add("user", userMarkerA); - template.Add("assistant", assistantMarkerA); - template.Add("user", userMarkerB); - - var rendered = LLamaTemplate.Encoding.GetString(template.Apply()); - var assistantIndex = rendered.IndexOf(assistantMarkerA, StringComparison.Ordinal); - var userIndex = rendered.IndexOf(userMarkerB, StringComparison.Ordinal); - if (assistantIndex < 0 || userIndex <= assistantIndex) - return null; - - var between = rendered.Substring(assistantIndex + assistantMarkerA.Length, userIndex - (assistantIndex + assistantMarkerA.Length)); - return string.IsNullOrWhiteSpace(between) ? null : between; + assistantTemplate.Add("user", userMarkerA); + assistantTemplate.Add("assistant", assistantMarkerA); + + var assistantRendered = LLamaTemplate.Encoding.GetString(assistantTemplate.Apply()); + var assistantIndex = assistantRendered.IndexOf(assistantMarkerA, StringComparison.Ordinal); + if (assistantIndex < 0) + return new TemplateMarkers(null, null); + + var assistantEndMarker = assistantRendered[(assistantIndex + assistantMarkerA.Length)..]; + + var conversationTemplate = new LLamaTemplate(_model.LLamaWeights.NativeHandle) + { + AddAssistant = false + }; + conversationTemplate.Add("user", userMarkerA); + conversationTemplate.Add("assistant", assistantMarkerA); + conversationTemplate.Add("user", userMarkerB); + + var conversationRendered = LLamaTemplate.Encoding.GetString(conversationTemplate.Apply()); + var assistantConversationIndex = conversationRendered.IndexOf(assistantMarkerA, StringComparison.Ordinal); + var userIndex = conversationRendered.IndexOf(userMarkerB, StringComparison.Ordinal); + if (assistantConversationIndex < 0 || userIndex <= assistantConversationIndex) + return new TemplateMarkers(NormalizeMarker(assistantEndMarker), null); + + var assistantToUserMarker = conversationRendered.Substring( + assistantConversationIndex + assistantMarkerA.Length, + userIndex - (assistantConversationIndex + assistantMarkerA.Length)); + + return new TemplateMarkers( + NormalizeMarker(assistantEndMarker), + NormalizeMarker(assistantToUserMarker)); } catch { - return null; + _logger?.LogWarning("Session {SessionId} failed to resolve template markers for model {ModelName}.", _sessionId, _sessionConfig.Model); + return new TemplateMarkers(null, null); } } + + private static string? NormalizeMarker(string? marker) + { + return string.IsNullOrWhiteSpace(marker) ? null : marker; + } + + private async IAsyncEnumerable LogStreamAsync(IAsyncEnumerable stream, string stage, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken) + { + var chunkCount = 0; + await foreach (var chunk in stream.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + chunkCount++; + _logger?.LogInformation( + "Session {SessionId} {Stage} chunk {ChunkCount}: {Chunk}", + _sessionId, + stage, + chunkCount, + EscapeForLog(chunk, 240)); + yield return chunk; + } + } + + private static string JoinForLog(IReadOnlyList? values) + { + if (values is null || values.Count == 0) + return string.Empty; + + return string.Join(", ", values.Select(value => EscapeForLog(value))); + } + + private static string EscapeForLog(string? value, int maxLength = 120) + { + if (string.IsNullOrEmpty(value)) + return ""; + + var escaped = value + .Replace("\\", "\\\\", StringComparison.Ordinal) + .Replace("\r", "\\r", StringComparison.Ordinal) + .Replace("\n", "\\n", StringComparison.Ordinal) + .Replace("\t", "\\t", StringComparison.Ordinal); + + if (escaped.Length <= maxLength) + return escaped; + + return escaped[..maxLength] + "..."; + } } diff --git a/LLama.Web/Pages/Chat.razor b/LLama.Web/Pages/Chat.razor index ec41f2a39..642814686 100644 --- a/LLama.Web/Pages/Chat.razor +++ b/LLama.Web/Pages/Chat.razor @@ -241,17 +241,21 @@ @(_isRecordingAudio ? "Stop recording" : "Record audio") } - +
- @if (_pendingFiles.Count > 0) + @if (_pendingFiles.Count > 0 || _recordedAudioAccepted) {
@foreach (var file in _pendingFiles) { @file.Name } + @if (_recordedAudioAccepted) + { + @_recordedAudioName + }
} @if (_hasRecordedAudio) @@ -271,7 +275,8 @@
} - + + @if (!string.IsNullOrWhiteSpace(_composerError)) {
@_composerError
@@ -279,7 +284,7 @@
- +
@@ -293,6 +298,8 @@ private readonly List _downloadSnapshots = new(); private ElementReference _threadRef; + private ElementReference _composerInputRef; + private ElementReference _sendButtonRef; private bool _initialized; private bool _sessionLoaded; private bool _isSending; @@ -318,9 +325,10 @@ private bool _hasRecordedAudio; private string _recordedAudioName = string.Empty; private string _recordedAudioPreviewUrl = string.Empty; - private MemoryBrowserFile _recordedAudioFile; + private bool _recordedAudioAccepted; private bool _canLoad => !_sessionLoaded && IsModelReady(_sessionConfig.Model); + private bool _canSend => _sessionLoaded && !_isSending && (!string.IsNullOrWhiteSpace(_inputText) || _pendingFiles.Count > 0 || _recordedAudioAccepted); protected override Task OnInitializedAsync() { @@ -329,9 +337,9 @@ { Model = _models.FirstOrDefault()?.Name, ExecutorType = LLamaExecutorType.Interactive, - Prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.", - AntiPrompt = "User:,<|im_end|>,<|eot_id|>,<|endoftext|>", - OutputFilter = "User:, Assistant:,<|im_end|>,<|eot_id|>,<|endoftext|>" + Prompt = string.Empty, + AntiPrompt = string.Empty, + OutputFilter = string.Empty }; _samplingSettings = SamplingSettings.CreateDefault(); @@ -366,6 +374,11 @@ } } + if (_initialized) + { + await JS.InvokeVoidAsync("llamaChat.bindComposerShortcuts", _composerInputRef, _sendButtonRef); + } + if (_initialized && _pendingRenders.Count > 0) { var toRender = _messages.Where(message => _pendingRenders.Contains(message.Id)).ToList(); @@ -428,7 +441,15 @@ if (!_sessionLoaded) return; - await ModelSessionService.CloseAsync(_sessionId); + try + { + await ModelSessionService.CloseAsync(_sessionId); + } + catch (Exception ex) + { + AddSystemMessage($"Session cleanup warning: {ex.Message}", MessageKind.Info); + } + await AttachmentService.CleanupAsync(_sessionId); _sessionLoaded = false; _statusText = "Model not loaded"; @@ -441,7 +462,7 @@ if (!_sessionLoaded || _isSending) return; - if (string.IsNullOrWhiteSpace(_inputText)) + if (string.IsNullOrWhiteSpace(_inputText) && _pendingFiles.Count == 0 && !_recordedAudioAccepted) return; if (HasImageAttachments() && !_modelCapabilities.SupportsVision) @@ -450,7 +471,7 @@ return; } - if (HasAudioAttachments() && !_modelCapabilities.SupportsAudio) + if ((HasAudioAttachments() || _recordedAudioAccepted) && !_modelCapabilities.SupportsAudio) { _composerError = "This model does not support audio. Remove audio attachments or choose an audio-capable model."; return; @@ -462,6 +483,12 @@ _composerError = string.Empty; var attachments = await UploadAttachmentsAsync(); + if (string.IsNullOrWhiteSpace(prompt) && attachments.Count == 0) + { + _composerError = "Add text or attach a file before sending."; + _isSending = false; + return; + } var userMessage = new ChatMessageView(MessageKind.User, prompt) { @@ -480,7 +507,7 @@ try { - _inferenceOptions = BuildInferenceOptions(); + _inferenceOptions = BuildInferenceOptions(HasAudioAttachments(attachments)); await foreach (var token in ModelSessionService.InferAsync(_sessionId, request, _inferenceOptions)) { if (token.TokenType == TokenType.Content) @@ -519,15 +546,31 @@ private async Task> UploadAttachmentsAsync() { - if (_pendingFiles.Count == 0) + if (_pendingFiles.Count == 0 && !_recordedAudioAccepted) return new List(); try { AddSystemMessage("Uploading attachments...", MessageKind.Info); - var result = await AttachmentService.SaveAsync(_sessionId, _pendingFiles, CancellationToken.None); - _pendingFiles.Clear(); - return result.Attachments.ToList(); + var attachments = new List(); + + if (_pendingFiles.Count > 0) + { + var result = await AttachmentService.SaveAsync(_sessionId, _pendingFiles, CancellationToken.None); + attachments.AddRange(result.Attachments); + _pendingFiles.Clear(); + } + + if (_recordedAudioAccepted) + { + var recordedResult = await JS.InvokeAsync("llamaChat.uploadRecordedAudio", _sessionId, _recordedAudioName); + if (recordedResult?.Attachments != null) + attachments.AddRange(recordedResult.Attachments); + + ClearRecordedAudio(); + } + + return attachments; } catch (Exception ex) { @@ -560,10 +603,10 @@ _composerError = string.Empty; } - private void ClearAttachments() + private async Task ClearAttachments() { _pendingFiles.Clear(); - ClearRecordedAudio(); + await DiscardRecordedAudio(); _composerError = string.Empty; } @@ -659,7 +702,7 @@ private async Task HandleKeyPress(KeyboardEventArgs args) { - if (args.Key == "Enter" && !args.ShiftKey) + if (args.Key == "Enter" && !args.ShiftKey && !args.CtrlKey && !args.AltKey && !args.MetaKey) { await SendAsync(); } @@ -695,45 +738,44 @@ var result = await JS.InvokeAsync("llamaChat.stopAudioRecording"); _isRecordingAudio = false; - if (result == null || string.IsNullOrWhiteSpace(result.Base64)) + if (result == null || string.IsNullOrWhiteSpace(result.PreviewUrl)) { AddSystemMessage("No audio captured.", MessageKind.Info); return; } - var bytes = Convert.FromBase64String(result.Base64); - var mimeType = string.IsNullOrWhiteSpace(result.MimeType) ? "audio/webm" : result.MimeType; + var mimeType = string.IsNullOrWhiteSpace(result.MimeType) ? "audio/wav" : result.MimeType; var extension = GetAudioExtension(mimeType); var fileName = $"recording-{DateTimeOffset.Now:yyyyMMdd-HHmmss}{extension}"; - _recordedAudioFile = new MemoryBrowserFile(fileName, mimeType, bytes, DateTimeOffset.Now); _recordedAudioName = fileName; - _recordedAudioPreviewUrl = $"data:{mimeType};base64,{result.Base64}"; + _recordedAudioPreviewUrl = result.PreviewUrl; _hasRecordedAudio = true; + _recordedAudioAccepted = false; _composerError = string.Empty; } private void AcceptRecordedAudio() { - if (_recordedAudioFile == null) + if (!_hasRecordedAudio) return; - _pendingFiles.Add(_recordedAudioFile); - ClearRecordedAudio(); + _recordedAudioAccepted = true; _composerError = string.Empty; } - private void DiscardRecordedAudio() + private async Task DiscardRecordedAudio() { + await JS.InvokeVoidAsync("llamaChat.discardRecordedAudio"); ClearRecordedAudio(); } private void ClearRecordedAudio() { - _recordedAudioFile = null; _recordedAudioName = string.Empty; _recordedAudioPreviewUrl = string.Empty; _hasRecordedAudio = false; + _recordedAudioAccepted = false; } private static string GetAudioExtension(string contentType) @@ -793,12 +835,12 @@ return $"{adjusted:0.0} {sizes[order]}"; } - private InferenceOptions BuildInferenceOptions() + private InferenceOptions BuildInferenceOptions(bool simplifyForAudio = false) { return new InferenceOptions { MaxTokens = _samplingSettings.MaxTokens, - DecodeSpecialTokens = true, + DecodeSpecialTokens = false, SamplingPipeline = new DefaultSamplingPipeline { Temperature = _samplingSettings.Temperature, @@ -816,6 +858,17 @@ }; } + private static bool HasAudioAttachments(IEnumerable attachments) + { + foreach (var attachment in attachments) + { + if (attachment.Kind == AttachmentKind.Audio) + return true; + } + + return false; + } + public async ValueTask DisposeAsync() { _downloadCts?.Cancel(); @@ -879,8 +932,8 @@ { return new SamplingSettings { - MaxTokens = 512, - Temperature = 0.7f, + MaxTokens = 1024, + Temperature = 0.1f, TopP = 0.9f, TopK = 40, RepeatPenalty = 1.0f, @@ -918,7 +971,7 @@ private sealed class AudioRecordingStopResult { - public string Base64 { get; set; } + public string PreviewUrl { get; set; } public string MimeType { get; set; } public long Size { get; set; } } diff --git a/LLama.Web/Pages/Shared/_Layout.cshtml b/LLama.Web/Pages/Shared/_Layout.cshtml index 80585255f..417f38661 100644 --- a/LLama.Web/Pages/Shared/_Layout.cshtml +++ b/LLama.Web/Pages/Shared/_Layout.cshtml @@ -19,7 +19,6 @@ - diff --git a/LLama.Web/Services/ModelService.cs b/LLama.Web/Services/ModelService.cs index f5351387a..55bf18370 100644 --- a/LLama.Web/Services/ModelService.cs +++ b/LLama.Web/Services/ModelService.cs @@ -165,7 +165,7 @@ public async Task CreateContext(string modelName, string contextNa public async Task RemoveContext(string modelName, string contextName) { if (!_modelInstances.TryGetValue(modelName, out var model)) - throw new Exception("Model not found"); + return false; using (await _contextLock.LockAsync()) { diff --git a/LLama.Web/Services/ModelSessionService.cs b/LLama.Web/Services/ModelSessionService.cs index a316c8299..ef748c68c 100644 --- a/LLama.Web/Services/ModelSessionService.cs +++ b/LLama.Web/Services/ModelSessionService.cs @@ -18,6 +18,7 @@ public class ModelSessionService : IModelSessionService private readonly AsyncGuard _sessionGuard; private readonly IModelService _modelService; private readonly IAttachmentService _attachmentService; + private readonly ILogger _logger; private readonly ConcurrentDictionary _modelSessions; /// @@ -25,10 +26,11 @@ public class ModelSessionService : IModelSessionService /// /// The model service. /// The attachment service. - public ModelSessionService(IModelService modelService, IAttachmentService attachmentService) + public ModelSessionService(IModelService modelService, IAttachmentService attachmentService, ILogger logger) { _modelService = modelService; _attachmentService = attachmentService; + _logger = logger; _sessionGuard = new AsyncGuard(); _modelSessions = new ConcurrentDictionary(); } @@ -74,7 +76,7 @@ public async Task CreateAsync(string sessionId, ISessionConfig ses var (model, context) = await _modelService.GetOrCreateModelAndContext(sessionConfig.Model, sessionId); // Create session. - var modelSession = new ModelSession(model, context, sessionId, sessionConfig, inferenceConfig); + var modelSession = new ModelSession(model, context, sessionId, sessionConfig, inferenceConfig, _logger); if (!_modelSessions.TryAdd(sessionId, modelSession)) throw new Exception($"Failed to create model session"); @@ -94,7 +96,15 @@ public async Task CloseAsync(string sessionId) if (_modelSessions.TryRemove(sessionId, out var modelSession)) { modelSession.CancelInfer(); - return await _modelService.RemoveContext(modelSession.ModelName, sessionId); + try + { + return await _modelService.RemoveContext(modelSession.ModelName, sessionId); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Failed to remove context '{ContextName}' from model '{ModelName}' during session close.", sessionId, modelSession.ModelName); + return false; + } } return false; } @@ -247,6 +257,7 @@ private async Task PreparePromptAsync(string sessionId, ModelSes return new PreparedPrompt(string.Empty); var promptBuilder = new StringBuilder(request.Prompt ?? string.Empty); + var mediaPrefixBuilder = new StringBuilder(); var embeds = new List(); try @@ -271,14 +282,14 @@ private async Task PreparePromptAsync(string sessionId, ModelSes throw new Exception("This model does not support multimodal inputs."); embeds.Add(await LoadEmbedAsync(modelSession, attachment, cancellationToken)); - AppendMedia(promptBuilder, attachment, "Image", mediaMarker); + AppendMedia(modelSession, promptBuilder, mediaPrefixBuilder, mediaMarker); break; case AttachmentKind.Audio: if (!modelSession.IsMultiModal || !modelSession.SupportsAudio) throw new Exception("This model does not support audio inputs."); embeds.Add(await LoadEmbedAsync(modelSession, attachment, cancellationToken)); - AppendMedia(promptBuilder, attachment, "Audio", mediaMarker); + AppendMedia(modelSession, promptBuilder, mediaPrefixBuilder, mediaMarker); break; } } @@ -287,6 +298,9 @@ private async Task PreparePromptAsync(string sessionId, ModelSes if (embeds.Count > 0) modelSession.QueueEmbeds(embeds); + if (mediaPrefixBuilder.Length > 0) + promptBuilder.Insert(0, mediaPrefixBuilder.ToString()); + return new PreparedPrompt(promptBuilder.ToString()); } catch @@ -310,11 +324,23 @@ private static void AppendPdf(StringBuilder builder, AttachmentInfo attachment) builder.AppendLine("[PDF text truncated]"); } - private static void AppendMedia(StringBuilder builder, AttachmentInfo attachment, string mediaLabel, string mediaMarker) + private static void AppendMedia(ModelSession modelSession, StringBuilder builder, StringBuilder mediaPrefixBuilder, string mediaMarker) { - builder.AppendLine(); - builder.AppendLine($"[{mediaLabel}: {attachment.FileName}]"); - builder.AppendLine(mediaMarker); + if (ShouldPrefixMediaMarker(modelSession)) + { + mediaPrefixBuilder.Append(mediaMarker); + return; + } + + if (builder.Length > 0 && !char.IsWhiteSpace(builder[^1])) + builder.Append(' '); + + builder.Append(mediaMarker); + } + + private static bool ShouldPrefixMediaMarker(ModelSession modelSession) + { + return modelSession.ModelName.Contains("Qwen2.5-Omni", StringComparison.OrdinalIgnoreCase); } private static string GetMediaMarker() @@ -344,12 +370,23 @@ private static async Task LoadEmbedAsync(ModelSession modelSessio if (!File.Exists(attachment.FilePath)) throw new FileNotFoundException("Attachment not found.", attachment.FilePath); - var data = await File.ReadAllBytesAsync(attachment.FilePath, cancellationToken); var clipModel = modelSession.ClipModel; if (clipModel is null) throw new Exception("Multimodal model is not available for this session."); - var embed = clipModel.LoadMedia(data); + cancellationToken.ThrowIfCancellationRequested(); + + SafeMtmdEmbed? embed; + if (attachment.Kind == AttachmentKind.Audio) + { + var data = await File.ReadAllBytesAsync(attachment.FilePath, cancellationToken); + embed = clipModel.LoadMediaStandalone(data); + } + else + { + embed = clipModel.LoadMediaStandalone(attachment.FilePath); + } + if (embed == null) throw new Exception("Failed to prepare multimodal embedding."); diff --git a/LLama.Web/appsettings.json b/LLama.Web/appsettings.json index 41465acff..df6d5fd43 100644 --- a/LLama.Web/appsettings.json +++ b/LLama.Web/appsettings.json @@ -20,24 +20,7 @@ "BatchSize": 512, "UBatchSize": 512, "Threads": 4, - "GpuLayerCount": 32, - "UseMemorymap": true, - "UseMemoryLock": false, - "MainGpu": 0, - "LowVram": false, - "Seed": 1686349486, - "UseFp16Memory": false, - "Perplexity": false, - "LoraAdapter": "", - "LoraBase": "", - "EmbeddingMode": false, - "TensorSplits": null, - "GroupedQueryAttention": 1, - "RmsNormEpsilon": 0.000005, - "RopeFrequencyBase": 10000.0, - "RopeFrequencyScale": 1.0, - "MulMatQ": false, - "Encoding": "UTF-8" + "GpuLayerCount": 32 }, { "Name": "Gemma-3-4B-IT Q4_K_M", @@ -50,24 +33,7 @@ "BatchSize": 512, "UBatchSize": 512, "Threads": 4, - "GpuLayerCount": 32, - "UseMemorymap": true, - "UseMemoryLock": false, - "MainGpu": 0, - "LowVram": false, - "Seed": 1686349486, - "UseFp16Memory": false, - "Perplexity": false, - "LoraAdapter": "", - "LoraBase": "", - "EmbeddingMode": false, - "TensorSplits": null, - "GroupedQueryAttention": 1, - "RmsNormEpsilon": 0.000005, - "RopeFrequencyBase": 10000.0, - "RopeFrequencyScale": 1.0, - "MulMatQ": false, - "Encoding": "UTF-8" + "GpuLayerCount": 32 }, { "Name": "Qwen2.5-Omni-3B Q4_K_M (Audio)", @@ -80,24 +46,20 @@ "BatchSize": 512, "UBatchSize": 512, "Threads": 4, - "GpuLayerCount": 32, - "UseMemorymap": true, - "UseMemoryLock": false, - "MainGpu": 0, - "LowVram": false, - "Seed": 1686349486, - "UseFp16Memory": false, - "Perplexity": false, - "LoraAdapter": "", - "LoraBase": "", - "EmbeddingMode": false, - "TensorSplits": null, - "GroupedQueryAttention": 1, - "RmsNormEpsilon": 0.000005, - "RopeFrequencyBase": 10000.0, - "RopeFrequencyScale": 1.0, - "MulMatQ": false, - "Encoding": "UTF-8" + "GpuLayerCount": 32 + }, + { + "Name": "Voxtral-Mini-3B-2507 Q4_K_M (Audio)", + "MaxInstances": 20, + "ModelPath": "Models/Voxtral-Mini-3B-2507-Q4_K_M.gguf", + "ModelDownloadUrl": "https://huggingface.co/ggml-org/Voxtral-Mini-3B-2507-GGUF/resolve/main/Voxtral-Mini-3B-2507-Q4_K_M.gguf", + "MmprojPath": "Models/mmproj-Voxtral-Mini-3B-2507-Q8_0.gguf", + "MmprojDownloadUrl": "https://huggingface.co/ggml-org/Voxtral-Mini-3B-2507-GGUF/resolve/main/mmproj-Voxtral-Mini-3B-2507-Q8_0.gguf", + "ContextSize": 32768, + "BatchSize": 512, + "UBatchSize": 512, + "Threads": 4, + "GpuLayerCount": 32 } ] } diff --git a/LLama.Web/wwwroot/js/blazorChat.js b/LLama.Web/wwwroot/js/blazorChat.js index 5099e04b0..2023637c1 100644 --- a/LLama.Web/wwwroot/js/blazorChat.js +++ b/LLama.Web/wwwroot/js/blazorChat.js @@ -3,12 +3,46 @@ window.llamaChat = (() => { let audioRecorder = null; let audioChunks = []; let audioStream = null; + let audioContext = null; + let audioSourceNode = null; + let audioProcessorNode = null; + let audioSilenceNode = null; + let audioSampleRate = 0; + let audioSampleChunks = []; + let recordedAudioBlob = null; + let recordedAudioUploadBlob = null; + let recordedAudioPreviewUrl = null; + let markdownWarningLogged = false; + const composerBindings = new WeakMap(); + + const escapeHtml = (text) => { + return (text || "") + .replaceAll("&", "&") + .replaceAll("<", "<") + .replaceAll(">", ">") + .replaceAll("\"", """) + .replaceAll("'", "'"); + }; + + const createPlainTextRenderer = () => ({ + render: (text) => escapeHtml(text).replaceAll("\n", "
") + }); const ensureMarkdown = () => { if (markdown) { return markdown; } + if (typeof window.markdownit !== "function") { + if (!markdownWarningLogged) { + console.warn("markdown-it is not available; falling back to plain text rendering."); + markdownWarningLogged = true; + } + + markdown = createPlainTextRenderer(); + return markdown; + } + markdown = window.markdownit({ html: false, linkify: true, @@ -69,7 +103,7 @@ window.llamaChat = (() => { }; const isAudioRecordingSupported = () => { - return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia && window.MediaRecorder); + return !!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia && (window.AudioContext || window.webkitAudioContext || window.MediaRecorder)); }; const getSupportedAudioMimeType = () => { @@ -78,17 +112,26 @@ window.llamaChat = (() => { } const candidates = [ + "audio/mp4", "audio/webm;codecs=opus", "audio/webm", "audio/ogg;codecs=opus", "audio/ogg", - "audio/mp4", "audio/mpeg" ]; return candidates.find((candidate) => MediaRecorder.isTypeSupported(candidate)) || ""; }; + const canPlayAudioType = (mimeType) => { + if (!mimeType) { + return false; + } + + const audio = document.createElement("audio"); + return !!audio.canPlayType && audio.canPlayType(mimeType) !== ""; + }; + const cleanupAudioStream = () => { if (audioStream) { audioStream.getTracks().forEach((track) => track.stop()); @@ -96,14 +139,186 @@ window.llamaChat = (() => { } }; - const bufferToBase64 = (buffer) => { - const bytes = new Uint8Array(buffer); - let binary = ""; - const chunkSize = 0x8000; - for (let i = 0; i < bytes.length; i += chunkSize) { - binary += String.fromCharCode.apply(null, bytes.subarray(i, i + chunkSize)); + const cleanupAudioGraph = async () => { + if (audioProcessorNode) { + audioProcessorNode.disconnect(); + audioProcessorNode.onaudioprocess = null; + audioProcessorNode = null; + } + + if (audioSourceNode) { + audioSourceNode.disconnect(); + audioSourceNode = null; + } + + if (audioSilenceNode) { + audioSilenceNode.disconnect(); + audioSilenceNode = null; + } + + if (audioContext) { + try { + await audioContext.close(); + } catch (err) { + console.warn("Failed to close audio context.", err); + } + audioContext = null; + } + + audioSampleRate = 0; + audioSampleChunks = []; + }; + + const clearRecordedAudio = () => { + recordedAudioBlob = null; + recordedAudioUploadBlob = null; + if (recordedAudioPreviewUrl) { + URL.revokeObjectURL(recordedAudioPreviewUrl); + recordedAudioPreviewUrl = null; + } + }; + + const interleaveAudioBuffer = (audioBuffer) => { + const channelCount = audioBuffer.numberOfChannels; + const sampleCount = audioBuffer.length; + const interleaved = new Float32Array(sampleCount * channelCount); + + for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex++) { + for (let channelIndex = 0; channelIndex < channelCount; channelIndex++) { + interleaved[sampleIndex * channelCount + channelIndex] = audioBuffer.getChannelData(channelIndex)[sampleIndex]; + } + } + + return interleaved; + }; + + const encodeWav = (audioBuffer) => { + const channelCount = audioBuffer.numberOfChannels; + const sampleRate = audioBuffer.sampleRate; + const samples = interleaveAudioBuffer(audioBuffer); + const bytesPerSample = 2; + const blockAlign = channelCount * bytesPerSample; + const buffer = new ArrayBuffer(44 + samples.length * bytesPerSample); + const view = new DataView(buffer); + + const writeString = (offset, value) => { + for (let i = 0; i < value.length; i++) { + view.setUint8(offset + i, value.charCodeAt(i)); + } + }; + + writeString(0, "RIFF"); + view.setUint32(4, 36 + samples.length * bytesPerSample, true); + writeString(8, "WAVE"); + writeString(12, "fmt "); + view.setUint32(16, 16, true); + view.setUint16(20, 1, true); + view.setUint16(22, channelCount, true); + view.setUint32(24, sampleRate, true); + view.setUint32(28, sampleRate * blockAlign, true); + view.setUint16(32, blockAlign, true); + view.setUint16(34, 16, true); + writeString(36, "data"); + view.setUint32(40, samples.length * bytesPerSample, true); + + let offset = 44; + for (let i = 0; i < samples.length; i++, offset += 2) { + const sample = Math.max(-1, Math.min(1, samples[i])); + view.setInt16(offset, sample < 0 ? sample * 0x8000 : sample * 0x7fff, true); + } + + return buffer; + }; + + const encodeMonoWav = (sampleRate, samples) => { + const bytesPerSample = 2; + const blockAlign = bytesPerSample; + const buffer = new ArrayBuffer(44 + samples.length * bytesPerSample); + const view = new DataView(buffer); + + const writeString = (offset, value) => { + for (let i = 0; i < value.length; i++) { + view.setUint8(offset + i, value.charCodeAt(i)); + } + }; + + writeString(0, "RIFF"); + view.setUint32(4, 36 + samples.length * bytesPerSample, true); + writeString(8, "WAVE"); + writeString(12, "fmt "); + view.setUint32(16, 16, true); + view.setUint16(20, 1, true); + view.setUint16(22, 1, true); + view.setUint32(24, sampleRate, true); + view.setUint32(28, sampleRate * blockAlign, true); + view.setUint16(32, blockAlign, true); + view.setUint16(34, 16, true); + writeString(36, "data"); + view.setUint32(40, samples.length * bytesPerSample, true); + + let offset = 44; + for (let i = 0; i < samples.length; i++, offset += 2) { + const sample = Math.max(-1, Math.min(1, samples[i])); + view.setInt16(offset, sample < 0 ? sample * 0x8000 : sample * 0x7fff, true); + } + + return buffer; + }; + + const mergeSampleChunks = (chunks) => { + const totalLength = chunks.reduce((sum, chunk) => sum + chunk.length, 0); + const merged = new Float32Array(totalLength); + let offset = 0; + + for (const chunk of chunks) { + merged.set(chunk, offset); + offset += chunk.length; + } + + return merged; + }; + + const mixInputToMono = (inputBuffer) => { + const channelCount = inputBuffer.numberOfChannels; + const sampleCount = inputBuffer.length; + const mono = new Float32Array(sampleCount); + + for (let channelIndex = 0; channelIndex < channelCount; channelIndex++) { + const channel = inputBuffer.getChannelData(channelIndex); + for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex++) { + mono[sampleIndex] += channel[sampleIndex]; + } + } + + if (channelCount > 1) { + for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex++) { + mono[sampleIndex] /= channelCount; + } + } + + return mono; + }; + + const convertBlobToWav = async (blob) => { + const AudioContextCtor = window.AudioContext || window.webkitAudioContext; + if (!AudioContextCtor) { + return blob; + } + + const audioContext = new AudioContextCtor(); + try { + const sourceBuffer = await blob.arrayBuffer(); + const decodedBuffer = await audioContext.decodeAudioData(sourceBuffer.slice(0)); + const wavBuffer = encodeWav(decodedBuffer); + return new Blob([wavBuffer], { type: "audio/wav" }); + } catch (err) { + console.warn("Failed to normalize recorded audio to WAV.", err); + return blob; + } finally { + if (typeof audioContext.close === "function") { + await audioContext.close(); + } } - return btoa(binary); }; return { @@ -134,6 +349,38 @@ window.llamaChat = (() => { } window.localStorage.setItem("llama.sidebar.tab", value); }, + bindComposerShortcuts: (textarea, sendButton) => { + if (!textarea || !sendButton) { + return; + } + + const existing = composerBindings.get(textarea); + if (existing === sendButton) { + return; + } + + if (!composerBindings.has(textarea)) { + textarea.addEventListener("keydown", (event) => { + if (event.key !== "Enter" || event.isComposing) { + return; + } + + if (event.shiftKey || event.ctrlKey || event.altKey || event.metaKey) { + return; + } + + event.preventDefault(); + const targetButton = composerBindings.get(textarea); + if (!targetButton || targetButton.disabled) { + return; + } + + targetButton.click(); + }); + } + + composerBindings.set(textarea, sendButton); + }, isAudioRecordingSupported: () => { return isAudioRecordingSupported(); }, @@ -142,12 +389,43 @@ window.llamaChat = (() => { return { started: false, error: "Audio recording is not supported in this browser." }; } - if (audioRecorder && audioRecorder.state === "recording") { + if ((audioRecorder && audioRecorder.state === "recording") || audioProcessorNode) { return { started: false, error: "Audio recording is already in progress." }; } try { - audioStream = await navigator.mediaDevices.getUserMedia({ audio: true }); + audioStream = await navigator.mediaDevices.getUserMedia({ + audio: { + channelCount: { ideal: 1 }, + echoCancellation: true, + noiseSuppression: true, + autoGainControl: true + } + }); + const AudioContextCtor = window.AudioContext || window.webkitAudioContext; + if (AudioContextCtor) { + audioContext = new AudioContextCtor(); + if (audioContext.state === "suspended" && typeof audioContext.resume === "function") { + await audioContext.resume(); + } + audioSampleRate = audioContext.sampleRate; + audioSampleChunks = []; + audioSourceNode = audioContext.createMediaStreamSource(audioStream); + audioProcessorNode = audioContext.createScriptProcessor(4096, 1, 1); + audioSilenceNode = audioContext.createGain(); + audioSilenceNode.gain.value = 0; + + audioProcessorNode.onaudioprocess = (event) => { + audioSampleChunks.push(mixInputToMono(event.inputBuffer)); + }; + + audioSourceNode.connect(audioProcessorNode); + audioProcessorNode.connect(audioSilenceNode); + audioSilenceNode.connect(audioContext.destination); + + return { started: true, mimeType: "audio/wav" }; + } + const mimeType = getSupportedAudioMimeType(); const options = mimeType ? { mimeType } : undefined; audioRecorder = new MediaRecorder(audioStream, options); @@ -159,7 +437,7 @@ window.llamaChat = (() => { } }; - audioRecorder.start(); + audioRecorder.start(250); return { started: true, mimeType: audioRecorder.mimeType || mimeType || "" }; } catch (err) { cleanupAudioStream(); @@ -167,27 +445,100 @@ window.llamaChat = (() => { } }, stopAudioRecording: () => { + if (audioProcessorNode) { + return (async () => { + const mergedSamples = mergeSampleChunks(audioSampleChunks); + if (mergedSamples.length === 0) { + await cleanupAudioGraph(); + cleanupAudioStream(); + return { previewUrl: "", mimeType: "", size: 0 }; + } + + const wavBuffer = encodeMonoWav(audioSampleRate || 16000, mergedSamples); + const wavBlob = new Blob([wavBuffer], { type: "audio/wav" }); + + clearRecordedAudio(); + recordedAudioBlob = wavBlob; + recordedAudioUploadBlob = wavBlob; + recordedAudioPreviewUrl = URL.createObjectURL(wavBlob); + + await cleanupAudioGraph(); + cleanupAudioStream(); + + return { previewUrl: recordedAudioPreviewUrl, mimeType: "audio/wav", size: wavBlob.size }; + })(); + } + if (!audioRecorder || audioRecorder.state !== "recording") { + cleanupAudioGraph(); cleanupAudioStream(); audioRecorder = null; audioChunks = []; - return Promise.resolve({ base64: "", mimeType: "", size: 0 }); + return Promise.resolve({ previewUrl: "", mimeType: "", size: 0 }); } return new Promise((resolve) => { audioRecorder.onstop = async () => { const blobType = audioRecorder.mimeType || "audio/webm"; - const blob = new Blob(audioChunks, { type: blobType }); - const buffer = await blob.arrayBuffer(); - const base64 = bufferToBase64(buffer); + const rawBlob = new Blob(audioChunks, { type: blobType }); + const uploadBlob = await convertBlobToWav(rawBlob); + const previewBlob = canPlayAudioType(rawBlob.type) ? rawBlob : uploadBlob; + + if (!previewBlob || previewBlob.size === 0) { + clearRecordedAudio(); + cleanupAudioStream(); + audioRecorder = null; + audioChunks = []; + resolve({ previewUrl: "", mimeType: "", size: 0 }); + return; + } + + clearRecordedAudio(); + recordedAudioBlob = rawBlob; + recordedAudioUploadBlob = uploadBlob; + recordedAudioPreviewUrl = URL.createObjectURL(previewBlob); cleanupAudioStream(); audioRecorder = null; audioChunks = []; - resolve({ base64, mimeType: blob.type, size: blob.size }); + resolve({ previewUrl: recordedAudioPreviewUrl, mimeType: previewBlob.type, size: previewBlob.size }); }; + try { + audioRecorder.requestData(); + } catch (err) { + console.warn("Failed to flush recorded audio before stopping.", err); + } + audioRecorder.stop(); }); + }, + discardRecordedAudio: () => { + clearRecordedAudio(); + }, + uploadRecordedAudio: async (connectionId, fileName) => { + const uploadBlob = recordedAudioUploadBlob || recordedAudioBlob; + if (!uploadBlob) { + return { attachments: [] }; + } + + const uploadFileName = fileName || "recording.wav"; + const formData = new FormData(); + formData.append("connectionId", connectionId); + formData.append("files", uploadBlob, uploadFileName); + + const response = await fetch("/api/attachments", { + method: "POST", + body: formData + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(errorText || "Unable to upload recorded audio."); + } + + const result = await response.json(); + clearRecordedAudio(); + return result; } }; })(); diff --git a/LLama/AntipromptProcessor.cs b/LLama/AntipromptProcessor.cs index e4ec0f188..eeecb97e5 100644 --- a/LLama/AntipromptProcessor.cs +++ b/LLama/AntipromptProcessor.cs @@ -68,7 +68,7 @@ public bool Add(string text) _buffer = _buffer.Substring(_buffer.Length - trimLength); foreach (var antiprompt in _antiprompts) - if (_buffer.EndsWith(antiprompt, StringComparison.CurrentCulture)) + if (_buffer.EndsWith(antiprompt, StringComparison.Ordinal)) return true; return false; diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index bda7472d5..2903637aa 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -143,6 +143,17 @@ public ChatSession AddInputTransform(ITextTransform transform) return this; } + private StatefulExecutorBase GetStatefulExecutor() + { + return (StatefulExecutorBase)Executor; + } + + private void EnsureSessionPersistenceSupported() + { + if (GetStatefulExecutor().IsMultiModal) + throw new NotSupportedException("Session persistence is not supported for multimodal chat sessions."); + } + /// /// Use a custom output transform. /// @@ -162,6 +173,7 @@ public ChatSession WithOutputTransform(ITextStreamTransform transform) /// public void SaveSession(string path) { + EnsureSessionPersistenceSupported(); GetSessionState().Save(path); } @@ -171,7 +183,9 @@ public void SaveSession(string path) /// SessionState object representing session state in-memory public SessionState GetSessionState() { - var executorState = ((StatefulExecutorBase)Executor).GetStateData(); + EnsureSessionPersistenceSupported(); + + var executorState = GetStatefulExecutor().GetStateData(); return new SessionState( executorState.PastTokensCount > 0 ? Executor.Context.GetState() : null, @@ -191,6 +205,8 @@ public SessionState GetSessionState() /// public void LoadSession(SessionState state, bool loadTransforms = true) { + EnsureSessionPersistenceSupported(); + if (Executor is StatefulExecutorBase statefulExecutor) { if (state.ExecutorState is not null) @@ -224,12 +240,14 @@ public void LoadSession(SessionState state, bool loadTransforms = true) /// public void LoadSession(string path, bool loadTransforms = true) { + EnsureSessionPersistenceSupported(); + var state = SessionState.Load(path); // Handle non-polymorphic serialization of executor state if (state.ExecutorState is null) { var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME); - ((StatefulExecutorBase)Executor).LoadState(filename: executorPath); + GetStatefulExecutor().LoadState(filename: executorPath); } LoadSession(state, loadTransforms); } diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 54499e5e1..98ee304ad 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -86,12 +86,8 @@ public bool IsMultiModal /// public List Embeds { get; } - /// - /// Pending multimodal chunks produced by the MTMD tokenizer. - /// - protected SafeMtmdInputChunks? MtmdChunks { get; set; } - private string? _mtmdMarker; + private readonly Queue _mtmdPromptSegments = new(); private readonly StreamingTokenDecoder _decoder; @@ -134,6 +130,9 @@ public StatefulExecutorBase(LLamaContext context, MtmdWeights mtmdWeights, ILogg /// public StatefulExecutorBase WithSessionFile(string filename) { + if (IsMultiModal) + throw new NotSupportedException("Session files are not supported for multimodal executors."); + _pathSession = filename; if (string.IsNullOrEmpty(filename)) { @@ -191,6 +190,9 @@ public StatefulExecutorBase WithSessionFile(string filename) /// Destination path for the llama.cpp session file. public void SaveSessionFile(string filename) { + if (IsMultiModal) + throw new NotSupportedException("Session files are not supported for multimodal executors."); + var session_token_array = _session_tokens.ToArray(); NativeApi.llama_state_save_file(Context.NativeHandle, filename, session_token_array, (ulong)session_token_array.Length); } @@ -201,6 +203,9 @@ public void SaveSessionFile(string filename) /// protected virtual void HandleRunOutOfContext(int tokensToKeep) { + if (IsMultiModal) + throw new RuntimeError("Context shifting is not supported for multimodal executors. Reduce the prompt or increase the context size."); + // if we run out of context: // - take the tokensToKeep first tokens from the original prompt (via n_past) // - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches @@ -215,11 +220,40 @@ protected virtual void HandleRunOutOfContext(int tokensToKeep) _pathSession = string.Empty; } + /// + /// Get the number of context positions required by the pending input. + /// + protected int GetPendingInputPositionCount() + { + if (IsMultiModal) + { + var pending = _embeds.Count; + foreach (var segment in _mtmdPromptSegments) + pending += segment.RemainingPositionCount; + + return pending; + } + + return _embeds.Count; + } + + /// + /// Shift the context window when the pending input would overflow the available KV cache. + /// + protected void EnsurePendingInputFitsContext(int tokensToKeep) + { + if (_pastTokensCount + GetPendingInputPositionCount() > Context.ContextSize) + HandleRunOutOfContext(tokensToKeep); + } + /// /// Try to reuse the matching prompt prefix from the loaded session cache before evaluating new tokens. /// protected virtual void TryReuseMatchingPrefix() { + if (IsMultiModal) + return; + if (_n_session_consumed < _session_tokens.Count) { int i = 0; @@ -250,12 +284,12 @@ protected virtual void TryReuseMatchingPrefix() } /// - /// Dispose and clear any queued multimodal chunk collection. + /// Dispose and clear any queued multimodal prompt segments. /// - protected void DisposeMtmdChunks() + protected void DisposeMtmdPromptSegments() { - MtmdChunks?.Dispose(); - MtmdChunks = null; + while (_mtmdPromptSegments.Count > 0) + _mtmdPromptSegments.Dequeue().Dispose(); } /// @@ -284,20 +318,6 @@ protected string GetMtmdMarker() return _mtmdMarker; } - /// - /// Ensure the token list fills all positional slots reported by the MTMD helper. - /// - protected static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken) - { - if (totalPositions <= tokens.Count) - return new List(tokens); - - var result = new List(totalPositions); - result.AddRange(tokens); - result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count)); - return result; - } - /// /// Resolve the fallback token inserted when the tokenizer emits fewer tokens than positions. /// @@ -322,8 +342,6 @@ protected Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, boo if (ClipModel is null) throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); - DisposeMtmdChunks(); - var marker = GetMtmdMarker(); var prompt = text; @@ -342,32 +360,65 @@ protected Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, boo SafeMtmdInputChunks? chunks = null; try { - var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); + var embeds = Embeds.ToArray(); + var status = Embeds.Count > 0 + ? ClipModel.Tokenize(prompt, addBos, parseSpecial: true, embeds, out chunks) + : ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks); if (status != 0 || chunks is null) { ClipModel.ClearMedia(); throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}."); } - MtmdChunks = chunks; + var promptTextTokenCount = LoadMtmdPromptSegments(chunks, GetFillerToken(marker), replaceExisting); + args.RemainedTokens -= promptTextTokenCount; + } + catch + { + chunks?.Dispose(); + throw; + } + finally + { + DisposeEmbeds(); + } - var tokens = new List(); + return Task.CompletedTask; + } + + /// + /// Build ordered MTMD prompt segments and logical placeholder tokens from the tokenized chunks. + /// + protected int LoadMtmdPromptSegments(SafeMtmdInputChunks chunks, LLamaToken fillerToken, bool replaceExisting) + { + var promptTokens = new List(); + var promptSegments = new List(); + var promptTextTokenCount = 0; + + try + { foreach (var chunk in chunks.Enumerate()) { using var scopedChunk = chunk; - if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) - continue; - - foreach (var token in scopedChunk.GetTextTokensSpan()) - tokens.Add(token); + if (scopedChunk.Type == SafeMtmdInputChunk.SafeMtmdInputChunkType.Text) + { + var textTokens = scopedChunk.GetTextTokensSpan().ToArray().Select(static token => (LLamaToken)token).ToArray(); + promptTokens.AddRange(textTokens); + promptTextTokenCount += textTokens.Length; + promptSegments.Add(MtmdPromptSegment.CreateText(textTokens)); + } + else + { + var copy = scopedChunk.Copy() ?? throw new RuntimeError("Failed to copy multimodal prompt chunk."); + promptTokens.AddRange(Enumerable.Repeat(fillerToken, checked((int)scopedChunk.NTokens))); + promptSegments.Add(MtmdPromptSegment.CreateMedia(copy)); + } } - var totalPositions = (int)chunks.CountPositions(); - var fillerToken = GetFillerToken(marker); - if (replaceExisting) { - _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken); + DisposeMtmdPromptSegments(); + _embed_inps = promptTokens; _consumedTokensCount = 0; } else @@ -375,66 +426,105 @@ protected Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, boo if (_embed_inps.Count == 0) _embed_inps = new List(); - _embed_inps.AddRange(tokens); - var fillerCount = totalPositions - tokens.Count; - if (fillerCount > 0) - _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount)); - - args.RemainedTokens -= tokens.Count; + _embed_inps.AddRange(promptTokens); } + + foreach (var segment in promptSegments) + _mtmdPromptSegments.Enqueue(segment); + + return promptTextTokenCount; } catch { - chunks?.Dispose(); - MtmdChunks = null; + foreach (var segment in promptSegments) + segment.Dispose(); + throw; } - finally - { - DisposeEmbeds(); - } - - return Task.CompletedTask; } /// - /// Apply bookkeeping after successfully evaluating multimodal chunks. + /// Queue the next multimodal text segment into if available. /// - protected void FinalizeMtmdEvaluation(int newNPast, int previousConsumed) + protected void QueueNextMtmdPromptInput() { - _pastTokensCount = newNPast; - DisposeMtmdChunks(); - - if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed) + while (_mtmdPromptSegments.Count > 0 && _embeds.Count < Context.BatchSize) { - _session_tokens.AddRange(_embed_inps.Skip(previousConsumed)); - _n_session_consumed = _session_tokens.Count; - } + var segment = _mtmdPromptSegments.Peek(); + if (segment.Kind != MtmdPromptSegmentKind.Text) + break; - _consumedTokensCount = _embed_inps.Count; - _embeds.Clear(); + var count = Math.Min(segment.RemainingTextTokenCount, checked((int)Context.BatchSize) - _embeds.Count); + var span = segment.TextTokens.AsSpan(segment.TextOffset, count); + foreach (var token in span) + { + _embeds.Add(token); + _last_n_tokens.Enqueue(token); + _consumedTokensCount++; + } + + segment.AdvanceText(count); + if (segment.RemainingTextTokenCount == 0) + _mtmdPromptSegments.Dequeue().Dispose(); + } } /// - /// Evaluate the queued MTMD chunks and update executor state. + /// Evaluate the current multimodal media segment and advance the prompt cursor. /// - protected void EvaluateMtmdChunks(ref int nPast, int previousConsumed, string executorName) + protected void EvaluateNextMtmdMediaSegment(string executorName) { if (ClipModel is null) throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model."); - if (MtmdChunks is null) - throw new InvalidOperationException("No MTMD chunks are queued for evaluation."); + if (_mtmdPromptSegments.Count == 0) + throw new InvalidOperationException("No MTMD prompt segments are queued for evaluation."); + + var segment = _mtmdPromptSegments.Peek(); + if (segment.Kind != MtmdPromptSegmentKind.Media || segment.Chunk is null) + throw new InvalidOperationException("Current MTMD prompt segment is not a media chunk."); - var evalStatus = ClipModel.EvaluateChunks(MtmdChunks, Context.NativeHandle, ref nPast, seqId: 0, + var nPast = _pastTokensCount; + var evalStatus = ClipModel.EvaluateChunk(segment.Chunk.NativePtr, Context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)Context.BatchSize), logitsLast: true); if (evalStatus != 0) { _logger?.LogError("[{Executor}] Failed to evaluate multimodal chunks. Status: {Status}", executorName, evalStatus); - DisposeMtmdChunks(); throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}."); } - FinalizeMtmdEvaluation(nPast, previousConsumed); + _pastTokensCount = nPast; + _consumedTokensCount += segment.TokenCount; + _mtmdPromptSegments.Dequeue().Dispose(); + } + + /// + /// Determine whether any prompt input is still pending evaluation. + /// + protected bool HasPendingPromptInput() + { + if (IsMultiModal) + return _embeds.Count > 0 || _mtmdPromptSegments.Count > 0; + + return _embed_inps.Count > _consumedTokensCount; + } + + /// + /// Determine whether any prompt content still remains to be queued/evaluated, excluding sampled output tokens. + /// + protected bool HasPendingPromptQueue() + { + if (IsMultiModal) + return _mtmdPromptSegments.Count > 0; + + return _embed_inps.Count > _consumedTokensCount; + } + + /// + /// Determine whether the next queued multimodal prompt segment is a media chunk. + /// + protected bool HasPendingMtmdMediaSegment() + { + return _mtmdPromptSegments.Count > 0 && _mtmdPromptSegments.Peek().Kind == MtmdPromptSegmentKind.Media; } /// @@ -509,6 +599,7 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc { cancellationToken.ThrowIfCancellationRequested(); inferenceParams ??= new InferenceParams(); + _decoder.DecodeSpecialTokens = inferenceParams.DecodeSpecialTokens; var args = new InferStateArgs { @@ -578,10 +669,8 @@ public virtual async Task PrefillPromptAsync(string prompt, CancellationToken ca }; await PreprocessInputs(prompt, args, cancellationToken); - // First run adds the prompt to the _embeds - await InferInternal(inferenceParams, args, cancellationToken); - // Second run puts it through decode - await InferInternal(inferenceParams, args, cancellationToken); + while (HasPendingPromptInput()) + await InferInternal(inferenceParams, args, cancellationToken); } /// @@ -666,6 +755,55 @@ internal ExecutorDiagnostics GetDiagnostics() _pastTokensCount, _embeds.Count); } + + private enum MtmdPromptSegmentKind + { + Text, + Media, + } + + private sealed class MtmdPromptSegment : IDisposable + { + private MtmdPromptSegment(MtmdPromptSegmentKind kind, LLamaToken[]? textTokens, SafeMtmdInputChunk? chunk) + { + Kind = kind; + TextTokens = textTokens ?? []; + Chunk = chunk; + } + + public MtmdPromptSegmentKind Kind { get; } + + public LLamaToken[] TextTokens { get; } + + public SafeMtmdInputChunk? Chunk { get; } + + public int TextOffset { get; private set; } + + public int RemainingTextTokenCount => TextTokens.Length - TextOffset; + + public int RemainingPositionCount => Kind == MtmdPromptSegmentKind.Text ? RemainingTextTokenCount : checked((int)(Chunk?.NPos ?? 0)); + + public int TokenCount => Kind == MtmdPromptSegmentKind.Text ? TextTokens.Length : checked((int)(Chunk?.NTokens ?? 0)); + + public static MtmdPromptSegment CreateText(LLamaToken[] tokens) + => new(MtmdPromptSegmentKind.Text, tokens, null); + + public static MtmdPromptSegment CreateMedia(SafeMtmdInputChunk chunk) + => new(MtmdPromptSegmentKind.Media, null, chunk); + + public void AdvanceText(int count) + { + if (Kind != MtmdPromptSegmentKind.Text) + throw new InvalidOperationException("Cannot advance text offset for a media segment."); + + TextOffset = Math.Min(TextTokens.Length, TextOffset + count); + } + + public void Dispose() + { + Chunk?.Dispose(); + } + } } } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index e7724cb65..9ba550862 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -63,6 +63,9 @@ public InstructExecutor(LLamaContext context, /// public override ExecutorBaseState GetStateData() { + if (IsMultiModal) + throw new NotSupportedException("State persistence is not supported for multimodal instruct executors."); + InstructExecutorState state = new() { ConsumedSessionCount = _n_session_consumed, @@ -84,7 +87,9 @@ public override ExecutorBaseState GetStateData() /// public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default) { - DisposeMtmdChunks(); + if (IsMultiModal) + throw new NotSupportedException("State persistence is not supported for multimodal instruct executors."); + if(data is InstructExecutorState state) { _n_session_consumed = state.ConsumedSessionCount; @@ -111,6 +116,9 @@ public override Task LoadState(ExecutorBaseState data, CancellationToken cancell /// public override async Task SaveState(string filename, CancellationToken cancellationToken = default) { + if (IsMultiModal) + throw new NotSupportedException("State persistence is not supported for multimodal instruct executors."); + var state = (InstructExecutorState)GetStateData(); using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) { @@ -121,6 +129,9 @@ public override async Task SaveState(string filename, CancellationToken cancella /// public override async Task LoadState(string filename, CancellationToken cancellationToken = default) { + if (IsMultiModal) + throw new NotSupportedException("State persistence is not supported for multimodal instruct executors."); + using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) { var state = await JsonSerializer.DeserializeAsync(fs); @@ -131,7 +142,7 @@ public override async Task LoadState(string filename, CancellationToken cancella /// protected override Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken) { - return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run); + return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run || HasPendingPromptInput()); } /// @@ -193,7 +204,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc /// protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { - if (_embed_inps.Count <= _consumedTokensCount) + if (!HasPendingPromptQueue()) { if (!string.IsNullOrEmpty(args.LastOutput) && AntipromptProcessor.Add(args.LastOutput)) { @@ -205,17 +216,17 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc { return Task.FromResult<(bool, IReadOnlyList)>((true, ["\n> "])); } - } - if (_embeds.Count > 0 && _embeds.Last() == Context.Vocab.EOS) - { - args.WaitForInput = true; - } + if (_embeds.Count > 0 && _embeds.Last() == Context.Vocab.EOS) + { + args.WaitForInput = true; + } - if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1) - { - args.RemainedTokens = inferenceParams.MaxTokens; - args.WaitForInput = true; + if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1) + { + args.RemainedTokens = inferenceParams.MaxTokens; + args.WaitForInput = true; + } } return Task.FromResult<(bool, IReadOnlyList)>((false, [])); } @@ -228,15 +239,13 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In if (_embeds.Count > 0) { _is_prompt_run = false; - if (_pastTokensCount + _embeds.Count > Context.ContextSize) - { - // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 - // Instruct always uses input token size. - var tokensToKeep = _embed_inps.Count; - HandleRunOutOfContext(tokensToKeep); - } + // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 + // Instruct always uses input token size. + var tokensToKeep = _embed_inps.Count; + EnsurePendingInputFitsContext(tokensToKeep); - TryReuseMatchingPrefix(); + if (!IsMultiModal) + TryReuseMatchingPrefix(); var (result, _, pastTokensCount) = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount); _pastTokensCount = pastTokensCount; @@ -252,17 +261,16 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In _n_session_consumed = _session_tokens.Count; } } - else if (IsMultiModal && MtmdChunks is not null) + else if (IsMultiModal && HasPendingMtmdMediaSegment()) { _is_prompt_run = false; - var nPast = _pastTokensCount; - var previousConsumed = _consumedTokensCount; - EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InstructExecutor)); + EnsurePendingInputFitsContext(_embed_inps.Count); + EvaluateNextMtmdMediaSegment(nameof(InstructExecutor)); } _embeds.Clear(); - if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) + if (!HasPendingPromptInput() && !args.WaitForInput) { if (inferenceParams.MaxTokens == 0) { @@ -289,14 +297,21 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In } else { - while (_embed_inps.Count > _consumedTokensCount) + if (IsMultiModal) + { + QueueNextMtmdPromptInput(); + } + else { - _embeds.Add(_embed_inps[_consumedTokensCount]); - _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); - _consumedTokensCount++; - if (_embeds.Count >= Context.BatchSize) + while (_embed_inps.Count > _consumedTokensCount) { - break; + _embeds.Add(_embed_inps[_consumedTokensCount]); + _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); + _consumedTokensCount++; + if (_embeds.Count >= Context.BatchSize) + { + break; + } } } } diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 91e1d4b7f..6dc8c32e1 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -25,11 +25,6 @@ public class InteractiveExecutor : StatefulExecutorBase // Indicates whether the executor is currently evaluating the initial prompt or a follow-up turn. private bool _is_prompt_run = true; - // MTMD multimodal state - private SafeMtmdInputChunks? _mtmdChunks; // Pending chunk collection produced by the multimodal tokenizer. - private string? _mtmdMarker; // Cached multimodal marker returned by the native helper. - - /// /// Create an interactive executor for text-only inference. /// @@ -54,6 +49,9 @@ public InteractiveExecutor(LLamaContext context, MtmdWeights clipModel, ILogger? /// public override ExecutorBaseState GetStateData() { + if (IsMultiModal) + throw new NotSupportedException("State persistence is not supported for multimodal interactive executors."); + InteractiveExecutorState state = new() { ConsumedSessionCount = _n_session_consumed, @@ -74,7 +72,9 @@ public override ExecutorBaseState GetStateData() /// public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default) { - DisposeMtmdChunks(); + if (IsMultiModal) + throw new NotSupportedException("State persistence is not supported for multimodal interactive executors."); + if (data is InteractiveExecutorState state) { _n_session_consumed = state.ConsumedSessionCount; @@ -97,6 +97,9 @@ public override Task LoadState(ExecutorBaseState data, CancellationToken cancell /// public override async Task SaveState(string filename, CancellationToken cancellationToken = default) { + if (IsMultiModal) + throw new NotSupportedException("State persistence is not supported for multimodal interactive executors."); + var state = (InteractiveExecutorState)GetStateData(); using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) { @@ -107,6 +110,9 @@ public override async Task SaveState(string filename, CancellationToken cancella /// public override async Task LoadState(string filename, CancellationToken cancellationToken = default) { + if (IsMultiModal) + throw new NotSupportedException("State persistence is not supported for multimodal interactive executors."); + using var fs = new FileStream(filename, FileMode.Open, FileAccess.Read); var state = await JsonSerializer.DeserializeAsync(fs); @@ -120,7 +126,7 @@ public override async Task LoadState(string filename, CancellationToken cancella /// true to keep generating; otherwise false. protected override Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken) { - return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run); + return Task.FromResult((args.RemainedTokens != 0 && !args.WaitForInput) || _is_prompt_run || HasPendingPromptInput()); } /// @@ -182,7 +188,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc /// Tuple describing whether to stop and any additional outputs to emit. protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { - if (_embed_inps.Count <= _consumedTokensCount) + if (!HasPendingPromptQueue()) { if (!string.IsNullOrEmpty(args.LastOutput) && AntipromptProcessor.Add(args.LastOutput)) { @@ -193,18 +199,18 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args, Canc { return Task.FromResult((true, (IReadOnlyList)[])); } - } - if (_embeds.Count > 0 && _embeds.Last().IsEndOfGeneration(Context.Vocab)) - { - return Task.FromResult((true, (IReadOnlyList)[])); - } + if (_embeds.Count > 0 && _embeds.Last().IsEndOfGeneration(Context.Vocab)) + { + return Task.FromResult((true, (IReadOnlyList)[])); + } - if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1) - { - args.RemainedTokens = inferenceParams.MaxTokens; - args.WaitForInput = true; - return Task.FromResult((true, (IReadOnlyList)[])); + if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1) + { + args.RemainedTokens = inferenceParams.MaxTokens; + args.WaitForInput = true; + return Task.FromResult((true, (IReadOnlyList)[])); + } } return Task.FromResult((false, (IReadOnlyList)[])); @@ -218,59 +224,54 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In if (_embeds.Count > 0) { _is_prompt_run = false; - if (_pastTokensCount + _embeds.Count > Context.ContextSize) + // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 + var tokensToKeep = inferenceParams.TokensKeep; + if (tokensToKeep < 0 || tokensToKeep > _embed_inps.Count) { - // number of tokens to keep when resetting context - // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 - var tokensToKeep = inferenceParams.TokensKeep; - if (tokensToKeep < 0 || tokensToKeep > _embed_inps.Count) - { - tokensToKeep = _embed_inps.Count; - } - else - { - tokensToKeep += Convert.ToInt32(Context.Vocab.ShouldAddBOS); // always keep the BOS token - } - - HandleRunOutOfContext(tokensToKeep); + tokensToKeep = _embed_inps.Count; } - - if (MtmdChunks is null) + else { - TryReuseMatchingPrefix(); + tokensToKeep += Convert.ToInt32(Context.Vocab.ShouldAddBOS); // always keep the BOS token } - if (IsMultiModal && MtmdChunks is not null) + EnsurePendingInputFitsContext(tokensToKeep); + + if (!IsMultiModal) { - var nPast = _pastTokensCount; - var previousConsumed = _consumedTokensCount; - EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InteractiveExecutor)); + TryReuseMatchingPrefix(); } - else - { - var result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount); - _pastTokensCount = result.Item3; - if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1); + var result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount); + _pastTokensCount = result.Item3; - if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) - { - _session_tokens.AddRange(_embeds); - _n_session_consumed = _session_tokens.Count; - } + if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1); + + if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) + { + _session_tokens.AddRange(_embeds); + _n_session_consumed = _session_tokens.Count; } } - else if (IsMultiModal && MtmdChunks is not null) + else if (IsMultiModal && HasPendingMtmdMediaSegment()) { _is_prompt_run = false; - var nPast = _pastTokensCount; - var previousConsumed = _consumedTokensCount; - EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InteractiveExecutor)); + var tokensToKeep = inferenceParams.TokensKeep; + if (tokensToKeep < 0 || tokensToKeep > _embed_inps.Count) + { + tokensToKeep = _embed_inps.Count; + } + else + { + tokensToKeep += Convert.ToInt32(Context.Vocab.ShouldAddBOS); + } + EnsurePendingInputFitsContext(tokensToKeep); + EvaluateNextMtmdMediaSegment(nameof(InteractiveExecutor)); } _embeds.Clear(); - if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput) + if (!HasPendingPromptInput() && !args.WaitForInput) { if (inferenceParams.MaxTokens == 0) { @@ -294,6 +295,14 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In if (id == Context.NativeHandle.ModelHandle.Vocab.EOS) { + if (IsMultiModal) + { + _embeds.Add(id); + args.RemainedTokens--; + args.ReturnValue = false; + return; + } + id = Context.NativeHandle.ModelHandle.Vocab.Newline!.Value; if (args.Antiprompts is not null && args.Antiprompts.Count > 0) { @@ -309,14 +318,21 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In } else { - while (_embed_inps.Count > _consumedTokensCount) + if (IsMultiModal) + { + QueueNextMtmdPromptInput(); + } + else { - _embeds.Add(_embed_inps[_consumedTokensCount]); - _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); - _consumedTokensCount++; - if (_embeds.Count >= Context.BatchSize) + while (_embed_inps.Count > _consumedTokensCount) { - break; + _embeds.Add(_embed_inps[_consumedTokensCount]); + _last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]); + _consumedTokensCount++; + if (_embeds.Count >= Context.BatchSize) + { + break; + } } } } diff --git a/LLama/MtmdWeights.cs b/LLama/MtmdWeights.cs index 07e739a61..1655601ca 100644 --- a/LLama/MtmdWeights.cs +++ b/LLama/MtmdWeights.cs @@ -12,6 +12,8 @@ namespace LLama; public sealed class MtmdWeights : IDisposable { + private readonly object _syncRoot = new(); + /// /// The native handle, which is used in the native APIs /// @@ -79,42 +81,87 @@ public static async Task LoadFromFileAsync(string mmProject, LLamaW /// /// Load media from disk and keep it pending for the next tokenize call. /// - public SafeMtmdEmbed LoadMedia(string path) => NativeHandle.LoadMediaFromFile(path); + public SafeMtmdEmbed LoadMedia(string path) + { + lock (_syncRoot) + return NativeHandle.LoadMediaFromFile(path); + } /// /// Load media from an in-memory buffer and keep it pending for the next tokenize call. /// - public SafeMtmdEmbed LoadMedia(ReadOnlySpan data) => NativeHandle.LoadMediaFromBuffer(data); + public SafeMtmdEmbed LoadMedia(ReadOnlySpan data) + { + lock (_syncRoot) + return NativeHandle.LoadMediaFromBuffer(data); + } + + /// + /// Load media from disk as a standalone embedding without touching the shared pending-media queue. + /// + public SafeMtmdEmbed LoadMediaStandalone(string path) + { + lock (_syncRoot) + return NativeHandle.CreateMediaEmbedFromFile(path); + } + + /// + /// Load media from an in-memory buffer as a standalone embedding without touching the shared pending-media queue. + /// + public SafeMtmdEmbed LoadMediaStandalone(ReadOnlySpan data) + { + lock (_syncRoot) + return NativeHandle.CreateMediaEmbedFromBuffer(data); + } /// /// Clear any pending media buffers before or after tokenization. /// - public void ClearMedia() => NativeHandle.ClearMedia(); + public void ClearMedia() + { + lock (_syncRoot) + NativeHandle.ClearMedia(); + } /// /// Tokenize text (with optional special tokens) against the pending media buffers. /// public int Tokenize(string text, bool addSpecial, bool parseSpecial, out SafeMtmdInputChunks? chunks) - => NativeHandle.Tokenize(text, addSpecial, parseSpecial, out chunks); + { + lock (_syncRoot) + return NativeHandle.Tokenize(text, addSpecial, parseSpecial, out chunks); + } /// /// Tokenize text (with optional special tokens) against explicit media embeddings. /// The caller retains ownership of . /// public int Tokenize(string text, bool addSpecial, bool parseSpecial, ReadOnlySpan embeds, out SafeMtmdInputChunks? chunks) - => NativeHandle.Tokenize(text, addSpecial, parseSpecial, embeds, out chunks); + { + lock (_syncRoot) + return NativeHandle.Tokenize(text, addSpecial, parseSpecial, embeds, out chunks); + } /// /// Evaluate a chunk batch using the helper that performs mtmd encode + llama decode. /// public int EvaluateChunks(SafeMtmdInputChunks chunks, SafeLLamaContextHandle llamaContext, ref int nPast, int seqId, int nBatch, bool logitsLast) - => NativeHandle.EvaluateChunks(chunks, llamaContext, ref nPast, seqId, nBatch, logitsLast); + { + lock (_syncRoot) + return NativeHandle.EvaluateChunks(chunks, llamaContext, ref nPast, seqId, nBatch, logitsLast); + } public int EvaluateChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, ref int nPast, int seqId, int nBatch, bool logitsLast) - => NativeHandle.EvaluateChunk(chunkPtr, llamaContext, ref nPast, seqId, nBatch, logitsLast); + { + lock (_syncRoot) + return NativeHandle.EvaluateChunk(chunkPtr, llamaContext, ref nPast, seqId, nBatch, logitsLast); + } public int DecodeImageChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, IntPtr encodedEmbeddings, ref int nPast, int seqId, int nBatch) - => NativeHandle.DecodeImageChunk(chunkPtr, llamaContext, encodedEmbeddings, ref nPast, seqId, nBatch); + { + lock (_syncRoot) + return NativeHandle.DecodeImageChunk(chunkPtr, llamaContext, encodedEmbeddings, ref nPast, seqId, nBatch); + } /// /// Indicates whether the model supports vision inputs. diff --git a/LLama/Native/SafeMtmdModelHandle.cs b/LLama/Native/SafeMtmdModelHandle.cs index 4a356f637..2b31f2d7e 100644 --- a/LLama/Native/SafeMtmdModelHandle.cs +++ b/LLama/Native/SafeMtmdModelHandle.cs @@ -76,8 +76,7 @@ public SafeMtmdEmbed LoadMediaFromFile(string path) { EnsureNotDisposed(); - var embed = SafeMtmdEmbed.FromMediaFile(this, path) - ?? throw new RuntimeError($"Failed to load media '{path}'."); + var embed = CreateMediaEmbedFromFile(path); _pendingMedia.Add(embed); return embed; } @@ -93,12 +92,41 @@ public SafeMtmdEmbed LoadMediaFromBuffer(ReadOnlySpan buffer) { EnsureNotDisposed(); - var embed = SafeMtmdEmbed.FromMediaBuffer(this, buffer) - ?? throw new RuntimeError("Failed to load media from buffer."); + var embed = CreateMediaEmbedFromBuffer(buffer); _pendingMedia.Add(embed); return embed; } + /// + /// Create a standalone media embedding from disk without queueing it for the next tokenize call. + /// + /// Path to the media file on disk. + /// Safe handle to the prepared media embedding. + /// The model handle has been disposed. + /// The native loader failed to ingest the file contents. + public SafeMtmdEmbed CreateMediaEmbedFromFile(string path) + { + EnsureNotDisposed(); + + return SafeMtmdEmbed.FromMediaFile(this, path) + ?? throw new RuntimeError($"Failed to load media '{path}'."); + } + + /// + /// Create a standalone media embedding from an in-memory buffer without queueing it for the next tokenize call. + /// + /// Binary buffer containing the encoded media data. + /// Safe handle to the prepared media embedding. + /// The model handle has been disposed. + /// The native loader failed to ingest the buffer contents. + public SafeMtmdEmbed CreateMediaEmbedFromBuffer(ReadOnlySpan buffer) + { + EnsureNotDisposed(); + + return SafeMtmdEmbed.FromMediaBuffer(this, buffer) + ?? throw new RuntimeError("Failed to load media from buffer."); + } + /// /// Disposes and clears any media buffers currently queued for tokenization. /// From 9e2c72f589ee2959a83f4acecd1bc09cf48aec1a Mon Sep 17 00:00:00 2001 From: SignalRT Date: Fri, 20 Mar 2026 22:01:23 +0100 Subject: [PATCH 4/8] Clean up + Documentation Some cleanup and change documentation. only mtmd doc update. I think we should regererate all doc, but I'm not sure --- LLama.Unittest/MtmdContextGuardTests.cs | 185 ++++--- LLama.Unittest/MtmdExecutorTests.cs | 36 +- LLama.Unittest/MtmdNoCiCollection.cs | 39 +- LLama.Unittest/MtmdWeightsTests.cs | 16 +- .../NativeLibraryConfigContainerTests.cs | 117 +++++ LLama/LLamaExecutorBase.cs | 114 ++++- LLama/LLamaInstructExecutor.cs | 26 +- LLama/LLamaInteractExecutor.cs | 27 +- LLama/Native/Load/NativeLibraryConfig.cs | 6 +- docs/Examples/MtmdInteractiveModeExecute.md | 14 +- docs/QuickStart.md | 79 ++- docs/Tutorials/Executors.md | 28 +- docs/Tutorials/NativeLibraryConfig.md | 4 +- docs/xmldocs/index.md | 16 +- .../llama.abstractions.illamaexecutor.md | 12 +- docs/xmldocs/llama.batched.batchedexecutor.md | 24 + docs/xmldocs/llama.batched.conversation.md | 46 +- docs/xmldocs/llama.instructexecutor.md | 100 +++- docs/xmldocs/llama.interactiveexecutor.md | 110 ++-- docs/xmldocs/llama.mtmdweights.md | 348 +++++++++++++ .../xmldocs/llama.native.mtmdcontextparams.md | 152 ++++++ ...ama.native.nativelibraryconfigcontainer.md | 14 +- docs/xmldocs/llama.native.safemtmdembed.md | 251 ++++++++++ .../llama.native.safemtmdinputchunk.md | 180 +++++++ .../llama.native.safemtmdinputchunks.md | 149 ++++++ .../llama.native.safemtmdmodelhandle.md | 469 ++++++++++++++++++ docs/xmldocs/llama.statefulexecutorbase.md | 403 +++++++++++++-- docs/xmldocs/llama.statelessexecutor.md | 10 +- mkdocs.yml | 10 +- 29 files changed, 2571 insertions(+), 414 deletions(-) create mode 100644 LLama.Unittest/NativeLibraryConfigContainerTests.cs create mode 100644 docs/xmldocs/llama.mtmdweights.md create mode 100644 docs/xmldocs/llama.native.mtmdcontextparams.md create mode 100644 docs/xmldocs/llama.native.safemtmdembed.md create mode 100644 docs/xmldocs/llama.native.safemtmdinputchunk.md create mode 100644 docs/xmldocs/llama.native.safemtmdinputchunks.md create mode 100644 docs/xmldocs/llama.native.safemtmdmodelhandle.md diff --git a/LLama.Unittest/MtmdContextGuardTests.cs b/LLama.Unittest/MtmdContextGuardTests.cs index d2ee7b156..be726e89d 100644 --- a/LLama.Unittest/MtmdContextGuardTests.cs +++ b/LLama.Unittest/MtmdContextGuardTests.cs @@ -1,9 +1,7 @@ using System; using System.Collections.Generic; -using System.Runtime.CompilerServices; using System.IO; using System.Linq; -using System.Reflection; using System.Threading.Tasks; using LLama.Common; using LLama.Exceptions; @@ -12,83 +10,53 @@ namespace LLama.Unittest; -public sealed class MtmdContextGuardTests : IDisposable +[Collection("MtmdNoCi")] +[Trait("Category", "NoCI")] +public sealed class MtmdContextGuardTests { - private const string EnableEnvVar = "LLAMASHARP_RUN_MTMD_TESTS"; - private const string SkipReason = "MTMD tests are opt-in. Set LLAMASHARP_RUN_MTMD_TESTS=1 to run them."; + private readonly MtmdNoCiFixture _fixture; - private readonly bool _isEnabled; - private readonly LLamaWeights? _weights; - private readonly ModelParams? _modelParams; - - public MtmdContextGuardTests() + public MtmdContextGuardTests(MtmdNoCiFixture fixture) { - _isEnabled = string.Equals(Environment.GetEnvironmentVariable(EnableEnvVar), "1", StringComparison.OrdinalIgnoreCase) - || string.Equals(Environment.GetEnvironmentVariable(EnableEnvVar), "true", StringComparison.OrdinalIgnoreCase); - - if (!_isEnabled) - return; - - NativeLogConfig.llama_log_set(static (_, _) => { }); - - _modelParams = new ModelParams(Constants.GenerativeModelPath2) - { - ContextSize = 128, - BatchSize = 8, - UBatchSize = 8, - VocabOnly = false, - GpuLayerCount = Constants.CIGpuLayerCount, - }; - _weights = LLamaWeights.LoadFromFile(_modelParams); + _fixture = fixture; } - public void Dispose() - { - if (!_isEnabled) - return; - - _weights.Dispose(); - NativeLogConfig.llama_log_set((NativeLogConfig.LLamaLogCallback?)null); - } - - [SkippableFact] + [Fact] public void InteractiveExecutor_LoadMtmdPromptSegments_PreservesOrderedPromptAndMediaBoundary() { - Skip.IfNot(_isEnabled, SkipReason); + using var context = _fixture.CreateContext(); + using var chunks = _fixture.TokenizePromptWithMedia($"Before {_fixture.MediaMarker}"); + var executor = new TrackingInteractiveExecutor(context, _fixture.Mtmd); + var fillerToken = _fixture.GetFillerToken(context); - using var context = _weights!.CreateContext(_modelParams!, NullLogger.Instance); - using var chunks = new SafeMtmdInputChunks(NativeApi.mtmd_test_create_input_chunks()); - var executor = new TrackingInteractiveExecutor(context); - var fillerToken = (LLamaToken)99; - - var (imageTokenCount, imagePositionCount) = GetSyntheticImageChunkMetadata(chunks); + var (imageTokenCount, imagePositionCount) = GetImageChunkMetadata(chunks); executor.LoadPendingMtmd(chunks, fillerToken); - Assert.Equal( - new[] { (LLamaToken)1, 2, 3, 4, 5 }.Concat(Enumerable.Repeat(fillerToken, imageTokenCount)), - executor.EmbedInputs); - Assert.Equal(5 + imagePositionCount, executor.PendingPositions); + Assert.Contains(fillerToken, executor.EmbedInputs); + Assert.Equal((int)chunks.CountPositions(), executor.PendingPositions); + Assert.Equal((int)chunks.CountTokens(), executor.PendingKvTokens); executor.QueuePromptText(); - Assert.Equal(new[] { (LLamaToken)1, 2, 3, 4, 5 }, executor.PendingEmbeds); + Assert.NotEmpty(executor.PendingEmbeds); + var queuedTextCount = executor.PendingEmbeds.Count; Assert.True(executor.PendingMediaSegment); executor.ClearPendingEmbeds(); - Assert.Equal(imagePositionCount, executor.PendingPositions); + Assert.Equal((int)chunks.CountPositions() - queuedTextCount, executor.PendingPositions); + Assert.Equal((int)chunks.CountTokens() - queuedTextCount, executor.PendingKvTokens); } - [SkippableFact] + [Fact] public void InteractiveExecutor_MtmdOverflowThrowsWithoutContextShift() { - Skip.IfNot(_isEnabled, SkipReason); - - using var context = _weights!.CreateContext(_modelParams!, NullLogger.Instance); - using var chunks = new SafeMtmdInputChunks(NativeApi.mtmd_test_create_input_chunks()); - var executor = new TrackingInteractiveExecutor(context); + using var context = _fixture.CreateContext(); + using var chunks = _fixture.TokenizePromptWithMedia($"{_fixture.MediaMarker} describe"); + var executor = new TrackingInteractiveExecutor(context, _fixture.Mtmd); + var fillerToken = _fixture.GetFillerToken(context); - executor.LoadPendingMtmd(chunks, fillerToken: 99, pastTokensCount: (int)context.ContextSize - 1, pendingEmbeds: [(LLamaToken)1]); + executor.LoadPendingMtmd(chunks, fillerToken, pastTokensCount: (int)context.ContextSize - 1, pendingEmbeds: [(LLamaToken)1]); Assert.True(executor.PendingPositions > 1); @@ -97,16 +65,36 @@ public void InteractiveExecutor_MtmdOverflowThrowsWithoutContextShift() Assert.Contains("Context shifting is not supported", ex.Message); } - [SkippableFact] - public void InstructExecutor_MtmdOverflowThrowsWithoutContextShift() + [Fact] + public void InteractiveExecutor_MtmdGuardUsesKvOccupancyInsteadOfPositions() { - Skip.IfNot(_isEnabled, SkipReason); + using var context = _fixture.CreateContext(); + using var chunks = _fixture.TokenizePromptWithMedia($"Before {_fixture.MediaMarker}"); + var executor = new TrackingInteractiveExecutor(context, _fixture.Mtmd); + + executor.LoadPendingMtmd(chunks, _fixture.GetFillerToken(context)); + + var positionCompatiblePast = (int)context.ContextSize - executor.PendingPositions; + var kvOverflowPast = (int)context.ContextSize - executor.PendingKvTokens + 1; + + Assert.True(positionCompatiblePast + executor.PendingPositions <= context.ContextSize); + Assert.True(kvOverflowPast + executor.PendingKvTokens > context.ContextSize); + + executor.SetOccupiedCounts(positionCompatiblePast, kvOverflowPast); - using var context = _weights!.CreateContext(_modelParams!, NullLogger.Instance); - using var chunks = new SafeMtmdInputChunks(NativeApi.mtmd_test_create_input_chunks()); - var executor = new TrackingInstructExecutor(context); + var ex = Assert.Throws(() => executor.CheckContextGuard(new InferenceParams())); + Assert.Contains("Context shifting is not supported", ex.Message); + } + + [Fact] + public void InstructExecutor_MtmdOverflowThrowsWithoutContextShift() + { + using var context = _fixture.CreateContext(); + using var chunks = _fixture.TokenizePromptWithMedia($"{_fixture.MediaMarker} describe"); + var executor = new TrackingInstructExecutor(context, _fixture.Mtmd); + var fillerToken = _fixture.GetFillerToken(context); - executor.LoadPendingMtmd(chunks, fillerToken: 99, pastTokensCount: (int)context.ContextSize - 1, pendingEmbeds: [(LLamaToken)1]); + executor.LoadPendingMtmd(chunks, fillerToken, pastTokensCount: (int)context.ContextSize - 1, pendingEmbeds: [(LLamaToken)1]); Assert.True(executor.PendingPositions > 1); @@ -115,13 +103,11 @@ public void InstructExecutor_MtmdOverflowThrowsWithoutContextShift() Assert.Contains("Context shifting is not supported", ex.Message); } - [SkippableFact] + [Fact] public async Task InteractiveExecutor_MultimodalSessionAndStateApisAreRejected() { - Skip.IfNot(_isEnabled, SkipReason); - - using var context = _weights!.CreateContext(_modelParams!, NullLogger.Instance); - var executor = new TrackingInteractiveExecutor(context); + using var context = _fixture.CreateContext(); + var executor = new TrackingInteractiveExecutor(context, _fixture.Mtmd); var statePath = Path.Combine(Path.GetTempPath(), $"{Guid.NewGuid():N}.json"); Assert.Throws(() => executor.WithSessionFile(statePath)); @@ -131,13 +117,11 @@ public async Task InteractiveExecutor_MultimodalSessionAndStateApisAreRejected() await Assert.ThrowsAsync(() => executor.LoadState(statePath)); } - [SkippableFact] + [Fact] public async Task InstructExecutor_MultimodalSessionAndStateApisAreRejected() { - Skip.IfNot(_isEnabled, SkipReason); - - using var context = _weights!.CreateContext(_modelParams!, NullLogger.Instance); - var executor = new TrackingInstructExecutor(context); + using var context = _fixture.CreateContext(); + var executor = new TrackingInstructExecutor(context, _fixture.Mtmd); var statePath = Path.Combine(Path.GetTempPath(), $"{Guid.NewGuid():N}.json"); Assert.Throws(() => executor.WithSessionFile(statePath)); @@ -147,13 +131,11 @@ public async Task InstructExecutor_MultimodalSessionAndStateApisAreRejected() await Assert.ThrowsAsync(() => executor.LoadState(statePath)); } - [SkippableFact] + [Fact] public void ChatSession_MultimodalSessionPersistenceApisAreRejected() { - Skip.IfNot(_isEnabled, SkipReason); - - using var context = _weights!.CreateContext(_modelParams!, NullLogger.Instance); - var executor = new TrackingInteractiveExecutor(context); + using var context = _fixture.CreateContext(); + var executor = new TrackingInteractiveExecutor(context, _fixture.Mtmd); var chatSession = new ChatSession(executor); var sessionPath = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString("N")); var state = new SessionState( @@ -170,7 +152,7 @@ public void ChatSession_MultimodalSessionPersistenceApisAreRejected() Assert.Throws(() => chatSession.LoadSession(sessionPath)); } - private static (int ImageTokenCount, int ImagePositionCount) GetSyntheticImageChunkMetadata(SafeMtmdInputChunks chunks) + private static (int ImageTokenCount, int ImagePositionCount) GetImageChunkMetadata(SafeMtmdInputChunks chunks) { foreach (var chunk in chunks.Enumerate()) { @@ -182,31 +164,17 @@ private static (int ImageTokenCount, int ImagePositionCount) GetSyntheticImageCh throw new InvalidOperationException("Synthetic MTMD test chunks do not contain an image segment."); } - private static MtmdWeights CreateFakeMtmdWeights() - => (MtmdWeights)RuntimeHelpers.GetUninitializedObject(typeof(MtmdWeights)); - - private static class TrackingMultimodalExecutorBase - { - private static readonly FieldInfo ClipModelField = typeof(StatefulExecutorBase) - .GetField("k__BackingField", BindingFlags.Instance | BindingFlags.NonPublic) - ?? throw new InvalidOperationException("Unable to find ClipModel backing field."); - - public static void MarkAsMultimodal(StatefulExecutorBase executor) - { - ClipModelField.SetValue(executor, CreateFakeMtmdWeights()); - } - } - private sealed class TrackingInteractiveExecutor : InteractiveExecutor { - public TrackingInteractiveExecutor(LLamaContext context) - : base(context, NullLogger.Instance) + public TrackingInteractiveExecutor(LLamaContext context, MtmdWeights clipModel) + : base(context, clipModel, NullLogger.Instance) { - TrackingMultimodalExecutorBase.MarkAsMultimodal(this); } public int PendingPositions => GetPendingInputPositionCount(); + public int PendingKvTokens => GetPendingInputKvTokenCount(); + public IReadOnlyList EmbedInputs => _embed_inps; public IReadOnlyList PendingEmbeds => _embeds; @@ -217,9 +185,16 @@ public void LoadPendingMtmd(SafeMtmdInputChunks chunks, LLamaToken fillerToken, { LoadMtmdPromptSegments(chunks, fillerToken, replaceExisting: true); _pastTokensCount = pastTokensCount; + _kvTokenCount = pastTokensCount; _embeds = pendingEmbeds?.ToList() ?? []; } + public void SetOccupiedCounts(int pastTokensCount, int kvTokenCount) + { + _pastTokensCount = pastTokensCount; + _kvTokenCount = kvTokenCount; + } + public void QueuePromptText() => QueueNextMtmdPromptInput(); @@ -244,21 +219,29 @@ public void CheckContextGuard(InferenceParams inferenceParams) private sealed class TrackingInstructExecutor : InstructExecutor { - public TrackingInstructExecutor(LLamaContext context) - : base(context, logger: NullLogger.Instance) + public TrackingInstructExecutor(LLamaContext context, MtmdWeights clipModel) + : base(context, clipModel, logger: NullLogger.Instance) { - TrackingMultimodalExecutorBase.MarkAsMultimodal(this); } public int PendingPositions => GetPendingInputPositionCount(); + public int PendingKvTokens => GetPendingInputKvTokenCount(); + public void LoadPendingMtmd(SafeMtmdInputChunks chunks, LLamaToken fillerToken, int pastTokensCount = 0, IReadOnlyList? pendingEmbeds = null) { LoadMtmdPromptSegments(chunks, fillerToken, replaceExisting: true); _pastTokensCount = pastTokensCount; + _kvTokenCount = pastTokensCount; _embeds = pendingEmbeds?.ToList() ?? []; } + public void SetOccupiedCounts(int pastTokensCount, int kvTokenCount) + { + _pastTokensCount = pastTokensCount; + _kvTokenCount = kvTokenCount; + } + public void CheckContextGuard() => EnsurePendingInputFitsContext(_embed_inps.Count); } diff --git a/LLama.Unittest/MtmdExecutorTests.cs b/LLama.Unittest/MtmdExecutorTests.cs index 041d10c9f..9610c015b 100644 --- a/LLama.Unittest/MtmdExecutorTests.cs +++ b/LLama.Unittest/MtmdExecutorTests.cs @@ -20,11 +20,9 @@ public MtmdExecutorTests(MtmdNoCiFixture fixture) _fixture = fixture; } - [SkippableFact] + [Fact] public async Task InteractiveExecutor_EvaluateChunks_DoesNotRetokenize() { - Skip.IfNot(_fixture.IsEnabled, MtmdNoCiFixture.SkipReason); - using var context = _fixture.CreateContext(); var executor = new InteractiveExecutor(context, _fixture.Mtmd, NullLogger.Instance); var marker = _fixture.MediaMarker; @@ -42,11 +40,9 @@ public async Task InteractiveExecutor_EvaluateChunks_DoesNotRetokenize() Assert.Equal(0, diagnostics.PendingEmbedCount); } - [SkippableFact] + [Fact] public async Task InstructExecutor_MtmdPromptAdvancesPastTokensOnce() { - Skip.IfNot(_fixture.IsEnabled, MtmdNoCiFixture.SkipReason); - using var context = _fixture.CreateContext(); var executor = new InstructExecutor(context, _fixture.Mtmd, logger: NullLogger.Instance); executor.Embeds.Add(_fixture.Mtmd.LoadMedia(Constants.MtmdImage)); @@ -62,17 +58,15 @@ public async Task InstructExecutor_MtmdPromptAdvancesPastTokensOnce() Assert.Equal(0, diagnostics.PendingEmbedCount); } - [SkippableFact] + [Fact] public async Task InteractiveExecutor_PreprocessMtmd_PreservesInterleavedPromptOrder() { - Skip.IfNot(_fixture.IsEnabled, MtmdNoCiFixture.SkipReason); - using var context = _fixture.CreateContext(); var executor = new InspectableInteractiveExecutor(context, _fixture.Mtmd); var marker = _fixture.MediaMarker; var prompt = $"Before {marker} after."; - using var expectedChunks = TokenizePromptWithMedia(prompt); - var fillerToken = GetFillerToken(context, marker); + using var expectedChunks = _fixture.TokenizePromptWithMedia(prompt); + var fillerToken = _fixture.GetFillerToken(context); var expectedPromptTokens = BuildLogicalPromptTokens(expectedChunks, fillerToken); var expectedLeadingTextTokens = GetLeadingTextTokens(expectedChunks); @@ -88,26 +82,6 @@ public async Task InteractiveExecutor_PreprocessMtmd_PreservesInterleavedPromptO Assert.True(executor.PendingMediaSegment); } - private SafeMtmdInputChunks TokenizePromptWithMedia(string prompt) - { - _fixture.Mtmd.ClearMedia(); - using var embed = _fixture.Mtmd.LoadMedia(Constants.MtmdImage); - - var status = _fixture.Mtmd.Tokenize(prompt, addSpecial: true, parseSpecial: true, out var chunks); - Assert.Equal(0, status); - Assert.NotNull(chunks); - - return chunks!; - } - - private static LLamaToken GetFillerToken(LLamaContext context, string marker) - { - var markerTokens = context.Tokenize(marker, false, true); - return markerTokens.Length > 0 - ? markerTokens[^1] - : context.Vocab.EOS ?? default; - } - private static IReadOnlyList BuildLogicalPromptTokens(SafeMtmdInputChunks chunks, LLamaToken fillerToken) { var tokens = new List(); diff --git a/LLama.Unittest/MtmdNoCiCollection.cs b/LLama.Unittest/MtmdNoCiCollection.cs index 38733ffc7..4abb6d22e 100644 --- a/LLama.Unittest/MtmdNoCiCollection.cs +++ b/LLama.Unittest/MtmdNoCiCollection.cs @@ -13,17 +13,8 @@ public sealed class MtmdNoCiCollection : ICollectionFixture public sealed class MtmdNoCiFixture : IDisposable { - public const string EnableEnvVar = "LLAMASHARP_RUN_NOCI_MTMD"; - public const string SkipReason = "MTMD NoCI tests are opt-in. Set LLAMASHARP_RUN_NOCI_MTMD=1 to run them."; - public MtmdNoCiFixture() { - IsEnabled = string.Equals(Environment.GetEnvironmentVariable(EnableEnvVar), "1", StringComparison.OrdinalIgnoreCase) - || string.Equals(Environment.GetEnvironmentVariable(EnableEnvVar), "true", StringComparison.OrdinalIgnoreCase); - - if (!IsEnabled) - return; - NativeLogConfig.llama_log_set(static (_, _) => { }); ModelParams = new ModelParams(Constants.MtmdModelPath) @@ -42,8 +33,6 @@ public MtmdNoCiFixture() MediaMarker = MtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; } - public bool IsEnabled { get; } - public ModelParams ModelParams { get; } = null!; public LLamaWeights Weights { get; } = null!; @@ -55,15 +44,31 @@ public MtmdNoCiFixture() public string MediaMarker { get; } = string.Empty; public LLamaContext CreateContext() - => IsEnabled - ? Weights.CreateContext(ModelParams, NullLogger.Instance) - : throw new InvalidOperationException(SkipReason); + => Weights.CreateContext(ModelParams, NullLogger.Instance); - public void Dispose() + public SafeMtmdInputChunks TokenizePromptWithMedia(string prompt) + => TokenizePromptWithMedia(prompt, Constants.MtmdImage); + + public SafeMtmdInputChunks TokenizePromptWithMedia(string prompt, string mediaPath) { - if (!IsEnabled) - return; + using var embed = Mtmd.LoadMediaStandalone(mediaPath); + var embeds = new[] { embed }; + var status = Mtmd.Tokenize(prompt, addSpecial: true, parseSpecial: true, embeds, out var chunks); + Assert.Equal(0, status); + Assert.NotNull(chunks); + return chunks!; + } + + public LLamaToken GetFillerToken(LLamaContext context) + { + var markerTokens = context.Tokenize(MediaMarker, false, true); + return markerTokens.Length > 0 + ? markerTokens[^1] + : context.Vocab.EOS ?? default; + } + public void Dispose() + { Mtmd.Dispose(); Weights.Dispose(); NativeLogConfig.llama_log_set((NativeLogConfig.LLamaLogCallback?)null); diff --git a/LLama.Unittest/MtmdWeightsTests.cs b/LLama.Unittest/MtmdWeightsTests.cs index 7dc6cc737..464933e08 100644 --- a/LLama.Unittest/MtmdWeightsTests.cs +++ b/LLama.Unittest/MtmdWeightsTests.cs @@ -51,11 +51,9 @@ private void AssertChunksEvaluate(SafeMtmdInputChunks chunks) Assert.True(nPast > 0); } - [SkippableFact, Trait("Category", "NoCI")] + [Fact, Trait("Category", "NoCI")] public void BasicPropertyChecks() { - Skip.IfNot(_fixture.IsEnabled, MtmdNoCiFixture.SkipReason); - Assert.False(_fixture.Mtmd.SupportsAudio); Assert.True(_fixture.Mtmd.SupportsVision); Assert.False(_fixture.Mtmd.UsesMRope); @@ -63,30 +61,24 @@ public void BasicPropertyChecks() Assert.Equal(-1, _fixture.Mtmd.AudioBitrate); } - [SkippableFact, Trait("Category", "NoCI")] + [Fact, Trait("Category", "NoCI")] public void EmbedImageAsFileName() { - Skip.IfNot(_fixture.IsEnabled, MtmdNoCiFixture.SkipReason); - using var chunks = TokenizeWithEmbed(() => _fixture.Mtmd.LoadMedia(Constants.MtmdImage)); AssertChunksEvaluate(chunks); } - [SkippableFact, Trait("Category", "NoCI")] + [Fact, Trait("Category", "NoCI")] public void EmbedImageAsBinary() { - Skip.IfNot(_fixture.IsEnabled, MtmdNoCiFixture.SkipReason); - var imageBytes = File.ReadAllBytes(Constants.MtmdImage); using var chunks = TokenizeWithEmbed(() => _fixture.Mtmd.LoadMedia(imageBytes)); AssertChunksEvaluate(chunks); } - [SkippableFact, Trait("Category", "NoCI")] + [Fact, Trait("Category", "NoCI")] public void TokenizeProvidesChunkMetadata() { - Skip.IfNot(_fixture.IsEnabled, MtmdNoCiFixture.SkipReason); - using var chunks = TokenizeWithEmbed(() => _fixture.Mtmd.LoadMedia(Constants.MtmdImage)); Assert.True(chunks.Size > 0); diff --git a/LLama.Unittest/NativeLibraryConfigContainerTests.cs b/LLama.Unittest/NativeLibraryConfigContainerTests.cs new file mode 100644 index 000000000..a14071589 --- /dev/null +++ b/LLama.Unittest/NativeLibraryConfigContainerTests.cs @@ -0,0 +1,117 @@ +using System.Reflection; +using System.Runtime.InteropServices; +using LLama.Abstractions; +using LLama.Native; + +namespace LLama.Unittest +{ + public class NativeLibraryConfigContainerTests + { + [Fact] + public void DryRunPreservesLoadedLibraries() + { + var libraryPath = FindLoadableNativeLibraryPath(); + var selectingPolicy = new StubSelectingPolicy(new StubNativeLibrary(libraryPath)); + + var llamaConfig = CreateConfig(NativeLibraryName.LLama).WithSelectingPolicy(selectingPolicy); + var mtmdConfig = CreateConfig(NativeLibraryName.Mtmd).WithSelectingPolicy(selectingPolicy); + var container = new NativeLibraryConfigContainer(llamaConfig, mtmdConfig); + + var success = container.DryRun(out var loadedLLamaNativeLibrary, out var loadedMtmdNativeLibrary); + + Assert.True(success); + Assert.NotNull(loadedLLamaNativeLibrary); + Assert.NotNull(loadedMtmdNativeLibrary); + Assert.IsType(loadedLLamaNativeLibrary); + Assert.IsType(loadedMtmdNativeLibrary); + Assert.Equal(libraryPath, ((StubNativeLibrary)loadedLLamaNativeLibrary!).LibraryPath); + Assert.Equal(libraryPath, ((StubNativeLibrary)loadedMtmdNativeLibrary!).LibraryPath); + } + + private static NativeLibraryConfig CreateConfig(NativeLibraryName nativeLibraryName) + { + var config = (NativeLibraryConfig?)Activator.CreateInstance( + typeof(NativeLibraryConfig), + BindingFlags.Instance | BindingFlags.NonPublic, + binder: null, + args: new object[] { nativeLibraryName }, + culture: null); + + return Assert.IsType(config); + } + + private static string FindLoadableNativeLibraryPath() + { + var runtimeDirectory = RuntimeEnvironment.GetRuntimeDirectory(); + var extension = OperatingSystem.IsWindows() + ? ".dll" + : OperatingSystem.IsMacOS() + ? ".dylib" + : ".so"; + + var candidateNames = OperatingSystem.IsWindows() + ? new[] { "coreclr.dll", "clrjit.dll", "hostpolicy.dll" } + : OperatingSystem.IsMacOS() + ? new[] { "libcoreclr.dylib", "libclrjit.dylib", "libhostpolicy.dylib" } + : new[] { "libcoreclr.so", "libclrjit.so", "libhostpolicy.so" }; + + foreach (var candidateName in candidateNames) + { + var candidatePath = Path.Combine(runtimeDirectory, candidateName); + if (CanLoad(candidatePath)) + { + return candidatePath; + } + } + + foreach (var candidatePath in Directory.EnumerateFiles(runtimeDirectory, $"*{extension}", SearchOption.TopDirectoryOnly)) + { + if (CanLoad(candidatePath)) + { + return candidatePath; + } + } + + throw new InvalidOperationException($"Could not find a loadable native library in '{runtimeDirectory}'."); + } + + private static bool CanLoad(string candidatePath) + { + if (!File.Exists(candidatePath)) + { + return false; + } + + if (!NativeLibrary.TryLoad(candidatePath, out var handle)) + { + return false; + } + + NativeLibrary.Free(handle); + return true; + } + + private sealed class StubSelectingPolicy(StubNativeLibrary library) : INativeLibrarySelectingPolicy + { + public IEnumerable Apply( + NativeLibraryConfig.Description description, + SystemInfo systemInfo, + NativeLogConfig.LLamaLogCallback? logCallback = null) + { + return new[] { library }; + } + } + + private sealed class StubNativeLibrary(string libraryPath) : INativeLibrary + { + public string LibraryPath { get; } = libraryPath; + + public NativeLibraryMetadata? Metadata => null; + + public IEnumerable Prepare(SystemInfo systemInfo, NativeLogConfig.LLamaLogCallback? logCallback = null) + { + return new[] { LibraryPath }; + } + } + } +} diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 98ee304ad..6bba1b9f5 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -24,10 +24,16 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// protected ILogger? _logger; /// - /// The tokens that were already processed by the model. + /// The positional cursor that has already been processed by the model. /// protected int _pastTokensCount; // n_past /// + /// Number of KV slots currently occupied by the active sequence. + /// For MTMD prompts this can diverge from because media chunks may + /// consume a different number of decoded tokens than logical positions. + /// + protected int _kvTokenCount; + /// /// The tokens that were consumed by the model during the current inference. /// protected int _consumedTokensCount; // n_consume @@ -88,6 +94,7 @@ public bool IsMultiModal private string? _mtmdMarker; private readonly Queue _mtmdPromptSegments = new(); + private bool _hasQueuedMtmdPromptText; private readonly StreamingTokenDecoder _decoder; @@ -102,6 +109,7 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) _logger = logger; Context = context; _pastTokensCount = 0; + _kvTokenCount = 0; _consumedTokensCount = 0; _n_session_consumed = 0; _last_n_tokens = new FixedSizeQueue((int)Context.ContextSize); @@ -216,6 +224,7 @@ protected virtual void HandleRunOutOfContext(int tokensToKeep) Context.NativeHandle.MemorySequenceAdd(LLamaSeqId.Zero, tokensToKeep + n_discard, _pastTokensCount, -n_discard); _pastTokensCount -= n_discard; + _kvTokenCount = _pastTokensCount; // stop saving session if we run out of context _pathSession = string.Empty; } @@ -237,15 +246,75 @@ protected int GetPendingInputPositionCount() return _embeds.Count; } + /// + /// Get the number of KV slots required by the pending input. + /// For multimodal prompts this counts decoded tokens rather than logical positions. + /// + protected int GetPendingInputKvTokenCount() + { + if (IsMultiModal) + { + var pending = _embeds.Count; + foreach (var segment in _mtmdPromptSegments) + pending += segment.RemainingKvTokenCount; + + return pending; + } + + return _embeds.Count; + } + + /// + /// Get the number of KV slots that are already occupied by the active sequence. + /// + protected int GetOccupiedKvTokenCount() + { + return _kvTokenCount; + } + /// /// Shift the context window when the pending input would overflow the available KV cache. /// protected void EnsurePendingInputFitsContext(int tokensToKeep) { - if (_pastTokensCount + GetPendingInputPositionCount() > Context.ContextSize) + if (GetOccupiedKvTokenCount() + GetPendingInputKvTokenCount() > Context.ContextSize) HandleRunOutOfContext(tokensToKeep); } + /// + /// Decode the currently queued token batch and update the shared executor bookkeeping. + /// + protected async Task DecodeQueuedEmbedsAsync(LLamaBatch batch, int tokensToKeep, CancellationToken cancellationToken = default) + { + EnsurePendingInputFitsContext(tokensToKeep); + + if (!IsMultiModal) + TryReuseMatchingPrefix(); + + var previousPastCount = _pastTokensCount; + var (result, _, pastTokensCount) = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount, cancellationToken); + + if (result != DecodeResult.Ok) + throw new LLamaDecodeError(result); + + OnTextDecodeCompleted(previousPastCount, pastTokensCount); + + if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) + { + _session_tokens.AddRange(_embeds); + _n_session_consumed = _session_tokens.Count; + } + } + + /// + /// Evaluate the next queued MTMD media segment after validating context availability. + /// + protected void EvaluatePendingMtmdMediaSegment(string executorName, int tokensToKeep) + { + EnsurePendingInputFitsContext(tokensToKeep); + EvaluateNextMtmdMediaSegment(executorName); + } + /// /// Try to reuse the matching prompt prefix from the loaded session cache before evaluating new tokens. /// @@ -267,6 +336,7 @@ protected virtual void TryReuseMatchingPrefix() } _pastTokensCount++; + _kvTokenCount++; _n_session_consumed++; if (_n_session_consumed >= _session_tokens.Count) @@ -290,6 +360,8 @@ protected void DisposeMtmdPromptSegments() { while (_mtmdPromptSegments.Count > 0) _mtmdPromptSegments.Dequeue().Dispose(); + + _hasQueuedMtmdPromptText = false; } /// @@ -448,6 +520,7 @@ protected int LoadMtmdPromptSegments(SafeMtmdInputChunks chunks, LLamaToken fill /// protected void QueueNextMtmdPromptInput() { + var queuedPromptText = false; while (_mtmdPromptSegments.Count > 0 && _embeds.Count < Context.BatchSize) { var segment = _mtmdPromptSegments.Peek(); @@ -461,12 +534,32 @@ protected void QueueNextMtmdPromptInput() _embeds.Add(token); _last_n_tokens.Enqueue(token); _consumedTokensCount++; + queuedPromptText = true; } segment.AdvanceText(count); if (segment.RemainingTextTokenCount == 0) _mtmdPromptSegments.Dequeue().Dispose(); } + + _hasQueuedMtmdPromptText = queuedPromptText; + } + + /// + /// Clear the marker that tracks prompt text currently staged in for MTMD execution. + /// + protected void ClearQueuedMtmdPromptText() + { + _hasQueuedMtmdPromptText = false; + } + + /// + /// Clear the currently staged decode buffer and any MTMD prompt-text marker associated with it. + /// + protected void ClearCurrentEmbeds() + { + ClearQueuedMtmdPromptText(); + _embeds.Clear(); } /// @@ -493,10 +586,20 @@ protected void EvaluateNextMtmdMediaSegment(string executorName) } _pastTokensCount = nPast; + _kvTokenCount += segment.TokenCount; _consumedTokensCount += segment.TokenCount; _mtmdPromptSegments.Dequeue().Dispose(); } + /// + /// Update counters after a successful text decode. + /// + protected void OnTextDecodeCompleted(int previousPastCount, int newPastCount) + { + _pastTokensCount = newPastCount; + _kvTokenCount += Math.Max(0, newPastCount - previousPastCount); + } + /// /// Determine whether any prompt input is still pending evaluation. /// @@ -514,7 +617,7 @@ protected bool HasPendingPromptInput() protected bool HasPendingPromptQueue() { if (IsMultiModal) - return _mtmdPromptSegments.Count > 0; + return _hasQueuedMtmdPromptText || _mtmdPromptSegments.Count > 0; return _embed_inps.Count > _consumedTokensCount; } @@ -724,6 +827,9 @@ public class ExecutorBaseState [JsonPropertyName("n_matching_session_tokens")] public int MatchingSessionTokensCount { get; set; } + [JsonPropertyName("n_kv")] + public int KvTokenCount { get; set; } + [JsonPropertyName("path_session")] public string? SessionFilePath { get; set; } @@ -785,6 +891,8 @@ private MtmdPromptSegment(MtmdPromptSegmentKind kind, LLamaToken[]? textTokens, public int TokenCount => Kind == MtmdPromptSegmentKind.Text ? TextTokens.Length : checked((int)(Chunk?.NTokens ?? 0)); + public int RemainingKvTokenCount => Kind == MtmdPromptSegmentKind.Text ? RemainingTextTokenCount : TokenCount; + public static MtmdPromptSegment CreateText(LLamaToken[] tokens) => new(MtmdPromptSegmentKind.Text, tokens, null); diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 9ba550862..5110102a7 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -78,6 +78,7 @@ public override ExecutorBaseState GetStateData() InputSuffixTokens = _inp_sfx, MatchingSessionTokensCount = _n_matching_session_tokens, PastTokensCount = _pastTokensCount, + KvTokenCount = _kvTokenCount, SessionFilePath = _pathSession, SessionTokens = _session_tokens.ToArray(), LastTokensCapacity = _last_n_tokens.Capacity, @@ -102,6 +103,7 @@ public override Task LoadState(ExecutorBaseState data, CancellationToken cancell _inp_sfx = state.InputSuffixTokens!; _n_matching_session_tokens = state.MatchingSessionTokensCount; _pastTokensCount = state.PastTokensCount; + _kvTokenCount = state.KvTokenCount == 0 ? state.PastTokensCount : state.KvTokenCount; _pathSession = state.SessionFilePath; _session_tokens = state.SessionTokens!.ToList(); } @@ -242,33 +244,15 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334 // Instruct always uses input token size. var tokensToKeep = _embed_inps.Count; - EnsurePendingInputFitsContext(tokensToKeep); - - if (!IsMultiModal) - TryReuseMatchingPrefix(); - - var (result, _, pastTokensCount) = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount); - _pastTokensCount = pastTokensCount; - - if (result != DecodeResult.Ok) - { - throw new LLamaDecodeError(result); - } - - if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) - { - _session_tokens.AddRange(_embeds); - _n_session_consumed = _session_tokens.Count; - } + await DecodeQueuedEmbedsAsync(batch, tokensToKeep, cancellationToken); } else if (IsMultiModal && HasPendingMtmdMediaSegment()) { _is_prompt_run = false; - EnsurePendingInputFitsContext(_embed_inps.Count); - EvaluateNextMtmdMediaSegment(nameof(InstructExecutor)); + EvaluatePendingMtmdMediaSegment(nameof(InstructExecutor), _embed_inps.Count); } - _embeds.Clear(); + ClearCurrentEmbeds(); if (!HasPendingPromptInput() && !args.WaitForInput) { diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 6dc8c32e1..254f6d638 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -62,6 +62,7 @@ public override ExecutorBaseState GetStateData() LastTokens = _last_n_tokens.ToArray(), MatchingSessionTokensCount = _n_matching_session_tokens, PastTokensCount = _pastTokensCount, + KvTokenCount = _kvTokenCount, SessionFilePath = _pathSession, SessionTokens = _session_tokens.ToArray(), LastTokensCapacity = _last_n_tokens.Capacity, @@ -85,6 +86,7 @@ public override Task LoadState(ExecutorBaseState data, CancellationToken cancell _last_n_tokens = new FixedSizeQueue(state.LastTokensCapacity, state.LastTokens!); _n_matching_session_tokens = state.MatchingSessionTokensCount; _pastTokensCount = state.PastTokensCount; + _kvTokenCount = state.KvTokenCount == 0 ? state.PastTokensCount : state.KvTokenCount; _pathSession = state.SessionFilePath; _session_tokens = state.SessionTokens!.ToList(); } @@ -235,23 +237,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In tokensToKeep += Convert.ToInt32(Context.Vocab.ShouldAddBOS); // always keep the BOS token } - EnsurePendingInputFitsContext(tokensToKeep); - - if (!IsMultiModal) - { - TryReuseMatchingPrefix(); - } - - var result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount); - _pastTokensCount = result.Item3; - - if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1); - - if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) - { - _session_tokens.AddRange(_embeds); - _n_session_consumed = _session_tokens.Count; - } + await DecodeQueuedEmbedsAsync(batch, tokensToKeep, cancellationToken); } else if (IsMultiModal && HasPendingMtmdMediaSegment()) { @@ -265,11 +251,10 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In { tokensToKeep += Convert.ToInt32(Context.Vocab.ShouldAddBOS); } - EnsurePendingInputFitsContext(tokensToKeep); - EvaluateNextMtmdMediaSegment(nameof(InteractiveExecutor)); + EvaluatePendingMtmdMediaSegment(nameof(InteractiveExecutor), tokensToKeep); } - - _embeds.Clear(); + + ClearCurrentEmbeds(); if (!HasPendingPromptInput() && !args.WaitForInput) { diff --git a/LLama/Native/Load/NativeLibraryConfig.cs b/LLama/Native/Load/NativeLibraryConfig.cs index 16e3fd88b..993d7fec1 100644 --- a/LLama/Native/Load/NativeLibraryConfig.cs +++ b/LLama/Native/Load/NativeLibraryConfig.cs @@ -593,9 +593,14 @@ public NativeLibraryConfigContainer WithLogCallback(ILogger? logger) /// /// You can still modify the configuration after this calling but only before any call from . /// + /// The loaded llama native library descriptor, or null when loading failed. + /// The loaded mtmd native library descriptor, or null when loading failed. /// Whether the running is successful. public bool DryRun(out INativeLibrary? loadedLLamaNativeLibrary, out INativeLibrary? loadedMtmdNativeLibrary) { + loadedLLamaNativeLibrary = null; + loadedMtmdNativeLibrary = null; + bool success = true; foreach(var config in _configs) { @@ -613,7 +618,6 @@ public bool DryRun(out INativeLibrary? loadedLLamaNativeLibrary, out INativeLibr throw new Exception("Unknown native library config during the dry run."); } } - loadedLLamaNativeLibrary = loadedMtmdNativeLibrary = null; return success; } } diff --git a/docs/Examples/MtmdInteractiveModeExecute.md b/docs/Examples/MtmdInteractiveModeExecute.md index 378c93a1b..642695f6e 100644 --- a/docs/Examples/MtmdInteractiveModeExecute.md +++ b/docs/Examples/MtmdInteractiveModeExecute.md @@ -5,7 +5,7 @@ ## Workflow - Resolve the model, multimodal projection, and sample image paths via `UserSettings`. - Create `ModelParams` for the text model and capture the MTMD defaults with `MtmdContextParams.Default()`. -- Load the base model and context, then initialize `SafeMtmdWeights` with the multimodal projection file. +- Load the base model and context, then initialize `MtmdWeights` with the multimodal projection file. - Ask the helper for a media marker (`mtmdParameters.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""`) and feed it into an `InteractiveExecutor`. ```cs @@ -15,7 +15,7 @@ using var model = await LLamaWeights.LoadFromFileAsync(parameters); using var context = model.CreateContext(parameters); // Mtmd Init -using var clipModel = await SafeMtmdWeights.LoadFromFileAsync( +using var clipModel = await MtmdWeights.LoadFromFileAsync( multiModalProj, model, mtmdParameters); @@ -28,14 +28,14 @@ var ex = new InteractiveExecutor(context, clipModel); ``` ## Handling user input -- Prompts can include image paths wrapped in braces (for example `{c:/image.jpg}`); the loop searches for those markers with regular expressions. -- Every referenced file is loaded through `SafeMtmdWeights.LoadMedia`, producing `SafeMtmdEmbed` instances that are queued for the next tokenization call. -- When the user provides images, the executor clears its KV cache (`MemorySequenceRemove`) before replacing each brace-wrapped path in the prompt with the multimodal marker. -- The embeds collected for the current turn are copied into `ex.Embeds`, so the executor submits both the text prompt and the pending media to the helper before generation. +- Prompts can include media paths wrapped in double braces (for example `{{c:/image.jpg}}`); the loop searches for those markers with regular expressions. +- Every referenced file is loaded through `MtmdWeights.LoadMediaStandalone`, producing `SafeMtmdEmbed` instances owned by the current turn. +- The prompt placeholders are replaced with the multimodal marker before inference. +- The embeds collected for the current turn are copied into `ex.Embeds`, so the executor submits both the text prompt and the explicit media embeddings to the MTMD helper before generation. ## Running the sample 1. Ensure the model and projection paths returned by `UserSettings` exist locally. 2. Start the example (for instance from the examples host application) and observe the initial description printed to the console. -3. Type text normally, or reference new images by including their path inside braces. Type `/exit` to end the conversation. +3. Type text normally, or reference new images by including their path inside double braces. Type `/exit` to end the conversation. This walkthrough mirrors the logic in the sample so you can adapt it for your own multimodal workflows. diff --git a/docs/QuickStart.md b/docs/QuickStart.md index 03bf82a94..2cbc474c2 100644 --- a/docs/QuickStart.md +++ b/docs/QuickStart.md @@ -90,22 +90,23 @@ while (userInput != "exit") ``` -## Examples of chatting with LLaVA +## Example of chatting with MTMD -This example shows chatting with LLaVA to ask it to describe the picture. +This example shows chatting with an MTMD-enabled model to ask it to describe the picture. ![llava_demo](./media/llava_demo.gif) ```cs using System.Text.RegularExpressions; using LLama; using LLama.Common; +using LLama.Native; string multiModalProj = @""; -string modelPath = @""; +string modelPath = @""; string modelImage = @""; const int maxTokens = 1024; // The max tokens that could be generated. -var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n"; +var prompt = $"{{{{{modelImage}}}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n"; var parameters = new ModelParams(modelPath) { @@ -115,64 +116,44 @@ var parameters = new ModelParams(modelPath) using var model = LLamaWeights.LoadFromFile(parameters); using var context = model.CreateContext(parameters); -// Llava Init -using var clipModel = LLavaWeights.LoadFromFile(multiModalProj); +var mtmdParameters = MtmdContextParams.Default(); +using var clipModel = MtmdWeights.LoadFromFile(multiModalProj, model, mtmdParameters); +var mediaMarker = mtmdParameters.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""; var ex = new InteractiveExecutor(context, clipModel); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to {0} and the context size is {1}.", maxTokens, parameters.ContextSize); -Console.WriteLine("To send an image, enter its filename in curly braces, like this {c:/image.jpg}."); +Console.WriteLine("To send an image, enter its filename in double curly braces, like this {{c:/image.jpg}}."); var inferenceParams = new InferenceParams() { Temperature = 0.1f, AntiPrompts = new List { "\nUSER:" }, MaxTokens = maxTokens }; do { + var mediaPathsWithBraces = Regex.Matches(prompt, @"\{\{(.*?)\}\}").Select(m => m.Value).ToList(); + var mediaPaths = Regex.Matches(prompt, @"\{\{(.*?)\}\}").Select(m => m.Groups[1].Value).ToList(); - // Evaluate if we have images - // - var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); - var imageCount = imageMatches.Count(); - var hasImages = imageCount > 0; - byte[][] imageBytes = null; - - if (hasImages) + var embeds = new List(); + try { - var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value); - var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value); - - try - { - imageBytes = imagePaths.Select(File.ReadAllBytes).ToArray(); - } - catch (IOException exception) - { - Console.ForegroundColor = ConsoleColor.Red; - Console.Write( - $"Could not load your {(imageCount == 1 ? "image" : "images")}:"); - Console.Write($"{exception.Message}"); - Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("Please try again."); - break; - } - - - int index = 0; - foreach (var path in imagePathsWithCurlyBraces) - { - // First image replace to tag "); - else - prompt = prompt.Replace(path, ""); - } - Console.WriteLine(); - - - // Initialize Images in executor - // - ex.ImagePaths = imagePaths.ToList(); + foreach (var mediaPath in mediaPaths) + embeds.Add(clipModel.LoadMediaStandalone(mediaPath)); } + catch (IOException exception) + { + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine($"Could not load your media: {exception.Message}"); + foreach (var embed in embeds) + embed.Dispose(); + break; + } + + foreach (var path in mediaPathsWithBraces) + prompt = prompt.Replace(path, mediaMarker, StringComparison.Ordinal); + + ex.Embeds.Clear(); + foreach (var embed in embeds) + ex.Embeds.Add(embed); Console.ForegroundColor = ConsoleColor.White; await foreach (var text in ex.InferAsync(prompt, inferenceParams)) diff --git a/docs/Tutorials/Executors.md b/docs/Tutorials/Executors.md index a0c96e800..d3e4c0fa6 100644 --- a/docs/Tutorials/Executors.md +++ b/docs/Tutorials/Executors.md @@ -16,21 +16,21 @@ public interface ILLamaExecutor /// public LLamaContext Context { get; } - // LLava Section + // Multimodal Section // /// - /// Identify if it's a multi-modal model and there is a image to process. + /// Identify if it's a multimodal executor. /// public bool IsMultiModal { get; } /// /// Multi-Modal Projections / Clip Model weights /// - public LLavaWeights? ClipModel { get; } + public MtmdWeights? ClipModel { get; } /// - /// List of images: List of images in byte array format. + /// Media embeddings queued for the next multimodal prompt. /// - public List Images { get; } + public List Embeds { get; } /// @@ -329,21 +329,23 @@ public record InferenceParams ## Save and load executor state -An executor also has its state, which can be saved and loaded. That means a lot when you want to support restore a previous session for the user in your application. +Text executors also have state, which can be saved and loaded. That matters when you want to restore a previous session for the user in your application. + +Multimodal executors do not support executor-state persistence or session files. `GetStateData`, `SaveState`, `LoadState`, `WithSessionFile`, and `SaveSessionFile` will throw `NotSupportedException` when `ClipModel` is attached. The following code shows how to use save and load executor state. ```cs -InteractiveExecutor executor = new InteractiveExecutor(model); +InteractiveExecutor executor = new InteractiveExecutor(context); // do some things... -executor.SaveState("executor.st"); +await executor.SaveState("executor.st"); var stateData = executor.GetStateData(); -InteractiveExecutor executor2 = new InteractiveExecutor(model); -executor2.LoadState(stateData); +InteractiveExecutor executor2 = new InteractiveExecutor(context); +await executor2.LoadState(stateData); // do some things... -InteractiveExecutor executor3 = new InteractiveExecutor(model); -executor3.LoadState("executor.st"); +InteractiveExecutor executor3 = new InteractiveExecutor(context); +await executor3.LoadState("executor.st"); // do some things... -``` \ No newline at end of file +``` diff --git a/docs/Tutorials/NativeLibraryConfig.md b/docs/Tutorials/NativeLibraryConfig.md index ff2681821..19247f41e 100644 --- a/docs/Tutorials/NativeLibraryConfig.md +++ b/docs/Tutorials/NativeLibraryConfig.md @@ -24,11 +24,11 @@ All you need to do is adding the following code to the very beginning of your co NativeLibraryConfig.All.WithLibrary(""); ``` -If you want to configure the loading for LLama library or llava library respectively, please call the following APIs. +If you want to configure the loading for the LLama library or the MTMD library respectively, please call the following APIs. ```cs NativeLibraryConfig.LLama.WithLibrary(""); -NativeLibraryConfig.LLava.WithLibrary(""); +NativeLibraryConfig.Mtmd.WithLibrary(""); ``` ### Automatically select one from multiple native library files diff --git a/docs/xmldocs/index.md b/docs/xmldocs/index.md index 3c3af0e90..eb477d1fc 100644 --- a/docs/xmldocs/index.md +++ b/docs/xmldocs/index.md @@ -24,7 +24,7 @@ [LLamaWeights](./llama.llamaweights.md) -[LLavaWeights](./llama.llavaweights.md) +[MtmdWeights](./llama.mtmdweights.md) [SessionState](./llama.sessionstate.md) @@ -152,8 +152,6 @@ [LLamaFtype](./llama.native.llamaftype.md) -[LLamaKvCacheViewSafeHandle](./llama.native.llamakvcacheviewsafehandle.md) - [LLamaLogitBias](./llama.native.llamalogitbias.md) [LLamaLogLevel](./llama.native.llamaloglevel.md) @@ -196,10 +194,10 @@ [LLamaVocabType](./llama.native.llamavocabtype.md) -[LLavaImageEmbed](./llama.native.llavaimageembed.md) - [LoraAdapter](./llama.native.loraadapter.md) +[MtmdContextParams](./llama.native.mtmdcontextparams.md) + [NativeApi](./llama.native.nativeapi.md) [NativeLibraryConfig](./llama.native.nativelibraryconfig.md) @@ -232,9 +230,13 @@ [SafeLLamaSamplerChainHandle](./llama.native.safellamasamplerchainhandle.md) -[SafeLlavaImageEmbedHandle](./llama.native.safellavaimageembedhandle.md) +[SafeMtmdEmbed](./llama.native.safemtmdembed.md) + +[SafeMtmdInputChunk](./llama.native.safemtmdinputchunk.md) + +[SafeMtmdInputChunks](./llama.native.safemtmdinputchunks.md) -[SafeLlavaModelHandle](./llama.native.safellavamodelhandle.md) +[SafeMtmdModelHandle](./llama.native.safemtmdmodelhandle.md) [SystemInfo](./llama.native.systeminfo.md) diff --git a/docs/xmldocs/llama.abstractions.illamaexecutor.md b/docs/xmldocs/llama.abstractions.illamaexecutor.md index d00a478de..8948f861b 100644 --- a/docs/xmldocs/llama.abstractions.illamaexecutor.md +++ b/docs/xmldocs/llama.abstractions.illamaexecutor.md @@ -45,24 +45,24 @@ public abstract bool IsMultiModal { get; } Multi-Modal Projections / Clip Model weights ```csharp -public abstract LLavaWeights ClipModel { get; } +public abstract MtmdWeights ClipModel { get; } ``` #### Property Value -[LLavaWeights](./llama.llavaweights.md)
+[MtmdWeights](./llama.mtmdweights.md)
-### **Images** +### **Embeds** -List of images: List of images in byte array format. +List of media: List of media for Multi-Modal models. ```csharp -public abstract List Images { get; } +public abstract List Embeds { get; } ``` #### Property Value -[List<Byte[]>](https://docs.microsoft.com/en-us/dotnet/api/system.collections.generic.list-1)
+[List<SafeMtmdEmbed>](https://docs.microsoft.com/en-us/dotnet/api/system.collections.generic.list-1)
## Methods diff --git a/docs/xmldocs/llama.batched.batchedexecutor.md b/docs/xmldocs/llama.batched.batchedexecutor.md index 2a90b6691..5a5a3515d 100644 --- a/docs/xmldocs/llama.batched.batchedexecutor.md +++ b/docs/xmldocs/llama.batched.batchedexecutor.md @@ -78,6 +78,16 @@ public bool IsDisposed { get; private set; } [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+### **ClipModel** + +```csharp +public MtmdWeights ClipModel { get; } +``` + +#### Property Value + +[MtmdWeights](./llama.mtmdweights.md)
+ ## Constructors ### **BatchedExecutor(LLamaWeights, IContextParams)** @@ -96,6 +106,20 @@ The model to use `contextParams` [IContextParams](./llama.abstractions.icontextparams.md)
Parameters to create a new context +### **BatchedExecutor(LLamaWeights, IContextParams, MtmdWeights)** + +```csharp +public BatchedExecutor(LLamaWeights model, IContextParams contextParams, MtmdWeights clipModel) +``` + +#### Parameters + +`model` [LLamaWeights](./llama.llamaweights.md)
+ +`contextParams` [IContextParams](./llama.abstractions.icontextparams.md)
+ +`clipModel` [MtmdWeights](./llama.mtmdweights.md)
+ ## Methods ### **Create()** diff --git a/docs/xmldocs/llama.batched.conversation.md b/docs/xmldocs/llama.batched.conversation.md index 597cba0f2..4cd45ff1c 100644 --- a/docs/xmldocs/llama.batched.conversation.md +++ b/docs/xmldocs/llama.batched.conversation.md @@ -195,6 +195,40 @@ Thrown if this conversation was not prompted before the previous call to infer [CannotSampleRequiresInferenceException](./llama.batched.cannotsamplerequiresinferenceexception.md)
Thrown if Infer() must be called on the executor +### **Prompt(String, Boolean, Boolean)** + +```csharp +public void Prompt(string promptText, bool addBos, bool special) +``` + +#### Parameters + +`promptText` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+ +`addBos` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +`special` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **Prompt(String, ReadOnlySpan<SafeMtmdEmbed>, Boolean)** + +Prompt this conversation with explicit multimodal embeddings. + The caller retains ownership of `embeds`. + +```csharp +public void Prompt(string promptText, ReadOnlySpan embeds, bool addBos) +``` + +#### Parameters + +`promptText` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+Prompt text for the model. + +`embeds` [ReadOnlySpan<SafeMtmdEmbed>](https://docs.microsoft.com/en-us/dotnet/api/system.readonlyspan-1)
+Media embeddings to include in the multimodal prompt. + +`addBos` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+Whether to add the BOS token. + ### **Prompt(List<LLamaToken>, Boolean)** Add tokens to this conversation @@ -255,18 +289,6 @@ public void Prompt(LLamaToken token) [AlreadyPromptedConversationException](./llama.batched.alreadypromptedconversationexception.md)
-### **Prompt(SafeLlavaImageEmbedHandle)** - -Prompt this conversation with an image embedding - -```csharp -public void Prompt(SafeLlavaImageEmbedHandle embedding) -``` - -#### Parameters - -`embedding` [SafeLlavaImageEmbedHandle](./llama.native.safellavaimageembedhandle.md)
- ### **Prompt(ReadOnlySpan<Single>)** Prompt this conversation with embeddings diff --git a/docs/xmldocs/llama.instructexecutor.md b/docs/xmldocs/llama.instructexecutor.md index 77a3d5ac5..3041f1747 100644 --- a/docs/xmldocs/llama.instructexecutor.md +++ b/docs/xmldocs/llama.instructexecutor.md @@ -28,12 +28,22 @@ protected ILogger _logger; ### **_pastTokensCount** -The tokens that were already processed by the model. +The positional cursor that has already been processed by the model. ```csharp protected int _pastTokensCount; ``` +### **_kvTokenCount** + +Number of KV slots currently occupied by the active sequence. + For MTMD prompts this can diverge from because media chunks may + consume a different number of decoded tokens than logical positions. + +```csharp +protected int _kvTokenCount; +``` + ### **_consumedTokensCount** The tokens that were consumed by the model during the current inference. @@ -44,7 +54,7 @@ protected int _consumedTokensCount; ### **_n_session_consumed** - +Number of tokens consumed from the session cache during the current run. ```csharp protected int _n_session_consumed; @@ -52,7 +62,7 @@ protected int _n_session_consumed; ### **_n_matching_session_tokens** - +Number of prompt tokens that match the loaded session cache prefix. ```csharp protected int _n_matching_session_tokens; @@ -84,7 +94,7 @@ protected List _embed_inps; ### **_session_tokens** - +Tokens recovered from the session file and reused to warm up the KV cache. ```csharp protected List _session_tokens; @@ -112,6 +122,18 @@ public LLamaContext Context { get; } [LLamaContext](./llama.llamacontext.md)
+### **AntipromptProcessor** + +Tracks anti-prompts across streamed output. + +```csharp +protected AntipromptProcessor AntipromptProcessor { get; } +``` + +#### Property Value + +[AntipromptProcessor](./llama.antipromptprocessor.md)
+ ### **IsMultiModal** ```csharp @@ -125,22 +147,22 @@ public bool IsMultiModal { get; } ### **ClipModel** ```csharp -public LLavaWeights ClipModel { get; } +public MtmdWeights ClipModel { get; } ``` #### Property Value -[LLavaWeights](./llama.llavaweights.md)
+[MtmdWeights](./llama.mtmdweights.md)
-### **Images** +### **Embeds** ```csharp -public List Images { get; } +public List Embeds { get; } ``` #### Property Value -[List<Byte[]>](https://docs.microsoft.com/en-us/dotnet/api/system.collections.generic.list-1)
+[List<SafeMtmdEmbed>](https://docs.microsoft.com/en-us/dotnet/api/system.collections.generic.list-1)
## Constructors @@ -162,6 +184,24 @@ public InstructExecutor(LLamaContext context, string instructionPrefix, string i `logger` ILogger
+### **InstructExecutor(LLamaContext, MtmdWeights, String, String, ILogger)** + +```csharp +public InstructExecutor(LLamaContext context, MtmdWeights clipModel, string instructionPrefix, string instructionSuffix, ILogger logger) +``` + +#### Parameters + +`context` [LLamaContext](./llama.llamacontext.md)
+ +`clipModel` [MtmdWeights](./llama.mtmdweights.md)
+ +`instructionPrefix` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+ +`instructionSuffix` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+ +`logger` ILogger
+ ## Methods ### **GetStateData()** @@ -174,66 +214,74 @@ public ExecutorBaseState GetStateData() [ExecutorBaseState](./llama.statefulexecutorbase.executorbasestate.md)
-### **LoadState(ExecutorBaseState)** +### **LoadState(ExecutorBaseState, CancellationToken)** ```csharp -public Task LoadState(ExecutorBaseState data) +public Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken) ``` #### Parameters `data` [ExecutorBaseState](./llama.statefulexecutorbase.executorbasestate.md)
+`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
+ #### Returns [Task](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task)
-### **SaveState(String)** +### **SaveState(String, CancellationToken)** ```csharp -public Task SaveState(string filename) +public Task SaveState(string filename, CancellationToken cancellationToken) ``` #### Parameters `filename` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
+ #### Returns [Task](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task)
-### **LoadState(String)** +### **LoadState(String, CancellationToken)** ```csharp -public Task LoadState(string filename) +public Task LoadState(string filename, CancellationToken cancellationToken) ``` #### Parameters `filename` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
+ #### Returns [Task](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task)
-### **GetLoopCondition(InferStateArgs)** +### **GetLoopCondition(InferStateArgs, CancellationToken)** ```csharp -protected Task GetLoopCondition(InferStateArgs args) +protected Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken) ``` #### Parameters `args` [InferStateArgs](./llama.statefulexecutorbase.inferstateargs.md)
+`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
+ #### Returns [Task<Boolean>](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task-1)
-### **PreprocessInputs(String, InferStateArgs)** +### **PreprocessInputs(String, InferStateArgs, CancellationToken)** ```csharp -protected Task PreprocessInputs(string text, InferStateArgs args) +protected Task PreprocessInputs(string text, InferStateArgs args, CancellationToken cancellationToken) ``` #### Parameters @@ -242,14 +290,16 @@ protected Task PreprocessInputs(string text, InferStateArgs args) `args` [InferStateArgs](./llama.statefulexecutorbase.inferstateargs.md)
+`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
+ #### Returns [Task](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task)
-### **PostProcess(IInferenceParams, InferStateArgs)** +### **PostProcess(IInferenceParams, InferStateArgs, CancellationToken)** ```csharp -protected Task>> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) +protected Task>> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) ``` #### Parameters @@ -258,14 +308,16 @@ protected Task>> PostProcess(IInferencePa `args` [InferStateArgs](./llama.statefulexecutorbase.inferstateargs.md)
+`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
+ #### Returns [Task<ValueTuple<Boolean, IReadOnlyList<String>>>](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task-1)
-### **InferInternal(IInferenceParams, InferStateArgs)** +### **InferInternal(IInferenceParams, InferStateArgs, CancellationToken)** ```csharp -protected Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) +protected Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) ``` #### Parameters @@ -274,6 +326,8 @@ protected Task InferInternal(IInferenceParams inferenceParams, InferStateArgs ar `args` [InferStateArgs](./llama.statefulexecutorbase.inferstateargs.md)
+`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
+ #### Returns [Task](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task)
diff --git a/docs/xmldocs/llama.interactiveexecutor.md b/docs/xmldocs/llama.interactiveexecutor.md index 8230e03c2..df3f293fb 100644 --- a/docs/xmldocs/llama.interactiveexecutor.md +++ b/docs/xmldocs/llama.interactiveexecutor.md @@ -28,12 +28,22 @@ protected ILogger _logger; ### **_pastTokensCount** -The tokens that were already processed by the model. +The positional cursor that has already been processed by the model. ```csharp protected int _pastTokensCount; ``` +### **_kvTokenCount** + +Number of KV slots currently occupied by the active sequence. + For MTMD prompts this can diverge from because media chunks may + consume a different number of decoded tokens than logical positions. + +```csharp +protected int _kvTokenCount; +``` + ### **_consumedTokensCount** The tokens that were consumed by the model during the current inference. @@ -44,7 +54,7 @@ protected int _consumedTokensCount; ### **_n_session_consumed** - +Number of tokens consumed from the session cache during the current run. ```csharp protected int _n_session_consumed; @@ -52,7 +62,7 @@ protected int _n_session_consumed; ### **_n_matching_session_tokens** - +Number of prompt tokens that match the loaded session cache prefix. ```csharp protected int _n_matching_session_tokens; @@ -84,7 +94,7 @@ protected List _embed_inps; ### **_session_tokens** - +Tokens recovered from the session file and reused to warm up the KV cache. ```csharp protected List _session_tokens; @@ -112,6 +122,18 @@ public LLamaContext Context { get; } [LLamaContext](./llama.llamacontext.md)
+### **AntipromptProcessor** + +Tracks anti-prompts across streamed output. + +```csharp +protected AntipromptProcessor AntipromptProcessor { get; } +``` + +#### Property Value + +[AntipromptProcessor](./llama.antipromptprocessor.md)
+ ### **IsMultiModal** ```csharp @@ -125,28 +147,28 @@ public bool IsMultiModal { get; } ### **ClipModel** ```csharp -public LLavaWeights ClipModel { get; } +public MtmdWeights ClipModel { get; } ``` #### Property Value -[LLavaWeights](./llama.llavaweights.md)
+[MtmdWeights](./llama.mtmdweights.md)
-### **Images** +### **Embeds** ```csharp -public List Images { get; } +public List Embeds { get; } ``` #### Property Value -[List<Byte[]>](https://docs.microsoft.com/en-us/dotnet/api/system.collections.generic.list-1)
+[List<SafeMtmdEmbed>](https://docs.microsoft.com/en-us/dotnet/api/system.collections.generic.list-1)
## Constructors ### **InteractiveExecutor(LLamaContext, ILogger)** - +Create an interactive executor for text-only inference. ```csharp public InteractiveExecutor(LLamaContext context, ILogger logger) @@ -155,24 +177,29 @@ public InteractiveExecutor(LLamaContext context, ILogger logger) #### Parameters `context` [LLamaContext](./llama.llamacontext.md)
+LLama context to operate against. `logger` ILogger
+Optional logger for diagnostic output. -### **InteractiveExecutor(LLamaContext, LLavaWeights, ILogger)** - +### **InteractiveExecutor(LLamaContext, MtmdWeights, ILogger)** +Create an interactive multimodal executor that can process text alongside media inputs. ```csharp -public InteractiveExecutor(LLamaContext context, LLavaWeights clipModel, ILogger logger) +public InteractiveExecutor(LLamaContext context, MtmdWeights clipModel, ILogger logger) ``` #### Parameters `context` [LLamaContext](./llama.llamacontext.md)
+LLama context to operate against. -`clipModel` [LLavaWeights](./llama.llavaweights.md)
+`clipModel` [MtmdWeights](./llama.mtmdweights.md)
+Multimodal weights (MTMD) to attach to the executor. `logger` ILogger
+Optional logger for diagnostic output. ## Methods @@ -186,102 +213,123 @@ public ExecutorBaseState GetStateData() [ExecutorBaseState](./llama.statefulexecutorbase.executorbasestate.md)
-### **LoadState(ExecutorBaseState)** +### **LoadState(ExecutorBaseState, CancellationToken)** ```csharp -public Task LoadState(ExecutorBaseState data) +public Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken) ``` #### Parameters `data` [ExecutorBaseState](./llama.statefulexecutorbase.executorbasestate.md)
+`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
+ #### Returns [Task](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task)
-### **SaveState(String)** +### **SaveState(String, CancellationToken)** ```csharp -public Task SaveState(string filename) +public Task SaveState(string filename, CancellationToken cancellationToken) ``` #### Parameters `filename` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
+ #### Returns [Task](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task)
-### **LoadState(String)** +### **LoadState(String, CancellationToken)** ```csharp -public Task LoadState(string filename) +public Task LoadState(string filename, CancellationToken cancellationToken) ``` #### Parameters `filename` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
+ #### Returns [Task](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task)
-### **GetLoopCondition(InferStateArgs)** +### **GetLoopCondition(InferStateArgs, CancellationToken)** -Define whether to continue the loop to generate responses. +Decide whether generation should continue for the current iteration. ```csharp -protected Task GetLoopCondition(InferStateArgs args) +protected Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken) ``` #### Parameters `args` [InferStateArgs](./llama.statefulexecutorbase.inferstateargs.md)
+Mutable inference state. + +`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
#### Returns [Task<Boolean>](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task-1)
+`true` to keep generating; otherwise `false`. + +### **PreprocessInputs(String, InferStateArgs, CancellationToken)** -### **PreprocessInputs(String, InferStateArgs)** +Preprocess the incoming prompt or continuation text before inference. ```csharp -protected Task PreprocessInputs(string text, InferStateArgs args) +protected Task PreprocessInputs(string text, InferStateArgs args, CancellationToken cancellationToken) ``` #### Parameters `text` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+Prompt text or continuation provided by the caller. `args` [InferStateArgs](./llama.statefulexecutorbase.inferstateargs.md)
+Mutable inference state. + +`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
#### Returns [Task](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task)
-### **PostProcess(IInferenceParams, InferStateArgs)** +### **PostProcess(IInferenceParams, InferStateArgs, CancellationToken)** -Return whether to break the generation. +Decide whether generation should stop based on antiprompts, token limits, or end-of-generation markers. ```csharp -protected Task>> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) +protected Task>> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) ``` #### Parameters `inferenceParams` [IInferenceParams](./llama.abstractions.iinferenceparams.md)
+Sampling parameters controlling generation. `args` [InferStateArgs](./llama.statefulexecutorbase.inferstateargs.md)
+Mutable inference state. + +`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
#### Returns [Task<ValueTuple<Boolean, IReadOnlyList<String>>>](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task-1)
+Tuple describing whether to stop and any additional outputs to emit. -### **InferInternal(IInferenceParams, InferStateArgs)** +### **InferInternal(IInferenceParams, InferStateArgs, CancellationToken)** ```csharp -protected Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) +protected Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) ``` #### Parameters @@ -290,6 +338,8 @@ protected Task InferInternal(IInferenceParams inferenceParams, InferStateArgs ar `args` [InferStateArgs](./llama.statefulexecutorbase.inferstateargs.md)
+`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
+ #### Returns [Task](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task)
diff --git a/docs/xmldocs/llama.mtmdweights.md b/docs/xmldocs/llama.mtmdweights.md new file mode 100644 index 000000000..816654e2e --- /dev/null +++ b/docs/xmldocs/llama.mtmdweights.md @@ -0,0 +1,348 @@ +[`< Back`](./) + +--- + +# MtmdWeights + +Namespace: LLama + +Lightweight wrapper around the MTMD native context and its helpers. + +```csharp +public sealed class MtmdWeights : System.IDisposable +``` + +Inheritance [Object](https://docs.microsoft.com/en-us/dotnet/api/system.object) → [MtmdWeights](./llama.mtmdweights.md)
+Implements [IDisposable](https://docs.microsoft.com/en-us/dotnet/api/system.idisposable)
+Attributes [NullableContextAttribute](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.compilerservices.nullablecontextattribute), [NullableAttribute](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.compilerservices.nullableattribute) + +## Properties + +### **NativeHandle** + +The native handle, which is used in the native APIs + +```csharp +public SafeMtmdModelHandle NativeHandle { get; } +``` + +#### Property Value + +[SafeMtmdModelHandle](./llama.native.safemtmdmodelhandle.md)
+ +**Remarks:** + +Be careful how you use this! + +### **SupportsVision** + +Indicates whether the model supports vision inputs. + +```csharp +public bool SupportsVision { get; } +``` + +#### Property Value + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **SupportsAudio** + +Indicates whether the model supports audio inputs. + +```csharp +public bool SupportsAudio { get; } +``` + +#### Property Value + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **UsesNonCausalAttention** + +Indicates whether the model decodes using the non-causal path. + +```csharp +public bool UsesNonCausalAttention { get; } +``` + +#### Property Value + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **UsesMRope** + +Indicates whether the model decodes using multi-scale RoPE. + +```csharp +public bool UsesMRope { get; } +``` + +#### Property Value + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **AudioBitrate** + +Gets the audio bitrate advertised by the model. + +```csharp +public int AudioBitrate { get; } +``` + +#### Property Value + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +## Methods + +### **LoadFromFile(String, LLamaWeights, MtmdContextParams)** + +Load weights into memory + +```csharp +public static MtmdWeights LoadFromFile(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams) +``` + +#### Parameters + +`mmProject` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+Path to the mmproj file + +`textModel` [LLamaWeights](./llama.llamaweights.md)
+The text model + +`mtmdCtxParams` [MtmdContextParams](./llama.native.mtmdcontextparams.md)
+Parameters for MTMD context creation + +#### Returns + +[MtmdWeights](./llama.mtmdweights.md)
+ +### **LoadFromFileAsync(String, LLamaWeights, MtmdContextParams, CancellationToken)** + +Load weights into memory + +```csharp +public static Task LoadFromFileAsync(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams, CancellationToken token) +``` + +#### Parameters + +`mmProject` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+Path to the mmproj file + +`textModel` [LLamaWeights](./llama.llamaweights.md)
+The text model + +`mtmdCtxParams` [MtmdContextParams](./llama.native.mtmdcontextparams.md)
+Parameters for MTMD context creation + +`token` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
+ +#### Returns + +[Task<MtmdWeights>](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task-1)
+ +### **LoadMedia(String)** + +Load media from disk and keep it pending for the next tokenize call. + +```csharp +public SafeMtmdEmbed LoadMedia(string path) +``` + +#### Parameters + +`path` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+ +#### Returns + +[SafeMtmdEmbed](./llama.native.safemtmdembed.md)
+ +### **LoadMedia(ReadOnlySpan<Byte>)** + +Load media from an in-memory buffer and keep it pending for the next tokenize call. + +```csharp +public SafeMtmdEmbed LoadMedia(ReadOnlySpan data) +``` + +#### Parameters + +`data` [ReadOnlySpan<Byte>](https://docs.microsoft.com/en-us/dotnet/api/system.readonlyspan-1)
+ +#### Returns + +[SafeMtmdEmbed](./llama.native.safemtmdembed.md)
+ +### **LoadMediaStandalone(String)** + +Load media from disk as a standalone embedding without touching the shared pending-media queue. + +```csharp +public SafeMtmdEmbed LoadMediaStandalone(string path) +``` + +#### Parameters + +`path` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+ +#### Returns + +[SafeMtmdEmbed](./llama.native.safemtmdembed.md)
+ +### **LoadMediaStandalone(ReadOnlySpan<Byte>)** + +Load media from an in-memory buffer as a standalone embedding without touching the shared pending-media queue. + +```csharp +public SafeMtmdEmbed LoadMediaStandalone(ReadOnlySpan data) +``` + +#### Parameters + +`data` [ReadOnlySpan<Byte>](https://docs.microsoft.com/en-us/dotnet/api/system.readonlyspan-1)
+ +#### Returns + +[SafeMtmdEmbed](./llama.native.safemtmdembed.md)
+ +### **ClearMedia()** + +Clear any pending media buffers before or after tokenization. + +```csharp +public void ClearMedia() +``` + +### **Tokenize(String, Boolean, Boolean, SafeMtmdInputChunks&)** + +Tokenize text (with optional special tokens) against the pending media buffers. + +```csharp +public int Tokenize(string text, bool addSpecial, bool parseSpecial, SafeMtmdInputChunks& chunks) +``` + +#### Parameters + +`text` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+ +`addSpecial` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +`parseSpecial` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +`chunks` [SafeMtmdInputChunks&](./llama.native.safemtmdinputchunks&.md)
+ +#### Returns + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +### **Tokenize(String, Boolean, Boolean, ReadOnlySpan<SafeMtmdEmbed>, SafeMtmdInputChunks&)** + +Tokenize text (with optional special tokens) against explicit media embeddings. + The caller retains ownership of `embeds`. + +```csharp +public int Tokenize(string text, bool addSpecial, bool parseSpecial, ReadOnlySpan embeds, SafeMtmdInputChunks& chunks) +``` + +#### Parameters + +`text` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+ +`addSpecial` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +`parseSpecial` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +`embeds` [ReadOnlySpan<SafeMtmdEmbed>](https://docs.microsoft.com/en-us/dotnet/api/system.readonlyspan-1)
+ +`chunks` [SafeMtmdInputChunks&](./llama.native.safemtmdinputchunks&.md)
+ +#### Returns + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +### **EvaluateChunks(SafeMtmdInputChunks, SafeLLamaContextHandle, Int32&, Int32, Int32, Boolean)** + +Evaluate a chunk batch using the helper that performs mtmd encode + llama decode. + +```csharp +public int EvaluateChunks(SafeMtmdInputChunks chunks, SafeLLamaContextHandle llamaContext, Int32& nPast, int seqId, int nBatch, bool logitsLast) +``` + +#### Parameters + +`chunks` [SafeMtmdInputChunks](./llama.native.safemtmdinputchunks.md)
+ +`llamaContext` [SafeLLamaContextHandle](./llama.native.safellamacontexthandle.md)
+ +`nPast` [Int32&](https://docs.microsoft.com/en-us/dotnet/api/system.int32&)
+ +`seqId` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +`nBatch` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +`logitsLast` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +#### Returns + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +### **EvaluateChunk(IntPtr, SafeLLamaContextHandle, Int32&, Int32, Int32, Boolean)** + +```csharp +public int EvaluateChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, Int32& nPast, int seqId, int nBatch, bool logitsLast) +``` + +#### Parameters + +`chunkPtr` [IntPtr](https://docs.microsoft.com/en-us/dotnet/api/system.intptr)
+ +`llamaContext` [SafeLLamaContextHandle](./llama.native.safellamacontexthandle.md)
+ +`nPast` [Int32&](https://docs.microsoft.com/en-us/dotnet/api/system.int32&)
+ +`seqId` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +`nBatch` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +`logitsLast` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +#### Returns + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +### **DecodeImageChunk(IntPtr, SafeLLamaContextHandle, IntPtr, Int32&, Int32, Int32)** + +```csharp +public int DecodeImageChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, IntPtr encodedEmbeddings, Int32& nPast, int seqId, int nBatch) +``` + +#### Parameters + +`chunkPtr` [IntPtr](https://docs.microsoft.com/en-us/dotnet/api/system.intptr)
+ +`llamaContext` [SafeLLamaContextHandle](./llama.native.safellamacontexthandle.md)
+ +`encodedEmbeddings` [IntPtr](https://docs.microsoft.com/en-us/dotnet/api/system.intptr)
+ +`nPast` [Int32&](https://docs.microsoft.com/en-us/dotnet/api/system.int32&)
+ +`seqId` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +`nBatch` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +#### Returns + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +### **Dispose()** + +```csharp +public void Dispose() +``` + +--- + +[`< Back`](./) diff --git a/docs/xmldocs/llama.native.mtmdcontextparams.md b/docs/xmldocs/llama.native.mtmdcontextparams.md new file mode 100644 index 000000000..19bcee5ba --- /dev/null +++ b/docs/xmldocs/llama.native.mtmdcontextparams.md @@ -0,0 +1,152 @@ +[`< Back`](./) + +--- + +# MtmdContextParams + +Namespace: LLama.Native + +Managed representation of the native `mtmd_context_params` structure used to configure multimodal helpers. + +```csharp +public class MtmdContextParams +``` + +Inheritance [Object](https://docs.microsoft.com/en-us/dotnet/api/system.object) → [MtmdContextParams](./llama.native.mtmdcontextparams.md)
+Attributes [NullableContextAttribute](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.compilerservices.nullablecontextattribute), [NullableAttribute](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.compilerservices.nullableattribute) + +## Properties + +### **UseGpu** + +Whether GPU acceleration should be requested when available. + +```csharp +public bool UseGpu { get; set; } +``` + +#### Property Value + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **PrintTimings** + +Whether timing information should be emitted by the native helper. + +```csharp +public bool PrintTimings { get; set; } +``` + +#### Property Value + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **NThreads** + +Number of worker threads to dedicate to preprocessing and tokenization. + +```csharp +public int NThreads { get; set; } +``` + +#### Property Value + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +### **ImageMarker** + +Marker token inserted into the text stream to reference an image embedding (deprecated by mtmd). + +```csharp +public string ImageMarker { get; set; } +``` + +#### Property Value + +[String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+ +### **MediaMarker** + +Marker token inserted into the text stream to reference a generic media embedding. + +```csharp +public string MediaMarker { get; set; } +``` + +#### Property Value + +[String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+ +### **FlashAttentionType** + +Flash attention policy forwarded to mtmd encoders. + +```csharp +public LLamaFlashAttentionType FlashAttentionType { get; set; } +``` + +#### Property Value + +[LLamaFlashAttentionType](./llama.native.llamaflashattentiontype.md)
+ +### **Warmup** + +Whether to run a warmup encode pass after initialization. + +```csharp +public bool Warmup { get; set; } +``` + +#### Property Value + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **ImageMinTokens** + +Minimum number of image tokens for dynamic resolution (use -1 to read metadata). + +```csharp +public int ImageMinTokens { get; set; } +``` + +#### Property Value + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +### **ImageMaxTokens** + +Maximum number of image tokens for dynamic resolution (use -1 to read metadata). + +```csharp +public int ImageMaxTokens { get; set; } +``` + +#### Property Value + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +## Constructors + +### **MtmdContextParams()** + +```csharp +public MtmdContextParams() +``` + +## Methods + +### **Default()** + +Create a managed copy of the native defaults returned by . + +```csharp +public static MtmdContextParams Default() +``` + +#### Returns + +[MtmdContextParams](./llama.native.mtmdcontextparams.md)
+ +--- + +[`< Back`](./) diff --git a/docs/xmldocs/llama.native.nativelibraryconfigcontainer.md b/docs/xmldocs/llama.native.nativelibraryconfigcontainer.md index c3d607853..683148f9c 100644 --- a/docs/xmldocs/llama.native.nativelibraryconfigcontainer.md +++ b/docs/xmldocs/llama.native.nativelibraryconfigcontainer.md @@ -35,7 +35,7 @@ Load a specified native library as backend for LLamaSharp. When this method is called, all the other configurations will be ignored. ```csharp -public NativeLibraryConfigContainer WithLibrary(string llamaPath, string llavaPath) +public NativeLibraryConfigContainer WithLibrary(string llamaPath, string mtmdPath) ``` #### Parameters @@ -43,8 +43,8 @@ public NativeLibraryConfigContainer WithLibrary(string llamaPath, string llavaPa `llamaPath` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
The full path to the llama library to load. -`llavaPath` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
-The full path to the llava library to load. +`mtmdPath` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+The full path to the mtmd library to load. #### Returns @@ -99,7 +99,7 @@ Thrown if `LibraryHasLoaded` is true. ### **WithAvx(AvxLevel)** -Configure the prefferred avx support level of the backend. +Configure the preferred AVX support level of the backend. ```csharp public NativeLibraryConfigContainer WithAvx(AvxLevel level) @@ -263,14 +263,16 @@ Try to load the native library with the current configurations, You can still modify the configuration after this calling but only before any call from [NativeApi](./llama.native.nativeapi.md). ```csharp -public bool DryRun(INativeLibrary& loadedLLamaNativeLibrary, INativeLibrary& loadedLLavaNativeLibrary) +public bool DryRun(INativeLibrary& loadedLLamaNativeLibrary, INativeLibrary& loadedMtmdNativeLibrary) ``` #### Parameters `loadedLLamaNativeLibrary` [INativeLibrary&](./llama.abstractions.inativelibrary&.md)
+The loaded llama native library descriptor, or null when loading failed. -`loadedLLavaNativeLibrary` [INativeLibrary&](./llama.abstractions.inativelibrary&.md)
+`loadedMtmdNativeLibrary` [INativeLibrary&](./llama.abstractions.inativelibrary&.md)
+The loaded mtmd native library descriptor, or null when loading failed. #### Returns diff --git a/docs/xmldocs/llama.native.safemtmdembed.md b/docs/xmldocs/llama.native.safemtmdembed.md new file mode 100644 index 000000000..1445b2e33 --- /dev/null +++ b/docs/xmldocs/llama.native.safemtmdembed.md @@ -0,0 +1,251 @@ +[`< Back`](./) + +--- + +# SafeMtmdEmbed + +Namespace: LLama.Native + +Managed wrapper around `mtmd_bitmap*` resources. Instances own the native pointer + and ensure proper cleanup when disposed. + +```csharp +public sealed class SafeMtmdEmbed : SafeLLamaHandleBase, System.IDisposable +``` + +Inheritance [Object](https://docs.microsoft.com/en-us/dotnet/api/system.object) → [CriticalFinalizerObject](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.constrainedexecution.criticalfinalizerobject) → [SafeHandle](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.interopservices.safehandle) → [SafeLLamaHandleBase](./llama.native.safellamahandlebase.md) → [SafeMtmdEmbed](./llama.native.safemtmdembed.md)
+Implements [IDisposable](https://docs.microsoft.com/en-us/dotnet/api/system.idisposable) + +## Fields + +### **handle** + +```csharp +protected IntPtr handle; +``` + +## Properties + +### **Nx** + +Width of the bitmap in pixels (or number of samples for audio embeddings). + +```csharp +public uint Nx { get; } +``` + +#### Property Value + +[UInt32](https://docs.microsoft.com/en-us/dotnet/api/system.uint32)
+ +### **Ny** + +Height of the bitmap in pixels. For audio embeddings this is typically `1`. + +```csharp +public uint Ny { get; } +``` + +#### Property Value + +[UInt32](https://docs.microsoft.com/en-us/dotnet/api/system.uint32)
+ +### **IsAudio** + +Indicates whether the embedding stores audio data instead of image pixels. + +```csharp +public bool IsAudio { get; } +``` + +#### Property Value + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **ByteCount** + +Get the byte count of the raw bitmap/audio data in this embed + +```csharp +public ulong ByteCount { get; } +``` + +#### Property Value + +[UInt64](https://docs.microsoft.com/en-us/dotnet/api/system.uint64)
+ +### **Id** + +Optional identifier assigned to this embedding. + +```csharp +public string Id { get; set; } +``` + +#### Property Value + +[String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+ +### **IsInvalid** + +```csharp +public bool IsInvalid { get; } +``` + +#### Property Value + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **IsClosed** + +```csharp +public bool IsClosed { get; } +``` + +#### Property Value + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +## Methods + +### **FromRgbBytes(UInt32, UInt32, ReadOnlySpan<Byte>)** + +Create an embedding from raw RGB bytes. + +```csharp +public static SafeMtmdEmbed FromRgbBytes(uint nx, uint ny, ReadOnlySpan rgbData) +``` + +#### Parameters + +`nx` [UInt32](https://docs.microsoft.com/en-us/dotnet/api/system.uint32)
+Width of the bitmap in pixels. + +`ny` [UInt32](https://docs.microsoft.com/en-us/dotnet/api/system.uint32)
+Height of the bitmap in pixels. + +`rgbData` [ReadOnlySpan<Byte>](https://docs.microsoft.com/en-us/dotnet/api/system.readonlyspan-1)
+Packed RGB data (3 bytes per pixel). + +#### Returns + +[SafeMtmdEmbed](./llama.native.safemtmdembed.md)
+Managed wrapper when initialization succeeds; otherwise `null`. + +#### Exceptions + +[ArgumentNullException](https://docs.microsoft.com/en-us/dotnet/api/system.argumentnullexception)
+The RGB buffer is null. + +### **FromAudioSamples(ReadOnlySpan<Single>)** + +Create an embedding from PCM audio samples. + +```csharp +public static SafeMtmdEmbed FromAudioSamples(ReadOnlySpan samples) +``` + +#### Parameters + +`samples` [ReadOnlySpan<Single>](https://docs.microsoft.com/en-us/dotnet/api/system.readonlyspan-1)
+Array of mono PCM samples in float format. + +#### Returns + +[SafeMtmdEmbed](./llama.native.safemtmdembed.md)
+Managed wrapper when initialization succeeds; otherwise `null`. + +#### Exceptions + +[ArgumentNullException](https://docs.microsoft.com/en-us/dotnet/api/system.argumentnullexception)
+The audio buffer is null. + +### **FromMediaFile(SafeMtmdModelHandle, String)** + +Create an embedding by decoding a media file using libmtmd helpers. + +```csharp +public static SafeMtmdEmbed FromMediaFile(SafeMtmdModelHandle mtmdContext, string path) +``` + +#### Parameters + +`mtmdContext` [SafeMtmdModelHandle](./llama.native.safemtmdmodelhandle.md)
+Model context that provides the decoder configuration. + +`path` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+Path to the media file on disk. + +#### Returns + +[SafeMtmdEmbed](./llama.native.safemtmdembed.md)
+Managed wrapper when decoding succeeds; otherwise `null`. + +#### Exceptions + +[ArgumentNullException](https://docs.microsoft.com/en-us/dotnet/api/system.argumentnullexception)
+The context is null. + +[ArgumentException](https://docs.microsoft.com/en-us/dotnet/api/system.argumentexception)
+The path is null or whitespace. + +[FileNotFoundException](https://docs.microsoft.com/en-us/dotnet/api/system.io.filenotfoundexception)
+The supplied file does not exist. + +### **FromMediaBuffer(SafeMtmdModelHandle, ReadOnlySpan<Byte>)** + +Create an embedding from an in-memory media buffer (image/audio/video). + +```csharp +public static SafeMtmdEmbed FromMediaBuffer(SafeMtmdModelHandle mtmdContext, ReadOnlySpan data) +``` + +#### Parameters + +`mtmdContext` [SafeMtmdModelHandle](./llama.native.safemtmdmodelhandle.md)
+Model context that provides the decoder configuration. + +`data` [ReadOnlySpan<Byte>](https://docs.microsoft.com/en-us/dotnet/api/system.readonlyspan-1)
+Binary buffer containing the encoded media. + +#### Returns + +[SafeMtmdEmbed](./llama.native.safemtmdembed.md)
+Managed wrapper when decoding succeeds; otherwise `null`. + +#### Exceptions + +[ArgumentNullException](https://docs.microsoft.com/en-us/dotnet/api/system.argumentnullexception)
+The context is null. + +[ArgumentException](https://docs.microsoft.com/en-us/dotnet/api/system.argumentexception)
+The buffer is empty. + +### **GetData()** + +Provides safe zero-copy access to the underlying bitmap bytes. + +```csharp +public IEmbedData GetData() +``` + +#### Returns + +[IEmbedData](./llama.native.safemtmdembed.iembeddata.md)
+The data access is guaranteed to remain valid until this object is disposed. + +### **ReleaseHandle()** + +Release the underlying native bitmap. + +```csharp +protected bool ReleaseHandle() +``` + +#### Returns + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +--- + +[`< Back`](./) diff --git a/docs/xmldocs/llama.native.safemtmdinputchunk.md b/docs/xmldocs/llama.native.safemtmdinputchunk.md new file mode 100644 index 000000000..258d32a6f --- /dev/null +++ b/docs/xmldocs/llama.native.safemtmdinputchunk.md @@ -0,0 +1,180 @@ +[`< Back`](./) + +--- + +# SafeMtmdInputChunk + +Namespace: LLama.Native + +Managed wrapper around a single `mtmd_input_chunk`. Instances can either own the + underlying native pointer (when created via [SafeMtmdInputChunk.Copy()](./llama.native.safemtmdinputchunk.md#copy)) or act as non-owning views + produced by the tokenizer. + +```csharp +public sealed class SafeMtmdInputChunk : SafeLLamaHandleBase, System.IDisposable +``` + +Inheritance [Object](https://docs.microsoft.com/en-us/dotnet/api/system.object) → [CriticalFinalizerObject](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.constrainedexecution.criticalfinalizerobject) → [SafeHandle](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.interopservices.safehandle) → [SafeLLamaHandleBase](./llama.native.safellamahandlebase.md) → [SafeMtmdInputChunk](./llama.native.safemtmdinputchunk.md)
+Implements [IDisposable](https://docs.microsoft.com/en-us/dotnet/api/system.idisposable) + +## Fields + +### **handle** + +```csharp +protected IntPtr handle; +``` + +## Properties + +### **NativePtr** + +Raw pointer to the native chunk structure. + +```csharp +public IntPtr NativePtr { get; } +``` + +#### Property Value + +[IntPtr](https://docs.microsoft.com/en-us/dotnet/api/system.intptr)
+ +### **Type** + +Chunk modality reported by the native helper. + +```csharp +public SafeMtmdInputChunkType Type { get; } +``` + +#### Property Value + +[SafeMtmdInputChunkType](./llama.native.safemtmdinputchunk.safemtmdinputchunktype.md)
+ +### **NTokens** + +Number of tokens contained in this chunk. + +```csharp +public ulong NTokens { get; } +``` + +#### Property Value + +[UInt64](https://docs.microsoft.com/en-us/dotnet/api/system.uint64)
+ +### **Id** + +Identifier assigned by the tokenizer (if any). + +```csharp +public string Id { get; } +``` + +#### Property Value + +[String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+ +### **NPos** + +Number of positional slots consumed by this chunk. + +```csharp +public long NPos { get; } +``` + +#### Property Value + +[Int64](https://docs.microsoft.com/en-us/dotnet/api/system.int64)
+ +### **IsInvalid** + +```csharp +public bool IsInvalid { get; } +``` + +#### Property Value + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **IsClosed** + +```csharp +public bool IsClosed { get; } +``` + +#### Property Value + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +## Methods + +### **Wrap(IntPtr)** + +Wrap an existing chunk pointer without taking ownership. + +```csharp +public static SafeMtmdInputChunk Wrap(IntPtr ptr) +``` + +#### Parameters + +`ptr` [IntPtr](https://docs.microsoft.com/en-us/dotnet/api/system.intptr)
+Pointer returned by the native tokenizer. + +#### Returns + +[SafeMtmdInputChunk](./llama.native.safemtmdinputchunk.md)
+Managed wrapper, or `null` when the pointer is null. + +### **Copy()** + +Create an owning copy of the current chunk. The caller becomes responsible for disposal. + +```csharp +public SafeMtmdInputChunk Copy() +``` + +#### Returns + +[SafeMtmdInputChunk](./llama.native.safemtmdinputchunk.md)
+Owning managed wrapper, or `null` if the native copy failed. + +#### Exceptions + +[ObjectDisposedException](https://docs.microsoft.com/en-us/dotnet/api/system.objectdisposedexception)
+Thrown when the current wrapper has been disposed. + +### **GetTextTokensSpan()** + +Zero-copy view over the chunk's token buffer. The span remains valid only while the native chunk is alive. + +```csharp +public ReadOnlySpan GetTextTokensSpan() +``` + +#### Returns + +[ReadOnlySpan<Int32>](https://docs.microsoft.com/en-us/dotnet/api/system.readonlyspan-1)
+Read-only span exposing the chunk's tokens. + +#### Exceptions + +[ObjectDisposedException](https://docs.microsoft.com/en-us/dotnet/api/system.objectdisposedexception)
+Thrown when the wrapper has been disposed. + +### **ReleaseHandle()** + +Releases the native chunk when ownership is held by this instance. + +```csharp +protected bool ReleaseHandle() +``` + +#### Returns + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +--- + +[`< Back`](./) diff --git a/docs/xmldocs/llama.native.safemtmdinputchunks.md b/docs/xmldocs/llama.native.safemtmdinputchunks.md new file mode 100644 index 000000000..ab1110881 --- /dev/null +++ b/docs/xmldocs/llama.native.safemtmdinputchunks.md @@ -0,0 +1,149 @@ +[`< Back`](./) + +--- + +# SafeMtmdInputChunks + +Namespace: LLama.Native + +Managed lifetime wrapper around a native `mtmd_input_chunks` collection returned by the tokenizer. + +```csharp +public sealed class SafeMtmdInputChunks : SafeLLamaHandleBase, System.IDisposable +``` + +Inheritance [Object](https://docs.microsoft.com/en-us/dotnet/api/system.object) → [CriticalFinalizerObject](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.constrainedexecution.criticalfinalizerobject) → [SafeHandle](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.interopservices.safehandle) → [SafeLLamaHandleBase](./llama.native.safellamahandlebase.md) → [SafeMtmdInputChunks](./llama.native.safemtmdinputchunks.md)
+Implements [IDisposable](https://docs.microsoft.com/en-us/dotnet/api/system.idisposable)
+Attributes [NullableContextAttribute](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.compilerservices.nullablecontextattribute), [NullableAttribute](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.compilerservices.nullableattribute) + +## Fields + +### **handle** + +```csharp +protected IntPtr handle; +``` + +## Properties + +### **Size** + +Number of chunks currently held by the native collection. + +```csharp +public ulong Size { get; } +``` + +#### Property Value + +[UInt64](https://docs.microsoft.com/en-us/dotnet/api/system.uint64)
+ +### **IsInvalid** + +```csharp +public bool IsInvalid { get; } +``` + +#### Property Value + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **IsClosed** + +```csharp +public bool IsClosed { get; } +``` + +#### Property Value + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +## Methods + +### **ReleaseHandle()** + +Releases the native chunk collection. + +```csharp +protected bool ReleaseHandle() +``` + +#### Returns + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **CountTokens()** + +Get the number of tokens contained in this chunk collection. + +```csharp +public ulong CountTokens() +``` + +#### Returns + +[UInt64](https://docs.microsoft.com/en-us/dotnet/api/system.uint64)
+Total token count. + +### **CountPositions()** + +Get the number of positions contained in this chunk collection. + +```csharp +public long CountPositions() +``` + +#### Returns + +[Int64](https://docs.microsoft.com/en-us/dotnet/api/system.int64)
+Total number of positional slots consumed. + +### **GetChunkPtr(UInt64)** + +Get a raw pointer to a chunk. The returned [IntPtr](https://docs.microsoft.com/en-us/dotnet/api/system.intptr) is the `mtmd_input_chunk*`. + Use [SafeMtmdInputChunk.Wrap(IntPtr)](./llama.native.safemtmdinputchunk.md#wrapintptr) to create a managed wrapper if desired. + +```csharp +public IntPtr GetChunkPtr(ulong index) +``` + +#### Parameters + +`index` [UInt64](https://docs.microsoft.com/en-us/dotnet/api/system.uint64)
+Zero-based index of the chunk to retrieve. + +#### Returns + +[IntPtr](https://docs.microsoft.com/en-us/dotnet/api/system.intptr)
+Pointer to the requested chunk. + +#### Exceptions + +[ObjectDisposedException](https://docs.microsoft.com/en-us/dotnet/api/system.objectdisposedexception)
+The collection has already been disposed. + +[IndexOutOfRangeException](https://docs.microsoft.com/en-us/dotnet/api/system.indexoutofrangeexception)
+The requested index is outside of the valid range. + +### **Enumerate()** + +Enumerate the contained chunks as non-owning wrappers. Callers should dispose the returned chunk + if they create a copy. + +```csharp +public IEnumerable Enumerate() +``` + +#### Returns + +[IEnumerable<SafeMtmdInputChunk>](https://docs.microsoft.com/en-us/dotnet/api/system.collections.generic.ienumerable-1)
+Enumeration of chunk wrappers backed by the native collection. + +#### Exceptions + +[ObjectDisposedException](https://docs.microsoft.com/en-us/dotnet/api/system.objectdisposedexception)
+The collection has already been disposed. + +--- + +[`< Back`](./) diff --git a/docs/xmldocs/llama.native.safemtmdmodelhandle.md b/docs/xmldocs/llama.native.safemtmdmodelhandle.md new file mode 100644 index 000000000..9c9a355c0 --- /dev/null +++ b/docs/xmldocs/llama.native.safemtmdmodelhandle.md @@ -0,0 +1,469 @@ +[`< Back`](./) + +--- + +# SafeMtmdModelHandle + +Namespace: LLama.Native + +Wrapper to the Multi Modal Weights handle. This wrapper manages the low level + operations. + +```csharp +public sealed class SafeMtmdModelHandle : SafeLLamaHandleBase, System.IDisposable +``` + +Inheritance [Object](https://docs.microsoft.com/en-us/dotnet/api/system.object) → [CriticalFinalizerObject](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.constrainedexecution.criticalfinalizerobject) → [SafeHandle](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.interopservices.safehandle) → [SafeLLamaHandleBase](./llama.native.safellamahandlebase.md) → [SafeMtmdModelHandle](./llama.native.safemtmdmodelhandle.md)
+Implements [IDisposable](https://docs.microsoft.com/en-us/dotnet/api/system.idisposable)
+Attributes [NullableContextAttribute](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.compilerservices.nullablecontextattribute), [NullableAttribute](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.compilerservices.nullableattribute) + +## Fields + +### **handle** + +```csharp +protected IntPtr handle; +``` + +## Properties + +### **IsInvalid** + +```csharp +public bool IsInvalid { get; } +``` + +#### Property Value + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **IsClosed** + +```csharp +public bool IsClosed { get; } +``` + +#### Property Value + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +## Constructors + +### **SafeMtmdModelHandle()** + +```csharp +public SafeMtmdModelHandle() +``` + +## Methods + +### **ReleaseHandle()** + +```csharp +protected bool ReleaseHandle() +``` + +#### Returns + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **LoadFromFile(String, LLamaWeights, MtmdContextParams)** + +Load a multimodal projection model from disk and bind it to the supplied text model. + +```csharp +public static SafeMtmdModelHandle LoadFromFile(string modelPath, LLamaWeights textModel, MtmdContextParams mtmdCtxParams) +``` + +#### Parameters + +`modelPath` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+Path to the MMP (Multi-Modal Projections) file. + +`textModel` [LLamaWeights](./llama.llamaweights.md)
+Text model that provides tokenizer weights for the multimodal helper. + +`mtmdCtxParams` [MtmdContextParams](./llama.native.mtmdcontextparams.md)
+Optional context parameters; defaults are used when `null`. + +#### Returns + +[SafeMtmdModelHandle](./llama.native.safemtmdmodelhandle.md)
+Safe handle for the MTMD model. + +#### Exceptions + +[InvalidOperationException](https://docs.microsoft.com/en-us/dotnet/api/system.invalidoperationexception)
+The file exists but is not readable by the current process. + +[LoadWeightsFailedException](./llama.exceptions.loadweightsfailedexception.md)
+The native loader failed to initialize the MTMD model. + +### **LoadMediaFromFile(String)** + +Load media from disk and queue it for the next tokenize call. + +```csharp +public SafeMtmdEmbed LoadMediaFromFile(string path) +``` + +#### Parameters + +`path` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+Absolute or relative path to the media asset. + +#### Returns + +[SafeMtmdEmbed](./llama.native.safemtmdembed.md)
+Safe handle to the media embedding. + +#### Exceptions + +[ObjectDisposedException](https://docs.microsoft.com/en-us/dotnet/api/system.objectdisposedexception)
+The model handle has been disposed. + +[RuntimeError](./llama.exceptions.runtimeerror.md)
+The native loader failed to ingest the file. + +### **LoadMediaFromBuffer(ReadOnlySpan<Byte>)** + +Load media from an in-memory buffer and queue it for the next tokenize call. + +```csharp +public SafeMtmdEmbed LoadMediaFromBuffer(ReadOnlySpan buffer) +``` + +#### Parameters + +`buffer` [ReadOnlySpan<Byte>](https://docs.microsoft.com/en-us/dotnet/api/system.readonlyspan-1)
+Binary buffer containing the encoded media data. + +#### Returns + +[SafeMtmdEmbed](./llama.native.safemtmdembed.md)
+Safe handle to the media embedding. + +#### Exceptions + +[ObjectDisposedException](https://docs.microsoft.com/en-us/dotnet/api/system.objectdisposedexception)
+The model handle has been disposed. + +[RuntimeError](./llama.exceptions.runtimeerror.md)
+The native loader failed to ingest the buffer contents. + +### **CreateMediaEmbedFromFile(String)** + +Create a standalone media embedding from disk without queueing it for the next tokenize call. + +```csharp +public SafeMtmdEmbed CreateMediaEmbedFromFile(string path) +``` + +#### Parameters + +`path` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+Path to the media file on disk. + +#### Returns + +[SafeMtmdEmbed](./llama.native.safemtmdembed.md)
+Safe handle to the prepared media embedding. + +#### Exceptions + +[ObjectDisposedException](https://docs.microsoft.com/en-us/dotnet/api/system.objectdisposedexception)
+The model handle has been disposed. + +[RuntimeError](./llama.exceptions.runtimeerror.md)
+The native loader failed to ingest the file contents. + +### **CreateMediaEmbedFromBuffer(ReadOnlySpan<Byte>)** + +Create a standalone media embedding from an in-memory buffer without queueing it for the next tokenize call. + +```csharp +public SafeMtmdEmbed CreateMediaEmbedFromBuffer(ReadOnlySpan buffer) +``` + +#### Parameters + +`buffer` [ReadOnlySpan<Byte>](https://docs.microsoft.com/en-us/dotnet/api/system.readonlyspan-1)
+Binary buffer containing the encoded media data. + +#### Returns + +[SafeMtmdEmbed](./llama.native.safemtmdembed.md)
+Safe handle to the prepared media embedding. + +#### Exceptions + +[ObjectDisposedException](https://docs.microsoft.com/en-us/dotnet/api/system.objectdisposedexception)
+The model handle has been disposed. + +[RuntimeError](./llama.exceptions.runtimeerror.md)
+The native loader failed to ingest the buffer contents. + +### **ClearMedia()** + +Disposes and clears any media buffers currently queued for tokenization. + +```csharp +public void ClearMedia() +``` + +### **Tokenize(String, Boolean, Boolean, SafeMtmdInputChunks&)** + +Tokenize a prompt alongside the pending media buffers. Pending media is cleared on success. + +```csharp +public int Tokenize(string text, bool addSpecial, bool parseSpecial, SafeMtmdInputChunks& chunks) +``` + +#### Parameters + +`text` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+Prompt text to tokenize. + +`addSpecial` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+Whether to append special tokens automatically. + +`parseSpecial` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+Whether special tokens should be treated as user-provided text. + +`chunks` [SafeMtmdInputChunks&](./llama.native.safemtmdinputchunks&.md)
+Receives the native chunk collection when tokenization succeeds. + +#### Returns + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+Zero on success; otherwise the native mtmd tokenize error code. + +#### Exceptions + +[ObjectDisposedException](https://docs.microsoft.com/en-us/dotnet/api/system.objectdisposedexception)
+The model handle has been disposed. + +### **Tokenize(String, Boolean, Boolean, ReadOnlySpan<SafeMtmdEmbed>, SafeMtmdInputChunks&)** + +Tokenize a prompt alongside the provided media embeddings. + The caller retains ownership of `embeds`. + +```csharp +public int Tokenize(string text, bool addSpecial, bool parseSpecial, ReadOnlySpan embeds, SafeMtmdInputChunks& chunks) +``` + +#### Parameters + +`text` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+Prompt text to tokenize. + +`addSpecial` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+Whether to append special tokens automatically. + +`parseSpecial` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+Whether special tokens should be treated as user-provided text. + +`embeds` [ReadOnlySpan<SafeMtmdEmbed>](https://docs.microsoft.com/en-us/dotnet/api/system.readonlyspan-1)
+Media embeddings to include in the multimodal prompt. + +`chunks` [SafeMtmdInputChunks&](./llama.native.safemtmdinputchunks&.md)
+Receives the native chunk collection when tokenization succeeds. + +#### Returns + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+Zero on success; otherwise the native mtmd tokenize error code. + +#### Exceptions + +[ObjectDisposedException](https://docs.microsoft.com/en-us/dotnet/api/system.objectdisposedexception)
+The model handle has been disposed. + +[RuntimeError](./llama.exceptions.runtimeerror.md)
+The native tokenizer failed to allocate output chunks. + +### **EvaluateChunks(SafeMtmdInputChunks, SafeLLamaContextHandle, Int32&, Int32, Int32, Boolean)** + +Evaluate a batch of chunks using the helper (mirrors mtmd-helper eval logic). + +```csharp +public int EvaluateChunks(SafeMtmdInputChunks chunks, SafeLLamaContextHandle llamaContext, Int32& nPast, int seqId, int nBatch, bool logitsLast) +``` + +#### Parameters + +`chunks` [SafeMtmdInputChunks](./llama.native.safemtmdinputchunks.md)
+Chunk collection produced by [SafeMtmdModelHandle.Tokenize(String, Boolean, Boolean, SafeMtmdInputChunks&)](./llama.native.safemtmdmodelhandle.md#tokenizestring-boolean-boolean-safemtmdinputchunks&). + +`llamaContext` [SafeLLamaContextHandle](./llama.native.safellamacontexthandle.md)
+Context handle that receives the evaluated tokens. + +`nPast` [Int32&](https://docs.microsoft.com/en-us/dotnet/api/system.int32&)
+Number of past tokens; updated when evaluation succeeds. + +`seqId` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+Sequence identifier used for KV cache management. + +`nBatch` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+Maximum number of tokens to evaluate in a single batch. + +`logitsLast` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+Whether to request logits for the last token only. + +#### Returns + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+Zero on success; otherwise the native helper error code. + +#### Exceptions + +[ArgumentNullException](https://docs.microsoft.com/en-us/dotnet/api/system.argumentnullexception)
+Thrown when required handles are null. + +### **EvaluateChunk(IntPtr, SafeLLamaContextHandle, Int32&, Int32, Int32, Boolean)** + +Evaluate a single chunk helper. + +```csharp +public int EvaluateChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, Int32& nPast, int seqId, int nBatch, bool logitsLast) +``` + +#### Parameters + +`chunkPtr` [IntPtr](https://docs.microsoft.com/en-us/dotnet/api/system.intptr)
+Pointer to the chunk to evaluate. + +`llamaContext` [SafeLLamaContextHandle](./llama.native.safellamacontexthandle.md)
+Context handle that receives the evaluated tokens. + +`nPast` [Int32&](https://docs.microsoft.com/en-us/dotnet/api/system.int32&)
+Number of past tokens; updated when evaluation succeeds. + +`seqId` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+Sequence identifier used for KV cache management. + +`nBatch` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+Maximum number of tokens to evaluate in a single batch. + +`logitsLast` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+Whether to request logits for the last token only. + +#### Returns + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+Zero on success; otherwise the native helper error code. + +#### Exceptions + +[ArgumentNullException](https://docs.microsoft.com/en-us/dotnet/api/system.argumentnullexception)
+Thrown when required handles are null. + +### **DecodeImageChunk(IntPtr, SafeLLamaContextHandle, IntPtr, Int32&, Int32, Int32)** + +Decode a prepared image chunk whose embedding is already computed. + +```csharp +public int DecodeImageChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, IntPtr encodedEmbeddings, Int32& nPast, int seqId, int nBatch) +``` + +#### Parameters + +`chunkPtr` [IntPtr](https://docs.microsoft.com/en-us/dotnet/api/system.intptr)
+Pointer to the chunk whose embedding should be decoded. + +`llamaContext` [SafeLLamaContextHandle](./llama.native.safellamacontexthandle.md)
+Context handle used for decoding. + +`encodedEmbeddings` [IntPtr](https://docs.microsoft.com/en-us/dotnet/api/system.intptr)
+Pointer to the pre-computed embedding data. + +`nPast` [Int32&](https://docs.microsoft.com/en-us/dotnet/api/system.int32&)
+Number of past tokens; updated when evaluation succeeds. + +`seqId` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+Sequence identifier used for KV cache management. + +`nBatch` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+Maximum number of tokens to evaluate in a single batch. + +#### Returns + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+Zero on success; otherwise the native helper error code. + +#### Exceptions + +[ArgumentNullException](https://docs.microsoft.com/en-us/dotnet/api/system.argumentnullexception)
+Thrown when required handles are null. + +### **Finalize()** + +Finalizer to ensure native resources are released if Dispose was not called. + +```csharp +protected void Finalize() +``` + +### **DecodeUseNonCausal()** + +Indicates whether the model decodes using the non-causal path. + +```csharp +public bool DecodeUseNonCausal() +``` + +#### Returns + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **DecodeUseMRope()** + +Indicates whether the model decodes using multi-scale RoPE. + +```csharp +public bool DecodeUseMRope() +``` + +#### Returns + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **SupportVision()** + +Indicates whether the model supports vision inputs. + +```csharp +public bool SupportVision() +``` + +#### Returns + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **SupportAudio()** + +Indicates whether the model supports audio inputs. + +```csharp +public bool SupportAudio() +``` + +#### Returns + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **GetAudioBitrate()** + +Gets the audio bitrate advertised by the model. + +```csharp +public int GetAudioBitrate() +``` + +#### Returns + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +--- + +[`< Back`](./) diff --git a/docs/xmldocs/llama.statefulexecutorbase.md b/docs/xmldocs/llama.statefulexecutorbase.md index 8e4b2226d..fbc8744da 100644 --- a/docs/xmldocs/llama.statefulexecutorbase.md +++ b/docs/xmldocs/llama.statefulexecutorbase.md @@ -28,12 +28,22 @@ protected ILogger _logger; ### **_pastTokensCount** -The tokens that were already processed by the model. +The positional cursor that has already been processed by the model. ```csharp protected int _pastTokensCount; ``` +### **_kvTokenCount** + +Number of KV slots currently occupied by the active sequence. + For MTMD prompts this can diverge from because media chunks may + consume a different number of decoded tokens than logical positions. + +```csharp +protected int _kvTokenCount; +``` + ### **_consumedTokensCount** The tokens that were consumed by the model during the current inference. @@ -44,7 +54,7 @@ protected int _consumedTokensCount; ### **_n_session_consumed** - +Number of tokens consumed from the session cache during the current run. ```csharp protected int _n_session_consumed; @@ -52,7 +62,7 @@ protected int _n_session_consumed; ### **_n_matching_session_tokens** - +Number of prompt tokens that match the loaded session cache prefix. ```csharp protected int _n_matching_session_tokens; @@ -84,7 +94,7 @@ protected List _embed_inps; ### **_session_tokens** - +Tokens recovered from the session file and reused to warm up the KV cache. ```csharp protected List _session_tokens; @@ -112,6 +122,18 @@ public LLamaContext Context { get; } [LLamaContext](./llama.llamacontext.md)
+### **AntipromptProcessor** + +Tracks anti-prompts across streamed output. + +```csharp +protected AntipromptProcessor AntipromptProcessor { get; } +``` + +#### Property Value + +[AntipromptProcessor](./llama.antipromptprocessor.md)
+ ### **IsMultiModal** ```csharp @@ -125,28 +147,28 @@ public bool IsMultiModal { get; } ### **ClipModel** ```csharp -public LLavaWeights ClipModel { get; } +public MtmdWeights ClipModel { get; } ``` #### Property Value -[LLavaWeights](./llama.llavaweights.md)
+[MtmdWeights](./llama.mtmdweights.md)
-### **Images** +### **Embeds** ```csharp -public List Images { get; } +public List Embeds { get; } ``` #### Property Value -[List<Byte[]>](https://docs.microsoft.com/en-us/dotnet/api/system.collections.generic.list-1)
+[List<SafeMtmdEmbed>](https://docs.microsoft.com/en-us/dotnet/api/system.collections.generic.list-1)
## Constructors ### **StatefulExecutorBase(LLamaContext, ILogger)** - +Initialize a stateful executor bound to a specific context. ```csharp protected StatefulExecutorBase(LLamaContext context, ILogger logger) @@ -155,30 +177,35 @@ protected StatefulExecutorBase(LLamaContext context, ILogger logger) #### Parameters `context` [LLamaContext](./llama.llamacontext.md)
+LLama context used for all native interactions. `logger` ILogger
+Optional logger for diagnostic output. -### **StatefulExecutorBase(LLamaContext, LLavaWeights, ILogger)** - +### **StatefulExecutorBase(LLamaContext, MtmdWeights, ILogger)** +Initialize a multimodal executor with the supplied MTMD weights. ```csharp -public StatefulExecutorBase(LLamaContext context, LLavaWeights lLavaWeights, ILogger logger) +public StatefulExecutorBase(LLamaContext context, MtmdWeights mtmdWeights, ILogger logger) ``` #### Parameters `context` [LLamaContext](./llama.llamacontext.md)
+LLama context used for all native interactions. -`lLavaWeights` [LLavaWeights](./llama.llavaweights.md)
+`mtmdWeights` [MtmdWeights](./llama.mtmdweights.md)
+Multimodal weights to associate with this executor. `logger` ILogger
+Optional logger for diagnostic output. ## Methods ### **WithSessionFile(String)** -This API is currently not verified. +Attach a session cache file so the executor can reuse previous KV state if compatible. ```csharp public StatefulExecutorBase WithSessionFile(string filename) @@ -187,10 +214,12 @@ public StatefulExecutorBase WithSessionFile(string filename) #### Parameters `filename` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+Path to the llama.cpp session file. #### Returns [StatefulExecutorBase](./llama.statefulexecutorbase.md)
+The current executor instance for fluent configuration. #### Exceptions @@ -200,7 +229,7 @@ public StatefulExecutorBase WithSessionFile(string filename) ### **SaveSessionFile(String)** -This API has not been verified currently. +Persist the current session cache to disk. ```csharp public void SaveSessionFile(string filename) @@ -209,6 +238,7 @@ public void SaveSessionFile(string filename) #### Parameters `filename` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+Destination path for the llama.cpp session file. ### **HandleRunOutOfContext(Int32)** @@ -222,95 +252,370 @@ protected void HandleRunOutOfContext(int tokensToKeep) `tokensToKeep` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+### **GetPendingInputPositionCount()** + +Get the number of context positions required by the pending input. + +```csharp +protected int GetPendingInputPositionCount() +``` + +#### Returns + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +### **GetPendingInputKvTokenCount()** + +Get the number of KV slots required by the pending input. + For multimodal prompts this counts decoded tokens rather than logical positions. + +```csharp +protected int GetPendingInputKvTokenCount() +``` + +#### Returns + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +### **GetOccupiedKvTokenCount()** + +Get the number of KV slots that are already occupied by the active sequence. + +```csharp +protected int GetOccupiedKvTokenCount() +``` + +#### Returns + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +### **EnsurePendingInputFitsContext(Int32)** + +Shift the context window when the pending input would overflow the available KV cache. + +```csharp +protected void EnsurePendingInputFitsContext(int tokensToKeep) +``` + +#### Parameters + +`tokensToKeep` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +### **DecodeQueuedEmbedsAsync(LLamaBatch, Int32, CancellationToken)** + +Decode the currently queued token batch and update the shared executor bookkeeping. + +```csharp +protected Task DecodeQueuedEmbedsAsync(LLamaBatch batch, int tokensToKeep, CancellationToken cancellationToken) +``` + +#### Parameters + +`batch` [LLamaBatch](./llama.native.llamabatch.md)
+ +`tokensToKeep` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
+ +#### Returns + +[Task](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task)
+ +### **EvaluatePendingMtmdMediaSegment(String, Int32)** + +Evaluate the next queued MTMD media segment after validating context availability. + +```csharp +protected void EvaluatePendingMtmdMediaSegment(string executorName, int tokensToKeep) +``` + +#### Parameters + +`executorName` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+ +`tokensToKeep` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ ### **TryReuseMatchingPrefix()** -Try to reuse the matching prefix from the session file. +Try to reuse the matching prompt prefix from the loaded session cache before evaluating new tokens. ```csharp protected void TryReuseMatchingPrefix() ``` -### **GetLoopCondition(InferStateArgs)** +### **DisposeMtmdPromptSegments()** + +Dispose and clear any queued multimodal prompt segments. + +```csharp +protected void DisposeMtmdPromptSegments() +``` + +### **DisposeEmbeds()** + +Dispose and clear any pending multimodal embeddings. + +```csharp +protected void DisposeEmbeds() +``` + +### **GetMtmdMarker()** + +Retrieve the marker token used to signal media segments to the tokenizer. + +```csharp +protected string GetMtmdMarker() +``` + +#### Returns + +[String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+ +### **GetFillerToken(String)** + +Resolve the fallback token inserted when the tokenizer emits fewer tokens than positions. + +```csharp +protected LLamaToken GetFillerToken(string marker) +``` + +#### Parameters + +`marker` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+ +#### Returns + +[LLamaToken](./llama.native.llamatoken.md)
+ +### **PreprocessMtmd(String, InferStateArgs, Boolean, Boolean)** -Decide whether to continue the loop. +Prepare multimodal inputs by invoking the MTMD tokenizer and aligning filler tokens. ```csharp -protected abstract Task GetLoopCondition(InferStateArgs args) +protected Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, bool replaceExisting) ``` #### Parameters +`text` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+ `args` [InferStateArgs](./llama.statefulexecutorbase.inferstateargs.md)
+`addBos` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +`replaceExisting` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +#### Returns + +[Task](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task)
+ +### **LoadMtmdPromptSegments(SafeMtmdInputChunks, LLamaToken, Boolean)** + +Build ordered MTMD prompt segments and logical placeholder tokens from the tokenized chunks. + +```csharp +protected int LoadMtmdPromptSegments(SafeMtmdInputChunks chunks, LLamaToken fillerToken, bool replaceExisting) +``` + +#### Parameters + +`chunks` [SafeMtmdInputChunks](./llama.native.safemtmdinputchunks.md)
+ +`fillerToken` [LLamaToken](./llama.native.llamatoken.md)
+ +`replaceExisting` [Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +#### Returns + +[Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +### **QueueNextMtmdPromptInput()** + +Queue the next multimodal text segment into if available. + +```csharp +protected void QueueNextMtmdPromptInput() +``` + +### **ClearQueuedMtmdPromptText()** + +Clear the marker that tracks prompt text currently staged in for MTMD execution. + +```csharp +protected void ClearQueuedMtmdPromptText() +``` + +### **ClearCurrentEmbeds()** + +Clear the currently staged decode buffer and any MTMD prompt-text marker associated with it. + +```csharp +protected void ClearCurrentEmbeds() +``` + +### **EvaluateNextMtmdMediaSegment(String)** + +Evaluate the current multimodal media segment and advance the prompt cursor. + +```csharp +protected void EvaluateNextMtmdMediaSegment(string executorName) +``` + +#### Parameters + +`executorName` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+ +### **OnTextDecodeCompleted(Int32, Int32)** + +Update counters after a successful text decode. + +```csharp +protected void OnTextDecodeCompleted(int previousPastCount, int newPastCount) +``` + +#### Parameters + +`previousPastCount` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +`newPastCount` [Int32](https://docs.microsoft.com/en-us/dotnet/api/system.int32)
+ +### **HasPendingPromptInput()** + +Determine whether any prompt input is still pending evaluation. + +```csharp +protected bool HasPendingPromptInput() +``` + +#### Returns + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **HasPendingPromptQueue()** + +Determine whether any prompt content still remains to be queued/evaluated, excluding sampled output tokens. + +```csharp +protected bool HasPendingPromptQueue() +``` + +#### Returns + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **HasPendingMtmdMediaSegment()** + +Determine whether the next queued multimodal prompt segment is a media chunk. + +```csharp +protected bool HasPendingMtmdMediaSegment() +``` + +#### Returns + +[Boolean](https://docs.microsoft.com/en-us/dotnet/api/system.boolean)
+ +### **GetLoopCondition(InferStateArgs, CancellationToken)** + +Determine whether the inference loop should continue processing tokens. + +```csharp +protected abstract Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken) +``` + +#### Parameters + +`args` [InferStateArgs](./llama.statefulexecutorbase.inferstateargs.md)
+Mutable state associated with the current inference. + +`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
+ #### Returns [Task<Boolean>](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task-1)
+`true` to continue generating; otherwise `false`. -### **PreprocessInputs(String, InferStateArgs)** +### **PreprocessInputs(String, InferStateArgs, CancellationToken)** -Preprocess the inputs before the inference. +Prepare the executor for inference by tokenizing input and updating cached state. ```csharp -protected abstract Task PreprocessInputs(string text, InferStateArgs args) +protected abstract Task PreprocessInputs(string text, InferStateArgs args, CancellationToken cancellationToken) ``` #### Parameters `text` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+Prompt text to process. `args` [InferStateArgs](./llama.statefulexecutorbase.inferstateargs.md)
+Mutable state associated with the current inference. + +`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
#### Returns [Task](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task)
-### **PostProcess(IInferenceParams, InferStateArgs)** +### **PostProcess(IInferenceParams, InferStateArgs, CancellationToken)** -Do some post processing after the inference. +Perform any post-processing on the generated tokens. ```csharp -protected abstract Task>> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) +protected abstract Task>> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) ``` #### Parameters `inferenceParams` [IInferenceParams](./llama.abstractions.iinferenceparams.md)
+Parameters controlling sampling. `args` [InferStateArgs](./llama.statefulexecutorbase.inferstateargs.md)
+Mutable state associated with the current inference. + +`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
#### Returns [Task<ValueTuple<Boolean, IReadOnlyList<String>>>](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task-1)
+A tuple indicating whether generation should stop and any extra outputs to emit. -### **InferInternal(IInferenceParams, InferStateArgs)** +### **InferInternal(IInferenceParams, InferStateArgs, CancellationToken)** -The core inference logic. +Core inference loop that advances the model by one step. ```csharp -protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) +protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) ``` #### Parameters `inferenceParams` [IInferenceParams](./llama.abstractions.iinferenceparams.md)
+Parameters controlling sampling. `args` [InferStateArgs](./llama.statefulexecutorbase.inferstateargs.md)
+Mutable state associated with the current inference. + +`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
#### Returns [Task](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task)
-### **SaveState(String)** +### **SaveState(String, CancellationToken)** -Save the current state to a file. +Save the executor state to a serialized snapshot file. ```csharp -public abstract Task SaveState(string filename) +public abstract Task SaveState(string filename, CancellationToken cancellationToken) ``` #### Parameters `filename` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+Destination file for the serialized state. + +`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
#### Returns @@ -318,7 +623,7 @@ public abstract Task SaveState(string filename) ### **GetStateData()** -Get the current state data. +Capture the executor state in a serializable object. ```csharp public abstract ExecutorBaseState GetStateData() @@ -327,34 +632,41 @@ public abstract ExecutorBaseState GetStateData() #### Returns [ExecutorBaseState](./llama.statefulexecutorbase.executorbasestate.md)
+State snapshot suitable for persistence. -### **LoadState(ExecutorBaseState)** +### **LoadState(ExecutorBaseState, CancellationToken)** -Load the state from data. +Restore executor state from a previously captured snapshot. ```csharp -public abstract Task LoadState(ExecutorBaseState data) +public abstract Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken) ``` #### Parameters `data` [ExecutorBaseState](./llama.statefulexecutorbase.executorbasestate.md)
+State snapshot created by [StatefulExecutorBase.GetStateData()](./llama.statefulexecutorbase.md#getstatedata). + +`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
#### Returns [Task](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task)
-### **LoadState(String)** +### **LoadState(String, CancellationToken)** -Load the state from a file. +Restore executor state from a serialized snapshot file. ```csharp -public abstract Task LoadState(string filename) +public abstract Task LoadState(string filename, CancellationToken cancellationToken) ``` #### Parameters `filename` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+Path to the snapshot produced by [StatefulExecutorBase.SaveState(String, CancellationToken)](./llama.statefulexecutorbase.md#savestatestring-cancellationtoken). + +`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
#### Returns @@ -362,7 +674,7 @@ public abstract Task LoadState(string filename) ### **InferAsync(String, IInferenceParams, CancellationToken)** -Execute the inference. +Execute an asynchronous inference session. ```csharp public IAsyncEnumerable InferAsync(string text, IInferenceParams inferenceParams, CancellationToken cancellationToken) @@ -371,23 +683,26 @@ public IAsyncEnumerable InferAsync(string text, IInferenceParams inferen #### Parameters `text` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
-The prompt. If null, generation will continue where it left off previously. +Optional prompt; when null generation resumes from prior state. `inferenceParams` [IInferenceParams](./llama.abstractions.iinferenceparams.md)
+Sampling parameters to apply; defaults are used when null. `cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
+Cancellation token for cooperative cancellation. #### Returns [IAsyncEnumerable<String>](https://docs.microsoft.com/en-us/dotnet/api/system.collections.generic.iasyncenumerable-1)
+Stream of decoded text segments as they become available. -### **PrefillPromptAsync(String)** +### **PrefillPromptAsync(String, CancellationToken)** Asynchronously runs a prompt through the model to compute KV cache without generating any new tokens. It could reduce the latency of the first time response if the first input from the user is not immediate. ```csharp -public Task PrefillPromptAsync(string prompt) +public Task PrefillPromptAsync(string prompt, CancellationToken cancellationToken) ``` #### Parameters @@ -395,6 +710,8 @@ public Task PrefillPromptAsync(string prompt) `prompt` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
Prompt to process +`cancellationToken` [CancellationToken](https://docs.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
+ #### Returns [Task](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task)
diff --git a/docs/xmldocs/llama.statelessexecutor.md b/docs/xmldocs/llama.statelessexecutor.md index 60c2e6e8b..a3a25d28d 100644 --- a/docs/xmldocs/llama.statelessexecutor.md +++ b/docs/xmldocs/llama.statelessexecutor.md @@ -32,22 +32,22 @@ public bool IsMultiModal { get; } ### **ClipModel** ```csharp -public LLavaWeights ClipModel { get; } +public MtmdWeights ClipModel { get; } ``` #### Property Value -[LLavaWeights](./llama.llavaweights.md)
+[MtmdWeights](./llama.mtmdweights.md)
-### **Images** +### **Embeds** ```csharp -public List Images { get; } +public List Embeds { get; } ``` #### Property Value -[List<Byte[]>](https://docs.microsoft.com/en-us/dotnet/api/system.collections.generic.list-1)
+[List<SafeMtmdEmbed>](https://docs.microsoft.com/en-us/dotnet/api/system.collections.generic.list-1)
### **Context** diff --git a/mkdocs.yml b/mkdocs.yml index fbffdbba7..8145adbec 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -108,7 +108,7 @@ nav: - llama.llamaquantizer: ./xmldocs/llama.llamaquantizer.md - llama.llamastate: ./xmldocs/llama.llamastate.md - llama.llamatransforms: ./xmldocs/llama.llamatransforms.md - - llama.llavaweights: ./xmldocs/llama.llavaweights.md + - llama.mtmdweights: ./xmldocs/llama.mtmdweights.md - llama.native.decoderesult: ./xmldocs/llama.native.decoderesult.md - llama.native.ggmltype: ./xmldocs/llama.native.ggmltype.md - llama.native.gpusplitmode: ./xmldocs/llama.native.gpusplitmode.md @@ -139,7 +139,7 @@ nav: - llama.native.llamatokendataarraynative: ./xmldocs/llama.native.llamatokendataarraynative.md - llama.native.llamatokentype: ./xmldocs/llama.native.llamatokentype.md - llama.native.llamavocabtype: ./xmldocs/llama.native.llamavocabtype.md - - llama.native.llavaimageembed: ./xmldocs/llama.native.llavaimageembed.md + - llama.native.mtmdcontextparams: ./xmldocs/llama.native.mtmdcontextparams.md - llama.native.nativeapi: ./xmldocs/llama.native.nativeapi.md - llama.native.nativelibraryconfig: ./xmldocs/llama.native.nativelibraryconfig.md - llama.native.ropescalingtype: ./xmldocs/llama.native.ropescalingtype.md @@ -147,8 +147,10 @@ nav: - llama.native.safellamagrammarhandle: ./xmldocs/llama.native.safellamagrammarhandle.md - llama.native.safellamahandlebase: ./xmldocs/llama.native.safellamahandlebase.md - llama.native.safellamamodelhandle: ./xmldocs/llama.native.safellamamodelhandle.md - - llama.native.safellavaimageembedhandle: ./xmldocs/llama.native.safellavaimageembedhandle.md - - llama.native.safellavamodelhandle: ./xmldocs/llama.native.safellavamodelhandle.md + - llama.native.safemtmdembed: ./xmldocs/llama.native.safemtmdembed.md + - llama.native.safemtmdinputchunk: ./xmldocs/llama.native.safemtmdinputchunk.md + - llama.native.safemtmdinputchunks: ./xmldocs/llama.native.safemtmdinputchunks.md + - llama.native.safemtmdmodelhandle: ./xmldocs/llama.native.safemtmdmodelhandle.md - llama.quantizer: ./xmldocs/llama.quantizer.md - llama.sampling.basesamplingpipeline: ./xmldocs/llama.sampling.basesamplingpipeline.md - llama.sampling.defaultsamplingpipeline: ./xmldocs/llama.sampling.defaultsamplingpipeline.md From c2170b07b7c4c82826b764e60d15cab1894060aa Mon Sep 17 00:00:00 2001 From: SignalRT Date: Fri, 20 Mar 2026 22:24:17 +0100 Subject: [PATCH 5/8] Solve some issues in the WEB Stop and load the model on change Solve issue with the ENTER --- LLama.Web/Pages/Chat.razor | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/LLama.Web/Pages/Chat.razor b/LLama.Web/Pages/Chat.razor index 642814686..2f3fcd71e 100644 --- a/LLama.Web/Pages/Chat.razor +++ b/LLama.Web/Pages/Chat.razor @@ -25,7 +25,7 @@ } - + @if (!string.IsNullOrWhiteSpace(_composerError)) { @@ -436,6 +436,26 @@ } } + private async Task OnModelChangedAsync(ChangeEventArgs args) + { + var selectedModel = args.Value?.ToString() ?? string.Empty; + if (string.IsNullOrWhiteSpace(selectedModel) || string.Equals(selectedModel, _sessionConfig.Model, StringComparison.Ordinal)) + return; + + var hadLoadedSession = _sessionLoaded; + _sessionConfig.Model = selectedModel; + + if (!hadLoadedSession) + return; + + if (_isSending) + await CancelAsync(); + + _statusText = $"Switching model to {selectedModel}..."; + + await LoadSessionAsync(); + } + private async Task UnloadSessionAsync() { if (!_sessionLoaded) @@ -700,14 +720,6 @@ return $"{string.Join(", ", parts)}."; } - private async Task HandleKeyPress(KeyboardEventArgs args) - { - if (args.Key == "Enter" && !args.ShiftKey && !args.CtrlKey && !args.AltKey && !args.MetaKey) - { - await SendAsync(); - } - } - private async Task ToggleAudioRecordingAsync() { if (_isRecordingAudio) From 68622e81941ef03a47ae0f271e44c74f4a95f23a Mon Sep 17 00:00:00 2001 From: SignalRT Date: Fri, 20 Mar 2026 23:27:19 +0100 Subject: [PATCH 6/8] Add log on attachment management --- LLama.Web/Pages/Chat.razor | 7 ++ LLama.Web/Services/AttachmentService.cs | 24 +++++- LLama.Web/Services/ModelSessionService.cs | 94 +++++++++++++++++++---- LLama.Web/wwwroot/js/blazorChat.js | 71 ++++++++++++++--- 4 files changed, 170 insertions(+), 26 deletions(-) diff --git a/LLama.Web/Pages/Chat.razor b/LLama.Web/Pages/Chat.razor index 2f3fcd71e..fd285a2ae 100644 --- a/LLama.Web/Pages/Chat.razor +++ b/LLama.Web/Pages/Chat.razor @@ -750,6 +750,12 @@ var result = await JS.InvokeAsync("llamaChat.stopAudioRecording"); _isRecordingAudio = false; + if (!string.IsNullOrWhiteSpace(result?.Error)) + { + AddSystemMessage(result.Error, MessageKind.Error); + return; + } + if (result == null || string.IsNullOrWhiteSpace(result.PreviewUrl)) { AddSystemMessage("No audio captured.", MessageKind.Info); @@ -986,6 +992,7 @@ public string PreviewUrl { get; set; } public string MimeType { get; set; } public long Size { get; set; } + public string Error { get; set; } } private void SetTab(SidebarTab tab) diff --git a/LLama.Web/Services/AttachmentService.cs b/LLama.Web/Services/AttachmentService.cs index 4990bf845..a5d46e3e5 100644 --- a/LLama.Web/Services/AttachmentService.cs +++ b/LLama.Web/Services/AttachmentService.cs @@ -14,12 +14,14 @@ public class AttachmentService : IAttachmentService private const int MaxExtractedCharacters = 12000; private const long MaxUploadSize = 512L * 1024 * 1024; private readonly IWebHostEnvironment _environment; + private readonly ILogger _logger; private readonly string _uploadsRoot; private readonly ConcurrentDictionary> _attachments = new(); - public AttachmentService(IWebHostEnvironment environment) + public AttachmentService(IWebHostEnvironment environment, ILogger logger) { _environment = environment; + _logger = logger; var appDataRoot = Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData); _uploadsRoot = Path.Combine(appDataRoot, "LLama.Web", "Uploads"); Directory.CreateDirectory(_uploadsRoot); @@ -69,6 +71,16 @@ public async Task SaveAsync(string connectionId, IEnumer if (info.Kind == AttachmentKind.Word) ExtractWordText(info); + _logger.LogInformation( + "Saved form attachment {AttachmentId} for session {SessionId}: name={FileName}, contentType={ContentType}, size={SizeBytes}, kind={Kind}, path={FilePath}", + info.Id, + connectionId, + info.FileName, + info.ContentType, + info.SizeBytes, + info.Kind, + info.FilePath); + storage[id] = info; result.Attachments.Add(info); } @@ -119,6 +131,16 @@ public async Task SaveAsync(string connectionId, IEnumer if (info.Kind == AttachmentKind.Word) ExtractWordText(info); + _logger.LogInformation( + "Saved browser attachment {AttachmentId} for session {SessionId}: name={FileName}, contentType={ContentType}, size={SizeBytes}, kind={Kind}, path={FilePath}", + info.Id, + connectionId, + info.FileName, + info.ContentType, + info.SizeBytes, + info.Kind, + info.FilePath); + storage[id] = info; result.Attachments.Add(info); } diff --git a/LLama.Web/Services/ModelSessionService.cs b/LLama.Web/Services/ModelSessionService.cs index ef748c68c..86e9c2e17 100644 --- a/LLama.Web/Services/ModelSessionService.cs +++ b/LLama.Web/Services/ModelSessionService.cs @@ -267,8 +267,24 @@ private async Task PreparePromptAsync(string sessionId, ModelSes var mediaMarker = GetMediaMarker(); var attachments = _attachmentService.GetAttachments(sessionId, request.AttachmentIds); + _logger.LogInformation( + "Preparing prompt for session {SessionId} with {RequestedCount} attachment id(s); resolved {ResolvedCount}: [{ResolvedAttachments}]", + sessionId, + request.AttachmentIds.Count, + attachments.Count, + string.Join(", ", attachments.Select(a => $"{a.Id}:{a.Kind}:{a.ContentType}:{a.SizeBytes}:{a.FileName}"))); + foreach (var attachment in attachments) { + _logger.LogInformation( + "Processing attachment {AttachmentId} for session {SessionId}: kind={Kind}, contentType={ContentType}, size={SizeBytes}, file={FileName}", + attachment.Id, + sessionId, + attachment.Kind, + attachment.ContentType, + attachment.SizeBytes, + attachment.FileName); + switch (attachment.Kind) { case AttachmentKind.Pdf: @@ -281,14 +297,14 @@ private async Task PreparePromptAsync(string sessionId, ModelSes if (!modelSession.IsMultiModal) throw new Exception("This model does not support multimodal inputs."); - embeds.Add(await LoadEmbedAsync(modelSession, attachment, cancellationToken)); + embeds.Add(await LoadEmbedAsync(sessionId, modelSession, attachment, cancellationToken)); AppendMedia(modelSession, promptBuilder, mediaPrefixBuilder, mediaMarker); break; case AttachmentKind.Audio: if (!modelSession.IsMultiModal || !modelSession.SupportsAudio) throw new Exception("This model does not support audio inputs."); - embeds.Add(await LoadEmbedAsync(modelSession, attachment, cancellationToken)); + embeds.Add(await LoadEmbedAsync(sessionId, modelSession, attachment, cancellationToken)); AppendMedia(modelSession, promptBuilder, mediaPrefixBuilder, mediaMarker); break; } @@ -365,7 +381,7 @@ private static void AppendWord(StringBuilder builder, AttachmentInfo attachment) builder.AppendLine("[Word text truncated]"); } - private static async Task LoadEmbedAsync(ModelSession modelSession, AttachmentInfo attachment, CancellationToken cancellationToken) + private async Task LoadEmbedAsync(string sessionId, ModelSession modelSession, AttachmentInfo attachment, CancellationToken cancellationToken) { if (!File.Exists(attachment.FilePath)) throw new FileNotFoundException("Attachment not found.", attachment.FilePath); @@ -376,21 +392,71 @@ private static async Task LoadEmbedAsync(ModelSession modelSessio cancellationToken.ThrowIfCancellationRequested(); - SafeMtmdEmbed? embed; - if (attachment.Kind == AttachmentKind.Audio) + try { - var data = await File.ReadAllBytesAsync(attachment.FilePath, cancellationToken); - embed = clipModel.LoadMediaStandalone(data); + SafeMtmdEmbed? embed; + if (attachment.Kind == AttachmentKind.Audio) + { + var data = await File.ReadAllBytesAsync(attachment.FilePath, cancellationToken); + _logger.LogInformation( + "Loading audio attachment {AttachmentId} for session {SessionId} into MTMD from buffer: file={FileName}, contentType={ContentType}, bytes={ByteCount}, model={ModelName}", + attachment.Id, + sessionId, + attachment.FileName, + attachment.ContentType, + data.Length, + modelSession.ModelName); + embed = clipModel.LoadMediaStandalone(data); + } + else + { + _logger.LogInformation( + "Loading media attachment {AttachmentId} for session {SessionId} into MTMD from file: file={FileName}, contentType={ContentType}, path={FilePath}, model={ModelName}", + attachment.Id, + sessionId, + attachment.FileName, + attachment.ContentType, + attachment.FilePath, + modelSession.ModelName); + embed = clipModel.LoadMediaStandalone(attachment.FilePath); + } + + if (embed == null) + { + _logger.LogWarning( + "MTMD returned null embed for attachment {AttachmentId} in session {SessionId}: kind={Kind}, contentType={ContentType}, file={FileName}", + attachment.Id, + sessionId, + attachment.Kind, + attachment.ContentType, + attachment.FileName); + throw new Exception("Failed to prepare multimodal embedding."); + } + + _logger.LogInformation( + "MTMD embed prepared for attachment {AttachmentId} in session {SessionId}: isAudio={IsAudio}, nx={Nx}, ny={Ny}, byteCount={ByteCount}", + attachment.Id, + sessionId, + embed.IsAudio, + embed.Nx, + embed.Ny, + embed.ByteCount); + + return embed; } - else + catch (Exception ex) { - embed = clipModel.LoadMediaStandalone(attachment.FilePath); + _logger.LogWarning( + ex, + "Failed to load MTMD embed for attachment {AttachmentId} in session {SessionId}: kind={Kind}, contentType={ContentType}, file={FileName}, path={FilePath}", + attachment.Id, + sessionId, + attachment.Kind, + attachment.ContentType, + attachment.FileName, + attachment.FilePath); + throw; } - - if (embed == null) - throw new Exception("Failed to prepare multimodal embedding."); - - return embed; } private sealed record PreparedPrompt(string Prompt); diff --git a/LLama.Web/wwwroot/js/blazorChat.js b/LLama.Web/wwwroot/js/blazorChat.js index 2023637c1..2a0fd5f95 100644 --- a/LLama.Web/wwwroot/js/blazorChat.js +++ b/LLama.Web/wwwroot/js/blazorChat.js @@ -278,6 +278,40 @@ window.llamaChat = (() => { return merged; }; + const normalizeMonoSamples = (samples) => { + if (!samples || samples.length === 0) { + return samples; + } + + let peak = 0; + for (let i = 0; i < samples.length; i++) { + const amplitude = Math.abs(samples[i]); + if (amplitude > peak) { + peak = amplitude; + } + } + + // Treat near-silence as silence instead of amplifying background noise. + if (peak < 0.01) { + return samples; + } + + const targetPeak = 0.85; + const maxGain = 8.0; + const gain = Math.min(maxGain, targetPeak / peak); + if (gain <= 1.05) { + return samples; + } + + const normalized = new Float32Array(samples.length); + for (let i = 0; i < samples.length; i++) { + const value = samples[i] * gain; + normalized[i] = Math.max(-1, Math.min(1, value)); + } + + return normalized; + }; + const mixInputToMono = (inputBuffer) => { const channelCount = inputBuffer.numberOfChannels; const sampleCount = inputBuffer.length; @@ -302,7 +336,7 @@ window.llamaChat = (() => { const convertBlobToWav = async (blob) => { const AudioContextCtor = window.AudioContext || window.webkitAudioContext; if (!AudioContextCtor) { - return blob; + return null; } const audioContext = new AudioContextCtor(); @@ -313,7 +347,7 @@ window.llamaChat = (() => { return new Blob([wavBuffer], { type: "audio/wav" }); } catch (err) { console.warn("Failed to normalize recorded audio to WAV.", err); - return blob; + return null; } finally { if (typeof audioContext.close === "function") { await audioContext.close(); @@ -397,9 +431,11 @@ window.llamaChat = (() => { audioStream = await navigator.mediaDevices.getUserMedia({ audio: { channelCount: { ideal: 1 }, - echoCancellation: true, - noiseSuppression: true, - autoGainControl: true + sampleRate: { ideal: 16000 }, + sampleSize: { ideal: 16 }, + echoCancellation: false, + noiseSuppression: false, + autoGainControl: false } }); const AudioContextCtor = window.AudioContext || window.webkitAudioContext; @@ -454,7 +490,8 @@ window.llamaChat = (() => { return { previewUrl: "", mimeType: "", size: 0 }; } - const wavBuffer = encodeMonoWav(audioSampleRate || 16000, mergedSamples); + const normalizedSamples = normalizeMonoSamples(mergedSamples); + const wavBuffer = encodeMonoWav(audioSampleRate || 16000, normalizedSamples); const wavBlob = new Blob([wavBuffer], { type: "audio/wav" }); clearRecordedAudio(); @@ -482,9 +519,21 @@ window.llamaChat = (() => { const blobType = audioRecorder.mimeType || "audio/webm"; const rawBlob = new Blob(audioChunks, { type: blobType }); const uploadBlob = await convertBlobToWav(rawBlob); - const previewBlob = canPlayAudioType(rawBlob.type) ? rawBlob : uploadBlob; + if (!uploadBlob || uploadBlob.size === 0 || uploadBlob.type !== "audio/wav") { + clearRecordedAudio(); + cleanupAudioStream(); + audioRecorder = null; + audioChunks = []; + resolve({ + previewUrl: "", + mimeType: "", + size: 0, + error: "Unable to convert the recording to WAV for upload." + }); + return; + } - if (!previewBlob || previewBlob.size === 0) { + if (uploadBlob.size === 0) { clearRecordedAudio(); cleanupAudioStream(); audioRecorder = null; @@ -494,13 +543,13 @@ window.llamaChat = (() => { } clearRecordedAudio(); - recordedAudioBlob = rawBlob; + recordedAudioBlob = uploadBlob; recordedAudioUploadBlob = uploadBlob; - recordedAudioPreviewUrl = URL.createObjectURL(previewBlob); + recordedAudioPreviewUrl = URL.createObjectURL(uploadBlob); cleanupAudioStream(); audioRecorder = null; audioChunks = []; - resolve({ previewUrl: recordedAudioPreviewUrl, mimeType: previewBlob.type, size: previewBlob.size }); + resolve({ previewUrl: recordedAudioPreviewUrl, mimeType: uploadBlob.type, size: uploadBlob.size }); }; try { From 6670e77ec30363d3f569e67dc2fb0b276830a257 Mon Sep 17 00:00:00 2001 From: SignalRT Date: Mon, 23 Mar 2026 21:45:32 +0100 Subject: [PATCH 7/8] Move that convenience up out of SafeMtmdModelHandle Keep explicit media passing as the primary API. Treat the implicit queue as optional convenience only. Move that convenience up out of SafeMtmdModelHandle --- LLama/MtmdWeights.cs | 53 +++++++++++++++-- LLama/Native/SafeMtmdModelHandle.cs | 92 ++--------------------------- 2 files changed, 53 insertions(+), 92 deletions(-) diff --git a/LLama/MtmdWeights.cs b/LLama/MtmdWeights.cs index 1655601ca..7ee7de886 100644 --- a/LLama/MtmdWeights.cs +++ b/LLama/MtmdWeights.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using LLama.Exceptions; @@ -13,6 +14,7 @@ public sealed class MtmdWeights : IDisposable { private readonly object _syncRoot = new(); + private readonly List _pendingMedia = new(); /// /// The native handle, which is used in the native APIs @@ -84,7 +86,11 @@ public static async Task LoadFromFileAsync(string mmProject, LLamaW public SafeMtmdEmbed LoadMedia(string path) { lock (_syncRoot) - return NativeHandle.LoadMediaFromFile(path); + { + var embed = NativeHandle.CreateMediaEmbedFromFile(path); + _pendingMedia.Add(embed); + return embed; + } } /// @@ -93,7 +99,11 @@ public SafeMtmdEmbed LoadMedia(string path) public SafeMtmdEmbed LoadMedia(ReadOnlySpan data) { lock (_syncRoot) - return NativeHandle.LoadMediaFromBuffer(data); + { + var embed = NativeHandle.CreateMediaEmbedFromBuffer(data); + _pendingMedia.Add(embed); + return embed; + } } /// @@ -120,7 +130,12 @@ public SafeMtmdEmbed LoadMediaStandalone(ReadOnlySpan data) public void ClearMedia() { lock (_syncRoot) - NativeHandle.ClearMedia(); + { + foreach (var media in _pendingMedia) + media.Dispose(); + + _pendingMedia.Clear(); + } } /// @@ -129,7 +144,20 @@ public void ClearMedia() public int Tokenize(string text, bool addSpecial, bool parseSpecial, out SafeMtmdInputChunks? chunks) { lock (_syncRoot) - return NativeHandle.Tokenize(text, addSpecial, parseSpecial, out chunks); + { + if (_pendingMedia.Count == 0) + return NativeHandle.Tokenize(text, addSpecial, parseSpecial, out chunks); + + var embeds = _pendingMedia.ToArray(); + try + { + return NativeHandle.Tokenize(text, addSpecial, parseSpecial, embeds, out chunks); + } + finally + { + ClearMediaInternal(); + } + } } /// @@ -189,5 +217,20 @@ public int DecodeImageChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext public int AudioBitrate => NativeHandle.GetAudioBitrate(); /// - public void Dispose() => NativeHandle.Dispose(); + public void Dispose() + { + lock (_syncRoot) + { + ClearMediaInternal(); + NativeHandle.Dispose(); + } + } + + private void ClearMediaInternal() + { + foreach (var media in _pendingMedia) + media.Dispose(); + + _pendingMedia.Clear(); + } } diff --git a/LLama/Native/SafeMtmdModelHandle.cs b/LLama/Native/SafeMtmdModelHandle.cs index 2b31f2d7e..0edad5aa3 100644 --- a/LLama/Native/SafeMtmdModelHandle.cs +++ b/LLama/Native/SafeMtmdModelHandle.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using System.IO; using LLama.Exceptions; @@ -12,9 +11,6 @@ namespace LLama.Native /// public sealed class SafeMtmdModelHandle : SafeLLamaHandleBase { - // Pending media embeddings queued for the next call to Tokenize. - private readonly List _pendingMedia = new(); - /// protected override bool ReleaseHandle() { @@ -66,39 +62,7 @@ public static SafeMtmdModelHandle LoadFromFile(string modelPath, LLamaWeights te } /// - /// Load media from disk and queue it for the next tokenize call. - /// - /// Absolute or relative path to the media asset. - /// Safe handle to the media embedding. - /// The model handle has been disposed. - /// The native loader failed to ingest the file. - public SafeMtmdEmbed LoadMediaFromFile(string path) - { - EnsureNotDisposed(); - - var embed = CreateMediaEmbedFromFile(path); - _pendingMedia.Add(embed); - return embed; - } - - /// - /// Load media from an in-memory buffer and queue it for the next tokenize call. - /// - /// Binary buffer containing the encoded media data. - /// Safe handle to the media embedding. - /// The model handle has been disposed. - /// The native loader failed to ingest the buffer contents. - public SafeMtmdEmbed LoadMediaFromBuffer(ReadOnlySpan buffer) - { - EnsureNotDisposed(); - - var embed = CreateMediaEmbedFromBuffer(buffer); - _pendingMedia.Add(embed); - return embed; - } - - /// - /// Create a standalone media embedding from disk without queueing it for the next tokenize call. + /// Create a media embedding from disk. /// /// Path to the media file on disk. /// Safe handle to the prepared media embedding. @@ -113,7 +77,7 @@ public SafeMtmdEmbed CreateMediaEmbedFromFile(string path) } /// - /// Create a standalone media embedding from an in-memory buffer without queueing it for the next tokenize call. + /// Create a media embedding from an in-memory buffer. /// /// Binary buffer containing the encoded media data. /// Safe handle to the prepared media embedding. @@ -128,17 +92,7 @@ public SafeMtmdEmbed CreateMediaEmbedFromBuffer(ReadOnlySpan buffer) } /// - /// Disposes and clears any media buffers currently queued for tokenization. - /// - public void ClearMedia() - { - foreach (var media in _pendingMedia) - media.Dispose(); - _pendingMedia.Clear(); - } - - /// - /// Tokenize a prompt alongside the pending media buffers. Pending media is cleared on success. + /// Tokenize a prompt without any media embeddings. /// /// Prompt text to tokenize. /// Whether to append special tokens automatically. @@ -148,43 +102,7 @@ public void ClearMedia() /// The model handle has been disposed. public int Tokenize(string text, bool addSpecial, bool parseSpecial, out SafeMtmdInputChunks? chunks) { - EnsureNotDisposed(); - - chunks = null; - // Allocate the chunk container before invoking the native tokenizer. - var output = NativeApi.mtmd_input_chunks_init(); - if (output == IntPtr.Zero) - throw new RuntimeError("Failed to allocate mtmd_input_chunks."); - - // Collect native pointers to the queued media embeddings. - var bitmapHandles = new IntPtr[_pendingMedia.Count]; - for (var i = 0; i < _pendingMedia.Count; i++) - bitmapHandles[i] = _pendingMedia[i].NativePtr; - - try - { - var result = NativeApi.mtmd_tokenize(this, output, text, addSpecial, parseSpecial, bitmapHandles, (nuint)bitmapHandles.Length); - - if (result == 0) - { - chunks = new SafeMtmdInputChunks(output); - } - else - { - NativeApi.mtmd_input_chunks_free(output); - } - - ClearMedia(); - - return result; - } - finally - { - // bitmapHandles contains dangerous references to media embeddings objects, we must ensure that these - // remain valid for the complete time that the bitmapHandles array is in use. KeepAlive guarantees the - // array remains valid to this point, and thus all of the safe handles too. - GC.KeepAlive(bitmapHandles); - } + return Tokenize(text, addSpecial, parseSpecial, ReadOnlySpan.Empty, out chunks); } /// @@ -231,7 +149,7 @@ public int Tokenize(string text, bool addSpecial, bool parseSpecial, ReadOnlySpa /// /// Evaluate a batch of chunks using the helper (mirrors mtmd-helper eval logic). /// - /// Chunk collection produced by . + /// Chunk collection produced by . /// Context handle that receives the evaluated tokens. /// Number of past tokens; updated when evaluation succeeds. /// Sequence identifier used for KV cache management. From b34cf13aa3afe536f9b202285d7aa8e8807a15b3 Mon Sep 17 00:00:00 2001 From: SignalRT Date: Thu, 26 Mar 2026 21:48:06 +0100 Subject: [PATCH 8/8] Solve Copilot comments --- LLama.Web/Controllers/AttachmentController.cs | 39 ++++-- LLama.Web/Pages/Shared/_Layout.cshtml | 2 - LLama.Web/Services/AttachmentService.cs | 120 +++++++++++++----- LLama.Web/libman.json | 14 -- LLama.Web/wwwroot/js/blazorChat.js | 3 - LLama.Web/wwwroot/js/sessionConnectionChat.js | 2 - 6 files changed, 116 insertions(+), 64 deletions(-) diff --git a/LLama.Web/Controllers/AttachmentController.cs b/LLama.Web/Controllers/AttachmentController.cs index fe57dc6a2..72dca9c32 100644 --- a/LLama.Web/Controllers/AttachmentController.cs +++ b/LLama.Web/Controllers/AttachmentController.cs @@ -9,14 +9,16 @@ namespace LLama.Web.Controllers; public class AttachmentController : ControllerBase { private readonly IAttachmentService _attachmentService; + private readonly IModelSessionService _modelSessionService; - public AttachmentController(IAttachmentService attachmentService) + public AttachmentController(IAttachmentService attachmentService, IModelSessionService modelSessionService) { _attachmentService = attachmentService; + _modelSessionService = modelSessionService; } [HttpPost] - [RequestSizeLimit(256_000_000)] + [RequestSizeLimit(AttachmentService.MaxUploadSizeBytes)] public async Task> Upload([FromForm] string connectionId, [FromForm] List files, CancellationToken cancellationToken) { if (string.IsNullOrWhiteSpace(connectionId)) @@ -25,6 +27,9 @@ public async Task> Upload([FromForm] string if (files is null || files.Count == 0) return BadRequest("No files provided."); + if (await _modelSessionService.GetAsync(connectionId) is null) + return BadRequest("Unknown or inactive connectionId."); + try { var result = await _attachmentService.SaveAsync(connectionId, files, cancellationToken); @@ -37,22 +42,32 @@ public async Task> Upload([FromForm] string } [HttpGet("{connectionId}/{id}")] - public ActionResult Download(string connectionId, string id) + public async Task Download(string connectionId, string id) { if (string.IsNullOrWhiteSpace(connectionId) || string.IsNullOrWhiteSpace(id)) return BadRequest("Missing connectionId or attachment id."); - var attachment = _attachmentService.GetAttachment(connectionId, id); - if (attachment == null || string.IsNullOrWhiteSpace(attachment.FilePath)) - return NotFound(); + try + { + if (await _modelSessionService.GetAsync(connectionId) is null) + return NotFound(); - if (!System.IO.File.Exists(attachment.FilePath)) - return NotFound(); + var attachment = _attachmentService.GetAttachment(connectionId, id); + if (attachment == null || string.IsNullOrWhiteSpace(attachment.FilePath)) + return NotFound(); - var contentType = string.IsNullOrWhiteSpace(attachment.ContentType) - ? "application/octet-stream" - : attachment.ContentType; + if (!System.IO.File.Exists(attachment.FilePath)) + return NotFound(); - return PhysicalFile(attachment.FilePath, contentType, attachment.FileName); + var contentType = string.IsNullOrWhiteSpace(attachment.ContentType) + ? "application/octet-stream" + : attachment.ContentType; + + return PhysicalFile(attachment.FilePath, contentType, attachment.FileName); + } + catch (InvalidOperationException ex) + { + return BadRequest(ex.Message); + } } } diff --git a/LLama.Web/Pages/Shared/_Layout.cshtml b/LLama.Web/Pages/Shared/_Layout.cshtml index 417f38661..6fcd77f1b 100644 --- a/LLama.Web/Pages/Shared/_Layout.cshtml +++ b/LLama.Web/Pages/Shared/_Layout.cshtml @@ -5,7 +5,6 @@ @ViewData["Title"] - LLamaSharp Web - @@ -17,7 +16,6 @@ - diff --git a/LLama.Web/Services/AttachmentService.cs b/LLama.Web/Services/AttachmentService.cs index a5d46e3e5..d76e51be5 100644 --- a/LLama.Web/Services/AttachmentService.cs +++ b/LLama.Web/Services/AttachmentService.cs @@ -12,7 +12,8 @@ namespace LLama.Web.Services; public class AttachmentService : IAttachmentService { private const int MaxExtractedCharacters = 12000; - private const long MaxUploadSize = 512L * 1024 * 1024; + public const long MaxUploadSizeBytes = 512L * 1024 * 1024; + private static readonly StringComparison PathComparison = OperatingSystem.IsWindows() ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal; private readonly IWebHostEnvironment _environment; private readonly ILogger _logger; private readonly string _uploadsRoot; @@ -31,14 +32,13 @@ public AttachmentService(IWebHostEnvironment environment, ILogger SaveAsync(string connectionId, IEnumerable files, CancellationToken cancellationToken) { - if (string.IsNullOrWhiteSpace(connectionId)) - throw new ArgumentException("Connection id is required.", nameof(connectionId)); + var normalizedConnectionId = NormalizeConnectionId(connectionId); ValidateUploads(files); var result = new AttachmentUploadResult(); - var storage = _attachments.GetOrAdd(connectionId, _ => new ConcurrentDictionary()); - var root = Path.Combine(_uploadsRoot, connectionId); + var storage = _attachments.GetOrAdd(normalizedConnectionId, _ => new ConcurrentDictionary()); + var root = GetConnectionRoot(normalizedConnectionId); Directory.CreateDirectory(root); foreach (var file in files) @@ -58,7 +58,7 @@ public async Task SaveAsync(string connectionId, IEnumer var info = new AttachmentInfo { Id = id, - ConnectionId = connectionId, + ConnectionId = normalizedConnectionId, FileName = safeName, FilePath = filePath, ContentType = file.ContentType, @@ -74,7 +74,7 @@ public async Task SaveAsync(string connectionId, IEnumer _logger.LogInformation( "Saved form attachment {AttachmentId} for session {SessionId}: name={FileName}, contentType={ContentType}, size={SizeBytes}, kind={Kind}, path={FilePath}", info.Id, - connectionId, + normalizedConnectionId, info.FileName, info.ContentType, info.SizeBytes, @@ -90,14 +90,13 @@ public async Task SaveAsync(string connectionId, IEnumer public async Task SaveAsync(string connectionId, IEnumerable files, CancellationToken cancellationToken) { - if (string.IsNullOrWhiteSpace(connectionId)) - throw new ArgumentException("Connection id is required.", nameof(connectionId)); + var normalizedConnectionId = NormalizeConnectionId(connectionId); ValidateUploads(files); var result = new AttachmentUploadResult(); - var storage = _attachments.GetOrAdd(connectionId, _ => new ConcurrentDictionary()); - var root = Path.Combine(_uploadsRoot, connectionId); + var storage = _attachments.GetOrAdd(normalizedConnectionId, _ => new ConcurrentDictionary()); + var root = GetConnectionRoot(normalizedConnectionId); Directory.CreateDirectory(root); foreach (var file in files) @@ -109,7 +108,7 @@ public async Task SaveAsync(string connectionId, IEnumer var safeName = Path.GetFileName(file.Name); var filePath = Path.Combine(root, $"{id}-{safeName}"); - await using (var input = file.OpenReadStream(maxAllowedSize: MaxUploadSize, cancellationToken)) + await using (var input = file.OpenReadStream(maxAllowedSize: MaxUploadSizeBytes, cancellationToken)) await using (var output = new FileStream(filePath, FileMode.Create, FileAccess.Write, FileShare.None, 81920, useAsync: true)) { await input.CopyToAsync(output, cancellationToken); @@ -118,7 +117,7 @@ public async Task SaveAsync(string connectionId, IEnumer var info = new AttachmentInfo { Id = id, - ConnectionId = connectionId, + ConnectionId = normalizedConnectionId, FileName = safeName, FilePath = filePath, ContentType = file.ContentType, @@ -134,7 +133,7 @@ public async Task SaveAsync(string connectionId, IEnumer _logger.LogInformation( "Saved browser attachment {AttachmentId} for session {SessionId}: name={FileName}, contentType={ContentType}, size={SizeBytes}, kind={Kind}, path={FilePath}", info.Id, - connectionId, + normalizedConnectionId, info.FileName, info.ContentType, info.SizeBytes, @@ -150,10 +149,12 @@ public async Task SaveAsync(string connectionId, IEnumer public IReadOnlyList GetAttachments(string connectionId, IEnumerable ids) { - if (string.IsNullOrWhiteSpace(connectionId) || ids is null) + if (ids is null) return Array.Empty(); - if (!_attachments.TryGetValue(connectionId, out var storage)) + var normalizedConnectionId = NormalizeConnectionId(connectionId); + + if (!_attachments.TryGetValue(normalizedConnectionId, out var storage)) return Array.Empty(); var list = new List(); @@ -171,7 +172,9 @@ public AttachmentInfo GetAttachment(string connectionId, string id) if (string.IsNullOrWhiteSpace(connectionId) || string.IsNullOrWhiteSpace(id)) return null; - if (!_attachments.TryGetValue(connectionId, out var storage)) + var normalizedConnectionId = NormalizeConnectionId(connectionId); + + if (!_attachments.TryGetValue(normalizedConnectionId, out var storage)) return null; return storage.TryGetValue(id, out var info) ? info : null; @@ -182,9 +185,11 @@ public Task CleanupAsync(string connectionId) if (string.IsNullOrWhiteSpace(connectionId)) return Task.CompletedTask; - if (_attachments.TryRemove(connectionId, out _)) + var normalizedConnectionId = NormalizeConnectionId(connectionId); + + if (_attachments.TryRemove(normalizedConnectionId, out _)) { - var root = Path.Combine(_uploadsRoot, connectionId); + var root = GetConnectionRoot(normalizedConnectionId); if (Directory.Exists(root)) Directory.Delete(root, recursive: true); } @@ -262,30 +267,83 @@ private static bool IsAllowedUpload(string contentType, string extension) private static void ValidateUploads(IEnumerable files) { - var invalid = files - .Where(file => file != null) - .Where(file => !IsAllowedUpload(file.ContentType?.ToLowerInvariant() ?? string.Empty, Path.GetExtension(file.FileName).ToLowerInvariant())) - .Select(file => file.FileName) - .ToList(); + var invalid = new List(); + var tooLarge = new List(); + + foreach (var file in files.Where(file => file != null)) + { + if (!IsAllowedUpload(file.ContentType?.ToLowerInvariant() ?? string.Empty, Path.GetExtension(file.FileName).ToLowerInvariant())) + invalid.Add(file.FileName); + + if (file.Length > MaxUploadSizeBytes) + tooLarge.Add(file.FileName); + } + + if (tooLarge.Count > 0) + throw new InvalidOperationException($"Files exceed the {FormatMaxUploadSize()} limit: {string.Join(", ", tooLarge)}."); if (invalid.Count == 0) return; - throw new InvalidOperationException($"Unsupported files: {string.Join(", ", invalid)}. Use PDF, DOCX, or images."); + throw new InvalidOperationException($"Unsupported files: {string.Join(", ", invalid)}. Use PDF, DOCX, images, or audio."); } private static void ValidateUploads(IEnumerable files) { - var invalid = files - .Where(file => file != null) - .Where(file => !IsAllowedUpload(file.ContentType?.ToLowerInvariant() ?? string.Empty, Path.GetExtension(file.Name).ToLowerInvariant())) - .Select(file => file.Name) - .ToList(); + var invalid = new List(); + var tooLarge = new List(); + + foreach (var file in files.Where(file => file != null)) + { + if (!IsAllowedUpload(file.ContentType?.ToLowerInvariant() ?? string.Empty, Path.GetExtension(file.Name).ToLowerInvariant())) + invalid.Add(file.Name); + + if (file.Size > MaxUploadSizeBytes) + tooLarge.Add(file.Name); + } + + if (tooLarge.Count > 0) + throw new InvalidOperationException($"Files exceed the {FormatMaxUploadSize()} limit: {string.Join(", ", tooLarge)}."); if (invalid.Count == 0) return; - throw new InvalidOperationException($"Unsupported files: {string.Join(", ", invalid)}. Use PDF, DOCX, or images."); + throw new InvalidOperationException($"Unsupported files: {string.Join(", ", invalid)}. Use PDF, DOCX, images, or audio."); + } + + private string GetConnectionRoot(string connectionId) + { + var fullUploadsRoot = Path.GetFullPath(_uploadsRoot); + var fullConnectionRoot = Path.GetFullPath(Path.Combine(fullUploadsRoot, connectionId)); + var uploadsRootWithSeparator = fullUploadsRoot.EndsWith(Path.DirectorySeparatorChar) + ? fullUploadsRoot + : fullUploadsRoot + Path.DirectorySeparatorChar; + + if (!fullConnectionRoot.StartsWith(uploadsRootWithSeparator, PathComparison)) + throw new InvalidOperationException("Invalid connection id."); + + return fullConnectionRoot; + } + + private static string NormalizeConnectionId(string connectionId) + { + if (string.IsNullOrWhiteSpace(connectionId)) + throw new InvalidOperationException("Connection id is required."); + + foreach (var character in connectionId) + { + if (char.IsAsciiLetterOrDigit(character) || character is '-' or '_') + continue; + + throw new InvalidOperationException("Invalid connection id."); + } + + return connectionId; + } + + private static string FormatMaxUploadSize() + { + return $"{MaxUploadSizeBytes / (1024 * 1024)} MB"; } private static void ExtractPdfText(AttachmentInfo info) diff --git a/LLama.Web/libman.json b/LLama.Web/libman.json index 8a7b77c9b..008384d32 100644 --- a/LLama.Web/libman.json +++ b/LLama.Web/libman.json @@ -6,15 +6,6 @@ "library": "jquery@3.7.1", "destination": "wwwroot/lib/jquery/" }, - { - "library": "katex@0.16.11", - "provider": "unpkg", - "destination": "wwwroot/lib/katex/", - "files": [ - "dist/katex.min.css", - "dist/katex.min.js" - ] - }, { "library": "markdown-it@14.1.0", "provider": "unpkg", @@ -23,11 +14,6 @@ "dist/markdown-it.min.js" ] }, - { - "library": "markdown-it-katex@2.0.3", - "provider": "unpkg", - "destination": "wwwroot/lib/markdown-it-katex/" - }, { "library": "markdown-it-task-lists@2.1.1", "provider": "unpkg", diff --git a/LLama.Web/wwwroot/js/blazorChat.js b/LLama.Web/wwwroot/js/blazorChat.js index 2a0fd5f95..3fc376904 100644 --- a/LLama.Web/wwwroot/js/blazorChat.js +++ b/LLama.Web/wwwroot/js/blazorChat.js @@ -49,9 +49,6 @@ window.llamaChat = (() => { breaks: true }); - if (window.markdownitKatex) { - markdown.use(window.markdownitKatex); - } if (window.markdownitTaskLists) { markdown.use(window.markdownitTaskLists); } diff --git a/LLama.Web/wwwroot/js/sessionConnectionChat.js b/LLama.Web/wwwroot/js/sessionConnectionChat.js index 0afd1f35b..cc1e03cc9 100644 --- a/LLama.Web/wwwroot/js/sessionConnectionChat.js +++ b/LLama.Web/wwwroot/js/sessionConnectionChat.js @@ -292,8 +292,6 @@ const createConnectionSessionChat = () => { typographer: true }); - if (window.markdownitKatex) - md.use(window.markdownitKatex); if (window.markdownitTaskLists) md.use(window.markdownitTaskLists, { enabled: true, label: true }); if (window.markdownitFootnote)