Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8d52d65
Clarify usage of agent factories
dsfaccini Nov 28, 2025
3128b4a
Merge branch 'pydantic:main' into main
dsfaccini Nov 30, 2025
6d942f5
implement tool choice resolution per model
dsfaccini Dec 1, 2025
0585347
- centralize logic in utility and add tests for all providers
dsfaccini Dec 2, 2025
96681ac
coverage?
dsfaccini Dec 2, 2025
e71dc86
coverage
dsfaccini Dec 2, 2025
5c387fd
imrpove tests
dsfaccini Dec 2, 2025
4dcfbe4
Merge branch 'main' into tool-choice
dsfaccini Dec 4, 2025
363c718
improvde code quality
dsfaccini Dec 5, 2025
31bb4e1
deduplicate openai logic
dsfaccini Dec 5, 2025
338a073
remove cast
dsfaccini Dec 5, 2025
07fcb6b
re-run existent cassettes and record new ones for new tool choice tests
dsfaccini Dec 8, 2025
51cada5
fix snapshots
dsfaccini Dec 8, 2025
6597e0b
fix tests
dsfaccini Dec 9, 2025
914748f
Merge branch 'main' into tool-choice
dsfaccini Dec 9, 2025
57ff6bf
upgrade to newer models
dsfaccini Dec 9, 2025
8a24d41
Merge upstream/main into tool-choice
dsfaccini Dec 9, 2025
1a46a7b
support tool choice callable to force tools on first request or arbit…
dsfaccini Dec 9, 2025
5e4cfb7
Merge branch 'main' into tool-choice
dsfaccini Dec 9, 2025
70dc917
skip tests
dsfaccini Dec 9, 2025
f924378
fix: skip lint/test for docstring examples with RunContext
dsfaccini Dec 9, 2025
80ece45
Merge branch 'main' into tool-choice
dsfaccini Dec 10, 2025
24962c0
add note about serialization obligation
dsfaccini Dec 9, 2025
dc97d4e
revert: remove callable tool_choice and force_first_request
dsfaccini Dec 9, 2025
88884f5
fix: align tool_choice tests with warning strategy and resolve merge …
dsfaccini Dec 10, 2025
b55bac9
covergae
dsfaccini Dec 11, 2025
50c9db1
simplify branches
dsfaccini Dec 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,90 @@ def prompted_output_instructions(self) -> str | None:
__repr__ = _utils.dataclasses_no_defaults_repr


@dataclass
class _ResolvedToolChoice:
"""Provider-agnostic resolved tool choice.

This is the result of validating and resolving the user's `tool_choice` setting.
Providers should map this to their API-specific format.
"""

mode: Literal['none', 'auto', 'required', 'specific']
"""The resolved tool choice mode."""

tool_names: list[str] = field(default_factory=list)
"""For 'specific' mode, the list of tool names to force. Empty for other modes."""

def filter_tools(
self,
function_tools: list[ToolDefinition],
output_tools: list[ToolDefinition],
) -> list[ToolDefinition]:
"""Filter tools based on the resolved mode.

- 'none': only output_tools
- 'required': only function_tools
- 'specific': specified function_tools + output_tools
- 'auto': all tools
"""
if self.mode == 'none':
return list(output_tools)
elif self.mode == 'required':
return list(function_tools)
elif self.mode == 'specific':
allowed = set(self.tool_names)
return [t for t in function_tools if t.name in allowed] + list(output_tools)
else: # 'auto'
return [*function_tools, *output_tools]


