Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20250918193514400373.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "add customization to vector store"
}
55 changes: 47 additions & 8 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,24 @@
"name": "Indexer",
"type": "debugpy",
"request": "launch",
"module": "uv",
"module": "graphrag",
"args": [
"poe", "index",
"--root", "<path_to_ragtest_root_demo>"
"index",
"--root",
"<path_to_index_folder>"
],
"console": "integratedTerminal"
},
{
"name": "Query",
"type": "debugpy",
"request": "launch",
"module": "uv",
"module": "graphrag",
"args": [
"poe", "query",
"--root", "<path_to_ragtest_root_demo>",
"--method", "global",
"query",
"--root",
"<path_to_index_folder>",
"--method", "basic",
"--query", "What are the top themes in this story",
]
},
Expand All @@ -34,6 +37,42 @@
"--config",
"<path_to_ragtest_root_demo>/settings.yaml",
]
}
},
{
"name": "Debug Integration Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"./tests/integration/vector_stores",
"-k", "test_azure_ai_search"
],
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "Debug Verbs Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"./tests/verbs",
"-k", "test_generate_text_embeddings"
],
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "Debug Smoke Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"./tests/smoke",
"-k", "test_fixtures"
],
"console": "integratedTerminal",
"justMyCode": false
},
]
}
1 change: 1 addition & 0 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ class VectorStoreDefaults:
api_key: None = None
audience: None = None
database_name: None = None
schema: None = None


@dataclass
Expand Down
6 changes: 3 additions & 3 deletions graphrag/config/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
]


