diff --git a/docs/agents.md b/docs/agents.md index 28a069a0be..9129d9e999 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -15,6 +15,7 @@ The [`Agent`][pydantic_ai.Agent] class has full API documentation, but conceptua | [Dependency type constraint](dependencies.md) | Dynamic instructions functions, tools, and output functions may all use dependencies when they're run. | | [LLM model](api/models/base.md) | Optional default LLM model associated with the agent. Can also be specified when running the agent. | | [Model Settings](#additional-configuration) | Optional default model settings to help fine tune requests. Can also be specified when running the agent. | +| [Prompt Configuration](#prompt-configuration) | Optional configuration for customizing system-generated messages, tool descriptions, and retry prompts. | In typing terms, agents are generic in their dependency and output types, e.g., an agent which required dependencies of type `#!python Foobar` and produced outputs of type `#!python list[str]` would have type `Agent[Foobar, list[str]]`. In practice, you shouldn't need to care about this, it should just mean your IDE can tell you when you have the right type, and if you choose to use [static type checking](#static-type-checking) it should work well with Pydantic AI. @@ -751,6 +752,135 @@ except UnexpectedModelBehavior as e: 1. This error is raised because the safety thresholds were exceeded. +### Prompt Configuration + +Pydantic AI provides [`PromptConfig`][pydantic_ai.PromptConfig] to customize the system-generated messages +that are sent to models during agent runs. This includes retry prompts, tool return confirmations, +validation error messages, and tool descriptions. + +#### Customizing System Messages with PromptTemplates + +[`PromptTemplates`][pydantic_ai.PromptTemplates] allows you to override the default messages that Pydantic AI +sends to the model for retries, tool results, and other system-generated content. + +```python {title="prompt_templates_example.py"} +from pydantic_ai import Agent, PromptConfig, PromptTemplates + +# Using static strings +agent = Agent( + 'openai:gpt-5', + prompt_config=PromptConfig( + templates=PromptTemplates( + validation_errors_retry='Please correct the validation errors and try again.', + final_result_processed='Result received successfully.', + ), + ), +) +``` + +You can also use callable functions for dynamic messages that have access to the message part +and the [`RunContext`][pydantic_ai.RunContext]: + +```python {title="prompt_templates_dynamic.py"} +from pydantic_ai import Agent, PromptConfig, PromptTemplates +from pydantic_ai.messages import RetryPromptPart +from pydantic_ai.tools import RunContext + + +def custom_retry_message(part: RetryPromptPart, ctx: RunContext) -> str: + return f'Attempt #{ctx.retry + 1}: Please fix the errors and try again.' + +agent = Agent( + 'openai:gpt-5', + prompt_config=PromptConfig( + templates=PromptTemplates( + validation_errors_retry=custom_retry_message, + ), + ), +) +``` + +The available template fields in [`PromptTemplates`][pydantic_ai.PromptTemplates] include: + +| Template Field | Description | +|----------------|-------------| +| `final_result_processed` | Confirmation message when a final result is successfully processed | +| `output_tool_not_executed` | Message when an output tool call is skipped because a result was already found | +| `function_tool_not_executed` | Message when a function tool call is skipped because a result was already found | +| `tool_call_denied` | Message when a tool call is denied by an approval handler | +| `validation_errors_retry` | Message appended to validation errors when asking the model to retry | +| `model_retry_string_tool` | Message when a `ModelRetry` exception is raised from a tool | +| `model_retry_string_no_tool` | Message when a `ModelRetry` exception is raised outside of a tool context | +| `prompted_output_template` | Template for prompted output schema instructions (uses `{schema}` placeholder) | + +#### Customizing Tool Descriptions with ToolConfig + +[`ToolConfig`][pydantic_ai.ToolConfig] allows you to override tool descriptions and argument descriptions +at runtime without modifying the original tool definitions. This is useful when you want to provide +different descriptions for the same tool in different contexts or agent runs. + +```python {title="tool_config_example.py"} +from pydantic_ai import Agent, PromptConfig, ToolConfig + +agent = Agent( + 'openai:gpt-5', + prompt_config=PromptConfig( + tool_config={ + 'search_database': ToolConfig( + tool_description='Search the customer database for user records by name or email.', + tool_args_descriptions={ + 'query': 'Search term to match against user names or email addresses.', + }, + ), + 'send_notification': ToolConfig( + tool_description='Send an urgent notification to the user via their preferred channel.', + tool_args_descriptions={ + 'user_id': 'The unique identifier of the user to notify.', + 'message': 'The notification message content (max 500 characters).', + }, + ), + }, + ), +) + + +@agent.tool_plain +def search_database(query: str) -> list[str]: + """Original description that will be overridden.""" + return ['result1', 'result2'] + + +@agent.tool_plain +def send_notification(user_id: str, message: str) -> bool: + """Original description that will be overridden.""" + return True +``` + +You can also override `prompt_config` at runtime using the `prompt_config` parameter in the run methods, +or temporarily using [`agent.override()`][pydantic_ai.Agent.override]: + +```python {title="prompt_config_override.py"} +from pydantic_ai import Agent, PromptConfig, PromptTemplates + +agent = Agent('openai:gpt-5') + +# Override at runtime +result = agent.run_sync( + 'Hello', + prompt_config=PromptConfig( + templates=PromptTemplates(validation_errors_retry='Custom retry message for this run.') + ), +) + +# Or use agent.override() context manager +with agent.override( + prompt_config=PromptConfig( + templates=PromptTemplates(validation_errors_retry='Another custom message.') + ) +): + result = agent.run_sync('Hello') +``` + ## Runs vs. Conversations An agent **run** might represent an entire conversation — there's no limit to how many messages can be exchanged in a single run. However, a **conversation** might also be composed of multiple runs, especially if you need to maintain state between separate interactions or API calls. diff --git a/docs/api/prompt_config.md b/docs/api/prompt_config.md new file mode 100644 index 0000000000..1b677292a7 --- /dev/null +++ b/docs/api/prompt_config.md @@ -0,0 +1,9 @@ +# `pydantic_ai.prompt_config` + +::: pydantic_ai.prompt_config + options: + inherited_members: true + members: + - PromptConfig + - PromptTemplates + - ToolConfig diff --git a/docs/deferred-tools.md b/docs/deferred-tools.md index e9f0a2ccee..eefac14ae7 100644 --- a/docs/deferred-tools.md +++ b/docs/deferred-tools.md @@ -150,6 +150,7 @@ print(result.all_messages()) content="File 'README.md' updated: 'Hello, world!'", tool_call_id='update_file_readme', timestamp=datetime.datetime(...), + return_kind='tool-executed', ) ], run_id='...', @@ -161,12 +162,14 @@ print(result.all_messages()) content="File '.env' updated: ''", tool_call_id='update_file_dotenv', timestamp=datetime.datetime(...), + return_kind='tool-executed', ), ToolReturnPart( tool_name='delete_file', content='Deleting files is not allowed', tool_call_id='delete_file', timestamp=datetime.datetime(...), + return_kind='tool-denied', ), UserPromptPart( content='Now create a backup of README.md', @@ -195,6 +198,7 @@ print(result.all_messages()) content="File 'README.md.bak' updated: 'Hello, world!'", tool_call_id='update_file_backup', timestamp=datetime.datetime(...), + return_kind='tool-executed', ) ], run_id='...', @@ -348,6 +352,7 @@ async def main(): content=42, tool_call_id='pyd_ai_tool_call_id', timestamp=datetime.datetime(...), + return_kind='tool-executed', ) ], run_id='...', diff --git a/docs/testing.md b/docs/testing.md index 3089585ab0..99d2e01472 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -156,6 +156,7 @@ async def test_forecast(): content='Sunny with a chance of rain', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), ], run_id=IsStr(), diff --git a/docs/tools.md b/docs/tools.md index fc8641251b..3440054788 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -108,6 +108,7 @@ print(dice_result.all_messages()) content='4', tool_call_id='pyd_ai_tool_call_id', timestamp=datetime.datetime(...), + return_kind='tool-executed', ) ], run_id='...', @@ -130,6 +131,7 @@ print(dice_result.all_messages()) content='Anne', tool_call_id='pyd_ai_tool_call_id', timestamp=datetime.datetime(...), + return_kind='tool-executed', ) ], run_id='...', diff --git a/mkdocs.yml b/mkdocs.yml index a1c944da65..bdca475f12 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -164,6 +164,7 @@ nav: - api/models/test.md - api/models/wrapper.md - api/profiles.md + - api/prompt_config.md - api/providers.md - api/retries.md - api/run.md diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index dfecd6288f..b0f4c0caf9 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -95,6 +95,18 @@ ModelProfile, ModelProfileSpec, ) +from .prompt_config import ( + DEFAULT_FINAL_RESULT_PROCESSED, + DEFAULT_FUNCTION_TOOL_NOT_EXECUTED, + DEFAULT_MODEL_RETRY, + DEFAULT_OUTPUT_TOOL_NOT_EXECUTED, + DEFAULT_OUTPUT_VALIDATION_FAILED, + DEFAULT_PROMPTED_OUTPUT_TEMPLATE, + DEFAULT_TOOL_CALL_DENIED, + PromptConfig, + PromptTemplates, + ToolConfig, +) from .run import AgentRun, AgentRunResult, AgentRunResultEvent from .settings import ModelSettings from .tools import DeferredToolRequests, DeferredToolResults, RunContext, Tool, ToolApproved, ToolDefinition, ToolDenied @@ -231,6 +243,17 @@ 'PromptedOutput', 'TextOutput', 'StructuredDict', + # prompt_config + 'PromptConfig', + 'PromptTemplates', + 'ToolConfig', + 'DEFAULT_FINAL_RESULT_PROCESSED', + 'DEFAULT_FUNCTION_TOOL_NOT_EXECUTED', + 'DEFAULT_MODEL_RETRY', + 'DEFAULT_OUTPUT_TOOL_NOT_EXECUTED', + 'DEFAULT_OUTPUT_VALIDATION_FAILED', + 'DEFAULT_PROMPTED_OUTPUT_TEMPLATE', + 'DEFAULT_TOOL_CALL_DENIED', # format_prompt 'format_as_xml', # settings diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 92c45a0c52..64670314a8 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -25,7 +25,16 @@ from pydantic_graph.beta import Graph, GraphBuilder from pydantic_graph.nodes import End, NodeRunEndT -from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage +from . import ( + _output, + _system_prompt, + exceptions, + messages as _messages, + models, + prompt_config as _prompt_config, + result, + usage as _usage, +) from .exceptions import ToolRetryError from .output import OutputDataT, OutputSpec from .settings import ModelSettings @@ -133,6 +142,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): model: models.Model model_settings: ModelSettings | None + prompt_config: _prompt_config.PromptConfig | None = None usage_limits: _usage.UsageLimits max_result_retries: int end_strategy: EndStrategy @@ -379,9 +389,18 @@ async def _prepare_request_parameters( """Build tools and create an agent model.""" output_schema = ctx.deps.output_schema - prompted_output_template = ( - output_schema.template if isinstance(output_schema, _output.PromptedOutputSchema) else None - ) + # Get the prompted output template with precedence: + # PromptConfig template > PromptedOutput.template > model profile default (handled downstream) + prompted_output_template: str | None = None + if isinstance(output_schema, _output.PromptedOutputSchema): + if ( + (prompt_config := ctx.deps.prompt_config) + and (prompt_templates := prompt_config.templates) + and (template := prompt_templates.prompted_output_template) + ): + prompted_output_template = template + else: + prompted_output_template = output_schema.template function_tools: list[ToolDefinition] = [] output_tools: list[ToolDefinition] = [] @@ -504,6 +523,13 @@ async def _prepare_request( # Update the new message index to ensure `result.new_messages()` returns the correct messages ctx.deps.new_message_index -= len(original_history) - len(message_history) + prompt_config = ctx.deps.prompt_config + + if prompt_config and (templates := prompt_config.templates): + message_history = templates.apply_template_message_history(message_history, run_context) + + ctx.state.message_history[:] = message_history + # Merge possible consecutive trailing `ModelRequest`s into one, with tool call parts before user parts, # but don't store it in the message history on state. This is just for the benefit of model classes that want clear user/assistant boundaries. # See `tests/test_tools.py::test_parallel_tool_return_with_deferred` for an example where this is necessary @@ -780,7 +806,14 @@ def _handle_final_result( # For backwards compatibility, append a new ModelRequest using the tool returns and retries if tool_responses: - messages.append(_messages.ModelRequest(parts=tool_responses, run_id=ctx.state.run_id)) + # Only apply templates if explicitly configured + message = _messages.ModelRequest(parts=tool_responses, run_id=ctx.state.run_id) + + if (prompt_config := ctx.deps.prompt_config) and (prompt_templates := prompt_config.templates): + run_ctx = build_run_context(ctx) + message = prompt_templates.apply_template_message_history([message], run_ctx)[0] + + messages.append(message) return End(final_result) @@ -865,8 +898,9 @@ async def process_tool_calls( # noqa: C901 if final_result and final_result.tool_call_id == call.tool_call_id: part = _messages.ToolReturnPart( tool_name=call.tool_name, - content='Final result processed.', + content=_prompt_config.DEFAULT_FINAL_RESULT_PROCESSED, tool_call_id=call.tool_call_id, + return_kind='final-result-processed', ) output_parts.append(part) # Early strategy is chosen and final result is already set @@ -874,8 +908,9 @@ async def process_tool_calls( # noqa: C901 yield _messages.FunctionToolCallEvent(call) part = _messages.ToolReturnPart( tool_name=call.tool_name, - content='Output tool not used - a final result was already processed.', + content=_prompt_config.DEFAULT_OUTPUT_TOOL_NOT_EXECUTED, tool_call_id=call.tool_call_id, + return_kind='output-tool-not-executed', ) yield _messages.FunctionToolResultEvent(part) output_parts.append(part) @@ -892,8 +927,9 @@ async def process_tool_calls( # noqa: C901 yield _messages.FunctionToolCallEvent(call) part = _messages.ToolReturnPart( tool_name=call.tool_name, - content='Output tool not used - output failed validation.', + content=_prompt_config.DEFAULT_OUTPUT_VALIDATION_FAILED, tool_call_id=call.tool_call_id, + return_kind='output-validation-failed', ) output_parts.append(part) yield _messages.FunctionToolResultEvent(part) @@ -916,8 +952,9 @@ async def process_tool_calls( # noqa: C901 else: part = _messages.ToolReturnPart( tool_name=call.tool_name, - content='Final result processed.', + content=_prompt_config.DEFAULT_FINAL_RESULT_PROCESSED, tool_call_id=call.tool_call_id, + return_kind='final-result-processed', ) output_parts.append(part) @@ -932,8 +969,9 @@ async def process_tool_calls( # noqa: C901 output_parts.append( _messages.ToolReturnPart( tool_name=call.tool_name, - content='Tool not executed - a final result was already processed.', + content=_prompt_config.DEFAULT_FUNCTION_TOOL_NOT_EXECUTED, tool_call_id=call.tool_call_id, + return_kind='function-tool-not-executed', ) ) else: @@ -990,8 +1028,9 @@ async def process_tool_calls( # noqa: C901 output_parts.append( _messages.ToolReturnPart( tool_name=call.tool_name, - content='Tool not executed - a final result was already processed.', + content=_prompt_config.DEFAULT_FUNCTION_TOOL_NOT_EXECUTED, tool_call_id=call.tool_call_id, + return_kind='function-tool-not-executed', ) ) elif calls: @@ -1148,6 +1187,7 @@ async def _call_tool( tool_name=tool_call.tool_name, content=tool_call_result.message, tool_call_id=tool_call.tool_call_id, + return_kind='tool-denied', ), None elif isinstance(tool_call_result, exceptions.ModelRetry): m = _messages.RetryPromptPart( @@ -1210,6 +1250,7 @@ async def _call_tool( tool_call_id=tool_call.tool_call_id, content=tool_return.return_value, # type: ignore metadata=tool_return.metadata, + return_kind='tool-executed', ) return return_part, tool_return.content or None diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 0bc1418470..71d7968940 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -24,6 +24,7 @@ exceptions, messages as _messages, models, + prompt_config as _prompt_config, usage as _usage, ) from .._agent_graph import ( @@ -63,7 +64,7 @@ ) from ..toolsets.combined import CombinedToolset from ..toolsets.function import FunctionToolset -from ..toolsets.prepared import PreparedToolset +from ..toolsets.prepared import PreparedToolset, ToolConfigPreparedToolset from .abstract import AbstractAgent, EventStreamHandler, Instructions, RunOutputDataT from .wrapper import WrapperAgent @@ -133,6 +134,9 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): be merged with this value, with the runtime argument taking priority. """ + prompt_config: _prompt_config.PromptConfig | None + """Optional prompt configuration used to customize the system-injected messages for this agent.""" + _output_type: OutputSpec[OutputDataT] instrument: InstrumentationSettings | bool | None @@ -176,6 +180,7 @@ def __init__( deps_type: type[AgentDepsT] = NoneType, name: str | None = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, retries: int = 1, validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, output_retries: int | None = None, @@ -230,6 +235,7 @@ def __init__( deps_type: type[AgentDepsT] = NoneType, name: str | None = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, retries: int = 1, validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, output_retries: int | None = None, @@ -264,6 +270,8 @@ def __init__( name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame when the agent is first run. model_settings: Optional model request settings to use for this agent's runs, by default. + prompt_config: Optional prompt configuration to customize how system-injected messages + (like retry prompts or tool return wrappers) are rendered for this agent. retries: The default number of retries to allow for tool calls and output validation, before raising an error. For model request retries, see the [HTTP Request Retries](../retries.md) documentation. validation_context: Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) used to validate tool arguments and outputs. @@ -310,6 +318,7 @@ def __init__( self._name = name self.end_strategy = end_strategy self.model_settings = model_settings + self.prompt_config = prompt_config self._output_type = output_type self.instrument = instrument @@ -376,6 +385,9 @@ def __init__( self._override_instructions: ContextVar[ _utils.Option[list[str | _system_prompt.SystemPromptFunc[AgentDepsT]]] ] = ContextVar('_override_instructions', default=None) + self._override_prompt_config: ContextVar[_utils.Option[_prompt_config.PromptConfig]] = ContextVar( + '_override_prompt_config', default=None + ) self._enter_lock = Lock() self._entered_count = 0 @@ -443,6 +455,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -462,6 +475,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -481,6 +495,7 @@ async def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -557,6 +572,8 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are + phrased for this specific run, falling back to the agent's defaults if omitted. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -581,6 +598,7 @@ async def main(): # may change the result type from the restricted type to something else. Therefore, we consider the following # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. output_validators = self._output_validators + prompt_config = self._get_prompt_config(prompt_config) output_toolset = self._output_toolset if output_schema != self._output_schema or output_validators: @@ -588,7 +606,9 @@ async def main(): if output_toolset: output_toolset.max_retries = self._max_result_retries output_toolset.output_validators = output_validators - toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) + toolset = self._get_toolset( + output_toolset=output_toolset, additional_toolsets=toolsets, prompt_config=prompt_config + ) tool_manager = ToolManager[AgentDepsT](toolset) # Build the graph @@ -634,6 +654,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: new_message_index=len(message_history) if message_history else 0, model=model_used, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, max_result_retries=self._max_result_retries, end_strategy=self.end_strategy, @@ -768,6 +789,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, + prompt_config: _prompt_config.PromptConfig | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. @@ -781,6 +803,7 @@ def override( toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. instructions: The instructions to use instead of the instructions registered with the agent. + prompt_config: The prompt config to use instead of the prompt config passed to the agent constructor and agent run. """ if _utils.is_set(name): name_token = self._override_name.set(_utils.Some(name)) @@ -813,6 +836,11 @@ def override( else: instructions_token = None + if _utils.is_set(prompt_config): + prompt_config_token = self._override_prompt_config.set(_utils.Some(prompt_config)) + else: + prompt_config_token = None + try: yield finally: @@ -828,6 +856,8 @@ def override( self._override_tools.reset(tools_token) if instructions_token is not None: self._override_instructions.reset(instructions_token) + if prompt_config_token is not None: + self._override_prompt_config.reset(prompt_config_token) @overload def instructions( @@ -1352,6 +1382,20 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T: else: return deps + def _get_prompt_config( + self, prompt_config: _prompt_config.PromptConfig | None + ) -> _prompt_config.PromptConfig | None: + """Get prompt_config for a run. + + If we've overridden prompt_config via `_override_prompt_config`, use that, + otherwise use the prompt_config passed to the call, falling back to the agent default. + Returns None if no prompt_config is configured at any level. + """ + if some_prompt_config := self._override_prompt_config.get(): + return some_prompt_config.value + else: + return prompt_config or self.prompt_config + def _normalize_instructions( self, instructions: Instructions[AgentDepsT], @@ -1390,12 +1434,14 @@ def _get_toolset( self, output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET, additional_toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, ) -> AbstractToolset[AgentDepsT]: """Get the complete toolset. Args: output_toolset: The output toolset to use instead of the one built at agent construction time. additional_toolsets: Additional toolsets to add, unless toolsets have been overridden. + prompt_config: The prompt config to use for tool descriptions. If None, uses agent-level or default. """ toolsets = self.toolsets # Don't add additional toolsets if the toolsets have been overridden @@ -1413,13 +1459,28 @@ def copy_dynamic_toolsets(toolset: AbstractToolset[AgentDepsT]) -> AbstractTools toolset = toolset.visit_and_replace(copy_dynamic_toolsets) + # Resolve tool_config from the prompt_config precedence chain: + # 1. Context override (agent.override(prompt_config=...)) + # 2. Per-call parameter (agent.run(..., prompt_config=...)) + # 3. Agent-level default (Agent(..., prompt_config=...)) + tool_config = ( + effective_prompt_config.tool_config + if (effective_prompt_config := self._get_prompt_config(prompt_config)) + else None + ) + if self._prepare_tools: toolset = PreparedToolset(toolset, self._prepare_tools) + if tool_config: + toolset = ToolConfigPreparedToolset(toolset, tool_config) + output_toolset = output_toolset if _utils.is_set(output_toolset) else self._output_toolset if output_toolset is not None: if self._prepare_output_tools: output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools) + if tool_config: + output_toolset = ToolConfigPreparedToolset(output_toolset, tool_config) toolset = CombinedToolset([output_toolset, toolset]) return toolset diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index cc99f80e74..dfae29f20d 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -21,6 +21,7 @@ exceptions, messages as _messages, models, + prompt_config as _prompt_config, result, usage as _usage, ) @@ -160,6 +161,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -180,6 +182,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -199,6 +202,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -233,6 +237,8 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for + this specific run, falling back to the agent defaults if omitted. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -257,6 +263,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, toolsets=toolsets, @@ -284,6 +291,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -304,6 +312,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -323,6 +332,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -356,6 +366,8 @@ def run_sync( instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for + this specific run, falling back to the agent defaults if omitted. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -379,6 +391,7 @@ def run_sync( instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=False, @@ -400,6 +413,7 @@ def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -420,6 +434,7 @@ def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -440,6 +455,7 @@ async def run_stream( # noqa: C901 instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -481,6 +497,8 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for + this specific run, falling back to the agent defaults if omitted. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -510,6 +528,7 @@ async def main(): deps=deps, instructions=instructions, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=False, @@ -632,6 +651,7 @@ def run_stream_sync( model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -651,6 +671,7 @@ def run_stream_sync( model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -669,6 +690,7 @@ def run_stream_sync( model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -712,6 +734,8 @@ def main(): model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for + this specific run, falling back to the agent defaults if omitted. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -736,6 +760,7 @@ async def _consume_stream(): model=model, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -760,6 +785,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -779,6 +805,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -797,6 +824,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -847,6 +875,8 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for + this specific run, falling back to the agent defaults if omitted. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -872,6 +902,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, toolsets=toolsets, @@ -889,6 +920,7 @@ async def _run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, @@ -915,6 +947,7 @@ async def run_agent() -> AgentRunResult[Any]: instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=False, @@ -944,6 +977,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -963,6 +997,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -983,6 +1018,7 @@ async def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -1059,6 +1095,8 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for + this specific run, falling back to the agent defaults if omitted. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -1082,6 +1120,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, + prompt_config: _prompt_config.PromptConfig | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. @@ -1095,6 +1134,7 @@ def override( toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. instructions: The instructions to use instead of the instructions registered with the agent. + prompt_config: The prompt configuration to use instead of the prompt config registered with the agent. """ raise NotImplementedError yield diff --git a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py index f363b5d990..7dc0d63e92 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py @@ -8,6 +8,7 @@ _utils, messages as _messages, models, + prompt_config as _prompt_config, usage as _usage, ) from .._json_schema import JsonSchema @@ -84,6 +85,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -103,6 +105,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -122,6 +125,7 @@ async def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -198,6 +202,8 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for + this specific run, falling back to the agent's defaults if omitted. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -216,6 +222,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -234,6 +241,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, + prompt_config: _prompt_config.PromptConfig | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. @@ -247,6 +255,7 @@ def override( toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. instructions: The instructions to use instead of the instructions registered with the agent. + prompt_config: The prompt configuration to use instead of the prompt config passed to the agent constructor and agent run. """ with self.wrapped.override( name=name, @@ -255,5 +264,6 @@ def override( toolsets=toolsets, tools=tools, instructions=instructions, + prompt_config=prompt_config, ): yield diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py index c5adf5221d..35b4147de6 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py @@ -13,6 +13,7 @@ _utils, messages as _messages, models, + prompt_config as _prompt_config, usage as _usage, ) from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, WrapperAgent @@ -135,6 +136,7 @@ async def wrapped_run_workflow( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -153,6 +155,7 @@ async def wrapped_run_workflow( instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -174,8 +177,9 @@ def wrapped_run_sync_workflow( deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT, - model_settings: ModelSettings | None = None, instructions: Instructions[AgentDepsT] = None, + model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -195,6 +199,7 @@ def wrapped_run_sync_workflow( deps=deps, model_settings=model_settings, usage_limits=usage_limits, + prompt_config=prompt_config, usage=usage, infer_name=infer_name, toolsets=toolsets, @@ -268,6 +273,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -288,6 +294,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -307,6 +314,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -343,6 +351,7 @@ async def main(): deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. usage_limits: Optional limits on model request count or token usage. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. @@ -365,6 +374,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -386,6 +396,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -406,6 +417,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -425,6 +437,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -459,6 +472,7 @@ def run_sync( instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -482,6 +496,7 @@ def run_sync( instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -503,6 +518,7 @@ def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -520,9 +536,10 @@ def run_stream( message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, instructions: Instructions[AgentDepsT] = None, + deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -622,6 +639,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -641,6 +659,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -659,6 +678,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -709,6 +729,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -736,6 +757,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -756,6 +778,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -776,6 +799,7 @@ async def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -853,6 +877,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -877,6 +902,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -896,6 +922,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, + prompt_config: _prompt_config.PromptConfig | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. @@ -909,6 +936,7 @@ def override( toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. instructions: The instructions to use instead of the instructions registered with the agent. + prompt_config: The prompt configuration to use instead of the prompt config registered with the agent. """ if _utils.is_set(model) and not isinstance(model, (DBOSModel)): raise UserError( @@ -922,5 +950,6 @@ def override( toolsets=toolsets, tools=tools, instructions=instructions, + prompt_config=prompt_config, ): yield diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py index 60c8122686..711d4ae247 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py @@ -16,6 +16,7 @@ _utils, messages as _messages, models, + prompt_config as _prompt_config, usage as _usage, ) from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, WrapperAgent @@ -184,6 +185,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -204,6 +206,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -223,6 +226,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -258,6 +262,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for this run. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -284,6 +289,7 @@ async def wrapped_run_flow() -> AgentRunResult[Any]: instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -308,6 +314,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -328,6 +335,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -347,6 +355,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -381,6 +390,7 @@ def run_sync( instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for this run. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -409,6 +419,7 @@ def wrapped_run_sync_flow() -> AgentRunResult[Any]: instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -434,6 +445,7 @@ def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -454,6 +466,7 @@ def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -474,6 +487,7 @@ async def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -506,6 +520,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for this run. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -531,6 +546,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -553,6 +569,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -572,6 +589,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -590,6 +608,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -640,6 +659,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for this run. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -665,6 +685,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -684,6 +705,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -703,6 +725,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -722,6 +745,7 @@ async def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -798,6 +822,7 @@ async def main(): deps: Optional dependencies to use for this run. instructions: Optional additional instructions to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for this run. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -822,6 +847,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -839,6 +865,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, + prompt_config: _prompt_config.PromptConfig | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions. @@ -852,6 +879,7 @@ def override( toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. instructions: The instructions to use instead of the instructions registered with the agent. + prompt_config: The prompt configuration to use instead of the prompt config registered with the agent. """ if _utils.is_set(model) and not isinstance(model, PrefectModel): raise UserError( @@ -859,6 +887,12 @@ def override( ) with super().override( - name=name, deps=deps, model=model, toolsets=toolsets, tools=tools, instructions=instructions + name=name, + deps=deps, + model=model, + toolsets=toolsets, + tools=tools, + instructions=instructions, + prompt_config=prompt_config, ): yield diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py index 0c9de7e29b..68e7b8f90c 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py @@ -21,6 +21,7 @@ _utils, messages as _messages, models, + prompt_config as _prompt_config, usage as _usage, ) from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, WrapperAgent @@ -270,6 +271,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -290,6 +292,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -309,6 +312,7 @@ async def run( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -344,6 +348,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -369,6 +374,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -390,6 +396,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -410,6 +417,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -429,6 +437,7 @@ def run_sync( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -463,6 +472,7 @@ def run_sync( instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -487,6 +497,7 @@ def run_sync( instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -508,6 +519,7 @@ def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -528,6 +540,7 @@ def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -548,6 +561,7 @@ async def run_stream( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -580,6 +594,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -605,6 +620,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -627,6 +643,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -646,6 +663,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -664,6 +682,7 @@ def run_stream_events( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -714,6 +733,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -739,6 +759,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -758,6 +779,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -778,6 +800,7 @@ def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -798,6 +821,7 @@ async def iter( instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, + prompt_config: _prompt_config.PromptConfig | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.RunUsage | None = None, infer_name: bool = True, @@ -875,6 +899,7 @@ async def main(): instructions: Optional additional instructions to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. + prompt_config: Optional prompt configuration to override how system-generated parts are phrased for usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. @@ -909,6 +934,7 @@ async def main(): instructions=instructions, deps=deps, model_settings=model_settings, + prompt_config=prompt_config, usage_limits=usage_limits, usage=usage, infer_name=infer_name, @@ -928,6 +954,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, + prompt_config: _prompt_config.PromptConfig | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. @@ -941,6 +968,7 @@ def override( toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. instructions: The instructions to use instead of the instructions registered with the agent. + prompt_config: The prompt configuration to use instead of the prompt config registered with the agent. """ if workflow.in_workflow(): if _utils.is_set(model): @@ -963,5 +991,6 @@ def override( toolsets=toolsets, tools=tools, instructions=instructions, + prompt_config=prompt_config, ): yield diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 43c6a3bb4b..680f6a385b 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -873,6 +873,16 @@ def has_content(self) -> bool: __repr__ = _utils.dataclasses_no_defaults_repr +ReturnKind: TypeAlias = Literal[ + 'final-result-processed', + 'output-tool-not-executed', + 'function-tool-not-executed', + 'tool-executed', + 'tool-denied', + 'output-validation-failed', +] + + @dataclass(repr=False) class ToolReturnPart(BaseToolReturnPart): """A tool return message, this encodes the result of running a tool.""" @@ -882,6 +892,18 @@ class ToolReturnPart(BaseToolReturnPart): part_kind: Literal['tool-return'] = 'tool-return' """Part type identifier, this is available on all parts as a discriminator.""" + return_kind: ReturnKind | None = None + """How the tool call was resolved, used for disambiguating return parts. + + * `tool-executed`: the tool ran successfully and produced a return value + * `final-result-processed`: an output tool produced the run's final result + * `output-tool-not-executed`: an output tool was skipped because a final result already existed + * `function-tool-not-executed`: a function tool was skipped due to early termination after a final result + * `tool-denied`: the tool call was rejected by an approval handler + * `output-validation-failed`: the tool call was rejected by an output validator + + """ + @dataclass(repr=False) class BuiltinToolReturnPart(BaseToolReturnPart): @@ -944,20 +966,27 @@ class RetryPromptPart: part_kind: Literal['retry-prompt'] = 'retry-prompt' """Part type identifier, this is available on all parts as a discriminator.""" + retry_message: str | None = None + """Pre-rendered retry message. When set by PromptTemplates, model_response() returns this directly.""" + def model_response(self) -> str: """Return a string message describing why the retry is requested.""" + # If templates were applied, return the pre-rendered message + if self.retry_message is not None: + return self.retry_message + + # Default fallback when no template was applied + from .prompt_config import DEFAULT_MODEL_RETRY, default_validation_error, default_validation_feedback + if isinstance(self.content, str): if self.tool_name is None: - description = f'Validation feedback:\n{self.content}' + description = default_validation_feedback(self.content) else: description = self.content else: - json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2) - plural = isinstance(self.content, list) and len(self.content) != 1 - description = ( - f'{len(self.content)} validation error{"s" if plural else ""}:\n```json\n{json_errors.decode()}\n```' - ) - return f'{description}\n\nFix the errors and try again.' + description = default_validation_error(self.content) + + return f'{description}\n\n{DEFAULT_MODEL_RETRY}' def otel_event(self, settings: InstrumentationSettings) -> LogRecord: if self.tool_name is None: diff --git a/pydantic_ai_slim/pydantic_ai/prompt_config.py b/pydantic_ai_slim/pydantic_ai/prompt_config.py new file mode 100644 index 0000000000..84da294cd5 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/prompt_config.py @@ -0,0 +1,392 @@ +from __future__ import annotations as _annotations + +from collections.abc import Callable +from dataclasses import dataclass, replace +from textwrap import dedent +from typing import TYPE_CHECKING, Any + +import pydantic_core + +from pydantic_ai.usage import RunUsage + +from ._run_context import RunContext +from .messages import ModelMessage, ModelRequest, ModelRequestPart, RetryPromptPart, ToolReturnPart + +if TYPE_CHECKING: + from pydantic_ai.agent import Agent + from pydantic_ai.models import Model + +# Default template strings - used when template field is None +DEFAULT_FINAL_RESULT_PROCESSED = 'Final result processed.' +"""Default confirmation message when a final result is successfully processed.""" + +DEFAULT_OUTPUT_TOOL_NOT_EXECUTED = 'Output tool not used - a final result was already processed.' +"""Default message when an output tool call is skipped because a result was already found.""" + +DEFAULT_OUTPUT_VALIDATION_FAILED = 'Output tool not used - output failed validation.' +"""Default message when an output tool fails validation.""" + +DEFAULT_FUNCTION_TOOL_NOT_EXECUTED = 'Tool not executed - a final result was already processed.' +"""Default message when a function tool call is skipped because a result was already found.""" + +DEFAULT_TOOL_CALL_DENIED = 'The tool call was denied.' +"""Default message when a tool call is denied by an approval handler.""" + +DEFAULT_MODEL_RETRY = 'Fix the errors and try again.' +"""Default message appended to retry prompts.""" + +DEFAULT_PROMPTED_OUTPUT_TEMPLATE = dedent( + """ + Always respond with a JSON object that's compatible with this schema: + + {schema} + + Don't include any text or Markdown fencing before or after. + """ +) +"""Default template for prompted output schema instructions.""" + + +def default_validation_feedback(content: str | list[pydantic_core.ErrorDetails]) -> str: + """Generate a default validation feedback message.""" + assert isinstance(content, str) + return f'Validation feedback:\n{content}' + + +def default_validation_error(content: str | list[pydantic_core.ErrorDetails]) -> str: + """Generate a default validation error message from a list of Pydantic `ErrorDetails`.""" + from .messages import error_details_ta + + assert isinstance(content, list) + + json_errors = error_details_ta.dump_json(content, exclude={'__all__': {'ctx'}}, indent=2) + plural = len(content) != 1 + return f'{len(content)} validation error{"s" if plural else ""}:\n```json\n{json_errors.decode()}\n```' + + +return_kind_to_default_prompt_template: dict[str, str] = { + 'final-result-processed': DEFAULT_FINAL_RESULT_PROCESSED, + 'output-tool-not-executed': DEFAULT_OUTPUT_TOOL_NOT_EXECUTED, + 'output-validation-failed': DEFAULT_OUTPUT_VALIDATION_FAILED, + 'function-tool-not-executed': DEFAULT_FUNCTION_TOOL_NOT_EXECUTED, + 'tool-denied': DEFAULT_TOOL_CALL_DENIED, +} + + +@dataclass +class PromptTemplates: + """Templates for customizing system-generated messages that Pydantic AI sends to models. + + Each template can be either: + - `None` to use the default message (or preserve existing content for `tool_call_denied`) + - A static string that replaces the default message + - A callable that receives the message part and [`RunContext`][pydantic_ai.RunContext] + and returns a dynamically generated string + + These templates are used within [`PromptConfig`][pydantic_ai.PromptConfig] to customize + retry prompts, tool return confirmations, validation error messages, and more. + + Example: + ```python + from pydantic_ai import Agent, PromptConfig, PromptTemplates + + # Using static strings + templates = PromptTemplates( + validation_errors_retry='Please fix the validation errors.', + final_result_processed='Done!', + ) + + # Using callable for dynamic messages + templates = PromptTemplates( + validation_errors_retry=lambda part, ctx: f'Retry #{ctx.retry}: Fix the errors.', + ) + + agent = Agent('openai:gpt-4o', prompt_config=PromptConfig(templates=templates)) + ``` + """ + + final_result_processed: str | Callable[[ToolReturnPart, RunContext[Any]], str] | None = None + """Confirmation message sent when a final result is successfully processed. + + If `None`, uses the default: 'Final result processed.' + """ + + output_tool_not_executed: str | Callable[[ToolReturnPart, RunContext[Any]], str] | None = None + """Message sent when an output tool call is skipped because a result was already found. + + If `None`, uses the default: 'Output tool not used - a final result was already processed.' + """ + + output_validation_failed: str | Callable[[ToolReturnPart, RunContext[Any]], str] | None = None + """Message sent when an output tool fails validation.""" + + function_tool_not_executed: str | Callable[[ToolReturnPart, RunContext[Any]], str] | None = None + """Message sent when a function tool call is skipped because a result was already found. + + If `None`, uses the default: 'Tool not executed - a final result was already processed.' + """ + + tool_denied: str | Callable[[ToolReturnPart, RunContext[Any]], str] | None = None + """Message sent when a tool call is denied by an approval handler. + + If `None`, preserves the custom message from `ToolDenied` (or uses the default if none was set). + Set explicitly to override all denied tool messages. + """ + + validation_errors_retry: str | Callable[[RetryPromptPart, RunContext[Any]], str] | None = None + """Message appended to validation errors when asking the model to retry. + + If `None`, uses the default: 'Fix the errors and try again.' + """ + + model_retry_string_tool: str | Callable[[RetryPromptPart, RunContext[Any]], str] | None = None + """Message sent when a `ModelRetry` exception is raised from a tool. + + If `None`, uses the default: 'Fix the errors and try again.' + """ + + model_retry_string_no_tool: str | Callable[[RetryPromptPart, RunContext[Any]], str] | None = None + """Message sent when a `ModelRetry` exception is raised outside of a tool context. + + If `None`, uses the default: 'Fix the errors and try again.' + """ + + prompted_output_template: str | None = None + """Template for prompted output schema instructions. + + If `None`, uses the template from `PromptedOutput` if set, otherwise the model's + profile-specific default template is used. + Set explicitly to override the template for all prompted outputs. + """ + + description_template: Callable[[str | list[pydantic_core.ErrorDetails]], str] | None = None + """Format a description message while asking the model to retry.""" + + def apply_template(self, message_part: ModelRequestPart, ctx: RunContext[Any]) -> ModelRequestPart: + if isinstance(message_part, ToolReturnPart): + if message_part.return_kind in (None, 'tool-executed'): + return message_part + + field_name = message_part.return_kind.replace('-', '_') + template = getattr(self, field_name, None) + # Map return_kind directly to template attribute name (e.g. 'final-result-processed' -> 'final_result_processed') + + # Special case for tool-denied: only apply if template is explicitly set + if message_part.return_kind == 'tool-denied': + return self._apply_tool_template(message_part, ctx, template) if template else message_part + + if template := template or return_kind_to_default_prompt_template.get(message_part.return_kind): + return self._apply_tool_template(message_part, ctx, template) + + elif isinstance(message_part, RetryPromptPart): + return self._apply_retry_template(message_part, ctx) + + return message_part + + def apply_template_message_history(self, _messages: list[ModelMessage], ctx: RunContext[Any]) -> list[ModelMessage]: + return [ + replace( + message, + parts=[self.apply_template(part, ctx) for part in message.parts], + ) + if isinstance(message, ModelRequest) + else message + for message in _messages + ] + + def _apply_retry_template( + self, + message_part: RetryPromptPart, + ctx: RunContext[Any], + ) -> RetryPromptPart: + """Render the full retry response based on content type. + + Selects the appropriate templates and applies them in a single pass, + pre-rendering everything so model_response() can just return the result. + """ + content = message_part.content + + if isinstance(content, str): + if message_part.tool_name is None: + # String without tool context (e.g., output validator raising ModelRetry) + description_template = self.description_template or default_validation_feedback + description = description_template(content) + retry_template = self.model_retry_string_no_tool or DEFAULT_MODEL_RETRY + else: + # String from a tool - use content directly + description = content + retry_template = self.model_retry_string_tool or DEFAULT_MODEL_RETRY + else: + # List of ErrorDetails (validation errors) + description_template = self.description_template or default_validation_error + description = description_template(content) + retry_template = self.validation_errors_retry or DEFAULT_MODEL_RETRY + + # Resolve callable if needed + if callable(retry_template): + retry_template = retry_template(message_part, ctx) + + return replace(message_part, retry_message=f'{description}\n\n{retry_template}') + + def _apply_tool_template( + self, + message_part: ToolReturnPart, + ctx: RunContext[Any], + template: str | Callable[[ToolReturnPart, RunContext[Any]], str], + ) -> ToolReturnPart: + content = template(message_part, ctx) if callable(template) else template + return replace(message_part, content=content) + + +@dataclass +class ToolConfig: + """Configuration for customizing tool descriptions and argument descriptions at runtime. + + This allows you to override tool metadata without modifying the original tool definitions. + """ + + name: str | None = None + tool_description: str | None = None + strict: bool | None = None + tool_args_descriptions: dict[str, str] | None = None + + +@dataclass +class PromptConfig: + """Configuration for customizing all strings and prompts sent to the model by Pydantic AI. + + `PromptConfig` provides a clean, extensible interface for overriding any text that + Pydantic AI sends to the model. This includes: + + - **Prompt Templates**: Messages for retry prompts, tool return confirmations, + validation errors, and other system-generated text via [`PromptTemplates`][pydantic_ai.PromptTemplates]. + - **Tool Configuration**: Tool descriptions, parameter descriptions, and other + tool metadata - allowing you to override descriptions and args for tools at the agent level. + + This allows you to fully customize how your agent communicates with the model + without modifying the underlying tool or agent code. + + Example: + ```python + from pydantic_ai import Agent, PromptConfig, PromptTemplates + + agent = Agent( + 'openai:gpt-4o', + prompt_config=PromptConfig( + templates=PromptTemplates( + validation_errors_retry='Please correct the errors and try again.', + final_result_processed='Result received successfully.', + ), + ), + ) + ``` + + Attributes: + templates: Templates for customizing system-generated messages like retry prompts, + tool return confirmations, and validation error messages. + tool_config: Configuration for customizing tool descriptions and metadata. + """ + + templates: PromptTemplates | None = None + """Templates for customizing system-generated messages sent to the model. + + See [`PromptTemplates`][pydantic_ai.PromptTemplates] for available template options. + """ + + tool_config: dict[str, ToolConfig] | None = None + """Configuration for customizing tool descriptions and metadata, keyed by tool name. + See [`ToolConfig`][pydantic_ai.ToolConfig] for available configuration options. + """ + + @staticmethod + async def generate_prompt_config_from_agent(agent: Agent, model: Model) -> PromptConfig: + """Generate a PromptConfig instance based on an Agent instance. + + The information we can find we will fill in, everything else can be None or defaults if any. + + """ + tool_config: dict[str, ToolConfig] = {} + + prompt_templates: PromptTemplates = PromptTemplates( + final_result_processed=DEFAULT_FINAL_RESULT_PROCESSED, + output_tool_not_executed=DEFAULT_OUTPUT_TOOL_NOT_EXECUTED, + output_validation_failed=DEFAULT_OUTPUT_VALIDATION_FAILED, + function_tool_not_executed=DEFAULT_FUNCTION_TOOL_NOT_EXECUTED, + tool_denied=DEFAULT_TOOL_CALL_DENIED, + validation_errors_retry=DEFAULT_MODEL_RETRY, + model_retry_string_tool=DEFAULT_MODEL_RETRY, + model_retry_string_no_tool=DEFAULT_MODEL_RETRY, + prompted_output_template=DEFAULT_PROMPTED_OUTPUT_TEMPLATE, + description_template=None, # No default description template should be picked as per the constraints of the model + ) + + run_ctx = RunContext(deps=None, model=model, usage=RunUsage()) + + # Include both regular and output tools + from .toolsets import CombinedToolset + + all_toolsets = [*agent.toolsets] + if output_toolset := getattr(agent, '_output_toolset', None): + all_toolsets.append(output_toolset) + + toolset = CombinedToolset(all_toolsets) + tools = await toolset.get_tools(run_ctx) + + for tool_name, toolset_tool in tools.items(): + tool_def = toolset_tool.tool_def + tool_config[tool_name] = ToolConfig( + name=tool_name, + tool_description=tool_def.description, + strict=tool_def.strict, + tool_args_descriptions=_extract_descriptions_from_json_schema(tool_def.parameters_json_schema), + ) + + return PromptConfig( + tool_config=tool_config, + templates=prompt_templates, + ) + + +# JSON Schema keys +_PROPERTIES = 'properties' +_DEFS = '$defs' +_REF = '$ref' +_REF_PREFIX = '#/$defs/' +_DESCRIPTION = 'description' + + +def _extract_descriptions_from_json_schema(parameters_json_schema: dict[str, Any]) -> dict[str, str]: + """Extract field descriptions from a JSON schema into dot notation format. + + Recursively traverses the schema's properties to build a flat dictionary mapping + dot-notation paths to their descriptions. This is useful for prompt optimizers + that need to modify tool argument descriptions. + """ + properties = parameters_json_schema.get(_PROPERTIES, {}) + if not properties: + return {} + + result: dict[str, str] = {} + defs = parameters_json_schema.get(_DEFS, {}) + visited: set[str] = set() + + def extract_from_properties(path: str, props: dict[str, Any]) -> None: + """Recursively extract descriptions from properties.""" + for key, value in props.items(): + full_path = f'{path}.{key}' if path else key + + if description := value.get(_DESCRIPTION): + result[full_path] = description + + if nested_props := value.get(_PROPERTIES): + extract_from_properties(full_path, nested_props) + elif (ref := value.get(_REF)) and ref.startswith(_REF_PREFIX): + def_name = ref[len(_REF_PREFIX) :] + if def_name not in visited: + visited.add(def_name) + if nested_props := defs.get(def_name, {}).get(_PROPERTIES): + extract_from_properties(full_path, nested_props) + visited.remove(def_name) + + extract_from_properties('', properties) + return result diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 900278ce44..b23b7cb06b 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -14,6 +14,7 @@ from .builtin_tools import AbstractBuiltinTool from .exceptions import ModelRetry from .messages import RetryPromptPart, ToolCallPart, ToolReturn +from .prompt_config import DEFAULT_TOOL_CALL_DENIED __all__ = ( 'AgentDepsT', @@ -176,7 +177,7 @@ class ToolApproved: class ToolDenied: """Indicates that a tool call has been denied and that a denial message should be returned to the model.""" - message: str = 'The tool call was denied.' + message: str = DEFAULT_TOOL_CALL_DENIED """The message to return to the model.""" _: KW_ONLY diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py b/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py index af604d4328..02776d5cd3 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py @@ -1,6 +1,10 @@ from __future__ import annotations +import copy from dataclasses import dataclass, replace +from typing import Any + +from pydantic_ai.prompt_config import ToolConfig from .._run_context import AgentDepsT, RunContext from ..exceptions import UserError @@ -34,3 +38,115 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[ name: replace(original_tools[name], tool_def=tool_def) for name, tool_def in prepared_tool_defs_by_name.items() } + + +@dataclass +class ToolConfigPreparedToolset(WrapperToolset[AgentDepsT]): + """A toolset that prepares the tools it contains using a ToolConfig.""" + + tool_config: dict[str, ToolConfig] + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + original_tools = await super().get_tools(ctx) + + # Start with a shallow copy to avoid mutating the parent's dict + result_tools = dict(original_tools) + + # Iterate tool_config - skip tools that don't exist in this toolset + # (tool_config may be shared across multiple toolsets, e.g. function tools + output tools) + for tool_name, config in self.tool_config.items(): + if tool_name not in original_tools: + continue + + tool = original_tools[tool_name] + original_tool_def = tool.tool_def + parameters_json_schema = copy.deepcopy(original_tool_def.parameters_json_schema) + + if config.tool_args_descriptions: + self._update_arg_descriptions(parameters_json_schema, config.tool_args_descriptions, tool_name) + + updated_tool_def = replace( + original_tool_def, + parameters_json_schema=parameters_json_schema, + **{ + k: v + for k, v in { + 'name': config.name, + 'description': config.tool_description, + 'strict': config.strict, + }.items() + if v is not None + }, + ) + + updated_tool = replace(tool, tool_def=updated_tool_def) + + # Handle renaming: remove old key if renamed, then add with final name + final_tool_name = config.name if config.name is not None else tool_name + if final_tool_name != tool_name: + del result_tools[tool_name] + result_tools[final_tool_name] = updated_tool + + return result_tools + + def _update_arg_descriptions( + self, + schema: dict[str, Any], + arg_descriptions: dict[str, str], + tool_name: str, + ) -> None: + """Update descriptions for argument paths in the JSON schema (modifies schema in place).""" + defs = schema.get(_DEFS, {}) + + for arg_path, description in arg_descriptions.items(): + current = schema + parts = arg_path.split('.') + + for i, part in enumerate(parts): + # Resolve $ref if present. + # We inline the definition to avoid modifying shared definitions in $defs + # and to handle chained references (A -> B -> C). + visited_refs: set[str] = set() + while (ref := current.get(_REF)) and isinstance(ref, str) and ref.startswith(_REF_PREFIX): + if ref in visited_refs: + raise UserError(f"Circular reference detected in schema at '{arg_path}': {ref}") + visited_refs.add(ref) + + def_name = ref[len(_REF_PREFIX) :] + if def_name not in defs: + raise UserError(f"Invalid path '{arg_path}' for tool '{tool_name}': undefined $ref '{ref}'.") + + # Inline the definition: replace 'current' contents with the definition's contents + # This ensures we don't mutate the shared definition in $defs + # "sender": { "$ref": "#/$defs/User" } + # "receiver": { "$ref": "#/$defs/User" } + # We write to the 'sender' key itself and not the $defs key because we don't want to mutate the shared definition which is also used by receiver. + # "sender": { + # "description": "I CHANGED THIS!", + # # ... copy of other User fields ... + # } + + target_def = defs[def_name] + current.pop(_REF) + current.update(copy.deepcopy(target_def)) + + props = current.get(_PROPERTIES, {}) + if part not in props: + available = ', '.join(f"'{p}'" for p in props) + raise UserError( + f"Invalid path '{arg_path}' for tool '{tool_name}': " + f"'{part}' not found. Available properties: {available}" + ) + + if i == len(parts) - 1: + props[part][_DESCRIPTION] = description + else: + current = props[part] + + +# JSON Schema keys used for traversal +_PROPERTIES = 'properties' +_DEFS = '$defs' +_REF = '$ref' +_REF_PREFIX = '#/$defs/' +_DESCRIPTION = 'description' diff --git a/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py b/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py index 391cf06f2f..91a2980ef0 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py +++ b/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py @@ -170,6 +170,7 @@ async def transform_stream( # noqa: C901 tool_call_id=tool_call_id, tool_name=tool_name, content='Final result processed.', + return_kind='final-result-processed', ) ) async for e in self.handle_function_tool_result(output_tool_result_event): diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py index fa82b9255b..0e10951769 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py @@ -203,7 +203,11 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: # if part.state == 'output-available': builder.add( - ToolReturnPart(tool_name=tool_name, tool_call_id=tool_call_id, content=part.output) + ToolReturnPart( + tool_name=tool_name, + tool_call_id=tool_call_id, + content=part.output, + ) ) elif part.state == 'output-error': builder.add( diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index c7be2c3ce2..f5928f4829 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -1084,6 +1084,7 @@ async def test_request_structured_response(allow_model_requests: None): content='Final result processed.', tool_call_id='123', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1184,6 +1185,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='2', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1301,6 +1303,7 @@ async def retrieve_entity_info(name: str) -> str: content="alice is bob's wife", tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', part_kind='tool-return', ), ToolReturnPart( @@ -1308,6 +1311,7 @@ async def retrieve_entity_info(name: str) -> str: content="bob is alice's husband", tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', part_kind='tool-return', ), ToolReturnPart( @@ -1315,6 +1319,7 @@ async def retrieve_entity_info(name: str) -> str: content="charlie is alice's son", tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', part_kind='tool-return', ), ToolReturnPart( @@ -1322,6 +1327,7 @@ async def retrieve_entity_info(name: str) -> str: content="daisy is bob's daughter and charlie's younger sister", tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', part_kind='tool-return', ), ] @@ -1735,6 +1741,7 @@ async def get_image() -> BinaryContent: content='See file 1c8566', tool_call_id='toolu_01WALUz3dC75yywrmL6dF3Bc', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -6587,6 +6594,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='toolu_01X9wcHKKAZD9tBC711xipPa', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -6625,6 +6633,7 @@ async def get_user_country() -> str: content='Final result processed.', tool_call_id='toolu_01LZABsgreMefH2Go8D5PQbW', timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -6696,6 +6705,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='toolu_01JJ8TequDsrEU2pv1QFRWAK', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -6789,6 +6799,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='toolu_01ArHq5f2wxRpRF2PVQcKExM', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/models/test_bedrock.py b/tests/models/test_bedrock.py index 9647a0eb7c..205f1a43b5 100644 --- a/tests/models/test_bedrock.py +++ b/tests/models/test_bedrock.py @@ -337,6 +337,7 @@ async def temperature(city: str, date: datetime.date) -> str: content='30°C', tool_call_id='tooluse_5WEci1UmQ8ifMFkUcy2gHQ', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -368,6 +369,7 @@ async def temperature(city: str, date: datetime.date) -> str: content='Final result processed.', tool_call_id='tooluse_9AjloJSaQDKmpPFff-2Clg', timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -636,6 +638,7 @@ async def get_temperature(city: str) -> str: content='30°C', tool_call_id='tooluse_lAG_zP8QRHmSYOwZzzaCqA', timestamp=IsDatetime(), + return_kind='tool-executed', ) ), PartStartEvent(index=0, part=TextPart(content='The')), diff --git a/tests/models/test_cohere.py b/tests/models/test_cohere.py index 2a574870d1..60273d0ad9 100644 --- a/tests/models/test_cohere.py +++ b/tests/models/test_cohere.py @@ -237,6 +237,7 @@ async def test_request_structured_response(allow_model_requests: None): content='Final result processed.', tool_call_id='123', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -358,6 +359,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='2', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index fe6fa8399a..9eeadfa69c 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -642,6 +642,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient): content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -732,12 +733,14 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='tool-executed', ), ToolReturnPart( tool_name='get_location', content='{"lat": 41, "lng": -74}', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='tool-executed', ), ], run_id=IsStr(), @@ -916,10 +919,18 @@ async def bar(y: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='foo', content='a', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='foo', + content='a', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ), ToolReturnPart( - tool_name='bar', content='b', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='bar', + content='b', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ), ], run_id=IsStr(), @@ -940,6 +951,7 @@ async def bar(y: str) -> str: content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1013,6 +1025,7 @@ def get_location(loc_name: str) -> str: content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='function-tool-not-executed', ) ], run_id=IsStr(), @@ -1222,6 +1235,7 @@ async def get_image() -> BinaryContent: content='See file 1c8566', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -1733,6 +1747,7 @@ async def bar() -> str: content='hello', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1762,6 +1777,7 @@ async def bar() -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1817,6 +1833,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1846,6 +1863,7 @@ async def get_user_country() -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -2130,6 +2148,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/models/test_google.py b/tests/models/test_google.py index c82afbff34..04c7c7b660 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -232,7 +232,11 @@ async def temperature(city: str, date: datetime.date) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='temperature', content='30°C', tool_call_id=IsStr(), timestamp=IsDatetime() + tool_name='temperature', + content='30°C', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -266,6 +270,7 @@ async def temperature(city: str, date: datetime.date) -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -638,6 +643,7 @@ async def get_capital(country: str) -> str: content='Paris', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -763,7 +769,11 @@ async def get_temperature(city: str) -> str: IsInstance(FunctionToolCallEvent), FunctionToolResultEvent( result=ToolReturnPart( - tool_name='get_capital', content='Paris', tool_call_id=IsStr(), timestamp=IsDatetime() + tool_name='get_capital', + content='Paris', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + return_kind='tool-executed', ) ), PartStartEvent( @@ -781,7 +791,11 @@ async def get_temperature(city: str) -> str: IsInstance(FunctionToolCallEvent), FunctionToolResultEvent( result=ToolReturnPart( - tool_name='get_temperature', content='30°C', tool_call_id=IsStr(), timestamp=IsDatetime() + tool_name='get_temperature', + content='30°C', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + return_kind='tool-executed', ) ), PartStartEvent(index=0, part=TextPart(content='The temperature in Paris')), @@ -2474,6 +2488,7 @@ async def bar() -> str: content='hello', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2505,6 +2520,7 @@ async def bar() -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -2572,6 +2588,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2603,6 +2620,7 @@ async def get_user_country() -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -2660,6 +2678,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2906,6 +2925,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -4564,6 +4584,7 @@ def get_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -4628,6 +4649,7 @@ def get_country() -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -4927,6 +4949,7 @@ def get_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -4979,6 +5002,7 @@ def get_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ), PartStartEvent(index=0, part=TextPart(content='The')), diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 0f10b4b7e4..71074fe4bf 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -264,6 +264,7 @@ async def test_request_structured_response(allow_model_requests: None): content='Final result processed.', tool_call_id='123', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -388,6 +389,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='2', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -529,6 +531,7 @@ async def test_stream_structured(allow_model_requests: None): content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -636,6 +639,7 @@ async def get_image() -> BinaryContent: content='See file 1c8566', tool_call_id='call_wkpd', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -5401,6 +5405,7 @@ async def get_something_by_name(name: str) -> str: content='Something with name: nonexistent', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], instructions='Be concise. Never use pretty double quotes, just regular ones.', @@ -5530,6 +5535,7 @@ async def get_something_by_name(name: str) -> str: content='Something with name: test_name', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], instructions='Be concise. Never use pretty double quotes, just regular ones.', diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index ed99de4e56..d111b086fd 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -417,6 +417,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='2', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index dffcc5512f..92dcfaacdb 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -469,6 +469,7 @@ class CityLocation(BaseModel): content='Final result processed.', tool_call_id='123', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -538,6 +539,7 @@ class CityLocation(BaseModel): content='Final result processed.', tool_call_id='123', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -604,6 +606,7 @@ async def test_request_output_type_with_arguments_str_response(allow_model_reque content='Final result processed.', tool_call_id='123', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1189,6 +1192,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='2', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1349,6 +1353,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='2', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1378,6 +1383,7 @@ async def get_location(loc_name: str) -> str: content='Final result processed.', tool_call_id='1', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1484,6 +1490,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='1', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1507,6 +1514,7 @@ async def get_location(loc_name: str) -> str: content='Final result processed.', tool_call_id='1', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1600,6 +1608,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='1', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1747,6 +1756,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='2', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1951,6 +1961,7 @@ async def get_image() -> BinaryContent: content='See file 1c8566', tool_call_id='GJYBCIkcS', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index 196b140454..83b79533a2 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -187,6 +187,7 @@ def test_weather(): content='{"lat": 51, "lng": 0}', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -205,6 +206,7 @@ def test_weather(): content='Raining', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -257,6 +259,7 @@ def test_var_args(): 'metadata': None, 'timestamp': IsStr() & IsNow(iso_string=True, tz=timezone.utc), # type: ignore[reportUnknownMemberType] 'part_kind': 'tool-return', + 'return_kind': 'tool-executed', } ) @@ -389,19 +392,39 @@ def test_call_all(): ModelRequest( parts=[ ToolReturnPart( - tool_name='foo', content='1', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='foo', + content='1', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ), ToolReturnPart( - tool_name='bar', content='2', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='bar', + content='2', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ), ToolReturnPart( - tool_name='baz', content='3', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='baz', + content='3', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ), ToolReturnPart( - tool_name='qux', content='4', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='qux', + content='4', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ), ToolReturnPart( - tool_name='quz', content='a', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='quz', + content='a', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ), ], run_id=IsStr(), diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index f7a6809a71..c6f0a30c76 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -100,6 +100,7 @@ def test_custom_output_args(): content='Final result processed.', tool_call_id='pyd_ai_tool_call_id__final_result', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -147,6 +148,7 @@ class Foo(BaseModel): content='Final result processed.', tool_call_id='pyd_ai_tool_call_id__final_result', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -190,6 +192,7 @@ def test_output_type(): content='Final result processed.', tool_call_id='pyd_ai_tool_call_id__final_result', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -248,7 +251,11 @@ async def my_ret(x: int) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='my_ret', content='1', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='my_ret', + content='1', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 993eebcc88..b65fb32c4f 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -306,6 +306,7 @@ async def test_request_structured_response(allow_model_requests: None): content='Final result processed.', tool_call_id='123', timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -440,6 +441,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id='2', timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1159,6 +1161,7 @@ async def get_image() -> ImageUrl: content='See file bd38f5', tool_call_id='call_4hrT4QP9jfojtK69vGiFCFjG', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -1248,6 +1251,7 @@ async def get_image() -> BinaryContent: content='See file 1c8566', tool_call_id='call_Btn0GIzGr4ugNlLmkQghQUMY', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -2363,7 +2367,11 @@ async def get_temperature(city: str) -> float: ModelRequest( parts=[ ToolReturnPart( - tool_name='get_temperature', content=20.0, tool_call_id=IsStr(), timestamp=IsDatetime() + tool_name='get_temperature', + content=20.0, + tool_call_id=IsStr(), + timestamp=IsDatetime(), + return_kind='tool-executed', ) ], instructions='You are a helpful assistant.', @@ -2789,6 +2797,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2827,6 +2836,7 @@ async def get_user_country() -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -2891,6 +2901,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_J1YabdC7G7kzEZNbbZopwenH', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2979,6 +2990,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_PkRGedQNRFUzJp2R7dO7avWR', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -3069,6 +3081,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_SIttSeiOistt33Htj4oiHOOX', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -3159,6 +3172,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_s7oT9jaLAsEqTgvxZTmFh0wB', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -3249,6 +3263,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_wJD14IyJ4KKVtjCrGyNCHO09', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index 18016ccf4c..d32059ec21 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -345,6 +345,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), ], run_id=IsStr(), @@ -423,6 +424,7 @@ async def get_image() -> BinaryContent: content='See file 1c8566', tool_call_id='call_FLm3B1f8QAan0KpbUXhNY8bA', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -1457,6 +1459,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_ZWkVhdUjupo528U9dqgFeRkH', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1487,6 +1490,7 @@ async def get_user_country() -> str: content='Final result processed.', tool_call_id='call_iFBd0zULhSZRR908DfH73VwN', timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1548,6 +1552,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_aTJhYjzmixZaVGqwl5gn2Ncr', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1629,6 +1634,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_tTAThu8l2S9hNky2krdwijGP', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1712,6 +1718,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_UaLahjOtaM2tTyYZLxTCbOaP', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1791,6 +1798,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_FrlL4M0CbAy8Dhv4VqF1Shom', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1874,6 +1882,7 @@ async def get_user_country() -> str: content='Mexico', tool_call_id='call_my4OyoVXRT0m7bLWmsxcaCQI', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2444,6 +2453,7 @@ def update_plan(plan: str) -> str: content='plan updated', tool_call_id='call_gL7JE6GDeGGsFubqO2XGytyO', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], instructions="You are a helpful assistant that uses planning. You MUST use the update_plan tool and continually update it as you make progress against the user's prompt", @@ -3718,6 +3728,7 @@ def get_meaning_of_life() -> int: content=42, tool_call_id='call_3WCunBU7lCG1HHaLmnnRJn8I', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -6168,6 +6179,7 @@ class Animal(BaseModel): content='Final result processed.', tool_call_id='call_eE7MHM5WMJnMt5srV69NmBJk', timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -6383,6 +6395,7 @@ async def get_animal() -> str: content='axolotl', tool_call_id='call_t76xO1K2zqrJkawkU3tur8vj', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -6695,6 +6708,7 @@ class CityLocation(BaseModel): content='Final result processed.', tool_call_id='call_LIXPi261Xx3dGYzlDsOoyHGk', timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), diff --git a/tests/test_a2a.py b/tests/test_a2a.py index 4e5f74f476..edf9e6045e 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -638,6 +638,7 @@ def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ), UserPromptPart(content='Second message', timestamp=IsDatetime()), ], diff --git a/tests/test_agent.py b/tests/test_agent.py index 6ce2d91c54..a6954cdc73 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -68,6 +68,7 @@ from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.output import OutputObjectDefinition, StructuredDict, ToolOutput +from pydantic_ai.prompt_config import PromptConfig, PromptTemplates, ToolConfig from pydantic_ai.result import RunUsage from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import DeferredToolRequests, DeferredToolResults, ToolDefinition, ToolDenied @@ -226,6 +227,7 @@ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -235,6 +237,595 @@ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse assert result.all_messages_json().startswith(b'[{"parts":[{"content":"Hello",') +def test_prompt_config_callable(): + """Test all prompt templates: validation_errors_retry, final_result_processed, output_tool_not_executed, and function_tool_not_executed.""" + + def my_function_tool() -> str: # pragma: no cover + return 'function executed' + + def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + if len(messages) == 1: + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, '{"a": "wrong", "b": "foo"}')]) + + else: + assert info.function_tools is not None + return ModelResponse( + parts=[ + ToolCallPart(info.output_tools[0].name, '{"a": 42, "b": "foo"}'), # Succeeds + ToolCallPart(info.output_tools[0].name, '{"a": 99, "b": "bar"}'), # Not executed + ToolCallPart(info.function_tools[0].name, '{}'), # Not executed + ] + ) + + agent = Agent( + FunctionModel(return_model), + output_type=Foo, + prompt_config=PromptConfig( + templates=PromptTemplates( + validation_errors_retry=lambda part, ctx: 'Please fix these validation errors and try again.', + final_result_processed=lambda part, ctx: f'Custom final result {part.content}', + output_tool_not_executed=lambda part, ctx: f'Custom output not executed: {part.tool_name}', + function_tool_not_executed=lambda part, ctx: f'Custom function not executed: {part.tool_name}', + ) + ), + ) + + agent.tool_plain(my_function_tool) + + result = agent.run_sync('Hello') + assert result.output.model_dump() == {'a': 42, 'b': 'foo'} + + retry_request = result.all_messages()[2] + assert isinstance(retry_request, ModelRequest) + retry_part = retry_request.parts[0] + assert isinstance(retry_part, RetryPromptPart) + # model_response() returns validation errors + retry_message appended + assert retry_part.model_response() == snapshot("""\ +1 validation error: +```json +[ + { + "type": "int_parsing", + "loc": [ + "a" + ], + "msg": "Input should be a valid integer, unable to parse string as an integer", + "input": "wrong" + } +] +``` + +Please fix these validation errors and try again.\ +""") + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))], + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='final_result', args='{"a": "wrong", "b": "foo"}', tool_call_id=IsStr())], + usage=RequestUsage(input_tokens=51, output_tokens=7), + model_name='function:return_model:', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + tool_name='final_result', + content=[ + { + 'type': 'int_parsing', + 'loc': ('a',), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'wrong', + } + ], + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + retry_message="""\ +1 validation error: +```json +[ + { + "type": "int_parsing", + "loc": [ + "a" + ], + "msg": "Input should be a valid integer, unable to parse string as an integer", + "input": "wrong" + } +] +``` + +Please fix these validation errors and try again.\ +""", + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr()), + ToolCallPart(tool_name='final_result', args='{"a": 99, "b": "bar"}', tool_call_id=IsStr()), + ToolCallPart(tool_name='my_function_tool', args='{}', tool_call_id=IsStr()), + ], + usage=RequestUsage(input_tokens=91, output_tokens=23), + model_name='function:return_model:', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Custom final result Final result processed.', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', + ), + ToolReturnPart( + tool_name='final_result', + content='Custom output not executed: final_result', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='output-tool-not-executed', + ), + ToolReturnPart( + tool_name='my_function_tool', + content='Custom function not executed: my_function_tool', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', + ), + ], + run_id=IsStr(), + ), + ] + ) + + # Covering the branch where retry prompt falls to the default template + from pydantic_ai.prompt_config import DEFAULT_MODEL_RETRY + + templates = PromptTemplates(final_result_processed='test') # Only set non-retry template + retry_part_for_default = RetryPromptPart(content='error', tool_name='some_tool') + result_part = templates.apply_template(retry_part_for_default, None) # type: ignore[arg-type] + assert isinstance(result_part, RetryPromptPart) + assert result_part.retry_message == f'error\n\n{DEFAULT_MODEL_RETRY}' + + +def test_prompt_config_string_and_override_prompt_config(): + """Test all prompt templates: validation_errors_retry, final_result_processed, output_tool_not_executed, and function_tool_not_executed.""" + + def my_function_tool() -> str: # pragma: no cover + return 'function executed' + + def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + if len(messages) == 1: + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, '{"a": "wrong", "b": "foo"}')]) + + else: + assert info.function_tools is not None + return ModelResponse( + parts=[ + ToolCallPart(info.output_tools[0].name, '{"a": 42, "b": "foo"}'), # Succeeds + ToolCallPart(info.output_tools[0].name, '{"a": 99, "b": "bar"}'), # Not executed + ToolCallPart(info.function_tools[0].name, '{}'), # Not executed + ] + ) + + agent = Agent( + FunctionModel(return_model), + output_type=Foo, + prompt_config=PromptConfig( + templates=PromptTemplates( + validation_errors_retry='Custom retry message', + final_result_processed='Custom final result', + output_tool_not_executed='Custom output not executed:', + function_tool_not_executed='Custom function not executed', + ) + ), + ) + + agent.tool_plain(my_function_tool) + + result = agent.run_sync('Hello') + assert result.output.model_dump() == {'a': 42, 'b': 'foo'} + + retry_request = result.all_messages()[2] + assert isinstance(retry_request, ModelRequest) + retry_part = retry_request.parts[0] + assert isinstance(retry_part, RetryPromptPart) + # model_response() returns validation errors + retry_message appended + assert retry_part.model_response() == snapshot("""\ +1 validation error: +```json +[ + { + "type": "int_parsing", + "loc": [ + "a" + ], + "msg": "Input should be a valid integer, unable to parse string as an integer", + "input": "wrong" + } +] +``` + +Custom retry message""") + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))], + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='final_result', args='{"a": "wrong", "b": "foo"}', tool_call_id=IsStr())], + usage=RequestUsage(input_tokens=51, output_tokens=7), + model_name='function:return_model:', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + tool_name='final_result', + content=[ + { + 'type': 'int_parsing', + 'loc': ('a',), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'wrong', + } + ], + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + retry_message="""\ +1 validation error: +```json +[ + { + "type": "int_parsing", + "loc": [ + "a" + ], + "msg": "Input should be a valid integer, unable to parse string as an integer", + "input": "wrong" + } +] +``` + +Custom retry message\ +""", + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr()), + ToolCallPart(tool_name='final_result', args='{"a": 99, "b": "bar"}', tool_call_id=IsStr()), + ToolCallPart(tool_name='my_function_tool', args='{}', tool_call_id=IsStr()), + ], + usage=RequestUsage(input_tokens=85, output_tokens=23), + model_name='function:return_model:', + timestamp=IsNow(tz=timezone.utc), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Custom final result', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', + ), + ToolReturnPart( + tool_name='final_result', + content='Custom output not executed:', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='output-tool-not-executed', + ), + ToolReturnPart( + tool_name='my_function_tool', + content='Custom function not executed', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', + ), + ], + run_id=IsStr(), + ), + ] + ) + + # Verify prompt_config can be overridden + with agent.override( + prompt_config=PromptConfig(templates=PromptTemplates(validation_errors_retry='Custom retry message override')) + ): + result = agent.run_sync('Hello') + assert result.output.model_dump() == {'a': 42, 'b': 'foo'} + retry_request = result.all_messages()[2] + assert isinstance(retry_request, ModelRequest) + retry_part = retry_request.parts[0] + assert isinstance(retry_part, RetryPromptPart) + # model_response() returns validation errors + retry_message appended + assert retry_part.model_response() == snapshot("""\ +1 validation error: +```json +[ + { + "type": "int_parsing", + "loc": [ + "a" + ], + "msg": "Input should be a valid integer, unable to parse string as an integer", + "input": "wrong" + } +] +``` + +Custom retry message override""") + + def model_with_tool_retry(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.function_tools is not None + if len(messages) == 1: + return ModelResponse(parts=[ToolCallPart('retry_tool', '{}')]) + else: + return ModelResponse(parts=[TextPart('done')]) + + agent_tool_retry = Agent( + FunctionModel(model_with_tool_retry), + output_type=str, + prompt_config=PromptConfig(templates=PromptTemplates(model_retry_string_tool='Custom tool retry message')), + ) + + @agent_tool_retry.tool_plain + def retry_tool() -> str: + raise ModelRetry('Tool failed') + + result_tool_retry = agent_tool_retry.run_sync('Test') + assert result_tool_retry.output == 'done' + assert result_tool_retry.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Test', + timestamp=IsDatetime(), + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='retry_tool', args='{}', tool_call_id=IsStr())], + usage=RequestUsage(input_tokens=51, output_tokens=2), + model_name='function:model_with_tool_retry:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Tool failed', + tool_name='retry_tool', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + retry_message="""\ +Tool failed + +Custom tool retry message\ +""", + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='done')], + usage=RequestUsage(input_tokens=57, output_tokens=3), + model_name='function:model_with_tool_retry:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) + + # Test model_retry_string_no_tool template (RetryPromptPart with string content, no tool) + def model_with_no_tool_retry(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + return ModelResponse(parts=[TextPart('invalid')]) + else: + return ModelResponse(parts=[TextPart('valid')]) + + agent_no_tool_retry = Agent( + FunctionModel(model_with_no_tool_retry), + output_type=str, + prompt_config=PromptConfig(templates=PromptTemplates(model_retry_string_no_tool='Custom no-tool retry')), + ) + + @agent_no_tool_retry.output_validator + def check_valid(ctx: RunContext[None], output: str) -> str: + if output == 'invalid': + raise ModelRetry('Output is invalid') + return output + + result_no_tool = agent_no_tool_retry.run_sync('Test') + assert result_no_tool.output == 'valid' + assert result_no_tool.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Test', + timestamp=IsDatetime(), + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='invalid')], + usage=RequestUsage(input_tokens=51, output_tokens=1), + model_name='function:model_with_no_tool_retry:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Output is invalid', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + retry_message="""\ +Validation feedback: +Output is invalid + +Custom no-tool retry\ +""", + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='valid')], + usage=RequestUsage(input_tokens=59, output_tokens=2), + model_name='function:model_with_no_tool_retry:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) + + +def test_prompt_config_tool_config_descriptions(): + """Test that ToolConfig.tool_description updates tool descriptions at the agent level.""" + + def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + # Verify the tool description was updated + assert info.function_tools is not None + my_tool = next(t for t in info.function_tools if t.name == 'my_tool') + assert my_tool.description == 'Custom tool description from ToolConfig' + return ModelResponse(parts=[TextPart('Done')]) + + agent = Agent( + FunctionModel(return_model), + prompt_config=PromptConfig( + tool_config={ + 'my_tool': ToolConfig( + name=None, + tool_description='Custom tool description from ToolConfig', + strict=None, + tool_args_descriptions=None, + ) + } + ), + ) + + @agent.tool_plain + def my_tool(x: int) -> int: # pragma: no cover + """Original description that should be overridden""" + return x * 2 + + result = agent.run_sync('Hello') + assert result.output == 'Done' + + +def test_prompt_config_tool_config_descriptions_at_runtime(): + """Test that ToolConfig.tool_description passed to run_sync() overrides agent-level prompt_config.""" + observed_descriptions: list[str | None] = [] + + def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.function_tools is not None + basic_tool = next(tool for tool in info.function_tools if tool.name == 'basic_tool') + observed_descriptions.append(basic_tool.description) + return ModelResponse(parts=[TextPart('Done')]) + + # Agent with agent-level prompt_config + agent = Agent( + FunctionModel(return_model), + prompt_config=PromptConfig( + tool_config={ + 'basic_tool': ToolConfig( + name=None, tool_description='Agent-level tool description', strict=None, tool_args_descriptions=None + ), + 'not_present_basic_tool': ToolConfig( + name=None, tool_description='Should not be used', strict=None, tool_args_descriptions=None + ), + } + ), + ) + + @agent.tool_plain + def basic_tool(x: int) -> int: # pragma: no cover + """Original description that should be overridden""" + return x * 2 + + # First run: no runtime prompt_config, should use agent-level description + result = agent.run_sync('Hello') + assert result.output == 'Done' + assert observed_descriptions[-1] == 'Agent-level tool description' + + # Second run: pass runtime prompt_config, should override agent-level description + result = agent.run_sync( + 'Hello', + prompt_config=PromptConfig( + tool_config={ + 'basic_tool': ToolConfig( + name=None, + tool_description='Runtime custom tool description', + strict=None, + tool_args_descriptions=None, + ) + } + ), + ) + assert result.output == 'Done' + assert observed_descriptions[-1] == 'Runtime custom tool description' + + +def test_prompt_config_tool_config_output_tool_descriptions(): + """Test that ToolConfig.tool_description updates output tool descriptions (covers ToolConfigPreparedToolset for output_toolset).""" + + def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + # Verify the output tool description was updated + assert info.output_tools is not None + output_tool = next(t for t in info.output_tools if t.name == 'final_result') + assert output_tool.description == 'Custom output tool description from ToolConfig' + assert output_tool.description != 'Output model for testing.' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, '{"a": 42, "b": "bar"}')]) + + class OutputModel(BaseModel): + """Output model for testing.""" + + a: int + b: str + + agent = Agent( + FunctionModel(return_model), + output_type=OutputModel, + prompt_config=PromptConfig( + tool_config={ + 'final_result': ToolConfig( + name=None, + tool_description='Custom output tool description from ToolConfig', + strict=None, + tool_args_descriptions=None, + ) + } + ), + ) + + result = agent.run_sync('Hello') + assert isinstance(result.output, OutputModel) + + def test_result_pydantic_model_validation_error(): def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None @@ -289,7 +880,8 @@ def check_b(cls, v: str) -> str: ] ``` -Fix the errors and try again.""") +Fix the errors and try again.\ +""") def test_output_validator(): @@ -352,6 +944,7 @@ def validate_output(ctx: RunContext[None], o: Foo) -> Foo: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -497,6 +1090,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -508,7 +1102,11 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ModelRequest( parts=[ ToolReturnPart( - tool_name='final_result', content='foobar', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='final_result', + content='foobar', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -522,6 +1120,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1096,6 +1695,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1149,9 +1749,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ModelRequest( parts=[ RetryPromptPart( - content='City not found, I only know Mexico City', - tool_call_id=IsStr(), - timestamp=IsDatetime(), + content='City not found, I only know Mexico City', tool_call_id=IsStr(), timestamp=IsDatetime() ) ], run_id=IsStr(), @@ -1407,6 +2005,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1445,6 +2044,7 @@ def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelRes content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -2028,9 +2628,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ModelRequest( parts=[ RetryPromptPart( - content='City not found, I only know Mexico City', - tool_call_id=IsStr(), - timestamp=IsDatetime(), + content='City not found, I only know Mexico City', tool_call_id=IsStr(), timestamp=IsDatetime() ) ], run_id=IsStr(), @@ -2075,7 +2673,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='ret_a', + content='a-apple', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2111,7 +2713,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='ret_a', + content='a-apple', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2176,7 +2782,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='ret_a', + content='a-apple', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2240,7 +2850,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='ret_a', + content='a-apple', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2265,6 +2879,7 @@ async def ret_a(x: str) -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -2292,7 +2907,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='ret_a', + content='a-apple', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2311,6 +2930,7 @@ async def ret_a(x: str) -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ], run_id=IsStr(), @@ -2336,6 +2956,7 @@ async def ret_a(x: str) -> str: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ], run_id=IsStr(), @@ -2461,6 +3082,7 @@ def test_tool() -> str: content='Test response', tool_call_id='call_123', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -3023,24 +3645,28 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='regular_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), ToolReturnPart( tool_name='another_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), ToolReturnPart( tool_name='deferred_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), ], run_id=IsStr(), @@ -3113,12 +3739,14 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='second_output', content='Output tool not used - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='output-tool-not-executed', ), ], run_id=IsStr(), @@ -3169,12 +3797,14 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='final_result', content='Output tool not used - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='output-tool-not-executed', ), ], run_id=IsStr(), @@ -3259,18 +3889,21 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='regular_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='function-tool-not-executed', ), ToolReturnPart( tool_name='another_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), RetryPromptPart( content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool', 'deferred_tool'", @@ -3283,6 +3916,7 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), ], run_id=IsStr(), @@ -3370,18 +4004,21 @@ def regular_tool(x: int) -> int: # pragma: no cover content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='regular_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), ToolReturnPart( tool_name='external_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), ], run_id=IsStr(), @@ -3458,6 +4095,7 @@ def regular_tool(x: int) -> int: content=1, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -3513,6 +4151,7 @@ def regular_tool(x: int) -> int: content=0, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -3537,6 +4176,7 @@ def regular_tool(x: int) -> int: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -3622,21 +4262,28 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='regular_tool', content=42, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), ToolReturnPart( - tool_name='another_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='another_tool', + content=2, + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), RetryPromptPart( content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool', 'deferred_tool'", @@ -3649,6 +4296,7 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), ], run_id=IsStr(), @@ -3721,12 +4369,14 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='second_output', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ], run_id=IsStr(), @@ -3805,6 +4455,7 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ], run_id=IsStr(), @@ -3878,12 +4529,14 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='second_output', content='Output tool not used - output failed validation.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='output-validation-failed', ), ], run_id=IsStr(), @@ -3957,6 +4610,7 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), RetryPromptPart( content='Second output validation failed', @@ -4052,6 +4706,7 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id='second', + return_kind='final-result-processed', ), ], run_id=IsStr(), @@ -4136,6 +4791,7 @@ async def get_location(loc_name: str) -> str: content='{"lat": 51, "lng": 0}', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -4443,6 +5099,7 @@ async def foobar(x: str) -> str: content='inner agent result', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -4764,6 +5421,7 @@ def get_image() -> BinaryContent: content='See file image_id_1', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -4812,6 +5470,7 @@ def get_files(): content=['See file img_001', 'See file vid_002', 'See file aud_003', 'See file doc_004'], tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -5039,6 +5698,7 @@ class Output(BaseModel): content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -5088,7 +5748,11 @@ def my_tool(x: int) -> int: ModelRequest( parts=[ ToolReturnPart( - tool_name='my_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='my_tool', + content=2, + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -5106,7 +5770,11 @@ def my_tool(x: int) -> int: ModelRequest( parts=[ ToolReturnPart( - tool_name='my_tool', content=4, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='my_tool', + content=4, + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -5178,6 +5846,7 @@ def foo_tool(foo: Foo) -> int: 'tool_call_id': IsStr(), 'timestamp': IsStr(), 'part_kind': 'retry-prompt', + 'retry_message': None, } ], 'instructions': None, @@ -5276,6 +5945,7 @@ def analyze_data() -> ToolReturn: tool_call_id=IsStr(), metadata={'foo': 'bar'}, timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -5357,6 +6027,7 @@ def analyze_data() -> ToolReturn: tool_call_id=IsStr(), metadata={'foo': 'bar'}, timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), ], run_id=IsStr(), @@ -5653,6 +6324,7 @@ def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='foo tool added', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -5671,6 +6343,7 @@ def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='Hello from foo', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -5750,6 +6423,7 @@ async def only_if_plan_presented( content='a', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -5774,6 +6448,7 @@ async def only_if_plan_presented( content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -6061,9 +6736,7 @@ def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon ModelRequest( parts=[ RetryPromptPart( - content='Please return text or call a tool.', - tool_call_id=IsStr(), - timestamp=IsDatetime(), + content='Please return text or call a tool.', tool_call_id=IsStr(), timestamp=IsDatetime() ) ], run_id=IsStr(), @@ -6102,7 +6775,12 @@ def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon model = FunctionModel(model_function) - agent = Agent(model, output_type=[str, DeferredToolRequests]) + # Test with tool_call_denied template set (covers the True branch at line 139) + agent = Agent( + model, + output_type=[str, DeferredToolRequests], + prompt_config=PromptConfig(templates=PromptTemplates(tool_denied='Tool call denied custom message.')), + ) @agent.tool_plain(requires_approval=True) def delete_file(path: str) -> str: @@ -6151,6 +6829,7 @@ def create_file(path: str, content: str) -> str: content='File \'new_file.py\' created with content: print("Hello, world!")', tool_call_id='create_file', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -6209,6 +6888,7 @@ def create_file(path: str, content: str) -> str: content='File \'new_file.py\' created with content: print("Hello, world!")', tool_call_id='create_file', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -6220,19 +6900,21 @@ def create_file(path: str, content: str) -> str: content="File 'ok_to_delete.py' deleted", tool_call_id='ok_to_delete', timestamp=IsDatetime(), + return_kind='tool-executed', ), ToolReturnPart( tool_name='delete_file', - content='File cannot be deleted', + content='Tool call denied custom message.', tool_call_id='never_delete', timestamp=IsDatetime(), + return_kind='tool-denied', ), ], run_id=IsStr(), ), ModelResponse( parts=[TextPart(content='Done!')], - usage=RequestUsage(input_tokens=78, output_tokens=24), + usage=RequestUsage(input_tokens=80, output_tokens=24), model_name='function:model_function:', timestamp=IsDatetime(), run_id=IsStr(), @@ -6250,19 +6932,21 @@ def create_file(path: str, content: str) -> str: content="File 'ok_to_delete.py' deleted", tool_call_id='ok_to_delete', timestamp=IsDatetime(), + return_kind='tool-executed', ), ToolReturnPart( tool_name='delete_file', - content='File cannot be deleted', + content='Tool call denied custom message.', tool_call_id='never_delete', timestamp=IsDatetime(), + return_kind='tool-denied', ), ], run_id=IsStr(), ), ModelResponse( parts=[TextPart(content='Done!')], - usage=RequestUsage(input_tokens=78, output_tokens=24), + usage=RequestUsage(input_tokens=80, output_tokens=24), model_name='function:model_function:', timestamp=IsDatetime(), run_id=IsStr(), @@ -6270,6 +6954,74 @@ def create_file(path: str, content: str) -> str: ] ) + def model_function_for_none_template(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + return ModelResponse( + parts=[ + ToolCallPart(tool_name='protected_delete', args={'path': 'file.txt'}, tool_call_id='del_call'), + ] + ) + else: + return ModelResponse(parts=[TextPart('Done!')]) + + agent_no_template = Agent( + FunctionModel(model_function_for_none_template), + output_type=[str, DeferredToolRequests], + prompt_config=PromptConfig(templates=PromptTemplates(final_result_processed='Done')), + ) + + @agent_no_template.tool_plain(requires_approval=True) + def protected_delete(path: str) -> str: # pragma: no cover + return f'File {path!r} deleted' + + result_no_template = await agent_no_template.run('Delete file.txt') + result_no_template = await agent_no_template.run( + message_history=result_no_template.all_messages(), + deferred_tool_results=DeferredToolResults( + approvals={'del_call': ToolDenied('Original denial message preserved')}, + ), + ) + + assert result_no_template.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Delete file.txt', + timestamp=IsDatetime(), + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='protected_delete', args={'path': 'file.txt'}, tool_call_id='del_call')], + usage=RequestUsage(input_tokens=53, output_tokens=6), + model_name='function:model_function_for_none_template:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='protected_delete', + content='Original denial message preserved', + tool_call_id='del_call', + timestamp=IsDatetime(), + return_kind='tool-denied', + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='Done!')], + usage=RequestUsage(input_tokens=57, output_tokens=7), + model_name='function:model_function_for_none_template:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) + async def test_run_with_deferred_tool_results_errors(): agent = Agent('test') @@ -6440,6 +7192,7 @@ def update_file(ctx: RunContext, path: str, content: str) -> str: content="File '.env' updated", tool_call_id='update_file_1', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart(content='continue with the operation', timestamp=IsDatetime()), ], @@ -6852,6 +7605,7 @@ def roll_dice() -> int: content=4, tool_call_id='pyd_ai_tool_call_id__roll_dice', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -6876,6 +7630,7 @@ def roll_dice() -> int: content='Final result processed.', tool_call_id='pyd_ai_tool_call_id__final_result', timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -6910,6 +7665,7 @@ def roll_dice() -> int: content=4, tool_call_id='pyd_ai_tool_call_id__roll_dice', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -6934,6 +7690,7 @@ def roll_dice() -> int: content='Final result processed.', tool_call_id='pyd_ai_tool_call_id__final_result', timestamp=IsDatetime(), + return_kind='final-result-processed', ) ], run_id=IsStr(), diff --git a/tests/test_agent_output_schemas.py b/tests/test_agent_output_schemas.py index 5c63343126..f6fb90df01 100644 --- a/tests/test_agent_output_schemas.py +++ b/tests/test_agent_output_schemas.py @@ -551,3 +551,20 @@ async def test_deferred_output_json_schema(): }, } ) + + +def test_build_instructions_appends_schema_placeholder(): + """Test that build_instructions appends {schema} when template doesn't contain it.""" + from pydantic_ai._output import OutputObjectDefinition, PromptedOutputSchema + + object_def = OutputObjectDefinition( + json_schema={'type': 'object', 'properties': {'name': {'type': 'string'}}}, + name='TestOutput', + description='A test output', + ) + template_without_schema = 'Please respond with JSON.' + + result = PromptedOutputSchema.build_instructions(template_without_schema, object_def) + assert result == snapshot( + 'Please respond with JSON.\n\n{"type": "object", "properties": {"name": {"type": "string"}}, "title": "TestOutput", "description": "A test output"}' + ) diff --git a/tests/test_dbos.py b/tests/test_dbos.py index de99f1b1d0..fbec58a658 100644 --- a/tests/test_dbos.py +++ b/tests/test_dbos.py @@ -371,7 +371,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D BasicSpan(content='ctx.run_step=1'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return","return_kind":"tool-executed"},"content":null,"event_kind":"function_tool_result"}' ) ), ], @@ -386,7 +386,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D BasicSpan(content='ctx.run_step=1'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return","return_kind":"tool-executed"},"content":null,"event_kind":"function_tool_result"}' ) ), ], @@ -450,7 +450,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D BasicSpan(content='ctx.run_step=2'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return","return_kind":"tool-executed"},"content":null,"event_kind":"function_tool_result"}' ) ), ], @@ -672,6 +672,7 @@ async def event_stream_handler( content='Mexico', tool_call_id='call_q2UyBRP7eXNTzAoR8lEhjc9Z', timestamp=IsDatetime(), + return_kind='tool-executed', ) ), FunctionToolResultEvent( @@ -680,6 +681,7 @@ async def event_stream_handler( content='Pydantic AI', tool_call_id='call_b51ijcpFkDiTQG1bQzsrmtW5', timestamp=IsDatetime(), + return_kind='tool-executed', ) ), PartStartEvent( @@ -721,6 +723,7 @@ async def event_stream_handler( content='sunny', tool_call_id='call_LwxJUB9KppVyogRRLQsamRJv', timestamp=IsDatetime(), + return_kind='tool-executed', ) ), PartStartEvent( @@ -1446,12 +1449,14 @@ async def hitl_main_loop(prompt: str) -> AgentRunResult[str | DeferredToolReques content=True, tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), ToolReturnPart( tool_name='create_file', content='Success', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), ], instructions='Just call tools without asking for confirmation.', @@ -1579,12 +1584,14 @@ def hitl_main_loop_sync(prompt: str) -> AgentRunResult[str | DeferredToolRequest content=True, tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), ToolReturnPart( tool_name='create_file', content='Success', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), ], instructions='Just call tools without asking for confirmation.', @@ -1720,6 +1727,7 @@ async def test_dbos_agent_with_model_retry(allow_model_requests: None, dbos: DBO content='sunny', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/test_examples.py b/tests/test_examples.py index ccf2cb3174..6573ef8918 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -325,6 +325,7 @@ async def call_tool( text_responses: dict[str, str | ToolCallPart | Sequence[ToolCallPart]] = { + 'Hello': 'Hello! How can I help you today?', 'Use the web to get the current time.': "In San Francisco, it's 8:21:41 pm PDT on Wednesday, August 6, 2025.", 'Give me a sentence with the biggest news in AI this week.': 'Scientists have developed a universal AI detector that can identify deepfake videos.', 'How many days between 2000-01-01 and 2025-03-18?': 'There are 9,208 days between January 1, 2000, and March 18, 2025.', diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 02bab17cc3..f737b20d92 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -262,6 +262,7 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) content=32.0, tool_call_id='call_QssdxTGkPblTYHmyVES1tKBj', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -433,6 +434,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): content='The weather in Mexico City is sunny and 26 degrees Celsius.', tool_call_id='call_m9goNwaHBbU926w47V7RtWPt', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -515,6 +517,7 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A content='Pydantic AI', tool_call_id='call_LaiWltzI39sdquflqeuF0EyE', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -593,6 +596,7 @@ async def test_tool_returning_text_resource_link(allow_model_requests: None, age content='Pydantic AI\n', tool_call_id='call_qi5GtBeIEyT7Y3yJvVFIi062', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -673,6 +677,7 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: content='See file 1c8566', tool_call_id='call_nFsDHYDZigO0rOHqmChZ3pmt', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart(content=['This is file 1c8566:', image_content], timestamp=IsDatetime()), ], @@ -760,6 +765,7 @@ async def test_tool_returning_image_resource_link( content='See file 1c8566', tool_call_id='call_eVFgn54V9Nuh8Y4zvuzkYjUp', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart(content=['This is file 1c8566:', image_content], timestamp=IsDatetime()), ], @@ -828,6 +834,7 @@ async def test_tool_returning_audio_resource( content='See file 2d36ae', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart(content=['This is file 2d36ae:', audio_content], timestamp=IsDatetime()), ], @@ -900,6 +907,7 @@ async def test_tool_returning_audio_resource_link( content='See file 2d36ae', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -981,6 +989,7 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im content='See file 1c8566', tool_call_id='call_Q7xG8CCG0dyevVfUS0ubsDdN', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ @@ -1060,6 +1069,7 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): content={'foo': 'bar', 'baz': 123}, tool_call_id='call_oqKviITBj8PwpQjGyUu4Zu5x', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1136,6 +1146,7 @@ async def test_tool_returning_unstructured_dict(allow_model_requests: None, agen content={'foo': 'bar', 'baz': 123}, tool_call_id='call_R0n2R7S9vL2aZOX25T9jahTd', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1254,6 +1265,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): content='This is not an error', tool_call_id='call_4xGyvdghYKHN8x19KWkRtA5N', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1330,6 +1342,7 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): content=[], tool_call_id='call_mJTuQ2Cl5SaHPTJbIILEUhJC', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1415,6 +1428,7 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: ], tool_call_id='call_kL0TvjEVQBDGZrn1Zv7iNYOW', timestamp=IsDatetime(), + return_kind='tool-executed', ), UserPromptPart( content=[ diff --git a/tests/test_prefect.py b/tests/test_prefect.py index b1c18b9803..715c80f251 100644 --- a/tests/test_prefect.py +++ b/tests/test_prefect.py @@ -305,7 +305,7 @@ async def run_complex_agent() -> Response: BasicSpan(content='ctx.run_step=1'), BasicSpan( content=IsStr( - regex=r'\{"result":\{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_rI3WKPYvVwlOgCGRjsPP2hEx","metadata":null,"timestamp":"[^"]+","part_kind":"tool-return"\},"content":null,"event_kind":"function_tool_result"\}' + regex=r'\{"result":\{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_rI3WKPYvVwlOgCGRjsPP2hEx","metadata":null,"timestamp":"[^"]+","part_kind":"tool-return","return_kind":"tool-executed"\},"content":null,"event_kind":"function_tool_result"\}' ) ), ], @@ -389,7 +389,7 @@ async def run_complex_agent() -> Response: BasicSpan(content='ctx.run_step=2'), BasicSpan( content=IsStr( - regex=r'\{"result":\{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_NS4iQj14cDFwc0BnrKqDHavt","metadata":null,"timestamp":"[^"]+","part_kind":"tool-return"\},"content":null,"event_kind":"function_tool_result"\}' + regex=r'\{"result":\{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_NS4iQj14cDFwc0BnrKqDHavt","metadata":null,"timestamp":"[^"]+","part_kind":"tool-return","return_kind":"tool-executed"\},"content":null,"event_kind":"function_tool_result"\}' ) ), ], @@ -406,7 +406,7 @@ async def run_complex_agent() -> Response: BasicSpan(content='ctx.run_step=2'), BasicSpan( content=IsStr( - regex=r'\{"result":\{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_SkGkkGDvHQEEk0CGbnAh2AQw","metadata":null,"timestamp":"[^"]+","part_kind":"tool-return"\},"content":null,"event_kind":"function_tool_result"\}' + regex=r'\{"result":\{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_SkGkkGDvHQEEk0CGbnAh2AQw","metadata":null,"timestamp":"[^"]+","part_kind":"tool-return","return_kind":"tool-executed"\},"content":null,"event_kind":"function_tool_result"\}' ) ), ], diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 9149d19d1b..9cddaeb8db 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -87,7 +87,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='ret_a', + content='a-apple', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -123,7 +127,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='ret_a', + content='a-apple', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -179,7 +187,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='ret_a', + content='a-apple', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -225,7 +237,11 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ ToolReturnPart( - tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + tool_name='ret_a', + content='a-apple', + timestamp=IsNow(tz=timezone.utc), + tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -606,6 +622,7 @@ async def ret_a(x: str) -> str: content='hello world', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -633,6 +650,7 @@ async def ret_a(x: str) -> str: content='hello world', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -657,6 +675,7 @@ async def ret_a(x: str) -> str: content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -826,24 +845,28 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='regular_tool', content='Tool not executed - a final result was already processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='function-tool-not-executed', ), ToolReturnPart( tool_name='another_tool', content='Tool not executed - a final result was already processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='function-tool-not-executed', ), ToolReturnPart( tool_name='deferred_tool', content='Tool not executed - a final result was already processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='function-tool-not-executed', ), ], run_id=IsStr(), @@ -918,12 +941,14 @@ async def stream_function(_: list[ModelMessage], info: AgentInfo) -> AsyncIterat content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='second_output', content='Output tool not used - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='output-tool-not-executed', ), ], run_id=IsStr(), @@ -970,12 +995,14 @@ async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | Delt content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='final_result', content='Output tool not used - a final result was already processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='output-tool-not-executed', ), ], run_id=IsStr(), @@ -1076,18 +1103,21 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='regular_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='function-tool-not-executed', ), ToolReturnPart( tool_name='another_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='function-tool-not-executed', ), RetryPromptPart( content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool', 'deferred_tool'", @@ -1100,6 +1130,7 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='function-tool-not-executed', ), ], run_id=IsStr(), @@ -1199,12 +1230,14 @@ def regular_tool(x: int) -> int: # pragma: no cover content='Output tool not used - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='output-tool-not-executed', ), ToolReturnPart( tool_name='regular_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='function-tool-not-executed', ), ], run_id=IsStr(), @@ -1279,6 +1312,7 @@ def regular_tool(x: int) -> int: content=1, tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1338,6 +1372,7 @@ def regular_tool(x: int) -> int: content=0, tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1363,6 +1398,7 @@ def regular_tool(x: int) -> int: content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='final-result-processed', ) ], run_id=IsStr(), @@ -1443,21 +1479,28 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='regular_tool', content=42, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), ToolReturnPart( - tool_name='another_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) + tool_name='another_tool', + content=2, + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), RetryPromptPart( content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool', 'deferred_tool'", @@ -1470,6 +1513,7 @@ def deferred_tool(x: int) -> int: # pragma: no cover content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='function-tool-not-executed', ), ], run_id=IsStr(), @@ -1544,12 +1588,14 @@ async def stream_function(_: list[ModelMessage], info: AgentInfo) -> AsyncIterat content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='second_output', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ], run_id=IsStr(), @@ -1705,12 +1751,14 @@ async def stream_function(_: list[ModelMessage], info: AgentInfo) -> AsyncIterat content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='final-result-processed', ), ToolReturnPart( tool_name='second_output', content='Output tool not used - output failed validation.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='output-validation-failed', ), ], run_id=IsStr(), @@ -1789,6 +1837,7 @@ async def stream_function(_: list[ModelMessage], info: AgentInfo) -> AsyncIterat content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), + return_kind='final-result-processed', ), RetryPromptPart( content='Second output validation failed', @@ -2100,6 +2149,7 @@ def known_tool(x: int) -> int: content=10, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), ), ] @@ -2346,6 +2396,7 @@ def my_tool(ctx: RunContext[None], x: int) -> int: content=84, tool_call_id='my_tool', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2529,6 +2580,7 @@ async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[Agen content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ), PartStartEvent(index=0, part=TextPart(content='')), @@ -2575,6 +2627,7 @@ async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[Agen content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ), PartStartEvent(index=0, part=TextPart(content='')), @@ -2624,6 +2677,7 @@ async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[Agen content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ), PartStartEvent(index=0, part=TextPart(content='')), @@ -2667,6 +2721,7 @@ async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[Agen content='See file bd38f5', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ), content=[ 'This is file bd38f5:', @@ -2716,6 +2771,7 @@ async def ret_a(x: str) -> str: content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), + return_kind='tool-executed', ) ), PartStartEvent(index=0, part=TextPart(content='')), diff --git a/tests/test_temporal.py b/tests/test_temporal.py index c12c37ded4..282dfbbd21 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -455,7 +455,7 @@ async def test_complex_agent_run_in_workflow( BasicSpan(content='ctx.run_step=1'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"(,"return_kind":"tool-executed")?},"content":null,"event_kind":"function_tool_result"}' ) ), ], @@ -484,7 +484,7 @@ async def test_complex_agent_run_in_workflow( BasicSpan(content='ctx.run_step=1'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"(,"return_kind":"tool-executed")?},"content":null,"event_kind":"function_tool_result"}' ) ), ], @@ -578,7 +578,7 @@ async def test_complex_agent_run_in_workflow( BasicSpan(content='ctx.run_step=2'), BasicSpan( content=IsStr( - regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' + regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"(,"return_kind":"tool-executed")?},"content":null,"event_kind":"function_tool_result"}' ) ), ], @@ -811,6 +811,7 @@ async def event_stream_handler( content='Mexico', tool_call_id='call_q2UyBRP7eXNTzAoR8lEhjc9Z', timestamp=IsDatetime(), + return_kind='tool-executed', ) ), FunctionToolResultEvent( @@ -819,6 +820,7 @@ async def event_stream_handler( content='Pydantic AI', tool_call_id='call_b51ijcpFkDiTQG1bQzsrmtW5', timestamp=IsDatetime(), + return_kind='tool-executed', ) ), PartStartEvent( @@ -860,6 +862,7 @@ async def event_stream_handler( content='sunny', tool_call_id='call_LwxJUB9KppVyogRRLQsamRJv', timestamp=IsDatetime(), + return_kind='tool-executed', ) ), PartStartEvent( @@ -1960,12 +1963,14 @@ async def test_temporal_agent_with_hitl_tool(allow_model_requests: None, client: content=True, tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), ToolReturnPart( tool_name='create_file', content='Success', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ), ], instructions='Just call tools without asking for confirmation.', @@ -2119,6 +2124,7 @@ async def test_temporal_agent_with_model_retry(allow_model_requests: None, clien content='sunny', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/test_tools.py b/tests/test_tools.py index 0031f702cd..2c6274ae23 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -30,6 +30,7 @@ ToolReturnPart, UserError, UserPromptPart, + prompt_config, ) from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UnexpectedModelBehavior from pydantic_ai.models.function import AgentInfo, FunctionModel @@ -155,6 +156,50 @@ def test_docstring_google(docstring_format: Literal['google', 'auto']): ) +@pytest.mark.parametrize('docstring_format', ['google', 'auto']) +def test_docstring_google_prompt_config(docstring_format: Literal['google', 'auto']): + agent = Agent(FunctionModel(get_json_schema)) + agent.tool_plain(docstring_format=docstring_format)(google_style_docstring) + p_config = prompt_config.PromptConfig( + tool_config={ + 'google_style_docstring': prompt_config.ToolConfig( + name=None, + tool_description=None, + strict=None, + tool_args_descriptions={ + 'foo': 'The foo thing from tool config.', + 'bar': 'The bar thing from tool config.', + }, + ) + } + ) + + result = agent.run_sync('Hello', prompt_config=p_config) + json_schema = json.loads(result.output) + + assert json_schema == snapshot( + { + 'name': 'google_style_docstring', + 'description': 'Do foobar stuff, a lot.', + 'parameters_json_schema': { + 'properties': { + 'foo': {'description': 'The foo thing from tool config.', 'type': 'integer'}, + 'bar': {'description': 'The bar thing from tool config.', 'type': 'string'}, + }, + 'required': ['foo', 'bar'], + 'type': 'object', + 'additionalProperties': False, + }, + 'outer_typed_dict_key': None, + 'strict': None, + 'kind': 'function', + 'sequential': False, + 'metadata': None, + 'timeout': None, + } + ) + + def sphinx_style_docstring(foo: int, /) -> str: # pragma: no cover """Sphinx style docstring. @@ -1397,6 +1442,7 @@ def my_tool(ctx: RunContext[None], x: int) -> int: content=84, tool_call_id='my_tool', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -1797,6 +1843,7 @@ def buy(fruit: str): tool_call_id='get_price_apple', metadata={'fruit': 'apple', 'price': 10.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='Unknown fruit: banana', @@ -1810,6 +1857,7 @@ def buy(fruit: str): tool_call_id='get_price_pear', metadata={'fruit': 'pear', 'price': 10.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='Unknown fruit: grape', @@ -1890,6 +1938,7 @@ def buy(fruit: str): tool_call_id='get_price_apple', metadata={'fruit': 'apple', 'price': 10.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='Unknown fruit: banana', @@ -1903,6 +1952,7 @@ def buy(fruit: str): tool_call_id='get_price_pear', metadata={'fruit': 'pear', 'price': 10.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='Unknown fruit: grape', @@ -1935,6 +1985,7 @@ def buy(fruit: str): tool_call_id='buy_banana', metadata={'fruit': 'banana', 'price': 100.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='The purchase of pears was denied.', @@ -1975,6 +2026,7 @@ def buy(fruit: str): tool_call_id='buy_banana', metadata={'fruit': 'banana', 'price': 100.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='The purchase of pears was denied.', @@ -2033,6 +2085,7 @@ def buy(fruit: str): tool_call_id='get_price_apple', metadata={'fruit': 'apple', 'price': 10.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='Unknown fruit: banana', @@ -2046,6 +2099,7 @@ def buy(fruit: str): tool_call_id='get_price_pear', metadata={'fruit': 'pear', 'price': 10.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='Unknown fruit: grape', @@ -2065,6 +2119,7 @@ def buy(fruit: str): tool_call_id='buy_banana', metadata={'fruit': 'banana', 'price': 100.0}, timestamp=IsDatetime(), + return_kind='tool-executed', ), RetryPromptPart( content='The purchase of pears was denied.', @@ -2185,6 +2240,7 @@ def bar(x: int) -> int: content=9, tool_call_id='bar', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2238,6 +2294,7 @@ def bar(x: int) -> int: content=9, tool_call_id='bar', timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), @@ -2249,12 +2306,14 @@ def bar(x: int) -> int: content=2, tool_call_id='foo1', timestamp=IsDatetime(), + return_kind='tool-executed', ), ToolReturnPart( tool_name='foo', content='The tool call was denied.', tool_call_id='foo2', timestamp=IsDatetime(), + return_kind='tool-denied', ), ], run_id=IsStr(), @@ -2312,6 +2371,7 @@ def test_deferred_tool_results_serializable(): 'tool_call_id': 'foo', 'timestamp': IsDatetime(), 'part_kind': 'retry-prompt', + 'retry_message': None, }, 'any': {'foo': 'bar'}, }, @@ -2456,6 +2516,7 @@ def always_fail(ctx: RunContext[None]) -> str: content='I guess you never learn', tool_call_id=IsStr(), timestamp=IsDatetime(), + return_kind='tool-executed', ) ], run_id=IsStr(), diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py index 9a4a344d12..6fb103de9c 100644 --- a/tests/test_toolsets.py +++ b/tests/test_toolsets.py @@ -8,10 +8,12 @@ import pytest from inline_snapshot import snapshot +from pydantic import BaseModel, Field from typing_extensions import Self from pydantic_ai import ( AbstractToolset, + Agent, CombinedToolset, FilteredToolset, FunctionToolset, @@ -25,8 +27,14 @@ from pydantic_ai._tool_manager import ToolManager from pydantic_ai.exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior, UserError from pydantic_ai.models.test import TestModel +from pydantic_ai.prompt_config import ( + DEFAULT_FINAL_RESULT_PROCESSED, + PromptConfig, + ToolConfig, +) from pydantic_ai.tools import ToolDefinition from pydantic_ai.toolsets._dynamic import DynamicToolset +from pydantic_ai.toolsets.prepared import ToolConfigPreparedToolset from pydantic_ai.usage import RunUsage pytestmark = pytest.mark.anyio @@ -34,6 +42,21 @@ T = TypeVar('T') +class DoubleNestedArg(BaseModel): + """Deeply nested configuration.""" + + c: str = Field(description='The C parameter for deep config.') + d: str = Field(description='The D parameter for deep config.') + + +class NestedArg(BaseModel): + """Nested configuration for a tool.""" + + a: str = Field(description='The A parameter.') + b: str = Field(description='The B parameter.') + nested: DoubleNestedArg = Field(description='Nested deep configuration.') + + def build_run_context(deps: T, run_step: int = 0) -> RunContext[T]: return RunContext( deps=deps, @@ -483,6 +506,73 @@ async def prepare_add_context(ctx: RunContext[TestDeps], tool_defs: list[ToolDef ] ) + partial_args_toolset = FunctionToolset[None]() + + @partial_args_toolset.tool + def calc(x: int, y: int, z: int, arg: NestedArg) -> int: # pragma: no cover + """Calculate sum""" + return x + y + z + + partial_tool_config = { + 'calc': ToolConfig( + tool_args_descriptions={ + 'x': 'First number', + 'z': 'Third number', + # 'y' intentionally missing + 'arg.b': 'Nested b argument', + } + ) + } + prepared_partial = ToolConfigPreparedToolset(partial_args_toolset, partial_tool_config) + partial_context = build_run_context(None) + tool_config_prepared_toolset = await ToolManager[None](prepared_partial).for_run_step(partial_context) + + assert tool_config_prepared_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='calc', + parameters_json_schema={ + '$defs': { + 'DoubleNestedArg': { + 'description': 'Deeply nested configuration.', + 'properties': { + 'c': {'description': 'The C parameter for deep config.', 'type': 'string'}, + 'd': {'description': 'The D parameter for deep config.', 'type': 'string'}, + }, + 'required': ['c', 'd'], + 'title': 'DoubleNestedArg', + 'type': 'object', + }, + 'NestedArg': { + 'description': 'Nested configuration for a tool.', + 'properties': { + 'a': {'description': 'The A parameter.', 'type': 'string'}, + 'b': {'type': 'string', 'description': 'Nested b argument'}, + 'nested': { + '$ref': '#/$defs/DoubleNestedArg', + 'description': 'Nested deep configuration.', + }, + }, + 'required': ['a', 'b', 'nested'], + 'title': 'NestedArg', + 'type': 'object', + }, + }, + 'additionalProperties': False, + 'properties': { + 'x': {'type': 'integer', 'description': 'First number'}, + 'y': {'type': 'integer'}, + 'z': {'type': 'integer', 'description': 'Third number'}, + 'arg': {'$ref': '#/$defs/NestedArg'}, + }, + 'required': ['x', 'y', 'z', 'arg'], + 'type': 'object', + }, + description='Calculate sum', + ) + ] + ) + async def test_context_manager(): try: @@ -822,3 +912,65 @@ def no_toolset_func(ctx: RunContext[None]) -> None: assert tools == {} assert toolset._toolset is None # pyright: ignore[reportPrivateUsage] + + +async def test_generate_prompt_config_from_agent(): + """Test generating a PromptConfig from an Agent - demonstrates optimizer workflow. + + This test shows how an optimizer would: + 1. Create an agent with tools (including nested BaseModel arguments) + 2. Generate a complete PromptConfig with all defaults filled in + 3. Use the generated config as a starting point for prompt optimization + """ + # Create an agent with tools that have nested BaseModel arguments + agent: Agent[None, str] = Agent('test') + + @agent.tool_plain + def calculate( + x: int, + y: int, + config: NestedArg, + ) -> str: + """Perform a calculation with configuration. + + Args: + x: The first operand. + y: The second operand. + config: Configuration for the calculation. + """ + return f'{x} + {y} with config {config}' # pragma: no cover + + @agent.tool_plain + def simple_tool(name: str) -> str: + """A simple tool with no nested args. + + Args: + name: The name to greet. + """ + return f'Hello {name}' # pragma: no cover + + model = TestModel() + generated_config = await PromptConfig.generate_prompt_config_from_agent(agent, model) + + assert generated_config.templates is not None + templates = generated_config.templates + + assert templates.final_result_processed == DEFAULT_FINAL_RESULT_PROCESSED + + assert generated_config.tool_config is not None + tool_config = generated_config.tool_config + + calc_config = tool_config['calculate'] + + assert calc_config.tool_args_descriptions == snapshot( + { + 'x': 'The first operand.', + 'y': 'The second operand.', + 'config': 'Configuration for the calculation.', + 'config.a': 'The A parameter.', + 'config.b': 'The B parameter.', + 'config.nested': 'Nested deep configuration.', + 'config.nested.c': 'The C parameter for deep config.', + 'config.nested.d': 'The D parameter for deep config.', + } + ) diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index ac17fd0be5..c49cce3c97 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -123,6 +123,7 @@ async def ret_a(x: str) -> str: content='a-apple', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), + return_kind='tool-executed', ) ], run_id=IsStr(),