Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Core/Config/AppConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ namespace KernelMemory.Core.Config;
/// Root configuration for Kernel Memory application
/// Loaded from ~/.km/config.json or custom path
/// </summary>
#pragma warning disable CA1724 // Conflicts with Microsoft.Identity.Client.AppConfig when Azure.Identity is referenced.
public sealed class AppConfig : IValidatable
#pragma warning restore CA1724
{
/// <summary>
/// Named memory nodes (e.g., "personal", "work")
Expand Down
5 changes: 5 additions & 0 deletions src/Core/Config/Embeddings/AzureOpenAIEmbeddingsConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,10 @@ public override void Validate(string path)
throw new ConfigException(path,
"Azure OpenAI: specify either ApiKey or UseManagedIdentity, not both");
}

if (this.BatchSize < 1)
{
throw new ConfigException($"{path}.BatchSize", "BatchSize must be >= 1");
}
}
}
8 changes: 8 additions & 0 deletions src/Core/Config/Embeddings/EmbeddingsConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ public abstract class EmbeddingsConfig : IValidatable
[JsonIgnore]
public abstract EmbeddingsTypes Type { get; }

/// <summary>
/// Maximum number of texts to send per embeddings API request.
/// Providers that support batch requests should chunk input using this size.
/// Default: 10.
/// </summary>
[JsonPropertyName("batchSize")]
public int BatchSize { get; set; } = Constants.EmbeddingDefaults.DefaultBatchSize;

/// <summary>
/// Validates the embeddings configuration
/// </summary>
Expand Down
18 changes: 17 additions & 1 deletion src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,25 @@ public override void Validate(string path)
throw new ConfigException($"{path}.Model", "HuggingFace model name is required");
}

// ApiKey can be provided via config or HF_TOKEN environment variable.
// Resolve env var into config so downstream code can rely on configuration only.
if (string.IsNullOrWhiteSpace(this.ApiKey))
{
throw new ConfigException($"{path}.ApiKey", "HuggingFace API key is required");
var token = Environment.GetEnvironmentVariable("HF_TOKEN");
if (!string.IsNullOrWhiteSpace(token))
{
this.ApiKey = token;
}
}

if (string.IsNullOrWhiteSpace(this.ApiKey))
{
throw new ConfigException($"{path}.ApiKey", "HuggingFace API key is required (set ApiKey or HF_TOKEN)");
}

if (this.BatchSize < 1)
{
throw new ConfigException($"{path}.BatchSize", "BatchSize must be >= 1");
}

if (string.IsNullOrWhiteSpace(this.BaseUrl))
Expand Down
5 changes: 5 additions & 0 deletions src/Core/Config/Embeddings/OllamaEmbeddingsConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ public override void Validate(string path)
throw new ConfigException($"{path}.BaseUrl", "Ollama base URL is required");
}

if (this.BatchSize < 1)
{
throw new ConfigException($"{path}.BatchSize", "BatchSize must be >= 1");
}

if (!Uri.TryCreate(this.BaseUrl, UriKind.Absolute, out _))
{
throw new ConfigException($"{path}.BaseUrl",
Expand Down
5 changes: 5 additions & 0 deletions src/Core/Config/Embeddings/OpenAIEmbeddingsConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ public override void Validate(string path)
throw new ConfigException($"{path}.ApiKey", "OpenAI API key is required");
}

if (this.BatchSize < 1)
{
throw new ConfigException($"{path}.BatchSize", "BatchSize must be >= 1");
}

if (!string.IsNullOrWhiteSpace(this.BaseUrl) &&
!Uri.TryCreate(this.BaseUrl, UriKind.Absolute, out _))
{
Expand Down
33 changes: 33 additions & 0 deletions src/Core/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,39 @@ public static bool TryGetDimensions(string modelName, out int dimensions)
}
}

