Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 108 additions & 1 deletion dotnet/src/VectorData/SqlServer/SqlServerCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer;
#pragma warning disable CA1711 // Identifiers should not have incorrect suffix (Collection)
public class SqlServerCollection<TKey, TRecord>
#pragma warning restore CA1711
: VectorStoreCollection<TKey, TRecord>
: VectorStoreCollection<TKey, TRecord>,
IKeywordHybridSearchable<TRecord>
where TKey : notnull
where TRecord : class
{
/// <summary>Metadata about vector store record collection.</summary>
private readonly VectorStoreCollectionMetadata _collectionMetadata;

private static readonly VectorSearchOptions<TRecord> s_defaultVectorSearchOptions = new();
private static readonly HybridSearchOptions<TRecord> s_defaultHybridSearchOptions = new();

private readonly string _connectionString;
private readonly CollectionModel _model;
Expand Down Expand Up @@ -635,6 +637,74 @@ _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embeddin
}
}

/// <inheritdoc />
public async IAsyncEnumerable<VectorSearchResult<TRecord>> HybridSearchAsync<TInput>(
TInput searchValue,
ICollection<string> keywords,
int top,
HybridSearchOptions<TRecord>? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
where TInput : notnull
{
Verify.NotNull(searchValue);
Verify.NotNull(keywords);
Verify.NotLessThan(top, 1);

options ??= s_defaultHybridSearchOptions;
if (options.IncludeVectors && this._model.EmbeddingGenerationRequired)
{
throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration);
}
#pragma warning disable CS0618 // Type or member is obsolete
if (options.OldFilter is not null)
{
throw new NotSupportedException("The obsolete Filter is not supported by the SQL Server connector, use Filter instead.");
}
#pragma warning restore CS0618 // Type or member is obsolete

var vectorProperty = this._model.GetVectorPropertyOrSingle(new VectorSearchOptions<TRecord> { VectorProperty = options.VectorProperty });
var textDataProperty = this._model.GetFullTextDataPropertyOrSingle(options.AdditionalProperty);

SqlVector<float> vector = searchValue switch
{
SqlVector<float> v => v,
ReadOnlyMemory<float> r => new(r),
float[] f => new(f),
Embedding<float> e => new(e.Vector),

_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embedding<float>> generator
=> new(await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false)),

_ => vectorProperty.EmbeddingGenerator is null
? throw new NotSupportedException(VectorDataStrings.InvalidSearchInputAndNoEmbeddingGeneratorWasConfigured(searchValue.GetType(), SqlServerModelBuilder.SupportedVectorTypes))
: throw new InvalidOperationException(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType(typeof(TInput), vectorProperty.EmbeddingGenerator.GetType()))
};

var keywordsCombined = string.Join(" ", keywords);

#pragma warning disable CA2000 // Dispose objects before losing scope
// Connection and command are going to be disposed by the ReadVectorSearchResultsAsync,
// when the user is done with the results.
SqlConnection connection = new(this._connectionString);
SqlCommand command = SqlServerCommandBuilder.SelectHybrid(
connection,
this._schema,
this.Name,
vectorProperty,
textDataProperty,
this._model,
top,
options,
vector,
keywordsCombined);
#pragma warning restore CA2000 // Dispose objects before losing scope

await foreach (var record in this.ReadHybridSearchResultsAsync(connection, command, options, cancellationToken).ConfigureAwait(false))
{
yield return record;
}
}

#endregion Search

/// <inheritdoc />
Expand Down Expand Up @@ -688,6 +758,43 @@ private async IAsyncEnumerable<VectorSearchResult<TRecord>> ReadVectorSearchResu
}
}

