|
16 | 16 |
|
17 | 17 | package org.springframework.security.web.context; |
18 | 18 |
|
| 19 | +import java.io.IOException; |
19 | 20 | import java.lang.annotation.ElementType; |
20 | 21 | import java.lang.annotation.Retention; |
21 | 22 | import java.lang.annotation.RetentionPolicy; |
22 | 23 | import java.lang.annotation.Target; |
23 | 24 |
|
| 25 | +import javax.servlet.Filter; |
| 26 | +import javax.servlet.ServletException; |
24 | 27 | import javax.servlet.ServletOutputStream; |
| 28 | +import javax.servlet.http.HttpServlet; |
25 | 29 | import javax.servlet.http.HttpServletRequest; |
26 | 30 | import javax.servlet.http.HttpServletRequestWrapper; |
27 | 31 | import javax.servlet.http.HttpServletResponse; |
|
31 | 35 | import org.junit.After; |
32 | 36 | import org.junit.Test; |
33 | 37 |
|
| 38 | +import org.springframework.mock.web.MockFilterChain; |
34 | 39 | import org.springframework.mock.web.MockHttpServletRequest; |
35 | 40 | import org.springframework.mock.web.MockHttpServletResponse; |
36 | 41 | import org.springframework.mock.web.MockHttpSession; |
37 | 42 | import org.springframework.security.authentication.AbstractAuthenticationToken; |
38 | 43 | import org.springframework.security.authentication.AnonymousAuthenticationToken; |
39 | 44 | import org.springframework.security.authentication.AuthenticationTrustResolver; |
40 | 45 | import org.springframework.security.authentication.TestingAuthenticationToken; |
| 46 | +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; |
41 | 47 | import org.springframework.security.core.Transient; |
42 | 48 | import org.springframework.security.core.authority.AuthorityUtils; |
43 | 49 | import org.springframework.security.core.context.SecurityContext; |
44 | 50 | import org.springframework.security.core.context.SecurityContextHolder; |
| 51 | +import org.springframework.security.core.context.SecurityContextImpl; |
| 52 | +import org.springframework.security.core.userdetails.User; |
| 53 | +import org.springframework.security.core.userdetails.UserDetails; |
45 | 54 |
|
46 | 55 | import static org.assertj.core.api.Assertions.assertThat; |
47 | 56 | import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; |
@@ -162,6 +171,48 @@ public void saveContextCallsSetAttributeIfContextIsModifiedDirectlyDuringRequest |
162 | 171 | verify(session).setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, ctx); |
163 | 172 | } |
164 | 173 |
|
| 174 | + @Test |
| 175 | + public void saveContextWhenSaveNewContextThenOriginalContextThenOriginalContextSaved() throws Exception { |
| 176 | + HttpSessionSecurityContextRepository repository = new HttpSessionSecurityContextRepository(); |
| 177 | + SecurityContextPersistenceFilter securityContextPersistenceFilter = new SecurityContextPersistenceFilter( |
| 178 | + repository); |
| 179 | + |
| 180 | + UserDetails original = User.withUsername("user").password("password").roles("USER").build(); |
| 181 | + SecurityContext originalContext = createSecurityContext(original); |
| 182 | + UserDetails impersonate = User.withUserDetails(original).username("impersonate").build(); |
| 183 | + SecurityContext impersonateContext = createSecurityContext(impersonate); |
| 184 | + |
| 185 | + MockHttpServletRequest mockRequest = new MockHttpServletRequest(); |
| 186 | + MockHttpServletResponse mockResponse = new MockHttpServletResponse(); |
| 187 | + |
| 188 | + Filter saveImpersonateContext = (request, response, chain) -> { |
| 189 | + SecurityContextHolder.setContext(impersonateContext); |
| 190 | + // ensure the response is committed to trigger save |
| 191 | + response.flushBuffer(); |
| 192 | + chain.doFilter(request, response); |
| 193 | + }; |
| 194 | + Filter saveOriginalContext = (request, response, chain) -> { |
| 195 | + SecurityContextHolder.setContext(originalContext); |
| 196 | + chain.doFilter(request, response); |
| 197 | + }; |
| 198 | + HttpServlet servlet = new HttpServlet() { |
| 199 | + @Override |
| 200 | + protected void service(HttpServletRequest req, HttpServletResponse resp) |
| 201 | + throws ServletException, IOException { |
| 202 | + resp.getWriter().write("Hi"); |
| 203 | + } |
| 204 | + }; |
| 205 | + |
| 206 | + SecurityContextHolder.setContext(originalContext); |
| 207 | + MockFilterChain chain = new MockFilterChain(servlet, saveImpersonateContext, saveOriginalContext); |
| 208 | + |
| 209 | + securityContextPersistenceFilter.doFilter(mockRequest, mockResponse, chain); |
| 210 | + |
| 211 | + assertThat( |
| 212 | + mockRequest.getSession().getAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY)) |
| 213 | + .isEqualTo(originalContext); |
| 214 | + } |
| 215 | + |
165 | 216 | @Test |
166 | 217 | public void nonSecurityContextInSessionIsIgnored() { |
167 | 218 | HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); |
@@ -577,6 +628,13 @@ public void saveContextWhenTransientAuthenticationWithCustomAnnotationThenSkippe |
577 | 628 | assertThat(session).isNull(); |
578 | 629 | } |
579 | 630 |
|
| 631 | + private SecurityContext createSecurityContext(UserDetails userDetails) { |
| 632 | + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(userDetails, |
| 633 | + userDetails.getPassword(), userDetails.getAuthorities()); |
| 634 | + SecurityContext securityContext = new SecurityContextImpl(token); |
| 635 | + return securityContext; |
| 636 | + } |
| 637 | + |
580 | 638 | @Transient |
581 | 639 | private static class SomeTransientAuthentication extends AbstractAuthenticationToken { |
582 | 640 |
|
|
0 commit comments