From 643690c123b59842172f3df9514b78a2a4e75fb4 Mon Sep 17 00:00:00 2001 From: junjun Date: Wed, 24 Sep 2025 17:33:58 +0800 Subject: [PATCH] feat: add table relation --- backend/apps/datasource/crud/datasource.py | 58 +++++++++++++++++-- .../datasource/embedding/table_embedding.py | 10 ++-- 2 files changed, 58 insertions(+), 10 deletions(-) diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py index 719ad79f0..fc502f0f9 100644 --- a/backend/apps/datasource/crud/datasource.py +++ b/backend/apps/datasource/crud/datasource.py @@ -354,6 +354,7 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat db_name = table_objs[0].schema schema_str += f"【DB_ID】 {db_name}\n【Schema】\n" tables = [] + all_tables = [] # temp save all tables for obj in table_objs: schema_table = '' schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}" @@ -376,13 +377,60 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})") schema_table += ",\n".join(field_list) schema_table += '\n]\n' - tables.append(schema_table) + t_obj = {"id": obj.table.id, "schema_table": schema_table} + tables.append(t_obj) + all_tables.append(t_obj) # do table embedding - if embedding: + if embedding and tables: tables = get_table_embedding(session, current_user, tables, question) + # splice schema + if tables: + for s in tables: + schema_str += s.get('schema_table') + + # field relation + if tables and ds.table_relation: + relations = list(filter(lambda x: x.get('shape') == 'edge', ds.table_relation)) + # Complete the missing table + # get tables in relation, remove irrelevant relation + embedding_table_ids = [s.get('id') for s in tables] + all_relations = list(filter(lambda x: x.get('source').get('cell') in embedding_table_ids or x.get('target').get( + 'cell') in embedding_table_ids, relations)) + + # get relation table ids, sub embedding table ids + relation_table_ids = [] + for r in all_relations: + relation_table_ids.append(r.get('source').get('cell')) + relation_table_ids.append(r.get('target').get('cell')) + relation_table_ids = list(set(relation_table_ids)) + # get table dict + table_records = session.query(CoreTable).filter(CoreTable.id.in_(list(map(int, relation_table_ids)))).all() + table_dict = {} + for ele in table_records: + table_dict[ele.id] = ele.table_name + + # get lost table ids + lost_table_ids = list(set(relation_table_ids) - set(embedding_table_ids)) + # get lost table schema and splice it + lost_tables = list(filter(lambda x: x.get('id') in lost_table_ids, all_tables)) + if lost_tables: + for s in lost_tables: + schema_str += s.get('schema_table') + + # get field dict + relation_field_ids = [] + for relation in all_relations: + relation_field_ids.append(relation.get('source').get('port')) + relation_field_ids.append(relation.get('target').get('port')) + relation_field_ids = list(set(relation_field_ids)) + field_records = session.query(CoreField).filter(CoreField.id.in_(list(map(int, relation_field_ids)))).all() + field_dict = {} + for ele in field_records: + field_dict[ele.id] = ele.field_name + + schema_str += '【Foreign keys】\n' + for ele in all_relations: + schema_str += f"{table_dict.get(int(ele.get('source').get('cell')))}.{field_dict.get(int(ele.get('source').get('port')))}={table_dict.get(int(ele.get('target').get('cell')))}.{field_dict.get(int(ele.get('target').get('port')))}\n" - # todo 外键 - for s in tables: - schema_str += s return schema_str diff --git a/backend/apps/datasource/embedding/table_embedding.py b/backend/apps/datasource/embedding/table_embedding.py index e1827174e..a14ad9891 100644 --- a/backend/apps/datasource/embedding/table_embedding.py +++ b/backend/apps/datasource/embedding/table_embedding.py @@ -10,14 +10,14 @@ from common.utils.utils import SQLBotLogUtil -def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[str], question: str): +def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[dict], question: str): _list = [] - for table_schema in tables: - _list.append({"table_schema": table_schema, "cosine_similarity": 0.0}) + for table in tables: + _list.append({"id": table.get('id'), "schema_table": table.get('schema_table'), "cosine_similarity": 0.0}) if _list: try: - text = [s.get('table_schema') for s in _list] + text = [s.get('schema_table') for s in _list] model = EmbeddingModelCache.get_model() results = model.embed_documents(text) @@ -31,7 +31,7 @@ def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: _list = _list[:settings.TABLE_EMBEDDING_COUNT] # print(len(_list)) SQLBotLogUtil.info(json.dumps(_list)) - return [t.get("table_schema") for t in _list] + return _list except Exception: traceback.print_exc() return _list