Skip to content

Commit 305dbcd

Browse files
RocMarshalTartarus0zm1996fanrui
committed
[FLINK-38943][runtime] Support Adaptive Partition Selection for RescalePartitioner & RebalancePartitioner
Co-authored-by: Tartarus0zm <zhangmang1@163.com> Co-authored-by: 1996fanrui <1996fanrui@gmail.com>
1 parent fd9a1ba commit 305dbcd

13 files changed

Lines changed: 531 additions & 11 deletions

File tree

docs/layouts/shortcodes/generated/all_taskmanager_network_section.html

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@
88
</tr>
99
</thead>
1010
<tbody>
11+
<tr>
12+
<td><h5>taskmanager.network.adaptive-partitioner.enabled</h5></td>
13+
<td style="word-wrap: break-word;">false</td>
14+
<td>Boolean</td>
15+
<td>Whether to enable adaptive partitioner feature for rescale and rebalance partitioners based on the load of the downstream tasks.</td>
16+
</tr>
17+
<tr>
18+
<td><h5>taskmanager.network.adaptive-partitioner.max-traverse-size</h5></td>
19+
<td style="word-wrap: break-word;">4</td>
20+
<td>Integer</td>
21+
<td>Maximum number of channels to traverse when looking for the most idle channel for rescale and rebalance partitioners when <code class="highlighter-rouge">taskmanager.network.adaptive-partitioner.enabled</code> is enabled.<br />Note, the value of the configuration option must be greater than `1`.</td>
22+
</tr>
1123
<tr>
1224
<td><h5>taskmanager.network.compression.codec</h5></td>
1325
<td style="word-wrap: break-word;">LZ4</td>

docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@
2626
<td>Boolean</td>
2727
<td>Enable SSL support for the taskmanager data transport. This is applicable only when the global flag for internal SSL (security.ssl.internal.enabled) is set to true</td>
2828
</tr>
29+
<tr>
30+
<td><h5>taskmanager.network.adaptive-partitioner.enabled</h5></td>
31+
<td style="word-wrap: break-word;">false</td>
32+
<td>Boolean</td>
33+
<td>Whether to enable adaptive partitioner feature for rescale and rebalance partitioners based on the load of the downstream tasks.</td>
34+
</tr>
35+
<tr>
36+
<td><h5>taskmanager.network.adaptive-partitioner.max-traverse-size</h5></td>
37+
<td style="word-wrap: break-word;">4</td>
38+
<td>Integer</td>
39+
<td>Maximum number of channels to traverse when looking for the most idle channel for rescale and rebalance partitioners when <code class="highlighter-rouge">taskmanager.network.adaptive-partitioner.enabled</code> is enabled.<br />Note, the value of the configuration option must be greater than `1`.</td>
40+
</tr>
2941
<tr>
3042
<td><h5>taskmanager.network.compression.codec</h5></td>
3143
<td style="word-wrap: break-word;">LZ4</td>

flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,34 @@ public enum CompressionCodec {
325325
code(NETWORK_REQUEST_BACKOFF_MAX.key()))
326326
.build());
327327

328+
/** Whether to improve the rebalance and rescale partitioners to adaptive partition. */
329+
@Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
330+
public static final ConfigOption<Boolean> ADAPTIVE_PARTITIONER_ENABLED =
331+
key("taskmanager.network.adaptive-partitioner.enabled")
332+
.booleanType()
333+
.defaultValue(false)
334+
.withDescription(
335+
"Whether to enable adaptive partitioner feature for rescale and rebalance partitioners based on the load of the downstream tasks.");
336+
337+
/**
338+
* Maximum number of channels to traverse when looking for the most idle channel for rescale and
339+
* rebalance partitioners when {@link #ADAPTIVE_PARTITIONER_ENABLED} is true.
340+
*/
341+
@Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
342+
public static final ConfigOption<Integer> ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE =
343+
key("taskmanager.network.adaptive-partitioner.max-traverse-size")
344+
.intType()
345+
.defaultValue(4)
346+
.withDescription(
347+
Description.builder()
348+
.text(
349+
"Maximum number of channels to traverse when looking for the most idle channel for rescale and rebalance partitioners when %s is enabled.",
350+
code(ADAPTIVE_PARTITIONER_ENABLED.key()))
351+
.linebreak()
352+
.text(
353+
"Note, the value of the configuration option must be greater than `1`.")
354+
.build());
355+
328356
// ------------------------------------------------------------------------
329357

