Skip to content

Commit 7251c98

Browse files
committed
Refactor configuration management to use Pydantic for better validation and structure
- Replaced the existing config loading mechanism with a Pydantic-based Settings model. - Updated all modules to access configuration values through the new settings instance. - Enhanced the knowledge base and chat configurations for improved clarity and usability. - Added new fields for API keys and base URLs, ensuring they are loaded from environment variables or .env files. - Improved the display of configuration information in the console for better user experience. - Removed deprecated functions and streamlined the codebase for maintainability.
1 parent d7d2bb3 commit 7251c98

13 files changed

Lines changed: 361 additions & 221 deletions

File tree

config.ini.example

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ KB_REPLACE_WHITESPACE = False # 是否在文本预处理时将所有连续的空
4545
KB_REMOVE_SPACES = False # 是否在文本预处理时移除所有空格 (请谨慎使用)
4646
KB_REMOVE_URLS = False # 是否在文本预处理时移除URL和电子邮件地址
4747
KB_USE_QA_SEGMENTATION = False # 是否启用QA对分割模式 (将文档按预设的问答对格式进行切分)
48-
KB_SPLITTER_SEPARATORS = ### # 文本分割器使用的分隔符,可以设置多个,用逗号隔开
48+
KB_SPLITTER_SEPARATORS = ###
49+
# 文本分割器使用的分隔符。如果要使用多个,请用英文逗号隔开,例如: ###,---,===
4950
KB_CHUNK_SIZE = 1500 # 文本切块的最大长度 (单位:字符)
5051
KB_CHUNK_OVERLAP = 150 # 文本切块之间的重叠长度 (单位:字符)
5152
KB_EMBEDDING_BATCH_SIZE = 32 # 向量化处理时,每批处理的文本数量 (可根据显存或API限制调整)
@@ -58,7 +59,7 @@ DEFAULT_RERANK_PROVIDER = siliconflow # 默认使用的Rerank模型 (必须是
5859

5960
[CHAT]
6061
# 聊天机器人核心功能配置
61-
CHAT_RETRIEVAL_METHOD = HYBRID_SEARCH # 检索方法, 可选: SEMANTIC_SEARCH (向量搜索), FULL_TEXT_SEARCH (关键词搜索), HYBRID_SEARCH (混合搜索)
62+
CHAT_RETRIEVAL_METHOD = HYBRID_SEARCH # 检索方法。可选值: SEMANTIC_SEARCH (或 "向量检索"), FULL_TEXT_SEARCH (或 "全文检索"), HYBRID_SEARCH (或 "混合检索")
6263
CHAT_VECTOR_WEIGHT = 0.3 # 混合搜索中,向量搜索结果的权重 (与关键词权重相加建议为1)
6364
CHAT_KEYWORD_WEIGHT = 0.7 # 混合搜索中,关键词搜索结果的权重 (与向量权重相加建议为1)
6465
CHAT_RERANK_ENABLED = False # 是否启用Rerank精排模型对检索结果进行二次排序

main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def display_banner():
6262
welcome_text.append("Mison", style="default")
6363
welcome_text.append(" · 邮箱: ", style="bold")
6464
welcome_text.append("1360962086@qq.com", style="default")
65-
welcome_text.append(" · GitHub: ", style="bold")
65+
welcome_text.append("\n") # 换行
66+
welcome_text.append("GitHub: ", style="bold")
6667
# 使用正确的 GitHub 仓库地址
6768
github_url = "https://github.com/MisonL/PyRAG-Kit"
6869
welcome_text.append(github_url, style=f"link {github_url}")

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@ rank_bm25
1414
questionary
1515
pyfiglet
1616
prompt_toolkit
17+
pydantic
18+
pydantic-settings

scripts/embed_knowledge_base.py

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
from rich.console import Console
2020
from rich.panel import Panel
2121
from rich.table import Table
22-
from rich.console import Group
22+
from rich import box
2323

2424
# 从 src 导入重构后的模块
25-
from src.utils.config import KB_CONFIG, CHAT_CONFIG, API_CONFIG
25+
from src.utils.config import settings, KB_CONFIG, API_CONFIG
2626
from src.providers.factory import ModelProviderFactory
2727
from src.retrieval.retriever import VectorStore # VectorStore现在从这里导入
2828
from src.ui.display_utils import CONSOLE_WIDTH, get_relative_path
@@ -126,35 +126,68 @@ def process_documents(vector_store: VectorStore):
126126
# 4. 辅助与主函数 (HELPERS & MAIN)
127127
# =================================================================
128128
def display_config_and_confirm():
129-
"""显示全面的配置信息并请求用户确认。"""
129+
"""以美观的表格形式显示全面的配置信息,并请求用户确认。"""
130130
console = Console()
131131

