diff --git a/proxy.go b/proxy.go index 72ec90bb..da1c62b0 100644 --- a/proxy.go +++ b/proxy.go @@ -238,7 +238,12 @@ func executeWithRetry( // comment s.host.dec() line to avoid double increment; issue #322 // s.host.dec() s.host.SetIsActive(false) - nextHost := s.cluster.getHost() + var nextHost *topology.Node + if s.replicaIndex > 0 || s.nodeIndex > 0 { + nextHost = s.cluster.getSpecificHost(s.replicaIndex, s.nodeIndex) + } else { + nextHost = s.cluster.getHost() + } // The query could be retried if it has no stickiness to a certain server if numRetry < maxRetry && nextHost.IsActive() && s.sessionId == "" { // the query execution has been failed @@ -917,6 +922,11 @@ func (rp *reverseProxy) getScope(req *http.Request) (*scope, int, error) { return nil, http.StatusForbidden, fmt.Errorf("cluster user %q is not allowed to access", cu.name) } - s := newScope(req, u, c, cu, sessionId, sessionTimeout) + replicaIndex, nodeIndex, err := getSpecificHostIndex(req, c) + if err != nil { + return nil, http.StatusBadRequest, err + } + + s := newScope(req, u, c, cu, sessionId, sessionTimeout, replicaIndex, nodeIndex) return s, 0, nil } diff --git a/scope.go b/scope.go index ef970565..0b4691d6 100644 --- a/scope.go +++ b/scope.go @@ -45,6 +45,8 @@ type scope struct { sessionId string sessionTimeout int + replicaIndex int + nodeIndex int remoteAddr string localAddr string @@ -57,10 +59,14 @@ type scope struct { requestPacketSize int } -func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId string, sessionTimeout int) *scope { - h := c.getHost() +func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId string, sessionTimeout int, replicaIndex, nodeIndex int) *scope { + var h *topology.Node if sessionId != "" { h = c.getHostSticky(sessionId) + } else if replicaIndex > 0 || nodeIndex > 0 { + h = c.getSpecificHost(replicaIndex, nodeIndex) + } else { + h = c.getHost() } var localAddr string if addr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok { @@ -75,6 +81,8 @@ func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId clusterUser: cu, sessionId: sessionId, sessionTimeout: sessionTimeout, + replicaIndex: replicaIndex, + nodeIndex: nodeIndex, remoteAddr: req.RemoteAddr, localAddr: localAddr, @@ -185,11 +193,13 @@ func (s *scope) waitUntilAllowStart(sleep time.Duration, deadline time.Time, lab var h *topology.Node // Choose new host, since the previous one may become obsolete // after sleeping. - if s.sessionId == "" { - h = s.cluster.getHost() - } else { + if s.sessionId != "" { // if request has session_id, set same host h = s.cluster.getHostSticky(s.sessionId) + } else if s.replicaIndex > 0 || s.nodeIndex > 0 { + h = s.cluster.getSpecificHost(s.replicaIndex, s.nodeIndex) + } else { + h = s.cluster.getHost() } s.host = h @@ -720,6 +730,8 @@ func newReplicas(replicasCfg []config.Replica, nodes []string, scheme string, c return nil, err } r.hosts = hosts + c.maxNodeIndex = len(r.hosts) + c.maxReplicaIndex = 1 return []*replica{r}, nil } @@ -735,7 +747,9 @@ func newReplicas(replicasCfg []config.Replica, nodes []string, scheme string, c } r.hosts = hosts replicas[i] = r + c.maxNodeIndex = max(c.maxNodeIndex, len(r.hosts)) } + c.maxReplicaIndex = len(replicas) return replicas, nil } @@ -775,6 +789,9 @@ type cluster struct { replicas []*replica nextReplicaIdx uint32 + maxReplicaIndex int + maxNodeIndex int + users map[string]*clusterUser killQueryUserName string @@ -937,6 +954,59 @@ func (r *replica) getHostSticky(sessionId string) *topology.Node { return h } +// getSpecificReplica returns specific replica by replicaIndex from the cluster. +// +// Always returns non-nil. +func (c *cluster) getSpecificReplica(replicaIndex, nodeIndex int) *replica { + if replicaIndex > 0 { + return c.replicas[replicaIndex-1] + } + if nodeIndex == 0 { + return c.getReplica() + } + + idx := atomic.AddUint32(&c.nextReplicaIdx, 1) + n := uint32(len(c.replicas)) + if n == 1 { + return c.replicas[0] + } + + var r *replica + reqs := ^uint32(0) + + // Scan all the replicas for the least loaded and nodeIndex-satisfied replica. + for i := uint32(0); i < n; i++ { + tmpIdx := (idx + i) % n + tmpR := c.replicas[tmpIdx] + if nodeIndex > len(tmpR.hosts) { + continue + } + if tmpR.isActive() || r == nil { + tmpReqs := tmpR.load() + if tmpReqs < reqs || !r.isActive() { + r = tmpR + reqs = tmpReqs + } + } + } + + // The returned replica may be inactive. This is OK, + // since this means all the nodeIndex-satisfied replicas are inactive, + // so let's try proxying the request to any replica. + return r +} + +// getSpecificHost returns specific host by nodeIndex from replica. +// +// Always returns non-nil. +func (r *replica) getSpecificHost(nodeIndex int) *topology.Node { + if nodeIndex > 0 { + return r.hosts[nodeIndex-1] + } + + return r.getHost() +} + // getHost returns least loaded + round-robin host from replica. // // Always returns non-nil. @@ -991,6 +1061,16 @@ func (c *cluster) getHostSticky(sessionId string) *topology.Node { return r.getHostSticky(sessionId) } +// getSpecificHost returns specific host by index from cluster. +// Both replicaIndex/nodeIndex start from 1 and satisfy [0, maxReplicaIndex/maxNodeIndex], 0 means no specific host index. +// If both are 0, getSpecificHost equals to getHost. +// +// Always returns non-nil. +func (c *cluster) getSpecificHost(replicaIndex, nodeIndex int) *topology.Node { + r := c.getSpecificReplica(replicaIndex, nodeIndex) + return r.getSpecificHost(nodeIndex) +} + // getHost returns least loaded + round-robin host from cluster. // // Always returns non-nil. diff --git a/scope_test.go b/scope_test.go index 438966c6..22c18da5 100644 --- a/scope_test.go +++ b/scope_test.go @@ -410,6 +410,78 @@ func TestGetHostSticky(t *testing.T) { } } +func TestGetSpecificHost(t *testing.T) { + c := testGetCluster() + + t.Run("SpecifyReplicaIndex", func(t *testing.T) { + h := c.getSpecificHost(1, 0) + if h.Host() != "127.0.0.11" && h.Host() != "127.0.0.22" { + t.Fatalf("Expected host from replica1, got: %s", h.Host()) + } + + h = c.getSpecificHost(2, 0) + if h.Host() != "127.0.0.33" && h.Host() != "127.0.0.44" { + t.Fatalf("Expected host from replica2, got: %s", h.Host()) + } + + h = c.getSpecificHost(3, 0) + if h.Host() != "127.0.0.55" && h.Host() != "127.0.0.66" { + t.Fatalf("Expected host from replica3, got: %s", h.Host()) + } + }) + + t.Run("SpecifyNodeIndex", func(t *testing.T) { + h := c.getSpecificHost(0, 1) + if h.Host() != "127.0.0.11" && h.Host() != "127.0.0.33" && h.Host() != "127.0.0.55" { + t.Fatalf("Expected first node from any replica, got: %s", h.Host()) + } + + h = c.getSpecificHost(0, 2) + if h.Host() != "127.0.0.22" && h.Host() != "127.0.0.44" && h.Host() != "127.0.0.66" { + t.Fatalf("Expected second node from any replica, got: %s", h.Host()) + } + }) + + t.Run("SpecifyReplicaIndexAndNodeIndex", func(t *testing.T) { + h := c.getSpecificHost(1, 1) + if h.Host() != "127.0.0.11" { + t.Fatalf("Expected 127.0.0.11, got: %s", h.Host()) + } + + h = c.getSpecificHost(1, 2) + if h.Host() != "127.0.0.22" { + t.Fatalf("Expected 127.0.0.22, got: %s", h.Host()) + } + + h = c.getSpecificHost(2, 1) + if h.Host() != "127.0.0.33" { + t.Fatalf("Expected 127.0.0.33, got: %s", h.Host()) + } + }) + + t.Run("SpecifyBothIndicesZero", func(t *testing.T) { + h := c.getSpecificHost(0, 0) + if h == nil { + t.Fatalf("getSpecificHost(0, 0) returned nil") + } + found := false + for _, r := range c.replicas { + for _, node := range r.hosts { + if h.Host() == node.Host() { + found = true + break + } + } + if found { + break + } + } + if !found { + t.Fatalf("getSpecificHost(0, 0) returned unknown host: %s", h.Host()) + } + }) +} + func TestIncQueued(t *testing.T) { u := testGetUser() cu := testGetClusterUser() @@ -485,6 +557,12 @@ func testGetCluster() *cluster { topology.NewNode(&url.URL{Host: "127.0.0.66"}, nil, "", r3.name, topology.WithDefaultActiveState(true)), } r3.name = "replica3" + + c.maxReplicaIndex = len(c.replicas) + for _, r := range c.replicas { + c.maxNodeIndex = max(c.maxNodeIndex, len(r.hosts)) + } + return c } diff --git a/utils.go b/utils.go index c3d97d14..420599bc 100644 --- a/utils.go +++ b/utils.go @@ -63,6 +63,53 @@ func getSessionTimeout(req *http.Request) int { return 60 } +// getSpecificHostIndex retrieves specific host index, including replica and node index +// index starts from 1, 0 means no specific host index +// shard_index is alias for node_index, and override node_index if both are specified +func getSpecificHostIndex(req *http.Request, c *cluster) (int, int, error) { + params := req.URL.Query() + var replicaIndex, nodeIndex int + var err error + // replica index + replicaIndexStr := params.Get("replica_index") + if replicaIndexStr != "" { + replicaIndex, err = strconv.Atoi(replicaIndexStr) + if err != nil { + return -1, -1, fmt.Errorf("invalid replica index %q", replicaIndexStr) + } + if replicaIndex < 0 || replicaIndex > c.maxReplicaIndex { + return -1, -1, fmt.Errorf("invalid replica index %q", replicaIndexStr) + } + } + // node index (shard_index is alias for node_index) + nodeIndexStr := params.Get("node_index") + if nodeIndexStr != "" { + nodeIndex, err = strconv.Atoi(nodeIndexStr) + if err != nil { + return -1, -1, fmt.Errorf("invalid node index %q", nodeIndexStr) + } + if nodeIndex < 0 || nodeIndex > c.maxNodeIndex { + return -1, -1, fmt.Errorf("invalid node index %q", nodeIndexStr) + } + } + shardIndexStr := params.Get("shard_index") + if shardIndexStr != "" { + nodeIndex, err = strconv.Atoi(shardIndexStr) + if err != nil { + return -1, -1, fmt.Errorf("invalid shard index %q", shardIndexStr) + } + if nodeIndex < 0 || nodeIndex > c.maxNodeIndex { + return -1, -1, fmt.Errorf("invalid shard index %q", shardIndexStr) + } + } + // validate if both replicaIndex and nodeIndex are specified + if replicaIndex > 0 && nodeIndex > 0 && nodeIndex > len(c.replicas[replicaIndex-1].hosts) { + return -1, -1, fmt.Errorf("invalid host index (%q, %q)", replicaIndexStr, nodeIndexStr) + } + + return replicaIndex, nodeIndex, nil +} + // getQuerySnippet returns query snippet. // // getQuerySnippet must be called only for error reporting. diff --git a/utils_test.go b/utils_test.go index 9fbbc31d..630b050e 100644 --- a/utils_test.go +++ b/utils_test.go @@ -4,6 +4,7 @@ import ( "bytes" "compress/gzip" "fmt" + "github.com/contentsquare/chproxy/internal/topology" "github.com/stretchr/testify/assert" "net/http" "net/url" @@ -374,3 +375,150 @@ func TestCalcMapHash(t *testing.T) { }) } } + +func TestGetSpecificHostIndex(t *testing.T) { + // Create a test cluster with 2 replicas, each having 3 nodes + testCluster := &cluster{ + name: "test_cluster", + replicas: []*replica{ + { + name: "replica1", + hosts: []*topology.Node{{}, {}, {}}, + }, + { + name: "replica2", + hosts: []*topology.Node{{}, {}, {}}, + }, + }, + maxReplicaIndex: 2, + maxNodeIndex: 3, + } + // Set the cluster reference for each replica + for _, r := range testCluster.replicas { + r.cluster = testCluster + } + + testCases := []struct { + name string + params map[string]string + expectedRI int + expectedNI int + expectedError bool + }{ + { + "no parameters", + map[string]string{}, + 0, + 0, + false, + }, + { + "only replica_index", + map[string]string{"replica_index": "1"}, + 1, + 0, + false, + }, + { + "only node_index", + map[string]string{"node_index": "2"}, + 0, + 2, + false, + }, + { + "only shard_index", + map[string]string{"shard_index": "3"}, + 0, + 3, + false, + }, + { + "replica_index and node_index", + map[string]string{"replica_index": "1", "node_index": "2"}, + 1, + 2, + false, + }, + { + "invalid replica_index", + map[string]string{"replica_index": "invalid"}, + 0, + 0, + true, + }, + { + "invalid node_index", + map[string]string{"node_index": "-1"}, + 0, + 0, + true, + }, + { + "replica_index out of range", + map[string]string{"replica_index": "3"}, + 0, + 0, + true, + }, + { + "node_index out of range", + map[string]string{"node_index": "4"}, + 0, + 0, + true, + }, + { + "node_index out of range for specific replica", + map[string]string{"replica_index": "1", "node_index": "4"}, + 0, + 0, + true, + }, + { + "replica_index is zero", + map[string]string{"replica_index": "0"}, + 0, + 0, + false, + }, + { + "node_index is zero", + map[string]string{"node_index": "0"}, + 0, + 0, + false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest("GET", "", nil) + checkErr(t, err) + + // Set up the URL parameters + params := make(url.Values) + for k, v := range tc.params { + params.Set(k, v) + } + req.URL.RawQuery = params.Encode() + + replicaIndex, nodeIndex, err := getSpecificHostIndex(req, testCluster) + if tc.expectedError { + if err == nil { + t.Fatalf("expected error but got none") + } + } else { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if replicaIndex != tc.expectedRI { + t.Fatalf("unexpected replicaIndex: got %d, expecting %d", replicaIndex, tc.expectedRI) + } + if nodeIndex != tc.expectedNI { + t.Fatalf("unexpected nodeIndex: got %d, expecting %d", nodeIndex, tc.expectedNI) + } + } + }) + } +}