Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions netty-socketio-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@
<groupId>io.netty</groupId>
<artifactId>netty-codec</artifactId>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-tcnative-boringssl-static</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-transport-native-epoll</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
*/
package com.socketio4j.socketio;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.security.KeyStore;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;

import org.slf4j.Logger;
Expand Down Expand Up @@ -57,7 +57,11 @@
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler;
import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider;

public class SocketIOChannelInitializer extends ChannelInitializer<Channel> implements DisconnectableHub {

Expand Down Expand Up @@ -91,7 +95,7 @@ public class SocketIOChannelInitializer extends ChannelInitializer<Channel> impl
private CancelableScheduler scheduler = new HashedWheelTimeoutScheduler();

private InPacketHandler packetHandler;
private SSLContext sslContext;
private SslContext sslContext;
private Configuration configuration;

@Override
Expand All @@ -111,7 +115,7 @@ public void start(Configuration configuration, NamespacesHub namespacesHub) {
String connectPath = configuration.getContext() + "/";

SocketSslConfig socketSslConfig = configuration.getSocketSslConfig();
boolean isSsl = socketSslConfig != null && socketSslConfig.getKeyStore() != null;
boolean isSsl = socketSslConfig != null && socketSslConfig.hasKeyStore();
if (isSsl) {
try {
sslContext = createSSLContext(socketSslConfig);
Expand Down Expand Up @@ -154,11 +158,11 @@ protected void initChannel(Channel ch) throws Exception {
*/
protected void addSslHandler(ChannelPipeline pipeline) {
if (sslContext != null) {
SSLEngine engine = sslContext.createSSLEngine();
SSLEngine engine = sslContext.newEngine(pipeline.channel().alloc());
engine.setUseClientMode(false);
if (configuration.isNeedClientAuth()
&& configuration.getSocketSslConfig() != null
&& configuration.getSocketSslConfig().getTrustStore() != null) {
&& configuration.getSocketSslConfig().hasTrustStore()) {
engine.setNeedClientAuth(true);
}
pipeline.addLast(SSL_HANDLER, new SslHandler(engine));
Expand Down Expand Up @@ -200,26 +204,62 @@ protected Object newContinueResponse(HttpMessage start, int maxContentLength,
pipeline.addLast(WRONG_URL_HANDLER, wrongUrlHandler);
}

private SSLContext createSSLContext(SocketSslConfig socketSslConfig) throws Exception {
TrustManager[] managers = null;

if (socketSslConfig.getTrustStore() != null) {
KeyStore ts = KeyStore.getInstance(socketSslConfig.getTrustStoreFormat());
ts.load(socketSslConfig.getTrustStore(), socketSslConfig.getTrustStorePassword().toCharArray());
TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
tmf.init(ts);
managers = tmf.getTrustManagers();
private SslContext createSSLContext(SocketSslConfig socketSslConfig) throws Exception {
byte[] keyMaterial = socketSslConfig.resolveKeyStoreBytes();
if (keyMaterial == null) {
throw new IllegalStateException("SocketSslConfig key store material is missing");
}

KeyStore ks = KeyStore.getInstance(socketSslConfig.getKeyStoreFormat());
ks.load(socketSslConfig.getKeyStore(), socketSslConfig.getKeyStorePassword().toCharArray());
try (InputStream keyStoreStream = new ByteArrayInputStream(keyMaterial)) {
ks.load(keyStoreStream, socketSslConfig.getKeyStorePassword().toCharArray());
}

KeyManagerFactory kmf = KeyManagerFactory.getInstance(socketSslConfig.getKeyManagerFactoryAlgorithm());
kmf.init(ks, socketSslConfig.getKeyStorePassword().toCharArray());

SSLContext serverContext = SSLContext.getInstance(socketSslConfig.getSSLProtocol());
serverContext.init(kmf.getKeyManagers(), managers, null);
return serverContext;
SslProvider sslProvider;
if (OpenSsl.isAvailable()) {
sslProvider = SslProvider.OPENSSL;
} else {
sslProvider = SslProvider.JDK;
}

SslContextBuilder builder = SslContextBuilder.forServer(kmf).sslProvider(sslProvider);
String sslProtocol = socketSslConfig.getSSLProtocol();
if (sslProtocol != null) {
// SocketSslConfig historically accepted SSLContext algorithm names like "TLS".
// SslContextBuilder.protocols(...) expects concrete enabled protocol versions.
if (isTlsProtocolVersion(sslProtocol)) {
builder.protocols(sslProtocol);
} else {
log.warn("Ignoring SocketSslConfig.sslProtocol='{}' because it is not a concrete TLS protocol " +
"version (expected values like 'TLSv1.2' or 'TLSv1.3'). Using provider defaults instead.",
sslProtocol);
}
}
if (socketSslConfig.hasTrustStore()) {
byte[] trustMaterial = socketSslConfig.resolveTrustStoreBytes();
if (trustMaterial == null) {
throw new IllegalStateException("SocketSslConfig trust store material is missing");
}
TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
KeyStore ts = KeyStore.getInstance(socketSslConfig.getTrustStoreFormat());
try (InputStream trustStoreStream = new ByteArrayInputStream(trustMaterial)) {
ts.load(trustStoreStream, socketSslConfig.getTrustStorePassword().toCharArray());
}
tmf.init(ts);
builder.trustManager(tmf);
}
return builder.build();
}

private static boolean isTlsProtocolVersion(String value) {
// Common enabled-protocol tokens used by JSSE/Netty.
// Accept "TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3"; reject "TLSv1.0" and other dotted minors.
if (value == null) {
return false;
}
return value.matches("^TLSv1(\\.(1|2|3))?$");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,16 @@ public Future<Void> startAsync() {
if (future.isSuccess()) {
ChannelFuture cf = (ChannelFuture) future;
serverChannel.set(cf.channel());
if (configCopy.getPort() == 0) {
try {
InetSocketAddress local = (InetSocketAddress) cf.channel().localAddress();
int actualPort = local.getPort();
configCopy.setPort(actualPort);
configuration.setPort(actualPort);
} catch (Exception ignore) {
// keep configured port if localAddress is not InetSocketAddress
}
}
serverStatus.set(ServerStatus.STARTED);
log.info("SocketIO server started on port {}", configCopy.getPort());
installShutdownHookOnce();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package com.socketio4j.socketio;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;

import javax.net.ssl.KeyManagerFactory;
Expand All @@ -31,6 +33,10 @@ public class SocketSslConfig {
private InputStream trustStore;
private String trustStorePassword;

private final Object sslMaterialLock = new Object();
private byte[] cachedKeyStoreBytes;
private byte[] cachedTrustStoreBytes;

private String keyManagerFactoryAlgorithm = KeyManagerFactory.getDefaultAlgorithm();

/**
Expand All @@ -47,7 +53,12 @@ public String getKeyStorePassword() {
}

/**
* SSL key store stream, maybe appointed to any source
* SSL key store stream, maybe appointed to any source.
* <p>
* On the first TLS context build when the server starts, the stream is read fully into memory and closed;
* later start/stop cycles reuse the buffered bytes so the same {@code SocketSslConfig} instance remains valid.
* After buffering, {@link #getKeyStore()} returns {@code null}.
* </p>
*
* @param keyStore - key store input stream
*/
Expand All @@ -59,6 +70,37 @@ public InputStream getKeyStore() {
return keyStore;
}

/**
* Whether a key store is configured (stream not yet consumed or already buffered).
*/
public boolean hasKeyStore() {
synchronized (sslMaterialLock) {
return keyStore != null || cachedKeyStoreBytes != null;
}
}

byte[] resolveKeyStoreBytes() throws IOException {
synchronized (sslMaterialLock) {
if (cachedKeyStoreBytes != null) {
return cachedKeyStoreBytes;
}
if (keyStore == null) {
return null;
}
try (InputStream in = keyStore) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
byte[] buffer = new byte[4096];
int read;
while ((read = in.read(buffer)) != -1) {
out.write(buffer, 0, read);
}
cachedKeyStoreBytes = out.toByteArray();
}
keyStore = null;
return cachedKeyStoreBytes;
}
}

/**
* Key store format
*
Expand All @@ -85,10 +127,46 @@ public InputStream getTrustStore() {
return trustStore;
}

/**
* Trust store stream. Same buffering and lifecycle as {@link #setKeyStore(InputStream)}.
*
* @param trustStore trust store input stream
*/
public void setTrustStore(InputStream trustStore) {
this.trustStore = trustStore;
}

/**
* Whether a trust store is configured (stream not yet consumed or already buffered).
*/
public boolean hasTrustStore() {
synchronized (sslMaterialLock) {
return trustStore != null || cachedTrustStoreBytes != null;
}
}

byte[] resolveTrustStoreBytes() throws IOException {
synchronized (sslMaterialLock) {
if (cachedTrustStoreBytes != null) {
return cachedTrustStoreBytes;
}
if (trustStore == null) {
return null;
}
try (InputStream in = trustStore) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
byte[] buffer = new byte[4096];
int read;
while ((read = in.read(buffer)) != -1) {
out.write(buffer, 0, read);
}
cachedTrustStoreBytes = out.toByteArray();
}
trustStore = null;
return cachedTrustStoreBytes;
}
}

public String getTrustStorePassword() {
return trustStorePassword;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,12 @@
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.netty.handler.codec.http.websocketx.WebSocket13FrameDecoder;

import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;

Expand Down Expand Up @@ -172,7 +175,19 @@ private void onPost(UUID sessionId, ChannelHandlerContext ctx, String origin, By
content = decoder.preprocessJson(jsonIndex, content);
}

ctx.pipeline().fireChannelRead(new PacketsMessage(client, content, Transport.POLLING));
ChannelHandlerContext codecCtx = ctx.pipeline().context(HttpRequestDecoder.class);
if (codecCtx == null) {
codecCtx = ctx.pipeline().context(WebSocket13FrameDecoder.class);
}
if (codecCtx == null) {
codecCtx = ctx.pipeline().context(HttpServerCodec.class);
}
PacketsMessage packetsMessage = new PacketsMessage(client, content, Transport.POLLING);
if (codecCtx != null) {
codecCtx.fireChannelRead(packetsMessage);
} else {
ctx.pipeline().fireChannelRead(packetsMessage);
}
}

protected void onGet(UUID sessionId, ChannelHandlerContext ctx, String origin) {
Expand Down
Loading
Loading