Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
9d1d995
Refactored formatters + added support for formatting streams.
apascal07 Dec 7, 2025
f445059
Added V2 of formatters that have legacy behavior.
apascal07 Dec 7, 2025
6dc77f1
Added back backward compatibility for JSONL.
apascal07 Dec 7, 2025
f5d9f5c
Removed v2 formats.
apascal07 Dec 12, 2025
3f1e7aa
Merge branch 'main' into ap/go-structured-streaming
apascal07 Dec 12, 2025
af4ecf6
Update gen.go
apascal07 Dec 12, 2025
39f9f74
Update extract_test.go
apascal07 Dec 12, 2025
c48d79c
Update formatter_test.go
apascal07 Dec 12, 2025
72425be
Update formatter_test.go
apascal07 Dec 15, 2025
0e5b472
Update gemini_test.go
apascal07 Dec 15, 2025
31f46e7
Rewrote formatter tests.
apascal07 Dec 15, 2025
327540a
Added iterator streaming functions and typed prompts.
apascal07 Dec 11, 2025
1d57a9f
Update generate.go
apascal07 Dec 13, 2025
b490aa8
Iterated on the new APIs.
apascal07 Dec 18, 2025
b9e90ea
Changed `Prompt` to `prompt` in `DataPrompt`.
apascal07 Dec 18, 2025
d75d708
Update prompt.go
apascal07 Dec 18, 2025
bd36083
Merge branch 'main' into ap/go-streaming
apascal07 Dec 18, 2025
e8957fe
Added some samples.
apascal07 Dec 18, 2025
334803e
Update main.go
apascal07 Dec 19, 2025
dadfcea
Merge branch 'main' into ap/go-streaming
apascal07 Dec 19, 2025
0700475
Update formatter_test.go
apascal07 Dec 19, 2025
befc705
Delete basic-prompts
apascal07 Dec 19, 2025
dfc9a65
Added unit tests.
apascal07 Dec 19, 2025
18ec2db
Further improvements.
apascal07 Dec 20, 2025
0c12ed0
Updated samples.
apascal07 Dec 23, 2025
faa9171
Update main.go
apascal07 Dec 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions go/ai/formatter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -661,17 +661,14 @@ func TestResolveFormat(t *testing.T) {
}
})

