Skip to content

Commit 5b9b0c7

Browse files
Improve VectorField schema default args and tests (#68)
By default, the `VectorField`'s in Redis do NOT need to have the block size or initial cap args set. This change allows for those params to be set, and only included in the field args if so. Otherwise, they are ignored. Also include a small refactor on the schema classes for vectors as well as new schema unit tests.
1 parent 4aab8b7 commit 5b9b0c7

File tree

6 files changed

+207
-25
lines changed

6 files changed

+207
-25
lines changed

conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ def client():
2828
def openai_key():
2929
return os.getenv("OPENAI_API_KEY")
3030

31+
@pytest.fixture
32+
def gcp_location():
33+
return os.getenv("GCP_LOCATION")
34+
35+
@pytest.fixture
36+
def gcp_project_id():
37+
return os.getenv("GCP_PROJECT_ID")
3138

3239
@pytest.fixture(scope="session")
3340
def event_loop():

redisvl/schema.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import List, Optional, Union
2+
from typing import Any, Dict, List, Optional, Union
33
from uuid import uuid4
44

55
import yaml
@@ -64,48 +64,54 @@ class BaseVectorField(BaseModel):
6464
algorithm: object = Field(...)
6565
datatype: str = Field(default="FLOAT32")
6666
distance_metric: str = Field(default="COSINE")
67+
initial_cap: Optional[int] = None
6768

6869
@validator("algorithm", "datatype", "distance_metric", pre=True, each_item=True)
6970
def uppercase_strings(cls, v):
7071
return v.upper()
7172

73+
def as_field(self) -> Dict[str, Any]:
74+
field_data = {
75+
"TYPE": self.datatype,
76+
"DIM": self.dims,
77+
"DISTANCE_METRIC": self.distance_metric,
78+
}
79+
if self.initial_cap is not None: # Only include it if it's set
80+
field_data["INITIAL_CAP"] = self.initial_cap
81+
return field_data
82+
7283

7384
class FlatVectorField(BaseVectorField):
74-
algorithm: object = Literal["FLAT"]
85+
algorithm: Literal["FLAT"] = "FLAT"
86+
block_size: Optional[int] = None
7587

7688
def as_field(self):
77-
return VectorField(
78-
self.name,
79-
self.algorithm,
80-
{
81-
"TYPE": self.datatype,
82-
"DIM": self.dims,
83-
"DISTANCE_METRIC": self.distance_metric,
84-
},
85-
)
89+
# grab base field params and augment with flat-specific fields
90+
field_data = super().as_field()
91+
if self.block_size is not None:
92+
field_data["BLOCK_SIZE"] = self.block_size
93+
return VectorField(self.name, self.algorithm, field_data)
8694

8795

8896
class HNSWVectorField(BaseVectorField):
89-
algorithm: object = Literal["HNSW"]
97+
algorithm: Literal["HNSW"] = "HNSW"
9098
m: int = Field(default=16)
9199
ef_construction: int = Field(default=200)
92100
ef_runtime: int = Field(default=10)
93-
epsilon: float = Field(default=0.8)
101+
epsilon: float = Field(default=0.01)
94102

95103
def as_field(self):
96-
return VectorField(
97-
self.name,
98-
self.algorithm,
104+
# grab base field params and augment with hnsw-specific fields
105+
field_data = super().as_field()
106+
field_data.update(
99107
{
100-
"TYPE": self.datatype,
101-
"DIM": self.dims,
102-
"DISTANCE_METRIC": self.distance_metric,
103108
"M": self.m,
104109
"EF_CONSTRUCTION": self.ef_construction,
105110
"EF_RUNTIME": self.ef_runtime,
106111
"EPSILON": self.epsilon,
107-
},
112+
}
108113
)
114+
return VectorField(self.name, self.algorithm, field_data)
109115

110116

111117
class IndexModel(BaseModel):

redisvl/vectorize/text/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, model: str, api_config: Optional[Dict] = None):
3636
import openai
3737
except ImportError:
3838
raise ImportError(
39-
"OpenAI vectorizer requires the openai library. Please install with pip install openai"
39+
"OpenAI vectorizer requires the openai library. Please install with `pip install openai`"
4040
)
4141