private async IAsyncEnumerable<VectorSearchResult<TRecord>> ReadHybridSearchResultsAsync(
SqlConnection connection,
SqlCommand command,
HybridSearchOptions<TRecord> options,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
try
{
using SqlDataReader reader = await connection.ExecuteWithErrorHandlingAsync(
this._collectionMetadata,
operationName: "HybridSearch",
() => command.ExecuteReaderAsync(cancellationToken),
cancellationToken).ConfigureAwait(false);

int scoreIndex = -1;
while (await reader.ReadWithErrorHandlingAsync(
this._collectionMetadata,
operationName: "HybridSearch",
cancellationToken).ConfigureAwait(false))
{
if (scoreIndex < 0)
{
scoreIndex = reader.GetOrdinal("score");
}

yield return new VectorSearchResult<TRecord>(
this._mapper.MapFromStorageToDataModel(reader, options.IncludeVectors),
reader.GetDouble(scoreIndex));
}
}
finally
{
command.Dispose();
connection.Dispose();
}
}

/// <inheritdoc />
public override async IAsyncEnumerable<TRecord> GetAsync(Expression<Func<TRecord, bool>> filter, int top,
FilteredRecordRetrievalOptions<TRecord>? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
Expand Down
180 changes: 180 additions & 0 deletions dotnet/src/VectorData/SqlServer/SqlServerCommandBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,48 @@ internal static SqlCommand CreateTable(
}
}

// Create full-text catalog and index for properties marked as IsFullTextIndexed
var fullTextProperties = new List<DataPropertyModel>();
foreach (var dataProperty in model.DataProperties)
{
if (dataProperty.IsFullTextIndexed)
{
fullTextProperties.Add(dataProperty);
}
}

if (fullTextProperties.Count > 0)
{
// Generate a unique catalog name based on the table name
var catalogName = $"ftcat_{tableName}".Replace(" ", "_");

// Create full-text catalog if it doesn't exist
sb.Append("IF NOT EXISTS (SELECT 1 FROM sys.fulltext_catalogs WHERE name = '").Append(catalogName.Replace("'", "''")).AppendLine("')");
sb.Append(" CREATE FULLTEXT CATALOG ").AppendIdentifier(catalogName).AppendLine(";");

// Create full-text index on the table using dynamic SQL to look up the PK constraint name
// Full-text indexes require a unique index (we use the primary key)
sb.AppendLine("DECLARE @pkIndexName NVARCHAR(128);");
sb.Append("SELECT @pkIndexName = name FROM sys.indexes WHERE object_id = OBJECT_ID(N'");
sb.AppendTableName(schema, tableName);
sb.AppendLine("') AND is_primary_key = 1;");

sb.AppendLine("DECLARE @ftSql NVARCHAR(MAX);");
sb.Append("SET @ftSql = N'CREATE FULLTEXT INDEX ON ");
sb.AppendTableName(schema, tableName).Append(" (");
for (int i = 0; i < fullTextProperties.Count; i++)
{
sb.AppendIdentifier(fullTextProperties[i].StorageName);
if (i < fullTextProperties.Count - 1)
{
sb.Append(',');
}
}
sb.Append(") KEY INDEX ' + QUOTENAME(@pkIndexName) + N' ON ");
sb.AppendIdentifier(catalogName).AppendLine("';");
sb.AppendLine("EXEC sp_executesql @ftSql;");
}

sb.Append("END;");

return connection.CreateCommand(sb);
Expand Down Expand Up @@ -425,6 +467,144 @@ internal static SqlCommand SelectVector<TRecord>(
return command;
}

