diff --git a/proxy.go b/proxy.go index 4cc6c56..f46e26a 100644 --- a/proxy.go +++ b/proxy.go @@ -144,6 +144,23 @@ func NewProxy(options *Options) (*Proxy, error) { } fastdialerOptions.BaseResolvers = []string{"127.0.0.1" + options.ListenDNSAddr} } + + if len(options.UpstreamHTTPProxies) > 0 { + proxyDialer, err := newHTTPProxyRoundRobinDialer(options.UpstreamHTTPProxies) + if err != nil { + return nil, err + } + fastdialerOptions.ProxyDialer = &proxyDialer + } + + if len(options.UpstreamSOCKS5Proxies) > 0 { + dialer, err := newSOCKS5ProxyRoundRobinDialer(options.UpstreamSOCKS5Proxies) + if err != nil { + return nil, err + } + fastdialerOptions.ProxyDialer = &dialer + } + dialer, err := fastdialer.NewDialer(fastdialerOptions) if err != nil { return nil, err diff --git a/upstream.go b/upstream.go new file mode 100644 index 0000000..882733f --- /dev/null +++ b/upstream.go @@ -0,0 +1,154 @@ +package proxify + +import ( + "bufio" + "encoding/base64" + "fmt" + "net" + "net/http" + "net/url" + + rbtransport "github.com/projectdiscovery/roundrobin/transport" + "golang.org/x/net/proxy" +) + +type httpProxyDialer struct { + proxyURL *url.URL + forward proxy.Dialer +} + +// Dial connects to the address using the HTTP proxy. +func (d *httpProxyDialer) Dial(_, addr string) (net.Conn, error) { + conn, err := d.forward.Dial("tcp", d.proxyURL.Host) + if err != nil { + return nil, err + } + + connectReq := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: addr}, + Host: addr, + Header: make(http.Header), + } + if d.proxyURL.User != nil { + encodedUserinfo := base64.StdEncoding.EncodeToString([]byte(d.proxyURL.User.String())) + connectReq.Header.Set("Proxy-Authorization", "Basic "+encodedUserinfo) + } + + if err := connectReq.Write(conn); err != nil { + _ = conn.Close() + return nil, err + } + + br := bufio.NewReader(conn) + resp, err := http.ReadResponse(br, connectReq) + if err != nil { + _ = conn.Close() + return nil, err + } + if resp.StatusCode != http.StatusOK { + _ = conn.Close() + return nil, fmt.Errorf("unexpected response from proxy: %s", resp.Status) + } + + return conn, nil +} + +type httpProxyRoundRobinDialer struct { + proxyDialers map[string]httpProxyDialer + transport *rbtransport.RoundTransport +} + +// Dial connects to the address on the named network via one of the HTTP proxies using round-robin scheduling. +func (d *httpProxyRoundRobinDialer) Dial(network, addr string) (net.Conn, error) { + nextProxyURL := d.transport.Next() + dialer, ok := d.proxyDialers[nextProxyURL] + if !ok { + return nil, fmt.Errorf("no matching proxy dialer found") + } + return dialer.Dial(network, addr) +} + +func newHTTPProxyRoundRobinDialer(upstreamProxies []string) (proxy.Dialer, error) { + if len(upstreamProxies) == 0 { + return nil, fmt.Errorf("proxy URLs cannot be empty") + } + + proxyURLs := make([]*url.URL, 0, len(upstreamProxies)) + dialers := make(map[string]httpProxyDialer) + for _, proxyAddr := range upstreamProxies { + proxyURL, err := url.Parse(proxyAddr) + if err != nil { + return nil, err + } + proxyURLs = append(proxyURLs, proxyURL) + dialer := httpProxyDialer{proxyURL: proxyURL, forward: proxy.Direct} + dialers[proxyURL.String()] = dialer + } + + robin, err := rbtransport.NewWithOptions(1, toStringSlice(proxyURLs)...) + if err != nil { + return nil, err + } + + return &httpProxyRoundRobinDialer{proxyDialers: dialers, transport: robin}, nil +} + +type socks5ProxyRoundRobinDialer struct { + proxyDialers map[string]proxy.Dialer + robin *rbtransport.RoundTransport +} + +// Dial connects to the address on the named network via one of the SOCKS5 proxies using round-robin scheduling. +func (d *socks5ProxyRoundRobinDialer) Dial(network, addr string) (net.Conn, error) { + nextProxyURL := d.robin.Next() + dialer, ok := d.proxyDialers[nextProxyURL] + if !ok { + return nil, fmt.Errorf("no matching proxy dialer found") + } + return dialer.Dial(network, addr) +} + +func newSOCKS5ProxyRoundRobinDialer(upstreamProxies []string) (proxy.Dialer, error) { + if len(upstreamProxies) == 0 { + return nil, fmt.Errorf("proxy URLs cannot be empty") + } + + proxyURLs := make([]*url.URL, 0, len(upstreamProxies)) + dialers := make(map[string]proxy.Dialer) + for _, proxyAddr := range upstreamProxies { + proxyURL, err := url.Parse(proxyAddr) + if err != nil { + return nil, err + } + proxyURLs = append(proxyURLs, proxyURL) + var auth *proxy.Auth + if proxyURL.User != nil { + password, _ := proxyURL.User.Password() + auth = &proxy.Auth{ + User: proxyURL.User.Username(), + Password: password, + } + } + dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct) + if err != nil { + return nil, err + } + dialers[proxyAddr] = dialer + } + + robin, err := rbtransport.NewWithOptions(1, toStringSlice(proxyURLs)...) + if err != nil { + return nil, err + } + + return &socks5ProxyRoundRobinDialer{proxyDialers: dialers, robin: robin}, nil +} + +func toStringSlice(urls []*url.URL) []string { + s := make([]string, len(urls)) + for i, u := range urls { + s[i] = u.String() + } + return s +} diff --git a/upstream_test.go b/upstream_test.go new file mode 100644 index 0000000..fa8008c --- /dev/null +++ b/upstream_test.go @@ -0,0 +1,128 @@ +package proxify + +import ( + "net/url" + "reflect" + "testing" +) + +func TestNewHTTPProxyRoundRobinDialer(t *testing.T) { + tests := []struct { + name string + upstreamProxies []string + shouldThrowErr bool + }{ + { + name: "empty", + upstreamProxies: []string{}, + shouldThrowErr: true, + }, + { + name: "one", + upstreamProxies: []string{"http://localhost:7777"}, + shouldThrowErr: false, + }, + { + name: "multiple", + upstreamProxies: []string{"http://localhost:7777", "http://localhost:9999"}, + shouldThrowErr: false, + }, + { + name: "invalid", + upstreamProxies: []string{"http://:invalid"}, + shouldThrowErr: true, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual, actualErr := newHTTPProxyRoundRobinDialer(test.upstreamProxies) + if (actualErr != nil) != test.shouldThrowErr { + t.Errorf("newHTTPProxyRoundRobinDialer() actualErr = %v, shouldThrowErr = %v", actualErr, test.shouldThrowErr) + return + } + if !test.shouldThrowErr && actual == nil { + t.Errorf("newHTTPProxyRoundRobinDialer() actual = %v, expected non-nil", actual) + } + }) + } +} + +func TestNewSOCKS5ProxyRoundRobinDialer(t *testing.T) { + tests := []struct { + name string + upstreamProxies []string + shouldThrowErr bool + }{ + { + name: "empty", + upstreamProxies: []string{}, + shouldThrowErr: true, + }, + { + name: "one", + upstreamProxies: []string{"socks5://localhost:10070"}, + shouldThrowErr: false, + }, + { + name: "multiple", + upstreamProxies: []string{"socks5://localhost:10070", "socks5://localhost:10090"}, + shouldThrowErr: false, + }, + { + name: "invalid", + upstreamProxies: []string{"socks5://:invalid"}, + shouldThrowErr: true, + }, + { + name: "auth", + upstreamProxies: []string{"socks5://user:pass@localhost:10070"}, + shouldThrowErr: false, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual, actualErr := newSOCKS5ProxyRoundRobinDialer(test.upstreamProxies) + if (actualErr != nil) != test.shouldThrowErr { + t.Errorf("newSOCKS5ProxyRoundRobinDialer() actualErr = %v, shouldThrowErr = %v", actualErr, test.shouldThrowErr) + return + } + if !test.shouldThrowErr && actual == nil { + t.Errorf("newSOCKS5ProxyRoundRobinDialer() actual = %v, expected non-nil", actual) + } + }) + } +} + +func TestToStringSlice(t *testing.T) { + tests := []struct { + name string + urls []*url.URL + expected []string + }{ + { + name: "single", + urls: []*url.URL{{Scheme: "http", Host: "localhost:8080"}}, + expected: []string{"http://localhost:8080"}, + }, + { + name: "multiple", + urls: []*url.URL{ + {Scheme: "http", Host: "localhost:8080"}, + {Scheme: "socks5", User: url.UserPassword("user", "pass"), Host: "localhost:8081"}, + }, + expected: []string{"http://localhost:8080", "socks5://user:pass@localhost:8081"}, + }, + { + name: "empty", + urls: []*url.URL{}, + expected: []string{}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if actual := toStringSlice(test.urls); !reflect.DeepEqual(actual, test.expected) { + t.Errorf("toStringSlice() actual = %v, expected = %v", actual, test.expected) + } + }) + } +}