Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 172 additions & 84 deletions dev/check_pyspark_custom_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,178 @@
# limitations under the License.
#

import ast
import dataclasses
import functools
import json
import sys
import textwrap

sys.path.insert(0, "python")
import os


@dataclasses.dataclass
class CustomErrorFailure:
file_path: str
line_number: int
source_code: str
error_message: str

def __str__(self):
return f"{self.file_path}:({self.line_number}):\n{textwrap.dedent(self.source_code)}{self.error_message}"


@dataclasses.dataclass
class CustomErrorAnalyzer:
def __init__(self, file_path):
self.file_path = file_path
self.file_content = None
self.pyspark_error_list = self._load_error_list()
self.error_conditions = self._load_error_conditions()

@staticmethod
@functools.cache
def _load_error_conditions():
with open("python/pyspark/errors/error-conditions.json", "r") as file:
return json.load(file)

@staticmethod
@functools.cache
def _load_error_list():
# PySpark-specific errors
pyspark_error_list = [
"AnalysisException",
"ArithmeticException",
"ArrayIndexOutOfBoundsException",
"DateTimeException",
"IllegalArgumentException",
"NumberFormatException",
"ParseException",
"PySparkAssertionError",
"PySparkAttributeError",
"PySparkException",
"PySparkImportError",
"PySparkIndexError",
"PySparkKeyError",
"PySparkNotImplementedError",
"PySparkPicklingError",
"PySparkRuntimeError",
"PySparkTypeError",
"PySparkValueError",
"PythonException",
"QueryExecutionException",
"SessionNotSameException",
"SparkNoSuchElementException",
"SparkRuntimeException",
"SparkUpgradeException",
"StreamingQueryException",
"TempTableAlreadyExistsException",
"UnknownException",
"UnsupportedOperationException",
]
connect_error_list = [
"AnalysisException",
"ArithmeticException",
"ArrayIndexOutOfBoundsException",
"BaseAnalysisException",
"BaseArithmeticException",
"BaseArrayIndexOutOfBoundsException",
"BaseDateTimeException",
"BaseIllegalArgumentException",
"BaseNoSuchElementException",
"BaseNumberFormatException",
"BaseParseException",
"BasePythonException",
"BaseQueryExecutionException",
"BaseSparkRuntimeException",
"BaseSparkUpgradeException",
"BaseStreamingQueryException",
"BaseUnsupportedOperationException",
"DateTimeException",
"IllegalArgumentException",
"NumberFormatException",
"ParseException",
"PySparkException",
"PythonException",
"QueryExecutionException",
"SparkConnectException",
"SparkConnectGrpcException",
"SparkException",
"SparkNoSuchElementException",
"SparkRuntimeException",
"SparkUpgradeException",
"StreamingQueryException",
"UnsupportedOperationException",
]
internal_error_list = ["RetryException", "StopIteration"]
pyspark_error_list += connect_error_list
pyspark_error_list += internal_error_list
return pyspark_error_list

def analyze(self):
with open(self.file_path, "r") as file:
self.file_lines = file.readlines()
tree = ast.parse("".join(self.file_lines))

failures = []

for node in ast.walk(tree):
if isinstance(node, ast.Raise):
if isinstance(node.exc, ast.Call):
if failure := self.analyze_call(node.exc):
failures.append(failure)
return failures

def analyze_call(self, call_node):
if not isinstance(call_node.func, ast.Name):
return None

exc_type = call_node.func.id
if exc_type[0].isupper() and exc_type not in self.pyspark_error_list:
return CustomErrorFailure(
file_path=self.file_path,
line_number=call_node.lineno,
source_code=self.get_source(call_node),
error_message=f"custom error '{exc_type}' is not defined in pyspark.errors",
)

