Skip to content

Commit c288cc1

Browse files
committed
Add missing extension YAMLs and test for completeness
Closes #722
1 parent d765711 commit c288cc1

4 files changed

Lines changed: 82 additions & 0 deletions

File tree

core/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ dependencies {
8787
testImplementation(platform(libs.junit.bom))
8888
testImplementation(libs.protobuf.java.util)
8989
testImplementation(libs.guava)
90+
testImplementation(libs.bundles.jackson)
9091

9192
testImplementation(libs.junit.jupiter)
9293
testRuntimeOnly(libs.junit.platform.launcher)

core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ public class DefaultExtensionCatalog {
1414
public static final String FUNCTIONS_AGGREGATE_APPROX =
1515
"extension:io.substrait:functions_aggregate_approx";
1616

17+
/** Extension identifier for aggregate functions with decimal output. */
18+
public static final String FUNCTIONS_AGGREGATE_DECIMAL_OUTPUT =
19+
"extension:io.substrait:functions_aggregate_decimal_output";
20+
1721
/** Extension identifier for generic aggregate functions. */
1822
public static final String FUNCTIONS_AGGREGATE_GENERIC =
1923
"extension:io.substrait:functions_aggregate_generic";
@@ -37,6 +41,9 @@ public class DefaultExtensionCatalog {
3741
/** Extension identifier for geometry functions. */
3842
public static final String FUNCTIONS_GEOMETRY = "extension:io.substrait:functions_geometry";
3943

44+
/** Extension identifier for list functions. */
45+
public static final String FUNCTIONS_LIST = "extension:io.substrait:functions_list";
46+
4047
/** Extension identifier for logarithmic functions. */
4148
public static final String FUNCTIONS_LOGARITHMIC = "extension:io.substrait:functions_logarithmic";
4249

@@ -78,7 +85,10 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() {
7885
"logarithmic",
7986
"rounding",
8087
"rounding_decimal",
88+
"set",
8189
"string")
90+
// TODO(#688): functions_list.yaml is not loaded here because it uses lambda type
91+
// expressions (e.g. func<any1 -> any2>) that are not yet supported by the type parser.
8292
.stream()
8393
.map(c -> String.format("/functions_%s.yaml", c))
8494
.collect(Collectors.toList());

core/src/main/java/io/substrait/extension/SimpleExtension.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,11 @@ public ScalarFunctionVariant getScalarFunction(FunctionAnchor anchor) {
680680
anchor.key(), anchor.urn()));
681681
}
682682

683+
/** Returns true if the given URN has any functions or types loaded in this collection. */
684+
public boolean containsUrn(String urn) {
685+
return urnSupplier.get().contains(urn) || types().stream().anyMatch(t -> t.urn().equals(urn));
686+
}
687+
683688
private void checkUrn(String name) {
684689
if (urnSupplier.get().contains(name)) {
685690
return;
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package io.substrait.extension;
2+
3+
import static io.substrait.extension.DefaultExtensionCatalog.DEFAULT_COLLECTION;
4+
import static org.junit.jupiter.api.Assertions.assertTrue;
5+
6+
import com.fasterxml.jackson.databind.JsonNode;
7+
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
9+
import java.io.File;
10+
import java.io.IOException;
11+
import java.util.List;
12+
import java.util.Set;
13+
import org.junit.jupiter.api.Test;
14+
15+
/**
16+
* Verifies that every extension YAML in substrait/extensions is loaded by {@link
17+
* DefaultExtensionCatalog}.
18+
*/
19+
class DefaultExtensionCatalogTest {
20+
21+
private static final Set<String> UNSUPPORTED_FILES =
22+
Set.of(
23+
// TODO: aggregate_decimal_output defines count and approx_count_distinct with
24+
// decimal<38,0> return types instead of i64. When loaded alongside aggregate_generic,
25+
// the same function key (e.g. count:any) maps to the same Calcite operator twice,
26+
// which breaks the reverse lookup in FunctionConverter.getSqlOperatorFromSubstraitFunc.
27+
// Fixing this requires either deduplicating the operator map or adding type-based
28+
// disambiguation for aggregate functions.
29+
"functions_aggregate_decimal_output.yaml",
30+
"functions_geometry.yaml", // user-defined types not supported in Calcite type conversion
31+
"functions_list.yaml", // TODO(#688): remove once lambda types are supported
32+
"type_variations.yaml", // type variations not yet supported by extension loader
33+
"unknown.yaml" // unknown type extension not yet loaded
34+
);
35+
36+
private static final ObjectMapper YAML_MAPPER = new ObjectMapper(new YAMLFactory());
37+
38+
@Test
39+
void allExtensionYamlFilesAreLoaded() throws IOException {
40+
List<File> yamlFiles = getExtensionYamlFiles();
41+
42+
for (File file : yamlFiles) {
43+
if (UNSUPPORTED_FILES.contains(file.getName())) {
44+
continue;
45+
}
46+
String urn = parseUrn(file);
47+
assertTrue(
48+
DEFAULT_COLLECTION.containsUrn(urn),
49+
file.getName() + " not loaded by DefaultExtensionCatalog (urn: " + urn + ")");
50+
}
51+
}
52+
53+
private static String parseUrn(File yamlFile) throws IOException {
54+
JsonNode doc = YAML_MAPPER.readTree(yamlFile);
55+
JsonNode urnNode = doc.get("urn");
56+
return urnNode == null ? null : urnNode.asText();
57+
}
58+
59+
private static List<File> getExtensionYamlFiles() {
60+
File extensionsDir = new File("../substrait/extensions");
61+
assertTrue(extensionsDir.isDirectory(), "substrait/extensions directory not found");
62+
File[] files = extensionsDir.listFiles((dir, name) -> name.endsWith(".yaml"));
63+
assertTrue(files != null && files.length > 0, "No YAML files found");
64+
return List.of(files);
65+
}
66+
}

0 commit comments

Comments
 (0)