diff --git a/adapters/cf/dev/specs/adapter.spec.ts b/adapters/cf/dev/specs/adapter.spec.ts index af7a47c..5afdb44 100644 --- a/adapters/cf/dev/specs/adapter.spec.ts +++ b/adapters/cf/dev/specs/adapter.spec.ts @@ -59,7 +59,7 @@ function createMockCloudflareBinding() { } }), - delete: vi.fn(async (ids: string[]) => { + deleteByIds: vi.fn(async (ids: string[]) => { for (const id of ids) { storage.delete(id) } @@ -79,6 +79,20 @@ function createMockCloudflareBinding() { } } +function createMockPayloadForEmbed(mockBinding: any) { + return { + config: { + custom: { + createVectorizedPayloadObject: () => ({ + getDbAdapterCustom: () => ({ _vectorizeBinding: mockBinding }), + }), + }, + }, + create: vi.fn().mockResolvedValue({ id: 'mapping-1' }), + logger: { error: vi.fn() }, + } as any +} + describe('createCloudflareVectorizeIntegration', () => { describe('validation', () => { test('should throw if vectorize binding is missing', () => { @@ -91,7 +105,7 @@ describe('createCloudflareVectorizeIntegration', () => { }) test('should create integration with valid config', () => { - const mockVectorize = { query: vi.fn(), upsert: vi.fn(), delete: vi.fn() } + const mockVectorize = { query: vi.fn(), upsert: vi.fn(), deleteByIds: vi.fn() } const integration = createCloudflareVectorizeIntegration({ config: { default: { dims: 384 } }, @@ -110,7 +124,7 @@ describe('createCloudflareVectorizeIntegration', () => { describe('getConfigExtension', () => { test('should return config with pool configurations', () => { const poolConfigs = { mainPool: { dims: 384 }, secondaryPool: { dims: 768 } } - const mockVectorize = { query: vi.fn() } + const mockVectorize = { query: vi.fn(), upsert: vi.fn(), deleteByIds: vi.fn() } const { adapter } = createCloudflareVectorizeIntegration({ config: poolConfigs, @@ -121,6 +135,20 @@ describe('createCloudflareVectorizeIntegration', () => { expect(extension.custom?._cfVectorizeAdapter).toBe(true) expect(extension.custom?._poolConfigs).toEqual(poolConfigs) }) + + test('should return collections with cfMappings', () => { + const mockVectorize = { query: vi.fn(), upsert: vi.fn(), deleteByIds: vi.fn() } + + const { adapter } = createCloudflareVectorizeIntegration({ + config: { default: { dims: 384 } }, + binding: mockVectorize, + }) + const extension = adapter.getConfigExtension({} as any) + + expect(extension.collections).toBeDefined() + expect(extension.collections!['vector-cf-mappings']).toBeDefined() + expect(extension.collections!['vector-cf-mappings'].slug).toBe('vector-cf-mappings') + }) }) describe('storeEmbedding', () => { @@ -132,9 +160,16 @@ describe('createCloudflareVectorizeIntegration', () => { }) const embedding = new Float32Array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]) - const mockPayload = { context: {} } as any + const mockPayload = createMockPayloadForEmbed(mockBinding) - await adapter.storeEmbedding(mockPayload, 'default', 'test-id', embedding) + await adapter.storeEmbedding( + mockPayload, + 'default', + 'test-collection', + 'doc-1', + 'test-id', + embedding, + ) expect(mockBinding.upsert).toHaveBeenCalledWith([ { @@ -144,69 +179,121 @@ describe('createCloudflareVectorizeIntegration', () => { ]) }) - test('should inject vectorize binding into context', async () => { + test('should create a mapping row', async () => { const mockBinding = createMockCloudflareBinding() const { adapter } = createCloudflareVectorizeIntegration({ config: { default: { dims: 8 } }, binding: mockBinding as any, }) - const mockPayload = { context: {} } as any + const mockPayload = createMockPayloadForEmbed(mockBinding) const embedding = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] - await adapter.storeEmbedding(mockPayload, 'default', 'test-id', embedding) - - expect(mockPayload.context.vectorize).toBe(mockBinding) + await adapter.storeEmbedding( + mockPayload, + 'default', + 'test-collection', + 'doc-1', + 'test-id', + embedding, + ) + + expect(mockPayload.create).toHaveBeenCalledWith({ + collection: 'vector-cf-mappings', + data: { + vectorId: 'test-id', + poolName: 'default', + sourceCollection: 'test-collection', + docId: 'doc-1', + }, + }) }) }) describe('deleteEmbeddings', () => { - test('should query with correct where clause', async () => { + test('should look up mappings with correct where clause', async () => { const mockBinding = createMockCloudflareBinding() const { adapter } = createCloudflareVectorizeIntegration({ config: { default: { dims: 8 } }, binding: mockBinding as any, }) - const mockPayload = { context: {}, logger: { error: vi.fn() } } as any + const mockPayload = { + find: vi.fn().mockResolvedValue({ docs: [], hasNextPage: false }), + delete: vi.fn().mockResolvedValue({}), + logger: { error: vi.fn() }, + } as any await adapter.deleteEmbeddings?.(mockPayload, 'default', 'test-collection', 'doc-123') - expect(mockBinding.query).toHaveBeenCalled() - const queryCall = mockBinding.query.mock.calls[0] - const options = queryCall[1] - - expect(options.where?.and).toEqual([ - { key: 'sourceCollection', value: 'test-collection' }, - { key: 'docId', value: 'doc-123' }, - ]) + expect(mockPayload.find).toHaveBeenCalledWith( + expect.objectContaining({ + collection: 'vector-cf-mappings', + where: { + and: [ + { poolName: { equals: 'default' } }, + { sourceCollection: { equals: 'test-collection' } }, + { docId: { equals: 'doc-123' } }, + ], + }, + }), + ) }) - test('should delete matching vectors', async () => { + test('should delete matching vectors via mappings', async () => { const mockBinding = createMockCloudflareBinding() const { adapter } = createCloudflareVectorizeIntegration({ config: { default: { dims: 8 } }, binding: mockBinding as any, }) - // Manually add some vectors to the mock storage - const storage = mockBinding.__getStorage() - storage.set('vec-1', { - id: 'vec-1', - values: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], - metadata: { sourceCollection: 'test-collection', docId: 'doc-123' }, - }) - storage.set('vec-2', { - id: 'vec-2', - values: [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], - metadata: { sourceCollection: 'test-collection', docId: 'doc-123' }, + const mockPayload = { + find: vi.fn().mockResolvedValue({ + docs: [ + { id: 'map-1', vectorId: 'vec-1' }, + { id: 'map-2', vectorId: 'vec-2' }, + ], + hasNextPage: false, + }), + delete: vi.fn().mockResolvedValue({}), + logger: { error: vi.fn() }, + } as any + + await adapter.deleteEmbeddings?.(mockPayload, 'default', 'test-collection', 'doc-123') + + expect(mockBinding.deleteByIds).toHaveBeenCalledWith(['vec-1', 'vec-2']) + }) + + test('should clean up mapping rows after deleting vectors', async () => { + const mockBinding = createMockCloudflareBinding() + const { adapter } = createCloudflareVectorizeIntegration({ + config: { default: { dims: 8 } }, + binding: mockBinding as any, }) - const mockPayload = { context: {}, logger: { error: vi.fn() } } as any + const mockPayload = { + find: vi.fn().mockResolvedValue({ + docs: [{ id: 'map-1', vectorId: 'vec-1' }], + hasNextPage: false, + }), + delete: vi.fn().mockResolvedValue({}), + logger: { error: vi.fn() }, + } as any await adapter.deleteEmbeddings?.(mockPayload, 'default', 'test-collection', 'doc-123') - expect(mockBinding.delete).toHaveBeenCalledWith(['vec-1', 'vec-2']) + expect(mockPayload.delete).toHaveBeenCalledWith( + expect.objectContaining({ + collection: 'vector-cf-mappings', + where: { + and: [ + { poolName: { equals: 'default' } }, + { sourceCollection: { equals: 'test-collection' } }, + { docId: { equals: 'doc-123' } }, + ], + }, + }), + ) }) test('should handle empty results gracefully', async () => { @@ -216,24 +303,26 @@ describe('createCloudflareVectorizeIntegration', () => { binding: mockBinding as any, }) - const mockPayload = { context: {}, logger: { error: vi.fn() } } as any + const mockPayload = { + find: vi.fn().mockResolvedValue({ docs: [], hasNextPage: false }), + delete: vi.fn().mockResolvedValue({}), + logger: { error: vi.fn() }, + } as any await adapter.deleteEmbeddings?.(mockPayload, 'default', 'test-collection', 'doc-123') - expect(mockBinding.delete).not.toHaveBeenCalled() + expect(mockBinding.deleteByIds).not.toHaveBeenCalled() }) - test('should handle query errors', async () => { + test('should handle errors', async () => { const mockBinding = createMockCloudflareBinding() - mockBinding.query = vi.fn().mockRejectedValue(new Error('Query failed')) - const { adapter } = createCloudflareVectorizeIntegration({ config: { default: { dims: 8 } }, binding: mockBinding as any, }) const mockPayload = { - context: {}, + find: vi.fn().mockRejectedValue(new Error('Query failed')), logger: { error: vi.fn() }, } as any diff --git a/adapters/cf/dev/specs/compliance.spec.ts b/adapters/cf/dev/specs/compliance.spec.ts index 1158440..2fae60b 100644 --- a/adapters/cf/dev/specs/compliance.spec.ts +++ b/adapters/cf/dev/specs/compliance.spec.ts @@ -73,7 +73,7 @@ function createMockVectorizeBinding() { } }), - delete: vi.fn(async (ids: string[]) => { + deleteByIds: vi.fn(async (ids: string[]) => { for (const id of ids) { storage.delete(id) } @@ -177,6 +177,14 @@ describe('Cloudflare Adapter Compliance Tests', () => { expect(extension.custom!._poolConfigs.default).toBeDefined() expect(extension.custom!._poolConfigs.default.dims).toBe(DIMS) }) + + test('collections property contains cfMappings collection', () => { + const extension = adapter.getConfigExtension({} as any) + + expect(extension.collections).toBeDefined() + expect(extension.collections!['vector-cf-mappings']).toBeDefined() + expect(extension.collections!['vector-cf-mappings'].slug).toBe('vector-cf-mappings') + }) }) describe('storeEmbedding()', () => { @@ -185,12 +193,14 @@ describe('Cloudflare Adapter Compliance Tests', () => { .fill(0) .map(() => Math.random()) + const sourceDocId = `test-embed-1-${Date.now()}` + // Create a document first const doc = await payload.create({ collection: 'default' as any, data: { sourceCollection: 'test-collection', - docId: `test-embed-1-${Date.now()}`, + docId: sourceDocId, chunkIndex: 0, chunkText: 'test text for embedding', embeddingVersion: 'v1-test', @@ -198,7 +208,14 @@ describe('Cloudflare Adapter Compliance Tests', () => { }) await expect( - adapter.storeEmbedding(payload, 'default', String(doc.id), embedding), + adapter.storeEmbedding( + payload, + 'default', + 'test-collection', + sourceDocId, + String(doc.id), + embedding, + ), ).resolves.not.toThrow() expect(mockVectorize.upsert).toHaveBeenCalled() @@ -211,11 +228,13 @@ describe('Cloudflare Adapter Compliance Tests', () => { .map(() => Math.random()), ) + const sourceDocId = `test-embed-2-${Date.now()}` + const doc = await payload.create({ collection: 'default' as any, data: { sourceCollection: 'test-collection', - docId: `test-embed-2-${Date.now()}`, + docId: sourceDocId, chunkIndex: 0, chunkText: 'test text for Float32Array', embeddingVersion: 'v1-test', @@ -223,7 +242,14 @@ describe('Cloudflare Adapter Compliance Tests', () => { }) await expect( - adapter.storeEmbedding(payload, 'default', String(doc.id), embedding), + adapter.storeEmbedding( + payload, + 'default', + 'test-collection', + sourceDocId, + String(doc.id), + embedding, + ), ).resolves.not.toThrow() expect(mockVectorize.upsert).toHaveBeenCalled() @@ -232,23 +258,32 @@ describe('Cloudflare Adapter Compliance Tests', () => { test('stores embedding in Vectorize with correct ID', async () => { const embedding = Array(DIMS).fill(0.5) + const sourceDocId = `test-embed-id-${Date.now()}` + const doc = await payload.create({ collection: 'default' as any, data: { sourceCollection: 'test-collection', - docId: `test-embed-id-${Date.now()}`, + docId: sourceDocId, chunkIndex: 0, chunkText: 'test text', embeddingVersion: 'v1-test', }, }) - const docId = String(doc.id) - await adapter.storeEmbedding(payload, 'default', docId, embedding) + const embeddingId = String(doc.id) + await adapter.storeEmbedding( + payload, + 'default', + 'test-collection', + sourceDocId, + embeddingId, + embedding, + ) const storage = mockVectorize.__getStorage() - expect(storage.has(docId)).toBe(true) - expect(storage.get(docId)?.values).toEqual(embedding) + expect(storage.has(embeddingId)).toBe(true) + expect(storage.get(embeddingId)?.values).toEqual(embedding) }) }) @@ -263,19 +298,28 @@ describe('Cloudflare Adapter Compliance Tests', () => { .fill(0.5) .map((v) => v + Math.random() * 0.05) + const sourceDocId = `test-search-similar-${Date.now()}` + // Create and embed a document const similarDoc = await payload.create({ collection: 'default' as any, data: { sourceCollection: 'test-collection', - docId: `test-search-similar-${Date.now()}`, + docId: sourceDocId, chunkIndex: 0, chunkText: 'similar document for search test', embeddingVersion: 'v1-test', }, }) similarDocId = String(similarDoc.id) - await adapter.storeEmbedding(payload, 'default', similarDocId, similarEmbedding) + await adapter.storeEmbedding( + payload, + 'default', + 'test-collection', + sourceDocId, + similarDocId, + similarEmbedding, + ) }) test('returns an array of results', async () => { @@ -328,33 +372,55 @@ describe('Cloudflare Adapter Compliance Tests', () => { }) describe('deleteEmbeddings()', () => { - test('removes embeddings from Vectorize', async () => { + test('removes embeddings from Vectorize via mapping', async () => { const embedding = Array(DIMS).fill(0.7) + const sourceDocId = `doc-to-delete-${Date.now()}` + // Create and embed a document const doc = await payload.create({ collection: 'default' as any, data: { sourceCollection: 'delete-test', - docId: `doc-to-delete-${Date.now()}`, + docId: sourceDocId, chunkIndex: 0, chunkText: 'document to delete', embeddingVersion: 'v1-test', }, }) - const docId = String(doc.id) - await adapter.storeEmbedding(payload, 'default', docId, embedding) + const embeddingId = String(doc.id) + await adapter.storeEmbedding( + payload, + 'default', + 'delete-test', + sourceDocId, + embeddingId, + embedding, + ) - // Verify it's stored + // Verify it's stored in Vectorize const storage = mockVectorize.__getStorage() - expect(storage.has(docId)).toBe(true) + expect(storage.has(embeddingId)).toBe(true) // Delete it - await adapter.deleteEmbeddings?.(payload, 'default', 'delete-test', docId) - - // Verify it's gone - expect(mockVectorize.delete).toHaveBeenCalledWith([docId]) + await adapter.deleteEmbeddings?.(payload, 'default', 'delete-test', sourceDocId) + + // Verify deleteByIds was called with the correct vector ID + expect(mockVectorize.deleteByIds).toHaveBeenCalledWith([embeddingId]) + + // Verify mapping rows are cleaned up + const remainingMappings = await payload.find({ + collection: 'vector-cf-mappings' as any, + where: { + and: [ + { poolName: { equals: 'default' } }, + { sourceCollection: { equals: 'delete-test' } }, + { docId: { equals: sourceDocId } }, + ], + }, + }) + expect(remainingMappings.totalDocs).toBe(0) }) test('handles non-existent embeddings gracefully', async () => { diff --git a/adapters/cf/dev/specs/constants.ts b/adapters/cf/dev/specs/constants.ts index 41483a4..5b48946 100644 --- a/adapters/cf/dev/specs/constants.ts +++ b/adapters/cf/dev/specs/constants.ts @@ -57,7 +57,7 @@ export function createMockVectorizeBinding() { } }, - delete: async (ids: string[]) => { + deleteByIds: async (ids: string[]) => { for (const id of ids) { storage.delete(id) } diff --git a/adapters/cf/src/collections/cfMappings.ts b/adapters/cf/src/collections/cfMappings.ts new file mode 100644 index 0000000..ea9743b --- /dev/null +++ b/adapters/cf/src/collections/cfMappings.ts @@ -0,0 +1,49 @@ +import type { CollectionConfig } from 'payload' + +export const CF_MAPPINGS_SLUG = 'vector-cf-mappings' + +// This collection maps Cloudflare Vectorize vector IDs to source documents, +// so we can find and delete vectors when the source document is deleted. +const CFMappingsCollection: CollectionConfig = { + slug: CF_MAPPINGS_SLUG, + admin: { + hidden: true, + description: + 'Maps Cloudflare Vectorize vector IDs to source documents. Managed by the CF adapter.', + }, + access: { + read: () => true, + create: ({ req }) => req?.payloadAPI === 'local', + update: ({ req }) => req?.payloadAPI === 'local', + delete: ({ req }) => req?.payloadAPI === 'local', + }, + fields: [ + { + name: 'vectorId', + type: 'text', + required: true, + index: true, + }, + { + name: 'poolName', + type: 'text', + required: true, + index: true, + }, + { + name: 'sourceCollection', + type: 'text', + required: true, + index: true, + }, + { + name: 'docId', + type: 'text', + required: true, + index: true, + }, + ], + timestamps: true, +} + +export default CFMappingsCollection diff --git a/adapters/cf/src/embed.ts b/adapters/cf/src/embed.ts index 363f7c0..b77f91a 100644 --- a/adapters/cf/src/embed.ts +++ b/adapters/cf/src/embed.ts @@ -1,6 +1,7 @@ -import { Payload } from 'payload' +import { CollectionSlug, Payload } from 'payload' import { getVectorizedPayload } from 'payloadcms-vectorize' import type { CloudflareVectorizeBinding } from './types.js' +import { CF_MAPPINGS_SLUG } from './collections/cfMappings.js' /** * Store an embedding vector in Cloudflare Vectorize @@ -8,13 +9,14 @@ import type { CloudflareVectorizeBinding } from './types.js' export default async ( payload: Payload, poolName: string, + sourceCollection: string, + sourceDocId: string, id: string, embedding: number[] | Float32Array, ) => { // Get Cloudflare binding from config - const vectorizeBinding = getVectorizedPayload(payload)?.getDbAdapterCustom()?._vectorizeBinding as - | CloudflareVectorizeBinding - | undefined + const vectorizeBinding = getVectorizedPayload(payload)?.getDbAdapterCustom() + ?._vectorizeBinding as CloudflareVectorizeBinding | undefined if (!vectorizeBinding) { throw new Error('[@payloadcms-vectorize/cf] Cloudflare Vectorize binding not found') } @@ -29,6 +31,17 @@ export default async ( values: vector, }, ]) + + // Create a mapping row so we can find this vector during deletion + await payload.create({ + collection: CF_MAPPINGS_SLUG as CollectionSlug, + data: { + vectorId: id, + poolName, + sourceCollection, + docId: sourceDocId, + }, + }) } catch (e) { const errorMessage = e instanceof Error ? e.message : String(e) payload.logger.error(`[@payloadcms-vectorize/cf] Failed to store embedding: ${errorMessage}`) diff --git a/adapters/cf/src/index.ts b/adapters/cf/src/index.ts index e7f5cd7..a7c5b31 100644 --- a/adapters/cf/src/index.ts +++ b/adapters/cf/src/index.ts @@ -1,5 +1,7 @@ +import type { CollectionSlug } from 'payload' import type { DbAdapter } from 'payloadcms-vectorize' import type { CloudflareVectorizeBinding, KnowledgePoolsConfig } from './types.js' +import cfMappingsCollection, { CF_MAPPINGS_SLUG } from './collections/cfMappings.js' import embed from './embed.js' import search from './search.js' @@ -45,6 +47,9 @@ export const createCloudflareVectorizeIntegration = ( const adapter: DbAdapter = { getConfigExtension: () => { return { + collections: { + [CF_MAPPINGS_SLUG]: cfMappingsCollection, + }, custom: { _cfVectorizeAdapter: true, _poolConfigs: poolConfig, @@ -57,32 +62,56 @@ export const createCloudflareVectorizeIntegration = ( return search(payload, queryEmbedding, poolName, limit, where) }, - storeEmbedding: async (payload, poolName, id, embedding) => { - return embed(payload, poolName, id, embedding) + storeEmbedding: async (payload, poolName, sourceCollection, sourceDocId, id, embedding) => { + return embed(payload, poolName, sourceCollection, sourceDocId, id, embedding) }, deleteEmbeddings: async (payload, poolName, sourceCollection, docId) => { - // Delete all embeddings for this document from Cloudflare Vectorize - // First, query to find all matching IDs const vectorizeBinding = options.binding - const dims = poolConfig[poolName]?.dims || 384 + try { - const results = await vectorizeBinding.query(new Array(dims).fill(0), { - topK: 100, - returnMetadata: 'indexed', + // Paginate through all mapping rows for this document+pool + const allVectorIds: string[] = [] + let page = 1 + let hasNextPage = true + + while (hasNextPage) { + const mappings = await payload.find({ + collection: CF_MAPPINGS_SLUG as CollectionSlug, + where: { + and: [ + { poolName: { equals: poolName } }, + { sourceCollection: { equals: sourceCollection } }, + { docId: { equals: docId } }, + ], + }, + page, + }) + + for (const mapping of mappings.docs) { + allVectorIds.push((mapping as Record).vectorId as string) + } + + hasNextPage = mappings.hasNextPage + page++ + } + + if (allVectorIds.length === 0) { + return + } + // Delete vectors from Cloudflare Vectorize + await vectorizeBinding.deleteByIds(allVectorIds) + // Delete mapping rows + await payload.delete({ + collection: CF_MAPPINGS_SLUG as CollectionSlug, where: { and: [ - { key: 'sourceCollection', value: sourceCollection }, - { key: 'docId', value: docId }, + { poolName: { equals: poolName } }, + { sourceCollection: { equals: sourceCollection } }, + { docId: { equals: docId } }, ], }, }) - - const idsToDelete = (results.matches || []).map((match) => match.id) - - if (idsToDelete.length > 0) { - await vectorizeBinding.deleteByIds(idsToDelete) - } } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error) payload.logger.error( @@ -96,5 +125,6 @@ export const createCloudflareVectorizeIntegration = ( return { adapter } } +export { CF_MAPPINGS_SLUG } from './collections/cfMappings.js' export type { CloudflareVectorizeBinding, KnowledgePoolsConfig } export type { KnowledgePoolsConfig as KnowledgePoolConfig } diff --git a/adapters/cf/src/search.ts b/adapters/cf/src/search.ts index 3c39fff..c27c55d 100644 --- a/adapters/cf/src/search.ts +++ b/adapters/cf/src/search.ts @@ -32,30 +32,38 @@ export default async ( return [] } - // Fetch full documents from Payload for metadata - const searchResults: VectorSearchResult[] = [] + // Batch-fetch all matched documents, paginating through results + const matchIds = results.matches.map((m) => m.id) + const scoreById = new Map(results.matches.map((m) => [m.id, m.score || 0])) - for (const match of results.matches) { - try { - const doc = await payload.findByID({ - collection: poolName as CollectionSlug, - id: match.id, - }) + const docsById = new Map>() + let page = 1 + let hasNextPage = true + while (hasNextPage) { + const found = await payload.find({ + collection: poolName as CollectionSlug, + where: { id: { in: matchIds } }, + page, + }) + for (const doc of found.docs as Record[]) { + docsById.set(String(doc.id), doc) + } + hasNextPage = found.hasNextPage + page++ + } - if (doc && (!where || matchesWhere(doc as Record, where))) { - // Extract fields excluding internal ones - const { id: _id, createdAt: _createdAt, updatedAt: _updatedAt, ...docFields } = - doc as Record + // Build results preserving the original similarity-score order + const searchResults: VectorSearchResult[] = [] + for (const matchId of matchIds) { + const doc = docsById.get(matchId) + if (!doc || (where && !matchesWhere(doc, where))) continue - searchResults.push({ - id: match.id, - score: match.score || 0, - ...docFields, // Includes sourceCollection, docId, chunkText, embeddingVersion, extension fields - } as VectorSearchResult) - } - } catch (_e) { - // Document not found or error fetching, skip - } + const { id: _id, createdAt: _createdAt, updatedAt: _updatedAt, ...docFields } = doc + searchResults.push({ + id: matchId, + score: scoreById.get(matchId) || 0, + ...docFields, + } as VectorSearchResult) } return searchResults diff --git a/adapters/pg/src/embed.ts b/adapters/pg/src/embed.ts index 254ed3a..4c7505b 100644 --- a/adapters/pg/src/embed.ts +++ b/adapters/pg/src/embed.ts @@ -5,6 +5,8 @@ import toSnakeCase from 'to-snake-case' export default async ( payload: Payload, poolName: string, + _sourceCollection: string, + _sourceDocId: string, id: string, embedding: number[] | Float32Array, ) => { diff --git a/dev/helpers/adapterComplianceTests.ts b/dev/helpers/adapterComplianceTests.ts index 538035e..b18a699 100644 --- a/dev/helpers/adapterComplianceTests.ts +++ b/dev/helpers/adapterComplianceTests.ts @@ -119,13 +119,14 @@ export const runAdapterComplianceTests = (getContext: AdapterTestContextFactory) describe('storeEmbedding()', () => { test('persists embedding without error (number[])', async () => { const embedding = generateRandomEmbedding(ctx.dims) + const sourceDocId = `test-embed-1-${Date.now()}` // Create a document first const doc = await ctx.payload.create({ collection: ctx.poolName as any, data: { sourceCollection: 'test-collection', - docId: `test-embed-1-${Date.now()}`, + docId: sourceDocId, chunkIndex: 0, chunkText: 'test text for embedding', embeddingVersion: 'v1-test', @@ -133,18 +134,20 @@ export const runAdapterComplianceTests = (getContext: AdapterTestContextFactory) }) await expect( - ctx.adapter.storeEmbedding(ctx.payload, ctx.poolName, String(doc.id), embedding), + ctx.adapter.storeEmbedding(ctx.payload, ctx.poolName, 'test-collection', sourceDocId, String(doc.id), embedding), ).resolves.not.toThrow() }) test('persists embedding without error (Float32Array)', async () => { const embedding = new Float32Array(generateRandomEmbedding(ctx.dims)) + const sourceDocId = `test-embed-2-${Date.now()}` + const doc = await ctx.payload.create({ collection: ctx.poolName as any, data: { sourceCollection: 'test-collection', - docId: `test-embed-2-${Date.now()}`, + docId: sourceDocId, chunkIndex: 0, chunkText: 'test text for Float32Array', embeddingVersion: 'v1-test', @@ -152,7 +155,7 @@ export const runAdapterComplianceTests = (getContext: AdapterTestContextFactory) }) await expect( - ctx.adapter.storeEmbedding(ctx.payload, ctx.poolName, String(doc.id), embedding), + ctx.adapter.storeEmbedding(ctx.payload, ctx.poolName, 'test-collection', sourceDocId, String(doc.id), embedding), ).resolves.not.toThrow() }) }) @@ -169,25 +172,27 @@ export const runAdapterComplianceTests = (getContext: AdapterTestContextFactory) const differentEmbedding = generateDifferentEmbedding(ctx.dims) // Create similar document + const similarSourceDocId = `test-search-similar-${Date.now()}` const similarDoc = await ctx.payload.create({ collection: ctx.poolName as any, data: { sourceCollection: 'test-collection', - docId: `test-search-similar-${Date.now()}`, + docId: similarSourceDocId, chunkIndex: 0, chunkText: 'similar document', embeddingVersion: 'v1-test', }, }) similarDocId = String(similarDoc.id) - await ctx.adapter.storeEmbedding(ctx.payload, ctx.poolName, similarDocId, similarEmbedding) + await ctx.adapter.storeEmbedding(ctx.payload, ctx.poolName, 'test-collection', similarSourceDocId, similarDocId, similarEmbedding) // Create different document + const differentSourceDocId = `test-search-different-${Date.now()}` const differentDoc = await ctx.payload.create({ collection: ctx.poolName as any, data: { sourceCollection: 'test-collection', - docId: `test-search-different-${Date.now()}`, + docId: differentSourceDocId, chunkIndex: 0, chunkText: 'different document', embeddingVersion: 'v1-test', @@ -197,6 +202,8 @@ export const runAdapterComplianceTests = (getContext: AdapterTestContextFactory) await ctx.adapter.storeEmbedding( ctx.payload, ctx.poolName, + 'test-collection', + differentSourceDocId, differentDocId, differentEmbedding, ) diff --git a/dev/helpers/mockAdapter.ts b/dev/helpers/mockAdapter.ts index b018f9b..3499457 100644 --- a/dev/helpers/mockAdapter.ts +++ b/dev/helpers/mockAdapter.ts @@ -51,6 +51,8 @@ export const createMockAdapter = (options: MockAdapterOptions = {}): DbAdapter = storeEmbedding: async ( _payload: Payload, poolName: KnowledgePoolName, + _sourceCollection: string, + _sourceDocId: string, id: string, embedding: number[] | Float32Array, ): Promise => { diff --git a/src/index.ts b/src/index.ts index 23ffaa4..d9a95e1 100644 --- a/src/index.ts +++ b/src/index.ts @@ -393,6 +393,15 @@ export default (pluginOptions: PayloadcmsVectorizeConfig) => config.bin = [...(config.bin || []), ...configExtension.bins] } + // Register adapter-provided collections + if (configExtension?.collections) { + for (const [_slug, collectionConfig] of Object.entries(configExtension.collections)) { + if (!config.collections!.find((c) => c.slug === collectionConfig.slug)) { + config.collections!.push(collectionConfig) + } + } + } + return config } diff --git a/src/tasks/bulkEmbedAll.ts b/src/tasks/bulkEmbedAll.ts index 189e53b..7642591 100644 --- a/src/tasks/bulkEmbedAll.ts +++ b/src/tasks/bulkEmbedAll.ts @@ -884,6 +884,8 @@ async function pollAndCompleteSingleBatch(args: { await adapter.storeEmbedding( payload, poolName, + meta.sourceCollection, + String(meta.docId), String(created.id), embeddingArray, ) diff --git a/src/tasks/vectorize.ts b/src/tasks/vectorize.ts index dc7597c..4e8b843 100644 --- a/src/tasks/vectorize.ts +++ b/src/tasks/vectorize.ts @@ -162,7 +162,7 @@ async function runVectorizeTask(args: { const id = String(created.id) - await adapter.storeEmbedding(payload, poolName, id, vector) + await adapter.storeEmbedding(payload, poolName, collection, String(sourceDoc.id), id, vector) }), ) } diff --git a/src/types.ts b/src/types.ts index bc87450..a43d370 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,4 +1,13 @@ -import type { CollectionSlug, Payload, Field, Where, Config, BasePayload, TypeWithID } from 'payload' +import type { + CollectionConfig, + CollectionSlug, + Payload, + Field, + Where, + Config, + BasePayload, + TypeWithID, +} from 'payload' /** Result from bulkEmbed method */ export type BulkEmbedResult = @@ -365,6 +374,7 @@ export type DbAdapter = { getConfigExtension: (payloadCmsConfig: Config) => { bins?: { key: string; scriptPath: string }[] custom?: Record + collections?: Record } search: ( payload: BasePayload, @@ -376,7 +386,9 @@ export type DbAdapter = { storeEmbedding: ( payload: Payload, poolName: KnowledgePoolName, - id: string, + sourceCollection: string, + sourceDocId: string, + embeddingId: string, embedding: number[] | Float32Array, ) => Promise /**