diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/OperatorVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/OperatorVertex.java index 1b2136d070..be44a8af3d 100644 --- a/common/src/main/java/org/apache/nemo/common/ir/vertex/OperatorVertex.java +++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/OperatorVertex.java @@ -27,6 +27,7 @@ */ public class OperatorVertex extends IRVertex { private final Transform transform; + private final String transformFullName; /** * Constructor of OperatorVertex. @@ -36,6 +37,12 @@ public class OperatorVertex extends IRVertex { public OperatorVertex(final Transform t) { super(); this.transform = t; + this.transformFullName = ""; + } + + public OperatorVertex(final Transform t, final String transformFullName) { + this.transform = t; + this.transformFullName = transformFullName; } /** @@ -46,6 +53,7 @@ public OperatorVertex(final Transform t) { private OperatorVertex(final OperatorVertex that) { super(that); this.transform = that.transform; + this.transformFullName = that.transformFullName; } @Override @@ -60,6 +68,10 @@ public final Transform getTransform() { return transform; } + public final String getTransformFullName() { + return transformFullName; + } + @Override public final ObjectNode getPropertiesAsJsonNode() { final ObjectNode node = getIRVertexPropertiesAsJsonNode(); diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/WorkStealingStateProperty.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/WorkStealingStateProperty.java new file mode 100644 index 0000000000..def6b481ce --- /dev/null +++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/WorkStealingStateProperty.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.nemo.common.ir.vertex.executionproperty; + +import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty; + +/** + * Marks Work Stealing Strategy of the vertex. + * + * Currently, there are three types: + * SPLIT : vertex which is the subject of work stealing + * MERGE : vertex which merges the effect of work stealing + * DEFAULT : vertex which is not the subject of work stealing + */ +public class WorkStealingStateProperty extends VertexExecutionProperty { + /** + * Default constructor. + * + * @param value value of the VertexExecutionProperty. + */ + public WorkStealingStateProperty(final String value) { + super(value); + } + + /** + * Static method exposing the constructor. + * + * @param value value of the new execution property. + * @return the newly created execution property. + */ + public static WorkStealingStateProperty of(final String value) { + return new WorkStealingStateProperty(value); + } +} diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/WorkStealingSubSplitProperty.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/WorkStealingSubSplitProperty.java new file mode 100644 index 0000000000..fe9db872a7 --- /dev/null +++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/WorkStealingSubSplitProperty.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.nemo.common.ir.vertex.executionproperty; + +import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty; + +/** + * Property to store the sub-split number of work stealing tasks. + */ +public class WorkStealingSubSplitProperty extends VertexExecutionProperty { + /** + * Default constructor. + * + * @param value value of the VertexExecutionProperty. + */ + public WorkStealingSubSplitProperty(final Integer value) { + super(value); + } + + /** + * Static method exposing the constructor. + * + * @param value value of the new execution property. + * @return the newly created execution property. + */ + public static WorkStealingSubSplitProperty of(final Integer value) { + return new WorkStealingSubSplitProperty(value); + } +} diff --git a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java index cd9d7ad223..f3063a2989 100644 --- a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java +++ b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java @@ -269,7 +269,8 @@ private static void parDoMultiOutputTranslator(final PipelineTranslationContext private static void groupByKeyTranslator(final PipelineTranslationContext ctx, final TransformHierarchy.Node beamNode, final GroupByKey transform) { - final IRVertex vertex = new OperatorVertex(createGBKTransform(ctx, beamNode)); + final String fullName = beamNode.getFullName(); + final IRVertex vertex = new OperatorVertex(createGBKTransform(ctx, beamNode), fullName); ctx.addVertex(vertex); beamNode.getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input)); beamNode.getOutputs().values().forEach(output -> ctx.registerMainOutputFrom(beamNode, vertex, output)); @@ -324,7 +325,8 @@ private static void createPCollectionViewTranslator(final PipelineTranslationCon private static void flattenTranslator(final PipelineTranslationContext ctx, final TransformHierarchy.Node beamNode, final Flatten.PCollections transform) { - final IRVertex vertex = new OperatorVertex(new FlattenTransform()); + final String fullName = beamNode.getFullName(); + final IRVertex vertex = new OperatorVertex(new FlattenTransform(), fullName); ctx.addVertex(vertex); beamNode.getInputs().values().forEach(input -> ctx.addEdgeTo(vertex, input)); beamNode.getOutputs().values().forEach(output -> ctx.registerMainOutputFrom(beamNode, vertex, output)); @@ -350,6 +352,7 @@ private static Pipeline.PipelineVisitor.CompositeBehavior combinePerKeyTranslato final PTransform transform) { final Combine.PerKey perKey = (Combine.PerKey) transform; + final String fullName = beamNode.getFullName(); // If there's any side inputs, translate each primitive transforms in this composite transform one by one. if (!perKey.getSideInputs().isEmpty()) { @@ -382,8 +385,8 @@ private static Pipeline.PipelineVisitor.CompositeBehavior combinePerKeyTranslato // Choose between batch processing and stream processing based on window type and boundedness of data if (isMainInputBounded(beamNode, ctx.getPipeline()) && isGlobalWindow(beamNode, ctx.getPipeline())) { // Batch processing, using CombinePartialTransform and CombineFinalTransform - partialCombine = new OperatorVertex(new CombineFnPartialTransform<>(combineFn)); - finalCombine = new OperatorVertex(new CombineFnFinalTransform<>(combineFn)); + partialCombine = new OperatorVertex(new CombineFnPartialTransform<>(combineFn), fullName); + finalCombine = new OperatorVertex(new CombineFnFinalTransform<>(combineFn), fullName); } else { // Stream data processing, using GBKTransform final AppliedPTransform pTransform = beamNode.toAppliedPTransform(ctx.getPipeline()); diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/WorkStealingStatePass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/WorkStealingStatePass.java new file mode 100644 index 0000000000..15b46e5c77 --- /dev/null +++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/WorkStealingStatePass.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.nemo.compiler.optimizer.pass.compiletime.annotating; + +import org.apache.nemo.common.Pair; +import org.apache.nemo.common.dag.Edge; +import org.apache.nemo.common.ir.IRDAG; +import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty; +import org.apache.nemo.common.ir.vertex.IRVertex; +import org.apache.nemo.common.ir.vertex.OperatorVertex; +import org.apache.nemo.common.ir.vertex.executionproperty.WorkStealingStateProperty; +import org.apache.nemo.common.ir.vertex.transform.Transform; +import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires; +import org.apache.nemo.runtime.common.plan.StagePartitioner; + +import java.util.*; + +/** + * Optimization pass for annotating {@link WorkStealingStateProperty}. + */ +@Annotates(WorkStealingStateProperty.class) +@Requires(CommunicationPatternProperty.class) +public final class WorkStealingStatePass extends AnnotatingPass { + private static final String SPLIT_STRATEGY = "SPLIT"; + private static final String MERGE_STRATEGY = "MERGE"; + private static final String DEFAULT_STRATEGY = "DEFAULT"; + + private final StagePartitioner stagePartitioner = new StagePartitioner(); + + public WorkStealingStatePass() { + super(WorkStealingStatePass.class); + } + + @Override + public IRDAG apply(final IRDAG irdag) { + irdag.topologicalDo(irVertex -> { + final boolean notConnectedToO2OEdge = irdag.getIncomingEdgesOf(irVertex).stream() + .map(edge -> edge.getPropertyValue(CommunicationPatternProperty.class).get()) + .noneMatch(property -> property.equals(CommunicationPatternProperty.Value.ONE_TO_ONE)); + if (irVertex instanceof OperatorVertex && notConnectedToO2OEdge) { + Transform transform = ((OperatorVertex) irVertex).getTransform(); + String transformFullName = ((OperatorVertex) irVertex).getTransformFullName(); + if (transform.toString().contains("work stealing") || transformFullName.contains("work stealing")) { + irVertex.setProperty(WorkStealingStateProperty.of(SPLIT_STRATEGY)); + } else if (transform.toString().contains("merge") || transformFullName.contains("merge")) { + irVertex.setProperty(WorkStealingStateProperty.of(MERGE_STRATEGY)); + } else { + irVertex.setProperty(WorkStealingStateProperty.of(DEFAULT_STRATEGY)); + } + } else { + irVertex.setProperty(WorkStealingStateProperty.of(DEFAULT_STRATEGY)); + } + }); + return tidyWorkStealingAnnotation(irdag); + } + + + /** + * Tidy annotated dag. + * Cleanup conditions: + * - The number of SPLIT annotations and MERGE annotations should be equal + * - SPLIT and MERGE should not be together in one stage, but needs to be in adjacent stage. + * - For now, nested work stealing optimizations are not provided. If detected, leave only the + * innermost pair. + * + * @param irdag irdag to cleanup. + * @return cleaned irdag. + */ + private IRDAG tidyWorkStealingAnnotation(final IRDAG irdag) { + String splitVertexId = null; + + final List> splitMergePairs = new ArrayList<>(); + final Set pairedVertices = new HashSet<>(); + final Map> stageIdToStageVertices = new HashMap<>(); + + // Make SPLIT - MERGE vertex pair. + for (IRVertex vertex : irdag.getTopologicalSort()) { + if (vertex.getPropertyValue(WorkStealingStateProperty.class).get().equals(SPLIT_STRATEGY)) { + if (splitVertexId != null) { + // nested SPLIT vertex detected: delete the prior one. + irdag.getVertexById(splitVertexId).setProperty(WorkStealingStateProperty.of(DEFAULT_STRATEGY)); + } + splitVertexId = vertex.getId(); + } else if (vertex.getPropertyValue(WorkStealingStateProperty.class).get().equals(MERGE_STRATEGY)) { + if (splitVertexId != null) { + splitMergePairs.add(Pair.of(splitVertexId, vertex.getId())); + pairedVertices.add(splitVertexId); + pairedVertices.add(vertex.getId()); + splitVertexId = null; + } else { + // no corresponding SPLIT vertex: delete + vertex.setProperty(WorkStealingStateProperty.of(DEFAULT_STRATEGY)); + } + } + } + + final Map vertexToStageId = stagePartitioner.apply(irdag); + + for (Pair splitMergePair : splitMergePairs) { + IRVertex splitVertex = irdag.getVertexById(splitMergePair.left()); + IRVertex mergeVertex = irdag.getVertexById(splitMergePair.right()); + + if (vertexToStageId.get(splitVertex) >= vertexToStageId.get(mergeVertex) + || irdag.getIncomingEdgesOf(mergeVertex).stream() + .map(Edge::getSrc) + .map(vertexToStageId::get) + .noneMatch(stageId -> stageId.equals(vertexToStageId.get(splitVertex)))) { + // split vertex is descendent of merge vertex or they are in the same stage, + // or they are not in adjacent stages + splitVertex.setProperty(WorkStealingStateProperty.of(DEFAULT_STRATEGY)); + mergeVertex.setProperty(WorkStealingStateProperty.of(DEFAULT_STRATEGY)); + pairedVertices.remove(splitVertex.getId()); + pairedVertices.remove(mergeVertex.getId()); + } + } + + irdag.topologicalDo(vertex -> { + if (!vertex.getPropertyValue(WorkStealingStateProperty.class) + .orElse(DEFAULT_STRATEGY).equals(DEFAULT_STRATEGY)) { + if (!pairedVertices.contains(vertex.getId())) { + vertex.setProperty(WorkStealingStateProperty.of(DEFAULT_STRATEGY)); + } + } + }); + + // update execution property of other vertices in same stage. + vertexToStageId.forEach((vertex, stageId) -> { + if (!stageIdToStageVertices.containsKey(stageId)) { + stageIdToStageVertices.put(stageId, new HashSet<>()); + } + stageIdToStageVertices.get(stageId).add(vertex); + }); + + for (String vertexId : pairedVertices) { + IRVertex vertex = irdag.getVertexById(vertexId); + Set stageVertices = stageIdToStageVertices.get(vertexToStageId.get(vertex)); + String strategy = vertex.getPropertyValue(WorkStealingStateProperty.class) + .orElse(DEFAULT_STRATEGY); + for (IRVertex stageVertex : stageVertices) { + stageVertex.setProperty(WorkStealingStateProperty.of(strategy)); + } + } + + return irdag; + } +} diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/WorkStealingSubSplitPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/WorkStealingSubSplitPass.java new file mode 100644 index 0000000000..d0d072aab2 --- /dev/null +++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/WorkStealingSubSplitPass.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.nemo.compiler.optimizer.pass.compiletime.annotating; + +import org.apache.nemo.common.ir.IRDAG; +import org.apache.nemo.common.ir.edge.IREdge; +import org.apache.nemo.common.ir.vertex.IRVertex; +import org.apache.nemo.common.ir.vertex.executionproperty.WorkStealingStateProperty; +import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty; +import org.apache.nemo.common.ir.vertex.executionproperty.WorkStealingSubSplitProperty; +import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires; + +import java.util.HashMap; +import java.util.Map; + +/** + * Optimization pass for tagging work stealing sub-split execution property. + */ +@Annotates(WorkStealingSubSplitProperty.class) +@Requires({WorkStealingStateProperty.class, ParallelismProperty.class}) +public final class WorkStealingSubSplitPass extends AnnotatingPass { + private static final String SPLIT_STRATEGY = "SPLIT"; + private static final String MERGE_STRATEGY = "MERGE"; + private static final String DEFAULT_STRATEGY = "DEFAULT"; + + private static final int MAX_SUB_SPLIT_NUM = 5; + private static final int DEFAULT_SUB_SPLIT_NUM = 1; + + /** + * Default Constructor. + */ + public WorkStealingSubSplitPass() { + super(WorkStealingSubSplitPass.class); + } + + @Override + public IRDAG apply(final IRDAG irdag) { + final Map vertexToSplitNum = new HashMap<>(); + + for (IRVertex vertex : irdag.getTopologicalSort()) { + if (vertex.getPropertyValue(WorkStealingStateProperty.class) + .orElse(DEFAULT_STRATEGY).equals(SPLIT_STRATEGY)) { + int maxSourceParallelism = irdag.getIncomingEdgesOf(vertex).stream().map(IREdge::getSrc) + .mapToInt(v -> v.getPropertyValue(ParallelismProperty.class).orElse(DEFAULT_SUB_SPLIT_NUM)) + .max().orElse(DEFAULT_SUB_SPLIT_NUM); + if (maxSourceParallelism > MAX_SUB_SPLIT_NUM) { + vertex.setProperty(WorkStealingSubSplitProperty.of(MAX_SUB_SPLIT_NUM)); + vertexToSplitNum.put(vertex, MAX_SUB_SPLIT_NUM); + } else { + vertex.setProperty(WorkStealingSubSplitProperty.of(maxSourceParallelism)); + vertexToSplitNum.put(vertex, maxSourceParallelism); + } + } else { + vertex.setProperty(WorkStealingSubSplitProperty.of(DEFAULT_SUB_SPLIT_NUM)); + } + } + + return irdag; + } +} diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/composite/WorkStealingCompositePass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/composite/WorkStealingCompositePass.java new file mode 100644 index 0000000000..9c8914ee6f --- /dev/null +++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/composite/WorkStealingCompositePass.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.nemo.compiler.optimizer.pass.compiletime.composite; + +import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.*; + +import java.util.Arrays; + +/** + * Composite pass for work stealing. + */ +public class WorkStealingCompositePass extends CompositePass { + /** + * Default constructor. + */ + public WorkStealingCompositePass() { + super(Arrays.asList( + new WorkStealingStatePass(), + new WorkStealingSubSplitPass() + )); + } +} diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/WorkStealingPolicy.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/WorkStealingPolicy.java new file mode 100644 index 0000000000..b094606f32 --- /dev/null +++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/WorkStealingPolicy.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.nemo.compiler.optimizer.policy; + +import org.apache.nemo.common.ir.IRDAG; +import org.apache.nemo.compiler.optimizer.pass.compiletime.composite.DefaultCompositePass; +import org.apache.nemo.compiler.optimizer.pass.compiletime.composite.WorkStealingCompositePass; +import org.apache.nemo.compiler.optimizer.pass.runtime.Message; + +/** + * Policy for work stealing. + */ +public final class WorkStealingPolicy implements Policy { + public static final PolicyBuilder BUILDER = + new PolicyBuilder() + .registerCompileTimePass(new DefaultCompositePass()) + .registerCompileTimePass(new WorkStealingCompositePass()); + private final Policy policy; + + /** + * Default constructor. + */ + public WorkStealingPolicy() { + this.policy = BUILDER.build(); + } + + @Override + public IRDAG runCompileTimeOptimization(final IRDAG dag, final String dagDirectory) { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); + } + + @Override + public IRDAG runRunTimeOptimizations(final IRDAG dag, final Message message) { + return this.policy.runRunTimeOptimizations(dag, message); + } +} diff --git a/compiler/test/src/main/java/org/apache/nemo/compiler/CompilerTestUtil.java b/compiler/test/src/main/java/org/apache/nemo/compiler/CompilerTestUtil.java index 3f5001a0b7..4f8dafdc2b 100644 --- a/compiler/test/src/main/java/org/apache/nemo/compiler/CompilerTestUtil.java +++ b/compiler/test/src/main/java/org/apache/nemo/compiler/CompilerTestUtil.java @@ -104,6 +104,18 @@ public static IRDAG compileWordCountDAG() throws Exception { return compileDAG(mrArgBuilder.build()); } + public static IRDAG compileWordCountWorkStealingDAG() throws Exception { + final String input = ROOT_DIR + "/examples/resources/inputs/test_input_wordcount"; + final String output = ROOT_DIR + "/examples/resources/inputs/test_output"; + final String main = "org.apache.nemo.examples.beam.WordCount"; + + final ArgBuilder mrArgBuilder = new ArgBuilder() + .addJobId("WordCount") + .addUserMain(main) + .addUserArgs(input, output, "true"); + return compileDAG(mrArgBuilder.build()); + } + public static IRDAG compileALSDAG() throws Exception { final String input = ROOT_DIR + "/examples/resources/inputs/test_input_als"; final String numFeatures = "10"; diff --git a/compiler/test/src/test/java/org/apache/nemo/compiler/optimizer/pass/compiletime/composite/WorkStealingCompositePassTest.java b/compiler/test/src/test/java/org/apache/nemo/compiler/optimizer/pass/compiletime/composite/WorkStealingCompositePassTest.java new file mode 100644 index 0000000000..03d7a0980e --- /dev/null +++ b/compiler/test/src/test/java/org/apache/nemo/compiler/optimizer/pass/compiletime/composite/WorkStealingCompositePassTest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.nemo.compiler.optimizer.pass.compiletime.composite; + +import org.apache.nemo.client.JobLauncher; +import org.apache.nemo.common.ir.IRDAG; +import org.apache.nemo.common.ir.vertex.IRVertex; +import org.apache.nemo.common.ir.vertex.executionproperty.WorkStealingStateProperty; +import org.apache.nemo.compiler.CompilerTestUtil; +import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.DefaultParallelismPass; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import static junit.framework.TestCase.assertEquals; + +/** + * Test {@link WorkStealingCompositePass} with MR workload. + */ +@RunWith(PowerMockRunner.class) +@PrepareForTest(JobLauncher.class) +public class WorkStealingCompositePassTest { + private IRDAG mrDAG; + + @Before + public void setUp() throws Exception { + } + + @Test + public void testWorkStealingPass() throws Exception { + mrDAG = CompilerTestUtil.compileWordCountWorkStealingDAG(); + + final IRDAG processedDAG = new WorkStealingCompositePass().apply(new DefaultParallelismPass().apply(mrDAG)); + + int numSplitSVertex = 0; + int numMergeVertex = 0; + + for (IRVertex vertex : processedDAG.getTopologicalSort()) { + if (vertex.getPropertyValue(WorkStealingStateProperty.class).equals("SPLIT")) { + numSplitSVertex++; + } else if (vertex.getPropertyValue(WorkStealingStateProperty.class).equals("MERGE")) { + numMergeVertex++; + } + } + + assertEquals(numSplitSVertex, numMergeVertex); + } +} diff --git a/examples/beam/src/main/java/org/apache/nemo/examples/beam/WordCount.java b/examples/beam/src/main/java/org/apache/nemo/examples/beam/WordCount.java index 367938fd68..b1cd4ed462 100644 --- a/examples/beam/src/main/java/org/apache/nemo/examples/beam/WordCount.java +++ b/examples/beam/src/main/java/org/apache/nemo/examples/beam/WordCount.java @@ -44,10 +44,15 @@ private WordCount() { public static void main(final String[] args) { final String inputFilePath = args[0]; final String outputFilePath = args[1]; + final boolean enableWorkStealing = args.length > 2 && Boolean.parseBoolean(args[2]); final PipelineOptions options = NemoPipelineOptionsFactory.create(); + options.setJobName("WordCount"); - final Pipeline p = generateWordCountPipeline(options, inputFilePath, outputFilePath); + final Pipeline p = enableWorkStealing + ? generateWordCountPipelineForWorkStealing(options, inputFilePath, outputFilePath) + : generateWordCountPipeline(options, inputFilePath, outputFilePath); + p.run().waitUntilFinish(); } @@ -59,7 +64,7 @@ public static void main(final String[] args) { * @return the generated pipeline. */ static Pipeline generateWordCountPipeline(final PipelineOptions options, - final String inputFilePath, final String outputFilePath) { + final String inputFilePath, final String outputFilePath) { final Pipeline p = Pipeline.create(options); final PCollection result = GenericSourceSink.read(p, inputFilePath) .apply(MapElements.>via(new SimpleFunction>() { @@ -72,7 +77,8 @@ public KV apply(final String line) { } })) .apply(Sum.longsPerKey()) - .apply(MapElements., String>via(new SimpleFunction, String>() { + .apply(MapElements., String>via( + new SimpleFunction, String>() { @Override public String apply(final KV kv) { return kv.getKey() + ": " + kv.getValue(); @@ -81,4 +87,37 @@ public String apply(final KV kv) { GenericSourceSink.write(result, outputFilePath); return p; } + + /** + * Static method to generate the word count Beam pipeline with work stealing optimization. + * @param options options for the pipeline. + * @param inputFilePath the input file path. + * @param outputFilePath the output file path. + * @return the generated pipeline. + */ + static Pipeline generateWordCountPipelineForWorkStealing(final PipelineOptions options, + final String inputFilePath, final String outputFilePath) { + final Pipeline p = Pipeline.create(options); + final PCollection result = GenericSourceSink.read(p, inputFilePath) + .apply(MapElements.>via(new SimpleFunction>() { + @Override + public KV apply(final String line) { + final String[] words = line.split(" +"); + final String documentId = words[0] + "#" + words[1]; + final Long count = Long.parseLong(words[2]); + return KV.of(documentId, count); + } + })) + .apply("work stealing", Sum.longsPerKey()) + .apply("merge", Sum.longsPerKey()) + .apply("test work stealing", MapElements., String>via( + new SimpleFunction, String>() { + @Override + public String apply(final KV kv) { + return kv.getKey() + ": " + kv.getValue(); + } + })); + GenericSourceSink.write(result, outputFilePath); + return p; + } } diff --git a/examples/beam/src/test/java/org/apache/nemo/examples/beam/WordCountITCase.java b/examples/beam/src/test/java/org/apache/nemo/examples/beam/WordCountITCase.java index cfe0571434..a609dd1469 100644 --- a/examples/beam/src/test/java/org/apache/nemo/examples/beam/WordCountITCase.java +++ b/examples/beam/src/test/java/org/apache/nemo/examples/beam/WordCountITCase.java @@ -24,6 +24,7 @@ import org.apache.nemo.common.test.ExampleTestUtil; import org.apache.nemo.compiler.optimizer.policy.ConditionalLargeShufflePolicy; import org.apache.nemo.compiler.optimizer.policy.DynamicTaskSizingPolicy; +import org.apache.nemo.compiler.optimizer.policy.WorkStealingPolicy; import org.apache.nemo.examples.beam.policy.*; import org.junit.After; import org.junit.Before; @@ -52,7 +53,7 @@ public final class WordCountITCase { public void setUp() throws Exception { builder = new ArgBuilder() .addUserMain(WordCount.class.getCanonicalName()) - .addUserArgs(inputFilePath, outputFilePath); + .addUserArgs(inputFilePath, outputFilePath, "false"); } @After @@ -73,6 +74,17 @@ public void test() throws Exception { .build()); } + @Test(timeout = ExampleTestArgs.TIMEOUT, expected = Test.None.class) + public void testWorkStealing() throws Exception { + JobLauncher.main(new ArgBuilder() + .addUserMain(WordCount.class.getCanonicalName()) + .addUserArgs(inputFilePath, outputFilePath, "true") + .addResourceJson(executorResourceFileName) + .addJobId(WordCountITCase.class.getSimpleName() + "_workStealing") + .addOptimizationPolicy(WorkStealingPolicy.class.getCanonicalName()) + .build()); + } + @Test(timeout = ExampleTestArgs.TIMEOUT, expected = Test.None.class) public void testLargeShuffle() throws Exception { JobLauncher.main(builder diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/RuntimeIdManager.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/RuntimeIdManager.java index 255cca70f5..6d69b992fc 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/RuntimeIdManager.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/RuntimeIdManager.java @@ -71,7 +71,17 @@ public static String generateTaskId(final String stageId, final int index, final if (index < 0 || attempt < 0) { throw new IllegalStateException(index + ", " + attempt); } - return stageId + SPLITTER + index + SPLITTER + attempt; + return stageId + SPLITTER + index + SPLITTER + "*" + SPLITTER + attempt; + } + + public static String generateWorkStealingTaskId(final String stageId, + final int index, + final int partial, + final int attempt) { + if (index < 0 || partial < 0 || attempt < 0) { + throw new IllegalStateException(index + ", " + partial + ", " + attempt); + } + return stageId + SPLITTER + index + SPLITTER + partial + SPLITTER + attempt; } /** @@ -92,8 +102,14 @@ public static String generateExecutorId() { */ public static String generateBlockId(final String runtimeEdgeId, final String producerTaskId) { - return runtimeEdgeId + SPLITTER + getIndexFromTaskId(producerTaskId) - + SPLITTER + getAttemptFromTaskId(producerTaskId); + if (isWorkStealingTask(producerTaskId)) { + return runtimeEdgeId + SPLITTER + getIndexFromTaskId(producerTaskId) + + SPLITTER + getSubSplitIndexFromTaskId(producerTaskId) + + SPLITTER + getAttemptFromTaskId(producerTaskId); + } else { + return runtimeEdgeId + SPLITTER + getIndexFromTaskId(producerTaskId) + + SPLITTER + getAttemptFromTaskId(producerTaskId); + } } /** @@ -109,8 +125,15 @@ public static String generateBlockId(final String runtimeEdgeId, * @return the generated WILDCARD ID */ public static String generateBlockIdWildcard(final String runtimeEdgeId, - final int producerTaskIndex) { - return runtimeEdgeId + SPLITTER + producerTaskIndex + SPLITTER + "*"; + final int producerTaskIndex, + final String subSplitIndex) { + if (!subSplitIndex.equals("*")) { + return runtimeEdgeId + SPLITTER + producerTaskIndex + + SPLITTER + subSplitIndex + SPLITTER + "*"; + } else { + return runtimeEdgeId + SPLITTER + producerTaskIndex + SPLITTER + "*"; + } + } /** @@ -123,6 +146,9 @@ public static long generateMessageId() { } //////////////////////////////////////////////////////////////// Parse IDs + public static boolean isWorkStealingBlock(final String blockId) { + return split(blockId).length == 4; + } /** * Extracts runtime edge ID from a block ID. @@ -144,6 +170,20 @@ public static int getTaskIndexFromBlockId(final String blockId) { return Integer.valueOf(split(blockId)[1]); } + /** + * Extracts task index from a block ID. + * + * @param blockId the block ID to extract. + * @return the task index. + */ + public static String getTaskSubSplitIndexFromBlockId(final String blockId) { + if (isWorkStealingBlock(blockId)) { + return split(blockId)[2]; + } else { + return "*"; + } + } + /** * Extracts wild card from a block ID. * @@ -151,7 +191,9 @@ public static int getTaskIndexFromBlockId(final String blockId) { * @return the wild card. */ public static String getWildCardFromBlockId(final String blockId) { - return generateBlockIdWildcard(getRuntimeEdgeIdFromBlockId(blockId), getTaskIndexFromBlockId(blockId)); + return generateBlockIdWildcard(getRuntimeEdgeIdFromBlockId(blockId), + getTaskIndexFromBlockId(blockId), + getTaskSubSplitIndexFromBlockId(blockId)); } /** @@ -174,6 +216,13 @@ public static int getIndexFromTaskId(final String taskId) { return Integer.valueOf(split(taskId)[1]); } + public static boolean isWorkStealingTask(final String taskId) { + return !split(taskId)[2].equals("*"); + } + + public static int getSubSplitIndexFromTaskId(final String taskId) { + return split(taskId)[2].equals("*") ? 0 : Integer.valueOf(split(taskId)[2]); + } /** * Extracts the attempt from a task ID. * @@ -181,7 +230,7 @@ public static int getIndexFromTaskId(final String taskId) { * @return the attempt. */ public static int getAttemptFromTaskId(final String taskId) { - return Integer.valueOf(split(taskId)[2]); + return Integer.valueOf(split(taskId)[3]); } private static String[] split(final String id) { diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGenerator.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGenerator.java index 83728e4f03..c63134e6f3 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGenerator.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/PhysicalPlanGenerator.java @@ -32,8 +32,10 @@ import org.apache.nemo.common.ir.vertex.IRVertex; import org.apache.nemo.common.ir.vertex.OperatorVertex; import org.apache.nemo.common.ir.vertex.SourceVertex; +import org.apache.nemo.common.ir.vertex.executionproperty.WorkStealingStateProperty; import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty; import org.apache.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty; +import org.apache.nemo.common.ir.vertex.executionproperty.WorkStealingSubSplitProperty; import org.apache.nemo.common.ir.vertex.utility.SamplingVertex; import org.apache.nemo.conf.JobConf; import org.apache.nemo.runtime.common.RuntimeIdManager; @@ -132,6 +134,7 @@ private void handleDuplicateEdgeGroupProperty(final DAG dagOfS */ public DAG stagePartitionIrDAG(final IRDAG irDAG) { final StagePartitioner stagePartitioner = new StagePartitioner(); + final DAGBuilder dagOfStagesBuilder = new DAGBuilder<>(); final Set interStageEdges = new HashSet<>(); final Map stageIdToStageMap = new HashMap<>(); @@ -208,12 +211,20 @@ public DAG stagePartitionIrDAG(final IRDAG irDAG) { if (!stageInternalDAGBuilder.isEmpty()) { final DAG> stageInternalDAG = stageInternalDAGBuilder.buildWithoutSourceSinkCheck(); + // check if this stage is subject of work stealing optimization + boolean isWorkStealingStage = stageInternalDAG.getVertices().stream() + .anyMatch(vertex -> vertex.getPropertyValue(WorkStealingStateProperty.class) + .orElse("DEFAULT").equals("SPLIT")); + int numSubSplit = stageInternalDAG.getVertices().stream() + .mapToInt(v -> v.getPropertyValue(WorkStealingSubSplitProperty.class).orElse(1)) + .max().orElse(1); final Stage stage = new Stage( stageIdentifier, taskIndices, stageInternalDAG, stageProperties, - vertexIdToReadables); + vertexIdToReadables, + isWorkStealingStage ? numSubSplit : 1); dagOfStagesBuilder.addVertex(stage); stageIdToStageMap.put(stageId, stage); } diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java index a7f472c0da..df106704e4 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java @@ -42,6 +42,7 @@ */ public final class Stage extends Vertex { private final List taskIndices; + private final int subSplitNum; private final DAG> irDag; private final byte[] serializedIRDag; private final List> vertexIdToReadables; @@ -61,13 +62,15 @@ public Stage(final String stageId, final List taskIndices, final DAG> irDag, final ExecutionPropertyMap executionProperties, - final List> vertexIdToReadables) { + final List> vertexIdToReadables, + final int subSplitNum) { // 이거 어떻게 설정해줄 수 있는지 생각!!! -> ws용으로 하겠다는 얘기있으면 10으로 하게... super(stageId); this.taskIndices = taskIndices; this.irDag = irDag; this.serializedIRDag = SerializationUtils.serialize(irDag); this.executionProperties = executionProperties; this.vertexIdToReadables = vertexIdToReadables; + this.subSplitNum = subSplitNum; } /** @@ -155,6 +158,10 @@ public List> getVertexIdToReadables() { return vertexIdToReadables; } + public int getSubSplitNum() { + return subSplitNum; + } + @Override public ObjectNode getPropertiesAsJsonNode() { final ObjectNode node = JsonNodeFactory.instance.objectNode(); diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/StagePartitioner.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/StagePartitioner.java index bc20448220..ec4db68219 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/StagePartitioner.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/StagePartitioner.java @@ -25,6 +25,8 @@ import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty; import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty; import org.apache.nemo.common.ir.vertex.IRVertex; +import org.apache.nemo.common.ir.vertex.executionproperty.WorkStealingStateProperty; +import org.apache.nemo.common.ir.vertex.executionproperty.WorkStealingSubSplitProperty; import org.apache.reef.annotations.audience.DriverSide; import java.util.HashMap; @@ -49,6 +51,14 @@ public final class StagePartitioner implements Function> ignoredPropertyKeys = ConcurrentHashMap.newKeySet(); private final MutableInt nextStageIndex = new MutableInt(0); + /** + * Default Constructor. + */ + public StagePartitioner() { + addIgnoredPropertyKey(WorkStealingStateProperty.class); + addIgnoredPropertyKey(WorkStealingSubSplitProperty.class); + } + /** * By default, the stage partitioner merges two vertices into one stage if and only if the two vertices have * same set of {@link VertexExecutionProperty}. diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockInputReader.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockInputReader.java index 7f2a7e3a53..280d6aeceb 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockInputReader.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockInputReader.java @@ -30,6 +30,8 @@ import org.apache.nemo.common.ir.executionproperty.EdgeExecutionProperty; import org.apache.nemo.common.ir.executionproperty.ExecutionPropertyMap; import org.apache.nemo.common.ir.vertex.IRVertex; +import org.apache.nemo.common.ir.vertex.executionproperty.WorkStealingStateProperty; +import org.apache.nemo.common.ir.vertex.executionproperty.WorkStealingSubSplitProperty; import org.apache.nemo.runtime.common.RuntimeIdManager; import org.apache.nemo.runtime.common.plan.RuntimeEdge; import org.apache.nemo.runtime.common.plan.StageEdge; @@ -53,6 +55,10 @@ public final class BlockInputReader implements InputReader { private final String dstTaskId; private final int dstTaskIndex; + private static final String SPLIT_STRATEGY = "SPLIT"; + private static final String MERGE_STRATEGY = "MERGE"; + private static final String DEFAULT_STRATEGY = "DEFAULT"; + /** * Attributes that specify how we should read the input. */ @@ -82,14 +88,52 @@ public List> read() { case ONE_TO_ONE: return Collections.singletonList(readOneToOne()); case BROADCAST: - return readBroadcast(index -> true); + return readBroadcast(index -> true, Optional.empty(), 1); case SHUFFLE: - return readDataInRange(index -> true); + return readDataInRange(index -> true, Optional.empty(), 1); default: throw new UnsupportedCommPatternException(new Exception("Communication pattern not supported")); } } + /** + * An extended version of {@link #read()} with work stealing options. + * - DEFAULT STRATEGY: {@link #read()} + * - SPLIT STRATEGY: {@link #readPartial(int, int)} + * - MERGE STRATEGY: {@link #readSplitBlocks(int, int)} + * + * @param workStealingState work stealing strategy. + * @param numSubSplit number to split within a task index. + * @param subSplitIndex index of sub split task. + * @return + */ + @Override + public List> read(final String workStealingState, + final int numSubSplit, + final int subSplitIndex) { + if (workStealingState.equals(MERGE_STRATEGY) + && srcVertex.getPropertyValue(WorkStealingStateProperty.class).orElse(DEFAULT_STRATEGY) + .equals(SPLIT_STRATEGY)) { + /* MERGE case */ + return readSplitBlocks(InputReader.getSourceParallelism(this), + srcVertex.getPropertyValue(WorkStealingSubSplitProperty.class).orElse(1)); + } else { + + if (workStealingState.equals(SPLIT_STRATEGY)) { + /* SPLIT case */ + final int srcParallelism = InputReader.getSourceParallelism(this); + final int leftInterval = subSplitIndex * (srcParallelism / numSubSplit); + final int rightInterval = numSubSplit == subSplitIndex + 1 + ? srcParallelism : (subSplitIndex + 1) * (srcParallelism / numSubSplit); + return readPartial(leftInterval, rightInterval); + } + + /* DEFAULT case */ + return read(); + } + } + + @Override public CompletableFuture retry(final int desiredIndex) { final Optional comValueOptional = @@ -100,9 +144,55 @@ public CompletableFuture retry(final int desiredI case ONE_TO_ONE: return readOneToOne(); case BROADCAST: - return checkSingleElement(readBroadcast(index -> index == desiredIndex)); + return checkSingleElement(readBroadcast(index -> index == desiredIndex, Optional.empty(), 1)); + case SHUFFLE: + return checkSingleElement(readDataInRange(index -> index == desiredIndex, Optional.empty(), 1)); + default: + throw new UnsupportedCommPatternException(new Exception("Communication pattern not supported")); + } + } + + /** + * An extended version of {@link #retry(int)} with work stealing options. + * + * @param workStealingState work stealing strategy. (SPLIT, MERGE, DEFAULT) + * @param numSubSplit number of sub-splits in SPLIT state, default by 1 in other states. + * @param desiredIndex desired index. + * @return iterator to retry. + */ + @Override + public CompletableFuture retry(final String workStealingState, + final int numSubSplit, + final int desiredIndex) { + + final boolean isMergeAfterSplit = workStealingState.equals(MERGE_STRATEGY) + && srcVertex.getPropertyValue(WorkStealingStateProperty.class).orElse(DEFAULT_STRATEGY) + .equals(SPLIT_STRATEGY); + + if (!isMergeAfterSplit && !workStealingState.equals(SPLIT_STRATEGY)) { + return retry(desiredIndex); + } + + final int srcParallelism = InputReader.getSourceParallelism(this); + final int srcNumSubSplit = srcVertex.getPropertyValue(WorkStealingSubSplitProperty.class).orElse(1); + final int trueIndex = translateIndex(workStealingState, numSubSplit, desiredIndex); + final int subIndex = isMergeAfterSplit ? desiredIndex % srcNumSubSplit : 0; + + final Optional comValueOptional = + runtimeEdge.getPropertyValue(CommunicationPatternProperty.class); + final CommunicationPatternProperty.Value comValue = comValueOptional.orElseThrow(IllegalStateException::new); + + switch (comValue) { + case ONE_TO_ONE: + return readOneToOne(); + case BROADCAST: + return readBroadcast(index -> index == trueIndex, + isMergeAfterSplit ? Optional.of(srcParallelism) : Optional.empty(), + isMergeAfterSplit ? srcNumSubSplit : 1).get(subIndex); case SHUFFLE: - return checkSingleElement(readDataInRange(index -> index == desiredIndex)); + return readDataInRange(index -> index == trueIndex, + isMergeAfterSplit ? Optional.of(srcParallelism) : Optional.empty(), + isMergeAfterSplit ? srcNumSubSplit : 1).get(subIndex); default: throw new UnsupportedCommPatternException(new Exception("Communication pattern not supported")); } @@ -127,47 +217,65 @@ private CompletableFuture checkSingleElement( } /** - * See {@link RuntimeIdManager#generateBlockIdWildcard(String, int)} for information on block wildcards. + * See {@link RuntimeIdManager#generateBlockIdWildcard(String, int, String)} for information on block wildcards. * * @param producerTaskIndex to use. * @return wildcard block id that corresponds to "ANY" task attempt of the task index. */ - private String generateWildCardBlockId(final int producerTaskIndex) { + private String generateWildCardBlockId(final int producerTaskIndex, + final String subSplitIndex) { final Optional duplicateDataProperty = runtimeEdge.getPropertyValue(DuplicateEdgeGroupProperty.class); if (!duplicateDataProperty.isPresent() || duplicateDataProperty.get().getGroupSize() <= 1) { - return RuntimeIdManager.generateBlockIdWildcard(runtimeEdge.getId(), producerTaskIndex); + return RuntimeIdManager.generateBlockIdWildcard(runtimeEdge.getId(), producerTaskIndex, subSplitIndex); } final String duplicateEdgeId = duplicateDataProperty.get().getRepresentativeEdgeId(); - return RuntimeIdManager.generateBlockIdWildcard(duplicateEdgeId, producerTaskIndex); + return RuntimeIdManager.generateBlockIdWildcard(duplicateEdgeId, producerTaskIndex, subSplitIndex); } private CompletableFuture readOneToOne() { - final String blockIdWildcard = generateWildCardBlockId(dstTaskIndex); + final String blockIdWildcard = generateWildCardBlockId(dstTaskIndex, "*"); return blockManagerWorker.readBlock( blockIdWildcard, runtimeEdge.getId(), runtimeEdge.getExecutionProperties(), HashRange.all()); } - private List> readBroadcast(final Predicate predicate) { - final int numSrcTasks = InputReader.getSourceParallelism(this); + /** + * Read data in full range of hash value. + * + * @param predicate function of the index. + * @param numSrcIndex not empty only if in MERGE strategy. + * @param numSubSplit > 1 only if in MERGE strategy. + * @return the list of the completable future of the data. + */ + private List> readBroadcast(final Predicate predicate, + final Optional numSrcIndex, + final int numSubSplit) { + final int numSrcTasks = numSrcIndex.orElse(InputReader.getSourceParallelism(this)); final List> futures = new ArrayList<>(); for (int srcTaskIdx = 0; srcTaskIdx < numSrcTasks; srcTaskIdx++) { if (predicate.test(srcTaskIdx)) { - final String blockIdWildcard = generateWildCardBlockId(srcTaskIdx); - futures.add(blockManagerWorker.readBlock( - blockIdWildcard, runtimeEdge.getId(), runtimeEdge.getExecutionProperties(), HashRange.all())); + for (int subSplitIdx = 0; subSplitIdx < numSubSplit; subSplitIdx++) { + final String blockIdWildcard = generateWildCardBlockId(srcTaskIdx, + numSubSplit == 1 ? "*" : Integer.toString(subSplitIdx)); + futures.add(blockManagerWorker.readBlock( + blockIdWildcard, runtimeEdge.getId(), runtimeEdge.getExecutionProperties(), HashRange.all())); + } } } - return futures; } /** * Read data in the assigned range of hash value. * + * @param predicate function of the index. + * @param numSrcIndex not empty only if in MERGE strategy. + * @param numSubSplit > 1 only if in MERGE strategy. * @return the list of the completable future of the data. */ - private List> readDataInRange(final Predicate predicate) { + private List> readDataInRange(final Predicate predicate, + final Optional numSrcIndex, + final int numSubSplit) { assert (runtimeEdge instanceof StageEdge); final List keyRangeList = ((StageEdge) runtimeEdge).getKeyRanges(); final KeyRange hashRangeToRead = keyRangeList.get(dstTaskIndex); @@ -180,16 +288,96 @@ private List> readDataInRange(f - ((HashRange) hashRangeToRead).rangeBeginInclusive(); metricMessageSender.send("TaskMetric", dstTaskId, "taskSizeRatio", SerializationUtils.serialize(partitionerProperty / taskSize)); - final int numSrcTasks = InputReader.getSourceParallelism(this); + final int numSrcTasks = numSrcIndex.orElse(InputReader.getSourceParallelism(this)); final List> futures = new ArrayList<>(); for (int srcTaskIdx = 0; srcTaskIdx < numSrcTasks; srcTaskIdx++) { if (predicate.test(srcTaskIdx)) { - final String blockIdWildcard = generateWildCardBlockId(srcTaskIdx); - futures.add(blockManagerWorker.readBlock( - blockIdWildcard, runtimeEdge.getId(), runtimeEdge.getExecutionProperties(), hashRangeToRead)); + for (int subSplitIdx = 0; subSplitIdx < numSubSplit; subSplitIdx++) { + final String blockIdWildcard = generateWildCardBlockId(srcTaskIdx, + numSubSplit == 1 ? "*" : Integer.toString(subSplitIdx)); + futures.add(blockManagerWorker.readBlock( + blockIdWildcard, runtimeEdge.getId(), runtimeEdge.getExecutionProperties(), hashRangeToRead)); + } } } - return futures; } + + // methods related to work stealing policy + + /** + * Read blocks in work stealing SPLIT strategy. + * + * @param startIndex start index (inclusive) to read. + * @param endIndex end index (exclusive) to read. + * @return the list of the completable future of the data. + */ + private List> readPartial(final int startIndex, + final int endIndex) { + final Optional comValueOptional = + runtimeEdge.getPropertyValue(CommunicationPatternProperty.class); + final CommunicationPatternProperty.Value comValue = comValueOptional.orElseThrow(IllegalStateException::new); + + switch (comValue) { + case ONE_TO_ONE: + return Collections.singletonList(readOneToOne()); + case BROADCAST: + return readBroadcast(index -> startIndex <= index && index < endIndex, Optional.empty(), 1); + case SHUFFLE: + return readDataInRange(index -> startIndex <= index && index < endIndex, Optional.empty(), 1); + default: + throw new UnsupportedCommPatternException(new Exception("Communication pattern not supported")); + } + } + + /** + * Read blocks in work stealing MERGE strategy. + * + * @param srcParallelism src stage parallelism. + * @param srcNumSubSplit number of sub-split blocks per src task index. + * @return List of iterators. + */ + private List> readSplitBlocks(final int srcParallelism, + final int srcNumSubSplit) { + final Optional comValueOptional = + runtimeEdge.getPropertyValue(CommunicationPatternProperty.class); + final CommunicationPatternProperty.Value comValue = comValueOptional.orElseThrow(IllegalStateException::new); + + switch (comValue) { + case ONE_TO_ONE: + return Collections.singletonList(readOneToOne()); + case BROADCAST: + return readBroadcast(index -> true, Optional.of(srcParallelism), srcNumSubSplit); + case SHUFFLE: + return readDataInRange(index -> true, Optional.of(srcParallelism), srcNumSubSplit); + default: + throw new UnsupportedCommPatternException(new Exception("Communication pattern not supported")); + } + } + + /** + * translate index for consistency. + * + * @param workStealingState Work stealing startegy. + * @param numSubSplit number of sub-split tasks of the task with given index. + * @param index task index. + * @return + */ + private int translateIndex(final String workStealingState, + final int numSubSplit, + final int index) { + if (workStealingState.equals(SPLIT_STRATEGY)) { + /* SPLIT strategy */ + int srcParallelism = InputReader.getSourceParallelism(this); + return Math.round(srcParallelism / numSubSplit) * RuntimeIdManager.getSubSplitIndexFromTaskId(dstTaskId) + index; + } else if (workStealingState.equals(MERGE_STRATEGY) + && srcVertex.getPropertyValue(WorkStealingStateProperty.class).orElse(DEFAULT_STRATEGY) + .equals(SPLIT_STRATEGY)) { + /* MERGE strategy*/ + int srcNumSubSplit = srcVertex.getPropertyValue(WorkStealingSubSplitProperty.class).orElse(1); + return index / srcNumSubSplit; + } else { + return index; + } + } } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java index 97cde037c6..2e4add6651 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java @@ -72,11 +72,11 @@ public final class BlockOutputWriter implements OutputWriter { this.blockManagerWorker = blockManagerWorker; this.blockStoreValue = runtimeEdge.getPropertyValue(DataStoreProperty.class) .orElseThrow(() -> new RuntimeException("No data store property on the edge")); - blockToWrite = blockManagerWorker.createBlock( + this.blockToWrite = blockManagerWorker.createBlock( RuntimeIdManager.generateBlockId(runtimeEdge.getId(), srcTaskId), blockStoreValue); final Optional duplicateDataProperty = runtimeEdge.getPropertyValue(DuplicateEdgeGroupProperty.class); - nonDummyBlock = !duplicateDataProperty.isPresent() + this.nonDummyBlock = !duplicateDataProperty.isPresent() || duplicateDataProperty.get().getRepresentativeEdgeId().equals(runtimeEdge.getId()) || duplicateDataProperty.get().getGroupSize() <= 1; } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/InputReader.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/InputReader.java index 9cd0195fdf..78c804bfc3 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/InputReader.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/InputReader.java @@ -38,6 +38,14 @@ public interface InputReader { */ List> read(); + /** Reads input data depending on the communication pattern of the srcVertex. + * + * @return the list of iterators. + */ + List> read(String workStealingState, + int maxSplitNum, + int index); + /** * Retry reading input data. * @@ -46,6 +54,17 @@ public interface InputReader { */ CompletableFuture retry(int index); + /** + * Retry reading input data during work stealing. + * @param workStealingState work stealing state (SPLIT, MERGE, DEFAULT) + * @param numSubSplit number of sub-splits in SPLIT state, default by 1 in other states. + * @param index of the failed iterator in the list returned by read(). + * @return the retried iterator. + */ + CompletableFuture retry(String workStealingState, + int numSubSplit, + int index); + IRVertex getSrcIrVertex(); ExecutionPropertyMap getProperties(); diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeInputReader.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeInputReader.java index cab2ed2f43..e69a31b6a8 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeInputReader.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeInputReader.java @@ -83,11 +83,25 @@ public List> read() { } } + @Override + public List> read(final String workStealingState, + final int maxSplitNum, + final int index) { + return read(); + } + @Override public CompletableFuture retry(final int index) { throw new UnsupportedOperationException(String.valueOf(index)); } + @Override + public CompletableFuture retry(final String workStealingState, + final int numSubSplit, + final int index) { + return retry(index); + } + @Override public ExecutionPropertyMap getProperties() { return runtimeEdge.getExecutionProperties(); diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java index a8ae4a9306..990b57b68f 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java @@ -22,6 +22,7 @@ import org.apache.nemo.common.ir.edge.executionproperty.BlockFetchFailureProperty; import org.apache.nemo.common.ir.vertex.IRVertex; import org.apache.nemo.common.punctuation.Finishmark; +import org.apache.nemo.runtime.common.RuntimeIdManager; import org.apache.nemo.runtime.executor.data.DataUtil; import org.apache.nemo.runtime.executor.datatransfer.InputReader; import org.slf4j.Logger; @@ -40,6 +41,7 @@ class ParentTaskDataFetcher extends DataFetcher { private static final Logger LOG = LoggerFactory.getLogger(ParentTaskDataFetcher.class); + private final String taskId; private final InputReader inputReader; private final LinkedBlockingQueue iteratorQueue; @@ -51,11 +53,21 @@ class ParentTaskDataFetcher extends DataFetcher { private long serBytes = 0; private long encodedBytes = 0; + private final int subSplitNum; + private final String workStealingState; + + ParentTaskDataFetcher(final IRVertex dataSource, final InputReader inputReader, - final OutputCollector outputCollector) { + final OutputCollector outputCollector, + final String workStealingState, + final int subSplitNum, + final String taskId) { super(dataSource, outputCollector); + this.taskId = taskId; this.inputReader = inputReader; + this.workStealingState = workStealingState; + this.subSplitNum = subSplitNum; this.firstFetch = true; this.currentIteratorIndex = 0; this.iteratorQueue = new LinkedBlockingQueue<>(); @@ -135,7 +147,8 @@ private void handleIncomingBlock(final int index, inputReader.getSrcIrVertex().getId(), index); final int twoSecondsInMs = 2 * 1000; Thread.sleep(twoSecondsInMs); - final CompletableFuture retryFuture = inputReader.retry(index); + final CompletableFuture retryFuture = inputReader.retry( + workStealingState, subSplitNum, index); handleIncomingBlock(index, retryFuture); } else if (fetchFailure.equals(BlockFetchFailureProperty.Value.CANCEL_TASK)) { // Retry the entire task @@ -155,7 +168,8 @@ private void handleIncomingBlock(final int index, } private void fetchDataLazily() { - final List> futures = inputReader.read(); + final List> futures = inputReader + .read(workStealingState, subSplitNum, RuntimeIdManager.getSubSplitIndexFromTaskId(taskId)); this.expectedNumOfIterators = futures.size(); for (int i = 0; i < futures.size(); i++) { final int index = i; diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java index 2bf574d396..22bede2ed2 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java @@ -30,6 +30,8 @@ import org.apache.nemo.common.ir.vertex.IRVertex; import org.apache.nemo.common.ir.vertex.OperatorVertex; import org.apache.nemo.common.ir.vertex.SourceVertex; +import org.apache.nemo.common.ir.vertex.executionproperty.WorkStealingStateProperty; +import org.apache.nemo.common.ir.vertex.executionproperty.WorkStealingSubSplitProperty; import org.apache.nemo.common.ir.vertex.transform.MessageAggregatorTransform; import org.apache.nemo.common.ir.vertex.transform.SignalTransform; import org.apache.nemo.common.ir.vertex.transform.Transform; @@ -82,6 +84,7 @@ public final class TaskExecutor { // Dynamic optimization private String idOfVertexPutOnHold; + private String workStealingStrategy; private final PersistentConnectionToMasterMap persistentConnectionToMasterMap; @@ -120,6 +123,7 @@ public TaskExecutor(final Task task, // Prepare data structures final Pair, List> pair = prepare(task, irVertexDag, intermediateDataIOFactory); + this.workStealingStrategy = getWorkStealingStrategy(irVertexDag); this.dataFetchers = pair.left(); this.sortedHarnesses = pair.right(); @@ -287,11 +291,21 @@ irVertex, outputCollector, new TransformContextImpl(broadcastManagerWorker), parentTaskReader, dataFetcherOutputCollector)); } else { + final String workStealingState = irVertexDag.getVertices().stream() + .map(v -> v.getPropertyValue(WorkStealingStateProperty.class).orElse("DEFAULT")) + .filter(s -> !s.equals("DEFAULT")) + .findFirst().orElse("DEFAULT"); + final int numSubSplit = irVertexDag.getVertices().stream() + .mapToInt(v -> v.getPropertyValue(WorkStealingSubSplitProperty.class).orElse(1)) + .max().orElse(1); dataFetcherList.add( new ParentTaskDataFetcher( parentTaskReader.getSrcIrVertex(), parentTaskReader, - dataFetcherOutputCollector)); + dataFetcherOutputCollector, + workStealingState, + numSubSplit, + taskId)); } } }); @@ -714,4 +728,17 @@ private void finalizeOutputWriters(final VertexHarness vertexHarness) { metricMessageSender.send(TASK_METRIC_ID, taskId, "taskOutputBytes", SerializationUtils.serialize(totalWrittenBytes)); } + + private String getWorkStealingStrategy(final DAG> irVertexDag) { + Set strategy = irVertexDag.getVertices().stream() + .map(vertex -> vertex.getPropertyValue(WorkStealingStateProperty.class).orElse("DEFAULT")) + .collect(Collectors.toSet()); + if (strategy.contains("SPLIT")) { + return "SPLIT"; + } else if (strategy.contains("MERGE")) { + return "MERGE"; + } else { + return "DEFAULT"; + } + } } diff --git a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/datatransfer/DataTransferTest.java b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/datatransfer/DataTransferTest.java index 3ce99197e6..d3cfe82c15 100644 --- a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/datatransfer/DataTransferTest.java +++ b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/datatransfer/DataTransferTest.java @@ -532,7 +532,8 @@ private Stage setupStages(final String stageId) { IntStream.range(0, PARALLELISM_TEN).boxed().collect(Collectors.toList()), emptyDag, stageExecutionProperty, - Collections.emptyList()); + Collections.emptyList(), + 1); } /** diff --git a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcherTest.java b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcherTest.java index ab774e968a..832c676fdf 100644 --- a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcherTest.java +++ b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcherTest.java @@ -43,8 +43,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -59,7 +58,6 @@ public final class ParentTaskDataFetcherTest { public void testEmpty() throws Exception { final List empty = new ArrayList<>(0); // empty data final InputReader inputReader = generateInputReader(generateCompletableFuture(empty.iterator())); - // Fetcher final ParentTaskDataFetcher fetcher = createFetcher(inputReader); assertEquals(Finishmark.getInstance(), fetcher.fetchDataElement()); @@ -70,7 +68,6 @@ public void testNull() throws Exception { final List oneNull = new ArrayList<>(1); // empty data oneNull.add(null); final InputReader inputReader = generateInputReader(generateCompletableFuture(oneNull.iterator())); - // Fetcher final ParentTaskDataFetcher fetcher = createFetcher(inputReader); @@ -119,7 +116,9 @@ public void testErrorWhenFutureWithRetry() throws Exception { when(inputReader.retry(anyInt())) .thenReturn(generateCompletableFuture( empty.iterator())); // success upon retry - + when(inputReader.retry(anyString(), anyInt(), anyInt())) + .thenReturn(generateCompletableFuture( + empty.iterator())); // success upon retry // Fetcher should work on retry final ParentTaskDataFetcher fetcher = createFetcher(inputReader); assertEquals(Finishmark.getInstance(), fetcher.fetchDataElement()); @@ -141,7 +140,10 @@ private ParentTaskDataFetcher createFetcher(final InputReader readerForParentTas return new ParentTaskDataFetcher( mock(IRVertex.class), readerForParentTask, // This is the only argument that affects the behavior of ParentTaskDataFetcher - mock(OutputCollector.class)); + mock(OutputCollector.class), + "DEFAULT", + 1, + "DUMMY-0-*-0"); } private InputReader generateInputReader(final CompletableFuture completableFuture, @@ -149,6 +151,7 @@ private InputReader generateInputReader(final CompletableFuture completableFutur final InputReader inputReader = mock(InputReader.class, Mockito.CALLS_REAL_METHODS); when(inputReader.getSrcIrVertex()).thenReturn(mock(IRVertex.class)); when(inputReader.read()).thenReturn(Arrays.asList(completableFuture)); + when(inputReader.read(anyString(), anyInt(), anyInt())).thenReturn(Arrays.asList(completableFuture)); final ExecutionPropertyMap propertyMap = new ExecutionPropertyMap<>(""); for (final EdgeExecutionProperty p : properties) { propertyMap.put(p); diff --git a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/task/TaskExecutorTest.java b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/task/TaskExecutorTest.java index 3e33fec1f4..6c6eeb79dd 100644 --- a/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/task/TaskExecutorTest.java +++ b/runtime/executor/src/test/java/org/apache/nemo/runtime/executor/task/TaskExecutorTest.java @@ -589,6 +589,7 @@ public InputReader answer(final InvocationOnMock invocationOnMock) throws Throwa srcVertex.setProperty(ParallelismProperty.of(SOURCE_PARALLELISM)); when(inputReader.getSrcIrVertex()).thenReturn(srcVertex); when(inputReader.read()).thenReturn(inputFutures); + when(inputReader.read(anyString(), anyInt(), anyInt())).thenReturn(inputFutures); when(inputReader.getProperties()).thenReturn(new ExecutionPropertyMap<>("")); return inputReader; } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/BlockManagerMaster.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/BlockManagerMaster.java index 410ae7168e..7db30cb30b 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/BlockManagerMaster.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/BlockManagerMaster.java @@ -58,7 +58,7 @@ public final class BlockManagerMaster { private final Map> producerTaskIdToBlockIds; // a task can have multiple out-edges /** - * See {@link RuntimeIdManager#generateBlockIdWildcard(String, int)} for information on block wildcards. + * See {@link RuntimeIdManager#generateBlockIdWildcard(String, String)} for information on block wildcards. */ private final Map> blockIdWildcardToMetadataSet; // a metadata = a task attempt output diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java index fd02da95b3..23d19f52a7 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java @@ -76,7 +76,7 @@ public final class PlanStateManager { private final Map stageIdToState; // list of attempt states sorted by attempt idx - private final Map>> stageIdToTaskIdxToAttemptStates; + private final Map>>> stageIdToTaskIdxToAttemptStates; /** * Used for speculative cloning. (in the unit of milliseconds - ms) @@ -102,6 +102,11 @@ public final class PlanStateManager { private final String dagDirectory; private MetricStore metricStore; + /** + * For dynamic optimization. + */ + private final int maxSubTaskSplitNum = 5; + /** * Constructor. */ @@ -168,7 +173,7 @@ private void initializeStates() { // for each task idx of this stage stage.getTaskIndices().forEach(taskIndex -> - stageIdToTaskIdxToAttemptStates.get(stage.getId()).putIfAbsent(taskIndex, new ArrayList<>())); + stageIdToTaskIdxToAttemptStates.get(stage.getId()).putIfAbsent(taskIndex, new ArrayList<>(maxSubTaskSplitNum))); // task states will be initialized lazily in getTaskAttemptsToSchedule() }); } @@ -183,6 +188,7 @@ private void initializeStates() { * @return executable task attempts */ public synchronized List getTaskAttemptsToSchedule(final String stageId) { + // initialization: 첫번째로 만들어지는케이스를 따로 생각해야 함. if (getStageState(stageId).equals(StageState.State.COMPLETE)) { // This stage is done return new ArrayList<>(0); @@ -192,46 +198,62 @@ public synchronized List getTaskAttemptsToSchedule(final String stageId) final List taskAttemptsToSchedule = new ArrayList<>(); final Stage stage = physicalPlan.getStageDAG().getVertexById(stageId); for (final int taskIndex : stage.getTaskIndices()) { - final List attemptStatesForThisTaskIndex = + final List> attemptStatesPerPartialTaskForThisTaskIndex = stageIdToTaskIdxToAttemptStates.get(stageId).get(taskIndex); - - // If one of the attempts is COMPLETE, do not schedule - if (attemptStatesForThisTaskIndex - .stream() - .noneMatch(state -> state.getStateMachine().getCurrentState().equals(TaskState.State.COMPLETE))) { - - // (Step 1) Create new READY attempts, as many as - // # of numOfConcurrentAttempts(including clones) - # of 'not-done' attempts - stageIdToTaskIndexToNumOfClones.putIfAbsent(stageId, new HashMap<>()); - final Optional cloneConf = - stage.getPropertyValue(ClonedSchedulingProperty.class); - final int numOfConcurrentAttempts = cloneConf.isPresent() && cloneConf.get().isUpFrontCloning() - // For now we support up to 1 clone (2 concurrent = 1 original + 1 clone) - ? 2 - // If the property is not set, then we do not clone (= 1 concurrent) - : stageIdToTaskIndexToNumOfClones.get(stageId).getOrDefault(stageId, 1); - final long numOfNotDoneAttempts = attemptStatesForThisTaskIndex.stream().filter(this::isTaskNotDone).count(); - for (int i = 0; i < numOfConcurrentAttempts - numOfNotDoneAttempts; i++) { - attemptStatesForThisTaskIndex.add(new TaskState()); + if (attemptStatesPerPartialTaskForThisTaskIndex.size() == 0) { + // initialize in here + for (int i = 0; i < stage.getSubSplitNum(); i++) { + attemptStatesPerPartialTaskForThisTaskIndex.add(new ArrayList<>()); } + } + for (List attemptStatesForThisPartialTaskIndex : attemptStatesPerPartialTaskForThisTaskIndex) { - // (Step 2) Check max attempt - if (attemptStatesForThisTaskIndex.size() > maxScheduleAttempt) { - throw new RuntimeException( - attemptStatesForThisTaskIndex.size() + " exceeds max attempt " + maxScheduleAttempt); - } + // If one of the attempts is COMPLETE, do not schedule + if (attemptStatesForThisPartialTaskIndex + .stream() + .noneMatch(state -> state.getStateMachine().getCurrentState().equals(TaskState.State.COMPLETE))) { + + // (Step 1) Create new READY attempts, as many as + // # of numOfConcurrentAttempts(including clones) - # of 'not-done' attempts + stageIdToTaskIndexToNumOfClones.putIfAbsent(stageId, new HashMap<>()); + final Optional cloneConf = + stage.getPropertyValue(ClonedSchedulingProperty.class); + final int numOfConcurrentAttempts = cloneConf.isPresent() && cloneConf.get().isUpFrontCloning() + // For now we support up to 1 clone (2 concurrent = 1 original + 1 clone) + ? 2 + // If the property is not set, then we do not clone (= 1 concurrent) + : stageIdToTaskIndexToNumOfClones.get(stageId).getOrDefault(stageId, 1); + final long numOfNotDoneAttempts = attemptStatesForThisPartialTaskIndex + .stream().filter(this::isTaskNotDone).count(); + for (int i = 0; i < numOfConcurrentAttempts - numOfNotDoneAttempts; i++) { + attemptStatesForThisPartialTaskIndex.add(new TaskState()); + } - // (Step 3) Return all READY attempts - for (int attempt = 0; attempt < attemptStatesForThisTaskIndex.size(); attempt++) { - if (attemptStatesForThisTaskIndex.get(attempt).getStateMachine().getCurrentState() - .equals(TaskState.State.READY)) { - taskAttemptsToSchedule.add(RuntimeIdManager.generateTaskId(stageId, taskIndex, attempt)); + // (Step 2) Check max attempt + if (attemptStatesForThisPartialTaskIndex.size() > maxScheduleAttempt) { + throw new RuntimeException( + attemptStatesForThisPartialTaskIndex.size() + " exceeds max attempt " + maxScheduleAttempt); } - } + // (Step 3) Return all READY attempts + for (int attempt = 0; attempt < attemptStatesForThisPartialTaskIndex.size(); attempt++) { + if (attemptStatesForThisPartialTaskIndex.get(attempt).getStateMachine().getCurrentState() + .equals(TaskState.State.READY)) { + if (attemptStatesPerPartialTaskForThisTaskIndex.size() > 1) { + + taskAttemptsToSchedule.add(RuntimeIdManager.generateWorkStealingTaskId(stageId, taskIndex, + attemptStatesPerPartialTaskForThisTaskIndex.indexOf(attemptStatesForThisPartialTaskIndex), attempt)); + } else { + taskAttemptsToSchedule.add(RuntimeIdManager.generateTaskId(stageId, taskIndex, attempt)); + } + + } + } } - } + + } + } return taskAttemptsToSchedule; } @@ -254,13 +276,15 @@ public synchronized Map getExecutingTaskToRunningTimeMs(final Stri final long curTime = System.currentTimeMillis(); final Map result = new HashMap<>(); - final Map> taskIdToState = stageIdToTaskIdxToAttemptStates.get(stageId); + final Map>> taskIdToState = stageIdToTaskIdxToAttemptStates.get(stageId); for (final int taskIndex : taskIdToState.keySet()) { - final List attemptStates = taskIdToState.get(taskIndex); - for (int attempt = 0; attempt < attemptStates.size(); attempt++) { - if (TaskState.State.EXECUTING.equals(attemptStates.get(attempt).getStateMachine().getCurrentState())) { - final String taskId = RuntimeIdManager.generateTaskId(stageId, taskIndex, attempt); - result.put(taskId, curTime - taskIdToStartTimeMs.get(taskId)); + final List> listOfAttemptStates = taskIdToState.get(taskIndex); + for (List attemptStates: listOfAttemptStates) { + for (int attempt = 0; attempt < attemptStates.size(); attempt++) { + if (TaskState.State.EXECUTING.equals(attemptStates.get(attempt).getStateMachine().getCurrentState())) { + final String taskId = RuntimeIdManager.generateTaskId(stageId, taskIndex, attempt); + result.put(taskId, curTime - taskIdToStartTimeMs.get(taskId)); + } } } } @@ -325,8 +349,9 @@ public synchronized void onTaskStateChanged(final String taskId, final TaskState // Log not-yet-completed tasks for us humans to track progress final String stageId = RuntimeIdManager.getStageIdFromTaskId(taskId); - final Map> taskStatesOfThisStage = stageIdToTaskIdxToAttemptStates.get(stageId); + final Map>> taskStatesOfThisStage = stageIdToTaskIdxToAttemptStates.get(stageId); final long numOfCompletedTaskIndicesInThisStage = taskStatesOfThisStage.values().stream() + .flatMap(Collection::stream) .filter(attempts -> { final List states = attempts .stream() @@ -337,8 +362,11 @@ public synchronized void onTaskStateChanged(final String taskId, final TaskState }) .count(); if (newTaskState.equals(TaskState.State.COMPLETE)) { + // 여기 나중에 고쳐야 함 + final int numOfTasksOfThisStage = taskStatesOfThisStage.values().stream() + .mapToInt(partialTasks -> partialTasks.size()).sum(); LOG.info("{} completed: {} Task(s) out of {} are remaining in this stage", - taskId, taskStatesOfThisStage.size() - numOfCompletedTaskIndicesInThisStage, taskStatesOfThisStage.size()); + taskId, numOfTasksOfThisStage - numOfCompletedTaskIndicesInThisStage, numOfTasksOfThisStage); } // Maintain info for speculative execution @@ -361,11 +389,12 @@ public synchronized void onTaskStateChanged(final String taskId, final TaskState } break; - // COMPLETE stage + // COMPLETE stage 여기도 고쳐야 함 case COMPLETE: case ON_HOLD: + Stage currentStage = physicalPlan.getStageDAG().getVertexById(stageId); if (numOfCompletedTaskIndicesInThisStage - == physicalPlan.getStageDAG().getVertexById(stageId).getTaskIndices().size()) { + == (long) currentStage.getTaskIndices().size() * currentStage.getSubSplitNum()) { onStageStateChanged(stageId, StageState.State.COMPLETE); } break; @@ -538,13 +567,26 @@ public synchronized TaskState.State getTaskState(final String taskId) { private Map getTaskAttemptIdsToItsState(final String stageId) { final Map result = new HashMap<>(); - final Map> taskIdToState = stageIdToTaskIdxToAttemptStates.get(stageId); - for (final int taskIndex : taskIdToState.keySet()) { - final List attemptStates = taskIdToState.get(taskIndex); - for (int attempt = 0; attempt < attemptStates.size(); attempt++) { - result.put(RuntimeIdManager.generateTaskId(stageId, taskIndex, attempt), - (TaskState.State) attemptStates.get(attempt).getStateMachine().getCurrentState()); + final Map>> taskIdToState = stageIdToTaskIdxToAttemptStates.get(stageId); + for (int taskIndex : taskIdToState.keySet()) { + List> partialIdxAttempts = taskIdToState.get(taskIndex); + if (partialIdxAttempts.size() > 1) { + for (int partialIdx = 0; partialIdx < partialIdxAttempts.size(); partialIdx++) { + List attemptStates = partialIdxAttempts.get(partialIdx); + for (int attempt = 0; attempt < attemptStates.size(); attempt++) { + result.put(RuntimeIdManager.generateWorkStealingTaskId(stageId, taskIndex, partialIdx, attempt), + (TaskState.State) attemptStates.get(attempt).getStateMachine().getCurrentState()); + } + } + } else { + for (List attemptStates : partialIdxAttempts) { + for (int attempt = 0; attempt < attemptStates.size(); attempt++) { + result.put(RuntimeIdManager.generateTaskId(stageId, taskIndex, attempt), + (TaskState.State) attemptStates.get(attempt).getStateMachine().getCurrentState()); + } + } } + } return result; } @@ -553,6 +595,7 @@ private TaskState getTaskStateHelper(final String taskId) { return stageIdToTaskIdxToAttemptStates .get(RuntimeIdManager.getStageIdFromTaskId(taskId)) .get(RuntimeIdManager.getIndexFromTaskId(taskId)) + .get(RuntimeIdManager.getSubSplitIndexFromTaskId(taskId)) .get(RuntimeIdManager.getAttemptFromTaskId(taskId)); } @@ -566,10 +609,11 @@ private boolean isTaskNotDone(final TaskState taskState) { private List getPeerAttemptsForTheSameTaskIndex(final String taskId) { final String stageId = RuntimeIdManager.getStageIdFromTaskId(taskId); final int taskIndex = RuntimeIdManager.getIndexFromTaskId(taskId); + final int partialIndex = RuntimeIdManager.getSubSplitIndexFromTaskId(taskId); final int attempt = RuntimeIdManager.getAttemptFromTaskId(taskId); final List otherAttemptsforTheSameTaskIndex = - new ArrayList<>(stageIdToTaskIdxToAttemptStates.get(stageId).get(taskIndex)); + new ArrayList<>(stageIdToTaskIdxToAttemptStates.get(stageId).get(taskIndex).get(partialIndex)); otherAttemptsforTheSameTaskIndex.remove(attempt); return otherAttemptsforTheSameTaskIndex.stream() diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java index d3b48f266a..a5393b3ad9 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java @@ -242,7 +242,7 @@ public void flushMetrics() { // save metric to file metricStore.dumpAllMetricToFile(Paths.get(dagDirectory, - "Metric_" + jobId + "_" + System.currentTimeMillis() + ".json").toString()); + "Metric_" + jobId + ".json").toString()); // save metric to database if (this.dbEnabled) {