Skip to content

Commit ed27b6f

Browse files
DeepAnchorrs
authored andcommitted
Extend Chain API, fix typos (#13)
Extend Chain API with 'Add' and 'With' methods.
1 parent d9d9599 commit ed27b6f

5 files changed

Lines changed: 174 additions & 16 deletions

File tree

chain.go

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,47 @@ import (
66
"golang.org/x/net/context"
77
)
88

9-
// Chain is an helper to chain middleware handlers together for an easier
9+
// Chain is a helper for chaining middleware handlers together for easier
1010
// management.
1111
type Chain []func(next HandlerC) HandlerC
1212

13+
// Add appends a variable number of additional middleware handlers
14+
// to the middleware chain. Middleware handlers can either be
15+
// context-aware or non-context aware handlers with the appropriate
16+
// function signatures.
17+
func (c *Chain) Add(f ...interface{}) {
18+
for _, h := range f {
19+
switch v := h.(type) {
20+
case func(http.Handler) http.Handler:
21+
c.Use(v)
22+
case func(HandlerC) HandlerC:
23+
c.UseC(v)
24+
default:
25+
panic("Adding invalid handler to the middleware chain")
26+
}
27+
}
28+
}
29+
30+
// With creates a new middleware chain from an existing chain,
31+
// extending it with additional middleware. Middleware handlers
32+
// can either be context-aware or non-context aware handlers
33+
// with the appropriate function signatures.
34+
func (c *Chain) With(f ...interface{}) *Chain {
35+
n := make(Chain, len(*c))
36+
copy(n, *c)
37+
n.Add(f...)
38+
return &n
39+
}
40+
1341
// UseC appends a context-aware handler to the middleware chain.
1442
func (c *Chain) UseC(f func(next HandlerC) HandlerC) {
1543
*c = append(*c, f)
1644
}
1745

1846
// Use appends a standard http.Handler to the middleware chain without
19-
// lossing track of the context when inserted between two context aware handlers.
47+
// losing track of the context when inserted between two context aware handlers.
2048
//
21-
// Caveat: the f function will be called on each request so you are better to put
49+
// Caveat: the f function will be called on each request so you are better off putting
2250
// any initialization sequence outside of this function.
2351
func (c *Chain) Use(f func(next http.Handler) http.Handler) {
2452
xf := func(next HandlerC) HandlerC {
@@ -33,14 +61,14 @@ func (c *Chain) Use(f func(next http.Handler) http.Handler) {
3361
}
3462

3563
// Handler wraps the provided final handler with all the middleware appended to
36-
// the chain and return a new standard http.Handler instance.
64+
// the chain and returns a new standard http.Handler instance.
3765
// The context.Background() context is injected automatically.
3866
func (c Chain) Handler(xh HandlerC) http.Handler {
3967
ctx := context.Background()
4068
return c.HandlerCtx(ctx, xh)
4169
}
4270

43-
// HandlerFC is an helper to provide a function (HandlerFuncC) to Handler().
71+
// HandlerFC is a helper to provide a function (HandlerFuncC) to Handler().
4472
//
4573
// HandlerFC is equivalent to:
4674
// c.Handler(xhandler.HandlerFuncC(xhc))
@@ -49,18 +77,18 @@ func (c Chain) HandlerFC(xhf HandlerFuncC) http.Handler {
4977
return c.HandlerCtx(ctx, HandlerFuncC(xhf))
5078
}
5179

52-
// HandlerH is an helper to provide a standard http handler (http.HandlerFunc)
53-
// to Handler(). Your final handler won't have access the context though.
80+
// HandlerH is a helper to provide a standard http handler (http.HandlerFunc)
81+
// to Handler(). Your final handler won't have access to the context though.
5482
func (c Chain) HandlerH(h http.Handler) http.Handler {
5583
ctx := context.Background()
5684
return c.HandlerCtx(ctx, HandlerFuncC(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
5785
h.ServeHTTP(w, r)
5886
}))
5987
}
6088

61-
// HandlerF is an helper to provide a standard http handler function
89+
// HandlerF is a helper to provide a standard http handler function
6290
// (http.HandlerFunc) to Handler(). Your final handler won't have access
63-
// the context though.
91+
// to the context though.
6492
func (c Chain) HandlerF(hf http.HandlerFunc) http.Handler {
6593
ctx := context.Background()
6694
return c.HandlerCtx(ctx, HandlerFuncC(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
@@ -69,7 +97,7 @@ func (c Chain) HandlerF(hf http.HandlerFunc) http.Handler {
6997
}
7098

7199
// HandlerCtx wraps the provided final handler with all the middleware appended to
72-
// the chain and return a new standard http.Handler instance.
100+
// the chain and returns a new standard http.Handler instance.
73101
func (c Chain) HandlerCtx(ctx context.Context, xh HandlerC) http.Handler {
74102
return New(ctx, c.HandlerC(xh))
75103
}

chain_example_test.go

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,43 @@ func ExampleChain() {
3535
})))
3636
}
3737

38+
func ExampleAddChain() {
39+
c := xhandler.Chain{}
40+
41+
close := xhandler.CloseHandler
42+
cors := cors.Default().Handler
43+
timeout := xhandler.TimeoutHandler(2 * time.Second)
44+
auth := func(next xhandler.HandlerC) xhandler.HandlerC {
45+
return xhandler.HandlerFuncC(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
46+
if v := ctx.Value("Authorization"); v == nil {
47+
http.Error(w, "Not authorized", http.StatusUnauthorized)
48+
return
49+
}
50+
next.ServeHTTPC(ctx, w, r)
51+
})
52+
}
53+
54+
c.Add(close, cors, timeout)
55+
56+
mux := http.NewServeMux()
57+
58+
// Use c.Handler to terminate the chain with your final handler
59+
mux.Handle("/", c.Handler(xhandler.HandlerFuncC(func(ctx context.Context, w http.ResponseWriter, req *http.Request) {
60+
fmt.Fprintf(w, "Welcome to the home page!")
61+
})))
62+
63+
// Create a new chain from an existing one, and add route-specific middleware to it
64+
protected := c.With(auth)
65+
66+
mux.Handle("/admin", protected.Handler(xhandler.HandlerFuncC(func(ctx context.Context, w http.ResponseWriter, req *http.Request) {
67+
fmt.Fprintf(w, "protected endpoint!")
68+
})))
69+
}
70+
3871
func ExampleIf() {
3972
c := xhandler.Chain{}
4073

41-
// Add timeout handler only if the path match a prefix
74+
// Add a timeout handler only if the URL path matches a prefix
4275
c.UseC(xhandler.If(
4376
func(ctx context.Context, w http.ResponseWriter, r *http.Request) bool {
4477
return strings.HasPrefix(r.URL.Path, "/with-timeout/")

chain_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,100 @@ func TestChainHandlerC(t *testing.T) {
110110

111111
assert.Equal(t, 3, handlerCalls, "all handler called once")
112112
}
113+
114+
func TestAdd(t *testing.T) {
115+
handlerCalls := 0
116+
h1 := func(next HandlerC) HandlerC {
117+
return HandlerFuncC(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
118+
handlerCalls++
119+
ctx = context.WithValue(ctx, "test", 1)
120+
next.ServeHTTPC(ctx, w, r)
121+
})
122+
}
123+
h2 := func(next http.Handler) http.Handler {
124+
handlerCalls++
125+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
126+
// Change r and w values
127+
w = httptest.NewRecorder()
128+
r = &http.Request{}
129+
next.ServeHTTP(w, r)
130+
})
131+
}
132+
h3 := func(next HandlerC) HandlerC {
133+
return HandlerFuncC(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
134+
handlerCalls++
135+
ctx = context.WithValue(ctx, "test", 2)
136+
next.ServeHTTPC(ctx, w, r)
137+
})
138+
}
139+
140+
c := Chain{}
141+
c.Add(h1, h2, h3)
142+
h := c.HandlerC(HandlerFuncC(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
143+
handlerCalls++
144+
145+
assert.Equal(t, 2, ctx.Value("test"),
146+
"third handler should overwrite first handler's context value")
147+
assert.Equal(t, 1, ctx.Value("mainCtx"),
148+
"the mainCtx value should be pass through")
149+
}))
150+
151+
mainCtx := context.WithValue(context.Background(), "mainCtx", 1)
152+
h.ServeHTTPC(mainCtx, nil, nil)
153+
assert.Equal(t, 4, handlerCalls, "all handler called once")
154+
}
155+
156+
func TestWith(t *testing.T) {
157+
handlerCalls := 0
158+
h1 := func(next HandlerC) HandlerC {
159+
return HandlerFuncC(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
160+
handlerCalls++
161+
ctx = context.WithValue(ctx, "test", 1)
162+
next.ServeHTTPC(ctx, w, r)
163+
})
164+
}
165+
h2 := func(next http.Handler) http.Handler {
166+
handlerCalls++
167+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
168+
// Change r and w values
169+
w = httptest.NewRecorder()
170+
r = &http.Request{}
171+
next.ServeHTTP(w, r)
172+
})
173+
}
174+
h3 := func(next HandlerC) HandlerC {
175+
return HandlerFuncC(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
176+
handlerCalls++
177+
ctx = context.WithValue(ctx, "test", 2)
178+
next.ServeHTTPC(ctx, w, r)
179+
})
180+
}
181+
182+
c := Chain{}
183+
c.Add(h1)
184+
d := c.With(h2, h3)
185+
186+
h := c.HandlerC(HandlerFuncC(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
187+
handlerCalls++
188+
189+
assert.Equal(t, 1, ctx.Value("test"),
190+
"third handler should not overwrite the first handler's context value")
191+
assert.Equal(t, 1, ctx.Value("mainCtx"),
192+
"the mainCtx value should be pass through")
193+
}))
194+
i := d.HandlerC(HandlerFuncC(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
195+
handlerCalls++
196+
197+
assert.Equal(t, 2, ctx.Value("test"),
198+
"third handler should overwrite first handler's context value")
199+
assert.Equal(t, 1, ctx.Value("mainCtx"),
200+
"the mainCtx value should be pass through")
201+
}))
202+
203+
mainCtx := context.WithValue(context.Background(), "mainCtx", 1)
204+
h.ServeHTTPC(mainCtx, nil, nil)
205+
assert.Equal(t, 2, handlerCalls, "all handlers called once")
206+
handlerCalls = 0
207+
i.ServeHTTPC(mainCtx, nil, nil)
208+
assert.Equal(t, 4, handlerCalls, "all handler called once")
209+
}

