Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.utils;

import org.apache.flink.annotation.Internal;
import org.apache.flink.util.FlinkRuntimeException;

/** Thrown when a MultiJoin node has no common join key. */
@Internal
public class NoCommonJoinKeyException extends FlinkRuntimeException {
private static final long serialVersionUID = 1L;

public NoCommonJoinKeyException(String message) {
super(message);
}

public NoCommonJoinKeyException(String message, Throwable cause) {
super(message, cause);
}

public NoCommonJoinKeyException(Throwable cause) {
super(cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,9 @@ private RelNode visitMultiJoin(FlinkLogicalMultiJoin multiJoin) {
.collect(Collectors.toList());

final List<RelDataType> allFields =
newInputs.stream().map(RelNode::getRowType).collect(Collectors.toList());
newInputs.stream()
.flatMap(input -> RelOptUtil.getFieldTypeList(input.getRowType()).stream())
.collect(Collectors.toList());

RexTimeIndicatorMaterializer materializer = new RexTimeIndicatorMaterializer(allFields);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public class StreamExecMultiJoin extends ExecNodeBase<RowData>
private static final String FIELD_NAME_JOIN_TYPES = "joinTypes";
private static final String FIELD_NAME_JOIN_CONDITIONS = "joinConditions";
private static final String FIELD_NAME_JOIN_ATTRIBUTE_MAP = "joinAttributeMap";
private static final String FIELD_NAME_INPUT_UPSERT_KEYS = "inputUpsertKeys";
private static final String FIELD_NAME_INPUT_UNIQUE_KEYS = "inputUniqueKeys";
private static final String FIELD_NAME_MULTI_JOIN_CONDITION = "multiJoinCondition";

@JsonProperty(FIELD_NAME_JOIN_TYPES)
Expand All @@ -100,12 +100,13 @@ public class StreamExecMultiJoin extends ExecNodeBase<RowData>
@JsonInclude(JsonInclude.Include.NON_EMPTY)
private final Map<Integer, List<ConditionAttributeRef>> joinAttributeMap;

@JsonProperty(FIELD_NAME_INPUT_UPSERT_KEYS)
// Why List<List<int[]>> as a type
// Each unique key can be also a composite key with multiple fields, thus -> int[].
// Theoretically, each input can have multiple unique keys, thus -> List<int[]>
// Since we have multiple inputs -> List<List<int[]>>
@JsonProperty(FIELD_NAME_INPUT_UNIQUE_KEYS)
@JsonInclude(JsonInclude.Include.NON_EMPTY)
// List of upsert keys for each input, where each inner list corresponds to an input
// The reason it's a List<List<int[]>> is that SQL allows only one primary key but
// multiple upsert (unique) keys per input
private final List<List<int[]>> inputUpsertKeys;
private final List<List<int[]>> inputUniqueKeys;

@JsonProperty(FIELD_NAME_STATE)
@JsonInclude(JsonInclude.Include.NON_NULL)
Expand All @@ -117,7 +118,7 @@ public StreamExecMultiJoin(
final List<? extends @Nullable RexNode> joinConditions,
@Nullable final RexNode multiJoinCondition,
final Map<Integer, List<ConditionAttributeRef>> joinAttributeMap,
final List<List<int[]>> inputUpsertKeys,
final List<List<int[]>> inputUniqueKeys,
final Map<Integer, Long> stateTtlFromHint,
final List<InputProperty> inputProperties,
final RowType outputType,
Expand All @@ -130,7 +131,7 @@ public StreamExecMultiJoin(
joinConditions,
multiJoinCondition,
joinAttributeMap,
inputUpsertKeys,
inputUniqueKeys,
StateMetadata.getMultiInputOperatorDefaultMeta(
stateTtlFromHint, tableConfig, generateStateNames(inputProperties.size())),
inputProperties,
Expand All @@ -150,26 +151,26 @@ public StreamExecMultiJoin(
final RexNode multiJoinCondition,
@JsonProperty(FIELD_NAME_JOIN_ATTRIBUTE_MAP)
final Map<Integer, List<ConditionAttributeRef>> joinAttributeMap,
@JsonProperty(FIELD_NAME_INPUT_UPSERT_KEYS) final List<List<int[]>> inputUpsertKeys,
@JsonProperty(FIELD_NAME_INPUT_UNIQUE_KEYS) final List<List<int[]>> inputUniqueKeys,
@Nullable @JsonProperty(FIELD_NAME_STATE) final List<StateMetadata> stateMetadataList,
@JsonProperty(FIELD_NAME_INPUT_PROPERTIES) final List<InputProperty> inputProperties,
@JsonProperty(FIELD_NAME_OUTPUT_TYPE) final RowType outputType,
@JsonProperty(FIELD_NAME_DESCRIPTION) final String description) {
super(id, context, persistedConfig, inputProperties, outputType, description);
validateInputs(inputProperties, joinTypes, joinConditions, inputUpsertKeys);
validateInputs(inputProperties, joinTypes, joinConditions, inputUniqueKeys);
this.joinTypes = checkNotNull(joinTypes);
this.joinConditions = checkNotNull(joinConditions);
this.multiJoinCondition = multiJoinCondition;
this.inputUpsertKeys = checkNotNull(inputUpsertKeys);
this.joinAttributeMap = Objects.requireNonNullElseGet(joinAttributeMap, Map::of);
this.inputUniqueKeys = checkNotNull(inputUniqueKeys);
this.stateMetadataList = stateMetadataList;
}

private void validateInputs(
final List<InputProperty> inputProperties,
final List<FlinkJoinType> joinTypes,
final List<? extends @Nullable RexNode> joinConditions,
final List<List<int[]>> inputUpsertKeys) {
final List<List<int[]>> inputUniqueKeys) {
checkArgument(
inputProperties.size() >= 2, "Multi-input join operator needs at least 2 inputs.");
checkArgument(
Expand All @@ -179,8 +180,8 @@ private void validateInputs(
joinConditions.size() == inputProperties.size(),
"Size of joinConditions must match the number of inputs.");
checkArgument(
inputUpsertKeys.size() == inputProperties.size(),
"Size of inputUpsertKeys must match the number of inputs.");
inputUniqueKeys.size() == inputProperties.size(),
"Size of inputUniqueKeys must match the number of inputs.");
}

private static String[] generateStateNames(int numInputs) {
Expand Down Expand Up @@ -220,7 +221,7 @@ protected Transformation<RowData> translateToPlanInternal(
planner.getFlinkContext().getClassLoader(),
inputTypeInfos.get(i),
keyExtractor.getJoinKeyIndices(i),
inputUpsertKeys.get(i)));
inputUniqueKeys.get(i)));
}

final GeneratedJoinCondition[] generatedJoinConditions =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ public RelOptCost computeSelfCost(final RelOptPlanner planner, final RelMetadata

final Double averageRowSize = mq.getAverageRowSize(input);
final double dAverageRowSize = averageRowSize == null ? 100.0 : averageRowSize;
rowCount *= inputRowCount;
rowCount += inputRowCount;
cpu += inputRowCount;
io += inputRowCount * dAverageRowSize;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ public class StreamPhysicalMultiJoin extends AbstractRelNode implements StreamPh
private final @Nullable RexNode postJoinFilter;
private final List<RelHint> hints;

// Cached derived properties to avoid recomputation
private @Nullable RexNode multiJoinCondition;
private @Nullable List<List<int[]>> inputUniqueKeys;

public StreamPhysicalMultiJoin(
final RelOptCluster cluster,
final RelTraitSet traitSet,
Expand All @@ -101,6 +105,8 @@ public StreamPhysicalMultiJoin(
this.postJoinFilter = postJoinFilter;
this.hints = hints;
this.keyExtractor = keyExtractor;
this.multiJoinCondition = getMultiJoinCondition();
this.inputUniqueKeys = getUniqueKeysForInputs();
}

@Override
Expand All @@ -119,6 +125,9 @@ public void replaceInput(final int ordinalInParent, final RelNode p) {
final List<RelNode> newInputs = new ArrayList<>(inputs);
newInputs.set(ordinalInParent, p);
this.inputs = List.copyOf(newInputs);
// Invalidate cached derived properties since inputs changed
this.multiJoinCondition = null;
this.inputUniqueKeys = null;
recomputeDigest();
}

Expand Down Expand Up @@ -166,18 +175,18 @@ protected RelDataType deriveRowType() {

@Override
public ExecNode<?> translateToExecNode() {
final RexNode multiJoinCondition = createMultiJoinCondition();
final List<List<int[]>> inputUpsertKeys = getUpsertKeysForInputs();
final RexNode multijoinCondition = getMultiJoinCondition();
final List<List<int[]>> localInputUniqueKeys = getUniqueKeysForInputs();
final List<FlinkJoinType> execJoinTypes = getExecJoinTypes();
final List<InputProperty> inputProperties = createInputProperties();

return new StreamExecMultiJoin(
unwrapTableConfig(this),
execJoinTypes,
joinConditions,
multiJoinCondition,
multijoinCondition,
joinAttributeMap,
inputUpsertKeys,
localInputUniqueKeys,
Collections.emptyMap(), // TODO Enable hint-based state ttl. See ticket
// TODO https://issues.apache.org/jira/browse/FLINK-37936
inputProperties,
Expand All @@ -187,33 +196,56 @@ public ExecNode<?> translateToExecNode() {

private RexNode createMultiJoinCondition() {
final List<RexNode> conjunctions = new ArrayList<>();

for (RexNode joinCondition : joinConditions) {
if (joinCondition != null) {
conjunctions.add(joinCondition);
}
}

conjunctions.add(joinFilter);

if (postJoinFilter != null) {
conjunctions.add(postJoinFilter);
}

return RexUtil.composeConjunction(getCluster().getRexBuilder(), conjunctions, true);
}

private List<List<int[]>> getUpsertKeysForInputs() {
return inputs.stream()
.map(
input -> {
final Set<ImmutableBitSet> upsertKeys = getUpsertKeys(input);

if (upsertKeys == null) {
return Collections.<int[]>emptyList();
}
return upsertKeys.stream()
.map(ImmutableBitSet::toArray)
.collect(Collectors.toList());
})
.collect(Collectors.toList());
public List<List<int[]>> getUniqueKeysForInputs() {
if (inputUniqueKeys == null) {
final List<List<int[]>> computed =
inputs.stream()
.map(
input -> {
final Set<ImmutableBitSet> uniqueKeys =
getUniqueKeys(input);

if (uniqueKeys == null) {
return Collections.<int[]>emptyList();
}

return uniqueKeys.stream()
.map(ImmutableBitSet::toArray)
.collect(Collectors.toList());
})
.collect(Collectors.toList());
inputUniqueKeys = Collections.unmodifiableList(computed);
}
return inputUniqueKeys;
}

private @Nullable Set<ImmutableBitSet> getUpsertKeys(RelNode input) {
private @Nullable Set<ImmutableBitSet> getUniqueKeys(RelNode input) {
final FlinkRelMetadataQuery fmq =
FlinkRelMetadataQuery.reuseOrCreate(input.getCluster().getMetadataQuery());
return fmq.getUpsertKeys(input);
return fmq.getUniqueKeys(input);
}

public RexNode getMultiJoinCondition() {
if (multiJoinCondition == null) {
multiJoinCondition = createMultiJoinCondition();
}
return multiJoinCondition;
}

private List<FlinkJoinType> getExecJoinTypes() {
Expand Down Expand Up @@ -255,8 +287,8 @@ public List<JoinRelType> getJoinTypes() {
*/
public boolean inputUniqueKeyContainsCommonJoinKey(int inputId) {
final RelNode input = getInputs().get(inputId);
final Set<ImmutableBitSet> inputUniqueKeys = getUpsertKeys(input);
if (inputUniqueKeys == null || inputUniqueKeys.isEmpty()) {
final Set<ImmutableBitSet> inputUniqueKeysSet = getUniqueKeys(input);
if (inputUniqueKeysSet == null || inputUniqueKeysSet.isEmpty()) {
return false;
}

Expand All @@ -266,7 +298,8 @@ public boolean inputUniqueKeyContainsCommonJoinKey(int inputId) {
}

final ImmutableBitSet commonJoinKeys = ImmutableBitSet.of(commonJoinKeyIndices);
return inputUniqueKeys.stream().anyMatch(uniqueKey -> uniqueKey.contains(commonJoinKeys));
return inputUniqueKeysSet.stream()
.anyMatch(uniqueKey -> uniqueKey.contains(commonJoinKeys));
}

private List<InputProperty> createInputProperties() {
Expand Down
Loading