diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py index fc502f0f9..3a35fd478 100644 --- a/backend/apps/datasource/crud/datasource.py +++ b/backend/apps/datasource/crud/datasource.py @@ -392,45 +392,48 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat # field relation if tables and ds.table_relation: relations = list(filter(lambda x: x.get('shape') == 'edge', ds.table_relation)) - # Complete the missing table - # get tables in relation, remove irrelevant relation - embedding_table_ids = [s.get('id') for s in tables] - all_relations = list(filter(lambda x: x.get('source').get('cell') in embedding_table_ids or x.get('target').get( - 'cell') in embedding_table_ids, relations)) - - # get relation table ids, sub embedding table ids - relation_table_ids = [] - for r in all_relations: - relation_table_ids.append(r.get('source').get('cell')) - relation_table_ids.append(r.get('target').get('cell')) - relation_table_ids = list(set(relation_table_ids)) - # get table dict - table_records = session.query(CoreTable).filter(CoreTable.id.in_(list(map(int, relation_table_ids)))).all() - table_dict = {} - for ele in table_records: - table_dict[ele.id] = ele.table_name - - # get lost table ids - lost_table_ids = list(set(relation_table_ids) - set(embedding_table_ids)) - # get lost table schema and splice it - lost_tables = list(filter(lambda x: x.get('id') in lost_table_ids, all_tables)) - if lost_tables: - for s in lost_tables: - schema_str += s.get('schema_table') - - # get field dict - relation_field_ids = [] - for relation in all_relations: - relation_field_ids.append(relation.get('source').get('port')) - relation_field_ids.append(relation.get('target').get('port')) - relation_field_ids = list(set(relation_field_ids)) - field_records = session.query(CoreField).filter(CoreField.id.in_(list(map(int, relation_field_ids)))).all() - field_dict = {} - for ele in field_records: - field_dict[ele.id] = ele.field_name - - schema_str += '【Foreign keys】\n' - for ele in all_relations: - schema_str += f"{table_dict.get(int(ele.get('source').get('cell')))}.{field_dict.get(int(ele.get('source').get('port')))}={table_dict.get(int(ele.get('target').get('cell')))}.{field_dict.get(int(ele.get('target').get('port')))}\n" + if relations: + # Complete the missing table + # get tables in relation, remove irrelevant relation + embedding_table_ids = [s.get('id') for s in tables] + all_relations = list( + filter(lambda x: x.get('source').get('cell') in embedding_table_ids or x.get('target').get( + 'cell') in embedding_table_ids, relations)) + + # get relation table ids, sub embedding table ids + relation_table_ids = [] + for r in all_relations: + relation_table_ids.append(r.get('source').get('cell')) + relation_table_ids.append(r.get('target').get('cell')) + relation_table_ids = list(set(relation_table_ids)) + # get table dict + table_records = session.query(CoreTable).filter(CoreTable.id.in_(list(map(int, relation_table_ids)))).all() + table_dict = {} + for ele in table_records: + table_dict[ele.id] = ele.table_name + + # get lost table ids + lost_table_ids = list(set(relation_table_ids) - set(embedding_table_ids)) + # get lost table schema and splice it + lost_tables = list(filter(lambda x: x.get('id') in lost_table_ids, all_tables)) + if lost_tables: + for s in lost_tables: + schema_str += s.get('schema_table') + + # get field dict + relation_field_ids = [] + for relation in all_relations: + relation_field_ids.append(relation.get('source').get('port')) + relation_field_ids.append(relation.get('target').get('port')) + relation_field_ids = list(set(relation_field_ids)) + field_records = session.query(CoreField).filter(CoreField.id.in_(list(map(int, relation_field_ids)))).all() + field_dict = {} + for ele in field_records: + field_dict[ele.id] = ele.field_name + + if all_relations: + schema_str += '【Foreign keys】\n' + for ele in all_relations: + schema_str += f"{table_dict.get(int(ele.get('source').get('cell')))}.{field_dict.get(int(ele.get('source').get('port')))}={table_dict.get(int(ele.get('target').get('cell')))}.{field_dict.get(int(ele.get('target').get('port')))}\n" return schema_str