diff --git a/netty-socketio-core/pom.xml b/netty-socketio-core/pom.xml index d297376d..45bfc6c3 100644 --- a/netty-socketio-core/pom.xml +++ b/netty-socketio-core/pom.xml @@ -80,6 +80,11 @@ io.netty netty-codec + + io.netty + netty-tcnative-boringssl-static + true + io.netty netty-transport-native-epoll diff --git a/netty-socketio-core/src/main/java/com/socketio4j/socketio/SocketIOChannelInitializer.java b/netty-socketio-core/src/main/java/com/socketio4j/socketio/SocketIOChannelInitializer.java index dbffb960..6170c013 100644 --- a/netty-socketio-core/src/main/java/com/socketio4j/socketio/SocketIOChannelInitializer.java +++ b/netty-socketio-core/src/main/java/com/socketio4j/socketio/SocketIOChannelInitializer.java @@ -16,12 +16,12 @@ */ package com.socketio4j.socketio; +import java.io.ByteArrayInputStream; +import java.io.InputStream; import java.security.KeyStore; import javax.net.ssl.KeyManagerFactory; -import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; -import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import org.slf4j.Logger; @@ -57,7 +57,11 @@ import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpResponseEncoder; import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler; +import io.netty.handler.ssl.OpenSsl; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.SslHandler; +import io.netty.handler.ssl.SslProvider; public class SocketIOChannelInitializer extends ChannelInitializer implements DisconnectableHub { @@ -91,7 +95,7 @@ public class SocketIOChannelInitializer extends ChannelInitializer impl private CancelableScheduler scheduler = new HashedWheelTimeoutScheduler(); private InPacketHandler packetHandler; - private SSLContext sslContext; + private SslContext sslContext; private Configuration configuration; @Override @@ -111,7 +115,7 @@ public void start(Configuration configuration, NamespacesHub namespacesHub) { String connectPath = configuration.getContext() + "/"; SocketSslConfig socketSslConfig = configuration.getSocketSslConfig(); - boolean isSsl = socketSslConfig != null && socketSslConfig.getKeyStore() != null; + boolean isSsl = socketSslConfig != null && socketSslConfig.hasKeyStore(); if (isSsl) { try { sslContext = createSSLContext(socketSslConfig); @@ -154,11 +158,11 @@ protected void initChannel(Channel ch) throws Exception { */ protected void addSslHandler(ChannelPipeline pipeline) { if (sslContext != null) { - SSLEngine engine = sslContext.createSSLEngine(); + SSLEngine engine = sslContext.newEngine(pipeline.channel().alloc()); engine.setUseClientMode(false); if (configuration.isNeedClientAuth() && configuration.getSocketSslConfig() != null - && configuration.getSocketSslConfig().getTrustStore() != null) { + && configuration.getSocketSslConfig().hasTrustStore()) { engine.setNeedClientAuth(true); } pipeline.addLast(SSL_HANDLER, new SslHandler(engine)); @@ -200,26 +204,62 @@ protected Object newContinueResponse(HttpMessage start, int maxContentLength, pipeline.addLast(WRONG_URL_HANDLER, wrongUrlHandler); } - private SSLContext createSSLContext(SocketSslConfig socketSslConfig) throws Exception { - TrustManager[] managers = null; - - if (socketSslConfig.getTrustStore() != null) { - KeyStore ts = KeyStore.getInstance(socketSslConfig.getTrustStoreFormat()); - ts.load(socketSslConfig.getTrustStore(), socketSslConfig.getTrustStorePassword().toCharArray()); - TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); - tmf.init(ts); - managers = tmf.getTrustManagers(); + private SslContext createSSLContext(SocketSslConfig socketSslConfig) throws Exception { + byte[] keyMaterial = socketSslConfig.resolveKeyStoreBytes(); + if (keyMaterial == null) { + throw new IllegalStateException("SocketSslConfig key store material is missing"); } - KeyStore ks = KeyStore.getInstance(socketSslConfig.getKeyStoreFormat()); - ks.load(socketSslConfig.getKeyStore(), socketSslConfig.getKeyStorePassword().toCharArray()); + try (InputStream keyStoreStream = new ByteArrayInputStream(keyMaterial)) { + ks.load(keyStoreStream, socketSslConfig.getKeyStorePassword().toCharArray()); + } KeyManagerFactory kmf = KeyManagerFactory.getInstance(socketSslConfig.getKeyManagerFactoryAlgorithm()); kmf.init(ks, socketSslConfig.getKeyStorePassword().toCharArray()); - SSLContext serverContext = SSLContext.getInstance(socketSslConfig.getSSLProtocol()); - serverContext.init(kmf.getKeyManagers(), managers, null); - return serverContext; + SslProvider sslProvider; + if (OpenSsl.isAvailable()) { + sslProvider = SslProvider.OPENSSL; + } else { + sslProvider = SslProvider.JDK; + } + + SslContextBuilder builder = SslContextBuilder.forServer(kmf).sslProvider(sslProvider); + String sslProtocol = socketSslConfig.getSSLProtocol(); + if (sslProtocol != null) { + // SocketSslConfig historically accepted SSLContext algorithm names like "TLS". + // SslContextBuilder.protocols(...) expects concrete enabled protocol versions. + if (isTlsProtocolVersion(sslProtocol)) { + builder.protocols(sslProtocol); + } else { + log.warn("Ignoring SocketSslConfig.sslProtocol='{}' because it is not a concrete TLS protocol " + + "version (expected values like 'TLSv1.2' or 'TLSv1.3'). Using provider defaults instead.", + sslProtocol); + } + } + if (socketSslConfig.hasTrustStore()) { + byte[] trustMaterial = socketSslConfig.resolveTrustStoreBytes(); + if (trustMaterial == null) { + throw new IllegalStateException("SocketSslConfig trust store material is missing"); + } + TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + KeyStore ts = KeyStore.getInstance(socketSslConfig.getTrustStoreFormat()); + try (InputStream trustStoreStream = new ByteArrayInputStream(trustMaterial)) { + ts.load(trustStoreStream, socketSslConfig.getTrustStorePassword().toCharArray()); + } + tmf.init(ts); + builder.trustManager(tmf); + } + return builder.build(); + } + + private static boolean isTlsProtocolVersion(String value) { + // Common enabled-protocol tokens used by JSSE/Netty. + // Accept "TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3"; reject "TLSv1.0" and other dotted minors. + if (value == null) { + return false; + } + return value.matches("^TLSv1(\\.(1|2|3))?$"); } @Override diff --git a/netty-socketio-core/src/main/java/com/socketio4j/socketio/SocketIOServer.java b/netty-socketio-core/src/main/java/com/socketio4j/socketio/SocketIOServer.java index 62fc8486..b10d233b 100644 --- a/netty-socketio-core/src/main/java/com/socketio4j/socketio/SocketIOServer.java +++ b/netty-socketio-core/src/main/java/com/socketio4j/socketio/SocketIOServer.java @@ -598,6 +598,16 @@ public Future startAsync() { if (future.isSuccess()) { ChannelFuture cf = (ChannelFuture) future; serverChannel.set(cf.channel()); + if (configCopy.getPort() == 0) { + try { + InetSocketAddress local = (InetSocketAddress) cf.channel().localAddress(); + int actualPort = local.getPort(); + configCopy.setPort(actualPort); + configuration.setPort(actualPort); + } catch (Exception ignore) { + // keep configured port if localAddress is not InetSocketAddress + } + } serverStatus.set(ServerStatus.STARTED); log.info("SocketIO server started on port {}", configCopy.getPort()); installShutdownHookOnce(); diff --git a/netty-socketio-core/src/main/java/com/socketio4j/socketio/SocketSslConfig.java b/netty-socketio-core/src/main/java/com/socketio4j/socketio/SocketSslConfig.java index 6b76b8fc..93640aa3 100644 --- a/netty-socketio-core/src/main/java/com/socketio4j/socketio/SocketSslConfig.java +++ b/netty-socketio-core/src/main/java/com/socketio4j/socketio/SocketSslConfig.java @@ -16,6 +16,8 @@ */ package com.socketio4j.socketio; +import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.io.InputStream; import javax.net.ssl.KeyManagerFactory; @@ -31,6 +33,10 @@ public class SocketSslConfig { private InputStream trustStore; private String trustStorePassword; + private final Object sslMaterialLock = new Object(); + private byte[] cachedKeyStoreBytes; + private byte[] cachedTrustStoreBytes; + private String keyManagerFactoryAlgorithm = KeyManagerFactory.getDefaultAlgorithm(); /** @@ -47,7 +53,12 @@ public String getKeyStorePassword() { } /** - * SSL key store stream, maybe appointed to any source + * SSL key store stream, maybe appointed to any source. + *

+ * On the first TLS context build when the server starts, the stream is read fully into memory and closed; + * later start/stop cycles reuse the buffered bytes so the same {@code SocketSslConfig} instance remains valid. + * After buffering, {@link #getKeyStore()} returns {@code null}. + *

* * @param keyStore - key store input stream */ @@ -59,6 +70,37 @@ public InputStream getKeyStore() { return keyStore; } + /** + * Whether a key store is configured (stream not yet consumed or already buffered). + */ + public boolean hasKeyStore() { + synchronized (sslMaterialLock) { + return keyStore != null || cachedKeyStoreBytes != null; + } + } + + byte[] resolveKeyStoreBytes() throws IOException { + synchronized (sslMaterialLock) { + if (cachedKeyStoreBytes != null) { + return cachedKeyStoreBytes; + } + if (keyStore == null) { + return null; + } + try (InputStream in = keyStore) { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + byte[] buffer = new byte[4096]; + int read; + while ((read = in.read(buffer)) != -1) { + out.write(buffer, 0, read); + } + cachedKeyStoreBytes = out.toByteArray(); + } + keyStore = null; + return cachedKeyStoreBytes; + } + } + /** * Key store format * @@ -85,10 +127,46 @@ public InputStream getTrustStore() { return trustStore; } + /** + * Trust store stream. Same buffering and lifecycle as {@link #setKeyStore(InputStream)}. + * + * @param trustStore trust store input stream + */ public void setTrustStore(InputStream trustStore) { this.trustStore = trustStore; } + /** + * Whether a trust store is configured (stream not yet consumed or already buffered). + */ + public boolean hasTrustStore() { + synchronized (sslMaterialLock) { + return trustStore != null || cachedTrustStoreBytes != null; + } + } + + byte[] resolveTrustStoreBytes() throws IOException { + synchronized (sslMaterialLock) { + if (cachedTrustStoreBytes != null) { + return cachedTrustStoreBytes; + } + if (trustStore == null) { + return null; + } + try (InputStream in = trustStore) { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + byte[] buffer = new byte[4096]; + int read; + while ((read = in.read(buffer)) != -1) { + out.write(buffer, 0, read); + } + cachedTrustStoreBytes = out.toByteArray(); + } + trustStore = null; + return cachedTrustStoreBytes; + } + } + public String getTrustStorePassword() { return trustStorePassword; } diff --git a/netty-socketio-core/src/main/java/com/socketio4j/socketio/transport/PollingTransport.java b/netty-socketio-core/src/main/java/com/socketio4j/socketio/transport/PollingTransport.java index 0aef5412..1e34aeae 100644 --- a/netty-socketio-core/src/main/java/com/socketio4j/socketio/transport/PollingTransport.java +++ b/netty-socketio-core/src/main/java/com/socketio4j/socketio/transport/PollingTransport.java @@ -43,9 +43,12 @@ import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpServerCodec; import io.netty.handler.codec.http.QueryStringDecoder; +import io.netty.handler.codec.http.websocketx.WebSocket13FrameDecoder; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; @@ -172,7 +175,19 @@ private void onPost(UUID sessionId, ChannelHandlerContext ctx, String origin, By content = decoder.preprocessJson(jsonIndex, content); } - ctx.pipeline().fireChannelRead(new PacketsMessage(client, content, Transport.POLLING)); + ChannelHandlerContext codecCtx = ctx.pipeline().context(HttpRequestDecoder.class); + if (codecCtx == null) { + codecCtx = ctx.pipeline().context(WebSocket13FrameDecoder.class); + } + if (codecCtx == null) { + codecCtx = ctx.pipeline().context(HttpServerCodec.class); + } + PacketsMessage packetsMessage = new PacketsMessage(client, content, Transport.POLLING); + if (codecCtx != null) { + codecCtx.fireChannelRead(packetsMessage); + } else { + ctx.pipeline().fireChannelRead(packetsMessage); + } } protected void onGet(UUID sessionId, ChannelHandlerContext ctx, String origin) { diff --git a/netty-socketio-core/src/main/java/com/socketio4j/socketio/transport/WebSocketTransport.java b/netty-socketio-core/src/main/java/com/socketio4j/socketio/transport/WebSocketTransport.java index 40524fd4..d1fb1609 100644 --- a/netty-socketio-core/src/main/java/com/socketio4j/socketio/transport/WebSocketTransport.java +++ b/netty-socketio-core/src/main/java/com/socketio4j/socketio/transport/WebSocketTransport.java @@ -46,10 +46,16 @@ import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpRequestDecoder; +import io.netty.handler.codec.http.HttpServerCodec; import io.netty.handler.codec.http.QueryStringDecoder; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocket13FrameDecoder; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrameAggregator; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker; import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory; @@ -82,6 +88,15 @@ public WebSocketTransport(boolean isSsl, public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (msg instanceof CloseWebSocketFrame) { ctx.channel().writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE); + } else if (msg instanceof PingWebSocketFrame) { + // keep connection alive, mirror pong + ctx.channel().writeAndFlush(new PongWebSocketFrame(((PingWebSocketFrame) msg).content().retain())); + ((PingWebSocketFrame) msg).release(); + return; + } else if (msg instanceof PongWebSocketFrame) { + // ignore + ((PongWebSocketFrame) msg).release(); + return; } else if (msg instanceof BinaryWebSocketFrame || msg instanceof TextWebSocketFrame) { ByteBufHolder frame = (ByteBufHolder) msg; @@ -97,10 +112,18 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception // Retain its content since we pass it further down the pipeline. PacketsMessage packetsMessage = new PacketsMessage(client, frame.content().retain(), Transport.WEBSOCKET); try { - ctx.pipeline().fireChannelRead(packetsMessage); + firePacketsMessageToPacketHandler(ctx, packetsMessage); } finally { frame.release(); } + } else if (msg instanceof WebSocketFrame) { + // Some clients may send fragmented frames (ContinuationWebSocketFrame) or other control frames. + // Log and release to avoid leaks and to surface missing handling. + if (log.isDebugEnabled()) { + log.debug("Unhandled WebSocketFrame type: {}", msg.getClass().getName()); + } + ((WebSocketFrame) msg).release(); + return; } else if (msg instanceof FullHttpRequest) { FullHttpRequest req = (FullHttpRequest) msg; QueryStringDecoder queryDecoder = new QueryStringDecoder(req.uri()); @@ -141,8 +164,28 @@ public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { ClientHead client = clientsBox.get(ctx.channel()); if (client != null && client.isTransportChannel(ctx.channel(), Transport.WEBSOCKET)) { ctx.flush(); + } + super.channelReadComplete(ctx); + } + + /** + * Deliver engine/socket.io payload to {@link com.socketio4j.socketio.handler.InPacketHandler} without + * re-entering the pipeline from the head (avoids passing {@link PacketsMessage} through {@code SslHandler} + * and the HTTP/WebSocket frame decoder again). + */ + private static void firePacketsMessageToPacketHandler(ChannelHandlerContext ctx, PacketsMessage packetsMessage) { + // After WebSocket handshake, Netty replaces HttpRequestDecoder with WebSocket13FrameDecoder named "wsdecoder". + ChannelHandlerContext codecCtx = ctx.pipeline().context(HttpRequestDecoder.class); + if (codecCtx == null) { + codecCtx = ctx.pipeline().context(WebSocket13FrameDecoder.class); + } + if (codecCtx == null) { + codecCtx = ctx.pipeline().context(HttpServerCodec.class); + } + if (codecCtx != null) { + codecCtx.fireChannelRead(packetsMessage); } else { - super.channelReadComplete(ctx); + ctx.pipeline().fireChannelRead(packetsMessage); } } @@ -168,7 +211,7 @@ private void handshake(ChannelHandlerContext ctx, final UUID sessionId, String p final Channel channel = ctx.channel(); WebSocketServerHandshakerFactory factory = - new WebSocketServerHandshakerFactory(getWebSocketLocation(req), null, true, configuration.getMaxFramePayloadLength()); + new WebSocketServerHandshakerFactory(getWebSocketLocation(req), null, configuration.isWebsocketCompression(), configuration.getMaxFramePayloadLength()); WebSocketServerHandshaker handshaker = factory.newHandshaker(req); if (handshaker != null) { try { diff --git a/netty-socketio-core/src/test/java/com/socketio4j/socketio/SocketSslServerRestartTest.java b/netty-socketio-core/src/test/java/com/socketio4j/socketio/SocketSslServerRestartTest.java new file mode 100644 index 00000000..7322c7b7 --- /dev/null +++ b/netty-socketio-core/src/test/java/com/socketio4j/socketio/SocketSslServerRestartTest.java @@ -0,0 +1,72 @@ +/** + * Copyright (c) 2025 The Socketio4j Project + * Parent project : Copyright (c) 2012-2025 Nikita Koksharov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.socketio4j.socketio; + +import java.io.InputStream; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.Test; + +import com.socketio4j.socketio.nativeio.TransportType; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Ensures TLS material from {@link SocketSslConfig} survives stop/start when streams are not reusable. + */ +public class SocketSslServerRestartTest { + + @Test + public void shouldStartStopStartWithSameSocketSslConfig() throws Exception { + Configuration cfg = new Configuration(); + cfg.setPort(0); + cfg.setOrigin("*"); + cfg.setTransportType(TransportType.NIO); + + SocketSslConfig ssl = new SocketSslConfig(); + ssl.setSSLProtocol("TLSv1.2"); + ssl.setKeyStoreFormat("PKCS12"); + ssl.setKeyStorePassword("password"); + InputStream ks = SocketSslServerRestartTest.class.getClassLoader() + .getResourceAsStream("ssl/test-socketio.p12"); + assertNotNull(ks, "Missing test keystore ssl/test-socketio.p12"); + ssl.setKeyStore(ks); + + cfg.setSocketSslConfig(ssl); + + SocketIOServer server = new SocketIOServer(cfg); + server.start(); + int port = awaitBoundPort(server); + assertTrue(port > 0); + server.stop(); + + assertDoesNotThrow(server::start, "second start should rebuild SSL from buffered keystore bytes"); + server.stop(); + } + + private static int awaitBoundPort(SocketIOServer server) throws InterruptedException { + long deadlineNs = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); + int port = server.getConfiguration().getPort(); + while (port == 0 && System.nanoTime() < deadlineNs) { + Thread.sleep(10); + port = server.getConfiguration().getPort(); + } + return port; + } +} diff --git a/netty-socketio-core/src/test/java/com/socketio4j/socketio/transport/SocketIoJavaClientSslTest.java b/netty-socketio-core/src/test/java/com/socketio4j/socketio/transport/SocketIoJavaClientSslTest.java new file mode 100644 index 00000000..ab5a262d --- /dev/null +++ b/netty-socketio-core/src/test/java/com/socketio4j/socketio/transport/SocketIoJavaClientSslTest.java @@ -0,0 +1,469 @@ +/** + * Copyright (c) 2025 The Socketio4j Project + * Parent project : Copyright (c) 2012-2025 Nikita Koksharov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.socketio4j.socketio.transport; + +import java.io.InputStream; +import java.security.SecureRandom; +import java.security.cert.X509Certificate; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; + +import org.json.JSONObject; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import com.socketio4j.socketio.Configuration; +import com.socketio4j.socketio.SocketIOServer; +import com.socketio4j.socketio.SocketSslConfig; +import com.socketio4j.socketio.nativeio.TransportType; + +import io.socket.client.Ack; +import io.socket.client.IO; +import io.socket.client.Socket; +import okhttp3.OkHttpClient; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * End-to-end tests using the official Java {@code socket.io-client} (OkHttp/WebSocket). + */ +public class SocketIoJavaClientSslTest { + + private SocketIOServer server; + + @AfterEach + public void tearDown() { + if (server != null) { + server.stop(); + server = null; + } + } + + @Test + public void shouldReceiveHelloEventAndAckOverWssFromJavaClient() throws Exception { + CountDownLatch serverReceivedHello = new CountDownLatch(1); + CountDownLatch clientReceivedAck = new CountDownLatch(1); + CountDownLatch engineHandshakeDone = new CountDownLatch(1); + AtomicReference connectError = new AtomicReference<>(); + + server = startServer(0, testSslConfig(), serverReceivedHello); + int port = awaitBoundPort(server); + assertTrue(port > 0, "server did not bind an ephemeral port"); + + X509TrustManager trustAll = new X509TrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) { + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) { + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + }; + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(null, new TrustManager[] { trustAll }, new SecureRandom()); + + OkHttpClient okHttp = new OkHttpClient.Builder() + .sslSocketFactory(sslContext.getSocketFactory(), trustAll) + .hostnameVerifier((hostname, session) -> true) + .readTimeout(1, TimeUnit.MINUTES) + .build(); + + IO.Options opts = new IO.Options(); + opts.forceNew = true; + opts.reconnection = false; + opts.transports = new String[] { "websocket" }; + opts.webSocketFactory = okHttp; + opts.callFactory = okHttp; + + Socket socket = IO.socket("https://127.0.0.1:" + port, opts); + try { + socket.on(Socket.EVENT_CONNECT_ERROR, args -> { + Object first = args.length > 0 ? args[0] : null; + if (first instanceof Throwable) { + connectError.set((Throwable) first); + } else { + connectError.set(new IllegalStateException(String.valueOf(first))); + } + engineHandshakeDone.countDown(); + }); + socket.on(Socket.EVENT_CONNECT, args -> { + try { + JSONObject payload = new JSONObject(); + payload.put("a", 1); + socket.emit("hello", payload, (Ack) ackArgs -> { + if (ackArgs.length > 0 && "ok".equals(String.valueOf(ackArgs[0]))) { + clientReceivedAck.countDown(); + } + }); + } catch (Exception e) { + connectError.set(e); + } finally { + engineHandshakeDone.countDown(); + } + }); + socket.connect(); + + assertTrue(engineHandshakeDone.await(20, TimeUnit.SECONDS), + () -> "Engine.IO handshake did not complete: " + connectError.get()); + assertNull(connectError.get(), () -> "connect_error: " + connectError.get()); + assertTrue(serverReceivedHello.await(15, TimeUnit.SECONDS), "server did not receive hello event"); + assertTrue(clientReceivedAck.await(15, TimeUnit.SECONDS), "client did not receive ack"); + } finally { + socket.disconnect(); + } + } + + /** + * Exercises polling first (POST body via {@code PollingTransport.onPost}), then upgrade to WebSocket. + * Server pushes an event right after connect so delivery runs while the client may still be on polling. + */ + @Test + public void shouldPollUpgradeToWebSocketWithServerPushAndClientHelloAckOverWss() throws Exception { + CountDownLatch serverReceivedHello = new CountDownLatch(1); + CountDownLatch clientReceivedAck = new CountDownLatch(1); + CountDownLatch clientReceivedWelcome = new CountDownLatch(1); + CountDownLatch engineHandshakeDone = new CountDownLatch(1); + AtomicReference connectError = new AtomicReference<>(); + AtomicBoolean unexpectedDisconnect = new AtomicBoolean(false); + + server = startServer(0, testSslConfig(), serverReceivedHello, true); + int port = awaitBoundPort(server); + assertTrue(port > 0, "server did not bind an ephemeral port"); + + X509TrustManager trustAll = new X509TrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) { + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) { + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + }; + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(null, new TrustManager[] { trustAll }, new SecureRandom()); + + OkHttpClient okHttp = new OkHttpClient.Builder() + .sslSocketFactory(sslContext.getSocketFactory(), trustAll) + .hostnameVerifier((hostname, session) -> true) + .readTimeout(1, TimeUnit.MINUTES) + .build(); + + IO.Options opts = new IO.Options(); + opts.forceNew = true; + opts.reconnection = false; + opts.transports = new String[] { "polling", "websocket" }; + opts.webSocketFactory = okHttp; + opts.callFactory = okHttp; + + Socket socket = IO.socket("https://127.0.0.1:" + port, opts); + try { + socket.on(Socket.EVENT_DISCONNECT, args -> unexpectedDisconnect.set(true)); + socket.on(Socket.EVENT_CONNECT_ERROR, args -> { + Object first = args.length > 0 ? args[0] : null; + if (first instanceof Throwable) { + connectError.set((Throwable) first); + } else { + connectError.set(new IllegalStateException(String.valueOf(first))); + } + engineHandshakeDone.countDown(); + }); + socket.on("welcome", args -> clientReceivedWelcome.countDown()); + socket.on(Socket.EVENT_CONNECT, args -> { + try { + JSONObject payload = new JSONObject(); + payload.put("a", 1); + socket.emit("hello", payload, (Ack) ackArgs -> { + if (ackArgs.length > 0 && "ok".equals(String.valueOf(ackArgs[0]))) { + clientReceivedAck.countDown(); + } + }); + } catch (Exception e) { + connectError.set(e); + } finally { + engineHandshakeDone.countDown(); + } + }); + socket.connect(); + + assertTrue(engineHandshakeDone.await(20, TimeUnit.SECONDS), + () -> "Engine.IO handshake did not complete: " + connectError.get()); + assertNull(connectError.get(), () -> "connect_error: " + connectError.get()); + assertFalse(unexpectedDisconnect.get(), "client disconnected before assertions"); + assertTrue(clientReceivedWelcome.await(15, TimeUnit.SECONDS), "client did not receive server welcome on polling/upgrade path"); + assertTrue(serverReceivedHello.await(15, TimeUnit.SECONDS), "server did not receive hello event"); + assertTrue(clientReceivedAck.await(15, TimeUnit.SECONDS), "client did not receive ack"); + assertFalse(unexpectedDisconnect.get(), "client disconnected during polling upgrade or ack"); + } finally { + socket.disconnect(); + } + } + + @Test + public void shouldReceiveHelloEventAndAckOverPlainWebSocketFromJavaClient() throws Exception { + CountDownLatch serverReceivedHello = new CountDownLatch(1); + CountDownLatch clientReceivedAck = new CountDownLatch(1); + CountDownLatch engineHandshakeDone = new CountDownLatch(1); + AtomicReference connectError = new AtomicReference<>(); + + server = startServer(0, null, serverReceivedHello); + int port = awaitBoundPort(server); + assertTrue(port > 0, "server did not bind an ephemeral port"); + + IO.Options opts = new IO.Options(); + opts.forceNew = true; + opts.reconnection = false; + opts.transports = new String[] { "websocket" }; + + Socket socket = IO.socket("http://127.0.0.1:" + port, opts); + try { + socket.on(Socket.EVENT_CONNECT_ERROR, args -> { + Object first = args.length > 0 ? args[0] : null; + if (first instanceof Throwable) { + connectError.set((Throwable) first); + } else { + connectError.set(new IllegalStateException(String.valueOf(first))); + } + engineHandshakeDone.countDown(); + }); + socket.on(Socket.EVENT_CONNECT, args -> { + try { + JSONObject payload = new JSONObject(); + payload.put("a", 1); + socket.emit("hello", payload, (Ack) ackArgs -> { + if (ackArgs.length > 0 && "ok".equals(String.valueOf(ackArgs[0]))) { + clientReceivedAck.countDown(); + } + }); + } catch (Exception e) { + connectError.set(e); + } finally { + engineHandshakeDone.countDown(); + } + }); + socket.connect(); + + assertTrue(engineHandshakeDone.await(20, TimeUnit.SECONDS), + () -> "Engine.IO handshake did not complete: " + connectError.get()); + assertNull(connectError.get(), () -> "connect_error: " + connectError.get()); + assertTrue(serverReceivedHello.await(15, TimeUnit.SECONDS), "server did not receive hello event"); + assertTrue(clientReceivedAck.await(15, TimeUnit.SECONDS), "client did not receive ack"); + } finally { + socket.disconnect(); + } + } + + @Test + public void shouldReceiveHelloEventAndAckOverPollingFromJavaClient() throws Exception { + CountDownLatch serverReceivedHello = new CountDownLatch(1); + CountDownLatch clientReceivedAck = new CountDownLatch(1); + CountDownLatch engineHandshakeDone = new CountDownLatch(1); + AtomicReference connectError = new AtomicReference<>(); + + server = startServer(0, testSslConfig(), serverReceivedHello); + int port = awaitBoundPort(server); + assertTrue(port > 0, "server did not bind an ephemeral port"); + + X509TrustManager trustAll = new X509TrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) { + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) { + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + }; + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(null, new TrustManager[] { trustAll }, new SecureRandom()); + + OkHttpClient okHttp = new OkHttpClient.Builder() + .sslSocketFactory(sslContext.getSocketFactory(), trustAll) + .hostnameVerifier((hostname, session) -> true) + .readTimeout(1, TimeUnit.MINUTES) + .build(); + + IO.Options opts = new IO.Options(); + opts.forceNew = true; + opts.reconnection = false; + opts.transports = new String[] { "polling" }; + opts.webSocketFactory = okHttp; + opts.callFactory = okHttp; + + Socket socket = IO.socket("https://127.0.0.1:" + port, opts); + try { + socket.on(Socket.EVENT_CONNECT_ERROR, args -> { + Object first = args.length > 0 ? args[0] : null; + if (first instanceof Throwable) { + connectError.set((Throwable) first); + } else { + connectError.set(new IllegalStateException(String.valueOf(first))); + } + engineHandshakeDone.countDown(); + }); + socket.on(Socket.EVENT_CONNECT, args -> { + try { + JSONObject payload = new JSONObject(); + payload.put("a", 1); + socket.emit("hello", payload, (Ack) ackArgs -> { + if (ackArgs.length > 0 && "ok".equals(String.valueOf(ackArgs[0]))) { + clientReceivedAck.countDown(); + } + }); + } catch (Exception e) { + connectError.set(e); + } finally { + engineHandshakeDone.countDown(); + } + }); + socket.connect(); + + assertTrue(engineHandshakeDone.await(20, TimeUnit.SECONDS), + () -> "Engine.IO handshake did not complete: " + connectError.get()); + assertNull(connectError.get(), () -> "connect_error: " + connectError.get()); + assertTrue(serverReceivedHello.await(15, TimeUnit.SECONDS), "server did not receive hello event"); + assertTrue(clientReceivedAck.await(15, TimeUnit.SECONDS), "client did not receive ack"); + } finally { + socket.disconnect(); + } + } + + @Test + public void shouldReceiveHelloEventAndAckOverPlainPollingFromJavaClient() throws Exception { + CountDownLatch serverReceivedHello = new CountDownLatch(1); + CountDownLatch clientReceivedAck = new CountDownLatch(1); + CountDownLatch engineHandshakeDone = new CountDownLatch(1); + AtomicReference connectError = new AtomicReference<>(); + + server = startServer(0, null, serverReceivedHello); + int port = awaitBoundPort(server); + assertTrue(port > 0, "server did not bind an ephemeral port"); + + IO.Options opts = new IO.Options(); + opts.forceNew = true; + opts.reconnection = false; + opts.transports = new String[] { "polling" }; + + Socket socket = IO.socket("http://127.0.0.1:" + port, opts); + try { + socket.on(Socket.EVENT_CONNECT_ERROR, args -> { + Object first = args.length > 0 ? args[0] : null; + if (first instanceof Throwable) { + connectError.set((Throwable) first); + } else { + connectError.set(new IllegalStateException(String.valueOf(first))); + } + engineHandshakeDone.countDown(); + }); + socket.on(Socket.EVENT_CONNECT, args -> { + try { + JSONObject payload = new JSONObject(); + payload.put("a", 1); + socket.emit("hello", payload, (Ack) ackArgs -> { + if (ackArgs.length > 0 && "ok".equals(String.valueOf(ackArgs[0]))) { + clientReceivedAck.countDown(); + } + }); + } catch (Exception e) { + connectError.set(e); + } finally { + engineHandshakeDone.countDown(); + } + }); + socket.connect(); + + assertTrue(engineHandshakeDone.await(20, TimeUnit.SECONDS), + () -> "Engine.IO handshake did not complete: " + connectError.get()); + assertNull(connectError.get(), () -> "connect_error: " + connectError.get()); + assertTrue(serverReceivedHello.await(15, TimeUnit.SECONDS), "server did not receive hello event"); + assertTrue(clientReceivedAck.await(15, TimeUnit.SECONDS), "client did not receive ack"); + } finally { + socket.disconnect(); + } + } + + private SocketIOServer startServer(int port, SocketSslConfig ssl, CountDownLatch hello) { + return startServer(port, ssl, hello, false); + } + + private SocketIOServer startServer(int port, SocketSslConfig ssl, CountDownLatch hello, boolean sendWelcomeOnConnect) { + Configuration cfg = new Configuration(); + cfg.setPort(port); + cfg.setOrigin("*"); + if (ssl != null) { + cfg.setSocketSslConfig(ssl); + } + cfg.setTransportType(TransportType.NIO); + + SocketIOServer s = new SocketIOServer(cfg); + s.addEventListener("hello", Map.class, (client, data, ackSender) -> { + hello.countDown(); + ackSender.sendAckData("ok"); + }); + if (sendWelcomeOnConnect) { + s.addConnectListener(client -> client.sendEvent("welcome", "from-server")); + } + s.start(); + return s; + } + + private SocketSslConfig testSslConfig() throws Exception { + SocketSslConfig ssl = new SocketSslConfig(); + ssl.setSSLProtocol("TLSv1.2"); + ssl.setKeyStoreFormat("PKCS12"); + ssl.setKeyStorePassword("password"); + + InputStream ks = SocketIoJavaClientSslTest.class.getClassLoader() + .getResourceAsStream("ssl/test-socketio.p12"); + assertNotNull(ks, "Missing test keystore resource ssl/test-socketio.p12"); + ssl.setKeyStore(ks); + return ssl; + } + + private static int awaitBoundPort(SocketIOServer server) throws InterruptedException { + long deadlineNs = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); + int port = server.getConfiguration().getPort(); + while (port == 0 && System.nanoTime() < deadlineNs) { + Thread.sleep(10); + port = server.getConfiguration().getPort(); + } + return port; + } +} diff --git a/netty-socketio-core/src/test/resources/ssl/test-socketio.p12 b/netty-socketio-core/src/test/resources/ssl/test-socketio.p12 new file mode 100644 index 00000000..1a471d2a Binary files /dev/null and b/netty-socketio-core/src/test/resources/ssl/test-socketio.p12 differ diff --git a/netty-socketio-examples/netty-socketio-examples-spring-boot-base/pom.xml b/netty-socketio-examples/netty-socketio-examples-spring-boot-base/pom.xml index e1c52cc5..d8fa47a3 100644 --- a/netty-socketio-examples/netty-socketio-examples-spring-boot-base/pom.xml +++ b/netty-socketio-examples/netty-socketio-examples-spring-boot-base/pom.xml @@ -15,7 +15,7 @@ NettySocketIO Spring Boot Examples - 4.1.119.Final + 4.2.9.Final diff --git a/netty-socketio-spring-boot-starter/src/test/java/com/socketio4j/socketio/test/spring/boot/starter/config/SocketIOOriginConfigurationTest.java b/netty-socketio-spring-boot-starter/src/test/java/com/socketio4j/socketio/test/spring/boot/starter/config/SocketIOOriginConfigurationTest.java index 4918911e..186c078a 100644 --- a/netty-socketio-spring-boot-starter/src/test/java/com/socketio4j/socketio/test/spring/boot/starter/config/SocketIOOriginConfigurationTest.java +++ b/netty-socketio-spring-boot-starter/src/test/java/com/socketio4j/socketio/test/spring/boot/starter/config/SocketIOOriginConfigurationTest.java @@ -34,10 +34,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; @DisplayName("Test for Socket.IO configuration properties") public class SocketIOOriginConfigurationTest extends BaseSpringApplicationTest { - private static final int PORT = 9090; + private static final int PORT = 19090; private static final int MAX_HEADER_SIZE = 1024; private static final boolean TCP_KEEP_ALIVE = true; @@ -149,7 +150,8 @@ public void testSocketConfigProperties() { @Test @DisplayName("Test SSL configuration properties") public void testSslConfigProperties() { - assertNotNull(nettySocketIOSslConfigProperties.getKeyStore(), "Key store should be loaded"); + assertTrue(nettySocketIOSslConfigProperties.hasKeyStore(), + "Key store should be configured (stream may be null after TLS material is buffered on server start)"); assertNotNull(nettySocketIOSslConfigProperties.getKeyStorePassword(), "Key store password should be loaded"); SocketSslConfig socketSslConfig = new SocketSslConfig(); diff --git a/pom.xml b/pom.xml index a2917e19..da8838fb 100644 --- a/pom.xml +++ b/pom.xml @@ -66,6 +66,7 @@ UTF-8 2.0.3 4.2.9.Final + 2.0.74.Final 1.50 1.18.4 6.0.2 @@ -278,6 +279,11 @@ netty-codec ${netty.version}
+ + io.netty + netty-tcnative-boringssl-static + ${netty.tcnative.version} + io.netty netty-transport-native-epoll