internal static SqlCommand SelectHybrid<TRecord>(
SqlConnection connection, string? schema, string tableName,
VectorPropertyModel vectorProperty,
DataPropertyModel textProperty,
CollectionModel model,
int top,
HybridSearchOptions<TRecord> options,
SqlVector<float> vector,
string keywords)
{
string distanceFunction = vectorProperty.DistanceFunction ?? DistanceFunction.CosineDistance;
(string distanceMetric, _) = MapDistanceFunction(distanceFunction);

SqlCommand command = connection.CreateCommand();
command.Parameters.AddWithValue("@vector", vector);
command.Parameters.AddWithValue("@keywords", keywords);

// For RRF, we need to fetch more candidates from each search than the final top count
// to allow proper merging. The number of candidates should be at least top + skip.
// The RRF constant (k) is typically 60 in literature, but we use a smaller value
// that still allows proper ranking while keeping the query efficient.
int candidateCount = Math.Max(top + options.Skip, 20); // Fetch at least 20 candidates
const int RrfK = 60; // Standard RRF constant

command.Parameters.AddWithValue("@candidateCount", candidateCount);
command.Parameters.AddWithValue("@rrfK", RrfK);

StringBuilder sb = new(1000);

// Build the hybrid search query using CTEs with Reciprocal Rank Fusion (RRF)
// Reference: https://github.com/Azure-Samples/azure-sql-db-openai/blob/main/vector-embeddings/07-hybrid-search.sql

// CTE 1: Keyword search using FREETEXTTABLE
sb.AppendLine("WITH keyword_search AS (");
sb.AppendLine(" SELECT TOP(@candidateCount)");
sb.Append(" ").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(",");
sb.AppendLine(" RANK() OVER (ORDER BY ft_rank DESC) AS [rank]");
sb.AppendLine(" FROM (");
sb.AppendLine(" SELECT TOP(@candidateCount)");
sb.Append(" w.").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(",");
sb.AppendLine(" ftt.[RANK] AS ft_rank");
sb.Append(" FROM ").AppendTableName(schema, tableName).AppendLine(" w");
sb.Append(" INNER JOIN FREETEXTTABLE(").AppendTableName(schema, tableName).Append(", ")
.AppendIdentifier(textProperty.StorageName).AppendLine(", @keywords) AS ftt");
sb.Append(" ON w.").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(" = ftt.[KEY]");

// Apply filter to keyword search if specified
if (options.Filter is not null)
{
int startParamIndex = command.Parameters.Count;
SqlServerFilterTranslator translator = new(model, options.Filter, sb, startParamIndex: startParamIndex, tableAlias: "w");
translator.Translate(appendWhere: true);
foreach (object parameter in translator.ParameterValues)
{
command.AddParameter(property: null, $"@_{startParamIndex++}", parameter);
}
sb.AppendLine();
}

sb.AppendLine(" ORDER BY ft_rank DESC");
sb.AppendLine(" ) AS freetext_documents");
sb.AppendLine("),");

// CTE 2: Semantic/vector search
sb.AppendLine("semantic_search AS (");
sb.AppendLine(" SELECT TOP(@candidateCount)");
sb.Append(" ").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(",");
sb.AppendLine(" RANK() OVER (ORDER BY cosine_distance) AS [rank]");
sb.AppendLine(" FROM (");
sb.AppendLine(" SELECT TOP(@candidateCount)");
sb.Append(" w.").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(",");
sb.Append(" VECTOR_DISTANCE('").Append(distanceMetric).Append("', ")
.AppendIdentifier(vectorProperty.StorageName)
.Append(", CAST(@vector AS VECTOR(").Append(vector.Length).AppendLine("))) AS cosine_distance");
sb.Append(" FROM ").AppendTableName(schema, tableName).AppendLine(" w");

// Apply filter to semantic search if specified
if (options.Filter is not null)
{
// We need to re-translate the filter for the semantic search CTE
// The parameters are already added from keyword search, so we start fresh for this CTE
int filterParamStart = command.Parameters.Count;
SqlServerFilterTranslator translator = new(model, options.Filter, sb, startParamIndex: filterParamStart, tableAlias: "w");
translator.Translate(appendWhere: true);
foreach (object parameter in translator.ParameterValues)
{
command.AddParameter(property: null, $"@_{filterParamStart++}", parameter);
}
sb.AppendLine();
}

sb.AppendLine(" ORDER BY cosine_distance");
sb.AppendLine(" ) AS similar_documents");
sb.AppendLine("),");

// CTE 3: Combined results with RRF scoring
sb.AppendLine("hybrid_result AS (");
sb.AppendLine(" SELECT");
sb.Append(" COALESCE(ss.").AppendIdentifier(model.KeyProperty.StorageName)
.Append(", ks.").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(") AS combined_key,");
sb.AppendLine(" ss.[rank] AS semantic_rank,");
sb.AppendLine(" ks.[rank] AS keyword_rank,");
// Cast to FLOAT to match the expected return type in C# (double)
// Use @rrfK as the RRF constant (typically 60)
sb.AppendLine(" CAST(COALESCE(1.0 / (@rrfK + ss.[rank]), 0.0) + COALESCE(1.0 / (@rrfK + ks.[rank]), 0.0) AS FLOAT) AS [score]");
sb.AppendLine(" FROM semantic_search ss");
sb.Append(" FULL OUTER JOIN keyword_search ks ON ss.").AppendIdentifier(model.KeyProperty.StorageName)
.Append(" = ks.").AppendIdentifier(model.KeyProperty.StorageName).AppendLine();
sb.AppendLine(")");

// Final SELECT joining back to the main table
sb.Append("SELECT ");
foreach (var property in model.Properties)
{
if (!options.IncludeVectors && property is VectorPropertyModel)
{
continue;
}
sb.Append("w.").AppendIdentifier(property.StorageName).Append(',');
}
sb.Length--; // remove trailing comma
sb.AppendLine(",");
sb.AppendLine(" hr.[score]");
sb.AppendLine("FROM hybrid_result hr");
sb.Append("INNER JOIN ").AppendTableName(schema, tableName).AppendLine(" w");
sb.Append(" ON hr.combined_key = w.").AppendIdentifier(model.KeyProperty.StorageName).AppendLine();
if (options.ScoreThreshold.HasValue)
{
command.Parameters.AddWithValue("@scoreThreshold", options.ScoreThreshold.Value);
sb.AppendLine("WHERE hr.[score] >= @scoreThreshold");
}
sb.AppendLine("ORDER BY hr.[score] DESC");
sb.AppendFormat("OFFSET {0} ROWS FETCH NEXT {1} ROWS ONLY;", options.Skip, top);

command.CommandText = sb.ToString();
return command;
}

