From 9c3b4aa6bdd4e67848557004dd8b058465272a7e Mon Sep 17 00:00:00 2001 From: Arthas Date: Fri, 29 May 2026 14:23:21 +0800 Subject: [PATCH 1/2] feat(openai): add group codex client restriction --- backend/ent/group.go | 13 ++- backend/ent/group/group.go | 10 ++ backend/ent/group/where.go | 15 +++ backend/ent/group_create.go | 65 ++++++++++++ backend/ent/group_update.go | 34 +++++++ backend/ent/migrate/schema.go | 3 +- backend/ent/mutation.go | 56 ++++++++++- backend/ent/runtime/runtime.go | 26 ++--- backend/ent/schema/group.go | 5 + backend/go.sum | 2 + .../internal/handler/admin/group_handler.go | 4 + backend/internal/handler/dto/mappers.go | 1 + backend/internal/handler/dto/types.go | 2 + .../handler/openai_chat_completions.go | 3 + backend/internal/handler/openai_embeddings.go | 3 + .../handler/openai_gateway_handler.go | 9 ++ .../handler/openai_gateway_handler_test.go | 98 +++++++++++++++++++ .../handler/openai_group_restriction.go | 32 ++++++ backend/internal/handler/openai_images.go | 3 + backend/internal/repository/api_key_repo.go | 1 + backend/internal/repository/group_repo.go | 2 + backend/internal/server/api_contract_test.go | 1 + backend/internal/service/admin_service.go | 8 ++ .../internal/service/api_key_auth_cache.go | 1 + .../service/api_key_auth_cache_impl.go | 4 +- .../service/api_key_service_cache_test.go | 2 + backend/internal/service/group.go | 2 + .../openai_client_restriction_detector.go | 17 +++- ...openai_client_restriction_detector_test.go | 54 ++++++++++ .../service/openai_gateway_service.go | 10 +- ...nai_gateway_service_codex_cli_only_test.go | 4 + .../service/openai_group_restriction.go | 40 ++++++++ ...5_add_group_codex_official_restriction.sql | 9 ++ frontend/src/i18n/locales/en.ts | 4 + frontend/src/i18n/locales/zh.ts | 4 + frontend/src/types/index.ts | 4 + frontend/src/views/admin/GroupsView.vue | 77 +++++++++++++++ 37 files changed, 602 insertions(+), 26 deletions(-) create mode 100644 backend/internal/handler/openai_group_restriction.go create mode 100644 backend/internal/service/openai_group_restriction.go create mode 100644 backend/migrations/145_add_group_codex_official_restriction.sql diff --git a/backend/ent/group.go b/backend/ent/group.go index 298df88abbc..262381c3e26 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -65,6 +65,8 @@ type Group struct { FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` // 无效请求兜底使用的分组 ID FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` + // OpenAI 分组是否仅允许 Codex 官方客户端 + CodexOfficialOnly bool `json:"codex_official_only,omitempty"` // 模型路由配置:模型模式 -> 优先账号ID列表 ModelRouting map[string][]int64 `json:"model_routing,omitempty"` // 是否启用模型路由配置 @@ -197,7 +199,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { switch columns[i] { case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig, group.FieldModelsListConfig: values[i] = new([]byte) - case group.FieldIsExclusive, group.FieldAllowImageGeneration, group.FieldImageRateIndependent, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: + case group.FieldIsExclusive, group.FieldAllowImageGeneration, group.FieldImageRateIndependent, group.FieldClaudeCodeOnly, group.FieldCodexOfficialOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImageRateMultiplier, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: values[i] = new(sql.NullFloat64) @@ -376,6 +378,12 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.FallbackGroupIDOnInvalidRequest = new(int64) *_m.FallbackGroupIDOnInvalidRequest = value.Int64 } + case group.FieldCodexOfficialOnly: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field codex_official_only", values[i]) + } else if value.Valid { + _m.CodexOfficialOnly = value.Bool + } case group.FieldModelRouting: if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field model_routing", values[i]) @@ -621,6 +629,9 @@ func (_m *Group) String() string { builder.WriteString(fmt.Sprintf("%v", *v)) } builder.WriteString(", ") + builder.WriteString("codex_official_only=") + builder.WriteString(fmt.Sprintf("%v", _m.CodexOfficialOnly)) + builder.WriteString(", ") builder.WriteString("model_routing=") builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting)) builder.WriteString(", ") diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index ebe9bd7e820..0f7893f5cd9 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -62,6 +62,8 @@ const ( FieldFallbackGroupID = "fallback_group_id" // FieldFallbackGroupIDOnInvalidRequest holds the string denoting the fallback_group_id_on_invalid_request field in the database. FieldFallbackGroupIDOnInvalidRequest = "fallback_group_id_on_invalid_request" + // FieldCodexOfficialOnly holds the string denoting the codex_official_only field in the database. + FieldCodexOfficialOnly = "codex_official_only" // FieldModelRouting holds the string denoting the model_routing field in the database. FieldModelRouting = "model_routing" // FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database. @@ -184,6 +186,7 @@ var Columns = []string{ FieldClaudeCodeOnly, FieldFallbackGroupID, FieldFallbackGroupIDOnInvalidRequest, + FieldCodexOfficialOnly, FieldModelRouting, FieldModelRoutingEnabled, FieldMcpXMLInject, @@ -259,6 +262,8 @@ var ( DefaultImageRateMultiplier float64 // DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field. DefaultClaudeCodeOnly bool + // DefaultCodexOfficialOnly holds the default value on creation for the "codex_official_only" field. + DefaultCodexOfficialOnly bool // DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field. DefaultModelRoutingEnabled bool // DefaultMcpXMLInject holds the default value on creation for the "mcp_xml_inject" field. @@ -408,6 +413,11 @@ func ByFallbackGroupIDOnInvalidRequest(opts ...sql.OrderTermOption) OrderOption return sql.OrderByField(FieldFallbackGroupIDOnInvalidRequest, opts...).ToFunc() } +// ByCodexOfficialOnly orders the results by the codex_official_only field. +func ByCodexOfficialOnly(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCodexOfficialOnly, opts...).ToFunc() +} + // ByModelRoutingEnabled orders the results by the model_routing_enabled field. func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc() diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index d3223a92577..2e2175e93a6 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -170,6 +170,11 @@ func FallbackGroupIDOnInvalidRequest(v int64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v)) } +// CodexOfficialOnly applies equality check predicate on the "codex_official_only" field. It's identical to CodexOfficialOnlyEQ. +func CodexOfficialOnly(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldCodexOfficialOnly, v)) +} + // ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ. func ModelRoutingEnabled(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v)) @@ -1235,6 +1240,16 @@ func FallbackGroupIDOnInvalidRequestNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldFallbackGroupIDOnInvalidRequest)) } +// CodexOfficialOnlyEQ applies the EQ predicate on the "codex_official_only" field. +func CodexOfficialOnlyEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldCodexOfficialOnly, v)) +} + +// CodexOfficialOnlyNEQ applies the NEQ predicate on the "codex_official_only" field. +func CodexOfficialOnlyNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldCodexOfficialOnly, v)) +} + // ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field. func ModelRoutingIsNil() predicate.Group { return predicate.Group(sql.FieldIsNull(FieldModelRouting)) diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index d5ed0c19a42..b2502b055c9 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -343,6 +343,20 @@ func (_c *GroupCreate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *Gro return _c } +// SetCodexOfficialOnly sets the "codex_official_only" field. +func (_c *GroupCreate) SetCodexOfficialOnly(v bool) *GroupCreate { + _c.mutation.SetCodexOfficialOnly(v) + return _c +} + +// SetNillableCodexOfficialOnly sets the "codex_official_only" field if the given value is not nil. +func (_c *GroupCreate) SetNillableCodexOfficialOnly(v *bool) *GroupCreate { + if v != nil { + _c.SetCodexOfficialOnly(*v) + } + return _c +} + // SetModelRouting sets the "model_routing" field. func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate { _c.mutation.SetModelRouting(v) @@ -676,6 +690,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultClaudeCodeOnly _c.mutation.SetClaudeCodeOnly(v) } + if _, ok := _c.mutation.CodexOfficialOnly(); !ok { + v := group.DefaultCodexOfficialOnly + _c.mutation.SetCodexOfficialOnly(v) + } if _, ok := _c.mutation.ModelRoutingEnabled(); !ok { v := group.DefaultModelRoutingEnabled _c.mutation.SetModelRoutingEnabled(v) @@ -784,6 +802,9 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)} } + if _, ok := _c.mutation.CodexOfficialOnly(); !ok { + return &ValidationError{Name: "codex_official_only", err: errors.New(`ent: missing required field "Group.codex_official_only"`)} + } if _, ok := _c.mutation.ModelRoutingEnabled(); !ok { return &ValidationError{Name: "model_routing_enabled", err: errors.New(`ent: missing required field "Group.model_routing_enabled"`)} } @@ -941,6 +962,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) _node.FallbackGroupIDOnInvalidRequest = &value } + if value, ok := _c.mutation.CodexOfficialOnly(); ok { + _spec.SetField(group.FieldCodexOfficialOnly, field.TypeBool, value) + _node.CodexOfficialOnly = value + } if value, ok := _c.mutation.ModelRouting(); ok { _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) _node.ModelRouting = value @@ -1535,6 +1560,18 @@ func (u *GroupUpsert) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsert { return u } +// SetCodexOfficialOnly sets the "codex_official_only" field. +func (u *GroupUpsert) SetCodexOfficialOnly(v bool) *GroupUpsert { + u.Set(group.FieldCodexOfficialOnly, v) + return u +} + +// UpdateCodexOfficialOnly sets the "codex_official_only" field to the value that was provided on create. +func (u *GroupUpsert) UpdateCodexOfficialOnly() *GroupUpsert { + u.SetExcluded(group.FieldCodexOfficialOnly) + return u +} + // SetModelRouting sets the "model_routing" field. func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert { u.Set(group.FieldModelRouting, v) @@ -2197,6 +2234,20 @@ func (u *GroupUpsertOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertOne }) } +// SetCodexOfficialOnly sets the "codex_official_only" field. +func (u *GroupUpsertOne) SetCodexOfficialOnly(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetCodexOfficialOnly(v) + }) +} + +// UpdateCodexOfficialOnly sets the "codex_official_only" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateCodexOfficialOnly() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateCodexOfficialOnly() + }) +} + // SetModelRouting sets the "model_routing" field. func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -3052,6 +3103,20 @@ func (u *GroupUpsertBulk) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertBul }) } +// SetCodexOfficialOnly sets the "codex_official_only" field. +func (u *GroupUpsertBulk) SetCodexOfficialOnly(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetCodexOfficialOnly(v) + }) +} + +// UpdateCodexOfficialOnly sets the "codex_official_only" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateCodexOfficialOnly() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateCodexOfficialOnly() + }) +} + // SetModelRouting sets the "model_routing" field. func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index c10d60ecb85..29d4f4981a0 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -473,6 +473,20 @@ func (_u *GroupUpdate) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdate { return _u } +// SetCodexOfficialOnly sets the "codex_official_only" field. +func (_u *GroupUpdate) SetCodexOfficialOnly(v bool) *GroupUpdate { + _u.mutation.SetCodexOfficialOnly(v) + return _u +} + +// SetNillableCodexOfficialOnly sets the "codex_official_only" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableCodexOfficialOnly(v *bool) *GroupUpdate { + if v != nil { + _u.SetCodexOfficialOnly(*v) + } + return _u +} + // SetModelRouting sets the "model_routing" field. func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate { _u.mutation.SetModelRouting(v) @@ -1085,6 +1099,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() { _spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64) } + if value, ok := _u.mutation.CodexOfficialOnly(); ok { + _spec.SetField(group.FieldCodexOfficialOnly, field.TypeBool, value) + } if value, ok := _u.mutation.ModelRouting(); ok { _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) } @@ -1886,6 +1903,20 @@ func (_u *GroupUpdateOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdateOne return _u } +// SetCodexOfficialOnly sets the "codex_official_only" field. +func (_u *GroupUpdateOne) SetCodexOfficialOnly(v bool) *GroupUpdateOne { + _u.mutation.SetCodexOfficialOnly(v) + return _u +} + +// SetNillableCodexOfficialOnly sets the "codex_official_only" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableCodexOfficialOnly(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetCodexOfficialOnly(*v) + } + return _u +} + // SetModelRouting sets the "model_routing" field. func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne { _u.mutation.SetModelRouting(v) @@ -2528,6 +2559,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() { _spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64) } + if value, ok := _u.mutation.CodexOfficialOnly(); ok { + _spec.SetField(group.FieldCodexOfficialOnly, field.TypeBool, value) + } if value, ok := _u.mutation.ModelRouting(); ok { _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) } diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 7abe4c601e3..1a85e6d77a6 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -659,6 +659,7 @@ var ( {Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true}, + {Name: "codex_official_only", Type: field.TypeBool, Default: false}, {Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "model_routing_enabled", Type: field.TypeBool, Default: false}, {Name: "mcp_xml_inject", Type: field.TypeBool, Default: true}, @@ -706,7 +707,7 @@ var ( { Name: "group_sort_order", Unique: false, - Columns: []*schema.Column{GroupsColumns[28]}, + Columns: []*schema.Column{GroupsColumns[29]}, }, }, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 003e25d5015..7950652a0fb 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -14889,6 +14889,7 @@ type GroupMutation struct { addfallback_group_id *int64 fallback_group_id_on_invalid_request *int64 addfallback_group_id_on_invalid_request *int64 + codex_official_only *bool model_routing *map[string][]int64 model_routing_enabled *bool mcp_xml_inject *bool @@ -16212,6 +16213,42 @@ func (m *GroupMutation) ResetFallbackGroupIDOnInvalidRequest() { delete(m.clearedFields, group.FieldFallbackGroupIDOnInvalidRequest) } +// SetCodexOfficialOnly sets the "codex_official_only" field. +func (m *GroupMutation) SetCodexOfficialOnly(b bool) { + m.codex_official_only = &b +} + +// CodexOfficialOnly returns the value of the "codex_official_only" field in the mutation. +func (m *GroupMutation) CodexOfficialOnly() (r bool, exists bool) { + v := m.codex_official_only + if v == nil { + return + } + return *v, true +} + +// OldCodexOfficialOnly returns the old "codex_official_only" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldCodexOfficialOnly(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCodexOfficialOnly is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCodexOfficialOnly requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCodexOfficialOnly: %w", err) + } + return oldValue.CodexOfficialOnly, nil +} + +// ResetCodexOfficialOnly resets all changes to the "codex_official_only" field. +func (m *GroupMutation) ResetCodexOfficialOnly() { + m.codex_official_only = nil +} + // SetModelRouting sets the "model_routing" field. func (m *GroupMutation) SetModelRouting(value map[string][]int64) { m.model_routing = &value @@ -17070,7 +17107,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 35) + fields := make([]string, 0, 36) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -17140,6 +17177,9 @@ func (m *GroupMutation) Fields() []string { if m.fallback_group_id_on_invalid_request != nil { fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) } + if m.codex_official_only != nil { + fields = append(fields, group.FieldCodexOfficialOnly) + } if m.model_routing != nil { fields = append(fields, group.FieldModelRouting) } @@ -17230,6 +17270,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.FallbackGroupID() case group.FieldFallbackGroupIDOnInvalidRequest: return m.FallbackGroupIDOnInvalidRequest() + case group.FieldCodexOfficialOnly: + return m.CodexOfficialOnly() case group.FieldModelRouting: return m.ModelRouting() case group.FieldModelRoutingEnabled: @@ -17309,6 +17351,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldFallbackGroupID(ctx) case group.FieldFallbackGroupIDOnInvalidRequest: return m.OldFallbackGroupIDOnInvalidRequest(ctx) + case group.FieldCodexOfficialOnly: + return m.OldCodexOfficialOnly(ctx) case group.FieldModelRouting: return m.OldModelRouting(ctx) case group.FieldModelRoutingEnabled: @@ -17503,6 +17547,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetFallbackGroupIDOnInvalidRequest(v) return nil + case group.FieldCodexOfficialOnly: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCodexOfficialOnly(v) + return nil case group.FieldModelRouting: v, ok := value.(map[string][]int64) if !ok { @@ -17933,6 +17984,9 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldFallbackGroupIDOnInvalidRequest: m.ResetFallbackGroupIDOnInvalidRequest() return nil + case group.FieldCodexOfficialOnly: + m.ResetCodexOfficialOnly() + return nil case group.FieldModelRouting: m.ResetModelRouting() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index fdb837e805f..90c3ae55f86 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -832,50 +832,54 @@ func init() { groupDescClaudeCodeOnly := groupFields[17].Descriptor() // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) + // groupDescCodexOfficialOnly is the schema descriptor for codex_official_only field. + groupDescCodexOfficialOnly := groupFields[20].Descriptor() + // group.DefaultCodexOfficialOnly holds the default value on creation for the codex_official_only field. + group.DefaultCodexOfficialOnly = groupDescCodexOfficialOnly.Default.(bool) // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. - groupDescModelRoutingEnabled := groupFields[21].Descriptor() + groupDescModelRoutingEnabled := groupFields[22].Descriptor() // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) // groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field. - groupDescMcpXMLInject := groupFields[22].Descriptor() + groupDescMcpXMLInject := groupFields[23].Descriptor() // group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field. group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool) // groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field. - groupDescSupportedModelScopes := groupFields[23].Descriptor() + groupDescSupportedModelScopes := groupFields[24].Descriptor() // group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field. group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string) // groupDescSortOrder is the schema descriptor for sort_order field. - groupDescSortOrder := groupFields[24].Descriptor() + groupDescSortOrder := groupFields[25].Descriptor() // group.DefaultSortOrder holds the default value on creation for the sort_order field. group.DefaultSortOrder = groupDescSortOrder.Default.(int) // groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field. - groupDescAllowMessagesDispatch := groupFields[25].Descriptor() + groupDescAllowMessagesDispatch := groupFields[26].Descriptor() // group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field. group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool) // groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field. - groupDescRequireOauthOnly := groupFields[26].Descriptor() + groupDescRequireOauthOnly := groupFields[27].Descriptor() // group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field. group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool) // groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field. - groupDescRequirePrivacySet := groupFields[27].Descriptor() + groupDescRequirePrivacySet := groupFields[28].Descriptor() // group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field. group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool) // groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field. - groupDescDefaultMappedModel := groupFields[28].Descriptor() + groupDescDefaultMappedModel := groupFields[29].Descriptor() // group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field. group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string) // group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error) // groupDescMessagesDispatchModelConfig is the schema descriptor for messages_dispatch_model_config field. - groupDescMessagesDispatchModelConfig := groupFields[29].Descriptor() + groupDescMessagesDispatchModelConfig := groupFields[30].Descriptor() // group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field. group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig) // groupDescModelsListConfig is the schema descriptor for models_list_config field. - groupDescModelsListConfig := groupFields[30].Descriptor() + groupDescModelsListConfig := groupFields[31].Descriptor() // group.DefaultModelsListConfig holds the default value on creation for the models_list_config field. group.DefaultModelsListConfig = groupDescModelsListConfig.Default.(domain.GroupModelsListConfig) // groupDescRpmLimit is the schema descriptor for rpm_limit field. - groupDescRpmLimit := groupFields[31].Descriptor() + groupDescRpmLimit := groupFields[32].Descriptor() // group.DefaultRpmLimit holds the default value on creation for the rpm_limit field. group.DefaultRpmLimit = groupDescRpmLimit.Default.(int) idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 2a1715f8b2b..7ed988e3904 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -110,6 +110,11 @@ func (Group) Fields() []ent.Field { Nillable(). Comment("无效请求兜底使用的分组 ID"), + // OpenAI Codex 官方客户端限制 (added by migration 145) + field.Bool("codex_official_only"). + Default(false). + Comment("OpenAI 分组是否仅允许 Codex 官方客户端"), + // 模型路由配置 (added by migration 040) field.JSON("model_routing", map[string][]int64{}). Optional(). diff --git a/backend/go.sum b/backend/go.sum index 7735fda29e3..fbc04494ce5 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -164,6 +164,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index dbf6f709a49..0b15792f18f 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -100,6 +100,7 @@ type CreateGroupRequest struct { ImagePrice4K *float64 `json:"image_price_4k"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` + CodexOfficialOnly bool `json:"codex_official_only"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 `json:"model_routing"` @@ -141,6 +142,7 @@ type UpdateGroupRequest struct { ImagePrice4K *float64 `json:"image_price_4k"` ClaudeCodeOnly *bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` + CodexOfficialOnly *bool `json:"codex_official_only"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 `json:"model_routing"` @@ -289,6 +291,7 @@ func (h *GroupHandler) Create(c *gin.Context) { ImagePrice4K: req.ImagePrice4K, ClaudeCodeOnly: req.ClaudeCodeOnly, FallbackGroupID: req.FallbackGroupID, + CodexOfficialOnly: req.CodexOfficialOnly, FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, ModelRouting: req.ModelRouting, ModelRoutingEnabled: req.ModelRoutingEnabled, @@ -345,6 +348,7 @@ func (h *GroupHandler) Update(c *gin.Context) { ImagePrice4K: req.ImagePrice4K, ClaudeCodeOnly: req.ClaudeCodeOnly, FallbackGroupID: req.FallbackGroupID, + CodexOfficialOnly: req.CodexOfficialOnly, FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, ModelRouting: req.ModelRouting, ModelRoutingEnabled: req.ModelRoutingEnabled, diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 51a11ea7782..128b1515f1c 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -185,6 +185,7 @@ func groupFromServiceBase(g *service.Group) Group { ImagePrice4K: g.ImagePrice4K, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, + CodexOfficialOnly: g.CodexOfficialOnly, FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, AllowMessagesDispatch: g.AllowMessagesDispatch, RequireOAuthOnly: g.RequireOAuthOnly, diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index b1841c622eb..a4d1aef8b4e 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -106,6 +106,8 @@ type Group struct { // Claude Code 客户端限制 ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` + // OpenAI Codex 官方客户端限制 + CodexOfficialOnly bool `json:"codex_official_only"` // 无效请求兜底分组 FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 9805bf8a703..fdea4627d98 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -44,6 +44,9 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), ) + if h.rejectOpenAINonCodexOfficialClient(c, apiKey, false) { + return + } if !h.ensureResponsesDependencies(c, reqLog) { return diff --git a/backend/internal/handler/openai_embeddings.go b/backend/internal/handler/openai_embeddings.go index 81713f7f9f4..7461d0eaaa1 100644 --- a/backend/internal/handler/openai_embeddings.go +++ b/backend/internal/handler/openai_embeddings.go @@ -42,6 +42,9 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) { zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), ) + if h.rejectOpenAINonCodexOfficialClient(c, apiKey, false) { + return + } if !h.ensureResponsesDependencies(c, reqLog) { return } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 979aaa1c484..ec8eed52b28 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -110,6 +110,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), ) + if h.rejectOpenAINonCodexOfficialClient(c, apiKey, false) { + return + } if !h.ensureResponsesDependencies(c, reqLog) { return } @@ -584,6 +587,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { "This group does not allow /v1/messages dispatch") return } + if h.rejectOpenAINonCodexOfficialClient(c, apiKey, true) { + return + } if !h.ensureResponsesDependencies(c, reqLog) { return @@ -1123,6 +1129,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { zap.Any("group_id", apiKey.GroupID), zap.Bool("openai_ws_mode", true), ) + if h.rejectOpenAINonCodexOfficialClient(c, apiKey, false) { + return + } if !h.ensureResponsesDependencies(c, reqLog) { return } diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index b3fb35eee9a..e322c0ad684 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -481,6 +481,104 @@ func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testin assert.Equal(t, "Service temporarily unavailable", errorObj["message"]) } +func TestOpenAIResponses_RejectsNonCodexWhenGroupRequiresOfficialClient(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{"model":"gpt-5","stream":false}`)) + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("User-Agent", "curl/8.0") + c.Request.Header.Set("originator", "my_client") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 10, + GroupID: &groupID, + User: &service.User{ID: 1}, + Group: &service.Group{ + ID: groupID, + Platform: service.PlatformOpenAI, + CodexOfficialOnly: true, + }, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + h := &OpenAIGatewayHandler{gatewayService: &service.OpenAIGatewayService{}} + h.Responses(c) + + require.Equal(t, http.StatusForbidden, w.Code) + require.Contains(t, w.Body.String(), "This group only allows Codex official clients") +} + +func TestOpenAIResponses_AllowsOfficialCodexPastGroupRestriction(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{"model":"gpt-5","stream":false}`)) + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.99.0") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 10, + GroupID: &groupID, + User: &service.User{ID: 1}, + Group: &service.Group{ + ID: groupID, + Platform: service.PlatformOpenAI, + CodexOfficialOnly: true, + }, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + h := &OpenAIGatewayHandler{gatewayService: &service.OpenAIGatewayService{}} + h.Responses(c) + + require.NotEqual(t, http.StatusForbidden, w.Code) +} + +func TestOpenAIMessages_RejectsNonCodexWhenGroupRequiresOfficialClient(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{"model":"claude-sonnet-4-5","messages":[]}`)) + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("User-Agent", "curl/8.0") + c.Request.Header.Set("originator", "my_client") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 10, + GroupID: &groupID, + User: &service.User{ID: 1}, + Group: &service.Group{ + ID: groupID, + Platform: service.PlatformOpenAI, + AllowMessagesDispatch: true, + CodexOfficialOnly: true, + }, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + h := &OpenAIGatewayHandler{gatewayService: &service.OpenAIGatewayService{}} + h.Messages(c) + + require.Equal(t, http.StatusForbidden, w.Code) + require.Contains(t, w.Body.String(), "This group only allows Codex official clients") +} + func TestOpenAIResponses_SetsClientTransportHTTP(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/handler/openai_group_restriction.go b/backend/internal/handler/openai_group_restriction.go new file mode 100644 index 00000000000..ed17af617e9 --- /dev/null +++ b/backend/internal/handler/openai_group_restriction.go @@ -0,0 +1,32 @@ +package handler + +import ( + "net/http" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +const openAIGroupCodexOfficialOnlyMessage = "This group only allows Codex official clients" + +func (h *OpenAIGatewayHandler) rejectOpenAINonCodexOfficialClient(c *gin.Context, apiKey *service.APIKey, anthropicFormat bool) bool { + if apiKey == nil { + return false + } + var gatewayService *service.OpenAIGatewayService + if h != nil { + gatewayService = h.gatewayService + } + result := gatewayService.DetectGroupCodexOfficialRestriction(c, apiKey.Group) + if !result.Enabled || result.Matched { + return false + } + + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonLocalPolicyDenied) + if anthropicFormat { + h.anthropicErrorResponse(c, http.StatusForbidden, "permission_error", openAIGroupCodexOfficialOnlyMessage) + return true + } + h.errorResponse(c, http.StatusForbidden, "forbidden_error", openAIGroupCodexOfficialOnlyMessage) + return true +} diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go index bbb080149c2..54d94d73b93 100644 --- a/backend/internal/handler/openai_images.go +++ b/backend/internal/handler/openai_images.go @@ -44,6 +44,9 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), ) + if h.rejectOpenAINonCodexOfficialClient(c, apiKey, false) { + return + } if !h.ensureResponsesDependencies(c, reqLog) { return } diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index bfe09283002..dc9fbe1e7f4 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -713,6 +713,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { DefaultValidityDays: g.DefaultValidityDays, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, + CodexOfficialOnly: g.CodexOfficialOnly, FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, ModelRouting: g.ModelRouting, ModelRoutingEnabled: g.ModelRoutingEnabled, diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index ac8669abdd9..db419334e55 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -59,6 +59,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetNillableFallbackGroupID(groupIn.FallbackGroupID). SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest). + SetCodexOfficialOnly(groupIn.CodexOfficialOnly). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). SetMcpXMLInject(groupIn.MCPXMLInject). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). @@ -135,6 +136,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice4k(groupIn.ImagePrice4K). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). + SetCodexOfficialOnly(groupIn.CodexOfficialOnly). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). SetMcpXMLInject(groupIn.MCPXMLInject). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 9eea092452f..1e94d06bd25 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -357,6 +357,7 @@ func TestAPIContracts(t *testing.T) { "image_rate_independent": false, "image_rate_multiplier": 0, "claude_code_only": false, + "codex_official_only": false, "allow_messages_dispatch": false, "fallback_group_id": null, "fallback_group_id_on_invalid_request": null, diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index d46b636f2cb..1f24f0104a6 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -206,6 +206,7 @@ type CreateGroupInput struct { ImagePrice4K *float64 ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 FallbackGroupID *int64 // 降级分组 ID + CodexOfficialOnly bool // OpenAI 分组仅允许 Codex 官方客户端 // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) @@ -247,6 +248,7 @@ type UpdateGroupInput struct { ImagePrice4K *float64 ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 FallbackGroupID *int64 // 降级分组 ID + CodexOfficialOnly *bool // OpenAI 分组仅允许 Codex 官方客户端 // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) @@ -1769,6 +1771,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ImagePrice4K: imagePrice4K, ClaudeCodeOnly: input.ClaudeCodeOnly, FallbackGroupID: input.FallbackGroupID, + CodexOfficialOnly: input.CodexOfficialOnly, FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest, ModelRouting: input.ModelRouting, MCPXMLInject: mcpXMLInject, @@ -1782,6 +1785,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn RPMLimit: input.RPMLimit, } sanitizeGroupMessagesDispatchFields(group) + sanitizeGroupCodexOfficialRestrictionFields(group) if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err } @@ -1979,6 +1983,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd group.FallbackGroupID = nil } } + if input.CodexOfficialOnly != nil { + group.CodexOfficialOnly = *input.CodexOfficialOnly + } fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest if input.FallbackGroupIDOnInvalidRequest != nil { if *input.FallbackGroupIDOnInvalidRequest > 0 { @@ -2033,6 +2040,7 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd group.RPMLimit = *input.RPMLimit } sanitizeGroupMessagesDispatchFields(group) + sanitizeGroupCodexOfficialRestrictionFields(group) if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 74163179c8f..dd1d1d0a3d5 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -72,6 +72,7 @@ type APIKeyAuthGroupSnapshot struct { ImagePrice4K *float64 `json:"image_price_4k,omitempty"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + CodexOfficialOnly bool `json:"codex_official_only"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` // Model routing is used by gateway account selection, so it must be part of auth cache snapshot. diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 69c6086f75d..e8edd8c399b 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -14,7 +14,7 @@ import ( "github.com/dgraph-io/ristretto" ) -const apiKeyAuthSnapshotVersion = 11 // v11: reload snapshots for custom models_list_config +const apiKeyAuthSnapshotVersion = 12 // v12: reload snapshots for group codex_official_only type apiKeyAuthCacheConfig struct { l1Size int @@ -264,6 +264,7 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) ImagePrice4K: apiKey.Group.ImagePrice4K, ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, FallbackGroupID: apiKey.Group.FallbackGroupID, + CodexOfficialOnly: apiKey.Group.CodexOfficialOnly, FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest, ModelRouting: apiKey.Group.ModelRouting, ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, @@ -335,6 +336,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ImagePrice4K: snapshot.Group.ImagePrice4K, ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, FallbackGroupID: snapshot.Group.FallbackGroupID, + CodexOfficialOnly: snapshot.Group.CodexOfficialOnly, FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest, ModelRouting: snapshot.Group.ModelRouting, ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index eaac9a1c898..c7cfbb08d40 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -251,6 +251,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t Status: StatusActive, SubscriptionType: SubscriptionTypeStandard, RateMultiplier: 1, + CodexOfficialOnly: true, AllowMessagesDispatch: true, DefaultMappedModel: "gpt-5.4", MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{ @@ -270,6 +271,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t require.NotNil(t, roundTrip) require.Equal(t, apiKey.Name, roundTrip.Name) require.NotNil(t, roundTrip.Group) + require.True(t, roundTrip.Group.CodexOfficialOnly) require.Equal(t, apiKey.Group.MessagesDispatchModelConfig, roundTrip.Group.MessagesDispatchModelConfig) } diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 9aa2a52f4aa..ca9c327a5e5 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -37,6 +37,8 @@ type Group struct { // Claude Code 客户端限制 ClaudeCodeOnly bool FallbackGroupID *int64 + // OpenAI Codex 官方客户端限制(仅 openai 平台使用) + CodexOfficialOnly bool // 无效请求兜底分组(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 diff --git a/backend/internal/service/openai_client_restriction_detector.go b/backend/internal/service/openai_client_restriction_detector.go index 8589737ade7..3e6fec958cf 100644 --- a/backend/internal/service/openai_client_restriction_detector.go +++ b/backend/internal/service/openai_client_restriction_detector.go @@ -33,6 +33,7 @@ type CodexClientRestrictionDetectionResult struct { // CodexClientRestrictionDetector 定义 codex_cli_only 统一检测入口。 type CodexClientRestrictionDetector interface { Detect(c *gin.Context, account *Account, globalAllowedClients []string) CodexClientRestrictionDetectionResult + DetectPolicy(c *gin.Context, enabled bool, allowedClients []string, globalAllowedClients []string) CodexClientRestrictionDetectionResult } // OpenAICodexClientRestrictionDetector 为 OpenAI OAuth codex_cli_only 的默认实现。 @@ -52,7 +53,17 @@ func (d *OpenAICodexClientRestrictionDetector) Detect(c *gin.Context, account *A Reason: CodexClientRestrictionReasonDisabled, } } + return d.DetectPolicy(c, true, account.GetCodexCLIOnlyAllowedClients(), globalAllowedClients) +} +func (d *OpenAICodexClientRestrictionDetector) DetectPolicy(c *gin.Context, enabled bool, allowedClients []string, globalAllowedClients []string) CodexClientRestrictionDetectionResult { + if !enabled { + return CodexClientRestrictionDetectionResult{ + Enabled: false, + Matched: false, + Reason: CodexClientRestrictionReasonDisabled, + } + } if d != nil && d.cfg != nil && d.cfg.Gateway.ForceCodexCLI { return CodexClientRestrictionDetectionResult{ Enabled: true, @@ -82,9 +93,9 @@ func (d *OpenAICodexClientRestrictionDetector) Detect(c *gin.Context, account *A } } - // 官方客户端白名单未命中时,先尝试账号级额外放行的命名客户端预设(如 Claude Code codex 插件)。 - if allowed := account.GetCodexCLIOnlyAllowedClients(); len(allowed) > 0 && - openai.MatchAllowedClients(userAgent, originator, allowed) { + // 官方客户端白名单未命中时,先尝试策略级额外放行的命名客户端预设(如账号级 Claude Code codex 插件)。 + if len(allowedClients) > 0 && + openai.MatchAllowedClients(userAgent, originator, allowedClients) { return CodexClientRestrictionDetectionResult{ Enabled: true, Matched: true, diff --git a/backend/internal/service/openai_client_restriction_detector_test.go b/backend/internal/service/openai_client_restriction_detector_test.go index fc115128cb4..52c862f25d9 100644 --- a/backend/internal/service/openai_client_restriction_detector_test.go +++ b/backend/internal/service/openai_client_restriction_detector_test.go @@ -259,3 +259,57 @@ func TestOpenAICodexClientRestrictionDetector_Detect_AllowedClients(t *testing.T require.Equal(t, CodexClientRestrictionReasonMatchedAllowedClient, result.Reason) }) } + +func TestOpenAICodexClientRestrictionDetector_DetectPolicy_GroupCodexOfficialOnly(t *testing.T) { + gin.SetMode(gin.TestMode) + + const ( + claudeCodeUA = "Claude Code/0.5.0 (Macos 15.5; arm64) iTerm2.app (Claude Code; 1.0.4)" + claudeCodeOriginator = "Claude Code" + ) + + t.Run("disabled policy bypasses", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + + result := detector.DetectPolicy(newCodexDetectorTestContext("curl/8.0", "my_client"), false, nil, nil) + + require.False(t, result.Enabled) + require.False(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonDisabled, result.Reason) + }) + + t.Run("enabled policy rejects non official client", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + + result := detector.DetectPolicy(newCodexDetectorTestContext("curl/8.0", "my_client"), true, nil, nil) + + require.True(t, result.Enabled) + require.False(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason) + }) + + t.Run("enabled policy allows official codex user agent", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + + result := detector.DetectPolicy(newCodexDetectorTestContext("codex_cli_rs/0.99.0", ""), true, nil, nil) + + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason) + }) + + t.Run("enabled policy allows global claude code preset", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + + result := detector.DetectPolicy( + newCodexDetectorTestContext(claudeCodeUA, claudeCodeOriginator), + true, + nil, + []string{"claude_code"}, + ) + + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonMatchedGlobalAllowedClient, result.Reason) + }) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 6d618ee4820..70aa84fb8fd 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -902,14 +902,8 @@ func SnapshotOpenAICompatibilityFallbackMetrics() OpenAICompatibilityFallbackMet func (s *OpenAIGatewayService) detectCodexClientRestriction(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult { var globalAllowedClients []string - if account != nil && account.IsCodexCLIOnlyEnabled() && s != nil && s.settingService != nil { - ctx := context.Background() - if c != nil && c.Request != nil { - ctx = c.Request.Context() - } - if s.settingService.IsOpenAIAllowClaudeCodeCodexPluginEnabled(ctx) { - globalAllowedClients = []string{openai.AllowedClientClaudeCode} - } + if account != nil && account.IsCodexCLIOnlyEnabled() { + globalAllowedClients = s.codexRestrictionGlobalAllowedClients(c) } return s.getCodexClientRestrictionDetector().Detect(c, account, globalAllowedClients) } diff --git a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go index 10d58654093..c9db4d821db 100644 --- a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go +++ b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go @@ -22,6 +22,10 @@ func (s *stubCodexRestrictionDetector) Detect(_ *gin.Context, _ *Account, _ []st return s.result } +func (s *stubCodexRestrictionDetector) DetectPolicy(_ *gin.Context, _ bool, _ []string, _ []string) CodexClientRestrictionDetectionResult { + return s.result +} + func TestOpenAIGatewayService_GetCodexClientRestrictionDetector(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_group_restriction.go b/backend/internal/service/openai_group_restriction.go new file mode 100644 index 00000000000..08d2d61137c --- /dev/null +++ b/backend/internal/service/openai_group_restriction.go @@ -0,0 +1,40 @@ +package service + +import ( + "context" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/gin-gonic/gin" +) + +func sanitizeGroupCodexOfficialRestrictionFields(g *Group) { + if g == nil || g.Platform == PlatformOpenAI { + return + } + g.CodexOfficialOnly = false +} + +func (s *OpenAIGatewayService) DetectGroupCodexOfficialRestriction(c *gin.Context, group *Group) CodexClientRestrictionDetectionResult { + if group == nil || group.Platform != PlatformOpenAI || !group.CodexOfficialOnly { + return CodexClientRestrictionDetectionResult{ + Enabled: false, + Matched: false, + Reason: CodexClientRestrictionReasonDisabled, + } + } + return s.getCodexClientRestrictionDetector().DetectPolicy(c, true, nil, s.codexRestrictionGlobalAllowedClients(c)) +} + +func (s *OpenAIGatewayService) codexRestrictionGlobalAllowedClients(c *gin.Context) []string { + if s == nil || s.settingService == nil { + return nil + } + ctx := context.Background() + if c != nil && c.Request != nil { + ctx = c.Request.Context() + } + if s.settingService.IsOpenAIAllowClaudeCodeCodexPluginEnabled(ctx) { + return []string{openai.AllowedClientClaudeCode} + } + return nil +} diff --git a/backend/migrations/145_add_group_codex_official_restriction.sql b/backend/migrations/145_add_group_codex_official_restriction.sql new file mode 100644 index 00000000000..751e961969f --- /dev/null +++ b/backend/migrations/145_add_group_codex_official_restriction.sql @@ -0,0 +1,9 @@ +-- Add OpenAI group-level Codex official client restriction. + +ALTER TABLE groups +ADD COLUMN IF NOT EXISTS codex_official_only BOOLEAN NOT NULL DEFAULT FALSE; + +CREATE INDEX IF NOT EXISTS idx_groups_codex_official_only +ON groups(codex_official_only) WHERE deleted_at IS NULL; + +COMMENT ON COLUMN groups.codex_official_only IS 'OpenAI 分组是否仅允许 Codex 官方客户端'; diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index b2aeb2f83f5..07f14536c17 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -2194,6 +2194,10 @@ export default { fallbackHint: 'Non-Claude Code requests will use this group. Leave empty to reject directly.', noFallback: 'No Fallback (Reject)' }, + openaiCodex: { + title: 'Codex Official Client Restriction', + hint: 'When enabled, API keys in this OpenAI group only accept official Codex clients; other clients are rejected.' + }, openaiMessages: { title: 'OpenAI Messages Dispatch', allowDispatch: 'Allow /v1/messages dispatch', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 85d1feee456..f9c3c9fe94a 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -2278,6 +2278,10 @@ export default { fallbackHint: '非 Claude Code 请求将使用此分组,留空则直接拒绝', noFallback: '不降级(直接拒绝)' }, + openaiCodex: { + title: 'Codex 官方客户端限制', + hint: '启用后,此 OpenAI 分组的 API Key 仅允许 Codex 官方客户端访问,其他客户端将被拒绝。' + }, openaiMessages: { title: 'OpenAI Messages 调度配置', allowDispatch: '允许 /v1/messages 调度', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index c21361693af..d765e667e0d 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -518,6 +518,8 @@ export interface Group { // Claude Code 客户端限制 claude_code_only: boolean fallback_group_id: number | null + // OpenAI Codex 官方客户端限制 + codex_official_only: boolean fallback_group_id_on_invalid_request: number | null // OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程) allow_messages_dispatch?: boolean @@ -635,6 +637,7 @@ export interface CreateGroupRequest { image_price_4k?: number | null claude_code_only?: boolean fallback_group_id?: number | null + codex_official_only?: boolean fallback_group_id_on_invalid_request?: number | null mcp_xml_inject?: boolean supported_model_scopes?: string[] @@ -670,6 +673,7 @@ export interface UpdateGroupRequest { image_price_4k?: number | null claude_code_only?: boolean fallback_group_id?: number | null + codex_official_only?: boolean fallback_group_id_on_invalid_request?: number | null mcp_xml_inject?: boolean supported_model_scopes?: string[] diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue index 0b583a095e3..a7d755c352a 100644 --- a/frontend/src/views/admin/GroupsView.vue +++ b/frontend/src/views/admin/GroupsView.vue @@ -1069,6 +1069,40 @@ {{ t("admin.groups.openaiMessages.title") }} + +
+
+ + +
+

+ {{ t("admin.groups.openaiCodex.hint") }} +

+
+
+ +
+
+ +
+ +
+
+

+ {{ t("admin.groups.openaiCodex.tooltip") }} +

+
+
+
+
+
+
+ + + {{ + createForm.codex_official_only + ? t("admin.groups.openaiCodex.enabled") + : t("admin.groups.openaiCodex.disabled") + }} + +
+
+ + +

+ {{ t("admin.groups.openaiCodex.fallbackHint") }} +

+
+
+
- -
-
- - -
-

- {{ t("admin.groups.openaiCodex.hint") }} -

-
-