|
38 | 38 | from apps.datasource.crud.datasource import get_table_schema, get_table_schema_by_names, get_table_obj_by_ds, get_mschema_by_table_names |
39 | 39 | from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user |
40 | 40 | from apps.datasource.embedding.ds_embedding import get_ds_embedding |
41 | | -from apps.datasource.models.datasource import CoreDatasource |
| 41 | +from apps.datasource.models.datasource import CoreDatasource, CoreTable, CoreField |
42 | 42 | from apps.db.db import exec_sql, get_version, check_connection |
43 | 43 | from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds |
44 | 44 | from apps.system.crud.parameter_manage import get_groups |
@@ -244,18 +244,57 @@ def _get_all_tables_brief(self, session: Session) -> str: |
244 | 244 | lines.append(f"- {name}: {comment}") |
245 | 245 | return "\n".join(lines) |
246 | 246 |
|
| 247 | + def _get_all_table_relations(self, session: Session) -> str: |
| 248 | + """获取所有表关系(用于意图识别)""" |
| 249 | + if not self.ds.table_relation: |
| 250 | + return "" |
| 251 | + |
| 252 | + relations = [r for r in self.ds.table_relation if r.get('shape') == 'edge'] |
| 253 | + if not relations: |
| 254 | + return "" |
| 255 | + |
| 256 | + # 获取所有涉及的表ID和字段ID |
| 257 | + table_ids = set() |
| 258 | + field_ids = set() |
| 259 | + for r in relations: |
| 260 | + table_ids.add(r.get('source').get('cell')) |
| 261 | + table_ids.add(r.get('target').get('cell')) |
| 262 | + field_ids.add(r.get('source').get('port')) |
| 263 | + field_ids.add(r.get('target').get('port')) |
| 264 | + |
| 265 | + # 查询表名和字段名映射 |
| 266 | + table_records = session.query(CoreTable).filter(CoreTable.id.in_(list(map(int, table_ids)))).all() |
| 267 | + table_dict = {t.id: t.table_name for t in table_records} |
| 268 | + |
| 269 | + field_records = session.query(CoreField).filter(CoreField.id.in_(list(map(int, field_ids)))).all() |
| 270 | + field_dict = {f.id: f.field_name for f in field_records} |
| 271 | + |
| 272 | + # 拼接输出 |
| 273 | + lines = ["【Foreign keys】"] |
| 274 | + for r in relations: |
| 275 | + src_table = table_dict.get(int(r.get('source').get('cell'))) |
| 276 | + src_field = field_dict.get(int(r.get('source').get('port'))) |
| 277 | + tgt_table = table_dict.get(int(r.get('target').get('cell'))) |
| 278 | + tgt_field = field_dict.get(int(r.get('target').get('port'))) |
| 279 | + if src_table and src_field and tgt_table and tgt_field: |
| 280 | + lines.append(f"{src_table}.{src_field}={tgt_table}.{tgt_field}") |
| 281 | + |
| 282 | + return "\n".join(lines) if len(lines) > 1 else "" |
| 283 | + |
247 | 284 | def init_intent_messages(self, session: Session): |
248 | 285 | """构建意图识别的消息列表(仿照 init_messages)""" |
249 | 286 | last_messages = self.intent_logs[-1].messages if len(self.intent_logs) > 0 else [] |
250 | 287 |
|
251 | 288 | self.intent_message = [] |
252 | 289 | # 1. System Prompt |
253 | 290 | table_list = self._get_all_tables_brief(session) |
| 291 | + table_relations = self._get_all_table_relations(session) |
254 | 292 | template = get_base_template() |
255 | 293 | intent_template = template['template']['intent_recognition']['system'] |
256 | 294 | self.intent_message.append(SystemMessage( |
257 | 295 | content=intent_template.format( |
258 | 296 | table_list=table_list, |
| 297 | + table_relations=table_relations, |
259 | 298 | lang=self.chat_question.lang |
260 | 299 | ) |
261 | 300 | )) |
|
0 commit comments