Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 98 additions & 3 deletions pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,21 @@
# under the License.
from __future__ import annotations

import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import date, datetime
from functools import cached_property, singledispatch
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar
from typing import (
Any,
Dict,
Generic,
List,
Optional,
Tuple,
TypeVar,
)
from urllib.parse import quote

from pydantic import (
BeforeValidator,
Expand All @@ -41,8 +53,18 @@
YearTransform,
parse_transform,
)
from pyiceberg.typedef import IcebergBaseModel
from pyiceberg.types import NestedField, StructType
from pyiceberg.typedef import IcebergBaseModel, Record
from pyiceberg.types import (
DateType,
IcebergType,
NestedField,
PrimitiveType,
StructType,
TimestampType,
TimestamptzType,
UUIDType,
)
from pyiceberg.utils.datetime import date_to_days, datetime_to_micros

INITIAL_PARTITION_SPEC_ID = 0
PARTITION_FIELD_ID_START: int = 1000
Expand Down Expand Up @@ -199,6 +221,23 @@ def partition_type(self, schema: Schema) -> StructType:
nested_fields.append(NestedField(field.field_id, field.name, result_type, required=False))
return StructType(*nested_fields)

def partition_to_path(self, data: Record, schema: Schema) -> str:
partition_type = self.partition_type(schema)
field_types = partition_type.fields

field_strs = []
value_strs = []
for pos, value in enumerate(data.record_fields()):
partition_field = self.fields[pos]
value_str = partition_field.transform.to_human_string(field_types[pos].field_type, value=value)

value_str = quote(value_str, safe='')
value_strs.append(value_str)
field_strs.append(partition_field.name)

path = "/".join([field_str + "=" + value_str for field_str, value_str in zip(field_strs, value_strs)])
return path


UNPARTITIONED_PARTITION_SPEC = PartitionSpec(spec_id=0)

Expand Down Expand Up @@ -326,3 +365,59 @@ def _visit_partition_field(schema: Schema, field: PartitionField, visitor: Parti
return visitor.unknown(field.field_id, source_name, field.source_id, repr(transform))
else:
raise ValueError(f"Unknown transform {transform}")


@dataclass(frozen=True)
class PartitionFieldValue:
field: PartitionField
value: Any


@dataclass(frozen=True)
class PartitionKey:
raw_partition_field_values: List[PartitionFieldValue]
partition_spec: PartitionSpec
schema: Schema

@cached_property
def partition(self) -> Record: # partition key transformed with iceberg internal representation as input
iceberg_typed_key_values = {}
for raw_partition_field_value in self.raw_partition_field_values:
partition_fields = self.partition_spec.source_id_to_fields_map[raw_partition_field_value.field.source_id]
if len(partition_fields) != 1:
raise ValueError("partition_fields must contain exactly one field.")
partition_field = partition_fields[0]
iceberg_type = self.schema.find_field(name_or_id=raw_partition_field_value.field.source_id).field_type
iceberg_typed_value = _to_partition_representation(iceberg_type, raw_partition_field_value.value)
transformed_value = partition_field.transform.transform(iceberg_type)(iceberg_typed_value)
iceberg_typed_key_values[partition_field.name] = transformed_value
return Record(**iceberg_typed_key_values)

def to_path(self) -> str:
return self.partition_spec.partition_to_path(self.partition, self.schema)


@singledispatch
def _to_partition_representation(type: IcebergType, value: Any) -> Any:
return TypeError(f"Unsupported partition field type: {type}")


@_to_partition_representation.register(TimestampType)
@_to_partition_representation.register(TimestamptzType)
def _(type: IcebergType, value: Optional[datetime]) -> Optional[int]:
return datetime_to_micros(value) if value is not None else None


@_to_partition_representation.register(DateType)
def _(type: IcebergType, value: Optional[date]) -> Optional[int]:
return date_to_days(value) if value is not None else None


@_to_partition_representation.register(UUIDType)
def _(type: IcebergType, value: Optional[uuid.UUID]) -> Optional[str]:
return str(value) if value is not None else None


@_to_partition_representation.register(PrimitiveType)
def _(type: IcebergType, value: Optional[Any]) -> Optional[Any]:
return value
5 changes: 5 additions & 0 deletions pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,11 @@ def _(value: int, _type: IcebergType) -> str:
return _int_to_human_string(_type, value)


@_human_string.register(bool)
def _(value: bool, _type: IcebergType) -> str:
return str(value).lower()


@singledispatch
def _int_to_human_string(_type: IcebergType, value: int) -> str:
return str(value)
Expand Down
51 changes: 50 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@
import boto3
import pytest
from moto import mock_aws
from pyspark.sql import SparkSession

from pyiceberg import schema
from pyiceberg.catalog import Catalog
from pyiceberg.catalog import Catalog, load_catalog
from pyiceberg.catalog.noop import NoopCatalog
from pyiceberg.expressions import BoundReference
from pyiceberg.io import (
Expand Down Expand Up @@ -1925,3 +1926,51 @@ def table_v2(example_table_metadata_v2: Dict[str, Any]) -> Table:
@pytest.fixture
def bound_reference_str() -> BoundReference[str]:
return BoundReference(field=NestedField(1, "field", StringType(), required=False), accessor=Accessor(position=0, inner=None))


@pytest.fixture(scope="session")
def session_catalog() -> Catalog:
return load_catalog(
"local",
**{
"type": "rest",
"uri": "http://localhost:8181",
"s3.endpoint": "http://localhost:9000",
"s3.access-key-id": "admin",
"s3.secret-access-key": "password",
},
)


@pytest.fixture(scope="session")
def spark() -> SparkSession:
import importlib.metadata
import os

spark_version = ".".join(importlib.metadata.version("pyspark").split(".")[:2])
scala_version = "2.12"
iceberg_version = "1.4.3"

os.environ["PYSPARK_SUBMIT_ARGS"] = (
f"--packages org.apache.iceberg:iceberg-spark-runtime-{spark_version}_{scala_version}:{iceberg_version},"
f"org.apache.iceberg:iceberg-aws-bundle:{iceberg_version} pyspark-shell"
)
os.environ["AWS_REGION"] = "us-east-1"
os.environ["AWS_ACCESS_KEY_ID"] = "admin"
os.environ["AWS_SECRET_ACCESS_KEY"] = "password"

spark = (
SparkSession.builder.appName("PyIceberg integration test")
.config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions")
.config("spark.sql.catalog.integration", "org.apache.iceberg.spark.SparkCatalog")
.config("spark.sql.catalog.integration.catalog-impl", "org.apache.iceberg.rest.RESTCatalog")
.config("spark.sql.catalog.integration.uri", "http://localhost:8181")
.config("spark.sql.catalog.integration.io-impl", "org.apache.iceberg.aws.s3.S3FileIO")
.config("spark.sql.catalog.integration.warehouse", "s3://warehouse/wh/")
.config("spark.sql.catalog.integration.s3.endpoint", "http://localhost:9000")
.config("spark.sql.catalog.integration.s3.path-style-access", "true")
.config("spark.sql.defaultCatalog", "integration")
.getOrCreate()
)

return spark
Loading