diff --git a/go/ai/x/option.go b/go/ai/x/option.go new file mode 100644 index 0000000000..d7380b4b83 --- /dev/null +++ b/go/ai/x/option.go @@ -0,0 +1,111 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package aix + +import "errors" + +// --- SessionFlowOption --- + +// SessionFlowOption configures a SessionFlow. +type SessionFlowOption[State any] interface { + applySessionFlow(*sessionFlowOptions[State]) error +} + +type sessionFlowOptions[State any] struct { + store SnapshotStore[State] + callback SnapshotCallback[State] +} + +func (o *sessionFlowOptions[State]) applySessionFlow(opts *sessionFlowOptions[State]) error { + if o.store != nil { + if opts.store != nil { + return errors.New("cannot set snapshot store more than once (WithSnapshotStore)") + } + opts.store = o.store + } + if o.callback != nil { + if opts.callback != nil { + return errors.New("cannot set snapshot callback more than once (WithSnapshotCallback)") + } + opts.callback = o.callback + } + return nil +} + +// WithSnapshotStore sets the store for persisting snapshots. +func WithSnapshotStore[State any](store SnapshotStore[State]) SessionFlowOption[State] { + return &sessionFlowOptions[State]{store: store} +} + +// WithSnapshotCallback configures when snapshots are created. +// If not provided and a store is configured, snapshots are always created. +func WithSnapshotCallback[State any](cb SnapshotCallback[State]) SessionFlowOption[State] { + return &sessionFlowOptions[State]{callback: cb} +} + +// --- StreamBidiOption --- + +// StreamBidiOption configures a StreamBidi call. +type StreamBidiOption[State any] interface { + applyStreamBidi(*streamBidiOptions[State]) error +} + +type streamBidiOptions[State any] struct { + state *SessionState[State] + snapshotID string + promptInput any +} + +func (o *streamBidiOptions[State]) applyStreamBidi(opts *streamBidiOptions[State]) error { + if o.state != nil { + if opts.state != nil { + return errors.New("cannot set state more than once (WithState)") + } + opts.state = o.state + } + if o.snapshotID != "" { + if opts.snapshotID != "" { + return errors.New("cannot set snapshot ID more than once (WithSnapshotID)") + } + opts.snapshotID = o.snapshotID + } + if o.promptInput != nil { + if opts.promptInput != nil { + return errors.New("cannot set prompt input more than once (WithPromptInput)") + } + opts.promptInput = o.promptInput + } + return nil +} + +// WithState sets the initial state for the invocation. +// Use this for client-managed state where the client sends state directly. +func WithState[State any](state *SessionState[State]) StreamBidiOption[State] { + return &streamBidiOptions[State]{state: state} +} + +// WithSnapshotID loads state from a persisted snapshot by ID. +// Use this for server-managed state where snapshots are stored. +func WithSnapshotID[State any](id string) StreamBidiOption[State] { + return &streamBidiOptions[State]{snapshotID: id} +} + +// WithPromptInput overrides the default prompt input for a prompt-backed session flow. +// Used with DefineSessionFlowFromPrompt to customize the prompt rendering per invocation. +func WithPromptInput[State any](input any) StreamBidiOption[State] { + return &streamBidiOptions[State]{promptInput: input} +} diff --git a/go/ai/x/prompt_session_flow_test.go b/go/ai/x/prompt_session_flow_test.go new file mode 100644 index 0000000000..70dade4e8c --- /dev/null +++ b/go/ai/x/prompt_session_flow_test.go @@ -0,0 +1,375 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package aix + +import ( + "context" + "strings" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/internal/registry" +) + +// setupPromptTestRegistry creates a registry with an echo model and generate action. +func setupPromptTestRegistry(t *testing.T) *registry.Registry { + t.Helper() + reg := registry.New() + ctx := context.Background() + + ai.ConfigureFormats(reg) + ai.DefineModel(reg, "test/echo", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + // Echo back the last user message text. + var text string + for i := len(req.Messages) - 1; i >= 0; i-- { + if req.Messages[i].Role == ai.RoleUser { + text = req.Messages[i].Text() + break + } + } + if text == "" { + text = "no input" + } + + resp := &ai.ModelResponse{ + Message: ai.NewModelTextMessage("echo: " + text), + } + + if cb != nil { + if err := cb(ctx, &ai.ModelResponseChunk{ + Content: resp.Message.Content, + }); err != nil { + return nil, err + } + } + + return resp, nil + }, + ) + ai.DefineGenerateAction(ctx, reg) + return reg +} + +func TestPromptSessionFlow_Basic(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + prompt := ai.DefinePrompt(reg, "testPrompt", + ai.WithModelName("test/echo"), + ai.WithSystem("You are a test assistant."), + ) + + sf := DefineSessionFlowFromPrompt[testState]( + reg, "promptFlow", prompt, nil, + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Turn 1. + if err := conn.SendText("hello"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + + var gotChunk bool + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.Chunk != nil { + gotChunk = true + } + if chunk.EndTurn { + break + } + } + if !gotChunk { + t.Error("expected at least one streaming chunk") + } + + // Turn 2. + if err := conn.SendText("world"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // 2 user messages + 2 model replies = 4. + if got := len(response.State.Messages); got != 4 { + t.Errorf("expected 4 messages, got %d", got) + for i, m := range response.State.Messages { + t.Logf(" msg[%d]: role=%s text=%s", i, m.Role, m.Text()) + } + } +} + +func TestPromptSessionFlow_PromptInputOverride(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + type greetInput struct { + Name string `json:"name"` + } + + prompt := ai.DefineDataPrompt[greetInput, string](reg, "greetPrompt", + ai.WithModelName("test/echo"), + ai.WithPrompt("Hello {{name}}!"), + ) + + sf := DefineSessionFlowFromPrompt[testState]( + reg, "promptInputFlow", prompt, greetInput{Name: "default"}, + ) + + // Use WithPromptInput to override. + conn, err := sf.StreamBidi(ctx, + WithPromptInput[testState](greetInput{Name: "override"}), + ) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + if err := conn.SendText("hi"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Verify the override was stored in session state. + if response.State.PromptInput == nil { + t.Fatal("expected PromptInput in state") + } + + // The model echoes the last user message, which is "hi". + // But the prompt was rendered with "override" so "Hello override!" should appear + // in the messages sent to the model (verified via the echo). + // We primarily verify the state was set correctly. + inputMap, ok := response.State.PromptInput.(map[string]any) + if !ok { + t.Fatalf("expected PromptInput to be map[string]any, got %T", response.State.PromptInput) + } + if name, _ := inputMap["name"].(string); name != "override" { + t.Errorf("expected PromptInput name='override', got %q", name) + } +} + +func TestPromptSessionFlow_MultiTurnHistory(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + + // Use a model that echoes all message count so we can verify history grows. + ai.DefineModel(reg, "test/history", &ai.ModelOptions{Supports: &ai.ModelSupports{Multiturn: true}}, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + // Count total messages received (includes prompt-rendered + history). + var parts []string + for _, m := range req.Messages { + parts = append(parts, string(m.Role)+":"+m.Text()) + } + text := strings.Join(parts, "|") + + resp := &ai.ModelResponse{ + Message: ai.NewModelTextMessage(text), + } + if cb != nil { + cb(ctx, &ai.ModelResponseChunk{Content: resp.Message.Content}) + } + return resp, nil + }, + ) + + prompt := ai.DefinePrompt(reg, "historyPrompt", + ai.WithModelName("test/history"), + ai.WithSystem("system prompt"), + ) + + sf := DefineSessionFlowFromPrompt[testState]( + reg, "historyFlow", prompt, nil, + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Turn 1. + conn.SendText("turn1") + var turn1Response string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.Chunk != nil { + turn1Response += chunk.Chunk.Text() + } + if chunk.EndTurn { + break + } + } + + // Turn 1 should have: system message + user message "turn1" (2 messages total from prompt + history). + // The system message comes from the prompt, "turn1" from session history. + if !strings.Contains(turn1Response, "turn1") { + t.Errorf("turn1 response should contain 'turn1', got: %s", turn1Response) + } + + // Turn 2. + conn.SendText("turn2") + var turn2Response string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.Chunk != nil { + turn2Response += chunk.Chunk.Text() + } + if chunk.EndTurn { + break + } + } + + // Turn 2 should have: system + turn1 user + turn1 model reply + turn2 user (4 messages from prompt + history). + if !strings.Contains(turn2Response, "turn1") || !strings.Contains(turn2Response, "turn2") { + t.Errorf("turn2 response should contain both 'turn1' and 'turn2', got: %s", turn2Response) + } + + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Session should have: turn1 user + turn1 model + turn2 user + turn2 model = 4 messages. + if got := len(response.State.Messages); got != 4 { + t.Errorf("expected 4 messages in session, got %d", got) + for i, m := range response.State.Messages { + t.Logf(" msg[%d]: role=%s text=%s", i, m.Role, m.Text()) + } + } +} + +func TestPromptSessionFlow_SnapshotPersistsPromptInput(t *testing.T) { + ctx := context.Background() + reg := setupPromptTestRegistry(t) + store := NewInMemorySnapshotStore[testState]() + + prompt := ai.DefinePrompt(reg, "snapPrompt", + ai.WithModelName("test/echo"), + ai.WithSystem("You are a test assistant."), + ) + + sf := DefineSessionFlowFromPrompt[testState]( + reg, "snapPromptFlow", prompt, nil, + WithSnapshotStore(store), + ) + + // Start with prompt input. + conn, err := sf.StreamBidi(ctx, + WithPromptInput[testState](map[string]any{"key": "value"}), + ) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("hello") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + resp, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + if resp.SnapshotID == "" { + t.Fatal("expected snapshot ID") + } + + // Verify the snapshot contains PromptInput. + snap, err := store.GetSnapshot(ctx, resp.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap.State.PromptInput == nil { + t.Error("expected PromptInput in snapshot state") + } + + // Resume from snapshot — the PromptInput should be preserved. + conn2, err := sf.StreamBidi(ctx, WithSnapshotID[testState](resp.SnapshotID)) + if err != nil { + t.Fatalf("StreamBidi with snapshot failed: %v", err) + } + + conn2.SendText("continued") + for chunk, err := range conn2.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn2.Close() + + resp2, err := conn2.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Should have messages from both invocations. + if got := len(resp2.State.Messages); got != 4 { + t.Errorf("expected 4 messages after resume, got %d", got) + } + + // PromptInput should still be present. + if resp2.State.PromptInput == nil { + t.Error("expected PromptInput preserved after resume") + } +} diff --git a/go/ai/x/session_flow.go b/go/ai/x/session_flow.go new file mode 100644 index 0000000000..06621014c0 --- /dev/null +++ b/go/ai/x/session_flow.go @@ -0,0 +1,743 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Package aix provides experimental AI primitives for Genkit. +// +// APIs in this package are under active development and may change in any +// minor version release. +package aix + +import ( + "context" + "encoding/json" + "fmt" + "iter" + "log/slog" + "sync" + "time" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/core/tracing" + "github.com/google/uuid" + "go.opentelemetry.io/otel/attribute" + oteltrace "go.opentelemetry.io/otel/trace" +) + +// SessionFlowArtifact represents a named collection of parts produced during a session. +// Examples: generated files, images, code snippets, diagrams, etc. +type SessionFlowArtifact struct { + // Name identifies the artifact (e.g., "generated_code.go", "diagram.png"). + Name string `json:"name,omitempty"` + // Parts contains the artifact content (text, media, etc.). + Parts []*ai.Part `json:"parts"` + // Metadata contains additional artifact-specific data. + Metadata map[string]any `json:"metadata,omitempty"` +} + +// SessionFlowInput is the input sent to a session flow during a conversation turn. +type SessionFlowInput struct { + // Messages contains the user's input for this turn. + Messages []*ai.Message `json:"messages,omitempty"` +} + +// SessionFlowInit is the input for starting a session flow invocation. +// Provide either SnapshotID (to load from store) or State (direct state). +type SessionFlowInit[State any] struct { + // SnapshotID loads state from a persisted snapshot. + // Mutually exclusive with State. + SnapshotID string `json:"snapshotId,omitempty"` + // State provides direct state for the invocation. + // Mutually exclusive with SnapshotID. + State *SessionState[State] `json:"state,omitempty"` + // PromptInput overrides the default prompt input for this invocation. + // Used by prompt-backed session flows (DefineSessionFlowFromPrompt). + PromptInput any `json:"promptInput,omitempty"` +} + +// SessionFlowOutput is the output when a session flow invocation completes. +type SessionFlowOutput[State any] struct { + // SnapshotID is the ID of the snapshot created at the end of this invocation. + // Empty if no snapshot was created (callback returned false or no store configured). + SnapshotID string `json:"snapshotId,omitempty"` + // State contains the final conversation state. + State *SessionState[State] `json:"state"` +} + +// SessionFlowStreamChunk represents a single item in the session flow's output stream. +// Multiple fields can be populated in a single chunk. +type SessionFlowStreamChunk[Stream any] struct { + // Chunk contains token-level generation data. + Chunk *ai.ModelResponseChunk `json:"chunk,omitempty"` + // Status contains user-defined structured status information. + // The Stream type parameter defines the shape of this data. + Status Stream `json:"status,omitempty"` + // Artifact contains a newly produced artifact. + Artifact *SessionFlowArtifact `json:"artifact,omitempty"` + // SnapshotCreated contains the ID of a snapshot that was just persisted. + SnapshotCreated string `json:"snapshotCreated,omitempty"` + // EndTurn signals that the session flow has finished processing the current input. + // When true, the client should stop iterating and may send the next input. + EndTurn bool `json:"endTurn,omitempty"` +} + +// --- Session --- + +// Session holds the working state during a session flow invocation. +// It is propagated through context and provides read/write access to state. +type Session[State any] struct { + mu sync.RWMutex + state SessionState[State] + store SnapshotStore[State] + + snapshotCallback SnapshotCallback[State] + + // onEndTurn is set by the framework; triggers snapshot + EndTurn chunk. + onEndTurn func(ctx context.Context) + inCh <-chan *SessionFlowInput + + // Snapshot tracking + lastSnapshot *SessionSnapshot[State] + turnIndex int +} + +// Run loops over the input channel, calling fn for each turn. Each turn is +// wrapped in a trace span for observability. Input messages are automatically +// added to the session before fn is called. After fn returns successfully, an +// EndTurn chunk is sent and a snapshot check is triggered. +func (s *Session[State]) Run( + ctx context.Context, + fn func(ctx context.Context, input *SessionFlowInput) error, +) error { + for input := range s.inCh { + spanMeta := &tracing.SpanMetadata{ + Name: fmt.Sprintf("sessionFlow/turn/%d", s.turnIndex), + Type: "sessionFlowTurn", + Subtype: "sessionFlowTurn", + } + + _, err := tracing.RunInNewSpan(ctx, spanMeta, input, + func(ctx context.Context, input *SessionFlowInput) (struct{}, error) { + s.AddMessages(input.Messages...) + + if err := fn(ctx, input); err != nil { + return struct{}{}, err + } + + s.onEndTurn(ctx) + s.turnIndex++ + return struct{}{}, nil + }, + ) + if err != nil { + return err + } + } + return nil +} + +// State returns a copy of the current session flow state. +func (s *Session[State]) State() *SessionState[State] { + s.mu.RLock() + defer s.mu.RUnlock() + copied := s.copyStateLocked() + return &copied +} + +// Messages returns the current conversation history. +func (s *Session[State]) Messages() []*ai.Message { + s.mu.RLock() + defer s.mu.RUnlock() + msgs := make([]*ai.Message, len(s.state.Messages)) + copy(msgs, s.state.Messages) + return msgs +} + +// AddMessages appends messages to the conversation history. +func (s *Session[State]) AddMessages(messages ...*ai.Message) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Messages = append(s.state.Messages, messages...) +} + +// SetMessages replaces the entire conversation history. +func (s *Session[State]) SetMessages(messages []*ai.Message) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Messages = messages +} + +// Custom returns the current user-defined custom state. +func (s *Session[State]) Custom() State { + s.mu.RLock() + defer s.mu.RUnlock() + return s.state.Custom +} + +// SetCustom updates the user-defined custom state. +func (s *Session[State]) SetCustom(custom State) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Custom = custom +} + +// UpdateCustom atomically reads the current custom state, applies the given +// function, and writes the result back. +func (s *Session[State]) UpdateCustom(fn func(State) State) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Custom = fn(s.state.Custom) +} + +// PromptInput returns the prompt input stored in the session state. +func (s *Session[State]) PromptInput() any { + s.mu.RLock() + defer s.mu.RUnlock() + return s.state.PromptInput +} + +// Artifacts returns the current artifacts. +func (s *Session[State]) Artifacts() []*SessionFlowArtifact { + s.mu.RLock() + defer s.mu.RUnlock() + arts := make([]*SessionFlowArtifact, len(s.state.Artifacts)) + copy(arts, s.state.Artifacts) + return arts +} + +// AddArtifacts adds artifacts to the session. If an artifact with the same +// name already exists, it is replaced. +func (s *Session[State]) AddArtifacts(artifacts ...*SessionFlowArtifact) { + s.mu.Lock() + defer s.mu.Unlock() + for _, a := range artifacts { + replaced := false + if a.Name != "" { + for i, existing := range s.state.Artifacts { + if existing.Name == a.Name { + s.state.Artifacts[i] = a + replaced = true + break + } + } + } + if !replaced { + s.state.Artifacts = append(s.state.Artifacts, a) + } + } +} + +// SetArtifacts replaces the entire artifact list. +func (s *Session[State]) SetArtifacts(artifacts []*SessionFlowArtifact) { + s.mu.Lock() + defer s.mu.Unlock() + s.state.Artifacts = artifacts +} + +// maybeSnapshot creates a snapshot if conditions are met (store configured, +// callback approves). Returns the snapshot ID or empty string. +func (s *Session[State]) maybeSnapshot(ctx context.Context, event SnapshotEvent) string { + if s.store == nil { + return "" + } + + s.mu.RLock() + currentState := s.copyStateLocked() + turnIndex := s.turnIndex + s.mu.RUnlock() + + shouldSnapshot := true + if s.snapshotCallback != nil { + var prevState *SessionState[State] + if s.lastSnapshot != nil { + prevState = &s.lastSnapshot.State + } + shouldSnapshot = s.snapshotCallback(ctx, &SnapshotContext[State]{ + State: ¤tState, + PrevState: prevState, + TurnIndex: turnIndex, + Event: event, + }) + } + + if !shouldSnapshot { + return "" + } + + snapshot := &SessionSnapshot[State]{ + SnapshotID: uuid.New().String(), + CreatedAt: time.Now(), + TurnIndex: turnIndex, + Event: event, + State: currentState, + } + if s.lastSnapshot != nil { + snapshot.ParentID = s.lastSnapshot.SnapshotID + } + + if err := s.store.SaveSnapshot(ctx, snapshot); err != nil { + slog.Error("session flow: failed to save snapshot", "err", err) + return "" + } + + // Set snapshotId in last message metadata. + s.mu.Lock() + if msgs := s.state.Messages; len(msgs) > 0 { + lastMsg := msgs[len(msgs)-1] + if lastMsg.Metadata == nil { + lastMsg.Metadata = make(map[string]any) + } + lastMsg.Metadata["snapshotId"] = snapshot.SnapshotID + } + s.mu.Unlock() + + s.lastSnapshot = snapshot + + // Record on OTel span. + span := oteltrace.SpanFromContext(ctx) + span.SetAttributes( + attribute.String("genkit:metadata:snapshotId", snapshot.SnapshotID), + ) + + return snapshot.SnapshotID +} + +// copyStateLocked returns a deep copy of the state. Caller must hold mu (read or write). +func (s *Session[State]) copyStateLocked() SessionState[State] { + bytes, err := json.Marshal(s.state) + if err != nil { + panic(fmt.Sprintf("session flow: failed to marshal state: %v", err)) + } + var copied SessionState[State] + if err := json.Unmarshal(bytes, &copied); err != nil { + panic(fmt.Sprintf("session flow: failed to unmarshal state: %v", err)) + } + return copied +} + +// --- Session context --- + +type sessionContextKey struct{} + +type sessionHolder struct { + session any +} + +// NewSessionContext returns a new context with the session attached. +func NewSessionContext[State any](ctx context.Context, s *Session[State]) context.Context { + return context.WithValue(ctx, sessionContextKey{}, &sessionHolder{session: s}) +} + +// SessionFromContext retrieves the current session from context. +// Returns nil if no session is in context or if the type doesn't match. +func SessionFromContext[State any](ctx context.Context) *Session[State] { + holder, ok := ctx.Value(sessionContextKey{}).(*sessionHolder) + if !ok || holder == nil { + return nil + } + session, ok := holder.session.(*Session[State]) + if !ok { + return nil + } + return session +} + +// --- Responder --- + +// Responder is the output channel for a session flow. Chunks sent through it +// are automatically inspected: if a chunk contains an artifact, it is added to +// the session before being forwarded to the client. +// +// Convenience methods are provided for common chunk types. +type Responder[Stream any] chan<- *SessionFlowStreamChunk[Stream] + +// SendChunk sends a generation chunk (token-level streaming). +func (r Responder[Stream]) SendChunk(chunk *ai.ModelResponseChunk) { + r <- &SessionFlowStreamChunk[Stream]{Chunk: chunk} +} + +// SendStatus sends a user-defined status update. +func (r Responder[Stream]) SendStatus(status Stream) { + r <- &SessionFlowStreamChunk[Stream]{Status: status} +} + +// SendArtifact sends an artifact to the stream and adds it to the session. +// If an artifact with the same name already exists in the session, it is replaced. +func (r Responder[Stream]) SendArtifact(artifact *SessionFlowArtifact) { + r <- &SessionFlowStreamChunk[Stream]{Artifact: artifact} +} + +// --- SessionFlowParams --- + +// SessionFlowParams contains the parameters passed to a session flow function. +type SessionFlowParams[State any] struct { + // Session provides access to the working state. + Session *Session[State] +} + +// --- SessionFlowFunc --- + +// SessionFlowFunc is the function signature for session flows. +// Type parameters: +// - Stream: Type for status updates sent via the responder +// - State: Type for user-defined state in snapshots +type SessionFlowFunc[Stream, State any] func( + ctx context.Context, + resp Responder[Stream], + params *SessionFlowParams[State], +) error + +// --- SessionFlow --- + +// SessionFlow is a bidirectional streaming flow with automatic snapshot management. +type SessionFlow[Stream, State any] struct { + flow *core.Flow[*SessionFlowInput, *SessionFlowOutput[State], *SessionFlowStreamChunk[Stream], *SessionFlowInit[State]] + store SnapshotStore[State] + snapshotCallback SnapshotCallback[State] +} + +// DefineSessionFlow creates a SessionFlow with automatic snapshot management and registers it. +func DefineSessionFlow[Stream, State any]( + r api.Registry, + name string, + fn SessionFlowFunc[Stream, State], + opts ...SessionFlowOption[State], +) *SessionFlow[Stream, State] { + sfOpts := &sessionFlowOptions[State]{} + for _, opt := range opts { + if err := opt.applySessionFlow(sfOpts); err != nil { + panic(fmt.Errorf("DefineSessionFlow %q: %w", name, err)) + } + } + + sf := &SessionFlow[Stream, State]{ + store: sfOpts.store, + snapshotCallback: sfOpts.callback, + } + + bidiFn := func( + ctx context.Context, + init *SessionFlowInit[State], + inCh <-chan *SessionFlowInput, + outCh chan<- *SessionFlowStreamChunk[Stream], + ) (*SessionFlowOutput[State], error) { + return sf.runWrapped(ctx, init, inCh, outCh, fn) + } + + sf.flow = core.DefineBidiFlow(r, name, bidiFn) + + // Register snapshot store action for reflection API. + if sfOpts.store != nil { + registerSnapshotStoreAction(r, name, sfOpts.store) + } + + return sf +} + +// StreamBidi starts a new session flow invocation. +func (sf *SessionFlow[Stream, State]) StreamBidi( + ctx context.Context, + opts ...StreamBidiOption[State], +) (*SessionFlowConnection[Stream, State], error) { + sbOpts := &streamBidiOptions[State]{} + for _, opt := range opts { + if err := opt.applyStreamBidi(sbOpts); err != nil { + return nil, fmt.Errorf("SessionFlow.StreamBidi %q: %w", sf.flow.Name(), err) + } + } + + init := &SessionFlowInit[State]{ + SnapshotID: sbOpts.snapshotID, + State: sbOpts.state, + PromptInput: sbOpts.promptInput, + } + + conn, err := sf.flow.StreamBidi(ctx, init) + if err != nil { + return nil, err + } + + return &SessionFlowConnection[Stream, State]{conn: conn}, nil +} + +// runWrapped is the BidiFunc implementation. It sets up the session, +// responder, and wiring, then delegates to the user's function. +func (sf *SessionFlow[Stream, State]) runWrapped( + ctx context.Context, + init *SessionFlowInit[State], + inCh <-chan *SessionFlowInput, + outCh chan<- *SessionFlowStreamChunk[Stream], + fn SessionFlowFunc[Stream, State], +) (*SessionFlowOutput[State], error) { + session, err := newSessionFromInit(ctx, init, sf.store, sf.snapshotCallback) + if err != nil { + return nil, err + } + session.inCh = inCh + ctx = NewSessionContext(ctx, session) + + // Intermediary channel: intercepts artifacts before forwarding to outCh. + respCh := make(chan *SessionFlowStreamChunk[Stream]) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for chunk := range respCh { + if chunk.Artifact != nil { + session.AddArtifacts(chunk.Artifact) + } + outCh <- chunk + } + }() + + // Wire up onEndTurn: triggers snapshot + sends EndTurn chunk. + // Writes through respCh to preserve ordering with user chunks. + session.onEndTurn = func(turnCtx context.Context) { + snapshotID := session.maybeSnapshot(turnCtx, SnapshotEventTurnEnd) + if snapshotID != "" { + respCh <- &SessionFlowStreamChunk[Stream]{SnapshotCreated: snapshotID} + } + respCh <- &SessionFlowStreamChunk[Stream]{EndTurn: true} + } + + params := &SessionFlowParams[State]{ + Session: session, + } + + fnErr := fn(ctx, Responder[Stream](respCh), params) + close(respCh) + wg.Wait() + + if fnErr != nil { + return nil, fnErr + } + + // Final snapshot at invocation end. + snapshotID := session.maybeSnapshot(ctx, SnapshotEventInvocationEnd) + + return &SessionFlowOutput[State]{ + State: session.State(), + SnapshotID: snapshotID, + }, nil +} + +// newSessionFromInit creates a session from initialization data. +func newSessionFromInit[State any]( + ctx context.Context, + init *SessionFlowInit[State], + store SnapshotStore[State], + cb SnapshotCallback[State], +) (*Session[State], error) { + s := &Session[State]{ + store: store, + snapshotCallback: cb, + } + + if init != nil { + if init.SnapshotID != "" && store != nil { + snapshot, err := store.GetSnapshot(ctx, init.SnapshotID) + if err != nil { + return nil, core.NewError(core.INTERNAL, "failed to load snapshot %q: %v", init.SnapshotID, err) + } + if snapshot == nil { + return nil, core.NewError(core.NOT_FOUND, "snapshot %q not found", init.SnapshotID) + } + s.state = snapshot.State + s.lastSnapshot = snapshot + s.turnIndex = snapshot.TurnIndex + } else if init.State != nil { + s.state = *init.State + } + if init.PromptInput != nil { + s.state.PromptInput = init.PromptInput + } + } + + return s, nil +} + +// --- Snapshot store reflection action --- + +type getSnapshotInput struct { + SnapshotID string `json:"snapshotId"` +} + +func registerSnapshotStoreAction[State any](r api.Registry, flowName string, store SnapshotStore[State]) { + core.DefineAction(r, flowName+"/getSnapshot", api.ActionTypeSnapshotStore, nil, nil, + func(ctx context.Context, input getSnapshotInput) (*SessionSnapshot[State], error) { + return store.GetSnapshot(ctx, input.SnapshotID) + }, + ) +} + +// --- SessionFlowConnection --- + +// SessionFlowConnection wraps BidiConnection with session flow-specific functionality. +// It provides a Receive() iterator that supports multi-turn patterns: breaking out +// of the iterator between turns does not cancel the underlying connection. +type SessionFlowConnection[Stream, State any] struct { + conn *core.BidiConnection[*SessionFlowInput, *SessionFlowOutput[State], *SessionFlowStreamChunk[Stream]] + + // chunks buffers stream chunks from the underlying connection so that + // breaking from Receive() between turns doesn't cancel the context. + chunks chan *SessionFlowStreamChunk[Stream] + chunkErr error + initOnce sync.Once +} + +// initReceiver starts a goroutine that drains the underlying BidiConnection's +// Receive into a channel. This goroutine never breaks from the underlying +// iterator, preventing context cancellation. +func (c *SessionFlowConnection[Stream, State]) initReceiver() { + c.initOnce.Do(func() { + c.chunks = make(chan *SessionFlowStreamChunk[Stream], 1) + go func() { + defer close(c.chunks) + for chunk, err := range c.conn.Receive() { + if err != nil { + c.chunkErr = err + return + } + c.chunks <- chunk + } + }() + }) +} + +// Send sends a SessionFlowInput to the session flow. +func (c *SessionFlowConnection[Stream, State]) Send(input *SessionFlowInput) error { + return c.conn.Send(input) +} + +// SendMessages sends messages to the session flow. +func (c *SessionFlowConnection[Stream, State]) SendMessages(messages ...*ai.Message) error { + return c.conn.Send(&SessionFlowInput{Messages: messages}) +} + +// SendText sends a single user text message to the session flow. +func (c *SessionFlowConnection[Stream, State]) SendText(text string) error { + return c.conn.Send(&SessionFlowInput{ + Messages: []*ai.Message{ai.NewUserTextMessage(text)}, + }) +} + +// Close signals that no more inputs will be sent. +func (c *SessionFlowConnection[Stream, State]) Close() error { + return c.conn.Close() +} + +// Receive returns an iterator for receiving stream chunks. +// Unlike the underlying BidiConnection.Receive, breaking out of this iterator +// does not cancel the connection. This enables multi-turn patterns where the +// caller breaks on EndTurn, sends the next input, then calls Receive again. +func (c *SessionFlowConnection[Stream, State]) Receive() iter.Seq2[*SessionFlowStreamChunk[Stream], error] { + c.initReceiver() + return func(yield func(*SessionFlowStreamChunk[Stream], error) bool) { + for { + chunk, ok := <-c.chunks + if !ok { + if err := c.chunkErr; err != nil { + var zero *SessionFlowStreamChunk[Stream] + yield(zero, err) + } + return + } + if !yield(chunk, nil) { + return + } + } + } +} + +// Output returns the final response after the session flow completes. +func (c *SessionFlowConnection[Stream, State]) Output() (*SessionFlowOutput[State], error) { + return c.conn.Output() +} + +// Done returns a channel closed when the connection completes. +func (c *SessionFlowConnection[Stream, State]) Done() <-chan struct{} { + return c.conn.Done() +} + +// --- Prompt-backed SessionFlow --- + +// PromptRenderer renders a prompt with typed input into GenerateActionOptions. +// This interface is satisfied by both ai.Prompt (with In=any) and +// *ai.DataPrompt[In, Out]. +type PromptRenderer[In any] interface { + Render(ctx context.Context, input In) (*ai.GenerateActionOptions, error) +} + +// DefineSessionFlowFromPrompt creates a prompt-backed SessionFlow with an +// automatic conversation loop. Each turn renders the prompt, appends +// conversation history, calls GenerateWithRequest, streams chunks to the +// client, and adds the model response to the session. +// +// The defaultInput is used for prompt rendering unless overridden per +// invocation via WithPromptInput. +func DefineSessionFlowFromPrompt[State, PromptIn any]( + r api.Registry, + name string, + p PromptRenderer[PromptIn], + defaultInput PromptIn, + opts ...SessionFlowOption[State], +) *SessionFlow[struct{}, State] { + fn := func(ctx context.Context, resp Responder[struct{}], params *SessionFlowParams[State]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + + // Resolve prompt input: session state override > default. + promptInput := defaultInput + if stored := sess.PromptInput(); stored != nil { + typed, ok := stored.(PromptIn) + if !ok { + return fmt.Errorf("prompt input type mismatch: got %T, want %T", stored, promptInput) + } + promptInput = typed + } + + // Render the prompt template. + actionOpts, err := p.Render(ctx, promptInput) + if err != nil { + return fmt.Errorf("prompt render: %w", err) + } + + // Append conversation history after the prompt-rendered messages. + actionOpts.Messages = append(actionOpts.Messages, sess.Messages()...) + + // Call the model with streaming. + modelResp, err := ai.GenerateWithRequest(ctx, r, actionOpts, nil, + func(ctx context.Context, chunk *ai.ModelResponseChunk) error { + resp.SendChunk(chunk) + return nil + }, + ) + if err != nil { + return fmt.Errorf("generate: %w", err) + } + + // Add the model response message to session history. + if modelResp.Message != nil { + sess.AddMessages(modelResp.Message) + } + + return nil + }) + } + + return DefineSessionFlow(r, name, fn, opts...) +} diff --git a/go/ai/x/session_flow_test.go b/go/ai/x/session_flow_test.go new file mode 100644 index 0000000000..d5722863fd --- /dev/null +++ b/go/ai/x/session_flow_test.go @@ -0,0 +1,746 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package aix + +import ( + "context" + "fmt" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/internal/registry" +) + +type testState struct { + Counter int `json:"counter"` + Topics []string `json:"topics,omitempty"` +} + +type testStatus struct { + Phase string `json:"phase"` +} + +func newTestRegistry(t *testing.T) *registry.Registry { + t.Helper() + return registry.New() +} + +func TestSessionFlow_BasicMultiTurn(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + sf := DefineSessionFlow(reg, "basicFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + resp.SendStatus(testStatus{Phase: "generating"}) + // Echo back the user's message. + if len(input.Messages) > 0 { + reply := ai.NewModelTextMessage("echo: " + input.Messages[0].Content[0].Text) + sess.AddMessages(reply) + } + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + resp.SendStatus(testStatus{Phase: "complete"}) + return nil + }) + }, + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Turn 1. + if err := conn.SendText("hello"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + var turn1Chunks int + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + turn1Chunks++ + if chunk.EndTurn { + break + } + } + if turn1Chunks < 2 { // at least status + endTurn + t.Errorf("expected at least 2 chunks in turn 1, got %d", turn1Chunks) + } + + // Turn 2. + if err := conn.SendText("world"); err != nil { + t.Fatalf("SendText failed: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // 2 user messages + 2 echo replies = 4. + if got := len(response.State.Messages); got != 4 { + t.Errorf("expected 4 messages, got %d", got) + } + if got := response.State.Custom.Counter; got != 2 { + t.Errorf("expected counter=2, got %d", got) + } +} + +func TestSessionFlow_WithSnapshotStore(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySnapshotStore[testState]() + + sf := DefineSessionFlow(reg, "snapshotFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + if len(input.Messages) > 0 { + sess.AddMessages(ai.NewModelTextMessage("reply")) + } + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSnapshotStore[testState](store), + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("turn1") + + var snapshotIDs []string + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.SnapshotCreated != "" { + snapshotIDs = append(snapshotIDs, chunk.SnapshotCreated) + } + if chunk.EndTurn { + break + } + } + + if len(snapshotIDs) != 1 { + t.Fatalf("expected 1 snapshot from turn, got %d", len(snapshotIDs)) + } + + // Verify the snapshot was persisted. + snap, err := store.GetSnapshot(ctx, snapshotIDs[0]) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap == nil { + t.Fatal("expected snapshot, got nil") + } + if snap.State.Custom.Counter != 1 { + t.Errorf("expected counter=1 in snapshot, got %d", snap.State.Custom.Counter) + } + if snap.TurnIndex != 0 { + t.Errorf("expected turnIndex=0, got %d", snap.TurnIndex) + } + + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Final snapshot at invocation end. + if response.SnapshotID == "" { + t.Error("expected final snapshot ID") + } +} + +func TestSessionFlow_ResumeFromSnapshot(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySnapshotStore[testState]() + + sf := DefineSessionFlow(reg, "resumeFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + if len(input.Messages) > 0 { + sess.AddMessages(ai.NewModelTextMessage("reply")) + } + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSnapshotStore[testState](store), + ) + + // First invocation: create a snapshot. + conn1, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + conn1.SendText("first message") + for chunk, err := range conn1.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn1.Close() + resp1, err := conn1.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + if resp1.SnapshotID == "" { + t.Fatal("expected snapshot ID from first invocation") + } + + // Second invocation: resume from snapshot. + conn2, err := sf.StreamBidi(ctx, WithSnapshotID[testState](resp1.SnapshotID)) + if err != nil { + t.Fatalf("StreamBidi with snapshot failed: %v", err) + } + conn2.SendText("continued message") + for chunk, err := range conn2.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn2.Close() + resp2, err := conn2.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Should have messages from both invocations: + // first: user + reply (2) + second: user + reply (2) = 4. + if got := len(resp2.State.Messages); got != 4 { + t.Errorf("expected 4 messages after resume, got %d", got) + } + // Counter should be 2 (1 from first + 1 from second). + if got := resp2.State.Custom.Counter; got != 2 { + t.Errorf("expected counter=2, got %d", got) + } + + // The new snapshot should reference the previous as parent. + if resp2.SnapshotID == "" { + t.Fatal("expected snapshot ID from second invocation") + } + snap2, err := store.GetSnapshot(ctx, resp2.SnapshotID) + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + // The parent chain: snap2's parent is a turn-end snapshot from the second invocation, + // which itself has a parent from the first invocation's final snapshot. + // We just verify that the parent chain exists (not empty). + if snap2.ParentID == "" { + t.Error("expected parent ID on resumed snapshot") + } +} + +func TestSessionFlow_ClientManagedState(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + sf := DefineSessionFlow(reg, "clientStateFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + if len(input.Messages) > 0 { + sess.AddMessages(ai.NewModelTextMessage("reply")) + } + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + ) + + // Start with client-provided state. + clientState := &SessionState[testState]{ + Messages: []*ai.Message{ + ai.NewUserTextMessage("previous message"), + ai.NewModelTextMessage("previous reply"), + }, + Custom: testState{Counter: 5}, + } + + conn, err := sf.StreamBidi(ctx, WithState(clientState)) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("new message") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // 2 previous + 1 new user + 1 reply = 4. + if got := len(response.State.Messages); got != 4 { + t.Errorf("expected 4 messages, got %d", got) + } + // Counter should be 6 (started at 5, incremented once). + if got := response.State.Custom.Counter; got != 6 { + t.Errorf("expected counter=6, got %d", got) + } + // No snapshot since no store was configured. + if response.SnapshotID != "" { + t.Errorf("expected no snapshot ID without store, got %q", response.SnapshotID) + } +} + +func TestSessionFlow_Artifacts(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + sf := DefineSessionFlow(reg, "artifactFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + + resp.SendArtifact(&SessionFlowArtifact{ + Name: "code.go", + Parts: []*ai.Part{ai.NewTextPart("package main")}, + }) + + // Replace artifact with same name. + resp.SendArtifact(&SessionFlowArtifact{ + Name: "code.go", + Parts: []*ai.Part{ai.NewTextPart("package main\nfunc main() {}")}, + }) + + // Add another artifact. + resp.SendArtifact(&SessionFlowArtifact{ + Name: "readme.md", + Parts: []*ai.Part{ai.NewTextPart("# README")}, + }) + + sess.AddMessages(ai.NewModelTextMessage("done")) + return nil + }) + }, + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("generate code") + var receivedArtifacts []*SessionFlowArtifact + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.Artifact != nil { + receivedArtifacts = append(receivedArtifacts, chunk.Artifact) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + if len(receivedArtifacts) != 3 { // all 3 sends are streamed + t.Errorf("expected 3 streamed artifacts, got %d", len(receivedArtifacts)) + } + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Session should have 2 unique artifacts (code.go was replaced). + if got := len(response.State.Artifacts); got != 2 { + t.Errorf("expected 2 artifacts in state, got %d", got) + } +} + +func TestSessionFlow_SnapshotCallback(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySnapshotStore[testState]() + + // Only snapshot on even turns. + callbackCalls := 0 + sf := DefineSessionFlow(reg, "callbackFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + sess.AddMessages(ai.NewModelTextMessage("reply")) + sess.UpdateCustom(func(s testState) testState { + s.Counter++ + return s + }) + return nil + }) + }, + WithSnapshotStore[testState](store), + WithSnapshotCallback(func(ctx context.Context, sc *SnapshotContext[testState]) bool { + callbackCalls++ + return sc.TurnIndex%2 == 0 // only snapshot on even turns + }), + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + var snapshotIDs []string + for i := 0; i < 3; i++ { + conn.SendText(fmt.Sprintf("turn %d", i)) + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error on turn %d: %v", i, err) + } + if chunk.SnapshotCreated != "" { + snapshotIDs = append(snapshotIDs, chunk.SnapshotCreated) + } + if chunk.EndTurn { + break + } + } + } + conn.Close() + conn.Output() // drain + + // Turn 0 (even) → snapshot, Turn 1 (odd) → no, Turn 2 (even) → snapshot. + // That's 2 turn snapshots from the callback. + if got := len(snapshotIDs); got != 2 { + t.Errorf("expected 2 turn snapshots, got %d", got) + } + // Callback should have been called 3 times (once per turn). + if callbackCalls < 3 { + t.Errorf("expected at least 3 callback calls, got %d", callbackCalls) + } +} + +func TestSessionFlow_SendMessages(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + sf := DefineSessionFlow(reg, "sendMsgsFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + return nil + }) + }, + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + // Send multiple messages at once. + err = conn.SendMessages( + ai.NewUserTextMessage("msg1"), + ai.NewUserTextMessage("msg2"), + ) + if err != nil { + t.Fatalf("SendMessages failed: %v", err) + } + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // Both messages should have been added. + if got := len(response.State.Messages); got != 2 { + t.Errorf("expected 2 messages, got %d", got) + } +} + +func TestSessionFlow_SessionContext(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + var retrievedCounter int + sf := DefineSessionFlow(reg, "contextFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + // Session should be retrievable from context. + sess := SessionFromContext[testState](ctx) + if sess == nil { + t.Error("expected session from context") + return nil + } + sess.UpdateCustom(func(s testState) testState { + s.Counter = 42 + return s + }) + retrievedCounter = sess.Custom().Counter + return nil + }) + }, + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("test") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + conn.Output() + + if retrievedCounter != 42 { + t.Errorf("expected counter=42 from context, got %d", retrievedCounter) + } +} + +func TestSessionFlow_ErrorInTurn(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + sf := DefineSessionFlow(reg, "errorFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + return fmt.Errorf("turn failed") + }) + }, + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("trigger error") + conn.Close() + + _, err = conn.Output() + if err == nil { + t.Fatal("expected error from failed turn") + } +} + +func TestSessionFlow_SetMessages(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + + sf := DefineSessionFlow(reg, "setMsgsFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + // Replace all messages with just one. + sess.SetMessages([]*ai.Message{ai.NewModelTextMessage("replaced")}) + return nil + }) + }, + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("original") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // SetMessages replaced everything with 1 message. + if got := len(response.State.Messages); got != 1 { + t.Errorf("expected 1 message after SetMessages, got %d", got) + } +} + +func TestSessionFlow_SnapshotIDInMessageMetadata(t *testing.T) { + ctx := context.Background() + reg := newTestRegistry(t) + store := NewInMemorySnapshotStore[testState]() + + sf := DefineSessionFlow(reg, "metadataFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return params.Session.Run(ctx, func(ctx context.Context, input *SessionFlowInput) error { + sess := params.Session + sess.AddMessages(ai.NewModelTextMessage("reply")) + return nil + }) + }, + WithSnapshotStore[testState](store), + ) + + conn, err := sf.StreamBidi(ctx) + if err != nil { + t.Fatalf("StreamBidi failed: %v", err) + } + + conn.SendText("hello") + for chunk, err := range conn.Receive() { + if err != nil { + t.Fatalf("Receive error: %v", err) + } + if chunk.EndTurn { + break + } + } + conn.Close() + + response, err := conn.Output() + if err != nil { + t.Fatalf("Output failed: %v", err) + } + + // The last message should have snapshotId in its metadata. + msgs := response.State.Messages + if len(msgs) == 0 { + t.Fatal("expected messages in response") + } + lastMsg := msgs[len(msgs)-1] + if lastMsg.Metadata == nil { + t.Fatal("expected metadata on last message") + } + if _, ok := lastMsg.Metadata["snapshotId"]; !ok { + t.Error("expected snapshotId in last message metadata") + } +} + +func TestInMemorySnapshotStore(t *testing.T) { + ctx := context.Background() + store := NewInMemorySnapshotStore[testState]() + + // Get non-existent. + snap, err := store.GetSnapshot(ctx, "nonexistent") + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if snap != nil { + t.Errorf("expected nil, got %v", snap) + } + + // Save and retrieve. + snapshot := &SessionSnapshot[testState]{ + SnapshotID: "snap-1", + TurnIndex: 0, + State: SessionState[testState]{ + Custom: testState{Counter: 1}, + }, + } + if err := store.SaveSnapshot(ctx, snapshot); err != nil { + t.Fatalf("SaveSnapshot failed: %v", err) + } + + retrieved, err := store.GetSnapshot(ctx, "snap-1") + if err != nil { + t.Fatalf("GetSnapshot failed: %v", err) + } + if retrieved == nil { + t.Fatal("expected snapshot") + } + if retrieved.State.Custom.Counter != 1 { + t.Errorf("expected counter=1, got %d", retrieved.State.Custom.Counter) + } + + // Verify isolation. + snapshot.State.Custom.Counter = 999 + retrieved2, _ := store.GetSnapshot(ctx, "snap-1") + if retrieved2.State.Custom.Counter != 1 { + t.Errorf("expected counter=1 (isolation), got %d", retrieved2.State.Custom.Counter) + } +} + +func TestSessionFlow_SnapshotStoreReflectionAction(t *testing.T) { + _ = context.Background() + reg := newTestRegistry(t) + store := NewInMemorySnapshotStore[testState]() + + DefineSessionFlow(reg, "reflectFlow", + func(ctx context.Context, resp Responder[testStatus], params *SessionFlowParams[testState]) error { + return nil + }, + WithSnapshotStore[testState](store), + ) + + // The getSnapshot action should be registered. + action := reg.LookupAction("/snapshot-store/reflectFlow/getSnapshot") + if action == nil { + t.Fatal("expected getSnapshot action to be registered") + } +} diff --git a/go/ai/x/snapshot.go b/go/ai/x/snapshot.go new file mode 100644 index 0000000000..e9411aa32f --- /dev/null +++ b/go/ai/x/snapshot.go @@ -0,0 +1,164 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package aix + +import ( + "context" + "encoding/json" + "sync" + "time" + + "github.com/firebase/genkit/go/ai" +) + +// SessionState is the portable conversation state that flows between client +// and server. It contains only the data needed for conversation continuity. +type SessionState[State any] struct { + // Messages is the conversation history (user/model exchanges). + // Does NOT include prompt-rendered messages — those are rendered fresh each turn. + Messages []*ai.Message `json:"messages,omitempty"` + // Custom is the user-defined state associated with this conversation. + Custom State `json:"custom,omitempty"` + // Artifacts are named collections of parts produced during the conversation. + Artifacts []*SessionFlowArtifact `json:"artifacts,omitempty"` + // PromptInput is the input used for prompt rendering in prompt-backed session flows. + // Stored as any to support type-erased persistence across snapshot boundaries. + PromptInput any `json:"promptInput,omitempty"` +} + +// SnapshotEvent identifies what triggered a snapshot. +type SnapshotEvent string + +const ( + // TurnEnd indicates the snapshot was triggered at the end of a turn. + SnapshotEventTurnEnd SnapshotEvent = "turnEnd" + // InvocationEnd indicates the snapshot was triggered at the end of the invocation. + SnapshotEventInvocationEnd SnapshotEvent = "invocationEnd" +) + +// SessionSnapshot is a persisted point-in-time capture of session state. +type SessionSnapshot[State any] struct { + // SnapshotID is the unique identifier for this snapshot (UUID). + SnapshotID string `json:"snapshotId"` + // ParentID is the ID of the previous snapshot in this timeline. + ParentID string `json:"parentId,omitempty"` + // CreatedAt is when the snapshot was created. + CreatedAt time.Time `json:"createdAt"` + // TurnIndex is the turn number when this snapshot was created (0-indexed). + TurnIndex int `json:"turnIndex"` + // Event is what triggered this snapshot. + Event SnapshotEvent `json:"event"` + // State is the actual conversation state. + State SessionState[State] `json:"state"` +} + +// SnapshotContext provides context for snapshot decision callbacks. +type SnapshotContext[State any] struct { + // State is the current state that will be snapshotted if the callback returns true. + State *SessionState[State] + // PrevState is the state at the last snapshot, or nil if none exists. + PrevState *SessionState[State] + // TurnIndex is the current turn number. + TurnIndex int + // Event is what triggered this snapshot check. + Event SnapshotEvent +} + +// SnapshotCallback decides whether to create a snapshot. +// If not provided and a store is configured, snapshots are always created. +type SnapshotCallback[State any] = func(ctx context.Context, sc *SnapshotContext[State]) bool + +// SnapshotStore persists and retrieves snapshots. +type SnapshotStore[State any] interface { + // GetSnapshot retrieves a snapshot by ID. Returns nil if not found. + GetSnapshot(ctx context.Context, snapshotID string) (*SessionSnapshot[State], error) + // SaveSnapshot persists a snapshot. + SaveSnapshot(ctx context.Context, snapshot *SessionSnapshot[State]) error +} + +// InMemorySnapshotStore provides a thread-safe in-memory snapshot store. +type InMemorySnapshotStore[State any] struct { + snapshots map[string]*SessionSnapshot[State] + mu sync.RWMutex +} + +// NewInMemorySnapshotStore creates a new in-memory snapshot store. +func NewInMemorySnapshotStore[State any]() *InMemorySnapshotStore[State] { + return &InMemorySnapshotStore[State]{ + snapshots: make(map[string]*SessionSnapshot[State]), + } +} + +// GetSnapshot retrieves a snapshot by ID. Returns nil if not found. +func (s *InMemorySnapshotStore[State]) GetSnapshot(_ context.Context, snapshotID string) (*SessionSnapshot[State], error) { + s.mu.RLock() + defer s.mu.RUnlock() + + snap, exists := s.snapshots[snapshotID] + if !exists { + return nil, nil + } + + copied, err := copySnapshot(snap) + if err != nil { + return nil, err + } + return copied, nil +} + +// SaveSnapshot persists a snapshot. +func (s *InMemorySnapshotStore[State]) SaveSnapshot(_ context.Context, snapshot *SessionSnapshot[State]) error { + s.mu.Lock() + defer s.mu.Unlock() + + copied, err := copySnapshot(snapshot) + if err != nil { + return err + } + s.snapshots[copied.SnapshotID] = copied + return nil +} + +// copySnapshot creates a deep copy of a snapshot using JSON marshaling. +func copySnapshot[State any](snap *SessionSnapshot[State]) (*SessionSnapshot[State], error) { + if snap == nil { + return nil, nil + } + bytes, err := json.Marshal(snap) + if err != nil { + return nil, err + } + var copied SessionSnapshot[State] + if err := json.Unmarshal(bytes, &copied); err != nil { + return nil, err + } + return &copied, nil +} + +// SnapshotOn returns a SnapshotCallback that only allows snapshots for the +// specified events. For example, SnapshotOn[MyState](TurnEnd) will skip the +// invocation-end snapshot. +func SnapshotOn[State any](events ...SnapshotEvent) SnapshotCallback[State] { + set := make(map[SnapshotEvent]struct{}, len(events)) + for _, e := range events { + set[e] = struct{}{} + } + return func(_ context.Context, sc *SnapshotContext[State]) bool { + _, ok := set[sc.Event] + return ok + } +} diff --git a/go/core/api/action.go b/go/core/api/action.go index a38958af51..3150fe88ae 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -64,6 +64,8 @@ const ( ActionTypeCustom ActionType = "custom" ActionTypeCheckOperation ActionType = "check-operation" ActionTypeCancelOperation ActionType = "cancel-operation" + ActionTypeSessionFlow ActionType = "session-flow" + ActionTypeSnapshotStore ActionType = "snapshot-store" ) // ActionDesc is a descriptor of an action. diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 40e137a2b8..78298c15af 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -30,6 +30,7 @@ import ( "syscall" "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/x" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/internal/registry" @@ -407,6 +408,65 @@ func DefineBidiFlow[In, Out, Stream, Init any](g *Genkit, name string, fn core.B return core.DefineBidiFlow(g.reg, name, fn) } +// DefineSessionFlow creates a SessionFlow with automatic snapshot management +// and registers it as a flow action. +// +// A SessionFlow is a stateful, multi-turn conversational flow with automatic +// snapshot persistence and turn semantics. It builds on bidirectional streaming +// to enable ongoing conversations with managed state. +// +// Type parameters: +// - Stream: Type for status updates sent via the responder +// - State: Type for user-defined state in snapshots +// +// Example: +// +// type ChatState struct { +// TopicHistory []string `json:"topicHistory,omitempty"` +// } +// +// type ChatStatus struct { +// Phase string `json:"phase"` +// } +// +// chatFlow := genkit.DefineSessionFlow(g, "chatFlow", +// func(ctx context.Context, resp aix.Responder[ChatStatus], params *aix.SessionFlowParams[ChatState]) error { +// return params.Session.Run(ctx, func(ctx context.Context, input *aix.SessionFlowInput) error { +// // ... handle each turn ... +// return nil +// }) +// }, +// aix.WithSnapshotStore(store), +// ) +func DefineSessionFlow[Stream, State any]( + g *Genkit, + name string, + fn aix.SessionFlowFunc[Stream, State], + opts ...aix.SessionFlowOption[State], +) *aix.SessionFlow[Stream, State] { + return aix.DefineSessionFlow(g.reg, name, fn, opts...) +} + +// DefineSessionFlowFromPrompt creates a prompt-backed SessionFlow with an +// automatic conversation loop. Each turn renders the prompt, appends +// conversation history, calls the model with streaming, and updates session state. +// +// The defaultInput is used for prompt rendering unless overridden per +// invocation via [aix.WithPromptInput]. +// +// Type parameters: +// - State: Type for user-defined state in snapshots +// - PromptIn: The prompt input type (inferred from the PromptRenderer) +func DefineSessionFlowFromPrompt[State, PromptIn any]( + g *Genkit, + name string, + p aix.PromptRenderer[PromptIn], + defaultInput PromptIn, + opts ...aix.SessionFlowOption[State], +) *aix.SessionFlow[struct{}, State] { + return aix.DefineSessionFlowFromPrompt(g.reg, name, p, defaultInput, opts...) +} + // Run executes the given function `fn` within the context of the current flow run, // creating a distinct trace span for this step. It's used to add observability // to specific sub-operations within a flow defined by [DefineFlow] or [DefineStreamingFlow]. diff --git a/go/samples/basic-session-flow/main.go b/go/samples/basic-session-flow/main.go new file mode 100644 index 0000000000..50a940c253 --- /dev/null +++ b/go/samples/basic-session-flow/main.go @@ -0,0 +1,119 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This sample demonstrates the SessionFlow API for multi-turn conversation +// with token-level streaming. It runs a CLI REPL where conversation history +// is managed automatically by the session. +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "strings" + + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/x" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "google.golang.org/genai" +) + +func main() { + ctx := context.Background() + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + store := aix.NewInMemorySnapshotStore[struct{}]() + + chatFlow := genkit.DefineSessionFlow(g, "chat", + func(ctx context.Context, resp aix.Responder[any], params *aix.SessionFlowParams[struct{}]) error { + sess := params.Session + return sess.Run(ctx, func(ctx context.Context, input *aix.SessionFlowInput) error { + for chunk, err := range genkit.GenerateStream(ctx, g, + ai.WithModel(googlegenai.ModelRef("googleai/gemini-3-flash-preview", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithSystem("You are a helpful assistant. Keep responses concise."), + ai.WithMessages(sess.Messages()...), + ) { + if err != nil { + return err + } + if chunk.Done { + sess.AddMessages(chunk.Response.Message) + break + } + resp.SendChunk(chunk.Chunk) + } + + return nil + }) + }, + aix.WithSnapshotStore(store), + aix.WithSnapshotCallback(aix.SnapshotOn[struct{}](aix.SnapshotEventTurnEnd)), + ) + + fmt.Println("Session Flow Chat (type 'quit' to exit)") + fmt.Println() + + conn, err := chatFlow.StreamBidi(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + reader := bufio.NewReader(os.Stdin) + for { + fmt.Print("> ") + input, _ := reader.ReadString('\n') + input = strings.TrimSpace(input) + + if input == "quit" || input == "exit" { + break + } + if input == "" { + continue + } + + if err := conn.SendText(input); err != nil { + fmt.Fprintf(os.Stderr, "Send error: %v\n", err) + break + } + + fmt.Println() + + for chunk, err := range conn.Receive() { + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + break + } + if chunk.Chunk != nil { + fmt.Print(chunk.Chunk.Text()) + } + if chunk.SnapshotCreated != "" { + fmt.Printf("\n[snapshot: %s]", chunk.SnapshotCreated) + } + if chunk.EndTurn { + fmt.Println() + fmt.Println() + break + } + } + } + + conn.Close() +} diff --git a/go/samples/prompt-session-flow/main.go b/go/samples/prompt-session-flow/main.go new file mode 100644 index 0000000000..d05b4b8929 --- /dev/null +++ b/go/samples/prompt-session-flow/main.go @@ -0,0 +1,101 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This sample demonstrates DefineSessionFlowFromPrompt, which creates a +// multi-turn conversational session flow backed by a .prompt file. The +// conversation loop (render prompt, call model, stream chunks, update history) +// is handled automatically. Compare with basic-session-flow which wires +// the same loop manually. +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "strings" + + aix "github.com/firebase/genkit/go/ai/x" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" +) + +type ChatPromptInput struct { + Personality string `json:"personality"` +} + +func main() { + ctx := context.Background() + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + chatPrompt := genkit.LookupDataPrompt[ChatPromptInput, string](g, "chat") + + chatFlow := genkit.DefineSessionFlowFromPrompt[struct{}]( + g, "chat", chatPrompt, ChatPromptInput{Personality: "a sarcastic pirate"}, + aix.WithSnapshotStore(aix.NewInMemorySnapshotStore[struct{}]()), + aix.WithSnapshotCallback(func(ctx context.Context, sc *aix.SnapshotContext[struct{}]) bool { + return sc.Event == aix.SnapshotEventInvocationEnd || sc.TurnIndex%5 == 0 + }), + ) + + fmt.Println("Prompt Session Flow Chat (type 'quit' to exit)") + fmt.Println() + + conn, err := chatFlow.StreamBidi(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + reader := bufio.NewReader(os.Stdin) + for { + fmt.Print("> ") + input, _ := reader.ReadString('\n') + input = strings.TrimSpace(input) + + if input == "quit" || input == "exit" { + break + } + if input == "" { + continue + } + + if err := conn.SendText(input); err != nil { + fmt.Fprintf(os.Stderr, "Send error: %v\n", err) + break + } + + fmt.Println() + + for chunk, err := range conn.Receive() { + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + break + } + if chunk.Chunk != nil { + fmt.Print(chunk.Chunk.Text()) + } + if chunk.SnapshotCreated != "" { + fmt.Printf("\n[snapshot: %s]", chunk.SnapshotCreated) + } + if chunk.EndTurn { + fmt.Println() + fmt.Println() + break + } + } + } + + conn.Close() +} diff --git a/go/samples/prompt-session-flow/prompts/chat.prompt b/go/samples/prompt-session-flow/prompts/chat.prompt new file mode 100644 index 0000000000..6a78a99b07 --- /dev/null +++ b/go/samples/prompt-session-flow/prompts/chat.prompt @@ -0,0 +1,12 @@ +--- +model: googleai/gemini-3-flash-preview +config: + thinkingConfig: + thinkingBudget: 0 +input: + schema: + personality: string + default: + personality: a helpful assistant +--- +You are {{personality}}. Keep responses concise.