keywords = list(call_node.keywords)
# We need to parse errorClass first
keywords.sort(key=lambda x: x.arg)
for keyword in keywords:
if keyword.arg == "errorClass":
if isinstance(keyword.value, ast.Constant):
error_class = keyword.value.value
if error_class not in self.error_conditions:
return CustomErrorFailure(
file_path=self.file_path,
line_number=call_node.lineno,
source_code=self.get_source(call_node),
error_message=f"errorClass '{error_class}' is not defined in error-conditions.json",
)
else:
error_class = None
continue
elif keyword.arg == "messageParameters" and error_class is not None:
if isinstance(keyword.value, ast.Dict):
for key_node in keyword.value.keys:
if isinstance(key_node, ast.Constant):
key = key_node.value
if f"<{key}>" not in self.error_conditions[error_class]["message"][0]:
return CustomErrorFailure(
file_path=self.file_path,
line_number=call_node.lineno,
source_code=self.get_source(call_node),
error_message=f"messageParameter '{key}' is not defined in {error_class}",
)
return None

def get_source(self, node):
# lineno is 1-indexed
return "".join(self.file_lines[node.lineno - 1 : node.end_lineno])


def find_py_files(path, exclude_paths):
"""
Find all .py files in a directory, excluding files in specified subdirectories.
Expand All @@ -49,7 +215,7 @@ def find_py_files(path, exclude_paths):
return py_files


def check_errors_in_file(file_path, pyspark_error_list):
def check_errors_in_file(file_path):
"""
Check if a file uses PySpark-specific errors correctly.

