Skip to content

Commit 6345cc1

Browse files
Add optional data load preprocessor hook (#67)
Another finding from working the arxiv example.. often you need to unpack or edit a record before writing to the source. If you do this before invoking redisvl, you add an additional loop, one that is unnecessary in the end. So this allows devs to optionally add a preprocessor method on load to call against each record.
1 parent 5b9b0c7 commit 6345cc1

File tree

2 files changed

+81
-10
lines changed

2 files changed

+81
-10
lines changed

redisvl/index.py

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
2+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Union
33
from uuid import uuid4
44

55
if TYPE_CHECKING:
@@ -217,13 +217,17 @@ def delete(self, drop: bool = True):
217217
Args:
218218
drop (bool, optional): Delete the documents in the index. Defaults to True.
219219
220-
raises:
220+
Raises:
221221
redis.exceptions.ResponseError: If the index does not exist.
222222
"""
223223
raise NotImplementedError
224224

225225
def load(
226-
self, data: Iterable[Dict[str, Any]], key_field: Optional[str] = None, **kwargs
226+
self,
227+
data: Iterable[Dict[str, Any]],
228+
key_field: Optional[str] = None,
229+
preprocess: Optional[Callable] = None,
230+
**kwargs,
227231
):
228232
"""Load data into Redis and index using this SearchIndex object.
229233
@@ -232,8 +236,10 @@ def load(
232236
containing the data to be indexed.
233237
key_field (Optional[str], optional): A field within the record
234238
to use in the Redis hash key.
239+
preprocess (Optional[Callabl], optional): An optional preprocessor function
240+
that mutates the individual record before writing to redis.
235241
236-
raises:
242+
Raises:
237243
redis.exceptions.ResponseError: If the index does not exist.
238244
"""
239245
raise NotImplementedError
@@ -357,7 +363,11 @@ def delete(self, drop: bool = True):
357363

358364
@check_connected("_redis_conn")
359365
def load(
360-
self, data: Iterable[Dict[str, Any]], key_field: Optional[str] = None, **kwargs
366+
self,
367+
data: Iterable[Dict[str, Any]],
368+
key_field: Optional[str] = None,
369+
preprocess: Optional[Callable] = None,
370+
**kwargs,
361371
):
362372
"""Load data into Redis and index using this SearchIndex object.
363373
@@ -366,9 +376,16 @@ def load(
366376
containing the data to be indexed.
367377
key_field (Optional[str], optional): A field within the record to
368378
use in the Redis hash key.
379+
preprocess (Optional[Callable], optional): An optional preprocessor function
380+
that mutates the individual record before writing to redis.
369381
370382
raises:
371383
redis.exceptions.ResponseError: If the index does not exist.
384+
385+
Example:
386+
>>> data = [{"foo": "bar"}, {"test": "values"}]
387+
>>> def func(record: dict): record["new"]="value";return record
388+
>>> index.load(data, preprocess=func)
372389
"""
373390
# TODO -- should we return a count of the upserts? or some kind of metadata?
374391
if data:
@@ -381,6 +398,19 @@ def load(
381398
with self._redis_conn.pipeline(transaction=False) as pipe: # type: ignore
382399
for record in data:
383400
key = self._create_key(record, key_field)
401+
# Optionally preprocess the record and validate type
402+
if preprocess:
403+
try:
404+
record = preprocess(record)
405+
except Exception as e:
406+
raise RuntimeError(
407+
"Error while preprocessing records on load"
408+
) from e
409+
if not isinstance(record, dict):
410+
raise TypeError(
411+
f"Individual records must be of type dict, got type {type(record)}"
412+
)
413+
# Write the record to Redis
384414
pipe.hset(key, mapping=record) # type: ignore
385415
if ttl:
386416
pipe.expire(key, ttl)
@@ -406,8 +436,8 @@ class AsyncSearchIndex(SearchIndexBase):
406436
Example:
407437
>>> from redisvl.index import AsyncSearchIndex
408438
>>> index = AsyncSearchIndex.from_yaml("schema.yaml")
409-
>>> index.create(overwrite=True)
410-
>>> index.load(data) # data is an iterable of dictionaries
439+
>>> await index.create(overwrite=True)
440+
>>> await index.load(data) # data is an iterable of dictionaries
411441
"""
412442

413443
def __init__(
@@ -502,7 +532,7 @@ async def delete(self, drop: bool = True):
502532
Args:
503533
drop (bool, optional): Delete the documents in the index. Defaults to True.
504534
505-
raises:
535+
Raises:
506536
redis.exceptions.ResponseError: If the index does not exist.
507537
"""
508538
# Delete the search index
@@ -514,6 +544,7 @@ async def load(
514544
data: Iterable[Dict[str, Any]],
515545
concurrency: int = 10,
516546
key_field: Optional[str] = None,
547+
preprocess: Optional[Callable] = None,
517548
**kwargs,
518549
):
519550
"""Load data into Redis and index using this SearchIndex object.
@@ -524,21 +555,41 @@ async def load(
524555
concurrency (int, optional): Number of concurrent tasks to run. Defaults to 10.
525556
key_field (Optional[str], optional): A field within the record to
526557
use in the Redis hash key.
558+
preprocess (Optional[Callable], optional): An optional preprocessor function
559+
that mutates the individual record before writing to redis.
527560
528-
raises:
561+
Raises:
529562
redis.exceptions.ResponseError: If the index does not exist.
563+
564+
Example:
565+
>>> data = [{"foo": "bar"}, {"test": "values"}]
566+
>>> def func(record: dict): record["new"]="value";return record
567+
>>> await index.load(data, preprocess=func)
530568
"""
531569
ttl = kwargs.get("ttl")
532570
semaphore = asyncio.Semaphore(concurrency)
533571

534572
async def _load(record: dict):
535573
async with semaphore:
536574
key = self._create_key(record, key_field)
575+
# Optionally preprocess the record and validate type
576+
if preprocess:
577+
try:
578+
record = preprocess(record)
579+
except Exception as e:
580+
raise RuntimeError(
581+
"Error while preprocessing records on load"
582+
) from e
583+
if not isinstance(record, dict):
584+
raise TypeError(
585+
f"Individual records must be of type dict, got type {type(record)}"
586+
)
587+
# Write the record to Redis
537588
await self._redis_conn.hset(key, mapping=record) # type: ignore
538589
if ttl:
539590
await self._redis_conn.expire(key, ttl) # type: ignore
540591

541-
# gather with concurrency
592+
# Gather with concurrency
542593
await asyncio.gather(*[_load(record) for record in data])
543594

544595
@check_connected("_redis_conn")

tests/test_index.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,26 @@ def test_search_index_load(client):
6262
assert convert_bytes(client.hget("rvl:1", "value")) == "test"
6363

6464

65+
def test_search_index_load_preprocess(client):
66+
si = SearchIndex("my_index", fields=fields)
67+
si.set_client(client)
68+
si.create(overwrite=True)
69+
data = [{"id": "1", "value": "test"}]
70+
71+
def preprocess(record):
72+
record["test"] = "foo"
73+
return record
74+
75+
si.load(data, key_field="id", preprocess=preprocess)
76+
assert convert_bytes(client.hget("rvl:1", "test")) == "foo"
77+
78+
def bad_preprocess(record):
79+
return 1
80+
81+
with pytest.raises(TypeError):
82+
si.load(data, key_field="id", preprocess=bad_preprocess)
83+
84+
6585
@pytest.mark.asyncio
6686
async def test_async_search_index_creation(async_client):
6787
asi = AsyncSearchIndex("my_index", fields=fields)

0 commit comments

Comments
 (0)