Skip to content

Commit 0e9837c

Browse files
committed
feat: Vector retrieval matches datasource
1 parent 0b53b62 commit 0e9837c

File tree

3 files changed

+110
-44
lines changed

3 files changed

+110
-44
lines changed

backend/apps/chat/task/llm.py

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from apps.data_training.curd.data_training import get_training_template
3333
from apps.datasource.crud.datasource import get_table_schema
3434
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
35+
from apps.datasource.embedding.ds_embedding import get_ds_embedding
3536
from apps.datasource.models.datasource import CoreDatasource
3637
from apps.db.db import exec_sql, get_version, check_connection
3738
from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds
@@ -427,56 +428,63 @@ def select_datasource(self):
427428
full_thinking_text = ''
428429
full_text = ''
429430

431+
ds = None
430432
if not ignore_auto_select:
431-
_ds_list_dict = []
432-
for _ds in _ds_list:
433-
_ds_list_dict.append(_ds)
434-
datasource_msg.append(
435-
HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode())))
436-
437-
self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session,
438-
ai_modal_id=self.chat_question.ai_modal_id,
439-
ai_modal_name=self.chat_question.ai_modal_name,
440-
operate=OperationEnum.CHOOSE_DATASOURCE,
441-
record_id=self.record.id,
442-
full_message=[{'type': msg.type,
443-
'content': msg.content} for
444-
msg in datasource_msg])
445-
446-
token_usage = {}
447-
res = self.llm.stream(datasource_msg)
448-
for chunk in res:
449-
SQLBotLogUtil.info(chunk)
450-
reasoning_content_chunk = ''
451-
if 'reasoning_content' in chunk.additional_kwargs:
452-
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
453-
# else:
454-
# reasoning_content_chunk = chunk.get('reasoning_content')
455-
if reasoning_content_chunk is None:
433+
if settings.EMBEDDING_ENABLED:
434+
ds = get_ds_embedding(self.session, self.current_user, _ds_list, self.chat_question.question)
435+
else:
436+
_ds_list_dict = []
437+
for _ds in _ds_list:
438+
_ds_list_dict.append(_ds)
439+
datasource_msg.append(
440+
HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode())))
441+
442+
self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session,
443+
ai_modal_id=self.chat_question.ai_modal_id,
444+
ai_modal_name=self.chat_question.ai_modal_name,
445+
operate=OperationEnum.CHOOSE_DATASOURCE,
446+
record_id=self.record.id,
447+
full_message=[{'type': msg.type,
448+
'content': msg.content}
449+
for
450+
msg in datasource_msg])
451+
452+
token_usage = {}
453+
res = self.llm.stream(datasource_msg)
454+
for chunk in res:
455+
SQLBotLogUtil.info(chunk)
456456
reasoning_content_chunk = ''
457-
full_thinking_text += reasoning_content_chunk
458-
459-
full_text += chunk.content
460-
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
461-
get_token_usage(chunk, token_usage)
462-
datasource_msg.append(AIMessage(full_text))
463-
464-
self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session,
465-
log=self.current_logs[
466-
OperationEnum.CHOOSE_DATASOURCE],
467-
full_message=[
468-
{'type': msg.type, 'content': msg.content}
469-
for msg in datasource_msg],
470-
reasoning_content=full_thinking_text,
471-
token_usage=token_usage)
472-
473-
json_str = extract_nested_json(full_text)
457+
if 'reasoning_content' in chunk.additional_kwargs:
458+
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
459+
# else:
460+
# reasoning_content_chunk = chunk.get('reasoning_content')
461+
if reasoning_content_chunk is None:
462+
reasoning_content_chunk = ''
463+
full_thinking_text += reasoning_content_chunk
464+
465+
full_text += chunk.content
466+
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
467+
get_token_usage(chunk, token_usage)
468+
datasource_msg.append(AIMessage(full_text))
469+
470+
self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session,
471+
log=self.current_logs[
472+
OperationEnum.CHOOSE_DATASOURCE],
473+
full_message=[
474+
{'type': msg.type,
475+
'content': msg.content}
476+
for msg in datasource_msg],
477+
reasoning_content=full_thinking_text,
478+
token_usage=token_usage)
479+
480+
json_str = extract_nested_json(full_text)
481+
ds = orjson.loads(json_str)
474482

475483
_error: Exception | None = None
476484
_datasource: int | None = None
477485
_engine_type: str | None = None
478486
try:
479-
data: dict = _ds_list[0] if ignore_auto_select else orjson.loads(json_str)
487+
data: dict = _ds_list[0] if ignore_auto_select else ds
480488

481489
if data.get('id') and data.get('id') != 0:
482490
_datasource = data['id']
@@ -515,7 +523,7 @@ def select_datasource(self):
515523
except Exception as e:
516524
_error = e
517525

518-
if not ignore_auto_select:
526+
if not ignore_auto_select and not settings.EMBEDDING_ENABLED:
519527
self.record = save_select_datasource_answer(session=self.session, record_id=self.record.id,
520528
answer=orjson.dumps({'content': full_text}).decode(),
521529
datasource=_datasource,
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Author: Junjun
2+
# Date: 2025/9/18
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Author: Junjun
2+
# Date: 2025/9/18
3+
import json
4+
import math
5+
import traceback
6+
7+
from apps.ai_model.embedding import EmbeddingModelCache
8+
from apps.datasource.crud.datasource import get_table_schema, get_ds
9+
from common.core.deps import SessionDep, CurrentUser
10+
11+
12+
def cosine_similarity(vec_a, vec_b):
13+
if len(vec_a) != len(vec_b):
14+
raise ValueError("The vector dimension must be the same")
15+
16+
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
17+
18+
norm_a = math.sqrt(sum(a * a for a in vec_a))
19+
norm_b = math.sqrt(sum(b * b for b in vec_b))
20+
21+
if norm_a == 0 or norm_b == 0:
22+
return 0.0
23+
24+
return dot_product / (norm_a * norm_b)
25+
26+
27+
def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, question: str):
28+
_list = []
29+
for _ds in _ds_list:
30+
if _ds.get('id'):
31+
ds = get_ds(session, _ds.get('id'))
32+
33+
table_schema = get_table_schema(session, current_user, ds)
34+
ds_info = f"{ds.name}, {ds.description}\n"
35+
ds_schema = ds_info + table_schema
36+
37+
_list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds})
38+
39+
if _list:
40+
try:
41+
text = [s.get('ds_schema') for s in _list]
42+
43+
model = EmbeddingModelCache.get_model()
44+
results = model.embed_documents(text)
45+
46+
q_embedding = model.embed_query(question)
47+
for index in range(len(results)):
48+
item = results[index]
49+
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)
50+
51+
_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
52+
print(json.dumps(_list))
53+
ds = _list[0].get('ds')
54+
return {"id": ds.id, "name": ds.name, "description": ds.description}
55+
except Exception:
56+
traceback.print_exc()

0 commit comments

Comments
 (0)