Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions src/amp/loaders/implementations/deltalake_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,7 @@ def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) ->
self._refresh_table_reference()

# Post-write optimizations
optimization_results = self._perform_post_write_optimizations(table.num_rows)

# Store optimization results in base class metadata
if hasattr(self, '_last_batch_metadata'):
self._last_batch_metadata = optimization_results
_optimization_results = self._perform_post_write_optimizations()

return batch.num_rows

Expand Down Expand Up @@ -283,7 +279,7 @@ def _refresh_table_reference(self) -> None:
self.logger.error(f'Failed to refresh table reference: {e}')
# Don't set _table_exists = False here as the table might still exist

def _perform_post_write_optimizations(self, rows_written: int) -> Dict[str, Any]:
def _perform_post_write_optimizations(self) -> Dict[str, Any]:
"""Perform post-write optimizations with robust API handling"""
optimization_results = {}

Expand Down Expand Up @@ -455,10 +451,6 @@ def _get_loader_batch_metadata(self, batch: pa.RecordBatch, duration: float, **k
if self._table_exists and self._delta_table is not None:
metadata['table_version'] = self._delta_table.version()

# Add optimization results if available
if hasattr(self, '_last_batch_metadata'):
metadata['optimization_results'] = self._last_batch_metadata

return metadata

def _get_loader_table_metadata(
Expand Down
123 changes: 58 additions & 65 deletions src/amp/loaders/implementations/iceberg_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(self, config: Dict[str, Any]):
self._current_table: Optional[IcebergTable] = None
self._namespace_exists: bool = False
self.enable_statistics: bool = config.get('enable_statistics', True)
self._table_cache: Dict[str, IcebergTable] = {} # Cache tables by identifier

def _get_required_config_fields(self) -> list[str]:
"""Return required configuration fields"""
Expand Down Expand Up @@ -119,6 +120,7 @@ def disconnect(self) -> None:
if self._catalog:
self._catalog = None

self._table_cache.clear() # Clear table cache on disconnect
self._is_connected = False
self.logger.info('Iceberg loader disconnected')

Expand All @@ -130,9 +132,16 @@ def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) ->
# Fix timestamps for Iceberg compatibility
table = self._fix_timestamps(table)

# Get or create the Iceberg table
# Get the Iceberg table (already created by _create_table_from_schema if needed)
mode = kwargs.get('mode', LoadMode.APPEND)
iceberg_table = self._get_or_create_table(table_name, table.schema)
table_identifier = f'{self.config.namespace}.{table_name}'

# Use cached table if available
if table_identifier in self._table_cache:
iceberg_table = self._table_cache[table_identifier]
else:
iceberg_table = self._catalog.load_table(table_identifier)
self._table_cache[table_identifier] = iceberg_table

# Validate schema compatibility (unless overwriting)
if mode != LoadMode.OVERWRITE:
Expand All @@ -143,15 +152,28 @@ def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) ->

return rows_written

def _create_table_from_schema(self, schema: pa.Schema, table_name: str) -> None:
"""Create table from Arrow schema"""
# Iceberg handles table creation in _get_or_create_table
self.logger.info(f"Iceberg will create table '{table_name}' on first write with appropriate schema")

def _clear_table(self, table_name: str) -> None:
"""Clear table for overwrite mode"""
# Iceberg handles overwrites internally
self.logger.info(f"Iceberg will handle overwrite for table '{table_name}'")
# Clear from cache to ensure fresh state after overwrite
table_identifier = f'{self.config.namespace}.{table_name}'
if table_identifier in self._table_cache:
del self._table_cache[table_identifier]

def _fix_schema_timestamps(self, schema: pa.Schema) -> pa.Schema:
"""Convert nanosecond timestamps to microseconds in schema for Iceberg compatibility"""
# Check if conversion is needed
if not any(pa.types.is_timestamp(f.type) and f.type.unit == 'ns' for f in schema):
return schema

new_fields = []
for field in schema:
if pa.types.is_timestamp(field.type) and field.type.unit == 'ns':
new_fields.append(pa.field(field.name, pa.timestamp('us', tz=field.type.tz)))
else:
new_fields.append(field)

return pa.schema(new_fields)

def _fix_timestamps(self, arrow_table: pa.Table) -> pa.Table:
"""Convert nanosecond timestamps to microseconds for Iceberg compatibility"""
Expand Down Expand Up @@ -219,33 +241,36 @@ def _check_namespace_exists(self, namespace: str) -> None:
except Exception as e:
raise NoSuchNamespaceError(f"Failed to verify namespace '{namespace}': {str(e)}") from e

