diff --git a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java index 6cba80781..0d5d5bf0e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java +++ b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java @@ -11,31 +11,35 @@ import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction; import org.apache.calcite.sql.type.ReturnTypes; +/** + * Provides Substrait-specific variants of Calcite aggregate functions to ensure type inference + * matches Substrait expectations. + * + *

Default Calcite implementations may infer return types that differ from Substrait, causing + * conversion issues. This class overrides those behaviors. + */ public class AggregateFunctions { - // For some arithmetic aggregate functions, the default Calcite aggregate function implementations - // will infer return types that differ from those expected by Substrait. - // This type mismatch can cause conversion and planning failures. - + /** Substrait-specific MIN aggregate function (nullable return type). */ public static SqlAggFunction MIN = new SubstraitSqlMinMaxAggFunction(SqlKind.MIN); + + /** Substrait-specific MAX aggregate function (nullable return type). */ public static SqlAggFunction MAX = new SubstraitSqlMinMaxAggFunction(SqlKind.MAX); + + /** Substrait-specific AVG aggregate function (nullable return type). */ public static SqlAggFunction AVG = new SubstraitAvgAggFunction(SqlKind.AVG); + + /** Substrait-specific SUM aggregate function (nullable return type). */ public static SqlAggFunction SUM = new SubstraitSumAggFunction(); + + /** Substrait-specific SUM0 aggregate function (non-null BIGINT return type). */ public static SqlAggFunction SUM0 = new SubstraitSumEmptyIsZeroAggFunction(); /** - * Some Calcite rules, like {@link - * org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule}, introduce the default - * Calcite aggregate functions into plans. - * - *

When converting these Calcite plans to Substrait, we need to convert the default Calcite - * aggregate calls to the Substrait specific variants. - * - *

This function attempts to convert the given {@code aggFunction} to its Substrait equivalent + * Converts default Calcite aggregate functions to Substrait-specific variants when needed. * - * @param aggFunction the {@link SqlAggFunction} to convert to a Substrait specific variant - * @return an optional containing the Substrait equivalent of the given {@code aggFunction} if - * conversion was needed, empty otherwise. + * @param aggFunction the Calcite aggregate function + * @return optional containing Substrait equivalent if conversion applies */ public static Optional toSubstraitAggVariant(SqlAggFunction aggFunction) { if (aggFunction instanceof SqlMinMaxAggFunction) { @@ -53,7 +57,7 @@ public static Optional toSubstraitAggVariant(SqlAggFunction aggF } } - /** Extension of {@link SqlMinMaxAggFunction} that ALWAYS infers a nullable return type */ + /** Substrait variant of {@link SqlMinMaxAggFunction} that forces nullable return type. */ private static class SubstraitSqlMinMaxAggFunction extends SqlMinMaxAggFunction { public SubstraitSqlMinMaxAggFunction(SqlKind kind) { super(kind); @@ -65,12 +69,10 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) { } } - /** Extension of {@link SqlSumAggFunction} that ALWAYS infers a nullable return type */ + /** Substrait variant of {@link SqlSumAggFunction} that forces nullable return type. */ private static class SubstraitSumAggFunction extends SqlSumAggFunction { public SubstraitSumAggFunction() { - // This is intentionally null - // See the instantiation of SqlSumAggFunction in SqlStdOperatorTable - super(null); + super(null); // Matches Calcite's instantiation pattern } @Override @@ -79,7 +81,7 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) { } } - /** Extension of {@link SqlAvgAggFunction} that ALWAYS infers a nullable return type */ + /** Substrait variant of {@link SqlAvgAggFunction} that forces nullable return type. */ private static class SubstraitAvgAggFunction extends SqlAvgAggFunction { public SubstraitAvgAggFunction(SqlKind kind) { super(kind); @@ -92,8 +94,8 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) { } /** - * Extension of {@link SqlSumEmptyIsZeroAggFunction} that ALWAYS infers a NOT NULL BIGINT return - * type + * Substrait variant of {@link SqlSumEmptyIsZeroAggFunction} that forces BIGINT return type and + * uses a user-friendly name. */ private static class SubstraitSumEmptyIsZeroAggFunction extends org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction { @@ -103,8 +105,7 @@ public SubstraitSumEmptyIsZeroAggFunction() { @Override public String getName() { - // the default name for this function is `$sum0` - // override this to `sum0` which is a nicer name to use in queries + // Override default `$sum0` with `sum0` for readability return "sum0"; } diff --git a/isthmus/src/main/java/io/substrait/isthmus/CallConverter.java b/isthmus/src/main/java/io/substrait/isthmus/CallConverter.java index 8d68ef612..bc1f465c5 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/CallConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/CallConverter.java @@ -6,7 +6,22 @@ import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; +/** + * Functional interface for converting Calcite {@link RexCall} expressions into Substrait {@link + * Expression}s. + * + *

Implementations should return an {@link Optional} containing the converted expression, or + * {@link Optional#empty()} if the call is not handled. + */ @FunctionalInterface public interface CallConverter { + + /** + * Converts a Calcite {@link RexCall} into a Substrait {@link Expression}. + * + * @param call the Calcite function/operator call to convert + * @param topLevelConverter a function for converting nested {@link RexNode} operands + * @return an {@link Optional} containing the converted expression, or empty if not applicable + */ Optional convert(RexCall call, Function topLevelConverter); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java b/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java index 377020bb3..ba273f0a6 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java +++ b/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java @@ -7,6 +7,12 @@ import java.util.Set; import java.util.stream.Collectors; +/** + * Utility methods for working with Substrait extensions. + * + *

Provides helpers to identify and extract dynamic (custom/user-defined) functions from an + * {@link io.substrait.extension.SimpleExtension.ExtensionCollection}. + */ public class ExtensionUtils { /** diff --git a/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java b/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java index 79410ea47..8071468c0 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java +++ b/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java @@ -15,8 +15,20 @@ import org.apache.calcite.rex.RexSubQuery; import org.apache.calcite.rex.RexUtil.SubQueryCollector; -/** Resolve correlated variable and get Depth map for RexFieldAccess */ -// See OuterReferenceResolver.md for explanation how the Depth map is computed. +/** + * Resolve correlated variables and compute a depth map for {@link RexFieldAccess}. + * + *

Traverses a {@link RelNode} tree and: + * + *

+ * + * See OuterReferenceResolver.md for details on how the depth map is computed. + */ public class OuterReferenceResolver extends RelNodeVisitor { private final Map nestedDepth; @@ -24,16 +36,31 @@ public class OuterReferenceResolver extends RelNodeVisitor(); fieldAccessDepthMap = new IdentityHashMap<>(); } + /** + * Applies the resolver to a {@link RelNode} tree, computing the depth map. + * + * @param r the root relational node + * @return the same node after traversal + * @throws RuntimeException if the visitor encounters an unrecoverable condition + */ public Map apply(RelNode r) { reverseAccept(r); return fieldAccessDepthMap; } + /** + * Visits a {@link Filter}, registering any correlation variables and visiting its condition. + * + * @param filter the filter node + * @return the result of {@link RelNodeVisitor#visit(Filter)} + * @throws RuntimeException if traversal fails + */ @Override public RelNode visit(Filter filter) throws RuntimeException { for (CorrelationId id : filter.getVariablesSet()) { @@ -43,6 +70,16 @@ public RelNode visit(Filter filter) throws RuntimeException { return super.visit(filter); } + /** + * Visits a {@link Correlate}, handling correlation depth for both sides. + * + *

Special case: the right side is a correlated subquery in the rel tree (not a REX), so we + * manually adjust depth before/after visiting it. + * + * @param correlate the correlate (correlated join) node + * @return the correlate node + * @throws RuntimeException if traversal fails + */ @Override public RelNode visit(Correlate correlate) throws RuntimeException { for (CorrelationId id : correlate.getVariablesSet()) { @@ -63,6 +100,13 @@ public RelNode visit(Correlate correlate) throws RuntimeException { return correlate; } + /** + * Visits a generic {@link RelNode}, applying the resolver to all the node inputs. + * + * @param other the node to visit + * @return the node + * @throws RuntimeException if traversal fails + */ @Override public RelNode visitOther(RelNode other) throws RuntimeException { for (RelNode child : other.getInputs()) { @@ -71,6 +115,14 @@ public RelNode visitOther(RelNode other) throws RuntimeException { return other; } + /** + * Visits a {@link Project}, registering correlation variables and visiting any subqueries within + * its expressions. + * + * @param project the project node + * @return the result of {@link RelNodeVisitor#visit(Project)} + * @throws RuntimeException if traversal fails + */ @Override public RelNode visit(Project project) throws RuntimeException { for (CorrelationId id : project.getVariablesSet()) { @@ -84,13 +136,25 @@ public RelNode visit(Project project) throws RuntimeException { return super.visit(project); } + /** Rex visitor used to track correlation depth within expressions and subqueries. */ private static class RexVisitor extends RexShuttle { final OuterReferenceResolver referenceResolver; + /** + * Creates a new Rex visitor bound to the given reference resolver. + * + * @param referenceResolver the parent resolver maintaining depth maps + */ RexVisitor(OuterReferenceResolver referenceResolver) { this.referenceResolver = referenceResolver; } + /** + * Increments correlation depth when entering a subquery and decrements when exiting. + * + * @param subQuery the subquery expression + * @return the same subquery + */ @Override public RexNode visitSubQuery(RexSubQuery subQuery) { referenceResolver.nestedDepth.replaceAll((k, v) -> v + 1); @@ -101,6 +165,12 @@ public RexNode visitSubQuery(RexSubQuery subQuery) { return subQuery; } + /** + * Records depth for {@link RexFieldAccess} referencing a {@link RexCorrelVariable}. + * + * @param fieldAccess the field access expression + * @return the same field access + */ @Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {