diff --git a/docs/auth/byok.md b/docs/auth/byok.md index a4a131913..b4718c19a 100644 --- a/docs/auth/byok.md +++ b/docs/auth/byok.md @@ -175,6 +175,7 @@ Console.WriteLine(response?.Data.Content); | `apiKey` / `api_key` | string | API key (optional for local providers like Ollama) | | `bearerToken` / `bearer_token` | string | Bearer token auth (takes precedence over apiKey) | | `wireApi` / `wire_api` | `"completions"` \| `"responses"` | API format (default: `"completions"`) | +| `headers` | `Record` | Custom HTTP headers for all outbound requests ([details](#custom-headers)) | | `azure.apiVersion` / `azure.api_version` | string | Azure API version (default: `"2024-10-21"`) | ### Wire API Format @@ -304,6 +305,327 @@ provider: { > **Note:** The `bearerToken` option accepts a **static token string** only. The SDK does not refresh this token automatically. If your token expires, requests will fail and you'll need to create a new session with a fresh token. +## Custom Headers + +Custom headers let you attach additional HTTP headers to every outbound model request. This is useful when your provider endpoint sits behind an API gateway or proxy that requires extra authentication or routing headers. + +### Use Cases + +| Scenario | Example Header | +|----------|---------------| +| Azure API Management / AI Gateway | `Ocp-Apim-Subscription-Key` | +| Cloudflare Tunnel authentication | `CF-Access-Client-Id`, `CF-Access-Client-Secret` | +| Custom API gateways with proprietary auth | `X-Gateway-Auth`, `X-Tenant-Id` | +| BYOK routing through enterprise proxies | `X-Proxy-Authorization`, `X-Route-Target` | + +### Session-Level Headers + +Set `headers` on `ProviderConfig` when creating a session. These headers are included in **every** outbound request for the lifetime of the session. + +
+Node.js / TypeScript + +```typescript +import { CopilotClient } from "@github/copilot-sdk"; + +const client = new CopilotClient(); +const session = await client.createSession({ + model: "gpt-4.1", + provider: { + type: "openai", + baseUrl: "https://my-gateway.example.com/v1", + apiKey: process.env.OPENAI_API_KEY, + headers: { + "Ocp-Apim-Subscription-Key": process.env.APIM_KEY!, + "X-Tenant-Id": "my-team", + }, + }, +}); +``` + +
+ +
+Python + +```python +import os +from copilot import CopilotClient + +client = CopilotClient() +await client.start() + +session = await client.create_session( + model="gpt-4.1", + provider={ + "type": "openai", + "base_url": "https://my-gateway.example.com/v1", + "api_key": os.environ["OPENAI_API_KEY"], + "headers": { + "Ocp-Apim-Subscription-Key": os.environ["APIM_KEY"], + "X-Tenant-Id": "my-team", + }, + }, +) +``` + +
+ +
+Go + +```go +session, err := client.CreateSession(ctx, &copilot.SessionConfig{ + Model: "gpt-4.1", + Provider: &copilot.ProviderConfig{ + Type: "openai", + BaseURL: "https://my-gateway.example.com/v1", + APIKey: os.Getenv("OPENAI_API_KEY"), + Headers: map[string]string{ + "Ocp-Apim-Subscription-Key": os.Getenv("APIM_KEY"), + "X-Tenant-Id": "my-team", + }, + }, +}) +``` + +
+ +
+.NET + +```csharp +var session = await client.CreateSessionAsync(new SessionConfig +{ + Model = "gpt-4.1", + Provider = new ProviderConfig + { + Type = "openai", + BaseUrl = "https://my-gateway.example.com/v1", + ApiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY"), + Headers = new Dictionary + { + ["Ocp-Apim-Subscription-Key"] = Environment.GetEnvironmentVariable("APIM_KEY")!, + ["X-Tenant-Id"] = "my-team", + }, + }, +}); +``` + +
+ +### Per-Turn Headers + +Pass `requestHeaders` on `send()` to include headers for a **single turn** only. This is useful when headers change between requests (e.g., per-request trace IDs or rotating tokens). + +
+Node.js / TypeScript + +```typescript +await session.send({ + prompt: "Summarize this document", + requestHeaders: { + "X-Request-Id": crypto.randomUUID(), + }, +}); +``` + +
+ +
+Python + +```python +import uuid + +await session.send( + "Summarize this document", + request_headers={ + "X-Request-Id": str(uuid.uuid4()), + }, +) +``` + +
+ +
+Go + +```go +_, err := session.Send(ctx, copilot.MessageOptions{ + Prompt: "Summarize this document", + RequestHeaders: map[string]string{ + "X-Request-Id": uuid.NewString(), + }, +}) +``` + +
+ +
+.NET + +```csharp +await session.SendAsync(new MessageOptions +{ + Prompt = "Summarize this document", + RequestHeaders = new Dictionary + { + ["X-Request-Id"] = Guid.NewGuid().ToString(), + }, +}); +``` + +
+ +### Header Merge Strategy + +When you provide both session-level `headers` and per-turn `requestHeaders`, the `headerMergeStrategy` controls how they combine. + +| Strategy | Behavior | +|----------|----------| +| `"override"` (default) | Per-turn headers **completely replace** session-level headers. No session headers are sent for that turn. This is the safest default — no unexpected header leakage. | +| `"merge"` | Per-turn headers are **merged** with session-level headers. Per-turn values win on key conflicts. | + +#### Override (Default) + +```typescript +// Session created with headers: { "X-Team": "alpha", "X-Env": "prod" } + +await session.send({ + prompt: "Hello", + requestHeaders: { "X-Request-Id": "abc-123" }, + // headerMergeStrategy defaults to "override" +}); +// Only "X-Request-Id" is sent — session headers are NOT included +``` + +#### Merge + +```typescript +// Session created with headers: { "X-Team": "alpha", "X-Env": "prod" } + +await session.send({ + prompt: "Hello", + requestHeaders: { "X-Env": "staging", "X-Request-Id": "abc-123" }, + headerMergeStrategy: "merge", +}); +// Sent headers: { "X-Team": "alpha", "X-Env": "staging", "X-Request-Id": "abc-123" } +// "X-Env" from per-turn wins over session-level value +``` + +The merge strategy setting is available in all languages: + +| Language | Field | +|----------|-------| +| TypeScript | `headerMergeStrategy: "override" \| "merge"` | +| Python | `header_merge_strategy: Literal["override", "merge"]` | +| Go | `HeaderMergeStrategy: copilot.HeaderMergeStrategyOverride \| copilot.HeaderMergeStrategyMerge` | +| C# | `HeaderMergeStrategy = HeaderMergeStrategy.Override \| HeaderMergeStrategy.Merge` | + +### Updating Provider Configuration Mid-Session + +Use `updateProvider()` to change provider configuration — including headers — between turns without recreating the session. This is useful for rotating API keys, switching tenants, or adjusting gateway headers on the fly. + +
+Node.js / TypeScript + +```typescript +// Rotate the subscription key between turns +await session.updateProvider({ + headers: { + "Ocp-Apim-Subscription-Key": newSubscriptionKey, + "X-Tenant-Id": "new-team", + }, +}); + +// Subsequent sends use the updated headers +await session.send({ prompt: "Continue" }); +``` + +
+ +
+Python + +```python +await session.update_provider({ + "headers": { + "Ocp-Apim-Subscription-Key": new_subscription_key, + "X-Tenant-Id": "new-team", + }, +}) + +await session.send("Continue") +``` + +
+ +
+Go + +```go +err := session.UpdateProvider(ctx, copilot.ProviderConfig{ + Headers: map[string]string{ + "Ocp-Apim-Subscription-Key": newSubscriptionKey, + "X-Tenant-Id": "new-team", + }, +}) + +_, err = session.Send(ctx, copilot.MessageOptions{Prompt: "Continue"}) +``` + +
+ +
+.NET + +```csharp +await session.UpdateProviderAsync(new ProviderConfig +{ + Headers = new Dictionary + { + ["Ocp-Apim-Subscription-Key"] = newSubscriptionKey, + ["X-Tenant-Id"] = "new-team", + }, +}); + +await session.SendAsync(new MessageOptions { Prompt = "Continue" }); +``` + +
+ +### Environment Variable Expansion + +Header values support environment variable expansion at the runtime level. This lets you reference secrets without hardcoding them in your application code. + +| Syntax | Behavior | +|--------|----------| +| `${VAR}` | Replaced with the value of `VAR`. Fails if `VAR` is not set. | +| `$VAR` | Same as `${VAR}`. | +| `${VAR:-default}` | Replaced with the value of `VAR`, or `default` if `VAR` is not set. | + +```typescript +provider: { + type: "openai", + baseUrl: "https://my-gateway.example.com/v1", + headers: { + // Expanded at runtime from the APIM_KEY environment variable + "Ocp-Apim-Subscription-Key": "${APIM_KEY}", + // Falls back to "default-tenant" if X_TENANT is not set + "X-Tenant-Id": "${X_TENANT:-default-tenant}", + }, +} +``` + +> **Note:** Expansion is performed by the CLI server, not the SDK client. The SDK passes header values as-is to the server, which resolves environment variables before sending requests to your provider. + +### Security Considerations + +- **Scoped to your endpoint** — Custom headers are sent only to the configured `baseUrl`. They are never sent to GitHub Copilot servers or other endpoints. +- **Prefer env var expansion** — Use `${VAR}` syntax for sensitive values like API keys and tokens rather than hardcoding them. This avoids secrets in source code and logs. +- **Override is the safe default** — The default `headerMergeStrategy` of `"override"` ensures per-turn headers completely replace session-level headers, preventing accidental leakage of session headers into turns that specify their own. + ## Custom Model Listing When using BYOK, the CLI server may not know which models your provider supports. You can supply a custom `onListModels` handler at the client level so that `client.listModels()` returns your provider's models in the standard `ModelInfo` format. This lets downstream consumers discover available models without querying the CLI. diff --git a/dotnet/README.md b/dotnet/README.md index 0f67fb11a..02c35a220 100644 --- a/dotnet/README.md +++ b/dotnet/README.md @@ -596,6 +596,29 @@ var session = await client.CreateSessionAsync(new SessionConfig }); ``` +### Custom Headers + +You can attach custom HTTP headers to outbound model requests — useful for API gateways, proxy authentication, or tenant routing: + +```csharp +var session = await client.CreateSessionAsync(new SessionConfig +{ + Model = "gpt-4.1", + Provider = new ProviderConfig + { + Type = "openai", + BaseUrl = "https://my-gateway.example.com/v1", + ApiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY"), + Headers = new Dictionary + { + ["Ocp-Apim-Subscription-Key"] = "${APIM_KEY}", + }, + }, +}); +``` + +Per-turn headers and merge strategies are also supported. See the [Custom Headers](docs/auth/byok.md#custom-headers) section in the BYOK guide for full details. + ## Telemetry The SDK supports OpenTelemetry for distributed tracing. Provide a `Telemetry` config to enable trace export and automatic W3C Trace Context propagation. diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index 675a3e0c0..7e32e7716 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -162,6 +162,8 @@ public async Task SendAsync(MessageOptions options, CancellationToken ca Prompt = options.Prompt, Attachments = options.Attachments, Mode = options.Mode, + RequestHeaders = options.RequestHeaders, + HeaderMergeStrategy = options.HeaderMergeStrategy, Traceparent = traceparent, Tracestate = tracestate }; @@ -783,6 +785,20 @@ await InvokeRpcAsync( "session.abort", [new SessionAbortRequest { SessionId = SessionId }], cancellationToken); } + /// + /// Update the provider configuration for this session. + /// This allows changing headers, authentication, or other provider settings between turns. + /// + /// The provider configuration to update. + /// Optional cancellation token. + public async Task UpdateProviderAsync(ProviderConfig provider, CancellationToken cancellationToken = default) + { + await InvokeRpcAsync( + "session.provider.update", + [new SessionProviderUpdateRequest { SessionId = SessionId, Provider = provider }], + cancellationToken); + } + /// /// Changes the model for this session. /// The new model takes effect for the next message. Conversation history is preserved. @@ -906,6 +922,8 @@ internal record SendMessageRequest public string Prompt { get; init; } = string.Empty; public List? Attachments { get; init; } public string? Mode { get; init; } + public Dictionary? RequestHeaders { get; init; } + public string? HeaderMergeStrategy { get; init; } public string? Traceparent { get; init; } public string? Tracestate { get; init; } } @@ -935,6 +953,12 @@ internal record SessionDestroyRequest public string SessionId { get; init; } = string.Empty; } + internal record SessionProviderUpdateRequest + { + public string SessionId { get; init; } = string.Empty; + public ProviderConfig Provider { get; init; } = new(); + } + [JsonSourceGenerationOptions( JsonSerializerDefaults.Web, AllowOutOfOrderMetadataProperties = true, @@ -946,6 +970,7 @@ internal record SessionDestroyRequest [JsonSerializable(typeof(SendMessageResponse))] [JsonSerializable(typeof(SessionAbortRequest))] [JsonSerializable(typeof(SessionDestroyRequest))] + [JsonSerializable(typeof(SessionProviderUpdateRequest))] [JsonSerializable(typeof(UserMessageDataAttachmentsItem))] [JsonSerializable(typeof(PreToolUseHookInput))] [JsonSerializable(typeof(PreToolUseHookOutput))] diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index d6530f9c7..4c96e9c8b 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -1110,6 +1110,13 @@ public class ProviderConfig [JsonPropertyName("bearerToken")] public string? BearerToken { get; set; } + /// + /// Custom HTTP headers to include in all outbound requests to the provider. + /// Supports env var expansion (e.g. ${VAR}, ${VAR:-default}). + /// + [JsonPropertyName("headers")] + public Dictionary? Headers { get; set; } + /// /// Azure-specific configuration options. /// @@ -1129,6 +1136,17 @@ public class AzureOptions public string? ApiVersion { get; set; } } +/// +/// Strategy for merging per-turn request headers with session-level provider headers. +/// +public static class HeaderMergeStrategy +{ + /// Per-turn headers completely replace session-level headers. + public const string Override = "override"; + /// Per-turn headers are merged with session-level headers; per-turn wins on conflicts. + public const string Merge = "merge"; +} + // ============================================================================ // MCP Server Configuration Types // ============================================================================ @@ -1686,6 +1704,8 @@ protected MessageOptions(MessageOptions? other) Attachments = other.Attachments is not null ? [.. other.Attachments] : null; Mode = other.Mode; Prompt = other.Prompt; + RequestHeaders = other.RequestHeaders is not null ? new(other.RequestHeaders) : null; + HeaderMergeStrategy = other.HeaderMergeStrategy; } /// @@ -1700,6 +1720,17 @@ protected MessageOptions(MessageOptions? other) /// Interaction mode for the message (e.g., "plan", "edit"). /// public string? Mode { get; set; } + /// + /// Custom HTTP headers to include in outbound model requests for this turn only. + /// + [JsonPropertyName("requestHeaders")] + public Dictionary? RequestHeaders { get; set; } + /// + /// Strategy for merging per-turn request headers with session-level provider headers. + /// Defaults to "override". + /// + [JsonPropertyName("headerMergeStrategy")] + public string? HeaderMergeStrategy { get; set; } /// /// Creates a shallow clone of this instance. diff --git a/dotnet/test/CustomHeadersTests.cs b/dotnet/test/CustomHeadersTests.cs new file mode 100644 index 000000000..93d34c984 --- /dev/null +++ b/dotnet/test/CustomHeadersTests.cs @@ -0,0 +1,182 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using Xunit; + +namespace GitHub.Copilot.SDK.Test; + +/// +/// Unit tests for custom headers support in MessageOptions, ProviderConfig, +/// HeaderMergeStrategy, and related serialization. +/// +public class CustomHeadersTests +{ + [Fact] + public void ProviderConfig_Headers_CanBeSet() + { + var config = new ProviderConfig + { + BaseUrl = "https://api.example.com", + ApiKey = "test-key", + Headers = new Dictionary + { + ["X-Custom"] = "value", + ["Authorization"] = "Bearer tok", + }, + }; + + Assert.Equal("value", config.Headers["X-Custom"]); + Assert.Equal("Bearer tok", config.Headers["Authorization"]); + } + + [Fact] + public void ProviderConfig_Headers_NullByDefault() + { + var config = new ProviderConfig(); + Assert.Null(config.Headers); + } + + [Fact] + public void ProviderConfig_Headers_CanBeEmpty() + { + var config = new ProviderConfig + { + Headers = new Dictionary(), + }; + + Assert.NotNull(config.Headers); + Assert.Empty(config.Headers); + } + + [Fact] + public void HeaderMergeStrategy_HasExpectedValues() + { + Assert.Equal("override", HeaderMergeStrategy.Override); + Assert.Equal("merge", HeaderMergeStrategy.Merge); + } + + [Fact] + public void MessageOptions_RequestHeaders_CanBeSet() + { + var options = new MessageOptions + { + Prompt = "test", + RequestHeaders = new Dictionary + { + ["X-Custom"] = "value", + ["X-Another"] = "other", + }, + }; + + Assert.Equal("value", options.RequestHeaders["X-Custom"]); + Assert.Equal("other", options.RequestHeaders["X-Another"]); + } + + [Fact] + public void MessageOptions_HeaderMergeStrategy_Override() + { + var options = new MessageOptions + { + Prompt = "test", + HeaderMergeStrategy = HeaderMergeStrategy.Override, + }; + + Assert.Equal("override", options.HeaderMergeStrategy); + } + + [Fact] + public void MessageOptions_HeaderMergeStrategy_Merge() + { + var options = new MessageOptions + { + Prompt = "test", + HeaderMergeStrategy = HeaderMergeStrategy.Merge, + }; + + Assert.Equal("merge", options.HeaderMergeStrategy); + } + + [Fact] + public void MessageOptions_RequestHeaders_NullByDefault() + { + var options = new MessageOptions { Prompt = "test" }; + Assert.Null(options.RequestHeaders); + Assert.Null(options.HeaderMergeStrategy); + } + + [Fact] + public void MessageOptions_Clone_CopiesHeaders() + { + var original = new MessageOptions + { + Prompt = "test", + RequestHeaders = new Dictionary + { + ["X-Custom"] = "value", + }, + HeaderMergeStrategy = HeaderMergeStrategy.Merge, + }; + + var clone = original.Clone(); + + Assert.Equal(original.Prompt, clone.Prompt); + Assert.Equal(original.HeaderMergeStrategy, clone.HeaderMergeStrategy); + Assert.NotNull(clone.RequestHeaders); + Assert.Equal("value", clone.RequestHeaders["X-Custom"]); + } + + [Fact] + public void MessageOptions_Clone_HeadersAreIndependent() + { + var original = new MessageOptions + { + Prompt = "test", + RequestHeaders = new Dictionary + { + ["X-Custom"] = "value", + }, + }; + + var clone = original.Clone(); + + clone.RequestHeaders!["X-New"] = "added"; + + Assert.False(original.RequestHeaders.ContainsKey("X-New")); + Assert.Single(original.RequestHeaders); + } + + [Fact] + public void MessageOptions_Clone_NullHeaders_StayNull() + { + var original = new MessageOptions + { + Prompt = "test", + }; + + var clone = original.Clone(); + + Assert.Null(clone.RequestHeaders); + Assert.Null(clone.HeaderMergeStrategy); + } + + [Fact] + public void MessageOptions_WithAllHeaderFields() + { + var options = new MessageOptions + { + Prompt = "hello", + RequestHeaders = new Dictionary + { + ["X-Request-Id"] = "req-123", + }, + HeaderMergeStrategy = HeaderMergeStrategy.Merge, + Mode = "enqueue", + }; + + Assert.Equal("hello", options.Prompt); + Assert.Equal("req-123", options.RequestHeaders["X-Request-Id"]); + Assert.Equal("merge", options.HeaderMergeStrategy); + Assert.Equal("enqueue", options.Mode); + } +} diff --git a/go/README.md b/go/README.md index f29ef9fb7..7e2101764 100644 --- a/go/README.md +++ b/go/README.md @@ -539,6 +539,26 @@ session, err := client.CreateSession(context.Background(), &copilot.SessionConfi > - For Azure OpenAI endpoints (`*.openai.azure.com`), you **must** use `Type: "azure"`, not `Type: "openai"`. > - The `BaseURL` should be just the host (e.g., `https://my-resource.openai.azure.com`). Do **not** include `/openai/v1` in the URL - the SDK handles path construction automatically. +### Custom Headers + +You can attach custom HTTP headers to outbound model requests — useful for API gateways, proxy authentication, or tenant routing: + +```go +session, err := client.CreateSession(ctx, &copilot.SessionConfig{ + Model: "gpt-4.1", + Provider: &copilot.ProviderConfig{ + Type: "openai", + BaseURL: "https://my-gateway.example.com/v1", + APIKey: os.Getenv("OPENAI_API_KEY"), + Headers: map[string]string{ + "Ocp-Apim-Subscription-Key": "${APIM_KEY}", + }, + }, +}) +``` + +Per-turn headers and merge strategies are also supported. See the [Custom Headers](docs/auth/byok.md#custom-headers) section in the BYOK guide for full details. + ## Telemetry The SDK supports OpenTelemetry for distributed tracing. Provide a `Telemetry` config to enable trace export and automatic W3C Trace Context propagation. diff --git a/go/custom_headers_test.go b/go/custom_headers_test.go new file mode 100644 index 000000000..1e4d16d00 --- /dev/null +++ b/go/custom_headers_test.go @@ -0,0 +1,256 @@ +package copilot + +import ( + "encoding/json" + "testing" +) + +func TestProviderConfig_Headers_JSONRoundTrip(t *testing.T) { + t.Run("serializes headers", func(t *testing.T) { + config := ProviderConfig{ + BaseURL: "https://api.example.com", + APIKey: "test-key", + Headers: map[string]string{ + "X-Custom": "value", + "Authorization": "Bearer tok", + }, + } + + data, err := json.Marshal(config) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded ProviderConfig + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if decoded.Headers["X-Custom"] != "value" { + t.Errorf("expected X-Custom=value, got %q", decoded.Headers["X-Custom"]) + } + if decoded.Headers["Authorization"] != "Bearer tok" { + t.Errorf("expected Authorization=Bearer tok, got %q", decoded.Headers["Authorization"]) + } + }) + + t.Run("omits headers when nil", func(t *testing.T) { + config := ProviderConfig{ + BaseURL: "https://api.example.com", + } + + data, err := json.Marshal(config) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + jsonStr := string(data) + if contains(jsonStr, "headers") { + t.Errorf("expected headers to be omitted, got %s", jsonStr) + } + }) + + t.Run("omits empty headers with omitempty", func(t *testing.T) { + config := ProviderConfig{ + BaseURL: "https://api.example.com", + Headers: map[string]string{}, + } + + data, err := json.Marshal(config) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + jsonStr := string(data) + // Go's omitempty omits empty maps + if contains(jsonStr, `"headers"`) { + t.Errorf("expected empty headers to be omitted with omitempty, got %s", jsonStr) + } + }) +} + +func TestHeaderMergeStrategy_Constants(t *testing.T) { + tests := []struct { + name string + strategy HeaderMergeStrategy + expected string + }{ + {"Override", HeaderMergeStrategyOverride, "override"}, + {"Merge", HeaderMergeStrategyMerge, "merge"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if string(tt.strategy) != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, string(tt.strategy)) + } + }) + } +} + +func TestSessionSendRequest_Headers_JSONSerialization(t *testing.T) { + t.Run("includes requestHeaders and headerMergeStrategy", func(t *testing.T) { + req := sessionSendRequest{ + SessionID: "sess-1", + Prompt: "hello", + RequestHeaders: map[string]string{"X-Custom": "value", "X-Another": "other"}, + HeaderMergeStrategy: HeaderMergeStrategyMerge, + } + + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded map[string]interface{} + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + headers, ok := decoded["requestHeaders"].(map[string]interface{}) + if !ok { + t.Fatal("expected requestHeaders to be a map") + } + if headers["X-Custom"] != "value" { + t.Errorf("expected X-Custom=value, got %v", headers["X-Custom"]) + } + if headers["X-Another"] != "other" { + t.Errorf("expected X-Another=other, got %v", headers["X-Another"]) + } + if decoded["headerMergeStrategy"] != "merge" { + t.Errorf("expected headerMergeStrategy=merge, got %v", decoded["headerMergeStrategy"]) + } + }) + + t.Run("includes override strategy", func(t *testing.T) { + req := sessionSendRequest{ + SessionID: "sess-1", + Prompt: "hello", + RequestHeaders: map[string]string{"X-Key": "val"}, + HeaderMergeStrategy: HeaderMergeStrategyOverride, + } + + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded map[string]interface{} + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if decoded["headerMergeStrategy"] != "override" { + t.Errorf("expected headerMergeStrategy=override, got %v", decoded["headerMergeStrategy"]) + } + }) + + t.Run("omits requestHeaders when nil", func(t *testing.T) { + req := sessionSendRequest{ + SessionID: "sess-1", + Prompt: "hello", + } + + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + jsonStr := string(data) + if contains(jsonStr, "requestHeaders") { + t.Errorf("expected requestHeaders to be omitted, got %s", jsonStr) + } + if contains(jsonStr, "headerMergeStrategy") { + t.Errorf("expected headerMergeStrategy to be omitted, got %s", jsonStr) + } + }) + + t.Run("roundtrip with headers preserves values", func(t *testing.T) { + original := sessionSendRequest{ + SessionID: "sess-1", + Prompt: "test", + RequestHeaders: map[string]string{"Authorization": "Bearer token123"}, + HeaderMergeStrategy: HeaderMergeStrategyOverride, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var decoded sessionSendRequest + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if decoded.RequestHeaders["Authorization"] != "Bearer token123" { + t.Errorf("expected Authorization=Bearer token123, got %q", decoded.RequestHeaders["Authorization"]) + } + if decoded.HeaderMergeStrategy != HeaderMergeStrategyOverride { + t.Errorf("expected strategy override, got %q", decoded.HeaderMergeStrategy) + } + }) + + t.Run("omits empty requestHeaders with omitempty", func(t *testing.T) { + req := sessionSendRequest{ + SessionID: "sess-1", + Prompt: "hello", + RequestHeaders: map[string]string{}, + } + + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + jsonStr := string(data) + // Go's omitempty omits empty maps + if contains(jsonStr, `"requestHeaders"`) { + t.Errorf("expected empty requestHeaders to be omitted with omitempty, got %s", jsonStr) + } + }) +} + +func TestMessageOptions_Headers(t *testing.T) { + t.Run("can set request headers and merge strategy", func(t *testing.T) { + opts := MessageOptions{ + Prompt: "test", + RequestHeaders: map[string]string{"X-Key": "val"}, + HeaderMergeStrategy: HeaderMergeStrategyMerge, + } + + if opts.RequestHeaders["X-Key"] != "val" { + t.Errorf("expected X-Key=val, got %q", opts.RequestHeaders["X-Key"]) + } + if opts.HeaderMergeStrategy != HeaderMergeStrategyMerge { + t.Errorf("expected merge strategy, got %q", opts.HeaderMergeStrategy) + } + }) + + t.Run("defaults to empty values", func(t *testing.T) { + opts := MessageOptions{ + Prompt: "test", + } + + if opts.RequestHeaders != nil { + t.Errorf("expected nil RequestHeaders, got %v", opts.RequestHeaders) + } + if opts.HeaderMergeStrategy != "" { + t.Errorf("expected empty HeaderMergeStrategy, got %q", opts.HeaderMergeStrategy) + } + }) +} + +// contains checks if substr is present in s. +func contains(s, substr string) bool { + return len(s) >= len(substr) && searchString(s, substr) +} + +func searchString(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/go/session.go b/go/session.go index cf970450d..da0c017bb 100644 --- a/go/session.go +++ b/go/session.go @@ -123,12 +123,14 @@ func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) func (s *Session) Send(ctx context.Context, options MessageOptions) (string, error) { traceparent, tracestate := getTraceContext(ctx) req := sessionSendRequest{ - SessionID: s.SessionID, - Prompt: options.Prompt, - Attachments: options.Attachments, - Mode: options.Mode, - Traceparent: traceparent, - Tracestate: tracestate, + SessionID: s.SessionID, + Prompt: options.Prompt, + Attachments: options.Attachments, + Mode: options.Mode, + RequestHeaders: options.RequestHeaders, + HeaderMergeStrategy: options.HeaderMergeStrategy, + Traceparent: traceparent, + Tracestate: tracestate, } result, err := s.client.Request("session.send", req) @@ -840,6 +842,23 @@ func (s *Session) SetModel(ctx context.Context, model string, opts *SetModelOpti return nil } +// UpdateProvider updates the provider configuration for this session. +// This allows changing headers, authentication, or other provider settings between turns. +func (s *Session) UpdateProvider(ctx context.Context, provider ProviderConfig) error { + req := struct { + SessionID string `json:"sessionId"` + Provider ProviderConfig `json:"provider"` + }{ + SessionID: s.SessionID, + Provider: provider, + } + _, err := s.client.Request("session.provider.update", req) + if err != nil { + return fmt.Errorf("failed to update provider: %w", err) + } + return nil +} + // LogOptions configures optional parameters for [Session.Log]. type LogOptions struct { // Level sets the log severity. Valid values are [rpc.LevelInfo] (default), diff --git a/go/types.go b/go/types.go index f888c9b6e..787a88bd2 100644 --- a/go/types.go +++ b/go/types.go @@ -599,6 +599,9 @@ type ProviderConfig struct { // Use this for services requiring bearer token auth instead of API key. // Takes precedence over APIKey when both are set. BearerToken string `json:"bearerToken,omitempty"` + // Headers contains custom HTTP headers to include in all outbound requests to the provider. + // Supports env var expansion (e.g. ${VAR}, ${VAR:-default}). + Headers map[string]string `json:"headers,omitempty"` // Azure contains Azure-specific options Azure *AzureProviderOptions `json:"azure,omitempty"` } @@ -609,6 +612,16 @@ type AzureProviderOptions struct { APIVersion string `json:"apiVersion,omitempty"` } +// HeaderMergeStrategy defines how per-turn request headers are merged with session-level provider headers. +type HeaderMergeStrategy string + +const ( + // HeaderMergeStrategyOverride means per-turn headers completely replace session-level headers. + HeaderMergeStrategyOverride HeaderMergeStrategy = "override" + // HeaderMergeStrategyMerge means per-turn headers are merged with session-level headers; per-turn wins on conflicts. + HeaderMergeStrategyMerge HeaderMergeStrategy = "merge" +) + // ToolBinaryResult represents binary payloads returned by tools. type ToolBinaryResult struct { Data string `json:"data"` @@ -625,6 +638,10 @@ type MessageOptions struct { Attachments []Attachment // Mode is the message delivery mode (default: "enqueue") Mode string + // RequestHeaders contains custom HTTP headers for this turn only. + RequestHeaders map[string]string + // HeaderMergeStrategy defines how per-turn headers merge with session-level headers. Defaults to "override". + HeaderMergeStrategy HeaderMergeStrategy } // SessionEventHandler is a callback for session events @@ -935,12 +952,14 @@ type sessionAbortRequest struct { } type sessionSendRequest struct { - SessionID string `json:"sessionId"` - Prompt string `json:"prompt"` - Attachments []Attachment `json:"attachments,omitempty"` - Mode string `json:"mode,omitempty"` - Traceparent string `json:"traceparent,omitempty"` - Tracestate string `json:"tracestate,omitempty"` + SessionID string `json:"sessionId"` + Prompt string `json:"prompt"` + Attachments []Attachment `json:"attachments,omitempty"` + Mode string `json:"mode,omitempty"` + RequestHeaders map[string]string `json:"requestHeaders,omitempty"` + HeaderMergeStrategy HeaderMergeStrategy `json:"headerMergeStrategy,omitempty"` + Traceparent string `json:"traceparent,omitempty"` + Tracestate string `json:"tracestate,omitempty"` } // sessionSendResponse is the response from session.send diff --git a/nodejs/README.md b/nodejs/README.md index eee4c2b65..fa465a0a7 100644 --- a/nodejs/README.md +++ b/nodejs/README.md @@ -754,6 +754,26 @@ const session = await client.createSession({ > - For Azure OpenAI endpoints (`*.openai.azure.com`), you **must** use `type: "azure"`, not `type: "openai"`. > - The `baseUrl` should be just the host (e.g., `https://my-resource.openai.azure.com`). Do **not** include `/openai/v1` in the URL - the SDK handles path construction automatically. +### Custom Headers + +You can attach custom HTTP headers to outbound model requests — useful for API gateways, proxy authentication, or tenant routing: + +```typescript +const session = await client.createSession({ + model: "gpt-4.1", + provider: { + type: "openai", + baseUrl: "https://my-gateway.example.com/v1", + apiKey: process.env.OPENAI_API_KEY, + headers: { + "Ocp-Apim-Subscription-Key": "${APIM_KEY}", + }, + }, +}); +``` + +Per-turn headers and merge strategies are also supported. See the [Custom Headers](docs/auth/byok.md#custom-headers) section in the BYOK guide for full details. + ## Telemetry The SDK supports OpenTelemetry for distributed tracing. Provide a `telemetry` config to enable trace export from the CLI process — this is all most users need: diff --git a/nodejs/src/index.ts b/nodejs/src/index.ts index 4c41d2dfe..f838addf7 100644 --- a/nodejs/src/index.ts +++ b/nodejs/src/index.ts @@ -28,6 +28,7 @@ export type { ForegroundSessionInfo, GetAuthStatusResponse, GetStatusResponse, + HeaderMergeStrategy, InfiniteSessionConfig, InputOptions, MCPLocalServerConfig, @@ -41,6 +42,7 @@ export type { PermissionHandler, PermissionRequest, PermissionRequestResult, + ProviderConfig, ResumeSessionConfig, SectionOverride, SectionOverrideAction, diff --git a/nodejs/src/session.ts b/nodejs/src/session.ts index c046edabf..6039ae286 100644 --- a/nodejs/src/session.ts +++ b/nodejs/src/session.ts @@ -23,6 +23,7 @@ import type { PermissionHandler, PermissionRequest, PermissionRequestResult, + ProviderConfig, ReasoningEffort, SectionTransformFn, SessionCapabilities, @@ -183,6 +184,8 @@ export class CopilotSession { prompt: options.prompt, attachments: options.attachments, mode: options.mode, + requestHeaders: options.requestHeaders, + headerMergeStrategy: options.headerMergeStrategy, }); return (response as { messageId: string }).messageId; @@ -1011,6 +1014,24 @@ export class CopilotSession { }); } + /** + * Update the provider configuration for this session. + * This allows changing headers, authentication, or other provider settings between turns. + * + * @param provider - Partial provider configuration to update + * + * @example + * ```typescript + * await session.updateProvider({ headers: { "X-Custom": "value" } }); + * ``` + */ + async updateProvider(provider: Partial): Promise { + await this.connection.sendRequest("session.provider.update", { + sessionId: this.sessionId, + provider, + }); + } + /** * Change the model for this session. * The new model takes effect for the next message. Conversation history is preserved. diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index ceca07d64..274fc6073 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -1278,8 +1278,21 @@ export interface ProviderConfig { */ apiVersion?: string; }; + + /** + * Custom HTTP headers to include in all outbound requests to the provider. + * Supports env var expansion (e.g. ${VAR}, ${VAR:-default}). + */ + headers?: Record; } +/** + * Strategy for merging per-turn request headers with session-level provider headers. + * - "override": Per-turn headers completely replace session-level headers (default) + * - "merge": Per-turn headers are merged with session-level headers; per-turn wins on conflicts + */ +export type HeaderMergeStrategy = "override" | "merge"; + /** * Options for sending a message to a session */ @@ -1327,6 +1340,17 @@ export interface MessageOptions { * - "immediate": Send immediately */ mode?: "enqueue" | "immediate"; + + /** + * Custom HTTP headers to include in outbound model requests for this turn only. + */ + requestHeaders?: Record; + + /** + * Strategy for merging per-turn requestHeaders with session-level provider headers. + * Defaults to "override". + */ + headerMergeStrategy?: HeaderMergeStrategy; } /** diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index 0b98ebcb8..69a1a2e0d 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -977,4 +977,221 @@ describe("CopilotClient", () => { rpcSpy.mockRestore(); }); }); + + describe("custom headers", () => { + it("sends requestHeaders in session.send RPC call", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const session = await client.createSession({ onPermissionRequest: approveAll }); + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string) => { + if (method === "session.send") return { messageId: "msg-1" }; + throw new Error(`Unexpected method: ${method}`); + }); + + await session.send({ + prompt: "hello", + requestHeaders: { "X-Custom": "value", "X-Another": "other" }, + }); + + expect(spy).toHaveBeenCalledWith( + "session.send", + expect.objectContaining({ + requestHeaders: { "X-Custom": "value", "X-Another": "other" }, + }) + ); + spy.mockRestore(); + }); + + it("sends headerMergeStrategy 'override' in session.send RPC call", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const session = await client.createSession({ onPermissionRequest: approveAll }); + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string) => { + if (method === "session.send") return { messageId: "msg-1" }; + throw new Error(`Unexpected method: ${method}`); + }); + + await session.send({ + prompt: "hello", + requestHeaders: { "X-Custom": "value" }, + headerMergeStrategy: "override", + }); + + expect(spy).toHaveBeenCalledWith( + "session.send", + expect.objectContaining({ + headerMergeStrategy: "override", + }) + ); + spy.mockRestore(); + }); + + it("sends headerMergeStrategy 'merge' in session.send RPC call", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const session = await client.createSession({ onPermissionRequest: approveAll }); + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string) => { + if (method === "session.send") return { messageId: "msg-1" }; + throw new Error(`Unexpected method: ${method}`); + }); + + await session.send({ + prompt: "hello", + requestHeaders: { "X-Custom": "value" }, + headerMergeStrategy: "merge", + }); + + expect(spy).toHaveBeenCalledWith( + "session.send", + expect.objectContaining({ + headerMergeStrategy: "merge", + }) + ); + spy.mockRestore(); + }); + + it("sends empty requestHeaders when provided", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const session = await client.createSession({ onPermissionRequest: approveAll }); + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string) => { + if (method === "session.send") return { messageId: "msg-1" }; + throw new Error(`Unexpected method: ${method}`); + }); + + await session.send({ + prompt: "hello", + requestHeaders: {}, + }); + + expect(spy).toHaveBeenCalledWith( + "session.send", + expect.objectContaining({ + requestHeaders: {}, + }) + ); + spy.mockRestore(); + }); + + it("does not include requestHeaders when undefined", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const session = await client.createSession({ onPermissionRequest: approveAll }); + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string) => { + if (method === "session.send") return { messageId: "msg-1" }; + throw new Error(`Unexpected method: ${method}`); + }); + + await session.send({ prompt: "hello" }); + + const [, params] = spy.mock.calls.find(([method]) => method === "session.send")!; + expect(params.requestHeaders).toBeUndefined(); + expect(params.headerMergeStrategy).toBeUndefined(); + spy.mockRestore(); + }); + + it("sends provider headers via updateProvider RPC call", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const session = await client.createSession({ onPermissionRequest: approveAll }); + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string) => { + if (method === "session.provider.update") return {}; + throw new Error(`Unexpected method: ${method}`); + }); + + await session.updateProvider({ + headers: { Authorization: "Bearer token123", "X-Custom": "val" }, + }); + + expect(spy).toHaveBeenCalledWith( + "session.provider.update", + expect.objectContaining({ + sessionId: session.sessionId, + provider: { + headers: { Authorization: "Bearer token123", "X-Custom": "val" }, + }, + }) + ); + spy.mockRestore(); + }); + + it("sends updateProvider with empty headers", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const session = await client.createSession({ onPermissionRequest: approveAll }); + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string) => { + if (method === "session.provider.update") return {}; + throw new Error(`Unexpected method: ${method}`); + }); + + await session.updateProvider({ headers: {} }); + + expect(spy).toHaveBeenCalledWith( + "session.provider.update", + expect.objectContaining({ + provider: { headers: {} }, + }) + ); + spy.mockRestore(); + }); + + it("sends both requestHeaders and headerMergeStrategy together", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const session = await client.createSession({ onPermissionRequest: approveAll }); + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string) => { + if (method === "session.send") return { messageId: "msg-1" }; + throw new Error(`Unexpected method: ${method}`); + }); + + await session.send({ + prompt: "test", + requestHeaders: { "X-Request-Id": "req-123" }, + headerMergeStrategy: "merge", + }); + + expect(spy).toHaveBeenCalledWith( + "session.send", + expect.objectContaining({ + requestHeaders: { "X-Request-Id": "req-123" }, + headerMergeStrategy: "merge", + sessionId: session.sessionId, + prompt: "test", + }) + ); + spy.mockRestore(); + }); + }); }); diff --git a/python/README.md b/python/README.md index 33f62c2d4..a085ff319 100644 --- a/python/README.md +++ b/python/README.md @@ -493,6 +493,26 @@ async with await client.create_session( > - For Azure OpenAI endpoints (`*.openai.azure.com`), you **must** use `type: "azure"`, not `type: "openai"`. > - The `base_url` should be just the host (e.g., `https://my-resource.openai.azure.com`). Do **not** include `/openai/v1` in the URL - the SDK handles path construction automatically. +### Custom Headers + +You can attach custom HTTP headers to outbound model requests — useful for API gateways, proxy authentication, or tenant routing: + +```python +session = await client.create_session( + model="gpt-4.1", + provider={ + "type": "openai", + "base_url": "https://my-gateway.example.com/v1", + "api_key": os.environ["OPENAI_API_KEY"], + "headers": { + "Ocp-Apim-Subscription-Key": "${APIM_KEY}", + }, + }, +) +``` + +Per-turn headers and merge strategies are also supported. See the [Custom Headers](docs/auth/byok.md#custom-headers) section in the BYOK guide for full details. + ## Telemetry The SDK supports OpenTelemetry for distributed tracing. Provide a `telemetry` config to enable trace export and automatic W3C Trace Context propagation. diff --git a/python/copilot/client.py b/python/copilot/client.py index ab8074756..8ebd9720d 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -1955,6 +1955,8 @@ def _convert_provider_to_wire_format( wire_provider["wireApi"] = provider["wire_api"] if "bearer_token" in provider: wire_provider["bearerToken"] = provider["bearer_token"] + if "headers" in provider: + wire_provider["headers"] = provider["headers"] if "azure" in provider: azure = provider["azure"] wire_azure: dict[str, Any] = {} diff --git a/python/copilot/session.py b/python/copilot/session.py index c4feb82de..2ab9f9937 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -507,6 +507,12 @@ class ProviderConfig(TypedDict, total=False): # Takes precedence over api_key when both are set. bearer_token: str azure: AzureProviderOptions # Azure-specific options + # Custom HTTP headers to include in all outbound requests to the provider. + # Supports env var expansion (e.g. ${VAR}, ${VAR:-default}). + headers: dict[str, str] + + +HeaderMergeStrategy = Literal["override", "merge"] class SessionConfig(TypedDict, total=False): @@ -706,6 +712,8 @@ async def send( *, attachments: list[Attachment] | None = None, mode: Literal["enqueue", "immediate"] | None = None, + request_headers: dict[str, str] | None = None, + header_merge_strategy: HeaderMergeStrategy | None = None, ) -> str: """ Send a message to this session. @@ -718,6 +726,9 @@ async def send( prompt: The message text to send. attachments: Optional file, directory, or selection attachments. mode: Message delivery mode (``"enqueue"`` or ``"immediate"``). + request_headers: Custom HTTP headers for this turn only. + header_merge_strategy: Strategy for merging per-turn headers with + session-level provider headers. Defaults to ``"override"``. Returns: The message ID assigned by the server, which can be used to correlate events. @@ -739,6 +750,10 @@ async def send( params["attachments"] = attachments if mode is not None: params["mode"] = mode + if request_headers is not None: + params["requestHeaders"] = request_headers + if header_merge_strategy is not None: + params["headerMergeStrategy"] = header_merge_strategy params.update(get_trace_context()) response = await self._client.request("session.send", params) @@ -1374,6 +1389,28 @@ async def abort(self) -> None: """ await self._client.request("session.abort", {"sessionId": self.session_id}) + async def update_provider(self, provider: ProviderConfig) -> None: + """ + Update the provider configuration for this session. + + This allows changing headers, authentication, or other provider settings + between turns. + + Args: + provider: Provider configuration to update. + + Raises: + Exception: If the session has been destroyed or the connection fails. + + Example: + >>> await session.update_provider({"headers": {"X-Custom": "value"}}) + """ + wire_provider = self._client._convert_provider_to_wire_format(provider) + await self._client.request( + "session.provider.update", + {"sessionId": self.session_id, "provider": wire_provider}, + ) + async def set_model(self, model: str, *, reasoning_effort: str | None = None) -> None: """ Change the model for this session. diff --git a/python/test_custom_headers.py b/python/test_custom_headers.py new file mode 100644 index 000000000..da5b8e828 --- /dev/null +++ b/python/test_custom_headers.py @@ -0,0 +1,243 @@ +"""Tests for custom headers in session.send() and provider configuration.""" + +from unittest.mock import AsyncMock + +import pytest + +from copilot.session import CopilotSession, HeaderMergeStrategy, ProviderConfig + + +def _make_session(client: AsyncMock) -> CopilotSession: + """Create a CopilotSession with a mocked client for unit testing.""" + return CopilotSession(session_id="sess-1", client=client) + + +class TestProviderConfigHeaders: + """Test ProviderConfig TypedDict construction with headers.""" + + def test_provider_config_with_headers(self): + """ProviderConfig can include custom headers.""" + config: ProviderConfig = { + "base_url": "https://api.example.com", + "api_key": "test-key", + "headers": {"X-Custom": "value", "Authorization": "Bearer tok"}, + } + assert config["headers"]["X-Custom"] == "value" + assert config["headers"]["Authorization"] == "Bearer tok" + + def test_provider_config_with_empty_headers(self): + """ProviderConfig can include an empty headers dict.""" + config: ProviderConfig = { + "base_url": "https://api.example.com", + "headers": {}, + } + assert config["headers"] == {} + + def test_provider_config_without_headers(self): + """ProviderConfig works without the optional headers field.""" + config: ProviderConfig = { + "base_url": "https://api.example.com", + } + assert "headers" not in config + + +class TestHeaderMergeStrategy: + """Test HeaderMergeStrategy literal values.""" + + def test_override_value(self): + strategy: HeaderMergeStrategy = "override" + assert strategy == "override" + + def test_merge_value(self): + strategy: HeaderMergeStrategy = "merge" + assert strategy == "merge" + + +class TestSendWithCustomHeaders: + """Test that send() passes requestHeaders and headerMergeStrategy to the RPC call.""" + + @pytest.mark.asyncio + async def test_send_includes_request_headers(self): + """Verify requestHeaders are forwarded in the RPC params.""" + client = AsyncMock() + client.request = AsyncMock(return_value={"messageId": "msg-1"}) + + session = _make_session(client) + + await session.send( + "test prompt", + request_headers={"X-Custom": "value", "X-Another": "other"}, + ) + + client.request.assert_called_once() + args, _ = client.request.call_args + assert args[0] == "session.send" + params = args[1] + assert params["requestHeaders"] == {"X-Custom": "value", "X-Another": "other"} + + @pytest.mark.asyncio + async def test_send_includes_header_merge_strategy_override(self): + """Verify headerMergeStrategy 'override' is forwarded.""" + client = AsyncMock() + client.request = AsyncMock(return_value={"messageId": "msg-1"}) + + session = _make_session(client) + + await session.send( + "test", + request_headers={"X-Key": "val"}, + header_merge_strategy="override", + ) + + args, _ = client.request.call_args + params = args[1] + assert params["headerMergeStrategy"] == "override" + + @pytest.mark.asyncio + async def test_send_includes_header_merge_strategy_merge(self): + """Verify headerMergeStrategy 'merge' is forwarded.""" + client = AsyncMock() + client.request = AsyncMock(return_value={"messageId": "msg-1"}) + + session = _make_session(client) + + await session.send( + "test", + request_headers={"X-Key": "val"}, + header_merge_strategy="merge", + ) + + args, _ = client.request.call_args + params = args[1] + assert params["headerMergeStrategy"] == "merge" + + @pytest.mark.asyncio + async def test_send_omits_headers_when_none(self): + """Verify requestHeaders and headerMergeStrategy are omitted when not provided.""" + client = AsyncMock() + client.request = AsyncMock(return_value={"messageId": "msg-1"}) + + session = _make_session(client) + + await session.send("test") + + args, _ = client.request.call_args + params = args[1] + assert "requestHeaders" not in params + assert "headerMergeStrategy" not in params + + @pytest.mark.asyncio + async def test_send_with_empty_request_headers(self): + """Verify empty requestHeaders dict is forwarded.""" + client = AsyncMock() + client.request = AsyncMock(return_value={"messageId": "msg-1"}) + + session = _make_session(client) + + await session.send("test", request_headers={}) + + args, _ = client.request.call_args + params = args[1] + assert params["requestHeaders"] == {} + + @pytest.mark.asyncio + async def test_send_with_both_headers_and_strategy(self): + """Verify both requestHeaders and headerMergeStrategy are forwarded together.""" + client = AsyncMock() + client.request = AsyncMock(return_value={"messageId": "msg-1"}) + + session = _make_session(client) + + await session.send( + "hello", + request_headers={"X-Request-Id": "req-123"}, + header_merge_strategy="merge", + ) + + args, _ = client.request.call_args + params = args[1] + assert params["sessionId"] == "sess-1" + assert params["prompt"] == "hello" + assert params["requestHeaders"] == {"X-Request-Id": "req-123"} + assert params["headerMergeStrategy"] == "merge" + + +class TestUpdateProvider: + """Test that update_provider() makes the correct RPC call.""" + + @pytest.mark.asyncio + async def test_update_provider_with_headers(self): + """Verify update_provider sends headers in wire format.""" + from copilot.client import CopilotClient + + client_mock = AsyncMock() + client_mock.request = AsyncMock(return_value={}) + + # Use the real wire format conversion + client_mock._convert_provider_to_wire_format = ( + CopilotClient._convert_provider_to_wire_format.__get__(client_mock) + ) + + session = _make_session(client_mock) + + await session.update_provider( + {"headers": {"Authorization": "Bearer token", "X-Custom": "val"}} + ) + + client_mock.request.assert_called_once() + args, _ = client_mock.request.call_args + assert args[0] == "session.provider.update" + params = args[1] + assert params["sessionId"] == "sess-1" + assert params["provider"]["headers"] == { + "Authorization": "Bearer token", + "X-Custom": "val", + } + + @pytest.mark.asyncio + async def test_update_provider_with_empty_headers(self): + """Verify update_provider with empty headers dict.""" + from copilot.client import CopilotClient + + client_mock = AsyncMock() + client_mock.request = AsyncMock(return_value={}) + client_mock._convert_provider_to_wire_format = ( + CopilotClient._convert_provider_to_wire_format.__get__(client_mock) + ) + + session = _make_session(client_mock) + + await session.update_provider({"headers": {}}) + + args, _ = client_mock.request.call_args + params = args[1] + assert params["provider"]["headers"] == {} + + @pytest.mark.asyncio + async def test_update_provider_wire_format_conversion(self): + """Verify provider config is converted from snake_case to camelCase.""" + from copilot.client import CopilotClient + + client_mock = AsyncMock() + client_mock.request = AsyncMock(return_value={}) + client_mock._convert_provider_to_wire_format = ( + CopilotClient._convert_provider_to_wire_format.__get__(client_mock) + ) + + session = _make_session(client_mock) + + await session.update_provider( + { + "base_url": "https://api.example.com", + "api_key": "key-123", + "headers": {"X-Custom": "value"}, + } + ) + + args, _ = client_mock.request.call_args + provider = args[1]["provider"] + assert provider["baseUrl"] == "https://api.example.com" + assert provider["apiKey"] == "key-123" + assert provider["headers"] == {"X-Custom": "value"} + assert "base_url" not in provider + assert "api_key" not in provider