diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java index 98f764132fe..ef50dd82bad 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java @@ -176,6 +176,18 @@ public static OkHttpChannelBuilder forTarget(String target, ChannelCredentials c return new OkHttpChannelBuilder(target, creds, result.callCredentials, result.factory); } + private static ConnectionSpec connectionSpecFromChannelCredentials(ChannelCredentials channelCredentials) { + if (channelCredentials instanceof SslSocketFactoryChannelCredentials.ChannelCredentials) { + return ((SslSocketFactoryChannelCredentials.ChannelCredentials) channelCredentials).getConnectionSpec(); + } else if (channelCredentials instanceof CompositeChannelCredentials) { + return connectionSpecFromChannelCredentials( + ((CompositeChannelCredentials) channelCredentials).getChannelCredentials() + ); + } else { + return null; + } + } + private ObjectPool transportExecutorPool = DEFAULT_TRANSPORT_EXECUTOR_POOL; private ObjectPool scheduledExecutorServicePool = SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); @@ -222,6 +234,10 @@ private OkHttpChannelBuilder(String target) { this.negotiationType = factory == null ? NegotiationType.PLAINTEXT : NegotiationType.TLS; this.freezeSecurityConfiguration = true; this.channelCredentials = channelCreds; + ConnectionSpec connectionSpec = connectionSpecFromChannelCredentials(channelCreds); + if (connectionSpec != null) { + this.connectionSpec = connectionSpec; + } } private final class OkHttpChannelTransportFactoryBuilder diff --git a/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryChannelCredentials.java b/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryChannelCredentials.java index 059a0972e49..965d4c75918 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryChannelCredentials.java +++ b/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryChannelCredentials.java @@ -18,6 +18,8 @@ import com.google.common.base.Preconditions; import io.grpc.ExperimentalApi; +import io.grpc.okhttp.internal.ConnectionSpec; + import javax.net.ssl.SSLSocketFactory; /** A credential with full control over the SSLSocketFactory. */ @@ -29,18 +31,43 @@ public static io.grpc.ChannelCredentials create(SSLSocketFactory factory) { return new ChannelCredentials(factory); } + public static io.grpc.ChannelCredentials create( + SSLSocketFactory factory, com.squareup.okhttp.ConnectionSpec connectionSpec) { + return new ChannelCredentials(factory, Utils.convertSpec(connectionSpec)); + } + + public static io.grpc.ChannelCredentials create( + SSLSocketFactory factory, String[] tlsVersions, String[] cipherSuiteList, boolean supportsTlsExtensions) { + ConnectionSpec connectionSpec = new ConnectionSpec.Builder(true) + .tlsVersions(tlsVersions) + .cipherSuites(cipherSuiteList) + .supportsTlsExtensions(supportsTlsExtensions) + .build(); + return new ChannelCredentials(factory, connectionSpec); + } + // Hide implementation detail of how these credentials operate static final class ChannelCredentials extends io.grpc.ChannelCredentials { private final SSLSocketFactory factory; + private final ConnectionSpec connectionSpec; - private ChannelCredentials(SSLSocketFactory factory) { + ChannelCredentials(SSLSocketFactory factory) { + this(factory, OkHttpChannelBuilder.INTERNAL_DEFAULT_CONNECTION_SPEC); + } + + ChannelCredentials(SSLSocketFactory factory, ConnectionSpec connectionSpec) { this.factory = Preconditions.checkNotNull(factory, "factory"); + this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec"); } public SSLSocketFactory getFactory() { return factory; } + public ConnectionSpec getConnectionSpec() { + return connectionSpec; + } + @Override public io.grpc.ChannelCredentials withoutBearerTokens() { return this; diff --git a/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java b/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java index 63c6f33ff79..22f650de494 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java +++ b/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java @@ -35,6 +35,16 @@ public static io.grpc.ServerCredentials create( return new ServerCredentials(factory, Utils.convertSpec(connectionSpec)); } + public static io.grpc.ServerCredentials create( + SSLSocketFactory factory, String[] tlsVersions, String[] cipherSuiteList, boolean supportsTlsExtensions) { + ConnectionSpec connectionSpec = new ConnectionSpec.Builder(true) + .tlsVersions(tlsVersions) + .cipherSuites(cipherSuiteList) + .supportsTlsExtensions(supportsTlsExtensions) + .build(); + return new ServerCredentials(factory, connectionSpec); + } + // Hide implementation detail of how these credentials operate static final class ServerCredentials extends io.grpc.ServerCredentials { private final SSLSocketFactory factory; diff --git a/okhttp/src/main/java/io/grpc/okhttp/TlsServerHandshakerSocketFactory.java b/okhttp/src/main/java/io/grpc/okhttp/TlsServerHandshakerSocketFactory.java index c375d6246cc..d0554f3824d 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/TlsServerHandshakerSocketFactory.java +++ b/okhttp/src/main/java/io/grpc/okhttp/TlsServerHandshakerSocketFactory.java @@ -26,6 +26,7 @@ import java.io.IOException; import java.net.Socket; import java.util.Arrays; +import java.util.Collections; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; @@ -58,7 +59,7 @@ public HandshakeResult handshake(Socket socket, Attributes attributes) throws IO String negotiatedProtocol = OkHttpProtocolNegotiator.get().negotiate( sslSocket, null, - connectionSpec.supportsTlsExtensions() ? Arrays.asList(expectedProtocol) : null); + connectionSpec.supportsTlsExtensions() ? Collections.singletonList(expectedProtocol) : null); if (!expectedProtocol.toString().equals(negotiatedProtocol)) { throw new IOException("Expected NPN/ALPN " + expectedProtocol + ": " + negotiatedProtocol); } diff --git a/okhttp/src/main/java/io/grpc/okhttp/Utils.java b/okhttp/src/main/java/io/grpc/okhttp/Utils.java index 4546143cf3b..29d3698e918 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/Utils.java +++ b/okhttp/src/main/java/io/grpc/okhttp/Utils.java @@ -75,15 +75,25 @@ static ConnectionSpec convertSpec(com.squareup.okhttp.ConnectionSpec spec) { Preconditions.checkArgument(spec.isTls(), "plaintext ConnectionSpec is not accepted"); List tlsVersionList = spec.tlsVersions(); - String[] tlsVersions = new String[tlsVersionList.size()]; - for (int i = 0; i < tlsVersions.length; i++) { - tlsVersions[i] = tlsVersionList.get(i).javaName(); + String[] tlsVersions; + if (tlsVersionList != null) { + tlsVersions = new String[tlsVersionList.size()]; + for (int i = 0; i < tlsVersions.length; i++) { + tlsVersions[i] = tlsVersionList.get(i).javaName(); + } + } else { + tlsVersions = null; } List cipherSuiteList = spec.cipherSuites(); - CipherSuite[] cipherSuites = new CipherSuite[cipherSuiteList.size()]; - for (int i = 0; i < cipherSuites.length; i++) { - cipherSuites[i] = CipherSuite.valueOf(cipherSuiteList.get(i).name()); + CipherSuite[] cipherSuites; + if (cipherSuiteList != null) { + cipherSuites = new CipherSuite[cipherSuiteList.size()]; + for (int i = 0; i < cipherSuites.length; i++) { + cipherSuites[i] = CipherSuite.valueOf(cipherSuiteList.get(i).name()); + } + } else { + cipherSuites = null; } return new ConnectionSpec.Builder(spec.isTls()) diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/ConnectionSpec.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/ConnectionSpec.java index b84a1ff94ee..e892320aa7e 100644 --- a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/ConnectionSpec.java +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/ConnectionSpec.java @@ -128,7 +128,11 @@ public boolean supportsTlsExtensions() { public void apply(SSLSocket sslSocket, boolean isFallback) { ConnectionSpec specToApply = supportedSpec(sslSocket, isFallback); - sslSocket.setEnabledProtocols(specToApply.tlsVersions); + String[] tlsVersionsToEnable = specToApply.tlsVersions; + // null means "use default set". + if (tlsVersionsToEnable != null) { + sslSocket.setEnabledProtocols(tlsVersionsToEnable); + } String[] cipherSuitesToEnable = specToApply.cipherSuites; // null means "use default set". @@ -169,8 +173,12 @@ private ConnectionSpec supportedSpec(SSLSocket sslSocket, boolean isFallback) { } } - String[] protocolsToSelectFrom = sslSocket.getEnabledProtocols(); - String[] protocolsToEnable = Util.intersect(String.class, tlsVersions, protocolsToSelectFrom); + String[] protocolsToEnable = null; + if (tlsVersions != null) { + String[] protocolsToSelectFrom = sslSocket.getEnabledProtocols(); + protocolsToEnable = Util.intersect(String.class, tlsVersions, protocolsToSelectFrom); + } + return new Builder(this) .cipherSuites(cipherSuitesToEnable) .tlsVersions(protocolsToEnable)