132132
def mask_api_key(key: Optional[str]) -> str:
133-
if not key: return "[dim]未设置[/dim]"
134-
return f"{key[:10]}..."
135-
136-
kb_table = Table(title="[bold green]知识库构建配置 (KB_CONFIG)[/bold green]", show_header=False, box=None, padding=(0, 1))
137-
kb_table.add_column(style="cyan")
138-
kb_table.add_column(style="bold white")
139-
active_embedding = KB_CONFIG['active_embedding_configuration']
140-
embedding_model_details = KB_CONFIG['embedding_configurations'][active_embedding]
141-
kb_table.add_row("激活的嵌入模型:", f"{active_embedding} ({embedding_model_details['provider']}: {embedding_model_details['model_name']})")
142-
kb_table.add_row("知识库目录:", get_relative_path(KB_CONFIG['kb_dir']))
143-
kb_table.add_row("输出文件:", get_relative_path(KB_CONFIG['output_file']))
133+
"""对API密钥进行脱敏处理,使其更安全地显示。"""
134+
if not key or key == "lm-studio":
135+
return "[dim]未设置或无需设置[/dim]"
136+
if len(key) > 12:
137+
return f"[white]{key[:6]}...{key[-4:]}[/white]"
138+
return "[white]已设置[/white]"
139+
140+
table = Table(
141+
box=box.ROUNDED,
142+
padding=(0, 2),
143+
title="[bold yellow]向量化脚本配置总览[/bold yellow]",
144+
show_header=False,
145+
width=CONSOLE_WIDTH
146+
)
147+
# 参数名列: 蓝色
148+
table.add_column(justify="right", style="cyan", no_wrap=True, width=28)
149+
# 参数值列: 亮白色
150+
table.add_column(style="bright_white")
151+
152+
# --- 知识库配置 ---
153+
table.add_row("[bold green]知识库配置[/bold green]", "")
154+
table.add_row("知识库目录", f"[bold cyan]{get_relative_path(settings.knowledge_base_path)}[/bold cyan]")
155+
table.add_row("输出文件路径", f"[bold cyan]{get_relative_path(settings.pkl_path)}[/bold cyan]")
156+
table.add_row("文本切分块大小 (Chunk Size)", f"[bold magenta]{settings.kb_chunk_size}[/bold magenta]")
157+
table.add_row("文本切分重叠量 (Overlap)", f"[bold magenta]{settings.kb_chunk_overlap}[/bold magenta]")
158+
table.add_row("切分分隔符 (Separators)", f"[bold bright_white]{settings.kb_splitter_separators}[/bold bright_white]")
159+
table.add_section()
160+
161+
# --- 模型与API配置 ---
162+
table.add_row("[bold green]模型与API配置[/bold green]", "")
163+
active_embedding_key = settings.default_embedding_provider
164+
embedding_model_details = settings.embedding_configurations[active_embedding_key]
165+
provider = embedding_model_details.provider
144166

145-
api_table = Table(title="[bold green]相关API配置[/bold green]", show_header=False, box=None, padding=(0, 1))
146-
api_table.add_column(style="cyan")
147-
api_table.add_column(style="bold white")
148-
provider = embedding_model_details['provider']
149-
key_name = f"{provider.upper()}_API_KEY"
150-
url_name = f"{provider.upper()}_BASE_URL"
151-
if key_name in API_CONFIG:
152-
api_table.add_row(f"{key_name}:", mask_api_key(API_CONFIG.get(key_name)))
153-
if url_name in API_CONFIG and API_CONFIG.get(url_name):
154-
api_table.add_row(f"{url_name}:", API_CONFIG.get(url_name))
155-
156-
console.print(Panel(Group(kb_table, api_table), title="[bold yellow]向量化脚本配置总览[/bold yellow]", border_style="blue", width=CONSOLE_WIDTH))
157-
console.print("[yellow]配置信息来源于 [bold]src/utils/config.py[/bold] 和 [bold].env[/bold] 文件。[/yellow]")
167+
table.add_row("激活的嵌入模型 (Provider)", f"[bold green]{active_embedding_key}[/bold green] ([dim]{provider}[/dim])")
168+
table.add_row("模型名称 (Model Name)", f"[bold bright_white]{embedding_model_details.model_name}[/bold bright_white]")
169+
170+
# 动态获取API Key和Base URL的字段名
171+
key_field_name = f"{provider.lower()}_api_key"
172+
173+
# 处理不一致的URL字段名
174+
url_field_name = ""
175+
if provider == "openai":
176+
url_field_name = "openai_api_base"
177+
elif hasattr(settings, f"{provider.lower()}_base_url"):
178+
url_field_name = f"{provider.lower()}_base_url"
179+
180+
api_key_value = getattr(settings, key_field_name, None)
181+
# mask_api_key 函数已内置样式,无需额外添加
182+
table.add_row(f"对应的 API Key ({key_field_name.upper()})", mask_api_key(api_key_value))
183+
184+
if url_field_name:
185+
base_url_value = getattr(settings, url_field_name, None)
186+
if base_url_value:
187+
table.add_row(f"对应的 Base URL ({url_field_name.upper()})", f"[bold cyan]{base_url_value}[/bold cyan]")
188+
189+
console.print(table)
190+
console.print("[yellow]配置信息来源于 [bold]config.ini[/bold], [bold].env[/bold] 文件或 [bold]环境变量[/bold]。[/yellow]")
158191

