diff --git a/go/ai/generate.go b/go/ai/generate.go index 3c0966cfa3..83c7f34179 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -284,7 +284,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi // Native constrained output is enabled only when the user has // requested it, the model supports it, and there's a JSON schema. outputCfg.Constrained = opts.Output.JsonSchema != nil && - opts.Output.Constrained && outputCfg.Constrained && m.(*model).supportsConstrained(len(toolDefs) > 0) + opts.Output.Constrained && outputCfg.Constrained && m != nil && m.(*model).supportsConstrained(len(toolDefs) > 0) // Add schema instructions to prompt when not using native constraints. // This is a no-op for unstructured output requests. @@ -313,12 +313,14 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi Output: &outputCfg, } - fn := m.Generate + var fn ModelFunc if bm != nil { if cb != nil { logger.FromContext(ctx).Warn("background model does not support streaming", "model", bm.Name()) } fn = backgroundModelToModelFn(bm.Start) + } else { + fn = m.Generate } fn = core.ChainMiddleware(mw...)(fn) diff --git a/go/plugins/googlegenai/actions.go b/go/plugins/googlegenai/actions.go new file mode 100644 index 0000000000..50c9c79218 --- /dev/null +++ b/go/plugins/googlegenai/actions.go @@ -0,0 +1,151 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package googlegenai + +import ( + "context" + "fmt" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" + "google.golang.org/genai" +) + +// ListActions lists all the actions supported by the Google AI plugin. +func (ga *GoogleAI) ListActions(ctx context.Context) []api.ActionDesc { + return listActions(ctx, ga.gclient, googleAIProvider) +} + +// ListActions lists all the actions supported by the Vertex AI plugin. +func (v *VertexAI) ListActions(ctx context.Context) []api.ActionDesc { + return listActions(ctx, v.gclient, vertexAIProvider) +} + +// listActions is the shared implementation for listing actions. +func listActions(ctx context.Context, client *genai.Client, provider string) []api.ActionDesc { + models, err := listGenaiModels(ctx, client) + if err != nil { + return nil + } + + actions := []api.ActionDesc{} + + // Gemini models + for _, name := range models.gemini { + opts := GetModelOptions(name, provider) + model := newModel(client, name, opts) + if actionDef, ok := model.(api.Action); ok { + actions = append(actions, actionDef.Desc()) + } + } + + // Imagen models + for _, name := range models.imagen { + opts := GetModelOptions(name, provider) + model := newModel(client, name, opts) + if actionDef, ok := model.(api.Action); ok { + actions = append(actions, actionDef.Desc()) + } + } + + // Veo models (background models) + for _, name := range models.veo { + opts := GetModelOptions(name, provider) + veoModel := newVeoModel(client, name, opts) + if actionDef, ok := veoModel.(api.Action); ok { + actions = append(actions, actionDef.Desc()) + } + } + + // Embedders + for _, name := range models.embedders { + opts := GetEmbedderOptions(name, provider) + embedder := newEmbedder(client, name, &opts) + if actionDef, ok := embedder.(api.Action); ok { + actions = append(actions, actionDef.Desc()) + } + } + + return actions +} + +// ResolveAction resolves an action with the given name. +func (ga *GoogleAI) ResolveAction(atype api.ActionType, name string) api.Action { + return resolveAction(ga.gclient, googleAIProvider, atype, name) +} + +// ResolveAction resolves an action with the given name. +func (v *VertexAI) ResolveAction(atype api.ActionType, name string) api.Action { + return resolveAction(v.gclient, vertexAIProvider, atype, name) +} + +// resolveAction is the shared implementation for resolving actions. +func resolveAction(client *genai.Client, provider string, atype api.ActionType, name string) api.Action { + mt := ClassifyModel(name) + + switch atype { + case api.ActionTypeEmbedder: + opts := GetEmbedderOptions(name, provider) + return newEmbedder(client, name, &opts).(api.Action) + + case api.ActionTypeModel: + // Veo models should not be resolved as regular models + if mt == ModelTypeVeo { + return nil + } + opts := GetModelOptions(name, provider) + opts.ConfigSchema = configToMap(mt.DefaultConfig()) + return newModel(client, name, opts).(api.Action) + + case api.ActionTypeBackgroundModel: + if mt != ModelTypeVeo { + return nil + } + return createVeoBackgroundAction(client, name, provider) + + case api.ActionTypeCheckOperation: + if mt != ModelTypeVeo { + return nil + } + return createVeoCheckAction(client, name, provider) + } + + return nil +} + +// createVeoBackgroundAction creates a background model action for Veo. +func createVeoBackgroundAction(client *genai.Client, name, provider string) api.Action { + opts := GetModelOptions(name, provider) + veoModel := newVeoModel(client, name, opts) + actionName := fmt.Sprintf("%s/%s", provider, name) + + return core.NewAction(actionName, api.ActionTypeBackgroundModel, nil, nil, + func(ctx context.Context, input *ai.ModelRequest) (*core.Operation[*ai.ModelResponse], error) { + op, err := veoModel.Start(ctx, input) + if err != nil { + return nil, err + } + op.Action = api.KeyFromName(api.ActionTypeBackgroundModel, actionName) + return op, nil + }) +} + +// createVeoCheckAction creates a check operation action for Veo. +func createVeoCheckAction(client *genai.Client, name, provider string) api.Action { + opts := GetModelOptions(name, provider) + veoModel := newVeoModel(client, name, opts) + actionName := fmt.Sprintf("%s/%s", provider, name) + + return core.NewAction(actionName, api.ActionTypeCheckOperation, + map[string]any{"description": fmt.Sprintf("Check status of %s operation", name)}, nil, + func(ctx context.Context, op *core.Operation[*ai.ModelResponse]) (*core.Operation[*ai.ModelResponse], error) { + updatedOp, err := veoModel.Check(ctx, op) + if err != nil { + return nil, err + } + updatedOp.Action = api.KeyFromName(api.ActionTypeBackgroundModel, actionName) + return updatedOp, nil + }) +} diff --git a/go/plugins/googlegenai/code_execution.go b/go/plugins/googlegenai/code_execution.go new file mode 100644 index 0000000000..71251fedb3 --- /dev/null +++ b/go/plugins/googlegenai/code_execution.go @@ -0,0 +1,121 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package googlegenai + +import ( + "github.com/firebase/genkit/go/ai" +) + +// CodeExecutionResult represents the result of a code execution. +type CodeExecutionResult struct { + Outcome string `json:"outcome"` + Output string `json:"output"` +} + +// ExecutableCode represents executable code. +type ExecutableCode struct { + Language string `json:"language"` + Code string `json:"code"` +} + +// newCodeExecutionResultPart returns a Part containing the result of code execution. +// This is internal and used by translateCandidate. +func newCodeExecutionResultPart(outcome string, output string) *ai.Part { + return ai.NewCustomPart(map[string]any{ + "codeExecutionResult": map[string]any{ + "outcome": outcome, + "output": output, + }, + }) +} + +// newExecutableCodePart returns a Part containing executable code. +// This is internal and used by translateCandidate. +func newExecutableCodePart(language string, code string) *ai.Part { + return ai.NewCustomPart(map[string]any{ + "executableCode": map[string]any{ + "language": language, + "code": code, + }, + }) +} + +// ToCodeExecutionResult tries to convert an ai.Part to a CodeExecutionResult. +// Returns nil if the part doesn't contain code execution results. +func ToCodeExecutionResult(part *ai.Part) *CodeExecutionResult { + if !part.IsCustom() { + return nil + } + + codeExec, ok := part.Custom["codeExecutionResult"] + if !ok { + return nil + } + + result, ok := codeExec.(map[string]any) + if !ok { + return nil + } + + outcome, _ := result["outcome"].(string) + output, _ := result["output"].(string) + + return &CodeExecutionResult{ + Outcome: outcome, + Output: output, + } +} + +// ToExecutableCode tries to convert an ai.Part to an ExecutableCode. +// Returns nil if the part doesn't contain executable code. +func ToExecutableCode(part *ai.Part) *ExecutableCode { + if !part.IsCustom() { + return nil + } + + execCode, ok := part.Custom["executableCode"] + if !ok { + return nil + } + + code, ok := execCode.(map[string]any) + if !ok { + return nil + } + + language, _ := code["language"].(string) + codeStr, _ := code["code"].(string) + + return &ExecutableCode{ + Language: language, + Code: codeStr, + } +} + +// HasCodeExecution checks if a message contains code execution results or executable code. +func HasCodeExecution(msg *ai.Message) bool { + return GetCodeExecutionResult(msg) != nil || GetExecutableCode(msg) != nil +} + +// GetExecutableCode returns the first executable code from a message. +// Returns nil if the message doesn't contain executable code. +func GetExecutableCode(msg *ai.Message) *ExecutableCode { + for _, part := range msg.Content { + if code := ToExecutableCode(part); code != nil { + return code + } + } + return nil +} + +// GetCodeExecutionResult returns the first code execution result from a message. +// Returns nil if the message doesn't contain a code execution result. +func GetCodeExecutionResult(msg *ai.Message) *CodeExecutionResult { + for _, part := range msg.Content { + if result := ToCodeExecutionResult(part); result != nil { + return result + } + } + return nil +} diff --git a/go/plugins/googlegenai/embedder.go b/go/plugins/googlegenai/embedder.go new file mode 100644 index 0000000000..759d5a7332 --- /dev/null +++ b/go/plugins/googlegenai/embedder.go @@ -0,0 +1,54 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package googlegenai + +import ( + "context" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" + "google.golang.org/genai" +) + +// newEmbedder creates an embedder without registering it. +func newEmbedder(client *genai.Client, name string, embedOpts *ai.EmbedderOptions) ai.Embedder { + provider := googleAIProvider + if client.ClientConfig().Backend == genai.BackendVertexAI { + provider = vertexAIProvider + } + + if embedOpts.ConfigSchema == nil { + embedOpts.ConfigSchema = core.InferSchemaMap(genai.EmbedContentConfig{}) + } + + return ai.NewEmbedder(api.NewName(provider, name), embedOpts, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { + var content []*genai.Content + var embedConfig *genai.EmbedContentConfig + + if config, ok := req.Options.(*genai.EmbedContentConfig); ok { + embedConfig = config + } + + for _, doc := range req.Input { + parts, err := toGeminiParts(doc.Content) + if err != nil { + return nil, err + } + content = append(content, &genai.Content{ + Parts: parts, + }) + } + + r, err := genai.Models.EmbedContent(*client.Models, ctx, name, content, embedConfig) + if err != nil { + return nil, err + } + var res ai.EmbedResponse + for _, emb := range r.Embeddings { + res.Embeddings = append(res.Embeddings, &ai.Embedding{Embedding: emb.Values}) + } + return &res, nil + }) +} diff --git a/go/plugins/googlegenai/gemini.go b/go/plugins/googlegenai/gemini.go index bfd9ef5410..5ad7bd7cee 100644 --- a/go/plugins/googlegenai/gemini.go +++ b/go/plugins/googlegenai/gemini.go @@ -19,15 +19,11 @@ package googlegenai import ( "context" "encoding/base64" - "encoding/json" "errors" "fmt" "net/http" "net/url" - "reflect" - "regexp" "slices" - "strconv" "strings" "github.com/firebase/genkit/go/ai" @@ -40,31 +36,7 @@ import ( "google.golang.org/genai" ) -const ( - // Tool name regex - toolNameRegex = "^[a-zA-Z_][a-zA-Z0-9_.-]{0,63}$" -) - var ( - // BasicText describes model capabilities for text-only Gemini models. - BasicText = ai.ModelSupports{ - Multiturn: true, - Tools: true, - ToolChoice: true, - SystemRole: true, - Media: false, - } - - // Multimodal describes model capabilities for multimodal Gemini models. - Multimodal = ai.ModelSupports{ - Multiturn: true, - Tools: true, - ToolChoice: true, - SystemRole: true, - Media: true, - Constrained: ai.ConstrainedSupportNoTools, - } - // Attribution header xGoogApiClientHeader = http.CanonicalHeaderKey("x-goog-api-client") genkitClientHeader = http.Header{ @@ -72,16 +44,6 @@ var ( } ) -// EmbedOptions are options for the Vertex AI embedder. -// Set [ai.EmbedRequest.Options] to a value of type *[EmbedOptions]. -type EmbedOptions struct { - // Document title. - Title string `json:"title,omitempty"` - // Task type: RETRIEVAL_QUERY, RETRIEVAL_DOCUMENT, and so forth. - // See the Vertex AI text embedding docs. - TaskType string `json:"task_type,omitempty"` -} - // configToMap converts a config struct to a map[string]any. func configToMap(config any) map[string]any { r := jsonschema.Reflector{ @@ -124,21 +86,16 @@ func configFromRequest(input *ai.ModelRequest) (*genai.GenerateContentConfig, er return &result, nil } -// newModel creates a model without registering it +// newModel creates a model without registering it. func newModel(client *genai.Client, name string, opts ai.ModelOptions) ai.Model { provider := googleAIProvider if client.ClientConfig().Backend == genai.BackendVertexAI { provider = vertexAIProvider } - var config any - config = &genai.GenerateContentConfig{} - if strings.Contains(name, "imagen") { - config = &genai.GenerateImagesConfig{} - } else if vi, fnd := supportedVideoModels[name]; fnd { - config = &genai.GenerateVideosConfig{} - opts = vi - } + mt := ClassifyModel(name) + config := mt.DefaultConfig() + meta := &ai.ModelOptions{ Label: opts.Label, Supports: opts.Supports, @@ -152,13 +109,14 @@ func newModel(client *genai.Client, name string, opts ai.ModelOptions) ai.Model input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error, ) (*ai.ModelResponse, error) { - switch config.(type) { - case *genai.GenerateImagesConfig: + switch mt { + case ModelTypeImagen: return generateImage(ctx, client, name, input, cb) default: return generate(ctx, client, name, input, cb) } } + // the gemini api doesn't support downloading media from http(s) if opts.Supports.Media { fn = core.ChainMiddleware(ai.DownloadRequestMedia(&ai.DownloadMediaOptions{ @@ -182,49 +140,8 @@ func newModel(client *genai.Client, name string, opts ai.ModelOptions) ai.Model return ai.NewModel(api.NewName(provider, name), meta, fn) } -// newEmbedder creates an embedder without registering it -func newEmbedder(client *genai.Client, name string, embedOpts *ai.EmbedderOptions) ai.Embedder { - provider := googleAIProvider - if client.ClientConfig().Backend == genai.BackendVertexAI { - provider = vertexAIProvider - } - - if embedOpts.ConfigSchema == nil { - embedOpts.ConfigSchema = core.InferSchemaMap(genai.EmbedContentConfig{}) - } - - return ai.NewEmbedder(api.NewName(provider, name), embedOpts, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { - var content []*genai.Content - var embedConfig *genai.EmbedContentConfig - - if config, ok := req.Options.(*genai.EmbedContentConfig); ok { - embedConfig = config - } - - for _, doc := range req.Input { - parts, err := toGeminiParts(doc.Content) - if err != nil { - return nil, err - } - content = append(content, &genai.Content{ - Parts: parts, - }) - } - - r, err := genai.Models.EmbedContent(*client.Models, ctx, name, content, embedConfig) - if err != nil { - return nil, err - } - var res ai.EmbedResponse - for _, emb := range r.Embeddings { - res.Embeddings = append(res.Embeddings, &ai.Embedding{Embedding: emb.Values}) - } - return &res, nil - }) -} - -// Generate requests generate call to the specified model with the provided -// configuration +// generate requests generate call to the specified model with the provided +// configuration. func generate( ctx context.Context, client *genai.Client, @@ -447,344 +364,6 @@ func toGeminiRequest(input *ai.ModelRequest, cache *genai.CachedContent) (*genai return gcc, nil } -// toGeminiTools translates a slice of [ai.ToolDefinition] to a slice of [genai.Tool]. -func toGeminiTools(inTools []*ai.ToolDefinition) ([]*genai.Tool, error) { - var outTools []*genai.Tool - functions := []*genai.FunctionDeclaration{} - - for _, t := range inTools { - if !validToolName(t.Name) { - return nil, fmt.Errorf(`invalid tool name: %q, must start with a letter or an underscore, must be alphanumeric, underscores, dots or dashes with a max length of 64 chars`, t.Name) - } - inputSchema, err := toGeminiSchema(t.InputSchema, t.InputSchema) - if err != nil { - return nil, err - } - fd := &genai.FunctionDeclaration{ - Name: t.Name, - Parameters: inputSchema, - Description: t.Description, - } - functions = append(functions, fd) - } - - if len(functions) > 0 { - outTools = append(outTools, &genai.Tool{ - FunctionDeclarations: functions, - }) - } - - return outTools, nil -} - -// toGeminiFunctionResponsePart translates a slice of [ai.Part] to a slice of [genai.FunctionResponsePart] -func toGeminiFunctionResponsePart(parts []*ai.Part) ([]*genai.FunctionResponsePart, error) { - frp := []*genai.FunctionResponsePart{} - for _, p := range parts { - switch { - case p.IsData(): - contentType, data, err := uri.Data(p) - if err != nil { - return nil, err - } - frp = append(frp, genai.NewFunctionResponsePartFromBytes(data, contentType)) - case p.IsMedia(): - if strings.HasPrefix(p.Text, "data:") { - contentType, data, err := uri.Data(p) - if err != nil { - return nil, err - } - frp = append(frp, genai.NewFunctionResponsePartFromBytes(data, contentType)) - continue - } - frp = append(frp, genai.NewFunctionResponsePartFromURI(p.Text, p.ContentType)) - default: - return nil, fmt.Errorf("unsupported function response part type: %d", p.Kind) - } - } - return frp, nil -} - -// mergeTools consolidates all FunctionDeclarations into a single Tool -// while preserving non-function tools (Retrieval, GoogleSearch, CodeExecution, etc.) -func mergeTools(ts []*genai.Tool) []*genai.Tool { - var decls []*genai.FunctionDeclaration - var out []*genai.Tool - - for _, t := range ts { - if t == nil { - continue - } - if len(t.FunctionDeclarations) == 0 { - out = append(out, t) - continue - } - decls = append(decls, t.FunctionDeclarations...) - if cpy := cloneToolWithoutFunctions(t); cpy != nil && !reflect.ValueOf(*cpy).IsZero() { - out = append(out, cpy) - } - } - - if len(decls) > 0 { - out = append([]*genai.Tool{{FunctionDeclarations: decls}}, out...) - } - return out -} - -func cloneToolWithoutFunctions(t *genai.Tool) *genai.Tool { - if t == nil { - return nil - } - clone := *t - clone.FunctionDeclarations = nil - return &clone -} - -// toGeminiSchema translates a map representing a standard JSON schema to a more -// limited [genai.Schema]. -func toGeminiSchema(originalSchema map[string]any, genkitSchema map[string]any) (*genai.Schema, error) { - // this covers genkitSchema == nil and {} - // genkitSchema will be {} if it's any - if len(genkitSchema) == 0 { - return nil, nil - } - if v, ok := genkitSchema["$ref"]; ok { - ref, ok := v.(string) - if !ok { - return nil, fmt.Errorf("invalid $ref value: not a string") - } - s, err := resolveRef(originalSchema, ref) - if err != nil { - return nil, err - } - return toGeminiSchema(originalSchema, s) - } - - // Handle "anyOf" subschemas by finding the first valid schema definition - if v, ok := genkitSchema["anyOf"]; ok { - if anyOfList, isList := v.([]map[string]any); isList { - for _, subSchema := range anyOfList { - if subSchemaType, hasType := subSchema["type"]; hasType { - if typeStr, isString := subSchemaType.(string); isString && typeStr != "null" { - if title, ok := genkitSchema["title"]; ok { - subSchema["title"] = title - } - if description, ok := genkitSchema["description"]; ok { - subSchema["description"] = description - } - // Found a schema like: {"type": "string"} - return toGeminiSchema(originalSchema, subSchema) - } - } - } - } - } - - schema := &genai.Schema{} - typeVal, ok := genkitSchema["type"] - if !ok { - return nil, fmt.Errorf("schema is missing the 'type' field: %#v", genkitSchema) - } - - typeStr, ok := typeVal.(string) - if !ok { - return nil, fmt.Errorf("schema 'type' field is not a string, but %T", typeVal) - } - - switch typeStr { - case "string": - schema.Type = genai.TypeString - case "float64", "number": - schema.Type = genai.TypeNumber - case "integer": - schema.Type = genai.TypeInteger - case "boolean": - schema.Type = genai.TypeBoolean - case "object": - schema.Type = genai.TypeObject - case "array": - schema.Type = genai.TypeArray - default: - return nil, fmt.Errorf("schema type %q not allowed", genkitSchema["type"]) - } - if v, ok := genkitSchema["required"]; ok { - schema.Required = castToStringArray(v) - } - if v, ok := genkitSchema["propertyOrdering"]; ok { - schema.PropertyOrdering = castToStringArray(v) - } - if v, ok := genkitSchema["description"]; ok { - schema.Description = v.(string) - } - if v, ok := genkitSchema["format"]; ok { - schema.Format = v.(string) - } - if v, ok := genkitSchema["title"]; ok { - schema.Title = v.(string) - } - if v, ok := genkitSchema["minItems"]; ok { - if i64, ok := castToInt64(v); ok { - schema.MinItems = genai.Ptr(i64) - } - } - if v, ok := genkitSchema["maxItems"]; ok { - if i64, ok := castToInt64(v); ok { - schema.MaxItems = genai.Ptr(i64) - } - } - if v, ok := genkitSchema["maximum"]; ok { - if f64, ok := castToFloat64(v); ok { - schema.Maximum = genai.Ptr(f64) - } - } - if v, ok := genkitSchema["minimum"]; ok { - if f64, ok := castToFloat64(v); ok { - schema.Minimum = genai.Ptr(f64) - } - } - if v, ok := genkitSchema["enum"]; ok { - schema.Enum = castToStringArray(v) - } - if v, ok := genkitSchema["items"]; ok { - items, err := toGeminiSchema(originalSchema, v.(map[string]any)) - if err != nil { - return nil, err - } - schema.Items = items - } - if val, ok := genkitSchema["properties"]; ok { - props := map[string]*genai.Schema{} - for k, v := range val.(map[string]any) { - p, err := toGeminiSchema(originalSchema, v.(map[string]any)) - if err != nil { - return nil, err - } - props[k] = p - } - schema.Properties = props - } - // Nullable -- not supported in jsonschema.Schema - - return schema, nil -} - -func resolveRef(originalSchema map[string]any, ref string) (map[string]any, error) { - tkns := strings.Split(ref, "/") - // refs look like: $/ref/foo -- we need the foo part - name := tkns[len(tkns)-1] - if defs, ok := originalSchema["$defs"].(map[string]any); ok { - if def, ok := defs[name].(map[string]any); ok { - return def, nil - } - } - // definitions (legacy) - if defs, ok := originalSchema["definitions"].(map[string]any); ok { - if def, ok := defs[name].(map[string]any); ok { - return def, nil - } - } - return nil, fmt.Errorf("unable to resolve schema reference") -} - -// castToStringArray converts either []any or []string to []string, filtering non-strings. -// This handles enum values from JSON Schema which may come as either type depending on unmarshaling. -// Filter out non-string types from if v is []any type. -func castToStringArray(v any) []string { - switch a := v.(type) { - case []string: - // Return a shallow copy to avoid aliasing - out := make([]string, 0, len(a)) - for _, s := range a { - if s != "" { - out = append(out, s) - } - } - return out - case []any: - var out []string - for _, it := range a { - if s, ok := it.(string); ok && s != "" { - out = append(out, s) - } - } - return out - default: - return nil - } -} - -// castToInt64 converts v to int64 when possible. -func castToInt64(v any) (int64, bool) { - switch t := v.(type) { - case int: - return int64(t), true - case int64: - return t, true - case float64: - return int64(t), true - case string: - if i, err := strconv.ParseInt(t, 10, 64); err == nil { - return i, true - } - case json.Number: - if i, err := t.Int64(); err == nil { - return i, true - } - } - return 0, false -} - -// castToFloat64 converts v to float64 when possible. -func castToFloat64(v any) (float64, bool) { - switch t := v.(type) { - case float64: - return t, true - case int: - return float64(t), true - case int64: - return float64(t), true - case string: - if f, err := strconv.ParseFloat(t, 64); err == nil { - return f, true - } - case json.Number: - if f, err := t.Float64(); err == nil { - return f, true - } - } - return 0, false -} - -func toGeminiToolChoice(toolChoice ai.ToolChoice, tools []*ai.ToolDefinition) (*genai.ToolConfig, error) { - var mode genai.FunctionCallingConfigMode - switch toolChoice { - case "": - return nil, nil - case ai.ToolChoiceAuto: - mode = genai.FunctionCallingConfigModeAuto - case ai.ToolChoiceRequired: - mode = genai.FunctionCallingConfigModeAny - case ai.ToolChoiceNone: - mode = genai.FunctionCallingConfigModeNone - default: - return nil, fmt.Errorf("tool choice mode %q not supported", toolChoice) - } - - var toolNames []string - // Per docs, only set AllowedToolNames with mode set to ANY. - if mode == genai.FunctionCallingConfigModeAny { - for _, t := range tools { - toolNames = append(toolNames, t.Name) - } - } - return &genai.ToolConfig{ - FunctionCallingConfig: &genai.FunctionCallingConfig{ - Mode: mode, - AllowedFunctionNames: toolNames, - }, - }, nil -} - // translateCandidate translates from a genai.GenerateContentResponse to an ai.ModelResponse. func translateCandidate(cand *genai.Candidate) (*ai.ModelResponse, error) { m := &ai.ModelResponse{} @@ -865,14 +444,14 @@ func translateCandidate(cand *genai.Candidate) (*ai.ModelResponse, error) { } if part.CodeExecutionResult != nil { partFound++ - p = NewCodeExecutionResultPart( + p = newCodeExecutionResultPart( string(part.CodeExecutionResult.Outcome), part.CodeExecutionResult.Output, ) } if part.ExecutableCode != nil { partFound++ - p = NewExecutableCodePart( + p = newExecutableCodePart( string(part.ExecutableCode.Language), part.ExecutableCode.Code, ) @@ -897,7 +476,7 @@ func translateCandidate(cand *genai.Candidate) (*ai.ModelResponse, error) { return m, nil } -// Translate from a genai.GenerateContentResponse to a ai.ModelResponse. +// translateResponse translates from a genai.GenerateContentResponse to a ai.ModelResponse. func translateResponse(resp *genai.GenerateContentResponse) (*ai.ModelResponse, error) { var r *ai.ModelResponse var err error @@ -1027,125 +606,3 @@ func toGeminiPart(p *ai.Part) (*genai.Part, error) { return gp, nil } - -// validToolName checks whether the provided tool name matches the -// following criteria: -// - Start with a letter or an underscore -// - Must be alphanumeric and can include underscores, dots or dashes -// - Maximum length of 64 chars -func validToolName(n string) bool { - re := regexp.MustCompile(toolNameRegex) - - return re.MatchString(n) -} - -// CodeExecutionResult represents the result of a code execution. -type CodeExecutionResult struct { - Outcome string `json:"outcome"` - Output string `json:"output"` -} - -// ExecutableCode represents executable code. -type ExecutableCode struct { - Language string `json:"language"` - Code string `json:"code"` -} - -// NewCodeExecutionResultPart returns a Part containing the result of code execution. -func NewCodeExecutionResultPart(outcome string, output string) *ai.Part { - return ai.NewCustomPart(map[string]any{ - "codeExecutionResult": map[string]any{ - "outcome": outcome, - "output": output, - }, - }) -} - -// NewExecutableCodePart returns a Part containing executable code. -func NewExecutableCodePart(language string, code string) *ai.Part { - return ai.NewCustomPart(map[string]any{ - "executableCode": map[string]any{ - "language": language, - "code": code, - }, - }) -} - -// ToCodeExecutionResult tries to convert an ai.Part to a CodeExecutionResult. -// Returns nil if the part doesn't contain code execution results. -func ToCodeExecutionResult(part *ai.Part) *CodeExecutionResult { - if !part.IsCustom() { - return nil - } - - codeExec, ok := part.Custom["codeExecutionResult"] - if !ok { - return nil - } - - result, ok := codeExec.(map[string]any) - if !ok { - return nil - } - - outcome, _ := result["outcome"].(string) - output, _ := result["output"].(string) - - return &CodeExecutionResult{ - Outcome: outcome, - Output: output, - } -} - -// ToExecutableCode tries to convert an ai.Part to an ExecutableCode. -// Returns nil if the part doesn't contain executable code. -func ToExecutableCode(part *ai.Part) *ExecutableCode { - if !part.IsCustom() { - return nil - } - - execCode, ok := part.Custom["executableCode"] - if !ok { - return nil - } - - code, ok := execCode.(map[string]any) - if !ok { - return nil - } - - language, _ := code["language"].(string) - codeStr, _ := code["code"].(string) - - return &ExecutableCode{ - Language: language, - Code: codeStr, - } -} - -// HasCodeExecution checks if a message contains code execution results or executable code. -func HasCodeExecution(msg *ai.Message) bool { - return GetCodeExecutionResult(msg) != nil || GetExecutableCode(msg) != nil -} - -// GetExecutableCode returns the first executable code from a message. -// Returns nil if the message doesn't contain executable code. -func GetExecutableCode(msg *ai.Message) *ExecutableCode { - for _, part := range msg.Content { - if code := ToExecutableCode(part); code != nil { - return code - } - } - return nil -} - -// GetCodeExecutionResult returns the first code execution result from a message. -// Returns nil if the message doesn't contain a code execution result. -func GetCodeExecutionResult(msg *ai.Message) *CodeExecutionResult { - for _, part := range msg.Content { - if result := ToCodeExecutionResult(part); result != nil { - return result - } - } - return nil -} diff --git a/go/plugins/googlegenai/googleai_live_test.go b/go/plugins/googlegenai/googleai_live_test.go index 783eccd239..f5b2375bae 100644 --- a/go/plugins/googlegenai/googleai_live_test.go +++ b/go/plugins/googlegenai/googleai_live_test.go @@ -70,8 +70,6 @@ func TestGoogleAILive(t *testing.T) { genkit.WithPlugins(&googlegenai.GoogleAI{APIKey: apiKey}), ) - embedder := googlegenai.GoogleAIEmbedder(g, "embedding-001") - gablorkenTool := genkit.DefineTool(g, "gablorken", "use this tool when the user asks to calculate a gablorken, carefuly inspect the user input to determine which value from the prompt corresponds to the input structure", func(ctx *ai.ToolContext, input struct { Value int @@ -89,7 +87,7 @@ func TestGoogleAILive(t *testing.T) { ) t.Run("embedder", func(t *testing.T) { - res, err := genkit.Embed(ctx, g, ai.WithEmbedder(embedder), ai.WithTextDocs("yellow banana")) + res, err := genkit.Embed(ctx, g, ai.WithEmbedderName("googleai/gemini-embedding-001"), ai.WithTextDocs("yellow banana")) if err != nil { t.Fatal(err) } diff --git a/go/plugins/googlegenai/googlegenai.go b/go/plugins/googlegenai/googlegenai.go index 04f361f74b..1009f5d2ab 100644 --- a/go/plugins/googlegenai/googlegenai.go +++ b/go/plugins/googlegenai/googlegenai.go @@ -9,13 +9,11 @@ import ( "fmt" "net/http" "os" - "strings" "sync" "cloud.google.com/go/auth/credentials" "cloud.google.com/go/auth/httptransport" "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/genkit" @@ -31,27 +29,6 @@ const ( vertexAILabelPrefix = "Vertex AI" ) -var ( - defaultGeminiOpts = ai.ModelOptions{ - Supports: &Multimodal, - Versions: []string{}, - Stage: ai.ModelStageUnstable, - } - - defaultImagenOpts = ai.ModelOptions{ - Supports: &Media, - Versions: []string{}, - Stage: ai.ModelStageUnstable, - } - - defaultEmbedOpts = ai.EmbedderOptions{ - Supports: &ai.EmbedderSupports{ - Input: []string{"text"}, - }, - Dimensions: 768, - } -) - // GoogleAI is a Genkit plugin for interacting with the Google AI service. type GoogleAI struct { APIKey string // API key to access the service. If empty, the values of the environment variables GEMINI_API_KEY or GOOGLE_API_KEY will be consulted, in that order. @@ -283,283 +260,34 @@ func (v *VertexAI) IsDefinedEmbedder(g *genkit.Genkit, name string) bool { return genkit.LookupEmbedder(g, api.NewName(vertexAIProvider, name)) != nil } -// ModelRef creates a new ModelRef for a Google Gen AI model with the given name and configuration. -func ModelRef(name string, config *genai.GenerateContentConfig) ai.ModelRef { - return ai.NewModelRef(name, config) -} - -// GoogleAIModelRef creates a new ModelRef for a Google AI model with the given ID and configuration. -func GoogleAIModelRef(id string, config *genai.GenerateContentConfig) ai.ModelRef { - return ai.NewModelRef(googleAIProvider+"/"+id, config) -} - -// VertexAIModelRef creates a new ModelRef for a Vertex AI model with the given ID and configuration. -func VertexAIModelRef(id string, config *genai.GenerateContentConfig) ai.ModelRef { - return ai.NewModelRef(vertexAIProvider+"/"+id, config) -} - // GoogleAIModel returns the [ai.Model] with the given name. // It returns nil if the model was not defined. +// +// Deprecated: Use genkit.LookupModel instead. func GoogleAIModel(g *genkit.Genkit, name string) ai.Model { return genkit.LookupModel(g, api.NewName(googleAIProvider, name)) } // VertexAIModel returns the [ai.Model] with the given name. // It returns nil if the model was not defined. +// +// Deprecated: Use genkit.LookupModel instead. func VertexAIModel(g *genkit.Genkit, name string) ai.Model { return genkit.LookupModel(g, api.NewName(vertexAIProvider, name)) } // GoogleAIEmbedder returns the [ai.Embedder] with the given name. // It returns nil if the embedder was not defined. +// +// Deprecated: Use genkit.LookupEmbedder instead. func GoogleAIEmbedder(g *genkit.Genkit, name string) ai.Embedder { return genkit.LookupEmbedder(g, api.NewName(googleAIProvider, name)) } // VertexAIEmbedder returns the [ai.Embedder] with the given name. // It returns nil if the embedder was not defined. +// +// Deprecated: Use genkit.LookupEmbedder instead. func VertexAIEmbedder(g *genkit.Genkit, name string) ai.Embedder { return genkit.LookupEmbedder(g, api.NewName(vertexAIProvider, name)) } - -// ListActions lists all the actions supported by the Google AI plugin. -func (ga *GoogleAI) ListActions(ctx context.Context) []api.ActionDesc { - models, err := listGenaiModels(ctx, ga.gclient) - if err != nil { - return nil - } - - actions := []api.ActionDesc{} - - // Generative models. - for _, name := range models.gemini { - var opts ai.ModelOptions - if knownOpts, ok := supportedGeminiModels[name]; ok { - opts = knownOpts - opts.Label = fmt.Sprintf("%s - %s", googleAILabelPrefix, opts.Label) - } else { - opts = defaultGeminiOpts - opts.Label = fmt.Sprintf("%s - %s", googleAILabelPrefix, name) - } - - model := newModel(ga.gclient, name, opts) - if actionDef, ok := model.(api.Action); ok { - actions = append(actions, actionDef.Desc()) - } - } - - // Imagen models. - for _, name := range models.imagen { - var opts ai.ModelOptions - if knownOpts, ok := supportedImagenModels[name]; ok { - opts = knownOpts - opts.Label = fmt.Sprintf("%s - %s", googleAILabelPrefix, opts.Label) - } else { - opts = defaultImagenOpts - opts.Label = fmt.Sprintf("%s - %s", googleAILabelPrefix, name) - } - - model := newModel(ga.gclient, name, opts) - if actionDef, ok := model.(api.Action); ok { - actions = append(actions, actionDef.Desc()) - } - } - - // Embedders. - for _, e := range models.embedders { - var embedOpts ai.EmbedderOptions - if knownOpts, ok := googleAIEmbedderConfig[e]; ok { - embedOpts = knownOpts - } else { - embedOpts = defaultEmbedOpts - embedOpts.Label = fmt.Sprintf("%s - %s", googleAILabelPrefix, e) - } - - embedder := newEmbedder(ga.gclient, e, &embedOpts) - if actionDef, ok := embedder.(api.Action); ok { - actions = append(actions, actionDef.Desc()) - } - } - - return actions -} - -// ResolveAction resolves an action with the given name. -func (ga *GoogleAI) ResolveAction(atype api.ActionType, name string) api.Action { - switch atype { - case api.ActionTypeEmbedder: - return newEmbedder(ga.gclient, name, &ai.EmbedderOptions{}).(api.Action) - case api.ActionTypeModel: - var supports *ai.ModelSupports - var config any - - // TODO: Add veo case. - switch { - case strings.Contains(name, "imagen"): - supports = &Media - config = &genai.GenerateImagesConfig{} - default: - supports = &Multimodal - config = &genai.GenerateContentConfig{} - } - - return newModel(ga.gclient, name, ai.ModelOptions{ - Label: fmt.Sprintf("%s - %s", googleAILabelPrefix, name), - Stage: ai.ModelStageStable, - Versions: []string{}, - Supports: supports, - ConfigSchema: configToMap(config), - }).(api.Action) - case api.ActionTypeBackgroundModel: - // Handle VEO models as background models - if strings.HasPrefix(name, "veo") { - veoModel := newVeoModel(ga.gclient, name, ai.ModelOptions{ - Label: fmt.Sprintf("%s - %s", googleAILabelPrefix, name), - Stage: ai.ModelStageStable, - Versions: []string{}, - Supports: &ai.ModelSupports{ - Media: true, - Multiturn: false, - Tools: false, - SystemRole: false, - Output: []string{"media"}, - LongRunning: true, - }, - }) - actionName := fmt.Sprintf("%s/%s", googleAIProvider, name) - return core.NewAction(actionName, api.ActionTypeBackgroundModel, nil, nil, - func(ctx context.Context, input *ai.ModelRequest) (*core.Operation[*ai.ModelResponse], error) { - op, err := veoModel.Start(ctx, input) - if err != nil { - return nil, err - } - op.Action = api.KeyFromName(api.ActionTypeBackgroundModel, actionName) - return op, nil - }) - } - return nil - case api.ActionTypeCheckOperation: - // Handle VEO model check operations - if strings.HasPrefix(name, "veo") { - veoModel := newVeoModel(ga.gclient, name, ai.ModelOptions{ - Label: fmt.Sprintf("%s - %s", googleAILabelPrefix, name), - Stage: ai.ModelStageStable, - Versions: []string{}, - Supports: &ai.ModelSupports{ - Media: true, - Multiturn: false, - Tools: false, - SystemRole: false, - Output: []string{"media"}, - LongRunning: true, - }, - }) - - actionName := fmt.Sprintf("%s/%s", googleAIProvider, name) - return core.NewAction(actionName, api.ActionTypeCheckOperation, - map[string]any{"description": fmt.Sprintf("Check status of %s operation", name)}, nil, - func(ctx context.Context, op *core.Operation[*ai.ModelResponse]) (*core.Operation[*ai.ModelResponse], error) { - updatedOp, err := veoModel.Check(ctx, op) - if err != nil { - return nil, err - } - updatedOp.Action = api.KeyFromName(api.ActionTypeBackgroundModel, actionName) - return updatedOp, nil - }) - } - return nil - } - return nil -} - -// ListActions lists all the actions supported by the Vertex AI plugin. -func (v *VertexAI) ListActions(ctx context.Context) []api.ActionDesc { - models, err := listGenaiModels(ctx, v.gclient) - if err != nil { - return nil - } - - actions := []api.ActionDesc{} - - // Gemini generative models. - for _, name := range models.gemini { - var opts ai.ModelOptions - if knownOpts, ok := supportedGeminiModels[name]; ok { - opts = knownOpts - opts.Label = fmt.Sprintf("%s - %s", vertexAILabelPrefix, opts.Label) - } else { - opts = defaultGeminiOpts - opts.Label = fmt.Sprintf("%s - %s", vertexAILabelPrefix, name) - } - - model := newModel(v.gclient, name, opts) - if actionDef, ok := model.(api.Action); ok { - actions = append(actions, actionDef.Desc()) - } - } - - // Imagen models. - for _, name := range models.imagen { - var opts ai.ModelOptions - if knownOpts, ok := supportedImagenModels[name]; ok { - opts = knownOpts - opts.Label = fmt.Sprintf("%s - %s", vertexAILabelPrefix, opts.Label) - } else { - opts = defaultImagenOpts - opts.Label = fmt.Sprintf("%s - %s", vertexAILabelPrefix, name) - } - - model := newModel(v.gclient, name, opts) - if actionDef, ok := model.(api.Action); ok { - actions = append(actions, actionDef.Desc()) - } - } - - // Embedders. - for _, e := range models.embedders { - var embedOpts ai.EmbedderOptions - if knownOpts, ok := googleAIEmbedderConfig[e]; ok { - embedOpts = knownOpts - } else { - embedOpts = defaultEmbedOpts - embedOpts.Label = fmt.Sprintf("%s - %s", vertexAILabelPrefix, e) - } - - embedder := newEmbedder(v.gclient, e, &embedOpts) - if actionDef, ok := embedder.(api.Action); ok { - actions = append(actions, actionDef.Desc()) - } - } - - return actions -} - -// ResolveAction resolves an action with the given name. -func (v *VertexAI) ResolveAction(atype api.ActionType, id string) api.Action { - switch atype { - case api.ActionTypeEmbedder: - return newEmbedder(v.gclient, id, &ai.EmbedderOptions{}).(api.Action) - case api.ActionTypeModel: - var supports *ai.ModelSupports - var config any - - // TODO: Add veo case. - switch { - case strings.Contains(id, "imagen"): - supports = &Media - config = &genai.GenerateImagesConfig{} - default: - supports = &Multimodal - config = &genai.GenerateContentConfig{} - } - - return newModel(v.gclient, id, ai.ModelOptions{ - Label: fmt.Sprintf("%s - %s", vertexAILabelPrefix, id), - Stage: ai.ModelStageStable, - Versions: []string{}, - Supports: supports, - ConfigSchema: configToMap(config), - }).(api.Action) - } - return nil -} diff --git a/go/plugins/googlegenai/imagen.go b/go/plugins/googlegenai/imagen.go index 6003494a30..86083a0832 100644 --- a/go/plugins/googlegenai/imagen.go +++ b/go/plugins/googlegenai/imagen.go @@ -26,16 +26,6 @@ import ( "google.golang.org/genai" ) -// Media describes model capabilities for Gemini models with media and text -// input and image only output -var Media = ai.ModelSupports{ - Media: true, - Multiturn: false, - Tools: false, - ToolChoice: false, - SystemRole: false, -} - // imagenConfigFromRequest translates an [*ai.ModelRequest] configuration to [*genai.GenerateImagesConfig] func imagenConfigFromRequest(input *ai.ModelRequest) (*genai.GenerateImagesConfig, error) { var result genai.GenerateImagesConfig diff --git a/go/plugins/googlegenai/model_type.go b/go/plugins/googlegenai/model_type.go new file mode 100644 index 0000000000..201f50a25e --- /dev/null +++ b/go/plugins/googlegenai/model_type.go @@ -0,0 +1,83 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package googlegenai + +import ( + "strings" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core/api" + "google.golang.org/genai" +) + +// ModelType categorizes models by their generation modality. +type ModelType int + +const ( + ModelTypeUnknown ModelType = iota + ModelTypeGemini // Text/multimodal generation (gemini-*, gemma-*) + ModelTypeImagen // Image generation (imagen-*) + ModelTypeVeo // Video generation (veo-*), long-running + ModelTypeEmbedder // Embedding models (*embedding*) +) + +// ClassifyModel determines the model type from its name. +// This is the single source of truth for model type classification. +func ClassifyModel(name string) ModelType { + switch { + case strings.HasPrefix(name, "veo"): + return ModelTypeVeo + case strings.HasPrefix(name, "imagen"), strings.HasPrefix(name, "image"): + return ModelTypeImagen + case strings.HasPrefix(name, "gemini"), strings.HasPrefix(name, "gemma"): + return ModelTypeGemini + case strings.Contains(name, "embedding"): + // Covers: text-embedding-*, embedding-*, textembedding-*, multimodalembedding + return ModelTypeEmbedder + default: + return ModelTypeUnknown + } +} + +// ActionType returns the appropriate API action type for this model type. +func (mt ModelType) ActionType() api.ActionType { + switch mt { + case ModelTypeVeo: + return api.ActionTypeBackgroundModel + case ModelTypeEmbedder: + return api.ActionTypeEmbedder + default: + return api.ActionTypeModel + } +} + +// DefaultSupports returns the default ModelSupports for this model type. +func (mt ModelType) DefaultSupports() *ai.ModelSupports { + switch mt { + case ModelTypeGemini: + return &Multimodal + case ModelTypeImagen: + return &Media + case ModelTypeVeo: + return &VeoSupports + default: + return nil + } +} + +// DefaultConfig returns the default config struct for this model type. +func (mt ModelType) DefaultConfig() any { + switch mt { + case ModelTypeGemini: + return &genai.GenerateContentConfig{} + case ModelTypeImagen: + return &genai.GenerateImagesConfig{} + case ModelTypeVeo: + return &genai.GenerateVideosConfig{} + case ModelTypeEmbedder: + return &genai.EmbedContentConfig{} + default: + return nil + } +} diff --git a/go/plugins/googlegenai/models.go b/go/plugins/googlegenai/models.go index d550253642..20603e5bbf 100644 --- a/go/plugins/googlegenai/models.go +++ b/go/plugins/googlegenai/models.go @@ -6,7 +6,6 @@ package googlegenai import ( "context" "fmt" - "log" "slices" "strings" @@ -14,6 +13,70 @@ import ( "google.golang.org/genai" ) +// Model capability definitions - these describe what different model types support. +var ( + // BasicText describes model capabilities for text-only Gemini models. + BasicText = ai.ModelSupports{ + Multiturn: true, + Tools: true, + ToolChoice: true, + SystemRole: true, + Media: false, + } + + // Multimodal describes model capabilities for multimodal Gemini models. + Multimodal = ai.ModelSupports{ + Multiturn: true, + Tools: true, + ToolChoice: true, + SystemRole: true, + Media: true, + Constrained: ai.ConstrainedSupportNoTools, + } + + // Media describes model capabilities for image generation models (Imagen). + Media = ai.ModelSupports{ + Multiturn: false, + Tools: false, + SystemRole: false, + Media: true, + Output: []string{"media"}, + } + + // VeoSupports describes model capabilities for video generation models (Veo). + VeoSupports = ai.ModelSupports{ + Media: true, + Multiturn: false, + Tools: false, + SystemRole: false, + Output: []string{"media"}, + LongRunning: true, + } +) + +// Default options for unknown models of each type. +var ( + defaultGeminiOpts = ai.ModelOptions{ + Supports: &Multimodal, + Stage: ai.ModelStageUnstable, + } + + defaultImagenOpts = ai.ModelOptions{ + Supports: &Media, + Stage: ai.ModelStageUnstable, + } + + defaultVeoOpts = ai.ModelOptions{ + Supports: &VeoSupports, + Stage: ai.ModelStageUnstable, + } + + defaultEmbedOpts = ai.EmbedderOptions{ + Supports: &ai.EmbedderSupports{Input: []string{"text"}}, + Dimensions: 768, + } +) + const ( gemini15Flash = "gemini-1.5-flash" gemini15Pro = "gemini-1.5-pro" @@ -250,47 +313,26 @@ var ( supportedVideoModels = map[string]ai.ModelOptions{ veo20Generate001: { - Label: "Google AI - Veo 2.0 Generate 001", + Label: "Veo 2.0 Generate 001", Versions: []string{}, - Supports: &ai.ModelSupports{ - Media: true, - Multiturn: false, - Tools: false, - SystemRole: false, - Output: []string{"media"}, - LongRunning: true, - }, - Stage: ai.ModelStageStable, + Supports: &VeoSupports, + Stage: ai.ModelStageStable, }, veo30Generate001: { - Label: "Google AI - Veo 3.0 Generate 001", + Label: "Veo 3.0 Generate 001", Versions: []string{}, - Supports: &ai.ModelSupports{ - Media: true, - Multiturn: false, - Tools: false, - SystemRole: false, - Output: []string{"media"}, - LongRunning: true, - }, - Stage: ai.ModelStageStable, + Supports: &VeoSupports, + Stage: ai.ModelStageStable, }, veo30FastGenerate001: { - Label: "Google AI - Veo 3.0 Fast Generate 001", + Label: "Veo 3.0 Fast Generate 001", Versions: []string{}, - Supports: &ai.ModelSupports{ - Media: true, - Multiturn: false, - Tools: false, - SystemRole: false, - Output: []string{"media"}, - LongRunning: true, - }, - Stage: ai.ModelStageStable, + Supports: &VeoSupports, + Stage: ai.ModelStageStable, }, } - googleAIEmbedderConfig = map[string]ai.EmbedderOptions{ + embedderConfig = map[string]ai.EmbedderOptions{ textembedding004: { Dimensions: 768, Label: "Google Gen AI - Text Embedding 001", @@ -354,39 +396,99 @@ var ( } ) +// GetModelOptions returns ModelOptions for a model name with provider-prefixed label. +func GetModelOptions(name, provider string) ai.ModelOptions { + mt := ClassifyModel(name) + var opts ai.ModelOptions + var ok bool + + switch mt { + case ModelTypeGemini: + opts, ok = supportedGeminiModels[name] + if !ok { + opts = defaultGeminiOpts + } + case ModelTypeImagen: + opts, ok = supportedImagenModels[name] + if !ok { + opts = defaultImagenOpts + } + case ModelTypeVeo: + opts, ok = supportedVideoModels[name] + if !ok { + opts = defaultVeoOpts + } + default: + opts = defaultGeminiOpts + } + + // Set label with provider prefix + prefix := googleAILabelPrefix + if provider == vertexAIProvider { + prefix = vertexAILabelPrefix + } + if opts.Label == "" { + opts.Label = name + } + opts.Label = fmt.Sprintf("%s - %s", prefix, opts.Label) + + return opts +} + +// GetEmbedderOptions returns EmbedderOptions for an embedder name with provider-prefixed label. +func GetEmbedderOptions(name, provider string) ai.EmbedderOptions { + opts, ok := embedderConfig[name] + if !ok { + opts = defaultEmbedOpts + } + + prefix := googleAILabelPrefix + if provider == vertexAIProvider { + prefix = vertexAILabelPrefix + } + if opts.Label == "" { + opts.Label = name + } + opts.Label = fmt.Sprintf("%s - %s", prefix, opts.Label) + + return opts +} + // listModels returns a map of supported models and their capabilities -// based on the detected backend +// based on the detected backend. func listModels(provider string) (map[string]ai.ModelOptions, error) { var names []string - var prefix string switch provider { case googleAIProvider: names = googleAIModels - prefix = googleAILabelPrefix case vertexAIProvider: names = vertexAIModels - prefix = vertexAILabelPrefix default: return nil, fmt.Errorf("unknown provider detected %s", provider) } - models := make(map[string]ai.ModelOptions, 0) + models := make(map[string]ai.ModelOptions, len(names)) for _, n := range names { + mt := ClassifyModel(n) var m ai.ModelOptions var ok bool - if strings.HasPrefix(n, "image") { + + switch mt { + case ModelTypeImagen: m, ok = supportedImagenModels[n] - } else if strings.HasPrefix(n, "veo") { + case ModelTypeVeo: m, ok = supportedVideoModels[n] - } else { + default: m, ok = supportedGeminiModels[n] } if !ok { return nil, fmt.Errorf("model %s not found for provider %s", n, provider) } + models[n] = GetModelOptions(n, provider) + // Preserve original fields that GetModelOptions doesn't copy models[n] = ai.ModelOptions{ - Label: prefix + " - " + m.Label, + Label: models[n].Label, Versions: m.Versions, Supports: m.Supports, ConfigSchema: m.ConfigSchema, @@ -406,44 +508,42 @@ type genaiModels struct { } // listGenaiModels returns a list of supported models and embedders from the -// Go Genai SDK +// Go Genai SDK, categorized by model type. func listGenaiModels(ctx context.Context, client *genai.Client) (genaiModels, error) { models := genaiModels{} - allowedModels := []string{"gemini", "gemma"} for item, err := range client.Models.All(ctx) { - var name string - var description string if err != nil { - log.Fatal(err) + return genaiModels{}, err } if !strings.HasPrefix(item.Name, "models/") { continue } - description = strings.ToLower(item.Description) + description := strings.ToLower(item.Description) if strings.Contains(description, "deprecated") { continue } - name = strings.TrimPrefix(item.Name, "models/") - if slices.Contains(item.SupportedActions, "embedContent") { - models.embedders = append(models.embedders, name) - continue - } - - if slices.Contains(item.SupportedActions, "predict") && strings.Contains(name, "imagen") { - models.imagen = append(models.imagen, name) - continue - } + name := strings.TrimPrefix(item.Name, "models/") + mt := ClassifyModel(name) - if slices.Contains(item.SupportedActions, "generateContent") { - found := slices.ContainsFunc(allowedModels, func(s string) bool { - return strings.Contains(name, s) - }) - // filter out: Aqa, Text-bison, Chat, learnlm - if found { + switch mt { + case ModelTypeEmbedder: + if slices.Contains(item.SupportedActions, "embedContent") { + models.embedders = append(models.embedders, name) + } + case ModelTypeImagen: + if slices.Contains(item.SupportedActions, "predict") { + models.imagen = append(models.imagen, name) + } + case ModelTypeVeo: + // Veo uses predict for long-running operations + if slices.Contains(item.SupportedActions, "predict") { + models.veo = append(models.veo, name) + } + case ModelTypeGemini: + if slices.Contains(item.SupportedActions, "generateContent") { models.gemini = append(models.gemini, name) - continue } } } diff --git a/go/plugins/googlegenai/refs.go b/go/plugins/googlegenai/refs.go new file mode 100644 index 0000000000..61d4d33785 --- /dev/null +++ b/go/plugins/googlegenai/refs.go @@ -0,0 +1,55 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package googlegenai + +import ( + "github.com/firebase/genkit/go/ai" + "google.golang.org/genai" +) + +// --- Gemini (text generation) --- + +// ModelRef creates a ModelRef for a Gemini model. +// The name should include provider prefix (e.g., "googleai/gemini-2.0-flash"). +func ModelRef(name string, config *genai.GenerateContentConfig) ai.ModelRef { + return ai.NewModelRef(name, config) +} + +// GoogleAIModelRef creates a ModelRef for a Google AI Gemini model. +// +// Deprecated: Use ModelRef with full name instead. +func GoogleAIModelRef(id string, config *genai.GenerateContentConfig) ai.ModelRef { + return ai.NewModelRef(googleAIProvider+"/"+id, config) +} + +// VertexAIModelRef creates a ModelRef for a Vertex AI Gemini model. +// +// Deprecated: Use ModelRef with full name instead. +func VertexAIModelRef(id string, config *genai.GenerateContentConfig) ai.ModelRef { + return ai.NewModelRef(vertexAIProvider+"/"+id, config) +} + +// --- Image generation (Imagen) --- + +// ImageModelRef creates a ModelRef for an image generation model. +// The name should include provider prefix (e.g., "googleai/imagen-3.0-generate-001"). +func ImageModelRef(name string, config *genai.GenerateImagesConfig) ai.ModelRef { + return ai.NewModelRef(name, config) +} + +// --- Video generation (Veo) --- + +// VideoModelRef creates a ModelRef for a video generation model. +// The name should include provider prefix (e.g., "googleai/veo-2.0-generate-001"). +func VideoModelRef(name string, config *genai.GenerateVideosConfig) ai.ModelRef { + return ai.NewModelRef(name, config) +} + +// --- Embedders --- + +// EmbedderRef creates an EmbedderRef for an embedding model. +// The name should include provider prefix (e.g., "googleai/text-embedding-004"). +func EmbedderRef(name string, config *genai.EmbedContentConfig) ai.EmbedderRef { + return ai.NewEmbedderRef(name, config) +} diff --git a/go/plugins/googlegenai/schema.go b/go/plugins/googlegenai/schema.go new file mode 100644 index 0000000000..f352104232 --- /dev/null +++ b/go/plugins/googlegenai/schema.go @@ -0,0 +1,229 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package googlegenai + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" + + "google.golang.org/genai" +) + +// toGeminiSchema translates a map representing a standard JSON schema to a more +// limited [genai.Schema]. +func toGeminiSchema(originalSchema map[string]any, genkitSchema map[string]any) (*genai.Schema, error) { + // this covers genkitSchema == nil and {} + // genkitSchema will be {} if it's any + if len(genkitSchema) == 0 { + return nil, nil + } + if v, ok := genkitSchema["$ref"]; ok { + ref, ok := v.(string) + if !ok { + return nil, fmt.Errorf("invalid $ref value: not a string") + } + s, err := resolveRef(originalSchema, ref) + if err != nil { + return nil, err + } + return toGeminiSchema(originalSchema, s) + } + + // Handle "anyOf" subschemas by finding the first valid schema definition + if v, ok := genkitSchema["anyOf"]; ok { + if anyOfList, isList := v.([]map[string]any); isList { + for _, subSchema := range anyOfList { + if subSchemaType, hasType := subSchema["type"]; hasType { + if typeStr, isString := subSchemaType.(string); isString && typeStr != "null" { + if title, ok := genkitSchema["title"]; ok { + subSchema["title"] = title + } + if description, ok := genkitSchema["description"]; ok { + subSchema["description"] = description + } + // Found a schema like: {"type": "string"} + return toGeminiSchema(originalSchema, subSchema) + } + } + } + } + } + + schema := &genai.Schema{} + typeVal, ok := genkitSchema["type"] + if !ok { + return nil, fmt.Errorf("schema is missing the 'type' field: %#v", genkitSchema) + } + + typeStr, ok := typeVal.(string) + if !ok { + return nil, fmt.Errorf("schema 'type' field is not a string, but %T", typeVal) + } + + switch typeStr { + case "string": + schema.Type = genai.TypeString + case "float64", "number": + schema.Type = genai.TypeNumber + case "integer": + schema.Type = genai.TypeInteger + case "boolean": + schema.Type = genai.TypeBoolean + case "object": + schema.Type = genai.TypeObject + case "array": + schema.Type = genai.TypeArray + default: + return nil, fmt.Errorf("schema type %q not allowed", genkitSchema["type"]) + } + if v, ok := genkitSchema["required"]; ok { + schema.Required = castToStringArray(v) + } + if v, ok := genkitSchema["propertyOrdering"]; ok { + schema.PropertyOrdering = castToStringArray(v) + } + if v, ok := genkitSchema["description"]; ok { + schema.Description = v.(string) + } + if v, ok := genkitSchema["format"]; ok { + schema.Format = v.(string) + } + if v, ok := genkitSchema["title"]; ok { + schema.Title = v.(string) + } + if v, ok := genkitSchema["minItems"]; ok { + if i64, ok := castToInt64(v); ok { + schema.MinItems = genai.Ptr(i64) + } + } + if v, ok := genkitSchema["maxItems"]; ok { + if i64, ok := castToInt64(v); ok { + schema.MaxItems = genai.Ptr(i64) + } + } + if v, ok := genkitSchema["maximum"]; ok { + if f64, ok := castToFloat64(v); ok { + schema.Maximum = genai.Ptr(f64) + } + } + if v, ok := genkitSchema["minimum"]; ok { + if f64, ok := castToFloat64(v); ok { + schema.Minimum = genai.Ptr(f64) + } + } + if v, ok := genkitSchema["enum"]; ok { + schema.Enum = castToStringArray(v) + } + if v, ok := genkitSchema["items"]; ok { + items, err := toGeminiSchema(originalSchema, v.(map[string]any)) + if err != nil { + return nil, err + } + schema.Items = items + } + if val, ok := genkitSchema["properties"]; ok { + props := map[string]*genai.Schema{} + for k, v := range val.(map[string]any) { + p, err := toGeminiSchema(originalSchema, v.(map[string]any)) + if err != nil { + return nil, err + } + props[k] = p + } + schema.Properties = props + } + // Nullable -- not supported in jsonschema.Schema + + return schema, nil +} + +// resolveRef resolves a $ref reference in a JSON schema. +func resolveRef(originalSchema map[string]any, ref string) (map[string]any, error) { + tkns := strings.Split(ref, "/") + // refs look like: $/ref/foo -- we need the foo part + name := tkns[len(tkns)-1] + if defs, ok := originalSchema["$defs"].(map[string]any); ok { + if def, ok := defs[name].(map[string]any); ok { + return def, nil + } + } + // definitions (legacy) + if defs, ok := originalSchema["definitions"].(map[string]any); ok { + if def, ok := defs[name].(map[string]any); ok { + return def, nil + } + } + return nil, fmt.Errorf("unable to resolve schema reference") +} + +// castToStringArray converts either []any or []string to []string, filtering non-strings. +// This handles enum values from JSON Schema which may come as either type depending on unmarshaling. +// Filter out non-string types from if v is []any type. +func castToStringArray(v any) []string { + switch a := v.(type) { + case []string: + // Return a shallow copy to avoid aliasing + out := make([]string, 0, len(a)) + for _, s := range a { + if s != "" { + out = append(out, s) + } + } + return out + case []any: + var out []string + for _, it := range a { + if s, ok := it.(string); ok && s != "" { + out = append(out, s) + } + } + return out + default: + return nil + } +} + +// castToInt64 converts v to int64 when possible. +func castToInt64(v any) (int64, bool) { + switch t := v.(type) { + case int: + return int64(t), true + case int64: + return t, true + case float64: + return int64(t), true + case string: + if i, err := strconv.ParseInt(t, 10, 64); err == nil { + return i, true + } + case json.Number: + if i, err := t.Int64(); err == nil { + return i, true + } + } + return 0, false +} + +// castToFloat64 converts v to float64 when possible. +func castToFloat64(v any) (float64, bool) { + switch t := v.(type) { + case float64: + return t, true + case int: + return float64(t), true + case int64: + return float64(t), true + case string: + if f, err := strconv.ParseFloat(t, 64); err == nil { + return f, true + } + case json.Number: + if f, err := t.Float64(); err == nil { + return f, true + } + } + return 0, false +} diff --git a/go/plugins/googlegenai/tools.go b/go/plugins/googlegenai/tools.go new file mode 100644 index 0000000000..7c31ffdbd5 --- /dev/null +++ b/go/plugins/googlegenai/tools.go @@ -0,0 +1,154 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +package googlegenai + +import ( + "fmt" + "reflect" + "regexp" + "strings" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/plugins/internal/uri" + "google.golang.org/genai" +) + +const ( + // toolNameRegex validates tool names. + toolNameRegex = "^[a-zA-Z_][a-zA-Z0-9_.-]{0,63}$" +) + +// toGeminiTools translates a slice of [ai.ToolDefinition] to a slice of [genai.Tool]. +func toGeminiTools(inTools []*ai.ToolDefinition) ([]*genai.Tool, error) { + var outTools []*genai.Tool + functions := []*genai.FunctionDeclaration{} + + for _, t := range inTools { + if !validToolName(t.Name) { + return nil, fmt.Errorf(`invalid tool name: %q, must start with a letter or an underscore, must be alphanumeric, underscores, dots or dashes with a max length of 64 chars`, t.Name) + } + inputSchema, err := toGeminiSchema(t.InputSchema, t.InputSchema) + if err != nil { + return nil, err + } + fd := &genai.FunctionDeclaration{ + Name: t.Name, + Parameters: inputSchema, + Description: t.Description, + } + functions = append(functions, fd) + } + + if len(functions) > 0 { + outTools = append(outTools, &genai.Tool{ + FunctionDeclarations: functions, + }) + } + + return outTools, nil +} + +// toGeminiFunctionResponsePart translates a slice of [ai.Part] to a slice of [genai.FunctionResponsePart] +func toGeminiFunctionResponsePart(parts []*ai.Part) ([]*genai.FunctionResponsePart, error) { + frp := []*genai.FunctionResponsePart{} + for _, p := range parts { + switch { + case p.IsData(): + contentType, data, err := uri.Data(p) + if err != nil { + return nil, err + } + frp = append(frp, genai.NewFunctionResponsePartFromBytes(data, contentType)) + case p.IsMedia(): + if strings.HasPrefix(p.Text, "data:") { + contentType, data, err := uri.Data(p) + if err != nil { + return nil, err + } + frp = append(frp, genai.NewFunctionResponsePartFromBytes(data, contentType)) + continue + } + frp = append(frp, genai.NewFunctionResponsePartFromURI(p.Text, p.ContentType)) + default: + return nil, fmt.Errorf("unsupported function response part type: %d", p.Kind) + } + } + return frp, nil +} + +// mergeTools consolidates all FunctionDeclarations into a single Tool +// while preserving non-function tools (Retrieval, GoogleSearch, CodeExecution, etc.) +func mergeTools(ts []*genai.Tool) []*genai.Tool { + var decls []*genai.FunctionDeclaration + var out []*genai.Tool + + for _, t := range ts { + if t == nil { + continue + } + if len(t.FunctionDeclarations) == 0 { + out = append(out, t) + continue + } + decls = append(decls, t.FunctionDeclarations...) + if cpy := cloneToolWithoutFunctions(t); cpy != nil && !reflect.ValueOf(*cpy).IsZero() { + out = append(out, cpy) + } + } + + if len(decls) > 0 { + out = append([]*genai.Tool{{FunctionDeclarations: decls}}, out...) + } + return out +} + +func cloneToolWithoutFunctions(t *genai.Tool) *genai.Tool { + if t == nil { + return nil + } + clone := *t + clone.FunctionDeclarations = nil + return &clone +} + +// toGeminiToolChoice translates tool choice settings to Gemini tool config. +func toGeminiToolChoice(toolChoice ai.ToolChoice, tools []*ai.ToolDefinition) (*genai.ToolConfig, error) { + var mode genai.FunctionCallingConfigMode + switch toolChoice { + case "": + return nil, nil + case ai.ToolChoiceAuto: + mode = genai.FunctionCallingConfigModeAuto + case ai.ToolChoiceRequired: + mode = genai.FunctionCallingConfigModeAny + case ai.ToolChoiceNone: + mode = genai.FunctionCallingConfigModeNone + default: + return nil, fmt.Errorf("tool choice mode %q not supported", toolChoice) + } + + var toolNames []string + // Per docs, only set AllowedToolNames with mode set to ANY. + if mode == genai.FunctionCallingConfigModeAny { + for _, t := range tools { + toolNames = append(toolNames, t.Name) + } + } + return &genai.ToolConfig{ + FunctionCallingConfig: &genai.FunctionCallingConfig{ + Mode: mode, + AllowedFunctionNames: toolNames, + }, + }, nil +} + +// validToolName checks whether the provided tool name matches the +// following criteria: +// - Start with a letter or an underscore +// - Must be alphanumeric and can include underscores, dots or dashes +// - Maximum length of 64 chars +func validToolName(n string) bool { + re := regexp.MustCompile(toolNameRegex) + return re.MatchString(n) +} diff --git a/go/plugins/googlegenai/veo.go b/go/plugins/googlegenai/veo.go index 328b4598f7..cf58aab98f 100644 --- a/go/plugins/googlegenai/veo.go +++ b/go/plugins/googlegenai/veo.go @@ -19,6 +19,7 @@ package googlegenai import ( "context" "fmt" + "strings" "time" "github.com/firebase/genkit/go/ai" @@ -41,15 +42,19 @@ func newVeoModel( return nil, fmt.Errorf("no text prompt found in request") } + video := extractVeoVideoFromRequest(req) image := extractVeoImageFromRequest(req) - videoConfig := toVeoParameters(req) + sourceConfig := &genai.GenerateVideosSource{ + Prompt: prompt, + Image: image, + Video: video, + } - operation, err := client.Models.GenerateVideos( + operation, err := client.Models.GenerateVideosFromSource( ctx, name, - prompt, - image, + sourceConfig, videoConfig, ) if err != nil { @@ -131,7 +136,7 @@ func extractVeoImageFromRequest(request *ai.ModelRequest) *genai.Image { for _, message := range request.Messages { for _, part := range message.Content { - if part.IsMedia() { + if part.IsMedia() && !part.IsVideo() { _, data, err := uri.Data(part) if err != nil { return nil @@ -147,6 +152,37 @@ func extractVeoImageFromRequest(request *ai.ModelRequest) *genai.Image { return nil } +// extractVeoVideoFromRequest extracts video content from a model request for Veo. +func extractVeoVideoFromRequest(request *ai.ModelRequest) *genai.Video { + if len(request.Messages) == 0 { + return nil + } + + for _, message := range request.Messages { + for _, part := range message.Content { + if !part.IsVideo() { + continue + } + if strings.HasPrefix(part.Text, "data:") { + contentType, data, err := uri.Data(part) + if err != nil { + return nil + } + return &genai.Video{ + VideoBytes: data, + MIMEType: contentType, + } + } + return &genai.Video{ + URI: part.Text, + // MIMEType: part.ContentType, + } + } + } + + return nil +} + // toVeoParameters converts model request configuration to Veo video generation parameters. func toVeoParameters(request *ai.ModelRequest) *genai.GenerateVideosConfig { params := &genai.GenerateVideosConfig{} diff --git a/go/samples/veo/main.go b/go/samples/veo/main.go index dceb622ff9..d9e40c4d2b 100644 --- a/go/samples/veo/main.go +++ b/go/samples/veo/main.go @@ -16,6 +16,7 @@ package main import ( "context" + "encoding/base64" "fmt" "io" "log" @@ -36,35 +37,114 @@ func main() { g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - operation, err := genkit.GenerateOperation(ctx, g, - ai.WithMessages(ai.NewUserTextMessage("Cat racing mouse")), - ai.WithModelName("googleai/veo-3.0-generate-001"), - ai.WithConfig(&genai.GenerateVideosConfig{ - NumberOfVideos: 1, - AspectRatio: "16:9", - DurationSeconds: genai.Ptr(int32(8)), - Resolution: "720p", - }), - ) - if err != nil { - log.Fatalf("Failed to start video generation: %v", err) - } + genkit.DefineFlow(g, "text-to-video", func(ctx context.Context, input string) (string, error) { + if input == "" { + input = "Cat racing mouse" + } + operation, err := genkit.GenerateOperation(ctx, g, + ai.WithMessages(ai.NewUserTextMessage(input)), + ai.WithModelName("googleai/veo-3.1-generate-preview"), + ai.WithConfig(&genai.GenerateVideosConfig{ + NumberOfVideos: 1, + AspectRatio: "16:9", + DurationSeconds: genai.Ptr(int32(8)), + Resolution: "720p", + }), + ) + if err != nil { + log.Fatalf("Failed to start video generation: %v", err) + } + printStatus(operation) - log.Printf("Started operation: %s", operation.ID) - printStatus(operation) + operation, err = waitForCompletion(ctx, g, operation) + if err != nil { + log.Fatalf("Operation failed: %v", err) + } + log.Println("Video generation completed successfully!") - operation, err = waitForCompletion(ctx, g, operation) - if err != nil { - log.Fatalf("Operation failed: %v", err) - } + if err := downloadGeneratedVideo(ctx, operation); err != nil { + log.Fatalf("Failed to download video: %v", err) + } + + // Return the video URI for chaining + uri, err := extractVideoURL(operation) + if err != nil { + return "", err + } + return uri, nil + }) - log.Println("Video generation completed successfully!") + genkit.DefineFlow(g, "image-to-video", func(ctx context.Context, input any) (string, error) { + imgb64, err := fetchImgAsBase64() + if err != nil { + log.Fatalf("unable to download image: %v", err) + } + operation, err := genkit.GenerateOperation(ctx, g, + ai.WithModelName("googleai/veo-3.1-generate-preview"), + ai.WithMessages(ai.NewUserMessage(ai.NewTextPart("Generate a video of the following image, the cat should wake up and start accelerating the go-kart as if it just acquired a mushroom from Mario Kart"), + ai.NewMediaPart("image/jpeg", "data:image/jpeg;base64,"+imgb64), + )), + ai.WithConfig(&genai.GenerateVideosConfig{ + NumberOfVideos: 1, + AspectRatio: "16:9", + DurationSeconds: genai.Ptr(int32(8)), + }), + ) + if err != nil { + log.Fatalf("Failed to start video generation: %v", err) + } + printStatus(operation) - if err := downloadGeneratedVideo(ctx, operation); err != nil { - log.Fatalf("Failed to download video: %v", err) - } + operation, err = waitForCompletion(ctx, g, operation) + if err != nil { + log.Fatalf("Operation failed: %v", err) + } + log.Println("Video generation completed successfully!") + + if err := downloadGeneratedVideo(ctx, operation); err != nil { + log.Fatalf("Failed to download video: %v", err) + } + + return "Video successfully downloaded to veo3_video.mp4", nil + }) + + genkit.DefineFlow(g, "video-to-video", func(ctx context.Context, inputURI string) (string, error) { + if inputURI == "" { + return "", fmt.Errorf("input URI is required for video extension") + } + + log.Printf("Extending video from URI: %s", inputURI) + + operation, err := genkit.GenerateOperation(ctx, g, + ai.WithModelName("googleai/veo-3.1-generate-preview"), + ai.WithMessages(ai.NewUserMessage( + ai.NewTextPart("Edit the original video backround to be a rainforest, also change the video style to be a cartoon from 1950! You must keep the characters from the original video"), + ai.NewMediaPart("video/mp4", inputURI), + )), + ai.WithConfig(&genai.GenerateVideosConfig{ + NumberOfVideos: 1, + AspectRatio: "16:9", + DurationSeconds: genai.Ptr(int32(8)), + }), + ) + if err != nil { + log.Fatalf("Failed to start video generation: %v", err) + } + printStatus(operation) + + operation, err = waitForCompletion(ctx, g, operation) + if err != nil { + log.Fatalf("Operation failed: %v", err) + } + log.Println("Video extension completed successfully!") - log.Println("Video successfully downloaded to veo3_video.mp4") + if err := downloadGeneratedVideo(ctx, operation); err != nil { + log.Fatalf("Failed to download video: %v", err) + } + + return "Video successfully downloaded to veo3_video.mp4", nil + }) + <-ctx.Done() } // waitForCompletion polls the operation status until it completes. @@ -193,3 +273,25 @@ func downloadVideo(ctx context.Context, url, filename string) error { return nil } + +// fetchImgAsBase64 downloads a predefined image and returns the image encoded in a base64 string +func fetchImgAsBase64() (string, error) { + // CC0 license image + imgURL := "https://pd.w.org/2025/07/896686fbbcd9990c9.84605288-2048x1365.jpg" + resp, err := http.Get(imgURL) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", err + } + + imageBytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + base64string := base64.StdEncoding.EncodeToString(imageBytes) + return base64string, nil +}