Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ func Distribute() func(c *gin.Context) {
}
}

if preferredChannelID, found := service.GetPreferredChannelByAffinity(c, modelRequest.Model, usingGroup); found {
preferred, err := model.CacheGetChannel(preferredChannelID)
if affinitySelection, found := service.GetPreferredChannelByAffinity(c, modelRequest.Model, usingGroup); found {
preferred, err := model.CacheGetChannel(affinitySelection.ChannelID)
if err == nil && preferred != nil {
if preferred.Status != common.ChannelStatusEnabled {
if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
Expand All @@ -117,14 +117,14 @@ func Distribute() func(c *gin.Context) {
selectGroup = g
common.SetContextKey(c, constant.ContextKeyAutoGroup, g)
channel = preferred
service.MarkChannelAffinityUsed(c, g, preferred.Id)
service.MarkChannelAffinityUsed(c, g, affinitySelection)
break
}
}
} else if model.IsChannelEnabledForGroupModel(usingGroup, modelRequest.Model, preferred.Id) {
channel = preferred
selectGroup = usingGroup
service.MarkChannelAffinityUsed(c, usingGroup, preferred.Id)
service.MarkChannelAffinityUsed(c, usingGroup, affinitySelection)
}
}
}
Expand Down Expand Up @@ -421,10 +421,19 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())

