Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
NODE_NAME=node-001
CONTROL_ADDR=https://127.0.0.1:8888
NODE_PORT=8080
TLS_CERT=certs/client.crt
TLS_KEY=certs/client.key
TLS_CA=certs/ca.crt
44 changes: 44 additions & 0 deletions backend/cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@ import (
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
// xiugai 添加节点功能
"github.com/Wei-Shaw/sub2api/internal/node"
// end
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
// xiugai 添加节点功能
"github.com/Wei-Shaw/sub2api/internal/service"
// end
"github.com/Wei-Shaw/sub2api/internal/setup"
"github.com/Wei-Shaw/sub2api/internal/web"

Expand Down Expand Up @@ -154,6 +160,13 @@ func runMainServer() {
}
defer app.Cleanup()

// xiugai 添加节点功能
// 启动节点上报服务(可选,仅当当前目录存在 .env 时生效)
reporterCtx, reporterCancel := context.WithCancel(context.Background())
defer reporterCancel()
startNodeReporter(reporterCtx, app.AccountRepo)
// end

// 启动服务器
go func() {
if err := app.Server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
Expand All @@ -179,3 +192,34 @@ func runMainServer() {

log.Println("Server exited")
}

// xiugai 添加节点功能

// startNodeReporter 尝试从当前工作目录加载 .env 配置并以后台 goroutine 方式
// 启动节点上报服务。若 .env 不存在或配置有误,仅打印日志后静默返回,不影响主服务启动。
//
// 参数:
// - ctx:上下文,取消时节点上报服务随之停止。
// - repo:账号数据库访问层,用于读取本地全量账号数据并上报给控制服务器。
func startNodeReporter(ctx context.Context, repo service.AccountRepository) {
cfg, err := node.LoadConfig(".env")
if err != nil {
// .env 不存在属于正常情况(未启用节点上报功能),其他错误才打印警告。
if !os.IsNotExist(err) {
log.Printf("[Node] Config error, reporter disabled: %v", err)
}
return
}

lister := node.NewRepoAccountLister(repo)
reporter, err := node.NewReporter(cfg, lister)
if err != nil {
log.Printf("[Node] Init error, reporter disabled: %v", err)
return
}

go reporter.Start(ctx)
log.Printf("[Node] Reporter started (node=%s, control=%s)", cfg.NodeName, cfg.ControlAddr)
}

// end
8 changes: 6 additions & 2 deletions backend/cmd/server/wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ import (
)

type Application struct {
Server *http.Server
Server *http.Server
// xiugai 添加节点功能
// AccountRepo 账号数据库访问层,供节点上报服务读取本地全量账号数据。
AccountRepo service.AccountRepository
// end
Cleanup func()
}

Expand Down Expand Up @@ -53,7 +57,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
provideCleanup,

// Application struct
wire.Struct(new(Application), "Server", "Cleanup"),
wire.Struct(new(Application), "Server", "AccountRepo", "Cleanup"),
)
return nil, nil
}
Expand Down
10 changes: 6 additions & 4 deletions backend/cmd/server/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions backend/internal/domain/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const (
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
AccountTypeAnthropicAWS = "anthropic_aws" // Claude Platform on AWS 类型账号(Anthropic 在 AWS 上托管的外部 API,aws-external-anthropic.{region}.api.aws)
AccountTypeServiceAccount = "service_account" // Google Service Account 类型账号(用于 Vertex AI)
)

Expand Down
54 changes: 53 additions & 1 deletion backend/internal/handler/admin/account_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ type DataImportResult struct {
ProxyReused int `json:"proxy_reused"`
ProxyFailed int `json:"proxy_failed"`
AccountCreated int `json:"account_created"`
AccountSkipped int `json:"account_skipped"`
AccountFailed int `json:"account_failed"`
Errors []DataImportError `json:"errors,omitempty"`
}
Expand Down Expand Up @@ -274,6 +275,17 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest)
// 收集需要异步设置隐私的 Antigravity OAuth 账号
var privacyAccounts []*service.Account

// 预加载该批次涉及的所有平台账号,构建身份索引,用于跨已存在账号去重。
importPlatforms := make([]string, 0, len(dataPayload.Accounts))
for _, item := range dataPayload.Accounts {
importPlatforms = append(importPlatforms, item.Platform)
}
accountIndex, indexErr := h.loadAccountIdentityIndex(ctx, importPlatforms)
if indexErr != nil {
return result, indexErr
}
seenAccountIdentity := make(map[string]int, len(dataPayload.Accounts))

