diff --git a/internal/cli/doctor.go b/internal/cli/doctor.go index 50876bd2c..bcbdfd238 100644 --- a/internal/cli/doctor.go +++ b/internal/cli/doctor.go @@ -2,6 +2,8 @@ package cli import ( "context" + "crypto/tls" + "crypto/x509" "errors" "fmt" "maps" @@ -66,6 +68,16 @@ var ( tplEFrontingDomain = template.Must( template.New("").Parse(" ❌ {{ .address }}: {{ .error }}\n"), ) + + tplOFrontingTLS = template.Must( + template.New("").Parse(" ✅ TLS certificate for {{ .host }} is valid\n"), + ) + tplEFrontingTLS = template.Must( + template.New("").Parse(" ❌ TLS certificate for {{ .host }} is invalid: {{ .error }}\n"), + ) + tplSFrontingTLS = template.Must( + template.New("").Parse(" ⏭ TLS certificate check skipped: proxy-protocol is enabled (the listener expects a PROXY header that mtg doctor does not send yet)\n"), + ) ) type Doctor struct { @@ -339,13 +351,19 @@ func (d *Doctor) checkNetworkAddresses(ntw mtglib.Network, dc int, addresses []s } func (d *Doctor) checkFrontingDomain(ntw mtglib.Network) bool { - host := d.conf.Secret.Host + // SNI must always be the secret host: that is what domain fronting puts on + // the wire and what the certificate is issued for. The TCP target may be a + // different address when domain-fronting.host overrides it (in the + // sni-router setup it is an internal name like "web"). + sniHost := d.conf.Secret.Host + + dialHost := sniHost if override := d.conf.GetDomainFrontingHost(); override != "" { - host = override + dialHost = override } port := d.conf.GetDomainFrontingPort(mtglib.DefaultDomainFrontingPort) - address := net.JoinHostPort(host, strconv.Itoa(int(port))) + address := net.JoinHostPort(dialHost, strconv.Itoa(int(port))) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -367,9 +385,71 @@ func (d *Doctor) checkFrontingDomain(ntw mtglib.Network) bool { "address": address, }) + // With proxy-protocol enabled the fronting listener expects a PROXY header + // before the TLS ClientHello, so a bare TLS handshake would hang or be + // rejected and report a misleading failure. mtg doctor does not emit that + // header yet, so skip the certificate probe rather than print a false + // negative. See issue #518. + if d.conf.GetDomainFrontingProxyProtocol(false) { + tplSFrontingTLS.Execute(os.Stdout, nil) //nolint: errcheck + return true + } + + // A default crypto/tls client handshake against the fronting endpoint with + // ServerName = secret host validates the whole certificate in one shot: + // chain against the system roots, leaf SAN against the secret host, and + // validity period. An expired / untrusted / wrong-host certificate all + // surface as descriptive x509 errors. + if err := probeFrontingTLS(ctx, dialer, address, sniHost, nil); err != nil { + tplEFrontingTLS.Execute(os.Stdout, map[string]any{ //nolint: errcheck + "host": sniHost, + "error": err, + }) + return false + } + + tplOFrontingTLS.Execute(os.Stdout, map[string]any{ //nolint: errcheck + "host": sniHost, + }) + return true } +// probeFrontingTLS dials dialAddress over TCP and performs a TLS handshake +// presenting sniHost as the SNI / ServerName. Verification is left at the +// crypto/tls default (InsecureSkipVerify=false), so the handshake fails with a +// descriptive x509 error if the certificate chain is untrusted, the leaf SAN +// does not cover sniHost, or the certificate is expired/not-yet-valid. +// +// rootCAs overrides the trust anchors; it is nil in production (system roots) +// and is only set by tests that need a self-signed anchor. +func probeFrontingTLS( + ctx context.Context, + dialer *net.Dialer, + dialAddress string, + sniHost string, + rootCAs *x509.CertPool, +) error { + conn, err := dialer.DialContext(ctx, "tcp", dialAddress) + if err != nil { + return fmt.Errorf("cannot dial %s: %w", dialAddress, err) + } + defer conn.Close() //nolint: errcheck + + if deadline, ok := ctx.Deadline(); ok { + conn.SetDeadline(deadline) //nolint: errcheck + } + + tlsConn := tls.Client(conn, &tls.Config{ + ServerName: sniHost, + RootCAs: rootCAs, + MinVersion: tls.VersionTLS12, + }) + defer tlsConn.Close() //nolint: errcheck + + return tlsConn.HandshakeContext(ctx) +} + func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) bool { res := runSNICheck(context.Background(), resolver, d.conf, ntw) diff --git a/internal/cli/doctor_test.go b/internal/cli/doctor_test.go new file mode 100644 index 000000000..129521a18 --- /dev/null +++ b/internal/cli/doctor_test.go @@ -0,0 +1,174 @@ +package cli + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "net" + "strings" + "testing" + "time" +) + +// makeCert builds a self-signed leaf certificate valid for the supplied DNS +// name (and IP, so dialing 127.0.0.1 still reaches the listener) plus a +// matching tls.Config and an x509 pool that trusts it. +func makeCert(t *testing.T, dnsName string, notAfter time.Time) (tls.Certificate, *x509.CertPool) { + t.Helper() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: dnsName}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: notAfter, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IsCA: true, + DNSNames: []string{dnsName}, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + if err != nil { + t.Fatalf("create certificate: %v", err) + } + + leaf, err := x509.ParseCertificate(der) + if err != nil { + t.Fatalf("parse certificate: %v", err) + } + + pool := x509.NewCertPool() + pool.AddCert(leaf) + + return tls.Certificate{Certificate: [][]byte{der}, PrivateKey: key, Leaf: leaf}, pool +} + +// startTLSServer spins up a TLS listener that completes handshakes using cert +// and returns its address. It is closed when the test finishes. +func startTLSServer(t *testing.T, cert tls.Certificate) string { + t.Helper() + + ln, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + }) + if err != nil { + t.Fatalf("listen: %v", err) + } + + t.Cleanup(func() { _ = ln.Close() }) + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + + go func() { + // Drive the handshake so the client side completes, then drop. + if tc, ok := conn.(*tls.Conn); ok { + _ = tc.Handshake() + } + _ = conn.Close() + }() + } + }() + + return ln.Addr().String() +} + +func TestProbeFrontingTLS_ValidCert(t *testing.T) { + cert, pool := makeCert(t, "front.example.org", time.Now().Add(24*time.Hour)) + addr := startTLSServer(t, cert) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := probeFrontingTLS(ctx, &net.Dialer{}, addr, "front.example.org", pool) + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } +} + +func TestProbeFrontingTLS_WrongHost(t *testing.T) { + // Cert is for front.example.org, but we verify against other.example.org. + cert, pool := makeCert(t, "front.example.org", time.Now().Add(24*time.Hour)) + addr := startTLSServer(t, cert) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := probeFrontingTLS(ctx, &net.Dialer{}, addr, "other.example.org", pool) + if err == nil { + t.Fatal("expected SAN-mismatch failure, got success") + } + if !strings.Contains(err.Error(), "x509") { + t.Fatalf("expected x509 verification error, got: %v", err) + } +} + +func TestProbeFrontingTLS_UntrustedCA(t *testing.T) { + // Server cert is self-signed; we hand the client an empty pool that does + // not trust it. Default verification must reject. + cert, _ := makeCert(t, "front.example.org", time.Now().Add(24*time.Hour)) + addr := startTLSServer(t, cert) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := probeFrontingTLS(ctx, &net.Dialer{}, addr, "front.example.org", x509.NewCertPool()) + if err == nil { + t.Fatal("expected untrusted-CA failure, got success") + } + if !strings.Contains(err.Error(), "x509") { + t.Fatalf("expected x509 verification error, got: %v", err) + } +} + +func TestProbeFrontingTLS_ExpiredCert(t *testing.T) { + // Same trust anchor, but the cert is already expired. + cert, pool := makeCert(t, "front.example.org", time.Now().Add(-time.Hour)) + addr := startTLSServer(t, cert) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := probeFrontingTLS(ctx, &net.Dialer{}, addr, "front.example.org", pool) + if err == nil { + t.Fatal("expected expiry failure, got success") + } + if !strings.Contains(err.Error(), "expired") { + t.Fatalf("expected expiry error, got: %v", err) + } +} + +func TestProbeFrontingTLS_OverrideDialDifferentFromSNI(t *testing.T) { + // Domain-fronting override: dial one address (the listener bound to + // 127.0.0.1), but verify the cert against the secret host name. The cert + // is issued for the secret host, so verification must pass even though the + // dial target is a bare IP:port. + cert, pool := makeCert(t, "secret.example.org", time.Now().Add(24*time.Hour)) + addr := startTLSServer(t, cert) // e.g. 127.0.0.1:NNNNN + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // dial-target (addr, an IP:port) != SNI (secret.example.org) + err := probeFrontingTLS(ctx, &net.Dialer{}, addr, "secret.example.org", pool) + if err != nil { + t.Fatalf("expected success when dialing override addr with secret-host SNI, got: %v", err) + } +}