Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,25 +82,16 @@ public void addToChannel(Channel ch) throws GeneralSecurityException {

@VisibleForTesting
class EncryptionHandler extends ChannelOutboundHandlerAdapter {
private final ByteBuffer plaintextBuffer;
private final ByteBuffer ciphertextBuffer;
private final AesGcmHkdfStreaming aesGcmHkdfStreaming;

EncryptionHandler() throws InvalidAlgorithmParameterException {
aesGcmHkdfStreaming = getAesGcmHkdfStreaming();
plaintextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize());
ciphertextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize());
}

@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
GcmEncryptedMessage encryptedMessage = new GcmEncryptedMessage(
aesGcmHkdfStreaming,
msg,
plaintextBuffer,
ciphertextBuffer);
ctx.write(encryptedMessage, promise);
ctx.write(new GcmEncryptedMessage(aesGcmHkdfStreaming, msg), promise);
}
}

Expand All @@ -116,15 +107,15 @@ static class GcmEncryptedMessage extends AbstractFileRegion {
private final long encryptedCount;

GcmEncryptedMessage(AesGcmHkdfStreaming aesGcmHkdfStreaming,
Object plaintextMessage,
ByteBuffer plaintextBuffer,
ByteBuffer ciphertextBuffer) throws GeneralSecurityException {
Object plaintextMessage) throws GeneralSecurityException {
JavaUtils.checkArgument(
plaintextMessage instanceof ByteBuf || plaintextMessage instanceof FileRegion,
"Unrecognized message type: %s", plaintextMessage.getClass().getName());
this.plaintextMessage = plaintextMessage;
this.plaintextBuffer = plaintextBuffer;
this.ciphertextBuffer = ciphertextBuffer;
this.plaintextBuffer =
ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize());
this.ciphertextBuffer =
ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize());
// If the ciphertext buffer cannot be fully written the target, transferTo may
// return with it containing some unwritten data. The initial call we'll explicitly
// set its limit to 0 to indicate the first call to transferTo.
Expand Down Expand Up @@ -289,7 +280,7 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter {
private final ByteBuffer headerBuffer;
private final ByteBuffer ciphertextBuffer;
private final AesGcmHkdfStreaming aesGcmHkdfStreaming;
private final StreamSegmentDecrypter decrypter;
private StreamSegmentDecrypter decrypter;
private final int plaintextSegmentSize;
private boolean decrypterInit = false;
private boolean completed = false;
Expand All @@ -307,6 +298,25 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter {
plaintextSegmentSize = aesGcmHkdfStreaming.getPlaintextSegmentSize();
}

/**
* Resets all per-message state so that the next incoming GCM message can be
* decoded through the same channel handler instance. This must be called after
* every successfully completed message because AesGcmHkdfStreaming is a one-shot
* streaming primitive: each encrypted message carries its own random IV and must
* be decrypted with a fresh StreamSegmentDecrypter.
*/
private void resetForNextMessage() throws GeneralSecurityException {
expectedLength = -1;
expectedLengthBuffer.clear();
headerBuffer.clear();
ciphertextBuffer.clear();
decrypterInit = false;
completed = false;
segmentNumber = 0;
ciphertextRead = 0;
decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter();
}

private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) {
if (expectedLength < 0) {
ciphertextNettyBuf.readBytes(expectedLengthBuffer);
Expand Down Expand Up @@ -354,56 +364,75 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
"Unrecognized message type: %s",
ciphertextMessage.getClass().getName());
ByteBuf ciphertextNettyBuf = (ByteBuf) ciphertextMessage;
// The format of the output is:
// The format of each message is:
// [8 byte length][Internal IV and header][Ciphertext][Auth Tag]
//
// A single channelRead() call may deliver bytes from multiple back-to-back
// GCM messages (common under shuffle load when TCP coalesces writes). The
// outer loop processes as many complete messages as possible from the buffer
// before releasing it, so that bytes belonging to the next message are never
// discarded mid-stream.
try {
if (!initalizeExpectedLength(ciphertextNettyBuf)) {
// We have not read enough bytes to initialize the expected length.
return;
}
if (!initalizeDecrypter(ciphertextNettyBuf)) {
// We have not read enough bytes to initialize a header, needed to
// initialize a decrypter.
return;
}
int nettyBufReadableBytes = ciphertextNettyBuf.readableBytes();
while (nettyBufReadableBytes > 0 && !completed) {
// Read the ciphertext into the local buffer
int readableBytes = Integer.min(
nettyBufReadableBytes,
ciphertextBuffer.remaining());
int expectedRemaining = (int) (expectedLength - ciphertextRead);
int bytesToRead = Integer.min(readableBytes, expectedRemaining);
// The smallest ciphertext size is 16 bytes for the auth tag
ciphertextBuffer.limit(ciphertextBuffer.position() + bytesToRead);
ciphertextNettyBuf.readBytes(ciphertextBuffer);
ciphertextRead += bytesToRead;
// Check if this is the last segment
if (ciphertextRead == expectedLength) {
completed = true;
} else if (ciphertextRead > expectedLength) {
throw new IllegalStateException("Read more ciphertext than expected.");
while (true) {
if (!initalizeExpectedLength(ciphertextNettyBuf)) {
// We have not read enough bytes to initialize the expected length.
break;
}
if (!initalizeDecrypter(ciphertextNettyBuf)) {
// We have not read enough bytes to initialize a header, needed to
// initialize a decrypter.
break;
}
int nettyBufReadableBytes = ciphertextNettyBuf.readableBytes();
while (nettyBufReadableBytes > 0 && !completed) {
// Read the ciphertext into the local buffer
int readableBytes = Integer.min(
nettyBufReadableBytes,
ciphertextBuffer.remaining());
int expectedRemaining = (int) (expectedLength - ciphertextRead);
int bytesToRead = Integer.min(readableBytes, expectedRemaining);
// The smallest ciphertext size is 16 bytes for the auth tag
ciphertextBuffer.limit(ciphertextBuffer.position() + bytesToRead);
ciphertextNettyBuf.readBytes(ciphertextBuffer);
ciphertextRead += bytesToRead;
// Check if this is the last segment
if (ciphertextRead == expectedLength) {
completed = true;
} else if (ciphertextRead > expectedLength) {
throw new IllegalStateException("Read more ciphertext than expected.");
}
// If the ciphertext buffer is full, or this is the last segment,
// then decrypt it and fire a read.
if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) {
ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize);
ciphertextBuffer.flip();
decrypter.decryptSegment(
ciphertextBuffer,
segmentNumber,
completed,
plaintextBuffer);
segmentNumber++;
// Clear the ciphertext buffer because it's been read
ciphertextBuffer.clear();
plaintextBuffer.flip();
ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer));
} else {
// Set the ciphertext buffer up to read the next chunk
ciphertextBuffer.limit(ciphertextBuffer.capacity());
}
nettyBufReadableBytes = ciphertextNettyBuf.readableBytes();
}
if (!completed) {
// Partial message: more bytes needed from the next channelRead() call.
break;
}
// If the ciphertext buffer is full, or this is the last segment,
// then decrypt it and fire a read.
if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) {
ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize);
ciphertextBuffer.flip();
decrypter.decryptSegment(
ciphertextBuffer,
segmentNumber,
completed,
plaintextBuffer);
segmentNumber++;
// Clear the ciphertext buffer because it's been read
ciphertextBuffer.clear();
plaintextBuffer.flip();
ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer));
} else {
// Set the ciphertext buffer up to read the next chunk
ciphertextBuffer.limit(ciphertextBuffer.capacity());
// Current message is fully decoded. Reset state so the handler can
// decode the next independent GCM message on the same channel.
resetForNextMessage();
if (ciphertextNettyBuf.readableBytes() == 0) {
break;
}
nettyBufReadableBytes = ciphertextNettyBuf.readableBytes();
// Remaining bytes may belong to another message; loop to process them.
}
} finally {
ciphertextNettyBuf.release();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,4 +336,163 @@ public void testCorruptGcmEncryptedMessage() throws Exception {
assertThrows(AEADBadTagException.class, () -> decryptionHandler.channelRead(ctx, ciphertext));
}
}

