Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions sql_metadata/column_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,12 @@ def add_alias(self, name: str, target: Any, clause: str) -> None:
self.alias_names.append(name)
if clause:
self.alias_dict.setdefault(clause, UniqueList()).append(name)
if target is not None:
self.alias_map[name] = target
if target is None:
return
existing = self.alias_map.get(name, [])
merged = UniqueList(existing if isinstance(existing, list) else [existing])
merged.extend(target if isinstance(target, list) else [target])
self.alias_map[name] = merged if len(merged) > 1 else merged[0]


# ---------------------------------------------------------------------------
Expand Down
35 changes: 35 additions & 0 deletions test/test_unions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,41 @@
from sql_metadata import Parser


def test_union_column_aliases():
# https://github.com/macbre/sql-metadata/issues/401
# When UNION combines queries with the same alias,
# columns_aliases should aggregate all source columns
query = """
select a.A as M
from tab1 a
union all
select b.B as M
from tab2 b
"""
parser = Parser(query)
assert parser.columns_aliases == {"M": ["tab1.A", "tab2.B"]}
assert parser.columns == ["tab1.A", "tab2.B"]
assert parser.tables == ["tab1", "tab2"]


def test_union_alias_with_expression_targets():
# Regression: scalar then list-target must not nest
q1 = """
SELECT a AS x FROM t1
UNION ALL
SELECT b + c AS x FROM t2
"""
assert Parser(q1).columns_aliases == {"x": ["a", "b", "c"]}

# Regression: list then list-target must not raise TypeError on UniqueList
q2 = """
SELECT a + b AS x FROM t1
UNION ALL
SELECT c + d AS x FROM t2
"""
assert Parser(q2).columns_aliases == {"x": ["a", "b", "c", "d"]}


def test_union():
query = """
SELECT
Expand Down