Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions internal/cli/access.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cli

import (
"context"
"encoding/json"
"fmt"
"net"
Expand Down Expand Up @@ -54,39 +55,40 @@ func (a *Access) Run(cli *CLI, version string) error {
return fmt.Errorf("cannot init network: %w", err)
}

wg := &sync.WaitGroup{}
ctx, cancel := context.WithTimeout(context.Background(), getIPTimeout)
defer cancel()

wg := &sync.WaitGroup{}
wg.Go(func() {
ip := a.PublicIPv4

if ip == nil {
ip = conf.PublicIPv4.Get(nil)
}

if ip == nil {
ip = getIP(ntw, "tcp4")
ip, _ = getIP(ctx, ntw, "tcp4")
}

if ip != nil {
ip = ip.To4()
resp.IPv4 = a.makeURLs(conf, ip.To4())
}

resp.IPv4 = a.makeURLs(conf, ip)
})
wg.Go(func() {
ip := a.PublicIPv6

if ip == nil {
ip = conf.PublicIPv6.Get(nil)
}

if ip == nil {
ip = getIP(ntw, "tcp6")
ip, _ = getIP(ctx, ntw, "tcp6")
}

if ip != nil {
ip = ip.To16()
resp.IPv6 = a.makeURLs(conf, ip.To16())
}

resp.IPv6 = a.makeURLs(conf, ip)
})

wg.Wait()

encoder := json.NewEncoder(os.Stdout)
Expand Down
102 changes: 63 additions & 39 deletions internal/cli/doctor.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,54 +24,76 @@ import (
)

var (
funcs = template.FuncMap{
"join": strings.Join,
}

tplError = template.Must(
template.New("").Parse(" ‼️ {{ .description }}: {{ .error }}\n"),
template.New("").
Funcs(funcs).
Parse(" ‼️ {{ .description }}: {{ .error }}\n"),
)

tplWDeprecatedConfig = template.Must(
template.New("").
Funcs(funcs).
Parse(` ⚠️ Option {{ .old | printf "%q" }}{{ if .old_section }} from section [{{ .old_section }}]{{ end }} is deprecated and will be removed in v{{ .when }}. Please use {{ .new | printf "%q" }}{{ if .new_section }} in [{{ .new_section }}] section{{ end }} instead.` + "\n"),
)

tplOTimeSkewness = template.Must(
template.New("").
Funcs(funcs).
Parse(" ✅ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}\n"),
)
tplWTimeSkewness = template.Must(
template.New("").
Funcs(funcs).
Parse(" ⚠️ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}. Please check ntp.\n"),
)
tplETimeSkewness = template.Must(
template.New("").
Funcs(funcs).
Parse(" ❌ Time drift is {{ .drift }}, but tolerate-time-skewness is {{ .value }}. You will get many rejected connections!\n"),
)

tplODCConnect = template.Must(
template.New("").Parse(" ✅ DC {{ .dc }} (rpc {{ .rtt }})\n"),
template.New("").
Funcs(funcs).
Parse(" ✅ DC {{ .dc }} (rpc {{ .rtt }})\n"),
)
tplEDCConnect = template.Must(
template.New("").Parse(" ❌ DC {{ .dc }}: {{ .error }}\n"),
template.New("").
Funcs(funcs).
Parse(" ❌ DC {{ .dc }}: {{ .error }}\n"),
)

tplODNSSNIMatch = template.Must(
template.New("").Parse(" ✅ IP address {{ .ip }} matches secret hostname {{ .hostname }}\n"),
template.New("").
Funcs(funcs).
Parse(" ✅ IP address {{ .ip }} matches secret hostname {{ .hostname }}\n"),
)
tplEDNSSNIMatch = template.Must(
template.New("").Parse(" ❌ Hostname {{ .hostname }} {{ if .resolved }}resolves to {{ .resolved }}, but the proxy's public IP is {{ if .ip4 }}{{ .ip4 }}{{ else }}<not detected>{{ end }} (IPv4) / {{ if .ip6 }}{{ .ip6 }}{{ else }}<not detected>{{ end }} (IPv6) — none of the resolved addresses match{{ else }}cannot be resolved to any host{{ end }}\n"),
template.New("").
Funcs(funcs).
Parse(` ❌ Hostname {{ .hostname }} resolves to {{ join ", " .resolved }} but public IP is {{ .ip }}` + "\n"),
)

tplOFrontingDomain = template.Must(
template.New("").Parse(" ✅ {{ .address }} is reachable\n"),
template.New("").
Funcs(funcs).
Parse(" ✅ {{ .address }} is reachable\n"),
)
tplEFrontingDomain = template.Must(
template.New("").Parse(" ❌ {{ .address }}: {{ .error }}\n"),
template.New("").
Funcs(funcs).
Parse(" ❌ {{ .address }}: {{ .error }}\n"),
)
)

type Doctor struct {
conf *config.Config

ConfigPath string `kong:"arg,required,type='existingfile',help='Path to the configuration file.',name='config-path'"` //nolint: lll
ConfigPath string `kong:"arg,required,type='existingfile',help='Path to the configuration file.',name='config-path'"` //nolint: lll
SkipNativeCheck bool `kong:"help='Skip the native network connectivity check (useful when proxy chaining is configured and direct egress is not expected to work).',name='skip-native-check'"` //nolint: lll
}

Expand Down Expand Up @@ -371,53 +393,55 @@ func (d *Doctor) checkFrontingDomain(ntw mtglib.Network) bool {
}

func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) bool {
res := runSNICheck(context.Background(), resolver, d.conf, ntw)

if res.ResolveErr != nil {
res, err := runSNICheck(context.Background(), d.conf, resolver, ntw)
if err != nil {
tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
"description": fmt.Sprintf("cannot resolve DNS name of %s", d.conf.Secret.Host),
"error": res.ResolveErr,
"error": err,
})
return false
}

