Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions runtime/common/src/main/proto/ControlMessage.proto
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ enum MessageType {
PipeInit = 13;
RequestPipeLoc = 14;
PipeLocInfo = 15;
ParentTaskDataCollected = 16;
CurrentlyProcessedBytesCollected = 17;
}

message Message {
Expand All @@ -107,6 +109,8 @@ message Message {
optional PipeInitMessage pipeInitMsg = 16;
optional RequestPipeLocationMessage requestPipeLocMsg = 17;
optional PipeLocationInfoMessage pipeLocInfoMsg = 18;
optional ParentTaskDataCollectMsg ParentTaskDataCollected = 19;
optional CurrentlyProcessedBytesCollectMsg currentlyProcessedBytesCollected = 20;
}

// Messages from Master to Executors
Expand Down Expand Up @@ -256,3 +260,13 @@ message PipeLocationInfoMessage {
required int64 requestId = 1; // To find the matching request msg
required string executorId = 2;
}

message ParentTaskDataCollectMsg {
required string taskId = 1;
required bytes partitionSizeMap = 2;
}

message CurrentlyProcessedBytesCollectMsg {
required string taskId = 1;
required int64 processedDataBytes = 2;
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ public final class BlockOutputWriter implements OutputWriter {

private long writtenBytes;

private Optional<Map<Integer, Long>> partitionSizeMap;

/**
* Constructor.
*
Expand Down Expand Up @@ -109,7 +111,7 @@ public void close() {
final DataPersistenceProperty.Value persistence = (DataPersistenceProperty.Value) runtimeEdge
.getPropertyValue(DataPersistenceProperty.class).orElseThrow(IllegalStateException::new);

final Optional<Map<Integer, Long>> partitionSizeMap = blockToWrite.commit();
partitionSizeMap = blockToWrite.commit();
// Return the total size of the committed block.
if (partitionSizeMap.isPresent()) {
long blockSizeTotal = 0;
Expand All @@ -123,6 +125,16 @@ public void close() {
blockManagerWorker.writeBlock(blockToWrite, blockStoreValue, getExpectedRead(), persistence);
}

@Override
public Optional<Map<Integer, Long>> getPartitionSizeMap() {
if (partitionSizeMap.isPresent()) {
return partitionSizeMap;
} else {
return Optional.empty();
}
}

@Override
public Optional<Long> getWrittenBytes() {
if (writtenBytes == -1) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.nemo.common.punctuation.Watermark;

import java.util.Map;
import java.util.Optional;

/**
Expand All @@ -45,5 +46,10 @@ public interface OutputWriter {
*/
Optional<Long> getWrittenBytes();

/**
* @return the map of hashed key to partition size.
*/
Optional<Map<Integer, Long>> getPartitionSizeMap();

void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
Expand Down Expand Up @@ -113,6 +114,11 @@ public Optional<Long> getWrittenBytes() {
return Optional.empty();
}

@Override
public Optional<Map<Integer, Long>> getPartitionSizeMap() {
return Optional.empty();
}

@Override
public void close() {
if (!initialized) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.nemo.common.ir.OutputCollector;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.runtime.executor.MetricMessageSender;

import java.io.IOException;

Expand Down Expand Up @@ -49,6 +50,21 @@ abstract class DataFetcher implements AutoCloseable {
*/
abstract Object fetchDataElement() throws IOException;

/**
* Identical with fetchDataElement(), except it sends intermediate serializedReadBytes to MetricStore
* on every iterator advance.
* This method is for WorkStealing implementation in Nemo.
*
* @param taskId task id
* @param metricMessageSender metricMessageSender
*
* @return data element
* @throws IOException upon I/O error
* @throws java.util.NoSuchElementException if no more element is available
*/
abstract Object fetchDataElementWithTrace(String taskId,
MetricMessageSender metricMessageSender) throws IOException;

OutputCollector getOutputCollector() {
return outputCollector;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.punctuation.Finishmark;
import org.apache.nemo.common.punctuation.Watermark;
import org.apache.nemo.runtime.executor.MetricMessageSender;
import org.apache.nemo.runtime.executor.data.DataUtil;
import org.apache.nemo.runtime.executor.datatransfer.*;
import org.slf4j.Logger;
Expand Down Expand Up @@ -100,6 +101,12 @@ Object fetchDataElement() throws IOException {
}
}

@Override
Object fetchDataElementWithTrace(final String taskId,
final MetricMessageSender metricMessageSender) throws IOException {
return fetchDataElement();
}

private void fetchDataLazily() {
final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> futures = readersForParentTask.read();
numOfIterators = futures.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
*/
package org.apache.nemo.runtime.executor.task;

import org.apache.commons.lang3.SerializationUtils;
import org.apache.nemo.common.ir.OutputCollector;
import org.apache.nemo.common.ir.edge.executionproperty.BlockFetchFailureProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.punctuation.Finishmark;
import org.apache.nemo.runtime.executor.MetricMessageSender;
import org.apache.nemo.runtime.executor.data.DataUtil;
import org.apache.nemo.runtime.executor.datatransfer.InputReader;
import org.slf4j.Logger;
Expand Down Expand Up @@ -100,6 +102,49 @@ Object fetchDataElement() throws IOException {
return Finishmark.getInstance();
}

@Override
Object fetchDataElementWithTrace(final String taskId,
final MetricMessageSender metricMessageSender) throws IOException {
try {
if (firstFetch) {
fetchDataLazily();
advanceIterator();
firstFetch = false;
}

while (true) {
// This iterator has the element
if (this.currentIterator.hasNext()) {
return this.currentIterator.next();
}

// This iterator does not have the element
if (currentIteratorIndex < expectedNumOfIterators) {
// Next iterator has the element
countBytes(currentIterator);
// Send the cumulative serBytes to MetricStore
metricMessageSender.send("TaskMetric", taskId, "serializedReadBytes",
SerializationUtils.serialize(serBytes));
advanceIterator();
continue;
} else {
// We've consumed all the iterators
break;
}

}
} catch (final Throwable e) {
// Any failure is caught and thrown as an IOException, so that the task is retried.
// In particular, we catch unchecked exceptions like RuntimeException thrown by DataUtil.IteratorWithNumBytes
// when remote data fetching fails for whatever reason.
// Note that we rely on unchecked exceptions because the Iterator interface does not provide the standard
// "throw Exception" that the TaskExecutor thread can catch and handle.
throw new IOException(e);
}

return Finishmark.getInstance();
}

private void advanceIterator() throws IOException {
// Take from iteratorQueue
final Object iteratorOrThrowable;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.nemo.common.ir.vertex.SourceVertex;
import org.apache.nemo.common.punctuation.Finishmark;
import org.apache.nemo.common.punctuation.Watermark;
import org.apache.nemo.runtime.executor.MetricMessageSender;

import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
Expand Down Expand Up @@ -74,6 +75,11 @@ Object fetchDataElement() {
}
}

@Override
Object fetchDataElementWithTrace(final String taskId, final MetricMessageSender metricMessageSender) {
return fetchDataElement();
}

final long getBoundedSourceReadTime() {
return boundedSourceReadTime;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.nemo.runtime.executor.task;

import com.google.common.collect.Lists;
import com.google.protobuf.ByteString;
import org.apache.commons.lang3.SerializationUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.nemo.common.Pair;
Expand Down Expand Up @@ -458,7 +459,7 @@ private boolean handleDataFetchers(final List<DataFetcher> fetchers) {
while (availableIterator.hasNext()) {
final DataFetcher dataFetcher = availableIterator.next();
try {
final Object element = dataFetcher.fetchDataElement();
final Object element = dataFetcher.fetchDataElementWithTrace(taskId, metricMessageSender);
onEventFromDataFetcher(element, dataFetcher);
if (element instanceof Finishmark) {
availableIterator.remove();
Expand Down Expand Up @@ -688,12 +689,21 @@ public void setIRVertexPutOnHold(final IRVertex irVertex) {
*/
private void finalizeOutputWriters(final VertexHarness vertexHarness) {
final List<Long> writtenBytesList = new ArrayList<>();
final HashMap<Integer, Long> partitionSizeMap = new HashMap<>();

// finalize OutputWriters for main children
vertexHarness.getWritersToMainChildrenTasks().forEach(outputWriter -> {
outputWriter.close();
final Optional<Long> writtenBytes = outputWriter.getWrittenBytes();
writtenBytes.ifPresent(writtenBytesList::add);

// Send partitionSizeMap to Scheduler
if (true) {
final Optional<Map<Integer, Long>> partitionSizes = outputWriter.getPartitionSizeMap();
if (partitionSizes.isPresent()) {
computePartitionSizeMap(partitionSizeMap, partitionSizes.get());
}
}
});

// finalize OutputWriters for additional tagged children
Expand All @@ -702,6 +712,14 @@ private void finalizeOutputWriters(final VertexHarness vertexHarness) {
outputWriter.close();
final Optional<Long> writtenBytes = outputWriter.getWrittenBytes();
writtenBytes.ifPresent(writtenBytesList::add);

// Send partitionSizeMap to Scheduler
if (true) {
final Optional<Map<Integer, Long>> partitionSizes = outputWriter.getPartitionSizeMap();
if (partitionSizes.isPresent()) {
computePartitionSizeMap(partitionSizeMap, partitionSizes.get());
}
}
})
);

Expand All @@ -713,5 +731,57 @@ private void finalizeOutputWriters(final VertexHarness vertexHarness) {
// TODO #236: Decouple metric collection and sending logic
metricMessageSender.send(TASK_METRIC_ID, taskId, "taskOutputBytes",
SerializationUtils.serialize(totalWrittenBytes));

if (!partitionSizeMap.isEmpty()) {
persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send(
ControlMessage.Message.newBuilder()
.setId(RuntimeIdManager.generateMessageId())
.setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
.setType(ControlMessage.MessageType.ParentTaskDataCollected)
.setParentTaskDataCollected(ControlMessage.ParentTaskDataCollectMsg.newBuilder()
.setTaskId(taskId)
.setPartitionSizeMap(ByteString.copyFrom(SerializationUtils.serialize(partitionSizeMap)))
.build())
.build());
}
}

// Methods for work stealing
/**
* Gather the KV statistics of processed data when execution is completed.
* This method is for work stealing implementation: the accumulated statistics will be used to
* detect skewed tasks of the child stage.
*
* @param totalPartitionSizeMap accumulated partitionSizeMap of task.
* @param singlePartitionSizeMap partitionSizeMap gained from single OutputWriter.
*/
private void computePartitionSizeMap(final Map<Integer, Long> totalPartitionSizeMap,
final Map<Integer, Long> singlePartitionSizeMap) {
for (Integer hashedKey : singlePartitionSizeMap.keySet()) {
final Long partitionSize = singlePartitionSizeMap.get(hashedKey);
if (totalPartitionSizeMap.containsKey(hashedKey)) {
totalPartitionSizeMap.compute(hashedKey, (existingKey, existingValue) -> existingValue + partitionSize);
} else {
totalPartitionSizeMap.put(hashedKey, partitionSize);
}
}
}

/**
* Send the temporally processed bytes of the current task on request from the scheduler.
* This method is for work stealing implementation.
*/
public void onRequestForProcessedData() {
LOG.error("{}, bytes {}, replying for the request", taskId, serializedReadBytes);
persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send(
ControlMessage.Message.newBuilder()
.setId(RuntimeIdManager.generateMessageId())
.setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
.setType(ControlMessage.MessageType.CurrentlyProcessedBytesCollected)
.setCurrentlyProcessedBytesCollected(ControlMessage.CurrentlyProcessedBytesCollectMsg.newBuilder()
.setTaskId(this.taskId)
.setProcessedDataBytes(serializedReadBytes)
.build())
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@
import javax.inject.Inject;
import java.io.Serializable;
import java.nio.file.Paths;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;

Expand Down Expand Up @@ -481,6 +478,22 @@ private void handleControlMessage(final ControlMessage.Message message) {
.setDataCollected(ControlMessage.DataCollectMessage.newBuilder().setData(serializedData).build())
.build());
break;
case ParentTaskDataCollected:
if (scheduler instanceof BatchScheduler) {
final ControlMessage.ParentTaskDataCollectMsg workStealingMsg = message.getParentTaskDataCollected();
final String taskId = workStealingMsg.getTaskId();
final Map<Integer, Long> partitionSizeMap = SerializationUtils
.deserialize(workStealingMsg.getPartitionSizeMap().toByteArray());
((BatchScheduler) scheduler).aggregateStageIdToPartitionSizeMap(taskId, partitionSizeMap);
}
break;
case CurrentlyProcessedBytesCollected:
if (scheduler instanceof BatchScheduler) {
((BatchScheduler) scheduler).aggregateTaskIdToProcessedBytes(
message.getCurrentlyProcessedBytesCollected().getTaskId(),
message.getCurrentlyProcessedBytesCollected().getProcessedDataBytes()
);
}
case MetricFlushed:
metricCountDownLatch.countDown();
break;
Expand Down
Loading