Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,12 @@ func executeWithRetry(
// comment s.host.dec() line to avoid double increment; issue #322
// s.host.dec()
s.host.SetIsActive(false)
nextHost := s.cluster.getHost()
var nextHost *topology.Node
if s.replicaIndex > 0 || s.nodeIndex > 0 {
nextHost = s.cluster.getSpecificHost(s.replicaIndex, s.nodeIndex)
} else {
nextHost = s.cluster.getHost()
}
// The query could be retried if it has no stickiness to a certain server
if numRetry < maxRetry && nextHost.IsActive() && s.sessionId == "" {
// the query execution has been failed
Expand Down Expand Up @@ -917,6 +922,11 @@ func (rp *reverseProxy) getScope(req *http.Request) (*scope, int, error) {
return nil, http.StatusForbidden, fmt.Errorf("cluster user %q is not allowed to access", cu.name)
}

s := newScope(req, u, c, cu, sessionId, sessionTimeout)
replicaIndex, nodeIndex, err := getSpecificHostIndex(req, c)
if err != nil {
return nil, http.StatusBadRequest, err
}

s := newScope(req, u, c, cu, sessionId, sessionTimeout, replicaIndex, nodeIndex)
return s, 0, nil
}
90 changes: 85 additions & 5 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ type scope struct {

sessionId string
sessionTimeout int
replicaIndex int
nodeIndex int

remoteAddr string
localAddr string
Expand All @@ -57,10 +59,14 @@ type scope struct {
requestPacketSize int
}

func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId string, sessionTimeout int) *scope {
h := c.getHost()
func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId string, sessionTimeout int, replicaIndex, nodeIndex int) *scope {
var h *topology.Node
if sessionId != "" {
h = c.getHostSticky(sessionId)
} else if replicaIndex > 0 || nodeIndex > 0 {
h = c.getSpecificHost(replicaIndex, nodeIndex)
} else {
h = c.getHost()
}
var localAddr string
if addr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok {
Expand All @@ -75,6 +81,8 @@ func newScope(req *http.Request, u *user, c *cluster, cu *clusterUser, sessionId
clusterUser: cu,
sessionId: sessionId,
sessionTimeout: sessionTimeout,
replicaIndex: replicaIndex,
nodeIndex: nodeIndex,

remoteAddr: req.RemoteAddr,
localAddr: localAddr,
Expand Down Expand Up @@ -185,11 +193,13 @@ func (s *scope) waitUntilAllowStart(sleep time.Duration, deadline time.Time, lab
var h *topology.Node
// Choose new host, since the previous one may become obsolete
// after sleeping.
if s.sessionId == "" {
h = s.cluster.getHost()
} else {
if s.sessionId != "" {
// if request has session_id, set same host
h = s.cluster.getHostSticky(s.sessionId)
} else if s.replicaIndex > 0 || s.nodeIndex > 0 {
h = s.cluster.getSpecificHost(s.replicaIndex, s.nodeIndex)
} else {
h = s.cluster.getHost()
}

s.host = h
Expand Down Expand Up @@ -720,6 +730,8 @@ func newReplicas(replicasCfg []config.Replica, nodes []string, scheme string, c
return nil, err
}
r.hosts = hosts
c.maxNodeIndex = len(r.hosts)
c.maxReplicaIndex = 1
return []*replica{r}, nil
}

Expand All @@ -735,7 +747,9 @@ func newReplicas(replicasCfg []config.Replica, nodes []string, scheme string, c
}
r.hosts = hosts
replicas[i] = r
c.maxNodeIndex = max(c.maxNodeIndex, len(r.hosts))
}
c.maxReplicaIndex = len(replicas)
return replicas, nil
}

Expand Down Expand Up @@ -775,6 +789,9 @@ type cluster struct {
replicas []*replica
nextReplicaIdx uint32

maxReplicaIndex int
maxNodeIndex int

users map[string]*clusterUser

killQueryUserName string
Expand Down Expand Up @@ -937,6 +954,59 @@ func (r *replica) getHostSticky(sessionId string) *topology.Node {
return h
}

// getSpecificReplica returns specific replica by replicaIndex from the cluster.
//
// Always returns non-nil.
func (c *cluster) getSpecificReplica(replicaIndex, nodeIndex int) *replica {
if replicaIndex > 0 {
return c.replicas[replicaIndex-1]
}
if nodeIndex == 0 {
return c.getReplica()
}

idx := atomic.AddUint32(&c.nextReplicaIdx, 1)
n := uint32(len(c.replicas))
if n == 1 {
return c.replicas[0]
}

var r *replica
reqs := ^uint32(0)

// Scan all the replicas for the least loaded and nodeIndex-satisfied replica.
for i := uint32(0); i < n; i++ {
tmpIdx := (idx + i) % n
tmpR := c.replicas[tmpIdx]
if nodeIndex > len(tmpR.hosts) {
continue
}
if tmpR.isActive() || r == nil {
tmpReqs := tmpR.load()
if tmpReqs < reqs || !r.isActive() {
r = tmpR
reqs = tmpReqs
}
}
}

// The returned replica may be inactive. This is OK,
// since this means all the nodeIndex-satisfied replicas are inactive,
// so let's try proxying the request to any replica.
return r
}

// getSpecificHost returns specific host by nodeIndex from replica.
//
// Always returns non-nil.
func (r *replica) getSpecificHost(nodeIndex int) *topology.Node {
if nodeIndex > 0 {
return r.hosts[nodeIndex-1]
}

return r.getHost()
}

// getHost returns least loaded + round-robin host from replica.
//
// Always returns non-nil.
Expand Down Expand Up @@ -991,6 +1061,16 @@ func (c *cluster) getHostSticky(sessionId string) *topology.Node {
return r.getHostSticky(sessionId)
}

// getSpecificHost returns specific host by index from cluster.
// Both replicaIndex/nodeIndex start from 1 and satisfy [0, maxReplicaIndex/maxNodeIndex], 0 means no specific host index.
// If both are 0, getSpecificHost equals to getHost.
//
// Always returns non-nil.
func (c *cluster) getSpecificHost(replicaIndex, nodeIndex int) *topology.Node {
r := c.getSpecificReplica(replicaIndex, nodeIndex)
return r.getSpecificHost(nodeIndex)
}

// getHost returns least loaded + round-robin host from cluster.
//
// Always returns non-nil.
Expand Down
78 changes: 78 additions & 0 deletions scope_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,78 @@ func TestGetHostSticky(t *testing.T) {
}
}

func TestGetSpecificHost(t *testing.T) {
c := testGetCluster()

t.Run("SpecifyReplicaIndex", func(t *testing.T) {
h := c.getSpecificHost(1, 0)
if h.Host() != "127.0.0.11" && h.Host() != "127.0.0.22" {
t.Fatalf("Expected host from replica1, got: %s", h.Host())
}

h = c.getSpecificHost(2, 0)
if h.Host() != "127.0.0.33" && h.Host() != "127.0.0.44" {
t.Fatalf("Expected host from replica2, got: %s", h.Host())
}

h = c.getSpecificHost(3, 0)
if h.Host() != "127.0.0.55" && h.Host() != "127.0.0.66" {
t.Fatalf("Expected host from replica3, got: %s", h.Host())
}
})

t.Run("SpecifyNodeIndex", func(t *testing.T) {
h := c.getSpecificHost(0, 1)
if h.Host() != "127.0.0.11" && h.Host() != "127.0.0.33" && h.Host() != "127.0.0.55" {
t.Fatalf("Expected first node from any replica, got: %s", h.Host())
}

h = c.getSpecificHost(0, 2)
if h.Host() != "127.0.0.22" && h.Host() != "127.0.0.44" && h.Host() != "127.0.0.66" {
t.Fatalf("Expected second node from any replica, got: %s", h.Host())
}
})

t.Run("SpecifyReplicaIndexAndNodeIndex", func(t *testing.T) {
h := c.getSpecificHost(1, 1)
if h.Host() != "127.0.0.11" {
t.Fatalf("Expected 127.0.0.11, got: %s", h.Host())
}

h = c.getSpecificHost(1, 2)
if h.Host() != "127.0.0.22" {
t.Fatalf("Expected 127.0.0.22, got: %s", h.Host())
}

h = c.getSpecificHost(2, 1)
if h.Host() != "127.0.0.33" {
t.Fatalf("Expected 127.0.0.33, got: %s", h.Host())
}
})

t.Run("SpecifyBothIndicesZero", func(t *testing.T) {
h := c.getSpecificHost(0, 0)
if h == nil {
t.Fatalf("getSpecificHost(0, 0) returned nil")
}
found := false
for _, r := range c.replicas {
for _, node := range r.hosts {
if h.Host() == node.Host() {
found = true
break
}
}
if found {
break
}
}
if !found {
t.Fatalf("getSpecificHost(0, 0) returned unknown host: %s", h.Host())
}
})
}

func TestIncQueued(t *testing.T) {
u := testGetUser()
cu := testGetClusterUser()
Expand Down Expand Up @@ -485,6 +557,12 @@ func testGetCluster() *cluster {
topology.NewNode(&url.URL{Host: "127.0.0.66"}, nil, "", r3.name, topology.WithDefaultActiveState(true)),
}
r3.name = "replica3"

c.maxReplicaIndex = len(c.replicas)
for _, r := range c.replicas {
c.maxNodeIndex = max(c.maxNodeIndex, len(r.hosts))
}

return c
}

Expand Down
47 changes: 47 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,53 @@ func getSessionTimeout(req *http.Request) int {
return 60
}

// getSpecificHostIndex retrieves specific host index, including replica and node index
// index starts from 1, 0 means no specific host index
// shard_index is alias for node_index, and override node_index if both are specified
func getSpecificHostIndex(req *http.Request, c *cluster) (int, int, error) {
params := req.URL.Query()
var replicaIndex, nodeIndex int
var err error
// replica index
replicaIndexStr := params.Get("replica_index")
if replicaIndexStr != "" {
replicaIndex, err = strconv.Atoi(replicaIndexStr)
if err != nil {
return -1, -1, fmt.Errorf("invalid replica index %q", replicaIndexStr)
}
if replicaIndex < 0 || replicaIndex > c.maxReplicaIndex {
return -1, -1, fmt.Errorf("invalid replica index %q", replicaIndexStr)
}
}
// node index (shard_index is alias for node_index)
nodeIndexStr := params.Get("node_index")
if nodeIndexStr != "" {
nodeIndex, err = strconv.Atoi(nodeIndexStr)
if err != nil {
return -1, -1, fmt.Errorf("invalid node index %q", nodeIndexStr)
}
if nodeIndex < 0 || nodeIndex > c.maxNodeIndex {
return -1, -1, fmt.Errorf("invalid node index %q", nodeIndexStr)
}
}
shardIndexStr := params.Get("shard_index")
if shardIndexStr != "" {
nodeIndex, err = strconv.Atoi(shardIndexStr)
if err != nil {
return -1, -1, fmt.Errorf("invalid shard index %q", shardIndexStr)
}
if nodeIndex < 0 || nodeIndex > c.maxNodeIndex {
return -1, -1, fmt.Errorf("invalid shard index %q", shardIndexStr)
}
}
// validate if both replicaIndex and nodeIndex are specified
if replicaIndex > 0 && nodeIndex > 0 && nodeIndex > len(c.replicas[replicaIndex-1].hosts) {
return -1, -1, fmt.Errorf("invalid host index (%q, %q)", replicaIndexStr, nodeIndexStr)
}

return replicaIndex, nodeIndex, nil
}

// getQuerySnippet returns query snippet.
//
// getQuerySnippet must be called only for error reporting.
Expand Down
Loading