330358
/** Not intended to be instantiated. */
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.runtime.io.network.api.writer;
20+
21+
import org.apache.flink.annotation.Internal;
22+
import org.apache.flink.annotation.VisibleForTesting;
23+
import org.apache.flink.core.io.IOReadableWritable;
24+
25+
import java.io.IOException;
26+
import java.nio.ByteBuffer;
27+
28+
/**
29+
* A record writer based on load of downstream tasks for {@link
30+
* org.apache.flink.streaming.runtime.partitioner.RescalePartitioner} and {@link
31+
* org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner}.
32+
*
33+
* <pre>
34+
*
35+
* Here are clarifications for some items to provide quick understanding.
36+
*
37+
* - Two new immutable attributes are introduced in this class:
38+
* -- `numberOfSubpartitions` represents the number of downstream partitions that can be written to.
39+
* -- `maxTraverseSize` represents the maximum number of partitions that the current partition selector can compare when performing rescale or rebalance.
40+
*
41+
* - Why do `maxTraverseSize` and `numberOfSubpartitions` not share a common attribute ?
42+
* If the same field were shared and `maxTraverseSize` were less than `numberOfSubpartitions` (e.g., 2 < 6), it would result in some downstream partitions (4 in this case) never being written to, which is incorrect behavior.
43+
*
44+
* - Why is it described that users cannot explicitly configure `maxTraverseSize` as 1 ?
45+
* Users should not explicitly set it to 1, as this would mean no load comparison is performed, effectively disabling the adaptive partitioning feature.
46+
*
47+
* - Why the internal value of `maxTraverseSize` may become 1:
48+
* This is reasonable if and only if the number of downstream partitions is exactly 1 (since no comparison is needed). This situation can arise from framework behaviors such as the {@link org.apache.flink.runtime.scheduler.adaptive.AdaptiveScheduler}, which are not directly controlled by users.
49+
* For example, when the following job enables the AdaptiveScheduler before rescaling:
50+
*
51+
* JobVertexA(parallelism=4, slotSharingGroup=SSG-A) --(rescale)--> JobVertexA(parallelism=5, slotSharingGroup=SSG-B)
52+
*
53+
* If the job scales down and only 2 slots are available, the parallelism configuration of the job changes to:
54+
*
55+
* JobVertexA(parallelism=1, slotSharingGroup=SSG-A) --(rescale)--> JobVertexA(parallelism=1, slotSharingGroup=SSG-B)
56+
*
57+
* In this case, the task of JobVertexA has only one writable downstream partition, so a `maxTraverseSize` of 1 is reasonable and meaningful.
58+
*
59+
* </pre>
60+
*
61+
* @param <T> The type of IOReadableWritable records.
62+
*/
63+
@Internal
64+
public final class AdaptiveLoadBasedRecordWriter<T extends IOReadableWritable>
65+
extends RecordWriter<T> {
66+
67+
private final int maxTraverseSize;
68+
private final int numberOfSubpartitions;
69+
private int currentChannel = -1;
70+
71+
AdaptiveLoadBasedRecordWriter(
72+
ResultPartitionWriter writer, long timeout, String taskName, int maxTraverseSize) {
73+
super(writer, timeout, taskName);
74+
this.numberOfSubpartitions = writer.getNumberOfSubpartitions();
75+
this.maxTraverseSize = Math.min(maxTraverseSize, numberOfSubpartitions);
76+
}
77+
78+
@Override
79+
public void emit(T record) throws IOException {
80+
checkErroneous();
81+
82+
currentChannel = getIdlestChannelIndex();
83+
84+
ByteBuffer byteBuffer = serializeRecord(serializer, record);
85+
targetPartition.emitRecord(byteBuffer, currentChannel);
86+
87+
if (flushAlways) {
88+
targetPartition.flush(currentChannel);
89+
}
90+
}
91+
92+
@VisibleForTesting
93+
int getIdlestChannelIndex() {
94+
int bestChannelBuffersCount = Integer.MAX_VALUE;
95+
long bestChannelBytesInQueue = Long.MAX_VALUE;
96+
int bestChannel = 0;
97+
for (int i = 1; i <= maxTraverseSize; i++) {
98+
int candidateChannel = (currentChannel + i) % numberOfSubpartitions;
99+
int candidateChannelBuffersCount =
100+
targetPartition.getBuffersCountUnsafe(candidateChannel);
101+
long candidateChannelBytesInQueue =
102+
targetPartition.getBytesInQueueUnsafe(candidateChannel);
103+
104+
if (candidateChannelBuffersCount == 0) {
105+
// If there isn't any pending data in the current channel, choose this channel
106+
// directly.
107+
return candidateChannel;
108+
}
109+
110+
if (candidateChannelBuffersCount < bestChannelBuffersCount
111+
|| (candidateChannelBuffersCount == bestChannelBuffersCount
112+
&& candidateChannelBytesInQueue < bestChannelBytesInQueue)) {
113+
bestChannel = candidateChannel;
114+
bestChannelBuffersCount = candidateChannelBuffersCount;
115+
bestChannelBytesInQueue = candidateChannelBytesInQueue;
116+
}
117+
}
118+
return bestChannel;
119+
}
120+
121+
/** Copy from {@link ChannelSelectorRecordWriter#broadcastEmit}. */
122+
@Override
123+
public void broadcastEmit(T record) throws IOException {
124+
checkErroneous();
125+
126+
// Emitting to all channels in a for loop can be better than calling
127+
// ResultPartitionWriter#broadcastRecord because the broadcastRecord
128+
// method incurs extra overhead.
129+
ByteBuffer serializedRecord = serializeRecord(serializer, record);
130+
for (int channelIndex = 0; channelIndex < numberOfSubpartitions; channelIndex++) {
131+
serializedRecord.rewind();
132+
emit(record, channelIndex);
133+
}
134+
135+
if (flushAlways) {
136+
flushAll();
137+
}
138+
}
139+
}

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
package org.apache.flink.runtime.io.network.api.writer;
2020

