Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.netty.handler.codec.http.HttpRequest;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import org.apache.log4j.Logger;
Expand Down Expand Up @@ -173,6 +174,13 @@ public ProxyModeController create(HttpRequest httpRequest) {
builder._certificateAuthority);
proxyServerBuilder.connectionFlow(Protocol.HTTPS, httpsConnectionFlow);
}

try {
builder._rootCertificateInputStream.close();
} catch (IOException e) {
LOG.error("Failed to close root certificate input stream", e);
}

return proxyServerBuilder.build();
}

Expand Down
56 changes: 56 additions & 0 deletions mitm/src/main/java/com/linkedin/mitm/model/HttpRequestInfo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.linkedin.mitm.model;

import io.netty.handler.codec.http.HttpMethod;
import org.apache.commons.lang3.builder.EqualsBuilder;
import org.apache.commons.lang3.builder.HashCodeBuilder;


/**
* Holds information about the http request such as path and method.
*/
public class HttpRequestInfo extends RequestInfo {
private String _path;
private HttpMethod _method;

public HttpRequestInfo(RequestInfo requestInfo) {
super(requestInfo.getDomain(), requestInfo.getPort());
}

public String getPath() {
return _path;
}

public HttpRequestInfo setPath(String path) {
_path = path;
return this;
}

public HttpMethod getMethod() {
return _method;
}

public HttpRequestInfo setMethod(HttpMethod method) {
_method = method;
return this;
}

@Override
public Protocol getProtocol() {
return Protocol.HTTP;
}

@Override
public int hashCode() {
return HashCodeBuilder.reflectionHashCode(this);
}

@Override
public boolean equals(Object obj) {
return EqualsBuilder.reflectionEquals(this, obj);
}

@Override
public String toString() {
return "path=" + _path + ", method=" + _method + ", " + super.toString();
}
}
2 changes: 1 addition & 1 deletion mitm/src/main/java/com/linkedin/mitm/model/Protocol.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
* @author shfeng
*/
public enum Protocol {
HTTP, HTTPS
HTTP, HTTPS, BINARY
}
46 changes: 46 additions & 0 deletions mitm/src/main/java/com/linkedin/mitm/model/RequestInfo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.linkedin.mitm.model;

import org.apache.commons.lang3.builder.EqualsBuilder;
import org.apache.commons.lang3.builder.HashCodeBuilder;


