Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions backend/apps/terminology/curd/terminology.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,23 @@ def save_embeddings(session: Session, ids: List[int]):
FROM terminology AS child
) TEMP
WHERE similarity > {settings.EMBEDDING_TERMINOLOGY_SIMILARITY} AND oid = :oid
AND (specific_ds = false OR specific_ds IS NULL)
ORDER BY similarity DESC
LIMIT {settings.EMBEDDING_TERMINOLOGY_TOP_COUNT}
"""

embedding_sql_with_datasource = f"""
SELECT id, pid, word, similarity
FROM
(SELECT id, pid, word, oid, specific_ds, datasource_ids,
( 1 - (embedding <=> :embedding_array) ) AS similarity
FROM terminology AS child
) TEMP
WHERE similarity > {settings.EMBEDDING_TERMINOLOGY_SIMILARITY} AND oid = :oid
AND (
(:datasource IS NULL AND (specific_ds = false OR specific_ds IS NULL))
OR
(:datasource IS NOT NULL AND ((specific_ds = false OR specific_ds IS NULL) OR (specific_ds = true AND datasource_ids IS NOT NULL AND datasource_ids @> jsonb_build_array(:datasource))))
(specific_ds = false OR specific_ds IS NULL)
OR
(specific_ds = true AND datasource_ids IS NOT NULL AND datasource_ids @> jsonb_build_array(:datasource))
)
ORDER BY similarity DESC
LIMIT {settings.EMBEDDING_TERMINOLOGY_TOP_COUNT}
Expand Down Expand Up @@ -447,14 +460,21 @@ def select_terminology_by_word(session: SessionDep, word: str, oid: int, datasou

embedding = model.embed_query(word)

results = session.execute(text(embedding_sql), {'embedding_array': str(embedding), 'oid': oid,
'datasource': datasource}).fetchall()
with session.begin():
if datasource is not None:
results = session.execute(text(embedding_sql_with_datasource),
{'embedding_array': str(embedding), 'oid': oid,
'datasource': datasource}).fetchall()
else:
results = session.execute(text(embedding_sql),
{'embedding_array': str(embedding), 'oid': oid}).fetchall()

for row in results:
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid))
for row in results:
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid))

except Exception:
traceback.print_exc()
session.rollback()

_map: dict = {}
_ids: list[int] = []
Expand Down