Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,11 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
}()

retryParam := &service.RetryParam{
Ctx: c,
TokenGroup: relayInfo.TokenGroup,
ModelName: relayInfo.OriginModelName,
Retry: common.GetPointer(0),
Ctx: c,
TokenGroup: relayInfo.TokenGroup,
ModelName: relayInfo.OriginModelName,
RelayFormat: relayInfo.RelayFormat,
Retry: common.GetPointer(0),
}
relayInfo.RetryIndex = 0
relayInfo.LastError = nil
Expand Down Expand Up @@ -507,10 +508,11 @@ func RelayTask(c *gin.Context) {
}()

retryParam := &service.RetryParam{
Ctx: c,
TokenGroup: relayInfo.TokenGroup,
ModelName: relayInfo.OriginModelName,
Retry: common.GetPointer(0),
Ctx: c,
TokenGroup: relayInfo.TokenGroup,
ModelName: relayInfo.OriginModelName,
RelayFormat: relayInfo.RelayFormat,
Retry: common.GetPointer(0),
}

for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
Expand Down
9 changes: 5 additions & 4 deletions middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,11 @@ func Distribute() func(c *gin.Context) {

if channel == nil {
channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(&service.RetryParam{
Ctx: c,
ModelName: modelRequest.Model,
TokenGroup: usingGroup,
Retry: common.GetPointer(0),
Ctx: c,
ModelName: modelRequest.Model,
TokenGroup: usingGroup,
RelayFormat: types.InferRelayFormatFromPath(c.Request.URL.Path),
Retry: common.GetPointer(0),
})
if err != nil {
showGroup := usingGroup
Expand Down
106 changes: 100 additions & 6 deletions model/ability.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"sync"

"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/types"

"github.com/samber/lo"
"gorm.io/gorm"
Expand Down Expand Up @@ -103,18 +105,110 @@ func getChannelQuery(group string, model string, retry int) (*gorm.DB, error) {
return channelQuery, nil
}

func GetChannel(group string, model string, retry int) (*Channel, error) {
var abilities []Ability

var err error = nil
// getChannelQueryWithAPIType returns a query that filters channels by priority,
// and when multiple channels share the same priority, prefers those whose
// channel type matches the expected API type (for smart routing).
//
// Both abilities and channels expose a `group` column, so any reference to
// `group` inside JOINed queries must be qualified with `abilities.` to avoid
// "ambiguous column name" errors (notably on SQLite).
func getChannelQueryWithAPIType(group string, model string, retry int, expectedAPIType int) (*gorm.DB, error) {
channelQuery, err := getChannelQuery(group, model, retry)
if err != nil {
return nil, err
}

abilitiesGroupCol := "abilities." + commonGroupCol

// Resolve the priority value once, shared by the probing query and the
// final filtered query. Reusing the existing channelQuery here would be
// unsafe because its WHERE clause references the bare `group` column.
var priorityValue interface{}
if retry == 0 {
priorityValue = DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
} else {
p, err := getPriority(group, model, retry)
if err != nil {
return channelQuery, nil
}
priorityValue = p
}

var abilities []AbilityWithChannel
err = DB.Table("abilities").
Select("abilities.*, channels.type as channel_type").
Joins("left join channels on abilities.channel_id = channels.id").
Where(abilitiesGroupCol+" = ? and abilities.model = ? and abilities.enabled = ? and abilities.priority = (?)", group, model, true, priorityValue).
Scan(&abilities).Error
if err != nil {
return channelQuery, nil // fall back to original query
}

if len(abilities) <= 1 {
return channelQuery, nil
}

// Check if any channel matches the expected API type
var hasMatch bool
for _, ab := range abilities {
channelAPIType, ok := common.ChannelType2APIType(ab.ChannelType)
if ok && channelAPIType == expectedAPIType {
hasMatch = true
break
}
}

if hasMatch {
filteredQuery := DB.Table("abilities").
Select("abilities.*").
Joins("left join channels on abilities.channel_id = channels.id").
Where(abilitiesGroupCol+" = ? and abilities.model = ? and abilities.enabled = ? and abilities.priority = (?)", group, model, true, priorityValue).
Where("channels.type IN (?)", getMatchingChannelTypes(expectedAPIType))
return filteredQuery, nil
}

return channelQuery, nil
}

// getMatchingChannelTypes returns channel types that map to the given API type.
func getMatchingChannelTypes(expectedAPIType int) []int {
var types []int
for i := 1; i < constant.ChannelTypeDummy; i++ {
apiType, ok := common.ChannelType2APIType(i)
if ok && apiType == expectedAPIType {
types = append(types, i)
}
}
return types
}

func GetChannel(group string, model string, retry int, relayFormat types.RelayFormat) (*Channel, error) {
var abilities []Ability

var err error = nil
var channelQuery *gorm.DB

// Use smart routing when relayFormat is provided and memory cache is disabled
if relayFormat != "" {
if expectedAPIType, ok := types.RelayFormatToAPIType(relayFormat); ok {
channelQuery, err = getChannelQueryWithAPIType(group, model, retry, expectedAPIType)
if err != nil {
return nil, err
}
}
}

if channelQuery == nil {
channelQuery, err = getChannelQuery(group, model, retry)
if err != nil {
return nil, err
}
}

if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("weight DESC").Find(&abilities).Error
err = channelQuery.Order("abilities.weight DESC").Find(&abilities).Error
} else {
err = channelQuery.Order("weight DESC").Find(&abilities).Error
err = channelQuery.Order("abilities.weight DESC").Find(&abilities).Error
}
if err != nil {
return nil, err
Expand Down
33 changes: 31 additions & 2 deletions model/channel_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/types"
)

var group2model2channels map[string]map[string][]int // enabled channel
Expand Down Expand Up @@ -94,10 +95,10 @@ func SyncChannelCache(frequency int) {
}
}

func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
func GetRandomSatisfiedChannel(group string, model string, retry int, relayFormat types.RelayFormat) (*Channel, error) {
// if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled {
return GetChannel(group, model, retry)
return GetChannel(group, model, retry, relayFormat)
}

channelSyncLock.RLock()
Expand Down Expand Up @@ -160,6 +161,34 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
return nil, errors.New(fmt.Sprintf("no channel found, group: %s, model: %s, priority: %d", group, model, targetPriority))
}

// Smart routing: when multiple channels share the same priority,
// prefer channels whose native API type matches the client request format.
// This avoids unnecessary request/response conversion.
if len(targetChannels) > 1 && relayFormat != "" {
if expectedAPIType, ok := types.RelayFormatToAPIType(relayFormat); ok {
var preferredChannels []*Channel
var fallbackChannels []*Channel
for _, ch := range targetChannels {
channelAPIType, typeOk := common.ChannelType2APIType(ch.Type)
if typeOk && channelAPIType == expectedAPIType {
preferredChannels = append(preferredChannels, ch)
} else {
fallbackChannels = append(fallbackChannels, ch)
}
}
// Only use preferred channels if at least one matches;
// otherwise fall back to the original set.
if len(preferredChannels) > 0 {
targetChannels = preferredChannels
// Recalculate sumWeight for the filtered set
sumWeight = 0
for _, ch := range targetChannels {
sumWeight += ch.GetWeight()
}
}
}
}

// smoothing factor and adjustment
smoothingFactor := 1
smoothingAdjustment := 0
Expand Down
6 changes: 4 additions & 2 deletions service/channel_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ import (
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)

type RetryParam struct {
Ctx *gin.Context
TokenGroup string
ModelName string
RelayFormat types.RelayFormat // client request API format for smart channel routing
Retry *int
resetNextTry bool
}
Expand Down Expand Up @@ -115,7 +117,7 @@ func CacheGetRandomSatisfiedChannel(param *RetryParam) (*model.Channel, string,
}
logger.LogDebug(param.Ctx, "Auto selecting group: %s, priorityRetry: %d", autoGroup, priorityRetry)

channel, _ = model.GetRandomSatisfiedChannel(autoGroup, param.ModelName, priorityRetry)
channel, _ = model.GetRandomSatisfiedChannel(autoGroup, param.ModelName, priorityRetry, param.RelayFormat)
if channel == nil {
// Current group has no available channel for this model, try next group
// 当前分组没有该模型的可用渠道,尝试下一个分组
Expand Down Expand Up @@ -153,7 +155,7 @@ func CacheGetRandomSatisfiedChannel(param *RetryParam) (*model.Channel, string,
break
}
} else {
channel, err = model.GetRandomSatisfiedChannel(param.TokenGroup, param.ModelName, param.GetRetry())
channel, err = model.GetRandomSatisfiedChannel(param.TokenGroup, param.ModelName, param.GetRetry(), param.RelayFormat)
if err != nil {
return nil, param.TokenGroup, err
}
Expand Down
59 changes: 59 additions & 0 deletions types/relay_format.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package types

import "strings"

type RelayFormat string

const (
Expand All @@ -17,3 +19,60 @@ const (
RelayFormatTask = "task"
RelayFormatMjProxy = "mj_proxy"
)

// RelayFormatToAPIType maps the client request relay format to the expected provider API type.
// This is used for smart channel routing: prefer channels whose native API type matches
// the client's request format, avoiding unnecessary request/response conversion.
func RelayFormatToAPIType(relayFormat RelayFormat) (int, bool) {
switch relayFormat {
case RelayFormatOpenAI, RelayFormatOpenAIAudio, RelayFormatOpenAIImage, RelayFormatOpenAIRealtime, RelayFormatOpenAIResponses, RelayFormatOpenAIResponsesCompaction:
return 0, true // APITypeOpenAI
case RelayFormatClaude:
return 1, true // APITypeAnthropic
case RelayFormatGemini:
return 13, true // APITypeGemini
case RelayFormatEmbedding:
// Embedding requests can be handled by multiple provider types;
// return OpenAI as the most common format, but let the caller
// decide whether to enforce strict matching.
return 0, true // APITypeOpenAI
case RelayFormatRerank:
return 0, true // APITypeOpenAI
default:
return 0, false
}
}

// InferRelayFormatFromPath returns the RelayFormat that the request to the given URL path will
// eventually be relayed as. The middleware Distribute() runs before the per-route handler that
// sets the format explicitly, so smart channel routing has to peek at the path here.
//
// Keep this in sync with router/relay-router.go. Unknown paths return an empty RelayFormat,
// which downstream callers (e.g. model.GetChannel) treat as "no API-type hint" and fall back
// to the original priority/weight-based selection.
func InferRelayFormatFromPath(path string) RelayFormat {
switch {
case strings.HasPrefix(path, "/v1/messages"):
return RelayFormatClaude
case strings.HasPrefix(path, "/v1/responses/compact"):
return RelayFormatOpenAIResponsesCompaction
case strings.HasPrefix(path, "/v1/responses"):
return RelayFormatOpenAIResponses
case strings.HasPrefix(path, "/v1/realtime"):
return RelayFormatOpenAIRealtime
case strings.HasPrefix(path, "/v1/embeddings"):
return RelayFormatEmbedding
case strings.HasPrefix(path, "/v1/rerank"):
return RelayFormatRerank
case strings.HasPrefix(path, "/v1/audio/"):
return RelayFormatOpenAIAudio
case strings.HasPrefix(path, "/v1/images/"), strings.HasPrefix(path, "/v1/edits"):
return RelayFormatOpenAIImage
case strings.HasPrefix(path, "/v1/engines/") && strings.HasSuffix(path, "/embeddings"):
return RelayFormatEmbedding
case strings.HasPrefix(path, "/v1beta/models/"), strings.HasPrefix(path, "/v1/models/"):
return RelayFormatGemini
default:
return ""
}
}