def create_collection_name(
def create_index_name(
container_name: str, embedding_name: str, validate: bool = True
) -> str:
"""
Create a collection name for the embedding store.
Create a index name for the embedding store.

Within any given vector store, we can have multiple sets of embeddings organized into projects.
The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation.
The `container` param is used for this partitioning, and is added as a prefix to the index name for differentiation.

The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings

Expand Down
18 changes: 18 additions & 0 deletions graphrag/config/models/vector_store_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from pydantic import BaseModel, Field, model_validator

from graphrag.config.defaults import vector_store_defaults
from graphrag.config.embeddings import all_embeddings
from graphrag.config.enums import VectorStoreType
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig


class VectorStoreConfig(BaseModel):
Expand Down Expand Up @@ -85,9 +87,25 @@ def _validate_url(self) -> None:
default=vector_store_defaults.overwrite,
)

embeddings_schema: dict[str, VectorStoreSchemaConfig] = {}

def _validate_embeddings_schema(self) -> None:
"""Validate the embeddings schema."""
for name in self.embeddings_schema:
if name not in all_embeddings:
msg = f"vector_store.embeddings_schema contains an invalid embedding schema name: {name}. Please update your settings.yaml and select the correct embedding schema names."
raise ValueError(msg)

if self.type == VectorStoreType.CosmosDB:
for id_field in self.embeddings_schema:
if id_field != "id":
msg = "When using CosmosDB, the id_field in embeddings_schema must be 'id'. Please update your settings.yaml and set the id_field to 'id'."
raise ValueError(msg)

@model_validator(mode="after")
def _validate_model(self):
"""Validate the model."""
self._validate_db_uri()
self._validate_url()
self._validate_embeddings_schema()
return self
66 changes: 66 additions & 0 deletions graphrag/config/models/vector_store_schema_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Parameterization settings for the default configuration."""

import re

from pydantic import BaseModel, Field, model_validator

DEFAULT_VECTOR_SIZE: int = 1536

VALID_IDENTIFIER_REGEX = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")


def is_valid_field_name(field: str) -> bool:
"""Check if a field name is valid for CosmosDB."""
return bool(VALID_IDENTIFIER_REGEX.match(field))


class VectorStoreSchemaConfig(BaseModel):
"""The default configuration section for Vector Store Schema."""

id_field: str = Field(
description="The ID field to use.",
default="id",
)

vector_field: str = Field(
description="The vector field to use.",
default="vector",
)

text_field: str = Field(
description="The text field to use.",
default="text",
)

attributes_field: str = Field(
description="The attributes field to use.",
default="attributes",
)

vector_size: int = Field(
description="The vector size to use.",
default=DEFAULT_VECTOR_SIZE,
)

index_name: str | None = Field(description="The index name to use.", default=None)

def _validate_schema(self) -> None:
"""Validate the schema."""
for field in [
self.id_field,
self.vector_field,
self.text_field,
self.attributes_field,
]:
if not is_valid_field_name(field):
msg = f"Unsafe or invalid field name: {field}"
raise ValueError(msg)

@model_validator(mode="after")
def _validate_model(self):
"""Validate the model."""
self._validate_schema()
return self
42 changes: 31 additions & 11 deletions graphrag/index/operations/embed_text/embed_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.embeddings import create_collection_name
from graphrag.config.embeddings import create_index_name
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
from graphrag.vector_stores.factory import VectorStoreFactory
Expand Down Expand Up @@ -49,9 +50,9 @@ async def embed_text(
vector_store_config = strategy.get("vector_store")

if vector_store_config:
collection_name = _get_collection_name(vector_store_config, embedding_name)
index_name = _get_index_name(vector_store_config, embedding_name)
vector_store: BaseVectorStore = _create_vector_store(
vector_store_config, collection_name
vector_store_config, index_name, embedding_name
)
vector_store_workflow_config = vector_store_config.get(
embedding_name, vector_store_config
Expand Down Expand Up @@ -183,27 +184,46 @@ async def _text_embed_with_vector_store(


def _create_vector_store(
vector_store_config: dict, collection_name: str
vector_store_config: dict, index_name: str, embedding_name: str | None = None
) -> BaseVectorStore:
vector_store_type: str = str(vector_store_config.get("type"))
if collection_name:
vector_store_config.update({"collection_name": collection_name})

embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get(
"embeddings_schema", {}
)
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()

if (
embeddings_schema is not None
and embedding_name is not None
and embedding_name in embeddings_schema
):
raw_config = embeddings_schema[embedding_name]
if isinstance(raw_config, dict):
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
else:
single_embedding_config = raw_config

if single_embedding_config.index_name is None:
single_embedding_config.index_name = index_name

vector_store = VectorStoreFactory().create_vector_store(
vector_store_type, kwargs=vector_store_config
vector_store_schema_config=single_embedding_config,
vector_store_type=vector_store_type,
kwargs=vector_store_config,
)

vector_store.connect(**vector_store_config)
return vector_store


def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
def _get_index_name(vector_store_config: dict, embedding_name: str) -> str:
container_name = vector_store_config.get("container_name", "default")
collection_name = create_collection_name(container_name, embedding_name)
index_name = create_index_name(container_name, embedding_name)

msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}"
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {index_name}"
logger.info(msg)
return collection_name
return index_name


def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy:
Expand Down
28 changes: 25 additions & 3 deletions graphrag/utils/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

from graphrag.cache.factory import CacheFactory
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.embeddings import create_collection_name
from graphrag.config.embeddings import create_index_name
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.storage_config import StorageConfig
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder
from graphrag.storage.factory import StorageFactory
from graphrag.storage.pipeline_storage import PipelineStorage
Expand Down Expand Up @@ -103,12 +104,33 @@ def get_embedding_store(
index_names = []
for index, store in config_args.items():
vector_store_type = store["type"]
collection_name = create_collection_name(
index_name = create_index_name(
store.get("container_name", "default"), embedding_name
)

embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get(
"embeddings_schema", {}
)
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()

if (
embeddings_schema is not None
and embedding_name is not None
and embedding_name in embeddings_schema
):
raw_config = embeddings_schema[embedding_name]
if isinstance(raw_config, dict):
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
else:
single_embedding_config = raw_config

if single_embedding_config.index_name is None:
single_embedding_config.index_name = index_name

embedding_store = VectorStoreFactory().create_vector_store(
vector_store_type=vector_store_type,
kwargs={**store, "collection_name": collection_name},
vector_store_schema_config=single_embedding_config,
kwargs={**store},
)
embedding_store.connect(**store)
# If there is only a single index, return the embedding store directly
Expand Down
Loading
Loading