Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.net.InetAddress;
import java.net.ServerSocket;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
Expand Down Expand Up @@ -137,9 +138,17 @@ public void cleanUpIdpEnv() {
idpContainer.stop();
idpContainer = null;
}
if (miniHS2 != null) {
if (miniHS2 != null && miniHS2.isStarted()) {
miniHS2.stop();
}
HiveSamlAuthTokenGenerator.shutdown();
}

private static ISAMLAuthTokenGenerator createTokenGenerator(String tokenTtl) {
HiveSamlAuthTokenGenerator.shutdown();
HiveConf conf = new HiveConf();
conf.setVar(ConfVars.HIVE_SERVER2_SAML_CALLBACK_TOKEN_TTL, tokenTtl);
return HiveSamlAuthTokenGenerator.get(conf);
}

private void setupIDP(boolean useSignedAssertions, String authMode) throws Exception {
Expand Down Expand Up @@ -554,6 +563,86 @@ public void testTokenReuse() throws Exception {
}
}

@Test
public void testValidTokenRoundTrip() throws Exception {
ISAMLAuthTokenGenerator tokenGenerator = createTokenGenerator("30s");
String token = tokenGenerator.get("alice", "relay-state-1");
assertEquals("alice", tokenGenerator.validate(token));
}

@Test
public void testForgedSignatureRejected() throws Exception {
ISAMLAuthTokenGenerator tokenGenerator = createTokenGenerator("30s");
String forgedPayload = "u=alice;id=1337;time=" + System.currentTimeMillis()
+ ";rs=deadbeef;sg=bogus";
try {
String forgedToken = Base64.getEncoder().encodeToString(forgedPayload.getBytes(StandardCharsets.UTF_8));
tokenGenerator.validate(forgedToken);
fail("Expected forged token to be rejected");
} catch (HttpSamlAuthenticationException e) {
assertEquals("Token could not be verified", e.getMessage());
}
}

@Test
public void testInvalidTokenRejected() throws Exception {
ISAMLAuthTokenGenerator tokenGenerator = createTokenGenerator("30s");
try {
tokenGenerator.validate("notAValidToken");
fail("Expected malformed base64 token to be rejected");
} catch (HttpSamlAuthenticationException e) {
assertEquals("Invalid token", e.getMessage());
}
String invalidStructure = Base64.getEncoder().encodeToString("foo".getBytes());
try {
tokenGenerator.validate(invalidStructure);
fail("Expected invalid token structure to be rejected");
} catch (HttpSamlAuthenticationException e) {
assertEquals("Invalid token", e.getMessage());
}
}

@Test
public void testExpiredTokenRejected() throws Exception {
ISAMLAuthTokenGenerator tokenGenerator = createTokenGenerator("1s");
String token = tokenGenerator.get("alice", "relay-state-1");
Thread.sleep(1100);
try {
tokenGenerator.validate(token);
fail("Expected expired token to be rejected");
} catch (HttpSamlAuthenticationException e) {
assertEquals("Token is expired", e.getMessage());
}
}

@Test
public void testParseHandlesBase64PaddingInSignature() {
Map<String, String> kv = new HashMap<>();
String token = "u=alice;id=1;time=1000;rs=rs1;sg=YWJjZA==";
assertTrue(HiveSamlAuthTokenGenerator.parse(token, kv));
assertEquals("alice", kv.get("u"));
assertEquals("YWJjZA==", kv.get("sg"));
}

@Test
public void testParseRejectsEncodedBearerToken() {
Map<String, String> kv = new HashMap<>();
String encoded = Base64.getEncoder().encodeToString(
"u=alice;id=1;time=1000;rs=rs1;sg=abc".getBytes());
assertFalse(HiveSamlAuthTokenGenerator.parse(encoded, kv));
}

@Test
public void testParseDecodedTokenFromGenerator() throws Exception {
ISAMLAuthTokenGenerator tokenGenerator = createTokenGenerator("30s");
String encoded = tokenGenerator.get("bob", "relay-42");
String decoded = new String(Base64.getDecoder().decode(encoded), StandardCharsets.UTF_8);
Map<String, String> kv = new HashMap<>();
assertTrue(HiveSamlAuthTokenGenerator.parse(decoded, kv));
assertEquals("bob", kv.get("u"));
assertEquals("relay-42", kv.get(HiveSamlAuthTokenGenerator.RELAY_STATE));
}

private static void assertLoggedInUser(HiveConnection connection, String expectedUser)
throws SQLException {
Statement stmt = connection.createStatement();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.hive.service.auth.saml;

import com.google.common.annotations.VisibleForTesting;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
Expand Down Expand Up @@ -78,11 +79,11 @@ public String get(String username, String relayStateKey) {
}

private String encode(String token) {
return Base64.getEncoder().encodeToString(token.getBytes());
return Base64.getEncoder().encodeToString(token.getBytes(StandardCharsets.UTF_8));
}

private String decode(String encodedToken) {
return new String(Base64.getDecoder().decode(encodedToken));
return new String(Base64.getDecoder().decode(encodedToken), StandardCharsets.UTF_8);
}

private String getTokenStr(String username, String id, String timestamp,
Expand All @@ -100,7 +101,7 @@ private String getTokenStr(String username, String id, String timestamp,
private String getSign(String input) {
try {
MessageDigest md = MessageDigest.getInstance("SHA-256");
md.update(input.getBytes());
md.update(input.getBytes(StandardCharsets.UTF_8));
md.update(signatureSecret);
byte[] digest = md.digest();
return Base64.getEncoder().encodeToString(digest);
Expand Down Expand Up @@ -144,7 +145,8 @@ private boolean isExpired(long currentTime, long tokenTime) {
}

private boolean signatureMatches(String origSign, String derivedSign) {
return !MessageDigest.isEqual(origSign.getBytes(), derivedSign.getBytes());
return MessageDigest.isEqual(origSign.getBytes(StandardCharsets.UTF_8),
derivedSign.getBytes(StandardCharsets.UTF_8));
}
Comment thread
saihemanth-cloudera marked this conversation as resolved.

public static boolean parse(String token, Map<String, String> kv) {
Expand All @@ -153,7 +155,7 @@ public static boolean parse(String token, Map<String, String> kv) {
return false;
}
for (String split : splits) {
String[] pair = split.split(SEPARATOR);
String[] pair = split.split(SEPARATOR, 2);
if (pair.length != 2) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ private String doSamlAuth(HttpServletRequest request, HttpServletResponse respon
LOG.info("Successfully validated the token for user {}", user);
// token is valid; now confirm if the client identifier matches with the relay state.
Map<String, String> keyValues = new HashMap<>();
if (HiveSamlAuthTokenGenerator.parse(token, keyValues)) {
String decodedToken = new String(Base64.getDecoder().decode(token), java.nio.charset.StandardCharsets.UTF_8);
if (HiveSamlAuthTokenGenerator.parse(decodedToken, keyValues)) {
String relayStateKey = keyValues.get(HiveSamlAuthTokenGenerator.RELAY_STATE);
if (!HiveSamlRelayStateStore.get()
.validateClientIdentifier(relayStateKey, clientIdentifier)) {
Expand Down
Loading