Skip to content

Commit 52dd761

Browse files
authored
fix oidc token exchange endpoint (#2490)
* fix oidc token exchange endpoint * cleanup
1 parent f69464e commit 52dd761

4 files changed

Lines changed: 259 additions & 24 deletions

File tree

src/main/java/org/ohdsi/webapi/OidcConfCreator.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
*/
1919
package org.ohdsi.webapi;
2020

21-
import com.nimbusds.jose.JWSAlgorithm;
2221
import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod;
2322
import org.pac4j.oidc.config.OidcConfiguration;
2423
import org.slf4j.Logger;
@@ -107,7 +106,7 @@ public OidcConfiguration build() {
107106
scopes += extraScopes;
108107
}
109108
conf.setScope(scopes);
110-
conf.setPreferredJwsAlgorithm(JWSAlgorithm.RS256);
109+
// Use all algorithms from provider metadata (supports RS256, ES384, etc.)
111110
conf.setPkceMethod(CodeChallengeMethod.S256);
112111

113112
try {
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
package org.ohdsi.webapi.shiro.filters;
2+
3+
import com.nimbusds.jose.JOSEException;
4+
import com.nimbusds.jose.JWSHeader;
5+
import com.nimbusds.jose.JWSVerifier;
6+
import com.nimbusds.jose.crypto.ECDSAVerifier;
7+
import com.nimbusds.jose.crypto.RSASSAVerifier;
8+
import com.nimbusds.jose.jwk.ECKey;
9+
import com.nimbusds.jose.jwk.JWK;
10+
import com.nimbusds.jose.jwk.JWKSet;
11+
import com.nimbusds.jose.jwk.RSAKey;
12+
import com.nimbusds.jwt.JWTClaimsSet;
13+
import com.nimbusds.jwt.SignedJWT;
14+
import jakarta.servlet.ServletRequest;
15+
import jakarta.servlet.ServletResponse;
16+
import jakarta.servlet.http.HttpServletRequest;
17+
import org.apache.shiro.authc.AuthenticationException;
18+
import org.ohdsi.webapi.shiro.PermissionManager;
19+
import org.ohdsi.webapi.shiro.ServletBridge;
20+
import org.ohdsi.webapi.shiro.tokens.JwtAuthToken;
21+
import org.pac4j.oidc.config.OidcConfiguration;
22+
import org.slf4j.Logger;
23+
import org.slf4j.LoggerFactory;
24+
25+
import java.net.URI;
26+
import java.security.interfaces.ECPublicKey;
27+
import java.security.interfaces.RSAPublicKey;
28+
import java.text.ParseException;
29+
import java.util.*;
30+
import java.util.concurrent.ConcurrentHashMap;
31+
32+
/**
33+
* Validates OIDC JWT bearer tokens using the provider's JWKS.
34+
* Used for token exchange: external OIDC token -> WebAPI JWT.
35+
*/
36+
public class OidcJwtAuthFilter extends AtlasAuthFilter {
37+
38+
private static final Logger logger = LoggerFactory.getLogger(OidcJwtAuthFilter.class);
39+
private static final String AUTHORIZATION_HEADER = "Authorization";
40+
private static final String BEARER_PREFIX = "Bearer ";
41+
private static final long JWKS_CACHE_DURATION_MS = 300_000;
42+
43+
public static final String OIDC_EXTERNAL_TOKEN = "oidc_external_token";
44+
45+
private final OidcConfiguration oidcConfiguration;
46+
private final PermissionManager authorizer;
47+
private final Set<String> defaultRoles;
48+
private final Map<String, JWK> keyCache = new ConcurrentHashMap<>();
49+
private volatile long lastJwksFetch = 0;
50+
51+
public OidcJwtAuthFilter(OidcConfiguration oidcConfiguration,
52+
PermissionManager authorizer,
53+
Set<String> defaultRoles,
54+
int tokenExpirationIntervalInSeconds) {
55+
this.oidcConfiguration = oidcConfiguration;
56+
this.authorizer = authorizer;
57+
this.defaultRoles = defaultRoles;
58+
}
59+
60+
@Override
61+
protected JwtAuthToken createToken(ServletRequest request, ServletResponse response) throws Exception {
62+
String bearerToken = extractBearerToken(request);
63+
if (bearerToken == null) {
64+
throw new AuthenticationException("No bearer token found");
65+
}
66+
return new JwtAuthToken(verifyAndExtractSubject(bearerToken));
67+
}
68+
69+
@Override
70+
protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws Exception {
71+
String bearerToken = extractBearerToken(request);
72+
if (bearerToken == null) {
73+
return true;
74+
}
75+
76+
try {
77+
String subject = verifyAndExtractSubject(bearerToken);
78+
String name = extractName(bearerToken, subject);
79+
authorizer.registerUser(subject, name, defaultRoles);
80+
request.setAttribute(OIDC_EXTERNAL_TOKEN, true);
81+
return executeLogin(request, response);
82+
} catch (AuthenticationException e) {
83+
logger.warn("OIDC JWT authentication failed for request from {}: {}",
84+
request.getRemoteAddr(), e.getMessage());
85+
return true;
86+
}
87+
}
88+
89+
private String extractBearerToken(ServletRequest request) {
90+
HttpServletRequest httpRequest = ServletBridge.toHttp(request);
91+
String authHeader = httpRequest.getHeader(AUTHORIZATION_HEADER);
92+
if (authHeader != null && authHeader.startsWith(BEARER_PREFIX)) {
93+
return authHeader.substring(BEARER_PREFIX.length());
94+
}
95+
return null;
96+
}
97+
98+
private String verifyAndExtractSubject(String jwtToken) throws AuthenticationException {
99+
try {
100+
SignedJWT signedJwt = SignedJWT.parse(jwtToken);
101+
JWSHeader header = signedJwt.getHeader();
102+
JWTClaimsSet claims = signedJwt.getJWTClaimsSet();
103+
104+
Date now = new Date();
105+
Date expiration = claims.getExpirationTime();
106+
if (expiration != null && expiration.before(now)) {
107+
throw new AuthenticationException("Token expired");
108+
}
109+
110+
Date notBefore = claims.getNotBeforeTime();
111+
if (notBefore != null && notBefore.after(now)) {
112+
throw new AuthenticationException("Token not yet valid");
113+
}
114+
115+
String expectedIssuer = getExpectedIssuer();
116+
if (expectedIssuer != null && !expectedIssuer.equals(claims.getIssuer())) {
117+
throw new AuthenticationException("Invalid token issuer");
118+
}
119+
120+
String expectedAudience = oidcConfiguration.getClientId();
121+
List<String> audiences = claims.getAudience();
122+
if (expectedAudience != null && (audiences == null || !audiences.contains(expectedAudience))) {
123+
throw new AuthenticationException("Invalid token audience");
124+
}
125+
126+
JWK jwk = getKey(header.getKeyID());
127+
if (jwk == null) {
128+
throw new AuthenticationException("Signing key not found");
129+
}
130+
131+
if (!signedJwt.verify(createVerifier(jwk))) {
132+
throw new AuthenticationException("Invalid signature");
133+
}
134+
135+
String email = (String) claims.getClaim("email");
136+
return (email != null && !email.isEmpty()) ? email : claims.getSubject();
137+
138+
} catch (ParseException | JOSEException e) {
139+
throw new AuthenticationException("JWT validation failed: " + e.getMessage(), e);
140+
}
141+
}
142+
143+
private String extractName(String jwtToken, String fallback) {
144+
try {
145+
SignedJWT signedJwt = SignedJWT.parse(jwtToken);
146+
String name = (String) signedJwt.getJWTClaimsSet().getClaim("name");
147+
return (name != null && !name.isEmpty()) ? name : fallback;
148+
} catch (ParseException e) {
149+
return fallback;
150+
}
151+
}
152+
153+
private String getExpectedIssuer() {
154+
try {
155+
var resolver = oidcConfiguration.getOpMetadataResolver();
156+
if (resolver != null) {
157+
var metadata = resolver.load();
158+
if (metadata != null) {
159+
return metadata.getIssuer().getValue();
160+
}
161+
}
162+
} catch (Exception e) {
163+
logger.warn("Failed to get OIDC issuer: {}", e.getMessage());
164+
}
165+
return null;
166+
}
167+
168+
private JWK getKey(String kid) {
169+
JWK jwk = keyCache.get(kid);
170+
if (jwk == null) {
171+
long currentTime = System.currentTimeMillis();
172+
if (currentTime - lastJwksFetch > JWKS_CACHE_DURATION_MS) {
173+
synchronized (this) {
174+
if (currentTime - lastJwksFetch > JWKS_CACHE_DURATION_MS) {
175+
refreshJwks();
176+
}
177+
}
178+
jwk = keyCache.get(kid);
179+
}
180+
}
181+
return jwk;
182+
}
183+
184+
private void refreshJwks() {
185+
try {
186+
URI jwksUri = getJwksUri();
187+
if (jwksUri == null) {
188+
logger.error("No JWKS URI available");
189+
return;
190+
}
191+
192+
JWKSet jwkSet = JWKSet.load(jwksUri.toURL());
193+
keyCache.clear();
194+
for (JWK key : jwkSet.getKeys()) {
195+
if (key.getKeyID() != null) {
196+
keyCache.put(key.getKeyID(), key);
197+
}
198+
}
199+
} catch (Exception e) {
200+
logger.error("Failed to fetch JWKS: {}", e.getMessage());
201+
} finally {
202+
lastJwksFetch = System.currentTimeMillis();
203+
}
204+
}
205+
206+
private URI getJwksUri() {
207+
try {
208+
var resolver = oidcConfiguration.getOpMetadataResolver();
209+
if (resolver != null) {
210+
var metadata = resolver.load();
211+
if (metadata != null) {
212+
return metadata.getJWKSetURI();
213+
}
214+
}
215+
} catch (Exception e) {
216+
logger.warn("Failed to get JWKS URI: {}", e.getMessage());
217+
}
218+
return null;
219+
}
220+
221+
private JWSVerifier createVerifier(JWK jwk) throws JOSEException {
222+
if (jwk instanceof ECKey) {
223+
return new ECDSAVerifier(((ECKey) jwk).toECPublicKey());
224+
} else if (jwk instanceof RSAKey) {
225+
return new RSASSAVerifier(((RSAKey) jwk).toRSAPublicKey());
226+
}
227+
throw new JOSEException("Unsupported key type: " + jwk.getKeyType());
228+
}
229+
}

src/main/java/org/ohdsi/webapi/shiro/filters/UpdateAccessTokenFilter.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,12 @@ protected boolean preHandle(ServletRequest request, ServletResponse response) th
134134

135135
String sessionId = (String) request.getAttribute(Constants.SESSION_ID);
136136
if (sessionId == null) {
137-
final String token = TokenManager.extractToken(request);
138-
if (token != null) {
139-
sessionId = (String) TokenManager.getBody(token).get(Constants.SESSION_ID);
137+
Boolean isOidcToken = (Boolean) request.getAttribute(OidcJwtAuthFilter.OIDC_EXTERNAL_TOKEN);
138+
if (!Boolean.TRUE.equals(isOidcToken)) {
139+
final String token = TokenManager.extractToken(request);
140+
if (token != null) {
141+
sessionId = (String) TokenManager.getBody(token).get(Constants.SESSION_ID);
142+
}
140143
}
141144
}
142145

src/main/java/org/ohdsi/webapi/shiro/management/AtlasRegularSecurity.java

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.ohdsi.webapi.shiro.Entities.UserRepository;
1616
import org.ohdsi.webapi.shiro.PermissionManager;
1717
import org.ohdsi.webapi.shiro.filters.*;
18+
import org.ohdsi.webapi.shiro.filters.OidcJwtAuthFilter;
1819
import org.ohdsi.webapi.shiro.filters.auth.ActiveDirectoryAuthFilter;
1920
import org.ohdsi.webapi.shiro.filters.auth.AtlasJwtAuthFilter;
2021
import org.ohdsi.webapi.shiro.filters.auth.JdbcAuthFilter;
@@ -49,8 +50,6 @@
4950
import org.pac4j.oauth.client.Google2Client;
5051
import org.pac4j.oidc.client.OidcClient;
5152
import org.pac4j.oidc.config.OidcConfiguration;
52-
import org.pac4j.oidc.credentials.authenticator.OidcAuthenticator;
53-
import org.pac4j.http.client.direct.DirectBearerAuthClient;
5453
import org.pac4j.saml.client.SAML2Client;
5554
import org.pac4j.saml.config.SAML2Configuration;
5655
import org.slf4j.Logger;
@@ -331,17 +330,18 @@ public Map<FilterTemplates, Filter> getFilters() {
331330
clients.add(githubClient);
332331
}
333332

333+
OidcConfiguration oidcConfiguration = null;
334334
if (this.openidAuthEnabled) {
335-
OidcConfiguration configuration = oidcConfCreator.build();
336-
if (StringUtils.isNotBlank(configuration.getClientId())) {
335+
oidcConfiguration = oidcConfCreator.build();
336+
if (StringUtils.isNotBlank(oidcConfiguration.getClientId())) {
337337
// https://www.pac4j.org/4.0.x/docs/clients/openid-connect.html
338338
// OidcClient allows indirect login through UI with code flow
339-
OidcClient oidcClient = new OidcClient(configuration);
339+
OidcClient oidcClient = new OidcClient(oidcConfiguration);
340340
oidcClient.setCallbackUrl(oauthApiCallback);
341341
oidcClient.setCallbackUrlResolver(urlResolver);
342342

343343
// URL rewriting: discovery from internal URL, redirect to external URL
344-
String internalUrl = configuration.getDiscoveryURI();
344+
String internalUrl = oidcConfiguration.getDiscoveryURI();
345345
String externalUrl = oidcConfCreator.getExternalUrl();
346346
if (externalUrl != null && !externalUrl.isEmpty()) {
347347
org.ohdsi.webapi.shiro.filters.ExternalUrlOidcRedirectionActionBuilder redirectBuilder =
@@ -353,12 +353,6 @@ public Map<FilterTemplates, Filter> getFilters() {
353353

354354
// Configuration already initialized; pac4j handles lazy init
355355
clients.add(oidcClient);
356-
357-
// Bearer token authentication for API access (pac4j 6.x)
358-
// OidcAuthenticator requires both configuration and client
359-
OidcAuthenticator authenticator = new OidcAuthenticator(configuration, oidcClient);
360-
DirectBearerAuthClient bearerClient = new DirectBearerAuthClient(authenticator);
361-
clients.add(bearerClient);
362356
} else {
363357
logger.warn("openidAuth is enabled but no client id is provided");
364358
}
@@ -405,11 +399,6 @@ public Map<FilterTemplates, Filter> getFilters() {
405399
oidcFilter.setConfig(cfg);
406400
oidcFilter.setClients("OidcClient");
407401
filters.put(OIDC_AUTH, oidcFilter);
408-
409-
SecurityFilter oidcDirectFilter = new SecurityFilter();
410-
oidcDirectFilter.setConfig(cfg);
411-
oidcDirectFilter.setClients("HeaderClient");
412-
filters.put(OIDC_DIRECT_AUTH, oidcDirectFilter);
413402
}
414403

415404
io.buji.pac4j.filter.CallbackFilter callbackFilter = new io.buji.pac4j.filter.CallbackFilter();
@@ -427,6 +416,17 @@ public Map<FilterTemplates, Filter> getFilters() {
427416
filters.put(HANDLE_UNSUCCESSFUL_OAUTH, new RedirectOnFailedOAuthFilter(this.oauthUiCallback));
428417
}
429418

419+
// OIDC token exchange filter
420+
if (this.openidAuthEnabled && oidcConfiguration != null) {
421+
OidcJwtAuthFilter oidcJwtFilter = new OidcJwtAuthFilter(
422+
oidcConfiguration,
423+
this.authorizer,
424+
this.defaultRoles,
425+
this.tokenExpirationIntervalInSeconds
426+
);
427+
filters.put(OIDC_DIRECT_AUTH, oidcJwtFilter);
428+
}
429+
430430
if (this.casAuthEnabled) {
431431
this.setUpCAS(filters);
432432
}
@@ -440,8 +440,12 @@ public Map<FilterTemplates, Filter> getFilters() {
440440
@Override
441441
protected FilterChainBuilder getFilterChainBuilder() {
442442

443-
List<FilterTemplates> authcFilters = googleAccessTokenEnabled ? Arrays.asList(ACCESS_AUTHC, JWT_AUTHC) :
444-
Collections.singletonList(JWT_AUTHC);
443+
// Build authentication filter chain: try JWT first, then OIDC if enabled
444+
List<FilterTemplates> authcFilters = new ArrayList<>();
445+
if (googleAccessTokenEnabled) {
446+
authcFilters.add(ACCESS_AUTHC);
447+
}
448+
authcFilters.add(JWT_AUTHC);
445449
// the order does matter - first match wins
446450
FilterChainBuilder filterChainBuilder = new FilterChainBuilder()
447451
.setRestFilters(SSL, NO_SESSION_CREATION, CORS, NO_CACHE)

0 commit comments

Comments
 (0)