diff --git a/go/ai/formatter_test.go b/go/ai/formatter_test.go index 40afd0008d..5f75d77acf 100644 --- a/go/ai/formatter_test.go +++ b/go/ai/formatter_test.go @@ -661,17 +661,14 @@ func TestResolveFormat(t *testing.T) { } }) - t.Run("defaults to text even when schema present but no format", func(t *testing.T) { + t.Run("defaults to json when schema present but no format", func(t *testing.T) { schema := map[string]any{"type": "object"} formatter, err := resolveFormat(r, schema, "") if err != nil { t.Fatalf("resolveFormat() error = %v", err) } - // Note: The current implementation defaults to text when format is empty, - // even if schema is present. The schema/format combination is typically - // handled at a higher level (e.g., in Generate options). - if formatter.Name() != OutputFormatText { - t.Errorf("resolveFormat() = %q, want %q", formatter.Name(), OutputFormatText) + if formatter.Name() != OutputFormatJSON { + t.Errorf("resolveFormat() = %q, want %q", formatter.Name(), OutputFormatJSON) } }) diff --git a/go/ai/generate.go b/go/ai/generate.go index f26cc9f09a..0d3d75e7db 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -21,6 +21,7 @@ import ( "encoding/json" "errors" "fmt" + "iter" "slices" "strings" @@ -550,7 +551,7 @@ func GenerateText(ctx context.Context, r api.Registry, opts ...GenerateOption) ( return res.Text(), nil } -// Generate run generate request for this model. Returns ModelResponse struct. +// GenerateData runs a generate request and returns strongly-typed output. func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...GenerateOption) (*Out, *ModelResponse, error) { var value Out opts = append(opts, WithOutputType(value)) @@ -568,6 +569,104 @@ func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...Generate return &value, resp, nil } +// StreamValue is either a streamed chunk or the final response of a generate request. +type StreamValue[Out, Stream any] struct { + Done bool + Chunk Stream // valid if Done is false + Output Out // valid if Done is true + Response *ModelResponse // valid if Done is true +} + +// ModelStreamValue is a stream value for a model response. +// Out is never set because the output is already available in the Response field. +type ModelStreamValue = StreamValue[struct{}, *ModelResponseChunk] + +// errGenerateStop is a sentinel error used to signal early termination of streaming. +var errGenerateStop = errors.New("stop") + +// GenerateStream generates a model response and streams the output. +// It returns an iterator that yields streaming results. +// +// If the yield function is passed a non-nil error, generation has failed with that +// error; the yield function will not be called again. +// +// If the yield function's [ModelStreamValue] argument has Done == true, the value's +// Response field contains the final response; the yield function will not be called +// again. +// +// Otherwise the Chunk field of the passed [ModelStreamValue] holds a streamed chunk. +func GenerateStream(ctx context.Context, r api.Registry, opts ...GenerateOption) iter.Seq2[*ModelStreamValue, error] { + return func(yield func(*ModelStreamValue, error) bool) { + cb := func(ctx context.Context, chunk *ModelResponseChunk) error { + if ctx.Err() != nil { + return ctx.Err() + } + if !yield(&ModelStreamValue{Chunk: chunk}, nil) { + return errGenerateStop + } + return nil + } + + allOpts := append(slices.Clone(opts), WithStreaming(cb)) + + resp, err := Generate(ctx, r, allOpts...) + if err != nil { + yield(nil, err) + } else { + yield(&ModelStreamValue{Done: true, Response: resp}, nil) + } + } +} + +// GenerateDataStream generates a model response with streaming and returns strongly-typed output. +// It returns an iterator that yields streaming results. +// +// If the yield function is passed a non-nil error, generation has failed with that +// error; the yield function will not be called again. +// +// If the yield function's [StreamValue] argument has Done == true, the value's +// Output and Response fields contain the final typed output and response; the yield function +// will not be called again. +// +// Otherwise the Chunk field of the passed [StreamValue] holds a streamed chunk. +func GenerateDataStream[Out any](ctx context.Context, r api.Registry, opts ...GenerateOption) iter.Seq2[*StreamValue[Out, Out], error] { + return func(yield func(*StreamValue[Out, Out], error) bool) { + cb := func(ctx context.Context, chunk *ModelResponseChunk) error { + if ctx.Err() != nil { + return ctx.Err() + } + var streamValue Out + if err := chunk.Output(&streamValue); err != nil { + yield(nil, err) + return err + } + if !yield(&StreamValue[Out, Out]{Chunk: streamValue}, nil) { + return errGenerateStop + } + return nil + } + + // Prepend WithOutputType so the user can override the output format. + var value Out + allOpts := append([]GenerateOption{WithOutputType(value)}, opts...) + allOpts = append(allOpts, WithStreaming(cb)) + + resp, err := Generate(ctx, r, allOpts...) + if err != nil { + yield(nil, err) + return + } + + output, err := extractTypedOutput[Out](resp) + if err != nil { + yield(nil, err) + return + } + + yield(&StreamValue[Out, Out]{Done: true, Output: output, Response: resp}, nil) + } +} + // Generate applies the [Action] to provided request. func (m *model) Generate(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { if m == nil { @@ -744,7 +843,7 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, // [ModelResponse] as a string. It returns an empty string if there // are no candidates or if the candidate has no message. func (mr *ModelResponse) Text() string { - if mr.Message == nil { + if mr == nil || mr.Message == nil { return "" } return mr.Message.Text() @@ -753,7 +852,7 @@ func (mr *ModelResponse) Text() string { // History returns messages from the request combined with the response message // to represent the conversation history. func (mr *ModelResponse) History() []*Message { - if mr.Message == nil { + if mr == nil || mr.Message == nil { return mr.Request.Messages } return append(mr.Request.Messages, mr.Message) @@ -762,7 +861,7 @@ func (mr *ModelResponse) History() []*Message { // Reasoning concatenates all reasoning parts present in the message func (mr *ModelResponse) Reasoning() string { var sb strings.Builder - if mr.Message == nil { + if mr == nil || mr.Message == nil { return "" } @@ -806,7 +905,7 @@ func (mr *ModelResponse) Output(v any) error { // ToolRequests returns the tool requests from the response. func (mr *ModelResponse) ToolRequests() []*ToolRequest { toolReqs := []*ToolRequest{} - if mr.Message == nil { + if mr == nil || mr.Message == nil { return toolReqs } for _, part := range mr.Message.Content { @@ -820,7 +919,7 @@ func (mr *ModelResponse) ToolRequests() []*ToolRequest { // Interrupts returns the interrupted tool request parts from the response. func (mr *ModelResponse) Interrupts() []*Part { parts := []*Part{} - if mr.Message == nil { + if mr == nil || mr.Message == nil { return parts } for _, part := range mr.Message.Content { @@ -833,7 +932,7 @@ func (mr *ModelResponse) Interrupts() []*Part { // Media returns the media content of the [ModelResponse] as a string. func (mr *ModelResponse) Media() string { - if mr.Message == nil { + if mr == nil || mr.Message == nil { return "" } for _, part := range mr.Message.Content { @@ -902,17 +1001,41 @@ func (c *ModelResponseChunk) Output(v any) error { // outputer is an interface for types that can unmarshal structured output. type outputer interface { - Output(v any) error + // Text returns the contents of the output as a string. + Text() string + // Output parses the structured output from the response and unmarshals it into value. + Output(value any) error } // OutputFrom is a convenience function that parses structured output from a // [ModelResponse] or [ModelResponseChunk] and returns it as a typed value. // This is equivalent to calling Output() but returns the value directly instead // of requiring a pointer argument. If you need to handle the error, use Output() instead. -func OutputFrom[T any](src outputer) T { - var v T - src.Output(&v) - return v +func OutputFrom[Out any](src outputer) Out { + output, err := extractTypedOutput[Out](src) + if err != nil { + return base.Zero[Out]() + } + return output +} + +// extractTypedOutput extracts the typed output from a model response. +// It supports string output by calling Text() and returning the result. +func extractTypedOutput[Out any](o outputer) (Out, error) { + var output Out + + switch any(output).(type) { + case string: + text := o.Text() + // Type assertion to convert string to Out (which we know is string). + result := any(text).(Out) + return result, nil + default: + if err := o.Output(&output); err != nil { + return base.Zero[Out](), fmt.Errorf("failed to parse output: %w", err) + } + return output, nil + } } // Text returns the contents of a [Message] as a string. It diff --git a/go/ai/generate_test.go b/go/ai/generate_test.go index cac1f9d508..fd50408ee6 100644 --- a/go/ai/generate_test.go +++ b/go/ai/generate_test.go @@ -18,6 +18,7 @@ package ai import ( "context" + "errors" "fmt" "math" "strings" @@ -1745,3 +1746,354 @@ func TestMultipartTools(t *testing.T) { } }) } + +// streamingTestData holds test output structures +type streamingTestData struct { + Name string `json:"name"` + Value int `json:"value"` +} + +func TestGenerateStream(t *testing.T) { + r := registry.New() + ConfigureFormats(r) + DefineGenerateAction(context.Background(), r) + + t.Run("yields chunks then final response", func(t *testing.T) { + chunkTexts := []string{"Hello", " ", "World"} + chunkIndex := 0 + + streamModel := DefineModel(r, "test/streamModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + if cb != nil { + for _, text := range chunkTexts { + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart(text)}, + }) + } + } + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("Hello World"), + }, nil + }) + + var receivedChunks []*ModelResponseChunk + var finalResponse *ModelResponse + + for val, err := range GenerateStream(context.Background(), r, + WithModel(streamModel), + WithPrompt("test streaming"), + ) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val.Done { + finalResponse = val.Response + } else { + receivedChunks = append(receivedChunks, val.Chunk) + chunkIndex++ + } + } + + if len(receivedChunks) != len(chunkTexts) { + t.Errorf("expected %d chunks, got %d", len(chunkTexts), len(receivedChunks)) + } + + for i, chunk := range receivedChunks { + if chunk.Text() != chunkTexts[i] { + t.Errorf("chunk %d: expected %q, got %q", i, chunkTexts[i], chunk.Text()) + } + } + + if finalResponse == nil { + t.Fatal("expected final response") + } + if finalResponse.Text() != "Hello World" { + t.Errorf("expected final text %q, got %q", "Hello World", finalResponse.Text()) + } + }) + + t.Run("handles no streaming callback gracefully", func(t *testing.T) { + noStreamModel := DefineModel(r, "test/noStreamModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("response without streaming"), + }, nil + }) + + var finalResponse *ModelResponse + chunkCount := 0 + + for val, err := range GenerateStream(context.Background(), r, + WithModel(noStreamModel), + WithPrompt("test no stream"), + ) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val.Done { + finalResponse = val.Response + } else { + chunkCount++ + } + } + + if chunkCount != 0 { + t.Errorf("expected 0 chunks when model doesn't stream, got %d", chunkCount) + } + if finalResponse == nil { + t.Fatal("expected final response") + } + if finalResponse.Text() != "response without streaming" { + t.Errorf("expected text %q, got %q", "response without streaming", finalResponse.Text()) + } + }) + + t.Run("propagates generation errors", func(t *testing.T) { + expectedErr := errors.New("generation failed") + + errorModel := DefineModel(r, "test/errorModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return nil, expectedErr + }) + + var receivedErr error + for _, err := range GenerateStream(context.Background(), r, + WithModel(errorModel), + WithPrompt("test error"), + ) { + if err != nil { + receivedErr = err + break + } + } + + if receivedErr == nil { + t.Fatal("expected error to be propagated") + } + if !errors.Is(receivedErr, expectedErr) { + t.Errorf("expected error %v, got %v", expectedErr, receivedErr) + } + }) + + t.Run("context cancellation stops iteration", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + streamModel := DefineModel(r, "test/cancelModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + if cb != nil { + for i := 0; i < 100; i++ { + err := cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart("chunk")}, + }) + if err != nil { + return nil, err + } + } + } + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("done"), + }, nil + }) + + chunksReceived := 0 + var receivedErr error + for val, err := range GenerateStream(ctx, r, + WithModel(streamModel), + WithPrompt("test cancel"), + ) { + if err != nil { + receivedErr = err + break + } + if !val.Done { + chunksReceived++ + if chunksReceived == 2 { + cancel() + } + } + } + + if chunksReceived < 2 { + t.Errorf("expected at least 2 chunks before cancellation, got %d", chunksReceived) + } + if receivedErr == nil { + t.Error("expected error from cancelled context") + } + }) +} + +func TestGenerateDataStream(t *testing.T) { + r := registry.New() + ConfigureFormats(r) + DefineGenerateAction(context.Background(), r) + + t.Run("yields typed chunks and final output", func(t *testing.T) { + streamModel := DefineModel(r, "test/typedStreamModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + if cb != nil { + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewJSONPart(`{"name":"partial","value":1}`)}, + }) + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewJSONPart(`{"name":"complete","value":42}`)}, + }) + } + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewJSONPart(`{"name":"final","value":42}`)}, + }, + }, nil + }) + + var chunks []streamingTestData + var finalOutput streamingTestData + var finalResponse *ModelResponse + + for val, err := range GenerateDataStream[streamingTestData](context.Background(), r, + WithModel(streamModel), + WithPrompt("test typed streaming"), + ) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val.Done { + finalOutput = val.Output + finalResponse = val.Response + } else { + chunks = append(chunks, val.Chunk) + } + } + + if len(chunks) < 1 { + t.Errorf("expected at least 1 chunk, got %d", len(chunks)) + } + + if finalOutput.Name != "final" || finalOutput.Value != 42 { + t.Errorf("expected final output {final, 42}, got %+v", finalOutput) + } + if finalResponse == nil { + t.Fatal("expected final response") + } + }) + + t.Run("final output is correctly typed", func(t *testing.T) { + streamModel := DefineModel(r, "test/finalTypedModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewJSONPart(`{"name":"result","value":123}`)}, + }, + }, nil + }) + + var finalOutput streamingTestData + var gotFinal bool + + for val, err := range GenerateDataStream[streamingTestData](context.Background(), r, + WithModel(streamModel), + WithPrompt("test final typed"), + ) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val.Done { + finalOutput = val.Output + gotFinal = true + } + } + + if !gotFinal { + t.Fatal("expected to receive final output") + } + if finalOutput.Name != "result" || finalOutput.Value != 123 { + t.Errorf("expected final output {result, 123}, got %+v", finalOutput) + } + }) + + t.Run("automatically sets output type", func(t *testing.T) { + var capturedRequest *ModelRequest + + streamModel := DefineModel(r, "test/autoOutputModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + capturedRequest = req + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewJSONPart(`{"name":"test","value":1}`)}, + }, + }, nil + }) + + for range GenerateDataStream[streamingTestData](context.Background(), r, + WithModel(streamModel), + WithPrompt("test auto output type"), + ) { + } + + if capturedRequest == nil { + t.Fatal("expected request to be captured") + } + if capturedRequest.Output == nil || capturedRequest.Output.Schema == nil { + t.Error("expected output schema to be set automatically") + } + }) + + t.Run("propagates chunk parsing errors", func(t *testing.T) { + streamModel := DefineModel(r, "test/parseErrorModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + if cb != nil { + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart("not valid json")}, + }) + } + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("done"), + }, nil + }) + + var receivedErr error + for _, err := range GenerateDataStream[streamingTestData](context.Background(), r, + WithModel(streamModel), + WithPrompt("test parse error"), + ) { + if err != nil { + receivedErr = err + break + } + } + + if receivedErr == nil { + t.Error("expected parsing error to be propagated") + } + }) +} diff --git a/go/ai/prompt.go b/go/ai/prompt.go index db4ec264cd..5bb9135548 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -19,11 +19,13 @@ import ( "encoding/json" "errors" "fmt" + "iter" "log/slog" "maps" "os" "path/filepath" "reflect" + "slices" "strings" "github.com/firebase/genkit/go/core" @@ -40,6 +42,8 @@ type Prompt interface { Name() string // Execute executes the prompt with the given options and returns a [ModelResponse]. Execute(ctx context.Context, opts ...PromptExecuteOption) (*ModelResponse, error) + // ExecuteStream executes the prompt with streaming and returns an iterator. + ExecuteStream(ctx context.Context, opts ...PromptExecuteOption) iter.Seq2[*ModelStreamValue, error] // Render renders the prompt with the given input and returns a [GenerateActionOptions] to be used with [GenerateWithRequest]. Render(ctx context.Context, input any) (*GenerateActionOptions, error) } @@ -51,6 +55,13 @@ type prompt struct { registry api.Registry } +// DataPrompt is a prompt with strongly-typed input and output. +// It wraps an underlying [Prompt] and provides type-safe Execute and Render methods. +// The Out type parameter can be string for text outputs or any struct type for JSON outputs. +type DataPrompt[In, Out any] struct { + prompt +} + // DefinePrompt creates a new [Prompt] and registers it. func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt { if name == "" { @@ -89,10 +100,7 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt { } metadata["type"] = api.ActionTypeExecutablePrompt - baseName := name - if idx := strings.LastIndex(name, "."); idx != -1 { - baseName = name[:idx] - } + baseName, variant, _ := strings.Cut(name, ".") promptMetadata := map[string]any{ "name": baseName, @@ -105,6 +113,9 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt { "tools": tools, "maxTurns": p.MaxTurns, } + if variant != "" { + promptMetadata["variant"] = variant + } if m, ok := metadata["prompt"].(map[string]any); ok { maps.Copy(m, promptMetadata) } else { @@ -133,7 +144,7 @@ func LookupPrompt(r api.Registry, name string) Prompt { // passes the rendered template to the AI model specified by the prompt. func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*ModelResponse, error) { if p == nil { - return nil, errors.New("Prompt.Execute: execute called on a nil Prompt; check that all prompts are defined") + return nil, core.NewError(core.INVALID_ARGUMENT, "Prompt.Execute: prompt is nil") } execOpts := &promptExecutionOptions{} @@ -239,10 +250,50 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod return GenerateWithRequest(ctx, r, actionOpts, execOpts.Middleware, execOpts.Stream) } +// ExecuteStream executes the prompt with streaming and returns an iterator. +// +// If the yield function is passed a non-nil error, execution has failed with that +// error; the yield function will not be called again. +// +// If the yield function's [ModelStreamValue] argument has Done == true, the value's +// Response field contains the final response; the yield function will not be called again. +// +// Otherwise the Chunk field of the passed [ModelStreamValue] holds a streamed chunk. +func (p *prompt) ExecuteStream(ctx context.Context, opts ...PromptExecuteOption) iter.Seq2[*ModelStreamValue, error] { + return func(yield func(*ModelStreamValue, error) bool) { + if p == nil { + yield(nil, core.NewError(core.INVALID_ARGUMENT, "Prompt.ExecuteStream: prompt is nil")) + return + } + + cb := func(ctx context.Context, chunk *ModelResponseChunk) error { + if ctx.Err() != nil { + return ctx.Err() + } + if !yield(&ModelStreamValue{Chunk: chunk}, nil) { + return errPromptStop + } + return nil + } + + allOpts := append(slices.Clone(opts), WithStreaming(cb)) + resp, err := p.Execute(ctx, allOpts...) + if err != nil { + yield(nil, err) + return + } + + yield(&ModelStreamValue{Done: true, Response: resp}, nil) + } +} + +// errPromptStop is a sentinel error used to signal early termination of streaming. +var errPromptStop = errors.New("stop") + // Render renders the prompt template based on user input. func (p *prompt) Render(ctx context.Context, input any) (*GenerateActionOptions, error) { if p == nil { - return nil, errors.New("Prompt.Render: called on a nil prompt; check that all prompts are defined") + return nil, core.NewError(core.INVALID_ARGUMENT, "Prompt.Render: prompt is nil") } if len(p.Middleware) > 0 { @@ -807,3 +858,129 @@ func contentType(ct, uri string) (string, []byte, error) { return "", nil, errors.New("uri content type not found") } + +// DefineDataPrompt creates a new data prompt and registers it. +// It automatically infers input schema from the In type parameter and configures +// output schema and JSON format from the Out type parameter (unless Out is string). +func DefineDataPrompt[In, Out any](r api.Registry, name string, opts ...PromptOption) *DataPrompt[In, Out] { + if name == "" { + panic("ai.DefineDataPrompt: name is required") + } + + var in In + allOpts := []PromptOption{WithInputType(in)} + + var out Out + switch any(out).(type) { + case string: + // String output - no schema needed + default: + // Prepend WithOutputType so the user can override the output format. + allOpts = append(allOpts, WithOutputType(out)) + } + + allOpts = append(allOpts, opts...) + p := DefinePrompt(r, name, allOpts...) + + return &DataPrompt[In, Out]{prompt: *p.(*prompt)} +} + +// LookupDataPrompt looks up a prompt by name and wraps it with type information. +// This is useful for wrapping prompts loaded from .prompt files with strong types. +// It returns nil if the prompt was not found. +func LookupDataPrompt[In, Out any](r api.Registry, name string) *DataPrompt[In, Out] { + return AsDataPrompt[In, Out](LookupPrompt(r, name)) +} + +// AsDataPrompt wraps an existing Prompt with type information, returning a DataPrompt. +// This is useful for adding strong typing to a dynamically obtained prompt. +func AsDataPrompt[In, Out any](p Prompt) *DataPrompt[In, Out] { + if p == nil { + return nil + } + + return &DataPrompt[In, Out]{prompt: *p.(*prompt)} +} + +// Execute executes the typed prompt and returns the strongly-typed output along with the full model response. +// For structured output types (non-string Out), the prompt must be configured with the appropriate +// output schema, either through [DefineDataPrompt] or by using [WithOutputType] when defining the prompt. +func (dp *DataPrompt[In, Out]) Execute(ctx context.Context, input In, opts ...PromptExecuteOption) (Out, *ModelResponse, error) { + if dp == nil { + return base.Zero[Out](), nil, core.NewError(core.INVALID_ARGUMENT, "DataPrompt.Execute: prompt is nil") + } + + allOpts := append(slices.Clone(opts), WithInput(input)) + resp, err := dp.prompt.Execute(ctx, allOpts...) + if err != nil { + return base.Zero[Out](), nil, err + } + + output, err := extractTypedOutput[Out](resp) + if err != nil { + return base.Zero[Out](), resp, err + } + + return output, resp, nil +} + +// ExecuteStream executes the typed prompt with streaming and returns an iterator. +// +// If the yield function is passed a non-nil error, execution has failed with that +// error; the yield function will not be called again. +// +// If the yield function's StreamValue argument has Done == true, the value's +// Output and Response fields contain the final typed output and response; the yield function +// will not be called again. +// +// Otherwise the Chunk field of the passed StreamValue holds a streamed chunk. +// +// For structured output types (non-string Out), the prompt must be configured with the appropriate +// output schema, either through [DefineDataPrompt] or by using [WithOutputType] when defining the prompt. +func (dp *DataPrompt[In, Out]) ExecuteStream(ctx context.Context, input In, opts ...PromptExecuteOption) iter.Seq2[*StreamValue[Out, Out], error] { + return func(yield func(*StreamValue[Out, Out], error) bool) { + if dp == nil { + yield(nil, core.NewError(core.INVALID_ARGUMENT, "DataPrompt.ExecuteStream: prompt is nil")) + return + } + + cb := func(ctx context.Context, chunk *ModelResponseChunk) error { + if ctx.Err() != nil { + return ctx.Err() + } + streamValue, err := extractTypedOutput[Out](chunk) + if err != nil { + yield(nil, err) + return err + } + if !yield(&StreamValue[Out, Out]{Chunk: streamValue}, nil) { + return errGenerateStop + } + return nil + } + + allOpts := append(slices.Clone(opts), WithInput(input), WithStreaming(cb)) + resp, err := dp.prompt.Execute(ctx, allOpts...) + if err != nil { + yield(nil, err) + return + } + + output, err := extractTypedOutput[Out](resp) + if err != nil { + yield(nil, err) + return + } + + yield(&StreamValue[Out, Out]{Done: true, Output: output, Response: resp}, nil) + } +} + +// Render renders the typed prompt template with the given input. +func (dp *DataPrompt[In, Out]) Render(ctx context.Context, input In) (*GenerateActionOptions, error) { + if dp == nil { + return nil, errors.New("DataPrompt.Render: prompt is nil") + } + + return dp.prompt.Render(ctx, input) +} diff --git a/go/ai/prompt_test.go b/go/ai/prompt_test.go index f711f6321b..7fb04ed0e4 100644 --- a/go/ai/prompt_test.go +++ b/go/ai/prompt_test.go @@ -16,6 +16,7 @@ package ai import ( "context" + "errors" "fmt" "os" "path/filepath" @@ -1096,6 +1097,50 @@ Hello, {{name}}! } } +func TestDefinePrompt_WithVariant(t *testing.T) { + reg := registry.New() + + DefinePrompt(reg, "example.code", WithPrompt("Hello, {{name}}!")) + + prompt := LookupPrompt(reg, "example.code") + if prompt == nil { + t.Fatalf("Prompt was not registered") + } + + promptMetadata, ok := prompt.(api.Action).Desc().Metadata["prompt"].(map[string]any) + if !ok { + t.Fatalf("Expected Metadata['prompt'] to be a map") + } + if promptMetadata["name"] != "example" { + t.Errorf("Expected metadata name 'example', got '%s'", promptMetadata["name"]) + } + if promptMetadata["variant"] != "code" { + t.Errorf("Expected variant 'code', got '%v'", promptMetadata["variant"]) + } +} + +func TestDefinePrompt_WithoutVariant(t *testing.T) { + reg := registry.New() + + DefinePrompt(reg, "simple", WithPrompt("Hello, world!")) + + prompt := LookupPrompt(reg, "simple") + if prompt == nil { + t.Fatalf("Prompt was not registered") + } + + promptMetadata, ok := prompt.(api.Action).Desc().Metadata["prompt"].(map[string]any) + if !ok { + t.Fatalf("Expected Metadata['prompt'] to be a map") + } + if promptMetadata["name"] != "simple" { + t.Errorf("Expected metadata name 'simple', got '%s'", promptMetadata["name"]) + } + if _, exists := promptMetadata["variant"]; exists { + t.Errorf("Expected no variant for prompt without dot, got '%v'", promptMetadata["variant"]) + } +} + func TestLoadPromptFolder(t *testing.T) { // Create a temporary directory for testing tempDir := t.TempDir() @@ -1518,3 +1563,490 @@ func TestWithOutputSchemaName_DefinePrompt_Missing(t *testing.T) { t.Errorf("Expected error 'schema \"MissingSchema\" not found', got: %v", err) } } + +func TestDataPromptExecute(t *testing.T) { + r := registry.New() + ConfigureFormats(r) + DefineGenerateAction(context.Background(), r) + + type GreetingInput struct { + Name string `json:"name"` + } + + type GreetingOutput struct { + Message string `json:"message"` + Count int `json:"count"` + } + + t.Run("typed input and output", func(t *testing.T) { + var capturedInput any + + testModel := DefineModel(r, "test/dataPromptModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + capturedInput = req.Messages[0].Text() + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewJSONPart(`{"message":"Hello, Alice!","count":1}`)}, + }, + }, nil + }) + + dp := DefineDataPrompt[GreetingInput, GreetingOutput](r, "greetingPrompt", + WithModel(testModel), + WithPrompt("Greet {{name}}"), + ) + + output, resp, err := dp.Execute(context.Background(), GreetingInput{Name: "Alice"}) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + if capturedInput != "Greet Alice" { + t.Errorf("expected input %q, got %q", "Greet Alice", capturedInput) + } + + if output.Message != "Hello, Alice!" { + t.Errorf("expected message %q, got %q", "Hello, Alice!", output.Message) + } + if output.Count != 1 { + t.Errorf("expected count 1, got %d", output.Count) + } + if resp == nil { + t.Error("expected response to be returned") + } + }) + + t.Run("string output type", func(t *testing.T) { + testModel := DefineModel(r, "test/stringDataPromptModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("Hello, World!"), + }, nil + }) + + dp := DefineDataPrompt[GreetingInput, string](r, "stringOutputPrompt", + WithModel(testModel), + WithPrompt("Say hello to {{name}}"), + ) + + output, resp, err := dp.Execute(context.Background(), GreetingInput{Name: "World"}) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + if output != "Hello, World!" { + t.Errorf("expected output %q, got %q", "Hello, World!", output) + } + if resp == nil { + t.Error("expected response to be returned") + } + }) + + t.Run("nil prompt returns error", func(t *testing.T) { + var dp *DataPrompt[GreetingInput, GreetingOutput] + + _, _, err := dp.Execute(context.Background(), GreetingInput{Name: "test"}) + if err == nil { + t.Error("expected error for nil prompt") + } + }) + + t.Run("additional options passed through", func(t *testing.T) { + var capturedConfig any + + testModel := DefineModel(r, "test/optionsDataPromptModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + capturedConfig = req.Config + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewJSONPart(`{"message":"test","count":0}`)}, + }, + }, nil + }) + + dp := DefineDataPrompt[GreetingInput, GreetingOutput](r, "optionsPrompt", + WithModel(testModel), + WithPrompt("Test {{name}}"), + ) + + _, _, err := dp.Execute(context.Background(), GreetingInput{Name: "test"}, + WithConfig(&GenerationCommonConfig{Temperature: 0.5}), + ) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + config, ok := capturedConfig.(*GenerationCommonConfig) + if !ok { + t.Fatalf("expected *GenerationCommonConfig, got %T", capturedConfig) + } + if config.Temperature != 0.5 { + t.Errorf("expected temperature 0.5, got %v", config.Temperature) + } + }) + + t.Run("returns error for invalid output parsing", func(t *testing.T) { + testModel := DefineModel(r, "test/parseFailDataPromptModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("not valid json"), + }, nil + }) + + dp := DefineDataPrompt[GreetingInput, GreetingOutput](r, "parseFailPrompt", + WithModel(testModel), + WithPrompt("Test {{name}}"), + ) + + _, _, err := dp.Execute(context.Background(), GreetingInput{Name: "test"}) + if err == nil { + t.Error("expected error for invalid JSON output") + } + }) +} + +func TestDataPromptExecuteStream(t *testing.T) { + r := registry.New() + ConfigureFormats(r) + DefineGenerateAction(context.Background(), r) + + type StreamInput struct { + Topic string `json:"topic"` + } + + type StreamOutput struct { + Text string `json:"text"` + Index int `json:"index"` + } + + t.Run("typed streaming with struct output", func(t *testing.T) { + testModel := DefineModel(r, "test/streamDataPromptModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + if cb != nil { + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewJSONPart(`{"text":"chunk1","index":1}`)}, + }) + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewJSONPart(`{"text":"final","index":99}`)}, + }) + } + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewJSONPart(`{"text":"final","index":99}`)}, + }, + }, nil + }) + + dp := DefineDataPrompt[StreamInput, StreamOutput](r, "streamPrompt", + WithModel(testModel), + WithPrompt("Stream about {{topic}}"), + ) + + var chunks []StreamOutput + var finalOutput StreamOutput + var finalResponse *ModelResponse + + for val, err := range dp.ExecuteStream(context.Background(), StreamInput{Topic: "testing"}) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val.Done { + finalOutput = val.Output + finalResponse = val.Response + } else { + chunks = append(chunks, val.Chunk) + } + } + + if len(chunks) < 1 { + t.Errorf("expected at least 1 chunk, got %d", len(chunks)) + } + + if finalOutput.Text != "final" || finalOutput.Index != 99 { + t.Errorf("expected final {final, 99}, got %+v", finalOutput) + } + if finalResponse == nil { + t.Error("expected final response") + } + }) + + t.Run("string output streaming", func(t *testing.T) { + testModel := DefineModel(r, "test/stringStreamDataPromptModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + if cb != nil { + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart("First ")}, + }) + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart("Second")}, + }) + } + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("First Second"), + }, nil + }) + + dp := DefineDataPrompt[StreamInput, string](r, "stringStreamPrompt", + WithModel(testModel), + WithPrompt("Generate text about {{topic}}"), + ) + + var chunks []string + var finalOutput string + + for val, err := range dp.ExecuteStream(context.Background(), StreamInput{Topic: "strings"}) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val.Done { + finalOutput = val.Output + } else { + chunks = append(chunks, val.Chunk) + } + } + + if len(chunks) != 2 { + t.Errorf("expected 2 chunks, got %d", len(chunks)) + } + if chunks[0] != "First " { + t.Errorf("chunk 0: expected %q, got %q", "First ", chunks[0]) + } + if chunks[1] != "Second" { + t.Errorf("chunk 1: expected %q, got %q", "Second", chunks[1]) + } + + if finalOutput != "First Second" { + t.Errorf("expected final %q, got %q", "First Second", finalOutput) + } + }) + + t.Run("nil prompt returns error", func(t *testing.T) { + var dp *DataPrompt[StreamInput, StreamOutput] + + var receivedErr error + for _, err := range dp.ExecuteStream(context.Background(), StreamInput{Topic: "test"}) { + if err != nil { + receivedErr = err + break + } + } + + if receivedErr == nil { + t.Error("expected error for nil prompt") + } + }) + + t.Run("handles options passed at execute time", func(t *testing.T) { + var capturedConfig any + + testModel := DefineModel(r, "test/optionsStreamModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + capturedConfig = req.Config + if cb != nil { + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewJSONPart(`{"text":"chunk","index":1}`)}, + }) + } + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewJSONPart(`{"text":"done","index":2}`)}, + }, + }, nil + }) + + dp := DefineDataPrompt[StreamInput, StreamOutput](r, "optionsStreamPrompt", + WithModel(testModel), + WithPrompt("Test {{topic}}"), + ) + + for range dp.ExecuteStream(context.Background(), StreamInput{Topic: "options"}, + WithConfig(&GenerationCommonConfig{Temperature: 0.7}), + ) { + } + + config, ok := capturedConfig.(*GenerationCommonConfig) + if !ok { + t.Fatalf("expected *GenerationCommonConfig, got %T", capturedConfig) + } + if config.Temperature != 0.7 { + t.Errorf("expected temperature 0.7, got %v", config.Temperature) + } + }) + + t.Run("propagates errors", func(t *testing.T) { + expectedErr := errors.New("stream failed") + + testModel := DefineModel(r, "test/errorStreamDataPromptModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return nil, expectedErr + }) + + dp := DefineDataPrompt[StreamInput, StreamOutput](r, "errorStreamPrompt", + WithModel(testModel), + WithPrompt("Test {{topic}}"), + ) + + var receivedErr error + for _, err := range dp.ExecuteStream(context.Background(), StreamInput{Topic: "error"}) { + if err != nil { + receivedErr = err + break + } + } + + if receivedErr == nil { + t.Error("expected error to be propagated") + } + if !errors.Is(receivedErr, expectedErr) { + t.Errorf("expected error %v, got %v", expectedErr, receivedErr) + } + }) +} + +func TestPromptExecuteStream(t *testing.T) { + r := registry.New() + ConfigureFormats(r) + DefineGenerateAction(context.Background(), r) + + t.Run("yields chunks then final response", func(t *testing.T) { + chunkTexts := []string{"A", "B", "C"} + + testModel := DefineModel(r, "test/promptStreamModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + if cb != nil { + for _, text := range chunkTexts { + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart(text)}, + }) + } + } + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("ABC"), + }, nil + }) + + p := DefinePrompt(r, "streamTestPrompt", + WithModel(testModel), + WithPrompt("Test"), + ) + + var chunks []*ModelResponseChunk + var finalResponse *ModelResponse + + for val, err := range p.ExecuteStream(context.Background()) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val.Done { + finalResponse = val.Response + } else { + chunks = append(chunks, val.Chunk) + } + } + + if len(chunks) != 3 { + t.Errorf("expected 3 chunks, got %d", len(chunks)) + } + for i, chunk := range chunks { + if chunk.Text() != chunkTexts[i] { + t.Errorf("chunk %d: expected %q, got %q", i, chunkTexts[i], chunk.Text()) + } + } + + if finalResponse == nil { + t.Fatal("expected final response") + } + if finalResponse.Text() != "ABC" { + t.Errorf("expected final text %q, got %q", "ABC", finalResponse.Text()) + } + }) + + t.Run("nil prompt returns error", func(t *testing.T) { + var p *prompt + + var receivedErr error + for _, err := range p.ExecuteStream(context.Background()) { + if err != nil { + receivedErr = err + break + } + } + + if receivedErr == nil { + t.Error("expected error for nil prompt") + } + }) + + t.Run("handles execution options", func(t *testing.T) { + var capturedConfig any + + testModel := DefineModel(r, "test/optionsPromptExecModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + capturedConfig = req.Config + if cb != nil { + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart("chunk")}, + }) + } + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("done"), + }, nil + }) + + p := DefinePrompt(r, "execOptionsTestPrompt", + WithModel(testModel), + WithPrompt("Test"), + ) + + for range p.ExecuteStream(context.Background(), + WithConfig(&GenerationCommonConfig{Temperature: 0.9}), + ) { + } + + config, ok := capturedConfig.(*GenerationCommonConfig) + if !ok { + t.Fatalf("expected *GenerationCommonConfig, got %T", capturedConfig) + } + if config.Temperature != 0.9 { + t.Errorf("expected temperature 0.9, got %v", config.Temperature) + } + }) +} diff --git a/go/core/flow.go b/go/core/flow.go index 0cd12120f2..ea514365c2 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -71,6 +71,9 @@ func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn St flowName: name, } ctx = flowContextKey.NewContext(ctx, fc) + if cb == nil { + cb = func(context.Context, Stream) error { return nil } + } return fn(ctx, input, cb) })) } diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 75ef8c9a8a..286952aae4 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -21,6 +21,7 @@ import ( "context" "errors" "fmt" + "iter" "log/slog" "os" "os/signal" @@ -268,7 +269,7 @@ func DefineFlow[In, Out any](g *Genkit, name string, fn core.Func[In, Out]) *cor // Example: // // counterFlow := genkit.DefineStreamingFlow(g, "counter", -// func(ctx context.Context, limit int, stream func(context.Context, int) error) (string, error) { +// func(ctx context.Context, limit int, stream core.StreamCallback[int]) (string, error) { // if stream == nil { // Non-streaming case // return fmt.Sprintf("Counted up to %d", limit), nil // } @@ -717,6 +718,42 @@ func DefineSchemaFor[T any](g *Genkit) { core.DefineSchemaFor[T](g.reg) } +// DefineDataPrompt creates a new [ai.DataPrompt] with strongly-typed input and output. +// It automatically infers input schema from the In type parameter and configures +// output schema and JSON format from the Out type parameter (unless Out is string). +// +// Example: +// +// type GeoInput struct { +// Country string `json:"country"` +// } +// +// type GeoOutput struct { +// Capital string `json:"capital"` +// } +// +// capitalPrompt := genkit.DefineDataPrompt[GeoInput, GeoOutput](g, "findCapital", +// ai.WithModelName("googleai/gemini-2.5-flash"), +// ai.WithSystem("You are a helpful geography assistant."), +// ai.WithPrompt("What is the capital of {{country}}?"), +// ) +// +// output, resp, err := capitalPrompt.Execute(ctx, GeoInput{Country: "France"}) +// if err != nil { +// log.Fatalf("Execute failed: %v", err) +// } +// fmt.Printf("Capital: %s\n", output.Capital) +func DefineDataPrompt[In, Out any](g *Genkit, name string, opts ...ai.PromptOption) *ai.DataPrompt[In, Out] { + return ai.DefineDataPrompt[In, Out](g.reg, name, opts...) +} + +// LookupDataPrompt looks up a prompt by name and wraps it with type information. +// This is useful for wrapping prompts loaded from .prompt files with strong types. +// It returns nil if the prompt was not found. +func LookupDataPrompt[In, Out any](g *Genkit, name string) *ai.DataPrompt[In, Out] { + return ai.LookupDataPrompt[In, Out](g.reg, name) +} + // GenerateWithRequest performs a model generation request using explicitly provided // [ai.GenerateActionOptions]. This function is typically used in conjunction with // prompts defined via [DefinePrompt], where [ai.prompt.Render] produces the @@ -766,6 +803,35 @@ func Generate(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) (*ai.Mo return ai.Generate(ctx, g.reg, opts...) } +// GenerateStream generates a model response and streams the output. +// It returns an iterator that yields streaming results. +// +// If the yield function is passed a non-nil error, generation has failed with that +// error; the yield function will not be called again. +// +// If the yield function's [ai.ModelStreamValue] argument has Done == true, the value's +// Response field contains the final response; the yield function will not be called again. +// +// Otherwise the Chunk field of the passed [ai.ModelStreamValue] holds a streamed chunk. +// +// Example: +// +// for result, err := range genkit.GenerateStream(ctx, g, +// ai.WithPrompt("Tell me a story about a brave knight."), +// ) { +// if err != nil { +// log.Fatalf("Stream error: %v", err) +// } +// if result.Done { +// fmt.Println("\nFinal response:", result.Response.Text()) +// } else { +// fmt.Print(result.Chunk.Text()) +// } +// } +func GenerateStream(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) iter.Seq2[*ai.ModelStreamValue, error] { + return ai.GenerateStream(ctx, g.reg, opts...) +} + // GenerateOperation performs a model generation request using a flexible set of options // provided via [ai.GenerateOption] arguments. It's a convenient way to make // generation calls without pre-defining a prompt object. @@ -854,6 +920,41 @@ func GenerateData[Out any](ctx context.Context, g *Genkit, opts ...ai.GenerateOp return ai.GenerateData[Out](ctx, g.reg, opts...) } +// GenerateDataStream generates a model response with streaming and returns strongly-typed output. +// It returns an iterator that yields streaming results. +// +// If the yield function is passed a non-nil error, generation has failed with that +// error; the yield function will not be called again. +// +// If the yield function's [ai.StreamValue] argument has Done == true, the value's +// Output and Response fields contain the final typed output and response; the yield function +// will not be called again. +// +// Otherwise the Chunk field of the passed [ai.StreamValue] holds a streamed chunk. +// +// Example: +// +// type Story struct { +// Title string `json:"title"` +// Content string `json:"content"` +// } +// +// for result, err := range genkit.GenerateDataStream[Story, *ai.ModelResponseChunk](ctx, g, +// ai.WithPrompt("Write a short story about a brave knight."), +// ) { +// if err != nil { +// log.Fatalf("Stream error: %v", err) +// } +// if result.Done { +// fmt.Printf("Story: %+v\n", result.Output) +// } else { +// fmt.Print(result.Chunk.Text()) +// } +// } +func GenerateDataStream[Out any](ctx context.Context, g *Genkit, opts ...ai.GenerateOption) iter.Seq2[*ai.StreamValue[Out, Out], error] { + return ai.GenerateDataStream[Out](ctx, g.reg, opts...) +} + // Retrieve performs a document retrieval request using a flexible set of options // provided via [ai.RetrieverOption] arguments. It's a convenient way to retrieve // relevant documents from registered retrievers without directly calling the diff --git a/go/plugins/googlegenai/googlegenai.go b/go/plugins/googlegenai/googlegenai.go index d056e6fb1c..8ddbdfbe4a 100644 --- a/go/plugins/googlegenai/googlegenai.go +++ b/go/plugins/googlegenai/googlegenai.go @@ -283,14 +283,19 @@ func (v *VertexAI) IsDefinedEmbedder(g *genkit.Genkit, name string) bool { return genkit.LookupEmbedder(g, api.NewName(vertexAIProvider, name)) != nil } -// GoogleAIModelRef creates a new ModelRef for a Google AI model with the given name and configuration. -func GoogleAIModelRef(name string, config *genai.GenerateContentConfig) ai.ModelRef { +// ModelRef creates a new ModelRef for a Google Gen AI model with the given name and configuration. +func ModelRef(name string, config *genai.GenerateContentConfig) ai.ModelRef { return ai.NewModelRef(googleAIProvider+"/"+name, config) } -// VertexAIModelRef creates a new ModelRef for a Vertex AI model with the given name and configuration. -func VertexAIModelRef(name string, config *genai.GenerateContentConfig) ai.ModelRef { - return ai.NewModelRef(vertexAIProvider+"/"+name, config) +// GoogleAIModelRef creates a new ModelRef for a Google AI model with the given ID and configuration. +func GoogleAIModelRef(id string, config *genai.GenerateContentConfig) ai.ModelRef { + return ai.NewModelRef(googleAIProvider+"/"+id, config) +} + +// VertexAIModelRef creates a new ModelRef for a Vertex AI model with the given ID and configuration. +func VertexAIModelRef(id string, config *genai.GenerateContentConfig) ai.ModelRef { + return ai.NewModelRef(vertexAIProvider+"/"+id, config) } // GoogleAIModel returns the [ai.Model] with the given name. diff --git a/go/samples/basic-gemini-with-context/main.go b/go/samples/basic-gemini-with-context/main.go deleted file mode 100644 index f971ecc9bc..0000000000 --- a/go/samples/basic-gemini-with-context/main.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "context" - "fmt" - - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/googlegenai" - "google.golang.org/genai" -) - -func main() { - ctx := context.Background() - - // Initialize Genkit with the Google AI plugin. When you pass nil for the - // Config parameter, the Google AI plugin will get the API key from the - // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended - // practice. - g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - - // Define a simple flow that generates jokes about a given topic with a context of bananas - genkit.DefineFlow(g, "contextFlow", func(ctx context.Context, input string) (string, error) { - resp, err := genkit.Generate(ctx, g, - ai.WithModelName("googleai/gemini-2.5-flash"), - ai.WithConfig(&genai.GenerateContentConfig{ - Temperature: genai.Ptr[float32](1.0), - }), - ai.WithPrompt(fmt.Sprintf(`Tell silly short jokes about %s`, input)), - ai.WithDocs(ai.DocumentFromText("Bananas are plentiful in the tropics.", nil))) - if err != nil { - return "", err - } - - text := resp.Text() - return text, nil - }) - - <-ctx.Done() -} diff --git a/go/samples/basic-gemini/main.go b/go/samples/basic-gemini/main.go deleted file mode 100644 index e61ec9df42..0000000000 --- a/go/samples/basic-gemini/main.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "context" - - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/googlegenai" - "google.golang.org/genai" -) - -func main() { - ctx := context.Background() - - // Initialize Genkit with the Google AI plugin. When you pass nil for the - // Config parameter, the Google AI plugin will get the API key from the - // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended - // practice. - g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - - // Define a simple flow that generates jokes about a given topic - genkit.DefineStreamingFlow(g, "jokesFlow", func(ctx context.Context, input string, cb ai.ModelStreamCallback) (string, error) { - type Joke struct { - Joke string `json:"joke"` - Category string `json:"jokeCategory" description:"What is the joke about"` - } - - genkit.DefineSchemaFor[Joke](g) - - resp, err := genkit.Generate(ctx, g, - ai.WithModelName("googleai/gemini-2.5-flash"), - ai.WithConfig(&genai.GenerateContentConfig{ - Temperature: genai.Ptr[float32](1.0), - ThinkingConfig: &genai.ThinkingConfig{ - ThinkingBudget: genai.Ptr[int32](0), - }, - }), - ai.WithStreaming(cb), - ai.WithOutputSchemaName("Joke"), - ai.WithPrompt(`Tell short jokes about %s`, input)) - if err != nil { - return "", err - } - - return resp.Text(), nil - }) - - <-ctx.Done() -} diff --git a/go/samples/basic-prompts/main.go b/go/samples/basic-prompts/main.go new file mode 100644 index 0000000000..ccbc308a13 --- /dev/null +++ b/go/samples/basic-prompts/main.go @@ -0,0 +1,287 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "log" + "net/http" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "github.com/firebase/genkit/go/plugins/server" + "google.golang.org/genai" +) + +type JokeRequest struct { + Topic string `json:"topic" jsonschema:"default=airplane food"` +} + +// Note how the fields are annotated with jsonschema tags to describe the output schema. +// This is vital for the model to understand the intent of the fields. +type Joke struct { + Joke string `json:"joke" jsonschema:"description=The joke text"` + Category string `json:"category" jsonschema:"description=The joke category"` +} + +type RecipeRequest struct { + Dish string `json:"dish" jsonschema:"default=pasta"` + Cuisine string `json:"cuisine" jsonschema:"default=Italian"` + ServingSize int `json:"servingSize" jsonschema:"default=4"` + MaxPrepMinutes int `json:"maxPrepMinutes" jsonschema:"default=30"` + DietaryRestrictions []string `json:"dietaryRestrictions,omitempty"` +} + +type Ingredient struct { + Name string `json:"name" jsonschema:"description=The ingredient name"` + Amount string `json:"amount" jsonschema:"description=The ingredient amount (e.g. 1 cup, 2 tablespoons, etc.)"` + Optional bool `json:"optional,omitempty" jsonschema:"description=Whether the ingredient is optional in the recipe"` +} + +type Recipe struct { + Title string `json:"title" jsonschema:"description=The recipe title (e.g. 'Spicy Chicken Tacos')"` + Description string `json:"description,omitempty" jsonschema:"description=The recipe description (under 100 characters)"` + Ingredients []*Ingredient `json:"ingredients" jsonschema:"description=The recipe ingredients (group by type and order by importance)"` + Instructions []string `json:"instructions" jsonschema:"description=The recipe instructions (step by step)"` + PrepTime string `json:"prepTime" jsonschema:"description=The recipe preparation time (e.g. 10 minutes, 30 minutes, etc.)"` + Difficulty string `json:"difficulty" jsonschema:"enum=easy,enum=medium,enum=hard"` +} + +func main() { + ctx := context.Background() + + // Initialize Genkit with the Google AI plugin. When you pass nil for the + // Config parameter, the Google AI plugin will get the API key from the + // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended + // practice. + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + // Define schemas for the expected input and output types so that the Dotprompt files can reference them. + // Alternatively, you can specify the JSON schema by hand in the Dotprompt metadata. + // Code-defined prompts do not need to have schemas defined in advance but they too can reference them. + genkit.DefineSchemaFor[JokeRequest](g) + genkit.DefineSchemaFor[Joke](g) + genkit.DefineSchemaFor[RecipeRequest](g) + genkit.DefineSchemaFor[Recipe](g) + + // TODO: Include partials and helpers. + + // Define the prompts and flows. + DefineSimpleJokeWithInlinePrompt(g) + DefineSimpleJokeWithDotprompt(g) + DefineStructuredJokeWithInlinePrompt(g) + DefineStructuredJokeWithDotprompt(g) + DefineRecipeWithInlinePrompt(g) + DefineRecipeWithDotprompt(g) + + // Optionally, start a web server to make the flows callable via HTTP. + mux := http.NewServeMux() + for _, a := range genkit.ListFlows(g) { + mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) + } + log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) +} + +// DefineSimpleJokeWithInlinePrompt demonstrates defining a prompt in code using DefinePrompt. +// The prompt has no output schema defined so it will always return a string. +// When executing the prompt, we pass in a map[string]any with the input fields. +func DefineSimpleJokeWithInlinePrompt(g *genkit.Genkit) { + jokePrompt := genkit.DefinePrompt( + g, "joke.code", + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + // Despite JokeRequest having defaults set in jsonschema tags, we can override it with values set in WithInputType. + ai.WithInputType(JokeRequest{Topic: "rush hour traffic"}), + ai.WithPrompt("Share a long joke about {{topic}}."), + ) + + genkit.DefineStreamingFlow(g, "simpleJokePromptFlow", + func(ctx context.Context, topic string, sendChunk core.StreamCallback[string]) (string, error) { + // One way to pass input is using a map[string]any. This is useful when there is no structured input type. + stream := jokePrompt.ExecuteStream(ctx, ai.WithInput(map[string]any{"topic": topic})) + for result, err := range stream { + if err != nil { + return "", fmt.Errorf("could not generate joke: %w", err) + } + if result.Done { + return result.Response.Text(), nil + } + sendChunk(ctx, result.Chunk.Text()) + } + + return "", nil + }, + ) +} + +// DefineSimpleJokeWithDotprompt demonstrates loading a prompt from a .prompt file using +// LoadPrompt. The prompt configuration (model, input schema, defaults) is defined in the +// file. Input is passed as a map since the .prompt file defines its own schema. +func DefineSimpleJokeWithDotprompt(g *genkit.Genkit) { + genkit.DefineStreamingFlow(g, "simpleJokeDotpromptFlow", + func(ctx context.Context, topic string, sendChunk core.StreamCallback[string]) (string, error) { + jokePrompt := genkit.LookupPrompt(g, "joke") + // One way to pass input is using a map[string]any. This is useful when there is no structured input type. + stream := jokePrompt.ExecuteStream(ctx, ai.WithInput(map[string]any{"topic": topic})) + for result, err := range stream { + if err != nil { + return "", fmt.Errorf("could not generate joke: %w", err) + } + if result.Done { + return result.Response.Text(), nil + } + sendChunk(ctx, result.Chunk.Text()) + } + + return "", nil + }, + ) +} + +// DefineStructuredJokeWithInlinePrompt demonstrates DefineDataPrompt for strongly-typed +// input and output. The type parameters automatically configure input/output schemas +// and JSON output format. ExecuteStream returns typed chunks and final output. +func DefineStructuredJokeWithInlinePrompt(g *genkit.Genkit) { + jokePrompt := genkit.DefineDataPrompt[JokeRequest, *Joke]( + g, "structured-joke.code", + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithPrompt("Share a long joke about {{topic}}."), + ) + + genkit.DefineStreamingFlow(g, "structuredJokePromptFlow", + func(ctx context.Context, input JokeRequest, sendChunk core.StreamCallback[*Joke]) (*Joke, error) { + for result, err := range jokePrompt.ExecuteStream(ctx, input) { + if err != nil { + return nil, fmt.Errorf("could not generate joke: %w", err) + } + if result.Done { + return result.Output, nil + } + sendChunk(ctx, result.Chunk) + } + + return nil, nil + }, + ) +} + +// DefineStructuredJokeWithDotprompt demonstrates LookupDataPrompt to wrap a .prompt file +// with Go type information. The .prompt file references registered schemas by name +// (e.g., "schema: Joke"), which must be defined via DefineSchemaFor before loading. +func DefineStructuredJokeWithDotprompt(g *genkit.Genkit) { + genkit.DefineStreamingFlow(g, "structuredJokeDotpromptFlow", + func(ctx context.Context, input JokeRequest, sendChunk core.StreamCallback[*Joke]) (*Joke, error) { + jokePrompt := genkit.LookupDataPrompt[JokeRequest, *Joke](g, "structured-joke") + stream := jokePrompt.ExecuteStream(ctx, input) + for result, err := range stream { + if err != nil { + return nil, fmt.Errorf("could not generate joke: %w", err) + } + if result.Done { + return result.Output, nil + } + sendChunk(ctx, result.Chunk) + } + return nil, nil + }, + ) +} + +// DefineRecipeWithInlinePrompt demonstrates DefineDataPrompt with complex nested types +// and Handlebars conditionals/loops in the prompt template. The streaming flow applies +// default values before execution and streams partial ingredients as they arrive. +func DefineRecipeWithInlinePrompt(g *genkit.Genkit) { + recipePrompt := genkit.DefineDataPrompt[RecipeRequest, *Recipe]( + g, "recipe.code", + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithSystem("You are an experienced chef. Come up with easy, creative recipes."), + ai.WithPrompt("Create a {{cuisine}} {{dish}} recipe for {{servingSize}} people that takes under {{maxPrepMinutes}} minutes to prepare. "+ + "{{#if dietaryRestrictions}}Dietary restrictions: {{#each dietaryRestrictions}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}.{{/if}}"), + ) + + genkit.DefineStreamingFlow(g, "recipePromptFlow", + func(ctx context.Context, input RecipeRequest, sendChunk core.StreamCallback[*Ingredient]) (*Recipe, error) { + // This is not necessary for this example but it shows how to easily have more control over what you stream. + filterNew := newIngredientFilter() + for result, err := range recipePrompt.ExecuteStream(ctx, input) { + if err != nil { + return nil, fmt.Errorf("could not generate recipe: %w", err) + } + if result.Done { + return result.Output, nil + } + for _, i := range filterNew(result.Chunk.Ingredients) { + sendChunk(ctx, i) + } + } + return nil, nil + }, + ) +} + +// DefineRecipeWithDotprompt demonstrates LookupDataPrompt with a .prompt file that uses +// multi-message format (system/user roles) and references registered schemas. +// Streams partial ingredients as they arrive via ExecuteStream. +func DefineRecipeWithDotprompt(g *genkit.Genkit) { + genkit.DefineStreamingFlow(g, "recipeDotpromptFlow", + func(ctx context.Context, input RecipeRequest, sendChunk core.StreamCallback[*Ingredient]) (*Recipe, error) { + // This is not necessary for this example but it shows how to easily have more control over what you stream. + filterNew := newIngredientFilter() + recipePrompt := genkit.LookupDataPrompt[RecipeRequest, *Recipe](g, "recipe") + stream := recipePrompt.ExecuteStream(ctx, input) + for result, err := range stream { + if err != nil { + return nil, fmt.Errorf("could not generate recipe: %w", err) + } + if result.Done { + return result.Output, nil + } + for _, i := range filterNew(result.Chunk.Ingredients) { + sendChunk(ctx, i) + } + } + return nil, nil + }, + ) +} + +// newIngredientFilter is a helper function to filter out duplicate ingredients. +// This allows us to stream only new ingredients as they are identified, avoiding duplicates. +func newIngredientFilter() func([]*Ingredient) []*Ingredient { + seen := map[string]struct{}{} + return func(ings []*Ingredient) (newIngs []*Ingredient) { + for _, ing := range ings { + if _, ok := seen[ing.Name]; !ok { + seen[ing.Name] = struct{}{} + newIngs = append(newIngs, ing) + } + } + return + } +} diff --git a/go/samples/basic-prompts/prompts/joke.prompt b/go/samples/basic-prompts/prompts/joke.prompt new file mode 100644 index 0000000000..fc1add0957 --- /dev/null +++ b/go/samples/basic-prompts/prompts/joke.prompt @@ -0,0 +1,13 @@ +--- +model: googleai/gemini-2.5-flash +config: + thinkingConfig: + thinkingBudget: 0 +input: + schema: + topic?: string + default: + topic: airplane food +--- +Share a long joke about {{topic}}. + diff --git a/go/samples/basic-prompts/prompts/recipe.prompt b/go/samples/basic-prompts/prompts/recipe.prompt new file mode 100644 index 0000000000..d132ba615e --- /dev/null +++ b/go/samples/basic-prompts/prompts/recipe.prompt @@ -0,0 +1,20 @@ +--- +model: googleai/gemini-2.5-flash +config: + thinkingConfig: + thinkingBudget: 0 +input: + schema: RecipeRequest +output: + format: json + schema: Recipe +--- +{{role "system"}} +You are an experienced chef. Come up with easy, creative recipes. + +{{role "user"}} +Create a {{cuisine}} {{dish}} recipe for {{servingSize}} people that takes under {{maxPrepMinutes}} minutes to prepare. +{{#if dietaryRestrictions}} +Dietary restrictions: {{#each dietaryRestrictions}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}. +{{/if}} + diff --git a/go/samples/basic-prompts/prompts/structured-joke.prompt b/go/samples/basic-prompts/prompts/structured-joke.prompt new file mode 100644 index 0000000000..7184b15483 --- /dev/null +++ b/go/samples/basic-prompts/prompts/structured-joke.prompt @@ -0,0 +1,13 @@ +--- +model: googleai/gemini-2.5-flash +config: + thinkingConfig: + thinkingBudget: 0 +input: + schema: JokeRequest +output: + format: json + schema: Joke +--- +Share a long joke about {{topic}}. + diff --git a/go/samples/basic-structured/main.go b/go/samples/basic-structured/main.go new file mode 100644 index 0000000000..428636de4d --- /dev/null +++ b/go/samples/basic-structured/main.go @@ -0,0 +1,181 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "log" + "net/http" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "github.com/firebase/genkit/go/plugins/server" + "google.golang.org/genai" +) + +type JokeRequest struct { + Topic string `json:"topic" jsonschema:"default=airplane food"` +} + +// Note how the fields are annotated with jsonschema tags to describe the output schema. +// This is vital for the model to understand the intent of the fields. +type Joke struct { + Joke string `json:"joke" jsonschema:"description=The joke text"` + Category string `json:"category" jsonschema:"description=The joke category"` +} + +type RecipeRequest struct { + Dish string `json:"dish" jsonschema:"default=pasta"` + Cuisine string `json:"cuisine" jsonschema:"default=Italian"` + ServingSize int `json:"servingSize" jsonschema:"default=4"` + MaxPrepMinutes int `json:"maxPrepMinutes" jsonschema:"default=30"` + DietaryRestrictions []string `json:"dietaryRestrictions,omitempty"` +} + +type Ingredient struct { + Name string `json:"name" jsonschema:"description=The ingredient name"` + Amount string `json:"amount" jsonschema:"description=The ingredient amount (e.g. 1 cup, 2 tablespoons, etc.)"` + Optional bool `json:"optional,omitempty" jsonschema:"description=Whether the ingredient is optional in the recipe"` +} + +type Recipe struct { + Title string `json:"title" jsonschema:"description=The recipe title (e.g. 'Spicy Chicken Tacos')"` + Description string `json:"description,omitempty" jsonschema:"description=The recipe description (under 100 characters)"` + Ingredients []*Ingredient `json:"ingredients" jsonschema:"description=The recipe ingredients (order by type first and then importance)"` + Instructions []string `json:"instructions" jsonschema:"description=The recipe instructions (step by step)"` + PrepTime string `json:"prepTime" jsonschema:"description=The recipe preparation time (e.g. 10 minutes, 30 minutes, etc.)"` + Difficulty string `json:"difficulty" jsonschema:"enum=easy,enum=medium,enum=hard"` +} + +func main() { + ctx := context.Background() + + // Initialize Genkit with the Google AI plugin. When you pass nil for the + // Config parameter, the Google AI plugin will get the API key from the + // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended + // practice. + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + // Define the flows. + DefineSimpleJoke(g) + DefineStructuredJoke(g) + DefineRecipe(g) + + // Optionally, start a web server to make the flows callable via HTTP. + mux := http.NewServeMux() + for _, a := range genkit.ListFlows(g) { + mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) + } + log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) +} + +// DefineSimpleJoke demonstrates defining a streaming flow that generates a joke about a given topic. +func DefineSimpleJoke(g *genkit.Genkit) { + genkit.DefineStreamingFlow(g, "simpleJokesFlow", + func(ctx context.Context, input string, sendChunk core.StreamCallback[string]) (string, error) { + stream := genkit.GenerateStream(ctx, g, + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithPrompt("Share a long joke about %s.", input), + ) + + for result, err := range stream { + if err != nil { + return "", fmt.Errorf("could not generate joke: %w", err) + } + if result.Done { + return result.Response.Text(), nil + } + sendChunk(ctx, result.Chunk.Text()) + } + + return "", nil + }, + ) +} + +// DefineStructuredJoke demonstrates defining a streaming flow that generates a joke about a given topic. +// The input is a strongly-typed JokeRequest struct and the output is a strongly-typed Joke struct. +func DefineStructuredJoke(g *genkit.Genkit) { + genkit.DefineStreamingFlow(g, "structuredJokesFlow", + func(ctx context.Context, input JokeRequest, sendChunk core.StreamCallback[*Joke]) (*Joke, error) { + stream := genkit.GenerateDataStream[*Joke](ctx, g, + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithPrompt("Share a long joke about %s.", input.Topic), + ) + + for result, err := range stream { + if err != nil { + return nil, fmt.Errorf("could not generate joke: %w", err) + } + if result.Done { + return result.Output, nil + } + sendChunk(ctx, result.Chunk) + } + + return nil, nil + }) +} + +// DefineRecipe demonstrates defining a streaming flow that generates a recipe based on a given RecipeRequest struct. +// The input is a strongly-typed RecipeRequest struct and the output is a strongly-typed Recipe struct. +func DefineRecipe(g *genkit.Genkit) { + genkit.DefineStreamingFlow(g, "recipeFlow", + func(ctx context.Context, input RecipeRequest, sendChunk core.StreamCallback[[]*Ingredient]) (*Recipe, error) { + stream := genkit.GenerateDataStream[*Recipe](ctx, g, + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithSystem("You are an experienced chef. Come up with easy, creative recipes."), + // Here we are passing WithPromptFn() since our prompt takes some string manipulation to build. + // Alternatively, we could pass WithPrompt() with the complete prompt string. + ai.WithPromptFn(func(ctx context.Context, _ any) (string, error) { + prompt := fmt.Sprintf( + "Create a %s %s recipe for %d people that takes under %d minutes to prepare.", + input.Cuisine, input.Dish, input.ServingSize, input.MaxPrepMinutes, + ) + if len(input.DietaryRestrictions) > 0 { + prompt += fmt.Sprintf(" Dietary restrictions: %v.", input.DietaryRestrictions) + } + return prompt, nil + }), + ) + + for result, err := range stream { + if err != nil { + return nil, fmt.Errorf("could not generate recipe: %w", err) + } + if result.Done { + return result.Output, nil + } + sendChunk(ctx, result.Chunk.Ingredients) + } + + return nil, nil + }) +} diff --git a/go/samples/basic/main.go b/go/samples/basic/main.go new file mode 100644 index 0000000000..2031340ac5 --- /dev/null +++ b/go/samples/basic/main.go @@ -0,0 +1,86 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "log" + "net/http" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "github.com/firebase/genkit/go/plugins/server" + "google.golang.org/genai" +) + +func main() { + ctx := context.Background() + + // Initialize Genkit with the Google AI plugin. When you pass nil for the + // Config parameter, the Google AI plugin will get the API key from the + // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended + // practice. + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + // Define a non-streaming flow that generates jokes about a given topic. + genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) { + if input == "" { + input = "airplane food" + } + + return genkit.GenerateText(ctx, g, + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithPrompt("Share a joke about %s.", input), + ) + }, + ) + + // Define a streaming flow that generates jokes about a given topic with passthrough streaming. + genkit.DefineStreamingFlow(g, "streamingJokesFlow", + func(ctx context.Context, input string, sendChunk ai.ModelStreamCallback) (string, error) { + if input == "" { + input = "airplane food" + } + + resp, err := genkit.Generate(ctx, g, + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithPrompt("Share a joke about %s.", input), + ai.WithStreaming(sendChunk), + ) + if err != nil { + return "", fmt.Errorf("could not generate joke: %w", err) + } + + return resp.Text(), nil + }, + ) + + // Optionally, start a web server to make the flow callable via HTTP. + mux := http.NewServeMux() + for _, a := range genkit.ListFlows(g) { + mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) + } + log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) +} diff --git a/go/samples/prompts-dir/main.go b/go/samples/prompts-dir/main.go deleted file mode 100644 index 59e5e83843..0000000000 --- a/go/samples/prompts-dir/main.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2025 Google LLC -// SPDX-License-Identifier: Apache-2.0 - -// [START main] -package main - -import ( - "context" - "errors" - - // Import Genkit and the Google AI plugin - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/googlegenai" -) - -func main() { - ctx := context.Background() - - g := genkit.Init(ctx, - genkit.WithPlugins(&googlegenai.GoogleAI{}), - genkit.WithPromptDir("prompts"), - ) - - type greetingStyle struct { - Style string `json:"style"` - Location string `json:"location"` - Name string `json:"name"` - } - - type greeting struct { - Greeting string `json:"greeting"` - } - - // Define a simple flow that prompts an LLM to generate greetings using a - // given style. - genkit.DefineFlow(g, "assistantGreetingFlow", func(ctx context.Context, input greetingStyle) (string, error) { - // Look up the prompt by name - prompt := genkit.LookupPrompt(g, "example") - if prompt == nil { - return "", errors.New("assistantGreetingFlow: failed to find prompt") - } - - // Execute the prompt with the provided input - resp, err := prompt.Execute(ctx, ai.WithInput(input)) - if err != nil { - return "", err - } - - var output greeting - if err = resp.Output(&output); err != nil { - return "", err - } - - return output.Greeting, nil - }) - - <-ctx.Done() -} - -// [END main] diff --git a/go/samples/prompts-dir/prompts/example.prompt b/go/samples/prompts-dir/prompts/example.prompt deleted file mode 100644 index 0492cfd326..0000000000 --- a/go/samples/prompts-dir/prompts/example.prompt +++ /dev/null @@ -1,19 +0,0 @@ ---- -model: googleai/gemini-2.5-flash -config: - temperature: 0.9 -input: - schema: - location: string - style?: string - name?: string - default: - name: Rutuja -output: - schema: - greeting: string ---- - -You are the world's most welcoming AI assistant and are currently working at {{location}}. - -Greet a guest{{#if name}} named {{name}}{{/if}}{{#if style}} in the style of {{style}}{{/if}}.