From 36a37ffe4eceda02effbdd42b5730a77b5fbebdf Mon Sep 17 00:00:00 2001 From: "xinyu.lin" Date: Sat, 2 May 2026 09:43:44 +0800 Subject: [PATCH] feat(java/driver/flight-sql): add OAuth2 support --- java/driver/flight-sql/pom.xml | 7 + .../driver/flightsql/FlightSqlConnection.java | 131 +++-- .../FlightSqlConnectionProperties.java | 23 + .../driver/flightsql/FlightSqlDatabase.java | 7 + .../oauth/FlightSqlOAuthConfiguration.java | 234 ++++++++ .../oauth/FlightSqlOAuthCredentialWriter.java | 62 +++ .../oauth/FlightSqlOAuthTokenProviders.java | 101 ++++ .../adbc/driver/flightsql/OAuthTest.java | 510 ++++++++++++++++++ 8 files changed, 1035 insertions(+), 40 deletions(-) create mode 100644 java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthConfiguration.java create mode 100644 java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthCredentialWriter.java create mode 100644 java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthTokenProviders.java create mode 100644 java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/OAuthTest.java diff --git a/java/driver/flight-sql/pom.xml b/java/driver/flight-sql/pom.xml index e0bd881855..78b84fa249 100644 --- a/java/driver/flight-sql/pom.xml +++ b/java/driver/flight-sql/pom.xml @@ -80,6 +80,13 @@ flight-sql-jdbc-core + + + com.nimbusds + oauth2-oidc-sdk + 11.20.1 + + org.checkerframework diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java index f099cb64c8..f78ebb0245 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java @@ -38,6 +38,7 @@ import org.apache.arrow.adbc.core.AdbcStatement; import org.apache.arrow.adbc.core.AdbcStatusCode; import org.apache.arrow.adbc.core.BulkIngestMode; +import org.apache.arrow.adbc.driver.flightsql.oauth.FlightSqlOAuthCredentialWriter; import org.apache.arrow.adbc.sql.SqlQuirks; import org.apache.arrow.flight.CallOption; import org.apache.arrow.flight.FlightCallHeaders; @@ -58,6 +59,10 @@ import org.checkerframework.checker.nullness.qual.Nullable; public class FlightSqlConnection implements AdbcConnection { + private static final String AUTH_HEADER_CONFLICT_ERROR = + "[Flight SQL] Authentication conflict: Use either Authorization header or OAuth options, " + + "or username/password parameters"; + private final BufferAllocator allocator; private final AtomicInteger counter = new AtomicInteger(0); private final FlightSqlClientWithCallOptions client; @@ -107,9 +112,18 @@ public class FlightSqlConnection implements AdbcConnection { .build( loc -> { FlightClient client = buildClient(loc); - client.handshake(callOptions); - return new FlightSqlClientWithCallOptions( - new FlightSqlClient(client), callOptions); + try { + client.handshake(callOptions); + return new FlightSqlClientWithCallOptions( + new FlightSqlClient(client), callOptions); + } catch (RuntimeException ex) { + try { + client.close(); + } catch (Exception closeEx) { + ex.addSuppressed(closeEx); + } + throw ex; + } }); this.clientCache.put(location, this.client); } @@ -262,54 +276,91 @@ private FlightClient createInitialConnection( } } - // Build the client using the above properties. - final FlightClient client = buildClient(location); - // Add user-specified headers. ArrayList options = new ArrayList<>(); final FlightCallHeaders callHeaders = new FlightCallHeaders(); - for (Map.Entry parameter : parameters.entrySet()) { - if (parameter.getKey().startsWith(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX)) { - String userHeaderName = - parameter - .getKey() - .substring(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX.length()); - - if (parameter.getValue() instanceof String) { - callHeaders.insert(userHeaderName, (String) parameter.getValue()); - } else if (parameter.getValue() instanceof byte[]) { - callHeaders.insert(userHeaderName, (byte[]) parameter.getValue()); - } else { - throw new AdbcException( - String.format( - "Header values must be String or byte[]. The header failing was %s.", - parameter.getKey()), - null, - AdbcStatusCode.INVALID_ARGUMENT, - null, - 0); + String authorizationHeader = null; + String username = null; + String password = null; + String oauthFlow = null; + if (parameters != null) { + for (Map.Entry parameter : parameters.entrySet()) { + if (parameter.getKey().startsWith(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX)) { + String userHeaderName = + parameter + .getKey() + .substring(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX.length()); + + if (parameter.getValue() instanceof String) { + callHeaders.insert(userHeaderName, (String) parameter.getValue()); + } else if (parameter.getValue() instanceof byte[]) { + callHeaders.insert(userHeaderName, (byte[]) parameter.getValue()); + } else { + throw new AdbcException( + String.format( + "Header values must be String or byte[]. The header failing was %s.", + parameter.getKey()), + null, + AdbcStatusCode.INVALID_ARGUMENT, + null, + 0); + } } } + + authorizationHeader = FlightSqlConnectionProperties.AUTHORIZATION_HEADER.get(parameters); + username = AdbcDriver.PARAM_USERNAME.get(parameters); + password = AdbcDriver.PARAM_PASSWORD.get(parameters); + oauthFlow = FlightSqlConnectionProperties.OAUTH_FLOW.get(parameters); + } + + if (authorizationHeader != null) { + callHeaders.insert("authorization", authorizationHeader); } options.add(new HeaderCallOption(callHeaders)); - // Test the connection. - String username = AdbcDriver.PARAM_USERNAME.get(parameters); - String password = AdbcDriver.PARAM_PASSWORD.get(parameters); - if (username != null && password != null) { - Optional bearerToken = - client.authenticateBasicToken(username, password); - options.add( - bearerToken.orElse( - new CredentialCallOption(new BasicAuthCredentialWriter(username, password)))); - this.callOptions = options.toArray(new CallOption[0]); - } else { - this.callOptions = options.toArray(new CallOption[0]); - client.handshake(this.callOptions); + final boolean hasAuthorizationHeader = authorizationHeader != null; + final boolean hasUsernamePassword = username != null || password != null; + final boolean hasOauth = oauthFlow != null; + + if ((hasAuthorizationHeader && (hasUsernamePassword || hasOauth)) + || (hasUsernamePassword && hasOauth)) { + throw AdbcException.invalidArgument(AUTH_HEADER_CONFLICT_ERROR); } - return client; + // Build the client using the above properties. + final FlightClient client = buildClient(location); + + try { + // Test the connection. + if (hasOauth) { + final FlightSqlOAuthCredentialWriter oauthCredentialWriter = + FlightSqlOAuthCredentialWriter.create(parameters); + oauthCredentialWriter.prefetchToken(); + options.add(new CredentialCallOption(oauthCredentialWriter)); + this.callOptions = options.toArray(new CallOption[0]); + client.handshake(this.callOptions); + } else if (username != null && password != null) { + Optional bearerToken = + client.authenticateBasicToken(username, password); + options.add( + bearerToken.orElse( + new CredentialCallOption(new BasicAuthCredentialWriter(username, password)))); + this.callOptions = options.toArray(new CallOption[0]); + } else { + this.callOptions = options.toArray(new CallOption[0]); + client.handshake(this.callOptions); + } + return client; + } catch (AdbcException | RuntimeException ex) { + try { + client.close(); + } catch (Exception closeEx) { + ex.addSuppressed(closeEx); + } + throw ex; + } } /** Returns a yet-to-be authenticated FlightClient */ diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java index 4ab1955a1b..a6ed5f2650 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java @@ -22,10 +22,33 @@ /** Defines connection options that are used by the FlightSql driver. */ public interface FlightSqlConnectionProperties { + TypedKey AUTHORIZATION_HEADER = + new TypedKey<>("adbc.flight.sql.authorization_header", String.class); TypedKey MTLS_CERT_CHAIN = new TypedKey<>("adbc.flight.sql.client_option.mtls_cert_chain", InputStream.class); TypedKey MTLS_PRIVATE_KEY = new TypedKey<>("adbc.flight.sql.client_option.mtls_private_key", InputStream.class); + TypedKey OAUTH_FLOW = new TypedKey<>("adbc.flight.sql.oauth.flow", String.class); + TypedKey OAUTH_TOKEN_URI = + new TypedKey<>("adbc.flight.sql.oauth.token_uri", String.class); + TypedKey OAUTH_CLIENT_ID = + new TypedKey<>("adbc.flight.sql.oauth.client_id", String.class); + TypedKey OAUTH_CLIENT_SECRET = + new TypedKey<>("adbc.flight.sql.oauth.client_secret", String.class); + TypedKey OAUTH_SCOPE = new TypedKey<>("adbc.flight.sql.oauth.scope", String.class); + TypedKey OAUTH_RESOURCE = new TypedKey<>("adbc.flight.sql.oauth.resource", String.class); + TypedKey OAUTH_EXCHANGE_SUBJECT_TOKEN = + new TypedKey<>("adbc.flight.sql.oauth.exchange.subject_token", String.class); + TypedKey OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE = + new TypedKey<>("adbc.flight.sql.oauth.exchange.subject_token_type", String.class); + TypedKey OAUTH_EXCHANGE_ACTOR_TOKEN = + new TypedKey<>("adbc.flight.sql.oauth.exchange.actor_token", String.class); + TypedKey OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE = + new TypedKey<>("adbc.flight.sql.oauth.exchange.actor_token_type", String.class); + TypedKey OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE = + new TypedKey<>("adbc.flight.sql.oauth.exchange.requested_token_type", String.class); + TypedKey OAUTH_EXCHANGE_AUD = + new TypedKey<>("adbc.flight.sql.oauth.exchange.aud", String.class); TypedKey TLS_OVERRIDE_HOSTNAME = new TypedKey<>("adbc.flight.sql.client_option.tls_override_hostname", String.class); TypedKey TLS_SKIP_VERIFY = diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java index af8221d6b0..706f69ff9f 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java @@ -64,6 +64,13 @@ public AdbcConnection connect() throws AdbcException { adbcException.addSuppressed(e); } throw adbcException; + } catch (AdbcException ex) { + try { + AutoCloseables.close(connectionAllocator); + } catch (Exception e) { + ex.addSuppressed(e); + } + throw ex; } catch (Exception ex) { AdbcException adbcException = FlightSqlDriverUtil.fromGeneralException(ex); try { diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthConfiguration.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthConfiguration.java new file mode 100644 index 0000000000..203f524bb2 --- /dev/null +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthConfiguration.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adbc.driver.flightsql.oauth; + +import com.nimbusds.oauth2.sdk.GrantType; +import com.nimbusds.oauth2.sdk.ParseException; +import com.nimbusds.oauth2.sdk.token.TokenTypeURI; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.adbc.core.TypedKey; +import org.apache.arrow.adbc.driver.flightsql.FlightSqlConnectionProperties; +import org.checkerframework.checker.nullness.qual.Nullable; + +final class FlightSqlOAuthConfiguration { + private final GrantType flowType; + private final @Nullable String tokenUri; + private final @Nullable String clientId; + private final @Nullable String clientSecret; + private final @Nullable String scope; + private final @Nullable String subjectToken; + private final @Nullable TokenTypeURI subjectTokenType; + private final @Nullable String actorToken; + private final @Nullable TokenTypeURI actorTokenType; + private final @Nullable TokenTypeURI requestedTokenType; + private final @Nullable String audience; + private final @Nullable URI resource; + + private FlightSqlOAuthConfiguration( + GrantType flowType, + @Nullable String tokenUri, + @Nullable String clientId, + @Nullable String clientSecret, + @Nullable String scope, + @Nullable String subjectToken, + @Nullable TokenTypeURI subjectTokenType, + @Nullable String actorToken, + @Nullable TokenTypeURI actorTokenType, + @Nullable TokenTypeURI requestedTokenType, + @Nullable String audience, + @Nullable URI resource) { + this.flowType = flowType; + this.tokenUri = tokenUri; + this.clientId = clientId; + this.clientSecret = clientSecret; + this.scope = scope; + this.subjectToken = subjectToken; + this.subjectTokenType = subjectTokenType; + this.actorToken = actorToken; + this.actorTokenType = actorTokenType; + this.requestedTokenType = requestedTokenType; + this.audience = audience; + this.resource = resource; + } + + static FlightSqlOAuthConfiguration from(Map parameters) throws AdbcException { + final GrantType flowType = + parseGrantType(requireOption(parameters, FlightSqlConnectionProperties.OAUTH_FLOW)); + final @Nullable String clientId = FlightSqlConnectionProperties.OAUTH_CLIENT_ID.get(parameters); + final @Nullable String clientSecret = + FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.get(parameters); + final @Nullable String subjectToken = + FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN.get(parameters); + final @Nullable String subjectTokenTypeValue = + FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE.get(parameters); + final @Nullable String actorToken = + FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN.get(parameters); + final @Nullable String actorTokenTypeValue = + FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE.get(parameters); + + final boolean clientCredentials = GrantType.CLIENT_CREDENTIALS.equals(flowType); + final boolean tokenExchange = GrantType.TOKEN_EXCHANGE.equals(flowType); + final @Nullable String tokenUri = + (clientCredentials || tokenExchange) + ? requireOption(parameters, FlightSqlConnectionProperties.OAUTH_TOKEN_URI) + : null; + + if (clientCredentials) { + requireOption(parameters, FlightSqlConnectionProperties.OAUTH_CLIENT_ID); + requireOption(parameters, FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET); + } + if (tokenExchange) { + requireOption(parameters, FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN); + requireOption(parameters, FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE); + if ((actorToken == null) != (actorTokenTypeValue == null)) { + throw AdbcException.invalidArgument( + "[Flight SQL] token exchange grant requires " + + FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE.getKey() + + " when " + + FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN.getKey() + + " is provided"); + } + if ((clientId == null) != (clientSecret == null)) { + throw AdbcException.invalidArgument( + "[Flight SQL] token exchange grant requires both " + + FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey() + + " and " + + FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey() + + " when client credentials are provided"); + } + } + + final @Nullable String resource = + tokenExchange ? FlightSqlConnectionProperties.OAUTH_RESOURCE.get(parameters) : null; + return new FlightSqlOAuthConfiguration( + flowType, + tokenUri, + clientId, + clientSecret, + FlightSqlConnectionProperties.OAUTH_SCOPE.get(parameters), + subjectToken, + parseTokenTypeUri( + subjectTokenTypeValue, + FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE), + actorToken, + parseTokenTypeUri( + FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE.get(parameters), + FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE), + parseTokenTypeUri( + FlightSqlConnectionProperties.OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE.get(parameters), + FlightSqlConnectionProperties.OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE), + FlightSqlConnectionProperties.OAUTH_EXCHANGE_AUD.get(parameters), + parseResource(resource)); + } + + GrantType flowType() { + return flowType; + } + + @Nullable String tokenUri() { + return tokenUri; + } + + @Nullable String clientId() { + return clientId; + } + + @Nullable String clientSecret() { + return clientSecret; + } + + @Nullable String scope() { + return scope; + } + + @Nullable String subjectToken() { + return subjectToken; + } + + @Nullable TokenTypeURI subjectTokenType() { + return subjectTokenType; + } + + @Nullable String actorToken() { + return actorToken; + } + + @Nullable TokenTypeURI actorTokenType() { + return actorTokenType; + } + + @Nullable TokenTypeURI requestedTokenType() { + return requestedTokenType; + } + + @Nullable String audience() { + return audience; + } + + @Nullable URI resource() { + return resource; + } + + private static String requireOption(Map parameters, TypedKey option) + throws AdbcException { + final String value = option.get(parameters); + if (value == null) { + throw AdbcException.invalidArgument("[Flight SQL] OAuth flow requires " + option.getKey()); + } + return value; + } + + private static GrantType parseGrantType(String value) throws AdbcException { + try { + return GrantType.parse(value); + } catch (ParseException e) { + throw AdbcException.invalidArgument("[Flight SQL] invalid OAuth flow: " + value).withCause(e); + } + } + + private static @Nullable TokenTypeURI parseTokenTypeUri( + @Nullable String value, TypedKey option) throws AdbcException { + if (value == null) { + return null; + } + try { + return TokenTypeURI.parse(value); + } catch (ParseException e) { + throw AdbcException.invalidArgument( + "[Flight SQL] invalid OAuth token type for " + option.getKey() + ": " + value) + .withCause(e); + } + } + + private static @Nullable URI parseResource(@Nullable String resource) throws AdbcException { + if (resource == null) { + return null; + } + try { + return new URI(resource); + } catch (URISyntaxException e) { + throw AdbcException.invalidArgument( + "[Flight SQL] token exchange grant requires a valid URI for " + + FlightSqlConnectionProperties.OAUTH_RESOURCE.getKey()) + .withCause(e); + } + } +} diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthCredentialWriter.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthCredentialWriter.java new file mode 100644 index 0000000000..7a9b5ad0cd --- /dev/null +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthCredentialWriter.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adbc.driver.flightsql.oauth; + +import java.sql.SQLException; +import java.util.Map; +import java.util.function.Consumer; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.driver.jdbc.client.oauth.OAuthTokenProvider; +import org.apache.arrow.flight.CallHeaders; + +public final class FlightSqlOAuthCredentialWriter implements Consumer { + private final OAuthTokenProvider tokenProvider; + + private FlightSqlOAuthCredentialWriter(OAuthTokenProvider tokenProvider) { + this.tokenProvider = tokenProvider; + } + + public static FlightSqlOAuthCredentialWriter create(Map parameters) + throws AdbcException { + final FlightSqlOAuthConfiguration configuration = + FlightSqlOAuthConfiguration.from(parameters); + return new FlightSqlOAuthCredentialWriter(FlightSqlOAuthTokenProviders.create(configuration)); + } + + public void prefetchToken() throws AdbcException { + currentAuthorizationValue(); + } + + @Override + public void accept(CallHeaders headers) { + try { + headers.insert("authorization", currentAuthorizationValue()); + } catch (AdbcException e) { + throw new IllegalStateException(e.getMessage(), e); + } + } + + private String currentAuthorizationValue() throws AdbcException { + try { + return "Bearer " + tokenProvider.getValidToken(); + } catch (SQLException e) { + throw AdbcException.io("[Flight SQL] OAuth token request failed: " + e.getMessage()) + .withCause(e); + } + } +} diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthTokenProviders.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthTokenProviders.java new file mode 100644 index 0000000000..5f1fd0187f --- /dev/null +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthTokenProviders.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adbc.driver.flightsql.oauth; + +import com.nimbusds.oauth2.sdk.GrantType; +import java.util.Objects; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.driver.jdbc.client.oauth.OAuthTokenProvider; +import org.apache.arrow.driver.jdbc.client.oauth.OAuthTokenProviders; + +final class FlightSqlOAuthTokenProviders { + private FlightSqlOAuthTokenProviders() {} + + static OAuthTokenProvider create(FlightSqlOAuthConfiguration configuration) + throws AdbcException { + final GrantType flowType = configuration.flowType(); + if (GrantType.CLIENT_CREDENTIALS.equals(flowType)) { + return createClientCredentialsProvider(configuration); + } + if (GrantType.TOKEN_EXCHANGE.equals(flowType)) { + return createTokenExchangeProvider(configuration); + } + throw AdbcException.notImplemented("[Flight SQL] oauth flow not implemented: " + flowType); + } + + private static OAuthTokenProvider createClientCredentialsProvider( + FlightSqlOAuthConfiguration configuration) throws AdbcException { + try { + final OAuthTokenProviders.ClientCredentialsBuilder builder = + OAuthTokenProviders.clientCredentials() + .tokenUri(Objects.requireNonNull(configuration.tokenUri())) + .clientId(Objects.requireNonNull(configuration.clientId())) + .clientSecret(Objects.requireNonNull(configuration.clientSecret())); + + final String scope = configuration.scope(); + if (scope != null) { + builder.scope(scope); + } + return builder.build(); + } catch (RuntimeException e) { + throw AdbcException.invalidArgument( + "[Flight SQL] Invalid OAuth client credentials configuration: " + e.getMessage()) + .withCause(e); + } + } + + private static OAuthTokenProvider createTokenExchangeProvider( + FlightSqlOAuthConfiguration configuration) throws AdbcException { + try { + final OAuthTokenProviders.TokenExchangeBuilder builder = + OAuthTokenProviders.tokenExchange() + .tokenUri(Objects.requireNonNull(configuration.tokenUri())) + .subjectToken(Objects.requireNonNull(configuration.subjectToken())) + .subjectTokenType( + Objects.requireNonNull(configuration.subjectTokenType()).toString()); + + if (configuration.actorToken() != null) { + builder + .actorToken(configuration.actorToken()) + .actorTokenType(Objects.requireNonNull(configuration.actorTokenType()).toString()); + } + if (configuration.clientId() != null) { + builder.clientCredentials( + configuration.clientId(), Objects.requireNonNull(configuration.clientSecret())); + } + if (configuration.audience() != null) { + builder.audience(configuration.audience()); + } + if (configuration.requestedTokenType() != null) { + builder.requestedTokenType(configuration.requestedTokenType().toString()); + } + if (configuration.resource() != null) { + builder.resource(configuration.resource()); + } + if (configuration.scope() != null) { + builder.scope(configuration.scope()); + } + + return builder.build(); + } catch (RuntimeException e) { + throw AdbcException.invalidArgument( + "[Flight SQL] Invalid OAuth token exchange configuration: " + e.getMessage()) + .withCause(e); + } + } +} diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/OAuthTest.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/OAuthTest.java new file mode 100644 index 0000000000..c19f432bbc --- /dev/null +++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/OAuthTest.java @@ -0,0 +1,510 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.driver.flightsql; + +import static org.junit.Assert.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.nimbusds.oauth2.sdk.GrantType; +import com.nimbusds.oauth2.sdk.http.HTTPRequest; +import com.nimbusds.oauth2.sdk.token.TokenTypeURI; +import com.sun.net.httpserver.Headers; +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpHandler; +import com.sun.net.httpserver.HttpServer; +import com.sun.net.httpserver.HttpsConfigurator; +import com.sun.net.httpserver.HttpsServer; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.KeyFactory; +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.SecureRandom; +import java.security.cert.Certificate; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.security.spec.PKCS8EncodedKeySpec; +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManagerFactory; +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.AdbcDriver; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.adbc.core.AdbcInfoCode; +import org.apache.arrow.adbc.core.AdbcStatusCode; +import org.apache.arrow.adbc.drivermanager.AdbcDriverManager; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class OAuthTest { + private static final String CLIENT_ID = "adbc-client"; + private static final String CLIENT_SECRET = "adbc-secret"; + private static final String CLIENT_ACCESS_TOKEN = "client-credentials-token"; + private static final String EXCHANGE_ACCESS_TOKEN = "token-exchange-token"; + private static final char[] TRUST_STORE_PASSWORD = "changeit".toCharArray(); + + private BufferAllocator allocator; + private Map params; + private FlightServer server; + private HttpServer tokenServer; + private AdbcDatabase database; + private AdbcConnection connection; + private HeaderValidator.Factory headerValidatorFactory; + private TokenHandler tokenHandler; + private String tokenServerScheme; + private String previousTrustStore; + private String previousTrustStorePassword; + private String previousTrustStoreType; + private SSLSocketFactory previousDefaultSslSocketFactory; + private HostnameVerifier previousDefaultHostnameVerifier; + private final List tempPaths = new ArrayList<>(); + + @BeforeEach + public void setUp() throws IOException { + allocator = new RootAllocator(Long.MAX_VALUE); + params = new HashMap<>(); + headerValidatorFactory = new HeaderValidator.Factory(); + server = + FlightServer.builder() + .allocator(allocator) + .middleware(HeaderValidator.KEY, headerValidatorFactory) + .location(Location.forGrpcInsecure("localhost", 0)) + .producer(new MockFlightSqlProducer()) + .build(); + server.start(); + + tokenHandler = new TokenHandler(); + previousTrustStore = System.getProperty("javax.net.ssl.trustStore"); + previousTrustStorePassword = System.getProperty("javax.net.ssl.trustStorePassword"); + previousTrustStoreType = System.getProperty("javax.net.ssl.trustStoreType"); + previousDefaultSslSocketFactory = HTTPRequest.getDefaultSSLSocketFactory(); + previousDefaultHostnameVerifier = HTTPRequest.getDefaultHostnameVerifier(); + startHttpTokenServer(); + + params.put( + AdbcDriver.PARAM_URI.getKey(), String.format("grpc+tcp://localhost:%d", server.getPort())); + } + + @AfterEach + public void tearDown() throws Exception { + AutoCloseables.close(connection, database, server, allocator); + if (tokenServer != null) { + tokenServer.stop(0); + } + restoreJvmTrustStoreConfiguration(); + for (Path path : tempPaths) { + Files.deleteIfExists(path); + } + connection = null; + database = null; + server = null; + allocator = null; + tokenServer = null; + } + + @Test + public void testClientCredentialsFlow() throws Exception { + tokenHandler.accessToken = CLIENT_ACCESS_TOKEN; + params.put( + FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), + GrantType.CLIENT_CREDENTIALS.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey(), CLIENT_SECRET); + params.put(FlightSqlConnectionProperties.OAUTH_SCOPE.getKey(), "scope-a scope-b"); + + connect(); + requestServerMetadata(); + + CallHeaders headers = headerValidatorFactory.getHeadersReceivedAtRequest(0); + assertEquals("Bearer " + CLIENT_ACCESS_TOKEN, headers.get("authorization")); + + assertEquals(1, tokenHandler.requestBodies.size()); + assertEquals("client_credentials", tokenHandler.formValue(0, "grant_type")); + assertEquals("scope-a scope-b", tokenHandler.formValue(0, "scope")); + assertTrue(tokenHandler.authorizationHeaders.get(0).startsWith("Basic ")); + } + + @Test + public void testTokenExchangeFlow() throws Exception { + tokenHandler.accessToken = EXCHANGE_ACCESS_TOKEN; + params.put( + FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), GrantType.TOKEN_EXCHANGE.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + params.put( + FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN.getKey(), "subject-token"); + params.put( + FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE.getKey(), + TokenTypeURI.ACCESS_TOKEN.toString()); + params.put(FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN.getKey(), "actor-token"); + params.put( + FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE.getKey(), + TokenTypeURI.JWT.toString()); + params.put( + FlightSqlConnectionProperties.OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE.getKey(), + TokenTypeURI.ACCESS_TOKEN.toString()); + params.put(FlightSqlConnectionProperties.OAUTH_EXCHANGE_AUD.getKey(), "flight-service"); + params.put( + FlightSqlConnectionProperties.OAUTH_RESOURCE.getKey(), "https://resource.example.com"); + params.put(FlightSqlConnectionProperties.OAUTH_SCOPE.getKey(), "profile email"); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey(), CLIENT_SECRET); + + connect(); + requestServerMetadata(); + + CallHeaders headers = headerValidatorFactory.getHeadersReceivedAtRequest(0); + assertEquals("Bearer " + EXCHANGE_ACCESS_TOKEN, headers.get("authorization")); + + assertEquals(1, tokenHandler.requestBodies.size()); + assertEquals( + "urn:ietf:params:oauth:grant-type:token-exchange", tokenHandler.formValue(0, "grant_type")); + assertEquals("subject-token", tokenHandler.formValue(0, "subject_token")); + assertEquals( + TokenTypeURI.ACCESS_TOKEN.toString(), tokenHandler.formValue(0, "subject_token_type")); + assertEquals("actor-token", tokenHandler.formValue(0, "actor_token")); + assertEquals(TokenTypeURI.JWT.toString(), tokenHandler.formValue(0, "actor_token_type")); + assertEquals( + TokenTypeURI.ACCESS_TOKEN.toString(), tokenHandler.formValue(0, "requested_token_type")); + assertEquals("flight-service", tokenHandler.formValue(0, "audience")); + assertEquals("https://resource.example.com", tokenHandler.formValue(0, "resource")); + assertEquals("profile email", tokenHandler.formValue(0, "scope")); + assertTrue(tokenHandler.authorizationHeaders.get(0).startsWith("Basic ")); + } + + @Test + public void testClientCredentialsFlowWithHttpsTokenEndpointWithoutTrustStore() throws Exception { + tokenHandler.accessToken = CLIENT_ACCESS_TOKEN; + startHttpsTokenServer(); + clearJvmTrustStoreConfiguration(); + + params.put( + FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), + GrantType.CLIENT_CREDENTIALS.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey(), CLIENT_SECRET); + + AdbcException adbcException = assertThrows(AdbcException.class, this::connect); + assertEquals(AdbcStatusCode.IO, adbcException.getStatus()); + } + + @Test + public void testClientCredentialsFlowWithHttpsTokenEndpointUsesJvmTrustStore() throws Exception { + tokenHandler.accessToken = CLIENT_ACCESS_TOKEN; + startHttpsTokenServer(); + configureJvmTrustStore(createTrustStorePath()); + + params.put( + FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), + GrantType.CLIENT_CREDENTIALS.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey(), CLIENT_SECRET); + + connect(); + requestServerMetadata(); + + CallHeaders headers = headerValidatorFactory.getHeadersReceivedAtRequest(0); + assertEquals("Bearer " + CLIENT_ACCESS_TOKEN, headers.get("authorization")); + assertEquals(1, tokenHandler.requestBodies.size()); + assertEquals("client_credentials", tokenHandler.formValue(0, "grant_type")); + } + + @Test + public void testAuthorizationHeaderConflictsWithOauth() { + params.put( + FlightSqlConnectionProperties.AUTHORIZATION_HEADER.getKey(), "Bearer existing-token"); + params.put( + FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), + GrantType.CLIENT_CREDENTIALS.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey(), CLIENT_SECRET); + + AdbcException adbcException = assertThrows(AdbcException.class, this::connect); + assertEquals(AdbcStatusCode.INVALID_ARGUMENT, adbcException.getStatus()); + } + + @Test + public void testMissingRequiredParamsClientCredentials() { + params.put( + FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), + GrantType.CLIENT_CREDENTIALS.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID); + + AdbcException adbcException = assertThrows(AdbcException.class, this::connect); + assertEquals(AdbcStatusCode.INVALID_ARGUMENT, adbcException.getStatus()); + } + + @Test + public void testMissingRequiredParamsTokenExchange() { + params.put( + FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), GrantType.TOKEN_EXCHANGE.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + params.put( + FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN.getKey(), "subject-token"); + + AdbcException adbcException = assertThrows(AdbcException.class, this::connect); + assertEquals(AdbcStatusCode.INVALID_ARGUMENT, adbcException.getStatus()); + } + + @Test + public void testActorTokenRequiresActorTokenType() { + params.put( + FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), GrantType.TOKEN_EXCHANGE.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + params.put( + FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN.getKey(), "subject-token"); + params.put( + FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE.getKey(), + TokenTypeURI.ACCESS_TOKEN.toString()); + params.put(FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN.getKey(), "actor-token"); + + AdbcException adbcException = assertThrows(AdbcException.class, this::connect); + assertEquals(AdbcStatusCode.INVALID_ARGUMENT, adbcException.getStatus()); + } + + @Test + public void testInvalidOauthFlow() { + params.put(FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), "invalid-flow"); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + + AdbcException adbcException = assertThrows(AdbcException.class, this::connect); + assertEquals(AdbcStatusCode.NOT_IMPLEMENTED, adbcException.getStatus()); + } + + private void connect() throws Exception { + database = + AdbcDriverManager.getInstance() + .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params); + connection = database.connect(); + } + + private void requestServerMetadata() throws Exception { + try (ArrowReader reader = connection.getInfo(new int[] {AdbcInfoCode.VENDOR_NAME.getValue()})) { + while (reader.loadNextBatch()) { + // Only interested in triggering an authenticated RPC. + } + } catch (Exception ex) { + // MockFlightSqlProducer does not implement the full SQL metadata surface. + } + } + + private String tokenUri() { + return String.format( + "%s://localhost:%d/token", tokenServerScheme, tokenServer.getAddress().getPort()); + } + + private void startHttpTokenServer() throws IOException { + if (tokenServer != null) { + tokenServer.stop(0); + } + tokenServerScheme = "http"; + tokenServer = HttpServer.create(new InetSocketAddress("localhost", 0), 0); + tokenServer.createContext("/token", tokenHandler); + tokenServer.start(); + } + + private void startHttpsTokenServer() throws Exception { + if (tokenServer != null) { + tokenServer.stop(0); + } + tokenServerScheme = "https"; + final SSLContext sslContext = createServerSslContext(); + final HttpsServer httpsServer = HttpsServer.create(new InetSocketAddress("localhost", 0), 0); + httpsServer.setHttpsConfigurator(new HttpsConfigurator(sslContext)); + httpsServer.createContext("/token", tokenHandler); + httpsServer.start(); + tokenServer = httpsServer; + } + + private void configureJvmTrustStore(Path trustStorePath) throws Exception { + System.setProperty("javax.net.ssl.trustStore", trustStorePath.toString()); + System.setProperty("javax.net.ssl.trustStorePassword", new String(TRUST_STORE_PASSWORD)); + System.setProperty("javax.net.ssl.trustStoreType", "PKCS12"); + refreshOAuthHttpsDefaults(); + } + + private void clearJvmTrustStoreConfiguration() throws Exception { + System.clearProperty("javax.net.ssl.trustStore"); + System.clearProperty("javax.net.ssl.trustStorePassword"); + System.clearProperty("javax.net.ssl.trustStoreType"); + refreshOAuthHttpsDefaults(); + } + + private void restoreJvmTrustStoreConfiguration() { + restoreSystemProperty("javax.net.ssl.trustStore", previousTrustStore); + restoreSystemProperty("javax.net.ssl.trustStorePassword", previousTrustStorePassword); + restoreSystemProperty("javax.net.ssl.trustStoreType", previousTrustStoreType); + HTTPRequest.setDefaultSSLSocketFactory(previousDefaultSslSocketFactory); + HTTPRequest.setDefaultHostnameVerifier(oauthHostnameVerifier()); + } + + private void refreshOAuthHttpsDefaults() throws Exception { + final TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init((KeyStore) null); + final SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom()); + HTTPRequest.setDefaultSSLSocketFactory(sslContext.getSocketFactory()); + HTTPRequest.setDefaultHostnameVerifier(oauthHostnameVerifier()); + } + + private Path createTrustStorePath() throws Exception { + final KeyStore trustStore = KeyStore.getInstance("PKCS12"); + trustStore.load(null, null); + trustStore.setCertificateEntry("root", readCertificate(flightDataPath("root-ca.pem"))); + + final Path trustStorePath = Files.createTempFile("oauth-truststore", ".p12"); + tempPaths.add(trustStorePath); + try (OutputStream output = Files.newOutputStream(trustStorePath)) { + trustStore.store(output, TRUST_STORE_PASSWORD); + } + return trustStorePath; + } + + private SSLContext createServerSslContext() throws Exception { + final Certificate certificate = readCertificate(flightDataPath("cert0.pem")); + final PrivateKey privateKey = readPrivateKey(flightDataPath("cert0.pkcs1")); + final KeyStore keyStore = KeyStore.getInstance("PKCS12"); + keyStore.load(null, null); + keyStore.setKeyEntry( + "token-server", privateKey, TRUST_STORE_PASSWORD, new Certificate[] {certificate}); + + final KeyManagerFactory keyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(keyStore, TRUST_STORE_PASSWORD); + + final SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(keyManagerFactory.getKeyManagers(), null, new SecureRandom()); + return sslContext; + } + + private static X509Certificate readCertificate(Path path) throws Exception { + try (InputStream input = Files.newInputStream(path)) { + final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); + return (X509Certificate) certificateFactory.generateCertificate(input); + } + } + + private static PrivateKey readPrivateKey(Path path) throws Exception { + final String pem = Files.readString(path, StandardCharsets.US_ASCII); + final String base64 = + pem.replace("-----BEGIN PRIVATE KEY-----", "") + .replace("-----END PRIVATE KEY-----", "") + .replaceAll("\\s+", ""); + final byte[] der = Base64.getDecoder().decode(base64); + return KeyFactory.getInstance("RSA").generatePrivate(new PKCS8EncodedKeySpec(der)); + } + + private static Path flightDataPath(String filename) { + final String dataRoot = System.getProperty("arrow.test.dataRoot"); + if (dataRoot != null) { + return Paths.get(dataRoot).resolve("flight").resolve(filename); + } + return Paths.get("testing", "data", "flight", filename).toAbsolutePath().normalize(); + } + + private static void restoreSystemProperty(String key, String value) { + if (value == null) { + System.clearProperty(key); + } else { + System.setProperty(key, value); + } + } + + private HostnameVerifier oauthHostnameVerifier() { + if (previousDefaultHostnameVerifier != null) { + return previousDefaultHostnameVerifier; + } + return HttpsURLConnection.getDefaultHostnameVerifier(); + } + + private static final class TokenHandler implements HttpHandler { + private final List requestBodies = new ArrayList<>(); + private final List authorizationHeaders = new ArrayList<>(); + private String accessToken = CLIENT_ACCESS_TOKEN; + + @Override + public void handle(HttpExchange exchange) throws IOException { + final String body = + new String(exchange.getRequestBody().readAllBytes(), StandardCharsets.UTF_8); + requestBodies.add(body); + authorizationHeaders.add(exchange.getRequestHeaders().getFirst("Authorization")); + + final byte[] responseBytes = + ("{\"access_token\":\"" + + accessToken + + "\",\"token_type\":\"Bearer\",\"expires_in\":3600}") + .getBytes(StandardCharsets.UTF_8); + final Headers responseHeaders = exchange.getResponseHeaders(); + responseHeaders.add("Content-Type", "application/json"); + exchange.sendResponseHeaders(200, responseBytes.length); + try (OutputStream output = exchange.getResponseBody()) { + output.write(responseBytes); + } + } + + private String formValue(int requestIndex, String key) { + return decodeForm(requestBodies.get(requestIndex)).get(key); + } + + private static Map decodeForm(String body) { + final Map values = new LinkedHashMap<>(); + if (body.isEmpty()) { + return values; + } + for (String pair : body.split("&")) { + final String[] parts = pair.split("=", 2); + final String name = URLDecoder.decode(parts[0], StandardCharsets.UTF_8); + final String value = + parts.length > 1 ? URLDecoder.decode(parts[1], StandardCharsets.UTF_8) : ""; + values.put(name, value); + } + return values; + } + } +}