diff --git a/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md b/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md index 624e580ad1..6eb49b0629 100644 --- a/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md +++ b/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md @@ -443,3 +443,146 @@ gepa = dspy.GEPA( auto="medium" ) ``` + +## Tool Optimization + +### What is enable_tool_optimization? + +When `enable_tool_optimization=True`, GEPA jointly optimizes `dspy.ReAct` modules: predictor instructions and tool descriptions and argument descriptions are updated together, instead of being tuned in isolation. This lets the model learn better patterns for when to call a tool and how to use it from the same execution traces and feedback that drive core GEPA. + +### Usage and constraints + +- **Expose tools as `dspy.Tool` in signatures and examples.** GEPA only optimizes tools that are represented as `dspy.Tool` and actually passed as `dspy.Tool` objects into your modules. +- **Treat `Tool.name` as a stable identifier.** `Tool.name` is the tool's name, and GEPA uses it to attach improved descriptions and argument descriptions. If you reuse the same `Tool.name` for different tools, they will share the same text updates. +- **Avoid custom tools named `"finish"`.** The built-in ReAct `"finish"` tool is reserved and excluded from optimization. Custom tools with the name `"finish"` are also not optimized. +- **Custom instruction proposers handle all modules and tool updates.** When you provide an `instruction_proposer`, GEPA routes every optimized module through your proposer instead of the built-in instruction proposer. If `enable_tool_optimization=True`, modules that call tools are still included, and your proposer is also responsible for updating their tool descriptions and argument descriptions. + +### Tool Module Optimization Prompt + +GEPA uses `ToolProposer` to optimize ReAct modules when `enable_tool_optimization=True`. For each module, the proposer builds a dynamic signature from the base `GenerateImprovedToolModuleDescriptionsFromFeedback` signature shown below, then appends output fields for each tool description and each tool argument description in that module. For ReAct modules, the proposer also appends input and output fields for the extract instruction. + +```python +class GenerateImprovedToolModuleDescriptionsFromFeedback(dspy.Signature): + """I provided an assistant with predictor instructions and tool descriptions, + but its performance needs improvement based on the examples_with_feedback below. + + Your task is to propose better predictor instructions, tool descriptions, and + tool argument descriptions that address the issues shown in these examples. + Focus on reinforcing patterns that clearly improve the assistant's performance + on similar tasks, rather than rewriting everything from scratch unless necessary. + These components are progressively optimized - refine only what needs to change. + + Analyze the examples_with_feedback to identify success and failure patterns, + and write improved instructions and descriptions at their appropriate level + of abstraction and/or specificity, so that each layer plays a clear, + complementary role without unnecessary repetition or verbosity unless + redundancy clearly helps the assistant's performance. + """ + + current_predictor_instruction = dspy.InputField( + desc="Current instruction guiding the predictor" + ) + current_tools = dspy.InputField( + annotation=list[dspy.Tool], + desc="Available tools with their complete schemas" + ) + examples_with_feedback = dspy.InputField( + desc="Execution examples with feedback showing successes and failures" + ) + + improved_predictor_instruction: str | None = dspy.OutputField( + desc="Improved instruction for the predictor", + default=None + ) + + # GEPA appends output fields dynamically for each tool and argument: + # - improved_tool_{name}_desc with desc="Improved description of tool '{name}'" + # - improved_tool_{name}_arg_{param}_desc with desc="Improved description of the argument '{param}' of tool '{name}'" + # For ReAct modules, GEPA also appends: + # - current_extract_instruction (input) with desc="Current instruction for extraction predictor" + # - improved_extract_instruction (output) with desc="Improved instruction for extraction" +``` + +The reflection LM uses this dynamically-built signature to jointly propose updates across predictor instructions, tool descriptions, and argument descriptions based on execution feedback. Updates are coordinated rather than made in isolation: the LM sees all current components together and can selectively update any subset by returning new text, or return `None` to keep a component unchanged. + +### How Tool Optimization Works + +When `enable_tool_optimization=True`, GEPA: + +1. **Discovers ReAct modules** - Identifies `dspy.ReAct` modules and their associated tools +2. **Treats them as joint optimization units** - Instead of only optimizing predictor instructions, GEPA optimizes predictor instructions and tool descriptions together as a coordinated set; for ReAct this includes both the react and extract instructions +3. **Routes to specialized proposer** - Separates components by type and routes them appropriately: + - **With custom `instruction_proposer`**: Your custom proposer receives both ReAct modules and plain predictors, and is responsible for updating all components + - **With default proposer**: Plain predictors use the default instruction proposer; ReAct modules use `ToolProposer`, which employs the dynamic signature mechanism described above +4. **Optimizes jointly** - `ToolProposer` improves predictor instructions and tool descriptions together based on execution feedback, coordinating updates across all components rather than tuning them in isolation +5. **Applies updates** - Improved instructions update predictor signatures; improved tool descriptions and argument descriptions update all `dspy.Tool` objects with matching tool names throughout the program + +Modules without tools (like `dspy.Predict` or `dspy.ChainOfThought`) continue using standard GEPA instruction-only optimization. + +### When to Use Tool Optimization + +Enable `enable_tool_optimization=True` when tools are central to your program's behavior and you want GEPA to jointly optimize predictor instructions and tool descriptions together. Common scenarios: + +1. **Wrong tool selection** - Predictor with `search` and `weather` tools keeps searching when it should check weather, or vice versa. GEPA refines predictor instructions and tool descriptions to clarify when to use each tool. + +2. **Underused tools** - Predictor responds "I don't know" without using available tools that could answer the question. GEPA improves predictor instructions to be more proactive about tool usage. + +3. **Tool call loops** - Agent keeps calling `web_search` multiple times with similar queries instead of synthesizing information. GEPA improves instructions to encourage synthesis and tool descriptions to clarify when searches are sufficient. + +4. **Extraction failures (ReAct)** - Agent executes tools correctly but fails to extract the final answer from the trajectory. GEPA improves extract instruction to better identify and format answers from tool outputs. + +5. **Multi-agent delegation** - Parent agent has delegation tools to specialized sub-agents but doesn't understand when to use each. GEPA optimizes instructions and tool descriptions across both parent and sub-agent modules for coherent delegation. + +See the usage example below for tool-using programs. + +### Usage Example + +```python +import dspy + +def search_web(query: str) -> str: + return f"Search results for: {query}" + +def get_weather(city: str) -> str: + """Get the current weather for a city.""" + return f"The weather in {city} is sunny and 75°F" + +# Create tools with basic descriptions +search_tool = dspy.Tool(search_web, name="search_web", desc="Search tool") +weather_tool = dspy.Tool(get_weather, name="get_weather", desc="Weather tool") + +program = dspy.ReAct("question -> answer", tools=[search_tool, weather_tool]) + +# Enable tool optimization +gepa = dspy.GEPA( + metric=my_metric, + reflection_lm=dspy.LM(model="gpt-5-mini"), + enable_tool_optimization=True, + auto="medium" +) + +optimized_program = gepa.compile(program, trainset=train_examples, valset=val_examples) +``` + +### Inspecting Optimized Programs + +View optimization results and metadata (requires `track_stats=True`): + +```python +# High-level optimization metadata +optimized_program.detailed_results +``` + +Access optimized instructions and tool descriptions directly: + +```python +# Predictor instructions +for name, predictor in optimized_program.named_predictors(): + print(f"{name}: {predictor.signature.instructions}") + +# Tool descriptions and argument descriptions +for tool_name, tool in optimized_program.tools.items(): + print(f"{tool_name}: {tool.desc}") + for arg_name, arg_schema in tool.args.items(): + print(f" {arg_name}: {arg_schema.get('description', 'N/A')}") +``` diff --git a/docs/docs/api/optimizers/GEPA/overview.md b/docs/docs/api/optimizers/GEPA/overview.md index 0125702bea..4bca3b1cc8 100644 --- a/docs/docs/api/optimizers/GEPA/overview.md +++ b/docs/docs/api/optimizers/GEPA/overview.md @@ -117,6 +117,12 @@ Practical Recipe for GEPA-Friendly Feedback: - **Multi-Objective Tasks** (e.g., PUPA): Decompose aggregate scores to reveal contributions from each objective, highlighting tradeoffs (e.g., quality vs. privacy). - **Stacked Pipelines** (e.g., code generation: parse → compile → run → profile → evaluate): Expose stage-specific failures; natural-language traces often suffice for LLM self-correction. +## Tool Optimization with GEPA + +When `enable_tool_optimization=True`, GEPA jointly optimizes `dspy.ReAct` modules with the tools - GEPA updates predictor instructions and tool descriptions/argument descriptions together, based on execution traces and feedback, instead of keeping tool behavior fixed. + +For details, examples, and the underlying design (tool discovery, naming requirements, and interaction with custom instruction proposers), see [Tool Optimization](GEPA_Advanced.md#tool-optimization). + ## Custom Instruction Proposal For advanced customization of GEPA's instruction proposal mechanism, including custom instruction proposers and component selectors, see [Advanced Features](GEPA_Advanced.md). diff --git a/dspy/teleprompt/gepa/gepa.py b/dspy/teleprompt/gepa/gepa.py index c35e916691..f79cd65053 100644 --- a/dspy/teleprompt/gepa/gepa.py +++ b/dspy/teleprompt/gepa/gepa.py @@ -1,4 +1,5 @@ import inspect +import json import logging import random from dataclasses import dataclass @@ -9,8 +10,15 @@ from gepa.proposer.reflective_mutation.base import ReflectionComponentSelector from dspy.clients.lm import LM +from dspy.predict.react import ReAct from dspy.primitives import Example, Module, Prediction -from dspy.teleprompt.gepa.gepa_utils import DspyAdapter, DSPyTrace, PredictorFeedbackFn, ScoreWithFeedback +from dspy.teleprompt.gepa.gepa_utils import ( + TOOL_MODULE_PREFIX, + DspyAdapter, + DSPyTrace, + PredictorFeedbackFn, + ScoreWithFeedback, +) from dspy.teleprompt.teleprompt import Teleprompter from dspy.utils.annotation import experimental @@ -22,6 +30,7 @@ "heavy": {"n": 18}, } + @experimental(version="3.0.0") class GEPAFeedbackMetric(Protocol): def __call__( @@ -36,22 +45,23 @@ def __call__( - gold: The gold example. - pred: The predicted output. - trace: Optional. The trace of the program's execution. - - pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which + - pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which the feedback is being requested. - pred_trace: Optional. The trace of the target predictor's execution GEPA is seeking feedback for. Note the `pred_name` and `pred_trace` arguments. During optimization, GEPA will call the metric to obtain feedback for individual predictors being optimized. GEPA provides the name of the predictor in `pred_name` and the sub-trace (of the trace) corresponding to the predictor in `pred_trace`. - If available at the predictor level, the metric should return dspy.Prediction(score: float, feedback: str) corresponding - to the predictor. + If available at the predictor level, the metric should return dspy.Prediction(score: float, feedback: str) + corresponding to the predictor. If not available at the predictor level, the metric can also return a text feedback at the program level (using just the gold, pred and trace). - If no feedback is returned, GEPA will use a simple text feedback consisting of just the score: + If no feedback is returned, GEPA will use a simple text feedback consisting of just the score: f"This trajectory got a score of {score}." """ ... + @experimental(version="3.0.0") @dataclass(frozen=True) class DspyGEPAResult: @@ -74,6 +84,7 @@ class DspyGEPAResult: - best_idx: candidate index with the highest val_aggregate_scores - best_candidate: the program text mapping for best_idx """ + # Data about the proposed candidates candidates: list[Module] parents: list[list[int | None]] @@ -108,10 +119,7 @@ def highest_score_achieved_per_val_task(self) -> list[float]: ] def to_dict(self) -> dict[str, Any]: - cands = [ - {k: v for k, v in cand.items()} - for cand in self.candidates - ] + cands = [{k: v for k, v in cand.items()} for cand in self.candidates] return dict( candidates=cands, @@ -144,6 +152,7 @@ def from_gepa_result(gepa_result: "GEPAResult", adapter: "DspyAdapter") -> "Dspy seed=gepa_result.seed, ) + @experimental(version="3.0.0") class GEPA(Teleprompter): """ @@ -172,18 +181,18 @@ def metric( - gold: The gold example. - pred: The predicted output. - trace: Optional. The trace of the program's execution. - - pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which + - pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which the feedback is being requested. - pred_trace: Optional. The trace of the target predictor's execution GEPA is seeking feedback for. Note the `pred_name` and `pred_trace` arguments. During optimization, GEPA will call the metric to obtain feedback for individual predictors being optimized. GEPA provides the name of the predictor in `pred_name` and the sub-trace (of the trace) corresponding to the predictor in `pred_trace`. - If available at the predictor level, the metric should return {'score': float, 'feedback': str} corresponding + If available at the predictor level, the metric should return {'score': float, 'feedback': str} corresponding to the predictor. If not available at the predictor level, the metric can also return a text feedback at the program level (using just the gold, pred and trace). - If no feedback is returned, GEPA will use a simple text feedback consisting of just the score: + If no feedback is returned, GEPA will use a simple text feedback consisting of just the score: f"This trajectory got a score of {score}." \""" ... @@ -207,123 +216,129 @@ def metric( max_full_evals: The maximum number of full evaluations to perform. max_metric_calls: The maximum number of metric calls to perform. reflection_minibatch_size: The number of examples to use for reflection in a single GEPA step. Default is 3. - candidate_selection_strategy: The strategy to use for candidate selection. Default is "pareto", - which stochastically selects candidates from the Pareto frontier of all validation scores. + candidate_selection_strategy: The strategy to use for candidate selection. Default is "pareto", + which stochastically selects candidates from the Pareto frontier of all validation scores. Options: "pareto", "current_best". - reflection_lm: The language model to use for reflection. Required parameter. GEPA benefits from - a strong reflection model. Consider using `dspy.LM(model='gpt-5', temperature=1.0, max_tokens=32000)` + reflection_lm: The language model to use for reflection. Required parameter. GEPA benefits from + a strong reflection model. Consider using `dspy.LM(model='gpt-5', temperature=1.0, max_tokens=32000)` for optimal performance. skip_perfect_score: Whether to skip examples with perfect scores during reflection. Default is True. instruction_proposer: Optional custom instruction proposer implementing GEPA's ProposalFn protocol. - **Default: None (recommended for most users)** - Uses GEPA's proven instruction proposer from - the [GEPA library](https://github.com/gepa-ai/gepa), which implements the - [`ProposalFn`](https://github.com/gepa-ai/gepa/blob/main/src/gepa/core/adapter.py). This default - proposer is highly capable and was validated across diverse experiments reported in the GEPA + **Default: None (recommended for most users)** - Uses GEPA's proven instruction proposer from + the [GEPA library](https://github.com/gepa-ai/gepa), which implements the + [`ProposalFn`](https://github.com/gepa-ai/gepa/blob/main/src/gepa/core/adapter.py). This default + proposer is highly capable and was validated across diverse experiments reported in the GEPA paper and tutorials. - See documentation on custom instruction proposers + See documentation on custom instruction proposers [here](https://dspy.ai/api/optimizers/GEPA/GEPA_Advanced/#custom-instruction-proposers). - + **Advanced Feature**: Only needed for specialized scenarios: - **Multi-modal handling**: Processing dspy.Image inputs alongside textual information - - **Nuanced control over constraints**: Fine-grained control over instruction length, format, + - **Nuanced control over constraints**: Fine-grained control over instruction length, format, and structural requirements beyond standard feedback mechanisms - - **Domain-specific knowledge injection**: Specialized terminology or context that cannot be + - **Domain-specific knowledge injection**: Specialized terminology or context that cannot be provided through feedback_func alone - - **Provider-specific prompting**: Optimizations for specific LLM providers (OpenAI, Anthropic) + - **Provider-specific prompting**: Optimizations for specific LLM providers (OpenAI, Anthropic) with unique formatting preferences - - **Coupled component updates**: Coordinated updates of multiple components together rather + - **Coupled component updates**: Coordinated updates of multiple components together rather than independent optimization - **External knowledge integration**: Runtime access to databases, APIs, or knowledge bases - - The default proposer handles the vast majority of use cases effectively. Use - MultiModalInstructionProposer() from dspy.teleprompt.gepa.instruction_proposal for visual + + The default proposer handles the vast majority of use cases effectively. Use + MultiModalInstructionProposer() from dspy.teleprompt.gepa.instruction_proposal for visual content or implement custom ProposalFn for highly specialized requirements. - - Note: When both instruction_proposer and reflection_lm are set, the instruction_proposer is called - in the reflection_lm context. However, reflection_lm is optional when using a custom instruction_proposer. + + Note: When both instruction_proposer and reflection_lm are set, the instruction_proposer is called + in the reflection_lm context. However, reflection_lm is optional when using a custom instruction_proposer. Custom instruction proposers can invoke their own LLMs if needed. component_selector: Custom component selector implementing the [ReflectionComponentSelector](https://github.com/gepa-ai/gepa/blob/main/src/gepa/proposer/reflective_mutation/base.py) protocol, - or a string specifying a built-in selector strategy. Controls which components (predictors) are selected - for optimization at each iteration. Defaults to 'round_robin' strategy which cycles through components - one at a time. Available string options: 'round_robin' (cycles through components sequentially), - 'all' (selects all components for simultaneous optimization). Custom selectors can implement strategies - using LLM-driven selection logic based on optimization state and trajectories. - See [gepa component selectors](https://github.com/gepa-ai/gepa/blob/main/src/gepa/strategies/component_selector.py) + or a string specifying a built-in selector strategy. Controls which components (predictors) are selected + for optimization at each iteration. Defaults to 'round_robin' strategy which cycles through components + one at a time. Available string options: 'round_robin' (cycles through components sequentially), + 'all' (selects all components for simultaneous optimization). Custom selectors can implement strategies + using LLM-driven selection logic based on optimization state and trajectories. + See [gepa component selectors](https://github.com/gepa-ai/gepa/blob/main/src/gepa/strategies/component_selector.py) for available built-in selectors and the ReflectionComponentSelector protocol for implementing custom selectors. add_format_failure_as_feedback: Whether to add format failures as feedback. Default is False. use_merge: Whether to use merge-based optimization. Default is True. max_merge_invocations: The maximum number of merge invocations to perform. Default is 5. num_threads: The number of threads to use for evaluation with `Evaluate`. Optional. failure_score: The score to assign to failed examples. Default is 0.0. - perfect_score: The maximum score achievable by the metric. Default is 1.0. Used by GEPA + perfect_score: The maximum score achievable by the metric. Default is 1.0. Used by GEPA to determine if all examples in a minibatch are perfect. - log_dir: The directory to save the logs. GEPA saves elaborate logs, along with all candidate - programs, in this directory. Running GEPA with the same `log_dir` will resume the run + log_dir: The directory to save the logs. GEPA saves elaborate logs, along with all candidate + programs, in this directory. Running GEPA with the same `log_dir` will resume the run from the last checkpoint. - track_stats: Whether to return detailed results and all proposed programs in the `detailed_results` + track_stats: Whether to return detailed results and all proposed programs in the `detailed_results` attribute of the optimized program. Default is False. use_wandb: Whether to use wandb for logging. Default is False. - wandb_api_key: The API key to use for wandb. If not provided, wandb will use the API key + wandb_api_key: The API key to use for wandb. If not provided, wandb will use the API key from the environment variable `WANDB_API_KEY`. wandb_init_kwargs: Additional keyword arguments to pass to `wandb.init`. - track_best_outputs: Whether to track the best outputs on the validation set. track_stats must - be True if track_best_outputs is True. The optimized program's `detailed_results.best_outputs_valset` + track_best_outputs: Whether to track the best outputs on the validation set. track_stats must + be True if track_best_outputs is True. The optimized program's `detailed_results.best_outputs_valset` will contain the best outputs for each task in the validation set. - warn_on_score_mismatch: GEPA (currently) expects the metric to return the same module-level score when - called with and without the pred_name. This flag (defaults to True) determines whether a warning is + warn_on_score_mismatch: GEPA (currently) expects the metric to return the same module-level score when + called with and without the pred_name. This flag (defaults to True) determines whether a warning is raised if a mismatch in module-level and predictor-level score is detected. + enable_tool_optimization: Whether to enable joint optimization of dspy.ReAct modules. + When enabled, GEPA jointly optimizes predictor instructions and tool descriptions together + for dspy.ReAct modules. See the + [Tool Optimization guide](https://dspy.ai/api/optimizers/GEPA/GEPA_Advanced/#tool-optimization) + for details on when to use this feature and how it works. Default is False. seed: The random seed to use for reproducibility. Default is 0. gepa_kwargs: (Optional) Additional keyword arguments to pass directly to [gepa.optimize](https://github.com/gepa-ai/gepa/blob/main/src/gepa/api.py). Useful for accessing advanced GEPA features not directly exposed through DSPy's GEPA interface. - + Available parameters: - - batch_sampler: Strategy for selecting training examples. Can be a [BatchSampler](https://github.com/gepa-ai/gepa/blob/main/src/gepa/strategies/batch_sampler.py) instance or a string + - batch_sampler: Strategy for selecting training examples. Can be a [BatchSampler](https://github.com/gepa-ai/gepa/blob/main/src/gepa/strategies/batch_sampler.py) instance or a string ('epoch_shuffled'). Defaults to 'epoch_shuffled'. Only valid when reflection_minibatch_size is None. - - merge_val_overlap_floor: Minimum number of shared validation ids required between parents before - attempting a merge subsample. Only relevant when using `val_evaluation_policy` other than 'full_eval'. + - merge_val_overlap_floor: Minimum number of shared validation ids required between parents before + attempting a merge subsample. Only relevant when using `val_evaluation_policy` other than 'full_eval'. Default is 5. - - stop_callbacks: Optional stopper(s) that return True when optimization should stop. Can be a single - [StopperProtocol](https://github.com/gepa-ai/gepa/blob/main/src/gepa/utils/stop_condition.py) or a list of StopperProtocol instances. - Examples: [FileStopper](https://github.com/gepa-ai/gepa/blob/main/src/gepa/utils/stop_condition.py), - [TimeoutStopCondition](https://github.com/gepa-ai/gepa/blob/main/src/gepa/utils/stop_condition.py), - [SignalStopper](https://github.com/gepa-ai/gepa/blob/main/src/gepa/utils/stop_condition.py), - [NoImprovementStopper](https://github.com/gepa-ai/gepa/blob/main/src/gepa/utils/stop_condition.py), - or custom stopping logic. Note: This overrides the default + - stop_callbacks: Optional stopper(s) that return True when optimization should stop. Can be a single + [StopperProtocol](https://github.com/gepa-ai/gepa/blob/main/src/gepa/utils/stop_condition.py) or a list of StopperProtocol instances. + Examples: [FileStopper](https://github.com/gepa-ai/gepa/blob/main/src/gepa/utils/stop_condition.py), + [TimeoutStopCondition](https://github.com/gepa-ai/gepa/blob/main/src/gepa/utils/stop_condition.py), + [SignalStopper](https://github.com/gepa-ai/gepa/blob/main/src/gepa/utils/stop_condition.py), + [NoImprovementStopper](https://github.com/gepa-ai/gepa/blob/main/src/gepa/utils/stop_condition.py), + or custom stopping logic. Note: This overrides the default max_metric_calls stopping condition. - - use_cloudpickle: Use cloudpickle instead of pickle for serialization. Can be helpful when the + - use_cloudpickle: Use cloudpickle instead of pickle for serialization. Can be helpful when the serialized state contains dynamically generated DSPy signatures. Default is False. - - val_evaluation_policy: Strategy controlling which validation ids to score each iteration. Can be + - val_evaluation_policy: Strategy controlling which validation ids to score each iteration. Can be 'full_eval' (evaluate every id each time) or an [EvaluationPolicy](https://github.com/gepa-ai/gepa/blob/main/src/gepa/strategies/eval_policy.py) instance. Default is 'full_eval'. - - use_mlflow: If True, enables MLflow integration to log optimization progress. + - use_mlflow: If True, enables MLflow integration to log optimization progress. MLflow can be used alongside Weights & Biases (WandB). - mlflow_tracking_uri: The tracking URI to use for MLflow (when use_mlflow=True). - mlflow_experiment_name: The experiment name to use for MLflow (when use_mlflow=True). - - Note: Parameters already handled by DSPy's GEPA class will be overridden by the direct parameters + + Note: Parameters already handled by DSPy's GEPA class will be overridden by the direct parameters and should not be passed through gepa_kwargs. - + Note: Budget Configuration: Exactly one of `auto`, `max_full_evals`, or `max_metric_calls` must be provided. The `auto` parameter provides preset configurations: "light" for quick experimentation, "medium" for balanced optimization, and "heavy" for thorough optimization. - + Reflection Configuration: The `reflection_lm` parameter is required and should be a strong language model. GEPA performs best with models like `dspy.LM(model='gpt-5', temperature=1.0, max_tokens=32000)`. The reflection process analyzes failed examples to generate feedback for program improvement. - + Merge Configuration: GEPA can merge successful program variants using `use_merge=True`. The `max_merge_invocations` parameter controls how many merge attempts are made during optimization. - - Evaluation Configuration: Use `num_threads` to parallelize evaluation. The `failure_score` and + + Evaluation Configuration: Use `num_threads` to parallelize evaluation. The `failure_score` and `perfect_score` parameters help GEPA understand your metric's range and optimize accordingly. - + Logging Configuration: Set `log_dir` to save detailed logs and enable checkpoint resuming. Use `track_stats=True` to access detailed optimization results via the `detailed_results` attribute. Enable `use_wandb=True` for experiment tracking and visualization. - + Reproducibility: Set `seed` to ensure consistent results across runs with the same configuration. """ + def __init__( self, metric: GEPAFeedbackMetric, @@ -348,18 +363,19 @@ def __init__( failure_score: float = 0.0, perfect_score: float = 1.0, # Logging - log_dir: str = None, + log_dir: str | None = None, track_stats: bool = False, use_wandb: bool = False, wandb_api_key: str | None = None, wandb_init_kwargs: dict[str, Any] | None = None, track_best_outputs: bool = False, warn_on_score_mismatch: bool = True, + enable_tool_optimization: bool = False, use_mlflow: bool = False, # Reproducibility seed: int | None = 0, # GEPA passthrough kwargs - gepa_kwargs: dict | None = None + gepa_kwargs: dict | None = None, ): try: inspect.signature(metric).bind(None, None, None, None, None) @@ -372,12 +388,7 @@ def __init__( self.metric_fn = metric # Budget configuration - assert ( - (max_metric_calls is not None) + - (max_full_evals is not None) + - (auto is not None) - == 1 - ), ( + assert (max_metric_calls is not None) + (max_full_evals is not None) + (auto is not None) == 1, ( "Exactly one of max_metric_calls, max_full_evals, auto must be set. " f"You set max_metric_calls={max_metric_calls}, " f"max_full_evals={max_full_evals}, " @@ -417,6 +428,7 @@ def __init__( self.wandb_api_key = wandb_api_key self.wandb_init_kwargs = wandb_init_kwargs self.warn_on_score_mismatch = warn_on_score_mismatch + self.enable_tool_optimization = enable_tool_optimization self.use_mlflow = use_mlflow if track_best_outputs: @@ -430,8 +442,85 @@ def __init__( self.component_selector = component_selector self.gepa_kwargs = gepa_kwargs or {} - def auto_budget(self, num_preds, num_candidates, valset_size: int, minibatch_size: int = 35, full_eval_steps: int = 5) -> int: + def _build_seed_candidate(self, student: Module) -> dict[str, str]: + """ + Build the seed candidate configuration from the student module. + + For ReAct modules (when tool optimization is enabled), creates a JSON config containing: + - react predictor instructions + - extract predictor instructions + - tool descriptions and argument descriptions + + For regular predictors, uses their signature instructions directly. + + Returns: + A dictionary mapping component names to their text representations (instructions or JSON configs). + """ + seed_candidate = {} + claimed_predictor_names = set() + + # Process ReAct modules when tool optimization is enabled + if self.enable_tool_optimization: + for module_path, module in student.named_sub_modules(): + if not isinstance(module, ReAct): + continue + + # Verify DSPy's two-predictor ReAct design + assert hasattr(module, "extract") and hasattr(module.extract, "predict"), ( + f"ReAct module '{module_path}' missing extract.predict - DSPy design may have changed" + ) + + # Get predictor names via object identity + extract_predictor = module.extract.predict + react_predictor = module.react + extract_predictor_name = None + react_predictor_name = None + for name, pred in student.named_predictors(): + if pred is extract_predictor: + extract_predictor_name = name + elif pred is react_predictor: + react_predictor_name = name + + # Use extract.predict as the key since it is the target predictor for feedback lookup + module_key = f"{TOOL_MODULE_PREFIX}:{extract_predictor_name}" + + # Build JSON config with dynamic predictor names as keys + config = { + react_predictor_name: react_predictor.signature.instructions, + extract_predictor_name: extract_predictor.signature.instructions, + "tools": { + tool_name: {"desc": tool.desc, "args": tool.args} + for tool_name, tool in module.tools.items() + if tool_name != "finish" # Skip the built-in finish tool + }, + } + + seed_candidate[module_key] = json.dumps(config, indent=2) + # Track predictor names that are part of ReAct modules + claimed_predictor_names.add(react_predictor_name) + claimed_predictor_names.add(extract_predictor_name) + else: + # Warn if ReAct modules found but tool optimization disabled + for module_path, module in student.named_sub_modules(): + if isinstance(module, ReAct): + logger.info( + f"Detected ReAct module at '{module_path}'. Consider using " + "`enable_tool_optimization=True` to jointly optimize react instructions, " + "extract instructions, tool descriptions, and tool argument descriptions." + ) + + # Add individual predictors that aren't part of ReAct module configs + for name, pred in student.named_predictors(): + if name not in claimed_predictor_names: + seed_candidate[name] = pred.signature.instructions + + return seed_candidate + + def auto_budget( + self, num_preds, num_candidates, valset_size: int, minibatch_size: int = 35, full_eval_steps: int = 5 + ) -> int: import numpy as np + num_trials = int(max(2 * (num_preds * 2) * np.log2(num_candidates), 1.5 * num_candidates)) if num_trials < 0 or valset_size < 0 or minibatch_size < 0: raise ValueError("num_trials, valset_size, and minibatch_size must be >= 0.") @@ -496,12 +585,18 @@ def compile( else: assert self.max_metric_calls is not None, "Either auto, max_full_evals, or max_metric_calls must be set." - logger.info(f"Running GEPA for approx {self.max_metric_calls} metric calls of the program. This amounts to {self.max_metric_calls / len(trainset) if valset is None else self.max_metric_calls / (len(trainset) + len(valset)):.2f} full evals on the {'train' if valset is None else 'train+val'} set.") + logger.info( + f"Running GEPA for approx {self.max_metric_calls} metric calls of the program. This amounts to {self.max_metric_calls / len(trainset) if valset is None else self.max_metric_calls / (len(trainset) + len(valset)):.2f} full evals on the {'train' if valset is None else 'train+val'} set." + ) if valset is None: - logger.warning("No valset provided; Using trainset as valset. This is useful as an inference-time scaling strategy where you want GEPA to find the best solutions for the provided tasks in the trainset, as it makes GEPA overfit prompts to the provided trainset. In order to ensure generalization and perform well on unseen tasks, please provide separate trainset and valset. Provide the smallest valset that is just large enough to match the downstream task distribution, while keeping trainset as large as possible.") + logger.warning( + "No valset provided; Using trainset as valset. This is useful as an inference-time scaling strategy where you want GEPA to find the best solutions for the provided tasks in the trainset, as it makes GEPA overfit prompts to the provided trainset. In order to ensure generalization and perform well on unseen tasks, please provide separate trainset and valset. Provide the smallest valset that is just large enough to match the downstream task distribution, while keeping trainset as large as possible." + ) valset = valset or trainset - logger.info(f"Using {len(valset)} examples for tracking Pareto scores. You can consider using a smaller sample of the valset to allow GEPA to explore more diverse solutions within the same budget. GEPA requires you to provide the smallest valset that is just large enough to match your downstream task distribution, while providing as large trainset as possible.") + logger.info( + f"Using {len(valset)} examples for tracking Pareto scores. You can consider using a smaller sample of the valset to allow GEPA to explore more diverse solutions within the same budget. GEPA requires you to provide the smallest valset that is just large enough to match your downstream task distribution, while providing as large trainset as possible." + ) rng = random.Random(self.seed) @@ -527,12 +622,10 @@ def feedback_fn( return o else: return dict(score=o, feedback=f"This trajectory got a score of {o}.") + return feedback_fn - feedback_map = { - k: feedback_fn_creator(k, v) - for k, v in student.named_predictors() - } + feedback_map = {k: feedback_fn_creator(k, v) for k, v in student.named_predictors()} # Build the DSPy adapter that encapsulates evaluation, trace capture, feedback extraction, and instruction proposal adapter = DspyAdapter( @@ -546,33 +639,30 @@ def feedback_fn( reflection_lm=self.reflection_lm, custom_instruction_proposer=self.custom_instruction_proposer, warn_on_score_mismatch=self.warn_on_score_mismatch, + enable_tool_optimization=self.enable_tool_optimization, reflection_minibatch_size=self.reflection_minibatch_size, ) - # Instantiate GEPA with the simpler adapter-based API - base_program = {name: pred.signature.instructions for name, pred in student.named_predictors()} + # Build the seed candidate configuration + seed_candidate = self._build_seed_candidate(student) + gepa_result: GEPAResult = optimize( - seed_candidate=base_program, + seed_candidate=seed_candidate, trainset=trainset, valset=valset, adapter=adapter, - # Reflection-based configuration reflection_lm=(lambda x: self.reflection_lm(x)[0]) if self.reflection_lm is not None else None, candidate_selection_strategy=self.candidate_selection_strategy, skip_perfect_score=self.skip_perfect_score, reflection_minibatch_size=self.reflection_minibatch_size, module_selector=self.component_selector, - perfect_score=self.perfect_score, - # Merge-based configuration use_merge=self.use_merge, max_merge_invocations=self.max_merge_invocations, - # Budget max_metric_calls=self.max_metric_calls, - # Logging logger=LoggerAdapter(logger), run_dir=self.log_dir, @@ -583,10 +673,9 @@ def feedback_fn( track_best_outputs=self.track_best_outputs, display_progress_bar=True, raise_on_exception=True, - # Reproducibility seed=self.seed, - **self.gepa_kwargs + **self.gepa_kwargs, ) new_prog = adapter.build_program(gepa_result.best_candidate) diff --git a/dspy/teleprompt/gepa/gepa_utils.py b/dspy/teleprompt/gepa/gepa_utils.py index d2e6772cef..d2acd1f826 100644 --- a/dspy/teleprompt/gepa/gepa_utils.py +++ b/dspy/teleprompt/gepa/gepa_utils.py @@ -1,20 +1,28 @@ +import json import logging import random from typing import Any, Callable, Protocol, TypedDict from gepa import EvaluationBatch, GEPAAdapter from gepa.core.adapter import ProposalFn +from gepa.strategies.instruction_proposal import InstructionProposalSignature import dspy from dspy.adapters.chat_adapter import ChatAdapter from dspy.adapters.types import History from dspy.adapters.types.base_type import Type +from dspy.adapters.types.tool import Tool from dspy.evaluate import Evaluate -from dspy.primitives import Example, Prediction -from dspy.teleprompt.bootstrap_trace import TraceData +from dspy.primitives import Example, Module, Prediction +from dspy.teleprompt.bootstrap_trace import FailedPrediction, TraceData logger = logging.getLogger(__name__) + +# Constants for module optimization +TOOL_MODULE_PREFIX = "tool_module" + + class LoggerAdapter: def __init__(self, logger: logging.Logger): self.logger = logger @@ -22,6 +30,7 @@ def __init__(self, logger: logging.Logger): def log(self, x: str): self.logger.info(x) + DSPyTrace = list[tuple[Any, dict[str, Any], Prediction]] @@ -31,15 +40,17 @@ class ReflectiveExample(TypedDict): Each example contains the predictor inputs, generated outputs, and feedback from evaluation. """ - Inputs: dict[str, Any] # Predictor inputs (may include str, dspy.Image, etc.) - Generated_Outputs: dict[str, Any] | str # Success: dict with output fields, Failure: error message string - Feedback: str # Always a string - from metric function or parsing error message + + Inputs: dict[str, Any] # Predictor inputs (may include str, dspy.Image, etc.) + Generated_Outputs: dict[str, Any] | str # Success: dict with output fields, Failure: error message string + Feedback: str # Always a string - from metric function or parsing error message class ScoreWithFeedback(Prediction): score: float feedback: str + class PredictorFeedbackFn(Protocol): def __call__( predictor_output: dict[str, Any], @@ -64,6 +75,7 @@ def __call__( """ ... + class DspyAdapter(GEPAAdapter[Example, TraceData, Prediction]): def __init__( self, @@ -77,6 +89,7 @@ def __init__( reflection_lm=None, custom_instruction_proposer: "ProposalFn | None" = None, warn_on_score_mismatch: bool = True, + enable_tool_optimization: bool = False, reflection_minibatch_size: int | None = None, ): self.student = student_module @@ -89,48 +102,155 @@ def __init__( self.reflection_lm = reflection_lm self.custom_instruction_proposer = custom_instruction_proposer self.warn_on_score_mismatch = warn_on_score_mismatch + self.enable_tool_optimization = enable_tool_optimization self.reflection_minibatch_size = reflection_minibatch_size - if self.custom_instruction_proposer is not None: - # We are only overriding the propose_new_texts method when a custom - # instruction proposer is provided. Otherwise, we use the GEPA - # default propose_new_texts. - - def custom_propose_new_texts( - candidate: dict[str, str], - reflective_dataset: dict[str, list[dict[str, Any]]], - components_to_update: list[str] - ) -> dict[str, str]: - if self.reflection_lm is not None: - with dspy.context(lm=self.reflection_lm): - return self.custom_instruction_proposer( - candidate=candidate, - reflective_dataset=reflective_dataset, - components_to_update=components_to_update - ) - else: - return self.custom_instruction_proposer( + def propose_new_texts( + self, + candidate: dict[str, str], + reflective_dataset: dict[str, list[dict[str, Any]]], + components_to_update: list[str], + ) -> dict[str, str]: + reflection_lm = self.reflection_lm or dspy.settings.lm + # If custom proposer provided, override everything with custom proposer + if self.custom_instruction_proposer: + with dspy.context(lm=reflection_lm): + return self.custom_instruction_proposer( + candidate=candidate, + reflective_dataset=reflective_dataset, + components_to_update=components_to_update, + ) + + # Otherwise, route to appropriate proposers + # Separate into two categories: tool-using modules (ReAct) vs regular instructions + # TODO: Add generic tool module support when DSPy trace lineage is improved + tool_components = [] + instruction_components = [] + + for c in components_to_update: + if c.startswith(TOOL_MODULE_PREFIX): + tool_components.append(c) + else: + instruction_components.append(c) + + results: dict[str, str] = {} + + with dspy.context(lm=reflection_lm): + # Handle regular instruction components + if instruction_components: + for name in instruction_components: + base_instruction = candidate[name] + dataset_with_feedback = reflective_dataset[name] + results[name] = InstructionProposalSignature.run( + lm=(lambda x: reflection_lm(x)[0]), + input_dict={ + "current_instruction_doc": base_instruction, + "dataset_with_feedback": dataset_with_feedback, + }, + )["new_instruction"] + + # Handle ReAct modules + if tool_components: + from dspy.teleprompt.gepa.instruction_proposal import ToolProposer + + tool_proposer = ToolProposer() + results.update( + tool_proposer( candidate=candidate, reflective_dataset=reflective_dataset, - components_to_update=components_to_update + components_to_update=tool_components, ) + ) - self.propose_new_texts = custom_propose_new_texts - - # Cache predictor names/signatures - self.named_predictors = list(self.student.named_predictors()) - + return results def build_program(self, candidate: dict[str, str]): new_prog = self.student.deepcopy() + + # Start with plain string instructions from candidate + predictor_candidates = {k: v for k, v in candidate.items() if not k.startswith(TOOL_MODULE_PREFIX)} + + tool_candidates = {} + if self.enable_tool_optimization: + for key, value in candidate.items(): + if not key.startswith(TOOL_MODULE_PREFIX): + continue + + config = json.loads(value) + + for pred_name, instruction in config.items(): + if isinstance(instruction, str): + predictor_candidates[pred_name] = instruction + + tool_candidates.update(config.get("tools", {})) + + # Update predictor instructions for name, pred in new_prog.named_predictors(): - if name in candidate: - pred.signature = pred.signature.with_instructions(candidate[name]) + if name in predictor_candidates: + pred.signature = pred.signature.with_instructions(predictor_candidates[name]) + + # Update tool descriptions + if tool_candidates: + self._update_tool_descriptions(new_prog, tool_candidates) + return new_prog + def _update_tool_descriptions(self, program: Module, tool_candidates: dict[str, Any]) -> None: + all_tools = self._collect_tools(program) + + for tool_name, tool_config in tool_candidates.items(): + if tool_name not in all_tools: + logger.warning( + f"Skipping updates for tool:'{tool_name}' because it cannot be detected on the student program." + ) + continue + + tool = all_tools[tool_name] + + # Update tool description if present. + if tool_config.get("desc"): + tool.desc = tool_config["desc"] + + # Update arg descriptions if present. + args_schema = tool_config.get("args") or {} + for arg_name, arg_schema in args_schema.items(): + if arg_schema.get("description") is not None: + tool.args[arg_name]["description"] = arg_schema["description"] + + def _collect_tools(self, module: Module) -> dict[str, Tool]: + """Recursively collect all Tool instances from a module and its sub-modules.""" + all_tools = {} + visited = set() + + def _collect_from_attribute(attr_value): + if isinstance(attr_value, Tool): + all_tools[attr_value.name] = attr_value + elif isinstance(attr_value, dspy.Module): + _traverse(attr_value) + elif isinstance(attr_value, list | dict): + items = attr_value if isinstance(attr_value, list) else attr_value.values() + for item in items: + if isinstance(item, Tool): + all_tools[item.name] = item + + def _traverse(current_module): + if id(current_module) in visited or not hasattr(current_module, "__dict__"): + return + visited.add(id(current_module)) + + for attr_value in current_module.__dict__.values(): + _collect_from_attribute(attr_value) + + _traverse(module) + return all_tools + def evaluate(self, batch, candidate, capture_traces=False): program = self.build_program(candidate) - callback_metadata = {"metric_key": "eval_full"} if self.reflection_minibatch_size is None or len(batch) > self.reflection_minibatch_size else {"disable_logging": True} + callback_metadata = ( + {"metric_key": "eval_full"} + if self.reflection_minibatch_size is None or len(batch) > self.reflection_minibatch_size + else {"disable_logging": True} + ) if capture_traces: # bootstrap_trace_data-like flow with trace capture @@ -158,6 +278,7 @@ def evaluate(self, batch, candidate, capture_traces=False): if hasattr(score, "score"): score = score["score"] scores.append(score) + return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajs) else: evaluator = Evaluate( @@ -176,19 +297,29 @@ def evaluate(self, batch, candidate, capture_traces=False): scores = [s["score"] if hasattr(s, "score") else s for s in scores] return EvaluationBatch(outputs=outputs, scores=scores, trajectories=None) - def make_reflective_dataset(self, candidate, eval_batch, components_to_update) -> dict[str, list[ReflectiveExample]]: - from dspy.teleprompt.bootstrap_trace import FailedPrediction + def make_reflective_dataset( + self, candidate, eval_batch, components_to_update + ) -> dict[str, list[ReflectiveExample]]: program = self.build_program(candidate) ret_d: dict[str, list[ReflectiveExample]] = {} + for pred_name in components_to_update: + # Extract predictor name from component key + if pred_name.startswith(TOOL_MODULE_PREFIX): + target_name = pred_name.removeprefix(f"{TOOL_MODULE_PREFIX}:") + else: + target_name = pred_name + + # Find the predictor object module = None for name, m in program.named_predictors(): - if name == pred_name: + if name == target_name: module = m break - assert module is not None + assert module is not None, f"Predictor not found: {target_name}" + # Create reflective examples from traces items: list[ReflectiveExample] = [] for data in eval_batch.trajectories or []: trace = data["trace"] @@ -265,7 +396,8 @@ def make_reflective_dataset(self, candidate, eval_batch, components_to_update) - d["Feedback"] = "Your output failed to parse. Follow this structure:\n" + structure_instruction # d['score'] = self.failure_score else: - feedback_fn = self.feedback_map[pred_name] + # Use actual predictor name for feedback lookup + feedback_fn = self.feedback_map[target_name] fb = feedback_fn( predictor_output=outputs, predictor_inputs=inputs, @@ -276,15 +408,18 @@ def make_reflective_dataset(self, candidate, eval_batch, components_to_update) - d["Feedback"] = fb["feedback"] if fb["score"] != module_score: if self.warn_on_score_mismatch: - logger.warning("The score returned by the metric with pred_name is different from the overall metric score. This can indicate 2 things: Either the metric is non-deterministic (e.g., LLM-as-judge, Semantic score, etc.) or the metric returned a score specific to pred_name that differs from the module level score. Currently, GEPA does not support predictor level scoring (support coming soon), and only requires a feedback text to be provided, which can be specific to the predictor or program level. GEPA will ignore the differing score returned, and instead use module level score. You can safely ignore this warning if using a semantic metric, however, if this mismatch is caused due to predictor scoring, please return module-level scores. To disable this warning, set warn_on_score_mismatch=False.") + logger.warning( + "The score returned by the metric with pred_name is different from the overall metric score. This can indicate 2 things: Either the metric is non-deterministic (e.g., LLM-as-judge, Semantic score, etc.) or the metric returned a score specific to pred_name that differs from the module level score. Currently, GEPA does not support predictor level scoring (support coming soon), and only requires a feedback text to be provided, which can be specific to the predictor or program level. GEPA will ignore the differing score returned, and instead use module level score. You can safely ignore this warning if using a semantic metric, however, if this mismatch is caused due to predictor scoring, please return module-level scores. To disable this warning, set warn_on_score_mismatch=False." + ) self.warn_on_score_mismatch = False fb["score"] = module_score items.append(d) if len(items) == 0: - # raise Exception(f"No valid predictions found for module {module.signature}.") + logger.warning(f" No valid reflective examples found for {pred_name}") continue + ret_d[pred_name] = items if len(ret_d) == 0: @@ -292,6 +427,14 @@ def make_reflective_dataset(self, candidate, eval_batch, components_to_update) - return ret_d + # TODO: Generic tool module optimization - pending DSPy trace lineage improvements + # Currently only ReAct modules are supported for tool optimization. + # Re-enable _update_candidate_tools when DSPy provides better tool→trace lineage. + # + # def _update_candidate_tools(self, candidate, program, trajectories) -> None: + # """Extract dspy.Tool objects from traces for tool modules and update candidate["tools"].""" + # ... + # TODO: The current DSPyAdapter implementation uses the GEPA default propose_new_texts. # We can potentially override this, to use the instruction proposal similar to MIPROv2. diff --git a/dspy/teleprompt/gepa/instruction_proposal.py b/dspy/teleprompt/gepa/instruction_proposal.py index 23810b9a02..19d3ac4007 100644 --- a/dspy/teleprompt/gepa/instruction_proposal.py +++ b/dspy/teleprompt/gepa/instruction_proposal.py @@ -1,3 +1,5 @@ +import json +import logging from typing import Any from gepa.core.adapter import ProposalFn @@ -6,6 +8,8 @@ from dspy.adapters.types.base_type import Type from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample +logger = logging.getLogger(__name__) + class GenerateEnhancedMultimodalInstructionFromFeedback(dspy.Signature): """I provided an assistant with instructions to perform a task involving visual content, but the assistant's performance needs improvement based on the examples and feedback below. @@ -310,3 +314,184 @@ def __call__( updated_components[component_name] = new_instruction return updated_components + + +class GenerateImprovedToolModuleDescriptionsFromFeedback(dspy.Signature): + """I provided an assistant with predictor instructions and tool descriptions, + but its performance needs improvement based on the examples_with_feedback below. + + Your task is to propose better predictor instructions, tool descriptions, and tool argument descriptions that address the issues shown in these examples. + Focus on reinforcing patterns that clearly improve the assistant's performance on similar tasks, rather than rewriting everything from scratch unless necessary. + These components are progressively optimized - refine only what needs to change. + + Analyze the examples_with_feedback to identify success and failure patterns, and write improved instructions and descriptions at their appropriate level of abstraction and/or specificity, + so that each layer plays a clear, complementary role without unnecessary repetition or verbosity unless redundancy clearly helps the assistant's performance. + """ + + current_predictor_instruction = dspy.InputField(desc="Current instruction guiding the predictor") + current_tools = dspy.InputField(annotation=list[dspy.Tool], desc="Available tools with their complete schemas") + examples_with_feedback = dspy.InputField(desc="Execution examples with feedback showing successes and failures") + + improved_predictor_instruction: str | None = dspy.OutputField( + desc="Improved instruction for the predictor", default=None + ) + + +class ToolProposer(ProposalFn): + """Proposer for optimizing tool-using module configurations. + + Supports two types of modules: + - Tool modules (1 predictor): Optimizes predictor instruction and tool descriptions + - ReAct modules (2 predictors): Jointly optimizes react instruction, extract instruction, and tool descriptions + + Uses dynamic signature generation to create output fields for each tool and parameter, + enabling the reflection LM to optimize all components cohesively based on execution feedback. + + This joint optimization approach allows the LM to see how instructions and tool descriptions + work together, leading to more coherent improvements than optimizing each component separately. + """ + + def __call__( + self, + candidate: dict[str, str], + reflective_dataset: dict[str, list[ReflectiveExample]], + components_to_update: list[str], + ) -> dict[str, str]: + """Optimize tool-using module components. + + Args: + candidate: Current component name -> JSON config mapping + reflective_dataset: Component name -> list of reflective examples + components_to_update: List of tool-using module component names to update + + Returns: + dict: Mapping of component names to improved JSON configs + """ + + updated_components = {} + + for module_key in components_to_update: + if module_key not in candidate or module_key not in reflective_dataset: + logger.debug( + f"Skipping {module_key}: not in candidate={module_key not in candidate}, not in " + "reflective_dataset={module_key not in reflective_dataset}" + ) + continue + current_module_config = json.loads(candidate[module_key]) + + # Predictor keys: ReAct has 2 predictors (react + extract) + predictor_keys = [k for k, v in current_module_config.items() if isinstance(v, str)] + primary_predictor_key = predictor_keys[0] + extract_predictor_key = predictor_keys[1] if len(predictor_keys) > 1 else None + + # Reconstruct Tool objects from JSON (func is placeholder since it can't be serialized) + current_tools_dict = current_module_config.get("tools", {}) + tools_list = [] + for tool_name, tool_info in current_tools_dict.items(): + tool = dspy.Tool( + func=lambda *args, **kwargs: None, # Placeholder - Tool requires Callable, but only schema is used + name=tool_name, + desc=tool_info.get("desc", ""), + ) + tool.args = tool_info.get("args", {}) + tools_list.append(tool) + + # Build dynamic signature with tool-specific output fields + signature = GenerateImprovedToolModuleDescriptionsFromFeedback + + for tool in tools_list: + tool_name = tool.name + tool_info = current_tools_dict[tool_name] + + signature = signature.append( + f"improved_tool_{tool_name}_desc", + dspy.OutputField(desc=f"Improved description of tool '{tool_name}'", default=None), + ) + + for arg_name in tool_info["args"].keys(): + signature = signature.append( + f"improved_tool_{tool_name}_arg_{arg_name}_desc", + dspy.OutputField( + desc=f"Improved description of the argument '{arg_name}' of tool '{tool_name}'", + default=None, + ), + ) + + kwargs = { + "current_predictor_instruction": current_module_config[primary_predictor_key], + "current_tools": tools_list, + "examples_with_feedback": self._format_examples(reflective_dataset[module_key]), + } + # If module has extract predictor, add extract fields + if extract_predictor_key is not None: + signature = signature.append( + "current_extract_instruction", dspy.InputField(desc="Current instruction for extraction predictor") + ) + signature = signature.append( + "improved_extract_instruction", + dspy.OutputField(desc="Improved instruction for extraction", default=None), + ) + kwargs["current_extract_instruction"] = current_module_config[extract_predictor_key] + + propose_descriptions = dspy.Predict(signature) + result = propose_descriptions(**kwargs) + + # Build improved config (reflection LM returns None to keep original, or new text) + improved_module_config = {} + + if result.improved_predictor_instruction is not None: + improved_module_config[primary_predictor_key] = result.improved_predictor_instruction + + if extract_predictor_key is not None and result.improved_extract_instruction is not None: + improved_module_config[extract_predictor_key] = result.improved_extract_instruction + + improved_module_config["tools"] = {} + for tool_name, tool_info in current_tools_dict.items(): + # Update tool description if LM proposed a change + improved_tool_desc = getattr(result, f"improved_tool_{tool_name}_desc", None) + if improved_tool_desc is not None: + tool_info["desc"] = improved_tool_desc + + # Update arg descriptions if LM proposed changes + for arg_name in tool_info["args"].keys(): + improved_tool_arg_desc = getattr(result, f"improved_tool_{tool_name}_arg_{arg_name}_desc", None) + if improved_tool_arg_desc is not None: + tool_info["args"][arg_name]["description"] = improved_tool_arg_desc + + improved_module_config["tools"][tool_name] = tool_info + + updated_components[module_key] = json.dumps(improved_module_config, indent=2) + + return updated_components + + def _format_examples(self, reflective_dataset: list[ReflectiveExample]) -> str: + """Format reflective examples using GEPA's markdown structure.""" + + def render_value(value, level=3): + if isinstance(value, dict): + s = "" + for key, val in value.items(): + s += f"{'#' * level} {key}\n" + s += render_value(val, min(level + 1, 6)) + if not value: + s += "\n" + return s + if isinstance(value, (list, tuple)): + s = "" + for index, item in enumerate(value): + s += f"{'#' * level} Item {index + 1}\n" + s += render_value(item, min(level + 1, 6)) + if not value: + s += "\n" + return s + return f"{str(value).strip()}\n\n" + + def convert_sample_to_markdown(sample, example_num): + s = f"# Example {example_num}\n" + for key, val in sample.items(): + s += f"## {key}\n" + s += render_value(val, level=3) + return s + + formatted_parts = [convert_sample_to_markdown(example, i + 1) for i, example in enumerate(reflective_dataset)] + return "\n\n".join(formatted_parts) diff --git a/tests/teleprompt/test_gepa_tool_optimization.py b/tests/teleprompt/test_gepa_tool_optimization.py new file mode 100644 index 0000000000..0c414e8491 --- /dev/null +++ b/tests/teleprompt/test_gepa_tool_optimization.py @@ -0,0 +1,353 @@ +"""Tests for GEPA's tool optimization (ReAct modules). + +Test categories: +1. Detection - Compile-time detection of dspy.ReAct modules +2. Application - build_program applies optimized instructions and tool descriptions + +DSPy ReAct Design Note: + DSPy's ReAct uses two predictors: + - react: reasoning/acting loop + - extract: structured output synthesis + + We optimize extract.predict as it's called once with the complete trajectory + and produces all output fields. +""" + +import json + +import gepa +from gepa import optimize as gepa_optimize + +import dspy +from dspy.teleprompt.gepa.gepa_utils import TOOL_MODULE_PREFIX, DspyAdapter +from dspy.utils.dummies import DummyLM + + +# Test tool fixtures +def search(query: str) -> str: + """Test search tool.""" + return f"Search: {query}" + + +def calculate(expr: str) -> str: + """Test calculator tool.""" + return str(eval(expr)) + + +def analyze(data: str) -> str: + """Test analyzer tool.""" + return f"Analysis: {data}" + + +def setup_seed_candidate_capture(monkeypatch): + """Capture seed_candidate dict passed to gepa.optimize.""" + captured = {} + + def capture_optimize(seed_candidate, **kwargs): + captured.update(seed_candidate) + return gepa_optimize(seed_candidate=seed_candidate, **kwargs) + + monkeypatch.setattr(gepa, "optimize", capture_optimize) + return captured + + +def create_optimizer(task_responses, reflection_responses): + """Create GEPA optimizer with explicit LM responses. + + Args: + task_responses: List of dicts for task LM (e.g., [{"answer": "test"}]) + reflection_responses: List of dicts for reflection LM + + Returns: + tuple: (optimizer, trainset) + """ + task_lm = DummyLM(task_responses) + reflection_lm = DummyLM(reflection_responses) + + dspy.settings.configure(lm=task_lm) + + optimizer = dspy.GEPA( + metric=lambda example, pred, trace=None, pred_name=None, pred_trace=None: dspy.Prediction(score=0.5, feedback="ok"), + reflection_lm=reflection_lm, + max_metric_calls=2, + enable_tool_optimization=True, + ) + + trainset = [dspy.Example(query="test", answer="test").with_inputs("query")] + return optimizer, trainset + + +def get_predictor_name(program, predictor): + """Find predictor name by object identity in named_predictors(). + + Args: + program: DSPy module + predictor: Predictor object to find + + Returns: + str: Predictor name (e.g., "pred", "agent.pred") + """ + for name, pred in program.named_predictors(): + if pred is predictor: + return name + raise ValueError(f"Predictor not found: {predictor}") + + +def test_skip_predictor_without_tools(monkeypatch): + """Skip predictors without Tool annotations.""" + seed_candidate = setup_seed_candidate_capture(monkeypatch) + + class PlainSignature(dspy.Signature): + """Answer questions.""" + query: str = dspy.InputField() + answer: str = dspy.OutputField() + + class PlainAgent(dspy.Module): + def __init__(self): + super().__init__() + self.pred = dspy.Predict(PlainSignature) + + def forward(self, query): + return self.pred(query=query) + + program = PlainAgent() + optimizer, trainset = create_optimizer( + task_responses=[{"answer": "test"}] * 20, # Repeat for GEPA iterations + reflection_responses=[{"improved_instruction": "optimized"}] * 20 # Repeat for GEPA iterations + ) + optimizer.compile(program, trainset=trainset, valset=trainset) + + predictor_name = get_predictor_name(program, program.pred) + assert predictor_name in seed_candidate + + # Should be plain string instruction, not JSON config + instruction = seed_candidate[predictor_name] + assert isinstance(instruction, str) + + +def test_detect_react_module(monkeypatch): + """Detect ReAct module with tools.""" + seed_candidate = setup_seed_candidate_capture(monkeypatch) + + program = dspy.ReAct("question -> answer", tools=[search]) + optimizer, trainset = create_optimizer( + task_responses=[ + {"next_thought": "I should search", "next_tool_name": "search", "next_tool_args": {"query": "test"}}, + {"next_thought": "Done", "next_tool_name": "finish", "next_tool_args": {}}, + {"reasoning": "Based on search", "answer": "test"}, + ] * 20, # Repeat for GEPA iterations + reflection_responses=[ + { + "improved_predictor_instruction": "optimized react", + "improved_extract_instruction": "optimized extract", + "improved_tool_search_desc": "optimized search desc", + "improved_tool_search_arg_query_desc": "optimized query desc" + } + ] * 20 # Repeat for GEPA iterations + ) + optimizer.compile(program, trainset=trainset, valset=trainset) + + # Verify detection - use extract.predict as primary (for tracing) + extract_name = get_predictor_name(program, program.extract.predict) + component_key = f"{TOOL_MODULE_PREFIX}:{extract_name}" + assert component_key in seed_candidate + + tool_config = json.loads(seed_candidate[component_key]) + assert "tools" in tool_config + + +def test_detect_multiple_react_modules(monkeypatch): + """Detect multiple ReAct modules in workflow.""" + seed_candidate = setup_seed_candidate_capture(monkeypatch) + + class Workflow(dspy.Module): + def __init__(self): + super().__init__() + self.searcher = dspy.ReAct("query -> results", tools=[search]) + self.analyzer = dspy.ReAct("data -> analysis", tools=[analyze]) + + def forward(self, query): + results = self.searcher(query=query) + return self.analyzer(data=results.results) + + program = Workflow() + optimizer, trainset = create_optimizer( + task_responses=[ + {"next_thought": "Searching", "next_tool_name": "search", "next_tool_args": {"query": "test"}}, + {"next_thought": "Done", "next_tool_name": "finish", "next_tool_args": {}}, + {"reasoning": "Found results", "results": "data"}, + {"next_thought": "Analyzing", "next_tool_name": "analyze", "next_tool_args": {"data": "test"}}, + {"next_thought": "Done", "next_tool_name": "finish", "next_tool_args": {}}, + {"reasoning": "Analyzed", "analysis": "result"}, + ] * 20, # Repeat for GEPA iterations + reflection_responses=[ + { + "improved_predictor_instruction": "opt react search", + "improved_extract_instruction": "opt extract search", + "improved_tool_search_desc": "opt search desc", + "improved_tool_search_arg_query_desc": "opt query desc" + }, + { + "improved_predictor_instruction": "opt react analyze", + "improved_extract_instruction": "opt extract analyze", + "improved_tool_analyze_desc": "opt analyze desc", + "improved_tool_analyze_arg_data_desc": "opt data desc" + } + ] * 20 # Repeat for GEPA iterations + ) + optimizer.compile(program, trainset=trainset, valset=trainset) + + # Verify both detected - use extract.predict as primary (for tracing) + searcher_name = get_predictor_name(program, program.searcher.extract.predict) + analyzer_name = get_predictor_name(program, program.analyzer.extract.predict) + + searcher_key = f"{TOOL_MODULE_PREFIX}:{searcher_name}" + analyzer_key = f"{TOOL_MODULE_PREFIX}:{analyzer_name}" + + assert searcher_key in seed_candidate + assert analyzer_key in seed_candidate + + +def test_apply_optimized_react_descriptions(): + """Apply optimized tool descriptions to ReAct modules.""" + + program = dspy.ReAct("question -> answer", tools=[search]) + + # Create mock optimized candidate - use extract.predict as primary (for tracing) + react_name = get_predictor_name(program, program.react) + extract_predict_name = get_predictor_name(program, program.extract.predict) + + component_key = f"{TOOL_MODULE_PREFIX}:{extract_predict_name}" + + optimized_candidate = { + component_key: json.dumps({ + react_name: "OPTIMIZED: React instruction", + extract_predict_name: "OPTIMIZED: Extract instruction", + "tools": { + "search": { + "desc": "OPTIMIZED: Search tool", + "args": {"query": {"type": "string"}}, + } + } + }) + } + + # Apply optimizations + adapter = DspyAdapter( + student_module=program, + metric_fn=lambda example, pred, trace=None: 0.5, + feedback_map={}, + enable_tool_optimization=True, + ) + rebuilt = adapter.build_program(optimized_candidate) + + # Verify instructions updated + assert rebuilt.react.signature.instructions == "OPTIMIZED: React instruction" + assert rebuilt.extract.predict.signature.instructions == "OPTIMIZED: Extract instruction" + + # Verify tool updated + assert rebuilt.tools["search"].desc == "OPTIMIZED: Search tool" + + +def test_detect_nested_react_modules(monkeypatch): + """Detect ReAct modules in nested program structure.""" + seed_candidate = setup_seed_candidate_capture(monkeypatch) + + class Worker(dspy.Module): + def __init__(self): + super().__init__() + self.react = dspy.ReAct("task -> result", tools=[analyze]) + + def forward(self, task): + return self.react(task=task) + + class Orchestrator(dspy.Module): + def __init__(self): + super().__init__() + self.searcher = dspy.ReAct("query -> results", tools=[search]) + self.worker = Worker() + + def forward(self, query): + results = self.searcher(query=query) + return self.worker(task=results.results) + + program = Orchestrator() + optimizer, trainset = create_optimizer( + task_responses=[ + {"next_thought": "Search", "next_tool_name": "search", "next_tool_args": {"query": "test"}}, + {"next_thought": "Done", "next_tool_name": "finish", "next_tool_args": {}}, + {"reasoning": "Found", "results": "data"}, + {"next_thought": "Analyze", "next_tool_name": "analyze", "next_tool_args": {"data": "test"}}, + {"next_thought": "Done", "next_tool_name": "finish", "next_tool_args": {}}, + {"reasoning": "Analyzed", "result": "final"}, + ] * 20, # Repeat for GEPA iterations + reflection_responses=[ + { + "improved_predictor_instruction": "opt react search", + "improved_extract_instruction": "opt extract search", + "improved_tool_search_desc": "opt search desc", + "improved_tool_search_arg_query_desc": "opt query desc" + }, + { + "improved_predictor_instruction": "opt react analyze", + "improved_extract_instruction": "opt extract analyze", + "improved_tool_analyze_desc": "opt analyze desc", + "improved_tool_analyze_arg_data_desc": "opt data desc" + } + ] * 20 # Repeat for GEPA iterations + ) + optimizer.compile(program, trainset=trainset, valset=trainset) + + # Verify nested modules detected with full paths - use extract.predict as primary (for tracing) + searcher_name = get_predictor_name(program, program.searcher.extract.predict) + worker_extract_name = get_predictor_name(program, program.worker.react.extract.predict) + + searcher_key = f"{TOOL_MODULE_PREFIX}:{searcher_name}" + worker_key = f"{TOOL_MODULE_PREFIX}:{worker_extract_name}" + + assert searcher_key in seed_candidate + assert worker_key in seed_candidate + + # Verify full paths preserved (not truncated) + assert "searcher" in searcher_name # Contains parent path + assert "worker" in worker_extract_name # Contains nested path + + +def test_selective_optimization_with_none_returns(): + """Verify selective optimization when reflection LM returns None for some fields.""" + + program = dspy.ReAct("question -> answer", tools=[search, calculate]) + + react_name = get_predictor_name(program, program.react) + extract_name = get_predictor_name(program, program.extract.predict) + component_key = f"{TOOL_MODULE_PREFIX}:{extract_name}" + + # Mock selective optimization (only react instruction and search tool updated) + optimized_candidate = { + component_key: json.dumps({ + react_name: "OPTIMIZED: React instruction", + extract_name: program.extract.predict.signature.instructions, + "tools": { + "search": { + "desc": "OPTIMIZED: Search tool", + "args": {"query": {"type": "string"}}, + } + } + }) + } + + adapter = DspyAdapter( + student_module=program, + metric_fn=lambda example, pred, trace=None: 0.5, + feedback_map={}, + enable_tool_optimization=True, + ) + rebuilt = adapter.build_program(optimized_candidate) + + # Verify selective updates + assert rebuilt.react.signature.instructions == "OPTIMIZED: React instruction" + assert rebuilt.extract.predict.signature.instructions == program.extract.predict.signature.instructions + assert rebuilt.tools["search"].desc == "OPTIMIZED: Search tool" + + # Original unchanged (calculate not in optimized candidate) + assert rebuilt.tools["calculate"].desc == program.tools["calculate"].desc