-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add tool_choice setting
#3611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add tool_choice setting
#3611
Changes from all commits
8d52d65
3128b4a
6d942f5
0585347
96681ac
e71dc86
5c387fd
4dcfbe4
363c718
31bb4e1
338a073
07fcb6b
51cada5
6597e0b
914748f
57ff6bf
8a24d41
1a46a7b
5e4cfb7
70dc917
f924378
80ece45
24962c0
dc97d4e
88884f5
b55bac9
50c9db1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| from __future__ import annotations as _annotations | ||
|
|
||
| import io | ||
| import warnings | ||
| from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator | ||
| from contextlib import asynccontextmanager | ||
| from dataclasses import dataclass, field, replace | ||
|
|
@@ -42,7 +43,15 @@ | |
| from ..providers.anthropic import AsyncAnthropicClient | ||
| from ..settings import ModelSettings, merge_model_settings | ||
| from ..tools import ToolDefinition | ||
| from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent | ||
| from . import ( | ||
| Model, | ||
| ModelRequestParameters, | ||
| StreamedResponse, | ||
| _resolve_tool_choice, # pyright: ignore[reportPrivateUsage] | ||
| check_allow_model_requests, | ||
| download_item, | ||
| get_user_agent, | ||
| ) | ||
|
|
||
| _FINISH_REASON_MAP: dict[BetaStopReason, FinishReason] = { | ||
| 'end_turn': 'stop', | ||
|
|
@@ -386,11 +395,9 @@ async def _messages_create( | |
| This is the last step before sending the request to the API. | ||
| Most preprocessing has happened in `prepare_request()`. | ||
| """ | ||
| tools = self._get_tools(model_request_parameters, model_settings) | ||
| tools, tool_choice = self._infer_tool_choice(model_settings, model_request_parameters) | ||
| tools, mcp_servers, builtin_tool_betas = self._add_builtin_tools(tools, model_request_parameters) | ||
|
|
||
| tool_choice = self._infer_tool_choice(tools, model_settings, model_request_parameters) | ||
|
|
||
| system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters, model_settings) | ||
| self._limit_cache_points(system_prompt, anthropic_messages, tools) | ||
| output_format = self._native_output_format(model_request_parameters) | ||
|
|
@@ -474,11 +481,9 @@ async def _messages_count_tokens( | |
| raise UserError('AsyncAnthropicBedrock client does not support `count_tokens` api.') | ||
|
|
||
| # standalone function to make it easier to override | ||
| tools = self._get_tools(model_request_parameters, model_settings) | ||
| tools, tool_choice = self._infer_tool_choice(model_settings, model_request_parameters) | ||
| tools, mcp_servers, builtin_tool_betas = self._add_builtin_tools(tools, model_request_parameters) | ||
|
|
||
| tool_choice = self._infer_tool_choice(tools, model_settings, model_request_parameters) | ||
|
|
||
| system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters, model_settings) | ||
| self._limit_cache_points(system_prompt, anthropic_messages, tools) | ||
| output_format = self._native_output_format(model_request_parameters) | ||
|
|
@@ -584,22 +589,6 @@ async def _process_streamed_response( | |
| _provider_url=self._provider.base_url, | ||
| ) | ||
|
|
||
| def _get_tools( | ||
| self, model_request_parameters: ModelRequestParameters, model_settings: AnthropicModelSettings | ||
| ) -> list[BetaToolUnionParam]: | ||
| tools: list[BetaToolUnionParam] = [ | ||
| self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values() | ||
| ] | ||
|
|
||
| # Add cache_control to the last tool if enabled | ||
| if tools and (cache_tool_defs := model_settings.get('anthropic_cache_tool_definitions')): | ||
| # If True, use '5m'; otherwise use the specified ttl value | ||
| ttl: Literal['5m', '1h'] = '5m' if cache_tool_defs is True else cache_tool_defs | ||
| last_tool = tools[-1] | ||
| last_tool['cache_control'] = self._build_cache_control(ttl) | ||
|
|
||
| return tools | ||
|
|
||
| def _add_builtin_tools( | ||
| self, tools: list[BetaToolUnionParam], model_request_parameters: ModelRequestParameters | ||
| ) -> tuple[list[BetaToolUnionParam], list[BetaRequestMCPServerURLDefinitionParam], set[str]]: | ||
|
|
@@ -663,26 +652,91 @@ def _add_builtin_tools( | |
| ) | ||
| return tools, mcp_servers, beta_features | ||
|
|
||
| def _infer_tool_choice( | ||
| def _infer_tool_choice( # noqa: C901 | ||
| self, | ||
| tools: list[BetaToolUnionParam], | ||
| model_settings: AnthropicModelSettings, | ||
| model_request_parameters: ModelRequestParameters, | ||
| ) -> BetaToolChoiceParam | None: | ||
| if not tools: | ||
| return None | ||
| ) -> tuple[list[BetaToolUnionParam], BetaToolChoiceParam | None]: | ||
| """Determine which tools to send and the API tool_choice value. | ||
|
|
||
| Returns: | ||
| A tuple of (filtered_tools, tool_choice). | ||
| """ | ||
| thinking_enabled = model_settings.get('anthropic_thinking') is not None | ||
| function_tools = model_request_parameters.function_tools | ||
| output_tools = model_request_parameters.output_tools | ||
|
|
||
| resolved = _resolve_tool_choice(model_settings, model_request_parameters) | ||
|
|
||
| if resolved is None: | ||
| tool_defs_to_send = [*function_tools, *output_tools] | ||
| else: | ||
| tool_choice: BetaToolChoiceParam | ||
| tool_defs_to_send = resolved.filter_tools(function_tools, output_tools) | ||
|
|
||
| # Map ToolDefinitions to Anthropic format | ||
| tools: list[BetaToolUnionParam] = [self._map_tool_definition(t) for t in tool_defs_to_send] | ||
|
|
||
| # Add cache_control to the last tool if enabled | ||
| if tools and (cache_tool_defs := model_settings.get('anthropic_cache_tool_definitions')): | ||
| ttl: Literal['5m', '1h'] = '5m' if cache_tool_defs is True else cache_tool_defs | ||
| last_tool = tools[-1] | ||
| last_tool['cache_control'] = self._build_cache_control(ttl) | ||
|
|
||
| if not tools: | ||
| return tools, None | ||
|
|
||
| tool_choice: BetaToolChoiceParam | ||
|
|
||
| if resolved is None: | ||
| if not model_request_parameters.allow_text_output: | ||
| tool_choice = {'type': 'any'} | ||
| else: | ||
| tool_choice = {'type': 'auto'} | ||
|
|
||
| if 'parallel_tool_calls' in model_settings: | ||
| tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls'] | ||
| elif resolved.mode == 'auto': | ||
| if not model_request_parameters.allow_text_output: | ||
| tool_choice = {'type': 'any'} | ||
| else: | ||
| tool_choice = {'type': 'auto'} | ||
|
|
||
| elif resolved.mode == 'required': | ||
| if thinking_enabled: | ||
| raise UserError( | ||
| "tool_choice='required' is not supported with Anthropic thinking mode. " | ||
| 'Use `output_type=NativeOutput(...)` or `PromptedOutput(...)` instead.' | ||
| ) | ||
| tool_choice = {'type': 'any'} | ||
|
|
||
| elif resolved.mode == 'none': | ||
| if len(output_tools) == 1: | ||
| tool_choice = {'type': 'tool', 'name': output_tools[0].name} | ||
| else: | ||
| warnings.warn( | ||
| "Anthropic only supports forcing a single tool. Falling back to 'auto' for multiple output tools." | ||
| ) | ||
| tool_choice = {'type': 'auto'} | ||
|
|
||
| elif resolved.mode == 'specific': | ||
| if thinking_enabled: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above; if we can't honor the user's wishes at all, we should raise an error |
||
| raise UserError( | ||
| 'Forcing specific tools is not supported with Anthropic thinking mode. ' | ||
| 'Use `output_type=NativeOutput(...)` or `PromptedOutput(...)` instead.' | ||
| ) | ||
| if len(resolved.tool_names) == 1: | ||
| tool_choice = {'type': 'tool', 'name': resolved.tool_names[0]} | ||
| else: | ||
| warnings.warn( | ||
| "Anthropic only supports forcing a single tool. Falling back to 'any' for multiple specific tools." | ||
| ) | ||
| tool_choice = {'type': 'any'} | ||
|
|
||
| else: | ||
| assert_never(resolved.mode) | ||
|
|
||
| if 'parallel_tool_calls' in model_settings: | ||
| tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls'] | ||
|
|
||
| return tool_choice | ||
| return tools, tool_choice | ||
|
|
||
| async def _map_message( # noqa: C901 | ||
| self, | ||
|
|
@@ -887,9 +941,10 @@ async def _map_message( # noqa: C901 | |
| system_prompt_parts.insert(0, instructions) | ||
| system_prompt = '\n\n'.join(system_prompt_parts) | ||
|
|
||
| ttl: Literal['5m', '1h'] | ||
| # Add cache_control to the last message content if anthropic_cache_messages is enabled | ||
| if anthropic_messages and (cache_messages := model_settings.get('anthropic_cache_messages')): | ||
| ttl: Literal['5m', '1h'] = '5m' if cache_messages is True else cache_messages | ||
| ttl = '5m' if cache_messages is True else cache_messages | ||
| m = anthropic_messages[-1] | ||
| content = m['content'] | ||
| if isinstance(content, str): | ||
|
|
@@ -909,7 +964,7 @@ async def _map_message( # noqa: C901 | |
| # If anthropic_cache_instructions is enabled, return system prompt as a list with cache_control | ||
| if system_prompt and (cache_instructions := model_settings.get('anthropic_cache_instructions')): | ||
| # If True, use '5m'; otherwise use the specified ttl value | ||
| ttl: Literal['5m', '1h'] = '5m' if cache_instructions is True else cache_instructions | ||
| ttl = '5m' if cache_instructions is True else cache_instructions | ||
| system_prompt_blocks = [ | ||
| BetaTextBlockParam( | ||
| type='text', | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| import functools | ||
| import typing | ||
| import warnings | ||
| from collections.abc import AsyncIterator, Iterable, Iterator, Mapping | ||
| from contextlib import asynccontextmanager | ||
| from dataclasses import dataclass, field | ||
|
|
@@ -41,7 +42,13 @@ | |
| ) | ||
| from pydantic_ai._run_context import RunContext | ||
| from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UserError | ||
| from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item | ||
| from pydantic_ai.models import ( | ||
| Model, | ||
| ModelRequestParameters, | ||
| StreamedResponse, | ||
| _resolve_tool_choice, # pyright: ignore[reportPrivateUsage] | ||
| download_item, | ||
| ) | ||
| from pydantic_ai.providers import Provider, infer_provider | ||
| from pydantic_ai.providers.bedrock import BEDROCK_GEO_PREFIXES, BedrockModelProfile | ||
| from pydantic_ai.settings import ModelSettings | ||
|
|
@@ -254,9 +261,6 @@ def system(self) -> str: | |
| """The model provider.""" | ||
| return self._provider.name | ||
|
|
||
| def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]: | ||
| return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] | ||
|
|
||
| @staticmethod | ||
| def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef: | ||
| tool_spec: ToolSpecificationTypeDef = {'name': f.name, 'inputSchema': {'json': f.parameters_json_schema}} | ||
|
|
@@ -422,7 +426,7 @@ async def _messages_create( | |
| 'inferenceConfig': inference_config, | ||
| } | ||
|
|
||
| tool_config = self._map_tool_config(model_request_parameters) | ||
| tool_config = self._map_tool_config(model_request_parameters, model_settings) | ||
| if tool_config: | ||
| params['toolConfig'] = tool_config | ||
|
|
||
|
|
@@ -478,17 +482,58 @@ def _map_inference_config( | |
|
|
||
| return inference_config | ||
|
|
||
| def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> ToolConfigurationTypeDef | None: | ||
| tools = self._get_tools(model_request_parameters) | ||
| if not tools: | ||
| def _map_tool_config( | ||
| self, | ||
| model_request_parameters: ModelRequestParameters, | ||
| model_settings: BedrockModelSettings | None, | ||
| ) -> ToolConfigurationTypeDef | None: | ||
| resolved = _resolve_tool_choice(model_settings, model_request_parameters) | ||
| function_tools = model_request_parameters.function_tools | ||
| output_tools = model_request_parameters.output_tools | ||
|
|
||
| if resolved is None: | ||
| tool_defs_to_send = [*function_tools, *output_tools] | ||
| else: | ||
| tool_defs_to_send = resolved.filter_tools(function_tools, output_tools) | ||
|
|
||
| if not tool_defs_to_send: | ||
| return None | ||
|
|
||
| tools = [self._map_tool_definition(t) for t in tool_defs_to_send] | ||
| tool_choice: ToolChoiceTypeDef | ||
| if not model_request_parameters.allow_text_output: | ||
|
|
||
| if resolved is None: | ||
| # Default behavior: infer from allow_text_output | ||
| if not model_request_parameters.allow_text_output: | ||
| tool_choice = {'any': {}} | ||
| else: | ||
| tool_choice = {'auto': {}} | ||
|
|
||
| elif resolved.mode == 'auto': | ||
| if not model_request_parameters.allow_text_output: | ||
| tool_choice = {'any': {}} | ||
| else: | ||
| tool_choice = {'auto': {}} | ||
|
|
||
| elif resolved.mode == 'required': | ||
| tool_choice = {'any': {}} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same issue as above: required means "function tool", so we should not allow output tools to be called yet |
||
| else: | ||
|
|
||
| elif resolved.mode == 'none': | ||
| # We've already filtered to only output tools, use 'auto' to let model choose | ||
| tool_choice = {'auto': {}} | ||
|
|
||
| elif resolved.mode == 'specific': | ||
| if not resolved.tool_names: # pragma: no cover | ||
| raise RuntimeError('Internal error: resolved.tool_names is empty for specific tool choice.') | ||
| if len(resolved.tool_names) == 1: | ||
| tool_choice = {'tool': {'name': resolved.tool_names[0]}} | ||
| else: | ||
| warnings.warn("Bedrock only supports forcing a single tool. Falling back to 'any'.") | ||
| tool_choice = {'any': {}} | ||
|
|
||
| else: | ||
| assert_never(resolved.mode) | ||
|
|
||
| tool_config: ToolConfigurationTypeDef = {'tools': tools} | ||
| if tool_choice and BedrockModelProfile.from_profile(self.profile).bedrock_supports_tool_choice: | ||
| tool_config['toolChoice'] = tool_choice | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this is too restrictive because some tools may only be conditionally available, but it's possible that if they are available, they would be one of the allowed tools. But I'm OK with having an error now, and removing it if someone complains