diff --git a/docs/HostedConversation-ProviderMapping.md b/docs/HostedConversation-ProviderMapping.md new file mode 100644 index 00000000000..a80776b8054 --- /dev/null +++ b/docs/HostedConversation-ProviderMapping.md @@ -0,0 +1,111 @@ +# IHostedConversationClient — Provider Mapping Report + +## Overview + +`IHostedConversationClient` is an abstraction for managing server-side (hosted) conversation state across AI providers. It provides a common interface for creating, retrieving, deleting, and managing messages within persistent conversations, decoupling application code from provider-specific conversation/thread/session APIs. Each provider maps these operations to its native primitives, with escape hatches (`RawRepresentation`, `RawRepresentationFactory`, `AdditionalProperties`) for accessing provider-specific features. + +## Interface Operations + +| Operation | Description | Return Type | +|-----------|-------------|-------------| +| `CreateAsync` | Creates a new hosted conversation | `HostedConversation` | +| `GetAsync` | Retrieves conversation by ID | `HostedConversation` | +| `DeleteAsync` | Deletes a conversation | `void` (Task) | +| `AddMessagesAsync` | Adds messages to a conversation | `void` (Task) | +| `GetMessagesAsync` | Lists messages in a conversation | `IAsyncEnumerable` | + +## Provider Mapping + +### OpenAI (Implemented) + +- Maps to `ConversationClient` in `OpenAI.Conversations` namespace +- Full CRUD support via protocol-level APIs +- `RawRepresentation` set to `ClientResult` objects +- ConversationId integrates with `ChatOptions.ConversationId` for inference via `OpenAIResponsesChatClient` +- Metadata limited to 16 key-value pairs (max 64 char keys, 512 char values) + +### Azure AI Foundry + +- **Azure Foundry v2** uses the OpenAI Responses API directly, so the OpenAI `IHostedConversationClient` implementation works for Azure Foundry v2 without a separate adapter +- The deprecated v1 Agent Service SDK mapped to Thread/Message APIs (`threads.create()`, `threads.get()`, `threads.delete()`, `messages.create()`, `messages.list()`), but this is no longer the recommended path +- **Gaps**: Agent-specific concepts (Runs, Agents) are not in our abstraction; use `AdditionalProperties` for agent-specific metadata + +### AWS Bedrock + +- Maps to Session Management APIs +- `CreateAsync` → `CreateSession` with optional encryption/metadata +- `GetAsync` → `GetSession` +- `DeleteAsync` → `DeleteSession` +- `AddMessagesAsync` → `PutInvocationStep` (different item model) +- `GetMessagesAsync` → `GetInvocationSteps` (requires translation) +- **Gaps**: Session status (ACTIVE/EXPIRED/ENDED) not in abstraction; encryption config is provider-specific; use `AdditionalProperties` or `RawRepresentationFactory` + +### Google Gemini + +- Maps to Interactions API +- `CreateAsync` → `interactions.create()` (creates an interaction, not a "conversation" per se) +- `GetAsync` → `interactions.get()` +- `DeleteAsync` → `interactions.delete()` +- `AddMessagesAsync` → No direct equivalent; use `interactions.create()` with `previous_interaction_id` chain +- `GetMessagesAsync` → `interactions.get().outputs` (retrieves outputs, not full message history) +- **Gaps**: Interactions are individual turns, not conversation containers. AddMessages requires creating new interactions chained via `previous_interaction_id`. Provider adapter would need to manage this mapping. + +### Anthropic + +- **No native conversation CRUD API** — requires local adapter +- Server-side features that CAN assist: + - **Prompt Caching** (`cache_control`): Stores KV cache of message prefixes (5min/1hr TTL). Adapter should auto-apply cache breakpoints. + - **Context Compaction** (beta): Server-side summarization when conversations exceed token threshold + - **Files API** (beta): Store documents server-side for reference across requests + - **Containers** (beta): Server-side execution state with reusable IDs +- Implementation approach: `LocalHostedConversationClient` using local storage (in-memory, SQLite, Redis) with automatic prompt caching optimization +- **Gaps**: All operations are simulated client-side. No server-side conversation persistence. + +### Ollama / Local Models + +- **No server-side state** at all +- Implementation: Same local adapter pattern as Anthropic but without prompt caching optimization +- **Gaps**: Same as Anthropic — entirely client-side simulation + +## Escape Hatches for Provider-Specific Features + +### RawRepresentation + +Every `HostedConversation` response carries `RawRepresentation` (the underlying provider object). This gives access to 100% of provider functionality: + +```csharp +var conversation = await client.CreateAsync(); +var openAIResult = (ClientResult)conversation.RawRepresentation; // Access any OpenAI-specific data +``` + +### RawRepresentationFactory + +`HostedConversationCreationOptions.RawRepresentationFactory` allows passing provider-specific creation options: + +```csharp +var options = new HostedConversationCreationOptions +{ + RawRepresentationFactory = client => new ConversationCreationOptions + { + // Any provider-specific settings + } +}; +``` + +### AdditionalProperties + +`HostedConversation.AdditionalProperties` and `HostedConversationCreationOptions.AdditionalProperties` carry provider-specific data that doesn't fit the common abstraction. + +## Feature Coverage Matrix + +| Feature | OpenAI | Azure | Bedrock | Gemini | Anthropic | Ollama | +|---------|--------|-------|---------|--------|-----------|--------| +| Create | ✅ Native | ✅ Native | ✅ Native | ✅ Native | ⚠️ Local | ⚠️ Local | +| Get | ✅ Native | ✅ Native | ✅ Native | ✅ Native | ⚠️ Local | ⚠️ Local | +| Delete | ✅ Native | ✅ Native | ✅ Native | ✅ Native | ⚠️ Local | ⚠️ Local | +| AddMessages | ✅ Native | ✅ Native | ⚠️ Translated | ⚠️ Chained | ⚠️ Local | ⚠️ Local | +| GetMessages | ✅ Native | ✅ Native | ⚠️ Translated | ⚠️ Partial | ⚠️ Local | ⚠️ Local | +| Metadata | ✅ 16 KV | ✅ | ✅ | ✅ | ⚠️ Local | ⚠️ Local | +| RawRepresentation | ✅ ClientResult | ✅ AgentThread | ✅ Session | ✅ Interaction | N/A | N/A | + +Legend: ✅ = Direct mapping, ⚠️ = Requires translation/local adapter diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/DelegatingHostedConversationClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/DelegatingHostedConversationClient.cs new file mode 100644 index 00000000000..58fc1b08384 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/DelegatingHostedConversationClient.cs @@ -0,0 +1,99 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.DiagnosticIds; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Provides an optional base class for an that passes through calls to another instance. +/// +/// +/// This is recommended as a base type when building clients that can be chained around an underlying . +/// The default implementation simply passes each call to the inner client instance. +/// +[Experimental(DiagnosticIds.Experiments.AIHostedConversation, UrlFormat = DiagnosticIds.UrlFormat)] +public class DelegatingHostedConversationClient : IHostedConversationClient +{ + /// + /// Initializes a new instance of the class. + /// + /// The wrapped client instance. + /// is . + protected DelegatingHostedConversationClient(IHostedConversationClient innerClient) + { + InnerClient = Throw.IfNull(innerClient); + } + + /// + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + /// Gets the inner . + protected IHostedConversationClient InnerClient { get; } + + /// + public virtual Task CreateAsync( + HostedConversationClientOptions? options = null, + CancellationToken cancellationToken = default) => + InnerClient.CreateAsync(options, cancellationToken); + + /// + public virtual Task GetAsync( + string conversationId, + HostedConversationClientOptions? options = null, + CancellationToken cancellationToken = default) => + InnerClient.GetAsync(conversationId, options, cancellationToken); + + /// + public virtual Task DeleteAsync( + string conversationId, + HostedConversationClientOptions? options = null, + CancellationToken cancellationToken = default) => + InnerClient.DeleteAsync(conversationId, options, cancellationToken); + + /// + public virtual Task AddMessagesAsync( + string conversationId, + IEnumerable messages, + HostedConversationClientOptions? options = null, + CancellationToken cancellationToken = default) => + InnerClient.AddMessagesAsync(conversationId, messages, options, cancellationToken); + + /// + public virtual IAsyncEnumerable GetMessagesAsync( + string conversationId, + HostedConversationClientOptions? options = null, + CancellationToken cancellationToken = default) => + InnerClient.GetMessagesAsync(conversationId, options, cancellationToken); + + /// + public virtual object? GetService(Type serviceType, object? serviceKey = null) + { + _ = Throw.IfNull(serviceType); + + // If the key is non-null, we don't know what it means so pass through to the inner service. + return + serviceKey is null && serviceType.IsInstanceOfType(this) ? this : + InnerClient.GetService(serviceType, serviceKey); + } + + /// Provides a mechanism for releasing unmanaged resources. + /// if being called from ; otherwise, . + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + InnerClient.Dispose(); + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/HostedConversation.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/HostedConversation.cs new file mode 100644 index 00000000000..6d5b3e4df7b --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/HostedConversation.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; +using Microsoft.Shared.DiagnosticIds; + +namespace Microsoft.Extensions.AI; + +/// Represents a hosted conversation. +[Experimental(DiagnosticIds.Experiments.AIHostedConversation, UrlFormat = DiagnosticIds.UrlFormat)] +public class HostedConversation +{ + /// Gets or sets the conversation identifier. + public string? ConversationId { get; set; } + + /// Gets or sets the creation timestamp. + public DateTimeOffset? CreatedAt { get; set; } + + /// Gets or sets the raw representation of the conversation from the underlying provider. + /// + /// If a is created to represent some underlying object from another object + /// model, this property can be used to store that original object. This can be useful for debugging or + /// for enabling a consumer to access the underlying object model if needed. + /// + [JsonIgnore] + public object? RawRepresentation { get; set; } + + /// Gets or sets any additional properties associated with the conversation. + public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/HostedConversationClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/HostedConversationClientExtensions.cs new file mode 100644 index 00000000000..0e783303d3a --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/HostedConversationClientExtensions.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Shared.DiagnosticIds; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides a collection of static methods for extending instances. +[Experimental(DiagnosticIds.Experiments.AIHostedConversation, UrlFormat = DiagnosticIds.UrlFormat)] +public static class HostedConversationClientExtensions +{ + /// Asks the for an object of type . + /// The type of the object to be retrieved. + /// The client. + /// An optional key that can be used to help identify the target service. + /// The found object, otherwise . + /// is . + /// + /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the , + /// including itself or any services it might be wrapping. + /// + public static TService? GetService(this IHostedConversationClient client, object? serviceKey = null) + { + _ = Throw.IfNull(client); + + return client.GetService(typeof(TService), serviceKey) is TService service ? service : default; + } + + /// + /// Asks the for an object of the specified type + /// and throws an exception if one isn't available. + /// + /// The client. + /// The type of object being requested. + /// An optional key that can be used to help identify the target service. + /// The found object. + /// is . + /// is . + /// No service of the requested type for the specified key is available. + /// + /// The purpose of this method is to allow for the retrieval of services that are required to be provided by the , + /// including itself or any services it might be wrapping. + /// + public static object GetRequiredService(this IHostedConversationClient client, Type serviceType, object? serviceKey = null) + { + _ = Throw.IfNull(client); + _ = Throw.IfNull(serviceType); + + return + client.GetService(serviceType, serviceKey) ?? + throw Throw.CreateMissingServiceException(serviceType, serviceKey); + } + + /// + /// Asks the for an object of type + /// and throws an exception if one isn't available. + /// + /// The type of the object to be retrieved. + /// The client. + /// An optional key that can be used to help identify the target service. + /// The found object. + /// is . + /// No service of the requested type for the specified key is available. + /// + /// The purpose of this method is to allow for the retrieval of strongly typed services that are required to be provided by the , + /// including itself or any services it might be wrapping. + /// + public static TService GetRequiredService(this IHostedConversationClient client, object? serviceKey = null) + { + _ = Throw.IfNull(client); + + if (client.GetService(typeof(TService), serviceKey) is not TService service) + { + throw Throw.CreateMissingServiceException(typeof(TService), serviceKey); + } + + return service; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/HostedConversationClientMetadata.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/HostedConversationClientMetadata.cs new file mode 100644 index 00000000000..5c746dd2c9c --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/HostedConversationClientMetadata.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Shared.DiagnosticIds; + +namespace Microsoft.Extensions.AI; + +/// Provides metadata about an . +[Experimental(DiagnosticIds.Experiments.AIHostedConversation, UrlFormat = DiagnosticIds.UrlFormat)] +public class HostedConversationClientMetadata +{ + /// Initializes a new instance of the class. + /// + /// The name of the hosted conversation provider, if applicable. Where possible, this should map to the + /// appropriate name defined in the OpenTelemetry Semantic Conventions for Generative AI systems. + /// + /// The URL for accessing the hosted conversation provider, if applicable. + public HostedConversationClientMetadata(string? providerName = null, Uri? providerUri = null) + { + ProviderName = providerName; + ProviderUri = providerUri; + } + + /// Gets the name of the hosted conversation provider. + /// + /// Where possible, this maps to the appropriate name defined in the + /// OpenTelemetry Semantic Conventions for Generative AI systems. + /// + public string? ProviderName { get; } + + /// Gets the URL for accessing the hosted conversation provider. + public Uri? ProviderUri { get; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/HostedConversationClientOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/HostedConversationClientOptions.cs new file mode 100644 index 00000000000..bbf8228d2d6 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/HostedConversationClientOptions.cs @@ -0,0 +1,71 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; +using Microsoft.Shared.DiagnosticIds; + +namespace Microsoft.Extensions.AI; + +/// Represents the options for a hosted conversation client request. +[Experimental(DiagnosticIds.Experiments.AIHostedConversation, UrlFormat = DiagnosticIds.UrlFormat)] +public class HostedConversationClientOptions +{ + /// + /// Initializes a new instance of the class. + /// + public HostedConversationClientOptions() + { + } + + /// + /// Initializes a new instance of the class + /// by cloning the properties of another instance. + /// + /// The instance to clone. + protected HostedConversationClientOptions(HostedConversationClientOptions? other) + { + if (other is null) + { + return; + } + + Limit = other.Limit; + RawRepresentationFactory = other.RawRepresentationFactory; + AdditionalProperties = other.AdditionalProperties?.Clone(); + } + + /// Gets or sets the maximum number of items to return in a list operation. + /// + /// If not specified, the provider's default limit will be used. + /// + public int? Limit { get; set; } + + /// + /// Gets or sets a callback responsible for creating the raw representation of the conversation client options from an underlying implementation. + /// + /// + /// The underlying implementation may have its own representation of options. + /// When an operation is invoked with a , that implementation may convert + /// the provided options into its own representation in order to use it while performing the operation. + /// For situations where a consumer knows which concrete is being used + /// and how it represents options, a new instance of that implementation-specific options type may be returned + /// by this callback, for the implementation to use instead of creating a new + /// instance. Such implementations may mutate the supplied options instance further based on other settings + /// supplied on this instance or from other inputs, + /// therefore, it is strongly recommended to not return shared instances and instead make the callback + /// return a new instance on each call. + /// This is typically used to set an implementation-specific setting that isn't otherwise exposed from the strongly typed + /// properties on . + /// + [JsonIgnore] + public Func? RawRepresentationFactory { get; set; } + + /// Gets or sets additional properties for the request. + public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } + + /// Creates a shallow clone of the current instance. + /// A shallow clone of the current instance. + public virtual HostedConversationClientOptions Clone() => new(this); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/IHostedConversationClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/IHostedConversationClient.cs new file mode 100644 index 00000000000..8ed56abe2ad --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Conversations/IHostedConversationClient.cs @@ -0,0 +1,95 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.DiagnosticIds; + +namespace Microsoft.Extensions.AI; + +/// Represents a client for managing hosted conversations. +/// +/// +/// Unless otherwise specified, all members of are thread-safe for concurrent use. +/// It is expected that all implementations of support being used by multiple requests concurrently. +/// +/// +/// However, implementations of might mutate the arguments supplied to +/// and , such as by configuring the options or messages instances. +/// Thus, consumers of the interface either should avoid using shared instances of these arguments for concurrent +/// invocations or should otherwise ensure by construction that no instances +/// are used which might employ such mutation. +/// +/// +[Experimental(DiagnosticIds.Experiments.AIHostedConversation, UrlFormat = DiagnosticIds.UrlFormat)] +public interface IHostedConversationClient : IDisposable +{ + /// Creates a new hosted conversation. + /// The options to configure the conversation creation. + /// The to monitor for cancellation requests. The default is . + /// The created . + Task CreateAsync( + HostedConversationClientOptions? options = null, + CancellationToken cancellationToken = default); + + /// Retrieves an existing hosted conversation by its identifier. + /// The unique identifier of the conversation to retrieve. + /// The options for the request. + /// The to monitor for cancellation requests. The default is . + /// The matching the specified . + /// is . + Task GetAsync( + string conversationId, + HostedConversationClientOptions? options = null, + CancellationToken cancellationToken = default); + + /// Deletes an existing hosted conversation. + /// The unique identifier of the conversation to delete. + /// The options for the request. + /// The to monitor for cancellation requests. The default is . + /// A representing the asynchronous operation. + /// is . + Task DeleteAsync( + string conversationId, + HostedConversationClientOptions? options = null, + CancellationToken cancellationToken = default); + + /// Adds messages to an existing hosted conversation. + /// The unique identifier of the conversation to add messages to. + /// The sequence of chat messages to add to the conversation. + /// The options for the request. + /// The to monitor for cancellation requests. The default is . + /// A representing the asynchronous operation. + /// is . + /// is . + Task AddMessagesAsync( + string conversationId, + IEnumerable messages, + HostedConversationClientOptions? options = null, + CancellationToken cancellationToken = default); + + /// Lists the messages in an existing hosted conversation. + /// The unique identifier of the conversation to list messages from. + /// The options for the request. + /// The to monitor for cancellation requests. The default is . + /// An asynchronous sequence of instances from the conversation. + /// is . + IAsyncEnumerable GetMessagesAsync( + string conversationId, + HostedConversationClientOptions? options = null, + CancellationToken cancellationToken = default); + + /// Asks the for an object of the specified type . + /// The type of object being requested. + /// An optional key that can be used to help identify the target service. + /// The found object, otherwise . + /// is . + /// + /// The purpose of this method is to allow for the retrieval of strongly typed services that might be provided by the , + /// including itself or any services it might be wrapping. + /// + object? GetService(Type serviceType, object? serviceKey = null); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs index 9a5ebd0d06a..85220bf8cc8 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs @@ -12,11 +12,13 @@ using System.Text.Json.Nodes; using System.Text.Json.Serialization; using Microsoft.Shared.DiagnosticIds; +using Microsoft.Shared.Diagnostics; using OpenAI; using OpenAI.Assistants; using OpenAI.Audio; using OpenAI.Chat; using OpenAI.Containers; +using OpenAI.Conversations; using OpenAI.Embeddings; using OpenAI.Files; using OpenAI.Images; @@ -230,6 +232,14 @@ public static IHostedFileClient AsIHostedFileClient(this OpenAIFileClient fileCl public static IHostedFileClient AsIHostedFileClient(this ContainerClient containerClient, string? defaultScope = null) => new OpenAIHostedFileClient(containerClient, defaultScope); + /// Gets an for use with this . + /// The client. + /// An that can be used to manage hosted conversations via the . + /// is . + [Experimental(DiagnosticIds.Experiments.AIHostedConversation, UrlFormat = DiagnosticIds.UrlFormat)] + public static IHostedConversationClient AsIHostedConversationClient(this ConversationClient conversationClient) => + new OpenAIHostedConversationClient(Throw.IfNull(conversationClient)); + /// Gets whether the properties specify that strict schema handling is desired. internal static bool? HasStrict(IReadOnlyDictionary? additionalProperties) => additionalProperties?.TryGetValue(StrictKey, out object? strictObj) is true && diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIHostedConversationClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIHostedConversationClient.cs new file mode 100644 index 00000000000..9ba556dc950 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIHostedConversationClient.cs @@ -0,0 +1,456 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.DiagnosticIds; +using Microsoft.Shared.Diagnostics; +using OpenAI.Conversations; + +#pragma warning disable S1006 // Add the default parameter value defined in the overridden method +#pragma warning disable S3254 // Remove this default value assigned to parameter + +namespace Microsoft.Extensions.AI; + +/// Represents an for an OpenAI . +[Experimental(DiagnosticIds.Experiments.AIHostedConversation, UrlFormat = DiagnosticIds.UrlFormat)] +internal sealed class OpenAIHostedConversationClient : IHostedConversationClient +{ + /// Metadata about the client. + private readonly HostedConversationClientMetadata _metadata; + + /// The underlying . + private readonly ConversationClient _conversationClient; + + /// Initializes a new instance of the class. + /// The underlying client. + /// is . + public OpenAIHostedConversationClient(ConversationClient conversationClient) + { + _conversationClient = Throw.IfNull(conversationClient); + _metadata = new("openai", conversationClient.Endpoint); + } + + /// + public async Task CreateAsync( + HostedConversationClientOptions? options = null, + CancellationToken cancellationToken = default) + { + using BinaryContent content = CreateOrGetCreatePayload(options); + RequestOptions requestOptions = cancellationToken.ToRequestOptions(streaming: false); + + ClientResult result = await _conversationClient.CreateConversationAsync(content, requestOptions).ConfigureAwait(false); + + return ParseConversation(result); + } + + /// + public async Task GetAsync( + string conversationId, + HostedConversationClientOptions? options = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(conversationId); + + RequestOptions requestOptions = cancellationToken.ToRequestOptions(streaming: false); + + ClientResult result = await _conversationClient.GetConversationAsync(conversationId, requestOptions).ConfigureAwait(false); + + return ParseConversation(result); + } + + /// + public async Task DeleteAsync( + string conversationId, + HostedConversationClientOptions? options = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(conversationId); + + RequestOptions requestOptions = cancellationToken.ToRequestOptions(streaming: false); + _ = await _conversationClient.DeleteConversationAsync(conversationId, requestOptions).ConfigureAwait(false); + } + + /// + public async Task AddMessagesAsync( + string conversationId, + IEnumerable messages, + HostedConversationClientOptions? options = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(conversationId); + _ = Throw.IfNull(messages); + + using BinaryContent content = CreateBinaryContent(static (writer, msgs) => WriteItemsPayload(writer, msgs), messages); + RequestOptions requestOptions = cancellationToken.ToRequestOptions(streaming: false); + + _ = await _conversationClient.CreateConversationItemsAsync(conversationId, content, null, requestOptions).ConfigureAwait(false); + } + + /// + public async IAsyncEnumerable GetMessagesAsync( + string conversationId, + HostedConversationClientOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(conversationId); + + int? limit = options?.Limit; + RequestOptions requestOptions = cancellationToken.ToRequestOptions(streaming: false); + + // Manual pagination: the SDK's GetRawPagesAsync() only yields a single page because + // GetContinuationToken returns null in the generated collection result class. + // We loop using the cursor-based 'after' parameter until all items are retrieved. + string? after = null; + + do + { + AsyncCollectionResult pages = _conversationClient.GetConversationItemsAsync( + conversationId, limit: limit, order: null, after: after, include: null, options: requestOptions); + + bool hasMore = false; + string? lastId = null; + + await foreach (ClientResult page in pages.GetRawPagesAsync().ConfigureAwait(false)) + { + using JsonDocument doc = JsonDocument.Parse(page.GetRawResponse().Content); + JsonElement root = doc.RootElement; + + if (root.TryGetProperty("has_more", out JsonElement hasMoreElement) && + hasMoreElement.ValueKind == JsonValueKind.True) + { + hasMore = true; + } + + if (root.TryGetProperty("last_id", out JsonElement lastIdElement)) + { + lastId = lastIdElement.GetString(); + } + + if (!root.TryGetProperty("data", out JsonElement dataElement)) + { + continue; + } + + foreach (JsonElement item in dataElement.EnumerateArray()) + { + if (TryConvertItemToChatMessage(item) is { } message) + { + yield return message; + } + } + } + + after = hasMore ? lastId : null; + } + while (after is not null); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + _ = Throw.IfNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(HostedConversationClientMetadata) ? _metadata : + serviceType == typeof(ConversationClient) ? _conversationClient : + serviceType.IsInstanceOfType(this) ? this : + null; + } + + /// Creates a for the create conversation request, using the raw representation factory if available. + private BinaryContent CreateOrGetCreatePayload(HostedConversationClientOptions? options) + { + if (options?.RawRepresentationFactory?.Invoke(this) is BinaryContent rawContent) + { + return rawContent; + } + + return CreateBinaryContent(static (writer, opts) => WriteCreatePayload(writer, opts), options); + } + + /// + void IDisposable.Dispose() + { + // Nothing to dispose. Implementation required for the IHostedConversationClient interface. + } + + /// Creates a from a JSON writing action. + private static BinaryContent CreateBinaryContent(Action writeAction, TState state) + { + using MemoryStream stream = new(); + using (var writer = new Utf8JsonWriter(stream)) + { + writeAction(writer, state); + } + + return BinaryContent.Create(new BinaryData(stream.ToArray())); + } + + /// Writes the JSON payload for creating a conversation. + private static void WriteCreatePayload(Utf8JsonWriter writer, HostedConversationClientOptions? options) + { + writer.WriteStartObject(); + + if (options?.AdditionalProperties is { Count: > 0 } additionalProperties) + { + // Map "metadata" from AdditionalProperties if present as a dictionary of string values. + if (additionalProperties.TryGetValue("metadata", out object? metadataObj) && + metadataObj is AdditionalPropertiesDictionary metadata && + metadata.Count > 0) + { + writer.WritePropertyName("metadata"); + writer.WriteStartObject(); + foreach (var kvp in metadata) + { + writer.WriteString(kvp.Key, kvp.Value); + } + + writer.WriteEndObject(); + } + } + + writer.WriteEndObject(); + } + + /// Writes the JSON payload for adding items to a conversation. + private static void WriteItemsPayload(Utf8JsonWriter writer, IEnumerable messages) + { + writer.WriteStartObject(); + writer.WritePropertyName("items"); + WriteMessagesArray(writer, messages); + writer.WriteEndObject(); + } + + /// Writes a JSON array of conversation items from a collection of instances. + private static void WriteMessagesArray(Utf8JsonWriter writer, IEnumerable messages) + { + writer.WriteStartArray(); + + foreach (ChatMessage message in messages) + { + string role = message.Role == ChatRole.Assistant ? "assistant" : + message.Role == ChatRole.System ? "system" : + message.Role == OpenAIClientExtensions.ChatRoleDeveloper ? "developer" : + "user"; + + writer.WriteStartObject(); + writer.WriteString("type", "message"); + writer.WriteString("role", role); + writer.WritePropertyName("content"); + writer.WriteStartArray(); + + bool hasContent = false; + foreach (AIContent aiContent in message.Contents) + { + switch (aiContent) + { + case TextContent textContent: + writer.WriteStartObject(); + writer.WriteString("type", message.Role == ChatRole.Assistant ? "output_text" : "input_text"); + writer.WriteString("text", textContent.Text); + writer.WriteEndObject(); + hasContent = true; + break; + + case UriContent uriContent when uriContent.HasTopLevelMediaType("image"): + writer.WriteStartObject(); + writer.WriteString("type", "input_image"); + writer.WriteString("image_url", uriContent.Uri.ToString()); + writer.WriteEndObject(); + hasContent = true; + break; + + case DataContent dataContent when dataContent.HasTopLevelMediaType("image"): + writer.WriteStartObject(); + writer.WriteString("type", "input_image"); + writer.WriteString("image_url", dataContent.Uri); + writer.WriteEndObject(); + hasContent = true; + break; + + case HostedFileContent fileContent: + writer.WriteStartObject(); + if (fileContent.HasTopLevelMediaType("image")) + { + writer.WriteString("type", "input_image"); + writer.WritePropertyName("image_file"); + writer.WriteStartObject(); + writer.WriteString("file_id", fileContent.FileId); + writer.WriteEndObject(); + } + else + { + writer.WriteString("type", "input_file"); + writer.WriteString("file_id", fileContent.FileId); + } + + writer.WriteEndObject(); + hasContent = true; + break; + } + } + + // If no structured content parts were produced, fall back to the text property. + if (!hasContent) + { + string? text = message.Text; + if (text is not null) + { + writer.WriteStartObject(); + writer.WriteString("type", message.Role == ChatRole.Assistant ? "output_text" : "input_text"); + writer.WriteString("text", text); + writer.WriteEndObject(); + } + } + + writer.WriteEndArray(); + writer.WriteEndObject(); + } + + writer.WriteEndArray(); + } + + /// Parses a from a . + private static HostedConversation ParseConversation(ClientResult result) + { + using JsonDocument doc = JsonDocument.Parse(result.GetRawResponse().Content); + JsonElement root = doc.RootElement; + + var conversation = new HostedConversation + { + ConversationId = root.TryGetProperty("id", out JsonElement idElement) ? idElement.GetString() : null, + CreatedAt = root.TryGetProperty("created_at", out JsonElement createdAtElement) && createdAtElement.ValueKind == JsonValueKind.Number ? DateTimeOffset.FromUnixTimeSeconds(createdAtElement.GetInt64()) : null, + RawRepresentation = result, + }; + + if (ParseMetadata(root) is { } metadata) + { + conversation.AdditionalProperties ??= new(); + conversation.AdditionalProperties["metadata"] = metadata; + } + + return conversation; + } + + /// Attempts to convert a JSON element representing a conversation item to a . + private static ChatMessage? TryConvertItemToChatMessage(JsonElement item) + { + if (!item.TryGetProperty("type", out JsonElement typeElement)) + { + return null; + } + + string? type = typeElement.GetString(); + if (type != "message") + { + return null; + } + + ChatRole role = ChatRole.User; + if (item.TryGetProperty("role", out JsonElement roleElement)) + { + string? roleStr = roleElement.GetString(); + role = roleStr switch + { + "assistant" => ChatRole.Assistant, + "system" => ChatRole.System, + "developer" => OpenAIClientExtensions.ChatRoleDeveloper, + "tool" => ChatRole.Tool, + _ => ChatRole.User, + }; + } + + var message = new ChatMessage(role, (string?)null); + + if (item.TryGetProperty("id", out JsonElement idElement)) + { + message.MessageId = idElement.GetString(); + } + + if (item.TryGetProperty("content", out JsonElement contentElement) && contentElement.ValueKind == JsonValueKind.Array) + { + foreach (JsonElement part in contentElement.EnumerateArray()) + { + if (!part.TryGetProperty("type", out JsonElement partType)) + { + continue; + } + + string? partTypeStr = partType.GetString(); + switch (partTypeStr) + { + case "input_text" or "output_text" or "text": + if (part.TryGetProperty("text", out JsonElement textElement)) + { + message.Contents.Add(new TextContent(textElement.GetString())); + } + + break; + + case "refusal": + if (part.TryGetProperty("refusal", out JsonElement refusalElement)) + { + message.Contents.Add(new ErrorContent(refusalElement.GetString()) { ErrorCode = "Refusal" }); + } + + break; + + case "input_image": + if (part.TryGetProperty("image_url", out JsonElement imageUrlElement) && + imageUrlElement.GetString() is { } imageUrl && + Uri.TryCreate(imageUrl, UriKind.Absolute, out Uri? imageUri)) + { + message.Contents.Add(new UriContent(imageUri, OpenAIClientExtensions.ImageUriToMediaType(imageUri))); + } + else if (part.TryGetProperty("file_id", out JsonElement fileIdElement) && + fileIdElement.GetString() is { } fileId) + { + message.Contents.Add(new HostedFileContent(fileId) { MediaType = "image/*" }); + } + + break; + + case "input_file": + if (part.TryGetProperty("file_id", out JsonElement inputFileIdElement) && + inputFileIdElement.GetString() is { } inputFileId) + { + string? filename = part.TryGetProperty("filename", out JsonElement filenameElement) ? filenameElement.GetString() : null; + message.Contents.Add(new HostedFileContent(inputFileId) { Name = filename }); + } + + break; + } + } + } + + return message; + } + + /// Parses metadata from a JSON element. + private static AdditionalPropertiesDictionary? ParseMetadata(JsonElement root) + { + if (!root.TryGetProperty("metadata", out JsonElement metadataElement) || + metadataElement.ValueKind != JsonValueKind.Object) + { + return null; + } + + var metadata = new AdditionalPropertiesDictionary(); + foreach (JsonProperty property in metadataElement.EnumerateObject()) + { + metadata[property.Name] = property.Value.GetString() ?? string.Empty; + } + + return metadata.Count > 0 ? metadata : null; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/HostedConversation/ConfigureOptionsHostedConversationClient.cs b/src/Libraries/Microsoft.Extensions.AI/HostedConversation/ConfigureOptionsHostedConversationClient.cs new file mode 100644 index 00000000000..7df9aa0ef24 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/HostedConversation/ConfigureOptionsHostedConversationClient.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.DiagnosticIds; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Represents a delegating hosted conversation client that configures a instance used by the remainder of the pipeline. +[Experimental(DiagnosticIds.Experiments.AIHostedConversation, UrlFormat = DiagnosticIds.UrlFormat)] +public sealed class ConfigureOptionsHostedConversationClient : DelegatingHostedConversationClient +{ + /// The callback delegate used to configure options. + private readonly Action _configureOptions; + + /// Initializes a new instance of the class with the specified callback. + /// The inner client. + /// + /// The delegate to invoke to configure the instance. It is passed a clone of the caller-supplied instance + /// (or a newly constructed instance if the caller-supplied instance is ). + /// + /// + /// The delegate is passed either a new instance of if + /// the caller didn't supply a instance, or a clone (via ) of the caller-supplied + /// instance if one was supplied. + /// + public ConfigureOptionsHostedConversationClient(IHostedConversationClient innerClient, Action configure) + : base(innerClient) + { + _configureOptions = Throw.IfNull(configure); + } + + /// + public override async Task CreateAsync( + HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + { + return await base.CreateAsync(Configure(options), cancellationToken); + } + + /// + public override async Task GetAsync( + string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + { + return await base.GetAsync(conversationId, Configure(options), cancellationToken); + } + + /// + public override async Task DeleteAsync( + string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + { + await base.DeleteAsync(conversationId, Configure(options), cancellationToken); + } + + /// + public override async Task AddMessagesAsync( + string conversationId, IEnumerable messages, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + { + await base.AddMessagesAsync(conversationId, messages, Configure(options), cancellationToken); + } + + /// + public override IAsyncEnumerable GetMessagesAsync( + string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + { + return base.GetMessagesAsync(conversationId, Configure(options), cancellationToken); + } + + /// Creates and configures the to pass along to the inner client. + private HostedConversationClientOptions Configure(HostedConversationClientOptions? options) + { + options = options?.Clone() ?? new(); + + _configureOptions(options); + + return options; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/HostedConversation/ConfigureOptionsHostedConversationClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/HostedConversation/ConfigureOptionsHostedConversationClientBuilderExtensions.cs new file mode 100644 index 00000000000..7046aec6490 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/HostedConversation/ConfigureOptionsHostedConversationClientBuilderExtensions.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Shared.DiagnosticIds; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides extensions for configuring instances. +[Experimental(DiagnosticIds.Experiments.AIHostedConversation, UrlFormat = DiagnosticIds.UrlFormat)] +public static class ConfigureOptionsHostedConversationClientBuilderExtensions +{ + /// + /// Adds a callback that configures a to be passed to the next client in the pipeline. + /// + /// The . + /// + /// The delegate to invoke to configure the instance. + /// It is passed a clone of the caller-supplied instance (or a newly constructed instance if the caller-supplied instance is ). + /// + /// + /// This method can be used to set default options. The delegate is passed either a new instance of + /// if the caller didn't supply a instance, or a clone (via ) + /// of the caller-supplied instance if one was supplied. + /// + /// The . + public static HostedConversationClientBuilder ConfigureOptions( + this HostedConversationClientBuilder builder, Action configure) + { + _ = Throw.IfNull(builder); + _ = Throw.IfNull(configure); + + return builder.Use(innerClient => new ConfigureOptionsHostedConversationClient(innerClient, configure)); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/HostedConversation/HostedConversationClientBuilder.cs b/src/Libraries/Microsoft.Extensions.AI/HostedConversation/HostedConversationClientBuilder.cs new file mode 100644 index 00000000000..3fd720a8895 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/HostedConversation/HostedConversationClientBuilder.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Shared.DiagnosticIds; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// A builder for creating pipelines of . +[Experimental(DiagnosticIds.Experiments.AIHostedConversation, UrlFormat = DiagnosticIds.UrlFormat)] +public sealed class HostedConversationClientBuilder +{ + private readonly Func _innerClientFactory; + + /// The registered client factory instances. + private List>? _clientFactories; + + /// Initializes a new instance of the class. + /// The inner that represents the underlying backend. + public HostedConversationClientBuilder(IHostedConversationClient innerClient) + { + _ = Throw.IfNull(innerClient); + _innerClientFactory = _ => innerClient; + } + + /// Initializes a new instance of the class. + /// A callback that produces the inner that represents the underlying backend. + public HostedConversationClientBuilder(Func innerClientFactory) + { + _innerClientFactory = Throw.IfNull(innerClientFactory); + } + + /// Builds an that represents the entire pipeline. Calls to this instance will pass through each of the pipeline stages in turn. + /// + /// The that should provide services to the instances. + /// If null, an empty will be used. + /// + /// An instance of that represents the entire pipeline. + public IHostedConversationClient Build(IServiceProvider? services = null) + { + services ??= EmptyServiceProvider.Instance; + var client = _innerClientFactory(services); + + // To match intuitive expectations, apply the factories in reverse order, so that the first factory added is the outermost. + if (_clientFactories is not null) + { + for (var i = _clientFactories.Count - 1; i >= 0; i--) + { + client = _clientFactories[i](client, services) ?? + throw new InvalidOperationException( + $"The {nameof(HostedConversationClientBuilder)} entry at index {i} returned null. " + + $"Ensure that the callbacks passed to {nameof(Use)} return non-null {nameof(IHostedConversationClient)} instances."); + } + } + + return client; + } + + /// Adds a factory for an intermediate hosted conversation client to the hosted conversation client pipeline. + /// The client factory function. + /// The updated instance. + public HostedConversationClientBuilder Use(Func clientFactory) + { + _ = Throw.IfNull(clientFactory); + + return Use((innerClient, _) => clientFactory(innerClient)); + } + + /// Adds a factory for an intermediate hosted conversation client to the hosted conversation client pipeline. + /// The client factory function. + /// The updated instance. + public HostedConversationClientBuilder Use(Func clientFactory) + { + _ = Throw.IfNull(clientFactory); + + (_clientFactories ??= []).Add(clientFactory); + return this; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/HostedConversation/HostedConversationClientBuilderServiceCollectionExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/HostedConversation/HostedConversationClientBuilderServiceCollectionExtensions.cs new file mode 100644 index 00000000000..94096907084 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/HostedConversation/HostedConversationClientBuilderServiceCollectionExtensions.cs @@ -0,0 +1,91 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.AI; +using Microsoft.Shared.DiagnosticIds; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.DependencyInjection; + +/// Provides extension methods for registering with a . +[Experimental(DiagnosticIds.Experiments.AIHostedConversation, UrlFormat = DiagnosticIds.UrlFormat)] +public static class HostedConversationClientBuilderServiceCollectionExtensions +{ + /// Registers a singleton in the . + /// The to which the client should be added. + /// The inner that represents the underlying backend. + /// The service lifetime for the client. Defaults to . + /// A that can be used to build a pipeline around the inner client. + /// The client is registered as a singleton service. + public static HostedConversationClientBuilder AddHostedConversationClient( + this IServiceCollection serviceCollection, + IHostedConversationClient innerClient, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + { + _ = Throw.IfNull(serviceCollection); + _ = Throw.IfNull(innerClient); + + return AddHostedConversationClient(serviceCollection, _ => innerClient, lifetime); + } + + /// Registers a singleton in the . + /// The to which the client should be added. + /// A callback that produces the inner that represents the underlying backend. + /// The service lifetime for the client. Defaults to . + /// A that can be used to build a pipeline around the inner client. + /// The client is registered as a singleton service. + public static HostedConversationClientBuilder AddHostedConversationClient( + this IServiceCollection serviceCollection, + Func innerClientFactory, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + { + _ = Throw.IfNull(serviceCollection); + _ = Throw.IfNull(innerClientFactory); + + var builder = new HostedConversationClientBuilder(innerClientFactory); + serviceCollection.Add(new ServiceDescriptor(typeof(IHostedConversationClient), builder.Build, lifetime)); + return builder; + } + + /// Registers a keyed singleton in the . + /// The to which the client should be added. + /// The key with which to associate the client. + /// The inner that represents the underlying backend. + /// The service lifetime for the client. Defaults to . + /// A that can be used to build a pipeline around the inner client. + /// The client is registered as a singleton service. + public static HostedConversationClientBuilder AddKeyedHostedConversationClient( + this IServiceCollection serviceCollection, + object? serviceKey, + IHostedConversationClient innerClient, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + { + _ = Throw.IfNull(serviceCollection); + _ = Throw.IfNull(innerClient); + + return AddKeyedHostedConversationClient(serviceCollection, serviceKey, _ => innerClient, lifetime); + } + + /// Registers a keyed singleton in the . + /// The to which the client should be added. + /// The key with which to associate the client. + /// A callback that produces the inner that represents the underlying backend. + /// The service lifetime for the client. Defaults to . + /// A that can be used to build a pipeline around the inner client. + /// The client is registered as a singleton service. + public static HostedConversationClientBuilder AddKeyedHostedConversationClient( + this IServiceCollection serviceCollection, + object? serviceKey, + Func innerClientFactory, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + { + _ = Throw.IfNull(serviceCollection); + _ = Throw.IfNull(innerClientFactory); + + var builder = new HostedConversationClientBuilder(innerClientFactory); + serviceCollection.Add(new ServiceDescriptor(typeof(IHostedConversationClient), serviceKey, factory: (services, serviceKey) => builder.Build(services), lifetime)); + return builder; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/HostedConversation/LoggingHostedConversationClient.cs b/src/Libraries/Microsoft.Extensions.AI/HostedConversation/LoggingHostedConversationClient.cs new file mode 100644 index 00000000000..a958be67cdb --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/HostedConversation/LoggingHostedConversationClient.cs @@ -0,0 +1,304 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Shared.DiagnosticIds; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// A delegating hosted conversation client that logs operations to an . +/// +/// +/// The provided implementation of is thread-safe for concurrent use so long as the +/// employed is also thread-safe for concurrent use. +/// +/// +/// When the employed enables , the contents of +/// messages and options are logged. These messages and options may contain sensitive application data. +/// is disabled by default and should never be enabled in a production environment. +/// Messages and options are not logged at other logging levels. +/// +/// +[Experimental(DiagnosticIds.Experiments.AIHostedConversation, UrlFormat = DiagnosticIds.UrlFormat)] +public partial class LoggingHostedConversationClient : DelegatingHostedConversationClient +{ + /// An instance used for all logging. + private readonly ILogger _logger; + + /// The to use for serialization of state written to the logger. + private JsonSerializerOptions _jsonSerializerOptions; + + /// Initializes a new instance of the class. + /// The underlying . + /// An instance that will be used for all logging. + public LoggingHostedConversationClient(IHostedConversationClient innerClient, ILogger logger) + : base(innerClient) + { + _logger = Throw.IfNull(logger); + _jsonSerializerOptions = AIJsonUtilities.DefaultOptions; + } + + /// Gets or sets JSON serialization options to use when serializing logging data. + public JsonSerializerOptions JsonSerializerOptions + { + get => _jsonSerializerOptions; + set => _jsonSerializerOptions = Throw.IfNull(value); + } + + /// + public override async Task CreateAsync( + HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogInvokedSensitive(nameof(CreateAsync), AsJson(options), AsJson(this.GetService())); + } + else + { + LogInvoked(nameof(CreateAsync)); + } + } + + try + { + var conversation = await base.CreateAsync(options, cancellationToken); + + if (_logger.IsEnabled(LogLevel.Debug)) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogCompletedSensitive(nameof(CreateAsync), AsJson(conversation)); + } + else + { + LogCompleted(nameof(CreateAsync)); + } + } + + return conversation; + } + catch (OperationCanceledException) + { + LogInvocationCanceled(nameof(CreateAsync)); + throw; + } + catch (Exception ex) + { + LogInvocationFailed(nameof(CreateAsync), ex); + throw; + } + } + + /// + public override async Task GetAsync( + string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + LogInvokedWithConversationId(nameof(GetAsync), conversationId); + } + + try + { + var conversation = await base.GetAsync(conversationId, options, cancellationToken); + + if (_logger.IsEnabled(LogLevel.Debug)) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogCompletedSensitive(nameof(GetAsync), AsJson(conversation)); + } + else + { + LogCompleted(nameof(GetAsync)); + } + } + + return conversation; + } + catch (OperationCanceledException) + { + LogInvocationCanceled(nameof(GetAsync)); + throw; + } + catch (Exception ex) + { + LogInvocationFailed(nameof(GetAsync), ex); + throw; + } + } + + /// + public override async Task DeleteAsync( + string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + LogInvokedWithConversationId(nameof(DeleteAsync), conversationId); + } + + try + { + await base.DeleteAsync(conversationId, options, cancellationToken); + LogCompleted(nameof(DeleteAsync)); + } + catch (OperationCanceledException) + { + LogInvocationCanceled(nameof(DeleteAsync)); + throw; + } + catch (Exception ex) + { + LogInvocationFailed(nameof(DeleteAsync), ex); + throw; + } + } + + /// + public override async Task AddMessagesAsync( + string conversationId, IEnumerable messages, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogAddMessagesInvokedSensitive(conversationId, AsJson(messages)); + } + else + { + LogInvokedWithConversationId(nameof(AddMessagesAsync), conversationId); + } + } + + try + { + await base.AddMessagesAsync(conversationId, messages, options, cancellationToken); + LogCompleted(nameof(AddMessagesAsync)); + } + catch (OperationCanceledException) + { + LogInvocationCanceled(nameof(AddMessagesAsync)); + throw; + } + catch (Exception ex) + { + LogInvocationFailed(nameof(AddMessagesAsync), ex); + throw; + } + } + + /// + public override async IAsyncEnumerable GetMessagesAsync( + string conversationId, HostedConversationClientOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + LogInvokedWithConversationId(nameof(GetMessagesAsync), conversationId); + } + + IAsyncEnumerator e; + try + { + e = base.GetMessagesAsync(conversationId, options, cancellationToken).GetAsyncEnumerator(cancellationToken); + } + catch (OperationCanceledException) + { + LogInvocationCanceled(nameof(GetMessagesAsync)); + throw; + } + catch (Exception ex) + { + LogInvocationFailed(nameof(GetMessagesAsync), ex); + throw; + } + + try + { + ChatMessage? message = null; + while (true) + { + try + { + if (!await e.MoveNextAsync()) + { + break; + } + + message = e.Current; + } + catch (OperationCanceledException) + { + LogInvocationCanceled(nameof(GetMessagesAsync)); + throw; + } + catch (Exception ex) + { + LogInvocationFailed(nameof(GetMessagesAsync), ex); + throw; + } + + if (_logger.IsEnabled(LogLevel.Debug)) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogMessageReceivedSensitive(AsJson(message)); + } + else + { + LogMessageReceived(); + } + } + + yield return message; + } + + LogCompleted(nameof(GetMessagesAsync)); + } + finally + { + await e.DisposeAsync(); + } + } + + private string AsJson(T value) => TelemetryHelpers.AsJson(value, _jsonSerializerOptions); + + [LoggerMessage(LogLevel.Debug, "{MethodName} invoked.")] + private partial void LogInvoked(string methodName); + + [LoggerMessage(LogLevel.Debug, "{MethodName} invoked. ConversationId: {ConversationId}.")] + private partial void LogInvokedWithConversationId(string methodName, string conversationId); + + [LoggerMessage(LogLevel.Trace, "{MethodName} invoked: Options: {HostedConversationClientOptions}. Metadata: {HostedConversationClientMetadata}.")] + private partial void LogInvokedSensitive(string methodName, string hostedConversationClientOptions, string hostedConversationClientMetadata); + + [LoggerMessage(LogLevel.Trace, "AddMessagesAsync invoked. ConversationId: {ConversationId}. Messages: {Messages}.")] + private partial void LogAddMessagesInvokedSensitive(string conversationId, string messages); + + [LoggerMessage(LogLevel.Debug, "{MethodName} completed.")] + private partial void LogCompleted(string methodName); + + [LoggerMessage(LogLevel.Trace, "{MethodName} completed: {HostedConversationResponse}.")] + private partial void LogCompletedSensitive(string methodName, string hostedConversationResponse); + + [LoggerMessage(LogLevel.Debug, "GetMessagesAsync received message.")] + private partial void LogMessageReceived(); + + [LoggerMessage(LogLevel.Trace, "GetMessagesAsync received message: {ChatMessage}")] + private partial void LogMessageReceivedSensitive(string chatMessage); + + [LoggerMessage(LogLevel.Debug, "{MethodName} canceled.")] + private partial void LogInvocationCanceled(string methodName); + + [LoggerMessage(LogLevel.Error, "{MethodName} failed.")] + private partial void LogInvocationFailed(string methodName, Exception error); +} diff --git a/src/Libraries/Microsoft.Extensions.AI/HostedConversation/LoggingHostedConversationClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/HostedConversation/LoggingHostedConversationClientBuilderExtensions.cs new file mode 100644 index 00000000000..65268b6bf6f --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/HostedConversation/LoggingHostedConversationClientBuilderExtensions.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Shared.DiagnosticIds; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides extensions for configuring instances. +[Experimental(DiagnosticIds.Experiments.AIHostedConversation, UrlFormat = DiagnosticIds.UrlFormat)] +public static class LoggingHostedConversationClientBuilderExtensions +{ + /// Adds logging to the hosted conversation client pipeline. + /// The . + /// + /// An optional used to create a logger with which logging should be performed. + /// If not supplied, a required instance will be resolved from the service provider. + /// + /// An optional callback that can be used to configure the instance. + /// The . + /// + /// + /// When the employed enables , the contents of + /// messages and options are logged. These messages and options may contain sensitive application data. + /// is disabled by default and should never be enabled in a production environment. + /// Messages and options are not logged at other logging levels. + /// + /// + public static HostedConversationClientBuilder UseLogging( + this HostedConversationClientBuilder builder, + ILoggerFactory? loggerFactory = null, + Action? configure = null) + { + _ = Throw.IfNull(builder); + + return builder.Use((innerClient, services) => + { + loggerFactory ??= services.GetRequiredService(); + + // If the factory we resolve is for the null logger, the LoggingHostedConversationClient will end up + // being an expensive nop, so skip adding it and just return the inner client. + if (loggerFactory == NullLoggerFactory.Instance) + { + return innerClient; + } + + var client = new LoggingHostedConversationClient(innerClient, loggerFactory.CreateLogger(typeof(LoggingHostedConversationClient))); + configure?.Invoke(client); + return client; + }); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/HostedConversation/OpenTelemetryHostedConversationClient.cs b/src/Libraries/Microsoft.Extensions.AI/HostedConversation/OpenTelemetryHostedConversationClient.cs new file mode 100644 index 00000000000..be652e7a855 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/HostedConversation/OpenTelemetryHostedConversationClient.cs @@ -0,0 +1,359 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Metrics; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Shared.DiagnosticIds; + +#pragma warning disable SA1111 // Closing parenthesis should be on line of last parameter +#pragma warning disable SA1113 // Comma should be on the same line as previous parameter + +namespace Microsoft.Extensions.AI; + +/// Represents a delegating hosted conversation client that implements OpenTelemetry-compatible tracing and metrics for conversation operations. +/// +/// +/// Since there is currently no OpenTelemetry Semantic Convention for hosted conversation operations, this implementation +/// uses general client span conventions alongside standard conversations.* registry attributes where applicable. +/// +/// +/// The specification is subject to change as relevant OpenTelemetry conventions emerge; as such, the telemetry +/// output by this client is also subject to change. +/// +/// +[Experimental(DiagnosticIds.Experiments.AIHostedConversation, UrlFormat = DiagnosticIds.UrlFormat)] +public sealed class OpenTelemetryHostedConversationClient : DelegatingHostedConversationClient +{ + private const string CreateOperationName = "conversations.create"; + private const string GetOperationName = "conversations.get"; + private const string DeleteOperationName = "conversations.delete"; + private const string AddMessagesOperationName = "conversations.add_messages"; + private const string GetMessagesOperationName = "conversations.get_messages"; + + private const string OperationDurationMetricName = "conversations.client.operation.duration"; + private const string OperationDurationMetricDescription = "Measures the duration of a conversation operation"; + + private const string ConversationsOperationNameAttribute = "conversations.operation.name"; + private const string ConversationsProviderNameAttribute = "conversations.provider.name"; + private const string ConversationsIdAttribute = "conversations.id"; + private const string ConversationsMessagesCountAttribute = "conversations.messages.count"; + + private readonly ActivitySource _activitySource; + private readonly Meter _meter; + private readonly Histogram _operationDurationHistogram; + + private readonly string? _providerName; + private readonly string? _serverAddress; + private readonly int _serverPort; + + private readonly ILogger? _logger; + + /// Initializes a new instance of the class. + /// The underlying . + /// The to use for emitting any logging data from the client. + /// An optional source name that will be used on the telemetry data. + public OpenTelemetryHostedConversationClient(IHostedConversationClient innerClient, ILogger? logger = null, string? sourceName = null) + : base(innerClient) + { + Debug.Assert(innerClient is not null, "Should have been validated by the base ctor"); + + _logger = logger; + + if (innerClient!.GetService() is HostedConversationClientMetadata metadata) + { + _providerName = metadata.ProviderName; + _serverAddress = metadata.ProviderUri?.Host; + _serverPort = metadata.ProviderUri?.Port ?? 0; + } + + string name = string.IsNullOrEmpty(sourceName) ? OpenTelemetryConsts.DefaultSourceName : sourceName!; + _activitySource = new(name); + _meter = new(name); + + _operationDurationHistogram = _meter.CreateHistogram( + OperationDurationMetricName, + OpenTelemetryConsts.SecondsUnit, + OperationDurationMetricDescription, + advice: new() { HistogramBucketBoundaries = OpenTelemetryConsts.GenAI.Client.OperationDuration.ExplicitBucketBoundaries } + ); + } + + /// + /// Gets or sets a value indicating whether potentially sensitive information should be included in telemetry. + /// + /// + /// if potentially sensitive information should be included in telemetry; + /// if telemetry shouldn't include raw inputs and outputs. + /// The default value is , unless the OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT + /// environment variable is set to "true" (case-insensitive). + /// + /// + /// + /// By default, telemetry includes operation metadata such as provider name, duration, + /// conversation IDs, and message counts. + /// + /// + /// When enabled, telemetry will additionally include message content, which may contain sensitive information. + /// + /// + /// The default value can be overridden by setting the OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT + /// environment variable to "true". Explicitly setting this property will override the environment variable. + /// + /// + public bool EnableSensitiveData { get; set; } = TelemetryHelpers.EnableSensitiveDataDefault; + + /// + public override object? GetService(Type serviceType, object? serviceKey = null) => + serviceKey is null && serviceType == typeof(ActivitySource) ? _activitySource : + base.GetService(serviceType, serviceKey); + + /// + protected override void Dispose(bool disposing) + { + if (disposing) + { + _activitySource.Dispose(); + _meter.Dispose(); + } + + base.Dispose(disposing); + } + + /// + public override async Task CreateAsync( + HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + { + using Activity? activity = StartActivity(CreateOperationName); + Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; + + HostedConversation? result = null; + Exception? error = null; + try + { + result = await base.CreateAsync(options, cancellationToken).ConfigureAwait(false); + return result; + } + catch (Exception ex) + { + error = ex; + throw; + } + finally + { + if (result is not null && activity is { IsAllDataRequested: true }) + { + _ = activity.AddTag(ConversationsIdAttribute, result.ConversationId); + } + + RecordDuration(stopwatch, CreateOperationName, error); + SetErrorStatus(activity, error); + } + } + + /// + public override async Task GetAsync( + string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + { + using Activity? activity = StartActivity(GetOperationName, conversationId); + Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; + + Exception? error = null; + try + { + return await base.GetAsync(conversationId, options, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + error = ex; + throw; + } + finally + { + RecordDuration(stopwatch, GetOperationName, error); + SetErrorStatus(activity, error); + } + } + + /// + public override async Task DeleteAsync( + string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + { + using Activity? activity = StartActivity(DeleteOperationName, conversationId); + Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; + + Exception? error = null; + try + { + await base.DeleteAsync(conversationId, options, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + error = ex; + throw; + } + finally + { + RecordDuration(stopwatch, DeleteOperationName, error); + SetErrorStatus(activity, error); + } + } + + /// + public override async Task AddMessagesAsync( + string conversationId, IEnumerable messages, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + { + using Activity? activity = StartActivity(AddMessagesOperationName, conversationId); + Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; + + Exception? error = null; + try + { + await base.AddMessagesAsync(conversationId, messages, options, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + error = ex; + throw; + } + finally + { + RecordDuration(stopwatch, AddMessagesOperationName, error); + SetErrorStatus(activity, error); + } + } + + /// + public override async IAsyncEnumerable GetMessagesAsync( + string conversationId, HostedConversationClientOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + using Activity? activity = StartActivity(GetMessagesOperationName, conversationId); + Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; + + IAsyncEnumerator e; + Exception? error = null; + try + { + e = base.GetMessagesAsync(conversationId, options, cancellationToken).GetAsyncEnumerator(cancellationToken); + } + catch (Exception ex) + { + error = ex; + RecordDuration(stopwatch, GetMessagesOperationName, error); + SetErrorStatus(activity, error); + throw; + } + + int count = 0; + try + { + while (true) + { + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + break; + } + } + catch (Exception ex) + { + error = ex; + throw; + } + + count++; + yield return e.Current; + Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 + } + } + finally + { + if (activity is { IsAllDataRequested: true }) + { + _ = activity.AddTag(ConversationsMessagesCountAttribute, count); + } + + RecordDuration(stopwatch, GetMessagesOperationName, error); + SetErrorStatus(activity, error); + + await e.DisposeAsync().ConfigureAwait(false); + } + } + + private void SetErrorStatus(Activity? activity, Exception? error) + { + if (error is not null) + { + _ = activity? + .AddTag(OpenTelemetryConsts.Error.Type, error.GetType().FullName) + .SetStatus(ActivityStatusCode.Error, error.Message); + + if (_logger is not null) + { + OpenTelemetryLog.OperationException(_logger, error); + } + } + } + + private Activity? StartActivity(string operationName, string? conversationId = null) + { + Activity? activity = null; + if (_activitySource.HasListeners()) + { + activity = _activitySource.StartActivity( + operationName, + ActivityKind.Client); + + if (activity is { IsAllDataRequested: true }) + { + _ = activity + .AddTag(ConversationsOperationNameAttribute, operationName) + .AddTag(ConversationsProviderNameAttribute, _providerName); + + if (conversationId is not null) + { + _ = activity.AddTag(ConversationsIdAttribute, conversationId); + } + + if (_serverAddress is not null) + { + _ = activity + .AddTag(OpenTelemetryConsts.Server.Address, _serverAddress) + .AddTag(OpenTelemetryConsts.Server.Port, _serverPort); + } + } + } + + return activity; + } + + private void RecordDuration(Stopwatch? stopwatch, string operationName, Exception? error) + { + if (_operationDurationHistogram.Enabled && stopwatch is not null) + { + TagList tags = default; + tags.Add(ConversationsOperationNameAttribute, operationName); + tags.Add(ConversationsProviderNameAttribute, _providerName); + + if (_serverAddress is string address) + { + tags.Add(OpenTelemetryConsts.Server.Address, address); + tags.Add(OpenTelemetryConsts.Server.Port, _serverPort); + } + + if (error is not null) + { + tags.Add(OpenTelemetryConsts.Error.Type, error.GetType().FullName); + } + + _operationDurationHistogram.Record(stopwatch.Elapsed.TotalSeconds, tags); + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/HostedConversation/OpenTelemetryHostedConversationClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/HostedConversation/OpenTelemetryHostedConversationClientBuilderExtensions.cs new file mode 100644 index 00000000000..ef4cef6302e --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/HostedConversation/OpenTelemetryHostedConversationClientBuilderExtensions.cs @@ -0,0 +1,44 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Shared.DiagnosticIds; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides extensions for configuring instances. +[Experimental(DiagnosticIds.Experiments.AIHostedConversation, UrlFormat = DiagnosticIds.UrlFormat)] +public static class OpenTelemetryHostedConversationClientBuilderExtensions +{ + /// + /// Adds OpenTelemetry support to the hosted conversation client pipeline. + /// + /// + /// Since there is currently no OpenTelemetry Semantic Convention for hosted conversation operations, this implementation + /// uses general client span conventions alongside standard conversations.* registry attributes where applicable. + /// The telemetry output is subject to change as relevant conventions emerge. + /// + /// The . + /// An optional to use to create a logger for logging events. + /// An optional source name that will be used on the telemetry data. + /// An optional callback that can be used to configure the instance. + /// The . + public static HostedConversationClientBuilder UseOpenTelemetry( + this HostedConversationClientBuilder builder, + ILoggerFactory? loggerFactory = null, + string? sourceName = null, + Action? configure = null) => + Throw.IfNull(builder).Use((innerClient, services) => + { + loggerFactory ??= services.GetService(); + + var client = new OpenTelemetryHostedConversationClient(innerClient, loggerFactory?.CreateLogger(typeof(OpenTelemetryHostedConversationClient)), sourceName); + configure?.Invoke(client); + + return client; + }); +} diff --git a/src/Shared/DiagnosticIds/DiagnosticIds.cs b/src/Shared/DiagnosticIds/DiagnosticIds.cs index 94cc1a1f04a..9b2130880bf 100644 --- a/src/Shared/DiagnosticIds/DiagnosticIds.cs +++ b/src/Shared/DiagnosticIds/DiagnosticIds.cs @@ -60,6 +60,7 @@ internal static class Experiments internal const string AIWebSearch = AIExperiments; internal const string AIRealTime = AIExperiments; internal const string AIFiles = AIExperiments; + internal const string AIHostedConversation = AIExperiments; // These diagnostic IDs are defined by the OpenAI package for its experimental APIs. // We use the same IDs so consumers do not need to suppress additional diagnostics diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/HostedConversation/DelegatingHostedConversationClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/HostedConversation/DelegatingHostedConversationClientTests.cs new file mode 100644 index 00000000000..7063872d429 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/HostedConversation/DelegatingHostedConversationClientTests.cs @@ -0,0 +1,315 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable MEAI001 + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class DelegatingHostedConversationClientTests +{ + [Fact] + public void RequiresInnerClient() + { + Assert.Throws("innerClient", () => new NoOpDelegatingHostedConversationClient(null!)); + } + + [Fact] + public async Task CreateAsyncDefaultsToInnerClientAsync() + { + // Arrange + var expectedOptions = new HostedConversationClientOptions(); + var expectedCancellationToken = CancellationToken.None; + var expectedResult = new TaskCompletionSource(); + var expectedConversation = new HostedConversation { ConversationId = "conv-1" }; + using var inner = new TestHostedConversationClient + { + CreateAsyncCallback = (options, cancellationToken) => + { + Assert.Same(expectedOptions, options); + Assert.Equal(expectedCancellationToken, cancellationToken); + return expectedResult.Task; + } + }; + + using var delegating = new NoOpDelegatingHostedConversationClient(inner); + + // Act + var resultTask = delegating.CreateAsync(expectedOptions, expectedCancellationToken); + + // Assert + Assert.False(resultTask.IsCompleted); + expectedResult.SetResult(expectedConversation); + Assert.True(resultTask.IsCompleted); + Assert.Same(expectedConversation, await resultTask); + } + + [Fact] + public async Task GetAsyncDefaultsToInnerClientAsync() + { + // Arrange + var expectedConversationId = "conv-123"; + var expectedCancellationToken = CancellationToken.None; + var expectedResult = new TaskCompletionSource(); + var expectedConversation = new HostedConversation { ConversationId = expectedConversationId }; + using var inner = new TestHostedConversationClient + { + GetAsyncCallback = (conversationId, options, cancellationToken) => + { + Assert.Equal(expectedConversationId, conversationId); + Assert.Equal(expectedCancellationToken, cancellationToken); + return expectedResult.Task; + } + }; + + using var delegating = new NoOpDelegatingHostedConversationClient(inner); + + // Act + var resultTask = delegating.GetAsync(expectedConversationId, cancellationToken: expectedCancellationToken); + + // Assert + Assert.False(resultTask.IsCompleted); + expectedResult.SetResult(expectedConversation); + Assert.True(resultTask.IsCompleted); + Assert.Same(expectedConversation, await resultTask); + } + + [Fact] + public async Task DeleteAsyncDefaultsToInnerClientAsync() + { + // Arrange + var expectedConversationId = "conv-123"; + var expectedCancellationToken = CancellationToken.None; + var expectedResult = new TaskCompletionSource(); + using var inner = new TestHostedConversationClient + { + DeleteAsyncCallback = (conversationId, options, cancellationToken) => + { + Assert.Equal(expectedConversationId, conversationId); + Assert.Equal(expectedCancellationToken, cancellationToken); + return expectedResult.Task; + } + }; + + using var delegating = new NoOpDelegatingHostedConversationClient(inner); + + // Act + var resultTask = delegating.DeleteAsync(expectedConversationId, cancellationToken: expectedCancellationToken); + + // Assert + Assert.False(resultTask.IsCompleted); + expectedResult.SetResult(true); + Assert.True(resultTask.IsCompleted); + await resultTask; + } + + [Fact] + public async Task AddMessagesAsyncDefaultsToInnerClientAsync() + { + // Arrange + var expectedConversationId = "conv-123"; + var expectedMessages = new List { new(ChatRole.User, "Hello") }; + var expectedCancellationToken = CancellationToken.None; + var expectedResult = new TaskCompletionSource(); + using var inner = new TestHostedConversationClient + { + AddMessagesAsyncCallback = (conversationId, messages, options, cancellationToken) => + { + Assert.Equal(expectedConversationId, conversationId); + Assert.Same(expectedMessages, messages); + Assert.Equal(expectedCancellationToken, cancellationToken); + return expectedResult.Task; + } + }; + + using var delegating = new NoOpDelegatingHostedConversationClient(inner); + + // Act + var resultTask = delegating.AddMessagesAsync(expectedConversationId, expectedMessages, cancellationToken: expectedCancellationToken); + + // Assert + Assert.False(resultTask.IsCompleted); + expectedResult.SetResult(true); + Assert.True(resultTask.IsCompleted); + await resultTask; + } + + [Fact] + public async Task GetMessagesAsyncDefaultsToInnerClientAsync() + { + // Arrange + var expectedConversationId = "conv-123"; + var expectedCancellationToken = CancellationToken.None; + ChatMessage[] expectedMessages = + [ + new(ChatRole.User, "Hello"), + new(ChatRole.Assistant, "Hi"), + ]; + + using var inner = new TestHostedConversationClient + { + GetMessagesAsyncCallback = (conversationId, options, cancellationToken) => + { + Assert.Equal(expectedConversationId, conversationId); + Assert.Equal(expectedCancellationToken, cancellationToken); + return YieldAsync(expectedMessages); + } + }; + + using var delegating = new NoOpDelegatingHostedConversationClient(inner); + + // Act + var resultAsyncEnumerable = delegating.GetMessagesAsync(expectedConversationId, cancellationToken: expectedCancellationToken); + + // Assert + var enumerator = resultAsyncEnumerable.GetAsyncEnumerator(); + Assert.True(await enumerator.MoveNextAsync()); + Assert.Same(expectedMessages[0], enumerator.Current); + Assert.True(await enumerator.MoveNextAsync()); + Assert.Same(expectedMessages[1], enumerator.Current); + Assert.False(await enumerator.MoveNextAsync()); + } + + [Fact] + public void GetServiceThrowsForNullType() + { + using var inner = new TestHostedConversationClient(); + using var delegating = new NoOpDelegatingHostedConversationClient(inner); + Assert.Throws("serviceType", () => delegating.GetService(null!)); + } + + [Fact] + public void GetServiceReturnsSelfIfCompatibleWithRequestAndKeyIsNull() + { + // Arrange + using var inner = new TestHostedConversationClient(); + using var delegating = new NoOpDelegatingHostedConversationClient(inner); + + // Act + var client = delegating.GetService(); + + // Assert + Assert.Same(delegating, client); + } + + [Fact] + public void GetServiceDelegatesToInnerIfKeyIsNotNull() + { + // Arrange + var expectedKey = new object(); + using var expectedResult = new TestHostedConversationClient(); + using var inner = new TestHostedConversationClient + { + GetServiceCallback = (_, _) => expectedResult + }; + using var delegating = new NoOpDelegatingHostedConversationClient(inner); + + // Act + var client = delegating.GetService(expectedKey); + + // Assert + Assert.Same(expectedResult, client); + } + + [Fact] + public void GetServiceDelegatesToInnerIfNotCompatibleWithRequest() + { + // Arrange + var expectedResult = TimeZoneInfo.Local; + var expectedKey = new object(); + using var inner = new TestHostedConversationClient + { + GetServiceCallback = (type, key) => type == expectedResult.GetType() && key == expectedKey + ? expectedResult + : throw new InvalidOperationException("Unexpected call") + }; + using var delegating = new NoOpDelegatingHostedConversationClient(inner); + + // Act + var tzi = delegating.GetService(expectedKey); + + // Assert + Assert.Same(expectedResult, tzi); + } + + [Fact] + public void DisposeDisposesInnerClient() + { + // Arrange + using var inner = new TestHostedConversationClient(); + using var delegating = new NoOpDelegatingHostedConversationClient(inner); + + Assert.False(inner.Disposed); + + // Act + delegating.Dispose(); + + // Assert + Assert.True(inner.Disposed); + } + + private static async IAsyncEnumerable YieldAsync(IEnumerable input) + { + await Task.Yield(); + foreach (var item in input) + { + yield return item; + } + } + + private sealed class NoOpDelegatingHostedConversationClient(IHostedConversationClient innerClient) + : DelegatingHostedConversationClient(innerClient); + + private sealed class TestHostedConversationClient : IHostedConversationClient + { + public TestHostedConversationClient() + { + GetServiceCallback = DefaultGetServiceCallback; + } + + public bool Disposed { get; private set; } + + public Func>? CreateAsyncCallback { get; set; } + + public Func>? GetAsyncCallback { get; set; } + + public Func? DeleteAsyncCallback { get; set; } + + public Func, HostedConversationClientOptions?, CancellationToken, Task>? AddMessagesAsyncCallback { get; set; } + + public Func>? GetMessagesAsyncCallback { get; set; } + + public Func GetServiceCallback { get; set; } + + private object? DefaultGetServiceCallback(Type serviceType, object? serviceKey) => + serviceType is not null && serviceKey is null && serviceType.IsInstanceOfType(this) ? this : null; + + public Task CreateAsync(HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => CreateAsyncCallback!.Invoke(options, cancellationToken); + + public Task GetAsync(string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => GetAsyncCallback!.Invoke(conversationId, options, cancellationToken); + + public Task DeleteAsync(string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => DeleteAsyncCallback!.Invoke(conversationId, options, cancellationToken); + + public Task AddMessagesAsync(string conversationId, IEnumerable messages, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => AddMessagesAsyncCallback!.Invoke(conversationId, messages, options, cancellationToken); + + public IAsyncEnumerable GetMessagesAsync(string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => GetMessagesAsyncCallback!.Invoke(conversationId, options, cancellationToken); + + public object? GetService(Type serviceType, object? serviceKey = null) + => GetServiceCallback(serviceType, serviceKey); + + public void Dispose() + { + Disposed = true; + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/HostedConversation/HostedConversationClientOptionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/HostedConversation/HostedConversationClientOptionsTests.cs new file mode 100644 index 00000000000..d536865f1c8 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/HostedConversation/HostedConversationClientOptionsTests.cs @@ -0,0 +1,165 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable MEAI001 + +using System; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class HostedConversationClientOptionsTests +{ + [Fact] + public void Constructor_Parameterless_PropsDefaulted() + { + HostedConversationClientOptions options = new(); + Assert.Null(options.Limit); + Assert.Null(options.RawRepresentationFactory); + Assert.Null(options.AdditionalProperties); + + HostedConversationClientOptions clone = options.Clone(); + Assert.Null(clone.Limit); + Assert.Null(clone.RawRepresentationFactory); + Assert.Null(clone.AdditionalProperties); + } + + [Fact] + public void Properties_Roundtrip() + { + HostedConversationClientOptions options = new(); + + Func rawRepresentationFactory = (c) => null; + + AdditionalPropertiesDictionary additionalProps = new() + { + ["key"] = "value", + }; + + options.Limit = 42; + options.RawRepresentationFactory = rawRepresentationFactory; + options.AdditionalProperties = additionalProps; + + Assert.Equal(42, options.Limit); + Assert.Same(rawRepresentationFactory, options.RawRepresentationFactory); + Assert.Same(additionalProps, options.AdditionalProperties); + } + + [Fact] + public void Clone_CreatesIndependentCopy() + { + HostedConversationClientOptions original = new() + { + Limit = 10, + RawRepresentationFactory = (c) => null, + AdditionalProperties = new() { ["prop"] = "val" }, + }; + + HostedConversationClientOptions clone = original.Clone(); + + Assert.NotSame(original, clone); + Assert.Equal(10, clone.Limit); + Assert.NotNull(clone.RawRepresentationFactory); + Assert.NotNull(clone.AdditionalProperties); + } + + [Fact] + public void Clone_CopiesRawRepresentationFactoryByReference() + { + Func factory = (c) => "test"; + + HostedConversationClientOptions original = new() + { + RawRepresentationFactory = factory, + }; + + HostedConversationClientOptions clone = original.Clone(); + + Assert.Same(original.RawRepresentationFactory, clone.RawRepresentationFactory); + } + + [Fact] + public void Clone_DeepCopiesAdditionalProperties() + { + HostedConversationClientOptions original = new() + { + AdditionalProperties = new() { ["key"] = "value" }, + }; + + HostedConversationClientOptions clone = original.Clone(); + + Assert.NotSame(original.AdditionalProperties, clone.AdditionalProperties); + Assert.Equal("value", clone.AdditionalProperties!["key"]); + + // Modifying clone should not affect original + clone.AdditionalProperties["key"] = "modified"; + Assert.Equal("value", original.AdditionalProperties["key"]); + + clone.AdditionalProperties["newkey"] = "newval"; + Assert.False(original.AdditionalProperties.ContainsKey("newkey")); + } + + [Fact] + public void Clone_CopiesLimit() + { + HostedConversationClientOptions original = new() + { + Limit = 25, + }; + + HostedConversationClientOptions clone = original.Clone(); + + Assert.Equal(25, clone.Limit); + + // Modifying clone should not affect original + clone.Limit = 50; + Assert.Equal(25, original.Limit); + } + + [Fact] + public void CopyConstructor_Null_Valid() + { + PassedNullToBaseOptions options = new(); + Assert.NotNull(options); + } + + [Fact] + public void CopyConstructors_EnableHierarchyCloning() + { + DerivedOptions derived = new() + { + Limit = 5, + CustomProperty = 42, + }; + + HostedConversationClientOptions clone = derived.Clone(); + + Assert.Equal(5, clone.Limit); + Assert.Equal(42, Assert.IsType(clone).CustomProperty); + } + + private class PassedNullToBaseOptions : HostedConversationClientOptions + { + public PassedNullToBaseOptions() + : base(null) + { + } + } + + private class DerivedOptions : HostedConversationClientOptions + { + public DerivedOptions() + { + } + + protected DerivedOptions(DerivedOptions other) + : base(other) + { + CustomProperty = other.CustomProperty; + } + + public int CustomProperty { get; set; } + + public override HostedConversationClientOptions Clone() => new DerivedOptions(this); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/HostedConversation/HostedConversationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/HostedConversation/HostedConversationTests.cs new file mode 100644 index 00000000000..9178550bee7 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/HostedConversation/HostedConversationTests.cs @@ -0,0 +1,80 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable MEAI001 + +using System; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class HostedConversationTests +{ + [Fact] + public void Constructor_Parameterless_PropsDefaulted() + { + HostedConversation conversation = new(); + Assert.Null(conversation.ConversationId); + Assert.Null(conversation.CreatedAt); + Assert.Null(conversation.RawRepresentation); + Assert.Null(conversation.AdditionalProperties); + } + + [Fact] + public void ConversationId_Roundtrips() + { + HostedConversation conversation = new(); + Assert.Null(conversation.ConversationId); + + conversation.ConversationId = "conv-123"; + Assert.Equal("conv-123", conversation.ConversationId); + + conversation.ConversationId = null; + Assert.Null(conversation.ConversationId); + } + + [Fact] + public void CreatedAt_Roundtrips() + { + HostedConversation conversation = new(); + Assert.Null(conversation.CreatedAt); + + DateTimeOffset createdAt = new(2024, 6, 15, 10, 30, 0, TimeSpan.Zero); + conversation.CreatedAt = createdAt; + Assert.Equal(createdAt, conversation.CreatedAt); + + conversation.CreatedAt = null; + Assert.Null(conversation.CreatedAt); + } + + [Fact] + public void RawRepresentation_Roundtrips() + { + HostedConversation conversation = new(); + Assert.Null(conversation.RawRepresentation); + + object raw = new(); + conversation.RawRepresentation = raw; + Assert.Same(raw, conversation.RawRepresentation); + + conversation.RawRepresentation = null; + Assert.Null(conversation.RawRepresentation); + } + + [Fact] + public void AdditionalProperties_Roundtrips() + { + HostedConversation conversation = new(); + Assert.Null(conversation.AdditionalProperties); + + AdditionalPropertiesDictionary additionalProps = new() + { + ["key"] = "value", + }; + conversation.AdditionalProperties = additionalProps; + Assert.Same(additionalProps, conversation.AdditionalProperties); + + conversation.AdditionalProperties = null; + Assert.Null(conversation.AdditionalProperties); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIHostedConversationClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIHostedConversationClientTests.cs new file mode 100644 index 00000000000..b77bacffb99 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIHostedConversationClientTests.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.ClientModel; +using OpenAI; +using OpenAI.Conversations; +using Xunit; + +#pragma warning disable MEAI001 +#pragma warning disable OPENAI001 + +namespace Microsoft.Extensions.AI; + +public class OpenAIHostedConversationClientTests +{ + [Fact] + public void AsIHostedConversationClient_NullClient_Throws() + { + Assert.Throws("conversationClient", () => ((ConversationClient)null!).AsIHostedConversationClient()); + } + + [Fact] + public void GetService_ReturnsMetadata() + { + Uri endpoint = new("http://localhost/some/endpoint"); + ConversationClient conversationClient = new(new ApiKeyCredential("key"), new OpenAIClientOptions { Endpoint = endpoint }); + + IHostedConversationClient client = conversationClient.AsIHostedConversationClient(); + + var metadata = client.GetService(typeof(HostedConversationClientMetadata)) as HostedConversationClientMetadata; + Assert.NotNull(metadata); + Assert.Equal("openai", metadata.ProviderName); + Assert.Equal(endpoint, metadata.ProviderUri); + } + + [Fact] + public void GetService_ReturnsConversationClient() + { + ConversationClient conversationClient = new(new ApiKeyCredential("key")); + + IHostedConversationClient client = conversationClient.AsIHostedConversationClient(); + + Assert.Same(conversationClient, client.GetService(typeof(ConversationClient))); + } + + [Fact] + public void GetService_ReturnsSelf() + { + ConversationClient conversationClient = new(new ApiKeyCredential("key")); + + IHostedConversationClient client = conversationClient.AsIHostedConversationClient(); + + Assert.Same(client, client.GetService(typeof(IHostedConversationClient))); + } + + [Fact] + public void GetService_ReturnsNull_ForUnknownType() + { + ConversationClient conversationClient = new(new ApiKeyCredential("key")); + + IHostedConversationClient client = conversationClient.AsIHostedConversationClient(); + + Assert.Null(client.GetService(typeof(string))); + } + + [Fact] + public void GetService_ReturnsNull_ForServiceKey() + { + ConversationClient conversationClient = new(new ApiKeyCredential("key")); + + IHostedConversationClient client = conversationClient.AsIHostedConversationClient(); + + Assert.Null(client.GetService(typeof(IHostedConversationClient), "someKey")); + Assert.Null(client.GetService(typeof(HostedConversationClientMetadata), "someKey")); + Assert.Null(client.GetService(typeof(ConversationClient), "someKey")); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/HostedConversation/ConfigureOptionsHostedConversationClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/HostedConversation/ConfigureOptionsHostedConversationClientTests.cs new file mode 100644 index 00000000000..d5063013520 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/HostedConversation/ConfigureOptionsHostedConversationClientTests.cs @@ -0,0 +1,157 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable MEAI001 + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ConfigureOptionsHostedConversationClientTests +{ + [Fact] + public async Task ConfigureOptions_CallbackIsCalled() + { + // Arrange + var callbackInvoked = false; + HostedConversationClientOptions? receivedByInner = null; + + using var innerClient = new TestHostedConversationClient + { + CreateAsyncCallback = (options, _) => + { + receivedByInner = options; + return Task.FromResult(new HostedConversation { ConversationId = "conv-1" }); + } + }; + + using var client = new ConfigureOptionsHostedConversationClient(innerClient, options => + { + callbackInvoked = true; + options.AdditionalProperties ??= new(); + options.AdditionalProperties["configured"] = "true"; + }); + + // Act + await client.CreateAsync(new HostedConversationClientOptions()); + + // Assert + Assert.True(callbackInvoked); + Assert.NotNull(receivedByInner); + Assert.Equal("true", receivedByInner!.AdditionalProperties!["configured"]); + } + + [Fact] + public async Task ConfigureOptions_OptionsAreCloned_OriginalNotModified() + { + // Arrange + HostedConversationClientOptions? receivedByInner = null; + var originalOptions = new HostedConversationClientOptions + { + AdditionalProperties = new() { ["key"] = "original" } + }; + + using var innerClient = new TestHostedConversationClient + { + CreateAsyncCallback = (options, _) => + { + receivedByInner = options; + return Task.FromResult(new HostedConversation { ConversationId = "conv-1" }); + } + }; + + using var client = new ConfigureOptionsHostedConversationClient(innerClient, options => + { + options.AdditionalProperties!["key"] = "modified"; + }); + + // Act + await client.CreateAsync(originalOptions); + + // Assert - original should not be modified + Assert.Equal("original", originalOptions.AdditionalProperties["key"]); + + // Assert - inner received modified clone + Assert.NotNull(receivedByInner); + Assert.NotSame(originalOptions, receivedByInner); + Assert.Equal("modified", receivedByInner!.AdditionalProperties!["key"]); + } + + [Fact] + public async Task ConfigureOptions_NullOptions_CreatesNew() + { + // Arrange + HostedConversationClientOptions? receivedByInner = null; + + using var innerClient = new TestHostedConversationClient + { + CreateAsyncCallback = (options, _) => + { + receivedByInner = options; + return Task.FromResult(new HostedConversation { ConversationId = "conv-1" }); + } + }; + + using var client = new ConfigureOptionsHostedConversationClient(innerClient, options => + { + options.AdditionalProperties = new() { ["new"] = "value" }; + }); + + // Act + await client.CreateAsync(options: null); + + // Assert - a new options instance should have been created + Assert.NotNull(receivedByInner); + Assert.Equal("value", receivedByInner!.AdditionalProperties!["new"]); + } + + [Fact] + public void Constructor_NullCallback_Throws() + { + using var innerClient = new TestHostedConversationClient(); + Assert.Throws("configure", () => new ConfigureOptionsHostedConversationClient(innerClient, null!)); + } + + [Fact] + public void Constructor_NullInnerClient_Throws() + { + Assert.Throws("innerClient", () => new ConfigureOptionsHostedConversationClient(null!, _ => { })); + } + + private sealed class TestHostedConversationClient : IHostedConversationClient + { + public Func>? CreateAsyncCallback { get; set; } + + public Task CreateAsync(HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => CreateAsyncCallback?.Invoke(options, cancellationToken) ?? Task.FromResult(new HostedConversation { ConversationId = "test" }); + + public Task GetAsync(string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => Task.FromResult(new HostedConversation { ConversationId = conversationId }); + + public Task DeleteAsync(string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => Task.CompletedTask; + + public Task AddMessagesAsync(string conversationId, IEnumerable messages, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => Task.CompletedTask; + + public IAsyncEnumerable GetMessagesAsync(string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => EmptyAsync(); + + private static async IAsyncEnumerable EmptyAsync() + { + await Task.CompletedTask; + yield break; + } + + public object? GetService(Type serviceType, object? serviceKey = null) + => serviceType is not null && serviceKey is null && serviceType.IsInstanceOfType(this) ? this : null; + + public void Dispose() + { + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/HostedConversation/HostedConversationClientBuilderTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/HostedConversation/HostedConversationClientBuilderTest.cs new file mode 100644 index 00000000000..d3fce6431eb --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/HostedConversation/HostedConversationClientBuilderTest.cs @@ -0,0 +1,152 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable MEAI001 + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class HostedConversationClientBuilderTest +{ + [Fact] + public void Build_WithSimpleInnerClient_Works() + { + // Arrange + using var innerClient = new TestHostedConversationClient(); + var builder = new HostedConversationClientBuilder(innerClient); + + // Act + var result = builder.Build(); + + // Assert + Assert.Same(innerClient, result); + } + + [Fact] + public void Use_AddsMiddlewareInCorrectOrder() + { + // Arrange + using var innerClient = new TestHostedConversationClient(); + var builder = new HostedConversationClientBuilder(innerClient); + + builder.Use(next => new NamedDelegatingHostedConversationClient("First", next)); + builder.Use(next => new NamedDelegatingHostedConversationClient("Second", next)); + builder.Use(next => new NamedDelegatingHostedConversationClient("Third", next)); + + // Act + var first = (NamedDelegatingHostedConversationClient)builder.Build(); + + // Assert - outermost is first added + Assert.Equal("First", first.Name); + var second = (NamedDelegatingHostedConversationClient)first.GetInnerClient(); + Assert.Equal("Second", second.Name); + var third = (NamedDelegatingHostedConversationClient)second.GetInnerClient(); + Assert.Equal("Third", third.Name); + } + + [Fact] + public void Build_WithIServiceProvider_PassesServiceProvider() + { + // Arrange + var expectedServiceProvider = new ServiceCollection().BuildServiceProvider(); + using var expectedInnerClient = new TestHostedConversationClient(); + using var expectedOuterClient = new TestHostedConversationClient(); + + var builder = new HostedConversationClientBuilder(services => + { + Assert.Same(expectedServiceProvider, services); + return expectedInnerClient; + }); + + builder.Use((innerClient, serviceProvider) => + { + Assert.Same(expectedServiceProvider, serviceProvider); + Assert.Same(expectedInnerClient, innerClient); + return expectedOuterClient; + }); + + // Act & Assert + Assert.Same(expectedOuterClient, builder.Build(expectedServiceProvider)); + } + + [Fact] + public void Constructor_NullInnerClient_Throws() + { + Assert.Throws("innerClient", () => new HostedConversationClientBuilder((IHostedConversationClient)null!)); + } + + [Fact] + public void Constructor_NullFactory_Throws() + { + Assert.Throws("innerClientFactory", () => new HostedConversationClientBuilder((Func)null!)); + } + + [Fact] + public void Use_NullFactory_Throws() + { + using var innerClient = new TestHostedConversationClient(); + var builder = new HostedConversationClientBuilder(innerClient); + Assert.Throws("clientFactory", () => builder.Use((Func)null!)); + Assert.Throws("clientFactory", () => builder.Use((Func)null!)); + } + + [Fact] + public void Build_FactoryReturnsNull_Throws() + { + using var innerClient = new TestHostedConversationClient(); + var builder = new HostedConversationClientBuilder(innerClient); + builder.Use(_ => null!); + var ex = Assert.Throws(() => builder.Build()); + Assert.Contains("entry at index 0", ex.Message); + } + + private sealed class NamedDelegatingHostedConversationClient : DelegatingHostedConversationClient + { + public NamedDelegatingHostedConversationClient(string name, IHostedConversationClient innerClient) + : base(innerClient) + { + Name = name; + } + + public string Name { get; } + + public IHostedConversationClient GetInnerClient() => InnerClient; + } + + private sealed class TestHostedConversationClient : IHostedConversationClient + { + public Task CreateAsync(HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => Task.FromResult(new HostedConversation { ConversationId = "test" }); + + public Task GetAsync(string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => Task.FromResult(new HostedConversation { ConversationId = conversationId }); + + public Task DeleteAsync(string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => Task.CompletedTask; + + public Task AddMessagesAsync(string conversationId, IEnumerable messages, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => Task.CompletedTask; + + public IAsyncEnumerable GetMessagesAsync(string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => EmptyAsync(); + + private static async IAsyncEnumerable EmptyAsync() + { + await Task.CompletedTask; + yield break; + } + + public object? GetService(Type serviceType, object? serviceKey = null) + => serviceType is not null && serviceKey is null && serviceType.IsInstanceOfType(this) ? this : null; + + public void Dispose() + { + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/HostedConversation/LoggingHostedConversationClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/HostedConversation/LoggingHostedConversationClientTests.cs new file mode 100644 index 00000000000..32fb7cbd79c --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/HostedConversation/LoggingHostedConversationClientTests.cs @@ -0,0 +1,205 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable MEAI001 + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class LoggingHostedConversationClientTests +{ + [Fact] + public void Constructor_InvalidArgs_Throws() + { + using var innerClient = new TestHostedConversationClient(); + Assert.Throws("innerClient", () => new LoggingHostedConversationClient(null!, NullLogger.Instance)); + Assert.Throws("logger", () => new LoggingHostedConversationClient(innerClient, null!)); + } + + [Fact] + public void Constructor_ValidArgs_CreatesWithoutError() + { + using var innerClient = new TestHostedConversationClient(); + using var client = new LoggingHostedConversationClient(innerClient, NullLogger.Instance); + Assert.NotNull(client); + } + + [Theory] + [InlineData(LogLevel.Trace)] + [InlineData(LogLevel.Debug)] + [InlineData(LogLevel.Information)] + public async Task CreateAsync_LogsInvocationAndCompletion(LogLevel level) + { + // Arrange + var collector = new FakeLogCollector(); + ServiceCollection c = new(); + c.AddLogging(b => b.AddProvider(new FakeLoggerProvider(collector)).SetMinimumLevel(level)); + var services = c.BuildServiceProvider(); + + using var innerClient = new TestHostedConversationClient + { + CreateAsyncCallback = (_, _) => + Task.FromResult(new HostedConversation { ConversationId = "conv-1" }) + }; + + var builder = new HostedConversationClientBuilder(innerClient); + builder.UseLogging(services.GetRequiredService()); + using var client = builder.Build(services); + + // Act + await client.CreateAsync(new HostedConversationClientOptions()); + + // Assert + var logs = collector.GetSnapshot(); + if (level is LogLevel.Trace) + { + Assert.True(logs.Count >= 2); + Assert.Contains(logs, e => e.Message.Contains("CreateAsync") && e.Message.Contains("invoked")); + Assert.Contains(logs, e => e.Message.Contains("CreateAsync") && e.Message.Contains("completed")); + } + else if (level is LogLevel.Debug) + { + Assert.True(logs.Count >= 2); + Assert.Contains(logs, e => e.Message.Contains("CreateAsync") && e.Message.Contains("invoked")); + Assert.Contains(logs, e => e.Message.Contains("CreateAsync") && e.Message.Contains("completed")); + } + else + { + Assert.Empty(logs); + } + } + + [Theory] + [InlineData(LogLevel.Trace)] + [InlineData(LogLevel.Debug)] + [InlineData(LogLevel.Information)] + public async Task GetAsync_LogsInvocationAndCompletion(LogLevel level) + { + // Arrange + var collector = new FakeLogCollector(); + ServiceCollection c = new(); + c.AddLogging(b => b.AddProvider(new FakeLoggerProvider(collector)).SetMinimumLevel(level)); + var services = c.BuildServiceProvider(); + + using var innerClient = new TestHostedConversationClient + { + GetAsyncCallback = (id, opts, _) => + Task.FromResult(new HostedConversation { ConversationId = id }) + }; + + var builder = new HostedConversationClientBuilder(innerClient); + builder.UseLogging(services.GetRequiredService()); + using var client = builder.Build(services); + + // Act + await client.GetAsync("conv-42"); + + // Assert + var logs = collector.GetSnapshot(); + if (level <= LogLevel.Debug) + { + Assert.True(logs.Count >= 2); + Assert.Contains(logs, e => e.Message.Contains("GetAsync") && e.Message.Contains("conv-42")); + Assert.Contains(logs, e => e.Message.Contains("GetAsync") && e.Message.Contains("completed")); + } + else + { + Assert.Empty(logs); + } + } + + [Theory] + [InlineData(LogLevel.Trace)] + [InlineData(LogLevel.Debug)] + [InlineData(LogLevel.Information)] + public async Task DeleteAsync_LogsInvocationAndCompletion(LogLevel level) + { + // Arrange + var collector = new FakeLogCollector(); + ServiceCollection c = new(); + c.AddLogging(b => b.AddProvider(new FakeLoggerProvider(collector)).SetMinimumLevel(level)); + var services = c.BuildServiceProvider(); + + using var innerClient = new TestHostedConversationClient + { + DeleteAsyncCallback = (_, opts, _) => Task.CompletedTask + }; + + var builder = new HostedConversationClientBuilder(innerClient); + builder.UseLogging(services.GetRequiredService()); + using var client = builder.Build(services); + + // Act + await client.DeleteAsync("conv-99"); + + // Assert + var logs = collector.GetSnapshot(); + if (level <= LogLevel.Debug) + { + Assert.True(logs.Count >= 2); + Assert.Contains(logs, e => e.Message.Contains("DeleteAsync") && e.Message.Contains("conv-99")); + Assert.Contains(logs, e => e.Message.Contains("DeleteAsync") && e.Message.Contains("completed")); + } + else + { + Assert.Empty(logs); + } + } + + [Fact] + public void UseLogging_AvoidsInjectingNopClient() + { + using var innerClient = new TestHostedConversationClient(); + + var builder = new HostedConversationClientBuilder(innerClient); + builder.UseLogging(NullLoggerFactory.Instance); + using var built = builder.Build(); + + // When NullLoggerFactory is used, LoggingHostedConversationClient should be skipped + Assert.Null(built.GetService(typeof(LoggingHostedConversationClient))); + } + + private sealed class TestHostedConversationClient : IHostedConversationClient + { + public Func>? CreateAsyncCallback { get; set; } + public Func>? GetAsyncCallback { get; set; } + public Func? DeleteAsyncCallback { get; set; } + + public Task CreateAsync(HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => CreateAsyncCallback?.Invoke(options, cancellationToken) ?? Task.FromResult(new HostedConversation { ConversationId = "test" }); + + public Task GetAsync(string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => GetAsyncCallback?.Invoke(conversationId, options, cancellationToken) ?? Task.FromResult(new HostedConversation { ConversationId = conversationId }); + + public Task DeleteAsync(string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => DeleteAsyncCallback?.Invoke(conversationId, options, cancellationToken) ?? Task.CompletedTask; + + public Task AddMessagesAsync(string conversationId, IEnumerable messages, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => Task.CompletedTask; + + public IAsyncEnumerable GetMessagesAsync(string conversationId, HostedConversationClientOptions? options = null, CancellationToken cancellationToken = default) + => EmptyAsync(); + + private static async IAsyncEnumerable EmptyAsync() + { + await Task.CompletedTask; + yield break; + } + + public object? GetService(Type serviceType, object? serviceKey = null) + => serviceType is not null && serviceKey is null && serviceType.IsInstanceOfType(this) ? this : null; + + public void Dispose() + { + } + } +}