Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 52 additions & 44 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from apps.data_training.curd.data_training import get_training_template
from apps.datasource.crud.datasource import get_table_schema
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
from apps.datasource.embedding.ds_embedding import get_ds_embedding
from apps.datasource.models.datasource import CoreDatasource
from apps.db.db import exec_sql, get_version, check_connection
from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds
Expand Down Expand Up @@ -427,56 +428,63 @@ def select_datasource(self):
full_thinking_text = ''
full_text = ''

ds = None
if not ignore_auto_select:
_ds_list_dict = []
for _ds in _ds_list:
_ds_list_dict.append(_ds)
datasource_msg.append(
HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode())))

self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session,
ai_modal_id=self.chat_question.ai_modal_id,
ai_modal_name=self.chat_question.ai_modal_name,
operate=OperationEnum.CHOOSE_DATASOURCE,
record_id=self.record.id,
full_message=[{'type': msg.type,
'content': msg.content} for
msg in datasource_msg])

token_usage = {}
res = self.llm.stream(datasource_msg)
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
if reasoning_content_chunk is None:
if settings.EMBEDDING_ENABLED:
ds = get_ds_embedding(self.session, self.current_user, _ds_list, self.chat_question.question)
else:
_ds_list_dict = []
for _ds in _ds_list:
_ds_list_dict.append(_ds)
datasource_msg.append(
HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode())))

self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session,
ai_modal_id=self.chat_question.ai_modal_id,
ai_modal_name=self.chat_question.ai_modal_name,
operate=OperationEnum.CHOOSE_DATASOURCE,
record_id=self.record.id,
full_message=[{'type': msg.type,
'content': msg.content}
for
msg in datasource_msg])

token_usage = {}
res = self.llm.stream(datasource_msg)
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk

full_text += chunk.content
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
get_token_usage(chunk, token_usage)
datasource_msg.append(AIMessage(full_text))

self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session,
log=self.current_logs[
OperationEnum.CHOOSE_DATASOURCE],
full_message=[
{'type': msg.type, 'content': msg.content}
for msg in datasource_msg],
reasoning_content=full_thinking_text,
token_usage=token_usage)

json_str = extract_nested_json(full_text)
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk

full_text += chunk.content
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
get_token_usage(chunk, token_usage)
datasource_msg.append(AIMessage(full_text))

self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session,
log=self.current_logs[
OperationEnum.CHOOSE_DATASOURCE],
full_message=[
{'type': msg.type,
'content': msg.content}
for msg in datasource_msg],
reasoning_content=full_thinking_text,
token_usage=token_usage)

json_str = extract_nested_json(full_text)
ds = orjson.loads(json_str)

_error: Exception | None = None
_datasource: int | None = None
_engine_type: str | None = None
try:
data: dict = _ds_list[0] if ignore_auto_select else orjson.loads(json_str)
data: dict = _ds_list[0] if ignore_auto_select else ds

if data.get('id') and data.get('id') != 0:
_datasource = data['id']
Expand Down Expand Up @@ -515,7 +523,7 @@ def select_datasource(self):
except Exception as e:
_error = e

if not ignore_auto_select:
if not ignore_auto_select and not settings.EMBEDDING_ENABLED:
self.record = save_select_datasource_answer(session=self.session, record_id=self.record.id,
answer=orjson.dumps({'content': full_text}).decode(),
datasource=_datasource,
Expand Down
2 changes: 2 additions & 0 deletions backend/apps/datasource/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Author: Junjun
# Date: 2025/9/18
56 changes: 56 additions & 0 deletions backend/apps/datasource/embedding/ds_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Author: Junjun
# Date: 2025/9/18
import json
import math
import traceback

from apps.ai_model.embedding import EmbeddingModelCache
from apps.datasource.crud.datasource import get_table_schema, get_ds
from common.core.deps import SessionDep, CurrentUser


def cosine_similarity(vec_a, vec_b):
if len(vec_a) != len(vec_b):
raise ValueError("The vector dimension must be the same")

dot_product = sum(a * b for a, b in zip(vec_a, vec_b))

norm_a = math.sqrt(sum(a * a for a in vec_a))
norm_b = math.sqrt(sum(b * b for b in vec_b))

if norm_a == 0 or norm_b == 0:
return 0.0

return dot_product / (norm_a * norm_b)


def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, question: str):
_list = []
for _ds in _ds_list:
if _ds.get('id'):
ds = get_ds(session, _ds.get('id'))

table_schema = get_table_schema(session, current_user, ds)
ds_info = f"{ds.name}, {ds.description}\n"
ds_schema = ds_info + table_schema

_list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds})

if _list:
try:
text = [s.get('ds_schema') for s in _list]

model = EmbeddingModelCache.get_model()
results = model.embed_documents(text)

q_embedding = model.embed_query(question)
for index in range(len(results)):
item = results[index]
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)

_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
print(json.dumps(_list))
ds = _list[0].get('ds')
return {"id": ds.id, "name": ds.name, "description": ds.description}
except Exception:
traceback.print_exc()