Skip to content

Commit 9b5a13a

Browse files
author
Sam Partee
authored
Add FilterQuery (#51)
Add a ``FilterQuery`` class for running just a query with a filter and no vector search in the case where the user needs to do this and should not have to also use Redis-py
1 parent 0c619f4 commit 9b5a13a

File tree

12 files changed

+184
-79
lines changed

12 files changed

+184
-79
lines changed

redisvl/cli/index.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import argparse
22
import sys
3-
from tabulate import tabulate
43
from argparse import Namespace
54

5+
from tabulate import tabulate
66

77
from redisvl.cli.log import get_logger
8-
from redisvl.cli.utils import create_redis_url, add_index_parsing_options
8+
from redisvl.cli.utils import add_index_parsing_options, create_redis_url
99
from redisvl.index import SearchIndex
1010
from redisvl.utils.connection import get_redis_connection
1111
from redisvl.utils.utils import convert_bytes, make_dict
@@ -36,7 +36,7 @@ def __init__(self):
3636
"--format",
3737
help="Output format for info command",
3838
type=str,
39-
default="rounded_outline"
39+
default="rounded_outline",
4040
)
4141
parser = add_index_parsing_options(parser)
4242

@@ -126,6 +126,7 @@ def _connect_to_index(self, args: Namespace) -> SearchIndex:
126126

127127
return index
128128

129+
129130
def _display_in_table(index_info, output_format="rounded_outline"):
130131
print("\n")
131132
attributes = index_info.get("attributes", [])
@@ -183,4 +184,4 @@ def _display_in_table(index_info, output_format="rounded_outline"):
183184
headers=headers,
184185
tablefmt=output_format,
185186
)
186-
)
187+
)

redisvl/cli/main.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33

44
from redisvl.cli.index import Index
55
from redisvl.cli.log import get_logger
6-
from redisvl.cli.version import Version
76
from redisvl.cli.stats import Stats
8-
7+
from redisvl.cli.version import Version
98

109
logger = get_logger(__name__)
1110

@@ -50,4 +49,3 @@ def version(self):
5049
def stats(self):
5150
Stats()
5251
exit(0)
53-

redisvl/cli/stats.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import argparse
22
import sys
3-
from tabulate import tabulate
43
from argparse import Namespace
54

6-
from redisvl.cli.utils import create_redis_url, add_index_parsing_options
5+
from tabulate import tabulate
6+
7+
from redisvl.cli.log import get_logger
8+
from redisvl.cli.utils import add_index_parsing_options, create_redis_url
79
from redisvl.index import SearchIndex
810
from redisvl.utils.connection import get_redis_connection
911

10-
from redisvl.cli.log import get_logger
1112
logger = get_logger("[RedisVL]")
1213

1314
STATS_KEYS = [
@@ -32,6 +33,7 @@
3233
"vector_index_sz_mb",
3334
]
3435

