diff --git a/go/core/cli/cmd/kagent/main.go b/go/core/cli/cmd/kagent/main.go index 8792835b0..bfec18610 100644 --- a/go/core/cli/cmd/kagent/main.go +++ b/go/core/cli/cmd/kagent/main.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "strings" "os/signal" "syscall" "time" @@ -64,6 +65,10 @@ func main() { _ = installCmd.RegisterFlagCompletionFunc("profile", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { return profiles.Profiles, cobra.ShellCompDirectiveNoFileComp }) + installCmd.Flags().StringVar(&installCfg.Provider, "provider", "", fmt.Sprintf("LLM provider to use (%s). Overrides KAGENT_DEFAULT_MODEL_PROVIDER.", strings.Join(cli.ValidProviders(), ", "))) + _ = installCmd.RegisterFlagCompletionFunc("provider", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return cli.ValidProviders(), cobra.ShellCompDirectiveNoFileComp + }) uninstallCmd := &cobra.Command{ Use: "uninstall", diff --git a/go/core/cli/internal/cli/agent/const.go b/go/core/cli/internal/cli/agent/const.go index 885f44c5d..2592379cb 100644 --- a/go/core/cli/internal/cli/agent/const.go +++ b/go/core/cli/internal/cli/agent/const.go @@ -1,6 +1,7 @@ package cli import ( + "fmt" "os" "strings" @@ -63,3 +64,25 @@ func GetEnvVarWithDefault(envVar, defaultValue string) string { } return defaultValue } + +// ValidProviders returns the accepted --provider flag values (helm key format). +func ValidProviders() []string { + return []string{ + GetModelProviderHelmValuesKey(v1alpha2.ModelProviderOpenAI), + GetModelProviderHelmValuesKey(v1alpha2.ModelProviderAnthropic), + GetModelProviderHelmValuesKey(v1alpha2.ModelProviderAzureOpenAI), + GetModelProviderHelmValuesKey(v1alpha2.ModelProviderOllama), + } +} + +// applyProviderFlag validates the --provider value and sets KAGENT_DEFAULT_MODEL_PROVIDER so +// that GetModelProvider() picks it up. This lets users avoid setting the env var manually. +func applyProviderFlag(provider string) error { + valid := ValidProviders() + for _, v := range valid { + if provider == v { + return os.Setenv(env.KagentDefaultModelProvider.Name(), provider) + } + } + return fmt.Errorf("unknown provider %q: valid values: %s", provider, strings.Join(valid, ", ")) +} diff --git a/go/core/cli/internal/cli/agent/install.go b/go/core/cli/internal/cli/agent/install.go index 27f5577c9..7f6959d6e 100644 --- a/go/core/cli/internal/cli/agent/install.go +++ b/go/core/cli/internal/cli/agent/install.go @@ -20,8 +20,9 @@ import ( ) type InstallCfg struct { - Config *config.Config - Profile string + Config *config.Config + Profile string + Provider string } // installChart installs or upgrades a Helm chart with the given parameters @@ -76,16 +77,28 @@ func InstallCmd(ctx context.Context, cfg *InstallCfg) *PortForward { return nil } + // --provider flag takes precedence over KAGENT_DEFAULT_MODEL_PROVIDER env var + if cfg.Provider != "" { + if err := applyProviderFlag(cfg.Provider); err != nil { + fmt.Fprintln(os.Stderr, err) + return nil + } + } + // get model provider from KAGENT_DEFAULT_MODEL_PROVIDER environment variable or use DefaultModelProvider modelProvider := GetModelProvider() - // If model provider is openai, check if the API key is set + // Check if the required API key is set for this provider apiKeyName := GetProviderAPIKey(modelProvider) apiKeyValue := os.Getenv(apiKeyName) if apiKeyName != "" && apiKeyValue == "" { fmt.Fprintf(os.Stderr, "%s is not set\n", apiKeyName) fmt.Fprintf(os.Stderr, "Please set the %s environment variable\n", apiKeyName) + if cfg.Provider == "" && modelProvider == DefaultModelProvider && apiKeyName == env.OpenAIAPIKey.Name() { + fmt.Fprintf(os.Stderr, "Tip: use --provider to select a different LLM provider (e.g. --provider anthropic)\n") + fmt.Fprintf(os.Stderr, " or set %s=%s before running install\n", env.KagentDefaultModelProvider.Name(), GetModelProviderHelmValuesKey(v1alpha2.ModelProviderAnthropic)) + } return nil } @@ -120,13 +133,16 @@ func InteractiveInstallCmd(ctx context.Context, c *ishell.Context) *PortForward // get model provider from KAGENT_DEFAULT_MODEL_PROVIDER environment variable or use DefaultModelProvider modelProvider := GetModelProvider() - // if model provider is openai, check if the api key is set + // Check if the required API key is set for this provider apiKeyName := GetProviderAPIKey(modelProvider) apiKeyValue := os.Getenv(apiKeyName) if apiKeyName != "" && apiKeyValue == "" { fmt.Fprintf(os.Stderr, "%s is not set\n", apiKeyName) fmt.Fprintf(os.Stderr, "Please set the %s environment variable\n", apiKeyName) + fmt.Fprintf(os.Stderr, "Tip: set %s to select a different provider (e.g. %s=%s)\n", + env.KagentDefaultModelProvider.Name(), env.KagentDefaultModelProvider.Name(), + GetModelProviderHelmValuesKey(v1alpha2.ModelProviderAnthropic)) return nil } diff --git a/go/core/internal/utils/config_map.go b/go/core/internal/utils/config_map.go index a7811ead7..fe9e6bccf 100644 --- a/go/core/internal/utils/config_map.go +++ b/go/core/internal/utils/config_map.go @@ -1,20 +1,108 @@ package utils import ( + "bytes" + "compress/gzip" "context" + "encoding/base64" "fmt" + "io" + "strings" + "github.com/klauspost/compress/zstd" corev1 "k8s.io/api/core/v1" "sigs.k8s.io/controller-runtime/pkg/client" ) -// GetConfigMapData fetches all data from a ConfigMap. +const ( + // CompressionAnnotation specifies the compression algorithm used for ConfigMap + // values. Supported values: "gzip", "zstd". When set, all values in the + // ConfigMap are expected to be base64-encoded compressed data and will be + // transparently decompressed when read via GetConfigMapData. + CompressionAnnotation = "kagent.dev/compression" + + // maxDecompressedSize is the upper bound on decompressed output (10 MB). + // This prevents a small compressed payload from expanding into an + // arbitrarily large allocation that could OOM the controller. + maxDecompressedSize = 10 << 20 // 10 MiB +) + +// GetConfigMapData fetches all data from a ConfigMap. If the ConfigMap carries +// the kagent.dev/compression annotation, values are transparently decompressed. +// Compressed values must be base64-encoded in the ConfigMap's Data field (not BinaryData). func GetConfigMapData(ctx context.Context, c client.Client, ref client.ObjectKey) (map[string]string, error) { configMap := &corev1.ConfigMap{} if err := c.Get(ctx, ref, configMap); err != nil { - return nil, fmt.Errorf("failed to find ConfigMap %s: %v", ref.String(), err) + return nil, fmt.Errorf("failed to find ConfigMap %s: %w", ref.String(), err) + } + + algo := strings.ToLower(strings.TrimSpace(configMap.Annotations[CompressionAnnotation])) + if algo == "" { + return configMap.Data, nil + } + + decompressed := make(map[string]string, len(configMap.Data)) + for key, value := range configMap.Data { + plain, err := decompress(value, algo) + if err != nil { + return nil, fmt.Errorf("failed to decompress key %q in ConfigMap %s (algorithm=%s): %w", key, ref.String(), algo, err) + } + decompressed[key] = plain + } + return decompressed, nil +} + +// decompress decodes base64 data and decompresses it with the given algorithm. +// The encoded payload is whitespace-tolerant (newlines and spaces are stripped +// before decoding) and decompressed output is capped at maxDecompressedSize. +func decompress(encoded string, algo string) (string, error) { + // Strip whitespace/newlines that commonly appear in pasted base64 + cleaned := strings.Map(func(r rune) rune { + if r == ' ' || r == '\n' || r == '\r' || r == '\t' { + return -1 + } + return r + }, encoded) + + raw, err := base64.StdEncoding.DecodeString(cleaned) + if err != nil { + return "", fmt.Errorf("base64 decode: %w", err) + } + + switch algo { + case "gzip": + r, err := gzip.NewReader(bytes.NewReader(raw)) + if err != nil { + return "", fmt.Errorf("gzip reader: %w", err) + } + defer r.Close() + out, err := io.ReadAll(io.LimitReader(r, maxDecompressedSize+1)) + if err != nil { + return "", fmt.Errorf("gzip read: %w", err) + } + if len(out) > maxDecompressedSize { + return "", fmt.Errorf("decompressed output exceeds %d bytes limit", maxDecompressedSize) + } + return string(out), nil + + case "zstd": + r, err := zstd.NewReader(bytes.NewReader(raw)) + if err != nil { + return "", fmt.Errorf("zstd reader: %w", err) + } + defer r.Close() + out, err := io.ReadAll(io.LimitReader(r, maxDecompressedSize+1)) + if err != nil { + return "", fmt.Errorf("zstd read: %w", err) + } + if len(out) > maxDecompressedSize { + return "", fmt.Errorf("decompressed output exceeds %d bytes limit", maxDecompressedSize) + } + return string(out), nil + + default: + return "", fmt.Errorf("unsupported compression algorithm %q (supported: gzip, zstd)", algo) } - return configMap.Data, nil } // GetConfigMapValue fetches a value from a ConfigMap @@ -22,7 +110,7 @@ func GetConfigMapValue(ctx context.Context, c client.Client, ref client.ObjectKe configMap := &corev1.ConfigMap{} err := c.Get(ctx, ref, configMap) if err != nil { - return "", fmt.Errorf("failed to find ConfigMap for %s: %v", ref.String(), err) + return "", fmt.Errorf("failed to find ConfigMap for %s: %w", ref.String(), err) } value, exists := configMap.Data[key] diff --git a/go/core/internal/utils/config_map_compression_test.go b/go/core/internal/utils/config_map_compression_test.go new file mode 100644 index 000000000..60f66205e --- /dev/null +++ b/go/core/internal/utils/config_map_compression_test.go @@ -0,0 +1,124 @@ +package utils + +import ( + "bytes" + "compress/gzip" + "encoding/base64" + "strings" + "testing" + + "github.com/klauspost/compress/zstd" +) + +func compressGzip(t *testing.T, data string) string { + t.Helper() + var buf bytes.Buffer + w := gzip.NewWriter(&buf) + if _, err := w.Write([]byte(data)); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + return base64.StdEncoding.EncodeToString(buf.Bytes()) +} + +func compressZstd(t *testing.T, data string) string { + t.Helper() + var buf bytes.Buffer + w, err := zstd.NewWriter(&buf) + if err != nil { + t.Fatal(err) + } + if _, err := w.Write([]byte(data)); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + return base64.StdEncoding.EncodeToString(buf.Bytes()) +} + +func TestDecompressGzip(t *testing.T) { + original := "Section 42 of the Children and Families Act 2014 imposes an absolute duty on the local authority to secure the provision specified in Section F." + encoded := compressGzip(t, original) + + result, err := decompress(encoded, "gzip") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != original { + t.Errorf("got %q, want %q", result, original) + } +} + +func TestDecompressZstd(t *testing.T) { + original := "Section 42 of the Children and Families Act 2014 imposes an absolute duty on the local authority to secure the provision specified in Section F." + encoded := compressZstd(t, original) + + result, err := decompress(encoded, "zstd") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != original { + t.Errorf("got %q, want %q", result, original) + } +} + +func TestDecompressUnsupportedAlgorithm(t *testing.T) { + _, err := decompress(base64.StdEncoding.EncodeToString([]byte("test")), "lz4") + if err == nil { + t.Fatal("expected error for unsupported algorithm") + } +} + +func TestDecompressInvalidBase64(t *testing.T) { + _, err := decompress("not-valid-base64!!!", "gzip") + if err == nil { + t.Fatal("expected error for invalid base64") + } +} + +func TestDecompressBase64WithWhitespace(t *testing.T) { + original := "Whitespace in base64 is common when users paste wrapped output." + clean := compressGzip(t, original) + + // Insert newlines and spaces to simulate wrapped base64 + wrapped := clean[:20] + "\n" + clean[20:40] + " " + clean[40:60] + "\r\n" + clean[60:] + + result, err := decompress(wrapped, "gzip") + if err != nil { + t.Fatalf("unexpected error with whitespace in base64: %v", err) + } + if result != original { + t.Errorf("got %q, want %q", result, original) + } +} + +func TestDecompressExceedsSizeLimit(t *testing.T) { + // Create data larger than maxDecompressedSize (10MB) + // zstd compresses repeated data extremely well, so a small input can exceed the limit + huge := make([]byte, maxDecompressedSize+1) + for i := range huge { + huge[i] = 'A' + } + + var buf bytes.Buffer + w, err := zstd.NewWriter(&buf) + if err != nil { + t.Fatal(err) + } + if _, err := w.Write(huge); err != nil { + t.Fatal(err) + } + w.Close() + encoded := base64.StdEncoding.EncodeToString(buf.Bytes()) + + _, err = decompress(encoded, "zstd") + if err == nil { + t.Fatal("expected error for oversized decompressed output") + } + if !strings.Contains(err.Error(), "exceeds") { + t.Errorf("expected 'exceeds' in error message, got: %v", err) + } +} diff --git a/go/go.mod b/go/go.mod index 9da25113f..e493eff09 100644 --- a/go/go.mod +++ b/go/go.mod @@ -62,7 +62,9 @@ require ( require ( github.com/aws/aws-sdk-go-v2 v1.41.5 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.4 + github.com/google/jsonschema-go v0.4.2 github.com/jackc/pgx/v5 v5.9.1 + github.com/klauspost/compress v1.18.5 github.com/ollama/ollama v0.20.5 github.com/testcontainers/testcontainers-go v0.42.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.42.0 @@ -154,7 +156,6 @@ require ( github.com/google/cel-go v0.26.0 // indirect github.com/google/gnostic-models v0.7.0 // indirect github.com/google/go-cmp v0.7.0 // indirect - github.com/google/jsonschema-go v0.4.2 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/google/safehtml v0.1.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect @@ -168,7 +169,6 @@ require ( github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/compress v1.18.5 // indirect github.com/lestrrat-go/blackmagic v1.0.2 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect github.com/lestrrat-go/httprc v1.0.6 // indirect