diff --git a/internal/api/api.go b/internal/api/api.go index 4dfad40d3b..e0bc234558 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -87,24 +87,10 @@ type ClientOpts struct { ProxyPath string } -// NewClient creates a new API client. -func NewClient(opts ClientOpts) Client { - if opts.Out == nil { - panic("unexpected nil out option") - } - - flags := opts.Flags - if flags == nil { - flags = defaultFlags() - } - - httpClient := http.DefaultClient - +func buildTransport(opts ClientOpts, flags *Flags) *http.Transport { transport := http.DefaultTransport.(*http.Transport).Clone() - customTransport := false if flags.insecureSkipVerify != nil && *flags.insecureSkipVerify { - customTransport = true transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} } @@ -112,14 +98,28 @@ func NewClient(opts ClientOpts) Client { transport.TLSClientConfig = &tls.Config{} } - if applyProxy(transport, opts.ProxyURL, opts.ProxyPath) { - customTransport = true + if opts.ProxyURL != nil || opts.ProxyPath != "" { + transport = withProxyTransport(transport, opts.ProxyURL, opts.ProxyPath) } - if customTransport { - httpClient = &http.Client{ - Transport: transport, - } + return transport +} + +// NewClient creates a new API client. +func NewClient(opts ClientOpts) Client { + if opts.Out == nil { + panic("unexpected nil out option") + } + + flags := opts.Flags + if flags == nil { + flags = defaultFlags() + } + + transport := buildTransport(opts, flags) + + httpClient := &http.Client{ + Transport: transport, } return &client{ diff --git a/internal/api/api_test.go b/internal/api/api_test.go deleted file mode 100644 index 0aaf3ec14a..0000000000 --- a/internal/api/api_test.go +++ /dev/null @@ -1,3 +0,0 @@ -package api - -// TODO: implement a super basic GraphQL server that can return canned results. diff --git a/internal/api/proxy.go b/internal/api/proxy.go index 3cf8673829..9589b9beb5 100644 --- a/internal/api/proxy.go +++ b/internal/api/proxy.go @@ -11,11 +11,13 @@ import ( "net/url" ) -func applyProxy(transport *http.Transport, proxyURL *url.URL, proxyPath string) (applied bool) { - if proxyURL == nil && proxyPath == "" { - return false - } - +// withProxyTransport modifies the given transport to handle proxying of unix, socks5 and http connections. +// +// Note: baseTransport is considered to be a clone created with transport.Clone() +// +// - If a the proxyPath is not empty, a unix socket proxy is created. +// - Otherwise, the proxyURL is used to determine if we should proxy socks5 / http connections +func withProxyTransport(baseTransport *http.Transport, proxyURL *url.URL, proxyPath string) *http.Transport { handshakeTLS := func(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) { // Extract the hostname (without the port) for TLS SNI host, _, err := net.SplitHostPort(addr) @@ -26,7 +28,7 @@ func applyProxy(transport *http.Transport, proxyURL *url.URL, proxyPath string) ServerName: host, // Pull InsecureSkipVerify from the target host transport // so that insecure-skip-verify flag settings are honored for the proxy server - InsecureSkipVerify: transport.TLSClientConfig.InsecureSkipVerify, + InsecureSkipVerify: baseTransport.TLSClientConfig.InsecureSkipVerify, }) if err := tlsConn.HandshakeContext(ctx); err != nil { return nil, err @@ -34,8 +36,6 @@ func applyProxy(transport *http.Transport, proxyURL *url.URL, proxyPath string) return tlsConn, nil } - proxyApplied := false - if proxyPath != "" { dial := func(ctx context.Context, _, _ string) (net.Conn, error) { d := net.Dialer{} @@ -48,17 +48,15 @@ func applyProxy(transport *http.Transport, proxyURL *url.URL, proxyPath string) } return handshakeTLS(ctx, conn, addr) } - transport.DialContext = dial - transport.DialTLSContext = dialTLS + baseTransport.DialContext = dial + baseTransport.DialTLSContext = dialTLS // clear out any system proxy settings - transport.Proxy = nil - proxyApplied = true + baseTransport.Proxy = nil } else if proxyURL != nil { switch proxyURL.Scheme { case "socks5", "socks5h": // SOCKS proxies work out of the box - no need to manually dial - transport.Proxy = http.ProxyURL(proxyURL) - proxyApplied = true + baseTransport.Proxy = http.ProxyURL(proxyURL) case "http", "https": dial := func(ctx context.Context, network, addr string) (net.Conn, error) { // Dial the proxy @@ -126,13 +124,12 @@ func applyProxy(transport *http.Transport, proxyURL *url.URL, proxyPath string) } return handshakeTLS(ctx, conn, addr) } - transport.DialContext = dial - transport.DialTLSContext = dialTLS + baseTransport.DialContext = dial + baseTransport.DialTLSContext = dialTLS // clear out any system proxy settings - transport.Proxy = nil - proxyApplied = true + baseTransport.Proxy = nil } } - return proxyApplied + return baseTransport }