From 457057fe9e3e6d370a08351373aa0964d6aff981 Mon Sep 17 00:00:00 2001 From: Jonathan Raphaelson Date: Fri, 12 Jun 2026 02:17:17 -0600 Subject: [PATCH] extends LongTermMemoryService to take TagFilters as params currently, `namespace`, `userId`, `sessionId`, `topics`, and `entities` are hard coded into the LongTermMemoryService's search request as strings (for `eq`),or lists-of-strings (for `any`). replace them with a new `TagFilter` class, which mirrors the filter definitions in the server and other client just encodes the query into the JSON the server see: agent-memory-server/agent_memory_server/models.py SearchRequest see: agent-memory-client/agent-memory-client-js/src/models.ts SearchRequestParams --- .../agentmemory/models/common/TagFilter.java | 97 +++++++++++ .../models/longtermemory/SearchRequest.java | 76 ++++++-- .../services/LongTermMemoryService.java | 17 +- .../models/common/TagFilterTest.java | 55 ++++++ .../services/LongTermMemoryServiceTest.java | 163 ++++++++++++++++++ 5 files changed, 387 insertions(+), 21 deletions(-) create mode 100644 agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/common/TagFilter.java create mode 100644 agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/common/TagFilterTest.java diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/common/TagFilter.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/common/TagFilter.java new file mode 100644 index 00000000..cf2cc694 --- /dev/null +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/common/TagFilter.java @@ -0,0 +1,97 @@ +package com.redis.agentmemory.models.common; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +import java.util.Arrays; +import java.util.List; + +/** + * Filter for tag-style string fields (user_id, session_id, namespace, topics, entities). + * Match the server-side TagFilter; supports eq, ne, any, all, and startswith operators. + * + *

+ * Example — match any of several user IDs in one search: + *

