-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathconfig.py
More file actions
200 lines (167 loc) · 6.77 KB
/
config.py
File metadata and controls
200 lines (167 loc) · 6.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""Configuration management for vector database embedder application."""
import json
import logging
import os
from dataclasses import dataclass
from typing import Dict, List
from dotenv import load_dotenv
from langchain_huggingface import HuggingFaceEmbeddings
from vector_db.db_provider import DBProvider
from vector_db.dryrun_provider import DryRunProvider
from vector_db.elastic_provider import ElasticProvider
from vector_db.mssql_provider import MSSQLProvider
from vector_db.pgvector_provider import PGVectorProvider
from vector_db.qdrant_provider import QdrantProvider
from vector_db.redis_provider import RedisProvider
@dataclass
class Config:
"""
Global configuration object for embedding and vector DB ingestion jobs.
This class loads configuration from environment variables and initializes
all the required components (e.g., DB providers, chunking strategy, input sources).
Attributes:
db_provider (DBProvider): Initialized provider for a vector database.
chunk_size (int): Character length for each document chunk.
chunk_overlap (int): Number of overlapping characters between adjacent chunks.
web_sources (List[str]): List of web URLs to scrape and embed.
repo_sources (List[Dict]): Repositories and glob patterns for file discovery.
temp_dir (str): Path to a temporary working directory.
Example:
>>> config = Config.load()
>>> print(config.chunk_size)
>>> config.db_provider.add_documents(docs)
"""
db_provider: DBProvider
chunk_size: int
chunk_overlap: int
web_sources: List[str]
repo_sources: List[Dict]
temp_dir: str
@staticmethod
def _get_required_env_var(key: str) -> str:
"""
Retrieve a required environment variable or raise an error.
Args:
key (str): The environment variable name.
Returns:
str: The value of the environment variable.
Raises:
ValueError: If the variable is not defined.
"""
value = os.getenv(key)
if not value:
raise ValueError(f"{key} environment variable is required.")
return value
@staticmethod
def _parse_log_level(log_level_name: str) -> int:
"""
Convert a string log level into a `logging` module constant.
Args:
log_level_name (str): One of DEBUG, INFO, WARNING, ERROR, CRITICAL.
Returns:
int: Corresponding `logging` level.
Raises:
ValueError: If an invalid level is provided.
"""
log_levels = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
if log_level_name not in log_levels:
raise ValueError(
f"Invalid LOG_LEVEL: '{log_level_name}'. "
f"Must be one of: {', '.join(log_levels.keys())}"
)
return log_levels[log_level_name]
@staticmethod
def _init_db_provider(db_type: str) -> DBProvider:
"""
Factory method to initialize the correct DB provider from environment variables.
Args:
db_type (str): Type of DB specified via `DB_TYPE` (e.g., REDIS, PGVECTOR, QDRANT, etc.)
Returns:
DBProvider: Initialized instance of a provider subclass.
Raises:
ValueError: If the DB type is unsupported or required vars are missing.
"""
get = Config._get_required_env_var
db_type = db_type.upper()
embeddings = HuggingFaceEmbeddings(model_name=get("EMBEDDING_MODEL"))
match db_type:
case "REDIS":
url = get("REDIS_URL")
index = os.getenv("REDIS_INDEX", "docs")
return RedisProvider(embeddings, url, index)
case "ELASTIC":
url = get("ELASTIC_URL")
password = get("ELASTIC_PASSWORD")
index = os.getenv("ELASTIC_INDEX", "docs")
user = os.getenv("ELASTIC_USER", "elastic")
return ElasticProvider(embeddings, url, password, index, user)
case "PGVECTOR":
url = get("PGVECTOR_URL")
collection = get("PGVECTOR_COLLECTION_NAME")
return PGVectorProvider(embeddings, url, collection)
case "MSSQL":
connection_string = get("MSSQL_CONNECTION_STRING")
table = get("MSSQL_TABLE")
return MSSQLProvider(embeddings, connection_string, table)
case "QDRANT":
url = get("QDRANT_URL")
collection = get("QDRANT_COLLECTION")
return QdrantProvider(embeddings, url, collection)
case "DRYRUN":
return DryRunProvider(embeddings)
case _:
raise ValueError(f"Unsupported DB_TYPE '{db_type}'")
@staticmethod
def load() -> "Config":
"""
Load application settings from `.env` variables into a typed config object.
This includes logging level setup, DB provider initialization, and input
source validation.
Returns:
Config: A fully-initialized configuration object.
Raises:
ValueError: If required environment variables are missing or malformed.
"""
load_dotenv()
get = Config._get_required_env_var
# Logging setup
log_level = get("LOG_LEVEL").upper()
logging.basicConfig(
level=Config._parse_log_level(log_level),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
logger.debug("Logging initialized at level: %s", log_level)
# Database backend
db_type = get("DB_TYPE")
db_provider = Config._init_db_provider(db_type)
# Web source URLs
try:
web_sources = json.loads(get("WEB_SOURCES"))
except json.JSONDecodeError as e:
raise ValueError(f"WEB_SOURCES must be a valid JSON list: {e}") from e
# Git repositories and file matchers
try:
repo_sources = json.loads(get("REPO_SOURCES"))
except json.JSONDecodeError as e:
raise ValueError(f"Invalid REPO_SOURCES JSON: {e}") from e
# Embedding chunking strategy
chunk_size = int(get("CHUNK_SIZE"))
chunk_overlap = int(get("CHUNK_OVERLAP"))
# Temporary file location
temp_dir = get("TEMP_DIR")
return Config(
db_provider=db_provider,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
web_sources=web_sources,
repo_sources=repo_sources,
temp_dir=temp_dir,
)