Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,9 @@ def generate_analysis(self):
self.chat_question.data = orjson.dumps(data.get('data')).decode()
analysis_msg: List[Union[BaseMessage, dict[str, Any]]] = []

ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question,
self.current_user.oid)
self.current_user.oid, ds_id)

analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question()))
analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question()))
Expand Down Expand Up @@ -504,7 +505,8 @@ def select_datasource(self):
oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1
ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None

self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, oid)
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, oid,
ds_id)
self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, ds_id,
oid)

Expand Down Expand Up @@ -897,7 +899,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1
ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question,
oid)
oid, ds_id)
self.chat_question.data_training = get_training_template(self.session, self.chat_question.question,
ds_id, oid)

Expand Down
42 changes: 34 additions & 8 deletions backend/apps/terminology/curd/terminology.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import logging
import traceback
from typing import List, Optional
from typing import List, Optional, Any
from xml.dom.minidom import parseString

import dicttoxml
Expand Down Expand Up @@ -367,17 +367,22 @@ def save_embeddings(session: Session, ids: List[int]):
embedding_sql = f"""
SELECT id, pid, word, similarity
FROM
(SELECT id, pid, word, oid,
(SELECT id, pid, word, oid, specific_ds, datasource_ids,
( 1 - (embedding <=> :embedding_array) ) AS similarity
FROM terminology AS child
) TEMP
WHERE similarity > {settings.EMBEDDING_TERMINOLOGY_SIMILARITY} and oid = :oid
WHERE similarity > {settings.EMBEDDING_TERMINOLOGY_SIMILARITY} AND oid = :oid
AND (
(:datasource IS NULL AND (specific_ds = false OR specific_ds IS NULL))
OR
(:datasource IS NOT NULL AND ((specific_ds = false OR specific_ds IS NULL) OR (specific_ds = true AND datasource_ids IS NOT NULL AND datasource_ids @> jsonb_build_array(:datasource))))
)
ORDER BY similarity DESC
LIMIT {settings.EMBEDDING_TERMINOLOGY_TOP_COUNT}
"""


def select_terminology_by_word(session: SessionDep, word: str, oid: int):
def select_terminology_by_word(session: SessionDep, word: str, oid: int, datasource: int = None):
if word.strip() == "":
return []

Expand All @@ -394,7 +399,26 @@ def select_terminology_by_word(session: SessionDep, word: str, oid: int):
)
)

results = session.execute(stmt, {'sentence': word}).fetchall()
if datasource is not None:
stmt = stmt.where(
or_(
or_(Terminology.specific_ds == False, Terminology.specific_ds.is_(None)),
and_(
Terminology.specific_ds == True,
Terminology.datasource_ids.isnot(None),
text("datasource_ids @> jsonb_build_array(:datasource)")
)
)
)
else:
stmt = stmt.where(or_(Terminology.specific_ds == False, Terminology.specific_ds.is_(None)))

# 执行查询
params: dict[str, Any] = {'sentence': word}
if datasource is not None:
params['datasource'] = datasource

results = session.execute(stmt, params).fetchall()

for row in results:
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid))
Expand All @@ -405,7 +429,8 @@ def select_terminology_by_word(session: SessionDep, word: str, oid: int):

embedding = model.embed_query(word)

results = session.execute(text(embedding_sql), {'embedding_array': str(embedding), 'oid': oid})
results = session.execute(text(embedding_sql), {'embedding_array': str(embedding), 'oid': oid,
'datasource': datasource}).fetchall()

for row in results:
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid))
Expand Down Expand Up @@ -481,10 +506,11 @@ def to_xml_string(_dict: list[dict] | dict, root: str = 'terminologies') -> str:
return pretty_xml


def get_terminology_template(session: SessionDep, question: str, oid: Optional[int] = 1) -> str:
def get_terminology_template(session: SessionDep, question: str, oid: Optional[int] = 1,
datasource: Optional[int] = None) -> str:
if not oid:
oid = 1
_results = select_terminology_by_word(session, question, oid)
_results = select_terminology_by_word(session, question, oid, datasource)
if _results and len(_results) > 0:
terminology = to_xml_string(_results)
template = get_base_terminology_template().format(terminologies=terminology)
Expand Down
23 changes: 18 additions & 5 deletions backend/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ template:
<Instruction>
你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。
你当前的任务是根据给定的表结构和用户问题生成SQL语句、可能适合展示的图表类型以及该SQL中所用到的表名。
我们会在<Infos>块内提供给你信息,帮助你生成SQL:
<Infos>内有<db-engine><m-schema><terminologies>等信息;
我们会在<Info>块内提供给你信息,帮助你生成SQL:
<Info>内有<db-engine><m-schema><terminologies>等信息;
其中,<db-engine>:提供数据库引擎及版本信息;
<m-schema>:以 M-Schema 格式提供数据库表结构信息;
<terminologies>:提供一组术语,块内每一个<terminology>就是术语,其中同一个<words>内的多个<word>代表术语的多种叫法,也就是术语与它的同义词,<description>即该术语对应的描述,其中也可能是能够用来参考的计算公式,或者是一些其他的查询条件
Expand Down Expand Up @@ -389,12 +389,25 @@ template:
{old_questions}
analysis:
system: |
### 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出
<Instruction>
你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。
你当前的任务是根据给定的数据分析数据,并给出你的分析结果。
我们会在<Info>块内提供给你信息,帮助你进行分析:
<Info>内有<terminologies>等信息;
<terminologies>:提供一组术语,块内每一个<terminology>就是术语,其中同一个<words>内的多个<word>代表术语的多种叫法,也就是术语与它的同义词,<description>即该术语对应的描述,其中也可能是能够用来参考的计算公式,或者是一些其他的查询条件
</Instruction>

### 说明:
你是一个数据分析师,你的任务是根据给定的数据分析数据,并给出你的分析结果。
你必须遵守以下规则:
<Rules>
<rule>
请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出
</rule>
</Rules>

### 下面是提供的信息
<Info>
{terminologies}
</Info>
user: |
### 字段(字段别名):
{fields}
Expand Down