1- from pgvector .sqlalchemy import Vector
1+ from pgvector .psycopg import register_vector
2+ import psycopg
23from sentence_transformers import SentenceTransformer
3- from sqlalchemy import create_engine , insert , select , text , Integer , String , Text
4- from sqlalchemy .orm import declarative_base , mapped_column , Session
54
6- engine = create_engine ('postgresql+psycopg://localhost/pgvector_example' )
7- with engine .connect () as conn :
8- conn .execute (text ('CREATE EXTENSION IF NOT EXISTS vector' ))
9- conn .commit ()
5+ conn = psycopg .connect (dbname = 'pgvector_example' , autocommit = True )
106
11- Base = declarative_base ()
7+ conn .execute ('CREATE EXTENSION IF NOT EXISTS vector' )
8+ register_vector (conn )
129
13-
14- class Document (Base ):
15- __tablename__ = 'document'
16-
17- id = mapped_column (Integer , primary_key = True )
18- content = mapped_column (Text )
19- embedding = mapped_column (Vector (384 ))
20-
21-
22- Base .metadata .drop_all (engine )
23- Base .metadata .create_all (engine )
10+ conn .execute ('DROP TABLE IF EXISTS documents' )
11+ conn .execute ('CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(384))' )
2412
2513sentences = [
2614 'The dog is barking' ,
@@ -31,12 +19,10 @@ class Document(Base):
3119model = SentenceTransformer ('all-MiniLM-L6-v2' )
3220embeddings = model .encode (sentences )
3321
34- documents = [dict (content = sentences [i ], embedding = embedding ) for i , embedding in enumerate (embeddings )]
35-
36- session = Session (engine )
37- session .execute (insert (Document ), documents )
22+ for content , embedding in zip (sentences , embeddings ):
23+ conn .execute ('INSERT INTO documents (content, embedding) VALUES (%s, %s)' , (content , embedding ))
3824
39- doc = session . get ( Document , 1 )
40- neighbors = session . scalars ( select ( Document ). filter ( Document . id != doc . id ). order_by ( Document . embedding . cosine_distance ( doc . embedding )). limit ( 5 ) )
25+ document_id = 1
26+ neighbors = conn . execute ( 'SELECT content FROM documents WHERE id != %( id)s ORDER BY embedding <=> (SELECT embedding FROM documents WHERE id = %(id)s) LIMIT 5' , { 'id' : document_id }). fetchall ( )
4127for neighbor in neighbors :
42- print (neighbor . content )
28+ print (neighbor [ 0 ] )
0 commit comments