Skip to content

Commit 7efff16

Browse files
Add Azure OpenAI provider support
Adds `azure_openai` as a provider option for both pipeline and simulation zones. Reuses `OpenAIProvider` by swapping in `AzureOpenAI`/`AsyncAzureOpenAI` clients at construction time — no new provider classes needed. Config: `entropy config set simulation.provider azure_openai` Env vars: AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_DEPLOYMENT Bumps version to 0.1.2.
1 parent 94a26eb commit 7efff16

7 files changed

Lines changed: 318 additions & 34 deletions

File tree

entropy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Entropy: Simulate how populations respond to scenarios."""
22

3-
__version__ = "0.1.1"
3+
__version__ = "0.1.2"

entropy/cli/commands/config_cmd.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
reset_config,
99
CONFIG_FILE,
1010
get_api_key,
11+
get_azure_config,
1112
)
1213

1314

@@ -121,6 +122,20 @@ def _show_config():
121122
console.print("[bold cyan]API Keys[/bold cyan] (from env vars)")
122123
_show_key_status("openai", "OPENAI_API_KEY")
123124
_show_key_status("claude", "ANTHROPIC_API_KEY")
125+
_show_key_status("azure_openai", "AZURE_OPENAI_API_KEY")
126+
127+
# Azure-specific config (show when Azure provider is in use)
128+
active_providers = {config.pipeline.provider, config.simulation.provider}
129+
if "azure_openai" in active_providers:
130+
azure_cfg = get_azure_config("azure_openai")
131+
console.print()
132+
console.print("[bold cyan]Azure OpenAI[/bold cyan]")
133+
console.print(
134+
f" endpoint = {azure_cfg['azure_endpoint'] or '[dim]not set[/dim]'}"
135+
)
136+
console.print(f" api_version = {azure_cfg['api_version']}")
137+
if azure_cfg["azure_deployment"]:
138+
console.print(f" deployment = {azure_cfg['azure_deployment']}")
124139

125140
# Config file
126141
console.print()

entropy/config.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def get_api_key(provider: str) -> str:
213213
Supports:
214214
- openai: OPENAI_API_KEY
215215
- claude: ANTHROPIC_API_KEY
216+
- azure_openai: AZURE_OPENAI_API_KEY
216217
217218
Returns empty string if not found (providers will raise on missing keys).
218219
"""
@@ -221,9 +222,32 @@ def get_api_key(provider: str) -> str:
221222
return os.environ.get("OPENAI_API_KEY", "")
222223
elif provider == "claude":
223224
return os.environ.get("ANTHROPIC_API_KEY", "")
225+
elif provider == "azure_openai":
226+
return os.environ.get("AZURE_OPENAI_API_KEY", "")
224227
return ""
225228

226229

230+
def get_azure_config(provider: str) -> dict[str, str]:
231+
"""Get Azure-specific configuration from environment variables.
232+
233+
Args:
234+
provider: 'azure_openai'
235+
236+
Returns:
237+
Dict of Azure config values (endpoint, api_version, deployment).
238+
"""
239+
_ensure_dotenv()
240+
if provider == "azure_openai":
241+
return {
242+
"azure_endpoint": os.environ.get("AZURE_OPENAI_ENDPOINT", ""),
243+
"api_version": os.environ.get(
244+
"AZURE_OPENAI_API_VERSION", "2025-03-01-preview"
245+
),
246+
"azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT", ""),
247+
}
248+
return {}
249+
250+
227251
# =============================================================================
228252
# Global config singleton
229253
# =============================================================================

entropy/core/providers/__init__.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010

1111
from .base import LLMProvider
12-
from ...config import get_config, get_api_key
12+
from ...config import get_config, get_api_key, get_azure_config
1313

1414

1515
# Cached simulation provider — reused across batch calls so the async
@@ -29,9 +29,25 @@ def _create_provider(provider_name: str) -> LLMProvider:
2929
from .claude import ClaudeProvider
3030

3131
return ClaudeProvider(api_key=api_key)
32+
elif provider_name == "azure_openai":
33+
from .openai import OpenAIProvider
34+
35+
azure_cfg = get_azure_config(provider_name)
36+
if not azure_cfg.get("azure_endpoint"):
37+
raise ValueError(
38+
"AZURE_OPENAI_ENDPOINT not found. Set it as an environment variable.\n"
39+
" export AZURE_OPENAI_ENDPOINT=https://<resource>.cognitiveservices.azure.com/"
40+
)
41+
return OpenAIProvider(
42+
api_key=api_key,
43+
azure_endpoint=azure_cfg["azure_endpoint"],
44+
api_version=azure_cfg.get("api_version", "2025-03-01-preview"),
45+
azure_deployment=azure_cfg.get("azure_deployment", ""),
46+
)
3247
else:
3348
raise ValueError(
34-
f"Unknown LLM provider: {provider_name}. Valid options: 'openai', 'claude'"
49+
f"Unknown LLM provider: {provider_name}. "
50+
f"Valid options: 'openai', 'claude', 'azure_openai'"
3551
)
3652

3753

entropy/core/providers/openai.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,51 @@
2323

