Skip to content

Commit e2a1508

Browse files
feat(py): loading prompts automatically (#3992)
Co-authored-by: Mengqin Shen <mengqin@google.com>
1 parent 7e37ed1 commit e2a1508

File tree

14 files changed

+163
-38
lines changed

14 files changed

+163
-38
lines changed

py/packages/genkit/src/genkit/ai/_aio.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class while customizing it with any plugins.
2323
import uuid
2424
from asyncio import Future
2525
from collections.abc import AsyncIterator
26+
from pathlib import Path
2627
from typing import Any
2728

2829
from genkit.aio import Channel
@@ -38,7 +39,7 @@ class while customizing it with any plugins.
3839
GenerateResponseWrapper,
3940
ModelMiddleware,
4041
)
41-
from genkit.blocks.prompt import PromptConfig, to_generate_action_options
42+
from genkit.blocks.prompt import PromptConfig, load_prompt_folder, to_generate_action_options
4243
from genkit.blocks.retriever import IndexerRef, IndexerRequest, RetrieverRef
4344
from genkit.core.action import ActionRunContext
4445
from genkit.core.action.types import ActionKind
@@ -72,18 +73,30 @@ def __init__(
7273
self,
7374
plugins: list[Plugin] | None = None,
7475
model: str | None = None,
76+
prompt_dir: str | Path | None = None,
7577
reflection_server_spec: ServerSpec | None = None,
7678
) -> None:
7779
"""Initialize a new Genkit instance.
7880
7981
Args:
8082
plugins: List of plugins to initialize.
8183
model: Model name to use.
84+
prompt_dir: Directory to automatically load prompts from.
85+
If not provided, defaults to loading from './prompts' if it exists.
8286
reflection_server_spec: Server spec for the reflection
8387
server.
8488
"""
8589
super().__init__(plugins=plugins, model=model, reflection_server_spec=reflection_server_spec)
8690

