Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,55 @@ class MiscOperatorSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa
}
}

test("native union_all with two level union keeps distinct output columns") {
withTempView("union_src_a", "union_src_b", "union_src_c") {
Seq(
("valueA", "value1", "value11", "value111"),
("valueA", "value2", "value22", "value222")
).toDF("col1", "col2", "col3", "col4")
.createOrReplaceTempView("union_src_a")
Seq(
("valueB", "value3", "value33", "value333"),
("valueB", "value4", "value44", "value444")
).toDF("col1", "col2", "col3", "col4")
.createOrReplaceTempView("union_src_b")

withSQLConf(GlutenConfig.NATIVE_UNION_ENABLED.key -> "true") {
compareDfResultsAgainstVanillaSpark(
() =>
spark.sql("""
|with deduplicated_data as (
| select col1, col2, col3, col4
| from (
| select
| u.col1,
| u.col2,
| u.col3,
| u.col4,
| row_number() over (partition by u.col2 order by u.col5 desc) as rn
| from (
| select col1, col2, col3, col4, 98 as col5 from union_src_a
| union all
| select col1, col2, col3, col4, 100 as col5 from union_src_b
| ) u
| ) t
| where t.rn = 1
|)
|select col1, col2, col3, col4
|from deduplicated_data
|where col1 != 'valueC'
|union all
|select col1, col2, col3, col4
|from deduplicated_data
|where col1 = 'valueC'
|""".stripMargin),
compareResult = true,
checkGlutenPlan[UnionExecTransformer]
)
}
}
}

test("union two tables") {
runQueryAndCompare("""
|select count(orderkey) from (
Expand Down
3 changes: 2 additions & 1 deletion cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1244,7 +1244,8 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
const RowTypePtr outRowType = asRowType(children[0]->outputType());
std::vector<std::string> outNames;
for (int32_t colIdx = 0; colIdx < outRowType->size(); ++colIdx) {
const auto name = outRowType->childAt(colIdx)->name();
// Using field names from the unified output row type instead child type names
const auto name = outRowType->nameOf(colIdx);
outNames.push_back(name);
}

Expand Down
Loading