4242
if not api_config or "api_key" not in api_config:

redisvl/vectorize/text/vertexai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
except ImportError:
4949
raise ImportError(
5050
"VertexAI vectorizer requires the google-cloud-aiplatform library."
51-
"Please install with pip install google-cloud-aiplatform>=1.26"
51+
"Please install with `pip install google-cloud-aiplatform>=1.26`"
5252
)
5353

5454
self._model_client = TextEmbeddingModel.from_pretrained(model)

tests/integration/test_vectorizers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
@pytest.fixture(params=[HFTextVectorizer, OpenAITextVectorizer, VertexAITextVectorizer])
13-
def vectorizer(request, openai_key):
13+
def vectorizer(request, openai_key, gcp_location, gcp_project_id):
1414
# Here we use actual models for integration test
1515
if request.param == HFTextVectorizer:
1616
return request.param(model="sentence-transformers/all-mpnet-base-v2")
@@ -23,8 +23,8 @@ def vectorizer(request, openai_key):
2323
return request.param(
2424
model="textembedding-gecko",
2525
api_config={
26-
"location": os.environ["GCP_LOCATION"],
27-
"project_id": os.environ["GCP_PROJECT_ID"],
26+
"location": gcp_location,
27+
"project_id": gcp_project_id,
2828
},
2929
)
3030

