From edcea63587b7d22f86dff97e14ead95b8f24079f Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 13:11:13 +0900 Subject: [PATCH 01/18] Add result_set_type_hints parameter for precise complex type conversion The Athena GetQueryResults API only returns base type names (e.g., "array", "map", "row") without nested type signatures, causing _convert_value() to use heuristic inference that incorrectly converts varchar values like "1234" to int(1234) inside complex types. This adds a result_set_type_hints parameter to all cursor execute() methods so users can provide full Athena DDL type signatures for precise conversion. Also changes the default behavior so nested elements without type hints remain as strings instead of being heuristically inferred (breaking change). Closes #689 Co-Authored-By: Claude Opus 4.6 --- pyathena/aio/cursor.py | 5 + pyathena/aio/result_set.py | 14 +- pyathena/arrow/async_cursor.py | 4 + pyathena/arrow/converter.py | 4 +- pyathena/arrow/cursor.py | 5 + pyathena/arrow/result_set.py | 2 + pyathena/async_cursor.py | 15 +- pyathena/converter.py | 427 ++++++++++++++++++++++++++++++- pyathena/cursor.py | 41 +-- pyathena/pandas/async_cursor.py | 4 + pyathena/pandas/converter.py | 4 +- pyathena/pandas/cursor.py | 5 + pyathena/pandas/result_set.py | 4 + pyathena/polars/async_cursor.py | 7 + pyathena/polars/converter.py | 4 +- pyathena/polars/cursor.py | 5 + pyathena/polars/result_set.py | 4 + pyathena/result_set.py | 16 +- pyathena/s3fs/async_cursor.py | 10 + pyathena/s3fs/converter.py | 12 +- pyathena/s3fs/cursor.py | 5 + pyathena/s3fs/result_set.py | 24 +- tests/pyathena/test_converter.py | 356 ++++++++++++++++++++++---- 23 files changed, 883 insertions(+), 94 deletions(-) diff --git a/pyathena/aio/cursor.py b/pyathena/aio/cursor.py index 30738f8f..d7998857 100644 --- a/pyathena/aio/cursor.py +++ b/pyathena/aio/cursor.py @@ -79,6 +79,7 @@ async def execute( # type: ignore[override] result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, + result_set_type_hints: dict[str, str] | None = None, **kwargs, ) -> AioCursor: """Execute a SQL query asynchronously. @@ -93,6 +94,9 @@ async def execute( # type: ignore[override] result_reuse_enable: Enable result reuse (optional). result_reuse_minutes: Result reuse duration in minutes (optional). paramstyle: Parameter style to use (optional). + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. **kwargs: Additional execution parameters. Returns: @@ -119,6 +123,7 @@ async def execute( # type: ignore[override] query_execution, self.arraysize, self._retry_config, + result_set_type_hints=result_set_type_hints, ) else: raise OperationalError(query_execution.state_change_reason) diff --git a/pyathena/aio/result_set.py b/pyathena/aio/result_set.py index 000337bd..97f4ad2b 100644 --- a/pyathena/aio/result_set.py +++ b/pyathena/aio/result_set.py @@ -35,6 +35,7 @@ def __init__( query_execution: AthenaQueryExecution, arraysize: int, retry_config: RetryConfig, + result_set_type_hints: dict[str, str] | None = None, ) -> None: super().__init__( connection=connection, @@ -43,6 +44,7 @@ def __init__( arraysize=arraysize, retry_config=retry_config, _pre_fetch=False, + result_set_type_hints=result_set_type_hints, ) @classmethod @@ -53,6 +55,7 @@ async def create( query_execution: AthenaQueryExecution, arraysize: int, retry_config: RetryConfig, + result_set_type_hints: dict[str, str] | None = None, ) -> AthenaAioResultSet: """Async factory method. @@ -64,11 +67,20 @@ async def create( query_execution: Query execution metadata. arraysize: Number of rows to fetch per request. retry_config: Retry configuration for API calls. + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion. Returns: A fully initialized ``AthenaAioResultSet``. """ - result_set = cls(connection, converter, query_execution, arraysize, retry_config) + result_set = cls( + connection, + converter, + query_execution, + arraysize, + retry_config, + result_set_type_hints=result_set_type_hints, + ) if result_set.state == AthenaQueryExecution.STATE_SUCCEEDED: await result_set._async_pre_fetch() return result_set diff --git a/pyathena/arrow/async_cursor.py b/pyathena/arrow/async_cursor.py index 39d2950b..bf0e538b 100644 --- a/pyathena/arrow/async_cursor.py +++ b/pyathena/arrow/async_cursor.py @@ -149,6 +149,7 @@ def arraysize(self, value: int) -> None: def _collect_result_set( self, query_id: str, + result_set_type_hints: dict[str, str] | None = None, unload_location: str | None = None, kwargs: dict[str, Any] | None = None, ) -> AthenaArrowResultSet: @@ -165,6 +166,7 @@ def _collect_result_set( unload_location=unload_location, connect_timeout=self._connect_timeout, request_timeout=self._request_timeout, + result_set_type_hints=result_set_type_hints, **kwargs, ) @@ -179,6 +181,7 @@ def execute( result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, + result_set_type_hints: dict[str, str] | None = None, **kwargs, ) -> tuple[str, Future[AthenaArrowResultSet | Any]]: operation, unload_location = self._prepare_unload(operation, s3_staging_dir) @@ -198,6 +201,7 @@ def execute( self._executor.submit( self._collect_result_set, query_id, + result_set_type_hints, unload_location, kwargs, ), diff --git a/pyathena/arrow/converter.py b/pyathena/arrow/converter.py index 1e5d5d91..26774186 100644 --- a/pyathena/arrow/converter.py +++ b/pyathena/arrow/converter.py @@ -90,7 +90,7 @@ def _dtypes(self) -> dict[str, type[Any]]: } return self.__dtypes - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: converter = self.get(type_) return converter(value) @@ -114,5 +114,5 @@ def __init__(self) -> None: default=_to_default, ) - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: pass diff --git a/pyathena/arrow/cursor.py b/pyathena/arrow/cursor.py index 8f831831..1876eebb 100644 --- a/pyathena/arrow/cursor.py +++ b/pyathena/arrow/cursor.py @@ -137,6 +137,7 @@ def execute( result_reuse_minutes: int | None = None, paramstyle: str | None = None, on_start_query_execution: Callable[[str], None] | None = None, + result_set_type_hints: dict[str, str] | None = None, **kwargs, ) -> ArrowCursor: """Execute a SQL query and return results as Apache Arrow Tables. @@ -156,6 +157,9 @@ def execute( result_reuse_minutes: Minutes to reuse cached results. paramstyle: Parameter style ('qmark' or 'pyformat'). on_start_query_execution: Callback called when query starts. + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. **kwargs: Additional execution parameters. Returns: @@ -197,6 +201,7 @@ def execute( unload_location=unload_location, connect_timeout=self._connect_timeout, request_timeout=self._request_timeout, + result_set_type_hints=result_set_type_hints, **kwargs, ) else: diff --git a/pyathena/arrow/result_set.py b/pyathena/arrow/result_set.py index 12fab1ac..33027497 100644 --- a/pyathena/arrow/result_set.py +++ b/pyathena/arrow/result_set.py @@ -91,6 +91,7 @@ def __init__( unload_location: str | None = None, connect_timeout: float | None = None, request_timeout: float | None = None, + result_set_type_hints: dict[str, str] | None = None, **kwargs, ) -> None: super().__init__( @@ -99,6 +100,7 @@ def __init__( query_execution=query_execution, arraysize=1, # Fetch one row to retrieve metadata retry_config=retry_config, + result_set_type_hints=result_set_type_hints, ) self._rows.clear() # Clear pre_fetch data self._arraysize = arraysize diff --git a/pyathena/async_cursor.py b/pyathena/async_cursor.py index 1716cc3c..986a413c 100644 --- a/pyathena/async_cursor.py +++ b/pyathena/async_cursor.py @@ -144,7 +144,11 @@ def poll(self, query_id: str) -> Future[AthenaQueryExecution]: """ return cast("Future[AthenaQueryExecution]", self._executor.submit(self._poll, query_id)) - def _collect_result_set(self, query_id: str) -> AthenaResultSet: + def _collect_result_set( + self, + query_id: str, + result_set_type_hints: dict[str, str] | None = None, + ) -> AthenaResultSet: query_execution = cast(AthenaQueryExecution, self._poll(query_id)) return self._result_set_class( connection=self._connection, @@ -152,6 +156,7 @@ def _collect_result_set(self, query_id: str) -> AthenaResultSet: query_execution=query_execution, arraysize=self._arraysize, retry_config=self._retry_config, + result_set_type_hints=result_set_type_hints, ) def execute( @@ -165,6 +170,7 @@ def execute( result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, + result_set_type_hints: dict[str, str] | None = None, **kwargs, ) -> tuple[str, Future[AthenaResultSet | Any]]: """Execute a SQL query asynchronously. @@ -183,6 +189,9 @@ def execute( result_reuse_enable: Enable result reuse for identical queries (optional). result_reuse_minutes: Result reuse duration in minutes (optional). paramstyle: Parameter style to use (optional). + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. **kwargs: Additional execution parameters. Returns: @@ -207,7 +216,9 @@ def execute( result_reuse_minutes=result_reuse_minutes, paramstyle=paramstyle, ) - return query_id, self._executor.submit(self._collect_result_set, query_id) + return query_id, self._executor.submit( + self._collect_result_set, query_id, result_set_type_hints + ) def executemany( self, diff --git a/pyathena/converter.py b/pyathena/converter.py index ce664c81..b83500b4 100644 --- a/pyathena/converter.py +++ b/pyathena/converter.py @@ -6,6 +6,7 @@ from abc import ABCMeta, abstractmethod from collections.abc import Callable from copy import deepcopy +from dataclasses import dataclass, field from datetime import date, datetime, time from decimal import Decimal from typing import Any @@ -395,24 +396,22 @@ def _parse_unnamed_struct(inner: str) -> dict[str, Any]: def _convert_value(value: str) -> Any: - """Convert string value to appropriate Python type. + """Convert string value without type inference. + + Returns the string as-is, except for null which becomes None. + This is a safe default that avoids incorrect type conversions + (e.g., converting varchar "1234" to int 1234 inside complex types). + + Use :func:`_convert_value_with_type` for type-aware conversion. Args: value: String value to convert. Returns: - Converted value as int, float, bool, None, or string. + None for "null" values, otherwise the original string. """ if value.lower() == "null": return None - if value.lower() == "true": - return True - if value.lower() == "false": - return False - if value.isdigit() or (value.startswith("-") and value[1:].isdigit()): - return int(value) - if "." in value and value.replace(".", "", 1).replace("-", "", 1).isdigit(): - return float(value) return value @@ -420,6 +419,384 @@ def _to_default(varchar_value: str | None) -> str | None: return varchar_value +@dataclass +class TypeNode: + """Parsed representation of an Athena DDL type signature. + + Represents a node in a type tree, where complex types (array, map, row) + have children representing their element/field types. + + Attributes: + type_name: The base type name (e.g., "array", "map", "row", "varchar"). + children: Child type nodes for complex types. + field_names: Field names for row/struct types (parallel to children). + """ + + type_name: str + children: list[TypeNode] = field(default_factory=list) + field_names: list[str] | None = None + + +def _split_top_level(s: str, delimiter: str = ",") -> list[str]: + """Split a string by delimiter, respecting nested parentheses. + + Args: + s: String to split. + delimiter: Character to split on. + + Returns: + List of parts split at top-level delimiters. + """ + parts: list[str] = [] + current: list[str] = [] + depth = 0 + + for char in s: + if char == "(": + depth += 1 + elif char == ")": + depth -= 1 + elif char == delimiter and depth == 0: + parts.append("".join(current).strip()) + current = [] + continue + current.append(char) + + if current: + parts.append("".join(current).strip()) + return parts + + +def parse_type_signature(type_str: str) -> TypeNode: + """Parse an Athena DDL type signature string into a TypeNode tree. + + Handles simple types (varchar, integer), parameterized types (decimal(10,2)), + and complex types (array, map, row/struct) with arbitrary nesting. + + Args: + type_str: Athena DDL type string (e.g., "array(row(name varchar, age integer))"). + + Returns: + TypeNode representing the parsed type tree. + """ + type_str = type_str.strip() + + paren_idx = type_str.find("(") + if paren_idx == -1: + return TypeNode(type_name=type_str.lower()) + + type_name = type_str[:paren_idx].strip().lower() + + # Find matching closing paren + inner = type_str[paren_idx + 1 : -1].strip() + + if type_name in ("row", "struct"): + parts = _split_top_level(inner) + field_names: list[str] = [] + children: list[TypeNode] = [] + for part in parts: + part = part.strip() + # Split into field_name and type at first space + space_idx = _find_field_name_boundary(part) + if space_idx == -1: + children.append(parse_type_signature(part)) + field_names.append(part) + else: + field_name = part[:space_idx].strip() + type_part = part[space_idx + 1 :].strip() + field_names.append(field_name) + children.append(parse_type_signature(type_part)) + return TypeNode(type_name=type_name, children=children, field_names=field_names) + + if type_name == "array": + child = parse_type_signature(inner) + return TypeNode(type_name=type_name, children=[child]) + + if type_name == "map": + parts = _split_top_level(inner) + if len(parts) == 2: + key_type = parse_type_signature(parts[0]) + value_type = parse_type_signature(parts[1]) + return TypeNode(type_name=type_name, children=[key_type, value_type]) + return TypeNode(type_name=type_name) + + # Types with parameters like decimal(10, 2), varchar(255) + return TypeNode(type_name=type_name) + + +def _find_field_name_boundary(part: str) -> int: + """Find the boundary between field name and type in a row field definition. + + Handles cases like "name varchar" and "data row(x integer, y integer)". + + Args: + part: A single field definition string. + + Returns: + Index of the space separating field name from type, or -1 if not found. + """ + depth = 0 + for i, char in enumerate(part): + if char == "(": + depth += 1 + elif char == ")": + depth -= 1 + elif char == " " and depth == 0: + return i + return -1 + + +def _convert_value_with_type(value: str, type_node: TypeNode) -> Any: + """Convert a value using type information from a TypeNode. + + For complex types (array, map, row), parses the structure and + recursively converts elements using child type information. + For simple types, uses the standard converter function. + + Args: + value: String value to convert. + type_node: Parsed type information. + + Returns: + Converted value. + """ + if type_node.type_name == "array": + return _convert_typed_array(value, type_node) + if type_node.type_name == "map": + return _convert_typed_map(value, type_node) + if type_node.type_name in ("row", "struct"): + return _convert_typed_struct(value, type_node) + # Simple type: use the standard converter + converter_fn = _DEFAULT_CONVERTERS.get(type_node.type_name, _to_default) + return converter_fn(value) + + +def _convert_element(value: str, type_node: TypeNode) -> Any: + """Convert a single element within a complex type using type information. + + Handles null values before delegating to type-specific conversion. + + Args: + value: String value to convert. + type_node: Type information for this element. + + Returns: + Converted value, or None for null. + """ + if value.lower() == "null": + return None + return _convert_value_with_type(value, type_node) + + +def _convert_typed_array(value: str, type_node: TypeNode) -> list[Any] | None: + """Convert an array value using type information. + + Args: + value: String representation of the array. + type_node: Type node with array element type as first child. + + Returns: + List of converted elements, or None if parsing fails. + """ + if not (value.startswith("[") and value.endswith("]")): + return None + + element_type = type_node.children[0] if type_node.children else TypeNode("varchar") + + # Try JSON first + try: + parsed = json.loads(value) + if isinstance(parsed, list): + return [ + None if elem is None else _convert_element(str(elem), element_type) + for elem in parsed + ] + except json.JSONDecodeError: + pass + + # Native format + inner = value[1:-1].strip() + if not inner: + return [] + + if "[" in inner: + return None # Nested arrays not supported in native format + + items = _split_array_items(inner) + result: list[Any] = [] + for item in items: + item = item.strip() + if not item: + continue + if item.startswith("{") and item.endswith("}"): + if element_type.type_name in ("row", "struct"): + result.append(_convert_typed_struct(item, element_type)) + elif element_type.type_name == "map": + result.append(_convert_typed_map(item, element_type)) + else: + result.append(_to_struct(item)) + else: + result.append(_convert_element(item, element_type)) + + return result if result else None + + +def _convert_typed_map(value: str, type_node: TypeNode) -> dict[str, Any] | None: + """Convert a map value using type information. + + Args: + value: String representation of the map. + type_node: Type node with key type and value type as children. + + Returns: + Dictionary of converted key-value pairs, or None if parsing fails. + """ + if not (value.startswith("{") and value.endswith("}")): + return None + + key_type = type_node.children[0] if len(type_node.children) > 0 else TypeNode("varchar") + value_type = type_node.children[1] if len(type_node.children) > 1 else TypeNode("varchar") + + # Try JSON first + inner_preview = value[1:10] if len(value) > 10 else value[1:-1] + if '"' in inner_preview or value.startswith('{"'): + try: + parsed = json.loads(value) + if isinstance(parsed, dict): + return { + str( + _convert_element(str(k), key_type) if k is not None else k + ): _convert_element(str(v), value_type) if v is not None else None + for k, v in parsed.items() + } + except json.JSONDecodeError: + pass + + # Native format + inner = value[1:-1].strip() + if not inner: + return {} + + if any(char in inner for char in "()[]"): + return None + + pairs = [pair.strip() for pair in inner.split(",")] + result: dict[str, Any] = {} + for pair in pairs: + if "=" not in pair: + continue + k, v = pair.split("=", 1) + k = k.strip() + v = v.strip() + if any(char in k for char in '{}="') or any(char in v for char in '{}="'): + continue + converted_key = _convert_element(k, key_type) + converted_value = _convert_element(v, value_type) + result[str(converted_key)] = converted_value + + return result if result else None + + +def _convert_typed_struct(value: str, type_node: TypeNode) -> dict[str, Any] | None: + """Convert a struct/row value using type information. + + Args: + value: String representation of the struct. + type_node: Type node with field types and names. + + Returns: + Dictionary of converted field values, or None if parsing fails. + """ + if not (value.startswith("{") and value.endswith("}")): + return None + + field_names = type_node.field_names or [] + field_types = type_node.children or [] + + # Try JSON first + inner_preview = value[1:10] if len(value) > 10 else value[1:-1] + if '"' in inner_preview or value.startswith('{"'): + try: + parsed = json.loads(value) + if isinstance(parsed, dict): + result: dict[str, Any] = {} + for i, (k, v) in enumerate(parsed.items()): + ft = field_types[i] if i < len(field_types) else TypeNode("varchar") + result[k] = _convert_element(str(v), ft) if v is not None else None + return result + except json.JSONDecodeError: + pass + + inner = value[1:-1].strip() + if not inner: + return {} + + if "=" in inner: + # Named struct + pairs = _split_array_items(inner) + result = {} + field_index = 0 + for pair in pairs: + if "=" not in pair: + continue + k, v = pair.split("=", 1) + k = k.strip() + v = v.strip() + if any(char in k for char in '{}="'): + continue + + ft = _get_field_type(k, field_names, field_types, field_index) + field_index += 1 + + if v.startswith("{") and v.endswith("}"): + if ft.type_name in ("row", "struct"): + result[k] = _convert_typed_struct(v, ft) + elif ft.type_name == "map": + result[k] = _convert_typed_map(v, ft) + else: + result[k] = _to_struct(v) + else: + result[k] = _convert_element(v, ft) + return result if result else None + + # Unnamed struct + values = [v.strip() for v in inner.split(",")] + result = {} + for i, v in enumerate(values): + ft = field_types[i] if i < len(field_types) else TypeNode("varchar") + name = field_names[i] if i < len(field_names) else str(i) + result[name] = _convert_element(v, ft) + return result + + +def _get_field_type( + field_name: str, + field_names: list[str], + field_types: list[TypeNode], + field_index: int, +) -> TypeNode: + """Look up the type for a struct field by name or index. + + Tries name-based lookup first, then falls back to positional index. + + Args: + field_name: Name of the field to look up. + field_names: List of known field names from the type hint. + field_types: List of corresponding field types. + field_index: Current positional index as fallback. + + Returns: + TypeNode for the field, defaulting to varchar if not found. + """ + if field_name in field_names: + idx = field_names.index(field_name) + if idx < len(field_types): + return field_types[idx] + if field_index < len(field_types): + return field_types[field_index] + return TypeNode("varchar") + + _DEFAULT_CONVERTERS: dict[str, Callable[[str | None], Any | None]] = { "boolean": _to_boolean, "tinyint": _to_int, @@ -549,7 +926,7 @@ def update(self, mappings: dict[str, Callable[[str | None], Any | None]]) -> Non self.mappings.update(mappings) @abstractmethod - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: raise NotImplementedError # pragma: no cover @@ -569,17 +946,43 @@ class DefaultTypeConverter(Converter): - Complex types: array, map, row/struct - JSON: json + When ``type_hint`` is provided (an Athena DDL type signature string like + ``"array(row(name varchar, age integer))"``), nested values within complex + types are converted according to the specified types instead of using + heuristic inference. + Example: >>> converter = DefaultTypeConverter() >>> converter.convert('integer', '42') 42 >>> converter.convert('date', '2023-01-15') datetime.date(2023, 1, 15) + >>> converter.convert('array', '[1, 2, 3]', type_hint='array(varchar)') + ['1', '2', '3'] """ def __init__(self) -> None: super().__init__(mappings=deepcopy(_DEFAULT_CONVERTERS), default=_to_default) + self._parsed_hints: dict[str, TypeNode] = {} - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: + if value is None: + return None + if type_hint: + type_node = self._get_or_parse_hint(type_hint) + return _convert_value_with_type(value, type_node) converter = self.get(type_) return converter(value) + + def _get_or_parse_hint(self, type_hint: str) -> TypeNode: + """Get or parse a type hint string into a TypeNode, with caching. + + Args: + type_hint: Athena DDL type signature string. + + Returns: + Parsed TypeNode. + """ + if type_hint not in self._parsed_hints: + self._parsed_hints[type_hint] = parse_type_signature(type_hint) + return self._parsed_hints[type_hint] diff --git a/pyathena/cursor.py b/pyathena/cursor.py index 9557dc7f..c730f9bc 100644 --- a/pyathena/cursor.py +++ b/pyathena/cursor.py @@ -95,32 +95,36 @@ def execute( result_reuse_minutes: int | None = None, paramstyle: str | None = None, on_start_query_execution: Callable[[str], None] | None = None, + result_set_type_hints: dict[str, str] | None = None, **kwargs, ) -> Cursor: """Execute a SQL query. Args: - operation: SQL query string to execute - parameters: Query parameters (optional) + operation: SQL query string to execute. + parameters: Query parameters (optional). on_start_query_execution: Callback function called immediately after - start_query_execution API is called. - Function signature: (query_id: str) -> None - This allows early access to query_id for - monitoring/cancellation. - **kwargs: Additional execution parameters + start_query_execution API is called. + Function signature: (query_id: str) -> None + This allows early access to query_id for + monitoring/cancellation. + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. For example: + ``{"tags": "array(varchar)", "metadata": "map(varchar, integer)"}`` + **kwargs: Additional execution parameters. Returns: - Cursor: Self reference for method chaining - - Example with callback for early query ID access: - def on_execution_started(query_id): - print(f"Query execution started: {query_id}") - # Store query_id for potential cancellation from another thread - global current_query_id - current_query_id = query_id - - cursor.execute("SELECT * FROM large_table", - on_start_query_execution=on_execution_started) + Self reference for method chaining. + + Example: + >>> cursor.execute( + ... "SELECT * FROM table_with_complex_types", + ... result_set_type_hints={ + ... "tags": "array(varchar)", + ... "metadata": "map(varchar, integer)", + ... } + ... ) """ self._reset_state() self.query_id = self._execute( @@ -150,6 +154,7 @@ def on_execution_started(query_id): query_execution, self.arraysize, self._retry_config, + result_set_type_hints=result_set_type_hints, ) else: raise OperationalError(query_execution.state_change_reason) diff --git a/pyathena/pandas/async_cursor.py b/pyathena/pandas/async_cursor.py index 48d3d217..faebe32f 100644 --- a/pyathena/pandas/async_cursor.py +++ b/pyathena/pandas/async_cursor.py @@ -118,6 +118,7 @@ def arraysize(self, value: int) -> None: def _collect_result_set( self, query_id: str, + result_set_type_hints: dict[str, str] | None = None, keep_default_na: bool = False, na_values: Iterable[str] | None = ("",), quoting: int = 1, @@ -140,6 +141,7 @@ def _collect_result_set( unload_location=unload_location, engine=kwargs.pop("engine", self._engine), chunksize=kwargs.pop("chunksize", self._chunksize), + result_set_type_hints=result_set_type_hints, **kwargs, ) @@ -154,6 +156,7 @@ def execute( result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, + result_set_type_hints: dict[str, str] | None = None, keep_default_na: bool = False, na_values: Iterable[str] | None = ("",), quoting: int = 1, @@ -176,6 +179,7 @@ def execute( self._executor.submit( self._collect_result_set, query_id, + result_set_type_hints, keep_default_na, na_values, quoting, diff --git a/pyathena/pandas/converter.py b/pyathena/pandas/converter.py index 576d2f53..29519bb3 100644 --- a/pyathena/pandas/converter.py +++ b/pyathena/pandas/converter.py @@ -80,7 +80,7 @@ def _dtypes(self) -> dict[str, type[Any]]: } return self.__dtypes - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: pass @@ -103,5 +103,5 @@ def __init__(self) -> None: default=_to_default, ) - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: pass diff --git a/pyathena/pandas/cursor.py b/pyathena/pandas/cursor.py index 39b1cf32..4413fea9 100644 --- a/pyathena/pandas/cursor.py +++ b/pyathena/pandas/cursor.py @@ -153,6 +153,7 @@ def execute( na_values: Iterable[str] | None = ("",), quoting: int = 1, on_start_query_execution: Callable[[str], None] | None = None, + result_set_type_hints: dict[str, str] | None = None, **kwargs, ) -> PandasCursor: """Execute a SQL query and return results as pandas DataFrames. @@ -175,6 +176,9 @@ def execute( na_values: Additional values to treat as NA. quoting: CSV quoting behavior (pandas csv.QUOTE_* constants). on_start_query_execution: Callback called when query starts. + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. **kwargs: Additional pandas read_csv/read_parquet parameters. Returns: @@ -224,6 +228,7 @@ def execute( cache_type=kwargs.pop("cache_type", self._cache_type), max_workers=kwargs.pop("max_workers", self._max_workers), auto_optimize_chunksize=self._auto_optimize_chunksize, + result_set_type_hints=result_set_type_hints, **kwargs, ) else: diff --git a/pyathena/pandas/result_set.py b/pyathena/pandas/result_set.py index 1486b45d..c88b295d 100644 --- a/pyathena/pandas/result_set.py +++ b/pyathena/pandas/result_set.py @@ -229,6 +229,7 @@ def __init__( cache_type: str | None = None, max_workers: int = (cpu_count() or 1) * 5, auto_optimize_chunksize: bool = False, + result_set_type_hints: dict[str, str] | None = None, **kwargs, ) -> None: """Initialize AthenaPandasResultSet with pandas-specific configurations. @@ -252,6 +253,8 @@ def __init__( max_workers: Maximum worker threads for parallel operations. auto_optimize_chunksize: Enable automatic chunksize determination for large files when chunksize is None. + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion. **kwargs: Additional arguments passed to pandas.read_csv/read_parquet. """ super().__init__( @@ -260,6 +263,7 @@ def __init__( query_execution=query_execution, arraysize=1, # Fetch one row to retrieve metadata retry_config=retry_config, + result_set_type_hints=result_set_type_hints, ) self._rows.clear() # Clear pre_fetch data self._arraysize = arraysize diff --git a/pyathena/polars/async_cursor.py b/pyathena/polars/async_cursor.py index 9349f61a..c973b736 100644 --- a/pyathena/polars/async_cursor.py +++ b/pyathena/polars/async_cursor.py @@ -161,6 +161,7 @@ def arraysize(self, value: int) -> None: def _collect_result_set( self, query_id: str, + result_set_type_hints: dict[str, str] | None = None, unload_location: str | None = None, kwargs: dict[str, Any] | None = None, ) -> AthenaPolarsResultSet: @@ -179,6 +180,7 @@ def _collect_result_set( cache_type=self._cache_type, max_workers=self._max_workers, chunksize=self._chunksize, + result_set_type_hints=result_set_type_hints, **kwargs, ) @@ -193,6 +195,7 @@ def execute( result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, + result_set_type_hints: dict[str, str] | None = None, **kwargs, ) -> tuple[str, Future[AthenaPolarsResultSet | Any]]: """Execute a SQL query asynchronously and return results as Polars DataFrames. @@ -210,6 +213,9 @@ def execute( result_reuse_enable: Enable Athena result reuse for this query. result_reuse_minutes: Minutes to reuse cached results. paramstyle: Parameter style ('qmark' or 'pyformat'). + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. **kwargs: Additional execution parameters passed to Polars read functions. Returns: @@ -237,6 +243,7 @@ def execute( self._executor.submit( self._collect_result_set, query_id, + result_set_type_hints, unload_location, kwargs, ), diff --git a/pyathena/polars/converter.py b/pyathena/polars/converter.py index 627a4ffd..356deaf2 100644 --- a/pyathena/polars/converter.py +++ b/pyathena/polars/converter.py @@ -103,7 +103,7 @@ def get_dtype(self, type_: str, precision: int = 0, scale: int = 0) -> Any: return pl.Decimal(precision=precision, scale=scale) return self._types.get(type_) - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: converter = self.get(type_) return converter(value) @@ -127,5 +127,5 @@ def __init__(self) -> None: default=_to_default, ) - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: pass diff --git a/pyathena/polars/cursor.py b/pyathena/polars/cursor.py index b788feac..0b95ba8f 100644 --- a/pyathena/polars/cursor.py +++ b/pyathena/polars/cursor.py @@ -157,6 +157,7 @@ def execute( result_reuse_minutes: int | None = None, paramstyle: str | None = None, on_start_query_execution: Callable[[str], None] | None = None, + result_set_type_hints: dict[str, str] | None = None, **kwargs, ) -> PolarsCursor: """Execute a SQL query and return results as Polars DataFrames. @@ -175,6 +176,9 @@ def execute( result_reuse_minutes: Minutes to reuse cached results. paramstyle: Parameter style ('qmark' or 'pyformat'). on_start_query_execution: Callback called when query starts. + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. **kwargs: Additional execution parameters passed to Polars read functions. Returns: @@ -218,6 +222,7 @@ def execute( cache_type=self._cache_type, max_workers=self._max_workers, chunksize=self._chunksize, + result_set_type_hints=result_set_type_hints, **kwargs, ) else: diff --git a/pyathena/polars/result_set.py b/pyathena/polars/result_set.py index 6818b84e..aedc8e33 100644 --- a/pyathena/polars/result_set.py +++ b/pyathena/polars/result_set.py @@ -202,6 +202,7 @@ def __init__( cache_type: str | None = None, max_workers: int = (cpu_count() or 1) * 5, chunksize: int | None = None, + result_set_type_hints: dict[str, str] | None = None, **kwargs, ) -> None: """Initialize the Polars result set. @@ -220,6 +221,8 @@ def __init__( chunksize: Number of rows per chunk for memory-efficient processing. If specified, data is loaded lazily in chunks for all data access methods including fetchone(), fetchmany(), and iter_chunks(). + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion. **kwargs: Additional arguments passed to Polars read functions. """ super().__init__( @@ -228,6 +231,7 @@ def __init__( query_execution=query_execution, arraysize=1, # Fetch one row to retrieve metadata retry_config=retry_config, + result_set_type_hints=result_set_type_hints, ) self._rows.clear() # Clear pre_fetch data self._arraysize = arraysize diff --git a/pyathena/result_set.py b/pyathena/result_set.py index 7f579d06..86a71bb6 100644 --- a/pyathena/result_set.py +++ b/pyathena/result_set.py @@ -61,6 +61,7 @@ def __init__( arraysize: int, retry_config: RetryConfig, _pre_fetch: bool = True, + result_set_type_hints: dict[str, str] | None = None, ) -> None: super().__init__(arraysize=arraysize) self._connection: Connection[Any] | None = connection @@ -69,6 +70,7 @@ def __init__( if not self._query_execution: raise ProgrammingError("Required argument `query_execution` not found.") self._retry_config = retry_config + self._result_set_type_hints = result_set_type_hints self._client = connection.session.client( "s3", region_name=connection.region_name, @@ -443,10 +445,15 @@ def _get_rows( converter: Converter | None = None, ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: conv = converter or self._converter + hints = self._result_set_type_hints return [ tuple( [ - conv.convert(meta.get("Type"), row.get("VarCharValue")) + conv.convert( + meta.get("Type"), + row.get("VarCharValue"), + type_hint=hints.get(meta.get("Name", "")) if hints else None, + ) for meta, row in zip(metadata, rows[i].get("Data", []), strict=False) ] ) @@ -630,12 +637,17 @@ def _get_rows( converter: Converter | None = None, ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: conv = converter or self._converter + hints = self._result_set_type_hints return [ self.dict_type( [ ( meta.get("Name"), - conv.convert(meta.get("Type"), row.get("VarCharValue")), + conv.convert( + meta.get("Type"), + row.get("VarCharValue"), + type_hint=hints.get(meta.get("Name", "")) if hints else None, + ), ) for meta, row in zip(metadata, rows[i].get("Data", []), strict=False) ] diff --git a/pyathena/s3fs/async_cursor.py b/pyathena/s3fs/async_cursor.py index a98c4558..c4d3c3fe 100644 --- a/pyathena/s3fs/async_cursor.py +++ b/pyathena/s3fs/async_cursor.py @@ -142,12 +142,16 @@ def arraysize(self, value: int) -> None: def _collect_result_set( self, query_id: str, + result_set_type_hints: dict[str, str] | None = None, kwargs: dict[str, Any] | None = None, ) -> AthenaS3FSResultSet: """Collect result set after query execution. Args: query_id: The Athena query execution ID. + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. kwargs: Additional keyword arguments for result set. Returns: @@ -163,6 +167,7 @@ def _collect_result_set( arraysize=self._arraysize, retry_config=self._retry_config, csv_reader=self._csv_reader, + result_set_type_hints=result_set_type_hints, **kwargs, ) @@ -177,6 +182,7 @@ def execute( result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, + result_set_type_hints: dict[str, str] | None = None, **kwargs, ) -> tuple[str, Future[AthenaS3FSResultSet | Any]]: """Execute a SQL query asynchronously. @@ -194,6 +200,9 @@ def execute( result_reuse_enable: Enable Athena result reuse for this query. result_reuse_minutes: Minutes to reuse cached results. paramstyle: Parameter style ('qmark' or 'pyformat'). + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. **kwargs: Additional execution parameters. Returns: @@ -220,6 +229,7 @@ def execute( self._executor.submit( self._collect_result_set, query_id, + result_set_type_hints, kwargs, ), ) diff --git a/pyathena/s3fs/converter.py b/pyathena/s3fs/converter.py index 853f7a7d..a46e570b 100644 --- a/pyathena/s3fs/converter.py +++ b/pyathena/s3fs/converter.py @@ -44,7 +44,7 @@ def __init__(self) -> None: default=_to_default, ) - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: """Convert a string value to the appropriate Python type. Looks up the converter function for the given Athena type and applies @@ -53,9 +53,19 @@ def convert(self, type_: str, value: str | None) -> Any | None: Args: type_: The Athena data type name (e.g., "integer", "varchar", "date"). value: The string value to convert, or None. + type_hint: Optional Athena DDL type signature for precise complex type + conversion (e.g., "array(varchar)"). Returns: The converted Python value, or None if the input value was None. """ + if value is None: + return None + if type_hint: + from pyathena.converter import DefaultTypeConverter + + # Delegate to DefaultTypeConverter for type_hint-based conversion + dtc = DefaultTypeConverter() + return dtc.convert(type_, value, type_hint=type_hint) converter = self.get(type_) return converter(value) diff --git a/pyathena/s3fs/cursor.py b/pyathena/s3fs/cursor.py index fbb85fab..c1bf47e4 100644 --- a/pyathena/s3fs/cursor.py +++ b/pyathena/s3fs/cursor.py @@ -133,6 +133,7 @@ def execute( result_reuse_minutes: int | None = None, paramstyle: str | None = None, on_start_query_execution: Callable[[str], None] | None = None, + result_set_type_hints: dict[str, str] | None = None, **kwargs, ) -> S3FSCursor: """Execute a SQL query and return results. @@ -151,6 +152,9 @@ def execute( result_reuse_minutes: Minutes to reuse cached results. paramstyle: Parameter style ('qmark' or 'pyformat'). on_start_query_execution: Callback called when query starts. + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. **kwargs: Additional execution parameters. Returns: @@ -188,6 +192,7 @@ def execute( arraysize=self.arraysize, retry_config=self._retry_config, csv_reader=self._csv_reader, + result_set_type_hints=result_set_type_hints, **kwargs, ) else: diff --git a/pyathena/s3fs/result_set.py b/pyathena/s3fs/result_set.py index 4173f25b..c289605b 100644 --- a/pyathena/s3fs/result_set.py +++ b/pyathena/s3fs/result_set.py @@ -64,6 +64,7 @@ def __init__( block_size: int | None = None, csv_reader: CSVReaderType | None = None, filesystem_class: type[AbstractFileSystem] | None = None, + result_set_type_hints: dict[str, str] | None = None, **kwargs, ) -> None: super().__init__( @@ -72,6 +73,7 @@ def __init__( query_execution=query_execution, arraysize=1, # Fetch one row to retrieve metadata retry_config=retry_config, + result_set_type_hints=result_set_type_hints, ) # Save pre-fetched rows (from Athena API) in case CSV reading is not available pre_fetched_rows = list(self._rows) @@ -149,6 +151,8 @@ def _fetch(self) -> None: description = self.description if self.description else [] column_types = [d[1] for d in description] + column_names = [d[0] for d in description] + hints = self._result_set_type_hints rows_fetched = 0 while rows_fetched < self._arraysize: @@ -162,13 +166,25 @@ def _fetch(self) -> None: # DefaultCSVReader returns empty string which needs conversion if self._csv_reader_class is DefaultCSVReader: converted_row = tuple( - self._converter.convert(col_type, value if value != "" else None) - for col_type, value in zip(column_types, row, strict=False) + self._converter.convert( + col_type, + value if value != "" else None, + type_hint=hints.get(col_name) if hints else None, + ) + for col_type, col_name, value in zip( + column_types, column_names, row, strict=False + ) ) else: converted_row = tuple( - self._converter.convert(col_type, value) - for col_type, value in zip(column_types, row, strict=False) + self._converter.convert( + col_type, + value, + type_hint=hints.get(col_name) if hints else None, + ) + for col_type, col_name, value in zip( + column_types, column_names, row, strict=False + ) ) self._rows.append(converted_row) rows_fetched += 1 diff --git a/tests/pyathena/test_converter.py b/tests/pyathena/test_converter.py index fbf78884..8ef99833 100644 --- a/tests/pyathena/test_converter.py +++ b/tests/pyathena/test_converter.py @@ -1,6 +1,15 @@ import pytest -from pyathena.converter import DefaultTypeConverter, _to_array, _to_struct +from pyathena.converter import ( + DefaultTypeConverter, + _to_array, + _to_struct, + parse_type_signature, +) + +# ============================================================================ +# Tests for _to_struct (JSON format - unchanged by breaking change) +# ============================================================================ @pytest.mark.parametrize( @@ -25,21 +34,31 @@ def test_to_struct_json_formats(input_value, expected): assert result == expected +# ============================================================================ +# Tests for _to_struct (Athena native format - affected by _convert_value change) +# ============================================================================ + + @pytest.mark.parametrize( ("input_value", "expected"), [ - ("{a=1, b=2}", {"a": 1, "b": 2}), + # Values that were previously inferred as int/bool now stay as strings + ("{a=1, b=2}", {"a": "1", "b": "2"}), ("{}", {}), ("{name=John, city=Tokyo}", {"name": "John", "city": "Tokyo"}), - ("{Alice, 25}", {"0": "Alice", "1": 25}), - ("{John, 30, true}", {"0": "John", "1": 30, "2": True}), - ("{name=John, age=30}", {"name": "John", "age": 30}), - ("{x=1, y=2, z=3}", {"x": 1, "y": 2, "z": 3}), - ("{active=true, count=42}", {"active": True, "count": 42}), + ("{Alice, 25}", {"0": "Alice", "1": "25"}), + ("{John, 30, true}", {"0": "John", "1": "30", "2": "true"}), + ("{name=John, age=30}", {"name": "John", "age": "30"}), + ("{x=1, y=2, z=3}", {"x": "1", "y": "2", "z": "3"}), + ("{active=true, count=42}", {"active": "true", "count": "42"}), ], ) def test_to_struct_athena_native_formats(input_value, expected): - """Test STRUCT conversion for Athena native formats.""" + """Test STRUCT conversion for Athena native formats. + + After the breaking change, _convert_value returns strings by default + (no heuristic type inference for numbers/bools). + """ result = _to_struct(input_value) assert result == expected @@ -47,40 +66,44 @@ def test_to_struct_athena_native_formats(input_value, expected): @pytest.mark.parametrize( ("input_value", "expected"), [ - # Single level nesting (Issue #627) + # Single level nesting (Issue #627) - leaf values are strings ( "{header={stamp=2024-01-01, seq=123}, x=4.736, y=0.583}", - {"header": {"stamp": "2024-01-01", "seq": 123}, "x": 4.736, "y": 0.583}, + {"header": {"stamp": "2024-01-01", "seq": "123"}, "x": "4.736", "y": "0.583"}, ), # Double nesting ( "{outer={middle={inner=value}}, field=123}", - {"outer": {"middle": {"inner": "value"}}, "field": 123}, + {"outer": {"middle": {"inner": "value"}}, "field": "123"}, ), # Multiple nested fields ( "{pos={x=1, y=2}, vel={x=0.5, y=0.3}, timestamp=12345}", - {"pos": {"x": 1, "y": 2}, "vel": {"x": 0.5, "y": 0.3}, "timestamp": 12345}, + { + "pos": {"x": "1", "y": "2"}, + "vel": {"x": "0.5", "y": "0.3"}, + "timestamp": "12345", + }, ), # Triple nesting ( "{level1={level2={level3={value=deep}}}}", {"level1": {"level2": {"level3": {"value": "deep"}}}}, ), - # Mixed types in nested struct + # Mixed types in nested struct - values are now strings ( "{metadata={id=123, active=true, name=test}, count=5}", - {"metadata": {"id": 123, "active": True, "name": "test"}, "count": 5}, + {"metadata": {"id": "123", "active": "true", "name": "test"}, "count": "5"}, ), - # Nested struct with null value + # Nested struct with null value - null still becomes None ( "{data={value=null, status=ok}, flag=true}", - {"data": {"value": None, "status": "ok"}, "flag": True}, + {"data": {"value": None, "status": "ok"}, "flag": "true"}, ), # Complex nesting with multiple levels and fields ( "{a={b={c=1, d=2}, e=3}, f=4, g={h=5}}", - {"a": {"b": {"c": 1, "d": 2}, "e": 3}, "f": 4, "g": {"h": 5}}, + {"a": {"b": {"c": "1", "d": "2"}, "e": "3"}, "f": "4", "g": {"h": "5"}}, ), ], ) @@ -108,18 +131,28 @@ def test_to_struct_athena_complex_cases(input_value): ) +# ============================================================================ +# Tests for _to_map (affected by _convert_value change) +# ============================================================================ + + def test_to_map_athena_numeric_keys(): - """Test Athena map with numeric keys""" + """Test Athena map with numeric keys - values are now strings.""" from pyathena.converter import _to_map map_value = "{1=2, 3=4}" result = _to_map(map_value) - expected = {"1": 2, "3": 4} + expected = {"1": "2", "3": "4"} assert result == expected +# ============================================================================ +# Tests for _to_array (JSON format - unchanged) +# ============================================================================ + + def test_to_array_athena_numeric_elements(): - """Test Athena array with numeric elements""" + """Test Athena array with numeric elements (JSON-parseable, unchanged).""" array_value = "[1, 2, 3, 4]" result = _to_array(array_value) expected = [1, 2, 3, 4] @@ -127,49 +160,51 @@ def test_to_array_athena_numeric_elements(): def test_to_array_athena_mixed_elements(): - """Test Athena array with mixed type elements""" + """Test Athena array with mixed type elements (native format, affected by change).""" array_value = "[1, hello, true, null]" result = _to_array(array_value) - expected = [1, "hello", True, None] + # JSON parsing fails (unquoted hello), so native format is used + # _convert_value now returns strings + expected = ["1", "hello", "true", None] assert result == expected def test_to_array_athena_struct_elements(): - """Test Athena array with struct elements""" + """Test Athena array with struct elements (native format, affected by change).""" array_value = "[{name=John, age=30}, {name=Jane, age=25}]" result = _to_array(array_value) - expected = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}] + expected = [{"name": "John", "age": "30"}, {"name": "Jane", "age": "25"}] assert result == expected def test_to_array_athena_unnamed_struct_elements(): - """Test Athena array with unnamed struct elements""" + """Test Athena array with unnamed struct elements (native format, affected by change).""" array_value = "[{Alice, 25}, {Bob, 30}]" result = _to_array(array_value) - expected = [{"0": "Alice", "1": 25}, {"0": "Bob", "1": 30}] + expected = [{"0": "Alice", "1": "25"}, {"0": "Bob", "1": "30"}] assert result == expected @pytest.mark.parametrize( ("input_value", "expected"), [ - # Array with nested structs (Issue #627) + # Array with nested structs (Issue #627) - leaf values are strings ( "[{header={stamp=2024-01-01, seq=123}, x=4.736}]", - [{"header": {"stamp": "2024-01-01", "seq": 123}, "x": 4.736}], + [{"header": {"stamp": "2024-01-01", "seq": "123"}, "x": "4.736"}], ), # Multiple elements with nested structs ( "[{pos={x=1, y=2}, vel={x=0.5}}, {pos={x=3, y=4}, vel={x=1.5}}]", [ - {"pos": {"x": 1, "y": 2}, "vel": {"x": 0.5}}, - {"pos": {"x": 3, "y": 4}, "vel": {"x": 1.5}}, + {"pos": {"x": "1", "y": "2"}, "vel": {"x": "0.5"}}, + {"pos": {"x": "3", "y": "4"}, "vel": {"x": "1.5"}}, ], ), # Array with deeply nested structs ( "[{data={meta={id=1, active=true}}}]", - [{"data": {"meta": {"id": 1, "active": True}}}], + [{"data": {"meta": {"id": "1", "active": "true"}}}], ), ], ) @@ -197,18 +232,10 @@ def test_to_struct_non_dict_json(input_value): ("input_value", "expected"), [ (None, None), - ( - "[1, 2, 3, 4, 5]", - [1, 2, 3, 4, 5], - ), - ( - '["apple", "banana", "cherry"]', - ["apple", "banana", "cherry"], - ), - ( - "[true, false, null]", - [True, False, None], - ), + # JSON-parseable arrays keep JSON types + ("[1, 2, 3, 4, 5]", [1, 2, 3, 4, 5]), + ('["apple", "banana", "cherry"]', ["apple", "banana", "cherry"]), + ("[true, false, null]", [True, False, None]), ( '[{"name": "John", "age": 30}, {"name": "Jane", "age": 25}]', [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}], @@ -227,16 +254,22 @@ def test_to_array_json_formats(input_value, expected): @pytest.mark.parametrize( ("input_value", "expected"), [ + # JSON-parseable: keep JSON types ("[1, 2, 3]", [1, 2, 3]), ("[]", []), + ("[true, false, null]", [True, False, None]), + # Native format: strings (no heuristic inference) ("[apple, banana, cherry]", ["apple", "banana", "cherry"]), - ("[{Alice, 25}, {Bob, 30}]", [{"0": "Alice", "1": 25}, {"0": "Bob", "1": 30}]), + ( + "[{Alice, 25}, {Bob, 30}]", + [{"0": "Alice", "1": "25"}, {"0": "Bob", "1": "30"}], + ), ( "[{name=John, age=30}, {name=Jane, age=25}]", - [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}], + [{"name": "John", "age": "30"}, {"name": "Jane", "age": "25"}], ), - ("[true, false, null]", [True, False, None]), - ("[1, 2.5, hello]", [1, 2.5, "hello"]), + # Mixed native: numbers and strings stay as strings + ("[1, 2.5, hello]", ["1", "2.5", "hello"]), ], ) def test_to_array_athena_native_formats(input_value, expected): @@ -288,15 +321,22 @@ def test_to_array_invalid_formats(input_value): assert result is None +# ============================================================================ +# Tests for DefaultTypeConverter (without type_hint) +# ============================================================================ + + class TestDefaultTypeConverter: @pytest.mark.parametrize( ("input_value", "expected"), [ + # JSON format keeps JSON types ('{"name": "Alice", "age": 25}', {"name": "Alice", "age": 25}), (None, None), ("", None), ("invalid json", None), - ("{a=1, b=2}", {"a": 1, "b": 2}), + # Native format: values are strings (breaking change) + ("{a=1, b=2}", {"a": "1", "b": "2"}), ], ) def test_struct_conversion(self, input_value, expected): @@ -308,11 +348,13 @@ def test_struct_conversion(self, input_value, expected): @pytest.mark.parametrize( ("input_value", "expected"), [ + # JSON format keeps JSON types ("[1, 2, 3]", [1, 2, 3]), ('["a", "b", "c"]', ["a", "b", "c"]), (None, None), ("", None), ("invalid json", None), + # Native format: values are strings ("[apple, banana]", ["apple", "banana"]), ("[]", []), ], @@ -322,3 +364,221 @@ def test_array_conversion(self, input_value, expected): converter = DefaultTypeConverter() result = converter.convert("array", input_value) assert result == expected + + +# ============================================================================ +# Tests for parse_type_signature +# ============================================================================ + + +class TestParseTypeSignature: + def test_simple_type(self): + node = parse_type_signature("varchar") + assert node.type_name == "varchar" + assert node.children == [] + assert node.field_names is None + + def test_simple_type_case_insensitive(self): + node = parse_type_signature("VARCHAR") + assert node.type_name == "varchar" + + def test_array_type(self): + node = parse_type_signature("array(varchar)") + assert node.type_name == "array" + assert len(node.children) == 1 + assert node.children[0].type_name == "varchar" + + def test_array_of_integer(self): + node = parse_type_signature("array(integer)") + assert node.type_name == "array" + assert node.children[0].type_name == "integer" + + def test_map_type(self): + node = parse_type_signature("map(varchar, integer)") + assert node.type_name == "map" + assert len(node.children) == 2 + assert node.children[0].type_name == "varchar" + assert node.children[1].type_name == "integer" + + def test_row_type(self): + node = parse_type_signature("row(name varchar, age integer)") + assert node.type_name == "row" + assert len(node.children) == 2 + assert node.field_names == ["name", "age"] + assert node.children[0].type_name == "varchar" + assert node.children[1].type_name == "integer" + + def test_struct_type(self): + node = parse_type_signature("struct(name varchar, age integer)") + assert node.type_name == "struct" + assert node.field_names == ["name", "age"] + + def test_nested_array_of_row(self): + node = parse_type_signature("array(row(name varchar, age integer))") + assert node.type_name == "array" + assert len(node.children) == 1 + row_node = node.children[0] + assert row_node.type_name == "row" + assert row_node.field_names == ["name", "age"] + assert row_node.children[0].type_name == "varchar" + assert row_node.children[1].type_name == "integer" + + def test_map_with_complex_value(self): + node = parse_type_signature("map(varchar, row(x integer, y double))") + assert node.type_name == "map" + assert node.children[0].type_name == "varchar" + assert node.children[1].type_name == "row" + assert node.children[1].field_names == ["x", "y"] + + def test_deeply_nested(self): + node = parse_type_signature("array(row(data row(x integer, y integer), name varchar))") + assert node.type_name == "array" + row_node = node.children[0] + assert row_node.type_name == "row" + assert row_node.field_names == ["data", "name"] + assert row_node.children[0].type_name == "row" + assert row_node.children[0].field_names == ["x", "y"] + assert row_node.children[1].type_name == "varchar" + + def test_parameterized_type(self): + node = parse_type_signature("decimal(10, 2)") + assert node.type_name == "decimal" + + def test_varchar_with_length(self): + node = parse_type_signature("varchar(255)") + assert node.type_name == "varchar" + + +# ============================================================================ +# Tests for typed conversion with type_hint +# ============================================================================ + + +class TestTypedConversion: + def test_array_varchar_keeps_strings(self): + """Test that array(varchar) type hint keeps elements as strings.""" + converter = DefaultTypeConverter() + # JSON format: [1234, 5678] would be parsed as ints by JSON, + # but type_hint says varchar, so they should be strings + result = converter.convert("array", "[1234, 5678]", type_hint="array(varchar)") + assert result == ["1234", "5678"] + + def test_array_integer_converts_to_int(self): + """Test that array(integer) type hint converts elements to ints.""" + converter = DefaultTypeConverter() + result = converter.convert("array", "[1, 2, 3]", type_hint="array(integer)") + assert result == [1, 2, 3] + + def test_array_boolean_converts(self): + """Test that array(boolean) type hint converts elements to bools.""" + converter = DefaultTypeConverter() + result = converter.convert("array", "[true, false]", type_hint="array(boolean)") + assert result == [True, False] + + def test_array_with_null(self): + """Test that nulls in arrays are preserved regardless of type hint.""" + converter = DefaultTypeConverter() + result = converter.convert("array", "[1, null, 3]", type_hint="array(integer)") + assert result == [1, None, 3] + + def test_map_varchar_integer(self): + """Test map(varchar, integer) type hint.""" + converter = DefaultTypeConverter() + result = converter.convert( + "map", '{"key1": 1, "key2": 2}', type_hint="map(varchar, integer)" + ) + assert result == {"key1": 1, "key2": 2} + + def test_map_native_format_with_hints(self): + """Test map type hint with native format.""" + converter = DefaultTypeConverter() + result = converter.convert("map", "{a=1, b=2}", type_hint="map(varchar, integer)") + assert result == {"a": 1, "b": 2} + + def test_row_type_hint(self): + """Test row type hint with named fields.""" + converter = DefaultTypeConverter() + result = converter.convert( + "row", + '{"name": "Alice", "age": 25}', + type_hint="row(name varchar, age integer)", + ) + assert result == {"name": "Alice", "age": 25} + + def test_row_native_format_with_hints(self): + """Test row type hint with native format.""" + converter = DefaultTypeConverter() + result = converter.convert( + "row", + "{name=Alice, age=25}", + type_hint="row(name varchar, age integer)", + ) + assert result == {"name": "Alice", "age": 25} + + def test_nested_array_of_row(self): + """Test array(row(...)) type hint preserves correct types.""" + converter = DefaultTypeConverter() + result = converter.convert( + "array", + "[{name=Alice, age=25}, {name=Bob, age=30}]", + type_hint="array(row(name varchar, age integer))", + ) + assert result == [ + {"name": "Alice", "age": 25}, + {"name": "Bob", "age": 30}, + ] + + def test_array_varchar_prevents_number_inference(self): + """Test the core use case: array(varchar) prevents "1234" -> 1234.""" + converter = DefaultTypeConverter() + # This is the key issue from #689: varchar values like "1234" should + # not be converted to numbers when type_hint specifies varchar + result = converter.convert( + "array", + "[1234, 5678, hello]", + type_hint="array(varchar)", + ) + assert result == ["1234", "5678", "hello"] + + def test_none_value_with_type_hint(self): + """Test that None value returns None even with type hint.""" + converter = DefaultTypeConverter() + result = converter.convert("array", None, type_hint="array(varchar)") + assert result is None + + def test_simple_type_hint(self): + """Test type hint for simple types.""" + converter = DefaultTypeConverter() + result = converter.convert("varchar", "hello", type_hint="varchar") + assert result == "hello" + + def test_type_hint_caching(self): + """Test that parsed type hints are cached.""" + converter = DefaultTypeConverter() + converter.convert("array", "[1, 2]", type_hint="array(integer)") + assert "array(integer)" in converter._parsed_hints + # Second call should use cache + converter.convert("array", "[3, 4]", type_hint="array(integer)") + assert len(converter._parsed_hints) == 1 + + def test_empty_array_with_type_hint(self): + """Test empty array with type hint.""" + converter = DefaultTypeConverter() + result = converter.convert("array", "[]", type_hint="array(varchar)") + assert result == [] + + def test_map_varchar_varchar(self): + """Test map(varchar, varchar) keeps all values as strings.""" + converter = DefaultTypeConverter() + result = converter.convert("map", "{key1=123, key2=456}", type_hint="map(varchar, varchar)") + assert result == {"key1": "123", "key2": "456"} + + def test_row_with_nested_struct(self): + """Test row with nested struct field.""" + converter = DefaultTypeConverter() + result = converter.convert( + "row", + "{header={seq=123, stamp=2024}, x=4.5}", + type_hint="row(header row(seq integer, stamp varchar), x double)", + ) + assert result == {"header": {"seq": 123, "stamp": "2024"}, "x": 4.5} From ce3854019b7b73387980bded8f0c3759d424c399 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 13:14:57 +0900 Subject: [PATCH 02/18] Add pull request template Co-Authored-By: Claude Opus 4.6 --- .github/PULL_REQUEST_TEMPLATE.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE.md diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..4e09ff5b --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,5 @@ +## WHAT + + +## WHY + From 5d057916ac767132fe2cfe6c23683ab940437bc0 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 13:37:22 +0900 Subject: [PATCH 03/18] Refactor: Extract type parsing and typed conversion into parser.py Move TypeNode, TypeSignatureParser, and TypedValueConverter into a new pyathena/parser.py module. TypedValueConverter receives converter dependencies via constructor injection to avoid circular imports. Also moves _split_array_items to parser.py as a shared parsing utility. Co-Authored-By: Claude Opus 4.6 --- pyathena/converter.py | 430 +-------------------------- pyathena/parser.py | 436 +++++++++++++++++++++++++++ tests/pyathena/test_converter.py | 487 +++++++++++++------------------ 3 files changed, 643 insertions(+), 710 deletions(-) create mode 100644 pyathena/parser.py diff --git a/pyathena/converter.py b/pyathena/converter.py index b83500b4..a95db4e2 100644 --- a/pyathena/converter.py +++ b/pyathena/converter.py @@ -6,13 +6,13 @@ from abc import ABCMeta, abstractmethod from collections.abc import Callable from copy import deepcopy -from dataclasses import dataclass, field from datetime import date, datetime, time from decimal import Decimal from typing import Any from dateutil.tz import gettz +from pyathena.parser import TypedValueConverter, TypeNode, TypeSignatureParser, _split_array_items from pyathena.util import strtobool _logger = logging.getLogger(__name__) @@ -267,44 +267,6 @@ def _parse_array_native(inner: str) -> list[Any] | None: return result if result else None -def _split_array_items(inner: str) -> list[str]: - """Split array items by comma, respecting brace and bracket groupings. - - Args: - inner: Interior content of array without brackets. - - Returns: - List of item strings. - """ - items = [] - current_item = "" - brace_depth = 0 - bracket_depth = 0 - - for char in inner: - if char == "{": - brace_depth += 1 - elif char == "}": - brace_depth -= 1 - elif char == "[": - bracket_depth += 1 - elif char == "]": - bracket_depth -= 1 - elif char == "," and brace_depth == 0 and bracket_depth == 0: - # Top-level comma - end current item - items.append(current_item.strip()) - current_item = "" - continue - - current_item += char - - # Add the last item - if current_item.strip(): - items.append(current_item.strip()) - - return items - - def _parse_map_native(inner: str) -> dict[str, Any] | None: """Parse map native format: key1=value1, key2=value2. @@ -402,7 +364,7 @@ def _convert_value(value: str) -> Any: This is a safe default that avoids incorrect type conversions (e.g., converting varchar "1234" to int 1234 inside complex types). - Use :func:`_convert_value_with_type` for type-aware conversion. + Use :class:`~pyathena.parser.TypedValueConverter` for type-aware conversion. Args: value: String value to convert. @@ -419,384 +381,6 @@ def _to_default(varchar_value: str | None) -> str | None: return varchar_value -@dataclass -class TypeNode: - """Parsed representation of an Athena DDL type signature. - - Represents a node in a type tree, where complex types (array, map, row) - have children representing their element/field types. - - Attributes: - type_name: The base type name (e.g., "array", "map", "row", "varchar"). - children: Child type nodes for complex types. - field_names: Field names for row/struct types (parallel to children). - """ - - type_name: str - children: list[TypeNode] = field(default_factory=list) - field_names: list[str] | None = None - - -def _split_top_level(s: str, delimiter: str = ",") -> list[str]: - """Split a string by delimiter, respecting nested parentheses. - - Args: - s: String to split. - delimiter: Character to split on. - - Returns: - List of parts split at top-level delimiters. - """ - parts: list[str] = [] - current: list[str] = [] - depth = 0 - - for char in s: - if char == "(": - depth += 1 - elif char == ")": - depth -= 1 - elif char == delimiter and depth == 0: - parts.append("".join(current).strip()) - current = [] - continue - current.append(char) - - if current: - parts.append("".join(current).strip()) - return parts - - -def parse_type_signature(type_str: str) -> TypeNode: - """Parse an Athena DDL type signature string into a TypeNode tree. - - Handles simple types (varchar, integer), parameterized types (decimal(10,2)), - and complex types (array, map, row/struct) with arbitrary nesting. - - Args: - type_str: Athena DDL type string (e.g., "array(row(name varchar, age integer))"). - - Returns: - TypeNode representing the parsed type tree. - """ - type_str = type_str.strip() - - paren_idx = type_str.find("(") - if paren_idx == -1: - return TypeNode(type_name=type_str.lower()) - - type_name = type_str[:paren_idx].strip().lower() - - # Find matching closing paren - inner = type_str[paren_idx + 1 : -1].strip() - - if type_name in ("row", "struct"): - parts = _split_top_level(inner) - field_names: list[str] = [] - children: list[TypeNode] = [] - for part in parts: - part = part.strip() - # Split into field_name and type at first space - space_idx = _find_field_name_boundary(part) - if space_idx == -1: - children.append(parse_type_signature(part)) - field_names.append(part) - else: - field_name = part[:space_idx].strip() - type_part = part[space_idx + 1 :].strip() - field_names.append(field_name) - children.append(parse_type_signature(type_part)) - return TypeNode(type_name=type_name, children=children, field_names=field_names) - - if type_name == "array": - child = parse_type_signature(inner) - return TypeNode(type_name=type_name, children=[child]) - - if type_name == "map": - parts = _split_top_level(inner) - if len(parts) == 2: - key_type = parse_type_signature(parts[0]) - value_type = parse_type_signature(parts[1]) - return TypeNode(type_name=type_name, children=[key_type, value_type]) - return TypeNode(type_name=type_name) - - # Types with parameters like decimal(10, 2), varchar(255) - return TypeNode(type_name=type_name) - - -def _find_field_name_boundary(part: str) -> int: - """Find the boundary between field name and type in a row field definition. - - Handles cases like "name varchar" and "data row(x integer, y integer)". - - Args: - part: A single field definition string. - - Returns: - Index of the space separating field name from type, or -1 if not found. - """ - depth = 0 - for i, char in enumerate(part): - if char == "(": - depth += 1 - elif char == ")": - depth -= 1 - elif char == " " and depth == 0: - return i - return -1 - - -def _convert_value_with_type(value: str, type_node: TypeNode) -> Any: - """Convert a value using type information from a TypeNode. - - For complex types (array, map, row), parses the structure and - recursively converts elements using child type information. - For simple types, uses the standard converter function. - - Args: - value: String value to convert. - type_node: Parsed type information. - - Returns: - Converted value. - """ - if type_node.type_name == "array": - return _convert_typed_array(value, type_node) - if type_node.type_name == "map": - return _convert_typed_map(value, type_node) - if type_node.type_name in ("row", "struct"): - return _convert_typed_struct(value, type_node) - # Simple type: use the standard converter - converter_fn = _DEFAULT_CONVERTERS.get(type_node.type_name, _to_default) - return converter_fn(value) - - -def _convert_element(value: str, type_node: TypeNode) -> Any: - """Convert a single element within a complex type using type information. - - Handles null values before delegating to type-specific conversion. - - Args: - value: String value to convert. - type_node: Type information for this element. - - Returns: - Converted value, or None for null. - """ - if value.lower() == "null": - return None - return _convert_value_with_type(value, type_node) - - -def _convert_typed_array(value: str, type_node: TypeNode) -> list[Any] | None: - """Convert an array value using type information. - - Args: - value: String representation of the array. - type_node: Type node with array element type as first child. - - Returns: - List of converted elements, or None if parsing fails. - """ - if not (value.startswith("[") and value.endswith("]")): - return None - - element_type = type_node.children[0] if type_node.children else TypeNode("varchar") - - # Try JSON first - try: - parsed = json.loads(value) - if isinstance(parsed, list): - return [ - None if elem is None else _convert_element(str(elem), element_type) - for elem in parsed - ] - except json.JSONDecodeError: - pass - - # Native format - inner = value[1:-1].strip() - if not inner: - return [] - - if "[" in inner: - return None # Nested arrays not supported in native format - - items = _split_array_items(inner) - result: list[Any] = [] - for item in items: - item = item.strip() - if not item: - continue - if item.startswith("{") and item.endswith("}"): - if element_type.type_name in ("row", "struct"): - result.append(_convert_typed_struct(item, element_type)) - elif element_type.type_name == "map": - result.append(_convert_typed_map(item, element_type)) - else: - result.append(_to_struct(item)) - else: - result.append(_convert_element(item, element_type)) - - return result if result else None - - -def _convert_typed_map(value: str, type_node: TypeNode) -> dict[str, Any] | None: - """Convert a map value using type information. - - Args: - value: String representation of the map. - type_node: Type node with key type and value type as children. - - Returns: - Dictionary of converted key-value pairs, or None if parsing fails. - """ - if not (value.startswith("{") and value.endswith("}")): - return None - - key_type = type_node.children[0] if len(type_node.children) > 0 else TypeNode("varchar") - value_type = type_node.children[1] if len(type_node.children) > 1 else TypeNode("varchar") - - # Try JSON first - inner_preview = value[1:10] if len(value) > 10 else value[1:-1] - if '"' in inner_preview or value.startswith('{"'): - try: - parsed = json.loads(value) - if isinstance(parsed, dict): - return { - str( - _convert_element(str(k), key_type) if k is not None else k - ): _convert_element(str(v), value_type) if v is not None else None - for k, v in parsed.items() - } - except json.JSONDecodeError: - pass - - # Native format - inner = value[1:-1].strip() - if not inner: - return {} - - if any(char in inner for char in "()[]"): - return None - - pairs = [pair.strip() for pair in inner.split(",")] - result: dict[str, Any] = {} - for pair in pairs: - if "=" not in pair: - continue - k, v = pair.split("=", 1) - k = k.strip() - v = v.strip() - if any(char in k for char in '{}="') or any(char in v for char in '{}="'): - continue - converted_key = _convert_element(k, key_type) - converted_value = _convert_element(v, value_type) - result[str(converted_key)] = converted_value - - return result if result else None - - -def _convert_typed_struct(value: str, type_node: TypeNode) -> dict[str, Any] | None: - """Convert a struct/row value using type information. - - Args: - value: String representation of the struct. - type_node: Type node with field types and names. - - Returns: - Dictionary of converted field values, or None if parsing fails. - """ - if not (value.startswith("{") and value.endswith("}")): - return None - - field_names = type_node.field_names or [] - field_types = type_node.children or [] - - # Try JSON first - inner_preview = value[1:10] if len(value) > 10 else value[1:-1] - if '"' in inner_preview or value.startswith('{"'): - try: - parsed = json.loads(value) - if isinstance(parsed, dict): - result: dict[str, Any] = {} - for i, (k, v) in enumerate(parsed.items()): - ft = field_types[i] if i < len(field_types) else TypeNode("varchar") - result[k] = _convert_element(str(v), ft) if v is not None else None - return result - except json.JSONDecodeError: - pass - - inner = value[1:-1].strip() - if not inner: - return {} - - if "=" in inner: - # Named struct - pairs = _split_array_items(inner) - result = {} - field_index = 0 - for pair in pairs: - if "=" not in pair: - continue - k, v = pair.split("=", 1) - k = k.strip() - v = v.strip() - if any(char in k for char in '{}="'): - continue - - ft = _get_field_type(k, field_names, field_types, field_index) - field_index += 1 - - if v.startswith("{") and v.endswith("}"): - if ft.type_name in ("row", "struct"): - result[k] = _convert_typed_struct(v, ft) - elif ft.type_name == "map": - result[k] = _convert_typed_map(v, ft) - else: - result[k] = _to_struct(v) - else: - result[k] = _convert_element(v, ft) - return result if result else None - - # Unnamed struct - values = [v.strip() for v in inner.split(",")] - result = {} - for i, v in enumerate(values): - ft = field_types[i] if i < len(field_types) else TypeNode("varchar") - name = field_names[i] if i < len(field_names) else str(i) - result[name] = _convert_element(v, ft) - return result - - -def _get_field_type( - field_name: str, - field_names: list[str], - field_types: list[TypeNode], - field_index: int, -) -> TypeNode: - """Look up the type for a struct field by name or index. - - Tries name-based lookup first, then falls back to positional index. - - Args: - field_name: Name of the field to look up. - field_names: List of known field names from the type hint. - field_types: List of corresponding field types. - field_index: Current positional index as fallback. - - Returns: - TypeNode for the field, defaulting to varchar if not found. - """ - if field_name in field_names: - idx = field_names.index(field_name) - if idx < len(field_types): - return field_types[idx] - if field_index < len(field_types): - return field_types[field_index] - return TypeNode("varchar") - - _DEFAULT_CONVERTERS: dict[str, Callable[[str | None], Any | None]] = { "boolean": _to_boolean, "tinyint": _to_int, @@ -963,6 +547,12 @@ class DefaultTypeConverter(Converter): def __init__(self) -> None: super().__init__(mappings=deepcopy(_DEFAULT_CONVERTERS), default=_to_default) + self._parser = TypeSignatureParser() + self._typed_converter = TypedValueConverter( + converters=_DEFAULT_CONVERTERS, + default_converter=_to_default, + struct_parser=_to_struct, + ) self._parsed_hints: dict[str, TypeNode] = {} def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: @@ -970,7 +560,7 @@ def convert(self, type_: str, value: str | None, type_hint: str | None = None) - return None if type_hint: type_node = self._get_or_parse_hint(type_hint) - return _convert_value_with_type(value, type_node) + return self._typed_converter.convert(value, type_node) converter = self.get(type_) return converter(value) @@ -984,5 +574,5 @@ def _get_or_parse_hint(self, type_hint: str) -> TypeNode: Parsed TypeNode. """ if type_hint not in self._parsed_hints: - self._parsed_hints[type_hint] = parse_type_signature(type_hint) + self._parsed_hints[type_hint] = self._parser.parse(type_hint) return self._parsed_hints[type_hint] diff --git a/pyathena/parser.py b/pyathena/parser.py new file mode 100644 index 00000000..0ef39313 --- /dev/null +++ b/pyathena/parser.py @@ -0,0 +1,436 @@ +from __future__ import annotations + +import json +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class TypeNode: + """Parsed representation of an Athena DDL type signature. + + Represents a node in a type tree, where complex types (array, map, row) + have children representing their element/field types. + + Attributes: + type_name: The base type name (e.g., "array", "map", "row", "varchar"). + children: Child type nodes for complex types. + field_names: Field names for row/struct types (parallel to children). + """ + + type_name: str + children: list[TypeNode] = field(default_factory=list) + field_names: list[str] | None = None + + +def _split_array_items(inner: str) -> list[str]: + """Split array items by comma, respecting brace and bracket groupings. + + Args: + inner: Interior content of array without brackets. + + Returns: + List of item strings. + """ + items: list[str] = [] + current_item = "" + brace_depth = 0 + bracket_depth = 0 + + for char in inner: + if char == "{": + brace_depth += 1 + elif char == "}": + brace_depth -= 1 + elif char == "[": + bracket_depth += 1 + elif char == "]": + bracket_depth -= 1 + elif char == "," and brace_depth == 0 and bracket_depth == 0: + items.append(current_item.strip()) + current_item = "" + continue + + current_item += char + + if current_item.strip(): + items.append(current_item.strip()) + + return items + + +class TypeSignatureParser: + """Parse Athena DDL type signature strings into a type tree.""" + + def parse(self, type_str: str) -> TypeNode: + """Parse an Athena DDL type signature string into a TypeNode tree. + + Handles simple types (varchar, integer), parameterized types (decimal(10,2)), + and complex types (array, map, row/struct) with arbitrary nesting. + + Args: + type_str: Athena DDL type string (e.g., "array(row(name varchar, age integer))"). + + Returns: + TypeNode representing the parsed type tree. + """ + type_str = type_str.strip() + + paren_idx = type_str.find("(") + if paren_idx == -1: + return TypeNode(type_name=type_str.lower()) + + type_name = type_str[:paren_idx].strip().lower() + + inner = type_str[paren_idx + 1 : -1].strip() + + if type_name in ("row", "struct"): + parts = self._split_top_level(inner) + field_names: list[str] = [] + children: list[TypeNode] = [] + for part in parts: + part = part.strip() + space_idx = self._find_field_name_boundary(part) + if space_idx == -1: + children.append(self.parse(part)) + field_names.append(part) + else: + field_name = part[:space_idx].strip() + type_part = part[space_idx + 1 :].strip() + field_names.append(field_name) + children.append(self.parse(type_part)) + return TypeNode(type_name=type_name, children=children, field_names=field_names) + + if type_name == "array": + child = self.parse(inner) + return TypeNode(type_name=type_name, children=[child]) + + if type_name == "map": + parts = self._split_top_level(inner) + if len(parts) == 2: + key_type = self.parse(parts[0]) + value_type = self.parse(parts[1]) + return TypeNode(type_name=type_name, children=[key_type, value_type]) + return TypeNode(type_name=type_name) + + # Types with parameters like decimal(10, 2), varchar(255) + return TypeNode(type_name=type_name) + + def _split_top_level(self, s: str, delimiter: str = ",") -> list[str]: + """Split a string by delimiter, respecting nested parentheses. + + Args: + s: String to split. + delimiter: Character to split on. + + Returns: + List of parts split at top-level delimiters. + """ + parts: list[str] = [] + current: list[str] = [] + depth = 0 + + for char in s: + if char == "(": + depth += 1 + elif char == ")": + depth -= 1 + elif char == delimiter and depth == 0: + parts.append("".join(current).strip()) + current = [] + continue + current.append(char) + + if current: + parts.append("".join(current).strip()) + return parts + + def _find_field_name_boundary(self, part: str) -> int: + """Find the boundary between field name and type in a row field definition. + + Handles cases like "name varchar" and "data row(x integer, y integer)". + + Args: + part: A single field definition string. + + Returns: + Index of the space separating field name from type, or -1 if not found. + """ + depth = 0 + for i, char in enumerate(part): + if char == "(": + depth += 1 + elif char == ")": + depth -= 1 + elif char == " " and depth == 0: + return i + return -1 + + +class TypedValueConverter: + """Convert values using TypeNode type information. + + Dependencies are injected via the constructor to avoid circular imports + between parser.py and converter.py. + + Args: + converters: Mapping of type names to conversion functions. + default_converter: Fallback conversion function for unknown types. + struct_parser: Function to parse untyped struct values. + """ + + def __init__( + self, + converters: dict[str, Callable[[str | None], Any | None]], + default_converter: Callable[[str | None], Any | None], + struct_parser: Callable[[str | None], dict[str, Any] | None], + ) -> None: + self._converters = converters + self._default_converter = default_converter + self._struct_parser = struct_parser + + def convert(self, value: str, type_node: TypeNode) -> Any: + """Convert a value using type information from a TypeNode. + + For complex types (array, map, row), parses the structure and + recursively converts elements using child type information. + For simple types, uses the standard converter function. + + Args: + value: String value to convert. + type_node: Parsed type information. + + Returns: + Converted value. + """ + if type_node.type_name == "array": + return self._convert_typed_array(value, type_node) + if type_node.type_name == "map": + return self._convert_typed_map(value, type_node) + if type_node.type_name in ("row", "struct"): + return self._convert_typed_struct(value, type_node) + converter_fn = self._converters.get(type_node.type_name, self._default_converter) + return converter_fn(value) + + def _convert_element(self, value: str, type_node: TypeNode) -> Any: + """Convert a single element within a complex type using type information. + + Handles null values before delegating to type-specific conversion. + + Args: + value: String value to convert. + type_node: Type information for this element. + + Returns: + Converted value, or None for null. + """ + if value.lower() == "null": + return None + return self.convert(value, type_node) + + def _convert_typed_array(self, value: str, type_node: TypeNode) -> list[Any] | None: + """Convert an array value using type information. + + Args: + value: String representation of the array. + type_node: Type node with array element type as first child. + + Returns: + List of converted elements, or None if parsing fails. + """ + if not (value.startswith("[") and value.endswith("]")): + return None + + element_type = type_node.children[0] if type_node.children else TypeNode("varchar") + + # Try JSON first + try: + parsed = json.loads(value) + if isinstance(parsed, list): + return [ + None if elem is None else self._convert_element(str(elem), element_type) + for elem in parsed + ] + except json.JSONDecodeError: + pass + + # Native format + inner = value[1:-1].strip() + if not inner: + return [] + + if "[" in inner: + return None # Nested arrays not supported in native format + + items = _split_array_items(inner) + result: list[Any] = [] + for item in items: + item = item.strip() + if not item: + continue + if item.startswith("{") and item.endswith("}"): + if element_type.type_name in ("row", "struct"): + result.append(self._convert_typed_struct(item, element_type)) + elif element_type.type_name == "map": + result.append(self._convert_typed_map(item, element_type)) + else: + result.append(self._struct_parser(item)) + else: + result.append(self._convert_element(item, element_type)) + + return result if result else None + + def _convert_typed_map(self, value: str, type_node: TypeNode) -> dict[str, Any] | None: + """Convert a map value using type information. + + Args: + value: String representation of the map. + type_node: Type node with key type and value type as children. + + Returns: + Dictionary of converted key-value pairs, or None if parsing fails. + """ + if not (value.startswith("{") and value.endswith("}")): + return None + + key_type = type_node.children[0] if len(type_node.children) > 0 else TypeNode("varchar") + value_type = type_node.children[1] if len(type_node.children) > 1 else TypeNode("varchar") + + # Try JSON first + inner_preview = value[1:10] if len(value) > 10 else value[1:-1] + if '"' in inner_preview or value.startswith('{"'): + try: + parsed = json.loads(value) + if isinstance(parsed, dict): + return { + str( + self._convert_element(str(k), key_type) if k is not None else k + ): self._convert_element(str(v), value_type) if v is not None else None + for k, v in parsed.items() + } + except json.JSONDecodeError: + pass + + # Native format + inner = value[1:-1].strip() + if not inner: + return {} + + if any(char in inner for char in "()[]"): + return None + + pairs = [pair.strip() for pair in inner.split(",")] + result: dict[str, Any] = {} + for pair in pairs: + if "=" not in pair: + continue + k, v = pair.split("=", 1) + k = k.strip() + v = v.strip() + if any(char in k for char in '{}="') or any(char in v for char in '{}="'): + continue + converted_key = self._convert_element(k, key_type) + converted_value = self._convert_element(v, value_type) + result[str(converted_key)] = converted_value + + return result if result else None + + def _convert_typed_struct(self, value: str, type_node: TypeNode) -> dict[str, Any] | None: + """Convert a struct/row value using type information. + + Args: + value: String representation of the struct. + type_node: Type node with field types and names. + + Returns: + Dictionary of converted field values, or None if parsing fails. + """ + if not (value.startswith("{") and value.endswith("}")): + return None + + field_names = type_node.field_names or [] + field_types = type_node.children or [] + + # Try JSON first + inner_preview = value[1:10] if len(value) > 10 else value[1:-1] + if '"' in inner_preview or value.startswith('{"'): + try: + parsed = json.loads(value) + if isinstance(parsed, dict): + result: dict[str, Any] = {} + for i, (k, v) in enumerate(parsed.items()): + ft = field_types[i] if i < len(field_types) else TypeNode("varchar") + result[k] = self._convert_element(str(v), ft) if v is not None else None + return result + except json.JSONDecodeError: + pass + + inner = value[1:-1].strip() + if not inner: + return {} + + if "=" in inner: + # Named struct + pairs = _split_array_items(inner) + result = {} + field_index = 0 + for pair in pairs: + if "=" not in pair: + continue + k, v = pair.split("=", 1) + k = k.strip() + v = v.strip() + if any(char in k for char in '{}="'): + continue + + ft = self._get_field_type(k, field_names, field_types, field_index) + field_index += 1 + + if v.startswith("{") and v.endswith("}"): + if ft.type_name in ("row", "struct"): + result[k] = self._convert_typed_struct(v, ft) + elif ft.type_name == "map": + result[k] = self._convert_typed_map(v, ft) + else: + result[k] = self._struct_parser(v) + else: + result[k] = self._convert_element(v, ft) + return result if result else None + + # Unnamed struct + values = [v.strip() for v in inner.split(",")] + result = {} + for i, v in enumerate(values): + ft = field_types[i] if i < len(field_types) else TypeNode("varchar") + name = field_names[i] if i < len(field_names) else str(i) + result[name] = self._convert_element(v, ft) + return result + + def _get_field_type( + self, + field_name: str, + field_names: list[str], + field_types: list[TypeNode], + field_index: int, + ) -> TypeNode: + """Look up the type for a struct field by name or index. + + Tries name-based lookup first, then falls back to positional index. + + Args: + field_name: Name of the field to look up. + field_names: List of known field names from the type hint. + field_types: List of corresponding field types. + field_index: Current positional index as fallback. + + Returns: + TypeNode for the field, defaulting to varchar if not found. + """ + if field_name in field_names: + idx = field_names.index(field_name) + if idx < len(field_types): + return field_types[idx] + if field_index < len(field_types): + return field_types[field_index] + return TypeNode("varchar") diff --git a/tests/pyathena/test_converter.py b/tests/pyathena/test_converter.py index 8ef99833..9c19349f 100644 --- a/tests/pyathena/test_converter.py +++ b/tests/pyathena/test_converter.py @@ -1,15 +1,14 @@ import pytest from pyathena.converter import ( + _DEFAULT_CONVERTERS, DefaultTypeConverter, _to_array, + _to_default, + _to_map, _to_struct, - parse_type_signature, ) - -# ============================================================================ -# Tests for _to_struct (JSON format - unchanged by breaking change) -# ============================================================================ +from pyathena.parser import TypedValueConverter, TypeNode, TypeSignatureParser @pytest.mark.parametrize( @@ -29,20 +28,12 @@ ], ) def test_to_struct_json_formats(input_value, expected): - """Test STRUCT conversion for various JSON formats and edge cases.""" - result = _to_struct(input_value) - assert result == expected - - -# ============================================================================ -# Tests for _to_struct (Athena native format - affected by _convert_value change) -# ============================================================================ + assert _to_struct(input_value) == expected @pytest.mark.parametrize( ("input_value", "expected"), [ - # Values that were previously inferred as int/bool now stay as strings ("{a=1, b=2}", {"a": "1", "b": "2"}), ("{}", {}), ("{name=John, city=Tokyo}", {"name": "John", "city": "Tokyo"}), @@ -54,29 +45,20 @@ def test_to_struct_json_formats(input_value, expected): ], ) def test_to_struct_athena_native_formats(input_value, expected): - """Test STRUCT conversion for Athena native formats. - - After the breaking change, _convert_value returns strings by default - (no heuristic type inference for numbers/bools). - """ - result = _to_struct(input_value) - assert result == expected + assert _to_struct(input_value) == expected @pytest.mark.parametrize( ("input_value", "expected"), [ - # Single level nesting (Issue #627) - leaf values are strings ( "{header={stamp=2024-01-01, seq=123}, x=4.736, y=0.583}", {"header": {"stamp": "2024-01-01", "seq": "123"}, "x": "4.736", "y": "0.583"}, ), - # Double nesting ( "{outer={middle={inner=value}}, field=123}", {"outer": {"middle": {"inner": "value"}}, "field": "123"}, ), - # Multiple nested fields ( "{pos={x=1, y=2}, vel={x=0.5, y=0.3}, timestamp=12345}", { @@ -85,22 +67,18 @@ def test_to_struct_athena_native_formats(input_value, expected): "timestamp": "12345", }, ), - # Triple nesting ( "{level1={level2={level3={value=deep}}}}", {"level1": {"level2": {"level3": {"value": "deep"}}}}, ), - # Mixed types in nested struct - values are now strings ( "{metadata={id=123, active=true, name=test}, count=5}", {"metadata": {"id": "123", "active": "true", "name": "test"}, "count": "5"}, ), - # Nested struct with null value - null still becomes None ( "{data={value=null, status=ok}, flag=true}", {"data": {"value": None, "status": "ok"}, "flag": "true"}, ), - # Complex nesting with multiple levels and fields ( "{a={b={c=1, d=2}, e=3}, f=4, g={h=5}}", {"a": {"b": {"c": "1", "d": "2"}, "e": "3"}, "f": "4", "g": {"h": "5"}}, @@ -108,131 +86,42 @@ def test_to_struct_athena_native_formats(input_value, expected): ], ) def test_to_struct_athena_nested_formats(input_value, expected): - """Test STRUCT conversion for nested struct formats (Issue #627).""" - result = _to_struct(input_value) - assert result == expected + assert _to_struct(input_value) == expected @pytest.mark.parametrize( "input_value", [ - "{formula=x=y+1, status=active}", # Equals in value - '{json={"key": "value"}, name=test}', # Braces in value - '{message=He said "hello", name=John}', # Quotes in value + "{formula=x=y+1, status=active}", + '{json={"key": "value"}, name=test}', + '{message=He said "hello", name=John}', ], ) def test_to_struct_athena_complex_cases(input_value): - """Test complex cases with special characters return None or partial dict (safe fallback).""" result = _to_struct(input_value) - # With the new continue logic, these may return partial results instead of None - # Check if they return None (strict safety) or partial results (lenient approach) - assert result is None or isinstance(result, dict), ( - f"Complex case should return None or dict: {input_value} -> {result}" - ) - - -# ============================================================================ -# Tests for _to_map (affected by _convert_value change) -# ============================================================================ - - -def test_to_map_athena_numeric_keys(): - """Test Athena map with numeric keys - values are now strings.""" - from pyathena.converter import _to_map - - map_value = "{1=2, 3=4}" - result = _to_map(map_value) - expected = {"1": "2", "3": "4"} - assert result == expected - - -# ============================================================================ -# Tests for _to_array (JSON format - unchanged) -# ============================================================================ - - -def test_to_array_athena_numeric_elements(): - """Test Athena array with numeric elements (JSON-parseable, unchanged).""" - array_value = "[1, 2, 3, 4]" - result = _to_array(array_value) - expected = [1, 2, 3, 4] - assert result == expected - - -def test_to_array_athena_mixed_elements(): - """Test Athena array with mixed type elements (native format, affected by change).""" - array_value = "[1, hello, true, null]" - result = _to_array(array_value) - # JSON parsing fails (unquoted hello), so native format is used - # _convert_value now returns strings - expected = ["1", "hello", "true", None] - assert result == expected - - -def test_to_array_athena_struct_elements(): - """Test Athena array with struct elements (native format, affected by change).""" - array_value = "[{name=John, age=30}, {name=Jane, age=25}]" - result = _to_array(array_value) - expected = [{"name": "John", "age": "30"}, {"name": "Jane", "age": "25"}] - assert result == expected - - -def test_to_array_athena_unnamed_struct_elements(): - """Test Athena array with unnamed struct elements (native format, affected by change).""" - array_value = "[{Alice, 25}, {Bob, 30}]" - result = _to_array(array_value) - expected = [{"0": "Alice", "1": "25"}, {"0": "Bob", "1": "30"}] - assert result == expected - - -@pytest.mark.parametrize( - ("input_value", "expected"), - [ - # Array with nested structs (Issue #627) - leaf values are strings - ( - "[{header={stamp=2024-01-01, seq=123}, x=4.736}]", - [{"header": {"stamp": "2024-01-01", "seq": "123"}, "x": "4.736"}], - ), - # Multiple elements with nested structs - ( - "[{pos={x=1, y=2}, vel={x=0.5}}, {pos={x=3, y=4}, vel={x=1.5}}]", - [ - {"pos": {"x": "1", "y": "2"}, "vel": {"x": "0.5"}}, - {"pos": {"x": "3", "y": "4"}, "vel": {"x": "1.5"}}, - ], - ), - # Array with deeply nested structs - ( - "[{data={meta={id=1, active=true}}}]", - [{"data": {"meta": {"id": "1", "active": "true"}}}], - ), - ], -) -def test_to_array_athena_nested_struct_elements(input_value, expected): - """Test Athena array with nested struct elements (Issue #627).""" - result = _to_array(input_value) - assert result == expected + assert result is None or isinstance(result, dict) @pytest.mark.parametrize( "input_value", [ - "[1, 2, 3]", # Array JSON - '"just a string"', # String JSON - "42", # Number JSON + "[1, 2, 3]", + '"just a string"', + "42", ], ) def test_to_struct_non_dict_json(input_value): - """Test that non-dict JSON formats return None.""" - result = _to_struct(input_value) - assert result is None + assert _to_struct(input_value) is None + + +def test_to_map_athena_numeric_keys(): + assert _to_map("{1=2, 3=4}") == {"1": "2", "3": "4"} @pytest.mark.parametrize( ("input_value", "expected"), [ (None, None), - # JSON-parseable arrays keep JSON types ("[1, 2, 3, 4, 5]", [1, 2, 3, 4, 5]), ('["apple", "banana", "cherry"]', ["apple", "banana", "cherry"]), ("[true, false, null]", [True, False, None]), @@ -246,19 +135,15 @@ def test_to_struct_non_dict_json(input_value): ], ) def test_to_array_json_formats(input_value, expected): - """Test ARRAY conversion for various JSON formats and edge cases.""" - result = _to_array(input_value) - assert result == expected + assert _to_array(input_value) == expected @pytest.mark.parametrize( ("input_value", "expected"), [ - # JSON-parseable: keep JSON types ("[1, 2, 3]", [1, 2, 3]), ("[]", []), ("[true, false, null]", [True, False, None]), - # Native format: strings (no heuristic inference) ("[apple, banana, cherry]", ["apple", "banana", "cherry"]), ( "[{Alice, 25}, {Bob, 30}]", @@ -268,221 +153,126 @@ def test_to_array_json_formats(input_value, expected): "[{name=John, age=30}, {name=Jane, age=25}]", [{"name": "John", "age": "30"}, {"name": "Jane", "age": "25"}], ), - # Mixed native: numbers and strings stay as strings ("[1, 2.5, hello]", ["1", "2.5", "hello"]), ], ) def test_to_array_athena_native_formats(input_value, expected): - """Test ARRAY conversion for Athena native formats.""" - result = _to_array(input_value) - assert result == expected + assert _to_array(input_value) == expected + + +@pytest.mark.parametrize( + ("input_value", "expected"), + [ + ( + "[{header={stamp=2024-01-01, seq=123}, x=4.736}]", + [{"header": {"stamp": "2024-01-01", "seq": "123"}, "x": "4.736"}], + ), + ( + "[{pos={x=1, y=2}, vel={x=0.5}}, {pos={x=3, y=4}, vel={x=1.5}}]", + [ + {"pos": {"x": "1", "y": "2"}, "vel": {"x": "0.5"}}, + {"pos": {"x": "3", "y": "4"}, "vel": {"x": "1.5"}}, + ], + ), + ( + "[{data={meta={id=1, active=true}}}]", + [{"data": {"meta": {"id": "1", "active": "true"}}}], + ), + ], +) +def test_to_array_athena_nested_struct_elements(input_value, expected): + assert _to_array(input_value) == expected @pytest.mark.parametrize( ("input_value", "expected"), [ - ("[ARRAY[1, 2], ARRAY[3, 4]]", None), # Nested arrays (native format) - ("[[1, 2], [3, 4]]", [[1, 2], [3, 4]]), # Nested arrays (JSON format - parseable) - ("[MAP(ARRAY['key'], ARRAY['value'])]", None), # Complex nested structures + ("[ARRAY[1, 2], ARRAY[3, 4]]", None), + ("[[1, 2], [3, 4]]", [[1, 2], [3, 4]]), + ("[MAP(ARRAY['key'], ARRAY['value'])]", None), ], ) def test_to_array_complex_nested_cases(input_value, expected): - """Test complex nested array cases behavior.""" - result = _to_array(input_value) - assert result == expected + assert _to_array(input_value) == expected @pytest.mark.parametrize( "input_value", [ - '"just a string"', # String JSON - "42", # Number JSON - '{"key": "value"}', # Object JSON + '"just a string"', + "42", + '{"key": "value"}', ], ) def test_to_array_non_array_json(input_value): - """Test that non-array JSON formats return None.""" - result = _to_array(input_value) - assert result is None + assert _to_array(input_value) is None @pytest.mark.parametrize( "input_value", [ - "not an array", # Not bracketed - "[unclosed array", # Malformed - "closed array]", # Malformed - "[{malformed struct}", # Malformed struct + "not an array", + "[unclosed array", + "closed array]", + "[{malformed struct}", ], ) def test_to_array_invalid_formats(input_value): - """Test that invalid array formats return None.""" - result = _to_array(input_value) - assert result is None - - -# ============================================================================ -# Tests for DefaultTypeConverter (without type_hint) -# ============================================================================ + assert _to_array(input_value) is None class TestDefaultTypeConverter: @pytest.mark.parametrize( ("input_value", "expected"), [ - # JSON format keeps JSON types ('{"name": "Alice", "age": 25}', {"name": "Alice", "age": 25}), (None, None), ("", None), ("invalid json", None), - # Native format: values are strings (breaking change) ("{a=1, b=2}", {"a": "1", "b": "2"}), ], ) def test_struct_conversion(self, input_value, expected): - """Test DefaultTypeConverter STRUCT conversion for various input formats.""" converter = DefaultTypeConverter() - result = converter.convert("row", input_value) - assert result == expected + assert converter.convert("row", input_value) == expected @pytest.mark.parametrize( ("input_value", "expected"), [ - # JSON format keeps JSON types ("[1, 2, 3]", [1, 2, 3]), ('["a", "b", "c"]', ["a", "b", "c"]), (None, None), ("", None), ("invalid json", None), - # Native format: values are strings ("[apple, banana]", ["apple", "banana"]), ("[]", []), ], ) def test_array_conversion(self, input_value, expected): - """Test DefaultTypeConverter ARRAY conversion for various input formats.""" converter = DefaultTypeConverter() - result = converter.convert("array", input_value) - assert result == expected - - -# ============================================================================ -# Tests for parse_type_signature -# ============================================================================ - - -class TestParseTypeSignature: - def test_simple_type(self): - node = parse_type_signature("varchar") - assert node.type_name == "varchar" - assert node.children == [] - assert node.field_names is None - - def test_simple_type_case_insensitive(self): - node = parse_type_signature("VARCHAR") - assert node.type_name == "varchar" - - def test_array_type(self): - node = parse_type_signature("array(varchar)") - assert node.type_name == "array" - assert len(node.children) == 1 - assert node.children[0].type_name == "varchar" - - def test_array_of_integer(self): - node = parse_type_signature("array(integer)") - assert node.type_name == "array" - assert node.children[0].type_name == "integer" - - def test_map_type(self): - node = parse_type_signature("map(varchar, integer)") - assert node.type_name == "map" - assert len(node.children) == 2 - assert node.children[0].type_name == "varchar" - assert node.children[1].type_name == "integer" - - def test_row_type(self): - node = parse_type_signature("row(name varchar, age integer)") - assert node.type_name == "row" - assert len(node.children) == 2 - assert node.field_names == ["name", "age"] - assert node.children[0].type_name == "varchar" - assert node.children[1].type_name == "integer" - - def test_struct_type(self): - node = parse_type_signature("struct(name varchar, age integer)") - assert node.type_name == "struct" - assert node.field_names == ["name", "age"] - - def test_nested_array_of_row(self): - node = parse_type_signature("array(row(name varchar, age integer))") - assert node.type_name == "array" - assert len(node.children) == 1 - row_node = node.children[0] - assert row_node.type_name == "row" - assert row_node.field_names == ["name", "age"] - assert row_node.children[0].type_name == "varchar" - assert row_node.children[1].type_name == "integer" - - def test_map_with_complex_value(self): - node = parse_type_signature("map(varchar, row(x integer, y double))") - assert node.type_name == "map" - assert node.children[0].type_name == "varchar" - assert node.children[1].type_name == "row" - assert node.children[1].field_names == ["x", "y"] - - def test_deeply_nested(self): - node = parse_type_signature("array(row(data row(x integer, y integer), name varchar))") - assert node.type_name == "array" - row_node = node.children[0] - assert row_node.type_name == "row" - assert row_node.field_names == ["data", "name"] - assert row_node.children[0].type_name == "row" - assert row_node.children[0].field_names == ["x", "y"] - assert row_node.children[1].type_name == "varchar" - - def test_parameterized_type(self): - node = parse_type_signature("decimal(10, 2)") - assert node.type_name == "decimal" - - def test_varchar_with_length(self): - node = parse_type_signature("varchar(255)") - assert node.type_name == "varchar" - - -# ============================================================================ -# Tests for typed conversion with type_hint -# ============================================================================ - + assert converter.convert("array", input_value) == expected -class TestTypedConversion: def test_array_varchar_keeps_strings(self): - """Test that array(varchar) type hint keeps elements as strings.""" converter = DefaultTypeConverter() - # JSON format: [1234, 5678] would be parsed as ints by JSON, - # but type_hint says varchar, so they should be strings result = converter.convert("array", "[1234, 5678]", type_hint="array(varchar)") assert result == ["1234", "5678"] def test_array_integer_converts_to_int(self): - """Test that array(integer) type hint converts elements to ints.""" converter = DefaultTypeConverter() result = converter.convert("array", "[1, 2, 3]", type_hint="array(integer)") assert result == [1, 2, 3] def test_array_boolean_converts(self): - """Test that array(boolean) type hint converts elements to bools.""" converter = DefaultTypeConverter() result = converter.convert("array", "[true, false]", type_hint="array(boolean)") assert result == [True, False] def test_array_with_null(self): - """Test that nulls in arrays are preserved regardless of type hint.""" converter = DefaultTypeConverter() result = converter.convert("array", "[1, null, 3]", type_hint="array(integer)") assert result == [1, None, 3] def test_map_varchar_integer(self): - """Test map(varchar, integer) type hint.""" converter = DefaultTypeConverter() result = converter.convert( "map", '{"key1": 1, "key2": 2}', type_hint="map(varchar, integer)" @@ -490,13 +280,11 @@ def test_map_varchar_integer(self): assert result == {"key1": 1, "key2": 2} def test_map_native_format_with_hints(self): - """Test map type hint with native format.""" converter = DefaultTypeConverter() result = converter.convert("map", "{a=1, b=2}", type_hint="map(varchar, integer)") assert result == {"a": 1, "b": 2} def test_row_type_hint(self): - """Test row type hint with named fields.""" converter = DefaultTypeConverter() result = converter.convert( "row", @@ -506,7 +294,6 @@ def test_row_type_hint(self): assert result == {"name": "Alice", "age": 25} def test_row_native_format_with_hints(self): - """Test row type hint with native format.""" converter = DefaultTypeConverter() result = converter.convert( "row", @@ -516,7 +303,6 @@ def test_row_native_format_with_hints(self): assert result == {"name": "Alice", "age": 25} def test_nested_array_of_row(self): - """Test array(row(...)) type hint preserves correct types.""" converter = DefaultTypeConverter() result = converter.convert( "array", @@ -529,10 +315,7 @@ def test_nested_array_of_row(self): ] def test_array_varchar_prevents_number_inference(self): - """Test the core use case: array(varchar) prevents "1234" -> 1234.""" converter = DefaultTypeConverter() - # This is the key issue from #689: varchar values like "1234" should - # not be converted to numbers when type_hint specifies varchar result = converter.convert( "array", "[1234, 5678, hello]", @@ -541,40 +324,30 @@ def test_array_varchar_prevents_number_inference(self): assert result == ["1234", "5678", "hello"] def test_none_value_with_type_hint(self): - """Test that None value returns None even with type hint.""" converter = DefaultTypeConverter() - result = converter.convert("array", None, type_hint="array(varchar)") - assert result is None + assert converter.convert("array", None, type_hint="array(varchar)") is None def test_simple_type_hint(self): - """Test type hint for simple types.""" converter = DefaultTypeConverter() - result = converter.convert("varchar", "hello", type_hint="varchar") - assert result == "hello" + assert converter.convert("varchar", "hello", type_hint="varchar") == "hello" def test_type_hint_caching(self): - """Test that parsed type hints are cached.""" converter = DefaultTypeConverter() converter.convert("array", "[1, 2]", type_hint="array(integer)") assert "array(integer)" in converter._parsed_hints - # Second call should use cache converter.convert("array", "[3, 4]", type_hint="array(integer)") assert len(converter._parsed_hints) == 1 def test_empty_array_with_type_hint(self): - """Test empty array with type hint.""" converter = DefaultTypeConverter() - result = converter.convert("array", "[]", type_hint="array(varchar)") - assert result == [] + assert converter.convert("array", "[]", type_hint="array(varchar)") == [] def test_map_varchar_varchar(self): - """Test map(varchar, varchar) keeps all values as strings.""" converter = DefaultTypeConverter() result = converter.convert("map", "{key1=123, key2=456}", type_hint="map(varchar, varchar)") assert result == {"key1": "123", "key2": "456"} def test_row_with_nested_struct(self): - """Test row with nested struct field.""" converter = DefaultTypeConverter() result = converter.convert( "row", @@ -582,3 +355,137 @@ def test_row_with_nested_struct(self): type_hint="row(header row(seq integer, stamp varchar), x double)", ) assert result == {"header": {"seq": 123, "stamp": "2024"}, "x": 4.5} + + +class TestTypeSignatureParser: + def test_simple_type(self): + parser = TypeSignatureParser() + node = parser.parse("varchar") + assert node.type_name == "varchar" + assert node.children == [] + assert node.field_names is None + + def test_simple_type_case_insensitive(self): + parser = TypeSignatureParser() + node = parser.parse("VARCHAR") + assert node.type_name == "varchar" + + def test_array_type(self): + parser = TypeSignatureParser() + node = parser.parse("array(varchar)") + assert node.type_name == "array" + assert len(node.children) == 1 + assert node.children[0].type_name == "varchar" + + def test_array_of_integer(self): + parser = TypeSignatureParser() + node = parser.parse("array(integer)") + assert node.type_name == "array" + assert node.children[0].type_name == "integer" + + def test_map_type(self): + parser = TypeSignatureParser() + node = parser.parse("map(varchar, integer)") + assert node.type_name == "map" + assert len(node.children) == 2 + assert node.children[0].type_name == "varchar" + assert node.children[1].type_name == "integer" + + def test_row_type(self): + parser = TypeSignatureParser() + node = parser.parse("row(name varchar, age integer)") + assert node.type_name == "row" + assert len(node.children) == 2 + assert node.field_names == ["name", "age"] + assert node.children[0].type_name == "varchar" + assert node.children[1].type_name == "integer" + + def test_struct_type(self): + parser = TypeSignatureParser() + node = parser.parse("struct(name varchar, age integer)") + assert node.type_name == "struct" + assert node.field_names == ["name", "age"] + + def test_nested_array_of_row(self): + parser = TypeSignatureParser() + node = parser.parse("array(row(name varchar, age integer))") + assert node.type_name == "array" + assert len(node.children) == 1 + row_node = node.children[0] + assert row_node.type_name == "row" + assert row_node.field_names == ["name", "age"] + assert row_node.children[0].type_name == "varchar" + assert row_node.children[1].type_name == "integer" + + def test_map_with_complex_value(self): + parser = TypeSignatureParser() + node = parser.parse("map(varchar, row(x integer, y double))") + assert node.type_name == "map" + assert node.children[0].type_name == "varchar" + assert node.children[1].type_name == "row" + assert node.children[1].field_names == ["x", "y"] + + def test_deeply_nested(self): + parser = TypeSignatureParser() + node = parser.parse("array(row(data row(x integer, y integer), name varchar))") + assert node.type_name == "array" + row_node = node.children[0] + assert row_node.type_name == "row" + assert row_node.field_names == ["data", "name"] + assert row_node.children[0].type_name == "row" + assert row_node.children[0].field_names == ["x", "y"] + assert row_node.children[1].type_name == "varchar" + + def test_parameterized_type(self): + parser = TypeSignatureParser() + node = parser.parse("decimal(10, 2)") + assert node.type_name == "decimal" + + def test_varchar_with_length(self): + parser = TypeSignatureParser() + node = parser.parse("varchar(255)") + assert node.type_name == "varchar" + + +class TestTypedValueConverter: + @pytest.fixture + def converter(self): + return TypedValueConverter( + converters=_DEFAULT_CONVERTERS, + default_converter=_to_default, + struct_parser=_to_struct, + ) + + def test_simple_varchar(self, converter): + node = TypeNode("varchar") + assert converter.convert("hello", node) == "hello" + + def test_simple_integer(self, converter): + node = TypeNode("integer") + assert converter.convert("42", node) == 42 + + def test_array_of_varchar(self, converter): + parser = TypeSignatureParser() + node = parser.parse("array(varchar)") + assert converter.convert("[1234, 5678]", node) == ["1234", "5678"] + + def test_array_of_integer(self, converter): + parser = TypeSignatureParser() + node = parser.parse("array(integer)") + assert converter.convert("[1, 2, 3]", node) == [1, 2, 3] + + def test_map_varchar_integer(self, converter): + parser = TypeSignatureParser() + node = parser.parse("map(varchar, integer)") + assert converter.convert('{"a": 1, "b": 2}', node) == {"a": 1, "b": 2} + + def test_row_named_fields(self, converter): + parser = TypeSignatureParser() + node = parser.parse("row(name varchar, age integer)") + assert converter.convert("{name=Alice, age=25}", node) == {"name": "Alice", "age": 25} + + def test_nested_row(self, converter): + parser = TypeSignatureParser() + node = parser.parse("row(header row(seq integer, stamp varchar), x double)") + result = converter.convert("{header={seq=123, stamp=2024}, x=4.5}", node) + assert result == {"header": {"seq": 123, "stamp": "2024"}, "x": 4.5} From 2d0f222d4ef5e601c3fc1c7d2b94ac538e9af988 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 14:03:26 +0900 Subject: [PATCH 04/18] Update test expected values for native format string conversion behavior MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Native format complex types (map, struct) now return string values instead of type-inferred values to prevent incorrect conversions (e.g., varchar "1234" → int 1234). JSON format paths are unaffected. Co-Authored-By: Claude Opus 4.6 --- tests/pyathena/pandas/test_util.py | 4 +-- tests/pyathena/s3fs/test_async_cursor.py | 4 +-- tests/pyathena/s3fs/test_cursor.py | 4 +-- tests/pyathena/sqlalchemy/test_base.py | 38 ++++++++++++------------ tests/pyathena/test_cursor.py | 6 ++-- 5 files changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/pyathena/pandas/test_util.py b/tests/pyathena/pandas/test_util.py index a34d2341..788308da 100644 --- a/tests/pyathena/pandas/test_util.py +++ b/tests/pyathena/pandas/test_util.py @@ -116,9 +116,9 @@ def test_as_pandas(cursor): b"123", [1, 2], [1, 2], + {"1": "2", "3": "4"}, {"1": 2, "3": 4}, - {"1": 2, "3": 4}, - {"a": 1, "b": 2}, + {"a": "1", "b": "2"}, Decimal("0.1"), ) ] diff --git a/tests/pyathena/s3fs/test_async_cursor.py b/tests/pyathena/s3fs/test_async_cursor.py index f7046fa3..4363367a 100644 --- a/tests/pyathena/s3fs/test_async_cursor.py +++ b/tests/pyathena/s3fs/test_async_cursor.py @@ -118,9 +118,9 @@ def test_complex(self, async_s3fs_cursor): b"123", [1, 2], [1, 2], + {"1": "2", "3": "4"}, {"1": 2, "3": 4}, - {"1": 2, "3": 4}, - {"a": 1, "b": 2}, + {"a": "1", "b": "2"}, Decimal("0.1"), ) ] diff --git a/tests/pyathena/s3fs/test_cursor.py b/tests/pyathena/s3fs/test_cursor.py index fd8f6572..40edaa07 100644 --- a/tests/pyathena/s3fs/test_cursor.py +++ b/tests/pyathena/s3fs/test_cursor.py @@ -118,9 +118,9 @@ def test_complex(self, s3fs_cursor): b"123", [1, 2], [1, 2], + {"1": "2", "3": "4"}, {"1": 2, "3": 4}, - {"1": 2, "3": 4}, - {"a": 1, "b": 2}, + {"a": "1", "b": "2"}, Decimal("0.1"), ) ] diff --git a/tests/pyathena/sqlalchemy/test_base.py b/tests/pyathena/sqlalchemy/test_base.py index 7c4c44fc..a42219ee 100644 --- a/tests/pyathena/sqlalchemy/test_base.py +++ b/tests/pyathena/sqlalchemy/test_base.py @@ -130,9 +130,9 @@ def test_select_nested_struct_query(self, engine): assert "header" in result.positions assert isinstance(result.positions["header"], dict) assert result.positions["header"]["stamp"] == "2024-01-01" - assert result.positions["header"]["seq"] == 123 - assert result.positions["x"] == 4.736 - assert result.positions["y"] == 0.583 + assert result.positions["header"]["seq"] == "123" + assert result.positions["x"] == "4.736" + assert result.positions["y"] == "0.583" # Test double nested struct query = sqlalchemy.text( @@ -147,7 +147,7 @@ def test_select_nested_struct_query(self, engine): result = conn.execute(query).fetchone() assert result is not None assert result.data["level1"]["level2"]["level3"] == "value" - assert result.data["field"] == 123 + assert result.data["field"] == "123" # Test multiple nested fields query = sqlalchemy.text( @@ -166,11 +166,11 @@ def test_select_nested_struct_query(self, engine): ) result = conn.execute(query).fetchone() assert result is not None - assert result.data["pos"]["x"] == 1 - assert result.data["pos"]["y"] == 2 - assert result.data["vel"]["x"] == 0.5 - assert result.data["vel"]["y"] == 0.3 - assert result.data["timestamp"] == 12345 + assert result.data["pos"]["x"] == "1" + assert result.data["pos"]["y"] == "2" + assert result.data["vel"]["x"] == "0.5" + assert result.data["vel"]["y"] == "0.3" + assert result.data["timestamp"] == "12345" def test_select_array_with_nested_struct(self, engine): """Test SELECT query with ARRAY containing nested STRUCT (Issue #627).""" @@ -197,8 +197,8 @@ def test_select_array_with_nested_struct(self, engine): assert "header" in result.positions[0] assert isinstance(result.positions[0]["header"], dict) assert result.positions[0]["header"]["stamp"] == "2024-01-01" - assert result.positions[0]["header"]["seq"] == 123 - assert result.positions[0]["x"] == 4.736 + assert result.positions[0]["header"]["seq"] == "123" + assert result.positions[0]["x"] == "4.736" # Multiple elements with nested structs query = sqlalchemy.text( @@ -213,12 +213,12 @@ def test_select_array_with_nested_struct(self, engine): result = conn.execute(query).fetchone() assert result is not None assert len(result.data) == 2 - assert result.data[0]["pos"]["x"] == 1 - assert result.data[0]["pos"]["y"] == 2 - assert result.data[0]["vel"]["x"] == 0.5 - assert result.data[1]["pos"]["x"] == 3 - assert result.data[1]["pos"]["y"] == 4 - assert result.data[1]["vel"]["x"] == 1.5 + assert result.data[0]["pos"]["x"] == "1" + assert result.data[0]["pos"]["y"] == "2" + assert result.data[0]["vel"]["x"] == "0.5" + assert result.data[1]["pos"]["x"] == "3" + assert result.data[1]["pos"]["y"] == "4" + assert result.data[1]["vel"]["x"] == "1.5" def test_reflect_no_such_table(self, engine): engine, conn = engine @@ -513,8 +513,8 @@ def test_reflect_select(self, engine): date(2017, 1, 2), b"123", [1, 2], - {"1": 2, "3": 4}, # map type now converted to dict - {"a": 1, "b": 2}, # row type now converted to dict + {"1": "2", "3": "4"}, # map type now converted to dict + {"a": "1", "b": "2"}, # row type now converted to dict Decimal("0.1"), ] assert isinstance(one_row_complex.c.col_boolean.type, types.BOOLEAN) diff --git a/tests/pyathena/test_cursor.py b/tests/pyathena/test_cursor.py index 27c34950..e8687b64 100644 --- a/tests/pyathena/test_cursor.py +++ b/tests/pyathena/test_cursor.py @@ -538,9 +538,9 @@ def test_complex(self, cursor): b"123", [1, 2], [1, 2], + {"1": "2", "3": "4"}, {"1": 2, "3": 4}, - {"1": 2, "3": 4}, - {"a": 1, "b": 2}, + {"a": "1", "b": "2"}, Decimal("0.1"), ) ] @@ -1346,7 +1346,7 @@ def test_array_converter_behavior(self, cursor): test_cases = [ ("[1, 2, 3]", [1, 2, 3]), ('["a", "b", "c"]', ["a", "b", "c"]), - ("[{a=1, b=2}]", [{"a": 1, "b": 2}]), + ("[{a=1, b=2}]", [{"a": "1", "b": "2"}]), ("[]", []), (None, None), ("invalid", None), From 9cd535fe6503c75a773ad561b4f52426e12e54f0 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 14:27:18 +0900 Subject: [PATCH 05/18] Move parser tests from test_converter.py to test_parser.py TestTypeSignatureParser and TestTypedValueConverter test the parser module directly, so they belong in a dedicated test file. Co-Authored-By: Claude Opus 4.6 --- tests/pyathena/test_converter.py | 137 ------------------------------ tests/pyathena/test_parser.py | 138 +++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 137 deletions(-) create mode 100644 tests/pyathena/test_parser.py diff --git a/tests/pyathena/test_converter.py b/tests/pyathena/test_converter.py index 9c19349f..1c4ff3dc 100644 --- a/tests/pyathena/test_converter.py +++ b/tests/pyathena/test_converter.py @@ -1,14 +1,11 @@ import pytest from pyathena.converter import ( - _DEFAULT_CONVERTERS, DefaultTypeConverter, _to_array, - _to_default, _to_map, _to_struct, ) -from pyathena.parser import TypedValueConverter, TypeNode, TypeSignatureParser @pytest.mark.parametrize( @@ -355,137 +352,3 @@ def test_row_with_nested_struct(self): type_hint="row(header row(seq integer, stamp varchar), x double)", ) assert result == {"header": {"seq": 123, "stamp": "2024"}, "x": 4.5} - - -class TestTypeSignatureParser: - def test_simple_type(self): - parser = TypeSignatureParser() - node = parser.parse("varchar") - assert node.type_name == "varchar" - assert node.children == [] - assert node.field_names is None - - def test_simple_type_case_insensitive(self): - parser = TypeSignatureParser() - node = parser.parse("VARCHAR") - assert node.type_name == "varchar" - - def test_array_type(self): - parser = TypeSignatureParser() - node = parser.parse("array(varchar)") - assert node.type_name == "array" - assert len(node.children) == 1 - assert node.children[0].type_name == "varchar" - - def test_array_of_integer(self): - parser = TypeSignatureParser() - node = parser.parse("array(integer)") - assert node.type_name == "array" - assert node.children[0].type_name == "integer" - - def test_map_type(self): - parser = TypeSignatureParser() - node = parser.parse("map(varchar, integer)") - assert node.type_name == "map" - assert len(node.children) == 2 - assert node.children[0].type_name == "varchar" - assert node.children[1].type_name == "integer" - - def test_row_type(self): - parser = TypeSignatureParser() - node = parser.parse("row(name varchar, age integer)") - assert node.type_name == "row" - assert len(node.children) == 2 - assert node.field_names == ["name", "age"] - assert node.children[0].type_name == "varchar" - assert node.children[1].type_name == "integer" - - def test_struct_type(self): - parser = TypeSignatureParser() - node = parser.parse("struct(name varchar, age integer)") - assert node.type_name == "struct" - assert node.field_names == ["name", "age"] - - def test_nested_array_of_row(self): - parser = TypeSignatureParser() - node = parser.parse("array(row(name varchar, age integer))") - assert node.type_name == "array" - assert len(node.children) == 1 - row_node = node.children[0] - assert row_node.type_name == "row" - assert row_node.field_names == ["name", "age"] - assert row_node.children[0].type_name == "varchar" - assert row_node.children[1].type_name == "integer" - - def test_map_with_complex_value(self): - parser = TypeSignatureParser() - node = parser.parse("map(varchar, row(x integer, y double))") - assert node.type_name == "map" - assert node.children[0].type_name == "varchar" - assert node.children[1].type_name == "row" - assert node.children[1].field_names == ["x", "y"] - - def test_deeply_nested(self): - parser = TypeSignatureParser() - node = parser.parse("array(row(data row(x integer, y integer), name varchar))") - assert node.type_name == "array" - row_node = node.children[0] - assert row_node.type_name == "row" - assert row_node.field_names == ["data", "name"] - assert row_node.children[0].type_name == "row" - assert row_node.children[0].field_names == ["x", "y"] - assert row_node.children[1].type_name == "varchar" - - def test_parameterized_type(self): - parser = TypeSignatureParser() - node = parser.parse("decimal(10, 2)") - assert node.type_name == "decimal" - - def test_varchar_with_length(self): - parser = TypeSignatureParser() - node = parser.parse("varchar(255)") - assert node.type_name == "varchar" - - -class TestTypedValueConverter: - @pytest.fixture - def converter(self): - return TypedValueConverter( - converters=_DEFAULT_CONVERTERS, - default_converter=_to_default, - struct_parser=_to_struct, - ) - - def test_simple_varchar(self, converter): - node = TypeNode("varchar") - assert converter.convert("hello", node) == "hello" - - def test_simple_integer(self, converter): - node = TypeNode("integer") - assert converter.convert("42", node) == 42 - - def test_array_of_varchar(self, converter): - parser = TypeSignatureParser() - node = parser.parse("array(varchar)") - assert converter.convert("[1234, 5678]", node) == ["1234", "5678"] - - def test_array_of_integer(self, converter): - parser = TypeSignatureParser() - node = parser.parse("array(integer)") - assert converter.convert("[1, 2, 3]", node) == [1, 2, 3] - - def test_map_varchar_integer(self, converter): - parser = TypeSignatureParser() - node = parser.parse("map(varchar, integer)") - assert converter.convert('{"a": 1, "b": 2}', node) == {"a": 1, "b": 2} - - def test_row_named_fields(self, converter): - parser = TypeSignatureParser() - node = parser.parse("row(name varchar, age integer)") - assert converter.convert("{name=Alice, age=25}", node) == {"name": "Alice", "age": 25} - - def test_nested_row(self, converter): - parser = TypeSignatureParser() - node = parser.parse("row(header row(seq integer, stamp varchar), x double)") - result = converter.convert("{header={seq=123, stamp=2024}, x=4.5}", node) - assert result == {"header": {"seq": 123, "stamp": "2024"}, "x": 4.5} diff --git a/tests/pyathena/test_parser.py b/tests/pyathena/test_parser.py new file mode 100644 index 00000000..5ca81364 --- /dev/null +++ b/tests/pyathena/test_parser.py @@ -0,0 +1,138 @@ +import pytest + +from pyathena.converter import _DEFAULT_CONVERTERS, _to_default, _to_struct +from pyathena.parser import TypedValueConverter, TypeNode, TypeSignatureParser + + +class TestTypeSignatureParser: + def test_simple_type(self): + parser = TypeSignatureParser() + node = parser.parse("varchar") + assert node.type_name == "varchar" + assert node.children == [] + assert node.field_names is None + + def test_simple_type_case_insensitive(self): + parser = TypeSignatureParser() + node = parser.parse("VARCHAR") + assert node.type_name == "varchar" + + def test_array_type(self): + parser = TypeSignatureParser() + node = parser.parse("array(varchar)") + assert node.type_name == "array" + assert len(node.children) == 1 + assert node.children[0].type_name == "varchar" + + def test_array_of_integer(self): + parser = TypeSignatureParser() + node = parser.parse("array(integer)") + assert node.type_name == "array" + assert node.children[0].type_name == "integer" + + def test_map_type(self): + parser = TypeSignatureParser() + node = parser.parse("map(varchar, integer)") + assert node.type_name == "map" + assert len(node.children) == 2 + assert node.children[0].type_name == "varchar" + assert node.children[1].type_name == "integer" + + def test_row_type(self): + parser = TypeSignatureParser() + node = parser.parse("row(name varchar, age integer)") + assert node.type_name == "row" + assert len(node.children) == 2 + assert node.field_names == ["name", "age"] + assert node.children[0].type_name == "varchar" + assert node.children[1].type_name == "integer" + + def test_struct_type(self): + parser = TypeSignatureParser() + node = parser.parse("struct(name varchar, age integer)") + assert node.type_name == "struct" + assert node.field_names == ["name", "age"] + + def test_nested_array_of_row(self): + parser = TypeSignatureParser() + node = parser.parse("array(row(name varchar, age integer))") + assert node.type_name == "array" + assert len(node.children) == 1 + row_node = node.children[0] + assert row_node.type_name == "row" + assert row_node.field_names == ["name", "age"] + assert row_node.children[0].type_name == "varchar" + assert row_node.children[1].type_name == "integer" + + def test_map_with_complex_value(self): + parser = TypeSignatureParser() + node = parser.parse("map(varchar, row(x integer, y double))") + assert node.type_name == "map" + assert node.children[0].type_name == "varchar" + assert node.children[1].type_name == "row" + assert node.children[1].field_names == ["x", "y"] + + def test_deeply_nested(self): + parser = TypeSignatureParser() + node = parser.parse("array(row(data row(x integer, y integer), name varchar))") + assert node.type_name == "array" + row_node = node.children[0] + assert row_node.type_name == "row" + assert row_node.field_names == ["data", "name"] + assert row_node.children[0].type_name == "row" + assert row_node.children[0].field_names == ["x", "y"] + assert row_node.children[1].type_name == "varchar" + + def test_parameterized_type(self): + parser = TypeSignatureParser() + node = parser.parse("decimal(10, 2)") + assert node.type_name == "decimal" + + def test_varchar_with_length(self): + parser = TypeSignatureParser() + node = parser.parse("varchar(255)") + assert node.type_name == "varchar" + + +class TestTypedValueConverter: + @pytest.fixture + def converter(self): + return TypedValueConverter( + converters=_DEFAULT_CONVERTERS, + default_converter=_to_default, + struct_parser=_to_struct, + ) + + def test_simple_varchar(self, converter): + node = TypeNode("varchar") + assert converter.convert("hello", node) == "hello" + + def test_simple_integer(self, converter): + node = TypeNode("integer") + assert converter.convert("42", node) == 42 + + def test_array_of_varchar(self, converter): + parser = TypeSignatureParser() + node = parser.parse("array(varchar)") + assert converter.convert("[1234, 5678]", node) == ["1234", "5678"] + + def test_array_of_integer(self, converter): + parser = TypeSignatureParser() + node = parser.parse("array(integer)") + assert converter.convert("[1, 2, 3]", node) == [1, 2, 3] + + def test_map_varchar_integer(self, converter): + parser = TypeSignatureParser() + node = parser.parse("map(varchar, integer)") + assert converter.convert('{"a": 1, "b": 2}', node) == {"a": 1, "b": 2} + + def test_row_named_fields(self, converter): + parser = TypeSignatureParser() + node = parser.parse("row(name varchar, age integer)") + assert converter.convert("{name=Alice, age=25}", node) == {"name": "Alice", "age": 25} + + def test_nested_row(self, converter): + parser = TypeSignatureParser() + node = parser.parse("row(header row(seq integer, stamp varchar), x double)") + result = converter.convert("{header={seq=123, stamp=2024}, x=4.5}", node) + assert result == {"header": {"seq": 123, "stamp": "2024"}, "x": 4.5} From cd9230edbf25050439e8f875ea608b85efd0e250 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 14:29:39 +0900 Subject: [PATCH 06/18] Reorder parser.py: move TypeNode after _split_array_items Place the private helper function before public classes for clearer top-down reading order. Co-Authored-By: Claude Opus 4.6 --- pyathena/parser.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/pyathena/parser.py b/pyathena/parser.py index 0ef39313..bc6443dc 100644 --- a/pyathena/parser.py +++ b/pyathena/parser.py @@ -6,24 +6,6 @@ from typing import Any -@dataclass -class TypeNode: - """Parsed representation of an Athena DDL type signature. - - Represents a node in a type tree, where complex types (array, map, row) - have children representing their element/field types. - - Attributes: - type_name: The base type name (e.g., "array", "map", "row", "varchar"). - children: Child type nodes for complex types. - field_names: Field names for row/struct types (parallel to children). - """ - - type_name: str - children: list[TypeNode] = field(default_factory=list) - field_names: list[str] | None = None - - def _split_array_items(inner: str) -> list[str]: """Split array items by comma, respecting brace and bracket groupings. @@ -60,6 +42,24 @@ def _split_array_items(inner: str) -> list[str]: return items +@dataclass +class TypeNode: + """Parsed representation of an Athena DDL type signature. + + Represents a node in a type tree, where complex types (array, map, row) + have children representing their element/field types. + + Attributes: + type_name: The base type name (e.g., "array", "map", "row", "varchar"). + children: Child type nodes for complex types. + field_names: Field names for row/struct types (parallel to children). + """ + + type_name: str + children: list[TypeNode] = field(default_factory=list) + field_names: list[str] | None = None + + class TypeSignatureParser: """Parse Athena DDL type signature strings into a type tree.""" From 2bc66a6d5f794177aaeaeca881e9446d3f3459de Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 14:31:25 +0900 Subject: [PATCH 07/18] Add test conventions to CLAUDE.md Document class-based vs standalone function test patterns, fixture usage with indirect parametrization, and integration vs unit test distinction. Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index c2725412..1fffd00e 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -42,6 +42,19 @@ export $(cat .env | xargs) && uv run pytest tests/pyathena/test_file.py -v - Use pytest fixtures from `conftest.py` - New features require tests; changes to SQLAlchemy dialects must pass `make test-sqla` +#### Test Conventions +- **Class-based tests** for integration tests that use fixtures (cursors, engines): `class TestCursor:` with methods like `def test_fetchone(self, cursor):` +- **Standalone functions** for unit tests of pure logic (converters, parsers, utils): `def test_to_struct_json_formats(input_value, expected):` +- Test file naming mirrors source: `pyathena/parser.py` → `tests/pyathena/test_parser.py` +- **Fixtures**: Cursor/engine fixtures are defined in `conftest.py` and injected by name (e.g., `cursor`, `engine`, `async_cursor`). Use `indirect=True` parametrization to pass connection options: + ```python + @pytest.mark.parametrize("engine", [{"driver": "rest"}], indirect=True) + def test_query(self, engine): + engine, conn = engine + ``` +- **Parametrize** with `@pytest.mark.parametrize(("input", "expected"), [...])` for data-driven tests +- **Integration tests** (need AWS) use cursor/engine fixtures with real Athena queries; **unit tests** (no AWS) call functions directly with test data + ## Architecture — Key Design Decisions These are non-obvious conventions that can't be discovered by reading code alone. From 40a6d93fbb7d8538f1ddc386949e5a876334ba97 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 14:44:07 +0900 Subject: [PATCH 08/18] Add integration test for result_set_type_hints with default cursor Co-Authored-By: Claude Opus 4.6 --- tests/pyathena/test_cursor.py | 72 +++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/tests/pyathena/test_cursor.py b/tests/pyathena/test_cursor.py index e8687b64..3dd15fc2 100644 --- a/tests/pyathena/test_cursor.py +++ b/tests/pyathena/test_cursor.py @@ -571,6 +571,78 @@ def test_complex(self, cursor): NUMBER, ] + def test_complex_with_type_hints(self, cursor): + # 1. Basic complex columns from one_row_complex + cursor.execute( + """ + SELECT col_array, col_map, col_struct + FROM one_row_complex + """, + result_set_type_hints={ + "col_array": "array(integer)", + "col_map": "map(integer, integer)", + "col_struct": "row(a integer, b integer)", + }, + ) + row = cursor.fetchone() + assert row[0] == [1, 2] + assert row[1] == {"1": 2, "3": 4} + assert row[2] == {"a": 1, "b": 2} + # Without type hints col_struct values are strings; with hints they are ints + assert isinstance(row[2]["a"], int) + assert isinstance(row[2]["b"], int) + # Map values should also be ints + assert isinstance(row[1]["1"], int) + + # 2. Nested struct + cursor.execute( + """ + SELECT CAST( + ROW(ROW('2024-01-01', 123), 4.736, 0.583) + AS ROW(header ROW(stamp VARCHAR, seq INTEGER), x DOUBLE, y DOUBLE) + ) AS positions + """, + result_set_type_hints={ + "positions": "row(header row(stamp varchar, seq integer), x double, y double)", + }, + ) + row = cursor.fetchone() + positions = row[0] + assert positions["header"]["stamp"] == "2024-01-01" + assert positions["header"]["seq"] == 123 + assert isinstance(positions["header"]["seq"], int) + assert positions["x"] == 4.736 + assert isinstance(positions["x"], float) + assert positions["y"] == 0.583 + assert isinstance(positions["y"], float) + + # 3. Array of nested structs + cursor.execute( + """ + SELECT CAST( + ARRAY[ + ROW(ROW(1, 2), ROW(CAST(0.5 AS DOUBLE))), + ROW(ROW(3, 4), ROW(CAST(1.5 AS DOUBLE))) + ] + AS ARRAY(ROW(pos ROW(x INTEGER, y INTEGER), vel ROW(x DOUBLE))) + ) AS data + """, + result_set_type_hints={ + "data": "array(row(pos row(x integer, y integer), vel row(x double)))", + }, + ) + row = cursor.fetchone() + data = row[0] + assert len(data) == 2 + assert data[0]["pos"]["x"] == 1 + assert isinstance(data[0]["pos"]["x"], int) + assert data[0]["pos"]["y"] == 2 + assert isinstance(data[0]["pos"]["y"], int) + assert data[0]["vel"]["x"] == 0.5 + assert isinstance(data[0]["vel"]["x"], float) + assert data[1]["pos"]["x"] == 3 + assert data[1]["vel"]["x"] == 1.5 + def test_cancel(self, cursor): def cancel(c): time.sleep(randint(5, 10)) From a2155a6628720bda2f2fe26678ab04359c597c35 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 15:40:31 +0900 Subject: [PATCH 09/18] Rename _split_top_level to _split_type_args for clarity Co-Authored-By: Claude Opus 4.6 --- pyathena/parser.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyathena/parser.py b/pyathena/parser.py index bc6443dc..65f62bdd 100644 --- a/pyathena/parser.py +++ b/pyathena/parser.py @@ -86,7 +86,7 @@ def parse(self, type_str: str) -> TypeNode: inner = type_str[paren_idx + 1 : -1].strip() if type_name in ("row", "struct"): - parts = self._split_top_level(inner) + parts = self._split_type_args(inner) field_names: list[str] = [] children: list[TypeNode] = [] for part in parts: @@ -107,7 +107,7 @@ def parse(self, type_str: str) -> TypeNode: return TypeNode(type_name=type_name, children=[child]) if type_name == "map": - parts = self._split_top_level(inner) + parts = self._split_type_args(inner) if len(parts) == 2: key_type = self.parse(parts[0]) value_type = self.parse(parts[1]) @@ -117,7 +117,7 @@ def parse(self, type_str: str) -> TypeNode: # Types with parameters like decimal(10, 2), varchar(255) return TypeNode(type_name=type_name) - def _split_top_level(self, s: str, delimiter: str = ",") -> list[str]: + def _split_type_args(self, s: str, delimiter: str = ",") -> list[str]: """Split a string by delimiter, respecting nested parentheses. Args: From 73ce2f0d58785f65d156334367ebe2ba69224805 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 16:00:47 +0900 Subject: [PATCH 10/18] Fix edge cases in typed conversion: backward compat, JSON nesting, null string - Only pass type_hint kwarg when hint exists (avoids breaking custom Converters) - Use json.dumps for dict/list in JSON paths instead of str() (fixes nested structs) - Use convert() instead of _convert_element() in JSON paths (preserves "null" strings) - Use _split_array_items in typed map native path (supports nested row/map values) - Normalize result_set_type_hints keys to lowercase for case-insensitive lookup - Cache DefaultTypeConverter instance in S3FS converter - Add unit tests for all fixed edge cases Co-Authored-By: Claude Opus 4.6 --- pyathena/parser.py | 56 +++++++++++++++++++++++++++-------- pyathena/result_set.py | 18 +++++++++-- pyathena/s3fs/converter.py | 9 +++--- pyathena/s3fs/result_set.py | 12 ++++++-- tests/pyathena/test_parser.py | 41 +++++++++++++++++++++++++ 5 files changed, 114 insertions(+), 22 deletions(-) diff --git a/pyathena/parser.py b/pyathena/parser.py index 65f62bdd..a9edc5fe 100644 --- a/pyathena/parser.py +++ b/pyathena/parser.py @@ -213,6 +213,23 @@ def convert(self, value: str, type_node: TypeNode) -> Any: converter_fn = self._converters.get(type_node.type_name, self._default_converter) return converter_fn(value) + @staticmethod + def _to_json_str(value: Any) -> str: + """Convert a JSON-parsed value back to a string for further conversion. + + Uses json.dumps for dict/list to produce valid JSON, and str() for + scalar types to produce converter-compatible strings. + + Args: + value: A value from json.loads output. + + Returns: + String representation suitable for type conversion. + """ + if isinstance(value, (dict, list)): + return json.dumps(value) + return str(value) + def _convert_element(self, value: str, type_node: TypeNode) -> Any: """Convert a single element within a complex type using type information. @@ -249,7 +266,7 @@ def _convert_typed_array(self, value: str, type_node: TypeNode) -> list[Any] | N parsed = json.loads(value) if isinstance(parsed, list): return [ - None if elem is None else self._convert_element(str(elem), element_type) + None if elem is None else self.convert(self._to_json_str(elem), element_type) for elem in parsed ] except json.JSONDecodeError: @@ -304,9 +321,11 @@ def _convert_typed_map(self, value: str, type_node: TypeNode) -> dict[str, Any] parsed = json.loads(value) if isinstance(parsed, dict): return { - str( - self._convert_element(str(k), key_type) if k is not None else k - ): self._convert_element(str(v), value_type) if v is not None else None + str(self.convert(self._to_json_str(k), key_type) if k is not None else k): ( + self.convert(self._to_json_str(v), value_type) + if v is not None + else None + ) for k, v in parsed.items() } except json.JSONDecodeError: @@ -317,10 +336,7 @@ def _convert_typed_map(self, value: str, type_node: TypeNode) -> dict[str, Any] if not inner: return {} - if any(char in inner for char in "()[]"): - return None - - pairs = [pair.strip() for pair in inner.split(",")] + pairs = _split_array_items(inner) result: dict[str, Any] = {} for pair in pairs: if "=" not in pair: @@ -328,11 +344,23 @@ def _convert_typed_map(self, value: str, type_node: TypeNode) -> dict[str, Any] k, v = pair.split("=", 1) k = k.strip() v = v.strip() - if any(char in k for char in '{}="') or any(char in v for char in '{}="'): + if any(char in k for char in '{}="'): continue - converted_key = self._convert_element(k, key_type) - converted_value = self._convert_element(v, value_type) - result[str(converted_key)] = converted_value + if v.startswith("{") and v.endswith("}"): + if value_type.type_name in ("row", "struct"): + result[str(self._convert_element(k, key_type))] = self._convert_typed_struct( + v, value_type + ) + elif value_type.type_name == "map": + result[str(self._convert_element(k, key_type))] = self._convert_typed_map( + v, value_type + ) + else: + result[str(self._convert_element(k, key_type))] = self._struct_parser(v) + else: + converted_key = self._convert_element(k, key_type) + converted_value = self._convert_element(v, value_type) + result[str(converted_key)] = converted_value return result if result else None @@ -361,7 +389,9 @@ def _convert_typed_struct(self, value: str, type_node: TypeNode) -> dict[str, An result: dict[str, Any] = {} for i, (k, v) in enumerate(parsed.items()): ft = field_types[i] if i < len(field_types) else TypeNode("varchar") - result[k] = self._convert_element(str(v), ft) if v is not None else None + result[k] = ( + self.convert(self._to_json_str(v), ft) if v is not None else None + ) return result except json.JSONDecodeError: pass diff --git a/pyathena/result_set.py b/pyathena/result_set.py index 86a71bb6..190a33e1 100644 --- a/pyathena/result_set.py +++ b/pyathena/result_set.py @@ -70,7 +70,11 @@ def __init__( if not self._query_execution: raise ProgrammingError("Required argument `query_execution` not found.") self._retry_config = retry_config - self._result_set_type_hints = result_set_type_hints + self._result_set_type_hints = ( + {k.lower(): v for k, v in result_set_type_hints.items()} + if result_set_type_hints + else None + ) self._client = connection.session.client( "s3", region_name=connection.region_name, @@ -452,7 +456,11 @@ def _get_rows( conv.convert( meta.get("Type"), row.get("VarCharValue"), - type_hint=hints.get(meta.get("Name", "")) if hints else None, + **( + {"type_hint": hints[meta.get("Name", "").lower()]} + if hints and meta.get("Name", "").lower() in hints + else {} + ), ) for meta, row in zip(metadata, rows[i].get("Data", []), strict=False) ] @@ -646,7 +654,11 @@ def _get_rows( conv.convert( meta.get("Type"), row.get("VarCharValue"), - type_hint=hints.get(meta.get("Name", "")) if hints else None, + **( + {"type_hint": hints[meta.get("Name", "").lower()]} + if hints and meta.get("Name", "").lower() in hints + else {} + ), ), ) for meta, row in zip(metadata, rows[i].get("Data", []), strict=False) diff --git a/pyathena/s3fs/converter.py b/pyathena/s3fs/converter.py index a46e570b..ad3c7b3d 100644 --- a/pyathena/s3fs/converter.py +++ b/pyathena/s3fs/converter.py @@ -43,6 +43,7 @@ def __init__(self) -> None: mappings=deepcopy(_DEFAULT_CONVERTERS), default=_to_default, ) + self._default_type_converter: Any | None = None def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: """Convert a string value to the appropriate Python type. @@ -62,10 +63,10 @@ def convert(self, type_: str, value: str | None, type_hint: str | None = None) - if value is None: return None if type_hint: - from pyathena.converter import DefaultTypeConverter + if self._default_type_converter is None: + from pyathena.converter import DefaultTypeConverter - # Delegate to DefaultTypeConverter for type_hint-based conversion - dtc = DefaultTypeConverter() - return dtc.convert(type_, value, type_hint=type_hint) + self._default_type_converter = DefaultTypeConverter() + return self._default_type_converter.convert(type_, value, type_hint=type_hint) converter = self.get(type_) return converter(value) diff --git a/pyathena/s3fs/result_set.py b/pyathena/s3fs/result_set.py index c289605b..760fd169 100644 --- a/pyathena/s3fs/result_set.py +++ b/pyathena/s3fs/result_set.py @@ -169,7 +169,11 @@ def _fetch(self) -> None: self._converter.convert( col_type, value if value != "" else None, - type_hint=hints.get(col_name) if hints else None, + **( + {"type_hint": hints[col_name.lower()]} + if hints and col_name.lower() in hints + else {} + ), ) for col_type, col_name, value in zip( column_types, column_names, row, strict=False @@ -180,7 +184,11 @@ def _fetch(self) -> None: self._converter.convert( col_type, value, - type_hint=hints.get(col_name) if hints else None, + **( + {"type_hint": hints[col_name.lower()]} + if hints and col_name.lower() in hints + else {} + ), ) for col_type, col_name, value in zip( column_types, column_names, row, strict=False diff --git a/tests/pyathena/test_parser.py b/tests/pyathena/test_parser.py index 5ca81364..94fe1d66 100644 --- a/tests/pyathena/test_parser.py +++ b/tests/pyathena/test_parser.py @@ -136,3 +136,44 @@ def test_nested_row(self, converter): node = parser.parse("row(header row(seq integer, stamp varchar), x double)") result = converter.convert("{header={seq=123, stamp=2024}, x=4.5}", node) assert result == {"header": {"seq": 123, "stamp": "2024"}, "x": 4.5} + + def test_array_of_row_json(self, converter): + """JSON path: array(row(...)) with nested dict elements.""" + parser = TypeSignatureParser() + node = parser.parse("array(row(x integer, y double))") + result = converter.convert('[{"x": 1, "y": 2.5}, {"x": 3, "y": 4.0}]', node) + assert result == [{"x": 1, "y": 2.5}, {"x": 3, "y": 4.0}] + assert isinstance(result[0]["x"], int) + assert isinstance(result[0]["y"], float) + + def test_null_string_preserved_in_json(self, converter): + """JSON path: string "null" in array(varchar) must not become None.""" + parser = TypeSignatureParser() + node = parser.parse("array(varchar)") + result = converter.convert('["null", "x"]', node) + assert result == ["null", "x"] + + def test_map_with_row_value_native(self, converter): + """Native path: map(varchar, row(...)) with nested struct values.""" + parser = TypeSignatureParser() + node = parser.parse("map(varchar, row(x integer, y integer))") + result = converter.convert("{a={x=1, y=2}, b={x=3, y=4}}", node) + assert result == {"a": {"x": 1, "y": 2}, "b": {"x": 3, "y": 4}} + assert isinstance(result["a"]["x"], int) + + def test_nested_row_json(self, converter): + """JSON path: row containing row with nested dict values.""" + parser = TypeSignatureParser() + node = parser.parse("row(inner row(a integer, b varchar), val double)") + result = converter.convert('{"inner": {"a": 10, "b": "hello"}, "val": 3.14}', node) + assert result == {"inner": {"a": 10, "b": "hello"}, "val": 3.14} + assert isinstance(result["inner"]["a"], int) + assert isinstance(result["val"], float) + + def test_map_json_null_value_preserved(self, converter): + """JSON path: map with null values vs "null" string values.""" + parser = TypeSignatureParser() + node = parser.parse("map(varchar, varchar)") + result = converter.convert('{"a": null, "b": "null"}', node) + assert result["a"] is None + assert result["b"] == "null" From 13a380545bd8f9c8f3f97c46f65c21dbf135b169 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 17:31:22 +0900 Subject: [PATCH 11/18] Clean up PR: fix docstrings, remove unused param, improve type annotation - Fix _parse_type_hint docstring to match renamed method - Add docstring to DefaultTypeConverter.convert - Remove unused delimiter parameter from _split_type_args - Use TYPE_CHECKING for DefaultTypeConverter type annotation in S3FS converter Co-Authored-By: Claude Opus 4.6 --- pyathena/converter.py | 21 ++++++++++++++++++--- pyathena/parser.py | 11 +++++------ pyathena/s3fs/converter.py | 7 +++++-- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/pyathena/converter.py b/pyathena/converter.py index a95db4e2..3a486783 100644 --- a/pyathena/converter.py +++ b/pyathena/converter.py @@ -556,16 +556,31 @@ def __init__(self) -> None: self._parsed_hints: dict[str, TypeNode] = {} def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: + """Convert a string value to the appropriate Python type. + + When ``type_hint`` is provided, uses the typed converter for precise + conversion of complex types. Otherwise, uses the standard converter + for the given Athena type. + + Args: + type_: The Athena data type name (e.g., "integer", "varchar", "array"). + value: The string value to convert, or None. + type_hint: Optional Athena DDL type signature for precise complex type + conversion (e.g., "array(varchar)", "row(name varchar, age integer)"). + + Returns: + The converted Python value, or None if the input value was None. + """ if value is None: return None if type_hint: - type_node = self._get_or_parse_hint(type_hint) + type_node = self._parse_type_hint(type_hint) return self._typed_converter.convert(value, type_node) converter = self.get(type_) return converter(value) - def _get_or_parse_hint(self, type_hint: str) -> TypeNode: - """Get or parse a type hint string into a TypeNode, with caching. + def _parse_type_hint(self, type_hint: str) -> TypeNode: + """Parse a type hint string into a TypeNode, with caching. Args: type_hint: Athena DDL type signature string. diff --git a/pyathena/parser.py b/pyathena/parser.py index a9edc5fe..946ee325 100644 --- a/pyathena/parser.py +++ b/pyathena/parser.py @@ -117,15 +117,14 @@ def parse(self, type_str: str) -> TypeNode: # Types with parameters like decimal(10, 2), varchar(255) return TypeNode(type_name=type_name) - def _split_type_args(self, s: str, delimiter: str = ",") -> list[str]: - """Split a string by delimiter, respecting nested parentheses. + def _split_type_args(self, s: str) -> list[str]: + """Split a type signature argument string by comma, respecting nested parentheses. Args: - s: String to split. - delimiter: Character to split on. + s: Type signature argument string to split. Returns: - List of parts split at top-level delimiters. + List of type argument strings. """ parts: list[str] = [] current: list[str] = [] @@ -136,7 +135,7 @@ def _split_type_args(self, s: str, delimiter: str = ",") -> list[str]: depth += 1 elif char == ")": depth -= 1 - elif char == delimiter and depth == 0: + elif char == "," and depth == 0: parts.append("".join(current).strip()) current = [] continue diff --git a/pyathena/s3fs/converter.py b/pyathena/s3fs/converter.py index ad3c7b3d..00fcd5ed 100644 --- a/pyathena/s3fs/converter.py +++ b/pyathena/s3fs/converter.py @@ -2,7 +2,7 @@ import logging from copy import deepcopy -from typing import Any +from typing import TYPE_CHECKING, Any from pyathena.converter import ( _DEFAULT_CONVERTERS, @@ -10,6 +10,9 @@ _to_default, ) +if TYPE_CHECKING: + from pyathena.converter import DefaultTypeConverter + _logger = logging.getLogger(__name__) @@ -43,7 +46,7 @@ def __init__(self) -> None: mappings=deepcopy(_DEFAULT_CONVERTERS), default=_to_default, ) - self._default_type_converter: Any | None = None + self._default_type_converter: DefaultTypeConverter | None = None def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: """Convert a string value to the appropriate Python type. From 5f5e6d5d378b537e23d7ec754d9b5515a4482fdf Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 17:52:59 +0900 Subject: [PATCH 12/18] Use name-based type matching in JSON struct conversion path The JSON parse path in _convert_typed_struct used positional indexing (field_types[i]) to assign types to fields. This breaks when JSON key order differs from the type definition order. Use _get_field_type() which matches by field name first, falling back to positional index. Co-Authored-By: Claude Opus 4.6 --- pyathena/parser.py | 2 +- tests/pyathena/test_parser.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/pyathena/parser.py b/pyathena/parser.py index 946ee325..33e80c7a 100644 --- a/pyathena/parser.py +++ b/pyathena/parser.py @@ -387,7 +387,7 @@ def _convert_typed_struct(self, value: str, type_node: TypeNode) -> dict[str, An if isinstance(parsed, dict): result: dict[str, Any] = {} for i, (k, v) in enumerate(parsed.items()): - ft = field_types[i] if i < len(field_types) else TypeNode("varchar") + ft = self._get_field_type(k, field_names, field_types, i) result[k] = ( self.convert(self._to_json_str(v), ft) if v is not None else None ) diff --git a/tests/pyathena/test_parser.py b/tests/pyathena/test_parser.py index 94fe1d66..1ae5256d 100644 --- a/tests/pyathena/test_parser.py +++ b/tests/pyathena/test_parser.py @@ -170,6 +170,16 @@ def test_nested_row_json(self, converter): assert isinstance(result["inner"]["a"], int) assert isinstance(result["val"], float) + def test_struct_json_name_based_type_matching(self, converter): + """JSON path: field types are matched by name, not position order.""" + parser = TypeSignatureParser() + node = parser.parse("row(name varchar, age integer)") + # JSON keys in reverse order compared to type definition + result = converter.convert('{"age": 25, "name": "Alice"}', node) + assert result == {"age": 25, "name": "Alice"} + assert isinstance(result["age"], int) + assert isinstance(result["name"], str) + def test_map_json_null_value_preserved(self, converter): """JSON path: map with null values vs "null" string values.""" parser = TypeSignatureParser() From 3140978e3708aabf23340100daccf8a427c8097b Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 17:53:07 +0900 Subject: [PATCH 13/18] Add documentation for result_set_type_hints feature Document the motivation (Athena API lacks nested type info), usage, constraints (nested arrays in native format, Arrow/Pandas/Polars), and the breaking change in 3.30.0 (complex type internals kept as strings without hints). Co-Authored-By: Claude Opus 4.6 --- docs/usage.md | 74 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/docs/usage.md b/docs/usage.md index 8149b8a5..80c24d71 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -389,6 +389,80 @@ The `on_start_query_execution` callback is supported by the following cursor typ Note: `AsyncCursor` and its variants do not support this callback as they already return the query ID immediately through their different execution model. +## Type hints for complex types + +*New in version 3.30.0.* + +The Athena API does not return element-level type information for complex types +(array, map, row/struct). PyAthena parses the string representation returned by +Athena, but without type metadata the converter can only apply heuristics — which +may produce incorrect Python types for nested values (e.g. integers left as strings +inside a struct). + +The `result_set_type_hints` parameter solves this by letting you provide Athena DDL +type signatures for specific columns. The converter then uses precise, recursive +type-aware conversion instead of heuristics. + +```python +from pyathena import connect + +cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2").cursor() +cursor.execute( + "SELECT col_array, col_map, col_struct FROM one_row_complex", + result_set_type_hints={ + "col_array": "array(integer)", + "col_map": "map(integer, integer)", + "col_struct": "row(a integer, b integer)", + }, +) +row = cursor.fetchone() +# col_struct values are now integers, not strings: +# {"a": 1, "b": 2} instead of {"a": "1", "b": "2"} +``` + +Column name matching is case-insensitive. Type hints support arbitrarily nested types: + +```python +cursor.execute( + """ + SELECT CAST( + ROW(ROW('2024-01-01', 123), 4.736, 0.583) + AS ROW(header ROW(stamp VARCHAR, seq INTEGER), x DOUBLE, y DOUBLE) + ) AS positions + """, + result_set_type_hints={ + "positions": "row(header row(stamp varchar, seq integer), x double, y double)", + }, +) +row = cursor.fetchone() +positions = row[0] +# positions["header"]["seq"] == 123 (int, not "123") +# positions["x"] == 4.736 (float, not "4.736") +``` + +### Constraints + +* **Nested arrays in native format** — Athena's native (non-JSON) string representation + does not clearly delimit nested arrays. If your query returns nested arrays + (e.g. `array(array(integer))`), use `CAST(... AS JSON)` in your query to get + JSON-formatted output, which is parsed reliably. +* **Arrow, Pandas, and Polars cursors** — These cursors accept `result_set_type_hints` + but their converters do not currently use the hints because they rely on their own + type systems. The parameter is passed through for forward compatibility and for + result sets that fall back to the default conversion path. + +### Breaking change in 3.30.0 + +Prior to 3.30.0, PyAthena attempted to infer Python types for scalar values inside +complex types using heuristics (e.g. `"123"` → `123`). Starting with 3.30.0, values +inside complex types are **kept as strings** unless `result_set_type_hints` is provided. +This change avoids silent misconversion but means existing code that relied on the +heuristic behavior may see string values where it previously saw integers or floats. + +To restore typed conversion, pass `result_set_type_hints` with the appropriate type +signatures for the affected columns. + ## Environment variables Support [Boto3 environment variables](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-environment-variables). From 2ff410c331eb20763c31aef7dfd383bd4ab7ac02 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 19:22:02 +0900 Subject: [PATCH 14/18] Optimize hot path performance for type hint conversion - ResultSet: Pre-compute column_type_hints tuple once in _process_metadata instead of per-cell dict creation and .lower() lookup. Replace **({} if ... else {}) with simple if/else branching. Applied to AthenaResultSet, AthenaDictResultSet, and S3FS. - Array JSON guard: Add JSON detection heuristic (check for '"', '[{', '[null') before json.loads in _convert_typed_array, matching the existing pattern in map/struct to avoid JSONDecodeError exceptions on native format strings. - TypeNode field lookup: Add cached _field_type_map dict for O(1) name-based field type resolution, replacing O(n) list.index() in _get_field_type. Co-Authored-By: Claude Opus 4.6 --- pyathena/parser.py | 66 ++++++++++++++++++++++------------- pyathena/result_set.py | 68 ++++++++++++++++++++++--------------- pyathena/s3fs/result_set.py | 49 +++++++++++++------------- 3 files changed, 106 insertions(+), 77 deletions(-) diff --git a/pyathena/parser.py b/pyathena/parser.py index 33e80c7a..09ef6da8 100644 --- a/pyathena/parser.py +++ b/pyathena/parser.py @@ -58,6 +58,23 @@ class TypeNode: type_name: str children: list[TypeNode] = field(default_factory=list) field_names: list[str] | None = None + _field_type_map: dict[str, TypeNode] | None = field(default=None, repr=False) + + def get_field_type(self, name: str) -> TypeNode | None: + """Look up a child type node by field name using a cached dict. + + Returns: + The TypeNode for the named field, or None if not found. + """ + if self._field_type_map is None and self.field_names: + self._field_type_map = { + fn: self.children[i] + for i, fn in enumerate(self.field_names) + if i < len(self.children) + } + if self._field_type_map: + return self._field_type_map.get(name) + return None class TypeSignatureParser: @@ -260,16 +277,20 @@ def _convert_typed_array(self, value: str, type_node: TypeNode) -> list[Any] | N element_type = type_node.children[0] if type_node.children else TypeNode("varchar") - # Try JSON first - try: - parsed = json.loads(value) - if isinstance(parsed, list): - return [ - None if elem is None else self.convert(self._to_json_str(elem), element_type) - for elem in parsed - ] - except json.JSONDecodeError: - pass + # Try JSON first (only if content looks like JSON) + inner_preview = value[1:10] if len(value) > 10 else value[1:-1] + if '"' in inner_preview or value.startswith(("[{", "[null")): + try: + parsed = json.loads(value) + if isinstance(parsed, list): + return [ + None + if elem is None + else self.convert(self._to_json_str(elem), element_type) + for elem in parsed + ] + except json.JSONDecodeError: + pass # Native format inner = value[1:-1].strip() @@ -376,7 +397,6 @@ def _convert_typed_struct(self, value: str, type_node: TypeNode) -> dict[str, An if not (value.startswith("{") and value.endswith("}")): return None - field_names = type_node.field_names or [] field_types = type_node.children or [] # Try JSON first @@ -387,7 +407,7 @@ def _convert_typed_struct(self, value: str, type_node: TypeNode) -> dict[str, An if isinstance(parsed, dict): result: dict[str, Any] = {} for i, (k, v) in enumerate(parsed.items()): - ft = self._get_field_type(k, field_names, field_types, i) + ft = self._get_field_type(k, type_node, i) result[k] = ( self.convert(self._to_json_str(v), ft) if v is not None else None ) @@ -413,7 +433,7 @@ def _convert_typed_struct(self, value: str, type_node: TypeNode) -> dict[str, An if any(char in k for char in '{}="'): continue - ft = self._get_field_type(k, field_names, field_types, field_index) + ft = self._get_field_type(k, type_node, field_index) field_index += 1 if v.startswith("{") and v.endswith("}"): @@ -428,6 +448,7 @@ def _convert_typed_struct(self, value: str, type_node: TypeNode) -> dict[str, An return result if result else None # Unnamed struct + field_names = type_node.field_names or [] values = [v.strip() for v in inner.split(",")] result = {} for i, v in enumerate(values): @@ -436,30 +457,29 @@ def _convert_typed_struct(self, value: str, type_node: TypeNode) -> dict[str, An result[name] = self._convert_element(v, ft) return result + @staticmethod def _get_field_type( - self, field_name: str, - field_names: list[str], - field_types: list[TypeNode], + type_node: TypeNode, field_index: int, ) -> TypeNode: """Look up the type for a struct field by name or index. - Tries name-based lookup first, then falls back to positional index. + Uses the TypeNode's cached dict for O(1) name lookup, then falls + back to positional index. Args: field_name: Name of the field to look up. - field_names: List of known field names from the type hint. - field_types: List of corresponding field types. + type_node: The parent row/struct TypeNode. field_index: Current positional index as fallback. Returns: TypeNode for the field, defaulting to varchar if not found. """ - if field_name in field_names: - idx = field_names.index(field_name) - if idx < len(field_types): - return field_types[idx] + ft = type_node.get_field_type(field_name) + if ft is not None: + return ft + field_types = type_node.children or [] if field_index < len(field_types): return field_types[field_index] return TypeNode("varchar") diff --git a/pyathena/result_set.py b/pyathena/result_set.py index 190a33e1..89c951ed 100644 --- a/pyathena/result_set.py +++ b/pyathena/result_set.py @@ -83,6 +83,7 @@ def __init__( ) self._metadata: tuple[dict[str, Any], ...] | None = None + self._column_type_hints: tuple[str | None, ...] | None = None self._rows: collections.deque[tuple[Any | None, ...] | dict[Any, Any | None]] = ( collections.deque() ) @@ -424,6 +425,10 @@ def _process_metadata(self, response: dict[str, Any]) -> None: if column_info is None: raise DataError("KeyError `ColumnInfo`") self._metadata = tuple(column_info) + if self._result_set_type_hints: + self._column_type_hints = tuple( + self._result_set_type_hints.get(m.get("Name", "").lower()) for m in self._metadata + ) def _process_update_count(self, response: dict[str, Any]) -> None: update_count = response.get("UpdateCount") @@ -449,21 +454,23 @@ def _get_rows( converter: Converter | None = None, ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: conv = converter or self._converter - hints = self._result_set_type_hints + col_hints = self._column_type_hints + if col_hints: + return [ + tuple( + conv.convert(meta.get("Type"), row.get("VarCharValue"), type_hint=hint) + if hint + else conv.convert(meta.get("Type"), row.get("VarCharValue")) + for meta, row, hint in zip( + metadata, rows[i].get("Data", []), col_hints, strict=False + ) + ) + for i in range(offset, len(rows)) + ] return [ tuple( - [ - conv.convert( - meta.get("Type"), - row.get("VarCharValue"), - **( - {"type_hint": hints[meta.get("Name", "").lower()]} - if hints and meta.get("Name", "").lower() in hints - else {} - ), - ) - for meta, row in zip(metadata, rows[i].get("Data", []), strict=False) - ] + conv.convert(meta.get("Type"), row.get("VarCharValue")) + for meta, row in zip(metadata, rows[i].get("Data", []), strict=False) ) for i in range(offset, len(rows)) ] @@ -645,24 +652,29 @@ def _get_rows( converter: Converter | None = None, ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: conv = converter or self._converter - hints = self._result_set_type_hints - return [ - self.dict_type( - [ + col_hints = self._column_type_hints + if col_hints: + return [ + self.dict_type( ( meta.get("Name"), - conv.convert( - meta.get("Type"), - row.get("VarCharValue"), - **( - {"type_hint": hints[meta.get("Name", "").lower()]} - if hints and meta.get("Name", "").lower() in hints - else {} - ), - ), + conv.convert(meta.get("Type"), row.get("VarCharValue"), type_hint=hint) + if hint + else conv.convert(meta.get("Type"), row.get("VarCharValue")), ) - for meta, row in zip(metadata, rows[i].get("Data", []), strict=False) - ] + for meta, row, hint in zip( + metadata, rows[i].get("Data", []), col_hints, strict=False + ) + ) + for i in range(offset, len(rows)) + ] + return [ + self.dict_type( + ( + meta.get("Name"), + conv.convert(meta.get("Type"), row.get("VarCharValue")), + ) + for meta, row in zip(metadata, rows[i].get("Data", []), strict=False) ) for i in range(offset, len(rows)) ] diff --git a/pyathena/s3fs/result_set.py b/pyathena/s3fs/result_set.py index 760fd169..fe15e97d 100644 --- a/pyathena/s3fs/result_set.py +++ b/pyathena/s3fs/result_set.py @@ -151,8 +151,7 @@ def _fetch(self) -> None: description = self.description if self.description else [] column_types = [d[1] for d in description] - column_names = [d[0] for d in description] - hints = self._result_set_type_hints + col_hints = self._column_type_hints rows_fetched = 0 while rows_fetched < self._arraysize: @@ -165,35 +164,33 @@ def _fetch(self) -> None: # AthenaCSVReader returns None for NULL values directly, # DefaultCSVReader returns empty string which needs conversion if self._csv_reader_class is DefaultCSVReader: - converted_row = tuple( - self._converter.convert( - col_type, - value if value != "" else None, - **( - {"type_hint": hints[col_name.lower()]} - if hints and col_name.lower() in hints - else {} - ), + if col_hints: + converted_row = tuple( + self._converter.convert( + col_type, value if value != "" else None, type_hint=hint + ) + if hint + else self._converter.convert(col_type, value if value != "" else None) + for col_type, value, hint in zip(column_types, row, col_hints, strict=False) ) - for col_type, col_name, value in zip( - column_types, column_names, row, strict=False + else: + converted_row = tuple( + self._converter.convert(col_type, value if value != "" else None) + for col_type, value in zip(column_types, row, strict=False) ) - ) else: - converted_row = tuple( - self._converter.convert( - col_type, - value, - **( - {"type_hint": hints[col_name.lower()]} - if hints and col_name.lower() in hints - else {} - ), + if col_hints: + converted_row = tuple( + self._converter.convert(col_type, value, type_hint=hint) + if hint + else self._converter.convert(col_type, value) + for col_type, value, hint in zip(column_types, row, col_hints, strict=False) ) - for col_type, col_name, value in zip( - column_types, column_names, row, strict=False + else: + converted_row = tuple( + self._converter.convert(col_type, value) + for col_type, value in zip(column_types, row, strict=False) ) - ) self._rows.append(converted_row) rows_fetched += 1 From 311f0391ebea0b59bdae7e3648736756efa270ef Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 19:30:35 +0900 Subject: [PATCH 15/18] Skip type hint path when no complex type columns exist Check column metadata types against _COMPLEX_TYPES (array, map, row, struct) in _process_metadata. Only compute and store column type hints when the result set actually contains complex type columns with matching hints. This eliminates all hint-related overhead in the hot loop for queries that return only scalar types. Co-Authored-By: Claude Opus 4.6 --- pyathena/result_set.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/pyathena/result_set.py b/pyathena/result_set.py index 89c951ed..5c892501 100644 --- a/pyathena/result_set.py +++ b/pyathena/result_set.py @@ -53,6 +53,10 @@ class AthenaResultSet(CursorIterator): https://docs.aws.amazon.com/athena/latest/APIReference/API_GetQueryResults.html """ + # https://docs.aws.amazon.com/athena/latest/ug/data-types.html + # Athena complex types that benefit from type hint conversion. + _COMPLEX_TYPES: frozenset[str] = frozenset({"array", "map", "row", "struct"}) + def __init__( self, connection: Connection[Any], @@ -425,10 +429,17 @@ def _process_metadata(self, response: dict[str, Any]) -> None: if column_info is None: raise DataError("KeyError `ColumnInfo`") self._metadata = tuple(column_info) - if self._result_set_type_hints: - self._column_type_hints = tuple( - self._result_set_type_hints.get(m.get("Name", "").lower()) for m in self._metadata + if self._result_set_type_hints and any( + m.get("Type", "").lower() in self._COMPLEX_TYPES for m in self._metadata + ): + hints = tuple( + self._result_set_type_hints.get(m.get("Name", "").lower()) + if m.get("Type", "").lower() in self._COMPLEX_TYPES + else None + for m in self._metadata ) + if any(hints): + self._column_type_hints = hints def _process_update_count(self, response: dict[str, Any]) -> None: update_count = response.get("UpdateCount") From 989aab2f6e25dd4f068a5152ea2b72c85d46bd81 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 20:06:00 +0900 Subject: [PATCH 16/18] Fix nested array JSON detection and pre-compute column metadata - Add '[[' to JSON detection guard in _convert_typed_array so nested arrays like [[1,2],[3]] are parsed via json.loads instead of falling through to native format (which returns None for nested arrays). - Pre-compute _column_types and _column_names tuples once in _process_metadata. Use them in _get_rows to eliminate per-cell meta.get("Type") and meta.get("Name") dict lookups. - S3FSResultSet._fetch() reuses _column_types from parent instead of rebuilding from self.description on every call. Co-Authored-By: Claude Opus 4.6 --- pyathena/parser.py | 2 +- pyathena/result_set.py | 62 +++++++++++++++++++++++++++-------- pyathena/s3fs/result_set.py | 14 ++++---- tests/pyathena/test_parser.py | 9 +++++ 4 files changed, 66 insertions(+), 21 deletions(-) diff --git a/pyathena/parser.py b/pyathena/parser.py index 09ef6da8..6b8f7c12 100644 --- a/pyathena/parser.py +++ b/pyathena/parser.py @@ -279,7 +279,7 @@ def _convert_typed_array(self, value: str, type_node: TypeNode) -> list[Any] | N # Try JSON first (only if content looks like JSON) inner_preview = value[1:10] if len(value) > 10 else value[1:-1] - if '"' in inner_preview or value.startswith(("[{", "[null")): + if '"' in inner_preview or value.startswith(("[{", "[null", "[[")): try: parsed = json.loads(value) if isinstance(parsed, list): diff --git a/pyathena/result_set.py b/pyathena/result_set.py index 5c892501..9248e1a2 100644 --- a/pyathena/result_set.py +++ b/pyathena/result_set.py @@ -87,6 +87,8 @@ def __init__( ) self._metadata: tuple[dict[str, Any], ...] | None = None + self._column_types: tuple[str, ...] | None = None + self._column_names: tuple[str, ...] | None = None self._column_type_hints: tuple[str | None, ...] | None = None self._rows: collections.deque[tuple[Any | None, ...] | dict[Any, Any | None]] = ( collections.deque() @@ -429,14 +431,16 @@ def _process_metadata(self, response: dict[str, Any]) -> None: if column_info is None: raise DataError("KeyError `ColumnInfo`") self._metadata = tuple(column_info) + self._column_types = tuple(m.get("Type", "") for m in self._metadata) + self._column_names = tuple(m.get("Name", "") for m in self._metadata) if self._result_set_type_hints and any( - m.get("Type", "").lower() in self._COMPLEX_TYPES for m in self._metadata + t.lower() in self._COMPLEX_TYPES for t in self._column_types ): hints = tuple( self._result_set_type_hints.get(m.get("Name", "").lower()) - if m.get("Type", "").lower() in self._COMPLEX_TYPES + if t.lower() in self._COMPLEX_TYPES else None - for m in self._metadata + for m, t in zip(self._metadata, self._column_types, strict=True) ) if any(hints): self._column_type_hints = hints @@ -465,19 +469,28 @@ def _get_rows( converter: Converter | None = None, ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: conv = converter or self._converter + col_types = self._column_types col_hints = self._column_type_hints - if col_hints: + if col_hints and col_types: return [ tuple( - conv.convert(meta.get("Type"), row.get("VarCharValue"), type_hint=hint) + conv.convert(col_type, row.get("VarCharValue"), type_hint=hint) if hint - else conv.convert(meta.get("Type"), row.get("VarCharValue")) - for meta, row, hint in zip( - metadata, rows[i].get("Data", []), col_hints, strict=False + else conv.convert(col_type, row.get("VarCharValue")) + for col_type, row, hint in zip( + col_types, rows[i].get("Data", []), col_hints, strict=False ) ) for i in range(offset, len(rows)) ] + if col_types: + return [ + tuple( + conv.convert(col_type, row.get("VarCharValue")) + for col_type, row in zip(col_types, rows[i].get("Data", []), strict=False) + ) + for i in range(offset, len(rows)) + ] return [ tuple( conv.convert(meta.get("Type"), row.get("VarCharValue")) @@ -639,6 +652,8 @@ def close(self) -> None: self._connection = None self._query_execution = None self._metadata = None + self._column_types = None + self._column_names = None self._rows.clear() self._next_token = None self._rownumber = None @@ -663,18 +678,37 @@ def _get_rows( converter: Converter | None = None, ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: conv = converter or self._converter + col_types = self._column_types + col_names = self._column_names col_hints = self._column_type_hints - if col_hints: + if col_hints and col_types and col_names: return [ self.dict_type( ( - meta.get("Name"), - conv.convert(meta.get("Type"), row.get("VarCharValue"), type_hint=hint) + name, + conv.convert(col_type, row.get("VarCharValue"), type_hint=hint) if hint - else conv.convert(meta.get("Type"), row.get("VarCharValue")), + else conv.convert(col_type, row.get("VarCharValue")), + ) + for name, col_type, row, hint in zip( + col_names, + col_types, + rows[i].get("Data", []), + col_hints, + strict=False, + ) + ) + for i in range(offset, len(rows)) + ] + if col_types and col_names: + return [ + self.dict_type( + ( + name, + conv.convert(col_type, row.get("VarCharValue")), ) - for meta, row, hint in zip( - metadata, rows[i].get("Data", []), col_hints, strict=False + for name, col_type, row in zip( + col_names, col_types, rows[i].get("Data", []), strict=False ) ) for i in range(offset, len(rows)) diff --git a/pyathena/s3fs/result_set.py b/pyathena/s3fs/result_set.py index fe15e97d..84e3e3e0 100644 --- a/pyathena/s3fs/result_set.py +++ b/pyathena/s3fs/result_set.py @@ -149,8 +149,10 @@ def _fetch(self) -> None: if not self._csv_reader: return - description = self.description if self.description else [] - column_types = [d[1] for d in description] + col_types = self._column_types + if not col_types: + description = self.description if self.description else [] + col_types = tuple(d[1] for d in description) col_hints = self._column_type_hints rows_fetched = 0 @@ -171,12 +173,12 @@ def _fetch(self) -> None: ) if hint else self._converter.convert(col_type, value if value != "" else None) - for col_type, value, hint in zip(column_types, row, col_hints, strict=False) + for col_type, value, hint in zip(col_types, row, col_hints, strict=False) ) else: converted_row = tuple( self._converter.convert(col_type, value if value != "" else None) - for col_type, value in zip(column_types, row, strict=False) + for col_type, value in zip(col_types, row, strict=False) ) else: if col_hints: @@ -184,12 +186,12 @@ def _fetch(self) -> None: self._converter.convert(col_type, value, type_hint=hint) if hint else self._converter.convert(col_type, value) - for col_type, value, hint in zip(column_types, row, col_hints, strict=False) + for col_type, value, hint in zip(col_types, row, col_hints, strict=False) ) else: converted_row = tuple( self._converter.convert(col_type, value) - for col_type, value in zip(column_types, row, strict=False) + for col_type, value in zip(col_types, row, strict=False) ) self._rows.append(converted_row) rows_fetched += 1 diff --git a/tests/pyathena/test_parser.py b/tests/pyathena/test_parser.py index 1ae5256d..7a29c5f3 100644 --- a/tests/pyathena/test_parser.py +++ b/tests/pyathena/test_parser.py @@ -180,6 +180,15 @@ def test_struct_json_name_based_type_matching(self, converter): assert isinstance(result["age"], int) assert isinstance(result["name"], str) + def test_nested_array_json(self, converter): + """JSON path: nested array like [[1,2],[3]] must be parsed via json.loads.""" + parser = TypeSignatureParser() + node = parser.parse("array(array(integer))") + result = converter.convert("[[1, 2], [3]]", node) + assert result == [[1, 2], [3]] + assert isinstance(result[0], list) + assert isinstance(result[0][0], int) + def test_map_json_null_value_preserved(self, converter): """JSON path: map with null values vs "null" string values.""" parser = TypeSignatureParser() From ab33667ca19c9f44fbbd2a79f46d5ec276188a8e Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 20:52:19 +0900 Subject: [PATCH 17/18] Add Hive syntax support, type aliases, parse fallback, and index-based hints - Normalize Hive-style DDL syntax (array>) to Trino-style so users can paste DESCRIBE TABLE output directly as type hints - Resolve type alias "int" to "integer" in the parser - Fall back to untyped conversion when typed converter returns None, preventing silent data loss on parse failures - Support integer keys in result_set_type_hints for index-based column resolution, enabling hints for duplicate column names - Update type annotations across all cursor/result_set files Co-Authored-By: Claude Opus 4.6 --- docs/usage.md | 32 +++++++++++++++ pyathena/aio/cursor.py | 2 +- pyathena/aio/result_set.py | 4 +- pyathena/arrow/async_cursor.py | 4 +- pyathena/arrow/cursor.py | 2 +- pyathena/arrow/result_set.py | 2 +- pyathena/async_cursor.py | 4 +- pyathena/converter.py | 32 +++++++++++---- pyathena/cursor.py | 2 +- pyathena/pandas/async_cursor.py | 4 +- pyathena/pandas/cursor.py | 2 +- pyathena/pandas/result_set.py | 2 +- pyathena/parser.py | 32 ++++++++++++++- pyathena/polars/async_cursor.py | 4 +- pyathena/polars/cursor.py | 2 +- pyathena/polars/result_set.py | 2 +- pyathena/result_set.py | 47 ++++++++++++++++------ pyathena/s3fs/async_cursor.py | 4 +- pyathena/s3fs/cursor.py | 2 +- pyathena/s3fs/result_set.py | 2 +- tests/pyathena/test_converter.py | 56 ++++++++++++++++++++++++++ tests/pyathena/test_parser.py | 67 +++++++++++++++++++++++++++++++- 22 files changed, 268 insertions(+), 42 deletions(-) diff --git a/docs/usage.md b/docs/usage.md index 80c24d71..549d1d2c 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -441,6 +441,38 @@ positions = row[0] # positions["x"] == 4.736 (float, not "4.736") ``` +### Hive-style syntax + +You can paste type signatures from Hive DDL or ``DESCRIBE TABLE`` output directly. +Hive-style angle brackets and colons are automatically converted to Trino-style syntax: + +```python +# Both are equivalent: +result_set_type_hints={"col": "array(struct(a integer, b varchar))"} # Trino +result_set_type_hints={"col": "array>"} # Hive +``` + +The ``int`` alias is also supported and resolves to ``integer``. + +### Index-based hints for duplicate column names + +When a query produces columns with the same alias (e.g. ``SELECT a AS x, b AS x``), +name-based hints cannot distinguish between them. Use integer keys to specify hints +by zero-based column position: + +```python +cursor.execute( + "SELECT a AS x, b AS x FROM my_table", + result_set_type_hints={ + 0: "array(integer)", # first "x" column + 1: "map(varchar, integer)", # second "x" column + }, +) +``` + +Integer (index-based) hints take priority over string (name-based) hints for the same +column. You can mix both styles in the same dictionary. + ### Constraints * **Nested arrays in native format** — Athena's native (non-JSON) string representation diff --git a/pyathena/aio/cursor.py b/pyathena/aio/cursor.py index d7998857..07bf507e 100644 --- a/pyathena/aio/cursor.py +++ b/pyathena/aio/cursor.py @@ -79,7 +79,7 @@ async def execute( # type: ignore[override] result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> AioCursor: """Execute a SQL query asynchronously. diff --git a/pyathena/aio/result_set.py b/pyathena/aio/result_set.py index 97f4ad2b..47fe6508 100644 --- a/pyathena/aio/result_set.py +++ b/pyathena/aio/result_set.py @@ -35,7 +35,7 @@ def __init__( query_execution: AthenaQueryExecution, arraysize: int, retry_config: RetryConfig, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, ) -> None: super().__init__( connection=connection, @@ -55,7 +55,7 @@ async def create( query_execution: AthenaQueryExecution, arraysize: int, retry_config: RetryConfig, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, ) -> AthenaAioResultSet: """Async factory method. diff --git a/pyathena/arrow/async_cursor.py b/pyathena/arrow/async_cursor.py index bf0e538b..da84b5c9 100644 --- a/pyathena/arrow/async_cursor.py +++ b/pyathena/arrow/async_cursor.py @@ -149,7 +149,7 @@ def arraysize(self, value: int) -> None: def _collect_result_set( self, query_id: str, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, unload_location: str | None = None, kwargs: dict[str, Any] | None = None, ) -> AthenaArrowResultSet: @@ -181,7 +181,7 @@ def execute( result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> tuple[str, Future[AthenaArrowResultSet | Any]]: operation, unload_location = self._prepare_unload(operation, s3_staging_dir) diff --git a/pyathena/arrow/cursor.py b/pyathena/arrow/cursor.py index 1876eebb..9d3d879f 100644 --- a/pyathena/arrow/cursor.py +++ b/pyathena/arrow/cursor.py @@ -137,7 +137,7 @@ def execute( result_reuse_minutes: int | None = None, paramstyle: str | None = None, on_start_query_execution: Callable[[str], None] | None = None, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> ArrowCursor: """Execute a SQL query and return results as Apache Arrow Tables. diff --git a/pyathena/arrow/result_set.py b/pyathena/arrow/result_set.py index 33027497..922159e4 100644 --- a/pyathena/arrow/result_set.py +++ b/pyathena/arrow/result_set.py @@ -91,7 +91,7 @@ def __init__( unload_location: str | None = None, connect_timeout: float | None = None, request_timeout: float | None = None, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> None: super().__init__( diff --git a/pyathena/async_cursor.py b/pyathena/async_cursor.py index 986a413c..8ba89073 100644 --- a/pyathena/async_cursor.py +++ b/pyathena/async_cursor.py @@ -147,7 +147,7 @@ def poll(self, query_id: str) -> Future[AthenaQueryExecution]: def _collect_result_set( self, query_id: str, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, ) -> AthenaResultSet: query_execution = cast(AthenaQueryExecution, self._poll(query_id)) return self._result_set_class( @@ -170,7 +170,7 @@ def execute( result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> tuple[str, Future[AthenaResultSet | Any]]: """Execute a SQL query asynchronously. diff --git a/pyathena/converter.py b/pyathena/converter.py index 3a486783..26c43f77 100644 --- a/pyathena/converter.py +++ b/pyathena/converter.py @@ -12,7 +12,13 @@ from dateutil.tz import gettz -from pyathena.parser import TypedValueConverter, TypeNode, TypeSignatureParser, _split_array_items +from pyathena.parser import ( + TypedValueConverter, + TypeNode, + TypeSignatureParser, + _normalize_hive_syntax, + _split_array_items, +) from pyathena.util import strtobool _logger = logging.getLogger(__name__) @@ -559,8 +565,9 @@ def convert(self, type_: str, value: str | None, type_hint: str | None = None) - """Convert a string value to the appropriate Python type. When ``type_hint`` is provided, uses the typed converter for precise - conversion of complex types. Otherwise, uses the standard converter - for the given Athena type. + conversion of complex types. If the typed converter returns ``None`` + (indicating a parse failure), falls back to the standard untyped + converter so that data is never silently lost. Args: type_: The Athena data type name (e.g., "integer", "varchar", "array"). @@ -575,19 +582,30 @@ def convert(self, type_: str, value: str | None, type_hint: str | None = None) - return None if type_hint: type_node = self._parse_type_hint(type_hint) - return self._typed_converter.convert(value, type_node) + result = self._typed_converter.convert(value, type_node) + if result is not None: + return result + # Typed conversion returned None — this means a parse failure + # (actual SQL NULLs are caught by the `value is None` check above). + # Fall back to untyped conversion to avoid silent data loss. + return self.get(type_)(value) converter = self.get(type_) return converter(value) def _parse_type_hint(self, type_hint: str) -> TypeNode: """Parse a type hint string into a TypeNode, with caching. + Normalizes Hive-style syntax (``array``) to Trino-style + (``array(integer)``) before parsing, so both syntaxes share the + same cache entry. + Args: type_hint: Athena DDL type signature string. Returns: Parsed TypeNode. """ - if type_hint not in self._parsed_hints: - self._parsed_hints[type_hint] = self._parser.parse(type_hint) - return self._parsed_hints[type_hint] + normalized = _normalize_hive_syntax(type_hint) + if normalized not in self._parsed_hints: + self._parsed_hints[normalized] = self._parser.parse(normalized) + return self._parsed_hints[normalized] diff --git a/pyathena/cursor.py b/pyathena/cursor.py index c730f9bc..d113b387 100644 --- a/pyathena/cursor.py +++ b/pyathena/cursor.py @@ -95,7 +95,7 @@ def execute( result_reuse_minutes: int | None = None, paramstyle: str | None = None, on_start_query_execution: Callable[[str], None] | None = None, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> Cursor: """Execute a SQL query. diff --git a/pyathena/pandas/async_cursor.py b/pyathena/pandas/async_cursor.py index faebe32f..db9f3f9e 100644 --- a/pyathena/pandas/async_cursor.py +++ b/pyathena/pandas/async_cursor.py @@ -118,7 +118,7 @@ def arraysize(self, value: int) -> None: def _collect_result_set( self, query_id: str, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, keep_default_na: bool = False, na_values: Iterable[str] | None = ("",), quoting: int = 1, @@ -156,7 +156,7 @@ def execute( result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, keep_default_na: bool = False, na_values: Iterable[str] | None = ("",), quoting: int = 1, diff --git a/pyathena/pandas/cursor.py b/pyathena/pandas/cursor.py index 4413fea9..22a7d8ac 100644 --- a/pyathena/pandas/cursor.py +++ b/pyathena/pandas/cursor.py @@ -153,7 +153,7 @@ def execute( na_values: Iterable[str] | None = ("",), quoting: int = 1, on_start_query_execution: Callable[[str], None] | None = None, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> PandasCursor: """Execute a SQL query and return results as pandas DataFrames. diff --git a/pyathena/pandas/result_set.py b/pyathena/pandas/result_set.py index c88b295d..592fdb2d 100644 --- a/pyathena/pandas/result_set.py +++ b/pyathena/pandas/result_set.py @@ -229,7 +229,7 @@ def __init__( cache_type: str | None = None, max_workers: int = (cpu_count() or 1) * 5, auto_optimize_chunksize: bool = False, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> None: """Initialize AthenaPandasResultSet with pandas-specific configurations. diff --git a/pyathena/parser.py b/pyathena/parser.py index 6b8f7c12..ec358dee 100644 --- a/pyathena/parser.py +++ b/pyathena/parser.py @@ -1,10 +1,38 @@ from __future__ import annotations import json +import re from collections.abc import Callable from dataclasses import dataclass, field from typing import Any +# Aliases for Athena type names that differ between Hive DDL and Trino DDL. +_TYPE_ALIASES: dict[str, str] = { + "int": "integer", +} + +# Pattern for normalizing Hive-style type signatures to Trino-style. +# Matches angle brackets and colons used in Hive DDL (e.g., array>). +_HIVE_SYNTAX_RE: re.Pattern[str] = re.compile(r"[<>:]") +_HIVE_REPLACEMENTS: dict[str, str] = {"<": "(", ">": ")", ":": " "} + + +def _normalize_hive_syntax(type_str: str) -> str: + """Normalize Hive-style DDL syntax to Trino-style. + + Converts angle-bracket notation (``array>``) to + parenthesized notation (``array(struct(a int))``). + + Args: + type_str: Type signature string, possibly using Hive syntax. + + Returns: + Normalized type signature using Trino-style parenthesized notation. + """ + if "<" not in type_str: + return type_str + return _HIVE_SYNTAX_RE.sub(lambda m: _HIVE_REPLACEMENTS[m.group()], type_str) + def _split_array_items(inner: str) -> list[str]: """Split array items by comma, respecting brace and bracket groupings. @@ -96,9 +124,11 @@ def parse(self, type_str: str) -> TypeNode: paren_idx = type_str.find("(") if paren_idx == -1: - return TypeNode(type_name=type_str.lower()) + name = type_str.lower() + return TypeNode(type_name=_TYPE_ALIASES.get(name, name)) type_name = type_str[:paren_idx].strip().lower() + type_name = _TYPE_ALIASES.get(type_name, type_name) inner = type_str[paren_idx + 1 : -1].strip() diff --git a/pyathena/polars/async_cursor.py b/pyathena/polars/async_cursor.py index c973b736..3774a47e 100644 --- a/pyathena/polars/async_cursor.py +++ b/pyathena/polars/async_cursor.py @@ -161,7 +161,7 @@ def arraysize(self, value: int) -> None: def _collect_result_set( self, query_id: str, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, unload_location: str | None = None, kwargs: dict[str, Any] | None = None, ) -> AthenaPolarsResultSet: @@ -195,7 +195,7 @@ def execute( result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> tuple[str, Future[AthenaPolarsResultSet | Any]]: """Execute a SQL query asynchronously and return results as Polars DataFrames. diff --git a/pyathena/polars/cursor.py b/pyathena/polars/cursor.py index 0b95ba8f..efc738c4 100644 --- a/pyathena/polars/cursor.py +++ b/pyathena/polars/cursor.py @@ -157,7 +157,7 @@ def execute( result_reuse_minutes: int | None = None, paramstyle: str | None = None, on_start_query_execution: Callable[[str], None] | None = None, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> PolarsCursor: """Execute a SQL query and return results as Polars DataFrames. diff --git a/pyathena/polars/result_set.py b/pyathena/polars/result_set.py index aedc8e33..adc4e2ac 100644 --- a/pyathena/polars/result_set.py +++ b/pyathena/polars/result_set.py @@ -202,7 +202,7 @@ def __init__( cache_type: str | None = None, max_workers: int = (cpu_count() or 1) * 5, chunksize: int | None = None, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> None: """Initialize the Polars result set. diff --git a/pyathena/result_set.py b/pyathena/result_set.py index 9248e1a2..ea1fe974 100644 --- a/pyathena/result_set.py +++ b/pyathena/result_set.py @@ -65,7 +65,7 @@ def __init__( arraysize: int, retry_config: RetryConfig, _pre_fetch: bool = True, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, ) -> None: super().__init__(arraysize=arraysize) self._connection: Connection[Any] | None = connection @@ -74,11 +74,14 @@ def __init__( if not self._query_execution: raise ProgrammingError("Required argument `query_execution` not found.") self._retry_config = retry_config - self._result_set_type_hints = ( - {k.lower(): v for k, v in result_set_type_hints.items()} - if result_set_type_hints - else None - ) + self._hints_by_name: dict[str, str] = {} + self._hints_by_index: dict[int, str] = {} + if result_set_type_hints: + for k, v in result_set_type_hints.items(): + if isinstance(k, int): + self._hints_by_index[k] = v + else: + self._hints_by_name[k.lower()] = v self._client = connection.session.client( "s3", region_name=connection.region_name, @@ -433,18 +436,40 @@ def _process_metadata(self, response: dict[str, Any]) -> None: self._metadata = tuple(column_info) self._column_types = tuple(m.get("Type", "") for m in self._metadata) self._column_names = tuple(m.get("Name", "") for m in self._metadata) - if self._result_set_type_hints and any( + if (self._hints_by_name or self._hints_by_index) and any( t.lower() in self._COMPLEX_TYPES for t in self._column_types ): hints = tuple( - self._result_set_type_hints.get(m.get("Name", "").lower()) - if t.lower() in self._COMPLEX_TYPES - else None - for m, t in zip(self._metadata, self._column_types, strict=True) + self._resolve_type_hint(i, m.get("Name", "").lower(), t.lower()) + for i, (m, t) in enumerate(zip(self._metadata, self._column_types, strict=True)) ) if any(hints): self._column_type_hints = hints + def _resolve_type_hint( + self, index: int, col_name_lower: str, col_type_lower: str + ) -> str | None: + """Look up the type hint for a column by index then by name. + + Index-based hints take priority over name-based hints, allowing + callers to disambiguate duplicate column names. + + Args: + index: Zero-based column position. + col_name_lower: Lowercased column name from metadata. + col_type_lower: Lowercased column type from metadata. + + Returns: + The type hint string, or None if the column has no hint or + is not a complex type. + """ + if col_type_lower not in self._COMPLEX_TYPES: + return None + hint = self._hints_by_index.get(index) + if hint is not None: + return hint + return self._hints_by_name.get(col_name_lower) + def _process_update_count(self, response: dict[str, Any]) -> None: update_count = response.get("UpdateCount") if ( diff --git a/pyathena/s3fs/async_cursor.py b/pyathena/s3fs/async_cursor.py index c4d3c3fe..73acede2 100644 --- a/pyathena/s3fs/async_cursor.py +++ b/pyathena/s3fs/async_cursor.py @@ -142,7 +142,7 @@ def arraysize(self, value: int) -> None: def _collect_result_set( self, query_id: str, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, kwargs: dict[str, Any] | None = None, ) -> AthenaS3FSResultSet: """Collect result set after query execution. @@ -182,7 +182,7 @@ def execute( result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> tuple[str, Future[AthenaS3FSResultSet | Any]]: """Execute a SQL query asynchronously. diff --git a/pyathena/s3fs/cursor.py b/pyathena/s3fs/cursor.py index c1bf47e4..dfc5dd5e 100644 --- a/pyathena/s3fs/cursor.py +++ b/pyathena/s3fs/cursor.py @@ -133,7 +133,7 @@ def execute( result_reuse_minutes: int | None = None, paramstyle: str | None = None, on_start_query_execution: Callable[[str], None] | None = None, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> S3FSCursor: """Execute a SQL query and return results. diff --git a/pyathena/s3fs/result_set.py b/pyathena/s3fs/result_set.py index 84e3e3e0..04828c00 100644 --- a/pyathena/s3fs/result_set.py +++ b/pyathena/s3fs/result_set.py @@ -64,7 +64,7 @@ def __init__( block_size: int | None = None, csv_reader: CSVReaderType | None = None, filesystem_class: type[AbstractFileSystem] | None = None, - result_set_type_hints: dict[str, str] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> None: super().__init__( diff --git a/tests/pyathena/test_converter.py b/tests/pyathena/test_converter.py index 1c4ff3dc..39583a24 100644 --- a/tests/pyathena/test_converter.py +++ b/tests/pyathena/test_converter.py @@ -352,3 +352,59 @@ def test_row_with_nested_struct(self): type_hint="row(header row(seq integer, stamp varchar), x double)", ) assert result == {"header": {"seq": 123, "stamp": "2024"}, "x": 4.5} + + def test_fallback_on_malformed_value(self): + """When typed conversion fails (returns None), fall back to untyped conversion.""" + converter = DefaultTypeConverter() + # "not-an-array" doesn't look like an array — typed converter returns None. + # Untyped _to_array also returns None for this input, which is correct. + result = converter.convert("array", "not-an-array", type_hint="array(integer)") + assert result is None + + def test_fallback_preserves_struct_value(self): + """Malformed struct with type_hint still falls back to untyped parsing.""" + converter = DefaultTypeConverter() + # Struct with no closing brace — typed converter returns None. + # Untyped _to_struct also returns None here. + result = converter.convert("row", "{unclosed", type_hint="row(a integer)") + assert result is None + + def test_fallback_returns_untyped_result(self): + """When typed conversion returns None, untyped conversion is used.""" + converter = DefaultTypeConverter() + # The typed converter returns None for a struct that doesn't start with "{". + # The untyped _to_struct also returns None for non-struct input. + # Use an array example where typed converter returns None (not a bracket-wrapped + # value), but untyped _to_array can still parse it via JSON. + result = converter.convert( + "row", + '{"a": 1}', + type_hint="row(a varchar)", + ) + # Typed conversion succeeds here — "a" is varchar so "1" stays a string + assert result == {"a": "1"} + + def test_hive_syntax_through_converter(self): + """Hive-style syntax works end-to-end through DefaultTypeConverter.""" + converter = DefaultTypeConverter() + result = converter.convert("array", "[1, 2, 3]", type_hint="array") + assert result == [1, 2, 3] + + def test_hive_syntax_struct_through_converter(self): + """Hive struct syntax works end-to-end.""" + converter = DefaultTypeConverter() + result = converter.convert( + "row", + "{name=Alice, age=25}", + type_hint="struct", + ) + assert result == {"name": "Alice", "age": 25} + + def test_hive_syntax_caching(self): + """Hive syntax is normalized before cache lookup.""" + converter = DefaultTypeConverter() + converter.convert("array", "[1]", type_hint="array") + converter.convert("array", "[2]", type_hint="array(integer)") + # Both should normalize to "array(integer)" in the cache + assert "array(integer)" in converter._parsed_hints + assert len(converter._parsed_hints) == 1 diff --git a/tests/pyathena/test_parser.py b/tests/pyathena/test_parser.py index 7a29c5f3..e3981476 100644 --- a/tests/pyathena/test_parser.py +++ b/tests/pyathena/test_parser.py @@ -1,7 +1,12 @@ import pytest from pyathena.converter import _DEFAULT_CONVERTERS, _to_default, _to_struct -from pyathena.parser import TypedValueConverter, TypeNode, TypeSignatureParser +from pyathena.parser import ( + TypedValueConverter, + TypeNode, + TypeSignatureParser, + _normalize_hive_syntax, +) class TestTypeSignatureParser: @@ -93,6 +98,66 @@ def test_varchar_with_length(self): node = parser.parse("varchar(255)") assert node.type_name == "varchar" + def test_type_alias_int(self): + parser = TypeSignatureParser() + node = parser.parse("int") + assert node.type_name == "integer" + + def test_type_alias_in_complex_type(self): + parser = TypeSignatureParser() + node = parser.parse("array(int)") + assert node.type_name == "array" + assert node.children[0].type_name == "integer" + + def test_hive_syntax_simple(self): + parser = TypeSignatureParser() + node = parser.parse(_normalize_hive_syntax("array")) + assert node.type_name == "array" + assert node.children[0].type_name == "integer" + + def test_hive_syntax_struct(self): + parser = TypeSignatureParser() + node = parser.parse(_normalize_hive_syntax("struct")) + assert node.type_name == "struct" + assert node.field_names == ["a", "b"] + assert node.children[0].type_name == "integer" + assert node.children[1].type_name == "varchar" + + def test_hive_syntax_nested(self): + parser = TypeSignatureParser() + node = parser.parse(_normalize_hive_syntax("array>")) + assert node.type_name == "array" + struct_node = node.children[0] + assert struct_node.type_name == "struct" + assert struct_node.field_names == ["a", "b"] + assert struct_node.children[0].type_name == "integer" + assert struct_node.children[1].type_name == "varchar" + + def test_hive_syntax_map(self): + parser = TypeSignatureParser() + node = parser.parse(_normalize_hive_syntax("map")) + assert node.type_name == "map" + assert node.children[0].type_name == "string" + assert node.children[1].type_name == "integer" + + def test_mixed_syntax(self): + """Hive angle brackets wrapping Trino-style parenthesized inner type.""" + parser = TypeSignatureParser() + node = parser.parse(_normalize_hive_syntax("array")) + assert node.type_name == "array" + row_node = node.children[0] + assert row_node.type_name == "row" + assert row_node.field_names == ["a", "b"] + assert row_node.children[0].type_name == "integer" + assert row_node.children[1].type_name == "varchar" + + def test_normalize_hive_syntax_noop(self): + """Trino-style input passes through unchanged.""" + assert _normalize_hive_syntax("array(integer)") == "array(integer)" + + def test_normalize_hive_syntax_replaces(self): + assert _normalize_hive_syntax("array>") == "array(struct(a int))" + class TestTypedValueConverter: @pytest.fixture From add20b48052e923fda416ac5f1d9a8c2f9677014 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 28 Feb 2026 21:23:06 +0900 Subject: [PATCH 18/18] Fix parser robustness: matching paren lookup and unnamed struct split - Use _find_matching_paren() instead of assuming closing ')' is at end of string, so trailing modifiers don't break parsing - Replace naive comma split with _split_array_items() in unnamed struct path to handle nested values correctly Closes #693, closes #694. Co-Authored-By: Claude Opus 4.6 --- pyathena/parser.py | 26 ++++++++++++++++++++++++-- tests/pyathena/test_parser.py | 14 ++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/pyathena/parser.py b/pyathena/parser.py index ec358dee..c6b70eb8 100644 --- a/pyathena/parser.py +++ b/pyathena/parser.py @@ -130,7 +130,8 @@ def parse(self, type_str: str) -> TypeNode: type_name = type_str[:paren_idx].strip().lower() type_name = _TYPE_ALIASES.get(type_name, type_name) - inner = type_str[paren_idx + 1 : -1].strip() + close_idx = self._find_matching_paren(type_str, paren_idx) + inner = type_str[paren_idx + 1 : close_idx].strip() if type_name in ("row", "struct"): parts = self._split_type_args(inner) @@ -192,6 +193,27 @@ def _split_type_args(self, s: str) -> list[str]: parts.append("".join(current).strip()) return parts + @staticmethod + def _find_matching_paren(s: str, open_idx: int) -> int: + """Find the index of the closing parenthesis matching the one at *open_idx*. + + Args: + s: The full string. + open_idx: Index of the opening ``(``. + + Returns: + Index of the matching ``)``. + """ + depth = 0 + for i in range(open_idx, len(s)): + if s[i] == "(": + depth += 1 + elif s[i] == ")": + depth -= 1 + if depth == 0: + return i + return len(s) - 1 + def _find_field_name_boundary(self, part: str) -> int: """Find the boundary between field name and type in a row field definition. @@ -479,7 +501,7 @@ def _convert_typed_struct(self, value: str, type_node: TypeNode) -> dict[str, An # Unnamed struct field_names = type_node.field_names or [] - values = [v.strip() for v in inner.split(",")] + values = _split_array_items(inner) result = {} for i, v in enumerate(values): ft = field_types[i] if i < len(field_types) else TypeNode("varchar") diff --git a/tests/pyathena/test_parser.py b/tests/pyathena/test_parser.py index e3981476..bd6ab9d7 100644 --- a/tests/pyathena/test_parser.py +++ b/tests/pyathena/test_parser.py @@ -158,6 +158,13 @@ def test_normalize_hive_syntax_noop(self): def test_normalize_hive_syntax_replaces(self): assert _normalize_hive_syntax("array>") == "array(struct(a int))" + def test_trailing_modifier_after_paren(self): + """Type with content after closing paren should not break parsing.""" + parser = TypeSignatureParser() + # Simulates a hypothetical "timestamp(3) with time zone" style input + node = parser.parse("decimal(10, 2) extra") + assert node.type_name == "decimal" + class TestTypedValueConverter: @pytest.fixture @@ -261,3 +268,10 @@ def test_map_json_null_value_preserved(self, converter): result = converter.convert('{"a": null, "b": "null"}', node) assert result["a"] is None assert result["b"] == "null" + + def test_unnamed_struct_with_nested_value(self, converter): + """Unnamed struct split must respect nested braces.""" + parser = TypeSignatureParser() + node = parser.parse("row(inner row(x integer, y integer), val varchar)") + result = converter.convert("{inner={x=1, y=2}, val=hello}", node) + assert result == {"inner": {"x": 1, "y": 2}, "val": "hello"}