Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,39 +87,39 @@ type ClientOpts struct {
ProxyPath string
}

// NewClient creates a new API client.
func NewClient(opts ClientOpts) Client {
if opts.Out == nil {
panic("unexpected nil out option")
}

flags := opts.Flags
if flags == nil {
flags = defaultFlags()
}

httpClient := http.DefaultClient

func buildTransport(opts ClientOpts, flags *Flags) *http.Transport {
transport := http.DefaultTransport.(*http.Transport).Clone()
customTransport := false

if flags.insecureSkipVerify != nil && *flags.insecureSkipVerify {
customTransport = true
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}

if transport.TLSClientConfig == nil {
transport.TLSClientConfig = &tls.Config{}
}

if applyProxy(transport, opts.ProxyURL, opts.ProxyPath) {
customTransport = true
if opts.ProxyURL != nil || opts.ProxyPath != "" {
transport = withProxyTransport(transport, opts.ProxyURL, opts.ProxyPath)
}

if customTransport {
httpClient = &http.Client{
Transport: transport,
}
return transport
}

// NewClient creates a new API client.
func NewClient(opts ClientOpts) Client {
if opts.Out == nil {
panic("unexpected nil out option")
}

flags := opts.Flags
if flags == nil {
flags = defaultFlags()
}

transport := buildTransport(opts, flags)

httpClient := &http.Client{
Transport: transport,
}

return &client{
Expand Down
3 changes: 0 additions & 3 deletions internal/api/api_test.go

This file was deleted.

35 changes: 16 additions & 19 deletions internal/api/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ import (
"net/url"
)

func applyProxy(transport *http.Transport, proxyURL *url.URL, proxyPath string) (applied bool) {
if proxyURL == nil && proxyPath == "" {
return false
}

// withProxyTransport modifies the given transport to handle proxying of unix, socks5 and http connections.
//
// Note: baseTransport is considered to be a clone created with transport.Clone()
//
// - If a the proxyPath is not empty, a unix socket proxy is created.
// - Otherwise, the proxyURL is used to determine if we should proxy socks5 / http connections
func withProxyTransport(baseTransport *http.Transport, proxyURL *url.URL, proxyPath string) *http.Transport {
handshakeTLS := func(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) {
// Extract the hostname (without the port) for TLS SNI
host, _, err := net.SplitHostPort(addr)
Expand All @@ -26,16 +28,14 @@ func applyProxy(transport *http.Transport, proxyURL *url.URL, proxyPath string)
ServerName: host,
// Pull InsecureSkipVerify from the target host transport
// so that insecure-skip-verify flag settings are honored for the proxy server
InsecureSkipVerify: transport.TLSClientConfig.InsecureSkipVerify,
InsecureSkipVerify: baseTransport.TLSClientConfig.InsecureSkipVerify,
})
if err := tlsConn.HandshakeContext(ctx); err != nil {
return nil, err
}
return tlsConn, nil
}

proxyApplied := false

if proxyPath != "" {
dial := func(ctx context.Context, _, _ string) (net.Conn, error) {
d := net.Dialer{}
Expand All @@ -48,17 +48,15 @@ func applyProxy(transport *http.Transport, proxyURL *url.URL, proxyPath string)
}
return handshakeTLS(ctx, conn, addr)
}
transport.DialContext = dial
transport.DialTLSContext = dialTLS
baseTransport.DialContext = dial
baseTransport.DialTLSContext = dialTLS
// clear out any system proxy settings
transport.Proxy = nil
proxyApplied = true
baseTransport.Proxy = nil
} else if proxyURL != nil {
switch proxyURL.Scheme {
case "socks5", "socks5h":
// SOCKS proxies work out of the box - no need to manually dial
transport.Proxy = http.ProxyURL(proxyURL)
proxyApplied = true
baseTransport.Proxy = http.ProxyURL(proxyURL)
case "http", "https":
dial := func(ctx context.Context, network, addr string) (net.Conn, error) {
// Dial the proxy
Expand Down Expand Up @@ -126,13 +124,12 @@ func applyProxy(transport *http.Transport, proxyURL *url.URL, proxyPath string)
}
return handshakeTLS(ctx, conn, addr)
}
transport.DialContext = dial
transport.DialTLSContext = dialTLS
baseTransport.DialContext = dial
baseTransport.DialTLSContext = dialTLS
// clear out any system proxy settings
transport.Proxy = nil
proxyApplied = true
baseTransport.Proxy = nil
}
}

return proxyApplied
return baseTransport
}
Loading