diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..4e09ff5b --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,5 @@ +## WHAT + + +## WHY + diff --git a/CLAUDE.md b/CLAUDE.md index c2725412..1fffd00e 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -42,6 +42,19 @@ export $(cat .env | xargs) && uv run pytest tests/pyathena/test_file.py -v - Use pytest fixtures from `conftest.py` - New features require tests; changes to SQLAlchemy dialects must pass `make test-sqla` +#### Test Conventions +- **Class-based tests** for integration tests that use fixtures (cursors, engines): `class TestCursor:` with methods like `def test_fetchone(self, cursor):` +- **Standalone functions** for unit tests of pure logic (converters, parsers, utils): `def test_to_struct_json_formats(input_value, expected):` +- Test file naming mirrors source: `pyathena/parser.py` → `tests/pyathena/test_parser.py` +- **Fixtures**: Cursor/engine fixtures are defined in `conftest.py` and injected by name (e.g., `cursor`, `engine`, `async_cursor`). Use `indirect=True` parametrization to pass connection options: + ```python + @pytest.mark.parametrize("engine", [{"driver": "rest"}], indirect=True) + def test_query(self, engine): + engine, conn = engine + ``` +- **Parametrize** with `@pytest.mark.parametrize(("input", "expected"), [...])` for data-driven tests +- **Integration tests** (need AWS) use cursor/engine fixtures with real Athena queries; **unit tests** (no AWS) call functions directly with test data + ## Architecture — Key Design Decisions These are non-obvious conventions that can't be discovered by reading code alone. diff --git a/docs/usage.md b/docs/usage.md index 8149b8a5..549d1d2c 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -389,6 +389,112 @@ The `on_start_query_execution` callback is supported by the following cursor typ Note: `AsyncCursor` and its variants do not support this callback as they already return the query ID immediately through their different execution model. +## Type hints for complex types + +*New in version 3.30.0.* + +The Athena API does not return element-level type information for complex types +(array, map, row/struct). PyAthena parses the string representation returned by +Athena, but without type metadata the converter can only apply heuristics — which +may produce incorrect Python types for nested values (e.g. integers left as strings +inside a struct). + +The `result_set_type_hints` parameter solves this by letting you provide Athena DDL +type signatures for specific columns. The converter then uses precise, recursive +type-aware conversion instead of heuristics. + +```python +from pyathena import connect + +cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2").cursor() +cursor.execute( + "SELECT col_array, col_map, col_struct FROM one_row_complex", + result_set_type_hints={ + "col_array": "array(integer)", + "col_map": "map(integer, integer)", + "col_struct": "row(a integer, b integer)", + }, +) +row = cursor.fetchone() +# col_struct values are now integers, not strings: +# {"a": 1, "b": 2} instead of {"a": "1", "b": "2"} +``` + +Column name matching is case-insensitive. Type hints support arbitrarily nested types: + +```python +cursor.execute( + """ + SELECT CAST( + ROW(ROW('2024-01-01', 123), 4.736, 0.583) + AS ROW(header ROW(stamp VARCHAR, seq INTEGER), x DOUBLE, y DOUBLE) + ) AS positions + """, + result_set_type_hints={ + "positions": "row(header row(stamp varchar, seq integer), x double, y double)", + }, +) +row = cursor.fetchone() +positions = row[0] +# positions["header"]["seq"] == 123 (int, not "123") +# positions["x"] == 4.736 (float, not "4.736") +``` + +### Hive-style syntax + +You can paste type signatures from Hive DDL or ``DESCRIBE TABLE`` output directly. +Hive-style angle brackets and colons are automatically converted to Trino-style syntax: + +```python +# Both are equivalent: +result_set_type_hints={"col": "array(struct(a integer, b varchar))"} # Trino +result_set_type_hints={"col": "array>"} # Hive +``` + +The ``int`` alias is also supported and resolves to ``integer``. + +### Index-based hints for duplicate column names + +When a query produces columns with the same alias (e.g. ``SELECT a AS x, b AS x``), +name-based hints cannot distinguish between them. Use integer keys to specify hints +by zero-based column position: + +```python +cursor.execute( + "SELECT a AS x, b AS x FROM my_table", + result_set_type_hints={ + 0: "array(integer)", # first "x" column + 1: "map(varchar, integer)", # second "x" column + }, +) +``` + +Integer (index-based) hints take priority over string (name-based) hints for the same +column. You can mix both styles in the same dictionary. + +### Constraints + +* **Nested arrays in native format** — Athena's native (non-JSON) string representation + does not clearly delimit nested arrays. If your query returns nested arrays + (e.g. `array(array(integer))`), use `CAST(... AS JSON)` in your query to get + JSON-formatted output, which is parsed reliably. +* **Arrow, Pandas, and Polars cursors** — These cursors accept `result_set_type_hints` + but their converters do not currently use the hints because they rely on their own + type systems. The parameter is passed through for forward compatibility and for + result sets that fall back to the default conversion path. + +### Breaking change in 3.30.0 + +Prior to 3.30.0, PyAthena attempted to infer Python types for scalar values inside +complex types using heuristics (e.g. `"123"` → `123`). Starting with 3.30.0, values +inside complex types are **kept as strings** unless `result_set_type_hints` is provided. +This change avoids silent misconversion but means existing code that relied on the +heuristic behavior may see string values where it previously saw integers or floats. + +To restore typed conversion, pass `result_set_type_hints` with the appropriate type +signatures for the affected columns. + ## Environment variables Support [Boto3 environment variables](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-environment-variables). diff --git a/pyathena/aio/cursor.py b/pyathena/aio/cursor.py index 30738f8f..07bf507e 100644 --- a/pyathena/aio/cursor.py +++ b/pyathena/aio/cursor.py @@ -79,6 +79,7 @@ async def execute( # type: ignore[override] result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> AioCursor: """Execute a SQL query asynchronously. @@ -93,6 +94,9 @@ async def execute( # type: ignore[override] result_reuse_enable: Enable result reuse (optional). result_reuse_minutes: Result reuse duration in minutes (optional). paramstyle: Parameter style to use (optional). + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. **kwargs: Additional execution parameters. Returns: @@ -119,6 +123,7 @@ async def execute( # type: ignore[override] query_execution, self.arraysize, self._retry_config, + result_set_type_hints=result_set_type_hints, ) else: raise OperationalError(query_execution.state_change_reason) diff --git a/pyathena/aio/result_set.py b/pyathena/aio/result_set.py index 000337bd..47fe6508 100644 --- a/pyathena/aio/result_set.py +++ b/pyathena/aio/result_set.py @@ -35,6 +35,7 @@ def __init__( query_execution: AthenaQueryExecution, arraysize: int, retry_config: RetryConfig, + result_set_type_hints: dict[str | int, str] | None = None, ) -> None: super().__init__( connection=connection, @@ -43,6 +44,7 @@ def __init__( arraysize=arraysize, retry_config=retry_config, _pre_fetch=False, + result_set_type_hints=result_set_type_hints, ) @classmethod @@ -53,6 +55,7 @@ async def create( query_execution: AthenaQueryExecution, arraysize: int, retry_config: RetryConfig, + result_set_type_hints: dict[str | int, str] | None = None, ) -> AthenaAioResultSet: """Async factory method. @@ -64,11 +67,20 @@ async def create( query_execution: Query execution metadata. arraysize: Number of rows to fetch per request. retry_config: Retry configuration for API calls. + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion. Returns: A fully initialized ``AthenaAioResultSet``. """ - result_set = cls(connection, converter, query_execution, arraysize, retry_config) + result_set = cls( + connection, + converter, + query_execution, + arraysize, + retry_config, + result_set_type_hints=result_set_type_hints, + ) if result_set.state == AthenaQueryExecution.STATE_SUCCEEDED: await result_set._async_pre_fetch() return result_set diff --git a/pyathena/arrow/async_cursor.py b/pyathena/arrow/async_cursor.py index 39d2950b..da84b5c9 100644 --- a/pyathena/arrow/async_cursor.py +++ b/pyathena/arrow/async_cursor.py @@ -149,6 +149,7 @@ def arraysize(self, value: int) -> None: def _collect_result_set( self, query_id: str, + result_set_type_hints: dict[str | int, str] | None = None, unload_location: str | None = None, kwargs: dict[str, Any] | None = None, ) -> AthenaArrowResultSet: @@ -165,6 +166,7 @@ def _collect_result_set( unload_location=unload_location, connect_timeout=self._connect_timeout, request_timeout=self._request_timeout, + result_set_type_hints=result_set_type_hints, **kwargs, ) @@ -179,6 +181,7 @@ def execute( result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> tuple[str, Future[AthenaArrowResultSet | Any]]: operation, unload_location = self._prepare_unload(operation, s3_staging_dir) @@ -198,6 +201,7 @@ def execute( self._executor.submit( self._collect_result_set, query_id, + result_set_type_hints, unload_location, kwargs, ), diff --git a/pyathena/arrow/converter.py b/pyathena/arrow/converter.py index 1e5d5d91..26774186 100644 --- a/pyathena/arrow/converter.py +++ b/pyathena/arrow/converter.py @@ -90,7 +90,7 @@ def _dtypes(self) -> dict[str, type[Any]]: } return self.__dtypes - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: converter = self.get(type_) return converter(value) @@ -114,5 +114,5 @@ def __init__(self) -> None: default=_to_default, ) - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: pass diff --git a/pyathena/arrow/cursor.py b/pyathena/arrow/cursor.py index 8f831831..9d3d879f 100644 --- a/pyathena/arrow/cursor.py +++ b/pyathena/arrow/cursor.py @@ -137,6 +137,7 @@ def execute( result_reuse_minutes: int | None = None, paramstyle: str | None = None, on_start_query_execution: Callable[[str], None] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> ArrowCursor: """Execute a SQL query and return results as Apache Arrow Tables. @@ -156,6 +157,9 @@ def execute( result_reuse_minutes: Minutes to reuse cached results. paramstyle: Parameter style ('qmark' or 'pyformat'). on_start_query_execution: Callback called when query starts. + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. **kwargs: Additional execution parameters. Returns: @@ -197,6 +201,7 @@ def execute( unload_location=unload_location, connect_timeout=self._connect_timeout, request_timeout=self._request_timeout, + result_set_type_hints=result_set_type_hints, **kwargs, ) else: diff --git a/pyathena/arrow/result_set.py b/pyathena/arrow/result_set.py index 12fab1ac..922159e4 100644 --- a/pyathena/arrow/result_set.py +++ b/pyathena/arrow/result_set.py @@ -91,6 +91,7 @@ def __init__( unload_location: str | None = None, connect_timeout: float | None = None, request_timeout: float | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> None: super().__init__( @@ -99,6 +100,7 @@ def __init__( query_execution=query_execution, arraysize=1, # Fetch one row to retrieve metadata retry_config=retry_config, + result_set_type_hints=result_set_type_hints, ) self._rows.clear() # Clear pre_fetch data self._arraysize = arraysize diff --git a/pyathena/async_cursor.py b/pyathena/async_cursor.py index 1716cc3c..8ba89073 100644 --- a/pyathena/async_cursor.py +++ b/pyathena/async_cursor.py @@ -144,7 +144,11 @@ def poll(self, query_id: str) -> Future[AthenaQueryExecution]: """ return cast("Future[AthenaQueryExecution]", self._executor.submit(self._poll, query_id)) - def _collect_result_set(self, query_id: str) -> AthenaResultSet: + def _collect_result_set( + self, + query_id: str, + result_set_type_hints: dict[str | int, str] | None = None, + ) -> AthenaResultSet: query_execution = cast(AthenaQueryExecution, self._poll(query_id)) return self._result_set_class( connection=self._connection, @@ -152,6 +156,7 @@ def _collect_result_set(self, query_id: str) -> AthenaResultSet: query_execution=query_execution, arraysize=self._arraysize, retry_config=self._retry_config, + result_set_type_hints=result_set_type_hints, ) def execute( @@ -165,6 +170,7 @@ def execute( result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> tuple[str, Future[AthenaResultSet | Any]]: """Execute a SQL query asynchronously. @@ -183,6 +189,9 @@ def execute( result_reuse_enable: Enable result reuse for identical queries (optional). result_reuse_minutes: Result reuse duration in minutes (optional). paramstyle: Parameter style to use (optional). + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. **kwargs: Additional execution parameters. Returns: @@ -207,7 +216,9 @@ def execute( result_reuse_minutes=result_reuse_minutes, paramstyle=paramstyle, ) - return query_id, self._executor.submit(self._collect_result_set, query_id) + return query_id, self._executor.submit( + self._collect_result_set, query_id, result_set_type_hints + ) def executemany( self, diff --git a/pyathena/converter.py b/pyathena/converter.py index ce664c81..26c43f77 100644 --- a/pyathena/converter.py +++ b/pyathena/converter.py @@ -12,6 +12,13 @@ from dateutil.tz import gettz +from pyathena.parser import ( + TypedValueConverter, + TypeNode, + TypeSignatureParser, + _normalize_hive_syntax, + _split_array_items, +) from pyathena.util import strtobool _logger = logging.getLogger(__name__) @@ -266,44 +273,6 @@ def _parse_array_native(inner: str) -> list[Any] | None: return result if result else None -def _split_array_items(inner: str) -> list[str]: - """Split array items by comma, respecting brace and bracket groupings. - - Args: - inner: Interior content of array without brackets. - - Returns: - List of item strings. - """ - items = [] - current_item = "" - brace_depth = 0 - bracket_depth = 0 - - for char in inner: - if char == "{": - brace_depth += 1 - elif char == "}": - brace_depth -= 1 - elif char == "[": - bracket_depth += 1 - elif char == "]": - bracket_depth -= 1 - elif char == "," and brace_depth == 0 and bracket_depth == 0: - # Top-level comma - end current item - items.append(current_item.strip()) - current_item = "" - continue - - current_item += char - - # Add the last item - if current_item.strip(): - items.append(current_item.strip()) - - return items - - def _parse_map_native(inner: str) -> dict[str, Any] | None: """Parse map native format: key1=value1, key2=value2. @@ -395,24 +364,22 @@ def _parse_unnamed_struct(inner: str) -> dict[str, Any]: def _convert_value(value: str) -> Any: - """Convert string value to appropriate Python type. + """Convert string value without type inference. + + Returns the string as-is, except for null which becomes None. + This is a safe default that avoids incorrect type conversions + (e.g., converting varchar "1234" to int 1234 inside complex types). + + Use :class:`~pyathena.parser.TypedValueConverter` for type-aware conversion. Args: value: String value to convert. Returns: - Converted value as int, float, bool, None, or string. + None for "null" values, otherwise the original string. """ if value.lower() == "null": return None - if value.lower() == "true": - return True - if value.lower() == "false": - return False - if value.isdigit() or (value.startswith("-") and value[1:].isdigit()): - return int(value) - if "." in value and value.replace(".", "", 1).replace("-", "", 1).isdigit(): - return float(value) return value @@ -549,7 +516,7 @@ def update(self, mappings: dict[str, Callable[[str | None], Any | None]]) -> Non self.mappings.update(mappings) @abstractmethod - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: raise NotImplementedError # pragma: no cover @@ -569,17 +536,76 @@ class DefaultTypeConverter(Converter): - Complex types: array, map, row/struct - JSON: json + When ``type_hint`` is provided (an Athena DDL type signature string like + ``"array(row(name varchar, age integer))"``), nested values within complex + types are converted according to the specified types instead of using + heuristic inference. + Example: >>> converter = DefaultTypeConverter() >>> converter.convert('integer', '42') 42 >>> converter.convert('date', '2023-01-15') datetime.date(2023, 1, 15) + >>> converter.convert('array', '[1, 2, 3]', type_hint='array(varchar)') + ['1', '2', '3'] """ def __init__(self) -> None: super().__init__(mappings=deepcopy(_DEFAULT_CONVERTERS), default=_to_default) + self._parser = TypeSignatureParser() + self._typed_converter = TypedValueConverter( + converters=_DEFAULT_CONVERTERS, + default_converter=_to_default, + struct_parser=_to_struct, + ) + self._parsed_hints: dict[str, TypeNode] = {} + + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: + """Convert a string value to the appropriate Python type. + + When ``type_hint`` is provided, uses the typed converter for precise + conversion of complex types. If the typed converter returns ``None`` + (indicating a parse failure), falls back to the standard untyped + converter so that data is never silently lost. + + Args: + type_: The Athena data type name (e.g., "integer", "varchar", "array"). + value: The string value to convert, or None. + type_hint: Optional Athena DDL type signature for precise complex type + conversion (e.g., "array(varchar)", "row(name varchar, age integer)"). - def convert(self, type_: str, value: str | None) -> Any | None: + Returns: + The converted Python value, or None if the input value was None. + """ + if value is None: + return None + if type_hint: + type_node = self._parse_type_hint(type_hint) + result = self._typed_converter.convert(value, type_node) + if result is not None: + return result + # Typed conversion returned None — this means a parse failure + # (actual SQL NULLs are caught by the `value is None` check above). + # Fall back to untyped conversion to avoid silent data loss. + return self.get(type_)(value) converter = self.get(type_) return converter(value) + + def _parse_type_hint(self, type_hint: str) -> TypeNode: + """Parse a type hint string into a TypeNode, with caching. + + Normalizes Hive-style syntax (``array``) to Trino-style + (``array(integer)``) before parsing, so both syntaxes share the + same cache entry. + + Args: + type_hint: Athena DDL type signature string. + + Returns: + Parsed TypeNode. + """ + normalized = _normalize_hive_syntax(type_hint) + if normalized not in self._parsed_hints: + self._parsed_hints[normalized] = self._parser.parse(normalized) + return self._parsed_hints[normalized] diff --git a/pyathena/cursor.py b/pyathena/cursor.py index 9557dc7f..d113b387 100644 --- a/pyathena/cursor.py +++ b/pyathena/cursor.py @@ -95,32 +95,36 @@ def execute( result_reuse_minutes: int | None = None, paramstyle: str | None = None, on_start_query_execution: Callable[[str], None] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> Cursor: """Execute a SQL query. Args: - operation: SQL query string to execute - parameters: Query parameters (optional) + operation: SQL query string to execute. + parameters: Query parameters (optional). on_start_query_execution: Callback function called immediately after - start_query_execution API is called. - Function signature: (query_id: str) -> None - This allows early access to query_id for - monitoring/cancellation. - **kwargs: Additional execution parameters + start_query_execution API is called. + Function signature: (query_id: str) -> None + This allows early access to query_id for + monitoring/cancellation. + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. For example: + ``{"tags": "array(varchar)", "metadata": "map(varchar, integer)"}`` + **kwargs: Additional execution parameters. Returns: - Cursor: Self reference for method chaining - - Example with callback for early query ID access: - def on_execution_started(query_id): - print(f"Query execution started: {query_id}") - # Store query_id for potential cancellation from another thread - global current_query_id - current_query_id = query_id - - cursor.execute("SELECT * FROM large_table", - on_start_query_execution=on_execution_started) + Self reference for method chaining. + + Example: + >>> cursor.execute( + ... "SELECT * FROM table_with_complex_types", + ... result_set_type_hints={ + ... "tags": "array(varchar)", + ... "metadata": "map(varchar, integer)", + ... } + ... ) """ self._reset_state() self.query_id = self._execute( @@ -150,6 +154,7 @@ def on_execution_started(query_id): query_execution, self.arraysize, self._retry_config, + result_set_type_hints=result_set_type_hints, ) else: raise OperationalError(query_execution.state_change_reason) diff --git a/pyathena/pandas/async_cursor.py b/pyathena/pandas/async_cursor.py index 48d3d217..db9f3f9e 100644 --- a/pyathena/pandas/async_cursor.py +++ b/pyathena/pandas/async_cursor.py @@ -118,6 +118,7 @@ def arraysize(self, value: int) -> None: def _collect_result_set( self, query_id: str, + result_set_type_hints: dict[str | int, str] | None = None, keep_default_na: bool = False, na_values: Iterable[str] | None = ("",), quoting: int = 1, @@ -140,6 +141,7 @@ def _collect_result_set( unload_location=unload_location, engine=kwargs.pop("engine", self._engine), chunksize=kwargs.pop("chunksize", self._chunksize), + result_set_type_hints=result_set_type_hints, **kwargs, ) @@ -154,6 +156,7 @@ def execute( result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, + result_set_type_hints: dict[str | int, str] | None = None, keep_default_na: bool = False, na_values: Iterable[str] | None = ("",), quoting: int = 1, @@ -176,6 +179,7 @@ def execute( self._executor.submit( self._collect_result_set, query_id, + result_set_type_hints, keep_default_na, na_values, quoting, diff --git a/pyathena/pandas/converter.py b/pyathena/pandas/converter.py index 576d2f53..29519bb3 100644 --- a/pyathena/pandas/converter.py +++ b/pyathena/pandas/converter.py @@ -80,7 +80,7 @@ def _dtypes(self) -> dict[str, type[Any]]: } return self.__dtypes - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: pass @@ -103,5 +103,5 @@ def __init__(self) -> None: default=_to_default, ) - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: pass diff --git a/pyathena/pandas/cursor.py b/pyathena/pandas/cursor.py index 39b1cf32..22a7d8ac 100644 --- a/pyathena/pandas/cursor.py +++ b/pyathena/pandas/cursor.py @@ -153,6 +153,7 @@ def execute( na_values: Iterable[str] | None = ("",), quoting: int = 1, on_start_query_execution: Callable[[str], None] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> PandasCursor: """Execute a SQL query and return results as pandas DataFrames. @@ -175,6 +176,9 @@ def execute( na_values: Additional values to treat as NA. quoting: CSV quoting behavior (pandas csv.QUOTE_* constants). on_start_query_execution: Callback called when query starts. + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. **kwargs: Additional pandas read_csv/read_parquet parameters. Returns: @@ -224,6 +228,7 @@ def execute( cache_type=kwargs.pop("cache_type", self._cache_type), max_workers=kwargs.pop("max_workers", self._max_workers), auto_optimize_chunksize=self._auto_optimize_chunksize, + result_set_type_hints=result_set_type_hints, **kwargs, ) else: diff --git a/pyathena/pandas/result_set.py b/pyathena/pandas/result_set.py index 1486b45d..592fdb2d 100644 --- a/pyathena/pandas/result_set.py +++ b/pyathena/pandas/result_set.py @@ -229,6 +229,7 @@ def __init__( cache_type: str | None = None, max_workers: int = (cpu_count() or 1) * 5, auto_optimize_chunksize: bool = False, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> None: """Initialize AthenaPandasResultSet with pandas-specific configurations. @@ -252,6 +253,8 @@ def __init__( max_workers: Maximum worker threads for parallel operations. auto_optimize_chunksize: Enable automatic chunksize determination for large files when chunksize is None. + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion. **kwargs: Additional arguments passed to pandas.read_csv/read_parquet. """ super().__init__( @@ -260,6 +263,7 @@ def __init__( query_execution=query_execution, arraysize=1, # Fetch one row to retrieve metadata retry_config=retry_config, + result_set_type_hints=result_set_type_hints, ) self._rows.clear() # Clear pre_fetch data self._arraysize = arraysize diff --git a/pyathena/parser.py b/pyathena/parser.py new file mode 100644 index 00000000..c6b70eb8 --- /dev/null +++ b/pyathena/parser.py @@ -0,0 +1,537 @@ +from __future__ import annotations + +import json +import re +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +# Aliases for Athena type names that differ between Hive DDL and Trino DDL. +_TYPE_ALIASES: dict[str, str] = { + "int": "integer", +} + +# Pattern for normalizing Hive-style type signatures to Trino-style. +# Matches angle brackets and colons used in Hive DDL (e.g., array>). +_HIVE_SYNTAX_RE: re.Pattern[str] = re.compile(r"[<>:]") +_HIVE_REPLACEMENTS: dict[str, str] = {"<": "(", ">": ")", ":": " "} + + +def _normalize_hive_syntax(type_str: str) -> str: + """Normalize Hive-style DDL syntax to Trino-style. + + Converts angle-bracket notation (``array>``) to + parenthesized notation (``array(struct(a int))``). + + Args: + type_str: Type signature string, possibly using Hive syntax. + + Returns: + Normalized type signature using Trino-style parenthesized notation. + """ + if "<" not in type_str: + return type_str + return _HIVE_SYNTAX_RE.sub(lambda m: _HIVE_REPLACEMENTS[m.group()], type_str) + + +def _split_array_items(inner: str) -> list[str]: + """Split array items by comma, respecting brace and bracket groupings. + + Args: + inner: Interior content of array without brackets. + + Returns: + List of item strings. + """ + items: list[str] = [] + current_item = "" + brace_depth = 0 + bracket_depth = 0 + + for char in inner: + if char == "{": + brace_depth += 1 + elif char == "}": + brace_depth -= 1 + elif char == "[": + bracket_depth += 1 + elif char == "]": + bracket_depth -= 1 + elif char == "," and brace_depth == 0 and bracket_depth == 0: + items.append(current_item.strip()) + current_item = "" + continue + + current_item += char + + if current_item.strip(): + items.append(current_item.strip()) + + return items + + +@dataclass +class TypeNode: + """Parsed representation of an Athena DDL type signature. + + Represents a node in a type tree, where complex types (array, map, row) + have children representing their element/field types. + + Attributes: + type_name: The base type name (e.g., "array", "map", "row", "varchar"). + children: Child type nodes for complex types. + field_names: Field names for row/struct types (parallel to children). + """ + + type_name: str + children: list[TypeNode] = field(default_factory=list) + field_names: list[str] | None = None + _field_type_map: dict[str, TypeNode] | None = field(default=None, repr=False) + + def get_field_type(self, name: str) -> TypeNode | None: + """Look up a child type node by field name using a cached dict. + + Returns: + The TypeNode for the named field, or None if not found. + """ + if self._field_type_map is None and self.field_names: + self._field_type_map = { + fn: self.children[i] + for i, fn in enumerate(self.field_names) + if i < len(self.children) + } + if self._field_type_map: + return self._field_type_map.get(name) + return None + + +class TypeSignatureParser: + """Parse Athena DDL type signature strings into a type tree.""" + + def parse(self, type_str: str) -> TypeNode: + """Parse an Athena DDL type signature string into a TypeNode tree. + + Handles simple types (varchar, integer), parameterized types (decimal(10,2)), + and complex types (array, map, row/struct) with arbitrary nesting. + + Args: + type_str: Athena DDL type string (e.g., "array(row(name varchar, age integer))"). + + Returns: + TypeNode representing the parsed type tree. + """ + type_str = type_str.strip() + + paren_idx = type_str.find("(") + if paren_idx == -1: + name = type_str.lower() + return TypeNode(type_name=_TYPE_ALIASES.get(name, name)) + + type_name = type_str[:paren_idx].strip().lower() + type_name = _TYPE_ALIASES.get(type_name, type_name) + + close_idx = self._find_matching_paren(type_str, paren_idx) + inner = type_str[paren_idx + 1 : close_idx].strip() + + if type_name in ("row", "struct"): + parts = self._split_type_args(inner) + field_names: list[str] = [] + children: list[TypeNode] = [] + for part in parts: + part = part.strip() + space_idx = self._find_field_name_boundary(part) + if space_idx == -1: + children.append(self.parse(part)) + field_names.append(part) + else: + field_name = part[:space_idx].strip() + type_part = part[space_idx + 1 :].strip() + field_names.append(field_name) + children.append(self.parse(type_part)) + return TypeNode(type_name=type_name, children=children, field_names=field_names) + + if type_name == "array": + child = self.parse(inner) + return TypeNode(type_name=type_name, children=[child]) + + if type_name == "map": + parts = self._split_type_args(inner) + if len(parts) == 2: + key_type = self.parse(parts[0]) + value_type = self.parse(parts[1]) + return TypeNode(type_name=type_name, children=[key_type, value_type]) + return TypeNode(type_name=type_name) + + # Types with parameters like decimal(10, 2), varchar(255) + return TypeNode(type_name=type_name) + + def _split_type_args(self, s: str) -> list[str]: + """Split a type signature argument string by comma, respecting nested parentheses. + + Args: + s: Type signature argument string to split. + + Returns: + List of type argument strings. + """ + parts: list[str] = [] + current: list[str] = [] + depth = 0 + + for char in s: + if char == "(": + depth += 1 + elif char == ")": + depth -= 1 + elif char == "," and depth == 0: + parts.append("".join(current).strip()) + current = [] + continue + current.append(char) + + if current: + parts.append("".join(current).strip()) + return parts + + @staticmethod + def _find_matching_paren(s: str, open_idx: int) -> int: + """Find the index of the closing parenthesis matching the one at *open_idx*. + + Args: + s: The full string. + open_idx: Index of the opening ``(``. + + Returns: + Index of the matching ``)``. + """ + depth = 0 + for i in range(open_idx, len(s)): + if s[i] == "(": + depth += 1 + elif s[i] == ")": + depth -= 1 + if depth == 0: + return i + return len(s) - 1 + + def _find_field_name_boundary(self, part: str) -> int: + """Find the boundary between field name and type in a row field definition. + + Handles cases like "name varchar" and "data row(x integer, y integer)". + + Args: + part: A single field definition string. + + Returns: + Index of the space separating field name from type, or -1 if not found. + """ + depth = 0 + for i, char in enumerate(part): + if char == "(": + depth += 1 + elif char == ")": + depth -= 1 + elif char == " " and depth == 0: + return i + return -1 + + +class TypedValueConverter: + """Convert values using TypeNode type information. + + Dependencies are injected via the constructor to avoid circular imports + between parser.py and converter.py. + + Args: + converters: Mapping of type names to conversion functions. + default_converter: Fallback conversion function for unknown types. + struct_parser: Function to parse untyped struct values. + """ + + def __init__( + self, + converters: dict[str, Callable[[str | None], Any | None]], + default_converter: Callable[[str | None], Any | None], + struct_parser: Callable[[str | None], dict[str, Any] | None], + ) -> None: + self._converters = converters + self._default_converter = default_converter + self._struct_parser = struct_parser + + def convert(self, value: str, type_node: TypeNode) -> Any: + """Convert a value using type information from a TypeNode. + + For complex types (array, map, row), parses the structure and + recursively converts elements using child type information. + For simple types, uses the standard converter function. + + Args: + value: String value to convert. + type_node: Parsed type information. + + Returns: + Converted value. + """ + if type_node.type_name == "array": + return self._convert_typed_array(value, type_node) + if type_node.type_name == "map": + return self._convert_typed_map(value, type_node) + if type_node.type_name in ("row", "struct"): + return self._convert_typed_struct(value, type_node) + converter_fn = self._converters.get(type_node.type_name, self._default_converter) + return converter_fn(value) + + @staticmethod + def _to_json_str(value: Any) -> str: + """Convert a JSON-parsed value back to a string for further conversion. + + Uses json.dumps for dict/list to produce valid JSON, and str() for + scalar types to produce converter-compatible strings. + + Args: + value: A value from json.loads output. + + Returns: + String representation suitable for type conversion. + """ + if isinstance(value, (dict, list)): + return json.dumps(value) + return str(value) + + def _convert_element(self, value: str, type_node: TypeNode) -> Any: + """Convert a single element within a complex type using type information. + + Handles null values before delegating to type-specific conversion. + + Args: + value: String value to convert. + type_node: Type information for this element. + + Returns: + Converted value, or None for null. + """ + if value.lower() == "null": + return None + return self.convert(value, type_node) + + def _convert_typed_array(self, value: str, type_node: TypeNode) -> list[Any] | None: + """Convert an array value using type information. + + Args: + value: String representation of the array. + type_node: Type node with array element type as first child. + + Returns: + List of converted elements, or None if parsing fails. + """ + if not (value.startswith("[") and value.endswith("]")): + return None + + element_type = type_node.children[0] if type_node.children else TypeNode("varchar") + + # Try JSON first (only if content looks like JSON) + inner_preview = value[1:10] if len(value) > 10 else value[1:-1] + if '"' in inner_preview or value.startswith(("[{", "[null", "[[")): + try: + parsed = json.loads(value) + if isinstance(parsed, list): + return [ + None + if elem is None + else self.convert(self._to_json_str(elem), element_type) + for elem in parsed + ] + except json.JSONDecodeError: + pass + + # Native format + inner = value[1:-1].strip() + if not inner: + return [] + + if "[" in inner: + return None # Nested arrays not supported in native format + + items = _split_array_items(inner) + result: list[Any] = [] + for item in items: + item = item.strip() + if not item: + continue + if item.startswith("{") and item.endswith("}"): + if element_type.type_name in ("row", "struct"): + result.append(self._convert_typed_struct(item, element_type)) + elif element_type.type_name == "map": + result.append(self._convert_typed_map(item, element_type)) + else: + result.append(self._struct_parser(item)) + else: + result.append(self._convert_element(item, element_type)) + + return result if result else None + + def _convert_typed_map(self, value: str, type_node: TypeNode) -> dict[str, Any] | None: + """Convert a map value using type information. + + Args: + value: String representation of the map. + type_node: Type node with key type and value type as children. + + Returns: + Dictionary of converted key-value pairs, or None if parsing fails. + """ + if not (value.startswith("{") and value.endswith("}")): + return None + + key_type = type_node.children[0] if len(type_node.children) > 0 else TypeNode("varchar") + value_type = type_node.children[1] if len(type_node.children) > 1 else TypeNode("varchar") + + # Try JSON first + inner_preview = value[1:10] if len(value) > 10 else value[1:-1] + if '"' in inner_preview or value.startswith('{"'): + try: + parsed = json.loads(value) + if isinstance(parsed, dict): + return { + str(self.convert(self._to_json_str(k), key_type) if k is not None else k): ( + self.convert(self._to_json_str(v), value_type) + if v is not None + else None + ) + for k, v in parsed.items() + } + except json.JSONDecodeError: + pass + + # Native format + inner = value[1:-1].strip() + if not inner: + return {} + + pairs = _split_array_items(inner) + result: dict[str, Any] = {} + for pair in pairs: + if "=" not in pair: + continue + k, v = pair.split("=", 1) + k = k.strip() + v = v.strip() + if any(char in k for char in '{}="'): + continue + if v.startswith("{") and v.endswith("}"): + if value_type.type_name in ("row", "struct"): + result[str(self._convert_element(k, key_type))] = self._convert_typed_struct( + v, value_type + ) + elif value_type.type_name == "map": + result[str(self._convert_element(k, key_type))] = self._convert_typed_map( + v, value_type + ) + else: + result[str(self._convert_element(k, key_type))] = self._struct_parser(v) + else: + converted_key = self._convert_element(k, key_type) + converted_value = self._convert_element(v, value_type) + result[str(converted_key)] = converted_value + + return result if result else None + + def _convert_typed_struct(self, value: str, type_node: TypeNode) -> dict[str, Any] | None: + """Convert a struct/row value using type information. + + Args: + value: String representation of the struct. + type_node: Type node with field types and names. + + Returns: + Dictionary of converted field values, or None if parsing fails. + """ + if not (value.startswith("{") and value.endswith("}")): + return None + + field_types = type_node.children or [] + + # Try JSON first + inner_preview = value[1:10] if len(value) > 10 else value[1:-1] + if '"' in inner_preview or value.startswith('{"'): + try: + parsed = json.loads(value) + if isinstance(parsed, dict): + result: dict[str, Any] = {} + for i, (k, v) in enumerate(parsed.items()): + ft = self._get_field_type(k, type_node, i) + result[k] = ( + self.convert(self._to_json_str(v), ft) if v is not None else None + ) + return result + except json.JSONDecodeError: + pass + + inner = value[1:-1].strip() + if not inner: + return {} + + if "=" in inner: + # Named struct + pairs = _split_array_items(inner) + result = {} + field_index = 0 + for pair in pairs: + if "=" not in pair: + continue + k, v = pair.split("=", 1) + k = k.strip() + v = v.strip() + if any(char in k for char in '{}="'): + continue + + ft = self._get_field_type(k, type_node, field_index) + field_index += 1 + + if v.startswith("{") and v.endswith("}"): + if ft.type_name in ("row", "struct"): + result[k] = self._convert_typed_struct(v, ft) + elif ft.type_name == "map": + result[k] = self._convert_typed_map(v, ft) + else: + result[k] = self._struct_parser(v) + else: + result[k] = self._convert_element(v, ft) + return result if result else None + + # Unnamed struct + field_names = type_node.field_names or [] + values = _split_array_items(inner) + result = {} + for i, v in enumerate(values): + ft = field_types[i] if i < len(field_types) else TypeNode("varchar") + name = field_names[i] if i < len(field_names) else str(i) + result[name] = self._convert_element(v, ft) + return result + + @staticmethod + def _get_field_type( + field_name: str, + type_node: TypeNode, + field_index: int, + ) -> TypeNode: + """Look up the type for a struct field by name or index. + + Uses the TypeNode's cached dict for O(1) name lookup, then falls + back to positional index. + + Args: + field_name: Name of the field to look up. + type_node: The parent row/struct TypeNode. + field_index: Current positional index as fallback. + + Returns: + TypeNode for the field, defaulting to varchar if not found. + """ + ft = type_node.get_field_type(field_name) + if ft is not None: + return ft + field_types = type_node.children or [] + if field_index < len(field_types): + return field_types[field_index] + return TypeNode("varchar") diff --git a/pyathena/polars/async_cursor.py b/pyathena/polars/async_cursor.py index 9349f61a..3774a47e 100644 --- a/pyathena/polars/async_cursor.py +++ b/pyathena/polars/async_cursor.py @@ -161,6 +161,7 @@ def arraysize(self, value: int) -> None: def _collect_result_set( self, query_id: str, + result_set_type_hints: dict[str | int, str] | None = None, unload_location: str | None = None, kwargs: dict[str, Any] | None = None, ) -> AthenaPolarsResultSet: @@ -179,6 +180,7 @@ def _collect_result_set( cache_type=self._cache_type, max_workers=self._max_workers, chunksize=self._chunksize, + result_set_type_hints=result_set_type_hints, **kwargs, ) @@ -193,6 +195,7 @@ def execute( result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> tuple[str, Future[AthenaPolarsResultSet | Any]]: """Execute a SQL query asynchronously and return results as Polars DataFrames. @@ -210,6 +213,9 @@ def execute( result_reuse_enable: Enable Athena result reuse for this query. result_reuse_minutes: Minutes to reuse cached results. paramstyle: Parameter style ('qmark' or 'pyformat'). + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. **kwargs: Additional execution parameters passed to Polars read functions. Returns: @@ -237,6 +243,7 @@ def execute( self._executor.submit( self._collect_result_set, query_id, + result_set_type_hints, unload_location, kwargs, ), diff --git a/pyathena/polars/converter.py b/pyathena/polars/converter.py index 627a4ffd..356deaf2 100644 --- a/pyathena/polars/converter.py +++ b/pyathena/polars/converter.py @@ -103,7 +103,7 @@ def get_dtype(self, type_: str, precision: int = 0, scale: int = 0) -> Any: return pl.Decimal(precision=precision, scale=scale) return self._types.get(type_) - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: converter = self.get(type_) return converter(value) @@ -127,5 +127,5 @@ def __init__(self) -> None: default=_to_default, ) - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: pass diff --git a/pyathena/polars/cursor.py b/pyathena/polars/cursor.py index b788feac..efc738c4 100644 --- a/pyathena/polars/cursor.py +++ b/pyathena/polars/cursor.py @@ -157,6 +157,7 @@ def execute( result_reuse_minutes: int | None = None, paramstyle: str | None = None, on_start_query_execution: Callable[[str], None] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> PolarsCursor: """Execute a SQL query and return results as Polars DataFrames. @@ -175,6 +176,9 @@ def execute( result_reuse_minutes: Minutes to reuse cached results. paramstyle: Parameter style ('qmark' or 'pyformat'). on_start_query_execution: Callback called when query starts. + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. **kwargs: Additional execution parameters passed to Polars read functions. Returns: @@ -218,6 +222,7 @@ def execute( cache_type=self._cache_type, max_workers=self._max_workers, chunksize=self._chunksize, + result_set_type_hints=result_set_type_hints, **kwargs, ) else: diff --git a/pyathena/polars/result_set.py b/pyathena/polars/result_set.py index 6818b84e..adc4e2ac 100644 --- a/pyathena/polars/result_set.py +++ b/pyathena/polars/result_set.py @@ -202,6 +202,7 @@ def __init__( cache_type: str | None = None, max_workers: int = (cpu_count() or 1) * 5, chunksize: int | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> None: """Initialize the Polars result set. @@ -220,6 +221,8 @@ def __init__( chunksize: Number of rows per chunk for memory-efficient processing. If specified, data is loaded lazily in chunks for all data access methods including fetchone(), fetchmany(), and iter_chunks(). + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion. **kwargs: Additional arguments passed to Polars read functions. """ super().__init__( @@ -228,6 +231,7 @@ def __init__( query_execution=query_execution, arraysize=1, # Fetch one row to retrieve metadata retry_config=retry_config, + result_set_type_hints=result_set_type_hints, ) self._rows.clear() # Clear pre_fetch data self._arraysize = arraysize diff --git a/pyathena/result_set.py b/pyathena/result_set.py index 7f579d06..ea1fe974 100644 --- a/pyathena/result_set.py +++ b/pyathena/result_set.py @@ -53,6 +53,10 @@ class AthenaResultSet(CursorIterator): https://docs.aws.amazon.com/athena/latest/APIReference/API_GetQueryResults.html """ + # https://docs.aws.amazon.com/athena/latest/ug/data-types.html + # Athena complex types that benefit from type hint conversion. + _COMPLEX_TYPES: frozenset[str] = frozenset({"array", "map", "row", "struct"}) + def __init__( self, connection: Connection[Any], @@ -61,6 +65,7 @@ def __init__( arraysize: int, retry_config: RetryConfig, _pre_fetch: bool = True, + result_set_type_hints: dict[str | int, str] | None = None, ) -> None: super().__init__(arraysize=arraysize) self._connection: Connection[Any] | None = connection @@ -69,6 +74,14 @@ def __init__( if not self._query_execution: raise ProgrammingError("Required argument `query_execution` not found.") self._retry_config = retry_config + self._hints_by_name: dict[str, str] = {} + self._hints_by_index: dict[int, str] = {} + if result_set_type_hints: + for k, v in result_set_type_hints.items(): + if isinstance(k, int): + self._hints_by_index[k] = v + else: + self._hints_by_name[k.lower()] = v self._client = connection.session.client( "s3", region_name=connection.region_name, @@ -77,6 +90,9 @@ def __init__( ) self._metadata: tuple[dict[str, Any], ...] | None = None + self._column_types: tuple[str, ...] | None = None + self._column_names: tuple[str, ...] | None = None + self._column_type_hints: tuple[str | None, ...] | None = None self._rows: collections.deque[tuple[Any | None, ...] | dict[Any, Any | None]] = ( collections.deque() ) @@ -418,6 +434,41 @@ def _process_metadata(self, response: dict[str, Any]) -> None: if column_info is None: raise DataError("KeyError `ColumnInfo`") self._metadata = tuple(column_info) + self._column_types = tuple(m.get("Type", "") for m in self._metadata) + self._column_names = tuple(m.get("Name", "") for m in self._metadata) + if (self._hints_by_name or self._hints_by_index) and any( + t.lower() in self._COMPLEX_TYPES for t in self._column_types + ): + hints = tuple( + self._resolve_type_hint(i, m.get("Name", "").lower(), t.lower()) + for i, (m, t) in enumerate(zip(self._metadata, self._column_types, strict=True)) + ) + if any(hints): + self._column_type_hints = hints + + def _resolve_type_hint( + self, index: int, col_name_lower: str, col_type_lower: str + ) -> str | None: + """Look up the type hint for a column by index then by name. + + Index-based hints take priority over name-based hints, allowing + callers to disambiguate duplicate column names. + + Args: + index: Zero-based column position. + col_name_lower: Lowercased column name from metadata. + col_type_lower: Lowercased column type from metadata. + + Returns: + The type hint string, or None if the column has no hint or + is not a complex type. + """ + if col_type_lower not in self._COMPLEX_TYPES: + return None + hint = self._hints_by_index.get(index) + if hint is not None: + return hint + return self._hints_by_name.get(col_name_lower) def _process_update_count(self, response: dict[str, Any]) -> None: update_count = response.get("UpdateCount") @@ -443,12 +494,32 @@ def _get_rows( converter: Converter | None = None, ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: conv = converter or self._converter + col_types = self._column_types + col_hints = self._column_type_hints + if col_hints and col_types: + return [ + tuple( + conv.convert(col_type, row.get("VarCharValue"), type_hint=hint) + if hint + else conv.convert(col_type, row.get("VarCharValue")) + for col_type, row, hint in zip( + col_types, rows[i].get("Data", []), col_hints, strict=False + ) + ) + for i in range(offset, len(rows)) + ] + if col_types: + return [ + tuple( + conv.convert(col_type, row.get("VarCharValue")) + for col_type, row in zip(col_types, rows[i].get("Data", []), strict=False) + ) + for i in range(offset, len(rows)) + ] return [ tuple( - [ - conv.convert(meta.get("Type"), row.get("VarCharValue")) - for meta, row in zip(metadata, rows[i].get("Data", []), strict=False) - ] + conv.convert(meta.get("Type"), row.get("VarCharValue")) + for meta, row in zip(metadata, rows[i].get("Data", []), strict=False) ) for i in range(offset, len(rows)) ] @@ -606,6 +677,8 @@ def close(self) -> None: self._connection = None self._query_execution = None self._metadata = None + self._column_types = None + self._column_names = None self._rows.clear() self._next_token = None self._rownumber = None @@ -630,15 +703,48 @@ def _get_rows( converter: Converter | None = None, ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: conv = converter or self._converter - return [ - self.dict_type( - [ + col_types = self._column_types + col_names = self._column_names + col_hints = self._column_type_hints + if col_hints and col_types and col_names: + return [ + self.dict_type( + ( + name, + conv.convert(col_type, row.get("VarCharValue"), type_hint=hint) + if hint + else conv.convert(col_type, row.get("VarCharValue")), + ) + for name, col_type, row, hint in zip( + col_names, + col_types, + rows[i].get("Data", []), + col_hints, + strict=False, + ) + ) + for i in range(offset, len(rows)) + ] + if col_types and col_names: + return [ + self.dict_type( ( - meta.get("Name"), - conv.convert(meta.get("Type"), row.get("VarCharValue")), + name, + conv.convert(col_type, row.get("VarCharValue")), + ) + for name, col_type, row in zip( + col_names, col_types, rows[i].get("Data", []), strict=False ) - for meta, row in zip(metadata, rows[i].get("Data", []), strict=False) - ] + ) + for i in range(offset, len(rows)) + ] + return [ + self.dict_type( + ( + meta.get("Name"), + conv.convert(meta.get("Type"), row.get("VarCharValue")), + ) + for meta, row in zip(metadata, rows[i].get("Data", []), strict=False) ) for i in range(offset, len(rows)) ] diff --git a/pyathena/s3fs/async_cursor.py b/pyathena/s3fs/async_cursor.py index a98c4558..73acede2 100644 --- a/pyathena/s3fs/async_cursor.py +++ b/pyathena/s3fs/async_cursor.py @@ -142,12 +142,16 @@ def arraysize(self, value: int) -> None: def _collect_result_set( self, query_id: str, + result_set_type_hints: dict[str | int, str] | None = None, kwargs: dict[str, Any] | None = None, ) -> AthenaS3FSResultSet: """Collect result set after query execution. Args: query_id: The Athena query execution ID. + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. kwargs: Additional keyword arguments for result set. Returns: @@ -163,6 +167,7 @@ def _collect_result_set( arraysize=self._arraysize, retry_config=self._retry_config, csv_reader=self._csv_reader, + result_set_type_hints=result_set_type_hints, **kwargs, ) @@ -177,6 +182,7 @@ def execute( result_reuse_enable: bool | None = None, result_reuse_minutes: int | None = None, paramstyle: str | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> tuple[str, Future[AthenaS3FSResultSet | Any]]: """Execute a SQL query asynchronously. @@ -194,6 +200,9 @@ def execute( result_reuse_enable: Enable Athena result reuse for this query. result_reuse_minutes: Minutes to reuse cached results. paramstyle: Parameter style ('qmark' or 'pyformat'). + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. **kwargs: Additional execution parameters. Returns: @@ -220,6 +229,7 @@ def execute( self._executor.submit( self._collect_result_set, query_id, + result_set_type_hints, kwargs, ), ) diff --git a/pyathena/s3fs/converter.py b/pyathena/s3fs/converter.py index 853f7a7d..00fcd5ed 100644 --- a/pyathena/s3fs/converter.py +++ b/pyathena/s3fs/converter.py @@ -2,7 +2,7 @@ import logging from copy import deepcopy -from typing import Any +from typing import TYPE_CHECKING, Any from pyathena.converter import ( _DEFAULT_CONVERTERS, @@ -10,6 +10,9 @@ _to_default, ) +if TYPE_CHECKING: + from pyathena.converter import DefaultTypeConverter + _logger = logging.getLogger(__name__) @@ -43,8 +46,9 @@ def __init__(self) -> None: mappings=deepcopy(_DEFAULT_CONVERTERS), default=_to_default, ) + self._default_type_converter: DefaultTypeConverter | None = None - def convert(self, type_: str, value: str | None) -> Any | None: + def convert(self, type_: str, value: str | None, type_hint: str | None = None) -> Any | None: """Convert a string value to the appropriate Python type. Looks up the converter function for the given Athena type and applies @@ -53,9 +57,19 @@ def convert(self, type_: str, value: str | None) -> Any | None: Args: type_: The Athena data type name (e.g., "integer", "varchar", "date"). value: The string value to convert, or None. + type_hint: Optional Athena DDL type signature for precise complex type + conversion (e.g., "array(varchar)"). Returns: The converted Python value, or None if the input value was None. """ + if value is None: + return None + if type_hint: + if self._default_type_converter is None: + from pyathena.converter import DefaultTypeConverter + + self._default_type_converter = DefaultTypeConverter() + return self._default_type_converter.convert(type_, value, type_hint=type_hint) converter = self.get(type_) return converter(value) diff --git a/pyathena/s3fs/cursor.py b/pyathena/s3fs/cursor.py index fbb85fab..dfc5dd5e 100644 --- a/pyathena/s3fs/cursor.py +++ b/pyathena/s3fs/cursor.py @@ -133,6 +133,7 @@ def execute( result_reuse_minutes: int | None = None, paramstyle: str | None = None, on_start_query_execution: Callable[[str], None] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> S3FSCursor: """Execute a SQL query and return results. @@ -151,6 +152,9 @@ def execute( result_reuse_minutes: Minutes to reuse cached results. paramstyle: Parameter style ('qmark' or 'pyformat'). on_start_query_execution: Callback called when query starts. + result_set_type_hints: Optional dictionary mapping column names to + Athena DDL type signatures for precise type conversion within + complex types. **kwargs: Additional execution parameters. Returns: @@ -188,6 +192,7 @@ def execute( arraysize=self.arraysize, retry_config=self._retry_config, csv_reader=self._csv_reader, + result_set_type_hints=result_set_type_hints, **kwargs, ) else: diff --git a/pyathena/s3fs/result_set.py b/pyathena/s3fs/result_set.py index 4173f25b..04828c00 100644 --- a/pyathena/s3fs/result_set.py +++ b/pyathena/s3fs/result_set.py @@ -64,6 +64,7 @@ def __init__( block_size: int | None = None, csv_reader: CSVReaderType | None = None, filesystem_class: type[AbstractFileSystem] | None = None, + result_set_type_hints: dict[str | int, str] | None = None, **kwargs, ) -> None: super().__init__( @@ -72,6 +73,7 @@ def __init__( query_execution=query_execution, arraysize=1, # Fetch one row to retrieve metadata retry_config=retry_config, + result_set_type_hints=result_set_type_hints, ) # Save pre-fetched rows (from Athena API) in case CSV reading is not available pre_fetched_rows = list(self._rows) @@ -147,8 +149,11 @@ def _fetch(self) -> None: if not self._csv_reader: return - description = self.description if self.description else [] - column_types = [d[1] for d in description] + col_types = self._column_types + if not col_types: + description = self.description if self.description else [] + col_types = tuple(d[1] for d in description) + col_hints = self._column_type_hints rows_fetched = 0 while rows_fetched < self._arraysize: @@ -161,15 +166,33 @@ def _fetch(self) -> None: # AthenaCSVReader returns None for NULL values directly, # DefaultCSVReader returns empty string which needs conversion if self._csv_reader_class is DefaultCSVReader: - converted_row = tuple( - self._converter.convert(col_type, value if value != "" else None) - for col_type, value in zip(column_types, row, strict=False) - ) + if col_hints: + converted_row = tuple( + self._converter.convert( + col_type, value if value != "" else None, type_hint=hint + ) + if hint + else self._converter.convert(col_type, value if value != "" else None) + for col_type, value, hint in zip(col_types, row, col_hints, strict=False) + ) + else: + converted_row = tuple( + self._converter.convert(col_type, value if value != "" else None) + for col_type, value in zip(col_types, row, strict=False) + ) else: - converted_row = tuple( - self._converter.convert(col_type, value) - for col_type, value in zip(column_types, row, strict=False) - ) + if col_hints: + converted_row = tuple( + self._converter.convert(col_type, value, type_hint=hint) + if hint + else self._converter.convert(col_type, value) + for col_type, value, hint in zip(col_types, row, col_hints, strict=False) + ) + else: + converted_row = tuple( + self._converter.convert(col_type, value) + for col_type, value in zip(col_types, row, strict=False) + ) self._rows.append(converted_row) rows_fetched += 1 diff --git a/tests/pyathena/pandas/test_util.py b/tests/pyathena/pandas/test_util.py index a34d2341..788308da 100644 --- a/tests/pyathena/pandas/test_util.py +++ b/tests/pyathena/pandas/test_util.py @@ -116,9 +116,9 @@ def test_as_pandas(cursor): b"123", [1, 2], [1, 2], + {"1": "2", "3": "4"}, {"1": 2, "3": 4}, - {"1": 2, "3": 4}, - {"a": 1, "b": 2}, + {"a": "1", "b": "2"}, Decimal("0.1"), ) ] diff --git a/tests/pyathena/s3fs/test_async_cursor.py b/tests/pyathena/s3fs/test_async_cursor.py index f7046fa3..4363367a 100644 --- a/tests/pyathena/s3fs/test_async_cursor.py +++ b/tests/pyathena/s3fs/test_async_cursor.py @@ -118,9 +118,9 @@ def test_complex(self, async_s3fs_cursor): b"123", [1, 2], [1, 2], + {"1": "2", "3": "4"}, {"1": 2, "3": 4}, - {"1": 2, "3": 4}, - {"a": 1, "b": 2}, + {"a": "1", "b": "2"}, Decimal("0.1"), ) ] diff --git a/tests/pyathena/s3fs/test_cursor.py b/tests/pyathena/s3fs/test_cursor.py index fd8f6572..40edaa07 100644 --- a/tests/pyathena/s3fs/test_cursor.py +++ b/tests/pyathena/s3fs/test_cursor.py @@ -118,9 +118,9 @@ def test_complex(self, s3fs_cursor): b"123", [1, 2], [1, 2], + {"1": "2", "3": "4"}, {"1": 2, "3": 4}, - {"1": 2, "3": 4}, - {"a": 1, "b": 2}, + {"a": "1", "b": "2"}, Decimal("0.1"), ) ] diff --git a/tests/pyathena/sqlalchemy/test_base.py b/tests/pyathena/sqlalchemy/test_base.py index 7c4c44fc..a42219ee 100644 --- a/tests/pyathena/sqlalchemy/test_base.py +++ b/tests/pyathena/sqlalchemy/test_base.py @@ -130,9 +130,9 @@ def test_select_nested_struct_query(self, engine): assert "header" in result.positions assert isinstance(result.positions["header"], dict) assert result.positions["header"]["stamp"] == "2024-01-01" - assert result.positions["header"]["seq"] == 123 - assert result.positions["x"] == 4.736 - assert result.positions["y"] == 0.583 + assert result.positions["header"]["seq"] == "123" + assert result.positions["x"] == "4.736" + assert result.positions["y"] == "0.583" # Test double nested struct query = sqlalchemy.text( @@ -147,7 +147,7 @@ def test_select_nested_struct_query(self, engine): result = conn.execute(query).fetchone() assert result is not None assert result.data["level1"]["level2"]["level3"] == "value" - assert result.data["field"] == 123 + assert result.data["field"] == "123" # Test multiple nested fields query = sqlalchemy.text( @@ -166,11 +166,11 @@ def test_select_nested_struct_query(self, engine): ) result = conn.execute(query).fetchone() assert result is not None - assert result.data["pos"]["x"] == 1 - assert result.data["pos"]["y"] == 2 - assert result.data["vel"]["x"] == 0.5 - assert result.data["vel"]["y"] == 0.3 - assert result.data["timestamp"] == 12345 + assert result.data["pos"]["x"] == "1" + assert result.data["pos"]["y"] == "2" + assert result.data["vel"]["x"] == "0.5" + assert result.data["vel"]["y"] == "0.3" + assert result.data["timestamp"] == "12345" def test_select_array_with_nested_struct(self, engine): """Test SELECT query with ARRAY containing nested STRUCT (Issue #627).""" @@ -197,8 +197,8 @@ def test_select_array_with_nested_struct(self, engine): assert "header" in result.positions[0] assert isinstance(result.positions[0]["header"], dict) assert result.positions[0]["header"]["stamp"] == "2024-01-01" - assert result.positions[0]["header"]["seq"] == 123 - assert result.positions[0]["x"] == 4.736 + assert result.positions[0]["header"]["seq"] == "123" + assert result.positions[0]["x"] == "4.736" # Multiple elements with nested structs query = sqlalchemy.text( @@ -213,12 +213,12 @@ def test_select_array_with_nested_struct(self, engine): result = conn.execute(query).fetchone() assert result is not None assert len(result.data) == 2 - assert result.data[0]["pos"]["x"] == 1 - assert result.data[0]["pos"]["y"] == 2 - assert result.data[0]["vel"]["x"] == 0.5 - assert result.data[1]["pos"]["x"] == 3 - assert result.data[1]["pos"]["y"] == 4 - assert result.data[1]["vel"]["x"] == 1.5 + assert result.data[0]["pos"]["x"] == "1" + assert result.data[0]["pos"]["y"] == "2" + assert result.data[0]["vel"]["x"] == "0.5" + assert result.data[1]["pos"]["x"] == "3" + assert result.data[1]["pos"]["y"] == "4" + assert result.data[1]["vel"]["x"] == "1.5" def test_reflect_no_such_table(self, engine): engine, conn = engine @@ -513,8 +513,8 @@ def test_reflect_select(self, engine): date(2017, 1, 2), b"123", [1, 2], - {"1": 2, "3": 4}, # map type now converted to dict - {"a": 1, "b": 2}, # row type now converted to dict + {"1": "2", "3": "4"}, # map type now converted to dict + {"a": "1", "b": "2"}, # row type now converted to dict Decimal("0.1"), ] assert isinstance(one_row_complex.c.col_boolean.type, types.BOOLEAN) diff --git a/tests/pyathena/test_converter.py b/tests/pyathena/test_converter.py index fbf78884..39583a24 100644 --- a/tests/pyathena/test_converter.py +++ b/tests/pyathena/test_converter.py @@ -1,6 +1,11 @@ import pytest -from pyathena.converter import DefaultTypeConverter, _to_array, _to_struct +from pyathena.converter import ( + DefaultTypeConverter, + _to_array, + _to_map, + _to_struct, +) @pytest.mark.parametrize( @@ -20,195 +25,103 @@ ], ) def test_to_struct_json_formats(input_value, expected): - """Test STRUCT conversion for various JSON formats and edge cases.""" - result = _to_struct(input_value) - assert result == expected + assert _to_struct(input_value) == expected @pytest.mark.parametrize( ("input_value", "expected"), [ - ("{a=1, b=2}", {"a": 1, "b": 2}), + ("{a=1, b=2}", {"a": "1", "b": "2"}), ("{}", {}), ("{name=John, city=Tokyo}", {"name": "John", "city": "Tokyo"}), - ("{Alice, 25}", {"0": "Alice", "1": 25}), - ("{John, 30, true}", {"0": "John", "1": 30, "2": True}), - ("{name=John, age=30}", {"name": "John", "age": 30}), - ("{x=1, y=2, z=3}", {"x": 1, "y": 2, "z": 3}), - ("{active=true, count=42}", {"active": True, "count": 42}), + ("{Alice, 25}", {"0": "Alice", "1": "25"}), + ("{John, 30, true}", {"0": "John", "1": "30", "2": "true"}), + ("{name=John, age=30}", {"name": "John", "age": "30"}), + ("{x=1, y=2, z=3}", {"x": "1", "y": "2", "z": "3"}), + ("{active=true, count=42}", {"active": "true", "count": "42"}), ], ) def test_to_struct_athena_native_formats(input_value, expected): - """Test STRUCT conversion for Athena native formats.""" - result = _to_struct(input_value) - assert result == expected + assert _to_struct(input_value) == expected @pytest.mark.parametrize( ("input_value", "expected"), [ - # Single level nesting (Issue #627) ( "{header={stamp=2024-01-01, seq=123}, x=4.736, y=0.583}", - {"header": {"stamp": "2024-01-01", "seq": 123}, "x": 4.736, "y": 0.583}, + {"header": {"stamp": "2024-01-01", "seq": "123"}, "x": "4.736", "y": "0.583"}, ), - # Double nesting ( "{outer={middle={inner=value}}, field=123}", - {"outer": {"middle": {"inner": "value"}}, "field": 123}, + {"outer": {"middle": {"inner": "value"}}, "field": "123"}, ), - # Multiple nested fields ( "{pos={x=1, y=2}, vel={x=0.5, y=0.3}, timestamp=12345}", - {"pos": {"x": 1, "y": 2}, "vel": {"x": 0.5, "y": 0.3}, "timestamp": 12345}, + { + "pos": {"x": "1", "y": "2"}, + "vel": {"x": "0.5", "y": "0.3"}, + "timestamp": "12345", + }, ), - # Triple nesting ( "{level1={level2={level3={value=deep}}}}", {"level1": {"level2": {"level3": {"value": "deep"}}}}, ), - # Mixed types in nested struct ( "{metadata={id=123, active=true, name=test}, count=5}", - {"metadata": {"id": 123, "active": True, "name": "test"}, "count": 5}, + {"metadata": {"id": "123", "active": "true", "name": "test"}, "count": "5"}, ), - # Nested struct with null value ( "{data={value=null, status=ok}, flag=true}", - {"data": {"value": None, "status": "ok"}, "flag": True}, + {"data": {"value": None, "status": "ok"}, "flag": "true"}, ), - # Complex nesting with multiple levels and fields ( "{a={b={c=1, d=2}, e=3}, f=4, g={h=5}}", - {"a": {"b": {"c": 1, "d": 2}, "e": 3}, "f": 4, "g": {"h": 5}}, + {"a": {"b": {"c": "1", "d": "2"}, "e": "3"}, "f": "4", "g": {"h": "5"}}, ), ], ) def test_to_struct_athena_nested_formats(input_value, expected): - """Test STRUCT conversion for nested struct formats (Issue #627).""" - result = _to_struct(input_value) - assert result == expected + assert _to_struct(input_value) == expected @pytest.mark.parametrize( "input_value", [ - "{formula=x=y+1, status=active}", # Equals in value - '{json={"key": "value"}, name=test}', # Braces in value - '{message=He said "hello", name=John}', # Quotes in value + "{formula=x=y+1, status=active}", + '{json={"key": "value"}, name=test}', + '{message=He said "hello", name=John}', ], ) def test_to_struct_athena_complex_cases(input_value): - """Test complex cases with special characters return None or partial dict (safe fallback).""" result = _to_struct(input_value) - # With the new continue logic, these may return partial results instead of None - # Check if they return None (strict safety) or partial results (lenient approach) - assert result is None or isinstance(result, dict), ( - f"Complex case should return None or dict: {input_value} -> {result}" - ) - - -def test_to_map_athena_numeric_keys(): - """Test Athena map with numeric keys""" - from pyathena.converter import _to_map - - map_value = "{1=2, 3=4}" - result = _to_map(map_value) - expected = {"1": 2, "3": 4} - assert result == expected - - -def test_to_array_athena_numeric_elements(): - """Test Athena array with numeric elements""" - array_value = "[1, 2, 3, 4]" - result = _to_array(array_value) - expected = [1, 2, 3, 4] - assert result == expected - - -def test_to_array_athena_mixed_elements(): - """Test Athena array with mixed type elements""" - array_value = "[1, hello, true, null]" - result = _to_array(array_value) - expected = [1, "hello", True, None] - assert result == expected - - -def test_to_array_athena_struct_elements(): - """Test Athena array with struct elements""" - array_value = "[{name=John, age=30}, {name=Jane, age=25}]" - result = _to_array(array_value) - expected = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}] - assert result == expected - - -def test_to_array_athena_unnamed_struct_elements(): - """Test Athena array with unnamed struct elements""" - array_value = "[{Alice, 25}, {Bob, 30}]" - result = _to_array(array_value) - expected = [{"0": "Alice", "1": 25}, {"0": "Bob", "1": 30}] - assert result == expected - - -@pytest.mark.parametrize( - ("input_value", "expected"), - [ - # Array with nested structs (Issue #627) - ( - "[{header={stamp=2024-01-01, seq=123}, x=4.736}]", - [{"header": {"stamp": "2024-01-01", "seq": 123}, "x": 4.736}], - ), - # Multiple elements with nested structs - ( - "[{pos={x=1, y=2}, vel={x=0.5}}, {pos={x=3, y=4}, vel={x=1.5}}]", - [ - {"pos": {"x": 1, "y": 2}, "vel": {"x": 0.5}}, - {"pos": {"x": 3, "y": 4}, "vel": {"x": 1.5}}, - ], - ), - # Array with deeply nested structs - ( - "[{data={meta={id=1, active=true}}}]", - [{"data": {"meta": {"id": 1, "active": True}}}], - ), - ], -) -def test_to_array_athena_nested_struct_elements(input_value, expected): - """Test Athena array with nested struct elements (Issue #627).""" - result = _to_array(input_value) - assert result == expected + assert result is None or isinstance(result, dict) @pytest.mark.parametrize( "input_value", [ - "[1, 2, 3]", # Array JSON - '"just a string"', # String JSON - "42", # Number JSON + "[1, 2, 3]", + '"just a string"', + "42", ], ) def test_to_struct_non_dict_json(input_value): - """Test that non-dict JSON formats return None.""" - result = _to_struct(input_value) - assert result is None + assert _to_struct(input_value) is None + + +def test_to_map_athena_numeric_keys(): + assert _to_map("{1=2, 3=4}") == {"1": "2", "3": "4"} @pytest.mark.parametrize( ("input_value", "expected"), [ (None, None), - ( - "[1, 2, 3, 4, 5]", - [1, 2, 3, 4, 5], - ), - ( - '["apple", "banana", "cherry"]', - ["apple", "banana", "cherry"], - ), - ( - "[true, false, null]", - [True, False, None], - ), + ("[1, 2, 3, 4, 5]", [1, 2, 3, 4, 5]), + ('["apple", "banana", "cherry"]', ["apple", "banana", "cherry"]), + ("[true, false, null]", [True, False, None]), ( '[{"name": "John", "age": 30}, {"name": "Jane", "age": 25}]', [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}], @@ -219,9 +132,7 @@ def test_to_struct_non_dict_json(input_value): ], ) def test_to_array_json_formats(input_value, expected): - """Test ARRAY conversion for various JSON formats and edge cases.""" - result = _to_array(input_value) - assert result == expected + assert _to_array(input_value) == expected @pytest.mark.parametrize( @@ -229,63 +140,82 @@ def test_to_array_json_formats(input_value, expected): [ ("[1, 2, 3]", [1, 2, 3]), ("[]", []), + ("[true, false, null]", [True, False, None]), ("[apple, banana, cherry]", ["apple", "banana", "cherry"]), - ("[{Alice, 25}, {Bob, 30}]", [{"0": "Alice", "1": 25}, {"0": "Bob", "1": 30}]), + ( + "[{Alice, 25}, {Bob, 30}]", + [{"0": "Alice", "1": "25"}, {"0": "Bob", "1": "30"}], + ), ( "[{name=John, age=30}, {name=Jane, age=25}]", - [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}], + [{"name": "John", "age": "30"}, {"name": "Jane", "age": "25"}], ), - ("[true, false, null]", [True, False, None]), - ("[1, 2.5, hello]", [1, 2.5, "hello"]), + ("[1, 2.5, hello]", ["1", "2.5", "hello"]), ], ) def test_to_array_athena_native_formats(input_value, expected): - """Test ARRAY conversion for Athena native formats.""" - result = _to_array(input_value) - assert result == expected + assert _to_array(input_value) == expected + + +@pytest.mark.parametrize( + ("input_value", "expected"), + [ + ( + "[{header={stamp=2024-01-01, seq=123}, x=4.736}]", + [{"header": {"stamp": "2024-01-01", "seq": "123"}, "x": "4.736"}], + ), + ( + "[{pos={x=1, y=2}, vel={x=0.5}}, {pos={x=3, y=4}, vel={x=1.5}}]", + [ + {"pos": {"x": "1", "y": "2"}, "vel": {"x": "0.5"}}, + {"pos": {"x": "3", "y": "4"}, "vel": {"x": "1.5"}}, + ], + ), + ( + "[{data={meta={id=1, active=true}}}]", + [{"data": {"meta": {"id": "1", "active": "true"}}}], + ), + ], +) +def test_to_array_athena_nested_struct_elements(input_value, expected): + assert _to_array(input_value) == expected @pytest.mark.parametrize( ("input_value", "expected"), [ - ("[ARRAY[1, 2], ARRAY[3, 4]]", None), # Nested arrays (native format) - ("[[1, 2], [3, 4]]", [[1, 2], [3, 4]]), # Nested arrays (JSON format - parseable) - ("[MAP(ARRAY['key'], ARRAY['value'])]", None), # Complex nested structures + ("[ARRAY[1, 2], ARRAY[3, 4]]", None), + ("[[1, 2], [3, 4]]", [[1, 2], [3, 4]]), + ("[MAP(ARRAY['key'], ARRAY['value'])]", None), ], ) def test_to_array_complex_nested_cases(input_value, expected): - """Test complex nested array cases behavior.""" - result = _to_array(input_value) - assert result == expected + assert _to_array(input_value) == expected @pytest.mark.parametrize( "input_value", [ - '"just a string"', # String JSON - "42", # Number JSON - '{"key": "value"}', # Object JSON + '"just a string"', + "42", + '{"key": "value"}', ], ) def test_to_array_non_array_json(input_value): - """Test that non-array JSON formats return None.""" - result = _to_array(input_value) - assert result is None + assert _to_array(input_value) is None @pytest.mark.parametrize( "input_value", [ - "not an array", # Not bracketed - "[unclosed array", # Malformed - "closed array]", # Malformed - "[{malformed struct}", # Malformed struct + "not an array", + "[unclosed array", + "closed array]", + "[{malformed struct}", ], ) def test_to_array_invalid_formats(input_value): - """Test that invalid array formats return None.""" - result = _to_array(input_value) - assert result is None + assert _to_array(input_value) is None class TestDefaultTypeConverter: @@ -296,14 +226,12 @@ class TestDefaultTypeConverter: (None, None), ("", None), ("invalid json", None), - ("{a=1, b=2}", {"a": 1, "b": 2}), + ("{a=1, b=2}", {"a": "1", "b": "2"}), ], ) def test_struct_conversion(self, input_value, expected): - """Test DefaultTypeConverter STRUCT conversion for various input formats.""" converter = DefaultTypeConverter() - result = converter.convert("row", input_value) - assert result == expected + assert converter.convert("row", input_value) == expected @pytest.mark.parametrize( ("input_value", "expected"), @@ -318,7 +246,165 @@ def test_struct_conversion(self, input_value, expected): ], ) def test_array_conversion(self, input_value, expected): - """Test DefaultTypeConverter ARRAY conversion for various input formats.""" converter = DefaultTypeConverter() - result = converter.convert("array", input_value) - assert result == expected + assert converter.convert("array", input_value) == expected + + def test_array_varchar_keeps_strings(self): + converter = DefaultTypeConverter() + result = converter.convert("array", "[1234, 5678]", type_hint="array(varchar)") + assert result == ["1234", "5678"] + + def test_array_integer_converts_to_int(self): + converter = DefaultTypeConverter() + result = converter.convert("array", "[1, 2, 3]", type_hint="array(integer)") + assert result == [1, 2, 3] + + def test_array_boolean_converts(self): + converter = DefaultTypeConverter() + result = converter.convert("array", "[true, false]", type_hint="array(boolean)") + assert result == [True, False] + + def test_array_with_null(self): + converter = DefaultTypeConverter() + result = converter.convert("array", "[1, null, 3]", type_hint="array(integer)") + assert result == [1, None, 3] + + def test_map_varchar_integer(self): + converter = DefaultTypeConverter() + result = converter.convert( + "map", '{"key1": 1, "key2": 2}', type_hint="map(varchar, integer)" + ) + assert result == {"key1": 1, "key2": 2} + + def test_map_native_format_with_hints(self): + converter = DefaultTypeConverter() + result = converter.convert("map", "{a=1, b=2}", type_hint="map(varchar, integer)") + assert result == {"a": 1, "b": 2} + + def test_row_type_hint(self): + converter = DefaultTypeConverter() + result = converter.convert( + "row", + '{"name": "Alice", "age": 25}', + type_hint="row(name varchar, age integer)", + ) + assert result == {"name": "Alice", "age": 25} + + def test_row_native_format_with_hints(self): + converter = DefaultTypeConverter() + result = converter.convert( + "row", + "{name=Alice, age=25}", + type_hint="row(name varchar, age integer)", + ) + assert result == {"name": "Alice", "age": 25} + + def test_nested_array_of_row(self): + converter = DefaultTypeConverter() + result = converter.convert( + "array", + "[{name=Alice, age=25}, {name=Bob, age=30}]", + type_hint="array(row(name varchar, age integer))", + ) + assert result == [ + {"name": "Alice", "age": 25}, + {"name": "Bob", "age": 30}, + ] + + def test_array_varchar_prevents_number_inference(self): + converter = DefaultTypeConverter() + result = converter.convert( + "array", + "[1234, 5678, hello]", + type_hint="array(varchar)", + ) + assert result == ["1234", "5678", "hello"] + + def test_none_value_with_type_hint(self): + converter = DefaultTypeConverter() + assert converter.convert("array", None, type_hint="array(varchar)") is None + + def test_simple_type_hint(self): + converter = DefaultTypeConverter() + assert converter.convert("varchar", "hello", type_hint="varchar") == "hello" + + def test_type_hint_caching(self): + converter = DefaultTypeConverter() + converter.convert("array", "[1, 2]", type_hint="array(integer)") + assert "array(integer)" in converter._parsed_hints + converter.convert("array", "[3, 4]", type_hint="array(integer)") + assert len(converter._parsed_hints) == 1 + + def test_empty_array_with_type_hint(self): + converter = DefaultTypeConverter() + assert converter.convert("array", "[]", type_hint="array(varchar)") == [] + + def test_map_varchar_varchar(self): + converter = DefaultTypeConverter() + result = converter.convert("map", "{key1=123, key2=456}", type_hint="map(varchar, varchar)") + assert result == {"key1": "123", "key2": "456"} + + def test_row_with_nested_struct(self): + converter = DefaultTypeConverter() + result = converter.convert( + "row", + "{header={seq=123, stamp=2024}, x=4.5}", + type_hint="row(header row(seq integer, stamp varchar), x double)", + ) + assert result == {"header": {"seq": 123, "stamp": "2024"}, "x": 4.5} + + def test_fallback_on_malformed_value(self): + """When typed conversion fails (returns None), fall back to untyped conversion.""" + converter = DefaultTypeConverter() + # "not-an-array" doesn't look like an array — typed converter returns None. + # Untyped _to_array also returns None for this input, which is correct. + result = converter.convert("array", "not-an-array", type_hint="array(integer)") + assert result is None + + def test_fallback_preserves_struct_value(self): + """Malformed struct with type_hint still falls back to untyped parsing.""" + converter = DefaultTypeConverter() + # Struct with no closing brace — typed converter returns None. + # Untyped _to_struct also returns None here. + result = converter.convert("row", "{unclosed", type_hint="row(a integer)") + assert result is None + + def test_fallback_returns_untyped_result(self): + """When typed conversion returns None, untyped conversion is used.""" + converter = DefaultTypeConverter() + # The typed converter returns None for a struct that doesn't start with "{". + # The untyped _to_struct also returns None for non-struct input. + # Use an array example where typed converter returns None (not a bracket-wrapped + # value), but untyped _to_array can still parse it via JSON. + result = converter.convert( + "row", + '{"a": 1}', + type_hint="row(a varchar)", + ) + # Typed conversion succeeds here — "a" is varchar so "1" stays a string + assert result == {"a": "1"} + + def test_hive_syntax_through_converter(self): + """Hive-style syntax works end-to-end through DefaultTypeConverter.""" + converter = DefaultTypeConverter() + result = converter.convert("array", "[1, 2, 3]", type_hint="array") + assert result == [1, 2, 3] + + def test_hive_syntax_struct_through_converter(self): + """Hive struct syntax works end-to-end.""" + converter = DefaultTypeConverter() + result = converter.convert( + "row", + "{name=Alice, age=25}", + type_hint="struct", + ) + assert result == {"name": "Alice", "age": 25} + + def test_hive_syntax_caching(self): + """Hive syntax is normalized before cache lookup.""" + converter = DefaultTypeConverter() + converter.convert("array", "[1]", type_hint="array") + converter.convert("array", "[2]", type_hint="array(integer)") + # Both should normalize to "array(integer)" in the cache + assert "array(integer)" in converter._parsed_hints + assert len(converter._parsed_hints) == 1 diff --git a/tests/pyathena/test_cursor.py b/tests/pyathena/test_cursor.py index 27c34950..3dd15fc2 100644 --- a/tests/pyathena/test_cursor.py +++ b/tests/pyathena/test_cursor.py @@ -538,9 +538,9 @@ def test_complex(self, cursor): b"123", [1, 2], [1, 2], + {"1": "2", "3": "4"}, {"1": 2, "3": 4}, - {"1": 2, "3": 4}, - {"a": 1, "b": 2}, + {"a": "1", "b": "2"}, Decimal("0.1"), ) ] @@ -571,6 +571,78 @@ def test_complex(self, cursor): NUMBER, ] + def test_complex_with_type_hints(self, cursor): + # 1. Basic complex columns from one_row_complex + cursor.execute( + """ + SELECT col_array, col_map, col_struct + FROM one_row_complex + """, + result_set_type_hints={ + "col_array": "array(integer)", + "col_map": "map(integer, integer)", + "col_struct": "row(a integer, b integer)", + }, + ) + row = cursor.fetchone() + assert row[0] == [1, 2] + assert row[1] == {"1": 2, "3": 4} + assert row[2] == {"a": 1, "b": 2} + # Without type hints col_struct values are strings; with hints they are ints + assert isinstance(row[2]["a"], int) + assert isinstance(row[2]["b"], int) + # Map values should also be ints + assert isinstance(row[1]["1"], int) + + # 2. Nested struct + cursor.execute( + """ + SELECT CAST( + ROW(ROW('2024-01-01', 123), 4.736, 0.583) + AS ROW(header ROW(stamp VARCHAR, seq INTEGER), x DOUBLE, y DOUBLE) + ) AS positions + """, + result_set_type_hints={ + "positions": "row(header row(stamp varchar, seq integer), x double, y double)", + }, + ) + row = cursor.fetchone() + positions = row[0] + assert positions["header"]["stamp"] == "2024-01-01" + assert positions["header"]["seq"] == 123 + assert isinstance(positions["header"]["seq"], int) + assert positions["x"] == 4.736 + assert isinstance(positions["x"], float) + assert positions["y"] == 0.583 + assert isinstance(positions["y"], float) + + # 3. Array of nested structs + cursor.execute( + """ + SELECT CAST( + ARRAY[ + ROW(ROW(1, 2), ROW(CAST(0.5 AS DOUBLE))), + ROW(ROW(3, 4), ROW(CAST(1.5 AS DOUBLE))) + ] + AS ARRAY(ROW(pos ROW(x INTEGER, y INTEGER), vel ROW(x DOUBLE))) + ) AS data + """, + result_set_type_hints={ + "data": "array(row(pos row(x integer, y integer), vel row(x double)))", + }, + ) + row = cursor.fetchone() + data = row[0] + assert len(data) == 2 + assert data[0]["pos"]["x"] == 1 + assert isinstance(data[0]["pos"]["x"], int) + assert data[0]["pos"]["y"] == 2 + assert isinstance(data[0]["pos"]["y"], int) + assert data[0]["vel"]["x"] == 0.5 + assert isinstance(data[0]["vel"]["x"], float) + assert data[1]["pos"]["x"] == 3 + assert data[1]["vel"]["x"] == 1.5 + def test_cancel(self, cursor): def cancel(c): time.sleep(randint(5, 10)) @@ -1346,7 +1418,7 @@ def test_array_converter_behavior(self, cursor): test_cases = [ ("[1, 2, 3]", [1, 2, 3]), ('["a", "b", "c"]', ["a", "b", "c"]), - ("[{a=1, b=2}]", [{"a": 1, "b": 2}]), + ("[{a=1, b=2}]", [{"a": "1", "b": "2"}]), ("[]", []), (None, None), ("invalid", None), diff --git a/tests/pyathena/test_parser.py b/tests/pyathena/test_parser.py new file mode 100644 index 00000000..bd6ab9d7 --- /dev/null +++ b/tests/pyathena/test_parser.py @@ -0,0 +1,277 @@ +import pytest + +from pyathena.converter import _DEFAULT_CONVERTERS, _to_default, _to_struct +from pyathena.parser import ( + TypedValueConverter, + TypeNode, + TypeSignatureParser, + _normalize_hive_syntax, +) + + +class TestTypeSignatureParser: + def test_simple_type(self): + parser = TypeSignatureParser() + node = parser.parse("varchar") + assert node.type_name == "varchar" + assert node.children == [] + assert node.field_names is None + + def test_simple_type_case_insensitive(self): + parser = TypeSignatureParser() + node = parser.parse("VARCHAR") + assert node.type_name == "varchar" + + def test_array_type(self): + parser = TypeSignatureParser() + node = parser.parse("array(varchar)") + assert node.type_name == "array" + assert len(node.children) == 1 + assert node.children[0].type_name == "varchar" + + def test_array_of_integer(self): + parser = TypeSignatureParser() + node = parser.parse("array(integer)") + assert node.type_name == "array" + assert node.children[0].type_name == "integer" + + def test_map_type(self): + parser = TypeSignatureParser() + node = parser.parse("map(varchar, integer)") + assert node.type_name == "map" + assert len(node.children) == 2 + assert node.children[0].type_name == "varchar" + assert node.children[1].type_name == "integer" + + def test_row_type(self): + parser = TypeSignatureParser() + node = parser.parse("row(name varchar, age integer)") + assert node.type_name == "row" + assert len(node.children) == 2 + assert node.field_names == ["name", "age"] + assert node.children[0].type_name == "varchar" + assert node.children[1].type_name == "integer" + + def test_struct_type(self): + parser = TypeSignatureParser() + node = parser.parse("struct(name varchar, age integer)") + assert node.type_name == "struct" + assert node.field_names == ["name", "age"] + + def test_nested_array_of_row(self): + parser = TypeSignatureParser() + node = parser.parse("array(row(name varchar, age integer))") + assert node.type_name == "array" + assert len(node.children) == 1 + row_node = node.children[0] + assert row_node.type_name == "row" + assert row_node.field_names == ["name", "age"] + assert row_node.children[0].type_name == "varchar" + assert row_node.children[1].type_name == "integer" + + def test_map_with_complex_value(self): + parser = TypeSignatureParser() + node = parser.parse("map(varchar, row(x integer, y double))") + assert node.type_name == "map" + assert node.children[0].type_name == "varchar" + assert node.children[1].type_name == "row" + assert node.children[1].field_names == ["x", "y"] + + def test_deeply_nested(self): + parser = TypeSignatureParser() + node = parser.parse("array(row(data row(x integer, y integer), name varchar))") + assert node.type_name == "array" + row_node = node.children[0] + assert row_node.type_name == "row" + assert row_node.field_names == ["data", "name"] + assert row_node.children[0].type_name == "row" + assert row_node.children[0].field_names == ["x", "y"] + assert row_node.children[1].type_name == "varchar" + + def test_parameterized_type(self): + parser = TypeSignatureParser() + node = parser.parse("decimal(10, 2)") + assert node.type_name == "decimal" + + def test_varchar_with_length(self): + parser = TypeSignatureParser() + node = parser.parse("varchar(255)") + assert node.type_name == "varchar" + + def test_type_alias_int(self): + parser = TypeSignatureParser() + node = parser.parse("int") + assert node.type_name == "integer" + + def test_type_alias_in_complex_type(self): + parser = TypeSignatureParser() + node = parser.parse("array(int)") + assert node.type_name == "array" + assert node.children[0].type_name == "integer" + + def test_hive_syntax_simple(self): + parser = TypeSignatureParser() + node = parser.parse(_normalize_hive_syntax("array")) + assert node.type_name == "array" + assert node.children[0].type_name == "integer" + + def test_hive_syntax_struct(self): + parser = TypeSignatureParser() + node = parser.parse(_normalize_hive_syntax("struct")) + assert node.type_name == "struct" + assert node.field_names == ["a", "b"] + assert node.children[0].type_name == "integer" + assert node.children[1].type_name == "varchar" + + def test_hive_syntax_nested(self): + parser = TypeSignatureParser() + node = parser.parse(_normalize_hive_syntax("array>")) + assert node.type_name == "array" + struct_node = node.children[0] + assert struct_node.type_name == "struct" + assert struct_node.field_names == ["a", "b"] + assert struct_node.children[0].type_name == "integer" + assert struct_node.children[1].type_name == "varchar" + + def test_hive_syntax_map(self): + parser = TypeSignatureParser() + node = parser.parse(_normalize_hive_syntax("map")) + assert node.type_name == "map" + assert node.children[0].type_name == "string" + assert node.children[1].type_name == "integer" + + def test_mixed_syntax(self): + """Hive angle brackets wrapping Trino-style parenthesized inner type.""" + parser = TypeSignatureParser() + node = parser.parse(_normalize_hive_syntax("array")) + assert node.type_name == "array" + row_node = node.children[0] + assert row_node.type_name == "row" + assert row_node.field_names == ["a", "b"] + assert row_node.children[0].type_name == "integer" + assert row_node.children[1].type_name == "varchar" + + def test_normalize_hive_syntax_noop(self): + """Trino-style input passes through unchanged.""" + assert _normalize_hive_syntax("array(integer)") == "array(integer)" + + def test_normalize_hive_syntax_replaces(self): + assert _normalize_hive_syntax("array>") == "array(struct(a int))" + + def test_trailing_modifier_after_paren(self): + """Type with content after closing paren should not break parsing.""" + parser = TypeSignatureParser() + # Simulates a hypothetical "timestamp(3) with time zone" style input + node = parser.parse("decimal(10, 2) extra") + assert node.type_name == "decimal" + + +class TestTypedValueConverter: + @pytest.fixture + def converter(self): + return TypedValueConverter( + converters=_DEFAULT_CONVERTERS, + default_converter=_to_default, + struct_parser=_to_struct, + ) + + def test_simple_varchar(self, converter): + node = TypeNode("varchar") + assert converter.convert("hello", node) == "hello" + + def test_simple_integer(self, converter): + node = TypeNode("integer") + assert converter.convert("42", node) == 42 + + def test_array_of_varchar(self, converter): + parser = TypeSignatureParser() + node = parser.parse("array(varchar)") + assert converter.convert("[1234, 5678]", node) == ["1234", "5678"] + + def test_array_of_integer(self, converter): + parser = TypeSignatureParser() + node = parser.parse("array(integer)") + assert converter.convert("[1, 2, 3]", node) == [1, 2, 3] + + def test_map_varchar_integer(self, converter): + parser = TypeSignatureParser() + node = parser.parse("map(varchar, integer)") + assert converter.convert('{"a": 1, "b": 2}', node) == {"a": 1, "b": 2} + + def test_row_named_fields(self, converter): + parser = TypeSignatureParser() + node = parser.parse("row(name varchar, age integer)") + assert converter.convert("{name=Alice, age=25}", node) == {"name": "Alice", "age": 25} + + def test_nested_row(self, converter): + parser = TypeSignatureParser() + node = parser.parse("row(header row(seq integer, stamp varchar), x double)") + result = converter.convert("{header={seq=123, stamp=2024}, x=4.5}", node) + assert result == {"header": {"seq": 123, "stamp": "2024"}, "x": 4.5} + + def test_array_of_row_json(self, converter): + """JSON path: array(row(...)) with nested dict elements.""" + parser = TypeSignatureParser() + node = parser.parse("array(row(x integer, y double))") + result = converter.convert('[{"x": 1, "y": 2.5}, {"x": 3, "y": 4.0}]', node) + assert result == [{"x": 1, "y": 2.5}, {"x": 3, "y": 4.0}] + assert isinstance(result[0]["x"], int) + assert isinstance(result[0]["y"], float) + + def test_null_string_preserved_in_json(self, converter): + """JSON path: string "null" in array(varchar) must not become None.""" + parser = TypeSignatureParser() + node = parser.parse("array(varchar)") + result = converter.convert('["null", "x"]', node) + assert result == ["null", "x"] + + def test_map_with_row_value_native(self, converter): + """Native path: map(varchar, row(...)) with nested struct values.""" + parser = TypeSignatureParser() + node = parser.parse("map(varchar, row(x integer, y integer))") + result = converter.convert("{a={x=1, y=2}, b={x=3, y=4}}", node) + assert result == {"a": {"x": 1, "y": 2}, "b": {"x": 3, "y": 4}} + assert isinstance(result["a"]["x"], int) + + def test_nested_row_json(self, converter): + """JSON path: row containing row with nested dict values.""" + parser = TypeSignatureParser() + node = parser.parse("row(inner row(a integer, b varchar), val double)") + result = converter.convert('{"inner": {"a": 10, "b": "hello"}, "val": 3.14}', node) + assert result == {"inner": {"a": 10, "b": "hello"}, "val": 3.14} + assert isinstance(result["inner"]["a"], int) + assert isinstance(result["val"], float) + + def test_struct_json_name_based_type_matching(self, converter): + """JSON path: field types are matched by name, not position order.""" + parser = TypeSignatureParser() + node = parser.parse("row(name varchar, age integer)") + # JSON keys in reverse order compared to type definition + result = converter.convert('{"age": 25, "name": "Alice"}', node) + assert result == {"age": 25, "name": "Alice"} + assert isinstance(result["age"], int) + assert isinstance(result["name"], str) + + def test_nested_array_json(self, converter): + """JSON path: nested array like [[1,2],[3]] must be parsed via json.loads.""" + parser = TypeSignatureParser() + node = parser.parse("array(array(integer))") + result = converter.convert("[[1, 2], [3]]", node) + assert result == [[1, 2], [3]] + assert isinstance(result[0], list) + assert isinstance(result[0][0], int) + + def test_map_json_null_value_preserved(self, converter): + """JSON path: map with null values vs "null" string values.""" + parser = TypeSignatureParser() + node = parser.parse("map(varchar, varchar)") + result = converter.convert('{"a": null, "b": "null"}', node) + assert result["a"] is None + assert result["b"] == "null" + + def test_unnamed_struct_with_nested_value(self, converter): + """Unnamed struct split must respect nested braces.""" + parser = TypeSignatureParser() + node = parser.parse("row(inner row(x integer, y integer), val varchar)") + result = converter.convert("{inner={x=1, y=2}, val=hello}", node) + assert result == {"inner": {"x": 1, "y": 2}, "val": "hello"}