diff --git a/src/main/java/org/apache/datasketches/count/CountMinSketch.java b/src/main/java/org/apache/datasketches/count/CountMinSketch.java index 36bea38cf..eb9fb1ce8 100644 --- a/src/main/java/org/apache/datasketches/count/CountMinSketch.java +++ b/src/main/java/org/apache/datasketches/count/CountMinSketch.java @@ -22,15 +22,30 @@ import org.apache.datasketches.common.Family; import org.apache.datasketches.common.SketchesArgumentException; import org.apache.datasketches.common.SketchesException; +import org.apache.datasketches.common.Util; +import org.apache.datasketches.common.positional.PositionalSegment; import org.apache.datasketches.hash.MurmurHash3; -import org.apache.datasketches.tuple.Util; -import java.io.ByteArrayOutputStream; -import java.nio.ByteBuffer; +import java.lang.foreign.MemorySegment; import java.nio.charset.StandardCharsets; import java.util.Random; +import static java.lang.foreign.ValueLayout.JAVA_BYTE; +import static java.lang.foreign.ValueLayout.JAVA_INT_UNALIGNED; +import static java.lang.foreign.ValueLayout.JAVA_LONG_UNALIGNED; +import static java.lang.foreign.ValueLayout.JAVA_SHORT_UNALIGNED; + +/** + * Java implementation of the CountMin sketch data structure of Cormode and Muthukrishnan. + * This implementation is inspired by and compatible with the datasketches-cpp version by Charlie Dickens. + * + * The CountMin sketch is a probabilistic data structure that provides frequency estimates for items + * in a data stream. It uses multiple hash functions to distribute items across a two-dimensional array, + * providing approximate counts with configurable error bounds. + * + * Reference: http://dimacs.rutgers.edu/~graham/pubs/papers/cm-full.pdf + */ public class CountMinSketch { private final byte numHashes_; private final int numBuckets_; @@ -39,6 +54,9 @@ public class CountMinSketch { private final long[] sketchArray_; private long totalWeight_; + // Thread-local MemorySegment to avoid allocations in hot paths with explicit endianness control + private static final ThreadLocal LONG_SEGMENT = + ThreadLocal.withInitial(() -> MemorySegment.ofArray(new byte[Long.BYTES])); private enum Flag { IS_EMPTY; @@ -57,35 +75,64 @@ int mask() { * @param seed The base hash seed */ CountMinSketch(final byte numHashes, final int numBuckets, final long seed) { - numHashes_ = numHashes; - numBuckets_ = numBuckets; - seed_ = seed; - hashSeeds_ = new long[numHashes]; - sketchArray_ = new long[numHashes * numBuckets]; - totalWeight_ = 0; + // Validate numHashes + if (numHashes <= 0) { + throw new SketchesArgumentException("Number of hash functions must be positive, got: " + numHashes); + } + // Validate numBuckets with clear mathematical justification + if (numBuckets <= 0) { + throw new SketchesArgumentException("Number of buckets must be positive, got: " + numBuckets); + } if (numBuckets < 3) { - throw new SketchesArgumentException("Using fewer than 3 buckets incurs relative error greater than 1."); + throw new SketchesArgumentException("Number of buckets must be at least 3 to ensure relative error ≤ 1.0. " + + "With " + numBuckets + " buckets, relative error would be " + String.format("%.3f", Math.exp(1.0) / numBuckets)); + } + + // Check for potential overflow in array size calculation + // Use long arithmetic to detect overflow before casting + final long totalSize = (long) numHashes * (long) numBuckets; + if (totalSize > Integer.MAX_VALUE) { + throw new SketchesArgumentException("Sketch array size would overflow: " + numHashes + " * " + numBuckets + + " = " + totalSize + " > " + Integer.MAX_VALUE); } // This check is to ensure later compatibility with a Java implementation whose maximum size can only // be 2^31-1. We check only against 2^30 for simplicity. - if (numBuckets * numHashes >= 1 << 30) { - throw new SketchesArgumentException("These parameters generate a sketch that exceeds 2^30 elements. \n" + - "Try reducing either the number of buckets or the number of hash functions."); + if (totalSize >= (1L << 30)) { + throw new SketchesArgumentException("Sketch would require excessive memory: " + numHashes + " * " + numBuckets + + " = " + totalSize + " elements (~" + String.format("%d", totalSize * Long.BYTES / (1024 * 1024 * 1024)) + " GB). " + + "Consider reducing numHashes or numBuckets."); } - Random rand = new Random(seed); + numHashes_ = numHashes; + numBuckets_ = numBuckets; + seed_ = seed; + hashSeeds_ = new long[numHashes]; + sketchArray_ = new long[(int) totalSize]; + totalWeight_ = 0; + + final Random rand = new Random(seed); for (int i = 0; i < numHashes; i++) { hashSeeds_[i] = rand.nextLong(); } } - private long[] getHashes(byte[] item) { - long[] updateLocations = new long[numHashes_]; + /** + * Efficiently converts a long to byte array using thread-local MemorySegment with explicit endianness. + */ + private static byte[] longToBytes(final long value) { + final MemorySegment segment = LONG_SEGMENT.get(); + segment.set(JAVA_LONG_UNALIGNED, 0, value); + return segment.toArray(JAVA_BYTE); + } + + + private long[] getHashes(final byte[] item) { + final long[] updateLocations = new long[numHashes_]; for (int i = 0; i < numHashes_; i++) { - long[] index = MurmurHash3.hash(item, hashSeeds_[i]); + final long[] index = MurmurHash3.hash(item, hashSeeds_[i]); updateLocations[i] = i * (long)numBuckets_ + Math.floorMod(index[0], numBuckets_); } @@ -145,11 +192,11 @@ public double getRelativeError() { * @param confidence The desired confidence level between 0 and 1. * @return Suggested number of hash functions. */ - public static byte suggestNumHashes(double confidence) { + public static byte suggestNumHashes(final double confidence) { if (confidence < 0 || confidence > 1) { throw new SketchesException("Confidence must be between 0 and 1.0 (inclusive)."); } - int value = (int) Math.ceil(Math.log(1.0 / (1.0 - confidence))); + final int value = (int) Math.ceil(Math.log(1.0 / (1.0 - confidence))); return (byte) Math.min(value, 127); } @@ -158,7 +205,7 @@ public static byte suggestNumHashes(double confidence) { * @param relativeError The desired relative error. * @return Suggested number of buckets. */ - public static int suggestNumBuckets(double relativeError) { + public static int suggestNumBuckets(final double relativeError) { if (relativeError < 0.) { throw new SketchesException("Relative error must be at least 0."); } @@ -171,8 +218,7 @@ public static int suggestNumBuckets(double relativeError) { * @param weight The weight of the item. */ public void update(final long item, final long weight) { - byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); - update(longByte, weight); + update(longToBytes(item), weight); } /** @@ -199,8 +245,8 @@ public void update(final byte[] item, final long weight) { } totalWeight_ += weight > 0 ? weight : -weight; - long[] hashLocations = getHashes(item); - for (long h : hashLocations) { + final long[] hashLocations = getHashes(item); + for (final long h : hashLocations) { sketchArray_[(int) h] += weight; } } @@ -211,8 +257,7 @@ public void update(final byte[] item, final long weight) { * @return Estimated frequency. */ public long getEstimate(final long item) { - byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); - return getEstimate(longByte); + return getEstimate(longToBytes(item)); } /** @@ -239,10 +284,11 @@ public long getEstimate(final byte[] item) { return 0; } - long[] hashLocations = getHashes(item); + final long[] hashLocations = getHashes(item); long res = sketchArray_[(int) hashLocations[0]]; - for (long h : hashLocations) { - res = Math.min(res, sketchArray_[(int) h]); + // Start from index 1 to avoid processing first element twice + for (int i = 1; i < hashLocations.length; i++) { + res = Math.min(res, sketchArray_[(int) hashLocations[i]]); } return res; @@ -254,8 +300,7 @@ public long getEstimate(final byte[] item) { * @return Upper bound of estimated frequency. */ public long getUpperBound(final long item) { - byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); - return getUpperBound(longByte); + return getUpperBound(longToBytes(item)); } /** @@ -268,8 +313,8 @@ public long getUpperBound(final String item) { return 0; } - byte[] strByte = item.getBytes(StandardCharsets.UTF_8); - return getUpperBound(strByte); + final byte[] strByte = item.getBytes(StandardCharsets.UTF_8); + return getUpperBound(strByte); } /** @@ -291,8 +336,7 @@ public long getUpperBound(final byte[] item) { * @return Lower bound of estimated frequency. */ public long getLowerBound(final long item) { - byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); - return getLowerBound(longByte); + return getLowerBound(longToBytes(item)); } /** @@ -305,7 +349,7 @@ public long getLowerBound(final String item) { return 0; } - byte[] strByte = item.getBytes(StandardCharsets.UTF_8); + final byte[] strByte = item.getBytes(StandardCharsets.UTF_8); return getLowerBound(strByte); } @@ -327,8 +371,8 @@ public void merge(final CountMinSketch other) { throw new SketchesException("Cannot merge a sketch with itself"); } - boolean acceptableConfig = getNumBuckets_() == other.getNumBuckets_() && - getNumHashes_() == other.getNumHashes_() && getSeed_() == other.getSeed_(); + final boolean acceptableConfig = getNumBuckets_() == other.getNumBuckets_() + && getNumHashes_() == other.getNumHashes_() && getSeed_() == other.getSeed_(); if (!acceptableConfig) { throw new SketchesException("Incompatible sketch configuration."); @@ -342,39 +386,56 @@ public void merge(final CountMinSketch other) { } /** - * Serializes the sketch into the provided ByteBuffer. - * @param buf The ByteBuffer to write into. + * Returns the serialized size in bytes. + */ + private int getSerializedSizeBytes() { + final int preambleBytes = Family.COUNTMIN.getMinPreLongs() * Long.BYTES; + if (isEmpty()) { + return preambleBytes; + } + return preambleBytes + Long.BYTES + (sketchArray_.length * Long.BYTES); + } + + + /** + * Returns the sketch as a byte array. */ - public void serialize(ByteArrayOutputStream buf) { + public byte[] toByteArray() { + final int serializedSizeBytes = getSerializedSizeBytes(); + final byte[] bytes = new byte[serializedSizeBytes]; + final PositionalSegment posSeg = PositionalSegment.wrap(MemorySegment.ofArray(bytes)); + // Long 0 final int preambleLongs = Family.COUNTMIN.getMinPreLongs(); - buf.write((byte) preambleLongs); + posSeg.setByte((byte) preambleLongs); final int serialVersion = 1; - buf.write((byte) serialVersion); + posSeg.setByte((byte) serialVersion); final int familyId = Family.COUNTMIN.getID(); - buf.write((byte) familyId); + posSeg.setByte((byte) familyId); final int flagsByte = isEmpty() ? Flag.IS_EMPTY.mask() : 0; - buf.write((byte)flagsByte); + posSeg.setByte((byte) flagsByte); final int NULL_32 = 0; - buf.writeBytes(ByteBuffer.allocate(4).putInt(NULL_32).array()); + posSeg.setInt(NULL_32); // Long 1 - buf.writeBytes(ByteBuffer.allocate(4).putInt(numBuckets_).array()); - buf.write(numHashes_); - short hashSeed = Util.computeSeedHash(seed_); - buf.writeBytes(ByteBuffer.allocate(2).putShort(hashSeed).array()); + posSeg.setInt(numBuckets_); + posSeg.setByte(numHashes_); + final short hashSeed = Util.computeSeedHash(seed_); + posSeg.setShort(hashSeed); final byte NULL_8 = 0; - buf.write(NULL_8); + posSeg.setByte(NULL_8); + if (isEmpty()) { - return; + return bytes; } - final byte[] totWeightByte = ByteBuffer.allocate(8).putLong(totalWeight_).array(); - buf.writeBytes(totWeightByte); + posSeg.setLong(totalWeight_); - for (long w: sketchArray_) { - buf.writeBytes(ByteBuffer.allocate(8).putLong(w).array()); + for (final long w: sketchArray_) { + posSeg.setLong(w); } + + return bytes; } /** @@ -384,36 +445,51 @@ public void serialize(ByteArrayOutputStream buf) { * @return The deserialized CountMinSketch. */ public static CountMinSketch deserialize(final byte[] b, final long seed) { - ByteBuffer buf = ByteBuffer.allocate(b.length); - buf.put(b); - buf.flip(); - - final byte preambleLongs = buf.get(); - final byte serialVersion = buf.get(); - final byte familyId = buf.get(); - final byte flagsByte = buf.get(); - final int NULL_32 = buf.getInt(); + final PositionalSegment posSeg = PositionalSegment.wrap(MemorySegment.ofArray(b)); + + final byte preambleLongs = posSeg.getByte(); + final byte serialVersion = posSeg.getByte(); + final byte familyId = posSeg.getByte(); + final byte flagsByte = posSeg.getByte(); + posSeg.getInt(); // skip NULL_32 + + // Validate serialization format + final int expectedPreambleLongs = Family.COUNTMIN.getMinPreLongs(); + if (preambleLongs != expectedPreambleLongs) { + throw new SketchesArgumentException("Preamble longs mismatch: expected " + expectedPreambleLongs + + ", actual " + preambleLongs); + } + final int expectedSerialVersion = 1; + if (serialVersion != expectedSerialVersion) { + throw new SketchesArgumentException("Serial version mismatch: expected " + expectedSerialVersion + + ", actual " + serialVersion); + } + final int expectedFamilyId = Family.COUNTMIN.getID(); + if (familyId != expectedFamilyId) { + throw new SketchesArgumentException("Family ID mismatch: expected " + expectedFamilyId + + ", actual " + familyId); + } - final int numBuckets = buf.getInt(); - final byte numHashes = buf.get(); - final short seedHash = buf.getShort(); - final byte NULL_8 = buf.get(); + final int numBuckets = posSeg.getInt(); + final byte numHashes = posSeg.getByte(); + final short seedHash = posSeg.getShort(); + posSeg.getByte(); // skip NULL_8 if (seedHash != Util.computeSeedHash(seed)) { - throw new SketchesArgumentException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", " - + String.valueOf(Util.computeSeedHash(seed))); + throw new SketchesArgumentException("Incompatible seed hashes: " + seedHash + ", " + + Util.computeSeedHash(seed)); } - CountMinSketch cms = new CountMinSketch(numHashes, numBuckets, seed); + final CountMinSketch cms = new CountMinSketch(numHashes, numBuckets, seed); final boolean empty = (flagsByte & Flag.IS_EMPTY.mask()) > 0; if (empty) { return cms; } - long w = buf.getLong(); + final long w = posSeg.getLong(); cms.totalWeight_ = w; for (int i = 0; i < cms.sketchArray_.length; i++) { - cms.sketchArray_[i] = buf.getLong(); + cms.sketchArray_[i] = posSeg.getLong(); } return cms; diff --git a/src/test/java/org/apache/datasketches/count/CountMinSketchTest.java b/src/test/java/org/apache/datasketches/count/CountMinSketchTest.java index cbd2fde79..cdf42c3b0 100644 --- a/src/test/java/org/apache/datasketches/count/CountMinSketchTest.java +++ b/src/test/java/org/apache/datasketches/count/CountMinSketchTest.java @@ -203,10 +203,7 @@ public void serializeDeserializeEmptyTest() { final long seed = 123456; CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed); - ByteArrayOutputStream buf = new ByteArrayOutputStream(); - c.serialize(buf); - - byte[] b = buf.toByteArray(); + byte[] b = c.toByteArray(); assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(b, seed - 1)); CountMinSketch d = CountMinSketch.deserialize(b, seed); @@ -228,11 +225,10 @@ public void serializeDeserializeTest() { c.update(i, 10*i*i); } - ByteArrayOutputStream buf = new ByteArrayOutputStream(); - c.serialize(buf); + byte[] b = c.toByteArray(); - assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(buf.toByteArray(), seed - 1)); - CountMinSketch d = CountMinSketch.deserialize(buf.toByteArray(), seed); + assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(b, seed - 1)); + CountMinSketch d = CountMinSketch.deserialize(b, seed); assertEquals(d.getNumHashes_(), c.getNumHashes_()); assertEquals(d.getNumBuckets_(), c.getNumBuckets_()); diff --git a/src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java b/src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java index 16cc55db9..906f1a52a 100644 --- a/src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java +++ b/src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java @@ -89,7 +89,7 @@ public void checkAllFlavorsGo() throws IOException { int flavorIdx = 0; for (int n: nArr) { final byte[] bytes = Files.readAllBytes(goPath.resolve("cpc_n" + n + "_go.sk")); - final CpcSketch sketch = CpcSketch.heapify(Memory.wrap(bytes)); + final CpcSketch sketch = CpcSketch.heapify(MemorySegment.ofArray(bytes)); assertEquals(sketch.getFlavor(), flavorArr[flavorIdx++]); assertEquals(sketch.getEstimate(), n, n * 0.02); }