Skip to content

Commit 71cf460

Browse files
committed
Implement vector delete and search operations with corresponding results in Dredis
1 parent d16d885 commit 71cf460

4 files changed

Lines changed: 413 additions & 3 deletions

File tree

Dredis.Abstractions.Storage/IKeyValueStore.cs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,78 @@ public VectorSimilarityResult(VectorResultStatus status, double? value)
10361036
public double? Value { get; }
10371037
}
10381038

1039+
/// <summary>
1040+
/// Represents a result for vector delete operations.
1041+
/// </summary>
1042+
public sealed class VectorDeleteResult
1043+
{
1044+
/// <summary>
1045+
/// Initializes a new instance of the <see cref="VectorDeleteResult"/> class.
1046+
/// </summary>
1047+
public VectorDeleteResult(VectorResultStatus status, long deleted)
1048+
{
1049+
Status = status;
1050+
Deleted = deleted;
1051+
}
1052+
1053+
/// <summary>
1054+
/// Gets the result status.
1055+
/// </summary>
1056+
public VectorResultStatus Status { get; }
1057+
/// <summary>
1058+
/// Gets the deleted vector count.
1059+
/// </summary>
1060+
public long Deleted { get; }
1061+
}
1062+
1063+
/// <summary>
1064+
/// Represents a vector search entry.
1065+
/// </summary>
1066+
public sealed class VectorSearchEntry
1067+
{
1068+
/// <summary>
1069+
/// Initializes a new instance of the <see cref="VectorSearchEntry"/> class.
1070+
/// </summary>
1071+
public VectorSearchEntry(string key, double score)
1072+
{
1073+
Key = key;
1074+
Score = score;
1075+
}
1076+
1077+
/// <summary>
1078+
/// Gets the vector key.
1079+
/// </summary>
1080+
public string Key { get; }
1081+
/// <summary>
1082+
/// Gets the vector score.
1083+
/// </summary>
1084+
public double Score { get; }
1085+
}
1086+
1087+
/// <summary>
1088+
/// Represents a result for vector search operations.
1089+
/// </summary>
1090+
public sealed class VectorSearchResult
1091+
{
1092+
/// <summary>
1093+
/// Initializes a new instance of the <see cref="VectorSearchResult"/> class.
1094+
/// </summary>
1095+
public VectorSearchResult(VectorResultStatus status, VectorSearchEntry[] entries)
1096+
{
1097+
Status = status;
1098+
Entries = entries;
1099+
}
1100+
1101+
/// <summary>
1102+
/// Gets the result status.
1103+
/// </summary>
1104+
public VectorResultStatus Status { get; }
1105+
/// <summary>
1106+
/// Gets the search result entries.
1107+
/// </summary>
1108+
public VectorSearchEntry[] Entries { get; }
1109+
}
1110+
10391111
/// <summary>
10401112
/// Key-Value Storage abstraction for Dredis.
10411113
/// </summary>
@@ -1972,5 +2044,31 @@ Task<VectorSimilarityResult> VectorSimilarityAsync(
19722044
string otherKey,
19732045
string metric,
19742046
CancellationToken token = default);
2047+
2048+
/// <summary>
2049+
/// Deletes a vector key.
2050+
/// </summary>
2051+
/// <param name="key">The vector key.</param>
2052+
/// <param name="token">A cancellation token that can be used to cancel the operation.</param>
2053+
/// <returns>A task that represents the asynchronous operation. The task result contains delete status and count.</returns>
2054+
Task<VectorDeleteResult> VectorDeleteAsync(
2055+
string key,
2056+
CancellationToken token = default);
2057+
2058+
/// <summary>
2059+
/// Searches vectors by prefix and returns top-k scored matches.
2060+
/// </summary>
2061+
/// <param name="keyPrefix">The key prefix to search (ordinal starts-with).</param>
2062+
/// <param name="topK">The maximum number of results to return.</param>
2063+
/// <param name="metric">The metric name (COSINE, DOT, or L2).</param>
2064+
/// <param name="queryVector">The query vector.</param>
2065+
/// <param name="token">A cancellation token that can be used to cancel the operation.</param>
2066+
/// <returns>A task that represents the asynchronous operation. The task result contains scored entries.</returns>
2067+
Task<VectorSearchResult> VectorSearchAsync(
2068+
string keyPrefix,
2069+
int topK,
2070+
string metric,
2071+
double[] queryVector,
2072+
CancellationToken token = default);
19752073
}
19762074
}

Dredis.Tests/DredisCommandHandlerTests.cs

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,6 +1535,101 @@ public async Task Vector_WrongTypeAndInvalidOp_ReturnErrors()
15351535
}
15361536
}
15371537

