11import asyncio
2- from typing import TYPE_CHECKING , Any , Dict , Iterable , List , Optional , Union
2+ from typing import TYPE_CHECKING , Any , Callable , Dict , Iterable , List , Optional , Union
33from uuid import uuid4
44
55if TYPE_CHECKING :
@@ -217,13 +217,17 @@ def delete(self, drop: bool = True):
217217 Args:
218218 drop (bool, optional): Delete the documents in the index. Defaults to True.
219219
220- raises :
220+ Raises :
221221 redis.exceptions.ResponseError: If the index does not exist.
222222 """
223223 raise NotImplementedError
224224
225225 def load (
226- self , data : Iterable [Dict [str , Any ]], key_field : Optional [str ] = None , ** kwargs
226+ self ,
227+ data : Iterable [Dict [str , Any ]],
228+ key_field : Optional [str ] = None ,
229+ preprocess : Optional [Callable ] = None ,
230+ ** kwargs ,
227231 ):
228232 """Load data into Redis and index using this SearchIndex object.
229233
@@ -232,8 +236,10 @@ def load(
232236 containing the data to be indexed.
233237 key_field (Optional[str], optional): A field within the record
234238 to use in the Redis hash key.
239+ preprocess (Optional[Callabl], optional): An optional preprocessor function
240+ that mutates the individual record before writing to redis.
235241
236- raises :
242+ Raises :
237243 redis.exceptions.ResponseError: If the index does not exist.
238244 """
239245 raise NotImplementedError
@@ -357,7 +363,11 @@ def delete(self, drop: bool = True):
357363
358364 @check_connected ("_redis_conn" )
359365 def load (
360- self , data : Iterable [Dict [str , Any ]], key_field : Optional [str ] = None , ** kwargs
366+ self ,
367+ data : Iterable [Dict [str , Any ]],
368+ key_field : Optional [str ] = None ,
369+ preprocess : Optional [Callable ] = None ,
370+ ** kwargs ,
361371 ):
362372 """Load data into Redis and index using this SearchIndex object.
363373
@@ -366,9 +376,16 @@ def load(
366376 containing the data to be indexed.
367377 key_field (Optional[str], optional): A field within the record to
368378 use in the Redis hash key.
379+ preprocess (Optional[Callable], optional): An optional preprocessor function
380+ that mutates the individual record before writing to redis.
369381
370382 raises:
371383 redis.exceptions.ResponseError: If the index does not exist.
384+
385+ Example:
386+ >>> data = [{"foo": "bar"}, {"test": "values"}]
387+ >>> def func(record: dict): record["new"]="value";return record
388+ >>> index.load(data, preprocess=func)
372389 """
373390 # TODO -- should we return a count of the upserts? or some kind of metadata?
374391 if data :
@@ -381,6 +398,19 @@ def load(
381398 with self ._redis_conn .pipeline (transaction = False ) as pipe : # type: ignore
382399 for record in data :
383400 key = self ._create_key (record , key_field )
401+ # Optionally preprocess the record and validate type
402+ if preprocess :
403+ try :
404+ record = preprocess (record )
405+ except Exception as e :
406+ raise RuntimeError (
407+ "Error while preprocessing records on load"
408+ ) from e
409+ if not isinstance (record , dict ):
410+ raise TypeError (
411+ f"Individual records must be of type dict, got type { type (record )} "
412+ )
413+ # Write the record to Redis
384414 pipe .hset (key , mapping = record ) # type: ignore
385415 if ttl :
386416 pipe .expire (key , ttl )
@@ -406,8 +436,8 @@ class AsyncSearchIndex(SearchIndexBase):
406436 Example:
407437 >>> from redisvl.index import AsyncSearchIndex
408438 >>> index = AsyncSearchIndex.from_yaml("schema.yaml")
409- >>> index.create(overwrite=True)
410- >>> index.load(data) # data is an iterable of dictionaries
439+ >>> await index.create(overwrite=True)
440+ >>> await index.load(data) # data is an iterable of dictionaries
411441 """
412442
413443 def __init__ (
@@ -502,7 +532,7 @@ async def delete(self, drop: bool = True):
502532 Args:
503533 drop (bool, optional): Delete the documents in the index. Defaults to True.
504534
505- raises :
535+ Raises :
506536 redis.exceptions.ResponseError: If the index does not exist.
507537 """
508538 # Delete the search index
@@ -514,6 +544,7 @@ async def load(
514544 data : Iterable [Dict [str , Any ]],
515545 concurrency : int = 10 ,
516546 key_field : Optional [str ] = None ,
547+ preprocess : Optional [Callable ] = None ,
517548 ** kwargs ,
518549 ):
519550 """Load data into Redis and index using this SearchIndex object.
@@ -524,21 +555,41 @@ async def load(
524555 concurrency (int, optional): Number of concurrent tasks to run. Defaults to 10.
525556 key_field (Optional[str], optional): A field within the record to
526557 use in the Redis hash key.
558+ preprocess (Optional[Callable], optional): An optional preprocessor function
559+ that mutates the individual record before writing to redis.
527560
528- raises :
561+ Raises :
529562 redis.exceptions.ResponseError: If the index does not exist.
563+
564+ Example:
565+ >>> data = [{"foo": "bar"}, {"test": "values"}]
566+ >>> def func(record: dict): record["new"]="value";return record
567+ >>> await index.load(data, preprocess=func)
530568 """
531569 ttl = kwargs .get ("ttl" )
532570 semaphore = asyncio .Semaphore (concurrency )
533571
534572 async def _load (record : dict ):
535573 async with semaphore :
536574 key = self ._create_key (record , key_field )
575+ # Optionally preprocess the record and validate type
576+ if preprocess :
577+ try :
578+ record = preprocess (record )
579+ except Exception as e :
580+ raise RuntimeError (
581+ "Error while preprocessing records on load"
582+ ) from e
583+ if not isinstance (record , dict ):
584+ raise TypeError (
585+ f"Individual records must be of type dict, got type { type (record )} "
586+ )
587+ # Write the record to Redis
537588 await self ._redis_conn .hset (key , mapping = record ) # type: ignore
538589 if ttl :
539590 await self ._redis_conn .expire (key , ttl ) # type: ignore
540591
541- # gather with concurrency
592+ # Gather with concurrency
542593 await asyncio .gather (* [_load (record ) for record in data ])
543594
544595 @check_connected ("_redis_conn" )
0 commit comments