159192
choice = console.input("是否使用以上配置继续处理? ([bold green]y[/bold green]/[bold red]n[/bold red]): ").lower()
160193
if choice not in ['y', 'yes']:

src/chat/core.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# 使用相对导入
2121
from ..providers.factory import ModelProviderFactory
2222
from ..providers.__base__.model_provider import LargeLanguageModel
23-
from ..utils.config import CHAT_CONFIG, KB_CONFIG, LOG_PATH
23+
from ..utils.config import CHAT_CONFIG, KB_CONFIG, LOG_PATH, PKL_PATH, settings
2424
from ..ui.config_menu import launch_config_editor
2525
from ..ui.display_utils import display_chat_config
2626
from ..retrieval.retriever import VectorStore, retrieve_documents
@@ -59,7 +59,7 @@ def __init__(self, console: Console):
5959
self.reload_llm() # 初始加载
6060

6161
def _load_vector_store(self) -> Optional[VectorStore]:
62-
file_path = KB_CONFIG["output_file"]
62+
file_path = str(PKL_PATH)
6363
try:
6464
with open(file_path, "rb") as f:
6565
data = pickle.load(f)
@@ -78,7 +78,7 @@ def _load_vector_store(self) -> Optional[VectorStore]:
7878
def reload_llm(self) -> bool:
7979
"""重新加载或初始化LLM模型。"""
8080
try:
81-
active_llm_key = CHAT_CONFIG['active_llm_configuration']
81+
active_llm_key = settings.default_llm_provider
8282
self.console.print(f"[dim]正在加载LLM: [bold cyan]{active_llm_key}[/bold cyan]...[/dim]")
8383
self.llm_model = ModelProviderFactory.get_llm_provider(active_llm_key)
8484
if self.llm_model:
@@ -103,7 +103,7 @@ def _identify_intent(self, user_query: str) -> str:
103103

104104
def _retrieve_knowledge(self, query: str) -> List[Dict[str, Any]]:
105105
if not self.vector_store: return []
106-
self.console.print(f"[dim]正在使用 '[yellow]{CHAT_CONFIG['retrieval_method'].value}[/yellow]' 模式检索...[/dim]")
106+
self.console.print(f"[dim]正在使用 '[yellow]{settings.chat_retrieval_method.value}[/yellow]' 模式检索...[/dim]")
107107
return retrieve_documents(query, self.vector_store, self.console)
108108

109109
def _generate_answer_stream(self, user_query: str, intent: str, retrieved_docs: List[Dict[str, Any]]) -> "Generator[str, None, None]":
@@ -181,7 +181,7 @@ def start_chat_session():
181181

182182
if bot.llm_model:
183183
display_chat_config(console)
184-
console.print(f"我是你的智能客服(由 [bold green]{CHAT_CONFIG['active_llm_configuration']}[/bold green] 支持),请输入问题(输入'[bold]/quit[/bold]'或'[bold]/config[/bold]'):")
184+
console.print(f"我是你的智能客服(由 [bold green]{settings.default_llm_provider}[/bold green] 支持),请输入问题(输入'[bold]/quit[/bold]'或'[bold]/config[/bold]'):")
185185

