diff --git a/src/amp/loaders/implementations/deltalake_loader.py b/src/amp/loaders/implementations/deltalake_loader.py index f18696c..7dfbfc9 100644 --- a/src/amp/loaders/implementations/deltalake_loader.py +++ b/src/amp/loaders/implementations/deltalake_loader.py @@ -203,11 +203,7 @@ def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> self._refresh_table_reference() # Post-write optimizations - optimization_results = self._perform_post_write_optimizations(table.num_rows) - - # Store optimization results in base class metadata - if hasattr(self, '_last_batch_metadata'): - self._last_batch_metadata = optimization_results + _optimization_results = self._perform_post_write_optimizations() return batch.num_rows @@ -283,7 +279,7 @@ def _refresh_table_reference(self) -> None: self.logger.error(f'Failed to refresh table reference: {e}') # Don't set _table_exists = False here as the table might still exist - def _perform_post_write_optimizations(self, rows_written: int) -> Dict[str, Any]: + def _perform_post_write_optimizations(self) -> Dict[str, Any]: """Perform post-write optimizations with robust API handling""" optimization_results = {} @@ -455,10 +451,6 @@ def _get_loader_batch_metadata(self, batch: pa.RecordBatch, duration: float, **k if self._table_exists and self._delta_table is not None: metadata['table_version'] = self._delta_table.version() - # Add optimization results if available - if hasattr(self, '_last_batch_metadata'): - metadata['optimization_results'] = self._last_batch_metadata - return metadata def _get_loader_table_metadata( diff --git a/src/amp/loaders/implementations/iceberg_loader.py b/src/amp/loaders/implementations/iceberg_loader.py index 18e2b45..afb80b5 100644 --- a/src/amp/loaders/implementations/iceberg_loader.py +++ b/src/amp/loaders/implementations/iceberg_loader.py @@ -88,6 +88,7 @@ def __init__(self, config: Dict[str, Any]): self._current_table: Optional[IcebergTable] = None self._namespace_exists: bool = False self.enable_statistics: bool = config.get('enable_statistics', True) + self._table_cache: Dict[str, IcebergTable] = {} # Cache tables by identifier def _get_required_config_fields(self) -> list[str]: """Return required configuration fields""" @@ -119,6 +120,7 @@ def disconnect(self) -> None: if self._catalog: self._catalog = None + self._table_cache.clear() # Clear table cache on disconnect self._is_connected = False self.logger.info('Iceberg loader disconnected') @@ -130,9 +132,16 @@ def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> # Fix timestamps for Iceberg compatibility table = self._fix_timestamps(table) - # Get or create the Iceberg table + # Get the Iceberg table (already created by _create_table_from_schema if needed) mode = kwargs.get('mode', LoadMode.APPEND) - iceberg_table = self._get_or_create_table(table_name, table.schema) + table_identifier = f'{self.config.namespace}.{table_name}' + + # Use cached table if available + if table_identifier in self._table_cache: + iceberg_table = self._table_cache[table_identifier] + else: + iceberg_table = self._catalog.load_table(table_identifier) + self._table_cache[table_identifier] = iceberg_table # Validate schema compatibility (unless overwriting) if mode != LoadMode.OVERWRITE: @@ -143,15 +152,28 @@ def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> return rows_written - def _create_table_from_schema(self, schema: pa.Schema, table_name: str) -> None: - """Create table from Arrow schema""" - # Iceberg handles table creation in _get_or_create_table - self.logger.info(f"Iceberg will create table '{table_name}' on first write with appropriate schema") - def _clear_table(self, table_name: str) -> None: """Clear table for overwrite mode""" # Iceberg handles overwrites internally - self.logger.info(f"Iceberg will handle overwrite for table '{table_name}'") + # Clear from cache to ensure fresh state after overwrite + table_identifier = f'{self.config.namespace}.{table_name}' + if table_identifier in self._table_cache: + del self._table_cache[table_identifier] + + def _fix_schema_timestamps(self, schema: pa.Schema) -> pa.Schema: + """Convert nanosecond timestamps to microseconds in schema for Iceberg compatibility""" + # Check if conversion is needed + if not any(pa.types.is_timestamp(f.type) and f.type.unit == 'ns' for f in schema): + return schema + + new_fields = [] + for field in schema: + if pa.types.is_timestamp(field.type) and field.type.unit == 'ns': + new_fields.append(pa.field(field.name, pa.timestamp('us', tz=field.type.tz))) + else: + new_fields.append(field) + + return pa.schema(new_fields) def _fix_timestamps(self, arrow_table: pa.Table) -> pa.Table: """Convert nanosecond timestamps to microseconds for Iceberg compatibility""" @@ -219,33 +241,36 @@ def _check_namespace_exists(self, namespace: str) -> None: except Exception as e: raise NoSuchNamespaceError(f"Failed to verify namespace '{namespace}': {str(e)}") from e - def _get_or_create_table(self, table_name: str, schema: pa.Schema) -> IcebergTable: - """Get existing table or create new one""" - table_identifier = f'{self.config.namespace}.{table_name}' + def _create_table_from_schema(self, schema: pa.Schema, table_name: str) -> None: + """Create table if it doesn't exist - called once by base class before first batch""" + if not self.config.create_table: + # If create_table is False, just verify table exists + table_identifier = f'{self.config.namespace}.{table_name}' + try: + table = self._catalog.load_table(table_identifier) + # Cache the existing table + self._table_cache[table_identifier] = table + self.logger.debug(f'Table already exists: {table_identifier}') + except (NoSuchTableError, NoSuchIcebergTableError) as e: + raise NoSuchTableError(f"Table '{table_identifier}' not found and create_table=False") from e + return - try: - table = self._catalog.load_table(table_identifier) - self.logger.debug(f'Loaded existing table: {table_identifier}') - return table + table_identifier = f'{self.config.namespace}.{table_name}' - except (NoSuchTableError, NoSuchIcebergTableError) as e: - if not self.config.create_table: - raise NoSuchTableError(f"Table '{table_identifier}' not found and create_table=False") from e + # Fix timestamps in schema before creating table + fixed_schema = self._fix_schema_timestamps(schema) - try: - # Use partition_spec if provided - if self.config.partition_spec: - table = self._catalog.create_table( - identifier=table_identifier, schema=schema, partition_spec=self.config.partition_spec - ) - else: - # Create table without partitioning - table = self._catalog.create_table(identifier=table_identifier, schema=schema) - self.logger.info(f'Created new table: {table_identifier}') - return table + # Use create_table_if_not_exists for simpler logic + if self.config.partition_spec: + table = self._catalog.create_table_if_not_exists( + identifier=table_identifier, schema=fixed_schema, partition_spec=self.config.partition_spec + ) + else: + table = self._catalog.create_table_if_not_exists(identifier=table_identifier, schema=fixed_schema) - except Exception as e: - raise RuntimeError(f"Failed to create table '{table_identifier}': {str(e)}") from e + # Cache the newly created/loaded table + self._table_cache[table_identifier] = table + self.logger.info(f"Table '{table_identifier}' ready (created if needed)") def _validate_schema_compatibility(self, iceberg_table: IcebergTable, arrow_schema: pa.Schema) -> None: """Validate that Arrow schema is compatible with Iceberg table schema and perform schema evolution if enabled""" @@ -383,45 +408,13 @@ def _perform_load_operation(self, iceberg_table: IcebergTable, arrow_table: pa.T def _get_loader_batch_metadata(self, batch: pa.RecordBatch, duration: float, **kwargs) -> Dict[str, Any]: """Get Iceberg-specific metadata for batch operation""" - metadata = {'namespace': self.config.namespace} - - # Add partition columns if available - table_name = kwargs.get('table_name') - if table_name and self._table_exists(table_name): - try: - table_info = self.get_table_info(table_name) - metadata['partition_columns'] = table_info.get('partition_columns', []) - except Exception: - metadata['partition_columns'] = [] - else: - # For new tables, get partition fields from partition_spec if available - metadata['partition_columns'] = [] - - return metadata + return {'namespace': self.config.namespace} def _get_loader_table_metadata( self, table: pa.Table, duration: float, batch_count: int, **kwargs ) -> Dict[str, Any]: """Get Iceberg-specific metadata for table operation""" - metadata = {'namespace': self.config.namespace} - - # Add partition columns if available - table_name = kwargs.get('table_name') - if table_name and self._table_exists(table_name): - try: - table_info = self.get_table_info(table_name) - metadata['partition_columns'] = table_info.get('partition_columns', []) - except Exception: - metadata['partition_columns'] = [] - else: - # For new tables, get partition fields from partition_spec if available - metadata['partition_columns'] = [] - if self.config.partition_spec and hasattr(self.config.partition_spec, 'fields'): - # partition_spec.fields contains partition field definitions - # We'll extract them during table creation - metadata['partition_columns'] = [] # Will be populated after table creation - - return metadata + return {'namespace': self.config.namespace} def _table_exists(self, table_name: str) -> bool: """Check if a table exists""" diff --git a/src/amp/loaders/utils.py b/src/amp/loaders/utils.py deleted file mode 100644 index 5aa9332..0000000 --- a/src/amp/loaders/utils.py +++ /dev/null @@ -1,364 +0,0 @@ -""" -Common utilities for data loaders to reduce code duplication. -""" - -from typing import Any, Dict, Set - -import pyarrow as pa - - -class ArrowTypeConverter: - """ - Centralized Arrow type conversion utilities used across multiple loaders. - """ - - @staticmethod - def get_postgresql_type_mapping() -> Dict[pa.DataType, str]: - """Get Arrow to PostgreSQL type mapping""" - return { - # Integer types - pa.int8(): 'SMALLINT', - pa.int16(): 'SMALLINT', - pa.int32(): 'INTEGER', - pa.int64(): 'BIGINT', - pa.uint8(): 'SMALLINT', - pa.uint16(): 'INTEGER', - pa.uint32(): 'BIGINT', - pa.uint64(): 'BIGINT', - # Floating point types - pa.float32(): 'REAL', - pa.float64(): 'DOUBLE PRECISION', - pa.float16(): 'REAL', - # String types - use TEXT for blockchain data which can be large - pa.string(): 'TEXT', - pa.large_string(): 'TEXT', - pa.utf8(): 'TEXT', - # Binary types - use BYTEA for efficient storage - pa.binary(): 'BYTEA', - pa.large_binary(): 'BYTEA', - # Boolean type - pa.bool_(): 'BOOLEAN', - # Date and time types - pa.date32(): 'DATE', - pa.date64(): 'DATE', - pa.time32('s'): 'TIME', - pa.time32('ms'): 'TIME', - pa.time64('us'): 'TIME', - pa.time64('ns'): 'TIME', - } - - @staticmethod - def get_snowflake_type_mapping() -> Dict[pa.DataType, str]: - """Get Arrow to Snowflake type mapping""" - return { - # Integer types - pa.int8(): 'TINYINT', - pa.int16(): 'SMALLINT', - pa.int32(): 'INTEGER', - pa.int64(): 'BIGINT', - pa.uint8(): 'SMALLINT', - pa.uint16(): 'INTEGER', - pa.uint32(): 'BIGINT', - pa.uint64(): 'BIGINT', - # Floating point types - pa.float32(): 'FLOAT', - pa.float64(): 'DOUBLE', - pa.float16(): 'FLOAT', - # String types - pa.string(): 'VARCHAR', - pa.large_string(): 'VARCHAR', - pa.utf8(): 'VARCHAR', - # Binary types - pa.binary(): 'BINARY', - pa.large_binary(): 'BINARY', - # Boolean type - pa.bool_(): 'BOOLEAN', - # Date and time types - pa.date32(): 'DATE', - pa.date64(): 'DATE', - pa.time32('s'): 'TIME', - pa.time32('ms'): 'TIME', - pa.time64('us'): 'TIME', - pa.time64('ns'): 'TIME', - } - - @staticmethod - def convert_arrow_to_iceberg_type(arrow_type: pa.DataType) -> Any: - """ - Convert Arrow data type to equivalent Iceberg type. - Extracted from IcebergLoader implementation. - """ - # Import here to avoid circular dependencies - try: - from pyiceberg.types import ( - BinaryType, - BooleanType, - DateType, - DecimalType, - DoubleType, - FloatType, - IntegerType, - ListType, - LongType, - NestedField, - StringType, - StructType, - TimestampType, - ) - except ImportError as e: - raise ImportError("Iceberg type conversion requires 'pyiceberg' package") from e - - if pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type): - return StringType() - elif pa.types.is_int32(arrow_type): - return IntegerType() - elif pa.types.is_int64(arrow_type): - return LongType() - elif pa.types.is_float32(arrow_type): - return FloatType() - elif pa.types.is_float64(arrow_type): - return DoubleType() - elif pa.types.is_boolean(arrow_type): - return BooleanType() - elif pa.types.is_decimal(arrow_type): - return DecimalType(arrow_type.precision, arrow_type.scale) - elif pa.types.is_date32(arrow_type) or pa.types.is_date64(arrow_type): - return DateType() - elif pa.types.is_timestamp(arrow_type): - return TimestampType() - elif pa.types.is_binary(arrow_type) or pa.types.is_large_binary(arrow_type): - return BinaryType() - elif pa.types.is_list(arrow_type): - element_type = ArrowTypeConverter.convert_arrow_to_iceberg_type(arrow_type.value_type) - return ListType(1, element_type, element_required=False) - elif pa.types.is_struct(arrow_type): - nested_fields = [] - for i, field in enumerate(arrow_type): - field_type = ArrowTypeConverter.convert_arrow_to_iceberg_type(field.type) - nested_fields.append(NestedField(i + 1, field.name, field_type, required=not field.nullable)) - return StructType(nested_fields) - else: - # Fallback to string for unsupported types - return StringType() - - @staticmethod - def convert_arrow_field_to_sql(field: pa.Field, target_system: str) -> str: - """ - Convert an Arrow field to SQL column definition. - - Args: - field: Arrow field to convert - target_system: Target system ('postgresql', 'snowflake', etc.) - - Returns: - SQL column definition string - """ - type_mappings = { - 'postgresql': ArrowTypeConverter.get_postgresql_type_mapping(), - 'snowflake': ArrowTypeConverter.get_snowflake_type_mapping(), - } - - if target_system not in type_mappings: - raise ValueError(f'Unsupported target system: {target_system}') - - type_mapping = type_mappings[target_system] - - # Handle complex types - if pa.types.is_timestamp(field.type): - if target_system == 'postgresql': - if field.type.tz is not None: - sql_type = 'TIMESTAMPTZ' - else: - sql_type = 'TIMESTAMP' - else: # snowflake - if field.type.tz is not None: - sql_type = 'TIMESTAMP_TZ' - else: - sql_type = 'TIMESTAMP_NTZ' - elif pa.types.is_date(field.type): - sql_type = 'DATE' - elif pa.types.is_time(field.type): - sql_type = 'TIME' - elif pa.types.is_decimal(field.type): - decimal_type = field.type - if target_system == 'postgresql': - sql_type = f'NUMERIC({decimal_type.precision},{decimal_type.scale})' - else: # snowflake - sql_type = f'NUMBER({decimal_type.precision},{decimal_type.scale})' - elif pa.types.is_list(field.type) or pa.types.is_large_list(field.type): - if target_system == 'postgresql': - sql_type = 'TEXT' # JSON-like data - else: # snowflake - sql_type = 'VARIANT' - elif pa.types.is_struct(field.type): - if target_system == 'postgresql': - sql_type = 'TEXT' # JSON-like data - else: # snowflake - sql_type = 'OBJECT' - elif pa.types.is_binary(field.type) or pa.types.is_large_binary(field.type): - if target_system == 'postgresql': - sql_type = 'BYTEA' - else: # snowflake - sql_type = 'BINARY' - elif pa.types.is_fixed_size_binary(field.type): - if target_system == 'postgresql': - sql_type = 'BYTEA' - else: # snowflake - sql_type = f'BINARY({field.type.byte_width})' - else: - # Use mapping or default - if target_system == 'postgresql': - sql_type = type_mapping.get(field.type, 'TEXT') - else: # snowflake - sql_type = type_mapping.get(field.type, 'VARCHAR') - - # Handle nullability - nullable = '' if field.nullable else ' NOT NULL' - - # Quote column name for safety - return f'"{field.name}" {sql_type}{nullable}' - - -class LoaderConfigValidator: - """ - Common configuration validation utilities. - """ - - @staticmethod - def validate_required_fields(config: Dict[str, Any], required_fields: Set[str]) -> None: - """ - Validate that all required fields are present in config. - - Args: - config: Configuration dictionary - required_fields: Set of required field names - - Raises: - ValueError: If any required fields are missing - """ - missing_fields = required_fields - set(config.keys()) - if missing_fields: - raise ValueError(f'Missing required configuration fields: {missing_fields}') - - @staticmethod - def validate_connection_config(config: Dict[str, Any], loader_type: str) -> None: - """ - Validate connection-specific configuration. - - Args: - config: Configuration dictionary - loader_type: Type of loader (postgresql, redis, etc.) - """ - connection_requirements = { - 'postgresql': {'host', 'database', 'user', 'password'}, - 'redis': {'host'}, # port is optional with default - 'snowflake': {'account', 'user', 'warehouse', 'database'}, - 'lmdb': {'db_path'}, - 'deltalake': {'table_path'}, - 'iceberg': {'catalog_config', 'namespace'}, - } - - if loader_type in connection_requirements: - required_fields = connection_requirements[loader_type] - LoaderConfigValidator.validate_required_fields(config, required_fields) - - -class TableNameUtils: - """ - Utilities for table name handling and validation. - """ - - @staticmethod - def sanitize_table_name(table_name: str, target_system: str) -> str: - """ - Sanitize table name for target system. - - Args: - table_name: Original table name - target_system: Target system name - - Returns: - Sanitized table name - """ - # Basic sanitization - can be extended for specific systems - sanitized = table_name.replace('-', '_').replace(' ', '_') - - # System-specific rules - if target_system == 'postgresql': - # PostgreSQL is case-sensitive when quoted, prefer lowercase - sanitized = sanitized.lower() - elif target_system == 'snowflake': - # Snowflake prefers uppercase - sanitized = sanitized.upper() - - return sanitized - - @staticmethod - def quote_identifier(identifier: str, target_system: str) -> str: - """ - Quote identifier for target system. - - Args: - identifier: Identifier to quote - target_system: Target system name - - Returns: - Quoted identifier - """ - if target_system in ['postgresql', 'snowflake']: - return f'"{identifier}"' - else: - return identifier - - -class CommonPatterns: - """ - Common patterns extracted from loader implementations. - """ - - @staticmethod - def get_connection_info_logger_template() -> str: - """Get template for logging connection info""" - return 'Connected to {system} {version} at {host}:{port} database: {database}' - - @staticmethod - def get_standard_metadata_fields() -> Set[str]: - """Get standard metadata fields that should be included in LoadResult""" - return { - 'batch_size', - 'schema_fields', - 'throughput_rows_per_sec', - 'total_rows', - 'batches_processed', - 'table_size_mb', - } - - @staticmethod - def convert_bytes_for_redis(value: Any) -> bytes: - """ - Convert value to bytes for Redis storage. - Extracted from RedisLoader pattern. - """ - if isinstance(value, bytes): - return value - elif isinstance(value, (int, float)): - return str(value).encode('utf-8') - elif isinstance(value, bool): - return b'1' if value else b'0' - elif isinstance(value, str): - return value.encode('utf-8') - else: - return str(value).encode('utf-8') - - @staticmethod - def has_binary_columns(schema: pa.Schema) -> bool: - """ - Check if schema contains any binary column types. - Extracted from PostgreSQL helpers. - """ - return any( - pa.types.is_binary(field.type) - or pa.types.is_large_binary(field.type) - or pa.types.is_fixed_size_binary(field.type) - for field in schema - ) diff --git a/tests/integration/test_iceberg_loader.py b/tests/integration/test_iceberg_loader.py index bce0aa6..94786ab 100644 --- a/tests/integration/test_iceberg_loader.py +++ b/tests/integration/test_iceberg_loader.py @@ -228,7 +228,6 @@ def test_partitioning(self, iceberg_partitioned_config, small_test_data): assert result.success == True # Note: Partitioning requires creating PartitionSpec objects now - assert result.metadata['partition_columns'] == [] assert result.metadata['namespace'] == iceberg_partitioned_config['namespace'] def test_timestamp_conversion(self, iceberg_basic_config):