Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions .circleci/continue_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -297,16 +297,16 @@ workflows:
matrix:
parameters:
engine:
- snowflake
- databricks
- redshift
- bigquery
- clickhouse-cloud
#- snowflake
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo: revert prior to merge

#- databricks
#- redshift
#- bigquery
#- clickhouse-cloud
- athena
filters:
branches:
only:
- main
#filters:
# branches:
# only:
# - main
- ui_style
- ui_test
- vscode_test
Expand Down
9 changes: 9 additions & 0 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sqlglot import Dialect, Generator, ParseError, Parser, Tokenizer, TokenType, exp
from sqlglot.dialects.dialect import DialectType
from sqlglot.dialects import DuckDB, Snowflake
import sqlglot.dialects.athena as athena
from sqlglot.helper import seq_get
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import quote_identifiers
Expand Down Expand Up @@ -1014,6 +1015,14 @@ def extend_sqlglot() -> None:
generators = {Generator}

for dialect in Dialect.classes.values():
# Athena picks a different Tokenizer / Parser / Generator depending on the query
# so this ensures that the extra ones it defines are also extended
if dialect == athena.Athena:
tokenizers.add(athena._TrinoTokenizer)
parsers.add(athena._TrinoParser)
generators.add(athena._TrinoGenerator)
generators.add(athena._HiveGenerator)

if hasattr(dialect, "Tokenizer"):
tokenizers.add(dialect.Tokenizer)
if hasattr(dialect, "Parser"):
Expand Down
17 changes: 17 additions & 0 deletions tests/core/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
select_from_values_for_batch_range,
text_diff,
)
import sqlmesh.core.dialect as d
from sqlmesh.core.model import SqlModel, load_sql_based_model
from sqlmesh.core.config.connection import DIALECT_TO_TYPE


def test_format_model_expressions():
Expand Down Expand Up @@ -700,3 +702,18 @@ def test_model_name_cannot_be_string():

def test_parse_snowflake_create_schema_ddl():
assert parse_one("CREATE SCHEMA d.s", dialect="snowflake").sql() == "CREATE SCHEMA d.s"


@pytest.mark.parametrize("dialect", sorted(set(DIALECT_TO_TYPE.values())))
def test_sqlglot_extended_correctly(dialect: str) -> None:
# MODEL is a SQLMesh extension and not part of SQLGlot
# If we can roundtrip an expression containing MODEL across every dialect, then the SQLMesh extensions have been registered correctly
ast = d.parse_one("MODEL (name foo)", dialect=dialect)
assert isinstance(ast, d.Model)
name_prop = ast.find(exp.Property)
assert isinstance(name_prop, exp.Property)
assert name_prop.this == "name"
value = name_prop.args["value"]
assert isinstance(value, exp.Table)
assert value.sql() == "foo"
assert ast.sql(dialect=dialect) == "MODEL (\nname foo\n)"