91+
load_path = prompt_dir
92+
if load_path is None:
93+
default_prompts_path = Path('./prompts')
94+
if default_prompts_path.is_dir():
95+
load_path = default_prompts_path
96+
97+
if load_path:
98+
load_prompt_folder(self.registry, dir_path=load_path)
99+
87100
async def generate(
88101
self,
89102
model: str | None = None,

py/packages/genkit/src/genkit/ai/_registry.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@
5151
from genkit.blocks.evaluator import BatchEvaluatorFn, EvaluatorFn
5252
from genkit.blocks.formats.types import FormatDef
5353
from genkit.blocks.model import ModelFn, ModelMiddleware
54-
from genkit.blocks.prompt import define_prompt
54+
from genkit.blocks.prompt import (
55+
define_helper,
56+
define_prompt,
57+
lookup_prompt,
58+
)
5559
from genkit.blocks.retriever import IndexerFn, RetrieverFn
5660
from genkit.blocks.tools import ToolRunContext
5761
from genkit.codec import dump_dict
@@ -168,6 +172,15 @@ def sync_wrapper(*args, **kwargs):
168172

169173
return wrapper
170174

175+
def define_helper(self, name: str, fn: Callable) -> None:
176+
"""Define a Handlebars helper function in the registry.
177+
178+
Args:
179+
name: The name of the helper function.
180+
fn: The helper function to register.
181+
"""
182+
define_helper(self.registry, name, fn)
183+
171184
def tool(self, name: str | None = None, description: str | None = None) -> Callable[[Callable], Callable]:
172185
"""Decorator to register a function as a tool.
173186

py/packages/genkit/src/genkit/blocks/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def my_model(request: GenerateRequest) -> GenerateResponse:
3636

3737
from pydantic import BaseModel, Field
3838

39-
from genkit.ai import ActionKind
4039
from genkit.core.action import ActionMetadata, ActionRunContext
40+
from genkit.core.action.types import ActionKind
4141
from genkit.core.extract import extract_json
4242
from genkit.core.schema import to_json_schema
4343
from genkit.core.typing import (

py/packages/genkit/src/genkit/blocks/prompt.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -659,10 +659,11 @@ async def render_dotprompt_to_parts(
659659
Raises:
660660
Exception: If the template produces more than one message.
661661
"""
662-
merged_input = input_
662+
# Flatten input and context for template resolution
663+
flattened_data = {**(context or {}), **(input_ or {})}
663664
rendered = await prompt_function(
664665
data=DataArgument[dict[str, Any]](
665-
input=merged_input,
666+
input=flattened_data,
666667
context=context,
667668
),
668669
options=options,
@@ -718,9 +719,11 @@ async def render_message_prompt(
718719
if isinstance(options.messages, list):
719720
messages_ = [e.model_dump() for e in options.messages]
720721

722+
# Flatten input and context for template resolution
723+
flattened_data = {**(context or {}), **(input or {})}
721724
rendered = await prompt_cache.messages(
722725
data=DataArgument[dict[str, Any]](
723-
input=input,
726+
input=flattened_data,
724727
context=context,
725728
messages=messages_,
726729
),
@@ -841,7 +844,7 @@ def define_helper(registry: Registry, name: str, fn: Callable) -> None:
841844
logger.debug(f'Registered Dotprompt helper "{name}"')
842845

843846

844-
def load_prompt(registry: Registry, path: Path, filename: str, prefix: str = '', ns: str = 'dotprompt') -> None:
847+
def load_prompt(registry: Registry, path: Path, filename: str, prefix: str = '', ns: str = '') -> None:
845848
"""Load a single prompt file and register it in the registry.
846849
847850
This function loads a .prompt file, parses it, and registers it as a lazy-loaded
@@ -1091,6 +1094,13 @@ def load_prompt_folder_recursively(registry: Registry, dir_path: Path, ns: str,
10911094
partial_name = entry.name[1:-7] # Remove "_" prefix and ".prompt" suffix
10921095
with open(entry.path, 'r', encoding='utf-8') as f:
10931096
source = f.read()
1097+
1098+
# Strip frontmatter if present
1099+
if source.startswith('---'):
1100+
end_frontmatter = source.find('---', 3)
1101+
if end_frontmatter != -1:
1102+
source = source[end_frontmatter + 3 :].strip()
1103+
10941104
define_partial(registry, partial_name, source)
10951105
logger.debug(f'Registered Dotprompt partial "{partial_name}" from "{entry.path}"')
10961106
else:
@@ -1107,7 +1117,7 @@ def load_prompt_folder_recursively(registry: Registry, dir_path: Path, ns: str,
11071117
logger.error(f'Error loading prompts from {full_path}: {e}')
11081118

11091119

1110-
def load_prompt_folder(registry: Registry, dir_path: str | Path = './prompts', ns: str = 'dotprompt') -> None:
1120+
def load_prompt_folder(registry: Registry, dir_path: str | Path = './prompts', ns: str = '') -> None:
11111121
"""Load all prompt files from a directory.
11121122
11131123
This is the main entry point for loading prompts from a directory.

py/packages/genkit/tests/genkit/blocks/prompt_test.py

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -465,9 +465,9 @@ async def test_file_based_prompt_registers_two_actions() -> None:
465465
# Load prompts from directory
466466
load_prompt_folder(ai.registry, prompt_dir)
467467

468-
# Actions are registered with registry_definition_key (e.g., "dotprompt/filePrompt")
468+
# Actions are registered with registry_definition_key (e.g., "filePrompt")
469469
# We need to look them up by kind and name (without the /prompt/ prefix)
470-
action_name = 'dotprompt/filePrompt' # registry_definition_key format
470+
action_name = 'filePrompt' # registry_definition_key format
471471

472472
prompt_action = ai.registry.lookup_action(ActionKind.PROMPT, action_name)
473473
executable_prompt_action = ai.registry.lookup_action(ActionKind.EXECUTABLE_PROMPT, action_name)
@@ -491,7 +491,7 @@ async def test_prompt_and_executable_prompt_return_types() -> None:
491491
prompt_file.write_text('hello {{name}}')
492492

493493
load_prompt_folder(ai.registry, prompt_dir)
494-
action_name = 'dotprompt/testPrompt'
494+
action_name = 'testPrompt'
495495

496496
prompt_action = ai.registry.lookup_action(ActionKind.PROMPT, action_name)
497497
executable_prompt_action = ai.registry.lookup_action(ActionKind.EXECUTABLE_PROMPT, action_name)
@@ -540,7 +540,73 @@ async def test_prompt_function_uses_lookup_prompt() -> None:
540540

541541
load_prompt_folder(ai.registry, prompt_dir)
542542

543-
# Use prompt() function to look up the file-based prompt
544-
executable = await prompt(ai.registry, 'promptFuncTest')
545-
response = await executable({'name': 'World'})
546-
assert 'World' in response.text
543+
# Use ai.prompt() to look up the file-based prompt
544+
executable = await ai.prompt('promptFuncTest')
545+
546+
# Verify it can be executed
547+
response = await executable({'name': 'Genkit'})
548+
assert 'Genkit' in response.text
549+
550+
551+
@pytest.mark.asyncio
552+
async def test_automatic_prompt_loading():
553+
"""Test that Genkit automatically loads prompts from a directory."""
554+
with tempfile.TemporaryDirectory() as tmp_dir:
555+
# Create a prompt file
556+
prompt_content = """---
557+
name: testPrompt
558+
---
559+
Hello {{name}}!
560+
"""
561+
prompt_file = Path(tmp_dir) / 'test.prompt'
562+
prompt_file.write_text(prompt_content)
563+
564+
# Initialize Genkit with the temporary directory
565+
ai = Genkit(prompt_dir=tmp_dir)
566+
567+
# Verify the prompt is registered
568+
# File-based prompts are registered with an empty namespace by default
569+
actions = ai.registry.list_serializable_actions()
570+
assert '/prompt/test' in actions
571+
assert '/executable-prompt/test' in actions
572+
573+
574+
@pytest.mark.asyncio
575+
async def test_automatic_prompt_loading_default_none():
576+
"""Test that Genkit does not load prompts if prompt_dir is None."""
577+
ai = Genkit(prompt_dir=None)
578+
actions = ai.registry.list_serializable_actions()
579+
580+
# Check that no prompts are registered (assuming a clean environment)
581+
dotprompts = [key for key in actions.keys() if '/prompt/' in key or '/executable-prompt/' in key]
582+
assert len(dotprompts) == 0
583+
584+
585+
@pytest.mark.asyncio
586+
async def test_automatic_prompt_loading_defaults_mock():
587+
"""Test that Genkit defaults to ./prompts when prompt_dir is not specified and dir exists."""
588+
from unittest.mock import ANY, MagicMock, patch
589+
590+
with patch('genkit.ai._aio.load_prompt_folder') as mock_load, patch('genkit.ai._aio.Path') as mock_path:
591+
# Setup mock to simulate ./prompts existing
592+
mock_path_instance = MagicMock()
593+
mock_path_instance.is_dir.return_value = True
594+
mock_path.return_value = mock_path_instance
595+
596+
Genkit()
597+
mock_load.assert_called_once_with(ANY, dir_path=mock_path_instance)
598+
599+
600+
@pytest.mark.asyncio
601+
async def test_automatic_prompt_loading_defaults_missing():
602+
"""Test that Genkit skips loading when ./prompts is missing."""
603+
from unittest.mock import ANY, MagicMock, patch
604+
605+
with patch('genkit.ai._aio.load_prompt_folder') as mock_load, patch('genkit.ai._aio.Path') as mock_path:
606+
# Setup mock to simulate ./prompts missing
607+
mock_path_instance = MagicMock()
608+
mock_path_instance.is_dir.return_value = False
609+
mock_path.return_value = mock_path_instance
610+
611+
Genkit()
612+
mock_load.assert_not_called()

py/samples/evaluator-demo/src/eval_demo.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import json
1818
import os
19-
from typing import Any
19+
from typing import Any, List
2020

2121
import pytest
2222
import structlog
@@ -27,7 +27,7 @@
2727

2828
logger = structlog.get_logger(__name__)
2929

30-
ai = Genkit(plugins=[GoogleAI()])
30+
ai = Genkit(plugins=[GoogleAI()], model='googleai/gemini-2.5-flash')
3131

3232

3333
async def substring_match(datapoint: BaseDataPoint, options: Any | None):
@@ -54,15 +54,19 @@ async def substring_match(datapoint: BaseDataPoint, options: Any | None):
5454
)
5555

5656

57+
5758
# Define a flow that programmatically runs the evaluation
5859
@ai.flow()
59-
async def run_eval_demo(input: Any = None):
60-
# Load dataset
61-
data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'dataset.json')
62-
with open(data_path, 'r') as f:
63-
raw_data = json.load(f)
64-
65-
dataset = [BaseDataPoint(**d) for d in raw_data]
60+
async def run_eval_demo(dataset_input: List[BaseDataPoint] | None = None):
61+
if dataset_input:
62+
dataset = dataset_input
63+
else:
64+
# Load dataset
65+
data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'dataset.json')
66+
with open(data_path, 'r') as f:
67+
raw_data = json.load(f)
68+
69+
dataset = [BaseDataPoint(**d) for d in raw_data]
6670

6771
logger.info('Running evaluation...', count=len(dataset))
6872

py/samples/prompt_demo/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@ genkit start -- uv run src/prompt_demo.py
1919

2020
## Prompt Structure
2121

22-
- `data/`: Contains `.prompt` files (using [Dotprompt](https://genkit.dev/docs/dotprompt)).
23-
- `data/_shared_partial.prompt`: A partial that can be included in other prompts.
24-
- `data/nested/nested_hello.prompt`: A prompt demonstrating nested structure and partial inclusion.
22+
- `prompts/`: Contains `.prompt` files (using [Dotprompt](https://genkit.dev/docs/dotprompt)).
23+
- `prompts/_shared_partial.prompt`: A partial that can be included in other prompts.
24+
- `prompts/nested/nested_hello.prompt`: A prompt demonstrating nested structure and partial inclusion.

py/samples/prompt_demo/data/hello.prompt

Lines changed: 0 additions & 1 deletion
This file was deleted.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
---
2-
model: googleai/gemini-1.5-flash
2+
model: googleai/gemini-2.5-flash
33
---
44
This is a PARTIAL that says: {{my_helper "Partial content with helper"}}
File renamed without changes.

0 commit comments

Comments
 (0)