middleware.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ import (
77
"golang.org/x/net/context"
88
)
99

10-
// CloseHandler returns a Handler cancelling the context when the client
11-
// connection close unexpectedly.
10+
// CloseHandler returns a Handler, cancelling the context when the client
11+
// connection closes unexpectedly.
1212
func CloseHandler(next HandlerC) HandlerC {
1313
return HandlerFuncC(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
1414
// Cancel the context if the client closes the connection
@@ -33,7 +33,7 @@ func CloseHandler(next HandlerC) HandlerC {
3333

3434
// TimeoutHandler returns a Handler which adds a timeout to the context.
3535
//
36-
// Child handlers have the responsability to obey the context deadline and to return
36+
// Child handlers have the responsability of obeying the context deadline and to return
3737
// an appropriate error (or not) response in case of timeout.
3838
func TimeoutHandler(timeout time.Duration) func(next HandlerC) HandlerC {
3939
return func(next HandlerC) HandlerC {

xhandler.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
// the connection unexpectedly.
99
//
1010
// You may create net/context aware middlewares pretty much the same way as
11-
// you would do with http.Handler.
11+
// you would with http.Handler.
1212
package xhandler // import "github.com/rs/xhandler"
1313

1414
import (
@@ -23,7 +23,7 @@ type HandlerC interface {
2323
}
2424

2525
// HandlerFuncC type is an adapter to allow the use of ordinary functions
26-
// as a xhandler.Handler. If f is a function with the appropriate signature,
26+
// as an xhandler.Handler. If f is a function with the appropriate signature,
2727
// xhandler.HandlerFuncC(f) is a xhandler.Handler object that calls f.
2828
type HandlerFuncC func(context.Context, http.ResponseWriter, *http.Request)
2929

0 commit comments

Comments
 (0)