From 9897b938f8aa9e87e6ebcd4b99a082e8c53baa00 Mon Sep 17 00:00:00 2001 From: junjun Date: Fri, 10 Oct 2025 09:59:42 +0800 Subject: [PATCH] refactor: Service initiation vectorization task --- .../apps/data_training/curd/data_training.py | 25 ++++++++----- backend/apps/datasource/crud/datasource.py | 2 +- backend/apps/datasource/crud/table.py | 31 +++++----------- backend/apps/terminology/curd/terminology.py | 35 ++++++++++++------- backend/common/utils/embedding_threads.py | 32 +++++++++++------ backend/main.py | 2 +- 6 files changed, 70 insertions(+), 57 deletions(-) diff --git a/backend/apps/data_training/curd/data_training.py b/backend/apps/data_training/curd/data_training.py index 15260a9aa..b9530437a 100644 --- a/backend/apps/data_training/curd/data_training.py +++ b/backend/apps/data_training/curd/data_training.py @@ -7,7 +7,6 @@ import dicttoxml from sqlalchemy import and_, select, func, delete, update, or_ from sqlalchemy import text -from sqlalchemy.orm.session import Session from apps.ai_model.embedding import EmbeddingModelCache from apps.data_training.models.data_training_model import DataTrainingInfo, DataTraining @@ -160,24 +159,30 @@ def delete_training(session: SessionDep, ids: list[int]): # executor.submit(run_fill_empty_embeddings) -def run_fill_empty_embeddings(session: Session): - if not settings.EMBEDDING_ENABLED: - return +def run_fill_empty_embeddings(session_maker): + try: + if not settings.EMBEDDING_ENABLED: + return - stmt = select(DataTraining.id).where(and_(DataTraining.embedding.is_(None))) - results = session.execute(stmt).scalars().all() + session = session_maker() + stmt = select(DataTraining.id).where(and_(DataTraining.embedding.is_(None))) + results = session.execute(stmt).scalars().all() - save_embeddings(session, results) + save_embeddings(session_maker, results) + except Exception: + traceback.print_exc() + finally: + session_maker.remove() -def save_embeddings(session: Session, ids: List[int]): +def save_embeddings(session_maker, ids: List[int]): if not settings.EMBEDDING_ENABLED: return if not ids or len(ids) == 0: return try: - + session = session_maker() _list = session.query(DataTraining).filter(and_(DataTraining.id.in_(ids))).all() _question_list = [item.question for item in _list] @@ -194,6 +199,8 @@ def save_embeddings(session: Session, ids: List[int]): except Exception: traceback.print_exc() + finally: + session_maker.remove() embedding_sql = f""" diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py index 696513b17..f9339bdc4 100644 --- a/backend/apps/datasource/crud/datasource.py +++ b/backend/apps/datasource/crud/datasource.py @@ -15,7 +15,7 @@ from apps.db.engine import get_engine_config, get_engine_conn from common.core.config import settings from common.core.deps import SessionDep, CurrentUser, Trans -from apps.datasource.crud.table import run_save_table_embeddings +from common.utils.embedding_threads import run_save_table_embeddings from common.utils.utils import deepcopy_ignore_extra from .table import get_tables_by_ds_id from ..crud.field import delete_field_by_ds_id, update_field diff --git a/backend/apps/datasource/crud/table.py b/backend/apps/datasource/crud/table.py index e4325a95c..535d343b8 100644 --- a/backend/apps/datasource/crud/table.py +++ b/backend/apps/datasource/crud/table.py @@ -1,12 +1,9 @@ import json import time import traceback -from concurrent.futures import ThreadPoolExecutor from typing import List from sqlalchemy import and_, select, update -from sqlalchemy.orm import sessionmaker -from sqlalchemy.orm.session import Session from apps.ai_model.embedding import EmbeddingModelCache from common.core.config import settings @@ -14,13 +11,6 @@ from common.utils.utils import SQLBotLogUtil from ..models.datasource import CoreTable, CoreField -executor = ThreadPoolExecutor(max_workers=200) - -from common.core.db import engine - -session_maker = sessionmaker(bind=engine) -session = session_maker() - def delete_table_by_ds_id(session: SessionDep, id: int): session.query(CoreTable).filter(CoreTable.ds_id == id).delete(synchronize_session=False) @@ -40,12 +30,13 @@ def update_table(session: SessionDep, item: CoreTable): session.commit() -def run_fill_empty_table_embedding(session: Session): +def run_fill_empty_table_embedding(session_maker): try: if not settings.TABLE_EMBEDDING_ENABLED: return SQLBotLogUtil.info('get tables') + session = session_maker() stmt = select(CoreTable.id).where(and_(CoreTable.embedding.is_(None))) results = session.execute(stmt).scalars().all() SQLBotLogUtil.info('result: ' + str(len(results))) @@ -53,9 +44,11 @@ def run_fill_empty_table_embedding(session: Session): save_table_embedding(session, results) except Exception: traceback.print_exc() + finally: + session_maker.remove() -def save_table_embedding(session: Session, ids: List[int]): +def save_table_embedding(session_maker, ids: List[int]): if not settings.TABLE_EMBEDDING_ENABLED: return @@ -65,6 +58,7 @@ def save_table_embedding(session: Session, ids: List[int]): SQLBotLogUtil.info('start table embedding') start_time = time.time() model = EmbeddingModelCache.get_model() + session = session_maker() for _id in ids: table = session.query(CoreTable).filter(CoreTable.id == _id).first() fields = session.query(CoreField).filter(CoreField.table_id == table.id).all() @@ -102,14 +96,5 @@ def save_table_embedding(session: Session, ids: List[int]): SQLBotLogUtil.info('table embedding finished in: ' + str(end_time - start_time) + ' seconds') except Exception: traceback.print_exc() - - -def run_save_table_embeddings(ids: List[int]): - executor.submit(save_table_embedding, session, ids) - - -def fill_empty_table_embeddings(): - try: - executor.submit(run_fill_empty_table_embedding, session) - except Exception: - traceback.print_exc() + finally: + session_maker.remove() diff --git a/backend/apps/terminology/curd/terminology.py b/backend/apps/terminology/curd/terminology.py index f5a382eba..1d6ea4cae 100644 --- a/backend/apps/terminology/curd/terminology.py +++ b/backend/apps/terminology/curd/terminology.py @@ -7,7 +7,6 @@ import dicttoxml from sqlalchemy import and_, or_, select, func, delete, update, union, text, BigInteger from sqlalchemy.orm import aliased -from sqlalchemy.orm.session import Session from apps.ai_model.embedding import EmbeddingModelCache from apps.datasource.models.datasource import CoreDatasource @@ -407,26 +406,36 @@ def delete_terminology(session: SessionDep, ids: list[int]): # # def fill_empty_embeddings(): # executor.submit(run_fill_empty_embeddings) +# from sqlalchemy import create_engine +# from sqlalchemy.orm import sessionmaker,scoped_session +# engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) +# session_maker = scoped_session(sessionmaker(bind=engine)) - -def run_fill_empty_embeddings(session: Session): - if not settings.EMBEDDING_ENABLED: - return - stmt1 = select(Terminology.id).where(and_(Terminology.embedding.is_(None), Terminology.pid.is_(None))) - stmt2 = select(Terminology.pid).where(and_(Terminology.embedding.is_(None), Terminology.pid.isnot(None))).distinct() - combined_stmt = union(stmt1, stmt2) - results = session.execute(combined_stmt).scalars().all() - save_embeddings(session, results) +def run_fill_empty_embeddings(session_maker): + try: + if not settings.EMBEDDING_ENABLED: + return + session = session_maker() + stmt1 = select(Terminology.id).where(and_(Terminology.embedding.is_(None), Terminology.pid.is_(None))) + stmt2 = select(Terminology.pid).where( + and_(Terminology.embedding.is_(None), Terminology.pid.isnot(None))).distinct() + combined_stmt = union(stmt1, stmt2) + results = session.execute(combined_stmt).scalars().all() + save_embeddings(session_maker, results) + except Exception: + traceback.print_exc() + finally: + session_maker.remove() -def save_embeddings(session: Session, ids: List[int]): +def save_embeddings(session_maker, ids: List[int]): if not settings.EMBEDDING_ENABLED: return if not ids or len(ids) == 0: return try: - + session = session_maker() _list = session.query(Terminology).filter(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids))).all() _words_list = [item.word for item in _list] @@ -443,6 +452,8 @@ def save_embeddings(session: Session, ids: List[int]): except Exception: traceback.print_exc() + finally: + session_maker.remove() embedding_sql = f""" diff --git a/backend/common/utils/embedding_threads.py b/backend/common/utils/embedding_threads.py index a38b66f0d..ab0632096 100644 --- a/backend/common/utils/embedding_threads.py +++ b/backend/common/utils/embedding_threads.py @@ -1,33 +1,43 @@ from concurrent.futures import ThreadPoolExecutor from typing import List -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -from common.core.config import settings +from sqlalchemy.orm import sessionmaker, scoped_session executor = ThreadPoolExecutor(max_workers=200) -engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) -session_maker = sessionmaker(bind=engine) -session = session_maker() +from common.core.db import engine + +session_maker = scoped_session(sessionmaker(bind=engine)) + + +# session = session_maker() def run_save_terminology_embeddings(ids: List[int]): from apps.terminology.curd.terminology import save_embeddings - executor.submit(save_embeddings, session, ids) + executor.submit(save_embeddings, session_maker, ids) def fill_empty_terminology_embeddings(): from apps.terminology.curd.terminology import run_fill_empty_embeddings - executor.submit(run_fill_empty_embeddings, session) + executor.submit(run_fill_empty_embeddings, session_maker) def run_save_data_training_embeddings(ids: List[int]): from apps.data_training.curd.data_training import save_embeddings - executor.submit(save_embeddings, session, ids) + executor.submit(save_embeddings, session_maker, ids) def fill_empty_data_training_embeddings(): from apps.data_training.curd.data_training import run_fill_empty_embeddings - executor.submit(run_fill_empty_embeddings, session) + executor.submit(run_fill_empty_embeddings, session_maker) + + +def run_save_table_embeddings(ids: List[int]): + from apps.datasource.crud.table import save_table_embedding + executor.submit(save_table_embedding, session_maker, ids) + + +def fill_empty_table_embeddings(): + from apps.datasource.crud.table import run_fill_empty_table_embedding + executor.submit(run_fill_empty_table_embedding, session_maker) diff --git a/backend/main.py b/backend/main.py index fb14b53f4..c63711f31 100644 --- a/backend/main.py +++ b/backend/main.py @@ -12,7 +12,7 @@ from alembic import command from apps.api import api_router -from apps.datasource.crud.table import fill_empty_table_embeddings +from common.utils.embedding_threads import fill_empty_table_embeddings from apps.system.crud.aimodel_manage import async_model_info from apps.system.crud.assistant import init_dynamic_cors from apps.system.middleware.auth import TokenMiddleware