Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 77 additions & 4 deletions controller/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package controller
import (
"fmt"
"net/http"
"strings"
"time"

"github.com/QuantumNous/new-api/common"
Expand Down Expand Up @@ -220,10 +221,7 @@ func ListModels(c *gin.Context, modelType int) {
case constant.ChannelTypeGemini:
userGeminiModels := make([]dto.GeminiModel, len(userOpenAiModels))
for i, model := range userOpenAiModels {
userGeminiModels[i] = dto.GeminiModel{
Name: model.Id,
DisplayName: model.Id,
}
userGeminiModels[i] = buildGeminiModel(model)
}
c.JSON(200, gin.H{
"models": userGeminiModels,
Expand All @@ -238,6 +236,81 @@ func ListModels(c *gin.Context, modelType int) {
}
}

func buildGeminiModel(openAIModel dto.OpenAIModels) dto.GeminiModel {
normalizedModelID := normalizeGeminiModelID(openAIModel.Id)
return dto.GeminiModel{
Name: fmt.Sprintf("models/%s", normalizedModelID),
DisplayName: normalizedModelID,
SupportedGenerationMethods: getGeminiSupportedGenerationMethods(openAIModel),
}
}

func getGeminiSupportedGenerationMethods(openAIModel dto.OpenAIModels) []string {
if methods := getGeminiSupportedGenerationMethodsFromEndpoints(openAIModel.SupportedEndpointTypes); len(methods) > 0 {
return methods
}
if methods := getGeminiSupportedGenerationMethodsFromModelID(openAIModel.Id); len(methods) > 0 {
return methods
}

return []string{"generateContent"}
}

func getGeminiSupportedGenerationMethodsFromEndpoints(endpointTypes []constant.EndpointType) []string {
methods := make([]string, 0, 2)
for _, endpointType := range endpointTypes {
switch endpointType {
case constant.EndpointTypeImageGeneration:
methods = appendUniqueGeminiMethod(methods, "predict")
case constant.EndpointTypeEmbeddings:
methods = appendUniqueGeminiMethod(methods, "embedContent")
methods = appendUniqueGeminiMethod(methods, "batchEmbedContents")
}
}
return methods
}

func getGeminiSupportedGenerationMethodsFromModelID(modelID string) []string {
normalizedModelID := normalizeGeminiModelID(modelID)
switch {
case isGeminiVideoModel(normalizedModelID):
return []string{"predictLongRunning"}
case isGeminiEmbeddingModel(normalizedModelID):
return []string{"embedContent", "batchEmbedContents"}
case isGeminiPredictModel(normalizedModelID):
return []string{"predict"}
}

return nil
}

func normalizeGeminiModelID(modelID string) string {
return strings.ToLower(strings.TrimPrefix(modelID, "models/"))
}

func isGeminiEmbeddingModel(modelID string) bool {
return strings.HasPrefix(modelID, "text-embedding") ||
strings.HasPrefix(modelID, "embedding") ||
strings.HasPrefix(modelID, "gemini-embedding")
}

func isGeminiPredictModel(modelID string) bool {
return strings.HasPrefix(modelID, "imagen")
}

func isGeminiVideoModel(modelID string) bool {
return strings.HasPrefix(modelID, "veo-")
}

func appendUniqueGeminiMethod(methods []string, method string) []string {
for _, existing := range methods {
if existing == method {
return methods
}
}
return append(methods, method)
}

func ChannelListModels(c *gin.Context) {
c.JSON(200, gin.H{
"success": true,
Expand Down
105 changes: 105 additions & 0 deletions controller/model_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ type listModelsResponse struct {
Object string `json:"object"`
}

type geminiListModelsResponse struct {
Models []dto.GeminiModel `json:"models"`
NextPageToken interface{} `json:"nextPageToken"`
}

func setupModelListControllerTestDB(t *testing.T) *gorm.DB {
t.Helper()

Expand Down Expand Up @@ -146,6 +151,65 @@ func decodeListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder)
return ids
}

func decodeGeminiListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]dto.GeminiModel {
t.Helper()

require.Equal(t, http.StatusOK, recorder.Code)
var payload geminiListModelsResponse
require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &payload))

modelsByName := make(map[string]dto.GeminiModel, len(payload.Models))
for _, item := range payload.Models {
name, ok := item.Name.(string)
require.True(t, ok)
modelsByName[name] = item
}
return modelsByName
}

func geminiMethods(model dto.GeminiModel) []string {
return append([]string(nil), model.SupportedGenerationMethods...)
}

func TestBuildGeminiModelNormalizesPrefixedID(t *testing.T) {
geminiModel := buildGeminiModel(dto.OpenAIModels{Id: "models/gemini-2.5-flash"})

require.Equal(t, "models/gemini-2.5-flash", geminiModel.Name)
require.Equal(t, "gemini-2.5-flash", geminiModel.DisplayName)
require.ElementsMatch(t, []string{"generateContent"}, geminiModel.SupportedGenerationMethods)
}

func TestGetGeminiSupportedGenerationMethodsUsesEndpointMetadata(t *testing.T) {
require.ElementsMatch(t,
[]string{"predict"},
getGeminiSupportedGenerationMethods(dto.OpenAIModels{
Id: "custom-image-model",
SupportedEndpointTypes: []constant.EndpointType{constant.EndpointTypeImageGeneration, constant.EndpointTypeGemini},
}),
)
require.ElementsMatch(t,
[]string{"embedContent", "batchEmbedContents"},
getGeminiSupportedGenerationMethods(dto.OpenAIModels{
Id: "custom-embedding-model",
SupportedEndpointTypes: []constant.EndpointType{constant.EndpointTypeEmbeddings, constant.EndpointTypeGemini},
}),
)
require.ElementsMatch(t,
[]string{"embedContent", "batchEmbedContents"},
getGeminiSupportedGenerationMethods(dto.OpenAIModels{
Id: "gemini-embedding-001",
SupportedEndpointTypes: []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI},
}),
)
require.ElementsMatch(t,
[]string{"generateContent"},
getGeminiSupportedGenerationMethods(dto.OpenAIModels{
Id: "custom-text-model",
SupportedEndpointTypes: []constant.EndpointType{constant.EndpointTypeGemini},
}),
)
}

func pricingByModelName(pricings []model.Pricing) map[string]model.Pricing {
byName := make(map[string]model.Pricing, len(pricings))
for _, pricing := range pricings {
Expand Down Expand Up @@ -240,3 +304,44 @@ func TestListModelsTokenLimitIncludesTieredBillingModel(t *testing.T) {
require.NotContains(t, ids, "zz-token-tiered-missing-expr-model")
require.NotContains(t, ids, "zz-token-unpriced-model")
}

func TestListModelsGeminiIncludesSupportedGenerationMethods(t *testing.T) {
originalSelfUseMode := operation_setting.SelfUseModeEnabled
operation_setting.SelfUseModeEnabled = true
t.Cleanup(func() {
operation_setting.SelfUseModeEnabled = originalSelfUseMode
model.InvalidatePricingCache()
})

db := setupModelListControllerTestDB(t)
require.NoError(t, db.Create(&model.User{
Id: 1002,
Username: "gemini-model-list-user",
Password: "password",
Group: "default",
Status: common.UserStatusEnabled,
}).Error)
require.NoError(t, db.Create(&[]model.Channel{
{Id: 1, Type: constant.ChannelTypeGemini, Status: common.ChannelStatusEnabled, Models: "gemini-2.5-flash,gemini-embedding-001,imagen-4.0-generate-001", Group: "default"},
}).Error)
require.NoError(t, db.Create(&[]model.Ability{
{Group: "default", Model: "gemini-2.5-flash", ChannelId: 1, Enabled: true},
{Group: "default", Model: "gemini-embedding-001", ChannelId: 1, Enabled: true},
{Group: "default", Model: "imagen-4.0-generate-001", ChannelId: 1, Enabled: true},
}).Error)
model.InvalidatePricingCache()
model.GetPricing()

recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodGet, "/v1beta/models", nil)
ctx.Set("id", 1002)

ListModels(ctx, constant.ChannelTypeGemini)

models := decodeGeminiListModelsResponse(t, recorder)

require.ElementsMatch(t, []string{"generateContent"}, geminiMethods(models["models/gemini-2.5-flash"]))
require.ElementsMatch(t, []string{"embedContent", "batchEmbedContents"}, geminiMethods(models["models/gemini-embedding-001"]))
require.ElementsMatch(t, []string{"predict"}, geminiMethods(models["models/imagen-4.0-generate-001"]))
}
26 changes: 13 additions & 13 deletions dto/pricing.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ type AnthropicModel struct {
}

type GeminiModel struct {
Name interface{} `json:"name"`
BaseModelId interface{} `json:"baseModelId"`
Version interface{} `json:"version"`
DisplayName interface{} `json:"displayName"`
Description interface{} `json:"description"`
InputTokenLimit interface{} `json:"inputTokenLimit"`
OutputTokenLimit interface{} `json:"outputTokenLimit"`
SupportedGenerationMethods []interface{} `json:"supportedGenerationMethods"`
Thinking interface{} `json:"thinking"`
Temperature interface{} `json:"temperature"`
MaxTemperature interface{} `json:"maxTemperature"`
TopP interface{} `json:"topP"`
TopK interface{} `json:"topK"`
Name interface{} `json:"name"`
BaseModelId interface{} `json:"baseModelId"`
Version interface{} `json:"version"`
DisplayName interface{} `json:"displayName"`
Description interface{} `json:"description"`
InputTokenLimit interface{} `json:"inputTokenLimit"`
OutputTokenLimit interface{} `json:"outputTokenLimit"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
Thinking interface{} `json:"thinking"`
Temperature interface{} `json:"temperature"`
MaxTemperature interface{} `json:"maxTemperature"`
TopP interface{} `json:"topP"`
TopK interface{} `json:"topK"`
}