diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 8696697d3a..6197577b67 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -1,3 +1,4 @@ +import json import logging from typing import TYPE_CHECKING, Any, get_origin @@ -8,7 +9,9 @@ from dspy.adapters.types.base_type import split_message_content_for_custom_types from dspy.adapters.types.reasoning import Reasoning from dspy.adapters.types.tool import Tool, ToolCalls +from dspy.adapters.utils import serialize_for_json from dspy.experimental import Citations +from dspy.signatures.field import InputField, OutputField from dspy.signatures.signature import Signature from dspy.utils.callback import BaseCallback, with_callbacks @@ -452,13 +455,13 @@ def format_demos(self, signature: type[Signature], demos: list[dict[str, Any]]) return messages - def _get_history_field_name(self, signature: type[Signature]) -> bool: + def _get_history_field_name(self, signature: type[Signature]) -> str | None: for name, field in signature.input_fields.items(): if field.annotation == History: return name return None - def _get_tool_call_input_field_name(self, signature: type[Signature]) -> bool: + def _get_tool_call_input_field_name(self, signature: type[Signature]) -> str | None: for name, field in signature.input_fields.items(): # Look for annotation `list[dspy.Tool]` or `dspy.Tool` origin = get_origin(field.annotation) @@ -468,54 +471,112 @@ def _get_tool_call_input_field_name(self, signature: type[Signature]) -> bool: return name return None - def _get_tool_call_output_field_name(self, signature: type[Signature]) -> bool: + def _get_tool_call_output_field_name(self, signature: type[Signature]) -> str | None: for name, field in signature.output_fields.items(): if field.annotation == ToolCalls: return name return None + def _serialize_kv_value(self, v: Any) -> str: + """Serialize a value to string for flat-mode history formatting. + + Uses the same pattern as format_field_value in adapters/utils.py. + """ + jsonable = serialize_for_json(v) + if isinstance(jsonable, (dict, list)): + return json.dumps(jsonable, ensure_ascii=False) + return str(jsonable) + def format_conversation_history( self, signature: type[Signature], history_field_name: str, inputs: dict[str, Any], ) -> list[dict[str, Any]]: - """Format the conversation history. + """Format the conversation history as multiturn messages. - This method formats the conversation history and the current input as multiturn messages. - - Args: - signature: The DSPy signature for which to format the conversation history. - history_field_name: The name of the history field in the signature. - inputs: The input arguments to the DSPy module. + Supports four modes: + - raw: Direct LM messages → passed through as-is + - demo: {"input_fields": {...}, "output_fields": {...}} → user/assistant pairs + - flat: Arbitrary kv pairs → single user message per dict (default) + - signature: Dict keys match signature fields → user/assistant pairs - Returns: - A list of multiturn messages. + For backward compatibility, flat-mode histories whose message keys are subsets of the + signature fields (and overlap output fields) are treated as signature-mode. """ - conversation_history = inputs[history_field_name].messages if history_field_name in inputs else None - - if conversation_history is None: + history = inputs.get(history_field_name) + if history is None: return [] - messages = [] - for message in conversation_history: - messages.append( - { + del inputs[history_field_name] + + if history.mode == "raw": + return [dict(msg) for msg in history.messages] + if history.mode == "demo": + return self._format_demo_history(history.messages) + if history.mode == "signature": + return self._format_signature_history(signature, history.messages) + + # Backward-compat shim: treat flat-mode as signature-mode if messages look like + # signature-style conversation history (keys subset of signature fields, overlapping outputs) + if history.mode == "flat" and history.messages: + sig_keys = set(signature.fields.keys()) + output_keys = set(signature.output_fields.keys()) + msg_key_sets = [set(m.keys()) for m in history.messages] + + if all(ks <= sig_keys for ks in msg_key_sets): + if any(ks & output_keys for ks in msg_key_sets): + return self._format_signature_history(signature, history.messages) + + return self._format_flat_history(history.messages) + + def _format_demo_history(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Format demo-mode history (input_fields/output_fields → user/assistant).""" + result = [] + for msg in messages: + if "input_fields" in msg: + input_dict = {k: self._serialize_kv_value(v) for k, v in msg["input_fields"].items()} + sig = Signature({k: InputField() for k in input_dict.keys()}, instructions="") + result.append({ "role": "user", - "content": self.format_user_message_content(signature, message), - } - ) - messages.append( - { + "content": self.format_user_message_content(sig, input_dict), + }) + if "output_fields" in msg: + output_dict = {k: self._serialize_kv_value(v) for k, v in msg["output_fields"].items()} + sig = Signature({k: OutputField() for k in output_dict.keys()}, instructions="") + result.append({ "role": "assistant", - "content": self.format_assistant_message_content(signature, message), - } - ) - - # Remove the history field from the inputs - del inputs[history_field_name] + "content": self.format_assistant_message_content(sig, output_dict), + }) + return result - return messages + def _format_signature_history( + self, signature: type[Signature], messages: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + """Format signature-mode history (signature fields → user/assistant pairs).""" + result = [] + for msg in messages: + result.append({ + "role": "user", + "content": self.format_user_message_content(signature, msg), + }) + result.append({ + "role": "assistant", + "content": self.format_assistant_message_content(signature, msg), + }) + return result + + def _format_flat_history(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Format flat-mode history (all kv pairs in single user message).""" + result = [] + for msg in messages: + serialized = {k: self._serialize_kv_value(v) for k, v in msg.items()} + sig = Signature({k: InputField() for k in serialized.keys()}, instructions="") + result.append({ + "role": "user", + "content": self.format_user_message_content(sig, serialized), + }) + return result def parse(self, signature: type[Signature], completion: str) -> dict[str, Any]: """Parse the LM output into a dictionary of the output fields. diff --git a/dspy/adapters/types/history.py b/dspy/adapters/types/history.py index 2c39d5c4ab..6f34175372 100644 --- a/dspy/adapters/types/history.py +++ b/dspy/adapters/types/history.py @@ -1,64 +1,98 @@ -from typing import Any +import warnings +from typing import Any, Literal import pydantic class History(pydantic.BaseModel): - """Class representing the conversation history. - - The conversation history is a list of messages, each message entity should have keys from the associated signature. - For example, if you have the following signature: - - ``` - class MySignature(dspy.Signature): - question: str = dspy.InputField() - history: dspy.History = dspy.InputField() - answer: str = dspy.OutputField() - ``` + """ + Class representing conversation history for DSPy modules. - Then the history should be a list of dictionaries with keys "question" and "answer". + The `History` class allows you to attach previous conversation turns or context to a module. + By default, DSPy will auto-detect the appropriate message format based on your data's structure, + making it easy to work with different LM and prompt formats without manual intervention. - Example: + In most cases, you can just use: + ```python + history = dspy.History(messages=[...]) ``` - import dspy + and DSPy will infer the message mode (chat, demo, flat, etc.) for you. - dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) + **If needed, you can specify the mode explicitly** + (e.g., if auto-detection guesses incorrectly, or for advanced use cases): + ```python + history = dspy.History(messages=[...], mode="demo") # force demo mode + ``` - class MySignature(dspy.Signature): - question: str = dspy.InputField() - history: dspy.History = dspy.InputField() - answer: str = dspy.OutputField() + Modes: + - "raw": LM-style messages with role/content + - "demo": Few-shot examples with input_fields/output_fields + - "signature": Dict keys match signature fields → user/assistant pairs + - "flat": Arbitrary key-value pairs → single user messages (default) + + **Raw mode: Chat-style history (LM messages):** + ```python + history = dspy.History(messages=[ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ]) + ``` - history = dspy.History( - messages=[ - {"question": "What is the capital of France?", "answer": "Paris"}, - {"question": "What is the capital of Germany?", "answer": "Berlin"}, - ] - ) + **Signature mode: Signature-matched history (previous input/output pairs):** + ```python + history = dspy.History(messages=[ + {"question": "What is 2+2?", "answer": "4"}, + ]) + ``` - predict = dspy.Predict(MySignature) - outputs = predict(question="What is the capital of France?", history=history) + **Demo mode: Few-shot demonstrations:** + ```python + history = dspy.History(messages=[ + {"input_fields": {"question": "2+2?"}, "output_fields": {"answer": "4"}}, + ]) ``` - Example of capturing the conversation history: + **Flat mode: Arbitrary context (key-value pairs as user messages):** + ```python + history = dspy.History(messages=[ + {"thought": "I need to search", "tool": "search", "result": "Found it"}, + ]) ``` + + In summary: Just pass `dspy.History(messages=[...])` and auto-detect will do the right thing most of the time. + Override the `mode` argument only if DSPy cannot reliably infer the correct message format for your use case. + + Example: + ```python import dspy dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) - class MySignature(dspy.Signature): + class QA(dspy.Signature): question: str = dspy.InputField() history: dspy.History = dspy.InputField() answer: str = dspy.OutputField() - predict = dspy.Predict(MySignature) - outputs = predict(question="What is the capital of France?") - history = dspy.History(messages=[{"question": "What is the capital of France?", **outputs}]) - outputs_with_history = predict(question="Are you sure?", history=history) + predict = dspy.Predict(QA) + + # First turn + result = predict(question="What is the capital of France?") + + # Build history from previous turn using auto-detect + history = dspy.History(messages=[ + {"question": "What is the capital of France?", **result} + ]) + + # Or, explicitly specify mode if needed + # history = dspy.History(messages=[...], mode="signature") + + # Follow-up with context + result = predict(question="What about Germany?", history=history) ``` """ messages: list[dict[str, Any]] + mode: Literal["signature", "demo", "flat", "raw"] = "flat" model_config = pydantic.ConfigDict( frozen=True, @@ -66,3 +100,150 @@ class MySignature(dspy.Signature): validate_assignment=True, extra="forbid", ) + + @staticmethod + def _infer_mode_from_msg(msg: dict) -> str: + """Infer the mode from a message's structure. + + Detection rules (conservative): + - Raw: has "role" key and ONLY LM-like keys (role, content, tool_calls, tool_call_id, name) + - Demo: keys are ONLY "input_fields" and/or "output_fields" + - Flat: everything else (signature mode must be explicit, and is applied retroactively inside the adapter) + """ + keys = set(msg.keys()) + lm_keys = {"role", "content", "tool_calls", "tool_call_id", "name"} + + if "role" in keys and keys <= lm_keys: + return "raw" + + if keys <= {"input_fields", "output_fields"} and keys: + return "demo" + + return "flat" + + def _validate_msg_for_mode(self, msg: dict, mode: str) -> None: + """Validate a message conforms to the expected mode structure.""" + if mode == "raw": + if not isinstance(msg.get("role"), str): + raise ValueError(f"Raw mode: 'role' must be a string: {msg}") + content = msg.get("content") + if content is not None and not isinstance(content, (str, list)): + raise ValueError(f"Raw mode: 'content' must be a string, list, or None: {msg}") + + elif mode == "demo": + if "input_fields" in msg and not isinstance(msg["input_fields"], dict): + raise ValueError(f"Demo mode: 'input_fields' must be a dict: {msg}") + if "output_fields" in msg and not isinstance(msg["output_fields"], dict): + raise ValueError(f"Demo mode: 'output_fields' must be a dict: {msg}") + + elif mode == "signature": + if not isinstance(msg, dict) or not msg: + raise ValueError(f"Signature mode: messages must be non-empty dicts: {msg}") + + def _warn_if_likely_wrong_mode(self, msg: dict, stacklevel: int = 2) -> None: + """Warn if a flat-mode message looks like it was intended for another mode.""" + keys = set(msg.keys()) + + if "role" in keys: + warnings.warn( + f"History message has 'role' key but is in flat mode. " + f"Did you mean to use mode='raw'? Message keys: {sorted(keys)}", + UserWarning, + stacklevel=stacklevel, + ) + elif keys & {"input_fields", "output_fields"}: + warnings.warn( + f"History message has 'input_fields'/'output_fields' but is in flat mode. " + f"Did you mean to use mode='demo'? Message keys: {sorted(keys)}", + UserWarning, + stacklevel=stacklevel, + ) + + @pydantic.model_validator(mode="after") + def _validate_messages(self) -> "History": + if not self.messages: + return self + + # Only infer if mode is the default "flat" and messages clearly match another mode + if self.mode == "flat": + inferred = self._infer_mode_from_msg(self.messages[0]) + if inferred in {"raw", "demo"}: + object.__setattr__(self, "mode", inferred) + + for msg in self.messages: + self._validate_msg_for_mode(msg, self.mode) + if self.mode == "flat": + # stacklevel=6: warn -> _warn_if_likely_wrong_mode -> _validate_messages -> validator -> __init__ -> caller + self._warn_if_likely_wrong_mode(msg, stacklevel=6) + + return self + + def with_messages(self, messages: list[dict[str, Any]]) -> "History": + """Return a new History with additional messages appended.""" + return History(messages=[*self.messages, *messages], mode=self.mode) + + @classmethod + def from_raw(cls, messages: list[dict[str, Any]]) -> "History": + """Create History from LM-style messages with role/content. + + Use this for chat-style conversation history or ReAct trajectories + that are already formatted as LM messages. + + Example: + ```python + history = dspy.History.from_raw([ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ]) + ``` + """ + return cls(messages=messages, mode="raw") + + @classmethod + def from_demos(cls, examples: list[dict[str, Any]]) -> "History": + """Create History from few-shot demonstration examples. + + Each example should have 'input_fields' and/or 'output_fields' keys + containing the respective field dictionaries. + + Example: + ```python + history = dspy.History.from_demos([ + {"input_fields": {"question": "2+2?"}, "output_fields": {"answer": "4"}}, + ]) + ``` + """ + return cls(messages=examples, mode="demo") + + @classmethod + def from_signature_pairs(cls, messages: list[dict[str, Any]]) -> "History": + """Create History from signature-matched field pairs. + + Each message dict should have keys matching the signature's input/output + fields. Each dict becomes a user/assistant message pair. + + Example: + ```python + history = dspy.History.from_signature_pairs([ + {"question": "What is 2+2?", "answer": "4"}, + ]) + ``` + """ + return cls(messages=messages, mode="signature") + + @classmethod + def from_kv(cls, messages: list[dict[str, Any]]) -> "History": + """Create History from arbitrary key-value context. + + Each dict becomes a single user message containing all key-value pairs. + Use this when you want to pass context that should NOT be split into + user/assistant turns. + + Example: + ```python + history = dspy.History.from_kv([ + {"thought": "I need to search", "tool": "search", "result": "Found it"}, + ]) + ``` + """ + return cls(messages=messages, mode="flat") diff --git a/dspy/utils/inspect_history.py b/dspy/utils/inspect_history.py index 07934157fd..65a32fab25 100644 --- a/dspy/utils/inspect_history.py +++ b/dspy/utils/inspect_history.py @@ -10,6 +10,14 @@ def _blue(text: str, end: str = "\n"): return "\x1b[34m" + str(text) + "\x1b[0m" + end +def _yellow(text: str, end: str = "\n"): + return "\x1b[33m" + str(text) + "\x1b[0m" + end + + +def _cyan(text: str, end: str = "\n"): + return "\x1b[36m" + str(text) + "\x1b[0m" + end + + def pretty_print_history(history, n: int = 1): """Prints the last n prompts and their completions.""" @@ -22,37 +30,67 @@ def pretty_print_history(history, n: int = 1): print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n") for msg in messages: - print(_red(f"{msg['role'].capitalize()} message:")) - if isinstance(msg["content"], str): - print(msg["content"].strip()) - else: - if isinstance(msg["content"], list): - for c in msg["content"]: - if c["type"] == "text": - print(c["text"].strip()) - elif c["type"] == "image_url": - image_str = "" - if "base64" in c["image_url"].get("url", ""): - len_base64 = len(c["image_url"]["url"].split("base64,")[1]) - image_str = ( - f"<{c['image_url']['url'].split('base64,')[0]}base64," - f"" - ) - else: - image_str = f"" - print(_blue(image_str.strip())) - elif c["type"] == "input_audio": - audio_format = c["input_audio"]["format"] - len_audio = len(c["input_audio"]["data"]) - audio_str = f"