Skip to content

Commit d31d279

Browse files
committed
Improved test [skip ci]
1 parent 536c3fb commit d31d279

1 file changed

Lines changed: 10 additions & 2 deletions

File tree

tests/test_asyncpg.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
from pgvector.asyncpg import register_vector
44
import pytest
55

6+
try:
7+
import numpy as np
8+
NUMPY_AVAILABLE = True
9+
except ImportError:
10+
NUMPY_AVAILABLE = False
11+
612

713
class TestAsyncpg:
814
async def setup_connection(self):
@@ -19,12 +25,14 @@ async def test_vector(self):
1925

2026
embedding = Vector([1.5, 2, 3])
2127
embedding2 = [4.5, 5, 6]
22-
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)", embedding, embedding2)
28+
embedding3 = np.array([7.5, 8, 9]) if NUMPY_AVAILABLE else [7.5, 8, 9]
29+
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), ($3), (NULL)", embedding, embedding2, embedding3)
2330

2431
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
2532
assert res[0]['embedding'] == embedding
2633
assert res[1]['embedding'] == Vector(embedding2)
27-
assert res[2]['embedding'] is None
34+
assert res[2]['embedding'] == Vector(embedding3)
35+
assert res[3]['embedding'] is None
2836

2937
# ensures binary format is correct
3038
text_res = await conn.fetch("SELECT embedding::text FROM asyncpg_items ORDER BY id LIMIT 1")

0 commit comments

Comments
 (0)