def _get_or_create_table(self, table_name: str, schema: pa.Schema) -> IcebergTable:
"""Get existing table or create new one"""
table_identifier = f'{self.config.namespace}.{table_name}'
def _create_table_from_schema(self, schema: pa.Schema, table_name: str) -> None:
"""Create table if it doesn't exist - called once by base class before first batch"""
if not self.config.create_table:
# If create_table is False, just verify table exists
table_identifier = f'{self.config.namespace}.{table_name}'
try:
table = self._catalog.load_table(table_identifier)
# Cache the existing table
self._table_cache[table_identifier] = table
self.logger.debug(f'Table already exists: {table_identifier}')
except (NoSuchTableError, NoSuchIcebergTableError) as e:
raise NoSuchTableError(f"Table '{table_identifier}' not found and create_table=False") from e
return

try:
table = self._catalog.load_table(table_identifier)
self.logger.debug(f'Loaded existing table: {table_identifier}')
return table
table_identifier = f'{self.config.namespace}.{table_name}'

except (NoSuchTableError, NoSuchIcebergTableError) as e:
if not self.config.create_table:
raise NoSuchTableError(f"Table '{table_identifier}' not found and create_table=False") from e
# Fix timestamps in schema before creating table
fixed_schema = self._fix_schema_timestamps(schema)

try:
# Use partition_spec if provided
if self.config.partition_spec:
table = self._catalog.create_table(
identifier=table_identifier, schema=schema, partition_spec=self.config.partition_spec
)
else:
# Create table without partitioning
table = self._catalog.create_table(identifier=table_identifier, schema=schema)
self.logger.info(f'Created new table: {table_identifier}')
return table
# Use create_table_if_not_exists for simpler logic
if self.config.partition_spec:
table = self._catalog.create_table_if_not_exists(
identifier=table_identifier, schema=fixed_schema, partition_spec=self.config.partition_spec
)
else:
table = self._catalog.create_table_if_not_exists(identifier=table_identifier, schema=fixed_schema)

except Exception as e:
raise RuntimeError(f"Failed to create table '{table_identifier}': {str(e)}") from e
# Cache the newly created/loaded table
self._table_cache[table_identifier] = table
self.logger.info(f"Table '{table_identifier}' ready (created if needed)")

def _validate_schema_compatibility(self, iceberg_table: IcebergTable, arrow_schema: pa.Schema) -> None:
"""Validate that Arrow schema is compatible with Iceberg table schema and perform schema evolution if enabled"""
Expand Down Expand Up @@ -383,45 +408,13 @@ def _perform_load_operation(self, iceberg_table: IcebergTable, arrow_table: pa.T

def _get_loader_batch_metadata(self, batch: pa.RecordBatch, duration: float, **kwargs) -> Dict[str, Any]:
"""Get Iceberg-specific metadata for batch operation"""
metadata = {'namespace': self.config.namespace}

# Add partition columns if available
table_name = kwargs.get('table_name')
if table_name and self._table_exists(table_name):
try:
table_info = self.get_table_info(table_name)
metadata['partition_columns'] = table_info.get('partition_columns', [])
except Exception:
metadata['partition_columns'] = []
else:
# For new tables, get partition fields from partition_spec if available
metadata['partition_columns'] = []

return metadata
return {'namespace': self.config.namespace}

def _get_loader_table_metadata(
self, table: pa.Table, duration: float, batch_count: int, **kwargs
) -> Dict[str, Any]:
"""Get Iceberg-specific metadata for table operation"""
metadata = {'namespace': self.config.namespace}

# Add partition columns if available
table_name = kwargs.get('table_name')
if table_name and self._table_exists(table_name):
try:
table_info = self.get_table_info(table_name)
metadata['partition_columns'] = table_info.get('partition_columns', [])
except Exception:
metadata['partition_columns'] = []
else:
# For new tables, get partition fields from partition_spec if available
metadata['partition_columns'] = []
if self.config.partition_spec and hasattr(self.config.partition_spec, 'fields'):
# partition_spec.fields contains partition field definitions
# We'll extract them during table creation
metadata['partition_columns'] = [] # Will be populated after table creation

return metadata
return {'namespace': self.config.namespace}

def _table_exists(self, table_name: str) -> bool:
"""Check if a table exists"""
Expand Down
Loading
Loading