11from typing import Callable , Dict , List , Optional
22
3- from tenacity import ( # for exponential backoff
4- retry ,
5- stop_after_attempt ,
6- wait_random_exponential ,
7- )
3+ from tenacity import retry , stop_after_attempt , wait_random_exponential
4+ from tenacity .retry import retry_if_not_exception_type
85
96from redisvl .vectorize .base import BaseVectorizer
107
118
129class OpenAITextVectorizer (BaseVectorizer ):
1310 # TODO - add docstring
1411 def __init__ (self , model : str , api_config : Optional [Dict ] = None ):
15- dims = 1536
16- super ().__init__ (model , dims , api_config )
12+ super ().__init__ (model )
1713 if not api_config :
1814 raise ValueError ("OpenAI API key is required in api_config" )
1915 try :
@@ -25,7 +21,23 @@ def __init__(self, model: str, api_config: Optional[Dict] = None):
2521 openai .api_key = api_config .get ("api_key" , None )
2622 self ._model_client = openai .Embedding
2723
28- @retry (wait = wait_random_exponential (min = 1 , max = 60 ), stop = stop_after_attempt (6 ))
24+ try :
25+ self ._dims = self ._set_model_dims ()
26+ except :
27+ raise ValueError ("Error setting embedding model dimensions" )
28+
29+ def _set_model_dims (self ):
30+ embedding = self ._model_client .create (
31+ input = ["dimension test" ],
32+ engine = self ._model
33+ )["data" ][0 ]["embedding" ]
34+ return len (embedding )
35+
36+ @retry (
37+ wait = wait_random_exponential (min = 1 , max = 60 ),
38+ stop = stop_after_attempt (6 ),
39+ retry = retry_if_not_exception_type (TypeError ),
40+ )
2941 def embed_many (
3042 self ,
3143 texts : List [str ],
@@ -46,7 +58,15 @@ def embed_many(
4658
4759 Returns:
4860 List[List[float]]: List of embeddings.
61+
62+ Raises:
63+ TypeError: If the wrong input type is passed in for the test.
4964 """
65+ if not isinstance (texts , list ):
66+ raise TypeError ("Must pass in a list of str values to embed." )
67+ if len (texts ) > 0 and not isinstance (texts [0 ], str ):
68+ raise TypeError ("Must pass in a list of str values to embed." )
69+
5070 embeddings : List = []
5171 for batch in self .batchify (texts , batch_size , preprocess ):
5272 response = self ._model_client .create (input = batch , engine = self ._model )
@@ -56,7 +76,11 @@ def embed_many(
5676 ]
5777 return embeddings
5878
59- @retry (wait = wait_random_exponential (min = 1 , max = 60 ), stop = stop_after_attempt (6 ))
79+ @retry (
80+ wait = wait_random_exponential (min = 1 , max = 60 ),
81+ stop = stop_after_attempt (6 ),
82+ retry = retry_if_not_exception_type (TypeError ),
83+ )
6084 def embed (
6185 self ,
6286 text : str ,
@@ -74,13 +98,23 @@ def embed(
7498
7599 Returns:
76100 List[float]: Embedding.
101+
102+ Raises:
103+ TypeError: If the wrong input type is passed in for the test.
77104 """
105+ if not isinstance (text , str ):
106+ raise TypeError ("Must pass in a str value to embed." )
107+
78108 if preprocess :
79109 text = preprocess (text )
80110 result = self ._model_client .create (input = [text ], engine = self ._model )
81111 return self ._process_embedding (result ["data" ][0 ]["embedding" ], as_buffer )
82112
83- @retry (wait = wait_random_exponential (min = 1 , max = 60 ), stop = stop_after_attempt (6 ))
113+ @retry (
114+ wait = wait_random_exponential (min = 1 , max = 60 ),
115+ stop = stop_after_attempt (6 ),
116+ retry = retry_if_not_exception_type (TypeError ),
117+ )
84118 async def aembed_many (
85119 self ,
86120 texts : List [str ],
@@ -101,7 +135,15 @@ async def aembed_many(
101135
102136 Returns:
103137 List[List[float]]: List of embeddings.
138+
139+ Raises:
140+ TypeError: If the wrong input type is passed in for the test.
104141 """
142+ if not isinstance (texts , list ):
143+ raise TypeError ("Must pass in a list of str values to embed." )
144+ if len (texts ) > 0 and not isinstance (texts [0 ], str ):
145+ raise TypeError ("Must pass in a list of str values to embed." )
146+
105147 embeddings : List = []
106148 for batch in self .batchify (texts , batch_size , preprocess ):
107149 response = await self ._model_client .acreate (input = batch , engine = self ._model )
@@ -111,7 +153,11 @@ async def aembed_many(
111153 ]
112154 return embeddings
113155
114- @retry (wait = wait_random_exponential (min = 1 , max = 60 ), stop = stop_after_attempt (6 ))
156+ @retry (
157+ wait = wait_random_exponential (min = 1 , max = 60 ),
158+ stop = stop_after_attempt (6 ),
159+ retry = retry_if_not_exception_type (TypeError ),
160+ )
115161 async def aembed (
116162 self ,
117163 text : str ,
@@ -129,7 +175,13 @@ async def aembed(
129175
130176 Returns:
131177 List[float]: Embedding.
178+
179+ Raises:
180+ TypeError: If the wrong input type is passed in for the test.
132181 """
182+ if not isinstance (text , str ):
183+ raise TypeError ("Must pass in a str value to embed." )
184+
133185 if preprocess :
134186 text = preprocess (text )
135187 result = await self ._model_client .acreate (input = [text ], engine = self ._model )
0 commit comments