2222from abc import ABC , abstractmethod
2323from copy import copy
2424from dataclasses import dataclass
25- from datetime import date , datetime
25+ from datetime import datetime
2626from enum import Enum
2727from functools import cached_property , singledispatch
2828from itertools import chain
6767 write_manifest ,
6868 write_manifest_list ,
6969)
70- from pyiceberg .partitioning import PartitionSpec
70+ from pyiceberg .partitioning import PartitionFieldValue , PartitionKey , PartitionSpec
7171from pyiceberg .schema import (
7272 PartnerAccessor ,
7373 Schema ,
107107 Identifier ,
108108 KeyDefaultDict ,
109109 Properties ,
110- Record ,
111110)
112111from pyiceberg .types import (
113112 IcebergType ,
118117 StructType ,
119118)
120119from pyiceberg .utils .concurrent import ExecutorFactory
121- from pyiceberg .utils .datetime import date_to_days , datetime_to_micros , datetime_to_millis
120+ from pyiceberg .utils .datetime import datetime_to_millis
122121
123122if TYPE_CHECKING :
124123 import pandas as pd
@@ -2257,7 +2256,7 @@ class WriteTask:
22572256 def generate_data_file_partition_path (self ) -> str :
22582257 if self .partition_key is None :
22592258 raise ValueError ("Cannot generate partition path based on non-partitioned WriteTask" )
2260- return self .partition_key .to_path (self . schema )
2259+ return self .partition_key .to_path ()
22612260
22622261 def generate_data_file_filename (self , extension : str ) -> str :
22632262 # Mimics the behavior in the Java API:
@@ -2467,41 +2466,6 @@ class TablePartition:
24672466 arrow_table_partition : pa .Table
24682467
24692468
2470- @dataclass (frozen = True )
2471- class PartitionKey :
2472- raw_partition_key : Record # partition key in raw python type
2473- partition_spec : PartitionSpec
2474-
2475- # this only supports identity transform now
2476- @property
2477- def partition (self ) -> Record : # partition key in iceberg type
2478- iceberg_typed_key_values = {
2479- field_name : iceberg_typed_value (getattr (self .raw_partition_key , field_name , None ))
2480- for field_name in self .raw_partition_key ._position_to_field_name
2481- }
2482-
2483- return Record (** iceberg_typed_key_values )
2484-
2485- def to_path (self , schema : Schema ) -> str :
2486- return self .partition_spec .partition_to_path (self .partition , schema )
2487-
2488-
2489- @singledispatch
2490- def iceberg_typed_value (value : Any ) -> Any :
2491- return value
2492-
2493-
2494- @iceberg_typed_value .register (datetime )
2495- def _ (value : Any ) -> int :
2496- val = datetime_to_micros (value )
2497- return val
2498-
2499-
2500- @iceberg_typed_value .register (date )
2501- def _ (value : Any ) -> int :
2502- return date_to_days (value )
2503-
2504-
25052469def _get_partition_sort_order (partition_columns : list [str ], reverse : bool = False ) -> dict [str , Any ]:
25062470 order = 'ascending' if not reverse else 'descending'
25072471 null_placement = 'at_start' if reverse else 'at_end'
@@ -2538,15 +2502,35 @@ def _get_partition_columns(iceberg_table: Table, arrow_table: pa.Table) -> list[
25382502 return partition_cols
25392503
25402504
2541- def _get_partition_key (
2542- arrow_table : pa .Table , partition_columns : list [str ], offset : int , partition_spec : PartitionSpec
2543- ) -> PartitionKey :
2544- # todo: Instead of fetching partition keys one at a time, try filtering by a mask made of offsets, and convert to py together,
2545- # possibly slightly more efficient.
2546- return PartitionKey (
2547- raw_partition_key = Record (** {col : arrow_table .column (col )[offset ].as_py () for col in partition_columns }),
2548- partition_spec = partition_spec ,
2549- )
2505+ def _get_table_partitions (
2506+ arrow_table : pa .Table ,
2507+ partition_spec : PartitionSpec ,
2508+ schema : Schema ,
2509+ slice_instructions : list [dict [str , Any ]],
2510+ ) -> list [TablePartition ]:
2511+ sorted_slice_instructions = sorted (slice_instructions , key = lambda x : x ['offset' ])
2512+
2513+ partition_fields = partition_spec .fields
2514+
2515+ offsets = [inst ["offset" ] for inst in sorted_slice_instructions ]
2516+ projected_and_filtered = {
2517+ partition_field .source_id : arrow_table [schema .find_field (name_or_id = partition_field .source_id ).name ]
2518+ .take (offsets )
2519+ .to_pylist ()
2520+ for partition_field in partition_fields
2521+ }
2522+
2523+ table_partitions = []
2524+ for inst in sorted_slice_instructions :
2525+ partition_slice = arrow_table .slice (** inst )
2526+ fieldvalues = [
2527+ PartitionFieldValue (partition_field .source_id , projected_and_filtered [partition_field .source_id ][inst ["offset" ]])
2528+ for partition_field in partition_fields
2529+ ]
2530+ partition_key = PartitionKey (raw_partition_field_values = fieldvalues , partition_spec = partition_spec , schema = schema )
2531+ table_partitions .append (TablePartition (partition_key = partition_key , arrow_table_partition = partition_slice ))
2532+
2533+ return table_partitions
25502534
25512535
25522536def _partition (iceberg_table : Table , arrow_table : pa .Table ) -> Iterable [TablePartition ]:
@@ -2584,7 +2568,7 @@ def _partition(iceberg_table: Table, arrow_table: pa.Table) -> Iterable[TablePar
25842568 reversing_sort_order_options = _get_partition_sort_order (partition_columns , reverse = True )
25852569 reversed_indices = pa .compute .sort_indices (arrow_table , ** reversing_sort_order_options ).to_pylist ()
25862570
2587- slice_instructions = []
2571+ slice_instructions : list [ dict [ str , Any ]] = []
25882572 last = len (reversed_indices )
25892573 reversed_indices_size = len (reversed_indices )
25902574 ptr = 0
@@ -2595,13 +2579,10 @@ def _partition(iceberg_table: Table, arrow_table: pa.Table) -> Iterable[TablePar
25952579 last = reversed_indices [ptr ]
25962580 ptr = ptr + group_size
25972581
2598- table_partitions : list [TablePartition ] = [
2599- TablePartition (
2600- partition_key = _get_partition_key (arrow_table , partition_columns , inst ["offset" ], iceberg_table .spec ()),
2601- arrow_table_partition = arrow_table .slice (** inst ),
2602- )
2603- for inst in slice_instructions
2604- ]
2582+ table_partitions : list [TablePartition ] = _get_table_partitions (
2583+ arrow_table , iceberg_table .spec (), iceberg_table .schema (), slice_instructions
2584+ )
2585+
26052586 return table_partitions
26062587
26072588
0 commit comments