Skip to content

Commit 3d3685e

Browse files
committed
fix: always get terminology in default workspace after select ds
1 parent a99573c commit 3d3685e

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

backend/apps/chat/task/llm.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from langchain.chat_models.base import BaseChatModel
1717
from langchain_community.utilities import SQLDatabase
1818
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, BaseMessageChunk
19-
from sqlalchemy import select
19+
from sqlalchemy import and_, select
2020
from sqlalchemy.orm import sessionmaker
2121
from sqlmodel import Session
2222

@@ -404,9 +404,8 @@ def select_datasource(self):
404404
if self.current_assistant and self.current_assistant.type != 4:
405405
_ds_list = get_assistant_ds(session=self.session, llm_service=self)
406406
else:
407-
oid: str = self.current_user.oid
408407
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where(
409-
CoreDatasource.oid == oid)
408+
and_(CoreDatasource.oid == self.current_user.oid))
410409
_ds_list = [
411410
{
412411
"id": ds.id,
@@ -424,7 +423,7 @@ def select_datasource(self):
424423

425424
full_thinking_text = ''
426425
full_text = ''
427-
426+
json_str: Optional[str] = None
428427
if not ignore_auto_select:
429428
_ds_list_dict = []
430429
for _ds in _ds_list:
@@ -471,6 +470,8 @@ def select_datasource(self):
471470
token_usage=token_usage)
472471

473472
json_str = extract_nested_json(full_text)
473+
if json_str is None:
474+
raise SingleMessageError(f'Cannot parse datasource from answer: {full_text}')
474475

475476
_error: Exception | None = None
476477
_datasource: int | None = None
@@ -522,11 +523,11 @@ def select_datasource(self):
522523
engine_type=_engine_type)
523524
if self.ds:
524525
oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1
525-
dsId = self.ds.id if isinstance(self.ds, CoreDatasource) else None
526-
527-
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, 1)
528-
self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, dsId, oid)
529-
526+
ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
527+
528+
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, oid)
529+
self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, ds_id,
530+
oid)
530531

531532
self.init_messages()
532533

@@ -938,10 +939,11 @@ def run_task(self, in_chat: bool = True):
938939
try:
939940
if self.ds:
940941
oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1
941-
dsId = self.ds.id if isinstance(self.ds, CoreDatasource) else None
942-
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, oid)
943-
self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, dsId, oid)
944-
942+
ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
943+
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question,
944+
oid)
945+
self.chat_question.data_training = get_training_template(self.session, self.chat_question.question,
946+
ds_id, oid)
945947

946948
self.init_messages()
947949

0 commit comments

Comments
 (0)