186186
while True:
187187
try:
@@ -197,7 +197,7 @@ def start_chat_session():
197197
console.print("[yellow]检测到LLM配置变更,正在重载模型...[/yellow]")
198198
bot.reload_llm()
199199
display_chat_config(console)
200-
console.print(f"我是你的智能客服(由 [bold green]{CHAT_CONFIG['active_llm_configuration']}[/bold green] 支持),请输入问题(输入'[bold]/quit[/bold]'或'[bold]/config[/bold]'):")
200+
console.print(f"我是你的智能客服(由 [bold green]{settings.default_llm_provider}[/bold green] 支持),请输入问题(输入'[bold]/quit[/bold]'或'[bold]/config[/bold]'):")
201201
continue
202202

203203
answer_stream = bot.chat(user_query)

src/providers/factory.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,7 @@
77
RerankModel,
88
TextEmbeddingModel,
99
)
10-
from src.utils.config import (
11-
CHAT_CONFIG,
12-
EMBEDDING_CONFIGS,
13-
LLM_CONFIGS,
14-
RERANK_CONFIGS,
15-
)
10+
from src.utils.config import settings
1611

1712
class ModelProviderFactory:
1813
"""模型提供商工厂"""
@@ -48,44 +43,44 @@ def _get_provider_class(provider_name: str) -> Type:
4843
@staticmethod
4944
def get_llm_provider(provider_key: str) -> LargeLanguageModel:
5045
"""获取一个语言模型提供商实例"""
51-
if provider_key not in LLM_CONFIGS:
46+
if provider_key not in settings.llm_configurations:
5247
raise ValueError(f"在LLM配置中未找到key: {provider_key}")
5348

54-
config = LLM_CONFIGS[provider_key]
55-
provider_name = config["provider"]
56-
model_name = config["model_name"]
49+
config = settings.llm_configurations[provider_key]
50+
provider_name = config.provider
51+
model_name = config.model_name
5752

5853
ProviderClass = ModelProviderFactory._get_provider_class(provider_name)
5954
return ProviderClass(model_name=model_name)
6055

6156
@staticmethod
6257
def get_embedding_provider(provider_key: str) -> TextEmbeddingModel:
6358
"""获取一个文本向量化模型提供商实例"""
64-
if provider_key not in EMBEDDING_CONFIGS:
59+
if provider_key not in settings.embedding_configurations:
6560
raise ValueError(f"在Embedding配置中未找到key: {provider_key}")
6661

67-
config = EMBEDDING_CONFIGS[provider_key]
68-
provider_name = config["provider"]
69-
model_name = config["model_name"]
62+
config = settings.embedding_configurations[provider_key]
63+
provider_name = config.provider
64+
model_name = config.model_name
7065

7166
ProviderClass = ModelProviderFactory._get_provider_class(provider_name)
7267
return ProviderClass(model_name=model_name)
7368

7469
@staticmethod
7570
def get_rerank_provider(provider_key: str) -> RerankModel:
7671
"""获取一个Rerank模型提供商实例"""
77-
if provider_key not in RERANK_CONFIGS:
72+
if provider_key not in settings.rerank_configurations:
7873
raise ValueError(f"在Rerank配置中未找到key: {provider_key}")
7974

80-
config = RERANK_CONFIGS[provider_key]
75+
config = settings.rerank_configurations[provider_key]
8176
# Rerank提供商的key可能与LLM/Embedding提供商的key冲突(如siliconflow)
8277
# 因此,我们在这里使用一个特殊的key,或者直接在配置中指定provider_map的key
8378
# 为了简单起见,我们假设rerank的provider name是唯一的
84-
provider_name = config["provider"]
79+
provider_name = config.provider
8580
if provider_name == "siliconflow":
8681
provider_name = "siliconflow_rerank" # 映射到唯一的rerank provider
8782

88-
model_name = config["model_name"]
83+
model_name = config.model_name
8984

9085
ProviderClass = ModelProviderFactory._get_provider_class(provider_name)
9186
return ProviderClass(model_name=model_name)

src/providers/grok.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import requests
55

66
from src.providers.__base__.model_provider import LargeLanguageModel
7-
from src.utils.config import API_CONFIG
7+
from src.utils.config import settings
88

99

1010
class GrokProvider(LargeLanguageModel):
@@ -14,10 +14,10 @@ class GrokProvider(LargeLanguageModel):
1414

1515
def __init__(self, model_name: str):
1616
self._model_name = model_name
17-
self._api_key = API_CONFIG.get("GROK_API_KEY")
17+
self._api_key = settings.grok_api_key
1818
if not self._api_key:
1919
raise ValueError("Grok配置不完整:缺少 GROK_API_KEY。")
20-
self._base_url = API_CONFIG.get("GROK_BASE_URL", "https://api.x.ai/v1")
20+
self._base_url = str(settings.grok_base_url)
2121

