diff --git a/backend/ent/account.go b/backend/ent/account.go index 2dbfc3a278b..e822c047cb6 100644 --- a/backend/ent/account.go +++ b/backend/ent/account.go @@ -57,6 +57,8 @@ type Account struct { ExpiresAt *time.Time `json:"expires_at,omitempty"` // Auto pause scheduling when account expires. AutoPauseOnExpired bool `json:"auto_pause_on_expired,omitempty"` + // Strip reasoning_effort when converting Responses→ChatCompletions for upstreams that reject it (e.g. b.ai). + StripReasoningEffortOnCc bool `json:"strip_reasoning_effort_on_cc,omitempty"` // Schedulable holds the value of the "schedulable" field. Schedulable bool `json:"schedulable,omitempty"` // RateLimitedAt holds the value of the "rate_limited_at" field. @@ -141,7 +143,7 @@ func (*Account) scanValues(columns []string) ([]any, error) { switch columns[i] { case account.FieldCredentials, account.FieldExtra: values[i] = new([]byte) - case account.FieldAutoPauseOnExpired, account.FieldSchedulable: + case account.FieldAutoPauseOnExpired, account.FieldStripReasoningEffortOnCc, account.FieldSchedulable: values[i] = new(sql.NullBool) case account.FieldRateMultiplier: values[i] = new(sql.NullFloat64) @@ -297,6 +299,12 @@ func (_m *Account) assignValues(columns []string, values []any) error { } else if value.Valid { _m.AutoPauseOnExpired = value.Bool } + case account.FieldStripReasoningEffortOnCc: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field strip_reasoning_effort_on_cc", values[i]) + } else if value.Valid { + _m.StripReasoningEffortOnCc = value.Bool + } case account.FieldSchedulable: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field schedulable", values[i]) @@ -486,6 +494,9 @@ func (_m *Account) String() string { builder.WriteString("auto_pause_on_expired=") builder.WriteString(fmt.Sprintf("%v", _m.AutoPauseOnExpired)) builder.WriteString(", ") + builder.WriteString("strip_reasoning_effort_on_cc=") + builder.WriteString(fmt.Sprintf("%v", _m.StripReasoningEffortOnCc)) + builder.WriteString(", ") builder.WriteString("schedulable=") builder.WriteString(fmt.Sprintf("%v", _m.Schedulable)) builder.WriteString(", ") diff --git a/backend/ent/account/account.go b/backend/ent/account/account.go index 4c1346490a1..e89dd5a8edb 100644 --- a/backend/ent/account/account.go +++ b/backend/ent/account/account.go @@ -53,6 +53,8 @@ const ( FieldExpiresAt = "expires_at" // FieldAutoPauseOnExpired holds the string denoting the auto_pause_on_expired field in the database. FieldAutoPauseOnExpired = "auto_pause_on_expired" + // FieldStripReasoningEffortOnCc holds the string denoting the strip_reasoning_effort_on_cc field in the database. + FieldStripReasoningEffortOnCc = "strip_reasoning_effort_on_cc" // FieldSchedulable holds the string denoting the schedulable field in the database. FieldSchedulable = "schedulable" // FieldRateLimitedAt holds the string denoting the rate_limited_at field in the database. @@ -131,6 +133,7 @@ var Columns = []string{ FieldLastUsedAt, FieldExpiresAt, FieldAutoPauseOnExpired, + FieldStripReasoningEffortOnCc, FieldSchedulable, FieldRateLimitedAt, FieldRateLimitResetAt, @@ -194,6 +197,8 @@ var ( StatusValidator func(string) error // DefaultAutoPauseOnExpired holds the default value on creation for the "auto_pause_on_expired" field. DefaultAutoPauseOnExpired bool + // DefaultStripReasoningEffortOnCc holds the default value on creation for the "strip_reasoning_effort_on_cc" field. + DefaultStripReasoningEffortOnCc bool // DefaultSchedulable holds the default value on creation for the "schedulable" field. DefaultSchedulable bool // SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save. @@ -293,6 +298,11 @@ func ByAutoPauseOnExpired(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldAutoPauseOnExpired, opts...).ToFunc() } +// ByStripReasoningEffortOnCc orders the results by the strip_reasoning_effort_on_cc field. +func ByStripReasoningEffortOnCc(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStripReasoningEffortOnCc, opts...).ToFunc() +} + // BySchedulable orders the results by the schedulable field. func BySchedulable(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldSchedulable, opts...).ToFunc() diff --git a/backend/ent/account/where.go b/backend/ent/account/where.go index 3749b45c556..d757a0fd353 100644 --- a/backend/ent/account/where.go +++ b/backend/ent/account/where.go @@ -140,6 +140,11 @@ func AutoPauseOnExpired(v bool) predicate.Account { return predicate.Account(sql.FieldEQ(FieldAutoPauseOnExpired, v)) } +// StripReasoningEffortOnCc applies equality check predicate on the "strip_reasoning_effort_on_cc" field. It's identical to StripReasoningEffortOnCcEQ. +func StripReasoningEffortOnCc(v bool) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldStripReasoningEffortOnCc, v)) +} + // Schedulable applies equality check predicate on the "schedulable" field. It's identical to SchedulableEQ. func Schedulable(v bool) predicate.Account { return predicate.Account(sql.FieldEQ(FieldSchedulable, v)) @@ -1035,6 +1040,16 @@ func AutoPauseOnExpiredNEQ(v bool) predicate.Account { return predicate.Account(sql.FieldNEQ(FieldAutoPauseOnExpired, v)) } +// StripReasoningEffortOnCcEQ applies the EQ predicate on the "strip_reasoning_effort_on_cc" field. +func StripReasoningEffortOnCcEQ(v bool) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldStripReasoningEffortOnCc, v)) +} + +// StripReasoningEffortOnCcNEQ applies the NEQ predicate on the "strip_reasoning_effort_on_cc" field. +func StripReasoningEffortOnCcNEQ(v bool) predicate.Account { + return predicate.Account(sql.FieldNEQ(FieldStripReasoningEffortOnCc, v)) +} + // SchedulableEQ applies the EQ predicate on the "schedulable" field. func SchedulableEQ(v bool) predicate.Account { return predicate.Account(sql.FieldEQ(FieldSchedulable, v)) diff --git a/backend/ent/account_create.go b/backend/ent/account_create.go index d6046c79775..ae7222cec71 100644 --- a/backend/ent/account_create.go +++ b/backend/ent/account_create.go @@ -251,6 +251,20 @@ func (_c *AccountCreate) SetNillableAutoPauseOnExpired(v *bool) *AccountCreate { return _c } +// SetStripReasoningEffortOnCc sets the "strip_reasoning_effort_on_cc" field. +func (_c *AccountCreate) SetStripReasoningEffortOnCc(v bool) *AccountCreate { + _c.mutation.SetStripReasoningEffortOnCc(v) + return _c +} + +// SetNillableStripReasoningEffortOnCc sets the "strip_reasoning_effort_on_cc" field if the given value is not nil. +func (_c *AccountCreate) SetNillableStripReasoningEffortOnCc(v *bool) *AccountCreate { + if v != nil { + _c.SetStripReasoningEffortOnCc(*v) + } + return _c +} + // SetSchedulable sets the "schedulable" field. func (_c *AccountCreate) SetSchedulable(v bool) *AccountCreate { _c.mutation.SetSchedulable(v) @@ -497,6 +511,10 @@ func (_c *AccountCreate) defaults() error { v := account.DefaultAutoPauseOnExpired _c.mutation.SetAutoPauseOnExpired(v) } + if _, ok := _c.mutation.StripReasoningEffortOnCc(); !ok { + v := account.DefaultStripReasoningEffortOnCc + _c.mutation.SetStripReasoningEffortOnCc(v) + } if _, ok := _c.mutation.Schedulable(); !ok { v := account.DefaultSchedulable _c.mutation.SetSchedulable(v) @@ -562,6 +580,9 @@ func (_c *AccountCreate) check() error { if _, ok := _c.mutation.AutoPauseOnExpired(); !ok { return &ValidationError{Name: "auto_pause_on_expired", err: errors.New(`ent: missing required field "Account.auto_pause_on_expired"`)} } + if _, ok := _c.mutation.StripReasoningEffortOnCc(); !ok { + return &ValidationError{Name: "strip_reasoning_effort_on_cc", err: errors.New(`ent: missing required field "Account.strip_reasoning_effort_on_cc"`)} + } if _, ok := _c.mutation.Schedulable(); !ok { return &ValidationError{Name: "schedulable", err: errors.New(`ent: missing required field "Account.schedulable"`)} } @@ -669,6 +690,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) { _spec.SetField(account.FieldAutoPauseOnExpired, field.TypeBool, value) _node.AutoPauseOnExpired = value } + if value, ok := _c.mutation.StripReasoningEffortOnCc(); ok { + _spec.SetField(account.FieldStripReasoningEffortOnCc, field.TypeBool, value) + _node.StripReasoningEffortOnCc = value + } if value, ok := _c.mutation.Schedulable(); ok { _spec.SetField(account.FieldSchedulable, field.TypeBool, value) _node.Schedulable = value @@ -1092,6 +1117,18 @@ func (u *AccountUpsert) UpdateAutoPauseOnExpired() *AccountUpsert { return u } +// SetStripReasoningEffortOnCc sets the "strip_reasoning_effort_on_cc" field. +func (u *AccountUpsert) SetStripReasoningEffortOnCc(v bool) *AccountUpsert { + u.Set(account.FieldStripReasoningEffortOnCc, v) + return u +} + +// UpdateStripReasoningEffortOnCc sets the "strip_reasoning_effort_on_cc" field to the value that was provided on create. +func (u *AccountUpsert) UpdateStripReasoningEffortOnCc() *AccountUpsert { + u.SetExcluded(account.FieldStripReasoningEffortOnCc) + return u +} + // SetSchedulable sets the "schedulable" field. func (u *AccountUpsert) SetSchedulable(v bool) *AccountUpsert { u.Set(account.FieldSchedulable, v) @@ -1622,6 +1659,20 @@ func (u *AccountUpsertOne) UpdateAutoPauseOnExpired() *AccountUpsertOne { }) } +// SetStripReasoningEffortOnCc sets the "strip_reasoning_effort_on_cc" field. +func (u *AccountUpsertOne) SetStripReasoningEffortOnCc(v bool) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.SetStripReasoningEffortOnCc(v) + }) +} + +// UpdateStripReasoningEffortOnCc sets the "strip_reasoning_effort_on_cc" field to the value that was provided on create. +func (u *AccountUpsertOne) UpdateStripReasoningEffortOnCc() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.UpdateStripReasoningEffortOnCc() + }) +} + // SetSchedulable sets the "schedulable" field. func (u *AccountUpsertOne) SetSchedulable(v bool) *AccountUpsertOne { return u.Update(func(s *AccountUpsert) { @@ -2344,6 +2395,20 @@ func (u *AccountUpsertBulk) UpdateAutoPauseOnExpired() *AccountUpsertBulk { }) } +// SetStripReasoningEffortOnCc sets the "strip_reasoning_effort_on_cc" field. +func (u *AccountUpsertBulk) SetStripReasoningEffortOnCc(v bool) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.SetStripReasoningEffortOnCc(v) + }) +} + +// UpdateStripReasoningEffortOnCc sets the "strip_reasoning_effort_on_cc" field to the value that was provided on create. +func (u *AccountUpsertBulk) UpdateStripReasoningEffortOnCc() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.UpdateStripReasoningEffortOnCc() + }) +} + // SetSchedulable sets the "schedulable" field. func (u *AccountUpsertBulk) SetSchedulable(v bool) *AccountUpsertBulk { return u.Update(func(s *AccountUpsert) { diff --git a/backend/ent/account_update.go b/backend/ent/account_update.go index 6f443c65e06..d3882c6bfec 100644 --- a/backend/ent/account_update.go +++ b/backend/ent/account_update.go @@ -329,6 +329,20 @@ func (_u *AccountUpdate) SetNillableAutoPauseOnExpired(v *bool) *AccountUpdate { return _u } +// SetStripReasoningEffortOnCc sets the "strip_reasoning_effort_on_cc" field. +func (_u *AccountUpdate) SetStripReasoningEffortOnCc(v bool) *AccountUpdate { + _u.mutation.SetStripReasoningEffortOnCc(v) + return _u +} + +// SetNillableStripReasoningEffortOnCc sets the "strip_reasoning_effort_on_cc" field if the given value is not nil. +func (_u *AccountUpdate) SetNillableStripReasoningEffortOnCc(v *bool) *AccountUpdate { + if v != nil { + _u.SetStripReasoningEffortOnCc(*v) + } + return _u +} + // SetSchedulable sets the "schedulable" field. func (_u *AccountUpdate) SetSchedulable(v bool) *AccountUpdate { _u.mutation.SetSchedulable(v) @@ -756,6 +770,9 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.AutoPauseOnExpired(); ok { _spec.SetField(account.FieldAutoPauseOnExpired, field.TypeBool, value) } + if value, ok := _u.mutation.StripReasoningEffortOnCc(); ok { + _spec.SetField(account.FieldStripReasoningEffortOnCc, field.TypeBool, value) + } if value, ok := _u.mutation.Schedulable(); ok { _spec.SetField(account.FieldSchedulable, field.TypeBool, value) } @@ -1256,6 +1273,20 @@ func (_u *AccountUpdateOne) SetNillableAutoPauseOnExpired(v *bool) *AccountUpdat return _u } +// SetStripReasoningEffortOnCc sets the "strip_reasoning_effort_on_cc" field. +func (_u *AccountUpdateOne) SetStripReasoningEffortOnCc(v bool) *AccountUpdateOne { + _u.mutation.SetStripReasoningEffortOnCc(v) + return _u +} + +// SetNillableStripReasoningEffortOnCc sets the "strip_reasoning_effort_on_cc" field if the given value is not nil. +func (_u *AccountUpdateOne) SetNillableStripReasoningEffortOnCc(v *bool) *AccountUpdateOne { + if v != nil { + _u.SetStripReasoningEffortOnCc(*v) + } + return _u +} + // SetSchedulable sets the "schedulable" field. func (_u *AccountUpdateOne) SetSchedulable(v bool) *AccountUpdateOne { _u.mutation.SetSchedulable(v) @@ -1713,6 +1744,9 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er if value, ok := _u.mutation.AutoPauseOnExpired(); ok { _spec.SetField(account.FieldAutoPauseOnExpired, field.TypeBool, value) } + if value, ok := _u.mutation.StripReasoningEffortOnCc(); ok { + _spec.SetField(account.FieldStripReasoningEffortOnCc, field.TypeBool, value) + } if value, ok := _u.mutation.Schedulable(); ok { _spec.SetField(account.FieldSchedulable, field.TypeBool, value) } diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 7abe4c601e3..4ff60d19cd0 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -102,7 +102,7 @@ var ( {Name: "name", Type: field.TypeString, Size: 100}, {Name: "notes", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "platform", Type: field.TypeString, Size: 50}, - {Name: "type", Type: field.TypeString, Size: 20}, + {Name: "type", Type: field.TypeString, Size: 40}, {Name: "credentials", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "concurrency", Type: field.TypeInt, Default: 3}, @@ -114,6 +114,7 @@ var ( {Name: "last_used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "auto_pause_on_expired", Type: field.TypeBool, Default: true}, + {Name: "strip_reasoning_effort_on_cc", Type: field.TypeBool, Default: false}, {Name: "schedulable", Type: field.TypeBool, Default: true}, {Name: "rate_limited_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "rate_limit_reset_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, @@ -133,7 +134,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "accounts_proxies_proxy", - Columns: []*schema.Column{AccountsColumns[28]}, + Columns: []*schema.Column{AccountsColumns[29]}, RefColumns: []*schema.Column{ProxiesColumns[0]}, OnDelete: schema.SetNull, }, @@ -157,7 +158,7 @@ var ( { Name: "account_proxy_id", Unique: false, - Columns: []*schema.Column{AccountsColumns[28]}, + Columns: []*schema.Column{AccountsColumns[29]}, }, { Name: "account_priority", @@ -172,22 +173,22 @@ var ( { Name: "account_schedulable", Unique: false, - Columns: []*schema.Column{AccountsColumns[19]}, + Columns: []*schema.Column{AccountsColumns[20]}, }, { Name: "account_rate_limited_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[20]}, + Columns: []*schema.Column{AccountsColumns[21]}, }, { Name: "account_rate_limit_reset_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[21]}, + Columns: []*schema.Column{AccountsColumns[22]}, }, { Name: "account_overload_until", Unique: false, - Columns: []*schema.Column{AccountsColumns[22]}, + Columns: []*schema.Column{AccountsColumns[23]}, }, { Name: "account_platform_priority", diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 003e25d5015..5987777b4dd 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -2274,52 +2274,53 @@ func (m *APIKeyMutation) ResetEdge(name string) error { // AccountMutation represents an operation that mutates the Account nodes in the graph. type AccountMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - notes *string - platform *string - _type *string - credentials *map[string]interface{} - extra *map[string]interface{} - concurrency *int - addconcurrency *int - load_factor *int - addload_factor *int - priority *int - addpriority *int - rate_multiplier *float64 - addrate_multiplier *float64 - status *string - error_message *string - last_used_at *time.Time - expires_at *time.Time - auto_pause_on_expired *bool - schedulable *bool - rate_limited_at *time.Time - rate_limit_reset_at *time.Time - overload_until *time.Time - temp_unschedulable_until *time.Time - temp_unschedulable_reason *string - session_window_start *time.Time - session_window_end *time.Time - session_window_status *string - clearedFields map[string]struct{} - groups map[int64]struct{} - removedgroups map[int64]struct{} - clearedgroups bool - proxy *int64 - clearedproxy bool - usage_logs map[int64]struct{} - removedusage_logs map[int64]struct{} - clearedusage_logs bool - done bool - oldValue func(context.Context) (*Account, error) - predicates []predicate.Account + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + notes *string + platform *string + _type *string + credentials *map[string]interface{} + extra *map[string]interface{} + concurrency *int + addconcurrency *int + load_factor *int + addload_factor *int + priority *int + addpriority *int + rate_multiplier *float64 + addrate_multiplier *float64 + status *string + error_message *string + last_used_at *time.Time + expires_at *time.Time + auto_pause_on_expired *bool + strip_reasoning_effort_on_cc *bool + schedulable *bool + rate_limited_at *time.Time + rate_limit_reset_at *time.Time + overload_until *time.Time + temp_unschedulable_until *time.Time + temp_unschedulable_reason *string + session_window_start *time.Time + session_window_end *time.Time + session_window_status *string + clearedFields map[string]struct{} + groups map[int64]struct{} + removedgroups map[int64]struct{} + clearedgroups bool + proxy *int64 + clearedproxy bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + done bool + oldValue func(context.Context) (*Account, error) + predicates []predicate.Account } var _ ent.Mutation = (*AccountMutation)(nil) @@ -3276,6 +3277,42 @@ func (m *AccountMutation) ResetAutoPauseOnExpired() { m.auto_pause_on_expired = nil } +// SetStripReasoningEffortOnCc sets the "strip_reasoning_effort_on_cc" field. +func (m *AccountMutation) SetStripReasoningEffortOnCc(b bool) { + m.strip_reasoning_effort_on_cc = &b +} + +// StripReasoningEffortOnCc returns the value of the "strip_reasoning_effort_on_cc" field in the mutation. +func (m *AccountMutation) StripReasoningEffortOnCc() (r bool, exists bool) { + v := m.strip_reasoning_effort_on_cc + if v == nil { + return + } + return *v, true +} + +// OldStripReasoningEffortOnCc returns the old "strip_reasoning_effort_on_cc" field's value of the Account entity. +// If the Account object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AccountMutation) OldStripReasoningEffortOnCc(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStripReasoningEffortOnCc is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStripReasoningEffortOnCc requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStripReasoningEffortOnCc: %w", err) + } + return oldValue.StripReasoningEffortOnCc, nil +} + +// ResetStripReasoningEffortOnCc resets all changes to the "strip_reasoning_effort_on_cc" field. +func (m *AccountMutation) ResetStripReasoningEffortOnCc() { + m.strip_reasoning_effort_on_cc = nil +} + // SetSchedulable sets the "schedulable" field. func (m *AccountMutation) SetSchedulable(b bool) { m.schedulable = &b @@ -3873,7 +3910,7 @@ func (m *AccountMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AccountMutation) Fields() []string { - fields := make([]string, 0, 28) + fields := make([]string, 0, 29) if m.created_at != nil { fields = append(fields, account.FieldCreatedAt) } @@ -3931,6 +3968,9 @@ func (m *AccountMutation) Fields() []string { if m.auto_pause_on_expired != nil { fields = append(fields, account.FieldAutoPauseOnExpired) } + if m.strip_reasoning_effort_on_cc != nil { + fields = append(fields, account.FieldStripReasoningEffortOnCc) + } if m.schedulable != nil { fields = append(fields, account.FieldSchedulable) } @@ -4004,6 +4044,8 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) { return m.ExpiresAt() case account.FieldAutoPauseOnExpired: return m.AutoPauseOnExpired() + case account.FieldStripReasoningEffortOnCc: + return m.StripReasoningEffortOnCc() case account.FieldSchedulable: return m.Schedulable() case account.FieldRateLimitedAt: @@ -4069,6 +4111,8 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldExpiresAt(ctx) case account.FieldAutoPauseOnExpired: return m.OldAutoPauseOnExpired(ctx) + case account.FieldStripReasoningEffortOnCc: + return m.OldStripReasoningEffortOnCc(ctx) case account.FieldSchedulable: return m.OldSchedulable(ctx) case account.FieldRateLimitedAt: @@ -4229,6 +4273,13 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error { } m.SetAutoPauseOnExpired(v) return nil + case account.FieldStripReasoningEffortOnCc: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStripReasoningEffortOnCc(v) + return nil case account.FieldSchedulable: v, ok := value.(bool) if !ok { @@ -4542,6 +4593,9 @@ func (m *AccountMutation) ResetField(name string) error { case account.FieldAutoPauseOnExpired: m.ResetAutoPauseOnExpired() return nil + case account.FieldStripReasoningEffortOnCc: + m.ResetStripReasoningEffortOnCc() + return nil case account.FieldSchedulable: m.ResetSchedulable() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index fdb837e805f..bfb92011da8 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -244,12 +244,16 @@ func init() { accountDescAutoPauseOnExpired := accountFields[15].Descriptor() // account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field. account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool) + // accountDescStripReasoningEffortOnCc is the schema descriptor for strip_reasoning_effort_on_cc field. + accountDescStripReasoningEffortOnCc := accountFields[16].Descriptor() + // account.DefaultStripReasoningEffortOnCc holds the default value on creation for the strip_reasoning_effort_on_cc field. + account.DefaultStripReasoningEffortOnCc = accountDescStripReasoningEffortOnCc.Default.(bool) // accountDescSchedulable is the schema descriptor for schedulable field. - accountDescSchedulable := accountFields[16].Descriptor() + accountDescSchedulable := accountFields[17].Descriptor() // account.DefaultSchedulable holds the default value on creation for the schedulable field. account.DefaultSchedulable = accountDescSchedulable.Default.(bool) // accountDescSessionWindowStatus is the schema descriptor for session_window_status field. - accountDescSessionWindowStatus := accountFields[24].Descriptor() + accountDescSessionWindowStatus := accountFields[25].Descriptor() // account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save. account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error) accountgroupFields := schema.AccountGroup{}.Fields() diff --git a/backend/ent/schema/account.go b/backend/ent/schema/account.go index 5616d39915b..96dc3998531 100644 --- a/backend/ent/schema/account.go +++ b/backend/ent/schema/account.go @@ -68,7 +68,7 @@ func (Account) Fields() []ent.Field { // type: 认证类型,如 "api_key", "oauth", "cookie" 等 // 不同类型决定了 credentials 中存储的数据结构 field.String("type"). - MaxLen(20). + MaxLen(40). NotEmpty(), // credentials: 认证凭证,以 JSONB 格式存储 @@ -136,6 +136,10 @@ func (Account) Fields() []ent.Field { field.Bool("auto_pause_on_expired"). Default(true). Comment("Auto pause scheduling when account expires."), + // strip_reasoning_effort_on_cc: 转换 Responses→ChatCompletions 时剥离 reasoning_effort + field.Bool("strip_reasoning_effort_on_cc"). + Default(false). + Comment("Strip reasoning_effort when converting Responses→ChatCompletions for upstreams that reject it (e.g. b.ai)."), // ========== 调度和速率限制相关字段 ========== // 这些字段在 migrations/005_schema_parity.sql 中添加 diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 7601f35bc95..47cbef9063f 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -32,6 +32,7 @@ const ( AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分) AccountTypeServiceAccount = "service_account" // Google Service Account 类型账号(用于 Vertex AI) + AccountTypeAPIKeyChatCompletions = "apikey-chat-completions" ) // Redeem type constants diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index 50beadf68e6..35c69d76faf 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net/url" "strconv" "strings" "time" @@ -47,18 +48,19 @@ type DataProxy struct { // Credentials 原文返回。这是"管理员备份"这一显式行为的一部分;如未来需要导出脱敏版本, // 应新增独立结构而非修改这里。 type DataAccount struct { - Name string `json:"name"` - Notes *string `json:"notes,omitempty"` - Platform string `json:"platform"` - Type string `json:"type"` - Credentials map[string]any `json:"credentials"` - Extra map[string]any `json:"extra,omitempty"` - ProxyKey *string `json:"proxy_key,omitempty"` - Concurrency int `json:"concurrency"` - Priority int `json:"priority"` - RateMultiplier *float64 `json:"rate_multiplier,omitempty"` - ExpiresAt *int64 `json:"expires_at,omitempty"` - AutoPauseOnExpired *bool `json:"auto_pause_on_expired,omitempty"` + Name string `json:"name"` + Notes *string `json:"notes,omitempty"` + Platform string `json:"platform"` + Type string `json:"type"` + Credentials map[string]any `json:"credentials"` + Extra map[string]any `json:"extra,omitempty"` + ProxyKey *string `json:"proxy_key,omitempty"` + Concurrency int `json:"concurrency"` + Priority int `json:"priority"` + RateMultiplier *float64 `json:"rate_multiplier,omitempty"` + ExpiresAt *int64 `json:"expires_at,omitempty"` + AutoPauseOnExpired *bool `json:"auto_pause_on_expired,omitempty"` + StripReasoningEffortOnCC *bool `json:"strip_reasoning_effort_on_cc,omitempty"` } type DataImportRequest struct { @@ -151,18 +153,19 @@ func (h *AccountHandler) ExportData(c *gin.Context) { expiresAt = &v } dataAccounts = append(dataAccounts, DataAccount{ - Name: acc.Name, - Notes: acc.Notes, - Platform: acc.Platform, - Type: acc.Type, - Credentials: acc.Credentials, - Extra: acc.Extra, - ProxyKey: proxyKey, - Concurrency: acc.Concurrency, - Priority: acc.Priority, - RateMultiplier: acc.RateMultiplier, - ExpiresAt: expiresAt, - AutoPauseOnExpired: &acc.AutoPauseOnExpired, + Name: acc.Name, + Notes: acc.Notes, + Platform: acc.Platform, + Type: acc.Type, + Credentials: acc.Credentials, + Extra: acc.Extra, + ProxyKey: proxyKey, + Concurrency: acc.Concurrency, + Priority: acc.Priority, + RateMultiplier: acc.RateMultiplier, + ExpiresAt: expiresAt, + AutoPauseOnExpired: &acc.AutoPauseOnExpired, + StripReasoningEffortOnCC: &acc.StripReasoningEffortOnCC, }) } @@ -305,20 +308,21 @@ func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) enrichCredentialsFromIDToken(&item) accountInput := &service.CreateAccountInput{ - Name: item.Name, - Notes: item.Notes, - Platform: item.Platform, - Type: item.Type, - Credentials: item.Credentials, - Extra: item.Extra, - ProxyID: proxyID, - Concurrency: item.Concurrency, - Priority: item.Priority, - RateMultiplier: item.RateMultiplier, - GroupIDs: nil, - ExpiresAt: item.ExpiresAt, - AutoPauseOnExpired: item.AutoPauseOnExpired, - SkipDefaultGroupBind: skipDefaultGroupBind, + Name: item.Name, + Notes: item.Notes, + Platform: item.Platform, + Type: item.Type, + Credentials: item.Credentials, + Extra: item.Extra, + ProxyID: proxyID, + Concurrency: item.Concurrency, + Priority: item.Priority, + RateMultiplier: item.RateMultiplier, + GroupIDs: nil, + ExpiresAt: item.ExpiresAt, + AutoPauseOnExpired: item.AutoPauseOnExpired, + StripReasoningEffortOnCC: item.StripReasoningEffortOnCC, + SkipDefaultGroupBind: skipDefaultGroupBind, } created, err := h.adminService.CreateAccount(ctx, accountInput) @@ -564,6 +568,10 @@ func validateDataAccount(item DataAccount) error { } switch item.Type { case service.AccountTypeOAuth, service.AccountTypeSetupToken, service.AccountTypeAPIKey, service.AccountTypeUpstream: + case service.AccountTypeAPIKeyChatCompletions: + if err := validateAPIKeyChatCompletionsCredentials(item.Credentials); err != nil { + return err + } default: return fmt.Errorf("account type is invalid: %s", item.Type) } @@ -579,6 +587,57 @@ func validateDataAccount(item DataAccount) error { return nil } +// validateAPIKeyChatCompletionsCredentials checks that an +// AccountTypeAPIKeyChatCompletions account carries the minimum set of +// credentials required by the upstream Chat Completions integration: +// a valid http(s) chat_completions_url plus a non-empty api_key. +func validateAPIKeyChatCompletionsCredentials(creds map[string]any) error { + if creds == nil { + return errors.New("credentials are required for apikey-chat-completions accounts") + } + + rawURL, _ := creds["chat_completions_url"].(string) + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" { + return errors.New("chat_completions_url is required for apikey-chat-completions accounts") + } + parsed, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("chat_completions_url is invalid: %w", err) + } + scheme := strings.ToLower(parsed.Scheme) + if scheme != "http" && scheme != "https" { + return fmt.Errorf("chat_completions_url must use http or https scheme, got %q", parsed.Scheme) + } + if parsed.Host == "" { + return errors.New("chat_completions_url must include a host") + } + + apiKey, _ := creds["api_key"].(string) + if strings.TrimSpace(apiKey) == "" { + return errors.New("api_key is required for apikey-chat-completions accounts") + } + + return validateOpenAICompatibleAuthHeaderCredential(creds) +} + +func validateAPIKeyChatCompletionsCredentialsUpdate(creds map[string]any) error { + if creds == nil { + return nil + } + return validateOpenAICompatibleAuthHeaderCredential(creds) +} + +func validateOpenAICompatibleAuthHeaderCredential(creds map[string]any) error { + if rawAuthHeader, ok := creds["auth_header"]; ok { + authHeader, _ := rawAuthHeader.(string) + if _, valid := service.NormalizeOpenAICompatibleAuthHeader(authHeader); !valid { + return fmt.Errorf("auth_header must be one of %q, %q, or %q", service.OpenAICompatibleAuthHeaderAuthorization, service.OpenAICompatibleAuthHeaderAPIKey, service.OpenAICompatibleAuthHeaderXAPIKey) + } + } + return nil +} + func defaultProxyName(name string) string { if strings.TrimSpace(name) == "" { return "imported-proxy" diff --git a/backend/internal/handler/admin/account_data_auth_header_test.go b/backend/internal/handler/admin/account_data_auth_header_test.go new file mode 100644 index 00000000000..6fe09f4ab1a --- /dev/null +++ b/backend/internal/handler/admin/account_data_auth_header_test.go @@ -0,0 +1,34 @@ +//go:build unit + +package admin + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValidateAPIKeyChatCompletionsCredentialsAuthHeader(t *testing.T) { + baseCreds := func() map[string]any { + return map[string]any{ + "chat_completions_url": "https://compat-upstream.example/v1/chat/completions", + "api_key": "sk-test", + } + } + + for _, value := range []string{"", "Authorization", "authorization", "api-key", "x-api-key"} { + t.Run("allows_"+value, func(t *testing.T) { + creds := baseCreds() + creds["auth_header"] = value + require.NoError(t, validateAPIKeyChatCompletionsCredentials(creds)) + }) + } + + for _, value := range []string{"X-Custom-Key", "api-key\nInjected: yes", "Authorization: Bearer"} { + t.Run("rejects_"+value, func(t *testing.T) { + creds := baseCreds() + creds["auth_header"] = value + require.ErrorContains(t, validateAPIKeyChatCompletionsCredentials(creds), "auth_header must be one of") + }) + } +} diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 5719534230b..3996cb684f9 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -95,41 +95,43 @@ func NewAccountHandler( // CreateAccountRequest represents create account request type CreateAccountRequest struct { - Name string `json:"name" binding:"required"` - Notes *string `json:"notes"` - Platform string `json:"platform" binding:"required"` - Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock service_account"` - Credentials map[string]any `json:"credentials" binding:"required"` - Extra map[string]any `json:"extra"` - ProxyID *int64 `json:"proxy_id"` - Concurrency int `json:"concurrency"` - Priority int `json:"priority"` - RateMultiplier *float64 `json:"rate_multiplier"` - LoadFactor *int `json:"load_factor"` - GroupIDs []int64 `json:"group_ids"` - ExpiresAt *int64 `json:"expires_at"` - AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` - ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险 + Name string `json:"name" binding:"required"` + Notes *string `json:"notes"` + Platform string `json:"platform" binding:"required"` + Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock service_account apikey-chat-completions"` + Credentials map[string]any `json:"credentials" binding:"required"` + Extra map[string]any `json:"extra"` + ProxyID *int64 `json:"proxy_id"` + Concurrency int `json:"concurrency"` + Priority int `json:"priority"` + RateMultiplier *float64 `json:"rate_multiplier"` + LoadFactor *int `json:"load_factor"` + GroupIDs []int64 `json:"group_ids"` + ExpiresAt *int64 `json:"expires_at"` + AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` + StripReasoningEffortOnCC *bool `json:"strip_reasoning_effort_on_cc"` + ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险 } // UpdateAccountRequest represents update account request // 使用指针类型来区分"未提供"和"设置为0" type UpdateAccountRequest struct { - Name string `json:"name"` - Notes *string `json:"notes"` - Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock service_account"` - Credentials map[string]any `json:"credentials"` - Extra map[string]any `json:"extra"` - ProxyID *int64 `json:"proxy_id"` - Concurrency *int `json:"concurrency"` - Priority *int `json:"priority"` - RateMultiplier *float64 `json:"rate_multiplier"` - LoadFactor *int `json:"load_factor"` - Status string `json:"status" binding:"omitempty,oneof=active inactive error"` - GroupIDs *[]int64 `json:"group_ids"` - ExpiresAt *int64 `json:"expires_at"` - AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` - ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险 + Name string `json:"name"` + Notes *string `json:"notes"` + Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock service_account apikey-chat-completions"` + Credentials map[string]any `json:"credentials"` + Extra map[string]any `json:"extra"` + ProxyID *int64 `json:"proxy_id"` + Concurrency *int `json:"concurrency"` + Priority *int `json:"priority"` + RateMultiplier *float64 `json:"rate_multiplier"` + LoadFactor *int `json:"load_factor"` + Status string `json:"status" binding:"omitempty,oneof=active inactive error"` + GroupIDs *[]int64 `json:"group_ids"` + ExpiresAt *int64 `json:"expires_at"` + AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` + StripReasoningEffortOnCC *bool `json:"strip_reasoning_effort_on_cc"` + ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险 } // BulkUpdateAccountsRequest represents the payload for bulk editing accounts @@ -522,6 +524,12 @@ func (h *AccountHandler) Create(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } + if req.Type == service.AccountTypeAPIKeyChatCompletions { + if err := validateAPIKeyChatCompletionsCredentials(req.Credentials); err != nil { + response.BadRequest(c, err.Error()) + return + } + } // base_rpm 输入校验:负值归零,超过 10000 截断 sanitizeExtraBaseRPM(req.Extra) @@ -534,21 +542,22 @@ func (h *AccountHandler) Create(c *gin.Context) { result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ - Name: req.Name, - Notes: req.Notes, - Platform: req.Platform, - Type: req.Type, - Credentials: req.Credentials, - Extra: req.Extra, - ProxyID: req.ProxyID, - Concurrency: req.Concurrency, - Priority: req.Priority, - RateMultiplier: req.RateMultiplier, - LoadFactor: req.LoadFactor, - GroupIDs: req.GroupIDs, - ExpiresAt: req.ExpiresAt, - AutoPauseOnExpired: req.AutoPauseOnExpired, - SkipMixedChannelCheck: skipCheck, + Name: req.Name, + Notes: req.Notes, + Platform: req.Platform, + Type: req.Type, + Credentials: req.Credentials, + Extra: req.Extra, + ProxyID: req.ProxyID, + Concurrency: req.Concurrency, + Priority: req.Priority, + RateMultiplier: req.RateMultiplier, + LoadFactor: req.LoadFactor, + GroupIDs: req.GroupIDs, + ExpiresAt: req.ExpiresAt, + AutoPauseOnExpired: req.AutoPauseOnExpired, + StripReasoningEffortOnCC: req.StripReasoningEffortOnCC, + SkipMixedChannelCheck: skipCheck, }) if execErr != nil { return nil, execErr @@ -606,6 +615,12 @@ func (h *AccountHandler) Update(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } + if req.Type == service.AccountTypeAPIKeyChatCompletions && len(req.Credentials) > 0 { + if err := validateAPIKeyChatCompletionsCredentialsUpdate(req.Credentials); err != nil { + response.BadRequest(c, err.Error()) + return + } + } // base_rpm 输入校验:负值归零,超过 10000 截断 sanitizeExtraBaseRPM(req.Extra) @@ -613,21 +628,22 @@ func (h *AccountHandler) Update(c *gin.Context) { skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk account, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ - Name: req.Name, - Notes: req.Notes, - Type: req.Type, - Credentials: req.Credentials, - Extra: req.Extra, - ProxyID: req.ProxyID, - Concurrency: req.Concurrency, // 指针类型,nil 表示未提供 - Priority: req.Priority, // 指针类型,nil 表示未提供 - RateMultiplier: req.RateMultiplier, - LoadFactor: req.LoadFactor, - Status: req.Status, - GroupIDs: req.GroupIDs, - ExpiresAt: req.ExpiresAt, - AutoPauseOnExpired: req.AutoPauseOnExpired, - SkipMixedChannelCheck: skipCheck, + Name: req.Name, + Notes: req.Notes, + Type: req.Type, + Credentials: req.Credentials, + Extra: req.Extra, + ProxyID: req.ProxyID, + Concurrency: req.Concurrency, // 指针类型,nil 表示未提供 + Priority: req.Priority, // 指针类型,nil 表示未提供 + RateMultiplier: req.RateMultiplier, + LoadFactor: req.LoadFactor, + Status: req.Status, + GroupIDs: req.GroupIDs, + ExpiresAt: req.ExpiresAt, + AutoPauseOnExpired: req.AutoPauseOnExpired, + StripReasoningEffortOnCC: req.StripReasoningEffortOnCC, + SkipMixedChannelCheck: skipCheck, }) if err != nil { // 检查是否为混合渠道错误 @@ -1330,20 +1346,21 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ - Name: item.Name, - Notes: item.Notes, - Platform: item.Platform, - Type: item.Type, - Credentials: item.Credentials, - Extra: item.Extra, - ProxyID: item.ProxyID, - Concurrency: item.Concurrency, - Priority: item.Priority, - RateMultiplier: item.RateMultiplier, - GroupIDs: item.GroupIDs, - ExpiresAt: item.ExpiresAt, - AutoPauseOnExpired: item.AutoPauseOnExpired, - SkipMixedChannelCheck: skipCheck, + Name: item.Name, + Notes: item.Notes, + Platform: item.Platform, + Type: item.Type, + Credentials: item.Credentials, + Extra: item.Extra, + ProxyID: item.ProxyID, + Concurrency: item.Concurrency, + Priority: item.Priority, + RateMultiplier: item.RateMultiplier, + GroupIDs: item.GroupIDs, + ExpiresAt: item.ExpiresAt, + AutoPauseOnExpired: item.AutoPauseOnExpired, + StripReasoningEffortOnCC: item.StripReasoningEffortOnCC, + SkipMixedChannelCheck: skipCheck, }) if err != nil { failed++ diff --git a/backend/internal/handler/anthropic_messages_apikey_cc_test.go b/backend/internal/handler/anthropic_messages_apikey_cc_test.go new file mode 100644 index 00000000000..5390bf66010 --- /dev/null +++ b/backend/internal/handler/anthropic_messages_apikey_cc_test.go @@ -0,0 +1,249 @@ +package handler + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func anthropicAPIKeyCCTestConfig() *config.Config { + return &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{ + Enabled: false, + AllowInsecureHTTP: true, + }, + }, + } +} + +func anthropicAPIKeyCCTestAccount(chatCompletionsURL string) *service.Account { + return &service.Account{ + ID: 22002, + Name: "anthropic-cc-handler-test", + Platform: service.PlatformAnthropic, + Type: service.AccountTypeAPIKeyChatCompletions, + Status: service.StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-anth-cc-handler", + "chat_completions_url": chatCompletionsURL, + }, + } +} + +func newAnthropicCCGatewayService(t *testing.T, httpUp service.HTTPUpstream) *service.GatewayService { + t.Helper() + cfg := anthropicAPIKeyCCTestConfig() + return service.NewGatewayService( + nil, nil, nil, nil, nil, nil, nil, + nil, // cache + cfg, + nil, // schedulerSnapshot + nil, // concurrencyService + nil, // billingService + nil, // rateLimitService + nil, // billingCacheService + nil, // identityService + httpUp, + nil, // deferredService + nil, // claudeTokenProvider + nil, // sessionLimitCache + nil, // rpmCache + nil, // digestStore + nil, // settingService + nil, // tlsFPProfileService + nil, // channelService + nil, // resolver + nil, // balanceNotifyService + nil, // userPlatformQuotaRepo + ) +} + +// TestAPIKeyCCMessages_NonStream verifies that an Anthropic-platform +// apikey-chat-completions account correctly translates a /v1/messages request +// into a Chat Completions upstream call and converts the upstream SSE response +// back to an Anthropic JSON envelope for the client. +func TestAPIKeyCCMessages_NonStream(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstreamSSE := strings.Join([]string{ + `data: {"id":"chk","object":"chat.completion.chunk","model":"deepseek-chat","choices":[{"index":0,"delta":{"role":"assistant"}}]}`, + ``, + `data: {"id":"chk","object":"chat.completion.chunk","model":"deepseek-chat","choices":[{"index":0,"delta":{"content":"hello"}}]}`, + ``, + `data: {"id":"chk","object":"chat.completion.chunk","model":"deepseek-chat","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`, + ``, + `data: [DONE]`, + ``, + }, "\n") + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(upstreamSSE)) + })) + defer upstream.Close() + + httpUp := &apikeyCCHTTPUpstream{client: upstream.Client()} + svc := newAnthropicCCGatewayService(t, httpUp) + account := anthropicAPIKeyCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"claude-sonnet-4","max_tokens":64,"stream":false,"messages":[{"role":"user","content":"ping"}]}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("anthropic-beta", "tools-2024-05-16") + + parsed := &service.ParsedRequest{Body: service.NewRequestBodyRef(body), Model: "claude-sonnet-4", Stream: false} + result, err := svc.ForwardAnthropicAsChatCompletions(c.Request.Context(), c, account, parsed) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.Stream) + + require.Equal(t, int32(1), atomic.LoadInt32(&httpUp.requestCount)) + require.Equal(t, upstream.URL+"/v1/chat/completions", httpUp.lastReqURL) + require.Equal(t, "Bearer sk-anth-cc-handler", httpUp.lastAuthz) + + // anthropic-* headers must NOT leak to the OpenAI-compatible upstream. + require.NotContains(t, strings.ToLower(string(httpUp.lastBody)), "anthropic-beta") + + clientBody := rec.Body.String() + require.True(t, gjson.Valid(clientBody)) + require.Equal(t, "message", gjson.Get(clientBody, "type").String()) + require.Contains(t, clientBody, "hello") +} + +func TestAPIKeyCCMessages_Stream(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstreamSSE := strings.Join([]string{ + `data: {"id":"chk","object":"chat.completion.chunk","model":"m","choices":[{"index":0,"delta":{"role":"assistant"}}]}`, + ``, + `data: {"id":"chk","object":"chat.completion.chunk","model":"m","choices":[{"index":0,"delta":{"content":"yo"}}]}`, + ``, + `data: {"id":"chk","object":"chat.completion.chunk","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":2,"completion_tokens":1,"total_tokens":3}}`, + ``, + `data: [DONE]`, + ``, + }, "\n") + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(upstreamSSE)) + })) + defer upstream.Close() + + httpUp := &apikeyCCHTTPUpstream{client: upstream.Client()} + svc := newAnthropicCCGatewayService(t, httpUp) + account := anthropicAPIKeyCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"claude-sonnet-4","max_tokens":64,"stream":true,"messages":[{"role":"user","content":"hi"}]}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + parsed := &service.ParsedRequest{Body: service.NewRequestBodyRef(body), Model: "claude-sonnet-4", Stream: true} + result, err := svc.ForwardAnthropicAsChatCompletions(c.Request.Context(), c, account, parsed) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + + clientBody := rec.Body.String() + require.Contains(t, clientBody, "event: message_start") + require.Contains(t, clientBody, "event: content_block_delta") + require.Contains(t, clientBody, "event: message_delta") + require.Contains(t, clientBody, "event: message_stop") +} + +// TestAPIKeyCCMessages_UpstreamError 验证可 failover 状态码(500/429/etc.) +// 不直接写客户端,而是返回 *UpstreamFailoverError 让 handler 端的 failover 循环处理 +// (决定是否切换到下一个账号,或最终调用 handleFailoverExhausted 写出错误响应)。 +func TestAPIKeyCCMessages_UpstreamError(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":{"message":"boom"}}`)) + })) + defer upstream.Close() + + httpUp := &apikeyCCHTTPUpstream{client: upstream.Client()} + svc := newAnthropicCCGatewayService(t, httpUp) + account := anthropicAPIKeyCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"claude-sonnet-4","max_tokens":64,"stream":false,"messages":[{"role":"user","content":"ping"}]}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + parsed := &service.ParsedRequest{Body: service.NewRequestBodyRef(body), Model: "claude-sonnet-4", Stream: false} + _, err := svc.ForwardAnthropicAsChatCompletions(c.Request.Context(), c, account, parsed) + require.Error(t, err) + + var failoverErr *service.UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, http.StatusInternalServerError, failoverErr.StatusCode) + require.Contains(t, string(failoverErr.ResponseBody), "boom") + + // 客户端响应未被写出,由外层 failover loop 决定后续动作 + require.Equal(t, 0, rec.Body.Len()) +} + +// TestAPIKeyCCChatCompletions_ClaudePlatform verifies that on an Anthropic +// platform group, a client POST /v1/chat/completions against an +// apikey-chat-completions account performs a raw passthrough to the upstream +// CC URL (no protocol conversion). +func TestAPIKeyCCChatCompletions_ClaudePlatform(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "id":"chatcmpl_anth_raw", + "object":"chat.completion", + "created":1, + "model":"deepseek-chat", + "choices":[{"index":0,"message":{"role":"assistant","content":"pong"},"finish_reason":"stop"}], + "usage":{"prompt_tokens":3,"completion_tokens":1,"total_tokens":4} + }`)) + })) + defer upstream.Close() + + httpUp := &apikeyCCHTTPUpstream{client: upstream.Client()} + svc := newAnthropicCCGatewayService(t, httpUp) + account := anthropicAPIKeyCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"deepseek-chat","messages":[{"role":"user","content":"ping"}],"stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + result, err := svc.ForwardClaudeChatCompletionsRaw(c.Request.Context(), c, account, body, &service.ParsedRequest{Body: service.NewRequestBodyRef(body)}) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.Stream) + + require.Equal(t, upstream.URL+"/v1/chat/completions", httpUp.lastReqURL) + require.Equal(t, "Bearer sk-anth-cc-handler", httpUp.lastAuthz) + + clientBody := rec.Body.String() + require.True(t, gjson.Valid(clientBody)) + require.Equal(t, "chat.completion", gjson.Get(clientBody, "object").String()) + require.Contains(t, clientBody, "pong") +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 86f98f15e2e..0dab9d7337d 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -219,7 +219,8 @@ func AccountFromServiceShallow(a *service.Account) *Account { ErrorMessage: a.ErrorMessage, LastUsedAt: a.LastUsedAt, ExpiresAt: timeToUnixSeconds(a.ExpiresAt), - AutoPauseOnExpired: a.AutoPauseOnExpired, + AutoPauseOnExpired: a.AutoPauseOnExpired, + StripReasoningEffortOnCC: a.StripReasoningEffortOnCC, CreatedAt: a.CreatedAt, UpdatedAt: a.UpdatedAt, Schedulable: a.Schedulable, diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 08dc657205a..1e5cb0b496f 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -160,21 +160,22 @@ type Account struct { Type string `json:"type"` // Credentials 经 RedactCredentials 处理后只含非敏感子键;敏感 token / api_key / 私钥 // 的存在性通过 CredentialsStatus(has_)暴露,原始值不返回前端。 - Credentials map[string]any `json:"credentials"` - CredentialsStatus map[string]bool `json:"credentials_status,omitempty"` - Extra map[string]any `json:"extra"` - ProxyID *int64 `json:"proxy_id"` - Concurrency int `json:"concurrency"` - LoadFactor *int `json:"load_factor,omitempty"` - Priority int `json:"priority"` - RateMultiplier float64 `json:"rate_multiplier"` - Status string `json:"status"` - ErrorMessage string `json:"error_message"` - LastUsedAt *time.Time `json:"last_used_at"` - ExpiresAt *int64 `json:"expires_at"` - AutoPauseOnExpired bool `json:"auto_pause_on_expired"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + Credentials map[string]any `json:"credentials"` + CredentialsStatus map[string]bool `json:"credentials_status,omitempty"` + Extra map[string]any `json:"extra"` + ProxyID *int64 `json:"proxy_id"` + Concurrency int `json:"concurrency"` + LoadFactor *int `json:"load_factor,omitempty"` + Priority int `json:"priority"` + RateMultiplier float64 `json:"rate_multiplier"` + Status string `json:"status"` + ErrorMessage string `json:"error_message"` + LastUsedAt *time.Time `json:"last_used_at"` + ExpiresAt *int64 `json:"expires_at"` + AutoPauseOnExpired bool `json:"auto_pause_on_expired"` + StripReasoningEffortOnCC bool `json:"strip_reasoning_effort_on_cc"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` Schedulable bool `json:"schedulable"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 3bd5e82e7b3..9587e80dc50 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -504,6 +504,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { requestPayloadHash := service.HashUsageRequestPayload(body) inboundEndpoint := GetInboundEndpoint(c) upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + if account.IsOpenAIChatCompletionsUpstream() { + // /v1/messages 流量被适配为 /v1/chat/completions 上游调用,与 OpenAI 侧 /responses → CC + // 改造保持一致,便于 Ops 区分。 + upstreamEndpoint = "/v1/chat/completions" + } if result.ReasoningEffort == nil { result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort) @@ -913,6 +918,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { requestPayloadHash := service.HashUsageRequestPayload(attemptParsedReq.Body.Bytes()) inboundEndpoint := GetInboundEndpoint(c) upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + if account.IsOpenAIChatCompletionsUpstream() { + // /v1/messages 流量被适配为 /v1/chat/completions 上游调用,与 OpenAI 侧 /responses → CC + // 改造保持一致,便于 Ops 区分。 + upstreamEndpoint = "/v1/chat/completions" + } if result.ReasoningEffort == nil { result.ReasoningEffort = service.NormalizeClaudeOutputEffort(attemptParsedReq.OutputEffort) diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index 719700aa9f6..689198f952b 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -251,6 +251,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { return } result, err = h.geminiCompatService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody) + } else if account.IsOpenAIChatCompletionsUpstream() { + // Client sent CC, upstream is also CC: raw passthrough, no protocol conversion. + result, err = h.gatewayService.ForwardClaudeChatCompletionsRaw(c.Request.Context(), c, account, forwardBody, parsedReq) + } else { result, err = h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq) } @@ -291,6 +295,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { requestPayloadHash := service.HashUsageRequestPayload(body) inboundEndpoint := GetInboundEndpoint(c) upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + if account.IsOpenAIChatCompletionsUpstream() { + // 上游实际打到 OpenAI 兼容 /v1/chat/completions,覆盖默认按 platform 派生的 /v1/messages, + // 与 OpenAI 侧 /responses → CC 改造保持一致,便于 Ops 区分。 + upstreamEndpoint = "/v1/chat/completions" + } quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey) h.submitUsageRecordTask(c.Request.Context(), func(ctx context.Context) { diff --git a/backend/internal/handler/openai_chat_completions_apikey_cc_test.go b/backend/internal/handler/openai_chat_completions_apikey_cc_test.go new file mode 100644 index 00000000000..e286ecc0989 --- /dev/null +++ b/backend/internal/handler/openai_chat_completions_apikey_cc_test.go @@ -0,0 +1,222 @@ +package handler + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +// apikeyCCHTTPUpstream 是一个轻量 HTTPUpstream 实现,把请求直接发给一个 +// 由测试启动的 httptest.Server。专用于 apikey-chat-completions 类型账号的 +// handler 集成 smoke test:验证 service 层在该账号类型下确实把请求送到 +// account.GetOpenAIChatCompletionsURL() 配置的 URL,且 body/header 正确。 +type apikeyCCHTTPUpstream struct { + client *http.Client + requestCount int32 + lastReqURL string + lastAuthz string + lastAccept string + lastBody []byte +} + +func (u *apikeyCCHTTPUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + atomic.AddInt32(&u.requestCount, 1) + u.lastReqURL = req.URL.String() + u.lastAuthz = req.Header.Get("Authorization") + u.lastAccept = req.Header.Get("Accept") + if req.Body != nil { + buf := new(bytes.Buffer) + _, _ = buf.ReadFrom(req.Body) + u.lastBody = buf.Bytes() + _ = req.Body.Close() + req.Body = http.NoBody + req.Body = newReusableBody(u.lastBody) + req.ContentLength = int64(len(u.lastBody)) + } + return u.client.Do(req) +} + +func (u *apikeyCCHTTPUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { + return u.Do(req, proxyURL, accountID, accountConcurrency) +} + +func newReusableBody(b []byte) *reusableBody { + return &reusableBody{Reader: bytes.NewReader(b)} +} + +type reusableBody struct { + *bytes.Reader +} + +func (r *reusableBody) Close() error { return nil } + +func apikeyCCTestAccount(chatCompletionsURL string) *service.Account { + return &service.Account{ + ID: 11001, + Name: "apikey-cc-handler-test", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeAPIKeyChatCompletions, + Status: service.StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-handler-cc-test", + "chat_completions_url": chatCompletionsURL, + }, + } +} + +func apikeyCCTestConfig() *config.Config { + return &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{ + Enabled: false, + AllowInsecureHTTP: true, + }, + }, + } +} + +// TestAPIKeyCCChatCompletions_NonStream 验证 apikey-chat-completions 账号在 +// /v1/chat/completions 路径下的非流式转发:上游 URL 必须等于 +// account.GetOpenAIChatCompletionsURL(),Authorization 必须使用账号 api_key, +// 客户端能拿到合法的 ChatCompletions JSON。 +func TestAPIKeyCCChatCompletions_NonStream(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "id":"chatcmpl_apikey_cc_1", + "object":"chat.completion", + "created":1, + "model":"gpt-4o", + "choices":[{"index":0,"message":{"role":"assistant","content":"pong"},"finish_reason":"stop"}], + "usage":{"prompt_tokens":3,"completion_tokens":1,"total_tokens":4} + }`)) + })) + defer upstream.Close() + + httpUp := &apikeyCCHTTPUpstream{client: upstream.Client()} + cfg := apikeyCCTestConfig() + + // 通过 NewOpenAIGatewayService 构造完整的 service,但仅注入测试需要的字段。 + svc := service.NewOpenAIGatewayService( + nil, nil, nil, nil, nil, nil, nil, + cfg, + nil, nil, + service.NewBillingService(cfg, nil), + nil, nil, + httpUp, + &service.DeferredService{}, + nil, nil, nil, nil, nil, + nil, + ) + + account := apikeyCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"gpt-4o","messages":[{"role":"user","content":"ping"}],"stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + result, err := svc.ForwardAsChatCompletions(c.Request.Context(), c, account, body, "", "") + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.Stream) + + // 上游 URL 与认证头 + require.Equal(t, int32(1), atomic.LoadInt32(&httpUp.requestCount), "应只调用一次上游") + require.Equal(t, upstream.URL+"/v1/chat/completions", httpUp.lastReqURL) + require.Equal(t, "Bearer sk-handler-cc-test", httpUp.lastAuthz) + require.Equal(t, "application/json", httpUp.lastAccept) + + // 上游收到的 body 必须仍是 ChatCompletions 格式(含 messages,不含 input) + require.Contains(t, string(httpUp.lastBody), `"messages"`) + require.NotContains(t, string(httpUp.lastBody), `"input":[`) + + // 客户端响应是合法 ChatCompletions JSON + clientBody := rec.Body.String() + require.True(t, gjson.Valid(clientBody), "客户端响应必须是合法 JSON: %s", clientBody) + require.Equal(t, "chat.completion", gjson.Get(clientBody, "object").String()) + require.Contains(t, clientBody, "pong") +} + +// TestAPIKeyCCChatCompletions_Stream 验证 apikey-chat-completions 账号在 +// /v1/chat/completions 路径下的流式转发:上游 stream_options.include_usage +// 必被网关强制打开,客户端能收到 SSE 事件并以 [DONE] 结束。 +func TestAPIKeyCCChatCompletions_Stream(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstreamSSE := strings.Join([]string{ + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}`, + "", + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"hi"},"finish_reason":null}]}`, + "", + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":2,"completion_tokens":1,"total_tokens":3}}`, + "", + "data: [DONE]", + "", + }, "\n") + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(upstreamSSE)) + })) + defer upstream.Close() + + httpUp := &apikeyCCHTTPUpstream{client: upstream.Client()} + cfg := apikeyCCTestConfig() + + svc := service.NewOpenAIGatewayService( + nil, nil, nil, nil, nil, nil, nil, + cfg, + nil, nil, + service.NewBillingService(cfg, nil), + nil, nil, + httpUp, + &service.DeferredService{}, + nil, nil, nil, nil, nil, + nil, + ) + + account := apikeyCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}],"stream":true}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + result, err := svc.ForwardAsChatCompletions(c.Request.Context(), c, account, body, "", "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + + // 上游 Accept: text/event-stream + require.Equal(t, "text/event-stream", httpUp.lastAccept) + require.Equal(t, upstream.URL+"/v1/chat/completions", httpUp.lastReqURL) + require.Equal(t, "Bearer sk-handler-cc-test", httpUp.lastAuthz) + + // 上游收到的 stream=true 且 stream_options.include_usage 被网关强制打开 + require.True(t, gjson.GetBytes(httpUp.lastBody, "stream").Bool()) + require.True(t, gjson.GetBytes(httpUp.lastBody, "stream_options.include_usage").Bool()) + + // 客户端收到 SSE 事件,含 chunk 与终止符 + clientBody := rec.Body.String() + require.Contains(t, clientBody, `"object":"chat.completion.chunk"`) + require.Contains(t, clientBody, "data: [DONE]") +} diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 0aa477b08cd..5869f03dc4f 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -359,6 +359,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { if channelMapping.Mapped { forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) } + // apikey-chat-completions 类型账号:将 /responses 请求转换为 /v1/chat/completions 调上游。 + isChatCompletionsUpstream := account.IsOpenAIChatCompletionsUpstream() writerSizeBeforeForward := c.Writer.Size() result, err := func() (*service.OpenAIForwardResult, error) { defer func() { @@ -366,8 +368,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { accountReleaseFunc() } }() + if isChatCompletionsUpstream { + return h.gatewayService.ForwardResponsesAsChatCompletions(c.Request.Context(), c, account, forwardBody) + } return h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody) }() + forwardDurationMs := time.Since(forwardStart).Milliseconds() upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) responseLatencyMs := forwardDurationMs @@ -462,6 +468,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { requestPayloadHash := service.HashUsageRequestPayload(body) inboundEndpoint := GetInboundEndpoint(c) upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + if isChatCompletionsUpstream { + // /responses 流量被适配为 /v1/chat/completions 上游调用,便于 Ops 区分。 + upstreamEndpoint = "/v1/chat/completions" + } // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitOpenAIUsageRecordTask(c.Request.Context(), result, func(ctx context.Context) { @@ -1358,11 +1368,15 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } - token, _, err := h.gatewayService.GetAccessToken(ctx, account) - if err != nil { - reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token") - return + token := "" + if !account.IsOpenAIChatCompletionsUpstream() { + var err error + token, _, err = h.gatewayService.GetAccessToken(ctx, account) + if err != nil { + reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token") + return + } } reqLog.Debug("openai.websocket_account_selected", @@ -1447,6 +1461,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) inboundEndpoint := GetInboundEndpoint(c) upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + if account.IsOpenAIChatCompletionsUpstream() { + upstreamEndpoint = "/v1/chat/completions" + } h.submitOpenAIUsageRecordTask(ctx, result, func(taskCtx context.Context) { if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ Result: result, @@ -1481,9 +1498,15 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { // WebSocket 首包可能很大,hash 必须在 hooks 外算成字符串,避免 AfterTurn 闭包保活请求体。 requestPayloadHash = service.HashUsageRequestPayload(wsFirstMessage) - if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil { + var proxyErr error + if account.IsOpenAIChatCompletionsUpstream() { + proxyErr = h.gatewayService.ProxyResponsesWebSocketAsChatCompletions(ctx, c, wsConn, account, wsFirstMessage, hooks) + } else { + proxyErr = h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks) + } + if proxyErr != nil { var failoverErr *service.UpstreamFailoverError - if errors.As(err, &failoverErr) { + if errors.As(proxyErr, &failoverErr) { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) releaseAccountSlot() failedAccountIDs[account.ID] = struct{}{} @@ -1519,7 +1542,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { zap.String("close_reason", closeReason), ) var closeErr *service.OpenAIWSClientCloseError - if errors.As(err, &closeErr) { + if errors.As(proxyErr, &closeErr) { closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason()) return } diff --git a/backend/internal/handler/openai_responses_apikey_cc_test.go b/backend/internal/handler/openai_responses_apikey_cc_test.go new file mode 100644 index 00000000000..34e113f8983 --- /dev/null +++ b/backend/internal/handler/openai_responses_apikey_cc_test.go @@ -0,0 +1,167 @@ +package handler + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func apikeyCCResponsesTestConfig() *config.Config { + return &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{ + Enabled: false, + AllowInsecureHTTP: true, + }, + }, + } +} + +// TestAPIKeyCCResponses_NonStream 验证 apikey-chat-completions 账号在 +// /v1/responses 路径下的非流式协议转换:service 层会把 Responses 请求转成 +// ChatCompletions 发给上游配置的 chat_completions_url,并把上游返回的 +// ChatCompletions JSON 转回 Responses 形态返回给客户端。 +func TestAPIKeyCCResponses_NonStream(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "id":"chatcmpl_resp_apikey_cc", + "object":"chat.completion", + "created":1, + "model":"gpt-4o", + "choices":[{"index":0,"message":{"role":"assistant","content":"hello world"},"finish_reason":"stop"}], + "usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7} + }`)) + })) + defer upstream.Close() + + httpUp := &apikeyCCHTTPUpstream{client: upstream.Client()} + cfg := apikeyCCResponsesTestConfig() + + svc := service.NewOpenAIGatewayService( + nil, nil, nil, nil, nil, nil, nil, + cfg, + nil, nil, + service.NewBillingService(cfg, nil), + nil, nil, + httpUp, + &service.DeferredService{}, + nil, nil, nil, nil, nil, + nil, + ) + + account := apikeyCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"gpt-4o","input":"ping","stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + result, err := svc.ForwardResponsesAsChatCompletions(c.Request.Context(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.Stream) + require.Equal(t, 5, result.Usage.InputTokens) + require.Equal(t, 2, result.Usage.OutputTokens) + + // 上游 URL 必须是账号配置的 chat_completions_url + require.Equal(t, upstream.URL+"/v1/chat/completions", httpUp.lastReqURL) + require.Equal(t, "Bearer sk-handler-cc-test", httpUp.lastAuthz) + require.Equal(t, "application/json", httpUp.lastAccept) + + // 发给上游的必须是 ChatCompletions(messages,不带 Responses 的 input) + require.Contains(t, string(httpUp.lastBody), `"messages"`) + require.NotContains(t, string(httpUp.lastBody), `"input":"ping"`) + + // 客户端响应应为 Responses API 形态 JSON + clientBody := rec.Body.String() + require.True(t, gjson.Valid(clientBody), "客户端响应必须是合法 JSON: %s", clientBody) + require.Equal(t, "response", gjson.Get(clientBody, "object").String()) + require.Contains(t, clientBody, "hello world") + require.Contains(t, clientBody, "output_text") +} + +// TestAPIKeyCCResponses_Stream 验证 apikey-chat-completions 账号在 +// /v1/responses 流式路径下的协议转换:上游 ChatCompletions SSE chunk 必须 +// 被服务端转换为 Responses API SSE 事件序列(response.created / +// response.output_text.delta / response.completed / [DONE])。 +func TestAPIKeyCCResponses_Stream(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstreamSSE := strings.Join([]string{ + `data: {"id":"chatcmpl_resp_stream","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}`, + "", + `data: {"id":"chatcmpl_resp_stream","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hi"},"finish_reason":null}]}`, + "", + `data: {"id":"chatcmpl_resp_stream","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"completion_tokens":1,"total_tokens":5}}`, + "", + "data: [DONE]", + "", + }, "\n") + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(upstreamSSE)) + })) + defer upstream.Close() + + httpUp := &apikeyCCHTTPUpstream{client: upstream.Client()} + cfg := apikeyCCResponsesTestConfig() + + svc := service.NewOpenAIGatewayService( + nil, nil, nil, nil, nil, nil, nil, + cfg, + nil, nil, + service.NewBillingService(cfg, nil), + nil, nil, + httpUp, + &service.DeferredService{}, + nil, nil, nil, nil, nil, + nil, + ) + + account := apikeyCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"gpt-4o","input":"hi","stream":true}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + result, err := svc.ForwardResponsesAsChatCompletions(c.Request.Context(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 4, result.Usage.InputTokens) + require.Equal(t, 1, result.Usage.OutputTokens) + + // 上游 URL 与流式 Accept 头 + require.Equal(t, upstream.URL+"/v1/chat/completions", httpUp.lastReqURL) + require.Equal(t, "text/event-stream", httpUp.lastAccept) + require.Equal(t, "Bearer sk-handler-cc-test", httpUp.lastAuthz) + + // 客户端收到的必须是 Responses SSE 事件序列 + clientBody := rec.Body.String() + require.Contains(t, clientBody, "event: response.created") + require.Contains(t, clientBody, "event: response.output_item.added") + require.Contains(t, clientBody, "event: response.content_part.added") + require.Contains(t, clientBody, "event: response.output_text.delta") + require.Contains(t, clientBody, "event: response.output_text.done") + require.Contains(t, clientBody, "event: response.content_part.done") + require.Contains(t, clientBody, "event: response.output_item.done") + require.Contains(t, clientBody, "event: response.completed") + require.Contains(t, clientBody, "data: [DONE]") +} diff --git a/backend/internal/pkg/apicompat/anthropic_to_chatcompletions_chain_test.go b/backend/internal/pkg/apicompat/anthropic_to_chatcompletions_chain_test.go new file mode 100644 index 00000000000..d9006346c92 --- /dev/null +++ b/backend/internal/pkg/apicompat/anthropic_to_chatcompletions_chain_test.go @@ -0,0 +1,144 @@ +package apicompat + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestAnthropicToCCChain_BasicText 验证 Anthropic→Responses→CC 的请求侧链式转换: +// 一段简单的 user 文本消息能完整保留到 ChatCompletionsRequest.messages。 +func TestAnthropicToCCChain_BasicText(t *testing.T) { + req := &AnthropicRequest{ + Model: "deepseek-chat", + MaxTokens: 256, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"Hello"`)}, + }, + } + + respReq, err := AnthropicToResponses(req) + require.NoError(t, err) + + ccReq, err := ResponsesToChatCompletionsRequest(respReq) + require.NoError(t, err) + assert.Equal(t, "deepseek-chat", ccReq.Model) + require.GreaterOrEqual(t, len(ccReq.Messages), 1) + last := ccReq.Messages[len(ccReq.Messages)-1] + assert.Equal(t, "user", last.Role) + assert.Contains(t, string(last.Content), "Hello") +} + +// TestAnthropicToCCChain_SystemBlock 验证 Anthropic system 数组形态会落地为 +// CC 的 system message。 +func TestAnthropicToCCChain_SystemBlock(t *testing.T) { + req := &AnthropicRequest{ + Model: "deepseek-chat", + MaxTokens: 64, + System: json.RawMessage(`[{"type":"text","text":"You are helpful."}]`), + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"Hi"`)}, + }, + } + respReq, err := AnthropicToResponses(req) + require.NoError(t, err) + ccReq, err := ResponsesToChatCompletionsRequest(respReq) + require.NoError(t, err) + require.GreaterOrEqual(t, len(ccReq.Messages), 2) + first := ccReq.Messages[0] + assert.Equal(t, "system", first.Role) + assert.Contains(t, string(first.Content), "You are helpful") +} + +// TestAnthropicToCCChain_ToolUseHistory 验证 assistant 的 tool_use 与后续 user 的 +// tool_result 都能在链路中正确转换为 CC 的 assistant.tool_calls + tool message。 +func TestAnthropicToCCChain_ToolUseHistory(t *testing.T) { + req := &AnthropicRequest{ + Model: "deepseek-chat", + MaxTokens: 64, + Tools: []AnthropicTool{{ + Name: "lookup", + Description: "fake", + InputSchema: json.RawMessage(`{"type":"object","properties":{"q":{"type":"string"}}}`), + }}, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"search docs"`)}, + {Role: "assistant", Content: json.RawMessage(`[ + {"type":"tool_use","id":"toolu_1","name":"lookup","input":{"q":"foo"}} + ]`)}, + {Role: "user", Content: json.RawMessage(`[ + {"type":"tool_result","tool_use_id":"toolu_1","content":"answer"} + ]`)}, + }, + } + respReq, err := AnthropicToResponses(req) + require.NoError(t, err) + ccReq, err := ResponsesToChatCompletionsRequest(respReq) + require.NoError(t, err) + require.GreaterOrEqual(t, len(ccReq.Messages), 3) + + var sawToolCall, sawToolMsg bool + for _, m := range ccReq.Messages { + if m.Role == "assistant" && len(m.ToolCalls) > 0 { + sawToolCall = true + assert.Equal(t, "lookup", m.ToolCalls[0].Function.Name) + } + if m.Role == "tool" && m.ToolCallID != "" { + sawToolMsg = true + assert.Contains(t, string(m.Content), "answer") + } + } + assert.True(t, sawToolCall, "expected an assistant tool_calls message") + assert.True(t, sawToolMsg, "expected a tool role message carrying the tool_result") +} + +// TestAnthropicToCCChain_ImageBlock 验证 image 块经 Anthropic→Responses→CC 后落到 +// CC 的 image_url.content 形态(base64 dataURI)。 +func TestAnthropicToCCChain_ImageBlock(t *testing.T) { + req := &AnthropicRequest{ + Model: "deepseek-chat", + MaxTokens: 64, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`[ + {"type":"text","text":"What is in this image?"}, + {"type":"image","source":{"type":"base64","media_type":"image/png","data":"AAA"}} + ]`)}, + }, + } + respReq, err := AnthropicToResponses(req) + require.NoError(t, err) + ccReq, err := ResponsesToChatCompletionsRequest(respReq) + require.NoError(t, err) + require.GreaterOrEqual(t, len(ccReq.Messages), 1) + bodyJSON, _ := json.Marshal(ccReq.Messages[len(ccReq.Messages)-1]) + assert.Contains(t, string(bodyJSON), "image_url") + assert.Contains(t, string(bodyJSON), "data:image/png;base64,AAA") +} + +// TestAnthropicToCCChain_MultiTurn 多轮对话历史的角色顺序应当被保持。 +func TestAnthropicToCCChain_MultiTurn(t *testing.T) { + req := &AnthropicRequest{ + Model: "deepseek-chat", + MaxTokens: 64, + Messages: []AnthropicMessage{ + {Role: "user", Content: json.RawMessage(`"q1"`)}, + {Role: "assistant", Content: json.RawMessage(`"a1"`)}, + {Role: "user", Content: json.RawMessage(`"q2"`)}, + }, + } + respReq, err := AnthropicToResponses(req) + require.NoError(t, err) + ccReq, err := ResponsesToChatCompletionsRequest(respReq) + require.NoError(t, err) + require.GreaterOrEqual(t, len(ccReq.Messages), 3) + roles := []string{} + for _, m := range ccReq.Messages { + if m.Role == "system" { + continue + } + roles = append(roles, m.Role) + } + assert.Equal(t, []string{"user", "assistant", "user"}, roles) +} diff --git a/backend/internal/pkg/apicompat/cc_chunks_to_anthropic_chain_test.go b/backend/internal/pkg/apicompat/cc_chunks_to_anthropic_chain_test.go new file mode 100644 index 00000000000..2f116539ada --- /dev/null +++ b/backend/internal/pkg/apicompat/cc_chunks_to_anthropic_chain_test.go @@ -0,0 +1,119 @@ +package apicompat + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// chainCCToAnthropic 把一组 CC SSE 帧 → Responses events → Anthropic events 后 +// 拼接为 Anthropic SSE 文本,方便 assert。 +func chainCCToAnthropic(t *testing.T, lines []string) (string, []AnthropicStreamEvent, *ResponsesUsage) { + t.Helper() + ccState := NewCCStreamState() + anthState := NewResponsesEventToAnthropicState() + var sseOut strings.Builder + var anthEvents []AnthropicStreamEvent + + pumpFrame := func(frame []byte) { + // Each frame is "event: \ndata: \n\n" or "data: [DONE]\n\n". + for _, raw := range strings.Split(string(frame), "\n") { + line := strings.TrimSpace(raw) + if !strings.HasPrefix(line, "data:") { + continue + } + payload := strings.TrimSpace(line[len("data:"):]) + if payload == "" || payload == "[DONE]" { + continue + } + var evt ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &evt); err != nil { + continue + } + for _, anth := range ResponsesEventToAnthropicEvents(&evt, anthState) { + anthEvents = append(anthEvents, anth) + sse, err := ResponsesAnthropicEventToSSE(anth) + require.NoError(t, err) + sseOut.WriteString(sse) + } + } + } + + for _, line := range lines { + frames, err := ConvertChatCompletionsSSEChunkToResponsesEvents([]byte(line), ccState) + if err != nil && !strings.Contains(err.Error(), "parse chat completions chunk") { + require.NoError(t, err) + } + for _, f := range frames { + pumpFrame(f) + } + } + if !ccState.CompletedSent { + for _, f := range FinalizeCCStream(ccState) { + pumpFrame(f) + } + } + for _, anth := range FinalizeResponsesAnthropicStream(anthState) { + anthEvents = append(anthEvents, anth) + sse, err := ResponsesAnthropicEventToSSE(anth) + require.NoError(t, err) + sseOut.WriteString(sse) + } + return sseOut.String(), anthEvents, ccState.Usage +} + +// TestCCChunksToAnthropicChain_TextDelta 验证 CC 的 text delta 链路转换为 +// Anthropic 的 message_start / content_block_start / content_block_delta 序列。 +func TestCCChunksToAnthropicChain_TextDelta(t *testing.T) { + lines := []string{ + `data: {"id":"chk_1","object":"chat.completion.chunk","model":"deepseek-chat","choices":[{"index":0,"delta":{"role":"assistant"}}]}`, + `data: {"id":"chk_1","object":"chat.completion.chunk","model":"deepseek-chat","choices":[{"index":0,"delta":{"content":"Hi"}}]}`, + `data: {"id":"chk_1","object":"chat.completion.chunk","model":"deepseek-chat","choices":[{"index":0,"delta":{"content":" there"}}]}`, + `data: {"id":"chk_1","object":"chat.completion.chunk","model":"deepseek-chat","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":3,"completion_tokens":2,"total_tokens":5}}`, + "data: [DONE]", + } + sse, events, usage := chainCCToAnthropic(t, lines) + assert.Contains(t, sse, "event: message_start") + assert.Contains(t, sse, "event: content_block_start") + assert.Contains(t, sse, "event: content_block_delta") + assert.Contains(t, sse, "event: content_block_stop") + assert.Contains(t, sse, "event: message_delta") + assert.Contains(t, sse, "event: message_stop") + require.NotNil(t, usage) + assert.Equal(t, 3, usage.InputTokens) + assert.Equal(t, 2, usage.OutputTokens) + assert.NotEmpty(t, events) +} + +// TestCCChunksToAnthropicChain_ToolCall 验证 CC 的 tool_calls.arguments 增量被 +// 翻译为 Anthropic input_json_delta + tool_use 块。 +func TestCCChunksToAnthropicChain_ToolCall(t *testing.T) { + lines := []string{ + `data: {"id":"chk","object":"chat.completion.chunk","model":"m","choices":[{"index":0,"delta":{"role":"assistant"}}]}`, + `data: {"id":"chk","object":"chat.completion.chunk","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_1","function":{"name":"lookup"}}]}}]}`, + `data: {"id":"chk","object":"chat.completion.chunk","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"q\":"}}]}}]}`, + `data: {"id":"chk","object":"chat.completion.chunk","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"x\"}"}}]}}]}`, + `data: {"id":"chk","object":"chat.completion.chunk","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`, + "data: [DONE]", + } + sse, _, _ := chainCCToAnthropic(t, lines) + assert.Contains(t, sse, "tool_use") + assert.Contains(t, sse, `"name":"lookup"`) + assert.Contains(t, sse, "input_json_delta") + assert.Contains(t, sse, "tool_use") +} + +// TestCCChunksToAnthropicChain_FinishLength 验证 finish_reason=length 在 Anthropic 侧映射为 +// stop_reason=max_tokens(与 Responses incomplete + max_output_tokens 对应)。 +func TestCCChunksToAnthropicChain_FinishLength(t *testing.T) { + lines := []string{ + `data: {"id":"c","object":"chat.completion.chunk","model":"m","choices":[{"index":0,"delta":{"role":"assistant","content":"x"}}]}`, + `data: {"id":"c","object":"chat.completion.chunk","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"length"}]}`, + "data: [DONE]", + } + sse, _, _ := chainCCToAnthropic(t, lines) + assert.Contains(t, sse, "max_tokens") +} diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_bridge.go b/backend/internal/pkg/apicompat/chatcompletions_responses_bridge.go deleted file mode 100644 index 09b680c7c73..00000000000 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_bridge.go +++ /dev/null @@ -1,727 +0,0 @@ -package apicompat - -import ( - "encoding/json" - "fmt" - "strings" - "time" -) - -// ResponsesToChatCompletionsRequest converts a Responses API request into a -// Chat Completions request for upstreams that only implement -// /v1/chat/completions. -func ResponsesToChatCompletionsRequest(req *ResponsesRequest) (*ChatCompletionsRequest, error) { - if req == nil { - return nil, fmt.Errorf("responses request is nil") - } - - messages, err := responsesInputToChatMessages(req.Instructions, req.Input) - if err != nil { - return nil, err - } - - out := &ChatCompletionsRequest{ - Model: req.Model, - Messages: messages, - MaxCompletionTokens: req.MaxOutputTokens, - Temperature: req.Temperature, - TopP: req.TopP, - Stream: req.Stream, - ServiceTier: req.ServiceTier, - } - if req.Reasoning != nil { - out.ReasoningEffort = req.Reasoning.Effort - } - if len(req.Tools) > 0 { - out.Tools = responsesToolsToChatTools(req.Tools) - } - if len(req.ToolChoice) > 0 { - out.ToolChoice = responsesToolChoiceToChatToolChoice(req.ToolChoice) - } - - return out, nil -} - -func responsesInputToChatMessages(instructions string, inputRaw json.RawMessage) ([]ChatMessage, error) { - var messages []ChatMessage - if strings.TrimSpace(instructions) != "" { - content, _ := json.Marshal(instructions) - messages = append(messages, ChatMessage{ - Role: "system", - Content: content, - }) - } - - inputRaw = bytesTrimSpace(inputRaw) - if len(inputRaw) == 0 || string(inputRaw) == "null" { - return messages, nil - } - - var inputText string - if err := json.Unmarshal(inputRaw, &inputText); err == nil { - content, _ := json.Marshal(inputText) - messages = append(messages, ChatMessage{ - Role: "user", - Content: content, - }) - return messages, nil - } - - var rawItems []json.RawMessage - if err := json.Unmarshal(inputRaw, &rawItems); err != nil { - return nil, fmt.Errorf("parse responses input: %w", err) - } - - for _, raw := range rawItems { - raw = bytesTrimSpace(raw) - if len(raw) == 0 || string(raw) == "null" { - continue - } - - var item map[string]json.RawMessage - if err := json.Unmarshal(raw, &item); err != nil { - var text string - if textErr := json.Unmarshal(raw, &text); textErr == nil { - content, _ := json.Marshal(text) - messages = append(messages, ChatMessage{Role: "user", Content: content}) - continue - } - return nil, fmt.Errorf("parse responses input item: %w", err) - } - - role := chatCompletionsBridgeRole(rawString(item["role"])) - itemType := rawString(item["type"]) - switch itemType { - case "function_call": - arguments := rawString(item["arguments"]) - if strings.TrimSpace(arguments) == "" { - arguments = "{}" - } - messages = append(messages, ChatMessage{ - Role: "assistant", - ToolCalls: []ChatToolCall{{ - ID: rawString(item["call_id"]), - Type: "function", - Function: ChatFunctionCall{ - Name: rawString(item["name"]), - Arguments: arguments, - }, - }}, - }) - continue - case "function_call_output": - content, _ := json.Marshal(rawString(item["output"])) - messages = append(messages, ChatMessage{ - Role: "tool", - ToolCallID: rawString(item["call_id"]), - Content: content, - }) - continue - case "input_text", "text": - content, _ := json.Marshal(rawString(item["text"])) - messages = append(messages, ChatMessage{Role: "user", Content: content}) - continue - case "input_image": - content, err := chatContentFromSingleResponsesPart(itemType, item) - if err != nil { - return nil, err - } - messages = append(messages, ChatMessage{Role: "user", Content: content}) - continue - } - - content := item["content"] - if len(bytesTrimSpace(content)) == 0 { - if text := rawString(item["text"]); text != "" { - content, _ = json.Marshal(text) - } - } - chatContent, err := responsesContentToChatContent(content, role) - if err != nil { - return nil, err - } - messages = append(messages, ChatMessage{ - Role: role, - Content: chatContent, - }) - } - - return messages, nil -} - -func chatCompletionsBridgeRole(role string) string { - trimmed := strings.TrimSpace(role) - if trimmed == "" { - return "user" - } - if strings.EqualFold(trimmed, "developer") { - return "system" - } - return role -} - -func responsesContentToChatContent(raw json.RawMessage, role string) (json.RawMessage, error) { - raw = bytesTrimSpace(raw) - if len(raw) == 0 || string(raw) == "null" { - empty, _ := json.Marshal("") - return empty, nil - } - - var text string - if err := json.Unmarshal(raw, &text); err == nil { - return raw, nil - } - - var rawParts []json.RawMessage - if err := json.Unmarshal(raw, &rawParts); err == nil { - return responsesContentPartsToChatContent(rawParts, role) - } - - var obj map[string]json.RawMessage - if err := json.Unmarshal(raw, &obj); err == nil { - return chatContentFromSingleResponsesPart(rawString(obj["type"]), obj) - } - - return raw, nil -} - -func responsesContentPartsToChatContent(rawParts []json.RawMessage, role string) (json.RawMessage, error) { - var textParts []string - var chatParts []ChatContentPart - hasNonText := false - - for _, rawPart := range rawParts { - var part map[string]json.RawMessage - if err := json.Unmarshal(rawPart, &part); err != nil { - continue - } - partType := rawString(part["type"]) - switch partType { - case "input_text", "output_text", "text", "": - text := rawString(part["text"]) - if text == "" { - continue - } - textParts = append(textParts, text) - chatParts = append(chatParts, ChatContentPart{Type: "text", Text: text}) - case "input_image", "image_url": - imageURL := rawString(part["image_url"]) - if imageURL == "" { - imageURL = rawNestedString(part["image_url"], "url") - } - if imageURL == "" { - continue - } - hasNonText = true - chatParts = append(chatParts, ChatContentPart{ - Type: "image_url", - ImageURL: &ChatImageURL{URL: imageURL}, - }) - } - } - - if !hasNonText { - joined, _ := json.Marshal(strings.Join(textParts, "\n\n")) - return joined, nil - } - if role != "user" { - joined, _ := json.Marshal(strings.Join(textParts, "\n\n")) - return joined, nil - } - if len(chatParts) == 0 { - empty, _ := json.Marshal("") - return empty, nil - } - return json.Marshal(chatParts) -} - -func chatContentFromSingleResponsesPart(partType string, part map[string]json.RawMessage) (json.RawMessage, error) { - switch partType { - case "input_image", "image_url": - imageURL := rawString(part["image_url"]) - if imageURL == "" { - imageURL = rawNestedString(part["image_url"], "url") - } - return json.Marshal([]ChatContentPart{{ - Type: "image_url", - ImageURL: &ChatImageURL{URL: imageURL}, - }}) - default: - return json.Marshal(rawString(part["text"])) - } -} - -func responsesToolsToChatTools(tools []ResponsesTool) []ChatTool { - out := make([]ChatTool, 0, len(tools)) - for _, tool := range tools { - if tool.Type != "function" { - continue - } - out = append(out, ChatTool{ - Type: "function", - Function: &ChatFunction{ - Name: tool.Name, - Description: tool.Description, - Parameters: tool.Parameters, - Strict: tool.Strict, - }, - }) - } - return out -} - -func responsesToolChoiceToChatToolChoice(raw json.RawMessage) json.RawMessage { - var choice map[string]json.RawMessage - if err := json.Unmarshal(raw, &choice); err != nil { - return raw - } - if rawString(choice["type"]) != "function" { - return raw - } - name := rawString(choice["name"]) - if name == "" { - name = rawNestedString(choice["function"], "name") - } - if name == "" { - return raw - } - out, err := json.Marshal(map[string]any{ - "type": "function", - "function": map[string]string{ - "name": name, - }, - }) - if err != nil { - return raw - } - return out -} - -// ChatCompletionsResponseToResponses converts a non-streaming Chat Completions -// response into a Responses API response. -func ChatCompletionsResponseToResponses(resp *ChatCompletionsResponse, model string) *ResponsesResponse { - id := "" - if resp != nil { - id = resp.ID - } - if id == "" { - id = generateResponsesID() - } - - out := &ResponsesResponse{ - ID: id, - Object: "response", - Model: model, - Status: "completed", - } - if resp == nil { - out.Output = []ResponsesOutput{emptyResponsesMessageOutput()} - return out - } - if out.Model == "" { - out.Model = resp.Model - } - - if len(resp.Choices) > 0 { - choice := resp.Choices[0] - out.Output = chatMessageToResponsesOutput(choice.Message) - if choice.FinishReason == "length" { - out.Status = "incomplete" - out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"} - } - } - if len(out.Output) == 0 { - out.Output = []ResponsesOutput{emptyResponsesMessageOutput()} - } - if resp.Usage != nil { - out.Usage = ChatUsageToResponsesUsage(resp.Usage) - } - return out -} - -func chatMessageToResponsesOutput(message ChatMessage) []ResponsesOutput { - var outputs []ResponsesOutput - if message.ReasoningContent != "" { - outputs = append(outputs, ResponsesOutput{ - Type: "reasoning", - ID: generateItemID(), - Summary: []ResponsesSummary{{ - Type: "summary_text", - Text: message.ReasoningContent, - }}, - }) - } - - text := chatMessageContentText(message.Content) - if text != "" || len(message.ToolCalls) == 0 { - outputs = append(outputs, ResponsesOutput{ - Type: "message", - ID: generateItemID(), - Role: "assistant", - Content: []ResponsesContentPart{{ - Type: "output_text", - Text: text, - }}, - Status: "completed", - }) - } - - for _, toolCall := range message.ToolCalls { - arguments := toolCall.Function.Arguments - if strings.TrimSpace(arguments) == "" { - arguments = "{}" - } - outputs = append(outputs, ResponsesOutput{ - Type: "function_call", - ID: generateItemID(), - CallID: toolCall.ID, - Name: toolCall.Function.Name, - Arguments: arguments, - Status: "completed", - }) - } - - return outputs -} - -func emptyResponsesMessageOutput() ResponsesOutput { - return ResponsesOutput{ - Type: "message", - ID: generateItemID(), - Role: "assistant", - Content: []ResponsesContentPart{{Type: "output_text", Text: ""}}, - Status: "completed", - } -} - -func chatMessageContentText(raw json.RawMessage) string { - raw = bytesTrimSpace(raw) - if len(raw) == 0 || string(raw) == "null" { - return "" - } - var text string - if err := json.Unmarshal(raw, &text); err == nil { - return text - } - var parts []ChatContentPart - if err := json.Unmarshal(raw, &parts); err == nil { - var texts []string - for _, part := range parts { - if part.Type == "text" && part.Text != "" { - texts = append(texts, part.Text) - } - } - return strings.Join(texts, "\n\n") - } - return "" -} - -// ChatUsageToResponsesUsage converts Chat Completions token usage to Responses -// usage shape. -func ChatUsageToResponsesUsage(usage *ChatUsage) *ResponsesUsage { - if usage == nil { - return nil - } - out := &ResponsesUsage{ - InputTokens: usage.PromptTokens, - OutputTokens: usage.CompletionTokens, - TotalTokens: usage.TotalTokens, - } - if out.TotalTokens == 0 { - out.TotalTokens = out.InputTokens + out.OutputTokens - } - if usage.PromptTokensDetails != nil && usage.PromptTokensDetails.CachedTokens > 0 { - out.InputTokensDetails = &ResponsesInputTokensDetails{ - CachedTokens: usage.PromptTokensDetails.CachedTokens, - } - } - return out -} - -// ChatCompletionsToResponsesStreamState tracks state while converting Chat -// Completions SSE chunks into Responses SSE events. -type ChatCompletionsToResponsesStreamState struct { - ResponseID string - Model string - Created int64 - SequenceNumber int - CreatedSent bool - CompletedSent bool - - MessageItemID string - Text strings.Builder - Reasoning strings.Builder - ToolCalls map[int]*ChatToolCall - - FinishReason string - Usage *ResponsesUsage -} - -// NewChatCompletionsToResponsesStreamState returns an initialized stream state. -func NewChatCompletionsToResponsesStreamState(model string) *ChatCompletionsToResponsesStreamState { - return &ChatCompletionsToResponsesStreamState{ - ResponseID: generateResponsesID(), - Model: model, - Created: time.Now().Unix(), - ToolCalls: make(map[int]*ChatToolCall), - } -} - -// ChatCompletionsChunkToResponsesEvents converts one Chat Completions stream -// chunk into zero or more Responses stream events. -func ChatCompletionsChunkToResponsesEvents( - chunk *ChatCompletionsChunk, - state *ChatCompletionsToResponsesStreamState, -) []ResponsesStreamEvent { - if chunk == nil || state == nil { - return nil - } - if chunk.ID != "" { - state.ResponseID = chunk.ID - } - if state.Model == "" && chunk.Model != "" { - state.Model = chunk.Model - } - if chunk.Usage != nil { - state.Usage = ChatUsageToResponsesUsage(chunk.Usage) - } - - var events []ResponsesStreamEvent - events = append(events, ensureChatToResponsesCreated(state)...) - - for _, choice := range chunk.Choices { - if choice.Delta.Content != nil { - events = append(events, ensureChatToResponsesMessageItem(state)...) - _, _ = state.Text.WriteString(*choice.Delta.Content) - events = append(events, chatToResponsesEvent(state, "response.output_text.delta", &ResponsesStreamEvent{ - OutputIndex: 0, - ContentIndex: 0, - Delta: *choice.Delta.Content, - ItemID: state.MessageItemID, - })) - } - if choice.Delta.ReasoningContent != nil { - _, _ = state.Reasoning.WriteString(*choice.Delta.ReasoningContent) - events = append(events, chatToResponsesEvent(state, "response.reasoning_summary_text.delta", &ResponsesStreamEvent{ - OutputIndex: 0, - SummaryIndex: 0, - Delta: *choice.Delta.ReasoningContent, - })) - } - for _, toolCall := range choice.Delta.ToolCalls { - idx := 0 - if toolCall.Index != nil { - idx = *toolCall.Index - } - stored, ok := state.ToolCalls[idx] - if !ok { - copyCall := toolCall - if copyCall.ID == "" { - copyCall.ID = generateItemID() - } - copyCall.Type = "function" - state.ToolCalls[idx] = ©Call - stored = ©Call - events = append(events, chatToResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ - OutputIndex: idx + 1, - Item: &ResponsesOutput{ - Type: "function_call", - ID: generateItemID(), - CallID: stored.ID, - Name: stored.Function.Name, - Status: "in_progress", - }, - })) - } else { - if toolCall.ID != "" { - stored.ID = toolCall.ID - } - if toolCall.Function.Name != "" { - stored.Function.Name = toolCall.Function.Name - } - } - if toolCall.Function.Arguments != "" { - stored.Function.Arguments += toolCall.Function.Arguments - events = append(events, chatToResponsesEvent(state, "response.function_call_arguments.delta", &ResponsesStreamEvent{ - OutputIndex: idx + 1, - Delta: toolCall.Function.Arguments, - CallID: stored.ID, - Name: stored.Function.Name, - })) - } - } - if choice.FinishReason != nil && *choice.FinishReason != "" { - state.FinishReason = *choice.FinishReason - } - } - - return events -} - -// FinalizeChatCompletionsResponsesStream emits terminal Responses events. -func FinalizeChatCompletionsResponsesStream(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent { - if state == nil || state.CompletedSent { - return nil - } - var events []ResponsesStreamEvent - events = append(events, ensureChatToResponsesCreated(state)...) - if state.MessageItemID != "" { - events = append(events, chatToResponsesEvent(state, "response.output_text.done", &ResponsesStreamEvent{ - OutputIndex: 0, - ContentIndex: 0, - Text: state.Text.String(), - ItemID: state.MessageItemID, - })) - events = append(events, chatToResponsesEvent(state, "response.output_item.done", &ResponsesStreamEvent{ - OutputIndex: 0, - Item: &ResponsesOutput{ - Type: "message", - ID: state.MessageItemID, - Role: "assistant", - Status: "completed", - }, - })) - } - - status := "completed" - var incompleteDetails *ResponsesIncompleteDetails - if state.FinishReason == "length" { - status = "incomplete" - incompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"} - } - - state.CompletedSent = true - events = append(events, chatToResponsesEvent(state, "response.completed", &ResponsesStreamEvent{ - Response: &ResponsesResponse{ - ID: state.ResponseID, - Object: "response", - Model: state.Model, - Status: status, - Output: state.chatOutput(), - Usage: state.Usage, - IncompleteDetails: incompleteDetails, - }, - })) - return events -} - -func ensureChatToResponsesCreated(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent { - if state.CreatedSent { - return nil - } - state.CreatedSent = true - return []ResponsesStreamEvent{chatToResponsesEvent(state, "response.created", &ResponsesStreamEvent{ - Response: &ResponsesResponse{ - ID: state.ResponseID, - Object: "response", - Model: state.Model, - Status: "in_progress", - Output: []ResponsesOutput{}, - }, - })} -} - -func ensureChatToResponsesMessageItem(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent { - if state.MessageItemID != "" { - return nil - } - state.MessageItemID = generateItemID() - return []ResponsesStreamEvent{chatToResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ - OutputIndex: 0, - Item: &ResponsesOutput{ - Type: "message", - ID: state.MessageItemID, - Role: "assistant", - Status: "in_progress", - }, - })} -} - -func (state *ChatCompletionsToResponsesStreamState) chatOutput() []ResponsesOutput { - var outputs []ResponsesOutput - if state.Reasoning.Len() > 0 { - outputs = append(outputs, ResponsesOutput{ - Type: "reasoning", - ID: generateItemID(), - Summary: []ResponsesSummary{{ - Type: "summary_text", - Text: state.Reasoning.String(), - }}, - }) - } - if state.MessageItemID != "" || len(state.ToolCalls) == 0 { - outputs = append(outputs, ResponsesOutput{ - Type: "message", - ID: nonEmpty(state.MessageItemID, generateItemID()), - Role: "assistant", - Content: []ResponsesContentPart{{ - Type: "output_text", - Text: state.Text.String(), - }}, - Status: "completed", - }) - } - for i := 0; i < len(state.ToolCalls); i++ { - toolCall, ok := state.ToolCalls[i] - if !ok || toolCall == nil { - continue - } - arguments := toolCall.Function.Arguments - if strings.TrimSpace(arguments) == "" { - arguments = "{}" - } - outputs = append(outputs, ResponsesOutput{ - Type: "function_call", - ID: generateItemID(), - CallID: toolCall.ID, - Name: toolCall.Function.Name, - Arguments: arguments, - Status: "completed", - }) - } - return outputs -} - -func chatToResponsesEvent( - state *ChatCompletionsToResponsesStreamState, - eventType string, - template *ResponsesStreamEvent, -) ResponsesStreamEvent { - seq := state.SequenceNumber - state.SequenceNumber++ - evt := *template - evt.Type = eventType - evt.SequenceNumber = seq - return evt -} - -func rawString(raw json.RawMessage) string { - raw = bytesTrimSpace(raw) - if len(raw) == 0 || string(raw) == "null" { - return "" - } - var s string - if err := json.Unmarshal(raw, &s); err == nil { - return s - } - return "" -} - -func rawNestedString(raw json.RawMessage, key string) string { - var obj map[string]json.RawMessage - if err := json.Unmarshal(raw, &obj); err != nil { - return "" - } - return rawString(obj[key]) -} - -func bytesTrimSpace(raw json.RawMessage) json.RawMessage { - return json.RawMessage(strings.TrimSpace(string(raw))) -} - -func nonEmpty(value, fallback string) string { - if value != "" { - return value - } - return fallback -} diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_bridge_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_bridge_test.go deleted file mode 100644 index 3e55e23a814..00000000000 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_bridge_test.go +++ /dev/null @@ -1,82 +0,0 @@ -package apicompat - -import ( - "encoding/json" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestResponsesInputToChatMessages_DeveloperRoleMapsToSystem(t *testing.T) { - messages, err := responsesInputToChatMessages("", json.RawMessage(`[{"role":"developer","content":"follow project instructions"}]`)) - require.NoError(t, err) - require.Len(t, messages, 1) - - assert.Equal(t, "system", messages[0].Role) - assert.JSONEq(t, `"follow project instructions"`, string(messages[0].Content)) -} - -func TestResponsesInputToChatMessages_KeepsChatCompletionRoles(t *testing.T) { - input := json.RawMessage(`[ - {"role":"system","content":"system message"}, - {"role":"user","content":"user message"}, - {"role":"assistant","content":"assistant message"}, - {"role":"tool","content":"tool message"} - ]`) - - messages, err := responsesInputToChatMessages("", input) - require.NoError(t, err) - require.Len(t, messages, 4) - - assert.Equal(t, []string{"system", "user", "assistant", "tool"}, chatMessageRoles(messages)) -} - -func TestResponsesInputToChatMessages_EmptyRoleFallsBackToUser(t *testing.T) { - messages, err := responsesInputToChatMessages("", json.RawMessage(`[{"role":"","content":"hello"}]`)) - require.NoError(t, err) - require.Len(t, messages, 1) - - assert.Equal(t, "user", messages[0].Role) -} - -func TestResponsesInputToChatMessages_DeveloperRoleTrimAndCaseInsensitive(t *testing.T) { - input := json.RawMessage(`[ - {"role":" Developer ","content":"one"}, - {"role":"\tDEVELOPER\n","content":"two"} - ]`) - - messages, err := responsesInputToChatMessages("", input) - require.NoError(t, err) - require.Len(t, messages, 2) - - assert.Equal(t, []string{"system", "system"}, chatMessageRoles(messages)) -} - -func TestResponsesToChatCompletionsRequest_InstructionsAndInputDeveloperRole(t *testing.T) { - req := &ResponsesRequest{ - Model: "gpt-4o", - Instructions: "Use concise answers.", - Input: json.RawMessage(`[ - {"role":"developer","content":[{"type":"input_text","text":"Prefer JSON."}]}, - {"role":"user","content":"Hello"} - ]`), - } - - out, err := ResponsesToChatCompletionsRequest(req) - require.NoError(t, err) - require.Len(t, out.Messages, 3) - - assert.Equal(t, []string{"system", "system", "user"}, chatMessageRoles(out.Messages)) - assert.JSONEq(t, `"Use concise answers."`, string(out.Messages[0].Content)) - assert.JSONEq(t, `"Prefer JSON."`, string(out.Messages[1].Content)) - assert.JSONEq(t, `"Hello"`, string(out.Messages[2].Content)) -} - -func chatMessageRoles(messages []ChatMessage) []string { - roles := make([]string, 0, len(messages)) - for _, message := range messages { - roles = append(roles, message.Role) - } - return roles -} diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go index b03b012fc7a..cf059594640 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -1160,6 +1160,34 @@ func TestChatChunkToSSE(t *testing.T) { assert.True(t, len(sse) > 10) } +func TestConvertChatCompletionsSSEChunkToResponsesEvents_TextContentPartLifecycle(t *testing.T) { + state := NewCCStreamState() + state.Model = "gpt-4o" + + content := "hello" + events, err := ConvertChatCompletionsSSEChunkToResponsesEvents([]byte(`data: {"id":"chatcmpl_test","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"hello"},"finish_reason":null}]}`), state) + require.NoError(t, err) + require.Len(t, events, 4) + assert.Contains(t, string(events[0]), "event: response.created") + assert.Contains(t, string(events[1]), "event: response.output_item.added") + assert.Contains(t, string(events[2]), "event: response.content_part.added") + assert.Contains(t, string(events[2]), `"type":"output_text"`) + assert.Contains(t, string(events[3]), "event: response.output_text.delta") + assert.Contains(t, string(events[3]), content) + + events, err = ConvertChatCompletionsSSEChunkToResponsesEvents([]byte(`data: {"id":"chatcmpl_test","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`), state) + require.NoError(t, err) + require.Len(t, events, 4) + assert.Contains(t, string(events[0]), "event: response.output_text.done") + assert.Contains(t, string(events[0]), content) + assert.Contains(t, string(events[1]), "event: response.content_part.done") + assert.Contains(t, string(events[1]), content) + assert.Contains(t, string(events[2]), "event: response.output_item.done") + assert.Contains(t, string(events[2]), content) + assert.Contains(t, string(events[3]), "event: response.completed") + assert.Contains(t, string(events[3]), content) +} + // --------------------------------------------------------------------------- // Stream round-trip test // --------------------------------------------------------------------------- diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses_response.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses_response.go new file mode 100644 index 00000000000..39dfba0ee13 --- /dev/null +++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses_response.go @@ -0,0 +1,1159 @@ +package apicompat + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" + "time" +) + +// --------------------------------------------------------------------------- +// Non-streaming: ChatCompletionsResponse → ResponsesResponse +// --------------------------------------------------------------------------- + +// ChatCompletionsToResponsesResponse converts a Chat Completions response into a +// Responses API response. choices[0].message becomes the output items, and usage +// is mapped to the Responses usage shape. +// +// Mapping rules: +// - choices[0].message.reasoning_content → reasoning output item (with summary_text) +// - choices[0].message.content (string) → message output item with output_text part +// - choices[0].message.tool_calls[] → function_call output items +// - choices[0].finish_reason → status / incomplete_details +// - usage.prompt_tokens / completion_tokens → usage.input_tokens / output_tokens +// - usage.prompt_tokens_details.cached_tokens → usage.input_tokens_details.cached_tokens +func ChatCompletionsToResponsesResponse(resp *ChatCompletionsResponse, model string) *ResponsesResponse { + id := resp.ID + if id == "" { + id = generateResponsesID() + } + if model == "" { + model = resp.Model + } + + out := &ResponsesResponse{ + ID: id, + Object: "response", + Model: model, + } + + var outputs []ResponsesOutput + finishReason := "" + + if len(resp.Choices) > 0 { + choice := resp.Choices[0] + finishReason = choice.FinishReason + msg := choice.Message + + if msg.ReasoningContent != "" { + outputs = append(outputs, ResponsesOutput{ + Type: "reasoning", + ID: generateItemID(), + Summary: []ResponsesSummary{{ + Type: "summary_text", + Text: msg.ReasoningContent, + }}, + }) + } + + text := extractChatMessageText(msg.Content) + if text != "" { + outputs = append(outputs, ResponsesOutput{ + Type: "message", + ID: generateItemID(), + Role: "assistant", + Content: []ResponsesContentPart{{ + Type: "output_text", + Text: text, + }}, + Status: "completed", + }) + } + + for _, tc := range msg.ToolCalls { + args := tc.Function.Arguments + if args == "" { + args = "{}" + } + callID := tc.ID + outputs = append(outputs, ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: callID, + Name: tc.Function.Name, + Arguments: args, + Status: "completed", + }) + } + + if msg.FunctionCall != nil { + args := msg.FunctionCall.Arguments + if args == "" { + args = "{}" + } + outputs = append(outputs, ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: generateCallID(), + Name: msg.FunctionCall.Name, + Arguments: args, + Status: "completed", + }) + } + } + + if len(outputs) == 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "message", + ID: generateItemID(), + Role: "assistant", + Content: []ResponsesContentPart{{Type: "output_text", Text: ""}}, + Status: "completed", + }) + } + out.Output = outputs + + out.Status = chatFinishReasonToResponsesStatus(finishReason) + if out.Status == "incomplete" { + reason := "max_output_tokens" + if finishReason == "content_filter" { + reason = "content_filter" + } + out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: reason} + } + + if resp.Usage != nil { + out.Usage = &ResponsesUsage{ + InputTokens: resp.Usage.PromptTokens, + OutputTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, + } + if out.Usage.TotalTokens == 0 { + out.Usage.TotalTokens = out.Usage.InputTokens + out.Usage.OutputTokens + } + if resp.Usage.PromptTokensDetails != nil && resp.Usage.PromptTokensDetails.CachedTokens > 0 { + out.Usage.InputTokensDetails = &ResponsesInputTokensDetails{ + CachedTokens: resp.Usage.PromptTokensDetails.CachedTokens, + } + } + if resp.Usage.CompletionTokensDetails != nil && resp.Usage.CompletionTokensDetails.ReasoningTokens > 0 { + out.Usage.OutputTokensDetails = &ResponsesOutputTokensDetails{ + ReasoningTokens: resp.Usage.CompletionTokensDetails.ReasoningTokens, + } + } + } + + return out +} + +// extractChatMessageText returns the textual content of a Chat message Content +// field. Content may be either a JSON string or an array of typed parts; for +// arrays, text parts are concatenated and non-text parts are ignored. +func extractChatMessageText(raw json.RawMessage) string { + if len(raw) == 0 { + return "" + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s + } + var parts []ChatContentPart + if err := json.Unmarshal(raw, &parts); err == nil { + var b strings.Builder + for _, p := range parts { + if p.Type == "text" && p.Text != "" { + b.WriteString(p.Text) + } + } + return b.String() + } + return "" +} + +// chatFinishReasonToResponsesStatus maps a Chat Completions finish_reason to +// the Responses API status field. +func chatFinishReasonToResponsesStatus(finishReason string) string { + switch finishReason { + case "length": + return "incomplete" + case "content_filter": + return "incomplete" + case "stop", "tool_calls", "function_call", "": + return "completed" + default: + return "completed" + } +} + +// --------------------------------------------------------------------------- +// Streaming: SSE chunk (Chat Completions) → []ResponsesStreamEvent (stateful) +// --------------------------------------------------------------------------- + +// CCStreamState tracks state for converting a stream of Chat Completions SSE +// chunks into Responses SSE events. +type CCStreamState struct { + ResponseID string + Model string + OutputIndex int + + CreatedSent bool + CompletedSent bool + + // Message item tracking — opened lazily on first text/reasoning delta. + MessageItemID string + MessageOpen bool + MessageOutputIdx int + ContentIndex int + MessageText string + + // Reasoning item tracking — opened lazily on first reasoning delta. + ReasoningItemID string + ReasoningOpen bool + ReasoningOutputIdx int + ReasoningSummaryIdx int + ReasoningText string + + // Tool calls indexed by Chat tool_call index → state. + ToolCalls map[int]*ccToolCallState + + // CompletedOutputs keeps terminal response.output complete even when items + // were emitted incrementally in earlier SSE events. + CompletedOutputs []ResponsesOutput + + // Sequence number for emitted Responses events. + SequenceNumber int + + // Final finish reason / usage captured from the closing chunk. + FinishReason string + Usage *ResponsesUsage +} + +type ccToolCallState struct { + ItemID string + CallID string + Name string + Arguments string + OutputIndex int + Opened bool +} + +// NewCCStreamState returns an initialised stream state. +func NewCCStreamState() *CCStreamState { + return &CCStreamState{ + ToolCalls: make(map[int]*ccToolCallState), + } +} + +// ConvertChatCompletionsSSEChunkToResponsesEvents parses a single SSE line from +// a Chat Completions stream (the raw bytes of one "data: ..." line, with or +// without trailing newline) and emits the corresponding Responses SSE events as +// "event: \ndata: \n\n" formatted byte slices. +// +// Supported inputs: +// - "data: {chunk JSON}" +// - "data: [DONE]" → finalises the stream (emits response.completed if not +// already sent, plus a literal "data: [DONE]\n\n" terminator) +// - blank lines / event lines without a payload → ignored +// - non-data lines (comments, SSE event:) → passed through unchanged so that +// any upstream noise reaches the client transparently +// +// Errors are returned only when JSON parsing fails for a non-[DONE] data line. +// In that case the original bytes are still returned (transparent passthrough) +// so the caller can decide whether to abort or forward. +func ConvertChatCompletionsSSEChunkToResponsesEvents(chunk []byte, state *CCStreamState) ([][]byte, error) { + line := bytes.TrimRight(chunk, "\r\n") + if len(line) == 0 { + return nil, nil + } + + // Pass through SSE comments and event: lines untouched. + if bytes.HasPrefix(line, []byte(":")) || bytes.HasPrefix(line, []byte("event:")) { + return [][]byte{appendNewlines(line)}, nil + } + + if !bytes.HasPrefix(line, []byte("data:")) { + // Unknown line shape — pass through so we don't lose information. + return [][]byte{appendNewlines(line)}, nil + } + + payload := bytes.TrimSpace(line[len("data:"):]) + if len(payload) == 0 { + return nil, nil + } + + if bytes.Equal(payload, []byte("[DONE]")) { + return finalizeCCStream(state), nil + } + + var chk ChatCompletionsChunk + if err := json.Unmarshal(payload, &chk); err != nil { + // Transparent passthrough for malformed chunks; surface the error so + // callers can log it without dropping the data. + return [][]byte{appendNewlines(line)}, fmt.Errorf("parse chat completions chunk: %w", err) + } + + events := convertChatChunk(&chk, state) + out := make([][]byte, 0, len(events)) + for _, evt := range events { + b, err := marshalResponsesEvent(evt) + if err != nil { + return out, err + } + out = append(out, b) + } + return out, nil +} + +// FinalizeCCStream emits any closing events that have not yet been written +// (response.completed + [DONE]) when the upstream stream ended without an +// explicit [DONE] marker. Idempotent. +func FinalizeCCStream(state *CCStreamState) [][]byte { + return finalizeCCStream(state) +} + +func finalizeCCStream(state *CCStreamState) [][]byte { + var out [][]byte + + if !state.CompletedSent { + if state.ResponseID == "" { + state.ResponseID = generateResponsesID() + } + if !state.CreatedSent { + created := makeCCCreatedEvent(state) + if b, err := marshalResponsesEvent(created); err == nil { + out = append(out, b) + } + state.CreatedSent = true + } + // Close any still-open items. + for _, evt := range closeCCOpenItems(state) { + if b, err := marshalResponsesEvent(evt); err == nil { + out = append(out, b) + } + } + completed := makeCCCompletedEvent(state) + if b, err := marshalResponsesEvent(completed); err == nil { + out = append(out, b) + } + state.CompletedSent = true + } + + out = append(out, []byte("data: [DONE]\n\n")) + return out +} + +func convertChatChunk(chk *ChatCompletionsChunk, state *CCStreamState) []ResponsesStreamEvent { + var events []ResponsesStreamEvent + + if state.ResponseID == "" { + if chk.ID != "" { + state.ResponseID = chk.ID + } else { + state.ResponseID = generateResponsesID() + } + } + if chk.Model != "" { + state.Model = chk.Model + } + + if chk.Usage != nil { + state.Usage = &ResponsesUsage{ + InputTokens: chk.Usage.PromptTokens, + OutputTokens: chk.Usage.CompletionTokens, + TotalTokens: chk.Usage.TotalTokens, + } + if state.Usage.TotalTokens == 0 { + state.Usage.TotalTokens = state.Usage.InputTokens + state.Usage.OutputTokens + } + if chk.Usage.PromptTokensDetails != nil && chk.Usage.PromptTokensDetails.CachedTokens > 0 { + state.Usage.InputTokensDetails = &ResponsesInputTokensDetails{ + CachedTokens: chk.Usage.PromptTokensDetails.CachedTokens, + } + } + if chk.Usage.CompletionTokensDetails != nil && chk.Usage.CompletionTokensDetails.ReasoningTokens > 0 { + state.Usage.OutputTokensDetails = &ResponsesOutputTokensDetails{ + ReasoningTokens: chk.Usage.CompletionTokensDetails.ReasoningTokens, + } + } + } + + if len(chk.Choices) == 0 { + return nil + } + + if !state.CreatedSent { + events = append(events, makeCCCreatedEvent(state)) + state.CreatedSent = true + } + + for _, ch := range chk.Choices { + // Reasoning delta + if ch.Delta.ReasoningContent != nil && *ch.Delta.ReasoningContent != "" { + if !state.ReasoningOpen { + state.ReasoningItemID = generateItemID() + state.ReasoningOpen = true + state.ReasoningOutputIdx = state.OutputIndex + state.OutputIndex++ + state.ReasoningText = "" + events = append(events, makeCCEvent(state, "response.output_item.added", ResponsesStreamEvent{ + OutputIndex: state.ReasoningOutputIdx, + Item: &ResponsesOutput{ + Type: "reasoning", + ID: state.ReasoningItemID, + }, + })) + } + events = append(events, makeCCEvent(state, "response.reasoning_summary_text.delta", ResponsesStreamEvent{ + OutputIndex: state.ReasoningOutputIdx, + SummaryIndex: state.ReasoningSummaryIdx, + Delta: *ch.Delta.ReasoningContent, + ItemID: state.ReasoningItemID, + })) + state.ReasoningText += *ch.Delta.ReasoningContent + } + + // Text content delta + if ch.Delta.Content != nil && *ch.Delta.Content != "" { + // If reasoning was open, close it before opening message. + if state.ReasoningOpen { + events = append(events, closeCCReasoningItem(state)...) + } + if !state.MessageOpen { + state.MessageItemID = generateItemID() + state.MessageOpen = true + state.MessageOutputIdx = state.OutputIndex + state.OutputIndex++ + state.MessageText = "" + events = append(events, makeCCEvent(state, "response.output_item.added", ResponsesStreamEvent{ + OutputIndex: state.MessageOutputIdx, + Item: &ResponsesOutput{ + Type: "message", + ID: state.MessageItemID, + Role: "assistant", + Status: "in_progress", + }, + })) + events = append(events, makeCCEvent(state, "response.content_part.added", ResponsesStreamEvent{ + OutputIndex: state.MessageOutputIdx, + ContentIndex: state.ContentIndex, + ItemID: state.MessageItemID, + Part: &ResponsesContentPart{ + Type: "output_text", + Text: "", + }, + })) + } + state.MessageText += *ch.Delta.Content + events = append(events, makeCCEvent(state, "response.output_text.delta", ResponsesStreamEvent{ + OutputIndex: state.MessageOutputIdx, + ContentIndex: state.ContentIndex, + Delta: *ch.Delta.Content, + ItemID: state.MessageItemID, + })) + } + + // Tool call deltas + for _, tc := range ch.Delta.ToolCalls { + idx := 0 + if tc.Index != nil { + idx = *tc.Index + } + st, ok := state.ToolCalls[idx] + if !ok { + st = &ccToolCallState{} + state.ToolCalls[idx] = st + } + if tc.ID != "" { + st.CallID = tc.ID + } + if tc.Function.Name != "" { + st.Name = tc.Function.Name + } + if !st.Opened && (st.CallID != "" || st.Name != "" || tc.Function.Arguments != "") { + // Close message/reasoning before opening function_call so that + // output indices stay consistent with the order they appear. + if state.MessageOpen { + events = append(events, closeCCMessageItem(state)...) + } + if state.ReasoningOpen { + events = append(events, closeCCReasoningItem(state)...) + } + if st.CallID == "" { + st.CallID = generateCallID() + } + st.ItemID = generateItemID() + st.OutputIndex = state.OutputIndex + state.OutputIndex++ + st.Opened = true + events = append(events, makeCCEvent(state, "response.output_item.added", ResponsesStreamEvent{ + OutputIndex: st.OutputIndex, + Item: &ResponsesOutput{ + Type: "function_call", + ID: st.ItemID, + CallID: st.CallID, + Name: st.Name, + Status: "in_progress", + }, + })) + } + if tc.Function.Arguments != "" && st.Opened { + st.Arguments += tc.Function.Arguments + events = append(events, makeCCEvent(state, "response.function_call_arguments.delta", ResponsesStreamEvent{ + OutputIndex: st.OutputIndex, + Delta: tc.Function.Arguments, + ItemID: st.ItemID, + CallID: st.CallID, + Name: st.Name, + })) + } + } + + // Finish reason → close items + completed + if ch.FinishReason != nil && *ch.FinishReason != "" { + state.FinishReason = *ch.FinishReason + events = append(events, closeCCOpenItems(state)...) + if !state.CompletedSent { + events = append(events, makeCCCompletedEvent(state)) + state.CompletedSent = true + } + } + } + + return events +} + +func closeCCOpenItems(state *CCStreamState) []ResponsesStreamEvent { + var events []ResponsesStreamEvent + if state.MessageOpen { + events = append(events, closeCCMessageItem(state)...) + } + if state.ReasoningOpen { + events = append(events, closeCCReasoningItem(state)...) + } + for _, idx := range sortedToolCallKeys(state.ToolCalls) { + st := state.ToolCalls[idx] + if st.Opened { + events = append(events, closeCCToolCall(state, st)...) + } + } + return events +} + +func closeCCMessageItem(state *CCStreamState) []ResponsesStreamEvent { + if !state.MessageOpen { + return nil + } + events := []ResponsesStreamEvent{ + makeCCEvent(state, "response.output_text.done", ResponsesStreamEvent{ + OutputIndex: state.MessageOutputIdx, + ContentIndex: state.ContentIndex, + ItemID: state.MessageItemID, + Text: state.MessageText, + }), + makeCCEvent(state, "response.content_part.done", ResponsesStreamEvent{ + OutputIndex: state.MessageOutputIdx, + ContentIndex: state.ContentIndex, + ItemID: state.MessageItemID, + Part: &ResponsesContentPart{ + Type: "output_text", + Text: state.MessageText, + }, + }), + makeCCEvent(state, "response.output_item.done", ResponsesStreamEvent{ + OutputIndex: state.MessageOutputIdx, + Item: &ResponsesOutput{ + Type: "message", + ID: state.MessageItemID, + Role: "assistant", + Status: "completed", + Content: []ResponsesContentPart{{ + Type: "output_text", + Text: state.MessageText, + }}, + }, + }), + } + state.MessageOpen = false + state.CompletedOutputs = append(state.CompletedOutputs, ResponsesOutput{ + Type: "message", + ID: state.MessageItemID, + Role: "assistant", + Status: "completed", + Content: []ResponsesContentPart{{ + Type: "output_text", + Text: state.MessageText, + }}, + }) + return events +} + +func closeCCReasoningItem(state *CCStreamState) []ResponsesStreamEvent { + if !state.ReasoningOpen { + return nil + } + events := []ResponsesStreamEvent{ + makeCCEvent(state, "response.reasoning_summary_text.done", ResponsesStreamEvent{ + OutputIndex: state.ReasoningOutputIdx, + SummaryIndex: state.ReasoningSummaryIdx, + ItemID: state.ReasoningItemID, + Text: state.ReasoningText, + }), + makeCCEvent(state, "response.output_item.done", ResponsesStreamEvent{ + OutputIndex: state.ReasoningOutputIdx, + Item: &ResponsesOutput{ + Type: "reasoning", + ID: state.ReasoningItemID, + Status: "completed", + Summary: []ResponsesSummary{{ + Type: "summary_text", + Text: state.ReasoningText, + }}, + }, + }), + } + state.ReasoningOpen = false + state.CompletedOutputs = append(state.CompletedOutputs, ResponsesOutput{ + Type: "reasoning", + ID: state.ReasoningItemID, + Status: "completed", + Summary: []ResponsesSummary{{ + Type: "summary_text", + Text: state.ReasoningText, + }}, + }) + return events +} + +func closeCCToolCall(state *CCStreamState, st *ccToolCallState) []ResponsesStreamEvent { + args := st.Arguments + if args == "" { + args = "{}" + } + events := []ResponsesStreamEvent{ + makeCCEvent(state, "response.function_call_arguments.done", ResponsesStreamEvent{ + OutputIndex: st.OutputIndex, + ItemID: st.ItemID, + CallID: st.CallID, + Name: st.Name, + Arguments: args, + }), + makeCCEvent(state, "response.output_item.done", ResponsesStreamEvent{ + OutputIndex: st.OutputIndex, + Item: &ResponsesOutput{ + Type: "function_call", + ID: st.ItemID, + CallID: st.CallID, + Name: st.Name, + Arguments: args, + Status: "completed", + }, + }), + } + st.Opened = false + state.CompletedOutputs = append(state.CompletedOutputs, ResponsesOutput{ + Type: "function_call", + ID: st.ItemID, + CallID: st.CallID, + Name: st.Name, + Arguments: args, + Status: "completed", + }) + return events +} + +func sortedToolCallKeys(m map[int]*ccToolCallState) []int { + keys := make([]int, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + // simple insertion sort to keep deterministic order without importing sort. + for i := 1; i < len(keys); i++ { + for j := i; j > 0 && keys[j-1] > keys[j]; j-- { + keys[j-1], keys[j] = keys[j], keys[j-1] + } + } + return keys +} + +func makeCCCreatedEvent(state *CCStreamState) ResponsesStreamEvent { + seq := state.SequenceNumber + state.SequenceNumber++ + return ResponsesStreamEvent{ + Type: "response.created", + SequenceNumber: seq, + Response: &ResponsesResponse{ + ID: state.ResponseID, + Object: "response", + Model: state.Model, + Status: "in_progress", + Output: []ResponsesOutput{}, + }, + } +} + +func makeCCCompletedEvent(state *CCStreamState) ResponsesStreamEvent { + seq := state.SequenceNumber + state.SequenceNumber++ + + status := chatFinishReasonToResponsesStatus(state.FinishReason) + var incompleteDetails *ResponsesIncompleteDetails + if status == "incomplete" { + reason := "max_output_tokens" + if state.FinishReason == "content_filter" { + reason = "content_filter" + } + incompleteDetails = &ResponsesIncompleteDetails{Reason: reason} + } + + output := append([]ResponsesOutput(nil), state.CompletedOutputs...) + + return ResponsesStreamEvent{ + Type: "response.completed", + SequenceNumber: seq, + Response: &ResponsesResponse{ + ID: state.ResponseID, + Object: "response", + Model: state.Model, + Status: status, + Output: output, + Usage: state.Usage, + IncompleteDetails: incompleteDetails, + }, + } +} + +func generateCallID() string { + id := generateItemID() + return "call_" + strings.TrimPrefix(id, "item_") +} + +func makeCCEvent(state *CCStreamState, eventType string, template ResponsesStreamEvent) ResponsesStreamEvent { + seq := state.SequenceNumber + state.SequenceNumber++ + template.Type = eventType + template.SequenceNumber = seq + return template +} + +func marshalResponsesEvent(evt ResponsesStreamEvent) ([]byte, error) { + data, err := json.Marshal(evt) + if err != nil { + return nil, err + } + return []byte(fmt.Sprintf("event: %s\ndata: %s\n\n", evt.Type, data)), nil +} + +func appendNewlines(line []byte) []byte { + out := make([]byte, 0, len(line)+2) + out = append(out, line...) + out = append(out, '\n', '\n') + return out +} + +func ChatCompletionsResponseToResponses(resp *ChatCompletionsResponse, model string) *ResponsesResponse { + id := "" + if resp != nil { + id = resp.ID + } + if id == "" { + id = generateResponsesID() + } + + out := &ResponsesResponse{ + ID: id, + Object: "response", + Model: model, + Status: "completed", + } + if resp == nil { + out.Output = []ResponsesOutput{emptyResponsesMessageOutput()} + return out + } + if out.Model == "" { + out.Model = resp.Model + } + + if len(resp.Choices) > 0 { + choice := resp.Choices[0] + out.Output = chatMessageToResponsesOutput(choice.Message) + if choice.FinishReason == "length" { + out.Status = "incomplete" + out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"} + } + } + if len(out.Output) == 0 { + out.Output = []ResponsesOutput{emptyResponsesMessageOutput()} + } + if resp.Usage != nil { + out.Usage = ChatUsageToResponsesUsage(resp.Usage) + } + return out +} + +func chatMessageToResponsesOutput(message ChatMessage) []ResponsesOutput { + var outputs []ResponsesOutput + if message.ReasoningContent != "" { + outputs = append(outputs, ResponsesOutput{ + Type: "reasoning", + ID: generateItemID(), + Summary: []ResponsesSummary{{ + Type: "summary_text", + Text: message.ReasoningContent, + }}, + }) + } + + text := chatMessageContentText(message.Content) + if text != "" || len(message.ToolCalls) == 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "message", + ID: generateItemID(), + Role: "assistant", + Content: []ResponsesContentPart{{ + Type: "output_text", + Text: text, + }}, + Status: "completed", + }) + } + + for _, toolCall := range message.ToolCalls { + arguments := toolCall.Function.Arguments + if strings.TrimSpace(arguments) == "" { + arguments = "{}" + } + outputs = append(outputs, ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: toolCall.ID, + Name: toolCall.Function.Name, + Arguments: arguments, + Status: "completed", + }) + } + + return outputs +} + +func emptyResponsesMessageOutput() ResponsesOutput { + return ResponsesOutput{ + Type: "message", + ID: generateItemID(), + Role: "assistant", + Content: []ResponsesContentPart{{Type: "output_text", Text: ""}}, + Status: "completed", + } +} + +func chatMessageContentText(raw json.RawMessage) string { + raw = bytesTrimSpace(raw) + if len(raw) == 0 || string(raw) == "null" { + return "" + } + var text string + if err := json.Unmarshal(raw, &text); err == nil { + return text + } + var parts []ChatContentPart + if err := json.Unmarshal(raw, &parts); err == nil { + var texts []string + for _, part := range parts { + if part.Type == "text" && part.Text != "" { + texts = append(texts, part.Text) + } + } + return strings.Join(texts, "\n\n") + } + return "" +} + +// ChatUsageToResponsesUsage converts Chat Completions token usage to Responses +// usage shape. +func ChatUsageToResponsesUsage(usage *ChatUsage) *ResponsesUsage { + if usage == nil { + return nil + } + out := &ResponsesUsage{ + InputTokens: usage.PromptTokens, + OutputTokens: usage.CompletionTokens, + TotalTokens: usage.TotalTokens, + } + if out.TotalTokens == 0 { + out.TotalTokens = out.InputTokens + out.OutputTokens + } + if usage.PromptTokensDetails != nil && usage.PromptTokensDetails.CachedTokens > 0 { + out.InputTokensDetails = &ResponsesInputTokensDetails{ + CachedTokens: usage.PromptTokensDetails.CachedTokens, + } + } + return out +} + +// ChatCompletionsToResponsesStreamState tracks state while converting Chat +// Completions SSE chunks into Responses SSE events. +type ChatCompletionsToResponsesStreamState struct { + ResponseID string + Model string + Created int64 + SequenceNumber int + CreatedSent bool + CompletedSent bool + + MessageItemID string + Text strings.Builder + Reasoning strings.Builder + ToolCalls map[int]*ChatToolCall + + FinishReason string + Usage *ResponsesUsage +} + +// NewChatCompletionsToResponsesStreamState returns an initialized stream state. +func NewChatCompletionsToResponsesStreamState(model string) *ChatCompletionsToResponsesStreamState { + return &ChatCompletionsToResponsesStreamState{ + ResponseID: generateResponsesID(), + Model: model, + Created: time.Now().Unix(), + ToolCalls: make(map[int]*ChatToolCall), + } +} + +// ChatCompletionsChunkToResponsesEvents converts one Chat Completions stream +// chunk into zero or more Responses stream events. +func ChatCompletionsChunkToResponsesEvents( + chunk *ChatCompletionsChunk, + state *ChatCompletionsToResponsesStreamState, +) []ResponsesStreamEvent { + if chunk == nil || state == nil { + return nil + } + if chunk.ID != "" { + state.ResponseID = chunk.ID + } + if state.Model == "" && chunk.Model != "" { + state.Model = chunk.Model + } + if chunk.Usage != nil { + state.Usage = ChatUsageToResponsesUsage(chunk.Usage) + } + + var events []ResponsesStreamEvent + events = append(events, ensureChatToResponsesCreated(state)...) + + for _, choice := range chunk.Choices { + if choice.Delta.Content != nil { + events = append(events, ensureChatToResponsesMessageItem(state)...) + _, _ = state.Text.WriteString(*choice.Delta.Content) + events = append(events, chatToResponsesEvent(state, "response.output_text.delta", &ResponsesStreamEvent{ + OutputIndex: 0, + ContentIndex: 0, + Delta: *choice.Delta.Content, + ItemID: state.MessageItemID, + })) + } + if choice.Delta.ReasoningContent != nil { + _, _ = state.Reasoning.WriteString(*choice.Delta.ReasoningContent) + events = append(events, chatToResponsesEvent(state, "response.reasoning_summary_text.delta", &ResponsesStreamEvent{ + OutputIndex: 0, + SummaryIndex: 0, + Delta: *choice.Delta.ReasoningContent, + })) + } + for _, toolCall := range choice.Delta.ToolCalls { + idx := 0 + if toolCall.Index != nil { + idx = *toolCall.Index + } + stored, ok := state.ToolCalls[idx] + if !ok { + copyCall := toolCall + if copyCall.ID == "" { + copyCall.ID = generateItemID() + } + copyCall.Type = "function" + state.ToolCalls[idx] = ©Call + stored = ©Call + events = append(events, chatToResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: idx + 1, + Item: &ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: stored.ID, + Name: stored.Function.Name, + Status: "in_progress", + }, + })) + } else { + if toolCall.ID != "" { + stored.ID = toolCall.ID + } + if toolCall.Function.Name != "" { + stored.Function.Name = toolCall.Function.Name + } + } + if toolCall.Function.Arguments != "" { + stored.Function.Arguments += toolCall.Function.Arguments + events = append(events, chatToResponsesEvent(state, "response.function_call_arguments.delta", &ResponsesStreamEvent{ + OutputIndex: idx + 1, + Delta: toolCall.Function.Arguments, + CallID: stored.ID, + Name: stored.Function.Name, + })) + } + } + if choice.FinishReason != nil && *choice.FinishReason != "" { + state.FinishReason = *choice.FinishReason + } + } + + return events +} + +// FinalizeChatCompletionsResponsesStream emits terminal Responses events. +func FinalizeChatCompletionsResponsesStream(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent { + if state == nil || state.CompletedSent { + return nil + } + var events []ResponsesStreamEvent + events = append(events, ensureChatToResponsesCreated(state)...) + if state.MessageItemID != "" { + events = append(events, chatToResponsesEvent(state, "response.output_text.done", &ResponsesStreamEvent{ + OutputIndex: 0, + ContentIndex: 0, + Text: state.Text.String(), + ItemID: state.MessageItemID, + })) + events = append(events, chatToResponsesEvent(state, "response.output_item.done", &ResponsesStreamEvent{ + OutputIndex: 0, + Item: &ResponsesOutput{ + Type: "message", + ID: state.MessageItemID, + Role: "assistant", + Status: "completed", + }, + })) + } + + status := "completed" + var incompleteDetails *ResponsesIncompleteDetails + if state.FinishReason == "length" { + status = "incomplete" + incompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"} + } + + state.CompletedSent = true + events = append(events, chatToResponsesEvent(state, "response.completed", &ResponsesStreamEvent{ + Response: &ResponsesResponse{ + ID: state.ResponseID, + Object: "response", + Model: state.Model, + Status: status, + Output: state.chatOutput(), + Usage: state.Usage, + IncompleteDetails: incompleteDetails, + }, + })) + return events +} + +func ensureChatToResponsesCreated(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent { + if state.CreatedSent { + return nil + } + state.CreatedSent = true + return []ResponsesStreamEvent{chatToResponsesEvent(state, "response.created", &ResponsesStreamEvent{ + Response: &ResponsesResponse{ + ID: state.ResponseID, + Object: "response", + Model: state.Model, + Status: "in_progress", + Output: []ResponsesOutput{}, + }, + })} +} + +func ensureChatToResponsesMessageItem(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent { + if state.MessageItemID != "" { + return nil + } + state.MessageItemID = generateItemID() + return []ResponsesStreamEvent{chatToResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: 0, + Item: &ResponsesOutput{ + Type: "message", + ID: state.MessageItemID, + Role: "assistant", + Status: "in_progress", + }, + })} +} + +func (state *ChatCompletionsToResponsesStreamState) chatOutput() []ResponsesOutput { + var outputs []ResponsesOutput + if state.Reasoning.Len() > 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "reasoning", + ID: generateItemID(), + Summary: []ResponsesSummary{{ + Type: "summary_text", + Text: state.Reasoning.String(), + }}, + }) + } + if state.MessageItemID != "" || len(state.ToolCalls) == 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "message", + ID: nonEmpty(state.MessageItemID, generateItemID()), + Role: "assistant", + Content: []ResponsesContentPart{{ + Type: "output_text", + Text: state.Text.String(), + }}, + Status: "completed", + }) + } + for i := 0; i < len(state.ToolCalls); i++ { + toolCall, ok := state.ToolCalls[i] + if !ok || toolCall == nil { + continue + } + arguments := toolCall.Function.Arguments + if strings.TrimSpace(arguments) == "" { + arguments = "{}" + } + outputs = append(outputs, ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: toolCall.ID, + Name: toolCall.Function.Name, + Arguments: arguments, + Status: "completed", + }) + } + return outputs +} + +func chatToResponsesEvent( + state *ChatCompletionsToResponsesStreamState, + eventType string, + template *ResponsesStreamEvent, +) ResponsesStreamEvent { + seq := state.SequenceNumber + state.SequenceNumber++ + evt := *template + evt.Type = eventType + evt.SequenceNumber = seq + return evt +} + +func bytesTrimSpace(raw json.RawMessage) json.RawMessage { + return bytes.TrimSpace(raw) +} + +func nonEmpty(value, fallback string) string { + if strings.TrimSpace(value) != "" { + return value + } + return fallback +} diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses_response_test.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses_response_test.go new file mode 100644 index 00000000000..ef2a3838b80 --- /dev/null +++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses_response_test.go @@ -0,0 +1,424 @@ +package apicompat + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Non-streaming tests +// --------------------------------------------------------------------------- + +func TestChatCompletionsToResponsesResponse_BasicText(t *testing.T) { + contentRaw, _ := json.Marshal("Hello, world!") + resp := &ChatCompletionsResponse{ + ID: "chatcmpl-1", + Model: "gpt-4o", + Choices: []ChatChoice{ + { + Index: 0, + Message: ChatMessage{ + Role: "assistant", + Content: contentRaw, + }, + FinishReason: "stop", + }, + }, + Usage: &ChatUsage{PromptTokens: 7, CompletionTokens: 4, TotalTokens: 11}, + } + + out := ChatCompletionsToResponsesResponse(resp, "gpt-4o") + assert.Equal(t, "response", out.Object) + assert.Equal(t, "gpt-4o", out.Model) + assert.Equal(t, "completed", out.Status) + require.Len(t, out.Output, 1) + assert.Equal(t, "message", out.Output[0].Type) + assert.Equal(t, "assistant", out.Output[0].Role) + require.Len(t, out.Output[0].Content, 1) + assert.Equal(t, "output_text", out.Output[0].Content[0].Type) + assert.Equal(t, "Hello, world!", out.Output[0].Content[0].Text) + + require.NotNil(t, out.Usage) + assert.Equal(t, 7, out.Usage.InputTokens) + assert.Equal(t, 4, out.Usage.OutputTokens) + assert.Equal(t, 11, out.Usage.TotalTokens) +} + +func TestChatCompletionsToResponsesResponse_ToolCalls(t *testing.T) { + resp := &ChatCompletionsResponse{ + ID: "chatcmpl-tc", + Choices: []ChatChoice{ + { + Message: ChatMessage{ + Role: "assistant", + ToolCalls: []ChatToolCall{ + { + ID: "call_abc", + Type: "function", + Function: ChatFunctionCall{ + Name: "get_weather", + Arguments: `{"city":"NYC"}`, + }, + }, + }, + }, + FinishReason: "tool_calls", + }, + }, + } + + out := ChatCompletionsToResponsesResponse(resp, "gpt-4o") + assert.Equal(t, "completed", out.Status) + require.Len(t, out.Output, 1) + assert.Equal(t, "function_call", out.Output[0].Type) + assert.Equal(t, "call_abc", out.Output[0].CallID) + assert.Equal(t, "get_weather", out.Output[0].Name) + assert.Equal(t, `{"city":"NYC"}`, out.Output[0].Arguments) +} + +func TestChatCompletionsToResponsesResponse_LegacyFunctionCall(t *testing.T) { + resp := &ChatCompletionsResponse{ + ID: "chatcmpl-fn", + Choices: []ChatChoice{{ + Message: ChatMessage{ + Role: "assistant", + FunctionCall: &ChatFunctionCall{ + Name: "get_weather", + Arguments: `{"city":"SF"}`, + }, + }, + FinishReason: "function_call", + }}, + } + + out := ChatCompletionsToResponsesResponse(resp, "gpt-4o") + require.Len(t, out.Output, 1) + assert.Equal(t, "function_call", out.Output[0].Type) + assert.NotEmpty(t, out.Output[0].CallID) + assert.Equal(t, "get_weather", out.Output[0].Name) + assert.Equal(t, `{"city":"SF"}`, out.Output[0].Arguments) +} + +func TestChatCompletionsToResponsesResponse_LengthFinish(t *testing.T) { + contentRaw, _ := json.Marshal("partial...") + resp := &ChatCompletionsResponse{ + Choices: []ChatChoice{ + { + Message: ChatMessage{Role: "assistant", Content: contentRaw}, + FinishReason: "length", + }, + }, + } + out := ChatCompletionsToResponsesResponse(resp, "gpt-4o") + assert.Equal(t, "incomplete", out.Status) + require.NotNil(t, out.IncompleteDetails) + assert.Equal(t, "max_output_tokens", out.IncompleteDetails.Reason) +} + +func TestChatCompletionsToResponsesResponse_Reasoning(t *testing.T) { + contentRaw, _ := json.Marshal("the answer is 42") + resp := &ChatCompletionsResponse{ + Choices: []ChatChoice{ + { + Message: ChatMessage{ + Role: "assistant", + Content: contentRaw, + ReasoningContent: "I considered the question carefully.", + }, + FinishReason: "stop", + }, + }, + } + out := ChatCompletionsToResponsesResponse(resp, "gpt-4o") + require.Len(t, out.Output, 2) + assert.Equal(t, "reasoning", out.Output[0].Type) + require.Len(t, out.Output[0].Summary, 1) + assert.Equal(t, "I considered the question carefully.", out.Output[0].Summary[0].Text) + assert.Equal(t, "message", out.Output[1].Type) + assert.Equal(t, "the answer is 42", out.Output[1].Content[0].Text) +} + +func TestChatCompletionsToResponsesResponse_CachedTokens(t *testing.T) { + contentRaw, _ := json.Marshal("cached") + resp := &ChatCompletionsResponse{ + Choices: []ChatChoice{ + {Message: ChatMessage{Role: "assistant", Content: contentRaw}, FinishReason: "stop"}, + }, + Usage: &ChatUsage{ + PromptTokens: 100, + CompletionTokens: 10, + TotalTokens: 110, + PromptTokensDetails: &ChatTokenDetails{CachedTokens: 80}, + }, + } + out := ChatCompletionsToResponsesResponse(resp, "gpt-4o") + require.NotNil(t, out.Usage) + require.NotNil(t, out.Usage.InputTokensDetails) + assert.Equal(t, 80, out.Usage.InputTokensDetails.CachedTokens) +} + +func TestChatCompletionsToResponsesResponse_ReasoningTokenUsage(t *testing.T) { + contentRaw, _ := json.Marshal("answer") + resp := &ChatCompletionsResponse{ + Choices: []ChatChoice{{Message: ChatMessage{Role: "assistant", Content: contentRaw}, FinishReason: "stop"}}, + Usage: &ChatUsage{ + PromptTokens: 3, + CompletionTokens: 7, + TotalTokens: 10, + CompletionTokensDetails: &ChatCompletionTokenDetails{ReasoningTokens: 5}, + }, + } + out := ChatCompletionsToResponsesResponse(resp, "gpt-4o") + require.NotNil(t, out.Usage) + require.NotNil(t, out.Usage.OutputTokensDetails) + assert.Equal(t, 5, out.Usage.OutputTokensDetails.ReasoningTokens) +} + +func TestChatCompletionsToResponsesResponse_EmptyChoicesFallback(t *testing.T) { + resp := &ChatCompletionsResponse{ID: "chatcmpl-x"} + out := ChatCompletionsToResponsesResponse(resp, "gpt-4o") + require.Len(t, out.Output, 1) + assert.Equal(t, "message", out.Output[0].Type) +} + +// --------------------------------------------------------------------------- +// Streaming tests +// --------------------------------------------------------------------------- + +func mustEvent(t *testing.T, raw []byte) (string, ResponsesStreamEvent) { + t.Helper() + s := string(raw) + require.True(t, strings.HasPrefix(s, "event: "), "missing event prefix: %q", s) + idx := strings.Index(s, "\ndata: ") + require.GreaterOrEqual(t, idx, 0) + eventType := strings.TrimPrefix(s[:idx], "event: ") + payload := strings.TrimSuffix(s[idx+len("\ndata: "):], "\n\n") + var evt ResponsesStreamEvent + require.NoError(t, json.Unmarshal([]byte(payload), &evt)) + return eventType, evt +} + +func sseChunk(payload string) []byte { + return []byte("data: " + payload) +} + +func TestChatCompletionsToResponses_Stream_TextSequence(t *testing.T) { + state := NewCCStreamState() + + // First chunk: role only (often opens the stream). + first := `{"id":"chatcmpl-stream","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}` + out, err := ConvertChatCompletionsSSEChunkToResponsesEvents(sseChunk(first), state) + require.NoError(t, err) + require.Len(t, out, 1) + typ, evt := mustEvent(t, out[0]) + assert.Equal(t, "response.created", typ) + require.NotNil(t, evt.Response) + assert.Equal(t, "chatcmpl-stream", evt.Response.ID) + assert.True(t, state.CreatedSent) + + // Text deltas + for _, txt := range []string{"Hello", ", ", "world"} { + chunk := `{"id":"chatcmpl-stream","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":` + jsonString(txt) + `},"finish_reason":null}]}` + out, err = ConvertChatCompletionsSSEChunkToResponsesEvents(sseChunk(chunk), state) + require.NoError(t, err) + } + // First text delta should also have produced response.output_item.added. + // Verify by re-running for the first text via fresh state — already covered via assertions below. + + // Final chunk: finish_reason + usage + final := `{"id":"chatcmpl-stream","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8}}` + out, err = ConvertChatCompletionsSSEChunkToResponsesEvents(sseChunk(final), state) + require.NoError(t, err) + + // Expect output_text.done, output_item.done (message), response.completed + types := make([]string, 0, len(out)) + for _, b := range out { + typ, _ := mustEvent(t, b) + types = append(types, typ) + } + assert.Contains(t, types, "response.output_text.done") + assert.Contains(t, types, "response.output_item.done") + assert.Contains(t, types, "response.completed") + assert.True(t, state.CompletedSent) + + // Verify completed carries usage + for _, b := range out { + typ, evt := mustEvent(t, b) + if typ == "response.completed" { + require.NotNil(t, evt.Response) + require.NotNil(t, evt.Response.Usage) + assert.Equal(t, 5, evt.Response.Usage.InputTokens) + assert.Equal(t, 3, evt.Response.Usage.OutputTokens) + } + } + + // [DONE] should add the terminator only. + out, err = ConvertChatCompletionsSSEChunkToResponsesEvents(sseChunk("[DONE]"), state) + require.NoError(t, err) + require.Len(t, out, 1) + assert.Equal(t, "data: [DONE]\n\n", string(out[0])) +} + +func TestChatCompletionsToResponses_Stream_ToolCalls(t *testing.T) { + state := NewCCStreamState() + + // Chunk 1: opens with role. + c1 := `{"id":"chatcmpl-tc","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}` + _, err := ConvertChatCompletionsSSEChunkToResponsesEvents(sseChunk(c1), state) + require.NoError(t, err) + + // Chunk 2: tool call header (id + name). + c2 := `{"id":"chatcmpl-tc","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":""}}]},"finish_reason":null}]}` + out, err := ConvertChatCompletionsSSEChunkToResponsesEvents(sseChunk(c2), state) + require.NoError(t, err) + var sawAdded bool + for _, b := range out { + typ, evt := mustEvent(t, b) + if typ == "response.output_item.added" { + require.NotNil(t, evt.Item) + if evt.Item.Type == "function_call" { + assert.Equal(t, "call_1", evt.Item.CallID) + assert.Equal(t, "get_weather", evt.Item.Name) + sawAdded = true + } + } + } + assert.True(t, sawAdded, "expected function_call output_item.added") + + // Chunk 3: arguments delta. + c3 := `{"id":"chatcmpl-tc","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"name":"","arguments":"{\"city\":"}}]},"finish_reason":null}]}` + out, err = ConvertChatCompletionsSSEChunkToResponsesEvents(sseChunk(c3), state) + require.NoError(t, err) + var sawArgsDelta bool + for _, b := range out { + typ, evt := mustEvent(t, b) + if typ == "response.function_call_arguments.delta" { + assert.Equal(t, `{"city":`, evt.Delta) + sawArgsDelta = true + } + } + assert.True(t, sawArgsDelta, "expected function_call_arguments.delta") + + // Chunk 4: finish. + c4 := `{"id":"chatcmpl-tc","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}` + out, err = ConvertChatCompletionsSSEChunkToResponsesEvents(sseChunk(c4), state) + require.NoError(t, err) + types := make([]string, 0, len(out)) + for _, b := range out { + typ, _ := mustEvent(t, b) + types = append(types, typ) + } + assert.Contains(t, types, "response.function_call_arguments.done") + assert.Contains(t, types, "response.output_item.done") + assert.Contains(t, types, "response.completed") + + for _, b := range out { + typ, evt := mustEvent(t, b) + if typ == "response.function_call_arguments.done" { + assert.Equal(t, `{"city":`, evt.Arguments) + } + if typ == "response.completed" { + require.NotNil(t, evt.Response) + require.Len(t, evt.Response.Output, 1) + assert.Equal(t, "function_call", evt.Response.Output[0].Type) + assert.Equal(t, `{"city":`, evt.Response.Output[0].Arguments) + } + } +} + +func TestChatCompletionsToResponses_Stream_UsageOnlyChunkDoesNotCreateResponse(t *testing.T) { + state := NewCCStreamState() + usageOnly := `{"id":"chatcmpl-usage","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[],"usage":{"prompt_tokens":9,"completion_tokens":2,"total_tokens":11,"completion_tokens_details":{"reasoning_tokens":1}}}` + out, err := ConvertChatCompletionsSSEChunkToResponsesEvents(sseChunk(usageOnly), state) + require.NoError(t, err) + require.Empty(t, out) + assert.False(t, state.CreatedSent) + require.NotNil(t, state.Usage) + assert.Equal(t, 9, state.Usage.InputTokens) + require.NotNil(t, state.Usage.OutputTokensDetails) + assert.Equal(t, 1, state.Usage.OutputTokensDetails.ReasoningTokens) + + out, err = ConvertChatCompletionsSSEChunkToResponsesEvents(sseChunk("[DONE]"), state) + require.NoError(t, err) + require.GreaterOrEqual(t, len(out), 3) + typ, evt := mustEvent(t, out[0]) + assert.Equal(t, "response.created", typ) + require.NotNil(t, evt.Response) + assert.Equal(t, "chatcmpl-usage", evt.Response.ID) +} + +func TestChatCompletionsToResponses_Stream_ToolArgumentsBeforeHeader(t *testing.T) { + state := NewCCStreamState() + chunk := `{"id":"chatcmpl-tc2","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\\"x\\\":1}"}}]},"finish_reason":null}]}` + out, err := ConvertChatCompletionsSSEChunkToResponsesEvents(sseChunk(chunk), state) + require.NoError(t, err) + var sawAdded bool + for _, b := range out { + typ, evt := mustEvent(t, b) + if typ == "response.output_item.added" && evt.Item != nil && evt.Item.Type == "function_call" { + sawAdded = true + assert.NotEmpty(t, evt.Item.CallID) + } + } + assert.True(t, sawAdded) +} + +func TestChatCompletionsToResponses_Stream_Done(t *testing.T) { + state := NewCCStreamState() + state.CompletedSent = true // pretend we already finalized. + out, err := ConvertChatCompletionsSSEChunkToResponsesEvents(sseChunk("[DONE]"), state) + require.NoError(t, err) + require.Len(t, out, 1) + assert.Equal(t, "data: [DONE]\n\n", string(out[0])) +} + +func TestChatCompletionsToResponses_Stream_DoneWithoutFinish(t *testing.T) { + // Stream ends with [DONE] before any explicit finish_reason — finalize must + // emit a synthetic response.completed. + state := NewCCStreamState() + state.CreatedSent = true + state.ResponseID = "resp_x" + + out, err := ConvertChatCompletionsSSEChunkToResponsesEvents(sseChunk("[DONE]"), state) + require.NoError(t, err) + // Expect at least: response.completed + [DONE] + require.GreaterOrEqual(t, len(out), 2) + last := out[len(out)-1] + assert.Equal(t, "data: [DONE]\n\n", string(last)) + typ, _ := mustEvent(t, out[len(out)-2]) + assert.Equal(t, "response.completed", typ) +} + +func TestChatCompletionsToResponses_Stream_MalformedChunkPassthrough(t *testing.T) { + state := NewCCStreamState() + out, err := ConvertChatCompletionsSSEChunkToResponsesEvents(sseChunk("not json"), state) + require.Error(t, err) + require.Len(t, out, 1) + assert.Equal(t, "data: not json\n\n", string(out[0])) +} + +func TestChatCompletionsToResponses_Stream_BlankLineIgnored(t *testing.T) { + state := NewCCStreamState() + out, err := ConvertChatCompletionsSSEChunkToResponsesEvents([]byte(""), state) + require.NoError(t, err) + require.Empty(t, out) +} + +func TestChatCompletionsToResponses_Stream_FinalizeIdempotent(t *testing.T) { + state := NewCCStreamState() + state.CreatedSent = true + state.CompletedSent = true + a := FinalizeCCStream(state) + require.Len(t, a, 1) + assert.Equal(t, "data: [DONE]\n\n", string(a[0])) +} + +// jsonString returns a JSON-encoded string literal (with quotes and escaping). +func jsonString(s string) string { + b, _ := json.Marshal(s) + return string(b) +} diff --git a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go index 8809b4fc0e9..4875db21723 100644 --- a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go +++ b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go @@ -355,7 +355,7 @@ func promptDetailsFromResponses(src *ResponsesInputTokensDetails) *ChatTokenDeta // official CompletionUsage schema: reasoning_tokens, audio_tokens, and // the predicted-outputs accepted/rejected counts. Returns nil when nothing // would be emitted so non-reasoning, non-audio responses stay clean. -func completionDetailsFromResponses(src *ResponsesOutputTokensDetails) *ChatTokenDetails { +func completionDetailsFromResponses(src *ResponsesOutputTokensDetails) *ChatCompletionTokenDetails { if src == nil { return nil } @@ -363,7 +363,7 @@ func completionDetailsFromResponses(src *ResponsesOutputTokensDetails) *ChatToke src.AcceptedPredictionTokens == 0 && src.RejectedPredictionTokens == 0 { return nil } - return &ChatTokenDetails{ + return &ChatCompletionTokenDetails{ ReasoningTokens: src.ReasoningTokens, AudioTokens: src.AudioTokens, AcceptedPredictionTokens: src.AcceptedPredictionTokens, diff --git a/backend/internal/pkg/apicompat/responses_to_chatcompletions_request.go b/backend/internal/pkg/apicompat/responses_to_chatcompletions_request.go new file mode 100644 index 00000000000..2eb8d8472df --- /dev/null +++ b/backend/internal/pkg/apicompat/responses_to_chatcompletions_request.go @@ -0,0 +1,489 @@ +package apicompat + +import ( + "encoding/json" + "fmt" + "log/slog" +) + +// ConvertResponsesOptions controls optional behaviour during Responses→ChatCompletions conversion. +type ConvertResponsesOptions struct { + // StripReasoningEffort omits reasoning_effort from the Chat Completions output. + // Enable for upstreams that reject reasoning_effort + tools together (e.g. b.ai /v1/chat/completions). + StripReasoningEffort bool + // PreserveInstructionsField keeps the non-standard instructions field in the + // Chat Completions request. By default instructions are represented only as a + // leading system message because strict Chat-Completions-only upstreams often + // reject the Responses-only instructions field. + PreserveInstructionsField bool +} + +// ResponsesToChatCompletionsRequest converts an OpenAI Responses API request +// into a Chat Completions request. This is the reverse of +// ChatCompletionsToResponses and enables Chat-Completions-only upstreams to +// accept Responses API traffic by translating it back to the native +// /v1/chat/completions format before forwarding. +// +// Field mapping: +// - input (array) → messages (system/user/assistant/tool 角色还原) +// - tools → tools +// - tool_choice → tool_choice +// - temperature → temperature +// - top_p → top_p +// - max_output_tokens → max_tokens +// - stream → stream +// - reasoning.effort → reasoning_effort +// - instructions → instructions (前置为 system 消息保留语义同时透传字段) +// +// Unsupported fields (previous_response_id / prompt_cache_key / +// service_tier-only / parallel_tool_calls / store / include / text 等) 不会丢失 +// 关键语义但会以 debug 日志形式记录被忽略,便于上游排障。 +func ResponsesToChatCompletionsRequest(req *ResponsesRequest) (*ChatCompletionsRequest, error) { + return ResponsesToChatCompletionsRequestWithOptions(req, ConvertResponsesOptions{}) +} + +// ResponsesToChatCompletionsRequestWithOptions is like ResponsesToChatCompletionsRequest +// but accepts options to control conversion behaviour. +func ResponsesToChatCompletionsRequestWithOptions(req *ResponsesRequest, opts ConvertResponsesOptions) (*ChatCompletionsRequest, error) { + if req == nil { + return nil, fmt.Errorf("nil ResponsesRequest") + } + + messages, err := convertResponsesInputToChatMessages(req.Input) + if err != nil { + return nil, fmt.Errorf("convert input: %w", err) + } + + // instructions 在 Chat Completions 中没有等价独立字段。为保持语义完整, + // 默认在最前置插入一条 system 消息;非标准 instructions 字段仅在显式 + // PreserveInstructionsField 时透传,避免严格 CC-only upstream 拒绝请求。 + if req.Instructions != "" { + raw, err := json.Marshal(req.Instructions) + if err != nil { + return nil, fmt.Errorf("marshal instructions: %w", err) + } + sysMsg := ChatMessage{Role: "system", Content: raw} + messages = append([]ChatMessage{sysMsg}, messages...) + } + + out := &ChatCompletionsRequest{ + Model: req.Model, + Messages: messages, + Temperature: req.Temperature, + TopP: req.TopP, + Stop: req.Stop, + User: req.User, + Metadata: req.Metadata, + Seed: req.Seed, + PresencePenalty: req.PresencePenalty, + FrequencyPenalty: req.FrequencyPenalty, + Logprobs: req.Logprobs, + TopLogprobs: req.TopLogprobs, + Stream: req.Stream, + StreamOptions: req.StreamOptions, + ServiceTier: req.ServiceTier, + ParallelToolCalls: req.ParallelToolCalls, + } + if opts.PreserveInstructionsField { + out.Instructions = req.Instructions + } + + if req.MaxOutputTokens != nil && *req.MaxOutputTokens > 0 { + v := *req.MaxOutputTokens + out.MaxTokens = &v + } + + if req.Reasoning != nil && req.Reasoning.Effort != "" { + if opts.StripReasoningEffort { + slog.Debug("apicompat: stripping reasoning_effort per account config") + } else { + out.ReasoningEffort = req.Reasoning.Effort + } + } + + if len(req.Tools) > 0 { + out.Tools = convertResponsesToolsToChat(req.Tools) + } + + if len(req.ToolChoice) > 0 { + toolChoice, err := convertResponsesToolChoiceToChat(req.ToolChoice) + if err != nil { + return nil, fmt.Errorf("convert tool_choice: %w", err) + } + out.ToolChoice = toolChoice + } + + logIgnoredResponsesFields(req) + + return out, nil +} + +// convertResponsesInputToChatMessages 把 Responses API 的 input 还原为 Chat +// Completions 的 messages 数组。input 既可能是字符串(等价于一条 user 消息), +// 也可能是 ResponsesInputItem 数组。 +func convertResponsesInputToChatMessages(raw json.RawMessage) ([]ChatMessage, error) { + if len(raw) == 0 { + return nil, nil + } + + // Plain string input → single user message. + var s string + if err := json.Unmarshal(raw, &s); err == nil { + content, err := json.Marshal(s) + if err != nil { + return nil, err + } + return []ChatMessage{{Role: "user", Content: content}}, nil + } + + var items []ResponsesInputItem + if err := json.Unmarshal(raw, &items); err != nil { + return nil, fmt.Errorf("input is neither string nor array: %w", err) + } + + out := make([]ChatMessage, 0, len(items)) + // 多个 assistant function_call 在 Chat Completions 中需要合并到同一条 + // assistant 消息的 tool_calls 数组里;用游标跟踪最后一条 assistant 消息。 + var lastAssistantIdx int = -1 + + for _, it := range items { + switch it.Type { + case "function_call": + tc := ChatToolCall{ + ID: it.CallID, + Type: "function", + Function: ChatFunctionCall{ + Name: it.Name, + Arguments: it.Arguments, + }, + } + if lastAssistantIdx >= 0 && out[lastAssistantIdx].Role == "assistant" { + out[lastAssistantIdx].ToolCalls = append(out[lastAssistantIdx].ToolCalls, tc) + } else { + out = append(out, ChatMessage{ + Role: "assistant", + ToolCalls: []ChatToolCall{tc}, + }) + lastAssistantIdx = len(out) - 1 + } + case "function_call_output": + content, err := json.Marshal(it.Output) + if err != nil { + return nil, err + } + out = append(out, ChatMessage{ + Role: "tool", + ToolCallID: it.CallID, + Content: content, + }) + lastAssistantIdx = -1 + case "", "message": + role := it.Role + if role == "" { + role = "user" + } + // Responses API 把 system 角色叫 "developer",回退兼容。 + if role == "developer" { + role = "system" + } + content, err := convertResponsesContentToChat(role, it.Content) + if err != nil { + return nil, err + } + msg := ChatMessage{Role: role, Content: content} + out = append(out, msg) + if role == "assistant" { + lastAssistantIdx = len(out) - 1 + } else { + lastAssistantIdx = -1 + } + default: + // 未识别 type 兜底为 user 文本,避免静默丢失。 + slog.Debug("apicompat: unknown Responses input item type, fallback to user message", + slog.String("type", it.Type)) + content, err := convertResponsesContentToChat("user", it.Content) + if err != nil { + return nil, err + } + out = append(out, ChatMessage{Role: "user", Content: content}) + lastAssistantIdx = -1 + } + } + + return normalizeChatMessagesForToolCallPairs(out), nil +} + +// normalizeChatMessagesForToolCallPairs enforces the Chat Completions invariant +// required by strict upstreams: an assistant message containing tool_calls must +// be immediately followed by tool messages for those tool_call IDs. Responses +// histories can contain pending function_call items without corresponding +// function_call_output items; forwarding those verbatim causes upstream 400s +// such as "An assistant message with 'tool_calls' must be followed by tool +// messages". Unanswered tool calls are removed from the assistant message, and +// orphan tool messages are preserved as user-visible text instead of illegal +// role=tool messages. +func normalizeChatMessagesForToolCallPairs(messages []ChatMessage) []ChatMessage { + out := make([]ChatMessage, 0, len(messages)) + + for i := 0; i < len(messages); i++ { + msg := messages[i] + if msg.Role == "tool" { + // Only degrade to a user message when there is no call_id — a tool + // message that carries a proper ToolCallID may appear at the start + // of a conversation history (e.g. a lone function_call_output from + // the Responses API) and should be forwarded as-is. + if msg.ToolCallID == "" { + out = append(out, orphanToolMessageAsUser(msg)) + } else { + out = append(out, msg) + } + continue + } + if msg.Role != "assistant" || len(msg.ToolCalls) == 0 { + out = append(out, msg) + continue + } + + j := i + 1 + for j < len(messages) && messages[j].Role == "tool" { + j++ + } + toolMessages := messages[i+1 : j] + matchingToolMessages := make(map[string][]ChatMessage, len(toolMessages)) + for _, toolMsg := range toolMessages { + if toolMsg.ToolCallID == "" { + continue + } + matchingToolMessages[toolMsg.ToolCallID] = append(matchingToolMessages[toolMsg.ToolCallID], toolMsg) + } + + validIDs := make(map[string]bool, len(msg.ToolCalls)) + validCalls := make([]ChatToolCall, 0, len(msg.ToolCalls)) + for _, tc := range msg.ToolCalls { + if tc.ID == "" { + slog.Debug("apicompat: dropping assistant tool_call without id") + continue + } + if len(matchingToolMessages[tc.ID]) == 0 { + slog.Debug("apicompat: dropping unanswered assistant tool_call", + slog.String("tool_call_id", tc.ID), + slog.String("name", tc.Function.Name)) + continue + } + validIDs[tc.ID] = true + validCalls = append(validCalls, tc) + } + + if len(validCalls) > 0 { + msg.ToolCalls = validCalls + out = append(out, msg) + } else if chatMessageHasNonEmptyContent(msg) { + msg.ToolCalls = nil + out = append(out, msg) + } + + for _, toolMsg := range toolMessages { + if validIDs[toolMsg.ToolCallID] { + out = append(out, toolMsg) + } else { + out = append(out, orphanToolMessageAsUser(toolMsg)) + } + } + i = j - 1 + } + + return out +} + +func chatMessageHasNonEmptyContent(msg ChatMessage) bool { + if len(msg.Content) == 0 || string(msg.Content) == "null" { + return false + } + var s string + if err := json.Unmarshal(msg.Content, &s); err == nil { + return s != "" + } + return true +} + +func orphanToolMessageAsUser(msg ChatMessage) ChatMessage { + toolOutput := "" + if len(msg.Content) > 0 { + if err := json.Unmarshal(msg.Content, &toolOutput); err != nil { + toolOutput = string(msg.Content) + } + } + text := "Tool result" + if msg.ToolCallID != "" { + text += " for " + msg.ToolCallID + } + if toolOutput != "" { + text += ": " + toolOutput + } + content, _ := json.Marshal(text) + return ChatMessage{Role: "user", Content: content} +} + +// convertResponsesContentToChat 把 Responses 的 message.content 还原为 Chat +// Completions 的 content。两种形式:字符串原样透传;对象数组转换为 +// ChatContentPart 数组(input_text/output_text → text;input_image → image_url)。 +func convertResponsesContentToChat(role string, raw json.RawMessage) (json.RawMessage, error) { + if len(raw) == 0 { + return nil, nil + } + + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return json.Marshal(s) + } + + var parts []ResponsesContentPart + if err := json.Unmarshal(raw, &parts); err != nil { + return nil, fmt.Errorf("parse content as string or parts array: %w", err) + } + + // assistant 角色:仅取文本拼接为字符串,保持与 Chat Completions 习惯一致。 + if role == "assistant" { + var b string + for _, p := range parts { + if p.Type == "output_text" || p.Type == "input_text" { + b += p.Text + } + } + return json.Marshal(b) + } + + chatParts := make([]ChatContentPart, 0, len(parts)) + for _, p := range parts { + switch p.Type { + case "input_text", "output_text", "text": + if p.Text == "" { + continue + } + chatParts = append(chatParts, ChatContentPart{Type: "text", Text: p.Text}) + case "input_image": + if p.ImageURL == "" { + continue + } + chatParts = append(chatParts, ChatContentPart{ + Type: "image_url", + ImageURL: &ChatImageURL{URL: p.ImageURL}, + }) + default: + slog.Debug("apicompat: unknown Responses content part type", + slog.String("type", p.Type)) + } + } + + // 单一 text part 折叠为字符串,以贴近 Chat Completions 的常见形态。 + if len(chatParts) == 1 && chatParts[0].Type == "text" { + return json.Marshal(chatParts[0].Text) + } + return json.Marshal(chatParts) +} + +// convertResponsesToolsToChat 把 Responses 的 tool 定义还原为 Chat Completions +// 形态(type=function, function={name,description,parameters,strict})。 +func convertResponsesToolsToChat(tools []ResponsesTool) []ChatTool { + out := make([]ChatTool, 0, len(tools)) + for _, t := range tools { + if t.Type != "" && t.Type != "function" { + // 内置 server-side 工具(web_search / local_shell 等)在 Chat + // Completions 中没有对应表达,跳过并记录。 + slog.Debug("apicompat: drop non-function Responses tool", + slog.String("type", t.Type), + slog.String("name", t.Name)) + continue + } + params := t.Parameters + if len(params) == 0 || string(params) == "null" { + params = json.RawMessage(`{"type":"object","properties":{}}`) + } + out = append(out, ChatTool{ + Type: "function", + Function: &ChatFunction{ + Name: t.Name, + Description: t.Description, + Parameters: params, + Strict: t.Strict, + }, + }) + } + return out +} + +// convertResponsesToolChoiceToChat maps Responses tool_choice variants to Chat +// Completions variants. String choices are compatible. Responses object form +// {"type":"function","name":"foo"} becomes Chat's +// {"type":"function","function":{"name":"foo"}}. Already-Chat-shaped +// objects are preserved. +func convertResponsesToolChoiceToChat(raw json.RawMessage) (json.RawMessage, error) { + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return json.Marshal(s) + } + + var obj map[string]json.RawMessage + if err := json.Unmarshal(raw, &obj); err != nil { + return nil, err + } + + var typ string + if v := obj["type"]; len(v) > 0 { + _ = json.Unmarshal(v, &typ) + } + if typ != "function" { + // Built-in/server-side tool choices have no Chat Completions equivalent; + // pass through so permissive upstreams can still decide what to do. + return raw, nil + } + if _, ok := obj["function"]; ok { + return raw, nil + } + var name string + if v := obj["name"]; len(v) > 0 { + _ = json.Unmarshal(v, &name) + } + if name == "" { + return raw, nil + } + return json.Marshal(map[string]any{ + "type": "function", + "function": map[string]any{ + "name": name, + }, + }) +} + +// logIgnoredResponsesFields 输出对未承载到 Chat Completions 的字段的 debug 日志, +// 便于上游侧排障,但不影响请求的实际效果。 +func logIgnoredResponsesFields(req *ResponsesRequest) { + if req.PreviousResponseID != "" { + slog.Debug("apicompat: ignoring previous_response_id (no Chat Completions equivalent)", + slog.String("previous_response_id", req.PreviousResponseID)) + } + if req.PromptCacheKey != "" { + slog.Debug("apicompat: ignoring prompt_cache_key", + slog.String("prompt_cache_key", req.PromptCacheKey)) + } + if len(req.Include) > 0 { + slog.Debug("apicompat: ignoring Responses include[]", + slog.Any("include", req.Include)) + } + if req.Store != nil { + slog.Debug("apicompat: ignoring store flag", + slog.Bool("store", *req.Store)) + } + if req.ParallelToolCalls != nil { + slog.Debug("apicompat: mapping parallel_tool_calls to Chat Completions", + slog.Bool("parallel_tool_calls", *req.ParallelToolCalls)) + } + if req.Text != nil { + slog.Debug("apicompat: ignoring text verbosity", + slog.String("verbosity", req.Text.Verbosity)) + } + if req.Reasoning != nil && req.Reasoning.Summary != "" { + slog.Debug("apicompat: ignoring reasoning.summary", + slog.String("summary", req.Reasoning.Summary)) + } +} diff --git a/backend/internal/pkg/apicompat/responses_to_chatcompletions_request_test.go b/backend/internal/pkg/apicompat/responses_to_chatcompletions_request_test.go new file mode 100644 index 00000000000..86442600d72 --- /dev/null +++ b/backend/internal/pkg/apicompat/responses_to_chatcompletions_request_test.go @@ -0,0 +1,480 @@ +package apicompat + +import ( + "encoding/json" + "testing" +) + +func mustMarshal(t *testing.T, v any) json.RawMessage { + t.Helper() + b, err := json.Marshal(v) + if err != nil { + t.Fatalf("marshal: %v", err) + } + return b +} + +func TestResponsesToChatCompletionsRequest_TextOnlySingleMessage(t *testing.T) { + req := &ResponsesRequest{ + Model: "gpt-4o", + Input: mustMarshal(t, []ResponsesInputItem{ + {Role: "user", Content: mustMarshal(t, "hello world")}, + }), + } + + got, err := ResponsesToChatCompletionsRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.Model != "gpt-4o" { + t.Errorf("model: got %q want %q", got.Model, "gpt-4o") + } + if len(got.Messages) != 1 { + t.Fatalf("messages len: got %d want 1", len(got.Messages)) + } + if got.Messages[0].Role != "user" { + t.Errorf("role: got %q want user", got.Messages[0].Role) + } + var s string + if err := json.Unmarshal(got.Messages[0].Content, &s); err != nil { + t.Fatalf("unmarshal content: %v", err) + } + if s != "hello world" { + t.Errorf("content: got %q want %q", s, "hello world") + } +} + +func TestResponsesToChatCompletionsRequest_StringInput(t *testing.T) { + req := &ResponsesRequest{ + Model: "gpt-4o", + Input: mustMarshal(t, "ping"), + } + + got, err := ResponsesToChatCompletionsRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got.Messages) != 1 || got.Messages[0].Role != "user" { + t.Fatalf("expected single user message, got %+v", got.Messages) + } + var s string + _ = json.Unmarshal(got.Messages[0].Content, &s) + if s != "ping" { + t.Errorf("content: got %q want ping", s) + } +} + +func TestResponsesToChatCompletionsRequest_MultiMessageHistory(t *testing.T) { + req := &ResponsesRequest{ + Model: "gpt-4o", + Input: mustMarshal(t, []ResponsesInputItem{ + {Role: "system", Content: mustMarshal(t, "you are helpful")}, + {Role: "user", Content: mustMarshal(t, "hi")}, + {Role: "assistant", Content: mustMarshal(t, []ResponsesContentPart{ + {Type: "output_text", Text: "hello!"}, + })}, + {Role: "user", Content: mustMarshal(t, []ResponsesContentPart{ + {Type: "input_text", Text: "and now?"}, + })}, + }), + } + + got, err := ResponsesToChatCompletionsRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got.Messages) != 4 { + t.Fatalf("messages len: got %d want 4", len(got.Messages)) + } + wantRoles := []string{"system", "user", "assistant", "user"} + for i, want := range wantRoles { + if got.Messages[i].Role != want { + t.Errorf("messages[%d].role: got %q want %q", i, got.Messages[i].Role, want) + } + } + + // assistant content should be folded back to plain string + var assistantContent string + if err := json.Unmarshal(got.Messages[2].Content, &assistantContent); err != nil { + t.Fatalf("assistant content not a string: %v", err) + } + if assistantContent != "hello!" { + t.Errorf("assistant content: got %q want %q", assistantContent, "hello!") + } + + // last user content (single text part) should be folded to string too + var userContent string + if err := json.Unmarshal(got.Messages[3].Content, &userContent); err != nil { + t.Fatalf("user content not a string: %v", err) + } + if userContent != "and now?" { + t.Errorf("user content: got %q want %q", userContent, "and now?") + } +} + +func TestResponsesToChatCompletionsRequest_ToolCallRoundTrip(t *testing.T) { + req := &ResponsesRequest{ + Model: "gpt-4o", + Input: mustMarshal(t, []ResponsesInputItem{ + {Role: "user", Content: mustMarshal(t, "what is the weather in SF?")}, + { + Type: "function_call", + CallID: "call_123", + Name: "get_weather", + Arguments: `{"city":"SF"}`, + }, + { + Type: "function_call_output", + CallID: "call_123", + Output: `{"temp":68}`, + }, + {Role: "assistant", Content: mustMarshal(t, "It is 68F.")}, + }), + Tools: []ResponsesTool{ + { + Type: "function", + Name: "get_weather", + Description: "Get current weather", + Parameters: json.RawMessage(`{"type":"object"}`), + }, + }, + ToolChoice: json.RawMessage(`"auto"`), + } + + got, err := ResponsesToChatCompletionsRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Expect: user, assistant(tool_calls), tool, assistant(text) + if len(got.Messages) != 4 { + t.Fatalf("messages len: got %d want 4 (%+v)", len(got.Messages), got.Messages) + } + if got.Messages[0].Role != "user" { + t.Errorf("[0] role: got %q want user", got.Messages[0].Role) + } + if got.Messages[1].Role != "assistant" { + t.Errorf("[1] role: got %q want assistant", got.Messages[1].Role) + } + if len(got.Messages[1].ToolCalls) != 1 { + t.Fatalf("[1] tool_calls len: got %d want 1", len(got.Messages[1].ToolCalls)) + } + tc := got.Messages[1].ToolCalls[0] + if tc.ID != "call_123" || tc.Type != "function" || tc.Function.Name != "get_weather" || tc.Function.Arguments != `{"city":"SF"}` { + t.Errorf("tool_call mismatch: %+v", tc) + } + + if got.Messages[2].Role != "tool" { + t.Errorf("[2] role: got %q want tool", got.Messages[2].Role) + } + if got.Messages[2].ToolCallID != "call_123" { + t.Errorf("[2] tool_call_id: got %q want call_123", got.Messages[2].ToolCallID) + } + var toolOutput string + _ = json.Unmarshal(got.Messages[2].Content, &toolOutput) + if toolOutput != `{"temp":68}` { + t.Errorf("tool output: got %q want %q", toolOutput, `{"temp":68}`) + } + + if got.Messages[3].Role != "assistant" { + t.Errorf("[3] role: got %q want assistant", got.Messages[3].Role) + } + + // tools mapping + if len(got.Tools) != 1 { + t.Fatalf("tools len: got %d want 1", len(got.Tools)) + } + if got.Tools[0].Type != "function" || got.Tools[0].Function == nil || + got.Tools[0].Function.Name != "get_weather" { + t.Errorf("tool mismatch: %+v", got.Tools[0]) + } + + // tool_choice pass-through + if string(got.ToolChoice) != `"auto"` { + t.Errorf("tool_choice: got %s want \"auto\"", string(got.ToolChoice)) + } +} + +func TestResponsesToChatCompletionsRequest_UnansweredToolCallsDropped(t *testing.T) { + req := &ResponsesRequest{ + Model: "gpt-4o", + Input: mustMarshal(t, []ResponsesInputItem{ + {Role: "user", Content: mustMarshal(t, "do two things")}, + {Type: "function_call", CallID: "c1", Name: "fn_a", Arguments: `{}`}, + {Type: "function_call", CallID: "c2", Name: "fn_b", Arguments: `{}`}, + }), + } + + got, err := ResponsesToChatCompletionsRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got.Messages) != 1 { + t.Fatalf("messages len: got %d want 1 (%+v)", len(got.Messages), got.Messages) + } + if got.Messages[0].Role != "user" { + t.Errorf("[0] role: got %q want user", got.Messages[0].Role) + } +} + +func TestResponsesToChatCompletionsRequest_MultipleAnsweredToolCallsMerged(t *testing.T) { + req := &ResponsesRequest{ + Model: "gpt-4o", + Input: mustMarshal(t, []ResponsesInputItem{ + {Role: "user", Content: mustMarshal(t, "do two things")}, + {Type: "function_call", CallID: "c1", Name: "fn_a", Arguments: `{}`}, + {Type: "function_call", CallID: "c2", Name: "fn_b", Arguments: `{}`}, + {Type: "function_call_output", CallID: "c1", Output: `ok1`}, + {Type: "function_call_output", CallID: "c2", Output: `ok2`}, + }), + } + + got, err := ResponsesToChatCompletionsRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got.Messages) != 4 { + t.Fatalf("messages len: got %d want 4 (%+v)", len(got.Messages), got.Messages) + } + if got.Messages[1].Role != "assistant" || len(got.Messages[1].ToolCalls) != 2 { + t.Fatalf("expected assistant with two tool calls, got %+v", got.Messages[1]) + } + if got.Messages[2].Role != "tool" || got.Messages[2].ToolCallID != "c1" { + t.Fatalf("expected tool c1, got %+v", got.Messages[2]) + } + if got.Messages[3].Role != "tool" || got.Messages[3].ToolCallID != "c2" { + t.Fatalf("expected tool c2, got %+v", got.Messages[3]) + } +} + +func TestResponsesToChatCompletionsRequest_PartiallyAnsweredToolCallsFiltered(t *testing.T) { + req := &ResponsesRequest{ + Model: "gpt-4o", + Input: mustMarshal(t, []ResponsesInputItem{ + {Role: "user", Content: mustMarshal(t, "do two things")}, + {Type: "function_call", CallID: "c1", Name: "fn_a", Arguments: `{}`}, + {Type: "function_call", CallID: "c2", Name: "fn_b", Arguments: `{}`}, + {Type: "function_call_output", CallID: "c2", Output: `ok2`}, + }), + } + + got, err := ResponsesToChatCompletionsRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got.Messages) != 3 { + t.Fatalf("messages len: got %d want 3 (%+v)", len(got.Messages), got.Messages) + } + if got.Messages[1].Role != "assistant" || len(got.Messages[1].ToolCalls) != 1 || got.Messages[1].ToolCalls[0].ID != "c2" { + t.Fatalf("expected only answered c2 tool call, got %+v", got.Messages[1]) + } + if got.Messages[2].Role != "tool" || got.Messages[2].ToolCallID != "c2" { + t.Fatalf("expected tool c2, got %+v", got.Messages[2]) + } +} + +func TestResponsesToChatCompletionsRequest_StreamTrue(t *testing.T) { + req := &ResponsesRequest{ + Model: "gpt-4o", + Input: mustMarshal(t, "hi"), + Stream: true, + } + got, err := ResponsesToChatCompletionsRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !got.Stream { + t.Error("stream: got false want true") + } +} + +func TestResponsesToChatCompletionsRequest_ReasoningEffortAllLevels(t *testing.T) { + for _, effort := range []string{"low", "medium", "high"} { + t.Run(effort, func(t *testing.T) { + req := &ResponsesRequest{ + Model: "o1", + Input: mustMarshal(t, "think"), + Reasoning: &ResponsesReasoning{ + Effort: effort, + Summary: "auto", + }, + } + got, err := ResponsesToChatCompletionsRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.ReasoningEffort != effort { + t.Errorf("reasoning_effort: got %q want %q", got.ReasoningEffort, effort) + } + }) + } +} + +func TestResponsesToChatCompletionsRequest_ScalarFieldMapping(t *testing.T) { + temp := 0.7 + topP := 0.9 + maxTok := 1024 + req := &ResponsesRequest{ + Model: "gpt-4o", + Input: mustMarshal(t, "x"), + Temperature: &temp, + TopP: &topP, + MaxOutputTokens: &maxTok, + ServiceTier: "auto", + Instructions: "be terse", + } + + got, err := ResponsesToChatCompletionsRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.Temperature == nil || *got.Temperature != 0.7 { + t.Errorf("temperature mismatch: %v", got.Temperature) + } + if got.TopP == nil || *got.TopP != 0.9 { + t.Errorf("top_p mismatch: %v", got.TopP) + } + if got.MaxTokens == nil || *got.MaxTokens != 1024 { + t.Errorf("max_tokens mismatch: %v", got.MaxTokens) + } + if got.ServiceTier != "auto" { + t.Errorf("service_tier mismatch: %q", got.ServiceTier) + } + if got.Instructions != "" { + t.Errorf("instructions should not be passed to strict Chat Completions upstreams by default, got %q", got.Instructions) + } + // instructions should also be inserted as a leading system message + if len(got.Messages) < 1 || got.Messages[0].Role != "system" { + t.Fatalf("expected leading system message from instructions, got %+v", got.Messages) + } + var s string + _ = json.Unmarshal(got.Messages[0].Content, &s) + if s != "be terse" { + t.Errorf("instructions system content: got %q want %q", s, "be terse") + } +} + +func TestResponsesToChatCompletionsRequestWithOptions_PreserveInstructionsField(t *testing.T) { + req := &ResponsesRequest{ + Model: "gpt-4o", + Input: mustMarshal(t, "x"), + Instructions: "be terse", + } + + got, err := ResponsesToChatCompletionsRequestWithOptions(req, ConvertResponsesOptions{PreserveInstructionsField: true}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.Instructions != "be terse" { + t.Errorf("instructions mismatch: %q", got.Instructions) + } +} + +func TestResponsesToChatCompletionsRequest_ToolChoiceObjectMapping(t *testing.T) { + req := &ResponsesRequest{ + Model: "gpt-4o", + Input: mustMarshal(t, "x"), + ToolChoice: json.RawMessage(`{"type":"function","name":"get_weather"}`), + } + + got, err := ResponsesToChatCompletionsRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + var choice struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + } `json:"function"` + } + if err := json.Unmarshal(got.ToolChoice, &choice); err != nil { + t.Fatalf("unmarshal tool_choice: %v", err) + } + if choice.Type != "function" || choice.Function.Name != "get_weather" { + t.Fatalf("unexpected tool_choice: %s", string(got.ToolChoice)) + } +} + +func TestResponsesToChatCompletionsRequest_FunctionCallOutputObject(t *testing.T) { + var req ResponsesRequest + body := []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_1","output":{"temp":68,"unit":"f"}}]}`) + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("unmarshal request: %v", err) + } + + got, err := ResponsesToChatCompletionsRequest(&req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got.Messages) != 1 || got.Messages[0].Role != "tool" { + t.Fatalf("expected single tool message, got %+v", got.Messages) + } + var content string + if err := json.Unmarshal(got.Messages[0].Content, &content); err != nil { + t.Fatalf("unmarshal tool content: %v", err) + } + if content != `{"temp":68,"unit":"f"}` { + t.Errorf("tool content: got %q", content) + } +} + +func TestResponsesToChatCompletionsRequest_IgnoredFieldsDoNotError(t *testing.T) { + store := false + parallel := true + req := &ResponsesRequest{ + Model: "gpt-4o", + Input: mustMarshal(t, "x"), + PreviousResponseID: "resp_abc", + PromptCacheKey: "ck_1", + Include: []string{"reasoning.encrypted_content"}, + Store: &store, + ParallelToolCalls: ¶llel, + Text: &ResponsesText{Verbosity: "low"}, + } + if _, err := ResponsesToChatCompletionsRequest(req); err != nil { + t.Fatalf("unexpected error with ignored fields: %v", err) + } +} + +func TestResponsesToChatCompletionsRequest_NilInputNoCrash(t *testing.T) { + req := &ResponsesRequest{Model: "gpt-4o"} + got, err := ResponsesToChatCompletionsRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got.Messages) != 0 { + t.Errorf("expected no messages, got %d", len(got.Messages)) + } +} + +func TestResponsesToChatCompletionsRequest_NilRequestErrors(t *testing.T) { + if _, err := ResponsesToChatCompletionsRequest(nil); err == nil { + t.Error("expected error for nil request") + } +} + +func TestResponsesToChatCompletionsRequestWithOptions_StripReasoningEffort(t *testing.T) { + req := &ResponsesRequest{ + Model: "gpt-5.4", + Reasoning: &ResponsesReasoning{Effort: "medium"}, + Input: mustMarshal(t, "hello"), + } + + // Default (StripReasoningEffort=false): reasoning_effort should be preserved. + got, err := ResponsesToChatCompletionsRequestWithOptions(req, ConvertResponsesOptions{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.ReasoningEffort != "medium" { + t.Errorf("default: expected reasoning_effort=medium, got %q", got.ReasoningEffort) + } + + // StripReasoningEffort=true: reasoning_effort should be omitted. + got, err = ResponsesToChatCompletionsRequestWithOptions(req, ConvertResponsesOptions{StripReasoningEffort: true}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.ReasoningEffort != "" { + t.Errorf("strip=true: expected empty reasoning_effort, got %q", got.ReasoningEffort) + } +} diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go index b4451f235bb..c17f1980a44 100644 --- a/backend/internal/pkg/apicompat/types.go +++ b/backend/internal/pkg/apicompat/types.go @@ -195,7 +195,16 @@ type ResponsesRequest struct { MaxOutputTokens *int `json:"max_output_tokens,omitempty"` Temperature *float64 `json:"temperature,omitempty"` TopP *float64 `json:"top_p,omitempty"` + Stop json.RawMessage `json:"stop,omitempty"` // string or []string + User string `json:"user,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` + Seed *int `json:"seed,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + Logprobs *bool `json:"logprobs,omitempty"` + TopLogprobs *int `json:"top_logprobs,omitempty"` Stream bool `json:"stream,omitempty"` + StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"` Tools []ResponsesTool `json:"tools,omitempty"` Include []string `json:"include,omitempty"` Store *bool `json:"store,omitempty"` @@ -239,6 +248,49 @@ type ResponsesInputItem struct { Output string `json:"output,omitempty"` } +// UnmarshalJSON accepts Responses function_call_output.output in both the +// official string form and the object/array forms commonly emitted by gateway +// clients. Object/array values are compacted into a JSON string so existing +// conversion code can keep treating Output as Chat Completions tool message +// content. +func (it *ResponsesInputItem) UnmarshalJSON(data []byte) error { + type responsesInputItemAlias ResponsesInputItem + var fields map[string]json.RawMessage + if err := json.Unmarshal(data, &fields); err != nil { + return err + } + output := fields["output"] + delete(fields, "output") + withoutOutput, err := json.Marshal(fields) + if err != nil { + return err + } + var alias responsesInputItemAlias + if err := json.Unmarshal(withoutOutput, &alias); err != nil { + return err + } + *it = ResponsesInputItem(alias) + it.Output = decodeResponsesFunctionOutput(output) + return nil +} + +func decodeResponsesFunctionOutput(raw json.RawMessage) string { + if len(raw) == 0 || string(raw) == "null" { + return "" + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s + } + var v any + if err := json.Unmarshal(raw, &v); err == nil { + if b, err := json.Marshal(v); err == nil { + return string(b) + } + } + return string(raw) +} + // ResponsesContentPart is a typed content part in a Responses message. type ResponsesContentPart struct { Type string `json:"type"` // "input_text" | "output_text" | "input_image" @@ -390,6 +442,9 @@ type ResponsesStreamEvent struct { // response.output_item.added / response.output_item.done Item *ResponsesOutput `json:"item,omitempty"` + // response.content_part.added / response.content_part.done + Part *ResponsesContentPart `json:"part,omitempty"` + // response.output_text.delta / response.output_text.done OutputIndex int `json:"output_index,omitempty"` ContentIndex int `json:"content_index,omitempty"` @@ -427,13 +482,21 @@ type ChatCompletionsRequest struct { MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` Temperature *float64 `json:"temperature,omitempty"` TopP *float64 `json:"top_p,omitempty"` + Stop json.RawMessage `json:"stop,omitempty"` // string or []string + User string `json:"user,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` + Seed *int `json:"seed,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + Logprobs *bool `json:"logprobs,omitempty"` + TopLogprobs *int `json:"top_logprobs,omitempty"` Stream bool `json:"stream,omitempty"` StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"` Tools []ChatTool `json:"tools,omitempty"` ToolChoice json.RawMessage `json:"tool_choice,omitempty"` ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high" | "xhigh" ServiceTier string `json:"service_tier,omitempty"` - Stop json.RawMessage `json:"stop,omitempty"` // string or []string + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Legacy function calling (deprecated but still supported) Functions []ChatFunction `json:"functions,omitempty"` @@ -521,11 +584,11 @@ type ChatChoice struct { // ChatUsage holds token counts in Chat Completions format. type ChatUsage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"` - CompletionTokensDetails *ChatTokenDetails `json:"completion_tokens_details,omitempty"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"` + CompletionTokensDetails *ChatCompletionTokenDetails `json:"completion_tokens_details,omitempty"` } // ChatTokenDetails provides a breakdown of token usage. The same type is @@ -537,9 +600,15 @@ type ChatUsage struct { // - completion_tokens_details: reasoning_tokens, audio_tokens, // accepted_prediction_tokens, rejected_prediction_tokens type ChatTokenDetails struct { - CachedTokens int `json:"cached_tokens,omitempty"` - AudioTokens int `json:"audio_tokens,omitempty"` - ReasoningTokens int `json:"reasoning_tokens,omitempty"` + CachedTokens int `json:"cached_tokens,omitempty"` + AudioTokens int `json:"audio_tokens,omitempty"` +} + +// ChatCompletionTokenDetails provides completion-side token breakdowns. +type ChatCompletionTokenDetails struct { + ReasoningTokens int `json:"reasoning_tokens,omitempty"` + AudioTokens int `json:"audio_tokens,omitempty"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"` RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"` } diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index bc970f76075..6337520af49 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -93,7 +93,8 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account SetStatus(account.Status). SetErrorMessage(account.ErrorMessage). SetSchedulable(account.Schedulable). - SetAutoPauseOnExpired(account.AutoPauseOnExpired) + SetAutoPauseOnExpired(account.AutoPauseOnExpired). + SetStripReasoningEffortOnCc(account.StripReasoningEffortOnCC) if account.RateMultiplier != nil { builder.SetRateMultiplier(*account.RateMultiplier) @@ -334,7 +335,8 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account SetStatus(account.Status). SetErrorMessage(account.ErrorMessage). SetSchedulable(schedulable). - SetAutoPauseOnExpired(account.AutoPauseOnExpired) + SetAutoPauseOnExpired(account.AutoPauseOnExpired). + SetStripReasoningEffortOnCc(account.StripReasoningEffortOnCC) if account.RateMultiplier != nil { builder.SetRateMultiplier(*account.RateMultiplier) @@ -1743,34 +1745,35 @@ func accountEntityToService(m *dbent.Account) *service.Account { rateMultiplier := m.RateMultiplier return &service.Account{ - ID: m.ID, - Name: m.Name, - Notes: m.Notes, - Platform: m.Platform, - Type: m.Type, - Credentials: copyJSONMap(m.Credentials), - Extra: copyJSONMap(m.Extra), - ProxyID: m.ProxyID, - Concurrency: m.Concurrency, - Priority: m.Priority, - RateMultiplier: &rateMultiplier, - LoadFactor: m.LoadFactor, - Status: m.Status, - ErrorMessage: derefString(m.ErrorMessage), - LastUsedAt: m.LastUsedAt, - ExpiresAt: m.ExpiresAt, - AutoPauseOnExpired: m.AutoPauseOnExpired, - CreatedAt: m.CreatedAt, - UpdatedAt: m.UpdatedAt, - Schedulable: m.Schedulable, - RateLimitedAt: m.RateLimitedAt, - RateLimitResetAt: m.RateLimitResetAt, - OverloadUntil: m.OverloadUntil, - TempUnschedulableUntil: m.TempUnschedulableUntil, - TempUnschedulableReason: derefString(m.TempUnschedulableReason), - SessionWindowStart: m.SessionWindowStart, - SessionWindowEnd: m.SessionWindowEnd, - SessionWindowStatus: derefString(m.SessionWindowStatus), + ID: m.ID, + Name: m.Name, + Notes: m.Notes, + Platform: m.Platform, + Type: m.Type, + Credentials: copyJSONMap(m.Credentials), + Extra: copyJSONMap(m.Extra), + ProxyID: m.ProxyID, + Concurrency: m.Concurrency, + Priority: m.Priority, + RateMultiplier: &rateMultiplier, + LoadFactor: m.LoadFactor, + Status: m.Status, + ErrorMessage: derefString(m.ErrorMessage), + LastUsedAt: m.LastUsedAt, + ExpiresAt: m.ExpiresAt, + AutoPauseOnExpired: m.AutoPauseOnExpired, + StripReasoningEffortOnCC: m.StripReasoningEffortOnCc, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + Schedulable: m.Schedulable, + RateLimitedAt: m.RateLimitedAt, + RateLimitResetAt: m.RateLimitResetAt, + OverloadUntil: m.OverloadUntil, + TempUnschedulableUntil: m.TempUnschedulableUntil, + TempUnschedulableReason: derefString(m.TempUnschedulableReason), + SessionWindowStart: m.SessionWindowStart, + SessionWindowEnd: m.SessionWindowEnd, + SessionWindowStatus: derefString(m.SessionWindowStatus), } } diff --git a/backend/internal/repository/scheduler_cache.go b/backend/internal/repository/scheduler_cache.go index cf19dedaa0b..12fa7990dab 100644 --- a/backend/internal/repository/scheduler_cache.go +++ b/backend/internal/repository/scheduler_cache.go @@ -439,8 +439,9 @@ func buildSchedulerMetadataAccount(account service.Account) service.Account { Status: account.Status, LastUsedAt: account.LastUsedAt, ExpiresAt: account.ExpiresAt, - AutoPauseOnExpired: account.AutoPauseOnExpired, - Schedulable: account.Schedulable, + AutoPauseOnExpired: account.AutoPauseOnExpired, + StripReasoningEffortOnCC: account.StripReasoningEffortOnCC, + Schedulable: account.Schedulable, RateLimitedAt: account.RateLimitedAt, RateLimitResetAt: account.RateLimitResetAt, OverloadUntil: account.OverloadUntil, diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index fb95201f200..2e6ee3ae097 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -29,15 +29,16 @@ type Account struct { Priority int // RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。 // 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。 - RateMultiplier *float64 - LoadFactor *int // 调度负载因子;nil 表示使用 Concurrency - Status string - ErrorMessage string - LastUsedAt *time.Time - ExpiresAt *time.Time - AutoPauseOnExpired bool - CreatedAt time.Time - UpdatedAt time.Time + RateMultiplier *float64 + LoadFactor *int // 调度负载因子;nil 表示使用 Concurrency + Status string + ErrorMessage string + LastUsedAt *time.Time + ExpiresAt *time.Time + AutoPauseOnExpired bool + StripReasoningEffortOnCC bool + CreatedAt time.Time + UpdatedAt time.Time Schedulable bool @@ -829,7 +830,7 @@ func matchWildcardMappingResult(mapping map[string]string, requestedModel string } func (a *Account) IsCustomErrorCodesEnabled() bool { - if a.Type != AccountTypeAPIKey || a.Credentials == nil { + if (a.Type != AccountTypeAPIKey && a.Type != AccountTypeAPIKeyChatCompletions) || a.Credentials == nil { return false } if v, ok := a.Credentials["custom_error_codes_enabled"]; ok { @@ -1043,7 +1044,7 @@ func (a *Account) IsBedrockAPIKey() bool { // IsAPIKeyOrBedrock 返回账号类型是否支持配额和池模式等特性 func (a *Account) IsAPIKeyOrBedrock() bool { - return a.Type == AccountTypeAPIKey || a.Type == AccountTypeBedrock + return a.Type == AccountTypeAPIKey || a.Type == AccountTypeAPIKeyChatCompletions || a.Type == AccountTypeBedrock } func (a *Account) IsOpenAI() bool { @@ -1059,14 +1060,90 @@ func (a *Account) IsOpenAIOAuth() bool { } func (a *Account) IsOpenAIApiKey() bool { - return a.IsOpenAI() && a.Type == AccountTypeAPIKey + return a.IsOpenAI() && (a.Type == AccountTypeAPIKey || a.Type == AccountTypeAPIKeyChatCompletions) +} + +// IsOpenAIChatCompletionsUpstream 返回账号是否为 OpenAI 兼容 Chat Completions 上游账号。 +// +// 平台无关:只要账号类型是 AccountTypeAPIKeyChatCompletions 即视为该上游形态。 +// 适用场景包括: +// - Platform=openai:客户端调 /v1/chat/completions 走 raw 透传;调 /v1/responses 走 +// Responses↔CC 协议转换(参见 OpenAIGatewayService.ForwardResponsesAsChatCompletions)。 +// - Platform=anthropic:客户端调 /v1/messages 走 Anthropic↔CC 协议转换 +// (参见 GatewayService.ForwardAnthropicAsChatCompletions);调 /v1/chat/completions +// 走 raw 透传(参见 GatewayService.ForwardClaudeChatCompletionsRaw)。 +func (a *Account) IsOpenAIChatCompletionsUpstream() bool { + return a.Type == AccountTypeAPIKeyChatCompletions +} + +// GetOpenAIChatCompletionsURL 返回 OpenAI 兼容 Chat Completions 上游的完整 endpoint URL。 +// credentials.chat_completions_url 应存储完整 URL(如 https://api.deepseek.com/v1/chat/completions)。 +func (a *Account) GetOpenAIChatCompletionsURL() string { + if a.Credentials == nil { + return "" + } + if v, ok := a.Credentials["chat_completions_url"]; ok { + if s, ok := v.(string); ok { + return strings.TrimSpace(s) + } + } + return "" +} + +// GetUpstreamAPIKey 返回 credentials.api_key,适用于 apikey-chat-completions 等 +// 不受 GetOpenAIApiKey 类型限制约束的上游账号。 +func (a *Account) GetUpstreamAPIKey() string { + if a.Credentials == nil { + return "" + } + if v, ok := a.Credentials["api_key"]; ok { + if s, ok := v.(string); ok { + return strings.TrimSpace(s) + } + } + return "" +} + +const ( + OpenAICompatibleAuthHeaderAuthorization = "Authorization" + OpenAICompatibleAuthHeaderAPIKey = "api-key" + OpenAICompatibleAuthHeaderXAPIKey = "x-api-key" +) + +func NormalizeOpenAICompatibleAuthHeader(value string) (string, bool) { + switch strings.ToLower(strings.TrimSpace(value)) { + case "": + return OpenAICompatibleAuthHeaderAuthorization, true + case strings.ToLower(OpenAICompatibleAuthHeaderAuthorization): + return OpenAICompatibleAuthHeaderAuthorization, true + case OpenAICompatibleAuthHeaderAPIKey: + return OpenAICompatibleAuthHeaderAPIKey, true + case OpenAICompatibleAuthHeaderXAPIKey: + return OpenAICompatibleAuthHeaderXAPIKey, true + default: + return "", false + } +} + +func (a *Account) OpenAICompatibleAuthHeader() string { + if a == nil || a.Credentials == nil || !a.IsOpenAIChatCompletionsUpstream() { + return OpenAICompatibleAuthHeaderAuthorization + } + if v, ok := a.Credentials["auth_header"]; ok { + if s, ok := v.(string); ok { + if header, valid := NormalizeOpenAICompatibleAuthHeader(s); valid { + return header + } + } + } + return OpenAICompatibleAuthHeaderAuthorization } func (a *Account) GetOpenAIBaseURL() string { if !a.IsOpenAI() { return "" } - if a.Type == AccountTypeAPIKey { + if a.Type == AccountTypeAPIKey || a.Type == AccountTypeAPIKeyChatCompletions { baseURL := a.GetCredential("base_url") if baseURL != "" { return baseURL @@ -1211,7 +1288,7 @@ func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapabilit } switch capability { case OpenAIImagesCapabilityBasic, OpenAIImagesCapabilityNative: - return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey + return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey || a.Type == AccountTypeAPIKeyChatCompletions default: return true } diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 748840b75d7..5853bf38f1c 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -103,8 +103,9 @@ type CreateAccountRequest struct { Concurrency int `json:"concurrency"` Priority int `json:"priority"` GroupIDs []int64 `json:"group_ids"` - ExpiresAt *time.Time `json:"expires_at"` - AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` + ExpiresAt *time.Time `json:"expires_at"` + AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` + StripReasoningEffortOnCC *bool `json:"strip_reasoning_effort_on_cc"` } // UpdateAccountRequest 更新账号请求 @@ -118,8 +119,9 @@ type UpdateAccountRequest struct { Priority *int `json:"priority"` Status *string `json:"status"` GroupIDs *[]int64 `json:"group_ids"` - ExpiresAt *time.Time `json:"expires_at"` - AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` + ExpiresAt *time.Time `json:"expires_at"` + AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` + StripReasoningEffortOnCC *bool `json:"strip_reasoning_effort_on_cc"` } // AccountService 账号管理服务 @@ -168,6 +170,9 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( } else { account.AutoPauseOnExpired = true } + if req.StripReasoningEffortOnCC != nil { + account.StripReasoningEffortOnCC = *req.StripReasoningEffortOnCC + } if err := s.accountRepo.Create(ctx, account); err != nil { return nil, fmt.Errorf("create account: %w", err) @@ -276,6 +281,9 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount if req.AutoPauseOnExpired != nil { account.AutoPauseOnExpired = *req.AutoPauseOnExpired } + if req.StripReasoningEffortOnCC != nil { + account.StripReasoningEffortOnCC = *req.StripReasoningEffortOnCC + } // 先验证分组是否存在(在任何写操作之前) if req.GroupIDs != nil { diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 032c13b18e0..9bde4e1a933 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -179,6 +179,30 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int return s.sendErrorAndEnd(c, "Account not found") } + // apikey-chat-completions accounts use the same CC endpoint test regardless of platform + // (the upstream is an arbitrary OpenAI-compatible CC endpoint, not platform-bound). + if account.Type == AccountTypeAPIKeyChatCompletions { + testModelID := modelID + if testModelID == "" { + testModelID = openai.DefaultTestModel + } + testModelID = account.GetMappedModel(testModelID) + authToken := account.GetUpstreamAPIKey() + if authToken == "" { + return s.sendErrorAndEnd(c, "No API key available") + } + ccURL := account.GetOpenAIChatCompletionsURL() + if ccURL == "" { + return s.sendErrorAndEnd(c, "No chat_completions_url configured for this account") + } + chatCompletionsURL, err := s.validateUpstreamBaseURL(ccURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid chat_completions_url: %s", err.Error())) + } + return s.testOpenAIChatCompletionsConnection(c, account, testModelID, prompt, chatCompletionsURL, authToken) + + } + // Route to platform-specific test method if account.IsOpenAI() { return s.testOpenAIAccountConnection(c, account, modelID, prompt, normalizeAccountTestMode(mode)) @@ -510,6 +534,24 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account return s.testOpenAICompactConnection(c, account, testModelID) } + // apikey-chat-completions accounts use the CC endpoint directly + if account.Type == AccountTypeAPIKeyChatCompletions { + authToken := account.GetUpstreamAPIKey() + if authToken == "" { + return s.sendErrorAndEnd(c, "No API key available") + } + ccURL := account.GetOpenAIChatCompletionsURL() + if ccURL == "" { + return s.sendErrorAndEnd(c, "No chat_completions_url configured for this account") + } + chatCompletionsURL, err := s.validateUpstreamBaseURL(ccURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid chat_completions_url: %s", err.Error())) + } + return s.testOpenAIChatCompletionsConnection(c, account, testModelID, prompt, chatCompletionsURL, authToken) + + } + // Route to image generation test if an image model is selected if isOpenAIImageModel(testModelID) { imagePrompt := strings.TrimSpace(prompt) @@ -555,7 +597,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) } if !openai_compat.ShouldUseResponsesAPI(account.Extra) { - return s.testOpenAIChatCompletionsConnection(c, account, testModelID, prompt, normalizedBaseURL, authToken) + return s.testOpenAIChatCompletionsConnection(c, account, testModelID, prompt, buildOpenAIChatCompletionsURL(normalizedBaseURL), authToken) } apiURL = buildOpenAIResponsesURL(normalizedBaseURL) } else { @@ -638,11 +680,11 @@ func (s *AccountTestService) testOpenAIChatCompletionsConnection( account *Account, testModelID string, prompt string, - normalizedBaseURL string, + chatCompletionsURL string, authToken string, ) error { ctx := c.Request.Context() - apiURL := buildOpenAIChatCompletionsURL(normalizedBaseURL) + apiURL := strings.TrimSpace(chatCompletionsURL) c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Cache-Control", "no-cache") @@ -663,7 +705,7 @@ func (s *AccountTestService) testOpenAIChatCompletionsConnection( req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Authorization", "Bearer "+authToken) + applyOpenAICompatibleAPIKeyAuth(req, account, authToken) proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go index 970c723a9b6..d697880d33b 100644 --- a/backend/internal/service/account_test_service_openai_test.go +++ b/backend/internal/service/account_test_service_openai_test.go @@ -389,6 +389,107 @@ func TestAccountTestService_OpenAIAPIKeyResponsesUnsupportedUsesChatCompletionsP require.NotContains(t, body, "当前测试接口仅支持 Responses API 路径") } +func TestAccountTestService_APIKeyChatCompletionsUsesConfiguredAuthHeaderAndExactURL(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := newTestContext() + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_test","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"pong"},"finish_reason":null}]}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + svc := &AccountTestService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 95, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKeyChatCompletions, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "auth_header": "api-key", + "chat_completions_url": "https://compat-upstream.example/custom/chat?tenant=abc", + }, + } + + err := svc.testOpenAIChatCompletionsConnection(ctx, account, "gpt-5.4", "hello", account.GetOpenAIChatCompletionsURL(), account.GetUpstreamAPIKey()) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "https://compat-upstream.example/custom/chat?tenant=abc", upstream.lastReq.URL.String()) + require.Equal(t, "sk-test", upstream.lastReq.Header.Get("api-key")) + require.Empty(t, upstream.lastReq.Header.Get("Authorization")) +} + +func TestAccountTestService_APIKeyChatCompletionsDefaultsToAuthorizationBearer(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := newTestContext() + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + }} + svc := &AccountTestService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 96, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKeyChatCompletions, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "chat_completions_url": "https://compat-upstream.example/v1/chat/completions", + }, + } + + err := svc.testOpenAIChatCompletionsConnection(ctx, account, "gpt-5.4", "hello", account.GetOpenAIChatCompletionsURL(), account.GetUpstreamAPIKey()) + require.NoError(t, err) + require.Equal(t, "Bearer sk-test", upstream.lastReq.Header.Get("Authorization")) + require.Empty(t, upstream.lastReq.Header.Get("api-key")) + require.Empty(t, upstream.lastReq.Header.Get("x-api-key")) +} + +func TestAccountTestService_APIKeyChatCompletionsSupportsXAPIKeyHeader(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := newTestContext() + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + }} + svc := &AccountTestService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 97, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKeyChatCompletions, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "auth_header": "x-api-key", + "chat_completions_url": "https://compat-upstream.example/v1/chat/completions", + }, + } + + err := svc.testOpenAIChatCompletionsConnection(ctx, account, "gpt-5.4", "hello", account.GetOpenAIChatCompletionsURL(), account.GetUpstreamAPIKey()) + require.NoError(t, err) + require.Equal(t, "sk-test", upstream.lastReq.Header.Get("x-api-key")) + require.Empty(t, upstream.lastReq.Header.Get("Authorization")) +} + func TestAccountTestService_OpenAIChatCompletionsPathReturns4xx(t *testing.T) { gin.SetMode(gin.TestMode) ctx, recorder := newTestContext() diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 81d1f022791..2885f270b4c 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -283,7 +283,8 @@ type CreateAccountInput struct { LoadFactor *int GroupIDs []int64 ExpiresAt *int64 - AutoPauseOnExpired *bool + AutoPauseOnExpired *bool + StripReasoningEffortOnCC *bool // SkipDefaultGroupBind prevents auto-binding to platform default group when GroupIDs is empty. SkipDefaultGroupBind bool // SkipMixedChannelCheck skips the mixed channel risk check when binding groups. @@ -305,8 +306,9 @@ type UpdateAccountInput struct { Status string GroupIDs *[]int64 ExpiresAt *int64 - AutoPauseOnExpired *bool - SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险) + AutoPauseOnExpired *bool + StripReasoningEffortOnCC *bool + SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险) } // BulkUpdateAccountsInput describes the payload for bulk updating accounts. @@ -2501,6 +2503,9 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } else { account.AutoPauseOnExpired = true } + if input.StripReasoningEffortOnCC != nil { + account.StripReasoningEffortOnCC = *input.StripReasoningEffortOnCC + } if input.RateMultiplier != nil { if *input.RateMultiplier < 0 { return nil, errors.New("rate_multiplier must be >= 0") @@ -2646,6 +2651,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U if input.AutoPauseOnExpired != nil { account.AutoPauseOnExpired = *input.AutoPauseOnExpired } + if input.StripReasoningEffortOnCC != nil { + account.StripReasoningEffortOnCC = *input.StripReasoningEffortOnCC + } // 先验证分组是否存在(在任何写操作之前) if input.GroupIDs != nil { diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index b64412383fc..f4dc4572daf 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -71,6 +71,7 @@ const ( AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分) AccountTypeServiceAccount = domain.AccountTypeServiceAccount // Google Service Account 类型账号(用于 Vertex AI) + AccountTypeAPIKeyChatCompletions = domain.AccountTypeAPIKeyChatCompletions ) // Redeem type constants diff --git a/backend/internal/service/gateway_claude_chat_completions_raw.go b/backend/internal/service/gateway_claude_chat_completions_raw.go new file mode 100644 index 00000000000..14453caf9a6 --- /dev/null +++ b/backend/internal/service/gateway_claude_chat_completions_raw.go @@ -0,0 +1,316 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// ForwardClaudeChatCompletionsRaw 直转客户端的 Chat Completions 请求到一个 OpenAI 兼容的 +// /v1/chat/completions 上游(账号类型 = apikey-chat-completions),不做任何 +// 协议转换。客户端的请求体本就是 CC 格式,上游也是 CC 端点,所以是纯透传。 +// +// 入口:handler.GatewayHandler.ChatCompletions(Anthropic 平台分组下的 /v1/chat/completions) +// 在该 handler 中按 account.IsOpenAIChatCompletionsUpstream() 分流到此函数。 +// +// 与 OpenAIGatewayService.forwardAsRawChatCompletions 高度相似,但归属于 GatewayService, +// 不依赖 OpenAI 专属的 fast policy / OAuth transform / WebSocket adapter。 +// +// TODO: dedupe with openai_gateway_chat_completions_raw.go::forwardAsRawChatCompletions +// (计划在抽出公共 raw-CC 透传 helper 后合并)。 +func (s *GatewayService) ForwardClaudeChatCompletionsRaw( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + parsed *ParsedRequest, +) (*ForwardResult, error) { + startTime := time.Now() + + // 1. Parse minimal fields needed for routing/billing. + originalModel := gjson.GetBytes(body, "model").String() + if originalModel == "" { + writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return nil, fmt.Errorf("missing model in request") + } + clientStream := gjson.GetBytes(body, "stream").Bool() + + // 2. Model mapping. + mappedModel := account.GetMappedModel(originalModel) + upstreamBody := body + if mappedModel != originalModel { + upstreamBody = ReplaceModelInBody(body, mappedModel) + } + + // 3. Stream usage enforcement so billing is recoverable from the closing chunk. + if clientStream { + var usageErr error + upstreamBody, usageErr = ensureOpenAIChatStreamUsage(upstreamBody) + if usageErr != nil { + writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "Failed to prepare stream options") + return nil, fmt.Errorf("enable stream usage: %w", usageErr) + } + } + + logger.L().Debug("gateway claude chat_completions raw: forwarding without protocol conversion", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("upstream_model", mappedModel), + zap.Bool("stream", clientStream), + ) + + // 4. Resolve upstream URL + API key. + apiKey := account.GetUpstreamAPIKey() + if apiKey == "" { + writeChatCompletionsError(c, http.StatusBadGateway, "api_error", fmt.Sprintf("account %d missing api_key", account.ID)) + return nil, fmt.Errorf("account %d missing api_key", account.ID) + } + targetURL := account.GetOpenAIChatCompletionsURL() + if targetURL == "" { + writeChatCompletionsError(c, http.StatusBadGateway, "api_error", fmt.Sprintf("account %d missing chat_completions_url", account.ID)) + return nil, fmt.Errorf("account %d missing chat_completions_url", account.ID) + } + if _, err := s.validateUpstreamBaseURL(targetURL); err != nil { + writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Invalid upstream URL") + return nil, fmt.Errorf("invalid chat_completions_url: %w", err) + } + + // 5. Build upstream request. + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody)) + releaseUpstreamCtx() + if err != nil { + writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Failed to build upstream request") + return nil, fmt.Errorf("build upstream request: %w", err) + } + upstreamReq.Header.Set("Content-Type", "application/json") + applyOpenAICompatibleAPIKeyAuth(upstreamReq, account, apiKey) + if clientStream { + upstreamReq.Header.Set("Accept", "text/event-stream") + } else { + upstreamReq.Header.Set("Accept", "application/json") + } + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + if openaiCCRawAllowedHeaders[strings.ToLower(key)] { + for _, v := range values { + upstreamReq.Header.Add(key, v) + } + } + } + } + + // 6. Send. + proxyURL := "" + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + // 7. Error handling. + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if upstreamMsg == "" { + upstreamMsg = http.StatusText(resp.StatusCode) + } + + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "upstream_error", + Message: upstreamMsg, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + // 可 failover 的状态码:交给 handler 失败循环决定是否切换账号;不在此处写客户端响应。 + if s.shouldFailoverUpstreamError(resp.StatusCode) { + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + ResponseHeaders: resp.Header, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + writeChatCompletionsError(c, mapUpstreamStatusCode(resp.StatusCode), "upstream_error", upstreamMsg) + return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) + } + + // 上游已接受请求(2xx),提前释放用户串行锁——不等流完成。 + // 当前 CC handler 尚未注入 OnUpstreamAccepted(无队列串行化),此处保持向前兼容。 + if parsed != nil && parsed.OnUpstreamAccepted != nil { + parsed.OnUpstreamAccepted() + } + + // 8. Forward response. + if clientStream { + return s.streamClaudeRawChatCompletions(c, resp, originalModel, mappedModel, startTime) + } + return s.bufferClaudeRawChatCompletions(c, resp, originalModel, mappedModel, startTime) +} + +func (s *GatewayService) streamClaudeRawChatCompletions( + c *gin.Context, + resp *http.Response, + originalModel string, + upstreamModel string, + startTime time.Time, +) (*ForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + var usage ClaudeUsage + var firstTokenMs *int + clientDisconnected := false + + for scanner.Scan() { + line := scanner.Text() + if payload, ok := extractOpenAISSEDataLine(line); ok { + trimmedPayload := strings.TrimSpace(payload) + if trimmedPayload != "[DONE]" { + usageOnlyChunk := isOpenAIChatUsageOnlyStreamChunk(payload) + if u := extractCCStreamUsage(payload); u != nil { + usage.InputTokens = u.InputTokens + usage.OutputTokens = u.OutputTokens + usage.CacheReadInputTokens = u.CacheReadInputTokens + } + if firstTokenMs == nil && !usageOnlyChunk { + elapsed := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &elapsed + } + } + } + + if !clientDisconnected { + if _, werr := c.Writer.WriteString(line + "\n"); werr != nil { + clientDisconnected = true + logger.L().Debug("gateway claude chat_completions raw: client disconnected, draining for billing", + zap.Error(werr), + zap.String("request_id", requestID), + ) + } + } + if !clientDisconnected { + c.Writer.Flush() + } + } + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("gateway claude chat_completions raw: stream read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + return &ForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: upstreamModel, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnected, + }, nil +} + +func (s *GatewayService) bufferClaudeRawChatCompletions( + c *gin.Context, + resp *http.Response, + originalModel string, + upstreamModel string, + startTime time.Time, +) (*ForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) + if err != nil { + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response") + } + return nil, fmt.Errorf("read upstream body: %w", err) + } + + var ccResp apicompat.ChatCompletionsResponse + var usage ClaudeUsage + if err := json.Unmarshal(respBody, &ccResp); err == nil && ccResp.Usage != nil { + usage.InputTokens = ccResp.Usage.PromptTokens + usage.OutputTokens = ccResp.Usage.CompletionTokens + if ccResp.Usage.PromptTokensDetails != nil { + usage.CacheReadInputTokens = ccResp.Usage.PromptTokensDetails.CachedTokens + } + } + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + if ct := resp.Header.Get("Content-Type"); ct != "" { + c.Writer.Header().Set("Content-Type", ct) + } else { + c.Writer.Header().Set("Content-Type", "application/json") + } + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(respBody) + + return &ForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: upstreamModel, + Stream: false, + Duration: time.Since(startTime), + }, nil +} diff --git a/backend/internal/service/gateway_claude_chat_completions_raw_test.go b/backend/internal/service/gateway_claude_chat_completions_raw_test.go new file mode 100644 index 00000000000..64b8c29e16a --- /dev/null +++ b/backend/internal/service/gateway_claude_chat_completions_raw_test.go @@ -0,0 +1,161 @@ +//go:build unit + +package service + +import ( + "bytes" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// claudeRawCCTestAccount 构造一个 Claude 平台 + apikey-chat-completions 账号, +// 上游指向给定的 CC URL。 +func claudeRawCCTestAccount(chatCompletionsURL string) *Account { + return &Account{ + ID: 22101, + Name: "claude-raw-cc-svc-test", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKeyChatCompletions, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-claude-raw-cc", + "chat_completions_url": chatCompletionsURL, + }, + } +} + +// TestForwardClaudeChatCompletionsRaw_FailoverOn429 验证可 failover 状态码(429) +// 不直接吐错给客户端,而是返回 *UpstreamFailoverError,让 handler 决定是否切账号。 +func TestForwardClaudeChatCompletionsRaw_FailoverOn429(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"error":{"message":"rate limited"}}`)) + })) + defer upstream.Close() + + httpUp := &anthropicCCStubUpstream{client: upstream.Client()} + svc := newAnthropicCCService(t, httpUp) + account := claudeRawCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"deepseek-chat","messages":[{"role":"user","content":"ping"}],"stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + + parsed := &ParsedRequest{Body: NewRequestBodyRef(body), Model: "deepseek-chat", Stream: false} + _, err := svc.ForwardClaudeChatCompletionsRaw(c.Request.Context(), c, account, body, parsed) + require.Error(t, err) + + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode) + require.Contains(t, string(failoverErr.ResponseBody), "rate limited") + + // 客户端响应未被写出 + require.Equal(t, 0, rec.Body.Len()) +} + +// TestForwardClaudeChatCompletionsRaw_NoFailoverOn400 验证不可 failover 的 4xx +// 仍然写出 OpenAI CC 格式的 error envelope(不是 Anthropic 格式)。 +func TestForwardClaudeChatCompletionsRaw_NoFailoverOn400(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":{"message":"bad model"}}`)) + })) + defer upstream.Close() + + httpUp := &anthropicCCStubUpstream{client: upstream.Client()} + svc := newAnthropicCCService(t, httpUp) + account := claudeRawCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"deepseek-chat","messages":[{"role":"user","content":"ping"}],"stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + + parsed := &ParsedRequest{Body: NewRequestBodyRef(body), Model: "deepseek-chat", Stream: false} + _, err := svc.ForwardClaudeChatCompletionsRaw(c.Request.Context(), c, account, body, parsed) + require.Error(t, err) + + var failoverErr *UpstreamFailoverError + require.False(t, errorIsFailover(err, &failoverErr), "400 must not propagate as UpstreamFailoverError") + + // CC error envelope: {"error":{"type":"upstream_error","message":"..."}} + clientBody := rec.Body.String() + require.Contains(t, clientBody, `"error"`) + require.Contains(t, clientBody, `upstream_error`) +} + +// TestForwardClaudeChatCompletionsRaw_OnUpstreamAcceptedFires 验证 2xx 后立即触发回调。 +func TestForwardClaudeChatCompletionsRaw_OnUpstreamAcceptedFires(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"x","object":"chat.completion","model":"deepseek-chat","choices":[{"index":0,"message":{"role":"assistant","content":"pong"},"finish_reason":"stop"}],"usage":{"prompt_tokens":2,"completion_tokens":1,"total_tokens":3}}`)) + })) + defer upstream.Close() + + httpUp := &anthropicCCStubUpstream{client: upstream.Client()} + svc := newAnthropicCCService(t, httpUp) + account := claudeRawCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"deepseek-chat","messages":[{"role":"user","content":"ping"}],"stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + + var acceptedCount int32 + parsed := &ParsedRequest{ + Body: NewRequestBodyRef(body), + Model: "deepseek-chat", + Stream: false, + OnUpstreamAccepted: func() { + atomic.AddInt32(&acceptedCount, 1) + }, + } + + result, err := svc.ForwardClaudeChatCompletionsRaw(c.Request.Context(), c, account, body, parsed) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, int32(1), atomic.LoadInt32(&acceptedCount)) +} + +// TestForwardClaudeChatCompletionsRaw_NilParsedSafe 验证 parsed=nil 时不 panic +// (兼容当前 CC handler 未注入回调的情况)。 +func TestForwardClaudeChatCompletionsRaw_NilParsedSafe(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"x","object":"chat.completion","model":"deepseek-chat","choices":[{"index":0,"message":{"role":"assistant","content":"pong"},"finish_reason":"stop"}],"usage":{"prompt_tokens":2,"completion_tokens":1,"total_tokens":3}}`)) + })) + defer upstream.Close() + + httpUp := &anthropicCCStubUpstream{client: upstream.Client()} + svc := newAnthropicCCService(t, httpUp) + account := claudeRawCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"deepseek-chat","messages":[{"role":"user","content":"ping"}],"stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + + require.NotPanics(t, func() { + _, _ = svc.ForwardClaudeChatCompletionsRaw(c.Request.Context(), c, account, body, nil) + }) +} diff --git a/backend/internal/service/gateway_forward_anthropic_as_cc.go b/backend/internal/service/gateway_forward_anthropic_as_cc.go new file mode 100644 index 00000000000..20bb7120f37 --- /dev/null +++ b/backend/internal/service/gateway_forward_anthropic_as_cc.go @@ -0,0 +1,561 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// anthropicAsCCAllowedHeaders 是 Anthropic→CC 转换路径的客户端 header 透传白名单。 +// +// 故意不包含 anthropic-* 系列 header(anthropic-version / anthropic-beta / x-api-key +// 等)——上游是 OpenAI 兼容的 Chat Completions 端点,对这些 header 一无所知, +// 透传只会被忽略或触发上游报错。授权头由本路径显式设置 Authorization: Bearer。 +var anthropicAsCCAllowedHeaders = map[string]bool{ + "accept-language": true, + "user-agent": true, +} + +// ForwardAnthropicAsChatCompletions accepts an Anthropic Messages API request +// and forwards it to a Chat-Completions-only upstream by: +// +// 1. Parsing the body as AnthropicRequest +// 2. Converting Anthropic → Responses → Chat Completions (chained via apicompat) +// 3. POSTing the CC body to account.GetOpenAIChatCompletionsURL() with stream=true +// 4. Converting the upstream CC SSE chunks back to Anthropic Messages format: +// - stream: per-chunk CC→Responses→Anthropic SSE conversion +// - non-stream: buffer SSE → assemble Responses → AnthropicResponse JSON +// +// Adapter for `apikey-chat-completions` accounts hit through the /v1/messages +// entry on Anthropic-platform groups. Mirrors OpenAIGatewayService.ForwardResponsesAsChatCompletions +// but with Anthropic as the client-facing protocol. +func (s *GatewayService) ForwardAnthropicAsChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + parsed *ParsedRequest, +) (*ForwardResult, error) { + startTime := time.Now() + if parsed == nil { + return nil, fmt.Errorf("parse request: empty parsed request") + } + body := parsed.Body.Bytes() + + // 1. Parse Anthropic request + var anthropicReq apicompat.AnthropicRequest + if err := json.Unmarshal(body, &anthropicReq); err != nil { + writeAnthropicCCError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return nil, fmt.Errorf("parse anthropic request: %w", err) + } + originalModel := strings.TrimSpace(anthropicReq.Model) + if originalModel == "" { + writeAnthropicCCError(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return nil, fmt.Errorf("model is required") + } + clientStream := anthropicReq.Stream + + // 2. Model mapping (apikey-chat-completions: 由用户自定义上游模型 ID, 不做 Anthropic 规范化) + mappedModel := account.GetMappedModel(originalModel) + if mappedModel != originalModel { + anthropicReq.Model = mappedModel + } + + // 3. Chain conversion: Anthropic → Responses → CC + responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq) + if err != nil { + writeAnthropicCCError(c, http.StatusBadRequest, "invalid_request_error", "Failed to convert request: "+err.Error()) + return nil, fmt.Errorf("convert anthropic to responses: %w", err) + } + ccReq, err := apicompat.ResponsesToChatCompletionsRequest(responsesReq) + if err != nil { + writeAnthropicCCError(c, http.StatusBadRequest, "invalid_request_error", "Failed to convert request: "+err.Error()) + return nil, fmt.Errorf("convert responses to chat completions: %w", err) + } + // 4. Force upstream stream so we can reuse the CC-SSE chain regardless of client. + ccReq.Stream = true + ccReq.Model = mappedModel + ccBody, err := json.Marshal(ccReq) + if err != nil { + return nil, fmt.Errorf("marshal chat completions request: %w", err) + } + ccBody, err = ensureOpenAIChatStreamUsage(ccBody) + if err != nil { + return nil, fmt.Errorf("enable stream usage: %w", err) + } + + // 5. Upstream URL + API key + apiKey := account.GetUpstreamAPIKey() + if apiKey == "" { + writeAnthropicCCError(c, http.StatusBadGateway, "api_error", fmt.Sprintf("account %d missing api_key", account.ID)) + return nil, fmt.Errorf("account %d missing api_key", account.ID) + } + targetURL := account.GetOpenAIChatCompletionsURL() + if targetURL == "" { + writeAnthropicCCError(c, http.StatusBadGateway, "api_error", fmt.Sprintf("account %d missing chat_completions_url", account.ID)) + return nil, fmt.Errorf("account %d missing chat_completions_url", account.ID) + } + if _, err := s.validateUpstreamBaseURL(targetURL); err != nil { + writeAnthropicCCError(c, http.StatusBadGateway, "api_error", "Invalid upstream URL") + return nil, fmt.Errorf("invalid chat_completions_url: %w", err) + } + + logger.L().Debug("gateway anthropic→cc: forwarding", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("upstream_model", mappedModel), + zap.String("upstream_url", targetURL), + zap.Bool("client_stream", clientStream), + ) + + // 6. Build upstream request + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(ccBody)) + releaseUpstreamCtx() + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Accept", "text/event-stream") + applyOpenAICompatibleAPIKeyAuth(upstreamReq, account, apiKey) + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + if anthropicAsCCAllowedHeaders[strings.ToLower(key)] { + for _, v := range values { + upstreamReq.Header.Add(key, v) + } + } + } + } + + // 7. Send + proxyURL := "" + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeAnthropicCCError(c, http.StatusBadGateway, "api_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + // 8. Error handling — translate to Anthropic error envelope + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if upstreamMsg == "" { + upstreamMsg = http.StatusText(resp.StatusCode) + } + + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "upstream_error", + Message: upstreamMsg, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + // 可重试 / 可切账号的状态码(401/403/429/529/5xx)走 failover loop —— 不在此处写客户端响应, + // 让 handler 端根据 stream-written guard 决定是否切换到下一个账号。 + if s.shouldFailoverUpstreamError(resp.StatusCode) { + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + ResponseHeaders: resp.Header, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + // 不可 failover 的错误(如 400 / 404 / 422):原样写 Anthropic 错误信封并返回普通错误 + writeAnthropicCCError(c, mapUpstreamStatusCode(resp.StatusCode), anthropicErrorTypeForStatus(resp.StatusCode), upstreamMsg) + return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) + } + + // 上游已接受请求(2xx),提前释放用户串行锁——不等流完成 + if parsed.OnUpstreamAccepted != nil { + parsed.OnUpstreamAccepted() + } + + // 9. Forward response + if clientStream { + return s.streamAnthropicFromCC(c, resp, originalModel, mappedModel, startTime) + } + return s.bufferAnthropicFromCC(c, resp, originalModel, mappedModel, startTime) +} + +// streamAnthropicFromCC reads upstream Chat Completions SSE chunks line-by-line, +// converts each through CC→Responses→Anthropic, and writes Anthropic SSE to client. +func (s *GatewayService) streamAnthropicFromCC( + c *gin.Context, + resp *http.Response, + originalModel string, + upstreamModel string, + startTime time.Time, +) (*ForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + ccState := apicompat.NewCCStreamState() + ccState.Model = originalModel + anthState := apicompat.NewResponsesEventToAnthropicState() + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + var firstTokenMs *int + var usage ClaudeUsage + clientDisconnected := false + + writeAnthropicEvent := func(evt apicompat.AnthropicStreamEvent) bool { + if clientDisconnected { + return true + } + sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) + if err != nil { + return false + } + if _, err := c.Writer.WriteString(sse); err != nil { + clientDisconnected = true + logger.L().Debug("gateway anthropic→cc stream: client disconnected", + zap.Error(err), + zap.String("request_id", requestID), + ) + return true + } + c.Writer.Flush() + return false + } + + processResponsesEvents := func(events [][]byte) bool { + for _, frame := range events { + // Each frame is "event: \ndata: \n\n" or "data: [DONE]\n\n". + payloads := extractResponsesSSEDataPayloads(frame) + for _, payload := range payloads { + if bytes.Equal(bytes.TrimSpace(payload), []byte("[DONE]")) { + continue + } + var resEvt apicompat.ResponsesStreamEvent + if err := json.Unmarshal(payload, &resEvt); err != nil { + continue + } + anthEvents := apicompat.ResponsesEventToAnthropicEvents(&resEvt, anthState) + for _, anthEvt := range anthEvents { + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + mergeAnthropicUsageFromEvent(&usage, anthEvt) + if disconnected := writeAnthropicEvent(anthEvt); disconnected { + return true + } + } + } + } + return false + } + + for scanner.Scan() { + line := scanner.Bytes() + if len(bytes.TrimSpace(line)) == 0 { + continue + } + events, err := apicompat.ConvertChatCompletionsSSEChunkToResponsesEvents(line, ccState) + if err != nil { + logger.L().Warn("gateway anthropic→cc stream: failed to convert chunk", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + if processResponsesEvents(events) { + break + } + } + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("gateway anthropic→cc stream: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + // Drain CC final events if upstream did not send [DONE]. + if !ccState.CompletedSent { + processResponsesEvents(apicompat.FinalizeCCStream(ccState)) + } + // Flush trailing Anthropic events from the Responses→Anthropic state machine. + for _, evt := range apicompat.FinalizeResponsesAnthropicStream(anthState) { + mergeAnthropicUsageFromEvent(&usage, evt) + if writeAnthropicEvent(evt) { + break + } + } + + // CC final usage may only be present in the closing chunk; pull from CC state if event-derived + // usage is empty. + if usage.InputTokens == 0 && usage.OutputTokens == 0 && ccState.Usage != nil { + usage.InputTokens = ccState.Usage.InputTokens + usage.OutputTokens = ccState.Usage.OutputTokens + if ccState.Usage.InputTokensDetails != nil { + usage.CacheReadInputTokens = ccState.Usage.InputTokensDetails.CachedTokens + } + } + + return &ForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: upstreamModel, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnected, + }, nil +} + +// bufferAnthropicFromCC reads upstream CC SSE (we forced stream=true), assembles +// a complete ChatCompletionsResponse from the chunks, then converts to AnthropicResponse JSON. +func (s *GatewayService) bufferAnthropicFromCC( + c *gin.Context, + resp *http.Response, + originalModel string, + upstreamModel string, + startTime time.Time, +) (*ForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + ccResp := apicompat.ChatCompletionsResponse{ + Object: "chat.completion", + Model: originalModel, + } + var contentBuf strings.Builder + var reasoningBuf strings.Builder + finishReason := "" + type toolCallAcc struct { + ID string + Name string + Arguments strings.Builder + } + toolCalls := map[int]*toolCallAcc{} + var toolCallOrder []int + + for scanner.Scan() { + line := bytes.TrimRight(scanner.Bytes(), "\r\n") + if len(line) == 0 { + continue + } + if !bytes.HasPrefix(line, []byte("data:")) { + continue + } + payload := bytes.TrimSpace(line[len("data:"):]) + if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) { + continue + } + var chk apicompat.ChatCompletionsChunk + if err := json.Unmarshal(payload, &chk); err != nil { + continue + } + if ccResp.ID == "" && chk.ID != "" { + ccResp.ID = chk.ID + } + if chk.Model != "" { + ccResp.Model = chk.Model + } + if chk.Usage != nil { + ccResp.Usage = chk.Usage + } + for _, ch := range chk.Choices { + if ch.Delta.Content != nil { + contentBuf.WriteString(*ch.Delta.Content) + } + if ch.Delta.ReasoningContent != nil { + reasoningBuf.WriteString(*ch.Delta.ReasoningContent) + } + for _, tc := range ch.Delta.ToolCalls { + idx := 0 + if tc.Index != nil { + idx = *tc.Index + } + st, ok := toolCalls[idx] + if !ok { + st = &toolCallAcc{} + toolCalls[idx] = st + toolCallOrder = append(toolCallOrder, idx) + } + if tc.ID != "" { + st.ID = tc.ID + } + if tc.Function.Name != "" { + st.Name = tc.Function.Name + } + if tc.Function.Arguments != "" { + st.Arguments.WriteString(tc.Function.Arguments) + } + } + if ch.FinishReason != nil && *ch.FinishReason != "" { + finishReason = *ch.FinishReason + } + } + } + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("gateway anthropic→cc buffered: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + msg := apicompat.ChatMessage{Role: "assistant"} + contentText := contentBuf.String() + if contentText != "" { + raw, _ := json.Marshal(contentText) + msg.Content = raw + } else { + msg.Content = json.RawMessage(`""`) + } + if reasoningBuf.Len() > 0 { + msg.ReasoningContent = reasoningBuf.String() + } + for _, idx := range toolCallOrder { + st := toolCalls[idx] + args := st.Arguments.String() + if args == "" { + args = "{}" + } + msg.ToolCalls = append(msg.ToolCalls, apicompat.ChatToolCall{ + ID: st.ID, + Type: "function", + Function: apicompat.ChatFunctionCall{ + Name: st.Name, + Arguments: args, + }, + }) + } + if finishReason == "" { + if len(msg.ToolCalls) > 0 { + finishReason = "tool_calls" + } else { + finishReason = "stop" + } + } + ccResp.Choices = []apicompat.ChatChoice{{Index: 0, Message: msg, FinishReason: finishReason}} + + responsesResp := apicompat.ChatCompletionsToResponsesResponse(&ccResp, originalModel) + anthResp := apicompat.ResponsesToAnthropic(responsesResp, originalModel) + + usage := ClaudeUsage{ + InputTokens: anthResp.Usage.InputTokens, + OutputTokens: anthResp.Usage.OutputTokens, + CacheReadInputTokens: anthResp.Usage.CacheReadInputTokens, + } + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + out, err := json.Marshal(anthResp) + if err != nil { + writeAnthropicCCError(c, http.StatusBadGateway, "api_error", "Failed to marshal response") + return nil, fmt.Errorf("marshal anthropic response: %w", err) + } + c.Data(http.StatusOK, "application/json; charset=utf-8", out) + + return &ForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: upstreamModel, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +// mergeAnthropicUsageFromEvent inspects an Anthropic stream event and merges any +// usage information found into the running ClaudeUsage tally. +func mergeAnthropicUsageFromEvent(target *ClaudeUsage, evt apicompat.AnthropicStreamEvent) { + if evt.Type == "message_start" && evt.Message != nil { + mergeAnthropicUsage(target, evt.Message.Usage) + } + if evt.Type == "message_delta" && evt.Usage != nil { + mergeAnthropicUsage(target, *evt.Usage) + } +} + +// writeAnthropicCCError writes an Anthropic-style error envelope. Used by the +// Anthropic→CC forwarding path so Claude SDK clients can parse failures. +func writeAnthropicCCError(c *gin.Context, statusCode int, errType, message string) { + if c == nil { + return + } + c.JSON(statusCode, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +// anthropicErrorTypeForStatus maps an upstream HTTP status code to an Anthropic +// error envelope `error.type` string. +func anthropicErrorTypeForStatus(status int) string { + switch { + case status == http.StatusUnauthorized || status == http.StatusForbidden: + return "authentication_error" + case status == http.StatusTooManyRequests: + return "rate_limit_error" + case status >= 400 && status < 500: + return "invalid_request_error" + default: + return "api_error" + } +} diff --git a/backend/internal/service/gateway_forward_anthropic_as_cc_test.go b/backend/internal/service/gateway_forward_anthropic_as_cc_test.go new file mode 100644 index 00000000000..d21f9de597e --- /dev/null +++ b/backend/internal/service/gateway_forward_anthropic_as_cc_test.go @@ -0,0 +1,432 @@ +//go:build unit + +package service + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +// anthropicCCStubUpstream forwards Do/DoWithTLS to an httptest.Server's client, +// remembering the last request's URL / Authorization / body for assertion. +type anthropicCCStubUpstream struct { + client *http.Client + requestCount int32 + lastURL string + lastAuthz string + lastAccept string + lastBody []byte + lastHeaders http.Header +} + +func (u *anthropicCCStubUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + atomic.AddInt32(&u.requestCount, 1) + u.lastURL = req.URL.String() + u.lastAuthz = req.Header.Get("Authorization") + u.lastAccept = req.Header.Get("Accept") + u.lastHeaders = req.Header.Clone() + if req.Body != nil { + buf := new(bytes.Buffer) + _, _ = buf.ReadFrom(req.Body) + u.lastBody = buf.Bytes() + _ = req.Body.Close() + req.Body = anthropicCCReusableBody(u.lastBody) + req.ContentLength = int64(len(u.lastBody)) + } + return u.client.Do(req) +} + +func (u *anthropicCCStubUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { + return u.Do(req, proxyURL, accountID, accountConcurrency) +} + +type anthropicCCBody struct { + *bytes.Reader +} + +func (r *anthropicCCBody) Close() error { return nil } + +func anthropicCCReusableBody(b []byte) *anthropicCCBody { + return &anthropicCCBody{Reader: bytes.NewReader(b)} +} + +func anthropicCCTestConfig() *config.Config { + return &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{ + Enabled: false, + AllowInsecureHTTP: true, + }, + }, + } +} + +func anthropicCCTestAccount(chatCompletionsURL string) *Account { + return &Account{ + ID: 22001, + Name: "anthropic-cc-svc-test", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKeyChatCompletions, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-anthropic-cc-test", + "chat_completions_url": chatCompletionsURL, + }, + } +} + +func newAnthropicCCService(t *testing.T, httpUp HTTPUpstream) *GatewayService { + t.Helper() + cfg := anthropicCCTestConfig() + svc := &GatewayService{ + cfg: cfg, + httpUpstream: httpUp, + } + return svc +} + +func TestForwardAnthropicAsChatCompletions_NonStream(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstreamSSE := strings.Join([]string{ + `data: {"id":"chk1","object":"chat.completion.chunk","model":"deepseek-chat","choices":[{"index":0,"delta":{"role":"assistant"}}]}`, + ``, + `data: {"id":"chk1","object":"chat.completion.chunk","model":"deepseek-chat","choices":[{"index":0,"delta":{"content":"hello"}}]}`, + ``, + `data: {"id":"chk1","object":"chat.completion.chunk","model":"deepseek-chat","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"completion_tokens":1,"total_tokens":5}}`, + ``, + `data: [DONE]`, + ``, + }, "\n") + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(upstreamSSE)) + })) + defer upstream.Close() + + httpUp := &anthropicCCStubUpstream{client: upstream.Client()} + svc := newAnthropicCCService(t, httpUp) + account := anthropicCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"claude-sonnet-4","max_tokens":64,"stream":false,"messages":[{"role":"user","content":"ping"}]}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("anthropic-beta", "tools-2024-05-16") + c.Request.Header.Set("anthropic-version", "2023-06-01") + + parsed := &ParsedRequest{Body: NewRequestBodyRef(body), Model: "claude-sonnet-4", Stream: false} + result, err := svc.ForwardAnthropicAsChatCompletions(c.Request.Context(), c, account, parsed) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.Stream) + + require.Equal(t, upstream.URL+"/v1/chat/completions", httpUp.lastURL) + require.Equal(t, "Bearer sk-anthropic-cc-test", httpUp.lastAuthz) + require.Equal(t, "text/event-stream", httpUp.lastAccept) + + // Header whitelist: anthropic-* MUST NOT leak. + require.Empty(t, httpUp.lastHeaders.Get("anthropic-beta")) + require.Empty(t, httpUp.lastHeaders.Get("anthropic-version")) + require.Empty(t, httpUp.lastHeaders.Get("x-api-key")) + + // Upstream body must be ChatCompletions (messages, not Anthropic 'system' style). + require.Contains(t, string(httpUp.lastBody), `"messages"`) + require.True(t, gjson.GetBytes(httpUp.lastBody, "stream").Bool()) + require.True(t, gjson.GetBytes(httpUp.lastBody, "stream_options.include_usage").Bool()) + + // Client got an Anthropic JSON response. + clientBody := rec.Body.String() + require.True(t, gjson.Valid(clientBody)) + require.Equal(t, "message", gjson.Get(clientBody, "type").String()) + require.Contains(t, clientBody, "hello") +} + +func TestForwardAnthropicAsChatCompletions_Stream(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstreamSSE := strings.Join([]string{ + `data: {"id":"chk2","object":"chat.completion.chunk","model":"deepseek-chat","choices":[{"index":0,"delta":{"role":"assistant"}}]}`, + ``, + `data: {"id":"chk2","object":"chat.completion.chunk","model":"deepseek-chat","choices":[{"index":0,"delta":{"content":"Hi"}}]}`, + ``, + `data: {"id":"chk2","object":"chat.completion.chunk","model":"deepseek-chat","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":3,"completion_tokens":1,"total_tokens":4}}`, + ``, + `data: [DONE]`, + ``, + }, "\n") + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(upstreamSSE)) + })) + defer upstream.Close() + + httpUp := &anthropicCCStubUpstream{client: upstream.Client()} + svc := newAnthropicCCService(t, httpUp) + account := anthropicCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"claude-sonnet-4","max_tokens":64,"stream":true,"messages":[{"role":"user","content":"hi"}]}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + parsed := &ParsedRequest{Body: NewRequestBodyRef(body), Model: "claude-sonnet-4", Stream: true} + result, err := svc.ForwardAnthropicAsChatCompletions(c.Request.Context(), c, account, parsed) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + + clientBody := rec.Body.String() + require.Contains(t, clientBody, "event: message_start") + require.Contains(t, clientBody, "event: content_block_start") + require.Contains(t, clientBody, "event: content_block_delta") + require.Contains(t, clientBody, "event: message_delta") + require.Contains(t, clientBody, "event: message_stop") +} + +func TestForwardAnthropicAsChatCompletions_ToolUse(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstreamSSE := strings.Join([]string{ + `data: {"id":"chk3","object":"chat.completion.chunk","model":"m","choices":[{"index":0,"delta":{"role":"assistant"}}]}`, + ``, + `data: {"id":"chk3","object":"chat.completion.chunk","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_x","function":{"name":"lookup"}}]}}]}`, + ``, + `data: {"id":"chk3","object":"chat.completion.chunk","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"q\":\"x\"}"}}]}}]}`, + ``, + `data: {"id":"chk3","object":"chat.completion.chunk","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`, + ``, + `data: [DONE]`, + ``, + }, "\n") + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(upstreamSSE)) + })) + defer upstream.Close() + + httpUp := &anthropicCCStubUpstream{client: upstream.Client()} + svc := newAnthropicCCService(t, httpUp) + account := anthropicCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"claude-sonnet-4","max_tokens":64,"stream":true,"tools":[{"name":"lookup","input_schema":{"type":"object","properties":{"q":{"type":"string"}}}}],"messages":[{"role":"user","content":"please lookup"}]}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + parsed := &ParsedRequest{Body: NewRequestBodyRef(body), Model: "claude-sonnet-4", Stream: true} + result, err := svc.ForwardAnthropicAsChatCompletions(c.Request.Context(), c, account, parsed) + require.NoError(t, err) + require.NotNil(t, result) + + clientBody := rec.Body.String() + require.Contains(t, clientBody, "tool_use") + require.Contains(t, clientBody, "lookup") + require.Contains(t, clientBody, "input_json_delta") +} + +// TestForwardAnthropicAsChatCompletions_FailoverOn429 验证可 failover 状态码(429) +// 不直接吐错给客户端,而是返回 *UpstreamFailoverError,让外层 handler 决定是否切账号。 +func TestForwardAnthropicAsChatCompletions_FailoverOn429(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Cf-Ray", "abc123") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"error":{"message":"rate limited"}}`)) + })) + defer upstream.Close() + + httpUp := &anthropicCCStubUpstream{client: upstream.Client()} + svc := newAnthropicCCService(t, httpUp) + account := anthropicCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"claude-sonnet-4","max_tokens":64,"stream":false,"messages":[{"role":"user","content":"ping"}]}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + parsed := &ParsedRequest{Body: NewRequestBodyRef(body), Model: "claude-sonnet-4", Stream: false} + _, err := svc.ForwardAnthropicAsChatCompletions(c.Request.Context(), c, account, parsed) + require.Error(t, err) + + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, http.StatusTooManyRequests, failoverErr.StatusCode) + require.Contains(t, string(failoverErr.ResponseBody), "rate limited") + require.Equal(t, "abc123", failoverErr.ResponseHeaders.Get("X-Cf-Ray")) + + // 客户端响应未被写出,让外层 handler 决定后续动作 + require.Equal(t, 0, rec.Body.Len()) +} + +// TestForwardAnthropicAsChatCompletions_FailoverOn500 同样验证 5xx 走 failover。 +func TestForwardAnthropicAsChatCompletions_FailoverOn500(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":{"message":"upstream busted"}}`)) + })) + defer upstream.Close() + + httpUp := &anthropicCCStubUpstream{client: upstream.Client()} + svc := newAnthropicCCService(t, httpUp) + account := anthropicCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"claude-sonnet-4","max_tokens":64,"stream":false,"messages":[{"role":"user","content":"ping"}]}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + parsed := &ParsedRequest{Body: NewRequestBodyRef(body), Model: "claude-sonnet-4", Stream: false} + _, err := svc.ForwardAnthropicAsChatCompletions(c.Request.Context(), c, account, parsed) + + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, http.StatusInternalServerError, failoverErr.StatusCode) + require.Equal(t, 0, rec.Body.Len()) +} + +// TestForwardAnthropicAsChatCompletions_NoFailoverOn400 验证不可 failover 的 4xx +// (如 400/404/422)保持原有行为:写 Anthropic error envelope + 返回 plain error。 +func TestForwardAnthropicAsChatCompletions_NoFailoverOn400(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":{"message":"bad model name"}}`)) + })) + defer upstream.Close() + + httpUp := &anthropicCCStubUpstream{client: upstream.Client()} + svc := newAnthropicCCService(t, httpUp) + account := anthropicCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"claude-sonnet-4","max_tokens":64,"stream":false,"messages":[{"role":"user","content":"ping"}]}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + parsed := &ParsedRequest{Body: NewRequestBodyRef(body), Model: "claude-sonnet-4", Stream: false} + _, err := svc.ForwardAnthropicAsChatCompletions(c.Request.Context(), c, account, parsed) + require.Error(t, err) + + var failoverErr *UpstreamFailoverError + require.False(t, errorIsFailover(err, &failoverErr), "400 must not propagate as UpstreamFailoverError") + + clientBody := rec.Body.String() + require.True(t, gjson.Valid(clientBody)) + require.Equal(t, "error", gjson.Get(clientBody, "type").String()) + require.Equal(t, "invalid_request_error", gjson.Get(clientBody, "error.type").String()) +} + +// TestForwardAnthropicAsChatCompletions_OnUpstreamAcceptedFires 验证上游返回 2xx +// 之后立即触发 OnUpstreamAccepted 回调,便于 handler 提前释放用户串行锁。 +func TestForwardAnthropicAsChatCompletions_OnUpstreamAcceptedFires(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstreamSSE := strings.Join([]string{ + `data: {"id":"chk1","object":"chat.completion.chunk","model":"deepseek-chat","choices":[{"index":0,"delta":{"role":"assistant"}}]}`, + ``, + `data: {"id":"chk1","object":"chat.completion.chunk","model":"deepseek-chat","choices":[{"index":0,"delta":{"content":"ok"}}]}`, + ``, + `data: {"id":"chk1","object":"chat.completion.chunk","model":"deepseek-chat","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":3,"completion_tokens":1,"total_tokens":4}}`, + ``, + `data: [DONE]`, + ``, + }, "\n") + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(upstreamSSE)) + })) + defer upstream.Close() + + httpUp := &anthropicCCStubUpstream{client: upstream.Client()} + svc := newAnthropicCCService(t, httpUp) + account := anthropicCCTestAccount(upstream.URL + "/v1/chat/completions") + + body := []byte(`{"model":"claude-sonnet-4","max_tokens":64,"stream":false,"messages":[{"role":"user","content":"ping"}]}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + var acceptedCount int32 + parsed := &ParsedRequest{ + Body: NewRequestBodyRef(body), + Model: "claude-sonnet-4", + Stream: false, + OnUpstreamAccepted: func() { + atomic.AddInt32(&acceptedCount, 1) + }, + } + + result, err := svc.ForwardAnthropicAsChatCompletions(c.Request.Context(), c, account, parsed) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, int32(1), atomic.LoadInt32(&acceptedCount)) +} + +// errorIsFailover 是 errors.As 的薄包装,避免重复使用 Go 1.20 ErrorAs 的笔误。 +func errorIsFailover(err error, target **UpstreamFailoverError) bool { + if err == nil { + return false + } + for cur := err; cur != nil; { + if cast, ok := cur.(*UpstreamFailoverError); ok { + *target = cast + return true + } + type unwrapper interface{ Unwrap() error } + if u, ok := cur.(unwrapper); ok { + cur = u.Unwrap() + continue + } + break + } + return false +} + +func TestForwardAnthropicAsChatCompletions_MissingURL(t *testing.T) { + gin.SetMode(gin.TestMode) + + httpUp := &anthropicCCStubUpstream{client: http.DefaultClient} + svc := newAnthropicCCService(t, httpUp) + account := anthropicCCTestAccount("") + + body := []byte(`{"model":"claude-sonnet-4","max_tokens":64,"stream":false,"messages":[{"role":"user","content":"ping"}]}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + parsed := &ParsedRequest{Body: NewRequestBodyRef(body), Model: "claude-sonnet-4", Stream: false} + _, err := svc.ForwardAnthropicAsChatCompletions(c.Request.Context(), c, account, parsed) + require.Error(t, err) + require.Contains(t, err.Error(), "chat_completions_url") + require.Equal(t, int32(0), atomic.LoadInt32(&httpUp.requestCount)) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 812780dc4b3..77ca42ecb5e 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -4439,6 +4439,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return s.handleWebSearchEmulation(ctx, c, account, parsed) } + // apikey-chat-completions:客户端 /v1/messages → Anthropic↔CC 协议转换 → 上游 OpenAI 兼容 CC 端点。 + // 必须在 Anthropic API key passthrough / Bedrock 等分支之前,因为 IsAnthropicAPIKeyPassthroughEnabled + // 只看类型 + 配置,可能错误命中此账号类型。 + if account != nil && account.IsOpenAIChatCompletionsUpstream() { + return s.ForwardAnthropicAsChatCompletions(ctx, c, account, parsed) + } + if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { passthroughBody := parsed.Body.Bytes() passthroughModel := parsed.Model diff --git a/backend/internal/service/openai_compatible_auth.go b/backend/internal/service/openai_compatible_auth.go new file mode 100644 index 00000000000..bb1481643c3 --- /dev/null +++ b/backend/internal/service/openai_compatible_auth.go @@ -0,0 +1,14 @@ +package service + +import "net/http" + +func applyOpenAICompatibleAPIKeyAuth(req *http.Request, account *Account, apiKey string) { + switch account.OpenAICompatibleAuthHeader() { + case OpenAICompatibleAuthHeaderAPIKey: + req.Header.Set(OpenAICompatibleAuthHeaderAPIKey, apiKey) + case OpenAICompatibleAuthHeaderXAPIKey: + req.Header.Set(OpenAICompatibleAuthHeaderXAPIKey, apiKey) + default: + req.Header.Set(OpenAICompatibleAuthHeaderAuthorization, "Bearer "+apiKey) + } +} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 6e91d85c83a..a845a763b72 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -61,9 +61,13 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( promptCacheKey string, defaultMappedModel string, ) (*OpenAIForwardResult, error) { - // 入口分流:APIKey 账号 + 强制或已探测确认上游不支持 Responses,走 CC 直转。 - // 自动模式下标记缺失(未探测)按"现状即证据"原则继续走下方原 Responses 转换路径。 - if account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra) { + // 入口分流:APIKey / APIKeyChatCompletions 账号走 CC 直转。 + // - APIKeyChatCompletions:上游为 OpenAI 兼容 Chat Completions 端点,原生不支持 Responses + // - APIKey + 已探测且确认上游不支持 Responses + // 标记缺失(未探测)按"现状即证据"原则继续走下方原 Responses 转换路径。 + if account.IsOpenAIChatCompletionsUpstream() || + (account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra)) { + return s.forwardAsRawChatCompletions(ctx, c, account, body, defaultMappedModel) } diff --git a/backend/internal/service/openai_gateway_chat_completions_raw.go b/backend/internal/service/openai_gateway_chat_completions_raw.go index 3ff6fac4aa8..98f0419c8a8 100644 --- a/backend/internal/service/openai_gateway_chat_completions_raw.go +++ b/backend/internal/service/openai_gateway_chat_completions_raw.go @@ -128,7 +128,19 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions( if err != nil { return nil, fmt.Errorf("invalid base_url: %w", err) } - targetURL := buildOpenAIChatCompletionsURL(validatedURL) + var targetURL string + if account.IsOpenAIChatCompletionsUpstream() { + ccURL := account.GetOpenAIChatCompletionsURL() + if ccURL == "" { + return nil, fmt.Errorf("account %d missing chat_completions_url", account.ID) + } + if _, err := s.validateUpstreamBaseURL(ccURL); err != nil { + return nil, fmt.Errorf("invalid chat_completions_url: %w", err) + } + targetURL = ccURL + } else { + targetURL = buildOpenAIChatCompletionsURL(validatedURL) + } upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody)) @@ -138,7 +150,7 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions( } upstreamReq = upstreamReq.WithContext(WithHTTPUpstreamProfile(upstreamReq.Context(), HTTPUpstreamProfileOpenAI)) upstreamReq.Header.Set("Content-Type", "application/json") - upstreamReq.Header.Set("Authorization", "Bearer "+apiKey) + applyOpenAICompatibleAPIKeyAuth(upstreamReq, account, apiKey) if clientStream { upstreamReq.Header.Set("Accept", "text/event-stream") } else { diff --git a/backend/internal/service/openai_gateway_responses_to_cc.go b/backend/internal/service/openai_gateway_responses_to_cc.go new file mode 100644 index 00000000000..16130ff0884 --- /dev/null +++ b/backend/internal/service/openai_gateway_responses_to_cc.go @@ -0,0 +1,610 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +type responsesAsChatCompletionsForwardRequest struct { + originalModel string + clientStream bool + reasoningEffort *string + serviceTier *string + billingModel string + upstreamModel string + ccBody []byte + targetURL string + apiKey string +} + +func (s *OpenAIGatewayService) prepareResponsesAsChatCompletionsRequest(account *Account, body []byte, forceStream bool) (*responsesAsChatCompletionsForwardRequest, error) { + var responsesReq apicompat.ResponsesRequest + if err := json.Unmarshal(body, &responsesReq); err != nil { + return nil, fmt.Errorf("parse responses request: %w", err) + } + originalModel := strings.TrimSpace(responsesReq.Model) + if originalModel == "" { + return nil, fmt.Errorf("model is required") + } + if forceStream { + responsesReq.Stream = true + } + clientStream := responsesReq.Stream + reasoningEffort := ExtractResponsesReasoningEffortFromBody(body) + var serviceTier *string + if responsesReq.ServiceTier != "" { + st := responsesReq.ServiceTier + serviceTier = &st + } + + ccReq, err := apicompat.ResponsesToChatCompletionsRequestWithOptions(&responsesReq, apicompat.ConvertResponsesOptions{ + StripReasoningEffort: account.StripReasoningEffortOnCC, + }) + if err != nil { + return nil, fmt.Errorf("convert responses to chat completions: %w", err) + } + + billingModel := resolveOpenAIForwardModel(account, originalModel, "") + upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel) + ccReq.Model = upstreamModel + + ccBody, err := json.Marshal(ccReq) + if err != nil { + return nil, fmt.Errorf("marshal chat completions request: %w", err) + } + if clientStream { + ccBody, err = ensureOpenAIChatStreamUsage(ccBody) + if err != nil { + return nil, fmt.Errorf("enable stream usage: %w", err) + } + } + + apiKey := account.GetUpstreamAPIKey() + if apiKey == "" { + return nil, fmt.Errorf("account %d missing api_key", account.ID) + } + targetURL := account.GetOpenAIChatCompletionsURL() + if targetURL == "" { + return nil, fmt.Errorf("account %d missing chat_completions_url", account.ID) + } + if _, err := s.validateUpstreamBaseURL(targetURL); err != nil { + return nil, fmt.Errorf("invalid chat_completions_url: %w", err) + } + + return &responsesAsChatCompletionsForwardRequest{ + originalModel: originalModel, + clientStream: clientStream, + reasoningEffort: reasoningEffort, + serviceTier: serviceTier, + billingModel: billingModel, + upstreamModel: upstreamModel, + ccBody: ccBody, + targetURL: targetURL, + apiKey: apiKey, + }, nil +} + +// ForwardResponsesAsChatCompletions accepts an OpenAI Responses API request +// and forwards it to a Chat-Completions-only upstream by: +// +// 1. Parsing the body as ResponsesRequest +// 2. Converting to ChatCompletionsRequest via apicompat.ResponsesToChatCompletionsRequest +// 3. POSTing the CC body to account.GetOpenAIChatCompletionsURL() +// 4. Converting the upstream CC response back to Responses format: +// - non-stream: ChatCompletionsResponse → ResponsesResponse JSON +// - stream: per-chunk SSE conversion via apicompat.ConvertChatCompletionsSSEChunkToResponsesEvents +// +// Adapter for `apikey-chat-completions` accounts hit through the /v1/responses +// or /responses entry. Mirrors the inverse of ForwardAsChatCompletions which +// handles CC requests targeting Responses upstreams. +func (s *OpenAIGatewayService) ForwardResponsesAsChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + forwardReq, err := s.prepareResponsesAsChatCompletionsRequest(account, body, false) + if err != nil { + if c != nil { + writeResponsesError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + } + return nil, err + } + originalModel := forwardReq.originalModel + clientStream := forwardReq.clientStream + reasoningEffort := forwardReq.reasoningEffort + serviceTier := forwardReq.serviceTier + billingModel := forwardReq.billingModel + upstreamModel := forwardReq.upstreamModel + ccBody := forwardReq.ccBody + targetURL := forwardReq.targetURL + apiKey := forwardReq.apiKey + + logger.L().Debug("openai responses→cc: forwarding", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("billing_model", billingModel), + zap.String("upstream_model", upstreamModel), + zap.String("upstream_url", targetURL), + zap.Bool("stream", clientStream), + ) + + // 6. Build upstream request + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(ccBody)) + releaseUpstreamCtx() + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + upstreamReq.Header.Set("Content-Type", "application/json") + applyOpenAICompatibleAPIKeyAuth(upstreamReq, account, apiKey) + if clientStream { + upstreamReq.Header.Set("Accept", "text/event-stream") + } else { + upstreamReq.Header.Set("Accept", "application/json") + } + for key, values := range c.Request.Header { + if openaiCCRawAllowedHeaders[strings.ToLower(key)] { + for _, v := range values { + upstreamReq.Header.Add(key, v) + } + } + } + if customUA := account.GetOpenAIUserAgent(); customUA != "" { + upstreamReq.Header.Set("user-agent", customUA) + } + + // 7. Send + proxyURL := "" + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + // 8. Error handling + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "upstream_error", + Message: upstreamMsg, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + writeResponsesError(c, mapUpstreamStatusCode(resp.StatusCode), "server_error", upstreamMsg) + return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) + } + + // 9. Forward response + if clientStream { + return s.streamResponsesFromCC(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime) + } + return s.bufferResponsesFromCC(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime) +} + +// bufferResponsesFromCC reads the upstream non-streaming Chat Completions JSON +// response, converts it to a Responses API JSON response, and writes it to the +// client. +func (s *OpenAIGatewayService) bufferResponsesFromCC( + c *gin.Context, + resp *http.Response, + originalModel string, + billingModel string, + upstreamModel string, + reasoningEffort *string, + serviceTier *string, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) + if err != nil { + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + writeResponsesError(c, http.StatusBadGateway, "server_error", "Failed to read upstream response") + } + return nil, fmt.Errorf("read upstream body: %w", err) + } + + var ccResp apicompat.ChatCompletionsResponse + if err := json.Unmarshal(respBody, &ccResp); err != nil { + writeResponsesError(c, http.StatusBadGateway, "server_error", "Invalid upstream response") + return nil, fmt.Errorf("decode chat completions response: %w", err) + } + + var usage OpenAIUsage + if ccResp.Usage != nil { + usage = OpenAIUsage{ + InputTokens: ccResp.Usage.PromptTokens, + OutputTokens: ccResp.Usage.CompletionTokens, + } + if ccResp.Usage.PromptTokensDetails != nil { + usage.CacheReadInputTokens = ccResp.Usage.PromptTokensDetails.CachedTokens + } + } + + responsesResp := apicompat.ChatCompletionsToResponsesResponse(&ccResp, originalModel) + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + out, err := json.Marshal(responsesResp) + if err != nil { + writeResponsesError(c, http.StatusBadGateway, "server_error", "Failed to marshal response") + return nil, fmt.Errorf("marshal responses response: %w", err) + } + c.Data(http.StatusOK, "application/json; charset=utf-8", out) + + return &OpenAIForwardResult{ + RequestID: requestID, + ResponseID: responsesResp.ID, + Usage: usage, + Model: originalModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, + ReasoningEffort: reasoningEffort, + ServiceTier: serviceTier, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +// streamResponsesFromCC reads upstream Chat Completions SSE chunks line-by-line, +// converts each to Responses API SSE events, and writes them to the client. +func (s *OpenAIGatewayService) streamResponsesFromCC( + c *gin.Context, + resp *http.Response, + originalModel string, + billingModel string, + upstreamModel string, + reasoningEffort *string, + serviceTier *string, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + state := apicompat.NewCCStreamState() + state.Model = originalModel + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + var firstTokenMs *int + clientDisconnected := false + + flushEvents := func(events [][]byte) { + if clientDisconnected || len(events) == 0 { + return + } + for _, evt := range events { + if _, err := c.Writer.Write(evt); err != nil { + clientDisconnected = true + logger.L().Debug("openai responses→cc stream: client disconnected", + zap.Error(err), + zap.String("request_id", requestID), + ) + return + } + } + c.Writer.Flush() + } + + for scanner.Scan() { + line := scanner.Bytes() + // Skip blank separator lines. + trimmed := bytes.TrimSpace(line) + if len(trimmed) == 0 { + continue + } + events, err := apicompat.ConvertChatCompletionsSSEChunkToResponsesEvents(line, state) + if err != nil { + logger.L().Warn("openai responses→cc stream: failed to convert chunk", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + if firstTokenMs == nil && len(events) > 0 { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + flushEvents(events) + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai responses→cc stream: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + // Ensure stream is finalised even if upstream stopped without [DONE]. + if !state.CompletedSent { + flushEvents(apicompat.FinalizeCCStream(state)) + } + + usage := OpenAIUsage{} + if state.Usage != nil { + usage.InputTokens = state.Usage.InputTokens + usage.OutputTokens = state.Usage.OutputTokens + if state.Usage.InputTokensDetails != nil { + usage.CacheReadInputTokens = state.Usage.InputTokensDetails.CachedTokens + } + } + + return &OpenAIForwardResult{ + RequestID: requestID, + ResponseID: state.ResponseID, + Usage: usage, + Model: originalModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, + ReasoningEffort: reasoningEffort, + ServiceTier: serviceTier, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +func (s *OpenAIGatewayService) ProxyResponsesWebSocketAsChatCompletions( + ctx context.Context, + c *gin.Context, + clientConn *coderws.Conn, + account *Account, + firstClientMessage []byte, + hooks *OpenAIWSIngressHooks, +) error { + if clientConn == nil { + return errors.New("client websocket is nil") + } + if account == nil { + return errors.New("account is nil") + } + + turn := 1 + nextMessage := firstClientMessage + for { + if hooks != nil && hooks.BeforeTurn != nil { + if err := hooks.BeforeTurn(turn); err != nil { + return err + } + } + + result, turnErr := s.proxyResponsesWebSocketTurnAsChatCompletions(ctx, c, clientConn, account, nextMessage) + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turn, result, turnErr) + } + if turnErr != nil { + return turnErr + } + + msgType, message, err := clientConn.Read(ctx) + if err != nil { + if isOpenAIWSClientDisconnectError(err) { + return nil + } + return fmt.Errorf("read client websocket message: %w", err) + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "unsupported websocket message type", nil) + } + if !gjson.ValidBytes(message) { + return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid JSON payload", nil) + } + nextMessage = message + turn++ + } +} + +func (s *OpenAIGatewayService) proxyResponsesWebSocketTurnAsChatCompletions( + ctx context.Context, + c *gin.Context, + clientConn *coderws.Conn, + account *Account, + body []byte, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + forwardReq, err := s.prepareResponsesAsChatCompletionsRequest(account, body, true) + if err != nil { + return nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, err.Error(), err) + } + + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, forwardReq.targetURL, bytes.NewReader(forwardReq.ccBody)) + releaseUpstreamCtx() + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+forwardReq.apiKey) + upstreamReq.Header.Set("Accept", "text/event-stream") + for key, values := range c.Request.Header { + if openaiCCRawAllowedHeaders[strings.ToLower(key)] { + for _, v := range values { + upstreamReq.Header.Add(key, v) + } + } + } + if customUA := account.GetOpenAIUserAgent(); customUA != "" { + upstreamReq.Header.Set("user-agent", customUA) + } + + proxyURL := "" + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, fmt.Errorf("upstream request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if upstreamMsg == "" { + upstreamMsg = http.StatusText(resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) + } + + requestID := resp.Header.Get("x-request-id") + state := apicompat.NewCCStreamState() + state.Model = forwardReq.originalModel + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + var firstTokenMs *int + for scanner.Scan() { + line := bytes.TrimSpace(scanner.Bytes()) + if len(line) == 0 { + continue + } + events, err := apicompat.ConvertChatCompletionsSSEChunkToResponsesEvents(line, state) + if err != nil { + logger.L().Warn("openai responses→cc websocket: failed to convert chunk", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + if firstTokenMs == nil && len(events) > 0 { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + for _, eventFrame := range events { + payloads := extractResponsesSSEDataPayloads(eventFrame) + for _, payload := range payloads { + if len(payload) == 0 || bytes.Equal(bytes.TrimSpace(payload), []byte("[DONE]")) { + continue + } + writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) + err := clientConn.Write(writeCtx, coderws.MessageText, payload) + cancel() + if err != nil { + return nil, fmt.Errorf("write client websocket event: %w", err) + } + } + } + } + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("read upstream stream: %w", err) + } + if !state.CompletedSent { + for _, eventFrame := range apicompat.FinalizeCCStream(state) { + for _, payload := range extractResponsesSSEDataPayloads(eventFrame) { + if len(payload) == 0 || bytes.Equal(bytes.TrimSpace(payload), []byte("[DONE]")) { + continue + } + writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) + err := clientConn.Write(writeCtx, coderws.MessageText, payload) + cancel() + if err != nil { + return nil, fmt.Errorf("write client websocket final event: %w", err) + } + } + } + } + + usage := OpenAIUsage{} + if state.Usage != nil { + usage.InputTokens = state.Usage.InputTokens + usage.OutputTokens = state.Usage.OutputTokens + if state.Usage.InputTokensDetails != nil { + usage.CacheReadInputTokens = state.Usage.InputTokensDetails.CachedTokens + } + } + + return &OpenAIForwardResult{ + RequestID: requestID, + ResponseID: state.ResponseID, + Usage: usage, + Model: forwardReq.originalModel, + BillingModel: forwardReq.billingModel, + UpstreamModel: forwardReq.upstreamModel, + ReasoningEffort: forwardReq.reasoningEffort, + ServiceTier: forwardReq.serviceTier, + Stream: true, + OpenAIWSMode: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +func extractResponsesSSEDataPayloads(frame []byte) [][]byte { + var payloads [][]byte + for _, rawLine := range bytes.Split(frame, []byte("\n")) { + line := bytes.TrimSpace(rawLine) + if !bytes.HasPrefix(line, []byte("data:")) { + continue + } + payload := bytes.TrimSpace(line[len("data:"):]) + if len(payload) == 0 { + continue + } + payloadCopy := append([]byte(nil), payload...) + payloads = append(payloads, payloadCopy) + } + return payloads +} diff --git a/backend/internal/service/openai_gateway_responses_to_cc_test.go b/backend/internal/service/openai_gateway_responses_to_cc_test.go new file mode 100644 index 00000000000..73f2b36754f --- /dev/null +++ b/backend/internal/service/openai_gateway_responses_to_cc_test.go @@ -0,0 +1,221 @@ +//go:build unit + +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func responsesToCCTestAccount(chatCompletionsURL string) *Account { + return &Account{ + ID: 202, + Name: "responses-to-cc", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKeyChatCompletions, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test-rcc", + "chat_completions_url": chatCompletionsURL, + }, + } +} + +func responsesToCCTestConfig() *config.Config { + return &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{ + Enabled: false, + AllowInsecureHTTP: true, + }, + }, + } +} + +func TestForward_APIKeyChatCompletionsResponsesUsesAdapter(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstreamResp := `{ + "id":"chatcmpl_rcc_forward", + "object":"chat.completion", + "created":1, + "model":"gpt-5.4", + "choices":[{"index":0,"message":{"role":"assistant","content":"hello from cc"},"finish_reason":"stop"}], + "usage":{"prompt_tokens":6,"completion_tokens":3,"total_tokens":9} + }` + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_rcc_forward"}}, + Body: io.NopCloser(strings.NewReader(upstreamResp)), + }} + + body := []byte(`{"model":"gpt-5.4","input":"hi","stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + svc := &OpenAIGatewayService{ + cfg: responsesToCCTestConfig(), + httpUpstream: upstream, + } + account := responsesToCCTestAccount("http://upstream.example/v1/chat/completions") + + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.Stream) + require.Equal(t, "gpt-5.4", result.Model) + require.Equal(t, 6, result.Usage.InputTokens) + require.Equal(t, 3, result.Usage.OutputTokens) + + require.NotNil(t, upstream.lastReq) + require.Equal(t, "http://upstream.example/v1/chat/completions", upstream.lastReq.URL.String()) + require.Equal(t, "Bearer sk-test-rcc", upstream.lastReq.Header.Get("Authorization")) + require.Contains(t, string(upstream.lastBody), `"messages"`) + require.NotContains(t, string(upstream.lastBody), `"input":"hi"`) + + clientBody := rec.Body.String() + require.Contains(t, clientBody, `"object":"response"`) + require.Contains(t, clientBody, "hello from cc") +} + +func TestForwardResponsesAsChatCompletions_NonStreamingTextResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstreamResp := `{ + "id":"chatcmpl_rcc_1", + "object":"chat.completion", + "created":1, + "model":"gpt-4o", + "choices":[{"index":0,"message":{"role":"assistant","content":"hello world"},"finish_reason":"stop"}], + "usage":{"prompt_tokens":7,"completion_tokens":2,"total_tokens":9,"prompt_tokens_details":{"cached_tokens":1}} + }` + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_rcc_buffer"}}, + Body: io.NopCloser(strings.NewReader(upstreamResp)), + }} + + body := []byte(`{"model":"gpt-4o","input":"hi","stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + svc := &OpenAIGatewayService{ + cfg: responsesToCCTestConfig(), + httpUpstream: upstream, + } + account := responsesToCCTestAccount("http://upstream.example/v1/chat/completions") + + result, err := svc.ForwardResponsesAsChatCompletions(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.Stream) + require.Equal(t, "gpt-4o", result.Model) + require.Equal(t, 7, result.Usage.InputTokens) + require.Equal(t, 2, result.Usage.OutputTokens) + require.Equal(t, 1, result.Usage.CacheReadInputTokens) + + // Upstream request was sent to the configured URL with Bearer + JSON headers. + require.NotNil(t, upstream.lastReq) + require.Equal(t, "http://upstream.example/v1/chat/completions", upstream.lastReq.URL.String()) + require.Equal(t, "application/json", upstream.lastReq.Header.Get("Content-Type")) + require.Equal(t, "Bearer sk-test-rcc", upstream.lastReq.Header.Get("Authorization")) + require.Equal(t, "application/json", upstream.lastReq.Header.Get("Accept")) + + // Upstream body must be a Chat Completions JSON (not a Responses payload). + require.Contains(t, string(upstream.lastBody), `"messages"`) + require.NotContains(t, string(upstream.lastBody), `"input"`) + + // Response body sent to client is Responses-shaped JSON containing the assistant text. + clientBody := rec.Body.String() + require.Contains(t, clientBody, `"object":"response"`) + require.Contains(t, clientBody, "hello world") + require.Contains(t, clientBody, `"output_text"`) +} + +func TestForwardResponsesAsChatCompletions_StreamingSSEConversion(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Mock upstream emits Chat Completions SSE chunks; service must convert each + // to Responses SSE events and stream them back to the client. + upstreamSSE := strings.Join([]string{ + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}`, + "", + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hi"},"finish_reason":null}]}`, + "", + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"completion_tokens":1,"total_tokens":5}}`, + "", + "data: [DONE]", + "", + }, "\n") + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_rcc_stream"}}, + Body: io.NopCloser(strings.NewReader(upstreamSSE)), + }} + + body := []byte(`{"model":"gpt-4o","input":"hi","stream":true}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + svc := &OpenAIGatewayService{ + cfg: responsesToCCTestConfig(), + httpUpstream: upstream, + } + account := responsesToCCTestAccount("http://upstream.example/v1/chat/completions") + + result, err := svc.ForwardResponsesAsChatCompletions(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 4, result.Usage.InputTokens) + require.Equal(t, 1, result.Usage.OutputTokens) + + // Upstream Accept must be SSE for stream requests. + require.Equal(t, "text/event-stream", upstream.lastReq.Header.Get("Accept")) + + // Client received Responses SSE events (not raw CC chunks) — should at minimum + // contain a response.created and response.completed event plus terminator. + clientBody := rec.Body.String() + require.Contains(t, clientBody, "event: response.created") + require.Contains(t, clientBody, "event: response.output_text.delta") + require.Contains(t, clientBody, "event: response.completed") + require.Contains(t, clientBody, "data: [DONE]") +} + +func TestForwardResponsesAsChatCompletions_MissingModelReturnsBadRequest(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"input":"hi"}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + svc := &OpenAIGatewayService{ + cfg: responsesToCCTestConfig(), + httpUpstream: &httpUpstreamRecorder{}, + } + account := responsesToCCTestAccount("http://upstream.example/v1/chat/completions") + + _, err := svc.ForwardResponsesAsChatCompletions(context.Background(), c, account, body) + require.Error(t, err) + require.Equal(t, http.StatusBadRequest, rec.Code) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 89ddaa7d1c5..b95ac66ee43 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -2363,6 +2363,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco reqModel, reqStream, promptCacheKey := requestView.Model, requestView.Stream, requestView.PromptCacheKey originalModel := reqModel + if account.IsOpenAIChatCompletionsUpstream() { + return s.ForwardResponsesAsChatCompletions(ctx, c, account, body) + } + if account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra) { return s.forwardResponsesViaRawChatCompletions(ctx, c, account, body) } diff --git a/backend/migrations/020_widen_accounts_type.sql b/backend/migrations/020_widen_accounts_type.sql new file mode 100644 index 00000000000..917de65f8a5 --- /dev/null +++ b/backend/migrations/020_widen_accounts_type.sql @@ -0,0 +1,2 @@ +-- Widen accounts.type column to support longer type names (e.g. "apikey-chat-completions" = 23 chars) +ALTER TABLE accounts ALTER COLUMN type TYPE VARCHAR(40); diff --git a/backend/migrations/021_add_accounts_strip_reasoning_effort.sql b/backend/migrations/021_add_accounts_strip_reasoning_effort.sql new file mode 100644 index 00000000000..bad4e902f71 --- /dev/null +++ b/backend/migrations/021_add_accounts_strip_reasoning_effort.sql @@ -0,0 +1,2 @@ +ALTER TABLE accounts + ADD COLUMN IF NOT EXISTS strip_reasoning_effort_on_cc BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index d1a7729a58d..45c8f28f81d 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -174,7 +174,7 @@ > -
+
{{ t('admin.accounts.claudeCode') }} @@ -204,7 +204,7 @@ >
-
+
{{ t('admin.accounts.claudeConsole') }} @@ -234,7 +234,7 @@ >
-
+
{{ t('admin.accounts.bedrockLabel') }} @@ -264,12 +264,38 @@ >
-
+
Vertex Service Account
+ +

{{ t('admin.accounts.vertexAnthropicHint') }}

+
+

{{ t('admin.accounts.types.apikeyChatCompletionsAnthropicHint') }}

+
-
+
+ +
@@ -1008,6 +1066,148 @@
+ +
+
+ + +

{{ t('admin.accounts.types.chatCompletionsUrlHint') }}

+
+
+ + +

{{ t('admin.accounts.types.chatCompletionsApiKeyHint') }}

+
+
+ + +

{{ t('admin.accounts.types.chatCompletionsAuthHeaderHint') }}

+
+ + +
+ + + +
+ + +
+ + +
+ +

+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} + {{ t('admin.accounts.supportsAllModels') }} +

