From 839fa618a1323b176d442d860b4f685e9b627fa8 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Thu, 26 Mar 2026 09:54:57 +0100 Subject: [PATCH 1/5] Reapply "Reduce weighing overhead for caching blocks (#36897)" This reverts commit 63177cb5db6623bbfe9ba7ede1041f58dd4ddb93. --- .../apache/beam/sdk/fn/data/WeightedList.java | 11 +-- .../harness/state/StateFetchingIterators.java | 83 ++++++++++++------- 2 files changed, 55 insertions(+), 39 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java index ad5e131cb2d7..5eb317fc2875 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/WeightedList.java @@ -20,6 +20,7 @@ import java.util.List; import java.util.concurrent.atomic.AtomicLong; import org.apache.beam.sdk.util.Weighted; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath; /** Facade for a {@link List} that keeps track of weight, for cache limit reasons. */ public class WeightedList implements Weighted { @@ -71,14 +72,6 @@ public void addAll(List values, long weight) { } public void accumulateWeight(long weight) { - this.weight.accumulateAndGet( - weight, - (first, second) -> { - try { - return Math.addExact(first, second); - } catch (ArithmeticException e) { - return Long.MAX_VALUE; - } - }); + this.weight.accumulateAndGet(weight, LongMath::saturatedAdd); } } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java index 1e06c98f2e31..339ddad4061e 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java @@ -49,6 +49,8 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath; /** * Adapters which convert a logical series of chunks using continuation tokens over the Beam Fn @@ -249,15 +251,11 @@ static class BlocksPrefix extends Blocks implements Shrinkable block : blocks) { - sum = Math.addExact(sum, block.getWeight()); - } - return sum; - } catch (ArithmeticException e) { - return Long.MAX_VALUE; + long sum = 8 + blocks.size() * 8L; + for (Block block : blocks) { + sum = LongMath.saturatedAdd(sum, block.getWeight()); } + return sum; } BlocksPrefix(List> blocks) { @@ -282,8 +280,7 @@ public List> getBlocks() { @AutoValue abstract static class Block implements Weighted { - private static final Block EMPTY = - fromValues(WeightedList.of(Collections.emptyList(), 0), null); + private static final Block EMPTY = fromValues(ImmutableList.of(), 0, null); @SuppressWarnings("unchecked") // Based upon as Collections.emptyList() public static Block emptyBlock() { @@ -299,21 +296,37 @@ public static Block mutatedBlock(WeightedList values) { } public static Block fromValues(List values, @Nullable ByteString nextToken) { - return fromValues(WeightedList.of(values, Caches.weigh(values)), nextToken); + if (values.isEmpty() && nextToken == null) { + return emptyBlock(); + } + ImmutableList immutableValues = ImmutableList.copyOf(values); + long listWeight = immutableValues.size() * Caches.REFERENCE_SIZE; + for (T value : immutableValues) { + listWeight = LongMath.saturatedAdd(listWeight, Caches.weigh(value)); + } + return fromValues(immutableValues, listWeight, nextToken); } public static Block fromValues( WeightedList values, @Nullable ByteString nextToken) { - long weight = values.getWeight() + 24; + if (values.isEmpty() && nextToken == null) { + return emptyBlock(); + } + return fromValues(ImmutableList.copyOf(values.getBacking()), values.getWeight(), nextToken); + } + + private static Block fromValues( + ImmutableList values, long listWeight, @Nullable ByteString nextToken) { + long weight = LongMath.saturatedAdd(listWeight, 24); if (nextToken != null) { if (nextToken.isEmpty()) { nextToken = ByteString.EMPTY; } else { - weight += Caches.weigh(nextToken); + weight = LongMath.saturatedAdd(weight, Caches.weigh(nextToken)); } } return new AutoValue_StateFetchingIterators_CachingStateIterable_Block<>( - values.getBacking(), nextToken, weight); + values, nextToken, weight); } abstract List getValues(); @@ -372,10 +385,12 @@ public void remove(Set toRemoveStructuralValues) { totalSize += tBlock.getValues().size(); } - WeightedList allValues = WeightedList.of(new ArrayList<>(totalSize), 0L); + ImmutableList.Builder allValues = ImmutableList.builderWithExpectedSize(totalSize); + long weight = 0; + List blockValuesToKeep = new ArrayList<>(); for (Block block : blocks) { + blockValuesToKeep.clear(); boolean valueRemovedFromBlock = false; - List blockValuesToKeep = new ArrayList<>(); for (T value : block.getValues()) { if (!toRemoveStructuralValues.contains(valueCoder.structuralValue(value))) { blockValuesToKeep.add(value); @@ -387,13 +402,19 @@ public void remove(Set toRemoveStructuralValues) { // If any value was removed from this block, need to estimate the weight again. // Otherwise, just reuse the block's weight. if (valueRemovedFromBlock) { - allValues.addAll(blockValuesToKeep, Caches.weigh(block.getValues())); + allValues.addAll(blockValuesToKeep); + for (T value : blockValuesToKeep) { + weight = LongMath.saturatedAdd(weight, Caches.weigh(value)); + } } else { - allValues.addAll(block.getValues(), block.getWeight()); + allValues.addAll(block.getValues()); + weight = LongMath.saturatedAdd(weight, block.getWeight()); } } - cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.mutatedBlock(allValues))); + cache.put( + IterableCacheKey.INSTANCE, + new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, null))); } /** @@ -484,21 +505,24 @@ private void appendHelper(List newValues, long newWeight) { for (Block block : blocks) { totalSize += block.getValues().size(); } - WeightedList allValues = WeightedList.of(new ArrayList<>(totalSize), 0L); + ImmutableList.Builder allValues = ImmutableList.builderWithExpectedSize(totalSize); + long weight = 0; for (Block block : blocks) { - allValues.addAll(block.getValues(), block.getWeight()); + allValues.addAll(block.getValues()); + weight = LongMath.saturatedAdd(weight, block.getWeight()); } if (newWeight < 0) { - if (newValues.size() == 1) { - // Optimize weighing of the common value state as single single-element bag state. - newWeight = Caches.weigh(newValues.get(0)); - } else { - newWeight = Caches.weigh(newValues); + newWeight = 0; + for (T value : newValues) { + newWeight = LongMath.saturatedAdd(newWeight, Caches.weigh(value)); } } - allValues.addAll(newValues, newWeight); + allValues.addAll(newValues); + weight = LongMath.saturatedAdd(weight, newWeight); - cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(Block.mutatedBlock(allValues))); + cache.put( + IterableCacheKey.INSTANCE, + new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, null))); } class CachingStateIterator implements PrefetchableIterator { @@ -580,8 +604,7 @@ public boolean hasNext() { return false; } // Release the block while we are loading the next one. - currentBlock = - Block.fromValues(WeightedList.of(Collections.emptyList(), 0L), ByteString.EMPTY); + currentBlock = Block.emptyBlock(); @Nullable Blocks existing = cache.peek(IterableCacheKey.INSTANCE); boolean isFirstBlock = ByteString.EMPTY.equals(nextToken); From 4bbec9dab116debdb9ad5d2b65b713f9690b17a1 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Thu, 26 Mar 2026 13:43:33 +0100 Subject: [PATCH 2/5] Remove use of ImmutableList where values are not known to be non-null --- .../org/apache/beam/fn/harness/Caches.java | 4 +- .../harness/state/StateFetchingIterators.java | 162 ++++++++++++------ .../state/StateFetchingIteratorsTest.java | 39 +++++ 3 files changed, 147 insertions(+), 58 deletions(-) diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java index 089ee3eda0fb..8782a7ed8b67 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java @@ -25,6 +25,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.LongAdder; import java.util.function.Function; +import javax.annotation.Nullable; import org.apache.beam.fn.harness.Cache.Shrinkable; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.SdkHarnessOptions; @@ -70,7 +71,7 @@ public final class Caches { public static final long REFERENCE_SIZE = 8; /** Returns the amount of memory in bytes the provided object consumes. */ - public static long weigh(Object o) { + public static long weigh(@Nullable Object o) { if (o == null) { return REFERENCE_SIZE; } @@ -136,6 +137,7 @@ public void onRemoval( if (!(removalNotification.getValue().getValue() instanceof Cache.Shrinkable)) { return; } + @Nullable Object updatedEntry = ((Shrinkable) removalNotification.getValue().getValue()).shrink(); if (updatedEntry != null) { cache.put( diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java index 339ddad4061e..59a5ee50c0b6 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java @@ -17,12 +17,14 @@ */ package org.apache.beam.fn.harness.state; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import com.google.auto.value.AutoValue; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -56,9 +58,6 @@ * Adapters which convert a logical series of chunks using continuation tokens over the Beam Fn * State API into an {@link Iterator} of {@link ByteString}s. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) public class StateFetchingIterators { // do not instantiate @@ -148,7 +147,7 @@ public DecodingIterator(PrefetchableIterator chunkIterator, Coder } @Override - protected T computeNext() { + protected @Nullable T computeNext() { try { while (currentChunk.available() == 0) { if (chunkIterator.hasNext()) { @@ -263,7 +262,7 @@ public long getWeight() { } @Override - public BlocksPrefix shrink() { + public @Nullable BlocksPrefix shrink() { // Copy the list to not hold a reference to the tail of the original list. List> subList = new ArrayList<>(getBlocks().subList(0, getBlocks().size() / 2)); if (subList.isEmpty()) { @@ -280,7 +279,61 @@ public List> getBlocks() { @AutoValue abstract static class Block implements Weighted { - private static final Block EMPTY = fromValues(ImmutableList.of(), 0, null); + private static class EmptyBlock extends Block { + @Override + List getValues() { + return ImmutableList.of(); + } + + @Nullable + @Override + ByteString getNextToken() { + return null; + } + + @Override + public long getWeight() { + return 0; + } + } + + private static final Block EMPTY = new EmptyBlock(); + + static class Builder { + private Builder(int size) { + values = new ArrayList<>(size); + } + + void addAndWeighAll(List addedValues) { + values.addAll(addedValues); + for (@Nullable T v : addedValues) { + weight = LongMath.saturatedAdd(weight, Caches.weigh(v)); + } + } + + void addAllWithWeight(List addedValues, long addedWeight) { + values.addAll(addedValues); + weight = LongMath.saturatedAdd(weight, addedWeight); + } + + void addBlock(Block b) { + values.addAll(b.getValues()); + weight = LongMath.saturatedAdd(weight, b.getWeight()); + } + + // The builder should not be used after this method. + Block build() { + weight = LongMath.saturatedAdd(values.size() * Caches.REFERENCE_SIZE, weight); + return fromValues(values, weight, false, null); + } + + private long weight; + private final ArrayList values; + } + + static Builder builder(int maxSize) { + return new Builder<>(maxSize); + } @SuppressWarnings("unchecked") // Based upon as Collections.emptyList() public static Block emptyBlock() { @@ -296,28 +349,38 @@ public static Block mutatedBlock(WeightedList values) { } public static Block fromValues(List values, @Nullable ByteString nextToken) { - if (values.isEmpty() && nextToken == null) { + if (nextToken == null && values.isEmpty()) { return emptyBlock(); } - ImmutableList immutableValues = ImmutableList.copyOf(values); - long listWeight = immutableValues.size() * Caches.REFERENCE_SIZE; - for (T value : immutableValues) { + long listWeight = values.size() * Caches.REFERENCE_SIZE; + for (@Nullable T value : values) { listWeight = LongMath.saturatedAdd(listWeight, Caches.weigh(value)); } - return fromValues(immutableValues, listWeight, nextToken); + return fromValues(values, listWeight, true, nextToken); } public static Block fromValues( WeightedList values, @Nullable ByteString nextToken) { - if (values.isEmpty() && nextToken == null) { - return emptyBlock(); - } - return fromValues(ImmutableList.copyOf(values.getBacking()), values.getWeight(), nextToken); + return fromValues(values.getBacking(), values.getWeight(), true, nextToken); } private static Block fromValues( - ImmutableList values, long listWeight, @Nullable ByteString nextToken) { - long weight = LongMath.saturatedAdd(listWeight, 24); + List values, long valuesWeight, boolean copyList, @Nullable ByteString nextToken) { + if (nextToken == null && values.isEmpty()) { + return emptyBlock(); + } + long listOverhead = 0; + if (values instanceof ImmutableList) { + listOverhead = Caches.REFERENCE_SIZE * 3; + } else if (copyList) { + // We can't use ImmutableList.copy because that requires non-null values + // and there are some null values. + @SuppressWarnings("unchecked") + List copiedValues = (List) Arrays.asList(values.toArray()); + values = copiedValues; + listOverhead = 3 * Caches.REFERENCE_SIZE + 8; + } + long weight = LongMath.saturatedAdd(valuesWeight, listOverhead); if (nextToken != null) { if (nextToken.isEmpty()) { nextToken = ByteString.EMPTY; @@ -329,6 +392,7 @@ private static Block fromValues( values, nextToken, weight); } + // The returned list should not be modified. abstract List getValues(); abstract @Nullable ByteString getNextToken(); @@ -385,8 +449,7 @@ public void remove(Set toRemoveStructuralValues) { totalSize += tBlock.getValues().size(); } - ImmutableList.Builder allValues = ImmutableList.builderWithExpectedSize(totalSize); - long weight = 0; + Block.Builder builder = Block.builder(totalSize); List blockValuesToKeep = new ArrayList<>(); for (Block block : blocks) { blockValuesToKeep.clear(); @@ -402,19 +465,13 @@ public void remove(Set toRemoveStructuralValues) { // If any value was removed from this block, need to estimate the weight again. // Otherwise, just reuse the block's weight. if (valueRemovedFromBlock) { - allValues.addAll(blockValuesToKeep); - for (T value : blockValuesToKeep) { - weight = LongMath.saturatedAdd(weight, Caches.weigh(value)); - } + builder.addAndWeighAll(blockValuesToKeep); } else { - allValues.addAll(block.getValues()); - weight = LongMath.saturatedAdd(weight, block.getWeight()); + builder.addBlock(block); } } - cache.put( - IterableCacheKey.INSTANCE, - new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, null))); + cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(builder.build())); } /** @@ -505,24 +562,16 @@ private void appendHelper(List newValues, long newWeight) { for (Block block : blocks) { totalSize += block.getValues().size(); } - ImmutableList.Builder allValues = ImmutableList.builderWithExpectedSize(totalSize); - long weight = 0; + Block.Builder builder = Block.builder(totalSize); for (Block block : blocks) { - allValues.addAll(block.getValues()); - weight = LongMath.saturatedAdd(weight, block.getWeight()); + builder.addBlock(block); } if (newWeight < 0) { - newWeight = 0; - for (T value : newValues) { - newWeight = LongMath.saturatedAdd(newWeight, Caches.weigh(value)); - } + builder.addAndWeighAll(newValues); + } else { + builder.addAllWithWeight(newValues, newWeight); } - allValues.addAll(newValues); - weight = LongMath.saturatedAdd(weight, newWeight); - - cache.put( - IterableCacheKey.INSTANCE, - new MutatedBlocks<>(Block.fromValues(allValues.build(), weight, null))); + cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks<>(builder.build())); } class CachingStateIterator implements PrefetchableIterator { @@ -539,8 +588,7 @@ public CachingStateIterator() { new DataStreamDecoder<>(valueCoder, underlyingStateFetchingIterator); this.currentBlock = Block.fromValues( - WeightedList.of(Collections.emptyList(), 0L), - stateRequestForFirstChunk.getGet().getContinuationToken()); + ImmutableList.of(), stateRequestForFirstChunk.getGet().getContinuationToken()); this.currentCachedBlockValueIndex = 0; } @@ -550,11 +598,12 @@ public boolean isReady() { if (currentBlock.getValues().size() > currentCachedBlockValueIndex) { return true; } - if (currentBlock.getNextToken() == null) { + @Nullable ByteString nextToken = currentBlock.getNextToken(); + if (nextToken == null) { return true; } - Blocks existing = cache.peek(IterableCacheKey.INSTANCE); - boolean isFirstBlock = ByteString.EMPTY.equals(currentBlock.getNextToken()); + @Nullable Blocks existing = cache.peek(IterableCacheKey.INSTANCE); + boolean isFirstBlock = ByteString.EMPTY.equals(nextToken); if (existing == null) { // If there is nothing cached and we are on the first block then we are not ready. return false; @@ -566,9 +615,7 @@ public boolean isReady() { List> blocks = existing.getBlocks(); int currentBlockIndex = 0; for (; currentBlockIndex < blocks.size(); ++currentBlockIndex) { - if (currentBlock - .getNextToken() - .equals(blocks.get(currentBlockIndex).getNextToken())) { + if (Objects.equals(nextToken, blocks.get(currentBlockIndex).getNextToken())) { break; } } @@ -626,7 +673,7 @@ public boolean hasNext() { List> blocks = existing.getBlocks(); int currentBlockIndex = 0; for (; currentBlockIndex < blocks.size(); ++currentBlockIndex) { - if (nextToken.equals(blocks.get(currentBlockIndex).getNextToken())) { + if (Objects.equals(nextToken, blocks.get(currentBlockIndex).getNextToken())) { break; } } @@ -645,7 +692,8 @@ public boolean hasNext() { // tokens. if (existing != null && !existing.getBlocks().isEmpty() - && nextToken.equals( + && Objects.equals( + nextToken, existing.getBlocks().get(existing.getBlocks().size() - 1).getNextToken())) { List> newBlocks = new ArrayList<>(currentBlockIndex + 1); newBlocks.addAll(existing.getBlocks()); @@ -662,8 +710,8 @@ public boolean hasNext() { Block loadNextBlock(ByteString continuationToken) { underlyingStateFetchingIterator.seekToContinuationToken(continuationToken); WeightedList values = dataStreamDecoder.decodeFromChunkBoundaryToChunkBoundary(); - ByteString nextToken = underlyingStateFetchingIterator.getContinuationToken(); - if (ByteString.EMPTY.equals(nextToken)) { + @Nullable ByteString nextToken = underlyingStateFetchingIterator.getContinuationToken(); + if (Objects.equals(nextToken, ByteString.EMPTY)) { nextToken = null; } return Block.fromValues(values, nextToken); @@ -690,8 +738,8 @@ static class LazyBlockingStateFetchingIterator implements PrefetchableIterator prefetchedResponse; + private @Nullable ByteString continuationToken; + private @Nullable CompletableFuture prefetchedResponse; LazyBlockingStateFetchingIterator( BeamFnStateClient beamFnStateClient, StateRequest stateRequestForFirstChunk) { @@ -762,7 +810,7 @@ public ByteString next() { prefetch(); StateResponse stateResponse; try { - stateResponse = prefetchedResponse.get(); + stateResponse = checkNotNull(prefetchedResponse).get(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IllegalStateException(e); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java index d1cacf534ee5..43087cabbd37 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java @@ -47,6 +47,7 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; +import org.apache.beam.sdk.coders.NullableCoder; import org.apache.beam.sdk.fn.data.WeightedList; import org.apache.beam.sdk.fn.stream.PrefetchableIterator; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; @@ -296,6 +297,44 @@ public void testBlocksWeight() throws Exception { assertEquals(Long.MAX_VALUE, blocksOverflow.getWeight()); } + // Regression test: ensure null values are supported + @Test + public void testNullableValues() throws Exception { + StateRequest requestForFirstChunk = + StateRequest.newBuilder() + .setStateKey( + StateKey.newBuilder() + .setBagUserState( + StateKey.BagUserState.newBuilder() + .setTransformId("transformId") + .setUserStateId("stateId") + .setKey(ByteString.copyFromUtf8("key")) + .setWindow(ByteString.copyFromUtf8("window")))) + .setGet(StateGetRequest.getDefaultInstance()) + .build(); + + List expected = Arrays.asList(0, null, 1, null, 2); + FakeBeamFnStateClient fakeStateClient = + new FakeBeamFnStateClient( + NullableCoder.of(BigEndianIntegerCoder.of()), + ImmutableMap.of(requestForFirstChunk.getStateKey(), expected), + 2); + + CachingStateIterable iterable = + new CachingStateIterable<>( + Caches.eternal(), + fakeStateClient, + requestForFirstChunk, + NullableCoder.of(BigEndianIntegerCoder.of())); + + List results = new ArrayList<>(); + PrefetchableIterator iterator = iterable.createIterator(); + while (iterator.hasNext()) { + results.add(iterator.next()); + } + assertEquals(expected, results); + } + private CachingStateIterable create(int chunkSize, int... values) { StateRequest requestForFirstChunk = StateRequest.newBuilder() From 39bca4ebdf97eee5afd199004699a4c6d3acebfa Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Thu, 26 Mar 2026 14:36:34 +0100 Subject: [PATCH 3/5] rm suppressions --- .../org/apache/beam/fn/harness/Caches.java | 55 +++++++----- .../harness/state/StateFetchingIterators.java | 83 +++++++++---------- 2 files changed, 71 insertions(+), 67 deletions(-) diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java index 8782a7ed8b67..de4303c4e7c1 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java @@ -17,6 +17,8 @@ */ package org.apache.beam.fn.harness; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; + import java.util.Arrays; import java.util.Collections; import java.util.Objects; @@ -42,7 +44,6 @@ import org.slf4j.LoggerFactory; /** Utility methods used to instantiate and operate over cache instances. */ -@SuppressWarnings("nullness") public final class Caches { private static final Logger LOG = LoggerFactory.getLogger(Caches.class); @@ -116,6 +117,7 @@ static class ShrinkOnEviction implements RemovalListener> cacheBuilder, LongAdder weightInBytes) { this.cache = cacheBuilder.removalListener(this).build(); @@ -131,19 +133,19 @@ static class ShrinkOnEviction implements RemovalListener> removalNotification) { - weightInBytes.add( - -(removalNotification.getKey().getWeight() + removalNotification.getValue().getWeight())); - if (removalNotification.wasEvicted()) { - if (!(removalNotification.getValue().getValue() instanceof Cache.Shrinkable)) { - return; - } - @Nullable - Object updatedEntry = ((Shrinkable) removalNotification.getValue().getValue()).shrink(); - if (updatedEntry != null) { - cache.put( - removalNotification.getKey(), - addWeightedValue(removalNotification.getKey(), updatedEntry, weightInBytes)); - } + CompositeKey key = checkNotNull(removalNotification.getKey()); + WeightedValue value = checkNotNull(removalNotification.getValue()); + weightInBytes.add(-(key.getWeight() + value.getWeight())); + if (!removalNotification.wasEvicted()) { + return; + } + @Nullable Object v = value.getValue(); + if (!(v instanceof Cache.Shrinkable)) { + return; + } + @Nullable Object updatedEntry = ((Shrinkable) v).shrink(); + if (updatedEntry != null) { + cache.put(key, addWeightedValue(key, updatedEntry, weightInBytes)); } } } @@ -284,8 +286,8 @@ private static class SubCache implements Cache { } @Override - public V peek(K key) { - WeightedValue value = cache.getIfPresent(keyPrefix.valueKey(key)); + public @Nullable V peek(K key) { + @Nullable WeightedValue value = cache.getIfPresent(keyPrefix.valueKey(key)); if (value == null) { return null; } @@ -300,7 +302,9 @@ public V computeIfAbsent(K key, Function loadingFunction) { cache .get( compositeKey, - () -> addWeightedValue(compositeKey, loadingFunction.apply(key), weightInBytes)) + () -> + addWeightedValue( + compositeKey, checkNotNull(loadingFunction.apply(key)), weightInBytes)) .getValue(); } catch (ExecutionException e) { throw new RuntimeException(e); @@ -310,7 +314,7 @@ public V computeIfAbsent(K key, Function loadingFunction) { @Override public void put(K key, V value) { CompositeKey compositeKey = keyPrefix.valueKey(key); - cache.put(compositeKey, addWeightedValue(compositeKey, value, weightInBytes)); + cache.put(compositeKey, addWeightedValue(compositeKey, checkNotNull(value), weightInBytes)); } @Override @@ -358,7 +362,7 @@ CompositeKeyPrefix subKey(Object suffix, Object... additionalSuffixes) { return new CompositeKeyPrefix(subKey, subKeyWeight); } - CompositeKey valueKey(K k) { + CompositeKey valueKey(@Nullable K k) { return new CompositeKey(namespace, weight, k); } @@ -393,10 +397,10 @@ boolean isEquivalentNamespace(CompositeKey otherKey) { @VisibleForTesting static class CompositeKey implements Weighted { private final Object[] namespace; - private final Object key; + private final @Nullable Object key; private final long weight; - private CompositeKey(Object[] namespace, long namespaceWeight, Object key) { + private CompositeKey(Object[] namespace, long namespaceWeight, @Nullable Object key) { this.namespace = namespace; this.key = key; this.weight = namespaceWeight + weigh(key); @@ -404,11 +408,15 @@ private CompositeKey(Object[] namespace, long namespaceWeight, Object key) { @Override public String toString() { - return "CompositeKey{namespace=" + Arrays.toString(namespace) + ", key=" + key + "}"; + return "CompositeKey{namespace=" + + Arrays.toString(namespace) + + ", key=" + + String.valueOf(key) + + "}"; } @Override - public boolean equals(Object o) { + public boolean equals(@Nullable Object o) { if (this == o) { return true; } @@ -436,6 +444,7 @@ public long getWeight() { *

The set of keys that are tracked are only those provided to {@link #peek} and {@link * #computeIfAbsent}. */ + @SuppressWarnings("nullness") public static class ClearableCache extends SubCache { private final Set weakHashSet; diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java index 59a5ee50c0b6..6c3b8962eda9 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java @@ -20,7 +20,6 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; -import com.google.auto.value.AutoValue; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; @@ -277,31 +276,36 @@ public List> getBlocks() { } } - @AutoValue - abstract static class Block implements Weighted { - private static class EmptyBlock extends Block { - @Override - List getValues() { - return ImmutableList.of(); - } + static class Block implements Weighted { + private final List values; + private final @Nullable ByteString nextToken; + private final long weight; + private static final long ARRAY_LIST_OVERHEAD = 3 * Caches.REFERENCE_SIZE; + private static final Block EMPTY = new Block<>(ImmutableList.of(), null, 0); - @Nullable - @Override - ByteString getNextToken() { - return null; - } + private Block(List values, @Nullable ByteString nextToken, long weight) { + this.values = values; + this.nextToken = nextToken; + this.weight = weight; + } - @Override - public long getWeight() { - return 0; - } + public @Nullable ByteString getNextToken() { + return nextToken; } - private static final Block EMPTY = new EmptyBlock(); + @Override + public long getWeight() { + return weight; + } + + public List getValues() { + return values; + } static class Builder { - private Builder(int size) { - values = new ArrayList<>(size); + private Builder(int initialCapacity) { + values = new ArrayList<>(initialCapacity); + weight = ARRAY_LIST_OVERHEAD; } void addAndWeighAll(List addedValues) { @@ -323,16 +327,19 @@ void addBlock(Block b) { // The builder should not be used after this method. Block build() { + if (values.isEmpty()) { + return emptyBlock(); + } weight = LongMath.saturatedAdd(values.size() * Caches.REFERENCE_SIZE, weight); - return fromValues(values, weight, false, null); + return new Block<>(values, null, weight); } private long weight; private final ArrayList values; } - static Builder builder(int maxSize) { - return new Builder<>(maxSize); + static Builder builder(int initialCapacity) { + return new Builder<>(initialCapacity); } @SuppressWarnings("unchecked") // Based upon as Collections.emptyList() @@ -349,36 +356,33 @@ public static Block mutatedBlock(WeightedList values) { } public static Block fromValues(List values, @Nullable ByteString nextToken) { - if (nextToken == null && values.isEmpty()) { - return emptyBlock(); - } long listWeight = values.size() * Caches.REFERENCE_SIZE; for (@Nullable T value : values) { listWeight = LongMath.saturatedAdd(listWeight, Caches.weigh(value)); } - return fromValues(values, listWeight, true, nextToken); + return copyValues(values, listWeight, nextToken); } public static Block fromValues( WeightedList values, @Nullable ByteString nextToken) { - return fromValues(values.getBacking(), values.getWeight(), true, nextToken); + return copyValues(values.getBacking(), values.getWeight(), nextToken); } - private static Block fromValues( - List values, long valuesWeight, boolean copyList, @Nullable ByteString nextToken) { + private static Block copyValues( + List values, long valuesWeight, @Nullable ByteString nextToken) { if (nextToken == null && values.isEmpty()) { return emptyBlock(); } - long listOverhead = 0; + long listOverhead; if (values instanceof ImmutableList) { listOverhead = Caches.REFERENCE_SIZE * 3; - } else if (copyList) { + } else { // We can't use ImmutableList.copy because that requires non-null values - // and there are some null values. + // and null values are supported in Beam. @SuppressWarnings("unchecked") List copiedValues = (List) Arrays.asList(values.toArray()); values = copiedValues; - listOverhead = 3 * Caches.REFERENCE_SIZE + 8; + listOverhead = ARRAY_LIST_OVERHEAD; } long weight = LongMath.saturatedAdd(valuesWeight, listOverhead); if (nextToken != null) { @@ -388,17 +392,8 @@ private static Block fromValues( weight = LongMath.saturatedAdd(weight, Caches.weigh(nextToken)); } } - return new AutoValue_StateFetchingIterators_CachingStateIterable_Block<>( - values, nextToken, weight); + return new Block<>(values, nextToken, weight); } - - // The returned list should not be modified. - abstract List getValues(); - - abstract @Nullable ByteString getNextToken(); - - @Override - public abstract long getWeight(); } private final Cache> cache; From 112ec6d9ad5a9b7ef7d3872a5674a40a9f8c55e2 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Thu, 26 Mar 2026 15:07:05 +0100 Subject: [PATCH 4/5] fix benchmark shutdown order --- .../apache/beam/fn/harness/jmh/ProcessBundleBenchmark.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sdks/java/harness/jmh/src/main/java/org/apache/beam/fn/harness/jmh/ProcessBundleBenchmark.java b/sdks/java/harness/jmh/src/main/java/org/apache/beam/fn/harness/jmh/ProcessBundleBenchmark.java index 3b8fbeaf3dd0..fecc946a184f 100644 --- a/sdks/java/harness/jmh/src/main/java/org/apache/beam/fn/harness/jmh/ProcessBundleBenchmark.java +++ b/sdks/java/harness/jmh/src/main/java/org/apache/beam/fn/harness/jmh/ProcessBundleBenchmark.java @@ -198,12 +198,16 @@ public void log(LogEntry entry) { @TearDown public void tearDown() { try { + // Shutting down the control server should terminate the sdk client. + // We do this before shutting down logging server in particular as that can + // trigger exceptions if the client was not yet shutdown. controlServer.close(); + sdkHarnessExecutorFuture.get(); + stateServer.close(); dataServer.close(); loggingServer.close(); controlClient.close(); - sdkHarnessExecutorFuture.get(); } catch (InterruptedException ignored) { Thread.currentThread().interrupt(); } catch (Exception e) { From 43c4731fe91b82323f34ae49457a4118010bba16 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Fri, 27 Mar 2026 15:49:39 +0100 Subject: [PATCH 5/5] address another weigh of tokens --- .../apache/beam/fn/harness/state/StateFetchingIterators.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java index 6c3b8962eda9..3d83d0582e2c 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java @@ -389,7 +389,9 @@ private static Block copyValues( if (nextToken.isEmpty()) { nextToken = ByteString.EMPTY; } else { - weight = LongMath.saturatedAdd(weight, Caches.weigh(nextToken)); + // We don't expect large tokens that would not be copied by ByteString so we just count the size plus + // some overhead as weighing accurately is expensive. + weight = LongMath.saturatedAdd(weight, (long) nextToken.size() + Caches.REFERENCE_SIZE * 2); } } return new Block<>(values, nextToken, weight);