2222
def invoke(
2323
self,

src/providers/jina.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import requests
55

66
from src.providers.__base__.model_provider import RerankModel
7-
from src.utils.config import API_CONFIG
7+
from src.utils.config import settings
88

99

1010
class JinaProvider(RerankModel):
@@ -14,7 +14,7 @@ class JinaProvider(RerankModel):
1414

1515
def __init__(self, model_name: str):
1616
self._model_name = model_name
17-
self._api_key = API_CONFIG.get("JINA_API_KEY")
17+
self._api_key = settings.jina_api_key
1818
if not self._api_key:
1919
raise ValueError("错误:Jina Rerank 提供商需要 API 密钥。")
2020
self._base_url = "https://api.jina.ai/v1/rerank"

src/providers/volcengine.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
LargeLanguageModel,
1212
TextEmbeddingModel,
1313
)
14-
from src.utils.config import API_CONFIG
14+
from src.utils.config import settings
1515

1616

1717
class VolcengineProvider(LargeLanguageModel, TextEmbeddingModel):
@@ -21,9 +21,9 @@ class VolcengineProvider(LargeLanguageModel, TextEmbeddingModel):
2121

2222
def __init__(self, model_name: str):
2323
self._model_name = model_name
24-
self._access_key = API_CONFIG.get("VOLC_ACCESS_KEY")
25-
self._secret_key = API_CONFIG.get("VOLC_SECRET_KEY")
26-
self._base_url = API_CONFIG.get("VOLC_BASE_URL")
24+
self._access_key = settings.volc_access_key
25+
self._secret_key = settings.volc_secret_key
26+
self._base_url = str(settings.volc_base_url)
2727

2828
if not all([self._access_key, self._secret_key, self._base_url]):
2929
raise ValueError("火山引擎配置不完整:缺少 VOLC_ACCESS_KEY, VOLC_SECRET_KEY, 或 VOLC_BASE_URL。")

src/retrieval/retriever.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from rich.console import Console
1212

1313
# 使用相对导入来引用同一 src 目录下的模块
14-
from ..utils.config import CHAT_CONFIG, KB_CONFIG, RetrievalMethod
14+
from ..utils.config import CHAT_CONFIG, KB_CONFIG, RetrievalMethod, settings
1515
from ..providers.factory import ModelProviderFactory
1616

1717
# =================================================================
@@ -74,12 +74,12 @@ def rerank(self, documents: List[Dict]) -> List[Dict]:
7474
return sorted(documents, key=lambda x: x["score"], reverse=True)
7575

7676
def retrieve_documents(query: str, vector_store: VectorStore, console: Console) -> List[Dict]:
77-
retrieval_method = CHAT_CONFIG["retrieval_method"]
78-
top_k = CHAT_CONFIG["top_k"]
79-
score_threshold = CHAT_CONFIG["score_threshold"]
77+
retrieval_method = settings.chat_retrieval_method
78+
top_k = settings.chat_top_k
79+
score_threshold = settings.chat_score_threshold
8080

8181
# 语义搜索
82-
active_embedding_key = KB_CONFIG['active_embedding_configuration']
82+
active_embedding_key = settings.default_embedding_provider
8383
embedding_provider = ModelProviderFactory.get_embedding_provider(active_embedding_key)
8484
query_embedding = np.array(embedding_provider.embed_documents([query])[0])
8585
semantic_results = vector_store.semantic_search(query_embedding, top_k, score_threshold)
@@ -103,8 +103,8 @@ def retrieve_documents(query: str, vector_store: VectorStore, console: Console)
103103
ranked_results = sorted(full_text_results, key=lambda x: x.get("keyword_score", 0), reverse=True)
104104

105105
# 使用Reranker(如果启用)
106-
if CHAT_CONFIG["rerank_enabled"]:
107-
active_rerank_key = CHAT_CONFIG['active_rerank_configuration']
106+
if settings.chat_rerank_enabled:
107+
active_rerank_key = settings.default_rerank_provider
108108
rerank_provider = ModelProviderFactory.get_rerank_provider(active_rerank_key)
109109
if rerank_provider and ranked_results:
110110
console.print(f"[dim]正在使用 '{active_rerank_key}' 进行重排...[/dim]")

0 commit comments

Comments
 (0)