diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java index ad5e131cb2d7..5eb317fc2875 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java @@ -20,6 +20,7 @@ import java.util.List; import java.util.concurrent.atomic.AtomicLong; import org.apache.beam.sdk.util.Weighted; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath; /** Facade for a {@link List} that keeps track of weight, for cache limit reasons. */ public class WeightedList implements Weighted { @@ -71,14 +72,6 @@ public void addAll(List values, long weight) { } public void accumulateWeight(long weight) { - this.weight.accumulateAndGet( - weight, - (first, second) -> { - try { - return Math.addExact(first, second); - } catch (ArithmeticException e) { - return Long.MAX_VALUE; - } - }); + this.weight.accumulateAndGet(weight, LongMath::saturatedAdd); } } diff --git a/sdks/java/harness/jmh/src/main/java/org/apache/beam/fn/harness/jmh/ProcessBundleBenchmark.java b/sdks/java/harness/jmh/src/main/java/org/apache/beam/fn/harness/jmh/ProcessBundleBenchmark.java index 3b8fbeaf3dd0..fecc946a184f 100644 --- a/sdks/java/harness/jmh/src/main/java/org/apache/beam/fn/harness/jmh/ProcessBundleBenchmark.java +++ b/sdks/java/harness/jmh/src/main/java/org/apache/beam/fn/harness/jmh/ProcessBundleBenchmark.java @@ -198,12 +198,16 @@ public void log(LogEntry entry) { @TearDown public void tearDown() { try { + // Shutting down the control server should terminate the sdk client. + // We do this before shutting down logging server in particular as that can + // trigger exceptions if the client was not yet shutdown. controlServer.close(); + sdkHarnessExecutorFuture.get(); + stateServer.close(); dataServer.close(); loggingServer.close(); controlClient.close(); - sdkHarnessExecutorFuture.get(); } catch (InterruptedException ignored) { Thread.currentThread().interrupt(); } catch (Exception e) { diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java index 089ee3eda0fb..de4303c4e7c1 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java @@ -17,6 +17,8 @@ */ package org.apache.beam.fn.harness; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; + import java.util.Arrays; import java.util.Collections; import java.util.Objects; @@ -25,6 +27,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.LongAdder; import java.util.function.Function; +import javax.annotation.Nullable; import org.apache.beam.fn.harness.Cache.Shrinkable; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.SdkHarnessOptions; @@ -41,7 +44,6 @@ import org.slf4j.LoggerFactory; /** Utility methods used to instantiate and operate over cache instances. */ -@SuppressWarnings("nullness") public final class Caches { private static final Logger LOG = LoggerFactory.getLogger(Caches.class); @@ -70,7 +72,7 @@ public final class Caches { public static final long REFERENCE_SIZE = 8; /** Returns the amount of memory in bytes the provided object consumes. */ - public static long weigh(Object o) { + public static long weigh(@Nullable Object o) { if (o == null) { return REFERENCE_SIZE; } @@ -115,6 +117,7 @@ static class ShrinkOnEviction implements RemovalListener> cacheBuilder, LongAdder weightInBytes) { this.cache = cacheBuilder.removalListener(this).build(); @@ -130,18 +133,19 @@ static class ShrinkOnEviction implements RemovalListener> removalNotification) { - weightInBytes.add( - -(removalNotification.getKey().getWeight() + removalNotification.getValue().getWeight())); - if (removalNotification.wasEvicted()) { - if (!(removalNotification.getValue().getValue() instanceof Cache.Shrinkable)) { - return; - } - Object updatedEntry = ((Shrinkable) removalNotification.getValue().getValue()).shrink(); - if (updatedEntry != null) { - cache.put( - removalNotification.getKey(), - addWeightedValue(removalNotification.getKey(), updatedEntry, weightInBytes)); - } + CompositeKey key = checkNotNull(removalNotification.getKey()); + WeightedValue value = checkNotNull(removalNotification.getValue()); + weightInBytes.add(-(key.getWeight() + value.getWeight())); + if (!removalNotification.wasEvicted()) { + return; + } + @Nullable Object v = value.getValue(); + if (!(v instanceof Cache.Shrinkable)) { + return; + } + @Nullable Object updatedEntry = ((Shrinkable) v).shrink(); + if (updatedEntry != null) { + cache.put(key, addWeightedValue(key, updatedEntry, weightInBytes)); } } } @@ -282,8 +286,8 @@ private static class SubCache implements Cache { } @Override - public V peek(K key) { - WeightedValue value = cache.getIfPresent(keyPrefix.valueKey(key)); + public @Nullable V peek(K key) { + @Nullable WeightedValue value = cache.getIfPresent(keyPrefix.valueKey(key)); if (value == null) { return null; } @@ -298,7 +302,9 @@ public V computeIfAbsent(K key, Function loadingFunction) { cache .get( compositeKey, - () -> addWeightedValue(compositeKey, loadingFunction.apply(key), weightInBytes)) + () -> + addWeightedValue( + compositeKey, checkNotNull(loadingFunction.apply(key)), weightInBytes)) .getValue(); } catch (ExecutionException e) { throw new RuntimeException(e); @@ -308,7 +314,7 @@ public V computeIfAbsent(K key, Function loadingFunction) { @Override public void put(K key, V value) { CompositeKey compositeKey = keyPrefix.valueKey(key); - cache.put(compositeKey, addWeightedValue(compositeKey, value, weightInBytes)); + cache.put(compositeKey, addWeightedValue(compositeKey, checkNotNull(value), weightInBytes)); } @Override @@ -356,7 +362,7 @@ CompositeKeyPrefix subKey(Object suffix, Object... additionalSuffixes) { return new CompositeKeyPrefix(subKey, subKeyWeight); } - CompositeKey valueKey(K k) { + CompositeKey valueKey(@Nullable K k) { return new CompositeKey(namespace, weight, k); } @@ -391,10 +397,10 @@ boolean isEquivalentNamespace(CompositeKey otherKey) { @VisibleForTesting static class CompositeKey implements Weighted { private final Object[] namespace; - private final Object key; + private final @Nullable Object key; private final long weight; - private CompositeKey(Object[] namespace, long namespaceWeight, Object key) { + private CompositeKey(Object[] namespace, long namespaceWeight, @Nullable Object key) { this.namespace = namespace; this.key = key; this.weight = namespaceWeight + weigh(key); @@ -402,11 +408,15 @@ private CompositeKey(Object[] namespace, long namespaceWeight, Object key) { @Override public String toString() { - return "CompositeKey{namespace=" + Arrays.toString(namespace) + ", key=" + key + "}"; + return "CompositeKey{namespace=" + + Arrays.toString(namespace) + + ", key=" + + String.valueOf(key) + + "}"; } @Override - public boolean equals(Object o) { + public boolean equals(@Nullable Object o) { if (this == o) { return true; } @@ -434,6 +444,7 @@ public long getWeight() { *

The set of keys that are tracked are only those provided to {@link #peek} and {@link * #computeIfAbsent}. */ + @SuppressWarnings("nullness") public static class ClearableCache extends SubCache { private final Set weakHashSet; diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java index 1e06c98f2e31..3d83d0582e2c 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java @@ -17,12 +17,13 @@ */ package org.apache.beam.fn.harness.state; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; -import com.google.auto.value.AutoValue; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -49,14 +50,13 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath; /** * Adapters which convert a logical series of chunks using continuation tokens over the Beam Fn * State API into an {@link Iterator} of {@link ByteString}s. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) public class StateFetchingIterators { // do not instantiate @@ -146,7 +146,7 @@ public DecodingIterator(PrefetchableIterator chunkIterator, Coder } @Override - protected T computeNext() { + protected @Nullable T computeNext() { try { while (currentChunk.available() == 0) { if (chunkIterator.hasNext()) { @@ -249,15 +249,11 @@ static class BlocksPrefix extends Blocks implements Shrinkable block : blocks) { - sum = Math.addExact(sum, block.getWeight()); - } - return sum; - } catch (ArithmeticException e) { - return Long.MAX_VALUE; + long sum = 8 + blocks.size() * 8L; + for (Block block : blocks) { + sum = LongMath.saturatedAdd(sum, block.getWeight()); } + return sum; } BlocksPrefix(List> blocks) { @@ -265,7 +261,7 @@ public long getWeight() { } @Override - public BlocksPrefix shrink() { + public @Nullable BlocksPrefix shrink() { // Copy the list to not hold a reference to the tail of the original list. List> subList = new ArrayList<>(getBlocks().subList(0, getBlocks().size() / 2)); if (subList.isEmpty()) { @@ -280,10 +276,71 @@ public List> getBlocks() { } } - @AutoValue - abstract static class Block implements Weighted { - private static final Block EMPTY = - fromValues(WeightedList.of(Collections.emptyList(), 0), null); + static class Block implements Weighted { + private final List values; + private final @Nullable ByteString nextToken; + private final long weight; + private static final long ARRAY_LIST_OVERHEAD = 3 * Caches.REFERENCE_SIZE; + private static final Block EMPTY = new Block<>(ImmutableList.of(), null, 0); + + private Block(List values, @Nullable ByteString nextToken, long weight) { + this.values = values; + this.nextToken = nextToken; + this.weight = weight; + } + + public @Nullable ByteString getNextToken() { + return nextToken; + } + + @Override + public long getWeight() { + return weight; + } + + public List getValues() { + return values; + } + + static class Builder { + private Builder(int initialCapacity) { + values = new ArrayList<>(initialCapacity); + weight = ARRAY_LIST_OVERHEAD; + } + + void addAndWeighAll(List addedValues) { + values.addAll(addedValues); + for (@Nullable T v : addedValues) { + weight = LongMath.saturatedAdd(weight, Caches.weigh(v)); + } + } + + void addAllWithWeight(List addedValues, long addedWeight) { + values.addAll(addedValues); + weight = LongMath.saturatedAdd(weight, addedWeight); + } + + void addBlock(Block b) { + values.addAll(b.getValues()); + weight = LongMath.saturatedAdd(weight, b.getWeight()); + } + + // The builder should not be used after this method. + Block build() { + if (values.isEmpty()) { + return emptyBlock(); + } + weight = LongMath.saturatedAdd(values.size() * Caches.REFERENCE_SIZE, weight); + return new Block<>(values, null, weight); + } + + private long weight; + private final ArrayList values; + } + + static Builder builder(int initialCapacity) { + return new Builder<>(initialCapacity); + } @SuppressWarnings("unchecked") // Based upon as Collections.emptyList() public static Block emptyBlock() { @@ -299,29 +356,46 @@ public static Block mutatedBlock(WeightedList values) { } public static Block fromValues(List values, @Nullable ByteString nextToken) { - return fromValues(WeightedList.of(values, Caches.weigh(values)), nextToken); + long listWeight = values.size() * Caches.REFERENCE_SIZE; + for (@Nullable T value : values) { + listWeight = LongMath.saturatedAdd(listWeight, Caches.weigh(value)); + } + return copyValues(values, listWeight, nextToken); } public static Block fromValues( WeightedList values, @Nullable ByteString nextToken) { - long weight = values.getWeight() + 24; + return copyValues(values.getBacking(), values.getWeight(), nextToken); + } + + private static Block copyValues( + List values, long valuesWeight, @Nullable ByteString nextToken) { + if (nextToken == null && values.isEmpty()) { + return emptyBlock(); + } + long listOverhead; + if (values instanceof ImmutableList) { + listOverhead = Caches.REFERENCE_SIZE * 3; + } else { + // We can't use ImmutableList.copy because that requires non-null values + // and null values are supported in Beam. + @SuppressWarnings("unchecked") + List copiedValues = (List) Arrays.asList(values.toArray()); + values = copiedValues; + listOverhead = ARRAY_LIST_OVERHEAD; + } + long weight = LongMath.saturatedAdd(valuesWeight, listOverhead); if (nextToken != null) { if (nextToken.isEmpty()) { nextToken = ByteString.EMPTY; } else { - weight += Caches.weigh(nextToken); + // We don't expect large tokens that would not be copied by ByteString so we just count the size plus + // some overhead as weighing accurately is expensive. + weight = LongMath.saturatedAdd(weight, (long) nextToken.size() + Caches.REFERENCE_SIZE * 2); } } - return new AutoValue_StateFetchingIterators_CachingStateIterable_Block<>( - values.getBacking(), nextToken, weight); + return new Block<>(values, nextToken, weight); } - - abstract List getValues(); - - abstract @Nullable ByteString getNextToken(); - - @Override - public abstract long getWeight(); } private final Cache> cache; @@ -372,10 +446,11 @@ public void remove(Set toRemoveStructuralValues) { totalSize += tBlock.getValues().size(); } - WeightedList allValues = WeightedList.of(new ArrayList<>(totalSize), 0L); + Block.Builder builder = Block.builder(totalSize); + List blockValuesToKeep = new ArrayList<>(); for (Block block : blocks) { + blockValuesToKeep.clear(); boolean valueRemovedFromBlock = false; - List blockValuesToKeep = new ArrayList<>(); for (T value : block.getValues()) { if (!toRemoveStructuralValues.contains(valueCoder.structuralValue(value))) { blockValuesToKeep.add(value); @@ -387,13 +462,13 @@ public void remove(Set toRemoveStructuralValues) { // If any value was removed from this block, need to estimate the weight again. // Otherwise, just reuse the block's weight. if (valueRemovedFromBlock) { - allValues.addAll(blockValuesToKeep, Caches.weigh(block.getValues())); + builder.addAndWeighAll(blockValuesToKeep); } else { - allValues.addAll(block.getValues(), block.getWeight()); + builder.addBlock(block); } } - cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.mutatedBlock(allValues))); + cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(builder.build())); } /** @@ -484,21 +559,16 @@ private void appendHelper(List newValues, long newWeight) { for (Block block : blocks) { totalSize += block.getValues().size(); } - WeightedList allValues = WeightedList.of(new ArrayList<>(totalSize), 0L); + Block.Builder builder = Block.builder(totalSize); for (Block block : blocks) { - allValues.addAll(block.getValues(), block.getWeight()); + builder.addBlock(block); } if (newWeight < 0) { - if (newValues.size() == 1) { - // Optimize weighing of the common value state as single single-element bag state. - newWeight = Caches.weigh(newValues.get(0)); - } else { - newWeight = Caches.weigh(newValues); - } + builder.addAndWeighAll(newValues); + } else { + builder.addAllWithWeight(newValues, newWeight); } - allValues.addAll(newValues, newWeight); - - cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.mutatedBlock(allValues))); + cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(builder.build())); } class CachingStateIterator implements PrefetchableIterator { @@ -515,8 +585,7 @@ public CachingStateIterator() { new DataStreamDecoder<>(valueCoder, underlyingStateFetchingIterator); this.currentBlock = Block.fromValues( - WeightedList.of(Collections.emptyList(), 0L), - stateRequestForFirstChunk.getGet().getContinuationToken()); + ImmutableList.of(), stateRequestForFirstChunk.getGet().getContinuationToken()); this.currentCachedBlockValueIndex = 0; } @@ -526,11 +595,12 @@ public boolean isReady() { if (currentBlock.getValues().size() > currentCachedBlockValueIndex) { return true; } - if (currentBlock.getNextToken() == null) { + @Nullable ByteString nextToken = currentBlock.getNextToken(); + if (nextToken == null) { return true; } - Blocks existing = cache.peek(IterableCacheKey.INSTANCE); - boolean isFirstBlock = ByteString.EMPTY.equals(currentBlock.getNextToken()); + @Nullable Blocks existing = cache.peek(IterableCacheKey.INSTANCE); + boolean isFirstBlock = ByteString.EMPTY.equals(nextToken); if (existing == null) { // If there is nothing cached and we are on the first block then we are not ready. return false; @@ -542,9 +612,7 @@ public boolean isReady() { List> blocks = existing.getBlocks(); int currentBlockIndex = 0; for (; currentBlockIndex < blocks.size(); ++currentBlockIndex) { - if (currentBlock - .getNextToken() - .equals(blocks.get(currentBlockIndex).getNextToken())) { + if (Objects.equals(nextToken, blocks.get(currentBlockIndex).getNextToken())) { break; } } @@ -580,8 +648,7 @@ public boolean hasNext() { return false; } // Release the block while we are loading the next one. - currentBlock = - Block.fromValues(WeightedList.of(Collections.emptyList(), 0L), ByteString.EMPTY); + currentBlock = Block.emptyBlock(); @Nullable Blocks existing = cache.peek(IterableCacheKey.INSTANCE); boolean isFirstBlock = ByteString.EMPTY.equals(nextToken); @@ -603,7 +670,7 @@ public boolean hasNext() { List> blocks = existing.getBlocks(); int currentBlockIndex = 0; for (; currentBlockIndex < blocks.size(); ++currentBlockIndex) { - if (nextToken.equals(blocks.get(currentBlockIndex).getNextToken())) { + if (Objects.equals(nextToken, blocks.get(currentBlockIndex).getNextToken())) { break; } } @@ -622,7 +689,8 @@ public boolean hasNext() { // tokens. if (existing != null && !existing.getBlocks().isEmpty() - && nextToken.equals( + && Objects.equals( + nextToken, existing.getBlocks().get(existing.getBlocks().size() - 1).getNextToken())) { List> newBlocks = new ArrayList<>(currentBlockIndex + 1); newBlocks.addAll(existing.getBlocks()); @@ -639,8 +707,8 @@ public boolean hasNext() { Block loadNextBlock(ByteString continuationToken) { underlyingStateFetchingIterator.seekToContinuationToken(continuationToken); WeightedList values = dataStreamDecoder.decodeFromChunkBoundaryToChunkBoundary(); - ByteString nextToken = underlyingStateFetchingIterator.getContinuationToken(); - if (ByteString.EMPTY.equals(nextToken)) { + @Nullable ByteString nextToken = underlyingStateFetchingIterator.getContinuationToken(); + if (Objects.equals(nextToken, ByteString.EMPTY)) { nextToken = null; } return Block.fromValues(values, nextToken); @@ -667,8 +735,8 @@ static class LazyBlockingStateFetchingIterator implements PrefetchableIterator prefetchedResponse; + private @Nullable ByteString continuationToken; + private @Nullable CompletableFuture prefetchedResponse; LazyBlockingStateFetchingIterator( BeamFnStateClient beamFnStateClient, StateRequest stateRequestForFirstChunk) { @@ -739,7 +807,7 @@ public ByteString next() { prefetch(); StateResponse stateResponse; try { - stateResponse = prefetchedResponse.get(); + stateResponse = checkNotNull(prefetchedResponse).get(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IllegalStateException(e); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java index d1cacf534ee5..43087cabbd37 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java @@ -47,6 +47,7 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; +import org.apache.beam.sdk.coders.NullableCoder; import org.apache.beam.sdk.fn.data.WeightedList; import org.apache.beam.sdk.fn.stream.PrefetchableIterator; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; @@ -296,6 +297,44 @@ public void testBlocksWeight() throws Exception { assertEquals(Long.MAX_VALUE, blocksOverflow.getWeight()); } + // Regression test: ensure null values are supported + @Test + public void testNullableValues() throws Exception { + StateRequest requestForFirstChunk = + StateRequest.newBuilder() + .setStateKey( + StateKey.newBuilder() + .setBagUserState( + StateKey.BagUserState.newBuilder() + .setTransformId("transformId") + .setUserStateId("stateId") + .setKey(ByteString.copyFromUtf8("key")) + .setWindow(ByteString.copyFromUtf8("window")))) + .setGet(StateGetRequest.getDefaultInstance()) + .build(); + + List expected = Arrays.asList(0, null, 1, null, 2); + FakeBeamFnStateClient fakeStateClient = + new FakeBeamFnStateClient( + NullableCoder.of(BigEndianIntegerCoder.of()), + ImmutableMap.of(requestForFirstChunk.getStateKey(), expected), + 2); + + CachingStateIterable iterable = + new CachingStateIterable<>( + Caches.eternal(), + fakeStateClient, + requestForFirstChunk, + NullableCoder.of(BigEndianIntegerCoder.of())); + + List results = new ArrayList<>(); + PrefetchableIterator iterator = iterable.createIterator(); + while (iterator.hasNext()) { + results.add(iterator.next()); + } + assertEquals(expected, results); + } + private CachingStateIterable create(int chunkSize, int... values) { StateRequest requestForFirstChunk = StateRequest.newBuilder()