Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions proxy/listener.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package proxy

import (
"net"
"sync"
)

// singleConnListener is a net.Listener that returns exactly one connection
// from Accept and then blocks until Close is called.
type singleConnListener struct {
conn net.Conn
once sync.Once
ch chan struct{}
}

func newSingleConnListener(c net.Conn) *singleConnListener {
return &singleConnListener{
conn: c,
ch: make(chan struct{}),
}
}

func (l *singleConnListener) Accept() (net.Conn, error) {
var c net.Conn
l.once.Do(func() { c = l.conn })
if c != nil {
return c, nil
}
// Block until Close is called.
<-l.ch
return nil, net.ErrClosed
}

func (l *singleConnListener) Close() error {
select {
case <-l.ch:
// Already closed.
default:
close(l.ch)
}
return nil
}

func (l *singleConnListener) Addr() net.Addr {
return l.conn.LocalAddr()
}
153 changes: 79 additions & 74 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package proxy

import (
"bufio"
"context"
"crypto/tls"
"errors"
Expand Down Expand Up @@ -61,8 +60,11 @@ func NewProxy(filter filter, certGenerator certGenerator, port int) (*Proxy, err
KeepAlive: 30 * time.Second,
}
p.requestTransport = &http.Transport{
Dial: p.netDialer.Dial,
DialContext: p.netDialer.DialContext,
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 20 * time.Second,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
}
p.requestClient = &http.Client{
Timeout: 60 * time.Second,
Expand Down Expand Up @@ -224,72 +226,75 @@ func (p *Proxy) proxyConnect(w http.ResponseWriter, connReq *http.Request) {

tlsConfig := &tls.Config{
Certificates: []tls.Certificate{*tlsCert},
NextProtos: []string{"h2", "http/1.1"},
MinVersion: tls.VersionTLS12,
}

tlsConn := tls.Server(clientConn, tlsConfig)
defer tlsConn.Close()
connReader := bufio.NewReader(tlsConn)

// Read requests in a loop to allow for HTTP connection reuse.
// https://en.wikipedia.org/wiki/HTTP_persistent_connection
for {
req, err := http.ReadRequest(connReader)
if err != nil {
if err != io.EOF {

msg := err.Error()
if strings.Contains(msg, "tls: ") {
log.Printf("adding %s to ignored hosts", redacted.Redacted(host))
p.addTransparentHost(host)
}

// The following errors occur when the underlying clientConn is closed.
// This usually happens during normal request/response flow when the client
// decides it no longer needs the connection to the host.
// To avoid excessive noise in the logs, we suppress these messages.
if !strings.HasSuffix(msg, "connection reset by peer") && !strings.HasSuffix(msg, "An existing connection was forcibly closed by the remote host.") {
log.Printf("reading request(%s): %v", redacted.Redacted(connReq.Host), err)
}
}
break
// Perform the TLS handshake manually so we can capture TLS errors
// and add the host to transparentHosts before entering the server loop.
if err := tlsConn.HandshakeContext(context.Background()); err != nil {
msg := err.Error()
if strings.Contains(msg, "tls: ") {
log.Printf("adding %s to ignored hosts", redacted.Redacted(host))
p.addTransparentHost(host)
}
log.Printf("TLS handshake(%s): %v", redacted.Redacted(connReq.Host), err)
return
}
Comment on lines +238 to +246
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

TLS handshake has no timeout — potential indefinite goroutine leak.

context.Background() carries no deadline. A client that sends the CONNECT request, receives 200 OK, then stalls mid-TLS-handshake will hold the goroutine and the TCP connection indefinitely. Each such connection is a goroutine pinned inside proxyConnect. Use a time-bounded context to cap the handshake duration.

🛡️ Proposed fix
+    hsCtx, hsCancel := context.WithTimeout(context.Background(), 10*time.Second)
+    defer hsCancel()
-    if err := tlsConn.HandshakeContext(context.Background()); err != nil {
+    if err := tlsConn.HandshakeContext(hsCtx); err != nil {
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if err := tlsConn.HandshakeContext(context.Background()); err != nil {
msg := err.Error()
if strings.Contains(msg, "tls: ") {
log.Printf("adding %s to ignored hosts", redacted.Redacted(host))
p.addTransparentHost(host)
}
log.Printf("TLS handshake(%s): %v", redacted.Redacted(connReq.Host), err)
return
}
hsCtx, hsCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer hsCancel()
if err := tlsConn.HandshakeContext(hsCtx); err != nil {
msg := err.Error()
if strings.Contains(msg, "tls: ") {
log.Printf("adding %s to ignored hosts", redacted.Redacted(host))
p.addTransparentHost(host)
}
log.Printf("TLS handshake(%s): %v", redacted.Redacted(connReq.Host), err)
return
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@proxy/proxy.go` around lines 236 - 244, The TLS handshake currently uses
context.Background() in the tlsConn.HandshakeContext call, allowing a stalled
client to hang the goroutine forever; change this to a time-bounded context
(e.g., use context.WithTimeout(context.Background(), tlsHandshakeTimeout))
before calling tlsConn.HandshakeContext, store the cancel func and defer
cancel(), and handle a deadline/timeout error path (log a timeout and
close/return). Update or add a tlsHandshakeTimeout constant and ensure the
existing logic around tlsConn.HandshakeContext, redacted.Redacted(host),
p.addTransparentHost(host) and logging remains but that the handshake is always
bounded and cleaned up via defer cancel().


ln := newSingleConnListener(tlsConn)

srv := &http.Server{
Handler: p.connectHandler(connReq, host, ln),
TLSConfig: tlsConfig,
ConnState: func(_ net.Conn, state http.ConnState) {
if state == http.StateClosed {
ln.Close()
}
},
ReadHeaderTimeout: 20 * time.Second,
}

if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) && !errors.Is(err, net.ErrClosed) {
log.Printf("serving connection(%s): %v", redacted.Redacted(connReq.Host), err)
}
}

// connectHandler returns an http.Handler that processes requests on a CONNECT-tunnelled TLS connection.
func (p *Proxy) connectHandler(connReq *http.Request, host string, ln *singleConnListener) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL.Host = connReq.Host
req.URL.Scheme = "https"
req.RequestURI = ""

if isWS(req) {
// Establish transparent flow, no hop-by-hop header removal required.
p.proxyWebsocketTLS(req, tlsConfig, tlsConn)
break
// WebSocket upgrade is only done over HTTP/1.1.
if isWS(req) && req.ProtoMajor == 1 {
p.proxyWebsocketTLS(w, req)
ln.Close()
return
}

// A standard CONNECT proxy establishes a TCP connection to the requested destination and relays the stream between the client and server.
// Here, we are MITM-ing the traffic and handling the request-response flow ourselves.
// Since the client and server do not share a direct TCP connection in this setup, we must strip hop-by-hop headers.
removeHopHeaders(req.Header)
req.URL.Scheme = "https"

filterResp, err := p.filter.HandleRequest(req)
if err != nil {
log.Printf("handling request for %q: %v", redacted.Redacted(req.URL), err)
}
if filterResp != nil {
if _, err := io.Copy(io.Discard, req.Body); err != nil {
log.Printf("discarding body for %q: %v", redacted.Redacted(req.URL), err)
break
}
if err := req.Body.Close(); err != nil {
log.Printf("closing body for %q: %v", redacted.Redacted(req.URL), err)
break
}
if err := filterResp.Write(tlsConn); err != nil {
log.Printf("writing filter response for %q: %v", redacted.Redacted(req.URL), err)
break
writeResp(w, filterResp)
if filterResp.Body != nil {
filterResp.Body.Close()
}
return
}

if req.Close {
break
}
continue
// Go's HTTP server always sets a non-nil value for req.Body.
// RoundTrip interprets a non-nil Body as chunked, which causes strict servers to reject the request.
if req.ContentLength == 0 {
req.Body = nil
}

resp, err := p.requestTransport.RoundTrip(req)
Expand All @@ -298,41 +303,22 @@ func (p *Proxy) proxyConnect(w http.ResponseWriter, connReq *http.Request) {
log.Printf("adding %s to ignored hosts", redacted.Redacted(host))
p.addTransparentHost(host)
}

log.Printf("roundtrip(%s): %v", redacted.Redacted(connReq.Host), err)
// TODO: better error presentation
response := fmt.Sprintf("HTTP/1.1 502 Bad Gateway\r\n\r\n%s", err.Error())
tlsConn.Write([]byte(response))
break
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()

removeHopHeaders(resp.Header)

if err := p.filter.HandleResponse(req, resp); err != nil {
log.Printf("error handling response by filter for %q: %v", redacted.Redacted(req.URL), err)
if err := resp.Body.Close(); err != nil {
log.Printf("closing body for %q: %v", redacted.Redacted(req.URL), err)
}
response := fmt.Sprintf("HTTP/1.1 502 Bad Gateway\r\n\r\n%s", err.Error())
tlsConn.Write([]byte(response))
break
http.Error(w, err.Error(), http.StatusBadGateway)
return
}

if err := resp.Write(tlsConn); err != nil {
log.Printf("writing response(%q): %v", redacted.Redacted(connReq.Host), err)
if err := resp.Body.Close(); err != nil {
log.Printf("closing body(%q): %v", redacted.Redacted(connReq.Host), err)
}
break
}
if err := resp.Body.Close(); err != nil {
log.Printf("closing body(%q): %v", redacted.Redacted(connReq.Host), err)
}

if req.Close || resp.Close {
break
}
}
writeResp(w, resp)
})
}

// shouldMITM returns true if the host should be MITM'd.
Expand Down Expand Up @@ -376,6 +362,25 @@ func (p *Proxy) tunnel(w net.Conn, r *http.Request) {
linkBidirectionalTunnel(w, remoteConn)
}

// writeResp writes the response (status code, headers, and body) to the ResponseWriter.
// It is the caller's responsibility to close the response body after calling the function.
func writeResp(w http.ResponseWriter, resp *http.Response) {
for h, v := range resp.Header {
for _, vv := range v {
w.Header().Add(h, vv)
}
}
w.WriteHeader(resp.StatusCode)
if resp.Body != nil {
io.Copy(w, resp.Body)
}
for h, v := range resp.Trailer {
for _, vv := range v {
w.Header().Add(http.TrailerPrefix+h, vv)
}
}
}

func linkBidirectionalTunnel(src, dst io.ReadWriter) {
doneC := make(chan struct{}, 2)
go tunnelConn(src, dst, doneC)
Expand Down
44 changes: 19 additions & 25 deletions proxy/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,46 @@ import (
"crypto/tls"
"io"
"log"
"net"
"net/http"
"strings"

"github.com/ZenPrivacy/zen-core/internal/redacted"
)

func (p *Proxy) proxyWebsocketTLS(req *http.Request, tlsConfig *tls.Config, clientConn *tls.Conn) {
dialer := &tls.Dialer{NetDialer: p.netDialer, Config: tlsConfig}
targetConn, err := dialer.Dial("tcp", req.URL.Host)
if err != nil {
log.Printf("dialing websocket backend(%s): %v", redacted.Redacted(req.URL.Host), err)
clientConn.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
return
}
defer targetConn.Close()

if err := websocketHandshake(req, targetConn, clientConn); err != nil {
return
}

linkBidirectionalTunnel(targetConn, clientConn)
func (p *Proxy) proxyWebsocketTLS(w http.ResponseWriter, req *http.Request) {
dialer := &tls.Dialer{NetDialer: p.netDialer, Config: &tls.Config{MinVersion: tls.VersionTLS12}}
hijackAndTunnelWebsocket(w, req, dialer.Dial)
}

func (p *Proxy) proxyWebsocket(w http.ResponseWriter, req *http.Request) {
targetConn, err := p.netDialer.Dial("tcp", req.URL.Host)
if err != nil {
w.WriteHeader(http.StatusBadGateway)
log.Printf("dialing websocket backend(%s): %v", redacted.Redacted(req.URL.Host), err)
return
}
defer targetConn.Close()
hijackAndTunnelWebsocket(w, req, p.netDialer.Dial)
}

func hijackAndTunnelWebsocket(w http.ResponseWriter, req *http.Request, dial func(network, addr string) (net.Conn, error)) {
hj, ok := w.(http.Hijacker)
if !ok {
panic("http server does not support hijacking")
http.Error(w, "websocket hijack not supported", http.StatusInternalServerError)
return
}
clientConn, _, err := hj.Hijack()
if err != nil {
log.Printf("hijacking websocket client(%s): %v", redacted.Redacted(req.URL.Host), err)
log.Printf("hijacking websocket(%s): %v", redacted.Redacted(req.URL.Host), err)
return
}
defer clientConn.Close()

if err := websocketHandshake(req, targetConn, clientConn); err != nil {
targetConn, err := dial("tcp", req.URL.Host)
if err != nil {
log.Printf("dialing websocket backend(%s): %v", redacted.Redacted(req.URL.Host), err)
clientConn.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
return
}
defer targetConn.Close()

if err := websocketHandshake(req, targetConn, clientConn); err != nil {
return
}
linkBidirectionalTunnel(targetConn, clientConn)
}

Expand Down
Loading