diff --git a/controller/model.go b/controller/model.go index 4dbd45838dd..93d88695f37 100644 --- a/controller/model.go +++ b/controller/model.go @@ -3,6 +3,7 @@ package controller import ( "fmt" "net/http" + "strings" "time" "github.com/QuantumNous/new-api/common" @@ -220,10 +221,7 @@ func ListModels(c *gin.Context, modelType int) { case constant.ChannelTypeGemini: userGeminiModels := make([]dto.GeminiModel, len(userOpenAiModels)) for i, model := range userOpenAiModels { - userGeminiModels[i] = dto.GeminiModel{ - Name: model.Id, - DisplayName: model.Id, - } + userGeminiModels[i] = buildGeminiModel(model) } c.JSON(200, gin.H{ "models": userGeminiModels, @@ -238,6 +236,81 @@ func ListModels(c *gin.Context, modelType int) { } } +func buildGeminiModel(openAIModel dto.OpenAIModels) dto.GeminiModel { + normalizedModelID := normalizeGeminiModelID(openAIModel.Id) + return dto.GeminiModel{ + Name: fmt.Sprintf("models/%s", normalizedModelID), + DisplayName: normalizedModelID, + SupportedGenerationMethods: getGeminiSupportedGenerationMethods(openAIModel), + } +} + +func getGeminiSupportedGenerationMethods(openAIModel dto.OpenAIModels) []string { + if methods := getGeminiSupportedGenerationMethodsFromEndpoints(openAIModel.SupportedEndpointTypes); len(methods) > 0 { + return methods + } + if methods := getGeminiSupportedGenerationMethodsFromModelID(openAIModel.Id); len(methods) > 0 { + return methods + } + + return []string{"generateContent"} +} + +func getGeminiSupportedGenerationMethodsFromEndpoints(endpointTypes []constant.EndpointType) []string { + methods := make([]string, 0, 2) + for _, endpointType := range endpointTypes { + switch endpointType { + case constant.EndpointTypeImageGeneration: + methods = appendUniqueGeminiMethod(methods, "predict") + case constant.EndpointTypeEmbeddings: + methods = appendUniqueGeminiMethod(methods, "embedContent") + methods = appendUniqueGeminiMethod(methods, "batchEmbedContents") + } + } + return methods +} + +func getGeminiSupportedGenerationMethodsFromModelID(modelID string) []string { + normalizedModelID := normalizeGeminiModelID(modelID) + switch { + case isGeminiVideoModel(normalizedModelID): + return []string{"predictLongRunning"} + case isGeminiEmbeddingModel(normalizedModelID): + return []string{"embedContent", "batchEmbedContents"} + case isGeminiPredictModel(normalizedModelID): + return []string{"predict"} + } + + return nil +} + +func normalizeGeminiModelID(modelID string) string { + return strings.ToLower(strings.TrimPrefix(modelID, "models/")) +} + +func isGeminiEmbeddingModel(modelID string) bool { + return strings.HasPrefix(modelID, "text-embedding") || + strings.HasPrefix(modelID, "embedding") || + strings.HasPrefix(modelID, "gemini-embedding") +} + +func isGeminiPredictModel(modelID string) bool { + return strings.HasPrefix(modelID, "imagen") +} + +func isGeminiVideoModel(modelID string) bool { + return strings.HasPrefix(modelID, "veo-") +} + +func appendUniqueGeminiMethod(methods []string, method string) []string { + for _, existing := range methods { + if existing == method { + return methods + } + } + return append(methods, method) +} + func ChannelListModels(c *gin.Context) { c.JSON(200, gin.H{ "success": true, diff --git a/controller/model_list_test.go b/controller/model_list_test.go index 97d27cae5c6..70ca0eb3992 100644 --- a/controller/model_list_test.go +++ b/controller/model_list_test.go @@ -26,6 +26,11 @@ type listModelsResponse struct { Object string `json:"object"` } +type geminiListModelsResponse struct { + Models []dto.GeminiModel `json:"models"` + NextPageToken interface{} `json:"nextPageToken"` +} + func setupModelListControllerTestDB(t *testing.T) *gorm.DB { t.Helper() @@ -146,6 +151,65 @@ func decodeListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder) return ids } +func decodeGeminiListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]dto.GeminiModel { + t.Helper() + + require.Equal(t, http.StatusOK, recorder.Code) + var payload geminiListModelsResponse + require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &payload)) + + modelsByName := make(map[string]dto.GeminiModel, len(payload.Models)) + for _, item := range payload.Models { + name, ok := item.Name.(string) + require.True(t, ok) + modelsByName[name] = item + } + return modelsByName +} + +func geminiMethods(model dto.GeminiModel) []string { + return append([]string(nil), model.SupportedGenerationMethods...) +} + +func TestBuildGeminiModelNormalizesPrefixedID(t *testing.T) { + geminiModel := buildGeminiModel(dto.OpenAIModels{Id: "models/gemini-2.5-flash"}) + + require.Equal(t, "models/gemini-2.5-flash", geminiModel.Name) + require.Equal(t, "gemini-2.5-flash", geminiModel.DisplayName) + require.ElementsMatch(t, []string{"generateContent"}, geminiModel.SupportedGenerationMethods) +} + +func TestGetGeminiSupportedGenerationMethodsUsesEndpointMetadata(t *testing.T) { + require.ElementsMatch(t, + []string{"predict"}, + getGeminiSupportedGenerationMethods(dto.OpenAIModels{ + Id: "custom-image-model", + SupportedEndpointTypes: []constant.EndpointType{constant.EndpointTypeImageGeneration, constant.EndpointTypeGemini}, + }), + ) + require.ElementsMatch(t, + []string{"embedContent", "batchEmbedContents"}, + getGeminiSupportedGenerationMethods(dto.OpenAIModels{ + Id: "custom-embedding-model", + SupportedEndpointTypes: []constant.EndpointType{constant.EndpointTypeEmbeddings, constant.EndpointTypeGemini}, + }), + ) + require.ElementsMatch(t, + []string{"embedContent", "batchEmbedContents"}, + getGeminiSupportedGenerationMethods(dto.OpenAIModels{ + Id: "gemini-embedding-001", + SupportedEndpointTypes: []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}, + }), + ) + require.ElementsMatch(t, + []string{"generateContent"}, + getGeminiSupportedGenerationMethods(dto.OpenAIModels{ + Id: "custom-text-model", + SupportedEndpointTypes: []constant.EndpointType{constant.EndpointTypeGemini}, + }), + ) +} + func pricingByModelName(pricings []model.Pricing) map[string]model.Pricing { byName := make(map[string]model.Pricing, len(pricings)) for _, pricing := range pricings { @@ -240,3 +304,44 @@ func TestListModelsTokenLimitIncludesTieredBillingModel(t *testing.T) { require.NotContains(t, ids, "zz-token-tiered-missing-expr-model") require.NotContains(t, ids, "zz-token-unpriced-model") } + +func TestListModelsGeminiIncludesSupportedGenerationMethods(t *testing.T) { + originalSelfUseMode := operation_setting.SelfUseModeEnabled + operation_setting.SelfUseModeEnabled = true + t.Cleanup(func() { + operation_setting.SelfUseModeEnabled = originalSelfUseMode + model.InvalidatePricingCache() + }) + + db := setupModelListControllerTestDB(t) + require.NoError(t, db.Create(&model.User{ + Id: 1002, + Username: "gemini-model-list-user", + Password: "password", + Group: "default", + Status: common.UserStatusEnabled, + }).Error) + require.NoError(t, db.Create(&[]model.Channel{ + {Id: 1, Type: constant.ChannelTypeGemini, Status: common.ChannelStatusEnabled, Models: "gemini-2.5-flash,gemini-embedding-001,imagen-4.0-generate-001", Group: "default"}, + }).Error) + require.NoError(t, db.Create(&[]model.Ability{ + {Group: "default", Model: "gemini-2.5-flash", ChannelId: 1, Enabled: true}, + {Group: "default", Model: "gemini-embedding-001", ChannelId: 1, Enabled: true}, + {Group: "default", Model: "imagen-4.0-generate-001", ChannelId: 1, Enabled: true}, + }).Error) + model.InvalidatePricingCache() + model.GetPricing() + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodGet, "/v1beta/models", nil) + ctx.Set("id", 1002) + + ListModels(ctx, constant.ChannelTypeGemini) + + models := decodeGeminiListModelsResponse(t, recorder) + + require.ElementsMatch(t, []string{"generateContent"}, geminiMethods(models["models/gemini-2.5-flash"])) + require.ElementsMatch(t, []string{"embedContent", "batchEmbedContents"}, geminiMethods(models["models/gemini-embedding-001"])) + require.ElementsMatch(t, []string{"predict"}, geminiMethods(models["models/imagen-4.0-generate-001"])) +} diff --git a/dto/pricing.go b/dto/pricing.go index 1ed8dcd31c2..db8de93ed17 100644 --- a/dto/pricing.go +++ b/dto/pricing.go @@ -19,17 +19,17 @@ type AnthropicModel struct { } type GeminiModel struct { - Name interface{} `json:"name"` - BaseModelId interface{} `json:"baseModelId"` - Version interface{} `json:"version"` - DisplayName interface{} `json:"displayName"` - Description interface{} `json:"description"` - InputTokenLimit interface{} `json:"inputTokenLimit"` - OutputTokenLimit interface{} `json:"outputTokenLimit"` - SupportedGenerationMethods []interface{} `json:"supportedGenerationMethods"` - Thinking interface{} `json:"thinking"` - Temperature interface{} `json:"temperature"` - MaxTemperature interface{} `json:"maxTemperature"` - TopP interface{} `json:"topP"` - TopK interface{} `json:"topK"` + Name interface{} `json:"name"` + BaseModelId interface{} `json:"baseModelId"` + Version interface{} `json:"version"` + DisplayName interface{} `json:"displayName"` + Description interface{} `json:"description"` + InputTokenLimit interface{} `json:"inputTokenLimit"` + OutputTokenLimit interface{} `json:"outputTokenLimit"` + SupportedGenerationMethods []string `json:"supportedGenerationMethods"` + Thinking interface{} `json:"thinking"` + Temperature interface{} `json:"temperature"` + MaxTemperature interface{} `json:"maxTemperature"` + TopP interface{} `json:"topP"` + TopK interface{} `json:"topK"` }