From 2fd22c35141c04fbff9119f2838a0493f02868b3 Mon Sep 17 00:00:00 2001 From: Ben McHone Date: Fri, 12 Dec 2025 21:34:17 -0600 Subject: [PATCH 1/2] fix(BAMLAdapter): Use docstrings to describe BaseModels --- dspy/adapters/baml_adapter.py | 42 ++++++++++++++++++++++++----- tests/adapters/test_baml_adapter.py | 18 +++++++++++-- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/dspy/adapters/baml_adapter.py b/dspy/adapters/baml_adapter.py index 129a6614d7..da256b504c 100644 --- a/dspy/adapters/baml_adapter.py +++ b/dspy/adapters/baml_adapter.py @@ -15,6 +15,7 @@ # Changing the comment symbol to Python's # rather than other languages' // seems to help COMMENT_SYMBOL = "#" +INDENTATION = " " def _render_type_str( @@ -52,7 +53,7 @@ def _render_type_str( if origin in (types.UnionType, Union): non_none_args = [arg for arg in args if arg is not type(None)] # Render the non-None part of the union - type_render = " or ".join([_render_type_str(arg, depth + 1, indent) for arg in non_none_args]) + type_render = " or ".join([_render_type_str(arg, depth + 1, indent, seen_models) for arg in non_none_args]) # Add "or null" if None was part of the union if len(non_none_args) < len(args): return f"{type_render} or null" @@ -70,14 +71,14 @@ def _render_type_str( # Build inner schema - the Pydantic model inside should use indent level for array contents inner_schema = _build_simplified_schema(inner_type, indent + 1, seen_models) # Format with proper bracket notation and indentation - current_indent = " " * indent + current_indent = INDENTATION * indent return f"[\n{inner_schema}\n{current_indent}]" else: - return f"{_render_type_str(inner_type, depth + 1, indent)}[]" + return f"{_render_type_str(inner_type, depth + 1, indent, seen_models)}[]" # dict[T1, T2] if origin is dict: - return f"dict[{_render_type_str(args[0], depth + 1, indent)}, {_render_type_str(args[1], depth + 1, indent)}]" + return f"dict[{_render_type_str(args[0], depth + 1, indent, seen_models)}, {_render_type_str(args[1], depth + 1, indent, seen_models)}]" # fallback if hasattr(annotation, "__name__"): @@ -106,8 +107,19 @@ def _build_simplified_schema( seen_models.add(pydantic_model) lines = [] - current_indent = " " * indent - next_indent = " " * (indent + 1) + current_indent = INDENTATION * indent + next_indent = INDENTATION * (indent + 1) + + # Add model docstring as a comment above the object if it exists + # Only do this for top-level schemas (indent=0), since nested field docstrings + # are already added before the field name in the parent schema + if indent == 0 and pydantic_model.__doc__: + docstring = pydantic_model.__doc__.strip() + # Handle multiline docstrings by prefixing each line with the comment symbol + for line in docstring.split("\n"): + line = line.strip() + if line: + lines.append(f"{current_indent}{COMMENT_SYMBOL} {line}") lines.append(f"{current_indent}{{") @@ -121,6 +133,24 @@ def _build_simplified_schema( # If there's an alias but no description, show the alias as a comment lines.append(f"{next_indent}{COMMENT_SYMBOL} alias: {field.alias}") + # If the field type is a BaseModel, add its docstring as a comment before the field + field_annotation = field.annotation + # Handle Optional types + origin = get_origin(field_annotation) + if origin in (types.UnionType, Union): + args = get_args(field_annotation) + non_none_args = [arg for arg in args if arg is not type(None)] + if len(non_none_args) == 1: + field_annotation = non_none_args[0] + + if inspect.isclass(field_annotation) and issubclass(field_annotation, BaseModel): + if field_annotation.__doc__: + docstring = field_annotation.__doc__.strip() + for line in docstring.split("\n"): + line = line.strip() + if line: + lines.append(f"{next_indent}{COMMENT_SYMBOL} {line}") + rendered_type = _render_type_str(field.annotation, indent=indent + 1, seen_models=seen_models) line = f"{next_indent}{name}: {rendered_type}," diff --git a/tests/adapters/test_baml_adapter.py b/tests/adapters/test_baml_adapter.py index eaaa4f0d23..827ad80046 100644 --- a/tests/adapters/test_baml_adapter.py +++ b/tests/adapters/test_baml_adapter.py @@ -7,23 +7,29 @@ from litellm.files.main import ModelResponse import dspy -from dspy.adapters.baml_adapter import COMMENT_SYMBOL, BAMLAdapter +from dspy.adapters.baml_adapter import COMMENT_SYMBOL, INDENTATION, BAMLAdapter # Test fixtures - Pydantic models for testing class PatientAddress(pydantic.BaseModel): + """Patient Address model docstring""" street: str city: str country: Literal["US", "CA"] class PatientDetails(pydantic.BaseModel): + """ + Patient Details model docstring + Multiline docstring support test + """ name: str = pydantic.Field(description="Full name of the patient") age: int address: PatientAddress | None = None class ComplexNestedModel(pydantic.BaseModel): + """Complex model docstring""" id: int = pydantic.Field(description="Unique identifier") details: PatientDetails tags: list[str] = pydantic.Field(default_factory=list) @@ -130,12 +136,20 @@ class TestSignature(dspy.Signature): adapter = BAMLAdapter() schema = adapter.format_field_structure(TestSignature) + expected_patient_details = "\n".join([ + f"{INDENTATION}{COMMENT_SYMBOL} Patient Details model docstring", + f"{INDENTATION}{COMMENT_SYMBOL} Multiline docstring support test", + f"{INDENTATION}details:", + ]) + # Should include nested structure with comments assert f"{COMMENT_SYMBOL} Unique identifier" in schema - assert "details:" in schema + assert expected_patient_details in schema assert f"{COMMENT_SYMBOL} Full name of the patient" in schema assert "tags: string[]," in schema assert "metadata: dict[string, string]," in schema + assert f"{COMMENT_SYMBOL} Complex model docstring" in schema + assert f"{COMMENT_SYMBOL} Patient Address model docstring" in schema def test_baml_adapter_raise_error_on_circular_references(): From a2813fe71f4eefe1ae79e6db034e81cde57aa019 Mon Sep 17 00:00:00 2001 From: Ben McHone Date: Sat, 13 Dec 2025 12:29:20 -0600 Subject: [PATCH 2/2] fix(BAMLAdapter): Use signature descriptions in system prompt --- dspy/adapters/baml_adapter.py | 22 ---------------------- tests/adapters/test_baml_adapter.py | 9 ++++++--- 2 files changed, 6 insertions(+), 25 deletions(-) diff --git a/dspy/adapters/baml_adapter.py b/dspy/adapters/baml_adapter.py index da256b504c..1d9cc5e6ec 100644 --- a/dspy/adapters/baml_adapter.py +++ b/dspy/adapters/baml_adapter.py @@ -209,28 +209,6 @@ class ExtractPatientInfo(dspy.Signature): ``` """ - def format_field_description(self, signature: type[Signature]) -> str: - """Format the field description for the system message.""" - sections = [] - - # Add input field descriptions - if signature.input_fields: - sections.append("Your input fields are:") - for i, (name, field) in enumerate(signature.input_fields.items(), 1): - type_name = getattr(field.annotation, "__name__", str(field.annotation)) - description = f": {field.description}" if field.description else ":" - sections.append(f"{i}. `{name}` ({type_name}){description}") - - # Add output field descriptions - if signature.output_fields: - sections.append("Your output fields are:") - for i, (name, field) in enumerate(signature.output_fields.items(), 1): - type_name = getattr(field.annotation, "__name__", str(field.annotation)) - description = f": {field.description}" if field.description else ":" - sections.append(f"{i}. `{name}` ({type_name}){description}") - - return "\n".join(sections) - def format_field_structure(self, signature: type[Signature]) -> str: """Overrides the base method to generate a simplified schema for Pydantic models.""" diff --git a/tests/adapters/test_baml_adapter.py b/tests/adapters/test_baml_adapter.py index 827ad80046..5cc1481334 100644 --- a/tests/adapters/test_baml_adapter.py +++ b/tests/adapters/test_baml_adapter.py @@ -515,9 +515,9 @@ class SystemConfig(pydantic.BaseModel): endpoints: list[str] class TestSignature(dspy.Signature): - input_1: UserProfile = dspy.InputField() - input_2: SystemConfig = dspy.InputField() - result: str = dspy.OutputField() + input_1: UserProfile = dspy.InputField(desc="User profile information") + input_2: SystemConfig = dspy.InputField(desc="System configuration settings") + result: str = dspy.OutputField(desc="Resulting output after processing") adapter = BAMLAdapter() @@ -535,7 +535,10 @@ class TestSignature(dspy.Signature): # Test field descriptions are in the correct method field_desc = adapter.format_field_description(TestSignature) assert "Your input fields are:" in field_desc + assert "1. `input_1` (UserProfile): User profile information" in field_desc + assert "2. `input_2` (SystemConfig): System configuration settings" in field_desc assert "Your output fields are:" in field_desc + assert "1. `result` (str): Resulting output after processing" in field_desc # Test message formatting with actual Pydantic instances user_profile = UserProfile(name="John Doe", email="john@example.com", age=30)