1616from langchain .chat_models .base import BaseChatModel
1717from langchain_community .utilities import SQLDatabase
1818from langchain_core .messages import BaseMessage , SystemMessage , HumanMessage , AIMessage , BaseMessageChunk
19- from sqlalchemy import select
19+ from sqlalchemy import and_ , select
2020from sqlalchemy .orm import sessionmaker
2121from sqlmodel import Session
2222
@@ -404,9 +404,8 @@ def select_datasource(self):
404404 if self .current_assistant and self .current_assistant .type != 4 :
405405 _ds_list = get_assistant_ds (session = self .session , llm_service = self )
406406 else :
407- oid : str = self .current_user .oid
408407 stmt = select (CoreDatasource .id , CoreDatasource .name , CoreDatasource .description ).where (
409- CoreDatasource .oid == oid )
408+ and_ ( CoreDatasource .oid == self . current_user . oid ) )
410409 _ds_list = [
411410 {
412411 "id" : ds .id ,
@@ -424,7 +423,7 @@ def select_datasource(self):
424423
425424 full_thinking_text = ''
426425 full_text = ''
427-
426+ json_str : Optional [ str ] = None
428427 if not ignore_auto_select :
429428 _ds_list_dict = []
430429 for _ds in _ds_list :
@@ -471,6 +470,8 @@ def select_datasource(self):
471470 token_usage = token_usage )
472471
473472 json_str = extract_nested_json (full_text )
473+ if json_str is None :
474+ raise SingleMessageError (f'Cannot parse datasource from answer: { full_text } ' )
474475
475476 _error : Exception | None = None
476477 _datasource : int | None = None
@@ -522,11 +523,11 @@ def select_datasource(self):
522523 engine_type = _engine_type )
523524 if self .ds :
524525 oid = self .ds .oid if isinstance (self .ds , CoreDatasource ) else 1
525- dsId = self .ds .id if isinstance (self .ds , CoreDatasource ) else None
526-
527- self .chat_question .terminologies = get_terminology_template (self .session , self .chat_question .question , 1 )
528- self .chat_question .data_training = get_training_template (self .session , self .chat_question .question , dsId , oid )
529-
526+ ds_id = self .ds .id if isinstance (self .ds , CoreDatasource ) else None
527+
528+ self .chat_question .terminologies = get_terminology_template (self .session , self .chat_question .question , oid )
529+ self .chat_question .data_training = get_training_template (self .session , self .chat_question .question , ds_id ,
530+ oid )
530531
531532 self .init_messages ()
532533
@@ -938,10 +939,11 @@ def run_task(self, in_chat: bool = True):
938939 try :
939940 if self .ds :
940941 oid = self .ds .oid if isinstance (self .ds , CoreDatasource ) else 1
941- dsId = self .ds .id if isinstance (self .ds , CoreDatasource ) else None
942- self .chat_question .terminologies = get_terminology_template (self .session , self .chat_question .question , oid )
943- self .chat_question .data_training = get_training_template (self .session , self .chat_question .question , dsId , oid )
944-
942+ ds_id = self .ds .id if isinstance (self .ds , CoreDatasource ) else None
943+ self .chat_question .terminologies = get_terminology_template (self .session , self .chat_question .question ,
944+ oid )
945+ self .chat_question .data_training = get_training_template (self .session , self .chat_question .question ,
946+ ds_id , oid )
945947
946948 self .init_messages ()
947949
0 commit comments