Skip to content

Commit 6fe7c08

Browse files
committed
Improved typechecking for tests [skip ci]
1 parent f083137 commit 6fe7c08

1 file changed

Lines changed: 7 additions & 7 deletions

File tree

tests/test_sqlmodel.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88

99
engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')
1010
with Session(engine) as session:
11-
session.exec(text('CREATE EXTENSION IF NOT EXISTS vector'))
11+
session.exec(text('CREATE EXTENSION IF NOT EXISTS vector')) # type: ignore
1212

1313

1414
class Item(SQLModel, table=True):
1515
__tablename__ = 'sqlmodel_item'
1616

1717
id: Optional[int] = Field(default=None, primary_key=True)
18-
embedding: Optional[Any] = Field(default=None, sa_type=VECTOR(3))
19-
half_embedding: Optional[Any] = Field(default=None, sa_type=HALFVEC(3))
20-
binary_embedding: Optional[Any] = Field(default=None, sa_type=BIT(3))
21-
sparse_embedding: Optional[Any] = Field(default=None, sa_type=SPARSEVEC(3))
18+
embedding: Optional[Any] = Field(default=None, sa_type=VECTOR(3)) # type: ignore
19+
half_embedding: Optional[Any] = Field(default=None, sa_type=HALFVEC(3)) # type: ignore
20+
binary_embedding: Optional[Any] = Field(default=None, sa_type=BIT(3)) # type: ignore
21+
sparse_embedding: Optional[Any] = Field(default=None, sa_type=SPARSEVEC(3)) # type: ignore
2222

2323

2424
SQLModel.metadata.drop_all(engine)
@@ -202,7 +202,7 @@ def test_vector_avg(self):
202202
session.add(Item(embedding=[1, 2, 3]))
203203
session.add(Item(embedding=[4, 5, 6]))
204204
res = session.exec(select(avg(Item.embedding))).first()
205-
assert np.array_equal(res, np.array([2.5, 3.5, 4.5]))
205+
assert np.array_equal(res, np.array([2.5, 3.5, 4.5])) # type: ignore
206206

207207
def test_vector_sum(self):
208208
with Session(engine) as session:
@@ -211,7 +211,7 @@ def test_vector_sum(self):
211211
session.add(Item(embedding=[1, 2, 3]))
212212
session.add(Item(embedding=[4, 5, 6]))
213213
res = session.exec(select(sum(Item.embedding))).first()
214-
assert np.array_equal(res, np.array([5, 7, 9]))
214+
assert np.array_equal(res, np.array([5, 7, 9])) # type: ignore
215215

216216
def test_halfvec_avg(self):
217217
with Session(engine) as session:

0 commit comments

Comments
 (0)