Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 174 additions & 4 deletions isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,57 @@ public class SubstraitRelVisitor extends RelNodeVisitor<Rel, RuntimeException> {

private static final Expression.BoolLiteral TRUE = ExpressionCreator.bool(false, true);

/** Converter for Calcite {@link RexNode} to Substrait {@link Expression}. */
protected final RexExpressionConverter rexExpressionConverter;

/** Converter for {@link AggregateCall} to Substrait aggregate invocation. */
protected final AggregateFunctionConverter aggregateFunctionConverter;

/** Converter for Calcite {@link RelDataType} to Substrait {@link Type}. */
protected final TypeConverter typeConverter;

private Map<RexFieldAccess, Integer> fieldAccessDepthMap;

/** Use {@link SubstraitRelVisitor#SubstraitRelVisitor(ConverterProvider)} */
/**
* Creates a new SubstraitRelVisitor with the specified type factory and extensions.
*
* @param typeFactory the Calcite type factory
* @param extensions the Substrait extension collection
* @deprecated Use {@link SubstraitRelVisitor#SubstraitRelVisitor(ConverterProvider)}
*/
@Deprecated
public SubstraitRelVisitor(
RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions) {
this(new ConverterProvider(extensions, typeFactory));
}

/**
* Creates a new SubstraitRelVisitor with the specified converter provider.
*
* @param converterProvider the converter provider containing configuration and converters
*/
public SubstraitRelVisitor(ConverterProvider converterProvider) {
this.typeConverter = converterProvider.getTypeConverter();
this.aggregateFunctionConverter = converterProvider.getAggregateFunctionConverter();
this.rexExpressionConverter = converterProvider.getRexExpressionConverter(this);
}

/**
* Converts a {@link RexNode} to a Substrait {@link Expression}.
*
* @param node Rex expression node
* @return Substrait expression
*/
protected Expression toExpression(RexNode node) {
return node.accept(rexExpressionConverter);
}

/**
* Converts a Calcite {@link org.apache.calcite.rel.core.TableScan}.
*
* @param scan Calcite table scan
* @return Substrait named scan
*/
@Override
public Rel visit(org.apache.calcite.rel.core.TableScan scan) {
NamedStruct type = typeConverter.toNamedStruct(scan.getRowType());
Expand All @@ -104,11 +133,23 @@ public Rel visit(org.apache.calcite.rel.core.TableScan scan) {
.build();
}

/**
* Converts a Calcite {@link org.apache.calcite.rel.core.TableFunctionScan}.
*
* @param scan Calcite table function scan
* @return Converted relation or {@code super.visit(scan)}
*/
@Override
public Rel visit(org.apache.calcite.rel.core.TableFunctionScan scan) {
return super.visit(scan);
}

/**
* Converts a Calcite {@link org.apache.calcite.rel.core.Values}.
*
* @param values Calcite values relation
* @return Substrait scan (empty or virtual table)
*/
@Override
public Rel visit(org.apache.calcite.rel.core.Values values) {
NamedStruct type = typeConverter.toNamedStruct(values.getRowType());
Expand All @@ -134,17 +175,35 @@ public Rel visit(org.apache.calcite.rel.core.Values values) {
return VirtualTableScan.builder().initialSchema(type).addAllRows(structs).build();
}

/**
* Converts a Calcite {@link org.apache.calcite.rel.core.Filter}.
*
* @param filter Calcite filter relation
* @return Substrait filter
*/
@Override
public Rel visit(org.apache.calcite.rel.core.Filter filter) {
Expression condition = toExpression(filter.getCondition());
return Filter.builder().condition(condition).input(apply(filter.getInput())).build();
}

/**
* Converts a Calcite {@link org.apache.calcite.rel.core.Calc}.
*
* @param calc Calcite calc relation
* @return Converted relation
*/
@Override
public Rel visit(org.apache.calcite.rel.core.Calc calc) {
return super.visit(calc);
}

/**
* Converts a Calcite {@link org.apache.calcite.rel.core.Project}.
*
* @param project Calcite project relation
* @return Substrait project
*/
@Override
public Rel visit(org.apache.calcite.rel.core.Project project) {
List<Expression> expressions =
Expand All @@ -166,6 +225,12 @@ public Rel visit(org.apache.calcite.rel.core.Project project) {
.build();
}

/**
* Converts a Calcite {@link org.apache.calcite.rel.core.Join}.
*
* @param join Calcite join relation
* @return Substrait join or cross
*/
@Override
public Rel visit(org.apache.calcite.rel.core.Join join) {
Rel left = apply(join.getLeft());
Expand Down Expand Up @@ -200,6 +265,12 @@ private Join.JoinType asJoinType(org.apache.calcite.rel.core.Join join) {
throw new UnsupportedOperationException("Unsupported join type: " + join.getJoinType());
}

/**
* Converts a Calcite {@link org.apache.calcite.rel.core.Correlate}.
*
* @param correlate Calcite correlate relation
* @return Converted relation
*/
@Override
public Rel visit(org.apache.calcite.rel.core.Correlate correlate) {
// left input of correlated-join is similar to the left input of a logical join
Expand All @@ -211,13 +282,25 @@ public Rel visit(org.apache.calcite.rel.core.Correlate correlate) {
return super.visit(correlate);
}

/**
* Converts a Calcite {@link org.apache.calcite.rel.core.Union}.
*
* @param union Calcite union relation
* @return Substrait set-union
*/
@Override
public Rel visit(org.apache.calcite.rel.core.Union union) {
List<Rel> inputs = apply(union.getInputs());
Set.SetOp setOp = union.all ? Set.SetOp.UNION_ALL : Set.SetOp.UNION_DISTINCT;
return Set.builder().inputs(inputs).setOp(setOp).build();
}

/**
* Converts a Calcite {@link org.apache.calcite.rel.core.Intersect}.
*
* @param intersect Calcite intersect relation
* @return Substrait set-intersection
*/
@Override
public Rel visit(org.apache.calcite.rel.core.Intersect intersect) {
List<Rel> inputs = apply(intersect.getInputs());
Expand All @@ -226,13 +309,26 @@ public Rel visit(org.apache.calcite.rel.core.Intersect intersect) {
return Set.builder().inputs(inputs).setOp(setOp).build();
}

/**
* Converts a Calcite {@link org.apache.calcite.rel.core.Minus}.
*
* @param minus Calcite minus relation
* @return Substrait set-minus
*/
@Override
public Rel visit(org.apache.calcite.rel.core.Minus minus) {
List<Rel> inputs = apply(minus.getInputs());
Set.SetOp setOp = minus.all ? Set.SetOp.MINUS_PRIMARY_ALL : Set.SetOp.MINUS_PRIMARY;
return Set.builder().inputs(inputs).setOp(setOp).build();
}

/**
* Converts a Calcite {@link org.apache.calcite.rel.core.Aggregate}.
*
* @param aggregate Calcite aggregate relation
* @return Substrait aggregate
* @throws IllegalStateException if unexpected remap state is encountered.
*/
@Override
public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) {
Rel input = apply(aggregate.getInput());
Expand Down Expand Up @@ -331,11 +427,23 @@ Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCal
return builder.build();
}

/**
* Converts a Calcite {@link org.apache.calcite.rel.core.Match}.
*
* @param match Calcite match relation
* @return Converted relation
*/
@Override
public Rel visit(org.apache.calcite.rel.core.Match match) {
return super.visit(match);
}

/**
* Converts a Calcite {@link org.apache.calcite.rel.core.Sort}.
*
* @param sort Calcite sort relation
* @return Substrait sort/fetch chain
*/
@Override
public Rel visit(org.apache.calcite.rel.core.Sort sort) {
Rel input = apply(sort.getInput());
Expand Down Expand Up @@ -377,6 +485,13 @@ private long asLong(RexNode rex) {
throw new UnsupportedOperationException("Unknown type: " + rex);
}

/**
* Converts a Calcite sort collation to a Substrait {@link Expression.SortField}.
*
* @param collation Calcite field collation
* @param inputType Input record type
* @return Substrait sort field
*/
public static Expression.SortField toSortField(
RelFieldCollation collation, Type.Struct inputType) {
Expression.SortDirection direction = asSortDirection(collation);
Expand Down Expand Up @@ -405,11 +520,24 @@ private static Expression.SortDirection asSortDirection(RelFieldCollation collat
throw new IllegalArgumentException("Unsupported collation direction: " + direction);
}

/**
* Converts a Calcite {@link org.apache.calcite.rel.core.Exchange}.
*
* @param exchange Calcite exchange relation
* @return Converted relation
*/
@Override
public Rel visit(org.apache.calcite.rel.core.Exchange exchange) {
return super.visit(exchange);
}

/**
* Converts a Calcite {@link TableModify} (INSERT/DELETE/UPDATE).
*
* @param modify Calcite table modify node
* @return Substrait write/update relation
* @throws IllegalStateException if an update column is not found in the table schema.
*/
@Override
public Rel visit(TableModify modify) {
switch (modify.getOperation()) {
Expand Down Expand Up @@ -518,6 +646,12 @@ private NamedStruct getSchema(final RelNode queryRelRoot) {
return typeConverter.toNamedStruct(rowType);
}

/**
* Handles Calcite {@link CreateTable} as Substrait CTAS. (Create Table As Select)
*
* @param createTable Calcite create-table node
* @return Substrait CTAS write relation
*/
public Rel handleCreateTable(CreateTable createTable) {
RelNode input = createTable.getInput();
Rel inputRel = apply(input);
Expand All @@ -532,6 +666,12 @@ public Rel handleCreateTable(CreateTable createTable) {
.build();
}

/**
* Handles Calcite {@link CreateView} as Substrait view DDL.
*
* @param createView Calcite create-view node
* @return Substrait view DDL relation
*/
public Rel handleCreateView(CreateView createView) {
RelNode input = createView.getInput();
Rel inputRel = apply(input);
Expand All @@ -548,6 +688,13 @@ public Rel handleCreateView(CreateView createView) {
.build();
}

/**
* Visits other Calcite nodes (e.g., DDL wrappers).
*
* @param other Calcite node
* @return Converted relation
* @throws UnsupportedOperationException if the node type is unsupported.
*/
@Override
public Rel visitOther(RelNode other) {
if (other instanceof CreateTable) {
Expand All @@ -559,19 +706,42 @@ public Rel visitOther(RelNode other) {
throw new UnsupportedOperationException("Unable to handle node: " + other);
}

/**
* Precomputes depth for outer field accesses used by correlated expressions.
*
* @param root Root Calcite node to analyze
*/
protected void popFieldAccessDepthMap(RelNode root) {
final OuterReferenceResolver resolver = new OuterReferenceResolver();
fieldAccessDepthMap = resolver.apply(root);
}

/**
* Returns the depth of a field access for correlated expressions.
*
* @param fieldAccess Rex field access
* @return Depth value, or {@code null} if unknown
*/
public Integer getFieldAccessDepth(RexFieldAccess fieldAccess) {
return fieldAccessDepthMap.get(fieldAccess);
}

/**
* Applies the visitor to a Calcite {@link RelNode}.
*
* @param r Calcite node
* @return Converted Substrait relation
*/
public Rel apply(RelNode r) {
return reverseAccept(r);
}

/**
* Applies the visitor to a list of Calcite {@link RelNode}s.
*
* @param inputs Calcite input relations
* @return Converted Substrait relations
*/
public List<Rel> apply(List<RelNode> inputs) {
return inputs.stream()
.map(inputRel -> apply(inputRel))
Expand All @@ -581,9 +751,9 @@ public List<Rel> apply(List<RelNode> inputs) {
/**
* Deprecated, use {@link #convert(RelRoot, ConverterProvider)} directly
*
* @param relRoot The Calcite RelRoot to convert.
* @param extensions The extension collection to use for the conversion.
* @return The resulting Substrait Plan.Root.
* @param relRoot The Calcite RelRoot to convert
* @param extensions The extension collection to use for the conversion
* @return The resulting Substrait Plan.Root
*/
@Deprecated
public static Plan.Root convert(RelRoot relRoot, SimpleExtension.ExtensionCollection extensions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ public SubstraitToCalcite(ConverterProvider converterProvider) {
this(converterProvider, null);
}

/**
* Creates a Substrait-to-Calcite converter with default type converter and a catalog reader.
*
* @param converterProvider the converter provider containing configuration and converters
* @param catalogReader Calcite catalog reader for schema resolution.
*/
public SubstraitToCalcite(
ConverterProvider converterProvider, Prepare.CatalogReader catalogReader) {
this.converterProvider = converterProvider;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ public SubstraitToSql(ConverterProvider converterProvider) {
substraitToCalcite = converterProvider.getSubstraitToCalcite();
}

/**
* Converts a Substrait {@link Rel} to a Calcite {@link RelNode}.
*
* <p>This is the first step before generating SQL from Substrait plans.
*
* @param relRoot The Substrait relational root to convert.
* @param catalog The Calcite catalog reader for schema resolution.
* @return A Calcite {@link RelNode} representing the converted Substrait plan.
*/
public RelNode substraitRelToCalciteRel(Rel relRoot, Prepare.CatalogReader catalog) {
return SubstraitRelNodeConverter.convert(relRoot, catalog, converterProvider);
}
Expand Down
Loading