From df9091cdd6421871589e417d11025ec9ea74f358 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Tue, 2 Dec 2025 16:44:19 +0100 Subject: [PATCH 1/6] feat: add embedding generators and cache (#00007) Implement embedding generation infrastructure with caching support: Core Components: - IEmbeddingGenerator interface for all providers - EmbeddingConstants with known model dimensions - CachedEmbeddingGenerator decorator for transparent caching Cache System: - IEmbeddingCache interface with SQLite implementation - EmbeddingCacheKey with SHA256 hashing (privacy: text never stored) - CachedEmbedding model with provider/model metadata - CacheModes enum: ReadWrite, ReadOnly, WriteOnly Embedding Providers: - OllamaEmbeddingGenerator for local models - OpenAIEmbeddingGenerator for OpenAI API - AzureOpenAIEmbeddingGenerator for Azure OpenAI - HuggingFaceEmbeddingGenerator for HuggingFace Inference API Configuration: - HuggingFaceEmbeddingsConfig with JSON discriminator - Updated EmbeddingsTypes enum with HuggingFace value Test coverage: 86.33% (133 new tests) --- .../Config/Embeddings/EmbeddingsConfig.cs | 1 + .../Embeddings/HuggingFaceEmbeddingsConfig.cs | 64 +++ src/Core/Config/Enums/CacheModes.cs | 30 ++ src/Core/Config/Enums/EmbeddingsTypes.cs | 5 +- src/Core/Embeddings/Cache/CachedEmbedding.cs | 30 ++ .../Embeddings/Cache/EmbeddingCacheKey.cs | 101 ++++ src/Core/Embeddings/Cache/IEmbeddingCache.cs | 36 ++ .../Embeddings/Cache/SqliteEmbeddingCache.cs | 204 ++++++++ .../Embeddings/CachedEmbeddingGenerator.cs | 170 ++++++ src/Core/Embeddings/EmbeddingConstants.cs | 77 +++ src/Core/Embeddings/IEmbeddingGenerator.cs | 57 ++ .../AzureOpenAIEmbeddingGenerator.cs | 163 ++++++ .../HuggingFaceEmbeddingGenerator.cs | 126 +++++ .../Providers/OllamaEmbeddingGenerator.cs | 130 +++++ .../Providers/OpenAIEmbeddingGenerator.cs | 163 ++++++ tests/Core.Tests/Config/CacheModesTests.cs | 53 ++ .../HuggingFaceEmbeddingsConfigTests.cs | 186 +++++++ .../Cache/SqliteEmbeddingCacheTests.cs | 376 +++++++++++++ .../CachedEmbeddingGeneratorTests.cs | 494 ++++++++++++++++++ .../Embeddings/CachedEmbeddingTests.cs | 138 +++++ .../Embeddings/EmbeddingCacheKeyTests.cs | 230 ++++++++ .../Embeddings/EmbeddingConstantsTests.cs | 103 ++++ .../AzureOpenAIEmbeddingGeneratorTests.cs | 308 +++++++++++ .../HuggingFaceEmbeddingGeneratorTests.cs | 327 ++++++++++++ .../OllamaEmbeddingGeneratorTests.cs | 272 ++++++++++ .../OpenAIEmbeddingGeneratorTests.cs | 332 ++++++++++++ tests/Core.Tests/GlobalUsings.cs | 9 + 27 files changed, 4184 insertions(+), 1 deletion(-) create mode 100644 src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs create mode 100644 src/Core/Config/Enums/CacheModes.cs create mode 100644 src/Core/Embeddings/Cache/CachedEmbedding.cs create mode 100644 src/Core/Embeddings/Cache/EmbeddingCacheKey.cs create mode 100644 src/Core/Embeddings/Cache/IEmbeddingCache.cs create mode 100644 src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs create mode 100644 src/Core/Embeddings/CachedEmbeddingGenerator.cs create mode 100644 src/Core/Embeddings/EmbeddingConstants.cs create mode 100644 src/Core/Embeddings/IEmbeddingGenerator.cs create mode 100644 src/Core/Embeddings/Providers/AzureOpenAIEmbeddingGenerator.cs create mode 100644 src/Core/Embeddings/Providers/HuggingFaceEmbeddingGenerator.cs create mode 100644 src/Core/Embeddings/Providers/OllamaEmbeddingGenerator.cs create mode 100644 src/Core/Embeddings/Providers/OpenAIEmbeddingGenerator.cs create mode 100644 tests/Core.Tests/Config/CacheModesTests.cs create mode 100644 tests/Core.Tests/Config/HuggingFaceEmbeddingsConfigTests.cs create mode 100644 tests/Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs create mode 100644 tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs create mode 100644 tests/Core.Tests/Embeddings/CachedEmbeddingTests.cs create mode 100644 tests/Core.Tests/Embeddings/EmbeddingCacheKeyTests.cs create mode 100644 tests/Core.Tests/Embeddings/EmbeddingConstantsTests.cs create mode 100644 tests/Core.Tests/Embeddings/Providers/AzureOpenAIEmbeddingGeneratorTests.cs create mode 100644 tests/Core.Tests/Embeddings/Providers/HuggingFaceEmbeddingGeneratorTests.cs create mode 100644 tests/Core.Tests/Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs create mode 100644 tests/Core.Tests/Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs diff --git a/src/Core/Config/Embeddings/EmbeddingsConfig.cs b/src/Core/Config/Embeddings/EmbeddingsConfig.cs index 154864d48..4fae204d1 100644 --- a/src/Core/Config/Embeddings/EmbeddingsConfig.cs +++ b/src/Core/Config/Embeddings/EmbeddingsConfig.cs @@ -12,6 +12,7 @@ namespace KernelMemory.Core.Config.Embeddings; [JsonDerivedType(typeof(OllamaEmbeddingsConfig), typeDiscriminator: "ollama")] [JsonDerivedType(typeof(OpenAIEmbeddingsConfig), typeDiscriminator: "openai")] [JsonDerivedType(typeof(AzureOpenAIEmbeddingsConfig), typeDiscriminator: "azureOpenAI")] +[JsonDerivedType(typeof(HuggingFaceEmbeddingsConfig), typeDiscriminator: "huggingFace")] public abstract class EmbeddingsConfig : IValidatable { /// diff --git a/src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs b/src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs new file mode 100644 index 000000000..5ae7c810f --- /dev/null +++ b/src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Text.Json.Serialization; +using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Config.Validation; +using KernelMemory.Core.Embeddings; + +namespace KernelMemory.Core.Config.Embeddings; + +/// +/// HuggingFace Inference API embeddings provider configuration. +/// Supports the serverless Inference API for embedding models. +/// +public sealed class HuggingFaceEmbeddingsConfig : EmbeddingsConfig +{ + /// + [JsonIgnore] + public override EmbeddingsTypes Type => EmbeddingsTypes.HuggingFace; + + /// + /// HuggingFace model name (e.g., "sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-base-en-v1.5"). + /// + [JsonPropertyName("model")] + public string Model { get; set; } = EmbeddingConstants.DefaultHuggingFaceModel; + + /// + /// HuggingFace API key (token). + /// Can also be set via HF_TOKEN environment variable. + /// + [JsonPropertyName("apiKey")] + public string? ApiKey { get; set; } + + /// + /// HuggingFace Inference API base URL. + /// Default: https://api-inference.huggingface.co + /// Can be changed for custom inference endpoints. + /// + [JsonPropertyName("baseUrl")] + public string BaseUrl { get; set; } = EmbeddingConstants.DefaultHuggingFaceBaseUrl; + + /// + public override void Validate(string path) + { + if (string.IsNullOrWhiteSpace(this.Model)) + { + throw new ConfigException($"{path}.Model", "HuggingFace model name is required"); + } + + if (string.IsNullOrWhiteSpace(this.ApiKey)) + { + throw new ConfigException($"{path}.ApiKey", "HuggingFace API key is required"); + } + + if (string.IsNullOrWhiteSpace(this.BaseUrl)) + { + throw new ConfigException($"{path}.BaseUrl", "HuggingFace base URL is required"); + } + + if (!Uri.TryCreate(this.BaseUrl, UriKind.Absolute, out _)) + { + throw new ConfigException($"{path}.BaseUrl", + $"Invalid HuggingFace base URL: {this.BaseUrl}"); + } + } +} diff --git a/src/Core/Config/Enums/CacheModes.cs b/src/Core/Config/Enums/CacheModes.cs new file mode 100644 index 000000000..b34c7f3d4 --- /dev/null +++ b/src/Core/Config/Enums/CacheModes.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Text.Json.Serialization; + +namespace KernelMemory.Core.Config.Enums; + +/// +/// Modes for embedding cache operations. +/// Controls whether the cache reads, writes, or both. +/// +[JsonConverter(typeof(JsonStringEnumConverter))] +public enum CacheModes +{ + /// + /// Both read from and write to cache (default). + /// Cache hits return stored embeddings, misses are generated and stored. + /// + ReadWrite, + + /// + /// Only read from cache, never write. + /// Useful for read-only deployments or when cache is pre-populated. + /// + ReadOnly, + + /// + /// Only write to cache, never read. + /// Useful for warming up a cache without affecting current behavior. + /// + WriteOnly +} diff --git a/src/Core/Config/Enums/EmbeddingsTypes.cs b/src/Core/Config/Enums/EmbeddingsTypes.cs index 77d93c471..b2269a6c4 100644 --- a/src/Core/Config/Enums/EmbeddingsTypes.cs +++ b/src/Core/Config/Enums/EmbeddingsTypes.cs @@ -16,5 +16,8 @@ public enum EmbeddingsTypes OpenAI, /// Azure OpenAI Service - AzureOpenAI + AzureOpenAI, + + /// Hugging Face Inference API + HuggingFace } diff --git a/src/Core/Embeddings/Cache/CachedEmbedding.cs b/src/Core/Embeddings/Cache/CachedEmbedding.cs new file mode 100644 index 000000000..8ce123b1e --- /dev/null +++ b/src/Core/Embeddings/Cache/CachedEmbedding.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Diagnostics.CodeAnalysis; + +namespace KernelMemory.Core.Embeddings.Cache; + +/// +/// Represents a cached embedding vector with metadata. +/// Stores the vector, optional token count, and timestamp of when it was cached. +/// +public sealed class CachedEmbedding +{ + /// + /// The embedding vector as a float array. + /// Array is intentional for performance - embeddings are read-only after creation. + /// + [SuppressMessage("Performance", "CA1819:Properties should not return arrays", + Justification = "Embedding vectors are performance-critical and read-only after creation")] + public required float[] Vector { get; init; } + + /// + /// Optional token count from the provider response. + /// Null if the provider did not return token count. + /// + public int? TokenCount { get; init; } + + /// + /// Timestamp when this embedding was stored in the cache. + /// + public required DateTimeOffset Timestamp { get; init; } +} diff --git a/src/Core/Embeddings/Cache/EmbeddingCacheKey.cs b/src/Core/Embeddings/Cache/EmbeddingCacheKey.cs new file mode 100644 index 000000000..55777d832 --- /dev/null +++ b/src/Core/Embeddings/Cache/EmbeddingCacheKey.cs @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Security.Cryptography; +using System.Text; + +namespace KernelMemory.Core.Embeddings.Cache; + +/// +/// Cache key for embeddings. Uniquely identifies an embedding by provider, model, +/// dimensions, normalization state, and content hash. +/// The input text is NOT stored - only a SHA256 hash is used for security. +/// +public sealed class EmbeddingCacheKey +{ + /// + /// Provider type name (e.g., "OpenAI", "Ollama", "AzureOpenAI", "HuggingFace"). + /// + public required string Provider { get; init; } + + /// + /// Model name (e.g., "text-embedding-ada-002", "qwen3-embedding"). + /// + public required string Model { get; init; } + + /// + /// Vector dimensions produced by this model. + /// + public required int VectorDimensions { get; init; } + + /// + /// Whether the vectors are normalized. + /// + public required bool IsNormalized { get; init; } + + /// + /// Length of the original text in characters. + /// Used as an additional collision prevention measure. + /// + public required int TextLength { get; init; } + + /// + /// SHA256 hash of the original text (hex string). + /// The text itself is never stored for security/privacy. + /// + public required string TextHash { get; init; } + + /// + /// Creates a cache key from the given parameters. + /// The text is hashed using SHA256 and not stored. + /// + /// Provider type name. + /// Model name. + /// Vector dimensions. + /// Whether vectors are normalized. + /// The text to hash. + /// A new EmbeddingCacheKey instance. + /// When provider, model, or text is null. + /// When vectorDimensions is less than 1. + public static EmbeddingCacheKey Create( + string provider, + string model, + int vectorDimensions, + bool isNormalized, + string text) + { + ArgumentNullException.ThrowIfNull(provider, nameof(provider)); + ArgumentNullException.ThrowIfNull(model, nameof(model)); + ArgumentNullException.ThrowIfNull(text, nameof(text)); + ArgumentOutOfRangeException.ThrowIfLessThan(vectorDimensions, 1, nameof(vectorDimensions)); + + return new EmbeddingCacheKey + { + Provider = provider, + Model = model, + VectorDimensions = vectorDimensions, + IsNormalized = isNormalized, + TextLength = text.Length, + TextHash = ComputeSha256Hash(text) + }; + } + + /// + /// Generates a composite key string for use as a database primary key. + /// Format: Provider|Model|Dimensions|IsNormalized|TextLength|TextHash + /// + /// A string suitable for use as a cache key. + public string ToCompositeKey() + { + return $"{this.Provider}|{this.Model}|{this.VectorDimensions}|{this.IsNormalized}|{this.TextLength}|{this.TextHash}"; + } + + /// + /// Computes SHA256 hash of the input text and returns as lowercase hex string. + /// + /// The text to hash. + /// 64-character lowercase hex string. + private static string ComputeSha256Hash(string text) + { + byte[] bytes = SHA256.HashData(Encoding.UTF8.GetBytes(text)); + return Convert.ToHexStringLower(bytes); + } +} diff --git a/src/Core/Embeddings/Cache/IEmbeddingCache.cs b/src/Core/Embeddings/Cache/IEmbeddingCache.cs new file mode 100644 index 000000000..201692161 --- /dev/null +++ b/src/Core/Embeddings/Cache/IEmbeddingCache.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core.Config.Enums; + +namespace KernelMemory.Core.Embeddings.Cache; + +/// +/// Interface for embedding cache implementations. +/// Supports dependency injection and multiple cache implementations (SQLite, etc.). +/// +public interface IEmbeddingCache +{ + /// + /// Cache mode (read-write, read-only, write-only). + /// Controls whether read and write operations are allowed. + /// + CacheModes Mode { get; } + + /// + /// Try to retrieve a cached embedding by key. + /// Returns null if not found or if mode is WriteOnly. + /// + /// The cache key to look up. + /// Cancellation token. + /// The cached embedding if found, null otherwise. + Task TryGetAsync(EmbeddingCacheKey key, CancellationToken ct = default); + + /// + /// Store an embedding in the cache. + /// Does nothing if mode is ReadOnly. + /// + /// The cache key. + /// The embedding vector to store. + /// Optional token count from the provider. + /// Cancellation token. + Task StoreAsync(EmbeddingCacheKey key, float[] vector, int? tokenCount, CancellationToken ct = default); +} diff --git a/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs b/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs new file mode 100644 index 000000000..babc9327c --- /dev/null +++ b/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Globalization; +using KernelMemory.Core.Config.Enums; +using Microsoft.Data.Sqlite; +using Microsoft.Extensions.Logging; + +namespace KernelMemory.Core.Embeddings.Cache; + +/// +/// SQLite-based embedding cache implementation. +/// Stores embedding vectors as BLOBs for efficient storage. +/// Uses WAL mode for better concurrency support. +/// +public sealed class SqliteEmbeddingCache : IEmbeddingCache, IDisposable +{ + private const string CreateTableSql = """ + CREATE TABLE IF NOT EXISTS embeddings_cache ( + key TEXT PRIMARY KEY, + vector BLOB NOT NULL, + token_count INTEGER NULL, + timestamp TEXT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_timestamp ON embeddings_cache(timestamp); + """; + + private const string SelectSql = "SELECT vector, token_count, timestamp FROM embeddings_cache WHERE key = @key"; + private const string UpsertSql = """ + INSERT INTO embeddings_cache (key, vector, token_count, timestamp) VALUES (@key, @vector, @tokenCount, @timestamp) + ON CONFLICT(key) DO UPDATE SET vector = @vector, token_count = @tokenCount, timestamp = @timestamp + """; + + private readonly SqliteConnection _connection; + private readonly ILogger _logger; + private bool _disposed; + + /// + public CacheModes Mode { get; } + + /// + /// Creates a new SQLite embedding cache. + /// The database file is created if it doesn't exist. + /// WAL mode is enabled for better concurrency. + /// + /// Path to the SQLite database file. + /// Cache mode (ReadWrite, ReadOnly, WriteOnly). + /// Logger instance. + public SqliteEmbeddingCache(string dbPath, CacheModes mode, ILogger logger) + { + ArgumentNullException.ThrowIfNull(dbPath, nameof(dbPath)); + ArgumentNullException.ThrowIfNull(logger, nameof(logger)); + + this.Mode = mode; + this._logger = logger; + + // Ensure directory exists + var directory = Path.GetDirectoryName(dbPath); + if (!string.IsNullOrEmpty(directory) && !Directory.Exists(directory)) + { + Directory.CreateDirectory(directory); + this._logger.LogDebug("Created directory for embedding cache: {Directory}", directory); + } + + var connectionString = new SqliteConnectionStringBuilder + { + DataSource = dbPath, + Mode = SqliteOpenMode.ReadWriteCreate, + Cache = SqliteCacheMode.Shared + }.ToString(); + + this._connection = new SqliteConnection(connectionString); + this._connection.Open(); + + // Enable WAL mode for better concurrency + using var walCommand = this._connection.CreateCommand(); + walCommand.CommandText = "PRAGMA journal_mode=WAL;"; + walCommand.ExecuteNonQuery(); + + // Set busy timeout to handle concurrent access + using var busyCommand = this._connection.CreateCommand(); + busyCommand.CommandText = "PRAGMA busy_timeout=5000;"; + busyCommand.ExecuteNonQuery(); + + // Create table if not exists + using var createCommand = this._connection.CreateCommand(); + createCommand.CommandText = CreateTableSql; + createCommand.ExecuteNonQuery(); + + this._logger.LogInformation("Embedding cache initialized at {Path} with mode {Mode}", dbPath, mode); + } + + /// + public async Task TryGetAsync(EmbeddingCacheKey key, CancellationToken ct = default) + { + ct.ThrowIfCancellationRequested(); + + // Skip read in WriteOnly mode + if (this.Mode == CacheModes.WriteOnly) + { + this._logger.LogTrace("Skipping cache read in WriteOnly mode"); + return null; + } + + var compositeKey = key.ToCompositeKey(); + + var command = this._connection.CreateCommand(); + await using (command.ConfigureAwait(false)) + { + command.CommandText = SelectSql; + command.Parameters.AddWithValue("@key", compositeKey); + + var reader = await command.ExecuteReaderAsync(ct).ConfigureAwait(false); + await using (reader.ConfigureAwait(false)) + { + if (!await reader.ReadAsync(ct).ConfigureAwait(false)) + { + this._logger.LogTrace("Cache miss for key: {KeyPrefix}...", compositeKey[..Math.Min(50, compositeKey.Length)]); + return null; + } + + var vectorBlob = (byte[])reader["vector"]; + var vector = BytesToFloatArray(vectorBlob); + + int? tokenCount = reader["token_count"] == DBNull.Value ? null : Convert.ToInt32(reader["token_count"], CultureInfo.InvariantCulture); + var timestamp = DateTimeOffset.Parse((string)reader["timestamp"], CultureInfo.InvariantCulture); + + this._logger.LogTrace("Cache hit for key: {KeyPrefix}..., vector dimensions: {Dimensions}", + compositeKey[..Math.Min(50, compositeKey.Length)], vector.Length); + + return new CachedEmbedding + { + Vector = vector, + TokenCount = tokenCount, + Timestamp = timestamp + }; + } + } + } + + /// + public async Task StoreAsync(EmbeddingCacheKey key, float[] vector, int? tokenCount, CancellationToken ct = default) + { + ct.ThrowIfCancellationRequested(); + + // Skip write in ReadOnly mode + if (this.Mode == CacheModes.ReadOnly) + { + this._logger.LogTrace("Skipping cache write in ReadOnly mode"); + return; + } + + var compositeKey = key.ToCompositeKey(); + var vectorBlob = FloatArrayToBytes(vector); + var timestamp = DateTimeOffset.UtcNow.ToString("o", CultureInfo.InvariantCulture); + + var command = this._connection.CreateCommand(); + await using (command.ConfigureAwait(false)) + { + command.CommandText = UpsertSql; + command.Parameters.AddWithValue("@key", compositeKey); + command.Parameters.AddWithValue("@vector", vectorBlob); + command.Parameters.AddWithValue("@tokenCount", tokenCount.HasValue ? tokenCount.Value : DBNull.Value); + command.Parameters.AddWithValue("@timestamp", timestamp); + + await command.ExecuteNonQueryAsync(ct).ConfigureAwait(false); + + this._logger.LogTrace("Stored embedding in cache: {KeyPrefix}..., vector dimensions: {Dimensions}", + compositeKey[..Math.Min(50, compositeKey.Length)], vector.Length); + } + } + + /// + /// Converts a float array to a byte array for BLOB storage. + /// + private static byte[] FloatArrayToBytes(float[] array) + { + var bytes = new byte[array.Length * sizeof(float)]; + Buffer.BlockCopy(array, 0, bytes, 0, bytes.Length); + return bytes; + } + + /// + /// Converts a byte array from BLOB storage back to a float array. + /// + private static float[] BytesToFloatArray(byte[] bytes) + { + var array = new float[bytes.Length / sizeof(float)]; + Buffer.BlockCopy(bytes, 0, array, 0, bytes.Length); + return array; + } + + /// + /// Disposes the SQLite connection. + /// + public void Dispose() + { + if (this._disposed) { return; } + + this._connection.Close(); + this._connection.Dispose(); + this._disposed = true; + + this._logger.LogDebug("Embedding cache disposed"); + } +} diff --git a/src/Core/Embeddings/CachedEmbeddingGenerator.cs b/src/Core/Embeddings/CachedEmbeddingGenerator.cs new file mode 100644 index 000000000..26a100676 --- /dev/null +++ b/src/Core/Embeddings/CachedEmbeddingGenerator.cs @@ -0,0 +1,170 @@ +// Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Embeddings.Cache; +using Microsoft.Extensions.Logging; + +namespace KernelMemory.Core.Embeddings; + +/// +/// Decorator that wraps an IEmbeddingGenerator with caching support. +/// Checks the cache before generating embeddings, and stores new embeddings in the cache. +/// Supports different cache modes (ReadWrite, ReadOnly, WriteOnly). +/// +public sealed class CachedEmbeddingGenerator : IEmbeddingGenerator +{ + private readonly IEmbeddingGenerator _inner; + private readonly IEmbeddingCache _cache; + private readonly ILogger _logger; + + /// + public EmbeddingsTypes ProviderType => this._inner.ProviderType; + + /// + public string ModelName => this._inner.ModelName; + + /// + public int VectorDimensions => this._inner.VectorDimensions; + + /// + public bool IsNormalized => this._inner.IsNormalized; + + /// + /// Creates a new cached embedding generator decorator. + /// + /// The inner generator to wrap. + /// The cache to use for storing/retrieving embeddings. + /// Logger instance. + /// When inner, cache, or logger is null. + public CachedEmbeddingGenerator( + IEmbeddingGenerator inner, + IEmbeddingCache cache, + ILogger logger) + { + ArgumentNullException.ThrowIfNull(inner, nameof(inner)); + ArgumentNullException.ThrowIfNull(cache, nameof(cache)); + ArgumentNullException.ThrowIfNull(logger, nameof(logger)); + + this._inner = inner; + this._cache = cache; + this._logger = logger; + + this._logger.LogDebug( + "CachedEmbeddingGenerator initialized for {Provider}/{Model} with cache mode {Mode}", + inner.ProviderType, inner.ModelName, cache.Mode); + } + + /// + public async Task GenerateAsync(string text, CancellationToken ct = default) + { + var key = this.BuildCacheKey(text); + + // Try cache read (if mode allows) + if (this._cache.Mode != CacheModes.WriteOnly) + { + var cached = await this._cache.TryGetAsync(key, ct).ConfigureAwait(false); + if (cached != null) + { + this._logger.LogDebug("Cache hit for single embedding, dimensions: {Dimensions}", cached.Vector.Length); + return cached.Vector; + } + } + + // Generate embedding + this._logger.LogDebug("Cache miss for single embedding, calling {Provider}", this.ProviderType); + var vector = await this._inner.GenerateAsync(text, ct).ConfigureAwait(false); + + // Store in cache (if mode allows) + if (this._cache.Mode != CacheModes.ReadOnly) + { + await this._cache.StoreAsync(key, vector, tokenCount: null, ct).ConfigureAwait(false); + this._logger.LogDebug("Stored embedding in cache, dimensions: {Dimensions}", vector.Length); + } + + return vector; + } + + /// + public async Task GenerateAsync(IEnumerable texts, CancellationToken ct = default) + { + var textList = texts.ToList(); + if (textList.Count == 0) + { + return Array.Empty(); + } + + // Initialize result array with nulls + var results = new float[textList.Count][]; + + // Track which texts need to be generated + var toGenerate = new List<(int Index, string Text)>(); + + // Try cache reads (if mode allows) + if (this._cache.Mode != CacheModes.WriteOnly) + { + for (int i = 0; i < textList.Count; i++) + { + var key = this.BuildCacheKey(textList[i]); + var cached = await this._cache.TryGetAsync(key, ct).ConfigureAwait(false); + + if (cached != null) + { + results[i] = cached.Vector; + } + else + { + toGenerate.Add((i, textList[i])); + } + } + + this._logger.LogDebug( + "Batch cache lookup: {HitCount} hits, {MissCount} misses", + textList.Count - toGenerate.Count, toGenerate.Count); + } + else + { + // WriteOnly mode - all texts need to be generated + for (int i = 0; i < textList.Count; i++) + { + toGenerate.Add((i, textList[i])); + } + } + + // Generate missing embeddings + if (toGenerate.Count > 0) + { + var textsToGenerate = toGenerate.Select(x => x.Text); + var generatedVectors = await this._inner.GenerateAsync(textsToGenerate, ct).ConfigureAwait(false); + + // Map generated vectors back to results and store in cache + for (int i = 0; i < toGenerate.Count; i++) + { + var (originalIndex, text) = toGenerate[i]; + results[originalIndex] = generatedVectors[i]; + + // Store in cache (if mode allows) + if (this._cache.Mode != CacheModes.ReadOnly) + { + var key = this.BuildCacheKey(text); + await this._cache.StoreAsync(key, generatedVectors[i], tokenCount: null, ct).ConfigureAwait(false); + } + } + + this._logger.LogDebug("Generated and cached {Count} embeddings", toGenerate.Count); + } + + return results; + } + + /// + /// Builds a cache key for the given text using the inner generator's properties. + /// + private EmbeddingCacheKey BuildCacheKey(string text) + { + return EmbeddingCacheKey.Create( + provider: this._inner.ProviderType.ToString(), + model: this._inner.ModelName, + vectorDimensions: this._inner.VectorDimensions, + isNormalized: this._inner.IsNormalized, + text: text); + } +} diff --git a/src/Core/Embeddings/EmbeddingConstants.cs b/src/Core/Embeddings/EmbeddingConstants.cs new file mode 100644 index 000000000..d8b9dcd4a --- /dev/null +++ b/src/Core/Embeddings/EmbeddingConstants.cs @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft. All rights reserved. +namespace KernelMemory.Core.Embeddings; + +/// +/// Constants for embedding generation including known model dimensions, +/// default configurations, and batch sizes. +/// +public static class EmbeddingConstants +{ + /// + /// Default batch size for embedding generation requests. + /// Configurable per provider, but this is the default. + /// + public const int DefaultBatchSize = 10; + + /// + /// Default Ollama model for embeddings. + /// + public const string DefaultOllamaModel = "qwen3-embedding"; + + /// + /// Default Ollama base URL. + /// + public const string DefaultOllamaBaseUrl = "http://localhost:11434"; + + /// + /// Default HuggingFace model for embeddings. + /// + public const string DefaultHuggingFaceModel = "sentence-transformers/all-MiniLM-L6-v2"; + + /// + /// Default HuggingFace Inference API base URL. + /// + public const string DefaultHuggingFaceBaseUrl = "https://api-inference.huggingface.co"; + + /// + /// Default OpenAI API base URL. + /// + public const string DefaultOpenAIBaseUrl = "https://api.openai.com"; + + /// + /// Azure OpenAI API version. + /// + public const string AzureOpenAIApiVersion = "2024-02-01"; + + /// + /// Known model dimensions for common embedding models. + /// These values are fixed per model and used for validation and cache key generation. + /// + public static readonly IReadOnlyDictionary KnownModelDimensions = new Dictionary + { + // Ollama models + ["qwen3-embedding"] = 1024, + ["nomic-embed-text"] = 768, + ["embeddinggemma"] = 768, + + // OpenAI models + ["text-embedding-ada-002"] = 1536, + ["text-embedding-3-small"] = 1536, + ["text-embedding-3-large"] = 3072, + + // HuggingFace models + ["sentence-transformers/all-MiniLM-L6-v2"] = 384, + ["BAAI/bge-base-en-v1.5"] = 768 + }; + + /// + /// Try to get the dimensions for a known model. + /// + /// The model name to look up. + /// The dimensions if found, 0 otherwise. + /// True if the model is known, false otherwise. + public static bool TryGetDimensions(string modelName, out int dimensions) + { + return KnownModelDimensions.TryGetValue(modelName, out dimensions); + } +} diff --git a/src/Core/Embeddings/IEmbeddingGenerator.cs b/src/Core/Embeddings/IEmbeddingGenerator.cs new file mode 100644 index 000000000..a037269a3 --- /dev/null +++ b/src/Core/Embeddings/IEmbeddingGenerator.cs @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core.Config.Enums; + +namespace KernelMemory.Core.Embeddings; + +/// +/// Interface for generating text embeddings using various providers. +/// Implementations handle communication with embedding APIs (OpenAI, Azure, Ollama, HuggingFace). +/// +public interface IEmbeddingGenerator +{ + /// + /// Provider type (OpenAI, Azure, Ollama, HuggingFace). + /// Used for cache key generation to ensure embeddings from different providers are not mixed. + /// + EmbeddingsTypes ProviderType { get; } + + /// + /// Model name being used for embedding generation. + /// Used for cache key generation to ensure embeddings from different models are not mixed. + /// + string ModelName { get; } + + /// + /// Vector dimensions produced by this model. + /// Used for validation and cache key generation. + /// + int VectorDimensions { get; } + + /// + /// Whether this generator returns normalized vectors. + /// Normalized vectors have unit length (magnitude 1.0). + /// Used for cache key generation since normalized and non-normalized vectors are not interchangeable. + /// + bool IsNormalized { get; } + + /// + /// Generate embedding for a single text. + /// + /// The text to generate embedding for. + /// Cancellation token. + /// The embedding vector as a float array. + /// When the API call fails. + /// When the operation is cancelled. + Task GenerateAsync(string text, CancellationToken ct = default); + + /// + /// Generate embeddings for multiple texts (batch). + /// Implementations may send texts in batches to the API based on provider-specific limits. + /// + /// The texts to generate embeddings for. + /// Cancellation token. + /// Array of embedding vectors, in the same order as the input texts. + /// When the API call fails. + /// When the operation is cancelled. + Task GenerateAsync(IEnumerable texts, CancellationToken ct = default); +} diff --git a/src/Core/Embeddings/Providers/AzureOpenAIEmbeddingGenerator.cs b/src/Core/Embeddings/Providers/AzureOpenAIEmbeddingGenerator.cs new file mode 100644 index 000000000..3e7a61302 --- /dev/null +++ b/src/Core/Embeddings/Providers/AzureOpenAIEmbeddingGenerator.cs @@ -0,0 +1,163 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Net.Http.Json; +using System.Text.Json.Serialization; +using KernelMemory.Core.Config.Enums; +using Microsoft.Extensions.Logging; + +namespace KernelMemory.Core.Embeddings.Providers; + +/// +/// Azure OpenAI embedding generator implementation. +/// Communicates with Azure OpenAI Service. +/// Supports API key authentication (managed identity would require Azure.Identity package). +/// +public sealed class AzureOpenAIEmbeddingGenerator : IEmbeddingGenerator +{ + private readonly HttpClient _httpClient; + private readonly string _endpoint; + private readonly string _deployment; + private readonly string _apiKey; + private readonly ILogger _logger; + + /// + public EmbeddingsTypes ProviderType => EmbeddingsTypes.AzureOpenAI; + + /// + public string ModelName { get; } + + /// + public int VectorDimensions { get; } + + /// + public bool IsNormalized { get; } + + /// + /// Creates a new Azure OpenAI embedding generator with API key authentication. + /// + /// HTTP client for API calls. + /// Azure OpenAI endpoint (e.g., https://myservice.openai.azure.com). + /// Deployment name in Azure. + /// Model name for identification. + /// Azure OpenAI API key. + /// Vector dimensions produced by the model. + /// Whether vectors are normalized. + /// Logger instance. + public AzureOpenAIEmbeddingGenerator( + HttpClient httpClient, + string endpoint, + string deployment, + string model, + string apiKey, + int vectorDimensions, + bool isNormalized, + ILogger logger) + { + ArgumentNullException.ThrowIfNull(httpClient, nameof(httpClient)); + ArgumentNullException.ThrowIfNull(endpoint, nameof(endpoint)); + ArgumentNullException.ThrowIfNull(deployment, nameof(deployment)); + ArgumentNullException.ThrowIfNull(model, nameof(model)); + ArgumentNullException.ThrowIfNull(apiKey, nameof(apiKey)); + ArgumentNullException.ThrowIfNull(logger, nameof(logger)); + + this._httpClient = httpClient; + this._endpoint = endpoint.TrimEnd('/'); + this._deployment = deployment; + this._apiKey = apiKey; + this.ModelName = model; + this.VectorDimensions = vectorDimensions; + this.IsNormalized = isNormalized; + this._logger = logger; + + this._logger.LogDebug("AzureOpenAIEmbeddingGenerator initialized: {Endpoint}, deployment: {Deployment}, model: {Model}", + this._endpoint, this._deployment, this.ModelName); + } + + /// + public async Task GenerateAsync(string text, CancellationToken ct = default) + { + var results = await this.GenerateAsync(new[] { text }, ct).ConfigureAwait(false); + return results[0]; + } + + /// + public async Task GenerateAsync(IEnumerable texts, CancellationToken ct = default) + { + var textArray = texts.ToArray(); + if (textArray.Length == 0) + { + return Array.Empty(); + } + + var url = $"{this._endpoint}/openai/deployments/{this._deployment}/embeddings?api-version={EmbeddingConstants.AzureOpenAIApiVersion}"; + + var request = new AzureEmbeddingRequest + { + Input = textArray + }; + + using var httpRequest = new HttpRequestMessage(HttpMethod.Post, url); + httpRequest.Headers.Add("api-key", this._apiKey); + httpRequest.Content = JsonContent.Create(request); + + this._logger.LogTrace("Calling Azure OpenAI embeddings API: deployment={Deployment}, batch size: {BatchSize}", + this._deployment, textArray.Length); + + var response = await this._httpClient.SendAsync(httpRequest, ct).ConfigureAwait(false); + response.EnsureSuccessStatusCode(); + + var result = await response.Content.ReadFromJsonAsync(ct).ConfigureAwait(false); + + if (result?.Data == null || result.Data.Length == 0) + { + throw new InvalidOperationException("Azure OpenAI returned empty embedding response"); + } + + // Sort by index to ensure correct ordering + var sortedData = result.Data.OrderBy(d => d.Index).ToArray(); + var embeddings = sortedData.Select(d => d.Embedding).ToArray(); + + this._logger.LogTrace("Azure OpenAI returned {Count} embeddings, usage: {TotalTokens} tokens", + embeddings.Length, result.Usage?.TotalTokens); + + return embeddings; + } + + /// + /// Request body for Azure OpenAI embeddings API. + /// + private sealed class AzureEmbeddingRequest + { + [JsonPropertyName("input")] + public string[] Input { get; set; } = Array.Empty(); + } + + /// + /// Response from Azure OpenAI embeddings API. + /// + private sealed class AzureEmbeddingResponse + { + [JsonPropertyName("data")] + public EmbeddingData[] Data { get; set; } = Array.Empty(); + + [JsonPropertyName("usage")] + public UsageInfo? Usage { get; set; } + } + + private sealed class EmbeddingData + { + [JsonPropertyName("index")] + public int Index { get; set; } + + [JsonPropertyName("embedding")] + public float[] Embedding { get; set; } = Array.Empty(); + } + + private sealed class UsageInfo + { + [JsonPropertyName("prompt_tokens")] + public int PromptTokens { get; set; } + + [JsonPropertyName("total_tokens")] + public int TotalTokens { get; set; } + } +} diff --git a/src/Core/Embeddings/Providers/HuggingFaceEmbeddingGenerator.cs b/src/Core/Embeddings/Providers/HuggingFaceEmbeddingGenerator.cs new file mode 100644 index 000000000..bc13c3ee6 --- /dev/null +++ b/src/Core/Embeddings/Providers/HuggingFaceEmbeddingGenerator.cs @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Net.Http.Headers; +using System.Net.Http.Json; +using System.Text.Json.Serialization; +using KernelMemory.Core.Config.Enums; +using Microsoft.Extensions.Logging; + +namespace KernelMemory.Core.Embeddings.Providers; + +/// +/// HuggingFace Inference API embedding generator implementation. +/// Communicates with the HuggingFace serverless Inference API. +/// Default model: sentence-transformers/all-MiniLM-L6-v2. +/// +public sealed class HuggingFaceEmbeddingGenerator : IEmbeddingGenerator +{ + private readonly HttpClient _httpClient; + private readonly string _apiKey; + private readonly string _baseUrl; + private readonly ILogger _logger; + + /// + public EmbeddingsTypes ProviderType => EmbeddingsTypes.HuggingFace; + + /// + public string ModelName { get; } + + /// + public int VectorDimensions { get; } + + /// + public bool IsNormalized { get; } + + /// + /// Creates a new HuggingFace embedding generator. + /// + /// HTTP client for API calls. + /// HuggingFace API token (HF_TOKEN). + /// Model name (e.g., sentence-transformers/all-MiniLM-L6-v2). + /// Vector dimensions produced by the model. + /// Whether vectors are normalized. + /// Optional custom base URL for inference endpoints. + /// Logger instance. + public HuggingFaceEmbeddingGenerator( + HttpClient httpClient, + string apiKey, + string model, + int vectorDimensions, + bool isNormalized, + string? baseUrl, + ILogger logger) + { + ArgumentNullException.ThrowIfNull(httpClient, nameof(httpClient)); + ArgumentNullException.ThrowIfNull(apiKey, nameof(apiKey)); + ArgumentException.ThrowIfNullOrEmpty(apiKey, nameof(apiKey)); + ArgumentNullException.ThrowIfNull(model, nameof(model)); + ArgumentNullException.ThrowIfNull(logger, nameof(logger)); + + this._httpClient = httpClient; + this._apiKey = apiKey; + this._baseUrl = (baseUrl ?? EmbeddingConstants.DefaultHuggingFaceBaseUrl).TrimEnd('/'); + this.ModelName = model; + this.VectorDimensions = vectorDimensions; + this.IsNormalized = isNormalized; + this._logger = logger; + + this._logger.LogDebug("HuggingFaceEmbeddingGenerator initialized: {BaseUrl}, model: {Model}, dimensions: {Dimensions}", + this._baseUrl, this.ModelName, this.VectorDimensions); + } + + /// + public async Task GenerateAsync(string text, CancellationToken ct = default) + { + var results = await this.GenerateAsync(new[] { text }, ct).ConfigureAwait(false); + return results[0]; + } + + /// + public async Task GenerateAsync(IEnumerable texts, CancellationToken ct = default) + { + var textArray = texts.ToArray(); + if (textArray.Length == 0) + { + return Array.Empty(); + } + + var endpoint = $"{this._baseUrl}/models/{this.ModelName}"; + + var request = new HuggingFaceRequest + { + Inputs = textArray + }; + + using var httpRequest = new HttpRequestMessage(HttpMethod.Post, endpoint); + httpRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", this._apiKey); + httpRequest.Content = JsonContent.Create(request); + + this._logger.LogTrace("Calling HuggingFace embeddings API: {Endpoint}, batch size: {BatchSize}", + endpoint, textArray.Length); + + var response = await this._httpClient.SendAsync(httpRequest, ct).ConfigureAwait(false); + response.EnsureSuccessStatusCode(); + + // HuggingFace returns array of embeddings directly: float[][] + var embeddings = await response.Content.ReadFromJsonAsync(ct).ConfigureAwait(false); + + if (embeddings == null || embeddings.Length == 0) + { + throw new InvalidOperationException("HuggingFace returned empty embedding response"); + } + + this._logger.LogTrace("HuggingFace returned {Count} embeddings with {Dimensions} dimensions each", + embeddings.Length, embeddings[0].Length); + + return embeddings; + } + + /// + /// Request body for HuggingFace Inference API. + /// + private sealed class HuggingFaceRequest + { + [JsonPropertyName("inputs")] + public string[] Inputs { get; set; } = Array.Empty(); + } +} diff --git a/src/Core/Embeddings/Providers/OllamaEmbeddingGenerator.cs b/src/Core/Embeddings/Providers/OllamaEmbeddingGenerator.cs new file mode 100644 index 000000000..d43e0058c --- /dev/null +++ b/src/Core/Embeddings/Providers/OllamaEmbeddingGenerator.cs @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Net.Http.Json; +using System.Text.Json.Serialization; +using KernelMemory.Core.Config.Enums; +using Microsoft.Extensions.Logging; + +namespace KernelMemory.Core.Embeddings.Providers; + +/// +/// Ollama embedding generator implementation. +/// Communicates with a local Ollama instance via HTTP. +/// Default model: qwen3-embedding. +/// +public sealed class OllamaEmbeddingGenerator : IEmbeddingGenerator +{ + private readonly HttpClient _httpClient; + private readonly string _baseUrl; + private readonly ILogger _logger; + + /// + public EmbeddingsTypes ProviderType => EmbeddingsTypes.Ollama; + + /// + public string ModelName { get; } + + /// + public int VectorDimensions { get; } + + /// + public bool IsNormalized { get; } + + /// + /// Creates a new Ollama embedding generator. + /// + /// HTTP client for API calls. + /// Ollama base URL (e.g., http://localhost:11434). + /// Model name (e.g., qwen3-embedding). + /// Vector dimensions produced by the model. + /// Whether vectors are normalized. + /// Logger instance. + public OllamaEmbeddingGenerator( + HttpClient httpClient, + string baseUrl, + string model, + int vectorDimensions, + bool isNormalized, + ILogger logger) + { + ArgumentNullException.ThrowIfNull(httpClient, nameof(httpClient)); + ArgumentNullException.ThrowIfNull(baseUrl, nameof(baseUrl)); + ArgumentNullException.ThrowIfNull(model, nameof(model)); + ArgumentNullException.ThrowIfNull(logger, nameof(logger)); + + this._httpClient = httpClient; + this._baseUrl = baseUrl.TrimEnd('/'); + this.ModelName = model; + this.VectorDimensions = vectorDimensions; + this.IsNormalized = isNormalized; + this._logger = logger; + + this._logger.LogDebug("OllamaEmbeddingGenerator initialized: {BaseUrl}, model: {Model}, dimensions: {Dimensions}", + this._baseUrl, this.ModelName, this.VectorDimensions); + } + + /// + public async Task GenerateAsync(string text, CancellationToken ct = default) + { + var endpoint = $"{this._baseUrl}/api/embeddings"; + + var request = new OllamaEmbeddingRequest + { + Model = this.ModelName, + Prompt = text + }; + + this._logger.LogTrace("Calling Ollama embeddings API: {Endpoint}", endpoint); + + var response = await this._httpClient.PostAsJsonAsync(endpoint, request, ct).ConfigureAwait(false); + response.EnsureSuccessStatusCode(); + + var result = await response.Content.ReadFromJsonAsync(ct).ConfigureAwait(false); + + if (result?.Embedding == null || result.Embedding.Length == 0) + { + throw new InvalidOperationException("Ollama returned empty embedding"); + } + + this._logger.LogTrace("Ollama returned embedding with {Dimensions} dimensions", result.Embedding.Length); + + return result.Embedding; + } + + /// + public async Task GenerateAsync(IEnumerable texts, CancellationToken ct = default) + { + // Ollama doesn't support batch embedding natively, so process one at a time + var textList = texts.ToList(); + var results = new float[textList.Count][]; + + this._logger.LogDebug("Generating {Count} embeddings via Ollama (sequential)", textList.Count); + + for (int i = 0; i < textList.Count; i++) + { + results[i] = await this.GenerateAsync(textList[i], ct).ConfigureAwait(false); + } + + return results; + } + + /// + /// Request body for Ollama embeddings API. + /// + private sealed class OllamaEmbeddingRequest + { + [JsonPropertyName("model")] + public string Model { get; set; } = string.Empty; + + [JsonPropertyName("prompt")] + public string Prompt { get; set; } = string.Empty; + } + + /// + /// Response from Ollama embeddings API. + /// + private sealed class OllamaEmbeddingResponse + { + [JsonPropertyName("embedding")] + public float[] Embedding { get; set; } = Array.Empty(); + } +} diff --git a/src/Core/Embeddings/Providers/OpenAIEmbeddingGenerator.cs b/src/Core/Embeddings/Providers/OpenAIEmbeddingGenerator.cs new file mode 100644 index 000000000..81e4bce18 --- /dev/null +++ b/src/Core/Embeddings/Providers/OpenAIEmbeddingGenerator.cs @@ -0,0 +1,163 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Net.Http.Headers; +using System.Net.Http.Json; +using System.Text.Json.Serialization; +using KernelMemory.Core.Config.Enums; +using Microsoft.Extensions.Logging; + +namespace KernelMemory.Core.Embeddings.Providers; + +/// +/// OpenAI embedding generator implementation. +/// Communicates with the OpenAI API (or compatible endpoints). +/// Supports batch embedding requests. +/// +public sealed class OpenAIEmbeddingGenerator : IEmbeddingGenerator +{ + private readonly HttpClient _httpClient; + private readonly string _apiKey; + private readonly string _baseUrl; + private readonly ILogger _logger; + + /// + public EmbeddingsTypes ProviderType => EmbeddingsTypes.OpenAI; + + /// + public string ModelName { get; } + + /// + public int VectorDimensions { get; } + + /// + public bool IsNormalized { get; } + + /// + /// Creates a new OpenAI embedding generator. + /// + /// HTTP client for API calls. + /// OpenAI API key. + /// Model name (e.g., text-embedding-ada-002). + /// Vector dimensions produced by the model. + /// Whether vectors are normalized. + /// Optional custom base URL for OpenAI-compatible APIs. + /// Logger instance. + public OpenAIEmbeddingGenerator( + HttpClient httpClient, + string apiKey, + string model, + int vectorDimensions, + bool isNormalized, + string? baseUrl, + ILogger logger) + { + ArgumentNullException.ThrowIfNull(httpClient, nameof(httpClient)); + ArgumentNullException.ThrowIfNull(apiKey, nameof(apiKey)); + ArgumentException.ThrowIfNullOrEmpty(apiKey, nameof(apiKey)); + ArgumentNullException.ThrowIfNull(model, nameof(model)); + ArgumentNullException.ThrowIfNull(logger, nameof(logger)); + + this._httpClient = httpClient; + this._apiKey = apiKey; + this._baseUrl = (baseUrl ?? EmbeddingConstants.DefaultOpenAIBaseUrl).TrimEnd('/'); + this.ModelName = model; + this.VectorDimensions = vectorDimensions; + this.IsNormalized = isNormalized; + this._logger = logger; + + this._logger.LogDebug("OpenAIEmbeddingGenerator initialized: {BaseUrl}, model: {Model}, dimensions: {Dimensions}", + this._baseUrl, this.ModelName, this.VectorDimensions); + } + + /// + public async Task GenerateAsync(string text, CancellationToken ct = default) + { + var results = await this.GenerateAsync(new[] { text }, ct).ConfigureAwait(false); + return results[0]; + } + + /// + public async Task GenerateAsync(IEnumerable texts, CancellationToken ct = default) + { + var textArray = texts.ToArray(); + if (textArray.Length == 0) + { + return Array.Empty(); + } + + var endpoint = $"{this._baseUrl}/v1/embeddings"; + + var request = new OpenAIEmbeddingRequest + { + Model = this.ModelName, + Input = textArray + }; + + using var httpRequest = new HttpRequestMessage(HttpMethod.Post, endpoint); + httpRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", this._apiKey); + httpRequest.Content = JsonContent.Create(request); + + this._logger.LogTrace("Calling OpenAI embeddings API: {Endpoint}, batch size: {BatchSize}", + endpoint, textArray.Length); + + var response = await this._httpClient.SendAsync(httpRequest, ct).ConfigureAwait(false); + response.EnsureSuccessStatusCode(); + + var result = await response.Content.ReadFromJsonAsync(ct).ConfigureAwait(false); + + if (result?.Data == null || result.Data.Length == 0) + { + throw new InvalidOperationException("OpenAI returned empty embedding response"); + } + + // Sort by index to ensure correct ordering + var sortedData = result.Data.OrderBy(d => d.Index).ToArray(); + var embeddings = sortedData.Select(d => d.Embedding).ToArray(); + + this._logger.LogTrace("OpenAI returned {Count} embeddings, usage: {TotalTokens} tokens", + embeddings.Length, result.Usage?.TotalTokens); + + return embeddings; + } + + /// + /// Request body for OpenAI embeddings API. + /// + private sealed class OpenAIEmbeddingRequest + { + [JsonPropertyName("model")] + public string Model { get; set; } = string.Empty; + + [JsonPropertyName("input")] + public string[] Input { get; set; } = Array.Empty(); + } + + /// + /// Response from OpenAI embeddings API. + /// + private sealed class OpenAIEmbeddingResponse + { + [JsonPropertyName("data")] + public EmbeddingData[] Data { get; set; } = Array.Empty(); + + [JsonPropertyName("usage")] + public UsageInfo? Usage { get; set; } + } + + private sealed class EmbeddingData + { + [JsonPropertyName("index")] + public int Index { get; set; } + + [JsonPropertyName("embedding")] + public float[] Embedding { get; set; } = Array.Empty(); + } + + private sealed class UsageInfo + { + [JsonPropertyName("prompt_tokens")] + public int PromptTokens { get; set; } + + [JsonPropertyName("total_tokens")] + public int TotalTokens { get; set; } + } +} diff --git a/tests/Core.Tests/Config/CacheModesTests.cs b/tests/Core.Tests/Config/CacheModesTests.cs new file mode 100644 index 000000000..3e8e47ef4 --- /dev/null +++ b/tests/Core.Tests/Config/CacheModesTests.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core.Config.Enums; + +namespace KernelMemory.Core.Tests.Config; + +/// +/// Tests for CacheModes enum to verify all expected values exist. +/// These values control cache read/write behavior for embedding caching. +/// +public sealed class CacheModesTests +{ + [Fact] + public void CacheModes_ShouldHaveReadWriteValue() + { + // Assert + Assert.True(Enum.IsDefined(CacheModes.ReadWrite)); + } + + [Fact] + public void CacheModes_ShouldHaveReadOnlyValue() + { + // Assert + Assert.True(Enum.IsDefined(CacheModes.ReadOnly)); + } + + [Fact] + public void CacheModes_ShouldHaveWriteOnlyValue() + { + // Assert + Assert.True(Enum.IsDefined(CacheModes.WriteOnly)); + } + + [Fact] + public void CacheModes_ShouldHaveExactlyThreeValues() + { + // Assert + var values = Enum.GetValues(); + Assert.Equal(3, values.Length); + } + + [Theory] + [InlineData("ReadWrite", CacheModes.ReadWrite)] + [InlineData("ReadOnly", CacheModes.ReadOnly)] + [InlineData("WriteOnly", CacheModes.WriteOnly)] + public void CacheModes_ShouldParseFromString(string name, CacheModes expected) + { + // Act + var parsed = Enum.Parse(name); + + // Assert + Assert.Equal(expected, parsed); + } +} diff --git a/tests/Core.Tests/Config/HuggingFaceEmbeddingsConfigTests.cs b/tests/Core.Tests/Config/HuggingFaceEmbeddingsConfigTests.cs new file mode 100644 index 000000000..bc39ad2f3 --- /dev/null +++ b/tests/Core.Tests/Config/HuggingFaceEmbeddingsConfigTests.cs @@ -0,0 +1,186 @@ +// Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core.Config; +using KernelMemory.Core.Config.Embeddings; +using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Config.Validation; + +namespace KernelMemory.Core.Tests.Config; + +/// +/// Tests for HuggingFaceEmbeddingsConfig to verify configuration parsing, validation, and defaults. +/// +public sealed class HuggingFaceEmbeddingsConfigTests +{ + [Fact] + public void Type_ShouldReturnHuggingFace() + { + // Arrange + var config = new HuggingFaceEmbeddingsConfig(); + + // Assert + Assert.Equal(EmbeddingsTypes.HuggingFace, config.Type); + } + + [Fact] + public void DefaultModel_ShouldBeAllMiniLM() + { + // Arrange + var config = new HuggingFaceEmbeddingsConfig(); + + // Assert + Assert.Equal("sentence-transformers/all-MiniLM-L6-v2", config.Model); + } + + [Fact] + public void DefaultBaseUrl_ShouldBeInferenceApi() + { + // Arrange + var config = new HuggingFaceEmbeddingsConfig(); + + // Assert + Assert.Equal("https://api-inference.huggingface.co", config.BaseUrl); + } + + [Fact] + public void ApiKey_ShouldDefaultToNull() + { + // Arrange + var config = new HuggingFaceEmbeddingsConfig(); + + // Assert + Assert.Null(config.ApiKey); + } + + [Fact] + public void Validate_WithValidConfig_ShouldNotThrow() + { + // Arrange + var config = new HuggingFaceEmbeddingsConfig + { + Model = "BAAI/bge-base-en-v1.5", + ApiKey = "hf_test_token", + BaseUrl = "https://api-inference.huggingface.co" + }; + + // Act & Assert - should not throw + config.Validate("test"); + } + + [Fact] + public void Validate_WithEmptyModel_ShouldThrowConfigException() + { + // Arrange + var config = new HuggingFaceEmbeddingsConfig + { + Model = "", + ApiKey = "hf_test_token" + }; + + // Act & Assert + var ex = Assert.Throws(() => config.Validate("test")); + Assert.Contains("Model", ex.Message); + } + + [Fact] + public void Validate_WithEmptyApiKey_ShouldThrowConfigException() + { + // Arrange + var config = new HuggingFaceEmbeddingsConfig + { + Model = "model", + ApiKey = "" + }; + + // Act & Assert + var ex = Assert.Throws(() => config.Validate("test")); + Assert.Contains("API", ex.Message, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public void Validate_WithNullApiKey_ShouldThrowConfigException() + { + // Arrange + var config = new HuggingFaceEmbeddingsConfig + { + Model = "model", + ApiKey = null + }; + + // Act & Assert + var ex = Assert.Throws(() => config.Validate("test")); + Assert.Contains("API", ex.Message, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public void Validate_WithInvalidBaseUrl_ShouldThrowConfigException() + { + // Arrange + var config = new HuggingFaceEmbeddingsConfig + { + Model = "model", + ApiKey = "hf_token", + BaseUrl = "not-a-valid-url" + }; + + // Act & Assert + var ex = Assert.Throws(() => config.Validate("test")); + Assert.Contains("URL", ex.Message, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public void LoadFromFile_WithHuggingFaceEmbeddings_ShouldDeserialize() + { + // Arrange + var tempFile = Path.Combine(Path.GetTempPath(), $"config-{Guid.NewGuid()}.json"); + const string json = @"{ + ""nodes"": { + ""test"": { + ""id"": ""test"", + ""access"": ""Full"", + ""contentIndex"": { + ""type"": ""sqlite"", + ""path"": ""test.db"" + }, + ""searchIndexes"": [ + { + ""type"": ""sqliteVector"", + ""id"": ""vector-test"", + ""path"": ""vector.db"", + ""dimensions"": 384, + ""embeddings"": { + ""type"": ""huggingFace"", + ""model"": ""sentence-transformers/all-MiniLM-L6-v2"", + ""apiKey"": ""hf_test_token"", + ""baseUrl"": ""https://api-inference.huggingface.co"" + } + } + ] + } + } + }"; + + try + { + File.WriteAllText(tempFile, json); + + // Act + var config = ConfigParser.LoadFromFile(tempFile); + + // Assert + Assert.NotNull(config); + var searchIndex = config.Nodes["test"].SearchIndexes[0]; + Assert.NotNull(searchIndex.Embeddings); + var hfConfig = Assert.IsType(searchIndex.Embeddings); + Assert.Equal("sentence-transformers/all-MiniLM-L6-v2", hfConfig.Model); + Assert.Equal("hf_test_token", hfConfig.ApiKey); + Assert.Equal("https://api-inference.huggingface.co", hfConfig.BaseUrl); + } + finally + { + if (File.Exists(tempFile)) + { + File.Delete(tempFile); + } + } + } +} diff --git a/tests/Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs b/tests/Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs new file mode 100644 index 000000000..88b29dd36 --- /dev/null +++ b/tests/Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs @@ -0,0 +1,376 @@ +// Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Embeddings.Cache; +using Microsoft.Extensions.Logging; +using Moq; + +namespace KernelMemory.Core.Tests.Embeddings.Cache; + +/// +/// Tests for SqliteEmbeddingCache to verify CRUD operations, cache modes, and BLOB storage. +/// These tests use temporary SQLite databases to ensure isolation and avoid user data. +/// +public sealed class SqliteEmbeddingCacheTests : IDisposable +{ + private readonly string _tempDbPath; + private readonly Mock> _loggerMock; + + public SqliteEmbeddingCacheTests() + { + this._tempDbPath = Path.Combine(Path.GetTempPath(), $"test-cache-{Guid.NewGuid()}.db"); + this._loggerMock = new Mock>(); + } + + public void Dispose() + { + // Clean up test database + if (File.Exists(this._tempDbPath)) + { + File.Delete(this._tempDbPath); + } + } + + [Fact] + public async Task TryGetAsync_WithEmptyCache_ShouldReturnNull() + { + // Arrange + using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); + + // Act + var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Null(result); + } + + [Fact] + public async Task StoreAsync_AndTryGetAsync_ShouldRoundTrip() + { + // Arrange + using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); + var vector = new float[] { 0.1f, 0.2f, 0.3f, 0.4f }; + + // Act + await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.NotNull(result); + Assert.Equal(vector, result.Vector); + } + + [Fact] + public async Task StoreAsync_WithTokenCount_ShouldPreserveValue() + { + // Arrange + using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); + var vector = new float[] { 0.1f, 0.2f, 0.3f }; + const int tokenCount = 42; + + // Act + await cache.StoreAsync(key, vector, tokenCount, CancellationToken.None).ConfigureAwait(false); + var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.NotNull(result); + Assert.Equal(tokenCount, result.TokenCount); + } + + [Fact] + public async Task StoreAsync_ShouldSetTimestamp() + { + // Arrange + using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); + var vector = new float[] { 0.1f, 0.2f, 0.3f }; + var beforeStore = DateTimeOffset.UtcNow; + + // Act + await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); + var afterStore = DateTimeOffset.UtcNow; + + // Assert + Assert.NotNull(result); + Assert.True(result.Timestamp >= beforeStore.AddSeconds(-1)); // Allow 1 second tolerance + Assert.True(result.Timestamp <= afterStore.AddSeconds(1)); + } + + [Fact] + public async Task StoreAsync_WithLargeVector_ShouldRoundTrip() + { + // Arrange - 1536 dimensions (OpenAI ada-002) + using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); + var key = EmbeddingCacheKey.Create("OpenAI", "text-embedding-ada-002", 1536, true, "test text"); + var vector = new float[1536]; + for (int i = 0; i < vector.Length; i++) + { + vector[i] = (float)i / 1536; + } + + // Act + await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.NotNull(result); + Assert.Equal(1536, result.Vector.Length); + for (int i = 0; i < vector.Length; i++) + { + Assert.Equal(vector[i], result.Vector[i]); + } + } + + [Fact] + public async Task StoreAsync_WithSameKey_ShouldOverwrite() + { + // Arrange + using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); + var vector1 = new float[] { 0.1f, 0.2f, 0.3f }; + var vector2 = new float[] { 0.9f, 0.8f, 0.7f }; + + // Act + await cache.StoreAsync(key, vector1, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector2, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.NotNull(result); + Assert.Equal(vector2, result.Vector); + } + + [Fact] + public async Task TryGetAsync_WithDifferentKeys_ShouldReturnCorrectValues() + { + // Arrange + using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); + var key1 = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "text one"); + var key2 = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "text two"); + var vector1 = new float[] { 0.1f, 0.2f, 0.3f }; + var vector2 = new float[] { 0.4f, 0.5f, 0.6f }; + + // Act + await cache.StoreAsync(key1, vector1, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key2, vector2, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + var result1 = await cache.TryGetAsync(key1, CancellationToken.None).ConfigureAwait(false); + var result2 = await cache.TryGetAsync(key2, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.NotNull(result1); + Assert.NotNull(result2); + Assert.Equal(vector1, result1.Vector); + Assert.Equal(vector2, result2.Vector); + } + + [Fact] + public async Task ReadOnlyMode_TryGetAsync_ShouldWork() + { + // Arrange - First write with read-write mode + using (var writeCache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object)) + { + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); + var vector = new float[] { 0.1f, 0.2f, 0.3f }; + await writeCache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + } + + // Act - Then read with read-only mode + using var readCache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadOnly, this._loggerMock.Object); + var readKey = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); + var result = await readCache.TryGetAsync(readKey, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.NotNull(result); + Assert.Equal(new float[] { 0.1f, 0.2f, 0.3f }, result.Vector); + } + + [Fact] + public async Task ReadOnlyMode_StoreAsync_ShouldNotWrite() + { + // Arrange + using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadOnly, this._loggerMock.Object); + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); + var vector = new float[] { 0.1f, 0.2f, 0.3f }; + + // Act - Store should be ignored in read-only mode + await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Null(result); + } + + [Fact] + public async Task WriteOnlyMode_StoreAsync_ShouldWork() + { + // Arrange + using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.WriteOnly, this._loggerMock.Object); + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); + var vector = new float[] { 0.1f, 0.2f, 0.3f }; + + // Act + await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + + // Assert - verify by reading with read-write cache + using var readCache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); + var result = await readCache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); + Assert.NotNull(result); + Assert.Equal(vector, result.Vector); + } + + [Fact] + public async Task WriteOnlyMode_TryGetAsync_ShouldReturnNull() + { + // Arrange - First store with read-write + using (var writeCache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object)) + { + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); + var vector = new float[] { 0.1f, 0.2f, 0.3f }; + await writeCache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + } + + // Act - Read with write-only mode should return null + using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.WriteOnly, this._loggerMock.Object); + var readKey = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); + var result = await cache.TryGetAsync(readKey, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Null(result); + } + + [Fact] + public void Mode_ShouldReflectConstructorValue() + { + // Arrange & Act + using var readWriteCache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); + var path2 = Path.Combine(Path.GetTempPath(), $"test-cache-{Guid.NewGuid()}.db"); + using var readOnlyCache = new SqliteEmbeddingCache(path2, CacheModes.ReadOnly, this._loggerMock.Object); + var path3 = Path.Combine(Path.GetTempPath(), $"test-cache-{Guid.NewGuid()}.db"); + using var writeOnlyCache = new SqliteEmbeddingCache(path3, CacheModes.WriteOnly, this._loggerMock.Object); + + try + { + // Assert + Assert.Equal(CacheModes.ReadWrite, readWriteCache.Mode); + Assert.Equal(CacheModes.ReadOnly, readOnlyCache.Mode); + Assert.Equal(CacheModes.WriteOnly, writeOnlyCache.Mode); + } + finally + { + // Cleanup + if (File.Exists(path2)) + { + File.Delete(path2); + } + + if (File.Exists(path3)) + { + File.Delete(path3); + } + } + } + + [Fact] + public async Task VectorBlobStorage_ShouldPreserveFloatPrecision() + { + // Arrange + using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); + + // Include edge cases for float precision + var vector = new float[] + { + 0.123456789f, + -0.987654321f, + float.Epsilon, + float.MaxValue / 2, + float.MinValue / 2, + 0.0f, + 1.0f, + -1.0f + }; + + // Act + await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.NotNull(result); + for (int i = 0; i < vector.Length; i++) + { + Assert.Equal(vector[i], result.Vector[i]); + } + } + + [Fact] + public async Task CacheDoesNotStoreInputText() + { + // Arrange + using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); + const string sensitiveText = "This is sensitive PII data that should not be stored"; + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, sensitiveText); + var vector = new float[] { 0.1f, 0.2f, 0.3f }; + + // Act + await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + + // Assert - Read database file and verify text is not present + var dbContent = await File.ReadAllBytesAsync(this._tempDbPath).ConfigureAwait(false); + var dbString = System.Text.Encoding.UTF8.GetString(dbContent); + Assert.DoesNotContain(sensitiveText, dbString); + } + + [Fact] + public async Task CachePersistence_ShouldSurviveReopen() + { + // Arrange + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); + var vector = new float[] { 0.1f, 0.2f, 0.3f }; + + // Store and close + using (var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object)) + { + await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + } + + // Act - Reopen and read + using var reopenedCache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); + var result = await reopenedCache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.NotNull(result); + Assert.Equal(vector, result.Vector); + } + + [Fact] + public async Task StoreAsync_WithCancellationToken_ShouldRespectCancellation() + { + // Arrange + using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); + var vector = new float[] { 0.1f, 0.2f, 0.3f }; + var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act & Assert + await Assert.ThrowsAsync( + () => cache.StoreAsync(key, vector, tokenCount: null, cts.Token)).ConfigureAwait(false); + } + + [Fact] + public async Task TryGetAsync_WithCancellationToken_ShouldRespectCancellation() + { + // Arrange + using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); + var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act & Assert + await Assert.ThrowsAsync( + () => cache.TryGetAsync(key, cts.Token)).ConfigureAwait(false); + } +} diff --git a/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs new file mode 100644 index 000000000..cb74ffb30 --- /dev/null +++ b/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs @@ -0,0 +1,494 @@ +// Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Embeddings; +using KernelMemory.Core.Embeddings.Cache; +using Microsoft.Extensions.Logging; +using Moq; + +namespace KernelMemory.Core.Tests.Embeddings; + +/// +/// Tests for CachedEmbeddingGenerator decorator to verify cache hit/miss logic, +/// batch processing with mixed cache hits, and mode-based behavior. +/// +public sealed class CachedEmbeddingGeneratorTests +{ + private readonly Mock _innerGeneratorMock; + private readonly Mock _cacheMock; + private readonly Mock> _loggerMock; + + public CachedEmbeddingGeneratorTests() + { + this._innerGeneratorMock = new Mock(); + this._cacheMock = new Mock(); + this._loggerMock = new Mock>(); + + // Setup default inner generator properties + this._innerGeneratorMock.Setup(x => x.ProviderType).Returns(EmbeddingsTypes.OpenAI); + this._innerGeneratorMock.Setup(x => x.ModelName).Returns("text-embedding-ada-002"); + this._innerGeneratorMock.Setup(x => x.VectorDimensions).Returns(1536); + this._innerGeneratorMock.Setup(x => x.IsNormalized).Returns(true); + } + + [Fact] + public void Properties_ShouldDelegateToInnerGenerator() + { + // Arrange + this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); + var cachedGenerator = new CachedEmbeddingGenerator( + this._innerGeneratorMock.Object, + this._cacheMock.Object, + this._loggerMock.Object); + + // Act & Assert + Assert.Equal(EmbeddingsTypes.OpenAI, cachedGenerator.ProviderType); + Assert.Equal("text-embedding-ada-002", cachedGenerator.ModelName); + Assert.Equal(1536, cachedGenerator.VectorDimensions); + Assert.True(cachedGenerator.IsNormalized); + } + + [Fact] + public async Task GenerateAsync_Single_WithCacheHit_ShouldReturnCachedVector() + { + // Arrange + var cachedVector = new float[] { 0.1f, 0.2f, 0.3f }; + var cachedEmbedding = new CachedEmbedding + { + Vector = cachedVector, + Timestamp = DateTimeOffset.UtcNow + }; + + this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); + this._cacheMock + .Setup(x => x.TryGetAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(cachedEmbedding); + + var cachedGenerator = new CachedEmbeddingGenerator( + this._innerGeneratorMock.Object, + this._cacheMock.Object, + this._loggerMock.Object); + + // Act + var result = await cachedGenerator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(cachedVector, result); + this._innerGeneratorMock.Verify( + x => x.GenerateAsync(It.IsAny(), It.IsAny()), + Times.Never); + } + + [Fact] + public async Task GenerateAsync_Single_WithCacheMiss_ShouldCallInnerGenerator() + { + // Arrange + var generatedVector = new float[] { 0.4f, 0.5f, 0.6f }; + + this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); + this._cacheMock + .Setup(x => x.TryGetAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((CachedEmbedding?)null); + + this._innerGeneratorMock + .Setup(x => x.GenerateAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(generatedVector); + + var cachedGenerator = new CachedEmbeddingGenerator( + this._innerGeneratorMock.Object, + this._cacheMock.Object, + this._loggerMock.Object); + + // Act + var result = await cachedGenerator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(generatedVector, result); + this._innerGeneratorMock.Verify( + x => x.GenerateAsync("test text", It.IsAny()), + Times.Once); + } + + [Fact] + public async Task GenerateAsync_Single_WithCacheMiss_ShouldStoreInCache() + { + // Arrange + var generatedVector = new float[] { 0.4f, 0.5f, 0.6f }; + + this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); + this._cacheMock + .Setup(x => x.TryGetAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((CachedEmbedding?)null); + + this._innerGeneratorMock + .Setup(x => x.GenerateAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(generatedVector); + + var cachedGenerator = new CachedEmbeddingGenerator( + this._innerGeneratorMock.Object, + this._cacheMock.Object, + this._loggerMock.Object); + + // Act + await cachedGenerator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); + + // Assert + this._cacheMock.Verify( + x => x.StoreAsync(It.IsAny(), generatedVector, null, It.IsAny()), + Times.Once); + } + + [Fact] + public async Task GenerateAsync_Single_WithWriteOnlyCache_ShouldSkipCacheRead() + { + // Arrange + var generatedVector = new float[] { 0.4f, 0.5f, 0.6f }; + + this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.WriteOnly); + + this._innerGeneratorMock + .Setup(x => x.GenerateAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(generatedVector); + + var cachedGenerator = new CachedEmbeddingGenerator( + this._innerGeneratorMock.Object, + this._cacheMock.Object, + this._loggerMock.Object); + + // Act + var result = await cachedGenerator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(generatedVector, result); + this._cacheMock.Verify( + x => x.TryGetAsync(It.IsAny(), It.IsAny()), + Times.Never); + this._cacheMock.Verify( + x => x.StoreAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()), + Times.Once); + } + + [Fact] + public async Task GenerateAsync_Single_WithReadOnlyCache_ShouldSkipCacheWrite() + { + // Arrange + var generatedVector = new float[] { 0.4f, 0.5f, 0.6f }; + + this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadOnly); + this._cacheMock + .Setup(x => x.TryGetAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((CachedEmbedding?)null); + + this._innerGeneratorMock + .Setup(x => x.GenerateAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(generatedVector); + + var cachedGenerator = new CachedEmbeddingGenerator( + this._innerGeneratorMock.Object, + this._cacheMock.Object, + this._loggerMock.Object); + + // Act + var result = await cachedGenerator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(generatedVector, result); + this._cacheMock.Verify( + x => x.StoreAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()), + Times.Never); + } + + [Fact] + public async Task GenerateAsync_Batch_AllCacheHits_ShouldNotCallInnerGenerator() + { + // Arrange + var texts = new[] { "text1", "text2", "text3" }; + var cachedVectors = new Dictionary + { + ["text1"] = new[] { 0.1f, 0.2f }, + ["text2"] = new[] { 0.3f, 0.4f }, + ["text3"] = new[] { 0.5f, 0.6f } + }; + + this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); + this._cacheMock + .Setup(x => x.TryGetAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((EmbeddingCacheKey key, CancellationToken ct) => + { + // Find the matching text by checking the hash + foreach (var kvp in cachedVectors) + { + var testKey = EmbeddingCacheKey.Create("OpenAI", "text-embedding-ada-002", 1536, true, kvp.Key); + if (testKey.TextHash == key.TextHash) + { + return new CachedEmbedding { Vector = kvp.Value, Timestamp = DateTimeOffset.UtcNow }; + } + } + + return null; + }); + + var cachedGenerator = new CachedEmbeddingGenerator( + this._innerGeneratorMock.Object, + this._cacheMock.Object, + this._loggerMock.Object); + + // Act + var results = await cachedGenerator.GenerateAsync(texts, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(3, results.Length); + this._innerGeneratorMock.Verify( + x => x.GenerateAsync(It.IsAny>(), It.IsAny()), + Times.Never); + } + + [Fact] + public async Task GenerateAsync_Batch_AllCacheMisses_ShouldCallInnerGeneratorWithAllTexts() + { + // Arrange + var texts = new[] { "text1", "text2", "text3" }; + var generatedVectors = new[] + { + new[] { 0.1f, 0.2f }, + new[] { 0.3f, 0.4f }, + new[] { 0.5f, 0.6f } + }; + + this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); + this._cacheMock + .Setup(x => x.TryGetAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((CachedEmbedding?)null); + + this._innerGeneratorMock + .Setup(x => x.GenerateAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(generatedVectors); + + var cachedGenerator = new CachedEmbeddingGenerator( + this._innerGeneratorMock.Object, + this._cacheMock.Object, + this._loggerMock.Object); + + // Act + var results = await cachedGenerator.GenerateAsync(texts, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(3, results.Length); + this._innerGeneratorMock.Verify( + x => x.GenerateAsync(It.Is>(t => t.Count() == 3), It.IsAny()), + Times.Once); + } + + [Fact] + public async Task GenerateAsync_Batch_MixedHitsAndMisses_ShouldOnlyGenerateMisses() + { + // Arrange + var texts = new[] { "cached", "not-cached-1", "not-cached-2" }; + var cachedVector = new[] { 0.1f, 0.2f }; + var generatedVectors = new[] + { + new[] { 0.3f, 0.4f }, + new[] { 0.5f, 0.6f } + }; + + var cachedKey = EmbeddingCacheKey.Create("OpenAI", "text-embedding-ada-002", 1536, true, "cached"); + + this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); + this._cacheMock + .Setup(x => x.TryGetAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((EmbeddingCacheKey key, CancellationToken ct) => + { + if (key.TextHash == cachedKey.TextHash) + { + return new CachedEmbedding { Vector = cachedVector, Timestamp = DateTimeOffset.UtcNow }; + } + + return null; + }); + + this._innerGeneratorMock + .Setup(x => x.GenerateAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(generatedVectors); + + var cachedGenerator = new CachedEmbeddingGenerator( + this._innerGeneratorMock.Object, + this._cacheMock.Object, + this._loggerMock.Object); + + // Act + var results = await cachedGenerator.GenerateAsync(texts, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(3, results.Length); + // First result should be cached + Assert.Equal(cachedVector, results[0]); + // Other results should be generated + Assert.Equal(generatedVectors[0], results[1]); + Assert.Equal(generatedVectors[1], results[2]); + + // Verify only non-cached texts were sent to generator + this._innerGeneratorMock.Verify( + x => x.GenerateAsync(It.Is>(t => t.Count() == 2), It.IsAny()), + Times.Once); + } + + [Fact] + public async Task GenerateAsync_Batch_ShouldStoreGeneratedInCache() + { + // Arrange + var texts = new[] { "text1", "text2" }; + var generatedVectors = new[] + { + new[] { 0.1f, 0.2f }, + new[] { 0.3f, 0.4f } + }; + + this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); + this._cacheMock + .Setup(x => x.TryGetAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((CachedEmbedding?)null); + + this._innerGeneratorMock + .Setup(x => x.GenerateAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(generatedVectors); + + var cachedGenerator = new CachedEmbeddingGenerator( + this._innerGeneratorMock.Object, + this._cacheMock.Object, + this._loggerMock.Object); + + // Act + await cachedGenerator.GenerateAsync(texts, CancellationToken.None).ConfigureAwait(false); + + // Assert - Both generated vectors should be stored + this._cacheMock.Verify( + x => x.StoreAsync(It.IsAny(), It.IsAny(), null, It.IsAny()), + Times.Exactly(2)); + } + + [Fact] + public async Task GenerateAsync_Batch_EmptyInput_ShouldReturnEmptyArray() + { + // Arrange + this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); + + var cachedGenerator = new CachedEmbeddingGenerator( + this._innerGeneratorMock.Object, + this._cacheMock.Object, + this._loggerMock.Object); + + // Act + var results = await cachedGenerator.GenerateAsync(Array.Empty(), CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Empty(results); + this._innerGeneratorMock.Verify( + x => x.GenerateAsync(It.IsAny>(), It.IsAny()), + Times.Never); + } + + [Fact] + public async Task GenerateAsync_Batch_ShouldPreserveOrder() + { + // Arrange - Specifically test that results maintain input order + var texts = new[] { "a", "b", "c", "d" }; + + // Make only "b" and "d" cached + var cachedB = EmbeddingCacheKey.Create("OpenAI", "text-embedding-ada-002", 1536, true, "b"); + var cachedD = EmbeddingCacheKey.Create("OpenAI", "text-embedding-ada-002", 1536, true, "d"); + var vectorB = new[] { 2.0f }; + var vectorD = new[] { 4.0f }; + + var generatedVectors = new[] + { + new[] { 1.0f }, // for "a" + new[] { 3.0f } // for "c" + }; + + this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); + this._cacheMock + .Setup(x => x.TryGetAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((EmbeddingCacheKey key, CancellationToken ct) => + { + if (key.TextHash == cachedB.TextHash) + { + return new CachedEmbedding { Vector = vectorB, Timestamp = DateTimeOffset.UtcNow }; + } + + if (key.TextHash == cachedD.TextHash) + { + return new CachedEmbedding { Vector = vectorD, Timestamp = DateTimeOffset.UtcNow }; + } + + return null; + }); + + this._innerGeneratorMock + .Setup(x => x.GenerateAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(generatedVectors); + + var cachedGenerator = new CachedEmbeddingGenerator( + this._innerGeneratorMock.Object, + this._cacheMock.Object, + this._loggerMock.Object); + + // Act + var results = await cachedGenerator.GenerateAsync(texts, CancellationToken.None).ConfigureAwait(false); + + // Assert - Order must be preserved: a, b, c, d + Assert.Equal(4, results.Length); + Assert.Equal(new[] { 1.0f }, results[0]); // a - generated + Assert.Equal(new[] { 2.0f }, results[1]); // b - cached + Assert.Equal(new[] { 3.0f }, results[2]); // c - generated + Assert.Equal(new[] { 4.0f }, results[3]); // d - cached + } + + [Fact] + public async Task GenerateAsync_WithCancellation_ShouldPropagate() + { + // Arrange + this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); + this._cacheMock + .Setup(x => x.TryGetAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((CachedEmbedding?)null); + + this._innerGeneratorMock + .Setup(x => x.GenerateAsync(It.IsAny(), It.IsAny())) + .ThrowsAsync(new OperationCanceledException()); + + var cachedGenerator = new CachedEmbeddingGenerator( + this._innerGeneratorMock.Object, + this._cacheMock.Object, + this._loggerMock.Object); + + var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act & Assert + await Assert.ThrowsAsync( + () => cachedGenerator.GenerateAsync("test", cts.Token)).ConfigureAwait(false); + } + + [Fact] + public void Constructor_WithNullInnerGenerator_ShouldThrow() + { + // Arrange & Act & Assert + Assert.Throws(() => + new CachedEmbeddingGenerator(null!, this._cacheMock.Object, this._loggerMock.Object)); + } + + [Fact] + public void Constructor_WithNullCache_ShouldThrow() + { + // Arrange & Act & Assert + Assert.Throws(() => + new CachedEmbeddingGenerator(this._innerGeneratorMock.Object, null!, this._loggerMock.Object)); + } + + [Fact] + public void Constructor_WithNullLogger_ShouldThrow() + { + // Arrange & Act & Assert + Assert.Throws(() => + new CachedEmbeddingGenerator(this._innerGeneratorMock.Object, this._cacheMock.Object, null!)); + } +} diff --git a/tests/Core.Tests/Embeddings/CachedEmbeddingTests.cs b/tests/Core.Tests/Embeddings/CachedEmbeddingTests.cs new file mode 100644 index 000000000..5c68e34a6 --- /dev/null +++ b/tests/Core.Tests/Embeddings/CachedEmbeddingTests.cs @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core.Embeddings.Cache; + +namespace KernelMemory.Core.Tests.Embeddings; + +/// +/// Tests for CachedEmbedding model to verify proper construction and immutability. +/// CachedEmbedding stores the embedding vector, optional token count, and timestamp. +/// +public sealed class CachedEmbeddingTests +{ + [Fact] + public void CachedEmbedding_WithRequiredProperties_ShouldBeCreated() + { + // Arrange + var vector = new float[] { 0.1f, 0.2f, 0.3f }; + var timestamp = DateTimeOffset.UtcNow; + + // Act + var cached = new CachedEmbedding + { + Vector = vector, + Timestamp = timestamp + }; + + // Assert + Assert.Equal(vector, cached.Vector); + Assert.Equal(timestamp, cached.Timestamp); + Assert.Null(cached.TokenCount); + } + + [Fact] + public void CachedEmbedding_WithTokenCount_ShouldStoreValue() + { + // Arrange + var vector = new float[] { 0.1f, 0.2f, 0.3f }; + var timestamp = DateTimeOffset.UtcNow; + const int tokenCount = 42; + + // Act + var cached = new CachedEmbedding + { + Vector = vector, + Timestamp = timestamp, + TokenCount = tokenCount + }; + + // Assert + Assert.Equal(tokenCount, cached.TokenCount); + } + + [Fact] + public void CachedEmbedding_WithNullTokenCount_ShouldBeNull() + { + // Arrange + var vector = new float[] { 0.1f, 0.2f, 0.3f }; + var timestamp = DateTimeOffset.UtcNow; + + // Act + var cached = new CachedEmbedding + { + Vector = vector, + Timestamp = timestamp, + TokenCount = null + }; + + // Assert + Assert.Null(cached.TokenCount); + } + + [Fact] + public void CachedEmbedding_VectorShouldPreserveFloatPrecision() + { + // Arrange + var vector = new float[] { 0.123456789f, -0.987654321f, float.MaxValue, float.MinValue }; + var timestamp = DateTimeOffset.UtcNow; + + // Act + var cached = new CachedEmbedding + { + Vector = vector, + Timestamp = timestamp + }; + + // Assert + Assert.Equal(0.123456789f, cached.Vector[0]); + Assert.Equal(-0.987654321f, cached.Vector[1]); + Assert.Equal(float.MaxValue, cached.Vector[2]); + Assert.Equal(float.MinValue, cached.Vector[3]); + } + + [Fact] + public void CachedEmbedding_WithLargeVector_ShouldPreserveAllDimensions() + { + // Arrange - simulate 1536-dimension embedding (OpenAI ada-002) + var vector = new float[1536]; + for (int i = 0; i < vector.Length; i++) + { + vector[i] = (float)i / 1536; + } + + var timestamp = DateTimeOffset.UtcNow; + + // Act + var cached = new CachedEmbedding + { + Vector = vector, + Timestamp = timestamp + }; + + // Assert + Assert.Equal(1536, cached.Vector.Length); + Assert.Equal(0.0f, cached.Vector[0]); + Assert.Equal(1535f / 1536, cached.Vector[1535], precision: 6); + } + + [Fact] + public void CachedEmbedding_TimestampShouldBeDateTimeOffset() + { + // Arrange + var vector = new float[] { 0.1f }; + var timestamp = new DateTimeOffset(2025, 12, 1, 10, 30, 0, TimeSpan.FromHours(-5)); + + // Act + var cached = new CachedEmbedding + { + Vector = vector, + Timestamp = timestamp + }; + + // Assert + Assert.Equal(timestamp.Year, cached.Timestamp.Year); + Assert.Equal(timestamp.Month, cached.Timestamp.Month); + Assert.Equal(timestamp.Day, cached.Timestamp.Day); + Assert.Equal(timestamp.Hour, cached.Timestamp.Hour); + Assert.Equal(timestamp.Offset, cached.Timestamp.Offset); + } +} diff --git a/tests/Core.Tests/Embeddings/EmbeddingCacheKeyTests.cs b/tests/Core.Tests/Embeddings/EmbeddingCacheKeyTests.cs new file mode 100644 index 000000000..1f169a2a6 --- /dev/null +++ b/tests/Core.Tests/Embeddings/EmbeddingCacheKeyTests.cs @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core.Embeddings.Cache; + +namespace KernelMemory.Core.Tests.Embeddings; + +/// +/// Tests for EmbeddingCacheKey to verify correct cache key generation and SHA256 hashing. +/// Cache keys must uniquely identify embeddings by provider, model, dimensions, normalization, and text content. +/// +public sealed class EmbeddingCacheKeyTests +{ + [Fact] + public void Create_WithValidParameters_ShouldGenerateCorrectCacheKey() + { + // Arrange & Act + var key = EmbeddingCacheKey.Create( + provider: "OpenAI", + model: "text-embedding-ada-002", + vectorDimensions: 1536, + isNormalized: true, + text: "Hello world"); + + // Assert + Assert.Equal("OpenAI", key.Provider); + Assert.Equal("text-embedding-ada-002", key.Model); + Assert.Equal(1536, key.VectorDimensions); + Assert.True(key.IsNormalized); + Assert.Equal(11, key.TextLength); + Assert.NotNull(key.TextHash); + Assert.NotEmpty(key.TextHash); + } + + [Fact] + public void Create_WithSameText_ShouldGenerateSameHash() + { + // Arrange + const string text = "Test embedding text"; + + // Act + var key1 = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, text); + var key2 = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, text); + + // Assert + Assert.Equal(key1.TextHash, key2.TextHash); + } + + [Fact] + public void Create_WithDifferentText_ShouldGenerateDifferentHash() + { + // Arrange + const string text1 = "Text one"; + const string text2 = "Text two"; + + // Act + var key1 = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, text1); + var key2 = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, text2); + + // Assert + Assert.NotEqual(key1.TextHash, key2.TextHash); + } + + [Fact] + public void TextHash_ShouldBeSha256Format() + { + // Arrange & Act + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test"); + + // Assert - SHA256 produces 64 character hex string + Assert.Equal(64, key.TextHash.Length); + Assert.Matches("^[a-fA-F0-9]+$", key.TextHash); + } + + [Fact] + public void ToCompositeKey_ShouldIncludeAllComponents() + { + // Arrange + var key = EmbeddingCacheKey.Create( + provider: "Ollama", + model: "qwen3-embedding", + vectorDimensions: 1024, + isNormalized: false, + text: "sample"); + + // Act + var compositeKey = key.ToCompositeKey(); + + // Assert + Assert.Contains("Ollama", compositeKey); + Assert.Contains("qwen3-embedding", compositeKey); + Assert.Contains("1024", compositeKey); + Assert.Contains("False", compositeKey); + Assert.Contains("6", compositeKey); // Text length + Assert.Contains(key.TextHash, compositeKey); + } + + [Fact] + public void ToCompositeKey_WithSameParameters_ShouldBeIdentical() + { + // Arrange + var key1 = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "text"); + var key2 = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "text"); + + // Act + var composite1 = key1.ToCompositeKey(); + var composite2 = key2.ToCompositeKey(); + + // Assert + Assert.Equal(composite1, composite2); + } + + [Fact] + public void ToCompositeKey_WithDifferentProvider_ShouldBeDifferent() + { + // Arrange + var key1 = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "text"); + var key2 = EmbeddingCacheKey.Create("Ollama", "model", 1536, true, "text"); + + // Act & Assert + Assert.NotEqual(key1.ToCompositeKey(), key2.ToCompositeKey()); + } + + [Fact] + public void ToCompositeKey_WithDifferentModel_ShouldBeDifferent() + { + // Arrange + var key1 = EmbeddingCacheKey.Create("OpenAI", "model-a", 1536, true, "text"); + var key2 = EmbeddingCacheKey.Create("OpenAI", "model-b", 1536, true, "text"); + + // Act & Assert + Assert.NotEqual(key1.ToCompositeKey(), key2.ToCompositeKey()); + } + + [Fact] + public void ToCompositeKey_WithDifferentDimensions_ShouldBeDifferent() + { + // Arrange + var key1 = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "text"); + var key2 = EmbeddingCacheKey.Create("OpenAI", "model", 768, true, "text"); + + // Act & Assert + Assert.NotEqual(key1.ToCompositeKey(), key2.ToCompositeKey()); + } + + [Fact] + public void ToCompositeKey_WithDifferentNormalization_ShouldBeDifferent() + { + // Arrange + var key1 = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "text"); + var key2 = EmbeddingCacheKey.Create("OpenAI", "model", 1536, false, "text"); + + // Act & Assert + Assert.NotEqual(key1.ToCompositeKey(), key2.ToCompositeKey()); + } + + [Fact] + public void Create_WithEmptyText_ShouldGenerateValidHash() + { + // Arrange & Act + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, ""); + + // Assert + Assert.Equal(0, key.TextLength); + Assert.NotNull(key.TextHash); + Assert.Equal(64, key.TextHash.Length); + } + + [Fact] + public void Create_WithUnicodeText_ShouldGenerateValidHash() + { + // Arrange & Act + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "Hello, "); + + // Assert + Assert.NotNull(key.TextHash); + Assert.Equal(64, key.TextHash.Length); + } + + [Fact] + public void Create_WithLongText_ShouldCaptureCorrectLength() + { + // Arrange + var longText = new string('a', 10000); + + // Act + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, longText); + + // Assert + Assert.Equal(10000, key.TextLength); + } + + [Fact] + public void Create_WithNullText_ShouldThrowArgumentNullException() + { + // Arrange & Act & Assert + Assert.Throws(() => + EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, null!)); + } + + [Fact] + public void Create_WithNullProvider_ShouldThrowArgumentNullException() + { + // Arrange & Act & Assert + Assert.Throws(() => + EmbeddingCacheKey.Create(null!, "model", 1536, true, "text")); + } + + [Fact] + public void Create_WithNullModel_ShouldThrowArgumentNullException() + { + // Arrange & Act & Assert + Assert.Throws(() => + EmbeddingCacheKey.Create("OpenAI", null!, 1536, true, "text")); + } + + [Fact] + public void Create_WithZeroDimensions_ShouldThrowArgumentOutOfRangeException() + { + // Arrange & Act & Assert + Assert.Throws(() => + EmbeddingCacheKey.Create("OpenAI", "model", 0, true, "text")); + } + + [Fact] + public void Create_WithNegativeDimensions_ShouldThrowArgumentOutOfRangeException() + { + // Arrange & Act & Assert + Assert.Throws(() => + EmbeddingCacheKey.Create("OpenAI", "model", -1, true, "text")); + } +} diff --git a/tests/Core.Tests/Embeddings/EmbeddingConstantsTests.cs b/tests/Core.Tests/Embeddings/EmbeddingConstantsTests.cs new file mode 100644 index 000000000..eeca003f2 --- /dev/null +++ b/tests/Core.Tests/Embeddings/EmbeddingConstantsTests.cs @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core.Embeddings; + +namespace KernelMemory.Core.Tests.Embeddings; + +/// +/// Tests for EmbeddingConstants to verify known model dimensions are correct. +/// These values are critical for cache key generation and validation. +/// +public sealed class EmbeddingConstantsTests +{ + [Theory] + [InlineData("qwen3-embedding", 1024)] + [InlineData("nomic-embed-text", 768)] + [InlineData("text-embedding-ada-002", 1536)] + [InlineData("text-embedding-3-small", 1536)] + [InlineData("text-embedding-3-large", 3072)] + [InlineData("sentence-transformers/all-MiniLM-L6-v2", 384)] + [InlineData("BAAI/bge-base-en-v1.5", 768)] + public void KnownModelDimensions_ShouldContainExpectedValues(string modelName, int expectedDimensions) + { + // Act + var exists = EmbeddingConstants.KnownModelDimensions.TryGetValue(modelName, out var dimensions); + + // Assert + Assert.True(exists, $"Model '{modelName}' should be in KnownModelDimensions"); + Assert.Equal(expectedDimensions, dimensions); + } + + [Fact] + public void KnownModelDimensions_ShouldNotBeEmpty() + { + // Assert + Assert.NotEmpty(EmbeddingConstants.KnownModelDimensions); + } + + [Fact] + public void KnownModelDimensions_AllValuesShouldBePositive() + { + // Assert + foreach (var kvp in EmbeddingConstants.KnownModelDimensions) + { + Assert.True(kvp.Value > 0, $"Model '{kvp.Key}' has invalid dimensions: {kvp.Value}"); + } + } + + [Fact] + public void TryGetDimensions_WithKnownModel_ShouldReturnTrue() + { + // Act + var result = EmbeddingConstants.TryGetDimensions("text-embedding-ada-002", out var dimensions); + + // Assert + Assert.True(result); + Assert.Equal(1536, dimensions); + } + + [Fact] + public void TryGetDimensions_WithUnknownModel_ShouldReturnFalse() + { + // Act + var result = EmbeddingConstants.TryGetDimensions("unknown-model", out var dimensions); + + // Assert + Assert.False(result); + Assert.Equal(0, dimensions); + } + + [Fact] + public void DefaultBatchSize_ShouldBe10() + { + // Assert + Assert.Equal(10, EmbeddingConstants.DefaultBatchSize); + } + + [Fact] + public void DefaultOllamaModel_ShouldBeQwen3Embedding() + { + // Assert + Assert.Equal("qwen3-embedding", EmbeddingConstants.DefaultOllamaModel); + } + + [Fact] + public void DefaultOllamaBaseUrl_ShouldBeLocalhost() + { + // Assert + Assert.Equal("http://localhost:11434", EmbeddingConstants.DefaultOllamaBaseUrl); + } + + [Fact] + public void DefaultHuggingFaceModel_ShouldBeAllMiniLM() + { + // Assert + Assert.Equal("sentence-transformers/all-MiniLM-L6-v2", EmbeddingConstants.DefaultHuggingFaceModel); + } + + [Fact] + public void DefaultHuggingFaceBaseUrl_ShouldBeInferenceApi() + { + // Assert + Assert.Equal("https://api-inference.huggingface.co", EmbeddingConstants.DefaultHuggingFaceBaseUrl); + } +} diff --git a/tests/Core.Tests/Embeddings/Providers/AzureOpenAIEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/Providers/AzureOpenAIEmbeddingGeneratorTests.cs new file mode 100644 index 000000000..485844874 --- /dev/null +++ b/tests/Core.Tests/Embeddings/Providers/AzureOpenAIEmbeddingGeneratorTests.cs @@ -0,0 +1,308 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Net; +using System.Text.Json; +using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Embeddings.Providers; +using Microsoft.Extensions.Logging; +using Moq; +using Moq.Protected; + +namespace KernelMemory.Core.Tests.Embeddings.Providers; + +/// +/// Tests for AzureOpenAIEmbeddingGenerator to verify Azure endpoint communication, +/// API key authentication, and managed identity support. +/// Uses mocked HttpMessageHandler to avoid real Azure OpenAI calls in unit tests. +/// +public sealed class AzureOpenAIEmbeddingGeneratorTests +{ + private readonly Mock> _loggerMock; + private readonly Mock _httpHandlerMock; + + public AzureOpenAIEmbeddingGeneratorTests() + { + this._loggerMock = new Mock>(); + this._httpHandlerMock = new Mock(); + } + + [Fact] + public void Properties_ShouldReflectConfiguration() + { + // Arrange + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new AzureOpenAIEmbeddingGenerator( + httpClient, + endpoint: "https://myservice.openai.azure.com", + deployment: "my-embedding", + model: "text-embedding-ada-002", + apiKey: "test-key", + vectorDimensions: 1536, + isNormalized: true, + this._loggerMock.Object); + + // Assert + Assert.Equal(EmbeddingsTypes.AzureOpenAI, generator.ProviderType); + Assert.Equal("text-embedding-ada-002", generator.ModelName); + Assert.Equal(1536, generator.VectorDimensions); + Assert.True(generator.IsNormalized); + } + + [Fact] + public async Task GenerateAsync_Single_ShouldCallAzureEndpoint() + { + // Arrange + var response = CreateAzureResponse(new[] { new[] { 0.1f, 0.2f, 0.3f } }); + var responseJson = JsonSerializer.Serialize(response); + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.Is(req => + req.Method == HttpMethod.Post && + req.RequestUri!.ToString().StartsWith("https://myservice.openai.azure.com/openai/deployments/my-embedding/embeddings")), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(responseJson) + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new AzureOpenAIEmbeddingGenerator( + httpClient, "https://myservice.openai.azure.com", "my-embedding", "ada-002", "test-key", 1536, true, this._loggerMock.Object); + + // Act + var result = await generator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(new[] { 0.1f, 0.2f, 0.3f }, result); + } + + [Fact] + public async Task GenerateAsync_WithApiKey_ShouldSendApiKeyHeader() + { + // Arrange + HttpRequestMessage? capturedRequest = null; + var response = CreateAzureResponse(new[] { new[] { 0.1f } }); + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .Callback((req, _) => capturedRequest = req) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(JsonSerializer.Serialize(response)) + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new AzureOpenAIEmbeddingGenerator( + httpClient, "https://myservice.openai.azure.com", "deployment", "model", "my-api-key", 1536, true, this._loggerMock.Object); + + // Act + await generator.GenerateAsync("test", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.NotNull(capturedRequest); + Assert.True(capturedRequest!.Headers.Contains("api-key")); + Assert.Equal("my-api-key", capturedRequest.Headers.GetValues("api-key").First()); + } + + [Fact] + public async Task GenerateAsync_Batch_ShouldSendAllInputsInOneRequest() + { + // Arrange + var texts = new[] { "text1", "text2", "text3" }; + var response = CreateAzureResponse(new[] + { + new[] { 0.1f }, + new[] { 0.2f }, + new[] { 0.3f } + }); + + string? capturedContent = null; + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .Returns(async (req, ct) => + { + // Capture content before it's disposed + capturedContent = await req.Content!.ReadAsStringAsync(ct).ConfigureAwait(false); + return new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(JsonSerializer.Serialize(response)) + }; + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new AzureOpenAIEmbeddingGenerator( + httpClient, "https://myservice.openai.azure.com", "deployment", "model", "key", 1536, true, this._loggerMock.Object); + + // Act + var results = await generator.GenerateAsync(texts, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(3, results.Length); + var requestBody = JsonSerializer.Deserialize(capturedContent!); + Assert.Equal(3, requestBody!.Input.Length); + } + + [Fact] + public async Task GenerateAsync_ShouldIncludeApiVersion() + { + // Arrange + HttpRequestMessage? capturedRequest = null; + var response = CreateAzureResponse(new[] { new[] { 0.1f } }); + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .Callback((req, _) => capturedRequest = req) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(JsonSerializer.Serialize(response)) + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new AzureOpenAIEmbeddingGenerator( + httpClient, "https://myservice.openai.azure.com", "deployment", "model", "key", 1536, true, this._loggerMock.Object); + + // Act + await generator.GenerateAsync("test", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.NotNull(capturedRequest); + Assert.Contains("api-version=", capturedRequest!.RequestUri!.Query); + } + + [Fact] + public async Task GenerateAsync_WithUnauthorizedError_ShouldThrowHttpRequestException() + { + // Arrange + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.Unauthorized, + Content = new StringContent("Access denied") + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new AzureOpenAIEmbeddingGenerator( + httpClient, "https://myservice.openai.azure.com", "deployment", "model", "bad-key", 1536, true, this._loggerMock.Object); + + // Act & Assert + await Assert.ThrowsAsync( + () => generator.GenerateAsync("test", CancellationToken.None)).ConfigureAwait(false); + } + + [Fact] + public async Task GenerateAsync_WithCancellation_ShouldPropagate() + { + // Arrange + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ThrowsAsync(new OperationCanceledException()); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new AzureOpenAIEmbeddingGenerator( + httpClient, "https://myservice.openai.azure.com", "deployment", "model", "key", 1536, true, this._loggerMock.Object); + + var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act & Assert - TaskCanceledException inherits from OperationCanceledException + await Assert.ThrowsAnyAsync( + () => generator.GenerateAsync("test", cts.Token)).ConfigureAwait(false); + } + + [Fact] + public void Constructor_WithNullEndpoint_ShouldThrow() + { + // Assert + var httpClient = new HttpClient(); + Assert.Throws(() => + new AzureOpenAIEmbeddingGenerator(httpClient, null!, "deployment", "model", "key", 1536, true, this._loggerMock.Object)); + } + + [Fact] + public void Constructor_WithNullDeployment_ShouldThrow() + { + // Assert + var httpClient = new HttpClient(); + Assert.Throws(() => + new AzureOpenAIEmbeddingGenerator(httpClient, "https://endpoint", null!, "model", "key", 1536, true, this._loggerMock.Object)); + } + + [Fact] + public void Constructor_WithNullApiKey_ShouldThrow() + { + // Assert + var httpClient = new HttpClient(); + Assert.Throws(() => + new AzureOpenAIEmbeddingGenerator(httpClient, "https://endpoint", "deployment", "model", null!, 1536, true, this._loggerMock.Object)); + } + + private static AzureEmbeddingResponse CreateAzureResponse(float[][] embeddings) + { + return new AzureEmbeddingResponse + { + Data = embeddings.Select((e, i) => new EmbeddingData { Index = i, Embedding = e }).ToArray(), + Usage = new UsageInfo { PromptTokens = 10, TotalTokens = 10 } + }; + } + + // Internal request/response classes for testing + private sealed class AzureEmbeddingRequest + { + [System.Text.Json.Serialization.JsonPropertyName("input")] + public string[] Input { get; set; } = Array.Empty(); + } + + private sealed class AzureEmbeddingResponse + { + [System.Text.Json.Serialization.JsonPropertyName("data")] + public EmbeddingData[] Data { get; set; } = Array.Empty(); + + [System.Text.Json.Serialization.JsonPropertyName("usage")] + public UsageInfo Usage { get; set; } = new(); + } + + private sealed class EmbeddingData + { + [System.Text.Json.Serialization.JsonPropertyName("index")] + public int Index { get; set; } + + [System.Text.Json.Serialization.JsonPropertyName("embedding")] + public float[] Embedding { get; set; } = Array.Empty(); + } + + private sealed class UsageInfo + { + [System.Text.Json.Serialization.JsonPropertyName("prompt_tokens")] + public int PromptTokens { get; set; } + + [System.Text.Json.Serialization.JsonPropertyName("total_tokens")] + public int TotalTokens { get; set; } + } +} diff --git a/tests/Core.Tests/Embeddings/Providers/HuggingFaceEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/Providers/HuggingFaceEmbeddingGeneratorTests.cs new file mode 100644 index 000000000..dd42f4112 --- /dev/null +++ b/tests/Core.Tests/Embeddings/Providers/HuggingFaceEmbeddingGeneratorTests.cs @@ -0,0 +1,327 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Net; +using System.Text.Json; +using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Embeddings.Providers; +using Microsoft.Extensions.Logging; +using Moq; +using Moq.Protected; + +namespace KernelMemory.Core.Tests.Embeddings.Providers; + +/// +/// Tests for HuggingFaceEmbeddingGenerator to verify HF Inference API communication, +/// authentication via HF_TOKEN, and response parsing. +/// Uses mocked HttpMessageHandler to avoid real HuggingFace API calls in unit tests. +/// +public sealed class HuggingFaceEmbeddingGeneratorTests +{ + private readonly Mock> _loggerMock; + private readonly Mock _httpHandlerMock; + + public HuggingFaceEmbeddingGeneratorTests() + { + this._loggerMock = new Mock>(); + this._httpHandlerMock = new Mock(); + } + + [Fact] + public void Properties_ShouldReflectConfiguration() + { + // Arrange + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new HuggingFaceEmbeddingGenerator( + httpClient, + apiKey: "hf_test_token", + model: "sentence-transformers/all-MiniLM-L6-v2", + vectorDimensions: 384, + isNormalized: true, + baseUrl: null, + this._loggerMock.Object); + + // Assert + Assert.Equal(EmbeddingsTypes.HuggingFace, generator.ProviderType); + Assert.Equal("sentence-transformers/all-MiniLM-L6-v2", generator.ModelName); + Assert.Equal(384, generator.VectorDimensions); + Assert.True(generator.IsNormalized); + } + + [Fact] + public async Task GenerateAsync_Single_ShouldCallCorrectEndpoint() + { + // Arrange + var response = new[] { new[] { 0.1f, 0.2f, 0.3f } }; + var responseJson = JsonSerializer.Serialize(response); + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.Is(req => + req.Method == HttpMethod.Post && + req.RequestUri!.ToString() == "https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2"), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(responseJson) + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new HuggingFaceEmbeddingGenerator( + httpClient, "hf_token", "sentence-transformers/all-MiniLM-L6-v2", 384, true, null, this._loggerMock.Object); + + // Act + var result = await generator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(new[] { 0.1f, 0.2f, 0.3f }, result); + } + + [Fact] + public async Task GenerateAsync_Single_ShouldSendAuthorizationHeader() + { + // Arrange + HttpRequestMessage? capturedRequest = null; + var response = new[] { new[] { 0.1f } }; + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .Callback((req, _) => capturedRequest = req) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(JsonSerializer.Serialize(response)) + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new HuggingFaceEmbeddingGenerator( + httpClient, "hf_my_secret_token", "model", 384, true, null, this._loggerMock.Object); + + // Act + await generator.GenerateAsync("test", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.NotNull(capturedRequest); + Assert.Equal("Bearer", capturedRequest!.Headers.Authorization?.Scheme); + Assert.Equal("hf_my_secret_token", capturedRequest.Headers.Authorization?.Parameter); + } + + [Fact] + public async Task GenerateAsync_Single_ShouldSendCorrectRequestBody() + { + // Arrange + string? capturedContent = null; + var response = new[] { new[] { 0.1f } }; + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .Returns(async (req, ct) => + { + // Capture content before it's disposed + capturedContent = await req.Content!.ReadAsStringAsync(ct).ConfigureAwait(false); + return new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(JsonSerializer.Serialize(response)) + }; + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new HuggingFaceEmbeddingGenerator( + httpClient, "hf_token", "model", 384, true, null, this._loggerMock.Object); + + // Act + await generator.GenerateAsync("hello world", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.NotNull(capturedContent); + var requestBody = JsonSerializer.Deserialize(capturedContent!); + Assert.NotNull(requestBody); + Assert.Contains("hello world", requestBody!.Inputs); + } + + [Fact] + public async Task GenerateAsync_Batch_ShouldProcessAllTexts() + { + // Arrange + var texts = new[] { "text1", "text2", "text3" }; + var response = new[] + { + new[] { 0.1f }, + new[] { 0.2f }, + new[] { 0.3f } + }; + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(JsonSerializer.Serialize(response)) + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new HuggingFaceEmbeddingGenerator( + httpClient, "hf_token", "model", 384, true, null, this._loggerMock.Object); + + // Act + var results = await generator.GenerateAsync(texts, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(3, results.Length); + Assert.Equal(new[] { 0.1f }, results[0]); + Assert.Equal(new[] { 0.2f }, results[1]); + Assert.Equal(new[] { 0.3f }, results[2]); + } + + [Fact] + public async Task GenerateAsync_WithCustomBaseUrl_ShouldUseIt() + { + // Arrange + var response = new[] { new[] { 0.1f } }; + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.Is(req => + req.RequestUri!.ToString().StartsWith("https://custom.hf-endpoint.com")), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(JsonSerializer.Serialize(response)) + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new HuggingFaceEmbeddingGenerator( + httpClient, "hf_token", "model", 384, true, "https://custom.hf-endpoint.com", this._loggerMock.Object); + + // Act + var result = await generator.GenerateAsync("test", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(new[] { 0.1f }, result); + } + + [Fact] + public async Task GenerateAsync_WithUnauthorizedError_ShouldThrowHttpRequestException() + { + // Arrange + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.Unauthorized, + Content = new StringContent("Invalid token") + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new HuggingFaceEmbeddingGenerator( + httpClient, "bad_token", "model", 384, true, null, this._loggerMock.Object); + + // Act & Assert + await Assert.ThrowsAsync( + () => generator.GenerateAsync("test", CancellationToken.None)).ConfigureAwait(false); + } + + [Fact] + public async Task GenerateAsync_WithModelLoadingError_ShouldThrowHttpRequestException() + { + // Arrange - HuggingFace returns 503 when model is loading + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.ServiceUnavailable, + Content = new StringContent("{\"error\":\"Model is currently loading\"}") + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new HuggingFaceEmbeddingGenerator( + httpClient, "hf_token", "model", 384, true, null, this._loggerMock.Object); + + // Act & Assert + await Assert.ThrowsAsync( + () => generator.GenerateAsync("test", CancellationToken.None)).ConfigureAwait(false); + } + + [Fact] + public async Task GenerateAsync_WithCancellation_ShouldPropagate() + { + // Arrange + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ThrowsAsync(new OperationCanceledException()); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new HuggingFaceEmbeddingGenerator( + httpClient, "hf_token", "model", 384, true, null, this._loggerMock.Object); + + var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act & Assert - TaskCanceledException inherits from OperationCanceledException + await Assert.ThrowsAnyAsync( + () => generator.GenerateAsync("test", cts.Token)).ConfigureAwait(false); + } + + [Fact] + public void Constructor_WithNullApiKey_ShouldThrow() + { + // Assert + var httpClient = new HttpClient(); + Assert.Throws(() => + new HuggingFaceEmbeddingGenerator(httpClient, null!, "model", 384, true, null, this._loggerMock.Object)); + } + + [Fact] + public void Constructor_WithEmptyApiKey_ShouldThrow() + { + // Assert + var httpClient = new HttpClient(); + Assert.Throws(() => + new HuggingFaceEmbeddingGenerator(httpClient, "", "model", 384, true, null, this._loggerMock.Object)); + } + + [Fact] + public void Constructor_WithNullModel_ShouldThrow() + { + // Assert + var httpClient = new HttpClient(); + Assert.Throws(() => + new HuggingFaceEmbeddingGenerator(httpClient, "key", null!, 384, true, null, this._loggerMock.Object)); + } + + // Internal request class for testing + private sealed class HuggingFaceRequest + { + [System.Text.Json.Serialization.JsonPropertyName("inputs")] + public string[] Inputs { get; set; } = Array.Empty(); + } +} diff --git a/tests/Core.Tests/Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs new file mode 100644 index 000000000..217977da0 --- /dev/null +++ b/tests/Core.Tests/Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs @@ -0,0 +1,272 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Net; +using System.Text.Json; +using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Embeddings.Providers; +using Microsoft.Extensions.Logging; +using Moq; +using Moq.Protected; + +namespace KernelMemory.Core.Tests.Embeddings.Providers; + +/// +/// Tests for OllamaEmbeddingGenerator to verify HTTP communication, response parsing, and error handling. +/// Uses mocked HttpMessageHandler to avoid real Ollama calls in unit tests. +/// +public sealed class OllamaEmbeddingGeneratorTests +{ + private readonly Mock> _loggerMock; + private readonly Mock _httpHandlerMock; + + public OllamaEmbeddingGeneratorTests() + { + this._loggerMock = new Mock>(); + this._httpHandlerMock = new Mock(); + } + + [Fact] + public void Properties_ShouldReflectConfiguration() + { + // Arrange + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OllamaEmbeddingGenerator( + httpClient, + baseUrl: "http://localhost:11434", + model: "qwen3-embedding", + vectorDimensions: 1024, + isNormalized: true, + this._loggerMock.Object); + + // Assert + Assert.Equal(EmbeddingsTypes.Ollama, generator.ProviderType); + Assert.Equal("qwen3-embedding", generator.ModelName); + Assert.Equal(1024, generator.VectorDimensions); + Assert.True(generator.IsNormalized); + } + + [Fact] + public async Task GenerateAsync_Single_ShouldCallCorrectEndpoint() + { + // Arrange + var response = new OllamaEmbeddingResponse { Embedding = new[] { 0.1f, 0.2f, 0.3f } }; + var responseJson = JsonSerializer.Serialize(response); + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.Is(req => + req.Method == HttpMethod.Post && + req.RequestUri!.ToString() == "http://localhost:11434/api/embeddings"), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(responseJson) + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OllamaEmbeddingGenerator( + httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object); + + // Act + var result = await generator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(new[] { 0.1f, 0.2f, 0.3f }, result); + } + + [Fact] + public async Task GenerateAsync_Single_ShouldSendCorrectRequestBody() + { + // Arrange + HttpRequestMessage? capturedRequest = null; + var response = new OllamaEmbeddingResponse { Embedding = new[] { 0.1f } }; + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .Callback((req, _) => capturedRequest = req) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(JsonSerializer.Serialize(response)) + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OllamaEmbeddingGenerator( + httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object); + + // Act + await generator.GenerateAsync("hello world", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.NotNull(capturedRequest); + var content = await capturedRequest!.Content!.ReadAsStringAsync().ConfigureAwait(false); + var requestBody = JsonSerializer.Deserialize(content); + Assert.Equal("qwen3-embedding", requestBody!.Model); + Assert.Equal("hello world", requestBody.Prompt); + } + + [Fact] + public async Task GenerateAsync_Batch_ShouldProcessAllTexts() + { + // Arrange + var texts = new[] { "text1", "text2", "text3" }; + var vectors = new[] + { + new[] { 0.1f }, + new[] { 0.2f }, + new[] { 0.3f } + }; + + var callIndex = 0; + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(() => + { + var response = new OllamaEmbeddingResponse { Embedding = vectors[callIndex++] }; + return new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(JsonSerializer.Serialize(response)) + }; + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OllamaEmbeddingGenerator( + httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object); + + // Act + var results = await generator.GenerateAsync(texts, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(3, results.Length); + Assert.Equal(new[] { 0.1f }, results[0]); + Assert.Equal(new[] { 0.2f }, results[1]); + Assert.Equal(new[] { 0.3f }, results[2]); + } + + [Fact] + public async Task GenerateAsync_WithHttpError_ShouldThrowHttpRequestException() + { + // Arrange + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.InternalServerError, + Content = new StringContent("Server error") + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OllamaEmbeddingGenerator( + httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object); + + // Act & Assert + await Assert.ThrowsAsync( + () => generator.GenerateAsync("test", CancellationToken.None)).ConfigureAwait(false); + } + + [Fact] + public async Task GenerateAsync_WithMalformedResponse_ShouldThrowJsonException() + { + // Arrange + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent("not json") + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OllamaEmbeddingGenerator( + httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object); + + // Act & Assert + await Assert.ThrowsAsync( + () => generator.GenerateAsync("test", CancellationToken.None)).ConfigureAwait(false); + } + + [Fact] + public async Task GenerateAsync_WithCancellation_ShouldPropagate() + { + // Arrange + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ThrowsAsync(new OperationCanceledException()); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OllamaEmbeddingGenerator( + httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object); + + var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act & Assert - TaskCanceledException inherits from OperationCanceledException + await Assert.ThrowsAnyAsync( + () => generator.GenerateAsync("test", cts.Token)).ConfigureAwait(false); + } + + [Fact] + public void Constructor_WithNullHttpClient_ShouldThrow() + { + // Assert + Assert.Throws(() => + new OllamaEmbeddingGenerator(null!, "http://localhost", "model", 1024, true, this._loggerMock.Object)); + } + + [Fact] + public void Constructor_WithNullBaseUrl_ShouldThrow() + { + // Assert + var httpClient = new HttpClient(); + Assert.Throws(() => + new OllamaEmbeddingGenerator(httpClient, null!, "model", 1024, true, this._loggerMock.Object)); + } + + [Fact] + public void Constructor_WithNullModel_ShouldThrow() + { + // Assert + var httpClient = new HttpClient(); + Assert.Throws(() => + new OllamaEmbeddingGenerator(httpClient, "http://localhost", null!, 1024, true, this._loggerMock.Object)); + } + + // Internal request/response classes for testing + private sealed class OllamaEmbeddingRequest + { + [System.Text.Json.Serialization.JsonPropertyName("model")] + public string Model { get; set; } = string.Empty; + + [System.Text.Json.Serialization.JsonPropertyName("prompt")] + public string Prompt { get; set; } = string.Empty; + } + + private sealed class OllamaEmbeddingResponse + { + [System.Text.Json.Serialization.JsonPropertyName("embedding")] + public float[] Embedding { get; set; } = Array.Empty(); + } +} diff --git a/tests/Core.Tests/Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs new file mode 100644 index 000000000..7dced11ba --- /dev/null +++ b/tests/Core.Tests/Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs @@ -0,0 +1,332 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Net; +using System.Text.Json; +using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Embeddings.Providers; +using Microsoft.Extensions.Logging; +using Moq; +using Moq.Protected; + +namespace KernelMemory.Core.Tests.Embeddings.Providers; + +/// +/// Tests for OpenAIEmbeddingGenerator to verify API communication, batch processing, and token count extraction. +/// Uses mocked HttpMessageHandler to avoid real OpenAI API calls in unit tests. +/// +public sealed class OpenAIEmbeddingGeneratorTests +{ + private readonly Mock> _loggerMock; + private readonly Mock _httpHandlerMock; + + public OpenAIEmbeddingGeneratorTests() + { + this._loggerMock = new Mock>(); + this._httpHandlerMock = new Mock(); + } + + [Fact] + public void Properties_ShouldReflectConfiguration() + { + // Arrange + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OpenAIEmbeddingGenerator( + httpClient, + apiKey: "test-key", + model: "text-embedding-ada-002", + vectorDimensions: 1536, + isNormalized: true, + baseUrl: null, + this._loggerMock.Object); + + // Assert + Assert.Equal(EmbeddingsTypes.OpenAI, generator.ProviderType); + Assert.Equal("text-embedding-ada-002", generator.ModelName); + Assert.Equal(1536, generator.VectorDimensions); + Assert.True(generator.IsNormalized); + } + + [Fact] + public async Task GenerateAsync_Single_ShouldCallCorrectEndpoint() + { + // Arrange + var response = CreateOpenAIResponse(new[] { new[] { 0.1f, 0.2f, 0.3f } }); + var responseJson = JsonSerializer.Serialize(response); + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.Is(req => + req.Method == HttpMethod.Post && + req.RequestUri!.ToString() == "https://api.openai.com/v1/embeddings"), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(responseJson) + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OpenAIEmbeddingGenerator( + httpClient, "test-key", "text-embedding-ada-002", 1536, true, null, this._loggerMock.Object); + + // Act + var result = await generator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(new[] { 0.1f, 0.2f, 0.3f }, result); + } + + [Fact] + public async Task GenerateAsync_Single_ShouldSendAuthorizationHeader() + { + // Arrange + HttpRequestMessage? capturedRequest = null; + var response = CreateOpenAIResponse(new[] { new[] { 0.1f } }); + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .Callback((req, _) => capturedRequest = req) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(JsonSerializer.Serialize(response)) + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OpenAIEmbeddingGenerator( + httpClient, "sk-test-api-key", "text-embedding-ada-002", 1536, true, null, this._loggerMock.Object); + + // Act + await generator.GenerateAsync("test", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.NotNull(capturedRequest); + Assert.Equal("Bearer", capturedRequest!.Headers.Authorization?.Scheme); + Assert.Equal("sk-test-api-key", capturedRequest.Headers.Authorization?.Parameter); + } + + [Fact] + public async Task GenerateAsync_Batch_ShouldSendAllInputsInOneRequest() + { + // Arrange + var texts = new[] { "text1", "text2", "text3" }; + var response = CreateOpenAIResponse(new[] + { + new[] { 0.1f }, + new[] { 0.2f }, + new[] { 0.3f } + }); + + string? capturedContent = null; + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .Returns(async (req, ct) => + { + // Capture content before it's disposed + capturedContent = await req.Content!.ReadAsStringAsync(ct).ConfigureAwait(false); + return new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(JsonSerializer.Serialize(response)) + }; + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OpenAIEmbeddingGenerator( + httpClient, "test-key", "text-embedding-ada-002", 1536, true, null, this._loggerMock.Object); + + // Act + var results = await generator.GenerateAsync(texts, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(3, results.Length); + var requestBody = JsonSerializer.Deserialize(capturedContent!); + Assert.Equal(3, requestBody!.Input.Length); + } + + [Fact] + public async Task GenerateAsync_WithCustomBaseUrl_ShouldUseIt() + { + // Arrange + var response = CreateOpenAIResponse(new[] { new[] { 0.1f } }); + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.Is(req => + req.RequestUri!.ToString() == "https://custom.api.com/v1/embeddings"), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(JsonSerializer.Serialize(response)) + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OpenAIEmbeddingGenerator( + httpClient, "test-key", "model", 1536, true, "https://custom.api.com", this._loggerMock.Object); + + // Act + var result = await generator.GenerateAsync("test", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(new[] { 0.1f }, result); + } + + [Fact] + public async Task GenerateAsync_WithRateLimitError_ShouldThrowHttpRequestException() + { + // Arrange + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.TooManyRequests, + Content = new StringContent("Rate limit exceeded") + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OpenAIEmbeddingGenerator( + httpClient, "test-key", "model", 1536, true, null, this._loggerMock.Object); + + // Act & Assert + await Assert.ThrowsAsync( + () => generator.GenerateAsync("test", CancellationToken.None)).ConfigureAwait(false); + } + + [Fact] + public async Task GenerateAsync_WithUnauthorizedError_ShouldThrowHttpRequestException() + { + // Arrange + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.Unauthorized, + Content = new StringContent("Invalid API key") + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OpenAIEmbeddingGenerator( + httpClient, "bad-key", "model", 1536, true, null, this._loggerMock.Object); + + // Act & Assert + await Assert.ThrowsAsync( + () => generator.GenerateAsync("test", CancellationToken.None)).ConfigureAwait(false); + } + + [Fact] + public async Task GenerateAsync_WithCancellation_ShouldPropagate() + { + // Arrange + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ThrowsAsync(new OperationCanceledException()); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OpenAIEmbeddingGenerator( + httpClient, "test-key", "model", 1536, true, null, this._loggerMock.Object); + + var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act & Assert - TaskCanceledException inherits from OperationCanceledException + await Assert.ThrowsAnyAsync( + () => generator.GenerateAsync("test", cts.Token)).ConfigureAwait(false); + } + + [Fact] + public void Constructor_WithNullApiKey_ShouldThrow() + { + // Assert + var httpClient = new HttpClient(); + Assert.Throws(() => + new OpenAIEmbeddingGenerator(httpClient, null!, "model", 1536, true, null, this._loggerMock.Object)); + } + + [Fact] + public void Constructor_WithEmptyApiKey_ShouldThrow() + { + // Assert + var httpClient = new HttpClient(); + Assert.Throws(() => + new OpenAIEmbeddingGenerator(httpClient, "", "model", 1536, true, null, this._loggerMock.Object)); + } + + [Fact] + public void Constructor_WithNullModel_ShouldThrow() + { + // Assert + var httpClient = new HttpClient(); + Assert.Throws(() => + new OpenAIEmbeddingGenerator(httpClient, "key", null!, 1536, true, null, this._loggerMock.Object)); + } + + private static OpenAIEmbeddingResponse CreateOpenAIResponse(float[][] embeddings) + { + return new OpenAIEmbeddingResponse + { + Data = embeddings.Select((e, i) => new EmbeddingData { Index = i, Embedding = e }).ToArray(), + Usage = new UsageInfo { PromptTokens = 10, TotalTokens = 10 } + }; + } + + // Internal request/response classes for testing + private sealed class OpenAIEmbeddingRequest + { + [System.Text.Json.Serialization.JsonPropertyName("model")] + public string Model { get; set; } = string.Empty; + + [System.Text.Json.Serialization.JsonPropertyName("input")] + public string[] Input { get; set; } = Array.Empty(); + } + + private sealed class OpenAIEmbeddingResponse + { + [System.Text.Json.Serialization.JsonPropertyName("data")] + public EmbeddingData[] Data { get; set; } = Array.Empty(); + + [System.Text.Json.Serialization.JsonPropertyName("usage")] + public UsageInfo Usage { get; set; } = new(); + } + + private sealed class EmbeddingData + { + [System.Text.Json.Serialization.JsonPropertyName("index")] + public int Index { get; set; } + + [System.Text.Json.Serialization.JsonPropertyName("embedding")] + public float[] Embedding { get; set; } = Array.Empty(); + } + + private sealed class UsageInfo + { + [System.Text.Json.Serialization.JsonPropertyName("prompt_tokens")] + public int PromptTokens { get; set; } + + [System.Text.Json.Serialization.JsonPropertyName("total_tokens")] + public int TotalTokens { get; set; } + } +} diff --git a/tests/Core.Tests/GlobalUsings.cs b/tests/Core.Tests/GlobalUsings.cs index cbd4300e0..543dc179d 100644 --- a/tests/Core.Tests/GlobalUsings.cs +++ b/tests/Core.Tests/GlobalUsings.cs @@ -1,3 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. global using Xunit; + +using System.Diagnostics.CodeAnalysis; + +// Test files create disposable objects that are managed by the test framework lifecycle +// These suppressions apply to test code only and are acceptable for unit tests +[assembly: SuppressMessage("Reliability", "CA2000:Dispose objects before losing scope", + Justification = "Test objects are disposed when test completes - managed by xUnit")] +[assembly: SuppressMessage("Performance", "CA1861:Prefer 'static readonly' fields over constant array arguments", + Justification = "Arrays in tests are for readability and small test data")] From b8fe07cb5b48e2636f3e4efd7e95a3e0bbd965c6 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Tue, 2 Dec 2025 17:37:44 +0100 Subject: [PATCH 2/6] refactor: use composite primary key in embedding cache Changed from single TEXT key to composite primary key with individual columns (provider, model, dimensions, normalized, text_hash) for: - Easier debugging when inspecting SQLite database contents - Selective SELECT/DELETE with individual discriminators - Better query flexibility for cache management Added indexes on provider and (provider, model) for common query patterns. --- .../Embeddings/Cache/SqliteEmbeddingCache.cs | 52 +++++++++++++------ 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs b/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs index babc9327c..bcc806036 100644 --- a/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs +++ b/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs @@ -15,18 +15,33 @@ public sealed class SqliteEmbeddingCache : IEmbeddingCache, IDisposable { private const string CreateTableSql = """ CREATE TABLE IF NOT EXISTS embeddings_cache ( - key TEXT PRIMARY KEY, + provider TEXT NOT NULL, + model TEXT NOT NULL, + dimensions INTEGER NOT NULL, + normalized INTEGER NOT NULL, + text_length INTEGER NOT NULL, + text_hash TEXT NOT NULL, vector BLOB NOT NULL, token_count INTEGER NULL, - timestamp TEXT NOT NULL + timestamp TEXT NOT NULL, + PRIMARY KEY (provider, model, dimensions, normalized, text_hash) ); CREATE INDEX IF NOT EXISTS idx_timestamp ON embeddings_cache(timestamp); + CREATE INDEX IF NOT EXISTS idx_provider ON embeddings_cache(provider); + CREATE INDEX IF NOT EXISTS idx_model ON embeddings_cache(provider, model); + """; + + private const string SelectSql = """ + SELECT vector, token_count, timestamp FROM embeddings_cache + WHERE provider = @provider AND model = @model AND dimensions = @dimensions + AND normalized = @normalized AND text_hash = @textHash """; - private const string SelectSql = "SELECT vector, token_count, timestamp FROM embeddings_cache WHERE key = @key"; private const string UpsertSql = """ - INSERT INTO embeddings_cache (key, vector, token_count, timestamp) VALUES (@key, @vector, @tokenCount, @timestamp) - ON CONFLICT(key) DO UPDATE SET vector = @vector, token_count = @tokenCount, timestamp = @timestamp + INSERT INTO embeddings_cache (provider, model, dimensions, normalized, text_length, text_hash, vector, token_count, timestamp) + VALUES (@provider, @model, @dimensions, @normalized, @textLength, @textHash, @vector, @tokenCount, @timestamp) + ON CONFLICT(provider, model, dimensions, normalized, text_hash) + DO UPDATE SET vector = @vector, token_count = @tokenCount, timestamp = @timestamp """; private readonly SqliteConnection _connection; @@ -100,20 +115,23 @@ public SqliteEmbeddingCache(string dbPath, CacheModes mode, ILogger Date: Tue, 2 Dec 2025 17:59:03 +0100 Subject: [PATCH 3/6] fix: address PR review comments for embedding cache Changes: - Allow null text in EmbeddingCacheKey.Create(), normalize to empty string - Remove unnecessary indexes (idx_timestamp, idx_provider, idx_model) - Rename 'normalized' column to 'is_normalized' for clarity - Rename 'timestamp' column to 'created_at' for clarity - Simplify timestamp documentation (debugging only, no invalidation) --- docs | 2 +- src/Core/Embeddings/Cache/CachedEmbedding.cs | 1 + .../Embeddings/Cache/EmbeddingCacheKey.cs | 8 +++-- .../Embeddings/Cache/SqliteEmbeddingCache.cs | 33 +++++++++---------- .../Embeddings/EmbeddingCacheKeyTests.cs | 11 ++++--- 5 files changed, 29 insertions(+), 26 deletions(-) diff --git a/docs b/docs index 03362c93d..23845bf23 160000 --- a/docs +++ b/docs @@ -1 +1 @@ -Subproject commit 03362c93d3f478ae345db7976db07b5e1043348c +Subproject commit 23845bf23aa39bb2a443fbe47e100a7f8de6c5db diff --git a/src/Core/Embeddings/Cache/CachedEmbedding.cs b/src/Core/Embeddings/Cache/CachedEmbedding.cs index 8ce123b1e..f943af03d 100644 --- a/src/Core/Embeddings/Cache/CachedEmbedding.cs +++ b/src/Core/Embeddings/Cache/CachedEmbedding.cs @@ -25,6 +25,7 @@ public sealed class CachedEmbedding /// /// Timestamp when this embedding was stored in the cache. + /// Useful for debugging and diagnostics. /// public required DateTimeOffset Timestamp { get; init; } } diff --git a/src/Core/Embeddings/Cache/EmbeddingCacheKey.cs b/src/Core/Embeddings/Cache/EmbeddingCacheKey.cs index 55777d832..2607bb4f6 100644 --- a/src/Core/Embeddings/Cache/EmbeddingCacheKey.cs +++ b/src/Core/Embeddings/Cache/EmbeddingCacheKey.cs @@ -53,20 +53,22 @@ public sealed class EmbeddingCacheKey /// Whether vectors are normalized. /// The text to hash. /// A new EmbeddingCacheKey instance. - /// When provider, model, or text is null. + /// When provider or model is null. /// When vectorDimensions is less than 1. public static EmbeddingCacheKey Create( string provider, string model, int vectorDimensions, bool isNormalized, - string text) + string? text) { ArgumentNullException.ThrowIfNull(provider, nameof(provider)); ArgumentNullException.ThrowIfNull(model, nameof(model)); - ArgumentNullException.ThrowIfNull(text, nameof(text)); ArgumentOutOfRangeException.ThrowIfLessThan(vectorDimensions, 1, nameof(vectorDimensions)); + // Normalize null to empty string for consistent hashing + text ??= string.Empty; + return new EmbeddingCacheKey { Provider = provider, diff --git a/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs b/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs index bcc806036..7bbb1d2a4 100644 --- a/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs +++ b/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs @@ -18,30 +18,27 @@ CREATE TABLE IF NOT EXISTS embeddings_cache ( provider TEXT NOT NULL, model TEXT NOT NULL, dimensions INTEGER NOT NULL, - normalized INTEGER NOT NULL, + is_normalized INTEGER NOT NULL, text_length INTEGER NOT NULL, text_hash TEXT NOT NULL, vector BLOB NOT NULL, token_count INTEGER NULL, - timestamp TEXT NOT NULL, - PRIMARY KEY (provider, model, dimensions, normalized, text_hash) + created_at TEXT NOT NULL, + PRIMARY KEY (provider, model, dimensions, is_normalized, text_hash) ); - CREATE INDEX IF NOT EXISTS idx_timestamp ON embeddings_cache(timestamp); - CREATE INDEX IF NOT EXISTS idx_provider ON embeddings_cache(provider); - CREATE INDEX IF NOT EXISTS idx_model ON embeddings_cache(provider, model); """; private const string SelectSql = """ - SELECT vector, token_count, timestamp FROM embeddings_cache + SELECT vector, token_count, created_at FROM embeddings_cache WHERE provider = @provider AND model = @model AND dimensions = @dimensions - AND normalized = @normalized AND text_hash = @textHash + AND is_normalized = @isNormalized AND text_hash = @textHash """; private const string UpsertSql = """ - INSERT INTO embeddings_cache (provider, model, dimensions, normalized, text_length, text_hash, vector, token_count, timestamp) - VALUES (@provider, @model, @dimensions, @normalized, @textLength, @textHash, @vector, @tokenCount, @timestamp) - ON CONFLICT(provider, model, dimensions, normalized, text_hash) - DO UPDATE SET vector = @vector, token_count = @tokenCount, timestamp = @timestamp + INSERT INTO embeddings_cache (provider, model, dimensions, is_normalized, text_length, text_hash, vector, token_count, created_at) + VALUES (@provider, @model, @dimensions, @isNormalized, @textLength, @textHash, @vector, @tokenCount, @createdAt) + ON CONFLICT(provider, model, dimensions, is_normalized, text_hash) + DO UPDATE SET vector = @vector, token_count = @tokenCount, created_at = @createdAt """; private readonly SqliteConnection _connection; @@ -122,7 +119,7 @@ public SqliteEmbeddingCache(string dbPath, CacheModes mode, ILogger(() => - EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, null!)); + // Arrange & Act + var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, null); + + // Assert - null should be normalized to empty string + Assert.Equal(0, key.TextLength); + Assert.NotEmpty(key.TextHash); // Hash of empty string } [Fact] From 06bdcbc4d21707eefa3c7b956ac73d388659bbef Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Tue, 2 Dec 2025 18:09:37 +0100 Subject: [PATCH 4/6] refactor: remove timestamp from embedding cache Embeddings are immutable - same text always produces the same vector. No need for cache invalidation or timestamps. Simplifies the schema. --- src/Core/Embeddings/Cache/CachedEmbedding.cs | 6 --- .../Embeddings/Cache/SqliteEmbeddingCache.cs | 15 +++---- .../Cache/SqliteEmbeddingCacheTests.cs | 20 --------- .../CachedEmbeddingGeneratorTests.cs | 11 +++-- .../Embeddings/CachedEmbeddingTests.cs | 42 ++----------------- 5 files changed, 14 insertions(+), 80 deletions(-) diff --git a/src/Core/Embeddings/Cache/CachedEmbedding.cs b/src/Core/Embeddings/Cache/CachedEmbedding.cs index f943af03d..6a876e411 100644 --- a/src/Core/Embeddings/Cache/CachedEmbedding.cs +++ b/src/Core/Embeddings/Cache/CachedEmbedding.cs @@ -22,10 +22,4 @@ public sealed class CachedEmbedding /// Null if the provider did not return token count. /// public int? TokenCount { get; init; } - - /// - /// Timestamp when this embedding was stored in the cache. - /// Useful for debugging and diagnostics. - /// - public required DateTimeOffset Timestamp { get; init; } } diff --git a/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs b/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs index 7bbb1d2a4..a93822d82 100644 --- a/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs +++ b/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs @@ -23,22 +23,21 @@ CREATE TABLE IF NOT EXISTS embeddings_cache ( text_hash TEXT NOT NULL, vector BLOB NOT NULL, token_count INTEGER NULL, - created_at TEXT NOT NULL, PRIMARY KEY (provider, model, dimensions, is_normalized, text_hash) ); """; private const string SelectSql = """ - SELECT vector, token_count, created_at FROM embeddings_cache + SELECT vector, token_count FROM embeddings_cache WHERE provider = @provider AND model = @model AND dimensions = @dimensions AND is_normalized = @isNormalized AND text_hash = @textHash """; private const string UpsertSql = """ - INSERT INTO embeddings_cache (provider, model, dimensions, is_normalized, text_length, text_hash, vector, token_count, created_at) - VALUES (@provider, @model, @dimensions, @isNormalized, @textLength, @textHash, @vector, @tokenCount, @createdAt) + INSERT INTO embeddings_cache (provider, model, dimensions, is_normalized, text_length, text_hash, vector, token_count) + VALUES (@provider, @model, @dimensions, @isNormalized, @textLength, @textHash, @vector, @tokenCount) ON CONFLICT(provider, model, dimensions, is_normalized, text_hash) - DO UPDATE SET vector = @vector, token_count = @tokenCount, created_at = @createdAt + DO UPDATE SET vector = @vector, token_count = @tokenCount """; private readonly SqliteConnection _connection; @@ -136,7 +135,6 @@ public SqliteEmbeddingCache(string dbPath, CacheModes mode, ILogger= beforeStore.AddSeconds(-1)); // Allow 1 second tolerance - Assert.True(result.Timestamp <= afterStore.AddSeconds(1)); - } - [Fact] public async Task StoreAsync_WithLargeVector_ShouldRoundTrip() { diff --git a/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs index cb74ffb30..66e2ae17e 100644 --- a/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs +++ b/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs @@ -54,8 +54,7 @@ public async Task GenerateAsync_Single_WithCacheHit_ShouldReturnCachedVector() var cachedVector = new float[] { 0.1f, 0.2f, 0.3f }; var cachedEmbedding = new CachedEmbedding { - Vector = cachedVector, - Timestamp = DateTimeOffset.UtcNow + Vector = cachedVector }; this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); @@ -220,7 +219,7 @@ public async Task GenerateAsync_Batch_AllCacheHits_ShouldNotCallInnerGenerator() var testKey = EmbeddingCacheKey.Create("OpenAI", "text-embedding-ada-002", 1536, true, kvp.Key); if (testKey.TextHash == key.TextHash) { - return new CachedEmbedding { Vector = kvp.Value, Timestamp = DateTimeOffset.UtcNow }; + return new CachedEmbedding { Vector = kvp.Value }; } } @@ -299,7 +298,7 @@ public async Task GenerateAsync_Batch_MixedHitsAndMisses_ShouldOnlyGenerateMisse { if (key.TextHash == cachedKey.TextHash) { - return new CachedEmbedding { Vector = cachedVector, Timestamp = DateTimeOffset.UtcNow }; + return new CachedEmbedding { Vector = cachedVector }; } return null; @@ -411,12 +410,12 @@ public async Task GenerateAsync_Batch_ShouldPreserveOrder() { if (key.TextHash == cachedB.TextHash) { - return new CachedEmbedding { Vector = vectorB, Timestamp = DateTimeOffset.UtcNow }; + return new CachedEmbedding { Vector = vectorB }; } if (key.TextHash == cachedD.TextHash) { - return new CachedEmbedding { Vector = vectorD, Timestamp = DateTimeOffset.UtcNow }; + return new CachedEmbedding { Vector = vectorD }; } return null; diff --git a/tests/Core.Tests/Embeddings/CachedEmbeddingTests.cs b/tests/Core.Tests/Embeddings/CachedEmbeddingTests.cs index 5c68e34a6..e309f370c 100644 --- a/tests/Core.Tests/Embeddings/CachedEmbeddingTests.cs +++ b/tests/Core.Tests/Embeddings/CachedEmbeddingTests.cs @@ -5,7 +5,7 @@ namespace KernelMemory.Core.Tests.Embeddings; /// /// Tests for CachedEmbedding model to verify proper construction and immutability. -/// CachedEmbedding stores the embedding vector, optional token count, and timestamp. +/// CachedEmbedding stores the embedding vector and optional token count. /// public sealed class CachedEmbeddingTests { @@ -14,18 +14,15 @@ public void CachedEmbedding_WithRequiredProperties_ShouldBeCreated() { // Arrange var vector = new float[] { 0.1f, 0.2f, 0.3f }; - var timestamp = DateTimeOffset.UtcNow; // Act var cached = new CachedEmbedding { - Vector = vector, - Timestamp = timestamp + Vector = vector }; // Assert Assert.Equal(vector, cached.Vector); - Assert.Equal(timestamp, cached.Timestamp); Assert.Null(cached.TokenCount); } @@ -34,14 +31,12 @@ public void CachedEmbedding_WithTokenCount_ShouldStoreValue() { // Arrange var vector = new float[] { 0.1f, 0.2f, 0.3f }; - var timestamp = DateTimeOffset.UtcNow; const int tokenCount = 42; // Act var cached = new CachedEmbedding { Vector = vector, - Timestamp = timestamp, TokenCount = tokenCount }; @@ -54,13 +49,11 @@ public void CachedEmbedding_WithNullTokenCount_ShouldBeNull() { // Arrange var vector = new float[] { 0.1f, 0.2f, 0.3f }; - var timestamp = DateTimeOffset.UtcNow; // Act var cached = new CachedEmbedding { Vector = vector, - Timestamp = timestamp, TokenCount = null }; @@ -73,13 +66,11 @@ public void CachedEmbedding_VectorShouldPreserveFloatPrecision() { // Arrange var vector = new float[] { 0.123456789f, -0.987654321f, float.MaxValue, float.MinValue }; - var timestamp = DateTimeOffset.UtcNow; // Act var cached = new CachedEmbedding { - Vector = vector, - Timestamp = timestamp + Vector = vector }; // Assert @@ -99,13 +90,10 @@ public void CachedEmbedding_WithLargeVector_ShouldPreserveAllDimensions() vector[i] = (float)i / 1536; } - var timestamp = DateTimeOffset.UtcNow; - // Act var cached = new CachedEmbedding { - Vector = vector, - Timestamp = timestamp + Vector = vector }; // Assert @@ -113,26 +101,4 @@ public void CachedEmbedding_WithLargeVector_ShouldPreserveAllDimensions() Assert.Equal(0.0f, cached.Vector[0]); Assert.Equal(1535f / 1536, cached.Vector[1535], precision: 6); } - - [Fact] - public void CachedEmbedding_TimestampShouldBeDateTimeOffset() - { - // Arrange - var vector = new float[] { 0.1f }; - var timestamp = new DateTimeOffset(2025, 12, 1, 10, 30, 0, TimeSpan.FromHours(-5)); - - // Act - var cached = new CachedEmbedding - { - Vector = vector, - Timestamp = timestamp - }; - - // Assert - Assert.Equal(timestamp.Year, cached.Timestamp.Year); - Assert.Equal(timestamp.Month, cached.Timestamp.Month); - Assert.Equal(timestamp.Day, cached.Timestamp.Day); - Assert.Equal(timestamp.Hour, cached.Timestamp.Hour); - Assert.Equal(timestamp.Offset, cached.Timestamp.Offset); - } } From b6449e74c8c15041081e01ef6ef39f79c4dead82 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Tue, 2 Dec 2025 18:23:23 +0100 Subject: [PATCH 5/6] refactor: simplify embedding cache - remove TokenCount CachedEmbedding now only stores the vector: - Removed TokenCount (token counting should be done before embedding) - Removed timestamp (embeddings are immutable, no invalidation needed) - StoreAsync signature simplified: (key, vector, ct) Final schema: provider, model, dimensions, is_normalized, text_length, text_hash, vector PRIMARY KEY (provider, model, dimensions, is_normalized, text_hash) --- src/Core/Embeddings/Cache/CachedEmbedding.cs | 14 ++---- src/Core/Embeddings/Cache/IEmbeddingCache.cs | 3 +- .../Embeddings/Cache/SqliteEmbeddingCache.cs | 18 +++----- .../Embeddings/CachedEmbeddingGenerator.cs | 4 +- .../Cache/SqliteEmbeddingCacheTests.cs | 46 ++++++------------- .../CachedEmbeddingGeneratorTests.cs | 8 ++-- .../Embeddings/CachedEmbeddingTests.cs | 40 +--------------- 7 files changed, 32 insertions(+), 101 deletions(-) diff --git a/src/Core/Embeddings/Cache/CachedEmbedding.cs b/src/Core/Embeddings/Cache/CachedEmbedding.cs index 6a876e411..43f847c13 100644 --- a/src/Core/Embeddings/Cache/CachedEmbedding.cs +++ b/src/Core/Embeddings/Cache/CachedEmbedding.cs @@ -4,22 +4,14 @@ namespace KernelMemory.Core.Embeddings.Cache; /// -/// Represents a cached embedding vector with metadata. -/// Stores the vector, optional token count, and timestamp of when it was cached. +/// Represents a cached embedding vector. /// public sealed class CachedEmbedding { /// - /// The embedding vector as a float array. - /// Array is intentional for performance - embeddings are read-only after creation. + /// The embedding vector. /// [SuppressMessage("Performance", "CA1819:Properties should not return arrays", - Justification = "Embedding vectors are performance-critical and read-only after creation")] + Justification = "Embedding vectors are read-only after creation and passed to storage layer")] public required float[] Vector { get; init; } - - /// - /// Optional token count from the provider response. - /// Null if the provider did not return token count. - /// - public int? TokenCount { get; init; } } diff --git a/src/Core/Embeddings/Cache/IEmbeddingCache.cs b/src/Core/Embeddings/Cache/IEmbeddingCache.cs index 201692161..59fd15743 100644 --- a/src/Core/Embeddings/Cache/IEmbeddingCache.cs +++ b/src/Core/Embeddings/Cache/IEmbeddingCache.cs @@ -30,7 +30,6 @@ public interface IEmbeddingCache /// /// The cache key. /// The embedding vector to store. - /// Optional token count from the provider. /// Cancellation token. - Task StoreAsync(EmbeddingCacheKey key, float[] vector, int? tokenCount, CancellationToken ct = default); + Task StoreAsync(EmbeddingCacheKey key, float[] vector, CancellationToken ct = default); } diff --git a/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs b/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs index a93822d82..39690c6d7 100644 --- a/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs +++ b/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs @@ -1,5 +1,4 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Globalization; using KernelMemory.Core.Config.Enums; using Microsoft.Data.Sqlite; using Microsoft.Extensions.Logging; @@ -22,22 +21,21 @@ CREATE TABLE IF NOT EXISTS embeddings_cache ( text_length INTEGER NOT NULL, text_hash TEXT NOT NULL, vector BLOB NOT NULL, - token_count INTEGER NULL, PRIMARY KEY (provider, model, dimensions, is_normalized, text_hash) ); """; private const string SelectSql = """ - SELECT vector, token_count FROM embeddings_cache + SELECT vector FROM embeddings_cache WHERE provider = @provider AND model = @model AND dimensions = @dimensions AND is_normalized = @isNormalized AND text_hash = @textHash """; private const string UpsertSql = """ - INSERT INTO embeddings_cache (provider, model, dimensions, is_normalized, text_length, text_hash, vector, token_count) - VALUES (@provider, @model, @dimensions, @isNormalized, @textLength, @textHash, @vector, @tokenCount) + INSERT INTO embeddings_cache (provider, model, dimensions, is_normalized, text_length, text_hash, vector) + VALUES (@provider, @model, @dimensions, @isNormalized, @textLength, @textHash, @vector) ON CONFLICT(provider, model, dimensions, is_normalized, text_hash) - DO UPDATE SET vector = @vector, token_count = @tokenCount + DO UPDATE SET vector = @vector """; private readonly SqliteConnection _connection; @@ -134,22 +132,19 @@ public SqliteEmbeddingCache(string dbPath, CacheModes mode, ILogger - public async Task StoreAsync(EmbeddingCacheKey key, float[] vector, int? tokenCount, CancellationToken ct = default) + public async Task StoreAsync(EmbeddingCacheKey key, float[] vector, CancellationToken ct = default) { ct.ThrowIfCancellationRequested(); @@ -173,7 +168,6 @@ public async Task StoreAsync(EmbeddingCacheKey key, float[] vector, int? tokenCo command.Parameters.AddWithValue("@textLength", key.TextLength); command.Parameters.AddWithValue("@textHash", key.TextHash); command.Parameters.AddWithValue("@vector", vectorBlob); - command.Parameters.AddWithValue("@tokenCount", tokenCount.HasValue ? tokenCount.Value : DBNull.Value); await command.ExecuteNonQueryAsync(ct).ConfigureAwait(false); diff --git a/src/Core/Embeddings/CachedEmbeddingGenerator.cs b/src/Core/Embeddings/CachedEmbeddingGenerator.cs index 26a100676..897942cef 100644 --- a/src/Core/Embeddings/CachedEmbeddingGenerator.cs +++ b/src/Core/Embeddings/CachedEmbeddingGenerator.cs @@ -76,7 +76,7 @@ public async Task GenerateAsync(string text, CancellationToken ct = def // Store in cache (if mode allows) if (this._cache.Mode != CacheModes.ReadOnly) { - await this._cache.StoreAsync(key, vector, tokenCount: null, ct).ConfigureAwait(false); + await this._cache.StoreAsync(key, vector, ct).ConfigureAwait(false); this._logger.LogDebug("Stored embedding in cache, dimensions: {Dimensions}", vector.Length); } @@ -145,7 +145,7 @@ public async Task GenerateAsync(IEnumerable texts, Cancellati if (this._cache.Mode != CacheModes.ReadOnly) { var key = this.BuildCacheKey(text); - await this._cache.StoreAsync(key, generatedVectors[i], tokenCount: null, ct).ConfigureAwait(false); + await this._cache.StoreAsync(key, generatedVectors[i], ct).ConfigureAwait(false); } } diff --git a/tests/Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs b/tests/Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs index 9f134ccb2..9eb1e89ba 100644 --- a/tests/Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs +++ b/tests/Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs @@ -53,7 +53,7 @@ public async Task StoreAsync_AndTryGetAsync_ShouldRoundTrip() var vector = new float[] { 0.1f, 0.2f, 0.3f, 0.4f }; // Act - await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); // Assert @@ -61,24 +61,6 @@ public async Task StoreAsync_AndTryGetAsync_ShouldRoundTrip() Assert.Equal(vector, result.Vector); } - [Fact] - public async Task StoreAsync_WithTokenCount_ShouldPreserveValue() - { - // Arrange - using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); - var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); - var vector = new float[] { 0.1f, 0.2f, 0.3f }; - const int tokenCount = 42; - - // Act - await cache.StoreAsync(key, vector, tokenCount, CancellationToken.None).ConfigureAwait(false); - var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); - - // Assert - Assert.NotNull(result); - Assert.Equal(tokenCount, result.TokenCount); - } - [Fact] public async Task StoreAsync_WithLargeVector_ShouldRoundTrip() { @@ -92,7 +74,7 @@ public async Task StoreAsync_WithLargeVector_ShouldRoundTrip() } // Act - await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); // Assert @@ -114,8 +96,8 @@ public async Task StoreAsync_WithSameKey_ShouldOverwrite() var vector2 = new float[] { 0.9f, 0.8f, 0.7f }; // Act - await cache.StoreAsync(key, vector1, tokenCount: null, CancellationToken.None).ConfigureAwait(false); - await cache.StoreAsync(key, vector2, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector1, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector2, CancellationToken.None).ConfigureAwait(false); var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); // Assert @@ -134,8 +116,8 @@ public async Task TryGetAsync_WithDifferentKeys_ShouldReturnCorrectValues() var vector2 = new float[] { 0.4f, 0.5f, 0.6f }; // Act - await cache.StoreAsync(key1, vector1, tokenCount: null, CancellationToken.None).ConfigureAwait(false); - await cache.StoreAsync(key2, vector2, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key1, vector1, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key2, vector2, CancellationToken.None).ConfigureAwait(false); var result1 = await cache.TryGetAsync(key1, CancellationToken.None).ConfigureAwait(false); var result2 = await cache.TryGetAsync(key2, CancellationToken.None).ConfigureAwait(false); @@ -154,7 +136,7 @@ public async Task ReadOnlyMode_TryGetAsync_ShouldWork() { var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); var vector = new float[] { 0.1f, 0.2f, 0.3f }; - await writeCache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + await writeCache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); } // Act - Then read with read-only mode @@ -176,7 +158,7 @@ public async Task ReadOnlyMode_StoreAsync_ShouldNotWrite() var vector = new float[] { 0.1f, 0.2f, 0.3f }; // Act - Store should be ignored in read-only mode - await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); // Assert @@ -192,7 +174,7 @@ public async Task WriteOnlyMode_StoreAsync_ShouldWork() var vector = new float[] { 0.1f, 0.2f, 0.3f }; // Act - await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); // Assert - verify by reading with read-write cache using var readCache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); @@ -209,7 +191,7 @@ public async Task WriteOnlyMode_TryGetAsync_ShouldReturnNull() { var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); var vector = new float[] { 0.1f, 0.2f, 0.3f }; - await writeCache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + await writeCache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); } // Act - Read with write-only mode should return null @@ -274,7 +256,7 @@ public async Task VectorBlobStorage_ShouldPreserveFloatPrecision() }; // Act - await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); // Assert @@ -295,7 +277,7 @@ public async Task CacheDoesNotStoreInputText() var vector = new float[] { 0.1f, 0.2f, 0.3f }; // Act - await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); // Assert - Read database file and verify text is not present var dbContent = await File.ReadAllBytesAsync(this._tempDbPath).ConfigureAwait(false); @@ -313,7 +295,7 @@ public async Task CachePersistence_ShouldSurviveReopen() // Store and close using (var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object)) { - await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); } // Act - Reopen and read @@ -337,7 +319,7 @@ public async Task StoreAsync_WithCancellationToken_ShouldRespectCancellation() // Act & Assert await Assert.ThrowsAsync( - () => cache.StoreAsync(key, vector, tokenCount: null, cts.Token)).ConfigureAwait(false); + () => cache.StoreAsync(key, vector, cts.Token)).ConfigureAwait(false); } [Fact] diff --git a/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs index 66e2ae17e..d2c257784 100644 --- a/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs +++ b/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs @@ -132,7 +132,7 @@ public async Task GenerateAsync_Single_WithCacheMiss_ShouldStoreInCache() // Assert this._cacheMock.Verify( - x => x.StoreAsync(It.IsAny(), generatedVector, null, It.IsAny()), + x => x.StoreAsync(It.IsAny(), generatedVector, It.IsAny()), Times.Once); } @@ -162,7 +162,7 @@ public async Task GenerateAsync_Single_WithWriteOnlyCache_ShouldSkipCacheRead() x => x.TryGetAsync(It.IsAny(), It.IsAny()), Times.Never); this._cacheMock.Verify( - x => x.StoreAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()), + x => x.StoreAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Once); } @@ -192,7 +192,7 @@ public async Task GenerateAsync_Single_WithReadOnlyCache_ShouldSkipCacheWrite() // Assert Assert.Equal(generatedVector, result); this._cacheMock.Verify( - x => x.StoreAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()), + x => x.StoreAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Never); } @@ -360,7 +360,7 @@ public async Task GenerateAsync_Batch_ShouldStoreGeneratedInCache() // Assert - Both generated vectors should be stored this._cacheMock.Verify( - x => x.StoreAsync(It.IsAny(), It.IsAny(), null, It.IsAny()), + x => x.StoreAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Exactly(2)); } diff --git a/tests/Core.Tests/Embeddings/CachedEmbeddingTests.cs b/tests/Core.Tests/Embeddings/CachedEmbeddingTests.cs index e309f370c..9dcfb8672 100644 --- a/tests/Core.Tests/Embeddings/CachedEmbeddingTests.cs +++ b/tests/Core.Tests/Embeddings/CachedEmbeddingTests.cs @@ -4,8 +4,8 @@ namespace KernelMemory.Core.Tests.Embeddings; /// -/// Tests for CachedEmbedding model to verify proper construction and immutability. -/// CachedEmbedding stores the embedding vector and optional token count. +/// Tests for CachedEmbedding model to verify proper construction. +/// CachedEmbedding stores the embedding vector. /// public sealed class CachedEmbeddingTests { @@ -23,42 +23,6 @@ public void CachedEmbedding_WithRequiredProperties_ShouldBeCreated() // Assert Assert.Equal(vector, cached.Vector); - Assert.Null(cached.TokenCount); - } - - [Fact] - public void CachedEmbedding_WithTokenCount_ShouldStoreValue() - { - // Arrange - var vector = new float[] { 0.1f, 0.2f, 0.3f }; - const int tokenCount = 42; - - // Act - var cached = new CachedEmbedding - { - Vector = vector, - TokenCount = tokenCount - }; - - // Assert - Assert.Equal(tokenCount, cached.TokenCount); - } - - [Fact] - public void CachedEmbedding_WithNullTokenCount_ShouldBeNull() - { - // Arrange - var vector = new float[] { 0.1f, 0.2f, 0.3f }; - - // Act - var cached = new CachedEmbedding - { - Vector = vector, - TokenCount = null - }; - - // Assert - Assert.Null(cached.TokenCount); } [Fact] From adfdc0fa6b768cd73b169daafd905012e3f74259 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Tue, 2 Dec 2025 18:27:57 +0100 Subject: [PATCH 6/6] fix: dispose CancellationTokenSource in tests --- .../Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs | 4 ++-- tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs | 2 +- .../Providers/AzureOpenAIEmbeddingGeneratorTests.cs | 2 +- .../Providers/HuggingFaceEmbeddingGeneratorTests.cs | 2 +- .../Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs | 2 +- .../Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs b/tests/Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs index 9eb1e89ba..82495d072 100644 --- a/tests/Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs +++ b/tests/Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs @@ -314,7 +314,7 @@ public async Task StoreAsync_WithCancellationToken_ShouldRespectCancellation() using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); var vector = new float[] { 0.1f, 0.2f, 0.3f }; - var cts = new CancellationTokenSource(); + using var cts = new CancellationTokenSource(); cts.Cancel(); // Act & Assert @@ -328,7 +328,7 @@ public async Task TryGetAsync_WithCancellationToken_ShouldRespectCancellation() // Arrange using var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); - var cts = new CancellationTokenSource(); + using var cts = new CancellationTokenSource(); cts.Cancel(); // Act & Assert diff --git a/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs index d2c257784..3e3d4c749 100644 --- a/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs +++ b/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs @@ -459,7 +459,7 @@ public async Task GenerateAsync_WithCancellation_ShouldPropagate() this._cacheMock.Object, this._loggerMock.Object); - var cts = new CancellationTokenSource(); + using var cts = new CancellationTokenSource(); cts.Cancel(); // Act & Assert diff --git a/tests/Core.Tests/Embeddings/Providers/AzureOpenAIEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/Providers/AzureOpenAIEmbeddingGeneratorTests.cs index 485844874..9cb79d13f 100644 --- a/tests/Core.Tests/Embeddings/Providers/AzureOpenAIEmbeddingGeneratorTests.cs +++ b/tests/Core.Tests/Embeddings/Providers/AzureOpenAIEmbeddingGeneratorTests.cs @@ -228,7 +228,7 @@ public async Task GenerateAsync_WithCancellation_ShouldPropagate() var generator = new AzureOpenAIEmbeddingGenerator( httpClient, "https://myservice.openai.azure.com", "deployment", "model", "key", 1536, true, this._loggerMock.Object); - var cts = new CancellationTokenSource(); + using var cts = new CancellationTokenSource(); cts.Cancel(); // Act & Assert - TaskCanceledException inherits from OperationCanceledException diff --git a/tests/Core.Tests/Embeddings/Providers/HuggingFaceEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/Providers/HuggingFaceEmbeddingGeneratorTests.cs index dd42f4112..94d5811f5 100644 --- a/tests/Core.Tests/Embeddings/Providers/HuggingFaceEmbeddingGeneratorTests.cs +++ b/tests/Core.Tests/Embeddings/Providers/HuggingFaceEmbeddingGeneratorTests.cs @@ -283,7 +283,7 @@ public async Task GenerateAsync_WithCancellation_ShouldPropagate() var generator = new HuggingFaceEmbeddingGenerator( httpClient, "hf_token", "model", 384, true, null, this._loggerMock.Object); - var cts = new CancellationTokenSource(); + using var cts = new CancellationTokenSource(); cts.Cancel(); // Act & Assert - TaskCanceledException inherits from OperationCanceledException diff --git a/tests/Core.Tests/Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs index 217977da0..65d5ba62a 100644 --- a/tests/Core.Tests/Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs +++ b/tests/Core.Tests/Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs @@ -220,7 +220,7 @@ public async Task GenerateAsync_WithCancellation_ShouldPropagate() var generator = new OllamaEmbeddingGenerator( httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object); - var cts = new CancellationTokenSource(); + using var cts = new CancellationTokenSource(); cts.Cancel(); // Act & Assert - TaskCanceledException inherits from OperationCanceledException diff --git a/tests/Core.Tests/Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs index 7dced11ba..1e7923cd0 100644 --- a/tests/Core.Tests/Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs +++ b/tests/Core.Tests/Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs @@ -249,7 +249,7 @@ public async Task GenerateAsync_WithCancellation_ShouldPropagate() var generator = new OpenAIEmbeddingGenerator( httpClient, "test-key", "model", 1536, true, null, this._loggerMock.Object); - var cts = new CancellationTokenSource(); + using var cts = new CancellationTokenSource(); cts.Cancel(); // Act & Assert - TaskCanceledException inherits from OperationCanceledException