Expand All @@ -65,16 +231,8 @@ def check_errors_in_file(file_path, pyspark_error_list):
list of str
A list of strings describing the errors found in the file, with line numbers.
"""
errors_found = []
with open(file_path, "r") as file:
for line_num, line in enumerate(file, start=1):
if line.strip().startswith("raise"):
parts = line.split()
# Check for 'raise' statement and ensure the error raised is a capitalized word.
if len(parts) > 1 and parts[1][0].isupper():
if not any(pyspark_error in line for pyspark_error in pyspark_error_list):
errors_found.append(f"{file_path}:{line_num}: {line.strip()}")
return errors_found
analyzer = CustomErrorAnalyzer(file_path)
return analyzer.analyze()


def check_pyspark_custom_errors(target_paths, exclude_paths):
Expand All @@ -96,81 +254,12 @@ def check_pyspark_custom_errors(target_paths, exclude_paths):
all_errors = []
for path in target_paths:
for py_file in find_py_files(path, exclude_paths):
file_errors = check_errors_in_file(py_file, pyspark_error_list)
file_errors = check_errors_in_file(py_file)
all_errors.extend(file_errors)
return all_errors


if __name__ == "__main__":
# PySpark-specific errors
pyspark_error_list = [
"AnalysisException",
"ArithmeticException",
"ArrayIndexOutOfBoundsException",
"DateTimeException",
"IllegalArgumentException",
"NumberFormatException",
"ParseException",
"PySparkAssertionError",
"PySparkAttributeError",
"PySparkException",
"PySparkImportError",
"PySparkIndexError",
"PySparkKeyError",
"PySparkNotImplementedError",
"PySparkPicklingError",
"PySparkRuntimeError",
"PySparkTypeError",
"PySparkValueError",
"PythonException",
"QueryExecutionException",
"SessionNotSameException",
"SparkNoSuchElementException",
"SparkRuntimeException",
"SparkUpgradeException",
"StreamingQueryException",
"TempTableAlreadyExistsException",
"UnknownException",
"UnsupportedOperationException",
]
connect_error_list = [
"AnalysisException",
"ArithmeticException",
"ArrayIndexOutOfBoundsException",
"BaseAnalysisException",
"BaseArithmeticException",
"BaseArrayIndexOutOfBoundsException",
"BaseDateTimeException",
"BaseIllegalArgumentException",
"BaseNoSuchElementException",
"BaseNumberFormatException",
"BaseParseException",
"BasePythonException",
"BaseQueryExecutionException",
"BaseSparkRuntimeException",
"BaseSparkUpgradeException",
"BaseStreamingQueryException",
"BaseUnsupportedOperationException",
"DateTimeException",
"IllegalArgumentException",
"NumberFormatException",
"ParseException",
"PySparkException",
"PythonException",
"QueryExecutionException",
"SparkConnectException",
"SparkConnectGrpcException",
"SparkException",
"SparkNoSuchElementException",
"SparkRuntimeException",
"SparkUpgradeException",
"StreamingQueryException",
"UnsupportedOperationException",
]
internal_error_list = ["RetryException", "StopIteration"]
pyspark_error_list += connect_error_list
pyspark_error_list += internal_error_list

# Target paths and exclude paths
TARGET_PATHS = ["python/pyspark/sql"]
EXCLUDE_PATHS = [
Expand All @@ -182,8 +271,7 @@ def check_pyspark_custom_errors(target_paths, exclude_paths):
# Check errors
errors_found = check_pyspark_custom_errors(TARGET_PATHS, EXCLUDE_PATHS)
if errors_found:
print("\nPySpark custom errors check found issues in the following files:", file=sys.stderr)
print("\nPySpark custom errors check found issues:", file=sys.stderr)
for error in errors_found:
print(error, file=sys.stderr)
print("\nUse existing or create a new custom error from pyspark.errors.", file=sys.stderr)
print(error, file=sys.stderr, end="\n\n")
sys.exit(1)
19 changes: 17 additions & 2 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@
},
"DATA_SOURCE_EXTRANEOUS_FILTERS": {
"message": [
"<type>.pushFilters() returned filters that are not part of the input. Make sure that each returned filter is one of the input filters by reference."
"<type>.pushFilters() returned filters that are not part of the input <input>. Extraneous filters: <extraneous>. Make sure that each returned filter is one of the input filters by reference."
]
},
"DATA_SOURCE_INVALID_RETURN_TYPE": {
Expand Down Expand Up @@ -417,7 +417,7 @@
},
"INVALID_MULTIPLE_ARGUMENT_CONDITIONS": {
"message": [
"[{arg_names}] cannot be <condition>."
"[<arg_names>] cannot be <condition>."
]
},
"INVALID_NDARRAY_DIMENSION": {
Expand All @@ -430,6 +430,11 @@
"Invalid number of dataframes in group <dataframes_in_group>."
]
},
"INVALID_OPERATION_UUID_ID": {
"message": [
"Parameter value <arg_name> must be a valid UUID format: <origin>"
]
},
"INVALID_PANDAS_UDF": {
"message": [
"Invalid function: <detail>"
Expand Down Expand Up @@ -1089,6 +1094,11 @@
"<data_type> is not supported in conversion to Arrow."
]
},
"UNSUPPORTED_FILTER": {
"message": [
"<name> is not a supported filter."
]
},
"UNSUPPORTED_JOIN_TYPE": {
"message": [
"Unsupported join type: '<typ>'. Supported join types include: <supported>."
Expand Down Expand Up @@ -1209,6 +1219,11 @@
"Value for `<arg_name>` must be between <lower_bound> and <upper_bound> (inclusive), got <actual>"
]
},
"VARIANT_SIZE_LIMIT_EXCEEDED": {
"message": [
"Variant size (<actual_size> bytes) exceeds the limit (<size_limit> bytes)."
]
},
"WKB_PARSE_ERROR" : {
"message" : [
"Error parsing WKB: <parseError> at position <pos>"
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/classic/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,7 @@ def dropna(
if how is not None and how not in ["any", "all"]:
raise PySparkValueError(
errorClass="VALUE_NOT_ANY_OR_ALL",
messageParameters={"arg_name": "how", "arg_type": how},
messageParameters={"arg_name": "how", "arg_value": how},
)

if subset is None:
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(
if self._session is None:
raise PySparkRuntimeError(
errorClass="NO_ACTIVE_SESSION",
messageParameters={"operator": "__init__"},
messageParameters={},
)

# Check whether _repr_html is supported or not, we use it to avoid calling RPC twice
Expand Down Expand Up @@ -1363,7 +1363,7 @@ def dropna(
min_non_nulls = None
else:
raise PySparkValueError(
errorClass="CANNOT_BE_EMPTY",
errorClass="VALUE_NOT_ANY_OR_ALL",
messageParameters={"arg_name": "how", "arg_value": str(how)},
)

Expand Down
Loading