if !res.PublicIPKnown() {
if res.OurIP4 == "" && res.OurIP6 == "" {
tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
"description": "cannot detect public IP address",
"error": errors.New("cannot detect automatically and public-ipv4/public-ipv6 are not set in config"),
})
return false
}

if res.IPv4Match || res.IPv6Match {
var matched net.IP
ok := true

for _, ip := range res.Resolved {
if (res.OurIPv4 != nil && ip.String() == res.OurIPv4.String()) ||
(res.OurIPv6 != nil && ip.String() == res.OurIPv6.String()) {
matched = ip
break
}
if len(res.ResolvedIP4) > 0 {
if slices.Contains(res.ResolvedIP4, res.OurIP4) {
tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
"ip": res.OurIP4,
"hostname": d.conf.Secret.Host,
})
} else {
tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
"ip": res.OurIP4,
"resolved": res.ResolvedIP4,
"hostname": d.conf.Secret.Host,
})
ok = false
}

tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
"ip": matched,
"hostname": d.conf.Secret.Host,
})
return true
}

strAddresses := make([]string, 0, len(res.Resolved))
for _, ip := range res.Resolved {
strAddresses = append(strAddresses, `"`+ip.String()+`"`)
if len(res.ResolvedIP6) > 0 {
if slices.Contains(res.ResolvedIP6, res.OurIP6) {
tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
"ip": res.OurIP6,
"hostname": d.conf.Secret.Host,
})
} else {
tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
"ip": res.OurIP6,
"resolved": res.ResolvedIP6,
"hostname": d.conf.Secret.Host,
})
ok = false
}
}

tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
"hostname": d.conf.Secret.Host,
"resolved": strings.Join(strAddresses, ", "),
"ip4": res.OurIPv4,
"ip6": res.OurIPv6,
})

return false
return ok
}
138 changes: 138 additions & 0 deletions internal/cli/get_ip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package cli

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"

"github.com/9seconds/mtg/v2/essentials"
"github.com/9seconds/mtg/v2/mtglib"
)

const (
getIPTimeout = 5 * time.Second
)

var getIPServicesPlain = []string{
"https://ifconfig.co",
"https://ifconfig.me",
"https://api.ipify.org",
"https://ipecho.net/plain",
}

func getIP(ctx context.Context, ntw mtglib.Network, protocol string) (net.IP, error) {
ctx, cancel := context.WithTimeout(ctx, getIPTimeout)
defer cancel()

ctx, cancelCause := context.WithCancelCause(ctx)
defer cancelCause(nil)

var ip net.IP

rvChan := make(chan net.IP)
errChan := make(chan error)
errs := []error{}
wg := &sync.WaitGroup{}
dialer := ntw.NativeDialer()
client := ntw.MakeHTTPClient(func(_ context.Context, network, address string) (essentials.Conn, error) {
conn, err := dialer.DialContext(ctx, protocol, address)
if err != nil {
return nil, err
}
return essentials.WrapNetConn(conn), err
})

for _, url := range getIPServicesPlain {
wg.Go(func() {
lErrChan := errChan
rChan := rvChan

ip, err := getIPAddressPlain(ctx, client, url)
if err == nil {
lErrChan = nil
} else {
rChan = nil
}

select {
case <-ctx.Done():
case lErrChan <- fmt.Errorf("%s: %w", url, err):
case rChan <- ip:
}
})
}

wg.Go(func() {
defer cancelCause(nil)

for {
select {
case <-ctx.Done():
return
case foundIP := <-rvChan:
ip = foundIP
return
case err := <-errChan:
errs = append(errs, err)
if len(errs) == len(getIPServicesPlain) {
cancelCause(fmt.Errorf(
"cannot resolve %s address: %w",
protocol,
errors.Join(errs...),
))
}
}
}
})

wg.Wait()

if ip != nil {
return ip, nil
}

return nil, context.Cause(ctx)
}

func getIPAddressPlain(ctx context.Context, client *http.Client, address string) (net.IP, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, address, nil)
if err != nil {
panic(err)
}

req.Header.Add("Accept", "text/plain")

resp, err := client.Do(req)
if err != nil {
return nil, err
}

defer func() {
io.Copy(io.Discard, resp.Body) //nolint: errcheck
resp.Body.Close() //nolint: errcheck
}()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
}

data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}

data = bytes.TrimSpace(data)
ip := net.ParseIP(string(data))

if ip == nil {
return nil, errors.New("cannot parse as IP address")
}

return ip, nil
}
Loading
Loading