diff --git a/itests/hive-unit/src/test/java/org/apache/hive/service/auth/saml/TestHttpSamlAuthentication.java b/itests/hive-unit/src/test/java/org/apache/hive/service/auth/saml/TestHttpSamlAuthentication.java index 7d119e9372c2..f58cf6a0adf5 100644 --- a/itests/hive-unit/src/test/java/org/apache/hive/service/auth/saml/TestHttpSamlAuthentication.java +++ b/itests/hive-unit/src/test/java/org/apache/hive/service/auth/saml/TestHttpSamlAuthentication.java @@ -35,6 +35,7 @@ import java.net.InetAddress; import java.net.ServerSocket; import java.nio.charset.StandardCharsets; +import java.util.Base64; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; @@ -137,9 +138,17 @@ public void cleanUpIdpEnv() { idpContainer.stop(); idpContainer = null; } - if (miniHS2 != null) { + if (miniHS2 != null && miniHS2.isStarted()) { miniHS2.stop(); } + HiveSamlAuthTokenGenerator.shutdown(); + } + + private static ISAMLAuthTokenGenerator createTokenGenerator(String tokenTtl) { + HiveSamlAuthTokenGenerator.shutdown(); + HiveConf conf = new HiveConf(); + conf.setVar(ConfVars.HIVE_SERVER2_SAML_CALLBACK_TOKEN_TTL, tokenTtl); + return HiveSamlAuthTokenGenerator.get(conf); } private void setupIDP(boolean useSignedAssertions, String authMode) throws Exception { @@ -554,6 +563,86 @@ public void testTokenReuse() throws Exception { } } + @Test + public void testValidTokenRoundTrip() throws Exception { + ISAMLAuthTokenGenerator tokenGenerator = createTokenGenerator("30s"); + String token = tokenGenerator.get("alice", "relay-state-1"); + assertEquals("alice", tokenGenerator.validate(token)); + } + + @Test + public void testForgedSignatureRejected() throws Exception { + ISAMLAuthTokenGenerator tokenGenerator = createTokenGenerator("30s"); + String forgedPayload = "u=alice;id=1337;time=" + System.currentTimeMillis() + + ";rs=deadbeef;sg=bogus"; + try { + String forgedToken = Base64.getEncoder().encodeToString(forgedPayload.getBytes(StandardCharsets.UTF_8)); + tokenGenerator.validate(forgedToken); + fail("Expected forged token to be rejected"); + } catch (HttpSamlAuthenticationException e) { + assertEquals("Token could not be verified", e.getMessage()); + } + } + + @Test + public void testInvalidTokenRejected() throws Exception { + ISAMLAuthTokenGenerator tokenGenerator = createTokenGenerator("30s"); + try { + tokenGenerator.validate("notAValidToken"); + fail("Expected malformed base64 token to be rejected"); + } catch (HttpSamlAuthenticationException e) { + assertEquals("Invalid token", e.getMessage()); + } + String invalidStructure = Base64.getEncoder().encodeToString("foo".getBytes()); + try { + tokenGenerator.validate(invalidStructure); + fail("Expected invalid token structure to be rejected"); + } catch (HttpSamlAuthenticationException e) { + assertEquals("Invalid token", e.getMessage()); + } + } + + @Test + public void testExpiredTokenRejected() throws Exception { + ISAMLAuthTokenGenerator tokenGenerator = createTokenGenerator("1s"); + String token = tokenGenerator.get("alice", "relay-state-1"); + Thread.sleep(1100); + try { + tokenGenerator.validate(token); + fail("Expected expired token to be rejected"); + } catch (HttpSamlAuthenticationException e) { + assertEquals("Token is expired", e.getMessage()); + } + } + + @Test + public void testParseHandlesBase64PaddingInSignature() { + Map kv = new HashMap<>(); + String token = "u=alice;id=1;time=1000;rs=rs1;sg=YWJjZA=="; + assertTrue(HiveSamlAuthTokenGenerator.parse(token, kv)); + assertEquals("alice", kv.get("u")); + assertEquals("YWJjZA==", kv.get("sg")); + } + + @Test + public void testParseRejectsEncodedBearerToken() { + Map kv = new HashMap<>(); + String encoded = Base64.getEncoder().encodeToString( + "u=alice;id=1;time=1000;rs=rs1;sg=abc".getBytes()); + assertFalse(HiveSamlAuthTokenGenerator.parse(encoded, kv)); + } + + @Test + public void testParseDecodedTokenFromGenerator() throws Exception { + ISAMLAuthTokenGenerator tokenGenerator = createTokenGenerator("30s"); + String encoded = tokenGenerator.get("bob", "relay-42"); + String decoded = new String(Base64.getDecoder().decode(encoded), StandardCharsets.UTF_8); + Map kv = new HashMap<>(); + assertTrue(HiveSamlAuthTokenGenerator.parse(decoded, kv)); + assertEquals("bob", kv.get("u")); + assertEquals("relay-42", kv.get(HiveSamlAuthTokenGenerator.RELAY_STATE)); + } + private static void assertLoggedInUser(HiveConnection connection, String expectedUser) throws SQLException { Statement stmt = connection.createStatement(); diff --git a/service/src/java/org/apache/hive/service/auth/saml/HiveSamlAuthTokenGenerator.java b/service/src/java/org/apache/hive/service/auth/saml/HiveSamlAuthTokenGenerator.java index 51cf646b01ee..d9060f44dfcc 100644 --- a/service/src/java/org/apache/hive/service/auth/saml/HiveSamlAuthTokenGenerator.java +++ b/service/src/java/org/apache/hive/service/auth/saml/HiveSamlAuthTokenGenerator.java @@ -19,6 +19,7 @@ package org.apache.hive.service.auth.saml; import com.google.common.annotations.VisibleForTesting; +import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; @@ -78,11 +79,11 @@ public String get(String username, String relayStateKey) { } private String encode(String token) { - return Base64.getEncoder().encodeToString(token.getBytes()); + return Base64.getEncoder().encodeToString(token.getBytes(StandardCharsets.UTF_8)); } private String decode(String encodedToken) { - return new String(Base64.getDecoder().decode(encodedToken)); + return new String(Base64.getDecoder().decode(encodedToken), StandardCharsets.UTF_8); } private String getTokenStr(String username, String id, String timestamp, @@ -100,7 +101,7 @@ private String getTokenStr(String username, String id, String timestamp, private String getSign(String input) { try { MessageDigest md = MessageDigest.getInstance("SHA-256"); - md.update(input.getBytes()); + md.update(input.getBytes(StandardCharsets.UTF_8)); md.update(signatureSecret); byte[] digest = md.digest(); return Base64.getEncoder().encodeToString(digest); @@ -144,7 +145,8 @@ private boolean isExpired(long currentTime, long tokenTime) { } private boolean signatureMatches(String origSign, String derivedSign) { - return !MessageDigest.isEqual(origSign.getBytes(), derivedSign.getBytes()); + return MessageDigest.isEqual(origSign.getBytes(StandardCharsets.UTF_8), + derivedSign.getBytes(StandardCharsets.UTF_8)); } public static boolean parse(String token, Map kv) { @@ -153,7 +155,7 @@ public static boolean parse(String token, Map kv) { return false; } for (String split : splits) { - String[] pair = split.split(SEPARATOR); + String[] pair = split.split(SEPARATOR, 2); if (pair.length != 2) { return false; } diff --git a/service/src/java/org/apache/hive/service/cli/thrift/ThriftHttpServlet.java b/service/src/java/org/apache/hive/service/cli/thrift/ThriftHttpServlet.java index 67ebb605d901..43f963107fae 100644 --- a/service/src/java/org/apache/hive/service/cli/thrift/ThriftHttpServlet.java +++ b/service/src/java/org/apache/hive/service/cli/thrift/ThriftHttpServlet.java @@ -382,7 +382,8 @@ private String doSamlAuth(HttpServletRequest request, HttpServletResponse respon LOG.info("Successfully validated the token for user {}", user); // token is valid; now confirm if the client identifier matches with the relay state. Map keyValues = new HashMap<>(); - if (HiveSamlAuthTokenGenerator.parse(token, keyValues)) { + String decodedToken = new String(Base64.getDecoder().decode(token), java.nio.charset.StandardCharsets.UTF_8); + if (HiveSamlAuthTokenGenerator.parse(decodedToken, keyValues)) { String relayStateKey = keyValues.get(HiveSamlAuthTokenGenerator.RELAY_STATE); if (!HiveSamlRelayStateStore.get() .validateClientIdentifier(relayStateKey, clientIdentifier)) {