diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java index e1cf22a612ea4..bb13c138b263f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java @@ -82,25 +82,16 @@ public void addToChannel(Channel ch) throws GeneralSecurityException { @VisibleForTesting class EncryptionHandler extends ChannelOutboundHandlerAdapter { - private final ByteBuffer plaintextBuffer; - private final ByteBuffer ciphertextBuffer; private final AesGcmHkdfStreaming aesGcmHkdfStreaming; EncryptionHandler() throws InvalidAlgorithmParameterException { aesGcmHkdfStreaming = getAesGcmHkdfStreaming(); - plaintextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize()); - ciphertextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize()); } @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - GcmEncryptedMessage encryptedMessage = new GcmEncryptedMessage( - aesGcmHkdfStreaming, - msg, - plaintextBuffer, - ciphertextBuffer); - ctx.write(encryptedMessage, promise); + ctx.write(new GcmEncryptedMessage(aesGcmHkdfStreaming, msg), promise); } } @@ -116,15 +107,15 @@ static class GcmEncryptedMessage extends AbstractFileRegion { private final long encryptedCount; GcmEncryptedMessage(AesGcmHkdfStreaming aesGcmHkdfStreaming, - Object plaintextMessage, - ByteBuffer plaintextBuffer, - ByteBuffer ciphertextBuffer) throws GeneralSecurityException { + Object plaintextMessage) throws GeneralSecurityException { JavaUtils.checkArgument( plaintextMessage instanceof ByteBuf || plaintextMessage instanceof FileRegion, "Unrecognized message type: %s", plaintextMessage.getClass().getName()); this.plaintextMessage = plaintextMessage; - this.plaintextBuffer = plaintextBuffer; - this.ciphertextBuffer = ciphertextBuffer; + this.plaintextBuffer = + ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize()); + this.ciphertextBuffer = + ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize()); // If the ciphertext buffer cannot be fully written the target, transferTo may // return with it containing some unwritten data. The initial call we'll explicitly // set its limit to 0 to indicate the first call to transferTo. @@ -289,7 +280,8 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter { private final ByteBuffer headerBuffer; private final ByteBuffer ciphertextBuffer; private final AesGcmHkdfStreaming aesGcmHkdfStreaming; - private final StreamSegmentDecrypter decrypter; + private StreamSegmentDecrypter decrypter; + private final int headerLength; private final int plaintextSegmentSize; private boolean decrypterInit = false; private boolean completed = false; @@ -299,17 +291,46 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter { DecryptionHandler() throws GeneralSecurityException { aesGcmHkdfStreaming = getAesGcmHkdfStreaming(); + headerLength = aesGcmHkdfStreaming.getHeaderLength(); expectedLengthBuffer = ByteBuffer.allocate(LENGTH_HEADER_BYTES); - headerBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getHeaderLength()); + headerBuffer = ByteBuffer.allocate(headerLength); ciphertextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize()); decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter(); plaintextSegmentSize = aesGcmHkdfStreaming.getPlaintextSegmentSize(); } - private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) { + /** + * Resets all per-message state so that the next incoming GCM message can be + * decoded through the same channel handler instance. This must be called after + * every successfully completed message because AesGcmHkdfStreaming is a one-shot + * streaming primitive: each encrypted message carries its own random IV and must + * be decrypted with a fresh StreamSegmentDecrypter. + */ + private void resetForNextMessage() throws GeneralSecurityException { + expectedLength = -1; + expectedLengthBuffer.clear(); + headerBuffer.clear(); + ciphertextBuffer.clear(); + decrypterInit = false; + completed = false; + segmentNumber = 0; + ciphertextRead = 0; + decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter(); + } + + private boolean initializeExpectedLength(ByteBuf ciphertextNettyBuf) { if (expectedLength < 0) { - ciphertextNettyBuf.readBytes(expectedLengthBuffer); + // ByteBuf.readBytes(ByteBuffer) throws if fewer than dst.remaining() bytes + // are available, so temporarily narrow the limit to what is actually present. + int toRead = Math.min(ciphertextNettyBuf.readableBytes(), + expectedLengthBuffer.remaining()); + if (toRead > 0) { + int savedLimit = expectedLengthBuffer.limit(); + expectedLengthBuffer.limit(expectedLengthBuffer.position() + toRead); + ciphertextNettyBuf.readBytes(expectedLengthBuffer); + expectedLengthBuffer.limit(savedLimit); + } if (expectedLengthBuffer.hasRemaining()) { // We did not read enough bytes to initialize the expected length. return false; @@ -324,12 +345,22 @@ private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) { return true; } - private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf) + private boolean initializeDecrypter(ByteBuf ciphertextNettyBuf) throws GeneralSecurityException { // Check if the ciphertext header has been read. This contains // the IV and other internal metadata. if (!decrypterInit) { - ciphertextNettyBuf.readBytes(headerBuffer); + // ByteBuf.readBytes(ByteBuffer) throws if fewer than dst.remaining() bytes + // are available. Under TCP fragmentation the header can arrive in multiple + // chunks, so temporarily narrow the limit to what is actually present. + int toRead = Math.min(ciphertextNettyBuf.readableBytes(), + headerBuffer.remaining()); + if (toRead > 0) { + int savedLimit = headerBuffer.limit(); + headerBuffer.limit(headerBuffer.position() + toRead); + ciphertextNettyBuf.readBytes(headerBuffer); + headerBuffer.limit(savedLimit); + } if (headerBuffer.hasRemaining()) { // We did not read enough bytes to initialize the header. return false; @@ -338,7 +369,7 @@ private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf) byte[] lengthAad = Longs.toByteArray(expectedLength); decrypter.init(headerBuffer, lengthAad); decrypterInit = true; - ciphertextRead += aesGcmHkdfStreaming.getHeaderLength(); + ciphertextRead += headerLength; if (expectedLength == ciphertextRead) { // If the expected length is just the header, the ciphertext is 0 length. completed = true; @@ -354,56 +385,75 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) "Unrecognized message type: %s", ciphertextMessage.getClass().getName()); ByteBuf ciphertextNettyBuf = (ByteBuf) ciphertextMessage; - // The format of the output is: + // The format of each message is: // [8 byte length][Internal IV and header][Ciphertext][Auth Tag] + // + // A single channelRead() call may deliver bytes from multiple back-to-back + // GCM messages (common under shuffle load when TCP coalesces writes). The + // outer loop processes as many complete messages as possible from the buffer + // before releasing it, so that bytes belonging to the next message are never + // discarded mid-stream. try { - if (!initalizeExpectedLength(ciphertextNettyBuf)) { - // We have not read enough bytes to initialize the expected length. - return; - } - if (!initalizeDecrypter(ciphertextNettyBuf)) { - // We have not read enough bytes to initialize a header, needed to - // initialize a decrypter. - return; - } - int nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); - while (nettyBufReadableBytes > 0 && !completed) { - // Read the ciphertext into the local buffer - int readableBytes = Integer.min( - nettyBufReadableBytes, - ciphertextBuffer.remaining()); - int expectedRemaining = (int) (expectedLength - ciphertextRead); - int bytesToRead = Integer.min(readableBytes, expectedRemaining); - // The smallest ciphertext size is 16 bytes for the auth tag - ciphertextBuffer.limit(ciphertextBuffer.position() + bytesToRead); - ciphertextNettyBuf.readBytes(ciphertextBuffer); - ciphertextRead += bytesToRead; - // Check if this is the last segment - if (ciphertextRead == expectedLength) { - completed = true; - } else if (ciphertextRead > expectedLength) { - throw new IllegalStateException("Read more ciphertext than expected."); + while (true) { + if (!initializeExpectedLength(ciphertextNettyBuf)) { + // We have not read enough bytes to initialize the expected length. + break; + } + if (!initializeDecrypter(ciphertextNettyBuf)) { + // We have not read enough bytes to initialize a header, needed to + // initialize a decrypter. + break; + } + int nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); + while (nettyBufReadableBytes > 0 && !completed) { + // Read the ciphertext into the local buffer + int readableBytes = Math.min( + nettyBufReadableBytes, + ciphertextBuffer.remaining()); + int expectedRemaining = (int) (expectedLength - ciphertextRead); + int bytesToRead = Math.min(readableBytes, expectedRemaining); + // The smallest ciphertext size is 16 bytes for the auth tag + ciphertextBuffer.limit(ciphertextBuffer.position() + bytesToRead); + ciphertextNettyBuf.readBytes(ciphertextBuffer); + ciphertextRead += bytesToRead; + // Check if this is the last segment + if (ciphertextRead == expectedLength) { + completed = true; + } else if (ciphertextRead > expectedLength) { + throw new IllegalStateException("Read more ciphertext than expected."); + } + // If the ciphertext buffer is full, or this is the last segment, + // then decrypt it and fire a read. + if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) { + ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize); + ciphertextBuffer.flip(); + decrypter.decryptSegment( + ciphertextBuffer, + segmentNumber, + completed, + plaintextBuffer); + segmentNumber++; + // Clear the ciphertext buffer because it's been read + ciphertextBuffer.clear(); + plaintextBuffer.flip(); + ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer)); + } else { + // Set the ciphertext buffer up to read the next chunk + ciphertextBuffer.limit(ciphertextBuffer.capacity()); + } + nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); + } + if (!completed) { + // Partial message: more bytes needed from the next channelRead() call. + break; } - // If the ciphertext buffer is full, or this is the last segment, - // then decrypt it and fire a read. - if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) { - ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize); - ciphertextBuffer.flip(); - decrypter.decryptSegment( - ciphertextBuffer, - segmentNumber, - completed, - plaintextBuffer); - segmentNumber++; - // Clear the ciphertext buffer because it's been read - ciphertextBuffer.clear(); - plaintextBuffer.flip(); - ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer)); - } else { - // Set the ciphertext buffer up to read the next chunk - ciphertextBuffer.limit(ciphertextBuffer.capacity()); + // Current message is fully decoded. Reset state so the handler can + // decode the next independent GCM message on the same channel. + resetForNextMessage(); + if (ciphertextNettyBuf.readableBytes() == 0) { + break; } - nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); + // Remaining bytes may belong to another message; loop to process them. } } finally { ciphertextNettyBuf.release(); diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java index 20efb8d57dcbf..cb7d6681822a0 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java @@ -290,6 +290,221 @@ public void testGcmUnalignedDecryption() throws Exception { } } + /** + * Verifies that the same DecryptionHandler instance correctly decodes multiple independent + * GCM-encrypted messages sent over the same channel. This is the regression test for the + * bug where DecryptionHandler.completed was never reset, causing every message after the + * first to be silently dropped — which manifested as YARN container launch failures. + */ + @Test + public void testMultipleMessages() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher cipher = server.sessionCipher(); + assert (cipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) cipher; + + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + + // --- First message --- + byte[] data1 = new byte[1024]; + Arrays.fill(data1, (byte) 'A'); + ArgumentCaptor captor1 = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data1), promise); + verify(ctx).write(captor1.capture(), eq(promise)); + ByteBuffer ct1 = ByteBuffer.allocate((int) captor1.getValue().count()); + captor1.getValue().transferTo(new ByteBufferWriteableChannel(ct1), 0); + ct1.flip(); + + ArgumentCaptor plaintextCaptor1 = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, Unpooled.wrappedBuffer(ct1)); + verify(ctx, atLeastOnce()).fireChannelRead(plaintextCaptor1.capture()); + byte[] decrypted1 = new byte[data1.length]; + int offset = 0; + for (ByteBuf segment : plaintextCaptor1.getAllValues()) { + int len = segment.readableBytes(); + segment.readBytes(decrypted1, offset, len); + offset += len; + } + assertEquals(data1.length, offset); + assertArrayEquals(data1, decrypted1); + + // --- Second message (same handler, different content) --- + reset(ctx); + byte[] data2 = new byte[2048]; + Arrays.fill(data2, (byte) 'B'); + ArgumentCaptor captor2 = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data2), promise); + verify(ctx).write(captor2.capture(), eq(promise)); + ByteBuffer ct2 = ByteBuffer.allocate((int) captor2.getValue().count()); + captor2.getValue().transferTo(new ByteBufferWriteableChannel(ct2), 0); + ct2.flip(); + + ArgumentCaptor plaintextCaptor2 = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, Unpooled.wrappedBuffer(ct2)); + verify(ctx, atLeastOnce()).fireChannelRead(plaintextCaptor2.capture()); + byte[] decrypted2 = new byte[data2.length]; + offset = 0; + for (ByteBuf segment : plaintextCaptor2.getAllValues()) { + int len = segment.readableBytes(); + segment.readBytes(decrypted2, offset, len); + offset += len; + } + assertEquals(data2.length, offset); + assertArrayEquals(data2, decrypted2); + } + } + + /** + * Verifies that multiple GCM-encrypted messages delivered inside a single channelRead() + * call (TCP coalescing) are all decoded correctly. This is the regression test for the + * IllegalStateException("Invalid expected ciphertext length") observed under SparkSQL + * shuffle load: when Netty batches two messages into one ByteBuf, the old code released + * the buffer after the first message, discarding remaining bytes. The next channelRead() + * then read bytes from the middle of the second message as a length header. + */ + @Test + public void testBatchedMessages() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher cipher = server.sessionCipher(); + assert (cipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) cipher; + + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + + byte[] data1 = new byte[1024]; + Arrays.fill(data1, (byte) 'A'); + ArgumentCaptor captor1 = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data1), promise); + verify(ctx).write(captor1.capture(), eq(promise)); + ByteBuffer ct1 = ByteBuffer.allocate((int) captor1.getValue().count()); + captor1.getValue().transferTo(new ByteBufferWriteableChannel(ct1), 0); + ct1.flip(); + + reset(ctx); + byte[] data2 = new byte[2048]; + Arrays.fill(data2, (byte) 'B'); + ArgumentCaptor captor2 = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data2), promise); + verify(ctx).write(captor2.capture(), eq(promise)); + ByteBuffer ct2 = ByteBuffer.allocate((int) captor2.getValue().count()); + captor2.getValue().transferTo(new ByteBufferWriteableChannel(ct2), 0); + ct2.flip(); + + // Simulate TCP coalescing: deliver both ciphertexts in one channelRead() call. + reset(ctx); + ByteBuf batched = Unpooled.wrappedBuffer(ct1, ct2); + ArgumentCaptor plaintextCaptor = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, batched); + verify(ctx, atLeastOnce()).fireChannelRead(plaintextCaptor.capture()); + + byte[] decrypted = new byte[data1.length + data2.length]; + int offset = 0; + for (ByteBuf segment : plaintextCaptor.getAllValues()) { + int len = segment.readableBytes(); + segment.readBytes(decrypted, offset, len); + offset += len; + } + assertEquals(data1.length + data2.length, offset); + assertArrayEquals(data1, Arrays.copyOfRange(decrypted, 0, data1.length)); + assertArrayEquals(data2, Arrays.copyOfRange(decrypted, data1.length, decrypted.length)); + } + } + + /** + * Verifies that DecryptionHandler correctly handles a GCM message whose framing header + * is split across two channelRead() calls. This is the regression test for the + * IndexOutOfBoundsException in initializeDecrypter observed in benchmarking: when only + * 4 bytes of the 24-byte GCM internal header arrived in one Netty buffer, + * ByteBuf.readBytes(ByteBuffer) threw because it requires all dst.remaining() bytes to + * be available rather than performing a partial fill. + */ + @Test + public void testSplitHeader() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher cipher = server.sessionCipher(); + assert (cipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) cipher; + + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'X'); + ArgumentCaptor captor = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data), promise); + verify(ctx).write(captor.capture(), eq(promise)); + + ByteBuffer ciphertextBuffer = ByteBuffer.allocate((int) captor.getValue().count()); + captor.getValue().transferTo(new ByteBufferWriteableChannel(ciphertextBuffer), 0); + ciphertextBuffer.flip(); + byte[] ciphertext = new byte[ciphertextBuffer.remaining()]; + ciphertextBuffer.get(ciphertext); + + // Split in the middle of the 24-byte GCM internal header: + // chunk1 = [8-byte length field][4 bytes of GCM header] + // chunk2 = [remaining 20 bytes of GCM header][full ciphertext] + int splitPoint = 8 + 4; + ByteBuf chunk1 = Unpooled.wrappedBuffer(ciphertext, 0, splitPoint); + ByteBuf chunk2 = Unpooled.wrappedBuffer( + ciphertext, splitPoint, ciphertext.length - splitPoint); + + decryptionHandler.channelRead(ctx, chunk1); + // Only a partial header was delivered; no plaintext should be emitted yet. + verify(ctx, never()).fireChannelRead(any()); + + ArgumentCaptor plaintextCaptor = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, chunk2); + verify(ctx, atLeastOnce()).fireChannelRead(plaintextCaptor.capture()); + + byte[] decrypted = new byte[data.length]; + int offset = 0; + for (ByteBuf segment : plaintextCaptor.getAllValues()) { + int len = segment.readableBytes(); + segment.readBytes(decrypted, offset, len); + offset += len; + } + assertEquals(data.length, offset); + assertArrayEquals(data, decrypted); + } + } + @Test public void testCorruptGcmEncryptedMessage() throws Exception { TransportConf gcmConf = getConf(2, false);