Skip to content

Commit 9a53caf

Browse files
Add mock implementations for testing neuron components
- Introduced `MockRoundTripper`, `MockCache`, `MockAuthProvider`, `MockRateLimiter`, and `MockValidator` to facilitate unit testing without real HTTP requests. - Implemented comprehensive test cases for each mock, ensuring functionality such as request recording, error injection, and response queuing. - Enhanced thread safety across mock implementations to support concurrent testing scenarios. - Added utility functions for assertions and error handling to streamline test writing and improve maintainability.
1 parent 934057b commit 9a53caf

15 files changed

Lines changed: 4782 additions & 0 deletions

mock/assertions.go

Lines changed: 447 additions & 0 deletions
Large diffs are not rendered by default.

mock/assertions_test.go

Lines changed: 435 additions & 0 deletions
Large diffs are not rendered by default.

mock/auth.go

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
package mock
2+
3+
import (
4+
"context"
5+
"sync"
6+
"sync/atomic"
7+
"time"
8+
)
9+
10+
// MockAuthProvider is a mock implementation of neuron.AuthProvider for testing.
11+
type MockAuthProvider struct {
12+
tokens []string
13+
tokenIndex atomic.Int32
14+
getCalls []GetTokenCall
15+
headerCalls []GetHeaderCall
16+
tokenErr atomic.Value // *errHolder
17+
headerFormat string
18+
recordCalls atomic.Bool
19+
mu sync.RWMutex
20+
}
21+
22+
// GetTokenCall records a call to GetToken.
23+
type GetTokenCall struct {
24+
Time time.Time
25+
Result string
26+
Err error
27+
}
28+
29+
// GetHeaderCall records a call to GetAuthHeader.
30+
type GetHeaderCall struct {
31+
Token string
32+
Result string
33+
Time time.Time
34+
}
35+
36+
// MockAuthProviderOptions configures mock auth provider behavior.
37+
type MockAuthProviderOptions struct {
38+
InitialToken string
39+
RecordCalls bool
40+
Tokens []string
41+
HeaderFormat string // e.g., "Bearer {}" or "X-API-Key {}"
42+
}
43+
44+
// NewMockAuthProvider creates a new mock auth provider.
45+
func NewMockAuthProvider(opts *MockAuthProviderOptions) *MockAuthProvider {
46+
ap := &MockAuthProvider{
47+
tokens: make([]string, 0),
48+
getCalls: make([]GetTokenCall, 0),
49+
headerCalls: make([]GetHeaderCall, 0),
50+
headerFormat: "Bearer {}",
51+
}
52+
53+
if opts == nil {
54+
ap.recordCalls.Store(true)
55+
return ap
56+
}
57+
58+
ap.recordCalls.Store(opts.RecordCalls)
59+
60+
if opts.InitialToken != "" {
61+
ap.tokens = append(ap.tokens, opts.InitialToken)
62+
}
63+
64+
if len(opts.Tokens) > 0 {
65+
ap.tokens = append(ap.tokens, opts.Tokens...)
66+
}
67+
68+
if opts.HeaderFormat != "" {
69+
ap.headerFormat = opts.HeaderFormat
70+
}
71+
72+
return ap
73+
}
74+
75+
// GetToken returns the current token, or the next token in the sequence if configured.
76+
func (ap *MockAuthProvider) GetToken(ctx context.Context) (string, error) {
77+
select {
78+
case <-ctx.Done():
79+
return "", ctx.Err()
80+
default:
81+
}
82+
83+
// Check for injected error
84+
if err := ap.getAndConsumeError(); err != nil {
85+
if ap.recordCalls.Load() {
86+
ap.mu.Lock()
87+
ap.getCalls = append(ap.getCalls, GetTokenCall{
88+
Time: time.Now(),
89+
Result: "",
90+
Err: err,
91+
})
92+
ap.mu.Unlock()
93+
}
94+
return "", err
95+
}
96+
97+
token := ap.getCurrentToken()
98+
99+
if ap.recordCalls.Load() {
100+
ap.mu.Lock()
101+
ap.getCalls = append(ap.getCalls, GetTokenCall{
102+
Time: time.Now(),
103+
Result: token,
104+
Err: nil,
105+
})
106+
ap.mu.Unlock()
107+
}
108+
109+
return token, nil
110+
}
111+
112+
// getCurrentToken returns the current or next token in the sequence.
113+
func (ap *MockAuthProvider) getCurrentToken() string {
114+
ap.mu.RLock()
115+
defer ap.mu.RUnlock()
116+
117+
if len(ap.tokens) == 0 {
118+
return ""
119+
}
120+
121+
idx := ap.tokenIndex.Load()
122+
if int(idx) >= len(ap.tokens) {
123+
return ap.tokens[len(ap.tokens)-1]
124+
}
125+
126+
return ap.tokens[idx]
127+
}
128+
129+
// GetAuthHeader returns the formatted authentication header value.
130+
func (ap *MockAuthProvider) GetAuthHeader(token string) string {
131+
header := formatHeader(ap.headerFormat, token)
132+
133+
if ap.recordCalls.Load() {
134+
ap.mu.Lock()
135+
ap.headerCalls = append(ap.headerCalls, GetHeaderCall{
136+
Token: token,
137+
Result: header,
138+
Time: time.Now(),
139+
})
140+
ap.mu.Unlock()
141+
}
142+
143+
return header
144+
}
145+
146+
// formatHeader substitutes {} with the token in the format string.
147+
func formatHeader(format, token string) string {
148+
for i := 0; i < len(format)-1; i++ {
149+
if format[i] == '{' && format[i+1] == '}' {
150+
return format[:i] + token + format[i+2:]
151+
}
152+
}
153+
return format + " " + token
154+
}
155+
156+
// getAndConsumeError retrieves and clears an injected error (one-shot).
157+
func (ap *MockAuthProvider) getAndConsumeError() error {
158+
val := ap.tokenErr.Load()
159+
if val == nil {
160+
return nil
161+
}
162+
163+
holder := val.(*errHolder)
164+
if holder == nil || holder.err == nil {
165+
return nil
166+
}
167+
168+
err := holder.err
169+
if holder.oneShot {
170+
ap.tokenErr.Store((*errHolder)(nil))
171+
}
172+
return err
173+
}
174+
175+
// InjectTokenError injects an error for the next GetToken call (one-shot by default).
176+
func (ap *MockAuthProvider) InjectTokenError(err error) {
177+
if err == nil {
178+
ap.tokenErr.Store((*errHolder)(nil))
179+
} else {
180+
ap.tokenErr.Store(&errHolder{err: err, oneShot: true})
181+
}
182+
}
183+
184+
// ClearInjectedErrors clears all injected errors.
185+
func (ap *MockAuthProvider) ClearInjectedErrors() {
186+
ap.tokenErr.Store((*errHolder)(nil))
187+
}
188+
189+
// SetTokens sets the token sequence for rotation testing.
190+
func (ap *MockAuthProvider) SetTokens(tokens []string) {
191+
ap.mu.Lock()
192+
ap.tokens = make([]string, len(tokens))
193+
copy(ap.tokens, tokens)
194+
ap.mu.Unlock()
195+
ap.tokenIndex.Store(0)
196+
}
197+
198+
// RotateToken advances to the next token in the sequence.
199+
func (ap *MockAuthProvider) RotateToken() {
200+
ap.mu.RLock()
201+
tokenCount := len(ap.tokens)
202+
ap.mu.RUnlock()
203+
204+
if tokenCount == 0 {
205+
return
206+
}
207+
208+
idx := ap.tokenIndex.Load()
209+
if int(idx)+1 < tokenCount {
210+
ap.tokenIndex.Store(idx + 1)
211+
}
212+
}
213+
214+
// CurrentTokenIndex returns the index of the current token.
215+
func (ap *MockAuthProvider) CurrentTokenIndex() int {
216+
ap.mu.RLock()
217+
tokenCount := len(ap.tokens)
218+
ap.mu.RUnlock()
219+
220+
idx := ap.tokenIndex.Load()
221+
if tokenCount == 0 {
222+
return -1
223+
}
224+
if int(idx) >= tokenCount {
225+
return tokenCount - 1
226+
}
227+
return int(idx)
228+
}
229+
230+
// Reset resets the token sequence to the first token.
231+
func (ap *MockAuthProvider) Reset() {
232+
ap.tokenIndex.Store(0)
233+
ap.mu.Lock()
234+
ap.getCalls = ap.getCalls[:0]
235+
ap.headerCalls = ap.headerCalls[:0]
236+
ap.mu.Unlock()
237+
ap.tokenErr.Store((*errHolder)(nil))
238+
}
239+
240+
// GetTokenCalls returns a copy of all recorded GetToken calls.
241+
func (ap *MockAuthProvider) GetTokenCalls() []GetTokenCall {
242+
ap.mu.RLock()
243+
defer ap.mu.RUnlock()
244+
245+
calls := make([]GetTokenCall, len(ap.getCalls))
246+
copy(calls, ap.getCalls)
247+
return calls
248+
}
249+
250+
// GetHeaderCalls returns a copy of all recorded GetAuthHeader calls.
251+
func (ap *MockAuthProvider) GetHeaderCalls() []GetHeaderCall {
252+
ap.mu.RLock()
253+
defer ap.mu.RUnlock()
254+
255+
calls := make([]GetHeaderCall, len(ap.headerCalls))
256+
copy(calls, ap.headerCalls)
257+
return calls
258+
}
259+
260+
// ClearRecorded clears all recorded calls.
261+
func (ap *MockAuthProvider) ClearRecorded() {
262+
ap.mu.Lock()
263+
ap.getCalls = ap.getCalls[:0]
264+
ap.headerCalls = ap.headerCalls[:0]
265+
ap.mu.Unlock()
266+
}
267+
268+
// SetHeaderFormat sets the format for GetAuthHeader.
269+
// Use "{}" as a placeholder for the token.
270+
func (ap *MockAuthProvider) SetHeaderFormat(format string) {
271+
ap.headerFormat = format
272+
}

0 commit comments

Comments
 (0)