Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
strategy:
fail-fast: false
matrix:
producer: [datafusion, duckdb, ibis, isthmus]
producer: [datafusion, duckdb, ibis, isthmus, spark]
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,6 @@ dmypy.json

# PyCharm
.idea/

jars/
substrait_consumer/data/
1 change: 1 addition & 0 deletions requirements-unlocked.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ JPype1
pandas
protobuf==5.28.3
pyarrow
pyspark
pytest
pytest-csv
pytest-snapshot
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ parsy==2.1
pluggy==1.5.0
protobuf==5.28.3
pyarrow==18.1.0
pyspark==3.5.5
pytest==8.3.3
pytest-csv==3.0.0
pytest-snapshot==0.9.0
Expand Down
2 changes: 1 addition & 1 deletion substrait-java
Submodule substrait-java updated 322 files
3 changes: 2 additions & 1 deletion substrait_consumer/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from substrait_consumer.producers.duckdb_producer import DuckDBProducer
from substrait_consumer.producers.ibis_producer import IbisProducer
from substrait_consumer.producers.isthmus_producer import IsthmusProducer
from substrait_consumer.producers.spark_producer import SparkProducer


@pytest.fixture
Expand Down Expand Up @@ -105,7 +106,7 @@ def saveplan(request):

PRODUCERS = {
cls.name(): cls
for cls in [DataFusionProducer, DuckDBProducer, IbisProducer, IsthmusProducer]
for cls in [DataFusionProducer, DuckDBProducer, IbisProducer, IsthmusProducer, SparkProducer]
}
CONSUMERS = {
cls.name(): cls for cls in [AceroConsumer, DataFusionConsumer, DuckDBConsumer]
Expand Down
82 changes: 82 additions & 0 deletions substrait_consumer/producers/spark_producer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from pyspark.sql import SparkSession, DataFrame
from pathlib import Path
import os
import substrait_validator as sv
from .producer import SQLProducer

class SparkProducer(SQLProducer):
"""
Adapts the Spark Substrait producer to the test framework.
"""

@classmethod
def name(self):
return "spark"

def __init__(self):
jars = "io.substrait:core:0.66.0,io.substrait:spark:0.66.0,com.google.protobuf:protobuf-java-util:4.33.0"
self._spark = SparkSession.builder.master("local").appName("SparkProducer").config("spark.jars.packages", jars).getOrCreate()

def _setup(
self, db_connection, local_files: dict[str, str], named_tables: dict[str, str]
):
self._tables = named_tables.keys()

# since Spark always stores the absolute path name in its filesystem read relation,
# we need to add a symlink to a fixed temp directory to avoid different path names on
# different systems.
temp_dir = "/tmp/substrait-io"
link = temp_dir + "/consumer-testing"
Path(temp_dir).mkdir(parents=True, exist_ok=True)
if os.path.exists(link):
os.remove(link)
cwd = os.getcwd()
os.symlink(cwd, link, target_is_directory=True)

for name, path in named_tables.items():
table = self._spark.read.load(path.replace(cwd, link))
table.createOrReplaceTempView(name)

def _produce_substrait(self, sql_query: str, validate=False) -> str:
"""
Produce the Spark substrait plan using the given SQL query.

Parameters:
sql_query:
SQL query.
Returns:
Substrait query plan in json format.
"""

df = self._spark.sql(sql_query)
jvm = self._spark.sparkContext._jvm

sparkPlan = df._jdf.queryExecution().optimizedPlan()

toSubstrait = jvm.io.substrait.spark.logical.ToSubstraitRel()
substrait_plan = toSubstrait.convert(sparkPlan)

proto_converter = jvm.io.substrait.plan.PlanProtoConverter()
proto_plan = proto_converter.toProto(substrait_plan)
proto_json = jvm.com.google.protobuf.util.JsonFormat.printer().print(proto_plan)

if validate:
config = sv.Config()
# Warning: cannot automatically determine whether plan version
# is compatible with the Substrait version
config.override_diagnostic_level(7, "info", "info") # warning
# Warning: did not attempt to resolve YAML: configured recursion
# limit for URI resolution has been reached
config.override_diagnostic_level(2001, "info", "info")
sv.check_plan_valid(proto_json, config)

return proto_json

def _format_sql(self, sql_query: str) -> str:
# The table names are enclosed in single quotes (i.e. string literals)
# These need to be removed because Spark won't tolerate them.
sql = sql_query
for table in ("{"+ t + "}" for t in self._tables):
sql = sql.replace(f"'{table}'", table)

return sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
True
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
{
"extensionUris": [{
"extensionUriAnchor": 2,
"uri": "/functions_arithmetic.yaml"
}, {
"extensionUriAnchor": 1,
"uri": "/functions_rounding.yaml"
}],
"extensions": [{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 1,
"name": "round:fp64_i32",
"extensionUrnReference": 1
}
}, {
"extensionFunction": {
"extensionUriReference": 2,
"functionAnchor": 2,
"name": "acos:fp64",
"extensionUrnReference": 2
}
}],
"relations": [{
"root": {
"input": {
"fetch": {
"common": {
"direct": {
}
},
"input": {
"project": {
"common": {
"emit": {
"outputMapping": [16]
},
"hint": {
"outputNames": ["ACOS_TAX"]
}
},
"input": {
"read": {
"common": {
"direct": {
}
},
"baseSchema": {
"names": ["l_orderkey", "l_partkey", "l_suppkey", "l_linenumber", "l_quantity", "l_extendedprice", "l_discount", "l_tax", "l_returnflag", "l_linestatus", "l_shipdate", "l_commitdate", "l_receiptdate", "l_shipinstruct", "l_shipmode", "l_comment"],
"struct": {
"types": [{
"i64": {
"nullability": "NULLABILITY_NULLABLE"
}
}, {
"i64": {
"nullability": "NULLABILITY_NULLABLE"
}
}, {
"i64": {
"nullability": "NULLABILITY_NULLABLE"
}
}, {
"i64": {
"nullability": "NULLABILITY_NULLABLE"
}
}, {
"decimal": {
"scale": 2,
"precision": 15,
"nullability": "NULLABILITY_NULLABLE"
}
}, {
"decimal": {
"scale": 2,
"precision": 15,
"nullability": "NULLABILITY_NULLABLE"
}
}, {
"decimal": {
"scale": 2,
"precision": 15,
"nullability": "NULLABILITY_NULLABLE"
}
}, {
"decimal": {
"scale": 2,
"precision": 15,
"nullability": "NULLABILITY_NULLABLE"
}
}, {
"string": {
"nullability": "NULLABILITY_NULLABLE"
}
}, {
"string": {
"nullability": "NULLABILITY_NULLABLE"
}
}, {
"date": {
"nullability": "NULLABILITY_NULLABLE"
}
}, {
"date": {
"nullability": "NULLABILITY_NULLABLE"
}
}, {
"date": {
"nullability": "NULLABILITY_NULLABLE"
}
}, {
"string": {
"nullability": "NULLABILITY_NULLABLE"
}
}, {
"string": {
"nullability": "NULLABILITY_NULLABLE"
}
}, {
"string": {
"nullability": "NULLABILITY_NULLABLE"
}
}],
"nullability": "NULLABILITY_REQUIRED"
}
},
"localFiles": {
"items": [{
"uriFile": "file:///tmp/substrait-io/consumer-testing/substrait_consumer/data/tpch_parquet/lineitem.parquet",
"length": "26368539",
"parquet": {
}
}]
}
}
},
"expressions": [{
"scalarFunction": {
"functionReference": 1,
"outputType": {
"fp64": {
"nullability": "NULLABILITY_NULLABLE"
}
},
"arguments": [{
"value": {
"scalarFunction": {
"functionReference": 2,
"outputType": {
"fp64": {
"nullability": "NULLABILITY_NULLABLE"
}
},
"arguments": [{
"value": {
"cast": {
"type": {
"fp64": {
"nullability": "NULLABILITY_NULLABLE"
}
},
"input": {
"selection": {
"directReference": {
"structField": {
"field": 7
}
},
"rootReference": {
}
}
},
"failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION"
}
}
}]
}
}
}, {
"value": {
"literal": {
"i32": 2
}
}
}]
}
}]
}
},
"offset": "0",
"count": "10"
}
},
"names": ["ACOS_TAX"]
}
}],
"version": {
"minorNumber": 77,
"producer": "substrait-spark"
},
"extensionUrns": [{
"extensionUrnAnchor": 1,
"urn": "extension:io.substrait:functions_rounding"
}, {
"extensionUrnAnchor": 2,
"urn": "extension:io.substrait:functions_arithmetic"
}]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
True
Loading
Loading