Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,48 @@
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;

/**
* Converts Calcite {@link RexNode} trees to Substrait {@link Expression}s.
*
* <p>Delegates function calls to registered {@link CallConverter}s and supports window function
* conversion via {@link WindowFunctionConverter}. Some Rex node kinds are intentionally unsupported
* and will throw {@link UnsupportedOperationException}.
*/
public class RexExpressionConverter implements RexVisitor<Expression> {

private final List<CallConverter> callConverters;
private final SubstraitRelVisitor relVisitor;
private final TypeConverter typeConverter;
private WindowFunctionConverter windowFunctionConverter;

/**
* Creates a converter with an explicit {@link SubstraitRelVisitor} and one or more call
* converters.
*
* @param relVisitor visitor used to convert subqueries/relations
* @param callConverters converters for Rex calls
*/
public RexExpressionConverter(SubstraitRelVisitor relVisitor, CallConverter... callConverters) {
this(relVisitor, Arrays.asList(callConverters), null, TypeConverter.DEFAULT);
}

/**
* Creates a converter with the given call converters and default {@link TypeConverter}.
*
* @param callConverters converters for Rex calls
*/
public RexExpressionConverter(CallConverter... callConverters) {
this(null, Arrays.asList(callConverters), null, TypeConverter.DEFAULT);
}

/**
* Creates a converter with full configuration.
*
* @param relVisitor visitor used to convert subqueries/relations; may be {@code null}
* @param callConverters converters for Rex calls
* @param windowFunctionConverter converter for window functions; may be {@code null}
* @param typeConverter converter from Calcite types to Substrait types
*/
public RexExpressionConverter(
SubstraitRelVisitor relVisitor,
List<CallConverter> callConverters,
Expand All @@ -59,19 +86,34 @@ public RexExpressionConverter(
}

/**
* Only used for testing. Missing `ScalarFunctionConverter`, `CallConverters.CREATE_SEARCH_CONV`
* Testing-only constructor that wires default converters.
*
* <p>Missing {@code ScalarFunctionConverter} and {@code CallConverters.CREATE_SEARCH_CONV}.
*/
public RexExpressionConverter() {
this(null, CallConverters.defaults(TypeConverter.DEFAULT), null, TypeConverter.DEFAULT);
// TODO: Hide this AND/OR UPDATE tests
}

/**
* Converts a {@link RexInputRef} to a root struct field reference.
*
* @param inputRef the input reference
* @return a Substrait field reference expression
*/
@Override
public Expression visitInputRef(RexInputRef inputRef) {
return FieldReference.newRootStructReference(
inputRef.getIndex(), typeConverter.toSubstrait(inputRef.getType()));
}

/**
* Converts a {@link RexCall} using registered {@link CallConverter}s.
*
* @param call the Rex call node
* @return the converted Substrait expression
* @throws IllegalArgumentException if no converter can handle the call
*/
@Override
public Expression visitCall(RexCall call) {
for (CallConverter c : callConverters) {
Expand All @@ -84,6 +126,12 @@ public Expression visitCall(RexCall call) {
throw new IllegalArgumentException(callConversionFailureMessage(call));
}

/**
* Builds a concise failure message for an unsupported call conversion.
*
* @param call the Rex call node
* @return a human-readable message describing the failure
*/
private String callConversionFailureMessage(RexCall call) {
return String.format(
"Unable to convert call %s(%s).",
Expand All @@ -93,11 +141,24 @@ private String callConversionFailureMessage(RexCall call) {
.collect(Collectors.joining(", ")));
}

/**
* Converts a {@link RexLiteral} to a Substrait literal expression.
*
* @param literal the Rex literal
* @return the converted Substrait expression
*/
@Override
public Expression visitLiteral(RexLiteral literal) {
return (new LiteralConverter(typeConverter)).convert(literal);
}

/**
* Converts a {@link RexOver} window function call.
*
* @param over the windowed call
* @return the converted Substrait expression
* @throws IllegalArgumentException if {@code IGNORE NULLS} is used or conversion fails
*/
@Override
public Expression visitOver(RexOver over) {
if (over.ignoreNulls()) {
Expand All @@ -109,21 +170,49 @@ public Expression visitOver(RexOver over) {
.orElseThrow(() -> new IllegalArgumentException(callConversionFailureMessage(over)));
}

/**
* Not supported.
*
* @param correlVariable the correl variable
* @return never returns
* @throws UnsupportedOperationException always
*/
@Override
public Expression visitCorrelVariable(RexCorrelVariable correlVariable) {
throw new UnsupportedOperationException("RexCorrelVariable not supported");
}

/**
* Not supported.
*
* @param dynamicParam the dynamic parameter
* @return never returns
* @throws UnsupportedOperationException always
*/
@Override
public Expression visitDynamicParam(RexDynamicParam dynamicParam) {
throw new UnsupportedOperationException("RexDynamicParam not supported");
}

/**
* Not supported.
*
* @param rangeRef the range ref
* @return never returns
* @throws UnsupportedOperationException always
*/
@Override
public Expression visitRangeRef(RexRangeRef rangeRef) {
throw new UnsupportedOperationException("RexRangeRef not supported");
}

/**
* Converts a {@link RexFieldAccess} to a Substrait field reference expression.
*
* @param fieldAccess the field access
* @return the converted Substrait expression
* @throws UnsupportedOperationException for unsupported reference kinds
*/
@Override
public Expression visitFieldAccess(RexFieldAccess fieldAccess) {
SqlKind kind = fieldAccess.getReferenceExpr().getKind();
Expand Down Expand Up @@ -155,6 +244,13 @@ public Expression visitFieldAccess(RexFieldAccess fieldAccess) {
}
}

/**
* Converts a {@link RexSubQuery} into a Substrait set or scalar subquery expression.
*
* @param subQuery the subquery node
* @return the converted Substrait expression
* @throws UnsupportedOperationException for unsupported subquery operators
*/
@Override
public Expression visitSubQuery(RexSubQuery subQuery) {
Rel rel = relVisitor.apply(subQuery.rel);
Expand Down Expand Up @@ -185,31 +281,73 @@ public Expression visitSubQuery(RexSubQuery subQuery) {
throw new UnsupportedOperationException("RexSubQuery not supported");
}

/**
* Not supported.
*
* @param fieldRef the table input reference
* @return never returns
* @throws UnsupportedOperationException always
*/
@Override
public Expression visitTableInputRef(RexTableInputRef fieldRef) {
throw new UnsupportedOperationException("RexTableInputRef not supported");
}

/**
* Not supported.
*
* @param localRef the local reference
* @return never returns
* @throws UnsupportedOperationException always
*/
@Override
public Expression visitLocalRef(RexLocalRef localRef) {
throw new UnsupportedOperationException("RexLocalRef not supported");
}

/**
* Not supported.
*
* @param fieldRef the pattern field reference
* @return never returns
* @throws UnsupportedOperationException always
*/
@Override
public Expression visitPatternFieldRef(RexPatternFieldRef fieldRef) {
throw new UnsupportedOperationException("RexPatternFieldRef not supported");
}

/**
* Not supported.
*
* @param rexLambda the lambda
* @return never returns
* @throws UnsupportedOperationException always
*/
@Override
public Expression visitLambda(RexLambda rexLambda) {
throw new UnsupportedOperationException("RexLambda not supported");
}

/**
* Not supported.
*
* @param rexLambdaRef the lambda reference
* @return never returns
* @throws UnsupportedOperationException always
*/
@Override
public Expression visitLambdaRef(RexLambdaRef rexLambdaRef) {
throw new UnsupportedOperationException("RexLambdaRef not supported");
}

/**
* Not supported.
*
* @param nodeAndFieldIndex the node/field index wrapper
* @return never returns
* @throws UnsupportedOperationException always
*/
@Override
public Expression visitNodeAndFieldIndex(RexNodeAndFieldIndex nodeAndFieldIndex) {
throw new UnsupportedOperationException("RexNodeAndFieldIndex not supported");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;

/**
* Converts Calcite {@link RexCall} scalar functions to Substrait {@link Expression} using known
* Substrait {@link SimpleExtension.ScalarFunctionVariant} declarations.
*
* <p>Supports custom function mappers for special cases (e.g., TRIM, SQRT), and falls back to
* default signature-based matching. Produces {@link Expression.ScalarFunctionInvocation}.
*/
public class ScalarFunctionConverter
extends FunctionConverter<
SimpleExtension.ScalarFunctionVariant,
Expand All @@ -30,11 +37,25 @@ public class ScalarFunctionConverter
*/
private final List<ScalarFunctionMapper> mappers;

/**
* Creates a converter with the given functions and type factory.
*
* @param functions available Substrait scalar function variants
* @param typeFactory Calcite type factory for type conversions
*/
public ScalarFunctionConverter(
List<SimpleExtension.ScalarFunctionVariant> functions, RelDataTypeFactory typeFactory) {
this(functions, Collections.emptyList(), typeFactory, TypeConverter.DEFAULT);
}

/**
* Creates a converter with additional signatures and a custom type converter.
*
* @param functions available Substrait scalar function variants
* @param additionalSignatures extra Calcite-to-Substrait signature mappings
* @param typeFactory Calcite type factory for type conversions
* @param typeConverter converter for Calcite {@link RelDataType} to Substrait {@link Type}
*/
public ScalarFunctionConverter(
List<SimpleExtension.ScalarFunctionVariant> functions,
List<FunctionMappings.Sig> additionalSignatures,
Expand All @@ -53,11 +74,24 @@ public ScalarFunctionConverter(
new StrptimeTimestampFunctionMapper(functions));
}

/**
* Returns the set of known scalar function signatures.
*
* @return immutable list of scalar signatures
*/
@Override
protected ImmutableList<FunctionMappings.Sig> getSigs() {
return FunctionMappings.SCALAR_SIGS;
}

/**
* Converts a {@link RexCall} into a Substrait {@link Expression}, applying any registered custom
* mapping first, then default matching if needed.
*
* @param call the Calcite function call to convert
* @param topLevelConverter converter for nested operands
* @return the converted expression if a match is found; otherwise {@link Optional#empty()}
*/
@Override
public Optional<Expression> convert(
RexCall call, Function<RexNode, Expression> topLevelConverter) {
Expand Down Expand Up @@ -117,6 +151,15 @@ private boolean isPotentialFunctionMatch(FunctionFinder finder, WrappedScalarCal
return Objects.nonNull(finder) && finder.allowedArgCount((int) call.getOperands().count());
}

/**
* Builds an {@link Expression.ScalarFunctionInvocation} for a matched function.
*
* @param call the wrapped Calcite call providing operands and type
* @param function the Substrait scalar function declaration to invoke
* @param arguments converted argument list for the invocation
* @param outputType the Substrait output type for the invocation
* @return a scalar function invocation expression
*/
@Override
protected Expression generateBinding(
WrappedScalarCall call,
Expand All @@ -130,6 +173,13 @@ protected Expression generateBinding(
.build();
}

/**
* Returns the Substrait arguments for a given scalar invocation, applying any custom mapping if
* present; otherwise returns the invocation's own arguments.
*
* @param expression the scalar function invocation
* @return the argument list, possibly remapped; never {@code null}
*/
public List<FunctionArg> getExpressionArguments(Expression.ScalarFunctionInvocation expression) {
// If a mapping applies to this expression, use it to get the arguments; otherwise default
// behavior.
Expand All @@ -145,6 +195,11 @@ private Optional<List<FunctionArg>> getMappedExpressionArguments(
.orElse(Optional.empty());
}

/**
* Wrapped view of a {@link RexCall} for signature matching.
*
* <p>Provides operand stream and type info used by {@link FunctionFinder}.
*/
protected static class WrappedScalarCall implements FunctionConverter.GenericCall {

private final RexCall delegate;
Expand All @@ -153,11 +208,21 @@ private WrappedScalarCall(RexCall delegate) {
this.delegate = delegate;
}

/**
* Returns the operand stream of the underlying {@link RexCall}.
*
* @return stream of operands
*/
@Override
public Stream<RexNode> getOperands() {
return delegate.getOperands().stream();
}

/**
* Returns the Calcite type of the underlying {@link RexCall}.
*
* @return call type
*/
@Override
public RelDataType getType() {
return delegate.getType();
Expand Down
Loading
Loading