diff --git a/sqlit/domains/query/app/multi_statement.py b/sqlit/domains/query/app/multi_statement.py index fd13d715..381a97f6 100644 --- a/sqlit/domains/query/app/multi_statement.py +++ b/sqlit/domains/query/app/multi_statement.py @@ -19,7 +19,8 @@ def _iter_sql_chars(sql: str) -> Iterator[tuple[int, str, bool]]: """Iterate through SQL characters, tracking string literal context. - Handles escape sequences (backslash) and SQL-style doubled quotes. + Handles escape sequences (backslash), SQL-style doubled quotes, + and PostgreSQL dollar-quoted strings ($$ or $tag$). Yields: (index, char, outside_string) tuples where outside_string is True @@ -27,9 +28,24 @@ def _iter_sql_chars(sql: str) -> Iterator[tuple[int, str, bool]]: """ in_single_quote = False in_double_quote = False + in_dollar_tag: str | None = None i = 0 while i < len(sql): + # If inside a dollar-quoted string, check for the closing tag + if in_dollar_tag is not None: + if sql[i:].startswith(in_dollar_tag): + # Yield the characters of the closing tag as inside string + for offset in range(len(in_dollar_tag)): + yield (i + offset, sql[i + offset], False) + i += len(in_dollar_tag) + in_dollar_tag = None + continue + else: + yield (i, sql[i], False) + i += 1 + continue + char = sql[i] # Handle escape sequences in strings @@ -51,6 +67,18 @@ def _iter_sql_chars(sql: str) -> Iterator[tuple[int, str, bool]]: i += 2 continue + # Check for PostgreSQL dollar-quoted string start + if char == "$" and not in_single_quote and not in_double_quote: + # Match $[a-zA-Z_][a-zA-Z0-9_]*$ or $$ + match = re.match(r"^\$([a-zA-Z_][a-zA-Z0-9_]*)?\$", sql[i:]) + if match: + delimiter = match.group(0) + in_dollar_tag = delimiter + for offset in range(len(delimiter)): + yield (i + offset, sql[i + offset], False) + i += len(delimiter) + continue + # Toggle quote state and yield if char == "'" and not in_double_quote: in_single_quote = not in_single_quote diff --git a/tests/unit/test_multi_statement.py b/tests/unit/test_multi_statement.py index 21e33077..617abb2c 100644 --- a/tests/unit/test_multi_statement.py +++ b/tests/unit/test_multi_statement.py @@ -76,6 +76,55 @@ def test_handles_multiline_statements(self): assert len(statements) == 2 + def test_preserves_semicolons_in_dollar_quoted_strings(self): + """Should not split on semicolons inside dollar-quoted strings.""" + from sqlit.domains.query.app.multi_statement import split_statements + + query = """ + CREATE OR REPLACE FUNCTION example() + RETURNS void AS $$ + BEGIN + INSERT INTO t (x) VALUES ('a;b'); + END; + $$ LANGUAGE plpgsql; + SELECT 1; + """ + statements = split_statements(query) + + assert len(statements) == 2 + assert "CREATE OR REPLACE FUNCTION" in statements[0] + assert "SELECT 1" in statements[1] + + def test_preserves_semicolons_in_named_dollar_quoted_strings(self): + """Should not split on semicolons inside named dollar-quoted strings.""" + from sqlit.domains.query.app.multi_statement import split_statements + + query = """ + CREATE OR REPLACE FUNCTION example() + RETURNS void AS $func_tag$ + BEGIN + INSERT INTO t (x) VALUES ('a;b'); + END; + $func_tag$ LANGUAGE plpgsql; + SELECT 1; + """ + statements = split_statements(query) + + assert len(statements) == 2 + assert "CREATE OR REPLACE FUNCTION" in statements[0] + assert "SELECT 1" in statements[1] + + def test_dollar_quotes_inside_standard_strings_are_ignored(self): + """Should ignore dollar quote delimiters when inside standard string literals.""" + from sqlit.domains.query.app.multi_statement import split_statements + + query = "INSERT INTO t (x) VALUES ('$$'); SELECT 1" + statements = split_statements(query) + + assert len(statements) == 2 + assert "INSERT" in statements[0] + assert "SELECT 1" in statements[1] + class TestMultiStatementResult: """Tests for MultiStatementResult data structure."""