Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 83 additions & 3 deletions internal/cli/doctor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package cli

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"maps"
Expand Down Expand Up @@ -66,6 +68,16 @@ var (
tplEFrontingDomain = template.Must(
template.New("").Parse(" ❌ {{ .address }}: {{ .error }}\n"),
)

tplOFrontingTLS = template.Must(
template.New("").Parse(" ✅ TLS certificate for {{ .host }} is valid\n"),
)
tplEFrontingTLS = template.Must(
template.New("").Parse(" ❌ TLS certificate for {{ .host }} is invalid: {{ .error }}\n"),
)
tplSFrontingTLS = template.Must(
template.New("").Parse(" ⏭ TLS certificate check skipped: proxy-protocol is enabled (the listener expects a PROXY header that mtg doctor does not send yet)\n"),
)
)

type Doctor struct {
Expand Down Expand Up @@ -339,13 +351,19 @@ func (d *Doctor) checkNetworkAddresses(ntw mtglib.Network, dc int, addresses []s
}

func (d *Doctor) checkFrontingDomain(ntw mtglib.Network) bool {
host := d.conf.Secret.Host
// SNI must always be the secret host: that is what domain fronting puts on
// the wire and what the certificate is issued for. The TCP target may be a
// different address when domain-fronting.host overrides it (in the
// sni-router setup it is an internal name like "web").
sniHost := d.conf.Secret.Host

dialHost := sniHost
if override := d.conf.GetDomainFrontingHost(); override != "" {
host = override
dialHost = override
}

port := d.conf.GetDomainFrontingPort(mtglib.DefaultDomainFrontingPort)
address := net.JoinHostPort(host, strconv.Itoa(int(port)))
address := net.JoinHostPort(dialHost, strconv.Itoa(int(port)))

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
Expand All @@ -367,9 +385,71 @@ func (d *Doctor) checkFrontingDomain(ntw mtglib.Network) bool {
"address": address,
})

// With proxy-protocol enabled the fronting listener expects a PROXY header
// before the TLS ClientHello, so a bare TLS handshake would hang or be
// rejected and report a misleading failure. mtg doctor does not emit that
// header yet, so skip the certificate probe rather than print a false
// negative. See issue #518.
if d.conf.GetDomainFrontingProxyProtocol(false) {
tplSFrontingTLS.Execute(os.Stdout, nil) //nolint: errcheck
return true
}

// A default crypto/tls client handshake against the fronting endpoint with
// ServerName = secret host validates the whole certificate in one shot:
// chain against the system roots, leaf SAN against the secret host, and
// validity period. An expired / untrusted / wrong-host certificate all
// surface as descriptive x509 errors.
if err := probeFrontingTLS(ctx, dialer, address, sniHost, nil); err != nil {
tplEFrontingTLS.Execute(os.Stdout, map[string]any{ //nolint: errcheck
"host": sniHost,
"error": err,
})
return false
}

tplOFrontingTLS.Execute(os.Stdout, map[string]any{ //nolint: errcheck
"host": sniHost,
})

return true
}

// probeFrontingTLS dials dialAddress over TCP and performs a TLS handshake
// presenting sniHost as the SNI / ServerName. Verification is left at the
// crypto/tls default (InsecureSkipVerify=false), so the handshake fails with a
// descriptive x509 error if the certificate chain is untrusted, the leaf SAN
// does not cover sniHost, or the certificate is expired/not-yet-valid.
//
// rootCAs overrides the trust anchors; it is nil in production (system roots)
// and is only set by tests that need a self-signed anchor.
func probeFrontingTLS(
ctx context.Context,
dialer *net.Dialer,
dialAddress string,
sniHost string,
rootCAs *x509.CertPool,
) error {
conn, err := dialer.DialContext(ctx, "tcp", dialAddress)
if err != nil {
return fmt.Errorf("cannot dial %s: %w", dialAddress, err)
}
defer conn.Close() //nolint: errcheck

if deadline, ok := ctx.Deadline(); ok {
conn.SetDeadline(deadline) //nolint: errcheck
}

tlsConn := tls.Client(conn, &tls.Config{
ServerName: sniHost,
RootCAs: rootCAs,
MinVersion: tls.VersionTLS12,
})
defer tlsConn.Close() //nolint: errcheck

return tlsConn.HandshakeContext(ctx)
}

func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) bool {
res := runSNICheck(context.Background(), resolver, d.conf, ntw)

Expand Down
174 changes: 174 additions & 0 deletions internal/cli/doctor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
package cli

import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"net"
"strings"
"testing"
"time"
)