tests/test_schema.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import pytest
2+
from pydantic import ValidationError
3+
from redis.commands.search.field import (
4+
GeoField,
5+
NumericField,
6+
TagField,
7+
TextField,
8+
VectorField,
9+
)
10+
11+
from redisvl.schema import (
12+
FlatVectorField,
13+
GeoFieldSchema,
14+
HNSWVectorField,
15+
NumericFieldSchema,
16+
SchemaModel,
17+
TagFieldSchema,
18+
TextFieldSchema,
19+
read_schema,
20+
)
21+
22+
23+
# Utility functions to create schema instances with default values
24+
def create_text_field_schema(**kwargs):
25+
defaults = {"name": "example_textfield", "sortable": False, "weight": 1.0}
26+
defaults.update(kwargs)
27+
return TextFieldSchema(**defaults)
28+
29+
30+
def create_tag_field_schema(**kwargs):
31+
defaults = {"name": "example_tagfield", "sortable": False, "separator": ","}
32+
defaults.update(kwargs)
33+
return TagFieldSchema(**defaults)
34+
35+
36+
def create_numeric_field_schema(**kwargs):
37+
defaults = {"name": "example_numericfield", "sortable": False}
38+
defaults.update(kwargs)
39+
return NumericFieldSchema(**defaults)
40+
41+
42+
def create_geo_field_schema(**kwargs):
43+
defaults = {"name": "example_geofield", "sortable": False}
44+
defaults.update(kwargs)
45+
return GeoFieldSchema(**defaults)
46+
47+
48+
def create_flat_vector_field(**kwargs):
49+
defaults = {"name": "example_flatvectorfield", "dims": 128, "algorithm": "FLAT"}
50+
defaults.update(kwargs)
51+
return FlatVectorField(**defaults)
52+
53+
54+
def create_hnsw_vector_field(**kwargs):
55+
defaults = {
56+
"name": "example_hnswvectorfield",
57+
"dims": 128,
58+
"algorithm": "HNSW",
59+
"m": 16,
60+
"ef_construction": 200,
61+
"ef_runtime": 10,
62+
"epsilon": 0.01,
63+
}
64+
defaults.update(kwargs)
65+
return HNSWVectorField(**defaults)
66+
67+
68+
# Tests for field schema creation and validation
69+
@pytest.mark.parametrize(
70+
"schema_func,field_class",
71+
[
72+
(create_text_field_schema, TextField),
73+
(create_tag_field_schema, TagField),
74+
(create_numeric_field_schema, NumericField),
75+
(create_geo_field_schema, GeoField),
76+
],
77+
)
78+
def test_field_schema_as_field(schema_func, field_class):
79+
schema = schema_func()
80+
field = schema.as_field()
81+
assert isinstance(field, field_class)
82+
assert field.name == f"example_{field_class.__name__.lower()}"
83+
84+
85+
def test_vector_fields_as_field():
86+
flat_vector_schema = create_flat_vector_field()
87+
flat_vector_field = flat_vector_schema.as_field()
88+
assert isinstance(flat_vector_field, VectorField)
89+
assert flat_vector_field.name == "example_flatvectorfield"
90+
91+
hnsw_vector_schema = create_hnsw_vector_field()
92+
hnsw_vector_field = hnsw_vector_schema.as_field()
93+
assert isinstance(hnsw_vector_field, VectorField)
94+
assert hnsw_vector_field.name == "example_hnswvectorfield"
95+
96+
97+
@pytest.mark.parametrize(
98+
"vector_schema_func,extra_params",
99+
[
100+
(create_flat_vector_field, {"block_size": 100}),
101+
(create_hnsw_vector_field, {"m": 24, "ef_construction": 300}),
102+
],
103+
)
104+
def test_vector_fields_with_optional_params(vector_schema_func, extra_params):
105+
# Create a vector schema with additional parameters set.
106+
vector_schema = vector_schema_func(**extra_params)
107+
vector_field = vector_schema.as_field()
108+
109+
# Assert that the field is correctly created and the optional parameters are set.
110+
assert isinstance(vector_field, VectorField)
111+
for param, value in extra_params.items():
112+
assert param.upper() in vector_field.args
113+
i = vector_field.args.index(param.upper())
114+
assert vector_field.args[i + 1] == value
115+
116+
117+
def test_hnsw_vector_field_optional_params_not_set():
118+
# Create HNSW vector field without setting optional params
119+
hnsw_field = HNSWVectorField(name="example_vector", dims=128, algorithm="HNSW")
120+
121+
assert hnsw_field.m == 16 # default value
122+
assert hnsw_field.ef_construction == 200 # default value
123+
assert hnsw_field.ef_runtime == 10 # default value
124+
assert hnsw_field.epsilon == 0.01 # default value
125+
126+
field_exported = hnsw_field.as_field()
127+
128+
# Check the default values are correctly applied in the exported object
129+
assert field_exported.args[field_exported.args.index("M") + 1] == 16
130+
assert field_exported.args[field_exported.args.index("EF_CONSTRUCTION") + 1] == 200
131+
assert field_exported.args[field_exported.args.index("EF_RUNTIME") + 1] == 10
132+
assert field_exported.args[field_exported.args.index("EPSILON") + 1] == 0.01
133+
134+
135+
def test_flat_vector_field_block_size_not_set():
136+
# Create Flat vector field without setting block_size
137+
flat_field = FlatVectorField(name="example_vector", dims=128, algorithm="FLAT")
138+
field_exported = flat_field.as_field()
139+
140+
# block_size and initial_cap should not be in the exported field if it was not set
141+
assert "BLOCK_SIZE" not in field_exported.args
142+
assert "INITIAL_CAP" not in field_exported.args
143+
144+
145+
# Test for schema model validation
146+
def test_schema_model_validation_success():
147+
valid_index = {"name": "test_index", "storage_type": "hash"}
148+
valid_fields = {"text": [create_text_field_schema()]}
149+
schema_model = SchemaModel(index=valid_index, fields=valid_fields)
150+
151+
assert schema_model.index.name == "test_index"
152+
assert schema_model.index.storage_type == "hash"
153+
assert len(schema_model.fields.text) == 1
154+
155+
156+
def test_schema_model_validation_failures():
157+
# Invalid storage type
158+
with pytest.raises(ValueError):
159+
invalid_index = {"name": "test_index", "storage_type": "unsupported"}
160+
SchemaModel(index=invalid_index, fields={})
161+
162+
# Missing required field
163+
with pytest.raises(ValidationError):
164+
SchemaModel(index={}, fields={})
165+
166+
167+
def test_read_schema_file_not_found():
168+
with pytest.raises(FileNotFoundError):
169+
read_schema("non_existent_file.yaml")

0 commit comments

Comments
 (0)