diff --git a/java/driver/flight-sql/pom.xml b/java/driver/flight-sql/pom.xml
index e0bd881855..78b84fa249 100644
--- a/java/driver/flight-sql/pom.xml
+++ b/java/driver/flight-sql/pom.xml
@@ -80,6 +80,13 @@
flight-sql-jdbc-core
+
+
+ com.nimbusds
+ oauth2-oidc-sdk
+ 11.20.1
+
+
org.checkerframework
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java
index f099cb64c8..f78ebb0245 100644
--- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java
@@ -38,6 +38,7 @@
import org.apache.arrow.adbc.core.AdbcStatement;
import org.apache.arrow.adbc.core.AdbcStatusCode;
import org.apache.arrow.adbc.core.BulkIngestMode;
+import org.apache.arrow.adbc.driver.flightsql.oauth.FlightSqlOAuthCredentialWriter;
import org.apache.arrow.adbc.sql.SqlQuirks;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.FlightCallHeaders;
@@ -58,6 +59,10 @@
import org.checkerframework.checker.nullness.qual.Nullable;
public class FlightSqlConnection implements AdbcConnection {
+ private static final String AUTH_HEADER_CONFLICT_ERROR =
+ "[Flight SQL] Authentication conflict: Use either Authorization header or OAuth options, "
+ + "or username/password parameters";
+
private final BufferAllocator allocator;
private final AtomicInteger counter = new AtomicInteger(0);
private final FlightSqlClientWithCallOptions client;
@@ -107,9 +112,18 @@ public class FlightSqlConnection implements AdbcConnection {
.build(
loc -> {
FlightClient client = buildClient(loc);
- client.handshake(callOptions);
- return new FlightSqlClientWithCallOptions(
- new FlightSqlClient(client), callOptions);
+ try {
+ client.handshake(callOptions);
+ return new FlightSqlClientWithCallOptions(
+ new FlightSqlClient(client), callOptions);
+ } catch (RuntimeException ex) {
+ try {
+ client.close();
+ } catch (Exception closeEx) {
+ ex.addSuppressed(closeEx);
+ }
+ throw ex;
+ }
});
this.clientCache.put(location, this.client);
}
@@ -262,54 +276,91 @@ private FlightClient createInitialConnection(
}
}
- // Build the client using the above properties.
- final FlightClient client = buildClient(location);
-
// Add user-specified headers.
ArrayList options = new ArrayList<>();
final FlightCallHeaders callHeaders = new FlightCallHeaders();
- for (Map.Entry parameter : parameters.entrySet()) {
- if (parameter.getKey().startsWith(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX)) {
- String userHeaderName =
- parameter
- .getKey()
- .substring(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX.length());
-
- if (parameter.getValue() instanceof String) {
- callHeaders.insert(userHeaderName, (String) parameter.getValue());
- } else if (parameter.getValue() instanceof byte[]) {
- callHeaders.insert(userHeaderName, (byte[]) parameter.getValue());
- } else {
- throw new AdbcException(
- String.format(
- "Header values must be String or byte[]. The header failing was %s.",
- parameter.getKey()),
- null,
- AdbcStatusCode.INVALID_ARGUMENT,
- null,
- 0);
+ String authorizationHeader = null;
+ String username = null;
+ String password = null;
+ String oauthFlow = null;
+ if (parameters != null) {
+ for (Map.Entry parameter : parameters.entrySet()) {
+ if (parameter.getKey().startsWith(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX)) {
+ String userHeaderName =
+ parameter
+ .getKey()
+ .substring(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX.length());
+
+ if (parameter.getValue() instanceof String) {
+ callHeaders.insert(userHeaderName, (String) parameter.getValue());
+ } else if (parameter.getValue() instanceof byte[]) {
+ callHeaders.insert(userHeaderName, (byte[]) parameter.getValue());
+ } else {
+ throw new AdbcException(
+ String.format(
+ "Header values must be String or byte[]. The header failing was %s.",
+ parameter.getKey()),
+ null,
+ AdbcStatusCode.INVALID_ARGUMENT,
+ null,
+ 0);
+ }
}
}
+
+ authorizationHeader = FlightSqlConnectionProperties.AUTHORIZATION_HEADER.get(parameters);
+ username = AdbcDriver.PARAM_USERNAME.get(parameters);
+ password = AdbcDriver.PARAM_PASSWORD.get(parameters);
+ oauthFlow = FlightSqlConnectionProperties.OAUTH_FLOW.get(parameters);
+ }
+
+ if (authorizationHeader != null) {
+ callHeaders.insert("authorization", authorizationHeader);
}
options.add(new HeaderCallOption(callHeaders));
- // Test the connection.
- String username = AdbcDriver.PARAM_USERNAME.get(parameters);
- String password = AdbcDriver.PARAM_PASSWORD.get(parameters);
- if (username != null && password != null) {
- Optional bearerToken =
- client.authenticateBasicToken(username, password);
- options.add(
- bearerToken.orElse(
- new CredentialCallOption(new BasicAuthCredentialWriter(username, password))));
- this.callOptions = options.toArray(new CallOption[0]);
- } else {
- this.callOptions = options.toArray(new CallOption[0]);
- client.handshake(this.callOptions);
+ final boolean hasAuthorizationHeader = authorizationHeader != null;
+ final boolean hasUsernamePassword = username != null || password != null;
+ final boolean hasOauth = oauthFlow != null;
+
+ if ((hasAuthorizationHeader && (hasUsernamePassword || hasOauth))
+ || (hasUsernamePassword && hasOauth)) {
+ throw AdbcException.invalidArgument(AUTH_HEADER_CONFLICT_ERROR);
}
- return client;
+ // Build the client using the above properties.
+ final FlightClient client = buildClient(location);
+
+ try {
+ // Test the connection.
+ if (hasOauth) {
+ final FlightSqlOAuthCredentialWriter oauthCredentialWriter =
+ FlightSqlOAuthCredentialWriter.create(parameters);
+ oauthCredentialWriter.prefetchToken();
+ options.add(new CredentialCallOption(oauthCredentialWriter));
+ this.callOptions = options.toArray(new CallOption[0]);
+ client.handshake(this.callOptions);
+ } else if (username != null && password != null) {
+ Optional bearerToken =
+ client.authenticateBasicToken(username, password);
+ options.add(
+ bearerToken.orElse(
+ new CredentialCallOption(new BasicAuthCredentialWriter(username, password))));
+ this.callOptions = options.toArray(new CallOption[0]);
+ } else {
+ this.callOptions = options.toArray(new CallOption[0]);
+ client.handshake(this.callOptions);
+ }
+ return client;
+ } catch (AdbcException | RuntimeException ex) {
+ try {
+ client.close();
+ } catch (Exception closeEx) {
+ ex.addSuppressed(closeEx);
+ }
+ throw ex;
+ }
}
/** Returns a yet-to-be authenticated FlightClient */
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java
index 4ab1955a1b..a6ed5f2650 100644
--- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java
@@ -22,10 +22,33 @@
/** Defines connection options that are used by the FlightSql driver. */
public interface FlightSqlConnectionProperties {
+ TypedKey AUTHORIZATION_HEADER =
+ new TypedKey<>("adbc.flight.sql.authorization_header", String.class);
TypedKey MTLS_CERT_CHAIN =
new TypedKey<>("adbc.flight.sql.client_option.mtls_cert_chain", InputStream.class);
TypedKey MTLS_PRIVATE_KEY =
new TypedKey<>("adbc.flight.sql.client_option.mtls_private_key", InputStream.class);
+ TypedKey OAUTH_FLOW = new TypedKey<>("adbc.flight.sql.oauth.flow", String.class);
+ TypedKey OAUTH_TOKEN_URI =
+ new TypedKey<>("adbc.flight.sql.oauth.token_uri", String.class);
+ TypedKey OAUTH_CLIENT_ID =
+ new TypedKey<>("adbc.flight.sql.oauth.client_id", String.class);
+ TypedKey OAUTH_CLIENT_SECRET =
+ new TypedKey<>("adbc.flight.sql.oauth.client_secret", String.class);
+ TypedKey OAUTH_SCOPE = new TypedKey<>("adbc.flight.sql.oauth.scope", String.class);
+ TypedKey OAUTH_RESOURCE = new TypedKey<>("adbc.flight.sql.oauth.resource", String.class);
+ TypedKey OAUTH_EXCHANGE_SUBJECT_TOKEN =
+ new TypedKey<>("adbc.flight.sql.oauth.exchange.subject_token", String.class);
+ TypedKey OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE =
+ new TypedKey<>("adbc.flight.sql.oauth.exchange.subject_token_type", String.class);
+ TypedKey OAUTH_EXCHANGE_ACTOR_TOKEN =
+ new TypedKey<>("adbc.flight.sql.oauth.exchange.actor_token", String.class);
+ TypedKey OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE =
+ new TypedKey<>("adbc.flight.sql.oauth.exchange.actor_token_type", String.class);
+ TypedKey OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE =
+ new TypedKey<>("adbc.flight.sql.oauth.exchange.requested_token_type", String.class);
+ TypedKey OAUTH_EXCHANGE_AUD =
+ new TypedKey<>("adbc.flight.sql.oauth.exchange.aud", String.class);
TypedKey TLS_OVERRIDE_HOSTNAME =
new TypedKey<>("adbc.flight.sql.client_option.tls_override_hostname", String.class);
TypedKey TLS_SKIP_VERIFY =
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java
index af8221d6b0..706f69ff9f 100644
--- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java
@@ -64,6 +64,13 @@ public AdbcConnection connect() throws AdbcException {
adbcException.addSuppressed(e);
}
throw adbcException;
+ } catch (AdbcException ex) {
+ try {
+ AutoCloseables.close(connectionAllocator);
+ } catch (Exception e) {
+ ex.addSuppressed(e);
+ }
+ throw ex;
} catch (Exception ex) {
AdbcException adbcException = FlightSqlDriverUtil.fromGeneralException(ex);
try {
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthConfiguration.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthConfiguration.java
new file mode 100644
index 0000000000..203f524bb2
--- /dev/null
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthConfiguration.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.adbc.driver.flightsql.oauth;
+
+import com.nimbusds.oauth2.sdk.GrantType;
+import com.nimbusds.oauth2.sdk.ParseException;
+import com.nimbusds.oauth2.sdk.token.TokenTypeURI;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.Map;
+import org.apache.arrow.adbc.core.AdbcException;
+import org.apache.arrow.adbc.core.TypedKey;
+import org.apache.arrow.adbc.driver.flightsql.FlightSqlConnectionProperties;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+final class FlightSqlOAuthConfiguration {
+ private final GrantType flowType;
+ private final @Nullable String tokenUri;
+ private final @Nullable String clientId;
+ private final @Nullable String clientSecret;
+ private final @Nullable String scope;
+ private final @Nullable String subjectToken;
+ private final @Nullable TokenTypeURI subjectTokenType;
+ private final @Nullable String actorToken;
+ private final @Nullable TokenTypeURI actorTokenType;
+ private final @Nullable TokenTypeURI requestedTokenType;
+ private final @Nullable String audience;
+ private final @Nullable URI resource;
+
+ private FlightSqlOAuthConfiguration(
+ GrantType flowType,
+ @Nullable String tokenUri,
+ @Nullable String clientId,
+ @Nullable String clientSecret,
+ @Nullable String scope,
+ @Nullable String subjectToken,
+ @Nullable TokenTypeURI subjectTokenType,
+ @Nullable String actorToken,
+ @Nullable TokenTypeURI actorTokenType,
+ @Nullable TokenTypeURI requestedTokenType,
+ @Nullable String audience,
+ @Nullable URI resource) {
+ this.flowType = flowType;
+ this.tokenUri = tokenUri;
+ this.clientId = clientId;
+ this.clientSecret = clientSecret;
+ this.scope = scope;
+ this.subjectToken = subjectToken;
+ this.subjectTokenType = subjectTokenType;
+ this.actorToken = actorToken;
+ this.actorTokenType = actorTokenType;
+ this.requestedTokenType = requestedTokenType;
+ this.audience = audience;
+ this.resource = resource;
+ }
+
+ static FlightSqlOAuthConfiguration from(Map parameters) throws AdbcException {
+ final GrantType flowType =
+ parseGrantType(requireOption(parameters, FlightSqlConnectionProperties.OAUTH_FLOW));
+ final @Nullable String clientId = FlightSqlConnectionProperties.OAUTH_CLIENT_ID.get(parameters);
+ final @Nullable String clientSecret =
+ FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.get(parameters);
+ final @Nullable String subjectToken =
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN.get(parameters);
+ final @Nullable String subjectTokenTypeValue =
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE.get(parameters);
+ final @Nullable String actorToken =
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN.get(parameters);
+ final @Nullable String actorTokenTypeValue =
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE.get(parameters);
+
+ final boolean clientCredentials = GrantType.CLIENT_CREDENTIALS.equals(flowType);
+ final boolean tokenExchange = GrantType.TOKEN_EXCHANGE.equals(flowType);
+ final @Nullable String tokenUri =
+ (clientCredentials || tokenExchange)
+ ? requireOption(parameters, FlightSqlConnectionProperties.OAUTH_TOKEN_URI)
+ : null;
+
+ if (clientCredentials) {
+ requireOption(parameters, FlightSqlConnectionProperties.OAUTH_CLIENT_ID);
+ requireOption(parameters, FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET);
+ }
+ if (tokenExchange) {
+ requireOption(parameters, FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN);
+ requireOption(parameters, FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE);
+ if ((actorToken == null) != (actorTokenTypeValue == null)) {
+ throw AdbcException.invalidArgument(
+ "[Flight SQL] token exchange grant requires "
+ + FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE.getKey()
+ + " when "
+ + FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN.getKey()
+ + " is provided");
+ }
+ if ((clientId == null) != (clientSecret == null)) {
+ throw AdbcException.invalidArgument(
+ "[Flight SQL] token exchange grant requires both "
+ + FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey()
+ + " and "
+ + FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey()
+ + " when client credentials are provided");
+ }
+ }
+
+ final @Nullable String resource =
+ tokenExchange ? FlightSqlConnectionProperties.OAUTH_RESOURCE.get(parameters) : null;
+ return new FlightSqlOAuthConfiguration(
+ flowType,
+ tokenUri,
+ clientId,
+ clientSecret,
+ FlightSqlConnectionProperties.OAUTH_SCOPE.get(parameters),
+ subjectToken,
+ parseTokenTypeUri(
+ subjectTokenTypeValue,
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE),
+ actorToken,
+ parseTokenTypeUri(
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE.get(parameters),
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE),
+ parseTokenTypeUri(
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE.get(parameters),
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE),
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_AUD.get(parameters),
+ parseResource(resource));
+ }
+
+ GrantType flowType() {
+ return flowType;
+ }
+
+ @Nullable String tokenUri() {
+ return tokenUri;
+ }
+
+ @Nullable String clientId() {
+ return clientId;
+ }
+
+ @Nullable String clientSecret() {
+ return clientSecret;
+ }
+
+ @Nullable String scope() {
+ return scope;
+ }
+
+ @Nullable String subjectToken() {
+ return subjectToken;
+ }
+
+ @Nullable TokenTypeURI subjectTokenType() {
+ return subjectTokenType;
+ }
+
+ @Nullable String actorToken() {
+ return actorToken;
+ }
+
+ @Nullable TokenTypeURI actorTokenType() {
+ return actorTokenType;
+ }
+
+ @Nullable TokenTypeURI requestedTokenType() {
+ return requestedTokenType;
+ }
+
+ @Nullable String audience() {
+ return audience;
+ }
+
+ @Nullable URI resource() {
+ return resource;
+ }
+
+ private static String requireOption(Map parameters, TypedKey option)
+ throws AdbcException {
+ final String value = option.get(parameters);
+ if (value == null) {
+ throw AdbcException.invalidArgument("[Flight SQL] OAuth flow requires " + option.getKey());
+ }
+ return value;
+ }
+
+ private static GrantType parseGrantType(String value) throws AdbcException {
+ try {
+ return GrantType.parse(value);
+ } catch (ParseException e) {
+ throw AdbcException.invalidArgument("[Flight SQL] invalid OAuth flow: " + value).withCause(e);
+ }
+ }
+
+ private static @Nullable TokenTypeURI parseTokenTypeUri(
+ @Nullable String value, TypedKey option) throws AdbcException {
+ if (value == null) {
+ return null;
+ }
+ try {
+ return TokenTypeURI.parse(value);
+ } catch (ParseException e) {
+ throw AdbcException.invalidArgument(
+ "[Flight SQL] invalid OAuth token type for " + option.getKey() + ": " + value)
+ .withCause(e);
+ }
+ }
+
+ private static @Nullable URI parseResource(@Nullable String resource) throws AdbcException {
+ if (resource == null) {
+ return null;
+ }
+ try {
+ return new URI(resource);
+ } catch (URISyntaxException e) {
+ throw AdbcException.invalidArgument(
+ "[Flight SQL] token exchange grant requires a valid URI for "
+ + FlightSqlConnectionProperties.OAUTH_RESOURCE.getKey())
+ .withCause(e);
+ }
+ }
+}
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthCredentialWriter.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthCredentialWriter.java
new file mode 100644
index 0000000000..7a9b5ad0cd
--- /dev/null
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthCredentialWriter.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.adbc.driver.flightsql.oauth;
+
+import java.sql.SQLException;
+import java.util.Map;
+import java.util.function.Consumer;
+import org.apache.arrow.adbc.core.AdbcException;
+import org.apache.arrow.driver.jdbc.client.oauth.OAuthTokenProvider;
+import org.apache.arrow.flight.CallHeaders;
+
+public final class FlightSqlOAuthCredentialWriter implements Consumer {
+ private final OAuthTokenProvider tokenProvider;
+
+ private FlightSqlOAuthCredentialWriter(OAuthTokenProvider tokenProvider) {
+ this.tokenProvider = tokenProvider;
+ }
+
+ public static FlightSqlOAuthCredentialWriter create(Map parameters)
+ throws AdbcException {
+ final FlightSqlOAuthConfiguration configuration =
+ FlightSqlOAuthConfiguration.from(parameters);
+ return new FlightSqlOAuthCredentialWriter(FlightSqlOAuthTokenProviders.create(configuration));
+ }
+
+ public void prefetchToken() throws AdbcException {
+ currentAuthorizationValue();
+ }
+
+ @Override
+ public void accept(CallHeaders headers) {
+ try {
+ headers.insert("authorization", currentAuthorizationValue());
+ } catch (AdbcException e) {
+ throw new IllegalStateException(e.getMessage(), e);
+ }
+ }
+
+ private String currentAuthorizationValue() throws AdbcException {
+ try {
+ return "Bearer " + tokenProvider.getValidToken();
+ } catch (SQLException e) {
+ throw AdbcException.io("[Flight SQL] OAuth token request failed: " + e.getMessage())
+ .withCause(e);
+ }
+ }
+}
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthTokenProviders.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthTokenProviders.java
new file mode 100644
index 0000000000..5f1fd0187f
--- /dev/null
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthTokenProviders.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.adbc.driver.flightsql.oauth;
+
+import com.nimbusds.oauth2.sdk.GrantType;
+import java.util.Objects;
+import org.apache.arrow.adbc.core.AdbcException;
+import org.apache.arrow.driver.jdbc.client.oauth.OAuthTokenProvider;
+import org.apache.arrow.driver.jdbc.client.oauth.OAuthTokenProviders;
+
+final class FlightSqlOAuthTokenProviders {
+ private FlightSqlOAuthTokenProviders() {}
+
+ static OAuthTokenProvider create(FlightSqlOAuthConfiguration configuration)
+ throws AdbcException {
+ final GrantType flowType = configuration.flowType();
+ if (GrantType.CLIENT_CREDENTIALS.equals(flowType)) {
+ return createClientCredentialsProvider(configuration);
+ }
+ if (GrantType.TOKEN_EXCHANGE.equals(flowType)) {
+ return createTokenExchangeProvider(configuration);
+ }
+ throw AdbcException.notImplemented("[Flight SQL] oauth flow not implemented: " + flowType);
+ }
+
+ private static OAuthTokenProvider createClientCredentialsProvider(
+ FlightSqlOAuthConfiguration configuration) throws AdbcException {
+ try {
+ final OAuthTokenProviders.ClientCredentialsBuilder builder =
+ OAuthTokenProviders.clientCredentials()
+ .tokenUri(Objects.requireNonNull(configuration.tokenUri()))
+ .clientId(Objects.requireNonNull(configuration.clientId()))
+ .clientSecret(Objects.requireNonNull(configuration.clientSecret()));
+
+ final String scope = configuration.scope();
+ if (scope != null) {
+ builder.scope(scope);
+ }
+ return builder.build();
+ } catch (RuntimeException e) {
+ throw AdbcException.invalidArgument(
+ "[Flight SQL] Invalid OAuth client credentials configuration: " + e.getMessage())
+ .withCause(e);
+ }
+ }
+
+ private static OAuthTokenProvider createTokenExchangeProvider(
+ FlightSqlOAuthConfiguration configuration) throws AdbcException {
+ try {
+ final OAuthTokenProviders.TokenExchangeBuilder builder =
+ OAuthTokenProviders.tokenExchange()
+ .tokenUri(Objects.requireNonNull(configuration.tokenUri()))
+ .subjectToken(Objects.requireNonNull(configuration.subjectToken()))
+ .subjectTokenType(
+ Objects.requireNonNull(configuration.subjectTokenType()).toString());
+
+ if (configuration.actorToken() != null) {
+ builder
+ .actorToken(configuration.actorToken())
+ .actorTokenType(Objects.requireNonNull(configuration.actorTokenType()).toString());
+ }
+ if (configuration.clientId() != null) {
+ builder.clientCredentials(
+ configuration.clientId(), Objects.requireNonNull(configuration.clientSecret()));
+ }
+ if (configuration.audience() != null) {
+ builder.audience(configuration.audience());
+ }
+ if (configuration.requestedTokenType() != null) {
+ builder.requestedTokenType(configuration.requestedTokenType().toString());
+ }
+ if (configuration.resource() != null) {
+ builder.resource(configuration.resource());
+ }
+ if (configuration.scope() != null) {
+ builder.scope(configuration.scope());
+ }
+
+ return builder.build();
+ } catch (RuntimeException e) {
+ throw AdbcException.invalidArgument(
+ "[Flight SQL] Invalid OAuth token exchange configuration: " + e.getMessage())
+ .withCause(e);
+ }
+ }
+}
diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/OAuthTest.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/OAuthTest.java
new file mode 100644
index 0000000000..c19f432bbc
--- /dev/null
+++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/OAuthTest.java
@@ -0,0 +1,510 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.adbc.driver.flightsql;
+
+import static org.junit.Assert.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import com.nimbusds.oauth2.sdk.GrantType;
+import com.nimbusds.oauth2.sdk.http.HTTPRequest;
+import com.nimbusds.oauth2.sdk.token.TokenTypeURI;
+import com.sun.net.httpserver.Headers;
+import com.sun.net.httpserver.HttpExchange;
+import com.sun.net.httpserver.HttpHandler;
+import com.sun.net.httpserver.HttpServer;
+import com.sun.net.httpserver.HttpsConfigurator;
+import com.sun.net.httpserver.HttpsServer;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.net.InetSocketAddress;
+import java.net.URLDecoder;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.security.KeyFactory;
+import java.security.KeyStore;
+import java.security.PrivateKey;
+import java.security.SecureRandom;
+import java.security.cert.Certificate;
+import java.security.cert.CertificateFactory;
+import java.security.cert.X509Certificate;
+import java.security.spec.PKCS8EncodedKeySpec;
+import java.util.ArrayList;
+import java.util.Base64;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import javax.net.ssl.HostnameVerifier;
+import javax.net.ssl.HttpsURLConnection;
+import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLSocketFactory;
+import javax.net.ssl.TrustManagerFactory;
+import org.apache.arrow.adbc.core.AdbcConnection;
+import org.apache.arrow.adbc.core.AdbcDatabase;
+import org.apache.arrow.adbc.core.AdbcDriver;
+import org.apache.arrow.adbc.core.AdbcException;
+import org.apache.arrow.adbc.core.AdbcInfoCode;
+import org.apache.arrow.adbc.core.AdbcStatusCode;
+import org.apache.arrow.adbc.drivermanager.AdbcDriverManager;
+import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer;
+import org.apache.arrow.flight.CallHeaders;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.ipc.ArrowReader;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+public class OAuthTest {
+ private static final String CLIENT_ID = "adbc-client";
+ private static final String CLIENT_SECRET = "adbc-secret";
+ private static final String CLIENT_ACCESS_TOKEN = "client-credentials-token";
+ private static final String EXCHANGE_ACCESS_TOKEN = "token-exchange-token";
+ private static final char[] TRUST_STORE_PASSWORD = "changeit".toCharArray();
+
+ private BufferAllocator allocator;
+ private Map params;
+ private FlightServer server;
+ private HttpServer tokenServer;
+ private AdbcDatabase database;
+ private AdbcConnection connection;
+ private HeaderValidator.Factory headerValidatorFactory;
+ private TokenHandler tokenHandler;
+ private String tokenServerScheme;
+ private String previousTrustStore;
+ private String previousTrustStorePassword;
+ private String previousTrustStoreType;
+ private SSLSocketFactory previousDefaultSslSocketFactory;
+ private HostnameVerifier previousDefaultHostnameVerifier;
+ private final List tempPaths = new ArrayList<>();
+
+ @BeforeEach
+ public void setUp() throws IOException {
+ allocator = new RootAllocator(Long.MAX_VALUE);
+ params = new HashMap<>();
+ headerValidatorFactory = new HeaderValidator.Factory();
+ server =
+ FlightServer.builder()
+ .allocator(allocator)
+ .middleware(HeaderValidator.KEY, headerValidatorFactory)
+ .location(Location.forGrpcInsecure("localhost", 0))
+ .producer(new MockFlightSqlProducer())
+ .build();
+ server.start();
+
+ tokenHandler = new TokenHandler();
+ previousTrustStore = System.getProperty("javax.net.ssl.trustStore");
+ previousTrustStorePassword = System.getProperty("javax.net.ssl.trustStorePassword");
+ previousTrustStoreType = System.getProperty("javax.net.ssl.trustStoreType");
+ previousDefaultSslSocketFactory = HTTPRequest.getDefaultSSLSocketFactory();
+ previousDefaultHostnameVerifier = HTTPRequest.getDefaultHostnameVerifier();
+ startHttpTokenServer();
+
+ params.put(
+ AdbcDriver.PARAM_URI.getKey(), String.format("grpc+tcp://localhost:%d", server.getPort()));
+ }
+
+ @AfterEach
+ public void tearDown() throws Exception {
+ AutoCloseables.close(connection, database, server, allocator);
+ if (tokenServer != null) {
+ tokenServer.stop(0);
+ }
+ restoreJvmTrustStoreConfiguration();
+ for (Path path : tempPaths) {
+ Files.deleteIfExists(path);
+ }
+ connection = null;
+ database = null;
+ server = null;
+ allocator = null;
+ tokenServer = null;
+ }
+
+ @Test
+ public void testClientCredentialsFlow() throws Exception {
+ tokenHandler.accessToken = CLIENT_ACCESS_TOKEN;
+ params.put(
+ FlightSqlConnectionProperties.OAUTH_FLOW.getKey(),
+ GrantType.CLIENT_CREDENTIALS.getValue());
+ params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri());
+ params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID);
+ params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey(), CLIENT_SECRET);
+ params.put(FlightSqlConnectionProperties.OAUTH_SCOPE.getKey(), "scope-a scope-b");
+
+ connect();
+ requestServerMetadata();
+
+ CallHeaders headers = headerValidatorFactory.getHeadersReceivedAtRequest(0);
+ assertEquals("Bearer " + CLIENT_ACCESS_TOKEN, headers.get("authorization"));
+
+ assertEquals(1, tokenHandler.requestBodies.size());
+ assertEquals("client_credentials", tokenHandler.formValue(0, "grant_type"));
+ assertEquals("scope-a scope-b", tokenHandler.formValue(0, "scope"));
+ assertTrue(tokenHandler.authorizationHeaders.get(0).startsWith("Basic "));
+ }
+
+ @Test
+ public void testTokenExchangeFlow() throws Exception {
+ tokenHandler.accessToken = EXCHANGE_ACCESS_TOKEN;
+ params.put(
+ FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), GrantType.TOKEN_EXCHANGE.getValue());
+ params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri());
+ params.put(
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN.getKey(), "subject-token");
+ params.put(
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE.getKey(),
+ TokenTypeURI.ACCESS_TOKEN.toString());
+ params.put(FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN.getKey(), "actor-token");
+ params.put(
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE.getKey(),
+ TokenTypeURI.JWT.toString());
+ params.put(
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE.getKey(),
+ TokenTypeURI.ACCESS_TOKEN.toString());
+ params.put(FlightSqlConnectionProperties.OAUTH_EXCHANGE_AUD.getKey(), "flight-service");
+ params.put(
+ FlightSqlConnectionProperties.OAUTH_RESOURCE.getKey(), "https://resource.example.com");
+ params.put(FlightSqlConnectionProperties.OAUTH_SCOPE.getKey(), "profile email");
+ params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID);
+ params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey(), CLIENT_SECRET);
+
+ connect();
+ requestServerMetadata();
+
+ CallHeaders headers = headerValidatorFactory.getHeadersReceivedAtRequest(0);
+ assertEquals("Bearer " + EXCHANGE_ACCESS_TOKEN, headers.get("authorization"));
+
+ assertEquals(1, tokenHandler.requestBodies.size());
+ assertEquals(
+ "urn:ietf:params:oauth:grant-type:token-exchange", tokenHandler.formValue(0, "grant_type"));
+ assertEquals("subject-token", tokenHandler.formValue(0, "subject_token"));
+ assertEquals(
+ TokenTypeURI.ACCESS_TOKEN.toString(), tokenHandler.formValue(0, "subject_token_type"));
+ assertEquals("actor-token", tokenHandler.formValue(0, "actor_token"));
+ assertEquals(TokenTypeURI.JWT.toString(), tokenHandler.formValue(0, "actor_token_type"));
+ assertEquals(
+ TokenTypeURI.ACCESS_TOKEN.toString(), tokenHandler.formValue(0, "requested_token_type"));
+ assertEquals("flight-service", tokenHandler.formValue(0, "audience"));
+ assertEquals("https://resource.example.com", tokenHandler.formValue(0, "resource"));
+ assertEquals("profile email", tokenHandler.formValue(0, "scope"));
+ assertTrue(tokenHandler.authorizationHeaders.get(0).startsWith("Basic "));
+ }
+
+ @Test
+ public void testClientCredentialsFlowWithHttpsTokenEndpointWithoutTrustStore() throws Exception {
+ tokenHandler.accessToken = CLIENT_ACCESS_TOKEN;
+ startHttpsTokenServer();
+ clearJvmTrustStoreConfiguration();
+
+ params.put(
+ FlightSqlConnectionProperties.OAUTH_FLOW.getKey(),
+ GrantType.CLIENT_CREDENTIALS.getValue());
+ params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri());
+ params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID);
+ params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey(), CLIENT_SECRET);
+
+ AdbcException adbcException = assertThrows(AdbcException.class, this::connect);
+ assertEquals(AdbcStatusCode.IO, adbcException.getStatus());
+ }
+
+ @Test
+ public void testClientCredentialsFlowWithHttpsTokenEndpointUsesJvmTrustStore() throws Exception {
+ tokenHandler.accessToken = CLIENT_ACCESS_TOKEN;
+ startHttpsTokenServer();
+ configureJvmTrustStore(createTrustStorePath());
+
+ params.put(
+ FlightSqlConnectionProperties.OAUTH_FLOW.getKey(),
+ GrantType.CLIENT_CREDENTIALS.getValue());
+ params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri());
+ params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID);
+ params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey(), CLIENT_SECRET);
+
+ connect();
+ requestServerMetadata();
+
+ CallHeaders headers = headerValidatorFactory.getHeadersReceivedAtRequest(0);
+ assertEquals("Bearer " + CLIENT_ACCESS_TOKEN, headers.get("authorization"));
+ assertEquals(1, tokenHandler.requestBodies.size());
+ assertEquals("client_credentials", tokenHandler.formValue(0, "grant_type"));
+ }
+
+ @Test
+ public void testAuthorizationHeaderConflictsWithOauth() {
+ params.put(
+ FlightSqlConnectionProperties.AUTHORIZATION_HEADER.getKey(), "Bearer existing-token");
+ params.put(
+ FlightSqlConnectionProperties.OAUTH_FLOW.getKey(),
+ GrantType.CLIENT_CREDENTIALS.getValue());
+ params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri());
+ params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID);
+ params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey(), CLIENT_SECRET);
+
+ AdbcException adbcException = assertThrows(AdbcException.class, this::connect);
+ assertEquals(AdbcStatusCode.INVALID_ARGUMENT, adbcException.getStatus());
+ }
+
+ @Test
+ public void testMissingRequiredParamsClientCredentials() {
+ params.put(
+ FlightSqlConnectionProperties.OAUTH_FLOW.getKey(),
+ GrantType.CLIENT_CREDENTIALS.getValue());
+ params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri());
+ params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID);
+
+ AdbcException adbcException = assertThrows(AdbcException.class, this::connect);
+ assertEquals(AdbcStatusCode.INVALID_ARGUMENT, adbcException.getStatus());
+ }
+
+ @Test
+ public void testMissingRequiredParamsTokenExchange() {
+ params.put(
+ FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), GrantType.TOKEN_EXCHANGE.getValue());
+ params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri());
+ params.put(
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN.getKey(), "subject-token");
+
+ AdbcException adbcException = assertThrows(AdbcException.class, this::connect);
+ assertEquals(AdbcStatusCode.INVALID_ARGUMENT, adbcException.getStatus());
+ }
+
+ @Test
+ public void testActorTokenRequiresActorTokenType() {
+ params.put(
+ FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), GrantType.TOKEN_EXCHANGE.getValue());
+ params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri());
+ params.put(
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN.getKey(), "subject-token");
+ params.put(
+ FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE.getKey(),
+ TokenTypeURI.ACCESS_TOKEN.toString());
+ params.put(FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN.getKey(), "actor-token");
+
+ AdbcException adbcException = assertThrows(AdbcException.class, this::connect);
+ assertEquals(AdbcStatusCode.INVALID_ARGUMENT, adbcException.getStatus());
+ }
+
+ @Test
+ public void testInvalidOauthFlow() {
+ params.put(FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), "invalid-flow");
+ params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri());
+
+ AdbcException adbcException = assertThrows(AdbcException.class, this::connect);
+ assertEquals(AdbcStatusCode.NOT_IMPLEMENTED, adbcException.getStatus());
+ }
+
+ private void connect() throws Exception {
+ database =
+ AdbcDriverManager.getInstance()
+ .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params);
+ connection = database.connect();
+ }
+
+ private void requestServerMetadata() throws Exception {
+ try (ArrowReader reader = connection.getInfo(new int[] {AdbcInfoCode.VENDOR_NAME.getValue()})) {
+ while (reader.loadNextBatch()) {
+ // Only interested in triggering an authenticated RPC.
+ }
+ } catch (Exception ex) {
+ // MockFlightSqlProducer does not implement the full SQL metadata surface.
+ }
+ }
+
+ private String tokenUri() {
+ return String.format(
+ "%s://localhost:%d/token", tokenServerScheme, tokenServer.getAddress().getPort());
+ }
+
+ private void startHttpTokenServer() throws IOException {
+ if (tokenServer != null) {
+ tokenServer.stop(0);
+ }
+ tokenServerScheme = "http";
+ tokenServer = HttpServer.create(new InetSocketAddress("localhost", 0), 0);
+ tokenServer.createContext("/token", tokenHandler);
+ tokenServer.start();
+ }
+
+ private void startHttpsTokenServer() throws Exception {
+ if (tokenServer != null) {
+ tokenServer.stop(0);
+ }
+ tokenServerScheme = "https";
+ final SSLContext sslContext = createServerSslContext();
+ final HttpsServer httpsServer = HttpsServer.create(new InetSocketAddress("localhost", 0), 0);
+ httpsServer.setHttpsConfigurator(new HttpsConfigurator(sslContext));
+ httpsServer.createContext("/token", tokenHandler);
+ httpsServer.start();
+ tokenServer = httpsServer;
+ }
+
+ private void configureJvmTrustStore(Path trustStorePath) throws Exception {
+ System.setProperty("javax.net.ssl.trustStore", trustStorePath.toString());
+ System.setProperty("javax.net.ssl.trustStorePassword", new String(TRUST_STORE_PASSWORD));
+ System.setProperty("javax.net.ssl.trustStoreType", "PKCS12");
+ refreshOAuthHttpsDefaults();
+ }
+
+ private void clearJvmTrustStoreConfiguration() throws Exception {
+ System.clearProperty("javax.net.ssl.trustStore");
+ System.clearProperty("javax.net.ssl.trustStorePassword");
+ System.clearProperty("javax.net.ssl.trustStoreType");
+ refreshOAuthHttpsDefaults();
+ }
+
+ private void restoreJvmTrustStoreConfiguration() {
+ restoreSystemProperty("javax.net.ssl.trustStore", previousTrustStore);
+ restoreSystemProperty("javax.net.ssl.trustStorePassword", previousTrustStorePassword);
+ restoreSystemProperty("javax.net.ssl.trustStoreType", previousTrustStoreType);
+ HTTPRequest.setDefaultSSLSocketFactory(previousDefaultSslSocketFactory);
+ HTTPRequest.setDefaultHostnameVerifier(oauthHostnameVerifier());
+ }
+
+ private void refreshOAuthHttpsDefaults() throws Exception {
+ final TrustManagerFactory trustManagerFactory =
+ TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+ trustManagerFactory.init((KeyStore) null);
+ final SSLContext sslContext = SSLContext.getInstance("TLS");
+ sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom());
+ HTTPRequest.setDefaultSSLSocketFactory(sslContext.getSocketFactory());
+ HTTPRequest.setDefaultHostnameVerifier(oauthHostnameVerifier());
+ }
+
+ private Path createTrustStorePath() throws Exception {
+ final KeyStore trustStore = KeyStore.getInstance("PKCS12");
+ trustStore.load(null, null);
+ trustStore.setCertificateEntry("root", readCertificate(flightDataPath("root-ca.pem")));
+
+ final Path trustStorePath = Files.createTempFile("oauth-truststore", ".p12");
+ tempPaths.add(trustStorePath);
+ try (OutputStream output = Files.newOutputStream(trustStorePath)) {
+ trustStore.store(output, TRUST_STORE_PASSWORD);
+ }
+ return trustStorePath;
+ }
+
+ private SSLContext createServerSslContext() throws Exception {
+ final Certificate certificate = readCertificate(flightDataPath("cert0.pem"));
+ final PrivateKey privateKey = readPrivateKey(flightDataPath("cert0.pkcs1"));
+ final KeyStore keyStore = KeyStore.getInstance("PKCS12");
+ keyStore.load(null, null);
+ keyStore.setKeyEntry(
+ "token-server", privateKey, TRUST_STORE_PASSWORD, new Certificate[] {certificate});
+
+ final KeyManagerFactory keyManagerFactory =
+ KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
+ keyManagerFactory.init(keyStore, TRUST_STORE_PASSWORD);
+
+ final SSLContext sslContext = SSLContext.getInstance("TLS");
+ sslContext.init(keyManagerFactory.getKeyManagers(), null, new SecureRandom());
+ return sslContext;
+ }
+
+ private static X509Certificate readCertificate(Path path) throws Exception {
+ try (InputStream input = Files.newInputStream(path)) {
+ final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
+ return (X509Certificate) certificateFactory.generateCertificate(input);
+ }
+ }
+
+ private static PrivateKey readPrivateKey(Path path) throws Exception {
+ final String pem = Files.readString(path, StandardCharsets.US_ASCII);
+ final String base64 =
+ pem.replace("-----BEGIN PRIVATE KEY-----", "")
+ .replace("-----END PRIVATE KEY-----", "")
+ .replaceAll("\\s+", "");
+ final byte[] der = Base64.getDecoder().decode(base64);
+ return KeyFactory.getInstance("RSA").generatePrivate(new PKCS8EncodedKeySpec(der));
+ }
+
+ private static Path flightDataPath(String filename) {
+ final String dataRoot = System.getProperty("arrow.test.dataRoot");
+ if (dataRoot != null) {
+ return Paths.get(dataRoot).resolve("flight").resolve(filename);
+ }
+ return Paths.get("testing", "data", "flight", filename).toAbsolutePath().normalize();
+ }
+
+ private static void restoreSystemProperty(String key, String value) {
+ if (value == null) {
+ System.clearProperty(key);
+ } else {
+ System.setProperty(key, value);
+ }
+ }
+
+ private HostnameVerifier oauthHostnameVerifier() {
+ if (previousDefaultHostnameVerifier != null) {
+ return previousDefaultHostnameVerifier;
+ }
+ return HttpsURLConnection.getDefaultHostnameVerifier();
+ }
+
+ private static final class TokenHandler implements HttpHandler {
+ private final List requestBodies = new ArrayList<>();
+ private final List authorizationHeaders = new ArrayList<>();
+ private String accessToken = CLIENT_ACCESS_TOKEN;
+
+ @Override
+ public void handle(HttpExchange exchange) throws IOException {
+ final String body =
+ new String(exchange.getRequestBody().readAllBytes(), StandardCharsets.UTF_8);
+ requestBodies.add(body);
+ authorizationHeaders.add(exchange.getRequestHeaders().getFirst("Authorization"));
+
+ final byte[] responseBytes =
+ ("{\"access_token\":\""
+ + accessToken
+ + "\",\"token_type\":\"Bearer\",\"expires_in\":3600}")
+ .getBytes(StandardCharsets.UTF_8);
+ final Headers responseHeaders = exchange.getResponseHeaders();
+ responseHeaders.add("Content-Type", "application/json");
+ exchange.sendResponseHeaders(200, responseBytes.length);
+ try (OutputStream output = exchange.getResponseBody()) {
+ output.write(responseBytes);
+ }
+ }
+
+ private String formValue(int requestIndex, String key) {
+ return decodeForm(requestBodies.get(requestIndex)).get(key);
+ }
+
+ private static Map decodeForm(String body) {
+ final Map values = new LinkedHashMap<>();
+ if (body.isEmpty()) {
+ return values;
+ }
+ for (String pair : body.split("&")) {
+ final String[] parts = pair.split("=", 2);
+ final String name = URLDecoder.decode(parts[0], StandardCharsets.UTF_8);
+ final String value =
+ parts.length > 1 ? URLDecoder.decode(parts[1], StandardCharsets.UTF_8) : "";
+ values.put(name, value);
+ }
+ return values;
+ }
+ }
+}