Skip to content

Commit 907f678

Browse files
committed
Added literal binds support for SQLAlchemy - closes #51
1 parent e9232f9 commit 907f678

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.2.5 (unreleased)
2+
3+
- Added literal binds support for SQLAlchemy
4+
15
## 0.2.4 (2023-11-24)
26

37
- Improved reflection with SQLAlchemy

pgvector/sqlalchemy/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from sqlalchemy.dialects.postgresql.base import ischema_names
2-
from sqlalchemy.types import UserDefinedType, Float
2+
from sqlalchemy.types import UserDefinedType, Float, String
33
from ..utils import from_db, to_db
44

55
__all__ = ['Vector']
66

77

88
class Vector(UserDefinedType):
99
cache_ok = True
10+
_string = String()
1011

1112
def __init__(self, dim=None):
1213
super(UserDefinedType, self).__init__()
@@ -22,6 +23,12 @@ def process(value):
2223
return to_db(value, self.dim)
2324
return process
2425

26+
def literal_processor(self, dialect):
27+
string_literal_processor = self._string._cached_literal_processor(dialect)
28+
def process(value):
29+
return string_literal_processor(to_db(value, self.dim))
30+
return process
31+
2532
def result_processor(self, dialect, coltype):
2633
def process(value):
2734
return from_db(value)

tests/test_sqlalchemy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,10 @@ def test_inspect(self):
219219
columns = inspect(engine).get_columns('orm_item')
220220
assert isinstance(columns[1]['type'], Vector)
221221

222+
def test_literal_binds(self):
223+
sql = select(Item).order_by(Item.embedding.l2_distance([1, 2, 3])).compile(compile_kwargs={'literal_binds': True})
224+
assert "embedding <-> '[1.0,2.0,3.0]'" in str(sql)
225+
222226
@pytest.mark.asyncio
223227
async def test_async(self):
224228
engine = create_async_engine('postgresql+psycopg://localhost/pgvector_python_test')

0 commit comments

Comments
 (0)