diff --git a/controller/relay.go b/controller/relay.go index 5e2db44c25a..538c27b9d61 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -179,10 +179,11 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { }() retryParam := &service.RetryParam{ - Ctx: c, - TokenGroup: relayInfo.TokenGroup, - ModelName: relayInfo.OriginModelName, - Retry: common.GetPointer(0), + Ctx: c, + TokenGroup: relayInfo.TokenGroup, + ModelName: relayInfo.OriginModelName, + RelayFormat: relayInfo.RelayFormat, + Retry: common.GetPointer(0), } relayInfo.RetryIndex = 0 relayInfo.LastError = nil @@ -507,10 +508,11 @@ func RelayTask(c *gin.Context) { }() retryParam := &service.RetryParam{ - Ctx: c, - TokenGroup: relayInfo.TokenGroup, - ModelName: relayInfo.OriginModelName, - Retry: common.GetPointer(0), + Ctx: c, + TokenGroup: relayInfo.TokenGroup, + ModelName: relayInfo.OriginModelName, + RelayFormat: relayInfo.RelayFormat, + Retry: common.GetPointer(0), } for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() { diff --git a/middleware/distributor.go b/middleware/distributor.go index 2263fae3fae..5fed4ca6011 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -129,10 +129,11 @@ func Distribute() func(c *gin.Context) { if channel == nil { channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(&service.RetryParam{ - Ctx: c, - ModelName: modelRequest.Model, - TokenGroup: usingGroup, - Retry: common.GetPointer(0), + Ctx: c, + ModelName: modelRequest.Model, + TokenGroup: usingGroup, + RelayFormat: types.InferRelayFormatFromPath(c.Request.URL.Path), + Retry: common.GetPointer(0), }) if err != nil { showGroup := usingGroup diff --git a/model/ability.go b/model/ability.go index 1d7c53fa580..29d1db43273 100644 --- a/model/ability.go +++ b/model/ability.go @@ -7,6 +7,8 @@ import ( "sync" "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/types" "github.com/samber/lo" "gorm.io/gorm" @@ -103,18 +105,110 @@ func getChannelQuery(group string, model string, retry int) (*gorm.DB, error) { return channelQuery, nil } -func GetChannel(group string, model string, retry int) (*Channel, error) { - var abilities []Ability - - var err error = nil +// getChannelQueryWithAPIType returns a query that filters channels by priority, +// and when multiple channels share the same priority, prefers those whose +// channel type matches the expected API type (for smart routing). +// +// Both abilities and channels expose a `group` column, so any reference to +// `group` inside JOINed queries must be qualified with `abilities.` to avoid +// "ambiguous column name" errors (notably on SQLite). +func getChannelQueryWithAPIType(group string, model string, retry int, expectedAPIType int) (*gorm.DB, error) { channelQuery, err := getChannelQuery(group, model, retry) if err != nil { return nil, err } + + abilitiesGroupCol := "abilities." + commonGroupCol + + // Resolve the priority value once, shared by the probing query and the + // final filtered query. Reusing the existing channelQuery here would be + // unsafe because its WHERE clause references the bare `group` column. + var priorityValue interface{} + if retry == 0 { + priorityValue = DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true) + } else { + p, err := getPriority(group, model, retry) + if err != nil { + return channelQuery, nil + } + priorityValue = p + } + + var abilities []AbilityWithChannel + err = DB.Table("abilities"). + Select("abilities.*, channels.type as channel_type"). + Joins("left join channels on abilities.channel_id = channels.id"). + Where(abilitiesGroupCol+" = ? and abilities.model = ? and abilities.enabled = ? and abilities.priority = (?)", group, model, true, priorityValue). + Scan(&abilities).Error + if err != nil { + return channelQuery, nil // fall back to original query + } + + if len(abilities) <= 1 { + return channelQuery, nil + } + + // Check if any channel matches the expected API type + var hasMatch bool + for _, ab := range abilities { + channelAPIType, ok := common.ChannelType2APIType(ab.ChannelType) + if ok && channelAPIType == expectedAPIType { + hasMatch = true + break + } + } + + if hasMatch { + filteredQuery := DB.Table("abilities"). + Select("abilities.*"). + Joins("left join channels on abilities.channel_id = channels.id"). + Where(abilitiesGroupCol+" = ? and abilities.model = ? and abilities.enabled = ? and abilities.priority = (?)", group, model, true, priorityValue). + Where("channels.type IN (?)", getMatchingChannelTypes(expectedAPIType)) + return filteredQuery, nil + } + + return channelQuery, nil +} + +// getMatchingChannelTypes returns channel types that map to the given API type. +func getMatchingChannelTypes(expectedAPIType int) []int { + var types []int + for i := 1; i < constant.ChannelTypeDummy; i++ { + apiType, ok := common.ChannelType2APIType(i) + if ok && apiType == expectedAPIType { + types = append(types, i) + } + } + return types +} + +func GetChannel(group string, model string, retry int, relayFormat types.RelayFormat) (*Channel, error) { + var abilities []Ability + + var err error = nil + var channelQuery *gorm.DB + + // Use smart routing when relayFormat is provided and memory cache is disabled + if relayFormat != "" { + if expectedAPIType, ok := types.RelayFormatToAPIType(relayFormat); ok { + channelQuery, err = getChannelQueryWithAPIType(group, model, retry, expectedAPIType) + if err != nil { + return nil, err + } + } + } + + if channelQuery == nil { + channelQuery, err = getChannelQuery(group, model, retry) + if err != nil { + return nil, err + } + } + if common.UsingSQLite || common.UsingPostgreSQL { - err = channelQuery.Order("weight DESC").Find(&abilities).Error + err = channelQuery.Order("abilities.weight DESC").Find(&abilities).Error } else { - err = channelQuery.Order("weight DESC").Find(&abilities).Error + err = channelQuery.Order("abilities.weight DESC").Find(&abilities).Error } if err != nil { return nil, err diff --git a/model/channel_cache.go b/model/channel_cache.go index 03740d2cd3a..0e9c95b9311 100644 --- a/model/channel_cache.go +++ b/model/channel_cache.go @@ -13,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/setting/ratio_setting" + "github.com/QuantumNous/new-api/types" ) var group2model2channels map[string]map[string][]int // enabled channel @@ -94,10 +95,10 @@ func SyncChannelCache(frequency int) { } } -func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { +func GetRandomSatisfiedChannel(group string, model string, retry int, relayFormat types.RelayFormat) (*Channel, error) { // if memory cache is disabled, get channel directly from database if !common.MemoryCacheEnabled { - return GetChannel(group, model, retry) + return GetChannel(group, model, retry, relayFormat) } channelSyncLock.RLock() @@ -160,6 +161,34 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, return nil, errors.New(fmt.Sprintf("no channel found, group: %s, model: %s, priority: %d", group, model, targetPriority)) } + // Smart routing: when multiple channels share the same priority, + // prefer channels whose native API type matches the client request format. + // This avoids unnecessary request/response conversion. + if len(targetChannels) > 1 && relayFormat != "" { + if expectedAPIType, ok := types.RelayFormatToAPIType(relayFormat); ok { + var preferredChannels []*Channel + var fallbackChannels []*Channel + for _, ch := range targetChannels { + channelAPIType, typeOk := common.ChannelType2APIType(ch.Type) + if typeOk && channelAPIType == expectedAPIType { + preferredChannels = append(preferredChannels, ch) + } else { + fallbackChannels = append(fallbackChannels, ch) + } + } + // Only use preferred channels if at least one matches; + // otherwise fall back to the original set. + if len(preferredChannels) > 0 { + targetChannels = preferredChannels + // Recalculate sumWeight for the filtered set + sumWeight = 0 + for _, ch := range targetChannels { + sumWeight += ch.GetWeight() + } + } + } + } + // smoothing factor and adjustment smoothingFactor := 1 smoothingAdjustment := 0 diff --git a/service/channel_select.go b/service/channel_select.go index a3710ef8cec..5228622c2c2 100644 --- a/service/channel_select.go +++ b/service/channel_select.go @@ -8,6 +8,7 @@ import ( "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/setting" + "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" ) @@ -15,6 +16,7 @@ type RetryParam struct { Ctx *gin.Context TokenGroup string ModelName string + RelayFormat types.RelayFormat // client request API format for smart channel routing Retry *int resetNextTry bool } @@ -115,7 +117,7 @@ func CacheGetRandomSatisfiedChannel(param *RetryParam) (*model.Channel, string, } logger.LogDebug(param.Ctx, "Auto selecting group: %s, priorityRetry: %d", autoGroup, priorityRetry) - channel, _ = model.GetRandomSatisfiedChannel(autoGroup, param.ModelName, priorityRetry) + channel, _ = model.GetRandomSatisfiedChannel(autoGroup, param.ModelName, priorityRetry, param.RelayFormat) if channel == nil { // Current group has no available channel for this model, try next group // 当前分组没有该模型的可用渠道,尝试下一个分组 @@ -153,7 +155,7 @@ func CacheGetRandomSatisfiedChannel(param *RetryParam) (*model.Channel, string, break } } else { - channel, err = model.GetRandomSatisfiedChannel(param.TokenGroup, param.ModelName, param.GetRetry()) + channel, err = model.GetRandomSatisfiedChannel(param.TokenGroup, param.ModelName, param.GetRetry(), param.RelayFormat) if err != nil { return nil, param.TokenGroup, err } diff --git a/types/relay_format.go b/types/relay_format.go index 9b4c86f2493..82d7fb8e628 100644 --- a/types/relay_format.go +++ b/types/relay_format.go @@ -1,5 +1,7 @@ package types +import "strings" + type RelayFormat string const ( @@ -17,3 +19,60 @@ const ( RelayFormatTask = "task" RelayFormatMjProxy = "mj_proxy" ) + +// RelayFormatToAPIType maps the client request relay format to the expected provider API type. +// This is used for smart channel routing: prefer channels whose native API type matches +// the client's request format, avoiding unnecessary request/response conversion. +func RelayFormatToAPIType(relayFormat RelayFormat) (int, bool) { + switch relayFormat { + case RelayFormatOpenAI, RelayFormatOpenAIAudio, RelayFormatOpenAIImage, RelayFormatOpenAIRealtime, RelayFormatOpenAIResponses, RelayFormatOpenAIResponsesCompaction: + return 0, true // APITypeOpenAI + case RelayFormatClaude: + return 1, true // APITypeAnthropic + case RelayFormatGemini: + return 13, true // APITypeGemini + case RelayFormatEmbedding: + // Embedding requests can be handled by multiple provider types; + // return OpenAI as the most common format, but let the caller + // decide whether to enforce strict matching. + return 0, true // APITypeOpenAI + case RelayFormatRerank: + return 0, true // APITypeOpenAI + default: + return 0, false + } +} + +// InferRelayFormatFromPath returns the RelayFormat that the request to the given URL path will +// eventually be relayed as. The middleware Distribute() runs before the per-route handler that +// sets the format explicitly, so smart channel routing has to peek at the path here. +// +// Keep this in sync with router/relay-router.go. Unknown paths return an empty RelayFormat, +// which downstream callers (e.g. model.GetChannel) treat as "no API-type hint" and fall back +// to the original priority/weight-based selection. +func InferRelayFormatFromPath(path string) RelayFormat { + switch { + case strings.HasPrefix(path, "/v1/messages"): + return RelayFormatClaude + case strings.HasPrefix(path, "/v1/responses/compact"): + return RelayFormatOpenAIResponsesCompaction + case strings.HasPrefix(path, "/v1/responses"): + return RelayFormatOpenAIResponses + case strings.HasPrefix(path, "/v1/realtime"): + return RelayFormatOpenAIRealtime + case strings.HasPrefix(path, "/v1/embeddings"): + return RelayFormatEmbedding + case strings.HasPrefix(path, "/v1/rerank"): + return RelayFormatRerank + case strings.HasPrefix(path, "/v1/audio/"): + return RelayFormatOpenAIAudio + case strings.HasPrefix(path, "/v1/images/"), strings.HasPrefix(path, "/v1/edits"): + return RelayFormatOpenAIImage + case strings.HasPrefix(path, "/v1/engines/") && strings.HasSuffix(path, "/embeddings"): + return RelayFormatEmbedding + case strings.HasPrefix(path, "/v1beta/models/"), strings.HasPrefix(path, "/v1/models/"): + return RelayFormatGemini + default: + return "" + } +}