diff --git a/go/ai/embedder.go b/go/ai/embedder.go index b84e3df81b..d93d5b52fa 100644 --- a/go/ai/embedder.go +++ b/go/ai/embedder.go @@ -85,7 +85,7 @@ type EmbedderOptions struct { // embedder is an action with functions specific to converting documents to multidimensional vectors such as Embed(). type embedder struct { - core.ActionDef[*EmbedRequest, *EmbedResponse, struct{}] + core.Action[*EmbedRequest, *EmbedResponse, struct{}, struct{}] } // NewEmbedder creates a new [Embedder]. @@ -127,7 +127,7 @@ func NewEmbedder(name string, opts *EmbedderOptions, fn EmbedderFunc) Embedder { } return &embedder{ - ActionDef: *core.NewAction(name, api.ActionTypeEmbedder, metadata, inputSchema, fn), + Action: *core.NewAction(name, api.ActionTypeEmbedder, metadata, inputSchema, fn), } } @@ -143,12 +143,12 @@ func DefineEmbedder(r api.Registry, name string, opts *EmbedderOptions, fn Embed // It will try to resolve the embedder dynamically if the embedder is not found. // It returns nil if the embedder was not resolved. func LookupEmbedder(r api.Registry, name string) Embedder { - action := core.ResolveActionFor[*EmbedRequest, *EmbedResponse, struct{}](r, api.ActionTypeEmbedder, name) + action := core.ResolveActionFor[*EmbedRequest, *EmbedResponse, struct{}, struct{}](r, api.ActionTypeEmbedder, name) if action == nil { return nil } return &embedder{ - ActionDef: *action, + Action: *action, } } diff --git a/go/ai/evaluator.go b/go/ai/evaluator.go index aa536fac9b..dd79a511ba 100644 --- a/go/ai/evaluator.go +++ b/go/ai/evaluator.go @@ -72,7 +72,7 @@ func (e EvaluatorRef) Config() any { // evaluator is an action with functions specific to evaluating a dataset. type evaluator struct { - core.ActionDef[*EvaluatorRequest, *EvaluatorResponse, struct{}] + core.Action[*EvaluatorRequest, *EvaluatorResponse, struct{}, struct{}] } // Example is a single example that requires evaluation @@ -190,7 +190,7 @@ func NewEvaluator(name string, opts *EvaluatorOptions, fn EvaluatorFunc) Evaluat } return &evaluator{ - ActionDef: *core.NewAction(name, api.ActionTypeEvaluator, metadata, inputSchema, func(ctx context.Context, req *EvaluatorRequest) (output *EvaluatorResponse, err error) { + Action: *core.NewAction(name, api.ActionTypeEvaluator, metadata, inputSchema, func(ctx context.Context, req *EvaluatorRequest) (output *EvaluatorResponse, err error) { var results []EvaluationResult for _, datapoint := range req.Dataset { if datapoint.TestCaseId == "" { @@ -275,7 +275,7 @@ func NewBatchEvaluator(name string, opts *EvaluatorOptions, fn BatchEvaluatorFun } return &evaluator{ - ActionDef: *core.NewAction(name, api.ActionTypeEvaluator, metadata, nil, fn), + Action: *core.NewAction(name, api.ActionTypeEvaluator, metadata, nil, fn), } } @@ -291,12 +291,12 @@ func DefineBatchEvaluator(r api.Registry, name string, opts *EvaluatorOptions, f // LookupEvaluator looks up an [Evaluator] registered by [DefineEvaluator]. // It returns nil if the evaluator was not defined. func LookupEvaluator(r api.Registry, name string) Evaluator { - action := core.ResolveActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}](r, api.ActionTypeEvaluator, name) + action := core.ResolveActionFor[*EvaluatorRequest, *EvaluatorResponse, struct{}, struct{}](r, api.ActionTypeEvaluator, name) if action == nil { return nil } return &evaluator{ - ActionDef: *action, + Action: *action, } } diff --git a/go/ai/generate.go b/go/ai/generate.go index 003eb0b653..6aeb1e6642 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -71,11 +71,11 @@ type ModelMiddleware = core.Middleware[*ModelRequest, *ModelResponse, *ModelResp // model is an action with functions specific to model generation such as Generate(). type model struct { - core.ActionDef[*ModelRequest, *ModelResponse, *ModelResponseChunk] + core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk, struct{}] } // generateAction is the type for a utility model generation action that takes in a GenerateActionOptions instead of a ModelRequest. -type generateAction = core.ActionDef[*GenerateActionOptions, *ModelResponse, *ModelResponseChunk] +type generateAction = core.Action[*GenerateActionOptions, *ModelResponse, *ModelResponseChunk, struct{}] // result is a generic struct for parallel operation results with index, value, and error. type result[T any] struct { @@ -191,12 +191,12 @@ func DefineModel(r api.Registry, name string, opts *ModelOptions, fn ModelFunc) // It will try to resolve the model dynamically if the model is not found. // It returns nil if the model was not resolved. func LookupModel(r api.Registry, name string) Model { - action := core.ResolveActionFor[*ModelRequest, *ModelResponse, *ModelResponseChunk](r, api.ActionTypeModel, name) + action := core.ResolveActionFor[*ModelRequest, *ModelResponse, *ModelResponseChunk, struct{}](r, api.ActionTypeModel, name) if action == nil { return nil } return &model{ - ActionDef: *action, + Action: *action, } } @@ -699,7 +699,7 @@ func (m *model) Generate(ctx context.Context, req *ModelRequest, cb ModelStreamC return nil, core.NewError(core.INVALID_ARGUMENT, "Model.Generate: generate called on a nil model; check that all models are defined") } - return m.ActionDef.Run(ctx, req, cb) + return m.Action.Run(ctx, req, cb) } // supportsConstrained returns whether the model supports constrained output. @@ -708,7 +708,7 @@ func (m *model) supportsConstrained(hasTools bool) bool { return false } - metadata := m.ActionDef.Desc().Metadata + metadata := m.Action.Desc().Metadata if metadata == nil { return false } diff --git a/go/ai/prompt.go b/go/ai/prompt.go index 4d0151c4c8..b176805b36 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -52,7 +52,7 @@ type Prompt interface { // prompt is a prompt template that can be executed to generate a model response. type prompt struct { - core.ActionDef[any, *GenerateActionOptions, struct{}] + core.Action[any, *GenerateActionOptions, struct{}, struct{}] promptOptions registry api.Registry } @@ -124,7 +124,7 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt { metadata["prompt"] = promptMetadata } - p.ActionDef = *core.DefineAction(r, name, api.ActionTypeExecutablePrompt, metadata, p.InputSchema, p.buildRequest) + p.Action = *core.DefineAction(r, name, api.ActionTypeExecutablePrompt, metadata, p.InputSchema, p.buildRequest) return p } @@ -132,13 +132,13 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt { // LookupPrompt looks up a [Prompt] registered by [DefinePrompt]. // It returns nil if the prompt was not defined. func LookupPrompt(r api.Registry, name string) Prompt { - action := core.ResolveActionFor[any, *GenerateActionOptions, struct{}](r, api.ActionTypeExecutablePrompt, name) + action := core.ResolveActionFor[any, *GenerateActionOptions, struct{}, struct{}](r, api.ActionTypeExecutablePrompt, name) if action == nil { return nil } return &prompt{ - ActionDef: *action, - registry: r, + Action: *action, + registry: r, } } @@ -312,7 +312,7 @@ func (p *prompt) Render(ctx context.Context, input any) (*GenerateActionOptions, // Desc returns a descriptor of the prompt with resolved schema references. func (p *prompt) Desc() api.ActionDesc { - desc := p.ActionDef.Desc() + desc := p.Action.Desc() promptMeta := desc.Metadata["prompt"].(map[string]any) if inputMeta, ok := promptMeta["input"].(map[string]any); ok { if inputSchema, ok := inputMeta["schema"].(map[string]any); ok { diff --git a/go/ai/resource.go b/go/ai/resource.go index e84ca1193c..18e1823019 100644 --- a/go/ai/resource.go +++ b/go/ai/resource.go @@ -109,7 +109,7 @@ type ResourceFunc = func(context.Context, *ResourceInput) (*ResourceOutput, erro // It holds the underlying core action and allows looking up resources // by name without knowing their specific input/output api. type resource struct { - core.ActionDef[*ResourceInput, *ResourceOutput, struct{}] + core.Action[*ResourceInput, *ResourceOutput, struct{}, struct{}] } // Resource represents an instance of a resource. @@ -129,7 +129,7 @@ type Resource interface { // DefineResource creates a resource and registers it with the given Registry. func DefineResource(r api.Registry, name string, opts *ResourceOptions, fn ResourceFunc) Resource { metadata := resourceMetadata(name, opts) - return &resource{ActionDef: *core.DefineAction(r, name, api.ActionTypeResource, metadata, nil, fn)} + return &resource{Action: *core.DefineAction(r, name, api.ActionTypeResource, metadata, nil, fn)} } // NewResource creates a resource but does not register it in the registry. @@ -137,7 +137,7 @@ func DefineResource(r api.Registry, name string, opts *ResourceOptions, fn Resou func NewResource(name string, opts *ResourceOptions, fn ResourceFunc) Resource { metadata := resourceMetadata(name, opts) metadata["dynamic"] = true - return &resource{ActionDef: *core.NewAction(name, api.ActionTypeResource, metadata, nil, fn)} + return &resource{Action: *core.NewAction(name, api.ActionTypeResource, metadata, nil, fn)} } // resourceMetadata creates the metadata common to both DefineResource and NewResource. @@ -227,8 +227,8 @@ func (r *resource) Execute(ctx context.Context, input *ResourceInput) (*Resource // FindMatchingResource finds a resource that matches the given URI. func FindMatchingResource(r api.Registry, uri string) (Resource, *ResourceInput, error) { for _, a := range r.ListActions() { - if action, ok := a.(*core.ActionDef[*ResourceInput, *ResourceOutput, struct{}]); ok { - res := &resource{ActionDef: *action} + if action, ok := a.(*core.Action[*ResourceInput, *ResourceOutput, struct{}, struct{}]); ok { + res := &resource{Action: *action} if res.Matches(uri) { variables, err := res.ExtractVariables(uri) if err != nil { @@ -244,9 +244,9 @@ func FindMatchingResource(r api.Registry, uri string) (Resource, *ResourceInput, // LookupResource looks up the resource in the registry by provided name and returns it. func LookupResource(r api.Registry, name string) Resource { - action := core.ResolveActionFor[*ResourceInput, *ResourceOutput, struct{}](r, api.ActionTypeResource, name) + action := core.ResolveActionFor[*ResourceInput, *ResourceOutput, struct{}, struct{}](r, api.ActionTypeResource, name) if action == nil { return nil } - return &resource{ActionDef: *action} + return &resource{Action: *action} } diff --git a/go/ai/retriever.go b/go/ai/retriever.go index 64f048d6d8..9ab97f17ce 100644 --- a/go/ai/retriever.go +++ b/go/ai/retriever.go @@ -40,7 +40,7 @@ type Retriever interface { // retriever is an action with functions specific to document retrieval such as Retrieve(). type retriever struct { - core.ActionDef[*RetrieverRequest, *RetrieverResponse, struct{}] + core.Action[*RetrieverRequest, *RetrieverResponse, struct{}, struct{}] } // RetrieverArg is the interface for retriever arguments. It can either be the retriever action itself or a reference to be looked up. @@ -121,7 +121,7 @@ func NewRetriever(name string, opts *RetrieverOptions, fn RetrieverFunc) Retriev } return &retriever{ - ActionDef: *core.NewAction(name, api.ActionTypeRetriever, metadata, inputSchema, fn), + Action: *core.NewAction(name, api.ActionTypeRetriever, metadata, inputSchema, fn), } } @@ -136,12 +136,12 @@ func DefineRetriever(r api.Registry, name string, opts *RetrieverOptions, fn Ret // It will try to resolve the retriever dynamically if the retriever is not found. // It returns nil if the retriever was not resolved. func LookupRetriever(r api.Registry, name string) Retriever { - action := core.ResolveActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](r, api.ActionTypeRetriever, name) + action := core.ResolveActionFor[*RetrieverRequest, *RetrieverResponse, struct{}, struct{}](r, api.ActionTypeRetriever, name) if action == nil { return nil } return &retriever{ - ActionDef: *action, + Action: *action, } } diff --git a/go/core/action.go b/go/core/action.go index 50c1aa63a5..125e44961f 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -19,7 +19,9 @@ package core import ( "context" "encoding/json" + "iter" "reflect" + "sync" "time" "github.com/firebase/genkit/go/core/api" @@ -38,19 +40,25 @@ type StreamingFunc[In, Out, Stream any] = func(context.Context, In, StreamCallba // StreamCallback is a function that is called during streaming to return the next chunk of the stream. type StreamCallback[Stream any] = func(context.Context, Stream) error -// An ActionDef is a named, observable operation that underlies all Genkit primitives. +// BidiFunc is the function signature for bidirectional streaming actions. +// It receives initialization data, reads inputs from inCh, and writes +// streamed outputs to outCh. It returns a final output when complete. +type BidiFunc[In, Out, Stream, Init any] = func(ctx context.Context, init Init, inCh <-chan In, outCh chan<- Stream) (Out, error) + +// An Action is a named, observable operation that underlies all Genkit primitives. // It consists of a function that takes an input of type I and returns an output // of type O, optionally streaming values of type S incrementally by invoking a callback. // It optionally has other metadata, like a description and JSON Schemas for its input and // output which it validates against. // -// Each time an ActionDef is run, it results in a new trace span. +// Each time an Action is run, it results in a new trace span. // // For internal use only. -type ActionDef[In, Out, Stream any] struct { - fn StreamingFunc[In, Out, Stream] // Function that is called during runtime. May not actually support streaming. - desc *api.ActionDesc // Descriptor of the action. - registry api.Registry // Registry for schema resolution. Set when registered. +type Action[In, Out, Stream, Init any] struct { + fn StreamingFunc[In, Out, Stream] // Function that is called during runtime. May not actually support streaming. + bidiFn BidiFunc[In, Out, Stream, Init] // Non-nil for bidi actions only. + desc *api.ActionDesc // Descriptor of the action. + registry api.Registry // Registry for schema resolution. Set when registered. } type noStream = func(context.Context, struct{}) error @@ -63,8 +71,8 @@ func NewAction[In, Out any]( metadata map[string]any, inputSchema map[string]any, fn Func[In, Out], -) *ActionDef[In, Out, struct{}] { - return newAction(name, atype, metadata, inputSchema, +) *Action[In, Out, struct{}, struct{}] { + return newAction[In, Out, struct{}, struct{}](name, atype, metadata, inputSchema, func(ctx context.Context, in In, cb noStream) (Out, error) { return fn(ctx, in) }) @@ -78,8 +86,69 @@ func NewStreamingAction[In, Out, Stream any]( metadata map[string]any, inputSchema map[string]any, fn StreamingFunc[In, Out, Stream], -) *ActionDef[In, Out, Stream] { - return newAction(name, atype, metadata, inputSchema, fn) +) *Action[In, Out, Stream, struct{}] { + return newAction[In, Out, Stream, struct{}](name, atype, metadata, inputSchema, fn) +} + +// ActionOptions configures a bidi action. Nil schema fields are inferred from type parameters. +type ActionOptions struct { + Metadata map[string]any // Arbitrary key-value data attached to the action descriptor. + InputSchema map[string]any // JSON schema for the action's input. Inferred from In if nil. + OutputSchema map[string]any // JSON schema for the action's output. Inferred from Out if nil. + StreamSchema map[string]any // JSON schema for streamed chunks. Inferred from Stream if nil. Not used for non-streaming actions. + InitSchema map[string]any // JSON schema for bidi initialization data. Inferred from Init if nil. Not used for non-bidi actions. +} + +// NewBidiAction creates a new bidirectional streaming [Action] without registering it. +func NewBidiAction[In, Out, Stream, Init any]( + name string, + atype api.ActionType, + opts *ActionOptions, + fn BidiFunc[In, Out, Stream, Init], +) *Action[In, Out, Stream, Init] { + if opts == nil { + opts = &ActionOptions{} + } + + metadata := opts.Metadata + if metadata == nil { + metadata = map[string]any{} + } + metadata["bidi"] = true + + a := newAction[In, Out, Stream, Init](name, atype, metadata, opts.InputSchema, wrapBidiAsStreaming(fn)) + a.bidiFn = fn + + if opts.OutputSchema != nil { + a.desc.OutputSchema = opts.OutputSchema + } + if opts.StreamSchema != nil { + a.desc.StreamSchema = opts.StreamSchema + } + + if opts.InitSchema != nil { + a.desc.InitSchema = opts.InitSchema + } else { + var init Init + if reflect.ValueOf(init).Kind() != reflect.Invalid { + a.desc.InitSchema = InferSchemaMap(init) + } + } + + return a +} + +// DefineBidiAction creates and registers a bidirectional streaming [Action]. +func DefineBidiAction[In, Out, Stream, Init any]( + r api.Registry, + name string, + atype api.ActionType, + opts *ActionOptions, + fn BidiFunc[In, Out, Stream, Init], +) *Action[In, Out, Stream, Init] { + a := NewBidiAction(name, atype, opts, fn) + a.Register(r) + return a } // DefineAction creates a new non-streaming Action and registers it. @@ -91,8 +160,8 @@ func DefineAction[In, Out any]( metadata map[string]any, inputSchema map[string]any, fn Func[In, Out], -) *ActionDef[In, Out, struct{}] { - return defineAction(r, name, atype, metadata, inputSchema, +) *Action[In, Out, struct{}, struct{}] { + return defineAction[In, Out, struct{}, struct{}](r, name, atype, metadata, inputSchema, func(ctx context.Context, in In, cb noStream) (Out, error) { return fn(ctx, in) }) @@ -107,20 +176,20 @@ func DefineStreamingAction[In, Out, Stream any]( metadata map[string]any, inputSchema map[string]any, fn StreamingFunc[In, Out, Stream], -) *ActionDef[In, Out, Stream] { - return defineAction(r, name, atype, metadata, inputSchema, fn) +) *Action[In, Out, Stream, struct{}] { + return defineAction[In, Out, Stream, struct{}](r, name, atype, metadata, inputSchema, fn) } // defineAction creates an action and registers it with the given Registry. -func defineAction[In, Out, Stream any]( +func defineAction[In, Out, Stream, Init any]( r api.Registry, name string, atype api.ActionType, metadata map[string]any, inputSchema map[string]any, fn StreamingFunc[In, Out, Stream], -) *ActionDef[In, Out, Stream] { - a := newAction(name, atype, metadata, inputSchema, fn) +) *Action[In, Out, Stream, Init] { + a := newAction[In, Out, Stream, Init](name, atype, metadata, inputSchema, fn) a.Register(r) return a } @@ -128,13 +197,13 @@ func defineAction[In, Out, Stream any]( // newAction creates a new Action with the given name and arguments. // If registry is nil, tracing state is left nil to be set later. // If inputSchema is nil, it is inferred from In. -func newAction[In, Out, Stream any]( +func newAction[In, Out, Stream, Init any]( name string, atype api.ActionType, metadata map[string]any, inputSchema map[string]any, fn StreamingFunc[In, Out, Stream], -) *ActionDef[In, Out, Stream] { +) *Action[In, Out, Stream, Init] { if inputSchema == nil { var i In if reflect.ValueOf(i).Kind() != reflect.Invalid { @@ -148,12 +217,18 @@ func newAction[In, Out, Stream any]( outputSchema = InferSchemaMap(o) } + var s Stream + var streamSchema map[string]any + if reflect.ValueOf(s).Kind() != reflect.Invalid { + streamSchema = InferSchemaMap(s) + } + var description string if desc, ok := metadata["description"].(string); ok { description = desc } - return &ActionDef[In, Out, Stream]{ + return &Action[In, Out, Stream, Init]{ fn: func(ctx context.Context, input In, cb StreamCallback[Stream]) (Out, error) { return fn(ctx, input, cb) }, @@ -164,16 +239,17 @@ func newAction[In, Out, Stream any]( Description: description, InputSchema: inputSchema, OutputSchema: outputSchema, + StreamSchema: streamSchema, Metadata: metadata, }, } } // Name returns the Action's Name. -func (a *ActionDef[In, Out, Stream]) Name() string { return a.desc.Name } +func (a *Action[In, Out, Stream, Init]) Name() string { return a.desc.Name } // Run executes the Action's function in a new trace span. -func (a *ActionDef[In, Out, Stream]) Run(ctx context.Context, input In, cb StreamCallback[Stream]) (output Out, err error) { +func (a *Action[In, Out, Stream, Init]) Run(ctx context.Context, input In, cb StreamCallback[Stream]) (output Out, err error) { r, err := a.runWithTelemetry(ctx, input, cb) if err != nil { return base.Zero[Out](), err @@ -182,7 +258,7 @@ func (a *ActionDef[In, Out, Stream]) Run(ctx context.Context, input In, cb Strea } // Run executes the Action's function in a new trace span. -func (a *ActionDef[In, Out, Stream]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[Stream]) (output api.ActionRunResult[Out], err error) { +func (a *Action[In, Out, Stream, Init]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[Stream]) (output api.ActionRunResult[Out], err error) { inputBytes, _ := json.Marshal(input) logger.FromContext(ctx).Debug("Action.Run", "name", a.Name(), @@ -263,7 +339,7 @@ func (a *ActionDef[In, Out, Stream]) runWithTelemetry(ctx context.Context, input } // RunJSON runs the action with a JSON input, and returns a JSON result. -func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) { +func (a *Action[In, Out, Stream, Init]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) { r, err := a.RunJSONWithTelemetry(ctx, input, cb) if err != nil { return nil, err @@ -272,7 +348,7 @@ func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.Raw } // RunJSONWithTelemetry runs the action with a JSON input, and returns a JSON result along with telemetry info. -func (a *ActionDef[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { +func (a *Action[In, Out, Stream, Init]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { i, err := base.UnmarshalAndNormalize[In](input, a.desc.InputSchema) if err != nil { return nil, NewError(INVALID_ARGUMENT, err.Error()) @@ -310,27 +386,79 @@ func (a *ActionDef[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, i } // Desc returns a descriptor of the action with resolved schema references. -func (a *ActionDef[In, Out, Stream]) Desc() api.ActionDesc { +func (a *Action[In, Out, Stream, Init]) Desc() api.ActionDesc { return *a.desc } // Register registers the action with the given registry. -func (a *ActionDef[In, Out, Stream]) Register(r api.Registry) { +func (a *Action[In, Out, Stream, Init]) Register(r api.Registry) { a.registry = r r.RegisterAction(a.desc.Key, a) } +// StreamBidi starts a bidirectional streaming connection. +// Returns an error if the action is not a bidi action. +// A trace span is created that remains open for the lifetime of the connection. +func (a *Action[In, Out, Stream, Init]) StreamBidi(ctx context.Context, init Init) (*BidiConnection[In, Out, Stream], error) { + if a.bidiFn == nil { + return nil, NewError(FAILED_PRECONDITION, "StreamBidi called on non-bidi action %q", a.desc.Name) + } + + ctx, cancel := context.WithCancel(ctx) + conn := &BidiConnection[In, Out, Stream]{ + inputCh: make(chan In, 1), + streamCh: make(chan Stream, 1), + doneCh: make(chan struct{}), + ctx: ctx, + cancel: cancel, + } + + spanMetadata := &tracing.SpanMetadata{ + Name: a.desc.Name, + Type: "action", + Subtype: string(a.desc.Type), + Metadata: make(map[string]string), + } + if flowName := FlowNameFromContext(ctx); flowName != "" { + spanMetadata.Metadata["flow:name"] = flowName + } + + go func() { + defer close(conn.doneCh) + defer close(conn.streamCh) + output, err := tracing.RunInNewSpan(conn.ctx, spanMetadata, init, + func(ctx context.Context, init Init) (Out, error) { + start := time.Now() + output, err := a.bidiFn(ctx, init, conn.inputCh, conn.streamCh) + latency := time.Since(start) + if err != nil { + metrics.WriteActionFailure(ctx, a.desc.Name, latency, err) + } else { + metrics.WriteActionSuccess(ctx, a.desc.Name, latency) + } + return output, err + }, + ) + conn.mu.Lock() + conn.output = output + conn.err = err + conn.mu.Unlock() + }() + + return conn, nil +} + // ResolveActionFor returns the action for the given key in the global registry, // or nil if there is none. // It panics if the action is of the wrong api. -func ResolveActionFor[In, Out, Stream any](r api.Registry, atype api.ActionType, name string) *ActionDef[In, Out, Stream] { +func ResolveActionFor[In, Out, Stream, Init any](r api.Registry, atype api.ActionType, name string) *Action[In, Out, Stream, Init] { provider, id := api.ParseName(name) key := api.NewKey(atype, provider, id) a := r.ResolveAction(key) if a == nil { return nil } - return a.(*ActionDef[In, Out, Stream]) + return a.(*Action[In, Out, Stream, Init]) } // LookupActionFor returns the action for the given key in the global registry, @@ -338,12 +466,138 @@ func ResolveActionFor[In, Out, Stream any](r api.Registry, atype api.ActionType, // It panics if the action is of the wrong api. // // Deprecated: Use ResolveActionFor. -func LookupActionFor[In, Out, Stream any](r api.Registry, atype api.ActionType, name string) *ActionDef[In, Out, Stream] { +func LookupActionFor[In, Out, Stream, Init any](r api.Registry, atype api.ActionType, name string) *Action[In, Out, Stream, Init] { provider, id := api.ParseName(name) key := api.NewKey(atype, provider, id) a := r.LookupAction(key) if a == nil { return nil } - return a.(*ActionDef[In, Out, Stream]) + return a.(*Action[In, Out, Stream, Init]) +} + +// wrapBidiAsStreaming wraps a BidiFunc into a StreamingFunc for use with Run/RunJSON. +// The input is sent as a single message, and stream chunks are forwarded to the callback. +func wrapBidiAsStreaming[In, Out, Stream, Init any](fn BidiFunc[In, Out, Stream, Init]) StreamingFunc[In, Out, Stream] { + return func(ctx context.Context, input In, cb StreamCallback[Stream]) (Out, error) { + inCh := make(chan In, 1) + outCh := make(chan Stream, 1) + doneCh := make(chan struct{}) + + var output Out + var fnErr error + + go func() { + defer close(doneCh) + defer close(outCh) + var init Init + output, fnErr = fn(ctx, init, inCh, outCh) + }() + + // Send the single input and close. + inCh <- input + close(inCh) + + // Forward streamed chunks to the callback. + if cb != nil { + for chunk := range outCh { + if err := cb(ctx, chunk); err != nil { + return base.Zero[Out](), err + } + } + } else { + // Drain the channel even without a callback. + for range outCh { + } + } + + <-doneCh + return output, fnErr + } +} + +// BidiConnection represents an active bidirectional streaming session. +type BidiConnection[In, Out, Stream any] struct { + inputCh chan In + streamCh chan Stream + doneCh chan struct{} + output Out + err error + ctx context.Context + cancel context.CancelFunc + mu sync.Mutex + closed bool +} + +// Send sends an input message to the bidi action. +// Returns an error if the connection is closed or the context is cancelled. +func (c *BidiConnection[In, Out, Stream]) Send(input In) (err error) { + defer func() { + if r := recover(); r != nil { + err = NewError(FAILED_PRECONDITION, "connection is closed") + } + }() + + select { + case c.inputCh <- input: + return nil + case <-c.ctx.Done(): + return c.ctx.Err() + case <-c.doneCh: + return NewError(FAILED_PRECONDITION, "action has completed") + } +} + +// Close signals that no more inputs will be sent. +func (c *BidiConnection[In, Out, Stream]) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return nil + } + c.closed = true + close(c.inputCh) + return nil +} + +// Receive returns an iterator for receiving streamed response chunks. +// The iterator completes when the action finishes. +func (c *BidiConnection[In, Out, Stream]) Receive() iter.Seq2[Stream, error] { + return func(yield func(Stream, error) bool) { + for { + select { + case chunk, ok := <-c.streamCh: + if !ok { + return + } + if !yield(chunk, nil) { + c.cancel() + return + } + case <-c.ctx.Done(): + var zero Stream + yield(zero, c.ctx.Err()) + return + } + } + } +} + +// Output returns the final output after the action completes. +// Blocks until done or context cancelled. +func (c *BidiConnection[In, Out, Stream]) Output() (Out, error) { + select { + case <-c.doneCh: + c.mu.Lock() + defer c.mu.Unlock() + return c.output, c.err + case <-c.ctx.Done(): + var zero Out + return zero, c.ctx.Err() + } +} + +// Done returns a channel that is closed when the connection completes. +func (c *BidiConnection[In, Out, Stream]) Done() <-chan struct{} { + return c.doneCh } diff --git a/go/core/action_test.go b/go/core/action_test.go index 65309d850d..93cd1471eb 100644 --- a/go/core/action_test.go +++ b/go/core/action_test.go @@ -34,7 +34,7 @@ func inc(_ context.Context, x int, _ noStream) (int, error) { func TestActionRun(t *testing.T) { r := registry.New() - a := defineAction(r, "test/inc", api.ActionTypeCustom, nil, nil, inc) + a := DefineStreamingAction(r, "test/inc", api.ActionTypeCustom, nil, nil, inc) got, err := a.Run(context.Background(), 3, nil) if err != nil { t.Fatal(err) @@ -46,7 +46,7 @@ func TestActionRun(t *testing.T) { func TestActionRunJSON(t *testing.T) { r := registry.New() - a := defineAction(r, "test/inc", api.ActionTypeCustom, nil, nil, inc) + a := DefineStreamingAction(r, "test/inc", api.ActionTypeCustom, nil, nil, inc) input := []byte("3") want := []byte("4") got, err := a.RunJSON(context.Background(), input, nil) @@ -73,7 +73,7 @@ func count(ctx context.Context, n int, cb func(context.Context, int) error) (int func TestActionStreaming(t *testing.T) { ctx := context.Background() r := registry.New() - a := defineAction(r, "test/count", api.ActionTypeCustom, nil, nil, count) + a := DefineStreamingAction(r, "test/count", api.ActionTypeCustom, nil, nil, count) const n = 3 // Non-streaming. @@ -108,7 +108,7 @@ func TestActionTracing(t *testing.T) { tc := tracing.NewTestOnlyTelemetryClient() tracing.WriteTelemetryImmediate(tc) name := api.NewName("test", "TestTracing-inc") - a := defineAction(r, name, api.ActionTypeCustom, nil, nil, inc) + a := DefineStreamingAction(r, name, api.ActionTypeCustom, nil, nil, inc) if _, err := a.Run(context.Background(), 3, nil); err != nil { t.Fatal(err) } @@ -309,7 +309,7 @@ func TestResolveActionFor(t *testing.T) { } DefineAction(r, "test/resolvable", api.ActionTypeCustom, nil, nil, fn) - found := ResolveActionFor[int, int, struct{}](r, api.ActionTypeCustom, "test/resolvable") + found := ResolveActionFor[int, int, struct{}, struct{}](r, api.ActionTypeCustom, "test/resolvable") if found == nil { t.Fatal("ResolveActionFor returned nil") @@ -322,7 +322,7 @@ func TestResolveActionFor(t *testing.T) { t.Run("returns nil for non-existent action", func(t *testing.T) { r := registry.New() - found := ResolveActionFor[int, int, struct{}](r, api.ActionTypeCustom, "test/nonexistent") + found := ResolveActionFor[int, int, struct{}, struct{}](r, api.ActionTypeCustom, "test/nonexistent") if found != nil { t.Errorf("ResolveActionFor returned %v, want nil", found) @@ -338,7 +338,7 @@ func TestLookupActionFor(t *testing.T) { } DefineAction(r, "test/lookupable", api.ActionTypeCustom, nil, nil, fn) - found := LookupActionFor[string, string, struct{}](r, api.ActionTypeCustom, "test/lookupable") + found := LookupActionFor[string, string, struct{}, struct{}](r, api.ActionTypeCustom, "test/lookupable") if found == nil { t.Fatal("LookupActionFor returned nil") @@ -348,7 +348,7 @@ func TestLookupActionFor(t *testing.T) { t.Run("returns nil for non-existent action", func(t *testing.T) { r := registry.New() - found := LookupActionFor[string, string, struct{}](r, api.ActionTypeCustom, "test/missing") + found := LookupActionFor[string, string, struct{}, struct{}](r, api.ActionTypeCustom, "test/missing") if found != nil { t.Errorf("LookupActionFor returned %v, want nil", found) diff --git a/go/core/api/action.go b/go/core/api/action.go index 18ffcfa67b..a38958af51 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -68,11 +68,13 @@ const ( // ActionDesc is a descriptor of an action. type ActionDesc struct { - Type ActionType `json:"type"` // Type of the action. - Key string `json:"key"` // Key of the action. - Name string `json:"name"` // Name of the action. - Description string `json:"description"` // Description of the action. - InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. - OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. - Metadata map[string]any `json:"metadata"` // Metadata for the action. + Type ActionType `json:"type"` // Type of the action. + Key string `json:"key"` // Key of the action. + Name string `json:"name"` // Name of the action. + Description string `json:"description"` // Description of the action. + InputSchema map[string]any `json:"inputSchema"` // JSON schema to validate against the action's input. + OutputSchema map[string]any `json:"outputSchema"` // JSON schema to validate against the action's output. + StreamSchema map[string]any `json:"streamSchema,omitempty"` // JSON schema to validate against the action's streamed chunks. + InitSchema map[string]any `json:"initSchema,omitempty"` // JSON schema to validate against the action's initialization data. + Metadata map[string]any `json:"metadata"` // Metadata for the action. } diff --git a/go/core/background_action.go b/go/core/background_action.go index e6af50399b..c41aaed05a 100644 --- a/go/core/background_action.go +++ b/go/core/background_action.go @@ -45,10 +45,10 @@ type Operation[Out any] struct { // // For internal use only. type BackgroundActionDef[In, Out any] struct { - *ActionDef[In, *Operation[Out], struct{}] + *Action[In, *Operation[Out], struct{}, struct{}] - check *ActionDef[*Operation[Out], *Operation[Out], struct{}] // Sub-action that checks the status of a background operation. - cancel *ActionDef[*Operation[Out], *Operation[Out], struct{}] // Sub-action that cancels a background operation. + check *Action[*Operation[Out], *Operation[Out], struct{}, struct{}] // Sub-action that checks the status of a background operation. + cancel *Action[*Operation[Out], *Operation[Out], struct{}, struct{}] // Sub-action that cancels a background operation. } // Start starts a background operation. @@ -77,7 +77,7 @@ func (b *BackgroundActionDef[In, Out]) SupportsCancel() bool { // Register registers the model with the given registry. func (b *BackgroundActionDef[In, Out]) Register(r api.Registry) { - b.ActionDef.Register(r) + b.Action.Register(r) b.check.Register(r) if b.cancel != nil { b.cancel.Register(r) @@ -140,7 +140,7 @@ func NewBackgroundAction[In, Out any]( return updatedOp, nil }) - var cancelAction *ActionDef[*Operation[Out], *Operation[Out], struct{}] + var cancelAction *Action[*Operation[Out], *Operation[Out], struct{}, struct{}] if cancelFn != nil { cancelAction = NewAction(name, api.ActionTypeCancelOperation, metadata, nil, func(ctx context.Context, op *Operation[Out]) (*Operation[Out], error) { @@ -154,9 +154,9 @@ func NewBackgroundAction[In, Out any]( } return &BackgroundActionDef[In, Out]{ - ActionDef: startAction, - check: checkAction, - cancel: cancelAction, + Action: startAction, + check: checkAction, + cancel: cancelAction, } } @@ -165,22 +165,22 @@ func LookupBackgroundAction[In, Out any](r api.Registry, key string) *Background atype, provider, id := api.ParseKey(key) name := api.NewName(provider, id) - startAction := ResolveActionFor[In, *Operation[Out], struct{}](r, atype, name) + startAction := ResolveActionFor[In, *Operation[Out], struct{}, struct{}](r, atype, name) if startAction == nil { return nil } - checkAction := ResolveActionFor[*Operation[Out], *Operation[Out], struct{}](r, api.ActionTypeCheckOperation, name) + checkAction := ResolveActionFor[*Operation[Out], *Operation[Out], struct{}, struct{}](r, api.ActionTypeCheckOperation, name) if checkAction == nil { return nil } - cancelAction := ResolveActionFor[*Operation[Out], *Operation[Out], struct{}](r, api.ActionTypeCancelOperation, name) + cancelAction := ResolveActionFor[*Operation[Out], *Operation[Out], struct{}, struct{}](r, api.ActionTypeCancelOperation, name) return &BackgroundActionDef[In, Out]{ - ActionDef: startAction, - check: checkAction, - cancel: cancelAction, + Action: startAction, + check: checkAction, + cancel: cancelAction, } } diff --git a/go/core/flow.go b/go/core/flow.go index ea514365c2..c173a0306c 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -18,7 +18,6 @@ package core import ( "context" - "encoding/json" "errors" "fmt" @@ -27,8 +26,11 @@ import ( "github.com/firebase/genkit/go/internal/base" ) -// A Flow is a user-defined Action. A Flow[In, Out, Stream] represents a function from In to Out. The Stream parameter is for flows that support streaming: providing their results incrementally. -type Flow[In, Out, Stream any] ActionDef[In, Out, Stream] +// A Flow is a user-defined Action. A Flow[In, Out, Stream, Init] represents a function from In to Out. +// The Stream parameter is for flows that support streaming: providing their results incrementally. The Init parameter is for bidi flows. +type Flow[In, Out, Stream, Init any] struct { + *Action[In, Out, Stream, Init] +} // StreamingFlowValue is either a streamed value or a final output of a flow. type StreamingFlowValue[Out, Stream any] struct { @@ -46,14 +48,14 @@ type flowContext struct { } // DefineFlow creates a Flow that runs fn, and registers it as an action. fn takes an input of type In and returns an output of type Out. -func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flow[In, Out, struct{}] { - return (*Flow[In, Out, struct{}])(DefineAction(r, name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In) (Out, error) { +func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flow[In, Out, struct{}, struct{}] { + return &Flow[In, Out, struct{}, struct{}]{DefineAction(r, name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In) (Out, error) { fc := &flowContext{ flowName: name, } ctx = flowContextKey.NewContext(ctx, fc) return fn(ctx, input) - })) + })} } // DefineStreamingFlow creates a streaming Flow that runs fn, and registers it as an action. @@ -65,8 +67,8 @@ func DefineFlow[In, Out any](r api.Registry, name string, fn Func[In, Out]) *Flo // stream the results by invoking the callback periodically, ultimately returning // with a final return value that includes all the streamed data. // Otherwise, it should ignore the callback and just return a result. -func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream] { - return (*Flow[In, Out, Stream])(DefineStreamingAction(r, name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) { +func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn StreamingFunc[In, Out, Stream]) *Flow[In, Out, Stream, struct{}] { + return &Flow[In, Out, Stream, struct{}]{DefineStreamingAction(r, name, api.ActionTypeFlow, nil, nil, func(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) { fc := &flowContext{ flowName: name, } @@ -75,7 +77,25 @@ func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn St cb = func(context.Context, Stream) error { return nil } } return fn(ctx, input, cb) - })) + })} +} + +// NewBidiFlow creates a bidirectional streaming Flow without registering it. +// Flow context is injected so that [Run] works inside the bidi function. +func NewBidiFlow[In, Out, Stream, Init any](name string, fn BidiFunc[In, Out, Stream, Init]) *Flow[In, Out, Stream, Init] { + wrapped := func(ctx context.Context, init Init, inCh <-chan In, outCh chan<- Stream) (Out, error) { + ctx = flowContextKey.NewContext(ctx, &flowContext{flowName: name}) + return fn(ctx, init, inCh, outCh) + } + return &Flow[In, Out, Stream, Init]{NewBidiAction(name, api.ActionTypeFlow, nil, wrapped)} +} + +// DefineBidiFlow creates a bidirectional streaming Flow that runs fn, and registers it as an action. +// Flow context is injected so that [Run] works inside the bidi function. +func DefineBidiFlow[In, Out, Stream, Init any](r api.Registry, name string, fn BidiFunc[In, Out, Stream, Init]) *Flow[In, Out, Stream, Init] { + f := NewBidiFlow[In, Out, Stream, Init](name, fn) + f.Register(r) + return f } // Run runs the function f in the context of the current flow @@ -105,29 +125,9 @@ func Run[Out any](ctx context.Context, name string, fn func() (Out, error)) (Out }) } -// Name returns the name of the flow. -func (f *Flow[In, Out, Stream]) Name() string { - return (*ActionDef[In, Out, Stream])(f).Name() -} - -// RunJSON runs the flow with JSON input and streaming callback and returns the output as JSON. -func (f *Flow[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) { - return (*ActionDef[In, Out, Stream])(f).RunJSON(ctx, input, cb) -} - -// RunJSONWithTelemetry runs the flow with JSON input and streaming callback and returns the output as JSON along with telemetry info. -func (f *Flow[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { - return (*ActionDef[In, Out, Stream])(f).RunJSONWithTelemetry(ctx, input, cb) -} - -// Desc returns the descriptor of the flow. -func (f *Flow[In, Out, Stream]) Desc() api.ActionDesc { - return (*ActionDef[In, Out, Stream])(f).Desc() -} - // Run runs the flow in the context of another flow. -func (f *Flow[In, Out, Stream]) Run(ctx context.Context, input In) (Out, error) { - return (*ActionDef[In, Out, Stream])(f).Run(ctx, input, nil) +func (f *Flow[In, Out, Stream, Init]) Run(ctx context.Context, input In) (Out, error) { + return f.Action.Run(ctx, input, nil) } // Stream runs the flow in the context of another flow and streams the output. @@ -142,7 +142,7 @@ func (f *Flow[In, Out, Stream]) Run(ctx context.Context, input In) (Out, error) // again. // // Otherwise the Stream field of the passed [StreamingFlowValue] holds a streamed result. -func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func(*StreamingFlowValue[Out, Stream], error) bool) { +func (f *Flow[In, Out, Stream, Init]) Stream(ctx context.Context, input In) func(func(*StreamingFlowValue[Out, Stream], error) bool) { return func(yield func(*StreamingFlowValue[Out, Stream], error) bool) { cb := func(ctx context.Context, s Stream) error { if ctx.Err() != nil { @@ -153,7 +153,7 @@ func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func( } return nil } - output, err := (*ActionDef[In, Out, Stream])(f).Run(ctx, input, cb) + output, err := f.Action.Run(ctx, input, cb) if err != nil { yield(nil, err) } else { @@ -162,11 +162,6 @@ func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func( } } -// Register registers the flow with the given registry. -func (f *Flow[In, Out, Stream]) Register(r api.Registry) { - (*ActionDef[In, Out, Stream])(f).Register(r) -} - var errStop = errors.New("stop") // FlowNameFromContext returns the flow name from context if we're in a flow, empty string otherwise. diff --git a/go/core/flow_test.go b/go/core/flow_test.go index e3c3e6b463..7da8d31778 100644 --- a/go/core/flow_test.go +++ b/go/core/flow_test.go @@ -18,9 +18,12 @@ package core import ( "context" + "fmt" "slices" + "strings" "testing" + "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/internal/registry" ) @@ -69,7 +72,7 @@ func TestRunFlow(t *testing.T) { func TestFlowNameFromContext(t *testing.T) { r := registry.New() - flows := []*Flow[struct{}, string, struct{}]{ + flows := []*Flow[struct{}, string, struct{}, struct{}]{ DefineFlow(r, "DefineFlow", func(ctx context.Context, _ struct{}) (string, error) { return FlowNameFromContext(ctx), nil }), @@ -257,3 +260,302 @@ func TestFlowNameFromContextOutsideFlow(t *testing.T) { } }) } + +func TestBidiActionEcho(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "echo", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + var count int + for input := range inCh { + count++ + outCh <- fmt.Sprintf("echo: %s", input) + } + return fmt.Sprintf("processed %d messages", count), nil + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + // With unbuffered channels, we must send and receive concurrently. + go func() { + conn.Send("hello") + conn.Send("world") + conn.Close() + }() + + var chunks []string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatal(err) + } + chunks = append(chunks, chunk) + } + + if len(chunks) != 2 { + t.Fatalf("expected 2 chunks, got %d: %v", len(chunks), chunks) + } + if chunks[0] != "echo: hello" { + t.Errorf("expected 'echo: hello', got %q", chunks[0]) + } + if chunks[1] != "echo: world" { + t.Errorf("expected 'echo: world', got %q", chunks[1]) + } + + output, err := conn.Output() + if err != nil { + t.Fatal(err) + } + if output != "processed 2 messages" { + t.Errorf("expected 'processed 2 messages', got %q", output) + } +} + +func TestBidiActionWithInit(t *testing.T) { + ctx := context.Background() + + type Config struct { + Prefix string + } + + action := NewBidiAction( + "prefixed", api.ActionTypeCustom, nil, + func(ctx context.Context, init Config, inCh <-chan string, outCh chan<- string) (string, error) { + for input := range inCh { + outCh <- fmt.Sprintf("%s: %s", init.Prefix, input) + } + return "done", nil + }, + ) + + conn, err := action.StreamBidi(ctx, Config{Prefix: "INFO"}) + if err != nil { + t.Fatal(err) + } + + go func() { + conn.Send("test message") + conn.Close() + }() + + var chunks []string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatal(err) + } + chunks = append(chunks, chunk) + } + + if len(chunks) != 1 || chunks[0] != "INFO: test message" { + t.Errorf("unexpected chunks: %v", chunks) + } +} + +func TestBidiConnectionSendAfterClose(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "test", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for range inCh { + } + return "", nil + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + conn.Close() + // Wait for completion so we know the state is settled. + <-conn.Done() + + if err := conn.Send("after close"); err == nil { + t.Error("expected error sending after close") + } +} + +func TestBidiConnectionContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + action := NewBidiAction( + "blocking", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + <-ctx.Done() + return "", ctx.Err() + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + cancel() + + _, err = conn.Output() + if err == nil { + t.Error("expected error after context cancellation") + } +} + +func TestBidiFlowRegistration(t *testing.T) { + r := registry.New() + + flow := DefineBidiFlow( + r, "echoFlow", + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for input := range inCh { + outCh <- input + } + return "done", nil + }, + ) + + if flow.Name() != "echoFlow" { + t.Errorf("expected name 'echoFlow', got %q", flow.Name()) + } + + desc := flow.Desc() + if desc.Type != api.ActionTypeFlow { + t.Errorf("expected type %q, got %q", api.ActionTypeFlow, desc.Type) + } + + // Verify bidi metadata is set. + if bidi, ok := desc.Metadata["bidi"].(bool); !ok || !bidi { + t.Error("expected metadata[\"bidi\"] = true") + } + + // Verify registered in registry. + action := r.LookupAction(desc.Key) + if action == nil { + t.Error("expected action to be registered") + } +} + +func TestBidiFlowEcho(t *testing.T) { + r := registry.New() + ctx := context.Background() + + flow := DefineBidiFlow( + r, "echoFlow", + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + var count int + for input := range inCh { + count++ + outCh <- fmt.Sprintf("echo: %s", input) + } + return fmt.Sprintf("processed %d", count), nil + }, + ) + + conn, err := flow.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + go func() { + conn.Send("a") + conn.Send("b") + conn.Close() + }() + + var chunks []string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatal(err) + } + chunks = append(chunks, chunk) + } + + if len(chunks) != 2 { + t.Fatalf("expected 2 chunks, got %d", len(chunks)) + } + + output, err := conn.Output() + if err != nil { + t.Fatal(err) + } + if output != "processed 2" { + t.Errorf("expected 'processed 2', got %q", output) + } +} + +func TestBidiFlowCoreRunWorks(t *testing.T) { + r := registry.New() + ctx := context.Background() + + flow := DefineBidiFlow( + r, "withSteps", + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for input := range inCh { + // core.Run should work inside a BidiFlow. + result, err := Run(ctx, "uppercase", func() (string, error) { + return strings.ToUpper(input), nil + }) + if err != nil { + return "", err + } + outCh <- result + } + return "done", nil + }, + ) + + conn, err := flow.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + go func() { + conn.Send("hello") + conn.Close() + }() + + var chunks []string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatal(err) + } + chunks = append(chunks, chunk) + } + + if len(chunks) != 1 || chunks[0] != "HELLO" { + t.Errorf("expected [HELLO], got %v", chunks) + } +} + +func TestBidiActionDone(t *testing.T) { + ctx := context.Background() + + action := NewBidiAction( + "quick", api.ActionTypeCustom, nil, + func(ctx context.Context, _ struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + for range inCh { + } + return "finished", nil + }, + ) + + conn, err := action.StreamBidi(ctx, struct{}{}) + if err != nil { + t.Fatal(err) + } + + conn.Close() + <-conn.Done() + + output, err := conn.Output() + if err != nil { + t.Fatal(err) + } + if output != "finished" { + t.Errorf("expected 'finished', got %q", output) + } +} diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 377fb5e836..40e137a2b8 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -304,7 +304,7 @@ func RegisterAction(g *Genkit, action api.Registerable) { // // handle error // } // fmt.Println(result) // Output: Hello, World! -func DefineFlow[In, Out any](g *Genkit, name string, fn core.Func[In, Out]) *core.Flow[In, Out, struct{}] { +func DefineFlow[In, Out any](g *Genkit, name string, fn core.Func[In, Out]) *core.Flow[In, Out, struct{}, struct{}] { return core.DefineFlow(g.reg, name, fn) } @@ -355,10 +355,58 @@ func DefineFlow[In, Out any](g *Genkit, name string, fn core.Func[In, Out]) *cor // fmt.Println("Stream Chunk:", result.Stream) // Outputs: 1, 2, 3, 4, 5 // } // } -func DefineStreamingFlow[In, Out, Stream any](g *Genkit, name string, fn core.StreamingFunc[In, Out, Stream]) *core.Flow[In, Out, Stream] { +func DefineStreamingFlow[In, Out, Stream any](g *Genkit, name string, fn core.StreamingFunc[In, Out, Stream]) *core.Flow[In, Out, Stream, struct{}] { return core.DefineStreamingFlow(g.reg, name, fn) } +// DefineBidiFlow defines a bidirectional streaming flow, registers it as a [core.Action] of type Flow, +// and returns a [core.Flow] capable of bidirectional streaming. +// +// The provided function `fn` receives initialization data of type `Init`, reads +// inputs of type `In` from an input channel, and writes streamed outputs of type +// `Stream` to an output channel. It returns a final output of type `Out` when complete. +// +// Example: +// +// chatFlow := genkit.DefineBidiFlow(g, "chat", +// func(ctx context.Context, init struct{}, inCh <-chan string, outCh chan<- string) (string, error) { +// var count int +// for input := range inCh { +// count++ +// outCh <- fmt.Sprintf("reply: %s", input) +// } +// return fmt.Sprintf("processed %d messages", count), nil +// }, +// ) +// +// // Start a bidi connection: +// conn, err := chatFlow.StreamBidi(ctx, struct{}{}) +// if err != nil { +// // handle error +// } +// +// // Send messages concurrently: +// go func() { +// conn.Send("hello") +// conn.Send("world") +// conn.Close() +// }() +// +// // Receive streamed responses: +// for chunk, err := range conn.Receive() { +// if err != nil { +// // handle error +// } +// fmt.Println(chunk) // Outputs: "reply: hello", "reply: world" +// } +// +// // Get the final output: +// output, err := conn.Output() +// fmt.Println(output) // Output: "processed 2 messages" +func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.BidiFunc[In, Out, Stream, Init]) *core.Flow[In, Out, Stream, Init] { + return core.DefineBidiFlow(g.reg, name, fn) +} + // Run executes the given function `fn` within the context of the current flow run, // creating a distinct trace span for this step. It's used to add observability // to specific sub-operations within a flow defined by [DefineFlow] or [DefineStreamingFlow].