2323
2424
2525class OpenAIProvider (LLMProvider ):
26- """OpenAI LLM provider using the Responses API."""
26+ """OpenAI LLM provider using the Responses API.
27+
28+ Supports both standard OpenAI and Azure OpenAI endpoints.
29+ When azure_endpoint is provided, uses AzureOpenAI/AsyncAzureOpenAI clients.
30+ """
2731
2832 provider_name = "openai"
2933
30- def __init__ (self , api_key : str = "" ) -> None :
34+ def __init__ (
35+ self ,
36+ api_key : str = "" ,
37+ * ,
38+ azure_endpoint : str = "" ,
39+ api_version : str = "" ,
40+ azure_deployment : str = "" ,
41+ ) -> None :
42+ self ._is_azure = bool (azure_endpoint )
43+ self ._azure_endpoint = azure_endpoint
44+ self ._api_version = api_version
45+ self ._azure_deployment = azure_deployment
46+
3147 if not api_key :
32- raise ValueError (
33- "OPENAI_API_KEY not found. Set it as an environment variable.\n "
34- " export OPENAI_API_KEY=sk-..."
35- )
48+ if self ._is_azure :
49+ raise ValueError (
50+ "AZURE_OPENAI_API_KEY not found. Set it as an environment variable.\n "
51+ " export AZURE_OPENAI_API_KEY=<your-subscription-key>"
52+ )
53+ else :
54+ raise ValueError (
55+ "OPENAI_API_KEY not found. Set it as an environment variable.\n "
56+ " export OPENAI_API_KEY=sk-..."
57+ )
3658 super ().__init__ (api_key )
3759
60+ if self ._is_azure :
61+ self .provider_name = "azure_openai"
62+
63+ def _resolve_model (self , model : str | None , default : str ) -> str :
64+ """Resolve model name, using Azure deployment as fallback when applicable."""
65+ if model :
66+ return model
67+ if self ._is_azure and self ._azure_deployment :
68+ return self ._azure_deployment
69+ return default
70+
3871 @staticmethod
3972 def _extract_output_text (response ) -> str | None :
4073 """Extract the output text content from an OpenAI Responses API response.
@@ -100,11 +133,28 @@ def default_research_model(self) -> str:
100133 return "gpt-5"
101134
102135 def _get_client (self ) -> OpenAI :
136+ if self ._is_azure :
137+ from openai import AzureOpenAI
138+
139+ return AzureOpenAI (
140+ api_key = self ._api_key ,
141+ azure_endpoint = self ._azure_endpoint ,
142+ api_version = self ._api_version ,
143+ )
103144 return OpenAI (api_key = self ._api_key )
104145
105146 def _get_async_client (self ) -> AsyncOpenAI :
106147 if self ._cached_async_client is None :
107- self ._cached_async_client = AsyncOpenAI (api_key = self ._api_key )
148+ if self ._is_azure :
149+ from openai import AsyncAzureOpenAI
150+
151+ self ._cached_async_client = AsyncAzureOpenAI (
152+ api_key = self ._api_key ,
153+ azure_endpoint = self ._azure_endpoint ,
154+ api_version = self ._api_version ,
155+ )
156+ else :
157+ self ._cached_async_client = AsyncOpenAI (api_key = self ._api_key )
108158 return self ._cached_async_client
109159
110160 def simple_call (
@@ -116,7 +166,7 @@ def simple_call(
116166 log : bool = True ,
117167 max_tokens : int | None = None ,
118168 ) -> dict :
119- model = model or self .default_simple_model
169+ model = self . _resolve_model ( model , self .default_simple_model )
120170 client = self ._get_client ()
121171
122172 # Acquire rate limit capacity before making the call
@@ -169,7 +219,7 @@ async def simple_call_async(
169219 model : str | None = None ,
170220 max_tokens : int | None = None ,
171221 ) -> dict :
172- model = model or self .default_simple_model
222+ model = self . _resolve_model ( model , self .default_simple_model )
173223 client = self ._get_async_client ()
174224
175225 request_params = {
@@ -211,7 +261,7 @@ def reasoning_call(
211261 max_retries : int = 2 ,
212262 on_retry : RetryCallback | None = None ,
213263 ) -> dict :
214- model = model or self .default_reasoning_model
264+ model = self . _resolve_model ( model , self .default_reasoning_model )
215265 client = self ._get_client ()
216266
217267 effective_prompt = prompt
@@ -272,7 +322,7 @@ def agentic_research(
272322 max_retries : int = 2 ,
273323 on_retry : RetryCallback | None = None ,
274324 ) -> tuple [dict , list [str ]]:
275- model = model or self .default_research_model
325+ model = self . _resolve_model ( model , self .default_research_model )
276326 client = self ._get_client ()
277327
278328 effective_prompt = prompt
0 commit comments