diff --git a/packages/bigframes/bigframes/core/pyformat.py b/packages/bigframes/bigframes/core/pyformat.py index 8f3c94054094..dfd91ba1ad00 100644 --- a/packages/bigframes/bigframes/core/pyformat.py +++ b/packages/bigframes/bigframes/core/pyformat.py @@ -162,6 +162,160 @@ def _parse_fields(sql_template: str) -> list[str]: ] +def _is_escaped_open_brace(sql_template: str, idx: int, literal_char: str) -> bool: + """Checks if the character at idx in sql_template is an escaped open brace '{{'.""" + return sql_template[idx : idx + 2] == "{{" and literal_char == "{" + + +def _is_escaped_close_brace(sql_template: str, idx: int, literal_char: str) -> bool: + """Checks if the character at idx in sql_template is an escaped close brace '}}'.""" + return sql_template[idx : idx + 2] == "}}" and literal_char == "}" + + +def _consume_literal(sql_template: str, current_idx: int, literal_text: str) -> int: + """Advances current_idx past literal_text in sql_template, accounting for escaped braces. + + A **literal** (or literal text) is the static part of the template string that + does not contain formatting placeholders. The string.Formatter parser resolves + escaped braces ('{{' and '}}') into single braces ('{' and '}') in its output + literal_text. + + This function aligns the resolved literal_text back to the original + sql_template by consuming 2 characters from sql_template ('{{' or '}}') for + every single escaped brace character in literal_text, and 1 character for + everything else. + + Returns: + int: the advanced current_idx in sql_template. + """ + lit_idx = 0 + while lit_idx < len(literal_text): + if _is_escaped_open_brace(sql_template, current_idx, literal_text[lit_idx]): + current_idx += 2 + lit_idx += 1 + elif _is_escaped_close_brace(sql_template, current_idx, literal_text[lit_idx]): + current_idx += 2 + lit_idx += 1 + elif ( + current_idx < len(sql_template) + and sql_template[current_idx] == literal_text[lit_idx] + ): + current_idx += 1 + lit_idx += 1 + else: + raise RuntimeError( + "Internal error: failed to align parsed SQL template with original query. " + f"Expected {literal_text[lit_idx]!r} at position {current_idx} in template, " + f"but found {sql_template[current_idx : current_idx + 2]!r}." + ) + return current_idx + + +def _is_escaped_brace(sql_template: str, idx: int) -> bool: + """Checks if the template has an escaped brace ('{{' or '}}') at the given index.""" + return sql_template[idx : idx + 2] in ("{{", "}}") + + +def _advance_past_field(sql_template: str, current_idx: int) -> int: + """Advances current_idx past the format field starting at current_idx. + + A **field** (or replacement field) is a placeholder in the template enclosed + in braces (e.g., "{my_var}" or "{json_col: { "val": 1 } }"). + + This function assumes current_idx points to the opening '{' of a field. + It parses forward, tracking nested braces to find the matching closing '}' + that terminates the field, while ignoring escaped braces ('{{' and '}}') + which do not affect the nesting level. + + Returns: + int: the index immediately after the closing '}' of the field. + """ + assert sql_template[current_idx] == "{" + brace_count = 1 + current_idx += 1 # past '{' + + while brace_count > 0 and current_idx < len(sql_template): + if _is_escaped_brace(sql_template, current_idx): + current_idx += 2 + elif sql_template[current_idx] == "{": + brace_count += 1 + current_idx += 1 + elif sql_template[current_idx] == "}": + brace_count -= 1 + current_idx += 1 + else: + current_idx += 1 + + return current_idx + + +def _find_all_field_positions(sql_template: str) -> dict[tuple[str, int], int]: + """Finds the character positions of all fields in the sql_template. + + Returns: + dict: a dict mapping (field_name, occurrence_idx) to character index. + """ + formatter = string.Formatter() + current_idx = 0 + seen_counts: dict[str, int] = {} + positions: dict[tuple[str, int], int] = {} + + for literal_text, field_name, _, _ in formatter.parse(sql_template): + current_idx = _consume_literal(sql_template, current_idx, literal_text) + + if field_name is not None: + occurrence_idx = seen_counts.get(field_name, 0) + seen_counts[field_name] = occurrence_idx + 1 + + positions[(field_name, occurrence_idx)] = current_idx + + current_idx = _advance_past_field(sql_template, current_idx) + + return positions + + +def get_error_context_at_pos(sql_template: str, pos: int) -> str: + """Create a helpful 'pointer' to where the problematic position is + in the original SQL. + + This should make the error message a lot friendlier, by providing more + context towards the problematic syntax. + """ + if pos == -1: + return "" + + lines = sql_template.splitlines(keepends=True) + + char_count = 0 + target_line_idx = -1 + for i, line in enumerate(lines): + if char_count <= pos < char_count + len(line): + target_line_idx = i + break + char_count += len(line) + + if target_line_idx == -1: + return "" + + col_offset = pos - char_count + + context_lines = [] + start_line = max(0, target_line_idx - 2) + end_line = min(len(lines), target_line_idx + 3) + + for i in range(start_line, end_line): + line_num = i + 1 + line_content = lines[i].rstrip("\r\n") + if i == target_line_idx: + context_lines.append(f"{line_num:4d}: {line_content}") + indent = 6 + col_offset + context_lines.append(" " * indent + "^") + else: + context_lines.append(f"{line_num:4d}: {line_content}") + + return "\n".join(context_lines) + + def pyformat( sql_template: str, *, @@ -185,13 +339,36 @@ def pyformat( Raises: TypeError: if a referenced variable is not of a supported type. - KeyError: if a referenced variable is not found. + ValueError: + if a referenced variable is not found (KeyError is caught and raised + as ValueError with context). """ - fields = _parse_fields(sql_template) - - format_kwargs = {} + try: + fields = _parse_fields(sql_template) + except ValueError as e: + raise ValueError( + "Failed to parse SQL template. " + "Did you mean to escape '{' and '}' by doubling them?\n" + f"Error details: {e}" + ) from e + + format_kwargs: dict[str, str] = {} + seen_counts: dict[str, int] = {} for name in fields: - value = pyformat_args[name] + seen_counts[name] = seen_counts.get(name, 0) + 1 + try: + value = pyformat_args[name] + except KeyError as e: + positions = _find_all_field_positions(sql_template) + occurrence_idx = seen_counts[name] - 1 + pos = positions.get((name, occurrence_idx), -1) + context = get_error_context_at_pos(sql_template, pos) + raise ValueError( + f"Undetected variable {name!r} in SQL template. " + "Did you mean to escape '{' and '}' by doubling them?\n" + f"{context}" + ) from e + format_kwargs[name] = _field_to_template_value( name, value, session=session, dry_run=dry_run ) diff --git a/packages/bigframes/tests/unit/core/test_pyformat.py b/packages/bigframes/tests/unit/core/test_pyformat.py index be7f52f4d5d4..416d1b351a31 100644 --- a/packages/bigframes/tests/unit/core/test_pyformat.py +++ b/packages/bigframes/tests/unit/core/test_pyformat.py @@ -62,6 +62,72 @@ def test_parse_fields(sql_template: str, expected: List[str]): assert fields == expected +def test_get_error_context_at_pos_invalid_pos(): + assert pyformat.get_error_context_at_pos("SELECT 1", -1) == "" + assert pyformat.get_error_context_at_pos("SELECT 1", 100) == "" + + +def test_get_error_context_at_pos_single_line(): + sql = "SELECT {foo}" + # pos of '{' is 7 + context = pyformat.get_error_context_at_pos(sql, 7) + expected = " 1: SELECT {foo}\n ^" + assert context == expected + + +def test_get_error_context_at_pos_multi_line(): + sql = "SELECT 1\nFROM my_table\nWHERE col = {foo}\nAND active = True\nLIMIT 10" + # Lines: + # 1: SELECT 1 (len 9 including \n) + # 2: FROM my_table (len 14 including \n) -> total 23 + # 3: WHERE col = {foo} -> '{' is at 23 + 12 = 35 + + context = pyformat.get_error_context_at_pos(sql, 35) + expected = ( + " 1: SELECT 1\n" + " 2: FROM my_table\n" + " 3: WHERE col = {foo}\n" + " ^\n" + " 4: AND active = True\n" + " 5: LIMIT 10" + ) + assert context == expected + + +def test_get_error_context_at_pos_multi_line_limits(): + # Test that it only shows at most 2 lines before and 2 lines after + sql = ( + "LINE 1\n" + "LINE 2\n" + "LINE 3\n" + "LINE 4\n" + "LINE 5\n" + "TARGET {foo}\n" + "LINE 7\n" + "LINE 8\n" + "LINE 9\n" + "LINE 10" + ) + # Line lengths: + # LINE 1\n (7) + # LINE 2\n (7) -> 14 + # LINE 3\n (7) -> 21 + # LINE 4\n (7) -> 28 + # LINE 5\n (7) -> 35 + # TARGET {foo}\n -> '{' is at 35 + 7 = 42 + + context = pyformat.get_error_context_at_pos(sql, 42) + expected = ( + " 4: LINE 4\n" + " 5: LINE 5\n" + " 6: TARGET {foo}\n" + " ^\n" + " 7: LINE 7\n" + " 8: LINE 8" + ) + assert context == expected + + def test_pyformat_with_unsupported_type_raises_typeerror(session): pyformat_args = {"my_object": object()} sql = "SELECT {my_object}" @@ -70,13 +136,67 @@ def test_pyformat_with_unsupported_type_raises_typeerror(session): pyformat.pyformat(sql, pyformat_args=pyformat_args, session=session) -def test_pyformat_with_missing_variable_raises_keyerror(session): +def test_pyformat_with_missing_variable_raises_valueerror(session): pyformat_args: Dict[str, Any] = {} sql = "SELECT {my_object}" - with pytest.raises(KeyError, match="my_object"): + with pytest.raises(ValueError) as exc_info: + pyformat.pyformat(sql, pyformat_args=pyformat_args, session=session) + + err_msg = str(exc_info.value) + assert "Undetected variable 'my_object' in SQL template" in err_msg + assert "Did you mean to escape '{' and '}'" in err_msg + assert " 1: SELECT {my_object}" in err_msg + assert " ^" in err_msg + + +def test_pyformat_with_unescaped_braces_raises_valueerror_with_context(session): + pyformat_args = {"active": True} + sql = """SELECT * FROM my_table +WHERE json_col = { "generation_config": { "temperature": 0.9 } } +AND active = {active} +""" + + with pytest.raises(ValueError) as exc_info: pyformat.pyformat(sql, pyformat_args=pyformat_args, session=session) + err_msg = str(exc_info.value) + assert "Undetected variable ' \"generation_config\"' in SQL template" in err_msg + assert "Did you mean to escape '{' and '}'" in err_msg + # The triple quote string starts with SELECT immediately, so lines are: + # 1: SELECT * FROM my_table + # 2: WHERE json_col = { "generation_config": { "temperature": 0.9 } } + # 3: AND active = {active} + assert " 1: SELECT * FROM my_table" in err_msg + assert ( + ' 2: WHERE json_col = { "generation_config": { "temperature": 0.9 } }' + in err_msg + ) + assert " ^" in err_msg + assert " 3: AND active = {active}" in err_msg + + +def test_pyformat_with_malformed_template_raises_valueerror(session): + pyformat_args: Dict[str, Any] = {} + + # Case 1: Single '{' (unmatched) + sql_1 = "SELECT {foo" + with pytest.raises(ValueError) as exc_info: + pyformat.pyformat(sql_1, pyformat_args=pyformat_args, session=session) + err_msg_1 = str(exc_info.value) + assert "Failed to parse SQL template" in err_msg_1 + assert "Did you mean to escape '{' and '}'" in err_msg_1 + assert "expected '}' before end of string" in err_msg_1 + + # Case 2: Single '}' (unmatched) + sql_2 = "SELECT foo}" + with pytest.raises(ValueError) as exc_info: + pyformat.pyformat(sql_2, pyformat_args=pyformat_args, session=session) + err_msg_2 = str(exc_info.value) + assert "Failed to parse SQL template" in err_msg_2 + assert "Did you mean to escape '{' and '}'" in err_msg_2 + assert "Single '}' encountered in format string" in err_msg_2 + def test_pyformat_with_no_variables(session): pyformat_args: Dict[str, Any] = {}