1- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Union
1+ from typing import TYPE_CHECKING , Any , Dict , List , Optional
22
33import numpy as np
44from redis .commands .search .query import Query
@@ -12,6 +12,32 @@ def __init__(self, return_fields: List[str] = [], num_results: int = 10):
1212 self ._return_fields = return_fields
1313 self ._num_results = num_results
1414
15+ def __str__ (self ) -> str :
16+ return " " .join ([str (x ) for x in self .query .get_args ()])
17+
18+ def set_filter (self , filter_expression : FilterExpression ):
19+ """Set the filter for the query.
20+
21+ Args:
22+ filter_expression (FilterExpression): The filter to apply to the query.
23+
24+ Raises:
25+ TypeError: If filter_expression is not of type redisvl.query.FilterExpression
26+ """
27+ if not isinstance (filter_expression , FilterExpression ):
28+ raise TypeError (
29+ "filter_expression must be of type redisvl.query.FilterExpression"
30+ )
31+ self ._filter = filter_expression
32+
33+ def get_filter (self ) -> FilterExpression :
34+ """Get the filter for the query.
35+
36+ Returns:
37+ FilterExpression: The filter for the query.
38+ """
39+ return self ._filter
40+
1541 @property
1642 def query (self ) -> "Query" :
1743 raise NotImplementedError
@@ -21,6 +47,54 @@ def params(self) -> Dict[str, Any]:
2147 raise NotImplementedError
2248
2349
50+ class CountQuery (BaseQuery ):
51+ def __init__ (
52+ self ,
53+ filter_expression : FilterExpression ,
54+ params : Optional [Dict [str , Any ]] = None ,
55+ ):
56+ """Query for a simple count operation on a filter expression.
57+
58+ Args:
59+ filter_expression (FilterExpression): The filter expression to query for.
60+ params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.
61+
62+ Raises:
63+ TypeError: If filter_expression is not of type redisvl.query.FilterExpression
64+
65+ Examples:
66+ >>> from redisvl.query import CountQuery
67+ >>> from redisvl.query.filter import Tag
68+ >>> t = Tag("brand") == "Nike"
69+ >>> q = CountQuery(filter_expression=t)
70+ >>> count = index.query(q)
71+ """
72+ self .set_filter (filter_expression )
73+ self ._params = params
74+
75+ @property
76+ def query (self ) -> Query :
77+ """Return a Redis-Py Query object representing the query.
78+
79+ Returns:
80+ redis.commands.search.query.Query: The query object.
81+ """
82+ base_query = str (self ._filter )
83+ query = Query (base_query ).no_content ().dialect (2 )
84+ return query
85+
86+ @property
87+ def params (self ) -> Dict [str , Any ]:
88+ """Return the parameters for the query.
89+
90+ Returns:
91+ Dict[str, Any]: The parameters for the query.
92+ """
93+ if not self ._params :
94+ self ._params = {}
95+ return self ._params
96+
97+
2498class FilterQuery (BaseQuery ):
2599 def __init__ (
26100 self ,
@@ -51,32 +125,6 @@ def __init__(
51125 self .set_filter (filter_expression )
52126 self ._params = params
53127
54- def __str__ (self ) -> str :
55- return " " .join ([str (x ) for x in self .query .get_args ()])
56-
57- def set_filter (self , filter_expression : FilterExpression ):
58- """Set the filter for the query.
59-
60- Args:
61- filter_expression (FilterExpression): The filter to apply to the query.
62-
63- Raises:
64- TypeError: If filter_expression is not of type redisvl.query.FilterExpression
65- """
66- if not isinstance (filter_expression , FilterExpression ):
67- raise TypeError (
68- "filter_expression must be of type redisvl.query.FilterExpression"
69- )
70- self ._filter = filter_expression
71-
72- def get_filter (self ) -> FilterExpression :
73- """Get the filter for the query.
74-
75- Returns:
76- FilterExpression: The filter for the query.
77- """
78- return self ._filter
79-
80128 @property
81129 def query (self ) -> Query :
82130 """Return a Redis-Py Query object representing the query.
@@ -127,36 +175,14 @@ def __init__(
127175 self ._vector = vector
128176 self ._field = vector_field_name
129177 self ._dtype = dtype .lower ()
130- self ._filter = filter_expression
178+ self ._filter = filter_expression # type: ignore
179+
131180 if filter_expression :
132181 self .set_filter (filter_expression )
133182
134183 if return_score :
135184 self ._return_fields .append (self .DISTANCE_ID )
136185
137- def set_filter (self , filter_expression : FilterExpression ):
138- """Set the filter for the query.
139-
140- Args:
141- filter_expression (FilterExpression): The filter to apply to the query.
142- """
143- if not isinstance (filter_expression , FilterExpression ):
144- raise TypeError (
145- "filter_expression must be of type redisvl.query.FilterExpression"
146- )
147- self ._filter = filter_expression
148-
149- def get_filter (self ) -> Optional [FilterExpression ]:
150- """Get the filter for the query.
151-
152- Returns:
153- Optional[FilterExpression]: The filter for the query.
154- """
155- return self ._filter
156-
157- def __str__ (self ):
158- return " " .join ([str (x ) for x in self .query .get_args ()])
159-
160186
161187class VectorQuery (BaseVectorQuery ):
162188 def __init__ (
0 commit comments