/**
* Verifies that the same DecryptionHandler instance correctly decodes multiple independent
* GCM-encrypted messages sent over the same channel. This is the regression test for the
* bug where DecryptionHandler.completed was never reset, causing every message after the
* first to be silently dropped — which manifested as YARN container launch failures.
*/
@Test
public void testMultipleMessages() throws Exception {
TransportConf gcmConf = getConf(2, false);
try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf);
AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) {
AuthMessage clientChallenge = client.challenge();
AuthMessage serverResponse = server.response(clientChallenge);
client.deriveSessionCipher(clientChallenge, serverResponse);
TransportCipher cipher = server.sessionCipher();
assert (cipher instanceof GcmTransportCipher);
GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) cipher;

GcmTransportCipher.EncryptionHandler encryptionHandler =
gcmTransportCipher.getEncryptionHandler();
GcmTransportCipher.DecryptionHandler decryptionHandler =
gcmTransportCipher.getDecryptionHandler();

ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
ChannelPromise promise = mock(ChannelPromise.class);

// --- First message ---
byte[] data1 = new byte[1024];
Arrays.fill(data1, (byte) 'A');
ByteBuf buf1 = Unpooled.wrappedBuffer(data1);
ArgumentCaptor<GcmTransportCipher.GcmEncryptedMessage> captor1 =
ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class);
encryptionHandler.write(ctx, buf1, promise);
verify(ctx).write(captor1.capture(), eq(promise));
GcmTransportCipher.GcmEncryptedMessage enc1 = captor1.getValue();
ByteBuffer ct1 = ByteBuffer.allocate((int) enc1.count());
enc1.transferTo(new ByteBufferWriteableChannel(ct1), 0);
ct1.flip();

