Skip to content

Commit 5ee2b63

Browse files
fix(javadoc): partial isthmus (#758)
Signed-off-by: MBWhite <whitemat@uk.ibm.com> Co-authored-by: Mark S. Lewis <Mark.S.Lewis@outlook.com>
1 parent 9718899 commit 5ee2b63

4 files changed

Lines changed: 119 additions & 27 deletions

File tree

isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,35 @@
1111
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
1212
import org.apache.calcite.sql.type.ReturnTypes;
1313

14+
/**
15+
* Provides Substrait-specific variants of Calcite aggregate functions to ensure type inference
16+
* matches Substrait expectations.
17+
*
18+
* <p>Default Calcite implementations may infer return types that differ from Substrait, causing
19+
* conversion issues. This class overrides those behaviors.
20+
*/
1421
public class AggregateFunctions {
1522

16-
// For some arithmetic aggregate functions, the default Calcite aggregate function implementations
17-
// will infer return types that differ from those expected by Substrait.
18-
// This type mismatch can cause conversion and planning failures.
19-
23+
/** Substrait-specific MIN aggregate function (nullable return type). */
2024
public static SqlAggFunction MIN = new SubstraitSqlMinMaxAggFunction(SqlKind.MIN);
25+
26+
/** Substrait-specific MAX aggregate function (nullable return type). */
2127
public static SqlAggFunction MAX = new SubstraitSqlMinMaxAggFunction(SqlKind.MAX);
28+
29+
/** Substrait-specific AVG aggregate function (nullable return type). */
2230
public static SqlAggFunction AVG = new SubstraitAvgAggFunction(SqlKind.AVG);
31+
32+
/** Substrait-specific SUM aggregate function (nullable return type). */
2333
public static SqlAggFunction SUM = new SubstraitSumAggFunction();
34+
35+
/** Substrait-specific SUM0 aggregate function (non-null BIGINT return type). */
2436
public static SqlAggFunction SUM0 = new SubstraitSumEmptyIsZeroAggFunction();
2537

2638
/**
27-
* Some Calcite rules, like {@link
28-
* org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule}, introduce the default
29-
* Calcite aggregate functions into plans.
30-
*
31-
* <p>When converting these Calcite plans to Substrait, we need to convert the default Calcite
32-
* aggregate calls to the Substrait specific variants.
33-
*
34-
* <p>This function attempts to convert the given {@code aggFunction} to its Substrait equivalent
39+
* Converts default Calcite aggregate functions to Substrait-specific variants when needed.
3540
*
36-
* @param aggFunction the {@link SqlAggFunction} to convert to a Substrait specific variant
37-
* @return an optional containing the Substrait equivalent of the given {@code aggFunction} if
38-
* conversion was needed, empty otherwise.
41+
* @param aggFunction the Calcite aggregate function
42+
* @return optional containing Substrait equivalent if conversion applies
3943
*/
4044
public static Optional<SqlAggFunction> toSubstraitAggVariant(SqlAggFunction aggFunction) {
4145
if (aggFunction instanceof SqlMinMaxAggFunction) {
@@ -53,7 +57,7 @@ public static Optional<SqlAggFunction> toSubstraitAggVariant(SqlAggFunction aggF
5357
}
5458
}
5559

56-
/** Extension of {@link SqlMinMaxAggFunction} that ALWAYS infers a nullable return type */
60+
/** Substrait variant of {@link SqlMinMaxAggFunction} that forces nullable return type. */
5761
private static class SubstraitSqlMinMaxAggFunction extends SqlMinMaxAggFunction {
5862
public SubstraitSqlMinMaxAggFunction(SqlKind kind) {
5963
super(kind);
@@ -65,12 +69,10 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
6569
}
6670
}
6771

68-
/** Extension of {@link SqlSumAggFunction} that ALWAYS infers a nullable return type */
72+
/** Substrait variant of {@link SqlSumAggFunction} that forces nullable return type. */
6973
private static class SubstraitSumAggFunction extends SqlSumAggFunction {
7074
public SubstraitSumAggFunction() {
71-
// This is intentionally null
72-
// See the instantiation of SqlSumAggFunction in SqlStdOperatorTable
73-
super(null);
75+
super(null); // Matches Calcite's instantiation pattern
7476
}
7577

7678
@Override
@@ -79,7 +81,7 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
7981
}
8082
}
8183

82-
/** Extension of {@link SqlAvgAggFunction} that ALWAYS infers a nullable return type */
84+
/** Substrait variant of {@link SqlAvgAggFunction} that forces nullable return type. */
8385
private static class SubstraitAvgAggFunction extends SqlAvgAggFunction {
8486
public SubstraitAvgAggFunction(SqlKind kind) {
8587
super(kind);
@@ -92,8 +94,8 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
9294
}
9395

9496
/**
95-
* Extension of {@link SqlSumEmptyIsZeroAggFunction} that ALWAYS infers a NOT NULL BIGINT return
96-
* type
97+
* Substrait variant of {@link SqlSumEmptyIsZeroAggFunction} that forces BIGINT return type and
98+
* uses a user-friendly name.
9799
*/
98100
private static class SubstraitSumEmptyIsZeroAggFunction
99101
extends org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction {
@@ -103,8 +105,7 @@ public SubstraitSumEmptyIsZeroAggFunction() {
103105

104106
@Override
105107
public String getName() {
106-
// the default name for this function is `$sum0`
107-
// override this to `sum0` which is a nicer name to use in queries
108+
// Override default `$sum0` with `sum0` for readability
108109
return "sum0";
109110
}
110111

isthmus/src/main/java/io/substrait/isthmus/CallConverter.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,22 @@
66
import org.apache.calcite.rex.RexCall;
77
import org.apache.calcite.rex.RexNode;
88

9+
/**
10+
* Functional interface for converting Calcite {@link RexCall} expressions into Substrait {@link
11+
* Expression}s.
12+
*
13+
* <p>Implementations should return an {@link Optional} containing the converted expression, or
14+
* {@link Optional#empty()} if the call is not handled.
15+
*/
916
@FunctionalInterface
1017
public interface CallConverter {
18+
19+
/**
20+
* Converts a Calcite {@link RexCall} into a Substrait {@link Expression}.
21+
*
22+
* @param call the Calcite function/operator call to convert
23+
* @param topLevelConverter a function for converting nested {@link RexNode} operands
24+
* @return an {@link Optional} containing the converted expression, or empty if not applicable
25+
*/
1126
Optional<Expression> convert(RexCall call, Function<RexNode, Expression> topLevelConverter);
1227
}

isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
import java.util.Set;
88
import java.util.stream.Collectors;
99

10+
/**
11+
* Utility methods for working with Substrait extensions.
12+
*
13+
* <p>Provides helpers to identify and extract dynamic (custom/user-defined) functions from an
14+
* {@link io.substrait.extension.SimpleExtension.ExtensionCollection}.
15+
*/
1016
public class ExtensionUtils {
1117

1218
/**

isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,52 @@
1515
import org.apache.calcite.rex.RexSubQuery;
1616
import org.apache.calcite.rex.RexUtil.SubQueryCollector;
1717

18-
/** Resolve correlated variable and get Depth map for RexFieldAccess */
19-
// See OuterReferenceResolver.md for explanation how the Depth map is computed.
18+
/**
19+
* Resolve correlated variables and compute a depth map for {@link RexFieldAccess}.
20+
*
21+
* <p>Traverses a {@link RelNode} tree and:
22+
*
23+
* <ul>
24+
* <li>Tracks nesting depth of {@link CorrelationId}s across filters, projects, subqueries, and
25+
* correlates
26+
* <li>Computes "steps out" for each {@link RexFieldAccess} referencing a {@link
27+
* RexCorrelVariable}
28+
* </ul>
29+
*
30+
* See OuterReferenceResolver.md for details on how the depth map is computed.
31+
*/
2032
public class OuterReferenceResolver extends RelNodeVisitor<RelNode, RuntimeException> {
2133

2234
private final Map<CorrelationId, Integer> nestedDepth;
2335
private final Map<RexFieldAccess, Integer> fieldAccessDepthMap;
2436

2537
private final RexVisitor rexVisitor = new RexVisitor(this);
2638

39+
/** Creates a new resolver with empty depth tracking maps. */
2740
public OuterReferenceResolver() {
2841
nestedDepth = new HashMap<>();
2942
fieldAccessDepthMap = new IdentityHashMap<>();
3043
}
3144

45+
/**
46+
* Applies the resolver to a {@link RelNode} tree, computing the depth map.
47+
*
48+
* @param r the root relational node
49+
* @return the same node after traversal
50+
* @throws RuntimeException if the visitor encounters an unrecoverable condition
51+
*/
3252
public Map<RexFieldAccess, Integer> apply(RelNode r) {
3353
reverseAccept(r);
3454
return fieldAccessDepthMap;
3555
}
3656

57+
/**
58+
* Visits a {@link Filter}, registering any correlation variables and visiting its condition.
59+
*
60+
* @param filter the filter node
61+
* @return the result of {@link RelNodeVisitor#visit(Filter)}
62+
* @throws RuntimeException if traversal fails
63+
*/
3764
@Override
3865
public RelNode visit(Filter filter) throws RuntimeException {
3966
for (CorrelationId id : filter.getVariablesSet()) {
@@ -43,6 +70,16 @@ public RelNode visit(Filter filter) throws RuntimeException {
4370
return super.visit(filter);
4471
}
4572

73+
/**
74+
* Visits a {@link Correlate}, handling correlation depth for both sides.
75+
*
76+
* <p>Special case: the right side is a correlated subquery in the rel tree (not a REX), so we
77+
* manually adjust depth before/after visiting it.
78+
*
79+
* @param correlate the correlate (correlated join) node
80+
* @return the correlate node
81+
* @throws RuntimeException if traversal fails
82+
*/
4683
@Override
4784
public RelNode visit(Correlate correlate) throws RuntimeException {
4885
for (CorrelationId id : correlate.getVariablesSet()) {
@@ -63,6 +100,13 @@ public RelNode visit(Correlate correlate) throws RuntimeException {
63100
return correlate;
64101
}
65102

103+
/**
104+
* Visits a generic {@link RelNode}, applying the resolver to all the node inputs.
105+
*
106+
* @param other the node to visit
107+
* @return the node
108+
* @throws RuntimeException if traversal fails
109+
*/
66110
@Override
67111
public RelNode visitOther(RelNode other) throws RuntimeException {
68112
for (RelNode child : other.getInputs()) {
@@ -71,6 +115,14 @@ public RelNode visitOther(RelNode other) throws RuntimeException {
71115
return other;
72116
}
73117

118+
/**
119+
* Visits a {@link Project}, registering correlation variables and visiting any subqueries within
120+
* its expressions.
121+
*
122+
* @param project the project node
123+
* @return the result of {@link RelNodeVisitor#visit(Project)}
124+
* @throws RuntimeException if traversal fails
125+
*/
74126
@Override
75127
public RelNode visit(Project project) throws RuntimeException {
76128
for (CorrelationId id : project.getVariablesSet()) {
@@ -84,13 +136,25 @@ public RelNode visit(Project project) throws RuntimeException {
84136
return super.visit(project);
85137
}
86138

139+
/** Rex visitor used to track correlation depth within expressions and subqueries. */
87140
private static class RexVisitor extends RexShuttle {
88141
final OuterReferenceResolver referenceResolver;
89142

143+
/**
144+
* Creates a new Rex visitor bound to the given reference resolver.
145+
*
146+
* @param referenceResolver the parent resolver maintaining depth maps
147+
*/
90148
RexVisitor(OuterReferenceResolver referenceResolver) {
91149
this.referenceResolver = referenceResolver;
92150
}
93151

152+
/**
153+
* Increments correlation depth when entering a subquery and decrements when exiting.
154+
*
155+
* @param subQuery the subquery expression
156+
* @return the same subquery
157+
*/
94158
@Override
95159
public RexNode visitSubQuery(RexSubQuery subQuery) {
96160
referenceResolver.nestedDepth.replaceAll((k, v) -> v + 1);
@@ -101,6 +165,12 @@ public RexNode visitSubQuery(RexSubQuery subQuery) {
101165
return subQuery;
102166
}
103167

168+
/**
169+
* Records depth for {@link RexFieldAccess} referencing a {@link RexCorrelVariable}.
170+
*
171+
* @param fieldAccess the field access expression
172+
* @return the same field access
173+
*/
104174
@Override
105175
public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
106176
if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {

0 commit comments

Comments
 (0)