Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,16 @@ func loadPromptDir(r api.Registry, dir string, namespace string) {
}
}

// LoadPromptFromSource loads a single prompt from a string source into the registry.
// The name parameter can optionally include a variant suffix (e.g. "promptName.variant"),
// which will be parsed and used as the variant for the prompt.
func LoadPromptFromSource(r api.Registry, source, name, namespace string) Prompt {
return registerPrompt(r, source, name, namespace, "LoadPromptFromSource/"+name)
}

// LoadPrompt loads a single prompt into the registry.
func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
name := strings.TrimSuffix(filename, ".prompt")
name, variant, _ := strings.Cut(name, ".")

sourceFile := filepath.Join(dir, filename)
source, err := os.ReadFile(sourceFile)
Expand All @@ -622,17 +628,22 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
return nil
}

return registerPrompt(r, string(source), name, namespace, sourceFile)
}

func registerPrompt(r api.Registry, source, name, namespace, sourceInfo string) Prompt {
name, variant, _ := strings.Cut(name, ".")
dp := r.Dotprompt()

parsedPrompt, err := dp.Parse(string(source))
parsedPrompt, err := dp.Parse(source)
if err != nil {
slog.Error("Failed to parse file as dotprompt", "file", sourceFile, "error", err)
slog.Error("Failed to parse file as dotprompt", "source", sourceInfo, "error", err)
return nil
}

metadata, err := dp.RenderMetadata(string(source), &parsedPrompt.PromptMetadata)
metadata, err := dp.RenderMetadata(source, &parsedPrompt.PromptMetadata)
if err != nil {
slog.Error("Failed to render dotprompt metadata", "file", sourceFile, "error", err)
slog.Error("Failed to render dotprompt metadata", "source", sourceInfo, "error", err)
return nil
}

Expand Down Expand Up @@ -710,7 +721,7 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {

dpMessages, err := dotprompt.ToMessages(parsedPrompt.Template, &dotprompt.DataArgument{})
if err != nil {
slog.Error("Failed to convert prompt template to messages", "file", sourceFile, "error", err)
slog.Error("Failed to convert prompt template to messages", "source", sourceInfo, "error", err)
return nil
}

Expand All @@ -719,7 +730,7 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
for _, dpMsg := range dpMessages {
parts, err := convertToPartPointers(dpMsg.Content)
if err != nil {
slog.Error("Failed to convert message parts", "file", sourceFile, "error", err)
slog.Error("Failed to convert message parts", "source", sourceInfo, "error", err)
return nil
}

Expand Down Expand Up @@ -754,7 +765,7 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {

prompt := DefinePrompt(r, key, promptOpts...)

slog.Debug("Registered Dotprompt", "name", key, "file", sourceFile)
slog.Debug("Registered Dotprompt", "name", key, "source", sourceInfo)

return prompt
}
Expand Down
49 changes: 49 additions & 0 deletions go/ai/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1506,3 +1506,52 @@ func TestWithOutputSchemaName_DefinePrompt_Missing(t *testing.T) {
t.Errorf("Expected error 'schema \"MissingSchema\" not found', got: %v", err)
}
}

func TestLoadPromptFromSource(t *testing.T) {
reg := registry.New()

source := `---
description: Test Prompt
model: test/model
input:
schema:
username: string
---
Hello {{username}}`

p := LoadPromptFromSource(reg, source, "testPrompt", "testNamespace")
if p == nil {
t.Fatal("LoadPromptFromSource returned nil")
}

if p.Name() != "testNamespace/testPrompt" {
t.Errorf("got name %q, want %q", p.Name(), "testNamespace/testPrompt")
}

// Verify variant handling in name
pVariant := LoadPromptFromSource(reg, source, "testPrompt.variant1", "testNamespace")
if pVariant == nil {
t.Fatal("LoadPromptFromSource with variant returned nil")
}

if pVariant.Name() != "testNamespace/testPrompt.variant1" {
t.Errorf("got name %q, want %q", pVariant.Name(), "testNamespace/testPrompt.variant1")
}

// Test rendering
input := map[string]any{"username": "Developer"}
req, err := p.Render(t.Context(), input)
if err != nil {
t.Fatalf("Render failed: %v", err)
}

if len(req.Messages) == 0 || len(req.Messages[0].Content) == 0 {
t.Fatal("Render produced empty messages")
}

gotText := req.Messages[0].Content[0].Text
wantText := "Hello Developer"
if gotText != wantText {
t.Errorf("Render got text %q, want %q", gotText, wantText)
}
}
Comment on lines +1556 to +1557
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test for TestLoadPromptFromSource is great for the happy paths. To make it more robust, consider adding a test case for when an invalid prompt source is provided. This will ensure that the function correctly handles errors and returns nil as expected.

For example, you could add a check for a source with an unclosed template variable.

Suggested change
}
}
}
// Test invalid source
invalidSource := "---\ndescription: Invalid\n---\nHello {{ name" // unclosed template variable
pInvalid := LoadPromptFromSource(reg, invalidSource, "invalidPrompt", "testNamespace")
if pInvalid != nil {
t.Error("LoadPromptFromSource with invalid source should have returned nil")
}
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot add a test case that expects LoadPromptFromSource to return nil for invalid template syntax (like an unclosed variable) because the underlying dotprompt library performs parsing loosely at load time. It does not validate the template syntax during Parse, so LoadPromptFromSource successfully returns a Prompt object. The error would only occur later during execution (e.g., when Render is called).

