diff --git a/runtime/common/src/main/proto/ControlMessage.proto b/runtime/common/src/main/proto/ControlMessage.proto index 97e30fb4e7..53adea3d3a 100644 --- a/runtime/common/src/main/proto/ControlMessage.proto +++ b/runtime/common/src/main/proto/ControlMessage.proto @@ -86,6 +86,8 @@ enum MessageType { PipeInit = 13; RequestPipeLoc = 14; PipeLocInfo = 15; + ParentTaskDataCollected = 16; + CurrentlyProcessedBytesCollected = 17; } message Message { @@ -107,6 +109,8 @@ message Message { optional PipeInitMessage pipeInitMsg = 16; optional RequestPipeLocationMessage requestPipeLocMsg = 17; optional PipeLocationInfoMessage pipeLocInfoMsg = 18; + optional ParentTaskDataCollectMsg ParentTaskDataCollected = 19; + optional CurrentlyProcessedBytesCollectMsg currentlyProcessedBytesCollected = 20; } // Messages from Master to Executors @@ -256,3 +260,13 @@ message PipeLocationInfoMessage { required int64 requestId = 1; // To find the matching request msg required string executorId = 2; } + +message ParentTaskDataCollectMsg { + required string taskId = 1; + required bytes partitionSizeMap = 2; +} + +message CurrentlyProcessedBytesCollectMsg { + required string taskId = 1; + required int64 processedDataBytes = 2; +} diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java index 97cde037c6..b9cf9ef129 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java @@ -52,6 +52,8 @@ public final class BlockOutputWriter implements OutputWriter { private long writtenBytes; + private Optional> partitionSizeMap; + /** * Constructor. * @@ -109,7 +111,7 @@ public void close() { final DataPersistenceProperty.Value persistence = (DataPersistenceProperty.Value) runtimeEdge .getPropertyValue(DataPersistenceProperty.class).orElseThrow(IllegalStateException::new); - final Optional> partitionSizeMap = blockToWrite.commit(); + partitionSizeMap = blockToWrite.commit(); // Return the total size of the committed block. if (partitionSizeMap.isPresent()) { long blockSizeTotal = 0; @@ -123,6 +125,16 @@ public void close() { blockManagerWorker.writeBlock(blockToWrite, blockStoreValue, getExpectedRead(), persistence); } + @Override + public Optional> getPartitionSizeMap() { + if (partitionSizeMap.isPresent()) { + return partitionSizeMap; + } else { + return Optional.empty(); + } + } + + @Override public Optional getWrittenBytes() { if (writtenBytes == -1) { return Optional.empty(); diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java index bf6ff84e69..a1862f5f2d 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java @@ -20,6 +20,7 @@ import org.apache.nemo.common.punctuation.Watermark; +import java.util.Map; import java.util.Optional; /** @@ -45,5 +46,10 @@ public interface OutputWriter { */ Optional getWrittenBytes(); + /** + * @return the map of hashed key to partition size. + */ + Optional> getPartitionSizeMap(); + void close(); } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java index 544d64d921..d0025428aa 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java @@ -34,6 +34,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; /** @@ -113,6 +114,11 @@ public Optional getWrittenBytes() { return Optional.empty(); } + @Override + public Optional> getPartitionSizeMap() { + return Optional.empty(); + } + @Override public void close() { if (!initialized) { diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java index 7af08852eb..b1a828c13c 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java @@ -20,6 +20,7 @@ import org.apache.nemo.common.ir.OutputCollector; import org.apache.nemo.common.ir.vertex.IRVertex; +import org.apache.nemo.runtime.executor.MetricMessageSender; import java.io.IOException; @@ -49,6 +50,21 @@ abstract class DataFetcher implements AutoCloseable { */ abstract Object fetchDataElement() throws IOException; + /** + * Identical with fetchDataElement(), except it sends intermediate serializedReadBytes to MetricStore + * on every iterator advance. + * This method is for WorkStealing implementation in Nemo. + * + * @param taskId task id + * @param metricMessageSender metricMessageSender + * + * @return data element + * @throws IOException upon I/O error + * @throws java.util.NoSuchElementException if no more element is available + */ + abstract Object fetchDataElementWithTrace(String taskId, + MetricMessageSender metricMessageSender) throws IOException; + OutputCollector getOutputCollector() { return outputCollector; } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java index 797818ce44..d7947e8c78 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java @@ -22,6 +22,7 @@ import org.apache.nemo.common.ir.vertex.IRVertex; import org.apache.nemo.common.punctuation.Finishmark; import org.apache.nemo.common.punctuation.Watermark; +import org.apache.nemo.runtime.executor.MetricMessageSender; import org.apache.nemo.runtime.executor.data.DataUtil; import org.apache.nemo.runtime.executor.datatransfer.*; import org.slf4j.Logger; @@ -100,6 +101,12 @@ Object fetchDataElement() throws IOException { } } + @Override + Object fetchDataElementWithTrace(final String taskId, + final MetricMessageSender metricMessageSender) throws IOException { + return fetchDataElement(); + } + private void fetchDataLazily() { final List> futures = readersForParentTask.read(); numOfIterators = futures.size(); diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java index a8ae4a9306..3a92cbc8a9 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java @@ -18,10 +18,12 @@ */ package org.apache.nemo.runtime.executor.task; +import org.apache.commons.lang3.SerializationUtils; import org.apache.nemo.common.ir.OutputCollector; import org.apache.nemo.common.ir.edge.executionproperty.BlockFetchFailureProperty; import org.apache.nemo.common.ir.vertex.IRVertex; import org.apache.nemo.common.punctuation.Finishmark; +import org.apache.nemo.runtime.executor.MetricMessageSender; import org.apache.nemo.runtime.executor.data.DataUtil; import org.apache.nemo.runtime.executor.datatransfer.InputReader; import org.slf4j.Logger; @@ -100,6 +102,49 @@ Object fetchDataElement() throws IOException { return Finishmark.getInstance(); } + @Override + Object fetchDataElementWithTrace(final String taskId, + final MetricMessageSender metricMessageSender) throws IOException { + try { + if (firstFetch) { + fetchDataLazily(); + advanceIterator(); + firstFetch = false; + } + + while (true) { + // This iterator has the element + if (this.currentIterator.hasNext()) { + return this.currentIterator.next(); + } + + // This iterator does not have the element + if (currentIteratorIndex < expectedNumOfIterators) { + // Next iterator has the element + countBytes(currentIterator); + // Send the cumulative serBytes to MetricStore + metricMessageSender.send("TaskMetric", taskId, "serializedReadBytes", + SerializationUtils.serialize(serBytes)); + advanceIterator(); + continue; + } else { + // We've consumed all the iterators + break; + } + + } + } catch (final Throwable e) { + // Any failure is caught and thrown as an IOException, so that the task is retried. + // In particular, we catch unchecked exceptions like RuntimeException thrown by DataUtil.IteratorWithNumBytes + // when remote data fetching fails for whatever reason. + // Note that we rely on unchecked exceptions because the Iterator interface does not provide the standard + // "throw Exception" that the TaskExecutor thread can catch and handle. + throw new IOException(e); + } + + return Finishmark.getInstance(); + } + private void advanceIterator() throws IOException { // Take from iteratorQueue final Object iteratorOrThrowable; diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java index 2d82898d7a..68a3362d27 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java @@ -23,6 +23,7 @@ import org.apache.nemo.common.ir.vertex.SourceVertex; import org.apache.nemo.common.punctuation.Finishmark; import org.apache.nemo.common.punctuation.Watermark; +import org.apache.nemo.runtime.executor.MetricMessageSender; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @@ -74,6 +75,11 @@ Object fetchDataElement() { } } + @Override + Object fetchDataElementWithTrace(final String taskId, final MetricMessageSender metricMessageSender) { + return fetchDataElement(); + } + final long getBoundedSourceReadTime() { return boundedSourceReadTime; } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java index 2bf574d396..91e8212640 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java @@ -19,6 +19,7 @@ package org.apache.nemo.runtime.executor.task; import com.google.common.collect.Lists; +import com.google.protobuf.ByteString; import org.apache.commons.lang3.SerializationUtils; import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.nemo.common.Pair; @@ -458,7 +459,7 @@ private boolean handleDataFetchers(final List fetchers) { while (availableIterator.hasNext()) { final DataFetcher dataFetcher = availableIterator.next(); try { - final Object element = dataFetcher.fetchDataElement(); + final Object element = dataFetcher.fetchDataElementWithTrace(taskId, metricMessageSender); onEventFromDataFetcher(element, dataFetcher); if (element instanceof Finishmark) { availableIterator.remove(); @@ -688,12 +689,21 @@ public void setIRVertexPutOnHold(final IRVertex irVertex) { */ private void finalizeOutputWriters(final VertexHarness vertexHarness) { final List writtenBytesList = new ArrayList<>(); + final HashMap partitionSizeMap = new HashMap<>(); // finalize OutputWriters for main children vertexHarness.getWritersToMainChildrenTasks().forEach(outputWriter -> { outputWriter.close(); final Optional writtenBytes = outputWriter.getWrittenBytes(); writtenBytes.ifPresent(writtenBytesList::add); + + // Send partitionSizeMap to Scheduler + if (true) { + final Optional> partitionSizes = outputWriter.getPartitionSizeMap(); + if (partitionSizes.isPresent()) { + computePartitionSizeMap(partitionSizeMap, partitionSizes.get()); + } + } }); // finalize OutputWriters for additional tagged children @@ -702,6 +712,14 @@ private void finalizeOutputWriters(final VertexHarness vertexHarness) { outputWriter.close(); final Optional writtenBytes = outputWriter.getWrittenBytes(); writtenBytes.ifPresent(writtenBytesList::add); + + // Send partitionSizeMap to Scheduler + if (true) { + final Optional> partitionSizes = outputWriter.getPartitionSizeMap(); + if (partitionSizes.isPresent()) { + computePartitionSizeMap(partitionSizeMap, partitionSizes.get()); + } + } }) ); @@ -713,5 +731,57 @@ private void finalizeOutputWriters(final VertexHarness vertexHarness) { // TODO #236: Decouple metric collection and sending logic metricMessageSender.send(TASK_METRIC_ID, taskId, "taskOutputBytes", SerializationUtils.serialize(totalWrittenBytes)); + + if (!partitionSizeMap.isEmpty()) { + persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send( + ControlMessage.Message.newBuilder() + .setId(RuntimeIdManager.generateMessageId()) + .setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID) + .setType(ControlMessage.MessageType.ParentTaskDataCollected) + .setParentTaskDataCollected(ControlMessage.ParentTaskDataCollectMsg.newBuilder() + .setTaskId(taskId) + .setPartitionSizeMap(ByteString.copyFrom(SerializationUtils.serialize(partitionSizeMap))) + .build()) + .build()); + } + } + + // Methods for work stealing + /** + * Gather the KV statistics of processed data when execution is completed. + * This method is for work stealing implementation: the accumulated statistics will be used to + * detect skewed tasks of the child stage. + * + * @param totalPartitionSizeMap accumulated partitionSizeMap of task. + * @param singlePartitionSizeMap partitionSizeMap gained from single OutputWriter. + */ + private void computePartitionSizeMap(final Map totalPartitionSizeMap, + final Map singlePartitionSizeMap) { + for (Integer hashedKey : singlePartitionSizeMap.keySet()) { + final Long partitionSize = singlePartitionSizeMap.get(hashedKey); + if (totalPartitionSizeMap.containsKey(hashedKey)) { + totalPartitionSizeMap.compute(hashedKey, (existingKey, existingValue) -> existingValue + partitionSize); + } else { + totalPartitionSizeMap.put(hashedKey, partitionSize); + } + } + } + + /** + * Send the temporally processed bytes of the current task on request from the scheduler. + * This method is for work stealing implementation. + */ + public void onRequestForProcessedData() { + LOG.error("{}, bytes {}, replying for the request", taskId, serializedReadBytes); + persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send( + ControlMessage.Message.newBuilder() + .setId(RuntimeIdManager.generateMessageId()) + .setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID) + .setType(ControlMessage.MessageType.CurrentlyProcessedBytesCollected) + .setCurrentlyProcessedBytesCollected(ControlMessage.CurrentlyProcessedBytesCollectMsg.newBuilder() + .setTaskId(this.taskId) + .setProcessedDataBytes(serializedReadBytes) + .build()) + .build()); } } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java index d3b48f266a..40fb5e86fc 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java @@ -58,10 +58,7 @@ import javax.inject.Inject; import java.io.Serializable; import java.nio.file.Paths; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; +import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; @@ -481,6 +478,22 @@ private void handleControlMessage(final ControlMessage.Message message) { .setDataCollected(ControlMessage.DataCollectMessage.newBuilder().setData(serializedData).build()) .build()); break; + case ParentTaskDataCollected: + if (scheduler instanceof BatchScheduler) { + final ControlMessage.ParentTaskDataCollectMsg workStealingMsg = message.getParentTaskDataCollected(); + final String taskId = workStealingMsg.getTaskId(); + final Map partitionSizeMap = SerializationUtils + .deserialize(workStealingMsg.getPartitionSizeMap().toByteArray()); + ((BatchScheduler) scheduler).aggregateStageIdToPartitionSizeMap(taskId, partitionSizeMap); + } + break; + case CurrentlyProcessedBytesCollected: + if (scheduler instanceof BatchScheduler) { + ((BatchScheduler) scheduler).aggregateTaskIdToProcessedBytes( + message.getCurrentlyProcessedBytesCollected().getTaskId(), + message.getCurrentlyProcessedBytesCollected().getProcessedDataBytes() + ); + } case MetricFlushed: metricCountDownLatch.countDown(); break; diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java index 086c9d08bd..8941aa093c 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java @@ -77,6 +77,12 @@ public final class BatchScheduler implements Scheduler { */ private List> sortedScheduleGroups; // Stages, sorted in the order to be scheduled. + /** + * Data Structures for work stealing. + */ + private final Map> stageIdToOutputPartitionSizeMap = new HashMap<>(); + private final Map taskIdToProcessedBytes = new HashMap<>(); + @Inject private BatchScheduler(final PlanRewriter planRewriter, final TaskDispatcher taskDispatcher, @@ -383,4 +389,39 @@ private boolean modifyStageNumCloneUsingMedianTime(final String stageId, return false; } + + // Methods for work stealing + + /** + * Accumulate the execution result of each stage in Map[STAGE ID, Map[KEY, SIZE]] format. + * KEY is assumed to be Integer because of the HashPartition. + * + * @param taskId id of task to accumulate. + * @param partitionSizeMap map of (K) - (partition size) of the task. + */ + public void aggregateStageIdToPartitionSizeMap(final String taskId, + final Map partitionSizeMap) { + final Map partitionSizeMapForThisStage = stageIdToOutputPartitionSizeMap + .getOrDefault(RuntimeIdManager.getStageIdFromTaskId(taskId), new HashMap<>()); + for (Integer hashedKey : partitionSizeMap.keySet()) { + final Long partitionSize = partitionSizeMap.get(hashedKey); + if (partitionSizeMapForThisStage.containsKey(hashedKey)) { + partitionSizeMapForThisStage.put(hashedKey, partitionSize + partitionSizeMapForThisStage.get(hashedKey)); + } else { + partitionSizeMapForThisStage.put(hashedKey, partitionSize); + } + } + stageIdToOutputPartitionSizeMap.put(RuntimeIdManager.getStageIdFromTaskId(taskId), partitionSizeMapForThisStage); + } + + /** + * Store the tracked processed bytes per task by the current time. + * + * @param taskId id of task to track. + * @param processedBytes size of the processed bytes till now. + */ + public void aggregateTaskIdToProcessedBytes(final String taskId, + final long processedBytes) { + taskIdToProcessedBytes.put(taskId, processedBytes); + } }