{@code
+ * SearchRequest.builder()
+ *     .userId(TagFilter.any("user-123", "__account__"))
+ *     .build()
+ * }
+ */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class TagFilter { + + @Nullable + private String eq; + + @Nullable + private String ne; + + @Nullable + @JsonProperty("any") + private List any; + + @Nullable + @JsonProperty("all") + private List all; + + @Nullable + private String startswith; + + private TagFilter() {} + + public static TagFilter eq(@NotNull String value) { + TagFilter f = new TagFilter(); + f.eq = value; + return f; + } + + public static TagFilter ne(@NotNull String value) { + TagFilter f = new TagFilter(); + f.ne = value; + return f; + } + + public static TagFilter any(@NotNull List values) { + TagFilter f = new TagFilter(); + f.any = List.copyOf(values); + return f; + } + + public static TagFilter any(@NotNull String... values) { + return any(Arrays.asList(values)); + } + + public static TagFilter all(@NotNull List values) { + TagFilter f = new TagFilter(); + f.all = List.copyOf(values); + return f; + } + + public static TagFilter all(@NotNull String... values) { + return all(Arrays.asList(values)); + } + + public static TagFilter startsWith(@NotNull String prefix) { + TagFilter f = new TagFilter(); + f.startswith = prefix; + return f; + } + + @Nullable + public String getEq() { return eq; } + + @Nullable + public String getNe() { return ne; } + + @Nullable + public List getAny() { return any; } + + @Nullable + public List getAll() { return all; } + + @Nullable + public String getStartswith() { return startswith; } +} diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/SearchRequest.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/SearchRequest.java index 2e0c9ef9..5168d15c 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/SearchRequest.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/SearchRequest.java @@ -2,6 +2,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import com.redis.agentmemory.models.common.TagFilter; import org.jetbrains.annotations.Nullable; import java.util.List; @@ -29,20 +30,20 @@ public class SearchRequest { @Nullable @JsonProperty("session_id") - private String sessionId; + private TagFilter sessionId; @Nullable - private String namespace; + private TagFilter namespace; @Nullable - private List topics; + private TagFilter topics; @Nullable - private List entities; + private TagFilter entities; @Nullable @JsonProperty("user_id") - private String userId; + private TagFilter userId; @Nullable @JsonProperty("distance_threshold") @@ -125,47 +126,69 @@ public void setTextScorer(@Nullable String textScorer) { } @Nullable - public String getSessionId() { + public TagFilter getSessionId() { return sessionId; } public void setSessionId(@Nullable String sessionId) { + this.sessionId = sessionId != null ? TagFilter.eq(sessionId) : null; + } + + public void setSessionId(@Nullable TagFilter sessionId) { this.sessionId = sessionId; } @Nullable - public String getNamespace() { + public TagFilter getNamespace() { return namespace; } public void setNamespace(@Nullable String namespace) { + this.namespace = namespace != null ? TagFilter.eq(namespace) : null; + } + + public void setNamespace(@Nullable TagFilter namespace) { this.namespace = namespace; } @Nullable - public List getTopics() { + public TagFilter getTopics() { return topics; } public void setTopics(@Nullable List topics) { + var present = topics != null && !topics.isEmpty(); + this.topics = present ? TagFilter.any(topics) : null; + } + + public void setTopics(@Nullable TagFilter topics) { this.topics = topics; } @Nullable - public List getEntities() { + public TagFilter getEntities() { return entities; } public void setEntities(@Nullable List entities) { + var present = entities != null && !entities.isEmpty(); + this.entities = present ? TagFilter.any(entities) : null; + } + + public void setEntities(@Nullable TagFilter entities) { this.entities = entities; } @Nullable - public String getUserId() { + public TagFilter getUserId() { return userId; } public void setUserId(@Nullable String userId) { + this.userId = userId != null ? TagFilter.eq(userId) : null; + } + + public void setUserId(@Nullable TagFilter userId) { this.userId = userId; } @@ -273,11 +296,11 @@ public String toString() { ", searchMode='" + searchMode + '\'' + ", hybridAlpha=" + hybridAlpha + ", textScorer='" + textScorer + '\'' + - ", sessionId='" + sessionId + '\'' + - ", namespace='" + namespace + '\'' + + ", sessionId=" + sessionId + + ", namespace=" + namespace + ", topics=" + topics + ", entities=" + entities + - ", userId='" + userId + '\'' + + ", userId=" + userId + ", distanceThreshold=" + distanceThreshold + ", limit=" + limit + ", offset=" + offset + @@ -327,26 +350,53 @@ public Builder textScorer(@Nullable String textScorer) { } public Builder sessionId(@Nullable String sessionId) { + request.sessionId = sessionId != null ? TagFilter.eq(sessionId) : null; + return this; + } + + public Builder sessionId(@Nullable TagFilter sessionId) { request.sessionId = sessionId; return this; } public Builder namespace(@Nullable String namespace) { + request.namespace = namespace != null ? TagFilter.eq(namespace) : null; + return this; + } + + public Builder namespace(@Nullable TagFilter namespace) { request.namespace = namespace; return this; } public Builder topics(@Nullable List topics) { + var present = topics != null && !topics.isEmpty(); + request.topics = present ? TagFilter.any(topics) : null; + return this; + } + + public Builder topics(@Nullable TagFilter topics) { request.topics = topics; return this; } public Builder entities(@Nullable List entities) { + var present = entities != null && !entities.isEmpty(); + request.entities = present ? TagFilter.any(entities) : null; + return this; + } + + public Builder entities(@Nullable TagFilter entities) { request.entities = entities; return this; } public Builder userId(@Nullable String userId) { + request.userId = userId != null ? TagFilter.eq(userId) : null; + return this; + } + + public Builder userId(@Nullable TagFilter userId) { request.userId = userId; return this; } diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/LongTermMemoryService.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/LongTermMemoryService.java index faeaf090..b9a66906 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/LongTermMemoryService.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/LongTermMemoryService.java @@ -3,6 +3,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.redis.agentmemory.exceptions.MemoryClientException; import com.redis.agentmemory.models.common.AckResponse; +import com.redis.agentmemory.models.common.TagFilter; import com.redis.agentmemory.models.longtermemory.*; import okhttp3.*; import org.jetbrains.annotations.NotNull; @@ -86,22 +87,22 @@ public MemoryRecordResults searchLongTermMemories(@NotNull SearchRequest request // Add filters if present if (request.getSessionId() != null) { - payload.put("session_id", Map.of("eq", request.getSessionId())); + payload.put("session_id", request.getSessionId()); } if (request.getUserId() != null) { - payload.put("user_id", Map.of("eq", request.getUserId())); + payload.put("user_id", request.getUserId()); } if (request.getNamespace() != null) { - payload.put("namespace", Map.of("eq", request.getNamespace())); + payload.put("namespace", request.getNamespace()); } else if (defaultNamespace != null) { - payload.put("namespace", Map.of("eq", defaultNamespace)); + payload.put("namespace", TagFilter.eq(defaultNamespace)); } - if (request.getTopics() != null && !request.getTopics().isEmpty()) { - payload.put("topics", Map.of("any", request.getTopics())); + if (request.getTopics() != null) { + payload.put("topics", request.getTopics()); } - if (request.getEntities() != null && !request.getEntities().isEmpty()) { - payload.put("entities", Map.of("any", request.getEntities())); + if (request.getEntities() != null) { + payload.put("entities", request.getEntities()); } if (request.getDistanceThreshold() != null) { payload.put("distance_threshold", request.getDistanceThreshold()); diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/common/TagFilterTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/common/TagFilterTest.java new file mode 100644 index 00000000..43f834e3 --- /dev/null +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/common/TagFilterTest.java @@ -0,0 +1,55 @@ +package com.redis.agentmemory.models.common; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +class TagFilterTest { + + private ObjectMapper objectMapper; + + @BeforeEach + void setUp() { + objectMapper = new ObjectMapper(); + } + + @Test + void eq_serializesToEqOperator() throws Exception { + String json = objectMapper.writeValueAsString(TagFilter.eq("user-123")); + assertEquals("{\"eq\":\"user-123\"}", json); + } + + @Test + void ne_serializesToNeOperator() throws Exception { + String json = objectMapper.writeValueAsString(TagFilter.ne("user-123")); + assertEquals("{\"ne\":\"user-123\"}", json); + } + + @Test + void any_varargs_serializesToAnyOperator() throws Exception { + String json = objectMapper.writeValueAsString(TagFilter.any("user-123", "__account__")); + assertEquals("{\"any\":[\"user-123\",\"__account__\"]}", json); + } + + @Test + void any_list_serializesToAnyOperator() throws Exception { + String json = objectMapper.writeValueAsString(TagFilter.any(List.of("a", "b", "c"))); + assertEquals("{\"any\":[\"a\",\"b\",\"c\"]}", json); + } + + @Test + void all_serializesToAllOperator() throws Exception { + String json = objectMapper.writeValueAsString(TagFilter.all("x", "y")); + assertEquals("{\"all\":[\"x\",\"y\"]}", json); + } + + @Test + void startsWith_serializesToStartswithOperator() throws Exception { + String json = objectMapper.writeValueAsString(TagFilter.startsWith("tenant-")); + assertEquals("{\"startswith\":\"tenant-\"}", json); + } +} diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/LongTermMemoryServiceTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/LongTermMemoryServiceTest.java index 6ffea49a..bc3fef01 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/LongTermMemoryServiceTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/LongTermMemoryServiceTest.java @@ -4,6 +4,7 @@ import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import com.redis.agentmemory.MemoryAPIClient; import com.redis.agentmemory.models.common.AckResponse; +import com.redis.agentmemory.models.common.TagFilter; import com.redis.agentmemory.models.longtermemory.*; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; @@ -435,6 +436,162 @@ void testSearchLongTermMemories_WithDistanceThreshold() throws Exception { assertTrue(requestBody.contains("\"distance_threshold\":0.35")); } + @Test + void testSearchLongTermMemories_WithUserIdString() throws Exception { + // backward compat: plain String userId still serializes {"eq": "..."} + mockServer.enqueue(new MockResponse() + .setBody(objectMapper.writeValueAsString(emptyResults())) + .addHeader("Content-Type", "application/json")); + + SearchRequest request = SearchRequest.builder() + .text("query") + .userId("user-123") + .build(); + client.longTermMemory().searchLongTermMemories(request); + + RecordedRequest recorded = mockServer.takeRequest(); + String body = recorded.getBody().readUtf8(); + assertTrue(body.contains("\"user_id\":{\"eq\":\"user-123\"}")); + } + + + @Test + void testSearchLongTermMemories_WithUserIdAny() throws Exception { + mockServer.enqueue(new MockResponse() + .setBody(objectMapper.writeValueAsString(emptyResults())) + .addHeader("Content-Type", "application/json")); + + SearchRequest request = SearchRequest.builder() + .text("query") + .userId(TagFilter.any("user-123", "__account__")) + .build(); + client.longTermMemory().searchLongTermMemories(request); + + RecordedRequest recorded = mockServer.takeRequest(); + String body = recorded.getBody().readUtf8(); + assertTrue(body.contains("\"user_id\":{\"any\":[\"user-123\",\"__account__\"]}")); + assertFalse(body.contains("\"eq\"")); + } + + @Test + void testSearchLongTermMemory_WithNamespaceTagFilter() throws Exception { + mockServer.enqueue(new MockResponse() + .setBody(objectMapper.writeValueAsString(emptyResults())) + .addHeader("Content-Type", "application/json")); + + SearchRequest request = SearchRequest.builder() + .text("query") + .namespace(TagFilter.startsWith("tenant-")) + .build(); + client.longTermMemory().searchLongTermMemories(request); + + RecordedRequest recorded = mockServer.takeRequest(); + String body = recorded.getBody().readUtf8(); + assertTrue(body.contains("\"namespace\":{\"startswith\":\"tenant-\"}")); + } + + @Test + void testSearchLongTermMemory_WithSessionIdAny() throws Exception { + mockServer.enqueue(new MockResponse() + .setBody(objectMapper.writeValueAsString(emptyResults())) + .addHeader("Content-Type", "application/json")); + + SearchRequest request = SearchRequest.builder() + .text("query") + .sessionId(TagFilter.any("session-a", "session-b")) + .build(); + client.longTermMemory().searchLongTermMemories(request); + + RecordedRequest recorded = mockServer.takeRequest(); + String body = recorded.getBody().readUtf8(); + assertTrue(body.contains("\"session_id\":{\"any\":[\"session-a\",\"session-b\"]}")); + } + + @Test + void testSearchLongTermMemory_WithTopicsList() throws Exception { + // backward compat: List convenience API still serializes {"any": [...]} + mockServer.enqueue(new MockResponse() + .setBody(objectMapper.writeValueAsString(emptyResults())) + .addHeader("Content-Type", "application/json")); + + SearchRequest request = SearchRequest.builder() + .text("query") + .topics(List.of("finance", "real-estate")) + .build(); + client.longTermMemory().searchLongTermMemories(request); + + RecordedRequest recorded = mockServer.takeRequest(); + String body = recorded.getBody().readUtf8(); + assertTrue(body.contains("\"topics\":{\"any\":[\"finance\",\"real-estate\"]}")); + } + + @Test + void testSearchLongTermMemory_WithTopicsAllFilter() throws Exception { + mockServer.enqueue(new MockResponse() + .setBody(objectMapper.writeValueAsString(emptyResults())) + .addHeader("Content-Type", "application/json")); + + SearchRequest request = SearchRequest.builder() + .text("query") + .topics(TagFilter.all("finance", "real-estate")) + .build(); + client.longTermMemory().searchLongTermMemories(request); + + RecordedRequest recorded = mockServer.takeRequest(); + String body = recorded.getBody().readUtf8(); + assertTrue(body.contains("\"topics\":{\"all\":[\"finance\",\"real-estate\"]}")); + } + + @Test + void testSearchLongTermMemory_WithEntitiesTagFilter() throws Exception { + mockServer.enqueue(new MockResponse() + .setBody(objectMapper.writeValueAsString(emptyResults())) + .addHeader("Content-Type", "application/json")); + + SearchRequest request = SearchRequest.builder() + .text("query") + .entities(TagFilter.any("Google", "Apple")) + .build(); + client.longTermMemory().searchLongTermMemories(request); + + RecordedRequest recorded = mockServer.takeRequest(); + String body = recorded.getBody().readUtf8(); + assertTrue(body.contains("\"entities\":{\"any\":[\"Google\",\"Apple\"]}")); + } + + @Test + void testSearchLongTermMemory_WithEmptyTopicsList() throws Exception { + mockServer.enqueue(new MockResponse() + .setBody(objectMapper.writeValueAsString(emptyResults())) + .addHeader("Content-Type", "application/json")); + + SearchRequest request = SearchRequest.builder() + .text("query") + .topics(List.of()) + .build(); + client.longTermMemory().searchLongTermMemories(request); + + RecordedRequest recorded = mockServer.takeRequest(); + String body = recorded.getBody().readUtf8(); + assertFalse(body.contains("\"topics\"")); + } + + @Test + void testSearchLongTermMemory_WithNoUserId() throws Exception { + mockServer.enqueue(new MockResponse() + .setBody(objectMapper.writeValueAsString(emptyResults())) + .addHeader("Content-Type", "application/json")); + + SearchRequest request = SearchRequest.builder() + .text("query") + .build(); + client.longTermMemory().searchLongTermMemories(request); + + RecordedRequest recorded = mockServer.takeRequest(); + String body = recorded.getBody().readUtf8(); + assertFalse(body.contains("\"user_id\"")); + } + @Test void testSearchRequestBuilder_AllRecencyFields() { // Test that all recency fields can be set via builder @@ -461,4 +618,10 @@ void testSearchRequestBuilder_AllRecencyFields() { assertFalse(request.getServerSideRecency()); } + private MemoryRecordResults emptyResults() { + MemoryRecordResults r = new MemoryRecordResults(); + r.setMemories(new ArrayList<>()); + r.setTotal(0); + return r; + } }