diff --git a/src/store/src/Bridge/ChromaDb/Store.php b/src/store/src/Bridge/ChromaDb/Store.php index 732e3c35c..45ff3f933 100644 --- a/src/store/src/Bridge/ChromaDb/Store.php +++ b/src/store/src/Bridge/ChromaDb/Store.php @@ -53,23 +53,40 @@ public function add(VectorDocument ...$documents): void } /** - * @param array{where?: array, whereDocument?: array} $options + * @param array{where?: array, whereDocument?: array, include?: array} $options */ public function query(Vector $vector, array $options = []): iterable { + $include = null; + if ([] !== ($options['include'] ?? [])) { + $include = array_values( + array_unique( + array_merge(['embeddings', 'metadatas', 'distances'], $options['include']) + ) + ); + } + $collection = $this->client->getOrCreateCollection($this->collectionName); $queryResponse = $collection->query( queryEmbeddings: [$vector->getData()], nResults: 4, where: $options['where'] ?? null, whereDocument: $options['whereDocument'] ?? null, + include: $include, ); - for ($i = 0; $i < \count($queryResponse->metadatas[0]); ++$i) { + $metaCount = \count($queryResponse->metadatas[0]); + + for ($i = 0; $i < $metaCount; ++$i) { + $metaData = new Metadata($queryResponse->metadatas[0][$i]); + if (isset($queryResponse->documents[0][$i])) { + $metaData->setText($queryResponse->documents[0][$i]); + } + yield new VectorDocument( id: Uuid::fromString($queryResponse->ids[0][$i]), vector: new Vector($queryResponse->embeddings[0][$i]), - metadata: new Metadata($queryResponse->metadatas[0][$i]), + metadata: $metaData, score: $queryResponse->distances[0][$i] ?? null, ); } diff --git a/src/store/tests/Bridge/ChromaDb/StoreTest.php b/src/store/tests/Bridge/ChromaDb/StoreTest.php index 55a291975..5cb8d9944 100644 --- a/src/store/tests/Bridge/ChromaDb/StoreTest.php +++ b/src/store/tests/Bridge/ChromaDb/StoreTest.php @@ -469,6 +469,108 @@ public function testQueryWithVariousFilterCombinations( $this->assertCount(1, $documents); } + public function testQueryReturnsMetadatasEmbeddingsDistanceWithoutInclude() + { + $queryVector = new Vector([0.15, 0.25, 0.35]); + $queryResponse = new QueryItemsResponse( + ids: [['01234567-89ab-cdef-0123-456789abcdef']], + embeddings: [[[0.1, 0.2, 0.3]]], + metadatas: [[['title' => 'Doc 1']]], + documents: null, + data: null, + uris: null, + distances: null + ); + + $collection = $this->createMock(CollectionResource::class); + $client = $this->createMock(Client::class); + + $client->expects($this->once()) + ->method('getOrCreateCollection') + ->with('test-collection') + ->willReturn($collection); + + $collection->expects($this->once()) + ->method('query') + ->willReturn($queryResponse); + + $store = new Store($client, 'test-collection'); + $documents = iterator_to_array($store->query($queryVector)); + + $this->assertCount(1, $documents); + $this->assertSame('01234567-89ab-cdef-0123-456789abcdef', (string) $documents[0]->id); + $this->assertSame([0.1, 0.2, 0.3], $documents[0]->vector->getData()); + $this->assertSame(['title' => 'Doc 1'], $documents[0]->metadata->getArrayCopy()); + } + + public function testQueryReturnsMetadatasEmbeddingsDistanceWithOnlyDocuments() + { + $queryVector = new Vector([0.15, 0.25, 0.35]); + $queryResponse = new QueryItemsResponse( + ids: [['01234567-89ab-cdef-0123-456789abcdef']], + embeddings: [[[0.1, 0.2, 0.3]]], + metadatas: [[['title' => 'Doc 1']]], + documents: [['Document content here']], + data: null, + uris: null, + distances: null + ); + + $collection = $this->createMock(CollectionResource::class); + $client = $this->createMock(Client::class); + + $client->expects($this->once()) + ->method('getOrCreateCollection') + ->with('test-collection') + ->willReturn($collection); + + $collection->expects($this->once()) + ->method('query') + ->willReturn($queryResponse); + + $store = new Store($client, 'test-collection'); + $documents = iterator_to_array($store->query($queryVector, ['include' => ['documents']])); + + $this->assertCount(1, $documents); + $this->assertSame('01234567-89ab-cdef-0123-456789abcdef', (string) $documents[0]->id); + $this->assertSame([0.1, 0.2, 0.3], $documents[0]->vector->getData()); + $this->assertSame(['title' => 'Doc 1', '_text' => 'Document content here'], $documents[0]->metadata->getArrayCopy()); + } + + public function testQueryReturnsMetadatasEmbeddingsDistanceWithAll() + { + $queryVector = new Vector([0.15, 0.25, 0.35]); + $queryResponse = new QueryItemsResponse( + ids: [['01234567-89ab-cdef-0123-456789abcdef']], + embeddings: [[[0.1, 0.2, 0.3]]], + metadatas: [[['title' => 'Doc 1']]], + documents: [['Document content here']], + data: null, + uris: null, + distances: null + ); + + $collection = $this->createMock(CollectionResource::class); + $client = $this->createMock(Client::class); + + $client->expects($this->once()) + ->method('getOrCreateCollection') + ->with('test-collection') + ->willReturn($collection); + + $collection->expects($this->once()) + ->method('query') + ->willReturn($queryResponse); + + $store = new Store($client, 'test-collection'); + $documents = iterator_to_array($store->query($queryVector, ['include' => ['embeddings', 'metadatas', 'distances', 'documents']])); + + $this->assertCount(1, $documents); + $this->assertSame('01234567-89ab-cdef-0123-456789abcdef', (string) $documents[0]->id); + $this->assertSame([0.1, 0.2, 0.3], $documents[0]->vector->getData()); + $this->assertSame(['title' => 'Doc 1', '_text' => 'Document content here'], $documents[0]->metadata->getArrayCopy()); + } + /** * @return \Iterator, whereDocument?: array},