Skip to content

Commit 0d9d384

Browse files
authored
Support us-gov. and other multi-character Bedrock geo prefixes (#3645)
1 parent 302fad5 commit 0d9d384

File tree

3 files changed

+62
-15
lines changed

3 files changed

+62
-15
lines changed

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UserError
4444
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item
4545
from pydantic_ai.providers import Provider, infer_provider
46-
from pydantic_ai.providers.bedrock import BedrockModelProfile
46+
from pydantic_ai.providers.bedrock import BEDROCK_GEO_PREFIXES, BedrockModelProfile
4747
from pydantic_ai.settings import ModelSettings
4848
from pydantic_ai.tools import ToolDefinition
4949

@@ -155,13 +155,6 @@
155155
'tool_use': 'tool_call',
156156
}
157157

158-
_AWS_BEDROCK_INFERENCE_GEO_PREFIXES: tuple[str, ...] = ('us.', 'eu.', 'apac.', 'jp.', 'au.', 'ca.', 'global.')
159-
"""Geo prefixes for Bedrock inference profile IDs (e.g., 'eu.', 'us.').
160-
161-
Used to strip the geo prefix so we can pass a pure foundation model ID/ARN to CountTokens,
162-
which does not accept profile IDs. Extend if new geos appear (e.g., 'global.', 'us-gov.').
163-
"""
164-
165158

166159
class BedrockModelSettings(ModelSettings, total=False):
167160
"""Settings for Bedrock models.
@@ -693,9 +686,9 @@ def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef:
693686
@staticmethod
694687
def _remove_inference_geo_prefix(model_name: BedrockModelName) -> BedrockModelName:
695688
"""Remove inference geographic prefix from model ID if present."""
696-
for prefix in _AWS_BEDROCK_INFERENCE_GEO_PREFIXES:
697-
if model_name.startswith(prefix):
698-
return model_name.removeprefix(prefix)
689+
for prefix in BEDROCK_GEO_PREFIXES:
690+
if model_name.startswith(f'{prefix}.'):
691+
return model_name.removeprefix(f'{prefix}.')
699692
return model_name
700693

701694

pydantic_ai_slim/pydantic_ai/providers/bedrock.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ def bedrock_deepseek_model_profile(model_name: str) -> ModelProfile | None:
5858
return profile # pragma: no cover
5959

6060

61+
# Known geo prefixes for cross-region inference profile IDs
62+
BEDROCK_GEO_PREFIXES: tuple[str, ...] = ('us', 'eu', 'apac', 'jp', 'au', 'ca', 'global', 'us-gov')
63+
64+
6165
class BedrockProvider(Provider[BaseClient]):
6266
"""Provider for AWS Bedrock."""
6367

@@ -90,10 +94,11 @@ def model_profile(self, model_name: str) -> ModelProfile | None:
9094
# Split the model name into parts
9195
parts = model_name.split('.', 2)
9296

93-
# Handle regional prefixes (e.g. "us.")
94-
if len(parts) > 2 and len(parts[0]) == 2:
97+
# Handle regional prefixes
98+
if len(parts) > 2 and parts[0] in BEDROCK_GEO_PREFIXES:
9599
parts = parts[1:]
96100

101+
# required format is provider.model-name-with-version
97102
if len(parts) < 2:
98103
return None
99104

tests/providers/test_bedrock.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import cast
1+
from typing import cast, get_args
22

33
import pytest
44
from pytest_mock import MockerFixture
@@ -16,7 +16,8 @@
1616
with try_import() as imports_successful:
1717
from mypy_boto3_bedrock_runtime import BedrockRuntimeClient
1818

19-
from pydantic_ai.providers.bedrock import BedrockModelProfile, BedrockProvider
19+
from pydantic_ai.models.bedrock import LatestBedrockModelNames
20+
from pydantic_ai.providers.bedrock import BEDROCK_GEO_PREFIXES, BedrockModelProfile, BedrockProvider
2021

2122

2223
pytestmark = pytest.mark.skipif(not imports_successful(), reason='bedrock not installed')
@@ -100,3 +101,51 @@ def test_bedrock_provider_model_profile(env: TestEnv, mocker: MockerFixture):
100101

101102
unknown_model = provider.model_profile('unknown.unknown-model')
102103
assert unknown_model is None
104+
105+
106+
@pytest.mark.parametrize('prefix', BEDROCK_GEO_PREFIXES)
107+
def test_bedrock_provider_model_profile_all_geo_prefixes(env: TestEnv, prefix: str):
108+
"""Test that all cross-region inference geo prefixes are correctly handled."""
109+
env.set('AWS_DEFAULT_REGION', 'us-east-1')
110+
provider = BedrockProvider()
111+
112+
model_name = f'{prefix}.anthropic.claude-sonnet-4-5-20250929-v1:0'
113+
profile = provider.model_profile(model_name)
114+
115+
assert profile is not None, f'model_profile returned None for {model_name}'
116+
117+
118+
def test_bedrock_provider_model_profile_with_unknown_geo_prefix(env: TestEnv):
119+
env.set('AWS_DEFAULT_REGION', 'us-east-1')
120+
provider = BedrockProvider()
121+
122+
model_name = 'narnia.anthropic.claude-sonnet-4-5-20250929-v1:0'
123+
profile = provider.model_profile(model_name)
124+
assert profile is None, f'model_profile returned {profile} for {model_name}'
125+
126+
127+
def test_latest_bedrock_model_names_geo_prefixes_are_supported():
128+
"""Ensure all geo prefixes used in LatestBedrockModelNames are in BEDROCK_GEO_PREFIXES.
129+
130+
This test prevents adding new model names with geo prefixes that aren't handled
131+
by the provider's model_profile method.
132+
"""
133+
model_names = get_args(LatestBedrockModelNames)
134+
135+
missing_prefixes: set[str] = set()
136+
137+
for model_name in model_names:
138+
# Model names with geo prefixes have 3+ dot-separated parts:
139+
# - No prefix: "anthropic.claude-xxx" (2 parts)
140+
# - With prefix: "us.anthropic.claude-xxx" (3 parts)
141+
parts = model_name.split('.')
142+
if len(parts) >= 3:
143+
geo_prefix = parts[0]
144+
if geo_prefix not in BEDROCK_GEO_PREFIXES: # pragma: no cover
145+
missing_prefixes.add(geo_prefix)
146+
147+
if missing_prefixes: # pragma: no cover
148+
pytest.fail(
149+
f'Found geo prefixes in LatestBedrockModelNames that are not in BEDROCK_GEO_PREFIXES: {missing_prefixes}. '
150+
f'Please add them to BEDROCK_GEO_PREFIXES'
151+
)

0 commit comments

Comments
 (0)