diff --git a/proxy/listener.go b/proxy/listener.go new file mode 100644 index 0000000..7e6ae1f --- /dev/null +++ b/proxy/listener.go @@ -0,0 +1,46 @@ +package proxy + +import ( + "net" + "sync" +) + +// singleConnListener is a net.Listener that returns exactly one connection +// from Accept and then blocks until Close is called. +type singleConnListener struct { + conn net.Conn + once sync.Once + ch chan struct{} +} + +func newSingleConnListener(c net.Conn) *singleConnListener { + return &singleConnListener{ + conn: c, + ch: make(chan struct{}), + } +} + +func (l *singleConnListener) Accept() (net.Conn, error) { + var c net.Conn + l.once.Do(func() { c = l.conn }) + if c != nil { + return c, nil + } + // Block until Close is called. + <-l.ch + return nil, net.ErrClosed +} + +func (l *singleConnListener) Close() error { + select { + case <-l.ch: + // Already closed. + default: + close(l.ch) + } + return nil +} + +func (l *singleConnListener) Addr() net.Addr { + return l.conn.LocalAddr() +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 4c84699..25bd06d 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -1,7 +1,6 @@ package proxy import ( - "bufio" "context" "crypto/tls" "errors" @@ -61,8 +60,11 @@ func NewProxy(filter filter, certGenerator certGenerator, port int) (*Proxy, err KeepAlive: 30 * time.Second, } p.requestTransport = &http.Transport{ - Dial: p.netDialer.Dial, + DialContext: p.netDialer.DialContext, + ForceAttemptHTTP2: true, TLSHandshakeTimeout: 20 * time.Second, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, } p.requestClient = &http.Client{ Timeout: 60 * time.Second, @@ -224,72 +226,75 @@ func (p *Proxy) proxyConnect(w http.ResponseWriter, connReq *http.Request) { tlsConfig := &tls.Config{ Certificates: []tls.Certificate{*tlsCert}, + NextProtos: []string{"h2", "http/1.1"}, MinVersion: tls.VersionTLS12, } tlsConn := tls.Server(clientConn, tlsConfig) defer tlsConn.Close() - connReader := bufio.NewReader(tlsConn) - // Read requests in a loop to allow for HTTP connection reuse. - // https://en.wikipedia.org/wiki/HTTP_persistent_connection - for { - req, err := http.ReadRequest(connReader) - if err != nil { - if err != io.EOF { - - msg := err.Error() - if strings.Contains(msg, "tls: ") { - log.Printf("adding %s to ignored hosts", redacted.Redacted(host)) - p.addTransparentHost(host) - } - - // The following errors occur when the underlying clientConn is closed. - // This usually happens during normal request/response flow when the client - // decides it no longer needs the connection to the host. - // To avoid excessive noise in the logs, we suppress these messages. - if !strings.HasSuffix(msg, "connection reset by peer") && !strings.HasSuffix(msg, "An existing connection was forcibly closed by the remote host.") { - log.Printf("reading request(%s): %v", redacted.Redacted(connReq.Host), err) - } - } - break + // Perform the TLS handshake manually so we can capture TLS errors + // and add the host to transparentHosts before entering the server loop. + if err := tlsConn.HandshakeContext(context.Background()); err != nil { + msg := err.Error() + if strings.Contains(msg, "tls: ") { + log.Printf("adding %s to ignored hosts", redacted.Redacted(host)) + p.addTransparentHost(host) } + log.Printf("TLS handshake(%s): %v", redacted.Redacted(connReq.Host), err) + return + } + + ln := newSingleConnListener(tlsConn) + + srv := &http.Server{ + Handler: p.connectHandler(connReq, host, ln), + TLSConfig: tlsConfig, + ConnState: func(_ net.Conn, state http.ConnState) { + if state == http.StateClosed { + ln.Close() + } + }, + ReadHeaderTimeout: 20 * time.Second, + } + + if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) && !errors.Is(err, net.ErrClosed) { + log.Printf("serving connection(%s): %v", redacted.Redacted(connReq.Host), err) + } +} + +// connectHandler returns an http.Handler that processes requests on a CONNECT-tunnelled TLS connection. +func (p *Proxy) connectHandler(connReq *http.Request, host string, ln *singleConnListener) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req.URL.Host = connReq.Host + req.URL.Scheme = "https" + req.RequestURI = "" - if isWS(req) { - // Establish transparent flow, no hop-by-hop header removal required. - p.proxyWebsocketTLS(req, tlsConfig, tlsConn) - break + // WebSocket upgrade is only done over HTTP/1.1. + if isWS(req) && req.ProtoMajor == 1 { + p.proxyWebsocketTLS(w, req) + ln.Close() + return } - // A standard CONNECT proxy establishes a TCP connection to the requested destination and relays the stream between the client and server. - // Here, we are MITM-ing the traffic and handling the request-response flow ourselves. - // Since the client and server do not share a direct TCP connection in this setup, we must strip hop-by-hop headers. removeHopHeaders(req.Header) - req.URL.Scheme = "https" filterResp, err := p.filter.HandleRequest(req) if err != nil { log.Printf("handling request for %q: %v", redacted.Redacted(req.URL), err) } if filterResp != nil { - if _, err := io.Copy(io.Discard, req.Body); err != nil { - log.Printf("discarding body for %q: %v", redacted.Redacted(req.URL), err) - break - } - if err := req.Body.Close(); err != nil { - log.Printf("closing body for %q: %v", redacted.Redacted(req.URL), err) - break - } - if err := filterResp.Write(tlsConn); err != nil { - log.Printf("writing filter response for %q: %v", redacted.Redacted(req.URL), err) - break + writeResp(w, filterResp) + if filterResp.Body != nil { + filterResp.Body.Close() } + return + } - if req.Close { - break - } - continue + // Go's HTTP server always sets a non-nil value for req.Body. + // RoundTrip interprets a non-nil Body as chunked, which causes strict servers to reject the request. + if req.ContentLength == 0 { + req.Body = nil } resp, err := p.requestTransport.RoundTrip(req) @@ -298,41 +303,22 @@ func (p *Proxy) proxyConnect(w http.ResponseWriter, connReq *http.Request) { log.Printf("adding %s to ignored hosts", redacted.Redacted(host)) p.addTransparentHost(host) } - log.Printf("roundtrip(%s): %v", redacted.Redacted(connReq.Host), err) - // TODO: better error presentation - response := fmt.Sprintf("HTTP/1.1 502 Bad Gateway\r\n\r\n%s", err.Error()) - tlsConn.Write([]byte(response)) - break + http.Error(w, err.Error(), http.StatusBadGateway) + return } + defer resp.Body.Close() removeHopHeaders(resp.Header) if err := p.filter.HandleResponse(req, resp); err != nil { log.Printf("error handling response by filter for %q: %v", redacted.Redacted(req.URL), err) - if err := resp.Body.Close(); err != nil { - log.Printf("closing body for %q: %v", redacted.Redacted(req.URL), err) - } - response := fmt.Sprintf("HTTP/1.1 502 Bad Gateway\r\n\r\n%s", err.Error()) - tlsConn.Write([]byte(response)) - break + http.Error(w, err.Error(), http.StatusBadGateway) + return } - if err := resp.Write(tlsConn); err != nil { - log.Printf("writing response(%q): %v", redacted.Redacted(connReq.Host), err) - if err := resp.Body.Close(); err != nil { - log.Printf("closing body(%q): %v", redacted.Redacted(connReq.Host), err) - } - break - } - if err := resp.Body.Close(); err != nil { - log.Printf("closing body(%q): %v", redacted.Redacted(connReq.Host), err) - } - - if req.Close || resp.Close { - break - } - } + writeResp(w, resp) + }) } // shouldMITM returns true if the host should be MITM'd. @@ -376,6 +362,25 @@ func (p *Proxy) tunnel(w net.Conn, r *http.Request) { linkBidirectionalTunnel(w, remoteConn) } +// writeResp writes the response (status code, headers, and body) to the ResponseWriter. +// It is the caller's responsibility to close the response body after calling the function. +func writeResp(w http.ResponseWriter, resp *http.Response) { + for h, v := range resp.Header { + for _, vv := range v { + w.Header().Add(h, vv) + } + } + w.WriteHeader(resp.StatusCode) + if resp.Body != nil { + io.Copy(w, resp.Body) + } + for h, v := range resp.Trailer { + for _, vv := range v { + w.Header().Add(http.TrailerPrefix+h, vv) + } + } +} + func linkBidirectionalTunnel(src, dst io.ReadWriter) { doneC := make(chan struct{}, 2) go tunnelConn(src, dst, doneC) diff --git a/proxy/websocket.go b/proxy/websocket.go index 78fb78b..dc81908 100644 --- a/proxy/websocket.go +++ b/proxy/websocket.go @@ -5,52 +5,46 @@ import ( "crypto/tls" "io" "log" + "net" "net/http" "strings" "github.com/ZenPrivacy/zen-core/internal/redacted" ) -func (p *Proxy) proxyWebsocketTLS(req *http.Request, tlsConfig *tls.Config, clientConn *tls.Conn) { - dialer := &tls.Dialer{NetDialer: p.netDialer, Config: tlsConfig} - targetConn, err := dialer.Dial("tcp", req.URL.Host) - if err != nil { - log.Printf("dialing websocket backend(%s): %v", redacted.Redacted(req.URL.Host), err) - clientConn.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n")) - return - } - defer targetConn.Close() - - if err := websocketHandshake(req, targetConn, clientConn); err != nil { - return - } - - linkBidirectionalTunnel(targetConn, clientConn) +func (p *Proxy) proxyWebsocketTLS(w http.ResponseWriter, req *http.Request) { + dialer := &tls.Dialer{NetDialer: p.netDialer, Config: &tls.Config{MinVersion: tls.VersionTLS12}} + hijackAndTunnelWebsocket(w, req, dialer.Dial) } func (p *Proxy) proxyWebsocket(w http.ResponseWriter, req *http.Request) { - targetConn, err := p.netDialer.Dial("tcp", req.URL.Host) - if err != nil { - w.WriteHeader(http.StatusBadGateway) - log.Printf("dialing websocket backend(%s): %v", redacted.Redacted(req.URL.Host), err) - return - } - defer targetConn.Close() + hijackAndTunnelWebsocket(w, req, p.netDialer.Dial) +} +func hijackAndTunnelWebsocket(w http.ResponseWriter, req *http.Request, dial func(network, addr string) (net.Conn, error)) { hj, ok := w.(http.Hijacker) if !ok { - panic("http server does not support hijacking") + http.Error(w, "websocket hijack not supported", http.StatusInternalServerError) + return } clientConn, _, err := hj.Hijack() if err != nil { - log.Printf("hijacking websocket client(%s): %v", redacted.Redacted(req.URL.Host), err) + log.Printf("hijacking websocket(%s): %v", redacted.Redacted(req.URL.Host), err) return } + defer clientConn.Close() - if err := websocketHandshake(req, targetConn, clientConn); err != nil { + targetConn, err := dial("tcp", req.URL.Host) + if err != nil { + log.Printf("dialing websocket backend(%s): %v", redacted.Redacted(req.URL.Host), err) + clientConn.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n")) return } + defer targetConn.Close() + if err := websocketHandshake(req, targetConn, clientConn); err != nil { + return + } linkBidirectionalTunnel(targetConn, clientConn) }