From 0525b91538f3a4ceb13cd6b9a4fbfaffe9811ca0 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 18 Dec 2025 17:47:36 +0100 Subject: [PATCH 1/5] Complete feature 00007: batching, retries, managed identity - Add EmbeddingsConfig.batchSize + validate provider configs - Implement shared HttpRetryPolicy (Retry-After + exponential backoff) and use it in OpenAI/Azure/HF/Ollama - Support Azure OpenAI managed identity via DefaultAzureCredential - Support HuggingFace HF_TOKEN fallback - Update docs + expand unit/integration tests; stabilize env-var tests --- src/Core/Config/AppConfig.cs | 2 + .../Embeddings/AzureOpenAIEmbeddingsConfig.cs | 5 + .../Config/Embeddings/EmbeddingsConfig.cs | 8 + .../Embeddings/HuggingFaceEmbeddingsConfig.cs | 11 +- .../Embeddings/OllamaEmbeddingsConfig.cs | 5 + .../Embeddings/OpenAIEmbeddingsConfig.cs | 5 + src/Core/Constants.cs | 21 +++ src/Core/Core.csproj | 1 + .../AzureOpenAIEmbeddingGenerator.cs | 97 +++++++++-- .../HuggingFaceEmbeddingGenerator.cs | 53 +++++- .../Providers/OllamaEmbeddingGenerator.cs | 22 ++- .../Providers/OpenAIEmbeddingGenerator.cs | 53 +++++- src/Core/Http/HttpRetryPolicy.cs | 151 ++++++++++++++++++ .../Services/EmbeddingGeneratorFactory.cs | 14 +- .../AzureOpenAIEmbeddingGeneratorTests.cs | 93 +++++++++-- .../HuggingFaceEmbeddingGeneratorTests.cs | 27 ++-- .../OllamaEmbeddingGeneratorTests.cs | 37 +++-- .../OpenAIEmbeddingGeneratorTests.cs | 76 +++++++-- tests/Core.Tests/Http/HttpRetryPolicyTests.cs | 52 ++++++ .../Logging/EnvironmentDetectorTests.cs | 1 + .../SensitiveDataScrubbingPolicyTests.cs | 1 + tests/Core.Tests/TestCollections.cs | 8 + .../Integration/SearchProcessTests.cs | 49 ++++-- .../EmbeddingGeneratorFactoryTests.cs | 32 ++++ 24 files changed, 725 insertions(+), 99 deletions(-) create mode 100644 src/Core/Http/HttpRetryPolicy.cs create mode 100644 tests/Core.Tests/Http/HttpRetryPolicyTests.cs create mode 100644 tests/Core.Tests/TestCollections.cs diff --git a/src/Core/Config/AppConfig.cs b/src/Core/Config/AppConfig.cs index 24c9a624a..7277ee0ec 100644 --- a/src/Core/Config/AppConfig.cs +++ b/src/Core/Config/AppConfig.cs @@ -10,7 +10,9 @@ namespace KernelMemory.Core.Config; /// Root configuration for Kernel Memory application /// Loaded from ~/.km/config.json or custom path /// +#pragma warning disable CA1724 // Conflicts with Microsoft.Identity.Client.AppConfig when Azure.Identity is referenced. public sealed class AppConfig : IValidatable +#pragma warning restore CA1724 { /// /// Named memory nodes (e.g., "personal", "work") diff --git a/src/Core/Config/Embeddings/AzureOpenAIEmbeddingsConfig.cs b/src/Core/Config/Embeddings/AzureOpenAIEmbeddingsConfig.cs index e2cfc5475..a7411b6d6 100644 --- a/src/Core/Config/Embeddings/AzureOpenAIEmbeddingsConfig.cs +++ b/src/Core/Config/Embeddings/AzureOpenAIEmbeddingsConfig.cs @@ -81,5 +81,10 @@ public override void Validate(string path) throw new ConfigException(path, "Azure OpenAI: specify either ApiKey or UseManagedIdentity, not both"); } + + if (this.BatchSize < 1) + { + throw new ConfigException($"{path}.BatchSize", "BatchSize must be >= 1"); + } } } diff --git a/src/Core/Config/Embeddings/EmbeddingsConfig.cs b/src/Core/Config/Embeddings/EmbeddingsConfig.cs index 4fae204d1..d072b7acc 100644 --- a/src/Core/Config/Embeddings/EmbeddingsConfig.cs +++ b/src/Core/Config/Embeddings/EmbeddingsConfig.cs @@ -21,6 +21,14 @@ public abstract class EmbeddingsConfig : IValidatable [JsonIgnore] public abstract EmbeddingsTypes Type { get; } + /// + /// Maximum number of texts to send per embeddings API request. + /// Providers that support batch requests should chunk input using this size. + /// Default: 10. + /// + [JsonPropertyName("batchSize")] + public int BatchSize { get; set; } = Constants.EmbeddingDefaults.DefaultBatchSize; + /// /// Validates the embeddings configuration /// diff --git a/src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs b/src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs index 6e750c81f..0be95dcbe 100644 --- a/src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs +++ b/src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs @@ -44,9 +44,16 @@ public override void Validate(string path) throw new ConfigException($"{path}.Model", "HuggingFace model name is required"); } - if (string.IsNullOrWhiteSpace(this.ApiKey)) + // ApiKey can be provided via config or HF_TOKEN environment variable + if (string.IsNullOrWhiteSpace(this.ApiKey) && + string.IsNullOrWhiteSpace(Environment.GetEnvironmentVariable("HF_TOKEN"))) { - throw new ConfigException($"{path}.ApiKey", "HuggingFace API key is required"); + throw new ConfigException($"{path}.ApiKey", "HuggingFace API key is required (set ApiKey or HF_TOKEN)"); + } + + if (this.BatchSize < 1) + { + throw new ConfigException($"{path}.BatchSize", "BatchSize must be >= 1"); } if (string.IsNullOrWhiteSpace(this.BaseUrl)) diff --git a/src/Core/Config/Embeddings/OllamaEmbeddingsConfig.cs b/src/Core/Config/Embeddings/OllamaEmbeddingsConfig.cs index e938b73e9..6d16e2dfd 100644 --- a/src/Core/Config/Embeddings/OllamaEmbeddingsConfig.cs +++ b/src/Core/Config/Embeddings/OllamaEmbeddingsConfig.cs @@ -39,6 +39,11 @@ public override void Validate(string path) throw new ConfigException($"{path}.BaseUrl", "Ollama base URL is required"); } + if (this.BatchSize < 1) + { + throw new ConfigException($"{path}.BatchSize", "BatchSize must be >= 1"); + } + if (!Uri.TryCreate(this.BaseUrl, UriKind.Absolute, out _)) { throw new ConfigException($"{path}.BaseUrl", diff --git a/src/Core/Config/Embeddings/OpenAIEmbeddingsConfig.cs b/src/Core/Config/Embeddings/OpenAIEmbeddingsConfig.cs index fa70cbe3d..257c2a3c6 100644 --- a/src/Core/Config/Embeddings/OpenAIEmbeddingsConfig.cs +++ b/src/Core/Config/Embeddings/OpenAIEmbeddingsConfig.cs @@ -45,6 +45,11 @@ public override void Validate(string path) throw new ConfigException($"{path}.ApiKey", "OpenAI API key is required"); } + if (this.BatchSize < 1) + { + throw new ConfigException($"{path}.BatchSize", "BatchSize must be >= 1"); + } + if (!string.IsNullOrWhiteSpace(this.BaseUrl) && !Uri.TryCreate(this.BaseUrl, UriKind.Absolute, out _)) { diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index e611b6624..0ad6baf3c 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -210,6 +210,27 @@ public static bool TryGetDimensions(string modelName, out int dimensions) } } + /// + /// Constants for HTTP retry/backoff used by external providers (embeddings, etc.). + /// + public static class HttpRetryDefaults + { + /// + /// Maximum attempts including the first try. + /// + public const int MaxAttempts = 5; + + /// + /// Base delay for exponential backoff. + /// + public const int BaseDelayMs = 200; + + /// + /// Maximum delay between attempts. + /// + public const int MaxDelayMs = 5000; + } + /// /// Constants for the logging system including file rotation, log levels, /// and output formatting. diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index 0518136d4..ef74afee6 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -7,6 +7,7 @@ + diff --git a/src/Core/Embeddings/Providers/AzureOpenAIEmbeddingGenerator.cs b/src/Core/Embeddings/Providers/AzureOpenAIEmbeddingGenerator.cs index fade83bb0..86918e548 100644 --- a/src/Core/Embeddings/Providers/AzureOpenAIEmbeddingGenerator.cs +++ b/src/Core/Embeddings/Providers/AzureOpenAIEmbeddingGenerator.cs @@ -1,7 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Net.Http.Headers; using System.Net.Http.Json; using System.Text.Json.Serialization; +using Azure.Core; +using Azure.Identity; using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Http; using Microsoft.Extensions.Logging; namespace KernelMemory.Core.Embeddings.Providers; @@ -9,15 +13,19 @@ namespace KernelMemory.Core.Embeddings.Providers; /// /// Azure OpenAI embedding generator implementation. /// Communicates with Azure OpenAI Service. -/// Supports API key authentication (managed identity would require Azure.Identity package). +/// Supports API key authentication or managed identity via . /// public sealed class AzureOpenAIEmbeddingGenerator : IEmbeddingGenerator { private readonly HttpClient _httpClient; private readonly string _endpoint; private readonly string _deployment; - private readonly string _apiKey; + private readonly string? _apiKey; + private readonly bool _useManagedIdentity; + private readonly TokenCredential? _credential; + private readonly int _batchSize; private readonly ILogger _logger; + private readonly Func _delayAsync; /// public EmbeddingsTypes ProviderType => EmbeddingsTypes.AzureOpenAI; @@ -38,35 +46,52 @@ public sealed class AzureOpenAIEmbeddingGenerator : IEmbeddingGenerator /// Azure OpenAI endpoint (e.g., https://myservice.openai.azure.com). /// Deployment name in Azure. /// Model name for identification. - /// Azure OpenAI API key. + /// Azure OpenAI API key (required unless is true). /// Vector dimensions produced by the model. /// Whether vectors are normalized. /// Logger instance. + /// Maximum number of texts per API request. + /// Whether to authenticate using managed identity. + /// Optional token credential (used for testing); defaults to . + /// Optional delay function for retries (used for fast unit tests). public AzureOpenAIEmbeddingGenerator( HttpClient httpClient, string endpoint, string deployment, string model, - string apiKey, + string? apiKey, int vectorDimensions, bool isNormalized, - ILogger logger) + ILogger logger, + int batchSize, + bool useManagedIdentity, + TokenCredential? credential = null, + Func? delayAsync = null) { ArgumentNullException.ThrowIfNull(httpClient, nameof(httpClient)); ArgumentNullException.ThrowIfNull(endpoint, nameof(endpoint)); ArgumentNullException.ThrowIfNull(deployment, nameof(deployment)); ArgumentNullException.ThrowIfNull(model, nameof(model)); - ArgumentNullException.ThrowIfNull(apiKey, nameof(apiKey)); ArgumentNullException.ThrowIfNull(logger, nameof(logger)); + ArgumentOutOfRangeException.ThrowIfLessThan(batchSize, 1, nameof(batchSize)); this._httpClient = httpClient; this._endpoint = endpoint.TrimEnd('/'); this._deployment = deployment; this._apiKey = apiKey; + this._useManagedIdentity = useManagedIdentity; + this._credential = credential; + this._batchSize = batchSize; this.ModelName = model; this.VectorDimensions = vectorDimensions; this.IsNormalized = isNormalized; this._logger = logger; + this._delayAsync = delayAsync ?? Task.Delay; + + if (!this._useManagedIdentity && string.IsNullOrWhiteSpace(this._apiKey)) + { + throw new ArgumentException("Azure OpenAI API key is required when not using managed identity", nameof(apiKey)); + } this._logger.LogDebug("AzureOpenAIEmbeddingGenerator initialized: {Endpoint}, deployment: {Deployment}, model: {Model}", this._endpoint, this._deployment, this.ModelName); @@ -88,6 +113,18 @@ public async Task GenerateAsync(IEnumerable texts, Ca return []; } + var allResults = new List(textArray.Length); + foreach (var chunk in Chunk(textArray, this._batchSize)) + { + var chunkResults = await this.GenerateBatchAsync(chunk, ct).ConfigureAwait(false); + allResults.AddRange(chunkResults); + } + + return allResults.ToArray(); + } + + private async Task GenerateBatchAsync(string[] textArray, CancellationToken ct) + { var url = $"{this._endpoint}/openai/deployments/{this._deployment}/embeddings?api-version={Constants.EmbeddingDefaults.AzureOpenAIApiVersion}"; var request = new AzureEmbeddingRequest @@ -95,14 +132,34 @@ public async Task GenerateAsync(IEnumerable texts, Ca Input = textArray }; - using var httpRequest = new HttpRequestMessage(HttpMethod.Post, url); - httpRequest.Headers.Add("api-key", this._apiKey); - httpRequest.Content = JsonContent.Create(request); + var bearerToken = this._useManagedIdentity + ? await this.GetManagedIdentityTokenAsync(ct).ConfigureAwait(false) + : null; this._logger.LogTrace("Calling Azure OpenAI embeddings API: deployment={Deployment}, batch size: {BatchSize}", this._deployment, textArray.Length); - var response = await this._httpClient.SendAsync(httpRequest, ct).ConfigureAwait(false); + using var response = await HttpRetryPolicy.SendAsync( + this._httpClient, + requestFactory: () => + { + var httpRequest = new HttpRequestMessage(HttpMethod.Post, url); + if (bearerToken != null) + { + httpRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", bearerToken); + } + else + { + httpRequest.Headers.Add("api-key", this._apiKey); + } + + httpRequest.Content = JsonContent.Create(request); + return httpRequest; + }, + this._logger, + ct, + delayAsync: this._delayAsync).ConfigureAwait(false); + response.EnsureSuccessStatusCode(); var result = await response.Content.ReadFromJsonAsync(ct).ConfigureAwait(false); @@ -141,6 +198,26 @@ public async Task GenerateAsync(IEnumerable texts, Ca return results; } + private async Task GetManagedIdentityTokenAsync(CancellationToken ct) + { + var credential = this._credential ?? new DefaultAzureCredential(); + var token = await credential.GetTokenAsync( + new TokenRequestContext(["https://cognitiveservices.azure.com/.default"]), + ct).ConfigureAwait(false); + return token.Token; + } + + private static IEnumerable Chunk(string[] items, int chunkSize) + { + for (int i = 0; i < items.Length; i += chunkSize) + { + var length = Math.Min(chunkSize, items.Length - i); + var chunk = new string[length]; + Array.Copy(items, i, chunk, 0, length); + yield return chunk; + } + } + /// /// Request body for Azure OpenAI embeddings API. /// diff --git a/src/Core/Embeddings/Providers/HuggingFaceEmbeddingGenerator.cs b/src/Core/Embeddings/Providers/HuggingFaceEmbeddingGenerator.cs index 9fd4537a6..d883012f4 100644 --- a/src/Core/Embeddings/Providers/HuggingFaceEmbeddingGenerator.cs +++ b/src/Core/Embeddings/Providers/HuggingFaceEmbeddingGenerator.cs @@ -3,6 +3,7 @@ using System.Net.Http.Json; using System.Text.Json.Serialization; using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Http; using Microsoft.Extensions.Logging; namespace KernelMemory.Core.Embeddings.Providers; @@ -17,7 +18,9 @@ public sealed class HuggingFaceEmbeddingGenerator : IEmbeddingGenerator private readonly HttpClient _httpClient; private readonly string _apiKey; private readonly string _baseUrl; + private readonly int _batchSize; private readonly ILogger _logger; + private readonly Func _delayAsync; /// public EmbeddingsTypes ProviderType => EmbeddingsTypes.HuggingFace; @@ -41,6 +44,8 @@ public sealed class HuggingFaceEmbeddingGenerator : IEmbeddingGenerator /// Whether vectors are normalized. /// Optional custom base URL for inference endpoints. /// Logger instance. + /// Maximum number of texts per API request. + /// Optional delay function for retries (used for fast unit tests). public HuggingFaceEmbeddingGenerator( HttpClient httpClient, string apiKey, @@ -48,21 +53,26 @@ public HuggingFaceEmbeddingGenerator( int vectorDimensions, bool isNormalized, string? baseUrl, - ILogger logger) + ILogger logger, + int batchSize, + Func? delayAsync = null) { ArgumentNullException.ThrowIfNull(httpClient, nameof(httpClient)); ArgumentNullException.ThrowIfNull(apiKey, nameof(apiKey)); ArgumentException.ThrowIfNullOrEmpty(apiKey, nameof(apiKey)); ArgumentNullException.ThrowIfNull(model, nameof(model)); ArgumentNullException.ThrowIfNull(logger, nameof(logger)); + ArgumentOutOfRangeException.ThrowIfLessThan(batchSize, 1, nameof(batchSize)); this._httpClient = httpClient; this._apiKey = apiKey; this._baseUrl = (baseUrl ?? Constants.EmbeddingDefaults.DefaultHuggingFaceBaseUrl).TrimEnd('/'); + this._batchSize = batchSize; this.ModelName = model; this.VectorDimensions = vectorDimensions; this.IsNormalized = isNormalized; this._logger = logger; + this._delayAsync = delayAsync ?? Task.Delay; this._logger.LogDebug("HuggingFaceEmbeddingGenerator initialized: {BaseUrl}, model: {Model}, dimensions: {Dimensions}", this._baseUrl, this.ModelName, this.VectorDimensions); @@ -84,6 +94,18 @@ public async Task GenerateAsync(IEnumerable texts, Ca return []; } + var allResults = new List(textArray.Length); + foreach (var chunk in Chunk(textArray, this._batchSize)) + { + var chunkResults = await this.GenerateBatchAsync(chunk, ct).ConfigureAwait(false); + allResults.AddRange(chunkResults); + } + + return allResults.ToArray(); + } + + private async Task GenerateBatchAsync(string[] textArray, CancellationToken ct) + { var endpoint = $"{this._baseUrl}/models/{this.ModelName}"; var request = new HuggingFaceRequest @@ -91,14 +113,22 @@ public async Task GenerateAsync(IEnumerable texts, Ca Inputs = textArray }; - using var httpRequest = new HttpRequestMessage(HttpMethod.Post, endpoint); - httpRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", this._apiKey); - httpRequest.Content = JsonContent.Create(request); - this._logger.LogTrace("Calling HuggingFace embeddings API: {Endpoint}, batch size: {BatchSize}", endpoint, textArray.Length); - var response = await this._httpClient.SendAsync(httpRequest, ct).ConfigureAwait(false); + using var response = await HttpRetryPolicy.SendAsync( + this._httpClient, + requestFactory: () => + { + var httpRequest = new HttpRequestMessage(HttpMethod.Post, endpoint); + httpRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", this._apiKey); + httpRequest.Content = JsonContent.Create(request); + return httpRequest; + }, + this._logger, + ct, + delayAsync: this._delayAsync).ConfigureAwait(false); + response.EnsureSuccessStatusCode(); // HuggingFace returns array of embeddings directly: float[][] @@ -122,6 +152,17 @@ public async Task GenerateAsync(IEnumerable texts, Ca return results; } + private static IEnumerable Chunk(string[] items, int chunkSize) + { + for (int i = 0; i < items.Length; i += chunkSize) + { + var length = Math.Min(chunkSize, items.Length - i); + var chunk = new string[length]; + Array.Copy(items, i, chunk, 0, length); + yield return chunk; + } + } + /// /// Request body for HuggingFace Inference API. /// diff --git a/src/Core/Embeddings/Providers/OllamaEmbeddingGenerator.cs b/src/Core/Embeddings/Providers/OllamaEmbeddingGenerator.cs index c9521d2dd..3ea61750f 100644 --- a/src/Core/Embeddings/Providers/OllamaEmbeddingGenerator.cs +++ b/src/Core/Embeddings/Providers/OllamaEmbeddingGenerator.cs @@ -2,6 +2,7 @@ using System.Net.Http.Json; using System.Text.Json.Serialization; using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Http; using Microsoft.Extensions.Logging; namespace KernelMemory.Core.Embeddings.Providers; @@ -16,6 +17,7 @@ public sealed class OllamaEmbeddingGenerator : IEmbeddingGenerator private readonly HttpClient _httpClient; private readonly string _baseUrl; private readonly ILogger _logger; + private readonly Func _delayAsync; /// public EmbeddingsTypes ProviderType => EmbeddingsTypes.Ollama; @@ -38,13 +40,15 @@ public sealed class OllamaEmbeddingGenerator : IEmbeddingGenerator /// Vector dimensions produced by the model. /// Whether vectors are normalized. /// Logger instance. + /// Optional delay function for retries (used for fast unit tests). public OllamaEmbeddingGenerator( HttpClient httpClient, string baseUrl, string model, int vectorDimensions, bool isNormalized, - ILogger logger) + ILogger logger, + Func? delayAsync = null) { ArgumentNullException.ThrowIfNull(httpClient, nameof(httpClient)); ArgumentNullException.ThrowIfNull(baseUrl, nameof(baseUrl)); @@ -57,6 +61,7 @@ public OllamaEmbeddingGenerator( this.VectorDimensions = vectorDimensions; this.IsNormalized = isNormalized; this._logger = logger; + this._delayAsync = delayAsync ?? Task.Delay; this._logger.LogDebug("OllamaEmbeddingGenerator initialized: {BaseUrl}, model: {Model}, dimensions: {Dimensions}", this._baseUrl, this.ModelName, this.VectorDimensions); @@ -75,7 +80,20 @@ public async Task GenerateAsync(string text, CancellationToken this._logger.LogTrace("Calling Ollama embeddings API: {Endpoint}", endpoint); - var response = await this._httpClient.PostAsJsonAsync(endpoint, request, ct).ConfigureAwait(false); + using var response = await HttpRetryPolicy.SendAsync( + this._httpClient, + requestFactory: () => + { + var httpRequest = new HttpRequestMessage(HttpMethod.Post, endpoint) + { + Content = JsonContent.Create(request) + }; + return httpRequest; + }, + this._logger, + ct, + delayAsync: this._delayAsync).ConfigureAwait(false); + response.EnsureSuccessStatusCode(); var result = await response.Content.ReadFromJsonAsync(ct).ConfigureAwait(false); diff --git a/src/Core/Embeddings/Providers/OpenAIEmbeddingGenerator.cs b/src/Core/Embeddings/Providers/OpenAIEmbeddingGenerator.cs index 3d8ca2b02..64f5589ff 100644 --- a/src/Core/Embeddings/Providers/OpenAIEmbeddingGenerator.cs +++ b/src/Core/Embeddings/Providers/OpenAIEmbeddingGenerator.cs @@ -3,6 +3,7 @@ using System.Net.Http.Json; using System.Text.Json.Serialization; using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Http; using Microsoft.Extensions.Logging; namespace KernelMemory.Core.Embeddings.Providers; @@ -17,7 +18,9 @@ public sealed class OpenAIEmbeddingGenerator : IEmbeddingGenerator private readonly HttpClient _httpClient; private readonly string _apiKey; private readonly string _baseUrl; + private readonly int _batchSize; private readonly ILogger _logger; + private readonly Func _delayAsync; /// public EmbeddingsTypes ProviderType => EmbeddingsTypes.OpenAI; @@ -41,6 +44,8 @@ public sealed class OpenAIEmbeddingGenerator : IEmbeddingGenerator /// Whether vectors are normalized. /// Optional custom base URL for OpenAI-compatible APIs. /// Logger instance. + /// Maximum number of texts per API request. + /// Optional delay function for retries (used for fast unit tests). public OpenAIEmbeddingGenerator( HttpClient httpClient, string apiKey, @@ -48,21 +53,26 @@ public OpenAIEmbeddingGenerator( int vectorDimensions, bool isNormalized, string? baseUrl, - ILogger logger) + ILogger logger, + int batchSize, + Func? delayAsync = null) { ArgumentNullException.ThrowIfNull(httpClient, nameof(httpClient)); ArgumentNullException.ThrowIfNull(apiKey, nameof(apiKey)); ArgumentException.ThrowIfNullOrEmpty(apiKey, nameof(apiKey)); ArgumentNullException.ThrowIfNull(model, nameof(model)); ArgumentNullException.ThrowIfNull(logger, nameof(logger)); + ArgumentOutOfRangeException.ThrowIfLessThan(batchSize, 1, nameof(batchSize)); this._httpClient = httpClient; this._apiKey = apiKey; this._baseUrl = (baseUrl ?? Constants.EmbeddingDefaults.DefaultOpenAIBaseUrl).TrimEnd('/'); + this._batchSize = batchSize; this.ModelName = model; this.VectorDimensions = vectorDimensions; this.IsNormalized = isNormalized; this._logger = logger; + this._delayAsync = delayAsync ?? Task.Delay; this._logger.LogDebug("OpenAIEmbeddingGenerator initialized: {BaseUrl}, model: {Model}, dimensions: {Dimensions}", this._baseUrl, this.ModelName, this.VectorDimensions); @@ -84,6 +94,18 @@ public async Task GenerateAsync(IEnumerable texts, Ca return []; } + var allResults = new List(textArray.Length); + foreach (var chunk in Chunk(textArray, this._batchSize)) + { + var chunkResults = await this.GenerateBatchAsync(chunk, ct).ConfigureAwait(false); + allResults.AddRange(chunkResults); + } + + return allResults.ToArray(); + } + + private async Task GenerateBatchAsync(string[] textArray, CancellationToken ct) + { var endpoint = $"{this._baseUrl}/v1/embeddings"; var request = new OpenAIEmbeddingRequest @@ -92,14 +114,22 @@ public async Task GenerateAsync(IEnumerable texts, Ca Input = textArray }; - using var httpRequest = new HttpRequestMessage(HttpMethod.Post, endpoint); - httpRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", this._apiKey); - httpRequest.Content = JsonContent.Create(request); - this._logger.LogTrace("Calling OpenAI embeddings API: {Endpoint}, batch size: {BatchSize}", endpoint, textArray.Length); - var response = await this._httpClient.SendAsync(httpRequest, ct).ConfigureAwait(false); + using var response = await HttpRetryPolicy.SendAsync( + this._httpClient, + requestFactory: () => + { + var httpRequest = new HttpRequestMessage(HttpMethod.Post, endpoint); + httpRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", this._apiKey); + httpRequest.Content = JsonContent.Create(request); + return httpRequest; + }, + this._logger, + ct, + delayAsync: this._delayAsync).ConfigureAwait(false); + response.EnsureSuccessStatusCode(); var result = await response.Content.ReadFromJsonAsync(ct).ConfigureAwait(false); @@ -138,6 +168,17 @@ public async Task GenerateAsync(IEnumerable texts, Ca return results; } + private static IEnumerable Chunk(string[] items, int chunkSize) + { + for (int i = 0; i < items.Length; i += chunkSize) + { + var length = Math.Min(chunkSize, items.Length - i); + var chunk = new string[length]; + Array.Copy(items, i, chunk, 0, length); + yield return chunk; + } + } + /// /// Request body for OpenAI embeddings API. /// diff --git a/src/Core/Http/HttpRetryPolicy.cs b/src/Core/Http/HttpRetryPolicy.cs new file mode 100644 index 000000000..f07daccba --- /dev/null +++ b/src/Core/Http/HttpRetryPolicy.cs @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Net; +using Microsoft.Extensions.Logging; + +namespace KernelMemory.Core.Http; + +/// +/// Simple HTTP retry policy with exponential backoff for transient failures (e.g., 429/503). +/// +public static class HttpRetryPolicy +{ + /// + /// Sends an HTTP request using a request factory, retrying on transient failures. + /// The request factory must create a new each time (requests are single-use). + /// + public static async Task SendAsync( + HttpClient httpClient, + Func requestFactory, + ILogger logger, + CancellationToken ct, + Func? delayAsync = null) + { + ArgumentNullException.ThrowIfNull(httpClient, nameof(httpClient)); + ArgumentNullException.ThrowIfNull(requestFactory, nameof(requestFactory)); + ArgumentNullException.ThrowIfNull(logger, nameof(logger)); + + delayAsync ??= Task.Delay; + + Exception? lastException = null; + + for (int attempt = 1; attempt <= Constants.HttpRetryDefaults.MaxAttempts; attempt++) + { + ct.ThrowIfCancellationRequested(); + + using var request = requestFactory(); + + try + { + var response = await httpClient.SendAsync(request, ct).ConfigureAwait(false); + + if (response.IsSuccessStatusCode) + { + return response; + } + + if (!IsRetryableStatusCode(response.StatusCode) || attempt == Constants.HttpRetryDefaults.MaxAttempts) + { + return response; + } + + var delay = CalculateDelay(attempt, response); + logger.LogWarning( + "HTTP call failed with {StatusCode}. Retrying attempt {Attempt}/{MaxAttempts} after {DelayMs}ms", + (int)response.StatusCode, + attempt + 1, + Constants.HttpRetryDefaults.MaxAttempts, + (int)delay.TotalMilliseconds); + + response.Dispose(); + await delayAsync(delay, ct).ConfigureAwait(false); + } + catch (TaskCanceledException ex) when (!ct.IsCancellationRequested) + { + lastException = ex; + if (attempt == Constants.HttpRetryDefaults.MaxAttempts) + { + throw; + } + + var delay = CalculateDelay(attempt, response: null); + logger.LogWarning( + ex, + "HTTP call timed out. Retrying attempt {Attempt}/{MaxAttempts} after {DelayMs}ms", + attempt + 1, + Constants.HttpRetryDefaults.MaxAttempts, + (int)delay.TotalMilliseconds); + + await delayAsync(delay, ct).ConfigureAwait(false); + } + catch (HttpRequestException ex) + { + lastException = ex; + if (attempt == Constants.HttpRetryDefaults.MaxAttempts) + { + throw; + } + + var delay = CalculateDelay(attempt, response: null); + logger.LogWarning( + ex, + "HTTP call failed with exception. Retrying attempt {Attempt}/{MaxAttempts} after {DelayMs}ms", + attempt + 1, + Constants.HttpRetryDefaults.MaxAttempts, + (int)delay.TotalMilliseconds); + + await delayAsync(delay, ct).ConfigureAwait(false); + } + } + + throw lastException ?? new HttpRequestException("HTTP call failed after retries"); + } + + private static bool IsRetryableStatusCode(HttpStatusCode statusCode) + { + return statusCode == HttpStatusCode.TooManyRequests || + statusCode == HttpStatusCode.RequestTimeout || + statusCode == HttpStatusCode.InternalServerError || + statusCode == HttpStatusCode.BadGateway || + statusCode == HttpStatusCode.ServiceUnavailable || + statusCode == HttpStatusCode.GatewayTimeout; + } + + private static TimeSpan CalculateDelay(int attempt, HttpResponseMessage? response) + { + var retryAfter = response != null ? TryGetRetryAfterDelay(response) : null; + if (retryAfter.HasValue) + { + return ClampDelay(retryAfter.Value); + } + + var exponentialMs = Constants.HttpRetryDefaults.BaseDelayMs * Math.Pow(2, attempt - 1); + var clampedMs = Math.Min(exponentialMs, Constants.HttpRetryDefaults.MaxDelayMs); + + // Deterministic jitter (avoid Random in deterministic tests/builds) + var jitterMs = (attempt * 37) % 101; + return TimeSpan.FromMilliseconds(Math.Min(clampedMs + jitterMs, Constants.HttpRetryDefaults.MaxDelayMs)); + } + + private static TimeSpan ClampDelay(TimeSpan delay) + { + var ms = Math.Clamp(delay.TotalMilliseconds, 0, Constants.HttpRetryDefaults.MaxDelayMs); + return TimeSpan.FromMilliseconds(ms); + } + + private static TimeSpan? TryGetRetryAfterDelay(HttpResponseMessage response) + { + if (response.Headers.RetryAfter?.Delta != null) + { + return response.Headers.RetryAfter.Delta.Value; + } + + if (response.Headers.RetryAfter?.Date != null) + { + var delta = response.Headers.RetryAfter.Date.Value - DateTimeOffset.UtcNow; + return delta < TimeSpan.Zero ? TimeSpan.Zero : delta; + } + + return null; + } +} diff --git a/src/Main/Services/EmbeddingGeneratorFactory.cs b/src/Main/Services/EmbeddingGeneratorFactory.cs index 85053134b..82be77055 100644 --- a/src/Main/Services/EmbeddingGeneratorFactory.cs +++ b/src/Main/Services/EmbeddingGeneratorFactory.cs @@ -101,7 +101,8 @@ private static IEmbeddingGenerator CreateOpenAIGenerator( dimensions, isNormalized: true, // OpenAI embeddings are typically normalized config.BaseUrl, - logger); + logger, + batchSize: config.BatchSize); } /// @@ -123,10 +124,12 @@ private static IEmbeddingGenerator CreateAzureOpenAIGenerator( config.Endpoint, config.Deployment, config.Model, - config.ApiKey ?? string.Empty, + config.ApiKey, dimensions, isNormalized: true, // Azure OpenAI embeddings are typically normalized - logger); + logger, + batchSize: config.BatchSize, + useManagedIdentity: config.UseManagedIdentity); } /// @@ -144,11 +147,12 @@ private static IEmbeddingGenerator CreateHuggingFaceGenerator( return new HuggingFaceEmbeddingGenerator( httpClient, - config.ApiKey ?? string.Empty, + string.IsNullOrWhiteSpace(config.ApiKey) ? (Environment.GetEnvironmentVariable("HF_TOKEN") ?? string.Empty) : config.ApiKey, config.Model, dimensions, isNormalized: true, // Sentence-transformers models typically return normalized vectors config.BaseUrl, - logger); + logger, + batchSize: config.BatchSize); } } diff --git a/tests/Core.Tests/Embeddings/Providers/AzureOpenAIEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/Providers/AzureOpenAIEmbeddingGeneratorTests.cs index d57b3e9dc..1c1a59823 100644 --- a/tests/Core.Tests/Embeddings/Providers/AzureOpenAIEmbeddingGeneratorTests.cs +++ b/tests/Core.Tests/Embeddings/Providers/AzureOpenAIEmbeddingGeneratorTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Net; using System.Text.Json; +using Azure.Core; using KernelMemory.Core.Config.Enums; using KernelMemory.Core.Embeddings.Providers; using Microsoft.Extensions.Logging; @@ -38,7 +39,10 @@ public void Properties_ShouldReflectConfiguration() apiKey: "test-key", vectorDimensions: 1536, isNormalized: true, - this._loggerMock.Object); + this._loggerMock.Object, + batchSize: 10, + useManagedIdentity: false, + delayAsync: (_, _) => Task.CompletedTask); // Assert Assert.Equal(EmbeddingsTypes.AzureOpenAI, generator.ProviderType); @@ -70,7 +74,7 @@ public async Task GenerateAsync_Single_ShouldCallAzureEndpoint() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new AzureOpenAIEmbeddingGenerator( - httpClient, "https://myservice.openai.azure.com", "my-embedding", "ada-002", "test-key", 1536, true, this._loggerMock.Object); + httpClient, "https://myservice.openai.azure.com", "my-embedding", "ada-002", "test-key", 1536, true, this._loggerMock.Object, batchSize: 10, useManagedIdentity: false, delayAsync: (_, _) => Task.CompletedTask); // Act var result = await generator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); @@ -101,7 +105,7 @@ public async Task GenerateAsync_WithApiKey_ShouldSendApiKeyHeader() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new AzureOpenAIEmbeddingGenerator( - httpClient, "https://myservice.openai.azure.com", "deployment", "model", "my-api-key", 1536, true, this._loggerMock.Object); + httpClient, "https://myservice.openai.azure.com", "deployment", "model", "my-api-key", 1536, true, this._loggerMock.Object, batchSize: 10, useManagedIdentity: false, delayAsync: (_, _) => Task.CompletedTask); // Act await generator.GenerateAsync("test", CancellationToken.None).ConfigureAwait(false); @@ -144,7 +148,7 @@ public async Task GenerateAsync_Batch_ShouldSendAllInputsInOneRequest() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new AzureOpenAIEmbeddingGenerator( - httpClient, "https://myservice.openai.azure.com", "deployment", "model", "key", 1536, true, this._loggerMock.Object); + httpClient, "https://myservice.openai.azure.com", "deployment", "model", "key", 1536, true, this._loggerMock.Object, batchSize: 10, useManagedIdentity: false, delayAsync: (_, _) => Task.CompletedTask); // Act var results = await generator.GenerateAsync(texts, CancellationToken.None).ConfigureAwait(false); @@ -177,7 +181,7 @@ public async Task GenerateAsync_ShouldIncludeApiVersion() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new AzureOpenAIEmbeddingGenerator( - httpClient, "https://myservice.openai.azure.com", "deployment", "model", "key", 1536, true, this._loggerMock.Object); + httpClient, "https://myservice.openai.azure.com", "deployment", "model", "key", 1536, true, this._loggerMock.Object, batchSize: 10, useManagedIdentity: false, delayAsync: (_, _) => Task.CompletedTask); // Act await generator.GenerateAsync("test", CancellationToken.None).ConfigureAwait(false); @@ -205,7 +209,7 @@ public async Task GenerateAsync_WithUnauthorizedError_ShouldThrowHttpRequestExce var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new AzureOpenAIEmbeddingGenerator( - httpClient, "https://myservice.openai.azure.com", "deployment", "model", "bad-key", 1536, true, this._loggerMock.Object); + httpClient, "https://myservice.openai.azure.com", "deployment", "model", "bad-key", 1536, true, this._loggerMock.Object, batchSize: 10, useManagedIdentity: false, delayAsync: (_, _) => Task.CompletedTask); // Act & Assert await Assert.ThrowsAsync( @@ -226,7 +230,7 @@ public async Task GenerateAsync_WithCancellation_ShouldPropagate() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new AzureOpenAIEmbeddingGenerator( - httpClient, "https://myservice.openai.azure.com", "deployment", "model", "key", 1536, true, this._loggerMock.Object); + httpClient, "https://myservice.openai.azure.com", "deployment", "model", "key", 1536, true, this._loggerMock.Object, batchSize: 10, useManagedIdentity: false, delayAsync: (_, _) => Task.CompletedTask); using var cts = new CancellationTokenSource(); cts.Cancel(); @@ -242,7 +246,7 @@ public void Constructor_WithNullEndpoint_ShouldThrow() // Assert var httpClient = new HttpClient(); Assert.Throws(() => - new AzureOpenAIEmbeddingGenerator(httpClient, null!, "deployment", "model", "key", 1536, true, this._loggerMock.Object)); + new AzureOpenAIEmbeddingGenerator(httpClient, null!, "deployment", "model", "key", 1536, true, this._loggerMock.Object, batchSize: 10, useManagedIdentity: false)); } [Fact] @@ -251,16 +255,81 @@ public void Constructor_WithNullDeployment_ShouldThrow() // Assert var httpClient = new HttpClient(); Assert.Throws(() => - new AzureOpenAIEmbeddingGenerator(httpClient, "https://endpoint", null!, "model", "key", 1536, true, this._loggerMock.Object)); + new AzureOpenAIEmbeddingGenerator(httpClient, "https://endpoint", null!, "model", "key", 1536, true, this._loggerMock.Object, batchSize: 10, useManagedIdentity: false)); } [Fact] - public void Constructor_WithNullApiKey_ShouldThrow() + public void Constructor_WithNullApiKey_WithoutManagedIdentity_ShouldThrow() { // Assert var httpClient = new HttpClient(); - Assert.Throws(() => - new AzureOpenAIEmbeddingGenerator(httpClient, "https://endpoint", "deployment", "model", null!, 1536, true, this._loggerMock.Object)); + Assert.Throws(() => + new AzureOpenAIEmbeddingGenerator(httpClient, "https://endpoint", "deployment", "model", apiKey: null, 1536, true, this._loggerMock.Object, batchSize: 10, useManagedIdentity: false)); + } + + [Fact] + public async Task GenerateAsync_WithManagedIdentity_ShouldSendBearerToken() + { + // Arrange + HttpRequestMessage? capturedRequest = null; + var response = CreateAzureResponse(new[] { new[] { 0.1f } }); + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .Callback((req, _) => capturedRequest = req) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(JsonSerializer.Serialize(response)) + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new AzureOpenAIEmbeddingGenerator( + httpClient, + endpoint: "https://myservice.openai.azure.com", + deployment: "deployment", + model: "model", + apiKey: null, + vectorDimensions: 1536, + isNormalized: true, + this._loggerMock.Object, + batchSize: 10, + useManagedIdentity: true, + credential: new TestTokenCredential("test-token"), + delayAsync: (_, _) => Task.CompletedTask); + + // Act + await generator.GenerateAsync("test", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.NotNull(capturedRequest); + Assert.False(capturedRequest!.Headers.Contains("api-key")); + Assert.Equal("Bearer", capturedRequest.Headers.Authorization?.Scheme); + Assert.Equal("test-token", capturedRequest.Headers.Authorization?.Parameter); + } + + private sealed class TestTokenCredential : TokenCredential + { + private readonly string _token; + + public TestTokenCredential(string token) + { + this._token = token; + } + + public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken) + { + return new AccessToken(this._token, DateTimeOffset.UtcNow.AddMinutes(5)); + } + + public override ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) + { + return ValueTask.FromResult(this.GetToken(requestContext, cancellationToken)); + } } private static AzureEmbeddingResponse CreateAzureResponse(float[][] embeddings) diff --git a/tests/Core.Tests/Embeddings/Providers/HuggingFaceEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/Providers/HuggingFaceEmbeddingGeneratorTests.cs index f0bc4441e..3757ce0cf 100644 --- a/tests/Core.Tests/Embeddings/Providers/HuggingFaceEmbeddingGeneratorTests.cs +++ b/tests/Core.Tests/Embeddings/Providers/HuggingFaceEmbeddingGeneratorTests.cs @@ -37,7 +37,8 @@ public void Properties_ShouldReflectConfiguration() vectorDimensions: 384, isNormalized: true, baseUrl: null, - this._loggerMock.Object); + this._loggerMock.Object, + batchSize: 10); // Assert Assert.Equal(EmbeddingsTypes.HuggingFace, generator.ProviderType); @@ -69,7 +70,7 @@ public async Task GenerateAsync_Single_ShouldCallCorrectEndpoint() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new HuggingFaceEmbeddingGenerator( - httpClient, "hf_token", "sentence-transformers/all-MiniLM-L6-v2", 384, true, null, this._loggerMock.Object); + httpClient, "hf_token", "sentence-transformers/all-MiniLM-L6-v2", 384, true, null, this._loggerMock.Object, batchSize: 10); // Act var result = await generator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); @@ -100,7 +101,7 @@ public async Task GenerateAsync_Single_ShouldSendAuthorizationHeader() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new HuggingFaceEmbeddingGenerator( - httpClient, "hf_my_secret_token", "model", 384, true, null, this._loggerMock.Object); + httpClient, "hf_my_secret_token", "model", 384, true, null, this._loggerMock.Object, batchSize: 10); // Act await generator.GenerateAsync("test", CancellationToken.None).ConfigureAwait(false); @@ -137,7 +138,7 @@ public async Task GenerateAsync_Single_ShouldSendCorrectRequestBody() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new HuggingFaceEmbeddingGenerator( - httpClient, "hf_token", "model", 384, true, null, this._loggerMock.Object); + httpClient, "hf_token", "model", 384, true, null, this._loggerMock.Object, batchSize: 10); // Act await generator.GenerateAsync("hello world", CancellationToken.None).ConfigureAwait(false); @@ -175,7 +176,7 @@ public async Task GenerateAsync_Batch_ShouldProcessAllTexts() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new HuggingFaceEmbeddingGenerator( - httpClient, "hf_token", "model", 384, true, null, this._loggerMock.Object); + httpClient, "hf_token", "model", 384, true, null, this._loggerMock.Object, batchSize: 10); // Act var results = await generator.GenerateAsync(texts, CancellationToken.None).ConfigureAwait(false); @@ -208,7 +209,7 @@ public async Task GenerateAsync_WithCustomBaseUrl_ShouldUseIt() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new HuggingFaceEmbeddingGenerator( - httpClient, "hf_token", "model", 384, true, "https://custom.hf-endpoint.com", this._loggerMock.Object); + httpClient, "hf_token", "model", 384, true, "https://custom.hf-endpoint.com", this._loggerMock.Object, batchSize: 10); // Act var result = await generator.GenerateAsync("test", CancellationToken.None).ConfigureAwait(false); @@ -235,7 +236,7 @@ public async Task GenerateAsync_WithUnauthorizedError_ShouldThrowHttpRequestExce var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new HuggingFaceEmbeddingGenerator( - httpClient, "bad_token", "model", 384, true, null, this._loggerMock.Object); + httpClient, "bad_token", "model", 384, true, null, this._loggerMock.Object, batchSize: 10); // Act & Assert await Assert.ThrowsAsync( @@ -252,7 +253,7 @@ public async Task GenerateAsync_WithModelLoadingError_ShouldThrowHttpRequestExce "SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .ReturnsAsync(new HttpResponseMessage + .ReturnsAsync(() => new HttpResponseMessage { StatusCode = HttpStatusCode.ServiceUnavailable, Content = new StringContent("{\"error\":\"Model is currently loading\"}") @@ -260,7 +261,7 @@ public async Task GenerateAsync_WithModelLoadingError_ShouldThrowHttpRequestExce var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new HuggingFaceEmbeddingGenerator( - httpClient, "hf_token", "model", 384, true, null, this._loggerMock.Object); + httpClient, "hf_token", "model", 384, true, null, this._loggerMock.Object, batchSize: 10); // Act & Assert await Assert.ThrowsAsync( @@ -281,7 +282,7 @@ public async Task GenerateAsync_WithCancellation_ShouldPropagate() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new HuggingFaceEmbeddingGenerator( - httpClient, "hf_token", "model", 384, true, null, this._loggerMock.Object); + httpClient, "hf_token", "model", 384, true, null, this._loggerMock.Object, batchSize: 10); using var cts = new CancellationTokenSource(); cts.Cancel(); @@ -297,7 +298,7 @@ public void Constructor_WithNullApiKey_ShouldThrow() // Assert var httpClient = new HttpClient(); Assert.Throws(() => - new HuggingFaceEmbeddingGenerator(httpClient, null!, "model", 384, true, null, this._loggerMock.Object)); + new HuggingFaceEmbeddingGenerator(httpClient, null!, "model", 384, true, null, this._loggerMock.Object, batchSize: 10)); } [Fact] @@ -306,7 +307,7 @@ public void Constructor_WithEmptyApiKey_ShouldThrow() // Assert var httpClient = new HttpClient(); Assert.Throws(() => - new HuggingFaceEmbeddingGenerator(httpClient, "", "model", 384, true, null, this._loggerMock.Object)); + new HuggingFaceEmbeddingGenerator(httpClient, "", "model", 384, true, null, this._loggerMock.Object, batchSize: 10)); } [Fact] @@ -315,7 +316,7 @@ public void Constructor_WithNullModel_ShouldThrow() // Assert var httpClient = new HttpClient(); Assert.Throws(() => - new HuggingFaceEmbeddingGenerator(httpClient, "key", null!, 384, true, null, this._loggerMock.Object)); + new HuggingFaceEmbeddingGenerator(httpClient, "key", null!, 384, true, null, this._loggerMock.Object, batchSize: 10)); } // Internal request class for testing diff --git a/tests/Core.Tests/Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs index 34fac8b4c..41adc8872 100644 --- a/tests/Core.Tests/Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs +++ b/tests/Core.Tests/Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs @@ -35,7 +35,8 @@ public void Properties_ShouldReflectConfiguration() model: "qwen3-embedding", vectorDimensions: 1024, isNormalized: true, - this._loggerMock.Object); + this._loggerMock.Object, + delayAsync: (_, _) => Task.CompletedTask); // Assert Assert.Equal(EmbeddingsTypes.Ollama, generator.ProviderType); @@ -67,7 +68,7 @@ public async Task GenerateAsync_Single_ShouldCallCorrectEndpoint() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new OllamaEmbeddingGenerator( - httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object); + httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object, delayAsync: (_, _) => Task.CompletedTask); // Act var result = await generator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); @@ -80,8 +81,8 @@ public async Task GenerateAsync_Single_ShouldCallCorrectEndpoint() public async Task GenerateAsync_Single_ShouldSendCorrectRequestBody() { // Arrange - HttpRequestMessage? capturedRequest = null; var response = new OllamaEmbeddingResponse { Embedding = new[] { 0.1f } }; + string? capturedContent = null; this._httpHandlerMock .Protected() @@ -89,24 +90,26 @@ public async Task GenerateAsync_Single_ShouldSendCorrectRequestBody() "SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .Callback((req, _) => capturedRequest = req) - .ReturnsAsync(new HttpResponseMessage + .Returns(async (req, ct) => { - StatusCode = HttpStatusCode.OK, - Content = new StringContent(JsonSerializer.Serialize(response)) + capturedContent = await req.Content!.ReadAsStringAsync(ct).ConfigureAwait(false); + return new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(JsonSerializer.Serialize(response)) + }; }); var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new OllamaEmbeddingGenerator( - httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object); + httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object, delayAsync: (_, _) => Task.CompletedTask); // Act await generator.GenerateAsync("hello world", CancellationToken.None).ConfigureAwait(false); // Assert - Assert.NotNull(capturedRequest); - var content = await capturedRequest!.Content!.ReadAsStringAsync().ConfigureAwait(false); - var requestBody = JsonSerializer.Deserialize(content); + Assert.NotNull(capturedContent); + var requestBody = JsonSerializer.Deserialize(capturedContent); Assert.Equal("qwen3-embedding", requestBody!.Model); Assert.Equal("hello world", requestBody.Prompt); } @@ -142,7 +145,7 @@ public async Task GenerateAsync_Batch_ShouldProcessAllTexts() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new OllamaEmbeddingGenerator( - httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object); + httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object, delayAsync: (_, _) => Task.CompletedTask); // Act var results = await generator.GenerateAsync(texts, CancellationToken.None).ConfigureAwait(false); @@ -164,7 +167,7 @@ public async Task GenerateAsync_WithHttpError_ShouldThrowHttpRequestException() "SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .ReturnsAsync(new HttpResponseMessage + .ReturnsAsync(() => new HttpResponseMessage { StatusCode = HttpStatusCode.InternalServerError, Content = new StringContent("Server error") @@ -172,7 +175,7 @@ public async Task GenerateAsync_WithHttpError_ShouldThrowHttpRequestException() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new OllamaEmbeddingGenerator( - httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object); + httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object, delayAsync: (_, _) => Task.CompletedTask); // Act & Assert await Assert.ThrowsAsync( @@ -197,7 +200,7 @@ public async Task GenerateAsync_WithMalformedResponse_ShouldThrowJsonException() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new OllamaEmbeddingGenerator( - httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object); + httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object, delayAsync: (_, _) => Task.CompletedTask); // Act & Assert await Assert.ThrowsAsync( @@ -218,7 +221,7 @@ public async Task GenerateAsync_WithCancellation_ShouldPropagate() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new OllamaEmbeddingGenerator( - httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object); + httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object, delayAsync: (_, _) => Task.CompletedTask); using var cts = new CancellationTokenSource(); cts.Cancel(); @@ -275,7 +278,7 @@ public async Task GenerateAsync_Single_ShouldReturnNullTokenCount() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new OllamaEmbeddingGenerator( - httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object); + httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object, delayAsync: (_, _) => Task.CompletedTask); // Act var result = await generator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); diff --git a/tests/Core.Tests/Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs index 3b1a45d22..bea9442ee 100644 --- a/tests/Core.Tests/Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs +++ b/tests/Core.Tests/Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs @@ -36,7 +36,8 @@ public void Properties_ShouldReflectConfiguration() vectorDimensions: 1536, isNormalized: true, baseUrl: null, - this._loggerMock.Object); + this._loggerMock.Object, + batchSize: 10); // Assert Assert.Equal(EmbeddingsTypes.OpenAI, generator.ProviderType); @@ -68,7 +69,7 @@ public async Task GenerateAsync_Single_ShouldCallCorrectEndpoint() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new OpenAIEmbeddingGenerator( - httpClient, "test-key", "text-embedding-ada-002", 1536, true, null, this._loggerMock.Object); + httpClient, "test-key", "text-embedding-ada-002", 1536, true, null, this._loggerMock.Object, batchSize: 10); // Act var result = await generator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); @@ -99,7 +100,7 @@ public async Task GenerateAsync_Single_ShouldSendAuthorizationHeader() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new OpenAIEmbeddingGenerator( - httpClient, "sk-test-api-key", "text-embedding-ada-002", 1536, true, null, this._loggerMock.Object); + httpClient, "sk-test-api-key", "text-embedding-ada-002", 1536, true, null, this._loggerMock.Object, batchSize: 10); // Act await generator.GenerateAsync("test", CancellationToken.None).ConfigureAwait(false); @@ -142,7 +143,7 @@ public async Task GenerateAsync_Batch_ShouldSendAllInputsInOneRequest() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new OpenAIEmbeddingGenerator( - httpClient, "test-key", "text-embedding-ada-002", 1536, true, null, this._loggerMock.Object); + httpClient, "test-key", "text-embedding-ada-002", 1536, true, null, this._loggerMock.Object, batchSize: 10); // Act var results = await generator.GenerateAsync(texts, CancellationToken.None).ConfigureAwait(false); @@ -153,6 +154,53 @@ public async Task GenerateAsync_Batch_ShouldSendAllInputsInOneRequest() Assert.Equal(3, requestBody!.Input.Length); } + [Fact] + public async Task GenerateAsync_Batch_WithSmallBatchSize_ShouldChunkRequests() + { + // Arrange + var texts = new[] { "text1", "text2", "text3" }; + + var callIndex = 0; + var seenBatchSizes = new List(); + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .Returns(async (req, ct) => + { + var body = await req.Content!.ReadAsStringAsync(ct).ConfigureAwait(false); + var requestBody = JsonSerializer.Deserialize(body)!; + seenBatchSizes.Add(requestBody.Input.Length); + + callIndex++; + var embeddings = callIndex == 1 + ? new[] { new[] { 0.1f }, new[] { 0.2f } } + : new[] { new[] { 0.3f } }; + + var response = CreateOpenAIResponseWithTokenCount(embeddings, totalTokens: 3); + + return new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(JsonSerializer.Serialize(response)) + }; + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OpenAIEmbeddingGenerator( + httpClient, "test-key", "text-embedding-ada-002", 1536, true, null, this._loggerMock.Object, batchSize: 2); + + // Act + var results = await generator.GenerateAsync(texts, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(3, results.Length); + Assert.Equal(new[] { 2, 1 }, seenBatchSizes); + } + [Fact] public async Task GenerateAsync_WithCustomBaseUrl_ShouldUseIt() { @@ -174,7 +222,7 @@ public async Task GenerateAsync_WithCustomBaseUrl_ShouldUseIt() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new OpenAIEmbeddingGenerator( - httpClient, "test-key", "model", 1536, true, "https://custom.api.com", this._loggerMock.Object); + httpClient, "test-key", "model", 1536, true, "https://custom.api.com", this._loggerMock.Object, batchSize: 10); // Act var result = await generator.GenerateAsync("test", CancellationToken.None).ConfigureAwait(false); @@ -193,7 +241,7 @@ public async Task GenerateAsync_WithRateLimitError_ShouldThrowHttpRequestExcepti "SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .ReturnsAsync(new HttpResponseMessage + .ReturnsAsync(() => new HttpResponseMessage { StatusCode = HttpStatusCode.TooManyRequests, Content = new StringContent("Rate limit exceeded") @@ -201,7 +249,7 @@ public async Task GenerateAsync_WithRateLimitError_ShouldThrowHttpRequestExcepti var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new OpenAIEmbeddingGenerator( - httpClient, "test-key", "model", 1536, true, null, this._loggerMock.Object); + httpClient, "test-key", "model", 1536, true, null, this._loggerMock.Object, batchSize: 10, delayAsync: (_, _) => Task.CompletedTask); // Act & Assert await Assert.ThrowsAsync( @@ -226,7 +274,7 @@ public async Task GenerateAsync_WithUnauthorizedError_ShouldThrowHttpRequestExce var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new OpenAIEmbeddingGenerator( - httpClient, "bad-key", "model", 1536, true, null, this._loggerMock.Object); + httpClient, "bad-key", "model", 1536, true, null, this._loggerMock.Object, batchSize: 10); // Act & Assert await Assert.ThrowsAsync( @@ -247,7 +295,7 @@ public async Task GenerateAsync_WithCancellation_ShouldPropagate() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new OpenAIEmbeddingGenerator( - httpClient, "test-key", "model", 1536, true, null, this._loggerMock.Object); + httpClient, "test-key", "model", 1536, true, null, this._loggerMock.Object, batchSize: 10); using var cts = new CancellationTokenSource(); cts.Cancel(); @@ -263,7 +311,7 @@ public void Constructor_WithNullApiKey_ShouldThrow() // Assert var httpClient = new HttpClient(); Assert.Throws(() => - new OpenAIEmbeddingGenerator(httpClient, null!, "model", 1536, true, null, this._loggerMock.Object)); + new OpenAIEmbeddingGenerator(httpClient, null!, "model", 1536, true, null, this._loggerMock.Object, batchSize: 10)); } [Fact] @@ -272,7 +320,7 @@ public void Constructor_WithEmptyApiKey_ShouldThrow() // Assert var httpClient = new HttpClient(); Assert.Throws(() => - new OpenAIEmbeddingGenerator(httpClient, "", "model", 1536, true, null, this._loggerMock.Object)); + new OpenAIEmbeddingGenerator(httpClient, "", "model", 1536, true, null, this._loggerMock.Object, batchSize: 10)); } [Fact] @@ -281,7 +329,7 @@ public void Constructor_WithNullModel_ShouldThrow() // Assert var httpClient = new HttpClient(); Assert.Throws(() => - new OpenAIEmbeddingGenerator(httpClient, "key", null!, 1536, true, null, this._loggerMock.Object)); + new OpenAIEmbeddingGenerator(httpClient, "key", null!, 1536, true, null, this._loggerMock.Object, batchSize: 10)); } [Fact] @@ -305,7 +353,7 @@ public async Task GenerateAsync_Single_ShouldReturnTokenCountFromApiResponse() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new OpenAIEmbeddingGenerator( - httpClient, "test-key", "text-embedding-ada-002", 1536, true, null, this._loggerMock.Object); + httpClient, "test-key", "text-embedding-ada-002", 1536, true, null, this._loggerMock.Object, batchSize: 10); // Act var result = await generator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); @@ -341,7 +389,7 @@ public async Task GenerateAsync_Batch_ShouldDistributeTokenCountEvenly() var httpClient = new HttpClient(this._httpHandlerMock.Object); var generator = new OpenAIEmbeddingGenerator( - httpClient, "test-key", "text-embedding-ada-002", 1536, true, null, this._loggerMock.Object); + httpClient, "test-key", "text-embedding-ada-002", 1536, true, null, this._loggerMock.Object, batchSize: 10); // Act var results = await generator.GenerateAsync(new[] { "text1", "text2", "text3" }, CancellationToken.None).ConfigureAwait(false); diff --git a/tests/Core.Tests/Http/HttpRetryPolicyTests.cs b/tests/Core.Tests/Http/HttpRetryPolicyTests.cs new file mode 100644 index 000000000..c4daffa39 --- /dev/null +++ b/tests/Core.Tests/Http/HttpRetryPolicyTests.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Net; +using KernelMemory.Core.Http; +using Microsoft.Extensions.Logging; +using Moq; +using Moq.Protected; + +namespace KernelMemory.Core.Tests.Http; + +public sealed class HttpRetryPolicyTests +{ + [Fact] + public async Task SendAsync_With429ThenSuccess_ShouldRetryAndReturnSuccessResponse() + { + // Arrange + var handlerMock = new Mock(); + var loggerMock = new Mock(); + var delayCalls = new List(); + + var callIndex = 0; + handlerMock + .Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .ReturnsAsync(() => + { + callIndex++; + return callIndex < 3 + ? new HttpResponseMessage { StatusCode = HttpStatusCode.TooManyRequests, Content = new StringContent("rate limit") } + : new HttpResponseMessage { StatusCode = HttpStatusCode.OK, Content = new StringContent("ok") }; + }); + + using var httpClient = new HttpClient(handlerMock.Object); + + // Act + using var response = await HttpRetryPolicy.SendAsync( + httpClient, + requestFactory: () => new HttpRequestMessage(HttpMethod.Get, "https://example.com"), + loggerMock.Object, + CancellationToken.None, + delayAsync: (d, _) => + { + delayCalls.Add(d); + return Task.CompletedTask; + }).ConfigureAwait(false); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal(3, callIndex); + Assert.Equal(2, delayCalls.Count); + } +} diff --git a/tests/Core.Tests/Logging/EnvironmentDetectorTests.cs b/tests/Core.Tests/Logging/EnvironmentDetectorTests.cs index a1fd52a09..fbb282cab 100644 --- a/tests/Core.Tests/Logging/EnvironmentDetectorTests.cs +++ b/tests/Core.Tests/Logging/EnvironmentDetectorTests.cs @@ -6,6 +6,7 @@ namespace KernelMemory.Core.Tests.Logging; /// Tests for EnvironmentDetector - validates environment detection logic. /// Environment detection is critical for security (sensitive data scrubbing). /// +[Collection("EnvironmentVariables")] public sealed class EnvironmentDetectorTests : IDisposable { private readonly string? _originalDotNetEnv; diff --git a/tests/Core.Tests/Logging/SensitiveDataScrubbingPolicyTests.cs b/tests/Core.Tests/Logging/SensitiveDataScrubbingPolicyTests.cs index 09419509a..823424517 100644 --- a/tests/Core.Tests/Logging/SensitiveDataScrubbingPolicyTests.cs +++ b/tests/Core.Tests/Logging/SensitiveDataScrubbingPolicyTests.cs @@ -10,6 +10,7 @@ namespace KernelMemory.Core.Tests.Logging; /// In Production environment, all string parameters must be scrubbed to prevent data leakage. /// In Development/Staging, full logging is allowed for debugging. /// +[Collection("EnvironmentVariables")] public sealed class SensitiveDataScrubbingPolicyTests : IDisposable { private readonly string? _originalDotNetEnv; diff --git a/tests/Core.Tests/TestCollections.cs b/tests/Core.Tests/TestCollections.cs new file mode 100644 index 000000000..8e71ad1d1 --- /dev/null +++ b/tests/Core.Tests/TestCollections.cs @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace KernelMemory.Core.Tests; + +[CollectionDefinition("EnvironmentVariables", DisableParallelization = true)] +public sealed class EnvironmentVariablesTestCollectionDefinition +{ +} diff --git a/tests/Main.Tests/Integration/SearchProcessTests.cs b/tests/Main.Tests/Integration/SearchProcessTests.cs index ec9e328c8..65e83fb15 100644 --- a/tests/Main.Tests/Integration/SearchProcessTests.cs +++ b/tests/Main.Tests/Integration/SearchProcessTests.cs @@ -86,9 +86,11 @@ private async Task ExecuteKmAsync(string args) return output.Trim(); } - [OllamaFact] + [Fact] public async Task Process_PutThenSearch_FindsContent() { + this.WriteFtsOnlyConfig(); + // Act: Insert content var putOutput = await this.ExecuteKmAsync($"put \"ciao mondo\" --config {this._configPath}").ConfigureAwait(false); var putResult = JsonSerializer.Deserialize(putOutput); @@ -108,6 +110,40 @@ public async Task Process_PutThenSearch_FindsContent() Assert.Contains("ciao", results[0].GetProperty("content").GetString()!, StringComparison.OrdinalIgnoreCase); } + private void WriteFtsOnlyConfig() + { + var personalNodeDir = Path.Combine(this._tempDir, "nodes", "personal"); + Directory.CreateDirectory(personalNodeDir); + + var contentDbPath = Path.Combine(personalNodeDir, "content.db"); + var ftsDbPath = Path.Combine(personalNodeDir, "fts.db"); + + var json = $$""" + { + "nodes": { + "personal": { + "id": "personal", + "contentIndex": { + "type": "sqlite", + "path": {{JsonSerializer.Serialize(contentDbPath)}} + }, + "searchIndexes": [ + { + "type": "sqliteFTS", + "id": "sqlite-fts", + "path": {{JsonSerializer.Serialize(ftsDbPath)}}, + "enableStemming": true, + "required": true + } + ] + } + } + } + """; + + File.WriteAllText(this._configPath, json); + } + [Fact] public async Task Process_BooleanAnd_FindsOnlyMatchingBoth() { @@ -209,15 +245,4 @@ await this.ExecuteKmAsync($"put \"docker helm charts\" --config {this._configPat Assert.Contains(id1, ids); Assert.Contains(id2, ids); } - - private sealed class OllamaFactAttribute : FactAttribute - { - public OllamaFactAttribute() - { - if (string.Equals(Environment.GetEnvironmentVariable("OLLAMA_AVAILABLE"), "false", StringComparison.OrdinalIgnoreCase)) - { - this.Skip = "Skipping because OLLAMA_AVAILABLE=false (vector embeddings unavailable)."; - } - } - } } diff --git a/tests/Main.Tests/Services/EmbeddingGeneratorFactoryTests.cs b/tests/Main.Tests/Services/EmbeddingGeneratorFactoryTests.cs index 67cc5c7b6..b75d5a220 100644 --- a/tests/Main.Tests/Services/EmbeddingGeneratorFactoryTests.cs +++ b/tests/Main.Tests/Services/EmbeddingGeneratorFactoryTests.cs @@ -134,6 +134,38 @@ public void CreateGenerator_CreatesHuggingFaceGenerator() Assert.Equal(384, generator.VectorDimensions); // Known dimension for all-MiniLM-L6-v2 } + [Fact] + public void CreateGenerator_CreatesHuggingFaceGenerator_FromHfTokenEnvVar() + { + // Arrange + var originalToken = Environment.GetEnvironmentVariable("HF_TOKEN"); + Environment.SetEnvironmentVariable("HF_TOKEN", "hf_test_token_from_env"); + + try + { + var config = new HuggingFaceEmbeddingsConfig + { + Model = "sentence-transformers/all-MiniLM-L6-v2", + ApiKey = null + }; + + // Act + var generator = EmbeddingGeneratorFactory.CreateGenerator( + config, + this._httpClient, + cache: null, + this._mockLoggerFactory.Object); + + // Assert + Assert.NotNull(generator); + Assert.Equal(EmbeddingsTypes.HuggingFace, generator.ProviderType); + } + finally + { + Environment.SetEnvironmentVariable("HF_TOKEN", originalToken); + } + } + [Fact] public void CreateGenerator_WrapsWithCacheWhenProvided() { From d341f7f07614b117d8168cb015af1d59a9f96340 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 18 Dec 2025 18:06:18 +0100 Subject: [PATCH 2/5] Revert SearchProcessTests to Ollama-gated behavior Restore the original OLLAMA_AVAILABLE gating and remove the unrequested FTS-only config change. --- .../Integration/SearchProcessTests.cs | 49 +++++-------------- 1 file changed, 12 insertions(+), 37 deletions(-) diff --git a/tests/Main.Tests/Integration/SearchProcessTests.cs b/tests/Main.Tests/Integration/SearchProcessTests.cs index 65e83fb15..ec9e328c8 100644 --- a/tests/Main.Tests/Integration/SearchProcessTests.cs +++ b/tests/Main.Tests/Integration/SearchProcessTests.cs @@ -86,11 +86,9 @@ private async Task ExecuteKmAsync(string args) return output.Trim(); } - [Fact] + [OllamaFact] public async Task Process_PutThenSearch_FindsContent() { - this.WriteFtsOnlyConfig(); - // Act: Insert content var putOutput = await this.ExecuteKmAsync($"put \"ciao mondo\" --config {this._configPath}").ConfigureAwait(false); var putResult = JsonSerializer.Deserialize(putOutput); @@ -110,40 +108,6 @@ public async Task Process_PutThenSearch_FindsContent() Assert.Contains("ciao", results[0].GetProperty("content").GetString()!, StringComparison.OrdinalIgnoreCase); } - private void WriteFtsOnlyConfig() - { - var personalNodeDir = Path.Combine(this._tempDir, "nodes", "personal"); - Directory.CreateDirectory(personalNodeDir); - - var contentDbPath = Path.Combine(personalNodeDir, "content.db"); - var ftsDbPath = Path.Combine(personalNodeDir, "fts.db"); - - var json = $$""" - { - "nodes": { - "personal": { - "id": "personal", - "contentIndex": { - "type": "sqlite", - "path": {{JsonSerializer.Serialize(contentDbPath)}} - }, - "searchIndexes": [ - { - "type": "sqliteFTS", - "id": "sqlite-fts", - "path": {{JsonSerializer.Serialize(ftsDbPath)}}, - "enableStemming": true, - "required": true - } - ] - } - } - } - """; - - File.WriteAllText(this._configPath, json); - } - [Fact] public async Task Process_BooleanAnd_FindsOnlyMatchingBoth() { @@ -245,4 +209,15 @@ await this.ExecuteKmAsync($"put \"docker helm charts\" --config {this._configPat Assert.Contains(id1, ids); Assert.Contains(id2, ids); } + + private sealed class OllamaFactAttribute : FactAttribute + { + public OllamaFactAttribute() + { + if (string.Equals(Environment.GetEnvironmentVariable("OLLAMA_AVAILABLE"), "false", StringComparison.OrdinalIgnoreCase)) + { + this.Skip = "Skipping because OLLAMA_AVAILABLE=false (vector embeddings unavailable)."; + } + } + } } From e004fec65c46bd3aed4e6febc81b242313599049 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 18 Dec 2025 18:28:20 +0100 Subject: [PATCH 3/5] Fix HttpRetryPolicy: fail fast on unreachable services - Add per-attempt timeout support and Ollama-specific timeout - Avoid retrying connection-refused/host-not-found to prevent long-running test hangs --- src/Core/Constants.cs | 12 +++++++++ .../Providers/OllamaEmbeddingGenerator.cs | 3 ++- src/Core/Http/HttpRetryPolicy.cs | 26 ++++++++++++++++--- 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 0ad6baf3c..3e7549509 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -220,6 +220,18 @@ public static class HttpRetryDefaults /// public const int MaxAttempts = 5; + /// + /// Default timeout per attempt (seconds). + /// Applied by to avoid hanging calls. + /// + public const int DefaultPerAttemptTimeoutSeconds = 60; + + /// + /// Per-attempt timeout for local Ollama calls (seconds). + /// Keep this low so local development and tests fail fast when Ollama is not running. + /// + public const int OllamaPerAttemptTimeoutSeconds = 5; + /// /// Base delay for exponential backoff. /// diff --git a/src/Core/Embeddings/Providers/OllamaEmbeddingGenerator.cs b/src/Core/Embeddings/Providers/OllamaEmbeddingGenerator.cs index 3ea61750f..8330ca077 100644 --- a/src/Core/Embeddings/Providers/OllamaEmbeddingGenerator.cs +++ b/src/Core/Embeddings/Providers/OllamaEmbeddingGenerator.cs @@ -92,7 +92,8 @@ public async Task GenerateAsync(string text, CancellationToken }, this._logger, ct, - delayAsync: this._delayAsync).ConfigureAwait(false); + delayAsync: this._delayAsync, + perAttemptTimeout: TimeSpan.FromSeconds(Constants.HttpRetryDefaults.OllamaPerAttemptTimeoutSeconds)).ConfigureAwait(false); response.EnsureSuccessStatusCode(); diff --git a/src/Core/Http/HttpRetryPolicy.cs b/src/Core/Http/HttpRetryPolicy.cs index f07daccba..c54f7f316 100644 --- a/src/Core/Http/HttpRetryPolicy.cs +++ b/src/Core/Http/HttpRetryPolicy.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Net; +using System.Net.Sockets; using Microsoft.Extensions.Logging; namespace KernelMemory.Core.Http; @@ -19,13 +20,15 @@ public static async Task SendAsync( Func requestFactory, ILogger logger, CancellationToken ct, - Func? delayAsync = null) + Func? delayAsync = null, + TimeSpan? perAttemptTimeout = null) { ArgumentNullException.ThrowIfNull(httpClient, nameof(httpClient)); ArgumentNullException.ThrowIfNull(requestFactory, nameof(requestFactory)); ArgumentNullException.ThrowIfNull(logger, nameof(logger)); delayAsync ??= Task.Delay; + perAttemptTimeout ??= TimeSpan.FromSeconds(Constants.HttpRetryDefaults.DefaultPerAttemptTimeoutSeconds); Exception? lastException = null; @@ -37,7 +40,10 @@ public static async Task SendAsync( try { - var response = await httpClient.SendAsync(request, ct).ConfigureAwait(false); + using var attemptCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + attemptCts.CancelAfter(perAttemptTimeout.Value); + + var response = await httpClient.SendAsync(request, attemptCts.Token).ConfigureAwait(false); if (response.IsSuccessStatusCode) { @@ -81,7 +87,7 @@ public static async Task SendAsync( catch (HttpRequestException ex) { lastException = ex; - if (attempt == Constants.HttpRetryDefaults.MaxAttempts) + if (!IsRetryableException(ex) || attempt == Constants.HttpRetryDefaults.MaxAttempts) { throw; } @@ -101,6 +107,20 @@ public static async Task SendAsync( throw lastException ?? new HttpRequestException("HTTP call failed after retries"); } + private static bool IsRetryableException(HttpRequestException ex) + { + // Fail fast on common "service not running" / "unreachable" conditions. + // Retrying these (especially with default HttpClient timeouts) can make test runs appear hung. + if (ex.InnerException is SocketException socketEx) + { + return socketEx.SocketErrorCode != SocketError.ConnectionRefused && + socketEx.SocketErrorCode != SocketError.HostNotFound && + socketEx.SocketErrorCode != SocketError.NetworkUnreachable; + } + + return true; + } + private static bool IsRetryableStatusCode(HttpStatusCode statusCode) { return statusCode == HttpStatusCode.TooManyRequests || From 9a6049f11a8f9a4bd1326389f5684609f65936c5 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 18 Dec 2025 18:47:52 +0100 Subject: [PATCH 4/5] Centralize HF_TOKEN handling in config Resolve HF_TOKEN into HuggingFaceEmbeddingsConfig.ApiKey during validation so factories/providers rely only on config. --- .../Embeddings/HuggingFaceEmbeddingsConfig.cs | 15 ++++++++++++--- src/Main/Services/EmbeddingGeneratorFactory.cs | 4 +++- .../Services/EmbeddingGeneratorFactoryTests.cs | 2 ++ 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs b/src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs index 0be95dcbe..b4e49e756 100644 --- a/src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs +++ b/src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs @@ -44,9 +44,18 @@ public override void Validate(string path) throw new ConfigException($"{path}.Model", "HuggingFace model name is required"); } - // ApiKey can be provided via config or HF_TOKEN environment variable - if (string.IsNullOrWhiteSpace(this.ApiKey) && - string.IsNullOrWhiteSpace(Environment.GetEnvironmentVariable("HF_TOKEN"))) + // ApiKey can be provided via config or HF_TOKEN environment variable. + // Resolve env var into config so downstream code can rely on configuration only. + if (string.IsNullOrWhiteSpace(this.ApiKey)) + { + var token = Environment.GetEnvironmentVariable("HF_TOKEN"); + if (!string.IsNullOrWhiteSpace(token)) + { + this.ApiKey = token; + } + } + + if (string.IsNullOrWhiteSpace(this.ApiKey)) { throw new ConfigException($"{path}.ApiKey", "HuggingFace API key is required (set ApiKey or HF_TOKEN)"); } diff --git a/src/Main/Services/EmbeddingGeneratorFactory.cs b/src/Main/Services/EmbeddingGeneratorFactory.cs index 82be77055..50b68ef2c 100644 --- a/src/Main/Services/EmbeddingGeneratorFactory.cs +++ b/src/Main/Services/EmbeddingGeneratorFactory.cs @@ -145,9 +145,11 @@ private static IEmbeddingGenerator CreateHuggingFaceGenerator( // Get known dimensions for the model var dimensions = Constants.EmbeddingDefaults.KnownModelDimensions.GetValueOrDefault(config.Model, defaultValue: 384); + var apiKey = config.ApiKey ?? throw new InvalidOperationException("HuggingFace API key is required"); + return new HuggingFaceEmbeddingGenerator( httpClient, - string.IsNullOrWhiteSpace(config.ApiKey) ? (Environment.GetEnvironmentVariable("HF_TOKEN") ?? string.Empty) : config.ApiKey, + apiKey, config.Model, dimensions, isNormalized: true, // Sentence-transformers models typically return normalized vectors diff --git a/tests/Main.Tests/Services/EmbeddingGeneratorFactoryTests.cs b/tests/Main.Tests/Services/EmbeddingGeneratorFactoryTests.cs index b75d5a220..3c2d1798b 100644 --- a/tests/Main.Tests/Services/EmbeddingGeneratorFactoryTests.cs +++ b/tests/Main.Tests/Services/EmbeddingGeneratorFactoryTests.cs @@ -149,6 +149,8 @@ public void CreateGenerator_CreatesHuggingFaceGenerator_FromHfTokenEnvVar() ApiKey = null }; + config.Validate("Embeddings"); + // Act var generator = EmbeddingGeneratorFactory.CreateGenerator( config, From d2b36e20add62574e8d5ffe959265baa19dfc08c Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 18 Dec 2025 18:51:57 +0100 Subject: [PATCH 5/5] Docs: clarify defensive fallback in HttpRetryPolicy Add comment explaining the final throw is unreachable in normal flow and kept as defensive code. --- src/Core/Http/HttpRetryPolicy.cs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/Core/Http/HttpRetryPolicy.cs b/src/Core/Http/HttpRetryPolicy.cs index c54f7f316..e865a88af 100644 --- a/src/Core/Http/HttpRetryPolicy.cs +++ b/src/Core/Http/HttpRetryPolicy.cs @@ -104,6 +104,12 @@ public static async Task SendAsync( } } + // Defensive fallback: in normal flow, this is unreachable because: + // - Successful responses return immediately. + // - Non-retryable HTTP status codes return immediately (no exception). + // - Retryable HTTP status codes exhaust attempts and return the final response. + // - Exceptions either throw (non-retryable / last attempt) or are captured in lastException. + // Keeping this protects against unexpected future changes to the control flow. throw lastException ?? new HttpRequestException("HTTP call failed after retries"); }