t.Run("defaults to text even when schema present but no format", func(t *testing.T) {
t.Run("defaults to json when schema present but no format", func(t *testing.T) {
schema := map[string]any{"type": "object"}
formatter, err := resolveFormat(r, schema, "")
if err != nil {
t.Fatalf("resolveFormat() error = %v", err)
}
// Note: The current implementation defaults to text when format is empty,
// even if schema is present. The schema/format combination is typically
// handled at a higher level (e.g., in Generate options).
if formatter.Name() != OutputFormatText {
t.Errorf("resolveFormat() = %q, want %q", formatter.Name(), OutputFormatText)
if formatter.Name() != OutputFormatJSON {
t.Errorf("resolveFormat() = %q, want %q", formatter.Name(), OutputFormatJSON)
}
})

Expand Down
147 changes: 135 additions & 12 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/json"
"errors"
"fmt"
"iter"
"slices"
"strings"

Expand Down Expand Up @@ -550,7 +551,7 @@ func GenerateText(ctx context.Context, r api.Registry, opts ...GenerateOption) (
return res.Text(), nil
}

// Generate run generate request for this model. Returns ModelResponse struct.
// GenerateData runs a generate request and returns strongly-typed output.
func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...GenerateOption) (*Out, *ModelResponse, error) {
var value Out
opts = append(opts, WithOutputType(value))
Expand All @@ -568,6 +569,104 @@ func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...Generate
return &value, resp, nil
}

// StreamValue is either a streamed chunk or the final response of a generate request.
type StreamValue[Out, Stream any] struct {
Done bool
Chunk Stream // valid if Done is false
Output Out // valid if Done is true
Response *ModelResponse // valid if Done is true
}

// ModelStreamValue is a stream value for a model response.
// Out is never set because the output is already available in the Response field.
type ModelStreamValue = StreamValue[struct{}, *ModelResponseChunk]

// errGenerateStop is a sentinel error used to signal early termination of streaming.
var errGenerateStop = errors.New("stop")

// GenerateStream generates a model response and streams the output.
// It returns an iterator that yields streaming results.
//
// If the yield function is passed a non-nil error, generation has failed with that
// error; the yield function will not be called again.
//
// If the yield function's [ModelStreamValue] argument has Done == true, the value's
// Response field contains the final response; the yield function will not be called
// again.
//
// Otherwise the Chunk field of the passed [ModelStreamValue] holds a streamed chunk.
func GenerateStream(ctx context.Context, r api.Registry, opts ...GenerateOption) iter.Seq2[*ModelStreamValue, error] {
return func(yield func(*ModelStreamValue, error) bool) {
cb := func(ctx context.Context, chunk *ModelResponseChunk) error {
if ctx.Err() != nil {
return ctx.Err()
}
if !yield(&ModelStreamValue{Chunk: chunk}, nil) {
return errGenerateStop
}
return nil
}

allOpts := append(slices.Clone(opts), WithStreaming(cb))

resp, err := Generate(ctx, r, allOpts...)
if err != nil {
yield(nil, err)
} else {
yield(&ModelStreamValue{Done: true, Response: resp}, nil)
}
}
}

// GenerateDataStream generates a model response with streaming and returns strongly-typed output.
// It returns an iterator that yields streaming results.
//
// If the yield function is passed a non-nil error, generation has failed with that
// error; the yield function will not be called again.
//
// If the yield function's [StreamValue] argument has Done == true, the value's
// Output and Response fields contain the final typed output and response; the yield function
// will not be called again.
//
// Otherwise the Chunk field of the passed [StreamValue] holds a streamed chunk.
func GenerateDataStream[Out any](ctx context.Context, r api.Registry, opts ...GenerateOption) iter.Seq2[*StreamValue[Out, Out], error] {
return func(yield func(*StreamValue[Out, Out], error) bool) {
cb := func(ctx context.Context, chunk *ModelResponseChunk) error {
if ctx.Err() != nil {
return ctx.Err()
}
var streamValue Out
if err := chunk.Output(&streamValue); err != nil {
yield(nil, err)
return err
}
if !yield(&StreamValue[Out, Out]{Chunk: streamValue}, nil) {
return errGenerateStop
}
return nil
}

// Prepend WithOutputType so the user can override the output format.
var value Out
allOpts := append([]GenerateOption{WithOutputType(value)}, opts...)
allOpts = append(allOpts, WithStreaming(cb))

resp, err := Generate(ctx, r, allOpts...)
if err != nil {
yield(nil, err)
return
}

output, err := extractTypedOutput[Out](resp)
if err != nil {
yield(nil, err)
return
}

yield(&StreamValue[Out, Out]{Done: true, Output: output, Response: resp}, nil)
}
}

// Generate applies the [Action] to provided request.
func (m *model) Generate(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
if m == nil {
Expand Down Expand Up @@ -744,7 +843,7 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest,
// [ModelResponse] as a string. It returns an empty string if there
// are no candidates or if the candidate has no message.
func (mr *ModelResponse) Text() string {
if mr.Message == nil {
if mr == nil || mr.Message == nil {
return ""
}
return mr.Message.Text()
Expand All @@ -753,7 +852,7 @@ func (mr *ModelResponse) Text() string {
// History returns messages from the request combined with the response message
// to represent the conversation history.
func (mr *ModelResponse) History() []*Message {
if mr.Message == nil {
if mr == nil || mr.Message == nil {
return mr.Request.Messages
}
return append(mr.Request.Messages, mr.Message)
Expand All @@ -762,7 +861,7 @@ func (mr *ModelResponse) History() []*Message {
// Reasoning concatenates all reasoning parts present in the message
func (mr *ModelResponse) Reasoning() string {
var sb strings.Builder
if mr.Message == nil {
if mr == nil || mr.Message == nil {
return ""
}

Expand Down Expand Up @@ -806,7 +905,7 @@ func (mr *ModelResponse) Output(v any) error {
// ToolRequests returns the tool requests from the response.
func (mr *ModelResponse) ToolRequests() []*ToolRequest {
toolReqs := []*ToolRequest{}
if mr.Message == nil {
if mr == nil || mr.Message == nil {
return toolReqs
}
for _, part := range mr.Message.Content {
Expand All @@ -820,7 +919,7 @@ func (mr *ModelResponse) ToolRequests() []*ToolRequest {
// Interrupts returns the interrupted tool request parts from the response.
func (mr *ModelResponse) Interrupts() []*Part {
parts := []*Part{}
if mr.Message == nil {
if mr == nil || mr.Message == nil {
return parts
}
for _, part := range mr.Message.Content {
Expand All @@ -833,7 +932,7 @@ func (mr *ModelResponse) Interrupts() []*Part {

// Media returns the media content of the [ModelResponse] as a string.
func (mr *ModelResponse) Media() string {
if mr.Message == nil {
if mr == nil || mr.Message == nil {
return ""
}
for _, part := range mr.Message.Content {
Expand Down Expand Up @@ -902,17 +1001,41 @@ func (c *ModelResponseChunk) Output(v any) error {

// outputer is an interface for types that can unmarshal structured output.
type outputer interface {
Output(v any) error
// Text returns the contents of the output as a string.
Text() string
// Output parses the structured output from the response and unmarshals it into value.
Output(value any) error
}

// OutputFrom is a convenience function that parses structured output from a
// [ModelResponse] or [ModelResponseChunk] and returns it as a typed value.
// This is equivalent to calling Output() but returns the value directly instead
// of requiring a pointer argument. If you need to handle the error, use Output() instead.
func OutputFrom[T any](src outputer) T {
var v T
src.Output(&v)
return v
func OutputFrom[Out any](src outputer) Out {
output, err := extractTypedOutput[Out](src)
if err != nil {
return base.Zero[Out]()
}
return output
}

// extractTypedOutput extracts the typed output from a model response.
// It supports string output by calling Text() and returning the result.
func extractTypedOutput[Out any](o outputer) (Out, error) {
var output Out

switch any(output).(type) {
case string:
text := o.Text()
// Type assertion to convert string to Out (which we know is string).
result := any(text).(Out)
return result, nil
default:
if err := o.Output(&output); err != nil {
return base.Zero[Out](), fmt.Errorf("failed to parse output: %w", err)
}
return output, nil
}
}

// Text returns the contents of a [Message] as a string. It
Expand Down
Loading
Loading