From 779567ff1e9e133dde542ebedf4786aaff7e26ee Mon Sep 17 00:00:00 2001 From: Zhigao Hong Date: Fri, 19 Dec 2025 02:11:28 +0800 Subject: [PATCH 1/4] feat: support proxy to specific host --- proxy.go | 7 ++++- scope.go | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-- utils.go | 47 +++++++++++++++++++++++++++++++ 3 files changed, 137 insertions(+), 3 deletions(-) diff --git a/proxy.go b/proxy.go index 72ec90bb..490b3092 100644 --- a/proxy.go +++ b/proxy.go @@ -917,6 +917,11 @@ func (rp *reverseProxy) getScope(req *http.Request) (*scope, int, error) { return nil, http.StatusForbidden, fmt.Errorf("cluster user %q is not allowed to access", cu.name) } - s := newScope(req, u, c, cu, sessionId, sessionTimeout) + replicaIndex, nodeIndex, err := getSpecificHostIndex(req, c) + if err != nil { + return nil, http.StatusBadRequest, err + } + + s := newScope(req, u, c, cu, sessionId, sessionTimeout, replicaIndex, nodeIndex) return s, 0, nil } diff --git a/scope.go b/scope.go index ef970565..821d8c5b 100644 --- a/scope.go +++ b/scope.go @@ -57,10 +57,14 @@ type scope struct { requestPacketSize int } -func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId string, sessionTimeout int) *scope { - h := c.getHost() +func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId string, sessionTimeout int, replicaIndex, nodeIndex int) *scope { + var h *topology.Node if sessionId != "" { h = c.getHostSticky(sessionId) + } else if replicaIndex > 0 || nodeIndex > 0 { + h = c.getSpecificHost(replicaIndex, nodeIndex) + } else { + h = c.getHost() } var localAddr string if addr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok { @@ -720,6 +724,8 @@ func newReplicas(replicasCfg []config.Replica, nodes []string, scheme string, c return nil, err } r.hosts = hosts + c.maxNodeIndex = len(r.hosts) + c.maxReplicaIndex = 1 return []*replica{r}, nil } @@ -735,7 +741,9 @@ func newReplicas(replicasCfg []config.Replica, nodes []string, scheme string, c } r.hosts = hosts replicas[i] = r + c.maxNodeIndex = max(c.maxNodeIndex, len(r.hosts)) } + c.maxReplicaIndex = len(replicas) return replicas, nil } @@ -775,6 +783,9 @@ type cluster struct { replicas []*replica nextReplicaIdx uint32 + maxReplicaIndex int + maxNodeIndex int + users map[string]*clusterUser killQueryUserName string @@ -937,6 +948,69 @@ func (r *replica) getHostSticky(sessionId string) *topology.Node { return h } +// getSpecificReplica returns specific replica by replicaIndex from the cluster. +// +// Always returns non-nil. +func (c *cluster) getSpecificReplica(replicaIndex, nodeIndex int) *replica { + if replicaIndex > 0 { + return c.replicas[replicaIndex-1] + } + if nodeIndex == 0 { + return c.getReplica() + } + + idx := atomic.AddUint32(&c.nextReplicaIdx, 1) + n := uint32(len(c.replicas)) + if n == 1 { + return c.replicas[0] + } + + idx %= n + r := c.replicas[idx] + reqs := r.load() + + // Set least priority to inactive replica. + if !r.isActive() { + reqs = ^uint32(0) + } + + if reqs == 0 && nodeIndex <= len(r.hosts) { + return r + } + + // Scan all the replicas for the least loaded and nodeIndex-satisfied replica. + for i := uint32(1); i < n; i++ { + tmpIdx := (idx + i) % n + tmpR := c.replicas[tmpIdx] + if !tmpR.isActive() || nodeIndex > len(tmpR.hosts) { + continue + } + tmpReqs := tmpR.load() + if tmpReqs == 0 && nodeIndex <= len(tmpR.hosts) { + return tmpR + } + if tmpReqs < reqs && nodeIndex <= len(tmpR.hosts) { + r = tmpR + reqs = tmpReqs + } + } + // The returned replica may be inactive. This is OK, + // since this means all the replicas are inactive, + // so let's try proxying the request to any replica. + return r +} + +// getSpecificHost returns specific host by nodeIndex from replica. +// +// Always returns non-nil. +func (r *replica) getSpecificHost(nodeIndex int) *topology.Node { + if nodeIndex > 0 { + return r.hosts[nodeIndex-1] + } + + return r.getHost() +} + // getHost returns least loaded + round-robin host from replica. // // Always returns non-nil. @@ -991,6 +1065,14 @@ func (c *cluster) getHostSticky(sessionId string) *topology.Node { return r.getHostSticky(sessionId) } +// getSpecificHost returns specific host by index from cluster. +// +// Always returns non-nil. +func (c *cluster) getSpecificHost(replicaIndex, nodeIndex int) *topology.Node { + r := c.getSpecificReplica(replicaIndex, nodeIndex) + return r.getSpecificHost(nodeIndex) +} + // getHost returns least loaded + round-robin host from cluster. // // Always returns non-nil. diff --git a/utils.go b/utils.go index c3d97d14..420599bc 100644 --- a/utils.go +++ b/utils.go @@ -63,6 +63,53 @@ func getSessionTimeout(req *http.Request) int { return 60 } +// getSpecificHostIndex retrieves specific host index, including replica and node index +// index starts from 1, 0 means no specific host index +// shard_index is alias for node_index, and override node_index if both are specified +func getSpecificHostIndex(req *http.Request, c *cluster) (int, int, error) { + params := req.URL.Query() + var replicaIndex, nodeIndex int + var err error + // replica index + replicaIndexStr := params.Get("replica_index") + if replicaIndexStr != "" { + replicaIndex, err = strconv.Atoi(replicaIndexStr) + if err != nil { + return -1, -1, fmt.Errorf("invalid replica index %q", replicaIndexStr) + } + if replicaIndex < 0 || replicaIndex > c.maxReplicaIndex { + return -1, -1, fmt.Errorf("invalid replica index %q", replicaIndexStr) + } + } + // node index (shard_index is alias for node_index) + nodeIndexStr := params.Get("node_index") + if nodeIndexStr != "" { + nodeIndex, err = strconv.Atoi(nodeIndexStr) + if err != nil { + return -1, -1, fmt.Errorf("invalid node index %q", nodeIndexStr) + } + if nodeIndex < 0 || nodeIndex > c.maxNodeIndex { + return -1, -1, fmt.Errorf("invalid node index %q", nodeIndexStr) + } + } + shardIndexStr := params.Get("shard_index") + if shardIndexStr != "" { + nodeIndex, err = strconv.Atoi(shardIndexStr) + if err != nil { + return -1, -1, fmt.Errorf("invalid shard index %q", shardIndexStr) + } + if nodeIndex < 0 || nodeIndex > c.maxNodeIndex { + return -1, -1, fmt.Errorf("invalid shard index %q", shardIndexStr) + } + } + // validate if both replicaIndex and nodeIndex are specified + if replicaIndex > 0 && nodeIndex > 0 && nodeIndex > len(c.replicas[replicaIndex-1].hosts) { + return -1, -1, fmt.Errorf("invalid host index (%q, %q)", replicaIndexStr, nodeIndexStr) + } + + return replicaIndex, nodeIndex, nil +} + // getQuerySnippet returns query snippet. // // getQuerySnippet must be called only for error reporting. From 87ad5bc82b4b270e911f8f249e45f2c7ee2e640d Mon Sep 17 00:00:00 2001 From: Zhigao Hong Date: Sat, 20 Dec 2025 19:11:35 +0800 Subject: [PATCH 2/4] fix: getSpecificReplica may return non-nodeIndex-satisfied replic --- scope.go | 36 ++++++++++++++---------------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/scope.go b/scope.go index 821d8c5b..5ad24395 100644 --- a/scope.go +++ b/scope.go @@ -965,37 +965,27 @@ func (c *cluster) getSpecificReplica(replicaIndex, nodeIndex int) *replica { return c.replicas[0] } - idx %= n - r := c.replicas[idx] - reqs := r.load() - - // Set least priority to inactive replica. - if !r.isActive() { - reqs = ^uint32(0) - } - - if reqs == 0 && nodeIndex <= len(r.hosts) { - return r - } + var r *replica + reqs := ^uint32(0) // Scan all the replicas for the least loaded and nodeIndex-satisfied replica. - for i := uint32(1); i < n; i++ { + for i := uint32(0); i < n; i++ { tmpIdx := (idx + i) % n tmpR := c.replicas[tmpIdx] - if !tmpR.isActive() || nodeIndex > len(tmpR.hosts) { + if nodeIndex > len(tmpR.hosts) { continue } - tmpReqs := tmpR.load() - if tmpReqs == 0 && nodeIndex <= len(tmpR.hosts) { - return tmpR - } - if tmpReqs < reqs && nodeIndex <= len(tmpR.hosts) { - r = tmpR - reqs = tmpReqs + if tmpR.isActive() || r == nil { + tmpReqs := tmpR.load() + if tmpReqs < reqs || !r.isActive() { + r = tmpR + reqs = tmpReqs + } } } + // The returned replica may be inactive. This is OK, - // since this means all the replicas are inactive, + // since this means all the nodeIndex-satisfied replicas are inactive, // so let's try proxying the request to any replica. return r } @@ -1066,6 +1056,8 @@ func (c *cluster) getHostSticky(sessionId string) *topology.Node { } // getSpecificHost returns specific host by index from cluster. +// Both replicaIndex/nodeIndex start from 1 and satisfy [0, maxReplicaIndex/maxNodeIndex], 0 means no specific host index. +// If both are 0, getSpecificHost equals to getHost. // // Always returns non-nil. func (c *cluster) getSpecificHost(replicaIndex, nodeIndex int) *topology.Node { From 0e23bafda1c8dd4811112c2fac180d382dbbc28d Mon Sep 17 00:00:00 2001 From: Zhigao Hong Date: Sat, 20 Dec 2025 19:27:33 +0800 Subject: [PATCH 3/4] feat: select specific host in retry and concurrency limiting scenarios --- proxy.go | 7 ++++++- scope.go | 12 +++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/proxy.go b/proxy.go index 490b3092..da1c62b0 100644 --- a/proxy.go +++ b/proxy.go @@ -238,7 +238,12 @@ func executeWithRetry( // comment s.host.dec() line to avoid double increment; issue #322 // s.host.dec() s.host.SetIsActive(false) - nextHost := s.cluster.getHost() + var nextHost *topology.Node + if s.replicaIndex > 0 || s.nodeIndex > 0 { + nextHost = s.cluster.getSpecificHost(s.replicaIndex, s.nodeIndex) + } else { + nextHost = s.cluster.getHost() + } // The query could be retried if it has no stickiness to a certain server if numRetry < maxRetry && nextHost.IsActive() && s.sessionId == "" { // the query execution has been failed diff --git a/scope.go b/scope.go index 5ad24395..0b4691d6 100644 --- a/scope.go +++ b/scope.go @@ -45,6 +45,8 @@ type scope struct { sessionId string sessionTimeout int + replicaIndex int + nodeIndex int remoteAddr string localAddr string @@ -79,6 +81,8 @@ func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId clusterUser: cu, sessionId: sessionId, sessionTimeout: sessionTimeout, + replicaIndex: replicaIndex, + nodeIndex: nodeIndex, remoteAddr: req.RemoteAddr, localAddr: localAddr, @@ -189,11 +193,13 @@ func (s *scope) waitUntilAllowStart(sleep time.Duration, deadline time.Time, lab var h *topology.Node // Choose new host, since the previous one may become obsolete // after sleeping. - if s.sessionId == "" { - h = s.cluster.getHost() - } else { + if s.sessionId != "" { // if request has session_id, set same host h = s.cluster.getHostSticky(s.sessionId) + } else if s.replicaIndex > 0 || s.nodeIndex > 0 { + h = s.cluster.getSpecificHost(s.replicaIndex, s.nodeIndex) + } else { + h = s.cluster.getHost() } s.host = h From 7c6ab745ac82a6289fc9ae04dd8762ca98b79607 Mon Sep 17 00:00:00 2001 From: Zhigao Hong Date: Sat, 20 Dec 2025 20:17:27 +0800 Subject: [PATCH 4/4] feat: add test cases for specific host selection --- scope_test.go | 78 ++++++++++++++++++++++++++ utils_test.go | 148 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 226 insertions(+) diff --git a/scope_test.go b/scope_test.go index 438966c6..22c18da5 100644 --- a/scope_test.go +++ b/scope_test.go @@ -410,6 +410,78 @@ func TestGetHostSticky(t *testing.T) { } } +func TestGetSpecificHost(t *testing.T) { + c := testGetCluster() + + t.Run("SpecifyReplicaIndex", func(t *testing.T) { + h := c.getSpecificHost(1, 0) + if h.Host() != "127.0.0.11" && h.Host() != "127.0.0.22" { + t.Fatalf("Expected host from replica1, got: %s", h.Host()) + } + + h = c.getSpecificHost(2, 0) + if h.Host() != "127.0.0.33" && h.Host() != "127.0.0.44" { + t.Fatalf("Expected host from replica2, got: %s", h.Host()) + } + + h = c.getSpecificHost(3, 0) + if h.Host() != "127.0.0.55" && h.Host() != "127.0.0.66" { + t.Fatalf("Expected host from replica3, got: %s", h.Host()) + } + }) + + t.Run("SpecifyNodeIndex", func(t *testing.T) { + h := c.getSpecificHost(0, 1) + if h.Host() != "127.0.0.11" && h.Host() != "127.0.0.33" && h.Host() != "127.0.0.55" { + t.Fatalf("Expected first node from any replica, got: %s", h.Host()) + } + + h = c.getSpecificHost(0, 2) + if h.Host() != "127.0.0.22" && h.Host() != "127.0.0.44" && h.Host() != "127.0.0.66" { + t.Fatalf("Expected second node from any replica, got: %s", h.Host()) + } + }) + + t.Run("SpecifyReplicaIndexAndNodeIndex", func(t *testing.T) { + h := c.getSpecificHost(1, 1) + if h.Host() != "127.0.0.11" { + t.Fatalf("Expected 127.0.0.11, got: %s", h.Host()) + } + + h = c.getSpecificHost(1, 2) + if h.Host() != "127.0.0.22" { + t.Fatalf("Expected 127.0.0.22, got: %s", h.Host()) + } + + h = c.getSpecificHost(2, 1) + if h.Host() != "127.0.0.33" { + t.Fatalf("Expected 127.0.0.33, got: %s", h.Host()) + } + }) + + t.Run("SpecifyBothIndicesZero", func(t *testing.T) { + h := c.getSpecificHost(0, 0) + if h == nil { + t.Fatalf("getSpecificHost(0, 0) returned nil") + } + found := false + for _, r := range c.replicas { + for _, node := range r.hosts { + if h.Host() == node.Host() { + found = true + break + } + } + if found { + break + } + } + if !found { + t.Fatalf("getSpecificHost(0, 0) returned unknown host: %s", h.Host()) + } + }) +} + func TestIncQueued(t *testing.T) { u := testGetUser() cu := testGetClusterUser() @@ -485,6 +557,12 @@ func testGetCluster() *cluster { topology.NewNode(&url.URL{Host: "127.0.0.66"}, nil, "", r3.name, topology.WithDefaultActiveState(true)), } r3.name = "replica3" + + c.maxReplicaIndex = len(c.replicas) + for _, r := range c.replicas { + c.maxNodeIndex = max(c.maxNodeIndex, len(r.hosts)) + } + return c } diff --git a/utils_test.go b/utils_test.go index 9fbbc31d..630b050e 100644 --- a/utils_test.go +++ b/utils_test.go @@ -4,6 +4,7 @@ import ( "bytes" "compress/gzip" "fmt" + "github.com/contentsquare/chproxy/internal/topology" "github.com/stretchr/testify/assert" "net/http" "net/url" @@ -374,3 +375,150 @@ func TestCalcMapHash(t *testing.T) { }) } } + +func TestGetSpecificHostIndex(t *testing.T) { + // Create a test cluster with 2 replicas, each having 3 nodes + testCluster := &cluster{ + name: "test_cluster", + replicas: []*replica{ + { + name: "replica1", + hosts: []*topology.Node{{}, {}, {}}, + }, + { + name: "replica2", + hosts: []*topology.Node{{}, {}, {}}, + }, + }, + maxReplicaIndex: 2, + maxNodeIndex: 3, + } + // Set the cluster reference for each replica + for _, r := range testCluster.replicas { + r.cluster = testCluster + } + + testCases := []struct { + name string + params map[string]string + expectedRI int + expectedNI int + expectedError bool + }{ + { + "no parameters", + map[string]string{}, + 0, + 0, + false, + }, + { + "only replica_index", + map[string]string{"replica_index": "1"}, + 1, + 0, + false, + }, + { + "only node_index", + map[string]string{"node_index": "2"}, + 0, + 2, + false, + }, + { + "only shard_index", + map[string]string{"shard_index": "3"}, + 0, + 3, + false, + }, + { + "replica_index and node_index", + map[string]string{"replica_index": "1", "node_index": "2"}, + 1, + 2, + false, + }, + { + "invalid replica_index", + map[string]string{"replica_index": "invalid"}, + 0, + 0, + true, + }, + { + "invalid node_index", + map[string]string{"node_index": "-1"}, + 0, + 0, + true, + }, + { + "replica_index out of range", + map[string]string{"replica_index": "3"}, + 0, + 0, + true, + }, + { + "node_index out of range", + map[string]string{"node_index": "4"}, + 0, + 0, + true, + }, + { + "node_index out of range for specific replica", + map[string]string{"replica_index": "1", "node_index": "4"}, + 0, + 0, + true, + }, + { + "replica_index is zero", + map[string]string{"replica_index": "0"}, + 0, + 0, + false, + }, + { + "node_index is zero", + map[string]string{"node_index": "0"}, + 0, + 0, + false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest("GET", "", nil) + checkErr(t, err) + + // Set up the URL parameters + params := make(url.Values) + for k, v := range tc.params { + params.Set(k, v) + } + req.URL.RawQuery = params.Encode() + + replicaIndex, nodeIndex, err := getSpecificHostIndex(req, testCluster) + if tc.expectedError { + if err == nil { + t.Fatalf("expected error but got none") + } + } else { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if replicaIndex != tc.expectedRI { + t.Fatalf("unexpected replicaIndex: got %d, expecting %d", replicaIndex, tc.expectedRI) + } + if nodeIndex != tc.expectedNI { + t.Fatalf("unexpected nodeIndex: got %d, expecting %d", nodeIndex, tc.expectedNI) + } + } + }) + } +}