+
+ + +
+
+

+ + + + {{ t('admin.accounts.mapRequestModels') }} +

+
+ +
+
+ + + + + + +
+
+ + + +
+ +
+
+
+
+
@@ -2780,6 +2980,34 @@
+
+
+
+ +

+ {{ t('admin.accounts.stripReasoningEffortOnCCDesc') }} +

+
+ +
+
+
@@ -3357,7 +3585,12 @@ interface TempUnschedRuleForm { // State const step = ref(1) const submitting = ref(false) -const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock' | 'service_account'>('oauth-based') // UI selection for account category +type ChatCompletionsAuthHeader = 'Authorization' | 'api-key' | 'x-api-key' + +const accountCategory = ref<'oauth-based' | 'apikey' | 'apikey-chat-completions' | 'bedrock' | 'service_account'>('oauth-based') // UI selection for account category +const chatCompletionsUrl = ref('') +const chatCompletionsApiKey = ref('') +const chatCompletionsAuthHeader = ref('Authorization') const addMethod = ref('oauth') // For oauth-based: 'oauth' or 'setup-token' const apiKeyBaseUrl = ref('https://api.anthropic.com') const apiKeyValue = ref('') @@ -3413,6 +3646,7 @@ const selectedErrorCodes = ref([]) const customErrorCodeInput = ref(null) const interceptWarmupRequests = ref(false) const autoPauseOnExpired = ref(true) +const stripReasoningEffortOnCC = ref(false) const openaiPassthroughEnabled = ref(false) const openAICompactMode = ref('auto') const openAIResponsesMode = ref('auto') @@ -3781,6 +4015,8 @@ watch( form.type = 'service_account' as AccountType } else if (category === 'oauth-based') { form.type = method as AccountType // 'oauth' or 'setup-token' + } else if (category === 'apikey-chat-completions') { + form.type = 'apikey-chat-completions' as AccountType } else { form.type = 'apikey' } @@ -3823,6 +4059,9 @@ watch( if (newPlatform !== 'anthropic' && accountCategory.value === 'bedrock') { accountCategory.value = 'oauth-based' } + if (newPlatform !== 'openai' && newPlatform !== 'anthropic' && accountCategory.value === 'apikey-chat-completions') { + accountCategory.value = 'oauth-based' + } // Reset Bedrock fields when switching platforms bedrockAccessKeyId.value = '' bedrockSecretAccessKey.value = '' @@ -4213,6 +4452,9 @@ const resetForm = () => { addMethod.value = 'oauth' apiKeyBaseUrl.value = 'https://api.anthropic.com' apiKeyValue.value = '' + chatCompletionsUrl.value = '' + chatCompletionsApiKey.value = '' + chatCompletionsAuthHeader.value = 'Authorization' editQuotaLimit.value = null editQuotaDailyLimit.value = null editQuotaWeeklyLimit.value = null @@ -4240,6 +4482,7 @@ const resetForm = () => { customErrorCodeInput.value = null interceptWarmupRequests.value = false autoPauseOnExpired.value = true + stripReasoningEffortOnCC.value = false openaiPassthroughEnabled.value = false openAICompactMode.value = 'auto' openAIResponsesMode.value = 'auto' @@ -4605,6 +4848,27 @@ const handleSubmit = async () => { return } + // For apikey-chat-completions type, create directly + if (form.type === 'apikey-chat-completions') { + if (!chatCompletionsUrl.value.trim()) { + appStore.showError(t('admin.accounts.types.chatCompletionsUrlPlaceholder')) + return + } + if (!chatCompletionsApiKey.value.trim()) { + appStore.showError(t('admin.accounts.pleaseEnterApiKey')) + return + } + const credentials: Record = { + chat_completions_url: chatCompletionsUrl.value.trim(), + api_key: chatCompletionsApiKey.value.trim(), + auth_header: chatCompletionsAuthHeader.value + } + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) credentials.model_mapping = modelMapping + await createAccountAndFinish(form.platform, 'apikey-chat-completions' as AccountType, credentials) + return + } + // For apikey type, create directly if (!apiKeyValue.value.trim()) { appStore.showError(t('admin.accounts.pleaseEnterApiKey')) @@ -4783,7 +5047,8 @@ const createAccountAndFinish = async ( rate_multiplier: form.rate_multiplier, group_ids: form.group_ids, expires_at: form.expires_at, - auto_pause_on_expired: autoPauseOnExpired.value + auto_pause_on_expired: autoPauseOnExpired.value, + ...(type === 'apikey-chat-completions' ? { strip_reasoning_effort_on_cc: stripReasoningEffortOnCC.value } : {}) }) } diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 8dc85d0eb84..a2e55b492e6 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -579,6 +579,137 @@
+ +
+
+ + +

{{ t('admin.accounts.types.chatCompletionsUrlHint') }}

+
+
+ + +

{{ t('admin.accounts.leaveEmptyToKeep') }}

+
+
+ + +

{{ t('admin.accounts.types.chatCompletionsAuthHeaderHint') }}

+
+ + +
+ + + +
+ + +
+ + +
+ +

+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} + {{ t('admin.accounts.supportsAllModels') }} +