/**
* Holds the domain and port of a request. Can be extended to hold other properties off requests once the protocol is known.
*/
public class RequestInfo {
private final String _domain;
private final int _port;

public RequestInfo(String domain, int port) {
_domain = domain;
_port = port;
}

public String getDomain() {
return _domain;
}

public int getPort() {
return _port;
}

@Override
public int hashCode() {
return HashCodeBuilder.reflectionHashCode(this);
}

@Override
public boolean equals(Object obj) {
return EqualsBuilder.reflectionEquals(this, obj);
}

@Override
public String toString() {
return "port=" + _port + ", domain=" + _domain;
}

// can be overriden when extending
public Protocol getProtocol() {
return Protocol.BINARY;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.util.ArrayList;
import java.util.concurrent.ConcurrentHashMap;
import javax.net.ssl.SSLContext;
import org.apache.log4j.Logger;
import org.bouncycastle.operator.OperatorCreationException;
Expand All @@ -38,6 +39,8 @@ public class HandshakeWithClient implements ConnectionFlowStep {
private final CertificateKeyStoreFactory _certificateKeyStoreFactory;
private final CertificateAuthority _certificateAuthority;

private final ConcurrentHashMap<String, SSLContext> _sslContextCache = new ConcurrentHashMap<>();

public HandshakeWithClient(CertificateKeyStoreFactory certificateKeyStoreFactory,
CertificateAuthority certificateAuthority) {
_certificateKeyStoreFactory = certificateKeyStoreFactory;
Expand All @@ -46,18 +49,21 @@ public HandshakeWithClient(CertificateKeyStoreFactory certificateKeyStoreFactory

@Override
public Future execute(ChannelMediator channelMediator, InetSocketAddress remoteAddress) {

//dynamically create SSLEngine based on CN and SANs
LOG.debug("Starting client to proxy connection handshaking");
try {
//TODO: if connect request only contains ip address, we need get either CA
//TODO: or SANS from server response
KeyStore keyStore = _certificateKeyStoreFactory.create(remoteAddress.getHostName(), new ArrayList<>());
SSLContext sslContext = SSLContextGenerator.createClientContext(keyStore, _certificateAuthority.getPassPhrase());
return channelMediator.handshakeWithClient(sslContext.createSSLEngine());
} catch (NoSuchAlgorithmException | KeyStoreException | IOException | CertificateException | OperatorCreationException
| NoSuchProviderException | InvalidKeyException | SignatureException | KeyManagementException | UnrecoverableKeyException e) {
throw new RuntimeException("Failed to create server identity certificate", e);
}
LOG.debug("Starting client connection handshaking");

//TODO: if connect request only contains ip address, we need get either CA
//TODO: or SANS from server response
String remoteHost = remoteAddress.getHostName();
SSLContext sslContext = _sslContextCache.computeIfAbsent(remoteHost, key -> {
try {
KeyStore keyStore = _certificateKeyStoreFactory.create(remoteHost, new ArrayList<>());
return SSLContextGenerator.createClientContext(keyStore, _certificateAuthority.getPassPhrase());
} catch (Exception e) {
throw new RuntimeException("Failed to create server identity certificate for " + remoteHost, e);
}
});
return channelMediator.handshakeWithClient(sslContext.createSSLEngine());
}

}
26 changes: 26 additions & 0 deletions mitm/src/main/java/com/linkedin/mitm/services/AccessService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.linkedin.mitm.services;

import com.linkedin.mitm.model.RequestInfo;


/**
* Is used to determine if a connection or a request is allowed by the proxy.
*/
public interface AccessService {
/**
*
* Called when a connect request is issued so we only have info about the service name, destination and port.\
*
* @param serviceName
* @param requestInfo
* @return true or false depending on if the domain/port combination for the specified service has been whitelisted
*/
boolean isValidConnection(String serviceName, RequestInfo requestInfo);

/**
* @param serviceName
* @param requestInfo
* @return true if any of the protocol rules defined in the config evaluates to true for the given request info.
*/
boolean isValidRequest(String serviceName, RequestInfo requestInfo);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.linkedin.mitm.services;

import io.netty.handler.codec.http.HttpRequest;


/**
* This service provides an opportunity for the user to handle any customized HTTP protocol that they use. Most importantly,
* we want to allow the user to have some kind of authentication mechanism for HTTP requests other than TLS.
* For example, a client may choose BEARER auth token as an authentication mechanism.
*/
public interface HttpCustomProtocolService {

/**
* Check if the incoming request is allowed to proceed or not.
* @param request
* @return true if the request should proceed further in the proxy flow.
*/
default boolean isAllowed(HttpRequest request) {
return true;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package com.linkedin.mitm.services;

import io.netty.util.concurrent.Promise;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
import org.apache.log4j.Logger;


/**
* This class runs long running tasks in a separate thread pool. When a task completes, its associated callback is invoked.
*/
public class LongRunningTaskService {
private static final Logger LOG = Logger.getLogger(LongRunningTaskService.class);

public interface LongRunningTaskCallback<R> {
R fullfillPromise() throws Exception;

Promise<R> getPromise();
}

private static final ExecutorService EXECUTOR_SERVICE = Executors.newCachedThreadPool();

private static final List<LongRunningTaskCallback> _callbackList = new ArrayList<>();

private static final ReentrantLock LOCK = new ReentrantLock();
private static final Condition NOT_EMPTY = LOCK.newCondition();

static {
EXECUTOR_SERVICE.submit(() -> {
while (true) {
LOCK.lock();
try {
while (_callbackList.isEmpty()) {
NOT_EMPTY.await();
}
for (Iterator<LongRunningTaskCallback> it = _callbackList.iterator(); it.hasNext(); ) {
LongRunningTaskCallback callback = it.next();
Promise promise = callback.getPromise();
try {
Object result = callback.fullfillPromise();
promise.setSuccess(result);
} catch (Exception e) {
LOG.error("Failed to complete callback", e);
promise.setFailure(e);
} finally {
_callbackList.remove(callback);
}
}
} catch (InterruptedException e) {
LOG.debug("shutting down thread pool in LongRunningTaskService");
return;
} finally {
LOCK.unlock();
}
}
}
);
}

public static void submitTaskCallback(LongRunningTaskCallback callback) {
LOCK.lock();
try {
_callbackList.add(callback);
NOT_EMPTY.signal();
} finally {
LOCK.unlock();
}
}

public static void shutdownAndAwaitTermination() {
EXECUTOR_SERVICE.shutdown(); // Disable new tasks from being submitted
try {
// Wait a while for existing tasks to terminate
if (!EXECUTOR_SERVICE.awaitTermination(60, TimeUnit.SECONDS)) {
EXECUTOR_SERVICE.shutdownNow(); // Cancel currently executing tasks
// Wait a while for tasks to respond to being cancelled
if (!EXECUTOR_SERVICE.awaitTermination(60, TimeUnit.SECONDS))
LOG.error("LongRunningTaskService thread pool did not terminate");
}
} catch (InterruptedException ie) {
// (Re-)Cancel if current thread also interrupted
EXECUTOR_SERVICE.shutdownNow();
// Preserve interrupt status
Thread.currentThread().interrupt();
}
}
}
Loading