From 72a7485e5f3e796ef7efacb09cb0fe79cbadebd3 Mon Sep 17 00:00:00 2001 From: xuyang Date: Wed, 25 Mar 2026 19:06:45 +0800 Subject: [PATCH 1/3] [FLINK-39233][table-runtime] Support cascaded delta join runtime --- .../exec/stream/StreamExecDeltaJoin.java | 142 +- ...Case.scala => BinaryDeltaJoinITCase.scala} | 73 +- .../stream/sql/CascadedDeltaJoinITCase.scala | 1128 ++++++++ .../stream/sql/DeltaJoinITCaseBase.scala | 106 + .../join/deltajoin/AsyncDeltaJoinRunner.java | 15 + .../join/deltajoin/CascadedLookupHandler.java | 322 +++ .../join/deltajoin/TailOutputDataHandler.java | 66 + .../join/lookup/CalcCollectionCollector.java | 6 + .../StreamingBinaryDeltaJoinOperatorTest.java | 22 +- ...treamingCascadedDeltaJoinOperatorTest.java | 2293 +++++++++++++++++ .../StreamingDeltaJoinOperatorTestBase.java | 4 +- 11 files changed, 4075 insertions(+), 102 deletions(-) rename flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/{DeltaJoinITCase.scala => BinaryDeltaJoinITCase.scala} (94%) create mode 100644 flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/CascadedDeltaJoinITCase.scala create mode 100644 flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DeltaJoinITCaseBase.scala create mode 100644 flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/CascadedLookupHandler.java create mode 100644 flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/TailOutputDataHandler.java create mode 100644 flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingCascadedDeltaJoinOperatorTest.java diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java index 409ff0c0ea1fc..bf80cf73d662d 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java @@ -49,6 +49,7 @@ import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil; import org.apache.flink.table.planner.plan.schema.TableSourceTable; import org.apache.flink.table.planner.plan.utils.DeltaJoinUtil; +import org.apache.flink.table.planner.plan.utils.FunctionCallUtil; import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.AsyncOptions; import org.apache.flink.table.planner.plan.utils.KeySelectorUtil; import org.apache.flink.table.planner.plan.utils.LookupJoinUtil; @@ -60,9 +61,12 @@ import org.apache.flink.table.runtime.operators.join.FlinkJoinType; import org.apache.flink.table.runtime.operators.join.deltajoin.AsyncDeltaJoinRunner; import org.apache.flink.table.runtime.operators.join.deltajoin.BinaryLookupHandler; +import org.apache.flink.table.runtime.operators.join.deltajoin.CascadedLookupHandler; +import org.apache.flink.table.runtime.operators.join.deltajoin.DeltaJoinHandlerBase; import org.apache.flink.table.runtime.operators.join.deltajoin.DeltaJoinHandlerChain; import org.apache.flink.table.runtime.operators.join.deltajoin.DeltaJoinRuntimeTree; import org.apache.flink.table.runtime.operators.join.deltajoin.LookupHandlerBase; +import org.apache.flink.table.runtime.operators.join.deltajoin.TailOutputDataHandler; import org.apache.flink.table.runtime.typeutils.InternalSerializers; import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; import org.apache.flink.table.types.logical.RowType; @@ -499,6 +503,7 @@ protected Transformation translateToPlanInternal( private static LookupHandlerBase generateLookupHandler( boolean isBinaryLookup, + @Nullable Integer id, // used for debug. `null` if it is a binary lookup DeltaJoinLookupChain.Node node, Map>> generatedFetcherCollector, @@ -507,6 +512,9 @@ private static LookupHandlerBase generateLookupHandler( FlinkTypeFactory typeFactory, ClassLoader classLoader, ExecNodeConfig config) { + Preconditions.checkArgument( + isBinaryLookup == (id == null), "Id should be null if it is binary lookup"); + final int[] sourceInputOrdinals = node.inputTableBinaryInputOrdinals; final int lookupTableOrdinal = node.lookupTableBinaryInputOrdinal; final RowType sourceStreamType = @@ -584,8 +592,61 @@ private static LookupHandlerBase generateLookupHandler( node.lookupTableBinaryInputOrdinal); } - // TODO FLINK-39233 Support cascaded delta join in runtime - throw new IllegalStateException("Support later"); + final RowType lookupResultPassThroughCalcRowType; + if (node.isLeftLookupRight()) { + lookupResultPassThroughCalcRowType = + combineOutputRowType( + sourceStreamType, + lookupSidePassThroughCalcRowType, + node.joinType, + typeFactory); + } else { + lookupResultPassThroughCalcRowType = + combineOutputRowType( + lookupSidePassThroughCalcRowType, + sourceStreamType, + swapJoinType(node.joinType), + typeFactory); + } + + GeneratedFilterCondition generatedRemainingCondition = + node.deltaJoinSpec + .getRemainingCondition() + .map( + remainCond -> + FilterCodeGenerator.generateFilterCondition( + config, + planner.getFlinkContext().getClassLoader(), + remainCond, + lookupResultPassThroughCalcRowType, + GENERATED_JOIN_CONDITION_CLASS_NAME)) + .orElse(null); + + final RowDataKeySelector streamSideLookupKeySelector = + KeySelectorUtil.getRowDataSelector( + classLoader, + lookupKeysOnInputSide.stream() + .mapToInt( + key -> { + Preconditions.checkState( + key instanceof FunctionCallUtil.FieldRef); + return ((FunctionCallUtil.FieldRef) key).index; + }) + .toArray(), + InternalTypeInfo.of(sourceStreamType)); + + return new CascadedLookupHandler( + id, + TypeConversions.fromLogicalToDataType(sourceStreamType), + lookupSideGeneratedFetcherWithType.dataType(), + TypeConversions.fromLogicalToDataType(lookupSidePassThroughCalcRowType), + InternalSerializers.create(lookupSidePassThroughCalcRowType), + lookupSideGeneratedCalc, + generatedRemainingCondition, + streamSideLookupKeySelector, + node.inputTableBinaryInputOrdinals, + node.lookupTableBinaryInputOrdinal, + node.isLeftLookupRight()); } private static RowDataKeySelector getUpsertKeySelector( @@ -600,23 +661,6 @@ private static RowDataKeySelector getUpsertKeySelector( classLoader, finalUpsertKeys, InternalTypeInfo.of(rowType)); } - private boolean enableCache(ReadableConfig config) { - return config.get(ExecutionConfigOptions.TABLE_EXEC_DELTA_JOIN_CACHE_ENABLED); - } - - /** Get the left cache size and right size. */ - private Tuple2 getCacheSize(ReadableConfig config) { - long leftCacheSize = - config.get(ExecutionConfigOptions.TABLE_EXEC_DELTA_JOIN_LEFT_CACHE_SIZE); - long rightCacheSize = - config.get(ExecutionConfigOptions.TABLE_EXEC_DELTA_JOIN_RIGHT_CACHE_SIZE); - if ((leftCacheSize <= 0 || rightCacheSize <= 0) && enableCache(config)) { - throw new IllegalArgumentException( - "Cache size in delta join must be positive when enabling cache."); - } - return Tuple2.of(leftCacheSize, rightCacheSize); - } - private abstract static class DeltaJoinOperatorFactoryBuilder { protected final PlannerBase planner; protected final ExecNodeConfig config; @@ -651,6 +695,23 @@ public DeltaJoinOperatorFactoryBuilder( } protected abstract StreamOperatorFactory build(); + + /** Get the left cache size and right size. */ + protected Tuple2 getCacheSize(ReadableConfig config) { + long leftCacheSize = + config.get(ExecutionConfigOptions.TABLE_EXEC_DELTA_JOIN_LEFT_CACHE_SIZE); + long rightCacheSize = + config.get(ExecutionConfigOptions.TABLE_EXEC_DELTA_JOIN_RIGHT_CACHE_SIZE); + if ((leftCacheSize <= 0 || rightCacheSize <= 0) && enableCache(config)) { + throw new IllegalArgumentException( + "Cache size in delta join must be positive when enabling cache."); + } + return Tuple2.of(leftCacheSize, rightCacheSize); + } + + protected boolean enableCache(ReadableConfig config) { + return config.get(ExecutionConfigOptions.TABLE_EXEC_DELTA_JOIN_CACHE_ENABLED); + } } private class DeltaJoinOperatorFactoryBuilderV1 extends DeltaJoinOperatorFactoryBuilder { @@ -798,6 +859,7 @@ private DeltaJoinHandlerChain buildBinaryLookupHandlerChain( Collections.singletonList( generateLookupHandler( true, // isBinaryLookup + null, node, generatedFetcherCollector, deltaJoinTree, @@ -926,9 +988,10 @@ public StreamOperatorFactory build() { Map>> generatedFetcherCollector = new HashMap<>(); DeltaJoinHandlerChain left2RightHandlerChain = - generateDeltaJoinHandlerChain(true, generatedFetcherCollector); + generateDeltaJoinHandlerChain(true, leftStreamType, generatedFetcherCollector); DeltaJoinHandlerChain right2LeftHandlerChain = - generateDeltaJoinHandlerChain(false, generatedFetcherCollector); + generateDeltaJoinHandlerChain( + false, rightStreamType, generatedFetcherCollector); Preconditions.checkState( generatedFetcherCollector.size() == leftAllBinaryInputOrdinals.size() @@ -1008,6 +1071,7 @@ public StreamOperatorFactory build() { private DeltaJoinHandlerChain generateDeltaJoinHandlerChain( boolean lookupRight, + RowType streamRowType, Map>> generatedFetcherCollector) { int[] streamOwnedSourceOrdinals = @@ -1029,6 +1093,7 @@ private DeltaJoinHandlerChain generateDeltaJoinHandlerChain( Collections.singletonList( generateLookupHandler( true, // isBinaryLookup + null, // debug id nodes.get(0), generatedFetcherCollector, deltaJoinTree, @@ -1039,7 +1104,40 @@ private DeltaJoinHandlerChain generateDeltaJoinHandlerChain( streamOwnedSourceOrdinals); } - throw new UnsupportedOperationException("Support cascaded delta join operator later"); + final List lookupJoinHandlers = new ArrayList<>(); + + // build delta join handler chain + for (int i = 0; i < nodes.size(); i++) { + DeltaJoinLookupChain.Node node = nodes.get(i); + LookupHandlerBase lookupHandler = + generateLookupHandler( + false, // isBinaryLookup + i + 1, // debug id + node, + generatedFetcherCollector, + deltaJoinTree, + planner, + typeFactory, + classLoader, + config); + lookupJoinHandlers.add(lookupHandler); + } + List lookupSideAllBinaryInputOrdinals = + lookupRight ? rightAllBinaryInputOrdinals : leftAllBinaryInputOrdinals; + int lookupSideTableOffset = lookupRight ? leftAllBinaryInputOrdinals.size() : 0; + lookupJoinHandlers.add( + new TailOutputDataHandler( + lookupSideAllBinaryInputOrdinals.stream() + .mapToInt(i -> i + lookupSideTableOffset) + .toArray())); + + Preconditions.checkArgument( + streamRowType.getFieldCount() + == deltaJoinTree + .getOutputRowTypeOnNode(streamOwnedSourceOrdinals, typeFactory) + .getFieldCount()); + + return DeltaJoinHandlerChain.build(lookupJoinHandlers, streamOwnedSourceOrdinals); } private Set> getAllDrivenInputsWhenLookup(boolean lookupRight) { diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DeltaJoinITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/BinaryDeltaJoinITCase.scala similarity index 94% rename from flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DeltaJoinITCase.scala rename to flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/BinaryDeltaJoinITCase.scala index be5f311edbe8b..0ff1493753600 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DeltaJoinITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/BinaryDeltaJoinITCase.scala @@ -18,23 +18,16 @@ package org.apache.flink.table.planner.runtime.stream.sql import org.apache.flink.core.execution.CheckpointingMode -import org.apache.flink.table.api.Schema -import org.apache.flink.table.api.bridge.scala.internal.StreamTableEnvironmentImpl -import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions} -import org.apache.flink.table.api.config.OptimizerConfigOptions.DeltaJoinStrategy -import org.apache.flink.table.catalog.{CatalogTable, ObjectPath, ResolvedCatalogTable} import org.apache.flink.table.planner.{JHashMap, JMap} import org.apache.flink.table.planner.factories.TestValuesRuntimeFunctions.AsyncTestValueLookupFunction import org.apache.flink.table.planner.factories.TestValuesTableFactory import org.apache.flink.table.planner.factories.TestValuesTableFactory.changelogRow -import org.apache.flink.table.planner.runtime.utils.{FailingCollectionSource, StreamingTestBase} -import org.apache.flink.testutils.junit.extensions.parameterized.{ParameterizedTestExtension, Parameters} +import org.apache.flink.table.planner.runtime.utils.FailingCollectionSource import org.apache.flink.types.Row import org.assertj.core.api.Assertions.assertThat import org.assertj.core.util.Maps -import org.junit.jupiter.api.{BeforeEach, TestTemplate} -import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.api.TestTemplate import javax.annotation.Nullable @@ -46,23 +39,8 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ import scala.collection.JavaConverters.mapAsScalaMapConverter -@ExtendWith(Array(classOf[ParameterizedTestExtension])) -class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { - - @BeforeEach - override def before(): Unit = { - super.before() - - tEnv.getConfig.set( - OptimizerConfigOptions.TABLE_OPTIMIZER_DELTA_JOIN_STRATEGY, - DeltaJoinStrategy.FORCE) - - tEnv.getConfig.set( - ExecutionConfigOptions.TABLE_EXEC_DELTA_JOIN_CACHE_ENABLED, - Boolean.box(enableCache)) - - AsyncTestValueLookupFunction.invokeCount.set(0) - } +/** Tests for binary delta join with two tables. */ +class BinaryDeltaJoinITCase(enableCache: Boolean) extends DeltaJoinITCaseBase(enableCache) { @TestTemplate def testJoinKeyEqualsIndex(): Unit = { @@ -886,38 +864,6 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { .build()) } - /** TODO add index in DDL. */ - private def addIndex(tableName: String, indexColumns: List[String]): Unit = { - if (indexColumns.isEmpty) { - return - } - - val catalogName = tEnv.getCurrentCatalog - val databaseName = tEnv.getCurrentDatabase - val tablePath = new ObjectPath(databaseName, tableName) - val catalog = tEnv.getCatalog(catalogName).get() - val catalogManager = tEnv.asInstanceOf[StreamTableEnvironmentImpl].getCatalogManager - val schemaResolver = catalogManager.getSchemaResolver - - val resolvedTable = catalog.getTable(tablePath).asInstanceOf[ResolvedCatalogTable] - val originTable = resolvedTable.getOrigin - val originSchema = originTable.getUnresolvedSchema - - val newSchema = Schema.newBuilder().fromSchema(originSchema).index(indexColumns).build() - - val newTable = CatalogTable - .newBuilder() - .schema(newSchema) - .comment(originTable.getComment) - .partitionKeys(originTable.getPartitionKeys) - .options(originTable.getOptions) - .build() - val newResolvedTable = new ResolvedCatalogTable(newTable, schemaResolver.resolve(newSchema)) - - catalog.dropTable(tablePath, false) - catalog.createTable(tablePath, newResolvedTable, false) - } - private def testUpsertResult(testSpec: TestSpec): Unit = { prepareTable( testSpec.leftIndex, @@ -1046,7 +992,7 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { | $leftExtraOptionsStr |) |""".stripMargin) - addIndex("testLeft", leftIndex) + addIndexesAndImmutableCols("testLeft", List(leftIndex), List()) tEnv.executeSql("drop table if exists testRight") val rightExtraOptionsStr = @@ -1080,7 +1026,7 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { | $rightExtraOptionsStr |) |""".stripMargin) - addIndex("testRight", rightIndex) + addIndexesAndImmutableCols("testRight", List(rightIndex), List()) tEnv.executeSql("drop table if exists testSnk") tEnv.executeSql(s""" @@ -1268,10 +1214,3 @@ class DeltaJoinITCase(enableCache: Boolean) extends StreamingTestBase { } } - -object DeltaJoinITCase { - @Parameters(name = "EnableCache={0}") - def parameters(): java.util.Collection[Boolean] = { - Seq[Boolean](true, false) - } -} diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/CascadedDeltaJoinITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/CascadedDeltaJoinITCase.scala new file mode 100644 index 0000000000000..b10edc827a46d --- /dev/null +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/CascadedDeltaJoinITCase.scala @@ -0,0 +1,1128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.runtime.stream.sql + +import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions} +import org.apache.flink.table.api.config.OptimizerConfigOptions.DeltaJoinStrategy +import org.apache.flink.table.planner.factories.TestValuesRuntimeFunctions.AsyncTestValueLookupFunction +import org.apache.flink.table.planner.factories.TestValuesTableFactory +import org.apache.flink.table.planner.factories.TestValuesTableFactory.changelogRow +import org.apache.flink.types.Row +import org.apache.flink.util.Preconditions + +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.{BeforeEach, TestTemplate} + +import scala.collection.JavaConversions._ + +/** Tests for multi cascaded delta join with multi tables. */ +class CascadedDeltaJoinITCase(enableCache: Boolean) extends DeltaJoinITCaseBase(enableCache) { + + @BeforeEach + override def before(): Unit = { + super.before() + + tEnv.getConfig.set( + OptimizerConfigOptions.TABLE_OPTIMIZER_DELTA_JOIN_STRATEGY, + DeltaJoinStrategy.FORCE) + + tEnv.getConfig.set( + ExecutionConfigOptions.TABLE_EXEC_DELTA_JOIN_CACHE_ENABLED, + Boolean.box(enableCache)) + + AsyncTestValueLookupFunction.invokeCount.set(0) + + tEnv.executeSql(s""" + |create table sink1( + | a1 int, + | c0 double, + | c2 string, + | a2 string, + | b2 string, + | primary key (a1, c0) not enforced + |) with ( + | 'connector' = 'values', + | 'bounded' = 'false', + | 'sink-insert-only' = 'false' + |) + |""".stripMargin) + + tEnv.executeSql(s""" + |create table sink2( + | a1 int, + | c0 double, + | d1 int, + | c2 string, + | d2 string, + | a2 string, + | b2 string, + | primary key (a1, c0, d1) not enforced + |) with ( + | 'connector' = 'values', + | 'bounded' = 'false', + | 'sink-insert-only' = 'false' + |) + |""".stripMargin) + } + + @TestTemplate + def testLHS1(): Unit = { + // DT2 + // / \ + // DT1 C + // / \ + // A B + // when records from C come, lookup chain is: + // C -> B -> A + prepareThreeTables() + + val sql = + """ + |insert into sink1 + |select a1, c0, c2, a2, b2 + | from A + |join B + | on a1 = b1 and a0 = b0 + |join C + | on c1 = b1 and c1 <> 1 + | + |""".stripMargin + tEnv.executeSql(sql).await() + + val expected = List("+I[2, 2.0, c-2-2, a-2, b-2-2]", "+I[2, 3.0, c-3, a-2, b-2-2]") + + val result = TestValuesTableFactory.getResultsAsStrings("sink1") + assertThat(result.sorted).isEqualTo(expected.sorted) + + // | DT | Lookup Count Without Cache | Lookup Count With Cache | + // | ------------ | --------------------------- | ----------------------- | + // | DT1 | 2 + 4 = 6 | 2 + 3 = 5 | + // | DT2 | 3 + 9 = 12 | 1 + 5 = 6 | + // | TOTAL | 18 | 11 | + // | ------------ | --------------------------- | ----------------------- | + assertThat(AsyncTestValueLookupFunction.invokeCount.get()) + .isEqualTo(if (enableCache) 11 else 18) + } + + @TestTemplate + def testLHS2(): Unit = { + // DT2 + // / \ + // DT1 C + // / \ + // A B + // when records from C come, lookup chain is: + // C -> A -> B + prepareThreeTables() + + val sql = + """ + |insert into sink1 + |select a1, c0, c2, a2, b2 + | from A + |join B + | on a1 = b1 and a0 = b0 + |join C + | on c1 = a1 and c1 <> 1 + | + |""".stripMargin + tEnv.executeSql(sql).await() + + val expected = List("+I[2, 2.0, c-2-2, a-2, b-2-2]", "+I[2, 3.0, c-3, a-2, b-2-2]") + + val result = TestValuesTableFactory.getResultsAsStrings("sink1") + assertThat(result.sorted).isEqualTo(expected.sorted) + + // | DT | Lookup Count Without Cache | Lookup Count With Cache | + // | ------------ | --------------------------- | ----------------------- | + // | DT1 | 2 + 4 = 6 | 2 + 3 = 5 | + // | DT2 | 3 + 9 = 12 | 1 + 5 = 6 | + // | TOTAL | 18 | 11 | + // | ------------ | --------------------------- | ----------------------- | + assertThat(AsyncTestValueLookupFunction.invokeCount.get()) + .isEqualTo(if (enableCache) 11 else 18) + } + + @TestTemplate + def testMultiLHS1(): Unit = { + // DT3 + // / \ + // DT2 D + // / \ + // DT1 C + // / \ + // A B + // when records from C come, lookup chain is: + // C -> B -> A + // when records from D come, lookup chain is: + // D -> C -> B -> A + prepareFourTables() + val sql = + """ + |insert into sink2 + |select a1, c0, d1, c2, d2, a2, b2 + | from A + |join B + | on a1 = b1 and a0 = b0 + |join C + | on c1 = b1 and c1 <> 2 + |join D + | on d1 = c1 + | + |""".stripMargin + tEnv.executeSql(sql).await() + + val expected = + List("+I[1, 1.0, 1, c-1, d-2, a-1-2, b-1-2]", "+I[1, 3.0, 1, c-3, d-2, a-1-2, b-1-2]") + + val result = TestValuesTableFactory.getResultsAsStrings("sink2") + assertThat(result.sorted).isEqualTo(expected.sorted) + + // | DT | Lookup Count Without Cache | Lookup Count With Cache | + // | ------------ | --------------------------- | ----------------------- | + // | DT1 | 3 + 4 = 7 | 2 + 3 = 5 | + // | DT2 | 4 + 9 = 13 | 1 + 5 = 6 | + // | DT3 | 10 + 10 = 20 | 1 + 7 = 8 | + // | TOTAL | 40 | 19 | + // | ------------ | --------------------------- | ----------------------- | + assertThat(AsyncTestValueLookupFunction.invokeCount.get()) + .isEqualTo(if (enableCache) 19 else 40) + } + + @TestTemplate + def testMultiLHS2(): Unit = { + // DT3 + // / \ + // DT2 D + // / \ + // DT1 C + // / \ + // A B + // when records from C come, lookup chain is: + // C -> B -> A + // when records from D come, lookup chain is: + // D -> A -> B -> C + prepareFourTables() + val sql = + """ + |insert into sink2 + |select a1, c0, d1, c2, d2, a2, b2 + | from A + |join B + | on a1 = b1 and a0 = b0 + |join C + | on c1 = b1 and c1 <> 2 + |join D + | on d1 = a1 + | + |""".stripMargin + tEnv.executeSql(sql).await() + + val expected = + List("+I[1, 1.0, 1, c-1, d-2, a-1-2, b-1-2]", "+I[1, 3.0, 1, c-3, d-2, a-1-2, b-1-2]") + + val result = TestValuesTableFactory.getResultsAsStrings("sink2") + assertThat(result.sorted).isEqualTo(expected.sorted) + + // | DT | Lookup Count Without Cache | Lookup Count With Cache | + // | ------------ | --------------------------- | ----------------------- | + // | DT1 | 3 + 4 = 7 | 2 + 3 = 5 | + // | DT2 | 4 + 9 = 13 | 1 + 5 = 6 | + // | DT3 | 10 + 9 = 19 | 1 + 6 = 7 | + // | TOTAL | 39 | 18 | + // | ------------ | --------------------------- | ----------------------- | + assertThat(AsyncTestValueLookupFunction.invokeCount.get()) + .isEqualTo(if (enableCache) 18 else 39) + } + + @TestTemplate + def testMultiLHS3(): Unit = { + // DT3 + // / \ + // DT2 D + // / \ + // DT1 C + // / \ + // A B + // when records from C come, lookup chain is: + // C -> B -> A + // when records from D come, lookup chain is: + // D -> B -> A -> C + prepareFourTables() + val sql = + """ + |insert into sink2 + |select a1, c0, d1, c2, d2, a2, b2 + | from A + |join B + | on a1 = b1 and a0 = b0 + |join C + | on c1 = b1 and c1 <> 2 + |join D + | on d1 = b1 + | + |""".stripMargin + tEnv.executeSql(sql).await() + + val expected = + List("+I[1, 1.0, 1, c-1, d-2, a-1-2, b-1-2]", "+I[1, 3.0, 1, c-3, d-2, a-1-2, b-1-2]") + + val result = TestValuesTableFactory.getResultsAsStrings("sink2") + assertThat(result.sorted).isEqualTo(expected.sorted) + + // | DT | Lookup Count Without Cache | Lookup Count With Cache | + // | ------------ | --------------------------- | ----------------------- | + // | DT1 | 3 + 4 = 7 | 2 + 3 = 5 | + // | DT2 | 4 + 9 = 13 | 1 + 5 = 6 | + // | DT3 | 10 + 9 = 19 | 1 + 6 = 7 | + // | TOTAL | 39 | 18 | + // | ------------ | --------------------------- | ----------------------- | + assertThat(AsyncTestValueLookupFunction.invokeCount.get()) + .isEqualTo(if (enableCache) 18 else 39) + } + + @TestTemplate + def testRHS1(): Unit = { + // DT2 + // / \ + // C DT1 + // / \ + // A B + // when records from C come, lookup chain is: + // C -> A -> B + prepareThreeTables() + + tEnv.executeSql(""" + |create temporary view myv as + |select * + | from A + |join B + | on a1 = b1 and a0 <> b0 + |""".stripMargin) + + tEnv + .executeSql(""" + |insert into sink1 + |select a1, c0, c2, a2, b2 + | from C + |join myv + | on c1 = a1 and c1 <> 2 + |""".stripMargin) + .await() + + val expected = List("+I[3, 32.0, c-3-2, a-3, b-3]") + + val result = TestValuesTableFactory.getResultsAsStrings("sink1") + assertThat(result.sorted).isEqualTo(expected.sorted) + + // | DT | Lookup Count Without Cache | Lookup Count With Cache | + // | ------------ | --------------------------- | ----------------------- | + // | DT1 | 2 + 4 = 6 | 2 + 3 = 5 | + // | DT2 | 5 + 2 = 7 | 5 + 1 = 6 | + // | TOTAL | 13 | 11 | + // | ------------ | --------------------------- | ----------------------- | + assertThat(AsyncTestValueLookupFunction.invokeCount.get()) + .isEqualTo(if (enableCache) 11 else 13) + } + + @TestTemplate + def testRHS2(): Unit = { + // DT2 + // / \ + // C DT1 + // / \ + // A B + // when records from C come, lookup chain is: + // C -> B -> A + prepareThreeTables() + + tEnv.executeSql(""" + |create temporary view myv as + |select * + | from A + |join B + | on a1 = b1 and a0 <> b0 + |""".stripMargin) + + tEnv + .executeSql(""" + |insert into sink1 + |select a1, c0, c2, a2, b2 + | from C + |join myv + | on c1 = b1 and c1 <> 2 + |""".stripMargin) + .await() + + val expected = List("+I[3, 32.0, c-3-2, a-3, b-3]") + + val result = TestValuesTableFactory.getResultsAsStrings("sink1") + assertThat(result.sorted).isEqualTo(expected.sorted) + + // | DT | Lookup Count Without Cache | Lookup Count With Cache | + // | ------------ | --------------------------- | ----------------------- | + // | DT1 | 2 + 4 = 6 | 2 + 3 = 5 | + // | DT2 | 5 + 2 = 7 | 5 + 1 = 6 | + // | TOTAL | 13 | 11 | + // | ------------ | --------------------------- | ----------------------- | + assertThat(AsyncTestValueLookupFunction.invokeCount.get()) + .isEqualTo(if (enableCache) 11 else 13) + } + + @TestTemplate + def testMultiRHS1(): Unit = { + // DT3 + // / \ + // D DT2 + // / \ + // C DT1 + // / \ + // A B + // when records from C come, lookup chain is: + // C -> B -> A + // when records from C come, lookup chain is: + // D -> C -> B -> A + prepareFourTables() + + tEnv.executeSql(""" + |create temporary view dt1 as + |select * from A + |join B + | on a1 = b1 and a0 <> b0 + |""".stripMargin) + + tEnv.executeSql(""" + |create temporary view dt2 as + |select * from C + |join dt1 + | on c1 = b1 and c1 <> 2 + |""".stripMargin) + + tEnv + .executeSql(""" + |insert into sink2 + |select a1, c0, d1, c2, d2, a2, b2 + | from D + |join dt2 + | on d1 = c1 + |""".stripMargin) + .await() + + val expected = + List("+I[3, 32.0, 3, c-3-2, d-3, a-3, b-3]", "+I[3, 33.0, 3, c-3-3, d-3, a-3, b-3]") + + val result = TestValuesTableFactory.getResultsAsStrings("sink2") + assertThat(result.sorted).isEqualTo(expected.sorted) + + // | DT | Lookup Count Without Cache | Lookup Count With Cache | + // | ------------ | --------------------------- | ----------------------- | + // | DT1 | 3 + 4 = 7 | 2 + 3 = 5 | + // | DT2 | 9 + 2 = 11 | 5 + 1 = 6 | + // | DT3 | 10 + 6 = 16 | 7 + 1 = 8 | + // | TOTAL | 34 | 19 | + // | ------------ | --------------------------- | ----------------------- | + assertThat(AsyncTestValueLookupFunction.invokeCount.get()) + .isEqualTo(if (enableCache) 19 else 34) + } + + @TestTemplate + def testMultiRHS2(): Unit = { + // DT3 + // / \ + // D DT2 + // / \ + // C DT1 + // / \ + // A B + // when records from C come, lookup chain is: + // C -> B -> A + // when records from C come, lookup chain is: + // D -> A -> B -> C + prepareFourTables() + + tEnv.executeSql(""" + |create temporary view dt1 as + |select * from A + |join B + | on a1 = b1 and a0 <> b0 + |""".stripMargin) + + tEnv.executeSql(""" + |create temporary view dt2 as + |select * from C + |join dt1 + | on c1 = b1 and c1 <> 2 + |""".stripMargin) + + tEnv + .executeSql(""" + |insert into sink2 + |select a1, c0, d1, c2, d2, a2, b2 + | from D + |join dt2 + | on d1 = a1 + |""".stripMargin) + .await() + + val expected = + List("+I[3, 32.0, 3, c-3-2, d-3, a-3, b-3]", "+I[3, 33.0, 3, c-3-3, d-3, a-3, b-3]") + + val result = TestValuesTableFactory.getResultsAsStrings("sink2") + assertThat(result.sorted).isEqualTo(expected.sorted) + + // | DT | Lookup Count Without Cache | Lookup Count With Cache | + // | ------------ | --------------------------- | ----------------------- | + // | DT1 | 3 + 4 = 7 | 2 + 3 = 5 | + // | DT2 | 9 + 2 = 11 | 5 + 1 = 6 | + // | DT3 | 8 + 6 = 14 | 6 + 1 = 7 | + // | TOTAL | 32 | 18 | + // | ------------ | --------------------------- | ----------------------- | + assertThat(AsyncTestValueLookupFunction.invokeCount.get()) + .isEqualTo(if (enableCache) 18 else 32) + } + + @TestTemplate + def testMultiRHS3(): Unit = { + // DT3 + // / \ + // D DT2 + // / \ + // C DT1 + // / \ + // A B + // when records from C come, lookup chain is: + // C -> B -> A + // when records from C come, lookup chain is: + // D -> B -> A -> C + prepareFourTables() + + tEnv.executeSql(""" + |create temporary view dt1 as + |select * from A + |join B + | on a1 = b1 and a0 <> b0 + |""".stripMargin) + + tEnv.executeSql(""" + |create temporary view dt2 as + |select * from C + |join dt1 + | on c1 = b1 and c1 <> 2 + |""".stripMargin) + + tEnv + .executeSql(""" + |insert into sink2 + |select a1, c0, d1, c2, d2, a2, b2 + | from D + |join dt2 + | on d1 = b1 + |""".stripMargin) + .await() + + val expected = + List("+I[3, 32.0, 3, c-3-2, d-3, a-3, b-3]", "+I[3, 33.0, 3, c-3-3, d-3, a-3, b-3]") + + val result = TestValuesTableFactory.getResultsAsStrings("sink2") + assertThat(result.sorted).isEqualTo(expected.sorted) + + // | DT | Lookup Count Without Cache | Lookup Count With Cache | + // | ------------ | --------------------------- | ----------------------- | + // | DT1 | 3 + 4 = 7 | 2 + 3 = 5 | + // | DT2 | 9 + 2 = 11 | 5 + 1 = 6 | + // | DT3 | 8 + 6 = 14 | 6 + 1 = 7 | + // | TOTAL | 32 | 18 | + // | ------------ | --------------------------- | ----------------------- | + assertThat(AsyncTestValueLookupFunction.invokeCount.get()) + .isEqualTo(if (enableCache) 18 else 32) + } + + @TestTemplate + def testBushy1(): Unit = { + // DT3 + // / \ + // DT1 DT2 + // / \ / \ + // A B C D + // when records from DT1 come, lookup chain is: + // DT-1 -> C -> D + // when records from DT2 come, lookup chain is: + // DT-2 -> B -> A + prepareFourTables() + + tEnv.executeSql(""" + |create temporary view dt1 as + |select * from A + |join B + | on a1 = b1 and a0 <> b0 + |""".stripMargin) + + tEnv.executeSql(""" + |create temporary view dt2 as + |select * from C + |join D + | on c1 = d1 and c0 <> 32.0 + |""".stripMargin) + + tEnv + .executeSql( + """ + |insert into sink2 + |select a1, c0, d1, c2, d2, a2, b2 + | from dt1 + |join dt2 + | on c1 = b1 + |""".stripMargin + ) + .await() + + val expected = List("+I[3, 33.0, 3, c-3-3, d-3, a-3, b-3]") + + val result = TestValuesTableFactory.getResultsAsStrings("sink2") + assertThat(result.sorted).isEqualTo(expected.sorted) + + // | DT | Lookup Count Without Cache | Lookup Count With Cache | + // | ------------ | --------------------------- | ----------------------- | + // | DT1 | 4 + 5 = 9 | 3 + 4 = 7 | + // | DT2 | 5 + 4 = 9 | 4 + 3 = 7 | + // | DT3 | 4 + 16 = 20 | 2 + 4 = 6 | + // | TOTAL | 38 | 20 | + // | ------------ | --------------------------- | ----------------------- | + assertThat(AsyncTestValueLookupFunction.invokeCount.get()) + .isEqualTo(if (enableCache) 20 else 38) + } + + @TestTemplate + def testBushy2(): Unit = { + // DT3 + // / \ + // DT1 DT2 + // / \ / \ + // A B C D + // when records from DT1 come, lookup chain is: + // DT-1 -> D -> C + // when records from DT2 come, lookup chain is: + // DT-2 -> A -> B + prepareFourTables() + + tEnv.executeSql(""" + |create temporary view dt1 as + |select * from A + |join B + | on a1 = b1 and a0 <> b0 + |""".stripMargin) + + tEnv.executeSql(""" + |create temporary view dt2 as + |select * from C + |join D + | on c1 = d1 and c0 <> 32.0 + |""".stripMargin) + + tEnv + .executeSql( + """ + |insert into sink2 + |select a1, c0, d1, c2, d2, a2, b2 + | from dt1 + |join dt2 + | on d1 = a1 + |""".stripMargin + ) + .await() + + val expected = List("+I[3, 33.0, 3, c-3-3, d-3, a-3, b-3]") + + val result = TestValuesTableFactory.getResultsAsStrings("sink2") + assertThat(result.sorted).isEqualTo(expected.sorted) + + // | DT | Lookup Count Without Cache | Lookup Count With Cache | + // | ------------ | --------------------------- | ----------------------- | + // | DT1 | 4 + 5 = 9 | 3 + 4 = 7 | + // | DT2 | 5 + 4 = 9 | 4 + 3 = 7 | + // | DT3 | 4 + 16 = 20 | 2 + 4 = 6 | + // | TOTAL | 38 | 20 | + // | ------------ | --------------------------- | ----------------------- | + assertThat(AsyncTestValueLookupFunction.invokeCount.get()) + .isEqualTo(if (enableCache) 20 else 38) + } + + @TestTemplate + def testCalcExistsBothBetweenSourceAndJoinAndCascadedJoins(): Unit = { + // DT2 + // / \ + // DT1 C + // / \ + // A B + // when records from C come, lookup chain is: + // C -> B -> A + prepareThreeTables() + + tEnv.executeSql(""" + |create temporary view dt1 as + |select + | a0, a1, b1, + | trim(a2) as new_a2, + | concat_ws('~', a2, b2) as ab2 + | from A + |join B + | on a1 = b1 and b0 <> 1.0 + |""".stripMargin) + + val sql = + """ + |insert into sink1 + |select a1, c0, c2, new_a2, ab2 + | from dt1 + |join C + | on c1 = b1 and c0 <> 3.0 + |""".stripMargin + tEnv.executeSql(sql).await() + + val expected = List("+I[2, 2.0, c-2-2, a-2, a-2~b-2-2]", "+I[3, 32.0, c-3-2, a-3, a-3~b-3]") + + val result = TestValuesTableFactory.getResultsAsStrings("sink1") + assertThat(result.sorted).isEqualTo(expected.sorted) + + // | DT | Lookup Count Without Cache | Lookup Count With Cache | + // | ------------ | --------------------------- | ----------------------- | + // | DT1 | 3 + 4 = 7 | 3 + 3 = 6 | + // | DT2 | 4 + 9 = 13 | 2 + 6 = 8 | + // | TOTAL | 20 | 14 | + // | ------------ | --------------------------- | ----------------------- | + assertThat(AsyncTestValueLookupFunction.invokeCount.get()) + .isEqualTo(if (enableCache) 14 else 20) + } + + @TestTemplate + def testConsecutiveOneToManyJoins(): Unit = { + // DT2 + // / \ + // DT1 C + // / \ + // A B + // when records from C come, lookup chain is: + // C -> B -> A + prepareSrcTableWithData( + "A", + List("a0 double primary key not enforced", "a1 int", "a2 string"), + List( + changelogRow("+I", Double.box(1.0), Int.box(1), String.valueOf("a-1")), + changelogRow("+I", Double.box(2.0), Int.box(2), String.valueOf("a-2")) + ), + List(List("a1")), + List("a1") + ) + prepareSrcTableWithData( + "B", + List("b0 double primary key not enforced", "b1 int", "b2 string"), + List( + changelogRow("+I", Double.box(1.0), Int.box(1), String.valueOf("b-1")), + changelogRow("+I", Double.box(11.0), Int.box(1), String.valueOf("b-11")), + changelogRow("+I", Double.box(2.0), Int.box(2), String.valueOf("b-2")), + changelogRow("+I", Double.box(22.0), Int.box(2), String.valueOf("b-22")) + ), + List(List("b1")), + List("b1") + ) + prepareSrcTableWithData( + "C", + List("c0 double primary key not enforced", "c1 int", "c2 string"), + List( + changelogRow("+I", Double.box(1.0), Int.box(1), String.valueOf("c-1")), + changelogRow("+I", Double.box(11.0), Int.box(1), String.valueOf("c-11")), + changelogRow("+I", Double.box(2.0), Int.box(2), String.valueOf("c-2")), + changelogRow("+I", Double.box(22.0), Int.box(2), String.valueOf("c-22")) + ), + List(List("c1")), + List("c1") + ) + + tEnv.executeSql(s""" + |create table tmp_sink( + | a0 double, + | b0 double, + | c0 double, + | abc2 string, + | primary key (a0, b0, c0) not enforced + |) with ( + | 'connector' = 'values', + | 'bounded' = 'false', + | 'sink-insert-only' = 'false' + |) + |""".stripMargin) + + val sql = + """ + |insert into tmp_sink + |select a0, b0, c0, concat_ws('~', a2, b2, c2) as abc2 + | from A + |join B + | on a1 = b1 + |join C + | on c1 = b1 + | + |""".stripMargin + tEnv.executeSql(sql).await() + + val expected = List( + "+I[1.0, 1.0, 1.0, a-1~b-1~c-1]", + "+I[1.0, 1.0, 11.0, a-1~b-1~c-11]", + "+I[1.0, 11.0, 1.0, a-1~b-11~c-1]", + "+I[1.0, 11.0, 11.0, a-1~b-11~c-11]", + "+I[2.0, 2.0, 2.0, a-2~b-2~c-2]", + "+I[2.0, 2.0, 22.0, a-2~b-2~c-22]", + "+I[2.0, 22.0, 2.0, a-2~b-22~c-2]", + "+I[2.0, 22.0, 22.0, a-2~b-22~c-22]" + ) + + val result = TestValuesTableFactory.getResultsAsStrings("tmp_sink") + assertThat(result.sorted).isEqualTo(expected.sorted) + + // | DT | Lookup Count Without Cache | Lookup Count With Cache | + // | ------------ | --------------------------- | ----------------------- | + // | DT1 | 2 + 4 = 6 | 2 + 2 = 4 | + // | DT2 | 8 + 8 = 16 | 2 + 4 = 6 | + // | TOTAL | 22 | 10 | + // | ------------ | --------------------------- | ----------------------- | + assertThat(AsyncTestValueLookupFunction.invokeCount.get()) + .isEqualTo(if (enableCache) 10 else 22) + } + + @TestTemplate + def testJoinTwoTablesWhileJoinKeyChanged1(): Unit = { + // DT2 + // / \ + // DT1 C + // / \ + // A B + // when records from C come, lookup chain is: + // C -> B -> A + prepareSrcTableWithData( + "A", + List("a0 double primary key not enforced", "a_key string", "a_b_key string"), + List( + changelogRow("+I", Double.box(1.0), String.valueOf("a-1"), String.valueOf("b-1")), + changelogRow("+I", Double.box(2.0), String.valueOf("a-2"), String.valueOf("b-2")) + ), + List(List("a_b_key")), + List("a_b_key") + ) + prepareSrcTableWithData( + "B", + List("b0 double primary key not enforced", "b_key string", "b_c_key string"), + List( + changelogRow("+I", Double.box(1.0), String.valueOf("b-1"), String.valueOf("c-1")), + changelogRow("+I", Double.box(11.0), String.valueOf("b-1"), String.valueOf("c-2")), + changelogRow("+I", Double.box(2.0), String.valueOf("b-2"), String.valueOf("c-1")), + changelogRow("+I", Double.box(22.0), String.valueOf("b-2"), String.valueOf("c-2")) + ), + List(List("b_key"), List("b_c_key")), + List("b_key", "b_c_key") + ) + prepareSrcTableWithData( + "C", + List("c0 double primary key not enforced", "c_key string"), + List( + changelogRow("+I", Double.box(1.0), String.valueOf("c-1")), + changelogRow("+I", Double.box(2.0), String.valueOf("c-2")) + ), + List(List("c_key")), + List("c_key") + ) + + tEnv.executeSql(s""" + |create table tmp_sink( + | a0 double, + | b0 double, + | c0 double, + | abc_key string, + | primary key (a0, b0, c0) not enforced + |) with ( + | 'connector' = 'values', + | 'bounded' = 'false', + | 'sink-insert-only' = 'false' + |) + |""".stripMargin) + + val sql = + """ + |insert into tmp_sink + |select a0, b0, c0, concat_ws('~', a_key, b_key, c_key) as abc_key + | from A + |join B + | on a_b_key = b_key + |join C + | on c_key = b_c_key + | + |""".stripMargin + tEnv.executeSql(sql).await() + + val expected = List( + "+I[1.0, 1.0, 1.0, a-1~b-1~c-1]", + "+I[1.0, 11.0, 2.0, a-1~b-1~c-2]", + "+I[2.0, 2.0, 1.0, a-2~b-2~c-1]", + "+I[2.0, 22.0, 2.0, a-2~b-2~c-2]" + ) + + val result = TestValuesTableFactory.getResultsAsStrings("tmp_sink") + assertThat(result.sorted).isEqualTo(expected.sorted) + + // | DT | Lookup Count Without Cache | Lookup Count With Cache | + // | ------------ | --------------------------- | ----------------------- | + // | DT1 | 2 + 4 = 6 | 2 + 2 = 4 | + // | DT2 | 8 + 6 = 14 | 2 + 6 = 8 | + // | TOTAL | 20 | 12 | + // | ------------ | --------------------------- | ----------------------- | + assertThat(AsyncTestValueLookupFunction.invokeCount.get()) + .isEqualTo(if (enableCache) 12 else 20) + } + + @TestTemplate + def testJoinTwoTablesWhileJoinKeyChanged2(): Unit = { + // DT2 + // / \ + // DT1 C + // / \ + // A B + // when records from C come, lookup chain is: + // C -> A -> B + prepareSrcTableWithData( + "A", + List( + "a0 double primary key not enforced", + "a_key string", + "a_b_key string", + "a_c_key string"), + List( + changelogRow( + "+I", + Double.box(1.0), + String.valueOf("a-1"), + String.valueOf("b-1"), + String.valueOf("c-1")), + changelogRow( + "+I", + Double.box(2.0), + String.valueOf("a-2"), + String.valueOf("b-2"), + String.valueOf("c-2")) + ), + List(List("a_b_key"), List("a_c_key")), + List("a_b_key", "a_c_key") + ) + prepareSrcTableWithData( + "B", + List("b0 double primary key not enforced", "b_key string"), + List( + changelogRow("+I", Double.box(1.0), String.valueOf("b-1")), + changelogRow("+I", Double.box(11.0), String.valueOf("b-1")), + changelogRow("+I", Double.box(2.0), String.valueOf("b-2")), + changelogRow("+I", Double.box(22.0), String.valueOf("b-2")) + ), + List(List("b_key")), + List("b_key") + ) + prepareSrcTableWithData( + "C", + List("c0 double primary key not enforced", "c_key string"), + List( + changelogRow("+I", Double.box(1.0), String.valueOf("c-1")), + changelogRow("+I", Double.box(2.0), String.valueOf("c-2")) + ), + List(List("c_key")), + List("c_key") + ) + + tEnv.executeSql(s""" + |create table tmp_sink( + | a0 double, + | b0 double, + | c0 double, + | abc_key string, + | primary key (a0, b0, c0) not enforced + |) with ( + | 'connector' = 'values', + | 'bounded' = 'false', + | 'sink-insert-only' = 'false' + |) + |""".stripMargin) + + val sql = + """ + |insert into tmp_sink + |select a0, b0, c0, concat_ws('~', a_key, a_b_key, a_c_key) as abc_key + | from A + |join B + | on a_b_key = b_key + |join C + | on a_c_key = c_key + | + |""".stripMargin + tEnv.executeSql(sql).await() + + val expected = List( + "+I[1.0, 1.0, 1.0, a-1~b-1~c-1]", + "+I[1.0, 11.0, 1.0, a-1~b-1~c-1]", + "+I[2.0, 2.0, 2.0, a-2~b-2~c-2]", + "+I[2.0, 22.0, 2.0, a-2~b-2~c-2]" + ) + + val result = TestValuesTableFactory.getResultsAsStrings("tmp_sink") + assertThat(result.sorted).isEqualTo(expected.sorted) + + // | DT | Lookup Count Without Cache | Lookup Count With Cache | + // | ------------ | --------------------------- | ----------------------- | + // | DT1 | 2 + 4 = 6 | 2 + 2 = 4 | + // | DT2 | 8 + 4 = 12 | 2 + 4 = 6 | + // | TOTAL | 18 | 10 | + // | ------------ | --------------------------- | ----------------------- | + assertThat(AsyncTestValueLookupFunction.invokeCount.get()) + .isEqualTo(if (enableCache) 10 else 18) + } + + private def prepareThreeTables(): Unit = { + prepareSrcTableWithData( + "A", + List("a0 double", "a1 int primary key not enforced", "a2 string"), + List( + changelogRow("+I", Double.box(1.0), Int.box(1), String.valueOf("a-1")), + changelogRow("+I", Double.box(2.0), Int.box(2), String.valueOf("a-2")), + changelogRow("+I", Double.box(3.0), Int.box(3), String.valueOf("a-3")) + ), + List(List("a1")), + List("a0") + ) + prepareSrcTableWithData( + "B", + List("b1 int primary key not enforced", "b0 double", "b2 string"), + List( + changelogRow("+I", Int.box(1), Double.box(1.0), String.valueOf("b-1")), + changelogRow("-U", Int.box(1), Double.box(1.0), String.valueOf("b-1")), + changelogRow("+U", Int.box(1), Double.box(1.0), String.valueOf("b-1-2")), + changelogRow("+I", Int.box(2), Double.box(2.0), String.valueOf("b-2")), + changelogRow("-U", Int.box(2), Double.box(2.0), String.valueOf("b-2")), + changelogRow("+U", Int.box(2), Double.box(2.0), String.valueOf("b-2-2")), + changelogRow("+I", Int.box(3), Double.box(4.0), String.valueOf("b-3")), + changelogRow("+I", Int.box(13), Double.box(13.0), String.valueOf("b-13")) + ), + List(List("b1")), + List("b0") + ) + prepareSrcTableWithData( + "C", + List("c1 int", "c2 string", "c0 double primary key not enforced"), + List( + changelogRow("+I", Int.box(1), String.valueOf("c-1"), Double.box(1.0)), + changelogRow("+I", Int.box(2), String.valueOf("c-2"), Double.box(2.0)), + changelogRow("-U", Int.box(2), String.valueOf("c-2"), Double.box(2.0)), + changelogRow("+U", Int.box(2), String.valueOf("c-2-2"), Double.box(2.0)), + changelogRow("+I", Int.box(2), String.valueOf("c-3"), Double.box(3.0)), + changelogRow("+I", Int.box(3), String.valueOf("c-3-2"), Double.box(32.0)), + changelogRow("+I", Int.box(23), String.valueOf("c-23"), Double.box(23.0)) + ), + List(List("c1")), + List("c1") + ) + } + + private def prepareFourTables(): Unit = { + prepareSrcTableWithData( + "A", + List("a0 double", "a1 int primary key not enforced", "a2 string"), + List( + changelogRow("+I", Double.box(1.0), Int.box(1), String.valueOf("a-1")), + changelogRow("-U", Double.box(1.0), Int.box(1), String.valueOf("a-1")), + changelogRow("+U", Double.box(1.0), Int.box(1), String.valueOf("a-1-2")), + changelogRow("+I", Double.box(2.0), Int.box(2), String.valueOf("a-2")), + changelogRow("+I", Double.box(3.0), Int.box(3), String.valueOf("a-3")) + ), + List(List("a1")), + List("a0") + ) + + prepareSrcTableWithData( + "B", + List("b1 int primary key not enforced", "b0 double", "b2 string"), + List( + changelogRow("+I", Int.box(1), Double.box(1.0), String.valueOf("b-1")), + changelogRow("-U", Int.box(1), Double.box(1.0), String.valueOf("b-1")), + changelogRow("+U", Int.box(1), Double.box(1.0), String.valueOf("b-1-2")), + changelogRow("+I", Int.box(2), Double.box(2.0), String.valueOf("b-2")), + changelogRow("+I", Int.box(3), Double.box(4.0), String.valueOf("b-3")), + changelogRow("+I", Int.box(13), Double.box(13.0), String.valueOf("b-13")) + ), + List(List("b1")), + List("b0") + ) + + prepareSrcTableWithData( + "C", + List("c1 int", "c2 string", "c0 double primary key not enforced"), + List( + changelogRow("+I", Int.box(1), String.valueOf("c-1"), Double.box(1.0)), + changelogRow("+I", Int.box(2), String.valueOf("c-2"), Double.box(2.0)), + changelogRow("+I", Int.box(1), String.valueOf("c-3"), Double.box(3.0)), + changelogRow("+I", Int.box(3), String.valueOf("c-3-2"), Double.box(32.0)), + changelogRow("+I", Int.box(3), String.valueOf("c-3-3"), Double.box(33.0)), + changelogRow("+I", Int.box(99), String.valueOf("c-99"), Double.box(99.0)) + ), + List(List("c1")), + List("c1") + ) + + prepareSrcTableWithData( + "D", + List("d2 string", "d1 int primary key not enforced", "d0 double"), + List( + changelogRow("+I", String.valueOf("d-1"), Int.box(1), Double.box(1.0)), + changelogRow("-U", String.valueOf("d-1"), Int.box(1), Double.box(1.0)), + changelogRow("+U", String.valueOf("d-2"), Int.box(1), Double.box(1.0)), + changelogRow("+I", String.valueOf("d-3"), Int.box(3), Double.box(3.0)), + changelogRow("+I", String.valueOf("d-100"), Int.box(100), Double.box(100.0)) + ), + List(List("d1")), + List() + ) + } + + private def prepareSrcTableWithData( + tableName: String, + fields: List[String], + data: List[Row], + indexes: List[List[String]], + immutableCols: List[String]): Unit = { + Preconditions.checkArgument(fields.nonEmpty) + tEnv.executeSql(s"drop table if exists $tableName") + tEnv.executeSql(s""" + |create table $tableName( + | ${fields.mkString(",")} + |) with ( + | 'connector' = 'values', + | 'bounded' = 'false', + | 'changelog-mode' = 'I,UA,UB', + | 'data-id' = '${TestValuesTableFactory.registerData(data)}', + | 'async' = 'true' + |) + |""".stripMargin) + + addIndexesAndImmutableCols(tableName, indexes, immutableCols) + } + +} diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DeltaJoinITCaseBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DeltaJoinITCaseBase.scala new file mode 100644 index 0000000000000..f2586cf30a8c9 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DeltaJoinITCaseBase.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.planner.runtime.stream.sql + +import org.apache.flink.table.api.Schema +import org.apache.flink.table.api.bridge.scala.internal.StreamTableEnvironmentImpl +import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions} +import org.apache.flink.table.api.config.OptimizerConfigOptions.DeltaJoinStrategy +import org.apache.flink.table.catalog._ +import org.apache.flink.table.planner.factories.TestValuesRuntimeFunctions.AsyncTestValueLookupFunction +import org.apache.flink.table.planner.runtime.utils.StreamingTestBase +import org.apache.flink.testutils.junit.extensions.parameterized.{ParameterizedTestExtension, Parameters} + +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.extension.ExtendWith + +import scala.collection.JavaConversions._ + +/** Base class for delta join integration tests. */ +@ExtendWith(Array(classOf[ParameterizedTestExtension])) +class DeltaJoinITCaseBase(enableCache: Boolean) extends StreamingTestBase { + + @BeforeEach + override def before(): Unit = { + super.before() + + tEnv.getConfig.set( + OptimizerConfigOptions.TABLE_OPTIMIZER_DELTA_JOIN_STRATEGY, + DeltaJoinStrategy.FORCE) + + tEnv.getConfig.set( + ExecutionConfigOptions.TABLE_EXEC_DELTA_JOIN_CACHE_ENABLED, + Boolean.box(enableCache)) + + AsyncTestValueLookupFunction.invokeCount.set(0) + } + + /** TODO add [[Index]] and [[ImmutableColumnsConstraint]] in DDL. */ + protected def addIndexesAndImmutableCols( + tableName: String, + indexes: List[List[String]], + immutableCols: List[String]): Unit = { + if (indexes.isEmpty && immutableCols.isEmpty) { + return + } + + val catalogName = tEnv.getCurrentCatalog + val databaseName = tEnv.getCurrentDatabase + val tablePath = new ObjectPath(databaseName, tableName) + val catalog = tEnv.getCatalog(catalogName).get() + val catalogManager = tEnv.asInstanceOf[StreamTableEnvironmentImpl].getCatalogManager + val schemaResolver = catalogManager.getSchemaResolver + + val resolvedTable = catalog.getTable(tablePath).asInstanceOf[ResolvedCatalogTable] + val originTable = resolvedTable.getOrigin + val originSchema = originTable.getUnresolvedSchema + + val newSchemaBuilder = Schema + .newBuilder() + .fromSchema(originSchema) + + if (indexes.nonEmpty) { + indexes.foreach(index => newSchemaBuilder.index(index)) + } + if (immutableCols.nonEmpty) { + newSchemaBuilder.immutableColumns(immutableCols) + } + + val newSchema = newSchemaBuilder.build() + + val newTable = CatalogTable + .newBuilder() + .schema(newSchema) + .comment(originTable.getComment) + .partitionKeys(originTable.getPartitionKeys) + .options(originTable.getOptions) + .build() + val newResolvedTable = new ResolvedCatalogTable(newTable, schemaResolver.resolve(newSchema)) + + catalog.dropTable(tablePath, false) + catalog.createTable(tablePath, newResolvedTable, false) + } + +} + +object DeltaJoinITCaseBase { + @Parameters(name = "EnableCache={0}") + def parameters(): java.util.Collection[Boolean] = { + Seq[Boolean](true, false) + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/AsyncDeltaJoinRunner.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/AsyncDeltaJoinRunner.java index 7c184b31522ad..d7de605f45541 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/AsyncDeltaJoinRunner.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/AsyncDeltaJoinRunner.java @@ -237,6 +237,11 @@ public DeltaJoinCache getCache() { return cache; } + @VisibleForTesting + public List getAllProcessors() { + return allProcessors; + } + private Optional> tryGetDataFromCache(RowData joinKey) { Preconditions.checkState(enableCache); @@ -419,6 +424,16 @@ public void complete(CollectionSupplier supplier) { throw new UnsupportedOperationException(); } + @VisibleForTesting + public DeltaJoinHandlerChain getDeltaJoinHandlerChain() { + return handlerChain; + } + + @VisibleForTesting + public MultiInputRowDataBuffer getMultiInputRowDataBuffer() { + return multiInputRowDataBuffer; + } + private void updateCacheIfNecessary(Collection lookupRows) throws Exception { if (!enableCache) { return; diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/CascadedLookupHandler.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/CascadedLookupHandler.java new file mode 100644 index 0000000000000..3e915ed06d73f --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/CascadedLookupHandler.java @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.join.deltajoin; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.api.common.functions.util.FunctionUtils; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.utils.JoinedRowData; +import org.apache.flink.table.runtime.generated.FilterCondition; +import org.apache.flink.table.runtime.generated.GeneratedFilterCondition; +import org.apache.flink.table.runtime.generated.GeneratedFunction; +import org.apache.flink.table.runtime.keyselector.RowDataKeySelector; +import org.apache.flink.table.runtime.typeutils.RowDataSerializer; +import org.apache.flink.table.types.DataType; +import org.apache.flink.util.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * This handler represents one of the lookup actions when performing cascaded lookups on multi + * dimension tables. + * + *

For example, if the lookup chain is `[D -> A, A -> B, (A, B) -> C]`, there are three {@link + * CascadedLookupHandler}, each representing one of the lookup actions within the sequence. + */ +public class CascadedLookupHandler extends LookupHandlerBase { + + private static final long serialVersionUID = 1L; + + private static final Logger LOG = LoggerFactory.getLogger(CascadedLookupHandler.class); + + // used for debug and start with 1 + protected final int id; + private @Nullable final GeneratedFilterCondition generatedRemainingCondition; + private final RowDataKeySelector streamSideLookupKeySelector; + private final boolean leftLookupRight; + + private @Nullable transient FilterCondition remainingCondition; + private transient Map> allInputsWithLookupKey; + private transient Map> lookupResults; + + protected @Nullable transient Integer totalNumShouldBeHandledThisRound = null; + protected @Nullable transient Integer handledNum = null; + + public CascadedLookupHandler( + int id, + DataType streamSideType, + DataType lookupResultType, + DataType lookupSidePassThroughCalcType, + RowDataSerializer lookupSidePassThroughCalcRowSerializer, + @Nullable GeneratedFunction> lookupSideGeneratedCalc, + @Nullable GeneratedFilterCondition generatedRemainingCondition, + RowDataKeySelector streamSideLookupKeySelector, + int[] ownedSourceOrdinals, + int ownedLookupOrdinal, + boolean leftLookupRight) { + super( + streamSideType, + lookupResultType, + lookupSidePassThroughCalcType, + lookupSidePassThroughCalcRowSerializer, + lookupSideGeneratedCalc, + ownedSourceOrdinals, + ownedLookupOrdinal, + "CascadedLookupHandler-" + id); + this.id = id; + this.generatedRemainingCondition = generatedRemainingCondition; + this.streamSideLookupKeySelector = streamSideLookupKeySelector; + + if (leftLookupRight) { + Preconditions.checkArgument( + Arrays.stream(ownedSourceOrdinals).allMatch(i -> i < ownedLookupOrdinal)); + } else { + Preconditions.checkArgument( + Arrays.stream(ownedSourceOrdinals).allMatch(i -> i > ownedLookupOrdinal)); + } + this.leftLookupRight = leftLookupRight; + } + + @Override + public void setNext(@Nullable DeltaJoinHandlerBase next) { + super.setNext(next); + + Preconditions.checkArgument( + next != null, "This cascaded lookup handler must have a concrete handler after it"); + } + + @Override + public void open(OpenContext openContext, DeltaJoinHandlerContext handlerContext) + throws Exception { + super.open(openContext, handlerContext); + + RuntimeContext runtimeContext = handlerContext.getRuntimeContext(); + + if (generatedRemainingCondition != null) { + this.remainingCondition = + generatedRemainingCondition.newInstance( + runtimeContext.getUserCodeClassLoader()); + FunctionUtils.setFunctionRuntimeContext(remainingCondition, runtimeContext); + FunctionUtils.openFunction(remainingCondition, openContext); + } + + this.lookupResults = new HashMap<>(); + this.allInputsWithLookupKey = new HashMap<>(); + + Preconditions.checkState( + next != null, "This cascaded lookup handler must have a concrete handler after it"); + } + + @Override + public void asyncHandle() throws Exception { + Preconditions.checkState( + totalNumShouldBeHandledThisRound == null && handledNum == null, + "This handler is handled without being reset"); + + Collection allSourceRowData = + handlerContext.getSharedMultiInputRowDataBuffer().getData(ownedSourceOrdinals); + + for (RowData input : allSourceRowData) { + RowData lookupKey = streamSideLookupKeySelector.getKey(input); + allInputsWithLookupKey.compute( + lookupKey, + (k, v) -> { + if (v == null) { + v = new ArrayList<>(); + } + v.add(input); + return v; + }); + } + + totalNumShouldBeHandledThisRound = allInputsWithLookupKey.size(); + handledNum = 0; + + if (LOG.isDebugEnabled()) { + LOG.debug( + "Begin to lookup from {} to {}, total round: {}, current round: {}", + Arrays.toString(ownedSourceOrdinals), + ownedLookupOrdinal, + totalNumShouldBeHandledThisRound, + handledNum); + } + + if (allSourceRowData.isEmpty()) { + finish(); + return; + } + + for (List inputsOnThisLookupKey : allInputsWithLookupKey.values()) { + // pick the first row as input to lookup + RowData chosenToLookupInput = inputsOnThisLookupKey.get(0); + fetcher.asyncInvoke(chosenToLookupInput, createLookupResultFuture(chosenToLookupInput)); + } + } + + @Override + protected void completeResultsInMailbox(RowData input, Collection result) { + Preconditions.checkState( + totalNumShouldBeHandledThisRound != null && handledNum != null, + "This handler is completed without being handled"); + + lookupResults.put(input, result); + + handledNum++; + + if (LOG.isDebugEnabled()) { + LOG.debug( + "End to lookup from {} to {}, total round: {}, next round: {}", + Arrays.toString(ownedSourceOrdinals), + ownedLookupOrdinal, + totalNumShouldBeHandledThisRound, + handledNum); + } + + if (noFurtherInput()) { + if (LOG.isDebugEnabled()) { + LOG.debug( + "There is no further input when looking up from {} to {}, total round: {}, next round: {}", + Arrays.toString(ownedSourceOrdinals), + ownedLookupOrdinal, + totalNumShouldBeHandledThisRound, + handledNum); + } + try { + finish(); + } catch (Exception t) { + LOG.error("Error happened in the lookup chain when finishing", t); + completeExceptionally(t); + } + } + } + + private boolean noFurtherInput() { + Preconditions.checkState( + totalNumShouldBeHandledThisRound != null && handledNum != null, + "This function is called without be handled"); + Preconditions.checkState( + handledNum <= totalNumShouldBeHandledThisRound, + String.format( + "The handled num is greater than the total num. The handledNum is %d, the totalNumShouldBeHandledThisRound is %d", + handledNum, totalNumShouldBeHandledThisRound)); + + return handledNum.equals(totalNumShouldBeHandledThisRound); + } + + private void finish() throws Exception { + Preconditions.checkState(noFurtherInput()); + totalNumShouldBeHandledThisRound = null; + handledNum = null; + + Set finalLookupResults = new HashSet<>(); + + JoinedRowData reusedRowData = new JoinedRowData(); + + // the time complexity is O(inputs * lookup results) + for (RowData lookupKey : allInputsWithLookupKey.keySet()) { + List allInputsOnThisLookupKey = allInputsWithLookupKey.get(lookupKey); + RowData chosenToLookupInput = allInputsOnThisLookupKey.get(0); + + Collection lookupResult = lookupResults.get(chosenToLookupInput); + for (RowData input : allInputsOnThisLookupKey) { + for (RowData lookedUp : lookupResult) { + if (remainingCondition != null) { + if (leftLookupRight) { + reusedRowData.replace(input, lookedUp); + } else { + reusedRowData.replace(lookedUp, input); + } + if (!remainingCondition.apply( + FilterCondition.Context.INVALID_CONTEXT, reusedRowData)) { + continue; + } + } + + finalLookupResults.add(lookedUp); + } + } + } + + MultiInputRowDataBuffer sharedMultiInputRowDataBuffer = + handlerContext.getSharedMultiInputRowDataBuffer(); + sharedMultiInputRowDataBuffer.setRowData(finalLookupResults, ownedLookupOrdinal); + + if (next != null) { + next.asyncHandle(); + } + } + + @Override + public CascadedLookupHandler copyInternal() { + return new CascadedLookupHandler( + id, + streamSideType, + lookupResultType, + lookupSidePassThroughCalcType, + lookupSidePassThroughCalcRowSerializer, + lookupSideGeneratedCalc, + generatedRemainingCondition, + streamSideLookupKeySelector.copy(), + ownedSourceOrdinals, + ownedLookupOrdinal, + leftLookupRight); + } + + @Override + public void reset() { + this.allInputsWithLookupKey.clear(); + this.lookupResults.clear(); + + this.totalNumShouldBeHandledThisRound = null; + this.handledNum = null; + + super.reset(); + } + + @Override + public void close() throws Exception { + if (remainingCondition != null) { + FunctionUtils.closeFunction(remainingCondition); + } + + if (this.allInputsWithLookupKey != null) { + this.allInputsWithLookupKey.clear(); + } + + if (this.lookupResults != null) { + this.lookupResults.clear(); + } + + super.close(); + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/TailOutputDataHandler.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/TailOutputDataHandler.java new file mode 100644 index 0000000000000..2a6e6488c770e --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/TailOutputDataHandler.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.join.deltajoin; + +import org.apache.flink.table.data.RowData; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nullable; + +import java.util.Collection; + +/** + * A tail handler to output the data. + * + *

Note: this handler is only used in the tail of the delta join chain and should not have a + * concrete handler after it. + */ +public class TailOutputDataHandler extends DeltaJoinHandlerBase { + + private static final long serialVersionUID = 1L; + + private final int[] allLookupSideBinaryInputOrdinals; + + public TailOutputDataHandler(int[] allLookupSideBinaryInputOrdinals) { + this.allLookupSideBinaryInputOrdinals = allLookupSideBinaryInputOrdinals; + } + + @Override + public void setNext(@Nullable DeltaJoinHandlerBase next) { + super.setNext(next); + + Preconditions.checkArgument( + next == null, "This tail handler should not have a concrete handler after it"); + } + + @Override + public void asyncHandle() throws Exception { + MultiInputRowDataBuffer sharedMultiInputRowDataBuffer = + handlerContext.getSharedMultiInputRowDataBuffer(); + + Collection allData = + sharedMultiInputRowDataBuffer.getData(allLookupSideBinaryInputOrdinals); + handlerContext.getRealOutputResultFuture().complete(allData); + } + + @Override + protected DeltaJoinHandlerBase copyInternal() { + return new TailOutputDataHandler(allLookupSideBinaryInputOrdinals); + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/CalcCollectionCollector.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/CalcCollectionCollector.java index 59dd86d77c884..298a76cfcde5b 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/CalcCollectionCollector.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/CalcCollectionCollector.java @@ -18,6 +18,7 @@ package org.apache.flink.table.runtime.operators.join.lookup; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.table.data.RowData; import org.apache.flink.table.runtime.typeutils.RowDataSerializer; import org.apache.flink.util.Collector; @@ -51,4 +52,9 @@ public void collect(RowData record) { @Override public void close() {} + + @VisibleForTesting + public RowDataSerializer getLookupResultRowSerializer() { + return lookupResultRowSerializer; + } } diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingBinaryDeltaJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingBinaryDeltaJoinOperatorTest.java index cc6d8b173c720..4e44935fea3d5 100644 --- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingBinaryDeltaJoinOperatorTest.java +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingBinaryDeltaJoinOperatorTest.java @@ -1056,12 +1056,12 @@ void testMeetExceptionWhenLookup() throws Exception { .isEqualTo(expectedException); } - private void initTestHarness(AbstractBinaryTestSpec testSpec) throws Exception { + private void initTestHarness(BinaryTestSpec testSpec) throws Exception { initTestHarness(testSpec, null, true); } private void initTestHarness( - AbstractBinaryTestSpec testSpec, + BinaryTestSpec testSpec, @Nullable Throwable expectedThrownException, boolean insertTableDataAfterEmit) throws Exception { @@ -1087,12 +1087,12 @@ private void initTestHarness( })); } - private void initAssertor(AbstractBinaryTestSpec testSpec) { + private void initAssertor(BinaryTestSpec testSpec) { assertor = createAssertor(testSpec.getOutputRowType()); } private void verifyCacheData( - AbstractBinaryTestSpec testSpec, + BinaryTestSpec testSpec, DeltaJoinCache actualCache, Map> expectedLeftCacheData, Map> expectedRightCacheData, @@ -1133,7 +1133,7 @@ private void waitAllDataProcessed() throws Exception { private KeyedTwoInputStreamOperatorTestHarness createBinaryDeltaJoinOperatorTestHarness( - AbstractBinaryTestSpec testSpec, @Nullable Throwable expectedThrownException) + BinaryTestSpec testSpec, @Nullable Throwable expectedThrownException) throws Exception { int[] eachBinaryInputFieldSize = new int[] { @@ -1241,16 +1241,16 @@ private void waitAllDataProcessed() throws Exception { enableCache); } - private void insertLeftTable(AbstractBinaryTestSpec testSpec, StreamRecord record) { + private void insertLeftTable(BinaryTestSpec testSpec, StreamRecord record) { insertTableData(testSpec, record.getValue(), true); } - private void insertRightTable(AbstractBinaryTestSpec testSpec, StreamRecord record) { + private void insertRightTable(BinaryTestSpec testSpec, StreamRecord record) { insertTableData(testSpec, record.getValue(), false); } private void insertTableData( - AbstractBinaryTestSpec testSpec, RowData rowData, boolean insertLeftTable) { + BinaryTestSpec testSpec, RowData rowData, boolean insertLeftTable) { try { synchronized (tableCurrentDataMap) { if (insertLeftTable) { @@ -1279,7 +1279,7 @@ private RowData toBinary(RowData row, RowType rowType) { return binaryrow(fields); } - private abstract static class AbstractBinaryTestSpec extends AbstractBaseTestSpec { + private abstract static class BinaryTestSpec extends AbstractTestSpec { abstract Optional> getFilterOnLeftTable(); @@ -1347,7 +1347,7 @@ final GeneratedFilterCondition getGeneratedJoinCondition() { * and left_jk2_index = right_jk2 * */ - private static class LogLogTableJoinTestSpec extends AbstractBinaryTestSpec { + private static class LogLogTableJoinTestSpec extends BinaryTestSpec { private static final LogLogTableJoinTestSpec WITHOUT_FILTER_ON_TABLE = new LogLogTableJoinTestSpec(false); @@ -1451,7 +1451,7 @@ Optional> getFilterOnRightTable() { * ) on left_pk2_jk_index = right_pk2_jk_index * */ - private static class PkPkTableJoinTestSpec extends AbstractBinaryTestSpec { + private static class PkPkTableJoinTestSpec extends BinaryTestSpec { private static final PkPkTableJoinTestSpec WITHOUT_FILTER_ON_TABLE = new PkPkTableJoinTestSpec(false); diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingCascadedDeltaJoinOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingCascadedDeltaJoinOperatorTest.java new file mode 100644 index 0000000000000..ba1052d930cae --- /dev/null +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingCascadedDeltaJoinOperatorTest.java @@ -0,0 +1,2293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.join.deltajoin; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.runtime.operators.testutils.ExpectedTestException; +import org.apache.flink.streaming.api.functions.async.AsyncFunction; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.generated.GeneratedFilterCondition; +import org.apache.flink.table.runtime.generated.GeneratedFunction; +import org.apache.flink.table.runtime.keyselector.RowDataKeySelector; +import org.apache.flink.table.runtime.operators.join.FlinkJoinType; +import org.apache.flink.table.runtime.operators.join.deltajoin.DeltaJoinRuntimeTree.JoinNode; +import org.apache.flink.table.runtime.operators.join.deltajoin.LookupHandlerBase.Object2RowDataConverterResultFuture; +import org.apache.flink.table.runtime.operators.join.lookup.keyordered.TableAsyncExecutionController; +import org.apache.flink.table.runtime.typeutils.InternalSerializers; +import org.apache.flink.table.runtime.util.RowDataHarnessAssertor; +import org.apache.flink.table.types.logical.DoubleType; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; +import org.apache.flink.testutils.junit.extensions.parameterized.Parameter; +import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension; +import org.apache.flink.testutils.junit.extensions.parameterized.Parameters; +import org.apache.flink.util.Preconditions; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicReference; + +import static java.util.Objects.requireNonNull; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.insertRecord; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.row; +import static org.apache.flink.table.runtime.util.StreamRecordUtils.updateAfterRecord; +import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataType; +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Test class for cascaded {@link StreamingDeltaJoinOperator}. + * + *

Compared to tests on {@link StreamingBinaryDeltaJoinOperatorTest} that tests the binary delta + * join logic and aec logic, this class focuses on cascaded delta join. + */ +@ExtendWith(ParameterizedTestExtension.class) +public class StreamingCascadedDeltaJoinOperatorTest extends StreamingDeltaJoinOperatorTestBase { + + /** + * Mock ddl like the following. + * + *

+     *      CREATE TABLE A(
+     *          a0 DOUBLE,
+     *          a1 INT PRIMARY KEY NOT ENFORCED,
+     *          a2 STRING
+     *      )
+     * 
+ * + *
+     *      CREATE TABLE B(
+     *          b1 INT,
+     *          b0 DOUBLE,
+     *          b2 STRING,
+     *          INDEX(b1)
+     *      )
+     * 
+ * + *
+     *      CREATE TABLE C(
+     *          c1 INT,
+     *          c2 STRING,
+     *          c0 DOUBLE PRIMARY KEY NOT ENFORCED,
+     *          INDEX(c1)
+     *      )
+     * 
+ * + *
+     *      CREATE TABLE D(
+     *          d2 STRING,
+     *          d0 DOUBLE,
+     *          d1 INT,
+     *          INDEX(d1),
+     *          INDEX(d0)
+     *      )
+     * 
+ */ + private static final Map tableRowTypeMap = new HashMap<>(); + + private static final Map tableUpsertKeySelector = new HashMap<>(); + + static { + tableRowTypeMap.put( + 0, + RowType.of( + new LogicalType[] { + new DoubleType(), new IntType(), VarCharType.STRING_TYPE + }, + new String[] {"a0", "a1", "a2"})); + tableUpsertKeySelector.put(0, getKeySelector(new int[] {1}, tableRowTypeMap.get(0))); + + tableRowTypeMap.put( + 1, + RowType.of( + new LogicalType[] { + new IntType(), new DoubleType(), VarCharType.STRING_TYPE + }, + new String[] {"b1", "b0", "b2"})); + tableUpsertKeySelector.put(1, getKeySelector(new int[] {0, 1, 2}, tableRowTypeMap.get(1))); + + tableRowTypeMap.put( + 2, + RowType.of( + new LogicalType[] { + new IntType(), VarCharType.STRING_TYPE, new DoubleType() + }, + new String[] {"c1", "c2", "c0"})); + tableUpsertKeySelector.put(2, getKeySelector(new int[] {2}, tableRowTypeMap.get(2))); + + tableRowTypeMap.put( + 3, + RowType.of( + new LogicalType[] { + VarCharType.STRING_TYPE, new DoubleType(), new IntType() + }, + new String[] {"d2", "d0", "d1"})); + tableUpsertKeySelector.put(3, getKeySelector(new int[] {0, 1, 2}, tableRowTypeMap.get(3))); + } + + // the data snapshot of the tables when joining + // > + private final Map> tableCurrentDataMap = + new HashMap<>(); + + private StreamingDeltaJoinOperator operator; + private KeyedTwoInputStreamOperatorTestHarness testHarness; + private RowDataHarnessAssertor assertor; + + @Parameter public boolean enableCache; + + @Parameters(name = "EnableCache = {0}") + public static List parameters() { + return Arrays.asList(false, true); + } + + @AfterEach + public void afterEach() throws Exception { + if (testHarness != null) { + testHarness.close(); + } + + tableCurrentDataMap.clear(); + + MyAsyncFunction.getLookupInvokeCount().clear(); + } + + /** + * The Join tree used to test is as following. + * + *
+     *             DT2
+     *           /    \
+     *         DT1     C
+     *       /    \
+     *      A      B
+     *
+     *    when records from C come, lookup chain is as following:
+     *    C -> B -> A
+     * 
+ * + *

Here we mainly test DT2 while two inputs come data. + */ + @TestTemplate + void testLHSWithTwoInputsProcessData() throws Exception { + LHSTestSpec spec = new LHSTestSpec(false, false, false); + prepareEnv(spec); + + StreamRecord leftRecordK1V1 = insertRecord(1.0, 1, "a-1", 1, 1.0, "b-1"); + // this record exists in table B but is filtered out in DT1 + insertTableData(1, row(1, 1.0, "b-2")); + StreamRecord leftRecordK1V2 = insertRecord(1.0, 1, "a-2", 1, 1.0, "b-3"); + StreamRecord leftRecordK1V3 = updateAfterRecord(1.0, 1, "a-3", 1, 1.0, "b-3"); + + testHarness.processElement1(leftRecordK1V1); + testHarness.processElement1(leftRecordK1V2); + testHarness.processElement1(leftRecordK1V3); + testHarness.endAllInputs(); + final ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + StreamRecord rightRecordK1V1 = insertRecord(1, "c-1", 1.0); + testHarness.processElement2(rightRecordK1V1); + testHarness.endAllInputs(); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + insertTableData(0, row(1.0, 1, "a-4")); + + StreamRecord rightRecordK1V2 = insertRecord(1, "c-2", 1.0); + testHarness.processElement2(rightRecordK1V2); + testHarness.endAllInputs(); + if (enableCache) { + expectedOutput.add(insertRecord(1.0, 1, "a-3", 1, 1.0, "b-1", 1, "c-2", 1.0)); + expectedOutput.add(insertRecord(1.0, 1, "a-3", 1, 1.0, "b-3", 1, "c-2", 1.0)); + } else { + expectedOutput.add(insertRecord(1.0, 1, "a-4", 1, 1.0, "b-1", 1, "c-2", 1.0)); + expectedOutput.add(insertRecord(1.0, 1, "a-4", 1, 1.0, "b-3", 1, "c-2", 1.0)); + } + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + insertTableData(2, row(1, "c-3", 1.0)); + + StreamRecord leftRecordK1V4 = updateAfterRecord(1.0, 1, "a-4", 1, 1.0, "b-3"); + testHarness.processElement1(leftRecordK1V4); + testHarness.endAllInputs(); + if (enableCache) { + expectedOutput.add(updateAfterRecord(1.0, 1, "a-4", 1, 1.0, "b-3", 1, "c-2", 1.0)); + } else { + expectedOutput.add(updateAfterRecord(1.0, 1, "a-4", 1, 1.0, "b-3", 1, "c-3", 1.0)); + } + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + StreamRecord rightRecordK1V3 = insertRecord(1, "c-3", 1.0); + testHarness.processElement2(rightRecordK1V3); + testHarness.endAllInputs(); + if (enableCache) { + expectedOutput.add(insertRecord(1.0, 1, "a-3", 1, 1.0, "b-1", 1, "c-3", 1.0)); + expectedOutput.add(insertRecord(1.0, 1, "a-3", 1, 1.0, "b-3", 1, "c-3", 1.0)); + expectedOutput.add(insertRecord(1.0, 1, "a-4", 1, 1.0, "b-3", 1, "c-3", 1.0)); + } else { + expectedOutput.add(insertRecord(1.0, 1, "a-4", 1, 1.0, "b-1", 1, "c-3", 1.0)); + expectedOutput.add(insertRecord(1.0, 1, "a-4", 1, 1.0, "b-3", 1, "c-3", 1.0)); + } + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + // validate aec + TableAsyncExecutionController aec = unwrapAEC(testHarness); + assertThat(aec.getBlockingSize()).isEqualTo(0); + assertThat(aec.getInFlightSize()).isEqualTo(0); + assertThat(aec.getFinishSize()).isEqualTo(0); + + // validate cache + DeltaJoinCache cache = unwrapCache(testHarness); + if (enableCache) { + RowDataKeySelector leftJoinKeySelector = spec.getLeftJoinKeySelector(); + RowDataKeySelector leftUpsertKeySelector = spec.getLeftUpsertKeySelector(); + RowDataKeySelector rightJoinKeySelector = spec.getRightJoinKeySelector(); + RowDataKeySelector rightUpsertKeySelector = spec.getRightUpsertKeySelector(); + Map> expectedLeftCacheData = + Map.of( + rightJoinKeySelector.getKey(rightRecordK1V3.getValue()), + Map.of( + leftUpsertKeySelector.getKey( + insertRecord(1.0, 1, "a-3", 1, 1.0, "b-1").getValue()), + insertRecord(1.0, 1, "a-3", 1, 1.0, "b-1").getValue(), + leftUpsertKeySelector.getKey( + insertRecord(1.0, 1, "a-3", 1, 1.0, "b-3").getValue()), + insertRecord(1.0, 1, "a-3", 1, 1.0, "b-3").getValue(), + leftUpsertKeySelector.getKey( + updateAfterRecord(1.0, 1, "a-4", 1, 1.0, "b-3") + .getValue()), + updateAfterRecord(1.0, 1, "a-4", 1, 1.0, "b-3").getValue())); + + Map> expectedRightCacheData = + Map.of( + leftJoinKeySelector.getKey(leftRecordK1V4.getValue()), + Map.of( + rightUpsertKeySelector.getKey(rightRecordK1V3.getValue()), + rightRecordK1V3.getValue())); + + verifyCacheData(spec, cache, expectedLeftCacheData, expectedRightCacheData, 3, 2, 4, 3); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(0).get()).isEqualTo(1); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(1).get()).isEqualTo(1); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(2).get()).isEqualTo(1); + } else { + verifyCacheData( + spec, cache, Collections.emptyMap(), Collections.emptyMap(), 0, 0, 0, 0); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(0).get()).isEqualTo(3); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(1).get()).isEqualTo(3); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(2).get()).isEqualTo(4); + } + + AsyncDeltaJoinRunner openedLeft2RightAsyncRunner = operator.getLeftTriggeredUserFunction(); + AsyncDeltaJoinRunner openedRight2LeftAsyncRunner = operator.getRightTriggeredUserFunction(); + + assertThat(openedLeft2RightAsyncRunner.getAllProcessors().size()) + .isEqualTo(AEC_CAPACITY + 1); + openedLeft2RightAsyncRunner + .getAllProcessors() + .forEach( + processor -> { + DeltaJoinHandlerBase handler = + processor.getDeltaJoinHandlerChain().getHead(); + assertThat(handler).isInstanceOf(BinaryLookupHandler.class); + }); + + assertThat(openedRight2LeftAsyncRunner.getAllProcessors().size()) + .isEqualTo(AEC_CAPACITY + 1); + openedRight2LeftAsyncRunner + .getAllProcessors() + .forEach( + processor -> { + DeltaJoinHandlerBase handler = + processor.getDeltaJoinHandlerChain().getHead(); + assertThat(handler).isInstanceOf(CascadedLookupHandler.class); + assertThat(handler.getNext()).isNotNull(); + + handler = handler.getNext(); + assertThat(handler).isInstanceOf(CascadedLookupHandler.class); + assertThat(handler.getNext()).isNotNull(); + + handler = handler.getNext(); + assertThat(handler).isInstanceOf(TailOutputDataHandler.class); + }); + } + + /** + * The Join tree used to test is as following. + * + *
+     *             DT2
+     *           /    \
+     *         DT1     C
+     *       /    \
+     *      A      B
+     *
+     *    when records from C come, lookup chain is as following:
+     *    C -> B -> A
+     * 
+ * + *

Here we mainly test DT2 while two inputs come data. + */ + @TestTemplate + void testLHSWithTwoInputsProcessDataWithFilterBetweenJoinAndSource() throws Exception { + prepareEnv(new LHSTestSpec(true, false, false)); + + // this record exists in table A, but is filtered out in filter on A + insertTableData(0, row(1.0, 1, "a-1")); + // this record exists in table B, but is filtered out in filter on B + insertTableData(1, row(1, 1.0, "b-1")); + // this record exists in table B, but is filtered out in dt1 + insertTableData(1, row(1, 1000.0, "b-2")); + + StreamRecord leftRecordK1V1 = insertRecord(1000.0, 1, "a-2", 1, 1000.0, "b-3"); + StreamRecord leftRecordK1V2 = insertRecord(1000.0, 1, "a-2", 1, 1000.0, "b-4"); + + testHarness.processElement1(leftRecordK1V1); + testHarness.processElement1(leftRecordK1V2); + testHarness.endAllInputs(); + final ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + StreamRecord rightRecordK1V1 = insertRecord(1, "c-1", 1000.0); + testHarness.processElement2(rightRecordK1V1); + testHarness.endAllInputs(); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + StreamRecord rightRecordK1V2 = updateAfterRecord(1, "c-2", 1200.0); + testHarness.processElement2(rightRecordK1V2); + testHarness.endAllInputs(); + expectedOutput.add(updateAfterRecord(1000.0, 1, "a-2", 1, 1000.0, "b-3", 1, "c-2", 1200.0)); + expectedOutput.add(updateAfterRecord(1000.0, 1, "a-2", 1, 1000.0, "b-4", 1, "c-2", 1200.0)); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + insertTableData(0, row(1000.0, 1, "a-3")); + + StreamRecord rightRecordK1V3 = updateAfterRecord(1, "c-3", 1300.0); + testHarness.processElement2(rightRecordK1V3); + testHarness.endAllInputs(); + if (enableCache) { + expectedOutput.add( + updateAfterRecord(1000.0, 1, "a-2", 1, 1000.0, "b-3", 1, "c-3", 1300.0)); + expectedOutput.add( + updateAfterRecord(1000.0, 1, "a-2", 1, 1000.0, "b-4", 1, "c-3", 1300.0)); + } else { + expectedOutput.add( + updateAfterRecord(1000.0, 1, "a-3", 1, 1000.0, "b-3", 1, "c-3", 1300.0)); + expectedOutput.add( + updateAfterRecord(1000.0, 1, "a-3", 1, 1000.0, "b-4", 1, "c-3", 1300.0)); + } + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + StreamRecord leftRecordK1V3 = + updateAfterRecord(1000.0, 1, "a-3", 1, 1000.0, "b-3"); + StreamRecord leftRecordK1V4 = + updateAfterRecord(1000.0, 1, "a-3", 1, 1000.0, "b-4"); + testHarness.processElement1(leftRecordK1V3); + testHarness.processElement1(leftRecordK1V4); + testHarness.endAllInputs(); + expectedOutput.add(updateAfterRecord(1000.0, 1, "a-3", 1, 1000.0, "b-3", 1, "c-2", 1200.0)); + expectedOutput.add(updateAfterRecord(1000.0, 1, "a-3", 1, 1000.0, "b-3", 1, "c-3", 1300.0)); + expectedOutput.add(updateAfterRecord(1000.0, 1, "a-3", 1, 1000.0, "b-4", 1, "c-2", 1200.0)); + expectedOutput.add(updateAfterRecord(1000.0, 1, "a-3", 1, 1000.0, "b-4", 1, "c-3", 1300.0)); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + } + + /** + * The Join tree used to test is as following. + * + *
+     *             DT2
+     *           /    \
+     *         DT1     C
+     *       /    \
+     *      A      B
+     *
+     *    when records from C come, lookup chain is as following:
+     *    C -> B -> A
+     * 
+ * + *

Here we mainly test DT2 while two inputs come data. + */ + @TestTemplate + void testLHSWithTwoInputsProcessDataWithFilterBothBetweenJoinAndSourceAndCascadedJoins() + throws Exception { + prepareEnv(new LHSTestSpec(true, true, false)); + + // this record exists in table A, but is filtered out in filter on A + insertTableData(0, row(99.0, 1, "a-1")); + // this record exists in table B, but is filtered out in filter on B + insertTableData(1, row(1, 199.0, "b-1")); + // this record exists in table B, but is filtered out in dt1 + insertTableData(1, row(1, 1000.0, "b-2")); + // this record exists in table B, but is filtered out in filter after dt1 + insertTableData(1, row(1, 599.0, "b-3")); + // this record exists in table B, but is filtered out in filter after dt1 + insertTableData(1, row(1, 800.0, "b-4")); + + StreamRecord rightRecordK1V1 = insertRecord(1, "c-1", 1000.0); + testHarness.processElement2(rightRecordK1V1); + testHarness.endAllInputs(); + final ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + StreamRecord rightRecordK1V2 = updateAfterRecord(1, "c-2", 1200.0); + testHarness.processElement2(rightRecordK1V2); + testHarness.endAllInputs(); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + insertTableData(0, row(300.0, 1, "a-2")); + + StreamRecord rightRecordK1V3 = updateAfterRecord(1, "c-3", 1300.0); + testHarness.processElement2(rightRecordK1V3); + testHarness.endAllInputs(); + if (!enableCache) { + expectedOutput.add( + updateAfterRecord(300.0, 1, "a-2", 1, 800.0, "b-4", 1, "c-3", 1300.0)); + } + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + insertTableData(2, row(1, "c-4", 200.0)); + + StreamRecord leftRecordK1V1 = updateAfterRecord(300.0, 1, "a-2", 1, 800.0, "b-4"); + testHarness.processElement1(leftRecordK1V1); + testHarness.endAllInputs(); + expectedOutput.add(updateAfterRecord(300.0, 1, "a-2", 1, 800.0, "b-4", 1, "c-2", 1200.0)); + expectedOutput.add(updateAfterRecord(300.0, 1, "a-2", 1, 800.0, "b-4", 1, "c-3", 1300.0)); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + } + + /** + * The Join tree used to test is as following. + * + *
+     *       DT2
+     *     /    \
+     *    C     DT1
+     *        /    \
+     *       A      B
+     *
+     *    when records from C come, lookup chain is as following:
+     *    C -> B -> A
+     * 
+ * + *

Here we mainly test DT2 while two inputs come data. + */ + @TestTemplate + void testRHSWithTwoInputsProcessData() throws Exception { + RHSTestSpec spec = new RHSTestSpec(false, false); + prepareEnv(spec); + + StreamRecord rightRecordK1V1 = insertRecord(1.0, 1, "a-1", 1, 1.0, "b-1"); + // this record exists in table B but is filtered out in DT1 + insertTableData(1, row(1, 1.0, "b-2")); + StreamRecord rightRecordK1V2 = insertRecord(1.0, 1, "a-2", 1, 1.0, "b-3"); + StreamRecord rightRecordK1V3 = updateAfterRecord(1.0, 1, "a-3", 1, 1.0, "b-3"); + + testHarness.processElement2(rightRecordK1V1); + testHarness.processElement2(rightRecordK1V2); + testHarness.processElement2(rightRecordK1V3); + testHarness.endAllInputs(); + final ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + StreamRecord leftRecordK1V1 = insertRecord(1, "c-1", 1.0); + testHarness.processElement1(leftRecordK1V1); + testHarness.endAllInputs(); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + insertTableData(0, row(1.0, 1, "a-4")); + + StreamRecord leftRecordK1V2 = insertRecord(1, "c-2", 1.0); + testHarness.processElement1(leftRecordK1V2); + testHarness.endAllInputs(); + if (enableCache) { + expectedOutput.add(insertRecord(1, "c-2", 1.0, 1.0, 1, "a-3", 1, 1.0, "b-1")); + expectedOutput.add(insertRecord(1, "c-2", 1.0, 1.0, 1, "a-3", 1, 1.0, "b-3")); + } else { + expectedOutput.add(insertRecord(1, "c-2", 1.0, 1.0, 1, "a-4", 1, 1.0, "b-1")); + expectedOutput.add(insertRecord(1, "c-2", 1.0, 1.0, 1, "a-4", 1, 1.0, "b-3")); + } + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + insertTableData(2, row(1, "c-3", 1.0)); + + StreamRecord rightRecordK1V4 = updateAfterRecord(1.0, 1, "a-4", 1, 1.0, "b-3"); + testHarness.processElement2(rightRecordK1V4); + testHarness.endAllInputs(); + if (enableCache) { + expectedOutput.add(updateAfterRecord(1, "c-2", 1.0, 1.0, 1, "a-4", 1, 1.0, "b-3")); + } else { + expectedOutput.add(updateAfterRecord(1, "c-3", 1.0, 1.0, 1, "a-4", 1, 1.0, "b-3")); + } + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + StreamRecord leftRecordK1V3 = insertRecord(1, "c-3", 1.0); + testHarness.processElement1(leftRecordK1V3); + testHarness.endAllInputs(); + if (enableCache) { + expectedOutput.add(insertRecord(1, "c-3", 1.0, 1.0, 1, "a-3", 1, 1.0, "b-1")); + expectedOutput.add(insertRecord(1, "c-3", 1.0, 1.0, 1, "a-3", 1, 1.0, "b-3")); + expectedOutput.add(insertRecord(1, "c-3", 1.0, 1.0, 1, "a-4", 1, 1.0, "b-3")); + } else { + expectedOutput.add(insertRecord(1, "c-3", 1.0, 1.0, 1, "a-4", 1, 1.0, "b-1")); + expectedOutput.add(insertRecord(1, "c-3", 1.0, 1.0, 1, "a-4", 1, 1.0, "b-3")); + } + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + // validate aec + TableAsyncExecutionController aec = unwrapAEC(testHarness); + assertThat(aec.getBlockingSize()).isEqualTo(0); + assertThat(aec.getInFlightSize()).isEqualTo(0); + assertThat(aec.getFinishSize()).isEqualTo(0); + + // validate cache + DeltaJoinCache cache = unwrapCache(testHarness); + if (enableCache) { + RowDataKeySelector leftJoinKeySelector = spec.getLeftJoinKeySelector(); + RowDataKeySelector leftUpsertKeySelector = spec.getLeftUpsertKeySelector(); + RowDataKeySelector rightJoinKeySelector = spec.getRightJoinKeySelector(); + RowDataKeySelector rightUpsertKeySelector = spec.getRightUpsertKeySelector(); + + Map> expectedLeftCacheData = + Map.of( + rightJoinKeySelector.getKey(rightRecordK1V4.getValue()), + Map.of( + leftUpsertKeySelector.getKey(leftRecordK1V3.getValue()), + leftRecordK1V3.getValue())); + + Map> expectedRightCacheData = + Map.of( + leftJoinKeySelector.getKey(leftRecordK1V3.getValue()), + Map.of( + rightUpsertKeySelector.getKey( + insertRecord(1.0, 1, "a-3", 1, 1.0, "b-1").getValue()), + insertRecord(1.0, 1, "a-3", 1, 1.0, "b-1").getValue(), + rightUpsertKeySelector.getKey( + insertRecord(1.0, 1, "a-3", 1, 1.0, "b-3").getValue()), + insertRecord(1.0, 1, "a-3", 1, 1.0, "b-3").getValue(), + rightUpsertKeySelector.getKey( + updateAfterRecord(1.0, 1, "a-4", 1, 1.0, "b-3") + .getValue()), + updateAfterRecord(1.0, 1, "a-4", 1, 1.0, "b-3").getValue())); + + verifyCacheData(spec, cache, expectedLeftCacheData, expectedRightCacheData, 4, 3, 3, 2); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(0).get()).isEqualTo(1); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(1).get()).isEqualTo(1); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(2).get()).isEqualTo(1); + } else { + verifyCacheData( + spec, cache, Collections.emptyMap(), Collections.emptyMap(), 0, 0, 0, 0); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(0).get()).isEqualTo(3); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(1).get()).isEqualTo(3); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(2).get()).isEqualTo(4); + } + + AsyncDeltaJoinRunner openedLeft2RightAsyncRunner = operator.getLeftTriggeredUserFunction(); + AsyncDeltaJoinRunner openedRight2LeftAsyncRunner = operator.getRightTriggeredUserFunction(); + + assertThat(openedLeft2RightAsyncRunner.getAllProcessors().size()) + .isEqualTo(AEC_CAPACITY + 1); + openedLeft2RightAsyncRunner + .getAllProcessors() + .forEach( + processor -> { + DeltaJoinHandlerBase handler = + processor.getDeltaJoinHandlerChain().getHead(); + assertThat(handler).isInstanceOf(CascadedLookupHandler.class); + assertThat(handler.getNext()).isNotNull(); + + handler = handler.getNext(); + assertThat(handler).isInstanceOf(CascadedLookupHandler.class); + assertThat(handler.getNext()).isNotNull(); + + handler = handler.getNext(); + assertThat(handler).isInstanceOf(TailOutputDataHandler.class); + }); + + assertThat(openedRight2LeftAsyncRunner.getAllProcessors().size()) + .isEqualTo(AEC_CAPACITY + 1); + openedRight2LeftAsyncRunner + .getAllProcessors() + .forEach( + processor -> { + DeltaJoinHandlerBase handler = + processor.getDeltaJoinHandlerChain().getHead(); + assertThat(handler).isInstanceOf(BinaryLookupHandler.class); + }); + } + + /** + * The Join tree used to test is as following. + * + *
+     *       DT2
+     *     /    \
+     *    C     DT1
+     *        /    \
+     *       A      B
+     *
+     *    when records from C come, lookup chain is as following:
+     *    C -> B -> A
+     * 
+ * + *

Here we mainly test DT2 while two inputs come data. + */ + @TestTemplate + void testRHSWithTwoInputsProcessDataWithFilterBetweenJoinAndSource() throws Exception { + prepareEnv(new RHSTestSpec(true, false)); + + // this record exists in table A, but is filtered out in filter on A + insertTableData(0, row(1.0, 1, "a-1")); + // this record exists in table B, but is filtered out in filter on B + insertTableData(1, row(1, 1.0, "b-1")); + // this record exists in table B, but is filtered out in dt1 + insertTableData(1, row(1, 1000.0, "b-2")); + + StreamRecord rightRecordK1V1 = insertRecord(1000.0, 1, "a-2", 1, 1000.0, "b-3"); + StreamRecord rightRecordK1V2 = insertRecord(1000.0, 1, "a-2", 1, 1000.0, "b-4"); + + testHarness.processElement2(rightRecordK1V1); + testHarness.processElement2(rightRecordK1V2); + testHarness.endAllInputs(); + final ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + StreamRecord leftRecordK1V1 = insertRecord(1, "c-1", 1000.0); + testHarness.processElement1(leftRecordK1V1); + testHarness.endAllInputs(); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + StreamRecord leftRecordK1V2 = updateAfterRecord(1, "c-2", 1200.0); + testHarness.processElement1(leftRecordK1V2); + testHarness.endAllInputs(); + expectedOutput.add(updateAfterRecord(1, "c-2", 1200.0, 1000.0, 1, "a-2", 1, 1000.0, "b-3")); + expectedOutput.add(updateAfterRecord(1, "c-2", 1200.0, 1000.0, 1, "a-2", 1, 1000.0, "b-4")); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + insertTableData(0, row(1000.0, 1, "a-3")); + + StreamRecord leftRecordK1V3 = updateAfterRecord(1, "c-3", 1300.0); + testHarness.processElement1(leftRecordK1V3); + testHarness.endAllInputs(); + if (enableCache) { + expectedOutput.add( + updateAfterRecord(1, "c-3", 1300.0, 1000.0, 1, "a-2", 1, 1000.0, "b-3")); + expectedOutput.add( + updateAfterRecord(1, "c-3", 1300.0, 1000.0, 1, "a-2", 1, 1000.0, "b-4")); + } else { + expectedOutput.add( + updateAfterRecord(1, "c-3", 1300.0, 1000.0, 1, "a-3", 1, 1000.0, "b-3")); + expectedOutput.add( + updateAfterRecord(1, "c-3", 1300.0, 1000.0, 1, "a-3", 1, 1000.0, "b-4")); + } + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + StreamRecord rightRecordK1V3 = + updateAfterRecord(1000.0, 1, "a-3", 1, 1000.0, "b-3"); + StreamRecord rightRecordK1V4 = + updateAfterRecord(1000.0, 1, "a-3", 1, 1000.0, "b-4"); + testHarness.processElement2(rightRecordK1V3); + testHarness.processElement2(rightRecordK1V4); + testHarness.endAllInputs(); + expectedOutput.add(updateAfterRecord(1, "c-2", 1200.0, 1000.0, 1, "a-3", 1, 1000.0, "b-3")); + expectedOutput.add(updateAfterRecord(1, "c-3", 1300.0, 1000.0, 1, "a-3", 1, 1000.0, "b-3")); + expectedOutput.add(updateAfterRecord(1, "c-2", 1200.0, 1000.0, 1, "a-3", 1, 1000.0, "b-4")); + expectedOutput.add(updateAfterRecord(1, "c-3", 1300.0, 1000.0, 1, "a-3", 1, 1000.0, "b-4")); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + } + + /** + * The Join tree used to test is as following. + * + *
+     *       DT2
+     *     /    \
+     *    C     DT1
+     *        /    \
+     *       A      B
+     *
+     *    when records from C come, lookup chain is as following:
+     *    C -> B -> A
+     * 
+ * + *

Here we mainly test DT2 while two inputs come data. + */ + @TestTemplate + void testRHSWithTwoInputsProcessDataWithFilterBothBetweenJoinAndSourceAndCascadedJoins() + throws Exception { + prepareEnv(new RHSTestSpec(true, true)); + + // this record exists in table A, but is filtered out in filter on A + insertTableData(0, row(99.0, 1, "a-1")); + // this record exists in table B, but is filtered out in filter on B + insertTableData(1, row(1, 199.0, "b-1")); + // this record exists in table B, but is filtered out in dt1 + insertTableData(1, row(1, 1000.0, "b-2")); + // this record exists in table B, but is filtered out in filter after dt1 + insertTableData(1, row(1, 599.0, "b-3")); + // this record exists in table B, but is filtered out in filter after dt1 + insertTableData(1, row(1, 800.0, "b-4")); + + StreamRecord leftRecordK1V1 = insertRecord(1, "c-1", 1000.0); + testHarness.processElement1(leftRecordK1V1); + testHarness.endAllInputs(); + final ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + StreamRecord leftRecordK1V2 = updateAfterRecord(1, "c-2", 1200.0); + testHarness.processElement1(leftRecordK1V2); + testHarness.endAllInputs(); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + insertTableData(0, row(300.0, 1, "a-2")); + + StreamRecord leftRecordK1V3 = updateAfterRecord(1, "c-3", 1300.0); + testHarness.processElement1(leftRecordK1V3); + testHarness.endAllInputs(); + if (!enableCache) { + expectedOutput.add( + updateAfterRecord(1, "c-3", 1300.0, 300.0, 1, "a-2", 1, 800.0, "b-4")); + } + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + insertTableData(2, row(1, "c-4", 200.0)); + + StreamRecord rightRecordK1V1 = updateAfterRecord(300.0, 1, "a-2", 1, 800.0, "b-4"); + testHarness.processElement2(rightRecordK1V1); + testHarness.endAllInputs(); + expectedOutput.add(updateAfterRecord(1, "c-2", 1200.0, 300.0, 1, "a-2", 1, 800.0, "b-4")); + expectedOutput.add(updateAfterRecord(1, "c-3", 1300.0, 300.0, 1, "a-2", 1, 800.0, "b-4")); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + } + + /** + * Test lookup chain with the order "C -> D -> B -> A" and the following join tree. + * + *
+     *
+     *   inner/left/right/full DT3
+     *     /               \
+     *    C           inner/left/right/full DT2
+     *                  /               \
+     *                 D      inner/left/right/full DT1
+     *                                /    \
+     *                               A      B
+     * 
+ */ + @TestTemplate + void testMultiCascadedHandlers() throws Exception { + prepareEnv(new ThirdCascadedWithOrderCDBATestSpec()); + + insertTableData(3, row("d-1", 1.0, 1)); + insertTableData(3, row("d-2", 1.0, 2)); + insertTableData(3, row("d-3", 1.0, 3)); + insertTableData(3, row("d-4", 2.0, 4)); + insertTableData(3, row("d-5", 1.0, 89)); + insertTableData(3, row("d-99", 99.0, 99)); + + insertTableData(1, row(1, 1.0, "b-1")); + insertTableData(1, row(1, 2.0, "b-2")); + insertTableData(1, row(1, 3.0, "b-3")); + insertTableData(1, row(2, 4.0, "b-1")); + insertTableData(1, row(2, 5.0, "b-2")); + insertTableData(1, row(2, 6.0, "b-3")); + insertTableData(1, row(3, 7.0, "b-1")); + insertTableData(1, row(3, 8.0, "b-2")); + insertTableData(1, row(3, 9.0, "b-3")); + insertTableData(1, row(4, 10.0, "b-4")); + insertTableData(1, row(199, 199.0, "b-199")); + + insertTableData(0, row(1.0, 1, "a-1")); + insertTableData(0, row(2.0, 2, "a-2")); + insertTableData(0, row(3.0, 3, "a-3")); + insertTableData(0, row(4.0, 4, "a-4")); + insertTableData(0, row(299.0, 299, "a-299")); + + StreamRecord leftInput1 = insertRecord(1, "c-1", 1.0); + testHarness.processElement1(leftInput1); + testHarness.endAllInputs(); + + // /- b-1 -- a-1 x filtered + // - d-1 -- b-2 -- a-1 x filtered + // / \- b-3 -- a-1 x filtered + // / /- b-1 -- a-2 + // c-1 ---- d-2 -- b-2 -- a-2 x filtered + // \ \- b-3 -- a-2 + // | \ /- b-1 -- a-3 + // \ - d-3 -- b-2 -- a-3 x filtered + // | \- b-3 -- a-3 + // \ + // - d-5 -- N/A -- N/A x filtered + final ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + expectedOutput.add( + insertRecord(1, "c-1", 1.0, "d-2", 1.0, 2, 2.0, 2, "a-2", 2, 4.0, "b-1")); + expectedOutput.add( + insertRecord(1, "c-1", 1.0, "d-2", 1.0, 2, 2.0, 2, "a-2", 2, 6.0, "b-3")); + expectedOutput.add( + insertRecord(1, "c-1", 1.0, "d-3", 1.0, 3, 3.0, 3, "a-3", 3, 7.0, "b-1")); + expectedOutput.add( + insertRecord(1, "c-1", 1.0, "d-3", 1.0, 3, 3.0, 3, "a-3", 3, 9.0, "b-3")); + + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + // all are filtered by "c2 <> 'c-2'" + StreamRecord leftInput2 = updateAfterRecord(1, "c-2", 1.0); + testHarness.processElement1(leftInput2); + testHarness.endAllInputs(); + + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + StreamRecord leftInput3 = updateAfterRecord(1, "c-3", 1.0); + testHarness.processElement1(leftInput3); + testHarness.endAllInputs(); + + expectedOutput.add( + updateAfterRecord(1, "c-3", 1.0, "d-2", 1.0, 2, 2.0, 2, "a-2", 2, 4.0, "b-1")); + expectedOutput.add( + updateAfterRecord(1, "c-3", 1.0, "d-2", 1.0, 2, 2.0, 2, "a-2", 2, 6.0, "b-3")); + expectedOutput.add( + updateAfterRecord(1, "c-3", 1.0, "d-3", 1.0, 3, 3.0, 3, "a-3", 3, 7.0, "b-1")); + expectedOutput.add( + updateAfterRecord(1, "c-3", 1.0, "d-3", 1.0, 3, 3.0, 3, "a-3", 3, 9.0, "b-3")); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + if (enableCache) { + assertThat(MyAsyncFunction.getLookupInvokeCount().get(0).get()).isEqualTo(3); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(1).get()).isEqualTo(4); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(2)).isNull(); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(3).get()).isEqualTo(1); + } else { + assertThat(MyAsyncFunction.getLookupInvokeCount().get(0).get()).isEqualTo(9); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(1).get()).isEqualTo(12); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(2)).isNull(); + assertThat(MyAsyncFunction.getLookupInvokeCount().get(3).get()).isEqualTo(3); + } + } + + /** + * The Join tree used to test is as following. + * + *
+     *             DT2
+     *           /    \
+     *         DT1     C
+     *       /    \
+     *      A      B
+     *
+     *    when records from C come, lookup chain is as following:
+     *    C -> B -> A
+     * 
+ * + *

Here we mainly test DT2 while two inputs come data. + */ + @TestTemplate + void testLookupFunctionThrowsException() throws Exception { + prepareEnv(new LHSTestSpec(false, false, true)); + + AtomicReference latestException = new AtomicReference<>(null); + // set external failure cause consumer to prevent hang + // DO NOT throw exception up again to avoid hang + testHarness + .getEnvironment() + // DO NOT throw exception up again to avoid hang + .setExternalFailureCauseConsumer(latestException::set); + + insertTableData(1, row(1, 1.0, "b-2")); + + StreamRecord rightRecordK1V1 = insertRecord(1, "c-1", 1000.0); + testHarness.processElement2(rightRecordK1V1); + + testHarness.endAllInputs(); + + // Exception: Could not complete the stream element: ... + // +- RuntimeException: Failed to look up table + // +- ExpectedTestException + assertThat(latestException.get()).isNotNull(); + assertThat(latestException.get().getCause().getCause()) + .isInstanceOf(ExpectedTestException.class); + } + + /** + * The Join tree used to test is as following. + * + *

+     *             DT2
+     *           /    \
+     *         DT1     C
+     *       /    \
+     *      A      B
+     *
+     *    when records from C come, lookup chain is as following:
+     *    C -> B -> A
+     * 
+ * + *

Here we mainly test DT2 while C comes data. + */ + @TestTemplate + void testOmitCalcCollectorWhenLookupIfNecessary() throws Exception { + insertTableData(0, row(1000.0, 1, "a")); + insertTableData(1, row(1, 1000.0, "b")); + + StreamRecord rightRecord = insertRecord(1, "c", 1000.0); + + // if there are no calc between source and delta join, the calc collector can be omitted + prepareEnv(new LHSTestSpec(false, false, false)); + testHarness.processElement2(rightRecord); + testHarness.endAllInputs(); + + MyAsyncFunction rightAsyncFunc1 = unwrapAsyncFunctions(operator, false, 0).get(0); + validateCalcFunctionAndCollectorWhenLookup(rightAsyncFunc1.getLastResultFuture(), true); + MyAsyncFunction rightAsyncFunc2 = unwrapAsyncFunctions(operator, false, 1).get(0); + validateCalcFunctionAndCollectorWhenLookup(rightAsyncFunc2.getLastResultFuture(), true); + testHarness.close(); + + // if there is a calc between source and delta join, the calc collector cannot be omitted + prepareEnv(new LHSTestSpec(true, true, false)); + testHarness.processElement2(rightRecord); + testHarness.endAllInputs(); + + rightAsyncFunc1 = unwrapAsyncFunctions(operator, false, 0).get(0); + validateCalcFunctionAndCollectorWhenLookup(rightAsyncFunc1.getLastResultFuture(), false); + rightAsyncFunc2 = unwrapAsyncFunctions(operator, false, 1).get(0); + validateCalcFunctionAndCollectorWhenLookup(rightAsyncFunc2.getLastResultFuture(), false); + testHarness.close(); + } + + /** + * The Join tree used to test is as following. + * + *

+     *       DT2
+     *     /    \
+     *    C     DT1
+     *        /    \
+     *       A      B
+     *
+     *    when records from C come, lookup chain is as following:
+     *    C -> B -> A
+     * 
+ * + *

Here we mainly test DT2 while two inputs come data. + */ + @TestTemplate + void testRowDataSerializerAreAlwaysSameInCalcCollector() throws Exception { + prepareEnv(new RHSTestSpec(true, true)); + + insertTableData(0, row(1000.0, 1, "a1")); + insertTableData(0, row(2000.0, 2, "a2")); + insertTableData(1, row(1, 1000.0, "b1")); + insertTableData(1, row(2, 2000.0, "b2")); + + StreamRecord leftRecord1 = insertRecord(1, "c1", 1001.0); + testHarness.processElement1(leftRecord1); + testHarness.endAllInputs(); + + MyAsyncFunction firstHandlerAsyncFunc1 = unwrapAsyncFunctions(operator, true, 0).get(0); + assertThat(firstHandlerAsyncFunc1.getLastResultFuture()).isNotNull(); + Object2RowDataConverterResultFuture firstHandlerResultFuture1 = + (Object2RowDataConverterResultFuture) firstHandlerAsyncFunc1.getLastResultFuture(); + + MyAsyncFunction secondHandlerAsyncFunc1 = unwrapAsyncFunctions(operator, true, 1).get(0); + assertThat(secondHandlerAsyncFunc1.getLastResultFuture()).isNotNull(); + Object2RowDataConverterResultFuture secondHandlerResultFuture1 = + (Object2RowDataConverterResultFuture) secondHandlerAsyncFunc1.getLastResultFuture(); + + DeltaJoinRuntimeTree joinTree1 = + unwrapProcessor(operator, true).get(0).getMultiInputRowDataBuffer().getJoinTree(); + + StreamRecord leftRecord2 = insertRecord(2, "c2", 2002.0); + testHarness.processElement1(leftRecord2); + testHarness.endAllInputs(); + + final ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + expectedOutput.add(insertRecord(1, "c1", 1001.0, 1000.0, 1, "a1", 1, 1000.0, "b1")); + expectedOutput.add(insertRecord(2, "c2", 2002.0, 2000.0, 2, "a2", 2, 2000.0, "b2")); + assertor.assertOutputEqualsSorted( + "result mismatch", expectedOutput, testHarness.getOutput()); + + MyAsyncFunction firstHandlerAsyncFunc2 = unwrapAsyncFunctions(operator, true, 0).get(1); + assertThat(firstHandlerAsyncFunc2.getLastResultFuture()).isNotNull(); + Object2RowDataConverterResultFuture firstHandlerResultFuture2 = + (Object2RowDataConverterResultFuture) firstHandlerAsyncFunc2.getLastResultFuture(); + + MyAsyncFunction secondHandlerAsyncFunc2 = unwrapAsyncFunctions(operator, true, 1).get(1); + assertThat(secondHandlerAsyncFunc2.getLastResultFuture()).isNotNull(); + Object2RowDataConverterResultFuture secondHandlerResultFuture2 = + (Object2RowDataConverterResultFuture) secondHandlerAsyncFunc2.getLastResultFuture(); + + DeltaJoinRuntimeTree joinTree2 = + unwrapProcessor(operator, true).get(1).getMultiInputRowDataBuffer().getJoinTree(); + + // validate first lookup handler + assertThat(firstHandlerResultFuture1).isNotSameAs(firstHandlerResultFuture2); + assertThat(firstHandlerResultFuture1.getCalcCollector()).isNotNull(); + assertThat(firstHandlerResultFuture2.getCalcCollector()).isNotNull(); + assertThat(firstHandlerResultFuture1.getCalcCollector().getLookupResultRowSerializer()) + .isSameAs( + firstHandlerResultFuture2 + .getCalcCollector() + .getLookupResultRowSerializer()); + + // validate second lookup handler + assertThat(secondHandlerResultFuture1).isNotSameAs(secondHandlerResultFuture2); + assertThat(secondHandlerResultFuture1.getCalcCollector()).isNotNull(); + assertThat(secondHandlerResultFuture2.getCalcCollector()).isNotNull(); + assertThat(secondHandlerResultFuture1.getCalcCollector().getLookupResultRowSerializer()) + .isSameAs( + secondHandlerResultFuture2 + .getCalcCollector() + .getLookupResultRowSerializer()); + + // validate join tree + // + // Join without calc + // / \ + // Binary with calc Join without calc + // / \ + // Binary with calc Binary with calc + assertThat(joinTree1).isNotSameAs(joinTree2); + + JoinNode topJoin1 = (JoinNode) joinTree1.root; + JoinNode topJoin2 = (JoinNode) joinTree2.root; + assertThat(topJoin1.rowDataSerializerPassThroughCalc) + .isSameAs(topJoin2.rowDataSerializerPassThroughCalc); + assertThat(topJoin1.left.rowDataSerializerPassThroughCalc) + .isSameAs(topJoin2.left.rowDataSerializerPassThroughCalc); + + JoinNode bottomJoin1 = (JoinNode) topJoin1.right; + JoinNode bottomJoin2 = (JoinNode) topJoin2.right; + assertThat(bottomJoin1.rowDataSerializerPassThroughCalc) + .isSameAs(bottomJoin2.rowDataSerializerPassThroughCalc); + assertThat(bottomJoin1.left.rowDataSerializerPassThroughCalc) + .isSameAs(bottomJoin2.left.rowDataSerializerPassThroughCalc); + assertThat(bottomJoin1.right.rowDataSerializerPassThroughCalc) + .isSameAs(bottomJoin2.right.rowDataSerializerPassThroughCalc); + } + + /** Abstract test specification for cascaded delta join operator tests. */ + private abstract static class CascadedTestSpec extends AbstractTestSpec { + + abstract int[] getEachBinaryInputFieldSize(); + + abstract DeltaJoinHandlerChain getLeft2RightHandlerChain( + Map>> fetcherCollector); + + abstract DeltaJoinHandlerChain getRight2LeftHandlerChain( + Map>> fetcherCollector); + + abstract @Nullable GeneratedFilterCondition getRemainingJoinCondition(); + + abstract DeltaJoinRuntimeTree getJoinRuntimeTree(); + + abstract Set> getLeft2RightDrivenSideInfo(); + + abstract Set> getRight2LeftDrivenSideInfo(); + + void insertTableDataOnEmit(int inputIdx, RowData rowData) {} + } + + private void prepareEnv(CascadedTestSpec spec) throws Exception { + testHarness = createCascadedDeltaJoinOperatorTestHarness(spec); + + testHarness.setup(); + testHarness.open(); + operator = unwrapOperator(testHarness); + operator.setAsyncExecutionController( + new MyAsyncExecutionControllerDelegate( + operator.getAsyncExecutionController(), true, spec::insertTableDataOnEmit)); + + assertor = createAssertor(spec.getOutputRowType()); + } + + private KeyedTwoInputStreamOperatorTestHarness + createCascadedDeltaJoinOperatorTestHarness(CascadedTestSpec spec) throws Exception { + Map>> fetcherCollector = + new HashMap<>(); + DeltaJoinHandlerChain left2RightHandlerChain = + spec.getLeft2RightHandlerChain(fetcherCollector); + DeltaJoinHandlerChain right2LeftHandlerChain = + spec.getRight2LeftHandlerChain(fetcherCollector); + + return createDeltaJoinOperatorTestHarness( + spec.getEachBinaryInputFieldSize(), + left2RightHandlerChain, + right2LeftHandlerChain, + spec.getRemainingJoinCondition(), + spec.getJoinRuntimeTree(), + spec.getLeft2RightDrivenSideInfo(), + spec.getRight2LeftDrivenSideInfo(), + spec.getLeftJoinKeySelector(), + spec.getLeftUpsertKeySelector(), + spec.getRightJoinKeySelector(), + spec.getRightUpsertKeySelector(), + fetcherCollector, + spec.getLeftInputTypeInfo(), + spec.getRightInputTypeInfo(), + enableCache); + } + + /** + * Test spec for LHS cascaded join. + * + *

If there is not a filter on binary input, the sql is: + * + *

+     *  insert into snk
+     *      select *
+     *          from A
+     *      join B
+     *          on a1 = b1 and b2 <> 'b-2'
+     *      join C
+     *          on c1 = b1 and c2 <> 'c-1'
+     * 
+ * + *

If there is a filter on binary input and no calc on cascaded joins, the sql is: + * + *

+     *  insert into snk
+     *      select * from (
+     *          select * from A where a0 >= 100.0
+     *      )
+     *      join (
+     *          select * from B where b0 >= 200.0
+     *      )
+     *      on a1 = b1 and b2 <> 'b-2'
+     *      join (
+     *          select * from C where c0 >= 300.0
+     *      )
+     *      on c1 = b1 and c2 <> 'c-1'
+     * 
+ * + *

If there is a filter on binary input and a calc on cascaded joins, the sql is: + * + *

+     *  insert into snk
+     *      select * from (
+     *          select * from (
+     *              select * from A where a0 >= 100.0
+     *          ) join (
+     *              select * from B where b0 >= 200.0
+     *          )
+     *          on a1 = b1 and b2 <> 'b-2'
+     *          where a0 + b0 >= 900.0
+     *      )
+     *      join (
+     *          select * from C where c0 >= 300.0
+     *      )
+     *      on c1 = b1 and c2 <> 'c-1'
+     * 
+ * + *

The join tree is like: + * + *

+     *             DT2
+     *           /    \
+     *         DT1     C
+     *       /    \
+     *      A      B
+     * 
+ */ + private class LHSTestSpec extends CascadedTestSpec { + + private final boolean containsFilterOnTable; + private final boolean containsFilterBetweenCascadedJoins; + private final boolean throwExceptionWhenLookingUpFromB2A; + + LHSTestSpec( + boolean containsFilterOnTable, + boolean containsFilterBetweenCascadedJoins, + boolean throwExceptionWhenLookingUpFromB2A) { + Preconditions.checkArgument( + containsFilterOnTable || !containsFilterBetweenCascadedJoins, + "Unsupported pattern in LHSTestSpec"); + this.containsFilterOnTable = containsFilterOnTable; + this.containsFilterBetweenCascadedJoins = containsFilterBetweenCascadedJoins; + this.throwExceptionWhenLookingUpFromB2A = throwExceptionWhenLookingUpFromB2A; + } + + @Override + int[] getEachBinaryInputFieldSize() { + return new int[] {3, 3, 3}; + } + + @Override + RowType getLeftInputRowType() { + return combineSourceRowTypes(0, 1); + } + + @Override + RowType getRightInputRowType() { + return tableRowTypeMap.get(2); + } + + @Override + int[] getLeftJoinKeyIndices() { + // left jk: b1 + return new int[] {3}; + } + + @Override + Optional getLeftUpsertKey() { + // left uk: none + return Optional.empty(); + } + + @Override + int[] getRightJoinKeyIndices() { + // right jk: c1 + return new int[] {0}; + } + + @Override + Optional getRightUpsertKey() { + // right uk: c0 + return Optional.of(new int[] {2}); + } + + @Override + DeltaJoinHandlerChain getLeft2RightHandlerChain( + Map>> fetcherCollector) { + RowType dt1OutType = combineSourceRowTypes(0, 1); + RowType cInOutType = tableRowTypeMap.get(2); + GeneratedFunction> generatedCalcOnC = + containsFilterOnTable ? createDoubleGreaterThanFilter(300.0, 2) : null; + return buildBinaryChain( + LookupTestSpec.builder() + .withSourceInputs(0, 1) + .withTargetInput(2) + .withTargetTableIdx(2) + .withSourceLookupKeyIdx(3) + .withSourceRowType(dt1OutType) + .withTargetLookupKeyIdx(0) + .withTargetRowType(cInOutType) + .withTargetGeneratedCalc(generatedCalcOnC) + .build(), + fetcherCollector); + } + + @Override + DeltaJoinHandlerChain getRight2LeftHandlerChain( + Map>> fetcherCollector) { + RowType aInOutType = tableRowTypeMap.get(0); + RowType bInOutType = tableRowTypeMap.get(1); + RowType cInOutType = tableRowTypeMap.get(2); + GeneratedFunction> generatedCalcOnA = + containsFilterOnTable ? createDoubleGreaterThanFilter(100.0, 0) : null; + GeneratedFunction> generatedCalcOnB = + containsFilterOnTable ? createDoubleGreaterThanFilter(200.0, 1) : null; + return buildCascadedChain( + Arrays.asList( + LookupTestSpec.builder() + .withSourceInputs(2) + .withTargetInput(1) + .withTargetTableIdx(1) + .withSourceLookupKeyIdx(0) + .withSourceRowType(cInOutType) + .withTargetLookupKeyIdx(0) + .withTargetRowType(bInOutType) + .withTargetGeneratedCalc(generatedCalcOnB) + .build(), + LookupTestSpec.builder() + .withSourceInputs(1) + .withTargetInput(0) + .withTargetTableIdx(0) + .withSourceLookupKeyIdx(0) + .withSourceRowType(bInOutType) + .withTargetLookupKeyIdx(1) + .withTargetRowType(aInOutType) + .withTargetGeneratedCalc(generatedCalcOnA) + // b2 <> 'b-2' + .withGeneratedRemainingCondition( + getFilterCondition( + // A + B + combineSourceRowTypes(0, 1), + new int[] {5}, + lookupResult -> + !lookupResult + .getString(0) + .toString() + .equals("b-2"))) + .expectedThrownException( + throwExceptionWhenLookingUpFromB2A + ? new ExpectedTestException() + : null) + .build()), + new int[] {0, 1}, + new int[] {2}, + fetcherCollector); + } + + @Override + @Nullable + GeneratedFilterCondition getRemainingJoinCondition() { + // c1 = b1 and c2 <> 'c-1' + return getFilterCondition( + // A + B + C + getOutputRowType(), + new int[] {6, 3, 7}, + row -> + row.getInt(0) == row.getInt(1) + && !row.getString(2).toString().equals("c-1")); + } + + @Override + DeltaJoinRuntimeTree getJoinRuntimeTree() { + RowType aInOutType = tableRowTypeMap.get(0); + RowType bInOutType = tableRowTypeMap.get(1); + RowType cInOutType = tableRowTypeMap.get(2); + + GeneratedFunction> generatedCalcOnA = + containsFilterOnTable ? createDoubleGreaterThanFilter(100.0, 0) : null; + GeneratedFunction> generatedCalcOnB = + containsFilterOnTable ? createDoubleGreaterThanFilter(200.0, 1) : null; + GeneratedFunction> generatedCalcOnC = + containsFilterOnTable ? createDoubleGreaterThanFilter(300.0, 2) : null; + + DeltaJoinRuntimeTree.BinaryInputNode nodeA = + new DeltaJoinRuntimeTree.BinaryInputNode( + 0, generatedCalcOnA, InternalSerializers.create(aInOutType)); + DeltaJoinRuntimeTree.BinaryInputNode nodeB = + new DeltaJoinRuntimeTree.BinaryInputNode( + 1, generatedCalcOnB, InternalSerializers.create(bInOutType)); + DeltaJoinRuntimeTree.BinaryInputNode nodeC = + new DeltaJoinRuntimeTree.BinaryInputNode( + 2, generatedCalcOnC, InternalSerializers.create(cInOutType)); + + GeneratedFunction> generatedCalcOnDT1 = + containsFilterBetweenCascadedJoins + ? createFlatMap( + rowData -> + rowData.getDouble(0) + rowData.getDouble(4) > 900.0 + ? Optional.of(rowData) + : Optional.empty()) + : null; + + DeltaJoinRuntimeTree.JoinNode nodeDT1 = + new DeltaJoinRuntimeTree.JoinNode( + FlinkJoinType.INNER, + getFilterCondition( + getLeftInputRowType(), + new int[] {1, 3, 5}, + row -> + row.getInt(0) == row.getInt(1) + && !row.getString(2).toString().equals("b-2")), + generatedCalcOnDT1, + nodeA, + nodeB, + InternalSerializers.create(getLeftInputRowType())); + DeltaJoinRuntimeTree.JoinNode nodeDT2 = + new DeltaJoinRuntimeTree.JoinNode( + FlinkJoinType.INNER, + getFilterCondition( + getOutputRowType(), + new int[] {6, 3, 7}, + row -> + row.getInt(0) == row.getInt(1) + && !row.getString(2).toString().equals("c-1")), + null, + nodeDT1, + nodeC, + InternalSerializers.create(getOutputRowType())); + return new DeltaJoinRuntimeTree(nodeDT2); + } + + @Override + Set> getLeft2RightDrivenSideInfo() { + return Set.of(Set.of(0, 1)); + } + + @Override + Set> getRight2LeftDrivenSideInfo() { + return Set.of(Set.of(2), Set.of(1)); + } + + @Override + void insertTableDataOnEmit(int inputIdx, RowData rowData) { + if (inputIdx == 0) { + // split A and B + RowType dt1OutType = combineSourceRowTypes(0, 1); + RowData dataOnA = projectRowData(rowData, dt1OutType, new int[] {0, 1, 2}); + insertTableData(0, dataOnA); + RowData dataOnB = projectRowData(rowData, dt1OutType, new int[] {3, 4, 5}); + insertTableData(1, dataOnB); + } else { + insertTableData(2, rowData); + } + } + } + + /** + * Test spec for RHS cascaded join. + * + *

If there is not a filter on binary input, the sql is: + * + *

+     *  insert into snk
+     *      select *
+     *          from C
+     *      join (
+     *          select * from A
+     *          join B
+     *          on a1 = b1 and b2 <> 'b-2'
+     *      )
+     *      on c1 = b1 and c2 <> 'c-1'
+     * 
+ * + *

If there is a filter on binary input, the sql is: + * + *

+     *  insert into snk
+     *      select *
+     *          from (
+     *              select * from C where c0 > 300.0
+     *          )
+     *      join (
+     *          select * from (
+     *              select * from A where a0 > 100.0
+     *          )
+     *          join (
+     *              select * from B where b0 > 200.0
+     *          )
+     *          on a1 = b1 and b2 <> 'b-2'
+     *      )
+     *      on c1 = b1 and c2 <> 'c-1'
+     * 
+ * + *

If there is a filter on binary input and a calc on cascaded joins, the sql is: + * + *

+     *  insert into snk
+     *      select *
+     *          from (
+     *              select * from C where c0 > 300.0
+     *          )
+     *      join (
+     *          select * from (
+     *              select * from A where a0 > 100.0
+     *          )
+     *          join (
+     *              select * from B where b0 > 200.0
+     *          )
+     *          on a1 = b1 and b2 <> 'b-2'
+     *          where a0 + b0 >= 900.0
+     *      )
+     *      on c1 = b1 and c2 <> 'c-1'
+     * 
+ * + *

The join tree is like: + * + *

+     *       DT2
+     *     /    \
+     *    C     DT1
+     *        /    \
+     *       A      B
+     * 
+ */ + private class RHSTestSpec extends CascadedTestSpec { + + private final boolean containsFilterOnTable; + private final boolean containsFilterBetweenCascadedJoins; + + RHSTestSpec(boolean containsFilterOnTable, boolean containsFilterBetweenCascadedJoins) { + Preconditions.checkArgument( + containsFilterOnTable || !containsFilterBetweenCascadedJoins, + "Unsupported pattern in RHSTestSpec"); + this.containsFilterOnTable = containsFilterOnTable; + this.containsFilterBetweenCascadedJoins = containsFilterBetweenCascadedJoins; + } + + @Override + int[] getEachBinaryInputFieldSize() { + return new int[] {3, 3, 3}; + } + + @Override + RowType getLeftInputRowType() { + return tableRowTypeMap.get(2); + } + + @Override + RowType getRightInputRowType() { + return combineSourceRowTypes(0, 1); + } + + @Override + int[] getLeftJoinKeyIndices() { + // left jk: c1 + return new int[] {0}; + } + + @Override + int[] getRightJoinKeyIndices() { + // right jk: b1 + return new int[] {3}; + } + + @Override + Optional getLeftUpsertKey() { + // left uk: c0 + return Optional.of(new int[] {2}); + } + + @Override + Optional getRightUpsertKey() { + // right uk: none + return Optional.empty(); + } + + @Override + DeltaJoinHandlerChain getLeft2RightHandlerChain( + Map>> fetcherCollector) { + RowType aInOutType = tableRowTypeMap.get(0); + RowType bInOutType = tableRowTypeMap.get(1); + RowType cInOutType = tableRowTypeMap.get(2); + GeneratedFunction> generatedCalcOnA = + containsFilterOnTable ? createDoubleGreaterThanFilter(100.0, 0) : null; + GeneratedFunction> generatedCalcOnB = + containsFilterOnTable ? createDoubleGreaterThanFilter(200.0, 1) : null; + return buildCascadedChain( + Arrays.asList( + LookupTestSpec.builder() + .withSourceInputs(0) + .withTargetInput(2) + .withTargetTableIdx(1) + .withSourceLookupKeyIdx(0) + .withSourceRowType(cInOutType) + .withTargetLookupKeyIdx(0) + .withTargetRowType(bInOutType) + .withTargetGeneratedCalc(generatedCalcOnB) + .build(), + LookupTestSpec.builder() + .withSourceInputs(2) + .withTargetInput(1) + .withTargetTableIdx(0) + .withSourceLookupKeyIdx(0) + .withSourceRowType(bInOutType) + .withTargetLookupKeyIdx(1) + .withTargetRowType(aInOutType) + .withTargetGeneratedCalc(generatedCalcOnA) + // b2 <> 'b-2' + .withGeneratedRemainingCondition( + getFilterCondition( + // A + B + combineSourceRowTypes(0, 1), + new int[] {5}, + lookupResult -> + !lookupResult + .getString(0) + .toString() + .equals("b-2"))) + .build()), + new int[] {1, 2}, + new int[] {0}, + fetcherCollector); + } + + @Override + DeltaJoinHandlerChain getRight2LeftHandlerChain( + Map>> fetcherCollector) { + RowType cInOutType = tableRowTypeMap.get(2); + GeneratedFunction> generatedCalcOnC = + containsFilterOnTable ? createDoubleGreaterThanFilter(300.0, 2) : null; + return buildBinaryChain( + LookupTestSpec.builder() + .withSourceInputs(1, 2) + .withTargetInput(0) + .withTargetTableIdx(2) + .withSourceLookupKeyIdx(3) + .withSourceRowType(getRightInputRowType()) + .withTargetLookupKeyIdx(0) + .withTargetRowType(cInOutType) + .withTargetGeneratedCalc(generatedCalcOnC) + .build(), + fetcherCollector); + } + + @Override + @Nullable + GeneratedFilterCondition getRemainingJoinCondition() { + // c1 = b1 and c2 <> 'c-1' + return getFilterCondition( + // C + A & B + getOutputRowType(), + new int[] {0, 6, 1}, + row -> + row.getInt(0) == row.getInt(1) + && !row.getString(2).toString().equals("c-1")); + } + + @Override + DeltaJoinRuntimeTree getJoinRuntimeTree() { + RowType aInOutType = tableRowTypeMap.get(0); + RowType bInOutType = tableRowTypeMap.get(1); + RowType dt1OutType = combineSourceRowTypes(0, 1); + RowType cInOutType = tableRowTypeMap.get(2); + + GeneratedFunction> generatedCalcOnA = + containsFilterOnTable ? createDoubleGreaterThanFilter(100.0, 0) : null; + GeneratedFunction> generatedCalcOnB = + containsFilterOnTable ? createDoubleGreaterThanFilter(200.0, 1) : null; + GeneratedFunction> generatedCalcOnC = + containsFilterOnTable ? createDoubleGreaterThanFilter(300.0, 2) : null; + + DeltaJoinRuntimeTree.BinaryInputNode nodeA = + new DeltaJoinRuntimeTree.BinaryInputNode( + 1, generatedCalcOnA, InternalSerializers.create(aInOutType)); + DeltaJoinRuntimeTree.BinaryInputNode nodeB = + new DeltaJoinRuntimeTree.BinaryInputNode( + 2, generatedCalcOnB, InternalSerializers.create(bInOutType)); + DeltaJoinRuntimeTree.BinaryInputNode nodeC = + new DeltaJoinRuntimeTree.BinaryInputNode( + 0, generatedCalcOnC, InternalSerializers.create(cInOutType)); + + GeneratedFunction> generatedCalcOnDT1 = + containsFilterBetweenCascadedJoins + ? createFlatMap( + rowData -> + rowData.getDouble(0) + rowData.getDouble(4) > 900.0 + ? Optional.of(rowData) + : Optional.empty()) + : null; + DeltaJoinRuntimeTree.JoinNode nodeDT1 = + new DeltaJoinRuntimeTree.JoinNode( + FlinkJoinType.INNER, + getFilterCondition( + dt1OutType, + new int[] {1, 3, 5}, + row -> + row.getInt(0) == row.getInt(1) + && !row.getString(2).toString().equals("b-2")), + generatedCalcOnDT1, + nodeA, + nodeB, + InternalSerializers.create(dt1OutType)); + + DeltaJoinRuntimeTree.JoinNode nodeDT2 = + new DeltaJoinRuntimeTree.JoinNode( + FlinkJoinType.INNER, + // c1 = b1 and c2 <> 'c-1' + getFilterCondition( + // C + A & B + getOutputRowType(), + new int[] {0, 6, 1}, + row -> + row.getInt(0) == row.getInt(1) + && !row.getString(2).toString().equals("c-1")), + null, + nodeC, + nodeDT1, + InternalSerializers.create(getOutputRowType())); + + return new DeltaJoinRuntimeTree(nodeDT2); + } + + @Override + Set> getLeft2RightDrivenSideInfo() { + return Set.of(Set.of(0), Set.of(2)); + } + + @Override + Set> getRight2LeftDrivenSideInfo() { + return Set.of(Set.of(1, 2)); + } + + @Override + void insertTableDataOnEmit(int inputIdx, RowData rowData) { + if (inputIdx == 1) { + // split A and B + RowType dt1OutType = combineSourceRowTypes(0, 1); + RowData dataOnA = projectRowData(rowData, dt1OutType, new int[] {0, 1, 2}); + insertTableData(0, dataOnA); + RowData dataOnB = projectRowData(rowData, dt1OutType, new int[] {3, 4, 5}); + insertTableData(1, dataOnB); + } else { + insertTableData(2, rowData); + } + } + } + + /** + * Test spec for third-cascaded join with lookup order C -> D -> B -> A. + * + *

The sql is like: + * + *

+     *  insert into snk
+     *      select * from C
+     *      join (
+     *          select * from D
+     *          join (
+     *              select * from A
+     *              join B
+     *              on a1 = b1 and b2 <> 'b-2'
+     *          )
+     *          on d1 = b1
+     *      )
+     *      on d0 = c0 and d2 <> 'd-1' and c2 <> 'c-2'
+     * 
+ * + *

The join tree is like: + * + *

+     *     DT3
+     *   /    \
+     *  C     DT2
+     *     /    \
+     *    D     DT1
+     *        /    \
+     *       A      B
+     * 
+ */ + private class ThirdCascadedWithOrderCDBATestSpec extends CascadedTestSpec { + + @Override + int[] getEachBinaryInputFieldSize() { + return new int[] {3, 3, 3, 3}; + } + + @Override + RowType getLeftInputRowType() { + return tableRowTypeMap.get(2); + } + + @Override + RowType getRightInputRowType() { + return combineSourceRowTypes(3, 0, 1); + } + + @Override + int[] getLeftJoinKeyIndices() { + // left jk: c0 + return new int[] {2}; + } + + @Override + int[] getRightJoinKeyIndices() { + // right jk: d0 + return new int[] {1}; + } + + @Override + Optional getLeftUpsertKey() { + // left uk: c0 + return Optional.of(new int[] {2}); + } + + @Override + Optional getRightUpsertKey() { + // right uk: none + return Optional.empty(); + } + + @Override + DeltaJoinHandlerChain getLeft2RightHandlerChain( + Map>> fetcherCollector) { + RowType aInOutType = tableRowTypeMap.get(0); + RowType bInOutType = tableRowTypeMap.get(1); + RowType cInOutType = tableRowTypeMap.get(2); + RowType dInOutType = tableRowTypeMap.get(3); + return buildCascadedChain( + Arrays.asList( + LookupTestSpec.builder() + .withSourceInputs(0) + .withTargetInput(1) + .withTargetTableIdx(3) + .withSourceLookupKeyIdx(2) + .withSourceRowType(cInOutType) + .withTargetLookupKeyIdx(1) + .withTargetRowType(dInOutType) + .build(), + LookupTestSpec.builder() + .withSourceInputs(1) + .withTargetInput(3) + .withTargetTableIdx(1) + .withSourceLookupKeyIdx(2) + .withSourceRowType(dInOutType) + .withTargetLookupKeyIdx(0) + .withTargetRowType(bInOutType) + .build(), + LookupTestSpec.builder() + .withSourceInputs(3) + .withTargetInput(2) + .withTargetTableIdx(0) + .withSourceLookupKeyIdx(0) + .withSourceRowType(bInOutType) + .withTargetLookupKeyIdx(1) + .withTargetRowType(aInOutType) + // b2 <> 'b-2' + .withGeneratedRemainingCondition( + getFilterCondition( + // B + A + combineSourceRowTypes(0, 1), + new int[] {5}, + row -> + !row.getString(0) + .toString() + .equals("b-2"))) + .build()), + new int[] {1, 2, 3}, + new int[] {0}, + fetcherCollector); + } + + @Override + DeltaJoinHandlerChain getRight2LeftHandlerChain( + Map>> fetcherCollector) { + RowType dt2OutType = getRightInputRowType(); + RowType cInOutType = getLeftInputRowType(); + return buildBinaryChain( + LookupTestSpec.builder() + .withSourceInputs(1, 2, 3) + .withTargetInput(0) + .withTargetTableIdx(2) + .withSourceLookupKeyIdx(1) + .withSourceRowType(dt2OutType) + .withTargetLookupKeyIdx(2) + .withTargetRowType(cInOutType) + .build(), + fetcherCollector); + } + + @Override + @Nullable + GeneratedFilterCondition getRemainingJoinCondition() { + // d0 = c0 and d2 <> 'd-1' and c2 <> 'c-2' + return getFilterCondition( + // C + D & A & B + getOutputRowType(), + new int[] {2, 4, 3, 1}, + row -> + row.getDouble(0) == row.getDouble(1) + && !row.getString(2).toString().equals("d-1") + && !row.getString(3).toString().equals("c-2")); + } + + @Override + DeltaJoinRuntimeTree getJoinRuntimeTree() { + RowType aInOutType = tableRowTypeMap.get(0); + RowType bInOutType = tableRowTypeMap.get(1); + RowType cInOutType = tableRowTypeMap.get(2); + RowType dInOutType = tableRowTypeMap.get(3); + RowType dt1OutType = combineSourceRowTypes(0, 1); + RowType dt2OutType = getRightInputRowType(); + RowType dt3OutType = getOutputRowType(); + + DeltaJoinRuntimeTree.BinaryInputNode nodeA = + new DeltaJoinRuntimeTree.BinaryInputNode( + 2, null, InternalSerializers.create(aInOutType)); + DeltaJoinRuntimeTree.BinaryInputNode nodeB = + new DeltaJoinRuntimeTree.BinaryInputNode( + 3, null, InternalSerializers.create(bInOutType)); + DeltaJoinRuntimeTree.BinaryInputNode nodeC = + new DeltaJoinRuntimeTree.BinaryInputNode( + 0, null, InternalSerializers.create(cInOutType)); + DeltaJoinRuntimeTree.BinaryInputNode nodeD = + new DeltaJoinRuntimeTree.BinaryInputNode( + 1, null, InternalSerializers.create(dInOutType)); + DeltaJoinRuntimeTree.JoinNode nodeDT1 = + new DeltaJoinRuntimeTree.JoinNode( + FlinkJoinType.INNER, + getFilterCondition( + dt1OutType, + new int[] {1, 3, 5}, + row -> + row.getInt(0) == row.getInt(1) + && !row.getString(2).toString().equals("b-2")), + null, + nodeA, + nodeB, + InternalSerializers.create(dt1OutType)); + DeltaJoinRuntimeTree.JoinNode nodeDT2 = + new DeltaJoinRuntimeTree.JoinNode( + FlinkJoinType.INNER, + getFilterCondition( + dt2OutType, + new int[] {2, 6}, + row -> row.getInt(0) == row.getInt(1)), + null, + nodeD, + nodeDT1, + InternalSerializers.create(dt2OutType)); + DeltaJoinRuntimeTree.JoinNode nodeDT3 = + new DeltaJoinRuntimeTree.JoinNode( + FlinkJoinType.INNER, + getFilterCondition( + dt3OutType, + new int[] {2, 4, 3, 1}, + row -> + row.getDouble(0) == row.getDouble(1) + && !row.getString(2).toString().equals("d-1") + && !row.getString(3).toString().equals("c-2")), + null, + nodeC, + nodeDT2, + InternalSerializers.create(dt3OutType)); + return new DeltaJoinRuntimeTree(nodeDT3); + } + + @Override + Set> getLeft2RightDrivenSideInfo() { + return Set.of(Set.of(0), Set.of(1), Set.of(3)); + } + + @Override + Set> getRight2LeftDrivenSideInfo() { + return Set.of(Set.of(1, 2, 3)); + } + } + + private void insertTableData(int tableIdx, RowData data) { + try { + RowData uk = tableUpsertKeySelector.get(tableIdx).getKey(data); + + tableCurrentDataMap.compute( + tableIdx, + (k, v) -> { + if (v == null) { + v = new LinkedHashMap<>(); + } + v.put(uk, data); + return v; + }); + } catch (Exception e) { + throw new IllegalStateException("Failed to add data to table", e); + } + } + + private List unwrapAsyncFunctions( + StreamingDeltaJoinOperator operator, boolean unwrapLeft, int handlerIdxInChain) { + List processors = + unwrapProcessor(operator, unwrapLeft); + List results = new ArrayList<>(); + for (AsyncDeltaJoinRunner.DeltaJoinProcessor processor : processors) { + int idx = 0; + DeltaJoinHandlerBase handler = processor.getDeltaJoinHandlerChain().getHead(); + while (idx < handlerIdxInChain) { + handler = Objects.requireNonNull(handler.getNext()); + idx++; + } + if (handler instanceof LookupHandlerBase) { + results.add((MyAsyncFunction) ((LookupHandlerBase) handler).getFetcher()); + } else { + throw new IllegalStateException("The handler is not a lookup handler"); + } + } + return results; + } + + private List unwrapProcessor( + StreamingDeltaJoinOperator operator, boolean unwrapLeft) { + if (unwrapLeft) { + return operator.getLeftTriggeredUserFunction().getAllProcessors(); + } else { + return operator.getRightTriggeredUserFunction().getAllProcessors(); + } + } + + private RowData projectRowData(RowData rowData, RowType rowType, int[] fields) { + List inputTypes = rowType.getChildren(); + GenericRowData data = new GenericRowData(rowData.getRowKind(), fields.length); + for (int i = 0; i < fields.length; i++) { + RowData.FieldGetter fieldGetter = + RowData.createFieldGetter(inputTypes.get(fields[i]), fields[i]); + data.setField(i, fieldGetter.getFieldOrNull(rowData)); + } + return data; + } + + private RowType combineSourceRowTypes(int... sourceIdx) { + return combineRowTypes( + Arrays.stream(sourceIdx).boxed().map(tableRowTypeMap::get).toArray(RowType[]::new)); + } + + private void verifyCacheData( + CascadedTestSpec testSpec, + DeltaJoinCache actualCache, + Map> expectedLeftCacheData, + Map> expectedRightCacheData, + long expectedLeftCacheRequestCount, + long expectedLeftCacheHitCount, + long expectedRightCacheRequestCount, + long expectedRightCacheHitCount) { + // assert left cache + verifyCacheData( + actualCache, + expectedLeftCacheData, + expectedLeftCacheRequestCount, + expectedLeftCacheHitCount, + testSpec.getLeftJoinKeySelector().getProducedType().toRowType(), + testSpec.getLeftUpsertKeySelector().getProducedType().toRowType(), + testSpec.getLeftInputRowType(), + true); + + // assert right cache + verifyCacheData( + actualCache, + expectedRightCacheData, + expectedRightCacheRequestCount, + expectedRightCacheHitCount, + testSpec.getRightJoinKeySelector().getProducedType().toRowType(), + testSpec.getRightUpsertKeySelector().getProducedType().toRowType(), + testSpec.getRightInputRowType(), + false); + } + + private DeltaJoinHandlerChain buildBinaryChain( + LookupTestSpec lookupTestSpec, + Map>> fetcherCollector) { + fetcherCollector.put( + lookupTestSpec.targetInput, + createFetcherFunction( + tableCurrentDataMap, + getKeySelector( + lookupTestSpec.sourceLookupKeyIdx, lookupTestSpec.sourceRowType), + getKeySelector( + lookupTestSpec.targetLookupKeyIdx, lookupTestSpec.targetRowType), + lookupTestSpec.targetTableIdx, + lookupTestSpec.expectedThrownException)); + BinaryLookupHandler handler = + new BinaryLookupHandler( + toInternalDataType(lookupTestSpec.sourceRowType), + toInternalDataType(lookupTestSpec.targetRowType), + toInternalDataType(lookupTestSpec.targetRowType), + InternalSerializers.create(lookupTestSpec.targetRowType), + lookupTestSpec.targetGeneratedCalc, + lookupTestSpec.sourceInputs, + lookupTestSpec.targetInput); + return DeltaJoinHandlerChain.build( + Collections.singletonList(handler), lookupTestSpec.sourceInputs); + } + + private DeltaJoinHandlerChain buildCascadedChain( + List lookupChain, + int[] allLookupSideBinaryInputOrdinals, + int[] streamInputOrdinals, + Map>> fetcherCollector) { + Preconditions.checkArgument(lookupChain.size() > 1); + List handlers = new ArrayList<>(); + for (int i = 0; i < lookupChain.size(); i++) { + LookupTestSpec lookupTestSpec = lookupChain.get(i); + fetcherCollector.put( + lookupTestSpec.targetInput, + createFetcherFunction( + tableCurrentDataMap, + getKeySelector( + lookupTestSpec.sourceLookupKeyIdx, + lookupTestSpec.sourceRowType), + getKeySelector( + lookupTestSpec.targetLookupKeyIdx, + lookupTestSpec.targetRowType), + lookupTestSpec.targetTableIdx, + lookupTestSpec.expectedThrownException)); + + handlers.add( + new CascadedLookupHandler( + i + 1, + toInternalDataType(lookupTestSpec.sourceRowType), + toInternalDataType(lookupTestSpec.targetRowType), + toInternalDataType(lookupTestSpec.targetRowType), + InternalSerializers.create(lookupTestSpec.targetRowType), + lookupTestSpec.targetGeneratedCalc, + lookupTestSpec.generatedRemainingCondition, + getKeySelector( + lookupTestSpec.sourceLookupKeyIdx, + lookupTestSpec.sourceRowType), + lookupTestSpec.sourceInputs, + lookupTestSpec.targetInput, + Arrays.stream(lookupTestSpec.sourceInputs) + .allMatch(src -> src < lookupTestSpec.targetInput))); + } + + handlers.add(new TailOutputDataHandler(allLookupSideBinaryInputOrdinals)); + return DeltaJoinHandlerChain.build(handlers, streamInputOrdinals); + } + + private static class LookupTestSpec { + private final int[] sourceInputs; + private final int targetInput; + private final int targetTableIdx; + private final int[] sourceLookupKeyIdx; + private final RowType sourceRowType; + private final int[] targetLookupKeyIdx; + private final RowType targetRowType; + private final @Nullable GeneratedFilterCondition generatedRemainingCondition; + private final @Nullable GeneratedFunction> + targetGeneratedCalc; + private final @Nullable Throwable expectedThrownException; + + private LookupTestSpec( + int[] sourceInputs, + int targetInput, + int targetTableIdx, + int[] sourceLookupKeyIdx, + RowType sourceRowType, + int[] targetLookupKeyIdx, + RowType targetRowType, + @Nullable GeneratedFilterCondition generatedRemainingCondition, + @Nullable GeneratedFunction> targetGeneratedCalc, + @Nullable Throwable expectedThrownException) { + this.sourceInputs = sourceInputs; + this.targetInput = targetInput; + this.targetTableIdx = targetTableIdx; + this.sourceLookupKeyIdx = sourceLookupKeyIdx; + this.sourceRowType = sourceRowType; + this.targetLookupKeyIdx = targetLookupKeyIdx; + this.targetRowType = targetRowType; + this.generatedRemainingCondition = generatedRemainingCondition; + this.targetGeneratedCalc = targetGeneratedCalc; + this.expectedThrownException = expectedThrownException; + } + + public static Builder builder() { + return new Builder(); + } + + private static class Builder { + private int[] sourceInputs; + private int targetInput; + private int targetTableIdx; + private int[] sourceLookupKeyIdx; + private RowType sourceRowType; + private int[] targetLookupKeyIdx; + private RowType targetRowType; + private @Nullable GeneratedFilterCondition generatedRemainingCondition; + private @Nullable GeneratedFunction> + targetGeneratedCalc; + private @Nullable Throwable expectedThrownException; + + public Builder withSourceInputs(int... sources) { + this.sourceInputs = sources; + return this; + } + + public Builder withTargetInput(int target) { + this.targetInput = target; + return this; + } + + public Builder withTargetTableIdx(int targetTableIdx) { + this.targetTableIdx = targetTableIdx; + return this; + } + + public Builder withSourceLookupKeyIdx(int... sourceLookupKeyIdx) { + this.sourceLookupKeyIdx = sourceLookupKeyIdx; + return this; + } + + public Builder withSourceRowType(RowType sourceRowType) { + this.sourceRowType = sourceRowType; + return this; + } + + public Builder withTargetLookupKeyIdx(int... targetLookupKeyIdx) { + this.targetLookupKeyIdx = targetLookupKeyIdx; + return this; + } + + public Builder withTargetRowType(RowType targetRowType) { + this.targetRowType = targetRowType; + return this; + } + + public Builder withGeneratedRemainingCondition( + GeneratedFilterCondition generatedRemainingCondition) { + this.generatedRemainingCondition = generatedRemainingCondition; + return this; + } + + public Builder withTargetGeneratedCalc( + @Nullable + GeneratedFunction> + lookupSideGeneratedCalc) { + this.targetGeneratedCalc = lookupSideGeneratedCalc; + return this; + } + + public Builder expectedThrownException(@Nullable Throwable expectedThrownException) { + this.expectedThrownException = expectedThrownException; + return this; + } + + public LookupTestSpec build() { + return new LookupTestSpec( + requireNonNull(sourceInputs), + targetInput, + targetTableIdx, + requireNonNull(sourceLookupKeyIdx), + requireNonNull(sourceRowType), + requireNonNull(targetLookupKeyIdx), + requireNonNull(targetRowType), + generatedRemainingCondition, + targetGeneratedCalc, + expectedThrownException); + } + } + } +} diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingDeltaJoinOperatorTestBase.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingDeltaJoinOperatorTestBase.java index 50b8d85297c67..809da3db98e20 100644 --- a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingDeltaJoinOperatorTestBase.java +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/deltajoin/StreamingDeltaJoinOperatorTestBase.java @@ -634,8 +634,8 @@ private void freeExecutor() { } } - /** Base test specification shared by binary and cascaded delta join test specs. */ - protected abstract static class AbstractBaseTestSpec { + /** Base test specification shared by binary and cascaded delta join. */ + protected abstract static class AbstractTestSpec { abstract RowType getLeftInputRowType(); From 73c0292834f26ef8b0819f0eb1d635c68dd6943c Mon Sep 17 00:00:00 2001 From: xuyang Date: Thu, 26 Mar 2026 09:59:37 +0800 Subject: [PATCH 2/3] spotless --- .../planner/runtime/stream/sql/CascadedDeltaJoinITCase.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/CascadedDeltaJoinITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/CascadedDeltaJoinITCase.scala index b10edc827a46d..db680dd1aab23 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/CascadedDeltaJoinITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/CascadedDeltaJoinITCase.scala @@ -15,7 +15,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.flink.table.planner.runtime.stream.sql import org.apache.flink.table.api.config.{ExecutionConfigOptions, OptimizerConfigOptions} From 5057152a95272a8e1c99f19f57ab45009bc8eb70 Mon Sep 17 00:00:00 2001 From: xuyang Date: Thu, 26 Mar 2026 14:28:05 +0800 Subject: [PATCH 3/3] address comment --- .../exec/stream/StreamExecDeltaJoin.java | 5 +- .../stream/sql/BinaryDeltaJoinITCase.scala | 136 +++++++++++++++++- .../join/deltajoin/CascadedLookupHandler.java | 18 +-- 3 files changed, 143 insertions(+), 16 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java index bf80cf73d662d..c540c543a6d0e 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java @@ -629,7 +629,10 @@ private static LookupHandlerBase generateLookupHandler( .mapToInt( key -> { Preconditions.checkState( - key instanceof FunctionCallUtil.FieldRef); + key instanceof FunctionCallUtil.FieldRef, + "Currently, delta join only supports to use field " + + "reference as lookup key, but found %s", + key.getClass().getName()); return ((FunctionCallUtil.FieldRef) key).index; }) .toArray(), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/BinaryDeltaJoinITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/BinaryDeltaJoinITCase.scala index 0ff1493753600..0ff483dda7ddb 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/BinaryDeltaJoinITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/BinaryDeltaJoinITCase.scala @@ -864,10 +864,123 @@ class BinaryDeltaJoinITCase(enableCache: Boolean) extends DeltaJoinITCaseBase(en .build()) } + @TestTemplate + def testJoinKeysContainNull(): Unit = { + val data1 = List( + // both join keys are null + changelogRow( + "+I", + null.asInstanceOf[java.lang.Double], + null.asInstanceOf[java.lang.Integer], + LocalDateTime.of(2021, 1, 1, 1, 1, 1)), + // one of join keys is null + changelogRow( + "+I", + null.asInstanceOf[java.lang.Double], + Int.box(21), + LocalDateTime.of(2022, 2, 2, 2, 2, 21)), + changelogRow( + "+I", + Double.box(22.0), + null.asInstanceOf[java.lang.Integer], + LocalDateTime.of(2022, 2, 2, 2, 2, 22)), + // both join keys are not null + changelogRow("+I", Double.box(3.0), Int.box(3), LocalDateTime.of(2033, 3, 3, 3, 3, 3)) + ) + + val data2 = List( + // both join keys are null + changelogRow( + "+I", + null.asInstanceOf[java.lang.Integer], + null.asInstanceOf[java.lang.Double], + LocalDateTime.of(2021, 1, 1, 1, 1, 1)), + // one of join keys is null + changelogRow( + "+I", + Int.box(21), + null.asInstanceOf[java.lang.Double], + LocalDateTime.of(2022, 2, 2, 2, 2, 21)), + changelogRow( + "+I", + null.asInstanceOf[java.lang.Integer], + Double.box(22.0), + LocalDateTime.of(2022, 2, 2, 2, 2, 22)), + // both join keys are not null + changelogRow("+I", Int.box(3), Double.box(3.0), LocalDateTime.of(2033, 3, 3, 3, 3, 3)) + ) + + val expected = List("+I[3.0, 3, 2033-03-03T03:03:03, 3, 3.0, 2033-03-03T03:03:03]") + + testUpsertResult( + newTestSpecBuilder() + .withLeftIndex(List("a0", "a1")) + .withRightIndex(List("b0", "b1")) + .withLeftPk(List("a2")) + .withRightPk(List("b2")) + .withLeftImmutableCols(List("a0", "a1")) + .withRightImmutableCols(List("b0", "b1")) + .withLeftData(data1) + .withRightData(data2) + .withJoinCondition("a0 = b0 and a1 = b1") + .withSinkPk(List("l2", "r2")) + .withExpectedData(expected) + .build()) + } + + @TestTemplate + def testLeftTableEmpty(): Unit = { + testOneTableEmpty(true) + } + + @TestTemplate + def testRightTableEmpty(): Unit = { + testOneTableEmpty(false) + } + + def testOneTableEmpty(isLeftTableEmpty: Boolean): Unit = { + val data1 = if (isLeftTableEmpty) { + List() + } else { + List( + // both join keys are null + changelogRow("+I", Double.box(1.0), Int.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)) + ) + } + + val data2 = if (isLeftTableEmpty) { + List( + // both join keys are null + changelogRow("+I", Int.box(1), Double.box(1), LocalDateTime.of(2021, 1, 1, 1, 1, 1)) + ) + } else { + List() + } + + val expected = List() + + testUpsertResult( + newTestSpecBuilder() + .withLeftIndex(List("a0", "a1")) + .withRightIndex(List("b0", "b1")) + .withLeftPk(List("a2")) + .withRightPk(List("b2")) + .withLeftImmutableCols(List("a0", "a1")) + .withRightImmutableCols(List("b0", "b1")) + .withLeftData(data1) + .withRightData(data2) + .withJoinCondition("a0 = b0 and a1 = b1") + .withSinkPk(List("l2", "r2")) + .withExpectedData(expected) + .build()) + } + private def testUpsertResult(testSpec: TestSpec): Unit = { prepareTable( testSpec.leftIndex, testSpec.rightIndex, + testSpec.leftImmutableCols.getOrElse(List()), + testSpec.rightImmutableCols.getOrElse(List()), testSpec.leftPk.orNull, testSpec.rightPk.orNull, testSpec.sinkPk, @@ -932,6 +1045,8 @@ class BinaryDeltaJoinITCase(enableCache: Boolean) extends DeltaJoinITCaseBase(en prepareTable( leftIndex, rightIndex, + List(), + List(), null, null, List("l0", "r0"), @@ -948,6 +1063,8 @@ class BinaryDeltaJoinITCase(enableCache: Boolean) extends DeltaJoinITCaseBase(en private def prepareTable( leftIndex: List[String], rightIndex: List[String], + leftImmutableCols: List[String], + rightImmutableCols: List[String], @Nullable leftPk: List[String], @Nullable rightPk: List[String], sinkPk: List[String], @@ -992,7 +1109,7 @@ class BinaryDeltaJoinITCase(enableCache: Boolean) extends DeltaJoinITCaseBase(en | $leftExtraOptionsStr |) |""".stripMargin) - addIndexesAndImmutableCols("testLeft", List(leftIndex), List()) + addIndexesAndImmutableCols("testLeft", List(leftIndex), leftImmutableCols) tEnv.executeSql("drop table if exists testRight") val rightExtraOptionsStr = @@ -1026,7 +1143,7 @@ class BinaryDeltaJoinITCase(enableCache: Boolean) extends DeltaJoinITCaseBase(en | $rightExtraOptionsStr |) |""".stripMargin) - addIndexesAndImmutableCols("testRight", List(rightIndex), List()) + addIndexesAndImmutableCols("testRight", List(rightIndex), rightImmutableCols) tEnv.executeSql("drop table if exists testSnk") tEnv.executeSql(s""" @@ -1053,6 +1170,8 @@ class BinaryDeltaJoinITCase(enableCache: Boolean) extends DeltaJoinITCaseBase(en private case class TestSpec( leftIndex: List[String], rightIndex: List[String], + leftImmutableCols: Option[List[String]], + rightImmutableCols: Option[List[String]], leftPk: Option[List[String]], rightPk: Option[List[String]], partialInsertCols: Option[List[String]], @@ -1075,6 +1194,8 @@ class BinaryDeltaJoinITCase(enableCache: Boolean) extends DeltaJoinITCaseBase(en private class TestSpecBuilder { private var leftIndex: Option[List[String]] = None private var rightIndex: Option[List[String]] = None + private var leftImmutableCols: Option[List[String]] = None + private var rightImmutableCols: Option[List[String]] = None private var leftPk: Option[List[String]] = None private var rightPk: Option[List[String]] = None private var partialInsertCols: Option[List[String]] = None @@ -1102,6 +1223,15 @@ class BinaryDeltaJoinITCase(enableCache: Boolean) extends DeltaJoinITCaseBase(en rightIndex = Some(requireNonNull(index)) this } + def withLeftImmutableCols(immutableCols: List[String]): TestSpecBuilder = { + leftImmutableCols = Some(requireNonNull(immutableCols)) + this + } + + def withRightImmutableCols(immutableCols: List[String]): TestSpecBuilder = { + rightImmutableCols = Some(requireNonNull(immutableCols)) + this + } def withLeftPk(pk: List[String]): TestSpecBuilder = { leftPk = Some(requireNonNull(pk)) @@ -1192,6 +1322,8 @@ class BinaryDeltaJoinITCase(enableCache: Boolean) extends DeltaJoinITCaseBase(en TestSpec( requireNonNull(leftIndex.orNull), requireNonNull(rightIndex.orNull), + leftImmutableCols, + rightImmutableCols, leftPk, rightPk, partialInsertCols, diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/CascadedLookupHandler.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/CascadedLookupHandler.java index 3e915ed06d73f..e71839cae9a09 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/CascadedLookupHandler.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/deltajoin/CascadedLookupHandler.java @@ -44,6 +44,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; /** @@ -60,7 +61,7 @@ public class CascadedLookupHandler extends LookupHandlerBase { private static final Logger LOG = LoggerFactory.getLogger(CascadedLookupHandler.class); // used for debug and start with 1 - protected final int id; + private final int id; private @Nullable final GeneratedFilterCondition generatedRemainingCondition; private final RowDataKeySelector streamSideLookupKeySelector; private final boolean leftLookupRight; @@ -69,8 +70,8 @@ public class CascadedLookupHandler extends LookupHandlerBase { private transient Map> allInputsWithLookupKey; private transient Map> lookupResults; - protected @Nullable transient Integer totalNumShouldBeHandledThisRound = null; - protected @Nullable transient Integer handledNum = null; + private @Nullable transient Integer totalNumShouldBeHandledThisRound = null; + private @Nullable transient Integer handledNum = null; public CascadedLookupHandler( int id, @@ -221,16 +222,7 @@ protected void completeResultsInMailbox(RowData input, Collection resul } private boolean noFurtherInput() { - Preconditions.checkState( - totalNumShouldBeHandledThisRound != null && handledNum != null, - "This function is called without be handled"); - Preconditions.checkState( - handledNum <= totalNumShouldBeHandledThisRound, - String.format( - "The handled num is greater than the total num. The handledNum is %d, the totalNumShouldBeHandledThisRound is %d", - handledNum, totalNumShouldBeHandledThisRound)); - - return handledNum.equals(totalNumShouldBeHandledThisRound); + return Objects.equals(handledNum, totalNumShouldBeHandledThisRound); } private void finish() throws Exception {