-
Notifications
You must be signed in to change notification settings - Fork 127
Expand file tree
/
Copy pathsql_cleaner.py
More file actions
255 lines (211 loc) · 8.94 KB
/
sql_cleaner.py
File metadata and controls
255 lines (211 loc) · 8.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
"""Raw SQL preprocessing before AST construction.
Pure string transformations — no sqlglot dependency. Handles comment
stripping, ``REPLACE INTO`` rewriting, qualified CTE name normalisation,
DB2 isolation-level clauses, malformed-query rejection, and redundant
outer-parenthesis removal.
"""
import re
from typing import NamedTuple
from sqlglot.errors import TokenError
from sqlglot.tokens import Tokenizer, TokenType
from sql_metadata.comments import strip_comments_for_parsing as _strip_comments
from sql_metadata.exceptions import InvalidQueryDefinition
from sql_metadata.utils import DOT_PLACEHOLDER
class CleanResult(NamedTuple):
"""Result of :meth:`SqlCleaner.clean`."""
sql: str | None
is_replace: bool
cte_name_map: dict[str, str]
def _is_wrapped(text: str) -> bool:
"""Check whether *text* is wrapped in balanced outer parentheses.
:param text: SQL string to check.
:type text: str
:returns: ``True`` if *text* has balanced outer parentheses.
:rtype: bool
"""
if len(text) < 2 or text[0] != "(" or text[-1] != ")":
return False
depth = 0
for c in text[1:-1]:
if c == "(":
depth += 1
elif c == ")":
depth -= 1
if depth < 0:
return False
return True
def _strip_outer_parens(sql: str) -> str:
"""Strip redundant outer parentheses from *sql*.
Needed because sqlglot cannot parse double-wrapped non-SELECT
statements like ``((UPDATE ...))``. A depth guard prevents stack
overflow on pathological input.
:param sql: SQL string that may be wrapped in outer parentheses.
:type sql: str
:returns: The unwrapped SQL string.
:rtype: str
"""
def _recur(s: str, depth: int) -> str:
if depth > 100:
return s
s = s.strip()
if _is_wrapped(s):
return _recur(s[1:-1], depth + 1)
return s
return _recur(sql, 0)
def _normalize_cte_names(sql: str) -> tuple[str, dict[str, str]]:
"""Replace qualified CTE names with simple placeholders.
sqlglot cannot parse ``WITH db.cte_name AS (...)`` because it
interprets ``db.cte_name`` as a table reference. This function
rewrites such names to ``db__DOT__cte_name`` and returns a mapping
so that the original qualified names can be restored after extraction.
:param sql: SQL string that may contain qualified CTE names.
:type sql: str
:returns: A 2-tuple of ``(modified_sql, {placeholder: original_name})``.
:rtype: tuple
"""
name_map = {}
# Find WITH ... AS patterns with qualified names
pattern = re.compile(
r"(\bWITH\s+|,\s*)(\w+\.\w+)(\s+AS\s*\()",
re.IGNORECASE,
)
def replacer(match: re.Match[str]) -> str:
prefix = match.group(1)
qualified_name = match.group(2)
suffix = match.group(3)
placeholder = qualified_name.replace(".", DOT_PLACEHOLDER)
name_map[placeholder] = qualified_name
return f"{prefix}{placeholder}{suffix}"
modified = pattern.sub(replacer, sql)
# Also replace references to qualified CTE names in FROM/JOIN clauses
for placeholder, original in name_map.items():
# Replace references but not the definition (already replaced)
# Use word boundary to avoid partial matches
modified = re.sub(
r"\b" + re.escape(original) + r"\b",
placeholder,
modified,
)
return modified, name_map
class SqlCleaner:
"""Preprocess raw SQL strings before dialect parsing.
All methods are ``@staticmethod`` — the class serves as a namespace
grouping :meth:`clean` (full preprocessing pipeline consumed by
:class:`ASTParser`) and :meth:`preprocess_query` (quoting/whitespace
normalisation consumed by :attr:`Parser.query`).
"""
@staticmethod
def preprocess_query(sql: str) -> str:
"""Normalise quoting and whitespace in raw SQL.
Walks sqlglot's tokenizer output, emitting each token's original
text verbatim except double-quoted identifiers, which are rewritten
to backtick-quoted. Because string literals are ``STRING`` tokens,
any ``"`` characters inside them are preserved automatically — no
sentinel substitution needed. Newlines between tokens (which is
where whitespace and comments live) are collapsed to spaces.
:param sql: Raw SQL string.
:type sql: str
:returns: The normalised SQL string, or ``""`` for empty input.
:rtype: str
"""
if not sql:
return ""
try:
# Use the default tokenizer unconditionally — the MySQL
# tokenizer reclassifies ``"X"`` as a STRING token (because
# MySQL with ANSI_QUOTES off treats double-quotes as strings),
# which would skip the identifier rewrite below.
tokens = list(Tokenizer().tokenize(sql))
except TokenError:
# Malformed SQL — fall back to plain whitespace collapse.
return re.sub(r" {2,}", " ", sql.replace("\n", " ")).strip()
parts: list[str] = []
prev_end = 0
for tok in tokens:
# Gap before the token holds whitespace and comments —
# collapse newlines so the output stays single-line.
parts.append(sql[prev_end:tok.start].replace("\n", " "))
raw = sql[tok.start:tok.end + 1]
if tok.token_type == TokenType.IDENTIFIER and raw.startswith('"'):
# e.g. "col" → `col`. String literals are STRING tokens
# and never enter this branch, so embedded " are safe.
parts.append(f"`{tok.text}`")
else:
parts.append(raw)
prev_end = tok.end + 1
parts.append(sql[prev_end:].replace("\n", " "))
return re.sub(r" {2,}", " ", "".join(parts))
@staticmethod
def clean(sql: str) -> CleanResult:
"""Apply all preprocessing steps to raw SQL.
Steps (in order):
1. Rewrite ``REPLACE INTO`` → ``INSERT INTO``.
2. Rewrite ``SELECT...INTO var FROM`` → ``SELECT...FROM``.
3. Strip comments.
4. Normalise qualified CTE names.
5. Strip DB2 isolation-level clauses.
6. Detect malformed ``WITH...AS(...) AS`` patterns.
7. Strip redundant outer parentheses.
:param sql: Raw SQL string.
:type sql: str
:returns: Cleaning result with preprocessed SQL (``None`` if
effectively empty), replace flag, and CTE name map.
:rtype: CleanResult
:raises ValueError: If a malformed WITH pattern is detected.
"""
is_replace = False
if re.match(r"\s*REPLACE\b", sql, re.IGNORECASE):
sql = re.sub(
r"\bREPLACE\s+INTO\b",
"INSERT INTO",
sql,
count=1,
flags=re.IGNORECASE,
)
is_replace = True
# Rewrite SELECT...INTO var1,var2 FROM → SELECT...FROM
# so sqlglot doesn't treat variables as tables.
# Only match when INTO target has a comma (variable assignment),
# not MSSQL's SELECT...INTO new_table FROM (table creation).
sql = re.sub(
r"(?i)(\bSELECT\b.+?)\bINTO\b\s+\w+\s*,.*?\bFROM\b",
r"\1FROM",
sql,
count=1,
flags=re.DOTALL,
)
clean_sql = _strip_comments(sql)
if not clean_sql.strip():
return CleanResult(sql=None, is_replace=is_replace, cte_name_map={})
clean_sql, cte_name_map = _normalize_cte_names(clean_sql)
clean_sql = re.sub(
r"\bwith\s+(ur|cs|rs|rr)\s*$", "", clean_sql, flags=re.IGNORECASE
).strip()
SqlCleaner._detect_malformed_with(clean_sql)
clean_sql = _strip_outer_parens(clean_sql)
if not clean_sql.strip():
return CleanResult(
sql=None, is_replace=is_replace, cte_name_map=cte_name_map
)
return CleanResult(
sql=clean_sql, is_replace=is_replace, cte_name_map=cte_name_map
)
@staticmethod
def _detect_malformed_with(clean_sql: str) -> None:
"""Raise ``ValueError`` if the SQL contains a malformed WITH pattern.
Detects ``WITH...AS(...) AS <keyword>`` or
``WITH...AS(...) AS <word> <keyword>`` — an extra ``AS`` token
after the CTE body that indicates malformed SQL.
:param clean_sql: Preprocessed SQL string.
:type clean_sql: str
:raises ValueError: If a malformed WITH pattern is found.
"""
if not re.match(r"\s*WITH\b", clean_sql, re.IGNORECASE):
return
main_kw = r"(?:SELECT|INSERT|UPDATE|DELETE)"
if re.search(
r"\)\s+AS\s+" + main_kw + r"\b", clean_sql, re.IGNORECASE
) or re.search(r"\)\s+AS\s+\w+\s+" + main_kw + r"\b", clean_sql, re.IGNORECASE):
raise InvalidQueryDefinition(
"Malformed WITH clause — extra AS keyword after CTE body"
)