Skip to content

Commit bfd8561

Browse files
committed
feat(llm): 添加获取所有表关系的功能并更新意图识别模板
1 parent 1947a22 commit bfd8561

2 files changed

Lines changed: 43 additions & 1 deletion

File tree

backend/apps/chat/task/llm.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from apps.datasource.crud.datasource import get_table_schema, get_table_schema_by_names, get_table_obj_by_ds, get_mschema_by_table_names
3939
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
4040
from apps.datasource.embedding.ds_embedding import get_ds_embedding
41-
from apps.datasource.models.datasource import CoreDatasource
41+
from apps.datasource.models.datasource import CoreDatasource, CoreTable, CoreField
4242
from apps.db.db import exec_sql, get_version, check_connection
4343
from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds
4444
from apps.system.crud.parameter_manage import get_groups
@@ -244,18 +244,57 @@ def _get_all_tables_brief(self, session: Session) -> str:
244244
lines.append(f"- {name}: {comment}")
245245
return "\n".join(lines)
246246

247+
def _get_all_table_relations(self, session: Session) -> str:
248+
"""获取所有表关系(用于意图识别)"""
249+
if not self.ds.table_relation:
250+
return ""
251+
252+
relations = [r for r in self.ds.table_relation if r.get('shape') == 'edge']
253+
if not relations:
254+
return ""
255+
256+
# 获取所有涉及的表ID和字段ID
257+
table_ids = set()
258+
field_ids = set()
259+
for r in relations:
260+
table_ids.add(r.get('source').get('cell'))
261+
table_ids.add(r.get('target').get('cell'))
262+
field_ids.add(r.get('source').get('port'))
263+
field_ids.add(r.get('target').get('port'))
264+
265+
# 查询表名和字段名映射
266+
table_records = session.query(CoreTable).filter(CoreTable.id.in_(list(map(int, table_ids)))).all()
267+
table_dict = {t.id: t.table_name for t in table_records}
268+
269+
field_records = session.query(CoreField).filter(CoreField.id.in_(list(map(int, field_ids)))).all()
270+
field_dict = {f.id: f.field_name for f in field_records}
271+
272+
# 拼接输出
273+
lines = ["【Foreign keys】"]
274+
for r in relations:
275+
src_table = table_dict.get(int(r.get('source').get('cell')))
276+
src_field = field_dict.get(int(r.get('source').get('port')))
277+
tgt_table = table_dict.get(int(r.get('target').get('cell')))
278+
tgt_field = field_dict.get(int(r.get('target').get('port')))
279+
if src_table and src_field and tgt_table and tgt_field:
280+
lines.append(f"{src_table}.{src_field}={tgt_table}.{tgt_field}")
281+
282+
return "\n".join(lines) if len(lines) > 1 else ""
283+
247284
def init_intent_messages(self, session: Session):
248285
"""构建意图识别的消息列表(仿照 init_messages)"""
249286
last_messages = self.intent_logs[-1].messages if len(self.intent_logs) > 0 else []
250287

251288
self.intent_message = []
252289
# 1. System Prompt
253290
table_list = self._get_all_tables_brief(session)
291+
table_relations = self._get_all_table_relations(session)
254292
template = get_base_template()
255293
intent_template = template['template']['intent_recognition']['system']
256294
self.intent_message.append(SystemMessage(
257295
content=intent_template.format(
258296
table_list=table_list,
297+
table_relations=table_relations,
259298
lang=self.chat_question.lang
260299
)
261300
))

backend/templates/template.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,9 @@ template:
639639
- 当前数据库表列表:
640640
{table_list}
641641
642+
- 表关系:
643+
{table_relations}
644+
642645
[用户输入格式]
643646
用户的提问在 <user-question> 内,<background-infos> 内的 <current-time> 会告诉你用户当前提问的时间。
644647
请结合当前时间理解用户问题中的时间相关表述(如"今年"、"上个月"、"最近"等)。

0 commit comments

Comments
 (0)