Skip to content

Commit 266de35

Browse files
Add support for CountQuery (#65)
Adds support for a simple `CountQuery` which expects a `FilterExpression` and allows users to check how many records match a particular set of filters. Also adds documentation and examples for multiple tag fields and count queries.
1 parent cf630f0 commit 266de35

File tree

5 files changed

+145
-63
lines changed

5 files changed

+145
-63
lines changed

docs/user_guide/hybrid_queries_02.ipynb

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,19 @@
194194
"result_print(index.query(v))"
195195
]
196196
},
197+
{
198+
"cell_type": "code",
199+
"execution_count": null,
200+
"metadata": {},
201+
"outputs": [],
202+
"source": [
203+
"# multiple tags\n",
204+
"t = Tag(\"credit_score\") == [\"high\", \"medium\"]\n",
205+
"\n",
206+
"v.set_filter(t)\n",
207+
"result_print(index.query(v))"
208+
]
209+
},
197210
{
198211
"cell_type": "markdown",
199212
"metadata": {},
@@ -586,6 +599,32 @@
586599
"result_print(results)"
587600
]
588601
},
602+
{
603+
"cell_type": "markdown",
604+
"metadata": {},
605+
"source": [
606+
"## Count Queries\n",
607+
"\n",
608+
"In some cases, you may need to use a ``FilterExpression`` to execute a ``CountQuery`` that simply returns the count of the number of entities in the pertaining set. It is similar to the ``FilterQuery`` class but does not return the values of the underlying data."
609+
]
610+
},
611+
{
612+
"cell_type": "code",
613+
"execution_count": null,
614+
"metadata": {},
615+
"outputs": [],
616+
"source": [
617+
"from redisvl.query import CountQuery\n",
618+
"\n",
619+
"has_low_credit = Tag(\"credit_score\") == \"low\"\n",
620+
"\n",
621+
"filter_query = CountQuery(filter_expression=has_low_credit)\n",
622+
"\n",
623+
"count = index.query(filter_query)\n",
624+
"\n",
625+
"print(f\"{count} records match the filter expression {str(has_low_credit)} for the given index.\")"
626+
]
627+
},
589628
{
590629
"cell_type": "markdown",
591630
"metadata": {},
@@ -846,7 +885,7 @@
846885
"name": "python",
847886
"nbconvert_exporter": "python",
848887
"pygments_lexer": "ipython3",
849-
"version": "3.10.10"
888+
"version": "3.9.12"
850889
},
851890
"orig_nbformat": 4,
852891
"vscode": {

redisvl/index.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional
2+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
33
from uuid import uuid4
44

55
if TYPE_CHECKING:
@@ -10,6 +10,7 @@
1010
import redis
1111
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
1212

13+
from redisvl.query.query import CountQuery
1314
from redisvl.schema import SchemaModel, read_schema
1415
from redisvl.utils.connection import (
1516
check_connected,
@@ -52,17 +53,17 @@ def client(self) -> redis.Redis:
5253
return self._redis_conn # type: ignore
5354

5455
@check_connected("_redis_conn")
55-
def search(self, *args, **kwargs) -> List["Result"]:
56+
def search(self, *args, **kwargs) -> Union["Result", Any]:
5657
"""Perform a search on this index.
5758
5859
Wrapper around redis.search.Search that adds the index name
5960
to the search query and passes along the rest of the arguments
6061
to the redis-py ft.search() method.
6162
6263
Returns:
63-
List[Result]: A list of search results
64+
Union["Result", Any]: Search results.
6465
"""
65-
results: List["Result"] = self._redis_conn.ft(self._name).search( # type: ignore
66+
results = self._redis_conn.ft(self._name).search( # type: ignore
6667
*args, **kwargs
6768
)
6869
return results
@@ -82,6 +83,8 @@ def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
8283
List[Result]: A list of search results.
8384
"""
8485
results = self.search(query.query, query_params=query.params)
86+
if isinstance(query, CountQuery):
87+
return results.total
8588
return process_results(results)
8689

8790
@classmethod
@@ -522,17 +525,19 @@ async def _load(record: dict):
522525
await asyncio.gather(*[_load(record) for record in data])
523526

524527
@check_connected("_redis_conn")
525-
async def search(self, *args, **kwargs) -> List["Result"]:
528+
async def search(self, *args, **kwargs) -> Union["Result", Any]:
526529
"""Perform a search on this index.
527530
528531
Wrapper around redis.search.Search that adds the index name
529532
to the search query and passes along the rest of the arguments
530533
to the redis-py ft.search() method.
531534
532535
Returns:
533-
List[Result]: A list of search results.
536+
Union["Result", Any]: Search results.
534537
"""
535-
results: List["Result"] = await self._redis_conn.ft(self._name).search(*args, **kwargs) # type: ignore
538+
results = await self._redis_conn.ft(self._name).search( # type: ignore
539+
*args, **kwargs
540+
)
536541
return results
537542

538543
async def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
@@ -549,6 +554,8 @@ async def query(self, query: "BaseQuery") -> List[Dict[str, Any]]:
549554
List[Result]: A list of search results.
550555
"""
551556
results = await self.search(query.query, query_params=query.params)
557+
if isinstance(query, CountQuery):
558+
return results.total
552559
return process_results(results)
553560

554561
@check_connected("_redis_conn")

redisvl/query/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from redisvl.query.query import FilterQuery, RangeQuery, VectorQuery
1+
from redisvl.query.query import CountQuery, FilterQuery, RangeQuery, VectorQuery
22

3-
__all__ = ["VectorQuery", "FilterQuery", "RangeQuery"]
3+
__all__ = ["VectorQuery", "FilterQuery", "RangeQuery", "CountQuery"]

redisvl/query/query.py

Lines changed: 77 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
1+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
22

33
import numpy as np
44
from redis.commands.search.query import Query
@@ -12,6 +12,32 @@ def __init__(self, return_fields: List[str] = [], num_results: int = 10):
1212
self._return_fields = return_fields
1313
self._num_results = num_results
1414

15+
def __str__(self) -> str:
16+
return " ".join([str(x) for x in self.query.get_args()])
17+
18+
def set_filter(self, filter_expression: FilterExpression):
19+
"""Set the filter for the query.
20+
21+
Args:
22+
filter_expression (FilterExpression): The filter to apply to the query.
23+
24+
Raises:
25+
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
26+
"""
27+
if not isinstance(filter_expression, FilterExpression):
28+
raise TypeError(
29+
"filter_expression must be of type redisvl.query.FilterExpression"
30+
)
31+
self._filter = filter_expression
32+
33+
def get_filter(self) -> FilterExpression:
34+
"""Get the filter for the query.
35+
36+
Returns:
37+
FilterExpression: The filter for the query.
38+
"""
39+
return self._filter
40+
1541
@property
1642
def query(self) -> "Query":
1743
raise NotImplementedError
@@ -21,6 +47,54 @@ def params(self) -> Dict[str, Any]:
2147
raise NotImplementedError
2248

2349

50+
class CountQuery(BaseQuery):
51+
def __init__(
52+
self,
53+
filter_expression: FilterExpression,
54+
params: Optional[Dict[str, Any]] = None,
55+
):
56+
"""Query for a simple count operation on a filter expression.
57+
58+
Args:
59+
filter_expression (FilterExpression): The filter expression to query for.
60+
params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.
61+
62+
Raises:
63+
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
64+
65+
Examples:
66+
>>> from redisvl.query import CountQuery
67+
>>> from redisvl.query.filter import Tag
68+
>>> t = Tag("brand") == "Nike"
69+
>>> q = CountQuery(filter_expression=t)
70+
>>> count = index.query(q)
71+
"""
72+
self.set_filter(filter_expression)
73+
self._params = params
74+
75+
@property
76+
def query(self) -> Query:
77+
"""Return a Redis-Py Query object representing the query.
78+
79+
Returns:
80+
redis.commands.search.query.Query: The query object.
81+
"""
82+
base_query = str(self._filter)
83+
query = Query(base_query).no_content().dialect(2)
84+
return query
85+
86+
@property
87+
def params(self) -> Dict[str, Any]:
88+
"""Return the parameters for the query.
89+
90+
Returns:
91+
Dict[str, Any]: The parameters for the query.
92+
"""
93+
if not self._params:
94+
self._params = {}
95+
return self._params
96+
97+
2498
class FilterQuery(BaseQuery):
2599
def __init__(
26100
self,
@@ -51,32 +125,6 @@ def __init__(
51125
self.set_filter(filter_expression)
52126
self._params = params
53127

54-
def __str__(self) -> str:
55-
return " ".join([str(x) for x in self.query.get_args()])
56-
57-
def set_filter(self, filter_expression: FilterExpression):
58-
"""Set the filter for the query.
59-
60-
Args:
61-
filter_expression (FilterExpression): The filter to apply to the query.
62-
63-
Raises:
64-
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
65-
"""
66-
if not isinstance(filter_expression, FilterExpression):
67-
raise TypeError(
68-
"filter_expression must be of type redisvl.query.FilterExpression"
69-
)
70-
self._filter = filter_expression
71-
72-
def get_filter(self) -> FilterExpression:
73-
"""Get the filter for the query.
74-
75-
Returns:
76-
FilterExpression: The filter for the query.
77-
"""
78-
return self._filter
79-
80128
@property
81129
def query(self) -> Query:
82130
"""Return a Redis-Py Query object representing the query.
@@ -127,36 +175,14 @@ def __init__(
127175
self._vector = vector
128176
self._field = vector_field_name
129177
self._dtype = dtype.lower()
130-
self._filter = filter_expression
178+
self._filter = filter_expression # type: ignore
179+
131180
if filter_expression:
132181
self.set_filter(filter_expression)
133182

134183
if return_score:
135184
self._return_fields.append(self.DISTANCE_ID)
136185

137-
def set_filter(self, filter_expression: FilterExpression):
138-
"""Set the filter for the query.
139-
140-
Args:
141-
filter_expression (FilterExpression): The filter to apply to the query.
142-
"""
143-
if not isinstance(filter_expression, FilterExpression):
144-
raise TypeError(
145-
"filter_expression must be of type redisvl.query.FilterExpression"
146-
)
147-
self._filter = filter_expression
148-
149-
def get_filter(self) -> Optional[FilterExpression]:
150-
"""Get the filter for the query.
151-
152-
Returns:
153-
Optional[FilterExpression]: The filter for the query.
154-
"""
155-
return self._filter
156-
157-
def __str__(self):
158-
return " ".join([str(x) for x in self.query.get_args()])
159-
160186

161187
class VectorQuery(BaseVectorQuery):
162188
def __init__(

tests/integration/test_query.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from redis.commands.search.result import Result
66

77
from redisvl.index import SearchIndex
8-
from redisvl.query import FilterQuery, RangeQuery, VectorQuery
9-
from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text
8+
from redisvl.query import CountQuery, FilterQuery, RangeQuery, VectorQuery
9+
from redisvl.query.filter import FilterExpression, Geo, GeoRadius, Num, Tag, Text
1010

1111
data = [
1212
{
@@ -171,6 +171,16 @@ def test_range_query(index):
171171
assert len(results) == 2
172172

173173

174+
def test_count_query(index):
175+
c = CountQuery(FilterExpression("*"))
176+
results = index.query(c)
177+
assert results == len(data)
178+
179+
c = CountQuery(Tag("credit_score") == "high")
180+
results = index.query(c)
181+
assert results == 4
182+
183+
174184
vector_query = VectorQuery(
175185
vector=[0.1, 0.1, 0.5],
176186
vector_field_name="user_embedding",

0 commit comments

Comments
 (0)