diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 03f66af98..79dfbdc68 100755 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -25,11 +25,11 @@ # from apps.system.models.user import SQLModel # noqa # from apps.settings.models.setting_models import SQLModel # from apps.chat.models.chat_model import SQLModel -# from apps.terminology.models.terminology_model import SQLModel +from apps.terminology.models.terminology_model import SQLModel # from apps.data_training.models.data_training_model import SQLModel # from apps.dashboard.models.dashboard_model import SQLModel from common.core.config import settings # noqa -from apps.datasource.models.datasource import SQLModel +#from apps.datasource.models.datasource import SQLModel target_metadata = SQLModel.metadata diff --git a/backend/alembic/versions/045_modify_terminolog.py b/backend/alembic/versions/045_modify_terminolog.py new file mode 100644 index 000000000..a452bb6c6 --- /dev/null +++ b/backend/alembic/versions/045_modify_terminolog.py @@ -0,0 +1,31 @@ +"""045_modify_terminolog + +Revision ID: 45e7e52bf2b8 +Revises: 455b8ce69e80 +Create Date: 2025-09-25 14:49:24.521795 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '45e7e52bf2b8' +down_revision = '455b8ce69e80' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('terminology', sa.Column('specific_ds', sa.Boolean(), nullable=True)) + op.add_column('terminology', sa.Column('datasource_ids', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('terminology', 'datasource_ids') + op.drop_column('terminology', 'specific_ds') + # ### end Alembic commands ### diff --git a/backend/apps/terminology/curd/terminology.py b/backend/apps/terminology/curd/terminology.py index 7d0ef1c8d..7b8d0a044 100644 --- a/backend/apps/terminology/curd/terminology.py +++ b/backend/apps/terminology/curd/terminology.py @@ -5,12 +5,14 @@ from xml.dom.minidom import parseString import dicttoxml +from sqlalchemy import BigInteger from sqlalchemy import and_, or_, select, func, delete, update, union from sqlalchemy import text from sqlalchemy.orm import aliased from sqlalchemy.orm.session import Session from apps.ai_model.embedding import EmbeddingModelCache +from apps.datasource.models.datasource import CoreDatasource from apps.template.generate_chart.generator import get_base_terminology_template from apps.terminology.models.terminology_model import Terminology, TerminologyInfo from common.core.config import settings @@ -80,6 +82,16 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int .subquery() ) + # 创建子查询来获取数据源名称,添加类型转换 + datasource_names_subquery = ( + select( + func.jsonb_array_elements(Terminology.datasource_ids).cast(BigInteger).label('ds_id'), + Terminology.id.label('term_id') + ) + .where(Terminology.id.in_(paginated_parent_ids)) + .subquery() + ) + # 主查询 stmt = ( select( @@ -87,13 +99,34 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int Terminology.word, Terminology.create_time, Terminology.description, - children_subquery.c.other_words + Terminology.specific_ds, + Terminology.datasource_ids, + children_subquery.c.other_words, + func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names') ) .outerjoin( children_subquery, Terminology.id == children_subquery.c.pid ) + # 关联数据源名称子查询和 CoreDatasource 表 + .outerjoin( + datasource_names_subquery, + datasource_names_subquery.c.term_id == Terminology.id + ) + .outerjoin( + CoreDatasource, + CoreDatasource.id == datasource_names_subquery.c.ds_id + ) .where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid)) + .group_by( + Terminology.id, + Terminology.word, + Terminology.create_time, + Terminology.description, + Terminology.specific_ds, + Terminology.datasource_ids, + children_subquery.c.other_words + ) .order_by(Terminology.create_time.desc()) ) else: @@ -116,15 +149,37 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int .subquery() ) + # 创建子查询来获取数据源名称 + datasource_names_subquery = ( + select( + func.jsonb_array_elements(Terminology.datasource_ids).cast(BigInteger).label('ds_id'), + Terminology.id.label('term_id') + ) + .where(Terminology.id.in_(paginated_parent_ids)) + .subquery() + ) + stmt = ( select( Terminology.id, Terminology.word, Terminology.create_time, Terminology.description, - func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words') + Terminology.specific_ds, + Terminology.datasource_ids, + func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words'), + func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names') ) .outerjoin(child, and_(Terminology.id == child.pid)) + # 关联数据源名称子查询和 CoreDatasource 表 + .outerjoin( + datasource_names_subquery, + datasource_names_subquery.c.term_id == Terminology.id + ) + .outerjoin( + CoreDatasource, + CoreDatasource.id == datasource_names_subquery.c.ds_id + ) .where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid)) .group_by(Terminology.id, Terminology.word) .order_by(Terminology.create_time.desc()) @@ -139,6 +194,9 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int create_time=row.create_time, description=row.description, other_words=row.other_words if row.other_words else [], + specific_ds=row.specific_ds if row.specific_ds is not None else False, + datasource_ids=row.datasource_ids if row.datasource_ids is not None else [], + datasource_names=row.datasource_names if row.datasource_names is not None else [], )) return current_page, page_size, total_count, total_pages, _list @@ -146,7 +204,13 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans): create_time = datetime.datetime.now() - parent = Terminology(word=info.word, create_time=create_time, description=info.description, oid=oid) + + specific_ds = info.specific_ds if info.specific_ds is not None else False + datasource_ids = info.datasource_ids if info.datasource_ids is not None else [] + + parent = Terminology(word=info.word, create_time=create_time, description=info.description, oid=oid, + specific_ds=specific_ds, + datasource_ids=datasource_ids) words = [info.word] for child in info.other_words: @@ -175,7 +239,8 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra if other_word.strip() == "": continue _list.append( - Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid)) + Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid, + specific_ds=specific_ds, datasource_ids=datasource_ids)) session.bulk_save_objects(_list) session.flush() session.commit() @@ -214,9 +279,14 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra if exists: raise Exception(trans("i18n_terminology.exists_in_db")) + specific_ds = info.specific_ds if info.specific_ds is not None else False + datasource_ids = info.datasource_ids if info.datasource_ids is not None else [] + stmt = update(Terminology).where(and_(Terminology.id == info.id)).values( word=info.word, description=info.description, + specific_ds=specific_ds, + datasource_ids=datasource_ids ) session.execute(stmt) session.commit() @@ -232,7 +302,8 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra if other_word.strip() == "": continue _list.append( - Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid)) + Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid, + specific_ds=specific_ds, datasource_ids=datasource_ids)) session.bulk_save_objects(_list) session.flush() session.commit() diff --git a/backend/apps/terminology/models/terminology_model.py b/backend/apps/terminology/models/terminology_model.py index 57c35ce4a..b90486593 100644 --- a/backend/apps/terminology/models/terminology_model.py +++ b/backend/apps/terminology/models/terminology_model.py @@ -3,7 +3,8 @@ from pgvector.sqlalchemy import VECTOR from pydantic import BaseModel -from sqlalchemy import Column, Text, BigInteger, DateTime, Identity +from sqlalchemy import Column, Text, BigInteger, DateTime, Identity, Boolean +from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import SQLModel, Field @@ -16,6 +17,8 @@ class Terminology(SQLModel, table=True): word: Optional[str] = Field(max_length=255) description: Optional[str] = Field(sa_column=Column(Text, nullable=True)) embedding: Optional[List[float]] = Field(sa_column=Column(VECTOR(), nullable=True)) + specific_ds: Optional[bool] = Field(sa_column=Column(Boolean, default=False)) + datasource_ids: Optional[list[int]] = Field(sa_column=Column(JSONB), default=[]) class TerminologyInfo(BaseModel): @@ -24,5 +27,6 @@ class TerminologyInfo(BaseModel): word: Optional[str] = None description: Optional[str] = None other_words: Optional[List[str]] = [] - - + specific_ds: Optional[bool] = False + datasource_ids: Optional[list[int]] = [] + datasource_names: Optional[list[str]] = []