+
+ + +
+
+

+ {{ t('admin.accounts.mapRequestModels') }} +

+
+ +
+
+ + + + + + +
+
+ + + +
+ +
+
+
+
+
@@ -1787,6 +1918,34 @@
+
+
+
+ +

+ {{ t('admin.accounts.stripReasoningEffortOnCCDesc') }} +

+
+ +
+
+

{{ t('admin.accounts.autoPauseThresholdHint') }}

+
@@ -2460,10 +2620,19 @@ interface TempUnschedRuleForm { description: string } +type ChatCompletionsAuthHeader = 'Authorization' | 'api-key' | 'x-api-key' + +const normalizeChatCompletionsAuthHeader = (value: unknown): ChatCompletionsAuthHeader => { + if (value === 'api-key' || value === 'x-api-key' || value === 'Authorization') return value + return 'Authorization' +} + // State const submitting = ref(false) const editBaseUrl = ref('https://api.anthropic.com') const editApiKey = ref('') +const editChatCompletionsUrl = ref('') +const editChatCompletionsAuthHeader = ref('Authorization') // Bedrock credentials const editBedrockAccessKeyId = ref('') const editBedrockSecretAccessKey = ref('') @@ -2525,10 +2694,12 @@ const selectedErrorCodes = ref([]) const customErrorCodeInput = ref(null) const interceptWarmupRequests = ref(false) const autoPauseOnExpired = ref(false) +const stripReasoningEffortOnCC = ref(false) const autoPause5hThreshold = ref(null) const autoPause7dThreshold = ref(null) const autoPause5hDisabled = ref(false) const autoPause7dDisabled = ref(false) + const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling const allowOverages = ref(false) // For antigravity accounts: enable AI Credits overages const antigravityModelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') @@ -2937,6 +3108,7 @@ const syncFormFromAccount = (newAccount: Account | null) => { const credentials = newAccount.credentials as Record | undefined interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true + stripReasoningEffortOnCC.value = newAccount.strip_reasoning_effort_on_cc === true editVertexProjectId.value = '' editVertexClientEmail.value = '' editVertexLocation.value = 'us-central1' @@ -3118,6 +3290,31 @@ const syncFormFromAccount = (newAccount: Account | null) => { } else { selectedErrorCodes.value = [] } + } else if (newAccount.type === 'apikey-chat-completions' && newAccount.credentials) { + const ccCreds = newAccount.credentials as Record + editChatCompletionsUrl.value = (ccCreds.chat_completions_url as string) || '' + editChatCompletionsAuthHeader.value = normalizeChatCompletionsAuthHeader(ccCreds.auth_header) + editApiKey.value = '' + + // Load model mappings and detect mode + const existingMappings = ccCreds.model_mapping as Record | undefined + if (existingMappings && typeof existingMappings === 'object') { + const entries = Object.entries(existingMappings) + const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) + if (isWhitelistMode) { + modelRestrictionMode.value = 'whitelist' + allowedModels.value = entries.map(([from]) => from) + modelMappings.value = [] + } else { + modelRestrictionMode.value = 'mapping' + modelMappings.value = entries.map(([from, to]) => ({ from, to })) + allowedModels.value = [] + } + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } } else if (newAccount.type === 'bedrock' && newAccount.credentials) { const bedrockCreds = newAccount.credentials as Record const authMode = (bedrockCreds.auth_mode as string) || 'sigv4' @@ -3680,6 +3877,9 @@ const handleSubmit = async () => { updatePayload.load_factor = 0 } updatePayload.auto_pause_on_expired = autoPauseOnExpired.value + if (props.account?.type === 'apikey-chat-completions') { + updatePayload.strip_reasoning_effort_on_cc = stripReasoningEffortOnCC.value + } // For apikey type, handle credentials update if (props.account.type === 'apikey') { @@ -3759,6 +3959,33 @@ const handleSubmit = async () => { return } + updatePayload.credentials = newCredentials + } else if (props.account.type === 'apikey-chat-completions') { + const currentCredentials = (props.account.credentials as Record) || {} + const newCredentials: Record = { ...currentCredentials } + + if (!editChatCompletionsUrl.value.trim()) { + appStore.showError(t('admin.accounts.types.chatCompletionsUrlPlaceholder')) + return + } + newCredentials.chat_completions_url = editChatCompletionsUrl.value.trim() + newCredentials.auth_header = editChatCompletionsAuthHeader.value + + if (editApiKey.value.trim()) { + newCredentials.api_key = editApiKey.value.trim() + } + + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) { + newCredentials.model_mapping = modelMapping + } else { + delete newCredentials.model_mapping + } + + if (!applyTempUnschedConfig(newCredentials)) { + return + } + updatePayload.credentials = newCredentials } else if (props.account.type === 'upstream') { const currentCredentials = (props.account.credentials as Record) || {} diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index b4e231e73f1..cb44aa0d045 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -3061,7 +3061,22 @@ export default { antigravityOauth: 'Antigravity OAuth', antigravityApikey: 'Connect via Base URL + API Key', upstream: 'Upstream', - upstreamDesc: 'Connect via Base URL + API Key' + upstreamDesc: 'Connect via Base URL + API Key', + apikeyChatCompletions: 'OpenAI Chat Completions Upstream', + apikeyChatCompletionsDesc: + 'Connect to any OpenAI-compatible /v1/chat/completions endpoint via full URL + API Key', + apikeyChatCompletionsAnthropicHint: + 'On the Anthropic platform, the gateway automatically converts /v1/messages requests to OpenAI Chat Completions and translates the upstream response back. The /v1/chat/completions endpoint passes through unchanged.', + chatCompletionsUrl: 'Chat Completions URL', + chatCompletionsUrlPlaceholder: + 'Full endpoint URL, e.g. https://api.deepseek.com/v1/chat/completions', + chatCompletionsUrlHint: 'Full /v1/chat/completions endpoint URL of the upstream service', + chatCompletionsApiKey: 'API Key', + chatCompletionsApiKeyPlaceholder: 'API Key for the upstream service', + chatCompletionsApiKeyHint: 'API Key used to authenticate against the upstream', + chatCompletionsAuthHeader: 'Auth Header', + chatCompletionsAuthHeaderHint: + 'Select the API Key header required by the upstream. Most OpenAI-compatible APIs use Authorization: Bearer; some use api-key or x-api-key.' }, status: { active: 'Active', @@ -3476,12 +3491,15 @@ export default { 'When enabled, warmup requests like title generation will return mock responses without consuming upstream tokens', autoPauseOnExpired: 'Auto Pause On Expired', autoPauseOnExpiredDesc: 'When enabled, the account will auto pause scheduling after it expires', - autoPause5hThreshold: '5h Usage Threshold (%)', - autoPause7dThreshold: '7d Usage Threshold (%)', - autoPauseThresholdHint: 'Leave empty or set 0 to use the global default threshold (configured in Ops settings); set a value to override the global default. Reaching the threshold only skips the account during scheduling and does not modify schedulable.', - autoPause5hDisabled: 'Disable 5h auto-pause', - autoPause7dDisabled: 'Disable 7d auto-pause', - autoPauseDisabledHint: 'When enabled, this account is never auto-paused (even if a global default threshold is configured).', +stripReasoningEffortOnCC: 'Strip reasoning_effort (Chat Completions)', + stripReasoningEffortOnCCDesc: 'Drop reasoning_effort when forwarding /responses traffic as /v1/chat/completions. Enable for upstreams that reject this combination, e.g. b.ai.', + autoPause5hThreshold: '5h Usage Threshold (%)', + autoPause7dThreshold: '7d Usage Threshold (%)', + autoPauseThresholdHint: 'Leave empty or set 0 to use the global default threshold (configured in Ops settings); set a value to override the global default. Reaching the threshold only skips the account during scheduling and does not modify schedulable.', + autoPause5hDisabled: 'Disable 5h auto-pause', + autoPause7dDisabled: 'Disable 7d auto-pause', + autoPauseDisabledHint: 'When enabled, this account is never auto-paused (even if a global default threshold is configured).', + // Quota control (Anthropic OAuth/SetupToken only) quotaControl: { title: 'Quota Control', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 70a7f6dfcbb..d5c3fe5ec9e 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -3249,6 +3249,21 @@ export default { antigravityApikey: '通过 Base URL + API Key 连接', upstream: '对接上游', upstreamDesc: '通过 Base URL + API Key 连接上游', + apikeyChatCompletions: 'OpenAI Chat Completions 上游', + apikeyChatCompletionsDesc: + '通过完整 URL + API Key 对接任意 OpenAI 兼容的 /v1/chat/completions 端点', + apikeyChatCompletionsAnthropicHint: + 'Anthropic 平台下,网关会自动把 /v1/messages 请求转换为 OpenAI Chat Completions 发给上游,再将响应翻译回 Messages 协议;/v1/chat/completions 走原样透传。', + chatCompletionsUrl: 'Chat Completions URL', + chatCompletionsUrlPlaceholder: + '完整端点 URL,例如 https://api.deepseek.com/v1/chat/completions', + chatCompletionsUrlHint: '上游服务的 /v1/chat/completions 完整端点 URL', + chatCompletionsApiKey: 'API Key', + chatCompletionsApiKeyPlaceholder: '上游服务的 API Key', + chatCompletionsApiKeyHint: '用于鉴权上游服务的 API Key', + chatCompletionsAuthHeader: '鉴权 Header', + chatCompletionsAuthHeaderHint: + '选择上游接口要求的 API Key Header。多数 OpenAI 兼容接口使用 Authorization: Bearer;部分接口使用 api-key 或 x-api-key。', api_key: 'API Key', cookie: 'Cookie' }, @@ -3614,12 +3629,15 @@ export default { interceptWarmupRequestsDesc: '启用后,标题生成等预热请求将返回 mock 响应,不消耗上游 token', autoPauseOnExpired: '过期自动暂停调度', autoPauseOnExpiredDesc: '启用后,账号过期将自动暂停调度', - autoPause5hThreshold: '5h 用量阈值(%)', - autoPause7dThreshold: '7d 用量阈值(%)', - autoPauseThresholdHint: '留空或填 0 表示使用全局默认阈值(在运维设置中配置);填具体值则覆盖全局默认。达到阈值后仅在调度时跳过账号,不修改 schedulable。', - autoPause5hDisabled: '禁用 5h 自动暂停', - autoPause7dDisabled: '禁用 7d 自动暂停', - autoPauseDisabledHint: '开启后该账号永不进入自动暂停(即使全局默认阈值已配置)。', +stripReasoningEffortOnCC: '剥离 reasoning_effort(Chat Completions)', + stripReasoningEffortOnCCDesc: '当 /responses 请求被转换为 /v1/chat/completions 转发时丢弃 reasoning_effort。适用于拒绝该组合的上游,例如 b.ai。', + autoPause5hThreshold: '5h 用量阈值(%)', + autoPause7dThreshold: '7d 用量阈值(%)', + autoPauseThresholdHint: '留空或填 0 表示使用全局默认阈值(在运维设置中配置);填具体值则覆盖全局默认。达到阈值后仅在调度时跳过账号,不修改 schedulable。', + autoPause5hDisabled: '禁用 5h 自动暂停', + autoPause7dDisabled: '禁用 7d 自动暂停', + autoPauseDisabledHint: '开启后该账号永不进入自动暂停(即使全局默认阈值已配置)。', + // Quota control (Anthropic OAuth/SetupToken only) quotaControl: { title: '配额控制', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 718c35cb27e..b4cf9a46d3d 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -689,7 +689,14 @@ export interface UpdateGroupRequest { // ==================== Account & Proxy Types ==================== export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' -export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream' | 'bedrock' | 'service_account' +export type AccountType = + | 'oauth' + | 'setup-token' + | 'apikey' + | 'apikey-chat-completions' + | 'upstream' + | 'bedrock' + | 'service_account' export type OAuthAddMethod = 'oauth' | 'setup-token' export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h' @@ -839,6 +846,7 @@ export interface Account { last_used_at: string | null expires_at: number | null auto_pause_on_expired: boolean + strip_reasoning_effort_on_cc: boolean created_at: string updated_at: string proxy?: Proxy @@ -1028,6 +1036,7 @@ export interface CreateAccountRequest { group_ids?: number[] expires_at?: number | null auto_pause_on_expired?: boolean + strip_reasoning_effort_on_cc?: boolean confirm_mixed_channel_risk?: boolean } @@ -1047,6 +1056,7 @@ export interface UpdateAccountRequest { group_ids?: number[] expires_at?: number | null auto_pause_on_expired?: boolean + strip_reasoning_effort_on_cc?: boolean confirm_mixed_channel_risk?: boolean } @@ -1121,6 +1131,7 @@ export interface AdminDataAccount { rate_multiplier?: number | null expires_at?: number | null auto_pause_on_expired?: boolean + strip_reasoning_effort_on_cc?: boolean } export interface AdminDataImportError {