@@ -110,3 +110,100 @@ func TestChainHandlerC(t *testing.T) {
110110
111111 assert .Equal (t , 3 , handlerCalls , "all handler called once" )
112112}
113+
114+ func TestAdd (t * testing.T ) {
115+ handlerCalls := 0
116+ h1 := func (next HandlerC ) HandlerC {
117+ return HandlerFuncC (func (ctx context.Context , w http.ResponseWriter , r * http.Request ) {
118+ handlerCalls ++
119+ ctx = context .WithValue (ctx , "test" , 1 )
120+ next .ServeHTTPC (ctx , w , r )
121+ })
122+ }
123+ h2 := func (next http.Handler ) http.Handler {
124+ handlerCalls ++
125+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
126+ // Change r and w values
127+ w = httptest .NewRecorder ()
128+ r = & http.Request {}
129+ next .ServeHTTP (w , r )
130+ })
131+ }
132+ h3 := func (next HandlerC ) HandlerC {
133+ return HandlerFuncC (func (ctx context.Context , w http.ResponseWriter , r * http.Request ) {
134+ handlerCalls ++
135+ ctx = context .WithValue (ctx , "test" , 2 )
136+ next .ServeHTTPC (ctx , w , r )
137+ })
138+ }
139+
140+ c := Chain {}
141+ c .Add (h1 , h2 , h3 )
142+ h := c .HandlerC (HandlerFuncC (func (ctx context.Context , w http.ResponseWriter , r * http.Request ) {
143+ handlerCalls ++
144+
145+ assert .Equal (t , 2 , ctx .Value ("test" ),
146+ "third handler should overwrite first handler's context value" )
147+ assert .Equal (t , 1 , ctx .Value ("mainCtx" ),
148+ "the mainCtx value should be pass through" )
149+ }))
150+
151+ mainCtx := context .WithValue (context .Background (), "mainCtx" , 1 )
152+ h .ServeHTTPC (mainCtx , nil , nil )
153+ assert .Equal (t , 4 , handlerCalls , "all handler called once" )
154+ }
155+
156+ func TestWith (t * testing.T ) {
157+ handlerCalls := 0
158+ h1 := func (next HandlerC ) HandlerC {
159+ return HandlerFuncC (func (ctx context.Context , w http.ResponseWriter , r * http.Request ) {
160+ handlerCalls ++
161+ ctx = context .WithValue (ctx , "test" , 1 )
162+ next .ServeHTTPC (ctx , w , r )
163+ })
164+ }
165+ h2 := func (next http.Handler ) http.Handler {
166+ handlerCalls ++
167+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
168+ // Change r and w values
169+ w = httptest .NewRecorder ()
170+ r = & http.Request {}
171+ next .ServeHTTP (w , r )
172+ })
173+ }
174+ h3 := func (next HandlerC ) HandlerC {
175+ return HandlerFuncC (func (ctx context.Context , w http.ResponseWriter , r * http.Request ) {
176+ handlerCalls ++
177+ ctx = context .WithValue (ctx , "test" , 2 )
178+ next .ServeHTTPC (ctx , w , r )
179+ })
180+ }
181+
182+ c := Chain {}
183+ c .Add (h1 )
184+ d := c .With (h2 , h3 )
185+
186+ h := c .HandlerC (HandlerFuncC (func (ctx context.Context , w http.ResponseWriter , r * http.Request ) {
187+ handlerCalls ++
188+
189+ assert .Equal (t , 1 , ctx .Value ("test" ),
190+ "third handler should not overwrite the first handler's context value" )
191+ assert .Equal (t , 1 , ctx .Value ("mainCtx" ),
192+ "the mainCtx value should be pass through" )
193+ }))
194+ i := d .HandlerC (HandlerFuncC (func (ctx context.Context , w http.ResponseWriter , r * http.Request ) {
195+ handlerCalls ++
196+
197+ assert .Equal (t , 2 , ctx .Value ("test" ),
198+ "third handler should overwrite first handler's context value" )
199+ assert .Equal (t , 1 , ctx .Value ("mainCtx" ),
200+ "the mainCtx value should be pass through" )
201+ }))
202+
203+ mainCtx := context .WithValue (context .Background (), "mainCtx" , 1 )
204+ h .ServeHTTPC (mainCtx , nil , nil )
205+ assert .Equal (t , 2 , handlerCalls , "all handlers called once" )
206+ handlerCalls = 0
207+ i .ServeHTTPC (mainCtx , nil , nil )
208+ assert .Equal (t , 4 , handlerCalls , "all handler called once" )
209+ }
0 commit comments