diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 8d81948f68..57f517dad5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -43,7 +43,7 @@ from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UserError from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item from pydantic_ai.providers import Provider, infer_provider -from pydantic_ai.providers.bedrock import BedrockModelProfile +from pydantic_ai.providers.bedrock import BEDROCK_GEO_PREFIXES, BedrockModelProfile from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import ToolDefinition @@ -155,13 +155,6 @@ 'tool_use': 'tool_call', } -_AWS_BEDROCK_INFERENCE_GEO_PREFIXES: tuple[str, ...] = ('us.', 'eu.', 'apac.', 'jp.', 'au.', 'ca.', 'global.') -"""Geo prefixes for Bedrock inference profile IDs (e.g., 'eu.', 'us.'). - -Used to strip the geo prefix so we can pass a pure foundation model ID/ARN to CountTokens, -which does not accept profile IDs. Extend if new geos appear (e.g., 'global.', 'us-gov.'). -""" - class BedrockModelSettings(ModelSettings, total=False): """Settings for Bedrock models. @@ -693,9 +686,9 @@ def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef: @staticmethod def _remove_inference_geo_prefix(model_name: BedrockModelName) -> BedrockModelName: """Remove inference geographic prefix from model ID if present.""" - for prefix in _AWS_BEDROCK_INFERENCE_GEO_PREFIXES: - if model_name.startswith(prefix): - return model_name.removeprefix(prefix) + for prefix in BEDROCK_GEO_PREFIXES: + if model_name.startswith(f'{prefix}.'): + return model_name.removeprefix(f'{prefix}.') return model_name diff --git a/pydantic_ai_slim/pydantic_ai/providers/bedrock.py b/pydantic_ai_slim/pydantic_ai/providers/bedrock.py index f6fac74fae..3f855021e4 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/providers/bedrock.py @@ -58,6 +58,10 @@ def bedrock_deepseek_model_profile(model_name: str) -> ModelProfile | None: return profile # pragma: no cover +# Known geo prefixes for cross-region inference profile IDs +BEDROCK_GEO_PREFIXES: tuple[str, ...] = ('us', 'eu', 'apac', 'jp', 'au', 'ca', 'global', 'us-gov') + + class BedrockProvider(Provider[BaseClient]): """Provider for AWS Bedrock.""" @@ -90,10 +94,11 @@ def model_profile(self, model_name: str) -> ModelProfile | None: # Split the model name into parts parts = model_name.split('.', 2) - # Handle regional prefixes (e.g. "us.") - if len(parts) > 2 and len(parts[0]) == 2: + # Handle regional prefixes + if len(parts) > 2 and parts[0] in BEDROCK_GEO_PREFIXES: parts = parts[1:] + # required format is provider.model-name-with-version if len(parts) < 2: return None diff --git a/tests/providers/test_bedrock.py b/tests/providers/test_bedrock.py index e791fe7a43..c0dcbd2ddc 100644 --- a/tests/providers/test_bedrock.py +++ b/tests/providers/test_bedrock.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import cast, get_args import pytest from pytest_mock import MockerFixture @@ -16,7 +16,8 @@ with try_import() as imports_successful: from mypy_boto3_bedrock_runtime import BedrockRuntimeClient - from pydantic_ai.providers.bedrock import BedrockModelProfile, BedrockProvider + from pydantic_ai.models.bedrock import LatestBedrockModelNames + from pydantic_ai.providers.bedrock import BEDROCK_GEO_PREFIXES, BedrockModelProfile, BedrockProvider pytestmark = pytest.mark.skipif(not imports_successful(), reason='bedrock not installed') @@ -100,3 +101,51 @@ def test_bedrock_provider_model_profile(env: TestEnv, mocker: MockerFixture): unknown_model = provider.model_profile('unknown.unknown-model') assert unknown_model is None + + +@pytest.mark.parametrize('prefix', BEDROCK_GEO_PREFIXES) +def test_bedrock_provider_model_profile_all_geo_prefixes(env: TestEnv, prefix: str): + """Test that all cross-region inference geo prefixes are correctly handled.""" + env.set('AWS_DEFAULT_REGION', 'us-east-1') + provider = BedrockProvider() + + model_name = f'{prefix}.anthropic.claude-sonnet-4-5-20250929-v1:0' + profile = provider.model_profile(model_name) + + assert profile is not None, f'model_profile returned None for {model_name}' + + +def test_bedrock_provider_model_profile_with_unknown_geo_prefix(env: TestEnv): + env.set('AWS_DEFAULT_REGION', 'us-east-1') + provider = BedrockProvider() + + model_name = 'narnia.anthropic.claude-sonnet-4-5-20250929-v1:0' + profile = provider.model_profile(model_name) + assert profile is None, f'model_profile returned {profile} for {model_name}' + + +def test_latest_bedrock_model_names_geo_prefixes_are_supported(): + """Ensure all geo prefixes used in LatestBedrockModelNames are in BEDROCK_GEO_PREFIXES. + + This test prevents adding new model names with geo prefixes that aren't handled + by the provider's model_profile method. + """ + model_names = get_args(LatestBedrockModelNames) + + missing_prefixes: set[str] = set() + + for model_name in model_names: + # Model names with geo prefixes have 3+ dot-separated parts: + # - No prefix: "anthropic.claude-xxx" (2 parts) + # - With prefix: "us.anthropic.claude-xxx" (3 parts) + parts = model_name.split('.') + if len(parts) >= 3: + geo_prefix = parts[0] + if geo_prefix not in BEDROCK_GEO_PREFIXES: # pragma: no cover + missing_prefixes.add(geo_prefix) + + if missing_prefixes: # pragma: no cover + pytest.fail( + f'Found geo prefixes in LatestBedrockModelNames that are not in BEDROCK_GEO_PREFIXES: {missing_prefixes}. ' + f'Please add them to BEDROCK_GEO_PREFIXES' + )