From e5d2e4386e2eb7a67584f36c1891ef92bf598c6a Mon Sep 17 00:00:00 2001 From: Anup Navin Date: Thu, 26 Jun 2025 11:42:31 +0530 Subject: [PATCH] adding websocket support * websocket requests are translated into HTTP requests between proxy and agent * client-proxy and agent-server are connected with websockets --- go.mod | 1 + go.sum | 4 + server/server.go | 331 +++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 325 insertions(+), 11 deletions(-) diff --git a/go.mod b/go.mod index 0fba961..3009ae4 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.18 require ( cloud.google.com/go/compute/metadata v0.2.3 cloud.google.com/go/monitoring v1.13.0 + github.com/coder/websocket v1.8.13 github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da github.com/golang/protobuf v1.5.3 github.com/google/go-cmp v0.5.9 diff --git a/go.sum b/go.sum index 8e4cf84..f7dbe7c 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,20 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.110.0 h1:Zc8gqp3+a9/Eyph2KDmcGaPtbKRIoqq4YTlL4NMD0Ys= +cloud.google.com/go v0.110.0/go.mod h1:SJnCLqQ0FCFGSZMUNUf84MV3Aia54kn7pi8st7tMzaY= cloud.google.com/go/compute v1.19.1 h1:am86mquDUgjGNWxiGn+5PGLbmgiWXlE/yNWpIpNvuXY= cloud.google.com/go/compute v1.19.1/go.mod h1:6ylj3a05WF8leseCdIf77NK0g1ey+nj5IKd5/kvShxE= cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= cloud.google.com/go/longrunning v0.4.1 h1:v+yFJOfKC3yZdY6ZUI933pIYdhyhV8S3NpWrXWmg7jM= +cloud.google.com/go/longrunning v0.4.1/go.mod h1:4iWDqhBZ70CvZ6BfETbvam3T8FMvLK+eFj0E6AaRQTo= cloud.google.com/go/monitoring v1.13.0 h1:2qsrgXGVoRXpP7otZ14eE1I568zAa92sJSDPyOJvwjM= cloud.google.com/go/monitoring v1.13.0/go.mod h1:k2yMBAB1H9JT/QETjNkgdCGD9bPF712XiLTVr+cBrpw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= +github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= diff --git a/server/server.go b/server/server.go index 72e547e..2ada39b 100644 --- a/server/server.go +++ b/server/server.go @@ -24,9 +24,12 @@ package main import ( "bufio" + "bytes" "context" "crypto/sha256" + "encoding/base64" "encoding/json" + "errors" "flag" "fmt" "io" @@ -38,11 +41,15 @@ import ( "sync" "time" + "github.com/coder/websocket" "github.com/google/inverting-proxy/agent/utils" ) var ( - port = flag.Int("port", 0, "Port on which to listen") + port int + setReadLimit int + bufSize int + shimPath string ) // pendingRequest represents a frontend request @@ -77,6 +84,31 @@ func newProxy() *proxy { } } +type sessionMessage struct { + ID string `json:"id,omitempty"` + Message interface{} `json:"msg,omitempty"` + Version int `json:"v,omitempty"` + Subprotocol string `json:"s,omitempty"` +} + +type messageData struct { + data []byte + mt websocket.MessageType +} + +type wsSessionHelper struct { + sessionInfo sessionMessage + writeChan chan []byte + readChan chan []messageData +} + +func newWsSessionHelper() *wsSessionHelper { + return &wsSessionHelper{ + readChan: make(chan []messageData), + writeChan: make(chan []byte, bufSize), + } +} + func (p *proxy) handleAgentPostResponse(w http.ResponseWriter, r *http.Request, requestID string) { p.Lock() pending, ok := p.requests[requestID] @@ -203,11 +235,55 @@ func isHopByHopHeader(name string) bool { } } -func (p *proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if backendID := r.Header.Get(utils.HeaderBackendID); backendID != "" { - p.handleAgentRequest(w, r, backendID) - return +func websocketShimResponseHandlerOpen(resp *http.Response, ws *wsSessionHelper) error { + p, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("%v/open: failed to read response from agent: %v", shimPath, err) + } + err = json.Unmarshal(p, &ws.sessionInfo) + if err != nil { + return fmt.Errorf("%v/open: failed to parse JSON encoded data: %v", shimPath, err) + } + return nil +} + +func websocketShimResponseHandlerData(resp *http.Response, ws *wsSessionHelper) error { + if resp.StatusCode != http.StatusOK { + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("%v/data: http status code is %v, error reading response body", shimPath, resp.StatusCode) + } + return fmt.Errorf("%v/data: http status code %v, response: %v", shimPath, resp.StatusCode, string(respBody)) + } + return nil +} + +func websocketShimResponseHandlerPoll(resp *http.Response, ws *wsSessionHelper) error { + if resp.StatusCode != http.StatusOK { + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("%v/poll: http status code %v, error reading response body", shimPath, resp.StatusCode) + } + return fmt.Errorf("%v/poll: http status code %v, response: %v", shimPath, resp.StatusCode, string(respBody)) + } + p, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response from agent: %v", err) + } + ws.writeChan <- p + return nil +} + +func websocketShimResponseHandlerClose(resp *http.Response, ws *wsSessionHelper) error { + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("%v/close: http status code %v", shimPath, resp.StatusCode) } + return nil +} + +// if the request is a normal HTTP request and not related to the websocket shim, nil value +// can be passed into the wsSessionHelper parameter +func (p *proxy) handleFrontendRequest(w http.ResponseWriter, r *http.Request, ws *wsSessionHelper) error { id := p.newID() log.Printf("Received new frontend request %q", id) // Filter out hop-by-hop headers from the request @@ -225,8 +301,7 @@ func (p *proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { select { case <-r.Context().Done(): // The client request was cancelled - log.Printf("Timeout waiting to enqueue the request ID for %q", id) - return + return fmt.Errorf("timeout waiting to enqueue the request ID for %q", id) case p.requestIDs <- id: } log.Printf("Request %q enqueued after %s", id, time.Since(pending.startTime)) @@ -235,9 +310,20 @@ func (p *proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { select { case <-r.Context().Done(): // The client request was cancelled - log.Printf("Timeout waiting for the response to %q", id) - return + return fmt.Errorf("timeout waiting for the response to %q", id) case resp := <-pending.respChan: + // websocket shim endpoint handling + switch resp.Request.URL.Path { + case shimPath + "/open": + return websocketShimResponseHandlerOpen(resp, ws) + case shimPath + "/data": + return websocketShimResponseHandlerData(resp, ws) + case shimPath + "/poll": + return websocketShimResponseHandlerPoll(resp, ws) + case shimPath + "/close": + return websocketShimResponseHandlerClose(resp, ws) + } + // Copy all of the non-hop-by-hop headers to the proxied response for name, vals := range resp.Header { if isHopByHopHeader(name) { @@ -257,15 +343,238 @@ func (p *proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Add(http.TrailerPrefix+name, v) } } + } + return nil +} + +func parseServerMessage(buf interface{}) ([]byte, websocket.MessageType, error) { + if data, ok := buf.(string); ok { + return []byte(data), websocket.MessageText, nil + } + + if arrData, ok := buf.([]interface{}); ok { + if b64data, ok := arrData[0].(string); ok { + data, err := base64.StdEncoding.DecodeString(b64data) + if err == nil { + return data, websocket.MessageBinary, nil + } + } + } + + return nil, websocket.MessageBinary, errors.New("unexpected data format from server") +} + +func (p *proxy) handleWebsocketRequest(w http.ResponseWriter, r *http.Request) error { + ws := newWsSessionHelper() + + // websocket: shimPath/open + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://:%v%v/open", port, shimPath), nil) + if err != nil { + return fmt.Errorf("failed to create a new %v/open request: %v", shimPath, err) + } + + for k, v := range r.Header { + if !isHopByHopHeader(k) { + req.Header.Set(k, strings.Join(v, ", ")) + } + } + // to avoid CORS errors + req.Header.Set("Origin", "") + req.Header.Set("X-Websocket-Shim-Version", "1") + + err = p.handleFrontendRequest(nil, req, ws) + if err != nil { + return err + } + + conn, err := websocket.Accept(w, r, nil) + if err != nil { + return err + } + + background := context.Background() + ctx, ctxCancel := context.WithCancel(background) + + // websocket: shimPath/close + defer func() { + buf, err := json.Marshal(sessionMessage{ + ID: ws.sessionInfo.ID, + }) + if err != nil { + log.Printf("Failed to encoded data in JSON format: %v", err) + return + } + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://:%v%v/close", port, shimPath), bytes.NewBuffer(buf)) + if err != nil { + log.Printf("Failed to create a new open request: %v", err) + return + } + req.Header.Set("X-Websocket-Shim-Version", "1") + p.handleFrontendRequest(nil, req, nil) + ctxCancel() + conn.CloseNow() + }() + + if setReadLimit != 0 { + conn.SetReadLimit(int64(setReadLimit)) + } + + // goroutine to read client messages + go func() { + for { + mt, ior, err := conn.Reader(ctx) + if err != nil { + log.Printf("Failed to create websocket reader: %v", err) + return + } + var bufArray []messageData + for { + buf := make([]byte, bufSize) + bytesRead, err := ior.Read(buf) + if bytesRead > 0 { + bufArray = append(bufArray, messageData{buf[:bytesRead], mt}) + } + if err != nil { + break + } + } + ws.readChan <- bufArray + } + }() + + // goroutine to read server messages + go func() { + for { + payload := sessionMessage{ + ID: ws.sessionInfo.ID, + } + + buf, err := json.Marshal(payload) + if err != nil { + log.Printf("Failed to encoded data in JSON format: %v", err) + return + } + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://:%v%v/poll", port, shimPath), bytes.NewBuffer(buf)) + if err != nil { + log.Printf("Failed to create a new open request: %v", err) + return + } + req.Header.Set("X-Websocket-Shim-Version", "1") + err = p.handleFrontendRequest(nil, req, ws) + if err != nil { + log.Print(err) + return + } + } + }() + + // loop to process read/write events + for { + select { + case <-ctx.Done(): + return errors.New("context closed") + case msg := <-ws.readChan: + var payload []sessionMessage + for _, md := range msg { + if md.mt == websocket.MessageText { + payload = append(payload, sessionMessage{ + ID: ws.sessionInfo.ID, + Version: ws.sessionInfo.Version, + Message: string(md.data), + Subprotocol: ws.sessionInfo.Subprotocol, + }) + } else { + payload = append(payload, sessionMessage{ + ID: ws.sessionInfo.ID, + Version: ws.sessionInfo.Version, + Message: []interface{}{md.data}, + Subprotocol: ws.sessionInfo.Subprotocol, + }) + } + } + + buf, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to encoded data in JSON format: %v", err) + } + + // send data to agent via HTTP + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://:%v%v/data", port, shimPath), bytes.NewBuffer(buf)) + if err != nil { + return fmt.Errorf("failed to create a new open request: %v", err) + } + req.Header.Set("X-Websocket-Shim-Version", "1") + + err = p.handleFrontendRequest(nil, req, nil) + if err != nil { + return err + } + case buf := <-ws.writeChan: + var decodedBuf []interface{} + err = json.Unmarshal(buf, &decodedBuf) + if err != nil { + return err + } + + var msg []byte + var msgType websocket.MessageType + + for i := range decodedBuf { + msg, msgType, err = parseServerMessage(decodedBuf[i]) + if err != nil { + return err + } + + iow, err := conn.Writer(ctx, msgType) + if err != nil { + return err + } + + bytesWritten, err := iow.Write(msg) + if bytesWritten < len(msg) { + if err != nil { + return err + } else { + return errors.New("unexpected error while writing data from proxy to client") + } + } + iow.Close() + } + } + } +} + +func (p *proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if backendID := r.Header.Get(utils.HeaderBackendID); backendID != "" { + p.handleAgentRequest(w, r, backendID) return } + + if shimPath != "" && strings.ToLower(r.Header.Get("Connection")) == "upgrade" && strings.ToLower(r.Header.Get("Upgrade")) == "websocket" { + err := p.handleWebsocketRequest(w, r) + if err != nil { + log.Print(err) + } + return + } + + err := p.handleFrontendRequest(w, r, nil) + if err != nil { + log.Print(err) + } } func main() { + flag.IntVar(&port, "port", 0, "Port on which to listen") + flag.IntVar(&setReadLimit, "ws-read-limit", 0, "websocket read limit from client in bytes") + flag.IntVar(&bufSize, "ws-buffer-size", 1024*4, "websocket buffer size for writes") + flag.StringVar(&shimPath, "shim-path", "", "Path under which to handle websocket shim requests") + flag.Parse() - listener, err := net.Listen("tcp", fmt.Sprintf(":%d", *port)) + + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { - log.Fatalf("Failed to create the TCP listener for port %d: %v", *port, err) + log.Fatalf("Failed to create the TCP listener for port %d: %v", port, err) } log.Printf("Listening on %s", listener.Addr()) log.Fatal(http.Serve(listener, newProxy()))