Skip to content

Commit f083137

Browse files
committed
Improved typechecking for tests [skip ci]
1 parent 87d0913 commit f083137

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

tests/test_sqlalchemy.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -428,11 +428,11 @@ def test_select_orm(self, engine):
428428

429429
def test_avg(self, engine):
430430
with Session(engine) as session:
431-
res = session.query(avg(Item.embedding)).first()[0]
431+
res = session.query(avg(Item.embedding)).one()[0]
432432
assert res is None
433433
session.add(Item(embedding=[1, 2, 3]))
434434
session.add(Item(embedding=[4, 5, 6]))
435-
res = session.query(avg(Item.embedding)).first()[0]
435+
res = session.query(avg(Item.embedding)).one()[0]
436436
assert np.array_equal(res, np.array([2.5, 3.5, 4.5]))
437437

438438
def test_avg_orm(self, engine):
@@ -441,16 +441,16 @@ def test_avg_orm(self, engine):
441441
assert res is None
442442
session.add(Item(embedding=[1, 2, 3]))
443443
session.add(Item(embedding=[4, 5, 6]))
444-
res = session.scalars(select(avg(Item.embedding))).first()
444+
res = session.scalars(select(avg(Item.embedding))).one()
445445
assert np.array_equal(res, np.array([2.5, 3.5, 4.5]))
446446

447447
def test_sum(self, engine):
448448
with Session(engine) as session:
449-
res = session.query(sum(Item.embedding)).first()[0]
449+
res = session.query(sum(Item.embedding)).one()[0]
450450
assert res is None
451451
session.add(Item(embedding=[1, 2, 3]))
452452
session.add(Item(embedding=[4, 5, 6]))
453-
res = session.query(sum(Item.embedding)).first()[0]
453+
res = session.query(sum(Item.embedding)).one()[0]
454454
assert np.array_equal(res, np.array([5, 7, 9]))
455455

456456
def test_sum_orm(self, engine):
@@ -459,7 +459,7 @@ def test_sum_orm(self, engine):
459459
assert res is None
460460
session.add(Item(embedding=[1, 2, 3]))
461461
session.add(Item(embedding=[4, 5, 6]))
462-
res = session.scalars(select(sum(Item.embedding))).first()
462+
res = session.scalars(select(sum(Item.embedding))).one()
463463
assert np.array_equal(res, np.array([5, 7, 9]))
464464

465465
def test_bad_dimensions(self, engine):
@@ -611,7 +611,7 @@ async def test_bit(self, engine):
611611

612612
async with async_session() as session:
613613
async with session.begin():
614-
embedding = asyncpg.BitString('101') if engine == asyncpg_engine else '101'
614+
embedding = asyncpg.BitString('101') if engine == asyncpg_engine else '101' # type: ignore
615615
session.add(Item(id=1, binary_embedding=embedding))
616616
item = await session.get_one(Item, 1)
617617
assert item.binary_embedding == embedding
@@ -645,7 +645,7 @@ async def test_avg(self, engine):
645645
session.add(Item(embedding=[1, 2, 3]))
646646
session.add(Item(embedding=[4, 5, 6]))
647647
res = await session.scalars(select(avg(Item.embedding)))
648-
assert np.array_equal(res.first(), [2.5, 3.5, 4.5])
648+
assert np.array_equal(res.one(), [2.5, 3.5, 4.5])
649649

650650
await engine.dispose()
651651

0 commit comments

Comments
 (0)