/// <summary>
/// Constants for HTTP retry/backoff used by external providers (embeddings, etc.).
/// </summary>
public static class HttpRetryDefaults
{
/// <summary>
/// Maximum attempts including the first try.
/// </summary>
public const int MaxAttempts = 5;

/// <summary>
/// Default timeout per attempt (seconds).
/// Applied by <see cref="KernelMemory.Core.Http.HttpRetryPolicy"/> to avoid hanging calls.
/// </summary>
public const int DefaultPerAttemptTimeoutSeconds = 60;

/// <summary>
/// Per-attempt timeout for local Ollama calls (seconds).
/// Keep this low so local development and tests fail fast when Ollama is not running.
/// </summary>
public const int OllamaPerAttemptTimeoutSeconds = 5;

/// <summary>
/// Base delay for exponential backoff.
/// </summary>
public const int BaseDelayMs = 200;

/// <summary>
/// Maximum delay between attempts.
/// </summary>
public const int MaxDelayMs = 5000;
}

/// <summary>
/// Constants for the logging system including file rotation, log levels,
/// and output formatting.
Expand Down
1 change: 1 addition & 0 deletions src/Core/Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.Identity" />
<PackageReference Include="cuid.net" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Sqlite" />
<PackageReference Include="Parlot" />
Expand Down
97 changes: 87 additions & 10 deletions src/Core/Embeddings/Providers/AzureOpenAIEmbeddingGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Net.Http.Headers;
using System.Net.Http.Json;
using System.Text.Json.Serialization;
using Azure.Core;
using Azure.Identity;
using KernelMemory.Core.Config.Enums;
using KernelMemory.Core.Http;
using Microsoft.Extensions.Logging;

namespace KernelMemory.Core.Embeddings.Providers;

/// <summary>
/// Azure OpenAI embedding generator implementation.
/// Communicates with Azure OpenAI Service.
/// Supports API key authentication (managed identity would require Azure.Identity package).
/// Supports API key authentication or managed identity via <see cref="DefaultAzureCredential"/>.
/// </summary>
public sealed class AzureOpenAIEmbeddingGenerator : IEmbeddingGenerator
{
private readonly HttpClient _httpClient;
private readonly string _endpoint;
private readonly string _deployment;
private readonly string _apiKey;
private readonly string? _apiKey;
private readonly bool _useManagedIdentity;
private readonly TokenCredential? _credential;
private readonly int _batchSize;
private readonly ILogger<AzureOpenAIEmbeddingGenerator> _logger;
private readonly Func<TimeSpan, CancellationToken, Task> _delayAsync;

/// <inheritdoc />
public EmbeddingsTypes ProviderType => EmbeddingsTypes.AzureOpenAI;
Expand All @@ -38,35 +46,52 @@ public sealed class AzureOpenAIEmbeddingGenerator : IEmbeddingGenerator
/// <param name="endpoint">Azure OpenAI endpoint (e.g., https://myservice.openai.azure.com).</param>
/// <param name="deployment">Deployment name in Azure.</param>
/// <param name="model">Model name for identification.</param>
/// <param name="apiKey">Azure OpenAI API key.</param>
/// <param name="apiKey">Azure OpenAI API key (required unless <paramref name="useManagedIdentity"/> is true).</param>
/// <param name="vectorDimensions">Vector dimensions produced by the model.</param>
/// <param name="isNormalized">Whether vectors are normalized.</param>
/// <param name="logger">Logger instance.</param>
/// <param name="batchSize">Maximum number of texts per API request.</param>
/// <param name="useManagedIdentity">Whether to authenticate using managed identity.</param>
/// <param name="credential">Optional token credential (used for testing); defaults to <see cref="DefaultAzureCredential"/>.</param>
/// <param name="delayAsync">Optional delay function for retries (used for fast unit tests).</param>
public AzureOpenAIEmbeddingGenerator(
HttpClient httpClient,
string endpoint,
string deployment,
string model,
string apiKey,
string? apiKey,
int vectorDimensions,
bool isNormalized,
ILogger<AzureOpenAIEmbeddingGenerator> logger)
ILogger<AzureOpenAIEmbeddingGenerator> logger,
int batchSize,
bool useManagedIdentity,
TokenCredential? credential = null,
Func<TimeSpan, CancellationToken, Task>? delayAsync = null)
{
ArgumentNullException.ThrowIfNull(httpClient, nameof(httpClient));
ArgumentNullException.ThrowIfNull(endpoint, nameof(endpoint));
ArgumentNullException.ThrowIfNull(deployment, nameof(deployment));
ArgumentNullException.ThrowIfNull(model, nameof(model));
ArgumentNullException.ThrowIfNull(apiKey, nameof(apiKey));
ArgumentNullException.ThrowIfNull(logger, nameof(logger));
ArgumentOutOfRangeException.ThrowIfLessThan(batchSize, 1, nameof(batchSize));

this._httpClient = httpClient;
this._endpoint = endpoint.TrimEnd('/');
this._deployment = deployment;
this._apiKey = apiKey;
this._useManagedIdentity = useManagedIdentity;
this._credential = credential;
this._batchSize = batchSize;
this.ModelName = model;
this.VectorDimensions = vectorDimensions;
this.IsNormalized = isNormalized;
this._logger = logger;
this._delayAsync = delayAsync ?? Task.Delay;

if (!this._useManagedIdentity && string.IsNullOrWhiteSpace(this._apiKey))
{
throw new ArgumentException("Azure OpenAI API key is required when not using managed identity", nameof(apiKey));
}

this._logger.LogDebug("AzureOpenAIEmbeddingGenerator initialized: {Endpoint}, deployment: {Deployment}, model: {Model}",
this._endpoint, this._deployment, this.ModelName);
Expand All @@ -88,21 +113,53 @@ public async Task<EmbeddingResult[]> GenerateAsync(IEnumerable<string> texts, Ca
return [];
}

var allResults = new List<EmbeddingResult>(textArray.Length);
foreach (var chunk in Chunk(textArray, this._batchSize))
{
var chunkResults = await this.GenerateBatchAsync(chunk, ct).ConfigureAwait(false);
allResults.AddRange(chunkResults);
}

return allResults.ToArray();
}

