Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,47 @@
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;

/**
* Converts Calcite {@link AggregateCall} instances into Substrait aggregate {@link
* AggregateFunctionInvocation}s using configured function variants and signatures.
*
* <p>Handles special cases (e.g., approximate distinct count) and collation/sort fields.
*/
public class AggregateFunctionConverter
extends FunctionConverter<
SimpleExtension.AggregateFunctionVariant,
AggregateFunctionInvocation,
AggregateFunctionConverter.WrappedAggregateCall> {

/**
* Returns the supported aggregate signatures used for matching functions.
*
* @return immutable list of aggregate signatures
*/
@Override
protected ImmutableList<FunctionMappings.Sig> getSigs() {
return FunctionMappings.AGGREGATE_SIGS;
}

/**
* Creates a converter with the given function variants and type factory.
*
* @param functions available aggregate function variants
* @param typeFactory Calcite type factory
*/
public AggregateFunctionConverter(
List<SimpleExtension.AggregateFunctionVariant> functions, RelDataTypeFactory typeFactory) {
super(functions, typeFactory);
}

/**
* Creates a converter with additional signatures and a type converter.
*
* @param functions available aggregate function variants
* @param additionalSignatures extra signatures to consider
* @param typeFactory Calcite type factory
* @param typeConverter Substrait type converter
*/
public AggregateFunctionConverter(
List<SimpleExtension.AggregateFunctionVariant> functions,
List<FunctionMappings.Sig> additionalSignatures,
Expand All @@ -48,6 +73,15 @@ public AggregateFunctionConverter(
super(functions, additionalSignatures, typeFactory, typeConverter);
}

/**
* Builds a Substrait aggregate invocation from the matched call and arguments.
*
* @param call wrapped aggregate call
* @param function matched Substrait function variant
* @param arguments converted arguments
* @param outputType result type of the invocation
* @return aggregate function invocation
*/
@Override
protected AggregateFunctionInvocation generateBinding(
WrappedAggregateCall call,
Expand Down Expand Up @@ -75,6 +109,15 @@ protected AggregateFunctionInvocation generateBinding(
arguments);
}

/**
* Attempts to convert a Calcite aggregate call to a Substrait invocation.
*
* @param input input relational node
* @param inputType Substrait input struct type
* @param call Calcite aggregate call
* @param topLevelConverter converter for RexNodes to Expressions
* @return optional Substrait aggregate invocation
*/
public Optional<AggregateFunctionInvocation> convert(
RelNode input,
Type.Struct inputType,
Expand All @@ -93,6 +136,12 @@ public Optional<AggregateFunctionInvocation> convert(
return m.attemptMatch(wrapped, topLevelConverter);
}

/**
* Resolves the appropriate function finder, applying Substrait-specific variants when needed.
*
* @param call Calcite aggregate call
* @return function finder for the resolved aggregate function, or {@code null} if none
*/
protected FunctionFinder getFunctionFinder(AggregateCall call) {
// replace COUNT() + distinct == true and approximate == true with APPROX_COUNT_DISTINCT
// before converting into substrait function
Expand All @@ -108,12 +157,21 @@ protected FunctionFinder getFunctionFinder(AggregateCall call) {
return signatures.get(lookupFunction);
}

/** Lightweight wrapper around {@link AggregateCall} providing operands and type access. */
static class WrappedAggregateCall implements FunctionConverter.GenericCall {
private final AggregateCall call;
private final RelNode input;
private final RexBuilder rexBuilder;
private final Type.Struct inputType;

/**
* Creates a new wrapped aggregate call.
*
* @param call underlying Calcite aggregate call
* @param input input relational node
* @param rexBuilder Rex builder for operand construction
* @param inputType Substrait input struct type
*/
private WrappedAggregateCall(
AggregateCall call, RelNode input, RexBuilder rexBuilder, Type.Struct inputType) {
this.call = call;
Expand All @@ -122,15 +180,30 @@ private WrappedAggregateCall(
this.inputType = inputType;
}

/**
* Returns operands as input references over the argument list.
*
* @return stream of RexNode operands
*/
@Override
public Stream<RexNode> getOperands() {
return call.getArgList().stream().map(r -> rexBuilder.makeInputRef(input, r));
}

/**
* Exposes the underlying Calcite aggregate call.
*
* @return the aggregate call
*/
public AggregateCall getUnderlying() {
return call;
}

/**
* Returns the type of the aggregate call result.
*
* @return Calcite result type
*/
@Override
public RelDataType getType() {
return call.getType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,24 @@
import org.apache.calcite.sql.SqlKind;
import org.jspecify.annotations.Nullable;

/**
* Collection of small, composable {@link CallConverter}s for common Calcite {@link RexCall}s (e.g.,
* CAST, CASE, REINTERPRET, SEARCH). Each converter returns a Substrait {@link Expression} or {@code
* null} when the call is not handled.
*
* <p>Use {@link #defaults(TypeConverter)} to get a standard set.
*/
public class CallConverters {

/**
* Converter for {@link SqlKind#CAST} and {@link SqlKind#SAFE_CAST} to Substrait {@link
* Expression.Cast}.
*
* <p>On SAFE_CAST, sets {@link Expression.FailureBehavior#RETURN_NULL}; otherwise
* THROW_EXCEPTION.
*
* @see ExpressionCreator#cast(Type, Expression, Expression.FailureBehavior)
*/
public static Function<TypeConverter, SimpleCallConverter> CAST =
typeConverter ->
(call, visitor) -> {
Expand Down Expand Up @@ -182,6 +198,9 @@ else if (operand instanceof Expression.StructLiteral
* Expand {@link org.apache.calcite.util.Sarg} values in a calcite `SqlSearchOperator` into
* simpler expressions. The expansion logic is encoded in {@link RexUtil#expandSearch(RexBuilder,
* RexProgram, RexNode)}
*
* <p>Returns a factory of {@link SimpleCallConverter} that expands SEARCH calls using the
* provided {@link RexBuilder}
*/
public static Function<RexBuilder, SimpleCallConverter> CREATE_SEARCH_CONV =
(RexBuilder rexBuilder) ->
Expand All @@ -195,6 +214,12 @@ else if (operand instanceof Expression.StructLiteral
}
};

/**
* Returns the default set of converters for common calls.
*
* @param typeConverter type mapper between Substrait and Calcite types
* @return list of default {@link CallConverter}s
*/
public static List<CallConverter> defaults(TypeConverter typeConverter) {
return ImmutableList.of(
new FieldSelectionConverter(typeConverter),
Expand All @@ -206,10 +231,27 @@ public static List<CallConverter> defaults(TypeConverter typeConverter) {
new SqlMapValueConstructorCallConverter());
}

/** Minimal interface for single-call converters used by {@link CallConverter}. */
public interface SimpleCallConverter extends CallConverter {

/**
* Converts a given {@link RexCall} to a Substrait {@link Expression}, or returns {@code null}
* if not handled.
*
* @param call the Calcite call to convert
* @param topLevelConverter converter for nested {@link RexNode} operands
* @return converted expression, or {@code null} if not applicable
*/
@Nullable Expression apply(RexCall call, Function<RexNode, Expression> topLevelConverter);

/**
* Default adapter to {@link CallConverter#convert(RexCall, Function)} returning {@link
* Optional#empty()} when {@link #apply(RexCall, Function)} returns {@code null}.
*
* @param call the Calcite call to convert
* @param topLevelConverter converter for nested {@link RexNode} operands
* @return optional converted expression
*/
@Override
default Optional<Expression> convert(
RexCall call, Function<RexNode, Expression> topLevelConverter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,34 @@ public class ExpressionRexConverter

private static final long MILLIS_IN_DAY = TimeUnit.DAYS.toMillis(1);

/** Calcite {@link RelDataTypeFactory} used for creating and managing relational types. */
protected final RelDataTypeFactory typeFactory;

/** Converter for mapping between Substrait and Calcite types. */
protected final TypeConverter typeConverter;

/** Calcite {@link RexBuilder} for constructing {@link org.apache.calcite.rex.RexNode}s. */
protected final RexBuilder rexBuilder;

/** Converter for Substrait scalar function invocations to Calcite {@link SqlOperator}s. */
protected final ScalarFunctionConverter scalarFunctionConverter;

/** Converter for Substrait window function invocations to Calcite {@link SqlOperator}s. */
protected final WindowFunctionConverter windowFunctionConverter;

/** Converter for Substrait relational nodes to Calcite {@link RelNode}s, used for subqueries. */
protected SubstraitRelNodeConverter relNodeConverter;

/**
* Creates an {@code ExpressionRexConverter} for converting Substrait expressions to Calcite Rex
* nodes.
*
* @param typeFactory Calcite {@link org.apache.calcite.rel.type.RelDataTypeFactory} for type
* creation
* @param scalarFunctionConverter converter for scalar function invocations
* @param windowFunctionConverter converter for window function invocations
* @param typeConverter converter for Substrait ↔ Calcite type mappings
*/
public ExpressionRexConverter(
RelDataTypeFactory typeFactory,
ScalarFunctionConverter scalarFunctionConverter,
Expand All @@ -101,6 +122,11 @@ public ExpressionRexConverter(
this.windowFunctionConverter = windowFunctionConverter;
}

/**
* Sets the {@link SubstraitRelNodeConverter} used for converting subqueries.
*
* @param substraitRelNodeConverter converter for Substrait relational nodes
*/
public void setRelNodeConverter(final SubstraitRelNodeConverter substraitRelNodeConverter) {
this.relNodeConverter = substraitRelNodeConverter;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,42 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Converts field selections from Calcite representation. */
/**
* Converts Calcite {@link RexCall} ITEM operators into Substrait {@link FieldReference}
* expressions.
*
* <p>Handles dereferencing of ROW, ARRAY, and MAP types using literal indices or keys.
*/
public class FieldSelectionConverter implements CallConverter {
private static final Logger LOGGER = LoggerFactory.getLogger(FieldSelectionConverter.class);

private final TypeConverter typeConverter;

/**
* Creates a converter for field selection operations.
*
* @param typeConverter converter for Substrait ↔ Calcite type mappings
*/
public FieldSelectionConverter(TypeConverter typeConverter) {
super();
this.typeConverter = typeConverter;
}

/**
* Converts a Calcite ITEM operator into a Substrait {@link FieldReference}, if applicable.
*
* <p>Supports:
*
* <ul>
* <li>ROW dereference by integer index
* <li>ARRAY dereference by integer index
* <li>MAP dereference by string key
* </ul>
*
* @param call the Calcite ITEM operator call
* @param topLevelConverter function to convert nested operands
* @return an {@link Optional} containing the converted expression, or empty if not applicable
*/
@Override
public Optional<Expression> convert(
RexCall call, Function<RexNode, Expression> topLevelConverter) {
Expand Down Expand Up @@ -96,6 +121,12 @@ public Optional<Expression> convert(
return Optional.empty();
}

/**
* Converts a numeric literal to an integer index.
*
* @param l literal to convert
* @return optional integer value, empty if not numeric
*/
private Optional<Integer> toInt(Expression.Literal l) {
if (l instanceof Expression.I8Literal) {
return Optional.of(((Expression.I8Literal) l).value());
Expand All @@ -110,6 +141,12 @@ private Optional<Integer> toInt(Expression.Literal l) {
return Optional.empty();
}

/**
* Converts a fixed-char literal to a string key.
*
* @param l literal to convert
* @return optional string value, empty if not a fixed-char literal
*/
public Optional<String> toString(Expression.Literal l) {
if (!(l instanceof Expression.FixedCharLiteral)) {
LOGGER.atWarn().log("Literal expected to be char type but was not. {}", l);
Expand Down
Loading