|
16 | 16 |
|
17 | 17 | package org.springframework.security.config.annotation.web.configurers.saml2; |
18 | 18 |
|
| 19 | +import java.util.ArrayList; |
19 | 20 | import java.util.LinkedHashMap; |
| 21 | +import java.util.List; |
20 | 22 | import java.util.Map; |
21 | 23 |
|
| 24 | +import jakarta.servlet.http.HttpServletRequest; |
| 25 | + |
22 | 26 | import org.springframework.beans.factory.NoSuchBeanDefinitionException; |
23 | 27 | import org.springframework.context.ApplicationContext; |
24 | 28 | import org.springframework.security.authentication.AuthenticationManager; |
|
33 | 37 | import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider; |
34 | 38 | import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; |
35 | 39 | import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; |
| 40 | +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrations; |
36 | 41 | import org.springframework.security.saml2.provider.service.web.HttpSessionSaml2AuthenticationRequestRepository; |
37 | 42 | import org.springframework.security.saml2.provider.service.web.OpenSamlAuthenticationTokenConverter; |
38 | 43 | import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository; |
|
50 | 55 | import org.springframework.security.web.util.matcher.AntPathRequestMatcher; |
51 | 56 | import org.springframework.security.web.util.matcher.NegatedRequestMatcher; |
52 | 57 | import org.springframework.security.web.util.matcher.OrRequestMatcher; |
| 58 | +import org.springframework.security.web.util.matcher.ParameterRequestMatcher; |
53 | 59 | import org.springframework.security.web.util.matcher.RequestHeaderRequestMatcher; |
54 | 60 | import org.springframework.security.web.util.matcher.RequestMatcher; |
55 | 61 | import org.springframework.security.web.util.matcher.RequestMatchers; |
@@ -111,7 +117,13 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> |
111 | 117 |
|
112 | 118 | private String loginPage; |
113 | 119 |
|
114 | | - private String authenticationRequestUri = Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI; |
| 120 | + private String authenticationRequestUri = "/saml2/authenticate"; |
| 121 | + |
| 122 | + private String[] authenticationRequestParams = { "registrationId={registrationId}" }; |
| 123 | + |
| 124 | + private RequestMatcher authenticationRequestMatcher = RequestMatchers.anyOf( |
| 125 | + new AntPathRequestMatcher(Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI), |
| 126 | + new AntPathQueryRequestMatcher(this.authenticationRequestUri, this.authenticationRequestParams)); |
115 | 127 |
|
116 | 128 | private Saml2AuthenticationRequestResolver authenticationRequestResolver; |
117 | 129 |
|
@@ -196,11 +208,31 @@ public Saml2LoginConfigurer<B> authenticationRequestResolver( |
196 | 208 | * Request |
197 | 209 | * @return the {@link Saml2LoginConfigurer} for further configuration |
198 | 210 | * @since 6.0 |
| 211 | + * @deprecated Use {@link #authenticationRequestUriQuery} instead |
199 | 212 | */ |
200 | 213 | public Saml2LoginConfigurer<B> authenticationRequestUri(String authenticationRequestUri) { |
201 | | - Assert.state(authenticationRequestUri.contains("{registrationId}"), |
202 | | - "authenticationRequestUri must contain {registrationId} path variable"); |
203 | | - this.authenticationRequestUri = authenticationRequestUri; |
| 214 | + return authenticationRequestUriQuery(authenticationRequestUri); |
| 215 | + } |
| 216 | + |
| 217 | + /** |
| 218 | + * Customize the URL that the SAML Authentication Request will be sent to. This method |
| 219 | + * also supports query parameters like so: <pre> |
| 220 | + * authenticationRequestUriQuery("/saml/authenticate?registrationId={registrationId}") |
| 221 | + * </pre> {@link RelyingPartyRegistrations} |
| 222 | + * @param authenticationRequestUriQuery the URI and query to use for the SAML 2.0 |
| 223 | + * Authentication Request |
| 224 | + * @return the {@link Saml2LoginConfigurer} for further configuration |
| 225 | + * @since 6.0 |
| 226 | + */ |
| 227 | + public Saml2LoginConfigurer<B> authenticationRequestUriQuery(String authenticationRequestUriQuery) { |
| 228 | + Assert.state(authenticationRequestUriQuery.contains("{registrationId}"), |
| 229 | + "authenticationRequestUri must contain {registrationId} path variable or query value"); |
| 230 | + String[] parts = authenticationRequestUriQuery.split("[?&]"); |
| 231 | + this.authenticationRequestUri = parts[0]; |
| 232 | + this.authenticationRequestParams = new String[parts.length - 1]; |
| 233 | + System.arraycopy(parts, 1, this.authenticationRequestParams, 0, parts.length - 1); |
| 234 | + this.authenticationRequestMatcher = new AntPathQueryRequestMatcher(this.authenticationRequestUri, |
| 235 | + this.authenticationRequestParams); |
204 | 236 | return this; |
205 | 237 | } |
206 | 238 |
|
@@ -255,7 +287,7 @@ public void init(B http) throws Exception { |
255 | 287 | } |
256 | 288 | else { |
257 | 289 | Map<String, String> providerUrlMap = getIdentityProviderUrlMap(this.authenticationRequestUri, |
258 | | - this.relyingPartyRegistrationRepository); |
| 290 | + this.authenticationRequestParams, this.relyingPartyRegistrationRepository); |
259 | 291 | boolean singleProvider = providerUrlMap.size() == 1; |
260 | 292 | if (singleProvider) { |
261 | 293 | // Setup auto-redirect to provider login page |
@@ -336,8 +368,7 @@ private Saml2AuthenticationRequestResolver getAuthenticationRequestResolver(B ht |
336 | 368 | } |
337 | 369 | OpenSaml4AuthenticationRequestResolver openSaml4AuthenticationRequestResolver = new OpenSaml4AuthenticationRequestResolver( |
338 | 370 | relyingPartyRegistrationRepository(http)); |
339 | | - openSaml4AuthenticationRequestResolver |
340 | | - .setRequestMatcher(new AntPathRequestMatcher(this.authenticationRequestUri)); |
| 371 | + openSaml4AuthenticationRequestResolver.setRequestMatcher(this.authenticationRequestMatcher); |
341 | 372 | return openSaml4AuthenticationRequestResolver; |
342 | 373 | } |
343 | 374 |
|
@@ -382,20 +413,28 @@ private void initDefaultLoginFilter(B http) { |
382 | 413 | return; |
383 | 414 | } |
384 | 415 | loginPageGeneratingFilter.setSaml2LoginEnabled(true); |
385 | | - loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName( |
386 | | - this.getIdentityProviderUrlMap(this.authenticationRequestUri, this.relyingPartyRegistrationRepository)); |
| 416 | + loginPageGeneratingFilter |
| 417 | + .setSaml2AuthenticationUrlToProviderName(this.getIdentityProviderUrlMap(this.authenticationRequestUri, |
| 418 | + this.authenticationRequestParams, this.relyingPartyRegistrationRepository)); |
387 | 419 | loginPageGeneratingFilter.setLoginPageUrl(this.getLoginPage()); |
388 | 420 | loginPageGeneratingFilter.setFailureUrl(this.getFailureUrl()); |
389 | 421 | } |
390 | 422 |
|
391 | 423 | @SuppressWarnings("unchecked") |
392 | | - private Map<String, String> getIdentityProviderUrlMap(String authRequestPrefixUrl, |
| 424 | + private Map<String, String> getIdentityProviderUrlMap(String authRequestPrefixUrl, String[] authRequestQueryParams, |
393 | 425 | RelyingPartyRegistrationRepository idpRepo) { |
394 | 426 | Map<String, String> idps = new LinkedHashMap<>(); |
395 | 427 | if (idpRepo instanceof Iterable) { |
396 | 428 | Iterable<RelyingPartyRegistration> repo = (Iterable<RelyingPartyRegistration>) idpRepo; |
397 | | - repo.forEach((p) -> idps.put(authRequestPrefixUrl.replace("{registrationId}", p.getRegistrationId()), |
398 | | - p.getRegistrationId())); |
| 429 | + StringBuilder authRequestQuery = new StringBuilder("?"); |
| 430 | + for (String authRequestQueryParam : authRequestQueryParams) { |
| 431 | + authRequestQuery.append(authRequestQueryParam + "&"); |
| 432 | + } |
| 433 | + authRequestQuery.deleteCharAt(authRequestQuery.length() - 1); |
| 434 | + String authenticationRequestUriQuery = authRequestPrefixUrl + authRequestQuery; |
| 435 | + repo.forEach( |
| 436 | + (p) -> idps.put(authenticationRequestUriQuery.replace("{registrationId}", p.getRegistrationId()), |
| 437 | + p.getRegistrationId())); |
399 | 438 | } |
400 | 439 | return idps; |
401 | 440 | } |
@@ -437,4 +476,35 @@ private <C> void setSharedObject(B http, Class<C> clazz, C object) { |
437 | 476 | } |
438 | 477 | } |
439 | 478 |
|
| 479 | + static class AntPathQueryRequestMatcher implements RequestMatcher { |
| 480 | + |
| 481 | + private final RequestMatcher matcher; |
| 482 | + |
| 483 | + AntPathQueryRequestMatcher(String path, String... params) { |
| 484 | + List<RequestMatcher> matchers = new ArrayList<>(); |
| 485 | + matchers.add(new AntPathRequestMatcher(path)); |
| 486 | + for (String param : params) { |
| 487 | + String[] parts = param.split("="); |
| 488 | + if (parts.length == 1) { |
| 489 | + matchers.add(new ParameterRequestMatcher(parts[0])); |
| 490 | + } |
| 491 | + else { |
| 492 | + matchers.add(new ParameterRequestMatcher(parts[0], parts[1])); |
| 493 | + } |
| 494 | + } |
| 495 | + this.matcher = new AndRequestMatcher(matchers); |
| 496 | + } |
| 497 | + |
| 498 | + @Override |
| 499 | + public boolean matches(HttpServletRequest request) { |
| 500 | + return matcher(request).isMatch(); |
| 501 | + } |
| 502 | + |
| 503 | + @Override |
| 504 | + public MatchResult matcher(HttpServletRequest request) { |
| 505 | + return this.matcher.matcher(request); |
| 506 | + } |
| 507 | + |
| 508 | + } |
| 509 | + |
440 | 510 | } |
0 commit comments