Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"

v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
"github.com/authzed/authzed-go/v1"
Expand Down Expand Up @@ -214,6 +215,43 @@
}
}

func extraHeadersUnaryInterceptor(headers map[string]string) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {

Check failure on line 219 in internal/client/client.go

View workflow job for this annotation

GitHub Actions / Lint Go

File is not properly formatted (gofmt)
if len(headers) > 0 {
md := metadata.New(headers)
ctx = metadata.NewOutgoingContext(ctx, md)
}
return invoker(ctx, method, req, reply, cc, opts...)
}
}

func extraHeadersStreamInterceptor(headers map[string]string) grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
if len(headers) > 0 {
md := metadata.New(headers)
ctx = metadata.NewOutgoingContext(ctx, md)
Comment on lines +231 to +232
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this logic functionally merge the headers, or would this overwrite headers if some other component added new headers?

}
return streamer(ctx, desc, cc, method, opts...)
}
}

func parseExtraHeaders(headerStrings []string) (map[string]string, error) {
headers := make(map[string]string)
for _, headerStr := range headerStrings {
parts := strings.SplitN(headerStr, "=", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid header format '%s': expected 'key=value'", headerStr)
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
if key == "" {
return nil, fmt.Errorf("invalid header format '%s': key cannot be empty", headerStr)
}
headers[key] = value
}
return headers, nil
}

// DialOptsFromFlags returns the dial options from the CLI-specified flags.
func DialOptsFromFlags(cmd *cobra.Command, token storage.Token) ([]grpc.DialOption, error) {
maxRetries := cobrautil.MustGetUint(cmd, "max-retries")
Expand All @@ -239,6 +277,17 @@
selector.StreamClientInterceptor(retry.StreamClientInterceptor(retryOpts...), selector.MatchFunc(isNoneOf(importBulkRoute, exportBulkRoute, watchRoute))),
}

// Parse and add extra headers if provided
extraHeaderStrings := cobrautil.MustGetStringSlice(cmd, "extra-header")
if len(extraHeaderStrings) > 0 {
headers, err := parseExtraHeaders(extraHeaderStrings)
if err != nil {
return nil, fmt.Errorf("failed to parse extra headers: %w", err)
}
unaryInterceptors = append(unaryInterceptors, extraHeadersUnaryInterceptor(headers))
streamInterceptors = append(streamInterceptors, extraHeadersStreamInterceptor(headers))
}

if !cobrautil.MustGetBool(cmd, "skip-version-check") {
unaryInterceptors = append(unaryInterceptors, zgrpcutil.CheckServerVersion)
}
Expand Down
2 changes: 2 additions & 0 deletions internal/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ func TestRetries(t *testing.T) {
zedtesting.StringFlag{FlagName: "proxy", FlagValue: "", Changed: true},
zedtesting.StringFlag{FlagName: "hostname-override", FlagValue: "", Changed: true},
zedtesting.IntFlag{FlagName: "max-message-size", FlagValue: 1000, Changed: true},
zedtesting.StringSliceFlag{FlagName: "extra-header", FlagValue: []string{}, Changed: false},
)
dialOpts, err := client.DialOptsFromFlags(cmd, storage.Token{Insecure: &secure})
require.NoError(t, err)
Expand Down Expand Up @@ -224,6 +225,7 @@ func TestDoesNotRetry(t *testing.T) {
zedtesting.StringFlag{FlagName: "proxy", FlagValue: "", Changed: true},
zedtesting.StringFlag{FlagName: "hostname-override", FlagValue: "", Changed: true},
zedtesting.IntFlag{FlagName: "max-message-size", FlagValue: 1000, Changed: true},
zedtesting.StringSliceFlag{FlagName: "extra-header", FlagValue: []string{}, Changed: false},
)
dialOpts, err := client.DialOptsFromFlags(cmd, storage.Token{Insecure: &secure})
require.NoError(t, err)
Expand Down
1 change: 1 addition & 0 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ zed permission check --explain document:firstdoc writer user:emilia
rootCmd.PersistentFlags().Int("max-message-size", 0, "maximum size *in bytes* (defaults to 4_194_304 bytes ~= 4MB) of a gRPC message that can be sent or received by zed")
rootCmd.PersistentFlags().String("proxy", "", "specify a SOCKS5 proxy address")
rootCmd.PersistentFlags().Uint("max-retries", 10, "maximum number of sequential retries to attempt when a request fails")
rootCmd.PersistentFlags().StringSlice("extra-header", []string{}, "extra header(s) to add to gRPC requests in the format 'key=value' (can be specified multiple times)")
_ = rootCmd.PersistentFlags().MarkHidden("debug") // This cannot return its error.

versionCmd := &cobra.Command{
Expand Down
9 changes: 9 additions & 0 deletions internal/testing/test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ type StringFlag struct {
Changed bool
}

type StringSliceFlag struct {
FlagName string
FlagValue []string
Changed bool
}

type BoolFlag struct {
FlagName string
FlagValue bool
Expand Down Expand Up @@ -116,6 +122,9 @@ func CreateTestCobraCommandWithFlagValue(t *testing.T, flagAndValues ...any) *co
case StringFlag:
c.Flags().String(f.FlagName, f.FlagValue, "")
c.Flag(f.FlagName).Changed = f.Changed
case StringSliceFlag:
c.Flags().StringSlice(f.FlagName, f.FlagValue, "")
c.Flag(f.FlagName).Changed = f.Changed
case BoolFlag:
c.Flags().Bool(f.FlagName, f.FlagValue, "")
c.Flag(f.FlagName).Changed = f.Changed
Expand Down
Loading