-
Notifications
You must be signed in to change notification settings - Fork 855
Enhancing VectorStoreWriter for better RAG support #7396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: data-ingestion-preview2
Are you sure you want to change the base?
Changes from all commits
ea56d33
581bce0
727bcd5
e48fa9d
e9aa7fe
56fa7ad
c3c57e0
b3153bf
550781c
def5e6a
64af799
81f4817
f20bb41
4558e6c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| // Licensed to the .NET Foundation under one or more agreements. | ||
| // The .NET Foundation licenses this file to you under the MIT license. | ||
|
|
||
| using System; | ||
| using Microsoft.Extensions.VectorData; | ||
|
|
||
| namespace Microsoft.Extensions.DataIngestion; | ||
|
|
||
| /// <summary> | ||
| /// Represents the base record type used by <see cref="VectorStoreWriter{TChunk, TRecord}"/> to store ingested chunks in a vector store. | ||
| /// </summary> | ||
| /// <typeparam name="TChunk">The type of the chunk content.</typeparam> | ||
| /// <remarks> | ||
| /// When the vector dimension count is not known at compile time, | ||
| /// use the <see cref="VectorStoreExtensions.GetIngestionRecordCollection{TRecord, TChunk}(VectorStore, string, int, string, string?, string?)"/> | ||
| /// helper to create a <see cref="VectorStoreCollection{TKey, TRecord}"/> and pass it to the <see cref="VectorStoreWriter{TChunk, TRecord}"/> constructor. | ||
| /// When the vector dimension count is known at compile time, derive from this class and add | ||
| /// the <see cref="VectorStoreVectorAttribute"/> to the <see cref="Embedding"/> property. | ||
| /// </remarks> | ||
| public class IngestedChunkRecord<TChunk> | ||
| { | ||
| /// <summary> | ||
| /// The storage name for the <see cref="Embedding"/> property. | ||
| /// </summary> | ||
| protected const string EmbeddingStorageName = VectorStoreExtensions.EmbeddingStorageName; | ||
|
|
||
| private const string KeyStorageName = "key"; | ||
| private const string DocumentIdStorageName = "documentid"; | ||
| private const string ContentStorageName = "content"; | ||
| private const string ContextStorageName = "context"; | ||
|
|
||
| /// <summary> | ||
| /// Gets or sets the unique key for this record. | ||
| /// </summary> | ||
| [VectorStoreKey(StorageName = KeyStorageName)] | ||
| public virtual Guid Key { get; set; } | ||
|
|
||
| /// <summary> | ||
| /// Gets or sets the identifier of the document from which this chunk was extracted. | ||
| /// </summary> | ||
| [VectorStoreData(StorageName = DocumentIdStorageName)] | ||
| public virtual string DocumentId { get; set; } = string.Empty; | ||
|
|
||
| /// <summary> | ||
| /// Gets or sets the content of the chunk. | ||
| /// </summary> | ||
| [VectorStoreData(StorageName = ContentStorageName)] | ||
| public virtual TChunk? Content { get; set; } | ||
|
|
||
| /// <summary> | ||
| /// Gets or sets additional context for the chunk. | ||
| /// </summary> | ||
| [VectorStoreData(StorageName = ContextStorageName)] | ||
| public virtual string? Context { get; set; } | ||
|
|
||
| /// <summary> | ||
| /// Gets the embedding value for this record. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// By default, returns the <see cref="Content"/> value. The vector store's embedding generator | ||
| /// will convert this to a vector. Override this property in derived classes to add | ||
| /// the <see cref="VectorStoreVectorAttribute"/> with the appropriate dimension count. | ||
| /// </remarks> | ||
| public virtual TChunk? Embedding => Content; | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| // Licensed to the .NET Foundation under one or more agreements. | ||
| // The .NET Foundation licenses this file to you under the MIT license. | ||
|
|
||
| using System; | ||
| using System.Diagnostics.CodeAnalysis; | ||
| using Microsoft.Extensions.VectorData; | ||
|
|
||
| namespace Microsoft.Extensions.DataIngestion; | ||
|
|
||
| /// <summary> | ||
| /// Provides extension methods for working with vector stores in the context of data ingestion. | ||
| /// </summary> | ||
| public static class VectorStoreExtensions | ||
| { | ||
| internal const string EmbeddingStorageName = "embedding"; | ||
|
|
||
| /// <summary> | ||
| /// Provides a convenient method to get a vector store collection specifically designed for storing ingested chunk records. | ||
| /// </summary> | ||
| /// <typeparam name="TRecord">The type of the record to be stored in the collection.</typeparam> | ||
| /// <typeparam name="TChunk">The type of the chunk content.</typeparam> | ||
| /// <param name="vectorStore">The vector store instance to create the collection in.</param> | ||
| /// <param name="collectionName">The name of the collection to be created.</param> | ||
| /// <param name="dimensionCount">The number of dimensions that the vector has.</param> | ||
| /// <param name="storageName">The storage name for the vector property.</param> | ||
| /// <param name="distanceFunction"> | ||
| /// The distance function to use. When not provided, the default specific to given database will be used. | ||
| /// Check <see cref="DistanceFunction"/> for available values. | ||
| /// </param> | ||
| /// <param name="indexKind">The index kind to use.</param> | ||
| /// <returns>A vector store collection configured for ingested chunk records.</returns> | ||
| [RequiresDynamicCode("This API is not compatible with NativeAOT. You can implement your own IngestionChunkWriter that uses dynamic mapping via VectorStore.GetCollectionDynamic().")] | ||
| [RequiresUnreferencedCode("This API is not compatible with trimming. You can implement your own IngestionChunkWriter that uses dynamic mapping via VectorStore.GetCollectionDynamic().")] | ||
| public static VectorStoreCollection<Guid, TRecord> GetIngestionRecordCollection<TRecord, TChunk>(this VectorStore vectorStore, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To other reviewers: This is an alternative to exposing entire schema. We can just expose a factory method that does the right thing. The advantages;
The disadvantages:
|
||
| string collectionName, int dimensionCount, string storageName = EmbeddingStorageName, string? distanceFunction = null, string? indexKind = null) | ||
| where TRecord : IngestedChunkRecord<TChunk>, new() | ||
| { | ||
| _ = Shared.Diagnostics.Throw.IfNull(vectorStore); | ||
| _ = Shared.Diagnostics.Throw.IfNullOrEmpty(collectionName); | ||
| _ = Shared.Diagnostics.Throw.IfLessThanOrEqual(dimensionCount, 0); | ||
| _ = Shared.Diagnostics.Throw.IfNullOrEmpty(storageName); | ||
|
|
||
| VectorStoreCollectionDefinition additiveDefintion = new() | ||
| { | ||
| Properties = | ||
| { | ||
| // By using TChunk as the type here we allow the vector store | ||
| // to handle the conversion from TChunk to the actual vector type it supports. | ||
| new VectorStoreVectorProperty(nameof(IngestedChunkRecord<>.Embedding), typeof(TChunk), dimensionCount) | ||
| { | ||
| StorageName = storageName, | ||
| DistanceFunction = distanceFunction, | ||
| IndexKind = indexKind, | ||
| }, | ||
| }, | ||
| }; | ||
|
|
||
| return vectorStore.GetCollection<Guid, TRecord>(collectionName, additiveDefintion); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,88 +11,72 @@ | |
| namespace Microsoft.Extensions.DataIngestion; | ||
|
|
||
| /// <summary> | ||
| /// Writes chunks to the <see cref="VectorStore"/> using the default schema. | ||
| /// Writes chunks to a <see cref="VectorStoreCollection{TKey, TRecord}"/>. | ||
| /// </summary> | ||
| /// <typeparam name="T">The type of the chunk content.</typeparam> | ||
| public sealed class VectorStoreWriter<T> : IngestionChunkWriter<T> | ||
| /// <typeparam name="TChunk">The type of the chunk content.</typeparam> | ||
| /// <typeparam name="TRecord">The type of the record stored in the vector store.</typeparam> | ||
| public class VectorStoreWriter<TChunk, TRecord> : IngestionChunkWriter<TChunk> | ||
| where TRecord : IngestedChunkRecord<TChunk>, new() | ||
| { | ||
| // The names are lowercase with no special characters to ensure compatibility with various vector stores. | ||
| private const string KeyName = "key"; | ||
| private const string EmbeddingName = "embedding"; | ||
| private const string ContentName = "content"; | ||
| private const string ContextName = "context"; | ||
| private const string DocumentIdName = "documentid"; | ||
|
|
||
| private readonly VectorStore _vectorStore; | ||
| private readonly int _dimensionCount; | ||
| private readonly VectorStoreWriterOptions _options; | ||
|
|
||
| private VectorStoreCollection<object, Dictionary<string, object?>>? _vectorStoreCollection; | ||
| private bool _collectionEnsured; | ||
|
|
||
| /// <summary> | ||
| /// Initializes a new instance of the <see cref="VectorStoreWriter{T}"/> class. | ||
| /// Initializes a new instance of the <see cref="VectorStoreWriter{TChunk, TRecord}"/> class. | ||
| /// </summary> | ||
| /// <param name="vectorStore">The <see cref="VectorStore"/> to use to store the <see cref="IngestionChunk{T}"/> instances.</param> | ||
| /// <param name="dimensionCount">The number of dimensions that the vector has. This value is required when creating collections.</param> | ||
| /// <param name="collection">The <see cref="VectorStoreCollection{TKey, TRecord}"/> to use to store the <see cref="IngestionChunk{T}"/> instances.</param> | ||
| /// <param name="options">The options for the vector store writer.</param> | ||
| /// <exception cref="ArgumentNullException">When <paramref name="vectorStore"/> is null.</exception> | ||
| /// <exception cref="ArgumentOutOfRangeException">When <paramref name="dimensionCount"/> is less or equal zero.</exception> | ||
| public VectorStoreWriter(VectorStore vectorStore, int dimensionCount, VectorStoreWriterOptions? options = default) | ||
| /// <exception cref="ArgumentNullException">When <paramref name="collection"/> is null.</exception> | ||
| /// <remarks> | ||
| /// You can use the <see cref="VectorStoreExtensions.GetIngestionRecordCollection{TRecord, TChunk}(VectorStore, string, int, string, string?, string?)"/> | ||
| /// helper to create a <see cref="VectorStoreCollection{TKey, TRecord}"/> with the appropriate schema for storing ingestion chunks. | ||
| /// </remarks> | ||
| public VectorStoreWriter(VectorStoreCollection<Guid, TRecord> collection, VectorStoreWriterOptions? options = default) | ||
| { | ||
| _vectorStore = Throw.IfNull(vectorStore); | ||
| _dimensionCount = Throw.IfLessThanOrEqual(dimensionCount, 0); | ||
| VectorStoreCollection = Throw.IfNull(collection); | ||
| _options = options ?? new VectorStoreWriterOptions(); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Gets the underlying <see cref="VectorStoreCollection{TKey,TRecord}"/> used to store the chunks. | ||
| /// </summary> | ||
| /// <remarks> | ||
| /// The collection is initialized when <see cref="WriteAsync(IAsyncEnumerable{IngestionChunk{T}}, CancellationToken)"/> is called for the first time. | ||
| /// </remarks> | ||
| /// <exception cref="InvalidOperationException">The collection has not been initialized yet. | ||
| /// Call <see cref="WriteAsync(IAsyncEnumerable{IngestionChunk{T}}, CancellationToken)"/> first.</exception> | ||
| public VectorStoreCollection<object, Dictionary<string, object?>> VectorStoreCollection | ||
| => _vectorStoreCollection ?? throw new InvalidOperationException("The collection has not been initialized yet. Call WriteAsync first."); | ||
| public VectorStoreCollection<Guid, TRecord> VectorStoreCollection { get; } | ||
|
|
||
| /// <inheritdoc/> | ||
| public override async Task WriteAsync(IAsyncEnumerable<IngestionChunk<T>> chunks, CancellationToken cancellationToken = default) | ||
| public override async Task WriteAsync(IAsyncEnumerable<IngestionChunk<TChunk>> chunks, CancellationToken cancellationToken = default) | ||
| { | ||
| _ = Throw.IfNull(chunks); | ||
|
|
||
| IReadOnlyList<object>? preExistingKeys = null; | ||
| List<Dictionary<string, object?>>? batch = null; | ||
| IReadOnlyList<Guid>? preExistingKeys = null; | ||
| List<TRecord>? batch = null; | ||
| long currentBatchTokenCount = 0; | ||
|
|
||
| await foreach (IngestionChunk<T> chunk in chunks.WithCancellation(cancellationToken)) | ||
| await foreach (IngestionChunk<TChunk> chunk in chunks.WithCancellation(cancellationToken)) | ||
| { | ||
| if (_vectorStoreCollection is null) | ||
| if (!_collectionEnsured) | ||
| { | ||
| _vectorStoreCollection = _vectorStore.GetDynamicCollection(_options.CollectionName, GetVectorStoreRecordDefinition(chunk)); | ||
|
|
||
| await _vectorStoreCollection.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); | ||
| await VectorStoreCollection.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); | ||
| _collectionEnsured = true; | ||
| } | ||
|
|
||
| // We obtain the IDs of the pre-existing chunks for given document, | ||
| // and delete them after we finish inserting the new chunks, | ||
| // to avoid a situation where we delete the chunks and then fail to insert the new ones. | ||
| preExistingKeys ??= await GetPreExistingChunksIdsAsync(chunk.Document, cancellationToken).ConfigureAwait(false); | ||
|
|
||
| var key = Guid.NewGuid(); | ||
| Dictionary<string, object?> record = new() | ||
| TRecord record = new() | ||
| { | ||
| [KeyName] = key, | ||
| [ContentName] = chunk.Content, | ||
| [EmbeddingName] = chunk.Content, | ||
| [ContextName] = chunk.Context, | ||
| [DocumentIdName] = chunk.Document.Identifier, | ||
| Key = Guid.NewGuid(), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We discussed it offline and agreed I am going to remove it in a separate PR (#7410), as it most likely going to require a NuGet update and mirror sync. |
||
| Content = chunk.Content, | ||
| Context = chunk.Context, | ||
| DocumentId = chunk.Document.Identifier, | ||
| }; | ||
|
|
||
| if (chunk.HasMetadata) | ||
| { | ||
| foreach (var metadata in chunk.Metadata) | ||
| { | ||
| record[metadata.Key] = metadata.Value; | ||
| SetMetadata(record, metadata.Key, metadata.Value); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -102,7 +86,7 @@ public override async Task WriteAsync(IAsyncEnumerable<IngestionChunk<T>> chunks | |
| // If the batch is empty or the chunk alone exceeds the limit, add it anyway. | ||
| if (batch.Count > 0 && currentBatchTokenCount + chunk.TokenCount > _options.BatchTokenCount) | ||
| { | ||
| await _vectorStoreCollection.UpsertAsync(batch, cancellationToken).ConfigureAwait(false); | ||
| await VectorStoreCollection.UpsertAsync(batch, cancellationToken).ConfigureAwait(false); | ||
|
|
||
| batch.Clear(); | ||
| currentBatchTokenCount = 0; | ||
|
|
@@ -115,75 +99,31 @@ public override async Task WriteAsync(IAsyncEnumerable<IngestionChunk<T>> chunks | |
| // Upsert any remaining chunks in the batch | ||
| if (batch?.Count > 0) | ||
| { | ||
| await _vectorStoreCollection!.UpsertAsync(batch, cancellationToken).ConfigureAwait(false); | ||
| await VectorStoreCollection.UpsertAsync(batch, cancellationToken).ConfigureAwait(false); | ||
| } | ||
|
|
||
| if (preExistingKeys?.Count > 0) | ||
| { | ||
| await _vectorStoreCollection!.DeleteAsync(preExistingKeys, cancellationToken).ConfigureAwait(false); | ||
| } | ||
| } | ||
|
|
||
| /// <inheritdoc/> | ||
| protected override void Dispose(bool disposing) | ||
| { | ||
| try | ||
| { | ||
| _vectorStoreCollection?.Dispose(); | ||
| } | ||
| finally | ||
| { | ||
| _vectorStore.Dispose(); | ||
| base.Dispose(disposing); | ||
| await VectorStoreCollection.DeleteAsync(preExistingKeys, cancellationToken).ConfigureAwait(false); | ||
| } | ||
| } | ||
|
|
||
| private VectorStoreCollectionDefinition GetVectorStoreRecordDefinition(IngestionChunk<T> representativeChunk) | ||
| /// <summary> | ||
| /// Sets a metadata value on the record. | ||
| /// </summary> | ||
| /// <param name="record">The record on which to set the metadata.</param> | ||
| /// <param name="key">The metadata key.</param> | ||
| /// <param name="value">The metadata value.</param> | ||
| /// <remarks> | ||
| /// Override this method in derived classes to store metadata as typed properties with | ||
| /// <see cref="VectorStoreDataAttribute"/> attributes. | ||
| /// </remarks> | ||
| protected virtual void SetMetadata(TRecord record, string key, object? value) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To other reviewers: So far, we were optimized for very easy ingestion. Now, the RAG is way simpler but when you need to use metadata, you need to create a derived type and handle it on your own. We throw here to avoid silent errors. |
||
| { | ||
| VectorStoreCollectionDefinition definition = new() | ||
| { | ||
| Properties = | ||
| { | ||
| new VectorStoreKeyProperty(KeyName, typeof(Guid)), | ||
|
|
||
| // By using T as the type here we allow the vector store | ||
| // to handle the conversion from T to the actual vector type it supports. | ||
| new VectorStoreVectorProperty(EmbeddingName, typeof(T), _dimensionCount) | ||
| { | ||
| DistanceFunction = _options.DistanceFunction, | ||
| IndexKind = _options.IndexKind | ||
| }, | ||
| new VectorStoreDataProperty(ContentName, typeof(T)), | ||
| new VectorStoreDataProperty(ContextName, typeof(string)), | ||
| new VectorStoreDataProperty(DocumentIdName, typeof(string)) | ||
| { | ||
| IsIndexed = true | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| if (representativeChunk.HasMetadata) | ||
| { | ||
| foreach (var metadata in representativeChunk.Metadata) | ||
| { | ||
| Type propertyType = metadata.Value.GetType(); | ||
| definition.Properties.Add(new VectorStoreDataProperty(metadata.Key, propertyType) | ||
| { | ||
| // We use lowercase storage names to ensure compatibility with various vector stores. | ||
| #pragma warning disable CA1308 // Normalize strings to uppercase | ||
| StorageName = metadata.Key.ToLowerInvariant() | ||
| #pragma warning restore CA1308 // Normalize strings to uppercase | ||
|
|
||
| // We could consider indexing for certain keys like classification etc. but for now we leave it as non-indexed. | ||
| // The reason is that not every DB supports it, moreover we would need to expose the ability to configure it. | ||
| }); | ||
| } | ||
| } | ||
|
|
||
| return definition; | ||
| throw new NotSupportedException($"Metadata key '{key}' is not supported. Override {nameof(SetMetadata)} in a derived class to handle metadata."); | ||
| } | ||
|
|
||
| private async Task<IReadOnlyList<object>> GetPreExistingChunksIdsAsync(IngestionDocument document, CancellationToken cancellationToken) | ||
| private async Task<IReadOnlyList<Guid>> GetPreExistingChunksIdsAsync(IngestionDocument document, CancellationToken cancellationToken) | ||
| { | ||
| if (!_options.IncrementalIngestion) | ||
| { | ||
|
|
@@ -193,19 +133,19 @@ private async Task<IReadOnlyList<object>> GetPreExistingChunksIdsAsync(Ingestion | |
| // Each Vector Store has a different max top count limit, so we use low value and loop. | ||
| const int MaxTopCount = 1_000; | ||
|
|
||
| List<object> keys = []; | ||
| List<Guid> keys = []; | ||
| int insertedCount; | ||
| do | ||
| { | ||
| insertedCount = 0; | ||
|
|
||
| await foreach (var record in _vectorStoreCollection!.GetAsync( | ||
| filter: record => (string)record[DocumentIdName]! == document.Identifier, | ||
| await foreach (var record in VectorStoreCollection.GetAsync( | ||
| filter: record => record.DocumentId == document.Identifier, | ||
| top: MaxTopCount, | ||
| options: new() { Skip = keys.Count }, | ||
| cancellationToken: cancellationToken).ConfigureAwait(false)) | ||
| { | ||
| keys.Add(record[KeyName]!); | ||
| keys.Add(record.Key); | ||
| insertedCount++; | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To other reviewers: I am going to work on removing this generic argument very soon (I want
IngestionChunkto be able to represent any input without using generic argument). But it's out of the scope of this PR.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know you're going to remove the generic argument, but FYI the TChunk name is causing me a bit of confusion, also in
IngestionChunk<TChunk>(as if it's a chunk over itself). When reading this code I wasn't sure if withIngestedChunkRecord<TChunk>, TChunk should be string orIngestionChunk<string>.So maybe consider renaming TChunk to just T (or TContent) everywhere.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I really want to remove it and now I think I even know how ( #7404)
Please keep in mind it's Preview2 branch, so whatever gets merged does not automatically gets released to nuget.org. So I would prefer to keep
TChunkhere and just remove it completely in next PR.