for i := range dataPayload.Accounts {
item := dataPayload.Accounts[i]
if err := validateDataAccount(item); err != nil {
Expand Down Expand Up @@ -304,6 +316,36 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest)

enrichCredentialsFromIDToken(&item)

// 去重:批内已见 → 跳过;DB 已存在 → 跳过。
identityKeys := buildAccountIdentityKeys(accountIdentityInput{
Platform: item.Platform,
Type: item.Type,
Credentials: item.Credentials,
Extra: item.Extra,
})
if len(identityKeys) > 0 {
if dupIndex, ok := firstSeenIdentity(seenAccountIdentity, identityKeys); ok {
result.AccountSkipped++
result.Errors = append(result.Errors, DataImportError{
Kind: "account_skipped",
Name: item.Name,
Message: fmt.Sprintf("与本次导入第 %d 条重复,已跳过", dupIndex),
})
continue
}
if existing := accountIndex.Find(identityKeys); existing != nil {
result.AccountSkipped++
result.Errors = append(result.Errors, DataImportError{
Kind: "account_skipped",
Name: item.Name,
Message: fmt.Sprintf("账号已存在(id=%d, name=%s),已跳过", existing.ID, existing.Name),
})
markIdentitySeen(seenAccountIdentity, identityKeys, i+1)
continue
}
markIdentitySeen(seenAccountIdentity, identityKeys, i+1)
}

accountInput := &service.CreateAccountInput{
Name: item.Name,
Notes: item.Notes,
Expand Down Expand Up @@ -331,6 +373,10 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest)
})
continue
}
// 新建账号加入索引,保证同批次后续项可以发现它。
if created != nil {
accountIndex.Add(*created)
}
// 收集 Antigravity OAuth 账号,稍后异步设置隐私
if created.Platform == service.PlatformAntigravity && created.Type == service.AccountTypeOAuth {
privacyAccounts = append(privacyAccounts, created)
Expand Down Expand Up @@ -563,7 +609,13 @@ func validateDataAccount(item DataAccount) error {
return errors.New("account credentials is required")
}
switch item.Type {
case service.AccountTypeOAuth, service.AccountTypeSetupToken, service.AccountTypeAPIKey, service.AccountTypeUpstream:
case service.AccountTypeOAuth,
service.AccountTypeSetupToken,
service.AccountTypeAPIKey,
service.AccountTypeUpstream,
service.AccountTypeBedrock,
service.AccountTypeAnthropicAWS,
service.AccountTypeServiceAccount:
default:
return fmt.Errorf("account type is invalid: %s", item.Type)
}
Expand Down
73 changes: 71 additions & 2 deletions backend/internal/handler/admin/account_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock service_account"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock anthropic_aws service_account"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Expand All @@ -117,7 +117,7 @@ type CreateAccountRequest struct {
type UpdateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock service_account"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock anthropic_aws service_account"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Expand Down Expand Up @@ -533,6 +533,24 @@ func (h *AccountHandler) Create(c *gin.Context) {
var createdAccount *service.Account

result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
// 去重检查:根据平台与凭证生成身份键,命中即返回 409。
if existing, dupErr := h.findDuplicateAccount(ctx, accountIdentityInput{
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
}); dupErr != nil {
return nil, dupErr
} else if existing != nil {
return nil, infraerrors.Conflict(
"ACCOUNT_DUPLICATE",
fmt.Sprintf("账号已存在(id=%d, name=%s)", existing.ID, existing.Name),
).WithMetadata(map[string]string{
"existing_id": strconv.FormatInt(existing.ID, 10),
"existing_name": existing.Name,
})
}

account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: req.Name,
Notes: req.Notes,
Expand Down Expand Up @@ -1308,12 +1326,24 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
executeAdminIdempotentJSON(c, "admin.accounts.batch_create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
success := 0
failed := 0
skipped := 0
results := make([]gin.H, 0, len(req.Accounts))
// 收集需要异步设置隐私的 OAuth 账号
var antigravityPrivacyAccounts []*service.Account
var openaiPrivacyAccounts []*service.Account

// 预加载该批次涉及的所有平台账号,构建身份索引用于跨已存在账号去重。
platforms := make([]string, 0, len(req.Accounts))
for _, item := range req.Accounts {
platforms = append(platforms, item.Platform)
}
index, indexErr := h.loadAccountIdentityIndex(ctx, platforms)
if indexErr != nil {
return nil, indexErr
}
seenIdentity := make(map[string]int, len(req.Accounts))

for batchIdx, item := range req.Accounts {
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
failed++
results = append(results, gin.H{
Expand All @@ -1327,6 +1357,40 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(item.Extra)

// 去重:先检查批内已见,再检查数据库已存在账号。
identityKeys := buildAccountIdentityKeys(accountIdentityInput{
Platform: item.Platform,
Type: item.Type,
Credentials: item.Credentials,
Extra: item.Extra,
})
if len(identityKeys) > 0 {
if dupIndex, ok := firstSeenIdentity(seenIdentity, identityKeys); ok {
skipped++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"skipped": true,
"error": fmt.Sprintf("与第 %d 条重复,已跳过", dupIndex),
})
continue
}
if existing := index.Find(identityKeys); existing != nil {
skipped++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"skipped": true,
"existing_id": existing.ID,
"existing_name": existing.Name,
"error": fmt.Sprintf("账号已存在(id=%d, name=%s),已跳过", existing.ID, existing.Name),
})
markIdentitySeen(seenIdentity, identityKeys, batchIdx+1)
continue
}
markIdentitySeen(seenIdentity, identityKeys, batchIdx+1)
}

skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk

account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Expand Down Expand Up @@ -1354,6 +1418,10 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
})
continue
}
// 新建账号加入索引,保证同批次后续项可以发现它。
if account != nil {
index.Add(*account)
}
// 收集需要异步设置隐私的 OAuth 账号
if account.Type == service.AccountTypeOAuth {
switch account.Platform {
Expand Down Expand Up @@ -1407,6 +1475,7 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
return gin.H{
"success": success,
"failed": failed,
"skipped": skipped,
"results": results,
}, nil
})
Expand Down
Loading
Loading