Skip to content

Commit 9a2a9d1

Browse files
Vectorizer improvements (#44)
This PR introduces a few improvements to the vectorizers module: - Raises `TypeError` when the wrong input type is not passed in. - Updates the methodology for fetching the output dimensions of the embedding model to be dynamic based on the given model.
1 parent 29bb39a commit 9a2a9d1

File tree

9 files changed

+124
-28
lines changed

9 files changed

+124
-28
lines changed

redisvl/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
2-
31
from redisvl.version import __version__
42

5-
all = ["__version__"]
3+
all = ["__version__"]

redisvl/cli/index.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
logger = get_logger("[RedisVL]")
1313

14+
1415
class Index:
1516
usage = "\n".join(
1617
[

redisvl/cli/log.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55

66
# constants for logging
77
coloredlogs.DEFAULT_DATE_FORMAT = "%H:%M:%S"
8-
coloredlogs.DEFAULT_LOG_FORMAT = (
9-
"%(asctime)s %(name)s %(levelname)s %(message)s"
10-
)
8+
coloredlogs.DEFAULT_LOG_FORMAT = "%(asctime)s %(name)s %(levelname)s %(message)s"
9+
1110

1211
def get_logger(name, log_level="info", fmt=None):
1312
"""Return a logger instance"""

redisvl/cli/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import sys
33

44
from redisvl.cli.index import Index
5-
from redisvl.cli.version import Version
65
from redisvl.cli.log import get_logger
6+
from redisvl.cli.version import Version
77

88
logger = get_logger(__name__)
99

@@ -42,4 +42,4 @@ def index(self):
4242

4343
def version(self):
4444
Version()
45-
exit(0)
45+
exit(0)

redisvl/cli/version.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
import sys
21
import argparse
2+
import sys
33
from argparse import Namespace
44

55
from redisvl import __version__
66
from redisvl.cli.log import get_logger
7+
78
logger = get_logger("[RedisVL]")
89

910

@@ -28,4 +29,4 @@ def version(self, args: Namespace):
2829
if args.short:
2930
print(__version__)
3031
else:
31-
logger.info(f"RedisVL version {__version__}")
32+
logger.info(f"RedisVL version {__version__}")

redisvl/vectorize/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Callable, Dict, List, Optional
1+
from typing import Callable, List, Optional
22

33
from redisvl.utils.utils import array_to_buffer
44

55

66
class BaseVectorizer:
7-
def __init__(self, model: str, dims: int, api_config: Optional[Dict] = None):
8-
self._dims = dims
7+
_dims = None
8+
9+
def __init__(self, model: str):
910
self._model = model
1011

1112
@property

redisvl/vectorize/text/huggingface.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,24 @@
66
class HFTextVectorizer(BaseVectorizer):
77
# TODO - add docstring
88
def __init__(self, model: str, api_config: Optional[Dict] = None):
9-
# TODO set dims based on model
10-
dims = 768
11-
super().__init__(model, dims, api_config)
9+
super().__init__(model)
1210
try:
1311
from sentence_transformers import SentenceTransformer
1412
except ImportError:
1513
raise ImportError(
1614
"HFTextVectorizer requires sentence-transformers library. Please install with pip install sentence-transformers"
1715
)
1816

19-
self._model_client = SentenceTransformer(model)
17+
self._model_client = SentenceTransformer(self._model)
18+
19+
try:
20+
self._dims = self._set_model_dims()
21+
except:
22+
raise ValueError("Error setting embedding model dimensions")
23+
24+
def _set_model_dims(self):
25+
embedding = self._model_client.encode(["dimension check"])[0]
26+
return len(embedding)
2027

2128
def embed(
2229
self,
@@ -35,7 +42,13 @@ def embed(
3542
3643
Returns:
3744
List[float]: Embedding.
45+
46+
Raises:
47+
TypeError: If the wrong input type is passed in for the text.
3848
"""
49+
if not isinstance(text, str):
50+
raise TypeError("Must pass in a str value to embed.")
51+
3952
if preprocess:
4053
text = preprocess(text)
4154
embedding = self._model_client.encode([text])[0]
@@ -62,7 +75,15 @@ def embed_many(
6275
6376
Returns:
6477
List[List[float]]: List of embeddings.
78+
79+
Raises:
80+
TypeError: If the wrong input type is passed in for the test.
6581
"""
82+
if not isinstance(texts, list):
83+
raise TypeError("Must pass in a list of str values to embed.")
84+
if len(texts) > 0 and not isinstance(texts[0], str):
85+
raise TypeError("Must pass in a list of str values to embed.")
86+
6687
embeddings: List = []
6788
for batch in self.batchify(texts, batch_size, preprocess):
6889
batch_embeddings = self._model_client.encode(batch)

redisvl/vectorize/text/openai.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
11
from typing import Callable, Dict, List, Optional
22

3-
from tenacity import ( # for exponential backoff
4-
retry,
5-
stop_after_attempt,
6-
wait_random_exponential,
7-
)
3+
from tenacity import retry, stop_after_attempt, wait_random_exponential
4+
from tenacity.retry import retry_if_not_exception_type
85

96
from redisvl.vectorize.base import BaseVectorizer
107

118

129
class OpenAITextVectorizer(BaseVectorizer):
1310
# TODO - add docstring
1411
def __init__(self, model: str, api_config: Optional[Dict] = None):
15-
dims = 1536
16-
super().__init__(model, dims, api_config)
12+
super().__init__(model)
1713
if not api_config:
1814
raise ValueError("OpenAI API key is required in api_config")
1915
try:
@@ -25,7 +21,23 @@ def __init__(self, model: str, api_config: Optional[Dict] = None):
2521
openai.api_key = api_config.get("api_key", None)
2622
self._model_client = openai.Embedding
2723

28-
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
24+
try:
25+
self._dims = self._set_model_dims()
26+
except:
27+
raise ValueError("Error setting embedding model dimensions")
28+
29+
def _set_model_dims(self):
30+
embedding = self._model_client.create(
31+
input=["dimension test"],
32+
engine=self._model
33+
)["data"][0]["embedding"]
34+
return len(embedding)
35+
36+
@retry(
37+
wait=wait_random_exponential(min=1, max=60),
38+
stop=stop_after_attempt(6),
39+
retry=retry_if_not_exception_type(TypeError),
40+
)
2941
def embed_many(
3042
self,
3143
texts: List[str],
@@ -46,7 +58,15 @@ def embed_many(
4658
4759
Returns:
4860
List[List[float]]: List of embeddings.
61+
62+
Raises:
63+
TypeError: If the wrong input type is passed in for the test.
4964
"""
65+
if not isinstance(texts, list):
66+
raise TypeError("Must pass in a list of str values to embed.")
67+
if len(texts) > 0 and not isinstance(texts[0], str):
68+
raise TypeError("Must pass in a list of str values to embed.")
69+
5070
embeddings: List = []
5171
for batch in self.batchify(texts, batch_size, preprocess):
5272
response = self._model_client.create(input=batch, engine=self._model)
@@ -56,7 +76,11 @@ def embed_many(
5676
]
5777
return embeddings
5878

59-
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
79+
@retry(
80+
wait=wait_random_exponential(min=1, max=60),
81+
stop=stop_after_attempt(6),
82+
retry=retry_if_not_exception_type(TypeError),
83+
)
6084
def embed(
6185
self,
6286
text: str,
@@ -74,13 +98,23 @@ def embed(
7498
7599
Returns:
76100
List[float]: Embedding.
101+
102+
Raises:
103+
TypeError: If the wrong input type is passed in for the test.
77104
"""
105+
if not isinstance(text, str):
106+
raise TypeError("Must pass in a str value to embed.")
107+
78108
if preprocess:
79109
text = preprocess(text)
80110
result = self._model_client.create(input=[text], engine=self._model)
81111
return self._process_embedding(result["data"][0]["embedding"], as_buffer)
82112

83-
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
113+
@retry(
114+
wait=wait_random_exponential(min=1, max=60),
115+
stop=stop_after_attempt(6),
116+
retry=retry_if_not_exception_type(TypeError),
117+
)
84118
async def aembed_many(
85119
self,
86120
texts: List[str],
@@ -101,7 +135,15 @@ async def aembed_many(
101135
102136
Returns:
103137
List[List[float]]: List of embeddings.
138+
139+
Raises:
140+
TypeError: If the wrong input type is passed in for the test.
104141
"""
142+
if not isinstance(texts, list):
143+
raise TypeError("Must pass in a list of str values to embed.")
144+
if len(texts) > 0 and not isinstance(texts[0], str):
145+
raise TypeError("Must pass in a list of str values to embed.")
146+
105147
embeddings: List = []
106148
for batch in self.batchify(texts, batch_size, preprocess):
107149
response = await self._model_client.acreate(input=batch, engine=self._model)
@@ -111,7 +153,11 @@ async def aembed_many(
111153
]
112154
return embeddings
113155

114-
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
156+
@retry(
157+
wait=wait_random_exponential(min=1, max=60),
158+
stop=stop_after_attempt(6),
159+
retry=retry_if_not_exception_type(TypeError),
160+
)
115161
async def aembed(
116162
self,
117163
text: str,
@@ -129,7 +175,13 @@ async def aembed(
129175
130176
Returns:
131177
List[float]: Embedding.
178+
179+
Raises:
180+
TypeError: If the wrong input type is passed in for the test.
132181
"""
182+
if not isinstance(text, str):
183+
raise TypeError("Must pass in a str value to embed.")
184+
133185
if preprocess:
134186
text = preprocess(text)
135187
result = await self._model_client.acreate(input=[text], engine=self._model)

tests/integration/test_vectorizers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ def test_vectorizer_embed_many(vectorizer):
3535
)
3636

3737

38+
def test_vectorizer_bad_input(vectorizer):
39+
with pytest.raises(TypeError):
40+
vectorizer.embed(1)
41+
42+
with pytest.raises(TypeError):
43+
vectorizer.embed({"foo": "bar"})
44+
45+
with pytest.raises(TypeError):
46+
vectorizer.embed_many(42)
47+
48+
3849
@pytest.fixture(params=[OpenAITextVectorizer])
3950
def avectorizer(request, openai_key):
4051
# Here we use actual models for integration test
@@ -63,3 +74,15 @@ async def test_vectorizer_aembed_many(avectorizer):
6374
assert all(
6475
isinstance(emb, list) and len(emb) == avectorizer.dims for emb in embeddings
6576
)
77+
78+
79+
@pytest.mark.asyncio
80+
async def test_avectorizer_bad_input(avectorizer):
81+
with pytest.raises(TypeError):
82+
avectorizer.embed(1)
83+
84+
with pytest.raises(TypeError):
85+
avectorizer.embed({"foo": "bar"})
86+
87+
with pytest.raises(TypeError):
88+
avectorizer.embed_many(42)

0 commit comments

Comments
 (0)