ArgumentCaptor<ByteBuf> plaintextCaptor1 = ArgumentCaptor.forClass(ByteBuf.class);
decryptionHandler.channelRead(ctx, Unpooled.wrappedBuffer(ct1));
verify(ctx, atLeastOnce()).fireChannelRead(plaintextCaptor1.capture());
byte[] decrypted1 = new byte[data1.length];
int offset = 0;
for (ByteBuf segment : plaintextCaptor1.getAllValues()) {
int len = segment.readableBytes();
segment.readBytes(decrypted1, offset, len);
offset += len;
}
assertEquals(data1.length, offset);
assertEquals('A', decrypted1[0]);
assertEquals('A', decrypted1[data1.length - 1]);

// --- Second message (same handler, different content) ---
reset(ctx);
byte[] data2 = new byte[2048];
Arrays.fill(data2, (byte) 'B');
ByteBuf buf2 = Unpooled.wrappedBuffer(data2);
ArgumentCaptor<GcmTransportCipher.GcmEncryptedMessage> captor2 =
ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class);
encryptionHandler.write(ctx, buf2, promise);
verify(ctx).write(captor2.capture(), eq(promise));
GcmTransportCipher.GcmEncryptedMessage enc2 = captor2.getValue();
ByteBuffer ct2 = ByteBuffer.allocate((int) enc2.count());
enc2.transferTo(new ByteBufferWriteableChannel(ct2), 0);
ct2.flip();

