diff --git a/backend/alembic/versions/047_table_embedding.py b/backend/alembic/versions/047_table_embedding.py index 1b6c766ec..ba6b89604 100644 --- a/backend/alembic/versions/047_table_embedding.py +++ b/backend/alembic/versions/047_table_embedding.py @@ -20,10 +20,12 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.add_column('core_table', sa.Column('embedding', sa.Text(), nullable=True)) + op.add_column('core_datasource', sa.Column('embedding', sa.Text(), nullable=True)) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### op.drop_column('core_table', 'embedding') + op.drop_column('core_datasource', 'embedding') # ### end Alembic commands ### diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py index f9339bdc4..ba0965fc8 100644 --- a/backend/apps/datasource/crud/datasource.py +++ b/backend/apps/datasource/crud/datasource.py @@ -15,7 +15,7 @@ from apps.db.engine import get_engine_config, get_engine_conn from common.core.config import settings from common.core.deps import SessionDep, CurrentUser, Trans -from common.utils.embedding_threads import run_save_table_embeddings +from common.utils.embedding_threads import run_save_table_embeddings, run_save_ds_embeddings from common.utils.utils import deepcopy_ignore_extra from .table import get_tables_by_ds_id from ..crud.field import delete_field_by_ds_id, update_field @@ -105,6 +105,8 @@ def update_ds(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreData setattr(record, field, value) session.add(record) session.commit() + + run_save_ds_embeddings([ds.id]) return ds @@ -197,6 +199,7 @@ def sync_table(session: SessionDep, ds: CoreDatasource, tables: List[CoreTable]) # do table embedding run_save_table_embeddings(id_list) + run_save_ds_embeddings([ds.id]) def sync_fields(session: SessionDep, ds: CoreDatasource, table: CoreTable, fields: List[ColumnSchema]): @@ -238,6 +241,7 @@ def update_table_and_fields(session: SessionDep, data: TableObj): # do table embedding run_save_table_embeddings([data.table.id]) + run_save_ds_embeddings([data.table.ds_id]) def updateTable(session: SessionDep, table: CoreTable): @@ -245,6 +249,7 @@ def updateTable(session: SessionDep, table: CoreTable): # do table embedding run_save_table_embeddings([table.id]) + run_save_ds_embeddings([table.ds_id]) def updateField(session: SessionDep, field: CoreField): @@ -252,6 +257,7 @@ def updateField(session: SessionDep, field: CoreField): # do table embedding run_save_table_embeddings([field.table_id]) + run_save_ds_embeddings([field.ds_id]) def preview(session: SessionDep, current_user: CurrentUser, id: int, data: TableObj): diff --git a/backend/apps/datasource/crud/table.py b/backend/apps/datasource/crud/table.py index 535d343b8..d5685e575 100644 --- a/backend/apps/datasource/crud/table.py +++ b/backend/apps/datasource/crud/table.py @@ -9,7 +9,7 @@ from common.core.config import settings from common.core.deps import SessionDep from common.utils.utils import SQLBotLogUtil -from ..models.datasource import CoreTable, CoreField +from ..models.datasource import CoreTable, CoreField, CoreDatasource def delete_table_by_ds_id(session: SessionDep, id: int): @@ -30,18 +30,24 @@ def update_table(session: SessionDep, item: CoreTable): session.commit() -def run_fill_empty_table_embedding(session_maker): +def run_fill_empty_table_and_ds_embedding(session_maker): try: if not settings.TABLE_EMBEDDING_ENABLED: return - SQLBotLogUtil.info('get tables') session = session_maker() + + SQLBotLogUtil.info('get tables') stmt = select(CoreTable.id).where(and_(CoreTable.embedding.is_(None))) results = session.execute(stmt).scalars().all() - SQLBotLogUtil.info('result: ' + str(len(results))) - - save_table_embedding(session, results) + SQLBotLogUtil.info('table result: ' + str(len(results))) + save_table_embedding(session_maker, results) + + SQLBotLogUtil.info('get datasource') + ds_stmt = select(CoreDatasource.id).where(and_(CoreDatasource.embedding.is_(None))) + ds_results = session.execute(ds_stmt).scalars().all() + SQLBotLogUtil.info('datasource result: ' + str(len(ds_results))) + save_ds_embedding(session_maker, ds_results) except Exception: traceback.print_exc() finally: @@ -98,3 +104,58 @@ def save_table_embedding(session_maker, ids: List[int]): traceback.print_exc() finally: session_maker.remove() + + +def save_ds_embedding(session_maker, ids: List[int]): + if not settings.TABLE_EMBEDDING_ENABLED: + return + + if not ids or len(ids) == 0: + return + try: + SQLBotLogUtil.info('start datasource embedding') + start_time = time.time() + model = EmbeddingModelCache.get_model() + session = session_maker() + for _id in ids: + schema_table = '' + ds = session.query(CoreDatasource).filter(CoreDatasource.id == _id).first() + schema_table += f"{ds.name}, {ds.description}\n" + tables = session.query(CoreTable).filter(CoreTable.ds_id == ds.id).all() + for table in tables: + fields = session.query(CoreField).filter(CoreField.table_id == table.id).all() + + schema_table += f"# Table: {table.table_name}" + table_comment = '' + if table.custom_comment: + table_comment = table.custom_comment.strip() + if table_comment == '': + schema_table += '\n[\n' + else: + schema_table += f", {table_comment}\n[\n" + + if fields: + field_list = [] + for field in fields: + field_comment = '' + if field.custom_comment: + field_comment = field.custom_comment.strip() + if field_comment == '': + field_list.append(f"({field.field_name}:{field.field_type})") + else: + field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})") + schema_table += ",\n".join(field_list) + schema_table += '\n]\n' + # table_schema.append(schema_table) + emb = json.dumps(model.embed_query(schema_table)) + + stmt = update(CoreDatasource).where(and_(CoreDatasource.id == _id)).values(embedding=emb) + session.execute(stmt) + session.commit() + + end_time = time.time() + SQLBotLogUtil.info('datasource embedding finished in: ' + str(end_time - start_time) + ' seconds') + except Exception: + traceback.print_exc() + finally: + session_maker.remove() diff --git a/backend/apps/datasource/embedding/ds_embedding.py b/backend/apps/datasource/embedding/ds_embedding.py index a3570178e..34ee7c7b1 100644 --- a/backend/apps/datasource/embedding/ds_embedding.py +++ b/backend/apps/datasource/embedding/ds_embedding.py @@ -1,11 +1,11 @@ # Author: Junjun # Date: 2025/9/18 import json +import time import traceback from typing import Optional from apps.ai_model.embedding import EmbeddingModelCache -from apps.datasource.crud.datasource import get_table_schema from apps.datasource.embedding.utils import cosine_similarity from apps.datasource.models.datasource import CoreDatasource from apps.system.crud.assistant import AssistantOutDs @@ -18,7 +18,7 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o question: str, current_assistant: Optional[CurrentAssistant] = None): _list = [] - if current_assistant and current_assistant.type != 4: + if current_assistant and current_assistant.type == 1: if out_ds.ds_list: for _ds in out_ds.ds_list: ds = out_ds.get_ds(_ds.id) @@ -26,34 +26,63 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o ds_info = f"{ds.name}, {ds.description}\n" ds_schema = ds_info + table_schema _list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds}) + + if _list: + try: + text = [s.get('ds_schema') for s in _list] + + model = EmbeddingModelCache.get_model() + results = model.embed_documents(text) + + q_embedding = model.embed_query(question) + for index in range(len(results)): + item = results[index] + _list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item) + + _list.sort(key=lambda x: x['cosine_similarity'], reverse=True) + # print(len(_list)) + SQLBotLogUtil.info(json.dumps( + [{"id": ele.get("id"), "name": ele.get("ds").name, + "cosine_similarity": ele.get("cosine_similarity")} + for ele in _list])) + ds = _list[0].get('ds') + return {"id": ds.id, "name": ds.name, "description": ds.description} + except Exception: + traceback.print_exc() else: for _ds in _ds_list: if _ds.get('id'): ds = session.get(CoreDatasource, _ds.get('id')) - table_schema = get_table_schema(session, current_user, ds, question, embedding=False) - ds_info = f"{ds.name}, {ds.description}\n" - ds_schema = ds_info + table_schema - _list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds}) + # table_schema = get_table_schema(session, current_user, ds, question, embedding=False) + # ds_info = f"{ds.name}, {ds.description}\n" + # ds_schema = ds_info + table_schema + _list.append({"id": ds.id, "cosine_similarity": 0.0, "ds": ds, "embedding": ds.embedding}) + + if _list: + try: + # text = [s.get('ds_schema') for s in _list] + + model = EmbeddingModelCache.get_model() + start_time = time.time() + # results = model.embed_documents(text) + results = [item.get('embedding') for item in _list] + + q_embedding = model.embed_query(question) + for index in range(len(results)): + item = results[index] + if item: + _list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item) - if _list: - try: - text = [s.get('ds_schema') for s in _list] - - model = EmbeddingModelCache.get_model() - results = model.embed_documents(text) - - q_embedding = model.embed_query(question) - for index in range(len(results)): - item = results[index] - _list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item) - - _list.sort(key=lambda x: x['cosine_similarity'], reverse=True) - # print(len(_list)) - SQLBotLogUtil.info(json.dumps( - [{"id": ele.get("id"), "name": ele.get("ds").name, "cosine_similarity": ele.get("cosine_similarity")} - for ele in _list])) - ds = _list[0].get('ds') - return {"id": ds.id, "name": ds.name, "description": ds.description} - except Exception: - traceback.print_exc() + _list.sort(key=lambda x: x['cosine_similarity'], reverse=True) + # print(len(_list)) + end_time = time.time() + SQLBotLogUtil.info(str(end_time - start_time)) + SQLBotLogUtil.info(json.dumps( + [{"id": ele.get("id"), "name": ele.get("ds").name, + "cosine_similarity": ele.get("cosine_similarity")} + for ele in _list])) + ds = _list[0].get('ds') + return {"id": ds.id, "name": ds.name, "description": ds.description} + except Exception: + traceback.print_exc() return _list diff --git a/backend/apps/datasource/embedding/table_embedding.py b/backend/apps/datasource/embedding/table_embedding.py index 0724d7dd2..c467ecd8d 100644 --- a/backend/apps/datasource/embedding/table_embedding.py +++ b/backend/apps/datasource/embedding/table_embedding.py @@ -52,7 +52,7 @@ def calc_table_embedding(tables: list[dict], question: str): # text = [s.get('schema_table') for s in _list] # model = EmbeddingModelCache.get_model() - # start_time = time.time() + start_time = time.time() # results = model.embed_documents(text) # end_time = time.time() # SQLBotLogUtil.info(str(end_time - start_time)) @@ -67,7 +67,11 @@ def calc_table_embedding(tables: list[dict], question: str): _list.sort(key=lambda x: x['cosine_similarity'], reverse=True) _list = _list[:settings.TABLE_EMBEDDING_COUNT] # print(len(_list)) - SQLBotLogUtil.info(json.dumps(_list)) + end_time = time.time() + SQLBotLogUtil.info(str(end_time - start_time)) + SQLBotLogUtil.info(json.dumps([{"id": ele.get('id'), "schema_table": ele.get('schema_table'), + "cosine_similarity": ele.get('cosine_similarity')} + for ele in _list])) return _list except Exception: traceback.print_exc() diff --git a/backend/apps/datasource/models/datasource.py b/backend/apps/datasource/models/datasource.py index 9584ed3f2..a86eb7ca8 100644 --- a/backend/apps/datasource/models/datasource.py +++ b/backend/apps/datasource/models/datasource.py @@ -21,6 +21,7 @@ class CoreDatasource(SQLModel, table=True): num: str = Field(max_length=256, nullable=True) oid: int = Field(sa_column=Column(BigInteger())) table_relation: List = Field(sa_column=Column(JSONB, nullable=True)) + embedding: str = Field(sa_column=Column(Text, nullable=True)) class CoreTable(SQLModel, table=True): diff --git a/backend/common/utils/embedding_threads.py b/backend/common/utils/embedding_threads.py index ab0632096..d4a2fe6b5 100644 --- a/backend/common/utils/embedding_threads.py +++ b/backend/common/utils/embedding_threads.py @@ -38,6 +38,11 @@ def run_save_table_embeddings(ids: List[int]): executor.submit(save_table_embedding, session_maker, ids) -def fill_empty_table_embeddings(): - from apps.datasource.crud.table import run_fill_empty_table_embedding - executor.submit(run_fill_empty_table_embedding, session_maker) +def run_save_ds_embeddings(ids: List[int]): + from apps.datasource.crud.table import save_ds_embedding + executor.submit(save_ds_embedding, session_maker, ids) + + +def fill_empty_table_and_ds_embeddings(): + from apps.datasource.crud.table import run_fill_empty_table_and_ds_embedding + executor.submit(run_fill_empty_table_and_ds_embedding, session_maker) diff --git a/backend/main.py b/backend/main.py index c63711f31..48e124f79 100644 --- a/backend/main.py +++ b/backend/main.py @@ -12,7 +12,7 @@ from alembic import command from apps.api import api_router -from common.utils.embedding_threads import fill_empty_table_embeddings +from common.utils.embedding_threads import fill_empty_table_and_ds_embeddings from apps.system.crud.aimodel_manage import async_model_info from apps.system.crud.assistant import init_dynamic_cors from apps.system.middleware.auth import TokenMiddleware @@ -36,8 +36,8 @@ def init_data_training_embedding_data(): fill_empty_data_training_embeddings() -def init_table_embedding(): - fill_empty_table_embeddings() +def init_table_and_ds_embedding(): + fill_empty_table_and_ds_embeddings() @asynccontextmanager @@ -47,7 +47,7 @@ async def lifespan(app: FastAPI): init_dynamic_cors(app) init_terminology_embedding_data() init_data_training_embedding_data() - init_table_embedding() + init_table_and_ds_embedding() SQLBotLogUtil.info("✅ SQLBot 初始化完成") await sqlbot_xpack.core.clean_xpack_cache() await async_model_info() # 异步加密已有模型的密钥和地址