Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions isthmus/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ dependencies {
testImplementation(platform(libs.junit.bom))
testImplementation(libs.junit.jupiter)
testRuntimeOnly(libs.junit.platform.launcher)
testRuntimeOnly(libs.slf4j.jdk14)
implementation(libs.guava)
implementation(libs.protobuf.java.util) {
exclude("com.google.guava", "guava")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package io.substrait.isthmus;

import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
import io.substrait.extension.SimpleExtension.Function;
import io.substrait.isthmus.expression.AggregateFunctionConverter;
import io.substrait.isthmus.expression.FunctionMappings;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.calcite.sql.util.SqlOperatorTables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* A ConverterProvider that automatically creates dynamic function mappings for unmapped extension
* functions.
*
* <p>This provider identifies functions in the extension collection that don't have explicit
* mappings in FunctionMappings and automatically generates SqlOperators and function signatures for
* them. This enables SQL queries to use extension functions without requiring manual mapping
* configuration.
*
* <p>Example use case: Using strftime() from functions_datetime.yaml without adding it to
* FunctionMappings.SCALAR_SIGS.
*
* @see ConverterProvider
* @see SimpleExtensionToSqlOperator
*/
public class AutomaticDynamicFunctionMappingConverterProvider extends ConverterProvider {

private static final Logger LOGGER =
LoggerFactory.getLogger(AutomaticDynamicFunctionMappingConverterProvider.class);

private final SqlOperatorTable operatorTable;

public AutomaticDynamicFunctionMappingConverterProvider() {
this(DefaultExtensionCatalog.DEFAULT_COLLECTION, SubstraitTypeSystem.TYPE_FACTORY);
}

public AutomaticDynamicFunctionMappingConverterProvider(
SimpleExtension.ExtensionCollection extensions) {
this(extensions, SubstraitTypeSystem.TYPE_FACTORY);
}

public AutomaticDynamicFunctionMappingConverterProvider(
SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory) {
super(extensions, typeFactory);

List<SqlOperator> dynamicScalarOperators = getDynamicScalarOperators();
this.scalarFunctionConverter = createScalarFunctionConverter(dynamicScalarOperators);

List<SqlOperator> dynamicAggregateOperators = getDynamicAggregateOperators();
this.aggregateFunctionConverter = createAggregateFunctionConverter(dynamicAggregateOperators);

List<SqlOperator> dynamicWindowOperators = getDynamicWindowOperators();
this.windowFunctionConverter = createWindowFunctionConverter(dynamicWindowOperators);

List<SqlOperator> allOperators =
Stream.of(dynamicScalarOperators, dynamicAggregateOperators, dynamicWindowOperators)
.flatMap(List::stream)
.collect(Collectors.toList());
this.operatorTable = buildOperatorTable(allOperators);
}

@Override
public SqlOperatorTable getSqlOperatorTable() {
return operatorTable;
}

private List<SqlOperator> getDynamicScalarOperators() {
List<SimpleExtension.ScalarFunctionVariant> unmappedFunctions =
io.substrait.isthmus.expression.FunctionConverter.getUnmappedFunctions(
extensions.scalarFunctions(), FunctionMappings.SCALAR_SIGS);

LOGGER.info(
"Dynamically mapping {} unmapped scalar functions: {}",
unmappedFunctions.size(),
unmappedFunctions.stream().map(Function::name).collect(Collectors.toList()));

return SimpleExtensionToSqlOperator.from(unmappedFunctions, typeFactory);
}

private List<SqlOperator> getDynamicAggregateOperators() {
List<SimpleExtension.AggregateFunctionVariant> unmappedFunctions =
io.substrait.isthmus.expression.FunctionConverter.getUnmappedFunctions(
extensions.aggregateFunctions(), FunctionMappings.AGGREGATE_SIGS);

LOGGER.info(
"Dynamically mapping {} unmapped aggregate functions: {}",
unmappedFunctions.size(),
unmappedFunctions.stream().map(Function::name).collect(Collectors.toList()));

return SimpleExtensionToSqlOperator.from(unmappedFunctions, typeFactory);
}

private List<SqlOperator> getDynamicWindowOperators() {
List<SimpleExtension.WindowFunctionVariant> unmappedFunctions =
io.substrait.isthmus.expression.FunctionConverter.getUnmappedFunctions(
extensions.windowFunctions(), FunctionMappings.WINDOW_SIGS);

LOGGER.info(
"Dynamically mapping {} unmapped window functions: {}",
unmappedFunctions.size(),
unmappedFunctions.stream().map(Function::name).collect(Collectors.toList()));

return SimpleExtensionToSqlOperator.from(unmappedFunctions, typeFactory);
}

private ScalarFunctionConverter createScalarFunctionConverter(
List<SqlOperator> dynamicOperators) {
List<FunctionMappings.Sig> additionalSignatures = createDynamicSignatures(dynamicOperators);
return new ScalarFunctionConverter(
extensions.scalarFunctions(), additionalSignatures, typeFactory, typeConverter);
}

private AggregateFunctionConverter createAggregateFunctionConverter(
List<SqlOperator> dynamicOperators) {
List<FunctionMappings.Sig> additionalSignatures = createDynamicSignatures(dynamicOperators);
return new AggregateFunctionConverter(
extensions.aggregateFunctions(), additionalSignatures, typeFactory, typeConverter);
}

private WindowFunctionConverter createWindowFunctionConverter(
List<SqlOperator> dynamicOperators) {
List<FunctionMappings.Sig> additionalSignatures = createDynamicSignatures(dynamicOperators);
return new WindowFunctionConverter(
extensions.windowFunctions(), additionalSignatures, typeFactory, typeConverter);
}

private List<FunctionMappings.Sig> createDynamicSignatures(List<SqlOperator> dynamicOperators) {
Map<String, SqlOperator> uniqueOperators = new LinkedHashMap<>(dynamicOperators.size());
for (SqlOperator op : dynamicOperators) {
uniqueOperators.put(op.getName().toLowerCase(Locale.ROOT), op);
}

return uniqueOperators.values().stream()
.map(op -> FunctionMappings.s(op))
.collect(Collectors.toList());
}

private SqlOperatorTable buildOperatorTable(List<SqlOperator> additionalOperators) {
SqlOperatorTable baseOperatorTable = super.getSqlOperatorTable();

if (additionalOperators.isEmpty()) {
return baseOperatorTable;
}

return SqlOperatorTables.chain(baseOperatorTable, SqlOperatorTables.of(additionalOperators));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Utility class for converting Substrait {@link SimpleExtension} function definitions (scalar and
Expand All @@ -40,6 +42,8 @@
*/
public final class SimpleExtensionToSqlOperator {

private static final Logger LOGGER = LoggerFactory.getLogger(SimpleExtensionToSqlOperator.class);

private static final RelDataTypeFactory DEFAULT_TYPE_FACTORY =
new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM);

Expand Down Expand Up @@ -85,9 +89,47 @@ public static List<SqlOperator> from(
SimpleExtension.ExtensionCollection collection,
RelDataTypeFactory typeFactory,
TypeConverter typeConverter) {
// TODO: add support for windows functions
return Stream.concat(
collection.scalarFunctions().stream(), collection.aggregateFunctions().stream())
List<? extends SimpleExtension.Function> functions =
Stream.of(
collection.scalarFunctions(),
collection.aggregateFunctions(),
collection.windowFunctions())
.flatMap(List::stream)
.collect(Collectors.toList());
return from(functions, typeFactory, typeConverter);
}

/**
* Converts a list of functions to SqlOperators. Handles scalar, aggregate, and window functions.
*
* @param functions list of functions to convert
* @param typeFactory the Calcite type factory
* @return list of SqlOperators
*/
public static List<SqlOperator> from(
List<? extends SimpleExtension.Function> functions, RelDataTypeFactory typeFactory) {
return from(functions, typeFactory, TypeConverter.DEFAULT);
}

/**
* Converts a list of functions to SqlOperators. Handles scalar, aggregate, and window functions.
*
* <p>Each function variant is converted to a separate SqlOperator. Functions with the same base
* name but different type signatures (e.g., strftime:ts_str, strftime:ts_string) are ALL added to
* the operator table. Calcite will try to match the function call arguments against all available
* operators and select the one that matches. This allows functions with multiple signatures to be
* used correctly without explicit deduplication.
*
* @param functions list of functions to convert
* @param typeFactory the Calcite type factory
* @param typeConverter the type converter
* @return list of SqlOperators
*/
public static List<SqlOperator> from(
List<? extends SimpleExtension.Function> functions,
RelDataTypeFactory typeFactory,
TypeConverter typeConverter) {
return functions.stream()
.map(function -> toSqlFunction(function, typeFactory, typeConverter))
.collect(Collectors.toList());
}
Expand Down Expand Up @@ -375,7 +417,8 @@ public SqlTypeName visit(ParameterizedType.StringLiteral expr) {
if (type.startsWith("LIST")) {
return SqlTypeName.ARRAY;
}
return super.visit(expr);
LOGGER.warn("Unsupported type literal for Calcite conversion: {}", type);
return SqlTypeName.ANY;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import io.substrait.extension.SimpleExtension.Argument;
import io.substrait.function.ParameterizedType;
import io.substrait.function.ToTypeString;
import io.substrait.function.TypeExpression;
import io.substrait.isthmus.TypeConverter;
import io.substrait.isthmus.Utils;
import io.substrait.isthmus.expression.FunctionMappings.Sig;
Expand Down Expand Up @@ -238,18 +239,42 @@ public boolean allowedArgCount(int count) {

private Optional<F> signatureMatch(List<Type> inputTypes, Type outputType) {
for (F function : functions) {
List<SimpleExtension.Argument> args = function.requiredArguments();
// Make sure that arguments & return are within bounds and match the types
if (function.returnType() instanceof ParameterizedType
&& isMatch(outputType, (ParameterizedType) function.returnType())
&& inputTypesMatchDefinedArguments(inputTypes, args)) {
TypeExpression funcReturnType = function.returnType();
boolean returnTypeMatches = isReturnTypeMatch(outputType, funcReturnType);

List<SimpleExtension.Argument> args = function.requiredArguments();

if (returnTypeMatches && inputTypesMatchDefinedArguments(inputTypes, args)) {
return Optional.of(function);
}
}

return Optional.empty();
}

private boolean isReturnTypeMatch(final Type outputType, final TypeExpression funcReturnType) {
if (funcReturnType instanceof ParameterizedType) {
return isMatch(outputType, (ParameterizedType) funcReturnType);
}

if (funcReturnType instanceof Type) {
// For non-parameterized return types, check if they match
Type targetType = (Type) funcReturnType;

if (outputType instanceof ParameterizedType) {
// outputType is parameterized but targetType is not - use visitor pattern
return ((ParameterizedType) outputType)
.accept(new IgnoreNullableAndParameters(targetType));
}

// Both are non-parameterized types - compare them directly
return outputType.getClass().equals(targetType.getClass());
}

return false;
}

/**
* Checks to see if the given input types satisfy the function arguments given. Checks that
*
Expand Down Expand Up @@ -467,16 +492,17 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
}
}

// Try matchCoerced even if singularInputType is empty.
// This handles functions with mixed argument types like strftime(timestamp, string)
Optional<T> coerced = matchCoerced(call, outputType, operands);
if (coerced.isPresent()) {
return coerced;
}

if (singularInputType.isPresent()) {
Optional<T> coerced = matchCoerced(call, outputType, operands);
if (coerced.isPresent()) {
return coerced;
}
Optional<T> leastRestrictive = matchByLeastRestrictive(call, outputType, operands);
if (leastRestrictive.isPresent()) {
return leastRestrictive;
}
return matchByLeastRestrictive(call, outputType, operands);
}

return Optional.empty();
}

Expand Down Expand Up @@ -565,4 +591,25 @@ private static boolean isMatch(ParameterizedType actualType, ParameterizedType t
}
return actualType.accept(new IgnoreNullableAndParameters(targetType));
}

/**
* Identifies functions that are not mapped in the provided Sig list.
*
* @param functions the list of function variants to check
* @param sigs the list of mapped Sig signatures
* @return a list of functions that are not found in the Sig mappings (case-insensitive name
* comparison)
*/
public static <F extends SimpleExtension.Function> List<F> getUnmappedFunctions(
List<F> functions, ImmutableList<FunctionMappings.Sig> sigs) {
Set<String> mappedNames =
sigs.stream()
.map(FunctionMappings.Sig::name)
.map(name -> name.toLowerCase(Locale.ROOT))
.collect(Collectors.toSet());

return functions.stream()
.filter(fn -> !mappedNames.contains(fn.name().toLowerCase(Locale.ROOT)))
.collect(Collectors.toList());
}
}
Loading
Loading