Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ docs/*
!docs/ADMIN_PAYMENT_INTEGRATION_API.md
.serena/
.codex/
.worktrees/
frontend/coverage/
aicodex
output/
24 changes: 24 additions & 0 deletions backend/internal/handler/admin/account_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package admin

import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
Expand Down Expand Up @@ -66,6 +67,29 @@ type DataImportRequest struct {
SkipDefaultGroupBind *bool `json:"skip_default_group_bind"`
}

func (r *DataImportRequest) UnmarshalJSON(data []byte) error {
var wrapped struct {
Data DataPayload `json:"data"`
SkipDefaultGroupBind *bool `json:"skip_default_group_bind"`
}
if err := json.Unmarshal(data, &wrapped); err != nil {
return err
}
if wrapped.Data.Accounts != nil || wrapped.Data.Proxies != nil || wrapped.Data.ExportedAt != "" || wrapped.Data.Type != "" || wrapped.Data.Version != 0 {
r.Data = wrapped.Data
r.SkipDefaultGroupBind = wrapped.SkipDefaultGroupBind
return nil
}

var payload DataPayload
if err := json.Unmarshal(data, &payload); err != nil {
return err
}
r.Data = payload
r.SkipDefaultGroupBind = wrapped.SkipDefaultGroupBind
return nil
}

type DataImportResult struct {
ProxyCreated int `json:"proxy_created"`
ProxyReused int `json:"proxy_reused"`
Expand Down
73 changes: 73 additions & 0 deletions backend/internal/handler/admin/account_data_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/Wei-Shaw/sub2api/internal/service"
Expand Down Expand Up @@ -275,3 +276,75 @@ func TestImportDataReusesProxyAndSkipsDefaultGroup(t *testing.T) {
require.Len(t, adminSvc.createdAccounts, 1)
require.True(t, adminSvc.createdAccounts[0].SkipDefaultGroupBind)
}

func TestImportDataAcceptsTopLevelExportPayload(t *testing.T) {
router, adminSvc := setupAccountDataRouter()

dataPayload := map[string]any{
"exported_at": "2026-05-29T10:09:22Z",
"proxies": []map[string]any{},
"accounts": []map[string]any{
{
"name": "agora1",
"platform": service.PlatformGemini,
"type": service.AccountTypeOAuth,
"concurrency": 0,
"priority": 0,
"auto_pause_on_expired": true,
"credentials": map[string]any{
"_token_version": 2,
"access_token": "access-token",
"refresh_token": "refresh-token",
"token_type": "Bearer",
"expires_at": float64(1790000000),
"oauth_type": "gemini",
"scope": "https://www.googleapis.com/auth/cloud-platform",
"tier_id": "free-tier",
"project_id": "gemini-project",
},
},
},
}

body, _ := json.Marshal(dataPayload)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/data", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)

require.Len(t, adminSvc.createdAccounts, 1)
created := adminSvc.createdAccounts[0]
require.Equal(t, "agora1", created.Name)
require.Equal(t, service.PlatformGemini, created.Platform)
require.Equal(t, service.AccountTypeOAuth, created.Type)
require.True(t, created.SkipDefaultGroupBind)
require.Equal(t, "refresh-token", created.Credentials["refresh_token"])
require.Equal(t, "gemini-project", created.Credentials["project_id"])
}

func TestImportDataAcceptsLocalGeminiExportFixture(t *testing.T) {
fixturePath := os.Getenv("SUB2API_ACCOUNT_IMPORT_FIXTURE")
if fixturePath == "" {
t.Skip("SUB2API_ACCOUNT_IMPORT_FIXTURE is not set")
}

body, err := os.ReadFile(fixturePath)
require.NoError(t, err)

router, adminSvc := setupAccountDataRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/data", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)

require.Len(t, adminSvc.createdAccounts, 10)
for _, created := range adminSvc.createdAccounts {
require.Equal(t, service.PlatformGemini, created.Platform)
require.Equal(t, service.AccountTypeOAuth, created.Type)
require.True(t, created.SkipDefaultGroupBind)
require.NotEmpty(t, created.Credentials["refresh_token"])
require.NotEmpty(t, created.Credentials["project_id"])
}
}
7 changes: 5 additions & 2 deletions backend/internal/handler/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ const (
func NormalizeInboundEndpoint(path string) string {
path = strings.TrimSpace(path)
switch {
case strings.Contains(path, EndpointEmbeddings):
case strings.Contains(path, EndpointEmbeddings) || strings.Contains(path, "/embeddings"):
return EndpointEmbeddings
case strings.Contains(path, EndpointChatCompletions):
case strings.Contains(path, EndpointChatCompletions) || strings.Contains(path, "/chat/completions"):
return EndpointChatCompletions
case strings.Contains(path, EndpointMessages):
return EndpointMessages
Expand Down Expand Up @@ -92,6 +92,9 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
return EndpointMessages

case service.PlatformGemini:
if inbound == EndpointEmbeddings || inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits {
return inbound
}
return EndpointGeminiModels

case service.PlatformAntigravity:
Expand Down
5 changes: 5 additions & 0 deletions backend/internal/handler/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ func TestNormalizeInboundEndpoint(t *testing.T) {
{"/v1/images/generations", EndpointImagesGenerations},
{"/v1/images/edits", EndpointImagesEdits},
{"/v1beta/models", EndpointGeminiModels},
{"/v1beta/openai/chat/completions", EndpointChatCompletions},
{"/v1beta/openai/embeddings", EndpointEmbeddings},
{"/v1beta/openai/images/generations", EndpointImagesGenerations},

// Prefixed paths (antigravity, openai).
{"/antigravity/v1/messages", EndpointMessages},
Expand Down Expand Up @@ -71,6 +74,8 @@ func TestDeriveUpstreamEndpoint(t *testing.T) {

// Gemini.
{"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels},
{"gemini openai embeddings", EndpointEmbeddings, "/v1beta/openai/embeddings", service.PlatformGemini, EndpointEmbeddings},
{"gemini openai image generations", EndpointImagesGenerations, "/v1beta/openai/images/generations", service.PlatformGemini, EndpointImagesGenerations},

// OpenAI — always /v1/responses.
{"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses},
Expand Down
Loading
Loading