ArgumentCaptor<ByteBuf> plaintextCaptor2 = ArgumentCaptor.forClass(ByteBuf.class);
decryptionHandler.channelRead(ctx, Unpooled.wrappedBuffer(ct2));
verify(ctx, atLeastOnce()).fireChannelRead(plaintextCaptor2.capture());
byte[] decrypted2 = new byte[data2.length];
offset = 0;
for (ByteBuf segment : plaintextCaptor2.getAllValues()) {
int len = segment.readableBytes();
segment.readBytes(decrypted2, offset, len);
offset += len;
}
assertEquals(data2.length, offset);
assertEquals('B', decrypted2[0]);
assertEquals('B', decrypted2[data2.length - 1]);
}
}

/**
* Verifies that multiple GCM-encrypted messages delivered inside a single channelRead()
* call (i.e., batched into one ByteBuf by TCP coalescing) are all decoded correctly.
* This is the regression test for the IllegalStateException("Invalid expected ciphertext
* length") error observed under SparkSQL shuffle load: when Netty batches two messages
* into one ByteBuf, the old code released the buffer after the first message was decoded,
* discarding the remaining bytes. The next channelRead() then read bytes from the middle
* of the second message as its length header, producing a negative long value.
*/
@Test
public void testBatchedMessages() throws Exception {
TransportConf gcmConf = getConf(2, false);
try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf);
AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) {
AuthMessage clientChallenge = client.challenge();
AuthMessage serverResponse = server.response(clientChallenge);
client.deriveSessionCipher(clientChallenge, serverResponse);
TransportCipher cipher = server.sessionCipher();
assert (cipher instanceof GcmTransportCipher);
GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) cipher;

GcmTransportCipher.EncryptionHandler encryptionHandler =
gcmTransportCipher.getEncryptionHandler();
GcmTransportCipher.DecryptionHandler decryptionHandler =
gcmTransportCipher.getDecryptionHandler();

ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
ChannelPromise promise = mock(ChannelPromise.class);

// Encrypt two independent messages.
byte[] data1 = new byte[1024];
Arrays.fill(data1, (byte) 'A');
ArgumentCaptor<GcmTransportCipher.GcmEncryptedMessage> captor1 =
ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class);
encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data1), promise);
verify(ctx).write(captor1.capture(), eq(promise));
ByteBuffer ct1 = ByteBuffer.allocate((int) captor1.getValue().count());
captor1.getValue().transferTo(new ByteBufferWriteableChannel(ct1), 0);
ct1.flip();

reset(ctx);
byte[] data2 = new byte[2048];
Arrays.fill(data2, (byte) 'B');
ArgumentCaptor<GcmTransportCipher.GcmEncryptedMessage> captor2 =
ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class);
encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data2), promise);
verify(ctx).write(captor2.capture(), eq(promise));
ByteBuffer ct2 = ByteBuffer.allocate((int) captor2.getValue().count());
captor2.getValue().transferTo(new ByteBufferWriteableChannel(ct2), 0);
ct2.flip();

// Simulate TCP coalescing: deliver both ciphertexts in one channelRead() call.
reset(ctx);
ByteBuf batched = Unpooled.wrappedBuffer(ct1, ct2);
ArgumentCaptor<ByteBuf> plaintextCaptor = ArgumentCaptor.forClass(ByteBuf.class);
decryptionHandler.channelRead(ctx, batched);
verify(ctx, atLeastOnce()).fireChannelRead(plaintextCaptor.capture());

// Collect all decrypted segments from both messages.
byte[] decrypted = new byte[data1.length + data2.length];
int offset = 0;
for (ByteBuf segment : plaintextCaptor.getAllValues()) {
int len = segment.readableBytes();
segment.readBytes(decrypted, offset, len);
offset += len;
}
assertEquals(data1.length + data2.length, offset);
// Verify message 1 content ('A').
assertEquals('A', decrypted[0]);
assertEquals('A', decrypted[data1.length - 1]);
// Verify message 2 content ('B').
assertEquals('B', decrypted[data1.length]);
assertEquals('B', decrypted[data1.length + data2.length - 1]);
}
}
}