44import json
55import logging
66from collections .abc import AsyncGenerator
7+ from typing import Any , Callable , Dict
78
89# Third-Party
910from langchain .agents import AgentExecutor , create_openai_functions_agent
@@ -54,7 +55,7 @@ def create_llm(config: AgentConfig) -> BaseChatModel:
5455 provider = config .llm_provider .lower ()
5556
5657 # Common LLM arguments
57- common_args = {
58+ common_args : Dict [ str , Any ] = {
5859 "temperature" : config .temperature ,
5960 "streaming" : config .streaming_enabled ,
6061 }
@@ -64,68 +65,89 @@ def create_llm(config: AgentConfig) -> BaseChatModel:
6465 if config .top_p :
6566 common_args ["top_p" ] = config .top_p
6667
67- if provider == "openai" :
68- if not config .openai_api_key :
69- raise ValueError ("OPENAI_API_KEY is required for OpenAI provider" )
68+ # Provider factory functions
69+ providers : Dict [str , Callable [[AgentConfig , Dict [str , Any ]], BaseChatModel ]] = {
70+ "openai" : _create_openai_llm ,
71+ "azure" : _create_azure_llm ,
72+ "bedrock" : _create_bedrock_llm ,
73+ "ollama" : _create_ollama_llm ,
74+ "anthropic" : _create_anthropic_llm ,
75+ }
7076
71- openai_args = {"model" : config .default_model , "api_key" : config .openai_api_key , ** common_args }
77+ if provider not in providers :
78+ raise ValueError (f"Unsupported LLM provider: { provider } . " f"Supported providers: { ', ' .join (providers .keys ())} " )
7279
73- if config .openai_base_url :
74- openai_args ["base_url" ] = config .openai_base_url
75- if config .openai_organization :
76- openai_args ["organization" ] = config .openai_organization
80+ return providers [provider ](config , common_args )
7781
78- return ChatOpenAI (** openai_args )
7982
80- elif provider == "azure" :
81- if not all ([config .azure_openai_api_key , config .azure_openai_endpoint , config .azure_deployment_name ]):
82- raise ValueError (
83- "Azure OpenAI requires AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_DEPLOYMENT_NAME"
84- )
83+ def _create_openai_llm (config : AgentConfig , common_args : Dict [str , Any ]) -> BaseChatModel :
84+ """Create OpenAI LLM instance."""
8585
86- return AzureChatOpenAI (
87- api_key = config .azure_openai_api_key ,
88- azure_endpoint = config .azure_openai_endpoint ,
89- api_version = config .azure_openai_api_version ,
90- azure_deployment = config .azure_deployment_name ,
91- ** common_args ,
92- )
93-
94- elif provider == "bedrock" :
95- if BedrockChat is None :
96- raise ImportError ("langchain-aws is required for Bedrock support. Install with: pip install langchain-aws" )
97- if not all ([config .aws_access_key_id , config .aws_secret_access_key , config .bedrock_model_id ]):
98- raise ValueError ("AWS Bedrock requires AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and BEDROCK_MODEL_ID" )
99-
100- return BedrockChat (
101- model_id = config .bedrock_model_id ,
102- region_name = config .aws_region ,
103- credentials_profile_name = None , # Use environment variables
104- ** common_args ,
105- )
106-
107- elif provider == "ollama" :
108- if ChatOllama is None :
109- raise ImportError (
110- "langchain-community is required for OLLAMA support. Install with: pip install langchain-community"
111- )
112- if not config .ollama_model :
113- raise ValueError ("OLLAMA_MODEL is required for OLLAMA provider" )
86+ if not config .openai_api_key :
87+ raise ValueError ("OPENAI_API_KEY is required for OpenAI provider" )
11488
115- return ChatOllama ( model = config .ollama_model , base_url = config .ollama_base_url , ** common_args )
89+ openai_args = { " model" : config .default_model , "api_key" : config .openai_api_key , ** common_args }
11690
117- elif provider == "anthropic" :
118- if ChatAnthropic is None :
119- raise ImportError (
120- "langchain-anthropic is required for Anthropic support. Install with: pip install langchain-anthropic"
121- )
122- if not config .anthropic_api_key :
123- raise ValueError ("ANTHROPIC_API_KEY is required for Anthropic provider" )
91+ if config .openai_base_url :
92+ openai_args ["base_url" ] = config .openai_base_url
93+ if config .openai_organization :
94+ openai_args ["organization" ] = config .openai_organization
95+
96+ return ChatOpenAI (** openai_args )
97+
98+
99+ def _create_azure_llm (config : AgentConfig , common_args : Dict [str , Any ]) -> BaseChatModel :
100+ """Create Azure OpenAI LLM instance."""
101+
102+ required_fields = [config .azure_openai_api_key , config .azure_openai_endpoint , config .azure_deployment_name ]
103+
104+ if not all (required_fields ):
105+ raise ValueError ("Azure OpenAI requires AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_DEPLOYMENT_NAME" )
106+
107+ return AzureChatOpenAI (
108+ api_key = config .azure_openai_api_key , azure_endpoint = config .azure_openai_endpoint , api_version = config .azure_openai_api_version , azure_deployment = config .azure_deployment_name , ** common_args
109+ )
110+
111+
112+ def _create_bedrock_llm (config : AgentConfig , common_args : Dict [str , Any ]) -> BaseChatModel :
113+ """Create AWS Bedrock LLM instance."""
124114
125- return ChatAnthropic (model = config .default_model , api_key = config .anthropic_api_key , ** common_args )
115+ if BedrockChat is None :
116+ raise ImportError ("langchain-aws is required for Bedrock support. " "Install with: pip install langchain-aws" )
126117
127- else :
128- raise ValueError (f"Unsupported LLM provider: { provider } . Supported: openai, azure, bedrock, ollama, anthropic" )
118+ required_fields = [config .aws_access_key_id , config .aws_secret_access_key , config .bedrock_model_id ]
119+
120+ if not all (required_fields ):
121+ raise ValueError ("AWS Bedrock requires AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and BEDROCK_MODEL_ID" )
122+
123+ return BedrockChat (
124+ model_id = config .bedrock_model_id ,
125+ region_name = config .aws_region ,
126+ credentials_profile_name = None , # Use environment variables
127+ ** common_args ,
128+ )
129+
130+
131+ def _create_ollama_llm (config : AgentConfig , common_args : Dict [str , Any ]) -> BaseChatModel :
132+ """Create OLLAMA LLM instance."""
133+ if ChatOllama is None :
134+ raise ImportError ("langchain-community is required for OLLAMA support. " "Install with: pip install langchain-community" )
135+
136+ if not config .ollama_model :
137+ raise ValueError ("OLLAMA_MODEL is required for OLLAMA provider" )
138+
139+ return ChatOllama (model = config .ollama_model , base_url = config .ollama_base_url , ** common_args )
140+
141+
142+ def _create_anthropic_llm (config : AgentConfig , common_args : Dict [str , Any ]) -> BaseChatModel :
143+ """Create Anthropic LLM instance."""
144+ if ChatAnthropic is None :
145+ raise ImportError ("langchain-anthropic is required for Anthropic support. " "Install with: pip install langchain-anthropic" )
146+
147+ if not config .anthropic_api_key :
148+ raise ValueError ("ANTHROPIC_API_KEY is required for Anthropic provider" )
149+
150+ return ChatAnthropic (model = config .default_model , api_key = config .anthropic_api_key , ** common_args )
129151
130152
131153class MCPTool (BaseTool ):
@@ -309,12 +331,7 @@ def is_initialized(self) -> bool:
309331 async def check_readiness (self ) -> bool :
310332 """Check if agent is ready to handle requests"""
311333 try :
312- return (
313- self ._initialized
314- and self .agent_executor is not None
315- and len (self .tools ) >= 0 # Allow 0 tools for testing
316- and await self .test_gateway_connection ()
317- )
334+ return self ._initialized and self .agent_executor is not None and len (self .tools ) >= 0 and await self .test_gateway_connection () # Allow 0 tools for testing
318335 except Exception :
319336 return False
320337
@@ -366,9 +383,7 @@ async def run_async(
366383 chat_history .append (SystemMessage (content = msg ["content" ]))
367384
368385 # Run the agent
369- result = await self .agent_executor .ainvoke (
370- {"input" : input_text , "chat_history" : chat_history , "tool_names" : [tool .name for tool in self .tools ]}
371- )
386+ result = await self .agent_executor .ainvoke ({"input" : input_text , "chat_history" : chat_history , "tool_names" : [tool .name for tool in self .tools ]})
372387
373388 return result ["output" ]
374389
0 commit comments