diff --git a/model/channel.go b/model/channel.go index 3e6d1866a09..78a1477c327 100644 --- a/model/channel.go +++ b/model/channel.go @@ -643,13 +643,25 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason if len(keys) == 0 { channel.Status = status } else { - var keyIndex int + keyIndex := -1 for i, key := range keys { if key == usingKey { keyIndex = i break } } + if keyIndex < 0 { + if usingKey != "" { + common.SysLog(fmt.Sprintf("failed to update multi-key status: channel_id=%d, using key not found", channel.Id)) + return + } + channel.Status = status + info := channel.GetOtherInfo() + info["status_reason"] = reason + info["status_time"] = common.GetTimestamp() + channel.SetOtherInfo(info) + return + } if channel.ChannelInfo.MultiKeyStatusList == nil { channel.ChannelInfo.MultiKeyStatusList = make(map[int]int) } @@ -666,16 +678,31 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp() } - if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize { + if !hasEnabledMultiKey(keys, channel.ChannelInfo.MultiKeyStatusList) { channel.Status = common.ChannelStatusAutoDisabled info := channel.GetOtherInfo() info["status_reason"] = "All keys are disabled" info["status_time"] = common.GetTimestamp() channel.SetOtherInfo(info) + } else if status == common.ChannelStatusEnabled { + channel.Status = common.ChannelStatusEnabled } } } +func hasEnabledMultiKey(keys []string, statusList map[int]int) bool { + for i := range keys { + if statusList == nil { + return true + } + status, ok := statusList[i] + if !ok || status == common.ChannelStatusEnabled { + return true + } + } + return false +} + func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool { if common.MemoryCacheEnabled { channelStatusLock.Lock() @@ -687,11 +714,15 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri } if channelCache.ChannelInfo.IsMultiKey { // Use per-channel lock to prevent concurrent map read/write with GetNextEnabledKey + beforeStatus := channelCache.Status pollingLock := GetChannelPollingLock(channelId) pollingLock.Lock() // 如果是多Key模式,更新缓存中的状态 handlerMultiKeyUpdate(channelCache, usingKey, status, reason) pollingLock.Unlock() + if beforeStatus != channelCache.Status { + CacheUpdateChannelStatus(channelId, channelCache.Status) + } //CacheUpdateChannel(channelCache) //return true } else { diff --git a/model/channel_cache_test.go b/model/channel_cache_test.go new file mode 100644 index 00000000000..61795c81c5e --- /dev/null +++ b/model/channel_cache_test.go @@ -0,0 +1,71 @@ +package model + +import ( + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func withChannelCacheTestState(t *testing.T) { + t.Helper() + originalMemoryCacheEnabled := common.MemoryCacheEnabled + originalGroup2Model2Channels := group2model2channels + originalChannelsIDM := channelsIDM + + common.MemoryCacheEnabled = true + group2model2channels = map[string]map[string][]int{} + channelsIDM = map[int]*Channel{} + + t.Cleanup(func() { + common.MemoryCacheEnabled = originalMemoryCacheEnabled + group2model2channels = originalGroup2Model2Channels + channelsIDM = originalChannelsIDM + }) +} + +func TestUpdateChannelStatusEvictsMultiKeyChannelFromRouteCache(t *testing.T) { + withChannelCacheTestState(t) + + channelsIDM = map[int]*Channel{ + 1: { + Id: 1, + Status: common.ChannelStatusEnabled, + Key: "k1", + Group: "default", + Models: "gpt-test", + ChannelInfo: ChannelInfo{ + IsMultiKey: true, + MultiKeySize: 1, + MultiKeyStatusList: map[int]int{}, + }, + }, + 2: { + Id: 2, + Status: common.ChannelStatusEnabled, + Group: "default", + Models: "gpt-test", + }, + } + group2model2channels = map[string]map[string][]int{ + "default": {"gpt-test": {1, 2}}, + } + + cache := channelsIDM[1] + pollingLock := GetChannelPollingLock(cache.Id) + pollingLock.Lock() + beforeStatus := cache.Status + handlerMultiKeyUpdate(cache, "k1", common.ChannelStatusAutoDisabled, "test reason") + pollingLock.Unlock() + require.NotEqual(t, beforeStatus, cache.Status, "channel should auto-disable when all keys are disabled") + CacheUpdateChannelStatus(cache.Id, cache.Status) + + assert.NotContains(t, group2model2channels["default"]["gpt-test"], 1, + "auto-disabled multi-key channel should be removed from route cache") + + channel, err := GetRandomSatisfiedChannel("default", "gpt-test", 0) + require.NoError(t, err) + require.NotNil(t, channel) + assert.Equal(t, 2, channel.Id) +}