Skip to content

Commit 4e3de2d

Browse files
tylerhutchersonSam Partee
andauthored
Add tests for token escaper class (#69)
Adds a set of unit tests on both the underlying token escaping class as well as the `Tag` filterable fields that utilize it. --------- Co-authored-by: Sam Partee <sam.partee@redis.com>
1 parent 6345cc1 commit 4e3de2d

File tree

5 files changed

+210
-35
lines changed

5 files changed

+210
-35
lines changed

redisvl/query/filter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from functools import wraps
33
from typing import Any, Callable, Dict, List, Optional, Union
44

5-
from redisvl.utils.utils import TokenEscaper
5+
from redisvl.utils.token_escaper import TokenEscaper
66

77
# disable mypy error for dunder method overrides
88
# mypy: disable-error-code="override"

redisvl/utils/token_escaper.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import re
2+
from typing import Optional, Pattern
3+
4+
5+
class TokenEscaper:
6+
"""
7+
Escape punctuation within an input string. Adapted from RedisOM Python.
8+
"""
9+
10+
# Characters that RediSearch requires us to escape during queries.
11+
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
12+
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"
13+
14+
def __init__(self, escape_chars_re: Optional[Pattern] = None):
15+
if escape_chars_re:
16+
self.escaped_chars_re = escape_chars_re
17+
else:
18+
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
19+
20+
def escape(self, value: str) -> str:
21+
if not isinstance(value, str):
22+
raise TypeError(
23+
f"Value must be a string object for token escaping, got type {type(value)}"
24+
)
25+
26+
def escape_symbol(match):
27+
value = match.group(0)
28+
return f"\\{value}"
29+
30+
return self.escaped_chars_re.sub(escape_symbol, value)

redisvl/utils/utils.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import re
2-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern
1+
from typing import TYPE_CHECKING, Any, Dict, List
32

43
if TYPE_CHECKING:
54
from redis.commands.search.result import Result
@@ -8,29 +7,6 @@
87
import numpy as np
98

109

11-
class TokenEscaper:
12-
"""
13-
Escape punctuation within an input string. Taken from RedisOM Python.
14-
"""
15-
16-
# Characters that RediSearch requires us to escape during queries.
17-
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
18-
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"
19-
20-
def __init__(self, escape_chars_re: Optional[Pattern] = None):
21-
if escape_chars_re:
22-
self.escaped_chars_re = escape_chars_re
23-
else:
24-
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
25-
26-
def escape(self, value: str) -> str:
27-
def escape_symbol(match):
28-
value = match.group(0)
29-
return f"\\{value}"
30-
31-
return self.escaped_chars_re.sub(escape_symbol, value)
32-
33-
3410
def make_dict(values: List[Any]):
3511
# TODO make this a real function
3612
i = 0

tests/test_filter.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,54 @@
33
from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text
44

55

6-
def test_tag_filter():
7-
tf = Tag("tag_field") == ["tag1", "tag2"]
8-
assert str(tf) == "@tag_field:{tag1|tag2}"
9-
10-
tf = Tag("tag_field") == "tag1"
11-
assert str(tf) == "@tag_field:{tag1}"
12-
13-
tf = Tag("tag_field") != ["tag1", "tag2"]
14-
assert str(tf) == "(-@tag_field:{tag1|tag2})"
6+
# Test cases for various scenarios of tag usage, combinations, and their string representations.
7+
@pytest.mark.parametrize(
8+
"operation,tags,expected",
9+
[
10+
# Testing single tags
11+
("==", "simpletag", "@tag_field:{simpletag}"),
12+
(
13+
"==",
14+
"tag with space",
15+
"@tag_field:{tag\\ with\\ space}",
16+
), # Escaping spaces within quotes
17+
(
18+
"==",
19+
"special$char",
20+
"@tag_field:{special\\$char}",
21+
), # Escaping a special character
22+
("!=", "negated", "(-@tag_field:{negated})"),
23+
# Testing multiple tags
24+
("==", ["tag1", "tag2"], "@tag_field:{tag1|tag2}"),
25+
(
26+
"==",
27+
["alpha", "beta with space", "gamma$special"],
28+
"@tag_field:{alpha|beta\\ with\\ space|gamma\\$special}",
29+
), # Multiple tags with spaces and special chars
30+
("!=", ["tagA", "tagB"], "(-@tag_field:{tagA|tagB})"),
31+
# Complex tag scenarios with special characters
32+
("==", "weird:tag", "@tag_field:{weird\\:tag}"), # Tags with colon
33+
("==", "tag&another", "@tag_field:{tag\\&another}"), # Tags with ampersand
34+
# Escaping various special characters within tags
35+
("==", "tag/with/slashes", "@tag_field:{tag\\/with\\/slashes}"),
36+
(
37+
"==",
38+
["hypen-tag", "under_score", "dot.tag"],
39+
"@tag_field:{hypen\\-tag|under_score|dot\\.tag}",
40+
),
41+
# ...additional unique cases as desired...
42+
],
43+
)
44+
def test_tag_filter_varied(operation, tags, expected):
45+
if operation == "==":
46+
tf = Tag("tag_field") == tags
47+
elif operation == "!=":
48+
tf = Tag("tag_field") != tags
49+
else:
50+
raise ValueError(f"Unsupported operation: {operation}")
51+
52+
# Verify the string representation matches the expected RediSearch query part
53+
assert str(tf) == expected
1554

1655

1756
def test_numeric_filter():

tests/test_token_escaper.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import pytest
2+
3+
from redisvl.utils.token_escaper import TokenEscaper
4+
5+
6+
@pytest.fixture
7+
def escaper():
8+
return TokenEscaper()
9+
10+
11+
@pytest.mark.parametrize(
12+
("test_input,expected"),
13+
[
14+
(r"a [big] test.", r"a\ \[big\]\ test\."),
15+
(r"hello, world!", r"hello\,\ world\!"),
16+
(
17+
r'special "quotes" (and parentheses)',
18+
r"special\ \"quotes\"\ \(and\ parentheses\)",
19+
),
20+
(
21+
r"& symbols, like * and ?",
22+
r"\&\ symbols\,\ like\ \*\ and\ ?",
23+
), # TODO: question marks are not caught?
24+
# underscores are ignored
25+
(r"-dashes_and_underscores-", r"\-dashes_and_underscores\-"),
26+
],
27+
ids=[
28+
"brackets",
29+
"commas",
30+
"quotes",
31+
"symbols",
32+
"underscores"
33+
]
34+
)
35+
def test_escape_text_chars(escaper, test_input, expected):
36+
assert escaper.escape(test_input) == expected
37+
38+
39+
@pytest.mark.parametrize(
40+
("test_input,expected"),
41+
[
42+
# Simple tags
43+
("user:name", r"user\:name"),
44+
("123#comment", r"123\#comment"),
45+
("hyphen-separated", r"hyphen\-separated"),
46+
# Tags with special characters
47+
("price$", r"price\$"),
48+
("super*star", r"super\*star"),
49+
("tag&value", r"tag\&value"),
50+
("@username", r"\@username"),
51+
# Space-containing tags often used in search scenarios
52+
("San Francisco", r"San\ Francisco"),
53+
("New Zealand", r"New\ Zealand"),
54+
# Multi-special-character tags
55+
("complex/tag:value", r"complex\/tag\:value"),
56+
("$special$tag$", r"\$special\$tag\$"),
57+
("tag-with-hyphen", r"tag\-with\-hyphen"),
58+
# Tags with less common, but legal characters
59+
("_underscore_", r"_underscore_"),
60+
("dot.tag", r"dot\.tag"),
61+
# ("pipe|tag", r"pipe\|tag"), #TODO - pipes are not caught?
62+
# More edge cases with special characters
63+
("(parentheses)", r"\(parentheses\)"),
64+
("[brackets]", r"\[brackets\]"),
65+
("{braces}", r"\{braces\}"),
66+
# ("question?mark", r"question\?mark"), #TODO - question marks are not caught?
67+
# Unicode characters in tags
68+
("你好", r"你好"), # Assuming non-Latin characters don't need escaping
69+
("emoji:😊", r"emoji\:😊"),
70+
# ...other cases as needed...
71+
],
72+
ids=[
73+
":",
74+
"#",
75+
"-",
76+
"$",
77+
"*",
78+
"&",
79+
"@",
80+
"space",
81+
"space-2",
82+
"complex",
83+
"special",
84+
"hyphen",
85+
"underscore",
86+
"dot",
87+
"parentheses",
88+
"brackets",
89+
"braces",
90+
"non-latin",
91+
"emoji"
92+
]
93+
)
94+
def test_escape_tag_like_values(escaper, test_input, expected):
95+
assert escaper.escape(test_input) == expected
96+
97+
98+
@pytest.mark.parametrize("test_input", [123, 45.67, None, [], {}])
99+
def test_escape_non_string_input(escaper, test_input):
100+
with pytest.raises(TypeError):
101+
escaper.escape(test_input)
102+
103+
104+
@pytest.mark.parametrize(
105+
"test_input,expected",
106+
[
107+
# ('你好,世界!', r'你好\,世界\!'), # TODO - non latin chars?
108+
("😊 ❤️ 👍", r"😊\ ❤️\ 👍"),
109+
# ...other cases as needed...
110+
],
111+
ids=[
112+
"emoji"
113+
]
114+
)
115+
def test_escape_unicode_characters(escaper, test_input, expected):
116+
assert escaper.escape(test_input) == expected
117+
118+
119+
def test_escape_empty_string(escaper):
120+
assert escaper.escape("") == ""
121+
122+
123+
def test_escape_long_string(escaper):
124+
# Construct a very long string
125+
long_str = "a," * 1000 # This creates a string "a,a,a,a,...a,"
126+
expected = r"a\," * 1000 # Expected escaped string
127+
128+
# Use pytest's benchmark fixture to check performance
129+
escaped = escaper.escape(long_str)
130+
assert escaped == expected

0 commit comments

Comments
 (0)