diff --git a/go/ai/prompt.go b/go/ai/prompt.go index e41d248fc3..68bad27090 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -610,10 +610,16 @@ func loadPromptDir(r api.Registry, dir string, namespace string) { } } +// LoadPromptFromSource loads a single prompt from a string source into the registry. +// The name parameter can optionally include a variant suffix (e.g. "promptName.variant"), +// which will be parsed and used as the variant for the prompt. +func LoadPromptFromSource(r api.Registry, source, name, namespace string) Prompt { + return registerPrompt(r, source, name, namespace, "LoadPromptFromSource/"+name) +} + // LoadPrompt loads a single prompt into the registry. func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt { name := strings.TrimSuffix(filename, ".prompt") - name, variant, _ := strings.Cut(name, ".") sourceFile := filepath.Join(dir, filename) source, err := os.ReadFile(sourceFile) @@ -622,17 +628,22 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt { return nil } + return registerPrompt(r, string(source), name, namespace, sourceFile) +} + +func registerPrompt(r api.Registry, source, name, namespace, sourceInfo string) Prompt { + name, variant, _ := strings.Cut(name, ".") dp := r.Dotprompt() - parsedPrompt, err := dp.Parse(string(source)) + parsedPrompt, err := dp.Parse(source) if err != nil { - slog.Error("Failed to parse file as dotprompt", "file", sourceFile, "error", err) + slog.Error("Failed to parse file as dotprompt", "source", sourceInfo, "error", err) return nil } - metadata, err := dp.RenderMetadata(string(source), &parsedPrompt.PromptMetadata) + metadata, err := dp.RenderMetadata(source, &parsedPrompt.PromptMetadata) if err != nil { - slog.Error("Failed to render dotprompt metadata", "file", sourceFile, "error", err) + slog.Error("Failed to render dotprompt metadata", "source", sourceInfo, "error", err) return nil } @@ -710,7 +721,7 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt { dpMessages, err := dotprompt.ToMessages(parsedPrompt.Template, &dotprompt.DataArgument{}) if err != nil { - slog.Error("Failed to convert prompt template to messages", "file", sourceFile, "error", err) + slog.Error("Failed to convert prompt template to messages", "source", sourceInfo, "error", err) return nil } @@ -719,7 +730,7 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt { for _, dpMsg := range dpMessages { parts, err := convertToPartPointers(dpMsg.Content) if err != nil { - slog.Error("Failed to convert message parts", "file", sourceFile, "error", err) + slog.Error("Failed to convert message parts", "source", sourceInfo, "error", err) return nil } @@ -754,7 +765,7 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt { prompt := DefinePrompt(r, key, promptOpts...) - slog.Debug("Registered Dotprompt", "name", key, "file", sourceFile) + slog.Debug("Registered Dotprompt", "name", key, "source", sourceInfo) return prompt } diff --git a/go/ai/prompt_test.go b/go/ai/prompt_test.go index d8109d2eff..95296baa77 100644 --- a/go/ai/prompt_test.go +++ b/go/ai/prompt_test.go @@ -1506,3 +1506,52 @@ func TestWithOutputSchemaName_DefinePrompt_Missing(t *testing.T) { t.Errorf("Expected error 'schema \"MissingSchema\" not found', got: %v", err) } } + +func TestLoadPromptFromSource(t *testing.T) { + reg := registry.New() + + source := `--- +description: Test Prompt +model: test/model +input: + schema: + username: string +--- +Hello {{username}}` + + p := LoadPromptFromSource(reg, source, "testPrompt", "testNamespace") + if p == nil { + t.Fatal("LoadPromptFromSource returned nil") + } + + if p.Name() != "testNamespace/testPrompt" { + t.Errorf("got name %q, want %q", p.Name(), "testNamespace/testPrompt") + } + + // Verify variant handling in name + pVariant := LoadPromptFromSource(reg, source, "testPrompt.variant1", "testNamespace") + if pVariant == nil { + t.Fatal("LoadPromptFromSource with variant returned nil") + } + + if pVariant.Name() != "testNamespace/testPrompt.variant1" { + t.Errorf("got name %q, want %q", pVariant.Name(), "testNamespace/testPrompt.variant1") + } + + // Test rendering + input := map[string]any{"username": "Developer"} + req, err := p.Render(t.Context(), input) + if err != nil { + t.Fatalf("Render failed: %v", err) + } + + if len(req.Messages) == 0 || len(req.Messages[0].Content) == 0 { + t.Fatal("Render produced empty messages") + } + + gotText := req.Messages[0].Content[0].Text + wantText := "Hello Developer" + if gotText != wantText { + t.Errorf("Render got text %q, want %q", gotText, wantText) + } +} diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 75ef8c9a8a..8db5aaf99e 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -1030,6 +1030,22 @@ func LoadPromptDir(g *Genkit, dir string, namespace string) { ai.LoadPromptDir(g.reg, dir, namespace) } +// LoadPromptFromSource loads a single prompt from a string source into the registry, +// associating it with the given `namespace`, and returns the resulting [ai.prompt]. +// +// The `source` should be the string content of the prompt. +// The `name` is the name of the prompt (e.g. "greeting"). If the name contains a dot, +// the suffix is treated as the variant (e.g. "greeting.formal"). +// The `namespace` acts as a prefix to the prompt name (e.g., namespace "myApp" and +// name "greeting" results in prompt name "myApp/greeting"). Use an empty +// string for no namespace. +// +// This provides a way to load specific prompt sources programmatically, outside of the +// automatic loading done by [Init] or [LoadPromptDir]. +func LoadPromptFromSource(g *Genkit, source, name, namespace string) ai.Prompt { + return ai.LoadPromptFromSource(g.reg, source, name, namespace) +} + // LoadPrompt loads a single `.prompt` file specified by `path` into the registry, // associating it with the given `namespace`, and returns the resulting [ai.prompt]. // diff --git a/go/genkit/genkit_test.go b/go/genkit/genkit_test.go index 1513b28dd6..8cef413b29 100644 --- a/go/genkit/genkit_test.go +++ b/go/genkit/genkit_test.go @@ -123,3 +123,21 @@ func TestDefineSchemaWithType_Error(t *testing.T) { DefineSchemaFor[Invalid](g) } + +func TestLoadPromptFromSource(t *testing.T) { + g := Init(t.Context()) + + source := `--- +model: test/model +--- +Hello` + + p := LoadPromptFromSource(g, source, "testPrompt", "testNamespace") + if p == nil { + t.Fatal("LoadPromptFromSource returned nil") + } + + if p.Name() != "testNamespace/testPrompt" { + t.Errorf("got name %q, want %q", p.Name(), "testNamespace/testPrompt") + } +}