Skip to content

Commit 7b88251

Browse files
authored
Fix(athena): Properly extend Athena dialect (#5077)
1 parent 3bbb819 commit 7b88251

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

sqlmesh/core/dialect.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sqlglot import Dialect, Generator, ParseError, Parser, Tokenizer, TokenType, exp
1414
from sqlglot.dialects.dialect import DialectType
1515
from sqlglot.dialects import DuckDB, Snowflake
16+
import sqlglot.dialects.athena as athena
1617
from sqlglot.helper import seq_get
1718
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1819
from sqlglot.optimizer.qualify_columns import quote_identifiers
@@ -1014,6 +1015,14 @@ def extend_sqlglot() -> None:
10141015
generators = {Generator}
10151016

10161017
for dialect in Dialect.classes.values():
1018+
# Athena picks a different Tokenizer / Parser / Generator depending on the query
1019+
# so this ensures that the extra ones it defines are also extended
1020+
if dialect == athena.Athena:
1021+
tokenizers.add(athena._TrinoTokenizer)
1022+
parsers.add(athena._TrinoParser)
1023+
generators.add(athena._TrinoGenerator)
1024+
generators.add(athena._HiveGenerator)
1025+
10171026
if hasattr(dialect, "Tokenizer"):
10181027
tokenizers.add(dialect.Tokenizer)
10191028
if hasattr(dialect, "Parser"):

tests/core/test_dialect.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
select_from_values_for_batch_range,
1313
text_diff,
1414
)
15+
import sqlmesh.core.dialect as d
1516
from sqlmesh.core.model import SqlModel, load_sql_based_model
17+
from sqlmesh.core.config.connection import DIALECT_TO_TYPE
1618

1719

1820
def test_format_model_expressions():
@@ -700,3 +702,18 @@ def test_model_name_cannot_be_string():
700702

701703
def test_parse_snowflake_create_schema_ddl():
702704
assert parse_one("CREATE SCHEMA d.s", dialect="snowflake").sql() == "CREATE SCHEMA d.s"
705+
706+
707+
@pytest.mark.parametrize("dialect", sorted(set(DIALECT_TO_TYPE.values())))
708+
def test_sqlglot_extended_correctly(dialect: str) -> None:
709+
# MODEL is a SQLMesh extension and not part of SQLGlot
710+
# If we can roundtrip an expression containing MODEL across every dialect, then the SQLMesh extensions have been registered correctly
711+
ast = d.parse_one("MODEL (name foo)", dialect=dialect)
712+
assert isinstance(ast, d.Model)
713+
name_prop = ast.find(exp.Property)
714+
assert isinstance(name_prop, exp.Property)
715+
assert name_prop.this == "name"
716+
value = name_prop.args["value"]
717+
assert isinstance(value, exp.Table)
718+
assert value.sql() == "foo"
719+
assert ast.sql(dialect=dialect) == "MODEL (\nname foo\n)"

0 commit comments

Comments
 (0)