Skip to content

Commit 87d0913

Browse files
committed
Improved typechecking for tests [skip ci]
1 parent 68e1a80 commit 87d0913

3 files changed

Lines changed: 18 additions & 18 deletions

File tree

tests/test_asyncpg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ async def test_bit(self):
6464

6565
await register_vector(conn)
6666

67-
embedding = asyncpg.BitString('101')
67+
embedding = asyncpg.BitString('101') # type: ignore
6868
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding)
6969

7070
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")

tests/test_sqlalchemy.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def test_vector(self, engine):
190190
with Session(engine) as session:
191191
session.add(Item(id=1, embedding=[1, 2, 3]))
192192
session.commit()
193-
item = session.get(Item, 1)
193+
item = session.get_one(Item, 1)
194194
assert np.array_equal(item.embedding, [1, 2, 3])
195195

196196
def test_vector_l2_distance(self, engine):
@@ -245,7 +245,7 @@ def test_halfvec(self, engine):
245245
with Session(engine) as session:
246246
session.add(Item(id=1, half_embedding=[1, 2, 3]))
247247
session.commit()
248-
item = session.get(Item, 1)
248+
item = session.get_one(Item, 1)
249249
assert item.half_embedding == HalfVector([1, 2, 3])
250250

251251
def test_halfvec_l2_distance(self, engine):
@@ -300,7 +300,7 @@ def test_bit(self, engine):
300300
with Session(engine) as session:
301301
session.add(Item(id=1, binary_embedding='101'))
302302
session.commit()
303-
item = session.get(Item, 1)
303+
item = session.get_one(Item, 1)
304304
assert item.binary_embedding == '101'
305305

306306
def test_bit_hamming_distance(self, engine):
@@ -337,7 +337,7 @@ def test_sparsevec(self, engine):
337337
with Session(engine) as session:
338338
session.add(Item(id=1, sparse_embedding=[1, 2, 3]))
339339
session.commit()
340-
item = session.get(Item, 1)
340+
item = session.get_one(Item, 1)
341341
assert item.sparse_embedding == SparseVector([1, 2, 3])
342342

343343
def test_sparsevec_l2_distance(self, engine):
@@ -560,7 +560,7 @@ def test_vector_array(self, engine):
560560
session.commit()
561561

562562
# this fails if the driver does not cast arrays
563-
item = session.get(Item, 1)
563+
item = session.get_one(Item, 1)
564564
assert np.array_equal(item.embeddings[0], [1, 2, 3])
565565
assert np.array_equal(item.embeddings[1], [4, 5, 6])
566566

@@ -570,7 +570,7 @@ def test_halfvec_array(self, engine):
570570
session.commit()
571571

572572
# this fails if the driver does not cast arrays
573-
item = session.get(Item, 1)
573+
item = session.get_one(Item, 1)
574574
assert item.half_embeddings == [HalfVector([1, 2, 3]), HalfVector([4, 5, 6])]
575575

576576

@@ -587,7 +587,7 @@ async def test_vector(self, engine):
587587
async with session.begin():
588588
embedding = np.array([1, 2, 3])
589589
session.add(Item(id=1, embedding=embedding))
590-
item = await session.get(Item, 1)
590+
item = await session.get_one(Item, 1)
591591
assert np.array_equal(item.embedding, embedding)
592592

593593
await engine.dispose()
@@ -600,7 +600,7 @@ async def test_halfvec(self, engine):
600600
async with session.begin():
601601
embedding = [1, 2, 3]
602602
session.add(Item(id=1, half_embedding=embedding))
603-
item = await session.get(Item, 1)
603+
item = await session.get_one(Item, 1)
604604
assert item.half_embedding == HalfVector(embedding)
605605

606606
await engine.dispose()
@@ -613,12 +613,12 @@ async def test_bit(self, engine):
613613
async with session.begin():
614614
embedding = asyncpg.BitString('101') if engine == asyncpg_engine else '101'
615615
session.add(Item(id=1, binary_embedding=embedding))
616-
item = await session.get(Item, 1)
616+
item = await session.get_one(Item, 1)
617617
assert item.binary_embedding == embedding
618618

619619
if engine == asyncpg_engine:
620620
session.add(Item(id=2, binary_embedding='101'))
621-
item = await session.get(Item, 2)
621+
item = await session.get_one(Item, 2)
622622
assert item.binary_embedding == embedding
623623

624624
await engine.dispose()
@@ -631,7 +631,7 @@ async def test_sparsevec(self, engine):
631631
async with session.begin():
632632
embedding = [1, 2, 3]
633633
session.add(Item(id=1, sparse_embedding=embedding))
634-
item = await session.get(Item, 1)
634+
item = await session.get_one(Item, 1)
635635
assert item.sparse_embedding == SparseVector(embedding)
636636

637637
await engine.dispose()
@@ -662,12 +662,12 @@ async def test_vector_array(self, engine):
662662
async with async_session() as session:
663663
async with session.begin():
664664
session.add(Item(id=1, embeddings=[Vector([1, 2, 3]), Vector([4, 5, 6])]))
665-
item = await session.get(Item, 1)
665+
item = await session.get_one(Item, 1)
666666
assert np.array_equal(item.embeddings[0], [1, 2, 3])
667667
assert np.array_equal(item.embeddings[1], [4, 5, 6])
668668

669669
session.add(Item(id=2, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
670-
item = await session.get(Item, 2)
670+
item = await session.get_one(Item, 2)
671671
assert np.array_equal(item.embeddings[0], [1, 2, 3])
672672
assert np.array_equal(item.embeddings[1], [4, 5, 6])
673673

tests/test_sqlmodel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_vector(self):
7575
with Session(engine) as session:
7676
session.add(Item(id=1, embedding=[1, 2, 3]))
7777
session.commit()
78-
item = session.get(Item, 1)
78+
item = session.get_one(Item, 1)
7979
assert np.array_equal(item.embedding, np.array([1, 2, 3]))
8080

8181
def test_vector_l2_distance(self):
@@ -106,7 +106,7 @@ def test_halfvec(self):
106106
with Session(engine) as session:
107107
session.add(Item(id=1, half_embedding=[1, 2, 3]))
108108
session.commit()
109-
item = session.get(Item, 1)
109+
item = session.get_one(Item, 1)
110110
assert item.half_embedding == HalfVector([1, 2, 3])
111111

112112
def test_halfvec_l2_distance(self):
@@ -137,7 +137,7 @@ def test_bit(self):
137137
with Session(engine) as session:
138138
session.add(Item(id=1, binary_embedding='101'))
139139
session.commit()
140-
item = session.get(Item, 1)
140+
item = session.get_one(Item, 1)
141141
assert item.binary_embedding == '101'
142142

143143
def test_bit_hamming_distance(self):
@@ -156,7 +156,7 @@ def test_sparsevec(self):
156156
with Session(engine) as session:
157157
session.add(Item(id=1, sparse_embedding=[1, 2, 3]))
158158
session.commit()
159-
item = session.get(Item, 1)
159+
item = session.get_one(Item, 1)
160160
assert item.sparse_embedding == SparseVector([1, 2, 3])
161161

162162
def test_sparsevec_l2_distance(self):

0 commit comments

Comments
 (0)