Skip to content
155 changes: 155 additions & 0 deletions src/Service.Tests/UnitTests/HealthCheckUtilitiesUnitTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

#nullable enable

using System;
using Azure.DataApiBuilder.Config.ObjectModel;
using Azure.DataApiBuilder.Service.HealthCheck;
using Microsoft.Extensions.Logging;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Moq;

namespace Azure.DataApiBuilder.Service.Tests.UnitTests
{
/// <summary>
/// Unit tests for health check utility methods.
/// </summary>
[TestClass]
public class HealthCheckUtilitiesUnitTests
{
/// <summary>
/// Tests that connection strings are properly normalized for supported database types.
/// </summary>
[TestMethod]
[DataRow(
DatabaseType.PostgreSQL,
"Host=localhost;Port=5432;Database=testdb;Username=testuser;Password=XXXX",
"Host=localhost",
"Database=testdb",
DisplayName = "PostgreSQL connection string normalization")]
[DataRow(
DatabaseType.MSSQL,
"Server=localhost;Database=testdb;User Id=testuser;Password=XXXX",
"Data Source=localhost",
"Initial Catalog=testdb",
DisplayName = "MSSQL connection string normalization")]
[DataRow(
DatabaseType.DWSQL,
"Server=localhost;Database=testdb;User Id=testuser;Password=XXXX",
"Data Source=localhost",
"Initial Catalog=testdb",
DisplayName = "DWSQL connection string normalization")]
[DataRow(
DatabaseType.MySQL,
"Server=localhost;Port=3306;Database=testdb;Uid=testuser;Pwd=XXXX",
"Server=localhost",
"Database=testdb",
DisplayName = "MySQL connection string normalization")]
public void NormalizeConnectionString_SupportedDatabases_Success(
DatabaseType dbType,
string connectionString,
string expectedServerPart,
string expectedDatabasePart)
{
// Act
string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType);

// Assert
Assert.IsNotNull(result);
Assert.IsTrue(result.Contains(expectedServerPart));
Assert.IsTrue(result.Contains(expectedDatabasePart));
}

/// <summary>
/// Tests that unsupported database types return the original connection string.
/// </summary>
[TestMethod]
public void NormalizeConnectionString_UnsupportedType_ReturnsOriginal()
{
// Arrange
string connectionString = "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=test";
DatabaseType dbType = DatabaseType.CosmosDB_NoSQL;

// Act
string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType);

// Assert
Assert.AreEqual(connectionString, result);
}

/// <summary>
/// Tests that malformed connection strings are handled gracefully.
/// </summary>
[TestMethod]
[DataRow(DatabaseType.PostgreSQL, true, DisplayName = "PostgreSQL malformed string with logger")]
[DataRow(DatabaseType.MSSQL, true, DisplayName = "MSSQL malformed string with logger")]
[DataRow(DatabaseType.MySQL, false, DisplayName = "MySQL malformed string without logger")]
public void NormalizeConnectionString_MalformedString_ReturnsOriginal(
DatabaseType dbType,
bool useLogger)
{
// Arrange
string malformedConnectionString = "InvalidConnectionString;NoEquals";
Mock<ILogger>? mockLogger = useLogger ? new Mock<ILogger>() : null;

// Act
string result = HealthCheck.Utilities.NormalizeConnectionString(
malformedConnectionString,
dbType,
mockLogger?.Object);

// Assert
Assert.AreEqual(malformedConnectionString, result);
if (useLogger && mockLogger != null)
{
mockLogger.Verify(
x => x.Log(
LogLevel.Warning,
It.IsAny<EventId>(),
It.Is<It.IsAnyType>((v, t) => true),
It.IsAny<Exception>(),
It.Is<Func<It.IsAnyType, Exception?, string>>((v, t) => true)),
Times.Once);
}
}

/// <summary>
/// Tests that PostgreSQL connection strings with lowercase keywords are normalized correctly.
/// This is the specific bug that was reported - lowercase 'host' was not supported.
/// </summary>
[TestMethod]
public void NormalizeConnectionString_PostgreSQL_LowercaseKeywords_Success()
{
// Arrange
string connectionString = "host=localhost;port=5432;database=mydb;username=myuser;password=XXXX";
DatabaseType dbType = DatabaseType.PostgreSQL;

// Act
string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType);

// Assert
Assert.IsNotNull(result);
// NpgsqlConnectionStringBuilder should normalize lowercase keywords to proper format
Assert.IsTrue(result.Contains("Host=localhost") || result.Contains("host=localhost"));
Assert.IsTrue(result.Contains("Database=mydb") || result.Contains("database=mydb"));
}

