Skip to content

Commit 061de10

Browse files
committed
Improved tests [skip ci]
1 parent d31d279 commit 061de10

1 file changed

Lines changed: 16 additions & 5 deletions

File tree

tests/test_asyncpg.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ async def test_vector(self):
2626
embedding = Vector([1.5, 2, 3])
2727
embedding2 = [4.5, 5, 6]
2828
embedding3 = np.array([7.5, 8, 9]) if NUMPY_AVAILABLE else [7.5, 8, 9]
29-
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), ($3), (NULL)", embedding, embedding2, embedding3)
29+
embedding4 = None
30+
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), ($3), ($4)", embedding, embedding2, embedding3, embedding4)
3031

3132
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
3233
assert res[0]['embedding'] == embedding
@@ -47,7 +48,8 @@ async def test_halfvec(self):
4748

4849
embedding = HalfVector([1.5, 2, 3])
4950
embedding2 = [4.5, 5, 6]
50-
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)", embedding, embedding2)
51+
embedding3 = None
52+
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), ($3)", embedding, embedding2, embedding3)
5153

5254
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
5355
assert res[0]['embedding'] == embedding
@@ -66,7 +68,8 @@ async def test_bit(self):
6668
await conn.execute('CREATE TABLE asyncpg_items (id bigserial PRIMARY KEY, embedding bit(3))')
6769

6870
embedding = asyncpg.BitString('101') # type: ignore
69-
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding)
71+
embedding2 = None
72+
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2)", embedding, embedding2)
7073

7174
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
7275
assert res[0]['embedding'].as_string() == '101'
@@ -85,7 +88,8 @@ async def test_sparsevec(self):
8588
await conn.execute('CREATE TABLE asyncpg_items (id bigserial PRIMARY KEY, embedding sparsevec(3))')
8689

8790
embedding = SparseVector([1.5, 2, 3])
88-
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding)
91+
embedding2 = None
92+
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2)", embedding, embedding2)
8993

9094
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
9195
assert res[0]['embedding'] == embedding
@@ -108,9 +112,15 @@ async def test_vector_array(self):
108112
embeddings2 = [[1.5, 2, 3], [4.5, 5, 6]]
109113
await conn.execute("INSERT INTO asyncpg_items (embeddings) VALUES (ARRAY[$1, $2]::vector[])", embeddings2[0], embeddings2[1])
110114

115+
if NUMPY_AVAILABLE:
116+
embeddings3 = [np.array([1.5, 2, 3]), np.array([4.5, 5, 6])]
117+
await conn.execute("INSERT INTO asyncpg_items (embeddings) VALUES (ARRAY[$1, $2]::vector[])", embeddings3[0], embeddings3[1])
118+
111119
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
112120
assert res[0]['embeddings'] == embeddings
113121
assert res[1]['embeddings'] == [Vector(e) for e in embeddings2]
122+
if NUMPY_AVAILABLE:
123+
assert res[2]['embeddings'] == [Vector(e) for e in embeddings3]
114124

115125
await conn.close()
116126

@@ -128,7 +138,8 @@ async def init(conn):
128138

129139
embedding = Vector([1.5, 2, 3])
130140
embedding2 = [1.5, 2, 3]
131-
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)", embedding, embedding2)
141+
embedding3 = None
142+
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), ($3)", embedding, embedding2, embedding3)
132143

133144
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
134145
assert res[0]['embedding'] == embedding

0 commit comments

Comments
 (0)