Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.azure
.azure
.aider*
4 changes: 2 additions & 2 deletions infra/main.bicep
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ param location string
param openAiSkuName string = 'S0'

@description('Version of the Chat GPT model.')
param chatGptModelVersion string = '0613'
param chatGptModelVersion string = '2024-05-13'

@description('Name of the Chat GPT deployment.')
param chatGptDeploymentName string = 'chat'
Expand All @@ -22,7 +22,7 @@ param embeddingGptModelVersion string = '2'
param embeddingGptDeploymentName string = 'embedding'

@description('Name of the Chat GPT model.')
param chatGptModelName string = 'gpt-35-turbo'
param chatGptModelName string = 'gpt-4o'

@description('The OpenAI endpoints capacity (in thousands of tokens per minute)')
param deploymentCapacity int = 30
Expand Down
44 changes: 44 additions & 0 deletions src/AuthenticationMiddleware.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using System.Threading.Tasks;

namespace openai_loadbalancer;

public class AuthenticationMiddleware
{
private readonly RequestDelegate _next;
private readonly ILogger<AuthenticationMiddleware> _logger;
private readonly IDictionary<string, string> _customerConfigurations;

public AuthenticationMiddleware(RequestDelegate next, ILogger<AuthenticationMiddleware> logger, IConfiguration configuration)
{
_next = next;
_logger = logger;
_customerConfigurations = configuration.GetSection("CustomerConfiguration").Get<Dictionary<string, string>>() ?? [];
}

public async Task InvokeAsync(HttpContext context)
{
if (!context.Request.Headers.TryGetValue("api-key", out var extractedApiKey))
{
_logger.LogWarning("API key was not provided.");
context.Response.StatusCode = 401; // Unauthorized
await context.Response.WriteAsync("API key is missing.");
return;
}

var customer = _customerConfigurations.FirstOrDefault(c => c.Value == extractedApiKey);

if (customer.Equals(default(KeyValuePair<string, string>)))
{
// display client id in logs
_logger.LogWarning($"Unauthorized client.");
context.Response.StatusCode = 401; // Unauthorized
await context.Response.WriteAsync("Unauthorized client.");
return;
}

_logger.LogInformation($"Customer '{customer.Key}' successfully authenticated.");
await _next(context);
}
}
2 changes: 1 addition & 1 deletion src/BackendConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public static IReadOnlyDictionary<string, BackendConfig> LoadConfig(IConfigurati
var apiKey = LoadEnvironmentVariable(environmentVariables, backendIndex, "APIKEY", isMandatory: false);
var priority = Convert.ToInt32(LoadEnvironmentVariable(environmentVariables, backendIndex, "PRIORITY"));

returnDictionary.Add(key, new BackendConfig { Url = url, ApiKey = apiKey, Priority = priority, DeploymentName = deploymentName });
returnDictionary.Add(key, new BackendConfig { Url = url!, ApiKey = apiKey!, Priority = priority!, DeploymentName = deploymentName! });
}

//Load the general settings not in scope only for specific backends
Expand Down
10 changes: 10 additions & 0 deletions src/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ public static void Main(string[] args)
{
var builder = WebApplication.CreateBuilder(args);

// Add Application Insights
builder.Services.AddApplicationInsightsTelemetry(o => o.ConnectionString = builder.Configuration["APPLICATIONINSIGHTS_CONNECTION_STRING"]);


var backendConfiguration = BackendConfig.LoadConfig(builder.Configuration);
var yarpConfiguration = new YarpConfiguration(backendConfiguration);
builder.Services.AddSingleton<IPassiveHealthCheckPolicy, ThrottlingHealthPolicy>();
Expand All @@ -20,6 +24,12 @@ public static void Main(string[] args)

builder.Services.AddHealthChecks();
var app = builder.Build();

var enableAuthentication = builder.Configuration.GetValue<bool>("EnableAuthenticationMiddleware");
if (enableAuthentication)
{
app.UseMiddleware<AuthenticationMiddleware>();
}

app.MapHealthChecks("/healthz");
app.MapReverseProxy(m =>
Expand Down
8 changes: 7 additions & 1 deletion src/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,11 @@
"Microsoft.AspNetCore": "Warning"
}
},
"AllowedHosts": "*"
"AllowedHosts": "*",
"CustomerConfiguration": {
"Customer1": "",
"Customer2": ""
},
"EnableAuthenticationMiddleware": true
}

7 changes: 4 additions & 3 deletions src/openai-loadbalancer.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.Identity" Version="1.10.4" />
<PackageReference Include="Microsoft.Extensions.Azure" Version="1.7.2" />
<PackageReference Include="Microsoft.VisualStudio.Azure.Containers.Tools.Targets" Version="1.19.5" />
<PackageReference Include="Azure.Identity" Version="1.12.0" />
<PackageReference Include="Microsoft.ApplicationInsights.AspNetCore" Version="2.22.0" />
<PackageReference Include="Microsoft.Extensions.Azure" Version="1.7.4" />
<PackageReference Include="Microsoft.VisualStudio.Azure.Containers.Tools.Targets" Version="1.20.1" />
<PackageReference Include="Yarp.ReverseProxy" Version="2.1.0" />
</ItemGroup>

Expand Down