internal static SqlCommand SelectWhere<TRecord>(
Expression<Func<TRecord, bool>> filter,
int top,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@ namespace Microsoft.SemanticKernel.Connectors.SqlServer;
internal sealed class SqlServerFilterTranslator : SqlFilterTranslator
{
private readonly List<object> _parameterValues = [];
private readonly string? _tableAlias;
private int _parameterIndex;

internal SqlServerFilterTranslator(
CollectionModel model,
LambdaExpression lambdaExpression,
StringBuilder sql,
int startParamIndex)
int startParamIndex,
string? tableAlias = null)
: base(model, lambdaExpression, sql)
{
this._parameterIndex = startParamIndex;
this._tableAlias = tableAlias;
}

internal List<object> ParameterValues => this._parameterValues;
Expand Down Expand Up @@ -65,6 +68,10 @@ protected override void TranslateConstant(object? value, bool isSearchCondition)
protected override void GenerateColumn(PropertyModel property, bool isSearchCondition = false)
{
// StorageName is considered to be a safe input, we quote and escape it mostly to produce valid SQL.
if (this._tableAlias is not null)
{
this._sql.Append(this._tableAlias).Append('.');
}
this._sql.Append('[').Append(property.StorageName.Replace("]", "]]")).Append(']');

// "SELECT * FROM MyTable WHERE BooleanColumn;" is not supported.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ public static IServiceCollection AddKeyedSqlServerCollection<TKey, TRecord>(
services.Add(new ServiceDescriptor(typeof(IVectorSearchable<TRecord>), serviceKey,
static (sp, key) => sp.GetRequiredKeyedService<SqlServerCollection<TKey, TRecord>>(key), lifetime));

// Once HybridSearch supports get implemented (https://github.com/microsoft/semantic-kernel/issues/11080)
// we need to add IKeywordHybridSearchable abstraction here as well.
services.Add(new ServiceDescriptor(typeof(IKeywordHybridSearchable<TRecord>), serviceKey,
static (sp, key) => sp.GetRequiredKeyedService<SqlServerCollection<TKey, TRecord>>(key), lifetime));

return services;
}
Expand Down
Loading
Loading