16 changes: 16 additions & 0 deletions go/genkit/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,22 @@ func LoadPromptDir(g *Genkit, dir string, namespace string) {
ai.LoadPromptDir(g.reg, dir, namespace)
}

// LoadPromptFromSource loads a single prompt from a string source into the registry,
// associating it with the given `namespace`, and returns the resulting [ai.prompt].
//
// The `source` should be the string content of the prompt.
// The `name` is the name of the prompt (e.g. "greeting"). If the name contains a dot,
// the suffix is treated as the variant (e.g. "greeting.formal").
// The `namespace` acts as a prefix to the prompt name (e.g., namespace "myApp" and
// name "greeting" results in prompt name "myApp/greeting"). Use an empty
// string for no namespace.
//
// This provides a way to load specific prompt sources programmatically, outside of the
// automatic loading done by [Init] or [LoadPromptDir].
func LoadPromptFromSource(g *Genkit, source, name, namespace string) ai.Prompt {
return ai.LoadPromptFromSource(g.reg, source, name, namespace)
}

// LoadPrompt loads a single `.prompt` file specified by `path` into the registry,
// associating it with the given `namespace`, and returns the resulting [ai.prompt].
//
Expand Down
18 changes: 18 additions & 0 deletions go/genkit/genkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,21 @@ func TestDefineSchemaWithType_Error(t *testing.T) {

DefineSchemaFor[Invalid](g)
}

func TestLoadPromptFromSource(t *testing.T) {
g := Init(t.Context())

source := `---
model: test/model
---
Hello`

p := LoadPromptFromSource(g, source, "testPrompt", "testNamespace")
if p == nil {
t.Fatal("LoadPromptFromSource returned nil")
}

if p.Name() != "testNamespace/testPrompt" {
t.Errorf("got name %q, want %q", p.Name(), "testNamespace/testPrompt")
}
}
Comment on lines +142 to +143
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test covers the successful creation of a prompt. It would be beneficial to also test the failure case, for instance, when the provided source is invalid. This ensures the wrapper function correctly propagates the nil return on error.

Suggested change
}
}
}
// Test invalid source
invalidSource := "---\nmodel: test/model\n---\nHello {{ name" // unclosed template variable
pInvalid := LoadPromptFromSource(g, invalidSource, "invalidPrompt", "testNamespace")
if pInvalid != nil {
t.Error("LoadPromptFromSource with invalid source should have returned nil")
}
}