diff --git a/flashback-smartproxy/src/main/java/com/linkedin/flashback/smartproxy/FlashbackRunner.java b/flashback-smartproxy/src/main/java/com/linkedin/flashback/smartproxy/FlashbackRunner.java index b54e779..119064b 100644 --- a/flashback-smartproxy/src/main/java/com/linkedin/flashback/smartproxy/FlashbackRunner.java +++ b/flashback-smartproxy/src/main/java/com/linkedin/flashback/smartproxy/FlashbackRunner.java @@ -21,6 +21,7 @@ import io.netty.handler.codec.http.HttpRequest; import java.io.FileInputStream; import java.io.FileNotFoundException; +import java.io.IOException; import java.io.InputStream; import java.util.List; import org.apache.log4j.Logger; @@ -173,6 +174,13 @@ public ProxyModeController create(HttpRequest httpRequest) { builder._certificateAuthority); proxyServerBuilder.connectionFlow(Protocol.HTTPS, httpsConnectionFlow); } + + try { + builder._rootCertificateInputStream.close(); + } catch (IOException e) { + LOG.error("Failed to close root certificate input stream", e); + } + return proxyServerBuilder.build(); } diff --git a/mitm/src/main/java/com/linkedin/mitm/model/HttpRequestInfo.java b/mitm/src/main/java/com/linkedin/mitm/model/HttpRequestInfo.java new file mode 100644 index 0000000..03f2ad1 --- /dev/null +++ b/mitm/src/main/java/com/linkedin/mitm/model/HttpRequestInfo.java @@ -0,0 +1,56 @@ +package com.linkedin.mitm.model; + +import io.netty.handler.codec.http.HttpMethod; +import org.apache.commons.lang3.builder.EqualsBuilder; +import org.apache.commons.lang3.builder.HashCodeBuilder; + + +/** + * Holds information about the http request such as path and method. + */ +public class HttpRequestInfo extends RequestInfo { + private String _path; + private HttpMethod _method; + + public HttpRequestInfo(RequestInfo requestInfo) { + super(requestInfo.getDomain(), requestInfo.getPort()); + } + + public String getPath() { + return _path; + } + + public HttpRequestInfo setPath(String path) { + _path = path; + return this; + } + + public HttpMethod getMethod() { + return _method; + } + + public HttpRequestInfo setMethod(HttpMethod method) { + _method = method; + return this; + } + + @Override + public Protocol getProtocol() { + return Protocol.HTTP; + } + + @Override + public int hashCode() { + return HashCodeBuilder.reflectionHashCode(this); + } + + @Override + public boolean equals(Object obj) { + return EqualsBuilder.reflectionEquals(this, obj); + } + + @Override + public String toString() { + return "path=" + _path + ", method=" + _method + ", " + super.toString(); + } +} diff --git a/mitm/src/main/java/com/linkedin/mitm/model/Protocol.java b/mitm/src/main/java/com/linkedin/mitm/model/Protocol.java index 7a64902..f2e32b6 100644 --- a/mitm/src/main/java/com/linkedin/mitm/model/Protocol.java +++ b/mitm/src/main/java/com/linkedin/mitm/model/Protocol.java @@ -11,5 +11,5 @@ * @author shfeng */ public enum Protocol { - HTTP, HTTPS + HTTP, HTTPS, BINARY } diff --git a/mitm/src/main/java/com/linkedin/mitm/model/RequestInfo.java b/mitm/src/main/java/com/linkedin/mitm/model/RequestInfo.java new file mode 100644 index 0000000..39fe3b0 --- /dev/null +++ b/mitm/src/main/java/com/linkedin/mitm/model/RequestInfo.java @@ -0,0 +1,46 @@ +package com.linkedin.mitm.model; + +import org.apache.commons.lang3.builder.EqualsBuilder; +import org.apache.commons.lang3.builder.HashCodeBuilder; + + +/** + * Holds the domain and port of a request. Can be extended to hold other properties off requests once the protocol is known. + */ +public class RequestInfo { + private final String _domain; + private final int _port; + + public RequestInfo(String domain, int port) { + _domain = domain; + _port = port; + } + + public String getDomain() { + return _domain; + } + + public int getPort() { + return _port; + } + + @Override + public int hashCode() { + return HashCodeBuilder.reflectionHashCode(this); + } + + @Override + public boolean equals(Object obj) { + return EqualsBuilder.reflectionEquals(this, obj); + } + + @Override + public String toString() { + return "port=" + _port + ", domain=" + _domain; + } + + // can be overriden when extending + public Protocol getProtocol() { + return Protocol.BINARY; + } +} \ No newline at end of file diff --git a/mitm/src/main/java/com/linkedin/mitm/proxy/connectionflow/steps/HandshakeWithClient.java b/mitm/src/main/java/com/linkedin/mitm/proxy/connectionflow/steps/HandshakeWithClient.java index 1326522..193f108 100644 --- a/mitm/src/main/java/com/linkedin/mitm/proxy/connectionflow/steps/HandshakeWithClient.java +++ b/mitm/src/main/java/com/linkedin/mitm/proxy/connectionflow/steps/HandshakeWithClient.java @@ -22,6 +22,7 @@ import java.security.UnrecoverableKeyException; import java.security.cert.CertificateException; import java.util.ArrayList; +import java.util.concurrent.ConcurrentHashMap; import javax.net.ssl.SSLContext; import org.apache.log4j.Logger; import org.bouncycastle.operator.OperatorCreationException; @@ -38,6 +39,8 @@ public class HandshakeWithClient implements ConnectionFlowStep { private final CertificateKeyStoreFactory _certificateKeyStoreFactory; private final CertificateAuthority _certificateAuthority; + private final ConcurrentHashMap _sslContextCache = new ConcurrentHashMap<>(); + public HandshakeWithClient(CertificateKeyStoreFactory certificateKeyStoreFactory, CertificateAuthority certificateAuthority) { _certificateKeyStoreFactory = certificateKeyStoreFactory; @@ -46,18 +49,21 @@ public HandshakeWithClient(CertificateKeyStoreFactory certificateKeyStoreFactory @Override public Future execute(ChannelMediator channelMediator, InetSocketAddress remoteAddress) { - //dynamically create SSLEngine based on CN and SANs - LOG.debug("Starting client to proxy connection handshaking"); - try { - //TODO: if connect request only contains ip address, we need get either CA - //TODO: or SANS from server response - KeyStore keyStore = _certificateKeyStoreFactory.create(remoteAddress.getHostName(), new ArrayList<>()); - SSLContext sslContext = SSLContextGenerator.createClientContext(keyStore, _certificateAuthority.getPassPhrase()); - return channelMediator.handshakeWithClient(sslContext.createSSLEngine()); - } catch (NoSuchAlgorithmException | KeyStoreException | IOException | CertificateException | OperatorCreationException - | NoSuchProviderException | InvalidKeyException | SignatureException | KeyManagementException | UnrecoverableKeyException e) { - throw new RuntimeException("Failed to create server identity certificate", e); - } + LOG.debug("Starting client connection handshaking"); + + //TODO: if connect request only contains ip address, we need get either CA + //TODO: or SANS from server response + String remoteHost = remoteAddress.getHostName(); + SSLContext sslContext = _sslContextCache.computeIfAbsent(remoteHost, key -> { + try { + KeyStore keyStore = _certificateKeyStoreFactory.create(remoteHost, new ArrayList<>()); + return SSLContextGenerator.createClientContext(keyStore, _certificateAuthority.getPassPhrase()); + } catch (Exception e) { + throw new RuntimeException("Failed to create server identity certificate for " + remoteHost, e); + } + }); + return channelMediator.handshakeWithClient(sslContext.createSSLEngine()); } + } diff --git a/mitm/src/main/java/com/linkedin/mitm/services/AccessService.java b/mitm/src/main/java/com/linkedin/mitm/services/AccessService.java new file mode 100644 index 0000000..7a0a033 --- /dev/null +++ b/mitm/src/main/java/com/linkedin/mitm/services/AccessService.java @@ -0,0 +1,26 @@ +package com.linkedin.mitm.services; + +import com.linkedin.mitm.model.RequestInfo; + + +/** + * Is used to determine if a connection or a request is allowed by the proxy. + */ +public interface AccessService { + /** + * + * Called when a connect request is issued so we only have info about the service name, destination and port.\ + * + * @param serviceName + * @param requestInfo + * @return true or false depending on if the domain/port combination for the specified service has been whitelisted + */ + boolean isValidConnection(String serviceName, RequestInfo requestInfo); + + /** + * @param serviceName + * @param requestInfo + * @return true if any of the protocol rules defined in the config evaluates to true for the given request info. + */ + boolean isValidRequest(String serviceName, RequestInfo requestInfo); +} diff --git a/mitm/src/main/java/com/linkedin/mitm/services/HttpCustomProtocolService.java b/mitm/src/main/java/com/linkedin/mitm/services/HttpCustomProtocolService.java new file mode 100644 index 0000000..be39c93 --- /dev/null +++ b/mitm/src/main/java/com/linkedin/mitm/services/HttpCustomProtocolService.java @@ -0,0 +1,22 @@ +package com.linkedin.mitm.services; + +import io.netty.handler.codec.http.HttpRequest; + + +/** + * This service provides an opportunity for the user to handle any customized HTTP protocol that they use. Most importantly, + * we want to allow the user to have some kind of authentication mechanism for HTTP requests other than TLS. + * For example, a client may choose BEARER auth token as an authentication mechanism. + */ +public interface HttpCustomProtocolService { + + /** + * Check if the incoming request is allowed to proceed or not. + * @param request + * @return true if the request should proceed further in the proxy flow. + */ + default boolean isAllowed(HttpRequest request) { + return true; + } + +} diff --git a/mitm/src/main/java/com/linkedin/mitm/services/LongRunningTaskService.java b/mitm/src/main/java/com/linkedin/mitm/services/LongRunningTaskService.java new file mode 100644 index 0000000..d54b362 --- /dev/null +++ b/mitm/src/main/java/com/linkedin/mitm/services/LongRunningTaskService.java @@ -0,0 +1,93 @@ +package com.linkedin.mitm.services; + +import io.netty.util.concurrent.Promise; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; +import org.apache.log4j.Logger; + + +/** + * This class runs long running tasks in a separate thread pool. When a task completes, its associated callback is invoked. + */ +public class LongRunningTaskService { + private static final Logger LOG = Logger.getLogger(LongRunningTaskService.class); + + public interface LongRunningTaskCallback { + R fullfillPromise() throws Exception; + + Promise getPromise(); + } + + private static final ExecutorService EXECUTOR_SERVICE = Executors.newCachedThreadPool(); + + private static final List _callbackList = new ArrayList<>(); + + private static final ReentrantLock LOCK = new ReentrantLock(); + private static final Condition NOT_EMPTY = LOCK.newCondition(); + + static { + EXECUTOR_SERVICE.submit(() -> { + while (true) { + LOCK.lock(); + try { + while (_callbackList.isEmpty()) { + NOT_EMPTY.await(); + } + for (Iterator it = _callbackList.iterator(); it.hasNext(); ) { + LongRunningTaskCallback callback = it.next(); + Promise promise = callback.getPromise(); + try { + Object result = callback.fullfillPromise(); + promise.setSuccess(result); + } catch (Exception e) { + LOG.error("Failed to complete callback", e); + promise.setFailure(e); + } finally { + _callbackList.remove(callback); + } + } + } catch (InterruptedException e) { + LOG.debug("shutting down thread pool in LongRunningTaskService"); + return; + } finally { + LOCK.unlock(); + } + } + } + ); + } + + public static void submitTaskCallback(LongRunningTaskCallback callback) { + LOCK.lock(); + try { + _callbackList.add(callback); + NOT_EMPTY.signal(); + } finally { + LOCK.unlock(); + } + } + + public static void shutdownAndAwaitTermination() { + EXECUTOR_SERVICE.shutdown(); // Disable new tasks from being submitted + try { + // Wait a while for existing tasks to terminate + if (!EXECUTOR_SERVICE.awaitTermination(60, TimeUnit.SECONDS)) { + EXECUTOR_SERVICE.shutdownNow(); // Cancel currently executing tasks + // Wait a while for tasks to respond to being cancelled + if (!EXECUTOR_SERVICE.awaitTermination(60, TimeUnit.SECONDS)) + LOG.error("LongRunningTaskService thread pool did not terminate"); + } + } catch (InterruptedException ie) { + // (Re-)Cancel if current thread also interrupted + EXECUTOR_SERVICE.shutdownNow(); + // Preserve interrupt status + Thread.currentThread().interrupt(); + } + } +} diff --git a/mitm/src/main/java/com/linkedin/mitm/services/SSLContextGenerator.java b/mitm/src/main/java/com/linkedin/mitm/services/SSLContextGenerator.java index 28116fa..22cdedf 100644 --- a/mitm/src/main/java/com/linkedin/mitm/services/SSLContextGenerator.java +++ b/mitm/src/main/java/com/linkedin/mitm/services/SSLContextGenerator.java @@ -5,13 +5,14 @@ package com.linkedin.mitm.services; -import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import java.io.IOException; +import java.io.InputStream; import java.security.KeyManagementException; import java.security.KeyStore; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; -import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; import javax.net.ssl.KeyManager; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; @@ -25,32 +26,60 @@ */ public class SSLContextGenerator { private static final String SSL_CONTEXT_PROTOCOL = "TLS"; - private static final String KEY_MANAGER_TYPE = "SunX509"; private static final String TRUST_MANAGER_TYPE = "SunX509"; + private static final String KEY_STORE_TYPE = "JKS"; + private static final String CA_STORE = "/etc/riddler/cacerts"; + private static final String CA_PASSWORD = "changeit"; /** - * Create client side SSLContext {@link javax.net.ssl.SSLContext} - * - * */ - public static SSLContext createClientContext(KeyStore keyStore, char[] passphrase) - throws NoSuchAlgorithmException, KeyManagementException, KeyStoreException, UnrecoverableKeyException { + * Create SSLContext for ssl traffic between client and proxy {@link javax.net.ssl.SSLContext} + **/ + public static SSLContext createClientContext(KeyStore keyStore, char[] passphrase) throws Exception { + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TRUST_MANAGER_TYPE); + trustManagerFactory.init(keyStore); String keyManAlg = KeyManagerFactory.getDefaultAlgorithm(); KeyManagerFactory kmf = KeyManagerFactory.getInstance(keyManAlg); kmf.init(keyStore, passphrase); KeyManager[] keyManagers = kmf.getKeyManagers(); - return create(keyManagers, InsecureTrustManagerFactory.INSTANCE.getTrustManagers(), + + // set up trust manager factory to use riddler trust store + TrustManagerFactory + tmf = TrustManagerFactoryGenerator.newTrustManagerFactory(CA_STORE, CA_PASSWORD.toCharArray(), KEY_STORE_TYPE); + + TrustManager[] trustManagers = tmf.getTrustManagers(); + + return create(keyManagers, trustManagers, RandomNumberGenerator.getInstance().getSecureRandom()); } /** - * Create default server side SSLContext {@link javax.net.ssl.SSLContext} - * - * */ - public static SSLContext createDefaultServerContext() - throws KeyManagementException, NoSuchAlgorithmException, UnrecoverableKeyException, KeyStoreException { + * Create SSLContext for ssl traffic between proxy and destination server {@link javax.net.ssl.SSLContext} + **/ + public static SSLContext createDefaultServerContext() throws KeyManagementException, NoSuchAlgorithmException { return create(null, null, RandomNumberGenerator.getInstance().getSecureRandom()); } + /** + * Create SSLContext for ssl traffic between proxy and destination server {@link javax.net.ssl.SSLContext} + **/ + public static SSLContext createCustomServerContext(InputStream inputStream, String password) + throws KeyManagementException, NoSuchAlgorithmException, KeyStoreException { + KeyStore keyStore = KeyStore.getInstance("JKS"); + + // load default Riddler keystore + try { + keyStore.load(inputStream, password.toCharArray()); + + TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(keyStore); + + return create(null, tmf.getTrustManagers(), RandomNumberGenerator.getInstance().getSecureRandom()); + } catch (IOException | CertificateException e) { + e.printStackTrace(); + throw new RuntimeException(e); + } + } + private static SSLContext create(KeyManager[] keyManagers, TrustManager[] trustManagers, SecureRandom secureRandom) throws NoSuchAlgorithmException, KeyManagementException { SSLContext sslContext = SSLContext.getInstance(SSL_CONTEXT_PROTOCOL); diff --git a/mitm/src/main/java/com/linkedin/mitm/services/ServiceNameExtractor.java b/mitm/src/main/java/com/linkedin/mitm/services/ServiceNameExtractor.java new file mode 100644 index 0000000..522d328 --- /dev/null +++ b/mitm/src/main/java/com/linkedin/mitm/services/ServiceNameExtractor.java @@ -0,0 +1,8 @@ +package com.linkedin.mitm.services; + +import java.security.cert.Certificate; + + +public interface ServiceNameExtractor { + String extractServiceName(Certificate certificate); +} diff --git a/mitm/src/main/java/com/linkedin/mitm/services/TrustManagerFactoryGenerator.java b/mitm/src/main/java/com/linkedin/mitm/services/TrustManagerFactoryGenerator.java new file mode 100644 index 0000000..dbb88e2 --- /dev/null +++ b/mitm/src/main/java/com/linkedin/mitm/services/TrustManagerFactoryGenerator.java @@ -0,0 +1,24 @@ +package com.linkedin.mitm.services; + +import java.io.FileInputStream; +import java.security.KeyStore; +import javax.net.ssl.TrustManagerFactory; + + +/** + * Helper class for creating a TrustManagerFactory + */ +public class TrustManagerFactoryGenerator { + + public static TrustManagerFactory newTrustManagerFactory(String trustStorePath, char[] trustStorePwd, String keyStoreType) throws + Exception { + try (FileInputStream fileInputStream = new FileInputStream(trustStorePath)) { + KeyStore ts = KeyStore.getInstance(keyStoreType); + ts.load(fileInputStream, trustStorePwd); + + TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(ts); + return tmf; + } + } +} diff --git a/mitm/src/main/java/com/linkedin/mitm/store/PKC12KeyStoreReadWriter.java b/mitm/src/main/java/com/linkedin/mitm/store/PKC12KeyStoreReadWriter.java index 86a135c..57e2f8a 100644 --- a/mitm/src/main/java/com/linkedin/mitm/store/PKC12KeyStoreReadWriter.java +++ b/mitm/src/main/java/com/linkedin/mitm/store/PKC12KeyStoreReadWriter.java @@ -30,11 +30,8 @@ public class PKC12KeyStoreReadWriter implements KeyStoreReader, KeyStoreWriter { public KeyStore load(InputStream inputstream, String password) throws CertificateException, NoSuchAlgorithmException, KeyStoreException, IOException { KeyStore ksKeys = KeyStore.getInstance(KEY_STORE_TYPE); - try { - ksKeys.load(inputstream, password.toCharArray()); - } finally { - inputstream.close(); - } + ksKeys.load(inputstream, password.toCharArray()); + return ksKeys; }