key, index, newAPIError := channel.GetNextEnabledKey()
if newAPIError != nil {
return newAPIError
var key string
var index int
var newAPIError *types.NewAPIError
if preferredKeyIndex, ok := service.GetChannelAffinityKeyIndex(c, channel.Id); ok {
key, index, newAPIError = channel.GetEnabledKeyByIndex(preferredKeyIndex)
}
if key == "" || newAPIError != nil {
key, index, newAPIError = channel.GetNextEnabledKey()
if newAPIError != nil {
return newAPIError
}
}
service.UpdateChannelAffinitySelectedKeyIndex(c, channel.Id, index)
if channel.ChannelInfo.IsMultiKey {
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
Expand Down
32 changes: 32 additions & 0 deletions model/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,38 @@ func (channel *Channel) GetKeys() []string {
return keys
}

func (channel *Channel) GetEnabledKeyByIndex(index int) (string, int, *types.NewAPIError) {
if !channel.ChannelInfo.IsMultiKey {
if index != 0 {
return "", 0, types.NewError(errors.New("invalid key index"), types.ErrorCodeChannelNoAvailableKey)
}
return channel.Key, 0, nil
}

keys := channel.GetKeys()
if len(keys) == 0 {
return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
}
if index < 0 || index >= len(keys) {
return "", 0, types.NewError(errors.New("invalid key index"), types.ErrorCodeChannelNoAvailableKey)
}

lock := GetChannelPollingLock(channel.Id)
lock.Lock()
defer lock.Unlock()

status := common.ChannelStatusEnabled
if channel.ChannelInfo.MultiKeyStatusList != nil {
if s, ok := channel.ChannelInfo.MultiKeyStatusList[index]; ok {
status = s
}
}
if status != common.ChannelStatusEnabled {
return "", 0, types.NewError(errors.New("key is disabled"), types.ErrorCodeChannelNoAvailableKey)
}
return keys[index], index, nil
}

func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
// If not in multi-key mode, return the original key string directly.
if !channel.ChannelInfo.IsMultiKey {
Expand Down
48 changes: 48 additions & 0 deletions model/channel_key_affinity_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package model

import (
"testing"

"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/stretchr/testify/require"
)

func TestGetEnabledKeyByIndex(t *testing.T) {
channel := &Channel{
Id: 9001,
Key: "key-a\nkey-b\nkey-c",
ChannelInfo: ChannelInfo{
IsMultiKey: true,
MultiKeyMode: constant.MultiKeyModePolling,
MultiKeyPollingIndex: 0,
MultiKeyStatusList: map[int]int{
1: common.ChannelStatusEnabled,
},
},
}

key, index, apiErr := channel.GetEnabledKeyByIndex(1)
require.Nil(t, apiErr)
require.Equal(t, "key-b", key)
require.Equal(t, 1, index)
require.Equal(t, 0, channel.ChannelInfo.MultiKeyPollingIndex)
}

func TestGetEnabledKeyByIndexDisabled(t *testing.T) {
channel := &Channel{
Id: 9002,
Key: "key-a\nkey-b",
ChannelInfo: ChannelInfo{
IsMultiKey: true,
MultiKeyStatusList: map[int]int{
1: common.ChannelStatusManuallyDisabled,
},
},
}

key, index, apiErr := channel.GetEnabledKeyByIndex(1)
require.NotNil(t, apiErr)
require.Empty(t, key)
require.Equal(t, 0, index)
}
97 changes: 79 additions & 18 deletions service/channel_affinity.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/pkg/cachex"
"github.com/QuantumNous/new-api/setting/operation_setting"
Expand All @@ -25,21 +26,28 @@ const (
ginKeyChannelAffinityMeta = "channel_affinity_meta"
ginKeyChannelAffinityLogInfo = "channel_affinity_log_info"
ginKeyChannelAffinitySkipRetry = "channel_affinity_skip_retry_on_failure"
ginKeyChannelAffinityKeyIndex = "channel_affinity_key_index"
ginKeyChannelAffinityChannelID = "channel_affinity_channel_id"

channelAffinityCacheNamespace = "new-api:channel_affinity:v1"
channelAffinityUsageCacheStatsNamespace = "new-api:channel_affinity_usage_cache_stats:v1"
channelAffinityCacheNamespace = "new-api:channel_affinity:v2"
channelAffinityUsageCacheStatsNamespace = "new-api:channel_affinity_usage_cache_stats:v2"
)

var (
channelAffinityCacheOnce sync.Once
channelAffinityCache *cachex.HybridCache[int]
channelAffinityCache *cachex.HybridCache[ChannelAffinitySelection]

channelAffinityUsageCacheStatsOnce sync.Once
channelAffinityUsageCacheStatsCache *cachex.HybridCache[ChannelAffinityUsageCacheCounters]

channelAffinityRegexCache sync.Map // map[string]*regexp.Regexp
)

type ChannelAffinitySelection struct {
ChannelID int `json:"channel_id"`
KeyIndex int `json:"key_index"`
}

type channelAffinityMeta struct {
CacheKey string
TTLSeconds int
Expand Down Expand Up @@ -78,7 +86,7 @@ type ChannelAffinityCacheStats struct {
CacheAlgo string `json:"cache_algo"`
}

func getChannelAffinityCache() *cachex.HybridCache[int] {
func getChannelAffinityCache() *cachex.HybridCache[ChannelAffinitySelection] {
channelAffinityCacheOnce.Do(func() {
setting := operation_setting.GetChannelAffinitySetting()
capacity := setting.MaxEntries
Expand All @@ -90,15 +98,15 @@ func getChannelAffinityCache() *cachex.HybridCache[int] {
defaultTTLSeconds = 3600
}

channelAffinityCache = cachex.NewHybridCache[int](cachex.HybridCacheConfig[int]{
channelAffinityCache = cachex.NewHybridCache[ChannelAffinitySelection](cachex.HybridCacheConfig[ChannelAffinitySelection]{
Namespace: cachex.Namespace(channelAffinityCacheNamespace),
Redis: common.RDB,
RedisEnabled: func() bool {
return common.RedisEnabled && common.RDB != nil
},
RedisCodec: cachex.IntCodec{},
Memory: func() *hot.HotCache[string, int] {
return hot.NewHotCache[string, int](hot.LRU, capacity).
RedisCodec: cachex.JSONCodec[ChannelAffinitySelection]{},
Memory: func() *hot.HotCache[string, ChannelAffinitySelection] {
return hot.NewHotCache[string, ChannelAffinitySelection](hot.LRU, capacity).
WithTTL(time.Duration(defaultTTLSeconds) * time.Second).
WithJanitor().
Build()
Expand Down Expand Up @@ -547,10 +555,10 @@ func ApplyChannelAffinityOverrideTemplate(c *gin.Context, paramOverride map[stri
return mergedParam, true
}

func GetPreferredChannelByAffinity(c *gin.Context, modelName string, usingGroup string) (int, bool) {
func GetPreferredChannelByAffinity(c *gin.Context, modelName string, usingGroup string) (ChannelAffinitySelection, bool) {
setting := operation_setting.GetChannelAffinitySetting()
if setting == nil || !setting.Enabled {
return 0, false
return ChannelAffinitySelection{}, false
}
path := ""
if c != nil && c.Request != nil && c.Request.URL != nil {
Expand Down Expand Up @@ -610,17 +618,17 @@ func GetPreferredChannelByAffinity(c *gin.Context, modelName string, usingGroup
})

cache := getChannelAffinityCache()
channelID, found, err := cache.Get(cacheKeySuffix)
selection, found, err := cache.Get(cacheKeySuffix)
if err != nil {
common.SysError(fmt.Sprintf("channel affinity cache get failed: key=%s, err=%v", cacheKeyFull, err))
return 0, false
return ChannelAffinitySelection{}, false
}
if found {
return channelID, true
if found && selection.ChannelID > 0 {
return selection, true
}
return 0, false
return ChannelAffinitySelection{}, false
}
return 0, false
return ChannelAffinitySelection{}, false
}

func ShouldSkipRetryAfterChannelAffinityFailure(c *gin.Context) bool {
Expand All @@ -641,7 +649,8 @@ func ShouldSkipRetryAfterChannelAffinityFailure(c *gin.Context) bool {
return meta.SkipRetry
}

func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int) {
func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, selection ChannelAffinitySelection) {
channelID := selection.ChannelID
if c == nil || channelID <= 0 {
return
}
Expand All @@ -650,6 +659,8 @@ func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int
return
}
c.Set(ginKeyChannelAffinitySkipRetry, meta.SkipRetry)
c.Set(ginKeyChannelAffinityChannelID, channelID)
c.Set(ginKeyChannelAffinityKeyIndex, selection.KeyIndex)
info := map[string]interface{}{
"reason": meta.RuleName,
"rule_name": meta.RuleName,
Expand All @@ -658,6 +669,7 @@ func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int
"model": meta.ModelName,
"request_path": meta.RequestPath,
"channel_id": channelID,
"key_index": selection.KeyIndex,
"key_source": meta.KeySourceType,
"key_key": meta.KeySourceKey,
"key_path": meta.KeySourcePath,
Expand All @@ -667,6 +679,47 @@ func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int
c.Set(ginKeyChannelAffinityLogInfo, info)
}

func GetChannelAffinityKeyIndex(c *gin.Context, channelID int) (int, bool) {
if c == nil || channelID <= 0 {
return 0, false
}
matchedChannelID, ok := c.Get(ginKeyChannelAffinityChannelID)
if !ok {
return 0, false
}
id, ok := matchedChannelID.(int)
if !ok || id != channelID {
return 0, false
}
keyIndexAny, ok := c.Get(ginKeyChannelAffinityKeyIndex)
if !ok {
return 0, false
}
keyIndex, ok := keyIndexAny.(int)
if !ok || keyIndex < 0 {
return 0, false
}
return keyIndex, true
}

func UpdateChannelAffinitySelectedKeyIndex(c *gin.Context, channelID int, keyIndex int) {
if c == nil || channelID <= 0 || keyIndex < 0 {
return
}
if _, ok := getChannelAffinityMeta(c); !ok {
return
}
c.Set(ginKeyChannelAffinityChannelID, channelID)
c.Set(ginKeyChannelAffinityKeyIndex, keyIndex)
if anyInfo, ok := c.Get(ginKeyChannelAffinityLogInfo); ok {
if info, ok := anyInfo.(map[string]interface{}); ok {
info["channel_id"] = channelID
info["key_index"] = keyIndex
c.Set(ginKeyChannelAffinityLogInfo, info)
}
}
}

func AppendChannelAffinityAdminInfo(c *gin.Context, adminInfo map[string]interface{}) {
if c == nil || adminInfo == nil {
return
Expand All @@ -691,6 +744,10 @@ func RecordChannelAffinity(c *gin.Context, channelID int) {
channelID = successChannelID
}
}
keyIndex := 0
if c != nil {
keyIndex = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
}
cacheKey, ttlSeconds, ok := getChannelAffinityContext(c)
if !ok {
return
Expand All @@ -702,7 +759,11 @@ func RecordChannelAffinity(c *gin.Context, channelID int) {
ttlSeconds = 3600
}
cache := getChannelAffinityCache()
if err := cache.SetWithTTL(cacheKey, channelID, time.Duration(ttlSeconds)*time.Second); err != nil {
selection := ChannelAffinitySelection{
ChannelID: channelID,
KeyIndex: keyIndex,
}
if err := cache.SetWithTTL(cacheKey, selection, time.Duration(ttlSeconds)*time.Second); err != nil {
common.SysError(fmt.Sprintf("channel affinity cache set failed: key=%s, err=%v", cacheKey, err))
}
}
Expand Down
Loading