36+
3537
class Stats:
3638
usage = "\n".join(
3739
[
@@ -43,11 +45,7 @@ def __init__(self):
4345
parser = argparse.ArgumentParser(usage=self.usage)
4446

4547
parser.add_argument(
46-
"-f",
47-
"--format",
48-
help="Output format",
49-
type=str,
50-
default="rounded_outline"
48+
"-f", "--format", help="Output format", type=str, default="rounded_outline"
5149
)
5250
parser = add_index_parsing_options(parser)
5351
args = parser.parse_args(sys.argv[2:])
@@ -57,7 +55,6 @@ def __init__(self):
5755
logger.error(e)
5856
exit(0)
5957

60-
6158
def stats(self, args: Namespace):
6259
"""Obtain stats about an index
6360

redisvl/cli/utils.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from argparse import Namespace, ArgumentParser
2+
from argparse import ArgumentParser, Namespace
33

44

55
def create_redis_url(args: Namespace) -> str:
@@ -18,20 +18,15 @@ def create_redis_url(args: Namespace) -> str:
1818
url += args.host + ":" + str(args.port)
1919
return url
2020

21+
2122
def add_index_parsing_options(parser: ArgumentParser) -> ArgumentParser:
22-
parser.add_argument(
23-
"-i", "--index", help="Index name", type=str, required=False
24-
)
23+
parser.add_argument("-i", "--index", help="Index name", type=str, required=False)
2524
parser.add_argument(
2625
"-s", "--schema", help="Path to schema file", type=str, required=False
2726
)
2827
parser.add_argument("--host", help="Redis host", type=str, default="localhost")
2928
parser.add_argument("-p", "--port", help="Redis port", type=int, default=6379)
30-
parser.add_argument(
31-
"--user", help="Redis username", type=str, default="default"
32-
)
29+
parser.add_argument("--user", help="Redis username", type=str, default="default")
3330
parser.add_argument("--ssl", help="Use SSL", action="store_true")
34-
parser.add_argument(
35-
"-a", "--password", help="Redis password", type=str, default=""
36-
)
31+
parser.add_argument("-a", "--password", help="Redis password", type=str, default="")
3732
return parser

redisvl/query/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1-
from redisvl.query.query import VectorQuery
1+
from redisvl.query.query import FilterQuery, VectorQuery
22

3-
__all__ = ["VectorQuery"]
3+
__all__ = [
4+
"VectorQuery",
5+
"FilterQuery",
6+
]

redisvl/query/filter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ class Geo(FilterField):
117117
field in a Redis index.
118118
119119
"""
120+
120121
OPERATORS = {
121122
FilterOperator.EQ: "==",
122123
FilterOperator.NE: "!=",
@@ -174,12 +175,13 @@ def __init__(self, longitude: float, latitude: float, unit: str = "km"):
174175

175176
class GeoRadius(GeoSpec):
176177
"""A GeoRadius is a GeoSpec representing a geographic radius"""
178+
177179
def __init__(
178180
self,
179181
longitude: float,
180182
latitude: float,
181183
radius: Optional[int] = 1,
182-
unit: Optional[str] = "km"
184+
unit: Optional[str] = "km",
183185
):
184186
"""Create a GeoRadius specification (GeoSpec)
185187
@@ -202,6 +204,7 @@ def get_args(self) -> List[Union[float, int, str]]:
202204

203205
class Num(FilterField):
204206
"""A Num is a FilterField representing a numeric field in a Redis index."""
207+
205208
OPERATORS = {
206209
FilterOperator.EQ: "==",
207210
FilterOperator.NE: "!=",
@@ -311,6 +314,7 @@ def __le__(self, other: str) -> "FilterExpression":
311314

312315
class Text(FilterField):
313316
"""A Text is a FilterField representing a text field in a Redis index."""
317+
314318
OPERATORS = {
315319
FilterOperator.EQ: "==",
316320
FilterOperator.NE: "!=",
@@ -399,6 +403,7 @@ class FilterExpression:
399403
... filter_expression=filter,
400404
... )
401405
"""
406+
402407
def __init__(
403408
self,
404409
_filter: str = None,

redisvl/query/query.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,88 @@ def params(self) -> Dict[str, Any]:
2323
pass
2424

2525

26+
class FilterQuery(BaseQuery):
27+
def __init__(
28+
self,
29+
return_fields: List[str],
30+
filter_expression: FilterExpression,
31+
num_results: Optional[int] = 10,
32+
params: Optional[Dict[str, Any]] = None,
33+
):
34+
"""Query for a filter expression.
35+
36+
Args:
37+
return_fields (List[str]): The fields to return.
38+
filter_expression (FilterExpression): The filter expression to query for.
39+
num_results (Optional[int], optional): The number of results to return. Defaults to 10.
40+
params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.
41+
42+
Raises:
43+
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
44+
45+
Examples:
46+
>>> from redisvl.query import FilterQuery
47+
>>> from redisvl.query.filter import Tag
48+
>>> t = Tag("brand") == "Nike"
49+
>>> q = FilterQuery(return_fields=["brand", "price"], filter_expression=t)
50+
"""
51+
52+
super().__init__(return_fields, num_results)
53+
self.set_filter(filter_expression)
54+
self._params = params
55+
56+
def __str__(self) -> str:
57+
return " ".join([str(x) for x in self.query.get_args()])
58+
59+
def set_filter(self, filter_expression: FilterExpression):
60+
"""Set the filter for the query.
61+
62+
Args:
63+
filter_expression (FilterExpression): The filter to apply to the query.
64+
65+
Raises:
66+
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
67+
"""
68+
if not isinstance(filter_expression, FilterExpression):
69+
raise TypeError(
70+
"filter_expression must be of type redisvl.query.FilterExpression"
71+
)
72+
self._filter = str(filter_expression)
73+
74+
def get_filter(self) -> FilterExpression:
75+
"""Get the filter for the query.
76+
77+
Returns:
78+
FilterExpression: The filter for the query.
79+
"""
80+
return self._filter
81+
82+
@property
83+
def query(self) -> Query:
84+
"""Return a Redis-Py Query object representing the query.
85+
86+
Returns:
87+
redis.commands.search.query.Query: The query object.
88+
"""
89+
base_query = str(self._filter)
90+
query = (
91+
Query(base_query)
92+
.return_fields(*self._return_fields)
93+
.paging(0, self._num_results)
94+
.dialect(2)
95+
)
96+
return query
97+
98+
@property
99+
def params(self) -> Dict[str, Any]:
100+
"""Return the parameters for the query.
101+
102+
Returns:
103+
Dict[str, Any]: The parameters for the query.
104+
"""
105+
return self._params
106+
107+
26108
class VectorQuery(BaseQuery):
27109
dtypes = {
28110
"float32": np.float32,

redisvl/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,4 @@ def read_schema(file_path: str):
156156
with open(fp, "r") as f:
157157
schema = yaml.safe_load(f)
158158

159-
return SchemaModel(**schema)
159+
return SchemaModel(**schema)

redisvl/vectorize/text/huggingface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def embed_many(
8080
TypeError: If the wrong input type is passed in for the test.
8181
"""
8282
if not isinstance(texts, list):
83-
raise TypeError("Must pass in a list of str values to embed.")
84-
if len(texts) > 0 and not isinstance(texts[0], str):
85-
raise TypeError("Must pass in a list of str values to embed.")
83+
raise TypeError("Must pass in a list of str values to embed.")
84+
if len(texts) > 0 and not isinstance(texts[0], str):
85+
raise TypeError("Must pass in a list of str values to embed.")
8686

8787
embeddings: List = []
8888
for batch in self.batchify(texts, batch_size, preprocess):

redisvl/vectorize/text/openai.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class OpenAITextVectorizer(BaseVectorizer):
1313
API key to be passed in the api_config dictionary. The API key can be obtained from
1414
https://api.openai.com/.
1515
"""
16+
1617
def __init__(self, model: str, api_config: Optional[Dict] = None):
1718
"""Initialize the OpenAI vectorizer.
1819
@@ -45,14 +46,13 @@ def __init__(self, model: str, api_config: Optional[Dict] = None):
4546
def _set_model_dims(self) -> int:
4647
try:
4748
embedding = self._model_client.create(
48-
input=["dimension test"],
49-
engine=self._model
49+
input=["dimension test"], engine=self._model
5050
)["data"][0]["embedding"]
5151
except (KeyError, IndexError) as ke:
5252
raise ValueError(f"Unexpected response from the OpenAI API: {str(ke)}")
5353
except openai.error.AuthenticationError as ae:
5454
raise ValueError(f"Error authenticating with the OpenAI API: {str(ae)}")
55-
except Exception as e: # pylint: disable=broad-except
55+
except Exception as e: # pylint: disable=broad-except
5656
# fall back (TODO get more specific)
5757
raise ValueError(f"Error setting embedding model dimensions: {str(e)}")
5858
return len(embedding)
@@ -87,9 +87,9 @@ def embed_many(
8787
TypeError: If the wrong input type is passed in for the test.
8888
"""
8989
if not isinstance(texts, list):
90-
raise TypeError("Must pass in a list of str values to embed.")
91-
if len(texts) > 0 and not isinstance(texts[0], str):
92-
raise TypeError("Must pass in a list of str values to embed.")
90+
raise TypeError("Must pass in a list of str values to embed.")
91+
if len(texts) > 0 and not isinstance(texts[0], str):
92+
raise TypeError("Must pass in a list of str values to embed.")
9393

9494
embeddings: List = []
9595
for batch in self.batchify(texts, batch_size, preprocess):
@@ -164,9 +164,9 @@ async def aembed_many(
164164
TypeError: If the wrong input type is passed in for the test.
165165
"""
166166
if not isinstance(texts, list):
167-
raise TypeError("Must pass in a list of str values to embed.")
168-
if len(texts) > 0 and not isinstance(texts[0], str):
169-
raise TypeError("Must pass in a list of str values to embed.")
167+
raise TypeError("Must pass in a list of str values to embed.")
168+
if len(texts) > 0 and not isinstance(texts[0], str):
169+
raise TypeError("Must pass in a list of str values to embed.")
170170

171171
embeddings: List = []
172172
for batch in self.batchify(texts, batch_size, preprocess):

0 commit comments

Comments
 (0)