1538+
[Fact]
1539+
public async Task Vector_Delete_ReturnsCountAndRemovesVector()
1540+
{
1541+
var store = new InMemoryKeyValueStore();
1542+
var channel = new EmbeddedChannel(new DredisCommandHandler(store));
1543+
1544+
try
1545+
{
1546+
channel.WriteInbound(Command("VSET", "vec:1", "1", "2"));
1547+
channel.RunPendingTasks();
1548+
_ = ReadOutbound(channel);
1549+
1550+
channel.WriteInbound(Command("VDEL", "vec:1"));
1551+
channel.RunPendingTasks();
1552+
var delResponse = ReadOutbound(channel);
1553+
var deleted = Assert.IsType<IntegerRedisMessage>(delResponse);
1554+
Assert.Equal(1, deleted.Value);
1555+
1556+
channel.WriteInbound(Command("VGET", "vec:1"));
1557+
channel.RunPendingTasks();
1558+
var getResponse = ReadOutbound(channel);
1559+
Assert.Same(FullBulkStringRedisMessage.Null, getResponse);
1560+
}
1561+
finally
1562+
{
1563+
await channel.CloseAsync();
1564+
}
1565+
}
1566+
1567+
[Fact]
1568+
public async Task Vector_Search_ReturnsTopKByMetric()
1569+
{
1570+
var store = new InMemoryKeyValueStore();
1571+
var channel = new EmbeddedChannel(new DredisCommandHandler(store));
1572+
1573+
try
1574+
{
1575+
channel.WriteInbound(Command("VSET", "emb:a", "1", "0"));
1576+
channel.RunPendingTasks();
1577+
_ = ReadOutbound(channel);
1578+
1579+
channel.WriteInbound(Command("VSET", "emb:b", "0.9", "0.1"));
1580+
channel.RunPendingTasks();
1581+
_ = ReadOutbound(channel);
1582+
1583+
channel.WriteInbound(Command("VSET", "emb:c", "-1", "0"));
1584+
channel.RunPendingTasks();
1585+
_ = ReadOutbound(channel);
1586+
1587+
channel.WriteInbound(Command("VSEARCH", "emb:", "2", "COSINE", "1", "0"));
1588+
channel.RunPendingTasks();
1589+
var cosineResponse = ReadOutbound(channel);
1590+
var cosine = Assert.IsType<ArrayRedisMessage>(cosineResponse);
1591+
Assert.Equal(4, cosine.Children.Count);
1592+
Assert.Equal("emb:a", GetBulkString(Assert.IsType<FullBulkStringRedisMessage>(cosine.Children[0])));
1593+
Assert.Equal("emb:b", GetBulkString(Assert.IsType<FullBulkStringRedisMessage>(cosine.Children[2])));
1594+
1595+
channel.WriteInbound(Command("VSEARCH", "emb:", "1", "L2", "1", "0"));
1596+
channel.RunPendingTasks();
1597+
var l2Response = ReadOutbound(channel);
1598+
var l2 = Assert.IsType<ArrayRedisMessage>(l2Response);
1599+
Assert.Equal(2, l2.Children.Count);
1600+
Assert.Equal("emb:a", GetBulkString(Assert.IsType<FullBulkStringRedisMessage>(l2.Children[0])));
1601+
Assert.Equal("0", GetBulkString(Assert.IsType<FullBulkStringRedisMessage>(l2.Children[1])));
1602+
}
1603+
finally
1604+
{
1605+
await channel.CloseAsync();
1606+
}
1607+
}
1608+
1609+
[Fact]
1610+
public async Task Vector_Search_InvalidMetric_ReturnsError()
1611+
{
1612+
var store = new InMemoryKeyValueStore();
1613+
var channel = new EmbeddedChannel(new DredisCommandHandler(store));
1614+
1615+
try
1616+
{
1617+
channel.WriteInbound(Command("VSET", "emb:a", "1", "0"));
1618+
channel.RunPendingTasks();
1619+
_ = ReadOutbound(channel);
1620+
1621+
channel.WriteInbound(Command("VSEARCH", "emb:", "5", "BAD", "1", "0"));
1622+
channel.RunPendingTasks();
1623+
var response = ReadOutbound(channel);
1624+
var error = Assert.IsType<ErrorRedisMessage>(response);
1625+
Assert.Equal("ERR invalid vector operation", error.Content);
1626+
}
1627+
finally
1628+
{
1629+
await channel.CloseAsync();
1630+
}
1631+
}
1632+
15381633
[Fact]
15391634
public async Task Publish_NoSubscribers_ReturnsZero()
15401635
{
@@ -4350,6 +4445,122 @@ public Task<VectorSimilarityResult> VectorSimilarityAsync(
43504445
return Task.FromResult(new VectorSimilarityResult(VectorResultStatus.InvalidArgument, null));
43514446
}
43524447

4448+
public Task<VectorDeleteResult> VectorDeleteAsync(
4449+
string key,
4450+
CancellationToken token = default)
4451+
{
4452+
if (IsExpired(key))
4453+
{
4454+
RemoveKey(key);
4455+
}
4456+
4457+
if (_data.ContainsKey(key) || _hashes.ContainsKey(key) || _lists.ContainsKey(key) || _sets.ContainsKey(key) || _sortedSets.ContainsKey(key) || _streams.ContainsKey(key) || _streamGroups.ContainsKey(key))
4458+
{
4459+
return Task.FromResult(new VectorDeleteResult(VectorResultStatus.WrongType, 0));
4460+
}
4461+
4462+
var removed = _vectors.Remove(key);
4463+
if (removed)
4464+
{
4465+
_expirations.Remove(key);
4466+
}
4467+
4468+
return Task.FromResult(new VectorDeleteResult(VectorResultStatus.Ok, removed ? 1 : 0));
4469+
}
4470+
4471+
public Task<VectorSearchResult> VectorSearchAsync(
4472+
string keyPrefix,
4473+
int topK,
4474+
string metric,
4475+
double[] queryVector,
4476+
CancellationToken token = default)
4477+
{
4478+
if (queryVector.Length == 0 || topK <= 0)
4479+
{
4480+
return Task.FromResult(new VectorSearchResult(VectorResultStatus.InvalidArgument, Array.Empty<VectorSearchEntry>()));
4481+
}
4482+
4483+
var scored = new List<VectorSearchEntry>();
4484+
foreach (var kvp in _vectors)
4485+
{
4486+
if (!kvp.Key.StartsWith(keyPrefix, StringComparison.Ordinal))
4487+
{
4488+
continue;
4489+
}
4490+
4491+
if (IsExpired(kvp.Key))
4492+
{
4493+
RemoveKey(kvp.Key);
4494+
continue;
4495+
}
4496+
4497+
var candidate = kvp.Value;
4498+
if (candidate.Length != queryVector.Length)
4499+
{
4500+
continue;
4501+
}
4502+
4503+
double score;
4504+
if (metric.Equals("DOT", StringComparison.OrdinalIgnoreCase))
4505+
{
4506+
score = 0;
4507+
for (int i = 0; i < candidate.Length; i++)
4508+
{
4509+
score += queryVector[i] * candidate[i];
4510+
}
4511+
}
4512+
else if (metric.Equals("COSINE", StringComparison.OrdinalIgnoreCase))
4513+
{
4514+
double dot = 0;
4515+
double queryNorm = 0;
4516+
double candNorm = 0;
4517+
for (int i = 0; i < candidate.Length; i++)
4518+
{
4519+
dot += queryVector[i] * candidate[i];
4520+
queryNorm += queryVector[i] * queryVector[i];
4521+
candNorm += candidate[i] * candidate[i];
4522+
}
4523+
4524+
if (queryNorm <= 0 || candNorm <= 0)
4525+
{
4526+
continue;
4527+
}
4528+
4529+
score = dot / (Math.Sqrt(queryNorm) * Math.Sqrt(candNorm));
4530+
}
4531+
else if (metric.Equals("L2", StringComparison.OrdinalIgnoreCase))
4532+
{
4533+
double sum = 0;
4534+
for (int i = 0; i < candidate.Length; i++)
4535+
{
4536+
var delta = queryVector[i] - candidate[i];
4537+
sum += delta * delta;
4538+
}
4539+
4540+
score = Math.Sqrt(sum);
4541+
}
4542+
else
4543+
{
4544+
return Task.FromResult(new VectorSearchResult(VectorResultStatus.InvalidArgument, Array.Empty<VectorSearchEntry>()));
4545+
}
4546+
4547+
scored.Add(new VectorSearchEntry(kvp.Key, score));
4548+
}
4549+
4550+
IEnumerable<VectorSearchEntry> ordered;
4551+
if (metric.Equals("L2", StringComparison.OrdinalIgnoreCase))
4552+
{
4553+
ordered = scored.OrderBy(entry => entry.Score).ThenBy(entry => entry.Key, StringComparer.Ordinal);
4554+
}
4555+
else
4556+
{
4557+
ordered = scored.OrderByDescending(entry => entry.Score).ThenBy(entry => entry.Key, StringComparer.Ordinal);
4558+
}
4559+
4560+
var top = ordered.Take(topK).ToArray();
4561+
return Task.FromResult(new VectorSearchResult(VectorResultStatus.Ok, top));
4562+
}
4563+
43534564
/// <summary>
43544565
/// Adds a stream entry and returns its id.
43554566
/// </summary>

0 commit comments

Comments
 (0)