Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 12 additions & 61 deletions pkg/abstractions/pod/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,90 +298,41 @@ func (pb *PodProxyBuffer) handleConnection(conn *connection) {
}
defer pb.decrementContainerConnections(container.id)

// If it's a websocket request, upgrade the connection
if websocket.IsWebSocketUpgrade(request) {
pb.proxyWebSocket(conn, container, targetHost, subPath)
return
}

// Otherwise, use regular HTTP proxying
targetURL, err := url.Parse("http://" + targetHost)
if err != nil {
conn.ctx.String(http.StatusInternalServerError, "Invalid target URL")
return
}

isWebSocket := websocket.IsWebSocketUpgrade(request)

proxy := httputil.NewSingleHostReverseProxy(targetURL)
proxy.Transport = &http.Transport{
DialContext: func(ctx context.Context, networkType, addr string) (net.Conn, error) {
conn, err := network.ConnectToHost(ctx, addr, containerDialTimeoutDurationS, pb.tailscale, pb.tsConfig)
if err == nil {
abstractions.SetConnOptions(conn, true, connectionKeepAliveInterval, connectionReadTimeout)
readTimeout := connectionReadTimeout
if isWebSocket {
readTimeout = -1 // No read deadline for WebSocket connections
}
abstractions.SetConnOptions(conn, true, connectionKeepAliveInterval, readTimeout)
}
return conn, err
},
}
proxy.FlushInterval = -1 // Flush immediately for streaming/SSE

defer func() {
if r := recover(); r != nil {
log.Error().Err(err).Str("stubId", pb.stubId).Str("workspace", pb.workspace.Name).Msg("handled abort in pod proxy")
}
}()

proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {}
proxy.ServeHTTP(response, request)
}

func (pb *PodProxyBuffer) proxyWebSocket(conn *connection, container container, addr string, path string) error {
subprotocols := websocket.Subprotocols(conn.ctx.Request())

upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins
},
Subprotocols: subprotocols,
}

clientConn, err := upgrader.Upgrade(conn.ctx.Response().Writer, conn.ctx.Request(), nil)
if err != nil {
return err
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
log.Error().Err(err).Str("stubId", pb.stubId).Msg("pod proxy error")
rw.WriteHeader(http.StatusBadGateway)
}
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
defer clientConn.Close()

wsURL := url.URL{Scheme: "ws", Host: addr, Path: path, RawQuery: conn.ctx.Request().URL.RawQuery}
dstDialer := websocket.Dialer{
NetDialContext: network.GetDialer(addr, pb.tailscale, pb.tsConfig),
Subprotocols: subprotocols,
}

serverConn, _, err := dstDialer.Dial(wsURL.String(), nil)
if err != nil {
return err
}
defer serverConn.Close()

wg := sync.WaitGroup{}
wg.Add(2)

proxyMessages := func(src, dst *websocket.Conn) {
defer wg.Done()

for {
messageType, message, err := src.ReadMessage()
if err != nil {
break
}
if err := dst.WriteMessage(messageType, message); err != nil {
break
}
}
}

go proxyMessages(clientConn, serverConn)
go proxyMessages(serverConn, clientConn)

wg.Wait()
return nil
proxy.ServeHTTP(response, request)
}

func (pb *PodProxyBuffer) discoverContainers() {
Expand Down
Loading