1111from sqlglot import Dialect , Generator , ParseError , Parser , Tokenizer , TokenType , exp
1212from sqlglot .dialects .dialect import DialectType
1313from sqlglot .dialects .snowflake import Snowflake
14+ from sqlglot .helper import seq_get
1415from sqlglot .optimizer .normalize_identifiers import normalize_identifiers
1516from sqlglot .optimizer .scope import traverse_scope
1617from sqlglot .tokens import Token
@@ -607,8 +608,10 @@ class ChunkType(Enum):
607608 SQL = auto ()
608609
609610
610- def parse_one (sql : str , dialect : t .Optional [str ] = None ) -> exp .Expression :
611- expressions = parse (sql , default_dialect = dialect , match_dialect = False )
611+ def parse_one (
612+ sql : str , dialect : t .Optional [str ] = None , into : t .Optional [exp .IntoType ] = None
613+ ) -> exp .Expression :
614+ expressions = parse (sql , default_dialect = dialect , match_dialect = False , into = into )
612615 if not expressions :
613616 raise SQLMeshError (f"No expressions found in '{ sql } '" )
614617 elif len (expressions ) > 1 :
@@ -617,7 +620,10 @@ def parse_one(sql: str, dialect: t.Optional[str] = None) -> exp.Expression:
617620
618621
619622def parse (
620- sql : str , default_dialect : t .Optional [str ] = None , match_dialect : bool = True
623+ sql : str ,
624+ default_dialect : t .Optional [str ] = None ,
625+ match_dialect : bool = True ,
626+ into : t .Optional [exp .IntoType ] = None ,
621627) -> t .List [exp .Expression ]:
622628 """Parse a sql string.
623629
@@ -668,7 +674,10 @@ def parse(
668674
669675 for chunk , chunk_type in chunks :
670676 if chunk_type == ChunkType .SQL :
671- for expression in parser .parse (chunk , sql ):
677+ parsed_expressions : t .List [t .Optional [exp .Expression ]] = (
678+ parser .parse (chunk , sql ) if into is None else parser .parse_into (into , chunk , sql )
679+ )
680+ for expression in parsed_expressions :
672681 if expression :
673682 expression .meta ["sql" ] = parser ._find_sql (chunk [0 ], chunk [- 1 ])
674683 expressions .append (expression )
@@ -706,6 +715,10 @@ def extend_sqlglot() -> None:
706715 parser .QUERY_MODIFIER_PARSERS .update (
707716 {TokenType .PARAMETER : lambda self : _parse_body_macro (self )}
708717 )
718+ # FIXME: Delete the extension below after upgrading to SQLGlot >= 20.3.0.
719+ parser .EXPRESSION_PARSERS .update (
720+ {exp .When : lambda self : seq_get (self ._parse_when_matched (), 0 )}
721+ )
709722
710723 for generator in generators :
711724 if MacroFunc not in generator .TRANSFORMS :
0 commit comments