From f4407c3854adc58b2c6c93c6f7c74870fdbb10f9 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 18 Dec 2025 13:51:31 +0100 Subject: [PATCH 1/5] Add vector index feature, centralize constants, and instrument CLI/e2e logging --- docs | 2 +- e2e-tests.sh | 60 ++ format.sh | 15 +- src/Core/Config/AppConfig.cs | 9 +- src/Core/Config/ConfigParser.cs | 23 +- .../Embeddings/HuggingFaceEmbeddingsConfig.cs | 5 +- src/Core/Config/NodeConfig.cs | 20 +- src/Core/Config/SearchConfig.cs | 35 +- .../SearchIndex/VectorSearchIndexConfig.cs | 12 +- src/Core/Constants.cs | 370 +++++++ src/Core/Embeddings/Cache/CachedEmbedding.cs | 14 +- src/Core/Embeddings/Cache/IEmbeddingCache.cs | 5 +- .../Embeddings/Cache/SqliteEmbeddingCache.cs | 30 +- .../Embeddings/CachedEmbeddingGenerator.cs | 38 +- src/Core/Embeddings/EmbeddingConstants.cs | 77 -- src/Core/Embeddings/EmbeddingResult.cs | 40 + src/Core/Embeddings/IEmbeddingGenerator.cs | 8 +- .../AzureOpenAIEmbeddingGenerator.cs | 33 +- .../HuggingFaceEmbeddingGenerator.cs | 17 +- .../Providers/OllamaEmbeddingGenerator.cs | 9 +- .../Providers/OpenAIEmbeddingGenerator.cs | 33 +- src/Core/Logging/ActivityEnricher.cs | 4 +- src/Core/Logging/EnvironmentDetector.cs | 10 +- src/Core/Logging/LoggingConstants.cs | 92 -- .../Logging/SensitiveDataScrubbingPolicy.cs | 2 +- src/Core/Logging/SerilogFactory.cs | 28 +- src/Core/Search/IVectorIndex.cs | 33 + src/Core/Search/Models/RerankingConfig.cs | 2 +- src/Core/Search/Models/SearchRequest.cs | 4 +- src/Core/Search/NodeSearchService.cs | 8 +- .../Reranking/WeightedDiminishingReranker.cs | 8 +- src/Core/Search/SearchConstants.cs | 127 --- src/Core/Search/SearchService.cs | 12 +- src/Core/Search/SqliteFtsIndex.cs | 12 +- src/Core/Search/SqliteVectorIndex.cs | 338 ++++++ src/Core/Search/VectorMatch.cs | 20 + src/Core/Search/VectorMath.cs | 127 +++ src/Core/Storage/ContentStorageDbContext.cs | 2 +- src/Core/Storage/ContentStorageService.cs | 46 +- src/Directory.Packages.props | 14 +- src/Main/CLI/CliApplicationBuilder.cs | 27 +- src/Main/CLI/Commands/BaseCommand.cs | 36 +- src/Main/CLI/Commands/ConfigCommand.cs | 19 +- src/Main/CLI/Commands/DeleteCommand.cs | 3 +- src/Main/CLI/Commands/DoctorCommand.cs | 959 ++++++++++++++++++ .../CLI/Commands/DoctorCommandSettings.cs | 13 + src/Main/CLI/Commands/GetCommand.cs | 7 +- src/Main/CLI/Commands/ListCommand.cs | 26 +- src/Main/CLI/Commands/NodesCommand.cs | 3 +- src/Main/CLI/Commands/SearchCommand.cs | 93 +- src/Main/CLI/Commands/UpsertCommand.cs | 3 +- src/Main/CLI/ModeRouter.cs | 4 +- .../OutputFormatters/HumanOutputFormatter.cs | 9 +- src/Main/Constants.cs | 43 - .../Services/EmbeddingGeneratorFactory.cs | 154 +++ src/Main/Services/SearchIndexFactory.cs | 79 +- tests/Core.Tests/Config/AppConfigTests.cs | 28 +- .../Config/ConfigParserAutoCreateTests.cs | 2 +- tests/Core.Tests/Config/ConfigParserTests.cs | 2 +- tests/Core.Tests/Config/SearchConfigTests.cs | 13 +- .../Cache/SqliteEmbeddingCacheTests.cs | 28 +- .../CachedEmbeddingGeneratorTests.cs | 156 ++- .../Embeddings/CachedEmbeddingTests.cs | 9 +- .../Embeddings/EmbeddingConstantsTests.cs | 22 +- .../AzureOpenAIEmbeddingGeneratorTests.cs | 2 +- .../HuggingFaceEmbeddingGeneratorTests.cs | 10 +- .../OllamaEmbeddingGeneratorTests.cs | 38 +- .../OpenAIEmbeddingGeneratorTests.cs | 80 +- tests/Core.Tests/GlobalUsings.cs | 4 +- .../Logging/ActivityEnricherTests.cs | 1 - .../Logging/EnvironmentDetectorTests.cs | 70 +- .../Logging/LoggerExtensionsTests.cs | 1 - .../Core.Tests/Logging/LoggingConfigTests.cs | 1 - .../Logging/LoggingConstantsTests.cs | 20 +- .../SensitiveDataScrubbingPolicyTests.cs | 33 +- .../Core.Tests/Logging/SerilogFactoryTests.cs | 1 - tests/Core.Tests/Logging/TestLoggerFactory.cs | 1 - .../Logging/TestLoggerFactoryTests.cs | 1 - .../Search/FtsIndexPersistenceTest.cs | 1 - .../Core.Tests/Search/FtsIntegrationTests.cs | 3 +- .../Search/FtsQueryExtractionTest.cs | 1 - .../Search/Models/SearchRequestTests.cs | 5 +- .../Search/NodeSearchServiceIndexIdTests.cs | 9 +- .../Core.Tests/Search/SearchConstantsTests.cs | 40 +- .../Core.Tests/Search/SearchEndToEndTests.cs | 3 +- .../Search/SearchServiceFunctionalTests.cs | 5 +- .../Search/SearchServiceIndexWeightsTests.cs | 7 +- tests/Core.Tests/Search/SimpleSearchTest.cs | 1 - .../Core.Tests/Search/SqliteFtsIndexTests.cs | 1 - .../SqliteVectorIndexErrorHandlingTests.cs | 237 +++++ .../SqliteVectorIndexPersistenceTests.cs | 279 +++++ .../Search/SqliteVectorIndexTests.cs | 358 +++++++ tests/Core.Tests/Search/VectorMathTests.cs | 271 +++++ tests/Main.Tests/GlobalUsings.cs | 1 + .../Integration/CliIntegrationTests.cs | 58 +- .../Integration/CommandExecutionTests.cs | 6 +- .../Integration/ConfigCommandTests.cs | 32 +- .../DefaultConfigVectorIndexTests.cs | 266 +++++ .../Integration/NodeSelectionTests.cs | 389 +++++++ .../Integration/ReadonlyCommandTests.cs | 10 +- .../Integration/UserDataProtectionTests.cs | 2 +- .../EmbeddingGeneratorFactoryTests.cs | 289 ++++++ .../Services/SearchIndexFactoryVectorTests.cs | 264 +++++ tests/Main.Tests/Unit/CLI/ModeRouterTests.cs | 2 +- .../Unit/Commands/BaseCommandTests.cs | 8 +- .../Unit/Commands/DoctorCommandTests.cs | 458 +++++++++ .../HumanOutputFormatterTests.cs | 4 +- .../Unit/Settings/ListCommandSettingsTests.cs | 2 +- tests/e2e/framework/__init__.py | 1 + tests/e2e/framework/cli.py | 101 ++ tests/e2e/framework/db.py | 49 + tests/e2e/framework/logging.py | 38 + tests/e2e/requirements.txt | 1 + tests/e2e/test_01_put_get_delete.py | 148 +++ tests/e2e/test_02_search_with_broken_node.py | 121 +++ tests/e2e/test_03_fts_stemming.py | 141 +++ tests/e2e/test_04_vector_search.py | 239 +++++ tests/e2e/test_05_embeddings_cache.py | 210 ++++ 118 files changed, 6952 insertions(+), 854 deletions(-) create mode 100755 e2e-tests.sh create mode 100644 src/Core/Constants.cs delete mode 100644 src/Core/Embeddings/EmbeddingConstants.cs create mode 100644 src/Core/Embeddings/EmbeddingResult.cs delete mode 100644 src/Core/Logging/LoggingConstants.cs create mode 100644 src/Core/Search/IVectorIndex.cs delete mode 100644 src/Core/Search/SearchConstants.cs create mode 100644 src/Core/Search/SqliteVectorIndex.cs create mode 100644 src/Core/Search/VectorMatch.cs create mode 100644 src/Core/Search/VectorMath.cs create mode 100644 src/Main/CLI/Commands/DoctorCommand.cs create mode 100644 src/Main/CLI/Commands/DoctorCommandSettings.cs delete mode 100644 src/Main/Constants.cs create mode 100644 src/Main/Services/EmbeddingGeneratorFactory.cs create mode 100644 tests/Core.Tests/Search/SqliteVectorIndexErrorHandlingTests.cs create mode 100644 tests/Core.Tests/Search/SqliteVectorIndexPersistenceTests.cs create mode 100644 tests/Core.Tests/Search/SqliteVectorIndexTests.cs create mode 100644 tests/Core.Tests/Search/VectorMathTests.cs create mode 100644 tests/Main.Tests/Integration/DefaultConfigVectorIndexTests.cs create mode 100644 tests/Main.Tests/Integration/NodeSelectionTests.cs create mode 100644 tests/Main.Tests/Services/EmbeddingGeneratorFactoryTests.cs create mode 100644 tests/Main.Tests/Services/SearchIndexFactoryVectorTests.cs create mode 100644 tests/Main.Tests/Unit/Commands/DoctorCommandTests.cs create mode 100644 tests/e2e/framework/__init__.py create mode 100644 tests/e2e/framework/cli.py create mode 100644 tests/e2e/framework/db.py create mode 100644 tests/e2e/framework/logging.py create mode 100644 tests/e2e/requirements.txt create mode 100755 tests/e2e/test_01_put_get_delete.py create mode 100755 tests/e2e/test_02_search_with_broken_node.py create mode 100755 tests/e2e/test_03_fts_stemming.py create mode 100755 tests/e2e/test_04_vector_search.py create mode 100755 tests/e2e/test_05_embeddings_cache.py diff --git a/docs b/docs index 23845bf23..a0321cf66 160000 --- a/docs +++ b/docs @@ -1 +1 @@ -Subproject commit 23845bf23aa39bb2a443fbe47e100a7f8de6c5db +Subproject commit a0321cf667c81bf16b0e830e54afe0890b279520 diff --git a/e2e-tests.sh b/e2e-tests.sh new file mode 100755 index 000000000..7eb2a06cf --- /dev/null +++ b/e2e-tests.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +set -e + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)" +cd "$ROOT" + +echo "=======================================" +echo " Running E2E Tests" +echo "=======================================" +echo "" + +# Choose build configuration (default Release to align with build.sh) +CONFIGURATION="${CONFIGURATION:-Release}" +KM_BIN="$ROOT/src/Main/bin/$CONFIGURATION/net10.0/KernelMemory.Main.dll" + +# Ensure km binary is built at the selected configuration +if [ ! -f "$KM_BIN" ]; then + echo "km binary not found at $KM_BIN. Building ($CONFIGURATION)..." + dotnet build src/Main/Main.csproj -c "$CONFIGURATION" +fi + +if [ ! -f "$KM_BIN" ]; then + echo "❌ km binary still not found at $KM_BIN after build. Set KM_BIN to a valid path." + exit 1 +fi + +export KM_BIN + +FAILED=0 +PASSED=0 + +# Run each test file +for test_file in tests/e2e/test_*.py; do + if [ -f "$test_file" ]; then + echo "" + echo "Running: $(basename "$test_file")" + echo "---------------------------------------" + + if python3 "$test_file"; then + PASSED=$((PASSED + 1)) + else + FAILED=$((FAILED + 1)) + fi + fi +done + +echo "" +echo "=======================================" +echo " E2E Test Results" +echo "=======================================" +echo "Passed: $PASSED" +echo "Failed: $FAILED" +echo "=======================================" + +if [ $FAILED -gt 0 ]; then + exit 1 +fi + +exit 0 diff --git a/format.sh b/format.sh index e07cac8a3..011073326 100755 --- a/format.sh +++ b/format.sh @@ -1 +1,14 @@ -dotnet format +#!/usr/bin/env bash + +set -e + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)" +cd "$ROOT" +TMPDIR="$ROOT/.tmp" +mkdir -p "$TMPDIR" +export TMPDIR + +dotnet format src/Core/Core.csproj +dotnet format src/Main/Main.csproj +dotnet format tests/Core.Tests/Core.Tests.csproj +dotnet format tests/Main.Tests/Main.Tests.csproj diff --git a/src/Core/Config/AppConfig.cs b/src/Core/Config/AppConfig.cs index 98f95d0d6..24c9a624a 100644 --- a/src/Core/Config/AppConfig.cs +++ b/src/Core/Config/AppConfig.cs @@ -83,7 +83,8 @@ public static AppConfig CreateDefault() /// /// Creates a default configuration with a single "personal" node - /// using local SQLite storage in the specified base directory + /// using local SQLite storage in the specified base directory. + /// Includes embeddings cache for efficient vector search operations. /// /// Base directory for data storage public static AppConfig CreateDefault(string baseDir) @@ -95,8 +96,10 @@ public static AppConfig CreateDefault(string baseDir) Nodes = new Dictionary { ["personal"] = NodeConfig.CreateDefaultPersonalNode(personalNodeDir) - } - // EmbeddingsCache and LLMCache intentionally omitted - add when features are implemented + }, + EmbeddingsCache = CacheConfig.CreateDefaultSqliteCache( + Path.Combine(baseDir, "embeddings-cache.db")) + // LLMCache intentionally omitted - add when LLM features are implemented }; } } diff --git a/src/Core/Config/ConfigParser.cs b/src/Core/Config/ConfigParser.cs index f20bd2c40..f0111d0d0 100644 --- a/src/Core/Config/ConfigParser.cs +++ b/src/Core/Config/ConfigParser.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System.Text.Json; +using System.Text.Json.Serialization; using System.Text.Json.Serialization.Metadata; using KernelMemory.Core.Config.Cache; using KernelMemory.Core.Config.ContentIndex; @@ -28,7 +29,8 @@ public static class ConfigParser ReadCommentHandling = JsonCommentHandling.Skip, AllowTrailingCommas = true, PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - TypeInfoResolver = new DefaultJsonTypeInfoResolver() + TypeInfoResolver = new DefaultJsonTypeInfoResolver(), + Converters = { new JsonStringEnumConverter() } }; /// @@ -46,13 +48,14 @@ public static class ConfigParser /// /// Loads configuration from a file, or creates default config if file doesn't exist. - /// The config file is always ensured to exist on disk after loading. + /// Optionally ensures the config file exists on disk after loading (for write operations). /// Performs tilde expansion on paths (~/ → home directory) /// /// Path to configuration file + /// If true, writes config to disk if missing (default: true for backward compatibility) /// Validated AppConfig instance /// Thrown when file exists but parsing or validation fails - public static AppConfig LoadFromFile(string filePath) + public static AppConfig LoadFromFile(string filePath, bool ensureFileExists = true) { AppConfig config; @@ -65,8 +68,11 @@ public static AppConfig LoadFromFile(string filePath) // Create default config relative to config file location config = AppConfig.CreateDefault(baseDir); - // Write the config file - WriteConfigFile(filePath, config); + // Write the config file only if requested + if (ensureFileExists) + { + WriteConfigFile(filePath, config); + } return config; } @@ -82,8 +88,11 @@ public static AppConfig LoadFromFile(string filePath) // Expand tilde paths ExpandTildePaths(config); - // Always ensure the config file exists (recreate if deleted between load and save) - WriteConfigFileIfMissing(filePath, config); + // Optionally ensure the config file exists (recreate if deleted between load and save) + if (ensureFileExists) + { + WriteConfigFileIfMissing(filePath, config); + } return config; } diff --git a/src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs b/src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs index 5ae7c810f..6e750c81f 100644 --- a/src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs +++ b/src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs @@ -2,7 +2,6 @@ using System.Text.Json.Serialization; using KernelMemory.Core.Config.Enums; using KernelMemory.Core.Config.Validation; -using KernelMemory.Core.Embeddings; namespace KernelMemory.Core.Config.Embeddings; @@ -20,7 +19,7 @@ public sealed class HuggingFaceEmbeddingsConfig : EmbeddingsConfig /// HuggingFace model name (e.g., "sentence-transformers/all-MiniLM-L6-v2", "BAAI/bge-base-en-v1.5"). /// [JsonPropertyName("model")] - public string Model { get; set; } = EmbeddingConstants.DefaultHuggingFaceModel; + public string Model { get; set; } = Constants.EmbeddingDefaults.DefaultHuggingFaceModel; /// /// HuggingFace API key (token). @@ -35,7 +34,7 @@ public sealed class HuggingFaceEmbeddingsConfig : EmbeddingsConfig /// Can be changed for custom inference endpoints. /// [JsonPropertyName("baseUrl")] - public string BaseUrl { get; set; } = EmbeddingConstants.DefaultHuggingFaceBaseUrl; + public string BaseUrl { get; set; } = Constants.EmbeddingDefaults.DefaultHuggingFaceBaseUrl; /// public override void Validate(string path) diff --git a/src/Core/Config/NodeConfig.cs b/src/Core/Config/NodeConfig.cs index a0a30c989..e9c800978 100644 --- a/src/Core/Config/NodeConfig.cs +++ b/src/Core/Config/NodeConfig.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Text.Json.Serialization; using KernelMemory.Core.Config.ContentIndex; +using KernelMemory.Core.Config.Embeddings; using KernelMemory.Core.Config.Enums; using KernelMemory.Core.Config.SearchIndex; using KernelMemory.Core.Config.Storage; @@ -106,7 +107,8 @@ public void Validate(string path) } /// - /// Creates a default "personal" node configuration + /// Creates a default "personal" node configuration with FTS and vector search. + /// Uses Ollama with qwen3-embedding model (1024 dimensions) for local, offline-capable vector search. /// /// internal static NodeConfig CreateDefaultPersonalNode(string nodeDir) @@ -128,7 +130,21 @@ internal static NodeConfig CreateDefaultPersonalNode(string nodeDir) Id = "sqlite-fts", Type = SearchIndexTypes.SqliteFTS, Path = Path.Combine(nodeDir, "fts.db"), - EnableStemming = true + EnableStemming = true, + Required = true + }, + new VectorSearchIndexConfig + { + Id = "sqlite-vector", + Type = SearchIndexTypes.SqliteVector, + Path = Path.Combine(nodeDir, "vector.db"), + Dimensions = 1024, + UseSqliteVec = false, + Embeddings = new OllamaEmbeddingsConfig + { + Model = Constants.EmbeddingDefaults.DefaultOllamaModel, + BaseUrl = Constants.EmbeddingDefaults.DefaultOllamaBaseUrl + } } } }; diff --git a/src/Core/Config/SearchConfig.cs b/src/Core/Config/SearchConfig.cs index 74dfce6e1..7fcec6c5f 100644 --- a/src/Core/Config/SearchConfig.cs +++ b/src/Core/Config/SearchConfig.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System.Text.Json.Serialization; using KernelMemory.Core.Config.Validation; -using KernelMemory.Core.Search; namespace KernelMemory.Core.Config; @@ -17,14 +16,14 @@ public sealed class SearchConfig : IValidatable /// Default: 0.3 (moderate threshold). /// [JsonPropertyName("defaultMinRelevance")] - public float DefaultMinRelevance { get; set; } = SearchConstants.DefaultMinRelevance; + public float DefaultMinRelevance { get; set; } = Constants.SearchDefaults.DefaultMinRelevance; /// /// Default maximum number of results to return per search. /// Default: 20 results. /// [JsonPropertyName("defaultLimit")] - public int DefaultLimit { get; set; } = SearchConstants.DefaultLimit; + public int DefaultLimit { get; set; } = Constants.SearchDefaults.DefaultLimit; /// /// Search timeout in seconds per node. @@ -32,7 +31,7 @@ public sealed class SearchConfig : IValidatable /// Default: 30 seconds. /// [JsonPropertyName("searchTimeoutSeconds")] - public int SearchTimeoutSeconds { get; set; } = SearchConstants.DefaultSearchTimeoutSeconds; + public int SearchTimeoutSeconds { get; set; } = Constants.SearchDefaults.DefaultSearchTimeoutSeconds; /// /// Default maximum results to retrieve from each node (memory safety). @@ -41,7 +40,7 @@ public sealed class SearchConfig : IValidatable /// Default: 1000 results per node. /// [JsonPropertyName("maxResultsPerNode")] - public int MaxResultsPerNode { get; set; } = SearchConstants.DefaultMaxResultsPerNode; + public int MaxResultsPerNode { get; set; } = Constants.SearchDefaults.DefaultMaxResultsPerNode; /// /// Default nodes to search when no explicit --nodes flag is provided. @@ -50,7 +49,7 @@ public sealed class SearchConfig : IValidatable /// [JsonPropertyName("defaultNodes")] [System.Diagnostics.CodeAnalysis.SuppressMessage("Performance", "CA1819:Properties should not return arrays")] - public string[] DefaultNodes { get; set; } = [SearchConstants.AllNodesWildcard]; + public string[] DefaultNodes { get; set; } = [Constants.SearchDefaults.AllNodesWildcard]; /// /// Nodes to exclude from search by default. @@ -67,7 +66,7 @@ public sealed class SearchConfig : IValidatable /// Default: 10 levels. /// [JsonPropertyName("maxQueryDepth")] - public int MaxQueryDepth { get; set; } = SearchConstants.MaxQueryDepth; + public int MaxQueryDepth { get; set; } = Constants.SearchDefaults.MaxQueryDepth; /// /// Maximum number of boolean operators (AND/OR/NOT) in a single query. @@ -75,7 +74,7 @@ public sealed class SearchConfig : IValidatable /// Default: 50 operators. /// [JsonPropertyName("maxBooleanOperators")] - public int MaxBooleanOperators { get; set; } = SearchConstants.MaxBooleanOperators; + public int MaxBooleanOperators { get; set; } = Constants.SearchDefaults.MaxBooleanOperators; /// /// Maximum length of a field value in query (characters). @@ -83,7 +82,7 @@ public sealed class SearchConfig : IValidatable /// Default: 1000 characters. /// [JsonPropertyName("maxFieldValueLength")] - public int MaxFieldValueLength { get; set; } = SearchConstants.MaxFieldValueLength; + public int MaxFieldValueLength { get; set; } = Constants.SearchDefaults.MaxFieldValueLength; /// /// Maximum time allowed for query parsing (milliseconds). @@ -91,42 +90,42 @@ public sealed class SearchConfig : IValidatable /// Default: 1000ms (1 second). /// [JsonPropertyName("queryParseTimeoutMs")] - public int QueryParseTimeoutMs { get; set; } = SearchConstants.QueryParseTimeoutMs; + public int QueryParseTimeoutMs { get; set; } = Constants.SearchDefaults.QueryParseTimeoutMs; /// /// Default snippet length in characters when --snippet flag is used. /// Default: 200 characters. /// [JsonPropertyName("snippetLength")] - public int SnippetLength { get; set; } = SearchConstants.DefaultSnippetLength; + public int SnippetLength { get; set; } = Constants.SearchDefaults.DefaultSnippetLength; /// /// Default maximum number of snippets per result when --snippet flag is used. /// Default: 1 snippet. /// [JsonPropertyName("maxSnippetsPerResult")] - public int MaxSnippetsPerResult { get; set; } = SearchConstants.DefaultMaxSnippetsPerResult; + public int MaxSnippetsPerResult { get; set; } = Constants.SearchDefaults.DefaultMaxSnippetsPerResult; /// /// Separator string between multiple snippets. /// Default: "..." (ellipsis). /// [JsonPropertyName("snippetSeparator")] - public string SnippetSeparator { get; set; } = SearchConstants.DefaultSnippetSeparator; + public string SnippetSeparator { get; set; } = Constants.SearchDefaults.DefaultSnippetSeparator; /// /// Prefix marker for highlighting matched terms. /// Default: "<mark>" (HTML-style). /// [JsonPropertyName("highlightPrefix")] - public string HighlightPrefix { get; set; } = SearchConstants.DefaultHighlightPrefix; + public string HighlightPrefix { get; set; } = Constants.SearchDefaults.DefaultHighlightPrefix; /// /// Suffix marker for highlighting matched terms. /// Default: "</mark>" (HTML-style). /// [JsonPropertyName("highlightSuffix")] - public string HighlightSuffix { get; set; } = SearchConstants.DefaultHighlightSuffix; + public string HighlightSuffix { get; set; } = Constants.SearchDefaults.DefaultHighlightSuffix; /// /// Validates the search configuration. @@ -135,10 +134,10 @@ public sealed class SearchConfig : IValidatable public void Validate(string path) { // Validate min relevance score - if (this.DefaultMinRelevance < SearchConstants.MinRelevanceScore || this.DefaultMinRelevance > SearchConstants.MaxRelevanceScore) + if (this.DefaultMinRelevance < Constants.SearchDefaults.MinRelevanceScore || this.DefaultMinRelevance > Constants.SearchDefaults.MaxRelevanceScore) { throw new ConfigException($"{path}.DefaultMinRelevance", - $"Must be between {SearchConstants.MinRelevanceScore} and {SearchConstants.MaxRelevanceScore}"); + $"Must be between {Constants.SearchDefaults.MinRelevanceScore} and {Constants.SearchDefaults.MaxRelevanceScore}"); } // Validate default limit @@ -167,7 +166,7 @@ public void Validate(string path) } // Validate no contradictory node configuration - if (this.DefaultNodes.Length == 1 && this.DefaultNodes[0] == SearchConstants.AllNodesWildcard) + if (this.DefaultNodes.Length == 1 && this.DefaultNodes[0] == Constants.SearchDefaults.AllNodesWildcard) { // Using wildcard - excludeNodes is OK } diff --git a/src/Core/Config/SearchIndex/VectorSearchIndexConfig.cs b/src/Core/Config/SearchIndex/VectorSearchIndexConfig.cs index f320c12f9..72e15f06e 100644 --- a/src/Core/Config/SearchIndex/VectorSearchIndexConfig.cs +++ b/src/Core/Config/SearchIndex/VectorSearchIndexConfig.cs @@ -32,11 +32,21 @@ public sealed class VectorSearchIndexConfig : SearchIndexConfig public int Dimensions { get; set; } = 768; /// - /// Distance/similarity metric for vector comparison + /// Distance/similarity metric for vector comparison. + /// Note: Implementation normalizes vectors at write time and uses dot product, + /// which is equivalent to cosine similarity for normalized vectors. /// [JsonPropertyName("metric")] public VectorMetrics Metric { get; set; } = VectorMetrics.Cosine; + /// + /// Whether to attempt loading the sqlite-vec extension for accelerated vector operations. + /// Default: false (uses pure BLOB storage with C# distance calculations). + /// If true and extension is not available, gracefully falls back to BLOB storage with a warning. + /// + [JsonPropertyName("useSqliteVec")] + public bool UseSqliteVec { get; set; } = false; + /// public override void Validate(string path) { diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs new file mode 100644 index 000000000..e611b6624 --- /dev/null +++ b/src/Core/Constants.cs @@ -0,0 +1,370 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Serilog.Events; + +namespace KernelMemory.Core; + +/// +/// Centralized constants for the Core module. +/// Organized in nested classes by domain for maintainability and discoverability. +/// All magic values should be defined here rather than hardcoded throughout the codebase. +/// +public static class Constants +{ + /// + /// Constants for search functionality including FTS and vector search. + /// + public static class SearchDefaults + { + /// + /// Default minimum relevance score threshold (0.0-1.0). + /// Results below this score are filtered out. + /// + public const float DefaultMinRelevance = 0.3f; + + /// + /// Default maximum number of results to return per search. + /// + public const int DefaultLimit = 20; + + /// + /// Default search timeout in seconds per node. + /// + public const int DefaultSearchTimeoutSeconds = 30; + + /// + /// Default maximum results to retrieve from each node (memory safety). + /// Prevents memory exhaustion from large result sets. + /// + public const int DefaultMaxResultsPerNode = 1000; + + /// + /// Default node weight for relevance scoring. + /// + public const float DefaultNodeWeight = 1.0f; + + /// + /// Default search index weight for relevance scoring. + /// + public const float DefaultIndexWeight = 1.0f; + + /// + /// BM25 score normalization divisor for exponential mapping. + /// Maps BM25 range [-10, 0] to [0.37, 1.0] using exp(score/divisor). + /// + public const double Bm25NormalizationDivisor = 10.0; + + /// + /// Maximum nesting depth for query parentheses. + /// Prevents DoS attacks via deeply nested queries. + /// + public const int MaxQueryDepth = 10; + + /// + /// Maximum number of boolean operators (AND/OR/NOT) in a single query. + /// Prevents query complexity attacks. + /// + public const int MaxBooleanOperators = 50; + + /// + /// Maximum length of a field value in query (characters). + /// Prevents oversized query values. + /// + public const int MaxFieldValueLength = 1000; + + /// + /// Maximum time allowed for query parsing (milliseconds). + /// Prevents regex catastrophic backtracking. + /// + public const int QueryParseTimeoutMs = 1000; + + /// + /// Default snippet length in characters. + /// + public const int DefaultSnippetLength = 200; + + /// + /// Default maximum number of snippets per result. + /// + public const int DefaultMaxSnippetsPerResult = 1; + + /// + /// Default snippet separator between multiple snippets. + /// + public const string DefaultSnippetSeparator = "..."; + + /// + /// Default highlight prefix marker. + /// + public const string DefaultHighlightPrefix = ""; + + /// + /// Default highlight suffix marker. + /// + public const string DefaultHighlightSuffix = ""; + + /// + /// Diminishing returns multipliers for aggregating multiple appearances of same record. + /// First appearance: 1.0 (full weight) + /// Second appearance: 0.5 (50% boost) + /// Third appearance: 0.25 (25% boost) + /// Fourth appearance: 0.125 (12.5% boost) + /// Each subsequent multiplier is half of the previous. + /// + public static readonly float[] DefaultDiminishingMultipliers = [1.0f, 0.5f, 0.25f, 0.125f]; + + /// + /// Wildcard character for "all nodes" in node selection. + /// + public const string AllNodesWildcard = "*"; + + /// + /// Maximum relevance score (scores are capped at this value). + /// + public const float MaxRelevanceScore = 1.0f; + + /// + /// Minimum relevance score. + /// + public const float MinRelevanceScore = 0.0f; + + /// + /// Default FTS index ID used when not specified in configuration. + /// This is the identifier assigned to search results from the full-text search index. + /// + public const string DefaultFtsIndexId = "fts-main"; + } + + /// + /// Constants for embedding generation including known model dimensions, + /// default configurations, and batch sizes. + /// + public static class EmbeddingDefaults + { + /// + /// Default batch size for embedding generation requests. + /// Configurable per provider, but this is the default. + /// + public const int DefaultBatchSize = 10; + + /// + /// Default Ollama model for embeddings. + /// + public const string DefaultOllamaModel = "qwen3-embedding:0.6b"; + + /// + /// Default Ollama base URL. + /// + public const string DefaultOllamaBaseUrl = "http://localhost:11434"; + + /// + /// Default HuggingFace model for embeddings. + /// + public const string DefaultHuggingFaceModel = "sentence-transformers/all-MiniLM-L6-v2"; + + /// + /// Default HuggingFace Inference API base URL. + /// + public const string DefaultHuggingFaceBaseUrl = "https://api-inference.huggingface.co"; + + /// + /// Default OpenAI API base URL. + /// + public const string DefaultOpenAIBaseUrl = "https://api.openai.com"; + + /// + /// Azure OpenAI API version. + /// + public const string AzureOpenAIApiVersion = "2024-02-01"; + + /// + /// Known model dimensions for common embedding models. + /// These values are fixed per model and used for validation and cache key generation. + /// + public static readonly IReadOnlyDictionary KnownModelDimensions = new Dictionary + { + // Ollama models + ["qwen3-embedding"] = 1024, + ["nomic-embed-text"] = 768, + ["embeddinggemma"] = 768, + + // OpenAI models + ["text-embedding-ada-002"] = 1536, + ["text-embedding-3-small"] = 1536, + ["text-embedding-3-large"] = 3072, + + // HuggingFace models + ["sentence-transformers/all-MiniLM-L6-v2"] = 384, + ["BAAI/bge-base-en-v1.5"] = 768 + }; + + /// + /// Try to get the dimensions for a known model. + /// + /// The model name to look up. + /// The dimensions if found, 0 otherwise. + /// True if the model is known, false otherwise. + public static bool TryGetDimensions(string modelName, out int dimensions) + { + return KnownModelDimensions.TryGetValue(modelName, out dimensions); + } + } + + /// + /// Constants for the logging system including file rotation, log levels, + /// and output formatting. + /// + public static class LoggingDefaults + { + /// + /// Default maximum file size before rotation (100MB). + /// Balances history retention with disk usage. + /// + public const long DefaultFileSizeLimitBytes = 100 * 1024 * 1024; + + /// + /// Default number of log files to retain (30 files). + /// Approximately 1 month of daily logs or ~3GB max storage. + /// + public const int DefaultRetainedFileCountLimit = 30; + + /// + /// Default minimum log level for file output. + /// Information level provides useful diagnostics without excessive verbosity. + /// + public const LogEventLevel DefaultFileLogLevel = LogEventLevel.Information; + + /// + /// Default minimum log level for console/stderr output. + /// Only warnings and errors appear on stderr by default. + /// + public const LogEventLevel DefaultConsoleLogLevel = LogEventLevel.Warning; + + /// + /// Environment variable for .NET runtime environment detection. + /// Takes precedence over ASPNETCORE_ENVIRONMENT. + /// + public const string DotNetEnvironmentVariable = "DOTNET_ENVIRONMENT"; + + /// + /// Fallback environment variable for ASP.NET Core applications. + /// Used when DOTNET_ENVIRONMENT is not set. + /// + public const string AspNetCoreEnvironmentVariable = "ASPNETCORE_ENVIRONMENT"; + + /// + /// Default environment when no environment variable is set. + /// Defaults to Development for developer safety (full logging enabled). + /// + public const string DefaultEnvironment = "Development"; + + /// + /// Production environment name for comparison. + /// Sensitive data is scrubbed only in Production. + /// + public const string ProductionEnvironment = "Production"; + + /// + /// Placeholder text for redacted sensitive data. + /// Used to indicate data was intentionally removed from logs. + /// + public const string RedactedPlaceholder = "[REDACTED]"; + + /// + /// Human-readable output template for log messages. + /// Includes timestamp, level, source context, message, and optional exception. + /// + public const string HumanReadableOutputTemplate = + "{Timestamp:yyyy-MM-dd HH:mm:ss.fff} [{Level:u3}] {SourceContext}: {Message:lj}{NewLine}{Exception}"; + + /// + /// Compact output template for console (stderr) output. + /// Shorter format suitable for CLI error reporting. + /// + public const string ConsoleOutputTemplate = + "{Timestamp:HH:mm:ss} [{Level:u3}] {Message:lj}{NewLine}{Exception}"; + + /// + /// Empty trace ID value (32 zeros) used when no Activity is present. + /// Indicates no distributed tracing context is available. + /// + public const string EmptyTraceId = "00000000000000000000000000000000"; + + /// + /// Empty span ID value (16 zeros) used when no Activity is present. + /// Indicates no distributed tracing context is available. + /// + public const string EmptySpanId = "0000000000000000"; + } + + /// + /// Constants for application configuration and setup. + /// + public static class ConfigDefaults + { + /// + /// Default configuration file name. + /// + public const string DefaultConfigFileName = "config.json"; + + /// + /// Default configuration directory name in user's home directory. + /// + public const string DefaultConfigDirName = ".km"; + } + + /// + /// Constants for application exit codes and CLI behavior. + /// + public static class App + { + /// + /// Exit code for successful operation. + /// + public const int ExitCodeSuccess = 0; + + /// + /// Exit code for user errors (bad input, not found, validation failure). + /// + public const int ExitCodeUserError = 1; + + /// + /// Exit code for system errors (storage failure, config error, unexpected exception). + /// + public const int ExitCodeSystemError = 2; + + /// + /// Default pagination size for list operations. + /// + public const int DefaultPageSize = 20; + + /// + /// Maximum content length to display in truncated view (characters). + /// + public const int MaxContentDisplayLength = 100; + } + + /// + /// Constants for database and storage operations. + /// + public static class Database + { + /// + /// SQLite busy timeout in milliseconds for handling concurrent access. + /// Waits up to this duration before throwing a busy exception. + /// + public const int SqliteBusyTimeoutMs = 5000; + + /// + /// Maximum length for MIME type field in content storage. + /// Prevents excessively long MIME type values. + /// + public const int MaxMimeTypeLength = 255; + + /// + /// Default snippet preview length in characters for SQL queries. + /// Used when displaying content excerpts in search results. + /// + public const int DefaultSqlSnippetLength = 200; + } +} diff --git a/src/Core/Embeddings/Cache/CachedEmbedding.cs b/src/Core/Embeddings/Cache/CachedEmbedding.cs index 43f847c13..34570ce26 100644 --- a/src/Core/Embeddings/Cache/CachedEmbedding.cs +++ b/src/Core/Embeddings/Cache/CachedEmbedding.cs @@ -4,7 +4,7 @@ namespace KernelMemory.Core.Embeddings.Cache; /// -/// Represents a cached embedding vector. +/// Represents a cached embedding vector with metadata. /// public sealed class CachedEmbedding { @@ -14,4 +14,16 @@ public sealed class CachedEmbedding [SuppressMessage("Performance", "CA1819:Properties should not return arrays", Justification = "Embedding vectors are read-only after creation and passed to storage layer")] public required float[] Vector { get; init; } + + /// + /// Optional token count returned by the provider. + /// Null if provider doesn't report token usage. + /// + public int? TokenCount { get; init; } + + /// + /// When this cache entry was created. + /// Used for debugging and potential future cache eviction. + /// + public required DateTimeOffset Timestamp { get; init; } } diff --git a/src/Core/Embeddings/Cache/IEmbeddingCache.cs b/src/Core/Embeddings/Cache/IEmbeddingCache.cs index 59fd15743..cc4ab10d1 100644 --- a/src/Core/Embeddings/Cache/IEmbeddingCache.cs +++ b/src/Core/Embeddings/Cache/IEmbeddingCache.cs @@ -25,11 +25,12 @@ public interface IEmbeddingCache Task TryGetAsync(EmbeddingCacheKey key, CancellationToken ct = default); /// - /// Store an embedding in the cache. + /// Store an embedding in the cache with optional token count. /// Does nothing if mode is ReadOnly. /// /// The cache key. /// The embedding vector to store. + /// Optional token count if provider reports it. /// Cancellation token. - Task StoreAsync(EmbeddingCacheKey key, float[] vector, CancellationToken ct = default); + Task StoreAsync(EmbeddingCacheKey key, float[] vector, int? tokenCount, CancellationToken ct = default); } diff --git a/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs b/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs index 39690c6d7..22c437c46 100644 --- a/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs +++ b/src/Core/Embeddings/Cache/SqliteEmbeddingCache.cs @@ -21,21 +21,24 @@ CREATE TABLE IF NOT EXISTS embeddings_cache ( text_length INTEGER NOT NULL, text_hash TEXT NOT NULL, vector BLOB NOT NULL, + token_count INTEGER NULL, + timestamp TEXT NOT NULL, PRIMARY KEY (provider, model, dimensions, is_normalized, text_hash) ); + CREATE INDEX IF NOT EXISTS idx_timestamp ON embeddings_cache(timestamp); """; private const string SelectSql = """ - SELECT vector FROM embeddings_cache + SELECT vector, token_count, timestamp FROM embeddings_cache WHERE provider = @provider AND model = @model AND dimensions = @dimensions AND is_normalized = @isNormalized AND text_hash = @textHash """; private const string UpsertSql = """ - INSERT INTO embeddings_cache (provider, model, dimensions, is_normalized, text_length, text_hash, vector) - VALUES (@provider, @model, @dimensions, @isNormalized, @textLength, @textHash, @vector) + INSERT INTO embeddings_cache (provider, model, dimensions, is_normalized, text_length, text_hash, vector, token_count, timestamp) + VALUES (@provider, @model, @dimensions, @isNormalized, @textLength, @textHash, @vector, @tokenCount, @timestamp) ON CONFLICT(provider, model, dimensions, is_normalized, text_hash) - DO UPDATE SET vector = @vector + DO UPDATE SET vector = @vector, token_count = @tokenCount, timestamp = @timestamp """; private readonly SqliteConnection _connection; @@ -86,7 +89,9 @@ public SqliteEmbeddingCache(string dbPath, CacheModes mode, ILogger - public async Task StoreAsync(EmbeddingCacheKey key, float[] vector, CancellationToken ct = default) + public async Task StoreAsync(EmbeddingCacheKey key, float[] vector, int? tokenCount, CancellationToken ct = default) { ct.ThrowIfCancellationRequested(); @@ -156,6 +165,7 @@ public async Task StoreAsync(EmbeddingCacheKey key, float[] vector, Cancellation } var vectorBlob = FloatArrayToBytes(vector); + var timestamp = DateTimeOffset.UtcNow.ToString("O"); // ISO 8601 format var command = this._connection.CreateCommand(); await using (command.ConfigureAwait(false)) @@ -168,11 +178,13 @@ public async Task StoreAsync(EmbeddingCacheKey key, float[] vector, Cancellation command.Parameters.AddWithValue("@textLength", key.TextLength); command.Parameters.AddWithValue("@textHash", key.TextHash); command.Parameters.AddWithValue("@vector", vectorBlob); + command.Parameters.AddWithValue("@tokenCount", tokenCount.HasValue ? (object)tokenCount.Value : DBNull.Value); + command.Parameters.AddWithValue("@timestamp", timestamp); await command.ExecuteNonQueryAsync(ct).ConfigureAwait(false); - this._logger.LogTrace("Stored embedding in cache: {Provider}/{Model} hash: {HashPrefix}..., dimensions: {Dimensions}", - key.Provider, key.Model, key.TextHash[..Math.Min(16, key.TextHash.Length)], vector.Length); + this._logger.LogTrace("Stored embedding in cache: {Provider}/{Model} hash: {HashPrefix}..., dimensions: {Dimensions}, tokens: {TokenCount}", + key.Provider, key.Model, key.TextHash[..Math.Min(16, key.TextHash.Length)], vector.Length, tokenCount); } } diff --git a/src/Core/Embeddings/CachedEmbeddingGenerator.cs b/src/Core/Embeddings/CachedEmbeddingGenerator.cs index 897942cef..dc64912a3 100644 --- a/src/Core/Embeddings/CachedEmbeddingGenerator.cs +++ b/src/Core/Embeddings/CachedEmbeddingGenerator.cs @@ -54,7 +54,7 @@ public CachedEmbeddingGenerator( } /// - public async Task GenerateAsync(string text, CancellationToken ct = default) + public async Task GenerateAsync(string text, CancellationToken ct = default) { var key = this.BuildCacheKey(text); @@ -65,35 +65,39 @@ public async Task GenerateAsync(string text, CancellationToken ct = def if (cached != null) { this._logger.LogDebug("Cache hit for single embedding, dimensions: {Dimensions}", cached.Vector.Length); - return cached.Vector; + // Return cached result with token count if available + return cached.TokenCount.HasValue + ? EmbeddingResult.FromVectorWithTokens(cached.Vector, cached.TokenCount.Value) + : EmbeddingResult.FromVector(cached.Vector); } } // Generate embedding this._logger.LogDebug("Cache miss for single embedding, calling {Provider}", this.ProviderType); - var vector = await this._inner.GenerateAsync(text, ct).ConfigureAwait(false); + var result = await this._inner.GenerateAsync(text, ct).ConfigureAwait(false); // Store in cache (if mode allows) if (this._cache.Mode != CacheModes.ReadOnly) { - await this._cache.StoreAsync(key, vector, ct).ConfigureAwait(false); - this._logger.LogDebug("Stored embedding in cache, dimensions: {Dimensions}", vector.Length); + await this._cache.StoreAsync(key, result.Vector, result.TokenCount, ct).ConfigureAwait(false); + this._logger.LogDebug("Stored embedding in cache, dimensions: {Dimensions}, tokenCount: {TokenCount}", + result.Vector.Length, result.TokenCount); } - return vector; + return result; } /// - public async Task GenerateAsync(IEnumerable texts, CancellationToken ct = default) + public async Task GenerateAsync(IEnumerable texts, CancellationToken ct = default) { var textList = texts.ToList(); if (textList.Count == 0) { - return Array.Empty(); + return []; } // Initialize result array with nulls - var results = new float[textList.Count][]; + var results = new EmbeddingResult?[textList.Count]; // Track which texts need to be generated var toGenerate = new List<(int Index, string Text)>(); @@ -108,7 +112,10 @@ public async Task GenerateAsync(IEnumerable texts, Cancellati if (cached != null) { - results[i] = cached.Vector; + // Return cached result with token count if available + results[i] = cached.TokenCount.HasValue + ? EmbeddingResult.FromVectorWithTokens(cached.Vector, cached.TokenCount.Value) + : EmbeddingResult.FromVector(cached.Vector); } else { @@ -133,26 +140,27 @@ public async Task GenerateAsync(IEnumerable texts, Cancellati if (toGenerate.Count > 0) { var textsToGenerate = toGenerate.Select(x => x.Text); - var generatedVectors = await this._inner.GenerateAsync(textsToGenerate, ct).ConfigureAwait(false); + var generatedResults = await this._inner.GenerateAsync(textsToGenerate, ct).ConfigureAwait(false); - // Map generated vectors back to results and store in cache + // Map generated results back to results array and store in cache for (int i = 0; i < toGenerate.Count; i++) { var (originalIndex, text) = toGenerate[i]; - results[originalIndex] = generatedVectors[i]; + results[originalIndex] = generatedResults[i]; // Store in cache (if mode allows) if (this._cache.Mode != CacheModes.ReadOnly) { var key = this.BuildCacheKey(text); - await this._cache.StoreAsync(key, generatedVectors[i], ct).ConfigureAwait(false); + await this._cache.StoreAsync(key, generatedResults[i].Vector, generatedResults[i].TokenCount, ct).ConfigureAwait(false); } } this._logger.LogDebug("Generated and cached {Count} embeddings", toGenerate.Count); } - return results; + // Convert nullable array to non-nullable (all slots should be filled now) + return results.Select(r => r!).ToArray(); } /// diff --git a/src/Core/Embeddings/EmbeddingConstants.cs b/src/Core/Embeddings/EmbeddingConstants.cs deleted file mode 100644 index d8b9dcd4a..000000000 --- a/src/Core/Embeddings/EmbeddingConstants.cs +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -namespace KernelMemory.Core.Embeddings; - -/// -/// Constants for embedding generation including known model dimensions, -/// default configurations, and batch sizes. -/// -public static class EmbeddingConstants -{ - /// - /// Default batch size for embedding generation requests. - /// Configurable per provider, but this is the default. - /// - public const int DefaultBatchSize = 10; - - /// - /// Default Ollama model for embeddings. - /// - public const string DefaultOllamaModel = "qwen3-embedding"; - - /// - /// Default Ollama base URL. - /// - public const string DefaultOllamaBaseUrl = "http://localhost:11434"; - - /// - /// Default HuggingFace model for embeddings. - /// - public const string DefaultHuggingFaceModel = "sentence-transformers/all-MiniLM-L6-v2"; - - /// - /// Default HuggingFace Inference API base URL. - /// - public const string DefaultHuggingFaceBaseUrl = "https://api-inference.huggingface.co"; - - /// - /// Default OpenAI API base URL. - /// - public const string DefaultOpenAIBaseUrl = "https://api.openai.com"; - - /// - /// Azure OpenAI API version. - /// - public const string AzureOpenAIApiVersion = "2024-02-01"; - - /// - /// Known model dimensions for common embedding models. - /// These values are fixed per model and used for validation and cache key generation. - /// - public static readonly IReadOnlyDictionary KnownModelDimensions = new Dictionary - { - // Ollama models - ["qwen3-embedding"] = 1024, - ["nomic-embed-text"] = 768, - ["embeddinggemma"] = 768, - - // OpenAI models - ["text-embedding-ada-002"] = 1536, - ["text-embedding-3-small"] = 1536, - ["text-embedding-3-large"] = 3072, - - // HuggingFace models - ["sentence-transformers/all-MiniLM-L6-v2"] = 384, - ["BAAI/bge-base-en-v1.5"] = 768 - }; - - /// - /// Try to get the dimensions for a known model. - /// - /// The model name to look up. - /// The dimensions if found, 0 otherwise. - /// True if the model is known, false otherwise. - public static bool TryGetDimensions(string modelName, out int dimensions) - { - return KnownModelDimensions.TryGetValue(modelName, out dimensions); - } -} diff --git a/src/Core/Embeddings/EmbeddingResult.cs b/src/Core/Embeddings/EmbeddingResult.cs new file mode 100644 index 000000000..cc7cac5e9 --- /dev/null +++ b/src/Core/Embeddings/EmbeddingResult.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace KernelMemory.Core.Embeddings; + +/// +/// Result of embedding generation including the vector and optional metadata. +/// +public sealed class EmbeddingResult +{ + /// + /// The generated embedding vector. + /// + [SuppressMessage("Performance", "CA1819:Properties should not return arrays", + Justification = "Embedding vectors are read-only after creation")] + public required float[] Vector { get; init; } + + /// + /// Optional token count if the provider reports it. + /// Used for cost tracking and usage monitoring. + /// + public int? TokenCount { get; init; } + + /// + /// Creates an EmbeddingResult with just a vector (no token count). + /// + public static EmbeddingResult FromVector(float[] vector) + { + return new EmbeddingResult { Vector = vector, TokenCount = null }; + } + + /// + /// Creates an EmbeddingResult with vector and token count. + /// + public static EmbeddingResult FromVectorWithTokens(float[] vector, int tokenCount) + { + return new EmbeddingResult { Vector = vector, TokenCount = tokenCount }; + } +} diff --git a/src/Core/Embeddings/IEmbeddingGenerator.cs b/src/Core/Embeddings/IEmbeddingGenerator.cs index a037269a3..08abd1a1c 100644 --- a/src/Core/Embeddings/IEmbeddingGenerator.cs +++ b/src/Core/Embeddings/IEmbeddingGenerator.cs @@ -39,10 +39,10 @@ public interface IEmbeddingGenerator /// /// The text to generate embedding for. /// Cancellation token. - /// The embedding vector as a float array. + /// The embedding result with vector and optional token count. /// When the API call fails. /// When the operation is cancelled. - Task GenerateAsync(string text, CancellationToken ct = default); + Task GenerateAsync(string text, CancellationToken ct = default); /// /// Generate embeddings for multiple texts (batch). @@ -50,8 +50,8 @@ public interface IEmbeddingGenerator /// /// The texts to generate embeddings for. /// Cancellation token. - /// Array of embedding vectors, in the same order as the input texts. + /// Array of embedding results with vectors and optional token counts, in the same order as the input texts. /// When the API call fails. /// When the operation is cancelled. - Task GenerateAsync(IEnumerable texts, CancellationToken ct = default); + Task GenerateAsync(IEnumerable texts, CancellationToken ct = default); } diff --git a/src/Core/Embeddings/Providers/AzureOpenAIEmbeddingGenerator.cs b/src/Core/Embeddings/Providers/AzureOpenAIEmbeddingGenerator.cs index 3e7a61302..fade83bb0 100644 --- a/src/Core/Embeddings/Providers/AzureOpenAIEmbeddingGenerator.cs +++ b/src/Core/Embeddings/Providers/AzureOpenAIEmbeddingGenerator.cs @@ -73,22 +73,22 @@ public AzureOpenAIEmbeddingGenerator( } /// - public async Task GenerateAsync(string text, CancellationToken ct = default) + public async Task GenerateAsync(string text, CancellationToken ct = default) { var results = await this.GenerateAsync(new[] { text }, ct).ConfigureAwait(false); return results[0]; } /// - public async Task GenerateAsync(IEnumerable texts, CancellationToken ct = default) + public async Task GenerateAsync(IEnumerable texts, CancellationToken ct = default) { var textArray = texts.ToArray(); if (textArray.Length == 0) { - return Array.Empty(); + return []; } - var url = $"{this._endpoint}/openai/deployments/{this._deployment}/embeddings?api-version={EmbeddingConstants.AzureOpenAIApiVersion}"; + var url = $"{this._endpoint}/openai/deployments/{this._deployment}/embeddings?api-version={Constants.EmbeddingDefaults.AzureOpenAIApiVersion}"; var request = new AzureEmbeddingRequest { @@ -114,12 +114,31 @@ public async Task GenerateAsync(IEnumerable texts, Cancellati // Sort by index to ensure correct ordering var sortedData = result.Data.OrderBy(d => d.Index).ToArray(); - var embeddings = sortedData.Select(d => d.Embedding).ToArray(); + + // Get total token count from API response + var totalTokens = result.Usage?.TotalTokens; this._logger.LogTrace("Azure OpenAI returned {Count} embeddings, usage: {TotalTokens} tokens", - embeddings.Length, result.Usage?.TotalTokens); + sortedData.Length, totalTokens); + + // Calculate per-embedding token count if total tokens available + // For batch requests, we distribute tokens evenly across embeddings (approximation) + int? perEmbeddingTokens = null; + if (totalTokens.HasValue && sortedData.Length > 0) + { + perEmbeddingTokens = totalTokens.Value / sortedData.Length; + } + + // Create EmbeddingResult for each embedding with token count + var results = new EmbeddingResult[sortedData.Length]; + for (int i = 0; i < sortedData.Length; i++) + { + results[i] = perEmbeddingTokens.HasValue + ? EmbeddingResult.FromVectorWithTokens(sortedData[i].Embedding, perEmbeddingTokens.Value) + : EmbeddingResult.FromVector(sortedData[i].Embedding); + } - return embeddings; + return results; } /// diff --git a/src/Core/Embeddings/Providers/HuggingFaceEmbeddingGenerator.cs b/src/Core/Embeddings/Providers/HuggingFaceEmbeddingGenerator.cs index bc13c3ee6..9fd4537a6 100644 --- a/src/Core/Embeddings/Providers/HuggingFaceEmbeddingGenerator.cs +++ b/src/Core/Embeddings/Providers/HuggingFaceEmbeddingGenerator.cs @@ -58,7 +58,7 @@ public HuggingFaceEmbeddingGenerator( this._httpClient = httpClient; this._apiKey = apiKey; - this._baseUrl = (baseUrl ?? EmbeddingConstants.DefaultHuggingFaceBaseUrl).TrimEnd('/'); + this._baseUrl = (baseUrl ?? Constants.EmbeddingDefaults.DefaultHuggingFaceBaseUrl).TrimEnd('/'); this.ModelName = model; this.VectorDimensions = vectorDimensions; this.IsNormalized = isNormalized; @@ -69,19 +69,19 @@ public HuggingFaceEmbeddingGenerator( } /// - public async Task GenerateAsync(string text, CancellationToken ct = default) + public async Task GenerateAsync(string text, CancellationToken ct = default) { var results = await this.GenerateAsync(new[] { text }, ct).ConfigureAwait(false); return results[0]; } /// - public async Task GenerateAsync(IEnumerable texts, CancellationToken ct = default) + public async Task GenerateAsync(IEnumerable texts, CancellationToken ct = default) { var textArray = texts.ToArray(); if (textArray.Length == 0) { - return Array.Empty(); + return []; } var endpoint = $"{this._baseUrl}/models/{this.ModelName}"; @@ -112,7 +112,14 @@ public async Task GenerateAsync(IEnumerable texts, Cancellati this._logger.LogTrace("HuggingFace returned {Count} embeddings with {Dimensions} dimensions each", embeddings.Length, embeddings[0].Length); - return embeddings; + // HuggingFace API does not return token count + var results = new EmbeddingResult[embeddings.Length]; + for (int i = 0; i < embeddings.Length; i++) + { + results[i] = EmbeddingResult.FromVector(embeddings[i]); + } + + return results; } /// diff --git a/src/Core/Embeddings/Providers/OllamaEmbeddingGenerator.cs b/src/Core/Embeddings/Providers/OllamaEmbeddingGenerator.cs index d43e0058c..c9521d2dd 100644 --- a/src/Core/Embeddings/Providers/OllamaEmbeddingGenerator.cs +++ b/src/Core/Embeddings/Providers/OllamaEmbeddingGenerator.cs @@ -63,7 +63,7 @@ public OllamaEmbeddingGenerator( } /// - public async Task GenerateAsync(string text, CancellationToken ct = default) + public async Task GenerateAsync(string text, CancellationToken ct = default) { var endpoint = $"{this._baseUrl}/api/embeddings"; @@ -87,15 +87,16 @@ public async Task GenerateAsync(string text, CancellationToken ct = def this._logger.LogTrace("Ollama returned embedding with {Dimensions} dimensions", result.Embedding.Length); - return result.Embedding; + // Ollama API does not return token count + return EmbeddingResult.FromVector(result.Embedding); } /// - public async Task GenerateAsync(IEnumerable texts, CancellationToken ct = default) + public async Task GenerateAsync(IEnumerable texts, CancellationToken ct = default) { // Ollama doesn't support batch embedding natively, so process one at a time var textList = texts.ToList(); - var results = new float[textList.Count][]; + var results = new EmbeddingResult[textList.Count]; this._logger.LogDebug("Generating {Count} embeddings via Ollama (sequential)", textList.Count); diff --git a/src/Core/Embeddings/Providers/OpenAIEmbeddingGenerator.cs b/src/Core/Embeddings/Providers/OpenAIEmbeddingGenerator.cs index 81e4bce18..3d8ca2b02 100644 --- a/src/Core/Embeddings/Providers/OpenAIEmbeddingGenerator.cs +++ b/src/Core/Embeddings/Providers/OpenAIEmbeddingGenerator.cs @@ -58,7 +58,7 @@ public OpenAIEmbeddingGenerator( this._httpClient = httpClient; this._apiKey = apiKey; - this._baseUrl = (baseUrl ?? EmbeddingConstants.DefaultOpenAIBaseUrl).TrimEnd('/'); + this._baseUrl = (baseUrl ?? Constants.EmbeddingDefaults.DefaultOpenAIBaseUrl).TrimEnd('/'); this.ModelName = model; this.VectorDimensions = vectorDimensions; this.IsNormalized = isNormalized; @@ -69,19 +69,19 @@ public OpenAIEmbeddingGenerator( } /// - public async Task GenerateAsync(string text, CancellationToken ct = default) + public async Task GenerateAsync(string text, CancellationToken ct = default) { var results = await this.GenerateAsync(new[] { text }, ct).ConfigureAwait(false); return results[0]; } /// - public async Task GenerateAsync(IEnumerable texts, CancellationToken ct = default) + public async Task GenerateAsync(IEnumerable texts, CancellationToken ct = default) { var textArray = texts.ToArray(); if (textArray.Length == 0) { - return Array.Empty(); + return []; } var endpoint = $"{this._baseUrl}/v1/embeddings"; @@ -111,12 +111,31 @@ public async Task GenerateAsync(IEnumerable texts, Cancellati // Sort by index to ensure correct ordering var sortedData = result.Data.OrderBy(d => d.Index).ToArray(); - var embeddings = sortedData.Select(d => d.Embedding).ToArray(); + + // Get total token count from API response + var totalTokens = result.Usage?.TotalTokens; this._logger.LogTrace("OpenAI returned {Count} embeddings, usage: {TotalTokens} tokens", - embeddings.Length, result.Usage?.TotalTokens); + sortedData.Length, totalTokens); + + // Calculate per-embedding token count if total tokens available + // For batch requests, we distribute tokens evenly across embeddings (approximation) + int? perEmbeddingTokens = null; + if (totalTokens.HasValue && sortedData.Length > 0) + { + perEmbeddingTokens = totalTokens.Value / sortedData.Length; + } + + // Create EmbeddingResult for each embedding with token count + var results = new EmbeddingResult[sortedData.Length]; + for (int i = 0; i < sortedData.Length; i++) + { + results[i] = perEmbeddingTokens.HasValue + ? EmbeddingResult.FromVectorWithTokens(sortedData[i].Embedding, perEmbeddingTokens.Value) + : EmbeddingResult.FromVector(sortedData[i].Embedding); + } - return embeddings; + return results; } /// diff --git a/src/Core/Logging/ActivityEnricher.cs b/src/Core/Logging/ActivityEnricher.cs index 8efeadde2..5f80a2c06 100644 --- a/src/Core/Logging/ActivityEnricher.cs +++ b/src/Core/Logging/ActivityEnricher.cs @@ -37,14 +37,14 @@ public void Enrich(LogEvent logEvent, ILogEventPropertyFactory propertyFactory) // Add TraceId for correlating logs across the entire operation var traceId = activity.TraceId.ToString(); - if (!string.IsNullOrEmpty(traceId) && traceId != LoggingConstants.EmptyTraceId) + if (!string.IsNullOrEmpty(traceId) && traceId != Constants.LoggingDefaults.EmptyTraceId) { logEvent.AddPropertyIfAbsent(propertyFactory.CreateProperty(TraceIdPropertyName, traceId)); } // Add SpanId for correlating logs within a specific span var spanId = activity.SpanId.ToString(); - if (!string.IsNullOrEmpty(spanId) && spanId != LoggingConstants.EmptySpanId) + if (!string.IsNullOrEmpty(spanId) && spanId != Constants.LoggingDefaults.EmptySpanId) { logEvent.AddPropertyIfAbsent(propertyFactory.CreateProperty(SpanIdPropertyName, spanId)); } diff --git a/src/Core/Logging/EnvironmentDetector.cs b/src/Core/Logging/EnvironmentDetector.cs index c48e24595..a7bafd49b 100644 --- a/src/Core/Logging/EnvironmentDetector.cs +++ b/src/Core/Logging/EnvironmentDetector.cs @@ -17,21 +17,21 @@ public static class EnvironmentDetector public static string GetEnvironment() { // Check DOTNET_ENVIRONMENT first (takes precedence) - var dotNetEnv = Environment.GetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable); + var dotNetEnv = Environment.GetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable); if (!string.IsNullOrWhiteSpace(dotNetEnv)) { return dotNetEnv; } // Fall back to ASPNETCORE_ENVIRONMENT - var aspNetEnv = Environment.GetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable); + var aspNetEnv = Environment.GetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable); if (!string.IsNullOrWhiteSpace(aspNetEnv)) { return aspNetEnv; } // Default to Development for safety (full logging) - return LoggingConstants.DefaultEnvironment; + return Constants.LoggingDefaults.DefaultEnvironment; } /// @@ -43,7 +43,7 @@ public static bool IsProduction() { return string.Equals( GetEnvironment(), - LoggingConstants.ProductionEnvironment, + Constants.LoggingDefaults.ProductionEnvironment, StringComparison.OrdinalIgnoreCase); } @@ -56,7 +56,7 @@ public static bool IsDevelopment() { return string.Equals( GetEnvironment(), - LoggingConstants.DefaultEnvironment, + Constants.LoggingDefaults.DefaultEnvironment, StringComparison.OrdinalIgnoreCase); } } diff --git a/src/Core/Logging/LoggingConstants.cs b/src/Core/Logging/LoggingConstants.cs deleted file mode 100644 index 52f2768ba..000000000 --- a/src/Core/Logging/LoggingConstants.cs +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using Serilog.Events; - -namespace KernelMemory.Core.Logging; - -/// -/// Centralized constants for the logging system. -/// All magic values related to logging are defined here for maintainability. -/// -public static class LoggingConstants -{ - /// - /// Default maximum file size before rotation (100MB). - /// Balances history retention with disk usage. - /// - public const long DefaultFileSizeLimitBytes = 100 * 1024 * 1024; - - /// - /// Default number of log files to retain (30 files). - /// Approximately 1 month of daily logs or ~3GB max storage. - /// - public const int DefaultRetainedFileCountLimit = 30; - - /// - /// Default minimum log level for file output. - /// Information level provides useful diagnostics without excessive verbosity. - /// - public const LogEventLevel DefaultFileLogLevel = LogEventLevel.Information; - - /// - /// Default minimum log level for console/stderr output. - /// Only warnings and errors appear on stderr by default. - /// - public const LogEventLevel DefaultConsoleLogLevel = LogEventLevel.Warning; - - /// - /// Environment variable for .NET runtime environment detection. - /// Takes precedence over ASPNETCORE_ENVIRONMENT. - /// - public const string DotNetEnvironmentVariable = "DOTNET_ENVIRONMENT"; - - /// - /// Fallback environment variable for ASP.NET Core applications. - /// Used when DOTNET_ENVIRONMENT is not set. - /// - public const string AspNetCoreEnvironmentVariable = "ASPNETCORE_ENVIRONMENT"; - - /// - /// Default environment when no environment variable is set. - /// Defaults to Development for developer safety (full logging enabled). - /// - public const string DefaultEnvironment = "Development"; - - /// - /// Production environment name for comparison. - /// Sensitive data is scrubbed only in Production. - /// - public const string ProductionEnvironment = "Production"; - - /// - /// Placeholder text for redacted sensitive data. - /// Used to indicate data was intentionally removed from logs. - /// - public const string RedactedPlaceholder = "[REDACTED]"; - - /// - /// Human-readable output template for log messages. - /// Includes timestamp, level, source context, message, and optional exception. - /// - public const string HumanReadableOutputTemplate = - "{Timestamp:yyyy-MM-dd HH:mm:ss.fff} [{Level:u3}] {SourceContext}: {Message:lj}{NewLine}{Exception}"; - - /// - /// Compact output template for console (stderr) output. - /// Shorter format suitable for CLI error reporting. - /// - public const string ConsoleOutputTemplate = - "{Timestamp:HH:mm:ss} [{Level:u3}] {Message:lj}{NewLine}{Exception}"; - - /// - /// Empty trace ID value (32 zeros) used when no Activity is present. - /// Indicates no distributed tracing context is available. - /// - public const string EmptyTraceId = "00000000000000000000000000000000"; - - /// - /// Empty span ID value (16 zeros) used when no Activity is present. - /// Indicates no distributed tracing context is available. - /// - public const string EmptySpanId = "0000000000000000"; -} diff --git a/src/Core/Logging/SensitiveDataScrubbingPolicy.cs b/src/Core/Logging/SensitiveDataScrubbingPolicy.cs index 7861479ed..e67eec61d 100644 --- a/src/Core/Logging/SensitiveDataScrubbingPolicy.cs +++ b/src/Core/Logging/SensitiveDataScrubbingPolicy.cs @@ -45,7 +45,7 @@ public bool TryDestructure( // - File contents, queries if (value is string) { - result = new ScalarValue(LoggingConstants.RedactedPlaceholder); + result = new ScalarValue(Constants.LoggingDefaults.RedactedPlaceholder); return true; } diff --git a/src/Core/Logging/SerilogFactory.cs b/src/Core/Logging/SerilogFactory.cs index 11244ddbb..69013dddc 100644 --- a/src/Core/Logging/SerilogFactory.cs +++ b/src/Core/Logging/SerilogFactory.cs @@ -37,8 +37,8 @@ public static ILoggerFactory CreateLoggerFactory(LoggingConfig config) // Configure console output (stderr) for warnings and errors loggerConfig = loggerConfig.WriteTo.Console( - restrictedToMinimumLevel: LoggingConstants.DefaultConsoleLogLevel, - outputTemplate: LoggingConstants.ConsoleOutputTemplate, + restrictedToMinimumLevel: Constants.LoggingDefaults.DefaultConsoleLogLevel, + outputTemplate: Constants.LoggingDefaults.ConsoleOutputTemplate, formatProvider: CultureInfo.InvariantCulture, standardErrorFromLevel: LogEventLevel.Verbose); @@ -81,21 +81,19 @@ private static LoggerConfiguration ConfigureFileLogging( loggerConfig = loggerConfig.WriteTo.Async(a => a.File( new CompactJsonFormatter(), filePath, - fileSizeLimitBytes: LoggingConstants.DefaultFileSizeLimitBytes, - rollingInterval: RollingInterval.Day, + fileSizeLimitBytes: Constants.LoggingDefaults.DefaultFileSizeLimitBytes, rollOnFileSizeLimit: true, - retainedFileCountLimit: LoggingConstants.DefaultRetainedFileCountLimit)); + retainedFileCountLimit: Constants.LoggingDefaults.DefaultRetainedFileCountLimit)); } else { loggerConfig = loggerConfig.WriteTo.Async(a => a.File( filePath, - outputTemplate: LoggingConstants.HumanReadableOutputTemplate, + outputTemplate: Constants.LoggingDefaults.HumanReadableOutputTemplate, formatProvider: CultureInfo.InvariantCulture, - fileSizeLimitBytes: LoggingConstants.DefaultFileSizeLimitBytes, - rollingInterval: RollingInterval.Day, + fileSizeLimitBytes: Constants.LoggingDefaults.DefaultFileSizeLimitBytes, rollOnFileSizeLimit: true, - retainedFileCountLimit: LoggingConstants.DefaultRetainedFileCountLimit)); + retainedFileCountLimit: Constants.LoggingDefaults.DefaultRetainedFileCountLimit)); } } else @@ -106,21 +104,19 @@ private static LoggerConfiguration ConfigureFileLogging( loggerConfig = loggerConfig.WriteTo.File( new CompactJsonFormatter(), filePath, - fileSizeLimitBytes: LoggingConstants.DefaultFileSizeLimitBytes, - rollingInterval: RollingInterval.Day, + fileSizeLimitBytes: Constants.LoggingDefaults.DefaultFileSizeLimitBytes, rollOnFileSizeLimit: true, - retainedFileCountLimit: LoggingConstants.DefaultRetainedFileCountLimit); + retainedFileCountLimit: Constants.LoggingDefaults.DefaultRetainedFileCountLimit); } else { loggerConfig = loggerConfig.WriteTo.File( filePath, - outputTemplate: LoggingConstants.HumanReadableOutputTemplate, + outputTemplate: Constants.LoggingDefaults.HumanReadableOutputTemplate, formatProvider: CultureInfo.InvariantCulture, - fileSizeLimitBytes: LoggingConstants.DefaultFileSizeLimitBytes, - rollingInterval: RollingInterval.Day, + fileSizeLimitBytes: Constants.LoggingDefaults.DefaultFileSizeLimitBytes, rollOnFileSizeLimit: true, - retainedFileCountLimit: LoggingConstants.DefaultRetainedFileCountLimit); + retainedFileCountLimit: Constants.LoggingDefaults.DefaultRetainedFileCountLimit); } } diff --git a/src/Core/Search/IVectorIndex.cs b/src/Core/Search/IVectorIndex.cs new file mode 100644 index 000000000..81de43de7 --- /dev/null +++ b/src/Core/Search/IVectorIndex.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft. All rights reserved. +namespace KernelMemory.Core.Search; + +/// +/// Interface for vector search index operations. +/// Extends ISearchIndex with vector-specific capabilities. +/// All vectors are normalized at write time, searches use dot product (equivalent to cosine similarity). +/// +public interface IVectorIndex : ISearchIndex +{ + /// + /// Vector dimensions for this index (must match embedding model). + /// + int VectorDimensions { get; } + + /// + /// Indexes content with vector embedding. + /// Generates embedding using configured generator, normalizes it, then stores. + /// + /// Unique content identifier. + /// Text to generate embedding for. + /// Cancellation token. + new Task IndexAsync(string contentId, string text, CancellationToken cancellationToken = default); + + /// + /// Searches the vector index for similar content using dot product on normalized vectors. + /// + /// Query text to generate embedding for. + /// Maximum number of results. + /// Cancellation token. + /// List of matches ordered by similarity (highest score first). + Task> SearchAsync(string queryText, int limit = 10, CancellationToken cancellationToken = default); +} diff --git a/src/Core/Search/Models/RerankingConfig.cs b/src/Core/Search/Models/RerankingConfig.cs index a23af480c..756ac102f 100644 --- a/src/Core/Search/Models/RerankingConfig.cs +++ b/src/Core/Search/Models/RerankingConfig.cs @@ -29,5 +29,5 @@ public sealed class RerankingConfig /// Fourth appearance: multiplier = 0.125 (12.5% boost) /// [System.Diagnostics.CodeAnalysis.SuppressMessage("Performance", "CA1819:Properties should not return arrays")] - public float[] DiminishingMultipliers { get; init; } = SearchConstants.DefaultDiminishingMultipliers; + public float[] DiminishingMultipliers { get; init; } = Constants.SearchDefaults.DefaultDiminishingMultipliers; } diff --git a/src/Core/Search/Models/SearchRequest.cs b/src/Core/Search/Models/SearchRequest.cs index 3202764c9..28c18effa 100644 --- a/src/Core/Search/Models/SearchRequest.cs +++ b/src/Core/Search/Models/SearchRequest.cs @@ -53,7 +53,7 @@ public sealed class SearchRequest /// Maximum number of results to return. /// Default: 20 (from config or SearchConstants). /// - public int Limit { get; set; } = SearchConstants.DefaultLimit; + public int Limit { get; set; } = Constants.SearchDefaults.DefaultLimit; /// /// Pagination offset (skip first N results). @@ -66,7 +66,7 @@ public sealed class SearchRequest /// Results below this score are filtered out. /// Default: 0.3 (from config or SearchConstants). /// - public float MinRelevance { get; set; } = SearchConstants.DefaultMinRelevance; + public float MinRelevance { get; set; } = Constants.SearchDefaults.DefaultMinRelevance; /// /// Memory safety limit per node. diff --git a/src/Core/Search/NodeSearchService.cs b/src/Core/Search/NodeSearchService.cs index cd7333abc..ba5669f1b 100644 --- a/src/Core/Search/NodeSearchService.cs +++ b/src/Core/Search/NodeSearchService.cs @@ -43,12 +43,12 @@ public sealed class NodeSearchService /// The node ID this service operates on. /// The FTS index for this node. /// The content storage for loading full records. - /// Optional index ID for this FTS index. Defaults to SearchConstants.DefaultFtsIndexId. + /// Optional index ID for this FTS index. Defaults to Constants.SearchDefaults.DefaultFtsIndexId. public NodeSearchService( string nodeId, IFtsIndex ftsIndex, IContentStorage contentStorage, - string indexId = SearchConstants.DefaultFtsIndexId) + string indexId = Constants.SearchDefaults.DefaultFtsIndexId) { this._nodeId = nodeId; this._indexId = indexId; @@ -73,12 +73,12 @@ public NodeSearchService( try { // Apply timeout - var timeout = request.TimeoutSeconds ?? SearchConstants.DefaultSearchTimeoutSeconds; + var timeout = request.TimeoutSeconds ?? Constants.SearchDefaults.DefaultSearchTimeoutSeconds; using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); cts.CancelAfter(TimeSpan.FromSeconds(timeout)); // Query the FTS index - var maxResults = request.MaxResultsPerNode ?? SearchConstants.DefaultMaxResultsPerNode; + var maxResults = request.MaxResultsPerNode ?? Constants.SearchDefaults.DefaultMaxResultsPerNode; // Convert QueryNode to FTS query string and extract NOT terms for post-filtering var queryResult = this.ExtractFtsQuery(queryNode); diff --git a/src/Core/Search/Reranking/WeightedDiminishingReranker.cs b/src/Core/Search/Reranking/WeightedDiminishingReranker.cs index 2a93a37ce..e803e2afa 100644 --- a/src/Core/Search/Reranking/WeightedDiminishingReranker.cs +++ b/src/Core/Search/Reranking/WeightedDiminishingReranker.cs @@ -58,10 +58,10 @@ private float ApplyWeights(SearchIndexResult result, RerankingConfig config) // Get node weight (default to 1.0 if not configured) var nodeWeight = config.NodeWeights.TryGetValue(result.NodeId, out var nw) ? nw - : SearchConstants.DefaultNodeWeight; + : Constants.SearchDefaults.DefaultNodeWeight; // Get index weight (default to 1.0 if not configured) - var indexWeight = SearchConstants.DefaultIndexWeight; + var indexWeight = Constants.SearchDefaults.DefaultIndexWeight; if (config.IndexWeights.TryGetValue(result.NodeId, out var nodeIndexes)) { if (nodeIndexes.TryGetValue(result.IndexId, out var iw)) @@ -103,9 +103,9 @@ private SearchResult AggregateRecord( } // Cap at 1.0 (max relevance) - if (finalScore > SearchConstants.MaxRelevanceScore) + if (finalScore > Constants.SearchDefaults.MaxRelevanceScore) { - finalScore = SearchConstants.MaxRelevanceScore; + finalScore = Constants.SearchDefaults.MaxRelevanceScore; } // Use the highest-scored appearance for the record data diff --git a/src/Core/Search/SearchConstants.cs b/src/Core/Search/SearchConstants.cs deleted file mode 100644 index f6e3b2812..000000000 --- a/src/Core/Search/SearchConstants.cs +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -namespace KernelMemory.Core.Search; - -/// -/// Constants for search functionality. -/// Centralizes all magic values for maintainability. -/// -public static class SearchConstants -{ - /// - /// Default minimum relevance score threshold (0.0-1.0). - /// Results below this score are filtered out. - /// - public const float DefaultMinRelevance = 0.3f; - - /// - /// Default maximum number of results to return per search. - /// - public const int DefaultLimit = 20; - - /// - /// Default search timeout in seconds per node. - /// - public const int DefaultSearchTimeoutSeconds = 30; - - /// - /// Default maximum results to retrieve from each node (memory safety). - /// Prevents memory exhaustion from large result sets. - /// - public const int DefaultMaxResultsPerNode = 1000; - - /// - /// Default node weight for relevance scoring. - /// - public const float DefaultNodeWeight = 1.0f; - - /// - /// Default search index weight for relevance scoring. - /// - public const float DefaultIndexWeight = 1.0f; - - /// - /// BM25 score normalization divisor for exponential mapping. - /// Maps BM25 range [-10, 0] to [0.37, 1.0] using exp(score/divisor). - /// - public const double Bm25NormalizationDivisor = 10.0; - - /// - /// Maximum nesting depth for query parentheses. - /// Prevents DoS attacks via deeply nested queries. - /// - public const int MaxQueryDepth = 10; - - /// - /// Maximum number of boolean operators (AND/OR/NOT) in a single query. - /// Prevents query complexity attacks. - /// - public const int MaxBooleanOperators = 50; - - /// - /// Maximum length of a field value in query (characters). - /// Prevents oversized query values. - /// - public const int MaxFieldValueLength = 1000; - - /// - /// Maximum time allowed for query parsing (milliseconds). - /// Prevents regex catastrophic backtracking. - /// - public const int QueryParseTimeoutMs = 1000; - - /// - /// Default snippet length in characters. - /// - public const int DefaultSnippetLength = 200; - - /// - /// Default maximum number of snippets per result. - /// - public const int DefaultMaxSnippetsPerResult = 1; - - /// - /// Default snippet separator between multiple snippets. - /// - public const string DefaultSnippetSeparator = "..."; - - /// - /// Default highlight prefix marker. - /// - public const string DefaultHighlightPrefix = ""; - - /// - /// Default highlight suffix marker. - /// - public const string DefaultHighlightSuffix = ""; - - /// - /// Diminishing returns multipliers for aggregating multiple appearances of same record. - /// First appearance: 1.0 (full weight) - /// Second appearance: 0.5 (50% boost) - /// Third appearance: 0.25 (25% boost) - /// Fourth appearance: 0.125 (12.5% boost) - /// Each subsequent multiplier is half of the previous. - /// - public static readonly float[] DefaultDiminishingMultipliers = [1.0f, 0.5f, 0.25f, 0.125f]; - - /// - /// Wildcard character for "all nodes" in node selection. - /// - public const string AllNodesWildcard = "*"; - - /// - /// Maximum relevance score (scores are capped at this value). - /// - public const float MaxRelevanceScore = 1.0f; - - /// - /// Minimum relevance score. - /// - public const float MinRelevanceScore = 0.0f; - - /// - /// Default FTS index ID used when not specified in configuration. - /// This is the identifier assigned to search results from the full-text search index. - /// - public const string DefaultFtsIndexId = "fts-main"; -} diff --git a/src/Core/Search/SearchService.cs b/src/Core/Search/SearchService.cs index 621c3c03e..8bd0626da 100644 --- a/src/Core/Search/SearchService.cs +++ b/src/Core/Search/SearchService.cs @@ -24,7 +24,7 @@ public sealed class SearchService : ISearchService /// /// Per-node, per-index weights for relevance scoring. /// Outer key = node ID, Inner key = index ID, Value = weight multiplier. - /// If null or missing entries, defaults to SearchConstants.DefaultIndexWeight (1.0). + /// If null or missing entries, defaults to Constants.SearchDefaults.DefaultIndexWeight (1.0). /// /// Reranking implementation (default: WeightedDiminishingReranker). public SearchService( @@ -223,7 +223,7 @@ private RerankingConfig BuildRerankingConfig(SearchRequest request, string[] nod } else { - nodeWeights[nodeId] = SearchConstants.DefaultNodeWeight; + nodeWeights[nodeId] = Constants.SearchDefaults.DefaultNodeWeight; } } @@ -239,9 +239,9 @@ private RerankingConfig BuildRerankingConfig(SearchRequest request, string[] nod var nodeIndexWeights = new Dictionary(configuredNodeIndexWeights); // Ensure default FTS index has a weight (use configured or default) - if (!nodeIndexWeights.ContainsKey(SearchConstants.DefaultFtsIndexId)) + if (!nodeIndexWeights.ContainsKey(Constants.SearchDefaults.DefaultFtsIndexId)) { - nodeIndexWeights[SearchConstants.DefaultFtsIndexId] = SearchConstants.DefaultIndexWeight; + nodeIndexWeights[Constants.SearchDefaults.DefaultFtsIndexId] = Constants.SearchDefaults.DefaultIndexWeight; } indexWeights[nodeId] = nodeIndexWeights; @@ -251,7 +251,7 @@ private RerankingConfig BuildRerankingConfig(SearchRequest request, string[] nod // No configured weights for this node, use default indexWeights[nodeId] = new Dictionary { - [SearchConstants.DefaultFtsIndexId] = SearchConstants.DefaultIndexWeight + [Constants.SearchDefaults.DefaultFtsIndexId] = Constants.SearchDefaults.DefaultIndexWeight }; } } @@ -260,7 +260,7 @@ private RerankingConfig BuildRerankingConfig(SearchRequest request, string[] nod { NodeWeights = nodeWeights, IndexWeights = indexWeights, - DiminishingMultipliers = SearchConstants.DefaultDiminishingMultipliers + DiminishingMultipliers = Constants.SearchDefaults.DefaultDiminishingMultipliers }; } } diff --git a/src/Core/Search/SqliteFtsIndex.cs b/src/Core/Search/SqliteFtsIndex.cs index 779f811cd..6ea0da048 100644 --- a/src/Core/Search/SqliteFtsIndex.cs +++ b/src/Core/Search/SqliteFtsIndex.cs @@ -201,7 +201,7 @@ LIMIT @limit // BM25 scores are typically in range [-10, 0] // Use exponential function to map to [0, 1]: score = exp(raw_score / divisor) // This gives: -10 → 0.37, -5 → 0.61, -1 → 0.90, 0 → 1.0 - var normalizedScore = Math.Exp(rawScore / SearchConstants.Bm25NormalizationDivisor); + var normalizedScore = Math.Exp(rawScore / Constants.SearchDefaults.Bm25NormalizationDivisor); results.Add(new FtsMatch { @@ -228,18 +228,14 @@ private async Task> GetAllDocumentsAsync(int limit, Canc { // Select all documents without FTS MATCH filtering // Since there's no FTS query, we can't use bm25() - assign a default score of 1.0 - var searchSql = $""" - SELECT - content_id, - substr(content, 1, 200) as snippet - FROM {TableName} - LIMIT @limit - """; + var searchSql = "SELECT content_id, substr(content, 1, " + Constants.Database.DefaultSqlSnippetLength + ") as snippet FROM " + TableName + " LIMIT @limit"; var searchCommand = this._connection!.CreateCommand(); await using (searchCommand.ConfigureAwait(false)) { +#pragma warning disable CA2100 // SQL string uses only constants and table name - no user input searchCommand.CommandText = searchSql; +#pragma warning restore CA2100 searchCommand.Parameters.AddWithValue("@limit", limit); var results = new List(); diff --git a/src/Core/Search/SqliteVectorIndex.cs b/src/Core/Search/SqliteVectorIndex.cs new file mode 100644 index 000000000..b68ce4607 --- /dev/null +++ b/src/Core/Search/SqliteVectorIndex.cs @@ -0,0 +1,338 @@ +// Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core.Embeddings; +using Microsoft.Data.Sqlite; +using Microsoft.Extensions.Logging; + +namespace KernelMemory.Core.Search; + +/// +/// SQLite-based vector search index implementation. +/// Stores normalized vectors as BLOBs and performs K-NN search using dot product. +/// Optionally supports sqlite-vec extension for accelerated distance calculations. +/// +public sealed class SqliteVectorIndex : IVectorIndex, IDisposable +{ + private const string TableName = "km_vectors"; + private readonly string _connectionString; + private readonly int _configuredDimensions; + private readonly bool _useSqliteVec; + private readonly IEmbeddingGenerator _embeddingGenerator; + private readonly ILogger _logger; + private SqliteConnection? _connection; + private bool _dimensionsValidated; + private bool _disposed; + private bool _sqliteVecAvailable; + + /// + public int VectorDimensions => this._configuredDimensions; + + /// + /// Initializes a new instance of SqliteVectorIndex. + /// + /// Path to the SQLite database file. + /// Expected vector dimensions (validated on first use). + /// Whether to attempt loading sqlite-vec extension. + /// The embedding generator to use. + /// Logger instance. + public SqliteVectorIndex( + string dbPath, + int dimensions, + bool useSqliteVec, + IEmbeddingGenerator embeddingGenerator, + ILogger logger) + { + ArgumentNullException.ThrowIfNull(dbPath, nameof(dbPath)); + ArgumentNullException.ThrowIfNull(embeddingGenerator, nameof(embeddingGenerator)); + ArgumentNullException.ThrowIfNull(logger, nameof(logger)); + + if (dimensions <= 0) + { + throw new ArgumentOutOfRangeException(nameof(dimensions), "Dimensions must be positive"); + } + + this._connectionString = $"Data Source={dbPath}"; + this._configuredDimensions = dimensions; + this._useSqliteVec = useSqliteVec; + this._embeddingGenerator = embeddingGenerator; + this._logger = logger; + } + + /// + /// Ensures the database connection is open and tables exist. + /// + /// Cancellation token. + public async Task InitializeAsync(CancellationToken cancellationToken = default) + { + if (this._connection != null) + { + return; + } + + this._connection = new SqliteConnection(this._connectionString); + await this._connection.OpenAsync(cancellationToken).ConfigureAwait(false); + + // Set synchronous=FULL to ensure writes are immediately persisted to disk + // This prevents data loss when connections are disposed quickly (CLI scenario) + using (var pragmaCmd = this._connection.CreateCommand()) + { + pragmaCmd.CommandText = "PRAGMA synchronous=FULL;"; + await pragmaCmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + + // Attempt to load sqlite-vec extension if configured + if (this._useSqliteVec) + { + this._sqliteVecAvailable = await this.TryLoadSqliteVecExtensionAsync(cancellationToken).ConfigureAwait(false); + if (!this._sqliteVecAvailable) + { + this._logger.LogWarning( + "sqlite-vec extension not found, using pure BLOB storage. " + + "For better performance with large datasets (>100K vectors), install sqlite-vec extension."); + } + } + + // Create vectors table if it doesn't exist + // Schema: content_id (primary key), vector (normalized float32 BLOB), created_at (timestamp) + var createTableSql = $""" + CREATE TABLE IF NOT EXISTS {TableName} ( + content_id TEXT PRIMARY KEY, + vector BLOB NOT NULL, + created_at TEXT NOT NULL + ); + """; + + var command = this._connection.CreateCommand(); + await using (command.ConfigureAwait(false)) + { + command.CommandText = createTableSql; + await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + + this._logger.LogDebug( + "SqliteVectorIndex initialized at {ConnectionString}, dimensions: {Dimensions}, sqlite-vec: {SqliteVec}", + this._connectionString, this._configuredDimensions, this._sqliteVecAvailable); + } + + /// + public async Task IndexAsync(string contentId, string text, CancellationToken cancellationToken = default) + { + ArgumentException.ThrowIfNullOrWhiteSpace(contentId, nameof(contentId)); + ArgumentNullException.ThrowIfNull(text, nameof(text)); + + await this.InitializeAsync(cancellationToken).ConfigureAwait(false); + + // Generate embedding + this._logger.LogDebug("Generating embedding for content {ContentId}", contentId); + var result = await this._embeddingGenerator.GenerateAsync(text, cancellationToken).ConfigureAwait(false); + var embedding = result.Vector; + + // Validate dimensions on first use (lazy validation) + if (!this._dimensionsValidated) + { + if (embedding.Length != this._configuredDimensions) + { + throw new InvalidOperationException( + $"Embedding model returned {embedding.Length} dimensions but config specifies {this._configuredDimensions}. " + + "Update config dimensions to match model output."); + } + + this._dimensionsValidated = true; + this._logger.LogDebug("Dimensions validated: {Dimensions}", this._configuredDimensions); + } + + // Normalize vector at write time (magnitude = 1) + var normalizedVector = VectorMath.NormalizeVector(embedding); + + // Serialize to BLOB (float32 array -> bytes) + var vectorBlob = VectorMath.VectorToBlob(normalizedVector); + + // Remove existing entry first (upsert semantics) + await this.RemoveAsync(contentId, cancellationToken).ConfigureAwait(false); + + // Insert new entry + var insertSql = $"INSERT INTO {TableName}(content_id, vector, created_at) VALUES (@contentId, @vector, @createdAt)"; + + var insertCommand = this._connection!.CreateCommand(); + await using (insertCommand.ConfigureAwait(false)) + { + insertCommand.CommandText = insertSql; + insertCommand.Parameters.AddWithValue("@contentId", contentId); + insertCommand.Parameters.AddWithValue("@vector", vectorBlob); + insertCommand.Parameters.AddWithValue("@createdAt", DateTimeOffset.UtcNow.ToString("o")); + await insertCommand.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + + this._logger.LogDebug("Indexed vector for content {ContentId}, dimensions: {Dimensions}", contentId, embedding.Length); + } + + /// + public async Task> SearchAsync(string queryText, int limit = 10, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(queryText, nameof(queryText)); + + if (string.IsNullOrWhiteSpace(queryText)) + { + return []; + } + + await this.InitializeAsync(cancellationToken).ConfigureAwait(false); + + // Generate query embedding + this._logger.LogDebug("Generating query embedding for vector search"); + var queryResult = await this._embeddingGenerator.GenerateAsync(queryText, cancellationToken).ConfigureAwait(false); + var queryEmbedding = queryResult.Vector; + + // Validate dimensions + if (queryEmbedding.Length != this._configuredDimensions) + { + throw new InvalidOperationException( + $"Query embedding has {queryEmbedding.Length} dimensions but index expects {this._configuredDimensions}"); + } + + // Normalize query vector + var normalizedQuery = VectorMath.NormalizeVector(queryEmbedding); + + // Retrieve all vectors and compute dot product (linear scan K-NN) + // For large datasets, sqlite-vec would provide optimized search, but we fall back to C# implementation + var selectSql = $"SELECT content_id, vector FROM {TableName}"; + + var selectCommand = this._connection!.CreateCommand(); + await using (selectCommand.ConfigureAwait(false)) + { + selectCommand.CommandText = selectSql; + + var matches = new List<(string ContentId, double Score)>(); + + var reader = await selectCommand.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + await using (reader.ConfigureAwait(false)) + { + while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + var contentId = reader.GetString(0); + var vectorBlob = (byte[])reader.GetValue(1); + var storedVector = VectorMath.BlobToVector(vectorBlob); + + // Compute dot product (for normalized vectors, this equals cosine similarity) + var score = VectorMath.DotProduct(normalizedQuery, storedVector); + matches.Add((contentId, score)); + } + } + + // Sort by score descending (highest similarity first) and take top N + var results = matches + .OrderByDescending(m => m.Score) + .Take(limit) + .Select(m => new VectorMatch { ContentId = m.ContentId, Score = m.Score }) + .ToList(); + + this._logger.LogDebug("Vector search returned {Count} results from {Total} vectors", results.Count, matches.Count); + return results; + } + } + + /// + public async Task RemoveAsync(string contentId, CancellationToken cancellationToken = default) + { + await this.InitializeAsync(cancellationToken).ConfigureAwait(false); + + var deleteSql = $"DELETE FROM {TableName} WHERE content_id = @contentId"; + + var deleteCommand = this._connection!.CreateCommand(); + await using (deleteCommand.ConfigureAwait(false)) + { + deleteCommand.CommandText = deleteSql; + deleteCommand.Parameters.AddWithValue("@contentId", contentId); + var rowsAffected = await deleteCommand.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + + if (rowsAffected > 0) + { + this._logger.LogDebug("Removed vector for content {ContentId}", contentId); + } + } + } + + /// + public async Task ClearAsync(CancellationToken cancellationToken = default) + { + await this.InitializeAsync(cancellationToken).ConfigureAwait(false); + + var deleteSql = $"DELETE FROM {TableName}"; + + var clearCommand = this._connection!.CreateCommand(); + await using (clearCommand.ConfigureAwait(false)) + { + clearCommand.CommandText = deleteSql; + await clearCommand.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + + this._logger.LogInformation("Cleared all vectors from vector index"); + } + + /// + /// Disposes the database connection. + /// Ensures all pending writes are flushed to disk before closing. + /// + public void Dispose() + { + if (this._disposed) + { + return; + } + + // Flush any pending writes before closing the connection + // SQLite needs explicit close to ensure writes are persisted + if (this._connection != null) + { + try + { + // Execute a checkpoint to flush WAL to disk (if WAL mode is enabled) + using var cmd = this._connection.CreateCommand(); + cmd.CommandText = "PRAGMA wal_checkpoint(TRUNCATE);"; + cmd.ExecuteNonQuery(); + } + catch (SqliteException ex) + { + this._logger.LogWarning(ex, "Failed to checkpoint WAL during vector index disposal"); + } + catch (InvalidOperationException ex) + { + this._logger.LogWarning(ex, "Failed to checkpoint WAL during vector index disposal - connection in invalid state"); + } + + this._connection.Close(); + this._connection.Dispose(); + this._connection = null; + } + + this._disposed = true; + } + + /// + /// Attempts to load the sqlite-vec extension for accelerated vector operations. + /// + /// True if extension loaded successfully, false otherwise. + private async Task TryLoadSqliteVecExtensionAsync(CancellationToken cancellationToken) + { + try + { + // sqlite-vec extension name varies by platform + // Linux: vec0, Windows: vec0.dll, macOS: vec0.dylib + using var cmd = this._connection!.CreateCommand(); + cmd.CommandText = "SELECT load_extension('vec0')"; + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + + this._logger.LogInformation("sqlite-vec extension loaded successfully"); + return true; + } + catch (SqliteException ex) when (ex.Message.Contains("not authorized") || ex.Message.Contains("cannot open")) + { + this._logger.LogDebug(ex, "sqlite-vec extension not available: {Message}", ex.Message); + return false; + } + catch (SqliteException ex) + { + this._logger.LogDebug(ex, "Failed to load sqlite-vec extension: {Message}", ex.Message); + return false; + } + } +} diff --git a/src/Core/Search/VectorMatch.cs b/src/Core/Search/VectorMatch.cs new file mode 100644 index 000000000..899487295 --- /dev/null +++ b/src/Core/Search/VectorMatch.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. +namespace KernelMemory.Core.Search; + +/// +/// Represents a match from a vector similarity search. +/// Score is a dot product of normalized vectors (range 0-1, where 1 is most similar). +/// +public sealed class VectorMatch +{ + /// + /// The content ID that matched the search query. + /// + public required string ContentId { get; init; } + + /// + /// The similarity score (dot product of normalized vectors). + /// Range: 0-1, where 1 indicates highest similarity. + /// + public required double Score { get; init; } +} diff --git a/src/Core/Search/VectorMath.cs b/src/Core/Search/VectorMath.cs new file mode 100644 index 000000000..2d2934b9d --- /dev/null +++ b/src/Core/Search/VectorMath.cs @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Buffers.Binary; + +namespace KernelMemory.Core.Search; + +/// +/// Static utility class for vector mathematics operations. +/// Provides normalization, distance calculations, and serialization for vector search. +/// +public static class VectorMath +{ + /// + /// Normalizes a vector to unit length (magnitude = 1). + /// Normalized vectors allow dot product to be used for cosine similarity. + /// + /// The vector to normalize. + /// A new normalized vector. + /// If the vector is zero-length or empty. + /// If vector is null. + public static float[] NormalizeVector(float[] vector) + { + ArgumentNullException.ThrowIfNull(vector, nameof(vector)); + + if (vector.Length == 0) + { + throw new ArgumentException("Cannot normalize empty vector", nameof(vector)); + } + + // Calculate magnitude (L2 norm) + double sumOfSquares = 0.0; + for (int i = 0; i < vector.Length; i++) + { + sumOfSquares += vector[i] * (double)vector[i]; + } + + var magnitude = Math.Sqrt(sumOfSquares); + + if (magnitude < double.Epsilon) + { + throw new ArgumentException("Cannot normalize zero vector", nameof(vector)); + } + + // Create normalized vector + var normalized = new float[vector.Length]; + var magnitudeF = (float)magnitude; + for (int i = 0; i < vector.Length; i++) + { + normalized[i] = vector[i] / magnitudeF; + } + + return normalized; + } + + /// + /// Computes dot product of two vectors. + /// For normalized vectors, this equals cosine similarity. + /// + /// First vector. + /// Second vector. + /// Dot product value (range -1 to 1 for normalized vectors). + /// If vectors have different lengths. + /// If either vector is null. + public static double DotProduct(float[] a, float[] b) + { + ArgumentNullException.ThrowIfNull(a, nameof(a)); + ArgumentNullException.ThrowIfNull(b, nameof(b)); + + if (a.Length != b.Length) + { + throw new ArgumentException($"Vectors must have same length: {a.Length} vs {b.Length}"); + } + + double sum = 0.0; + for (int i = 0; i < a.Length; i++) + { + sum += a[i] * (double)b[i]; + } + + return sum; + } + + /// + /// Serializes a float32 vector to a byte array (BLOB). + /// Uses little-endian format for cross-platform compatibility. + /// + /// The vector to serialize. + /// Byte array representation. + /// If vector is null. + public static byte[] VectorToBlob(float[] vector) + { + ArgumentNullException.ThrowIfNull(vector, nameof(vector)); + + var blob = new byte[vector.Length * sizeof(float)]; + for (int i = 0; i < vector.Length; i++) + { + BinaryPrimitives.WriteSingleLittleEndian(blob.AsSpan(i * sizeof(float)), vector[i]); + } + + return blob; + } + + /// + /// Deserializes a byte array (BLOB) to a float32 vector. + /// Expects little-endian format. + /// + /// The byte array to deserialize. + /// Float array representation. + /// If blob is null. + /// If blob length is not divisible by sizeof(float). + public static float[] BlobToVector(byte[] blob) + { + ArgumentNullException.ThrowIfNull(blob, nameof(blob)); + + if (blob.Length % sizeof(float) != 0) + { + throw new ArgumentException($"BLOB length {blob.Length} is not divisible by sizeof(float)", nameof(blob)); + } + + var vector = new float[blob.Length / sizeof(float)]; + for (int i = 0; i < vector.Length; i++) + { + vector[i] = BinaryPrimitives.ReadSingleLittleEndian(blob.AsSpan(i * sizeof(float))); + } + + return vector; + } +} diff --git a/src/Core/Storage/ContentStorageDbContext.cs b/src/Core/Storage/ContentStorageDbContext.cs index 39a8974b1..e1091b447 100644 --- a/src/Core/Storage/ContentStorageDbContext.cs +++ b/src/Core/Storage/ContentStorageDbContext.cs @@ -41,7 +41,7 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) entity.Property(e => e.MimeType) .IsRequired() - .HasMaxLength(255); + .HasMaxLength(Constants.Database.MaxMimeTypeLength); entity.Property(e => e.ByteSize) .IsRequired(); diff --git a/src/Core/Storage/ContentStorageService.cs b/src/Core/Storage/ContentStorageService.cs index 123de8a84..e1839acf0 100644 --- a/src/Core/Storage/ContentStorageService.cs +++ b/src/Core/Storage/ContentStorageService.cs @@ -72,18 +72,29 @@ public async Task UpsertAsync(UpsertRequest request, CancellationTo // Phase 1: Queue the operation (MUST succeed - throws if fails) var operationId = await this.QueueUpsertOperationAsync(contentId, request, cancellationToken).ConfigureAwait(false); - this._logger.LogDebug("Phase 1 complete: Operation {OperationId} queued for content {ContentId}", operationId, contentId); + this._logger.LogInformation("Upsert queued successfully - ContentId: {ContentId}, OperationId: {OperationId}, MimeType: {MimeType}, Size: {ByteSize} bytes", + contentId, operationId, request.MimeType, request.Content.Length); // Phase 2: Try to cancel superseded operations (best effort) try { - await this.TryCancelSupersededUpsertOperationsAsync(contentId, operationId, cancellationToken).ConfigureAwait(false); - this._logger.LogDebug("Phase 2 complete: Cancelled superseded operations for content {ContentId}", contentId); + var cancelledCount = await this.TryCancelSupersededUpsertOperationsAsync(contentId, operationId, cancellationToken).ConfigureAwait(false); + if (cancelledCount > 0) + { + this._logger.LogInformation("Superseded {CancelledCount} older operation(s) for ContentId: {ContentId}, keeping latest OperationId: {OperationId}", + cancelledCount, contentId, operationId); + } + else + { + this._logger.LogDebug("No superseded operations to cancel for ContentId: {ContentId}, OperationId: {OperationId}", + contentId, operationId); + } } catch (Exception ex) { // Best effort - log but don't fail - this._logger.LogWarning(ex, "Phase 2 failed to cancel superseded operations for content {ContentId} - continuing anyway", contentId); + this._logger.LogWarning(ex, "Failed to cancel superseded operations for ContentId: {ContentId}, OperationId: {OperationId} - continuing anyway", + contentId, operationId); } // Processing: Try to process the new operation synchronously @@ -114,18 +125,29 @@ public async Task DeleteAsync(string id, CancellationToken cancella // Phase 1: Queue the operation (MUST succeed - throws if fails) var operationId = await this.QueueDeleteOperationAsync(id, cancellationToken).ConfigureAwait(false); - this._logger.LogDebug("Phase 1 complete: Operation {OperationId} queued for content {ContentId}", operationId, id); + this._logger.LogInformation("Delete queued successfully - ContentId: {ContentId}, OperationId: {OperationId}", + id, operationId); // Phase 2: Try to cancel ALL previous operations (best effort) try { - await this.TryCancelAllOperationsAsync(id, operationId, cancellationToken).ConfigureAwait(false); - this._logger.LogDebug("Phase 2 complete: Cancelled all previous operations for content {ContentId}", id); + var cancelledCount = await this.TryCancelAllOperationsAsync(id, operationId, cancellationToken).ConfigureAwait(false); + if (cancelledCount > 0) + { + this._logger.LogInformation("Cancelled {CancelledCount} previous operation(s) for ContentId: {ContentId}, proceeding with delete OperationId: {OperationId}", + cancelledCount, id, operationId); + } + else + { + this._logger.LogDebug("No previous operations to cancel for ContentId: {ContentId}, OperationId: {OperationId}", + id, operationId); + } } catch (Exception ex) { // Best effort - log but don't fail - this._logger.LogWarning(ex, "Phase 2 failed to cancel previous operations for content {ContentId} - continuing anyway", id); + this._logger.LogWarning(ex, "Failed to cancel previous operations for ContentId: {ContentId}, OperationId: {OperationId} - continuing anyway", + id, operationId); } // Processing: Try to process the new operation synchronously @@ -305,7 +327,7 @@ private async Task QueueDeleteOperationAsync(string contentId, Cancellat /// /// /// - private async Task TryCancelSupersededUpsertOperationsAsync(string contentId, string newOperationId, CancellationToken cancellationToken) + private async Task TryCancelSupersededUpsertOperationsAsync(string contentId, string newOperationId, CancellationToken cancellationToken) { // Find incomplete operations with same ContentId and older Timestamp // Exclude Delete operations (they must complete) @@ -332,6 +354,8 @@ private async Task TryCancelSupersededUpsertOperationsAsync(string contentId, st { await this._context.SaveChangesAsync(cancellationToken).ConfigureAwait(false); } + + return superseded.Count; } /// @@ -341,7 +365,7 @@ private async Task TryCancelSupersededUpsertOperationsAsync(string contentId, st /// /// /// - private async Task TryCancelAllOperationsAsync(string contentId, string newOperationId, CancellationToken cancellationToken) + private async Task TryCancelAllOperationsAsync(string contentId, string newOperationId, CancellationToken cancellationToken) { // Find incomplete operations with same ContentId and older Timestamp var timestamp = await this._context.Operations @@ -366,6 +390,8 @@ private async Task TryCancelAllOperationsAsync(string contentId, string newOpera { await this._context.SaveChangesAsync(cancellationToken).ConfigureAwait(false); } + + return superseded.Count; } // ========== Processing: Execute Operations ========== diff --git a/src/Directory.Packages.props b/src/Directory.Packages.props index 9a36b4a1e..2d1053e60 100644 --- a/src/Directory.Packages.props +++ b/src/Directory.Packages.props @@ -7,15 +7,15 @@ - - + + - - - - - + + + + + diff --git a/src/Main/CLI/CliApplicationBuilder.cs b/src/Main/CLI/CliApplicationBuilder.cs index ef2c68e03..fc494a6b7 100644 --- a/src/Main/CLI/CliApplicationBuilder.cs +++ b/src/Main/CLI/CliApplicationBuilder.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core; using KernelMemory.Core.Config; using KernelMemory.Core.Logging; using KernelMemory.Main.CLI.Commands; @@ -41,6 +42,8 @@ public sealed class CliApplicationBuilder private static readonly string[] s_searchExample4 = new[] { "search", "{\"content\": \"kubernetes\"}", "--format", "json" }; private static readonly string[] s_examplesExample1 = new[] { "examples" }; private static readonly string[] s_examplesExample2 = new[] { "examples", "--command", "search" }; + private static readonly string[] s_doctorExample1 = new[] { "doctor" }; + private static readonly string[] s_doctorExample2 = new[] { "doctor", "-f", "json" }; /// /// Creates and configures a CommandApp with all CLI commands. @@ -71,6 +74,18 @@ public CommandApp Build(string[]? args = null) // 4. Create logger factory using Serilog ILoggerFactory loggerFactory = SerilogFactory.CreateLoggerFactory(loggingConfig); + var bootstrapLogger = loggerFactory.CreateLogger(); + var commandName = actualArgs.Length > 0 ? actualArgs[0] : ""; + bootstrapLogger.LogInformation( + "km CLI starting. Command={CommandName}, ConfigPath={ConfigPath}, LogFile={LogFile}", + commandName, + configPath, + loggingConfig.FilePath ?? "(not set)"); + + if (actualArgs.Length > 1) + { + bootstrapLogger.LogDebug("km CLI args: {Args}", string.Join(" ", actualArgs)); + } // 5. Create DI container and register services ServiceCollection services = new(); @@ -109,8 +124,8 @@ private string DetermineConfigPath(string[] args) // Default: ~/.km/config.json return Path.Combine( Environment.GetFolderPath(Environment.SpecialFolder.UserProfile), - Constants.DefaultConfigDirName, - Constants.DefaultConfigFileName); + Constants.ConfigDefaults.DefaultConfigDirName, + Constants.ConfigDefaults.DefaultConfigFileName); } /// @@ -241,7 +256,13 @@ public void Configure(CommandApp app) .WithExample(s_examplesExample1) .WithExample(s_examplesExample2); + // Doctor command + config.AddCommand("doctor") + .WithDescription("Validate configuration and check system health") + .WithExample(s_doctorExample1) + .WithExample(s_doctorExample2); + config.ValidateExamples(); }); } -} \ No newline at end of file +} diff --git a/src/Main/CLI/Commands/BaseCommand.cs b/src/Main/CLI/Commands/BaseCommand.cs index e2c393215..3a360e03a 100644 --- a/src/Main/CLI/Commands/BaseCommand.cs +++ b/src/Main/CLI/Commands/BaseCommand.cs @@ -1,8 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. using System.Diagnostics.CodeAnalysis; +using KernelMemory.Core; using KernelMemory.Core.Config; using KernelMemory.Core.Config.ContentIndex; +using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Embeddings.Cache; using KernelMemory.Core.Storage; using KernelMemory.Main.CLI.Exceptions; using KernelMemory.Main.CLI.OutputFormatters; @@ -130,9 +133,34 @@ protected ContentService CreateContentService(NodeConfig node, bool readonlyMode // Create dependencies var cuidGenerator = new CuidGenerator(); var logger = this._loggerFactory.CreateLogger(); + var httpClient = new HttpClient(); - // Create search indexes from node configuration using injected logger factory - var searchIndexes = SearchIndexFactory.CreateIndexes(node.SearchIndexes, this._loggerFactory); + // Create embedding cache if configured + IEmbeddingCache? embeddingCache = null; + if (this._config.EmbeddingsCache != null) + { + var cachePath = this._config.EmbeddingsCache.Path + ?? throw new InvalidOperationException("Embeddings cache path is required"); + var cacheLogger = this._loggerFactory.CreateLogger(); + + // Determine cache mode from allowRead/allowWrite flags + var cacheMode = (this._config.EmbeddingsCache.AllowRead, this._config.EmbeddingsCache.AllowWrite) switch + { + (true, true) => CacheModes.ReadWrite, + (true, false) => CacheModes.ReadOnly, + (false, true) => CacheModes.WriteOnly, + (false, false) => throw new InvalidOperationException("Embeddings cache must allow at least read or write") + }; + + embeddingCache = new SqliteEmbeddingCache(cachePath, cacheMode, cacheLogger); + } + + // Create all search indexes from node configuration + var searchIndexes = SearchIndexFactory.CreateIndexes( + node.SearchIndexes, + httpClient, + embeddingCache, + this._loggerFactory); // Create storage service with search indexes var storage = new ContentStorageService(context, cuidGenerator, logger, searchIndexes); @@ -154,10 +182,10 @@ protected int HandleError(Exception ex, IOutputFormatter formatter) // User errors: InvalidOperationException, ArgumentException if (ex is InvalidOperationException or ArgumentException) { - return Constants.ExitCodeUserError; + return Constants.App.ExitCodeUserError; } // System errors: everything else - return Constants.ExitCodeSystemError; + return Constants.App.ExitCodeSystemError; } } diff --git a/src/Main/CLI/Commands/ConfigCommand.cs b/src/Main/CLI/Commands/ConfigCommand.cs index f651705e2..a5c3821f2 100644 --- a/src/Main/CLI/Commands/ConfigCommand.cs +++ b/src/Main/CLI/Commands/ConfigCommand.cs @@ -4,6 +4,7 @@ using System.Diagnostics.CodeAnalysis; using System.Text.Json; using System.Text.Json.Serialization; +using KernelMemory.Core; using KernelMemory.Core.Config; using KernelMemory.Main.CLI.Infrastructure; using KernelMemory.Main.CLI.Models; @@ -62,9 +63,10 @@ public ConfigCommand( [SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "Top-level command handler must catch all exceptions to return appropriate exit codes and error messages")] - public async Task ExecuteAsync( + public override async Task ExecuteAsync( CommandContext context, - ConfigCommandSettings settings) + ConfigCommandSettings settings, + CancellationToken cancellationToken) { try { @@ -131,7 +133,7 @@ public async Task ExecuteAsync( formatter.Format(output); - return Constants.ExitCodeSuccess; + return Constants.App.ExitCodeSuccess; } catch (Exception ex) { @@ -140,11 +142,6 @@ public async Task ExecuteAsync( } } - public override Task ExecuteAsync(CommandContext context, ConfigCommandSettings settings, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } - /// /// Handles the --create flag to write the configuration to disk. /// @@ -162,7 +159,7 @@ private int HandleCreateConfig(string configPath, bool configFileExists, IOutput if (configFileExists) { formatter.FormatError($"Configuration file already exists: {configPath}"); - return Constants.ExitCodeUserError; + return Constants.App.ExitCodeUserError; } // Ensure directory exists @@ -180,12 +177,12 @@ private int HandleCreateConfig(string configPath, bool configFileExists, IOutput formatter.Format(new { Message = $"Configuration file created: {configPath}" }); - return Constants.ExitCodeSuccess; + return Constants.App.ExitCodeSuccess; } catch (Exception ex) { formatter.FormatError($"Failed to create configuration file: {ex.Message}"); - return Constants.ExitCodeSystemError; + return Constants.App.ExitCodeSystemError; } } } diff --git a/src/Main/CLI/Commands/DeleteCommand.cs b/src/Main/CLI/Commands/DeleteCommand.cs index 35ac6ac0b..721d7f843 100644 --- a/src/Main/CLI/Commands/DeleteCommand.cs +++ b/src/Main/CLI/Commands/DeleteCommand.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.ComponentModel; +using KernelMemory.Core; using KernelMemory.Core.Config; using KernelMemory.Main.CLI.OutputFormatters; using Microsoft.Extensions.Logging; @@ -78,7 +79,7 @@ public override async Task ExecuteAsync( }); } - return Constants.ExitCodeSuccess; + return Constants.App.ExitCodeSuccess; } catch (Exception ex) { diff --git a/src/Main/CLI/Commands/DoctorCommand.cs b/src/Main/CLI/Commands/DoctorCommand.cs new file mode 100644 index 000000000..9f167aaa3 --- /dev/null +++ b/src/Main/CLI/Commands/DoctorCommand.cs @@ -0,0 +1,959 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; +using System.Net.Http.Json; +using System.Text.Json; +using System.Text.Json.Serialization; +using KernelMemory.Core; +using KernelMemory.Core.Config; +using KernelMemory.Core.Config.Cache; +using KernelMemory.Core.Config.ContentIndex; +using KernelMemory.Core.Config.Embeddings; +using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Config.SearchIndex; +using KernelMemory.Main.CLI.Infrastructure; +using Microsoft.Extensions.Logging; +using Spectre.Console; +using Spectre.Console.Cli; + +namespace KernelMemory.Main.CLI.Commands; + +/// +/// Diagnostic levels for health checks. +/// OK = working, Warning = will work but suboptimal, Error = will not work. +/// +public enum DiagnosticLevels +{ + OK, + Warning, + Error +} + +/// +/// Result of a single diagnostic check. +/// Includes component name, status, message, and optional node association. +/// +public sealed record DiagnosticResult +{ + /// + /// Name of the component being checked (e.g., "Config file", "Content index"). + /// + public required string Component { get; init; } + + /// + /// Diagnostic level indicating severity. + /// + public required DiagnosticLevels Level { get; init; } + + /// + /// Human-readable description of the check result. + /// + public required string Message { get; init; } + + /// + /// Node ID this check belongs to, or null for global checks. + /// Used for grouping output by node. + /// + public string? NodeId { get; init; } +} + +/// +/// Command to validate configuration and check system health. +/// Checks config file, content indexes, search indexes (FTS/vector), and caches. +/// Groups output by node for clarity when multiple nodes are configured. +/// +public sealed class DoctorCommand : AsyncCommand, IDisposable +{ + private static readonly JsonSerializerOptions s_jsonOptions = new() + { + WriteIndented = true, + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }; + + private readonly AppConfig _config; + private readonly ILoggerFactory _loggerFactory; + private readonly ILogger _logger; + private readonly ConfigPathService _configPathService; + private readonly HttpClient _httpClient; + + /// + /// Initializes a new instance of the class. + /// + /// Application configuration (injected by DI). + /// Logger factory for creating loggers (injected by DI). + /// Service providing the config file path (injected by DI). + public DoctorCommand( + AppConfig config, + ILoggerFactory loggerFactory, + ConfigPathService configPathService) + { + this._config = config ?? throw new ArgumentNullException(nameof(config)); + this._loggerFactory = loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory)); + this._logger = loggerFactory.CreateLogger(); + this._configPathService = configPathService ?? throw new ArgumentNullException(nameof(configPathService)); + this._httpClient = new HttpClient { Timeout = TimeSpan.FromSeconds(10) }; + } + + /// + /// Constructor for testing without ConfigPathService. + /// Uses a default config path for diagnostic purposes. + /// + /// Application configuration. + /// Logger factory for creating loggers. + internal DoctorCommand(AppConfig config, ILoggerFactory loggerFactory) + : this(config, loggerFactory, new ConfigPathService(GetDefaultConfigPath())) + { + } + + /// + /// Disposes the HTTP client. + /// + public void Dispose() + { + this._httpClient.Dispose(); + } + + [SuppressMessage("Design", "CA1031:Do not catch general exception types", + Justification = "Top-level command handler must catch all exceptions to return appropriate exit codes")] + public override async Task ExecuteAsync( + CommandContext context, + DoctorCommandSettings settings, + CancellationToken cancellationToken) + { + var results = new List(); + + try + { + // Global check: Config file + results.Add(this.CheckConfigFile()); + + // Per-node checks + foreach (var (nodeId, nodeConfig) in this._config.Nodes) + { + // Content index check + results.Add(this.CheckContentIndex(nodeId, nodeConfig.ContentIndex)); + + // Search index checks + foreach (var searchIndex in nodeConfig.SearchIndexes) + { + var indexResult = await this.CheckSearchIndexAsync(nodeId, searchIndex, cancellationToken) + .ConfigureAwait(false); + results.Add(indexResult); + } + } + + // Global check: Embeddings cache + if (this._config.EmbeddingsCache != null) + { + results.Add(this.CheckCache("Embeddings cache", this._config.EmbeddingsCache)); + } + + // Global check: LLM cache + if (this._config.LLMCache != null) + { + results.Add(this.CheckCache("LLM cache", this._config.LLMCache)); + } + } + catch (Exception ex) + { + this._logger.LogError(ex, "Unexpected error during doctor checks"); + results.Add(new DiagnosticResult + { + Component = "Doctor command", + Level = DiagnosticLevels.Error, + Message = $"Unexpected error: {ex.Message}" + }); + } + + // Display results + this.DisplayResults(results, settings); + + // Return appropriate exit code + var hasErrors = results.Any(r => r.Level == DiagnosticLevels.Error); + return hasErrors ? Constants.App.ExitCodeUserError : Constants.App.ExitCodeSuccess; + } + + /// + /// Checks the configuration file accessibility. + /// + private DiagnosticResult CheckConfigFile() + { + var configPath = this._configPathService.Path; + + if (!File.Exists(configPath)) + { + return new DiagnosticResult + { + Component = "Config file", + Level = DiagnosticLevels.Warning, + Message = $"Using default configuration, file does not exist: {configPath}" + }; + } + + try + { + var fileInfo = new FileInfo(configPath); + + // Actually test read access by opening the file + using var stream = File.OpenRead(configPath); + var canRead = stream.CanRead; + + if (!canRead) + { + return new DiagnosticResult + { + Component = "Config file", + Level = DiagnosticLevels.Error, + Message = $"Cannot read config file: {configPath}" + }; + } + + return new DiagnosticResult + { + Component = "Config file", + Level = DiagnosticLevels.OK, + Message = $"{configPath} readable ({fileInfo.Length} bytes)" + }; + } + catch (UnauthorizedAccessException) + { + return new DiagnosticResult + { + Component = "Config file", + Level = DiagnosticLevels.Error, + Message = $"Permission denied reading config file: {configPath}" + }; + } + catch (IOException ex) + { + return new DiagnosticResult + { + Component = "Config file", + Level = DiagnosticLevels.Error, + Message = $"Error reading config file: {ex.Message}" + }; + } + } + + /// + /// Checks the content index (SQLite database) accessibility. + /// + private DiagnosticResult CheckContentIndex(string nodeId, ContentIndexConfig config) + { + if (config is not SqliteContentIndexConfig sqliteConfig) + { + return new DiagnosticResult + { + Component = "Content index", + Level = DiagnosticLevels.Error, + Message = $"Unsupported content index type: {config.Type}", + NodeId = nodeId + }; + } + + var dbPath = sqliteConfig.Path; + var dirPath = Path.GetDirectoryName(dbPath); + + if (File.Exists(dbPath)) + { + // Database exists - test read/write access + try + { + using var stream = File.Open(dbPath, FileMode.Open, FileAccess.ReadWrite); + return new DiagnosticResult + { + Component = "Content index", + Level = DiagnosticLevels.OK, + Message = $"Content database readable at {dbPath}", + NodeId = nodeId + }; + } + catch (UnauthorizedAccessException) + { + return new DiagnosticResult + { + Component = "Content index", + Level = DiagnosticLevels.Error, + Message = $"Permission denied accessing database: {dbPath}", + NodeId = nodeId + }; + } + catch (IOException ex) + { + return new DiagnosticResult + { + Component = "Content index", + Level = DiagnosticLevels.Error, + Message = $"Error accessing database: {ex.Message}", + NodeId = nodeId + }; + } + } + + // Database doesn't exist - check if directory is writable + if (!string.IsNullOrEmpty(dirPath)) + { + if (!Directory.Exists(dirPath)) + { + // Try to create the directory to test permissions + try + { + Directory.CreateDirectory(dirPath); + return new DiagnosticResult + { + Component = "Content index", + Level = DiagnosticLevels.Warning, + Message = $"Directory writable, content database will be created at {dbPath}", + NodeId = nodeId + }; + } + catch (UnauthorizedAccessException) + { + return new DiagnosticResult + { + Component = "Content index", + Level = DiagnosticLevels.Error, + Message = $"Permission denied creating directory: {dirPath}", + NodeId = nodeId + }; + } + catch (IOException ex) + { + return new DiagnosticResult + { + Component = "Content index", + Level = DiagnosticLevels.Error, + Message = $"Error creating directory: {ex.Message}", + NodeId = nodeId + }; + } + } + + // Directory exists - test write permissions by creating a temp file + var canWrite = this.CanWriteToDirectory(dirPath); + if (canWrite) + { + return new DiagnosticResult + { + Component = "Content index", + Level = DiagnosticLevels.Warning, + Message = $"Directory writable, content database will be created at {dbPath}", + NodeId = nodeId + }; + } + + return new DiagnosticResult + { + Component = "Content index", + Level = DiagnosticLevels.Error, + Message = $"Directory not writable: {dirPath}", + NodeId = nodeId + }; + } + + return new DiagnosticResult + { + Component = "Content index", + Level = DiagnosticLevels.Error, + Message = "Invalid database path configuration", + NodeId = nodeId + }; + } + + /// + /// Checks a search index configuration and connectivity. + /// + private async Task CheckSearchIndexAsync( + string nodeId, + SearchIndexConfig config, + CancellationToken cancellationToken) + { + return config switch + { + FtsSearchIndexConfig ftsConfig => this.CheckFtsIndex(nodeId, ftsConfig), + VectorSearchIndexConfig vectorConfig => await this.CheckVectorIndexAsync(nodeId, vectorConfig, cancellationToken) + .ConfigureAwait(false), + _ => new DiagnosticResult + { + Component = $"Search index '{config.Id}'", + Level = DiagnosticLevels.Warning, + Message = $"Unknown search index type: {config.GetType().Name}", + NodeId = nodeId + } + }; + } + + /// + /// Checks an FTS index (SQLite FTS5 database) accessibility. + /// + private DiagnosticResult CheckFtsIndex(string nodeId, FtsSearchIndexConfig config) + { + var dbPath = config.Path; + if (string.IsNullOrEmpty(dbPath)) + { + return new DiagnosticResult + { + Component = $"FTS index '{config.Id}'", + Level = DiagnosticLevels.Error, + Message = "FTS index path not configured", + NodeId = nodeId + }; + } + + var dirPath = Path.GetDirectoryName(dbPath); + + if (File.Exists(dbPath)) + { + try + { + using var stream = File.Open(dbPath, FileMode.Open, FileAccess.ReadWrite); + return new DiagnosticResult + { + Component = $"FTS index '{config.Id}'", + Level = DiagnosticLevels.OK, + Message = $"FTS database readable at {dbPath}", + NodeId = nodeId + }; + } + catch (UnauthorizedAccessException) + { + return new DiagnosticResult + { + Component = $"FTS index '{config.Id}'", + Level = DiagnosticLevels.Error, + Message = $"Permission denied accessing FTS database: {dbPath}", + NodeId = nodeId + }; + } + catch (IOException ex) + { + return new DiagnosticResult + { + Component = $"FTS index '{config.Id}'", + Level = DiagnosticLevels.Error, + Message = $"Error accessing FTS database: {ex.Message}", + NodeId = nodeId + }; + } + } + + // Database doesn't exist - check directory write permissions + if (!string.IsNullOrEmpty(dirPath)) + { + if (!Directory.Exists(dirPath)) + { + try + { + Directory.CreateDirectory(dirPath); + return new DiagnosticResult + { + Component = $"FTS index '{config.Id}'", + Level = DiagnosticLevels.Warning, + Message = $"Directory writable, FTS database will be created at {dbPath}", + NodeId = nodeId + }; + } + catch (Exception ex) + { + return new DiagnosticResult + { + Component = $"FTS index '{config.Id}'", + Level = DiagnosticLevels.Error, + Message = $"Cannot create directory: {ex.Message}", + NodeId = nodeId + }; + } + } + + var canWrite = this.CanWriteToDirectory(dirPath); + if (canWrite) + { + return new DiagnosticResult + { + Component = $"FTS index '{config.Id}'", + Level = DiagnosticLevels.Warning, + Message = $"Directory writable, FTS database will be created at {dbPath}", + NodeId = nodeId + }; + } + + return new DiagnosticResult + { + Component = $"FTS index '{config.Id}'", + Level = DiagnosticLevels.Error, + Message = $"Directory not writable: {dirPath}", + NodeId = nodeId + }; + } + + return new DiagnosticResult + { + Component = $"FTS index '{config.Id}'", + Level = DiagnosticLevels.Error, + Message = "Invalid FTS database path configuration", + NodeId = nodeId + }; + } + + /// + /// Checks a vector index, including embeddings provider connectivity. + /// + private async Task CheckVectorIndexAsync( + string nodeId, + VectorSearchIndexConfig config, + CancellationToken cancellationToken) + { + // First check database accessibility + var dbPath = config.Path; + if (string.IsNullOrEmpty(dbPath)) + { + return new DiagnosticResult + { + Component = $"Vector index '{config.Id}'", + Level = DiagnosticLevels.Error, + Message = "Vector index path not configured", + NodeId = nodeId + }; + } + + // Check embeddings provider if configured + if (config.Embeddings == null) + { + return new DiagnosticResult + { + Component = $"Vector index '{config.Id}'", + Level = DiagnosticLevels.Error, + Message = "No embeddings provider configured for vector index", + NodeId = nodeId + }; + } + + return config.Embeddings switch + { + OllamaEmbeddingsConfig ollamaConfig => await this.CheckOllamaEmbeddingsAsync( + nodeId, config.Id, ollamaConfig, config.Dimensions, cancellationToken).ConfigureAwait(false), + OpenAIEmbeddingsConfig openAiConfig => this.CheckOpenAIEmbeddings(nodeId, config.Id, openAiConfig, config.Dimensions), + _ => new DiagnosticResult + { + Component = $"Vector index '{config.Id}'", + Level = DiagnosticLevels.Warning, + Message = $"Unsupported embeddings provider: {config.Embeddings.GetType().Name}", + NodeId = nodeId + } + }; + } + + /// + /// Checks Ollama embeddings provider by actually calling the API. + /// + private async Task CheckOllamaEmbeddingsAsync( + string nodeId, + string indexId, + OllamaEmbeddingsConfig config, + int expectedDimensions, + CancellationToken cancellationToken) + { + try + { + // Actually test embedding generation with a POST request + var endpoint = $"{config.BaseUrl.TrimEnd('/')}/api/embed"; + var request = new { model = config.Model, input = "test" }; + + using var response = await this._httpClient.PostAsJsonAsync(endpoint, request, cancellationToken) + .ConfigureAwait(false); + + if (!response.IsSuccessStatusCode) + { + var errorContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + return new DiagnosticResult + { + Component = $"Vector index '{indexId}'", + Level = DiagnosticLevels.Error, + Message = $"Ollama API error ({response.StatusCode}): {errorContent.Substring(0, Math.Min(100, errorContent.Length))}", + NodeId = nodeId + }; + } + + // Parse response to verify dimensions + var responseJson = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + using var doc = JsonDocument.Parse(responseJson); + + if (doc.RootElement.TryGetProperty("embeddings", out var embeddingsArray) && + embeddingsArray.GetArrayLength() > 0) + { + var firstEmbedding = embeddingsArray[0]; + var actualDimensions = firstEmbedding.GetArrayLength(); + + if (actualDimensions != expectedDimensions) + { + return new DiagnosticResult + { + Component = $"Vector index '{indexId}'", + Level = DiagnosticLevels.Error, + Message = $"Dimension mismatch: model produces {actualDimensions}D, config expects {expectedDimensions}D", + NodeId = nodeId + }; + } + + return new DiagnosticResult + { + Component = $"Vector index '{indexId}'", + Level = DiagnosticLevels.OK, + Message = $"Ollama embeddings working ({config.Model}, {actualDimensions}D)", + NodeId = nodeId + }; + } + + return new DiagnosticResult + { + Component = $"Vector index '{indexId}'", + Level = DiagnosticLevels.Warning, + Message = "Ollama responded but could not verify dimensions", + NodeId = nodeId + }; + } + catch (HttpRequestException ex) + { + return new DiagnosticResult + { + Component = $"Vector index '{indexId}'", + Level = DiagnosticLevels.Error, + Message = $"Cannot connect to Ollama at {config.BaseUrl}: {ex.Message}", + NodeId = nodeId + }; + } + catch (TaskCanceledException) + { + return new DiagnosticResult + { + Component = $"Vector index '{indexId}'", + Level = DiagnosticLevels.Error, + Message = $"Timeout connecting to Ollama at {config.BaseUrl}", + NodeId = nodeId + }; + } + catch (JsonException ex) + { + return new DiagnosticResult + { + Component = $"Vector index '{indexId}'", + Level = DiagnosticLevels.Warning, + Message = $"Ollama responded but response parsing failed: {ex.Message}", + NodeId = nodeId + }; + } + } + + /// + /// Checks OpenAI embeddings configuration (API key presence, not connectivity). + /// + private DiagnosticResult CheckOpenAIEmbeddings( + string nodeId, + string indexId, + OpenAIEmbeddingsConfig config, + int expectedDimensions) + { + // Check if API key is configured + var apiKey = config.ApiKey; + + // Try environment variable if not set directly + if (string.IsNullOrWhiteSpace(apiKey)) + { + apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + } + + if (string.IsNullOrWhiteSpace(apiKey)) + { + return new DiagnosticResult + { + Component = $"Vector index '{indexId}'", + Level = DiagnosticLevels.Error, + Message = "OpenAI API key not configured (set in config or OPENAI_API_KEY env var)", + NodeId = nodeId + }; + } + + // Verify dimensions match known model dimensions + if (Constants.EmbeddingDefaults.TryGetDimensions(config.Model, out var knownDimensions)) + { + if (knownDimensions != expectedDimensions) + { + return new DiagnosticResult + { + Component = $"Vector index '{indexId}'", + Level = DiagnosticLevels.Error, + Message = $"Dimension mismatch: {config.Model} produces {knownDimensions}D, config expects {expectedDimensions}D", + NodeId = nodeId + }; + } + } + + return new DiagnosticResult + { + Component = $"Vector index '{indexId}'", + Level = DiagnosticLevels.OK, + Message = $"OpenAI API key configured ({config.Model}, {expectedDimensions}D)", + NodeId = nodeId + }; + } + + /// + /// Checks cache configuration and accessibility. + /// + private DiagnosticResult CheckCache(string name, CacheConfig config) + { + if (config.Type != CacheTypes.Sqlite || string.IsNullOrEmpty(config.Path)) + { + return new DiagnosticResult + { + Component = name, + Level = DiagnosticLevels.Warning, + Message = "Non-SQLite cache or missing path" + }; + } + + var dbPath = config.Path; + var dirPath = Path.GetDirectoryName(dbPath); + + if (File.Exists(dbPath)) + { + try + { + using var stream = File.Open(dbPath, FileMode.Open, FileAccess.ReadWrite); + return new DiagnosticResult + { + Component = name, + Level = DiagnosticLevels.OK, + Message = $"Cache database readable at {dbPath}" + }; + } + catch (Exception ex) + { + return new DiagnosticResult + { + Component = name, + Level = DiagnosticLevels.Error, + Message = $"Error accessing cache: {ex.Message}" + }; + } + } + + // Database doesn't exist - check directory write permissions + if (!string.IsNullOrEmpty(dirPath)) + { + if (!Directory.Exists(dirPath)) + { + try + { + Directory.CreateDirectory(dirPath); + return new DiagnosticResult + { + Component = name, + Level = DiagnosticLevels.Warning, + Message = $"Directory writable, cache database will be created at {dbPath}" + }; + } + catch (Exception ex) + { + return new DiagnosticResult + { + Component = name, + Level = DiagnosticLevels.Error, + Message = $"Cannot create directory: {ex.Message}" + }; + } + } + + var canWrite = this.CanWriteToDirectory(dirPath); + if (canWrite) + { + return new DiagnosticResult + { + Component = name, + Level = DiagnosticLevels.Warning, + Message = $"Directory writable, cache database will be created at {dbPath}" + }; + } + + return new DiagnosticResult + { + Component = name, + Level = DiagnosticLevels.Error, + Message = $"Directory not writable: {dirPath}" + }; + } + + return new DiagnosticResult + { + Component = name, + Level = DiagnosticLevels.Error, + Message = "Invalid cache path configuration" + }; + } + + /// + /// Tests if a directory is writable by creating and deleting a temp file. + /// + private bool CanWriteToDirectory(string dirPath) + { + try + { + var testFile = Path.Combine(dirPath, $".km-doctor-test-{Guid.NewGuid()}"); + File.WriteAllText(testFile, "test"); + File.Delete(testFile); + return true; + } + catch + { + return false; + } + } + + /// + /// Displays results, grouped by node for clarity. + /// + private void DisplayResults(List results, DoctorCommandSettings settings) + { + // JSON output + if (settings.Format.Equals("json", StringComparison.OrdinalIgnoreCase)) + { + this.DisplayResultsAsJson(results); + return; + } + + // Human-readable output + this.DisplayResultsGroupedByNode(results, settings); + } + + /// + /// Displays results in JSON format. + /// + private void DisplayResultsAsJson(List results) + { + var output = new + { + results = results.Select(r => new + { + component = r.Component, + level = r.Level.ToString().ToLowerInvariant(), + message = r.Message, + nodeId = r.NodeId + }).ToList(), + summary = new + { + total = results.Count, + ok = results.Count(r => r.Level == DiagnosticLevels.OK), + warnings = results.Count(r => r.Level == DiagnosticLevels.Warning), + errors = results.Count(r => r.Level == DiagnosticLevels.Error) + } + }; + + var json = JsonSerializer.Serialize(output, s_jsonOptions); + Console.WriteLine(json); + } + + /// + /// Displays results grouped by node with visual formatting. + /// Global checks (NodeId == null) are shown at the top and bottom. + /// Node-specific checks are indented under node headers. + /// + private void DisplayResultsGroupedByNode(List results, DoctorCommandSettings settings) + { + var useColor = !settings.NoColor; + + // Separate global checks and node-specific checks + var globalChecks = results.Where(r => r.NodeId == null).ToList(); + var nodeGroups = results + .Where(r => r.NodeId != null) + .GroupBy(r => r.NodeId!) + .OrderBy(g => g.Key) + .ToList(); + + // Display global checks first (config file) + var configCheck = globalChecks.FirstOrDefault(r => r.Component == "Config file"); + if (configCheck != null) + { + this.DisplayCheck(configCheck, indent: 0, useColor); + AnsiConsole.WriteLine(); + } + + // Display node-grouped checks + foreach (var nodeGroup in nodeGroups) + { + // Node header in bold + if (useColor) + { + AnsiConsole.MarkupLine($"[bold]Node '{Markup.Escape(nodeGroup.Key)}':[/]"); + } + else + { + Console.WriteLine($"Node '{nodeGroup.Key}':"); + } + + // Indented checks for this node + foreach (var check in nodeGroup) + { + this.DisplayCheck(check, indent: 2, useColor); + } + + AnsiConsole.WriteLine(); + } + + // Display remaining global checks (caches) + foreach (var check in globalChecks.Where(r => r.Component != "Config file")) + { + this.DisplayCheck(check, indent: 0, useColor); + } + + // Summary line + var errorCount = results.Count(r => r.Level == DiagnosticLevels.Error); + var warningCount = results.Count(r => r.Level == DiagnosticLevels.Warning); + + AnsiConsole.WriteLine(); + if (useColor) + { + var summaryColor = errorCount > 0 ? "red" : (warningCount > 0 ? "yellow" : "green"); + AnsiConsole.MarkupLine($"[{summaryColor}]Summary: {warningCount} warning(s), {errorCount} error(s)[/]"); + } + else + { + Console.WriteLine($"Summary: {warningCount} warning(s), {errorCount} error(s)"); + } + } + + /// + /// Displays a single check result with appropriate formatting. + /// + private void DisplayCheck(DiagnosticResult result, int indent, bool useColor) + { + var prefix = new string(' ', indent); + var (symbol, color) = result.Level switch + { + DiagnosticLevels.OK => ("V", "green"), // checkmark + DiagnosticLevels.Warning => ("!", "yellow"), // warning + DiagnosticLevels.Error => ("X", "red"), // error + _ => ("?", "grey") + }; + + if (useColor) + { + AnsiConsole.MarkupLine($"{prefix}[{color}]{symbol}[/] {Markup.Escape(result.Component)}: {Markup.Escape(result.Message)}"); + } + else + { + Console.WriteLine($"{prefix}{symbol} {result.Component}: {result.Message}"); + } + } + + /// + /// Gets the default config path. + /// + private static string GetDefaultConfigPath() + { + return Path.Combine( + Environment.GetFolderPath(Environment.SpecialFolder.UserProfile), + Constants.ConfigDefaults.DefaultConfigDirName, + Constants.ConfigDefaults.DefaultConfigFileName); + } +} diff --git a/src/Main/CLI/Commands/DoctorCommandSettings.cs b/src/Main/CLI/Commands/DoctorCommandSettings.cs new file mode 100644 index 000000000..e5c4b531c --- /dev/null +++ b/src/Main/CLI/Commands/DoctorCommandSettings.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace KernelMemory.Main.CLI.Commands; + +/// +/// Settings for the doctor command. +/// The doctor command validates configuration dependencies and checks system health. +/// +public sealed class DoctorCommandSettings : GlobalOptions +{ + // Doctor command has no additional settings beyond global options + // Uses config file, node selection, and output format from GlobalOptions +} diff --git a/src/Main/CLI/Commands/GetCommand.cs b/src/Main/CLI/Commands/GetCommand.cs index 88fa1a997..93f3c642b 100644 --- a/src/Main/CLI/Commands/GetCommand.cs +++ b/src/Main/CLI/Commands/GetCommand.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.ComponentModel; +using KernelMemory.Core; using KernelMemory.Core.Config; using KernelMemory.Core.Storage.Models; using KernelMemory.Main.CLI.Exceptions; @@ -70,7 +71,7 @@ public override async Task ExecuteAsync( if (result == null) { formatter.FormatError($"Content with ID '{settings.Id}' not found"); - return Constants.ExitCodeUserError; + return Constants.App.ExitCodeUserError; } // Wrap result with node information @@ -78,13 +79,13 @@ public override async Task ExecuteAsync( formatter.Format(response); - return Constants.ExitCodeSuccess; + return Constants.App.ExitCodeSuccess; } catch (DatabaseNotFoundException) { // First-run scenario: no database exists yet (expected state) this.ShowFirstRunMessage(settings); - return Constants.ExitCodeSuccess; // Not a user error + return Constants.App.ExitCodeSuccess; // Not a user error } catch (Exception ex) { diff --git a/src/Main/CLI/Commands/ListCommand.cs b/src/Main/CLI/Commands/ListCommand.cs index fd09d038f..48b9a8b89 100644 --- a/src/Main/CLI/Commands/ListCommand.cs +++ b/src/Main/CLI/Commands/ListCommand.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.ComponentModel; +using KernelMemory.Core; using KernelMemory.Core.Config; using KernelMemory.Core.Storage.Models; using KernelMemory.Main.CLI.Exceptions; @@ -24,7 +25,7 @@ public class ListCommandSettings : GlobalOptions [CommandOption("--take")] [Description("Number of items to take (default: 20)")] [DefaultValue(20)] - public int Take { get; init; } = Constants.DefaultPageSize; + public int Take { get; init; } = Constants.App.DefaultPageSize; public override ValidationResult Validate() { @@ -67,9 +68,10 @@ public override async Task ExecuteAsync( ListCommandSettings settings, CancellationToken cancellationToken) { + var (config, node, formatter) = this.Initialize(settings); + try { - var (config, node, formatter) = this.Initialize(settings); using var service = this.CreateContentService(node, readonlyMode: true); // Get total count @@ -85,17 +87,16 @@ public override async Task ExecuteAsync( // Format list with pagination info formatter.FormatList(itemsWithNode, totalCount, settings.Skip, settings.Take); - return Constants.ExitCodeSuccess; + return Constants.App.ExitCodeSuccess; } catch (DatabaseNotFoundException) { // First-run scenario: no database exists yet (expected state) - this.ShowFirstRunMessage(settings); - return Constants.ExitCodeSuccess; // Not a user error + this.ShowFirstRunMessage(settings, node.Id); + return Constants.App.ExitCodeSuccess; // Not a user error } catch (Exception ex) { - var formatter = OutputFormatterFactory.Create(settings); return this.HandleError(ex, formatter); } } @@ -104,7 +105,8 @@ public override async Task ExecuteAsync( /// Shows a friendly first-run message when no database exists yet. /// /// Command settings for output format. - private void ShowFirstRunMessage(ListCommandSettings settings) + /// The node ID being listed. + private void ShowFirstRunMessage(ListCommandSettings settings, string nodeId) { var formatter = OutputFormatterFactory.Create(settings); @@ -116,16 +118,20 @@ private void ShowFirstRunMessage(ListCommandSettings settings) } // Human format: friendly welcome message + // Include --node parameter if not using the first (default) node + var isDefaultNode = nodeId == this.Config.Nodes.Keys.First(); + var nodeParam = isDefaultNode ? "" : $" --node {nodeId}"; + AnsiConsole.WriteLine(); AnsiConsole.MarkupLine("[bold green]Welcome to Kernel Memory! 🚀[/]"); AnsiConsole.WriteLine(); - AnsiConsole.MarkupLine("[dim]No content found yet. This is your first run.[/]"); + AnsiConsole.MarkupLine($"[dim]No content found in node '{nodeId}' yet.[/]"); AnsiConsole.WriteLine(); AnsiConsole.MarkupLine("[bold]To get started:[/]"); - AnsiConsole.MarkupLine(" [cyan]km put \"Your content here\"[/]"); + AnsiConsole.MarkupLine($" [cyan]km put \"Your content here\"{nodeParam}[/]"); AnsiConsole.WriteLine(); AnsiConsole.MarkupLine("[bold]Example:[/]"); - AnsiConsole.MarkupLine(" [cyan]km put \"Hello, world!\" --id greeting[/]"); + AnsiConsole.MarkupLine($" [cyan]km put \"Hello, world!\" --id greeting{nodeParam}[/]"); AnsiConsole.WriteLine(); } } diff --git a/src/Main/CLI/Commands/NodesCommand.cs b/src/Main/CLI/Commands/NodesCommand.cs index ec6abba7b..4bd56ae6c 100644 --- a/src/Main/CLI/Commands/NodesCommand.cs +++ b/src/Main/CLI/Commands/NodesCommand.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core; using KernelMemory.Core.Config; using KernelMemory.Main.CLI.OutputFormatters; using Microsoft.Extensions.Logging; @@ -44,7 +45,7 @@ public override async Task ExecuteAsync( // Format as list formatter.FormatList(nodeIds, totalCount, 0, totalCount); - return Constants.ExitCodeSuccess; + return Constants.App.ExitCodeSuccess; } catch (Exception ex) { diff --git a/src/Main/CLI/Commands/SearchCommand.cs b/src/Main/CLI/Commands/SearchCommand.cs index 0939db150..242bf700d 100644 --- a/src/Main/CLI/Commands/SearchCommand.cs +++ b/src/Main/CLI/Commands/SearchCommand.cs @@ -2,6 +2,7 @@ using System.ComponentModel; using System.Diagnostics.CodeAnalysis; +using KernelMemory.Core; using KernelMemory.Core.Config; using KernelMemory.Core.Search; using KernelMemory.Core.Search.Exceptions; @@ -141,6 +142,8 @@ public override ValidationResult Validate() /// public class SearchCommand : BaseCommand { + private readonly ILogger _logger; + /// /// Initializes a new instance of the class. /// @@ -148,6 +151,7 @@ public class SearchCommand : BaseCommand /// Logger factory for creating loggers (injected by DI). public SearchCommand(AppConfig config, ILoggerFactory loggerFactory) : base(config, loggerFactory) { + this._logger = loggerFactory.CreateLogger(); } public override async Task ExecuteAsync( @@ -177,19 +181,19 @@ public override async Task ExecuteAsync( // Format and display results this.FormatSearchResults(response, settings, formatter); - return Constants.ExitCodeSuccess; + return Constants.App.ExitCodeSuccess; } catch (DatabaseNotFoundException) { // First-run scenario: no database exists yet this.ShowFirstRunMessage(settings); - return Constants.ExitCodeSuccess; // Not a user error + return Constants.App.ExitCodeSuccess; // Not a user error } catch (SearchException ex) { var formatter = OutputFormatterFactory.Create(settings); formatter.FormatError($"Search error: {ex.Message}"); - return Constants.ExitCodeUserError; + return Constants.App.ExitCodeUserError; } catch (Exception ex) { @@ -251,7 +255,7 @@ private async Task ValidateQueryAsync( } } - return result.IsValid ? Constants.ExitCodeSuccess : Constants.ExitCodeUserError; + return result.IsValid ? Constants.App.ExitCodeSuccess : Constants.App.ExitCodeUserError; } /// @@ -450,49 +454,80 @@ private void FormatSearchResultsHuman(SearchResponse response, SearchCommandSett /// /// Creates a SearchService instance with all configured nodes. + /// Skips nodes with missing databases gracefully (logs warning, continues with working nodes). /// /// A configured SearchService. + /// Thrown when ALL nodes have missing databases. [SuppressMessage("Reliability", "CA2000:Dispose objects before losing scope", Justification = "ContentService instances must remain alive for the duration of the search operation. CLI commands are short-lived and process exit handles cleanup.")] private SearchService CreateSearchService() { var nodeServices = new Dictionary(); var indexWeights = new Dictionary>(); + var skippedNodes = new List(); foreach (var (nodeId, nodeConfig) in this.Config.Nodes) { - // Create ContentService for this node - // Don't dispose - NodeSearchService needs access to its Storage and SearchIndexes - var contentService = this.CreateContentService(nodeConfig, readonlyMode: true); - - // Get FTS index from the content service's registered indexes - // The content service already has FTS indexes registered and keeps them in sync - var ftsIndex = contentService.SearchIndexes.Values.OfType().FirstOrDefault(); - if (ftsIndex == null) + try { - throw new InvalidOperationException($"Node '{nodeId}' does not have an FTS index configured"); - } + // Create ContentService for this node + // Don't dispose - NodeSearchService needs access to its Storage and SearchIndexes + var contentService = this.CreateContentService(nodeConfig, readonlyMode: true); + + // Get FTS index from the content service's registered indexes + // The content service already has FTS indexes registered and keeps them in sync + var ftsIndex = contentService.SearchIndexes.Values.OfType().FirstOrDefault(); + if (ftsIndex == null) + { + this._logger.LogWarning("Skipping node '{NodeId}': No FTS index configured", nodeId); + skippedNodes.Add(nodeId); + continue; + } - // Create NodeSearchService - var nodeSearchService = new NodeSearchService( - nodeId, - ftsIndex, - contentService.Storage - ); + // Create NodeSearchService + var nodeSearchService = new NodeSearchService( + nodeId, + ftsIndex, + contentService.Storage + ); - nodeServices[nodeId] = nodeSearchService; + nodeServices[nodeId] = nodeSearchService; - // Extract index weights from configuration - if (nodeConfig.SearchIndexes.Count > 0) - { - var nodeIndexWeights = new Dictionary(); - foreach (var searchIndex in nodeConfig.SearchIndexes) + // Extract index weights from configuration + if (nodeConfig.SearchIndexes.Count > 0) { - // Use the configured weight for each search index - nodeIndexWeights[searchIndex.Id] = searchIndex.Weight; + var nodeIndexWeights = new Dictionary(); + foreach (var searchIndex in nodeConfig.SearchIndexes) + { + // Use the configured weight for each search index + nodeIndexWeights[searchIndex.Id] = searchIndex.Weight; + } + indexWeights[nodeId] = nodeIndexWeights; } - indexWeights[nodeId] = nodeIndexWeights; } + catch (DatabaseNotFoundException ex) + { + // Node's database doesn't exist - skip this node and continue with others + this._logger.LogWarning("Skipping node '{NodeId}': {Message}", nodeId, ex.Message); + skippedNodes.Add(nodeId); + } + } + + // If ALL nodes failed, throw to trigger first-run message + if (nodeServices.Count == 0) + { + throw new DatabaseNotFoundException( + $"No nodes available for search. All {skippedNodes.Count} node(s) have missing databases: {string.Join(", ", skippedNodes)}"); + } + + // Log summary if some nodes were skipped + if (skippedNodes.Count > 0) + { + this._logger.LogInformation( + "Search using {ActiveCount} of {TotalCount} nodes. Skipped: {SkippedNodes}", + nodeServices.Count, + this.Config.Nodes.Count, + string.Join(", ", skippedNodes)); } return new SearchService(nodeServices, indexWeights); diff --git a/src/Main/CLI/Commands/UpsertCommand.cs b/src/Main/CLI/Commands/UpsertCommand.cs index 2d50cbb3f..ba5eaca07 100644 --- a/src/Main/CLI/Commands/UpsertCommand.cs +++ b/src/Main/CLI/Commands/UpsertCommand.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.ComponentModel; +using KernelMemory.Core; using KernelMemory.Core.Config; using KernelMemory.Core.Storage.Models; using KernelMemory.Main.CLI.OutputFormatters; @@ -118,7 +119,7 @@ public override async Task ExecuteAsync( }); } - return Constants.ExitCodeSuccess; + return Constants.App.ExitCodeSuccess; } catch (Exception ex) { diff --git a/src/Main/CLI/ModeRouter.cs b/src/Main/CLI/ModeRouter.cs index 5d26014ec..1be2e3da9 100644 --- a/src/Main/CLI/ModeRouter.cs +++ b/src/Main/CLI/ModeRouter.cs @@ -1,5 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core; + namespace KernelMemory.Main.CLI; /// @@ -40,6 +42,6 @@ public int HandleUnimplementedMode(string mode, string description) { Console.Error.WriteLine($"Error: {mode} mode not yet implemented"); Console.Error.WriteLine(description); - return Constants.ExitCodeSystemError; + return Constants.App.ExitCodeSystemError; } } diff --git a/src/Main/CLI/OutputFormatters/HumanOutputFormatter.cs b/src/Main/CLI/OutputFormatters/HumanOutputFormatter.cs index bbff39234..c87d6d59e 100644 --- a/src/Main/CLI/OutputFormatters/HumanOutputFormatter.cs +++ b/src/Main/CLI/OutputFormatters/HumanOutputFormatter.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System.Text.Json; +using KernelMemory.Core; using KernelMemory.Core.Storage.Models; using Spectre.Console; @@ -126,9 +127,9 @@ private void FormatContent(ContentDto content) // Truncate content unless verbose var displayContent = content.Content; - if (!isVerbose && displayContent.Length > Constants.MaxContentDisplayLength) + if (!isVerbose && displayContent.Length > Constants.App.MaxContentDisplayLength) { - displayContent = string.Concat(displayContent.AsSpan(0, Constants.MaxContentDisplayLength), "..."); + displayContent = string.Concat(displayContent.AsSpan(0, Constants.App.MaxContentDisplayLength), "..."); } table.AddRow("[yellow]Content[/]", Markup.Escape(displayContent)); @@ -308,9 +309,9 @@ private void FormatContentWithNode(Core.Storage.Models.ContentDtoWithNode conten // Truncate content unless verbose var displayContent = content.Content; - if (!isVerbose && displayContent.Length > Constants.MaxContentDisplayLength) + if (!isVerbose && displayContent.Length > Constants.App.MaxContentDisplayLength) { - displayContent = string.Concat(displayContent.AsSpan(0, Constants.MaxContentDisplayLength), "..."); + displayContent = string.Concat(displayContent.AsSpan(0, Constants.App.MaxContentDisplayLength), "..."); } table.AddRow("[yellow]Content[/]", Markup.Escape(displayContent)); diff --git a/src/Main/Constants.cs b/src/Main/Constants.cs deleted file mode 100644 index 908f18638..000000000 --- a/src/Main/Constants.cs +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -namespace KernelMemory.Main; - -/// -/// Application-wide constants. -/// -public static class Constants -{ - /// - /// Default configuration file name. - /// - public const string DefaultConfigFileName = "config.json"; - - /// - /// Default configuration directory name in user's home directory. - /// - public const string DefaultConfigDirName = ".km"; - - /// - /// Exit code for successful operation. - /// - public const int ExitCodeSuccess = 0; - - /// - /// Exit code for user errors (bad input, not found, validation failure). - /// - public const int ExitCodeUserError = 1; - - /// - /// Exit code for system errors (storage failure, config error, unexpected exception). - /// - public const int ExitCodeSystemError = 2; - - /// - /// Default pagination size for list operations. - /// - public const int DefaultPageSize = 20; - - /// - /// Maximum content length to display in truncated view (characters). - /// - public const int MaxContentDisplayLength = 100; -} diff --git a/src/Main/Services/EmbeddingGeneratorFactory.cs b/src/Main/Services/EmbeddingGeneratorFactory.cs new file mode 100644 index 000000000..85053134b --- /dev/null +++ b/src/Main/Services/EmbeddingGeneratorFactory.cs @@ -0,0 +1,154 @@ +// Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core; +using KernelMemory.Core.Config.Embeddings; +using KernelMemory.Core.Embeddings; +using KernelMemory.Core.Embeddings.Cache; +using KernelMemory.Core.Embeddings.Providers; +using Microsoft.Extensions.Logging; + +namespace KernelMemory.Main.Services; + +/// +/// Factory for creating embedding generators from configuration. +/// Supports caching decorator when cache is provided. +/// +public static class EmbeddingGeneratorFactory +{ + /// + /// Creates an embedding generator from configuration. + /// + /// Embeddings configuration. + /// HTTP client for API calls. + /// Optional embedding cache (applies caching decorator if provided). + /// Logger factory for creating component loggers. + /// The embedding generator instance. + /// If configuration type is not supported. + public static IEmbeddingGenerator CreateGenerator( + EmbeddingsConfig config, + HttpClient httpClient, + IEmbeddingCache? cache, + ILoggerFactory loggerFactory) + { + ArgumentNullException.ThrowIfNull(config, nameof(config)); + ArgumentNullException.ThrowIfNull(httpClient, nameof(httpClient)); + ArgumentNullException.ThrowIfNull(loggerFactory, nameof(loggerFactory)); + + IEmbeddingGenerator innerGenerator = config switch + { + OllamaEmbeddingsConfig ollama => CreateOllamaGenerator(ollama, httpClient, loggerFactory), + OpenAIEmbeddingsConfig openai => CreateOpenAIGenerator(openai, httpClient, loggerFactory), + AzureOpenAIEmbeddingsConfig azure => CreateAzureOpenAIGenerator(azure, httpClient, loggerFactory), + HuggingFaceEmbeddingsConfig hf => CreateHuggingFaceGenerator(hf, httpClient, loggerFactory), + _ => throw new InvalidOperationException($"Unsupported embeddings config type: {config.GetType().Name}") + }; + + // Wrap with caching decorator if cache is provided + if (cache != null) + { + var cacheLogger = loggerFactory.CreateLogger(); + return new CachedEmbeddingGenerator(innerGenerator, cache, cacheLogger); + } + + return innerGenerator; + } + + /// + /// Creates an Ollama embedding generator. + /// + private static IEmbeddingGenerator CreateOllamaGenerator( + OllamaEmbeddingsConfig config, + HttpClient httpClient, + ILoggerFactory loggerFactory) + { + var logger = loggerFactory.CreateLogger(); + + // Try to get known dimensions for the model + var dimensions = Constants.EmbeddingDefaults.KnownModelDimensions.GetValueOrDefault(config.Model, defaultValue: 0); + + if (dimensions == 0) + { + // Unknown model - we'll validate on first use + dimensions = Constants.EmbeddingDefaults.KnownModelDimensions.GetValueOrDefault( + Constants.EmbeddingDefaults.DefaultOllamaModel, defaultValue: 1024); + } + + return new OllamaEmbeddingGenerator( + httpClient, + config.BaseUrl, + config.Model, + dimensions, + isNormalized: true, // Ollama models typically return normalized vectors + logger); + } + + /// + /// Creates an OpenAI embedding generator. + /// + private static IEmbeddingGenerator CreateOpenAIGenerator( + OpenAIEmbeddingsConfig config, + HttpClient httpClient, + ILoggerFactory loggerFactory) + { + var logger = loggerFactory.CreateLogger(); + + // Get known dimensions for the model + var dimensions = Constants.EmbeddingDefaults.KnownModelDimensions.GetValueOrDefault(config.Model, defaultValue: 1536); + + return new OpenAIEmbeddingGenerator( + httpClient, + config.ApiKey, + config.Model, + dimensions, + isNormalized: true, // OpenAI embeddings are typically normalized + config.BaseUrl, + logger); + } + + /// + /// Creates an Azure OpenAI embedding generator. + /// Constructor signature: httpClient, endpoint, deployment, model, apiKey, vectorDimensions, isNormalized, logger + /// + private static IEmbeddingGenerator CreateAzureOpenAIGenerator( + AzureOpenAIEmbeddingsConfig config, + HttpClient httpClient, + ILoggerFactory loggerFactory) + { + var logger = loggerFactory.CreateLogger(); + + // Get known dimensions for the model + var dimensions = Constants.EmbeddingDefaults.KnownModelDimensions.GetValueOrDefault(config.Model, defaultValue: 1536); + + return new AzureOpenAIEmbeddingGenerator( + httpClient, + config.Endpoint, + config.Deployment, + config.Model, + config.ApiKey ?? string.Empty, + dimensions, + isNormalized: true, // Azure OpenAI embeddings are typically normalized + logger); + } + + /// + /// Creates a HuggingFace embedding generator. + /// + private static IEmbeddingGenerator CreateHuggingFaceGenerator( + HuggingFaceEmbeddingsConfig config, + HttpClient httpClient, + ILoggerFactory loggerFactory) + { + var logger = loggerFactory.CreateLogger(); + + // Get known dimensions for the model + var dimensions = Constants.EmbeddingDefaults.KnownModelDimensions.GetValueOrDefault(config.Model, defaultValue: 384); + + return new HuggingFaceEmbeddingGenerator( + httpClient, + config.ApiKey ?? string.Empty, + config.Model, + dimensions, + isNormalized: true, // Sentence-transformers models typically return normalized vectors + config.BaseUrl, + logger); + } +} diff --git a/src/Main/Services/SearchIndexFactory.cs b/src/Main/Services/SearchIndexFactory.cs index 06641c2db..643f24b08 100644 --- a/src/Main/Services/SearchIndexFactory.cs +++ b/src/Main/Services/SearchIndexFactory.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using KernelMemory.Core.Config.SearchIndex; +using KernelMemory.Core.Embeddings.Cache; using KernelMemory.Core.Search; using Microsoft.Extensions.Logging; @@ -11,13 +12,17 @@ namespace KernelMemory.Main.Services; public static class SearchIndexFactory { /// - /// Creates search indexes from configuration as a dictionary keyed by index ID. + /// Creates all search indexes from configuration (FTS, vector, and future types). /// /// List of search index configurations. - /// Logger factory for creating index loggers. + /// HTTP client for embedding API calls (required for vector indexes). + /// Optional embedding cache for vector indexes. + /// Logger factory for creating component loggers. /// Dictionary of index ID to ISearchIndex instance. public static IReadOnlyDictionary CreateIndexes( List configs, + HttpClient httpClient, + IEmbeddingCache? embeddingCache, ILoggerFactory loggerFactory) { var indexes = new Dictionary(); @@ -26,21 +31,73 @@ public static IReadOnlyDictionary CreateIndexes( { if (config is FtsSearchIndexConfig ftsConfig) { - if (string.IsNullOrWhiteSpace(ftsConfig.Path)) - { - throw new InvalidOperationException($"FTS index '{config.Id}' has no Path configured"); - } - - var logger = loggerFactory.CreateLogger(); - var index = new SqliteFtsIndex(ftsConfig.Path, ftsConfig.EnableStemming, logger); - indexes[config.Id] = index; + var ftsIndex = CreateFtsIndexFromConfig(ftsConfig, loggerFactory); + indexes[config.Id] = ftsIndex; + } + else if (config is VectorSearchIndexConfig vectorConfig) + { + var vectorIndex = CreateVectorIndexFromConfig(vectorConfig, httpClient, embeddingCache, loggerFactory); + indexes[config.Id] = vectorIndex; } - // Add other index types here (vector, hybrid, etc.) + // Add other index types here (graph, hybrid, etc.) } return indexes; } + /// + /// Creates an FTS index from configuration. + /// + private static SqliteFtsIndex CreateFtsIndexFromConfig( + FtsSearchIndexConfig config, + ILoggerFactory loggerFactory) + { + if (string.IsNullOrWhiteSpace(config.Path)) + { + throw new InvalidOperationException($"FTS index '{config.Id}' has no Path configured"); + } + + var logger = loggerFactory.CreateLogger(); + return new SqliteFtsIndex(config.Path, config.EnableStemming, logger); + } + + /// + /// Creates a vector index from configuration. + /// Requires embeddings configuration to be present. + /// + private static SqliteVectorIndex CreateVectorIndexFromConfig( + VectorSearchIndexConfig config, + HttpClient httpClient, + IEmbeddingCache? embeddingCache, + ILoggerFactory loggerFactory) + { + if (string.IsNullOrWhiteSpace(config.Path)) + { + throw new InvalidOperationException($"Vector index '{config.Id}' has no Path configured"); + } + + if (config.Embeddings == null) + { + throw new InvalidOperationException($"Vector index '{config.Id}' has no Embeddings configuration"); + } + + // Create embedding generator from config + var embeddingGenerator = EmbeddingGeneratorFactory.CreateGenerator( + config.Embeddings, + httpClient, + embeddingCache, + loggerFactory); + + var logger = loggerFactory.CreateLogger(); + + return new SqliteVectorIndex( + config.Path, + config.Dimensions, + config.UseSqliteVec, + embeddingGenerator, + logger); + } + /// /// Creates the first FTS index from configuration. /// Returns null if no FTS index is configured. diff --git a/tests/Core.Tests/Config/AppConfigTests.cs b/tests/Core.Tests/Config/AppConfigTests.cs index 2f3ec61ae..090eddab2 100644 --- a/tests/Core.Tests/Config/AppConfigTests.cs +++ b/tests/Core.Tests/Config/AppConfigTests.cs @@ -2,6 +2,7 @@ using KernelMemory.Core.Config; using KernelMemory.Core.Config.Cache; using KernelMemory.Core.Config.ContentIndex; +using KernelMemory.Core.Config.Embeddings; using KernelMemory.Core.Config.Enums; using KernelMemory.Core.Config.SearchIndex; using KernelMemory.Core.Config.Validation; @@ -23,8 +24,10 @@ public void CreateDefault_ShouldCreateValidConfiguration() Assert.NotNull(config); Assert.Single(config.Nodes); Assert.True(config.Nodes.ContainsKey("personal")); - // Cache configs intentionally null - only created when features are implemented - Assert.Null(config.EmbeddingsCache); + // Embeddings cache now included in default (Feature 00007+00008 complete) + Assert.NotNull(config.EmbeddingsCache); + Assert.NotNull(config.EmbeddingsCache.Path); + // LLM cache still not included (feature not yet implemented) Assert.Null(config.LLMCache); // Verify personal node structure @@ -35,16 +38,27 @@ public void CreateDefault_ShouldCreateValidConfiguration() Assert.IsType(personalNode.ContentIndex); Assert.Null(personalNode.FileStorage); Assert.Null(personalNode.RepoStorage); - Assert.Single(personalNode.SearchIndexes); + Assert.Equal(2, personalNode.SearchIndexes.Count); // FTS + Vector - // Verify search indexes (only FTS for now - vectors not yet implemented) - Assert.IsType(personalNode.SearchIndexes[0]); - - var ftsIndex = (FtsSearchIndexConfig)personalNode.SearchIndexes[0]; + // Verify FTS index + var ftsIndex = personalNode.SearchIndexes.First(i => i is FtsSearchIndexConfig) as FtsSearchIndexConfig; + Assert.NotNull(ftsIndex); Assert.Equal(SearchIndexTypes.SqliteFTS, ftsIndex.Type); Assert.True(ftsIndex.EnableStemming); + Assert.True(ftsIndex.Required); Assert.NotNull(ftsIndex.Path); Assert.Contains("fts.db", ftsIndex.Path); + + // Verify Vector index + var vectorIndex = personalNode.SearchIndexes.First(i => i is VectorSearchIndexConfig) as VectorSearchIndexConfig; + Assert.NotNull(vectorIndex); + Assert.Equal(SearchIndexTypes.SqliteVector, vectorIndex.Type); + Assert.False(vectorIndex.Required); // Optional - Ollama may not be running + Assert.Equal(1024, vectorIndex.Dimensions); + Assert.NotNull(vectorIndex.Path); + Assert.Contains("vector.db", vectorIndex.Path); + Assert.NotNull(vectorIndex.Embeddings); + Assert.IsType(vectorIndex.Embeddings); } [Fact] diff --git a/tests/Core.Tests/Config/ConfigParserAutoCreateTests.cs b/tests/Core.Tests/Config/ConfigParserAutoCreateTests.cs index ee348b553..76da8b792 100644 --- a/tests/Core.Tests/Config/ConfigParserAutoCreateTests.cs +++ b/tests/Core.Tests/Config/ConfigParserAutoCreateTests.cs @@ -98,7 +98,7 @@ public void LoadFromFile_WhenFileDoesNotExist_CreatedConfigIsValid() Assert.Single(config.Nodes); Assert.True(config.Nodes.ContainsKey("personal")); // Cache configs intentionally null - only created when features are implemented - Assert.Null(config.EmbeddingsCache); + Assert.NotNull(config.EmbeddingsCache); // Now included in default config Assert.Null(config.LLMCache); } diff --git a/tests/Core.Tests/Config/ConfigParserTests.cs b/tests/Core.Tests/Config/ConfigParserTests.cs index afaec9b1f..f74ee9986 100644 --- a/tests/Core.Tests/Config/ConfigParserTests.cs +++ b/tests/Core.Tests/Config/ConfigParserTests.cs @@ -23,7 +23,7 @@ public void LoadFromFile_WhenFileMissing_ShouldReturnDefaultConfig() Assert.Single(config.Nodes); Assert.True(config.Nodes.ContainsKey("personal")); // Cache configs intentionally null in default config - Assert.Null(config.EmbeddingsCache); + Assert.NotNull(config.EmbeddingsCache); // Now included in default config } [Fact] diff --git a/tests/Core.Tests/Config/SearchConfigTests.cs b/tests/Core.Tests/Config/SearchConfigTests.cs index b70604b43..b529e3f68 100644 --- a/tests/Core.Tests/Config/SearchConfigTests.cs +++ b/tests/Core.Tests/Config/SearchConfigTests.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using KernelMemory.Core.Config; using KernelMemory.Core.Config.Validation; -using KernelMemory.Core.Search; namespace KernelMemory.Core.Tests.Config; @@ -17,12 +16,12 @@ public void DefaultValues_MatchConstants() var config = new SearchConfig(); // Assert - verify defaults match SearchConstants - Assert.Equal(SearchConstants.DefaultMinRelevance, config.DefaultMinRelevance); - Assert.Equal(SearchConstants.DefaultLimit, config.DefaultLimit); - Assert.Equal(SearchConstants.DefaultSearchTimeoutSeconds, config.SearchTimeoutSeconds); - Assert.Equal(SearchConstants.DefaultMaxResultsPerNode, config.MaxResultsPerNode); + Assert.Equal(Constants.SearchDefaults.DefaultMinRelevance, config.DefaultMinRelevance); + Assert.Equal(Constants.SearchDefaults.DefaultLimit, config.DefaultLimit); + Assert.Equal(Constants.SearchDefaults.DefaultSearchTimeoutSeconds, config.SearchTimeoutSeconds); + Assert.Equal(Constants.SearchDefaults.DefaultMaxResultsPerNode, config.MaxResultsPerNode); Assert.Single(config.DefaultNodes); - Assert.Equal(SearchConstants.AllNodesWildcard, config.DefaultNodes[0]); + Assert.Equal(Constants.SearchDefaults.AllNodesWildcard, config.DefaultNodes[0]); Assert.Empty(config.ExcludeNodes); } @@ -118,7 +117,7 @@ public void Validate_WildcardWithExclusions_Succeeds() // Arrange - wildcard with exclusions is valid var config = new SearchConfig { - DefaultNodes = [SearchConstants.AllNodesWildcard], + DefaultNodes = [Constants.SearchDefaults.AllNodesWildcard], ExcludeNodes = ["archive", "temp"] }; diff --git a/tests/Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs b/tests/Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs index 82495d072..4f86d606b 100644 --- a/tests/Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs +++ b/tests/Core.Tests/Embeddings/Cache/SqliteEmbeddingCacheTests.cs @@ -53,7 +53,7 @@ public async Task StoreAsync_AndTryGetAsync_ShouldRoundTrip() var vector = new float[] { 0.1f, 0.2f, 0.3f, 0.4f }; // Act - await cache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); // Assert @@ -74,7 +74,7 @@ public async Task StoreAsync_WithLargeVector_ShouldRoundTrip() } // Act - await cache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); // Assert @@ -96,8 +96,8 @@ public async Task StoreAsync_WithSameKey_ShouldOverwrite() var vector2 = new float[] { 0.9f, 0.8f, 0.7f }; // Act - await cache.StoreAsync(key, vector1, CancellationToken.None).ConfigureAwait(false); - await cache.StoreAsync(key, vector2, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector1, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector2, tokenCount: null, CancellationToken.None).ConfigureAwait(false); var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); // Assert @@ -116,8 +116,8 @@ public async Task TryGetAsync_WithDifferentKeys_ShouldReturnCorrectValues() var vector2 = new float[] { 0.4f, 0.5f, 0.6f }; // Act - await cache.StoreAsync(key1, vector1, CancellationToken.None).ConfigureAwait(false); - await cache.StoreAsync(key2, vector2, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key1, vector1, tokenCount: null, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key2, vector2, tokenCount: null, CancellationToken.None).ConfigureAwait(false); var result1 = await cache.TryGetAsync(key1, CancellationToken.None).ConfigureAwait(false); var result2 = await cache.TryGetAsync(key2, CancellationToken.None).ConfigureAwait(false); @@ -136,7 +136,7 @@ public async Task ReadOnlyMode_TryGetAsync_ShouldWork() { var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); var vector = new float[] { 0.1f, 0.2f, 0.3f }; - await writeCache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); + await writeCache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); } // Act - Then read with read-only mode @@ -158,7 +158,7 @@ public async Task ReadOnlyMode_StoreAsync_ShouldNotWrite() var vector = new float[] { 0.1f, 0.2f, 0.3f }; // Act - Store should be ignored in read-only mode - await cache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); // Assert @@ -174,7 +174,7 @@ public async Task WriteOnlyMode_StoreAsync_ShouldWork() var vector = new float[] { 0.1f, 0.2f, 0.3f }; // Act - await cache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); // Assert - verify by reading with read-write cache using var readCache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object); @@ -191,7 +191,7 @@ public async Task WriteOnlyMode_TryGetAsync_ShouldReturnNull() { var key = EmbeddingCacheKey.Create("OpenAI", "model", 1536, true, "test text"); var vector = new float[] { 0.1f, 0.2f, 0.3f }; - await writeCache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); + await writeCache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); } // Act - Read with write-only mode should return null @@ -256,7 +256,7 @@ public async Task VectorBlobStorage_ShouldPreserveFloatPrecision() }; // Act - await cache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); var result = await cache.TryGetAsync(key, CancellationToken.None).ConfigureAwait(false); // Assert @@ -277,7 +277,7 @@ public async Task CacheDoesNotStoreInputText() var vector = new float[] { 0.1f, 0.2f, 0.3f }; // Act - await cache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); // Assert - Read database file and verify text is not present var dbContent = await File.ReadAllBytesAsync(this._tempDbPath).ConfigureAwait(false); @@ -295,7 +295,7 @@ public async Task CachePersistence_ShouldSurviveReopen() // Store and close using (var cache = new SqliteEmbeddingCache(this._tempDbPath, CacheModes.ReadWrite, this._loggerMock.Object)) { - await cache.StoreAsync(key, vector, CancellationToken.None).ConfigureAwait(false); + await cache.StoreAsync(key, vector, tokenCount: null, CancellationToken.None).ConfigureAwait(false); } // Act - Reopen and read @@ -319,7 +319,7 @@ public async Task StoreAsync_WithCancellationToken_ShouldRespectCancellation() // Act & Assert await Assert.ThrowsAsync( - () => cache.StoreAsync(key, vector, cts.Token)).ConfigureAwait(false); + () => cache.StoreAsync(key, vector, tokenCount: null, cts.Token)).ConfigureAwait(false); } [Fact] diff --git a/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs index 3e3d4c749..684a0d3fe 100644 --- a/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs +++ b/tests/Core.Tests/Embeddings/CachedEmbeddingGeneratorTests.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. using KernelMemory.Core.Config.Enums; -using KernelMemory.Core.Embeddings; using KernelMemory.Core.Embeddings.Cache; using Microsoft.Extensions.Logging; using Moq; @@ -54,7 +53,8 @@ public async Task GenerateAsync_Single_WithCacheHit_ShouldReturnCachedVector() var cachedVector = new float[] { 0.1f, 0.2f, 0.3f }; var cachedEmbedding = new CachedEmbedding { - Vector = cachedVector + Vector = cachedVector, + Timestamp = DateTimeOffset.UtcNow }; this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); @@ -71,7 +71,7 @@ public async Task GenerateAsync_Single_WithCacheHit_ShouldReturnCachedVector() var result = await cachedGenerator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(cachedVector, result); + Assert.Equal(cachedVector, result.Vector); this._innerGeneratorMock.Verify( x => x.GenerateAsync(It.IsAny(), It.IsAny()), Times.Never); @@ -82,6 +82,7 @@ public async Task GenerateAsync_Single_WithCacheMiss_ShouldCallInnerGenerator() { // Arrange var generatedVector = new float[] { 0.4f, 0.5f, 0.6f }; + var generatedResult = EmbeddingResult.FromVector(generatedVector); this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); this._cacheMock @@ -90,7 +91,7 @@ public async Task GenerateAsync_Single_WithCacheMiss_ShouldCallInnerGenerator() this._innerGeneratorMock .Setup(x => x.GenerateAsync(It.IsAny(), It.IsAny())) - .ReturnsAsync(generatedVector); + .ReturnsAsync(generatedResult); var cachedGenerator = new CachedEmbeddingGenerator( this._innerGeneratorMock.Object, @@ -101,7 +102,7 @@ public async Task GenerateAsync_Single_WithCacheMiss_ShouldCallInnerGenerator() var result = await cachedGenerator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(generatedVector, result); + Assert.Equal(generatedVector, result.Vector); this._innerGeneratorMock.Verify( x => x.GenerateAsync("test text", It.IsAny()), Times.Once); @@ -112,6 +113,7 @@ public async Task GenerateAsync_Single_WithCacheMiss_ShouldStoreInCache() { // Arrange var generatedVector = new float[] { 0.4f, 0.5f, 0.6f }; + var generatedResult = EmbeddingResult.FromVector(generatedVector); this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); this._cacheMock @@ -120,7 +122,7 @@ public async Task GenerateAsync_Single_WithCacheMiss_ShouldStoreInCache() this._innerGeneratorMock .Setup(x => x.GenerateAsync(It.IsAny(), It.IsAny())) - .ReturnsAsync(generatedVector); + .ReturnsAsync(generatedResult); var cachedGenerator = new CachedEmbeddingGenerator( this._innerGeneratorMock.Object, @@ -132,7 +134,7 @@ public async Task GenerateAsync_Single_WithCacheMiss_ShouldStoreInCache() // Assert this._cacheMock.Verify( - x => x.StoreAsync(It.IsAny(), generatedVector, It.IsAny()), + x => x.StoreAsync(It.IsAny(), generatedVector, It.IsAny(), It.IsAny()), Times.Once); } @@ -141,12 +143,13 @@ public async Task GenerateAsync_Single_WithWriteOnlyCache_ShouldSkipCacheRead() { // Arrange var generatedVector = new float[] { 0.4f, 0.5f, 0.6f }; + var generatedResult = EmbeddingResult.FromVector(generatedVector); this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.WriteOnly); this._innerGeneratorMock .Setup(x => x.GenerateAsync(It.IsAny(), It.IsAny())) - .ReturnsAsync(generatedVector); + .ReturnsAsync(generatedResult); var cachedGenerator = new CachedEmbeddingGenerator( this._innerGeneratorMock.Object, @@ -157,12 +160,12 @@ public async Task GenerateAsync_Single_WithWriteOnlyCache_ShouldSkipCacheRead() var result = await cachedGenerator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(generatedVector, result); + Assert.Equal(generatedVector, result.Vector); this._cacheMock.Verify( x => x.TryGetAsync(It.IsAny(), It.IsAny()), Times.Never); this._cacheMock.Verify( - x => x.StoreAsync(It.IsAny(), It.IsAny(), It.IsAny()), + x => x.StoreAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Once); } @@ -171,6 +174,7 @@ public async Task GenerateAsync_Single_WithReadOnlyCache_ShouldSkipCacheWrite() { // Arrange var generatedVector = new float[] { 0.4f, 0.5f, 0.6f }; + var generatedResult = EmbeddingResult.FromVector(generatedVector); this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadOnly); this._cacheMock @@ -179,7 +183,7 @@ public async Task GenerateAsync_Single_WithReadOnlyCache_ShouldSkipCacheWrite() this._innerGeneratorMock .Setup(x => x.GenerateAsync(It.IsAny(), It.IsAny())) - .ReturnsAsync(generatedVector); + .ReturnsAsync(generatedResult); var cachedGenerator = new CachedEmbeddingGenerator( this._innerGeneratorMock.Object, @@ -190,9 +194,9 @@ public async Task GenerateAsync_Single_WithReadOnlyCache_ShouldSkipCacheWrite() var result = await cachedGenerator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(generatedVector, result); + Assert.Equal(generatedVector, result.Vector); this._cacheMock.Verify( - x => x.StoreAsync(It.IsAny(), It.IsAny(), It.IsAny()), + x => x.StoreAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Never); } @@ -203,9 +207,9 @@ public async Task GenerateAsync_Batch_AllCacheHits_ShouldNotCallInnerGenerator() var texts = new[] { "text1", "text2", "text3" }; var cachedVectors = new Dictionary { - ["text1"] = new[] { 0.1f, 0.2f }, - ["text2"] = new[] { 0.3f, 0.4f }, - ["text3"] = new[] { 0.5f, 0.6f } + ["text1"] = [0.1f, 0.2f], + ["text2"] = [0.3f, 0.4f], + ["text3"] = [0.5f, 0.6f] }; this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); @@ -219,7 +223,7 @@ public async Task GenerateAsync_Batch_AllCacheHits_ShouldNotCallInnerGenerator() var testKey = EmbeddingCacheKey.Create("OpenAI", "text-embedding-ada-002", 1536, true, kvp.Key); if (testKey.TextHash == key.TextHash) { - return new CachedEmbedding { Vector = kvp.Value }; + return new CachedEmbedding { Vector = kvp.Value, Timestamp = DateTimeOffset.UtcNow }; } } @@ -246,11 +250,11 @@ public async Task GenerateAsync_Batch_AllCacheMisses_ShouldCallInnerGeneratorWit { // Arrange var texts = new[] { "text1", "text2", "text3" }; - var generatedVectors = new[] + var generatedResults = new EmbeddingResult[] { - new[] { 0.1f, 0.2f }, - new[] { 0.3f, 0.4f }, - new[] { 0.5f, 0.6f } + EmbeddingResult.FromVector([0.1f, 0.2f]), + EmbeddingResult.FromVector([0.3f, 0.4f]), + EmbeddingResult.FromVector([0.5f, 0.6f]) }; this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); @@ -260,7 +264,7 @@ public async Task GenerateAsync_Batch_AllCacheMisses_ShouldCallInnerGeneratorWit this._innerGeneratorMock .Setup(x => x.GenerateAsync(It.IsAny>(), It.IsAny())) - .ReturnsAsync(generatedVectors); + .ReturnsAsync(generatedResults); var cachedGenerator = new CachedEmbeddingGenerator( this._innerGeneratorMock.Object, @@ -283,10 +287,10 @@ public async Task GenerateAsync_Batch_MixedHitsAndMisses_ShouldOnlyGenerateMisse // Arrange var texts = new[] { "cached", "not-cached-1", "not-cached-2" }; var cachedVector = new[] { 0.1f, 0.2f }; - var generatedVectors = new[] + var generatedResults = new EmbeddingResult[] { - new[] { 0.3f, 0.4f }, - new[] { 0.5f, 0.6f } + EmbeddingResult.FromVector([0.3f, 0.4f]), + EmbeddingResult.FromVector([0.5f, 0.6f]) }; var cachedKey = EmbeddingCacheKey.Create("OpenAI", "text-embedding-ada-002", 1536, true, "cached"); @@ -298,7 +302,7 @@ public async Task GenerateAsync_Batch_MixedHitsAndMisses_ShouldOnlyGenerateMisse { if (key.TextHash == cachedKey.TextHash) { - return new CachedEmbedding { Vector = cachedVector }; + return new CachedEmbedding { Vector = cachedVector, Timestamp = DateTimeOffset.UtcNow }; } return null; @@ -306,7 +310,7 @@ public async Task GenerateAsync_Batch_MixedHitsAndMisses_ShouldOnlyGenerateMisse this._innerGeneratorMock .Setup(x => x.GenerateAsync(It.IsAny>(), It.IsAny())) - .ReturnsAsync(generatedVectors); + .ReturnsAsync(generatedResults); var cachedGenerator = new CachedEmbeddingGenerator( this._innerGeneratorMock.Object, @@ -319,10 +323,10 @@ public async Task GenerateAsync_Batch_MixedHitsAndMisses_ShouldOnlyGenerateMisse // Assert Assert.Equal(3, results.Length); // First result should be cached - Assert.Equal(cachedVector, results[0]); + Assert.Equal(cachedVector, results[0].Vector); // Other results should be generated - Assert.Equal(generatedVectors[0], results[1]); - Assert.Equal(generatedVectors[1], results[2]); + Assert.Equal(new[] { 0.3f, 0.4f }, results[1].Vector); + Assert.Equal(new[] { 0.5f, 0.6f }, results[2].Vector); // Verify only non-cached texts were sent to generator this._innerGeneratorMock.Verify( @@ -335,10 +339,10 @@ public async Task GenerateAsync_Batch_ShouldStoreGeneratedInCache() { // Arrange var texts = new[] { "text1", "text2" }; - var generatedVectors = new[] + var generatedResults = new EmbeddingResult[] { - new[] { 0.1f, 0.2f }, - new[] { 0.3f, 0.4f } + EmbeddingResult.FromVector([0.1f, 0.2f]), + EmbeddingResult.FromVector([0.3f, 0.4f]) }; this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); @@ -348,7 +352,7 @@ public async Task GenerateAsync_Batch_ShouldStoreGeneratedInCache() this._innerGeneratorMock .Setup(x => x.GenerateAsync(It.IsAny>(), It.IsAny())) - .ReturnsAsync(generatedVectors); + .ReturnsAsync(generatedResults); var cachedGenerator = new CachedEmbeddingGenerator( this._innerGeneratorMock.Object, @@ -360,7 +364,7 @@ public async Task GenerateAsync_Batch_ShouldStoreGeneratedInCache() // Assert - Both generated vectors should be stored this._cacheMock.Verify( - x => x.StoreAsync(It.IsAny(), It.IsAny(), It.IsAny()), + x => x.StoreAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Exactly(2)); } @@ -397,10 +401,10 @@ public async Task GenerateAsync_Batch_ShouldPreserveOrder() var vectorB = new[] { 2.0f }; var vectorD = new[] { 4.0f }; - var generatedVectors = new[] + var generatedResults = new EmbeddingResult[] { - new[] { 1.0f }, // for "a" - new[] { 3.0f } // for "c" + EmbeddingResult.FromVector([1.0f]), // for "a" + EmbeddingResult.FromVector([3.0f]) // for "c" }; this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); @@ -410,12 +414,12 @@ public async Task GenerateAsync_Batch_ShouldPreserveOrder() { if (key.TextHash == cachedB.TextHash) { - return new CachedEmbedding { Vector = vectorB }; + return new CachedEmbedding { Vector = vectorB, Timestamp = DateTimeOffset.UtcNow }; } if (key.TextHash == cachedD.TextHash) { - return new CachedEmbedding { Vector = vectorD }; + return new CachedEmbedding { Vector = vectorD, Timestamp = DateTimeOffset.UtcNow }; } return null; @@ -423,7 +427,7 @@ public async Task GenerateAsync_Batch_ShouldPreserveOrder() this._innerGeneratorMock .Setup(x => x.GenerateAsync(It.IsAny>(), It.IsAny())) - .ReturnsAsync(generatedVectors); + .ReturnsAsync(generatedResults); var cachedGenerator = new CachedEmbeddingGenerator( this._innerGeneratorMock.Object, @@ -435,10 +439,10 @@ public async Task GenerateAsync_Batch_ShouldPreserveOrder() // Assert - Order must be preserved: a, b, c, d Assert.Equal(4, results.Length); - Assert.Equal(new[] { 1.0f }, results[0]); // a - generated - Assert.Equal(new[] { 2.0f }, results[1]); // b - cached - Assert.Equal(new[] { 3.0f }, results[2]); // c - generated - Assert.Equal(new[] { 4.0f }, results[3]); // d - cached + Assert.Equal(new[] { 1.0f }, results[0].Vector); // a - generated + Assert.Equal(new[] { 2.0f }, results[1].Vector); // b - cached + Assert.Equal(new[] { 3.0f }, results[2].Vector); // c - generated + Assert.Equal(new[] { 4.0f }, results[3].Vector); // d - cached } [Fact] @@ -490,4 +494,66 @@ public void Constructor_WithNullLogger_ShouldThrow() Assert.Throws(() => new CachedEmbeddingGenerator(this._innerGeneratorMock.Object, this._cacheMock.Object, null!)); } + + [Fact] + public async Task GenerateAsync_Single_WithTokenCount_ShouldStoreTokenCountInCache() + { + // Arrange + var generatedVector = new float[] { 0.4f, 0.5f, 0.6f }; + const int tokenCount = 10; + var generatedResult = EmbeddingResult.FromVectorWithTokens(generatedVector, tokenCount); + + this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); + this._cacheMock + .Setup(x => x.TryGetAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((CachedEmbedding?)null); + + this._innerGeneratorMock + .Setup(x => x.GenerateAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(generatedResult); + + var cachedGenerator = new CachedEmbeddingGenerator( + this._innerGeneratorMock.Object, + this._cacheMock.Object, + this._loggerMock.Object); + + // Act + await cachedGenerator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); + + // Assert - Token count should be passed to cache + this._cacheMock.Verify( + x => x.StoreAsync(It.IsAny(), generatedVector, tokenCount, It.IsAny()), + Times.Once); + } + + [Fact] + public async Task GenerateAsync_Single_WithCacheHitAndTokenCount_ShouldReturnTokenCount() + { + // Arrange + var cachedVector = new float[] { 0.1f, 0.2f, 0.3f }; + const int tokenCount = 15; + var cachedEmbedding = new CachedEmbedding + { + Vector = cachedVector, + TokenCount = tokenCount, + Timestamp = DateTimeOffset.UtcNow + }; + + this._cacheMock.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); + this._cacheMock + .Setup(x => x.TryGetAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(cachedEmbedding); + + var cachedGenerator = new CachedEmbeddingGenerator( + this._innerGeneratorMock.Object, + this._cacheMock.Object, + this._loggerMock.Object); + + // Act + var result = await cachedGenerator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(cachedVector, result.Vector); + Assert.Equal(tokenCount, result.TokenCount); + } } diff --git a/tests/Core.Tests/Embeddings/CachedEmbeddingTests.cs b/tests/Core.Tests/Embeddings/CachedEmbeddingTests.cs index 9dcfb8672..85c31dbef 100644 --- a/tests/Core.Tests/Embeddings/CachedEmbeddingTests.cs +++ b/tests/Core.Tests/Embeddings/CachedEmbeddingTests.cs @@ -18,7 +18,8 @@ public void CachedEmbedding_WithRequiredProperties_ShouldBeCreated() // Act var cached = new CachedEmbedding { - Vector = vector + Vector = vector, + Timestamp = DateTimeOffset.UtcNow }; // Assert @@ -34,7 +35,8 @@ public void CachedEmbedding_VectorShouldPreserveFloatPrecision() // Act var cached = new CachedEmbedding { - Vector = vector + Vector = vector, + Timestamp = DateTimeOffset.UtcNow }; // Assert @@ -57,7 +59,8 @@ public void CachedEmbedding_WithLargeVector_ShouldPreserveAllDimensions() // Act var cached = new CachedEmbedding { - Vector = vector + Vector = vector, + Timestamp = DateTimeOffset.UtcNow }; // Assert diff --git a/tests/Core.Tests/Embeddings/EmbeddingConstantsTests.cs b/tests/Core.Tests/Embeddings/EmbeddingConstantsTests.cs index eeca003f2..747ab15c3 100644 --- a/tests/Core.Tests/Embeddings/EmbeddingConstantsTests.cs +++ b/tests/Core.Tests/Embeddings/EmbeddingConstantsTests.cs @@ -1,6 +1,4 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Embeddings; - namespace KernelMemory.Core.Tests.Embeddings; /// @@ -20,7 +18,7 @@ public sealed class EmbeddingConstantsTests public void KnownModelDimensions_ShouldContainExpectedValues(string modelName, int expectedDimensions) { // Act - var exists = EmbeddingConstants.KnownModelDimensions.TryGetValue(modelName, out var dimensions); + var exists = Constants.EmbeddingDefaults.KnownModelDimensions.TryGetValue(modelName, out var dimensions); // Assert Assert.True(exists, $"Model '{modelName}' should be in KnownModelDimensions"); @@ -31,14 +29,14 @@ public void KnownModelDimensions_ShouldContainExpectedValues(string modelName, i public void KnownModelDimensions_ShouldNotBeEmpty() { // Assert - Assert.NotEmpty(EmbeddingConstants.KnownModelDimensions); + Assert.NotEmpty(Constants.EmbeddingDefaults.KnownModelDimensions); } [Fact] public void KnownModelDimensions_AllValuesShouldBePositive() { // Assert - foreach (var kvp in EmbeddingConstants.KnownModelDimensions) + foreach (var kvp in Constants.EmbeddingDefaults.KnownModelDimensions) { Assert.True(kvp.Value > 0, $"Model '{kvp.Key}' has invalid dimensions: {kvp.Value}"); } @@ -48,7 +46,7 @@ public void KnownModelDimensions_AllValuesShouldBePositive() public void TryGetDimensions_WithKnownModel_ShouldReturnTrue() { // Act - var result = EmbeddingConstants.TryGetDimensions("text-embedding-ada-002", out var dimensions); + var result = Constants.EmbeddingDefaults.TryGetDimensions("text-embedding-ada-002", out var dimensions); // Assert Assert.True(result); @@ -59,7 +57,7 @@ public void TryGetDimensions_WithKnownModel_ShouldReturnTrue() public void TryGetDimensions_WithUnknownModel_ShouldReturnFalse() { // Act - var result = EmbeddingConstants.TryGetDimensions("unknown-model", out var dimensions); + var result = Constants.EmbeddingDefaults.TryGetDimensions("unknown-model", out var dimensions); // Assert Assert.False(result); @@ -70,34 +68,34 @@ public void TryGetDimensions_WithUnknownModel_ShouldReturnFalse() public void DefaultBatchSize_ShouldBe10() { // Assert - Assert.Equal(10, EmbeddingConstants.DefaultBatchSize); + Assert.Equal(10, Constants.EmbeddingDefaults.DefaultBatchSize); } [Fact] public void DefaultOllamaModel_ShouldBeQwen3Embedding() { // Assert - Assert.Equal("qwen3-embedding", EmbeddingConstants.DefaultOllamaModel); + Assert.Equal("qwen3-embedding:0.6b", Constants.EmbeddingDefaults.DefaultOllamaModel); } [Fact] public void DefaultOllamaBaseUrl_ShouldBeLocalhost() { // Assert - Assert.Equal("http://localhost:11434", EmbeddingConstants.DefaultOllamaBaseUrl); + Assert.Equal("http://localhost:11434", Constants.EmbeddingDefaults.DefaultOllamaBaseUrl); } [Fact] public void DefaultHuggingFaceModel_ShouldBeAllMiniLM() { // Assert - Assert.Equal("sentence-transformers/all-MiniLM-L6-v2", EmbeddingConstants.DefaultHuggingFaceModel); + Assert.Equal("sentence-transformers/all-MiniLM-L6-v2", Constants.EmbeddingDefaults.DefaultHuggingFaceModel); } [Fact] public void DefaultHuggingFaceBaseUrl_ShouldBeInferenceApi() { // Assert - Assert.Equal("https://api-inference.huggingface.co", EmbeddingConstants.DefaultHuggingFaceBaseUrl); + Assert.Equal("https://api-inference.huggingface.co", Constants.EmbeddingDefaults.DefaultHuggingFaceBaseUrl); } } diff --git a/tests/Core.Tests/Embeddings/Providers/AzureOpenAIEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/Providers/AzureOpenAIEmbeddingGeneratorTests.cs index 9cb79d13f..d57b3e9dc 100644 --- a/tests/Core.Tests/Embeddings/Providers/AzureOpenAIEmbeddingGeneratorTests.cs +++ b/tests/Core.Tests/Embeddings/Providers/AzureOpenAIEmbeddingGeneratorTests.cs @@ -76,7 +76,7 @@ public async Task GenerateAsync_Single_ShouldCallAzureEndpoint() var result = await generator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(new[] { 0.1f, 0.2f, 0.3f }, result); + Assert.Equal(new[] { 0.1f, 0.2f, 0.3f }, result.Vector); } [Fact] diff --git a/tests/Core.Tests/Embeddings/Providers/HuggingFaceEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/Providers/HuggingFaceEmbeddingGeneratorTests.cs index 94d5811f5..f0bc4441e 100644 --- a/tests/Core.Tests/Embeddings/Providers/HuggingFaceEmbeddingGeneratorTests.cs +++ b/tests/Core.Tests/Embeddings/Providers/HuggingFaceEmbeddingGeneratorTests.cs @@ -75,7 +75,7 @@ public async Task GenerateAsync_Single_ShouldCallCorrectEndpoint() var result = await generator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(new[] { 0.1f, 0.2f, 0.3f }, result); + Assert.Equal(new[] { 0.1f, 0.2f, 0.3f }, result.Vector); } [Fact] @@ -182,9 +182,9 @@ public async Task GenerateAsync_Batch_ShouldProcessAllTexts() // Assert Assert.Equal(3, results.Length); - Assert.Equal(new[] { 0.1f }, results[0]); - Assert.Equal(new[] { 0.2f }, results[1]); - Assert.Equal(new[] { 0.3f }, results[2]); + Assert.Equal(new[] { 0.1f }, results[0].Vector); + Assert.Equal(new[] { 0.2f }, results[1].Vector); + Assert.Equal(new[] { 0.3f }, results[2].Vector); } [Fact] @@ -214,7 +214,7 @@ public async Task GenerateAsync_WithCustomBaseUrl_ShouldUseIt() var result = await generator.GenerateAsync("test", CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(new[] { 0.1f }, result); + Assert.Equal(new[] { 0.1f }, result.Vector); } [Fact] diff --git a/tests/Core.Tests/Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs index 65d5ba62a..34fac8b4c 100644 --- a/tests/Core.Tests/Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs +++ b/tests/Core.Tests/Embeddings/Providers/OllamaEmbeddingGeneratorTests.cs @@ -73,7 +73,7 @@ public async Task GenerateAsync_Single_ShouldCallCorrectEndpoint() var result = await generator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(new[] { 0.1f, 0.2f, 0.3f }, result); + Assert.Equal(new[] { 0.1f, 0.2f, 0.3f }, result.Vector); } [Fact] @@ -149,9 +149,9 @@ public async Task GenerateAsync_Batch_ShouldProcessAllTexts() // Assert Assert.Equal(3, results.Length); - Assert.Equal(new[] { 0.1f }, results[0]); - Assert.Equal(new[] { 0.2f }, results[1]); - Assert.Equal(new[] { 0.3f }, results[2]); + Assert.Equal(new[] { 0.1f }, results[0].Vector); + Assert.Equal(new[] { 0.2f }, results[1].Vector); + Assert.Equal(new[] { 0.3f }, results[2].Vector); } [Fact] @@ -254,6 +254,36 @@ public void Constructor_WithNullModel_ShouldThrow() new OllamaEmbeddingGenerator(httpClient, "http://localhost", null!, 1024, true, this._loggerMock.Object)); } + [Fact] + public async Task GenerateAsync_Single_ShouldReturnNullTokenCount() + { + // Arrange - Ollama API does not return token count + var response = new OllamaEmbeddingResponse { Embedding = new[] { 0.1f, 0.2f, 0.3f } }; + var responseJson = JsonSerializer.Serialize(response); + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(responseJson) + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OllamaEmbeddingGenerator( + httpClient, "http://localhost:11434", "qwen3-embedding", 1024, true, this._loggerMock.Object); + + // Act + var result = await generator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); + + // Assert - Ollama does NOT provide token count, so it should be null + Assert.Null(result.TokenCount); + } + // Internal request/response classes for testing private sealed class OllamaEmbeddingRequest { diff --git a/tests/Core.Tests/Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs b/tests/Core.Tests/Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs index 1e7923cd0..3b1a45d22 100644 --- a/tests/Core.Tests/Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs +++ b/tests/Core.Tests/Embeddings/Providers/OpenAIEmbeddingGeneratorTests.cs @@ -74,7 +74,7 @@ public async Task GenerateAsync_Single_ShouldCallCorrectEndpoint() var result = await generator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(new[] { 0.1f, 0.2f, 0.3f }, result); + Assert.Equal(new[] { 0.1f, 0.2f, 0.3f }, result.Vector); } [Fact] @@ -180,7 +180,7 @@ public async Task GenerateAsync_WithCustomBaseUrl_ShouldUseIt() var result = await generator.GenerateAsync("test", CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(new[] { 0.1f }, result); + Assert.Equal(new[] { 0.1f }, result.Vector); } [Fact] @@ -284,12 +284,86 @@ public void Constructor_WithNullModel_ShouldThrow() new OpenAIEmbeddingGenerator(httpClient, "key", null!, 1536, true, null, this._loggerMock.Object)); } + [Fact] + public async Task GenerateAsync_Single_ShouldReturnTokenCountFromApiResponse() + { + // Arrange + var response = CreateOpenAIResponseWithTokenCount(new[] { new[] { 0.1f, 0.2f, 0.3f } }, totalTokens: 42); + var responseJson = JsonSerializer.Serialize(response); + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(responseJson) + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OpenAIEmbeddingGenerator( + httpClient, "test-key", "text-embedding-ada-002", 1536, true, null, this._loggerMock.Object); + + // Act + var result = await generator.GenerateAsync("test text", CancellationToken.None).ConfigureAwait(false); + + // Assert - Token count should be extracted from API response + Assert.NotNull(result.TokenCount); + Assert.Equal(42, result.TokenCount.Value); + } + + [Fact] + public async Task GenerateAsync_Batch_ShouldDistributeTokenCountEvenly() + { + // Arrange - Response with 30 total tokens for 3 embeddings = 10 tokens each + var response = CreateOpenAIResponseWithTokenCount(new[] + { + new[] { 0.1f }, + new[] { 0.2f }, + new[] { 0.3f } + }, totalTokens: 30); + var responseJson = JsonSerializer.Serialize(response); + + this._httpHandlerMock + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(responseJson) + }); + + var httpClient = new HttpClient(this._httpHandlerMock.Object); + var generator = new OpenAIEmbeddingGenerator( + httpClient, "test-key", "text-embedding-ada-002", 1536, true, null, this._loggerMock.Object); + + // Act + var results = await generator.GenerateAsync(new[] { "text1", "text2", "text3" }, CancellationToken.None).ConfigureAwait(false); + + // Assert - Each result should have evenly distributed token count (30/3 = 10) + Assert.Equal(3, results.Length); + Assert.Equal(10, results[0].TokenCount); + Assert.Equal(10, results[1].TokenCount); + Assert.Equal(10, results[2].TokenCount); + } + private static OpenAIEmbeddingResponse CreateOpenAIResponse(float[][] embeddings) + { + return CreateOpenAIResponseWithTokenCount(embeddings, totalTokens: 10); + } + + private static OpenAIEmbeddingResponse CreateOpenAIResponseWithTokenCount(float[][] embeddings, int totalTokens) { return new OpenAIEmbeddingResponse { Data = embeddings.Select((e, i) => new EmbeddingData { Index = i, Embedding = e }).ToArray(), - Usage = new UsageInfo { PromptTokens = 10, TotalTokens = 10 } + Usage = new UsageInfo { PromptTokens = totalTokens, TotalTokens = totalTokens } }; } diff --git a/tests/Core.Tests/GlobalUsings.cs b/tests/Core.Tests/GlobalUsings.cs index 543dc179d..ab250a6b9 100644 --- a/tests/Core.Tests/GlobalUsings.cs +++ b/tests/Core.Tests/GlobalUsings.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. +global using KernelMemory.Core.Embeddings; +global using KernelMemory.Core.Logging; +global using KernelMemory.Core.Search; global using Xunit; - using System.Diagnostics.CodeAnalysis; // Test files create disposable objects that are managed by the test framework lifecycle diff --git a/tests/Core.Tests/Logging/ActivityEnricherTests.cs b/tests/Core.Tests/Logging/ActivityEnricherTests.cs index eff2737c9..1e92a018c 100644 --- a/tests/Core.Tests/Logging/ActivityEnricherTests.cs +++ b/tests/Core.Tests/Logging/ActivityEnricherTests.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System.Diagnostics; -using KernelMemory.Core.Logging; using Serilog.Events; namespace KernelMemory.Core.Tests.Logging; diff --git a/tests/Core.Tests/Logging/EnvironmentDetectorTests.cs b/tests/Core.Tests/Logging/EnvironmentDetectorTests.cs index 47ca8999b..a1fd52a09 100644 --- a/tests/Core.Tests/Logging/EnvironmentDetectorTests.cs +++ b/tests/Core.Tests/Logging/EnvironmentDetectorTests.cs @@ -1,7 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Logging; - namespace KernelMemory.Core.Tests.Logging; /// @@ -19,8 +17,8 @@ public sealed class EnvironmentDetectorTests : IDisposable public EnvironmentDetectorTests() { // Capture original values to restore after tests - this._originalDotNetEnv = Environment.GetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable); - this._originalAspNetEnv = Environment.GetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable); + this._originalDotNetEnv = Environment.GetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable); + this._originalAspNetEnv = Environment.GetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable); } /// @@ -31,20 +29,20 @@ public void Dispose() // Restore original environment variables if (this._originalDotNetEnv != null) { - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, this._originalDotNetEnv); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, this._originalDotNetEnv); } else { - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, null); } if (this._originalAspNetEnv != null) { - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, this._originalAspNetEnv); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, this._originalAspNetEnv); } else { - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, null); } GC.SuppressFinalize(this); @@ -58,8 +56,8 @@ public void Dispose() public void GetEnvironment_WhenDotNetEnvSet_ShouldReturnDotNetEnv() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "Production"); - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, "Staging"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "Production"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, "Staging"); // Act var result = EnvironmentDetector.GetEnvironment(); @@ -76,8 +74,8 @@ public void GetEnvironment_WhenDotNetEnvSet_ShouldReturnDotNetEnv() public void GetEnvironment_WhenOnlyAspNetEnvSet_ShouldReturnAspNetEnv() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, null); - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, "Staging"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, "Staging"); // Act var result = EnvironmentDetector.GetEnvironment(); @@ -94,8 +92,8 @@ public void GetEnvironment_WhenOnlyAspNetEnvSet_ShouldReturnAspNetEnv() public void GetEnvironment_WhenNothingSet_ShouldReturnDevelopment() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, null); - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, null); // Act var result = EnvironmentDetector.GetEnvironment(); @@ -111,8 +109,8 @@ public void GetEnvironment_WhenNothingSet_ShouldReturnDevelopment() public void IsProduction_WhenProductionSet_ShouldReturnTrue() { // Arrange - clear both env vars to ensure isolation - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "Production"); - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "Production"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, null); // Act var result = EnvironmentDetector.IsProduction(); @@ -129,8 +127,8 @@ public void IsProduction_WhenProductionSet_ShouldReturnTrue() public void IsProduction_WhenProductionLowercase_ShouldReturnTrue() { // Arrange - clear both env vars to ensure isolation - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "production"); - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "production"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, null); // Act var result = EnvironmentDetector.IsProduction(); @@ -146,8 +144,8 @@ public void IsProduction_WhenProductionLowercase_ShouldReturnTrue() public void IsProduction_WhenProductionUppercase_ShouldReturnTrue() { // Arrange - clear both env vars to ensure isolation - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "PRODUCTION"); - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "PRODUCTION"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, null); // Act var result = EnvironmentDetector.IsProduction(); @@ -164,8 +162,8 @@ public void IsProduction_WhenDevelopment_ShouldReturnFalse() { // Arrange - set DOTNET_ENVIRONMENT to Development (takes precedence over ASPNETCORE_ENVIRONMENT) // Set both to Development to ensure no Production leaks from other tests - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "Development"); - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, string.Empty); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "Development"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, string.Empty); // Act var result = EnvironmentDetector.IsProduction(); @@ -181,8 +179,8 @@ public void IsProduction_WhenDevelopment_ShouldReturnFalse() public void IsProduction_WhenStaging_ShouldReturnFalse() { // Arrange - clear both env vars to ensure isolation - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "Staging"); - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "Staging"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, null); // Act var result = EnvironmentDetector.IsProduction(); @@ -199,8 +197,8 @@ public void IsProduction_WhenStaging_ShouldReturnFalse() public void IsProduction_WhenNotSet_ShouldReturnFalse() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, null); - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, null); // Act var result = EnvironmentDetector.IsProduction(); @@ -216,8 +214,8 @@ public void IsProduction_WhenNotSet_ShouldReturnFalse() public void IsDevelopment_WhenDevelopmentSet_ShouldReturnTrue() { // Arrange - clear both env vars to ensure isolation - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "Development"); - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "Development"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, null); // Act var result = EnvironmentDetector.IsDevelopment(); @@ -233,8 +231,8 @@ public void IsDevelopment_WhenDevelopmentSet_ShouldReturnTrue() public void IsDevelopment_WhenNotSet_ShouldReturnTrue() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, null); - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, null); // Act var result = EnvironmentDetector.IsDevelopment(); @@ -250,8 +248,8 @@ public void IsDevelopment_WhenNotSet_ShouldReturnTrue() public void IsDevelopment_WhenProduction_ShouldReturnFalse() { // Arrange - clear both env vars to ensure isolation - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "Production"); - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "Production"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, null); // Act var result = EnvironmentDetector.IsDevelopment(); @@ -267,8 +265,8 @@ public void IsDevelopment_WhenProduction_ShouldReturnFalse() public void GetEnvironment_WhenDotNetEnvIsEmpty_ShouldFallbackToAspNet() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, string.Empty); - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, "Staging"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, string.Empty); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, "Staging"); // Act var result = EnvironmentDetector.GetEnvironment(); @@ -284,8 +282,8 @@ public void GetEnvironment_WhenDotNetEnvIsEmpty_ShouldFallbackToAspNet() public void GetEnvironment_WhenDotNetEnvIsWhitespace_ShouldFallbackToAspNet() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, " "); - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, "Staging"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, " "); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, "Staging"); // Act var result = EnvironmentDetector.GetEnvironment(); diff --git a/tests/Core.Tests/Logging/LoggerExtensionsTests.cs b/tests/Core.Tests/Logging/LoggerExtensionsTests.cs index 90c404c89..5ccce5cbf 100644 --- a/tests/Core.Tests/Logging/LoggerExtensionsTests.cs +++ b/tests/Core.Tests/Logging/LoggerExtensionsTests.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Logging; using Microsoft.Extensions.Logging; using Moq; diff --git a/tests/Core.Tests/Logging/LoggingConfigTests.cs b/tests/Core.Tests/Logging/LoggingConfigTests.cs index 33c3e0c0e..26faba254 100644 --- a/tests/Core.Tests/Logging/LoggingConfigTests.cs +++ b/tests/Core.Tests/Logging/LoggingConfigTests.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Logging; using Serilog.Events; namespace KernelMemory.Core.Tests.Logging; diff --git a/tests/Core.Tests/Logging/LoggingConstantsTests.cs b/tests/Core.Tests/Logging/LoggingConstantsTests.cs index 0f502a210..25a7d3d11 100644 --- a/tests/Core.Tests/Logging/LoggingConstantsTests.cs +++ b/tests/Core.Tests/Logging/LoggingConstantsTests.cs @@ -19,7 +19,7 @@ public void DefaultFileSizeLimitBytes_ShouldBe100MB() const long expectedBytes = 100 * 1024 * 1024; // Assert - Assert.Equal(expectedBytes, Core.Logging.LoggingConstants.DefaultFileSizeLimitBytes); + Assert.Equal(expectedBytes, Constants.LoggingDefaults.DefaultFileSizeLimitBytes); } /// @@ -30,7 +30,7 @@ public void DefaultFileSizeLimitBytes_ShouldBe100MB() public void DefaultRetainedFileCountLimit_ShouldBe30() { // Assert - Assert.Equal(30, Core.Logging.LoggingConstants.DefaultRetainedFileCountLimit); + Assert.Equal(30, Constants.LoggingDefaults.DefaultRetainedFileCountLimit); } /// @@ -41,7 +41,7 @@ public void DefaultRetainedFileCountLimit_ShouldBe30() public void DefaultFileLogLevel_ShouldBeInformation() { // Assert - Assert.Equal(Serilog.Events.LogEventLevel.Information, Core.Logging.LoggingConstants.DefaultFileLogLevel); + Assert.Equal(Serilog.Events.LogEventLevel.Information, Constants.LoggingDefaults.DefaultFileLogLevel); } /// @@ -52,7 +52,7 @@ public void DefaultFileLogLevel_ShouldBeInformation() public void DefaultConsoleLogLevel_ShouldBeWarning() { // Assert - Assert.Equal(Serilog.Events.LogEventLevel.Warning, Core.Logging.LoggingConstants.DefaultConsoleLogLevel); + Assert.Equal(Serilog.Events.LogEventLevel.Warning, Constants.LoggingDefaults.DefaultConsoleLogLevel); } /// @@ -62,7 +62,7 @@ public void DefaultConsoleLogLevel_ShouldBeWarning() public void DotNetEnvironmentVariable_ShouldBeDefined() { // Assert - Assert.Equal("DOTNET_ENVIRONMENT", Core.Logging.LoggingConstants.DotNetEnvironmentVariable); + Assert.Equal("DOTNET_ENVIRONMENT", Constants.LoggingDefaults.DotNetEnvironmentVariable); } /// @@ -72,7 +72,7 @@ public void DotNetEnvironmentVariable_ShouldBeDefined() public void AspNetCoreEnvironmentVariable_ShouldBeDefined() { // Assert - Assert.Equal("ASPNETCORE_ENVIRONMENT", Core.Logging.LoggingConstants.AspNetCoreEnvironmentVariable); + Assert.Equal("ASPNETCORE_ENVIRONMENT", Constants.LoggingDefaults.AspNetCoreEnvironmentVariable); } /// @@ -82,7 +82,7 @@ public void AspNetCoreEnvironmentVariable_ShouldBeDefined() public void DefaultEnvironment_ShouldBeDevelopment() { // Assert - Assert.Equal("Development", Core.Logging.LoggingConstants.DefaultEnvironment); + Assert.Equal("Development", Constants.LoggingDefaults.DefaultEnvironment); } /// @@ -92,7 +92,7 @@ public void DefaultEnvironment_ShouldBeDevelopment() public void ProductionEnvironment_ShouldBeDefined() { // Assert - Assert.Equal("Production", Core.Logging.LoggingConstants.ProductionEnvironment); + Assert.Equal("Production", Constants.LoggingDefaults.ProductionEnvironment); } /// @@ -102,7 +102,7 @@ public void ProductionEnvironment_ShouldBeDefined() public void RedactedPlaceholder_ShouldBeDefined() { // Assert - Assert.Equal("[REDACTED]", Core.Logging.LoggingConstants.RedactedPlaceholder); + Assert.Equal("[REDACTED]", Constants.LoggingDefaults.RedactedPlaceholder); } /// @@ -112,7 +112,7 @@ public void RedactedPlaceholder_ShouldBeDefined() public void HumanReadableOutputTemplate_ShouldContainTimestampAndLevel() { // Arrange & Act - const string template = Core.Logging.LoggingConstants.HumanReadableOutputTemplate; + const string template = Constants.LoggingDefaults.HumanReadableOutputTemplate; // Assert - template should contain key elements Assert.Contains("{Timestamp", template); diff --git a/tests/Core.Tests/Logging/SensitiveDataScrubbingPolicyTests.cs b/tests/Core.Tests/Logging/SensitiveDataScrubbingPolicyTests.cs index b8d9d6ca2..09419509a 100644 --- a/tests/Core.Tests/Logging/SensitiveDataScrubbingPolicyTests.cs +++ b/tests/Core.Tests/Logging/SensitiveDataScrubbingPolicyTests.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Logging; using Serilog.Core; using Serilog.Events; @@ -21,7 +20,7 @@ public sealed class SensitiveDataScrubbingPolicyTests : IDisposable /// public SensitiveDataScrubbingPolicyTests() { - this._originalDotNetEnv = Environment.GetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable); + this._originalDotNetEnv = Environment.GetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable); this._policy = new SensitiveDataScrubbingPolicy(); } @@ -32,11 +31,11 @@ public void Dispose() { if (this._originalDotNetEnv != null) { - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, this._originalDotNetEnv); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, this._originalDotNetEnv); } else { - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, null); } GC.SuppressFinalize(this); @@ -50,7 +49,7 @@ public void Dispose() public void TryDestructure_WhenProductionAndString_ShouldScrub() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "Production"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "Production"); const string sensitiveValue = "secret-api-key-12345"; // Act @@ -60,7 +59,7 @@ public void TryDestructure_WhenProductionAndString_ShouldScrub() Assert.True(handled); Assert.NotNull(result); Assert.IsType(result); - Assert.Equal(LoggingConstants.RedactedPlaceholder, ((ScalarValue)result).Value); + Assert.Equal(Constants.LoggingDefaults.RedactedPlaceholder, ((ScalarValue)result).Value); } /// @@ -71,7 +70,7 @@ public void TryDestructure_WhenProductionAndString_ShouldScrub() public void TryDestructure_WhenDevelopmentAndString_ShouldNotScrub() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "Development"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "Development"); const string value = "test-value"; // Act @@ -90,8 +89,8 @@ public void TryDestructure_WhenDevelopmentAndString_ShouldNotScrub() public void TryDestructure_WhenNoEnvironmentAndString_ShouldNotScrub() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, null); - Environment.SetEnvironmentVariable(LoggingConstants.AspNetCoreEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, null); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.AspNetCoreEnvironmentVariable, null); const string value = "test-value"; // Act @@ -110,7 +109,7 @@ public void TryDestructure_WhenNoEnvironmentAndString_ShouldNotScrub() public void TryDestructure_WhenProductionAndInteger_ShouldNotScrub() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "Production"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "Production"); // Act var handled = this._policy.TryDestructure(42, new TestPropertyValueFactory(), out var result); @@ -128,7 +127,7 @@ public void TryDestructure_WhenProductionAndInteger_ShouldNotScrub() public void TryDestructure_WhenProductionAndDateTime_ShouldNotScrub() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "Production"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "Production"); var dateTime = DateTimeOffset.UtcNow; // Act @@ -147,7 +146,7 @@ public void TryDestructure_WhenProductionAndDateTime_ShouldNotScrub() public void TryDestructure_WhenProductionAndBoolean_ShouldNotScrub() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "Production"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "Production"); // Act var handled = this._policy.TryDestructure(true, new TestPropertyValueFactory(), out var result); @@ -165,7 +164,7 @@ public void TryDestructure_WhenProductionAndBoolean_ShouldNotScrub() public void TryDestructure_WhenProductionAndGuid_ShouldNotScrub() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "Production"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "Production"); var guid = Guid.NewGuid(); // Act @@ -184,7 +183,7 @@ public void TryDestructure_WhenProductionAndGuid_ShouldNotScrub() public void TryDestructure_WhenProductionAndEmptyString_ShouldScrub() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "Production"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "Production"); // Act var handled = this._policy.TryDestructure(string.Empty, new TestPropertyValueFactory(), out var result); @@ -192,7 +191,7 @@ public void TryDestructure_WhenProductionAndEmptyString_ShouldScrub() // Assert Assert.True(handled); Assert.NotNull(result); - Assert.Equal(LoggingConstants.RedactedPlaceholder, ((ScalarValue)result).Value); + Assert.Equal(Constants.LoggingDefaults.RedactedPlaceholder, ((ScalarValue)result).Value); } /// @@ -203,7 +202,7 @@ public void TryDestructure_WhenProductionAndEmptyString_ShouldScrub() public void TryDestructure_WhenProductionAndNull_ShouldNotScrub() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "Production"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "Production"); // Act var handled = this._policy.TryDestructure(null!, new TestPropertyValueFactory(), out var result); @@ -221,7 +220,7 @@ public void TryDestructure_WhenProductionAndNull_ShouldNotScrub() public void TryDestructure_WhenStagingAndString_ShouldNotScrub() { // Arrange - Environment.SetEnvironmentVariable(LoggingConstants.DotNetEnvironmentVariable, "Staging"); + Environment.SetEnvironmentVariable(Constants.LoggingDefaults.DotNetEnvironmentVariable, "Staging"); const string value = "test-value"; // Act diff --git a/tests/Core.Tests/Logging/SerilogFactoryTests.cs b/tests/Core.Tests/Logging/SerilogFactoryTests.cs index c772058bb..2772ce9fa 100644 --- a/tests/Core.Tests/Logging/SerilogFactoryTests.cs +++ b/tests/Core.Tests/Logging/SerilogFactoryTests.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System.Diagnostics.CodeAnalysis; -using KernelMemory.Core.Logging; using Microsoft.Extensions.Logging; using Serilog.Events; diff --git a/tests/Core.Tests/Logging/TestLoggerFactory.cs b/tests/Core.Tests/Logging/TestLoggerFactory.cs index 5db62384a..580702960 100644 --- a/tests/Core.Tests/Logging/TestLoggerFactory.cs +++ b/tests/Core.Tests/Logging/TestLoggerFactory.cs @@ -2,7 +2,6 @@ using System.Diagnostics.CodeAnalysis; using System.Globalization; -using KernelMemory.Core.Logging; using Microsoft.Extensions.Logging; using Serilog; using Serilog.Extensions.Logging; diff --git a/tests/Core.Tests/Logging/TestLoggerFactoryTests.cs b/tests/Core.Tests/Logging/TestLoggerFactoryTests.cs index cc1f2eb1f..92d8842dd 100644 --- a/tests/Core.Tests/Logging/TestLoggerFactoryTests.cs +++ b/tests/Core.Tests/Logging/TestLoggerFactoryTests.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Logging; using Microsoft.Extensions.Logging; using Xunit.Abstractions; diff --git a/tests/Core.Tests/Search/FtsIndexPersistenceTest.cs b/tests/Core.Tests/Search/FtsIndexPersistenceTest.cs index 0fc4c1169..e7e985361 100644 --- a/tests/Core.Tests/Search/FtsIndexPersistenceTest.cs +++ b/tests/Core.Tests/Search/FtsIndexPersistenceTest.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Search; using Microsoft.Extensions.Logging; using Moq; diff --git a/tests/Core.Tests/Search/FtsIntegrationTests.cs b/tests/Core.Tests/Search/FtsIntegrationTests.cs index bfa656540..ea12f51a0 100644 --- a/tests/Core.Tests/Search/FtsIntegrationTests.cs +++ b/tests/Core.Tests/Search/FtsIntegrationTests.cs @@ -1,5 +1,4 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Search; using KernelMemory.Core.Storage; using KernelMemory.Core.Storage.Models; using Microsoft.Data.Sqlite; @@ -61,7 +60,7 @@ public FtsIntegrationTests() this._context, this._mockCuidGenerator.Object, this._mockStorageLogger.Object, - searchIndexById); + (IReadOnlyDictionary)searchIndexById); } public void Dispose() diff --git a/tests/Core.Tests/Search/FtsQueryExtractionTest.cs b/tests/Core.Tests/Search/FtsQueryExtractionTest.cs index f444db1ef..85e529e32 100644 --- a/tests/Core.Tests/Search/FtsQueryExtractionTest.cs +++ b/tests/Core.Tests/Search/FtsQueryExtractionTest.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Search; using KernelMemory.Core.Search.Models; using KernelMemory.Core.Search.Query.Parsers; using KernelMemory.Core.Storage; diff --git a/tests/Core.Tests/Search/Models/SearchRequestTests.cs b/tests/Core.Tests/Search/Models/SearchRequestTests.cs index 5153cdc45..25c771fb2 100644 --- a/tests/Core.Tests/Search/Models/SearchRequestTests.cs +++ b/tests/Core.Tests/Search/Models/SearchRequestTests.cs @@ -1,5 +1,4 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Search; using KernelMemory.Core.Search.Models; namespace KernelMemory.Core.Tests.Search.Models; @@ -21,9 +20,9 @@ public void DefaultValues_AreCorrect() Assert.Empty(request.ExcludeNodes); Assert.Empty(request.SearchIndexes); Assert.Empty(request.ExcludeIndexes); - Assert.Equal(SearchConstants.DefaultLimit, request.Limit); + Assert.Equal(Constants.SearchDefaults.DefaultLimit, request.Limit); Assert.Equal(0, request.Offset); - Assert.Equal(SearchConstants.DefaultMinRelevance, request.MinRelevance); + Assert.Equal(Constants.SearchDefaults.DefaultMinRelevance, request.MinRelevance); Assert.Null(request.MaxResultsPerNode); Assert.Null(request.NodeWeights); Assert.False(request.SnippetOnly); diff --git a/tests/Core.Tests/Search/NodeSearchServiceIndexIdTests.cs b/tests/Core.Tests/Search/NodeSearchServiceIndexIdTests.cs index 628326182..a04872aa8 100644 --- a/tests/Core.Tests/Search/NodeSearchServiceIndexIdTests.cs +++ b/tests/Core.Tests/Search/NodeSearchServiceIndexIdTests.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Search; using KernelMemory.Core.Search.Models; using KernelMemory.Core.Search.Query.Ast; using KernelMemory.Core.Storage; @@ -137,7 +136,7 @@ await storage.UpsertAsync(new KernelMemory.Core.Storage.Models.UpsertRequest // Assert Assert.NotEmpty(results); - Assert.All(results, r => Assert.Equal(SearchConstants.DefaultFtsIndexId, r.IndexId)); + Assert.All(results, r => Assert.Equal(Constants.SearchDefaults.DefaultFtsIndexId, r.IndexId)); } /// @@ -193,14 +192,14 @@ await storage2.UpsertAsync(new KernelMemory.Core.Storage.Models.UpsertRequest } /// - /// Tests that SearchConstants.DefaultFtsIndexId constant has the expected value. + /// Tests that Constants.SearchDefaults.DefaultFtsIndexId constant has the expected value. /// Validates the constant is properly defined. /// [Fact] public void DefaultFtsIndexId_HasExpectedValue() { // Assert - Assert.Equal("fts-main", SearchConstants.DefaultFtsIndexId); + Assert.Equal("fts-main", Constants.SearchDefaults.DefaultFtsIndexId); } /// @@ -226,7 +225,7 @@ public void DefaultFtsIndexId_HasExpectedValue() var cuidGenerator = new CuidGenerator(); var searchIndexes = new Dictionary { ["fts"] = ftsIndex }; - var storage = new ContentStorageService(context, cuidGenerator, this._mockStorageLogger.Object, searchIndexes); + var storage = new ContentStorageService(context, cuidGenerator, this._mockStorageLogger.Object, (IReadOnlyDictionary)searchIndexes); return (ftsIndex, storage); } diff --git a/tests/Core.Tests/Search/SearchConstantsTests.cs b/tests/Core.Tests/Search/SearchConstantsTests.cs index c6cbc41de..c51ed9a72 100644 --- a/tests/Core.Tests/Search/SearchConstantsTests.cs +++ b/tests/Core.Tests/Search/SearchConstantsTests.cs @@ -1,6 +1,4 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Search; - namespace KernelMemory.Core.Tests.Search; /// @@ -12,40 +10,40 @@ public sealed class SearchConstantsTests public void DefaultValues_AreCorrect() { // Verify default values match requirements - Assert.Equal(0.3f, SearchConstants.DefaultMinRelevance); - Assert.Equal(20, SearchConstants.DefaultLimit); - Assert.Equal(30, SearchConstants.DefaultSearchTimeoutSeconds); - Assert.Equal(1000, SearchConstants.DefaultMaxResultsPerNode); - Assert.Equal(1.0f, SearchConstants.DefaultNodeWeight); - Assert.Equal(1.0f, SearchConstants.DefaultIndexWeight); + Assert.Equal(0.3f, Constants.SearchDefaults.DefaultMinRelevance); + Assert.Equal(20, Constants.SearchDefaults.DefaultLimit); + Assert.Equal(30, Constants.SearchDefaults.DefaultSearchTimeoutSeconds); + Assert.Equal(1000, Constants.SearchDefaults.DefaultMaxResultsPerNode); + Assert.Equal(1.0f, Constants.SearchDefaults.DefaultNodeWeight); + Assert.Equal(1.0f, Constants.SearchDefaults.DefaultIndexWeight); } [Fact] public void QueryComplexityLimits_AreReasonable() { // Verify query complexity limits are set - Assert.Equal(10, SearchConstants.MaxQueryDepth); - Assert.Equal(50, SearchConstants.MaxBooleanOperators); - Assert.Equal(1000, SearchConstants.MaxFieldValueLength); - Assert.Equal(1000, SearchConstants.QueryParseTimeoutMs); + Assert.Equal(10, Constants.SearchDefaults.MaxQueryDepth); + Assert.Equal(50, Constants.SearchDefaults.MaxBooleanOperators); + Assert.Equal(1000, Constants.SearchDefaults.MaxFieldValueLength); + Assert.Equal(1000, Constants.SearchDefaults.QueryParseTimeoutMs); } [Fact] public void SnippetDefaults_AreConfigured() { // Verify snippet configuration - Assert.Equal(200, SearchConstants.DefaultSnippetLength); - Assert.Equal(1, SearchConstants.DefaultMaxSnippetsPerResult); - Assert.Equal("...", SearchConstants.DefaultSnippetSeparator); - Assert.Equal("", SearchConstants.DefaultHighlightPrefix); - Assert.Equal("", SearchConstants.DefaultHighlightSuffix); + Assert.Equal(200, Constants.SearchDefaults.DefaultSnippetLength); + Assert.Equal(1, Constants.SearchDefaults.DefaultMaxSnippetsPerResult); + Assert.Equal("...", Constants.SearchDefaults.DefaultSnippetSeparator); + Assert.Equal("", Constants.SearchDefaults.DefaultHighlightPrefix); + Assert.Equal("", Constants.SearchDefaults.DefaultHighlightSuffix); } [Fact] public void DiminishingMultipliers_FollowPattern() { // Verify diminishing returns pattern (each is half of previous) - var multipliers = SearchConstants.DefaultDiminishingMultipliers; + var multipliers = Constants.SearchDefaults.DefaultDiminishingMultipliers; Assert.Equal(4, multipliers.Length); Assert.Equal(1.0f, multipliers[0]); Assert.Equal(0.5f, multipliers[1]); @@ -57,14 +55,14 @@ public void DiminishingMultipliers_FollowPattern() public void RelevanceScoreBounds_AreCorrect() { // Verify score boundaries - Assert.Equal(1.0f, SearchConstants.MaxRelevanceScore); - Assert.Equal(0.0f, SearchConstants.MinRelevanceScore); + Assert.Equal(1.0f, Constants.SearchDefaults.MaxRelevanceScore); + Assert.Equal(0.0f, Constants.SearchDefaults.MinRelevanceScore); } [Fact] public void AllNodesWildcard_IsAsterisk() { // Verify wildcard character - Assert.Equal("*", SearchConstants.AllNodesWildcard); + Assert.Equal("*", Constants.SearchDefaults.AllNodesWildcard); } } diff --git a/tests/Core.Tests/Search/SearchEndToEndTests.cs b/tests/Core.Tests/Search/SearchEndToEndTests.cs index 7792982cf..dd82564b1 100644 --- a/tests/Core.Tests/Search/SearchEndToEndTests.cs +++ b/tests/Core.Tests/Search/SearchEndToEndTests.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Search; using KernelMemory.Core.Search.Models; using KernelMemory.Core.Storage; using Microsoft.EntityFrameworkCore; @@ -44,7 +43,7 @@ public SearchEndToEndTests() this._ftsIndex = new SqliteFtsIndex(ftsDbPath, enableStemming: true, mockFtsLogger.Object); var searchIndexes = new Dictionary { ["fts"] = this._ftsIndex }; - this._storage = new ContentStorageService(this._context, cuidGenerator, mockStorageLogger.Object, searchIndexes); + this._storage = new ContentStorageService(this._context, cuidGenerator, mockStorageLogger.Object, (IReadOnlyDictionary)searchIndexes); var nodeService = new NodeSearchService("test-node", this._ftsIndex, this._storage); this._searchService = new SearchService(new Dictionary { ["test-node"] = nodeService }); diff --git a/tests/Core.Tests/Search/SearchServiceFunctionalTests.cs b/tests/Core.Tests/Search/SearchServiceFunctionalTests.cs index 808f131d0..b85f7ff1a 100644 --- a/tests/Core.Tests/Search/SearchServiceFunctionalTests.cs +++ b/tests/Core.Tests/Search/SearchServiceFunctionalTests.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Search; using KernelMemory.Core.Search.Models; using KernelMemory.Core.Storage; using Microsoft.EntityFrameworkCore; @@ -46,7 +45,7 @@ public SearchServiceFunctionalTests() var fts1DbPath = Path.Combine(this._tempDir, "node1_fts.db"); this._fts1 = new SqliteFtsIndex(fts1DbPath, enableStemming: true, mockFtsLogger1.Object); var searchIndexes1 = new Dictionary { ["fts"] = this._fts1 }; - this._storage1 = new ContentStorageService(this._context1, cuidGenerator, mockStorageLogger1.Object, searchIndexes1); + this._storage1 = new ContentStorageService(this._context1, cuidGenerator, mockStorageLogger1.Object, (IReadOnlyDictionary)searchIndexes1); var node1Service = new NodeSearchService("node1", this._fts1, this._storage1); // Node 2 @@ -60,7 +59,7 @@ public SearchServiceFunctionalTests() var fts2DbPath = Path.Combine(this._tempDir, "node2_fts.db"); this._fts2 = new SqliteFtsIndex(fts2DbPath, enableStemming: true, mockFtsLogger2.Object); var searchIndexes2 = new Dictionary { ["fts"] = this._fts2 }; - this._storage2 = new ContentStorageService(this._context2, cuidGenerator, mockStorageLogger2.Object, searchIndexes2); + this._storage2 = new ContentStorageService(this._context2, cuidGenerator, mockStorageLogger2.Object, (IReadOnlyDictionary)searchIndexes2); var node2Service = new NodeSearchService("node2", this._fts2, this._storage2); var nodeServices = new Dictionary diff --git a/tests/Core.Tests/Search/SearchServiceIndexWeightsTests.cs b/tests/Core.Tests/Search/SearchServiceIndexWeightsTests.cs index 768298eaa..9f840ee30 100644 --- a/tests/Core.Tests/Search/SearchServiceIndexWeightsTests.cs +++ b/tests/Core.Tests/Search/SearchServiceIndexWeightsTests.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Search; using KernelMemory.Core.Search.Models; using KernelMemory.Core.Storage; using Microsoft.EntityFrameworkCore; @@ -43,7 +42,7 @@ public SearchServiceIndexWeightsTests() this._ftsIndex = new SqliteFtsIndex(ftsDbPath, enableStemming: true, mockFtsLogger.Object); var searchIndexes = new Dictionary { ["fts"] = this._ftsIndex }; - this._storage = new ContentStorageService(this._context, cuidGenerator, mockStorageLogger.Object, searchIndexes); + this._storage = new ContentStorageService(this._context, cuidGenerator, mockStorageLogger.Object, (IReadOnlyDictionary)searchIndexes); } public void Dispose() @@ -79,7 +78,7 @@ await this._storage.UpsertAsync(new KernelMemory.Core.Storage.Models.UpsertReque { ["test-node"] = new Dictionary { - [SearchConstants.DefaultFtsIndexId] = 0.5f // Custom weight + [Constants.SearchDefaults.DefaultFtsIndexId] = 0.5f // Custom weight } }; @@ -161,7 +160,7 @@ await this._storage.UpsertAsync(new KernelMemory.Core.Storage.Models.UpsertReque { ["test-node"] = new Dictionary { - [SearchConstants.DefaultFtsIndexId] = 0.7f, // FTS index weight + [Constants.SearchDefaults.DefaultFtsIndexId] = 0.7f, // FTS index weight ["vector-main"] = 0.3f // Vector index weight (not used here, but configured) } }; diff --git a/tests/Core.Tests/Search/SimpleSearchTest.cs b/tests/Core.Tests/Search/SimpleSearchTest.cs index f1f47027b..3e05a4db6 100644 --- a/tests/Core.Tests/Search/SimpleSearchTest.cs +++ b/tests/Core.Tests/Search/SimpleSearchTest.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Search; using KernelMemory.Core.Search.Models; using KernelMemory.Core.Storage; using Microsoft.EntityFrameworkCore; diff --git a/tests/Core.Tests/Search/SqliteFtsIndexTests.cs b/tests/Core.Tests/Search/SqliteFtsIndexTests.cs index 8b72e8320..d792282c8 100644 --- a/tests/Core.Tests/Search/SqliteFtsIndexTests.cs +++ b/tests/Core.Tests/Search/SqliteFtsIndexTests.cs @@ -1,5 +1,4 @@ // Copyright (c) Microsoft. All rights reserved. -using KernelMemory.Core.Search; using Microsoft.Extensions.Logging; using Moq; diff --git a/tests/Core.Tests/Search/SqliteVectorIndexErrorHandlingTests.cs b/tests/Core.Tests/Search/SqliteVectorIndexErrorHandlingTests.cs new file mode 100644 index 000000000..49e784367 --- /dev/null +++ b/tests/Core.Tests/Search/SqliteVectorIndexErrorHandlingTests.cs @@ -0,0 +1,237 @@ +// Copyright (c) Microsoft. All rights reserved. + +using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Embeddings.Cache; +using Microsoft.Extensions.Logging; +using Moq; + +namespace KernelMemory.Core.Tests.Search; + +/// +/// Tests for SqliteVectorIndex error handling scenarios. +/// Validates that errors are handled gracefully with appropriate warnings/exceptions. +/// +public sealed class SqliteVectorIndexErrorHandlingTests : IDisposable +{ + private readonly string _dbPath; + private readonly Mock _mockGenerator; + private readonly Mock> _mockLogger; + + public SqliteVectorIndexErrorHandlingTests() + { + this._dbPath = Path.Combine(Path.GetTempPath(), $"vector-error-test-{Guid.NewGuid()}.db"); + this._mockGenerator = new Mock(); + this._mockLogger = new Mock>(); + + // Setup mock generator to return predictable embeddings + this._mockGenerator.Setup(g => g.VectorDimensions).Returns(3); + this._mockGenerator.Setup(g => g.GenerateAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(EmbeddingResult.FromVector([1.0f, 0.0f, 0.0f])); + } + + public void Dispose() + { + if (File.Exists(this._dbPath)) + { + File.Delete(this._dbPath); + } + } + + /// + /// Verifies that cache write failures generate warnings but don't prevent indexing. + /// This tests the non-blocking cache error handling requirement. + /// + [Fact] + public async Task IndexAsync_WhenCacheWriteFails_ContinuesWithWarning() + { + // Arrange - Create a cache that fails on write + var mockCache = new Mock(); + mockCache.Setup(c => c.Mode).Returns(CacheModes.ReadWrite); + mockCache.Setup(c => c.TryGetAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((CachedEmbedding?)null); // Cache miss + mockCache.Setup(c => c.StoreAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .ThrowsAsync(new IOException("Disk full")); // Cache write fails + + var mockCachedGenerator = new Mock(); + mockCachedGenerator.Setup(g => g.VectorDimensions).Returns(3); + mockCachedGenerator.Setup(g => g.GenerateAsync(It.IsAny(), It.IsAny())) + .ThrowsAsync(new IOException("Cache write failed")); // Simulates CachedEmbeddingGenerator catching cache error + + // But the actual generator should work + this._mockGenerator.Setup(g => g.GenerateAsync("test content", It.IsAny())) + .ReturnsAsync(EmbeddingResult.FromVector([1.0f, 0.0f, 0.0f])); + + using var index = new SqliteVectorIndex(this._dbPath, 3, useSqliteVec: false, this._mockGenerator.Object, this._mockLogger.Object); + + // Act - Should succeed despite cache error + await index.IndexAsync("test-id", "test content", CancellationToken.None).ConfigureAwait(false); + + // Assert - Warning should be logged but operation succeeds + // Verify data was actually stored + var results = await index.SearchAsync("test content", limit: 10, CancellationToken.None).ConfigureAwait(false); + Assert.Single(results); + Assert.Equal("test-id", results[0].ContentId); + } + + /// + /// Verifies that cache read failures don't prevent indexing. + /// System should continue without cache when read fails. + /// + [Fact] + public async Task IndexAsync_WhenCacheReadFails_ContinuesWithoutCache() + { + // Arrange - Cache that fails on read + var mockCache = new Mock(); + mockCache.Setup(c => c.Mode).Returns(CacheModes.ReadWrite); + mockCache.Setup(c => c.TryGetAsync(It.IsAny(), It.IsAny())) + .ThrowsAsync(new IOException("Cache corrupted")); + + // Even with cache read failure, generator should be called + this._mockGenerator.Setup(g => g.GenerateAsync("test", It.IsAny())) + .ReturnsAsync(EmbeddingResult.FromVector([1.0f, 0.0f, 0.0f])); + + using var index = new SqliteVectorIndex(this._dbPath, 3, useSqliteVec: false, this._mockGenerator.Object, this._mockLogger.Object); + + // Act - Should succeed by calling generator directly + await index.IndexAsync("id1", "test", CancellationToken.None).ConfigureAwait(false); + + // Assert - Data stored successfully + var results = await index.SearchAsync("test", 10, CancellationToken.None).ConfigureAwait(false); + Assert.Single(results); + } + + /// + /// Verifies that when embedding provider is unreachable, operation throws with clear message. + /// This is the blocking behavior - operation should be queued for retry. + /// + [Fact] + public async Task IndexAsync_WhenProviderUnreachable_ThrowsWithClearMessage() + { + // Arrange - Generator that simulates Ollama being down + var failingGenerator = new Mock(); + failingGenerator.Setup(g => g.VectorDimensions).Returns(3); + failingGenerator.Setup(g => g.GenerateAsync(It.IsAny(), It.IsAny())) + .ThrowsAsync(new HttpRequestException("Connection refused")); + + using var index = new SqliteVectorIndex(this._dbPath, 3, useSqliteVec: false, failingGenerator.Object, this._mockLogger.Object); + + // Act & Assert - Should throw and propagate to caller + var ex = await Assert.ThrowsAsync(async () => + await index.IndexAsync("id1", "test content", CancellationToken.None).ConfigureAwait(false)).ConfigureAwait(false); + + Assert.Contains("Connection refused", ex.Message); + } + + /// + /// Verifies that invalid API key errors are propagated with clear messages. + /// Operation should fail and be queued for retry. + /// + [Fact] + public async Task IndexAsync_WhenApiKeyInvalid_ThrowsWithClearMessage() + { + // Arrange - Generator that simulates invalid API key + var failingGenerator = new Mock(); + failingGenerator.Setup(g => g.VectorDimensions).Returns(3); + failingGenerator.Setup(g => g.GenerateAsync(It.IsAny(), It.IsAny())) + .ThrowsAsync(new UnauthorizedAccessException("Invalid API key")); + + using var index = new SqliteVectorIndex(this._dbPath, 3, useSqliteVec: false, failingGenerator.Object, this._mockLogger.Object); + + // Act & Assert + var ex = await Assert.ThrowsAsync(async () => + await index.IndexAsync("id1", "test content", CancellationToken.None).ConfigureAwait(false)).ConfigureAwait(false); + + Assert.Contains("Invalid API key", ex.Message); + } + + /// + /// Verifies that when useSqliteVec is true but extension is unavailable, + /// system logs a warning and falls back to C# implementation. + /// This tests the graceful degradation requirement. + /// Warning is logged during first IndexAsync call (lazy initialization). + /// + [Fact] + public async Task IndexAsync_WhenSqliteVecUnavailableButRequested_LogsWarningAndContinues() + { + // Arrange - Request sqlite-vec but it won't be available (no extension installed) + using var index = new SqliteVectorIndex( + this._dbPath, + dimensions: 3, + useSqliteVec: true, // Request extension + this._mockGenerator.Object, + this._mockLogger.Object); + + // Act - First IndexAsync triggers initialization and warning + await index.IndexAsync("test-id", "test content", CancellationToken.None).ConfigureAwait(false); + + // Assert - Operation should succeed + var results = await index.SearchAsync("test content", 10, CancellationToken.None).ConfigureAwait(false); + Assert.Single(results); + + // Verify warning was logged about sqlite-vec fallback + this._mockLogger.Verify( + x => x.Log( + LogLevel.Warning, + It.IsAny(), + It.Is((v, t) => v.ToString()!.Contains("sqlite-vec")), + It.IsAny(), + It.IsAny>()), + Times.AtLeastOnce); + } + + /// + /// Verifies that vector search produces same results whether using + /// sqlite-vec extension or C# implementation (when extension unavailable). + /// This ensures data format compatibility. + /// + [Fact] + public async Task SearchAsync_WithAndWithoutSqliteVec_ProducesSameResults() + { + // Arrange - Create two indexes with same data + var dbPath1 = Path.Combine(Path.GetTempPath(), $"vec-test-1-{Guid.NewGuid()}.db"); + var dbPath2 = Path.Combine(Path.GetTempPath(), $"vec-test-2-{Guid.NewGuid()}.db"); + + try + { + this._mockGenerator.Setup(g => g.GenerateAsync("hello world", It.IsAny())) + .ReturnsAsync(EmbeddingResult.FromVector([0.6f, 0.8f, 0.0f])); + this._mockGenerator.Setup(g => g.GenerateAsync("goodbye world", It.IsAny())) + .ReturnsAsync(EmbeddingResult.FromVector([0.8f, 0.6f, 0.0f])); + this._mockGenerator.Setup(g => g.GenerateAsync("hello", It.IsAny())) + .ReturnsAsync(EmbeddingResult.FromVector([1.0f, 0.0f, 0.0f])); + + // Index without extension (C# implementation) + using var index1 = new SqliteVectorIndex(dbPath1, 3, useSqliteVec: false, this._mockGenerator.Object, this._mockLogger.Object); + await index1.IndexAsync("id1", "hello world", CancellationToken.None).ConfigureAwait(false); + await index1.IndexAsync("id2", "goodbye world", CancellationToken.None).ConfigureAwait(false); + + // Index with extension requested (will fall back to C# if unavailable) + using var index2 = new SqliteVectorIndex(dbPath2, 3, useSqliteVec: true, this._mockGenerator.Object, this._mockLogger.Object); + await index2.IndexAsync("id1", "hello world", CancellationToken.None).ConfigureAwait(false); + await index2.IndexAsync("id2", "goodbye world", CancellationToken.None).ConfigureAwait(false); + + // Act - Search both indexes + var results1 = await index1.SearchAsync("hello", 10, CancellationToken.None).ConfigureAwait(false); + var results2 = await index2.SearchAsync("hello", 10, CancellationToken.None).ConfigureAwait(false); + + // Assert - Results should be identical (same ranking, same scores) + Assert.Equal(results1.Count, results2.Count); + Assert.Equal(results1[0].ContentId, results2[0].ContentId); + Assert.Equal(results1[0].Score, results2[0].Score, precision: 5); + Assert.Equal(results1[1].ContentId, results2[1].ContentId); + Assert.Equal(results1[1].Score, results2[1].Score, precision: 5); + } + finally + { + if (File.Exists(dbPath1)) + { + File.Delete(dbPath1); + } + + if (File.Exists(dbPath2)) + { + File.Delete(dbPath2); + } + } + } +} diff --git a/tests/Core.Tests/Search/SqliteVectorIndexPersistenceTests.cs b/tests/Core.Tests/Search/SqliteVectorIndexPersistenceTests.cs new file mode 100644 index 000000000..e849438f6 --- /dev/null +++ b/tests/Core.Tests/Search/SqliteVectorIndexPersistenceTests.cs @@ -0,0 +1,279 @@ +// Copyright (c) Microsoft. All rights reserved. +using Microsoft.Extensions.Logging; +using Moq; + +namespace KernelMemory.Core.Tests.Search; + +/// +/// Tests for SqliteVectorIndex data persistence across dispose/recreate cycles. +/// Ensures vectors survive database reconnection. +/// +public sealed class SqliteVectorIndexPersistenceTests : IDisposable +{ + private readonly string _dbPath; + private readonly Mock> _mockLogger; + private readonly Mock _mockEmbeddingGenerator; + private const int TestDimensions = 4; + + public SqliteVectorIndexPersistenceTests() + { + // Use temp file for SQLite + this._dbPath = Path.Combine(Path.GetTempPath(), $"vector_persist_test_{Guid.NewGuid()}.db"); + this._mockLogger = new Mock>(); + this._mockEmbeddingGenerator = new Mock(); + + // Configure mock to return predictable embeddings + this._mockEmbeddingGenerator + .Setup(x => x.GenerateAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((string text, CancellationToken _) => EmbeddingResult.FromVector(this.GenerateTestEmbedding(text))); + } + + public void Dispose() + { + // Clean up temp file + if (File.Exists(this._dbPath)) + { + File.Delete(this._dbPath); + } + + // Clean up WAL files + var walPath = this._dbPath + "-wal"; + var shmPath = this._dbPath + "-shm"; + if (File.Exists(walPath)) + { + File.Delete(walPath); + } + + if (File.Exists(shmPath)) + { + File.Delete(shmPath); + } + + GC.SuppressFinalize(this); + } + + /// + /// Generates a deterministic test embedding based on text hash. + /// + /// + private float[] GenerateTestEmbedding(string text) + { + var hash = text.GetHashCode(); + var embedding = new float[TestDimensions]; + for (int i = 0; i < TestDimensions; i++) + { + embedding[i] = ((hash >> (i * 8)) & 0xFF) / 255.0f; + } + + return embedding; + } + + [Fact] + public async Task VectorsPersistAcrossDisposeAndRecreate() + { + // Arrange - Create and populate index + const string contentId = "persist-test-1"; + const string text = "This is persisted content"; + + using (var firstIndex = new SqliteVectorIndex( + this._dbPath, + TestDimensions, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object)) + { + await firstIndex.IndexAsync(contentId, text).ConfigureAwait(false); + + // Verify it exists before dispose + var beforeDispose = await firstIndex.SearchAsync(text).ConfigureAwait(false); + Assert.Single(beforeDispose); + } + + // Act - Create new index pointing to same file + using (var secondIndex = new SqliteVectorIndex( + this._dbPath, + TestDimensions, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object)) + { + // Assert - Data should persist + var afterRecreate = await secondIndex.SearchAsync(text).ConfigureAwait(false); + Assert.Single(afterRecreate); + Assert.Equal(contentId, afterRecreate[0].ContentId); + } + } + + [Fact] + public async Task MultipleVectorsPersistCorrectly() + { + // Arrange + var testData = new Dictionary + { + { "id1", "First document about science" }, + { "id2", "Second document about history" }, + { "id3", "Third document about mathematics" } + }; + + // Create and populate + using (var firstIndex = new SqliteVectorIndex( + this._dbPath, + TestDimensions, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object)) + { + foreach (var (id, content) in testData) + { + await firstIndex.IndexAsync(id, content).ConfigureAwait(false); + } + } + + // Act - Recreate and verify + using (var secondIndex = new SqliteVectorIndex( + this._dbPath, + TestDimensions, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object)) + { + var results = await secondIndex.SearchAsync("document").ConfigureAwait(false); + + // Assert + Assert.Equal(3, results.Count); + var contentIds = results.Select(r => r.ContentId).ToHashSet(); + Assert.Contains("id1", contentIds); + Assert.Contains("id2", contentIds); + Assert.Contains("id3", contentIds); + } + } + + [Fact] + public async Task RemovalPersistsCorrectly() + { + // Arrange + const string toKeep = "keep-id"; + const string toRemove = "remove-id"; + + using (var firstIndex = new SqliteVectorIndex( + this._dbPath, + TestDimensions, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object)) + { + await firstIndex.IndexAsync(toKeep, "Content to keep").ConfigureAwait(false); + await firstIndex.IndexAsync(toRemove, "Content to remove").ConfigureAwait(false); + await firstIndex.RemoveAsync(toRemove).ConfigureAwait(false); + } + + // Act - Recreate and verify + using (var secondIndex = new SqliteVectorIndex( + this._dbPath, + TestDimensions, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object)) + { + var results = await secondIndex.SearchAsync("Content").ConfigureAwait(false); + + // Assert + Assert.Single(results); + Assert.Equal(toKeep, results[0].ContentId); + } + } + + [Fact] + public async Task ClearPersistsCorrectly() + { + // Arrange + using (var firstIndex = new SqliteVectorIndex( + this._dbPath, + TestDimensions, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object)) + { + await firstIndex.IndexAsync("id1", "First content").ConfigureAwait(false); + await firstIndex.IndexAsync("id2", "Second content").ConfigureAwait(false); + await firstIndex.ClearAsync().ConfigureAwait(false); + } + + // Act - Recreate and verify + using (var secondIndex = new SqliteVectorIndex( + this._dbPath, + TestDimensions, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object)) + { + var results = await secondIndex.SearchAsync("content").ConfigureAwait(false); + + // Assert + Assert.Empty(results); + } + } + + [Fact] + public async Task UpdatePersistsCorrectly() + { + // Arrange + const string contentId = "update-test"; + + using (var firstIndex = new SqliteVectorIndex( + this._dbPath, + TestDimensions, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object)) + { + await firstIndex.IndexAsync(contentId, "Original content").ConfigureAwait(false); + await firstIndex.IndexAsync(contentId, "Updated content").ConfigureAwait(false); + } + + // Act - Recreate and verify + using (var secondIndex = new SqliteVectorIndex( + this._dbPath, + TestDimensions, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object)) + { + var results = await secondIndex.SearchAsync("Updated content").ConfigureAwait(false); + + // Assert - Should only have one entry for this ID + Assert.Single(results); + Assert.Equal(contentId, results[0].ContentId); + } + } + + [Fact] + public async Task CanIndexAfterReopen() + { + // Arrange + using (var firstIndex = new SqliteVectorIndex( + this._dbPath, + TestDimensions, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object)) + { + await firstIndex.IndexAsync("id1", "First content").ConfigureAwait(false); + } + + // Act - Reopen and add more content + using (var secondIndex = new SqliteVectorIndex( + this._dbPath, + TestDimensions, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object)) + { + await secondIndex.IndexAsync("id2", "Second content").ConfigureAwait(false); + var results = await secondIndex.SearchAsync("content").ConfigureAwait(false); + + // Assert + Assert.Equal(2, results.Count); + } + } +} diff --git a/tests/Core.Tests/Search/SqliteVectorIndexTests.cs b/tests/Core.Tests/Search/SqliteVectorIndexTests.cs new file mode 100644 index 000000000..97b4bbe53 --- /dev/null +++ b/tests/Core.Tests/Search/SqliteVectorIndexTests.cs @@ -0,0 +1,358 @@ +// Copyright (c) Microsoft. All rights reserved. +using Microsoft.Extensions.Logging; +using Moq; + +namespace KernelMemory.Core.Tests.Search; + +/// +/// Unit tests for SqliteVectorIndex using mock embedding generator. +/// Tests cover indexing, searching, and removal operations with normalized vectors. +/// +public sealed class SqliteVectorIndexTests : IDisposable +{ + private readonly string _dbPath; + private readonly Mock> _mockLogger; + private readonly Mock _mockEmbeddingGenerator; + private readonly SqliteVectorIndex _vectorIndex; + private const int TestDimensions = 4; + + public SqliteVectorIndexTests() + { + // Use temp file for SQLite + this._dbPath = Path.Combine(Path.GetTempPath(), $"vector_test_{Guid.NewGuid()}.db"); + this._mockLogger = new Mock>(); + this._mockEmbeddingGenerator = new Mock(); + + // Configure mock to return predictable embeddings + this._mockEmbeddingGenerator + .Setup(x => x.GenerateAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((string text, CancellationToken _) => EmbeddingResult.FromVector(GenerateTestEmbedding(text))); + + this._vectorIndex = new SqliteVectorIndex( + this._dbPath, + TestDimensions, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object); + } + + public void Dispose() + { + this._vectorIndex.Dispose(); + + // Clean up temp file + if (File.Exists(this._dbPath)) + { + File.Delete(this._dbPath); + } + + // Clean up WAL files + var walPath = this._dbPath + "-wal"; + var shmPath = this._dbPath + "-shm"; + if (File.Exists(walPath)) + { + File.Delete(walPath); + } + + if (File.Exists(shmPath)) + { + File.Delete(shmPath); + } + + GC.SuppressFinalize(this); + } + + /// + /// Generates a deterministic test embedding based on text hash. + /// Returns consistent embeddings for the same text. + /// + /// + private static float[] GenerateTestEmbedding(string text) + { + // Simple deterministic embedding based on text hash + var hash = text.GetHashCode(); + var embedding = new float[TestDimensions]; + for (int i = 0; i < TestDimensions; i++) + { + embedding[i] = ((hash >> (i * 8)) & 0xFF) / 255.0f; + } + + return embedding; + } + + [Fact] + public async Task IndexAsync_IndexesContentSuccessfully() + { + // Arrange + const string contentId = "test-id-1"; + const string text = "This is a test document for vector search."; + + // Act + await this._vectorIndex.IndexAsync(contentId, text).ConfigureAwait(false); + + // Assert - Search should find it + var results = await this._vectorIndex.SearchAsync(text).ConfigureAwait(false); + Assert.Single(results); + Assert.Equal(contentId, results[0].ContentId); + } + + [Fact] + public async Task IndexAsync_ReplacesExistingContent() + { + // Arrange + const string contentId = "test-id-replace"; + await this._vectorIndex.IndexAsync(contentId, "original content about cats").ConfigureAwait(false); + + // Act - Replace with new content + await this._vectorIndex.IndexAsync(contentId, "updated content about dogs").ConfigureAwait(false); + + // Assert - Search should only find one result for this ID + var results = await this._vectorIndex.SearchAsync("updated content about dogs").ConfigureAwait(false); + Assert.Single(results); + Assert.Equal(contentId, results[0].ContentId); + } + + [Fact] + public async Task SearchAsync_ReturnsEmptyForEmptyQuery() + { + // Arrange + await this._vectorIndex.IndexAsync("id1", "some content").ConfigureAwait(false); + + // Act + var results = await this._vectorIndex.SearchAsync("").ConfigureAwait(false); + + // Assert + Assert.Empty(results); + } + + [Fact] + public async Task SearchAsync_ReturnsMultipleMatches() + { + // Arrange + await this._vectorIndex.IndexAsync("id1", "The quick brown fox jumps").ConfigureAwait(false); + await this._vectorIndex.IndexAsync("id2", "A quick rabbit runs fast").ConfigureAwait(false); + await this._vectorIndex.IndexAsync("id3", "Slow turtle walks slowly").ConfigureAwait(false); + + // Act - Search for something + var results = await this._vectorIndex.SearchAsync("fast animal").ConfigureAwait(false); + + // Assert - Should return all indexed items + Assert.Equal(3, results.Count); + } + + [Fact] + public async Task SearchAsync_RespectsLimit() + { + // Arrange - Create many documents + for (int i = 0; i < 20; i++) + { + await this._vectorIndex.IndexAsync($"id{i}", $"Document number {i} with common word test").ConfigureAwait(false); + } + + // Act + var results = await this._vectorIndex.SearchAsync("test", limit: 5).ConfigureAwait(false); + + // Assert + Assert.Equal(5, results.Count); + } + + [Fact] + public async Task SearchAsync_OrdersByScore() + { + // Arrange - Create documents + await this._vectorIndex.IndexAsync("id1", "hello world").ConfigureAwait(false); + await this._vectorIndex.IndexAsync("id2", "goodbye world").ConfigureAwait(false); + + // Act + var results = await this._vectorIndex.SearchAsync("hello world").ConfigureAwait(false); + + // Assert - First result should have exact match (highest score) + Assert.Equal(2, results.Count); + Assert.Equal("id1", results[0].ContentId); + Assert.True(results[0].Score >= results[1].Score, "Results should be ordered by score descending"); + } + + [Fact] + public async Task RemoveAsync_RemovesIndexedContent() + { + // Arrange + const string contentId = "test-remove"; + await this._vectorIndex.IndexAsync(contentId, "content to be removed").ConfigureAwait(false); + + // Verify it exists + var beforeRemove = await this._vectorIndex.SearchAsync("content to be removed").ConfigureAwait(false); + Assert.Single(beforeRemove); + + // Act + await this._vectorIndex.RemoveAsync(contentId).ConfigureAwait(false); + + // Assert + var afterRemove = await this._vectorIndex.SearchAsync("content to be removed").ConfigureAwait(false); + Assert.Empty(afterRemove); + } + + [Fact] + public async Task RemoveAsync_IsIdempotent() + { + // Arrange + const string contentId = "non-existent-id"; + + // Act - Should not throw for non-existent content + await this._vectorIndex.RemoveAsync(contentId).ConfigureAwait(false); + await this._vectorIndex.RemoveAsync(contentId).ConfigureAwait(false); + + // Assert - No exception thrown + Assert.True(true); + } + + [Fact] + public async Task ClearAsync_RemovesAllContent() + { + // Arrange + await this._vectorIndex.IndexAsync("id1", "first document").ConfigureAwait(false); + await this._vectorIndex.IndexAsync("id2", "second document").ConfigureAwait(false); + await this._vectorIndex.IndexAsync("id3", "third document").ConfigureAwait(false); + + // Verify content exists + var beforeClear = await this._vectorIndex.SearchAsync("document").ConfigureAwait(false); + Assert.Equal(3, beforeClear.Count); + + // Act + await this._vectorIndex.ClearAsync().ConfigureAwait(false); + + // Assert + var afterClear = await this._vectorIndex.SearchAsync("document").ConfigureAwait(false); + Assert.Empty(afterClear); + } + + [Fact] + public async Task ScoreProperty_IsInValidRange() + { + // Arrange + await this._vectorIndex.IndexAsync("id1", "test content for scoring").ConfigureAwait(false); + + // Act + var results = await this._vectorIndex.SearchAsync("test").ConfigureAwait(false); + + // Assert - Score should be in valid range for normalized dot product + Assert.Single(results); + Assert.True(results[0].Score >= -1.0 && results[0].Score <= 1.0, "Score should be in range [-1, 1] for normalized vectors"); + } + + [Fact] + public async Task VectorDimensions_ReturnsConfiguredValue() + { + // Assert + Assert.Equal(TestDimensions, this._vectorIndex.VectorDimensions); + } + + [Fact] + public async Task IndexAsync_ValidatesDimensionMismatch() + { + // Arrange - Configure mock to return wrong dimensions + var wrongDimensions = new float[] { 1.0f, 2.0f }; // Only 2 dimensions instead of 4 + this._mockEmbeddingGenerator + .Setup(x => x.GenerateAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(EmbeddingResult.FromVector(wrongDimensions)); + + // Create new index with mismatch + var mismatchedIndex = new SqliteVectorIndex( + Path.Combine(Path.GetTempPath(), $"mismatch_test_{Guid.NewGuid()}.db"), + TestDimensions, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object); + + try + { + // Act & Assert + var ex = await Assert.ThrowsAsync( + () => mismatchedIndex.IndexAsync("id1", "test content")).ConfigureAwait(false); + Assert.Contains("dimensions", ex.Message, StringComparison.OrdinalIgnoreCase); + } + finally + { + mismatchedIndex.Dispose(); + } + } + + [Fact] + public void Constructor_ThrowsForInvalidDimensions() + { + // Act & Assert + Assert.Throws(() => + new SqliteVectorIndex( + this._dbPath + "_invalid", + dimensions: 0, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object)); + + Assert.Throws(() => + new SqliteVectorIndex( + this._dbPath + "_invalid2", + dimensions: -1, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object)); + } + + [Fact] + public void Constructor_ThrowsForNullArguments() + { + // Act & Assert + Assert.Throws(() => + new SqliteVectorIndex( + null!, + TestDimensions, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + this._mockLogger.Object)); + + Assert.Throws(() => + new SqliteVectorIndex( + this._dbPath + "_null", + TestDimensions, + useSqliteVec: false, + null!, + this._mockLogger.Object)); + + Assert.Throws(() => + new SqliteVectorIndex( + this._dbPath + "_null2", + TestDimensions, + useSqliteVec: false, + this._mockEmbeddingGenerator.Object, + null!)); + } + + [Fact] + public async Task IndexAsync_ThrowsForNullContentId() + { + // Act & Assert - null throws ArgumentNullException, empty/whitespace throws ArgumentException + await Assert.ThrowsAsync( + () => this._vectorIndex.IndexAsync(null!, "test content")).ConfigureAwait(false); + + await Assert.ThrowsAsync( + () => this._vectorIndex.IndexAsync("", "test content")).ConfigureAwait(false); + + await Assert.ThrowsAsync( + () => this._vectorIndex.IndexAsync(" ", "test content")).ConfigureAwait(false); + } + + [Fact] + public async Task IndexAsync_ThrowsForNullText() + { + // Act & Assert + await Assert.ThrowsAsync( + () => this._vectorIndex.IndexAsync("id1", null!)).ConfigureAwait(false); + } + + [Fact] + public async Task SearchAsync_ThrowsForNullQuery() + { + // Act & Assert + await Assert.ThrowsAsync( + () => this._vectorIndex.SearchAsync(null!)).ConfigureAwait(false); + } +} diff --git a/tests/Core.Tests/Search/VectorMathTests.cs b/tests/Core.Tests/Search/VectorMathTests.cs new file mode 100644 index 000000000..b09722e22 --- /dev/null +++ b/tests/Core.Tests/Search/VectorMathTests.cs @@ -0,0 +1,271 @@ +// Copyright (c) Microsoft. All rights reserved. +namespace KernelMemory.Core.Tests.Search; + +/// +/// Unit tests for VectorMath class: normalization, dot product, and serialization operations. +/// +public sealed class VectorMathTests +{ + private const double Tolerance = 1e-6; + + [Fact] + public void NormalizeVector_ProducesMagnitudeOne() + { + // Arrange + var vector = new float[] { 3.0f, 4.0f }; // 3-4-5 triangle + + // Act + var normalized = VectorMath.NormalizeVector(vector); + + // Assert - Magnitude should be 1 + var magnitude = Math.Sqrt(normalized.Sum(x => x * x)); + Assert.Equal(1.0, magnitude, Tolerance); + + // Check individual components: 3/5 and 4/5 + Assert.Equal(0.6f, normalized[0], (float)Tolerance); + Assert.Equal(0.8f, normalized[1], (float)Tolerance); + } + + [Fact] + public void NormalizeVector_PreservesDirection() + { + // Arrange + var vector = new float[] { 1.0f, 2.0f, 3.0f }; + + // Act + var normalized = VectorMath.NormalizeVector(vector); + + // Assert - Ratios should be preserved + var ratio01 = vector[0] / vector[1]; + var normalizedRatio01 = normalized[0] / normalized[1]; + Assert.Equal(ratio01, normalizedRatio01, Tolerance); + + var ratio12 = vector[1] / vector[2]; + var normalizedRatio12 = normalized[1] / normalized[2]; + Assert.Equal(ratio12, normalizedRatio12, Tolerance); + } + + [Fact] + public void NormalizeVector_ThrowsForZeroVector() + { + // Arrange + var zeroVector = new float[] { 0.0f, 0.0f, 0.0f }; + + // Act & Assert + Assert.Throws(() => VectorMath.NormalizeVector(zeroVector)); + } + + [Fact] + public void NormalizeVector_ThrowsForEmptyVector() + { + // Arrange + var emptyVector = Array.Empty(); + + // Act & Assert + Assert.Throws(() => VectorMath.NormalizeVector(emptyVector)); + } + + [Fact] + public void NormalizeVector_ThrowsForNullVector() + { + // Act & Assert + Assert.Throws(() => VectorMath.NormalizeVector(null!)); + } + + [Fact] + public void NormalizeVector_HandlesNearZeroMagnitude() + { + // Arrange - Values so tiny that magnitude underflows to zero + // float.Epsilon is the smallest positive float that is greater than zero + // Values below this effectively produce zero magnitude + var zeroVector = new float[] { 0f, 0f }; + + // Act & Assert - Zero vector should throw because magnitude is zero + Assert.Throws(() => VectorMath.NormalizeVector(zeroVector)); + } + + [Fact] + public void DotProduct_ReturnsOneForIdenticalNormalizedVectors() + { + // Arrange + var vector = new float[] { 3.0f, 4.0f }; + var normalized = VectorMath.NormalizeVector(vector); + + // Act + var dotProduct = VectorMath.DotProduct(normalized, normalized); + + // Assert - Dot product of identical normalized vectors should be 1 + Assert.Equal(1.0, dotProduct, Tolerance); + } + + [Fact] + public void DotProduct_ReturnsZeroForOrthogonalVectors() + { + // Arrange - Two orthogonal (perpendicular) normalized vectors + var v1 = VectorMath.NormalizeVector(new float[] { 1.0f, 0.0f }); + var v2 = VectorMath.NormalizeVector(new float[] { 0.0f, 1.0f }); + + // Act + var dotProduct = VectorMath.DotProduct(v1, v2); + + // Assert + Assert.Equal(0.0, dotProduct, Tolerance); + } + + [Fact] + public void DotProduct_ReturnsNegativeOneForOppositeVectors() + { + // Arrange + var v1 = VectorMath.NormalizeVector(new float[] { 1.0f, 0.0f }); + var v2 = VectorMath.NormalizeVector(new float[] { -1.0f, 0.0f }); + + // Act + var dotProduct = VectorMath.DotProduct(v1, v2); + + // Assert + Assert.Equal(-1.0, dotProduct, Tolerance); + } + + [Fact] + public void DotProduct_ThrowsForDifferentLengths() + { + // Arrange + var v1 = new float[] { 1.0f, 2.0f }; + var v2 = new float[] { 1.0f, 2.0f, 3.0f }; + + // Act & Assert + Assert.Throws(() => VectorMath.DotProduct(v1, v2)); + } + + [Fact] + public void DotProduct_ThrowsForNullVectors() + { + // Arrange + var vector = new float[] { 1.0f, 2.0f }; + + // Act & Assert + Assert.Throws(() => VectorMath.DotProduct(null!, vector)); + Assert.Throws(() => VectorMath.DotProduct(vector, null!)); + } + + [Fact] + public void DotProduct_EqualsCosineSimilarityForNormalizedVectors() + { + // Arrange - Two vectors at 60 degrees (cosine = 0.5) + var v1 = VectorMath.NormalizeVector(new float[] { 1.0f, 0.0f }); + var v2 = VectorMath.NormalizeVector(new float[] { 0.5f, (float)Math.Sqrt(0.75) }); // 60 degrees + + // Act + var dotProduct = VectorMath.DotProduct(v1, v2); + + // Assert - Dot product should equal cosine of angle + Assert.Equal(0.5, dotProduct, Tolerance); + } + + [Fact] + public void VectorToBlob_And_BlobToVector_RoundTrip() + { + // Arrange + var original = new float[] { 1.5f, -2.5f, 3.14159f, 0.0f }; + + // Act + var blob = VectorMath.VectorToBlob(original); + var restored = VectorMath.BlobToVector(blob); + + // Assert + Assert.Equal(original.Length, restored.Length); + for (int i = 0; i < original.Length; i++) + { + Assert.Equal(original[i], restored[i]); + } + } + + [Fact] + public void VectorToBlob_ProducesCorrectSize() + { + // Arrange + var vector = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + + // Act + var blob = VectorMath.VectorToBlob(vector); + + // Assert - Should be 4 bytes per float + Assert.Equal(vector.Length * sizeof(float), blob.Length); + } + + [Fact] + public void BlobToVector_ThrowsForInvalidBlobLength() + { + // Arrange - Blob length not divisible by sizeof(float) + var invalidBlob = new byte[] { 0, 1, 2 }; + + // Act & Assert + Assert.Throws(() => VectorMath.BlobToVector(invalidBlob)); + } + + [Fact] + public void VectorToBlob_ThrowsForNullVector() + { + // Act & Assert + Assert.Throws(() => VectorMath.VectorToBlob(null!)); + } + + [Fact] + public void BlobToVector_ThrowsForNullBlob() + { + // Act & Assert + Assert.Throws(() => VectorMath.BlobToVector(null!)); + } + + [Fact] + public void VectorToBlob_HandlesEmptyVector() + { + // Arrange + var emptyVector = Array.Empty(); + + // Act + var blob = VectorMath.VectorToBlob(emptyVector); + var restored = VectorMath.BlobToVector(blob); + + // Assert + Assert.Empty(blob); + Assert.Empty(restored); + } + + [Fact] + public void VectorSerialization_HandlesSpecialFloatValues() + { + // Arrange + var specialValues = new float[] { float.MaxValue, float.MinValue, float.Epsilon, -float.Epsilon }; + + // Act + var blob = VectorMath.VectorToBlob(specialValues); + var restored = VectorMath.BlobToVector(blob); + + // Assert + for (int i = 0; i < specialValues.Length; i++) + { + Assert.Equal(specialValues[i], restored[i]); + } + } + + [Fact] + public void NormalizeVector_HandlesHighDimensionalVectors() + { + // Arrange - 1024 dimensions (common for embedding models) + // Using deterministic values instead of Random for reproducibility and security compliance + var highDimVector = new float[1024]; + for (int i = 0; i < highDimVector.Length; i++) + { + // Deterministic pattern: sine wave values between -1 and 1 + highDimVector[i] = (float)Math.Sin(i * 0.123); + } + + // Act + var normalized = VectorMath.NormalizeVector(highDimVector); + + // Assert - Magnitude should be 1 + var magnitude = Math.Sqrt(normalized.Sum(x => x * (double)x)); + Assert.Equal(1.0, magnitude, Tolerance); + } +} diff --git a/tests/Main.Tests/GlobalUsings.cs b/tests/Main.Tests/GlobalUsings.cs index cbd4300e0..84b45e6ed 100644 --- a/tests/Main.Tests/GlobalUsings.cs +++ b/tests/Main.Tests/GlobalUsings.cs @@ -1,3 +1,4 @@ // Copyright (c) Microsoft. All rights reserved. +global using KernelMemory.Core; global using Xunit; diff --git a/tests/Main.Tests/Integration/CliIntegrationTests.cs b/tests/Main.Tests/Integration/CliIntegrationTests.cs index 6b063c329..e36c75aa0 100644 --- a/tests/Main.Tests/Integration/CliIntegrationTests.cs +++ b/tests/Main.Tests/Integration/CliIntegrationTests.cs @@ -78,7 +78,7 @@ public async Task UpsertCommand_WithMinimalOptions_CreatesContent() var exitCode = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); } [Fact] @@ -103,7 +103,7 @@ public async Task UpsertCommand_WithCustomId_UsesProvidedId() var exitCode = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); // Verify content exists with custom ID var getSettings = new GetCommandSettings @@ -115,7 +115,7 @@ public async Task UpsertCommand_WithCustomId_UsesProvidedId() var getCommand = new GetCommand(config, NullLoggerFactory.Instance); var getExitCode = await getCommand.ExecuteAsync(context, getSettings, CancellationToken.None).ConfigureAwait(false); - Assert.Equal(Constants.ExitCodeSuccess, getExitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, getExitCode); } [Fact] @@ -141,7 +141,7 @@ public async Task UpsertCommand_WithAllMetadata_StoresAllFields() var exitCode = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); } [Fact] @@ -174,7 +174,7 @@ public async Task GetCommand_ExistingId_ReturnsContent() var exitCode = await getCommand.ExecuteAsync(context, getSettings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); } [Fact] @@ -208,7 +208,7 @@ public async Task GetCommand_NonExistentId_ReturnsUserError() var exitCode = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); // Assert - ID not found in existing DB is user error - Assert.Equal(Constants.ExitCodeUserError, exitCode); + Assert.Equal(Constants.App.ExitCodeUserError, exitCode); } [Fact] @@ -243,7 +243,7 @@ public async Task GetCommand_WithFullFlag_ReturnsAllDetails() var exitCode = await getCommand.ExecuteAsync(context, getSettings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); } [Fact] @@ -286,7 +286,7 @@ public async Task ListCommand_EmptyDatabase_ReturnsEmptyList() var exitCode = await command.ExecuteAsync(listContext, settings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); } [Fact] @@ -333,7 +333,7 @@ public async Task Bug3_ListCommand_EmptyDatabase_HumanFormat_ShouldHandleGracefu var exitCode = await command.ExecuteAsync(listContext, settings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); // TODO: Capture stdout and verify it doesn't show an empty table // Expected: A message like "No content found" instead of empty table } @@ -365,7 +365,7 @@ public async Task ListCommand_WithContent_ReturnsList() var exitCode = await listCommand.ExecuteAsync(context, listSettings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); } [Fact] @@ -400,7 +400,7 @@ public async Task ListCommand_WithPagination_RespectsSkipAndTake() var exitCode = await listCommand.ExecuteAsync(context, listSettings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); } [Fact] @@ -433,7 +433,7 @@ public async Task DeleteCommand_ExistingId_DeletesSuccessfully() var exitCode = await deleteCommand.ExecuteAsync(context, deleteSettings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); // Verify content is gone var getSettings = new GetCommandSettings @@ -445,7 +445,7 @@ public async Task DeleteCommand_ExistingId_DeletesSuccessfully() var getCommand = new GetCommand(config, NullLoggerFactory.Instance); var getExitCode = await getCommand.ExecuteAsync(context, getSettings, CancellationToken.None).ConfigureAwait(false); - Assert.Equal(Constants.ExitCodeUserError, getExitCode); + Assert.Equal(Constants.App.ExitCodeUserError, getExitCode); } [Fact] @@ -479,7 +479,7 @@ public async Task DeleteCommand_WithQuietVerbosity_SucceedsWithMinimalOutput() var exitCode = await deleteCommand.ExecuteAsync(context, deleteSettings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); } [Fact] @@ -501,7 +501,7 @@ public async Task EndToEndWorkflow_UpsertGetListDelete_AllSucceed() }; var upsertCommand = new UpsertCommand(config, NullLoggerFactory.Instance); var upsertExitCode = await upsertCommand.ExecuteAsync(context, upsertSettings, CancellationToken.None).ConfigureAwait(false); - Assert.Equal(Constants.ExitCodeSuccess, upsertExitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, upsertExitCode); // 2. Get var getSettings = new GetCommandSettings @@ -512,7 +512,7 @@ public async Task EndToEndWorkflow_UpsertGetListDelete_AllSucceed() }; var getCommand = new GetCommand(config, NullLoggerFactory.Instance); var getExitCode = await getCommand.ExecuteAsync(context, getSettings, CancellationToken.None).ConfigureAwait(false); - Assert.Equal(Constants.ExitCodeSuccess, getExitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, getExitCode); // 3. List var listSettings = new ListCommandSettings @@ -522,7 +522,7 @@ public async Task EndToEndWorkflow_UpsertGetListDelete_AllSucceed() }; var listCommand = new ListCommand(config, NullLoggerFactory.Instance); var listExitCode = await listCommand.ExecuteAsync(context, listSettings, CancellationToken.None).ConfigureAwait(false); - Assert.Equal(Constants.ExitCodeSuccess, listExitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, listExitCode); // 4. Delete var deleteSettings = new DeleteCommandSettings @@ -533,11 +533,11 @@ public async Task EndToEndWorkflow_UpsertGetListDelete_AllSucceed() }; var deleteCommand = new DeleteCommand(config, NullLoggerFactory.Instance); var deleteExitCode = await deleteCommand.ExecuteAsync(context, deleteSettings, CancellationToken.None).ConfigureAwait(false); - Assert.Equal(Constants.ExitCodeSuccess, deleteExitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, deleteExitCode); // 5. Verify deleted var verifyExitCode = await getCommand.ExecuteAsync(context, getSettings, CancellationToken.None).ConfigureAwait(false); - Assert.Equal(Constants.ExitCodeUserError, verifyExitCode); + Assert.Equal(Constants.App.ExitCodeUserError, verifyExitCode); } [Fact] @@ -558,7 +558,7 @@ public async Task NodesCommand_WithJsonFormat_ListsAllNodes() var exitCode = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); } [Fact] @@ -579,7 +579,7 @@ public async Task NodesCommand_WithYamlFormat_ListsAllNodes() var exitCode = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); } [Fact] @@ -598,10 +598,10 @@ public async Task ConfigCommand_Default_ShowsCurrentNode() var context = CreateTestContext("config"); // Act - var exitCode = await command.ExecuteAsync(context, settings).ConfigureAwait(false); + var exitCode = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); } [Fact] @@ -621,10 +621,10 @@ public async Task ConfigCommand_WithShowNodes_ShowsAllNodes() var context = CreateTestContext("config"); // Act - var exitCode = await command.ExecuteAsync(context, settings).ConfigureAwait(false); + var exitCode = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); } [Fact] @@ -644,10 +644,10 @@ public async Task ConfigCommand_WithShowCache_ShowsCacheConfig() var context = CreateTestContext("config"); // Act - var exitCode = await command.ExecuteAsync(context, settings).ConfigureAwait(false); + var exitCode = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); } [Fact] @@ -697,10 +697,10 @@ public async Task Bug2_ConfigCommand_HumanFormat_ShouldNotLeakTypeNames() var context = CreateTestContext("config"); // Act - var exitCode = await command.ExecuteAsync(context, settings).ConfigureAwait(false); + var exitCode = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); // Note: AnsiConsole output cannot be easily captured in tests. // The fix ensures that HumanOutputFormatter.Format() handles DTO objects diff --git a/tests/Main.Tests/Integration/CommandExecutionTests.cs b/tests/Main.Tests/Integration/CommandExecutionTests.cs index fd0cdc760..8bc3870a7 100644 --- a/tests/Main.Tests/Integration/CommandExecutionTests.cs +++ b/tests/Main.Tests/Integration/CommandExecutionTests.cs @@ -187,7 +187,7 @@ public async Task ConfigCommand_WithoutFlags_ReturnsSuccess() var command = new ConfigCommand(config, NullLoggerFactory.Instance, configPathService); var context = new CommandContext(new[] { "--config", this._configPath }, new EmptyRemainingArguments(), "config", null); - var result = await command.ExecuteAsync(context, settings).ConfigureAwait(false); + var result = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); Assert.Equal(0, result); } @@ -206,7 +206,7 @@ public async Task ConfigCommand_WithShowNodes_ReturnsSuccess() var command = new ConfigCommand(config, NullLoggerFactory.Instance, configPathService); var context = new CommandContext(new[] { "--config", this._configPath }, new EmptyRemainingArguments(), "config", null); - var result = await command.ExecuteAsync(context, settings).ConfigureAwait(false); + var result = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); Assert.Equal(0, result); } @@ -225,7 +225,7 @@ public async Task ConfigCommand_WithShowCache_ReturnsSuccess() var command = new ConfigCommand(config, NullLoggerFactory.Instance, configPathService); var context = new CommandContext(new[] { "--config", this._configPath }, new EmptyRemainingArguments(), "config", null); - var result = await command.ExecuteAsync(context, settings).ConfigureAwait(false); + var result = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); Assert.Equal(0, result); } diff --git a/tests/Main.Tests/Integration/ConfigCommandTests.cs b/tests/Main.Tests/Integration/ConfigCommandTests.cs index b1ac3f461..b954ed0a1 100644 --- a/tests/Main.Tests/Integration/ConfigCommandTests.cs +++ b/tests/Main.Tests/Integration/ConfigCommandTests.cs @@ -98,10 +98,10 @@ public void ConfigCommand_WithoutFlags_ShouldShowEntireConfiguration() try { // Act - var exitCode = command.ExecuteAsync(context, settings).GetAwaiter().GetResult(); + var exitCode = command.ExecuteAsync(context, settings, CancellationToken.None).GetAwaiter().GetResult(); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); var output = outputCapture.ToString(); @@ -151,10 +151,10 @@ public void ConfigCommand_OutputStructure_ShouldMatchAppConfigFormat() try { // Act - var exitCode = command.ExecuteAsync(context, settings).GetAwaiter().GetResult(); + var exitCode = command.ExecuteAsync(context, settings, CancellationToken.None).GetAwaiter().GetResult(); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); var output = outputCapture.ToString(); var outputJson = System.Text.Json.JsonDocument.Parse(output); @@ -209,10 +209,10 @@ public void ConfigCommand_WithShowNodesFlag_ShouldShowAllNodesSummary() try { // Act - var exitCode = command.ExecuteAsync(context, settings).GetAwaiter().GetResult(); + var exitCode = command.ExecuteAsync(context, settings, CancellationToken.None).GetAwaiter().GetResult(); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); var output = outputCapture.ToString(); @@ -252,10 +252,10 @@ public void ConfigCommand_WithCreate_CreatesConfigFile() null); // Act - var exitCode = command.ExecuteAsync(context, settings).GetAwaiter().GetResult(); + var exitCode = command.ExecuteAsync(context, settings, CancellationToken.None).GetAwaiter().GetResult(); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); Assert.True(File.Exists(newConfigPath), "Config file should be created"); // Verify the file content is valid JSON @@ -298,10 +298,10 @@ public void ConfigCommand_WithCreate_WhenFileExists_ReturnsError() try { // Act - var exitCode = command.ExecuteAsync(context, settings).GetAwaiter().GetResult(); + var exitCode = command.ExecuteAsync(context, settings, CancellationToken.None).GetAwaiter().GetResult(); // Assert - Assert.Equal(Constants.ExitCodeUserError, exitCode); + Assert.Equal(Constants.App.ExitCodeUserError, exitCode); // Error message goes to Console.Error var errorOutput = errorCapture.ToString(); @@ -346,11 +346,11 @@ public void ConfigCommand_WithoutConfigFile_StillSucceeds() try { // Act - var exitCode = command.ExecuteAsync(context, settings).GetAwaiter().GetResult(); + var exitCode = command.ExecuteAsync(context, settings, CancellationToken.None).GetAwaiter().GetResult(); // Assert // The key behavior: command succeeds even without config file - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); // Should still output valid JSON config var output = outputCapture.ToString(); @@ -392,10 +392,10 @@ public void ConfigCommand_OutputJson_DoesNotContainNullFields() try { // Act - var exitCode = command.ExecuteAsync(context, settings).GetAwaiter().GetResult(); + var exitCode = command.ExecuteAsync(context, settings, CancellationToken.None).GetAwaiter().GetResult(); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); var output = outputCapture.ToString(); @@ -440,10 +440,10 @@ public void ConfigCommand_OutputJson_ContainsCorrectDiscriminators() try { // Act - var exitCode = command.ExecuteAsync(context, settings).GetAwaiter().GetResult(); + var exitCode = command.ExecuteAsync(context, settings, CancellationToken.None).GetAwaiter().GetResult(); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); var output = outputCapture.ToString(); var outputJson = System.Text.Json.JsonDocument.Parse(output); diff --git a/tests/Main.Tests/Integration/DefaultConfigVectorIndexTests.cs b/tests/Main.Tests/Integration/DefaultConfigVectorIndexTests.cs new file mode 100644 index 000000000..bf503d71c --- /dev/null +++ b/tests/Main.Tests/Integration/DefaultConfigVectorIndexTests.cs @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft. All rights reserved. + +using KernelMemory.Core.Config; +using KernelMemory.Core.Config.ContentIndex; +using KernelMemory.Core.Config.Embeddings; +using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Config.SearchIndex; +using KernelMemory.Core.Search; +using KernelMemory.Core.Storage; +using KernelMemory.Core.Storage.Models; +using KernelMemory.Main.Services; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Logging; + +namespace KernelMemory.Main.Tests.Integration; + +/// +/// Integration tests verifying that the default configuration includes +/// vector search indexes and that all configured indexes are properly created +/// and used during ingestion operations. +/// +public sealed class DefaultConfigVectorIndexTests +{ + /// + /// Verifies that the default configuration includes both FTS and vector search indexes + /// as specified in Feature 00001. + /// This is a regression test to catch if vector indexes are accidentally removed from defaults. + /// + [Fact] + public void DefaultConfig_ShouldIncludeBothFtsAndVectorIndexes() + { + // Arrange & Act + var config = AppConfig.CreateDefault("/tmp/test"); + + // Assert + var personalNode = config.Nodes["personal"]; + Assert.NotNull(personalNode); + Assert.Equal(2, personalNode.SearchIndexes.Count); + + // Verify FTS index exists + var ftsIndex = personalNode.SearchIndexes.FirstOrDefault(i => i is FtsSearchIndexConfig); + Assert.NotNull(ftsIndex); + Assert.Equal("sqlite-fts", ftsIndex.Id); + Assert.True(ftsIndex.Required); // FTS should be required + + // Verify Vector index exists + var vectorIndex = personalNode.SearchIndexes.FirstOrDefault(i => i is VectorSearchIndexConfig) as VectorSearchIndexConfig; + Assert.NotNull(vectorIndex); + Assert.Equal("sqlite-vector", vectorIndex.Id); + Assert.False(vectorIndex.Required); // Vector should be optional (Ollama may not be running) + Assert.Equal(1024, vectorIndex.Dimensions); + Assert.False(vectorIndex.UseSqliteVec); + + // Verify vector index has Ollama embeddings configured + Assert.NotNull(vectorIndex.Embeddings); + Assert.IsType(vectorIndex.Embeddings); + var ollamaConfig = (OllamaEmbeddingsConfig)vectorIndex.Embeddings; + Assert.Equal("qwen3-embedding:0.6b", ollamaConfig.Model); + Assert.Equal("http://localhost:11434", ollamaConfig.BaseUrl); + } + + /// + /// Verifies that the default configuration includes embeddings cache + /// as specified in Feature 00001. + /// + [Fact] + public void DefaultConfig_ShouldIncludeEmbeddingsCache() + { + // Arrange & Act + var config = AppConfig.CreateDefault("/tmp/test"); + + // Assert + Assert.NotNull(config.EmbeddingsCache); + Assert.Equal("/tmp/test/embeddings-cache.db", config.EmbeddingsCache.Path); + Assert.True(config.EmbeddingsCache.AllowRead); + Assert.True(config.EmbeddingsCache.AllowWrite); + } + + /// + /// Verifies that when a node has multiple search indexes (20+ mixed types), + /// ALL indexes are created and registered for use during ingestion. + /// This is a critical regression test ensuring no indexes are skipped. + /// + [Fact] + public void CreateIndexes_WithManyMixedIndexTypes_ShouldCreateAllIndexes() + { + // Arrange - Create config with 20 indexes (mix of FTS and Vector) + var configs = new List(); + + // Add 10 FTS indexes + for (int i = 0; i < 10; i++) + { + configs.Add(new FtsSearchIndexConfig + { + Id = $"fts-{i}", + Type = SearchIndexTypes.SqliteFTS, + Path = $"/tmp/test/fts-{i}.db", + EnableStemming = i % 2 == 0 // Alternate stemming + }); + } + + // Add 10 Vector indexes with different dimensions + var dimensions = new[] { 384, 768, 1024, 1536, 3072 }; + for (int i = 0; i < 10; i++) + { + configs.Add(new VectorSearchIndexConfig + { + Id = $"vector-{i}", + Type = SearchIndexTypes.SqliteVector, + Path = $"/tmp/test/vector-{i}.db", + Dimensions = dimensions[i % dimensions.Length], + UseSqliteVec = false, + Embeddings = new OllamaEmbeddingsConfig + { + Model = "qwen3-embedding", + BaseUrl = "http://localhost:11434" + } + }); + } + + // Act + using var httpClient = new HttpClient(); + using var loggerFactory = LoggerFactory.Create(builder => builder.AddConsole()); + var indexes = SearchIndexFactory.CreateIndexes(configs, httpClient, null, loggerFactory); + + // Assert - ALL 20 indexes should be created + Assert.Equal(20, indexes.Count); + + // Verify all FTS indexes present + for (int i = 0; i < 10; i++) + { + Assert.Contains($"fts-{i}", indexes.Keys); + Assert.IsAssignableFrom(indexes[$"fts-{i}"]); + } + + // Verify all Vector indexes present + for (int i = 0; i < 10; i++) + { + Assert.Contains($"vector-{i}", indexes.Keys); + Assert.IsAssignableFrom(indexes[$"vector-{i}"]); + } + + // Cleanup + foreach (var index in indexes.Values.OfType()) + { + index.Dispose(); + } + } + + /// + /// Verifies that upsert operations create steps for ALL configured indexes. + /// This ensures the ingestion pipeline will update every index during km put. + /// + [Fact] + public async Task UpsertOperation_WithMultipleIndexes_ShouldCreateStepsForAllIndexes() + { + // Arrange + var tempDir = Path.Combine(Path.GetTempPath(), $"km-test-{Guid.NewGuid()}"); + var nodeDir = Path.Combine(tempDir, "nodes", "multi"); + Directory.CreateDirectory(nodeDir); + + try + { + // Create config with 3 search indexes + var config = new AppConfig + { + Nodes = new Dictionary + { + ["multi"] = new NodeConfig + { + Id = "multi", + Access = NodeAccessLevels.Full, + ContentIndex = new SqliteContentIndexConfig + { + Path = Path.Combine(nodeDir, "content.db") + }, + SearchIndexes = new List + { + new FtsSearchIndexConfig + { + Id = "fts-1", + Type = SearchIndexTypes.SqliteFTS, + Path = Path.Combine(nodeDir, "fts-1.db"), + EnableStemming = true + }, + new FtsSearchIndexConfig + { + Id = "fts-2", + Type = SearchIndexTypes.SqliteFTS, + Path = Path.Combine(nodeDir, "fts-2.db"), + EnableStemming = false + }, + new VectorSearchIndexConfig + { + Id = "vector-1", + Type = SearchIndexTypes.SqliteVector, + Path = Path.Combine(nodeDir, "vector-1.db"), + Dimensions = 1024, + UseSqliteVec = false, + Embeddings = new OllamaEmbeddingsConfig + { + Model = "qwen3-embedding", + BaseUrl = "http://localhost:11434" + } + } + } + } + } + }; + + // Create content storage service with all indexes + var connectionString = "Data Source=" + Path.Combine(nodeDir, "content.db"); + var optionsBuilder = new DbContextOptionsBuilder(); + optionsBuilder.UseSqlite(connectionString); + var context = new ContentStorageDbContext(optionsBuilder.Options); + context.Database.EnsureCreated(); + + var cuidGenerator = new CuidGenerator(); + using var loggerFactory = LoggerFactory.Create(builder => builder.AddConsole().SetMinimumLevel(LogLevel.Debug)); + var storageLogger = loggerFactory.CreateLogger(); + + using var httpClient = new HttpClient(); + var indexes = SearchIndexFactory.CreateIndexes( + config.Nodes["multi"].SearchIndexes, + httpClient, + embeddingCache: null, + loggerFactory); + + var storage = new ContentStorageService(context, cuidGenerator, storageLogger, indexes); + + // Act - Queue an upsert operation + var request = new UpsertRequest + { + Content = "Test content for multi-index ingestion", + MimeType = "text/plain" + }; + + var result = await storage.UpsertAsync(request, CancellationToken.None).ConfigureAwait(false); + + // Assert - Verify operation was queued with steps for ALL 3 indexes + var operation = await context.Operations + .FirstOrDefaultAsync(o => o.ContentId == result.Id).ConfigureAwait(false); + + Assert.NotNull(operation); + Assert.Contains("upsert", operation.PlannedSteps); + Assert.Contains("index:fts-1", operation.PlannedSteps); + Assert.Contains("index:fts-2", operation.PlannedSteps); + Assert.Contains("index:vector-1", operation.PlannedSteps); + Assert.Equal(4, operation.PlannedSteps.Length); // upsert + 3 index steps + + // Cleanup + foreach (var index in indexes.Values.OfType()) + { + index.Dispose(); + } + context.Dispose(); + } + finally + { + if (Directory.Exists(tempDir)) + { + Directory.Delete(tempDir, recursive: true); + } + } + } +} diff --git a/tests/Main.Tests/Integration/NodeSelectionTests.cs b/tests/Main.Tests/Integration/NodeSelectionTests.cs new file mode 100644 index 000000000..22ace5b8c --- /dev/null +++ b/tests/Main.Tests/Integration/NodeSelectionTests.cs @@ -0,0 +1,389 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using KernelMemory.Core.Config; +using KernelMemory.Core.Config.ContentIndex; +using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Config.SearchIndex; +using KernelMemory.Main.CLI.Commands; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Logging.Abstractions; +using Spectre.Console.Cli; + +namespace KernelMemory.Main.Tests.Integration; + +/// +/// Integration tests for node selection and broken node handling. +/// Issue 00006: km search should skip broken nodes gracefully and not crash. +/// +public sealed class NodeSelectionTests : IDisposable +{ + private static readonly JsonSerializerOptions s_jsonOptions = new() { WriteIndented = true }; + private readonly string _tempDir; + private readonly string _configPath; + private readonly string _goodNodeDbPath; + private readonly string _brokenNodeDbPath; + + public NodeSelectionTests() + { + // Create temp directory for test config + this._tempDir = Path.Combine(Path.GetTempPath(), $"km-node-selection-test-{Guid.NewGuid()}"); + Directory.CreateDirectory(this._tempDir); + + // Good node: will have a valid database + this._goodNodeDbPath = Path.Combine(this._tempDir, "nodes", "good-node", "content.db"); + + // Broken node: will NOT have a database (simulating missing database) + this._brokenNodeDbPath = Path.Combine(this._tempDir, "nodes", "broken-node", "content.db"); + + this._configPath = Path.Combine(this._tempDir, "config.json"); + } + + public void Dispose() + { + if (Directory.Exists(this._tempDir)) + { + Directory.Delete(this._tempDir, recursive: true); + } + } + + /// + /// Creates a test config with both good and broken nodes. + /// The good node has its database created, the broken node does not. + /// + private AppConfig CreateTestConfigWithBothNodes() + { + // Create the good node with its database + var goodNodeDir = Path.GetDirectoryName(this._goodNodeDbPath)!; + Directory.CreateDirectory(goodNodeDir); + + // Create actual database for good node using the same setup as real commands + var optionsBuilder = new Microsoft.EntityFrameworkCore.DbContextOptionsBuilder(); + optionsBuilder.UseSqlite("Data Source=" + this._goodNodeDbPath); + using var context = new KernelMemory.Core.Storage.ContentStorageDbContext(optionsBuilder.Options); + context.Database.EnsureCreated(); + + // Create FTS index database for good node + var goodNodeFtsPath = Path.Combine(goodNodeDir, "fts.db"); + + // Config with BOTH nodes (good first, broken second) + var config = new AppConfig + { + Nodes = new Dictionary + { + ["good-node"] = new NodeConfig + { + Id = "good-node", + ContentIndex = new SqliteContentIndexConfig { Path = this._goodNodeDbPath }, + SearchIndexes = + [ + new FtsSearchIndexConfig + { + Id = "fts", + Type = SearchIndexTypes.SqliteFTS, + Path = goodNodeFtsPath, + Weight = 1.0f + } + ] + }, + ["broken-node"] = new NodeConfig + { + Id = "broken-node", + ContentIndex = new SqliteContentIndexConfig { Path = this._brokenNodeDbPath }, + SearchIndexes = + [ + new FtsSearchIndexConfig + { + Id = "fts", + Type = SearchIndexTypes.SqliteFTS, + Path = Path.Combine(this._tempDir, "nodes", "broken-node", "fts.db"), + Weight = 1.0f + } + ] + } + } + }; + + var json = JsonSerializer.Serialize(config, s_jsonOptions); + File.WriteAllText(this._configPath, json); + + return config; + } + + /// + /// Creates a test config where the broken node is first (first in insertion order). + /// This tests that search still works even when the first node is broken. + /// + private AppConfig CreateTestConfigWithBrokenNodeFirst() + { + // Create the good node with its database + var goodNodeDir = Path.GetDirectoryName(this._goodNodeDbPath)!; + Directory.CreateDirectory(goodNodeDir); + + // Create actual database for good node + var optionsBuilder = new Microsoft.EntityFrameworkCore.DbContextOptionsBuilder(); + optionsBuilder.UseSqlite("Data Source=" + this._goodNodeDbPath); + using var context = new KernelMemory.Core.Storage.ContentStorageDbContext(optionsBuilder.Options); + context.Database.EnsureCreated(); + + // FTS paths + var goodNodeFtsPath = Path.Combine(goodNodeDir, "fts.db"); + + // Config with broken node FIRST (to test graceful skip) + var config = new AppConfig + { + Nodes = new Dictionary + { + ["broken-node"] = new NodeConfig + { + Id = "broken-node", + ContentIndex = new SqliteContentIndexConfig { Path = this._brokenNodeDbPath }, + SearchIndexes = + [ + new FtsSearchIndexConfig + { + Id = "fts", + Type = SearchIndexTypes.SqliteFTS, + Path = Path.Combine(this._tempDir, "nodes", "broken-node", "fts.db"), + Weight = 1.0f + } + ] + }, + ["good-node"] = new NodeConfig + { + Id = "good-node", + ContentIndex = new SqliteContentIndexConfig { Path = this._goodNodeDbPath }, + SearchIndexes = + [ + new FtsSearchIndexConfig + { + Id = "fts", + Type = SearchIndexTypes.SqliteFTS, + Path = goodNodeFtsPath, + Weight = 1.0f + } + ] + } + } + }; + + var json = JsonSerializer.Serialize(config, s_jsonOptions); + File.WriteAllText(this._configPath, json); + + return config; + } + + private static CommandContext CreateTestContext(string commandName) + { + return new CommandContext([], new EmptyRemainingArguments(), commandName, null); + } + + /// + /// Tests that search command succeeds even when one node has a broken/missing database. + /// The search should skip the broken node and use the working node. + /// Issue 00006: km search crashes on broken nodes instead of skipping them. + /// + [Fact] + public async Task SearchCommand_WithBrokenNode_ShouldSkipBrokenNodeAndSucceed() + { + // Arrange: Create config with good node first, broken node second + var config = this.CreateTestConfigWithBothNodes(); + + var settings = new SearchCommandSettings + { + ConfigPath = this._configPath, + Format = "json", + Query = "test", + Limit = 20, + Offset = 0, + MinRelevance = 0.3f + }; + + var command = new SearchCommand(config, NullLoggerFactory.Instance); + var context = CreateTestContext("search"); + + // Act: Execute search - this should NOT crash even though broken-node has no database + var exitCode = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); + + // Assert: Search should succeed (may return empty results, but shouldn't crash) + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); + } + + /// + /// Tests that search command succeeds even when the FIRST node in config is broken. + /// This tests the scenario where iteration order could cause early failure. + /// + [Fact] + public async Task SearchCommand_WithBrokenNodeFirst_ShouldSkipBrokenNodeAndSucceed() + { + // Arrange: Create config with BROKEN node first + var config = this.CreateTestConfigWithBrokenNodeFirst(); + + var settings = new SearchCommandSettings + { + ConfigPath = this._configPath, + Format = "json", + Query = "test", + Limit = 20, + Offset = 0, + MinRelevance = 0.3f + }; + + var command = new SearchCommand(config, NullLoggerFactory.Instance); + var context = CreateTestContext("search"); + + // Act: Execute search - this should NOT crash even though broken-node is first + var exitCode = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); + + // Assert: Search should succeed (may return empty results, but shouldn't crash) + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); + } + + /// + /// Tests that search command shows first-run message when ALL nodes are broken. + /// This is different from the partial failure case - if no nodes work at all, + /// we should show a helpful first-run message. + /// + [Fact] + public async Task SearchCommand_WithAllNodesBroken_ShouldShowFirstRunMessage() + { + // Arrange: Create config with ONLY broken nodes (no databases exist) + var config = new AppConfig + { + Nodes = new Dictionary + { + ["broken1"] = new NodeConfig + { + Id = "broken1", + ContentIndex = new SqliteContentIndexConfig { Path = Path.Combine(this._tempDir, "nonexistent1.db") }, + SearchIndexes = + [ + new FtsSearchIndexConfig + { + Id = "fts", + Type = SearchIndexTypes.SqliteFTS, + Path = Path.Combine(this._tempDir, "nonexistent1_fts.db"), + Weight = 1.0f + } + ] + }, + ["broken2"] = new NodeConfig + { + Id = "broken2", + ContentIndex = new SqliteContentIndexConfig { Path = Path.Combine(this._tempDir, "nonexistent2.db") }, + SearchIndexes = + [ + new FtsSearchIndexConfig + { + Id = "fts", + Type = SearchIndexTypes.SqliteFTS, + Path = Path.Combine(this._tempDir, "nonexistent2_fts.db"), + Weight = 1.0f + } + ] + } + } + }; + + var settings = new SearchCommandSettings + { + ConfigPath = this._configPath, + Format = "json", + Query = "test", + Limit = 20, + Offset = 0, + MinRelevance = 0.3f + }; + + var command = new SearchCommand(config, NullLoggerFactory.Instance); + var context = CreateTestContext("search"); + + // Act: Execute search - all nodes broken, should show first-run message (not crash) + var exitCode = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); + + // Assert: Should return success (first-run scenario) not error + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); + } + + /// + /// Tests that list command uses the first node in config file order. + /// Issue 00006: km list uses wrong default node (dictionary order vs config order). + /// + [Fact] + public async Task ListCommand_WithoutNodeFlag_ShouldUseFirstNodeInConfigOrder() + { + // Arrange: Create config with good node first + var config = this.CreateTestConfigWithBothNodes(); + + var settings = new ListCommandSettings + { + ConfigPath = this._configPath, + Format = "json", + // NOTE: Not specifying NodeName - should use first node in config + Skip = 0, + Take = 20 + }; + + var command = new ListCommand(config, NullLoggerFactory.Instance); + var context = CreateTestContext("list"); + + // Act: Execute list without --node flag + var exitCode = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); + + // Assert: Should succeed using first node (good-node) + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); + } + + /// + /// Tests that the config preserves node order from JSON file. + /// System.Text.Json should preserve property order since .NET 6+. + /// + [Fact] + public void ConfigParser_ShouldPreserveNodeOrder() + { + // Arrange: Create config with specific order (must include type discriminators and all required fields) + const string json = """ + { + "nodes": { + "alpha": { + "id": "alpha", + "contentIndex": { "type": "sqlite", "path": "/tmp/alpha.db" }, + "searchIndexes": [{ "type": "sqliteFTS", "id": "fts", "path": "/tmp/alpha_fts.db", "weight": 1.0 }] + }, + "beta": { + "id": "beta", + "contentIndex": { "type": "sqlite", "path": "/tmp/beta.db" }, + "searchIndexes": [{ "type": "sqliteFTS", "id": "fts", "path": "/tmp/beta_fts.db", "weight": 1.0 }] + }, + "gamma": { + "id": "gamma", + "contentIndex": { "type": "sqlite", "path": "/tmp/gamma.db" }, + "searchIndexes": [{ "type": "sqliteFTS", "id": "fts", "path": "/tmp/gamma_fts.db", "weight": 1.0 }] + } + } + } + """; + + var configPath = Path.Combine(this._tempDir, "order-test-config.json"); + File.WriteAllText(configPath, json); + + // Act + var config = ConfigParser.LoadFromFile(configPath); + + // Assert: Order should be preserved (Dictionary.Keys preserves insertion order in .NET) + var nodeIds = config.Nodes.Keys.ToList(); + Assert.Equal(3, nodeIds.Count); + Assert.Equal("alpha", nodeIds[0]); // First in JSON = first in dictionary + Assert.Equal("beta", nodeIds[1]); // Second in JSON = second in dictionary + Assert.Equal("gamma", nodeIds[2]); // Third in JSON = third in dictionary + } + + /// + /// Simple test implementation of IRemainingArguments. + /// + private sealed class EmptyRemainingArguments : IRemainingArguments + { + public IReadOnlyList Raw => Array.Empty(); + public ILookup Parsed => Enumerable.Empty().ToLookup(x => x, x => (string?)null); + } +} diff --git a/tests/Main.Tests/Integration/ReadonlyCommandTests.cs b/tests/Main.Tests/Integration/ReadonlyCommandTests.cs index 8b0af4f5d..c125c9c52 100644 --- a/tests/Main.Tests/Integration/ReadonlyCommandTests.cs +++ b/tests/Main.Tests/Integration/ReadonlyCommandTests.cs @@ -90,7 +90,7 @@ public async Task BugA_ListCommand_NonExistentDatabase_ShouldNotCreateDirectory( // Assert - With friendly first-run UX, missing DB returns success (0) not error // The key is that it should NOT create any files/directories - Assert.Equal(Constants.ExitCodeSuccess, exitCode); // First-run is not an error + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); // First-run is not an error Assert.False(Directory.Exists(dbDir), $"BUG: ListCommand (readonly) should NOT create directory: {dbDir}"); Assert.False(File.Exists(this._dbPath), @@ -126,7 +126,7 @@ public async Task BugA_GetCommand_NonExistentDatabase_ShouldNotCreateDirectory() // Assert - With friendly first-run UX, missing DB returns success (0) not error // The key is that it should NOT create any files/directories - Assert.Equal(Constants.ExitCodeSuccess, exitCode); // First-run is not an error + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); // First-run is not an error Assert.False(Directory.Exists(dbDir), $"BUG: GetCommand (readonly) should NOT create directory: {dbDir}"); Assert.False(File.Exists(this._dbPath), @@ -161,7 +161,7 @@ public async Task BugA_NodesCommand_NonExistentDatabase_ShouldNotCreateDirectory // Assert - This test SHOULD FAIL initially (reproducing the bug) // NodesCommand only reads config, shouldn't touch the database at all - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); Assert.False(Directory.Exists(dbDir), $"BUG: NodesCommand (readonly) should NOT create directory: {dbDir}"); Assert.False(File.Exists(this._dbPath), @@ -193,11 +193,11 @@ public async Task BugA_ConfigCommand_NonExistentDatabase_ShouldNotCreateDirector var context = CreateTestContext("config"); // Act - var exitCode = await command.ExecuteAsync(context, settings).ConfigureAwait(false); + var exitCode = await command.ExecuteAsync(context, settings, CancellationToken.None).ConfigureAwait(false); // Assert - This test SHOULD FAIL initially (reproducing the bug) // ConfigCommand only reads config, shouldn't touch the database at all - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); Assert.False(Directory.Exists(dbDir), $"BUG: ConfigCommand (readonly) should NOT create directory: {dbDir}"); Assert.False(File.Exists(this._dbPath), diff --git a/tests/Main.Tests/Integration/UserDataProtectionTests.cs b/tests/Main.Tests/Integration/UserDataProtectionTests.cs index 105761038..52c47cf9f 100644 --- a/tests/Main.Tests/Integration/UserDataProtectionTests.cs +++ b/tests/Main.Tests/Integration/UserDataProtectionTests.cs @@ -158,7 +158,7 @@ public async Task Fixed_SettingsWithConfigPath_MustUseTestDirectory() var exitCode = await command.ExecuteAsync(context, settingsWithConfigPath, CancellationToken.None).ConfigureAwait(false); // Assert - Assert.Equal(Constants.ExitCodeSuccess, exitCode); + Assert.Equal(Constants.App.ExitCodeSuccess, exitCode); // Verify test used temp directory, not ~/.km var testDbPath = Path.Combine(this._tempDir, "nodes", "test", "content.db"); diff --git a/tests/Main.Tests/Services/EmbeddingGeneratorFactoryTests.cs b/tests/Main.Tests/Services/EmbeddingGeneratorFactoryTests.cs new file mode 100644 index 000000000..67cc5c7b6 --- /dev/null +++ b/tests/Main.Tests/Services/EmbeddingGeneratorFactoryTests.cs @@ -0,0 +1,289 @@ +// Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core.Config.Embeddings; +using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Embeddings; +using KernelMemory.Core.Embeddings.Cache; +using KernelMemory.Core.Embeddings.Providers; +using KernelMemory.Main.Services; +using Microsoft.Extensions.Logging; +using Moq; + +namespace KernelMemory.Main.Tests.Services; + +/// +/// Unit tests for EmbeddingGeneratorFactory. +/// Tests verify correct generator creation from configuration. +/// +public sealed class EmbeddingGeneratorFactoryTests : IDisposable +{ + private readonly Mock _mockLoggerFactory; + private readonly HttpClient _httpClient; + + public EmbeddingGeneratorFactoryTests() + { + // Setup mock logger factory + this._mockLoggerFactory = new Mock(); + this._mockLoggerFactory + .Setup(x => x.CreateLogger(It.IsAny())) + .Returns(new Mock().Object); + + this._httpClient = new HttpClient(); + } + + public void Dispose() + { + this._httpClient.Dispose(); + GC.SuppressFinalize(this); + } + + [Fact] + public void CreateGenerator_CreatesOllamaGenerator() + { + // Arrange + var config = new OllamaEmbeddingsConfig + { + Model = "qwen3-embedding", + BaseUrl = "http://localhost:11434" + }; + + // Act + var generator = EmbeddingGeneratorFactory.CreateGenerator( + config, + this._httpClient, + cache: null, + this._mockLoggerFactory.Object); + + // Assert + Assert.NotNull(generator); + Assert.Equal(EmbeddingsTypes.Ollama, generator.ProviderType); + Assert.Equal("qwen3-embedding", generator.ModelName); + Assert.Equal(1024, generator.VectorDimensions); // Known dimension for qwen3-embedding + } + + [Fact] + public void CreateGenerator_CreatesOpenAIGenerator() + { + // Arrange + var config = new OpenAIEmbeddingsConfig + { + Model = "text-embedding-3-small", + ApiKey = "test-api-key" + }; + + // Act + var generator = EmbeddingGeneratorFactory.CreateGenerator( + config, + this._httpClient, + cache: null, + this._mockLoggerFactory.Object); + + // Assert + Assert.NotNull(generator); + Assert.Equal(EmbeddingsTypes.OpenAI, generator.ProviderType); + Assert.Equal("text-embedding-3-small", generator.ModelName); + Assert.Equal(1536, generator.VectorDimensions); // Known dimension for text-embedding-3-small + } + + [Fact] + public void CreateGenerator_CreatesAzureOpenAIGenerator() + { + // Arrange + var config = new AzureOpenAIEmbeddingsConfig + { + Model = "text-embedding-ada-002", + Endpoint = "https://test.openai.azure.com", + Deployment = "test-deployment", + ApiKey = "test-api-key" + }; + + // Act + var generator = EmbeddingGeneratorFactory.CreateGenerator( + config, + this._httpClient, + cache: null, + this._mockLoggerFactory.Object); + + // Assert + Assert.NotNull(generator); + Assert.Equal(EmbeddingsTypes.AzureOpenAI, generator.ProviderType); + Assert.Equal("text-embedding-ada-002", generator.ModelName); + Assert.Equal(1536, generator.VectorDimensions); // Known dimension for ada-002 + } + + [Fact] + public void CreateGenerator_CreatesHuggingFaceGenerator() + { + // Arrange + var config = new HuggingFaceEmbeddingsConfig + { + Model = "sentence-transformers/all-MiniLM-L6-v2", + ApiKey = "test-api-key" + }; + + // Act + var generator = EmbeddingGeneratorFactory.CreateGenerator( + config, + this._httpClient, + cache: null, + this._mockLoggerFactory.Object); + + // Assert + Assert.NotNull(generator); + Assert.Equal(EmbeddingsTypes.HuggingFace, generator.ProviderType); + Assert.Equal("sentence-transformers/all-MiniLM-L6-v2", generator.ModelName); + Assert.Equal(384, generator.VectorDimensions); // Known dimension for all-MiniLM-L6-v2 + } + + [Fact] + public void CreateGenerator_WrapsWithCacheWhenProvided() + { + // Arrange + var config = new OllamaEmbeddingsConfig + { + Model = "qwen3-embedding", + BaseUrl = "http://localhost:11434" + }; + + var mockCache = new Mock(); + mockCache.Setup(x => x.Mode).Returns(CacheModes.ReadWrite); + + // Act + var generator = EmbeddingGeneratorFactory.CreateGenerator( + config, + this._httpClient, + mockCache.Object, + this._mockLoggerFactory.Object); + + // Assert + Assert.NotNull(generator); + Assert.IsType(generator); + } + + [Fact] + public void CreateGenerator_DoesNotWrapWithCacheWhenNull() + { + // Arrange + var config = new OllamaEmbeddingsConfig + { + Model = "qwen3-embedding", + BaseUrl = "http://localhost:11434" + }; + + // Act + var generator = EmbeddingGeneratorFactory.CreateGenerator( + config, + this._httpClient, + cache: null, + this._mockLoggerFactory.Object); + + // Assert + Assert.NotNull(generator); + Assert.IsType(generator); + } + + [Fact] + public void CreateGenerator_ThrowsForNullConfig() + { + // Act & Assert + Assert.Throws(() => + EmbeddingGeneratorFactory.CreateGenerator( + null!, + this._httpClient, + cache: null, + this._mockLoggerFactory.Object)); + } + + [Fact] + public void CreateGenerator_ThrowsForNullHttpClient() + { + // Arrange + var config = new OllamaEmbeddingsConfig + { + Model = "qwen3-embedding", + BaseUrl = "http://localhost:11434" + }; + + // Act & Assert + Assert.Throws(() => + EmbeddingGeneratorFactory.CreateGenerator( + config, + null!, + cache: null, + this._mockLoggerFactory.Object)); + } + + [Fact] + public void CreateGenerator_ThrowsForNullLoggerFactory() + { + // Arrange + var config = new OllamaEmbeddingsConfig + { + Model = "qwen3-embedding", + BaseUrl = "http://localhost:11434" + }; + + // Act & Assert + Assert.Throws(() => + EmbeddingGeneratorFactory.CreateGenerator( + config, + this._httpClient, + cache: null, + null!)); + } + + [Fact] + public void CreateGenerator_UsesDefaultDimensionsForUnknownModel() + { + // Arrange + var config = new OllamaEmbeddingsConfig + { + Model = "unknown-model-xyz", + BaseUrl = "http://localhost:11434" + }; + + // Act + var generator = EmbeddingGeneratorFactory.CreateGenerator( + config, + this._httpClient, + cache: null, + this._mockLoggerFactory.Object); + + // Assert - Should use default Ollama dimension + Assert.NotNull(generator); + Assert.True(generator.VectorDimensions > 0); + } + + [Fact] + public void CreateGenerator_SetsIsNormalizedTrue() + { + // Arrange + var ollamaConfig = new OllamaEmbeddingsConfig + { + Model = "qwen3-embedding", + BaseUrl = "http://localhost:11434" + }; + + var openaiConfig = new OpenAIEmbeddingsConfig + { + Model = "text-embedding-3-small", + ApiKey = "test-key" + }; + + // Act + var ollamaGen = EmbeddingGeneratorFactory.CreateGenerator( + ollamaConfig, + this._httpClient, + cache: null, + this._mockLoggerFactory.Object); + + var openaiGen = EmbeddingGeneratorFactory.CreateGenerator( + openaiConfig, + this._httpClient, + cache: null, + this._mockLoggerFactory.Object); + + // Assert - All generators should be normalized + Assert.True(ollamaGen.IsNormalized); + Assert.True(openaiGen.IsNormalized); + } +} diff --git a/tests/Main.Tests/Services/SearchIndexFactoryVectorTests.cs b/tests/Main.Tests/Services/SearchIndexFactoryVectorTests.cs new file mode 100644 index 000000000..edc458557 --- /dev/null +++ b/tests/Main.Tests/Services/SearchIndexFactoryVectorTests.cs @@ -0,0 +1,264 @@ +// Copyright (c) Microsoft. All rights reserved. +using KernelMemory.Core.Config.Embeddings; +using KernelMemory.Core.Config.SearchIndex; +using KernelMemory.Core.Search; +using KernelMemory.Main.Services; +using Microsoft.Extensions.Logging; +using Moq; + +namespace KernelMemory.Main.Tests.Services; + +/// +/// Unit tests for SearchIndexFactory vector index creation. +/// Tests verify correct index creation from configuration. +/// +public sealed class SearchIndexFactoryVectorTests : IDisposable +{ + private readonly string _tempDir; + private readonly Mock _mockLoggerFactory; + private readonly HttpClient _httpClient; + + public SearchIndexFactoryVectorTests() + { + this._tempDir = Path.Combine(Path.GetTempPath(), $"factory_test_{Guid.NewGuid()}"); + Directory.CreateDirectory(this._tempDir); + + // Setup mock logger factory + this._mockLoggerFactory = new Mock(); + this._mockLoggerFactory + .Setup(x => x.CreateLogger(It.IsAny())) + .Returns(new Mock().Object); + + this._httpClient = new HttpClient(); + } + + public void Dispose() + { + this._httpClient.Dispose(); + + // Clean up temp directory + if (Directory.Exists(this._tempDir)) + { + Directory.Delete(this._tempDir, recursive: true); + } + + GC.SuppressFinalize(this); + } + + [Fact] + public void CreateIndexesWithEmbeddings_CreatesFtsIndex() + { + // Arrange + var ftsPath = Path.Combine(this._tempDir, "fts.db"); + var configs = new List + { + new FtsSearchIndexConfig + { + Id = "fts-test", + Path = ftsPath, + EnableStemming = true + } + }; + + // Act + var indexes = SearchIndexFactory.CreateIndexes( + configs, + this._httpClient, + embeddingCache: null, + this._mockLoggerFactory.Object); + + // Assert + Assert.Single(indexes); + Assert.Contains("fts-test", indexes.Keys); + Assert.IsType(indexes["fts-test"]); + + // Cleanup + ((IDisposable)indexes["fts-test"]).Dispose(); + } + + [Fact] + public void CreateIndexesWithEmbeddings_CreatesVectorIndex() + { + // Arrange + var vectorPath = Path.Combine(this._tempDir, "vector.db"); + var configs = new List + { + new VectorSearchIndexConfig + { + Id = "vector-test", + Path = vectorPath, + Dimensions = 384, + UseSqliteVec = false, + Embeddings = new OllamaEmbeddingsConfig + { + Model = "test-model", + BaseUrl = "http://localhost:11434" + } + } + }; + + // Act + var indexes = SearchIndexFactory.CreateIndexes( + configs, + this._httpClient, + embeddingCache: null, + this._mockLoggerFactory.Object); + + // Assert + Assert.Single(indexes); + Assert.Contains("vector-test", indexes.Keys); + Assert.IsType(indexes["vector-test"]); + + // Verify dimensions + var vectorIndex = (SqliteVectorIndex)indexes["vector-test"]; + Assert.Equal(384, vectorIndex.VectorDimensions); + + // Cleanup + vectorIndex.Dispose(); + } + + [Fact] + public void CreateIndexesWithEmbeddings_CreatesMixedIndexes() + { + // Arrange + var ftsPath = Path.Combine(this._tempDir, "fts-mixed.db"); + var vectorPath = Path.Combine(this._tempDir, "vector-mixed.db"); + var configs = new List + { + new FtsSearchIndexConfig + { + Id = "fts-mixed", + Path = ftsPath, + EnableStemming = true + }, + new VectorSearchIndexConfig + { + Id = "vector-mixed", + Path = vectorPath, + Dimensions = 768, + UseSqliteVec = false, + Embeddings = new OllamaEmbeddingsConfig + { + Model = "test-model", + BaseUrl = "http://localhost:11434" + } + } + }; + + // Act + var indexes = SearchIndexFactory.CreateIndexes( + configs, + this._httpClient, + embeddingCache: null, + this._mockLoggerFactory.Object); + + // Assert + Assert.Equal(2, indexes.Count); + Assert.IsType(indexes["fts-mixed"]); + Assert.IsType(indexes["vector-mixed"]); + + // Cleanup + ((IDisposable)indexes["fts-mixed"]).Dispose(); + ((IDisposable)indexes["vector-mixed"]).Dispose(); + } + + [Fact] + public void CreateIndexesWithEmbeddings_ThrowsForVectorIndexWithoutPath() + { + // Arrange + var configs = new List + { + new VectorSearchIndexConfig + { + Id = "vector-no-path", + Path = null, + Dimensions = 768, + Embeddings = new OllamaEmbeddingsConfig + { + Model = "test-model", + BaseUrl = "http://localhost:11434" + } + } + }; + + // Act & Assert + Assert.Throws(() => + SearchIndexFactory.CreateIndexes( + configs, + this._httpClient, + embeddingCache: null, + this._mockLoggerFactory.Object)); + } + + [Fact] + public void CreateIndexesWithEmbeddings_ThrowsForVectorIndexWithoutEmbeddings() + { + // Arrange + var vectorPath = Path.Combine(this._tempDir, "vector-no-embeddings.db"); + var configs = new List + { + new VectorSearchIndexConfig + { + Id = "vector-no-embeddings", + Path = vectorPath, + Dimensions = 768, + Embeddings = null + } + }; + + // Act & Assert + Assert.Throws(() => + SearchIndexFactory.CreateIndexes( + configs, + this._httpClient, + embeddingCache: null, + this._mockLoggerFactory.Object)); + } + + [Fact] + public void CreateIndexes_DoesNotCreateVectorIndexes() + { + // Arrange - Vector config should be ignored by CreateIndexes (no embedding support) + var ftsPath = Path.Combine(this._tempDir, "fts-only.db"); + var vectorPath = Path.Combine(this._tempDir, "vector-ignored.db"); + var configs = new List + { + new FtsSearchIndexConfig + { + Id = "fts-only", + Path = ftsPath, + EnableStemming = true + }, + new VectorSearchIndexConfig + { + Id = "vector-ignored", + Path = vectorPath, + Dimensions = 768, + Embeddings = new OllamaEmbeddingsConfig + { + Model = "test-model", + BaseUrl = "http://localhost:11434" + } + } + }; + + // Act + using var httpClient = new HttpClient(); + var indexes = SearchIndexFactory.CreateIndexes( + configs, + httpClient, + embeddingCache: null, + this._mockLoggerFactory.Object); + + // Assert - Both FTS and Vector indexes created + Assert.Equal(2, indexes.Count); + Assert.Contains("fts-only", indexes.Keys); + Assert.Contains("vector-ignored", indexes.Keys); + + // Cleanup + foreach (var index in indexes.Values.OfType()) + { + index.Dispose(); + } + } +} diff --git a/tests/Main.Tests/Unit/CLI/ModeRouterTests.cs b/tests/Main.Tests/Unit/CLI/ModeRouterTests.cs index d5ed66633..bcac1d48e 100644 --- a/tests/Main.Tests/Unit/CLI/ModeRouterTests.cs +++ b/tests/Main.Tests/Unit/CLI/ModeRouterTests.cs @@ -66,6 +66,6 @@ public void HandleUnimplementedMode_ReturnsSystemError() { var router = new ModeRouter(); var exitCode = router.HandleUnimplementedMode("Test", "Test description"); - Assert.Equal(Constants.ExitCodeSystemError, exitCode); + Assert.Equal(Constants.App.ExitCodeSystemError, exitCode); } } diff --git a/tests/Main.Tests/Unit/Commands/BaseCommandTests.cs b/tests/Main.Tests/Unit/Commands/BaseCommandTests.cs index 1147b58ce..7ed9cb8ed 100644 --- a/tests/Main.Tests/Unit/Commands/BaseCommandTests.cs +++ b/tests/Main.Tests/Unit/Commands/BaseCommandTests.cs @@ -25,7 +25,7 @@ public void HandleError_WithInvalidOperationException_ReturnsUserError() var exitCode = command.TestHandleError(exception, mockFormatter.Object); // Assert - Assert.Equal(Constants.ExitCodeUserError, exitCode); + Assert.Equal(Constants.App.ExitCodeUserError, exitCode); mockFormatter.Verify(f => f.FormatError("Invalid operation"), Times.Once); } @@ -41,7 +41,7 @@ public void HandleError_WithArgumentException_ReturnsUserError() var exitCode = command.TestHandleError(exception, mockFormatter.Object); // Assert - Assert.Equal(Constants.ExitCodeUserError, exitCode); + Assert.Equal(Constants.App.ExitCodeUserError, exitCode); mockFormatter.Verify(f => f.FormatError("Invalid argument"), Times.Once); } @@ -57,7 +57,7 @@ public void HandleError_WithGenericException_ReturnsSystemError() var exitCode = command.TestHandleError(exception, mockFormatter.Object); // Assert - Assert.Equal(Constants.ExitCodeSystemError, exitCode); + Assert.Equal(Constants.App.ExitCodeSystemError, exitCode); mockFormatter.Verify(f => f.FormatError("System failure"), Times.Once); } @@ -73,7 +73,7 @@ public void HandleError_WithIOException_ReturnsSystemError() var exitCode = command.TestHandleError(exception, mockFormatter.Object); // Assert - Assert.Equal(Constants.ExitCodeSystemError, exitCode); + Assert.Equal(Constants.App.ExitCodeSystemError, exitCode); mockFormatter.Verify(f => f.FormatError("File access error"), Times.Once); } diff --git a/tests/Main.Tests/Unit/Commands/DoctorCommandTests.cs b/tests/Main.Tests/Unit/Commands/DoctorCommandTests.cs new file mode 100644 index 000000000..a40cfb316 --- /dev/null +++ b/tests/Main.Tests/Unit/Commands/DoctorCommandTests.cs @@ -0,0 +1,458 @@ +// Copyright (c) Microsoft. All rights reserved. + +using KernelMemory.Core.Config; +using KernelMemory.Core.Config.Cache; +using KernelMemory.Core.Config.ContentIndex; +using KernelMemory.Core.Config.Embeddings; +using KernelMemory.Core.Config.Enums; +using KernelMemory.Core.Config.SearchIndex; +using KernelMemory.Main.CLI.Commands; +using Microsoft.Extensions.Logging; +using Moq; +using Spectre.Console.Cli; + +namespace KernelMemory.Main.Tests.Unit.Commands; + +/// +/// Unit tests for DoctorCommand validating configuration and system health checks. +/// Tests verify that doctor correctly identifies configuration issues and provides actionable fixes. +/// +public sealed class DoctorCommandTests +{ + private readonly Mock _mockLoggerFactory; + private readonly Mock> _mockLogger; + + public DoctorCommandTests() + { + this._mockLoggerFactory = new Mock(); + this._mockLogger = new Mock>(); + this._mockLoggerFactory + .Setup(f => f.CreateLogger(It.IsAny())) + .Returns(this._mockLogger.Object); + } + + /// + /// Verifies that doctor command succeeds when all dependencies are properly configured. + /// This is the happy path test. + /// + [Fact] + public async Task ExecuteAsync_WithValidConfigAndNoDependencies_ReturnsSuccess() + { + // Arrange - Create minimal config with only FTS (no external dependencies) + var tempDir = Path.Combine(Path.GetTempPath(), $"km-doctor-test-{Guid.NewGuid()}"); + Directory.CreateDirectory(tempDir); + + try + { + var config = new AppConfig + { + Nodes = new Dictionary + { + ["test"] = new NodeConfig + { + Id = "test", + Access = NodeAccessLevels.Full, + ContentIndex = new SqliteContentIndexConfig + { + Path = Path.Combine(tempDir, "content.db") + }, + SearchIndexes = new List + { + new FtsSearchIndexConfig + { + Id = "fts", + Type = SearchIndexTypes.SqliteFTS, + Path = Path.Combine(tempDir, "fts.db") + } + } + } + } + }; + + using var command = new DoctorCommand(config, this._mockLoggerFactory.Object); + var settings = new DoctorCommandSettings { NoColor = true, Format = "json" }; + var cliContext = new CommandContext([], new EmptyRemainingArguments(), "doctor", null!); + + // Act + var exitCode = await command.ExecuteAsync(cliContext, settings, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(0, exitCode); + } + finally + { + if (Directory.Exists(tempDir)) + { + Directory.Delete(tempDir, recursive: true); + } + } + } + + /// + /// Verifies that doctor detects missing Ollama and returns error exit code. + /// Tests the critical use case of vector search with unavailable provider. + /// + [Fact] + public async Task ExecuteAsync_WithOllamaConfigButServerDown_ReturnsError() + { + // Arrange - Config with Ollama vector index (Ollama likely not running on port 9999) + var tempDir = Path.Combine(Path.GetTempPath(), $"km-doctor-test-{Guid.NewGuid()}"); + Directory.CreateDirectory(tempDir); + + try + { + var config = new AppConfig + { + Nodes = new Dictionary + { + ["test"] = new NodeConfig + { + Id = "test", + Access = NodeAccessLevels.Full, + ContentIndex = new SqliteContentIndexConfig + { + Path = Path.Combine(tempDir, "content.db") + }, + SearchIndexes = new List + { + new VectorSearchIndexConfig + { + Id = "vector", + Type = SearchIndexTypes.SqliteVector, + Path = Path.Combine(tempDir, "vector.db"), + Dimensions = 1024, + Embeddings = new OllamaEmbeddingsConfig + { + Model = "qwen3-embedding", + BaseUrl = "http://localhost:9999" // Non-existent port + } + } + } + } + } + }; + + using var command = new DoctorCommand(config, this._mockLoggerFactory.Object); + var settings = new DoctorCommandSettings { NoColor = true, Format = "json" }; + var cliContext = new CommandContext([], new EmptyRemainingArguments(), "doctor", null!); + + // Act + var exitCode = await command.ExecuteAsync(cliContext, settings, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(1, exitCode); // User error - configuration issue + } + finally + { + if (Directory.Exists(tempDir)) + { + Directory.Delete(tempDir, recursive: true); + } + } + } + + /// + /// Verifies that doctor detects missing OpenAI API key. + /// Tests that environment variable checks work correctly. + /// + [Fact] + public async Task ExecuteAsync_WithOpenAIButNoApiKey_ReturnsError() + { + // Arrange + var tempDir = Path.Combine(Path.GetTempPath(), $"km-doctor-test-{Guid.NewGuid()}"); + Directory.CreateDirectory(tempDir); + + // Save and clear the OPENAI_API_KEY environment variable for this test + var originalApiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + Environment.SetEnvironmentVariable("OPENAI_API_KEY", null); + + try + { + var config = new AppConfig + { + Nodes = new Dictionary + { + ["test"] = new NodeConfig + { + Id = "test", + Access = NodeAccessLevels.Full, + ContentIndex = new SqliteContentIndexConfig + { + Path = Path.Combine(tempDir, "content.db") + }, + SearchIndexes = new List + { + new VectorSearchIndexConfig + { + Id = "vector", + Type = SearchIndexTypes.SqliteVector, + Path = Path.Combine(tempDir, "vector.db"), + Dimensions = 1536, + Embeddings = new OpenAIEmbeddingsConfig + { + Model = "text-embedding-3-small" + // No ApiKey set, and OPENAI_API_KEY env var cleared above + } + } + } + } + } + }; + + using var command = new DoctorCommand(config, this._mockLoggerFactory.Object); + var settings = new DoctorCommandSettings { NoColor = true, Format = "json" }; + var cliContext = new CommandContext([], new EmptyRemainingArguments(), "doctor", null!); + + // Act + var exitCode = await command.ExecuteAsync(cliContext, settings, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.Equal(1, exitCode); // Error exit code + } + finally + { + // Restore the original environment variable + Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalApiKey); + + if (Directory.Exists(tempDir)) + { + Directory.Delete(tempDir, recursive: true); + } + } + } + + /// + /// Verifies that doctor warns about missing cache file but doesn't fail. + /// Cache will be created on first use, so this is a warning not an error. + /// + [Fact] + public async Task ExecuteAsync_WithMissingCacheFile_ReturnsSuccessWithWarning() + { + // Arrange + var tempDir = Path.Combine(Path.GetTempPath(), $"km-doctor-test-{Guid.NewGuid()}"); + Directory.CreateDirectory(tempDir); + + try + { + var config = new AppConfig + { + Nodes = new Dictionary + { + ["test"] = new NodeConfig + { + Id = "test", + Access = NodeAccessLevels.Full, + ContentIndex = new SqliteContentIndexConfig + { + Path = Path.Combine(tempDir, "content.db") + }, + SearchIndexes = new List + { + new FtsSearchIndexConfig + { + Id = "fts", + Type = SearchIndexTypes.SqliteFTS, + Path = Path.Combine(tempDir, "fts.db") + } + } + } + }, + EmbeddingsCache = new CacheConfig + { + Path = Path.Combine(tempDir, "cache.db"), + AllowRead = true, + AllowWrite = true, + Type = CacheTypes.Sqlite + } + }; + + using var command = new DoctorCommand(config, this._mockLoggerFactory.Object); + var settings = new DoctorCommandSettings { NoColor = true, Format = "json" }; + var cliContext = new CommandContext([], new EmptyRemainingArguments(), "doctor", null!); + + // Act + var exitCode = await command.ExecuteAsync(cliContext, settings, CancellationToken.None).ConfigureAwait(false); + + // Assert - Should succeed (warnings don't cause failure, only errors do) + Assert.Equal(0, exitCode); + } + finally + { + if (Directory.Exists(tempDir)) + { + Directory.Delete(tempDir, recursive: true); + } + } + } + + /// + /// Verifies that doctor correctly groups output by node when multiple nodes are configured. + /// Tests that NodeId is properly set for node-specific checks (Content index, FTS index). + /// Global checks (Config file, cache) should have null NodeId. + /// + [Fact] + public async Task ExecuteAsync_WithMultipleNodes_GroupsOutputByNode() + { + // Arrange - Create config with 3 nodes for thorough testing + var tempDir = Path.Combine(Path.GetTempPath(), $"km-doctor-test-{Guid.NewGuid()}"); + Directory.CreateDirectory(tempDir); + + try + { + // Create node directories + var node1Dir = Path.Combine(tempDir, "node1"); + var node2Dir = Path.Combine(tempDir, "node2"); + var node3Dir = Path.Combine(tempDir, "node3"); + Directory.CreateDirectory(node1Dir); + Directory.CreateDirectory(node2Dir); + Directory.CreateDirectory(node3Dir); + + var config = new AppConfig + { + Nodes = new Dictionary + { + ["personal"] = new NodeConfig + { + Id = "personal", + Access = NodeAccessLevels.Full, + ContentIndex = new SqliteContentIndexConfig + { + Path = Path.Combine(node1Dir, "content.db") + }, + SearchIndexes = new List + { + new FtsSearchIndexConfig + { + Id = "fts-1", + Type = SearchIndexTypes.SqliteFTS, + Path = Path.Combine(node1Dir, "fts.db") + } + } + }, + ["work"] = new NodeConfig + { + Id = "work", + Access = NodeAccessLevels.Full, + ContentIndex = new SqliteContentIndexConfig + { + Path = Path.Combine(node2Dir, "content.db") + }, + SearchIndexes = new List + { + new FtsSearchIndexConfig + { + Id = "fts-2", + Type = SearchIndexTypes.SqliteFTS, + Path = Path.Combine(node2Dir, "fts.db") + } + } + }, + ["archive"] = new NodeConfig + { + Id = "archive", + Access = NodeAccessLevels.ReadOnly, + ContentIndex = new SqliteContentIndexConfig + { + Path = Path.Combine(node3Dir, "content.db") + }, + SearchIndexes = new List + { + new FtsSearchIndexConfig + { + Id = "fts-3", + Type = SearchIndexTypes.SqliteFTS, + Path = Path.Combine(node3Dir, "fts.db") + } + } + } + }, + EmbeddingsCache = new CacheConfig + { + Path = Path.Combine(tempDir, "cache.db"), + AllowRead = true, + AllowWrite = true, + Type = CacheTypes.Sqlite + } + }; + + using var command = new DoctorCommand(config, this._mockLoggerFactory.Object); + var settings = new DoctorCommandSettings { NoColor = true, Format = "json" }; + var cliContext = new CommandContext([], new EmptyRemainingArguments(), "doctor", null!); + + // Capture console output to verify JSON output contains nodeId + var originalOut = Console.Out; + using var stringWriter = new StringWriter(); + Console.SetOut(stringWriter); + + try + { + // Act + var exitCode = await command.ExecuteAsync(cliContext, settings, CancellationToken.None).ConfigureAwait(false); + + // Assert - Should succeed (only FTS, no external dependencies) + Assert.Equal(0, exitCode); + + // Parse JSON output to verify nodeId is set correctly + var output = stringWriter.ToString(); + using var doc = System.Text.Json.JsonDocument.Parse(output); + var root = doc.RootElement; + + Assert.True(root.TryGetProperty("results", out var results)); + var resultsList = results.EnumerateArray().ToList(); + + // Should have checks for: config file (1) + 3 nodes x 2 checks (content + FTS) + 1 cache = 8 checks + Assert.Equal(8, resultsList.Count); + + // Verify config file check has no nodeId (null is omitted in JSON) + var configCheck = resultsList.First(r => r.GetProperty("component").GetString() == "Config file"); + // null values are omitted due to JsonIgnoreCondition.WhenWritingNull + Assert.False(configCheck.TryGetProperty("nodeId", out _)); + + // Verify cache check has no nodeId (null is omitted in JSON) + var cacheCheck = resultsList.First(r => r.GetProperty("component").GetString() == "Embeddings cache"); + Assert.False(cacheCheck.TryGetProperty("nodeId", out _)); + + // Verify node-specific checks have correct nodeId + var nodeChecks = resultsList.Where(r => + r.TryGetProperty("nodeId", out var nid) && + nid.ValueKind == System.Text.Json.JsonValueKind.String).ToList(); + + // Should have 6 node-specific checks (3 nodes x 2 checks each) + Assert.Equal(6, nodeChecks.Count); + + // Verify each node has both content and FTS checks + var personalChecks = nodeChecks.Where(r => r.GetProperty("nodeId").GetString() == "personal").ToList(); + var workChecks = nodeChecks.Where(r => r.GetProperty("nodeId").GetString() == "work").ToList(); + var archiveChecks = nodeChecks.Where(r => r.GetProperty("nodeId").GetString() == "archive").ToList(); + + Assert.Equal(2, personalChecks.Count); + Assert.Equal(2, workChecks.Count); + Assert.Equal(2, archiveChecks.Count); + + // Verify summary + Assert.True(root.TryGetProperty("summary", out var summary)); + Assert.Equal(8, summary.GetProperty("total").GetInt32()); + } + finally + { + Console.SetOut(originalOut); + } + } + finally + { + if (Directory.Exists(tempDir)) + { + Directory.Delete(tempDir, recursive: true); + } + } + } +} + +/// +/// Empty IRemainingArguments implementation for CommandContext in tests. +/// +internal sealed class EmptyRemainingArguments : IRemainingArguments +{ + public IReadOnlyList Raw => Array.Empty(); + public ILookup Parsed => Enumerable.Empty().ToLookup(x => x, x => (string?)null); +} diff --git a/tests/Main.Tests/Unit/OutputFormatters/HumanOutputFormatterTests.cs b/tests/Main.Tests/Unit/OutputFormatters/HumanOutputFormatterTests.cs index b0152558d..f0068cfbe 100644 --- a/tests/Main.Tests/Unit/OutputFormatters/HumanOutputFormatterTests.cs +++ b/tests/Main.Tests/Unit/OutputFormatters/HumanOutputFormatterTests.cs @@ -133,7 +133,7 @@ public void Format_WithLongContent_TruncatesInNormalMode() { // Arrange var formatter = new HumanOutputFormatter("normal", useColors: false); - var longContent = new string('x', Constants.MaxContentDisplayLength + 100); + var longContent = new string('x', Constants.App.MaxContentDisplayLength + 100); var content = new ContentDto { Id = "long-content-id", @@ -151,7 +151,7 @@ public void Format_WithLongContent_DoesNotTruncateInVerboseMode() { // Arrange var formatter = new HumanOutputFormatter("verbose", useColors: false); - var longContent = new string('y', Constants.MaxContentDisplayLength + 100); + var longContent = new string('y', Constants.App.MaxContentDisplayLength + 100); var content = new ContentDto { Id = "long-verbose-id", diff --git a/tests/Main.Tests/Unit/Settings/ListCommandSettingsTests.cs b/tests/Main.Tests/Unit/Settings/ListCommandSettingsTests.cs index 9da7b797d..f094fca70 100644 --- a/tests/Main.Tests/Unit/Settings/ListCommandSettingsTests.cs +++ b/tests/Main.Tests/Unit/Settings/ListCommandSettingsTests.cs @@ -138,6 +138,6 @@ public void DefaultValues_AreSetCorrectly() // Assert Assert.Equal(0, settings.Skip); - Assert.Equal(Constants.DefaultPageSize, settings.Take); + Assert.Equal(Constants.App.DefaultPageSize, settings.Take); } } diff --git a/tests/e2e/framework/__init__.py b/tests/e2e/framework/__init__.py new file mode 100644 index 000000000..de1d6e5ac --- /dev/null +++ b/tests/e2e/framework/__init__.py @@ -0,0 +1 @@ +"""E2E testing framework for km CLI.""" diff --git a/tests/e2e/framework/cli.py b/tests/e2e/framework/cli.py new file mode 100644 index 000000000..674df35c7 --- /dev/null +++ b/tests/e2e/framework/cli.py @@ -0,0 +1,101 @@ +"""CLI execution wrapper for testing km commands.""" +import subprocess +import json +import os +from pathlib import Path +from typing import Optional + + +class KmResult: + """Result of executing a km command.""" + + def __init__(self, stdout: str, stderr: str, exit_code: int): + self.stdout = stdout + self.stderr = stderr + self.exit_code = exit_code + + @property + def stdout_json(self): + """Parse stdout as JSON.""" + return json.loads(self.stdout) + + def assert_success(self): + """Assert command succeeded.""" + assert self.exit_code == 0, f"Command failed with exit code {self.exit_code}\nstderr: {self.stderr}" + + +class KmCli: + """Wrapper for executing km CLI commands.""" + + def __init__(self, config_path: Optional[str] = None, km_binary: Optional[str] = None): + self.config_path = config_path + + # Find km binary (Main.dll) + if km_binary: + self.km_binary = Path(km_binary) + else: + self.km_binary = locate_km_binary() + + def run(self, *args, timeout: int = 30) -> KmResult: + """Execute km command and return result.""" + cmd = ["dotnet", str(self.km_binary)] + cmd.extend(args) + + # Add config if specified + if self.config_path and "--config" not in args: + cmd.extend(["--config", self.config_path]) + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout + ) + + return KmResult(result.stdout, result.stderr, result.returncode) + + def get_database_path(self, node: str = "personal") -> Optional[Path]: + """Get path to node's content database from config.""" + if not self.config_path: + return None + + with open(self.config_path) as f: + config = json.load(f) + + node_config = config["nodes"].get(node) + if not node_config: + return None + + db_path = node_config["contentIndex"]["path"] + return Path(db_path) + + +def locate_km_binary() -> Path: + """ + Locate the built km CLI (KernelMemory.Main.dll). + + Priority: + 1) KM_BIN environment variable + 2) Debug output + 3) Release output + """ + env_bin = os.environ.get("KM_BIN") + if env_bin: + path = Path(env_bin) + if path.exists(): + return path + raise FileNotFoundError(f"KM_BIN is set but does not exist: {path}") + + repo_root = Path(__file__).parent.parent.parent.parent + candidates = [ + repo_root / "src/Main/bin/Debug/net10.0/KernelMemory.Main.dll", + repo_root / "src/Main/bin/Release/net10.0/KernelMemory.Main.dll", + ] + + for candidate in candidates: + if candidate.exists(): + return candidate + + raise FileNotFoundError( + "km binary not found. Set KM_BIN to the path of KernelMemory.Main.dll or build the project." + ) diff --git a/tests/e2e/framework/db.py b/tests/e2e/framework/db.py new file mode 100644 index 000000000..3ad77d931 --- /dev/null +++ b/tests/e2e/framework/db.py @@ -0,0 +1,49 @@ +"""SQLite database inspection utilities.""" +import sqlite3 +from pathlib import Path +from typing import Optional, List, Dict, Any + + +class SqliteDb: + """Wrapper for inspecting SQLite databases.""" + + def __init__(self, db_path: Path): + self.db_path = db_path + self.conn: Optional[sqlite3.Connection] = None + + def __enter__(self): + self.conn = sqlite3.connect(str(self.db_path)) + self.conn.row_factory = sqlite3.Row + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.conn: + self.conn.close() + + def count_rows(self, table: str) -> int: + """Count rows in a table.""" + cursor = self.conn.execute(f"SELECT COUNT(*) FROM {table}") + return cursor.fetchone()[0] + + def get_row(self, table: str, id_value: str) -> Optional[Dict[str, Any]]: + """Get a single row by ID.""" + cursor = self.conn.execute(f"SELECT * FROM {table} WHERE Id = ?", (id_value,)) + row = cursor.fetchone() + return dict(row) if row else None + + def has_table(self, table_name: str) -> bool: + """Check if table exists.""" + cursor = self.conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (table_name,) + ) + return cursor.fetchone() is not None + + def get_columns(self, table_name: str) -> List[str]: + """Get list of column names in a table.""" + cursor = self.conn.execute(f"PRAGMA table_info({table_name})") + return [row[1] for row in cursor.fetchall()] + + def has_column(self, table: str, column: str) -> bool: + """Check if table has a specific column.""" + return column in self.get_columns(table) diff --git a/tests/e2e/framework/logging.py b/tests/e2e/framework/logging.py new file mode 100644 index 000000000..632f195b2 --- /dev/null +++ b/tests/e2e/framework/logging.py @@ -0,0 +1,38 @@ +"""Helpers for per-test C# log files.""" +from __future__ import annotations + +from pathlib import Path + + +def prepare_log_path(log_path: Path) -> Path: + """ + Prepare a dedicated log file path for a test (in the same folder as the test). + + - Ensures the directory exists. + - Removes any previous log files for this test (including rolled files). + - Returns the full path to the log file to place in the C# config/CLI options. + """ + log_path = log_path.resolve() + log_path.parent.mkdir(parents=True, exist_ok=True) + + pattern = f"{log_path.stem}*.log" + for file in log_path.parent.glob(pattern): + file.unlink(missing_ok=True) + + return log_path + + +def assert_log_has_entries(log_path: Path, markers: list[str] | None = None) -> None: + """ + Verify the C# log file exists, is non-empty, and contains expected markers. + + markers are short substrings expected to come from the C# logging output, + e.g., "km CLI starting" or "Command=put". + """ + assert log_path.exists(), f"Expected log file at {log_path}" + assert log_path.stat().st_size > 0, f"Log file {log_path} should not be empty" + + if markers: + content = log_path.read_text(encoding="utf-8", errors="ignore") + for marker in markers: + assert marker in content, f"Expected log marker '{marker}' in {log_path}" diff --git a/tests/e2e/requirements.txt b/tests/e2e/requirements.txt new file mode 100644 index 000000000..5b1f97bd7 --- /dev/null +++ b/tests/e2e/requirements.txt @@ -0,0 +1 @@ +pytest>=7.4.0 diff --git a/tests/e2e/test_01_put_get_delete.py b/tests/e2e/test_01_put_get_delete.py new file mode 100755 index 000000000..0f13f22b6 --- /dev/null +++ b/tests/e2e/test_01_put_get_delete.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +""" +E2E Test: Basic CRUD workflow (put → get → delete) + +Tests the most fundamental user workflow. +Verifies database state at each step. +""" +import subprocess +import json +import sqlite3 +import tempfile +import shutil +from pathlib import Path +from framework.cli import locate_km_binary +from framework.logging import assert_log_has_entries, prepare_log_path + + +def run_km(*args, config_path, log_path): + """Execute km command and return result.""" + km_binary = locate_km_binary() + cmd = ["dotnet", str(km_binary)] + list(args) + [ + "--config", + config_path, + "--log-file", + str(log_path), + "--verbosity", + "verbose", + ] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) + return result + + +def main(): + print("="*70) + print("TEST: Basic CRUD workflow (put → get → delete)") + print("="*70) + + tmp_dir = Path(tempfile.mkdtemp(prefix="km-e2e-test01-")) + log_path = prepare_log_path(Path(__file__).with_suffix(".log")) + + try: + # Setup: Create config + print("\n[SETUP] Creating test config...") + config = { + "nodes": { + "test": { + "id": "test", + "access": "Full", + "contentIndex": {"type": "sqlite", "path": str(tmp_dir / "content.db")}, + "searchIndexes": [ + {"type": "sqliteFTS", "id": "fts", "path": str(tmp_dir / "fts.db"), "required": True} + ], + } + } + } + config_path = str(tmp_dir / "config.json") + with open(config_path, 'w') as f: + json.dump(config, f) + print(f" Config: {config_path}") + + # Step 1: Put content + print("\n[STEP 1] Running: km put 'Hello world' --id test-1") + result = run_km("put", "Hello world", "--id", "test-1", "--format", "json", config_path=config_path, log_path=log_path) + + print(f" Exit code: {result.returncode}") + assert result.returncode == 0, f"Put failed: {result.stderr}" + + data = json.loads(result.stdout) + print(f" Response: id={data['id']}, completed={data['completed']}") + assert data["id"] == "test-1", "ID should be test-1" + assert data["completed"] == True, "Should complete immediately" + print(" ✓ PASS: Content created") + + # Step 2: Verify database + print("\n[STEP 2] Checking database state...") + db_path = tmp_dir / "content.db" + print(f" Database: {db_path}") + assert db_path.exists(), "Database file should exist" + print(" ✓ Database file exists") + + conn = sqlite3.connect(str(db_path)) + cursor = conn.execute("SELECT COUNT(*) FROM km_content") + count = cursor.fetchone()[0] + print(f" Row count in km_content: {count}") + assert count == 1, f"Expected 1 row, got {count}" + print(" ✓ PASS: Database has 1 row") + + cursor = conn.execute("SELECT Id, Content FROM km_content WHERE Id = ?", ("test-1",)) + row = cursor.fetchone() + assert row, "Content row should exist" + print(f" Content: '{row[1]}'") + assert "Hello world" in row[1], "Content should contain 'Hello world'" + print(" ✓ PASS: Content matches") + conn.close() + + # Step 3: Get content + print("\n[STEP 3] Running: km get test-1") + result = run_km("get", "test-1", "--format", "json", config_path=config_path, log_path=log_path) + + print(f" Exit code: {result.returncode}") + assert result.returncode == 0, f"Get failed: {result.stderr}" + + data = json.loads(result.stdout) + print(f" Retrieved content: '{data['content'][:50]}'") + assert "Hello world" in data["content"], "Should retrieve correct content" + print(" ✓ PASS: Get succeeded") + + # Step 4: Delete content + print("\n[STEP 4] Running: km delete test-1") + result = run_km("delete", "test-1", "--format", "json", config_path=config_path, log_path=log_path) + + print(f" Exit code: {result.returncode}") + assert result.returncode == 0, f"Delete failed: {result.stderr}" + print(" ✓ PASS: Delete succeeded") + + # Step 5: Verify deletion + print("\n[STEP 5] Verifying content deleted from database...") + conn = sqlite3.connect(str(db_path)) + cursor = conn.execute("SELECT COUNT(*) FROM km_content") + count = cursor.fetchone()[0] + print(f" Row count after delete: {count}") + assert count == 0, "Content should be deleted" + print(" ✓ PASS: Content deleted from database") + conn.close() + + print("\n[VERIFY] Checking C# log file...") + assert_log_has_entries(log_path, markers=["km CLI starting", "Command=put", "Command=get", "Command=delete"]) + print(f" ✓ PASS: C# log captured at {log_path}") + + print("\n" + "="*70) + print("✅ TEST PASSED: All steps completed successfully") + print("="*70) + return 0 + + except AssertionError as e: + print(f"\n❌ TEST FAILED: {e}") + return 1 + except Exception as e: + print(f"\n❌ TEST ERROR: {e}") + import traceback + traceback.print_exc() + return 1 + finally: + shutil.rmtree(tmp_dir) + + +if __name__ == "__main__": + exit(main()) diff --git a/tests/e2e/test_02_search_with_broken_node.py b/tests/e2e/test_02_search_with_broken_node.py new file mode 100755 index 000000000..1e0d88c05 --- /dev/null +++ b/tests/e2e/test_02_search_with_broken_node.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +""" +E2E Test: Search with broken node + +Regression test for bug found 2025-12-04: +km search was crashing when one node had missing database. +Should skip broken node and search working nodes. +""" +import subprocess +import json +import tempfile +import shutil +from pathlib import Path +from framework.cli import locate_km_binary +from framework.logging import assert_log_has_entries, prepare_log_path + + +def run_km(*args, config_path, log_path): + """Execute km command and return result.""" + km_binary = locate_km_binary() + cmd = ["dotnet", str(km_binary)] + list(args) + [ + "--config", + config_path, + "--log-file", + str(log_path), + "--verbosity", + "verbose", + "--format", + "json", + ] + return subprocess.run(cmd, capture_output=True, text=True, timeout=30) + + +def main(): + print("="*70) + print("TEST: Search with broken node (regression test)") + print("="*70) + + tmp_dir = Path(tempfile.mkdtemp(prefix="km-e2e-test02-")) + log_path = prepare_log_path(Path(__file__).with_suffix(".log")) + + try: + # Setup: Config with 2 nodes + print("\n[SETUP] Creating config with 2 nodes...") + config = { + "nodes": { + "working": { + "id": "working", + "access": "Full", + "contentIndex": {"type": "sqlite", "path": str(tmp_dir / "working/content.db")}, + "searchIndexes": [ + {"type": "sqliteFTS", "id": "fts1", "path": str(tmp_dir / "working/fts.db"), "required": True} + ] + }, + "broken": { + "id": "broken", + "access": "Full", + "contentIndex": {"type": "sqlite", "path": str(tmp_dir / "broken/content.db")}, + "searchIndexes": [ + {"type": "sqliteFTS", "id": "fts2", "path": str(tmp_dir / "broken/fts.db"), "required": True} + ] + } + } + } + config_path = str(tmp_dir / "config.json") + with open(config_path, 'w') as f: + json.dump(config, f) + print(f" Nodes: working, broken") + + # Step 1: Put content only to working node + print("\n[STEP 1] Adding content to 'working' node only...") + result = run_km("put", "searchable content", "--node", "working", config_path=config_path, log_path=log_path) + assert result.returncode == 0, f"Put failed: {result.stderr}" + print(" ✓ Content added to working node") + print(" Note: 'broken' node has no database") + + # Step 2: Search across all nodes + print("\n[STEP 2] Running: km search 'searchable' (searches all nodes)") + result = run_km("search", "searchable", config_path=config_path, log_path=log_path) + + print(f" Exit code: {result.returncode}") + print(f" Stderr: {result.stderr[:200] if result.stderr else '(empty)'}...") + + # Check: Should NOT crash + assert result.returncode == 0, \ + f"Search should succeed even with broken node. Exit code: {result.returncode}, stderr: {result.stderr}" + print(" ✓ PASS: Search did not crash") + + # Check: Should find content from working node + assert "searchable content" in result.stdout, \ + f"Should find content from working node. Output: {result.stdout}" + print(" ✓ PASS: Found content from working node") + + # Check: Should log warning about broken node + has_warning = "broken" in result.stderr.lower() or "skipping" in result.stderr.lower() + assert has_warning, f"Should warn about skipping broken node. stderr: {result.stderr}" + print(" ✓ PASS: Warning logged for broken node") + + print("\n[VERIFY] Checking C# log file...") + assert_log_has_entries(log_path, markers=["km CLI starting", "Command=put", "Command=search"]) + print(f" ✓ PASS: C# log captured at {log_path}") + + print("\n" + "="*70) + print("✅ TEST PASSED: Search handles broken nodes gracefully") + print("="*70) + return 0 + + except AssertionError as e: + print(f"\n❌ TEST FAILED: {e}") + return 1 + except Exception as e: + print(f"\n❌ TEST ERROR: {e}") + import traceback + traceback.print_exc() + return 1 + finally: + shutil.rmtree(tmp_dir) + + +if __name__ == "__main__": + exit(main()) diff --git a/tests/e2e/test_03_fts_stemming.py b/tests/e2e/test_03_fts_stemming.py new file mode 100755 index 000000000..3be8c8136 --- /dev/null +++ b/tests/e2e/test_03_fts_stemming.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +""" +E2E Test: FTS Stemming + +Verifies that FTS stemming is working correctly. +With stemming enabled, searching for "running" should find "run". +""" +import subprocess +import json +import tempfile +import shutil +from pathlib import Path +from framework.cli import locate_km_binary +from framework.logging import assert_log_has_entries, prepare_log_path + + +def run_km(*args, config_path, log_path): + """Execute km command and return result.""" + km_binary = locate_km_binary() + cmd = ["dotnet", str(km_binary)] + list(args) + [ + "--config", + config_path, + "--log-file", + str(log_path), + "--verbosity", + "verbose", + "--format", + "json", + ] + return subprocess.run(cmd, capture_output=True, text=True, timeout=30) + + +def main(): + print("="*70) + print("TEST: FTS Stemming") + print("="*70) + + tmp_dir = Path(tempfile.mkdtemp(prefix="km-e2e-test03-")) + log_path = prepare_log_path(Path(__file__).with_suffix(".log")) + + try: + # Setup: Create config with stemming enabled + print("\n[SETUP] Creating config with FTS stemming enabled...") + config = { + "nodes": { + "test": { + "id": "test", + "access": "Full", + "contentIndex": {"type": "sqlite", "path": str(tmp_dir / "content.db")}, + "searchIndexes": [ + { + "type": "sqliteFTS", + "id": "fts-stemmed", + "path": str(tmp_dir / "fts.db"), + "enableStemming": True, + "required": True + } + ] + } + } + } + config_path = str(tmp_dir / "config.json") + with open(config_path, 'w') as f: + json.dump(config, f) + print(" FTS stemming: enabled") + + # Step 1: Put content with base word "test" + print("\n[STEP 1] Adding content with word 'test'...") + result = run_km("put", "We test the software thoroughly", "--id", "doc1", "--format", "json", config_path=config_path, log_path=log_path) + assert result.returncode == 0, f"Put failed: {result.stderr}" + print(" ✓ Content added: 'We test the software thoroughly'") + + # Step 2: Put content with variant word "testing" + print("\n[STEP 2] Adding content with word 'testing'...") + result = run_km("put", "Testing is important for quality", "--id", "doc2", "--format", "json", config_path=config_path, log_path=log_path) + assert result.returncode == 0, f"Put failed: {result.stderr}" + print(" ✓ Content added: 'Testing is important for quality'") + + # Step 3: Put content with variant word "tests" + print("\n[STEP 3] Adding content with word 'tests'...") + result = run_km("put", "All tests passed successfully", "--id", "doc3", "--format", "json", config_path=config_path, log_path=log_path) + assert result.returncode == 0, f"Put failed: {result.stderr}" + print(" ✓ Content added: 'All tests passed successfully'") + + # Step 4: Search for "testing" - should find all 3 due to stemming + print("\n[STEP 4] Searching for 'testing' (should find all variants due to stemming)...") + result = run_km("search", "testing", "--format", "json", config_path=config_path, log_path=log_path) + assert result.returncode == 0, f"Search failed: {result.stderr}" + + data = json.loads(result.stdout) + total_results = data.get("totalResults", 0) + print(f" Total results found: {total_results}") + + # With stemming: "testing" should match "test", "testing", "tests" + assert total_results == 3, f"Stemming should find all 3 variants. Found: {total_results}" + print(" ✓ PASS: Stemming found all 3 variants (test, testing, tests)") + + # Verify all 3 documents are in results + result_ids = {r["id"] for r in data["results"]} + print(f" Result IDs: {result_ids}") + assert "doc1" in result_ids, "Should find 'test'" + assert "doc2" in result_ids, "Should find 'testing'" + assert "doc3" in result_ids, "Should find 'tests'" + print(" ✓ PASS: All expected documents found") + + # Step 5: Search for variant not in documents - verify stemming finds the stem + print("\n[STEP 5] Searching for 'tested' (not in any document, stems to 'test')...") + result = run_km("search", "tested", "--format", "json", config_path=config_path, log_path=log_path) + assert result.returncode == 0 + + data = json.loads(result.stdout) + total_results = data.get("totalResults", 0) + print(f" Total results: {total_results}") + + # "tested" stems to "test", should find all 3 documents + assert total_results == 3, f"Stemming 'tested' should find all 'test' variants. Found: {total_results}" + print(" ✓ PASS: Stemming works for 'tested' → finds all 'test' variants") + + print("\n[VERIFY] Checking C# log file...") + assert_log_has_entries(log_path, markers=["km CLI starting", "Command=put", "Command=search"]) + print(f" ✓ PASS: C# log captured at {log_path}") + + print("\n" + "="*70) + print("✅ TEST PASSED: FTS stemming works correctly") + print("="*70) + return 0 + + except AssertionError as e: + print(f"\n❌ TEST FAILED: {e}") + return 1 + except Exception as e: + print(f"\n❌ TEST ERROR: {e}") + import traceback + traceback.print_exc() + return 1 + finally: + shutil.rmtree(tmp_dir) + + +if __name__ == "__main__": + exit(main()) diff --git a/tests/e2e/test_04_vector_search.py b/tests/e2e/test_04_vector_search.py new file mode 100755 index 000000000..5f1d1f141 --- /dev/null +++ b/tests/e2e/test_04_vector_search.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 +""" +E2E Test: Vector Search + +Verifies that vector search indexing works end-to-end. +Tests that embeddings are generated and stored in vector database. +FAILS if Ollama is not available (this is intentional - integration test). +""" +import subprocess +import json +import sqlite3 +import tempfile +import shutil +from pathlib import Path +from framework.cli import locate_km_binary +from framework.logging import assert_log_has_entries, prepare_log_path + + +def run_km(*args, config_path, log_path): + """Execute km command and return result.""" + km_binary = locate_km_binary() + cmd = ["dotnet", str(km_binary)] + list(args) + [ + "--config", + config_path, + "--log-file", + str(log_path), + "--verbosity", + "verbose", + "--format", + "json", + ] + return subprocess.run(cmd, capture_output=True, text=True, timeout=30) + + +def main(): + print("="*70) + print("TEST: Vector Search with Embeddings") + print("="*70) + + tmp_dir = Path(tempfile.mkdtemp(prefix="km-e2e-test04-")) + log_path = prepare_log_path(Path(__file__).with_suffix(".log")) + + try: + # Setup: Create config with vector search + print("\n[SETUP] Creating config with vector search index...") + config = { + "nodes": { + "test": { + "id": "test", + "access": "Full", + "contentIndex": {"type": "sqlite", "path": str(tmp_dir / "content.db")}, + "searchIndexes": [ + { + "type": "sqliteFTS", + "id": "fts", + "path": str(tmp_dir / "fts.db"), + "required": True + }, + { + "type": "sqliteVector", + "id": "vector", + "path": str(tmp_dir / "vector.db"), + "dimensions": 1024, + "useSqliteVec": False, + "embeddings": { + "type": "ollama", + "model": "qwen3-embedding:0.6b", + "baseUrl": "http://localhost:11434" + } + } + ] + } + }, + "embeddingsCache": { + "type": "Sqlite", + "path": str(tmp_dir / "cache.db"), + "allowRead": True, + "allowWrite": True + } + } + config_path = str(tmp_dir / "config.json") + with open(config_path, 'w') as f: + json.dump(config, f) + print(" Vector index: configured with Ollama qwen3-embedding:0.6b") + print(" Embeddings cache: enabled") + + # Step 1: Verify Ollama is available + print("\n[STEP 1] Checking if Ollama is available...") + import urllib.request + try: + urllib.request.urlopen("http://localhost:11434/api/tags", timeout=2) + print(" ✓ Ollama is reachable") + except Exception as e: + print(f" ❌ TEST SKIPPED: Ollama not available ({e})") + print(" This test requires Ollama running with qwen3-embedding:0.6b model") + config_result = run_km("config", "--format", "json", config_path=config_path, log_path=log_path) + assert config_result.returncode == 0, f"Config command failed while skipping: {config_result.stderr}" + assert_log_has_entries(log_path, markers=["km CLI starting", "Command=config"]) + return 0 + + # Step 2: Put content (should generate embedding and store in vector index) + print("\n[STEP 2] Running: km put 'machine learning concepts'...") + result = run_km("put", "machine learning concepts", "--id", "ml-doc", "--format", "json", config_path=config_path, log_path=log_path) + + print(f" Exit code: {result.returncode}") + if result.returncode != 0: + print(f" Stderr: {result.stderr}") + + data = json.loads(result.stdout) + print(f" Response: id={data['id']}, completed={data['completed']}, queued={data.get('queued', False)}") + + # Must complete successfully (not just queued) + assert data["completed"] == True, f"Operation should complete. If queued, check Ollama: {result.stderr}" + print(" ✓ PASS: Content indexed (embeddings generated)") + + # Step 3: Verify vector database was created and has data + print("\n[STEP 3] Inspecting vector database...") + vector_db_path = tmp_dir / "vector.db" + assert vector_db_path.exists(), "Vector database should be created" + print(f" Database: {vector_db_path}") + + conn = sqlite3.connect(str(vector_db_path)) + + # Check table exists + cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='km_vectors'") + table = cursor.fetchone() + assert table is not None, "km_vectors table should exist" + print(" ✓ km_vectors table exists") + + # Check row count + cursor = conn.execute("SELECT COUNT(*) FROM km_vectors") + count = cursor.fetchone()[0] + print(f" Row count: {count}") + assert count == 1, f"Should have 1 vector, got {count}" + print(" ✓ PASS: Vector database has 1 row") + + # Check actual vector data + cursor = conn.execute("SELECT content_id, LENGTH(vector), created_at FROM km_vectors WHERE content_id = ?", ("ml-doc",)) + row = cursor.fetchone() + assert row is not None, "Vector for ml-doc should exist" + + content_id, vector_size, created_at = row + print(f" Content ID: {content_id}") + print(f" Vector size: {vector_size} bytes") + print(f" Created at: {created_at}") + + # Vector should be 1024 dimensions * 4 bytes (float32) = 4096 bytes + expected_size = 1024 * 4 + assert vector_size == expected_size, f"Vector should be {expected_size} bytes (1024 dims), got {vector_size}" + print(f" ✓ PASS: Vector size correct (1024 dimensions * 4 bytes = {expected_size} bytes)") + + conn.close() + + # Step 4: Verify embeddings cache was populated + print("\n[STEP 4] Inspecting embeddings cache...") + cache_db_path = tmp_dir / "cache.db" + assert cache_db_path.exists(), "Cache database should be created" + print(f" Cache: {cache_db_path}") + + conn = sqlite3.connect(str(cache_db_path)) + + # Check cache has entry + cursor = conn.execute("SELECT COUNT(*) FROM embeddings_cache") + count = cursor.fetchone()[0] + print(f" Cached embeddings: {count}") + assert count == 1, f"Cache should have 1 entry, got {count}" + print(" ✓ PASS: Embedding was cached") + + # Check cache entry details + cursor = conn.execute("SELECT provider, model, dimensions, LENGTH(vector), timestamp FROM embeddings_cache LIMIT 1") + row = cursor.fetchone() + provider, model, dims, vec_size, timestamp = row + print(f" Provider: {provider}") + print(f" Model: {model}") + print(f" Dimensions: {dims}") + print(f" Vector size: {vec_size} bytes") + print(f" Timestamp: {timestamp}") + + assert provider == "Ollama", f"Provider should be Ollama, got {provider}" + assert model == "qwen3-embedding:0.6b", f"Model should be qwen3-embedding:0.6b, got {model}" + assert dims == 1024, f"Dimensions should be 1024, got {dims}" + print(" ✓ PASS: Cache entry has correct metadata") + + conn.close() + + # Step 5: Put second document - should use cache + print("\n[STEP 5] Adding second document with same text (should use cache)...") + result = run_km("put", "machine learning concepts", "--id", "ml-doc-2", "--format", "json", config_path=config_path, log_path=log_path) + assert result.returncode == 0 + data = json.loads(result.stdout) + assert data["completed"] == True + print(" ✓ Second document indexed") + + # Check cache still has only 1 entry (same text = cache hit) + conn = sqlite3.connect(str(cache_db_path)) + cursor = conn.execute("SELECT COUNT(*) FROM embeddings_cache") + count = cursor.fetchone()[0] + print(f" Cache entries: {count}") + assert count == 1, "Cache should still have 1 entry (cache hit for same text)" + print(" ✓ PASS: Cache was reused (no new entry)") + conn.close() + + # Check vector DB has 2 entries + conn = sqlite3.connect(str(vector_db_path)) + cursor = conn.execute("SELECT COUNT(*) FROM km_vectors") + count = cursor.fetchone()[0] + print(f" Vector DB entries: {count}") + assert count == 2, f"Vector DB should have 2 entries, got {count}" + print(" ✓ PASS: Both documents have vectors") + conn.close() + + print("\n[VERIFY] Checking C# log file...") + assert_log_has_entries(log_path, markers=["km CLI starting", "Command=put"]) + print(f" ✓ PASS: C# log captured at {log_path}") + + print("\n" + "="*70) + print("✅ TEST PASSED: Vector search and caching work correctly") + print("="*70) + print("\nVerified:") + print(" ✓ Embeddings generated via Ollama") + print(" ✓ Vectors stored in vector database (1024 dimensions)") + print(" ✓ Embeddings cached (same text reused cache)") + print(" ✓ Multiple documents share cached embedding") + return 0 + + except AssertionError as e: + print(f"\n❌ TEST FAILED: {e}") + return 1 + except Exception as e: + print(f"\n❌ TEST ERROR: {e}") + import traceback + traceback.print_exc() + return 1 + finally: + shutil.rmtree(tmp_dir) + + +if __name__ == "__main__": + exit(main()) diff --git a/tests/e2e/test_05_embeddings_cache.py b/tests/e2e/test_05_embeddings_cache.py new file mode 100755 index 000000000..c9b4ed132 --- /dev/null +++ b/tests/e2e/test_05_embeddings_cache.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +""" +E2E Test: Embeddings Cache + +Verifies that embeddings are cached and reused. +Tests cache hit/miss scenarios and verifies timestamp and token_count are stored. +""" +import subprocess +import json +import sqlite3 +import tempfile +import shutil +from pathlib import Path +from framework.cli import locate_km_binary +from framework.logging import assert_log_has_entries, prepare_log_path + + +def run_km(*args, config_path, log_path): + """Execute km command and return result.""" + km_binary = locate_km_binary() + cmd = ["dotnet", str(km_binary)] + list(args) + [ + "--config", + config_path, + "--log-file", + str(log_path), + "--verbosity", + "verbose", + "--format", + "json", + ] + return subprocess.run(cmd, capture_output=True, text=True, timeout=30) + + +def main(): + print("="*70) + print("TEST: Embeddings Cache") + print("="*70) + + tmp_dir = Path(tempfile.mkdtemp(prefix="km-e2e-test05-")) + log_path = prepare_log_path(Path(__file__).with_suffix(".log")) + + try: + # Setup + print("\n[SETUP] Creating config with vector search and cache...") + config = { + "nodes": { + "test": { + "id": "test", + "access": "Full", + "contentIndex": {"type": "sqlite", "path": str(tmp_dir / "content.db")}, + "searchIndexes": [ + { + "type": "sqliteVector", + "id": "vector", + "path": str(tmp_dir / "vector.db"), + "dimensions": 1024, + "embeddings": { + "type": "ollama", + "model": "qwen3-embedding:0.6b" + } + } + ] + } + }, + "embeddingsCache": { + "type": "Sqlite", + "path": str(tmp_dir / "cache.db"), + "allowRead": True, + "allowWrite": True + } + } + config_path = str(tmp_dir / "config.json") + with open(config_path, 'w') as f: + json.dump(config, f) + print(" Cache enabled: ReadWrite mode") + + # Check Ollama + print("\n[PREREQ] Checking Ollama availability...") + import urllib.request + try: + urllib.request.urlopen("http://localhost:11434/api/tags", timeout=2) + print(" ✓ Ollama reachable") + except Exception as e: + print(f" ❌ TEST SKIPPED: Ollama not available") + config_result = run_km("config", "--format", "json", config_path=config_path, log_path=log_path) + assert config_result.returncode == 0, f"Config command failed while skipping: {config_result.stderr}" + assert_log_has_entries(log_path, markers=["km CLI starting", "Command=config"]) + return 0 + + # Step 1: Put first document (cache miss) + print("\n[STEP 1] Adding first document (cache miss)...") + result = run_km("put", "artificial intelligence", "--id", "doc1", "--format", "json", config_path=config_path, log_path=log_path) + assert result.returncode == 0 + data = json.loads(result.stdout) + assert data["completed"] == True, "Should complete" + print(" ✓ Document indexed") + + # Verify cache has 1 entry + cache_db = tmp_dir / "cache.db" + assert cache_db.exists(), "Cache database should exist" + + conn = sqlite3.connect(str(cache_db)) + cursor = conn.execute("SELECT COUNT(*) FROM embeddings_cache") + count = cursor.fetchone()[0] + print(f" Cache entries: {count}") + assert count == 1, "Cache should have 1 entry" + print(" ✓ PASS: Embedding cached (cache miss → cached)") + + # Check cache entry has timestamp and token_count + cursor = conn.execute("SELECT provider, model, dimensions, timestamp, token_count FROM embeddings_cache LIMIT 1") + row = cursor.fetchone() + provider, model, dims, timestamp, token_count = row + print(f" Cached: provider={provider}, model={model}, dims={dims}") + print(f" Timestamp: {timestamp}") + print(f" Token count: {token_count}") + + assert timestamp is not None, "Timestamp should be stored" + assert len(timestamp) > 0, "Timestamp should not be empty" + print(" ✓ PASS: Timestamp stored") + + # Token count may be None (Ollama doesn't provide it) + if token_count is None: + print(" ℹ Token count: None (Ollama doesn't provide token count)") + else: + print(f" ✓ Token count: {token_count}") + + conn.close() + + # Step 2: Put same text again (cache hit) + print("\n[STEP 2] Adding document with same text (cache hit)...") + result = run_km("put", "artificial intelligence", "--id", "doc2", "--format", "json", config_path=config_path, log_path=log_path) + assert result.returncode == 0 + data = json.loads(result.stdout) + assert data["completed"] == True + print(" ✓ Document indexed") + + # Cache should still have 1 entry (cache hit, no new embedding generated) + conn = sqlite3.connect(str(cache_db)) + cursor = conn.execute("SELECT COUNT(*) FROM embeddings_cache") + count = cursor.fetchone()[0] + print(f" Cache entries: {count}") + assert count == 1, "Cache should still have 1 entry (cache hit)" + print(" ✓ PASS: Cache reused (no new entry)") + conn.close() + + # Step 3: Put different text (cache miss) + print("\n[STEP 3] Adding document with different text (cache miss)...") + result = run_km("put", "deep learning neural networks", "--id", "doc3", "--format", "json", config_path=config_path, log_path=log_path) + assert result.returncode == 0 + data = json.loads(result.stdout) + assert data["completed"] == True + print(" ✓ Document indexed") + + # Cache should now have 2 entries + conn = sqlite3.connect(str(cache_db)) + cursor = conn.execute("SELECT COUNT(*) FROM embeddings_cache") + count = cursor.fetchone()[0] + print(f" Cache entries: {count}") + assert count == 2, f"Cache should have 2 entries, got {count}" + print(" ✓ PASS: New embedding cached") + conn.close() + + # Step 4: Verify vector database has 3 entries + print("\n[STEP 4] Verifying vector database...") + vector_db = tmp_dir / "vector.db" + conn = sqlite3.connect(str(vector_db)) + cursor = conn.execute("SELECT COUNT(*) FROM km_vectors") + count = cursor.fetchone()[0] + print(f" Vector DB entries: {count}") + assert count == 3, f"Should have 3 vectors, got {count}" + print(" ✓ PASS: All 3 documents have vectors") + + # Verify each vector is correct size (1024 dims * 4 bytes = 4096) + cursor = conn.execute("SELECT content_id, LENGTH(vector) FROM km_vectors ORDER BY content_id") + for row in cursor.fetchall(): + cid, size = row + print(f" {cid}: {size} bytes") + assert size == 4096, f"Vector should be 4096 bytes, got {size}" + + print(" ✓ PASS: All vectors are correct size (1024 dimensions)") + conn.close() + + print("\n[VERIFY] Checking C# log file...") + assert_log_has_entries(log_path, markers=["km CLI starting", "Command=put"]) + print(f" ✓ PASS: C# log captured at {log_path}") + + print("\n" + "="*70) + print("✅ TEST PASSED: Embeddings cache works correctly") + print("="*70) + print("\nVerified:") + print(" ✓ Cache miss → embedding generated and cached") + print(" ✓ Cache hit → embedding reused") + print(" ✓ Timestamp stored in cache") + print(" ✓ Multiple documents share cached embeddings when text matches") + return 0 + + except AssertionError as e: + print(f"\n❌ TEST FAILED: {e}") + return 1 + except Exception as e: + print(f"\n❌ TEST ERROR: {e}") + import traceback + traceback.print_exc() + return 1 + finally: + shutil.rmtree(tmp_dir) + + +if __name__ == "__main__": + exit(main()) From bfbd39b58c96cb400cdf4771e878aba91c9871ef Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 18 Dec 2025 14:20:02 +0100 Subject: [PATCH 2/5] ignore python cache --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 61a52de24..5c084db64 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ archived/ node_modules/ obj/ bin/ +__pycache__/ _dev/ .dev/ .vs/ @@ -68,4 +69,4 @@ publish/ *.crt *.key *.pem -certs/ \ No newline at end of file +certs/ From b5d126b4b8f3ff5b1372ebdd15132db37c010794 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 18 Dec 2025 15:00:13 +0100 Subject: [PATCH 3/5] Skip SearchProcessTests when Ollama unavailable --- .github/workflows/coverage.yml | 1 + tests/Main.Tests/Integration/SearchProcessTests.cs | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index ad0a11c27..4c64b3753 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -25,6 +25,7 @@ jobs: chmod +x ./coverage.sh ./coverage.sh 80 env: + OLLAMA_AVAILABLE: "false" MIN_COVERAGE: 80 - name: Upload coverage report diff --git a/tests/Main.Tests/Integration/SearchProcessTests.cs b/tests/Main.Tests/Integration/SearchProcessTests.cs index e325212e3..052334c82 100644 --- a/tests/Main.Tests/Integration/SearchProcessTests.cs +++ b/tests/Main.Tests/Integration/SearchProcessTests.cs @@ -2,6 +2,7 @@ using System.Diagnostics; using System.Text.Json; +using Xunit.Sdk; namespace KernelMemory.Main.Tests.Integration; @@ -89,6 +90,11 @@ private async Task ExecuteKmAsync(string args) [Fact] public async Task Process_PutThenSearch_FindsContent() { + if (string.Equals(Environment.GetEnvironmentVariable("OLLAMA_AVAILABLE"), "false", StringComparison.OrdinalIgnoreCase)) + { + throw new SkipException("Skipping because OLLAMA_AVAILABLE=false (vector embeddings unavailable)."); + } + // Act: Insert content var putOutput = await this.ExecuteKmAsync($"put \"ciao mondo\" --config {this._configPath}").ConfigureAwait(false); var putResult = JsonSerializer.Deserialize(putOutput); From ef11bb3027ebafe5f765177e151aa97b6d02c1d2 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 18 Dec 2025 15:07:27 +0100 Subject: [PATCH 4/5] Skip test requiring ollama for embeddings --- .../Integration/SearchProcessTests.cs | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/Main.Tests/Integration/SearchProcessTests.cs b/tests/Main.Tests/Integration/SearchProcessTests.cs index 052334c82..ec9e328c8 100644 --- a/tests/Main.Tests/Integration/SearchProcessTests.cs +++ b/tests/Main.Tests/Integration/SearchProcessTests.cs @@ -2,7 +2,6 @@ using System.Diagnostics; using System.Text.Json; -using Xunit.Sdk; namespace KernelMemory.Main.Tests.Integration; @@ -87,14 +86,9 @@ private async Task ExecuteKmAsync(string args) return output.Trim(); } - [Fact] + [OllamaFact] public async Task Process_PutThenSearch_FindsContent() { - if (string.Equals(Environment.GetEnvironmentVariable("OLLAMA_AVAILABLE"), "false", StringComparison.OrdinalIgnoreCase)) - { - throw new SkipException("Skipping because OLLAMA_AVAILABLE=false (vector embeddings unavailable)."); - } - // Act: Insert content var putOutput = await this.ExecuteKmAsync($"put \"ciao mondo\" --config {this._configPath}").ConfigureAwait(false); var putResult = JsonSerializer.Deserialize(putOutput); @@ -215,4 +209,15 @@ await this.ExecuteKmAsync($"put \"docker helm charts\" --config {this._configPat Assert.Contains(id1, ids); Assert.Contains(id2, ids); } + + private sealed class OllamaFactAttribute : FactAttribute + { + public OllamaFactAttribute() + { + if (string.Equals(Environment.GetEnvironmentVariable("OLLAMA_AVAILABLE"), "false", StringComparison.OrdinalIgnoreCase)) + { + this.Skip = "Skipping because OLLAMA_AVAILABLE=false (vector embeddings unavailable)."; + } + } + } } From 0532c176599eefc4cbbe3c345d5e3ec1ac6e9e0f Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 18 Dec 2025 06:59:30 -0800 Subject: [PATCH 5/5] Update format.sh --- format.sh | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/format.sh b/format.sh index 011073326..e07cac8a3 100755 --- a/format.sh +++ b/format.sh @@ -1,14 +1 @@ -#!/usr/bin/env bash - -set -e - -ROOT="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)" -cd "$ROOT" -TMPDIR="$ROOT/.tmp" -mkdir -p "$TMPDIR" -export TMPDIR - -dotnet format src/Core/Core.csproj -dotnet format src/Main/Main.csproj -dotnet format tests/Core.Tests/Core.Tests.csproj -dotnet format tests/Main.Tests/Main.Tests.csproj +dotnet format