diff --git a/streams/src/main/java/org/apache/kafka/streams/state/AggregationWithHeaders.java b/streams/src/main/java/org/apache/kafka/streams/state/AggregationWithHeaders.java index 0e1173c2ed2b5..40f75afa89c66 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/AggregationWithHeaders.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/AggregationWithHeaders.java @@ -18,6 +18,7 @@ import org.apache.kafka.common.header.Headers; +import java.util.Arrays; import java.util.Objects; /** @@ -96,12 +97,18 @@ public boolean equals(final Object o) { } final AggregationWithHeaders that = (AggregationWithHeaders) o; return Objects.equals(aggregation, that.aggregation) - && Objects.equals(this.headers, that.headers); + && headersEqual(this.headers, that.headers); + } + + private static boolean headersEqual(final Headers a, final Headers b) { + if (a == b) return true; + if (a == null || b == null) return false; + return Arrays.equals(a.toArray(), b.toArray()); } @Override public int hashCode() { - return Objects.hash(aggregation, headers); + return Objects.hash(aggregation, Arrays.hashCode(headers.toArray())); } @Override diff --git a/streams/src/main/java/org/apache/kafka/streams/state/ValueTimestampHeaders.java b/streams/src/main/java/org/apache/kafka/streams/state/ValueTimestampHeaders.java index 6a92413486797..8d16cb2f94be9 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/ValueTimestampHeaders.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/ValueTimestampHeaders.java @@ -19,6 +19,7 @@ import org.apache.kafka.common.header.Headers; import org.apache.kafka.common.header.internals.RecordHeaders; +import java.util.Arrays; import java.util.Objects; /** @@ -108,12 +109,18 @@ public boolean equals(final Object o) { final ValueTimestampHeaders that = (ValueTimestampHeaders) o; return timestamp == that.timestamp && Objects.equals(value, that.value) - && Objects.equals(this.headers, that.headers); + && headersEqual(this.headers, that.headers); + } + + private static boolean headersEqual(final Headers a, final Headers b) { + if (a == b) return true; + if (a == null || b == null) return false; + return Arrays.equals(a.toArray(), b.toArray()); } @Override public int hashCode() { - return Objects.hash(value, timestamp, headers); + return Objects.hash(value, timestamp, Arrays.hashCode(headers.toArray())); } @Override diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/LazyHeaders.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/LazyHeaders.java new file mode 100644 index 0000000000000..b52d0fb383ba1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/LazyHeaders.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.header.internals.RecordHeaders; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; + +/** + * A lazy implementation of {@link Headers} that defers deserialization of header bytes + * until first read access. This avoids unnecessary parsing when the downstream + * deserializer does not inspect headers. + * + *

Headers added via {@link #add(Header)} or {@link #add(String, byte[])} before + * materialization are accumulated in a side list and merged on first read access. + * + *

Instances are confined to a single {@code StreamThread} and are not shared + * across threads, so no synchronization is needed. + */ +class LazyHeaders implements Headers { + + private final byte[] rawHeaders; + private RecordHeaders materialized; + private List

pendingAdds; + + /** + * Creates a new LazyHeaders wrapping the given raw header bytes. + * + * @param rawHeaders the serialized header bytes (without the varint size prefix), + * as expected by {@link HeadersDeserializer#deserialize(byte[])}. + * May be null or empty for empty headers. + */ + LazyHeaders(final byte[] rawHeaders) { + this.rawHeaders = rawHeaders; + } + + private RecordHeaders materialize() { + if (materialized == null) { + final Headers deserialized = HeadersDeserializer.deserialize(rawHeaders); + materialized = (deserialized instanceof RecordHeaders) + ? (RecordHeaders) deserialized + : new RecordHeaders(deserialized); + if (pendingAdds != null) { + for (final Header h : pendingAdds) { + materialized.add(h); + } + pendingAdds = null; + } + } + return materialized; + } + + /** + * Returns true if the headers have been deserialized. + * Visible for testing. + */ + boolean isDeserialized() { + return materialized != null; + } + + @Override + public Headers add(final Header header) throws IllegalStateException { + Objects.requireNonNull(header, "header cannot be null"); + if (materialized != null) { + materialized.add(header); + } else { + if (pendingAdds == null) { + pendingAdds = new ArrayList<>(); + } + pendingAdds.add(header); + } + return this; + } + + @Override + public Headers add(final String key, final byte[] value) throws IllegalStateException { + return add(new RecordHeader(key, value)); + } + + @Override + public Headers remove(final String key) throws IllegalStateException { + materialize().remove(key); + return this; + } + + @Override + public Header lastHeader(final String key) { + return materialize().lastHeader(key); + } + + @Override + public Iterable
headers(final String key) { + return materialize().headers(key); + } + + @Override + public Header[] toArray() { + return materialize().toArray(); + } + + @Override + public Iterator
iterator() { + return materialize().iterator(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (!(o instanceof Headers)) return false; + final Headers other = (o instanceof LazyHeaders) + ? ((LazyHeaders) o).materialize() + : (Headers) o; + return Arrays.equals(materialize().toArray(), other.toArray()); + } + + @Override + public int hashCode() { + return Arrays.hashCode(materialize().toArray()); + } + + @Override + public String toString() { + if (materialized != null) { + return materialized.toString(); + } + return "LazyHeaders(not yet deserialized)"; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/Utils.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/Utils.java index 6b9d382758775..acc47d497848a 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/Utils.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/Utils.java @@ -43,7 +43,7 @@ public static Headers headers(final byte[] valueWithHeaders) { return null; } - // If the header is empty, simply return it + // If the header is empty, return empty headers without lazy wrapping if (hasEmptyHeaders(valueWithHeaders)) { return new RecordHeaders(); } @@ -173,7 +173,7 @@ public static byte[] readBytes(final ByteBuffer buffer, final int length) { public static Headers readHeaders(final ByteBuffer buffer) { final int headersSize = ByteUtils.readVarint(buffer); final byte[] rawHeaders = readBytes(buffer, headersSize); - return HeadersDeserializer.deserialize(rawHeaders); + return (headersSize == 0) ? new RecordHeaders() : new LazyHeaders(rawHeaders); } /** diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ValueTimestampHeadersDeserializer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ValueTimestampHeadersDeserializer.java index 73f41be794ca7..9fa29db6ae5e9 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/ValueTimestampHeadersDeserializer.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ValueTimestampHeadersDeserializer.java @@ -17,6 +17,7 @@ package org.apache.kafka.streams.state.internals; import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.serialization.LongDeserializer; import org.apache.kafka.common.utils.internals.ByteUtils; @@ -73,7 +74,7 @@ public ValueTimestampHeaders deserialize(final String topic, final byte[] val final int headersSize = ByteUtils.readVarint(buffer); final byte[] rawHeaders = readBytes(buffer, headersSize); - final Headers headers = HeadersDeserializer.deserialize(rawHeaders); + final Headers headers = (headersSize == 0) ? new RecordHeaders() : new LazyHeaders(rawHeaders); final byte[] rawTimestamp = readBytes(buffer, Long.BYTES); final long timestamp = timestampDeserializer.deserialize(topic, rawTimestamp); final byte[] rawValue = readBytes(buffer, buffer.remaining()); diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStoreWithHeadersTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStoreWithHeadersTest.java index af108a8dc34d1..0b6280516381b 100644 --- a/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStoreWithHeadersTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStoreWithHeadersTest.java @@ -27,6 +27,8 @@ import org.apache.kafka.streams.state.AggregationWithHeaders; import org.apache.kafka.streams.state.SessionStore; +import java.util.Arrays; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -38,6 +40,8 @@ import static org.apache.kafka.common.utils.Utils.mkEntry; import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -100,12 +104,12 @@ public void shouldLogPutWithHeaders() { verify(inner).put(key1, serializedValue); verify(context).logChange( - store.name(), - binaryKey, - value1, - 0L, - headers, - Position.emptyPosition() + eq(store.name()), + eq(binaryKey), + eq(value1), + eq(0L), + argThat(actual -> Arrays.equals(actual.toArray(), headers.toArray())), + eq(Position.emptyPosition()) ); } @@ -124,12 +128,12 @@ public void shouldLogPutWithPosition() { verify(inner).put(key1, serializedValue); verify(context).logChange( - store.name(), - binaryKey, - value1, - 0L, - headers, - POSITION + eq(store.name()), + eq(binaryKey), + eq(value1), + eq(0L), + argThat(actual -> Arrays.equals(actual.toArray(), headers.toArray())), + eq(POSITION) ); } @@ -217,12 +221,12 @@ public void shouldHandleMultipleHeadersInSingleRecord() { verify(inner).put(key1, serializedValue); verify(context).logChange( - store.name(), - binaryKey, - value1, - 0L, - headers, - Position.emptyPosition() + eq(store.name()), + eq(binaryKey), + eq(value1), + eq(0L), + argThat(actual -> Arrays.equals(actual.toArray(), headers.toArray())), + eq(Position.emptyPosition()) ); } } diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/LazyHeadersTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/LazyHeadersTest.java new file mode 100644 index 0000000000000..18e535cae3fb3 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/LazyHeadersTest.java @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.header.internals.RecordHeaders; + +import org.junit.jupiter.api.Test; + +import java.nio.charset.StandardCharsets; +import java.util.Iterator; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class LazyHeadersTest { + + private byte[] serializeHeaders(final Headers headers) { + return HeadersSerializer.serialize(headers); + } + + @Test + void shouldNotDeserializeOnConstruction() { + final Headers original = new RecordHeaders() + .add("key1", "value1".getBytes(StandardCharsets.UTF_8)); + final byte[] raw = serializeHeaders(original); + + final LazyHeaders lazy = new LazyHeaders(raw); + + assertFalse(lazy.isDeserialized()); + } + + @Test + void shouldDeserializeOnFirstReadAccess() { + final Headers original = new RecordHeaders() + .add("key1", "value1".getBytes(StandardCharsets.UTF_8)); + final byte[] raw = serializeHeaders(original); + + final LazyHeaders lazy = new LazyHeaders(raw); + assertFalse(lazy.isDeserialized()); + + final Header[] array = lazy.toArray(); + assertTrue(lazy.isDeserialized()); + assertEquals(1, array.length); + assertEquals("key1", array[0].key()); + assertArrayEquals("value1".getBytes(StandardCharsets.UTF_8), array[0].value()); + } + + @Test + void shouldDeserializeOnIterator() { + final Headers original = new RecordHeaders() + .add("key1", "value1".getBytes(StandardCharsets.UTF_8)) + .add("key2", "value2".getBytes(StandardCharsets.UTF_8)); + final byte[] raw = serializeHeaders(original); + + final LazyHeaders lazy = new LazyHeaders(raw); + assertFalse(lazy.isDeserialized()); + + final Iterator
iter = lazy.iterator(); + assertTrue(lazy.isDeserialized()); + assertTrue(iter.hasNext()); + + final Header h1 = iter.next(); + assertEquals("key1", h1.key()); + + final Header h2 = iter.next(); + assertEquals("key2", h2.key()); + + assertFalse(iter.hasNext()); + } + + @Test + void shouldDeserializeOnLastHeader() { + final Headers original = new RecordHeaders() + .add("key", "v1".getBytes(StandardCharsets.UTF_8)) + .add("key", "v2".getBytes(StandardCharsets.UTF_8)); + final byte[] raw = serializeHeaders(original); + + final LazyHeaders lazy = new LazyHeaders(raw); + + final Header last = lazy.lastHeader("key"); + assertTrue(lazy.isDeserialized()); + assertArrayEquals("v2".getBytes(StandardCharsets.UTF_8), last.value()); + } + + @Test + void shouldDeserializeOnHeadersByKey() { + final Headers original = new RecordHeaders() + .add("key", "v1".getBytes(StandardCharsets.UTF_8)) + .add("other", "v2".getBytes(StandardCharsets.UTF_8)) + .add("key", "v3".getBytes(StandardCharsets.UTF_8)); + final byte[] raw = serializeHeaders(original); + + final LazyHeaders lazy = new LazyHeaders(raw); + + int count = 0; + for (final Header h : lazy.headers("key")) { + count++; + assertEquals("key", h.key()); + } + assertEquals(2, count); + assertTrue(lazy.isDeserialized()); + } + + @Test + void shouldAddWithoutDeserializing() { + final Headers original = new RecordHeaders() + .add("existing", "value".getBytes(StandardCharsets.UTF_8)); + final byte[] raw = serializeHeaders(original); + + final LazyHeaders lazy = new LazyHeaders(raw); + lazy.add("new-key", "new-value".getBytes(StandardCharsets.UTF_8)); + + assertFalse(lazy.isDeserialized()); + + // Now access forces deserialization and merging + final Header[] all = lazy.toArray(); + assertTrue(lazy.isDeserialized()); + assertEquals(2, all.length); + assertEquals("existing", all[0].key()); + assertEquals("new-key", all[1].key()); + } + + @Test + void shouldAddMultipleBeforeDeserialization() { + final Headers original = new RecordHeaders() + .add("h1", "v1".getBytes(StandardCharsets.UTF_8)); + final byte[] raw = serializeHeaders(original); + + final LazyHeaders lazy = new LazyHeaders(raw); + lazy.add("h2", "v2".getBytes(StandardCharsets.UTF_8)); + lazy.add("h3", "v3".getBytes(StandardCharsets.UTF_8)); + assertFalse(lazy.isDeserialized()); + + final Header[] all = lazy.toArray(); + assertEquals(3, all.length); + assertEquals("h1", all[0].key()); + assertEquals("h2", all[1].key()); + assertEquals("h3", all[2].key()); + } + + @Test + void shouldAddAfterDeserialization() { + final Headers original = new RecordHeaders() + .add("existing", "value".getBytes(StandardCharsets.UTF_8)); + final byte[] raw = serializeHeaders(original); + + final LazyHeaders lazy = new LazyHeaders(raw); + + // Force deserialization + lazy.toArray(); + assertTrue(lazy.isDeserialized()); + + // Add after deserialization + lazy.add("post", "post-value".getBytes(StandardCharsets.UTF_8)); + + final Header[] all = lazy.toArray(); + assertEquals(2, all.length); + assertEquals("post", all[1].key()); + } + + @Test + void shouldHandleRemove() { + final Headers original = new RecordHeaders() + .add("keep", "v1".getBytes(StandardCharsets.UTF_8)) + .add("remove", "v2".getBytes(StandardCharsets.UTF_8)); + final byte[] raw = serializeHeaders(original); + + final LazyHeaders lazy = new LazyHeaders(raw); + lazy.remove("remove"); + + assertTrue(lazy.isDeserialized()); + assertEquals(1, lazy.toArray().length); + assertEquals("keep", lazy.toArray()[0].key()); + } + + @Test + void shouldHandleNullRawHeaders() { + final LazyHeaders lazy = new LazyHeaders(null); + assertFalse(lazy.isDeserialized()); + + final Header[] all = lazy.toArray(); + assertTrue(lazy.isDeserialized()); + assertEquals(0, all.length); + } + + @Test + void shouldHandleEmptyRawHeaders() { + final LazyHeaders lazy = new LazyHeaders(new byte[0]); + assertFalse(lazy.isDeserialized()); + + final Header[] all = lazy.toArray(); + assertTrue(lazy.isDeserialized()); + assertEquals(0, all.length); + } + + @Test + void shouldPreserveNullHeaderValues() { + final Headers original = new RecordHeaders() + .add("nullable", null); + final byte[] raw = serializeHeaders(original); + + final LazyHeaders lazy = new LazyHeaders(raw); + final Header last = lazy.lastHeader("nullable"); + assertNotNull(last); + assertNull(last.value()); + } + + @Test + void shouldReturnCorrectToStringBeforeDeserialization() { + final LazyHeaders lazy = new LazyHeaders(new byte[0]); + assertEquals("LazyHeaders(not yet deserialized)", lazy.toString()); + } + + @Test + void shouldReturnDelegateToStringAfterDeserialization() { + final LazyHeaders lazy = new LazyHeaders(new byte[0]); + lazy.toArray(); // force materialization + assertNotNull(lazy.toString()); + assertFalse(lazy.toString().contains("not yet deserialized")); + } + + @Test + void shouldBeEqualToEquivalentRecordHeaders() { + final Headers original = new RecordHeaders() + .add("k1", "v1".getBytes(StandardCharsets.UTF_8)); + final byte[] raw = serializeHeaders(original); + + final LazyHeaders lazy = new LazyHeaders(raw); + + final RecordHeaders expected = new RecordHeaders(); + expected.add(new RecordHeader("k1", "v1".getBytes(StandardCharsets.UTF_8))); + + // LazyHeaders.equals(RecordHeaders) uses content-based comparison + assertEquals(lazy, expected); + // Also verify symmetric content equality via toArray + assertArrayEquals(expected.toArray(), lazy.toArray()); + } + + @Test + void shouldAddWithHeaderObject() { + final LazyHeaders lazy = new LazyHeaders(null); + + lazy.add(new RecordHeader("key", "value".getBytes(StandardCharsets.UTF_8))); + assertFalse(lazy.isDeserialized()); + + final Header[] all = lazy.toArray(); + assertEquals(1, all.length); + assertEquals("key", all[0].key()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/UtilsTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/UtilsTest.java index 1d08e89899659..020cbd5caeded 100644 --- a/streams/src/test/java/org/apache/kafka/streams/state/internals/UtilsTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/UtilsTest.java @@ -43,6 +43,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -240,4 +241,48 @@ private static byte[] timestampedValueOf(final byte[] value) { buf.put(value); return res; } + + @Test + public void shouldReturnLazyHeadersForNonEmptyHeaders() { + final RecordHeaders originalHeaders = new RecordHeaders(); + originalHeaders.add("key1", "val1".getBytes(StandardCharsets.UTF_8)); + final byte[] serializedHeaders = HeadersSerializer.serialize(originalHeaders); + final byte[] storedBytes = headersTimestampValueOf(headersOf(serializedHeaders), VALUE); + + final Headers result = Utils.headers(storedBytes); + + assertInstanceOf(LazyHeaders.class, result); + assertFalse(((LazyHeaders) result).isDeserialized()); + } + + @Test + public void shouldReturnRecordHeadersForEmptyHeaders() { + final byte[] storedBytes = timestampedValueWithEmptyHeaders(VALUE); + + final Headers result = Utils.headers(storedBytes); + + assertInstanceOf(RecordHeaders.class, result); + assertFalse(result instanceof LazyHeaders); + } + + @Test + public void shouldRemainLazyThroughChangelogFlow() { + final RecordHeaders originalHeaders = new RecordHeaders(); + originalHeaders.add("h1", "v1".getBytes(StandardCharsets.UTF_8)); + final byte[] serializedHeaders = HeadersSerializer.serialize(originalHeaders); + final byte[] storedBytes = headersTimestampValueOf(headersOf(serializedHeaders), VALUE); + + // Step 1: Utils.headers() returns LazyHeaders, not yet parsed + final Headers headers = Utils.headers(storedBytes); + assertInstanceOf(LazyHeaders.class, headers); + assertFalse(((LazyHeaders) headers).isDeserialized()); + + // Step 2: add() does not force parsing (simulates addVectorClockToHeaders) + headers.add("clock", new byte[]{0x01, 0x02}); + assertFalse(((LazyHeaders) headers).isDeserialized()); + + // Step 3: toArray() forces parsing (simulates producer serialization) + headers.toArray(); + assertTrue(((LazyHeaders) headers).isDeserialized()); + } }