2424

2525
class OpenAIProvider(LLMProvider):
26-
"""OpenAI LLM provider using the Responses API."""
26+
"""OpenAI LLM provider using the Responses API.
27+
28+
Supports both standard OpenAI and Azure OpenAI endpoints.
29+
When azure_endpoint is provided, uses AzureOpenAI/AsyncAzureOpenAI clients.
30+
"""
2731

2832
provider_name = "openai"
2933

30-
def __init__(self, api_key: str = "") -> None:
34+
def __init__(
35+
self,
36+
api_key: str = "",
37+
*,
38+
azure_endpoint: str = "",
39+
api_version: str = "",
40+
azure_deployment: str = "",
41+
) -> None:
42+
self._is_azure = bool(azure_endpoint)
43+
self._azure_endpoint = azure_endpoint
44+
self._api_version = api_version
45+
self._azure_deployment = azure_deployment
46+
3147
if not api_key:
32-
raise ValueError(
33-
"OPENAI_API_KEY not found. Set it as an environment variable.\n"
34-
" export OPENAI_API_KEY=sk-..."
35-
)
48+
if self._is_azure:
49+
raise ValueError(
50+
"AZURE_OPENAI_API_KEY not found. Set it as an environment variable.\n"
51+
" export AZURE_OPENAI_API_KEY=<your-subscription-key>"
52+
)
53+
else:
54+
raise ValueError(
55+
"OPENAI_API_KEY not found. Set it as an environment variable.\n"
56+
" export OPENAI_API_KEY=sk-..."
57+
)
3658
super().__init__(api_key)
3759

60+
if self._is_azure:
61+
self.provider_name = "azure_openai"
62+
63+
def _resolve_model(self, model: str | None, default: str) -> str:
64+
"""Resolve model name, using Azure deployment as fallback when applicable."""
65+
if model:
66+
return model
67+
if self._is_azure and self._azure_deployment:
68+
return self._azure_deployment
69+
return default
70+
3871
@staticmethod
3972
def _extract_output_text(response) -> str | None:
4073
"""Extract the output text content from an OpenAI Responses API response.
@@ -100,11 +133,28 @@ def default_research_model(self) -> str:
100133
return "gpt-5"
101134

102135
def _get_client(self) -> OpenAI:
136+
if self._is_azure:
137+
from openai import AzureOpenAI
138+
139+
return AzureOpenAI(
140+
api_key=self._api_key,
141+
azure_endpoint=self._azure_endpoint,
142+
api_version=self._api_version,
143+
)
103144
return OpenAI(api_key=self._api_key)
104145

105146
def _get_async_client(self) -> AsyncOpenAI:
106147
if self._cached_async_client is None:
107-
self._cached_async_client = AsyncOpenAI(api_key=self._api_key)
148+
if self._is_azure:
149+
from openai import AsyncAzureOpenAI
150+
151+
self._cached_async_client = AsyncAzureOpenAI(
152+
api_key=self._api_key,
153+
azure_endpoint=self._azure_endpoint,
154+
api_version=self._api_version,
155+
)
156+
else:
157+
self._cached_async_client = AsyncOpenAI(api_key=self._api_key)
108158
return self._cached_async_client
109159

110160
def simple_call(
@@ -116,7 +166,7 @@ def simple_call(
116166
log: bool = True,
117167
max_tokens: int | None = None,
118168
) -> dict:
119-
model = model or self.default_simple_model
169+
model = self._resolve_model(model, self.default_simple_model)
120170
client = self._get_client()
121171

122172
# Acquire rate limit capacity before making the call
@@ -169,7 +219,7 @@ async def simple_call_async(
169219
model: str | None = None,
170220
max_tokens: int | None = None,
171221
) -> dict:
172-
model = model or self.default_simple_model
222+
model = self._resolve_model(model, self.default_simple_model)
173223
client = self._get_async_client()
174224

175225
request_params = {
@@ -211,7 +261,7 @@ def reasoning_call(
211261
max_retries: int = 2,
212262
on_retry: RetryCallback | None = None,
213263
) -> dict:
214-
model = model or self.default_reasoning_model
264+
model = self._resolve_model(model, self.default_reasoning_model)
215265
client = self._get_client()
216266

217267
effective_prompt = prompt
@@ -272,7 +322,7 @@ def agentic_research(
272322
max_retries: int = 2,
273323
on_retry: RetryCallback | None = None,
274324
) -> tuple[dict, list[str]]:
275-
model = model or self.default_research_model
325+
model = self._resolve_model(model, self.default_research_model)
276326
client = self._get_client()
277327

278328
effective_prompt = prompt

entropy/core/rate_limits.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@
8888
# Map "claude" provider name to anthropic profiles
8989
RATE_LIMIT_PROFILES["claude"] = RATE_LIMIT_PROFILES["anthropic"]
9090

91+
# Azure OpenAI uses the same rate limit profiles as standard OpenAI
92+
RATE_LIMIT_PROFILES["azure_openai"] = RATE_LIMIT_PROFILES["openai"]
93+
9194

9295
def get_limits(
9396
provider: str,

0 commit comments

Comments
 (0)