@@ -26,7 +26,8 @@ async def test_vector(self):
2626 embedding = Vector ([1.5 , 2 , 3 ])
2727 embedding2 = [4.5 , 5 , 6 ]
2828 embedding3 = np .array ([7.5 , 8 , 9 ]) if NUMPY_AVAILABLE else [7.5 , 8 , 9 ]
29- await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), ($3), (NULL)" , embedding , embedding2 , embedding3 )
29+ embedding4 = None
30+ await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), ($3), ($4)" , embedding , embedding2 , embedding3 , embedding4 )
3031
3132 res = await conn .fetch ("SELECT * FROM asyncpg_items ORDER BY id" )
3233 assert res [0 ]['embedding' ] == embedding
@@ -47,7 +48,8 @@ async def test_halfvec(self):
4748
4849 embedding = HalfVector ([1.5 , 2 , 3 ])
4950 embedding2 = [4.5 , 5 , 6 ]
50- await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)" , embedding , embedding2 )
51+ embedding3 = None
52+ await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), ($3)" , embedding , embedding2 , embedding3 )
5153
5254 res = await conn .fetch ("SELECT * FROM asyncpg_items ORDER BY id" )
5355 assert res [0 ]['embedding' ] == embedding
@@ -66,7 +68,8 @@ async def test_bit(self):
6668 await conn .execute ('CREATE TABLE asyncpg_items (id bigserial PRIMARY KEY, embedding bit(3))' )
6769
6870 embedding = asyncpg .BitString ('101' ) # type: ignore
69- await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)" , embedding )
71+ embedding2 = None
72+ await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2)" , embedding , embedding2 )
7073
7174 res = await conn .fetch ("SELECT * FROM asyncpg_items ORDER BY id" )
7275 assert res [0 ]['embedding' ].as_string () == '101'
@@ -85,7 +88,8 @@ async def test_sparsevec(self):
8588 await conn .execute ('CREATE TABLE asyncpg_items (id bigserial PRIMARY KEY, embedding sparsevec(3))' )
8689
8790 embedding = SparseVector ([1.5 , 2 , 3 ])
88- await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)" , embedding )
91+ embedding2 = None
92+ await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2)" , embedding , embedding2 )
8993
9094 res = await conn .fetch ("SELECT * FROM asyncpg_items ORDER BY id" )
9195 assert res [0 ]['embedding' ] == embedding
@@ -108,9 +112,15 @@ async def test_vector_array(self):
108112 embeddings2 = [[1.5 , 2 , 3 ], [4.5 , 5 , 6 ]]
109113 await conn .execute ("INSERT INTO asyncpg_items (embeddings) VALUES (ARRAY[$1, $2]::vector[])" , embeddings2 [0 ], embeddings2 [1 ])
110114
115+ if NUMPY_AVAILABLE :
116+ embeddings3 = [np .array ([1.5 , 2 , 3 ]), np .array ([4.5 , 5 , 6 ])]
117+ await conn .execute ("INSERT INTO asyncpg_items (embeddings) VALUES (ARRAY[$1, $2]::vector[])" , embeddings3 [0 ], embeddings3 [1 ])
118+
111119 res = await conn .fetch ("SELECT * FROM asyncpg_items ORDER BY id" )
112120 assert res [0 ]['embeddings' ] == embeddings
113121 assert res [1 ]['embeddings' ] == [Vector (e ) for e in embeddings2 ]
122+ if NUMPY_AVAILABLE :
123+ assert res [2 ]['embeddings' ] == [Vector (e ) for e in embeddings3 ]
114124
115125 await conn .close ()
116126
@@ -128,7 +138,8 @@ async def init(conn):
128138
129139 embedding = Vector ([1.5 , 2 , 3 ])
130140 embedding2 = [1.5 , 2 , 3 ]
131- await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)" , embedding , embedding2 )
141+ embedding3 = None
142+ await conn .execute ("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), ($3)" , embedding , embedding2 , embedding3 )
132143
133144 res = await conn .fetch ("SELECT * FROM asyncpg_items ORDER BY id" )
134145 assert res [0 ]['embedding' ] == embedding
0 commit comments