// makeCert builds a self-signed leaf certificate valid for the supplied DNS
// name (and IP, so dialing 127.0.0.1 still reaches the listener) plus a
// matching tls.Config and an x509 pool that trusts it.
func makeCert(t *testing.T, dnsName string, notAfter time.Time) (tls.Certificate, *x509.CertPool) {
t.Helper()

key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("generate key: %v", err)
}

tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: dnsName},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: notAfter,
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
DNSNames: []string{dnsName},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
}

der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
if err != nil {
t.Fatalf("create certificate: %v", err)
}

leaf, err := x509.ParseCertificate(der)
if err != nil {
t.Fatalf("parse certificate: %v", err)
}

pool := x509.NewCertPool()
pool.AddCert(leaf)

return tls.Certificate{Certificate: [][]byte{der}, PrivateKey: key, Leaf: leaf}, pool
}

// startTLSServer spins up a TLS listener that completes handshakes using cert
// and returns its address. It is closed when the test finishes.
func startTLSServer(t *testing.T, cert tls.Certificate) string {
t.Helper()

ln, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
})
if err != nil {
t.Fatalf("listen: %v", err)
}

t.Cleanup(func() { _ = ln.Close() })

go func() {
for {
conn, err := ln.Accept()
if err != nil {
return
}

go func() {
// Drive the handshake so the client side completes, then drop.
if tc, ok := conn.(*tls.Conn); ok {
_ = tc.Handshake()
}
_ = conn.Close()
}()
}
}()

return ln.Addr().String()
}

func TestProbeFrontingTLS_ValidCert(t *testing.T) {
cert, pool := makeCert(t, "front.example.org", time.Now().Add(24*time.Hour))
addr := startTLSServer(t, cert)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

err := probeFrontingTLS(ctx, &net.Dialer{}, addr, "front.example.org", pool)
if err != nil {
t.Fatalf("expected success, got error: %v", err)
}
}

func TestProbeFrontingTLS_WrongHost(t *testing.T) {
// Cert is for front.example.org, but we verify against other.example.org.
cert, pool := makeCert(t, "front.example.org", time.Now().Add(24*time.Hour))
addr := startTLSServer(t, cert)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

err := probeFrontingTLS(ctx, &net.Dialer{}, addr, "other.example.org", pool)
if err == nil {
t.Fatal("expected SAN-mismatch failure, got success")
}
if !strings.Contains(err.Error(), "x509") {
t.Fatalf("expected x509 verification error, got: %v", err)
}
}

func TestProbeFrontingTLS_UntrustedCA(t *testing.T) {
// Server cert is self-signed; we hand the client an empty pool that does
// not trust it. Default verification must reject.
cert, _ := makeCert(t, "front.example.org", time.Now().Add(24*time.Hour))
addr := startTLSServer(t, cert)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

err := probeFrontingTLS(ctx, &net.Dialer{}, addr, "front.example.org", x509.NewCertPool())
if err == nil {
t.Fatal("expected untrusted-CA failure, got success")
}
if !strings.Contains(err.Error(), "x509") {
t.Fatalf("expected x509 verification error, got: %v", err)
}
}

func TestProbeFrontingTLS_ExpiredCert(t *testing.T) {
// Same trust anchor, but the cert is already expired.
cert, pool := makeCert(t, "front.example.org", time.Now().Add(-time.Hour))
addr := startTLSServer(t, cert)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

err := probeFrontingTLS(ctx, &net.Dialer{}, addr, "front.example.org", pool)
if err == nil {
t.Fatal("expected expiry failure, got success")
}
if !strings.Contains(err.Error(), "expired") {
t.Fatalf("expected expiry error, got: %v", err)
}
}

func TestProbeFrontingTLS_OverrideDialDifferentFromSNI(t *testing.T) {
// Domain-fronting override: dial one address (the listener bound to
// 127.0.0.1), but verify the cert against the secret host name. The cert
// is issued for the secret host, so verification must pass even though the
// dial target is a bare IP:port.
cert, pool := makeCert(t, "secret.example.org", time.Now().Add(24*time.Hour))
addr := startTLSServer(t, cert) // e.g. 127.0.0.1:NNNNN

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// dial-target (addr, an IP:port) != SNI (secret.example.org)
err := probeFrontingTLS(ctx, &net.Dialer{}, addr, "secret.example.org", pool)
if err != nil {
t.Fatalf("expected success when dialing override addr with secret-host SNI, got: %v", err)
}
}
Loading