Skip to content

Commit 3502c22

Browse files
authored
reconfigure vector store size (microsoft#2281)
1 parent ae2508d commit 3502c22

2 files changed

Lines changed: 37 additions & 1 deletion

File tree

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "reconfigure vector store size by embedding model"
4+
}

packages/graphrag/graphrag/index/validate_config.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
import asyncio
77
import logging
88
import sys
9+
from typing import TYPE_CHECKING
910

1011
from graphrag_llm.completion import create_completion
1112
from graphrag_llm.embedding import create_embedding
1213

1314
from graphrag.config.models.graph_rag_config import GraphRagConfig
1415

16+
if TYPE_CHECKING:
17+
from graphrag_llm.types import LLMEmbeddingResponse
18+
1519
logger = logging.getLogger(__name__)
1620

1721

@@ -29,13 +33,41 @@ def validate_config_names(parameters: GraphRagConfig) -> None:
2933
for id, config in parameters.embedding_models.items():
3034
embed_llm = create_embedding(config)
3135
try:
32-
asyncio.run(
36+
response = asyncio.run(
3337
embed_llm.embedding_async(
3438
input=["This is an LLM Embedding Test String"]
3539
)
3640
)
3741
logger.info("Embedding LLM Config Params Validated")
42+
43+
if id == parameters.embed_text.embedding_model_id:
44+
_sync_vector_store_dimensions(parameters, response)
45+
3846
except Exception as e: # noqa: BLE001
3947
logger.error(f"Embedding configuration error detected.\n{e}") # noqa
4048
print(f"Failed to validate embedding model ({id}) params", e) # noqa: T201
4149
sys.exit(1)
50+
51+
52+
def _sync_vector_store_dimensions(
53+
parameters: GraphRagConfig,
54+
response: "LLMEmbeddingResponse",
55+
) -> None:
56+
"""Sync vector store dimensions to match the actual embedding model output."""
57+
detected = len(response.first_embedding)
58+
if detected == 0:
59+
return
60+
61+
configured = parameters.vector_store.vector_size
62+
if detected == configured:
63+
return
64+
65+
logger.warning(
66+
"Embedding model produces %d-dimensional vectors but vector_size is "
67+
"configured as %d. Overriding vector_size to match the model.",
68+
detected,
69+
configured,
70+
)
71+
parameters.vector_store.vector_size = detected
72+
for schema in parameters.vector_store.index_schema.values():
73+
schema.vector_size = detected

0 commit comments

Comments
 (0)