Skip to content

Commit 2f23e10

Browse files
Add schema generator (#77)
The schema generator is responsible for taking a dictionary of key/values, inferring types, and converting to a dictionary of redisvl schema.
1 parent f2db78b commit 2f23e10

File tree

3 files changed

+164
-1
lines changed

3 files changed

+164
-1
lines changed

redisvl/schema.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,96 @@ def read_schema(file_path: str):
158158
schema = yaml.safe_load(f)
159159

160160
return SchemaModel(**schema)
161+
162+
163+
class MetadataSchemaGenerator:
164+
"""
165+
A class to generate a schema for metadata, categorizing fields into text, numeric, and tag types.
166+
"""
167+
168+
def _test_numeric(self, value) -> bool:
169+
"""
170+
Test if the given value can be represented as a numeric value.
171+
172+
Args:
173+
value: The value to test.
174+
175+
Returns:
176+
bool: True if the value can be converted to float, False otherwise.
177+
"""
178+
try:
179+
float(value)
180+
return True
181+
except (ValueError, TypeError):
182+
return False
183+
184+
def _infer_type(self, value) -> Optional[str]:
185+
"""
186+
Infer the type of the given value.
187+
188+
Args:
189+
value: The value to infer the type of.
190+
191+
Returns:
192+
Optional[str]: The inferred type of the value, or None if the type is unrecognized or the value is empty.
193+
"""
194+
if value is None or value == "":
195+
return None
196+
elif self._test_numeric(value):
197+
return "numeric"
198+
elif isinstance(value, (list, set, tuple)) and all(
199+
isinstance(v, str) for v in value
200+
):
201+
return "tag"
202+
elif isinstance(value, str):
203+
return "text"
204+
else:
205+
return "unknown"
206+
207+
def generate(
208+
self, metadata: Dict[str, Any], strict: Optional[bool] = False
209+
) -> Dict[str, List[Dict[str, Any]]]:
210+
"""
211+
Generate a schema from the provided metadata.
212+
213+
This method categorizes each metadata field into text, numeric, or tag types based on the field values.
214+
It also allows forcing strict type determination by raising an exception if a type cannot be inferred.
215+
216+
Args:
217+
metadata: The metadata dictionary to generate the schema from.
218+
strict: If True, the method will raise an exception for fields where the type cannot be determined.
219+
220+
Returns:
221+
Dict[str, List[Dict[str, Any]]]: A dictionary with keys 'text', 'numeric', and 'tag', each mapping to a list of field schemas.
222+
223+
Raises:
224+
ValueError: If the force parameter is True and a field's type cannot be determined.
225+
"""
226+
result: Dict[str, List[Dict[str, Any]]] = {"text": [], "numeric": [], "tag": []}
227+
228+
for key, value in metadata.items():
229+
field_type = self._infer_type(value)
230+
231+
if field_type in ["unknown", None]:
232+
if strict:
233+
raise ValueError(
234+
f"Unable to determine field type for key '{key}' with value '{value}'"
235+
)
236+
print(
237+
f"Warning: Unable to determine field type for key '{key}' with value '{value}'"
238+
)
239+
continue
240+
241+
# Extract the field class with defaults
242+
field_class = {
243+
"text": TextFieldSchema,
244+
"tag": TagFieldSchema,
245+
"numeric": NumericFieldSchema,
246+
}.get(
247+
field_type # type: ignore
248+
)
249+
250+
if field_class:
251+
result[field_type].append(field_class(name=key).dict(exclude_none=True)) # type: ignore
252+
253+
return result

tests/integration/test_vectorizers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
@pytest.fixture
1313
def skip_vectorizer() -> bool:
1414
# os.getenv returns a string
15-
return os.getenv("SKIP_VECTORIZERS", 'False').lower() == 'true'
15+
return os.getenv("SKIP_VECTORIZERS", "False").lower() == "true"
16+
1617

1718
skip_vectorizer_test = lambda: pytest.config.getfixturevalue("skip_vectorizer")
1819

tests/unit/test_schema.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
FlatVectorField,
1313
GeoFieldSchema,
1414
HNSWVectorField,
15+
MetadataSchemaGenerator,
1516
NumericFieldSchema,
1617
SchemaModel,
1718
TagFieldSchema,
@@ -167,3 +168,71 @@ def test_schema_model_validation_failures():
167168
def test_read_schema_file_not_found():
168169
with pytest.raises(FileNotFoundError):
169170
read_schema("non_existent_file.yaml")
171+
172+
173+
# Fixture for the generator instance
174+
@pytest.fixture
175+
def schema_generator():
176+
return MetadataSchemaGenerator()
177+
178+
179+
# Test cases for _test_numeric
180+
@pytest.mark.parametrize(
181+
"value, expected",
182+
[
183+
(123, True),
184+
("123", True),
185+
("123.45", True),
186+
("abc", False),
187+
(None, False),
188+
("", False),
189+
],
190+
)
191+
def test_test_numeric(schema_generator, value, expected):
192+
assert schema_generator._test_numeric(value) == expected
193+
194+
195+
# Test cases for _infer_type
196+
@pytest.mark.parametrize(
197+
"value, expected",
198+
[
199+
(123, "numeric"),
200+
("123", "numeric"),
201+
(["tag1", "tag2"], "tag"),
202+
("text", "text"),
203+
(None, None),
204+
("", None),
205+
({"key": "value"}, "unknown"),
206+
],
207+
)
208+
def test_infer_type(schema_generator, value, expected):
209+
assert schema_generator._infer_type(value) == expected
210+
211+
212+
# Test cases for generate
213+
@pytest.mark.parametrize(
214+
"metadata, strict, expected",
215+
[
216+
(
217+
{"name": "John", "age": 30, "tags": ["friend", "colleague"]},
218+
False,
219+
{
220+
"text": [TextFieldSchema(name="name").dict(exclude_none=True)],
221+
"numeric": [NumericFieldSchema(name="age").dict(exclude_none=True)],
222+
"tag": [TagFieldSchema(name="tags").dict(exclude_none=True)],
223+
},
224+
),
225+
(
226+
{"invalid": {"nested": "dict"}},
227+
False,
228+
{"text": [], "numeric": [], "tag": []},
229+
),
230+
({"invalid": {"nested": "dict"}}, True, pytest.raises(ValueError)),
231+
],
232+
)
233+
def test_generate(schema_generator, metadata, strict, expected):
234+
if not isinstance(expected, dict):
235+
with expected:
236+
schema_generator.generate(metadata, strict)
237+
else:
238+
assert schema_generator.generate(metadata, strict) == expected

0 commit comments

Comments
 (0)