diff --git a/internal/cli/access.go b/internal/cli/access.go index c93c97b19..00c96a993 100644 --- a/internal/cli/access.go +++ b/internal/cli/access.go @@ -1,6 +1,7 @@ package cli import ( + "context" "encoding/json" "fmt" "net" @@ -54,39 +55,40 @@ func (a *Access) Run(cli *CLI, version string) error { return fmt.Errorf("cannot init network: %w", err) } - wg := &sync.WaitGroup{} + ctx, cancel := context.WithTimeout(context.Background(), getIPTimeout) + defer cancel() + wg := &sync.WaitGroup{} wg.Go(func() { ip := a.PublicIPv4 + if ip == nil { ip = conf.PublicIPv4.Get(nil) } + if ip == nil { - ip = getIP(ntw, "tcp4") + ip, _ = getIP(ctx, ntw, "tcp4") } if ip != nil { - ip = ip.To4() + resp.IPv4 = a.makeURLs(conf, ip.To4()) } - - resp.IPv4 = a.makeURLs(conf, ip) }) wg.Go(func() { ip := a.PublicIPv6 + if ip == nil { ip = conf.PublicIPv6.Get(nil) } + if ip == nil { - ip = getIP(ntw, "tcp6") + ip, _ = getIP(ctx, ntw, "tcp6") } if ip != nil { - ip = ip.To16() + resp.IPv6 = a.makeURLs(conf, ip.To16()) } - - resp.IPv6 = a.makeURLs(conf, ip) }) - wg.Wait() encoder := json.NewEncoder(os.Stdout) diff --git a/internal/cli/doctor.go b/internal/cli/doctor.go index 50876bd2c..6b7f07445 100644 --- a/internal/cli/doctor.go +++ b/internal/cli/doctor.go @@ -24,54 +24,76 @@ import ( ) var ( + funcs = template.FuncMap{ + "join": strings.Join, + } + tplError = template.Must( - template.New("").Parse(" ‼️ {{ .description }}: {{ .error }}\n"), + template.New(""). + Funcs(funcs). + Parse(" ‼️ {{ .description }}: {{ .error }}\n"), ) tplWDeprecatedConfig = template.Must( template.New(""). + Funcs(funcs). Parse(` ⚠️ Option {{ .old | printf "%q" }}{{ if .old_section }} from section [{{ .old_section }}]{{ end }} is deprecated and will be removed in v{{ .when }}. Please use {{ .new | printf "%q" }}{{ if .new_section }} in [{{ .new_section }}] section{{ end }} instead.` + "\n"), ) tplOTimeSkewness = template.Must( template.New(""). + Funcs(funcs). Parse(" ✅ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}\n"), ) tplWTimeSkewness = template.Must( template.New(""). + Funcs(funcs). Parse(" ⚠️ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}. Please check ntp.\n"), ) tplETimeSkewness = template.Must( template.New(""). + Funcs(funcs). Parse(" ❌ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}. You will get many rejected connections!\n"), ) tplODCConnect = template.Must( - template.New("").Parse(" ✅ DC {{ .dc }} (rpc {{ .rtt }})\n"), + template.New(""). + Funcs(funcs). + Parse(" ✅ DC {{ .dc }} (rpc {{ .rtt }})\n"), ) tplEDCConnect = template.Must( - template.New("").Parse(" ❌ DC {{ .dc }}: {{ .error }}\n"), + template.New(""). + Funcs(funcs). + Parse(" ❌ DC {{ .dc }}: {{ .error }}\n"), ) tplODNSSNIMatch = template.Must( - template.New("").Parse(" ✅ IP address {{ .ip }} matches secret hostname {{ .hostname }}\n"), + template.New(""). + Funcs(funcs). + Parse(" ✅ IP address {{ .ip }} matches secret hostname {{ .hostname }}\n"), ) tplEDNSSNIMatch = template.Must( - template.New("").Parse(" ❌ Hostname {{ .hostname }} {{ if .resolved }}resolves to {{ .resolved }}, but the proxy's public IP is {{ if .ip4 }}{{ .ip4 }}{{ else }}{{ end }} (IPv4) / {{ if .ip6 }}{{ .ip6 }}{{ else }}{{ end }} (IPv6) — none of the resolved addresses match{{ else }}cannot be resolved to any host{{ end }}\n"), + template.New(""). + Funcs(funcs). + Parse(` ❌ Hostname {{ .hostname }} resolves to {{ join ", " .resolved }} but public IP is {{ .ip }}` + "\n"), ) tplOFrontingDomain = template.Must( - template.New("").Parse(" ✅ {{ .address }} is reachable\n"), + template.New(""). + Funcs(funcs). + Parse(" ✅ {{ .address }} is reachable\n"), ) tplEFrontingDomain = template.Must( - template.New("").Parse(" ❌ {{ .address }}: {{ .error }}\n"), + template.New(""). + Funcs(funcs). + Parse(" ❌ {{ .address }}: {{ .error }}\n"), ) ) type Doctor struct { conf *config.Config - ConfigPath string `kong:"arg,required,type='existingfile',help='Path to the configuration file.',name='config-path'"` //nolint: lll + ConfigPath string `kong:"arg,required,type='existingfile',help='Path to the configuration file.',name='config-path'"` //nolint: lll SkipNativeCheck bool `kong:"help='Skip the native network connectivity check (useful when proxy chaining is configured and direct egress is not expected to work).',name='skip-native-check'"` //nolint: lll } @@ -371,17 +393,16 @@ func (d *Doctor) checkFrontingDomain(ntw mtglib.Network) bool { } func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) bool { - res := runSNICheck(context.Background(), resolver, d.conf, ntw) - - if res.ResolveErr != nil { + res, err := runSNICheck(context.Background(), d.conf, resolver, ntw) + if err != nil { tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck "description": fmt.Sprintf("cannot resolve DNS name of %s", d.conf.Secret.Host), - "error": res.ResolveErr, + "error": err, }) return false } - if !res.PublicIPKnown() { + if res.OurIP4 == "" && res.OurIP6 == "" { tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck "description": "cannot detect public IP address", "error": errors.New("cannot detect automatically and public-ipv4/public-ipv6 are not set in config"), @@ -389,35 +410,38 @@ func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) boo return false } - if res.IPv4Match || res.IPv6Match { - var matched net.IP + ok := true - for _, ip := range res.Resolved { - if (res.OurIPv4 != nil && ip.String() == res.OurIPv4.String()) || - (res.OurIPv6 != nil && ip.String() == res.OurIPv6.String()) { - matched = ip - break - } + if len(res.ResolvedIP4) > 0 { + if slices.Contains(res.ResolvedIP4, res.OurIP4) { + tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck + "ip": res.OurIP4, + "hostname": d.conf.Secret.Host, + }) + } else { + tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck + "ip": res.OurIP4, + "resolved": res.ResolvedIP4, + "hostname": d.conf.Secret.Host, + }) + ok = false } - - tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck - "ip": matched, - "hostname": d.conf.Secret.Host, - }) - return true } - - strAddresses := make([]string, 0, len(res.Resolved)) - for _, ip := range res.Resolved { - strAddresses = append(strAddresses, `"`+ip.String()+`"`) + if len(res.ResolvedIP6) > 0 { + if slices.Contains(res.ResolvedIP6, res.OurIP6) { + tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck + "ip": res.OurIP6, + "hostname": d.conf.Secret.Host, + }) + } else { + tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck + "ip": res.OurIP6, + "resolved": res.ResolvedIP6, + "hostname": d.conf.Secret.Host, + }) + ok = false + } } - tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck - "hostname": d.conf.Secret.Host, - "resolved": strings.Join(strAddresses, ", "), - "ip4": res.OurIPv4, - "ip6": res.OurIPv6, - }) - - return false + return ok } diff --git a/internal/cli/get_ip.go b/internal/cli/get_ip.go new file mode 100644 index 000000000..bc70a80d3 --- /dev/null +++ b/internal/cli/get_ip.go @@ -0,0 +1,138 @@ +package cli + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "sync" + "time" + + "github.com/9seconds/mtg/v2/essentials" + "github.com/9seconds/mtg/v2/mtglib" +) + +const ( + getIPTimeout = 5 * time.Second +) + +var getIPServicesPlain = []string{ + "https://ifconfig.co", + "https://ifconfig.me", + "https://api.ipify.org", + "https://ipecho.net/plain", +} + +func getIP(ctx context.Context, ntw mtglib.Network, protocol string) (net.IP, error) { + ctx, cancel := context.WithTimeout(ctx, getIPTimeout) + defer cancel() + + ctx, cancelCause := context.WithCancelCause(ctx) + defer cancelCause(nil) + + var ip net.IP + + rvChan := make(chan net.IP) + errChan := make(chan error) + errs := []error{} + wg := &sync.WaitGroup{} + dialer := ntw.NativeDialer() + client := ntw.MakeHTTPClient(func(_ context.Context, network, address string) (essentials.Conn, error) { + conn, err := dialer.DialContext(ctx, protocol, address) + if err != nil { + return nil, err + } + return essentials.WrapNetConn(conn), err + }) + + for _, url := range getIPServicesPlain { + wg.Go(func() { + lErrChan := errChan + rChan := rvChan + + ip, err := getIPAddressPlain(ctx, client, url) + if err == nil { + lErrChan = nil + } else { + rChan = nil + } + + select { + case <-ctx.Done(): + case lErrChan <- fmt.Errorf("%s: %w", url, err): + case rChan <- ip: + } + }) + } + + wg.Go(func() { + defer cancelCause(nil) + + for { + select { + case <-ctx.Done(): + return + case foundIP := <-rvChan: + ip = foundIP + return + case err := <-errChan: + errs = append(errs, err) + if len(errs) == len(getIPServicesPlain) { + cancelCause(fmt.Errorf( + "cannot resolve %s address: %w", + protocol, + errors.Join(errs...), + )) + } + } + } + }) + + wg.Wait() + + if ip != nil { + return ip, nil + } + + return nil, context.Cause(ctx) +} + +func getIPAddressPlain(ctx context.Context, client *http.Client, address string) (net.IP, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, address, nil) + if err != nil { + panic(err) + } + + req.Header.Add("Accept", "text/plain") + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + + defer func() { + io.Copy(io.Discard, resp.Body) //nolint: errcheck + resp.Body.Close() //nolint: errcheck + }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + data = bytes.TrimSpace(data) + ip := net.ParseIP(string(data)) + + if ip == nil { + return nil, errors.New("cannot parse as IP address") + } + + return ip, nil +} diff --git a/internal/cli/run_proxy.go b/internal/cli/run_proxy.go index 704666381..f1ce8af8b 100644 --- a/internal/cli/run_proxy.go +++ b/internal/cli/run_proxy.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "os" + "slices" "strings" "github.com/9seconds/mtg/v2/antireplay" @@ -215,58 +216,32 @@ func warnSNIMismatch(conf *config.Config, ntw mtglib.Network, log mtglib.Logger) return } - res := runSNICheck(context.Background(), net.DefaultResolver, conf, ntw) + log = log.BindStr("hostname", host) - if res.ResolveErr != nil { - log.BindStr("hostname", host). - WarningError("SNI-DNS check: cannot resolve secret hostname", res.ResolveErr) + res, err := runSNICheck(context.Background(), conf, net.DefaultResolver, ntw) + if err != nil { + log.WarningError("SNI-DNS check: cannot resolve secret hostname", err) return } - if !res.PublicIPKnown() { + if res.OurIP4 == "" && res.OurIP6 == "" { log.Warning("SNI-DNS check: cannot detect public IP address; set public-ipv4/public-ipv6 in config or run 'mtg doctor'") return } - v4Match := res.OurIPv4 == nil || res.IPv4Match - v6Match := res.OurIPv6 == nil || res.IPv6Match - - if v4Match && v6Match { - return - } - - resolved := make([]string, 0, len(res.Resolved)) - for _, ip := range res.Resolved { - resolved = append(resolved, ip.String()) + if len(res.ResolvedIP4) > 0 && !slices.Contains(res.ResolvedIP4, res.OurIP4) { + log. + BindStr("public_ip", res.OurIP4). + BindStr("resolved", strings.Join(res.ResolvedIP4, ",")). + Warning("SNI-DNS check: address mismatch") } - our := "" - if res.OurIPv4 != nil { - our = res.OurIPv4.String() + if len(res.ResolvedIP6) > 0 && !slices.Contains(res.ResolvedIP6, res.OurIP6) { + log. + BindStr("public_ip", res.OurIP6). + BindStr("resolved", strings.Join(res.ResolvedIP6, ",")). + Warning("SNI-DNS check: address mismatch") } - - if res.OurIPv6 != nil { - if our != "" { - our += "/" - } - - our += res.OurIPv6.String() - } - - entry := log.BindStr("hostname", host). - BindStr("resolved", strings.Join(resolved, ", ")). - BindStr("public_ip", our) - - if res.OurIPv4 != nil { - entry = entry.BindStr("ipv4_match", fmt.Sprintf("%t", v4Match)) - } - - if res.OurIPv6 != nil { - entry = entry.BindStr("ipv6_match", fmt.Sprintf("%t", v6Match)) - } - - entry.Warning("SNI-DNS mismatch: secret hostname does not resolve to this server's public IP. " + - "DPI may detect and block the proxy. See 'mtg doctor' for details") } func warnDeprecatedDomainFronting(conf *config.Config, log mtglib.Logger) { diff --git a/internal/cli/sni_check.go b/internal/cli/sni_check.go index d4dbade9b..b92469cbe 100644 --- a/internal/cli/sni_check.go +++ b/internal/cli/sni_check.go @@ -2,77 +2,88 @@ package cli import ( "context" + "fmt" "net" + "sync" "github.com/9seconds/mtg/v2/internal/config" "github.com/9seconds/mtg/v2/mtglib" ) -// sniCheckResult holds the data gathered while comparing the secret -// hostname's DNS records against this server's public IP addresses. -// -// IPv4Match / IPv6Match report whether a resolved record actually equals the -// corresponding public IP. They are false when that family's public IP could -// not be determined — there is nothing to compare against. Callers decide -// what counts as a clean result from these fields: `mtg doctor` and the -// startup warning apply different rules. type sniCheckResult struct { - Resolved []net.IP - OurIPv4 net.IP - OurIPv6 net.IP - IPv4Match bool - IPv6Match bool - ResolveErr error + ResolvedIP4 []string + ResolvedIP6 []string + OurIP4 string + OurIP6 string } -// PublicIPKnown reports whether at least one public IP family was detected. -func (r sniCheckResult) PublicIPKnown() bool { - return r.OurIPv4 != nil || r.OurIPv6 != nil -} - -// runSNICheck resolves conf.Secret.Host and compares the records with this -// server's public IPv4 and IPv6. Public IPs come from config first and fall -// back to on-the-fly detection via ntw. It gathers data only — it does not -// decide success; see sniCheckResult. func runSNICheck( ctx context.Context, - resolver *net.Resolver, conf *config.Config, + resolver *net.Resolver, ntw mtglib.Network, -) sniCheckResult { +) (sniCheckResult, error) { res := sniCheckResult{} + ctx, cancelCause := context.WithCancelCause(ctx) + defer cancelCause(nil) + addrs, err := resolver.LookupIPAddr(ctx, conf.Secret.Host) if err != nil { - res.ResolveErr = err - - return res + return res, fmt.Errorf("cannot resolve addresses of %s: %w", conf.Secret.Host, err) } - res.Resolved = make([]net.IP, 0, len(addrs)) - for _, a := range addrs { - res.Resolved = append(res.Resolved, a.IP) + if len(addrs) == 0 { + return res, fmt.Errorf("no known addresses for %s", conf.Secret.Host) } - res.OurIPv4 = conf.PublicIPv4.Get(nil) - if res.OurIPv4 == nil { - res.OurIPv4 = getIP(ntw, "tcp4") + for _, addr := range addrs { + if ip := addr.IP.To4(); ip == nil { + res.ResolvedIP6 = append(res.ResolvedIP6, addr.IP.To16().String()) + } else { + res.ResolvedIP4 = append(res.ResolvedIP4, ip.String()) + } } - res.OurIPv6 = conf.PublicIPv6.Get(nil) - if res.OurIPv6 == nil { - res.OurIPv6 = getIP(ntw, "tcp6") + wg := &sync.WaitGroup{} + + if len(res.ResolvedIP4) > 0 { + wg.Go(func() { + var err error + + ip := conf.PublicIPv4.Get(nil) + if ip == nil { + ip, err = getIP(ctx, ntw, "tcp4") + if err != nil { + cancelCause(err) + } + } + + if ip != nil { + res.OurIP4 = ip.To4().String() + } + }) } - for _, ip := range res.Resolved { - if res.OurIPv4 != nil && ip.String() == res.OurIPv4.String() { - res.IPv4Match = true - } + if len(res.ResolvedIP6) > 0 { + wg.Go(func() { + var err error - if res.OurIPv6 != nil && ip.String() == res.OurIPv6.String() { - res.IPv6Match = true - } + ip := conf.PublicIPv6.Get(nil) + if ip == nil { + ip, err = getIP(ctx, ntw, "tcp6") + if err != nil { + cancelCause(err) + } + } + + if ip != nil { + res.OurIP6 = ip.To16().String() + } + }) } - return res + wg.Wait() + + return res, context.Cause(ctx) } diff --git a/internal/cli/utils.go b/internal/cli/utils.go deleted file mode 100644 index db8af549b..000000000 --- a/internal/cli/utils.go +++ /dev/null @@ -1,51 +0,0 @@ -package cli - -import ( - "context" - "io" - "net" - "net/http" - "strings" - - "github.com/9seconds/mtg/v2/essentials" - "github.com/9seconds/mtg/v2/mtglib" -) - -func getIP(ntw mtglib.Network, protocol string) net.IP { - dialer := ntw.NativeDialer() - client := ntw.MakeHTTPClient(func(ctx context.Context, network, address string) (essentials.Conn, error) { - conn, err := dialer.DialContext(ctx, protocol, address) - if err != nil { - return nil, err - } - return essentials.WrapNetConn(conn), err - }) - - req, err := http.NewRequest(http.MethodGet, "https://ifconfig.co", nil) //nolint: noctx - if err != nil { - panic(err) - } - - req.Header.Add("Accept", "text/plain") - - resp, err := client.Do(req) - if err != nil { - return nil - } - - if resp.StatusCode != http.StatusOK { - return nil - } - - defer func() { - io.Copy(io.Discard, resp.Body) //nolint: errcheck - resp.Body.Close() //nolint: errcheck - }() - - data, err := io.ReadAll(resp.Body) - if err != nil { - return nil - } - - return net.ParseIP(strings.TrimSpace(string(data))) -}