From db6e692c582049451f9093be09010ad695f9f967 Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Wed, 25 Mar 2026 13:06:15 -0700 Subject: [PATCH 1/3] Add errorClass and messageParameters check for custom errors --- dev/check_pyspark_custom_errors.py | 256 +++++++++++++++++++---------- 1 file changed, 172 insertions(+), 84 deletions(-) diff --git a/dev/check_pyspark_custom_errors.py b/dev/check_pyspark_custom_errors.py index db152c77d1b86..b181f19b5157a 100644 --- a/dev/check_pyspark_custom_errors.py +++ b/dev/check_pyspark_custom_errors.py @@ -17,12 +17,178 @@ # limitations under the License. # +import ast +import dataclasses +import functools +import json import sys +import textwrap sys.path.insert(0, "python") import os +@dataclasses.dataclass +class CustomErrorFailure: + file_path: str + line_number: int + source_code: str + error_message: str + + def __str__(self): + return f"{self.file_path}:({self.line_number}):\n{textwrap.dedent(self.source_code)}{self.error_message}" + + +@dataclasses.dataclass +class CustomErrorAnalyzer: + def __init__(self, file_path): + self.file_path = file_path + self.file_content = None + self.pyspark_error_list = self._load_error_list() + self.error_conditions = self._load_error_conditions() + + @staticmethod + @functools.cache + def _load_error_conditions(): + with open("python/pyspark/errors/error-conditions.json", "r") as file: + return json.load(file) + + @staticmethod + @functools.cache + def _load_error_list(): + # PySpark-specific errors + pyspark_error_list = [ + "AnalysisException", + "ArithmeticException", + "ArrayIndexOutOfBoundsException", + "DateTimeException", + "IllegalArgumentException", + "NumberFormatException", + "ParseException", + "PySparkAssertionError", + "PySparkAttributeError", + "PySparkException", + "PySparkImportError", + "PySparkIndexError", + "PySparkKeyError", + "PySparkNotImplementedError", + "PySparkPicklingError", + "PySparkRuntimeError", + "PySparkTypeError", + "PySparkValueError", + "PythonException", + "QueryExecutionException", + "SessionNotSameException", + "SparkNoSuchElementException", + "SparkRuntimeException", + "SparkUpgradeException", + "StreamingQueryException", + "TempTableAlreadyExistsException", + "UnknownException", + "UnsupportedOperationException", + ] + connect_error_list = [ + "AnalysisException", + "ArithmeticException", + "ArrayIndexOutOfBoundsException", + "BaseAnalysisException", + "BaseArithmeticException", + "BaseArrayIndexOutOfBoundsException", + "BaseDateTimeException", + "BaseIllegalArgumentException", + "BaseNoSuchElementException", + "BaseNumberFormatException", + "BaseParseException", + "BasePythonException", + "BaseQueryExecutionException", + "BaseSparkRuntimeException", + "BaseSparkUpgradeException", + "BaseStreamingQueryException", + "BaseUnsupportedOperationException", + "DateTimeException", + "IllegalArgumentException", + "NumberFormatException", + "ParseException", + "PySparkException", + "PythonException", + "QueryExecutionException", + "SparkConnectException", + "SparkConnectGrpcException", + "SparkException", + "SparkNoSuchElementException", + "SparkRuntimeException", + "SparkUpgradeException", + "StreamingQueryException", + "UnsupportedOperationException", + ] + internal_error_list = ["RetryException", "StopIteration"] + pyspark_error_list += connect_error_list + pyspark_error_list += internal_error_list + return pyspark_error_list + + def analyze(self): + with open(self.file_path, "r") as file: + self.file_lines = file.readlines() + tree = ast.parse("".join(self.file_lines)) + + failures = [] + + for node in ast.walk(tree): + if isinstance(node, ast.Raise): + if isinstance(node.exc, ast.Call): + if failure := self.analyze_call(node.exc): + failures.append(failure) + return failures + + def analyze_call(self, call_node): + if not isinstance(call_node.func, ast.Name): + return None + + exc_type = call_node.func.id + if exc_type[0].isupper() and exc_type not in self.pyspark_error_list: + return CustomErrorFailure( + file_path=self.file_path, + line_number=call_node.lineno, + source_code=self.get_source(call_node), + error_message=f"custom error '{exc_type}' is not defined in pyspark.errors", + ) + + keywords = list(call_node.keywords) + # We need to parse errorClass first + keywords.sort(key=lambda x: x.arg) + for keyword in keywords: + if keyword.arg == "errorClass": + if isinstance(keyword.value, ast.Constant): + error_class = keyword.value.value + if error_class not in self.error_conditions: + return CustomErrorFailure( + file_path=self.file_path, + line_number=call_node.lineno, + source_code=self.get_source(call_node), + error_message=f"errorClass '{error_class}' is not defined in error-conditions.json", + ) + else: + error_class = None + continue + elif keyword.arg == "messageParameters" and error_class is not None: + if isinstance(keyword.value, ast.Dict): + for key_node in keyword.value.keys: + if isinstance(key_node, ast.Constant): + key = key_node.value + if f"<{key}>" not in self.error_conditions[error_class]["message"][0]: + return CustomErrorFailure( + file_path=self.file_path, + line_number=call_node.lineno, + source_code=self.get_source(call_node), + error_message=f"messageParameter '{key}' is not defined in {error_class}", + ) + return None + + def get_source(self, node): + # lineno is 1-indexed + return "".join(self.file_lines[node.lineno - 1 : node.end_lineno]) + + def find_py_files(path, exclude_paths): """ Find all .py files in a directory, excluding files in specified subdirectories. @@ -49,7 +215,7 @@ def find_py_files(path, exclude_paths): return py_files -def check_errors_in_file(file_path, pyspark_error_list): +def check_errors_in_file(file_path): """ Check if a file uses PySpark-specific errors correctly. @@ -65,16 +231,8 @@ def check_errors_in_file(file_path, pyspark_error_list): list of str A list of strings describing the errors found in the file, with line numbers. """ - errors_found = [] - with open(file_path, "r") as file: - for line_num, line in enumerate(file, start=1): - if line.strip().startswith("raise"): - parts = line.split() - # Check for 'raise' statement and ensure the error raised is a capitalized word. - if len(parts) > 1 and parts[1][0].isupper(): - if not any(pyspark_error in line for pyspark_error in pyspark_error_list): - errors_found.append(f"{file_path}:{line_num}: {line.strip()}") - return errors_found + analyzer = CustomErrorAnalyzer(file_path) + return analyzer.analyze() def check_pyspark_custom_errors(target_paths, exclude_paths): @@ -96,81 +254,12 @@ def check_pyspark_custom_errors(target_paths, exclude_paths): all_errors = [] for path in target_paths: for py_file in find_py_files(path, exclude_paths): - file_errors = check_errors_in_file(py_file, pyspark_error_list) + file_errors = check_errors_in_file(py_file) all_errors.extend(file_errors) return all_errors if __name__ == "__main__": - # PySpark-specific errors - pyspark_error_list = [ - "AnalysisException", - "ArithmeticException", - "ArrayIndexOutOfBoundsException", - "DateTimeException", - "IllegalArgumentException", - "NumberFormatException", - "ParseException", - "PySparkAssertionError", - "PySparkAttributeError", - "PySparkException", - "PySparkImportError", - "PySparkIndexError", - "PySparkKeyError", - "PySparkNotImplementedError", - "PySparkPicklingError", - "PySparkRuntimeError", - "PySparkTypeError", - "PySparkValueError", - "PythonException", - "QueryExecutionException", - "SessionNotSameException", - "SparkNoSuchElementException", - "SparkRuntimeException", - "SparkUpgradeException", - "StreamingQueryException", - "TempTableAlreadyExistsException", - "UnknownException", - "UnsupportedOperationException", - ] - connect_error_list = [ - "AnalysisException", - "ArithmeticException", - "ArrayIndexOutOfBoundsException", - "BaseAnalysisException", - "BaseArithmeticException", - "BaseArrayIndexOutOfBoundsException", - "BaseDateTimeException", - "BaseIllegalArgumentException", - "BaseNoSuchElementException", - "BaseNumberFormatException", - "BaseParseException", - "BasePythonException", - "BaseQueryExecutionException", - "BaseSparkRuntimeException", - "BaseSparkUpgradeException", - "BaseStreamingQueryException", - "BaseUnsupportedOperationException", - "DateTimeException", - "IllegalArgumentException", - "NumberFormatException", - "ParseException", - "PySparkException", - "PythonException", - "QueryExecutionException", - "SparkConnectException", - "SparkConnectGrpcException", - "SparkException", - "SparkNoSuchElementException", - "SparkRuntimeException", - "SparkUpgradeException", - "StreamingQueryException", - "UnsupportedOperationException", - ] - internal_error_list = ["RetryException", "StopIteration"] - pyspark_error_list += connect_error_list - pyspark_error_list += internal_error_list - # Target paths and exclude paths TARGET_PATHS = ["python/pyspark/sql"] EXCLUDE_PATHS = [ @@ -182,8 +271,7 @@ def check_pyspark_custom_errors(target_paths, exclude_paths): # Check errors errors_found = check_pyspark_custom_errors(TARGET_PATHS, EXCLUDE_PATHS) if errors_found: - print("\nPySpark custom errors check found issues in the following files:", file=sys.stderr) + print("\nPySpark custom errors check found issues:", file=sys.stderr) for error in errors_found: - print(error, file=sys.stderr) - print("\nUse existing or create a new custom error from pyspark.errors.", file=sys.stderr) + print(error, file=sys.stderr, end="\n\n") sys.exit(1) From 6eb665aa5f36c5b491ca4a2771d255adc5956d07 Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Wed, 25 Mar 2026 13:13:08 -0700 Subject: [PATCH 2/3] Add operation uuid error condition --- python/pyspark/errors/error-conditions.json | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 5f7d2da6398cc..bdaeb8343b3cc 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -425,6 +425,11 @@ "Invalid number of dataframes in group ." ] }, + "INVALID_OPERATION_UUID_ID": { + "message": [ + "Parameter value must be a valid UUID format: " + ] + }, "INVALID_PANDAS_UDF": { "message": [ "Invalid function: " From 5431b03c5875bf3950f8d611ac95b47519a7598c Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Wed, 25 Mar 2026 16:17:58 -0700 Subject: [PATCH 3/3] Fix error conditions and exceptions --- python/pyspark/errors/error-conditions.json | 14 ++++++++++++-- python/pyspark/sql/classic/dataframe.py | 3 +-- python/pyspark/sql/connect/dataframe.py | 4 ++-- python/pyspark/sql/variant_utils.py | 15 ++++++++++++--- 4 files changed, 27 insertions(+), 9 deletions(-) diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index bdaeb8343b3cc..0d5a95713aa05 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -194,7 +194,7 @@ }, "DATA_SOURCE_EXTRANEOUS_FILTERS": { "message": [ - ".pushFilters() returned filters that are not part of the input. Make sure that each returned filter is one of the input filters by reference." + ".pushFilters() returned filters that are not part of the input . Extraneous filters: . Make sure that each returned filter is one of the input filters by reference." ] }, "DATA_SOURCE_INVALID_RETURN_TYPE": { @@ -412,7 +412,7 @@ }, "INVALID_MULTIPLE_ARGUMENT_CONDITIONS": { "message": [ - "[{arg_names}] cannot be ." + "[] cannot be ." ] }, "INVALID_NDARRAY_DIMENSION": { @@ -1249,6 +1249,11 @@ " is not supported in conversion to Arrow." ] }, + "UNSUPPORTED_FILTER": { + "message": [ + " is not a supported filter." + ] + }, "UNSUPPORTED_JOIN_TYPE": { "message": [ "Unsupported join type: ''. Supported join types include: ." @@ -1369,6 +1374,11 @@ "Value for `` must be between and (inclusive), got " ] }, + "VARIANT_SIZE_LIMIT_EXCEEDED": { + "message": [ + "Variant size ( bytes) exceeds the limit ( bytes)." + ] + }, "WKB_PARSE_ERROR" : { "message" : [ "Error parsing WKB: at position " diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index 1d8a074e43f88..2d99afa05b212 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -1199,7 +1199,6 @@ def observe( errorClass="NOT_LIST_OF_COLUMN", messageParameters={ "arg_name": "observation", - "arg_type": type(observation).__name__, }, ) @@ -1270,7 +1269,7 @@ def dropna( if how is not None and how not in ["any", "all"]: raise PySparkValueError( errorClass="VALUE_NOT_ANY_OR_ALL", - messageParameters={"arg_name": "how", "arg_type": how}, + messageParameters={"arg_name": "how", "arg_value": how}, ) if subset is None: diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index cf2c3b0de1a20..49d66bcd21fd0 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -136,7 +136,7 @@ def __init__( if self._session is None: raise PySparkRuntimeError( errorClass="NO_ACTIVE_SESSION", - messageParameters={"operator": "__init__"}, + messageParameters={}, ) # Check whether _repr_html is supported or not, we use it to avoid calling RPC twice @@ -1277,7 +1277,7 @@ def dropna( min_non_nulls = None else: raise PySparkValueError( - errorClass="CANNOT_BE_EMPTY", + errorClass="VALUE_NOT_ANY_OR_ALL", messageParameters={"arg_name": "how", "arg_value": str(how)}, ) diff --git a/python/pyspark/sql/variant_utils.py b/python/pyspark/sql/variant_utils.py index 3025523064e1d..46886df32b28a 100644 --- a/python/pyspark/sql/variant_utils.py +++ b/python/pyspark/sql/variant_utils.py @@ -563,14 +563,20 @@ def build(self, json_str: str) -> Tuple[bytes, bytes]: # calculation in case of pathological data. max_size = max(dictionary_string_size, num_keys) if max_size > self.size_limit: - raise PySparkValueError(errorClass="VARIANT_SIZE_LIMIT_EXCEEDED", messageParameters={}) + raise PySparkValueError( + errorClass="VARIANT_SIZE_LIMIT_EXCEEDED", + messageParameters={"actual_size": max_size, "size_limit": self.size_limit}, + ) offset_size = self._get_integer_size(max_size) offset_start = 1 + offset_size string_start = offset_start + (num_keys + 1) * offset_size metadata_size = string_start + dictionary_string_size if metadata_size > self.size_limit: - raise PySparkValueError(errorClass="VARIANT_SIZE_LIMIT_EXCEEDED", messageParameters={}) + raise PySparkValueError( + errorClass="VARIANT_SIZE_LIMIT_EXCEEDED", + messageParameters={"actual_size": metadata_size, "size_limit": self.size_limit}, + ) metadata = bytearray() header_byte = VariantUtils.VERSION | ((offset_size - 1) << 6) @@ -631,7 +637,10 @@ def _get_integer_size(self, value: int) -> int: def _check_capacity(self, additional: int) -> None: required = len(self.value) + additional if required > self.size_limit: - raise PySparkValueError(errorClass="VARIANT_SIZE_LIMIT_EXCEEDED", messageParameters={}) + raise PySparkValueError( + errorClass="VARIANT_SIZE_LIMIT_EXCEEDED", + messageParameters={"actual_size": required, "size_limit": self.size_limit}, + ) def _primitive_header(self, type: int) -> bytes: return bytes([(type << 2) | VariantUtils.PRIMITIVE])