21+
import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
2122
import org.apache.flink.core.io.IOReadableWritable;
2223

2324
/** Utility class to encapsulate the logic of building a {@link RecordWriter} instance. */
@@ -29,6 +30,11 @@ public class RecordWriterBuilder<T extends IOReadableWritable> {
2930

3031
private String taskName = "test";
3132

33+
private boolean enabledAdaptivePartitioner = false;
34+
35+
private int maxTraverseSize =
36+
NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE.defaultValue();
37+
3238
public RecordWriterBuilder<T> setChannelSelector(ChannelSelector<T> selector) {
3339
this.selector = selector;
3440
return this;
@@ -44,11 +50,24 @@ public RecordWriterBuilder<T> setTaskName(String taskName) {
4450
return this;
4551
}
4652

53+
public RecordWriterBuilder<T> setEnabledAdaptivePartitioner(
54+
boolean enabledAdaptivePartitioner) {
55+
this.enabledAdaptivePartitioner = enabledAdaptivePartitioner;
56+
return this;
57+
}
58+
59+
public RecordWriterBuilder<T> setMaxTraverseSize(int maxTraverseSize) {
60+
this.maxTraverseSize = maxTraverseSize;
61+
return this;
62+
}
63+
4764
public RecordWriter<T> build(ResultPartitionWriter writer) {
4865
if (selector.isBroadcast()) {
4966
return new BroadcastRecordWriter<>(writer, timeout, taskName);
50-
} else {
51-
return new ChannelSelectorRecordWriter<>(writer, selector, timeout, taskName);
5267
}
68+
if (enabledAdaptivePartitioner) {
69+
return new AdaptiveLoadBasedRecordWriter<>(writer, timeout, taskName, maxTraverseSize);
70+
}
71+
return new ChannelSelectorRecordWriter<>(writer, selector, timeout, taskName);
5372
}
5473
}

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ public interface ResultPartitionWriter extends AutoCloseable, AvailabilityProvid
6060
/** Writes the given serialized record to the target subpartition. */
6161
void emitRecord(ByteBuffer record, int targetSubpartition) throws IOException;
6262

63+
default long getBytesInQueueUnsafe(int targetSubpartition) {
64+
return 0;
65+
}
66+
67+
default int getBuffersCountUnsafe(int targetSubpartition) {
68+
return 0;
69+
}
70+
6371
/**
6472
* Writes the given serialized record to all subpartitions. One can also achieve the same effect
6573
* by emitting the same record to all subpartitions one by one, however, this method can have

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,9 @@ public interface BufferPool extends BufferProvider, BufferRecycler {
7575

7676
/** Returns the number of used buffers of this buffer pool. */
7777
int bestEffortGetNumOfUsedBuffers();
78+
79+
/** Returns the requested buffer count for target channel. */
80+
default int getBuffersCountUnsafe(int targetChannel) {
81+
return 0;
82+
}
7883
}

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,4 +824,9 @@ public static AvailabilityStatus from(
824824
}
825825
}
826826
}
827+
828+
@Override
829+
public int getBuffersCountUnsafe(int targetChannel) {
830+
return subpartitionBuffersCount[targetChannel];
831+
}
827832
}

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public abstract class BufferWritingResultPartition extends ResultPartition {
6565

6666
private TimerGauge hardBackPressuredTimeMsPerSecond = new TimerGauge();
6767

68-
private long totalWrittenBytes;
68+
private final long[] writtenBytesPerSubpartition;
6969

7070
public BufferWritingResultPartition(
7171
String owningTaskName,
@@ -91,6 +91,7 @@ public BufferWritingResultPartition(
9191

9292
this.subpartitions = checkNotNull(subpartitions);
9393
this.unicastBufferBuilders = new BufferBuilder[subpartitions.length];
94+
this.writtenBytesPerSubpartition = new long[subpartitions.length];
9495
}
9596

9697
@Override
@@ -114,6 +115,11 @@ public int getNumberOfQueuedBuffers() {
114115

115116
@Override
116117
public long getSizeOfQueuedBuffersUnsafe() {
118+
long totalWrittenBytes = 0;
119+
for (int i = 0; i < subpartitions.length; i++) {
120+
totalWrittenBytes += writtenBytesPerSubpartition[i];
121+
}
122+
117123
long totalNumberOfBytes = 0;
118124

119125
for (ResultSubpartition subpartition : subpartitions) {
@@ -123,6 +129,12 @@ public long getSizeOfQueuedBuffersUnsafe() {
123129
return totalWrittenBytes - totalNumberOfBytes;
124130
}
125131

132+
@Override
133+
public long getBytesInQueueUnsafe(int targetSubpartition) {
134+
return writtenBytesPerSubpartition[targetSubpartition]
135+
- subpartitions[targetSubpartition].getTotalNumberOfBytesUnsafe();
136+
}
137+
126138
@Override
127139
public int getNumberOfQueuedBuffers(int targetSubpartition) {
128140
checkArgument(targetSubpartition >= 0 && targetSubpartition < numSubpartitions);
@@ -151,7 +163,7 @@ protected void flushAllSubpartitions(boolean finishProducers) {
151163

152164
@Override
153165
public void emitRecord(ByteBuffer record, int targetSubpartition) throws IOException {
154-
totalWrittenBytes += record.remaining();
166+
writtenBytesPerSubpartition[targetSubpartition] += record.remaining();
155167

156168
BufferBuilder buffer = appendUnicastDataForNewRecord(record, targetSubpartition);
157169

@@ -171,7 +183,9 @@ public void emitRecord(ByteBuffer record, int targetSubpartition) throws IOExcep
171183

172184
@Override
173185
public void broadcastRecord(ByteBuffer record) throws IOException {
174-
totalWrittenBytes += ((long) record.remaining() * numSubpartitions);
186+
for (int i = 0; i < subpartitions.length; i++) {
187+
writtenBytesPerSubpartition[i] += record.remaining();
188+
}
175189

176190
BufferBuilder buffer = appendBroadcastDataForNewRecord(record);
177191

@@ -197,11 +211,11 @@ public void broadcastEvent(AbstractEvent event, boolean isPriorityEvent) throws
197211

198212
try (BufferConsumer eventBufferConsumer =
199213
EventSerializer.toBufferConsumer(event, isPriorityEvent)) {
200-
totalWrittenBytes += ((long) eventBufferConsumer.getWrittenBytes() * numSubpartitions);
201-
for (ResultSubpartition subpartition : subpartitions) {
214+
for (int i = 0; i < subpartitions.length; i++) {
202215
// Retain the buffer so that it can be recycled by each subpartition of
203216
// targetPartition
204-
subpartition.add(eventBufferConsumer.copy(), 0);
217+
subpartitions[i].add(eventBufferConsumer.copy(), 0);
218+
writtenBytesPerSubpartition[i] += eventBufferConsumer.getWrittenBytes();
205219
}
206220
}
207221
}
@@ -246,8 +260,8 @@ public void finish() throws IOException {
246260
finishBroadcastBufferBuilder();
247261
finishUnicastBufferBuilders();
248262

249-
for (ResultSubpartition subpartition : subpartitions) {
250-
totalWrittenBytes += subpartition.finish();
263+
for (int i = 0; i < subpartitions.length; i++) {
264+
writtenBytesPerSubpartition[i] += subpartitions[i].finish();
251265
}
252266

253267
super.finish();
@@ -340,7 +354,7 @@ private void addToSubpartition(
340354
protected int addToSubpartition(
341355
int targetSubpartition, BufferConsumer bufferConsumer, int partialRecordLength)
342356
throws IOException {
343-
totalWrittenBytes += bufferConsumer.getWrittenBytes();
357+
writtenBytesPerSubpartition[targetSubpartition] += bufferConsumer.getWrittenBytes();
344358
return subpartitions[targetSubpartition].add(bufferConsumer, partialRecordLength);
345359
}
346360

0 commit comments

Comments
 (0)