diff --git a/config/config.go b/config/config.go index 62ec6f2..1bcfeb5 100644 --- a/config/config.go +++ b/config/config.go @@ -7,6 +7,9 @@ import ( "net/url" "os" + "crypto/tls" + "crypto/x509" + "gopkg.in/yaml.v3" ) @@ -45,7 +48,10 @@ type DynamicClientRegistration struct { } type DexGRPCClient struct { - Addr string `yaml:"addr"` + Addr string `yaml:"addr"` + TLSCert string `yaml:"tlsCert,omitempty"` + TLSKey string `yaml:"tlsKey,omitempty"` + TLSClientCA string `yaml:"tlsClientCA,omitempty"` } type Proxy struct { @@ -133,6 +139,50 @@ func (c *Config) YAMLString() (string, error) { } } +func (g *DexGRPCClient) ClientTLSConfig() (*tls.Config, error) { + // Check if TLS fields are set - must be all or nothing + tlsFieldsSet := 0 + if g.TLSCert != "" { + tlsFieldsSet++ + } + if g.TLSKey != "" { + tlsFieldsSet++ + } + if g.TLSClientCA != "" { + tlsFieldsSet++ + } + + if tlsFieldsSet == 0 { + // No TLS configured - return nil for insecure connection + return nil, nil + } + + if tlsFieldsSet != 3 { + return nil, fmt.Errorf("all three TLS fields (tlsCert, tlsKey, tlsClientCA) must be set together or all left empty") + } + + // All three fields are set - configure mTLS + cPool := x509.NewCertPool() + caCert, err := os.ReadFile(g.TLSClientCA) + if err != nil { + return nil, fmt.Errorf("failed to read TLS client CA: %w", err) + } + if !cPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to append TLS client CA certificate") + } + + clientCert, err := tls.LoadX509KeyPair(g.TLSCert, g.TLSKey) + if err != nil { + return nil, fmt.Errorf("failed to load TLS certificate and key: %w", err) + } + + clientTLSConfig := &tls.Config{ + RootCAs: cPool, + Certificates: []tls.Certificate{clientCert}, + } + return clientTLSConfig, nil +} + func (c *Config) Validate() error { if c.Host == nil { return fmt.Errorf("host is required") @@ -150,6 +200,12 @@ func (c *Config) Validate() error { if c.DexGRPCClient == nil || c.DexGRPCClient.Addr == "" { return fmt.Errorf("dexGRPCClient is required when dynamicClientRegistrationEnabled is true") } + + _, err := c.DexGRPCClient.ClientTLSConfig() + if err != nil { + return fmt.Errorf("dexGRPCClient TLS configuration is invalid: %w", err) + } + } return nil diff --git a/oauth/dynamic_client_registration.go b/oauth/dynamic_client_registration.go index dfad4f5..6f16af2 100644 --- a/oauth/dynamic_client_registration.go +++ b/oauth/dynamic_client_registration.go @@ -11,6 +11,7 @@ import ( "github.com/hyprmcp/mcp-gateway/config" "github.com/hyprmcp/mcp-gateway/log" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" ) @@ -27,9 +28,22 @@ type ClientInformation struct { } func NewDynamicClientRegistrationHandler(config *config.Config, meta map[string]any) (http.Handler, error) { + clientTLSConfig, err := config.DexGRPCClient.ClientTLSConfig() + if err != nil { + return nil, err + } + + var creds credentials.TransportCredentials + + if clientTLSConfig != nil { + creds = credentials.NewTLS(clientTLSConfig) + } else { + creds = insecure.NewCredentials() + } + grpcClient, err := grpc.NewClient( config.DexGRPCClient.Addr, - grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithTransportCredentials(creds), ) if err != nil { return nil, err