diff --git a/go/core/cli/cmd/kagent/main.go b/go/core/cli/cmd/kagent/main.go index 8792835b0..bfec18610 100644 --- a/go/core/cli/cmd/kagent/main.go +++ b/go/core/cli/cmd/kagent/main.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "strings" "os/signal" "syscall" "time" @@ -64,6 +65,10 @@ func main() { _ = installCmd.RegisterFlagCompletionFunc("profile", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { return profiles.Profiles, cobra.ShellCompDirectiveNoFileComp }) + installCmd.Flags().StringVar(&installCfg.Provider, "provider", "", fmt.Sprintf("LLM provider to use (%s). Overrides KAGENT_DEFAULT_MODEL_PROVIDER.", strings.Join(cli.ValidProviders(), ", "))) + _ = installCmd.RegisterFlagCompletionFunc("provider", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return cli.ValidProviders(), cobra.ShellCompDirectiveNoFileComp + }) uninstallCmd := &cobra.Command{ Use: "uninstall", diff --git a/go/core/cli/internal/cli/agent/const.go b/go/core/cli/internal/cli/agent/const.go index 885f44c5d..2592379cb 100644 --- a/go/core/cli/internal/cli/agent/const.go +++ b/go/core/cli/internal/cli/agent/const.go @@ -1,6 +1,7 @@ package cli import ( + "fmt" "os" "strings" @@ -63,3 +64,25 @@ func GetEnvVarWithDefault(envVar, defaultValue string) string { } return defaultValue } + +// ValidProviders returns the accepted --provider flag values (helm key format). +func ValidProviders() []string { + return []string{ + GetModelProviderHelmValuesKey(v1alpha2.ModelProviderOpenAI), + GetModelProviderHelmValuesKey(v1alpha2.ModelProviderAnthropic), + GetModelProviderHelmValuesKey(v1alpha2.ModelProviderAzureOpenAI), + GetModelProviderHelmValuesKey(v1alpha2.ModelProviderOllama), + } +} + +// applyProviderFlag validates the --provider value and sets KAGENT_DEFAULT_MODEL_PROVIDER so +// that GetModelProvider() picks it up. This lets users avoid setting the env var manually. +func applyProviderFlag(provider string) error { + valid := ValidProviders() + for _, v := range valid { + if provider == v { + return os.Setenv(env.KagentDefaultModelProvider.Name(), provider) + } + } + return fmt.Errorf("unknown provider %q: valid values: %s", provider, strings.Join(valid, ", ")) +} diff --git a/go/core/cli/internal/cli/agent/install.go b/go/core/cli/internal/cli/agent/install.go index 27f5577c9..7f6959d6e 100644 --- a/go/core/cli/internal/cli/agent/install.go +++ b/go/core/cli/internal/cli/agent/install.go @@ -20,8 +20,9 @@ import ( ) type InstallCfg struct { - Config *config.Config - Profile string + Config *config.Config + Profile string + Provider string } // installChart installs or upgrades a Helm chart with the given parameters @@ -76,16 +77,28 @@ func InstallCmd(ctx context.Context, cfg *InstallCfg) *PortForward { return nil } + // --provider flag takes precedence over KAGENT_DEFAULT_MODEL_PROVIDER env var + if cfg.Provider != "" { + if err := applyProviderFlag(cfg.Provider); err != nil { + fmt.Fprintln(os.Stderr, err) + return nil + } + } + // get model provider from KAGENT_DEFAULT_MODEL_PROVIDER environment variable or use DefaultModelProvider modelProvider := GetModelProvider() - // If model provider is openai, check if the API key is set + // Check if the required API key is set for this provider apiKeyName := GetProviderAPIKey(modelProvider) apiKeyValue := os.Getenv(apiKeyName) if apiKeyName != "" && apiKeyValue == "" { fmt.Fprintf(os.Stderr, "%s is not set\n", apiKeyName) fmt.Fprintf(os.Stderr, "Please set the %s environment variable\n", apiKeyName) + if cfg.Provider == "" && modelProvider == DefaultModelProvider && apiKeyName == env.OpenAIAPIKey.Name() { + fmt.Fprintf(os.Stderr, "Tip: use --provider to select a different LLM provider (e.g. --provider anthropic)\n") + fmt.Fprintf(os.Stderr, " or set %s=%s before running install\n", env.KagentDefaultModelProvider.Name(), GetModelProviderHelmValuesKey(v1alpha2.ModelProviderAnthropic)) + } return nil } @@ -120,13 +133,16 @@ func InteractiveInstallCmd(ctx context.Context, c *ishell.Context) *PortForward // get model provider from KAGENT_DEFAULT_MODEL_PROVIDER environment variable or use DefaultModelProvider modelProvider := GetModelProvider() - // if model provider is openai, check if the api key is set + // Check if the required API key is set for this provider apiKeyName := GetProviderAPIKey(modelProvider) apiKeyValue := os.Getenv(apiKeyName) if apiKeyName != "" && apiKeyValue == "" { fmt.Fprintf(os.Stderr, "%s is not set\n", apiKeyName) fmt.Fprintf(os.Stderr, "Please set the %s environment variable\n", apiKeyName) + fmt.Fprintf(os.Stderr, "Tip: set %s to select a different provider (e.g. %s=%s)\n", + env.KagentDefaultModelProvider.Name(), env.KagentDefaultModelProvider.Name(), + GetModelProviderHelmValuesKey(v1alpha2.ModelProviderAnthropic)) return nil }