From 0e9837cf913a209726881b87efd6db1b40891247 Mon Sep 17 00:00:00 2001 From: junjun Date: Thu, 18 Sep 2025 16:11:17 +0800 Subject: [PATCH 1/2] feat: Vector retrieval matches datasource --- backend/apps/chat/task/llm.py | 96 ++++++++++--------- backend/apps/datasource/embedding/__init__.py | 2 + .../apps/datasource/embedding/ds_embedding.py | 56 +++++++++++ 3 files changed, 110 insertions(+), 44 deletions(-) create mode 100644 backend/apps/datasource/embedding/__init__.py create mode 100644 backend/apps/datasource/embedding/ds_embedding.py diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 5455187da..1309f59da 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -32,6 +32,7 @@ from apps.data_training.curd.data_training import get_training_template from apps.datasource.crud.datasource import get_table_schema from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user +from apps.datasource.embedding.ds_embedding import get_ds_embedding from apps.datasource.models.datasource import CoreDatasource from apps.db.db import exec_sql, get_version, check_connection from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds @@ -427,56 +428,63 @@ def select_datasource(self): full_thinking_text = '' full_text = '' + ds = None if not ignore_auto_select: - _ds_list_dict = [] - for _ds in _ds_list: - _ds_list_dict.append(_ds) - datasource_msg.append( - HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode()))) - - self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session, - ai_modal_id=self.chat_question.ai_modal_id, - ai_modal_name=self.chat_question.ai_modal_name, - operate=OperationEnum.CHOOSE_DATASOURCE, - record_id=self.record.id, - full_message=[{'type': msg.type, - 'content': msg.content} for - msg in datasource_msg]) - - token_usage = {} - res = self.llm.stream(datasource_msg) - for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: + if settings.EMBEDDING_ENABLED: + ds = get_ds_embedding(self.session, self.current_user, _ds_list, self.chat_question.question) + else: + _ds_list_dict = [] + for _ds in _ds_list: + _ds_list_dict.append(_ds) + datasource_msg.append( + HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode()))) + + self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session, + ai_modal_id=self.chat_question.ai_modal_id, + ai_modal_name=self.chat_question.ai_modal_name, + operate=OperationEnum.CHOOSE_DATASOURCE, + record_id=self.record.id, + full_message=[{'type': msg.type, + 'content': msg.content} + for + msg in datasource_msg]) + + token_usage = {} + res = self.llm.stream(datasource_msg) + for chunk in res: + SQLBotLogUtil.info(chunk) reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) - datasource_msg.append(AIMessage(full_text)) - - self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session, - log=self.current_logs[ - OperationEnum.CHOOSE_DATASOURCE], - full_message=[ - {'type': msg.type, 'content': msg.content} - for msg in datasource_msg], - reasoning_content=full_thinking_text, - token_usage=token_usage) - - json_str = extract_nested_json(full_text) + if 'reasoning_content' in chunk.additional_kwargs: + reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') + # else: + # reasoning_content_chunk = chunk.get('reasoning_content') + if reasoning_content_chunk is None: + reasoning_content_chunk = '' + full_thinking_text += reasoning_content_chunk + + full_text += chunk.content + yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} + get_token_usage(chunk, token_usage) + datasource_msg.append(AIMessage(full_text)) + + self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session, + log=self.current_logs[ + OperationEnum.CHOOSE_DATASOURCE], + full_message=[ + {'type': msg.type, + 'content': msg.content} + for msg in datasource_msg], + reasoning_content=full_thinking_text, + token_usage=token_usage) + + json_str = extract_nested_json(full_text) + ds = orjson.loads(json_str) _error: Exception | None = None _datasource: int | None = None _engine_type: str | None = None try: - data: dict = _ds_list[0] if ignore_auto_select else orjson.loads(json_str) + data: dict = _ds_list[0] if ignore_auto_select else ds if data.get('id') and data.get('id') != 0: _datasource = data['id'] @@ -515,7 +523,7 @@ def select_datasource(self): except Exception as e: _error = e - if not ignore_auto_select: + if not ignore_auto_select and not settings.EMBEDDING_ENABLED: self.record = save_select_datasource_answer(session=self.session, record_id=self.record.id, answer=orjson.dumps({'content': full_text}).decode(), datasource=_datasource, diff --git a/backend/apps/datasource/embedding/__init__.py b/backend/apps/datasource/embedding/__init__.py new file mode 100644 index 000000000..87bb6f5dd --- /dev/null +++ b/backend/apps/datasource/embedding/__init__.py @@ -0,0 +1,2 @@ +# Author: Junjun +# Date: 2025/9/18 diff --git a/backend/apps/datasource/embedding/ds_embedding.py b/backend/apps/datasource/embedding/ds_embedding.py new file mode 100644 index 000000000..f0524b763 --- /dev/null +++ b/backend/apps/datasource/embedding/ds_embedding.py @@ -0,0 +1,56 @@ +# Author: Junjun +# Date: 2025/9/18 +import json +import math +import traceback + +from apps.ai_model.embedding import EmbeddingModelCache +from apps.datasource.crud.datasource import get_table_schema, get_ds +from common.core.deps import SessionDep, CurrentUser + + +def cosine_similarity(vec_a, vec_b): + if len(vec_a) != len(vec_b): + raise ValueError("The vector dimension must be the same") + + dot_product = sum(a * b for a, b in zip(vec_a, vec_b)) + + norm_a = math.sqrt(sum(a * a for a in vec_a)) + norm_b = math.sqrt(sum(b * b for b in vec_b)) + + if norm_a == 0 or norm_b == 0: + return 0.0 + + return dot_product / (norm_a * norm_b) + + +def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, question: str): + _list = [] + for _ds in _ds_list: + if _ds.get('id'): + ds = get_ds(session, _ds.get('id')) + + table_schema = get_table_schema(session, current_user, ds) + ds_info = f"{ds.name}, {ds.description}\n" + ds_schema = ds_info + table_schema + + _list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds}) + + if _list: + try: + text = [s.get('ds_schema') for s in _list] + + model = EmbeddingModelCache.get_model() + results = model.embed_documents(text) + + q_embedding = model.embed_query(question) + for index in range(len(results)): + item = results[index] + _list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item) + + _list.sort(key=lambda x: x['cosine_similarity'], reverse=True) + print(json.dumps(_list)) + ds = _list[0].get('ds') + return {"id": ds.id, "name": ds.name, "description": ds.description} + except Exception: + traceback.print_exc() From 5900980dc938ad01b19c0f024d10a70caf4d84fd Mon Sep 17 00:00:00 2001 From: junjun Date: Thu, 18 Sep 2025 17:09:35 +0800 Subject: [PATCH 2/2] feat: Vector retrieval matches datasource --- backend/apps/chat/task/llm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 63354fa88..4d72481dc 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -432,6 +432,7 @@ def select_datasource(self): if not ignore_auto_select: if settings.EMBEDDING_ENABLED: ds = get_ds_embedding(self.session, self.current_user, _ds_list, self.chat_question.question) + yield {'content': '{"id":' + ds.id + '}'} else: _ds_list_dict = [] for _ds in _ds_list: