diff --git a/example/src/App.tsx b/example/src/App.tsx index fe2f05c..a16b18a 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -1,4 +1,5 @@ -import { type Message, MemoryVectorStore, useRAG } from 'react-native-rag'; +import { type Message, useRAG } from 'react-native-rag'; +import { OPSQLiteVectorStore } from '@react-native-rag/op-sqlite'; import { QWEN3_0_6B_QUANTIZED, ALL_MINILM_L6_V2, @@ -32,7 +33,8 @@ export default function App() { const [messages, setMessages] = useState([]); const vectorStore = useMemo(() => { - return new MemoryVectorStore({ + return new OPSQLiteVectorStore({ + name: 'rag_example_db1', embeddings: new ExecuTorchEmbeddings(ALL_MINILM_L6_V2), }); }, []); @@ -55,13 +57,11 @@ export default function App() { try { if (ids.length) { for (const id of ids) { - await rag.deleteDocument({ - ids: [id], - }); + await rag.deleteDocument({ predicate: (value) => value.id === id }); } setIds([]); } - const newIds = await rag.splitAddDocument(document); + const newIds = await rag.splitAddDocument({ document }); setIds(newIds); console.log('Document splitted and added with IDs:', newIds); setModalVisible(false); diff --git a/package.json b/package.json index ccae342..24016a5 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "react-native-rag", - "version": "0.2.0", + "version": "0.2.0-rc3", "description": "Private, local RAGs. Supercharge LLMs with your own knowledge base.", "main": "./lib/module/index.js", "types": "./lib/typescript/src/index.d.ts", diff --git a/packages/executorch/package.json b/packages/executorch/package.json index fd9a6fd..7473571 100644 --- a/packages/executorch/package.json +++ b/packages/executorch/package.json @@ -1,6 +1,6 @@ { "name": "@react-native-rag/executorch", - "version": "0.2.0", + "version": "0.2.0-rc3", "main": "src/index.ts", "scripts": { "test": "echo \"Error: no test specified\" && exit 1" diff --git a/packages/executorch/src/wrappers/embeddings.ts b/packages/executorch/src/wrappers/embeddings.ts index 654ad9b..ec0aaa3 100644 --- a/packages/executorch/src/wrappers/embeddings.ts +++ b/packages/executorch/src/wrappers/embeddings.ts @@ -9,7 +9,7 @@ interface ExecuTorchEmbeddingsParams { modelSource: ResourceSource; /** Source of the tokenizer model. */ tokenizerSource: ResourceSource; - /** Optional download progress callback (0-1). */ + /** Download progress callback (0-1). */ onDownloadProgress?: (progress: number) => void; } @@ -29,7 +29,7 @@ export class ExecuTorchEmbeddings implements Embeddings { * @param params - Parameters for the instance. * @param params.modelSource - Source of the embedding model. * @param params.tokenizerSource - Source of the tokenizer. - * @param params.onDownloadProgress - Optional download progress callback (0-1). + * @param params.onDownloadProgress - Download progress callback (0-1). */ constructor({ modelSource, @@ -61,16 +61,16 @@ export class ExecuTorchEmbeddings implements Embeddings { } /** - * Unloads the underlying module. Note: unload is synchronous in ExecuTorch - * at the time of writing; this method resolves immediately after calling delete. + * Unloads the underlying module. + * Note: current ExecuTorch unload is synchronous. * Awaiting this method will not guarantee completion. - * @returns Promise that resolves when unloading is initiated. */ async unload() { console.warn( 'This function will call a synchronous unload on the instance of TextEmbeddingsModule from React Native ExecuTorch. Awaiting this method will not guarantee completion. This may change in future versions to support async unload.' ); this.module.delete(); + this.isLoaded = false; } /** diff --git a/packages/executorch/src/wrappers/llms.ts b/packages/executorch/src/wrappers/llms.ts index e9dec4f..5607b00 100644 --- a/packages/executorch/src/wrappers/llms.ts +++ b/packages/executorch/src/wrappers/llms.ts @@ -12,14 +12,14 @@ interface ExecuTorchLLMParams { /** Source of the tokenizer config. */ tokenizerConfigSource: ResourceSource; - /** Optional download progress callback (0-1). */ + /** Download progress callback (0-1). */ onDownloadProgress?: (progress: number) => void; /** Callback invoked with final full response string. */ responseCallback?: (response: string) => void; /** Reserved: callback for message history changes (not wired currently). */ messageHistoryCallback?: (messageHistory: Message[]) => void; - /** Optional chat configuration forwarded to ExecuTorch. */ + /** Chat configuration forwarded to ExecuTorch. */ chatConfig?: Partial; } @@ -43,9 +43,9 @@ export class ExecuTorchLLM implements LLM { * @param params.modelSource - Source of the LLM model. * @param params.tokenizerSource - Source of the tokenizer. * @param params.tokenizerConfigSource - Source of the tokenizer config. - * @param params.onDownloadProgress - Optional download progress callback (0-1). + * @param params.onDownloadProgress - Download progress callback (0-1). * @param params.responseCallback - Callback invoked with final full response string. - * @param params.chatConfig - Optional chat configuration forwarded to ExecuTorch. + * @param params.chatConfig - Chat configuration forwarded to ExecuTorch. */ constructor({ modelSource, @@ -88,10 +88,9 @@ export class ExecuTorchLLM implements LLM { } /** - * Interrupts current generation. Note: interrupt is synchronous in ExecuTorch - * at the time of writing; this method resolves immediately after calling interrupt. + * Interrupts current generation. + * Note: current ExecuTorch interrupt is synchronous. * Awaiting this method will not guarantee completion. - * @returns Promise that resolves when interrupt is initiated. */ async interrupt() { console.warn( @@ -101,15 +100,16 @@ export class ExecuTorchLLM implements LLM { } /** - * Unloads the underlying module. Note: unload is synchronous in ExecuTorch. + * Unloads the underlying module. + * Note: current ExecuTorch unload is synchronous. * Awaiting this method will not guarantee completion. - * @returns Promise that resolves when unload is initiated. */ async unload() { console.warn( 'This function will call a synchronous unload on the instance of LLMModule from React Native ExecuTorch. Awaiting this method will not guarantee completion. This may change in future versions to support async unload.' ); this.module.delete(); + this.isLoaded = false; } /** diff --git a/packages/op-sqlite/package.json b/packages/op-sqlite/package.json index 18b0fc3..de56590 100644 --- a/packages/op-sqlite/package.json +++ b/packages/op-sqlite/package.json @@ -1,6 +1,6 @@ { "name": "@react-native-rag/op-sqlite", - "version": "0.2.0", + "version": "0.2.0-rc3", "main": "src/index.ts", "scripts": { "test": "echo \"Error: no test specified\" && exit 1" diff --git a/packages/op-sqlite/src/wrappers/op-sqlite.ts b/packages/op-sqlite/src/wrappers/op-sqlite.ts index 92ba2e5..de31722 100644 --- a/packages/op-sqlite/src/wrappers/op-sqlite.ts +++ b/packages/op-sqlite/src/wrappers/op-sqlite.ts @@ -15,8 +15,8 @@ import { open, type DB } from '@op-engineering/op-sqlite'; * * @example * const store = await new OPSQLiteVectorStore({ name: 'vector-db', embeddings }).load(); - * await store.add({ documents: ['hello world'] }); - * const [[top]] = await store.query({ queryTexts: ['hello'] }); + * await store.add({ document: 'hello world' }); + * const [top] = await store.query({ queryText: 'hello', nResults: 1 }); * console.log(top.id, top.similarity); */ export class OPSQLiteVectorStore implements VectorStore { @@ -67,309 +67,167 @@ export class OPSQLiteVectorStore implements VectorStore { } /** - * Inserts documents with embeddings. Generates IDs when not provided. + * Inserts a document with an embedding. Generates an ID when not provided. * @param params - Parameters for the operation. - * @param params.ids - Optional IDs for each document (must match `documents.length`). If not provided, IDs will be generated. - * @param params.documents - Raw text content for each document. - * @param params.embeddings - Optional embeddings for each document. - * @param params.metadatas - Optional metadata for each document (aligned by index). - * @returns Promise that resolves to the IDs of the newly added documents. + * @param params.id - ID for the document. If not provided, it will be auto-generated. + * @param params.document - Raw text content for the document. + * @param params.embedding - Embedding for the document. If not provided, it will be generated based on the `document`. + * @param params.metadata - Metadata for the document. + * @returns Promise that resolves to the ID of the newly added document. */ public async add(params: { - ids?: string[]; - documents: string[]; - embeddings?: number[][]; - metadatas?: Record[]; - }): Promise { - const { embeddings, documents, metadatas } = params; - const ids = params.ids ?? documents.map(() => uuidv4()); - - const idsLength = ids.length; - this.assertLengthMatchIds(embeddings, idsLength); - this.assertLengthMatchIds(documents, idsLength); - this.assertLengthMatchIds(metadatas, idsLength); + id?: string; + document?: string; + embedding?: number[]; + metadata?: Record; + }): Promise { + const { id = uuidv4(), document, embedding, metadata } = params; + + if (!document && !embedding) { + throw new Error('document and embedding cannot be both undefined'); + } - for (const id of ids) { - const existing = await this.db.execute( - 'SELECT 1 FROM vectors WHERE id = ? LIMIT 1', - [id] + if (embedding && embedding.length !== this.embeddingDim) { + throw new Error( + `embedding dimension ${embedding.length} does not match collection embedding dimension ${this.embeddingDim}` ); - if (existing.rows.length > 0) { - throw new Error(`id already exists: ${id}`); - } } - if (embeddings) { - for (const emb of embeddings) { - this.assertEmbeddingDim(emb); - } + if ( + ( + await this.db.execute('SELECT 1 FROM vectors WHERE id = ? LIMIT 1', [ + id, + ]) + ).rows.length > 0 + ) { + throw new Error(`id already exists: ${id}`); } - for (let i = 0; i < idsLength; i++) { - const meta = metadatas?.[i] ? JSON.stringify(metadatas[i]) : null; - await this.db.execute( - 'INSERT INTO vectors(id, document, embedding, metadata) VALUES (?, ?, vector(?), ?)', - [ - ids[i]!, - documents[i]!, - this.arrayToScalar( - embeddings - ? embeddings[i]! - : await this.embeddings.embed(documents[i]!) - ), - meta, - ] - ); - } + await this.db.execute( + 'INSERT INTO vectors(id, document, embedding, metadata) VALUES (?, ?, vector(?), ?)', + [ + id, + document ?? '', + `[${(embedding ?? (await this.embeddings.embed(document!))).join(',')}]`, + metadata ? JSON.stringify(metadata) : null, + ] + ); - return ids; + return id; } /** - * Updates documents by ID. If `documents` are provided and `embeddings` are not, - * new embeddings are computed. + * Updates a document by ID. * @param params - Parameters for the update. - * @param params.ids - IDs of the documents to update. - * @param params.embeddings - New embeddings (optional; aligned by index if provided). - * @param params.documents - New content (optional; aligned by index if provided). - * @param params.metadatas - New metadata (optional; aligned by index if provided). + * @param params.id - ID of the document to update. + * @param params.document - New content for the document. + * @param params.embedding - New embeddings for the document. If not provided, it will be generated based on the `document`. + * @param params.metadata - New metadata for the document. * @returns Promise that resolves when the update completes. */ public async update(params: { - ids: string[]; - embeddings?: number[][]; - documents?: string[]; - metadatas?: Record[]; + id: string; + embedding?: number[]; + document?: string; + metadata?: Record; }): Promise { - const { ids, embeddings, documents, metadatas } = params; - - const idsLength = ids.length; - this.assertLengthMatchIds(embeddings, idsLength); - this.assertLengthMatchIds(documents, idsLength); - this.assertLengthMatchIds(metadatas, idsLength); + const { id, document, embedding, metadata } = params; - for (const id of ids) { - const existing = await this.db.execute( - 'SELECT 1 FROM vectors WHERE id = ? LIMIT 1', - [id] + if (embedding && embedding.length !== this.embeddingDim) { + throw new Error( + `embedding dimension ${embedding.length} does not match collection embedding dimension ${this.embeddingDim}` ); - if (existing.rows.length === 0) { - throw new Error(`id not found: ${id}`); - } } - if (embeddings) { - for (const emb of embeddings) { - this.assertEmbeddingDim(emb); - } - } - - for (let i = 0; i < idsLength; i++) { - const id = ids[i]!; - const row = await this.db.execute( - 'SELECT document, embedding, metadata FROM vectors WHERE id = ?', - [id] - ); - await this.db.execute( - ` - UPDATE vectors - SET document = ?, - embedding = vector(?), - metadata = ? - WHERE id = ? - `, - [ - documents ? documents[i]! : row.rows[0]!.document!, - embeddings - ? this.arrayToScalar(embeddings[i]!) - : documents - ? this.arrayToScalar(await this.embeddings.embed(documents[i]!)) - : row.rows[0]!.embedding!, - metadatas ? JSON.stringify(metadatas[i]) : row.rows[0]!.metadata!, + if ( + ( + await this.db.execute('SELECT 1 FROM vectors WHERE id = ? LIMIT 1', [ id, - ] - ); + ]) + ).rows.length === 0 + ) { + throw new Error(`id not found: ${id}`); } + + await this.db.execute( + 'UPDATE vectors SET document = ?, embedding = vector(?), metadata = ? WHERE id = ?', + [ + document ?? '', + `[${(embedding ?? (await this.embeddings.embed(document!))).join(',')}]`, + metadata ? JSON.stringify(metadata) : null, + id, + ] + ); } /** - * Deletes documents by IDs and/or predicate. + * Deletes documents by predicate. * @param params - Parameters for deletion. - * @param params.ids - List of document IDs to delete. * @param params.predicate - Predicate to match documents for deletion. - * @returns Promise that resolves when deletion completes. + * @returns Promise that resolves once the documents are deleted. */ public async delete(params: { - ids?: string[]; - predicate?: (value: GetResult) => boolean; + predicate: (value: GetResult) => boolean; }): Promise { - const { ids, predicate } = params; - - if (ids && predicate) { - for (const id of ids) { - const existing = await this.db.execute( - 'SELECT 1 FROM vectors WHERE id = ? LIMIT 1', - [id] - ); - if (existing.rows.length === 0) { - throw new Error(`id not found: ${id}`); - } - } - - const existingRows = await this.getRowsByIds(ids); - const toDelete = existingRows.filter(predicate).map((r) => r.id); - if (toDelete.length > 0) { - await this.db.execute( - `DELETE FROM vectors WHERE id IN (${toDelete.map(() => '?').join(',')})`, - toDelete - ); - } - } else if (ids) { - for (const id of ids) { - const existing = await this.db.execute( - 'SELECT 1 FROM vectors WHERE id = ? LIMIT 1', - [id] - ); - if (existing.rows.length === 0) { - throw new Error(`id not found: ${id}`); - } - } + const { predicate } = params; + for (const row of ( await this.db.execute( - `DELETE FROM vectors WHERE id IN (${ids.map(() => '?').join(',')})`, - ids - ); - } else if (predicate) { - const allRows = await this.db.execute( 'SELECT id, document, embedding, metadata FROM vectors' - ); - const toDelete: string[] = []; - for (const row of allRows.rows) { - const getRes = this.rowToGetResult(row); - if (predicate(getRes)) { - toDelete.push(getRes.id); - } - } - - if (toDelete.length > 0) { - await this.db.execute( - `DELETE FROM vectors WHERE id IN (${toDelete.map(() => '?').join(',')})`, - toDelete - ); + ) + ).rows) { + const getResult = this.rowToGetResult(row); + if (predicate(getResult)) { + await this.db.execute('DELETE FROM vectors WHERE id = ?', [ + getResult.id, + ]); } } } /** - * Executes a cosine-similarity query using SQLite vector functions. - * Provide exactly one of `queryTexts` or `queryEmbeddings`. + * Executes a cosine-similarity query over stored vectors. + * Provide exactly one of `queryText` or `queryEmbedding`. * @param params - Query parameters. - * @param params.queryTexts - Raw query strings to search for. - * @param params.queryEmbeddings - Precomputed query embeddings. + * @param params.queryText - Raw query string to search for. + * @param params.queryEmbedding - Precomputed query embedding. * @param params.nResults - Number of top results to return. - * @param params.ids - Restrict the search to these document IDs. * @param params.predicate - Function to filter results after retrieval. - * @returns Promise resolving to arrays of scored results for each query. + * @returns Promise that resolves to an array of {@link QueryResult}. */ public async query(params: { - queryTexts?: string[]; - queryEmbeddings?: number[][]; + queryText?: string; + queryEmbedding?: number[]; nResults?: number; - ids?: string[]; predicate?: (value: QueryResult) => boolean; - }): Promise { - const { - queryTexts, - queryEmbeddings, - nResults, - ids, - predicate = () => true, - } = params; - if (!queryTexts === !queryEmbeddings) { - throw new Error( - 'Exactly one of queryTexts or queryEmbeddings must be provided' - ); - } + }): Promise { + const { queryText, queryEmbedding, nResults, predicate } = params; - if (ids) { - for (const id of ids) { - const existing = await this.db.execute( - 'SELECT 1 FROM vectors WHERE id = ? LIMIT 1', - [id] - ); - if (existing.rows.length === 0) { - throw new Error(`id not found: ${id}`); - } - } + if (!queryText && !queryEmbedding) { + throw new Error('queryText and queryEmbedding cannot be both undefined'); } - const queries: number[][] = []; - - if (queryEmbeddings) { - for (const emb of queryEmbeddings) { - this.assertEmbeddingDim(emb); - queries.push(emb); - } - } else if (queryTexts) { - for (const text of queryTexts) { - const emb = await this.embeddings.embed(text); - queries.push(emb); - } + if (queryEmbedding && queryEmbedding.length !== this.embeddingDim) { + throw new Error( + `queryEmbedding dimension ${queryEmbedding.length} does not match collection embedding dimension ${this.embeddingDim}` + ); } - const pool: GetResult[] = ids?.length - ? await this.getRowsByIds(ids) - : ( - await this.db.execute( - 'SELECT id, document, embedding, metadata FROM vectors' - ) - ).rows.map((r: any) => this.rowToGetResult(r)); - - const results: QueryResult[][] = []; + const searchEmbedding = + queryEmbedding ?? (await this.embeddings.embed(queryText!)); - for (const q of queries) { - const qScalar = this.arrayToScalar(q); - let res; - if (ids && ids.length) { - res = await this.db.execute( - ` - SELECT - id, - document, - embedding, - metadata, - (1.0 - vector_distance_cos(embedding, vector(?))) AS similarity - FROM vectors - WHERE vectors.id IN (${ids.map(() => '?').join(',')}) - ORDER BY similarity DESC - `, - [qScalar, ...pool.map((r) => r.id)] - ); - } else { - res = await this.db.execute( - ` - SELECT - id, - document, - embedding, - metadata, - (1.0 - vector_distance_cos(embedding, vector(?))) AS similarity - FROM vectors - ORDER BY similarity DESC - `, - [qScalar] - ); - } - - const scored = res.rows - .map((r) => this.rowToGetResult(r) as QueryResult) - .filter(predicate) - .slice(0, nResults); - - results.push(scored); - } + const res = await this.db.execute( + 'SELECT id, document, embedding, metadata, (1.0 - vector_distance_cos(embedding, vector(?))) AS similarity FROM vectors ORDER BY similarity DESC', + [`[${searchEmbedding.join(',')}]`] + ); - return results; + return res.rows + .map((r: any) => ({ + ...this.rowToGetResult(r), + similarity: r.similarity as number, + })) + .filter(predicate ?? (() => true)) + .slice(0, nResults); } /** @@ -383,61 +241,11 @@ export class OPSQLiteVectorStore implements VectorStore { * Maps a DB row to a {@link GetResult} object. */ private rowToGetResult(row: any): GetResult { - const embedding = Array.isArray(row.embedding) - ? (row.embedding as number[]) - : Array.from(row.embedding as Float32Array); return { - id: row.id as string, - document: (row.document as string) ?? '', - embedding, - metadata: row.metadata - ? (JSON.parse(row.metadata as string) as Record) - : undefined, + id: row.id, + document: row.document, + embedding: Array.from(new Float32Array(row.embedding)), + metadata: row.metadata ? JSON.parse(row.metadata) : undefined, }; } - - /** - * Fetches rows by IDs and returns them as {@link GetResult} objects. - */ - private async getRowsByIds(ids: string[]): Promise { - if (ids.length === 0) return []; - const placeholders = ids.map(() => '?').join(','); - const res = await this.db.execute( - `SELECT id, document, embedding, metadata FROM vectors WHERE id IN (${placeholders})`, - ids - ); - return res.rows.map((r: any) => this.rowToGetResult(r)); - } - - /** - * Converts an array of numbers to a string representation suitable for `vector(?)` binding. - */ - private arrayToScalar(arr: number[]): string { - return `[${arr.join(',')}]`; - } - - /** - * Verifies optional arrays match expected length. - */ - private assertLengthMatchIds(arr: T[] | undefined, idsLength: number) { - if (arr && arr.length !== idsLength) { - throw new Error('array length must match ids length'); - } - } - - /** - * Ensures all embeddings share the same dimensionality, setting it on first use. - */ - private assertEmbeddingDim(vec: number[]) { - if (!Array.isArray(vec) || vec.length === 0) { - throw new Error('embedding must be a non-empty vector'); - } - if (this.embeddingDim === undefined) { - this.embeddingDim = vec.length; - } else if (vec.length !== this.embeddingDim) { - throw new Error( - `embedding dimension ${vec.length} does not match collection dimension ${this.embeddingDim}` - ); - } - } } diff --git a/src/__tests__/memoryVectorStore.test.ts b/src/__tests__/memoryVectorStore.test.ts index 7476cb3..56f8699 100644 --- a/src/__tests__/memoryVectorStore.test.ts +++ b/src/__tests__/memoryVectorStore.test.ts @@ -32,263 +32,216 @@ class MockEmbeddings implements Embeddings { } describe('MemoryVectorStore', () => { - test('add() with provided embeddings + metadata', async () => { + test('add() with provided embedding + metadata', async () => { const emb = new MockEmbeddings(3); const store = new MemoryVectorStore({ embeddings: emb }); + await store.load(); - const ids = ['a', 'b']; - const documents = ['hello', 'world']; - const embeddings = [ - [1, 2, 3], - [3, 2, 1], - ]; - const metadatas = [{ tag: 't1' }, { tag: 't2' }]; - - await store.add({ ids, documents, embeddings, metadatas }); + await store.add({ + id: 'a', + document: 'hello', + embedding: [1, 2, 3], + metadata: { tag: 't1' }, + }); + await store.add({ + id: 'b', + document: 'world', + embedding: [3, 2, 1], + metadata: { tag: 't2' }, + }); const res = await store.query({ - queryEmbeddings: [[1, 2, 3]], + queryEmbedding: [1, 2, 3], nResults: 2, }); - expect(res).toHaveLength(1); - expect(res[0]!.map((r) => r.id)).toEqual(['a', 'b']); - expect(res[0]![0]!.metadata).toEqual({ tag: 't1' }); + expect(res.map((r) => r.id)).toEqual(['a', 'b']); + expect(res[0]!.metadata).toEqual({ tag: 't1' }); }); - test('add() without embeddings computes them via embeddings.embed', async () => { + test('add() without embedding computes it via embeddings.embed', async () => { const emb = new MockEmbeddings(4); const store = new MemoryVectorStore({ embeddings: emb }); + await store.load(); + emb.embedCalls = []; - await store.add({ - ids: ['x', 'y'], - documents: ['Doc X', 'Doc Y'], - }); + await store.add({ id: 'x', document: 'Doc X' }); + await store.add({ id: 'y', document: 'Doc Y' }); expect(emb.embedCalls).toEqual(['Doc X', 'Doc Y']); - const res = await store.query({ queryTexts: ['Doc X'], nResults: 2 }); - expect(res[0]![0]!.id).toBe('x'); + const res = await store.query({ queryText: 'Doc X', nResults: 2 }); + expect(res[0]!.id).toBe('x'); }); test('add() rejects duplicate ids', async () => { const store = new MemoryVectorStore({ embeddings: new MockEmbeddings(2) }); + await store.load(); - await store.add({ ids: ['1'], documents: ['one'] }); - await expect(store.add({ ids: ['1'], documents: ['uno'] })).rejects.toThrow( + await store.add({ id: '1', document: 'one' }); + await expect(store.add({ id: '1', document: 'uno' })).rejects.toThrow( /id already exists/i ); }); - test('add() rejects array length mismatches', async () => { + test('add() rejects when both document and embedding are missing', async () => { const store = new MemoryVectorStore({ embeddings: new MockEmbeddings(2) }); - - await expect( - store.add({ - ids: ['1', '2'], - documents: ['only-one'], - }) - ).rejects.toThrow(/array length must match ids length/i); + await store.load(); + await expect(store.add({ id: 'x' })).rejects.toThrow( + /document and embedding cannot be both undefined/i + ); }); test('add() rejects embedding dimension mismatch', async () => { const store = new MemoryVectorStore({ embeddings: new MockEmbeddings(3) }); + await store.load(); - await store.add({ - ids: ['a'], - documents: ['foo'], - embeddings: [[1, 1, 1]], - }); + await store.add({ id: 'a', document: 'foo', embedding: [1, 1, 1] }); await expect( - store.add({ - ids: ['b'], - documents: ['bar'], - embeddings: [[1, 2]], - }) + store.add({ id: 'b', document: 'bar', embedding: [1, 2] }) ).rejects.toThrow( - /embedding dimension .* does not match collection dimension/i + /embedding dimension .* does not match collection embedding dimension/i ); }); test('update() updates document + recomputes embedding when no embedding provided', async () => { const emb = new MockEmbeddings(4); const store = new MemoryVectorStore({ embeddings: emb }); + await store.load(); - await store.add({ - ids: ['k'], - documents: ['old doc'], - }); + await store.add({ id: 'k', document: 'old doc' }); emb.embedCalls = []; - await store.update({ - ids: ['k'], - documents: ['new doc'], - }); + await store.update({ id: 'k', document: 'new doc' }); expect(emb.embedCalls).toEqual(['new doc']); - const res = await store.query({ queryTexts: ['new doc'], nResults: 1 }); - expect(res[0]![0]!.id).toBe('k'); - expect(res[0]![0]!.document).toBe('new doc'); + const res = await store.query({ queryText: 'new doc', nResults: 1 }); + expect(res[0]!.id).toBe('k'); + expect(res[0]!.document).toBe('new doc'); }); test('update() updates embedding + metadata directly', async () => { const store = new MemoryVectorStore({ embeddings: new MockEmbeddings(3) }); + await store.load(); - await store.add({ - ids: ['m'], - documents: ['alpha'], - metadatas: [{ a: 1 }], - }); + await store.add({ id: 'm', document: 'alpha', metadata: { a: 1 } }); - await store.update({ - ids: ['m'], - embeddings: [[9, 9, 9]], - metadatas: [{ b: 2 }], - }); + await store.update({ id: 'm', embedding: [9, 9, 9], metadata: { b: 2 } }); - const res = await store.query({ - queryEmbeddings: [[9, 9, 9]], - nResults: 1, - }); + const res = await store.query({ queryEmbedding: [9, 9, 9], nResults: 1 }); - expect(res[0]![0]!.id).toBe('m'); - expect(res[0]![0]!.metadata).toEqual({ b: 2 }); + expect(res[0]!.id).toBe('m'); + expect(res[0]!.metadata).toEqual({ b: 2 }); }); test('update() rejects unknown ids', async () => { const store = new MemoryVectorStore({ embeddings: new MockEmbeddings(2) }); + await store.load(); await expect( - store.update({ ids: ['missing'], documents: ['nope'] }) + store.update({ id: 'missing', document: 'nope' }) ).rejects.toThrow(/id not found/i); }); - test('delete() by ids works', async () => { + test('delete() by predicate (by id) works', async () => { const store = new MemoryVectorStore({ embeddings: new MockEmbeddings(2) }); - await store.add({ - ids: ['a', 'b', 'c'], - documents: ['A', 'B', 'C'], - }); + await store.load(); + await store.add({ id: 'a', document: 'A' }); + await store.add({ id: 'b', document: 'B' }); + await store.add({ id: 'c', document: 'C' }); - await store.delete({ ids: ['b'] }); + await store.delete({ predicate: (row) => row.id === 'b' }); - const res = await store.query({ queryTexts: ['A'], nResults: 10 }); - const remaining = new Set(res[0]!.map((r) => r.id)); + const res = await store.query({ queryText: 'A', nResults: 10 }); + const remaining = new Set(res.map((r) => r.id)); expect(remaining.has('b')).toBe(false); }); test('delete() by predicate works', async () => { const store = new MemoryVectorStore({ embeddings: new MockEmbeddings(2) }); - await store.add({ - ids: ['a', 'b', 'c'], - documents: ['keep', 'drop', 'keep-too'], - metadatas: [{ role: 'x' }, { role: 'y' }, { role: 'x' }], - }); + await store.load(); + await store.add({ id: 'a', document: 'keep', metadata: { role: 'x' } }); + await store.add({ id: 'b', document: 'drop', metadata: { role: 'y' } }); + await store.add({ id: 'c', document: 'keep-too', metadata: { role: 'x' } }); - await store.delete({ - predicate: (row) => row.metadata!.role === 'y', - }); + await store.delete({ predicate: (row) => row.metadata!.role === 'y' }); - const res = await store.query({ queryTexts: ['keep'], nResults: 10 }); - const ids = new Set(res[0]!.map((r) => r.id)); + const res = await store.query({ queryText: 'keep', nResults: 10 }); + const ids = new Set(res.map((r) => r.id)); expect(ids.has('b')).toBe(false); }); - test('delete() by ids + predicate works', async () => { + test('delete() supports complex predicates', async () => { const store = new MemoryVectorStore({ embeddings: new MockEmbeddings(2) }); - await store.add({ - ids: ['a', 'b'], - documents: ['doc-a', 'doc-b'], - metadatas: [{ keep: false }, { keep: true }], - }); + await store.load(); + await store.add({ id: 'a', document: 'doc-a', metadata: { keep: false } }); + await store.add({ id: 'b', document: 'doc-b', metadata: { keep: true } }); - await store.delete({ - ids: ['a', 'b'], - predicate: (row) => row.metadata!.keep === false, - }); + await store.delete({ predicate: (row) => row.metadata!.keep === false }); - const res = await store.query({ queryTexts: ['doc'], nResults: 10 }); - const ids = new Set(res[0]!.map((r) => r.id)); + const res = await store.query({ queryText: 'doc', nResults: 10 }); + const ids = new Set(res.map((r) => r.id)); expect(ids.has('a')).toBe(false); expect(ids.has('b')).toBe(true); }); - test('query() by text returns top n sorted, supports ids filter and predicate', async () => { + test('query() by text returns top n sorted and supports predicate', async () => { const store = new MemoryVectorStore({ embeddings: new MockEmbeddings(4) }); + await store.load(); + await store.add({ id: 'p', document: 'apple pie', metadata: { cat: 'x' } }); + await store.add({ + id: 'q', + document: 'banana split', + metadata: { cat: 'y' }, + }); await store.add({ - ids: ['p', 'q', 'r'], - documents: ['apple pie', 'banana split', 'ripe banana'], - metadatas: [{ cat: 'x' }, { cat: 'y' }, { cat: 'y' }], + id: 'r', + document: 'ripe banana', + metadata: { cat: 'y' }, }); const res = await store.query({ - queryTexts: ['banana'], + queryText: 'banana', nResults: 2, - ids: ['q', 'r'], predicate: (row) => row.metadata!.cat === 'y', }); - expect(res).toHaveLength(1); - expect(res[0]).toHaveLength(2); - expect(res[0]!.map((x) => x.id)).toEqual(['q', 'r']); - expect(res[0]![0]!.similarity).toBeGreaterThanOrEqual( - res[0]![1]!.similarity - ); + expect(res).toHaveLength(2); + expect(res.map((x) => x.id)).toEqual(['q', 'r']); + expect(res[0]!.similarity).toBeGreaterThanOrEqual(res[1]!.similarity); }); test('query() by embeddings works and enforces dimension', async () => { const store = new MemoryVectorStore({ embeddings: new MockEmbeddings(3) }); - await store.add({ - ids: ['u', 'v'], - documents: ['foo', 'bar'], - embeddings: [ - [1, 0, 0], - [0, 1, 0], - ], - }); + await store.load(); + await store.add({ id: 'u', document: 'foo', embedding: [1, 0, 0] }); + await store.add({ id: 'v', document: 'bar', embedding: [0, 1, 0] }); const ok = await store.query({ - queryEmbeddings: [[0.9, 0.1, 0]], + queryEmbedding: [0.9, 0.1, 0], nResults: 1, }); - expect(ok[0]![0]!.id).toBe('u'); + expect(ok[0]!.id).toBe('u'); await expect( - store.query({ - queryEmbeddings: [[1, 2]], - nResults: 1, - }) + store.query({ queryEmbedding: [1, 2], nResults: 1 }) ).rejects.toThrow( - /embedding dimension .* does not match collection dimension/i + /queryEmbedding dimension .* does not match collection embedding dimension/i ); }); - test('query() arg validation: exactly one of queryTexts or queryEmbeddings must be provided', async () => { + test('query() arg validation: queryText and queryEmbedding cannot both be undefined', async () => { const store = new MemoryVectorStore({ embeddings: new MockEmbeddings(2) }); - await store.add({ ids: ['a'], documents: ['alpha'] }); + await store.load(); + await store.add({ id: 'a', document: 'alpha' }); await expect(store.query({ nResults: 1 })).rejects.toThrow( - /exactly one of queryTexts or queryEmbeddings must be provided/i - ); - - await expect( - store.query({ queryTexts: ['x'], queryEmbeddings: [[1, 2]], nResults: 1 }) - ).rejects.toThrow( - /exactly one of queryTexts or queryEmbeddings must be provided/i + /queryText and queryEmbedding cannot be both undefined/i ); await expect( - store.query({ queryTexts: ['alpha'], nResults: 1 }) + store.query({ queryText: 'alpha', nResults: 1 }) ).resolves.toBeDefined(); }); - - test('query() rejects unknown ids in ids filter', async () => { - const store = new MemoryVectorStore({ embeddings: new MockEmbeddings(2) }); - await store.add({ ids: ['a'], documents: ['alpha'] }); - - await expect( - store.query({ queryTexts: ['alpha'], ids: ['missing'], nResults: 1 }) - ).rejects.toThrow(/id not found/i); - }); }); diff --git a/src/__tests__/rag.integration.test.ts b/src/__tests__/rag.integration.test.ts index 1aa0b69..08c1c97 100644 --- a/src/__tests__/rag.integration.test.ts +++ b/src/__tests__/rag.integration.test.ts @@ -132,13 +132,13 @@ describe('RAG (integration with MemoryVectorStore + MockEmbeddings)', () => { await rag.load(); const doc = 'a short document that fits in one chunk'; - const ids = await rag.splitAddDocument(doc); + const ids = await rag.splitAddDocument({ document: doc }); expect(ids).toHaveLength(1); expect(embeddings.embedCalls).toContain(doc); - const results = await store.query({ queryTexts: [doc], nResults: 3 }); - const topIds = results[0]!.map((r) => r.id); + const results = await store.query({ queryText: doc, nResults: 3 }); + const topIds = results.map((r) => r.id); expect(ids.some((id) => topIds.includes(id))).toBe(true); }); @@ -158,12 +158,16 @@ describe('RAG (integration with MemoryVectorStore + MockEmbeddings)', () => { chunks.map((c, i) => ({ idx: i, len: c.length })); const doc = 'DOC-DOC'; - const ids = await rag.splitAddDocument(doc, metaGen, splitter); + const ids = await rag.splitAddDocument({ + document: doc, + metadataGenerator: metaGen, + textSplitter: splitter, + }); expect(ids.length).toBeGreaterThan(1); - const res = await store.query({ queryTexts: ['DOC'], nResults: 10 }); - const gotIds = new Set(res[0]!.map((r) => r.id)); + const res = await store.query({ queryText: 'DOC', nResults: 10 }); + const gotIds = new Set(res.map((r) => r.id)); expect(ids.some((id) => gotIds.has(id))).toBe(true); }); @@ -180,17 +184,17 @@ describe('RAG (integration with MemoryVectorStore + MockEmbeddings)', () => { const spyDelete = jest.spyOn(store, 'delete'); await rag.addDocument({ - ids: ['id1'], - embeddings: [[1, 0, 0]], - documents: ['alpha'], - metadatas: [{ a: 1 }], + id: 'id1', + embedding: [1, 0, 0], + document: 'alpha', + metadata: { a: 1 }, }); expect(spyAdd).toHaveBeenCalledTimes(1); - await rag.updateDocument({ ids: ['id1'], documents: ['alpha-new'] }); + await rag.updateDocument({ id: 'id1', document: 'alpha-new' }); expect(spyUpdate).toHaveBeenCalledTimes(1); - await rag.deleteDocument({ ids: ['id1'] }); + await rag.deleteDocument({ predicate: (r) => r.id === 'id1' }); expect(spyDelete).toHaveBeenCalledTimes(1); }); @@ -221,10 +225,8 @@ describe('RAG (integration with MemoryVectorStore + MockEmbeddings)', () => { await rag.load(); - await store.add({ - ids: ['d1', 'd2'], - documents: ['bananas are yellow', 'apples are red'], - }); + await store.add({ id: 'd1', document: 'bananas are yellow' }); + await store.add({ id: 'd2', document: 'apples are red' }); const tokens: string[] = []; const out = await rag.generate({ @@ -253,10 +255,7 @@ describe('RAG (integration with MemoryVectorStore + MockEmbeddings)', () => { await rag.load(); - await store.add({ - ids: ['z'], - documents: ['custom context'], - }); + await store.add({ id: 'z', document: 'custom context' }); const qSpy = jest.spyOn(store, 'query'); @@ -270,7 +269,7 @@ describe('RAG (integration with MemoryVectorStore + MockEmbeddings)', () => { promptGenerator, }); - expect(qSpy.mock.calls[0]![0]!.queryTexts).toEqual(['custom-question']); + expect(qSpy.mock.calls[0]![0]!.queryText).toEqual('custom-question'); const llmInput = llm.generateCalls[0]!.input; expect(llmInput[1]!.content).toBe('PROMPT(orig) :: custom context'); @@ -285,9 +284,19 @@ describe('RAG (integration with MemoryVectorStore + MockEmbeddings)', () => { await rag.load(); await store.add({ - ids: ['a', 'b', 'c'], - documents: ['keep this', 'drop this', 'keep as well'], - metadatas: [{ role: 'x' }, { role: 'y' }, { role: 'x' }], + id: 'a', + document: 'keep this', + metadata: { role: 'x' }, + }); + await store.add({ + id: 'b', + document: 'drop this', + metadata: { role: 'y' }, + }); + await store.add({ + id: 'c', + document: 'keep as well', + metadata: { role: 'x' }, }); const predicate = (r: QueryResult) => (r as any).metadata?.role === 'x'; diff --git a/src/hooks/rag.ts b/src/hooks/rag.ts index e5f5cd7..48ca8ed 100644 --- a/src/hooks/rag.ts +++ b/src/hooks/rag.ts @@ -6,19 +6,10 @@ import type { VectorStore } from '../interfaces/vectorStore'; import type { LLM } from '../interfaces/llm'; /** - * A React hook for Retrieval Augmented Generation (RAG). - * Manages RAG system lifecycle, loading, unloading, generation, and document storage. + * React hook for Retrieval Augmented Generation. + * Manages load/unload, generation, and document storage. * - * @param params - RAG configuration. - * @returns An object with state and RAG operations: `response`, `isReady`, `isGenerating`, `isStoring`, `error`, and functions `generate`, `interrupt`, `splitAddDocument`, `addDocument`, `updateDocument`, `deleteDocument`. - * - * @example - * // Basic usage in a component - * const { isReady, response, generate } = useRAG({ vectorStore, llm }); - * useEffect(() => { - * if (!isReady) return; - * generate({ input: 'What is RAG?' }); - * }, [isReady]); + * @returns State and operations: `response`, `isReady`, `isGenerating`, `isStoring`, `error`, `generate`, `interrupt`, `splitAddDocument`, `addDocument`, `updateDocument`, `deleteDocument`. */ export function useRAG({ vectorStore, @@ -69,14 +60,14 @@ export function useRAG({ /** * Generates a text response. - * @param params - Object containing: - * - `input`: User input as a string or array of messages. - * - `augmentedGeneration` (optional): Whether to use RAG augmentation (default: true). - * - `nResults` (optional): Number of documents to retrieve for augmentation (default: 3). - * - `predicate` (optional): Predicate to filter retrieved documents. - * - `questionGenerator` (optional): Function to generate a question from messages. - * - `promptGenerator` (optional): Function to generate a prompt from messages and retrieved documents. - * - `callback` (optional): Callback function for streaming tokens. + * @param params - Parameters for the generation. + * @param params.input - User input as a string or array of messages. + * @param params.augmentedGeneration - Whether to use RAG augmentation (default: true). + * @param params.nResults - Number of documents to retrieve for augmentation (default: 3). + * @param params.predicate - Predicate to filter retrieved documents. + * @param params.questionGenerator - Function to generate a question from messages. + * @param params.promptGenerator - Function to generate a prompt from messages and retrieved documents. + * @param params.callback - Callback function for streaming tokens. * @returns A promise that resolves to the generated text. * @throws Error if RAG is not ready or is currently generating. */ @@ -129,27 +120,24 @@ export function useRAG({ /** * Splits and adds a document to the vector store. - * @param document - Document content. - * @param metadataGenerator - Optional metadata generator. - * @param textSplitter - Optional text splitter. + * @param params - Parameters for the operation. + * @param params.document - Document content. + * @param params.metadataGenerator - Metadata generator. + * @param params.textSplitter - Text splitter. * @returns IDs of added chunks. */ const splitAddDocument = useCallback( - async ( - document: string, - metadataGenerator?: (chunks: string[]) => Record[], - textSplitter?: TextSplitter - ): Promise => { + async (params: { + document: string; + metadataGenerator?: (chunks: string[]) => Record[]; + textSplitter?: TextSplitter; + }): Promise => { if (!isReady) throw new Error('RAG not ready.'); if (isStoring) throw new Error('RAG busy storing.'); setError(null); try { setIsStoring(true); - return await rag.splitAddDocument( - document, - metadataGenerator, - textSplitter - ); + return await rag.splitAddDocument(params); } catch (e) { setError(e instanceof Error ? e.message : 'Split/add doc error.'); throw e; @@ -161,22 +149,22 @@ export function useRAG({ ); /** - * Adds documents to the vector store. - * @param params - Object containing: - * - `ids`: (optional) The IDs of the documents. If not provided, they will be auto-generated. - * - `documents`: Raw text content of the documents. - * - `embeddings` (optional): Embeddings for the documents. - * - `metadatas` (optional): Metadata associated with each document. - * @returns A promise that resolves to the IDs of the newly added documents. + * Adds a document to the vector store. + * @param params - Parameters for the operation. + * @param params.id - The ID of the document. If not provided, it will be auto-generated. + * @param params.document - Raw text content of the document. + * @param params.embedding - Embedding for the document. If not provided, it will be generated based on the `document`. + * @param params.metadata - Metadata associated with the document. + * @returns A promise that resolves to the ID of the newly added document. * @throws Error if RAG is not ready or is currently storing. */ const addDocument = useCallback( async (params: { - ids?: string[]; - documents: string[]; - embeddings?: number[][]; - metadatas?: Record[]; - }): Promise => { + id?: string; + document?: string; + embedding?: number[]; + metadata?: Record; + }): Promise => { if (!isReady) throw new Error('RAG not ready.'); if (isStoring) throw new Error('RAG busy storing.'); setError(null); @@ -194,21 +182,21 @@ export function useRAG({ ); /** - * Updates documents in the vector store. - * @param params - Object containing: - * - `ids`: The IDs of the documents to update. - * - `embeddings` (optional): New embeddings for the documents. - * - `documents` (optional): New content for the documents. - * - `metadatas` (optional): New metadata for the documents. - * @returns A promise that resolves when the documents are updated. + * Updates a document in the vector store by its ID. + * @param params - Parameters for the update. + * @param params.id - The ID of the document to update. + * @param params.document - New content for the document. + * @param params.embedding - New embedding for the document. If not provided, it will be generated based on the `document`. + * @param params.metadata - New metadata for the document. + * @returns A promise that resolves when the document is updated. * @throws Error if RAG is not ready or is currently storing. */ const updateDocument = useCallback( async (params: { - ids: string[]; - embeddings?: number[][]; - documents?: string[]; - metadatas?: Record[]; + id: string; + embedding?: number[]; + document?: string; + metadata?: Record; }): Promise => { if (!isReady) throw new Error('RAG not ready.'); if (isStoring) throw new Error('RAG busy storing.'); @@ -224,17 +212,15 @@ export function useRAG({ ); /** - * Deletes documents from the vector store. - * @param params - Object containing: - * - `ids` (optional): List of document IDs to delete. - * - `predicate` (optional): Predicate to match documents for deletion. - * @returns A promise that resolves when the documents are deleted. + * Deletes documents from the vector store by the provided predicate. + * @param params - Parameters for deletion. + * @param params.predicate - Predicate to match documents for deletion. + * @returns A promise that resolves once the documents are deleted. * @throws Error if RAG is not ready or is currently storing. */ const deleteDocument = useCallback( async (params: { - ids?: string[]; - predicate?: (value: GetResult) => boolean; + predicate: (value: GetResult) => boolean; }): Promise => { if (!isReady) throw new Error('RAG not ready.'); if (isStoring) throw new Error('RAG busy storing.'); diff --git a/src/interfaces/vectorStore.ts b/src/interfaces/vectorStore.ts index e4c8576..5119fd7 100644 --- a/src/interfaces/vectorStore.ts +++ b/src/interfaces/vectorStore.ts @@ -10,76 +10,69 @@ import type { GetResult, QueryResult } from '../types/common'; export interface VectorStore { /** * Initializes the vector store, loading necessary resources. - * @returns A promise that resolves to the initialized vector store instance. + * @returns Promise that resolves to the initialized vector store instance. */ load: () => Promise; /** * Unloads the vector store, releasing any resources used. - * @returns A promise that resolves when the vector store is unloaded. + * @returns Promise that resolves when the vector store is unloaded. */ unload: () => Promise; /** - * Adds documents to the vector store. + * Adds a document to the vector store. * @param params - Object containing: - * - `ids`: (optional) The IDs of the documents. If not provided, they will be auto-generated. - * - `documents`: Raw text content of the documents. - * - `embeddings` (optional): Embeddings for the documents. - * - `metadatas` (optional): Metadata associated with each document. - * @returns A promise that resolves to the IDs of the newly added documents. + * @param params.id - The ID of the document. If not provided, it will be auto-generated. + * @param params.document - Raw text content of the document. + * @param params.embedding - Embedding for the document. If not provided, it will be generated based on the `document`. + * @param params.metadata - Metadata associated with the document. + * @returns Promise that resolves to the ID of the newly added document. */ add(params: { - ids?: string[]; - documents: string[]; - embeddings?: number[][]; - metadatas?: Record[]; - }): Promise; + id?: string; + document?: string; + embedding?: number[]; + metadata?: Record; + }): Promise; /** - * Updates documents in the vector store by their IDs. - * If `documents` are provided, and `embeddings` are not, new embeddings will be generated. + * Updates a document in the vector store by its ID. * @param params - Object containing: - * - `ids`: The IDs of the documents to update. - * - `embeddings` (optional): New embeddings for the documents. - * - `documents` (optional): New content for the documents. - * - `metadatas` (optional): New metadata for the documents. - * @returns A promise that resolves when the documents are updated. + * @param params.id - The ID of the document to update. + * @param params.document - New content for the document. + * @param params.embedding - New embedding for the document. If not provided, it will be generated based on the `document`. + * @param params.metadata - New metadata for the document. + * @returns Promise that resolves once the document is updated. */ update(params: { - ids: string[]; - embeddings?: number[][]; - documents?: string[]; - metadatas?: Record[]; + id: string; + document?: string; + embedding?: number[]; + metadata?: Record; }): Promise; /** - * Deletes documents from the vector store. + * Deletes documents from the vector store by the provided predicate. * @param params - Object containing: - * - `ids` (optional): List of document IDs to delete. - * - `predicate` (optional): Predicate to match documents for deletion. - * @returns A promise that resolves when the documents are deleted. + * @param params.predicate - Predicate to match documents for deletion. + * @returns Promise that resolves once the documents are deleted. */ - delete(params: { - ids?: string[]; - predicate?: (value: GetResult) => boolean; - }): Promise; + delete(params: { predicate: (value: GetResult) => boolean }): Promise; /** * Performs a similarity search against the stored vectors. * @param params - Object containing: - * - `queryTexts` (optional): The raw query strings to search for. - * - `queryEmbeddings` (optional): Pre-computed embeddings for the queries. - * - `nResults` (optional): The number of top similar results to return per query. - * - `ids` (optional): Restrict the search to these document IDs. - * - `predicate` (optional): Function to filter results after retrieval. - * @returns A promise that resolves to an array of result arrays (one per query). + * @param params.queryText - The raw query string to search for. + * @param params.queryEmbedding - Pre-computed embedding for the query. + * @param params.nResults - The number of top similar results to return. + * @param params.predicate - Function to filter results after retrieval. + * @returns Promise that resolves to an array of {@link QueryResult}. */ query(params: { - queryTexts?: string[]; - queryEmbeddings?: number[][]; + queryText?: string; + queryEmbedding?: number[]; nResults?: number; - ids?: string[]; predicate?: (value: QueryResult) => boolean; - }): Promise; + }): Promise; } diff --git a/src/rag/rag.ts b/src/rag/rag.ts index e45e85a..cc610f7 100644 --- a/src/rag/rag.ts +++ b/src/rag/rag.ts @@ -6,17 +6,12 @@ import type { VectorStore } from '../interfaces/vectorStore'; import { uuidv4 } from '../utils/uuidv4'; /** - * Core Retrieval Augmented Generation orchestrator. - * - * The `RAG` class coordinates a `VectorStore` and an `LLM` to: - * - Split and ingest documents for retrieval - * - Retrieve relevant context for a query - * - Generate responses with or without augmented context + * Orchestrates Retrieval Augmented Generation. + * Coordinates a `VectorStore` and an `LLM` to ingest, retrieve, and generate. * * @example * const rag = await new RAG({ vectorStore, llm }).load(); * const answer = await rag.generate({ input: 'What is RAG?' }); - * console.log(answer); */ export class RAG { private vectorStore: VectorStore; @@ -57,17 +52,19 @@ export class RAG { * If no `textSplitter` is provided, a default * `RecursiveCharacterTextSplitter({ chunkSize: 500, chunkOverlap: 100 })` is used. * - * @param document - The content of the document to split and add. - * @param metadataGenerator - Optional function to generate metadata for each chunk. - * Must return an array which length is equal to the number of chunks. - * @param textSplitter - Optional text splitter implementation. + * @param params - Parameters for the operation. + * @param params.document - The content of the document to split and add. + * @param params.metadataGenerator - Function to generate metadata for each chunk. Must return an array which length is equal to the number of chunks. + * @param params.textSplitter - Text splitter implementation. * @returns Promise that resolves to the IDs of the newly added chunks. */ - async splitAddDocument( - document: string, - metadataGenerator?: (chunks: string[]) => Record[], - textSplitter?: TextSplitter - ): Promise { + async splitAddDocument(params: { + document: string; + metadataGenerator?: (chunks: string[]) => Record[]; + textSplitter?: TextSplitter; + }): Promise { + let { document, metadataGenerator, textSplitter } = params; + if (!textSplitter) { textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 500, @@ -84,61 +81,61 @@ export class RAG { ); } - await this.vectorStore.add({ - ids, - documents: chunks, - metadatas, - }); + for (let i = 0; i < ids.length; i++) { + await this.vectorStore.add({ + id: ids[i], + document: chunks[i], + metadata: metadatas ? metadatas[i] : undefined, + }); + } return ids; } /** - * Adds documents to the vector store. + * Adds a document to the vector store. * @param params - Parameters for the operation. - * @param params.ids - Optional IDs for each document (must match `documents.length`). - * @param params.documents - Raw text content for each document. - * @param params.embeddings - Optional embeddings for each document. - * @param params.metadatas - Optional metadata for each document (aligned by index). - * @returns Promise that resolves to the IDs of the newly added documents. + * @param params.id - ID for the document. + * @param params.document - Raw text content for the document. + * @param params.embedding - Embedding for the document. + * @param params.metadata - Metadata for the document. + * @returns Promise that resolves to the ID of the newly added document. */ async addDocument(params: { - ids?: string[]; - documents: string[]; - embeddings?: number[][]; - metadatas?: Record[]; - }): Promise { + id?: string; + document?: string; + embedding?: number[]; + metadata?: Record; + }): Promise { return this.vectorStore.add(params); } /** - * Updates documents in the vector store by their IDs. + * Updates a document in the vector store by its ID. * @param params - Parameters for the update. - * @param params.ids - IDs of the documents to update. - * @param params.embeddings - New embeddings (optional; aligned by index if provided). - * @param params.documents - New content (optional; aligned by index if provided). - * @param params.metadatas - New metadata (optional; aligned by index if provided). - * @returns Promise that resolves when the update completes. + * @param params.id - The ID of the document to update. + * @param params.document - New content for the document. + * @param params.embedding - New embedding for the document. If not provided, it will be generated based on the `document`. + * @param params.metadata - New metadata for the document. + * @returns Promise that resolves once the document is updated. */ async updateDocument(params: { - ids: string[]; - embeddings?: number[][]; - documents?: string[]; - metadatas?: Record[]; + id: string; + embedding?: number[]; + document?: string; + metadata?: Record; }): Promise { return this.vectorStore.update(params); } /** - * Deletes documents from the vector store. + * Deletes documents from the vector store by the provided predicate. * @param params - Parameters for deletion. - * @param params.ids - List of document IDs to delete. * @param params.predicate - Predicate to match documents for deletion. - * @returns Promise that resolves when deletion completes. + * @returns Promise that resolves once the documents are deleted. */ async deleteDocument(params: { - ids?: string[]; - predicate?: (value: GetResult) => boolean; + predicate: (value: GetResult) => boolean; }): Promise { return this.vectorStore.delete(params); } @@ -161,15 +158,14 @@ Context: ${retrievedDocs.map((result) => result.document).join('\n')}`; * Generates a response based on the input messages and retrieved documents. * If `augmentedGeneration` is true, it retrieves relevant documents from the vector store * and includes them in the prompt for the LLM. - * * @param params - Generation parameters. * @param params.input - Input messages or a single string. * @param params.augmentedGeneration - Whether to augment with retrieved context (default: true). * @param params.nResults - Number of docs to retrieve (default: 3). - * @param params.predicate - Optional filter applied to retrieved docs. + * @param params.predicate - Filter applied to retrieved docs. * @param params.questionGenerator - Maps the message list to a search query (default: last message content). * @param params.promptGenerator - Builds the context-augmented prompt from messages and retrieved docs. - * @param params.callback - Optional token callback for streaming. + * @param params.callback - Token callback for streaming. * @returns Promise that resolves to the generated text. */ public async generate(params: { @@ -210,12 +206,11 @@ Context: ${retrievedDocs.map((result) => result.document).join('\n')}`; throw new Error('Last message has no content'); } const retrievedDocs = await this.vectorStore.query({ - queryTexts: [questionGenerator(input)], + queryText: questionGenerator(input), nResults, - queryEmbeddings: undefined, predicate, }); - const prompt = promptGenerator(input, retrievedDocs[0] ?? []); + const prompt = promptGenerator(input, retrievedDocs); const augmentedInput: Message[] = [ ...input, { role: 'user', content: prompt }, diff --git a/src/types/common.ts b/src/types/common.ts index 9ab9fbb..08bee67 100644 --- a/src/types/common.ts +++ b/src/types/common.ts @@ -1,10 +1,10 @@ /** - * Represents a chat message exchanged in a conversation. - * - `role`: Identifies the sender (`user`, `assistant`, or `system`). - * - `content`: The text content of the message. + * Chat message in a conversation. */ export interface Message { + /** Sender role. */ role: 'user' | 'assistant' | 'system'; + /** Message text content. */ content: string; } @@ -17,24 +17,23 @@ export interface Message { export type ResourceSource = string | number | object; /** - * Represents a single retrieval result. - * Each field is aligned by index. - * - `document`: Retrieved document text. - * - `embedding`: Embedding vector for the document. - * - `id`: Document identifier. - * - `metadata`: Optional metadata object (`Record`). + * Single retrieval result. */ export interface GetResult { - document: string; - embedding: number[]; + /** Document identifier. */ id: string; + /** Retrieved document text. */ + document?: string; + /** Embedding vector for the document. */ + embedding: number[]; + /** Document metadata. */ metadata?: Record; } /** - * Represents a single scored result from a similarity query. - * Extends {@link GetResult} with a `similarity` score in the range [-1, 1]. + * Retrieval result with cosine similarity score. */ export interface QueryResult extends GetResult { + /** Similarity score. */ similarity: number; } diff --git a/src/utils/vectorMath.ts b/src/utils/vectorMath.ts index 0202ede..d209df8 100644 --- a/src/utils/vectorMath.ts +++ b/src/utils/vectorMath.ts @@ -1,11 +1,6 @@ /** - * Calculates the dot product of two vectors. - * The dot product is a scalar value that represents the sum of the products of corresponding components of two vectors. - * It is a measure of how much two vectors are in the same direction. - * @param a The first vector (array of numbers). - * @param b The second vector (array of numbers). - * @returns The dot product of the two vectors. - * @throws {Error} If the vectors are not of the same length. + * Returns the dot product of two equal-length vectors. + * @throws {Error} If vector lengths differ. */ export function dotProduct(a: number[], b: number[]): number { if (a.length !== b.length) { @@ -15,24 +10,15 @@ export function dotProduct(a: number[], b: number[]): number { } /** - * Calculates the Euclidean magnitude (or L2-norm) of a vector. - * The magnitude represents the length of the vector from the origin to its coordinates. - * @param a The vector (array of numbers). - * @returns The magnitude of the vector. + * Returns the Euclidean (L2) norm of `a`. */ export function magnitude(a: number[]): number { return Math.sqrt(a.reduce((sum, ai) => sum + ai * ai, 0)); } /** - * Calculates the cosine similarity between two vectors. - * Cosine similarity measures the cosine of the angle between two non-zero vectors. - * It is a measure of similarity between two vectors, ranging from -1 (opposite) to 1 (identical), - * with 0 indicating orthogonality (no similarity). - * Note: both vectors must have non-zero magnitude to avoid division by zero. - * @param a - The first vector. - * @param b - The second vector. - * @returns The cosine similarity between the two vectors. + * Returns cosine similarity of two vectors in [-1, 1]. + * Inputs must be non-zero vectors. */ export function cosine(a: number[], b: number[]): number { return dotProduct(a, b) / (magnitude(a) * magnitude(b)); diff --git a/src/vector_stores/memoryVectorStore.ts b/src/vector_stores/memoryVectorStore.ts index 214e219..6337792 100644 --- a/src/vector_stores/memoryVectorStore.ts +++ b/src/vector_stores/memoryVectorStore.ts @@ -19,7 +19,7 @@ import { cosine } from '../utils/vectorMath'; export class MemoryVectorStore implements VectorStore { private embeddings: Embeddings; private rows = new Map(); - private dim?: number; + private embeddingDim?: number; /** * Creates a new in-memory vector store. @@ -34,6 +34,7 @@ export class MemoryVectorStore implements VectorStore { */ public async load(): Promise { await this.embeddings.load(); + this.embeddingDim = (await this.embeddings.embed('dummy')).length; return this; } @@ -46,251 +47,139 @@ export class MemoryVectorStore implements VectorStore { } /** - * Adds one or more documents to the in-memory store. Generates IDs when not provided. + * Adds a document to the in-memory store. * @param params - Parameters for the operation. - * @param params.ids - Optional IDs for each document (must match `documents.length`). - * @param params.documents - Raw text content for each document. - * @param params.embeddings - Optional embeddings for each document. - * @param params.metadatas - Optional metadata for each document (aligned by index). - * @returns Promise that resolves to the IDs of the newly added documents. + * @param params.id - ID for the document. + * @param params.document - Raw text content for the document. + * @param params.embedding - Embeddings for the document. + * @param params.metadata - Metadata for the document. + * @returns Promise that resolves to the ID of the newly added document. */ public async add(params: { - ids?: string[]; - documents: string[]; - embeddings?: number[][]; - metadatas?: Record[]; - }): Promise { - const { embeddings, documents, metadatas } = params; - const ids = params.ids ?? documents.map(() => uuidv4()); - - const idsLength = ids.length; - this.assertLengthMatchIds(embeddings, idsLength); - this.assertLengthMatchIds(documents, idsLength); - this.assertLengthMatchIds(metadatas, idsLength); - - for (const id of ids) { - if (this.rows.has(id)) { - throw new Error(`id already exists: ${id}`); - } + id?: string; + document?: string; + embedding?: number[]; + metadata?: Record; + }): Promise { + const { id = uuidv4(), document, embedding, metadata } = params; + + if (!document && !embedding) { + throw new Error('document and embedding cannot be both undefined'); } - if (embeddings) { - for (const emb of embeddings) { - this.assertAndSetDim(emb); - } + if (embedding && embedding.length !== this.embeddingDim) { + throw new Error( + `embedding dimension ${embedding.length} does not match collection embedding dimension ${this.embeddingDim}` + ); } - for (let i = 0; i < idsLength; i++) { - this.rows.set(ids[i]!, { - id: ids[i]!, - document: documents[i]!, - embedding: embeddings - ? embeddings[i]! - : await this.embeddings.embed(documents[i]!), - metadata: metadatas ? metadatas[i]! : undefined, - }); + if (this.rows.has(id)) { + throw new Error(`id already exists: ${id}`); } - return ids; + this.rows.set(id, { + id, + document, + embedding: embedding ?? (await this.embeddings.embed(document!)), + metadata, + }); + + return id; } /** - * Updates one or more documents by ID. If `documents` are provided and - * `embeddings` are not, fresh embeddings are generated automatically. - * @param params - Parameters for the update. - * @param params.ids - IDs of the documents to update. - * @param params.embeddings - New embeddings (optional; aligned by index if provided). - * @param params.documents - New content (optional; aligned by index if provided). - * @param params.metadatas - New metadata (optional; aligned by index if provided). + * Updates a document by ID. + * Recomputes the embedding when `document` is provided and `embedding` is omitted. + * @param params - Update parameters. + * @param params.id - ID of the document to update. + * @param params.document - New content. + * @param params.embedding - New embedding. + * @param params.metadata - New metadata. * @returns Promise that resolves when the update completes. */ public async update(params: { - ids: string[]; - embeddings?: number[][]; - documents?: string[]; - metadatas?: Record[]; + id: string; + document?: string; + embedding?: number[]; + metadata?: Record; }): Promise { - const { ids, embeddings, documents, metadatas } = params; - - const n = ids.length; - this.assertLengthMatchIds(embeddings, n); - this.assertLengthMatchIds(documents, n); - this.assertLengthMatchIds(metadatas, n); + const { id, document, embedding, metadata } = params; - for (const id of ids) { - if (!this.rows.has(id)) { - throw new Error(`id not found: ${id}`); - } + if (embedding && embedding.length !== this.embeddingDim) { + throw new Error( + `embedding dimension ${embedding.length} does not match collection embedding dimension ${this.embeddingDim}` + ); } - if (embeddings) { - for (const emb of embeddings) { - this.assertAndSetDim(emb); - } + if (!this.rows.has(id)) { + throw new Error(`id not found: ${id}`); } - for (let i = 0; i < n; i++) { - const id = ids[i]!; - const row = this.rows.get(id)!; + const oldRow = this.rows.get(id)!; - this.rows.set(id, { - id, - document: documents ? documents[i]! : row.document, - embedding: embeddings - ? embeddings[i]! - : documents - ? await this.embeddings.embed(documents[i]!) - : row.embedding, - metadata: metadatas ? metadatas[i]! : row.metadata, - }); - } + this.rows.set(id, { + id, + document: document ?? oldRow.document, + embedding: + embedding ?? + (document ? await this.embeddings.embed(document!) : oldRow.embedding), + metadata: metadata ?? oldRow.metadata, + }); } /** - * Deletes documents by IDs and/or predicate. + * Deletes documents by predicate. * @param params - Parameters for deletion. - * @param params.ids - List of document IDs to delete. * @param params.predicate - Predicate to match documents for deletion. - * @returns Promise that resolves when deletion completes. + * @returns Promise that resolves once the documents are deleted. */ public async delete(params: { - ids?: string[]; - predicate?: (value: GetResult) => boolean; + predicate: (value: GetResult) => boolean; }): Promise { - const { ids, predicate } = params; - - if (ids && predicate) { - for (const id of ids) { - if (!this.rows.has(id)) { - throw new Error(`id not found: ${id}`); - } - } + const { predicate } = params; - for (const id of ids) { - const row = this.rows.get(id)!; - if (predicate(row)) { - this.rows.delete(id); - } - } - } else if (ids) { - for (const id of ids) { - if (!this.rows.has(id)) { - throw new Error(`id not found: ${id}`); - } - } - - for (const id of ids) { + for (const [id, row] of this.rows) { + if (predicate(row)) { this.rows.delete(id); } - } else if (predicate) { - for (const [id, row] of this.rows) { - if (predicate(row)) { - this.rows.delete(id); - } - } } } /** * Executes a cosine-similarity query over the in-memory vectors. - * Provide exactly one of `queryTexts` or `queryEmbeddings`. + * Provide exactly one of `queryText` or `queryEmbedding`. * @param params - Query parameters. - * @param params.queryTexts - Raw query strings to search for. - * @param params.queryEmbeddings - Precomputed query embeddings. + * @param params.queryText - Raw query string to search for. + * @param params.queryEmbedding - Precomputed query embedding. * @param params.nResults - Number of top results to return. - * @param params.ids - Restrict the search to these document IDs. * @param params.predicate - Function to filter results after retrieval. - * @returns Promise resolving to arrays of scored results for each query. + * @returns Promise that resolves to an array of {@link QueryResult}. */ public async query(params: { - queryTexts?: string[]; - queryEmbeddings?: number[][]; + queryText?: string; + queryEmbedding?: number[]; nResults?: number; - ids?: string[]; predicate?: (value: QueryResult) => boolean; - }): Promise { - const { - queryTexts, - queryEmbeddings, - nResults, - ids, - predicate = () => true, - } = params; - if (!queryTexts === !queryEmbeddings) { - throw new Error( - 'Exactly one of queryTexts or queryEmbeddings must be provided' - ); - } + }): Promise { + const { queryText, queryEmbedding, nResults, predicate } = params; - if (ids) { - for (const id of ids) { - if (!this.rows.has(id)) { - throw new Error(`id not found: ${id}`); - } - } - } - - const queries: number[][] = []; - - if (queryEmbeddings) { - for (const emb of queryEmbeddings) { - this.assertAndSetDim(emb); - queries.push(emb); - } - } else if (queryTexts) { - for (const text of queryTexts) { - const emb = await this.embeddings.embed(text); - queries.push(emb); - } - } - - const pool: GetResult[] = ids?.length - ? ids.map((id) => this.rows.get(id)!) - : Array.from(this.rows.values()); - - const result: QueryResult[][] = []; - - for (const q of queries) { - const scored = pool - .map( - (r) => - ({ - ...r, - similarity: cosine(q, r.embedding), - }) as QueryResult - ) - .filter(predicate) - .sort((a, b) => b.similarity - a.similarity) - .slice(0, nResults); - - result.push(scored); + if (!queryText && !queryEmbedding) { + throw new Error('queryText and queryEmbedding cannot be both undefined'); } - return result; - } - - /** - * Ensures all embeddings share the same dimensionality, setting it on first use. - */ - private assertAndSetDim(vec: number[]) { - if (!Array.isArray(vec) || vec.length === 0) { - throw new Error('embedding must be a non-empty vector'); - } - if (this.dim === undefined) { - this.dim = vec.length; - } else if (vec.length !== this.dim) { + if (queryEmbedding && queryEmbedding.length !== this.embeddingDim) { throw new Error( - `embedding dimension ${vec.length} does not match collection dimension ${this.dim}` + `queryEmbedding dimension ${queryEmbedding.length} does not match collection embedding dimension ${this.embeddingDim}` ); } - } - /** - * Verifies optional arrays match expected length. - */ - private assertLengthMatchIds(arr: T[] | undefined, idsLength: number) { - if (arr && arr.length !== idsLength) { - throw new Error('array length must match ids length'); - } + const searchEmbedding = + queryEmbedding ?? (await this.embeddings.embed(queryText!)); + + return Array.from(this.rows.values()) + .map((r) => ({ ...r, similarity: cosine(searchEmbedding, r.embedding) })) + .filter(predicate ?? (() => true)) + .sort((a, b) => b.similarity - a.similarity) + .slice(0, nResults); } }