/// <summary>
/// Tests that empty connection strings are handled gracefully.
/// </summary>
[TestMethod]
public void NormalizeConnectionString_EmptyString_ReturnsEmpty()
{
// Arrange
string connectionString = string.Empty;
DatabaseType dbType = DatabaseType.PostgreSQL;

// Act
string result = HealthCheck.Utilities.NormalizeConnectionString(connectionString, dbType);

// Assert
Assert.AreEqual(string.Empty, result);
}
}
}
6 changes: 3 additions & 3 deletions src/Service/HealthCheck/HealthCheckHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ private async Task UpdateDataSourceHealthCheckResultsAsync(ComprehensiveHealthCh
if (comprehensiveHealthCheckReport.Checks != null && runtimeConfig.DataSource.IsDatasourceHealthEnabled)
{
string query = Utilities.GetDatSourceQuery(runtimeConfig.DataSource.DatabaseType);
(int, string?) response = await ExecuteDatasourceQueryCheckAsync(query, runtimeConfig.DataSource.ConnectionString, Utilities.GetDbProviderFactory(runtimeConfig.DataSource.DatabaseType));
(int, string?) response = await ExecuteDatasourceQueryCheckAsync(query, runtimeConfig.DataSource.ConnectionString, Utilities.GetDbProviderFactory(runtimeConfig.DataSource.DatabaseType), runtimeConfig.DataSource.DatabaseType);
bool isResponseTimeWithinThreshold = response.Item1 >= 0 && response.Item1 < runtimeConfig.DataSource.DatasourceThresholdMs;

// Add DataSource Health Check Results
Expand All @@ -182,14 +182,14 @@ private async Task UpdateDataSourceHealthCheckResultsAsync(ComprehensiveHealthCh
}

// Executes the DB Query and keeps track of the response time and error message.
private async Task<(int, string?)> ExecuteDatasourceQueryCheckAsync(string query, string connectionString, DbProviderFactory dbProviderFactory)
private async Task<(int, string?)> ExecuteDatasourceQueryCheckAsync(string query, string connectionString, DbProviderFactory dbProviderFactory, DatabaseType databaseType)
{
string? errorMessage = null;
if (!string.IsNullOrEmpty(query) && !string.IsNullOrEmpty(connectionString))
{
Stopwatch stopwatch = new();
stopwatch.Start();
errorMessage = await _httpUtility.ExecuteDbQueryAsync(query, connectionString, dbProviderFactory);
errorMessage = await _httpUtility.ExecuteDbQueryAsync(query, connectionString, dbProviderFactory, databaseType);
stopwatch.Stop();
return string.IsNullOrEmpty(errorMessage) ? ((int)stopwatch.ElapsedMilliseconds, errorMessage) : (HealthCheckConstants.ERROR_RESPONSE_TIME_MS, errorMessage);
}
Expand Down
4 changes: 2 additions & 2 deletions src/Service/HealthCheck/HttpUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public HttpUtilities(
}

// Executes the DB query by establishing a connection to the DB.
public async Task<string?> ExecuteDbQueryAsync(string query, string connectionString, DbProviderFactory providerFactory)
public async Task<string?> ExecuteDbQueryAsync(string query, string connectionString, DbProviderFactory providerFactory, DatabaseType databaseType)
{
string? errorMessage = null;
// Execute the query on DB and return the response time.
Expand All @@ -65,7 +65,7 @@ public HttpUtilities(
{
try
{
connection.ConnectionString = connectionString;
connection.ConnectionString = Utilities.NormalizeConnectionString(connectionString, databaseType, _logger);
using (DbCommand command = connection.CreateCommand())
{
command.CommandText = query;
Expand Down
29 changes: 29 additions & 0 deletions src/Service/HealthCheck/Utilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
using System.Text.Json;
using Azure.DataApiBuilder.Config.ObjectModel;
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Logging;
using MySqlConnector;
using Npgsql;

namespace Azure.DataApiBuilder.Service.HealthCheck
Expand Down Expand Up @@ -69,5 +71,32 @@ public static string CreateHttpRestQuery(string entityName, int first)
// "EntityName?$first=4"
return $"/{entityName}?$first={first}";
}

public static string NormalizeConnectionString(string connectionString, DatabaseType dbType, ILogger? logger = null)
{
try
{
switch (dbType)
{
case DatabaseType.PostgreSQL:
return new NpgsqlConnectionStringBuilder(connectionString).ToString();
case DatabaseType.MySQL:
return new MySqlConnectionStringBuilder(connectionString).ToString();
case DatabaseType.MSSQL:
case DatabaseType.DWSQL:
return new SqlConnectionStringBuilder(connectionString).ToString();
default:
return connectionString;
}
}
catch (Exception ex)
{
// Log the exception if a logger is provided
logger?.LogWarning(ex, "Failed to parse connection string for database type {DatabaseType}. Returning original connection string.", dbType);
// If the connection string cannot be parsed by the builder,
// return the original string to avoid failing the health check.
return connectionString;
}
}
}
}
Loading