private async Task<EmbeddingResult[]> GenerateBatchAsync(string[] textArray, CancellationToken ct)
{
var url = $"{this._endpoint}/openai/deployments/{this._deployment}/embeddings?api-version={Constants.EmbeddingDefaults.AzureOpenAIApiVersion}";

var request = new AzureEmbeddingRequest
{
Input = textArray
};

using var httpRequest = new HttpRequestMessage(HttpMethod.Post, url);
httpRequest.Headers.Add("api-key", this._apiKey);
httpRequest.Content = JsonContent.Create(request);
var bearerToken = this._useManagedIdentity
? await this.GetManagedIdentityTokenAsync(ct).ConfigureAwait(false)
: null;

this._logger.LogTrace("Calling Azure OpenAI embeddings API: deployment={Deployment}, batch size: {BatchSize}",
this._deployment, textArray.Length);

var response = await this._httpClient.SendAsync(httpRequest, ct).ConfigureAwait(false);
using var response = await HttpRetryPolicy.SendAsync(
this._httpClient,
requestFactory: () =>
{
var httpRequest = new HttpRequestMessage(HttpMethod.Post, url);
if (bearerToken != null)
{
httpRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", bearerToken);
}
else
{
httpRequest.Headers.Add("api-key", this._apiKey);
}

httpRequest.Content = JsonContent.Create(request);
return httpRequest;
},
this._logger,
ct,
delayAsync: this._delayAsync).ConfigureAwait(false);

response.EnsureSuccessStatusCode();

var result = await response.Content.ReadFromJsonAsync<AzureEmbeddingResponse>(ct).ConfigureAwait(false);
Expand Down Expand Up @@ -141,6 +198,26 @@ public async Task<EmbeddingResult[]> GenerateAsync(IEnumerable<string> texts, Ca
return results;
}

private async Task<string> GetManagedIdentityTokenAsync(CancellationToken ct)
{
var credential = this._credential ?? new DefaultAzureCredential();
var token = await credential.GetTokenAsync(
new TokenRequestContext(["https://cognitiveservices.azure.com/.default"]),
ct).ConfigureAwait(false);
return token.Token;
}

private static IEnumerable<string[]> Chunk(string[] items, int chunkSize)
{
for (int i = 0; i < items.Length; i += chunkSize)
{
var length = Math.Min(chunkSize, items.Length - i);
var chunk = new string[length];
Array.Copy(items, i, chunk, 0, length);
yield return chunk;
}
}

/// <summary>
/// Request body for Azure OpenAI embeddings API.
/// </summary>
Expand Down
Loading
Loading