From c0912144c578d0e6b2ff5608f6d74410c106bd9e Mon Sep 17 00:00:00 2001 From: Pierre Lacave Date: Thu, 24 Jul 2025 12:04:16 +0200 Subject: [PATCH 1/6] Fix build following cms/cpc recent PR --- .../datasketches/count/CountMinSketch.java | 71 +++++++++++++------ .../cpc/CpcSketchCrossLanguageTest.java | 3 +- 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/src/main/java/org/apache/datasketches/count/CountMinSketch.java b/src/main/java/org/apache/datasketches/count/CountMinSketch.java index 36bea38cf..4ef6a8ec5 100644 --- a/src/main/java/org/apache/datasketches/count/CountMinSketch.java +++ b/src/main/java/org/apache/datasketches/count/CountMinSketch.java @@ -22,8 +22,8 @@ import org.apache.datasketches.common.Family; import org.apache.datasketches.common.SketchesArgumentException; import org.apache.datasketches.common.SketchesException; +import org.apache.datasketches.common.Util; import org.apache.datasketches.hash.MurmurHash3; -import org.apache.datasketches.tuple.Util; import java.io.ByteArrayOutputStream; import java.nio.ByteBuffer; @@ -39,6 +39,9 @@ public class CountMinSketch { private final long[] sketchArray_; private long totalWeight_; + // Thread-local ByteBuffer to avoid allocations in hot paths + private static final ThreadLocal LONG_BUFFER = + ThreadLocal.withInitial(() -> ByteBuffer.allocate(8)); private enum Flag { IS_EMPTY; @@ -57,30 +60,59 @@ int mask() { * @param seed The base hash seed */ CountMinSketch(final byte numHashes, final int numBuckets, final long seed) { - numHashes_ = numHashes; - numBuckets_ = numBuckets; - seed_ = seed; - hashSeeds_ = new long[numHashes]; - sketchArray_ = new long[numHashes * numBuckets]; - totalWeight_ = 0; + // Validate numHashes + if (numHashes <= 0) { + throw new SketchesArgumentException("Number of hash functions must be positive, got: " + numHashes); + } + // Validate numBuckets with clear mathematical justification + if (numBuckets <= 0) { + throw new SketchesArgumentException("Number of buckets must be positive, got: " + numBuckets); + } if (numBuckets < 3) { - throw new SketchesArgumentException("Using fewer than 3 buckets incurs relative error greater than 1."); + throw new SketchesArgumentException("Number of buckets must be at least 3 to ensure relative error ≤ 1.0. " + + "With " + numBuckets + " buckets, relative error would be " + String.format("%.3f", Math.exp(1.0) / numBuckets)); + } + + // Check for potential overflow in array size calculation + // Use long arithmetic to detect overflow before casting + final long totalSize = (long) numHashes * (long) numBuckets; + if (totalSize > Integer.MAX_VALUE) { + throw new SketchesArgumentException("Sketch array size would overflow: " + numHashes + " * " + numBuckets + + " = " + totalSize + " > " + Integer.MAX_VALUE); } // This check is to ensure later compatibility with a Java implementation whose maximum size can only // be 2^31-1. We check only against 2^30 for simplicity. - if (numBuckets * numHashes >= 1 << 30) { - throw new SketchesArgumentException("These parameters generate a sketch that exceeds 2^30 elements. \n" + - "Try reducing either the number of buckets or the number of hash functions."); + if (totalSize >= (1L << 30)) { + throw new SketchesArgumentException("Sketch would require excessive memory: " + numHashes + " * " + numBuckets + + " = " + totalSize + " elements (~" + String.format("%.1f", totalSize * 8.0 / (1024 * 1024 * 1024)) + " GB). " + + "Consider reducing numHashes or numBuckets."); } + numHashes_ = numHashes; + numBuckets_ = numBuckets; + seed_ = seed; + hashSeeds_ = new long[numHashes]; + sketchArray_ = new long[(int) totalSize]; + totalWeight_ = 0; + Random rand = new Random(seed); for (int i = 0; i < numHashes; i++) { hashSeeds_[i] = rand.nextLong(); } } + /** + * Efficiently converts a long to byte array using thread-local buffer to avoid allocations. + */ + private static byte[] longToBytes(final long value) { + final ByteBuffer buffer = LONG_BUFFER.get(); + buffer.clear(); + buffer.putLong(value); + return buffer.array(); + } + private long[] getHashes(byte[] item) { long[] updateLocations = new long[numHashes_]; @@ -171,8 +203,7 @@ public static int suggestNumBuckets(double relativeError) { * @param weight The weight of the item. */ public void update(final long item, final long weight) { - byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); - update(longByte, weight); + update(longToBytes(item), weight); } /** @@ -211,8 +242,7 @@ public void update(final byte[] item, final long weight) { * @return Estimated frequency. */ public long getEstimate(final long item) { - byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); - return getEstimate(longByte); + return getEstimate(longToBytes(item)); } /** @@ -241,8 +271,9 @@ public long getEstimate(final byte[] item) { long[] hashLocations = getHashes(item); long res = sketchArray_[(int) hashLocations[0]]; - for (long h : hashLocations) { - res = Math.min(res, sketchArray_[(int) h]); + // Start from index 1 to avoid processing first element twice + for (int i = 1; i < hashLocations.length; i++) { + res = Math.min(res, sketchArray_[(int) hashLocations[i]]); } return res; @@ -254,8 +285,7 @@ public long getEstimate(final byte[] item) { * @return Upper bound of estimated frequency. */ public long getUpperBound(final long item) { - byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); - return getUpperBound(longByte); + return getUpperBound(longToBytes(item)); } /** @@ -291,8 +321,7 @@ public long getUpperBound(final byte[] item) { * @return Lower bound of estimated frequency. */ public long getLowerBound(final long item) { - byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); - return getLowerBound(longByte); + return getLowerBound(longToBytes(item)); } /** diff --git a/src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java b/src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java index 16cc55db9..2346ec918 100644 --- a/src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java +++ b/src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java @@ -31,6 +31,7 @@ import java.io.IOException; import java.nio.file.Files; +import org.apache.datasketches.memory.Memory; import org.testng.annotations.Test; /** @@ -89,7 +90,7 @@ public void checkAllFlavorsGo() throws IOException { int flavorIdx = 0; for (int n: nArr) { final byte[] bytes = Files.readAllBytes(goPath.resolve("cpc_n" + n + "_go.sk")); - final CpcSketch sketch = CpcSketch.heapify(Memory.wrap(bytes)); + final CpcSketch sketch = CpcSketch.heapify(MemorySegment.ofArray(bytes)); assertEquals(sketch.getFlavor(), flavorArr[flavorIdx++]); assertEquals(sketch.getEstimate(), n, n * 0.02); } From c6c90c3ae254d9322447f8701c212ef9d3d0c8f3 Mon Sep 17 00:00:00 2001 From: Pierre Lacave Date: Thu, 24 Jul 2025 12:11:34 +0200 Subject: [PATCH 2/6] Update CpcSketchCrossLanguageTest.java remove memory import --- .../org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java b/src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java index 2346ec918..906f1a52a 100644 --- a/src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java +++ b/src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java @@ -31,7 +31,6 @@ import java.io.IOException; import java.nio.file.Files; -import org.apache.datasketches.memory.Memory; import org.testng.annotations.Test; /** From e0445efe45f6361ed6f0396f682eeddf3f008ce3 Mon Sep 17 00:00:00 2001 From: Pierre Lacave Date: Thu, 24 Jul 2025 20:13:53 +0200 Subject: [PATCH 3/6] use MemorySegment instead of ByteBuffer in CMS --- .../datasketches/count/CountMinSketch.java | 116 +++++++++++------- .../count/CountMinSketchTest.java | 12 +- 2 files changed, 78 insertions(+), 50 deletions(-) diff --git a/src/main/java/org/apache/datasketches/count/CountMinSketch.java b/src/main/java/org/apache/datasketches/count/CountMinSketch.java index 4ef6a8ec5..d1d79a0c2 100644 --- a/src/main/java/org/apache/datasketches/count/CountMinSketch.java +++ b/src/main/java/org/apache/datasketches/count/CountMinSketch.java @@ -25,11 +25,16 @@ import org.apache.datasketches.common.Util; import org.apache.datasketches.hash.MurmurHash3; -import java.io.ByteArrayOutputStream; -import java.nio.ByteBuffer; +import java.lang.foreign.MemorySegment; import java.nio.charset.StandardCharsets; + +import static java.lang.foreign.ValueLayout.JAVA_BYTE; import java.util.Random; +import static org.apache.datasketches.common.SpecialValueLayouts.JAVA_INT_UNALIGNED_BIG_ENDIAN; +import static org.apache.datasketches.common.SpecialValueLayouts.JAVA_LONG_UNALIGNED_BIG_ENDIAN; +import static org.apache.datasketches.common.SpecialValueLayouts.JAVA_SHORT_UNALIGNED_BIG_ENDIAN; + public class CountMinSketch { private final byte numHashes_; @@ -39,9 +44,9 @@ public class CountMinSketch { private final long[] sketchArray_; private long totalWeight_; - // Thread-local ByteBuffer to avoid allocations in hot paths - private static final ThreadLocal LONG_BUFFER = - ThreadLocal.withInitial(() -> ByteBuffer.allocate(8)); + // Thread-local MemorySegment to avoid allocations in hot paths with explicit endianness control + private static final ThreadLocal LONG_SEGMENT = + ThreadLocal.withInitial(() -> MemorySegment.ofArray(new byte[8])); private enum Flag { IS_EMPTY; @@ -104,15 +109,15 @@ int mask() { } /** - * Efficiently converts a long to byte array using thread-local buffer to avoid allocations. + * Efficiently converts a long to byte array using thread-local MemorySegment with explicit endianness. */ private static byte[] longToBytes(final long value) { - final ByteBuffer buffer = LONG_BUFFER.get(); - buffer.clear(); - buffer.putLong(value); - return buffer.array(); + final MemorySegment segment = LONG_SEGMENT.get(); + segment.set(JAVA_LONG_UNALIGNED_BIG_ENDIAN, 0, value); + return segment.toArray(JAVA_BYTE); } + private long[] getHashes(byte[] item) { long[] updateLocations = new long[numHashes_]; @@ -371,39 +376,62 @@ public void merge(final CountMinSketch other) { } /** - * Serializes the sketch into the provided ByteBuffer. - * @param buf The ByteBuffer to write into. + * Returns the serialized size in bytes. */ - public void serialize(ByteArrayOutputStream buf) { + private int getSerializedSizeBytes() { + final int preambleBytes = Family.COUNTMIN.getMinPreLongs() * Long.BYTES; + if (isEmpty()) { + return preambleBytes; + } + return preambleBytes + Long.BYTES + (sketchArray_.length * Long.BYTES); + } + + + /** + * Returns the sketch as a byte array. + */ + public byte[] toByteArray() { + final int serializedSizeBytes = getSerializedSizeBytes(); + final MemorySegment wseg = MemorySegment.ofArray(new byte[serializedSizeBytes]); + + long offset = 0; + // Long 0 final int preambleLongs = Family.COUNTMIN.getMinPreLongs(); - buf.write((byte) preambleLongs); + wseg.set(JAVA_BYTE, offset++, (byte) preambleLongs); final int serialVersion = 1; - buf.write((byte) serialVersion); + wseg.set(JAVA_BYTE, offset++, (byte) serialVersion); final int familyId = Family.COUNTMIN.getID(); - buf.write((byte) familyId); + wseg.set(JAVA_BYTE, offset++, (byte) familyId); final int flagsByte = isEmpty() ? Flag.IS_EMPTY.mask() : 0; - buf.write((byte)flagsByte); + wseg.set(JAVA_BYTE, offset++, (byte) flagsByte); final int NULL_32 = 0; - buf.writeBytes(ByteBuffer.allocate(4).putInt(NULL_32).array()); + wseg.set(JAVA_INT_UNALIGNED_BIG_ENDIAN, offset, NULL_32); + offset += 4; // Long 1 - buf.writeBytes(ByteBuffer.allocate(4).putInt(numBuckets_).array()); - buf.write(numHashes_); + wseg.set(JAVA_INT_UNALIGNED_BIG_ENDIAN, offset, numBuckets_); + offset += 4; + wseg.set(JAVA_BYTE, offset++, numHashes_); short hashSeed = Util.computeSeedHash(seed_); - buf.writeBytes(ByteBuffer.allocate(2).putShort(hashSeed).array()); + wseg.set(JAVA_SHORT_UNALIGNED_BIG_ENDIAN, offset, hashSeed); + offset += 2; final byte NULL_8 = 0; - buf.write(NULL_8); + wseg.set(JAVA_BYTE, offset++, NULL_8); + if (isEmpty()) { - return; + return wseg.toArray(JAVA_BYTE); } - final byte[] totWeightByte = ByteBuffer.allocate(8).putLong(totalWeight_).array(); - buf.writeBytes(totWeightByte); + wseg.set(JAVA_LONG_UNALIGNED_BIG_ENDIAN, offset, totalWeight_); + offset += 8; for (long w: sketchArray_) { - buf.writeBytes(ByteBuffer.allocate(8).putLong(w).array()); + wseg.set(JAVA_LONG_UNALIGNED_BIG_ENDIAN, offset, w); + offset += 8; } + + return wseg.toArray(JAVA_BYTE); } /** @@ -413,20 +441,22 @@ public void serialize(ByteArrayOutputStream buf) { * @return The deserialized CountMinSketch. */ public static CountMinSketch deserialize(final byte[] b, final long seed) { - ByteBuffer buf = ByteBuffer.allocate(b.length); - buf.put(b); - buf.flip(); - - final byte preambleLongs = buf.get(); - final byte serialVersion = buf.get(); - final byte familyId = buf.get(); - final byte flagsByte = buf.get(); - final int NULL_32 = buf.getInt(); - - final int numBuckets = buf.getInt(); - final byte numHashes = buf.get(); - final short seedHash = buf.getShort(); - final byte NULL_8 = buf.get(); + final MemorySegment buf = MemorySegment.ofArray(b); + long offset = 0; + + final byte preambleLongs = buf.get(JAVA_BYTE, offset++); + final byte serialVersion = buf.get(JAVA_BYTE, offset++); + final byte familyId = buf.get(JAVA_BYTE, offset++); + final byte flagsByte = buf.get(JAVA_BYTE, offset++); + final int NULL_32 = buf.get(JAVA_INT_UNALIGNED_BIG_ENDIAN, offset); + offset += 4; + + final int numBuckets = buf.get(JAVA_INT_UNALIGNED_BIG_ENDIAN, offset); + offset += 4; + final byte numHashes = buf.get(JAVA_BYTE, offset++); + final short seedHash = buf.get(JAVA_SHORT_UNALIGNED_BIG_ENDIAN, offset); + offset += 2; + final byte NULL_8 = buf.get(JAVA_BYTE, offset++); if (seedHash != Util.computeSeedHash(seed)) { throw new SketchesArgumentException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", " @@ -438,11 +468,13 @@ public static CountMinSketch deserialize(final byte[] b, final long seed) { if (empty) { return cms; } - long w = buf.getLong(); + long w = buf.get(JAVA_LONG_UNALIGNED_BIG_ENDIAN, offset); + offset += 8; cms.totalWeight_ = w; for (int i = 0; i < cms.sketchArray_.length; i++) { - cms.sketchArray_[i] = buf.getLong(); + cms.sketchArray_[i] = buf.get(JAVA_LONG_UNALIGNED_BIG_ENDIAN, offset); + offset += 8; } return cms; diff --git a/src/test/java/org/apache/datasketches/count/CountMinSketchTest.java b/src/test/java/org/apache/datasketches/count/CountMinSketchTest.java index cbd2fde79..cdf42c3b0 100644 --- a/src/test/java/org/apache/datasketches/count/CountMinSketchTest.java +++ b/src/test/java/org/apache/datasketches/count/CountMinSketchTest.java @@ -203,10 +203,7 @@ public void serializeDeserializeEmptyTest() { final long seed = 123456; CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed); - ByteArrayOutputStream buf = new ByteArrayOutputStream(); - c.serialize(buf); - - byte[] b = buf.toByteArray(); + byte[] b = c.toByteArray(); assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(b, seed - 1)); CountMinSketch d = CountMinSketch.deserialize(b, seed); @@ -228,11 +225,10 @@ public void serializeDeserializeTest() { c.update(i, 10*i*i); } - ByteArrayOutputStream buf = new ByteArrayOutputStream(); - c.serialize(buf); + byte[] b = c.toByteArray(); - assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(buf.toByteArray(), seed - 1)); - CountMinSketch d = CountMinSketch.deserialize(buf.toByteArray(), seed); + assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(b, seed - 1)); + CountMinSketch d = CountMinSketch.deserialize(b, seed); assertEquals(d.getNumHashes_(), c.getNumHashes_()); assertEquals(d.getNumBuckets_(), c.getNumBuckets_()); From ea9fda9409e73641cdb5e0989d278ca00352a992 Mon Sep 17 00:00:00 2001 From: Pierre Lacave Date: Thu, 24 Jul 2025 21:16:59 +0200 Subject: [PATCH 4/6] use default endian - not BE --- .../datasketches/count/CountMinSketch.java | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/main/java/org/apache/datasketches/count/CountMinSketch.java b/src/main/java/org/apache/datasketches/count/CountMinSketch.java index d1d79a0c2..fdee6b044 100644 --- a/src/main/java/org/apache/datasketches/count/CountMinSketch.java +++ b/src/main/java/org/apache/datasketches/count/CountMinSketch.java @@ -31,9 +31,9 @@ import static java.lang.foreign.ValueLayout.JAVA_BYTE; import java.util.Random; -import static org.apache.datasketches.common.SpecialValueLayouts.JAVA_INT_UNALIGNED_BIG_ENDIAN; -import static org.apache.datasketches.common.SpecialValueLayouts.JAVA_LONG_UNALIGNED_BIG_ENDIAN; -import static org.apache.datasketches.common.SpecialValueLayouts.JAVA_SHORT_UNALIGNED_BIG_ENDIAN; +import static java.lang.foreign.ValueLayout.JAVA_INT_UNALIGNED; +import static java.lang.foreign.ValueLayout.JAVA_LONG_UNALIGNED; +import static java.lang.foreign.ValueLayout.JAVA_SHORT_UNALIGNED; public class CountMinSketch { @@ -113,7 +113,7 @@ int mask() { */ private static byte[] longToBytes(final long value) { final MemorySegment segment = LONG_SEGMENT.get(); - segment.set(JAVA_LONG_UNALIGNED_BIG_ENDIAN, 0, value); + segment.set(JAVA_LONG_UNALIGNED, 0, value); return segment.toArray(JAVA_BYTE); } @@ -393,9 +393,9 @@ private int getSerializedSizeBytes() { public byte[] toByteArray() { final int serializedSizeBytes = getSerializedSizeBytes(); final MemorySegment wseg = MemorySegment.ofArray(new byte[serializedSizeBytes]); - + long offset = 0; - + // Long 0 final int preambleLongs = Family.COUNTMIN.getMinPreLongs(); wseg.set(JAVA_BYTE, offset++, (byte) preambleLongs); @@ -406,31 +406,31 @@ public byte[] toByteArray() { final int flagsByte = isEmpty() ? Flag.IS_EMPTY.mask() : 0; wseg.set(JAVA_BYTE, offset++, (byte) flagsByte); final int NULL_32 = 0; - wseg.set(JAVA_INT_UNALIGNED_BIG_ENDIAN, offset, NULL_32); + wseg.set(JAVA_INT_UNALIGNED, offset, NULL_32); offset += 4; // Long 1 - wseg.set(JAVA_INT_UNALIGNED_BIG_ENDIAN, offset, numBuckets_); + wseg.set(JAVA_INT_UNALIGNED, offset, numBuckets_); offset += 4; wseg.set(JAVA_BYTE, offset++, numHashes_); short hashSeed = Util.computeSeedHash(seed_); - wseg.set(JAVA_SHORT_UNALIGNED_BIG_ENDIAN, offset, hashSeed); + wseg.set(JAVA_SHORT_UNALIGNED, offset, hashSeed); offset += 2; final byte NULL_8 = 0; wseg.set(JAVA_BYTE, offset++, NULL_8); - + if (isEmpty()) { return wseg.toArray(JAVA_BYTE); } - wseg.set(JAVA_LONG_UNALIGNED_BIG_ENDIAN, offset, totalWeight_); + wseg.set(JAVA_LONG_UNALIGNED, offset, totalWeight_); offset += 8; for (long w: sketchArray_) { - wseg.set(JAVA_LONG_UNALIGNED_BIG_ENDIAN, offset, w); + wseg.set(JAVA_LONG_UNALIGNED, offset, w); offset += 8; } - + return wseg.toArray(JAVA_BYTE); } @@ -448,13 +448,13 @@ public static CountMinSketch deserialize(final byte[] b, final long seed) { final byte serialVersion = buf.get(JAVA_BYTE, offset++); final byte familyId = buf.get(JAVA_BYTE, offset++); final byte flagsByte = buf.get(JAVA_BYTE, offset++); - final int NULL_32 = buf.get(JAVA_INT_UNALIGNED_BIG_ENDIAN, offset); + final int NULL_32 = buf.get(JAVA_INT_UNALIGNED, offset); offset += 4; - final int numBuckets = buf.get(JAVA_INT_UNALIGNED_BIG_ENDIAN, offset); + final int numBuckets = buf.get(JAVA_INT_UNALIGNED, offset); offset += 4; final byte numHashes = buf.get(JAVA_BYTE, offset++); - final short seedHash = buf.get(JAVA_SHORT_UNALIGNED_BIG_ENDIAN, offset); + final short seedHash = buf.get(JAVA_SHORT_UNALIGNED, offset); offset += 2; final byte NULL_8 = buf.get(JAVA_BYTE, offset++); @@ -468,12 +468,12 @@ public static CountMinSketch deserialize(final byte[] b, final long seed) { if (empty) { return cms; } - long w = buf.get(JAVA_LONG_UNALIGNED_BIG_ENDIAN, offset); + long w = buf.get(JAVA_LONG_UNALIGNED, offset); offset += 8; cms.totalWeight_ = w; for (int i = 0; i < cms.sketchArray_.length; i++) { - cms.sketchArray_[i] = buf.get(JAVA_LONG_UNALIGNED_BIG_ENDIAN, offset); + cms.sketchArray_[i] = buf.get(JAVA_LONG_UNALIGNED, offset); offset += 8; } From 8bd1423a7111774ae37334d115132837bdd18d76 Mon Sep 17 00:00:00 2001 From: Pierre Lacave Date: Sat, 26 Jul 2025 15:18:17 +0200 Subject: [PATCH 5/6] Address all remaining review comments --- .../datasketches/count/CountMinSketch.java | 143 +++++++++--------- 1 file changed, 74 insertions(+), 69 deletions(-) diff --git a/src/main/java/org/apache/datasketches/count/CountMinSketch.java b/src/main/java/org/apache/datasketches/count/CountMinSketch.java index fdee6b044..52801f01e 100644 --- a/src/main/java/org/apache/datasketches/count/CountMinSketch.java +++ b/src/main/java/org/apache/datasketches/count/CountMinSketch.java @@ -23,14 +23,14 @@ import org.apache.datasketches.common.SketchesArgumentException; import org.apache.datasketches.common.SketchesException; import org.apache.datasketches.common.Util; +import org.apache.datasketches.common.positional.PositionalSegment; import org.apache.datasketches.hash.MurmurHash3; import java.lang.foreign.MemorySegment; import java.nio.charset.StandardCharsets; - -import static java.lang.foreign.ValueLayout.JAVA_BYTE; import java.util.Random; +import static java.lang.foreign.ValueLayout.JAVA_BYTE; import static java.lang.foreign.ValueLayout.JAVA_INT_UNALIGNED; import static java.lang.foreign.ValueLayout.JAVA_LONG_UNALIGNED; import static java.lang.foreign.ValueLayout.JAVA_SHORT_UNALIGNED; @@ -46,7 +46,7 @@ public class CountMinSketch { // Thread-local MemorySegment to avoid allocations in hot paths with explicit endianness control private static final ThreadLocal LONG_SEGMENT = - ThreadLocal.withInitial(() -> MemorySegment.ofArray(new byte[8])); + ThreadLocal.withInitial(() -> MemorySegment.ofArray(new byte[Long.BYTES])); private enum Flag { IS_EMPTY; @@ -83,16 +83,16 @@ int mask() { // Use long arithmetic to detect overflow before casting final long totalSize = (long) numHashes * (long) numBuckets; if (totalSize > Integer.MAX_VALUE) { - throw new SketchesArgumentException("Sketch array size would overflow: " + numHashes + " * " + numBuckets + - " = " + totalSize + " > " + Integer.MAX_VALUE); + throw new SketchesArgumentException("Sketch array size would overflow: " + numHashes + " * " + numBuckets + + " = " + totalSize + " > " + Integer.MAX_VALUE); } // This check is to ensure later compatibility with a Java implementation whose maximum size can only // be 2^31-1. We check only against 2^30 for simplicity. if (totalSize >= (1L << 30)) { - throw new SketchesArgumentException("Sketch would require excessive memory: " + numHashes + " * " + numBuckets + - " = " + totalSize + " elements (~" + String.format("%.1f", totalSize * 8.0 / (1024 * 1024 * 1024)) + " GB). " + - "Consider reducing numHashes or numBuckets."); + throw new SketchesArgumentException("Sketch would require excessive memory: " + numHashes + " * " + numBuckets + + " = " + totalSize + " elements (~" + String.format("%d", totalSize * Long.BYTES / (1024 * 1024 * 1024)) + " GB). " + + "Consider reducing numHashes or numBuckets."); } numHashes_ = numHashes; @@ -102,7 +102,7 @@ int mask() { sketchArray_ = new long[(int) totalSize]; totalWeight_ = 0; - Random rand = new Random(seed); + final Random rand = new Random(seed); for (int i = 0; i < numHashes; i++) { hashSeeds_[i] = rand.nextLong(); } @@ -118,11 +118,11 @@ private static byte[] longToBytes(final long value) { } - private long[] getHashes(byte[] item) { - long[] updateLocations = new long[numHashes_]; + private long[] getHashes(final byte[] item) { + final long[] updateLocations = new long[numHashes_]; for (int i = 0; i < numHashes_; i++) { - long[] index = MurmurHash3.hash(item, hashSeeds_[i]); + final long[] index = MurmurHash3.hash(item, hashSeeds_[i]); updateLocations[i] = i * (long)numBuckets_ + Math.floorMod(index[0], numBuckets_); } @@ -182,11 +182,11 @@ public double getRelativeError() { * @param confidence The desired confidence level between 0 and 1. * @return Suggested number of hash functions. */ - public static byte suggestNumHashes(double confidence) { + public static byte suggestNumHashes(final double confidence) { if (confidence < 0 || confidence > 1) { throw new SketchesException("Confidence must be between 0 and 1.0 (inclusive)."); } - int value = (int) Math.ceil(Math.log(1.0 / (1.0 - confidence))); + final int value = (int) Math.ceil(Math.log(1.0 / (1.0 - confidence))); return (byte) Math.min(value, 127); } @@ -195,7 +195,7 @@ public static byte suggestNumHashes(double confidence) { * @param relativeError The desired relative error. * @return Suggested number of buckets. */ - public static int suggestNumBuckets(double relativeError) { + public static int suggestNumBuckets(final double relativeError) { if (relativeError < 0.) { throw new SketchesException("Relative error must be at least 0."); } @@ -235,8 +235,8 @@ public void update(final byte[] item, final long weight) { } totalWeight_ += weight > 0 ? weight : -weight; - long[] hashLocations = getHashes(item); - for (long h : hashLocations) { + final long[] hashLocations = getHashes(item); + for (final long h : hashLocations) { sketchArray_[(int) h] += weight; } } @@ -274,7 +274,7 @@ public long getEstimate(final byte[] item) { return 0; } - long[] hashLocations = getHashes(item); + final long[] hashLocations = getHashes(item); long res = sketchArray_[(int) hashLocations[0]]; // Start from index 1 to avoid processing first element twice for (int i = 1; i < hashLocations.length; i++) { @@ -303,8 +303,8 @@ public long getUpperBound(final String item) { return 0; } - byte[] strByte = item.getBytes(StandardCharsets.UTF_8); - return getUpperBound(strByte); + final byte[] strByte = item.getBytes(StandardCharsets.UTF_8); + return getUpperBound(strByte); } /** @@ -339,7 +339,7 @@ public long getLowerBound(final String item) { return 0; } - byte[] strByte = item.getBytes(StandardCharsets.UTF_8); + final byte[] strByte = item.getBytes(StandardCharsets.UTF_8); return getLowerBound(strByte); } @@ -361,8 +361,8 @@ public void merge(final CountMinSketch other) { throw new SketchesException("Cannot merge a sketch with itself"); } - boolean acceptableConfig = getNumBuckets_() == other.getNumBuckets_() && - getNumHashes_() == other.getNumHashes_() && getSeed_() == other.getSeed_(); + final boolean acceptableConfig = getNumBuckets_() == other.getNumBuckets_() + && getNumHashes_() == other.getNumHashes_() && getSeed_() == other.getSeed_(); if (!acceptableConfig) { throw new SketchesException("Incompatible sketch configuration."); @@ -392,46 +392,40 @@ private int getSerializedSizeBytes() { */ public byte[] toByteArray() { final int serializedSizeBytes = getSerializedSizeBytes(); - final MemorySegment wseg = MemorySegment.ofArray(new byte[serializedSizeBytes]); - - long offset = 0; + final byte[] bytes = new byte[serializedSizeBytes]; + final PositionalSegment posSeg = PositionalSegment.wrap(MemorySegment.ofArray(bytes)); // Long 0 final int preambleLongs = Family.COUNTMIN.getMinPreLongs(); - wseg.set(JAVA_BYTE, offset++, (byte) preambleLongs); + posSeg.setByte((byte) preambleLongs); final int serialVersion = 1; - wseg.set(JAVA_BYTE, offset++, (byte) serialVersion); + posSeg.setByte((byte) serialVersion); final int familyId = Family.COUNTMIN.getID(); - wseg.set(JAVA_BYTE, offset++, (byte) familyId); + posSeg.setByte((byte) familyId); final int flagsByte = isEmpty() ? Flag.IS_EMPTY.mask() : 0; - wseg.set(JAVA_BYTE, offset++, (byte) flagsByte); + posSeg.setByte((byte) flagsByte); final int NULL_32 = 0; - wseg.set(JAVA_INT_UNALIGNED, offset, NULL_32); - offset += 4; + posSeg.setInt(NULL_32); // Long 1 - wseg.set(JAVA_INT_UNALIGNED, offset, numBuckets_); - offset += 4; - wseg.set(JAVA_BYTE, offset++, numHashes_); - short hashSeed = Util.computeSeedHash(seed_); - wseg.set(JAVA_SHORT_UNALIGNED, offset, hashSeed); - offset += 2; + posSeg.setInt(numBuckets_); + posSeg.setByte(numHashes_); + final short hashSeed = Util.computeSeedHash(seed_); + posSeg.setShort(hashSeed); final byte NULL_8 = 0; - wseg.set(JAVA_BYTE, offset++, NULL_8); + posSeg.setByte(NULL_8); if (isEmpty()) { - return wseg.toArray(JAVA_BYTE); + return bytes; } - wseg.set(JAVA_LONG_UNALIGNED, offset, totalWeight_); - offset += 8; + posSeg.setLong(totalWeight_); - for (long w: sketchArray_) { - wseg.set(JAVA_LONG_UNALIGNED, offset, w); - offset += 8; + for (final long w: sketchArray_) { + posSeg.setLong(w); } - return wseg.toArray(JAVA_BYTE); + return bytes; } /** @@ -441,40 +435,51 @@ public byte[] toByteArray() { * @return The deserialized CountMinSketch. */ public static CountMinSketch deserialize(final byte[] b, final long seed) { - final MemorySegment buf = MemorySegment.ofArray(b); - long offset = 0; - - final byte preambleLongs = buf.get(JAVA_BYTE, offset++); - final byte serialVersion = buf.get(JAVA_BYTE, offset++); - final byte familyId = buf.get(JAVA_BYTE, offset++); - final byte flagsByte = buf.get(JAVA_BYTE, offset++); - final int NULL_32 = buf.get(JAVA_INT_UNALIGNED, offset); - offset += 4; - - final int numBuckets = buf.get(JAVA_INT_UNALIGNED, offset); - offset += 4; - final byte numHashes = buf.get(JAVA_BYTE, offset++); - final short seedHash = buf.get(JAVA_SHORT_UNALIGNED, offset); - offset += 2; - final byte NULL_8 = buf.get(JAVA_BYTE, offset++); + final PositionalSegment posSeg = PositionalSegment.wrap(MemorySegment.ofArray(b)); + + final byte preambleLongs = posSeg.getByte(); + final byte serialVersion = posSeg.getByte(); + final byte familyId = posSeg.getByte(); + final byte flagsByte = posSeg.getByte(); + posSeg.getInt(); // skip NULL_32 + + // Validate serialization format + final int expectedPreambleLongs = Family.COUNTMIN.getMinPreLongs(); + if (preambleLongs != expectedPreambleLongs) { + throw new SketchesArgumentException("Preamble longs mismatch: expected " + expectedPreambleLongs + + ", actual " + preambleLongs); + } + final int expectedSerialVersion = 1; + if (serialVersion != expectedSerialVersion) { + throw new SketchesArgumentException("Serial version mismatch: expected " + expectedSerialVersion + + ", actual " + serialVersion); + } + final int expectedFamilyId = Family.COUNTMIN.getID(); + if (familyId != expectedFamilyId) { + throw new SketchesArgumentException("Family ID mismatch: expected " + expectedFamilyId + + ", actual " + familyId); + } + + final int numBuckets = posSeg.getInt(); + final byte numHashes = posSeg.getByte(); + final short seedHash = posSeg.getShort(); + posSeg.getByte(); // skip NULL_8 if (seedHash != Util.computeSeedHash(seed)) { - throw new SketchesArgumentException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", " - + String.valueOf(Util.computeSeedHash(seed))); + throw new SketchesArgumentException("Incompatible seed hashes: " + seedHash + ", " + + Util.computeSeedHash(seed)); } - CountMinSketch cms = new CountMinSketch(numHashes, numBuckets, seed); + final CountMinSketch cms = new CountMinSketch(numHashes, numBuckets, seed); final boolean empty = (flagsByte & Flag.IS_EMPTY.mask()) > 0; if (empty) { return cms; } - long w = buf.get(JAVA_LONG_UNALIGNED, offset); - offset += 8; + final long w = posSeg.getLong(); cms.totalWeight_ = w; for (int i = 0; i < cms.sketchArray_.length; i++) { - cms.sketchArray_[i] = buf.get(JAVA_LONG_UNALIGNED, offset); - offset += 8; + cms.sketchArray_[i] = posSeg.getLong(); } return cms; From 66f13338464566d7a940843ce70c70ebac6fb657 Mon Sep 17 00:00:00 2001 From: Pierre Lacave Date: Sat, 26 Jul 2025 15:21:34 +0200 Subject: [PATCH 6/6] Add java doc for CMS --- .../org/apache/datasketches/count/CountMinSketch.java | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/main/java/org/apache/datasketches/count/CountMinSketch.java b/src/main/java/org/apache/datasketches/count/CountMinSketch.java index 52801f01e..eb9fb1ce8 100644 --- a/src/main/java/org/apache/datasketches/count/CountMinSketch.java +++ b/src/main/java/org/apache/datasketches/count/CountMinSketch.java @@ -36,6 +36,16 @@ import static java.lang.foreign.ValueLayout.JAVA_SHORT_UNALIGNED; +/** + * Java implementation of the CountMin sketch data structure of Cormode and Muthukrishnan. + * This implementation is inspired by and compatible with the datasketches-cpp version by Charlie Dickens. + * + * The CountMin sketch is a probabilistic data structure that provides frequency estimates for items + * in a data stream. It uses multiple hash functions to distribute items across a two-dimensional array, + * providing approximate counts with configurable error bounds. + * + * Reference: http://dimacs.rutgers.edu/~graham/pubs/papers/cm-full.pdf + */ public class CountMinSketch { private final byte numHashes_; private final int numBuckets_;