Skip to content

Commit e4c7359

Browse files
committed
Converted SentenceTransformers example to psycopg [skip ci]
1 parent 105f74a commit e4c7359

File tree

1 file changed

+12
-26
lines changed

1 file changed

+12
-26
lines changed

examples/sentence_embeddings.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,14 @@
1-
from pgvector.sqlalchemy import Vector
1+
from pgvector.psycopg import register_vector
2+
import psycopg
23
from sentence_transformers import SentenceTransformer
3-
from sqlalchemy import create_engine, insert, select, text, Integer, String, Text
4-
from sqlalchemy.orm import declarative_base, mapped_column, Session
54

6-
engine = create_engine('postgresql+psycopg://localhost/pgvector_example')
7-
with engine.connect() as conn:
8-
conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
9-
conn.commit()
5+
conn = psycopg.connect(dbname='pgvector_example', autocommit=True)
106

11-
Base = declarative_base()
7+
conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
8+
register_vector(conn)
129

13-
14-
class Document(Base):
15-
__tablename__ = 'document'
16-
17-
id = mapped_column(Integer, primary_key=True)
18-
content = mapped_column(Text)
19-
embedding = mapped_column(Vector(384))
20-
21-
22-
Base.metadata.drop_all(engine)
23-
Base.metadata.create_all(engine)
10+
conn.execute('DROP TABLE IF EXISTS documents')
11+
conn.execute('CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(384))')
2412

2513
sentences = [
2614
'The dog is barking',
@@ -31,12 +19,10 @@ class Document(Base):
3119
model = SentenceTransformer('all-MiniLM-L6-v2')
3220
embeddings = model.encode(sentences)
3321

34-
documents = [dict(content=sentences[i], embedding=embedding) for i, embedding in enumerate(embeddings)]
35-
36-
session = Session(engine)
37-
session.execute(insert(Document), documents)
22+
for content, embedding in zip(sentences, embeddings):
23+
conn.execute('INSERT INTO documents (content, embedding) VALUES (%s, %s)', (content, embedding))
3824

39-
doc = session.get(Document, 1)
40-
neighbors = session.scalars(select(Document).filter(Document.id != doc.id).order_by(Document.embedding.cosine_distance(doc.embedding)).limit(5))
25+
document_id = 1
26+
neighbors = conn.execute('SELECT content FROM documents WHERE id != %(id)s ORDER BY embedding <=> (SELECT embedding FROM documents WHERE id = %(id)s) LIMIT 5', {'id': document_id}).fetchall()
4127
for neighbor in neighbors:
42-
print(neighbor.content)
28+
print(neighbor[0])

0 commit comments

Comments
 (0)