diff --git a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java index 6cba80781..0d5d5bf0e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java +++ b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java @@ -11,31 +11,35 @@ import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction; import org.apache.calcite.sql.type.ReturnTypes; +/** + * Provides Substrait-specific variants of Calcite aggregate functions to ensure type inference + * matches Substrait expectations. + * + *
Default Calcite implementations may infer return types that differ from Substrait, causing + * conversion issues. This class overrides those behaviors. + */ public class AggregateFunctions { - // For some arithmetic aggregate functions, the default Calcite aggregate function implementations - // will infer return types that differ from those expected by Substrait. - // This type mismatch can cause conversion and planning failures. - + /** Substrait-specific MIN aggregate function (nullable return type). */ public static SqlAggFunction MIN = new SubstraitSqlMinMaxAggFunction(SqlKind.MIN); + + /** Substrait-specific MAX aggregate function (nullable return type). */ public static SqlAggFunction MAX = new SubstraitSqlMinMaxAggFunction(SqlKind.MAX); + + /** Substrait-specific AVG aggregate function (nullable return type). */ public static SqlAggFunction AVG = new SubstraitAvgAggFunction(SqlKind.AVG); + + /** Substrait-specific SUM aggregate function (nullable return type). */ public static SqlAggFunction SUM = new SubstraitSumAggFunction(); + + /** Substrait-specific SUM0 aggregate function (non-null BIGINT return type). */ public static SqlAggFunction SUM0 = new SubstraitSumEmptyIsZeroAggFunction(); /** - * Some Calcite rules, like {@link - * org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule}, introduce the default - * Calcite aggregate functions into plans. - * - *
When converting these Calcite plans to Substrait, we need to convert the default Calcite - * aggregate calls to the Substrait specific variants. - * - *
This function attempts to convert the given {@code aggFunction} to its Substrait equivalent
+ * Converts default Calcite aggregate functions to Substrait-specific variants when needed.
*
- * @param aggFunction the {@link SqlAggFunction} to convert to a Substrait specific variant
- * @return an optional containing the Substrait equivalent of the given {@code aggFunction} if
- * conversion was needed, empty otherwise.
+ * @param aggFunction the Calcite aggregate function
+ * @return optional containing Substrait equivalent if conversion applies
*/
public static Optional Implementations should return an {@link Optional} containing the converted expression, or
+ * {@link Optional#empty()} if the call is not handled.
+ */
@FunctionalInterface
public interface CallConverter {
+
+ /**
+ * Converts a Calcite {@link RexCall} into a Substrait {@link Expression}.
+ *
+ * @param call the Calcite function/operator call to convert
+ * @param topLevelConverter a function for converting nested {@link RexNode} operands
+ * @return an {@link Optional} containing the converted expression, or empty if not applicable
+ */
Optional Provides helpers to identify and extract dynamic (custom/user-defined) functions from an
+ * {@link io.substrait.extension.SimpleExtension.ExtensionCollection}.
+ */
public class ExtensionUtils {
/**
diff --git a/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java b/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java
index 79410ea47..8071468c0 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java
@@ -15,8 +15,20 @@
import org.apache.calcite.rex.RexSubQuery;
import org.apache.calcite.rex.RexUtil.SubQueryCollector;
-/** Resolve correlated variable and get Depth map for RexFieldAccess */
-// See OuterReferenceResolver.md for explanation how the Depth map is computed.
+/**
+ * Resolve correlated variables and compute a depth map for {@link RexFieldAccess}.
+ *
+ * Traverses a {@link RelNode} tree and:
+ *
+ * Special case: the right side is a correlated subquery in the rel tree (not a REX), so we
+ * manually adjust depth before/after visiting it.
+ *
+ * @param correlate the correlate (correlated join) node
+ * @return the correlate node
+ * @throws RuntimeException if traversal fails
+ */
@Override
public RelNode visit(Correlate correlate) throws RuntimeException {
for (CorrelationId id : correlate.getVariablesSet()) {
@@ -63,6 +100,13 @@ public RelNode visit(Correlate correlate) throws RuntimeException {
return correlate;
}
+ /**
+ * Visits a generic {@link RelNode}, applying the resolver to all the node inputs.
+ *
+ * @param other the node to visit
+ * @return the node
+ * @throws RuntimeException if traversal fails
+ */
@Override
public RelNode visitOther(RelNode other) throws RuntimeException {
for (RelNode child : other.getInputs()) {
@@ -71,6 +115,14 @@ public RelNode visitOther(RelNode other) throws RuntimeException {
return other;
}
+ /**
+ * Visits a {@link Project}, registering correlation variables and visiting any subqueries within
+ * its expressions.
+ *
+ * @param project the project node
+ * @return the result of {@link RelNodeVisitor#visit(Project)}
+ * @throws RuntimeException if traversal fails
+ */
@Override
public RelNode visit(Project project) throws RuntimeException {
for (CorrelationId id : project.getVariablesSet()) {
@@ -84,13 +136,25 @@ public RelNode visit(Project project) throws RuntimeException {
return super.visit(project);
}
+ /** Rex visitor used to track correlation depth within expressions and subqueries. */
private static class RexVisitor extends RexShuttle {
final OuterReferenceResolver referenceResolver;
+ /**
+ * Creates a new Rex visitor bound to the given reference resolver.
+ *
+ * @param referenceResolver the parent resolver maintaining depth maps
+ */
RexVisitor(OuterReferenceResolver referenceResolver) {
this.referenceResolver = referenceResolver;
}
+ /**
+ * Increments correlation depth when entering a subquery and decrements when exiting.
+ *
+ * @param subQuery the subquery expression
+ * @return the same subquery
+ */
@Override
public RexNode visitSubQuery(RexSubQuery subQuery) {
referenceResolver.nestedDepth.replaceAll((k, v) -> v + 1);
@@ -101,6 +165,12 @@ public RexNode visitSubQuery(RexSubQuery subQuery) {
return subQuery;
}
+ /**
+ * Records depth for {@link RexFieldAccess} referencing a {@link RexCorrelVariable}.
+ *
+ * @param fieldAccess the field access expression
+ * @return the same field access
+ */
@Override
public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {
+ *
+ *
+ * See OuterReferenceResolver.md for details on how the depth map is computed.
+ */
public class OuterReferenceResolver extends RelNodeVisitor