def _resolve_tool_choice( # pyright: ignore[reportUnusedFunction]
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> _ResolvedToolChoice | None:
"""Resolve and validate tool_choice from model settings.

This centralizes the common logic for handling tool_choice across all providers:
- Validates tool names in list[str] against available function_tools
- Returns a provider-agnostic _ResolvedToolChoice for the provider to map to their API format

Args:
model_settings: The model settings containing tool_choice.
model_request_parameters: The request parameters containing tool definitions.

Returns:
_ResolvedToolChoice if an explicit tool_choice was provided and validated,
None if tool_choice was not set (provider should use default behavior based on allow_text_output).

Raises:
UserError: If tool names in list[str] are invalid.
"""
user_tool_choice = (model_settings or {}).get('tool_choice')

if user_tool_choice is None:
return None

if user_tool_choice == 'none':
return _ResolvedToolChoice(mode='none')

if user_tool_choice in ('auto', 'required'):
return _ResolvedToolChoice(mode=user_tool_choice)

if isinstance(user_tool_choice, list):
if not user_tool_choice:
return _ResolvedToolChoice(mode='none')
function_tool_names = {t.name for t in model_request_parameters.function_tools}
invalid_names = set(user_tool_choice) - function_tool_names
if invalid_names:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this is too restrictive because some tools may only be conditionally available, but it's possible that if they are available, they would be one of the allowed tools. But I'm OK with having an error now, and removing it if someone complains

raise UserError(
f'Invalid tool names in `tool_choice`: {invalid_names}. '
f'Available function tools: {function_tool_names or "none"}'
)
return _ResolvedToolChoice(mode='specific', tool_names=list(user_tool_choice))

return None # pragma: no cover


class Model(ABC):
"""Abstract class for a model."""

Expand Down
123 changes: 89 additions & 34 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations as _annotations

import io
import warnings
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field, replace
Expand Down Expand Up @@ -42,7 +43,15 @@
from ..providers.anthropic import AsyncAnthropicClient
from ..settings import ModelSettings, merge_model_settings
from ..tools import ToolDefinition
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
from . import (
Model,
ModelRequestParameters,
StreamedResponse,
_resolve_tool_choice, # pyright: ignore[reportPrivateUsage]
check_allow_model_requests,
download_item,
get_user_agent,
)

_FINISH_REASON_MAP: dict[BetaStopReason, FinishReason] = {
'end_turn': 'stop',
Expand Down Expand Up @@ -386,11 +395,9 @@ async def _messages_create(
This is the last step before sending the request to the API.
Most preprocessing has happened in `prepare_request()`.
"""
tools = self._get_tools(model_request_parameters, model_settings)
tools, tool_choice = self._infer_tool_choice(model_settings, model_request_parameters)
tools, mcp_servers, builtin_tool_betas = self._add_builtin_tools(tools, model_request_parameters)

tool_choice = self._infer_tool_choice(tools, model_settings, model_request_parameters)

system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters, model_settings)
self._limit_cache_points(system_prompt, anthropic_messages, tools)
output_format = self._native_output_format(model_request_parameters)
Expand Down Expand Up @@ -474,11 +481,9 @@ async def _messages_count_tokens(
raise UserError('AsyncAnthropicBedrock client does not support `count_tokens` api.')

# standalone function to make it easier to override
tools = self._get_tools(model_request_parameters, model_settings)
tools, tool_choice = self._infer_tool_choice(model_settings, model_request_parameters)
tools, mcp_servers, builtin_tool_betas = self._add_builtin_tools(tools, model_request_parameters)

tool_choice = self._infer_tool_choice(tools, model_settings, model_request_parameters)

system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters, model_settings)
self._limit_cache_points(system_prompt, anthropic_messages, tools)
output_format = self._native_output_format(model_request_parameters)
Expand Down Expand Up @@ -584,22 +589,6 @@ async def _process_streamed_response(
_provider_url=self._provider.base_url,
)

def _get_tools(
self, model_request_parameters: ModelRequestParameters, model_settings: AnthropicModelSettings
) -> list[BetaToolUnionParam]:
tools: list[BetaToolUnionParam] = [
self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()
]

# Add cache_control to the last tool if enabled
if tools and (cache_tool_defs := model_settings.get('anthropic_cache_tool_definitions')):
# If True, use '5m'; otherwise use the specified ttl value
ttl: Literal['5m', '1h'] = '5m' if cache_tool_defs is True else cache_tool_defs
last_tool = tools[-1]
last_tool['cache_control'] = self._build_cache_control(ttl)

return tools

def _add_builtin_tools(
self, tools: list[BetaToolUnionParam], model_request_parameters: ModelRequestParameters
) -> tuple[list[BetaToolUnionParam], list[BetaRequestMCPServerURLDefinitionParam], set[str]]:
Expand Down Expand Up @@ -663,26 +652,91 @@ def _add_builtin_tools(
)
return tools, mcp_servers, beta_features

def _infer_tool_choice(
def _infer_tool_choice( # noqa: C901
self,
tools: list[BetaToolUnionParam],
model_settings: AnthropicModelSettings,
model_request_parameters: ModelRequestParameters,
) -> BetaToolChoiceParam | None:
if not tools:
return None
) -> tuple[list[BetaToolUnionParam], BetaToolChoiceParam | None]:
"""Determine which tools to send and the API tool_choice value.

Returns:
A tuple of (filtered_tools, tool_choice).
"""
thinking_enabled = model_settings.get('anthropic_thinking') is not None
function_tools = model_request_parameters.function_tools
output_tools = model_request_parameters.output_tools

resolved = _resolve_tool_choice(model_settings, model_request_parameters)

if resolved is None:
tool_defs_to_send = [*function_tools, *output_tools]
else:
tool_choice: BetaToolChoiceParam
tool_defs_to_send = resolved.filter_tools(function_tools, output_tools)

# Map ToolDefinitions to Anthropic format
tools: list[BetaToolUnionParam] = [self._map_tool_definition(t) for t in tool_defs_to_send]

# Add cache_control to the last tool if enabled
if tools and (cache_tool_defs := model_settings.get('anthropic_cache_tool_definitions')):
ttl: Literal['5m', '1h'] = '5m' if cache_tool_defs is True else cache_tool_defs
last_tool = tools[-1]
last_tool['cache_control'] = self._build_cache_control(ttl)

if not tools:
return tools, None

tool_choice: BetaToolChoiceParam

if resolved is None:
if not model_request_parameters.allow_text_output:
tool_choice = {'type': 'any'}
else:
tool_choice = {'type': 'auto'}

if 'parallel_tool_calls' in model_settings:
tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls']
elif resolved.mode == 'auto':
if not model_request_parameters.allow_text_output:
tool_choice = {'type': 'any'}
else:
tool_choice = {'type': 'auto'}

elif resolved.mode == 'required':
if thinking_enabled:
raise UserError(
"tool_choice='required' is not supported with Anthropic thinking mode. "
'Use `output_type=NativeOutput(...)` or `PromptedOutput(...)` instead.'
)
tool_choice = {'type': 'any'}

elif resolved.mode == 'none':
if len(output_tools) == 1:
tool_choice = {'type': 'tool', 'name': output_tools[0].name}
else:
warnings.warn(
"Anthropic only supports forcing a single tool. Falling back to 'auto' for multiple output tools."
)
tool_choice = {'type': 'auto'}

elif resolved.mode == 'specific':
if thinking_enabled:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above; if we can't honor the user's wishes at all, we should raise an error

raise UserError(
'Forcing specific tools is not supported with Anthropic thinking mode. '
'Use `output_type=NativeOutput(...)` or `PromptedOutput(...)` instead.'
)
if len(resolved.tool_names) == 1:
tool_choice = {'type': 'tool', 'name': resolved.tool_names[0]}
else:
warnings.warn(
"Anthropic only supports forcing a single tool. Falling back to 'any' for multiple specific tools."
)
tool_choice = {'type': 'any'}

else:
assert_never(resolved.mode)

if 'parallel_tool_calls' in model_settings:
tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls']

return tool_choice
return tools, tool_choice

async def _map_message( # noqa: C901
self,
Expand Down Expand Up @@ -887,9 +941,10 @@ async def _map_message( # noqa: C901
system_prompt_parts.insert(0, instructions)
system_prompt = '\n\n'.join(system_prompt_parts)

ttl: Literal['5m', '1h']
# Add cache_control to the last message content if anthropic_cache_messages is enabled
if anthropic_messages and (cache_messages := model_settings.get('anthropic_cache_messages')):
ttl: Literal['5m', '1h'] = '5m' if cache_messages is True else cache_messages
ttl = '5m' if cache_messages is True else cache_messages
m = anthropic_messages[-1]
content = m['content']
if isinstance(content, str):
Expand All @@ -909,7 +964,7 @@ async def _map_message( # noqa: C901
# If anthropic_cache_instructions is enabled, return system prompt as a list with cache_control
if system_prompt and (cache_instructions := model_settings.get('anthropic_cache_instructions')):
# If True, use '5m'; otherwise use the specified ttl value
ttl: Literal['5m', '1h'] = '5m' if cache_instructions is True else cache_instructions
ttl = '5m' if cache_instructions is True else cache_instructions
system_prompt_blocks = [
BetaTextBlockParam(
type='text',
Expand Down
65 changes: 55 additions & 10 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import functools
import typing
import warnings
from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
Expand Down Expand Up @@ -41,7 +42,13 @@
)
from pydantic_ai._run_context import RunContext
from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UserError
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item
from pydantic_ai.models import (
Model,
ModelRequestParameters,
StreamedResponse,
_resolve_tool_choice, # pyright: ignore[reportPrivateUsage]
download_item,
)
from pydantic_ai.providers import Provider, infer_provider
from pydantic_ai.providers.bedrock import BEDROCK_GEO_PREFIXES, BedrockModelProfile
from pydantic_ai.settings import ModelSettings
Expand Down Expand Up @@ -254,9 +261,6 @@ def system(self) -> str:
"""The model provider."""
return self._provider.name

def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]

@staticmethod
def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef:
tool_spec: ToolSpecificationTypeDef = {'name': f.name, 'inputSchema': {'json': f.parameters_json_schema}}
Expand Down Expand Up @@ -422,7 +426,7 @@ async def _messages_create(
'inferenceConfig': inference_config,
}

tool_config = self._map_tool_config(model_request_parameters)
tool_config = self._map_tool_config(model_request_parameters, model_settings)
if tool_config:
params['toolConfig'] = tool_config

Expand Down Expand Up @@ -478,17 +482,58 @@ def _map_inference_config(

return inference_config

def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> ToolConfigurationTypeDef | None:
tools = self._get_tools(model_request_parameters)
if not tools:
def _map_tool_config(
self,
model_request_parameters: ModelRequestParameters,
model_settings: BedrockModelSettings | None,
) -> ToolConfigurationTypeDef | None:
resolved = _resolve_tool_choice(model_settings, model_request_parameters)
function_tools = model_request_parameters.function_tools
output_tools = model_request_parameters.output_tools

if resolved is None:
tool_defs_to_send = [*function_tools, *output_tools]
else:
tool_defs_to_send = resolved.filter_tools(function_tools, output_tools)

if not tool_defs_to_send:
return None

tools = [self._map_tool_definition(t) for t in tool_defs_to_send]
tool_choice: ToolChoiceTypeDef
if not model_request_parameters.allow_text_output:

if resolved is None:
# Default behavior: infer from allow_text_output
if not model_request_parameters.allow_text_output:
tool_choice = {'any': {}}
else:
tool_choice = {'auto': {}}

elif resolved.mode == 'auto':
if not model_request_parameters.allow_text_output:
tool_choice = {'any': {}}
else:
tool_choice = {'auto': {}}

elif resolved.mode == 'required':
tool_choice = {'any': {}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as above: required means "function tool", so we should not allow output tools to be called yet

else:

elif resolved.mode == 'none':
# We've already filtered to only output tools, use 'auto' to let model choose
tool_choice = {'auto': {}}

elif resolved.mode == 'specific':
if not resolved.tool_names: # pragma: no cover
raise RuntimeError('Internal error: resolved.tool_names is empty for specific tool choice.')
if len(resolved.tool_names) == 1:
tool_choice = {'tool': {'name': resolved.tool_names[0]}}
else:
warnings.warn("Bedrock only supports forcing a single tool. Falling back to 'any'.")
tool_choice = {'any': {}}

else:
assert_never(resolved.mode)

tool_config: ToolConfigurationTypeDef = {'tools': tools}
if tool_choice and BedrockModelProfile.from_profile(self.profile).bedrock_supports_tool_choice:
tool_config['toolChoice'] = tool_choice
Expand Down
Loading